菜菜的深度学习笔记 | 基于Python的理论与实现(六)—>简单两层网络的实现

news/2024/5/20 6:59:11 标签: 算法, 聚类, 机器学习, 人工智能, 深度学习

在这里插入图片描述

系列索引:菜菜的深度学习笔记 | 基于Python的理论与实现

文章目录

  • 一、学习算法的实现
    • (1)神经网络的学习步骤
    • (2)神经网络的类
    • (3)mini-batch的实现
    • (4)基于测试数据的评价

一、学习算法的实现

(1)神经网络的学习步骤

前提:调整权重和偏置以便拟合训练数据的过程称为学习,主要有以下4个步骤:

步骤1(mini-batch)

从训练数据中随机选出一部分数据,这部分数据称为mini-batch。我们的目标是减小mini-batch的损失函数的值。

步骤2(计算梯度)

为了减小损失函数的值,需要求出各个权重参数的梯度。梯度表示损失函数的值减小最多的方向。

步骤3(更新参数)

将权重参数沿梯度方向进行微小更新。

步骤4(重复)

重复步骤1~3

神经网络的学习按照按照上面4个步骤进行,这个方法通过梯度下降法更新参数,不过因为这里使用的数据是随机选择的mini-batch数据,所以又被称为随机梯度下降法,随机指的是“随机选择的”,因此随机梯度下降法是“对随机选择的数据进行的梯度下降法”。在许多深度学习框架中,随机梯度下降法是由一个名为SGD的函数来实现。

(2)神经网络的类

先来看一段代码:

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
from common.functions import *
from common.gradient import numerical_gradient


class TwoLayerNet:

    def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
        # 初始化权重
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
        self.params['b2'] = np.zeros(output_size)

    def predict(self, x):
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
    
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        
        return y
        
    # x:输入数据, t:监督数据
    def loss(self, x, t):
        y = self.predict(x)
        
        return cross_entropy_error(y, t)
    
    def accuracy(self, x, t):
        y = self.predict(x)
        y = np.argmax(y, axis=1)
        t = np.argmax(t, axis=1)
        
        accuracy = np.sum(y == t) / float(x.shape[0])
        return accuracy
        
    # x:输入数据, t:监督数据
    def numerical_gradient(self, x, t):
        loss_W = lambda W: self.loss(x, t)
        
        grads = {}
        grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
        grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
        grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
        grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
        
        return grads
        
    def gradient(self, x, t):
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
        grads = {}
        
        batch_num = x.shape[0]
        
        # forward
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        
        # backward
        dy = (y - t) / batch_num
        grads['W2'] = np.dot(z1.T, dy)
        grads['b2'] = np.sum(dy, axis=0)
        
        da1 = np.dot(dy, W2.T)
        dz1 = sigmoid_grad(a1) * da1
        grads['W1'] = np.dot(x.T, dz1)
        grads['b1'] = np.sum(dz1, axis=0)

        return grads

TwoLayerNetparamsgrads两个字典型实例变量,params变量中保存了权重参数,grads变量中保存了各个参数的梯度。

初始化方法会对权重参数进行初始化,非常重要,权重使用符合高斯分布的随机数进行初始化,偏置使用0进行初始化。

numerical_graduent方法会基于数值微分计算各个参数的梯度,之后会介绍一种更高速的计算梯度的方法称为误差反向传播法。

(3)mini-batch的实现

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from two_layer_net import TwoLayerNet

# 读入数据
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)

iters_num = 10000  # 适当设定循环的次数
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1

train_loss_list = []
train_acc_list = []
test_acc_list = []

iter_per_epoch = max(train_size / batch_size, 1)

for i in range(iters_num):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    # 计算梯度
    #grad = network.numerical_gradient(x_batch, t_batch)
    grad = network.gradient(x_batch, t_batch)
    
    # 更新参数
    for key in ('W1', 'b1', 'W2', 'b2'):
        network.params[key] -= learning_rate * grad[key]
    
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)
    
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))

# 绘制图形
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

这里mini-batch的大小为100,需要每次从60000个数据中随机取100个,然后求梯度,使用SGD进行更新参数。

如需完整源码包可关注并私信我。

(4)基于测试数据的评价

神经网络的学习中,必须确认是否能够正确识别训练数据以外的其他数据,即确认是否会发生过拟合。神经网络的目的是掌握泛化能力,因此,要评价神经网络的泛化能力就必须使用不包含在训练数据中的数据。

基于Python的理论与实现 系列持续更新,欢迎点赞收藏关注

上一篇:菜菜的深度学习笔记 | 基于Python的理论与实现(五)
下一篇:菜菜的深度学习笔记 | 基于Python的理论与实现(七)—>误差反向传播

本人水平有限,文章中不足之处欢迎下方👇评论区批评指正~

如果感觉对你有帮助,点个赞👍 支持一下吧 ~

不定期分享 有趣、有料、有营养内容,欢迎 订阅关注 🤝 我的博客 ,期待在这与你相遇 ~


http://www.niftyadmin.cn/n/1425767.html

相关文章

Hibernate Annotation1

1.概述 值得期待的Hibernate Annotation 式配置终于随着Hibernate 3.2GA 版本的发布而宣布正式被支持了! 只要数据库以及字段名称设计合适,我们甚至只需要在原来程序上加上3行代码,就可以配置完成一个Bean。 这依稀看到了Rails 的影子...…

菜菜的深度学习笔记 | 基于Python的理论与实现(七)—>误差反向传播

系列索引:菜菜的深度学习笔记 | 基于Python的理论与实现 文章目录一、误差反向传播法(1)基础概念(2)计算图(3)链式法则(4)反向传播一、误差反向传播法 (1&am…

菜菜的深度学习笔记 | 基于Python的理论与实现(八)—>简单层的实现

系列索引:菜菜的深度学习笔记 | 基于Python的理论与实现 文章目录(1)乘法层、加法层的实现(2)激活函数层的实现1.ReLU层2.Sigmoid层我们以购买苹果和橘子的例子来了解一下计算图是如何应用的 (1&#xff0…

数据分析入门 | kaggle泰坦尼克任务(一)—>数据加载和初步观察

系列索引:数据分析入门 | kaggle泰坦尼克任务 文章目录一、数据加载(1)载入数据(2)逐块读取二、初步观察一、数据加载 本次主要以实战的方式了解数据分析的流程和熟悉数据分析python的基本操作,完成kaggle…

Hibernate Annotation2

用EJB3注释进行映射   现在EJB3实体Bean是纯粹的POJO。实际上表达了和Hibernate持久化实体对象同样的概念。他们的映射都通过JDK5.0注释来定义(EJB3规范中的XML描述语法至今还没有定下来)。注释分为两个部分,分别是逻辑映射注释和物理映射注释,通过逻辑…

数据分析入门 | kaggle泰坦尼克任务

这个章节主要是参加DataWhale的数据分析项目过程中的记录,希望能对感兴趣的同学有一些帮助。 目录索引一、章节导航二、其他集合一、章节导航 数据分析入门 | kaggle泰坦尼克任务(一)—>数据加载和初步观察数据分析入门 | kaggle…

数据分析入门 | kaggle泰坦尼克任务(二)—>pandas基础

系列索引:数据分析入门 | kaggle泰坦尼克任务 文章目录一、pandas基础(1)数据类型(2)基本操作(3)筛选的逻辑(4)loc函数和iloc函数:一、pandas基础 &#xff…

hibernate annotation 3 使用

1 总共需要以下的jar mysql-connector-java-3.0.7-stable-bin.jar Hibernate 3.2、 Hibernate Annotations 3.2(hibernate-annotations.jar、ejb3-persistence.jar) 2 mysql中新建数据库demo 表 CREATE TABLE user ( id INT(11) NOT NULL auto_increment PRIMARY KEY, …