scipy.cluster.vq.kmeans2¶
- scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, seed=None)[源代码]¶
使用k-均值算法将一组观测数据分类为k个簇。
该算法试图最小化观测值和质心之间的欧几里德距离。其中包括几种初始化方法。
- 参数
- datandarray
‘N’维的‘M’בN’观测的‘M’阵列或‘M’一维观测的长度‘M’阵列。
- k整数或ndarray
要形成的簇数以及要生成的质心数。如果 minit 初始化字符串是‘Matrix’,或者如果给出了ndarray,则会将其解释为要使用的初始簇。
- iter整型,可选
要运行的k-Means算法的迭代次数。请注意,这与iters参数和kmeans函数的含义不同。
- thresh浮动,可选
(尚未使用)
- minit字符串,可选
初始化的方法。可用的方法有‘随机’、‘点’、‘++’和‘矩阵’:
“随机”:根据数据估计的均值和方差,从高斯函数生成k个质心。
‘点’:从初始质心的数据中随机选择k个观测值(行)。
‘++’:根据kmeans++方法选择k个观测值(仔细播种)
‘Matrix’:将k参数解释为初始质心的k×M(或1-D数据的长度k数组)数组。
- missing字符串,可选
方法来处理空簇。可用方法有‘WARN’和‘RAISE’:
“WARN”:发出警告并继续。
‘raise’:引发ClusterError并终止算法。
- check_finite布尔值,可选
是否检查输入矩阵是否仅包含有限个数字。禁用可能会带来性能提升,但如果输入确实包含无穷大或NAN,则可能会导致问题(崩溃、非终止)。默认值:true
- seed :{无,整型,
numpy.random.Generator
,{无,整型, 用于初始化伪随机数生成器的种子。如果 seed 为无(或
numpy.random
)、numpy.random.RandomState
使用的是Singleton。如果 seed 是一个整型、一个新的RandomState
实例,其种子设定为 seed 。如果 seed 已经是一个Generator
或RandomState
实例,则使用该实例。默认值为None。
- 退货
- centroidndarray
在上次迭代k-Means时找到的‘k’בN’质心数组。
- labelndarray
标签 [i] 是第i个观测最接近的质心的代码或索引。
参见
参考文献
- 1
D.Arthur和S.Vassilvitskii,“k-Means++:谨慎播种的优势”,第18届ACM-SIAM离散算法研讨会论文集,2007年。
示例
>>> from scipy.cluster.vq import kmeans2 >>> import matplotlib.pyplot as plt
创建z,这是一个形状为(100,2)的数组,包含来自三个多变量正态分布的样本的混合。
>>> rng = np.random.default_rng() >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45) >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30) >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25) >>> z = np.concatenate((a, b, c)) >>> rng.shuffle(z)
计算三个群集。
>>> centroid, label = kmeans2(z, 3, minit='points') >>> centroid array([[ 2.22274463, -0.61666946], # may vary [ 0.54069047, 5.86541444], [ 6.73846769, 4.01991898]])
每个群集中有多少个点?
>>> counts = np.bincount(label) >>> counts array([29, 51, 20]) # may vary
绘制群集。
>>> w0 = z[label == 0] >>> w1 = z[label == 1] >>> w2 = z[label == 2] >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0') >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1') >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2') >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids') >>> plt.axis('equal') >>> plt.legend(shadow=True) >>> plt.show()