この式で表される多次元ガウス関数をscikit-learnなど使わずに自分で実装したい. 下の通りだけど実装するにあたりWeb上のいろんな解説を見ていると ガウス関数のPythonコードみたいに何故かexp関数に代入する前にdiag関数を呼んでいる. なんでdiag関数が必要なのか?これが本当にいい方法なのか?実はこのdiag関数を使うと異常にメモリを消費して計算速度が遅くなる. 今回はそれらを確認してみた.
import numpy as np def gaussian(x,mu,sigma): det = np.linalg.det(sigma) inv = np.linalg.inv(sigma) d = x - mu n = x.ndim norm = (np.sqrt((2 * np.pi) ** n * det)) return np.exp(-np.diag(d@inv@d.T)/2.0) / norm
そもそもの疑問
そもそも何が疑問なのかというと,exp関数の入力はスカラーなのだからnp.diagなんて要らないのでは?というもの. diag関数は行列が入力されたときに対角成分を出力する関数で例えば次のような行列Aに対してはベクトルが出力になる.
A = np.array([ [1, 2], [3, 4]]) print(np.diag(A))
[1 4]
そこでdiagにどのような次元の行列が入力されることになるのかを見てみる. 多次元ガウス関数は入力が多次元なのであって,出力は1次元. なのでexp関数は,いわゆる高校数学で習う指数関数なので入力も出力も1次元. 指数部分を見ると
diff@inv@diff.T
は数式では
diag関数の意味
実際にdiag関数はスカラーでは次のようなエラーがでる.
np.diag(100)
ValueError: Input must be 1- or 2-d.
そのため上のガウス関数gaussianはデータ点1つでは正しく動作しない. やはりdiag関数にスカラーを入力したときと同じエラーがでる.
mu = np.array([10,10]) sigma = np.array([ [4**2, 2], [2, 3**2]]) x = np.array([1,1]) gaussian(x, mu, sigma)
ValueError: Input must be 1- or 2-d.
要するにgaussian関数は,入力のベクトルが多次元というだけでなく, そのベクトルが多数ある場合のbroadcastを前提とした設計だった. そこで2次元んベクトルが3つの場合を確認してみる.
x = np.array([ [1,1], [2,2], [3,3]) diff@inv@dif.T
[[12.15 10.8 9.45] [10.8 9.6 8.4 ] [ 9.45 8.4 7.35]]
なんと
一体どうなっているか考えてみた.まず
変数の中身はともかくこんな形になっているはず.データ点が3つあれば変数d(つまり数式上の
この行列のまま計算すると数式が複雑になりすぎるので成分のみの計算にする.
左2つのベクトルの積を
と表せる.これは行列の積の公式そのもの.これを用いて残りも含めてた全体を成分で表示してみる.
と表せる.この成分をよく見ると
だからdiag関数が必要だった.ブロードキャストが原因だった.
diag関数は非効率?
ここまで見てすぐに分かるのは行列
def gaussian2(x, mu, sigma): det = np.linalg.det(sigma) inv = np.linalg.inv(sigma) d = x - mu n = x.ndim norm = np.sqrt((2 * np.pi) ** n * det) power = - np.einsum('il,lk,ik->i', d, inv, d) /2.0 return np.exp(power) / norm
einsum関数は,添え字を指定するだけでの行列(テンソル)どおしの必要な積が定義できる関数.
式
x = np.linspace(0, 40,40) y = np.linspace(0, 40,40) X, Y = np.meshgrid(x, y) XY = np.c_[X.ravel(), Y.ravel()] mu = np.array([10,10]) sigma = np.array([ [4**2, 2], [2, 3**2]])
これらのパラメータでgaussianとgaussian2の計算時間をテストしてみた.
timeit gaussian(XY, mu, sigma)
16.5 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
timeit gaussian2(XY, mu, sigma)
198 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
0.198/16.5=0.012とおよそ100倍のスピード. loopの回数は自動的に決まるみたい. とにかく圧倒的にdiag関数は無駄無駄無駄.