Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to create a seaborn correlation heatmap in Python?
A correlation heatmap is a graphical representation that displays the correlation matrix of a dataset using colors to show the strength and direction of relationships between variables. It's an effective tool for identifying patterns and connections in large datasets.
Seaborn, a Python data visualization library, provides simple utilities for creating statistical visualizations including correlation heatmaps. The process involves importing your dataset, computing the correlation matrix, and using Seaborn's heatmap function to generate the visualization.
Using the heatmap() Function
The heatmap() function generates a color-coded matrix showing correlations between variable pairs. It requires a correlation matrix as input, which can be calculated using the corr() method of a Pandas DataFrame.
Syntax
import seaborn as sns sns.heatmap(data, cmap=None, annot=None)
Parameters:
-
dataThe correlation matrix (input dataset) -
cmapThe colormap for coloring the heatmap -
annotWhether to annotate cells with correlation values
Example 1: Iris Dataset Correlation Heatmap
Let's create a correlation heatmap using the famous iris dataset, which contains measurements of sepal and petal dimensions for different iris species ?
# Required libraries
import seaborn as sns
import matplotlib.pyplot as plt
# Load the iris dataset
iris_data = sns.load_dataset('iris')
# Display first few rows
print("First 5 rows of iris dataset:")
print(iris_data.head())
# Create correlation matrix
iris_corr_matrix = iris_data.corr()
print("\nCorrelation Matrix:")
print(iris_corr_matrix)
First 5 rows of iris dataset:
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
Correlation Matrix:
sepal_length sepal_width petal_length petal_width
sepal_length 1.000000 -0.117570 0.871754 0.817941
sepal_width -0.117570 1.000000 -0.428440 -0.366126
petal_length 0.871754 -0.428440 1.000000 0.962865
petal_width 0.817941 -0.366126 0.962865 1.000000
Now let's create the heatmap with annotations ?
import seaborn as sns
import matplotlib.pyplot as plt
# Load dataset and create correlation matrix
iris_data = sns.load_dataset('iris')
iris_corr_matrix = iris_data.corr()
# Create the heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(iris_corr_matrix, cmap='coolwarm', annot=True, center=0)
plt.title('Iris Dataset Correlation Heatmap')
plt.show()
[Creates a correlation heatmap showing strong positive correlations between petal_length and petal_width (0.96), and between sepal_length and petal measurements]
Example 2: Diamonds Dataset Correlation Heatmap
The diamonds dataset contains information about diamond prices and characteristics. Let's examine correlations between numerical variables ?
# Load the diamonds dataset
diamonds_data = sns.load_dataset('diamonds')
# Display dataset info
print("Diamonds dataset shape:", diamonds_data.shape)
print("\nNumerical columns:")
print(diamonds_data.select_dtypes(include=['float64', 'int64']).columns.tolist())
# Create correlation matrix for numerical columns only
diamonds_corr_matrix = diamonds_data.corr()
print("\nCorrelation Matrix:")
print(diamonds_corr_matrix.round(3))
Diamonds dataset shape: (53940, 10)
Numerical columns:
['carat', 'depth', 'table', 'price', 'x', 'y', 'z']
Correlation Matrix:
carat depth table price x y z
carat 1.000 0.028 0.182 0.922 0.975 0.952 0.953
depth 0.028 1.000 -0.296 -0.011 -0.025 -0.029 0.095
table 0.182 -0.296 1.000 0.127 0.195 0.184 0.151
price 0.922 -0.011 0.127 1.000 0.884 0.865 0.861
x 0.975 -0.025 0.195 0.884 1.000 0.975 0.971
y 0.952 -0.029 0.184 0.865 0.975 1.000 0.952
z 0.953 0.095 0.151 0.861 0.971 0.952 1.000
Now let's create an enhanced heatmap with better formatting ?
import seaborn as sns
import matplotlib.pyplot as plt
# Load dataset
diamonds_data = sns.load_dataset('diamonds')
diamonds_corr_matrix = diamonds_data.corr()
# Create enhanced heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(diamonds_corr_matrix,
cmap='coolwarm',
annot=True,
center=0,
square=True,
fmt='.2f')
plt.title('Diamonds Dataset Correlation Heatmap')
plt.tight_layout()
plt.show()
[Creates a correlation heatmap showing strong positive correlations between carat and price (0.92), and very high correlations between diamond dimensions (x, y, z)]
Key Insights from Heatmaps
| Correlation Value | Color (coolwarm) | Interpretation |
|---|---|---|
| 1.0 | Dark Red | Perfect positive correlation |
| 0.0 | White | No correlation |
| -1.0 | Dark Blue | Perfect negative correlation |
Conclusion
Seaborn correlation heatmaps provide an intuitive way to visualize relationships between variables in your dataset. Use annot=True to display correlation coefficients and choose appropriate colormaps like 'coolwarm' to distinguish positive and negative correlations effectively.
