What is a scatter plot?
A scatter plot is a simple yet powerful way to visually represent data relationships by displaying points on a graph.
Imagine a coordinate system, where the horizontal (x) axis
represents one variable, and the vertical (y)
axis represents another. Each data point corresponds to a specific observation and is represented by its coordinates on the graph. By looking at the distribution of points, we can easily identify patterns or trends, such as correlations, clusters, or outliers.
Essentially, a scatter plot allows us to see how two variables might be related or influenced by each other, making it a valuable tool for understanding data and drawing insights in a visual and intuitive manner.
Libraries
Plotly is a library designed for interactive visualization with python.
If you want to use plotly, you can either use the plotly.express
module (px
) or the plotly.graph_objects
module (go
). The main difference between these modules is the "level": the px
module is easier to use (high level), but if you want to do something very specific and need flexibility, I recommend you use the go
module (lower level).
Don't forget to install plotly with the pip install plotly
command.
In our case, we'll use the go
module (or graph_objects
). We will also use numpy
and pandas
to generate data and put it into a dataframe.
# Libraries
import plotly.graph_objects as go
import pandas as pd
import numpy as np
Dataset
Since scatterplots are intended to represent 2 continuous variables, let's generate a sample of 200 randomly distributed observations using numpy and its functions random.normal()
and random.uniform()
.
# Generate a sample of 100 observations
sample_size = 200
x = np.random.uniform(30, 20, sample_size)
y = x * 10 + np.random.normal(0, 10, sample_size)
# Put the data into a pandas df
df = pd.DataFrame({'x': x,
'y': y})
Basic scatter plot
The following code displays a simple scatter plot, with a title and an axis name, thanks to the Scatter()
function.
The fig.add_trace(go.Scatter([...]
line tells the program to add a new trace (data series) to a figure (fig
) that we initiated just before. In this case, the trace is a scatter plot.
# Create the figure (for the moment: a blank graph)
fig = go.Figure()
# Add the scatter trace
fig.add_trace(go.Scatter(
x=df['x'], # Variable in the x-axis
y=df['y'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points
# Properties associated with points
marker=dict(
size=12, # Size
color='#cb1dd1', # Color
opacity=0.8, # Point transparency
line=dict(width=1, color='black') # Properties of the edges
),
))
# Customize the layout
fig.update_layout(
title='Interactive Scatter Plot', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800, # Set the width of the figure to 800 pixels
height=600, # Set the height of the figure to 600 pixels
)
Now, let's save the graph in a HTML file and display it in this website using an iframe
# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-basic.html")
%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-basic.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>
That's it, a first interactive scatterplot! 🔥 Try to hover markers, zoom on a specific area to understand the full potential of this chart.
Add a grouping variable
If you want to display a grouping variable on a scatter plot, you can change the color of the points according to the labels of the categorical variable.
To do this with plotly, we'll need to define the colors we want for each label and then iterate over each label to add the associated points.
We need to create and add a categorical variable to our initial data set. To do this, we use a comprehension list which will assign the value 'Group1'
to observations with a value less than 25 for the first_variable
, 'Group2'
otherwise:
# Generate a sample of 100 observations
sample_size = 200
first_variable = np.random.uniform(30, 20, sample_size)
second_variable = first_variable * 10 + np.random.normal(0, 10, sample_size)
categorical_variable = ['Group1' if i < 25 else 'Group2' for i in first_variable]
# Put the data into a pandas df
df = pd.DataFrame({'first_variable': first_variable,
'second_variable': second_variable,
'categorical_variable': categorical_variable,})
Now, let's create a scatterplot using the group
variable to color the markers:
# Create a dictionary to map categories to colors
category_colors = {
'Group1': 'orange',
'Group2': 'purple'
}
# Create the scatter plot (for the moment: a blank graph)
fig = go.Figure()
# Add the scatter trace with color based on the category_variable
for category, color in category_colors.items():
category_data = df[df['categorical_variable'] == category]
fig.add_trace(go.Scatter(
x=category_data['first_variable'], # Variable in the x-axis
y=category_data['second_variable'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points
name=category,
# Properties associated with points
marker=dict(
size=12,
color=color,
opacity=0.7,
line=dict(width=2, color='black') # Properties of the edges
),
))
# Customize the layout and change the figure size
fig.update_layout(
title='Interactive Scatter Plot with a Categorical Variable', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800, # Set the width of the figure to 800 pixels
height=600 # Set the height of the figure to 600 pixels
)
Save in HTML and display:
# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-grouping.html")
%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-grouping.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>
Add a trendline (linear or polynomial)
Adding a trendline to a scatter plot serves the purpose of visually representing the overall trend or pattern in the data. It helps in understanding the general relationship between two variables and provides insights into how they might be related.
For this purpose, we've created a function fit_trendline()
to fit a relationship between 2 quantitative variables, for a given degree. A degree of 1 is equivalent to fit a linear relationship.
Once the data has been fitted, simply add the line to the graph and define a few styling parameters.
# Function to fit linear or polynomial trendline
def fit_trendline(x, y, degree):
coeffs = np.polyfit(x, y, degree)
return np.polyval(coeffs, x)
# Create a dictionary to map categories to colors
category_colors = {
'Group1': 'orange',
'Group2': 'purple'
}
# Create the scatter plot (for the moment: a blank graph)
fig = go.Figure()
# Add the scatter trace with color based on the category_variable
for category, color in category_colors.items():
category_data = df[df['categorical_variable'] == category]
fig.add_trace(go.Scatter(
x=category_data['first_variable'], # Variable in the x-axis
y=category_data['second_variable'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points
name=category,
# Properties associated with points
marker=dict(
size=12,
color=color,
opacity=0.7,
line=dict(width=2, color='black') # Properties of the edges
),
))
# Fit the data with our function
trendline_y = fit_trendline(df['first_variable'], df['second_variable'], degree=1)
fig.add_trace(go.Scatter(
x=df['first_variable'],
y=trendline_y,
mode='lines',
line=dict(color='black', dash='solid', width=3), # Dash the line to distinguish trendlines
showlegend=False # Remove trendline from the legend
))
# Customize the layout and change the figure size
fig.update_layout(
title='Interactive Scatter Plot with a Categorical Variable', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800, # Set the width of the figure to 800 pixels
height=600 # Set the height of the figure to 600 pixels
)
# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-trendline.html")
%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-trendline.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>
Going further
This article explains how to create an interactive scatter plot with plotly with various customization features, such as adding a categorical variable or a trendline.
For more examples of how to create or customize your scatter plots with Python, see the scatter plot section. You may also be interested in creating a scatter plot with marginal distribution.