使用 Numpy 实现 K-Means 聚类算法

news/2024/5/20 6:03:07 标签: 机器学习, 聚类

K-Means 算法原理链接.

使用时,实例化类后,只需关注 fit(), predict(),传入数据类型为np.array,形状为 N x M。

class MyKMeans:
    labels_ = []  # fit 后每类数据的标签
    cluster_centers_ = None  # N x M, 聚类中心个数
    __cluster_centers_dict = dict()

    def __init__(self, n_clusters=3, n_init=3, max_iter=300):
        """
        :param n_clusters: 聚类中心个数
        :param n_init: 随机初始化几次聚类中心, 结果从最优中选择
        :param max_iter: 迭代次数
        """
        self.n_clusters = n_clusters
        self.n_init = n_init
        self.max_iter = max_iter
        for i in range(n_init):
            self.__cluster_centers_dict[str(i)] = None

    # 拟合
    def fit(self, X):
        # X.shape : N x M
        for i in tqdm(range(self.n_init)):
            self.__init_cluster_centers(X)
            sum_dis = 0
            for j in range(self.max_iter):
                sum_dis, _ = self.__distance(X)
                # print(_.shape)
                self.__change_cluster_centers(X)
            self.__cluster_centers_dict[str(i)] = [sum_dis, self.cluster_centers_, self.labels_]
        # print(list(self.__cluster_centers_dict.items())[0][1][0])
        self.cluster_centers_, self.labels_ = \
            sorted(list(self.__cluster_centers_dict.items()), key=lambda item: item[1][0])[0][1][1:]

    # 预测
    def predict(self, x):
        # x.shape : N x M
        _, dis = self.__distance(x, is_fit=False)
        # print(dis.shape)
        return np.argmin(dis, axis=1)

    # 算距离
    def __distance(self, X, is_fit=True):
        distances = np.sum(np.power(X - self.cluster_centers_[:, np.newaxis, ...], 2), axis=2).T
        if is_fit:
            self.labels_ = np.argmin(distances, axis=1)
        return np.sum(distances), distances

    # 修改聚类中心
    def __change_cluster_centers(self, X):
        new_cc = np.zeros(self.cluster_centers_.shape)
        sum_l = np.zeros(self.cluster_centers_.shape[0])
        for i in range(len(self.labels_)):
            # print(new_cc.shape, X.shape, new_cc[self.labels_[i]].shape, X[i].shape, new_cc[self.labels_[i]], X[i])
            new_cc[self.labels_[i]] = new_cc[self.labels_[i]] + X[i]
            sum_l[self.labels_[i]] += 1
        sum_l = sum_l[:, np.newaxis]
        sum_l = np.repeat(sum_l, repeats=new_cc.shape[1], axis=1)
        self.cluster_centers_ = new_cc / sum_l

    # 随机初始化聚类中心
    def __init_cluster_centers(self, X):
        indexs = random.sample(range(0, X.shape[0]), self.n_clusters)
        self.cluster_centers_ = np.array([X[i] for i in indexs])

http://www.niftyadmin.cn/n/1362963.html

相关文章

操作系统作业调度-操作系统

操作系统作业调度--操作系统 一、目的和要求 1. 实验目的 (1)加深对作业调度算法的理解; (2)进行程序设计的训练。 2.实验要求 用高级语言编写一个或多个作业调度的模拟程序。 单道批处理系统的作业调度程序…

Paddle2.0实现PSPNet进行人体解析(图像分割)

Paddle2.0实现PSPNet进行人体解析(图像分割)项目背景概述前言PSPNet介绍为什么会提出PSPNet ?PSPNet 的效果为什么好 ?PSPNet 是怎样考虑上下文信息的 ?PSPNet 是怎样增大感受野的 ?PSPNet - 金字塔模块(Pyramid Pooling)PSPNet 的总体结构PSPNet - …

docker WARNING: bridge-nf-call-iptables is disabled 处理

在CentOS中 vim /etc/sysctl.conf net.bridge.bridge-nf-call-ip6tables 1 net.bridge.bridge-nf-call-iptables 1 net.bridge.bridge-nf-call-arptables 1 然后就解决了 docker info 就不显示了转载于:https://www.cnblogs.com/jackluo/p/5422243.html

基于PaddleX实现电梯电瓶车检测

基于PaddleX实现电梯电瓶车检测一、项目背景二、数据集简介1. 数据获取2. 解压数据集和安装第三方库3. COCO2VOC4. 将pascal-voc和coco2017中的摩托车数据提取出来5. 将文件路径写入.txt三、模型选择与开发1. 安装paddlex2. 数据加载与预处理3. 模型选择4. 模型训练四、效果展示…

DjangoORM一对多多对多操作

简要说明 通过操作对象的方式操作数据库 详细步骤 models.py的结构是: 如果models.py中外键定义没有写 related_name’student_teacher’, 可以直接用 studentList teacher.student_teacher.all() 可以改写成:teacher Teacher.objects.get(id 1)stude…

hadoop报错:HADOOP_HOME and hadoop.home.dir are unset. 解决方法

目录报错信息解决方法1.下载apache-hadoop-3.1.0-winutils-master2.解压到宿主机3.添加环境变量4.重启IDEA或eclipse报错信息 java.lang.RuntimeException: java.io.FileNotFoundException: java.io.FileNotFoundException: HADOOP_HOME and hadoop.home.dir are unset. java…

20145122《敏捷开发与XP实践 》实验三实验报告

实验名称 敏捷开发与XP实践 实验内容 1.团队代码要使用git在实验楼中托管,要使用结对同学中的一个同学的账号托管。 2.使用git推送代码并对结对同学的代码修改完成后再git推送。 3.掌握重构流程。 统计的PSP(Personal Software Process)时间 步骤耗时百分比需求分析…

JS:变量,typeof,数据类型,数据类型转换,for-in语句,with语句

声明当前执行环境的变量: var message; var message hi; var message hi, found false, age 29; 声明全局变量: message hi; typeof:用来检测给定变量的数据类型 返回值: undefined:这个值未定义;boolean:这个值是…