Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to split the Dataset With scikit-learnís train_test_split() Function
Machine learning models require proper data splitting to evaluate performance accurately. Scikit-learn's train_test_split() function provides a simple way to divide your dataset into training and testing portions, ensuring your model can be validated on unseen data.
Syntax
from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Parameters
X, y: Feature matrix and target vector respectively
test_size: Proportion of data for testing (typically 0.2 or 20%)
random_state: Seed for reproducible random splitting
stratify: Maintains class proportions in splits
shuffle: Whether to shuffle data before splitting (default True)
Basic Train-Test Split
Here's a simple example using sample data ?
from sklearn.model_selection import train_test_split
import numpy as np
# Sample data
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
y = np.array([0, 1, 0, 1, 0, 1, 0, 1])
# Split data (80% training, 20% testing)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("Training features shape:", X_train.shape)
print("Testing features shape:", X_test.shape)
print("Training targets shape:", y_train.shape)
print("Testing targets shape:", y_test.shape)
Training features shape: (6, 2) Testing features shape: (2, 2) Training targets shape: (6,) Testing targets shape: (2,)
Stratified Split
For classification tasks, use stratified splitting to maintain class proportions ?
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
# Create imbalanced dataset
X, y = make_classification(n_samples=100, n_classes=2, weights=[0.8, 0.2], random_state=42)
# Stratified split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
print("Original class distribution:", np.bincount(y) / len(y))
print("Training class distribution:", np.bincount(y_train) / len(y_train))
print("Testing class distribution:", np.bincount(y_test) / len(y_test))
Original class distribution: [0.8 0.2] Training class distribution: [0.8 0.2] Testing class distribution: [0.8 0.2]
Three-Way Split (Train-Validation-Test)
For complex models requiring validation sets, split the data three ways ?
from sklearn.model_selection import train_test_split
import numpy as np
# Sample data
X = np.random.rand(100, 4)
y = np.random.randint(0, 2, 100)
# First split: 60% train, 40% temp
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
# Second split: 20% validation, 20% test (from the 40% temp)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
print(f"Training set: {X_train.shape[0]} samples ({X_train.shape[0]/len(X)*100:.0f}%)")
print(f"Validation set: {X_val.shape[0]} samples ({X_val.shape[0]/len(X)*100:.0f}%)")
print(f"Test set: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.0f}%)")
Training set: 60 samples (60%) Validation set: 20 samples (20%) Test set: 20 samples (20%)
Comparison of Split Methods
| Method | Use Case | Key Parameter |
|---|---|---|
| Basic Split | Simple train-test evaluation | test_size |
| Stratified Split | Imbalanced classification | stratify=y |
| Three-way Split | Model tuning with validation | Two-step splitting |
| Fixed Random State | Reproducible results | random_state |
Conclusion
The train_test_split() function is essential for proper model evaluation in machine learning. Use stratified splits for classification tasks and set random_state for reproducible results. For complex workflows, consider three-way splits with dedicated validation sets.
