Member-only story

Use Jax to create a neural network and predict on the Iris dataset

Crystal X
5 min readOct 14, 2023

--

In my last blog post I discussed how to use Python’s numerical library, numpy, to create a neural network that will make predictions on the Iris dataset, and that post can be found here:- https://medium.com/@tracyrenee61/predict-on-the-iris-dataset-using-a-neural-network-made-in-numpy-794c31e4ddef

Neural networks, also known as artificial neural networks (ANNs), are a subset of machine learning and are at the heart of deep learning algorithms. Their name and structure are inspired by the human brain, mimicking the way that biological neurons signal to one another.

ANNs are comprised of a node layers, containing an input layer, one or more hidden layers, and an output layer. Each node, or artificial neuron, connects to another and has an associated weight and threshold. If the output of any individual node is above the specified threshold value, that node is activated, sending data to the next layer of the network. Otherwise, no data is passed along to the next layer of the network.

I chose the Iris dataset because it is simplistic and can yield high accuracy with the right model.

I have converted the program from numpy to Jax, which is Google Brain’s up and coming library that is intended to be used for research purposes. Jax is written similar to numpy’s API, but there are some dissimilarities in this library, such as the fact that Jax arrays are immutable and are not as flexible to work with. In addition, Jax’s random…

--

--

Crystal X
Crystal X

Written by Crystal X

I have over five decades experience in the world of work, being in fast food, the military, business, non-profits, and the healthcare sector.

No responses yet