欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

seq2seq:LSTM+attention的生成式文本概要

程序员文章站 2022-07-01 18:10:51
...

seq2seq:LSTM+attention的生成式文本概要

最近做的利用seq2seq模型的生成式文本概要,参考了这位大佬的源码:
https://spaces.ac.cn/archives/5861/comment-page-1

数据集准备及预处理

我直接拿的新闻数据集的内容(content)和标题(title),根据内容概括标题。
一般想要达到比较能看的结果的话需要10w左右的数据集,跑50次迭代左右。
这种数据集网上很多,自己去找然后处理一下就好了。
数据集的预处理我是只保留了中文,去空格,最后所有文本都是连在一起的:

#正则表达式去除非中文字符
delCop = re.compile("[^\u4e00-\u9fa5]")
changeCop=re.compile("[^\u4e00-\u9fa5]")
for i in range(0, len(trainSet)):
    trainSet.iloc[i,1] = changeCop.sub(' ', delCop.sub('', trainSet.iloc[i,1]))
    trainSet.iloc[i,2] = changeCop.sub(' ', delCop.sub('', trainSet.iloc[i,2]))

生成式文本摘要与seq2seq

sequence2sequence就是利用一个encoder与一个decoder,将需要处理的原始文本投进encoder生成一个理论上的“中间码”,再有decoder解码输出为结果:
seq2seq:LSTM+attention的生成式文本概要

# encoder,双层双向LSTM
x = LayerNormalization()(x)
x = OurBidirectional(CuDNNLSTM(z_dim // 2, return_sequences=True))([x, x_mask])
x = LayerNormalization()(x)
x = OurBidirectional(CuDNNLSTM(z_dim // 2, return_sequences=True))([x, x_mask])
x_max = Lambda(seq_maxpool)([x, x_mask])

# decoder,双层单向LSTM
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])
y = CuDNNLSTM(z_dim, return_sequences=True)(y)
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])
y = CuDNNLSTM(z_dim, return_sequences=True)(y)
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])

最后也是在评价器中放了两个句子进行的调用解码进行的输出:

s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及时就医 。'
s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午,华住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。'

class Evaluate(Callback):
    def __init__(self):
        self.lowest = 1e10
    def on_epoch_end(self, epoch, logs=None):
        # 训练过程中观察一两个例子,显示标题质量提高的过程
        resStr=s1+'\n输出:'+gen_sent(s1)+'\n'+s2+'\n输出:'+gen_sent(s2)+'\n'
        with open('output', 'a',encoding='utf-8') as file_obj:
            file_obj.write(resStr)
        print(resStr)
        # 保存最优结果
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('./best_model.weights')

attention

attention是一种编码机制,用于形容词与词之间的注意力关系,比如下面这句话:
The animal didn’t cross the street because it was too tired

这句话中的"it"指的是什么?它指的是“animal”还是“street”?对于人来说,这其实是一个很简单的问题,但是对于一个算法来说,处理这个问题其实并不容易。self attention的出现就是为了解决这个问题,通过self attention,我们能将“it”与“animal”联系起来。
由于有时候一个词可能与多个词有较大关联,所以我们采用了一种叫做“多头”的策略。
比如上面的句子,it的注意力会集中在animal和tired身上。
具体可以参照这篇博文:https://blog.csdn.net/qq_43012160/article/details/100782291
著名的transformer和bert的词编码就是基于attention机制的

fit_generator

为什么那么多人训练模型的时候不用fit用fit_generator?
fit_generator中传入的不是数据集,而是一个数据生成器,如果你数据量非常大无法读入内存,fit就用不了了,但用fit_generator就只要传一个生成器(一个函数)进去。一般生成器每次选取batch_size个数据进行处理,处理完抛进模型训练,再处理后batch_size个数据。
生成器的数据生成与模型的训练还是并行的。

def data_generator():
    # 数据生成器
    X,Y = [],[]
    i=0
    while True:
        sentence=data.loc[i%dataLen]
        X.append(str2id(sentence['content']))
        Y.append(str2id(sentence['title'], start_end=True))
        i=i+1
        if len(X) == batch_size:
            X = np.array(padding(X))
            Y = np.array(padding(Y))
            yield [X,Y], None
            X,Y = [],[]
#模型训练
evaluator = Evaluate()

model.fit_generator(data_generator(),
                    steps_per_epoch=int(dataLen/batch_size),
                    epochs=epochs,
                    callbacks=[evaluator])

dataLen是数据的总长度,steps_per_epoch就是指每次迭代fit_generator执行的步数,dataLen=batch_size* steps_per_epoch。就是每步执行batch_size条数据。

放两条比较好的结果:
1.夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及时就医 。
输出1:紫外线照射成长期伤害长期伤害
输出2:夏天来临天后护肤品要慎重要

2.8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午,华住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。
输出1:连锁酒店用户数据泄露泄露
输出1:客户数据泄密牌酒店用户数据泄密店用户数据泄露

现在这个模型的vocab是单字的,后面打算用jieba分一下词,分完词就可以判断词之间的相似性,能判断相似性这种“数据泄密牌酒店用户数据泄密店用户数据泄露”的情况我就能把他检测出来然后做处理了。