#include <ceres/ceres.h>
#include <iostream>
#include <chrono>
#include <opencv2/core/core.hpp>
using namespace std;
//代价函数的计算模型
struct CURVE_FITTING_COST
{
// 仿函数,使用方式 CURVE_FITTING_COST(const T* const abc, T* residual)
// 传入的是真实值(x-x_data[i],y-y_data[i])
CURVE_FITTING_COST (double x, double y) : _x (x), _y (y) {}
//残差计算
//操作符()重载
template <typename T>
bool operator() (const T* const abc, T* residual) const //residual残差
{
//y-exp(ax^2+bx+c)
residual[0] = T(_y) - ceres::exp(abc[0]*T(_x)*T(_x)+abc[1]*T(_x)+abc[2]);
return true;
}
const double _x, _y; //x,y数据
};
int main(int argc, char** argv)
{
//a,b,c是曲线真实模型的参数值,要求一条与该曲线逼近的曲线的参数a,b,c估计值
double a = 1.0, b = 2.0, c = 1.0; //计算结果: estimate a, b, c = 0.891943 2.17039 0.944142
//double a = 1.0, b = 5.0, c = 2.0; // 计算结果: estimate a, b, c = 0.989318 5.0169 1.99377
int N = 100; //100个数据点
double w_sigma = 1.0; //噪声值
cv::RNG rng; //OpenCV随机数产生器
double abc[3] = {0.0, 0.0, 0.0}; //abc的参数估计值,都初始化为0.0
vector<double> x_data, y_data; //根据方程y-exp(ax^2+bx+c)+w 生成100个(x,y)数据,传入到x_data, y_data中。
//假设有一条满足该方程的曲线
cout << "generating data: " << endl;
for(int i = 0; i < N; i++)
{
double x = i/100.0;
//把x(取0~1中的100个x坐标)压入vector队列的x_data[]中
x_data.push_back(x);
//对应的找到与x对应的100个y,压入vector队列的y_data[]中
y_data.push_back(exp(a*x*x+b*x+c)+rng.gaussian(w_sigma));
cout << x_data[i] << " " << y_data[i] << endl;
}
//================================================
//构建最小二乘问题
ceres::Problem problem;
for(int i = 0; i < N; i++)
{
problem.AddResidualBlock(//向问题中添加误差项。
// 使用自动求导ceres::AutoDiffCostFunction,
// 将之前的代价函数结构体<CURVE_FITTING_COST, 1, 3>传入:
// 解释:
// 第一个1是误差项(是标量)输出维度,即残差的维度,第二个3是输入维度,即待寻优参数abc的维度
new ceres::AutoDiffCostFunction <CURVE_FITTING_COST, 1, 3> (
new CURVE_FITTING_COST (x_data[i], y_data[i])
), //输入数据
nullptr, //核函数,这里不使用,为空
abc //待估计参数
);
}
//配置求解器
ceres::Solver::Options options;
//有很多配置向可以填
options.linear_solver_type = ceres::DENSE_QR; //增量方程如何求解
options.minimizer_progress_to_stdout = true; //输出到cout
ceres::Solver::Summary summary; //优化信息
chrono::steady_clock::time_point t1 = chrono::steady_clock::now(); //时间同步
ceres::Solve(options, &problem, &summary); //开始优化
chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2-t1);
cout << "solve time cost = " << time_used.count() << endl;
//输出结果
cout << summary.BriefReport() << endl;
cout << "estimate a, b, c = ";
for (auto a:abc)
cout << a << " ";
cout << endl;
return 0;
}