distill_utils

to_distill(self, return_qkv=False, return_attentions=False, return_layer_outputs=False, layer_index=- 1)[源代码]

Can be bound to object with transformer encoder layers, and make model expose attributes outputs.q, outputs.k, outputs.v, outputs.scaled_qks, outputs.hidden_states`and `outputs.attentions of the object for distillation. It could be returned intermediate tensor using in MiniLM and TinyBERT strategy.

calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0)[源代码]

Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. :param loss_fct: Loss function for distillation. It only supports kl_div loss now. :type loss_fct: callable :param s: Q, K, V of Student. :type s: Tensor :param t: Q, K, V of teacher. :type t: Tensor :param attn_mask: Attention mask for relation. :type attn_mask: Tensor :param num_relation_heads: The number of relation heads. 0 means num_relation_heads equals

to origin head num. Defaults to 0.

返回

MiniLM loss value.

返回类型

Tensor

calc_multi_relation_loss(loss_fct, s, t, attn_mask, num_relation_heads=0, alpha=0.0, beta=0.0)[源代码]

Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight alpha and beta.

参数
  • loss_fct (callable) -- Loss function for distillation. It only supports kl_div loss now.

  • s (Tensor) -- Q, K, V of Student.

  • t (Tensor) -- Q, K, V of teacher.

  • attn_mask (Tensor) -- Attention mask for relation.

  • num_relation_heads (int) -- The number of relation heads. 0 means num_relation_heads equals to origin head num. Defaults to 0.

  • alpha (float) -- The weight for head-head relation. Defaults to 0.0.

  • beta (float) -- The weight for sample-sample relation. Defaults to 0.0.

返回

Weighted loss of token-token loss, head-head loss and

sample-sample loss.

返回类型

Tensor