How to verify Pyspark dataframe column type?


PySpark, the Python API for Apache Spark, provides a powerful and scalable big data processing and analytics framework. When working with PySpark DataFrames, it's essential to understand and verify the data types of each column. Accurate column-type verification ensures data integrity and enables you to perform operations and transformations accurately. In this article, we will explore various methods to verify PySpark DataFrame column types and provide examples for better understanding.

Overview of PySpark DataFrame Column Types

In PySpark, a DataFrame represents a distributed data collection organized into named columns. Each column has a specific data type, which can be any valid PySpark data type, such as IntegerType, StringType, BooleanType, etc. Understanding the column types is crucial as it allows you to perform operations based on the expected data types.

Using the printSchema() Method

The printSchema() method provides a concise and structured representation of the DataFrame's schema, including the column names and their corresponding data types. It is one of the easiest ways to verify the column types.

Syntax

df.printSchema()

Here, df.printSchema() syntax is used to display the schema of a PySpark DataFrame. It prints the column names along with their respective data types and whether they allow null values.

Example

In the example below, we create a SparkSession and define a schema for a PySpark DataFrame. The sample data is then used to create the DataFrame with columns named "col1", "col2", and "col3" having the corresponding data types IntegerType, StringType, and DoubleType. Finally, the schema of the DataFrame is printed using the printSchema() method, which displays the column names and their data types.

from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, DoubleType

# Create a SparkSession
spark = SparkSession.builder.getOrCreate()

# Sample data
data = [
    (1, "John", 3.14),
    (2, "Jane", 2.71),
    (3, "Alice", 1.23)
]

# Define the schema
schema = [
    ("col1", IntegerType(), True),
    ("col2", StringType(), True),
    ("col3", DoubleType(), True)
]

# Create a DataFrame with the provided data and schema
df = spark.createDataFrame(data, schema)

# Print the DataFrame schema
df.printSchema()

Output

root
 |-- col1: integer (nullable = true)
 |-- col2: string (nullable = true)
 |-- col3: double (nullable = true)

Inspecting Column Types with dtypes

The dtypes attribute returns a list of tuples, where each tuple contains the column name and its corresponding data type. This method allows for a programmatic way of accessing the column types.

Syntax

column_types = df.dtypes
for column_name, data_type in column_types:
    print(f"Column '{column_name}' has data type: {data_type}")

Here, df.dtypes retrieves the column names and their corresponding data types as a list of tuples from the PySpark DataFrame. The for loop iterates over each tuple, extracting the column name and data type, and then prints them using f-string formatting.

Example

In the example below, we create a PySpark DataFrame using a SparkSession. It defines sample data as a list of tuples and creates a DataFrame named df with columns "col1", "col2", and "col3".The df.dtypes attribute retrieves the column names and their corresponding data types as a list of tuples. The for loop iterates over each tuple, extracting the column name and data type, and then prints them using f-string formatting.

from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.getOrCreate()

# Sample data
data = [
    (1, "John", 3.14),
    (2, "Jane", 2.71),
    (3, "Alice", 1.23)
]

# Create a DataFrame
df = spark.createDataFrame(data, ["col1", "col2", "col3"])

# Get the column types
column_types = df.dtypes

# Display the column types
for column_name, data_type in column_types:
    print(f"Column '{column_name}' has data type: {data_type}")

Output

The output displays the column names (col1, col2, col3) and their corresponding data types (int, string, double). This information is obtained using the dtypes attribute of the DataFrame, which returns a list of tuples, where each tuple contains the column name and its data type.

Column 'col1' has data type: int
Column 'col2' has data type: string
Column 'col3' has data type: double

Verifying Column Types with selectExpr()

The selectExpr() method allows us to select columns and apply expressions or transformations on them. Combining it with the typeof() function lets you directly check the data types of specific columns.

Syntax

from pyspark.sql.functions import expr

column_names = ["col1", "col2", "col3"]
exprs = [expr(f"typeof({col}) as {col}_type") for col in column_names]
df.selectExpr(*exprs).show()

Here, typeof() function retrieves the data type of each column and alias it with a new column name that includes "_type". The df.selectExpr(*exprs).show() then applies these expressions to the DataFrame, selecting the dynamically created columns and displaying their results.

Example

In the example below, we create a SparkSession and define a PySpark DataFrame named df with three columns: "col1", "col2", and "col3". To verify the column types, the code uses the selectExpr() method on the DataFrame. It creates a list of expressions using list comprehension, where each expression uses the typeof() function to determine the data type of a column and aliases it with a new column name that includes "_type".Finally, the df.selectExpr(*exprs).show() applies these expressions to the DataFrame, selecting the dynamically created columns with the column names and their respective data types. The show() method displays the resulting DataFrame.

from pyspark.sql import SparkSession

# Create a SparkSession
spark = SparkSession.builder.getOrCreate()

# Sample data
data = [
    (1, "John", 3.14),
    (2, "Jane", 2.71),
    (3, "Alice", 1.23)
]

# Create a DataFrame
df = spark.createDataFrame(data, ["col1", "col2", "col3"])

# Verify column types using selectExpr()
column_names = ["col1", "col2", "col3"]
exprs = [f"typeof({col}) as {col}_type" for col in column_names]
df.selectExpr(*exprs).show()

Output

+---------+---------+---------+
|col1_type|col2_type|col3_type|
+---------+---------+---------+
|  integer|   string|   double|
|  integer|   string|   double|
|  integer|   string|   double|
+---------+---------+---------+

Checking Column Types using cast()

The cast() function allows us to cast columns to different data types explicitly. By comparing the original and cast columns, you can verify if the cast was successful, indicating that the original column had the expected data type.

Example

In the example below, we create a SparkSession and define a PySpark DataFrame named df with three columns: "col1", "col2", and "col3", along with sample data. The code defines a dictionary expected_data_types that specifies the expected data types for each column. A for loop iterates over each item in the expected_data_types dictionary. Within the loop, the code uses the cast() function to attempt to cast the column to the expected data type. It creates a new column with the casted values and compares it with the original column to identify rows where the cast was successful.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession
spark = SparkSession.builder.getOrCreate()

# Sample data
data = [
    (1, "John", 3.14),
    (2, "Jane", 2.71),
    (3, "Alice", 1.23)
]

# Create a DataFrame
df = spark.createDataFrame(data, ["col1", "col2", "col3"])

# Define the expected data types
expected_data_types = {
    "col1": "integer",
    "col2": "string",
    "col3": "double"
}

# Check column types using cast()
for column_name, expected_type in expected_data_types.items():
    cast_column = df.select(col(column_name).cast(expected_type).alias(column_name))
    matched_rows = df.filter(col(column_name) == cast_column[column_name])
    print(f"Column '{column_name}' has the expected data type: {expected_type}?")
    matched_rows.show()

Output

The output verifies each column's data type by attempting to cast it to the expected type using the cast() function. The original DataFrame is filtered based on matching rows after the cast, and if all rows are matched, it indicates that the column has the expected data type.

Column 'col1' has the expected data type: integer?
+----+-----+----+
|col1|col2 |col3|
+----+-----+----+
|   1| John|3.14|
|   2| Jane|2.71|
|   3|Alice|1.23|
+----+-----+----+

Column 'col2' has the expected data type: string?
+----+-----+----+
|col1|col2 |col3|
+----+-----+----+
|   1| John|3.14|
|   2| Jane|2.71|
|   3|Alice|1.23|
+----+-----+----+

Column 'col3' has the expected data type: double?
+----+-----+----+
|col1|col2 |col3|
+----+-----+----+
|   1| John|3.14|
|   2| Jane|2.71|
|   3|Alice|1.23|
+----+-----+----+

Conclusion

In this article, we discussed how we can verify Pyspark dataframe column type. Verifying PySpark DataFrame column types is crucial for ensuring data accuracy and performing meaningful operations. In this article, we explored several methods for verifying column types, including using printSchema(), dtypes, selectExpr(), cast().

Updated on: 16-Oct-2023

321 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements