Get specific row from PySpark dataframe

PySpark is a powerful tool for big data processing and analysis. When working with PySpark DataFrames, you often need to retrieve specific rows for analysis or debugging. This article explores various methods to get specific rows from PySpark DataFrames using functional programming approaches.

Creating Sample DataFrame

Let's create a sample DataFrame to demonstrate all the methods ?

from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName("get_specific_rows").getOrCreate()

# Create sample DataFrame
df = spark.createDataFrame([
    ('Row1', 1, 2, 3),
    ('Row2', 4, 5, 6),
    ('Row3', 7, 8, 9)
], ['Name', 'Col1', 'Col2', 'Col3'])

# Show DataFrame structure and data
df.printSchema()
df.show()
root
 |-- Name: string (nullable = true)
 |-- Col1: long (nullable = true)
 |-- Col2: long (nullable = true)
 |-- Col3: long (nullable = true)

+----+----+----+----+
|Name|Col1|Col2|Col3|
+----+----+----+----+
|Row1|   1|   2|   3|
|Row2|   4|   5|   6|
|Row3|   7|   8|   9|
+----+----+----+----+

Using collect() Method

The collect() method retrieves all rows as a list, allowing index-based access ?

# Get specific rows by index
first_row = df.collect()[0]
second_row = df.collect()[1]
last_row = df.collect()[-1]

print("First row:", first_row)
print("Second row:", second_row)
print("Last row:", last_row)
First row: Row(Name='Row1', Col1=1, Col2=2, Col3=3)
Second row: Row(Name='Row2', Col1=4, Col2=5, Col3=6)
Last row: Row(Name='Row3', Col1=7, Col2=8, Col3=9)

Using first() Method

The first() method returns only the first row of the DataFrame ?

# Get the first row
first_row = df.first()
print("First row:", first_row)
First row: Row(Name='Row1', Col1=1, Col2=2, Col3=3)

Using head() and take() Methods

Both head() and take() return the first n rows as a list ?

# Using head() method
head_rows = df.head(2)
print("First 2 rows (head):", head_rows)

# Using take() method
take_rows = df.take(2)
print("First 2 rows (take):", take_rows)
First 2 rows (head): [Row(Name='Row1', Col1=1, Col2=2, Col3=3), Row(Name='Row2', Col1=4, Col2=5, Col3=6)]
First 2 rows (take): [Row(Name='Row1', Col1=1, Col2=2, Col3=3), Row(Name='Row2', Col1=4, Col2=5, Col3=6)]

Using tail() Method

The tail() method returns the last n rows ?

# Get last 2 rows
last_rows = df.tail(2)
print("Last 2 rows:", last_rows)
Last 2 rows: [Row(Name='Row2', Col1=4, Col2=5, Col3=6), Row(Name='Row3', Col1=7, Col2=8, Col3=9)]

Using filter() with collect()

Filter rows based on conditions and collect specific results ?

# Filter rows where Col1 > 1 and get specific rows
filtered_rows = df.filter(df.Col1 > 1).collect()

print("First filtered row:", filtered_rows[0])
print("Last filtered row:", filtered_rows[-1])
First filtered row: Row(Name='Row2', Col1=4, Col2=5, Col3=6)
Last filtered row: Row(Name='Row3', Col1=7, Col2=8, Col3=9)

Using where() with collect()

The where() method is an alias for filter() ?

# Filter using where() method
where_rows = df.where(df.Col2 == 5).collect()
print("Row where Col2=5:", where_rows[0])
Row where Col2=5: Row(Name='Row2', Col1=4, Col2=5, Col3=6)

Using select() with collect()

Select specific columns and then collect rows ?

# Select specific columns and get rows
selected_rows = df.select('Name', 'Col1').collect()

print("Selected columns (first row):", selected_rows[0])
print("Selected columns (second row):", selected_rows[1])
Selected columns (first row): Row(Name='Row1', Col1=1)
Selected columns (second row): Row(Name='Row2', Col1=4)

Comparison of Methods

Method Returns Use Case Performance
collect() List of all rows Access any row by index Expensive for large data
first() Single Row object Get only first row Fast
head(n) List of first n rows Get top n rows Good for small n
take(n) List of first n rows Same as head() Good for small n
tail(n) List of last n rows Get bottom n rows Slower than head()
filter().collect() List of filtered rows Conditional row selection Depends on filter selectivity

Conclusion

Use first() for getting the first row efficiently. Use collect() sparingly on large datasets as it brings all data to the driver. For conditional row selection, combine filter() or where() with collect() for specific results.

Updated on: 2026-03-27T06:44:50+05:30

11K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements