Hacker News Re-Imagined

Useful Algorithms That Are Not Optimized by Jax, PyTorch, or TensorFlow

1 day ago

Created a post 135 points @ChrisRackauckas

Useful Algorithms That Are Not Optimized by Jax, PyTorch, or TensorFlow

@cjv 4 hours

Replying to @ChrisRackauckas 🎙

...doesn't the JAX example just need the argument set to static_argnums and then it will work?

Reply


@ssivark 5 hours

There are many interesting threads in this post, one of which is using “non standard interpretations” of programs, and enabling the compiler to augment the human-written code with the extra pieces necessary to get gradients, propagate uncertainties, etc. I wonder whether there’s a more unified discussion of the potential of these methods. I suspect that a lot of “solvers” (each typically with their own DSL for specifying the problem) might be nicely formulated in such a framework. (Particularly in the case of auto diff, I found recent work/talks by Conal Elliot and Tom Minka quite enlightening.)

Tangentially, thinking about Julia, while one initially gets awed by the speed, and then the multiple dispatch, I wonder whether it’s deepest superpower (that we’re still discovering) might be the expressiveness to augment the compiler to do interesting things with a piece of code. Generic programming then acts as a lever to use these improvements for a variety of use cases, and the speed is merely the icing on the cake!

Reply


@marcle 4 hours

There is no free lunch:).

I remember spending a summer using Template Model Builder (TMB), which is a useful R/C++ automatic differentiation (AD) framework, for working with accelerated failure time models. For these models, the survival to time T given covariates X is defined by S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival S_0(t). I wanted to use splines for the baseline survival and then use AD for gradients and random effects. Unfortunately, after implementing the splines in template C++, I found a web page entitled "Things you should NOT do in TMB" (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...) - which included using if statements that are based on coefficients. In this case, the splines for S_0 depend on beta, which is this specific excluded case:(. An older framework (ADMB) did not have this constraint, but dissemination of code was more difficult. Finally, PyTorch did not have an implementation of B-splines or an implementation for Laplace's approximation. Returning to my opening comment, there is no free lunch.

Reply


@_hl_ 1 hour

Tangentialy related: Faster training of Neural ODEs is super exciting! There are a lot of promising applications (although personally I believe that the intuition of "magically choosing the number of layers" is misguided, but I'm not am expert and might be wrong) but right now it takes forever to train even on toy problems, but I'm sure that enough work in this direction will eventually lead to more practical methods.

Reply


@6gvONxR4sf7o 4 hours

This is a really cool post.

It seems like you can't solve this kind of thing with a new jax primitive for the algorithm, but what prevents new function transformations from doing what the mentioned julia libraries do? It seems like between new function transformations and new primitives, you out to be able to do just about anything. Is XLA the issue, and you could run but not jit the result?

Reply


@ipsum2 2 hours

The example that fails in Jax would work fine in PyTorch. If you're working on purely training the model, TorchScript doesn't give many benefits, if any.

Reply


About Us

site design / logo © 2021 Box Piper