Valuable Matplotlib & Seaborn Visualization Handbook, Part I

1. Introduction

This post summarizes the top 50 most valuable Matplotlib & Seaborn data visualizations in data science. It can be taken as a data visualization handbook for you to look up for useful visulaization.

The 50 visualizations are categorized into 7 different application scenarios, and this post would mainly focuses on the first two categories, shown as follows:

  1. Correlation
  1. Deviation
  1. Ranking
  1. Distribution
  1. Composition
  1. Change
  1. Groups

Before heading into the 50 visualizations, let's first import neccesary libraries and plot settings.

In [1]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

large = 22
med = 16
small = 12
params = {'axes.titlesize': large,
         'legend.fontsize': med,
         'figure.figsize': (16, 10),
         'axes.labelsize': med,
         'axes.titlesize': med,
         'xtick.labelsize': med,
         'ytick.labelsize': med,
         'figure.titlesize': large}
plt.rcParams.update(params)
plt.style.use('seaborn-whitegrid')
sns.set_style('white')
import warnings
warnings.filterwarnings(action = 'once')

%matplotlib inline

2. Correlation

The plot of correlation can used to visualize the correlation between two or more variables.

Scatter plot

Scatter plot is a classical plot to visualize the correlation between variables. It can be done using plt.scatterplot().

In [2]:
#import dataset
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
#prepare data
#create as many colors as there are unique midewest['category']
categories = np.unique(midwest['category'])
colors = [plt.cm.tab10(i/float(len(categories) - 1)) for i in
          range(len(categories))]
#draw plot for each category
plt.figure(figsize = (16, 10), dpi = 80, facecolor = 'w', edgecolor = 'k')
for i, category in enumerate(categories):
    plt.scatter('area', 'poptotal', data = midwest.loc[midwest.category == category, :],
               s = 20, cmap = colors[i], label = str(category))

#decorations
plt.gca().set(xlim = (0.0, 0.1), ylim = (0, 90000), xlabel = 'Area',
             ylabel = 'Population')
plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.title("Scatterplot of Midwest Area vs Population", fontsize = 22)
plt.legend(fontsize = 12)
plt.show()

Bubble plot with Encircling

This type of plot shows the significance within a circular range. Although we use the data same as we did for the scatter plot above, the data importing procedure is done once again to keep the visualization as a whole process.

In [3]:
from matplotlib import patches
from scipy.spatial import ConvexHull

#import data
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
# As many colors as there are unique midwest['category']
categories = np.unique(midwest['category'])
colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
#Step 2: draw scatterplot with unique color for each category
fig = plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k') 
for i, category in enumerate(categories):
    plt.scatter('area', 'poptotal', data=midwest.loc[midwest.category == category, :],
                s = 'dot_size', cmap = colors[i], label = str(category), edgecolors = 'black',
                linewidths = .5)
    
#Step 3: Encircling
def encircle(x, y, ax = None, **kw):
    if not ax:
        ax = plt.gca()
    p = np.c_[x,y]
    hull = ConvexHull(p)
    poly = plt.Polygon(p[hull.vertices, :], **kw)
    ax.add_patch(poly)
    
#Select the data to be encircled
midwest_encircle_data = midwest.loc[midwest.state == 'IN', :]
#Draw polygon surrounding vertices
encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal,
        ec = 'k', fc = 'gold', alpha = 0.1)
encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal,
        ec = 'firebrick', fc = 'none', linewidth = 1.5)
#decorations

plt.gca().set(xlim = (0.0, 0.1), ylim = (0, 90000), xlabel = 'Area',
             ylabel = 'Population')
plt.xticks(fontsize = 12)
plt.yticks(fontsize = 12)
plt.title("Scatterplot of Midwest Area vs Population", fontsize = 22)
plt.legend(fontsize = 12)
plt.show()

Scatter plot with line of the best fit

We use linear regression to find the best fit line. It can be used to visualize how two variables interact with each other. If you only want to fit one single line for the entire dataset, delete the parameter hue = 'cyl' in sns.lmplot().

In [4]:
import warnings 
warnings.simplefilter('ignore')
#Import data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_select = df.loc[df.cyl.isin([4, 8]), :]
sns.set_style("white")
gridobj = sns.lmplot(x = 'displ', y = 'hwy', hue = 'cyl', data = df_select,
                    size = 7, aspect = 1.6, robust = True, palette = 'tab10',
                    scatter_kws = dict(s = 60, linewidths = .7, edgecolors = 'black'))
