Member-only story
Use Flax and Jax to build a MLPClassifier to make predictions on the breast cancer dataset
In my previous two blog posts, I used Jax and Flax to create a regression neural network, and those blog posts can be found here:-
- https://medium.com/@tracyrenee61/make-predictions-on-the-california-house-price-dataset-using-a-flax-and-jax-mlpr-8f00a228ae1a
- https://medium.com/@tracyrenee61/predict-on-the-boston-house-price-dataset-using-a-mlpregressor-made-from-jax-and-flax-90b08e36474c
In this blog post I am going to endeavour to review the code of a neural network that has been made with Jax and Flax for a binary classification dataset.
I decided to create a MLPClassifier using a binary classification dataset to demonstrate the sigmoid function in a neural network created by Flax. I have used sklearn’s breast cancer dataset to illustrate how Flax and Jax can be used to create a binary classification MLPClassifier.
I have written the Python program in Google Colab, which is a free online Jupyter Notebook hosted by Google. Once I created the notebook, I imported the libraries that I would need to execute the program:-
- Jax to create jax arrays and perform numerical computations at a higher level than numpy,
- Typing module, which attempts to…