均值漂移聚类Mean Shift Cluster用于数据聚类,代码注释详细,适合新手学习~
均值漂移聚类这名字听起来挺学术,但说白了就是个"随大流"的算法。它的核心思想特有意思:想象你站在一片丘陵地带,蒙着眼睛往四周扔石子,听到石子落地声最大的方向就迈一步,重复这操作直到爬上山顶——这就是密度最高的地方。搞机器学习的人把这过程叫做"密度梯度上升",咱们普通人理解成"哪人多往哪挤"就行了。
先看个实际场景:假设我要把淘宝买家的消费行为数据分成不同群体。咱们不整那些虚的,直接上代码:
from sklearn.datasets import make_blobs import matplotlib.pyplot as plt # 生成3簇数据,特意加了些重叠区域更真实 X, _ = make_blobs(n_samples=500, centers=3, cluster_std=1.8, random_state=11) # 可视化原始数据分布 plt.scatter(X[:,0], X[:,1], s=15, edgecolor='k') plt.title('原始数据分布') plt.show()!散点图显示三个有部分重叠的簇
这段代码的关键在makeblobs,它像捏面团似的给我们生成测试数据。clusterstd参数控制着数据点的分散程度,调大这个值会让各簇之间产生更多重叠,更接近真实场景中的数据分布。
接下来上主菜:
from sklearn.cluster import MeanShift # 创建模型时要注意带宽选择,相当于决定搜索范围的大小 # 这里让算法自动估算,实际项目需要交叉验证 ms = MeanShift(bandwidth=2.5, bin_seeding=True) ms.fit(X) # 提取聚类结果 labels = ms.labels_ cluster_centers = ms.cluster_centers_ # 可视化魔法 plt.scatter(X[:,0], X[:,1], c=labels, s=15, edgecolor='k') plt.scatter(cluster_centers[:,0], cluster_centers[:,1], c='red', s=250, marker='X', edgecolor='w') plt.title('聚类结果(红叉为簇中心)') plt.show()!显示三个被正确划分的簇,中心标记明显
重点说下bandwidth参数,这相当于算法中的"社交距离"。设置太大容易把不同群体合并,太小又会产生过多小群体。新手可以试着把这个值从1改到5,观察聚类结果的变化,比看论文管用多了。
这个算法有个特点,不用预先指定簇的数量。它像扫地机器人似的在数据空间里游走,自动发现高密度区域。但代价就是计算量较大,当数据量超过十万级时可能得换其他方法。
遇到异常值怎么办?咱们举个极端例子:
# 故意添加离群点 import numpy as np outliers = np.array([[15, 15], [-10, -5], [20, -8]]) X_with_outliers = np.vstack([X, outliers]) # 重新训练模型 ms_out = MeanShift(bandwidth=3).fit(X_with_outliers) # 可视化时用不同符号标记离群点 labels_out = ms_out.labels_ plt.scatter(X_with_outliers[:,0], X_with_outliers[:,1], c=labels_out, s=15, edgecolor='k', alpha=0.7) plt.scatter(ms_out.cluster_centers_[:,0], ms_out.cluster_centers_[:,1], c='red', s=200, marker='X') plt.title('存在离群点时的聚类效果') plt.show()!显示三个主簇外存在孤立点
这时候会发现那些孤立的点要么自成一组,要么被忽略——具体取决于带宽设置。这种特性在欺诈检测场景中反而成了优点,可以把异常交易自动筛出来。
最后给新人提个醒:虽然sklearn的实现很方便,但处理大规模数据时记得开启cluster_all=False参数,否则内存可能爆炸。工业级应用通常会结合空间索引进行优化,不过那是进阶玩法了。下次遇到形状不规则的客户分群需求,不妨先试试这个"随大流"的方法,说不定比K-means更管用。