# Interactive scatterplot with Plotly

This post describes how to create an interactive scatterplot with python using the plotly library.
Interactivity on scatterplots is very handy: it allows to zoom on a specific area and have a tooltip for data point description.
This tutorial starts with a simple example and then shows how to add features such as a level of grouping or a trend line.

## What is a scatter plot?

A scatter plot is a simple yet powerful way to visually represent data relationships by displaying points on a graph.

Imagine a coordinate system, where the horizontal `(x) axis` represents one variable, and the `vertical (y)` axis represents another. Each data point corresponds to a specific observation and is represented by its coordinates on the graph. By looking at the distribution of points, we can easily identify patterns or trends, such as correlations, clusters, or outliers.

Essentially, a scatter plot allows us to see how two variables might be related or influenced by each other, making it a valuable tool for understanding data and drawing insights in a visual and intuitive manner.

## Libraries

Plotly is a library designed for interactive visualization with python.

If you want to use plotly, you can either use the `plotly.express` module (`px`) or the `plotly.graph_objects` module (`go`). The main difference between these modules is the "level": the `px` module is easier to use (high level), but if you want to do something very specific and need flexibility, I recommend you use the `go` module (lower level).

Don't forget to install plotly with the `pip install plotly` command.

In our case, we'll use the `go` module (or `graph_objects`). We will also use `numpy` and `pandas` to generate data and put it into a dataframe.

``````# Libraries
import plotly.graph_objects as go
import pandas as pd
import numpy as np``````

## Dataset

Since scatterplots are intended to represent 2 continuous variables, let's generate a sample of 200 randomly distributed observations using numpy and its functions `random.normal()` and `random.uniform()`.

``````# Generate a sample of 100 observations
sample_size = 200
x = np.random.uniform(30, 20, sample_size)
y = x * 10 + np.random.normal(0, 10, sample_size)

# Put the data into a pandas df
df = pd.DataFrame({'x': x,
'y': y})``````

## Basic scatter plot

The following code displays a simple scatter plot, with a title and an axis name, thanks to the `Scatter()` function.

The `fig.add_trace(go.Scatter([...]` line tells the program to add a new trace (data series) to a figure (`fig`) that we initiated just before. In this case, the trace is a scatter plot.

``````# Create the figure (for the moment: a blank graph)
fig = go.Figure()

x=df['x'], # Variable in the x-axis
y=df['y'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points

# Properties associated with points
marker=dict(
size=12, # Size
color='#cb1dd1', # Color
opacity=0.8, # Point transparency
line=dict(width=1, color='black') # Properties of the edges
),
))

# Customize the layout
fig.update_layout(
title='Interactive Scatter Plot', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800,  # Set the width of the figure to 800 pixels
height=600,  # Set the height of the figure to 600 pixels
)

``````

Now, let's save the graph in a HTML file and display it in this website using an `iframe`

``````# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-basic.html")

``````
``````%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-basic.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>

``````

That's it, a first interactive scatterplot! 🔥 Try to hover markers, zoom on a specific area to understand the full potential of this chart.

If you want to display a grouping variable on a scatter plot, you can change the color of the points according to the labels of the categorical variable.

To do this with plotly, we'll need to define the colors we want for each label and then iterate over each label to add the associated points.

We need to create and add a categorical variable to our initial data set. To do this, we use a comprehension list which will assign the value `'Group1'` to observations with a value less than 25 for the `first_variable`, `'Group2'` otherwise:

``````# Generate a sample of 100 observations
sample_size = 200
first_variable = np.random.uniform(30, 20, sample_size)
second_variable = first_variable * 10 + np.random.normal(0, 10, sample_size)
categorical_variable = ['Group1' if i < 25 else 'Group2' for i in first_variable]

# Put the data into a pandas df
df = pd.DataFrame({'first_variable': first_variable,
'second_variable': second_variable,
'categorical_variable': categorical_variable,})``````

Now, let's create a scatterplot using the `group` variable to color the markers:

``````# Create a dictionary to map categories to colors
category_colors = {
'Group1': 'orange',
'Group2': 'purple'
}

# Create the scatter plot (for the moment: a blank graph)
fig = go.Figure()

# Add the scatter trace with color based on the category_variable
for category, color in category_colors.items():
category_data = df[df['categorical_variable'] == category]
x=category_data['first_variable'], # Variable in the x-axis
y=category_data['second_variable'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points
name=category,

# Properties associated with points
marker=dict(
size=12,
color=color,
opacity=0.7,
line=dict(width=2, color='black') # Properties of the edges
),
))

# Customize the layout and change the figure size
fig.update_layout(
title='Interactive Scatter Plot with a Categorical Variable', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800,  # Set the width of the figure to 800 pixels
height=600  # Set the height of the figure to 600 pixels
)``````

Save in HTML and display:

``````# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-grouping.html")

``````
``````%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-grouping.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>``````

## Add a trendline (linear or polynomial)

Adding a trendline to a scatter plot serves the purpose of visually representing the overall trend or pattern in the data. It helps in understanding the general relationship between two variables and provides insights into how they might be related.

For this purpose, we've created a function `fit_trendline()` to fit a relationship between 2 quantitative variables, for a given degree. A degree of 1 is equivalent to fit a linear relationship.

Once the data has been fitted, simply add the line to the graph and define a few styling parameters.

``````# Function to fit linear or polynomial trendline
def fit_trendline(x, y, degree):
coeffs = np.polyfit(x, y, degree)
return np.polyval(coeffs, x)

# Create a dictionary to map categories to colors
category_colors = {
'Group1': 'orange',
'Group2': 'purple'
}

# Create the scatter plot (for the moment: a blank graph)
fig = go.Figure()

# Add the scatter trace with color based on the category_variable
for category, color in category_colors.items():
category_data = df[df['categorical_variable'] == category]
x=category_data['first_variable'], # Variable in the x-axis
y=category_data['second_variable'], # Variable in the y-axis
mode='markers', # This explicitly states that we want our observations to be represented by points
name=category,

# Properties associated with points
marker=dict(
size=12,
color=color,
opacity=0.7,
line=dict(width=2, color='black') # Properties of the edges
),
))

# Fit the data with our function
trendline_y = fit_trendline(df['first_variable'], df['second_variable'], degree=1)

x=df['first_variable'],
y=trendline_y,
mode='lines',
line=dict(color='black', dash='solid', width=3), # Dash the line to distinguish trendlines
showlegend=False  # Remove trendline from the legend
))

# Customize the layout and change the figure size
fig.update_layout(
title='Interactive Scatter Plot with a Categorical Variable', # Title
xaxis_title='First Variable', # x-axis name
yaxis_title='Second Variable', # y-axis name
width=800,  # Set the width of the figure to 800 pixels
height=600  # Set the height of the figure to 600 pixels
)

``````
``````# save this file as a standalong html file:
fig.write_html("../../static/interactiveCharts/scatterplot-plotly-trendline.html")``````
``````%%html
<iframe
src="../../interactiveCharts/scatterplot-plotly-trendline.html"
width="800"
height="600"
title="scatterplot with plotly"
style="border:none">
</iframe>``````

## Going further

This article explains how to create an interactive scatter plot with plotly with various customization features, such as adding a categorical variable or a trendline.

For more examples of how to create or customize your scatter plots with Python, see the scatter plot section. You may also be interested in creating a scatter plot with marginal distribution.

## Contact & Edit

👋 This document is a work by Yan Holtz. You can contribute on github, send me a feedback on twitter or subscribe to the newsletter to know when new examples are published! 🔥