Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to plot statsmodels linear regression (OLS) cleanly in Matplotlib?
Plotting statsmodels linear regression (OLS) results cleanly in Matplotlib involves creating a regression model, calculating predictions and confidence intervals, then visualizing the data points, fitted line, and confidence bands together.
Steps to Plot OLS Regression
Set up figure size and random seed for reproducible results
Create sample data with linear and non-linear features
Fit an OLS regression model using
statsmodelsCalculate prediction standard errors and confidence intervals
Plot the original data points, true relationship, fitted values, and confidence bands
Add legend and display the plot
Example
Here's how to create a comprehensive OLS regression plot with confidence intervals ?
import numpy as np
from matplotlib import pyplot as plt
from statsmodels import api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std
# Set figure size and random seed
plt.rcParams["figure.figsize"] = [7.50, 3.50]
plt.rcParams["figure.autolayout"] = True
np.random.seed(9876789)
# Generate sample data
nsample = 50
sig = 0.5
x = np.linspace(0, 20, nsample)
X = np.column_stack((x, np.sin(x), (x - 5) ** 2, np.ones(nsample)))
beta = [0.5, 0.5, -0.02, 5.]
# Create true relationship and add noise
y_true = np.dot(X, beta)
y = y_true + sig * np.random.normal(size=nsample)
# Fit OLS regression model
res = sm.OLS(y, X).fit()
# Calculate prediction standard errors and confidence intervals
prstd, iv_l, iv_u = wls_prediction_std(res)
# Create the plot
fig, ax = plt.subplots()
ax.plot(x, y, 'o', label="Data points")
ax.plot(x, y_true, 'b-', label="True relationship")
ax.plot(x, res.fittedvalues, 'r--.', label="OLS fitted")
ax.plot(x, iv_u, 'r--', alpha=0.7, label="Confidence interval")
ax.plot(x, iv_l, 'r--', alpha=0.7)
ax.legend(loc='best')
plt.title("OLS Regression with Confidence Intervals")
plt.xlabel("X values")
plt.ylabel("Y values")
plt.show()
The output shows the original data points, the true underlying relationship, the OLS fitted line, and the confidence interval bands ?
A plot displaying: - Blue dots: Original data points with noise - Blue solid line: True underlying relationship - Red dashed line with dots: OLS fitted values - Red dashed lines: Upper and lower confidence intervals
Understanding the Components
Model Features
The design matrix X includes multiple features ?
# Feature matrix breakdown
print("Features in the model:")
print("1. Linear term: x")
print("2. Sine term: sin(x)")
print("3. Quadratic term: (x-5)²")
print("4. Intercept: constant term")
# Check model summary
print(f"\nR-squared: {res.rsquared:.3f}")
print(f"Number of observations: {res.nobs}")
Features in the model: 1. Linear term: x 2. Sine term: sin(x) 3. Quadratic term: (x-5)² 4. Intercept: constant term R-squared: 0.999 Number of observations: 50
Key Components Explained
| Component | Purpose | Visualization |
|---|---|---|
y (data points) |
Original observations | Blue circles |
y_true |
True underlying relationship | Blue solid line |
res.fittedvalues |
OLS predicted values | Red dashed line with dots |
iv_l, iv_u |
Confidence interval bounds | Red dashed lines |
Conclusion
This approach creates a comprehensive OLS regression visualization showing data points, fitted values, and confidence intervals. The wls_prediction_std function calculates prediction standard errors, while proper plotting techniques make the relationships clear and interpretable.
