MNIST Dataset Analysis (Part-1)

Posted by Ashutosh on January 25, 2020

What is the MNIST dataset?

MNIST  database, alternatively known as the Mixed National Institute of Standards and Technology database. It is the collection of large Images dataset (70K Images)  commonly used for testing of Machine Learning Classification algorithms. 

Each image is handwritten, well labeled and is of 28 x 28 pixels.

The MNIST dataset is a reliable source for starting the Image Classification problem. 


  • Elementary knowledge of Python
  • Fundamental knowledge of Machine Learning.
  • Any IDE of your preference (Pycharm preferred.)

Getting Started

To get started, download the dataset first.

Classification algorithm on MNIST dataset

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import random
 mnist_data = fetch_openml('mnist_784')

Execute the program to download the dataset at $HOME/scikit_learn_data/.

Unzip the file $HOME/scikit_learn_data/openml/, to view the features data.

To print the length of 'mnis_data' variable print(len(mnist_data)).

The above will print '9' on the screen. It means that mnist_data has 9 keys for each Image. To print the keys
    # print keys
By executing the program, you will see the following output on the screen.
  dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'DESCR', 'details', 'categories', 'url'])

The subsequent task is to separate the data and the target.

    def get_mnis_data_and_target():
    return mnist_data['data'], mnist_data['target']

# print keys

    x, y = get_mnis_data_and_target()

For now, we have successfully downloaded the dataset and separated the data and the target. Before proceeding, we want to see what we have in the data. How does the handwritten image looks?

Go the IDE and add the following:

# Shape of x

70,000 indicates the total number of datasets, and 784 represents the distinctive feature of each image. If you recall, in earlier section MNIST dataset image has been labeled with 28 x 28 visible pixels, equal to 784.

The subsequent step is to import the matplotlib and random at the top of the program.

Random is to generate the random number. We use it to generate it a random number. And in Matplotlib we use imshow() function to display the image on screen.

# Generates the random number between 0 and 70,000
    index = random.randint(0, 70000)
print("Index is: " + str(index))

    random_digit = x[index]
    random_digit_image = random_digit.reshape(28, 28)


Execute the above program and an random image displays on screen.

In the next tutorial, we learn more about MNIST data Training.