欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

RNN cell的实现

程序员文章站 2024-03-25 10:45:16
...

RNN cell的实现

flyfish

已编译通过
步骤
1 使用tanh**函数计算隐藏状态
at=tanh(Waaat1+Waxxt+ba)

2 使用新的隐藏状态at 计算预测值,
y^t=softmax(Wyaat+by)
已提供softmax函数

3 在cache中存储 (at,at1,xt,parameters)
4 返回 at,yt,cache

import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def rnn_cell_forward(xt, a_prev, parameters):

    # Retrieve parameters from "parameters"
    #字符串与值的一种映射方法
    Wax = parameters["Wax"]
    Waa = parameters["Waa"]
    Wya = parameters["Wya"]
    ba = parameters["ba"]
    by = parameters["by"]

    #按照上面的公式写就行
    # compute next activation state using the formula given above np.tanh
    #使用上面的np.tanh公式计算下一个**状态
    a_next = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)
    yt_pred = softmax(np.dot(Wya, a_next) + by)


    # store values you need for backward propagation in cache
    #cache在反向传播中会使用
    cache = (a_next, a_prev, xt, parameters)

    return a_next, yt_pred, cache


np.random.seed(1)
xt = np.random.randn(3,10)
a_prev = np.random.randn(5,10)
Waa = np.random.randn(5,5)
Wax = np.random.randn(5,3)
Wya = np.random.randn(2,5)
ba = np.random.randn(5,1)
by = np.random.randn(2,1)
parameters = {"Waa": Waa, "Wax": Wax, "Wya": Wya, "ba": ba, "by": by}

a_next, yt_pred, cache = rnn_cell_forward(xt, a_prev, parameters)
print("a_next = ", a_next)
print("a_next.shape = ", a_next.shape)
print("yt_pred[1] =", yt_pred[1])
print("yt_pred.shape = ", yt_pred.shape)
相关标签: RNN