Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to generate and plot classification dataset using Python Scikit-learn?
Scikit-learn provides the make_classification() function to generate synthetic classification datasets with configurable parameters like informative features, clusters per class, and number of classes. This is useful for testing machine learning algorithms and understanding data patterns.
Understanding make_classification() Parameters
The key parameters for controlling dataset generation are:
- n_features ? Total number of features
- n_informative ? Number of informative features
- n_redundant ? Number of redundant features
- n_clusters_per_class ? Number of clusters per class
- n_classes ? Number of classes (default is 2)
Dataset with One Informative Feature
Here's how to create a classification dataset with one informative feature and one cluster per class:
# Importing libraries
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# Creating the classification dataset with one informative feature
X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, random_state=42)
# Plotting the dataset
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k", cmap='viridis')
plt.title("Classification Dataset: One Informative Feature", fontsize=14)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()
Dataset with Two Informative Features
Creating a dataset with two informative features provides better class separation:
# Importing libraries
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# Creating the classification dataset with two informative features
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2, n_clusters_per_class=1, random_state=42)
# Plotting the dataset
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k", cmap='viridis')
plt.title("Classification Dataset: Two Informative Features", fontsize=14)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()
Dataset with Multiple Clusters per Class
Increase complexity by adding multiple clusters per class:
# Importing libraries
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# Creating dataset with two clusters per class
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2, n_clusters_per_class=2, random_state=42)
# Plotting the dataset
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k", cmap='viridis')
plt.title("Classification Dataset: Two Clusters per Class", fontsize=14)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()
Multi-class Classification Dataset
Generate datasets with more than two classes for multi-class classification:
# Importing libraries
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# Creating multi-class classification dataset
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
n_clusters_per_class=1, n_classes=3, random_state=42)
# Plotting the dataset
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k", cmap='viridis')
plt.title("Multi-class Classification Dataset", fontsize=14)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label='Class')
plt.show()
Comparison of Dataset Types
| Dataset Type | Informative Features | Clusters per Class | Classes | Best For |
|---|---|---|---|---|
| Simple Binary | 1 | 1 | 2 | Basic classification |
| Complex Binary | 2 | 1 | 2 | Linear separable data |
| Clustered Binary | 2 | 2 | 2 | Non-linear problems |
| Multi-class | 2 | 1 | 3+ | Multi-class algorithms |
Conclusion
The make_classification() function is essential for creating synthetic datasets to test classification algorithms. Use different parameter combinations to generate datasets that match your specific testing requirements.
