Get specific row from PySpark dataframe


PySpark is a powerful tool for data processing and analysis. When working with data in a PySpark DataFrame, you may sometimes need to get a specific row from the dataframe. It helps users to manipulate and access data easily in a distributed and parallel manner, making it ideal for big data applications. In this article, We will explore how to get specific rows from the PySpark dataframe using various methods in PySpark. We will cover the approaches in functional programming style using PySpark's DataFrame APIs.

Before Moving forward, let's make a sample dataframe from which we have to get the rows.

from colorama import Fore
from pyspark.sql import SparkSession

# Building a SparkSession named "column_sum"
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Creating the Spark DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3),
   ('Row2', 4, 5, 6),
   ('Row3', 7, 8, 9)],
   ['__', 'Col1', 'Col2', 'Col3'])
   
# Printing the schema of the DataFrame
df.printSchema()

# Showing the DataFrame
df.show()

Output

This python script will first print the schema of the data frame we have created and then the dataframe itself.

root
|-- __: string (nullable = true)
|-- Col1: long (nullable = true)
|-- Col2: long (nullable = true)
|-- Col3: long (nullable = true)

+----+----+----+----+
|  __|Col1|Col2|Col3|
+----+----+----+----+
|Row1|   1|   2|   3|
|Row2|   4|   5|   6|
|Row3|   7|   8|   9|
+----+----+----+----+

The approaches that can be used for completing the task are mentioned below:

Approaches

  • Using collect()

  • Using first()

  • Using show()

  • Using head()

  • Using tail()

  • Using select() and collect()

  • Using filter() and collect()

  • Using where() and collect()

  • Using take()

Now let's discuss each approach and how they can be used for adding columns.

Method 1: Using collect()

In PySpark, collect() method can be used to retrieve all the data from a PySpark DataFrame and return it as a List. This function is typically used when you want to view or manipulate the data in a dataframe. Below is the syntax used:

dataframe.collect()[index]

Here

  • dataframe is the one on which we apply the method

  • Index is the row we want to get.

After getting the dataframe in the form of a list, we can pass the index to the list representing the row we want.

Algorithm

  • First, create a dataframe using the above code.

  • Use the collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.

  • Print the values of the variables containing the desired rows to the console.

Example

# Retrieving the first row of the DataFrame using collect() function
Row1 = df.collect()[0]
print(Row1)

# Retrieving the last row of the DataFrame using collect() function
Row2 = df.collect()[-1]
print(Row2)

# Retrieving the second row of the DataFrame using collect() function
Row3 = df.collect()[1]
print(Row3)

Output

Row(__='Row1', Col1=1, Col2=2, Col3=3)
Row(__='Row3', Col1=7, Col2=8, Col3=9)
Row(__='Row2', Col1=4, Col2=5, Col3=6)

Method 2: Using first()

The first() function in PySpark returns the first element of a dataframe or RDD. We can use this function to extract a specific row from a dataframe. This function is typically used when you want to view the data in a dataframe. Below is the syntax used:

dataframe.first()

Here

  • dataframe is the one on which we apply the method

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Retrieve the first row of the DataFrame using the first() function

  • Print the first row to the console

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the first row
Row1 = df.first()
print(Row1)

Output

Row(Row1, 1, 2, 3)

Method 3: Using show()

In PySpark show() function is used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:

dataframe.show(n)

Here

  • dataframe is the one on which we apply the method

  • n is the number of rows

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Retrieve the first row of the DataFrame using the show() function by passing the row parameter as 1

  • Print the first row to the console

  • Retrieve the first two rows of the DataFrame using the show() function by passing the row parameter as 2

  • Print the first two rows to the console

  • Retrieve the first three rows of the DataFrame using the show() function by passing the row parameter as 3

  • Print the first three rows to the console

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the first row
df1= df.show(1)
print(df1)

# Retrieve the first two rows
df2= df.show(2)
print(df2)

# Retrieve the first three rows
df3= df.show(3)
print(df3)

Output

+----+----+----+----+
|__  |Col1|Col2|Col3|
+----+----+----+----+
|Row1|   1|   2|   3|
+----+----+----+----+

+----+----+----+----+
|__  |Col1|Col2|Col3|
+----+----+----+----+
|Row1|   1|   2|   3|
|Row2|   4|   5|   6|
+----+----+----+----+

+----+----+----+----+
|__  |Col1|Col2|Col3|
+----+----+----+----+
|Row1|   1|   2|   3|
|Row2|   4|   5|   6|
|Row3|   7|   8|   9|
+----+----+----+----+

Method 4: Using head()

In PySpark head() function is used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:

dataframe.head(n)

Here

  • dataframe is the one on which we apply the method

  • n is the number of rows

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Retrieve the first row of the DataFrame using the head() function by passing the row parameter as 1

  • Print the first row to the console

  • Retrieve the first two rows of the DataFrame using the head() function by passing the row parameter as 2

  • Print the first two rows to the console

  • Retrieve the first three rows of the DataFrame using the head() function by passing the row parameter as 3

  • Print the first three rows to the console

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the first row
df1= df.head(1)
print(df1)

# Retrieve the first two rows
df2= df.head(2)
print(df2)

# Retrieve the first three rows
df3= df.head(3)
print(df3)

Output

[Row(__='Row1', Col1=1, Col2=2, Col3=3)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]

