天天看點

模式識别 - libsvm的函數調用方法 具體解釋

libsvm的函數調用方法 具體解釋

本文位址: http://blog.csdn.net/caroline_wendy/article/details/26261173

須要載入(load)SVM的模型, 然後将結點轉換為SVM的格式, 即索引(index)+資料(value)的形式;

釋放SVM的model有專用的函數: svm_free_and_destroy_model, 否則easy記憶體洩露;

能夠預測資料的機率, 則須要模型是機率模型, 傳回的是一個類别數組(2分類, 則為2個值的數組), 即各個标簽的機率值;

注意: 标簽即機率值較大的部分, 是以在訓練時, 應注意正負樣本的順序, 

正樣本在前, 下标0, 為正樣本的機率, 下标1, 為負樣本的機率; 反之亦然.

代碼:

/*! @file
********************************************************************************
<PRE>
子產品名 : 分類器
檔案名稱 : SvmClassifier.cpp
相關檔案 : SvmClassifier.h
檔案實作功能 : SVM分類器類實作
作者 : C.L.Wang
Email: [email protected]
版本号 : 1.0
--------------------------------------------------------------------------------
多線程安全性 : 是
異常時安全性 : 是
--------------------------------------------------------------------------------
備注 : 無
--------------------------------------------------------------------------------
改動記錄 :  
日 期              版本号   改動人         改動内容 
2014/03/27  1.0    C.L.Wang        Create
</PRE>
********************************************************************************

* 版權全部(c) C.L.Wang, 保留全部權利

*******************************************************************************/

#include "stdafx.h"

#include "SvmClassifier.h"

#include <opencv.hpp>

using namespace std;
using namespace cv;
using namespace vd;

const std::string SvmClassifier::NORM_NAME = "normalization.xml"; //歸一化模型
const std::string SvmClassifier::SVM_MODEL_NAME = "hvd.model"; //Svm模型
bool SvmClassifier::m_mutex = true; //互相排斥鎖

/*! @function
********************************************************************************
<PRE>
函數名 : SvmClassifier
功能 : 參數構造函數
參數 : 
const Mat& _videoFeature, 視訊特征; 
const string& _modelPath, 模型路徑;
傳回值 : 無
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : SvmClassifier iSF(_videoFeature, _modelPath);
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
SvmClassifier::SvmClassifier (
  const cv::Mat& _videoFeature, /*特征*/
  const std::string& _modelPath /*模型路徑*/
  ) :
  Classifier(_videoFeature, _modelPath),
  m_model(nullptr),
  m_node(nullptr)
{
  return;
}

/*! @function
********************************************************************************
<PRE>
函數名 : ~SvmClassifier
功能 : 析構函數
參數 : void
傳回值 : 無
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : iSC.~SvmClassifier();
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
SvmClassifier::~SvmClassifier (void)
{
  if (m_model != nullptr) {
    svm_free_and_destroy_model(&m_model);
  }

  if (m_node != nullptr) {
    delete[] m_node;
    m_node = nullptr;
  }

  return;
}

/*! @function
********************************************************************************
<PRE>
函數名 : calculateResult
功能 : 計算分類結果
參數 : void
傳回值 : const double, 分類結果
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : result = iSC.calculateResult();
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
const double SvmClassifier::calculateResult (void)
{
  double result(0.0);

  while(1) {
    if (m_mutex == true) 
    {
      m_mutex = false;
      _initModel();

      result = _predictValue();

      if (m_model != nullptr) {
        svm_free_and_destroy_model(&m_model);
      }

      if (m_node != nullptr) {
        delete[] m_node;
        m_node = nullptr;
      }
      m_mutex = true;
      break;
    }
  }

  return result;
}

/*! @function
********************************************************************************
<PRE>
函數名 : _predictValue
功能 : 預測值
參數 : void
傳回值 : const double, 預測值;
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : result = _predictValue();
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
const double SvmClassifier::_predictValue (void) const
{
  double label (0.0);
  double prop (0.0);
  const int nr_class (2);
  double* prob_estimates = (double *) malloc(nr_class*sizeof(double));

  label = svm_predict_probability(m_model, m_node, prob_estimates);
  prop = prob_estimates[0]; //傳回預測機率值

  delete[] prob_estimates;

  return prop;
}

/*! @function
********************************************************************************
<PRE>
函數名 : _initModel
功能 : 初始化模型
參數 : void
傳回值 : void
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : _initModel();
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
void SvmClassifier::_initModel (void)
{
  /*完整路徑*/

  std::string modelName (m_modelPath); //模型名稱
  std::string normName (m_modelPath); //歸一化名稱

  const std::string slash("/");

  modelName.append(slash);
  modelName.append(SVM_MODEL_NAME);

  normName.append(slash);
  normName.append(NORM_NAME);

  std::ifstream ifs;
  ifs.open(modelName, ios::in);
  if (ifs.fail()) {
    __printLog(std::cerr, "Failed to open the model file!");
  }
  ifs.close();

  ifs.open(normName, ios::in);
  if (ifs.fail()) {
    __printLog(std::cerr, "Failed to open the model file!");
  }
  ifs.close();

  if (m_model != nullptr) {
    svm_free_and_destroy_model(&m_model);
  }
  m_model = svm_load_model(modelName.c_str());

  __transSvmNode(normName);

  return;
}

/*! @function
********************************************************************************
<PRE>
函數名 : __transSvmNode
功能 : 轉換Svm結點
參數 : const string& normName, 歸一化模型路徑
傳回值 : void
抛出異常 : 無
--------------------------------------------------------------------------------
複雜度 : 無
備注 : 無
典型使用方法 : __transSvmNode(normName);
--------------------------------------------------------------------------------
作者 : C.L.Wang
</PRE>
*******************************************************************************/ 
void SvmClassifier::__transSvmNode (const std::string& _normName)
{
  cv::FileStorage fs(_normName, FileStorage::READ);
  cv::Mat maxNorm;
  fs["normalization"] >> maxNorm;
  fs.release(); 

  /*歸一化視訊特征*/

  cv::Mat normFeature = 
    cv::Mat::zeros(1, maxNorm.cols-2, CV_64FC1);

  for (int j=2; j<m_videoFeature.cols; ++j) {
    for(int i=0; i<m_videoFeature.rows; ++i) {
      normFeature.at<double>(0, j-2) += m_videoFeature.at<double>(i, j);
    }
  }

  for (int j=0; j<normFeature.cols; ++j)
  {
    normFeature.at<double>(0, j) /= m_videoFeature.rows;
    if (maxNorm.at<double>(0, j+2) > 0.0001)
      normFeature.at<double>(0, j) /= maxNorm.at<double>(0, j+2);
  }
  normFeature.at<double>(0,0) = 0.0;

  if (m_node != nullptr) {
    delete[] m_node;
    m_node = nullptr;
  }

  m_node = new svm_node[normFeature.cols];
  for (int j=1; j < normFeature.cols; ++j) {
    m_node[j-1].index = j;
    m_node[j-1].value = normFeature.at<double>(0, j);
  }

  m_node[normFeature.cols-1].index = -1;
  m_node[normFeature.cols-1].value = 0;

  return;
}