中文纠错

一、任务

中文拼写检测(Chinese Spelling Check )旨在将句子中的错误拼写检测出来并进行纠正。这个任务是自然处理领域一个重要的任务。比如:当新闻中出现一些拼写错误或者不合适的词语使用,这个任务能帮助自动检测。常见的可能出现的错误包括相似字形出错、相似语义混淆出错等。

输入:一个包含错误的自然语言句子

输出:一个正确的自然语言句子

二、实现过程

2.1 seq2seq(加入注意力机制和中文预训练模型初始化的embedding)

使用eq2seq模型,将文本纠错任务转换为翻译任务,即输入带有错字的序列,输出纠正后的序列。

首先输入文本序列经过编码器变换为一个定长的背景变量,作为解码器的初始隐藏状态,该背景变量中包含了输入序列的所有编码信息,在该实验中,背景变量取值为输入序列有效长度最后一个隐藏状态。在给定框架的基础上,将编码器由LSTM变为双向的GRU,主要原因由两点:一是GRU只含有两个门控结构,在超参数全部调优的情况下,GRU的参数比LSTM少,计算速度快,二是采用双向的GRU,每个时间步的隐藏状态不再只取决于该时间步之前的状态,而是同时取决于该时间步的之前和之后的子序列,因此最后选用的背景变量的编码信息会更丰富。

self.encoder = nn.GRU(
            self.embedding_size, self.hidden_dim, batch_first=True, bidirectional=True)

解码器同样使用GRU,训练时,使用强制学习,解码器使用输入句子的编码信息和上个时间步的真实标签以及隐藏状态作为输入,在测试时,解码器使用输入句子的编码信息和当前时间步的输出以及隐藏状态作为输入。

self.decoder = nn.GRU(self.embedding_size,
                              self.hidden_dim, batch_first=True)

2.1.1 加入注意力机制

该实验中,在解码器中加入注意力机制,解码器在每一时间步调整对编码器所有时间步的隐藏状态做加权的权重,从而在不同时间步分别关注输入序列中的不同部分并编码进相应时间步的背景变量。

class Attention(nn.Module):
    def __init__(self, encoder_hidden_size, decoder_hidden_size):
        super(Attention, self).__init__()
        self.enc_hidden_size = encoder_hidden_size
        self.dec_hidden_size = decoder_hidden_size
        self.fc_in = nn.Linear(
            encoder_hidden_size*2, decoder_hidden_size, bias=False)
        self.fc_out = nn.Linear(
            encoder_hidden_size*2 + decoder_hidden_size, decoder_hidden_size)

    def forward(self, output, context):
        # output [batch, target_len, dec_hidden_size]
        # context [batch, source_len, enc_hidden_size*2]
        batch_size = output.size(0)
        y_len = output.size(1)
        x_len = context.size(1)
        # [batch_size * x_sentence_len, enc_hidden_size*2]
        x = context.contiguous().view(batch_size*x_len, -1)
        # [batch_size * x_len, dec_hidden_size]
        x = self.fc_in(x)  
        # [batch_size, x_sentence_len, dec_hidden_size]
        context_in = x.view(batch_size, x_len, -1)
        # [batch_size, y_sentence_len, x_sentence_len]
        atten = torch.bmm(output, context_in.transpose(1, 2))
        # [batch_size, y_sentence_len, x_sentence_len]
        atten = F.softmax(atten, dim=2)
        # [batch_size, y_sentence_len, enc_hidden_size*2]
        context = torch.bmm(atten, context)
        # [batch_size, y_sentence_len, enc_hidden_size*2+dec_hidden_size]
        output = torch.cat((context, output), dim=2)
        # [batch_size * y_sentence_len, enc_hidden_size*2+dec_hidden_size]
        output = output.contiguous().view(batch_size*y_len, -1)
        output = torch.tanh(self.fc_out(output))
        # [batch_size, y_sentence_len, dec_hidden_size]
        output = output.view(batch_size, y_len, -1)
        return output, atten

2.1.2 加入中文预训练模型初始化的embedding

预训练词嵌入在大数据集上训练时捕获单词的语义和句法意义,它们能够提高自然语言处理模型的性能,故考虑将原本模型中的embedding初始化替换为“bert-base-chinese”预训练模型的embedding。

self.in_tok_embed = self.bert.embeddings.to(self.device)
self.out_tok_embed = nn.Linear(self.embedding_size, dataset_size)
self.out_tok_embed.weight = copy.deepcopy(
    self.in_tok_embed.word_embeddings.weight)

