AnyTech Engineer Blog

AnyTech Engineer Blogは、AnyTechのエンジニアたちによる調査や成果、Tipsなどを公開するブログです。

matplotlibの高速化手法

matplotlibの高速化手法

こんにちは、AnyTechの岩井です。

今回は推論結果の可視化などで使うことが多いにも関わらず推論より遅いじゃないか!となることがあるmatplotlibでのplotを高速化する手法を備忘録もかねてここに紹介したいと思います。

実行環境

  • OS : Ubuntu20.04
  • CPU: Intel® Core™ i9-10850K
  • メモリ: 64GB

各種手法比較

検証用設定

ノイズののった正弦波と余弦波をplotし、その上を点が動くという5分の動画を作成してみます。 30fpsで5分ですので各波は9000個のデータで構成されます。 plot数は各波2つと点2つの計4つになります。

高速化手法なし

まずは普通に書くとどのくらい遅いかを確認するため、点を動かす度に全部plotし直すコードを書きました。

検証用コード

import tqdm
import cv2
import numpy as np
import matplotlib.pyplot as plt


def get_plot_data():
    """
    sin, cosの波にノイズを加えたデータを作成
    """
    freq = 0.01
    x_array = np.linspace(start=0.0, stop=300.0, num=9000)
    sin_array = np.sin(2*np.pi*freq*x_array) + 0.05*np.random.normal(0.0, 1.0, x_array.shape)
    cos_array = np.cos(2*np.pi*freq*x_array) + 0.05*np.random.normal(0.0, 1.0, x_array.shape)

    return x_array, sin_array, cos_array


def plot_setting():
    """
    グラフの初期設定
    """
    plt.clf()
    plt.close()
    fig, ax = plt.subplots(figsize=(12.8, 7.2))
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_xlim([0, 300])
    ax.set_ylim([-1.5, 1.5])

    return fig, ax


def main():
    np.random.seed(234)
    vid = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (1280, 720))

    x_array, sin_array,  cos_array = get_plot_data()
    for i in tqdm.tqdm(range(x_array.shape[0])):
        fig, ax = plot_setting()

        # sin, cosのplot
        ax.plot(x_array, sin_array, linestyle='solid', linewidth=2, color='aqua', zorder=1, alpha=0.6)
        ax.plot(x_array, cos_array, linestyle='solid', linewidth=2, color='darkkhaki', zorder=1, alpha=0.6)

        # 動く点のplot
        ax.plot(x_array[i], sin_array[i], marker='.', color='blue', zorder=2, markersize=20)
        ax.plot(x_array[i], cos_array[i], marker='.', color='olivedrab', zorder=2, markersize=20)

        # 描画とnumpy array化
        fig.canvas.draw()
        plt_img = np.array(fig.canvas.renderer.buffer_rgba())[:,:,:3]
        plt_img = cv2.cvtColor(plt_img, cv2.COLOR_RGB2BGR)
        
        vid.write(plt_img)
    vid.release()


if __name__ == '__main__':
    main()

これを実行すると私の環境では5分の動画を作成するのに10分8秒かかるという結果になりました。 動画時間の約2倍かかってますね。 気軽に動画6個作るかなんて思って実行したら1時間待たされます。

mplstyle.use('fast')を書く

先程の検証用コードに少し追記します。

import matplotlib.style as mplstyle
mplstyle.use('fast')

これだと5分の動画に8分51秒かかりました。 まだ動画時間よりも時間がかかってますが、少し追記するだけで早くなるのは魅力的ですね。とりあえず書いとけばよさそうです。

更新するべきものだけ更新する

plotにこんなに時間がかかっているのは毎度全てのplotを描画し直しているからです。 今回作成している動画では、実線で描かれた正弦波と余弦波は更新する必要はなく、動く点だけを更新するだけでいいことに気が付きます。 それを実現するにはコードを以下のように書きます

import tqdm
import cv2
import numpy as np
import matplotlib.pyplot as plt


def get_plot_data():
    """
    sin, cosの波にノイズを加えたデータを作成
    """
    freq = 0.01
    x_array = np.linspace(start=0.0, stop=300.0, num=9000)
    sin_array = np.sin(2*np.pi*freq*x_array) + 0.05*np.random.normal(0.0, 1.0, x_array.shape)
    cos_array = np.cos(2*np.pi*freq*x_array) + 0.05*np.random.normal(0.0, 1.0, x_array.shape)

    return x_array, sin_array, cos_array


def plot_setting(x_array, sin_array, cos_array):
    """
    グラフの初期設定
    """
    plt.clf()
    plt.close()
    fig, ax = plt.subplots(figsize=(12.8, 7.2))
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_xlim([0, 300])
    ax.set_ylim([-1.5, 1.5])

    # sin, cosのplot
    ax.plot(x_array, sin_array, linestyle='solid', linewidth=2, color='aqua', zorder=1, alpha=0.6)
    ax.plot(x_array, cos_array, linestyle='solid', linewidth=2, color='darkkhaki', zorder=1, alpha=0.6)

    # 動く点のplot
    ax.plot([], [], marker='.', color='blue', zorder=2, markersize=20)
    ax.plot([], [], marker='.', color='olivedrab', zorder=2, markersize=20)

    # 描画と背景の保存
    fig.canvas.draw()
    bg = fig.canvas.copy_from_bbox(ax.bbox)

    return fig, ax, bg


def update_plot(fig, ax, bg, x_data, sin_data, cos_data):
    """
    動く点の位置を更新
    """
    lines = ax.get_lines()
    fig.canvas.restore_region(bg)

    # sin側の動く点のplot
    lines[2].set_xdata(x_data)
    lines[2].set_ydata(sin_data)

    # cos側の動く点のplot
    lines[3].set_xdata(x_data)
    lines[3].set_ydata(cos_data)

    # 更新
    ax.draw_artist(lines[2])
    ax.draw_artist(lines[3])
    fig.canvas.blit(ax.bbox)


def main():
    np.random.seed(234)
    vid = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (1280, 720))
    
    x_array, sin_array,  cos_array = get_plot_data()
    fig, ax, bg = plot_setting(x_array, sin_array, cos_array)
    for i in tqdm.tqdm(range(x_array.shape[0])):
        update_plot(fig, ax, bg, x_array[i], sin_array[i], cos_array[i])

        # numpy array化
        plt_img = np.array(fig.canvas.renderer.buffer_rgba())[:,:,:3]
        plt_img = cv2.cvtColor(plt_img, cv2.COLOR_RGB2BGR)
        vid.write(plt_img)
    vid.release()


if __name__ == '__main__':
    main()

これを実行すると5分の動画に58秒かかるという結果になりました。 圧倒的に早くなりましたしこれなら30fpsのリアルタイム描画にも十分耐えられます。 ちなみにmplstyle.use('fast')を付けても58秒であり差はありませんでした。

まとめ

今回はmatplotlibのplotを高速化する手法を紹介しました。 本当はアニメーションを作るならmatplotlib.animation.FuncAnimation()を使う手法もあるのですが、私は結構1フレームずつ処理してnumpy arrayにすることが多いので今回は省いてます。 とにかく早くするなら更新すべきものだけ更新する手法が最も早くなりますがコード量が少し増加してしまいます。コード変えたくないけど少し早くしてほしいという場合はmplstyle.use('fast')だけ書いておくのも手ですね。 この結果はあくまで9000個のデータを2つplotし、その上を点が動くという設定における結果ですので、データ数がもっと少ない場合やplotが少ない場合あまり効果が感じられない可能性はあります。 しかしながら、matplotlibの遅さに悩んでいる方はぜひ試してみてください。