paddlenlp.metrics.sighan 源代码

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
from paddle.metric import Metric

__all__ = ["DetectionF1", "CorrectionF1"]


[文档]class DetectionF1(Metric): def __init__(self, pos_label=1, name="DetectionF1", *args, **kwargs): super(DetectionF1, self).__init__(*args, **kwargs) self.pos_label = pos_label self._name = name self.reset()
[文档] def update(self, preds, labels, length, *args): # [B, T, 2] pred_labels = preds.argmax(axis=-1) for i, label_length in enumerate(length): pred_label = pred_labels[i][1 : 1 + label_length] label = labels[i][1 : 1 + label_length] # the sequence has errors if (label == self.pos_label).any(): if (pred_label == label).all(): self.tp += 1 else: self.fn += 1 else: if (label != pred_label).any(): self.fp += 1
[文档] def reset(self): """ Resets all of the metric state. """ self.tp = 0 self.fp = 0 self.fn = 0
[文档] def accumulate(self): precision = np.nan if self.tp + self.fp > 0: precision = self.tp / (self.tp + self.fp) recall = np.nan if self.tp + self.fn > 0: recall = self.tp / (self.tp + self.fn) if self.tp == 0: f1 = 0.0 else: f1 = 2 * precision * recall / (precision + recall) return f1, precision, recall
[文档] def name(self): """ Returns name of the metric instance. Returns: str: The name of the metric instance. """ return self._name
[文档]class CorrectionF1(DetectionF1): def __init__(self, pos_label=1, name="CorrectionF1", *args, **kwargs): super(CorrectionF1, self).__init__(pos_label, name, *args, **kwargs)
[文档] def update(self, det_preds, det_labels, corr_preds, corr_labels, length, *args): # [B, T, 2] det_preds_labels = det_preds.argmax(axis=-1) corr_preds_labels = corr_preds.argmax(axis=-1) for i, label_length in enumerate(length): # Ignore [CLS] token, so calculate from position 1. det_preds_label = det_preds_labels[i][1 : 1 + label_length] det_label = det_labels[i][1 : 1 + label_length] corr_preds_label = corr_preds_labels[i][1 : 1 + label_length] corr_label = corr_labels[i][1 : 1 + label_length] # The sequence has any errors. if (det_label == self.pos_label).any(): corr_pred_label = corr_preds_label * det_preds_label corr_label = det_label * corr_label if (corr_pred_label == corr_label).all(): self.tp += 1 else: self.fn += 1 else: if (det_label != det_preds_label).any(): self.fp += 1