Stanford MLSys Seminar

11/19/20 #6 Roy Frostig - The Story Behind JAX


Listen Later

Roy Frostig - JAX: accelerating machine learning research by composing function transformations in Python

JAX is a system for high-performance machine learning research and numerical computing. It offers the familiarity of Python+NumPy together with hardware acceleration, plus a set of composable function transformations: automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. JAX's core strength is its guarantee that these user-wielded transformations can be composed arbitrarily, so that programmers can write math (e.g. a loss function) and transform it into pieces of an ML program (e.g. a vectorized, compiled, batch gradient function for that loss).

JAX had its open-source release in December 2018 (https://github.com/google/jax). It's used by researchers for a wide range of applications, from studying training dynamics of neural networks, to probabilistic programming, to scientific applications in physics and biology.

...more
View all episodesView all episodes
Download on the App Store

Stanford MLSys SeminarBy Dan Fu, Karan Goel, Fiodar Kazhamakia, Piero Molino, Matei Zaharia, Chris Ré

  • 5
  • 5
  • 5
  • 5
  • 5

5

7 ratings