Sitemap

Use Jax to predict on whether an image is a pizza or not

4 min readNov 27, 2023
Press enter or click to view image in full size

In my last post I described how to use Python’s numpy library to classify an image as being pizza or not pizza. That blog post can be found here:- https://medium.com/@tracyrenee61/use-numpy-to-classify-an-image-as-being-pizza-or-not-pizza-f25c7c3fe9b7

In this blog post I have translated the code in numpy to Jax, which is a python library used for research purposes. Although Jax code is similar to numpy in many ways, it is not a direct translation. For instance, Jax arrays are immutable so they are more complex to work with. Jax’s random numbers are also different from numpy’s, so this requires a change in the coding.

The dataset that I have used is Kaggle’s pizza or not pizza dataset and it can be found here:- https://www.kaggle.com/datasets/carlosrunner/pizza-not-pizza/data

I have written the classification program in Python using Kaggle’s free online Jupyter Notebook.

Once I created the Jupyter Notebook, I imported the libraries that I would need to execute the program, being:-

  1. Jax to create the neural network and perform numerical computations,
  2. PIL to process the images and view them,
  3. Pathlib to establish paths of the images,
  4. Cv2 to carry out computer vision functionality,
  5. Os to go into the operating system, and
  6. matplotlib to visualise the data

I used jax to set the random number in the program.

I then used the os library to retrieve the images used in the program:-

Press enter or click to view image in full size

I used the PIL library to visualise an image of pizza and an image of not pizza:-

Press enter or click to view image in full size

I defined the function, load_and_process_images, which loads ands processes all of the images in the dataset:-

Press enter or click to view image in full size

I created the path path for pizza images and the path for not pizza images.

I then loaded and processed the images and labels for the images of pizza and not pizza.

I stacked the data and concatenated the labels:-

Press enter or click to view image in full size

I shuffled the dataset and reshaped the dependent variable to make it compatible with the neural network.

I then split the dataset into training and testing sets:-

Press enter or click to view image in full size

I defined the neural network architecture by initialising the values of the input_size, hidden_size, output_size, learning_rate, and epochs.

I then initialised the weights and bias of the neural network to matrices of zero:-

Press enter or click to view image in full size

I defined the two functions that are used in the neural network, being the sigmoid function and the sigmoid derivative function:-

I trained the training data into the network:-

Press enter or click to view image in full size

And finally, I evaluated the model using the test dataset:-

Press enter or click to view image in full size

In this instance, I achieved an accuracy of 58.12%. This accuracy is not too high and there are other models, such as pretrained models, that will achieve a higher score. The point is, however, is to understand the inner workings of a neural network in order to progress on to more complex machine learning models.

I have created a code review to accompany this blog post, which can be viewed here:- https://youtu.be/6WKuDBh0Hgc

--

--

Crystal X
Crystal X

Written by Crystal X

I have over five decades experience in the world of work, being in fast food, the military, business, non-profits, and the healthcare sector.

No responses yet