天天看点

SLAM--ceres库--曲线逼近示例--啰里啰嗦的代码解析

#include <ceres/ceres.h>
#include <iostream>
#include <chrono>
#include <opencv2/core/core.hpp>

using namespace std;

//代价函数的计算模型
struct CURVE_FITTING_COST
{
    // 仿函数,使用方式 CURVE_FITTING_COST(const T* const abc, T* residual)
    // 传入的是真实值(x-x_data[i],y-y_data[i])
    CURVE_FITTING_COST (double x, double y) : _x (x), _y (y) {}
    //残差计算
    //操作符()重载
    template <typename T>
    bool operator() (const T* const abc, T* residual) const  //residual残差
    {
        //y-exp(ax^2+bx+c)
        residual[0] = T(_y) - ceres::exp(abc[0]*T(_x)*T(_x)+abc[1]*T(_x)+abc[2]);
        return true;
    }
    const double _x, _y;  //x,y数据
};

int main(int argc, char** argv)
{
    //a,b,c是曲线真实模型的参数值,要求一条与该曲线逼近的曲线的参数a,b,c估计值
    double a = 1.0, b = 2.0, c = 1.0;  //计算结果: estimate a, b, c = 0.891943 2.17039 0.944142
    //double a = 1.0, b = 5.0, c = 2.0; // 计算结果: estimate a, b, c = 0.989318 5.0169 1.99377
    int N = 100;  //100个数据点
    double w_sigma = 1.0;  //噪声值
    cv::RNG rng;         //OpenCV随机数产生器
    double abc[3] = {0.0, 0.0, 0.0}; //abc的参数估计值,都初始化为0.0

    vector<double> x_data, y_data; //根据方程y-exp(ax^2+bx+c)+w 生成100个(x,y)数据,传入到x_data, y_data中。
    //假设有一条满足该方程的曲线
    cout << "generating data: " << endl;
    for(int i = 0; i < N; i++)
    {
        double x = i/100.0;
        //把x(取0~1中的100个x坐标)压入vector队列的x_data[]中
        x_data.push_back(x);
        //对应的找到与x对应的100个y,压入vector队列的y_data[]中
        y_data.push_back(exp(a*x*x+b*x+c)+rng.gaussian(w_sigma));


        cout << x_data[i] << " " << y_data[i] << endl;

    }

    //================================================
    //构建最小二乘问题
    ceres::Problem problem;
    for(int i = 0; i < N; i++)
    {
        problem.AddResidualBlock(//向问题中添加误差项。
            // 使用自动求导ceres::AutoDiffCostFunction,
            // 将之前的代价函数结构体<CURVE_FITTING_COST, 1, 3>传入:
            // 解释:
            // 第一个1是误差项(是标量)输出维度,即残差的维度,第二个3是输入维度,即待寻优参数abc的维度
            new ceres::AutoDiffCostFunction <CURVE_FITTING_COST, 1, 3> (
                    new CURVE_FITTING_COST (x_data[i], y_data[i])
                    ),  //输入数据
                    nullptr,  //核函数,这里不使用,为空
                    abc //待估计参数
         );
    }

    //配置求解器
    ceres::Solver::Options options;
    //有很多配置向可以填
    options.linear_solver_type = ceres::DENSE_QR; //增量方程如何求解
    options.minimizer_progress_to_stdout = true;  //输出到cout

    ceres::Solver::Summary summary;                                    //优化信息
    chrono::steady_clock::time_point t1 = chrono::steady_clock::now(); //时间同步
    ceres::Solve(options, &problem, &summary);                         //开始优化
    chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
    chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2-t1);
    cout << "solve time cost = " << time_used.count() << endl;

    //输出结果
    cout << summary.BriefReport() << endl;
    cout << "estimate a, b, c = ";
    for (auto a:abc)
        cout << a << " ";
    cout << endl;

    return 0;
}
           

继续阅读