distill_utils
- to_distill(self, return_qkv=False, return_attentions=False, return_layer_outputs=False, layer_index=-1)[source]
Can be bound to object with transformer encoder layers, and make model expose attributes
outputs.q
,outputs.k
,outputs.v
,outputs.scaled_qks
,outputs.hidden_states`and `outputs.attentions
of the object for distillation. It could be returned intermediate tensor using in MiniLM and TinyBERT strategy.
- calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0)[source]
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. :param loss_fct: Loss function for distillation. It only supports kl_div loss now. :type loss_fct: callable :param s: Q, K, V of Student. :type s: Tensor :param t: Q, K, V of teacher. :type t: Tensor :param attn_mask: Attention mask for relation. :type attn_mask: Tensor :param num_relation_heads: The number of relation heads. 0 means
num_relation_heads
equalsto origin head num. Defaults to 0.
- Returns:
MiniLM loss value.
- Return type:
Tensor
- calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0)[source]
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight
alpha
andbeta
.- Parameters:
loss_fct (callable) – Loss function for distillation. It only supports kl_div loss now.
s (Tensor) – Q, K, V of Student.
t (Tensor) – Q, K, V of teacher.
attn_mask (Tensor) – Attention mask for relation.
num_relation_heads (int) – The number of relation heads. 0 means
num_relation_heads
equals to origin head num. Defaults to 0.alpha (float) – The weight for head-head relation. Defaults to 0.0.
beta (float) – The weight for sample-sample relation. Defaults to 0.0.
- Returns:
- Weighted loss of token-token loss, head-head loss and
sample-sample loss.
- Return type:
Tensor