distill_utils#
- to_distill(self, return_qkv=False, return_attentions=False, return_layer_outputs=False, layer_index=-1)[源代码]#
Can be bound to object with transformer encoder layers, and make model expose attributes
outputs.q
,outputs.k
,outputs.v
,outputs.scaled_qks
,outputs.hidden_states`and `outputs.attentions
of the object for distillation. It could be returned intermediate tensor using in MiniLM and TinyBERT strategy.
- calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0)[源代码]#
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. :param loss_fct: Loss function for distillation. It only supports kl_div loss now. :type loss_fct: callable :param s: Q, K, V of Student. :type s: Tensor :param t: Q, K, V of teacher. :type t: Tensor :param attn_mask: Attention mask for relation. :type attn_mask: Tensor :param num_relation_heads: The number of relation heads. 0 means
num_relation_heads
equalsto origin head num. Defaults to 0.
- 返回:
MiniLM loss value.
- 返回类型:
Tensor
- calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0)[源代码]#
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight
alpha
andbeta
.- 参数:
loss_fct (callable) -- Loss function for distillation. It only supports kl_div loss now.
s (Tensor) -- Q, K, V of Student.
t (Tensor) -- Q, K, V of teacher.
attn_mask (Tensor) -- Attention mask for relation.
num_relation_heads (int) -- The number of relation heads. 0 means
num_relation_heads
equals to origin head num. Defaults to 0.alpha (float) -- The weight for head-head relation. Defaults to 0.0.
beta (float) -- The weight for sample-sample relation. Defaults to 0.0.
- 返回:
- Weighted loss of token-token loss, head-head loss and
sample-sample loss.
- 返回类型:
Tensor