欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

深度学习库 trax 简单事例Trax Quick Intro

程序员文章站 2022-07-14 14:52:19
...

Google 的深度学习库Trax,其简版教程(Trax Quick Intro)需要通过*才能看,这次当个搬运工,这是一个关于transformer训练和预测的简要介绍:

深度学习库 trax 简单事例Trax Quick Intro
深度学习库 trax 简单事例Trax Quick Intro
导入包:

! pip install -q -U trax
! pip install -q tensorflow

import os
import numpy as np
import trax

模拟训练数据:

# Construct inputs, see one batch
def copy_task(batch_size, vocab_size, length):
  """This task is to copy a random string w, so the input is 0w0w."""
  while True:
    assert length % 2 == 0
    w_length = (length // 2) - 1
    w = np.random.randint(low=1, high=vocab_size-1,
                          size=(batch_size, w_length))
    zero = np.zeros([batch_size, 1], np.int32)
    loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)),
                                   np.ones((batch_size, w_length))], axis=1)
    x = np.concatenate([zero, w, zero, w], axis=1)
    yield (x, x, loss_weights)  # Here inputs and targets are the same.
copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))

# Peek into the inputs.
data_stream = copy_inputs.train_stream(1)
inputs, targets, mask = next(data_stream)
print("Inputs[0]:  %s" % str(inputs[0]))
print("Targets[0]: %s" % str(targets[0]))
print("Mask[0]:    %s" % str(mask[0]))

Inputs[0]: [ 0 6 13 29 22 0 6 13 29 22]
Targets[0]: [ 0 6 13 29 22 0 6 13 29 22]
Mask[0]: [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]

Transformer训练:

# Transformer LM
def tiny_transformer_lm(mode):
  return trax.models.TransformerLM(   # You can try trax_models.ReformerLM too.
    d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)

# Train tiny model with Trainer.
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl  # Remove old model.
trainer = trax.supervised.Trainer(
    model=tiny_transformer_lm,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.
    lr_schedule=trax.lr.MultifactorSchedule,  # Change lr schedule here.
    inputs=copy_inputs,
    output_dir=output_dir,
    has_weights=True)  # Because we have loss mask, this API may change.

# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.
n_epochs  = 3
train_steps = 500
eval_steps = 2
for _ in range(n_epochs):
  trainer.train_epoch(train_steps, eval_steps)

Step 500: Ran 500 train steps in 16.51 secs
Step 500: Evaluation
Step 500: train accuracy | 0.53125000
Step 500: train loss | 1.83887446
Step 500: train neg_log_perplexity | -1.83887446
Step 500: train weights_per_batch_per_core | 80.00000000
Step 500: eval accuracy | 0.52500004
Step 500: eval loss | 1.92791247
Step 500: eval neg_log_perplexity | -1.92791247
Step 500: eval weights_per_batch_per_core | 80.00000000
Step 500: Finished evaluation
Step 1000: Ran 500 train steps in 2.54 secs
Step 1000: Evaluation
Step 1000: train accuracy | 1.00000000
Step 1000: train loss | 0.00707983
Step 1000: train neg_log_perplexity | -0.00707983
Step 1000: train weights_per_batch_per_core | 80.00000000
Step 1000: eval accuracy | 1.00000000
Step 1000: eval loss | 0.01029818
Step 1000: eval neg_log_perplexity | -0.01029818
Step 1000: eval weights_per_batch_per_core | 80.00000000
Step 1000: Finished evaluation
Step 1500: Ran 500 train steps in 2.46 secs
Step 1500: Evaluation
Step 1500: train accuracy | 1.00000000
Step 1500: train loss | 0.00037777
Step 1500: train neg_log_perplexity | -0.00037777
Step 1500: train weights_per_batch_per_core | 80.00000000
Step 1500: eval accuracy | 1.00000000
Step 1500: eval loss | 0.00037660
Step 1500: eval neg_log_perplexity | -0.00037660
Step 1500: eval weights_per_batch_per_core | 80.00000000
Step 1500: Finished evaluation

模型预测:

# Initialize model for inference.
predict_model = tiny_transformer_lm(mode='predict')
predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)
predict_model.init(predict_signature)
predict_model.init_from_file(os.path.join(output_dir, "model.pkl"),
                             weights_only=True)
# You can also do: predict_model.weights = trainer.model_weights

# Run inference
prefix = [0, 1, 2, 3, 4, 0]   # Change non-0 digits to see if it's copying
cur_input = np.array([[0]])
result = []
for i in range(10):
  logits = predict_model(cur_input)
  next_input = np.argmax(logits[0, 0, :], axis=-1)
  if i < len(prefix) - 1:
    next_input = prefix[i]
  cur_input = np.array([[next_input]])
  result.append(int(next_input))  # Append to the result
print(result)

[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]

非常简单的一个事例,Trax库也在不断更新中,代码很清晰,有兴趣的可以关注一下!

相关标签: 模型