#Decorations
gridobj.set(xlim = [0.5, 7.5], ylim = [0, 50])
plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize = 20)
plt.show()

Or we can plot the best fit line for each column. This can be done by setting col = groupingcolumn in sns.lmplot().

In [5]:
# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_select = df.loc[df.cyl.isin([4, 8]), :]
# Each line in its own column
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy",
 data = df_select,
 size = 7,
 robust=True,
 palette = 'Set1',
 col = "cyl",
 scatter_kws=dict(s = 60, linewidths = .7, edgecolors = 'black'))
# Decorations
gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
plt.show()

Jittering with stripplot

In order to avoid hidden data points in the plot, usually due to the same or similar value, we use jittering plot to see all the overlapping data points. This can be done using stripplot() from seaborn.

In [6]:
# Import data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
# Draw Stripplot
fig, ax = plt.subplots(figsize = (16,10), dpi = 80)
sns.stripplot(df.cty, df.hwy, jitter = 0.25, size = 8,
             ax = ax, linewidth = 0.5)
# Decorations
plt.title("Use jittered plots to avoid overlapping of points", fontsize = 20)
plt.show()

Counts plot

This is another approach to visualize overlapping points. The bigger the ploted point is, the more data points are located nearby.

In [7]:
# Import data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
df_counts = df.groupby(['hwy', 'cty']).size().reset_index(name = 'counts')
# Draw Stripplot
fig, ax = plt.subplots(figsize = (16, 10), dpi = 80)
sns.stripplot(df_counts.cty, df_counts.hwy, size = df_counts.counts*2, ax = ax)
# Decorations
plt.title("Counts plot - Size of circle is bigger as more points overlap", fontsize = 20)
plt.show()

Marginal histogram

A plot of marginal histogram has two histograms along X and Y axes. This plot can not only visualize the relationship between X and Y, but also their separate distribution. This plot is often used in exploratory data analysis (EDA).

In [8]:
# Import data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")

#Create figure and gridspec
fig = plt.figure(figsize = (16, 10), dpi = 80)
grid = plt.GridSpec(4, 4, hspace = 0.5, wspace = 0.2)

#Define the axes
ax_main = fig.add_subplot(grid[:-1, :-1])
ax_right = fig.add_subplot(grid[:-1, -1], xticklabels = [], yticklabels = [])
ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels = [], yticklabels = [])

#Scatterplot on the main ax
ax_main.scatter('displ', 'hwy', s = df.cty*4, c = df.manufacturer.astype('category').cat.codes,
               alpha = 0.9, data = df, cmap = "tab10", edgecolors = 'gray',
               linewidths = 0.5)
#histogram in the bottom
ax_bottom.hist(df.displ, 40, histtype = 'stepfilled', orientation = 'vertical',
              color = 'b')
ax_bottom.invert_yaxis()

#histogram on the right
ax_right.hist(df.hwy, 40, histtype = 'stepfilled', orientation = 'horizontal', color = 'b')

#Decorations
ax_main.set(title = "Scatterplot with Histograms \n displ vs hwy", xlabel = 'displ',
           ylabel = 'hwy')
ax_main.title.set_fontsize(20)
for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
    item.set_fontsize(14)

xlabels = ax_main.get_xticks().tolist()
ax_main.set_xticklabels(xlabels)
plt.show()

Mariginal boxplot

Marginal boxplot allows to see the median and percentiles of the two variables, X and Y.

In [9]:
# Import Data
df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")

# Create Fig and gridspec
fig = plt.figure(figsize=(16, 10), dpi= 80)
grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)

# Define the axes
ax_main = fig.add_subplot(grid[:-1, :-1])
ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])
ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])

# Scatterplot on main ax
ax_main.scatter('displ', 'hwy', s=df.cty*5, c=df.manufacturer.astype('category').cat.codes,
                alpha=.9, data=df, cmap="Set1", edgecolors='black', linewidths=.5)

# Add a boxplot in each part
sns.boxplot(df.hwy, ax = ax_right, orient = 'v')
sns.boxplot(df.displ, ax = ax_bottom, orient = 'h')

# Decorations ------------------
# Remove x axis name for the boxplot
ax_bottom.set(xlabel='')
ax_right.set(ylabel='')

# Main Title, Xlabel and YLabel
ax_main.set(title='Scatterplot with Histograms \n displ vs hwy', xlabel='displ', ylabel='hwy')

