天天看點

簡述isodata算法的原理_【機器學習】聚類算法:ISODATA算法

在之前的K-Means算法中,有兩大缺陷:

(1)K值是事先選好的固定的值

(2)随機種子選取可能對結果有影響

針對缺陷(2),我們提出了K-Means++算法,它使得随機種子選取非常合理,進而使得算法更加完美。但是缺

陷(1)始終沒有解決,也就是說在K-Means算法中K值得選取是事先選好固定的一個值,當時也提出ISODATA算

法可以找到合适的K,現在就來詳細講述ISODATA算法的原理,并會給出C++代碼。

Contents

1. ISODATA算法的認識

2. ISODATA的參數介紹

3. ISODATA的C++實作

1. ISODATA算法的認識

ISODATA算法全稱為Iterative

Self Organizing Data Analysis Techniques Algorithm,即疊代

自組織資料分析方法。ISODATA算法通過設定初始參數而引入人機對話環節,并使用歸并和分裂等機制,當兩類

聚中心小于某個閥值時,将它們合并為一類。當某類的标準差大于某一閥值時或其樣本數目超過某一閥值時,将其

分裂為兩類,在某類樣本數目小于某一閥值時,将其取消。這樣根據初始類聚中心和設定的類别數目等參數疊代,

最終得到一個比較理想的分類結果。ISODATA算法是一種常用的聚類分析方法,是一種非監督學習方法。

2. ISODATA的參數介紹

上面介紹了ISODATA算法的大緻原理,在ISODATA算法中有6個重要的參數。

expClusters預期的類聚中心數

thetaN     一個類别至少應該具有的樣本數目,小于此數目就不作為一個獨立的聚類

thetaS       一個類别樣本的标準差閥值

thetaC類聚中心之間距離的閥值,即歸并系數,若小于此數,則兩個類進行合并

maxIts允許疊代的最多次數

combL在一次疊代中可以歸并的類别的最多對數

有了如上參數,接下來就開始進行疊代了。

3. ISODATA的C++實作

ISODATA算法的詳細步驟可以參考如下代碼

#include

#include

#include

#include

#include

#include

#include

#define iniClusters 5 //初始類聚的個數

using namespace std;

//定義6個使用的參數

struct Args

{

int expClusters; //期望得到的聚類數

int thetaN; //聚類中最少樣本數

int maxIts; //最大疊代次數

int combL; //每次疊代允許合并的最大聚類對數

double thetaS; //标準偏差參數

double thetaC; //合并參數

}args;

//定義二維點,這裡假設是二維的特征,當然可以推廣到多元

struct Point

{

double x, y;

};

//需要合并的兩個類聚的資訊,包括兩個類聚的id和距離

struct MergeInfo

{

int u, v;

double d; //類聚u中心與類聚v中心的距離

};

//定義比較函數

bool cmp(MergeInfo a, MergeInfo b)

{

return a.d < b.d;

}

//計算兩點之間距離

double dist(Point A, Point B)

{

return sqrt((A.x - B.x) * (A.x - B.x) + (A.y - B.y) * (A.y - B.y));

}

struct Cluster

{

int nSamples; //樣本點的個數

double avgDist; //樣本點到樣本中心的平均距離

Point center; //樣本中心

Point sigma; //樣本與中心的标準差

vector data; //聚類的資料

//計算該聚類的中心,即該類的均值

void calMean()

{

assert(nSamples == data.size());

for(int i = 0; i < nSamples; i++)

{

center.x += data.at(i)->x;

center.y += data.at(i)->y;

}

center.x /= nSamples;

center.y /= nSamples;

}

//計算該類樣本點到該聚類中心得平均距離

void calDist()

{

avgDist = 0;

for(int i = 0; i < nSamples; i++)

avgDist += dist(*(data.at(i)), center);

avgDist /= nSamples;

}

//計算樣本與中心的标準差

void calStErr()

{

assert(nSamples == data.size());

double attr1 = 0;

double attr2 = 0; //樣本的兩個次元

for(int i = 0; i < nSamples; i++)

{

attr1 += (data.at(i)->x - center.x) * (data.at(i)->x - center.x);

attr2 += (data.at(i)->y - center.y) * (data.at(i)->y - center.y);

}

sigma.x = sqrt(attr1 / nSamples);

sigma.y = sqrt(attr2 / nSamples);

}

};

//擷取資料

void getData(Point p[], int n)

