PySpark randomSplit() and sample() Methods

PySpark, an open-source framework for big data processing and analytics, offers powerful methods for working with large datasets. When dealing with massive amounts of data, it is often impractical to process everything at once. Data sampling, which involves selecting a representative subset of data, becomes crucial for efficient analysis. In PySpark, two commonly used methods for data sampling are randomSplit() and sample().

These methods allow us to extract subsets of data for different purposes like testing models or exploring data patterns. Let's explore how to use them effectively for data sampling in big data analytics.

Understanding Data Sampling

Data sampling is essential in many data analysis tasks. We can work with a manageable subset of the data while still capturing the essential characteristics of the entire dataset. Sampling significantly reduces computational overhead, accelerates analysis, and helps gain insights into the underlying data distribution.

PySpark randomSplit() Method

The randomSplit() method splits a DataFrame or RDD into multiple parts based on provided weights. Each weight represents the proportion of data allocated to the corresponding split.

Syntax

randomSplit(weights, seed=None)

Parameters

  • weights: A list of weights indicating the relative sizes of each split. The sum should equal 1.0.

  • seed (optional): A random seed for reproducibility.

Example

Here's how to split data into training and testing sets ?

from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName("DataSampling").getOrCreate()

# Create sample data
data = [(1, "Alice", 25), (2, "Bob", 30), (3, "Charlie", 35), 
        (4, "Diana", 28), (5, "Eve", 32), (6, "Frank", 27)]
columns = ["id", "name", "age"]
df = spark.createDataFrame(data, columns)

# Split data into 70% training and 30% testing
train_data, test_data = df.randomSplit([0.7, 0.3], seed=42)

print("Training data count:", train_data.count())
print("Testing data count:", test_data.count())
train_data.show()
Training data count: 4
Testing data count: 2
+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alice| 25|
|  2|    Bob| 30|
|  4|  Diana| 28|
|  6|  Frank| 27|
+---+-------+---+

PySpark sample() Method

The sample() method extracts a random sample from a DataFrame or RDD based on a specified fraction. Unlike randomSplit(), it returns a single subset of the original data.

Syntax

sample(withReplacement, fraction, seed=None)

Parameters

  • withReplacement: Boolean indicating whether sampling allows duplicates (True) or not (False).

  • fraction: The proportion of data to include (0.0 to 1.0).

  • seed (optional): Random seed for reproducibility.

Example

Here's how to extract a 50% sample without replacement ?

from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName("DataSampling").getOrCreate()

# Create sample data
data = [(1, "Alice", 25), (2, "Bob", 30), (3, "Charlie", 35), 
        (4, "Diana", 28), (5, "Eve", 32), (6, "Frank", 27)]
columns = ["id", "name", "age"]
df = spark.createDataFrame(data, columns)

# Extract 50% sample without replacement
sample_data = df.sample(withReplacement=False, fraction=0.5, seed=42)

print("Original data count:", df.count())
print("Sample data count:", sample_data.count())
sample_data.show()
Original data count: 6
Sample data count: 3
+---+-------+---+
| id|   name|age|
+---+-------+---+
|  2|    Bob| 30|
|  4|  Diana| 28|
|  6|  Frank| 27|
+---+-------+---+

Key Differences

Aspect randomSplit() sample()
Purpose Split into multiple datasets Create single subset
Output List of DataFrames Single DataFrame
Use Case Train-test splits Data exploration, prototyping
Control Multiple proportions Single fraction

Common Use Cases

randomSplit() is ideal for:

  • Creating train-validation-test splits for machine learning

  • Partitioning data for parallel processing

  • A/B testing scenarios

sample() is perfect for:

  • Exploratory data analysis with large datasets

  • Creating smaller datasets for algorithm prototyping

  • Performance testing with representative data subsets

Conclusion

PySpark's randomSplit() and sample() methods provide essential functionality for data sampling. Use randomSplit() when you need multiple data partitions and sample() when you need a single representative subset. Both methods help reduce computational overhead while maintaining data characteristics for efficient big data analysis.

Updated on: 2026-03-27T09:59:46+05:30

1K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements