KMeans聚类算法实现

news/2024/5/20 6:59:04 标签: 机器学习, 人工智能, 聚类, kmeans

目录

1. K-Means的工作原理

2.Kmeans损失函数

3.Kmeans优缺点

4.编写KMeans算法实现类

5.KMeans算法测试

6.结果


       Kmeans是一种无监督的基于距离的聚类算法,其变种还有Kmeans++。其中,sklearn中KMeans的默认使用的即为KMeans++。使用sklearn相关算法API的调用案例可参考博主另一篇文章:KMeans算法实现图像分割。本文主要通过纯手写的方式,帮助学习理解KMeans算法的数据处理过程。

1. K-Means的工作原理

       在K-Means算法中,簇的个数K是一个超参数,需要人为输入来确定。K-Means的核心任务就是根据设定好的K,找出K个最优的质心,并将离这些质心最近的数据分别分配到这些质心代表的簇中去。具体过程可以总结如下:

  • 首先随机选取样本中的K个点作为聚类中心;
  • 分别算出样本中其他样本距离这K个聚类中心的距离,并把这些样本分别作为自己最近的那个聚类中心的类别;
  • 对上述分类完的样本再进行每个类别求平均值,求解出新的聚类质心;
  • 与前一次计算得到的K个聚类质心比较,如果聚类质心发生变化,转过程b,否则转过程e;
  • 当质心不发生变化时(当我们找到一个质心,在每次迭代中被分配到这个质心上的样本都是一致的,即每次新生成的簇都是一致的,所有的样本点都不会再从一个簇转移到另一个簇,质心就不会变化了),停止并输出聚类结果。

综上,K-Means 的算法步骤能够简单概括为:

1-分配:样本分配到簇。

2-移动:移动聚类中心到簇中样本的平均位置。

2.Kmeans损失函数

和其他机器学习算法一样,K-Means 也要评估并且最小化聚类代价,在引入 K-Means 的代价函数之前,先引入如下定义:

引入代价函数:

3.Kmeans优缺点

优点:
1.容易理解,聚类效果不错,虽然是局部最优, 但往往局部最优就够了;
2.处理大数据集的时候,该算法可以保证较好的伸缩性;
3.当簇近似高斯分布的时候,效果非常不错;
4.算法复杂度低。

缺点:
1.K 值需要人为设定,不同 K 值得到的结果不一样;
2.对初始的簇中心敏感,不同选取方式会得到不同结果;
3.对异常值敏感;
4.样本只能归为一类,不适合多分类任务;
5.不适合太离散的分类、样本类别不平衡的分类、非凸形状的分类。

4.编写KMeans算法实现类

import numpy as np


class KMeans:
    def __init__(self, data, num_clusters):
        self.data = data
        self.num_clusters = num_clusters

    def train(self, max_iterations):
        centerids = KMeans.centerids_init(self.data, self.num_clusters)        
        num_examples = self.data.shape[0]        
        closest_centerids_ids = np.empty((num_examples, 1))        
        for _ in range(max_iterations):
            closest_centerids_ids = KMeans.centerids_find_closest(self.data, centerids)            
            centerids = KMeans.centerids_compute(self.data, closest_centerids_ids, self.num_clusters)        
        return centerids, closest_centerids_ids

    @staticmethod    
    def centerids_init(data, num_clusters):
        num_examples = data.shape[0]        
        random_ids = np.random.permutation(num_examples)        
        centerids = data[random_ids[:num_clusters], :]        
    return centerids

    @staticmethod    
    def centerids_find_closest(data, centerids):
        num_examples = data.shape[0]        
        num_centerids = centerids.shape[0]        
        closest_centerids_ids = np.zeros((num_examples, 1))        
        for example_index in range(num_examples):
            distance = np.zeros((num_centerids, 1))            
            for centerid_index in range(num_centerids):
                distance_diff = data[example_index, :] - centerids[centerid_index, :]                
                distance[centerid_index] = np.sum((distance_diff ** 2))            
                closest_centerids_ids[example_index] = np.argmin(distance)        
        return closest_centerids_ids

    @staticmethod    
    def centerids_compute(data, closest_centerids_ids, num_clusters):
        num_features = data.shape[1]        
        centerids = np.zeros((num_clusters, num_features))        
        for centerid in range(num_clusters):
            closest_ids = closest_centerids_ids == centerid
            centerids[centerid] = np.mean(data[closest_ids.flatten(), :], axis=0)        
        return centerids

5.KMeans算法测试

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

from cls_kmeans.k_means import KMeans

