分层集群:结构化与非结构化病房#

Example构建了一个swiss roll数据集,并在它们的位置上运行层次聚类。

详细信息请参见 层次聚类 .

在第一步中,分层集群是在没有结构连接性约束的情况下执行的,并且仅基于距离,而在第二步中,集群仅限于k-最近邻居图:这是具有结构先验性的分层集群。

一些在没有连接性约束的情况下学习的集群不尊重瑞士卷的结构,而是延伸穿过管汇的不同褶皱。相反,当对抗连接性限制时,集群形成了瑞士卷的良好分布。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import time as time

# The following import is required
# for 3D projection to work with matplotlib < 3.2
import mpl_toolkits.mplot3d  # noqa: F401
import numpy as np

生成数据#

我们首先生成Swiss Roll数据集。

from sklearn.datasets import make_swiss_roll

n_samples = 1500
noise = 0.05
X, _ = make_swiss_roll(n_samples, noise=noise)
# Make it thinner
X[:, 1] *= 0.5

计算聚类#

我们执行AgglomerativeHolding,它属于分层集群,没有任何连接性约束。

from sklearn.cluster import AgglomerativeClustering

print("Compute unstructured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print(f"Elapsed time: {elapsed_time:.2f}s")
print(f"Number of points: {label.size}")
Compute unstructured hierarchical clustering...
Elapsed time: 0.02s
Number of points: 1500

情节结果#

绘制非结构化分层集群。

import matplotlib.pyplot as plt

fig1 = plt.figure()
ax1 = fig1.add_subplot(111, projection="3d", elev=7, azim=-80)
ax1.set_position([0, 0, 0.95, 1])
for l in np.unique(label):
    ax1.scatter(
        X[label == l, 0],
        X[label == l, 1],
        X[label == l, 2],
        color=plt.cm.jet(float(l) / np.max(label + 1)),
        s=20,
        edgecolor="k",
    )
_ = fig1.suptitle(f"Without connectivity constraints (time {elapsed_time:.2f}s)")
Without connectivity constraints (time 0.02s)

我们正在定义具有10个邻居的k-最近邻居#

from sklearn.neighbors import kneighbors_graph

connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)

计算聚类#

我们再次使用连通性约束执行AgglomerativeClustering。

print("Compute structured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(
    n_clusters=6, connectivity=connectivity, linkage="ward"
).fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print(f"Elapsed time: {elapsed_time:.2f}s")
print(f"Number of points: {label.size}")
Compute structured hierarchical clustering...
Elapsed time: 0.04s
Number of points: 1500

情节结果#

绘制结构化分层集群。

fig2 = plt.figure()
ax2 = fig2.add_subplot(121, projection="3d", elev=7, azim=-80)
ax2.set_position([0, 0, 0.95, 1])
for l in np.unique(label):
    ax2.scatter(
        X[label == l, 0],
        X[label == l, 1],
        X[label == l, 2],
        color=plt.cm.jet(float(l) / np.max(label + 1)),
        s=20,
        edgecolor="k",
    )
fig2.suptitle(f"With connectivity constraints (time {elapsed_time:.2f}s)")

plt.show()
With connectivity constraints (time 0.04s)

Total running time of the script: (0分0.286秒)

相关实例

硬币图像上的结构化Ward分层集群演示

A demo of structured Ward hierarchical clustering on an image of coins

有结构和不有结构的集聚

Agglomerative clustering with and without structure

图分层聚集树图

Plot Hierarchical Clustering Dendrogram

在玩具数据集上比较不同的聚类算法

Comparing different clustering algorithms on toy datasets

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _