Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to Add Legends to charts in Python?
Charts help visualize complex data effectively. When creating charts with multiple data series, legends are essential for identifying what each visual element represents. Python's matplotlib library provides flexible options for adding and customizing legends.
Basic Legend Setup
First, let's prepare sample data and create a basic bar chart with legends ?
import matplotlib.pyplot as plt
# Sample mobile sales data (in millions)
mobile_brands = ['iPhone', 'Galaxy', 'Pixel']
units_sold = (
('2016', 12, 8, 6),
('2017', 14, 10, 7),
('2018', 16, 12, 8),
('2019', 18, 14, 10),
('2020', 20, 16, 5)
)
# Split data into separate lists for each brand
iphone_sales = [data[1] for data in units_sold]
galaxy_sales = [data[2] for data in units_sold]
pixel_sales = [data[3] for data in units_sold]
years = [data[0] for data in units_sold]
print("iPhone Sales:", iphone_sales)
print("Galaxy Sales:", galaxy_sales)
print("Pixel Sales:", pixel_sales)
iPhone Sales: [12, 14, 16, 18, 20] Galaxy Sales: [8, 10, 12, 14, 16] Pixel Sales: [6, 7, 8, 10, 5]
Creating Bar Chart with Legend
Now let's create a grouped bar chart with proper legends ?
import matplotlib.pyplot as plt
# Data preparation
mobile_brands = ['iPhone', 'Galaxy', 'Pixel']
units_sold = (
('2016', 12, 8, 6),
('2017', 14, 10, 7),
('2018', 16, 12, 8),
('2019', 18, 14, 10),
('2020', 20, 16, 5)
)
iphone_sales = [data[1] for data in units_sold]
galaxy_sales = [data[2] for data in units_sold]
pixel_sales = [data[3] for data in units_sold]
years = [data[0] for data in units_sold]
# Set up bar positions
positions = list(range(len(units_sold)))
width = 0.25
# Create grouped bars
plt.figure(figsize=(10, 6))
plt.bar([p - width for p in positions], iphone_sales, width=width, color='green', label='iPhone')
plt.bar(positions, galaxy_sales, width=width, color='blue', label='Galaxy')
plt.bar([p + width for p in positions], pixel_sales, width=width, color='orange', label='Pixel')
# Customize the chart
plt.xticks(positions, years)
plt.xlabel('Year')
plt.ylabel('Unit Sales (Millions)')
plt.title('Mobile Phone Sales by Brand')
# Add legend
plt.legend(title='Manufacturers')
plt.tight_layout()
plt.show()
Adding Annotations with Legends
You can combine legends with annotations to highlight specific data points ?
import matplotlib.pyplot as plt
# Same data setup
mobile_brands = ['iPhone', 'Galaxy', 'Pixel']
units_sold = (
('2016', 12, 8, 6),
('2017', 14, 10, 7),
('2018', 16, 12, 8),
('2019', 18, 14, 10),
('2020', 20, 16, 5)
)
iphone_sales = [data[1] for data in units_sold]
galaxy_sales = [data[2] for data in units_sold]
pixel_sales = [data[3] for data in units_sold]
years = [data[0] for data in units_sold]
positions = list(range(len(units_sold)))
width = 0.25
plt.figure(figsize=(10, 6))
plt.bar([p - width for p in positions], iphone_sales, width=width, color='green', label='iPhone')
plt.bar(positions, galaxy_sales, width=width, color='blue', label='Galaxy')
plt.bar([p + width for p in positions], pixel_sales, width=width, color='orange', label='Pixel')
# Add annotation for the drop in Pixel sales
plt.annotate('50% Drop in Sales',
xy=(4 + width, 5),
xytext=(3.5, 12),
horizontalalignment='center',
arrowprops=dict(facecolor='red', shrink=0.05))
plt.xticks(positions, years)
plt.xlabel('Year')
plt.ylabel('Unit Sales (Millions)')
plt.title('Mobile Phone Sales with Annotation')
plt.legend(title='Manufacturers')
plt.tight_layout()
plt.show()
Positioning Legends Outside the Plot
To avoid overlapping with data, you can position the legend outside the plot area ?
import matplotlib.pyplot as plt
# Data setup
mobile_brands = ['iPhone', 'Galaxy', 'Pixel']
units_sold = (
('2016', 12, 8, 6),
('2017', 14, 10, 7),
('2018', 16, 12, 8),
('2019', 18, 14, 10),
('2020', 20, 16, 5)
)
iphone_sales = [data[1] for data in units_sold]
galaxy_sales = [data[2] for data in units_sold]
pixel_sales = [data[3] for data in units_sold]
years = [data[0] for data in units_sold]
positions = list(range(len(units_sold)))
width = 0.25
plt.figure(figsize=(12, 6))
plt.bar([p - width for p in positions], iphone_sales, width=width, color='green', label='iPhone')
plt.bar(positions, galaxy_sales, width=width, color='blue', label='Galaxy')
plt.bar([p + width for p in positions], pixel_sales, width=width, color='orange', label='Pixel')
plt.xticks(positions, years)
plt.xlabel('Year')
plt.ylabel('Unit Sales (Millions)')
plt.title('Mobile Phone Sales - Legend Outside')
# Position legend outside the plot area
plt.legend(title='Manufacturers', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()
Legend Customization Options
| Parameter | Purpose | Example |
|---|---|---|
title |
Add title to legend | title='Brands' |
loc |
Legend position | loc='upper right' |
bbox_to_anchor |
Custom positioning | bbox_to_anchor=(1, 0.8) |
ncol |
Number of columns | ncol=3 |
Complete Example
import matplotlib.pyplot as plt
# Complete example with customized legend
mobile_brands = ['iPhone', 'Galaxy', 'Pixel']
units_sold = (
('2016', 12, 8, 6),
('2017', 14, 10, 7),
('2018', 16, 12, 8),
('2019', 18, 14, 10),
('2020', 20, 16, 5)
)
iphone_sales = [data[1] for data in units_sold]
galaxy_sales = [data[2] for data in units_sold]
pixel_sales = [data[3] for data in units_sold]
years = [data[0] for data in units_sold]
positions = list(range(len(units_sold)))
width = 0.25
plt.figure(figsize=(10, 6))
plt.bar([p - width for p in positions], iphone_sales, width=width, color='green', label='iPhone')
plt.bar(positions, galaxy_sales, width=width, color='blue', label='Galaxy')
plt.bar([p + width for p in positions], pixel_sales, width=width, color='orange', label='Pixel')
plt.xticks(positions, years)
plt.xlabel('Year')
plt.ylabel('Unit Sales (Millions)')
plt.title('Mobile Phone Sales Analysis')
# Customized legend with multiple columns
plt.legend(title='Phone Manufacturers', loc='upper left', ncol=3)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
Conclusion
Legends are crucial for making multi-series charts readable and informative. Use plt.legend() with appropriate parameters like title, loc, and bbox_to_anchor to position and customize legends effectively.