iris = load_iris()data = pd.DataFrame(data=iris.data, columns=iris.feature_names)
data["species"] = iris.target_names[iris.target]

# print(data.head())
# print(iris.feature_names)

x_axis = iris.feature_names[2]
y_axis = iris.feature_names[3]

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)  # 一行两列,第一个图
for iris_type in iris.target_names:
    plt.scatter(data[x_axis][data["species"] == iris_type],                
    data[y_axis][data["species"] == iris_type],                
    label=iris_type)
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label Known")
    plt.legend()
    
    plt.subplot(1, 2, 2)  # 一行两列,第二个图
    plt.scatter(data[x_axis][:], data[y_axis][:], label="all_type")
    plt.title("Label Unknown")
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.show()
    
    # print(np.unique(iris.target).shape[0])
    num_examples = data.shape[0]
    x_train = data[[x_axis, y_axis]].values.reshape(num_examples, 2)
    max_iterations = 50
    num_clusters = 3
    
    kmeans = KMeans(data=x_train, num_clusters=num_clusters)
    (centerids, closest_centerids_ids) = kmeans.train(max_iterations=max_iterations)
    
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)  # 一行两列,第一个图
    for iris_type in iris.target_names:
        plt.scatter(data[x_axis][data["species"] == iris_type],                
                    data[y_axis][data["species"] == iris_type],                
                    label=iris_type)
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label Known")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    for centerid_id, centerid in enumerate(centerids):
        current_example_index = (closest_centerids_ids == centerid_id).flatten()    
        plt.scatter(data[x_axis][current_example_index],                
                    data[y_axis][current_example_index],                
                    label=centerid_id)

    for centerid_id, centerid in enumerate(centerids):
        plt.scatter(centerid[0], centerid[1], c="black", marker="x")

    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.title("Label KMeans")
    plt.legend()
    plt.show()

6.结果


http://www.niftyadmin.cn/n/1008667.html

相关文章

【Python 随练】插入元素到已排序数组

题目: 有一个已经排好序的数组。现输入一个数,要求按原来的规律将它插入数组中。 简介: 在本篇博客中,我们将解决一个数组操作问题:将一个数插入已经排好序的数组中,同时保持原有的排序规律。我们将介绍…

Linux的权限管理操作(权限设置chmod、属主chown与所组设置chgrp)

Linux的权限管理 权限概述权限介绍身份介绍Owner身份(文件所有者,默认为文档的创建者)Group身份(与文件所有者同组的用户)Others身份(其他人,相对于所有者)Root用户(超级…

SELECT * 会导致查询效率低的原因

SELECT * 会导致查询效率低的原因 前言一、适合SELECT * 的使用场景二、SELECT * 会导致查询效率低的原因2.1、数据库引擎的查询流程2.2、SELECT * 的实际执行过程2.3、使用 SELECT * 查询语句带来的不良影响 三、优化查询效率的方法四、总结 前言 因为 SELECT * 查询语句会查…

JavaSE学习总结(五)

聊一聊你对泛型的理解 泛型就是将类型参数化,在编译时才确定具体的参数,泛型可以用于类、接口和方法,即泛型类、泛型接口和泛型方法。使用泛型主要有两个好处,第一是提高Java程序的类型安全,这也是泛型的主要目的&…

Ai绘画-Midjourney常用关键词

一、视角关键词 视角关键词近距离景Tight Shot两人/物景Two Shot (2S), Three Shot (3S), Group Shot (GS)三人/物景Three Shot (3S), Group Shot (GS)风景照Scenery Shot背景虚化Bokeh前景Foreground背景Background细节镜头Detail Shot (ECU)面部拍摄Face Shot (VCU)膝景Knee …

HOT31-K个一组翻转链表

leetcode原题链接:K个一组翻转链表 题目描述 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那么请将最后剩余…

Radzen Blazor Studio 1.12 Crack

Radzen Blazor Studio 是一款桌面工具,使 开发人员 能够创建精美的商业 Blazor 应用程序。快速地。 开放技术栈 没有供应商锁定。生成的源代码是人类可读的,您可以使用免费工具构建它。 Radzen 由流行的开源技术 - ASP.NET Core、Blazor、Bootstrap 提供…

基于大数据技术对基金分析-python

提示:本文为个人原创,仅供技术探讨与交流,对实际投资并不造成建议。 基于大数据技术对基金分析-python 前言一、数据获取:python爬虫1).从天天基金数据接口获取数据2).爬虫前期准备3).爬虫具体实现 二、数据清洗及计算指标1.过滤数…