2022年3月29日火曜日

diag関数なしで高速な多次元ガウス関数を実装する

この式で表される多次元ガウス関数をscikit-learnなど使わずに自分で実装したい. 下の通りだけど実装するにあたりWeb上のいろんな解説を見ていると ガウス関数のPythonコードみたいに何故かexp関数に代入する前にdiag関数を呼んでいる. なんでdiag関数が必要なのか?これが本当にいい方法なのか?実はこのdiag関数を使うと異常にメモリを消費して計算速度が遅くなる. 今回はそれらを確認してみた.

\begin{equation} \newcommand{\vec}{\boldsymbol} f(\vec{x}) = \frac{1}{\sqrt{(2\pi)^n \left|\Sigma\right|}}\exp \left \{-\frac{1}{2}{} \vec{d}^T\,{\Sigma}^{-1} \vec{d} \right\},\ \ \vec{d}=\vec{x}-\vec{\mu} \end{equation}
import numpy as np
def gaussian(x,mu,sigma):
    det = np.linalg.det(sigma)
    inv = np.linalg.inv(sigma)
    d = x - mu
    n = x.ndim
    norm = (np.sqrt((2 * np.pi) ** n * det))
    return np.exp(-np.diag(d@inv@d.T)/2.0) / norm

そもそもの疑問

そもそも何が疑問なのかというと,exp関数の入力はスカラーなのだからnp.diagなんて要らないのでは?というもの. diag関数は行列が入力されたときに対角成分を出力する関数で例えば次のような行列Aに対してはベクトルが出力になる.

A = np.array([
    [1, 2],
    [3, 4]])
print(np.diag(A))
[1 4]

そこでdiagにどのような次元の行列が入力されることになるのかを見てみる. 多次元ガウス関数は入力が多次元なのであって,出力は1次元. なのでexp関数は,いわゆる高校数学で習う指数関数なので入力も出力も1次元. 指数部分を見ると

diff@inv@diff.T

は数式では$\vec{d}^T\,{\Sigma}^{-1} \vec{d}$. $\vec{d}=\vec{x}-\vec{\mu}$は$n\times 1$のベクトルどうしの引き算なのでその結果も$n\times 1$のベクトル. ${\Sigma}^{-1}$は分散共分散行列であり$n\times n$の対称行列. $n\times 1$ベクトルと$n\times n$行列の積は$n\times 1$ベクトル.つまり${\Sigma}^{-1}\vec{d}$は$n\times 1$ベクトル. それと$\vec{d}^T$の積は,内積そのものだから$\vec{d}^T\,{\Sigma}^{-1} \vec{d}$はスカラーになる. つまり,行列を入力とするdiag関数なんて不要じゃねーか?というのが疑問.

diag関数の意味

実際にdiag関数はスカラーでは次のようなエラーがでる.

np.diag(100)
ValueError: Input must be 1- or 2-d.

そのため上のガウス関数gaussianはデータ点1つでは正しく動作しない. やはりdiag関数にスカラーを入力したときと同じエラーがでる.

mu = np.array([10,10])
sigma = np.array([
    [4**2, 2],
    [2, 3**2]])
x = np.array([1,1])
gaussian(x, mu, sigma)
ValueError: Input must be 1- or 2-d.

要するにgaussian関数は,入力のベクトルが多次元というだけでなく, そのベクトルが多数ある場合のbroadcastを前提とした設計だった. そこで2次元んベクトルが3つの場合を確認してみる.