{

cout << "getting data..." << endl;

for(int i = 0; i < n; i++)

scanf("%lf %lf", &p[i].x, &p[i].y);

cout << "get data done!" << endl;

}

//設定參數的值

void setArgs()

{

args.expClusters = 5;

args.thetaN = 3;

args.maxIts = 10000;

args.combL = 10;

args.thetaS = 3;

args.thetaC = 0.001;

}

//尋找點t距離最近的類的中心對應的id

int FindIdx(vector &c, Point &t)

{

int nClusters = c.size();

assert(nClusters >= 1);

double ans = dist(c.at(0).center, t);

int idx = 0;

for(int i = 1; i < nClusters; i++)

{

double tmp = dist(c.at(i).center, t);

if(ans > tmp)

{

idx = i;

ans = tmp;

}

}

return idx;

}

//二分法尋找距離剛好小于thetaC的兩個類聚的index

int FindPos(MergeInfo *info, int n, double thetaC)

{

int l = 0;

int r = n - 1;

while(l <= r)

{

int mid = (l + r) >> 1;

if(info[mid].d < thetaC)

{

l = mid + 1;

if(l < n && info[l].d >= thetaC)

return mid;

}

else

{

r = mid - 1;

if(r >= 0 && info[r].d < thetaC)

return r;

}

}

if(info[n - 1].d < thetaC)

return n - 1;

else

return -1;

}

void Print(const vector c)

{

int n = c.size();

for(int i = 0; i < n; i++)

{

cout << "------------------------------------" << endl;

cout << "第" << i + 1 << "個聚類是:" << endl;

for(int j = 0; j < c.at(i).data.size(); j++)

cout << "(" << c[i].data[j]->x << "," << c[i].data[j]->y << ") ";

cout << endl;

cout << endl;

}

}

void ISOData(Point p[], int n)

{

cout << "ISOData is processing......." << endl;

vector c; //每個類聚的資料

const double split = 0.5; //分裂常數(0,1]

int nClusters = iniClusters; //初始化類聚個數

//初始化nClusters個類,設定相關資料

for(int i = 0; i < nClusters; i++)

{

Cluster t;

t.center = p[i];

t.nSamples = 0;

t.avgDist = 0;

c.push_back(t);

}

int iter = 0;

bool isLess = false; //标志是否有類的數目低于thetaN

while(1)

{

//先清空每一個聚類

for(int i = 0; i < nClusters; i++)

{

c.at(i).nSamples = 0;

c.at(i).data.clear();

}

//将所有樣本劃分到距離類聚中心最近的類中

for(int i = 0; i < n; i++)

{

int idx = FindIdx(c, p[i]);

c.at(idx).data.push_back(&p[i]);

c.at(idx).nSamples++;

}

int k = 0; //記錄樣本數目低于thetaN的類的index

for(int i = 0; i < nClusters; i++)

{

if(c.at(i).data.size() < args.thetaN)

{

isLess = true; //說明樣本數過少,該類應該删除

k = i;

break;

}

}

//如果有類的樣本數目小于thetaN

if(isLess)

{

nClusters--;

Cluster t = c.at(k);

vector::iterator pos = c.begin() + k;

c.erase(pos);

assert(nClusters == c.size());

for(int i = 0; i < t.data.size(); i++)

{

int idx = FindIdx(c, *(t.data.at(i)));

c.at(idx).data.push_back(t.data.at(i));

c.at(idx).nSamples++;

}

isLess = false;

}

//重新計算均值和樣本到類聚中心的平均距離

for(int i = 0; i < nClusters; i++)

{

c.at(i).calMean();

c.at(i).calDist();

}

//計算總的平均距離

double totalAvgDist = 0;

for(int i = 0; i < nClusters; i++)

totalAvgDist += c.at(i).avgDist * c.at(i).nSamples;

totalAvgDist /= n;

if(iter >= args.maxIts) break;

//分裂操作

if(nClusters <= args.expClusters / 2)

{

vector maxsigma;

for(int i = 0; i < nClusters; i++)

{

//計算該類的标準偏差

c.at(i).calStErr();

//計算該類标準差的最大分量

double mt = c.at(i).sigma.x > c.at(i).sigma.y? c.at(i).sigma.x : c.at(i).sigma.y;

maxsigma.push_back(mt);

}

for(int i = 0; i < nClusters; i++)

{

if(maxsigma.at(i) > args.thetaS)

{

if((c.at(i).avgDist > totalAvgDist && c.at(i).nSamples > 2 * (args.thetaN + 1)) || (nClusters < args.expClusters / 2))

{

nClusters++;

Cluster newCtr; //新的聚類中心

//擷取新的中心

newCtr.center.x = c.at(i).center.x - split * c.at(i).sigma.x;

newCtr.center.y = c.at(i).center.y - split * c.at(i).sigma.y;

c.push_back(newCtr);

//改變老的中心

c.at(i).center.x = c.at(i).center.x + split * c.at(i).sigma.x;

c.at(i).center.y = c.at(i).center.y + split * c.at(i).sigma.y;

break;

}

}

}

}

//合并操作

if(nClusters >= 2 * args.expClusters || (iter & 1) == 0)

{

int size = nClusters * (nClusters - 1);

//需要合并的聚類個數

int cnt = 0;

MergeInfo *info = new MergeInfo[size];

for(int i = 0; i < nClusters; i++)

{

for(int j = i + 1; j < nClusters; j++)

{

info[cnt].u = i;

info[cnt].v = j;

info[cnt].d = dist(c.at(i).center, c.at(j).center);

cnt++;

}

}

//進行排序

sort(info, info + cnt, cmp);

//找出info數組中距離剛好小于thetaC的index,那麼index更小的更應該合并

int iPos = FindPos(info, cnt, args.thetaC);

//用于訓示該位置的樣本點是否已經合并

bool *flag = new bool[nClusters];

memset(flag, false, sizeof(bool) * nClusters);

//用于标記該位置的樣本點是否已經合并删除

bool *del = new bool[nClusters];

memset(del, false, sizeof(bool) * nClusters);

//記錄合并的次數

int nTimes = 0;

for(int i = 0; i <= iPos; i++)

{

int u = info[i].u;

int v = info[i].v;

//確定同一個類聚隻合并一次

if(!flag[u] && !flag[v])

{

nTimes++;

//如果一次疊代中合并對數多于combL,則停止合并

if(nTimes > args.combL) break;

//将數目少的樣本合并到數目多的樣本中

if(c.at(u).nSamples < c.at(v).nSamples)

{

del[u] = true;

Cluster t = c.at(u);

assert(t.nSamples == t.data.size());

for(int j = 0; j < t.nSamples; j++)

c.at(v).data.push_back(t.data.at(j));

c.at(v).center.x = c.at(v).center.x * c.at(v).nSamples + t.nSamples * t.center.x;

c.at(v).center.y = c.at(v).center.y * c.at(v).nSamples + t.nSamples * t.center.y;

c.at(v).nSamples += t.nSamples;

c.at(v).center.x /= c.at(v).nSamples;

c.at(v).center.y /= c.at(v).nSamples;

}

else

{

del[v] = true;

Cluster t = c.at(v);

assert(t.nSamples == t.data.size());

for(int j = 0; j < t.nSamples; j++)

c.at(u).data.push_back(t.data.at(j));

c.at(u).center.x = c.at(u).center.x * c.at(u).nSamples + t.nSamples * t.center.x;

c.at(u).center.y = c.at(u).center.y * c.at(u).nSamples + t.nSamples * t.center.y;

c.at(u).nSamples += t.nSamples;

c.at(u).center.x /= c.at(u).nSamples;

c.at(u).center.y /= c.at(u).nSamples;

}

}

}

//删除合并後的聚類

vector::iterator id = c.begin();

for(int i = 0; i < nClusters; i++)

{

if(del[i])

id = c.erase(id);

else

id++;

}

//合并多少次就删除多少個

nClusters -= nTimes;

assert(nClusters == c.size());

delete[] info;

delete[] flag;

delete[] del;

info = NULL;

flag = NULL;

del = NULL;

}

if(iter >= args.maxIts) break;

iter++;

}

assert(nClusters == c.size());

Print(c);

}

int main()

{

int n;

scanf("%d", &n);

Point *p = new Point[n];

getData(p, n);

setArgs();

ISOData(p, n);

delete[] p;

p = NULL;

return 0;

}

還是用上次K-Means中的測試資料,如下

15

0 0

1 0

0 1

-1 0

0 -1

10 0

11 0

9 0

10 1

10 -1

-10 0

-11 0

-9 0

-10 1

-10 -1

輸入資料後得到如下結果

簡述isodata算法的原理_【機器學習】聚類算法:ISODATA算法

可以看出設定适當的參數後,得到的結果比較理想。