to_distill(self, return_qkv=False, return_attentions=False, return_layer_outputs=False, layer_index=- 1)¶
Can be bound to object with transformer encoder layers, and make model expose attributes
outputs.hidden_states`and `outputs.attentionsof the object for distillation. It could be returned intermediate tensor using in MiniLM and TinyBERT strategy.
calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0)¶
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. :param loss_fct: Loss function for distillation. It only supports kl_div loss now. :type loss_fct: callable :param s: Q, K, V of Student. :type s: Tensor :param t: Q, K, V of teacher. :type t: Tensor :param attn_mask: Attention mask for relation. :type attn_mask: Tensor :param num_relation_heads: The number of relation heads. 0 means
to origin head num. Defaults to 0.
MiniLM loss value.
- Return type
calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0)¶
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight
loss_fct (callable) – Loss function for distillation. It only supports kl_div loss now.
s (Tensor) – Q, K, V of Student.
t (Tensor) – Q, K, V of teacher.
attn_mask (Tensor) – Attention mask for relation.
num_relation_heads (int) – The number of relation heads. 0 means
num_relation_headsequals to origin head num. Defaults to 0.
alpha (float) – The weight for head-head relation. Defaults to 0.0.
beta (float) – The weight for sample-sample relation. Defaults to 0.0.
- Weighted loss of token-token loss, head-head loss and
- Return type