Toujours dans le module sklearn
, et particulièrement le sous-module cluster
AgglomerativeClustering
: réalisation de la CAHKMeans
: réalisation de $k$-meansimport pandas
import numpy
import matplotlib.pyplot as plt
import seaborn
seaborn.set_style("white")
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
from sklearn.preprocessing import scale
Données sur des iris disponibles ici
iris = pandas.read_table("https://fxjollois.github.io/donnees/Iris.txt", sep = "\t")
iris.head()
Sepal Length | Sepal Width | Petal Length | Petal Width | Species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
Comme l'ACP, la classification avec la CAH et $k$-means ne se fait uniquement que sur des variables quantitatives
iris2 = iris.drop("Species", axis = 1)
iris2.head()
Sepal Length | Sepal Width | Petal Length | Petal Width | |
---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 |
1 | 4.9 | 3.0 | 1.4 | 0.2 |
2 | 4.7 | 3.2 | 1.3 | 0.2 |
3 | 4.6 | 3.1 | 1.5 | 0.2 |
4 | 5.0 | 3.6 | 1.4 | 0.2 |
hac = AgglomerativeClustering(distance_threshold = 0, n_clusters = None)
hac.fit(scale(iris2))
AgglomerativeClustering(distance_threshold=0, n_clusters=None)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
AgglomerativeClustering(distance_threshold=0, n_clusters=None)
Création d'une fonction en se basant sur cette page (avec quelques modifications)
from scipy.cluster.hierarchy import dendrogram
def plot_dendrogram(model, **kwargs):
# Create linkage matrix and then plot the dendrogram
# create the counts of samples under each node
counts = numpy.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = numpy.column_stack([model.children_, model.distances_, counts]).astype(float)
# Plot the corresponding dendrogram
dendrogram(linkage_matrix, **kwargs)
plt.figure(figsize = (16, 8))
plt.title("CAH (Ward)")
plot_dendrogram(hac)
plt.axhline(y = 20, linewidth = .5, color = "dimgray", linestyle = "--")
plt.axhline(y = 10, linewidth = .5, color = "dimgray", linestyle = "--")
plt.show()
La méthode propose une partition en un nombre de classes choisi via un algorithme interne.
hac2 = AgglomerativeClustering()
hac2.fit(scale(iris2))
AgglomerativeClustering()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
AgglomerativeClustering()
hac2.labels_
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Mais on peut bien évidemment choisir notre propre nombre de classes.
hac3 = AgglomerativeClustering(n_clusters = 3)
hac3.fit(scale(iris2))
AgglomerativeClustering(n_clusters=3)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
AgglomerativeClustering(n_clusters=3)
hac3.labels_
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 0, 2, 0, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 2, 0, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Très généralement, pour comprendre les classes et les commenter, nous calculons les centres de celles-ci (valeurs moyennes pour chaque variable)
iris2.assign(classe = hac3.labels_).groupby("classe").mean()
Sepal Length | Sepal Width | Petal Length | Petal Width | |
---|---|---|---|---|
classe | ||||
0 | 6.546479 | 2.992958 | 5.267606 | 1.854930 |
1 | 5.016327 | 3.451020 | 1.465306 | 0.244898 |
2 | 5.530000 | 2.566667 | 3.930000 | 1.206667 |
kmeans = KMeans(n_clusters = 3, n_init = 20)
kmeans.fit(scale(iris2))
KMeans(n_clusters=3, n_init=20)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
KMeans(n_clusters=3, n_init=20)
pandas.Series(kmeans.labels_).value_counts()
2 53 1 50 0 47 Name: count, dtype: int64
On obtient les centres des classes automatiquement. Ayant utilisé les données centrées-réduites, leur analyse est simple par un lecteur avisé (valeur positive $\rightarrow$ supérieure à la moyenne, et inversement).
kmeans.cluster_centers_
array([[ 1.13597027, 0.08842168, 0.99615451, 1.01752612], [-1.01457897, 0.85326268, -1.30498732, -1.25489349], [-0.05021989, -0.88337647, 0.34773781, 0.2815273 ]])
Mais pour présenter les classes, on va préférer recalculer ces centres sur les données originelles.
iris2.assign(classe = kmeans.labels_).groupby("classe").mean()
Sepal Length | Sepal Width | Petal Length | Petal Width | |
---|---|---|---|---|
classe | ||||
0 | 6.780851 | 3.095745 | 5.510638 | 1.972340 |
1 | 5.006000 | 3.428000 | 1.462000 | 0.246000 |
2 | 5.801887 | 2.673585 | 4.369811 | 1.413208 |
inertia = []
for k in range(1, 11):
kmeans = KMeans(n_clusters = k, init = "random", n_init = 20).fit(scale(iris2))
inertia = inertia + [kmeans.inertia_]
rsquare = [(inertia[0] - i) / inertia[0] for i in inertia]
criteres = pandas.DataFrame({
"k": range(1, 11),
"inertia": inertia,
"rsquare": rsquare,
"pseudof": [(rsquare[k-1] / k) / ((1 - rsquare[k-1]) / (150 - k)) if k > 1 else None for k in range(1, 11)]
})
print(criteres)
k inertia rsquare pseudof 0 1 600.000000 0.000000 NaN 1 2 222.361705 0.629397 125.674670 2 3 139.820496 0.766966 161.269601 3 4 114.092547 0.809846 155.449436 4 5 91.047670 0.848254 162.108680 5 6 81.744006 0.863760 152.159706 6 7 70.749389 0.882084 152.818761 7 8 62.616113 0.895640 152.334017 8 9 54.531599 0.909114 156.710453 9 10 47.223240 0.921295 163.878520
seaborn.lineplot(data = criteres, x = "k", y = "inertia")
plt.scatter(2, criteres.query('k == 2')["inertia"], c = "red")
plt.scatter(3, criteres.query('k == 3')["inertia"], c = "red")
plt.show()
seaborn.lineplot(data = criteres, x = "k", y = "rsquare")
plt.scatter(2, criteres.query('k == 2')["rsquare"], c = "red")
plt.scatter(3, criteres.query('k == 3')["rsquare"], c = "red")
plt.show()
seaborn.lineplot(data = criteres, x = "k", y = "pseudof")
plt.scatter(3, criteres.query('k == 3')["pseudof"], c = "red")
plt.scatter(5, criteres.query('k == 5')["pseudof"], c = "red")
plt.show()
Nous reprenons les données sur le vin disponible sur cette page du site l'UCI MLR. Voici le code pour récupérer les données
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data"
wine = pandas.read_csv(url, header = None, sep = ",")
wine.columns = ["class", "Alcohol", "Malic acid", "Ash", "Alcalinity of ash", "Magnesium",
"Total phenols", "Flavanoids", "Nonflavanoid phenols", "Proanthocyanins",
"Color intensity", "Hue", "OD280/OD315 of diluted wines", "Proline"]
wine
class | Alcohol | Malic acid | Ash | Alcalinity of ash | Magnesium | Total phenols | Flavanoids | Nonflavanoid phenols | Proanthocyanins | Color intensity | Hue | OD280/OD315 of diluted wines | Proline | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 14.23 | 1.71 | 2.43 | 15.6 | 127 | 2.80 | 3.06 | 0.28 | 2.29 | 5.64 | 1.04 | 3.92 | 1065 |
1 | 1 | 13.20 | 1.78 | 2.14 | 11.2 | 100 | 2.65 | 2.76 | 0.26 | 1.28 | 4.38 | 1.05 | 3.40 | 1050 |
2 | 1 | 13.16 | 2.36 | 2.67 | 18.6 | 101 | 2.80 | 3.24 | 0.30 | 2.81 | 5.68 | 1.03 | 3.17 | 1185 |
3 | 1 | 14.37 | 1.95 | 2.50 | 16.8 | 113 | 3.85 | 3.49 | 0.24 | 2.18 | 7.80 | 0.86 | 3.45 | 1480 |
4 | 1 | 13.24 | 2.59 | 2.87 | 21.0 | 118 | 2.80 | 2.69 | 0.39 | 1.82 | 4.32 | 1.04 | 2.93 | 735 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
173 | 3 | 13.71 | 5.65 | 2.45 | 20.5 | 95 | 1.68 | 0.61 | 0.52 | 1.06 | 7.70 | 0.64 | 1.74 | 740 |
174 | 3 | 13.40 | 3.91 | 2.48 | 23.0 | 102 | 1.80 | 0.75 | 0.43 | 1.41 | 7.30 | 0.70 | 1.56 | 750 |
175 | 3 | 13.27 | 4.28 | 2.26 | 20.0 | 120 | 1.59 | 0.69 | 0.43 | 1.35 | 10.20 | 0.59 | 1.56 | 835 |
176 | 3 | 13.17 | 2.59 | 2.37 | 20.0 | 120 | 1.65 | 0.68 | 0.53 | 1.46 | 9.30 | 0.60 | 1.62 | 840 |
177 | 3 | 14.13 | 4.10 | 2.74 | 24.5 | 96 | 2.05 | 0.76 | 0.56 | 1.35 | 9.20 | 0.61 | 1.60 | 560 |
178 rows × 14 columns