Bayesian Approximated Neural Network Example via JAX

Flexible Parameter Distributions for ODEs

Benjamin Murders
Towards Data Science

--

Although several known methods for approximating parameter distributions of ODEs exist (notably PyMC3), this proof-of-concept approach estimates an unknown distribution without providing a prior distribution. Using the concept of dropout in neural networks as a form of Bayesian approximation for model uncertainty, flexible parameter distributions can be approximated via sampling of the trained neural network.

Due to the advantage of having flexible output distributions learned by a given neural network, predictions can theoretically better fit real-life data accompanied by model uncertainty instead of hard discrete numbers. Having this information for predictions is imperative for many business scenarios, particularly relating to patient care in the healthcare industry. After all, why should a physician consider a robust model’s output for preferable medication based on a given patient’s chart when the interpretability or confidence for the model is unknown? We certainly would want to have some understanding of a model’s predictions when making critical decisions. When combining this approach with an ordinary differential equation (ODE) or even a system of ODEs, it can be easier to more effectively model complex real-world dynamics that are messy while still applying some rigid structure (for increased interpretability). The model uncertainty and flexible ODE parameter distributions can be derived without having to make influential prior assumptions of what the potential parameters “could be”.

What inspired me is this paper here:

[1506.02142] Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning (arxiv.org)

The example in this post is for approximating the two parameters needed for a simple ODE of the form, dy/dt = -r * y. One parameter for the growth rate, r, with the other parameter for the initial y value to integrate from for some t. One expectation from this proof-of-concept is visualizing the output distribution over the t axis with the possibility of having variable distributions across the t axis. Through sampling of the trained Bayesian neural network (via dropout), we can expect the approximated ODE parameter distributions to be flexible and therefore not guaranteed to be uniform. I find this to be extremely powerful and useful for more complex scenarios without having to provide some assumed prior distribution as the neural network will approximate the parameter distributions and yet can be scalable as well as support continuous/online training.

I prefer variable dependencies to be kept externally rather than embedded in-line, but I wanted to provide a working single file script to make it easier to see all referenced functions and variables. Here is the full code in a single file (file: bayesian_appx_nn_ex_jax.py):

The Python file depends on the following packages:

  • numpy
  • pandas
  • matplotlib
  • seaborn
  • jax[cpu] (JAX)

My approach when executing the Python script depends on the Tensorflow GPU Docker image (version 2.6.0, Python 3.6.9) for accelerated GPU CUDA training. However, neither Tensorflow nor a GPU is required for running this script as the standard CPU-based JAX package will work (successfully tested with listed packages and Python 3.9). JAX’s utilization of XLA and JIT is amazing even for CPUs! If you happen to be using VS Code when executing this Python file, there are character strings, #%%, which are part of VS Code’s Python interactive feature for having a Jupyter-like experience while working with Python files. If this Python file is exported out as a Jupyter file, those character strings will automatically create their corresponding cells (if running this file with interactive cells directly in a Jupyter notebook is desired over VS Code’s Interactive Window feature).

JAX’s experimental package includes Stax for neural network development and training within JAX directly. Other powerful packages use JAX for creating very robust neural networks such as Trax, but I believe Stax is easier to prototype for this demonstration without requiring additional packages beyond JAX. The Python file uses the following neural network hyperparameters:

  • Dropout rate value: 0.1
  • Units per layer: 2048
  • Number of hidden layers: 1
  • Primary activation function: MISH

Although a “deep” neural network will allow for more robust fitting, this example ODE is simple enough for a single hidden layer but with a “very wide” hidden layer (an important concept when attempting to relate or “approximate” a Gaussian Process given this Bayesian approximation approach via neural network dropout). For training, it’s critical to have a “good” guess before fully training on the training data. This is similar to most optimizers or minimization algorithms where a “bad” guess of initial parameter values could prevent effective convergence of an appropriate fit. So, the training is broken down into two parts with the first part for “priming” the neural network to output the “good” guess parameter values before the final training on noise-added training data. Taking this approach will help significantly with avoiding NaN values during training as a given function’s parameters could have very steep gradients depending on where the parameters reside. The optimizer of choice for both training instances is JAX’s ADAMAX with a step size of 2e-3 for the priming training session, and 2e-5 for the training against the noise-added training data. For the loss training function, Huber is used.

After training on the noise-added training data, 200 samples are used to determine the output distributions using quantiles for the arbitrary prediction intervals. Using quantiles over direct standard deviation + mean for plotting a prediction interval is more useful in this approach as you don’t have to worry about the interval disregarding important aspects of the ODE such as crossing over the horizontal asymptote of the target ODE function. Seaborn and Matplotlib are used for the following plots:

  • Deriving a density visual of the output samples across the t axis.
  • Contour visual of samples.
  • A plot of the individual samples.
  • Distributions of the approximated ODE parameters.
Density Plot of Output Distributions across Axis t
Image by Author — Density Plot of Output Distributions across Axis t
Contour View of Output
Image by Author — Contour View of Output
Individual Samples View with Prediction Interval
Image by Author — Individual Samples View with Prediction Interval
Sampled Distribution for Parameter Y(0)
Image by Author — Sampled Distribution for Parameter y(0)
Sampled Distribution for Parameter r
Image by Author — Sampled Distribution for Parameter r

We can see from the density plots of the approximated ODE parameters that they’re not completely uniform, which is what I wanted to see after training. The target ODE isn’t overly complicated but demonstrating how this approach can flexibly derive non-uniform distributions is the goal. Importantly, drawing more samples can affect the visualized approximated density plots of the ODE parameter estimates, but with 200 samples, it should be sufficient for what the approximated ODE parameter density “is”.

We could have approximated the function directly using an augmented neural ordinary differential equation in some form with the Bayesian twist presented here. But if we happen to have some reasonable knowledge or assumptions for what the underlying dynamics are, it would be better to leverage that approach. One likely outcome of attempting to model the underlying dynamics directly with a black-box model is that we wouldn’t be able to appropriately appreciate the horizontal asymptote that does exist for the referenced ODE. However, this approach does allow for leveraging neural networks to “learn” the target ODE parameters’ distributions and give a useful prediction interval and median output. This presented approach has been fascinating to me and hope that others will be inspired to leverage probabilistic outputs for data modeling (especially for black-box models such as neural networks).

--

--