欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

小黑NLPbaseline成长日记1:Skip_Gram+NEG的pytorch练习

程序员文章站 2023-03-25 08:53:38
import numpy as npfrom collections import dequeimport argparseimport torchimport torch.nn as nnimport torch.nn.functional as Fimport sysimport torch.optim as optimfrom tqdm import tqdmimport torchdef ArgumentParser(): # 参数基本配置 parser = arg...
import numpy as np
from collections import deque
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import torch.optim as optim
from tqdm import tqdm
import torch
def ArgumentParser():    # 参数基本配置
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',type = str,default = 'skip-gram',help = 'skip-gram or cbow')
    parser.add_argument('--window_size',type = int,default = 3,help = 'window size in word2vec')
    parser.add_argument('--batch_size',type = int,default = 256,help = 'batch size during training phase')
    parser.add_argument('--min_count',type = int,default = 3,help = 'min count of training word')
    parser.add_argument('--embed_dimension',type = int,default = 100,help = 'embedding dimension of word embedding')
    parser.add_argument('--learning_rate',type = float,default = 0.02,help = 'learning rate during training phase')
    parser.add_argument('--neg-count',type = int,default = 5,help = 'neg count of skip-gram')
    return parser.parse_known_args()[0]
args = ArgumentParser()
WINDOW_SIZE = args.window_size    # 上下文窗口c
BATCH_SIZE = args.batch_size    # mini-batch
MIN_COUNT = args.min_count    # 需要剔除的低频词个数
EMB_DIMENSION = args.embed_dimension    # embedding维度
LR = args.learning_rate    # 学习率
NEG_COUNT = args.neg_count    # 负采样数
# 数据输入类
class InputData:
    def __init__(self,input_file_name,min_count):
        self.input_file_name = input_file_name  
        self.index = 0    # 中心词索引初始化
        self.input_file = open(self.input_file_name,'r',encoding = 'utf-8')    # 输入文档
        self.min_count = min_count    # 最小词频
        self.wordid_frequency_dict = dict()    
        self.word_count = 0    # 词典词语数量
        self.word_count_sum = 0     # 文档的词语总数
        self.sentence_count = 0
        self.id2word_dict = dict()
        self.word2id_dict = dict()
        self._init_dict()    # 初始化字典
        self.sample_table = []     # 采样的词语对儿
        self._init_sample_table()    # 初始化负采样映射表
        self.get_wordId_list()    # 得到文档的词语idlist
        self.word_pairs_queue = deque()    # 初始化队列,便于取batch数据
        # 结果展示
        print('Word Count is:', self.word_count)
        print('Word Count Sum is', self.word_count_sum)
        print('Sentence Count is:', self.sentence_count)
    def _init_dict(self):
        print('word_freq初始化中...')
        word_freq = dict()
        for line in self.input_file:
            line = line.strip().split()
            self.word_count_sum += len(line)
            self.sentence_count += 1
            for i,word in enumerate(line):
                if i % 1000000 == 0:
                    print(i,len(line))
                if word_freq.get(word) == None:
                    word_freq[word] = 1
                else:
                    word_freq[word] += 1
        print('word2id_dict,id2word_dict,wordid_frequency_dict初始化中...')
        for i,word in enumerate(word_freq):
            if i % 100000 == 0:
                print(i,len(word_freq))
            if word_freq[word] < self.min_count:
                self.word_count_sum -= word_freq[word]
                continue
            self.word2id_dict[word] = len(self.word2id_dict)
            self.id2word_dict[len(self.word2id_dict)] = word
            self.wordid_frequency_dict[len(self.word2id_dict)-1] = word_freq[word]
        self.word_count = len(self.word2id_dict)
    def _init_sample_table(self):
        sample_table_size = 1e8
        pow_frequency = np.array(list(self.wordid_frequency_dict.values())) ** 0.75
        word_pow_sum = sum(pow_frequency)
        ratio_array = pow_frequency / word_pow_sum
        word_count_list = np.round(ratio_array * sample_table_size)
        for word_index,word_freq in enumerate(word_count_list):
            self.sample_table += [word_index] * int(word_freq)
        self.sample_table = np.array(self.sample_table)
        np.random.shuffle(self.sample_table)
    def get_wordId_list(self):
        self.input_file = open(self.input_file_name,encoding = 'utf-8')
        sentence = self.input_file.readline()
        wordId_list = []    # 一句中的所有word对应的id
        sentence = sentence.strip().split(' ')
        print('建立wordID_list...')
        for i,word in enumerate(sentence):
            if i % 1000000 == 0:
                print(i,len(sentence))
            try:
                word_id = self.word2id_dict[word]
                wordId_list.append(word_id)
            except:
                continue
        self.wordId_list = wordId_list
    def get_batch_pairs(self,batch_size,window_size):
        while len(self.word_pairs_queue) < batch_size:
            for _ in range(1000):
                if self.index == len(self.wordId_list):
                    self.index = 0
                wordId_w = self.wordId_list[self.index]
                for i in range(max(self.index-window_size,0),min(self.index+window_size,len(self.wordId_list))):
                    wordId_v = self.wordId_list[i]
                    if self.index == i:     # 上下文=中心词 跳过
                        continue
                    self.word_pairs_queue.append((wordId_w,wordId_v))
                self.index += 1
        result_pairs = []    # 返回mini-batch大小的正采样对
        for _ in range(batch_size):
            result_pairs.append(self.word_pairs_queue.popleft())
        return result_pairs   
    # 获得负采样 输入正采样对数组positive_pairs,以及每一个正采样对需要的负采样数neg_count从采样表抽取负采样词的id
    # (假设数据够大,不考虑(负采样=正采样)的小概率情况)
    def get_negative_sampling(self,positive_pairs,neg_count):
        neg_v = np.random.choice(self.sample_table,size = (len(positive_pairs),neg_count)).tolist()
        return neg_v
    # 估计数据中正采样对数,用于设定batch
    def evaluate_pairs_count(self,window_size):
        return self.word_count_sum * (2 * window_size) - self.sentence_count * (1 + window_size) * window_size
