Defining Statistical Models in Jax?

hackandthink | 102 points

I'm very excited by the work being put in to make Bayesian inference more manageable. It's in a spot that feels very similar to deep learning circa mid-2010s when Caffe, Torch, and hand-written gradients were the options. We can do it, but doing anything more complicated than common model structures like hierarchical Gaussian linear models requires dropping out of the nice places and into the guts.

I've had a lot of success with Numpyro (a JAX library), and used quite a lot of tools that are simpler interfaces to Stan. I've also had to write quite a few model-specific things from scratch by hand (more for sequential Monte Carlo than MCMC). I'm very excited for a world where PPLs become scalable and easier to use /customize.

> I think there is a good chance that normalizing flow-based variational inference will displace MCMC as the go-to method for Bayesian posterior inference as soon as everyone gets access to good GPUs.

Wow. This is incredibly surprising. I'm only tangentially aware of normalizing flows, but apparently I need to look at the intersection of them and Bayesian statistics now! Any sources from anyone would be most appreciated!

JHonaker | 14 hours ago

Reading this post, and reviewing the documentation of NumPyro/Pyro, I think I'm not following the crucial difference between NumPyro/Pyro. I understand that Pyro uses PyTorch as backend, and NumPyro uses JAX as backend, but other than that I'm not sure about the critical differences. If their frontend is about the same (which seems to be the case here) why is JAX mentioned in this post? Could we simply not replace Pyro with Stan for statistical modelling (whether with PyTorch or JAX backend)?

gnulinux | 10 hours ago

I'm curious about the involvement of tech companies here. Obviously approximating posterior distributions of explicit statistical models via simulation techniques is common in academic scientific literature but I'd like to hear about examples of it being done in "production" settings, i.e. not just as a one-off analysis. I have for a long time had a vague belief that in production settings people usually opt for heuristics / point estimates etc but I haven't had much involvement with this sort of thing for a while.

Myrmornis | 7 hours ago

This is coming at the perfect time! I was recently trying to decide whether I wanted to implement a model in Stan or Pyro/Numpyro, and I've been eyeing implementing in JAX. I would love to write a tutorial comparing Stan to Jax.

techwizrd | 13 hours ago

Off topic: I think there's some opportunities for making bayesian inference technology more accessible, and I'd love to chat with other people in this space. Email in my profile.

helltone | 12 hours ago

this is great development!

Iwan-Zotow | 9 hours ago