#194 Split the graphic window with subplot

It can be really useful to split your graphic window in several parts, in order to display several charts in the same time. With matplotlib, this is done through the subplot function. It is possible to display your graphics in several rows or several columns, or both. You can also control the size of each part. The examples below should allow you to understand how to proceed:

  •  

     

    We use the subplot function to split the graphic window in several parts. It works like that:

    plt.subplot(num1, num2, num3)

    num1: how many row do you want?
    num2: how many column do you want?
    num3: Which part are you going to modify?

    Here we want 1 row and 2 columns: plt.subplot(121) and then plt.subplot(122) !

     

     

     

     

     

    
    # libraries and data
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    
    # Cut your window in 1 row and 2 columns, and start a plot in the first part
    plt.subplot(121)
    plt.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    plt.title("A subplot with 2 lines")
    
    
    # And now add something in the second part:
    plt.subplot(122)
    plt.plot( 'x','z', data=df, linestyle='none', marker='o', color="orange", alpha=0.3)
    plt.savefig('PNG/#194_matplotlib_subplot1.png', dpi=96)
    
    
    # Show the graph
    plt.show()
    
    

     

  •  

     

     

     

    The concept is exactly the same, except that we cut the graphic window in 2 lines and 1 column: subplots(21x).

     

     

     

     

     

     

     

    
    # libraries and data
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    
    # fist line:
    plt.subplot(211)
    plt.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    plt.title("A subplot with 2 lines")
    
    # second line
    plt.subplot(212)
    plt.plot( 'x','z', data=df, linestyle='none', marker='o', color="orange", alpha=0.3)
    plt.show()
    
    
  •  

    Once you have 2 aligned graphics, you need to decide wether or not you want to share axis. If you do, the limits of the axis will be the same for all plots. If you don’t, limits will be adjusted for each plot.

    Think about it! The message is not the same in both case. Do you want to compare the average value in your plots? Or do you want to study the trend in both plot?

    Here the 2 plots share the axis. It allows to detect that the blue values are lower than the oranges. But the trend of the blue value is hard to observe!

     

     

     

    
    # libraries
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*10 })
    
    # initialise the figure. here we share X and Y axis
    fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True)
    axes[0].plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    axes[1].plot( 'x','z', data=df, linestyle='none', marker='o', color="orange", alpha=0.3)
    axes[0].title.set_text('These 2 plots have the same limit for the Y axis')
    plt.show()
    

     

  •  

     

     

     

    Following the same idea it is straightforward to make a 2 x 2 grid:

     

     

     

     

     

     

     

    
    # libraries
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    
    # Cut the window with 2 rows and 2 columns:
    plt.subplot(221)
    plt.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    plt.subplot(222)
    plt.plot( 'x','z', data=df, linestyle='none', marker='o', color="orange", alpha=0.3)
    plt.subplot(223)
    plt.plot( 'x','z', data=df, linestyle='none', marker='D', color="green", alpha=0.3)
    plt.subplot(224)
    plt.plot( 'x','z', data=df, marker='o', color="grey", alpha=0.3)
    plt.show()
    
    
  •  

     

     

     

    Here are a few tips to add a general title to your figure. It is done through the suptitle command. Do not forget to use y=1.0.. to add more space on the top of the figure! See more about titles here.

     

     

     

     

     

     

     

    
    # libraries
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    
    # initialize a figure
    fig=plt.figure()
    
    # Do a 2x2 chart
    plt.subplot(221)
    plt.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    plt.title('title of fig A', fontsize=12, color='grey', loc='left', style='italic')
    plt.subplot(222)
    plt.plot( 'x','z', data=df, linestyle='none', marker='o', color="orange", alpha=0.3)
    plt.title('title of fig A', fontsize=12, color='grey', loc='left', style='italic')
    plt.subplot(223)
    plt.plot( 'x','z', data=df, linestyle='none', marker='D', color="green", alpha=0.3)
    plt.title('title of fig A', fontsize=12, color='grey', loc='left', style='italic')
    plt.subplot(224)
    plt.plot( 'x','z', data=df, marker='o', color="grey", alpha=0.3)
    plt.title('title of fig A', fontsize=12, color='grey', loc='left', style='italic')
    
    # Add a title:
    plt.suptitle('A title common to my 4 plots', y=1.02)
    plt.show()
    
    
  • It is possible to custom the proportion of each graph of your split window. Here are 3 examples that illustrate this concept. You can read more about that here. It is done using the subplot2grid function. It works like that:

    plt.subplot2grid((2, 2), (0, 0), colspan=2)

    (2, 2): I cut my window in 2 lines and 2 columns
    (2, 2): I am going to add a plot in the line 0+1=1 of the column 0+1=1. This part starts counting on 0. So the first line has the number 0.
    colspan=2: This chart will be spread on 2 columns

     

    Example 1


    
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    ax1 = plt.subplot2grid((2, 2), (0, 0), colspan=2)
    ax1.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    ax2 = plt.subplot2grid((2, 2), (1, 0), colspan=1)
    ax2.plot( 'x','z', data=df, marker='o', color="grey", alpha=0.3)
    ax3 = plt.subplot2grid((2, 2), (1, 1), colspan=1)
    ax3.plot( 'x','z', data=df, marker='o', color="orange", alpha=0.3)
    
    

    Example 2


    
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    
    # 4 columns and 2 rows
    # The first plot is on line 1, and is spread all along the 4 columns
    ax1 = plt.subplot2grid((2, 4), (0, 0), colspan=4)
    ax1.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    # The second one is on column2, spread on 3 columns
    ax2 = plt.subplot2grid((2, 4), (1, 0), colspan=3)
    ax2.plot( 'x','z', data=df, marker='o', color="grey", alpha=0.3)
    # The last one is spread on 1 column only, on the 4th column of the second line.
    ax3 = plt.subplot2grid((2, 4), (1, 3), colspan=1)
    ax3.plot( 'x','z', data=df, marker='o', color="orange", alpha=0.3)
    

     

    Example 3


    
    from matplotlib import pyplot as plt
    import pandas as pd
    import numpy as np
    df=pd.DataFrame({'x': range(1,101), 'y': np.random.randn(100)*15+range(1,101), 'z': (np.random.randn(100)*15+range(1,101))*2 })
    ax1 = plt.subplot2grid((2, 2), (0, 0), colspan=1)
    ax1.plot( 'x', 'y', data=df, marker='o', alpha=0.4)
    ax2 = plt.subplot2grid((2, 2), (1, 0), colspan=1)
    ax2.plot( 'x','z', data=df, marker='o', color="grey", alpha=0.3)
    ax3 = plt.subplot2grid((2, 2), (0, 1), rowspan=2)
    ax3.plot( 'x','z', data=df, marker='o', color="orange", alpha=0.3)
    
    

     

1 comment

Leave a Reply

Your email address will not be published.