小黑NLPbaseline成长日记1:Skip_Gram+NEG的pytorch练习
程序员文章站
2023-03-25 08:53:38
import numpy as npfrom collections import dequeimport argparseimport torchimport torch.nn as nnimport torch.nn.functional as Fimport sysimport torch.optim as optimfrom tqdm import tqdmimport torchdef ArgumentParser(): # 参数基本配置 parser = arg...
import numpy as np
from collections import deque
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import torch.optim as optim
from tqdm import tqdm
import torch
def ArgumentParser(): # 参数基本配置
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',type = str,default = 'skip-gram',help = 'skip-gram or cbow')
parser.add_argument('--window_size',type = int,default = 3,help = 'window size in word2vec')
parser.add_argument('--batch_size',type = int,default = 256,help = 'batch size during training phase')
parser.add_argument('--min_count',type = int,default = 3,help = 'min count of training word')
parser.add_argument('--embed_dimension',type = int,default = 100,help = 'embedding dimension of word embedding')
parser.add_argument('--learning_rate',type = float,default = 0.02,help = 'learning rate during training phase')
parser.add_argument('--neg-count',type = int,default = 5,help = 'neg count of skip-gram')
return parser.parse_known_args()[0]
args = ArgumentParser()
WINDOW_SIZE = args.window_size # 上下文窗口c
BATCH_SIZE = args.batch_size # mini-batch
MIN_COUNT = args.min_count # 需要剔除的低频词个数
EMB_DIMENSION = args.embed_dimension # embedding维度
LR = args.learning_rate # 学习率
NEG_COUNT = args.neg_count # 负采样数
# 数据输入类
class InputData:
def __init__(self,input_file_name,min_count):
self.input_file_name = input_file_name
self.index = 0 # 中心词索引初始化
self.input_file = open(self.input_file_name,'r',encoding = 'utf-8') # 输入文档
self.min_count = min_count # 最小词频
self.wordid_frequency_dict = dict()
self.word_count = 0 # 词典词语数量
self.word_count_sum = 0 # 文档的词语总数
self.sentence_count = 0
self.id2word_dict = dict()
self.word2id_dict = dict()
self._init_dict() # 初始化字典
self.sample_table = [] # 采样的词语对儿
self._init_sample_table() # 初始化负采样映射表
self.get_wordId_list() # 得到文档的词语idlist
self.word_pairs_queue = deque() # 初始化队列,便于取batch数据
# 结果展示
print('Word Count is:', self.word_count)
print('Word Count Sum is', self.word_count_sum)
print('Sentence Count is:', self.sentence_count)
def _init_dict(self):
print('word_freq初始化中...')
word_freq = dict()
for line in self.input_file:
line = line.strip().split()
self.word_count_sum += len(line)
self.sentence_count += 1
for i,word in enumerate(line):
if i % 1000000 == 0:
print(i,len(line))
if word_freq.get(word) == None:
word_freq[word] = 1
else:
word_freq[word] += 1
print('word2id_dict,id2word_dict,wordid_frequency_dict初始化中...')
for i,word in enumerate(word_freq):
if i % 100000 == 0:
print(i,len(word_freq))
if word_freq[word] < self.min_count:
self.word_count_sum -= word_freq[word]
continue
self.word2id_dict[word] = len(self.word2id_dict)
self.id2word_dict[len(self.word2id_dict)] = word
self.wordid_frequency_dict[len(self.word2id_dict)-1] = word_freq[word]
self.word_count = len(self.word2id_dict)
def _init_sample_table(self):
sample_table_size = 1e8
pow_frequency = np.array(list(self.wordid_frequency_dict.values())) ** 0.75
word_pow_sum = sum(pow_frequency)
ratio_array = pow_frequency / word_pow_sum
word_count_list = np.round(ratio_array * sample_table_size)
for word_index,word_freq in enumerate(word_count_list):
self.sample_table += [word_index] * int(word_freq)
self.sample_table = np.array(self.sample_table)
np.random.shuffle(self.sample_table)
def get_wordId_list(self):
self.input_file = open(self.input_file_name,encoding = 'utf-8')
sentence = self.input_file.readline()
wordId_list = [] # 一句中的所有word对应的id
sentence = sentence.strip().split(' ')
print('建立wordID_list...')
for i,word in enumerate(sentence):
if i % 1000000 == 0:
print(i,len(sentence))
try:
word_id = self.word2id_dict[word]
wordId_list.append(word_id)
except:
continue
self.wordId_list = wordId_list
def get_batch_pairs(self,batch_size,window_size):
while len(self.word_pairs_queue) < batch_size:
for _ in range(1000):
if self.index == len(self.wordId_list):
self.index = 0
wordId_w = self.wordId_list[self.index]
for i in range(max(self.index-window_size,0),min(self.index+window_size,len(self.wordId_list))):
wordId_v = self.wordId_list[i]
if self.index == i: # 上下文=中心词 跳过
continue
self.word_pairs_queue.append((wordId_w,wordId_v))
self.index += 1
result_pairs = [] # 返回mini-batch大小的正采样对
for _ in range(batch_size):
result_pairs.append(self.word_pairs_queue.popleft())
return result_pairs
# 获得负采样 输入正采样对数组positive_pairs,以及每一个正采样对需要的负采样数neg_count从采样表抽取负采样词的id
# (假设数据够大,不考虑(负采样=正采样)的小概率情况)
def get_negative_sampling(self,positive_pairs,neg_count):
neg_v = np.random.choice(self.sample_table,size = (len(positive_pairs),neg_count)).tolist()
return neg_v
# 估计数据中正采样对数,用于设定batch
def evaluate_pairs_count(self,window_size):
return self.word_count_sum * (2 * window_size) - self.sentence_count * (1 + window_size) * window_size
class SkipGramModel(nn.Module):
def __init__(self,vocab_size,embed_size):
super(SkipGramModel,self).__init__()
self.vocab_size = vocab_size
self.embed_size = embed_size
self.w_embeddings = nn.Embedding(vocab_size,embed_size)
self.v_embeddings = nn.Embedding(vocab_size,embed_size)
self._init_emb()
def _init_emb(self):
initrange = 0.5 / self.embed_size
self.w_embeddings.weight.data.uniform_(-initrange,initrange)
self.v_embeddings.weight.data.uniform_(-0,0)
def forward(self,pos_w,pos_v,neg_v):
emb_w = self.w_embeddings(torch.LongTensor(pos_w).cpu()) # [batch*emb_dim]
emb_v = self.v_embeddings(torch.LongTensor(pos_v).cpu()) # [batch*emb_dim]
neg_emb_v = self.v_embeddings(torch.LongTensor(neg_v).cpu()) # [batch*neg_num*emb_dim]
score = torch.mul(emb_w,emb_v)
score = torch.sum(score,dim = 1)
score = torch.clamp(score,max = 10,min = -10)
score = F.logsigmoid(score)
neg_score = torch.bmm(neg_emb_v,emb_w.unsqueeze(2))
neg_score = torch.clamp(neg_score,max = 10,min = -10)
neg_score = F.logsigmoid(-1 * neg_score)
loss = - torch.sum(score) - torch.sum(neg_score)
return loss
def save_embedding(self,id2word,file_name):
embedding_1 = self.w_embeddings.weight.data.cpu().numpy()
embedding_2 = self.v_embeddings.weight.data.cpu().numpy()
embedding = (embedding_1 + embedding_2) / 2
fout = open(file_name,'w')
fout.write('%d %d \n' % (len(id2word),self.embed_size))
for wid,w in id2word.items():
e = embedding[wid]
e = ' '.join(map(lambda x:str(x),e))
fout.write('%s %s\n' % (w,e))
class Word2Vec:
def __init__(self,input_file_name,output_file_name):
self.output_file_name = output_file_name
self.data = InputData(input_file_name,MIN_COUNT)
self.model = SkipGramModel(self.data.word_count,EMB_DIMENSION).cpu()
self.lr = LR
self.optimizer = optim.SGD(self.model.parameters(),lr = self.lr)
def train(self):
print('SkipGram Training......')
pairs_count = self.data.evaluate_pairs_count(WINDOW_SIZE)
print('pairs_count',pairs_count)
batch_count = pairs_count / BATCH_SIZE
print('batch_count',batch_count)
process_bar = tqdm(range(int(5*batch_count)))
for i in process_bar:
pos_pairs = self.data.get_batch_pairs(BATCH_SIZE,WINDOW_SIZE)
pos_w = [int(pair[0]) for pair in pos_pairs]
pos_v = [int(pair[1]) for pair in pos_pairs]
neg_v = self.data.get_negative_sampling(pos_pairs,NEG_COUNT)
self.optimizer.zero_grad()
loss = self.model.forward(pos_w,pos_v,neg_v)
loss.backward()
self.optimizer.step()
process_bar.set_postfix(loss = loss.data)
process_bar.update()
torch.save(self.model.state_dict(),'./test_skipgram_nge.pkl')
self.model.save_embedding(self.data.id2word_dict,self.output_file_name)
w2v = Word2Vec(input_file_name='./word2vec/data/text8.txt', output_file_name="../results/skip_gram_neg.txt")
w2v.train()
本文地址:https://blog.csdn.net/qq_37418807/article/details/109955261
上一篇: 快速识别联想显示器厂商方法介绍
下一篇: 校园网试炼:注销