What is the MNIST dataset?
MNIST database, alternatively known as the Mixed National Institute of Standards and Technology database. It is the collection of large Images dataset (70K Images) commonly used for testing of Machine Learning Classification algorithms.
Each image is handwritten, well labeled and is of 28 x 28 pixels.
The MNIST dataset is a reliable source for starting the Image Classification problem.
Following are the essentials for this post:
- Elementary knowledge of Python
- Fundamental knowledge of Machine Learning.
- Any IDE of your preference (Pycharm preferred.)
To get started, download the dataset first.
""" Classification algorithm on MNIST dataset """ from sklearn.datasets import fetch_openml import matplotlib.pyplot as plt # To download the data mnist_data = fetch_openml('mnist_784')
Execute the program to download the dataset at $HOME/scikit_learn_data/.
Unzip the file $HOME/scikit_learn_data/openml/openml.org/api/v1/json/data/features/554.gz, to view the features data.
To print the length of 'mnis_data' variable print(len(mnist_data)).
The above will print '9' on the screen. It means that mnist_data has 9 keys for each Image. To print the keys
By executing the program, you will see the following output on the screen.
dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'DESCR', 'details', 'categories', 'url'])
The subsequent task is to separate the data and the target.
def get_mnis_data__and_target(): return mnist_data['data'], mnist_data['target'] x, y = get_mnis_data_and_target()
For now, we have successfully downloaded the dataset and separated the data and the target. Before proceeding, we want to see what we have in the data. How does the handwritten image looks?
Go the IDE and add the following:
# Shape of x print(x.shape)
70,000 indicates the total number of datasets, and 784 represents the distinctive feature of each image. If you recall, in earlier section MNIST dataset image has been labeled with 28 x 28 visible pixels, equal to 784.
The subsequent step is to import the matplotlib and random at the top of the program.
Random is to generate the random number. We use it to generate it a random number. And in Matplotlib we use imshow() function to display the image on screen.
import matplotlib.pyplot as plt import random # Generates the random number between 0 and 70,000 index = random.randint(0, 70000) print("Index is: " + str(index)) random_digit = x[index] random_digit_image = random_digit.reshape(28, 28) plt.imshow(random_digit_image) plt.axis('off') plt.show()
Execute the above program and an random image displays on screen.
In the next tutorial, we learn more about MNIST data Training.
Join the newsletter to receive the latest updates in your inbox.