How to classify dog and cat images using a neural network written in Jax

6 min readNov 18, 2023

In my last blog post I discussed how to classify dog and cat images using a neural network made from numpy. In this blog post I have translated the numpy code into Google’s research library, Jax.

Jax has been written very similar to the numpy API, but there are a few differences, so it is not a straightforward translation. For instance, random numbers must be coded differently in Jax. In addition, Jax arrays are immutable, so it takes a bit of alternative coding to change a Jax array.

A neural network is a method in artificial intelligence that teaches computers to process data in a way that is inspired by the human brain. It is a type of machine learning process, called deep learning, that uses interconnected nodes or neurons in a layered structure that resembles the human brain. It creates an adaptive system that computers use to learn from their mistakes and improve continuously. Thus, artificial neural networks attempt to solve complicated problems, like summarising documents or recognizing faces, with greater accuracy. Neural networks can help computers make intelligent decisions with limited human assistance. This is because they can learn and model the relationships between input and output data that are nonlinear and complex.

Computer vision is the ability of computers to extract information and insights from images and videos. With neural networks, computers can distinguish and recognize images similar to the way humans do.

Neural networks are composed of three layers, being:-

  1. Input layer. Information from the outside world enters the artificial neural network from the input layer. Input nodes process the data, analyse or categorise it, and pass it on to the next layer.
  2. Hidden layer. Hidden layers take their input from the input layer or other hidden layers. Artificial neural networks can have a large number of hidden layers. Each hidden layer analyses the output from the previous layer, processes it further, and passes it on to the next layer.
  3. Output layer. The output layer gives the final result of all the data processing by the artificial neural network. It can have single or multiple nodes. For instance, if we have a binary (yes/no)…




I have close to five decades experience in the world of work, being in fast food, the military, business, non-profits, and the healthcare sector.