由BERT到Bi-LSTM的知识蒸馏#
整体原理介绍#
本例是将特定任务下BERT模型的知识蒸馏到基于Bi-LSTM的小模型中,主要参考论文 Distilling Task-Specific Knowledge from BERT into Simple Neural Networks 实现。整体原理如下:
在本例中,较大的模型是BERT被称为教师模型,Bi-LSTM被称为学生模型。
小模型学习大模型的知识,需要小模型学习蒸馏相关的损失函数。在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。
在论文的模型蒸馏阶段,作者为了能让教师模型表达出更多的“暗知识”(dark knowledge,通常指分类任务中低概率类别与高概率类别的关系)供学生模型学习,对训练数据进行了数据增强。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。本文的作者使用了三种数据增强方式,分别是:
Masking,即以一定的概率将原数据中的word token替换成
[MASK]
;POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;
n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。
模型蒸馏步骤介绍#
本实验分为三个训练过程:在特定任务上对BERT进行微调、在特定任务上对基于Bi-LSTM的小模型进行训练(用于评价蒸馏效果)、将BERT模型的知识蒸馏到基于Bi-LSTM的小模型上。
1. 基于bert-base-uncased预训练模型在特定任务上进行微调#
训练BERT的fine-tuning模型,可以去 PaddleNLP 中的 glue 目录下对bert-base-uncased做微调。
以GLUE的SST-2任务为例,用bert-base-uncased做微调之后,可以得到一个在SST-2任务上的教师模型,可以把在dev上取得最好Accuracy的模型保存下来,用于第三步的蒸馏。
2. 训练基于Bi-LSTM的小模型#
在本示例中,小模型采取的是基于双向LSTM的分类模型,网络层分别是 Embedding
、LSTM
、 带有 tanh
激活函数的 Linear
层,最后经过一个全连接的输出层得到logits。LSTM
网络层定义如下:
self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers,
'bidirectional', dropout=dropout_prob)
基于Bi-LSTM的小模型的 forward
函数定义如下:
def forward(self, x, seq_len):
x_embed = self.embedder(x)
lstm_out, (hidden, _) = self.lstm(
x_embed, sequence_length=seq_len) # 双向LSTM
out = paddle.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
out = paddle.tanh(self.fc(out))
logits = self.output_layer(out)
return logits
3.数据增强介绍#
接下来的蒸馏过程,蒸馏时使用的训练数据集并不只包含数据集中原有的数据,而是按照上文原理介绍中的A、C两种方法进行数据增强后的总数据。
在多数情况下,alpha
会被设置为0,表示无视硬标签,学生模型只利用数据增强后的无标签数据进行训练。根据教师模型提供的软标签 teacher_logits
,对比学生模型的 logits
,计算均方误差损失。由于数据增强过程产生了更多的数据,学生模型可以从教师模型中学到更多的暗知识。
数据增强的核心代码如下:
def ngram_sampling(words, words_2=None, p_ng=0.25, ngram_range=(2, 6)):
if np.random.rand() < p_ng:
ngram_len = np.random.randint(ngram_range[0], ngram_range[1] + 1)
ngram_len = min(ngram_len, len(words))
start = np.random.randint(0, len(words) - ngram_len + 1)
words = words[start:start + ngram_len]
if words_2:
words_2 = words_2[start:start + ngram_len]
return words if not words_2 else (words, words_2)
def data_augmentation(data, whole_word_mask=whole_word_mask):
# 1. Masking
words = []
if not whole_word_mask:
tokenized_list = tokenizer.tokenize(data)
words = [
tokenizer.mask_token if np.random.rand() < p_mask else word
for word in tokenized_list
]
else:
for word in data.split():
words += [[tokenizer.mask_token]] if np.random.rand(
) < p_mask else [tokenizer.tokenize(word)]
# 2. N-gram sampling
words = ngram_sampling(words, p_ng=p_ng, ngram_range=ngram_range)
words = flatten(words) if isinstance(words[0], list) else words
new_text = " ".join(words)
return words, new_text
4.蒸馏模型#
这一步是将教师模型BERT的知识蒸馏到基于Bi-LSTM的学生模型中,在本例中,主要是让学生模型(Bi-LSTM)去学习教师模型的输出logits。蒸馏时使用的训练数据集是由上一步数据增强后的数据,核心代码如下:
ce_loss = nn.CrossEntropyLoss() # 交叉熵损失函数
mse_loss = nn.MSELoss() # 均方误差损失函数
for epoch in range(args.max_epoch):
for i, batch in enumerate(train_data_loader):
bert_input_ids, bert_segment_ids, student_input_ids, seq_len, labels = batch
# Calculate teacher model's forward.
with paddle.no_grad():
teacher_logits = teacher.model(bert_input_ids, bert_segment_ids)
# Calculate student model's forward.
logits = model(student_input_ids, seq_len)
# Calculate the loss, usually args.alpha equals to 0.
loss = args.alpha * ce_loss(logits, labels) + (
1 - args.alpha) * mse_loss(logits, teacher_logits)
loss.backward()
optimizer.step()