# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import math
from typing import Optional, Tuple

from .. import PretrainedModel, register_base_model
from ..activations import ACT2FN
from .configuration import (
PROPHETNET_PRETRAINED_INIT_CONFIGURATION,
PROPHETNET_PRETRAINED_RESOURCE_FILES_MAP,
ProphetNetConfig,
)

__all__ = [
"ProphetNetModel",
"ProphetNetPretrainedModel",
"ProphetNetEncoder",
"ProphetNetDecoder",
"ProphetNetForConditionalGeneration",
]

def ngram_attention_bias(sequence_length, ngram, dtype):
"""
This function computes the bias for the predict stream
"""
left_block = paddle.ones((ngram, sequence_length, sequence_length), dtype=dtype) * float("-inf")
right_block = left_block.detach().clone()
# create bias
for stream_idx in range(ngram):
right_block[stream_idx] = right_block[stream_idx].fill_diagonal_(0, wrap=False)
left_block[stream_idx] = paddle.triu(left_block[stream_idx], diagonal=-stream_idx + 1)

left_block[:, :, 0] = 0

def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
"""
This function computes individual parts of the relative position buckets. For more detail, see paper.
"""
inv_relative_positions = -relative_positions
rel_positions_bucket = 0

if is_bidirectional:
num_buckets = num_buckets // 2
rel_positions_bucket = (
rel_positions_bucket
)
* num_buckets
)
else:
inv_relative_positions = (
)
* inv_relative_positions
)

max_exact = num_buckets // 2
) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
val_if_large_num_buckets = paddle.ones_like(val_if_large) * (num_buckets - 1)
val_if_large = (
+ (1 - val_if_large_lt) * val_if_large_num_buckets
)
)
return rel_positions_bucket

def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
"""
This function computes both main and predict relative position buckets. For more detail, see paper.
"""
# main stream
)
main_stream_relative_positions = main_stream_relative_positions - paddle.unsqueeze(position_ids, axis=-1)

# predicting stream
paddle.concat([position_ids - 1, position_ids], axis=-1), axis=1
)
predicting_stream_relative_positions, repeat_times=[1, position_ids.shape[-1], 1]
)
position_ids, axis=-1
)

# get both position buckets
main_relative_position_buckets = compute_relative_buckets(
num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
)
predict_relative_position_buckets = compute_relative_buckets(
num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
)
return main_relative_position_buckets, predict_relative_position_buckets