Member-only story
Use Jax to create a neural network and predict on the Iris dataset
In my last blog post I discussed how to use Python’s numerical library, numpy, to create a neural network that will make predictions on the Iris dataset, and that post can be found here:- https://medium.com/@tracyrenee61/predict-on-the-iris-dataset-using-a-neural-network-made-in-numpy-794c31e4ddef
Neural networks, also known as artificial neural networks (ANNs), are a subset of machine learning and are at the heart of deep learning algorithms. Their name and structure are inspired by the human brain, mimicking the way that biological neurons signal to one another.
ANNs are comprised of a node layers, containing an input layer, one or more hidden layers, and an output layer. Each node, or artificial neuron, connects to another and has an associated weight and threshold. If the output of any individual node is above the specified threshold value, that node is activated, sending data to the next layer of the network. Otherwise, no data is passed along to the next layer of the network.
I chose the Iris dataset because it is simplistic and can yield high accuracy with the right model.
I have converted the program from numpy to Jax, which is Google Brain’s up and coming library that is intended to be used for research purposes. Jax is written similar to numpy’s API, but there are some dissimilarities in this library, such as the fact that Jax arrays are immutable and are not as flexible to work with. In addition, Jax’s random…