Kei Minagawa's Blog

皆川圭(@keimina)のブログ、Pythonで試したことを書いていきます

sympy でニューラルネットワークの重みの更新に使用する式を計算グラフにしてみる

1. はじめに

sympy(https://www.sympy.org/en/index.html) を使用し、ニューラルネットワークの重みの更新に使用する式を計算グラフにしてみます。具体的には、 sympy を使って、損失関数の勾配の計算を計算グラフとして表し、それを graphviz で画像に出力します。sympy と graphviz を使用すると、 数式を簡単に計算グラフに変換できます。ちなみに計算グラフとは、以下のようなものです。(実際にこれらを使用し作成したものです)

f:id:keimina:20190325203008p:plain
図1. c = a + b の計算グラフ

このように計算式をグラフとして表現したものを計算グラフと呼びます。

2. 使用するモジュールと、そのインストール

本記事では、以下の二つのモジュールを使用します。グラフの描画に graphviz というモジュールを使用します。sympy は計算グラフをテキストで出力できますが、単体では画像で出力することができないためグラフ描画アプリ(graphviz)を使用します。

  1. sympy
  2. graphviz

graphviz は以下のコマンドでインストールしました。

pip install graphviz

コマンドの出力結果

Collecting graphviz
  Downloading https://files.pythonhosted.org/packages/1f/e2/ef2581b5b86625657afd32030f90cf2717456c1d2b711ba074bf007c0f1a/graphviz-0.10.1-py2.py3-none-any.whl
distributed 1.21.8 requires msgpack, which is not installed.
Installing collected packages: graphviz
Successfully installed graphviz-0.10.1
You are using pip version 10.0.1, however version 19.0.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

コマンド実行後、"requires msgpack" と出力されていますが、 最終的に"Successfully installed" と出力されているため、問題なしとみなし気にせず先に進みました。

今回は、計算グラフを画像で出力したいため、 さらに以下の手順行う必要がありました。

  1. MacPorts のインストール(インストール手順は以下のサイトを参照)
  2. 以下のコマンドの実行しアプリをインストール
sudo port install graphviz

sympy のインストール手順について説明は省略します。

3. ニューラルネットワークのモデルと、順伝播の式

本記事で扱うニューラルネットワークのモデルと、順伝播の式は以下の通りです。

f:id:keimina:20190325203628p:plain
図2. ニューラルネットワークのモデルと、順伝播の式

4. ソースコード

3. のニューラルネットワークの損失関数の勾配を計算するコードにすると以下のようになります。
※本記事は、重みの更新を行う計算を、計算グラフに出力することを目的としているため、説明をわかりやすくするため損失関数 l は o1 としています。本来であれば、l = (o1 - y)**2 のような定義をすべきですが、パラメータが増えると、計算グラフが大きくなり見づらくなるため、簡略化しています。ご了承ください。
ソースコードで numpy を使用して行列計算していますが、sympy での行列計算の仕方がわからなかったためです。あまり真似しないように。。。

from sympy import symbols, Function, tanh
from sympy.printing import dot
import graphviz
import numpy as np

# sympy で使用する変数(シンボル)を定義する
# 入力値、中間層の値
i1, i2 = symbols("i1, i2")
h11, h12 = symbols("h11, h12")

# 重み
w111, w112, w121, w122 = symbols("w111, w112, w121, w122")
w211, w212 = symbols("w211, w212")

# バイアス
b11, b12 = symbols("b11, b12")
b21 = symbols("b21")

# numpy を利用して行列形式にする
I = np.array([[i1], [i2]])
W1 = np.array([[w111, w112], [w121, w122]])
W2 = np.array([[w211, w212]])
B1 = np.array([[b11], [b12]])
B2 = np.array([[b21]])

# 活性化関数
TANH = np.vectorize(tanh)

# 順伝播の数式を作成する
H1 = TANH(W1.dot(I) + B1)
O = TANH(W2.dot(H1) + B2)

# 行列からスカラを取り出す
o1 = O[0, 0]

# 最急降下法 によって重み w1 を更新する式(学習係数は 1 とする)
# 損失関数を定義する、重み更新を行う計算グラフを出したいだけなので l = o1  としてしまう
l = o1
dl_dw111 = l.diff(w111)

expr_forward = o1
src_forward = graphviz.Source(dot.dotprint(expr_forward))

expr_backward = dl_dw111
src_backward = graphviz.Source(dot.dotprint(expr_backward))

# 私の環境ではExecutableNotFound エラーが発生したので、以下のコードで回避しました
app_path = '/opt/local/bin/dot'
graphviz.ENGINES.add(app_path)
src_forward._engine = app_path
src_backward._engine = app_path

# 画像ファイルを作成する
filename_forward = "calc_forward_graph"
src_forward.render(filename_forward)
filename_backward = "calc_backward_graph"
src_backward.render(filename_backward)

順伝播の式が o1 です。勾配降下法によって損失関数 l を最小にするため重みパラメータ w111 を更新するときに行われる計算が dl_dw111 です。コードの最後の方で、sympy と graphviz を使用して、o1 と dl_dw111 の計算を行う計算グラフを出力しています。sympy.printing.dot で計算グラフを graphviz で使用可能な dot フォーマット形式のテキストに変換しています。そして、 dot フォーマット形式のテキストを graphviz.Source で graphviz に取り込み、 render メソッドを呼びだし pdf ファイルを作成しています。

5. 出力結果

出力結果は以下になります。図は上から順に、順伝播の計算を行うための計算グラフ(o1 の計算を行うグラフ)、損失関数lを最小化すべく w111 を勾配降下法により更新するための計算グラフ(dl_dw111 の計算を行うグラフ)、となっています

f:id:keimina:20190325204132p:plain
図3. 順伝播の計算グラフ("calc_forward_graph")

f:id:keimina:20190325204331p:plain
図4. w111 の重みを更新する数式の計算グラフ("calc_backward_graph")

楕円の中に書かれている Mul 、 Add 、 Pow 、 tanh はどのような計算を行うかを表していて、それぞれ、 乗算、加算、累乗、ハイパボリックタンジェントを表しています。ニューラルネットの学習時には、図中の w111、w112 、w121、 b11 、、、 などの変数のように見えるものは全て値が割り振られているため、 この計算グラフの計算を行うと、何かしらの値(スカラ値)が求まります。そしてこの値を使用して重み w111 を更新します。

6. まとめ

sympy と graphviz を使用して、ニューラルネットワークの重みの更新に使用する式を計算グラフにしました。このように、ニューラルネットワークのモデルが決まると、重みを更新するための式、すなわち計算グラフも決まります。

7. 参考文献

  1. ゼロから作るDeep Learning――Pythonで学ぶディープラーニングの理論と実装