Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to make a scatter plot for clustering in Python?
A scatter plot for clustering visualizes data points grouped by cluster membership, with cluster centers marked distinctly. This helps analyze clustering algorithms like K-means by showing how data points are distributed across different clusters.
Basic Clustering Scatter Plot
Here's how to create a scatter plot that shows clustered data points with their centers ?
import numpy as np
import matplotlib.pyplot as plt
# Set figure size
plt.rcParams["figure.figsize"] = [8.00, 6.00]
plt.rcParams["figure.autolayout"] = True
# Generate sample data points
x = np.random.randn(15)
y = np.random.randn(15)
# Assign cluster labels (0, 1, 2, 3 for 4 clusters)
cluster_labels = np.array([0, 1, 1, 1, 3, 2, 2, 3, 0, 2, 0, 1, 3, 2, 1])
# Define cluster centers
centers = np.array([[0.5, 0.2], [-0.3, 0.8], [1.2, -0.5], [-1.0, -0.7]])
# Create the plot
fig = plt.figure()
ax = fig.add_subplot(111)
# Plot data points colored by cluster
scatter = ax.scatter(x, y, c=cluster_labels, s=80, alpha=0.7, cmap='viridis')
# Plot cluster centers
for center in centers:
ax.scatter(center[0], center[1], s=200, c='red', marker='X',
edgecolors='black', linewidth=2, label='Center')
# Add labels and title
ax.set_xlabel('X coordinate')
ax.set_ylabel('Y coordinate')
ax.set_title('Scatter Plot for Clustering')
ax.grid(True, alpha=0.3)
# Show only one legend entry for centers
handles, labels = ax.get_legend_handles_labels()
if handles:
ax.legend([handles[0]], ['Cluster Centers'])
plt.colorbar(scatter, label='Cluster ID')
plt.show()
Real-World Clustering Example
Here's a more realistic example using K-means clustering on generated data ?
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
# Generate sample data with 3 natural clusters
data, true_labels = make_blobs(n_samples=150, centers=3,
cluster_std=1.5, random_state=42)
# Apply K-means clustering
kmeans = KMeans(n_clusters=3, random_state=42)
predicted_labels = kmeans.fit_predict(data)
centers = kmeans.cluster_centers_
# Create the scatter plot
plt.figure(figsize=(10, 8))
# Plot original data points
plt.subplot(2, 1, 1)
plt.scatter(data[:, 0], data[:, 1], c=true_labels, cmap='viridis', alpha=0.7)
plt.title('True Clusters')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
# Plot K-means results
plt.subplot(2, 1, 2)
plt.scatter(data[:, 0], data[:, 1], c=predicted_labels, cmap='viridis', alpha=0.7)
plt.scatter(centers[:, 0], centers[:, 1], c='red', marker='X',
s=200, edgecolors='black', linewidth=2, label='Centroids')
plt.title('K-means Clustering Results')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.tight_layout()
plt.show()
Key Components
Essential elements of a clustering scatter plot include ?
-
Data Points − Colored by cluster assignment using parameter
c - Cluster Centers − Marked with distinct symbols like 'X' or '+'
-
Color Map − Different colors for each cluster using
cmap - Legend/Colorbar − Shows cluster identification
Customization Options
| Parameter | Purpose | Example Values |
|---|---|---|
s |
Point size | 50, 100, 200 |
alpha |
Transparency | 0.5, 0.7, 1.0 |
cmap |
Color scheme | 'viridis', 'plasma', 'tab10' |
marker |
Point shape | 'o', 'X', '+', 's' |
Conclusion
Scatter plots effectively visualize clustering results by color-coding data points and highlighting cluster centers. Use different markers and colors to distinguish between data points and centroids for clear cluster visualization.