# Set font size of different components
ax_main.title.set_fontsize(20)
for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
    item.set_fontsize(14)

plt.show()

Correlogram

Correlogram is used to show the correlation matrix between all possible pairs of numeric variables in a given dataframe. We can use the sns.heatmap() to draw the plot.

In [10]:
# Import Dataset
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")

# Plot
plt.figure(figsize=(12,10), dpi= 80)
sns.heatmap(df.corr(), xticklabels = df.corr().columns, yticklabels = df.corr().columns,
           cmap = 'RdYlGn', center = 0, annot = True)
# Decorations
plt.title("Correlogram of mtcars", fontsize = 20)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

Pairwise plot

Pairwise plot is another way to visualize all possible relationships between each pair of numeric variables. We can use sns.pairplot() to generate the plot. It's easy and nice.

First, let's use the scatter plot in the pairplot.

In [11]:
# Load Dataset
df = sns.load_dataset('iris')

# Plot
plt.figure(figsize=(10,8), dpi= 80)
sns.pairplot(df, kind="scatter", hue="species", plot_kws=dict(s=80, edgecolor="white", linewidth=2.5))
plt.show()
<Figure size 800x640 with 0 Axes>

Then, let's see the pairplot with regression lines.

In [12]:
# Load Dataset
df = sns.load_dataset('iris')

# Plot
plt.figure(figsize=(10,8), dpi= 80)
sns.pairplot(df, kind="reg", hue="species")
plt.show()
<Figure size 800x640 with 0 Axes>

2. Deviation

Diverging bars

The diverging bars plot is a useful tool if you want to visualize how your observations behave based on a single variable.

In [13]:
# Import Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
x = df.loc[:, ['mpg']]
# Calculate the z score of mpg
df['mpg_z'] = (x-x.mean())/x.std()
df['colors'] = ['red' if x<0 else 'green' for x in df['mpg_z']]
df.sort_values('mpg_z', inplace = True)

# Draw plot
plt.figure(figsize = (12, 10), dpi = 80)
plt.hlines(y = df.index, xmin = 0, xmax = df.mpg_z, color = df.colors, alpha = 0.4,
          linewidth = 5)

# Decorations
plt.gca().set(ylabel = '$Model$', xlabel = '$Mileage$')
plt.yticks(df.index, df.cars, fontsize = 12)
plt.title('Diverging Bars of Car Mileage', fontdict = {'size': 20})
plt.grid(linestyle = '--', alpha = 0.5)
plt.show()

Diverging texts

Diverging texts is similar to diverging bars and it is preferred if you want to show the exact value of eahc observationwithin the chart in a nice and presentable way.

In [14]:
# Import Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
x = df.loc[:, ['mpg']]
df['mpg_z'] = (x - x.mean())/x.std()
df['colors'] = ['red' if x < 0 else 'green' for x in df['mpg_z']]
df.sort_values('mpg_z', inplace=True)
df.reset_index(inplace=True)

# Draw plot
plt.figure(figsize = (12, 10), dpi = 80)
plt.hlines(y = df.index, xmin = 0, xmax = df.mpg_z)
for x,y, tex in zip(df.mpg_z, df.index, df.mpg_z):
    t = plt.text(x,y, round(tex, 2), horizontalalignment = 'right' if x < 0 else 'left',
                verticalalignment = 'center', fontdict = {'color': 'red' if x<0 else 'green',
                                                         'size': 14})

# Decorations    
plt.yticks(df.index, df.cars, fontsize=12)
plt.title('Diverging Text Bars of Car Mileage', fontdict={'size':20})
plt.grid(linestyle='--', alpha=0.5)
plt.xlim(-2.5, 2.5)
plt.show()

Diverging dot plot

The dot plot changes the lines into dots. The absence of bars reduces the amount of contrast and disparity between observations. In my opinion, it is not as good as the above two plots, but it is still an option.

In [15]:
# Import Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
x = df.loc[:, ['mpg']]
df['mpg_z'] = (x - x.mean())/x.std()
df['colors'] = ['red' if x < 0 else 'darkgreen' for x in df['mpg_z']]
df.sort_values('mpg_z', inplace=True)
df.reset_index(inplace=True)

# Draw plot
plt.figure(figsize=(12,10), dpi= 80)
plt.scatter(df.mpg_z, df.index, s=450, alpha=.6, color=df.colors)
for x, y, tex in zip(df.mpg_z, df.index, df.mpg_z):
    t = plt.text(x, y, round(tex, 1), horizontalalignment='center', 
                 verticalalignment='center', fontdict={'color':'white'})

