How to classify dog and cat images using a neural network written in Jax

Tracyrenee
6 min readNov 18, 2023

In my last blog post I discussed how to classify dog and cat images using a neural network made from numpy. In this blog post I have translated the numpy code into Google’s research library, Jax.

Jax has been written very similar to the numpy API, but there are a few differences, so it is not a straightforward translation. For instance, random numbers must be coded differently in Jax. In addition, Jax arrays are immutable, so it takes a bit of alternative coding to change a Jax array.

A neural network is a method in artificial intelligence that teaches computers to process data in a way that is inspired by the human brain. It is a type of machine learning process, called deep learning, that uses interconnected nodes or neurons in a layered structure that resembles the human brain. It creates an adaptive system that computers use to learn from their mistakes and improve continuously. Thus, artificial neural networks attempt to solve complicated problems, like summarising documents or recognizing faces, with greater accuracy. Neural networks can help computers make intelligent decisions with limited human assistance. This is because they can learn and model the relationships between input and output data that are nonlinear and complex.

--

--

Tracyrenee

I have five decades experience in the world of work, being in fast food, the military, business, non-profits, and the healthcare sector.