Space of a muser

Insight

Enjoy life, relish the moment!


Plot heatmap with side color indicating the class of variables

Sometimes, we need to make heatmaps for the matrix data with class information of categorical variables. In this case, we may want to use side color to mark the classes. Since neither matplotlib nor seaborn offer such functionality directly, we have to implement it in an indirect way. Here is some code snippets from Stack Overflow with a bit change to achieve this requirement (The original code does not work for me, so I did some change).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
%matplotlib inline
import pandas as pd
import seaborn as sns
sns.set(font="Arial")

# Load the brain networks example dataset
df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0)

# Select a subset of the networks
used_networks = [1, 5, 6, 7, 8, 11, 12, 13, 16, 17]
used_columns = (df.columns.get_level_values("network")
                          .astype(int)
                          .isin(used_networks))
df = df.loc[:, used_columns]
# Create a custom palette to identify the networks
network_pal = sns.cubehelix_palette(len(used_networks),
                                    light=.9, dark=.1, reverse=True,
                                    start=1, rot=-2)
network_lut = dict(zip(map(str, used_networks), network_pal))

# Convert the palette to vectors that will be drawn on the side of the matrix
network_labels = df.columns.get_level_values("network")
network_colors = pd.Series(network_labels, index=df.columns).map(network_lut)

# Create a custom colormap for the heatmap values
cmap = sns.diverging_palette(h_neg=210, h_pos=350, s=90, l=30, as_cmap=True)

# Draw the full plot
g = sns.clustermap(df.corr(),

                  # Turn off the clustering
                  row_cluster=False, col_cluster=False,

                  # Add colored class labels
                  row_colors=network_colors, col_colors=network_colors,

                  # Make the plot look better when many rows/cols
                  linewidths=0, xticklabels=False, yticklabels=False)

# Draw the legend bar for the classes                 
for label in network_labels.unique():
    g.ax_col_dendrogram.bar(0, 0, color=network_lut[label],
                            label=label, linewidth=0)
g.ax_col_dendrogram.legend(loc="center", ncol=5)

# Adjust the postion of the main colorbar for the heatmap
g.cax.set_position([.97, .2, .03, .45])

Heatmap

更早的文章

Some useful tips for pandas

于  Python, data process, pandas
comments powered by Disqus