Use Jax to predict on whether an image is a pizza or not

4 min readNov 27, 2023

In my last post I described how to use Python’s numpy library to classify an image as being pizza or not pizza. That blog post can be found here:-

In this blog post I have translated the code in numpy to Jax, which is a python library used for research purposes. Although Jax code is similar to numpy in many ways, it is not a direct translation. For instance, Jax arrays are immutable so they are more complex to work with. Jax’s random numbers are also different from numpy’s, so this requires a change in the coding.

The dataset that I have used is Kaggle’s pizza or not pizza dataset and it can be found here:-

I have written the classification program in Python using Kaggle’s free online Jupyter Notebook.

Once I created the Jupyter Notebook, I imported the libraries that I would need to execute the program, being:-

  1. Jax to create the neural network and perform numerical computations,
  2. PIL to process the images and view them,
  3. Pathlib to establish paths of the images,
  4. Cv2 to carry out computer vision functionality,
  5. Os to go into the operating system, and
  6. matplotlib to visualise the data

I used jax to set the random number in the program.

I then used the os library to retrieve the images used in the program:-

I used the PIL library to visualise an image of pizza and an image of not pizza:-




