Applying K-Means to real-world data: MNIST

The MNIST dataset, a collection of handwritten digits, is a popular benchmark for machine learning algorithms. In this guide, we will explore how to apply K-means clustering to the MNIST dataset and evaluate the results.

Understanding the MNIST Dataset

The MNIST dataset consists of 70,000 grayscale images, each representing a handwritten digit from 0 to 9. Each image is 28×28 pixels, resulting in a total of 784 features.

Loading the MNIST Dataset

To load the MNIST dataset, we can use the keras.datasets module:

Python

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

Preprocessing the Data

Before applying K-means, we need to preprocess the data:

  1. Flatten the images: Convert each 28×28 image into a one-dimensional vector of 784 pixels.
  2. Normalize the pixel values: Scale the pixel values to the range [0, 1] to improve numerical stability.

Python

X_train = X_train.reshape(X_train.shape[0], -1).astype('float32') / 255
X_test = X_test.reshape(X_test.shape[0], -1).astype('float32') / 255

Applying K-Means

Create a K-means model and fit it to the training data:

Python

from sklearn.cluster import KMeans

n_clusters = 10  # Assuming we want 10 clusters for the 10 digits
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(X_train)

Evaluating the Clustering Results

We can evaluate the clustering results using various metrics, such as purity and the Davies-Bouldin index. However, since we have the ground truth labels for the MNIST dataset, we can also directly compare the predicted cluster labels with the true labels.

Python

y_pred_train = kmeans.predict(X_train)
y_pred_test = kmeans.predict(X_test)

from sklearn.metrics import accuracy_score, adjusted_rand_score

accuracy_train = accuracy_score(y_train, y_pred_train)
accuracy_test = accuracy_score(y_test, y_pred_test)
ari_train = adjusted_rand_score(y_train, y_pred_train)
ari_test = adjusted_rand_score(y_test, y_pred_test)

print("Accuracy (train):", accuracy_train)
print("Accuracy (test):", accuracy_test)
print("Adjusted Rand Index (train):", ari_train)
print("Adjusted Rand Index (test):", ari_test)

Visualization

We can also visualize the clustering results by plotting the data points and the cluster centroids. This can provide insights into the quality of the clustering and the separation between clusters.

Limitations and Considerations

  • Number of clusters: The choice of the number of clusters (K) is crucial for the performance of K-means. In this example, we assumed 10 clusters, but the optimal number may vary depending on the data.
  • Initialization: The initial centroids can affect the clustering results. Techniques like K-means++ can help to improve initialization.
  • Scalability: K-means can be computationally expensive for large datasets. Consider using techniques like mini-batch K-means or distributed K-means for scalability.

By following these steps, you can effectively apply K-means clustering to the MNIST dataset and evaluate the performance of the model.

Methods to evaluate clustering (Purity, Davies-Bouldin Index)
A method for selecting K

Get industry recognized certification – Contact us

keyboard_arrow_up
Open chat
Need help?
Hello 👋
Can we help you?