但是由实验结果发现,用大型预训练模型初始化本实验的embedding,检测错误位置的f1值和纠正错字的f1值反而变差了,考虑原因可能是本实验的数据集过小,导致预测效果不好。

2.2 Bert

复现论文Spelling Error Correction with Soft-Masked BERT,由于BERT模型以15%的概率mask句中的每个字,因此直接用BERT模型倾向于不做任何改动直接输出相同的句子,故而Soft-Masked BERT提出使用两个网络结构,先通过纠错网络得到错误的分布,然后在同纠错网络纠错。

2.2.1 数据预处理

使用预训练模型的作为词典库(dataset)。

pretrained_tokenzier = BertTokenizer.from_pretrained("bert-base-chinese")
dataset = pretrained_tokenzier

2.2.2使用bert-base-chinese预训练模型

首先将输入序列BERT模型,通过训练得到的词嵌入向量表示,可以认为其即代表单词本身及其含义。

with torch.no_grad():
            p_ = self.em_bert(input_ids=input_ids, attention_mask=input_mask)[
                'last_hidden_state']

将上述得到输出输入到检测网络(detector),检测网络由一个双向GRU、一个线性函数和一个sigmoid非线性激活构成,最终得到每个字的错误概率。

class BiGRU(nn.Module):
    def __init__(self, embedding_size=768, hidden_size=128, n_layers=2, dropout=0.0):
        super(BiGRU, self).__init__()
        self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=n_layers,
                          bidirectional=True, dropout=dropout, batch_first=True)
        self.sigmoid = nn.Sigmoid()
        self.linear = nn.Linear(hidden_size*2, 1)

    def forward(self, x):
        gru_out, _ = self.rnn(x)
        pi = self.sigmoid(self.linear(gru_out))
        return pi

每个词的最终的词嵌入表示由[‘mask’]的embedding表示以及本身的embedding加权平均表示构成,权重分别为p和(1-p)。p是检测网络的输出,p越大表示该字越有可能是错字,即[‘mask’]embed的权重越大。

e_ = p * self.mask_e + (1-p) * e_bert

纠错网络(corrector)是一个BERT模型加上一个残差结构,即将BERT模型的输出加上字原本的embedding在进行softmax计算,最终取预测到字典中概率最大的字的index作为该字的预测输出。以下是代码详细构建过程。

h = self.corrector(e_,
                           attention_mask=encoder_extended_attention_mask,
                           head_mask=head_mask,
                           encoder_hidden_states=encoder_hidden_states,
                           encoder_attention_mask=encoder_extended_attention_mask)
h = h[0] + e
h = self.linear(h)
out = self.softmax(h)
label = torch.where(input_ids != answer_input_ids, 1, 0)
p = p.reshape(-1, p.shape[1])
loss = self.count_loss(out, answer_input_ids, p, label)
decode_result = out.max(-1)[1]

损失值的计算是由检测网络和纠错网络的损失加权平均作为整个网络的损失,检测网络的损失函数是二分类交叉熵损失,纠错网络采用多分类交叉熵损失,权重分别为(1-gama)和gama,gama为0-1之间的小数,通常设置gama大于0.5,因为多分类比二分类问题复杂。

def count_loss(self, out_v, true_v, p, true_label):
        loss_d = self.criterion_d(p, true_label.float())
        loss_c = self.criterion_c(out_v.transpose(1, 2), true_v)
        loss = self.gama * loss_c + (1-self.gama) * loss_d
        return loss

2.2.2使用shibing624/macbert4csc-base-chinese预训练模型

模型下载地址:shibing624/macbert4csc-base-chinese · Hugging Face

在本实验中,我已经将该模型下载后放在10205501458/gectoolkit/properties/model/RNN文件夹下

修改softmaskedbert的模型结构:

检测网络不再使用bert.embedding,而是使用bert全部模型,并取消其梯度,防止参数过导致过拟合,将结果输入线性层之后sigmoid激活。

with torch.no_grad():
            e_bert = self.em_bert(input_ids=input_ids, attention_mask=input_mask)[
                'last_hidden_state']
 p = self.sigmoid(self.detector(e_bert))

同样,纠错网络也是使用整个bert模型的最后一层隐藏状态作为输入。

h = self.corrector(e_bert,
attention_mask=encoder_extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=input_mask)

损失函数不变,同上softmaskedbert。

  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2023 Yuqing He

请我喝杯奶茶吧~

支付宝
微信