```# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import os
import sys
import math
from functools import partial

import numpy as np

__all__ = ['DetectionF1', 'CorrectionF1']

[文档]class DetectionF1(Metric):

def __init__(self, pos_label=1, name='DetectionF1', *args, **kwargs):
super(DetectionF1, self).__init__(*args, **kwargs)
self.pos_label = pos_label
self._name = name
self.reset()

[文档]    def update(self, preds, labels, length, *args):
# [B, T, 2]
pred_labels = preds.argmax(axis=-1)
for i, label_length in enumerate(length):
pred_label = pred_labels[i][1:1 + label_length]
label = labels[i][1:1 + label_length]
# the sequence has errors
if (label == self.pos_label).any():
if (pred_label == label).all():
self.tp += 1
else:
self.fn += 1
else:
if (label != pred_label).any():
self.fp += 1

[文档]    def reset(self):
"""
Resets all of the metric state.
"""
self.tp = 0
self.fp = 0
self.fn = 0

[文档]    def accumulate(self):
precision = np.nan
if self.tp + self.fp > 0:
precision = self.tp / (self.tp + self.fp)
recall = np.nan
if self.tp + self.fn > 0:
recall = self.tp / (self.tp + self.fn)
if self.tp == 0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return f1, precision, recall

[文档]    def name(self):
"""
Returns name of the metric instance.

Returns:
str: The name of the metric instance.

"""
return self._name

[文档]class CorrectionF1(DetectionF1):

def __init__(self, pos_label=1, name='CorrectionF1', *args, **kwargs):
super(CorrectionF1, self).__init__(pos_label, name, *args, **kwargs)

[文档]    def update(self, det_preds, det_labels, corr_preds, corr_labels, length,
*args):
# [B, T, 2]
det_preds_labels = det_preds.argmax(axis=-1)
corr_preds_labels = corr_preds.argmax(axis=-1)

for i, label_length in enumerate(length):
# Ignore [CLS] token, so calculate from position 1.
det_preds_label = det_preds_labels[i][1:1 + label_length]
det_label = det_labels[i][1:1 + label_length]
corr_preds_label = corr_preds_labels[i][1:1 + label_length]
corr_label = corr_labels[i][1:1 + label_length]

# The sequence has any errors.
if (det_label == self.pos_label).any():
corr_pred_label = corr_preds_label * det_preds_label
corr_label = det_label * corr_label
if (corr_pred_label == corr_label).all():
self.tp += 1
else:
self.fn += 1
else:
if (det_label != det_preds_label).any():
self.fp += 1
```