Gen-AI with JAX for simulation data [course]
Note: all times are shown in the timezone in which each event occurs.
Date: 7 May 2026 @ 10:00 - 12:00
Timezone: Pacific Daylight Time
Language of instruction: English
Register
Abstract: Most machine learning courses focus on standard workflows -- such as image classification -- where many pre-trained community models are readily available. In this course, we take a different approach by building a model for a custom scientific task: predicting the solution of a partial differential equation (PDE) from a given initial condition. The goal is to train a neural network that acts as a fast surrogate for traditional numerical solvers.
Using JAX, Flax, and Optax, we will train a surrogate model on 2D simulation data that we generate ourselves. Along the way, we will develop a complete machine learning pipeline from scratch: generating synthetic training data via simulations, selecting the right model architecture, training the model on an HPC cluster, making sure it runs efficiently on cluster's GPUs, and evaluating the quality of the surrogate solutions.
Keywords: Machine Learning, AI
Activity log