# Decorations
# Lighten borders
plt.gca().spines["top"].set_alpha(.3)
plt.gca().spines["bottom"].set_alpha(.3)
plt.gca().spines["right"].set_alpha(.3)
plt.gca().spines["left"].set_alpha(.3)

plt.yticks(df.index, df.cars, fontsize = 10)
plt.title('Diverging Dotplot of Car Mileage', fontdict={'size':20})
plt.xlabel('$Mileage$')
plt.grid(linestyle='--', alpha=0.5)
plt.xlim(-2.5, 2.5)
plt.show()

Diverging Lollipop chart with markers

Lollipop with markers provides a flexible way fo visualizing the divergence by laying emphasis on any significant datapoints you want to bring attention to and give reasoning within the chart appropriately.

In [16]:
# Prepare Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
x = df.loc[:, ['mpg']]
df['mpg_z'] = (x - x.mean())/x.std()
df['colors'] = 'black'

# Color Fiat differently than other cars
df.loc[df.cars == 'Fiat X1-9', 'colors'] = 'darkorange'
df.sort_values('mpg_z', inplace = True)
df.reset_index(inplace = True)

# Draw plot
import matplotlib.patches as patches
plt.figure(figsize = (14, 10), dpi = 80)
plt.hlines(y = df.index, xmin = 0, xmax = df.mpg_z, color = df.colors, alpha = 0.4, linewidth = 1)
plt.scatter(df.mpg_z, df.index, color = df.colors, s = [600 if x =='Fiat X1-9' else 300 for x in 
                                                       df.cars], alpha = 0.6)
plt.yticks(df.index, df.cars)
plt.xticks(fontsize = 12)

# Annotate
plt.annotate('Mercedes Models', xy=(0.0, 11.0), xytext=(1.0, 11), xycoords='data', 
            fontsize=15, ha='center', va='center',
            bbox=dict(boxstyle='square', fc='firebrick'),
            arrowprops=dict(arrowstyle='-[, widthB=2.0, lengthB=1.5', lw=2.0, color='steelblue'), color='white')

# Add Patches
p1 = patches.Rectangle((-2.0, -1), width=.3, height=3, alpha=.2, facecolor='red')
p2 = patches.Rectangle((1.5, 27), width=.8, height=5, alpha=.2, facecolor='green')
plt.gca().add_patch(p1)
plt.gca().add_patch(p2)

# Decorate
plt.title('Diverging Bars of Car Mileage', fontdict={'size':20})
plt.grid(linestyle='--', alpha=0.5)
plt.show()        

Area chart

By coloring the area between the axis and the lines, the area chart throws more emphasis not just on the peaks and troughs, but also the duration of hte highs and lows. The longer the duration of the highs, the larger is the area under the line.

In [17]:
# Import data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/economics.csv", parse_dates=['date']).head(100)
x = np.arange(df.shape[0])
y_returns = (df.psavert.diff().fillna(0)/df.psavert.shift(1)).fillna(0) * 100
# Plot
plt.figure(figsize=(16,10), dpi= 80)
plt.fill_between(x[1:], y_returns[1:], 0, where = y_returns[1:] >= 0, facecolor = 'green', 
                interpolate = True, alpha = 0.7)
plt.fill_between(x[1:], y_returns[1:], 0, where=y_returns[1:] <= 0, facecolor='red', interpolate=True, alpha=0.7)
plt.annotate('Peak \n1975', xy=(94.0, 21.0), xytext=(88.0, 28),
             bbox=dict(boxstyle='square', fc='firebrick'),
             arrowprops=dict(facecolor='steelblue', shrink=0.05), fontsize=15, color='white')

# Decorations
xtickvals = [str(m)[:3].upper()+"-"+str(y) for y,m in zip(df.date.dt.year, df.date.dt.month_name())]
plt.gca().set_xticks(x[::6])
plt.gca().set_xticklabels(xtickvals[::6], rotation=45, fontdict={'horizontalalignment': 'center', 
                                                                 'verticalalignment': 'top', 'fontsize': 12})
plt.ylim(-35,35)
plt.xlim(1,100)
plt.title("Month Economics Return %", fontsize=22)
plt.ylabel('Monthly returns %')
plt.grid(alpha=0.5)
plt.show()

The rest of the top visualizations will be presented in the future posts. Peace~



Comments

comments powered by Disqus