Custom legends in Matplotlib

logo of a chart:ScatterPlot

This post explains how to customize the legend on a chart with matplotlib.

It provides many examples covering the most common use cases like controling the legend location, adding a legend title or customizing the legend markers and labels.

The data

Let's get started by importing libraries and loading the data

import palmerpenguins

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

The Penguins data set used here was collected and made available by Dr. Kristen Gorman at the Palmer Station, Antarctica LTER. This dataset was popularized by Allison Horst in her R package palmerpenguins with the goal to offer an alternative to the iris dataset for data exploration and visualization.

data = palmerpenguins.load_penguins()
data.head()
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007
3 Adelie Torgersen NaN NaN NaN NaN NaN 2007
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 female 2007

Today's goal is to show you a lot of examples that show how to customize different aspects of a legend. Hopefully, after reading this post, you will be able to take the pieces you need and build your own custom legend.

The base plot is going to be a scatterplot of flipper length vs bill length, colored by species.

FLIPPER_LENGTH = data["flipper_length_mm"].values
BILL_LENGTH = data["bill_length_mm"].values

SPECIES = data["species"].values
SPECIES_ = np.unique(SPECIES)

COLORS = ["#1B9E77", "#D95F02", "#7570B3"]

Default legend

Let's get started by creating the chart and calling ax.legend() to see what is matplotlib default behavior when it come to adding legends.

fig, ax = plt.subplots(figsize=(8,8))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    # No legend will be generated if we don't pass label=species
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )
    
ax.legend()
plt.show()

By default, Matplotlib automatically generates a legend that correctly reflects the colors and labels we passed. Usually, it also places the legend in a good place. But that's not the case here since the legend overlaps with one of the dots.

Markers are automatically accurate

In addition, Matplotlib also reflects the different markers in the chart. So you should just care about using some nice markers and the legend will update for free:

# The markers we use in the scatterplot
MARKERS = ["o", "^", "s"] # circle, triangle, square

fig, ax = plt.subplots(figsize=(8,8))

for species, color, marker in zip(SPECIES_, COLORS, MARKERS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, marker=marker, alpha=0.7
    )
    
ax.legend()
plt.show()

Now, let's see a lot of different approaches one can use to position the legend in different places.

Adjust the legend position with loc

The first thing one can do is pass something to the loc argument. This can be either a string in plain English indicating the position of the label, or a number. A complete description of these codes can be found in the matplotlib doc of the loc argument.

fig, ax = plt.subplots(figsize=(8,8))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )

# Lower right corner is a better place for this legend
ax.legend(loc="lower right")
plt.show()

# The same result with
# ax.legend(loc=4);

Legend outside the plot area with subplots_adjust

It's also possible to take it position the legend outside the plotting region (i.e. out of the the Axis). To do so, we need to make room in the figure real estate thanks to the subplots_adjust() function, and then use the same loc argument described above:

fig, ax = plt.subplots(figsize=(8,8))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )

# Let's say we want it on the right side. 
# First, make room on the right side of the figure.
fig.subplots_adjust(right=0.8)


# Add the legend
# Pass `fig.transFigure` as the bounding box transformation 'bbox_transform'
# loc="center left" and bbox_to_anchor=(0.8, 0.5) indicate the left border
# of the legend is placed at 0.8 in the x axis of the figure, and it is
# vertically centered at y = 0.5
ax.legend(
    loc="center left",
    bbox_to_anchor=(0.8, 0.5),
    bbox_transform=fig.transFigure 

)
plt.show()

One could do the same thing to place the legend on the left side of the plot

fig, ax = plt.subplots(figsize=(8,8))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )


# Make room on the left side of the figure.
fig.subplots_adjust(left=0.2)

# Add the legend
# Now x = 0.16 to leave some space for the axis tick labels
ax.legend(
    loc="center right",
    bbox_to_anchor=(0.16, 0.5),
    bbox_transform=fig.transFigure 
)
plt.show()

Can we have the legend on top/bottom? Of course we can! In this case, it makes much more sense to have the entries one next to each other, so we use ncol=3 to tell Matplotlib the legend has three columns.

fig, ax = plt.subplots(figsize=(8, 6))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )

# Make room on top now
fig.subplots_adjust(top=0.8)

ax.legend(
    loc="lower center", # "upper center" puts it below the line
    ncol=3,
    bbox_to_anchor=(0.5, 0.8),
    bbox_transform=fig.transFigure 
)
plt.show()
fig, ax = plt.subplots(figsize=(8, 6))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )

# Make room on below
fig.subplots_adjust(bottom=0.2)

# Again, leave some extra space for the axis tick labels
ax.legend(
    loc="upper center",
    ncol=3,
    bbox_to_anchor=(0.5, 0.16),
    bbox_transform=fig.transFigure 
)
plt.show()

Align legend with axis

Is it possible to align the plot to the right or left? Sure!

One difference with the plots above, is that here we don't use bbox_transform=fig.transFigure. If we want to align the boundary of the legend with the boundary of the axis, it's easier to use the default which is the axis. Note the value 1.05. It means the legend is 5% of the height of the axis above its top boundary.

fig, axes = plt.subplots(2, 1, figsize=(8, 12), tight_layout=True)
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    axes[0].scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )
    
    axes[1].scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )

# borderaxespad = 0 means there's no padding between the border
# of the legend and the axis
axes[0].legend(
    loc="center left", 
    ncol=3,
    bbox_to_anchor=[0, 1.05],
    borderaxespad=0,
)

axes[1].legend(
    loc="center right", 
    ncol=3,
    bbox_to_anchor=[1, 1.05],
    borderaxespad=0,
);

Customize labels

One can store the object returned by ax.legend() and do lot of interesting things with it.

The legend returned has a method .get_texts() that return a list of matplotlib.text.Text objects. These objects have a lot of methods one can use to customize the appearence of the text.

fig, ax = plt.subplots(figsize=(8, 6))
for species, color in zip(SPECIES_, COLORS):
    idxs = np.where(SPECIES == species)
    ax.scatter(
        FLIPPER_LENGTH[idxs], BILL_LENGTH[idxs], label=species,
        s=50, color=color, alpha=0.7
    )
    
legend = ax.legend(loc="lower right")

# Iterate over texts.
# Method names are quite self-describing
for text in legend.get_texts():
    text.set_color("#b13f64")
    text.set_fontstyle("italic")
    text.set_fontweight("bold")
    text.set_fontsize(14)

plt.show()