Member-only story
Use Jax’s regression neural network to predict on Boston house prices
In a recent blog post I discussed how to make a regression neural network with numpy, and that post can be read here:- https://medium.com/@tracyrenee61/use-a-regression-neural-network-made-in-numpy-to-predict-on-house-prices-304c47ae527d
In this post I will discuss how to make a regression neural network with Jax. Although the Jax library is similar to the numpy API, it is not exact. For instance the random number generator in Jax is different from numpy. In addition, Jax arrays are immutable, so it is more difficult to program using these structures than it is with numpy.
I have written the program using Google Colab, a free online Jupyter Notebook hosted by Google. Google Colab is a great platform to use to write code in Python, with the only exception being that it does not have an undo function. Therefore, care needs to be taken not to inadvertently overwrite or delete valuable code.
Once the Jupyter Notebook was created, I imported the libraries that I would need to execute the program, being:-
- Jax to create the neural network,
- Pandas to create the dataframe and process the data,
- Sklearn to provide machine learning functionality,
- Seaborn to statistically visualise the…