x = np.array([
    [1,1],
    [2,2],
    [3,3])
diff@inv@dif.T
[[12.15 10.8   9.45]
 [10.8   9.6   8.4 ]
 [ 9.45  8.4   7.35]]

なんと$3\times 3$の行列になってしまう. gaussian関数の入力が2次元ベクトルのとき結果がスカラーなら, 入力が2次元bベクトル3個分の$3\times 2$行列なら結果は,スカラー3個分の$3\times 1$などになりそうなものだけどそうではなかった.

一体どうなっているか考えてみた.まず$\vec{d}^T\Sigma^{-1}\vec{d}$の部分を成分表示してみる.

\begin{eqnarray} \vec{d}^T\Sigma^{-1}\vec{d} &=& \left[ \begin{array}{cc} d_{11} & d_{12}\\ \end{array}\right] \left[ \begin{array}{ccc} \sigma_{11} & \sigma_{21}\\ \sigma_{21} & \sigma_{22}\\ \end{array}\right] \left[ \begin{array}{c} d_{11} \\ d_{12} \\ \end{array}\right] \end{eqnarray}

変数の中身はともかくこんな形になっているはず.データ点が3つあれば変数d(つまり数式上の$d^T$)は3行になるので↓のようになるはず.

\begin{eqnarray} \left[ \begin{array}{cc} d_{11} & d_{12}\\ d_{21} & d_{22}\\ d_{31} & d_{32}\\ \end{array}\right] \left[ \begin{array}{cc} \sigma_{11} & \sigma_{21}\\ \sigma_{21} & \sigma_{22}\\ \end{array}\right] \left[ \begin{array}{c} d_{11} & d_{21} & d_{31}\\ d_{12} & d_{22} & d_{32}\\ \end{array}\right] \end{eqnarray}

この行列のまま計算すると数式が複雑になりすぎるので成分のみの計算にする. 左2つのベクトルの積を$\vec{b}=d^T\Sigma^{-1}$とすると,その成分は

\begin{eqnarray} b_{ij} &=&\sum_l d_{il}\,\sigma_{lj} \end{eqnarray}

と表せる.これは行列の積の公式そのもの.これを用いて残りも含めてた全体を成分で表示してみる. $\vec{a}=d^T\Sigma^{-1}\vec{d}=\vec{b}\vec{d}$とすると,

\begin{eqnarray} a_{ij} &=&\sum_k b_{ik}\,d_{jk}\\ &=&\sum_k \sum_l d_{il}\,\sigma_{lk}\,d_{jk}\label{tensor}\\ \end{eqnarray}

と表せる.この成分をよく見ると$d_{ik}$の$i$がデータ数に対応する. 最後の式で$d_{il}$と$d_{jk}$とあるけど,もしデータが1つしか無かったら$i=1$の一行だけになりこのとき$d_{1x}$の形の成分しか存在しない. ということは$i=j=1$に限られることになる.2つ目のデータしかないときは$i=j=2$に限られ,以下同様. 結局$a_{ij}$の対角成分が$d^T\Sigma^{-1}\vec{d}$の結果を表していることになる.

\begin{eqnarray} \left[ \begin{array}{ccc} a_{11} & & \\ & a_{22} & \\ & & a_{33}\\ \end{array}\right] \end{eqnarray}

だからdiag関数が必要だった.ブロードキャストが原因だった.

diag関数は非効率?

ここまで見てすぐに分かるのは行列$a_{ij}$の対角成分以外は必要ない. なのにデータが$N$個ある場合,一旦$N\times N$行列が生成されていることになる. これはデータ数が多いとすぐに破綻するメモリ的にも計算時間的にも無駄が多い.

def gaussian2(x, mu, sigma):
    det = np.linalg.det(sigma)
    inv = np.linalg.inv(sigma)
    d = x - mu
    n = x.ndim
    norm = np.sqrt((2 * np.pi) ** n * det)
    power = - np.einsum('il,lk,ik->i', d, inv, d) /2.0
    return np.exp(power) / norm

einsum関数は,添え字を指定するだけでの行列(テンソル)どおしの必要な積が定義できる関数. 式$\ref{tensor}$の添え字をそのまま記入すればいい.$i=j$なので$j$は$i$に置き換えられている.

x = np.linspace(0, 40,40)
y = np.linspace(0, 40,40)
X, Y = np.meshgrid(x, y)
XY = np.c_[X.ravel(), Y.ravel()]

mu = np.array([10,10])
sigma = np.array([
    [4**2, 2],
    [2, 3**2]])

これらのパラメータでgaussianとgaussian2の計算時間をテストしてみた.

timeit gaussian(XY, mu, sigma)
16.5 ms ± 207 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
timeit gaussian2(XY, mu, sigma)
198 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

0.198/16.5=0.012とおよそ100倍のスピード. loopの回数は自動的に決まるみたい. とにかく圧倒的にdiag関数は無駄無駄無駄.