- Trending Categories
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
Get specific row from PySpark dataframe
PySpark is a powerful tool for data processing and analysis. When working with data in a PySpark DataFrame, you may sometimes need to get a specific row from the dataframe. It helps users to manipulate and access data easily in a distributed and parallel manner, making it ideal for big data applications. In this article, We will explore how to get specific rows from the PySpark dataframe using various methods in PySpark. We will cover the approaches in functional programming style using PySpark's DataFrame APIs.
Before Moving forward, let's make a sample dataframe from which we have to get the rows.
from colorama import Fore from pyspark.sql import SparkSession # Building a SparkSession named "column_sum" spark = SparkSession.builder.appName("column_sum").getOrCreate() # Creating the Spark DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Printing the schema of the DataFrame df.printSchema() # Showing the DataFrame df.show()
Output
This python script will first print the schema of the data frame we have created and then the dataframe itself.
root |-- __: string (nullable = true) |-- Col1: long (nullable = true) |-- Col2: long (nullable = true) |-- Col3: long (nullable = true) +----+----+----+----+ | __|Col1|Col2|Col3| +----+----+----+----+ |Row1| 1| 2| 3| |Row2| 4| 5| 6| |Row3| 7| 8| 9| +----+----+----+----+
The approaches that can be used for completing the task are mentioned below:
Approaches
Using collect()
Using first()
Using show()
Using head()
Using tail()
Using select() and collect()
Using filter() and collect()
Using where() and collect()
Using take()
Now let's discuss each approach and how they can be used for adding columns.
Method 1: Using collect()
In PySpark, collect() method can be used to retrieve all the data from a PySpark DataFrame and return it as a List. This function is typically used when you want to view or manipulate the data in a dataframe. Below is the syntax used:
dataframe.collect()[index]
Here
dataframe is the one on which we apply the method
Index is the row we want to get.
After getting the dataframe in the form of a list, we can pass the index to the list representing the row we want.
Algorithm
First, create a dataframe using the above code.
Use the collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.
Print the values of the variables containing the desired rows to the console.
Example
# Retrieving the first row of the DataFrame using collect() function Row1 = df.collect()[0] print(Row1) # Retrieving the last row of the DataFrame using collect() function Row2 = df.collect()[-1] print(Row2) # Retrieving the second row of the DataFrame using collect() function Row3 = df.collect()[1] print(Row3)
Output
Row(__='Row1', Col1=1, Col2=2, Col3=3) Row(__='Row3', Col1=7, Col2=8, Col3=9) Row(__='Row2', Col1=4, Col2=5, Col3=6)
Method 2: Using first()
The first() function in PySpark returns the first element of a dataframe or RDD. We can use this function to extract a specific row from a dataframe. This function is typically used when you want to view the data in a dataframe. Below is the syntax used:
dataframe.first()
Here
dataframe is the one on which we apply the method
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Retrieve the first row of the DataFrame using the first() function
Print the first row to the console
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the first row Row1 = df.first() print(Row1)
Output
Row(Row1, 1, 2, 3)
Method 3: Using show()
In PySpark show() function is used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:
dataframe.show(n)
Here
dataframe is the one on which we apply the method
n is the number of rows
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Retrieve the first row of the DataFrame using the show() function by passing the row parameter as 1
Print the first row to the console
Retrieve the first two rows of the DataFrame using the show() function by passing the row parameter as 2
Print the first two rows to the console
Retrieve the first three rows of the DataFrame using the show() function by passing the row parameter as 3
Print the first three rows to the console
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the first row df1= df.show(1) print(df1) # Retrieve the first two rows df2= df.show(2) print(df2) # Retrieve the first three rows df3= df.show(3) print(df3)
Output
+----+----+----+----+ |__ |Col1|Col2|Col3| +----+----+----+----+ |Row1| 1| 2| 3| +----+----+----+----+ +----+----+----+----+ |__ |Col1|Col2|Col3| +----+----+----+----+ |Row1| 1| 2| 3| |Row2| 4| 5| 6| +----+----+----+----+ +----+----+----+----+ |__ |Col1|Col2|Col3| +----+----+----+----+ |Row1| 1| 2| 3| |Row2| 4| 5| 6| |Row3| 7| 8| 9| +----+----+----+----+
Method 4: Using head()
In PySpark head() function is used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:
dataframe.head(n)
Here
dataframe is the one on which we apply the method
n is the number of rows
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Retrieve the first row of the DataFrame using the head() function by passing the row parameter as 1
Print the first row to the console
Retrieve the first two rows of the DataFrame using the head() function by passing the row parameter as 2
Print the first two rows to the console
Retrieve the first three rows of the DataFrame using the head() function by passing the row parameter as 3
Print the first three rows to the console
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the first row df1= df.head(1) print(df1) # Retrieve the first two rows df2= df.head(2) print(df2) # Retrieve the first three rows df3= df.head(3) print(df3)
Output
[Row(__='Row1', Col1=1, Col2=2, Col3=3)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
Method 5: Using tail()
In PySpark tail() function is used for displaying the last n rows present in the python dataframe. The return value of this function is a small dataframe made from last n rows. Below is the syntax used:
dataframe.tail(n)
Here
dataframe is the one on which we apply the method
n is the number of rows
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Retrieve the first row of the DataFrame using the tail() function by passing the row parameter as 1
Print the last row to the console
Retrieve the first two rows of the DataFrame using the tail() function by passing the row parameter as 2
Print the last two rows to the console
Retrieve the first three rows of the DataFrame using the tail() function by passing the row parameter as 3
Print the last three rows to the console
Example
# Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the last row df1= df.tail(1) print(df1) # Retrieve the last two rows df2= df.tail(2) print(df2) # Retrieve the last three rows df3= df.tail(3) print(df3)
Output
[Row(__='Row3', Col1=7, Col2=8, Col3=9)] [Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
Method 6: Using select() and collect()
We can use select() function along with collect() method to show specific rows in the Pyspark Dataframe. Below is the syntax used:
dataframe.select([columns]).collect()[index]
Here
dataframe is the one on which we apply the method
columns is the list of columns we want to have in output.
Index is the row number we want to have in output.
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Use the combination of select() function and collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.
Print the values of the variables containing the desired rows to the console.
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the last row df1= df.select(['Col1', 'Col2', 'Col3']).collect(0) print(df1) # Retrieve the last two rows df2= df.select(['Col1', 'Col2', 'Col3']).collect(-1) print(df2) # Retrieve the last three rows df3= df.select(['Col1', 'Col2', 'Col3']).collect(1) print(df3)
Output
[Row(__='Row3', Col1=7, Col2=8, Col3=9)] [Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
Method 7: Using filter() and collect()
We can use the filter() function along with collect() method to show specific rows in the Pyspark Dataframe. Below is the syntax used:
dataframe.filter(condition).collect()[index]
Here
dataframe is the one on which we apply the method
Condition is the condition based on which the Dataframe rows are being filtered.
Index is the row number we want to have in output.
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Use the combination of filter() function and collect() function to retrieve the desired rows from the DataFrame, storing each row in a separate variable.
Print the values of the variables containing the desired rows to the console.
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("filter_collect_example").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Filter the DataFrame df1 = df.filter(df.Col1 > 1).collect()[0] # Print the collected data print(df1) # Filter the DataFrame df2 = df.filter(df.Col1 > 1).collect()[1] # Print the collected data print(df2) # Filter the DataFrame df3 = df.filter(df.Col1 > 1).collect()[-1] # Print the collected data print(df3)
Output
Row(Col1=4, Col2=5, Col3=6) Row(Col1=7, Col2=8, Col3=9) Row(Col1=7, Col2=8, Col3=9)
Method 8: Using where() and collect()
We can use the where() function along with collect() method to show specific rows in the Pyspark Dataframe. Using the where() method, we can have a specific row filtered based on the condition passed in the method, further we can apply the collect() method for storing the result in a variable. Below is the syntax used:
dataframe.where(condition).collect()[index]
Here:
dataframe is the one on which we apply the method
Condition is the condition based on which the Dataframe rows are being filtered.
Index is the row number we want to have in output.
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Use the combination of where() function and collect() function to retrieve the desired row from the DataFrame, storing each row in a separate variable.
Print the values of the variables containing the desired rows to the console.
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("filter_collect_example").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Filter the DataFrame df1 = df.where(df.Col1 > 1).collect()[0] # Print the collected data print(df1) # Filter the DataFrame df2 = df.where(df.Col1 > 1).collect()[1] # Print the collected data print(df2) # Filter the DataFrame df3 = df.where(df.Col1 > 1).collect()[-1] # Print the collected data print(df3)
Output
Row(Col1=4, Col2=5, Col3=6) Row(Col1=7, Col2=8, Col3=9) Row(Col1=7, Col2=8, Col3=9)
Method 9: Using take()
In PySpark, take() function is also used for displaying the top n rows present in the python dataframe. The return value of this function is a small dataframe made from first n rows. Below is the syntax used:
dataframe.take(n)
Here
dataframe is the one on which we apply the method
n is the number of rows
Algorithm
Import the necessary libraries
Create a SparkSession
Create a DataFrame
Retrieve the first row of the DataFrame using the take() function by passing the row parameter as 1
Print the first row to the console
Retrieve the first two rows of the DataFrame using the take() function by passing the row parameter as 2
Print the first two rows to the console
Retrieve the first three rows of the DataFrame using the take() function by passing the row parameter as 3
Print the first three rows to the console
Example
# Import necessary libraries from pyspark.sql import SparkSession # Create a SparkSession spark = SparkSession.builder.appName("column_sum").getOrCreate() # Create the DataFrame df = spark.createDataFrame([('Row1', 1, 2, 3), ('Row2', 4, 5, 6), ('Row3', 7, 8, 9)], ['__', 'Col1', 'Col2', 'Col3']) # Retrieve the first row df1= df.take(1) print(df1) # Retrieve the first two rows df2= df.take(2) print(df2) # Retrieve the first three rows df3= df.take(3) print(df3)
Output
[Row(__='Row1', Col1=1, Col2=2, Col3=3)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6)] [Row(__='Row1', Col1=1, Col2=2, Col3=3), Row(__='Row2', Col1=4, Col2=5, Col3=6), Row(__='Row3', Col1=7, Col2=8, Col3=9)]
Conclusion
Depending upon the use case, each approach can be more or less efficient than the others and each method can have its own advantages or disadvantages. It is more important to choose the best one for the specific task. These approaches can also be applied to large datasets due to their efficiency.