How can I write unit tests against code that uses Matplotlib?

Writing unit tests for Matplotlib code requires testing the data and properties of plots without displaying them. The key is extracting plot data using methods like get_data() and comparing it with expected values.

Creating a Testable Function

First, create a function that generates a plot and returns the plot object for testing ?

import numpy as np
from matplotlib import pyplot as plt

def plot_sqr_curve(x):
    """
    Plotting x points with y = x^2.
    """
    return plt.plot(x, np.square(x))

Writing Unit Tests

Use unittest.TestCase to test the plot data against expected values ?

import unittest
import numpy as np
from matplotlib import pyplot as plt

def plot_sqr_curve(x):
    """
    Plotting x points with y = x^2.
    """
    return plt.plot(x, np.square(x))

class TestSqrCurve(unittest.TestCase):
    def test_curve_sqr_plot(self):
        # Create test data
        x = np.array([1, 3, 4])
        y = np.square(x)  # Expected y values
        
        # Generate plot and extract first line object
        pt, = plot_sqr_curve(x)
        
        # Extract actual data from plot
        x_data = pt.get_data()[0]
        y_data = pt.get_data()[1]
        
        # Assert that actual data matches expected
        self.assertTrue((x == x_data).all())
        self.assertTrue((y == y_data).all())

if __name__ == '__main__':
    unittest.main()
Ran 1 test in 0.001s

OK

Testing Plot Properties

You can also test other plot properties like labels, colors, and line styles ?

class TestPlotProperties(unittest.TestCase):
    def test_plot_styling(self):
        x = np.array([1, 2, 3])
        line, = plt.plot(x, x**2, color='red', label='Square')
        
        # Test line color
        self.assertEqual(line.get_color(), 'red')
        
        # Test line label
        self.assertEqual(line.get_label(), 'Square')
        
        # Test line style
        self.assertEqual(line.get_linestyle(), '-')

Key Testing Methods

Method Purpose Returns
get_data() Extract x, y coordinates Tuple of arrays
get_color() Get line color Color string/tuple
get_label() Get line label Label string
get_linestyle() Get line style Style string

Best Practices

Use plt.close() in tearDown to prevent memory leaks ?

def tearDown(self):
    plt.close('all')  # Clean up plots after each test

Test data ranges and edge cases ?

def test_empty_array(self):
    x = np.array([])
    with self.assertRaises(ValueError):
        plot_sqr_curve(x)

Conclusion

Test Matplotlib code by extracting plot data with get_data() and comparing against expected values. Use plt.close('all') in tearDown to prevent memory issues during testing.

Updated on: 2026-03-25T20:14:01+05:30

690 Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements