在之前的K-Means算法中,有兩大缺陷:
(1)K值是事先選好的固定的值
(2)随機種子選取可能對結果有影響
針對缺陷(2),我們提出了K-Means++算法,它使得随機種子選取非常合理,進而使得算法更加完美。但是缺
陷(1)始終沒有解決,也就是說在K-Means算法中K值得選取是事先選好固定的一個值,當時也提出ISODATA算
法可以找到合适的K,現在就來詳細講述ISODATA算法的原理,并會給出C++代碼。
Contents
1. ISODATA算法的認識
2. ISODATA的參數介紹
3. ISODATA的C++實作
1. ISODATA算法的認識
ISODATA算法全稱為Iterative
Self Organizing Data Analysis Techniques Algorithm,即疊代
自組織資料分析方法。ISODATA算法通過設定初始參數而引入人機對話環節,并使用歸并和分裂等機制,當兩類
聚中心小于某個閥值時,将它們合并為一類。當某類的标準差大于某一閥值時或其樣本數目超過某一閥值時,将其
分裂為兩類,在某類樣本數目小于某一閥值時,将其取消。這樣根據初始類聚中心和設定的類别數目等參數疊代,
最終得到一個比較理想的分類結果。ISODATA算法是一種常用的聚類分析方法,是一種非監督學習方法。
2. ISODATA的參數介紹
上面介紹了ISODATA算法的大緻原理,在ISODATA算法中有6個重要的參數。
expClusters預期的類聚中心數
thetaN 一個類别至少應該具有的樣本數目,小于此數目就不作為一個獨立的聚類
thetaS 一個類别樣本的标準差閥值
thetaC類聚中心之間距離的閥值,即歸并系數,若小于此數,則兩個類進行合并
maxIts允許疊代的最多次數
combL在一次疊代中可以歸并的類别的最多對數
有了如上參數,接下來就開始進行疊代了。
3. ISODATA的C++實作
ISODATA算法的詳細步驟可以參考如下代碼
#include
#include
#include
#include
#include
#include
#include
#define iniClusters 5 //初始類聚的個數
using namespace std;
//定義6個使用的參數
struct Args
{
int expClusters; //期望得到的聚類數
int thetaN; //聚類中最少樣本數
int maxIts; //最大疊代次數
int combL; //每次疊代允許合并的最大聚類對數
double thetaS; //标準偏差參數
double thetaC; //合并參數
}args;
//定義二維點,這裡假設是二維的特征,當然可以推廣到多元
struct Point
{
double x, y;
};
//需要合并的兩個類聚的資訊,包括兩個類聚的id和距離
struct MergeInfo
{
int u, v;
double d; //類聚u中心與類聚v中心的距離
};
//定義比較函數
bool cmp(MergeInfo a, MergeInfo b)
{
return a.d < b.d;
}
//計算兩點之間距離
double dist(Point A, Point B)
{
return sqrt((A.x - B.x) * (A.x - B.x) + (A.y - B.y) * (A.y - B.y));
}
struct Cluster
{
int nSamples; //樣本點的個數
double avgDist; //樣本點到樣本中心的平均距離
Point center; //樣本中心
Point sigma; //樣本與中心的标準差
vector data; //聚類的資料
//計算該聚類的中心,即該類的均值
void calMean()
{
assert(nSamples == data.size());
for(int i = 0; i < nSamples; i++)
{
center.x += data.at(i)->x;
center.y += data.at(i)->y;
}
center.x /= nSamples;
center.y /= nSamples;
}
//計算該類樣本點到該聚類中心得平均距離
void calDist()
{
avgDist = 0;
for(int i = 0; i < nSamples; i++)
avgDist += dist(*(data.at(i)), center);
avgDist /= nSamples;
}
//計算樣本與中心的标準差
void calStErr()
{
assert(nSamples == data.size());
double attr1 = 0;
double attr2 = 0; //樣本的兩個次元
for(int i = 0; i < nSamples; i++)
{
attr1 += (data.at(i)->x - center.x) * (data.at(i)->x - center.x);
attr2 += (data.at(i)->y - center.y) * (data.at(i)->y - center.y);
}
sigma.x = sqrt(attr1 / nSamples);
sigma.y = sqrt(attr2 / nSamples);
}
};
//擷取資料
void getData(Point p[], int n)
{
cout << "getting data..." << endl;
for(int i = 0; i < n; i++)
scanf("%lf %lf", &p[i].x, &p[i].y);
cout << "get data done!" << endl;
}
//設定參數的值
void setArgs()
{
args.expClusters = 5;
args.thetaN = 3;
args.maxIts = 10000;
args.combL = 10;
args.thetaS = 3;
args.thetaC = 0.001;
}
//尋找點t距離最近的類的中心對應的id
int FindIdx(vector &c, Point &t)
{
int nClusters = c.size();
assert(nClusters >= 1);
double ans = dist(c.at(0).center, t);
int idx = 0;
for(int i = 1; i < nClusters; i++)
{
double tmp = dist(c.at(i).center, t);
if(ans > tmp)
{
idx = i;
ans = tmp;
}
}
return idx;
}
//二分法尋找距離剛好小于thetaC的兩個類聚的index
int FindPos(MergeInfo *info, int n, double thetaC)
{
int l = 0;
int r = n - 1;
while(l <= r)
{
int mid = (l + r) >> 1;
if(info[mid].d < thetaC)
{
l = mid + 1;
if(l < n && info[l].d >= thetaC)
return mid;
}
else
{
r = mid - 1;
if(r >= 0 && info[r].d < thetaC)
return r;
}
}
if(info[n - 1].d < thetaC)
return n - 1;
else
return -1;
}
void Print(const vector c)
{
int n = c.size();
for(int i = 0; i < n; i++)
{
cout << "------------------------------------" << endl;
cout << "第" << i + 1 << "個聚類是:" << endl;
for(int j = 0; j < c.at(i).data.size(); j++)
cout << "(" << c[i].data[j]->x << "," << c[i].data[j]->y << ") ";
cout << endl;
cout << endl;
}
}
void ISOData(Point p[], int n)
{
cout << "ISOData is processing......." << endl;
vector c; //每個類聚的資料
const double split = 0.5; //分裂常數(0,1]
int nClusters = iniClusters; //初始化類聚個數
//初始化nClusters個類,設定相關資料
for(int i = 0; i < nClusters; i++)
{
Cluster t;
t.center = p[i];
t.nSamples = 0;
t.avgDist = 0;
c.push_back(t);
}
int iter = 0;
bool isLess = false; //标志是否有類的數目低于thetaN
while(1)
{
//先清空每一個聚類
for(int i = 0; i < nClusters; i++)
{
c.at(i).nSamples = 0;
c.at(i).data.clear();
}
//将所有樣本劃分到距離類聚中心最近的類中
for(int i = 0; i < n; i++)
{
int idx = FindIdx(c, p[i]);
c.at(idx).data.push_back(&p[i]);
c.at(idx).nSamples++;
}
int k = 0; //記錄樣本數目低于thetaN的類的index
for(int i = 0; i < nClusters; i++)
{
if(c.at(i).data.size() < args.thetaN)
{
isLess = true; //說明樣本數過少,該類應該删除
k = i;
break;
}
}
//如果有類的樣本數目小于thetaN
if(isLess)
{
nClusters--;
Cluster t = c.at(k);
vector::iterator pos = c.begin() + k;
c.erase(pos);
assert(nClusters == c.size());
for(int i = 0; i < t.data.size(); i++)
{
int idx = FindIdx(c, *(t.data.at(i)));
c.at(idx).data.push_back(t.data.at(i));
c.at(idx).nSamples++;
}
isLess = false;
}
//重新計算均值和樣本到類聚中心的平均距離
for(int i = 0; i < nClusters; i++)
{
c.at(i).calMean();
c.at(i).calDist();
}
//計算總的平均距離
double totalAvgDist = 0;
for(int i = 0; i < nClusters; i++)
totalAvgDist += c.at(i).avgDist * c.at(i).nSamples;
totalAvgDist /= n;
if(iter >= args.maxIts) break;
//分裂操作
if(nClusters <= args.expClusters / 2)
{
vector maxsigma;
for(int i = 0; i < nClusters; i++)
{
//計算該類的标準偏差
c.at(i).calStErr();
//計算該類标準差的最大分量
double mt = c.at(i).sigma.x > c.at(i).sigma.y? c.at(i).sigma.x : c.at(i).sigma.y;
maxsigma.push_back(mt);
}
for(int i = 0; i < nClusters; i++)
{
if(maxsigma.at(i) > args.thetaS)
{
if((c.at(i).avgDist > totalAvgDist && c.at(i).nSamples > 2 * (args.thetaN + 1)) || (nClusters < args.expClusters / 2))
{
nClusters++;
Cluster newCtr; //新的聚類中心
//擷取新的中心
newCtr.center.x = c.at(i).center.x - split * c.at(i).sigma.x;
newCtr.center.y = c.at(i).center.y - split * c.at(i).sigma.y;
c.push_back(newCtr);
//改變老的中心
c.at(i).center.x = c.at(i).center.x + split * c.at(i).sigma.x;
c.at(i).center.y = c.at(i).center.y + split * c.at(i).sigma.y;
break;
}
}
}
}
//合并操作
if(nClusters >= 2 * args.expClusters || (iter & 1) == 0)
{
int size = nClusters * (nClusters - 1);
//需要合并的聚類個數
int cnt = 0;
MergeInfo *info = new MergeInfo[size];
for(int i = 0; i < nClusters; i++)
{
for(int j = i + 1; j < nClusters; j++)
{
info[cnt].u = i;
info[cnt].v = j;
info[cnt].d = dist(c.at(i).center, c.at(j).center);
cnt++;
}
}
//進行排序
sort(info, info + cnt, cmp);
//找出info數組中距離剛好小于thetaC的index,那麼index更小的更應該合并
int iPos = FindPos(info, cnt, args.thetaC);
//用于訓示該位置的樣本點是否已經合并
bool *flag = new bool[nClusters];
memset(flag, false, sizeof(bool) * nClusters);
//用于标記該位置的樣本點是否已經合并删除
bool *del = new bool[nClusters];
memset(del, false, sizeof(bool) * nClusters);
//記錄合并的次數
int nTimes = 0;
for(int i = 0; i <= iPos; i++)
{
int u = info[i].u;
int v = info[i].v;
//確定同一個類聚隻合并一次
if(!flag[u] && !flag[v])
{
nTimes++;
//如果一次疊代中合并對數多于combL,則停止合并
if(nTimes > args.combL) break;
//将數目少的樣本合并到數目多的樣本中
if(c.at(u).nSamples < c.at(v).nSamples)
{
del[u] = true;
Cluster t = c.at(u);
assert(t.nSamples == t.data.size());
for(int j = 0; j < t.nSamples; j++)
c.at(v).data.push_back(t.data.at(j));
c.at(v).center.x = c.at(v).center.x * c.at(v).nSamples + t.nSamples * t.center.x;
c.at(v).center.y = c.at(v).center.y * c.at(v).nSamples + t.nSamples * t.center.y;
c.at(v).nSamples += t.nSamples;
c.at(v).center.x /= c.at(v).nSamples;
c.at(v).center.y /= c.at(v).nSamples;
}
else
{
del[v] = true;
Cluster t = c.at(v);
assert(t.nSamples == t.data.size());
for(int j = 0; j < t.nSamples; j++)
c.at(u).data.push_back(t.data.at(j));
c.at(u).center.x = c.at(u).center.x * c.at(u).nSamples + t.nSamples * t.center.x;
c.at(u).center.y = c.at(u).center.y * c.at(u).nSamples + t.nSamples * t.center.y;
c.at(u).nSamples += t.nSamples;
c.at(u).center.x /= c.at(u).nSamples;
c.at(u).center.y /= c.at(u).nSamples;
}
}
}
//删除合并後的聚類
vector::iterator id = c.begin();
for(int i = 0; i < nClusters; i++)
{
if(del[i])
id = c.erase(id);
else
id++;
}
//合并多少次就删除多少個
nClusters -= nTimes;
assert(nClusters == c.size());
delete[] info;
delete[] flag;
delete[] del;
info = NULL;
flag = NULL;
del = NULL;
}
if(iter >= args.maxIts) break;
iter++;
}
assert(nClusters == c.size());
Print(c);
}
int main()
{
int n;
scanf("%d", &n);
Point *p = new Point[n];
getData(p, n);
setArgs();
ISOData(p, n);
delete[] p;
p = NULL;
return 0;
}
還是用上次K-Means中的測試資料,如下
15
0 0
1 0
0 1
-1 0
0 -1
10 0
11 0
9 0
10 1
10 -1
-10 0
-11 0
-9 0
-10 1
-10 -1
輸入資料後得到如下結果
可以看出設定适當的參數後,得到的結果比較理想。