备注
Go to the end 下载完整的示例代码。或者通过浏览器中的MysterLite或Binder运行此示例
分层集群:结构化与非结构化病房#
Example构建了一个swiss roll数据集,并在它们的位置上运行层次聚类。
详细信息请参见 层次聚类 .
在第一步中,分层集群是在没有结构连接性约束的情况下执行的,并且仅基于距离,而在第二步中,集群仅限于k-最近邻居图:这是具有结构先验性的分层集群。
一些在没有连接性约束的情况下学习的集群不尊重瑞士卷的结构,而是延伸穿过管汇的不同褶皱。相反,当对抗连接性限制时,集群形成了瑞士卷的良好分布。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import time as time
# The following import is required
# for 3D projection to work with matplotlib < 3.2
import mpl_toolkits.mplot3d # noqa: F401
import numpy as np
生成数据#
我们首先生成Swiss Roll数据集。
from sklearn.datasets import make_swiss_roll
n_samples = 1500
noise = 0.05
X, _ = make_swiss_roll(n_samples, noise=noise)
# Make it thinner
X[:, 1] *= 0.5
计算聚类#
我们执行AgglomerativeHolding,它属于分层集群,没有任何连接性约束。
from sklearn.cluster import AgglomerativeClustering
print("Compute unstructured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print(f"Elapsed time: {elapsed_time:.2f}s")
print(f"Number of points: {label.size}")
Compute unstructured hierarchical clustering...
Elapsed time: 0.02s
Number of points: 1500
情节结果#
绘制非结构化分层集群。
import matplotlib.pyplot as plt
fig1 = plt.figure()
ax1 = fig1.add_subplot(111, projection="3d", elev=7, azim=-80)
ax1.set_position([0, 0, 0.95, 1])
for l in np.unique(label):
ax1.scatter(
X[label == l, 0],
X[label == l, 1],
X[label == l, 2],
color=plt.cm.jet(float(l) / np.max(label + 1)),
s=20,
edgecolor="k",
)
_ = fig1.suptitle(f"Without connectivity constraints (time {elapsed_time:.2f}s)")

我们正在定义具有10个邻居的k-最近邻居#
from sklearn.neighbors import kneighbors_graph
connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
计算聚类#
我们再次使用连通性约束执行AgglomerativeClustering。
print("Compute structured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(
n_clusters=6, connectivity=connectivity, linkage="ward"
).fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print(f"Elapsed time: {elapsed_time:.2f}s")
print(f"Number of points: {label.size}")
Compute structured hierarchical clustering...
Elapsed time: 0.04s
Number of points: 1500
情节结果#
绘制结构化分层集群。
fig2 = plt.figure()
ax2 = fig2.add_subplot(121, projection="3d", elev=7, azim=-80)
ax2.set_position([0, 0, 0.95, 1])
for l in np.unique(label):
ax2.scatter(
X[label == l, 0],
X[label == l, 1],
X[label == l, 2],
color=plt.cm.jet(float(l) / np.max(label + 1)),
s=20,
edgecolor="k",
)
fig2.suptitle(f"With connectivity constraints (time {elapsed_time:.2f}s)")
plt.show()

Total running time of the script: (0分0.286秒)
相关实例
Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>
_