2022年3月29日火曜日

diag関数なしで高速な多次元ガウス関数を実装する

この式で表される多次元ガウス関数をscikit-learnなど使わずに自分で実装したい. 下の通りだけど実装するにあたりWeb上のいろんな解説を見ていると ガウス関数のPythonコードみたいに何故かexp関数に代入する前にdiag関数を呼んでいる. なんでdiag関数が必要なのか?これが本当にいい方法なのか?実はこのdiag関数を使うと異常にメモリを消費して計算速度が遅くなる. 今回はそれらを確認してみた.

(1)f(x)=1(2π)n|Σ|exp{12dTΣ1d},  d=xμ
import numpy as np
def gaussian(x,mu,sigma):
    det = np.linalg.det(sigma)
    inv = np.linalg.inv(sigma)
    d = x - mu
    n = x.ndim
    norm = (np.sqrt((2 * np.pi) ** n * det))
    return np.exp(-np.diag(d@inv@d.T)/2.0) / norm

そもそもの疑問

そもそも何が疑問なのかというと,exp関数の入力はスカラーなのだからnp.diagなんて要らないのでは?というもの. diag関数は行列が入力されたときに対角成分を出力する関数で例えば次のような行列Aに対してはベクトルが出力になる.

A = np.array([
    [1, 2],
    [3, 4]])
print(np.diag(A))
[1 4]

そこでdiagにどのような次元の行列が入力されることになるのかを見てみる. 多次元ガウス関数は入力が多次元なのであって,出力は1次元. なのでexp関数は,いわゆる高校数学で習う指数関数なので入力も出力も1次元. 指数部分を見ると

diff@inv@diff.T

は数式ではdTΣ1dd=xμn×1のベクトルどうしの引き算なのでその結果もn×1のベクトル. Σ1は分散共分散行列でありn×nの対称行列. n×1ベクトルとn×n行列の積はn×1ベクトル.つまりΣ1dn×1ベクトル. それとdTの積は,内積そのものだからdTΣ1dはスカラーになる. つまり,行列を入力とするdiag関数なんて不要じゃねーか?というのが疑問.

diag関数の意味

実際にdiag関数はスカラーでは次のようなエラーがでる.

np.diag(100)
ValueError: Input must be 1- or 2-d.

そのため上のガウス関数gaussianはデータ点1つでは正しく動作しない. やはりdiag関数にスカラーを入力したときと同じエラーがでる.

mu = np.array([10,10])
sigma = np.array([
    [4**2, 2],
    [2, 3**2]])
x = np.array([1,1])
gaussian(x, mu, sigma)
ValueError: Input must be 1- or 2-d.

要するにgaussian関数は,入力のベクトルが多次元というだけでなく, そのベクトルが多数ある場合のbroadcastを前提とした設計だった. そこで2次元んベクトルが3つの場合を確認してみる.

x = np.array([
    [1,1],
    [2,2],
    [3,3])
diff@inv@dif.T
[[12.15 10.8   9.45]
 [10.8   9.6   8.4 ]
 [ 9.45  8.4   7.35]]

なんと3×3の行列になってしまう. gaussian関数の入力が2次元ベクトルのとき結果がスカラーなら, 入力が2次元bベクトル3個分の3×2行列なら結果は,スカラー3個分の3×1などになりそうなものだけどそうではなかった.

一体どうなっているか考えてみた.まずdTΣ1dの部分を成分表示してみる.

(2)dTΣ1d=[d11d12][σ11σ21σ21σ22][d11d12]

変数の中身はともかくこんな形になっているはず.データ点が3つあれば変数d(つまり数式上のdT)は3行になるので↓のようになるはず.

(3)[d11d12d21d22d31d32][σ11σ21σ21σ22][d11d21d31d12d22d32]

この行列のまま計算すると数式が複雑になりすぎるので成分のみの計算にする. 左2つのベクトルの積をb=dTΣ1とすると,その成分は

(4)bij=ldilσlj

と表せる.これは行列の積の公式そのもの.これを用いて残りも含めてた全体を成分で表示してみる. a=dTΣ1d=bdとすると,

(5)aij=kbikdjk(6)=kldilσlkdjk

と表せる.この成分をよく見るとdikiがデータ数に対応する. 最後の式でdildjkとあるけど,もしデータが1つしか無かったらi=1の一行だけになりこのときd1xの形の成分しか存在しない. ということはi=j=1に限られることになる.2つ目のデータしかないときはi=j=2に限られ,以下同様. 結局aijの対角成分がdTΣ1dの結果を表していることになる.

(7)[a11a22a33]

だからdiag関数が必要だった.ブロードキャストが原因だった.

diag関数は非効率?

ここまで見てすぐに分かるのは行列aijの対角成分以外は必要ない. なのにデータがN個ある場合,一旦N×N行列が生成されていることになる. これはデータ数が多いとすぐに破綻するメモリ的にも計算時間的にも無駄が多い.

def gaussian2(x, mu, sigma):
    det = np.linalg.det(sigma)
    inv = np.linalg.inv(sigma)
    d = x - mu
    n = x.ndim
    norm = np.sqrt((2 * np.pi) ** n * det)
    power = - np.einsum('il,lk,ik->i', d, inv, d) /2.0
    return np.exp(power) / norm

einsum関数は,添え字を指定するだけでの行列(テンソル)どおしの必要な積が定義できる関数. 式6の添え字をそのまま記入すればいい.i=jなのでjiに置き換えられている.

x = np.linspace(0, 40,40)
y = np.linspace(0, 40,40)
X, Y = np.meshgrid(x, y)
XY = np.c_[X.ravel(), Y.ravel()]

mu = np.array([10,10])
sigma = np.array([
    [4**2, 2],
    [2, 3**2]])

これらのパラメータでgaussianとgaussian2の計算時間をテストしてみた.

timeit gaussian(XY, mu, sigma)
16.5 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
timeit gaussian2(XY, mu, sigma)
198 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

0.198/16.5=0.012とおよそ100倍のスピード. loopの回数は自動的に決まるみたい. とにかく圧倒的にdiag関数は無駄無駄無駄.