1 day agoCreated a post • 135 points @ChrisRackauckas
...doesn't the JAX example just need the argument set to static_argnums and then it will work?Reply
There are many interesting threads in this post, one of which is using “non standard interpretations” of programs, and enabling the compiler to augment the human-written code with the extra pieces necessary to get gradients, propagate uncertainties, etc. I wonder whether there’s a more unified discussion of the potential of these methods. I suspect that a lot of “solvers” (each typically with their own DSL for specifying the problem) might be nicely formulated in such a framework. (Particularly in the case of auto diff, I found recent work/talks by Conal Elliot and Tom Minka quite enlightening.)
Tangentially, thinking about Julia, while one initially gets awed by the speed, and then the multiple dispatch, I wonder whether it’s deepest superpower (that we’re still discovering) might be the expressiveness to augment the compiler to do interesting things with a piece of code. Generic programming then acts as a lever to use these improvements for a variety of use cases, and the speed is merely the icing on the cake!Reply
There is no free lunch:).
I remember spending a summer using Template Model Builder (TMB), which is a useful R/C++ automatic differentiation (AD) framework, for working with accelerated failure time models. For these models, the survival to time T given covariates X is defined by S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival S_0(t). I wanted to use splines for the baseline survival and then use AD for gradients and random effects. Unfortunately, after implementing the splines in template C++, I found a web page entitled "Things you should NOT do in TMB" (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...) - which included using if statements that are based on coefficients. In this case, the splines for S_0 depend on beta, which is this specific excluded case:(. An older framework (ADMB) did not have this constraint, but dissemination of code was more difficult. Finally, PyTorch did not have an implementation of B-splines or an implementation for Laplace's approximation. Returning to my opening comment, there is no free lunch.Reply
Tangentialy related: Faster training of Neural ODEs is super exciting! There are a lot of promising applications (although personally I believe that the intuition of "magically choosing the number of layers" is misguided, but I'm not am expert and might be wrong) but right now it takes forever to train even on toy problems, but I'm sure that enough work in this direction will eventually lead to more practical methods.Reply
This is a really cool post.
It seems like you can't solve this kind of thing with a new jax primitive for the algorithm, but what prevents new function transformations from doing what the mentioned julia libraries do? It seems like between new function transformations and new primitives, you out to be able to do just about anything. Is XLA the issue, and you could run but not jit the result?Reply
The example that fails in Jax would work fine in PyTorch. If you're working on purely training the model, TorchScript doesn't give many benefits, if any.Reply