class SkipGramModel(nn.Module):
    def __init__(self,vocab_size,embed_size):
        super(SkipGramModel,self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.w_embeddings = nn.Embedding(vocab_size,embed_size)
        self.v_embeddings = nn.Embedding(vocab_size,embed_size)
        self._init_emb()
    def _init_emb(self):
        initrange = 0.5 / self.embed_size
        self.w_embeddings.weight.data.uniform_(-initrange,initrange)
        self.v_embeddings.weight.data.uniform_(-0,0)
    def forward(self,pos_w,pos_v,neg_v):
        emb_w = self.w_embeddings(torch.LongTensor(pos_w).cpu())    # [batch*emb_dim]
        emb_v = self.v_embeddings(torch.LongTensor(pos_v).cpu()) # [batch*emb_dim]
        neg_emb_v = self.v_embeddings(torch.LongTensor(neg_v).cpu()) # [batch*neg_num*emb_dim]
        score = torch.mul(emb_w,emb_v)   
        score = torch.sum(score,dim = 1)
        score = torch.clamp(score,max = 10,min = -10)
        score = F.logsigmoid(score)
        neg_score = torch.bmm(neg_emb_v,emb_w.unsqueeze(2))
        neg_score = torch.clamp(neg_score,max = 10,min = -10)
        neg_score = F.logsigmoid(-1 * neg_score)
        loss = - torch.sum(score) - torch.sum(neg_score)
        return loss
    def save_embedding(self,id2word,file_name):
        embedding_1 = self.w_embeddings.weight.data.cpu().numpy()
        embedding_2 = self.v_embeddings.weight.data.cpu().numpy()
        embedding = (embedding_1 + embedding_2) / 2
        fout = open(file_name,'w')
        fout.write('%d %d \n' % (len(id2word),self.embed_size))
        for wid,w in id2word.items():
            e = embedding[wid]
            e = ' '.join(map(lambda x:str(x),e))
            fout.write('%s %s\n' % (w,e))
class Word2Vec:
    def __init__(self,input_file_name,output_file_name):
        self.output_file_name = output_file_name
        self.data = InputData(input_file_name,MIN_COUNT)
        self.model = SkipGramModel(self.data.word_count,EMB_DIMENSION).cpu()
        self.lr = LR
        self.optimizer = optim.SGD(self.model.parameters(),lr = self.lr)
    def train(self):
        print('SkipGram Training......')
        pairs_count = self.data.evaluate_pairs_count(WINDOW_SIZE)
        print('pairs_count',pairs_count)
        batch_count = pairs_count / BATCH_SIZE
        print('batch_count',batch_count)
        process_bar = tqdm(range(int(5*batch_count)))
        for i in process_bar:
            pos_pairs = self.data.get_batch_pairs(BATCH_SIZE,WINDOW_SIZE)
            pos_w = [int(pair[0]) for pair in pos_pairs]
            pos_v = [int(pair[1]) for pair in pos_pairs]
            neg_v = self.data.get_negative_sampling(pos_pairs,NEG_COUNT)
            self.optimizer.zero_grad()
            loss = self.model.forward(pos_w,pos_v,neg_v)
            loss.backward()
            self.optimizer.step()
            process_bar.set_postfix(loss = loss.data)
            process_bar.update()
        torch.save(self.model.state_dict(),'./test_skipgram_nge.pkl')
        self.model.save_embedding(self.data.id2word_dict,self.output_file_name)
w2v = Word2Vec(input_file_name='./word2vec/data/text8.txt', output_file_name="../results/skip_gram_neg.txt")
w2v.train()

本文地址:https://blog.csdn.net/qq_37418807/article/details/109955261