数据分析大佬用Python代码教会你Mean Shift聚类
MeanShift算法可以称之为均值漂移聚类,是基于聚类中心的聚类算法,但和k-means聚类不同的是,不需要提前设定类别的个数k。在MeanShift算法中聚类中心是通过一定范围内样本密度来确定的,通过不断更新聚类中心,直到最终的聚类中心达到终止条件。整个过程可以看下图,我觉得还是比较形象的。
MeanShift向量
MeanShift向量是指对于样本X1,在以样本点X1为中心,半径为h的高维球区域内的所有样本点X的加权平均值,如下所示,同时也是样本点X1更新后的坐标。
而终止条件则是指 | Mh(X) - X |<ε,满足条件则样本点X1停止更新,否则将以Mh(X)为新的样本中心重复上述步骤。
核函数
核函数在机器学习(SVM,LR)中出现的频率是非常高的,你可以把它看做是一种映射,是计算映射到高维空间之后的内积的一种简便方法。在这个算法中将使用高斯核,其函数形式如下。
h表示带宽,当带宽h一定时,两个样本点距离越近,其核函数值越大;当两个样本点距离一定时,h越大,核函数值越小。核函数代码如下,gaosi_value为以样本点X1为中心,半径为h的高维球范围内所有样本点与X1的高斯核函数值,是一个(m,1)的矩阵。
def gaussian_kernel(self,distant): m=shape(distant)[1]#样本数 gaosi=mat(zeros((m,1))) for i in range(m): gaosi[i][0]=(distant.tolist()[0][i]*distant.tolist()[0][i]*(-0.5)/(self.bandwidth*self.bandwidth)) gaosi[i][0]=exp(gaosi[i][0]) q=1/(sqrt(2*pi)*self.bandwidth) gaosi_value=q*gaosi return gaosi_value
MeanShift向量与核函数
在01中有提到MeanShift向量是指对于样本X1,在以样本点X1为中心,半径为h的高维球区域内的所有样本点X的加权平均值。但事实上是不同点对于样本X1的贡献程度是不一样的,因此将权值(1/k)更改为每个样本与样本点X1的核函数值。改进后的MeanShift向量如下所示。
其中
就是指高斯核函数,Sh表示在半径h内的所有样本点集合。
MeanShift算法原理
在MeanShift算法中实际上利用了概率密度,求得概率密度的局部最优解。
对于一个概率密度函数f(x),已知一个概率密度函数f(X),其核密度估计为
其中K(X)是单位核,概率密度函数f(X)的梯度估计为
其中G(X)=-K'(X)。第一个中括号是以G(X)为核函数对概率密度的估计,第二个中括号是MeanShift 向量。因此MeanShift向量是与概率密度函数的梯度成正比的,总是指向概率密度增加的方向。
而对于MeanShift向量,可以将其变形为下列形式,其中mh(x)为样本点X更新后的位置。
MeanShift算法流程
在未被标记的数据点中随机选择一个点作为起始中心点X;
找出以X为中心半径为radius的区域中出现的所有数据点,认为这些点同属于一个聚类C。同时在该聚类中记录数据点出现的次数加1。
以X为中心点,计算从X开始到集合M中每个元素的向量,将这些向量相加,得到向量Mh(X)。
mh(x) =Mh(X) + X。即X沿着Mh(X)的方向移动,移动距离是||Mh(X)||。
重复步骤2、3、4,直到Mh(X)的很小(就是迭代到收敛),记住此时的X。注意,这个迭代过程中遇到的点都应该归类到簇C。
如果收敛时当前簇C的center与其它已经存在的簇C2中心的距离小于阈值,那么把C2和C合并,数据点出现次数也对应合并。否则,把C作为新的聚类。
重复1、2、3、4、5直到所有的点都被标记为已访问。
分类:根据每个类,对每个点的访问频率,取访问频率最大的那个类,作为当前点集的所属类。
TIPS:每一个样本点都需要计算其漂移均值,并根据计算出的漂移均值进行移动,直至满足终止条件,最终得到的均值漂移点为该点的聚类中心点。
MeanShift算法代码
from numpy import *from matplotlib import pyplot as plt
class mean_shift(): def __init__(self): #带宽 self.bandwidth=2 #漂移点收敛条件 self.mindistance=0.001 #簇心距离,小于该值则两簇心合并 self.cudistance=2.5
def gaussian_kernel(self,distant): m=shape(distant)[1]#样本数 gaosi=mat(zeros((m,1))) for i in range(m): gaosi[i][0]=(distant.tolist()[0][i]*distant.tolist()[0][i]*(-0.5)/(self.bandwidth*self.bandwidth)) gaosi[i][0]=exp(gaosi[i][0]) q=1/(sqrt(2*pi)*self.bandwidth) gaosi_value=q*gaosi return gaosi_value
def load_data(self): X =array([ [-4, -3.5], [-3.5, -5], [-2.7, -4.5], [-2, -4.5], [-2.9, -2.9], [-0.4, -4.5], [-1.4, -2.5], [-1.6, -2], [-1.5, -1.3], [-0.5, -2.1], [-0.6, -1], [0, -1.6], [-2.8, -1], [-2.4, -0.6], [-3.5, 0], [-0.2, 4], [0.9, 1.8], [1, 2.2], [1.1, 2.8], [1.1, 3.4], [1, 4.5], [1.8, 0.3], [2.2, 1.3], [2.9, 0], [2.7, 1.2], [3, 3], [3.4, 2.8], [3, 5], [5.4, 1.2], [6.3, 2],[0,0],[0.2,0.2],[0.1, 0.1],[-4, -3.5]]) x,y=[],[] for i in range(shape(X)[0]): x.append(X[i][0]) y.append(X[i][1]) plt.scatter(x,y,c='r') # plt.plot(x, y) plt.show() classlable=mat(zeros((shape(X)[0],1))) return X,classlable
def distance(self,a,b): v=a-b return sqrt(v*mat(v).T).tolist()[0][0] def shift_point(self,point,data,clusterfrequency): sum=0 n=shape(data)[0] ou=mat(zeros((n,1))) t=mat(zeros((n,1))) newdata=[] for i in range(n): # print(self.distance(point,data[i])) d=self.distance(point,data[i]) if d<self.bandwidth: ou[i][0] =d t[i][0]=1 newdata.append(data[i]) clusterfrequency[i]=clusterfrequency[i]+1 gaosi=self.gaussian_kernel(ou[t==1]) meanshift=gaosi.T*mat(newdata) return meanshift/gaosi.sum(),clusterfrequency
def group2(self,dataset,clusters,m): data=[] fre=[] for i in clusters: i['data']=[] fre.append(i['frequnecy']) for j in range(m): n=where(array(fre)[:,j]==max(array(fre)[:,j]))[0][0] data.append(n) clusters[n]['data'].append(dataset[j]) print("一共有%d个簇心" % len(set(data))) # print(clusters) # print(data) return clusters
def plot(self,dataset,clust): colors = 10 * ['r', 'g', 'b', 'k', 'y','orange','purple'] plt.figure(figsize=(5, 5)) plt.xlim((-8, 8)) plt.ylim((-8, 8)) plt.scatter(dataset[:, 0],dataset[:, 1], s=20) theta = linspace(0, 2 * pi, 800) for i in range(len(clust)): cluster = clust[i] data = array(cluster['data']) if len(data): plt.scatter(data[:, 0], data[:, 1], color=colors[i], s=20) centroid =cluster['centroid'].tolist()[0] plt.scatter(centroid[0], centroid[1], color=colors[i], marker='x', s=30) x, y = cos(theta) * self.bandwidth + centroid[0], sin(theta) * self.bandwidth + centroid[1] plt.plot(x, y, linewidth=1, color=colors[i]) plt.show()
def mean_shift_train(self): dataset, classlable = self.load_data() m = shape(dataset)[0] clusters = [] for i in range(m): max_distance = inf cluster_centroid = dataset[i] # print(cluster_centroid) cluster_frequency =zeros((m,1)) while max_distance>self.mindistance: w,cluster_frequency = self.shift_point(cluster_centroid,dataset,cluster_frequency) dis = self.distance(cluster_centroid, w) if dis < max_distance: max_distance = dis # print(max_distance) cluster_centroid = w has_same_cluster = False for cluster in clusters: if self.distance(cluster['centroid'],cluster_centroid)<self.cudistance: cluster['frequnecy']=cluster['frequnecy']+cluster_frequency has_same_cluster=True break if not has_same_cluster: clusters.append({'frequnecy':cluster_frequency,'centroid':cluster_centroid}) clusters=self.group2(dataset,clusters,m) print(clusters) self.plot(dataset,clusters)
if __name__=="__main__": shift=mean_shift() shift.mean_shift_train()
得到的结果图如下。
之后还会详细解说K-means聚类以及DBSCAN聚类,敬请关注。
最新活动更多
-
即日-1.24立即参与>>> 【限时免费】安森美:Treo 平台带来出色的精密模拟
-
2月28日火热报名中>> 【免费试用】东集技术年终福利——免费试用活动
-
即日-3.21立即报名 >> 【深圳 IEAE】2025 消费新场景创新与实践论坛
-
4日10日立即报名>> OFweek 2025(第十四届)中国机器人产业大会
-
7.30-8.1火热报名中>> 全数会2025(第六届)机器人及智能工厂展
-
即日-2025.8.1立即下载>> 《2024智能制造产业高端化、智能化、绿色化发展蓝皮书》
推荐专题
发表评论
请输入评论内容...
请输入评论/评论长度6~500个字
暂无评论
暂无评论