2013年12月10日火曜日

Ceres Solverの使用例(放物線の最小二乗法)

Ceres SolverはGoogleが開発した非線形最適化のためのC++ライブラリ。 使い方は本家サイトにもチュートリアルがあるし、サンプルコードもある。 でも、どれも単純すぎたり、難しすぎたり中途半端なので、 データが復数ある最小二乗法のサンプルコードを載せる。

\[ y=a x^2 \]

この関数をフィッティングする場合、データとして$(x_i, y_i)\ (i=0,1,\cdots,n-1)$が与えれるので、 求めるパラメータは、$a$のみ。ここのデータの誤差を \[ e_i = y_i - a*x_i^2 \] とすると、単純な最小二乗法では、 \[ E=\sum_i e_i^2 \] が最小化される。 ここでは、データ$(x_i, y_i)$として次のファイル「data.txt」が与えられるとする。

0 0
1 1.2
2 3.9
-1 0.9
-2 4.1

サンプルコード

まず、ヘッダー部分は単純にceresのヘッダをインクルードすれば良い。 namespaceはライブラリのexamplesを真似して追加した。 OpenCVのヘッダはデータ入力に使うだけなので本来必要ない。

#include "ceres/ceres.h"
#include "opencv2/opencv.hpp"
 
using ceres::AutoDiffCostFunction;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solver;
using ceres::Solve;

次に、誤差関数の定義部分。 ErrorFunctionクラスのメンバ関数として定義する必要がある。 operator()関数をテンプレート関数として ここに書くのは$e_i$の部分だけ。 パラメータ$a$も残差$e_i$も配列で定義するみたい。 汎用性を考えてType型という抽象型になっているだけで今回はdoubleと思えば良い。

class ErrorFunction{
private:
    //1組のデータ
    const double x;
    const double y;
 
public:
    //データを1つずつ入力する際このコンストラクタが呼ばれる。
    ErrorFunction(double x, double y) : x(x), y(y) {}
 
    //e_iの定義。
    template <typename Type> bool operator()(const Type* const a, Type* residual) const {

        //e_i = y_i - a*x_i*x_i
        residual[0] = Type(y)  - a[0]*Type(x)*Type(x);
        return true;
    }
};

誤差関数が定義できれば、あとはデータを与えて最適化するだけ。 データの入力は面倒なので 「cv::Mat_をテキストファイルに保存・読み込み 」 で紹介した方法を使う。

void main(int argc, char *argv[]){

    //最適化するパラメータ。初期値を入力しておく。
    double a = 0.0;
 
    //データの入力
    cv::Mat_<double> data;
    readTxt("data.txt", data);
 
    //データの入力
    Problem problem;
    for (int i = 0; i < data.rows; ++i) {

        //x_i,y_iを1つずつ入力する
        problem.AddResidualBlock(
            new AutoDiffCostFunction<ErrorFunction, 1, 1>
            (
                //ここでデータを入れている
                new ErrorFunction(data(i, 0), data(i, 1))
            ),
            NULL,    //Loss function
            &a
        );
        //std::cout << data(i, 0) << ' ' << data(i, 1) << std::endl;
    }
 
    //ソルバの準備
    Solver::Options options;

    //CXSparseをインストールしていないときはDENSE_*にしないとエラーになる。
    options.linear_solver_type = ceres::DENSE_NORMAL_CHOLESKY;   
 
    //最適化
    Solver::Summary summary;
    Solve(options, &problem, &summary);
 
    std::cout << "Fit complete:   a: " << a << "\n";
    return;
}

参考