How to split the Dataset With scikit-learnís train_test_split() Function

Machine learning models require proper data splitting to evaluate performance accurately. Scikit-learn's train_test_split() function provides a simple way to divide your dataset into training and testing portions, ensuring your model can be validated on unseen data.

Syntax

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Parameters

  • X, y: Feature matrix and target vector respectively

  • test_size: Proportion of data for testing (typically 0.2 or 20%)

  • random_state: Seed for reproducible random splitting

  • stratify: Maintains class proportions in splits

  • shuffle: Whether to shuffle data before splitting (default True)

Basic Train-Test Split

Here's a simple example using sample data ?

from sklearn.model_selection import train_test_split
import numpy as np

# Sample data
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
y = np.array([0, 1, 0, 1, 0, 1, 0, 1])

# Split data (80% training, 20% testing)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print("Training features shape:", X_train.shape)
print("Testing features shape:", X_test.shape)
print("Training targets shape:", y_train.shape)
print("Testing targets shape:", y_test.shape)
Training features shape: (6, 2)
Testing features shape: (2, 2)
Training targets shape: (6,)
Testing targets shape: (2,)

Stratified Split

For classification tasks, use stratified splitting to maintain class proportions ?

from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

# Create imbalanced dataset
X, y = make_classification(n_samples=100, n_classes=2, weights=[0.8, 0.2], random_state=42)

# Stratified split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)

print("Original class distribution:", np.bincount(y) / len(y))
print("Training class distribution:", np.bincount(y_train) / len(y_train))
print("Testing class distribution:", np.bincount(y_test) / len(y_test))
Original class distribution: [0.8 0.2]
Training class distribution: [0.8 0.2]
Testing class distribution: [0.8 0.2]

Three-Way Split (Train-Validation-Test)

For complex models requiring validation sets, split the data three ways ?

from sklearn.model_selection import train_test_split
import numpy as np

# Sample data
X = np.random.rand(100, 4)
y = np.random.randint(0, 2, 100)

# First split: 60% train, 40% temp
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)

# Second split: 20% validation, 20% test (from the 40% temp)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f"Training set: {X_train.shape[0]} samples ({X_train.shape[0]/len(X)*100:.0f}%)")
print(f"Validation set: {X_val.shape[0]} samples ({X_val.shape[0]/len(X)*100:.0f}%)")
print(f"Test set: {X_test.shape[0]} samples ({X_test.shape[0]/len(X)*100:.0f}%)")
Training set: 60 samples (60%)
Validation set: 20 samples (20%)
Test set: 20 samples (20%)

Comparison of Split Methods

Method Use Case Key Parameter
Basic Split Simple train-test evaluation test_size
Stratified Split Imbalanced classification stratify=y
Three-way Split Model tuning with validation Two-step splitting
Fixed Random State Reproducible results random_state

Conclusion

The train_test_split() function is essential for proper model evaluation in machine learning. Use stratified splits for classification tasks and set random_state for reproducible results. For complex workflows, consider three-way splits with dedicated validation sets.

Updated on: 2026-03-27T13:56:49+05:30

462 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements