欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

纯Python和PyTorch对比实现SGD, Momentum, RMSprop, Adam梯度下降算法

程序员文章站 2022-06-27 10:37:25
...

摘要

本文使用纯 Python 和 PyTorch 对比实现SGD, Momentum, RMSprop, Adam梯度下降算法.

相关

原理和详细解释, 请参考: :

常用梯度下降算法SGD, Momentum, RMSprop, Adam详解

文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. 算法类

文件目录 : vanilla_nn/optim.py

import numpy as np


class SGD:
    def __init__(self, lr=0.01):
        self.lr = lr

    def __call__(self, params, grads):
        params -= self.lr * grads


class Momentum:
    def __init__(self, lr=0.01, momentum=0.9):
        self.lr = lr
        self.momentum = momentum
        self.v = None

    def __call__(self, params, grads):
        if self.v is None:
            self.v = np.zeros_like(params)

        self.v = self.momentum * self.v + grads
        params -= self.lr * self.v


class RMSProp:
    def __init__(self, lr=0.01, alpha=0.9, eps=1e-08):
        self.lr = lr
        self.alpha = alpha
        self.eps = eps
        self.v = None

    def __call__(self, params, grads):
        if self.v is None:
            self.v = np.zeros_like(params)

        self.v = self.alpha * self.v
        self.v += (1 - self.alpha) * np.square(grads)
        eta = self.lr / (np.sqrt(self.v) + self.eps)
        params -= eta * grads


class Adam:
    def __init__(self, lr=0.01, betas=(0.9, 0.999), eps=1e-08):
        self.lr = lr
        self.beta1 = betas[0]
        self.beta2 = betas[1]
        self.eps = eps
        self.m = None
        self.v = None
        self.n = 0

    def __call__(self, params, grads):
        if self.m is None:
            self.m = np.zeros_like(params)
        if self.v is None:
            self.v = np.zeros_like(params)

        self.n += 1

        self.m = self.beta1 * self.m + (1 - self.beta1) * grads
        self.v = self.beta2 * self.v + (1 - self.beta2) * np.square(grads)

        alpha = self.lr * np.sqrt(1 - np.power(self.beta2, self.n))
        alpha = alpha / (1 - np.power(self.beta1, self.n))

        params -= alpha * self.m / (np.sqrt(self.v) + self.eps)

2. 算法检查

import torch
import numpy as np
from vanilla_nn.optim import SGD, Momentum, RMSProp, Adam


def check_optim(optim_numpy, optim_torch, p, p_torch):
    """
    check with y = p * x^2
    optim param p
    """
    x_size = 5
    x = np.random.random(x_size)
    x_torch = torch.tensor(x, requires_grad=True)

    dxi_numpy_list = []
    for i in range(x_size):
        yi_numpy = p * x[i] ** 2
        dxi_numpy = 2 * p * x[i]
        dxi_numpy_list.append(dxi_numpy)

        da = x[i] ** 2
        optim_numpy(p, da)

    for i in range(x_size):
        yi_torch = p_torch * x_torch[i] ** 2
        optim_torch.zero_grad()
        yi_torch.backward()
        optim_torch.step()

    print(np.array(dxi_numpy_list))
    print(x_torch.grad.data.numpy())


np.random.seed(123)
np.set_printoptions(precision=12, suppress=True, linewidth=80)

print("--- 检查SGD ---")
a_numpy = np.array(1.2)
a_torch = torch.tensor(a_numpy, requires_grad=True)
sgd_numpy = SGD(0.1)
sgd_torch = torch.optim.SGD([a_torch], lr=0.1)
check_optim(sgd_numpy, sgd_torch, a_numpy, a_torch)

print("--- 检查Momentum ---")
a_numpy = np.array(1.2)
a_torch = torch.tensor(a_numpy, requires_grad=True)
momentum_numpy = Momentum(0.1, 0.9)
momentum_torch = torch.optim.SGD([a_torch], lr=0.1, momentum=0.9)
check_optim(momentum_numpy, momentum_torch, a_numpy, a_torch)

print("--- 检查RMSProp ---")
a_numpy = np.array(1.2)
a_torch = torch.tensor(a_numpy, requires_grad=True)
rms_numpy = RMSProp(0.1, 0.9, eps=1e-08)
rms_torch = torch.optim.RMSprop([a_torch], lr=0.1, alpha=0.9)
check_optim(rms_numpy, rms_torch, a_numpy, a_torch)

print("--- 检查Adam ---")
a_numpy = np.array(1.2)
a_torch = torch.tensor(a_numpy, requires_grad=True)
adam_numpy = Adam(lr=0.1, betas=(0.9, 0.99), eps=0.001)
adam_torch = torch.optim.Adam([a_torch], lr=0.1, betas=(0.9, 0.99), eps=0.001)
check_optim(adam_numpy, adam_torch, a_numpy, a_torch)

"""
--- 检查SGD ---
[ 1.671526045435  0.658974920984  0.518721027022  1.254968104394  1.594004424417]
[ 1.671526045435  0.658974920984  0.518721027022  1.254968104394  1.594004424417]
--- 检查Momentum ---
[ 1.015455504299  2.318718975892  1.46525696166   0.886671018481  0.600349866658]
[ 1.015455504299  2.318718975892  1.46525696166   0.886671018481  0.600349866658]
--- 检查RMSProp ---
[ 0.823627238762  1.288627900967  0.503750904131  0.055347017535  0.367439505846]
[ 0.823627238762  1.288627900967  0.503750904131  0.055347017535  0.367439505846]
--- 检查Adam ---
[ 1.771188973757  0.402139865001  0.361961132239  1.035028538823  0.96262110244 ]
[ 1.771188973757  0.402139865001  0.361961132239  1.035028538823  0.96262110244 ]
"""