天天看點

Python Logistic 回歸分類

Logistic回歸可以認為是線性回歸的延伸,其作用是對二分類樣本進行訓練,進而對達到預測新樣本分類的目的。

假設有一組已知分類的MxN維樣本X,M為樣本數,N為特征次元,其相應的已知分類标簽為Mx1維矩陣Y。那麼Logistic回歸的實作思路如下:

(1)用一組權重值W(Nx1)對X的特征進行線性變換,得到變換後的樣本X’(Mx1),其目标是使屬于不同分類的樣本X’存在一個明顯的一維邊界。

(2)然後再對樣本X’進一步做函數變換,進而使處于一維邊界兩測的值變換到相應的範圍之内。

(3)訓練過程就是通過改變W盡可能使得到的值位于一維邊界兩側,并且與已知分類相符。

(4)對于Logistic回歸,就是将原樣本的邊界變換到x=0這個邊界。

下面是Logistic回歸的典型代碼:

# -*- coding: utf-8 -*-

"""

Created on Wed Nov 09 15:21:48 2016

Logistic回歸分類

"""

import numpy  as np

class LogisticRegressionClassifier ( ):

     def  __init__ ( self ):

         self._alpha  =  None

     #定義一個sigmoid函數

     def _sigmoid ( self , fx ):

         return  1.0/ ( 1 + np. exp (-fx ) )

     #alpha為步長(學習率);maxCycles最大疊代次數

     def _gradDescent ( self , featData , labelData , alpha , maxCycles ):

        dataMat  = np. mat (featData )                       #size: m*n

        labelMat  = np. mat (labelData ). transpose ( )         #size: m*1

        m , n  = np. shape (dataMat )

        weigh  = np. ones ( (n ,  1 ) ) 

         for i  in  range (maxCycles ):

            hx  =  self._sigmoid (dataMat * weigh )

            error  = labelMat - hx        #size:m*1

            weigh  = weigh + alpha * dataMat. transpose ( ) * error #根據誤差修改回歸系數

         return weigh

     #使用梯度下降方法訓練模型,如果使用其它的尋參方法,此處可以做相應修改

     def fit ( self , train_x , train_y , alpha = 0.01 , maxCycles = 100 ):

         return  self._gradDescent (train_x , train_y , alpha , maxCycles )

     #使用學習得到的參數進行分類

     def predict ( self , test_X , test_y , weigh ):

        dataMat  = np. mat (test_X )

        labelMat  = np. mat (test_y ). transpose ( )   #使用transpose()轉置

        hx  =  self._sigmoid (dataMat*weigh )   #size:m*1

        m  =  len (hx )

        error  =  0.0

         for i  in  range (m ):

             if  int (hx [i ] )  >  0.5:

                 print  str (i+ 1 )+ '-th sample ' ,  int (labelMat [i ] ) ,  'is classfied as: 1' 

                 if  int (labelMat [i ] )  !=  1:

                    error + =  1.0

                     print  "classify error."

             else:

                 print  str (i+ 1 )+ '-th sample ' ,  int (labelMat [i ] ) ,  'is classfied as: 0' 

                 if  int (labelMat [i ] )  !=  0:

                    error + =  1.0

                     print  "classify error."

        error_rate  = error/m

         print  "error rate is:" ,  "%.4f" %error_rate

         return error_rate