# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import tensor
import paddle.nn.functional as F
from paddle.nn import MultiHeadAttention, TransformerEncoderLayer, TransformerEncoder
from paddle.common_ops_import import convert_dtype
from paddlenlp.utils.log import logger
from paddlenlp.transformers import PPMiniLMForSequenceClassification
from paddlenlp.transformers import TinyBertForPretraining
from paddlenlp.transformers import BertForSequenceClassification
__all__ = ["to_distill", "calc_minilm_loss", "calc_multi_relation_loss"]
[文档]
def calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0):
"""
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports
head-head relation, sample-sample relation and origin token-token relation.
The final loss value could be balanced by weight `alpha` and `beta`.
Args:
loss_fct (callable):
Loss function for distillation. It only supports kl_div loss now.
s (Tensor):
Q, K, V of Student.
t (Tensor):
Q, K, V of teacher.
attn_mask (Tensor):
Attention mask for relation.
num_relation_heads (int):
The number of relation heads. 0 means `num_relation_heads` equals
to origin head num.
Defaults to 0.
alpha (float):
The weight for head-head relation.
Defaults to 0.0.
beta (float):
The weight for sample-sample relation.
Defaults to 0.0.
Returns:
Tensor: Weighted loss of token-token loss, head-head loss and
sample-sample loss.
"""
# Initialize head_num
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
# s'shape: [bs, seq_len, head_num, head_dim]
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
# s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
s1 = tensor.transpose(x=s, perm=[0, 2, 1, 3])
if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
t1 = tensor.transpose(x=t, perm=[0, 2, 1, 3])
s_head_dim, t_head_dim = s.shape[3], t.shape[3]
if alpha + beta == 1.0:
loss_token_token = 0.0
else:
scaled_dot_product_s1 = tensor.matmul(x=s1, y=s1, transpose_y=True) / math.sqrt(s_head_dim)
del s1
scaled_dot_product_s1 += attn_mask
scaled_dot_product_t1 = tensor.matmul(x=t1, y=t1, transpose_y=True) / math.sqrt(t_head_dim)
del t1
scaled_dot_product_t1 += attn_mask
loss_token_token = loss_fct(F.log_softmax(scaled_dot_product_s1), F.softmax(scaled_dot_product_t1))
if alpha == 0.0:
loss_head_head = 0.0
else:
scaled_dot_product_s = tensor.matmul(x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
attn_mask_head_head = tensor.transpose(x=attn_mask, perm=[0, 3, 1, 2])
scaled_dot_product_s += attn_mask_head_head
scaled_dot_product_t = tensor.matmul(x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
scaled_dot_product_t += attn_mask_head_head
loss_head_head = loss_fct(F.log_softmax(scaled_dot_product_s), F.softmax(scaled_dot_product_t))
if beta == 0.0:
loss_sample_sample = 0.0
else:
s2 = tensor.transpose(x=s, perm=[1, 2, 0, 3])
scaled_dot_product_s2 = tensor.matmul(x=s2, y=s2, transpose_y=True) / math.sqrt(s_head_dim)
del s, s2
# Shape: [seq_len, 1, batch_size, 1]
attn_mask_sample_sample = tensor.transpose(x=attn_mask, perm=[3, 1, 0, 2])
# Shape: [seq_len, head_num, batch_size, batch_size]
scaled_dot_product_s2 += attn_mask_sample_sample
t2 = tensor.transpose(x=t, perm=[1, 2, 0, 3])
scaled_dot_product_t2 = tensor.matmul(x=t2, y=t2, transpose_y=True) / math.sqrt(t_head_dim)
del t, t2
scaled_dot_product_t2 += attn_mask_sample_sample
loss_sample_sample = loss_fct(F.log_softmax(scaled_dot_product_s2), F.softmax(scaled_dot_product_t2))
return (1 - alpha - beta) * loss_token_token + alpha * loss_head_head + beta * loss_sample_sample
[文档]
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
"""
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2.
Args:
loss_fct (callable):
Loss function for distillation. It only supports kl_div loss now.
s (Tensor):
Q, K, V of Student.
t (Tensor):
Q, K, V of teacher.
attn_mask (Tensor):
Attention mask for relation.
num_relation_heads (int):
The number of relation heads. 0 means `num_relation_heads` equals
to origin head num.
Defaults to 0.
Returns:
Tensor: MiniLM loss value.
"""
# Initialize head_num
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
# s'shape: [bs, seq_len, head_num, head_dim]
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
# s'shape: [bs, seq_len, num_relation_heads, head_dim_new]
s = tensor.reshape(x=s, shape=[0, 0, num_relation_heads, -1])
# s' shape: [bs, num_relation_heads, seq_len, head_dim_new]
s = tensor.transpose(x=s, perm=[0, 2, 1, 3])
if num_relation_heads > 0 and num_relation_heads != t.shape[1]:
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
t = tensor.reshape(x=t, shape=[0, 0, num_relation_heads, -1])
t = tensor.transpose(x=t, perm=[0, 2, 1, 3])
s_head_dim, t_head_dim = s.shape[3], t.shape[3]
scaled_dot_product_s = tensor.matmul(x=s, y=s, transpose_y=True) / math.sqrt(s_head_dim)
del s
scaled_dot_product_s += attn_mask
scaled_dot_product_t = tensor.matmul(x=t, y=t, transpose_y=True) / math.sqrt(t_head_dim)
del t
scaled_dot_product_t += attn_mask
loss = loss_fct(F.log_softmax(scaled_dot_product_s), F.softmax(scaled_dot_product_t))
return loss
[文档]
def to_distill(self, return_qkv=False, return_attentions=False, return_layer_outputs=False, layer_index=-1):
"""
Can be bound to object with transformer encoder layers, and make model
expose attributes `outputs.q`, `outputs.k`, `outputs.v`,
`outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of
the object for distillation.
It could be returned intermediate tensor using in MiniLM and TinyBERT
strategy.
"""
logger.warning("`to_distill` is an experimental API and subject to change.")
MultiHeadAttention._forward = attention_forward
TransformerEncoderLayer._forward = transformer_encoder_layer_forward
TransformerEncoder._forward = transformer_encoder_forward
BertForSequenceClassification._forward = bert_forward
if return_qkv:
# forward function of student class should be replaced for distributed training.
TinyBertForPretraining._forward = minilm_pretraining_forward
PPMiniLMForSequenceClassification._forward = minilm_pretraining_forward
else:
TinyBertForPretraining._forward = tinybert_forward
def init_func(layer):
if isinstance(
layer,
(
MultiHeadAttention,
TransformerEncoderLayer,
TransformerEncoder,
TinyBertForPretraining,
BertForSequenceClassification,
PPMiniLMForSequenceClassification,
),
):
layer.forward = layer._forward
if isinstance(layer, TransformerEncoder):
layer.return_layer_outputs = return_layer_outputs
layer.layer_index = layer_index
if isinstance(layer, MultiHeadAttention):
layer.return_attentions = return_attentions
layer.return_qkv = return_qkv
for layer in self.children():
layer.apply(init_func)
base_model_prefix = (
self._layers.base_model_prefix if isinstance(self, paddle.DataParallel) else self.base_model_prefix
)
# For distribute training
if isinstance(self, paddle.DataParallel):
if hasattr(self._layers, base_model_prefix):
self.outputs = getattr(self._layers, base_model_prefix).encoder
else:
self.outputs = self._layers.encoder
else:
if hasattr(self, base_model_prefix):
self.outputs = getattr(self, base_model_prefix).encoder
else:
self.outputs = self.encoder
return self
def _convert_attention_mask(attn_mask, dtype):
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == "bool" or "int" in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask
def attention_forward(self, query, key=None, value=None, attn_mask=None, cache=None):
"""
Redefines the `forward` function of `paddle.nn.MultiHeadAttention`.
"""
key = query if key is None else key
value = query if value is None else value
# Computes q ,k ,v
if cache is None:
q, k, v = self._prepare_qkv(query, key, value, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
# Scale dot product attention
product = tensor.matmul(x=q, y=k, transpose_y=True)
product /= math.sqrt(self.head_dim)
if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
self.attention_matrix = product if self.return_attentions else None
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train")
out = tensor.matmul(weights, v)
if self.return_qkv:
self.q = q
self.k = k
self.v = v
# Combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# Project to output
out = self.out_proj(out)
outs = [out]
if self.need_weights:
outs.append(weights)
if cache is not None:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)
def transformer_encoder_layer_forward(self, src, src_mask=None, cache=None):
"""
Redefines the `forward` function of `paddle.nn.TransformerEncoderLayer`.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src
if self.normalize_before:
src = self.norm1(src)
# Add cache for encoder for the usage like UniLM
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
src, incremental_cache = self.self_attn(src, src, src, src_mask, cache)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
if hasattr(self.self_attn, "attention_matrix"):
self.attention_matrix = self.self_attn.attention_matrix
if hasattr(self.self_attn, "q"):
self.q = self.self_attn.q
self.k = self.self_attn.k
self.v = self.self_attn.v
return src if cache is None else (src, incremental_cache)
def transformer_encoder_forward(self, src, src_mask=None, cache=None):
"""
Redefines the `forward` function of `paddle.nn.TransformerEncoder`.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
output = src
new_caches = []
self.attentions = []
self.hidden_states = []
for i, mod in enumerate(self.layers):
if self.return_layer_outputs:
self.hidden_states.append(output)
if cache is None:
output = mod(output, src_mask=src_mask)
else:
output, new_cache = mod(output, src_mask=src_mask, cache=cache[i])
new_caches.append(new_cache)
if hasattr(mod, "attention_matrix"):
self.attentions.append(mod.attention_matrix)
if i == self.layer_index and hasattr(mod, "q"):
self.q = mod.q
self.k = mod.k
self.v = mod.v
if self.norm is not None:
output = self.norm(output)
if self.return_layer_outputs:
self.hidden_states.append(output)
return output if cache is None else (output, new_caches)
def minilm_pretraining_forward(self, input_ids, token_type_ids=None, attention_mask=None):
"""
Replaces `forward` function while using multi gpus to train. If training on
single GPU, this `forward` could not be replaced.
The type of `self` should inherit from base class of pretrained LMs, such as
`TinyBertForPretraining`.
Strategy MINILM only needs q, k and v of transformers.
"""
assert hasattr(self, self.base_model_prefix), "Student class should inherit from %s" % (self.base_model_class)
model = getattr(self, self.base_model_prefix)
encoder = model.encoder
sequence_output, pooled_output = model(input_ids, token_type_ids, attention_mask)
return encoder.q, encoder.k, encoder.v
def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
"""
Replaces `forward` function while using multi gpus to train.
"""
assert hasattr(self, self.base_model_prefix), "Student class should inherit from %s" % (self.base_model_class)
model = getattr(self, self.base_model_prefix)
encoder = model.encoder
sequence_output, pooled_output = model(input_ids, token_type_ids, attention_mask)
for i in range(len(encoder.hidden_states)):
# While using tinybert-4l-312d, tinybert-6l-768d, tinybert-4l-312d-zh,
# tinybert-6l-768d-zh
# While using tinybert-4l-312d-v2, tinybert-6l-768d-v2
# encoder.hidden_states[i] = self.tinybert.fit_dense(encoder.hidden_states[i])
encoder.hidden_states[i] = self.tinybert.fit_denses[i](encoder.hidden_states[i])
return encoder.attentions, encoder.hidden_states
def bert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
"""
Replaces `forward` function while using multi gpus to train.
"""
assert hasattr(self, self.base_model_prefix), "Student class should inherit from %s" % (self.base_model_class)
model = getattr(self, self.base_model_prefix)
encoder = model.encoder
sequence_output, pooled_output = model(input_ids, token_type_ids, attention_mask)
return encoder.attentions, encoder.hidden_states