Source code for paddlenlp.transformers.fnet.modeling

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling classes for FNet model."""

import paddle
import paddle.nn as nn
from paddle.nn import Layer

from .. import PretrainedModel, register_base_model
from ..activations import ACT2FN
from .configuration import (
    FNET_PRETRAINED_INIT_CONFIGURATION,
    FNET_PRETRAINED_RESOURCE_FILES_MAP,
    FNetConfig,
)

__all__ = [
    "FNetPretrainedModel",
    "FNetModel",
    "FNetForSequenceClassification",
    "FNetForPreTraining",
    "FNetForMaskedLM",
    "FNetForNextSentencePrediction",
    "FNetForMultipleChoice",
    "FNetForTokenClassification",
    "FNetForQuestionAnswering",
]


class FNetBasicOutput(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.layer_norm(input_tensor + hidden_states)
        return hidden_states


class FNetOutput(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layer_norm(input_tensor + hidden_states)
        return hidden_states


class FNetIntermediate(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class FNetLayer(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.fourier = FNetFourierTransform(config)
        self.intermediate = FNetIntermediate(config)
        self.output = FNetOutput(config)

    def forward(self, hidden_states):
        self_fourier_outputs = self.fourier(hidden_states)
        fourier_output = self_fourier_outputs[0]
        intermediate_output = self.intermediate(fourier_output)
        layer_output = self.output(intermediate_output, fourier_output)

        return (layer_output,)


class FNetEncoder(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.layers = nn.LayerList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
        all_hidden_states = () if output_hidden_states else None
        for i, layer_module in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer_outputs = layer_module(hidden_states)
            hidden_states = layer_outputs[0]
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        if return_dict:
            return {"last_hidden_state": hidden_states, "all_hidden_states": all_hidden_states}
        return (hidden_states,)


class FNetPooler(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class FNetEmbeddings(Layer):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config: FNetConfig):
        super(FNetEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
        # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions.
        self.projection = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids", paddle.arange(config.max_position_embeddings, dtype="int64").expand((1, -1))
        )

    def forward(
        self,
        input_ids,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):
        if input_ids is not None:
            input_shape = input_ids.shape
        else:
            input_shape = inputs_embeds.shape[:-1]
        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = paddle.zeros(input_shape, dtype="int64")

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings

        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.projection(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class FNetBasicFourierTransform(Layer):
    def __init__(self):
        super().__init__()
        self.fourier_transform = paddle.fft.fftn

    def forward(self, hidden_states):
        outputs = self.fourier_transform(hidden_states).real()
        return (outputs,)


class FNetFourierTransform(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.fourier_transform = FNetBasicFourierTransform()
        self.output = FNetBasicOutput(config)

    def forward(self, hidden_states):
        self_outputs = self.fourier_transform(hidden_states)
        fourier_output = self.output(self_outputs[0], hidden_states)
        return (fourier_output,)


class FNetPredictionHeadTransform(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


class FNetLMPredictionHead(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.transform = FNetPredictionHeadTransform(config)
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.vocab_size, config.hidden_size)

        self.bias = self.create_parameter(
            [config.vocab_size], is_bias=True, default_initializer=nn.initializer.Constant(value=0)
        )
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = paddle.matmul(hidden_states, self.decoder.weight, transpose_y=True) + self.bias
        return hidden_states


class FNetOnlyMLMHead(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.predictions = FNetLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class FNetOnlyNSPHead(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class FNetPreTrainingHeads(Layer):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.predictions = FNetLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


[docs] class FNetPretrainedModel(PretrainedModel): """ An abstract class for pretrained FNet models. It provides FNet related `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, `pretrained_resource_files_map`, `base_model_prefix` for downloading and loading pretrained models. See `PretrainedModel` for more details. """ pretrained_init_configuration = FNET_PRETRAINED_INIT_CONFIGURATION pretrained_resource_files_map = FNET_PRETRAINED_RESOURCE_FILES_MAP base_model_prefix = "fnet" config_class = FNetConfig def _init_weights(self, layer): # Initialize the weights. if isinstance(layer, nn.Linear): layer.weight.set_value( paddle.tensor.normal( mean=0.0, std=self.config.initializer_range, shape=layer.weight.shape, ) ) if layer.bias is not None: layer.bias.set_value(paddle.zeros_like(layer.bias)) elif isinstance(layer, nn.Embedding): layer.weight.set_value( paddle.tensor.normal( mean=0.0, std=self.config.initializer_range, shape=layer.weight.shape, ) ) if layer._padding_idx is not None: layer.weight[layer._padding_idx].set_value(paddle.zeros_like(layer.weight[layer._padding_idx])) elif isinstance(layer, nn.LayerNorm): layer.bias.set_value(paddle.zeros_like(layer.bias)) layer.weight.set_value(paddle.ones_like(layer.weight))
[docs] @register_base_model class FNetModel(FNetPretrainedModel): """ The model can behave as an encoder, following the architecture described in `FNet: Mixing Tokens with Fourier Transforms <https://arxiv.org/abs/2105.03824>`__ by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. """ def __init__(self, config: FNetConfig): super(FNetModel, self).__init__(config) self.initializer_range = config.initializer_range self.num_hidden_layers = config.num_hidden_layers self.embeddings = FNetEmbeddings(config) self.encoder = FNetEncoder(config) self.pooler = FNetPooler(config) if config.add_pooling_layer else None
[docs] def get_input_embeddings(self): return self.embeddings.word_embeddings
[docs] def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, output_hidden_states=None, return_dict=None, ): r""" The FNetModel forward method. Args: input_ids (Tensor): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. token_type_ids (Tensor, optional): Segment token indices to indicate different portions of the inputs. Selected in the range ``[0, type_vocab_size - 1]``. If `type_vocab_size` is 2, which means the inputs have two portions. Indices can either be 0 or 1: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. Defaults to `None`, which means we don't add segment embeddings. position_ids(Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. inputs_embeds (Tensor, optional): If you want to control how to convert `inputs_ids` indices into associated vectors, you can pass an embedded representation directly instead of passing `inputs_ids`. output_hidden_states (bool, optional): Whether or not to return all hidden states. Default to `None`. return_dict (bool, optional): Whether or not to return a dict instead of a plain tuple. Default to `None`. Returns: tuple or Dict: Returns tuple (`sequence_output`, `pooled_output`, `encoder_outputs[1:]`) or a dict with last_hidden_state`, `pooled_output`, `all_hidden_states`, fields. With the fields: - `sequence_output` (Tensor): Sequence of hidden-states at the last layer of the model. It's data type should be float32 and has a shape of [`batch_size, sequence_length, hidden_size`]. - `pooled_output` (Tensor): The output of first token (`[CLS]`) in sequence. We "pool" the model by simply taking the hidden state corresponding to the first token. Its data type should be float32 and has a shape of [batch_size, hidden_size]. - `last_hidden_state` (Tensor): The output of the last encoder layer, it is also the `sequence_output`. It's data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. - `all_hidden_states` (Tensor): Hidden_states of all layers in the Transformer encoder. The length of `all_hidden_states` is `num_hidden_layers + 1`. For all element in the tuple, its data type should be float32 and its shape is [`batch_size, sequence_length, hidden_size`]. Example: .. code-block:: import paddle from paddlenlp.transformers.fnet.modeling import FNetModel from paddlenlp.transformers.fnet.tokenizer import FNetTokenizer tokenizer = FNetTokenizer.from_pretrained('fnet-base') model = FNetModel.from_pretrained('fnet-base') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} output = model(**inputs) """ if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.shape elif inputs_embeds is not None: input_shape = inputs_embeds.shape[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if token_type_ids is None: token_type_ids = paddle.zeros(shape=input_shape, dtype="int64") embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, ) encoder_outputs = self.encoder( embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs["last_hidden_state"] if return_dict else encoder_outputs[0] pooler_output = self.pooler(sequence_output) if self.pooler is not None else None if return_dict: return { "last_hidden_state": sequence_output, "pooler_output": pooler_output, "all_hidden_states": encoder_outputs["all_hidden_states"], } return (sequence_output, pooler_output) + encoder_outputs[1:]
[docs] class FNetForSequenceClassification(FNetPretrainedModel): """ FNet Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks. Args: fnet (:class:`FNetModel`): An instance of FNetModel. num_classes (int, optional): The number of classes. Defaults to `2`. """ def __init__(self, config: FNetConfig, num_classes=2): super(FNetForSequenceClassification, self).__init__(config) self.num_classes = num_classes self.fnet = FNetModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_classes)
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, output_hidden_states=None, return_dict=None, ): r""" The FNetForSequenceClassification forward method. Args: input_ids (Tensor): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. token_type_ids (Tensor, optional): Segment token indices to indicate different portions of the inputs. Selected in the range ``[0, type_vocab_size - 1]``. If `type_vocab_size` is 2, which means the inputs have two portions. Indices can either be 0 or 1: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. Defaults to `None`, which means we don't add segment embeddings. position_ids(Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. inputs_embeds (Tensor, optional): If you want to control how to convert `inputs_ids` indices into associated vectors, you can pass an embedded representation directly instead of passing `inputs_ids`. output_hidden_states (bool, optional): Whether or not to return all hidden states. Default to `None`. return_dict (bool, optional): Whether or not to return a dict instead of a plain tuple. Default to `None`. Returns: Tensor or Dict: Returns tensor `logits`, or a dict with `logits`, `hidden_states`, `attentions` fields. With the fields: - `logits` (Tensor): A tensor of the input text classification logits. Shape as `[batch_size, num_classes]` and dtype as float32. - `hidden_states` (Tensor): Hidden_states of all layers in the Transformer encoder. The length of `hidden_states` is `num_hidden_layers + 1`. For all element in the tuple, its data type should be float32 and its shape is [`batch_size, sequence_length, hidden_size`]. Example: .. code-block:: import paddle from paddlenlp.transformers.fnet.modeling import FNetForSequenceClassification from paddlenlp.transformers.fnet.tokenizer import FNetTokenizer tokenizer = FNetTokenizer.from_pretrained('fnet-base') model = FNetModel.from_pretrained('fnet-base') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} output = model(**inputs) """ outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs["pooler_output"] if return_dict else outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if return_dict: return { "logits": logits, "hidden_states": outputs["all_hidden_states"], } return logits
[docs] class FNetForPreTraining(FNetPretrainedModel): """ FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next sentence prediction (classification)` head. """ def __init__(self, config: FNetConfig): super().__init__(config) self.fnet = FNetModel(config) self.cls = FNetPreTrainingHeads(config)
[docs] def get_output_embeddings(self): return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings
[docs] def get_input_embeddings(self): return self.fnet.embeddings.word_embeddings
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, next_sentence_label=None, output_hidden_states=None, return_dict=None, ): r""" The FNetForPretraining forward method. Args: input_ids (Tensor): See :class:`FNetModel`. token_type_ids (Tensor, optional): See :class:`FNetModel`. position_ids(Tensor, optional): See :class:`FNetModel`. labels (LongTensor of shape (batch_size, sequence_length), optional): Labels for computing the masked language modeling loss. inputs_embeds(Tensor, optional): See :class:`FNetModel`. next_sentence_labels(Tensor): The labels of the next sentence prediction task, the dimensionality of `next_sentence_labels` is equal to `seq_relation_labels`. Its data type should be int64 and its shape is [batch_size, 1] output_hidden_states (bool, optional): See :class:`FNetModel`. return_dict (bool, optional): See :class:`FNetModel`. Returns: tuple or Dict: Returns tuple (`prediction_scores`, `seq_relationship_score`) or a dict with `prediction_logits`, `seq_relationship_logits`, `hidden_states` fields. """ outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if not return_dict else outputs["last_hidden_state"] pooled_output = outputs[1] if not return_dict else outputs["pooler_output"] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) if return_dict: return { "prediction_logits": prediction_scores, "seq_relationship_logits": seq_relationship_score, "hidden_states": outputs["all_hidden_states"], } return prediction_scores, seq_relationship_score, outputs["all_hidden_states"]
[docs] class FNetForMaskedLM(FNetPretrainedModel): """ FNet Model with a `masked language modeling` head on top. Args: fnet (:class:`FNetModel`): An instance of :class:`FNetModel`. """ def __init__(self, config: FNetConfig): super().__init__(config) self.fnet = FNetModel(config) self.cls = FNetOnlyMLMHead(config) self.tie_weights()
[docs] def get_output_embeddings(self): return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings
[docs] def get_input_embeddings(self): return self.fnet.embeddings.word_embeddings
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, next_sentence_label=None, output_hidden_states=None, return_dict=None, ): r""" The FNetForMaskedLM forward method. Args: input_ids (Tensor): See :class:`FNetModel`. token_type_ids (Tensor, optional): See :class:`FNetModel`. position_ids(Tensor, optional): See :class:`FNetModel`. inputs_embeds(Tensor, optional): See :class:`FNetModel`. labels(Tensor, optional): See :class:`FNetForPreTraining`. next_sentence_label(Tensor, optional): See :class:`FNetForPreTraining`. output_hidden_states(Tensor, optional): See :class:`FNetModel`. return_dict(bool, optional): See :class:`FNetModel`. Returns: Tensor or Dict: Returns tensor `prediction_scores` or a dict with `prediction_logits`, `hidden_states` fields. With the fields: - `prediction_scores` (Tensor): The scores of masked token prediction. Its data type should be float32. and its shape is [batch_size, sequence_length, vocab_size]. - `hidden_states` (Tensor): Hidden_states of all layers in the Transformer encoder. The length of `hidden_states` is `num_hidden_layers + 1`. For all element in the tuple, its data type should be float32 and its shape is [`batch_size, sequence_length, hidden_size`]. """ outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if not return_dict else outputs["last_hidden_state"] prediction_scores = self.cls(sequence_output) if return_dict: return {"prediction_logits": prediction_scores, "hidden_states": outputs["all_hidden_states"]} return prediction_scores, outputs["all_hidden_states"]
[docs] class FNetForNextSentencePrediction(FNetPretrainedModel): """ FNet Model with a `next sentence prediction` head on top. Args: fnet (:class:`FNetModel`): An instance of :class:`FNetModel`. """ def __init__(self, config: FNetConfig): super().__init__(config) self.fnet = FNetModel(config) self.cls = FNetOnlyNSPHead(config)
[docs] def get_output_embeddings(self): return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings
[docs] def get_input_embeddings(self): return self.fnet.embeddings.word_embeddings
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, next_sentence_label=None, output_hidden_states=None, return_dict=None, ): outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] if not return_dict else outputs["pooler_output"] seq_relationship_score = self.cls(pooled_output) if return_dict: return {"seq_relationship_logits": seq_relationship_score, "hidden_states": outputs["all_hidden_states"]} return seq_relationship_score, outputs["all_hidden_states"]
[docs] class FNetForMultipleChoice(FNetPretrainedModel): """ FNet Model with a linear layer on top of the hidden-states output layer, designed for multiple choice tasks like SWAG tasks . Args: fnet (:class:`FNetModel`): An instance of FNetModel. """ def __init__(self, config: FNetConfig): super(FNetForMultipleChoice, self).__init__(config) self.fnet = FNetModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1)
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, output_hidden_states=None, return_dict=None, ): num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] input_ids = input_ids.reshape([-1, input_ids.shape[-1]]) if input_ids is not None else None token_type_ids = token_type_ids.reshape([-1, token_type_ids.shape[-1]]) if token_type_ids is not None else None position_ids = position_ids.reshape([-1, position_ids.shape[-1]]) if position_ids is not None else None inputs_embeds = ( inputs_embeds.reshape([-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]]) if inputs_embeds is not None else None ) outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs["pooler_output"] if return_dict else outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape([-1, num_choices]) if return_dict: return { "logits": reshaped_logits, "hidden_states": outputs["all_hidden_states"], } return reshaped_logits
[docs] class FNetForTokenClassification(FNetPretrainedModel): """ FNet Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks. Args: fnet (:class:`FNetModel`): An instance of FNetModel. num_classes (int, optional): The number of classes. Defaults to `2`. """ def __init__(self, config: FNetConfig, num_classes=2): super(FNetForTokenClassification, self).__init__(config) self.fnet = FNetModel(config) self.num_classes = num_classes self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, self.num_classes)
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, output_hidden_states=None, return_dict=None, ): outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if not return_dict else outputs["last_hidden_state"] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) if return_dict: return { "logits": logits, "hidden_states": outputs["all_hidden_states"], } return logits
[docs] class FNetForQuestionAnswering(FNetPretrainedModel): """ FNet Model with a linear layer on top of the hidden-states output to compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD. Args: fnet (:class:`FNetModel`): An instance of FNetModel. num_labels (int): The number of labels. """ def __init__(self, config: FNetConfig, num_labels): super(FNetForQuestionAnswering, self).__init__(config) self.num_labels = num_labels self.fnet = FNetModel(config) self.qa_outputs = nn.Linear(config.hidden_size, self.num_labels)
[docs] def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, start_positions=None, end_positions=None, output_hidden_states=None, return_dict=None, ): outputs = self.fnet( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if not return_dict else outputs["last_hidden_state"] logits = self.qa_outputs(sequence_output) start_logits, end_logits = paddle.split(logits, num_or_sections=2, axis=-1) start_logits = start_logits.squeeze(axis=-1) end_logits = start_logits.squeeze(axis=-1) if return_dict: return { "start_logits": start_logits, "end_logits": end_logits, "hidden_states": outputs["all_hidden_states"], } return start_logits, end_logits