R:實作K-means

這篇文章來自於大四系上修的一堂課,主要的練習為用R實做出K-means的概念,那首先我們就得先知道,K-means的核心是甚麼,那就讓我們來看看吧

What is k-means?

K-means是一種機器學習的演算法,其本身屬於非監督式學習,而非監督式學習又是什麼呢?簡單來說,就是你丟給電腦的資料是沒有標記的,就讓演算法自己表現吧類似的概念,關於更詳細的內容之後有機會,在做一個詳細的整理,這裡就先大概帶過,而k-means本身的概念比較像是clustering,以成語來說就是「物以類聚」。

來自https://dotblogs.com.tw/dragon229/2013/02/04/89919

好,這張圖應該有很好的幫助你了解K-means了。沒錯,就是像你看到的一樣,我們拿到一筆資料後,透過K-means演算法,成功的將資料分成了三類,而當我們分成三類後,對們在後續說故事或分析才會更有幫助,若是我們只使用原始資料(左邊),可能就難以看出什麼端倪,但當我們將資料做好cluster後,我們就可以發現,喔~原來是有三種pattern的資料在裡面,也可以幫助我們更好的理解我們的資料,好的,事不宜遲,我們就來用R建立我們的K-means 演算法吧(基本上R裡已經有套件可以使用,這裡實際建置是為了瞭解其運作,並且可以有更多的練習。

我們需要準備什麼

1.給定一個cluster的中心:這點相當的重要,我們可以random尋找,畢竟在這過程中,K-means的中心點會慢慢的移動,直到滿足我們設定的條件,但我們會發現,第一個找的中心點,如果找得不好,最後分出來的群也會受到影響。

2.開始把資料分派到最靠近的中心,對了,K就是你想要幾個中心點的意思,K=3,代表你決定了三個中心點,最終會分成三群。

3.計算新群的新的中心點(取這個群中所有資料的平均)

4.重複 2~3的步驟直到收斂為止(就是達到你要的結果為止),那要怎麼收斂呢?接著就要來看一點數學啦

這個數學式子其實非常的好懂,X_i就是每個資料點,而U_j就是每個新群的中心點(也就是該資料群的mean),我們的目標就是找到,在該群中的資料點,到中心點的距離總和的最小值。 既然已經明白了,那我們就開始吧!

R實作部分

我們使用的資料長得像這樣,資料為txt檔,資料跟資料是
用Tab隔開的

資料點的分布大概長得像這個樣子,用R分常簡單直接的plot函數

接著我們將我們的kmeans寫成函式,Let’s Go,整段的程式碼大概像底下這樣

kmeansfun = function(x,k){
  tol=0.0001
  l=10
  iter=1
  dists=matrix(0,nrow = nrow(x),ncol=k)
  ini=sample(nrow(x),k,replace = FALSE) 
  center=matrix(0,ncol=ncol(x),nrow=k)
  center=x[ini,]#第一次的中點
  nsse=vector()
  grp=vector()
  sse=vector()
  while(l>tol){
    for(i in 1 : nrow(x) ){
      for(j in 1:k){
        dists[i,j]=sqrt(sum((x[i,]-center[j,])^2))
        
      }
    }
    
    #分組
    for(i in 1:nrow(x)){
      grp[i]=which.min(dists[i,])    
      
    }
    
    newcenter=matrix(0,ncol=ncol(x),nrow=k) #指定新的中心點為分完群的平均點
    
    for(i in 1:k){             
      newcenter[i , ] = apply(x[which(grp==i),],2,mean) 
      sse[i ] = sum((x[which(grp==i),]-newcenter[ i , ])^2)
    }
    
    
    nsse[iter]=sum(sse)
    l=sum(center-newcenter)^2
    center = newcenter
    iter=iter+1
    #此處為停止的條件
  }
  plot(x=c(2:iter-1),y=nsse,type='b',main='SSE圖隨著疊代的變化')
  return(list(grp=grp, center=center, sse=nsse[iter-1],iter=iter))
}

但現在讓我們來一段一段的拆解它吧!

一開始的寫法就是R的函式的寫法,tol為我們的容忍值,因為資料不管怎麼算因為公式的關係不會為0,因此我們就必須設定一個容忍值,當我們用公式算出來的值仍然大於容忍值時,我們就繼續指派新的群中心,直到算出小於容忍值的中心點,dists為我們拿來存放跟三個中心點的距離,以我們的資料而言 nrow=100,k=3,所以我們會建出一個100X3的matrix,而ini則是我們以R內建函數sample,從nrow(100)中,挑出k(3)個數字,此處是用位置的方式操作,可以對應到center=x[ini,]這一行,因此我們的步驟1:指派起始中點在這裡就默默完成了,剩下的三個變數是指派以讓後續使用。

kmeansfun = function(x,k){
  tol=0.0001
  l=10
  iter=1
  dists=matrix(0,nrow = nrow(x),ncol=k)
  ini=sample(nrow(x),k,replace = FALSE) 
  center=matrix(0,ncol=ncol(x),nrow=k)
  center=x[ini,]#第一次的中點
  nsse=vector()
  grp=vector()
  sse=vector()

我們在剛剛將l設為了10,而tol設定0.0001,因此這個while迴圈會成立,那讓我們來看看while迴圈裡有什麼吧

一開始,我們用了兩個for迴圈,而作的事情,就是把每個資料點到起始中心點的距離算出來,並存在dist裡,而我們計算距離的方法,是採用歐幾里德距離

算好距離之後,非常的直觀,就是選擇離自己最近的中心,加入他們,如此所有的資料就會被分好群(群的數量為你設定的k),這就是步驟2:分群

接著有新的群之後,我們在把新的群的data取平均值,並將其舉派為新的值(感覺就是,喔,我們現在是新團體,啊剛剛好你到我們每個人的距離總和是最近的,那你就是新的中心啦)。於是,我們又默默的完成了步驟3:找新的中心

while(l>tol){
    for(i in 1 : nrow(x) ){
      for(j in 1:k){
        dists[i,j]=sqrt(sum((x[i,]-center[j,])^2))
        
      }
    }
    
    #分組
    for(i in 1:nrow(x)){
      grp[i]=which.min(dists[i,])    
      
    }
    
    newcenter=matrix(0,ncol=ncol(x),nrow=k) #指定新的中心點為分完群的平均點
    

接著就是要把我們計算的新的中心放入我們的newcenter裡啦,這裡有使用一個工具apply,而許多使用R語言的前輩也會建議多使用apply族函數,而原因則是效率問題,因為apply族方法是用C語言寫的,如果操作的當,就能避免使用for迴圈(在R裡用迴圈非常的慢),之後有機會會在特別寫一篇整理apply族的筆記,apply的用法通常會有三個參數apply(data,margin,function),data放資料,margin=1表示按行計算,=2表示按列計算,function就是要套用的算式啦,所以這句code的解讀為,將被分到第i組的資料,按列計算平均。而SSE就是指Sum of Square Errors。

l這個變數,代表了舊的中心點和新找的中心點的距離平方,當它小於我們的容忍值時,代表它的變動不大,我們就可以停止我們的迴圈。到這邊為止,我們的k-means就完成啦。

 for(i in 1:k){             
      newcenter[i , ] = apply(x[which(grp==i),],2,mean) 
      sse[i ] = sum((x[which(grp==i),]-newcenter[ i , ])^2)
    }
    
    
    nsse[iter]=sum(sse)
    l=sum(center-newcenter)^2
    center = newcenter
    iter=iter+1
    #此處為停止的條件
  }
  plot(x=c(2:iter-1),y=nsse,type='b',main='SSE圖隨著疊代的變化')
  return(list(grp=grp, center=center, sse=nsse[iter-1],iter=iter))
}

K-means動起來

#k=3
kmeansfun(hw8_q2_data,3)

我的設計是讓他回傳四個值,第一個是每個資料點的分組,第二個是中心點,第三個是SSE,第四個是共迭代了幾次

我們可以看到SSE隨著迭代次數增加而降低,最後在一個區間震盪

我們的資料點被很好的分組啦!!!

#將值存到cr裡(grp,center...)
cr<-kmeansfun(hw8_q2_data,3)
#割成兩個視窗
par(mfrow=c(1,2))
plot(hw8_q2_data)
plot(hw8_q2_data,col=cr$grp)
points(cr$center,col=1:3,pch=16,cex=1.5)

結論:當你多按幾次的時候,會發現分組有點不太一樣,這是因為K-means會受到選擇的第一個起始點很大的影響,因此,也有許多玩家在討論如何找到最佳起始點,大家可以自己去試試

而k-means由於在算新的中點時,使用了mean,但我們也知道,mean會受到outlier很大的影響而失準;同樣的,也使得k-means會outlier相當的敏感。

簡單的實作,當然k-means還有很多其他的應用,等待發掘,之後有機會再繼續研究~

發表留言