Method 5: Using tail()

In PySpark tail() function is used for displaying the last n rows present in the python dataframe. The return value of this function is a small dataframe made from last n rows. Below is the syntax used:

dataframe.tail(n)

Here

  • dataframe is the one on which we apply the method

  • n is the number of rows

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Retrieve the first row of the DataFrame using the tail() function by passing the row parameter as 1

  • Print the last row to the console

  • Retrieve the first two rows of the DataFrame using the tail() function by passing the row parameter as 2

  • Print the last two rows to the console

  • Retrieve the first three rows of the DataFrame using the tail() function by passing the row parameter as 3

  • Print the last three rows to the console

Example

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the last row
df1= df.tail(1)
print(df1)

# Retrieve the last two rows
df2= df.tail(2)
print(df2)

# Retrieve the last three rows
df3= df.tail(3)
print(df3)

Output

[Row(__='Row3', Col1=7, Col2=8, Col3=9)]
[Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]

Method 6: Using select() and collect()

We can use select() function along with collect() method to show specific rows in the Pyspark Dataframe. Below is the syntax used:

dataframe.select([columns]).collect()[index]

Here

  • dataframe is the one on which we apply the method

  • columns is the list of columns we want to have in output.

  • Index is the row number we want to have in output.

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Use the combination of select() function and collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.

  • Print the values of the variables containing the desired rows to the console.

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the last row
df1= df.select(['Col1', 'Col2', 'Col3']).collect(0)
print(df1)

# Retrieve the last two rows
df2= df.select(['Col1', 'Col2', 'Col3']).collect(-1)
print(df2)

# Retrieve the last three rows
df3= df.select(['Col1', 'Col2', 'Col3']).collect(1)
print(df3)

Output

[Row(__='Row3', Col1=7, Col2=8, Col3=9)]
[Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]

Method 7: Using filter() and collect()

We can use the filter() function along with collect() method to show specific rows in the Pyspark Dataframe. Below is the syntax used:

dataframe.filter(condition).collect()[index]

Here

  • dataframe is the one on which we apply the method

  • Condition is the condition based on which the Dataframe rows are being filtered.

  • Index is the row number we want to have in output.

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Use the combination of filter() function and collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.

  • Print the values of the variables containing the desired rows to the console.

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("filter_collect_example").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Filter the DataFrame
df1 = df.filter(df.Col1 > 1).collect()[0]

# Print the collected data
print(df1)

# Filter the DataFrame
df2 = df.filter(df.Col1 > 1).collect()[1]

# Print the collected data
print(df2)

# Filter the DataFrame
df3 = df.filter(df.Col1 > 1).collect()[-1]

# Print the collected data
print(df3)

Output

Row(Col1=4, Col2=5, Col3=6)
Row(Col1=7, Col2=8, Col3=9)
Row(Col1=7, Col2=8, Col3=9)

Method 8: Using where() and collect()

We can use the where() function along with collect() method to show specific rows in the Pyspark Dataframe. Using the where() method, we can have a specific row filtered based on the condition passed in the method, further we can apply the collect() method for storing the result in a variable. Below is the syntax used:

dataframe.where(condition).collect()[index]

Here:

  • dataframe is the one on which we apply the method

  • Condition is the condition based on which the Dataframe rows are being filtered.

  • Index is the row number we want to have in output.

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Use the combination of where() function and collect() function to retrieve the desired row from the DataFrame, storing each row in a separate variable.

  • Print the values of the variables containing the desired rows to the console.

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("filter_collect_example").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Filter the DataFrame
df1 = df.where(df.Col1 > 1).collect()[0]

# Print the collected data
print(df1)

# Filter the DataFrame
df2 = df.where(df.Col1 > 1).collect()[1]

# Print the collected data
print(df2)

# Filter the DataFrame
df3 = df.where(df.Col1 > 1).collect()[-1]

# Print the collected data
print(df3)

Output

Row(Col1=4, Col2=5, Col3=6)
Row(Col1=7, Col2=8, Col3=9)
Row(Col1=7, Col2=8, Col3=9)

Method 9: Using take()

In PySpark, take() function is also used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:

dataframe.take(n)

Here

  • dataframe is the one on which we apply the method

  • n is the number of rows

Algorithm

  • Import the necessary libraries

  • Create a SparkSession

  • Create a DataFrame

  • Retrieve the first row of the DataFrame using the take() function by passing the row parameter as 1

  • Print the first row to the console

  • Retrieve the first two rows of the DataFrame using the take() function by passing the row parameter as 2

  • Print the first two rows to the console

  • Retrieve the first three rows of the DataFrame using the take() function by passing the row parameter as 3

  • Print the first three rows to the console

Example

# Import necessary libraries
from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.appName("column_sum").getOrCreate()

# Create the DataFrame
df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3'])

# Retrieve the first row
df1= df.take(1)
print(df1)

# Retrieve the first two rows
df2= df.take(2)
print(df2)

# Retrieve the first three rows
df3= df.take(3)
print(df3)

Output

[Row(__='Row1', Col1=1, Col2=2, Col3=3)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6)]
[Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]

Conclusion

Depending upon the use case, each approach can be more or less efficient than the others and each method can have its own advantages or disadvantages. It is more important to choose the best one for the specific task. These approaches can also be applied to large datasets due to their efficiency.

Updated on: 29-May-2023

5K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements