From personalized recommendations to scientific advances, AI models are helping to improve lives and transform industries. But the impact and accuracy of these AI models is often determined by the quality of data they use. Large, high-quality datasets are crucial for developing accurate and representative AI models, however, they must be used in ways that preserve individual privacy.
That’s where JAX and JAX-Privacy come in. Introduced in 2020, JAX is a high-performance numerical computing library designed for large-scale machine learning (ML). Its core features — including automatic differentiation, just-in-time compilation, and seamless scaling across multiple accelerators — make it an ideal platform for building and training complex models efficiently. JAX has become a cornerstone for researchers and engineers pushing the boundaries of AI. Its surrounding ecosystem includes a robust set of domain-specific libraries, including Flax, which simplifies the implementation of neural network architectures, and Optax, which implements state-of-the-art optimizers.
Built on JAX, JAX-Privacy is a robust toolkit for building and auditing differentially private models. It enables researchers and developers to quickly and efficiently implement differentially private (DP) algorithms for training deep learning models on large datasets, and provides the core tools needed to integrate private training into modern distributed training workflows. The original version of JAX-Privacy was introduced in 2022 to enable external researchers to reproduce and validate some of our advances on private training. It has since evolved into a hub where research teams across Google integrate their novel research insights into DP training and auditing algorithms.
Today, we are proud to announce the release of JAX-Privacy 1.0. Integrating our latest research advances and re-designed for modularity, this new version makes it easier than ever for researchers and developers to build DP training pipelines that combine state-of-the-art DP algorithms with the scalability provided by JAX.

