paddlenlp.transformers.utils 源代码
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import inspect
import os
import warnings
from contextlib import ExitStack
from typing import TYPE_CHECKING, ContextManager, List, Optional, Type
if TYPE_CHECKING:
from paddlenlp.transformers import PretrainedModel
import paddle
from paddle.nn import Layer
from paddlenlp.utils.env import HF_CACHE_HOME, MODEL_HOME
from paddlenlp.utils.import_utils import import_module
from paddlenlp.utils.log import logger
[文档]def fn_args_to_dict(func, *args, **kwargs):
"""
Inspect function `func` and its arguments for running, and extract a
dict mapping between argument names and keys.
"""
if hasattr(inspect, "getfullargspec"):
(spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = inspect.getfullargspec(func)
else:
(spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(func)
# add positional argument values
init_dict = dict(zip(spec_args, args))
# add default argument values
kwargs_dict = dict(zip(spec_args[-len(spec_defaults) :], spec_defaults)) if spec_defaults else {}
for k in list(kwargs_dict.keys()):
if k in init_dict:
kwargs_dict.pop(k)
kwargs_dict.update(kwargs)
init_dict.update(kwargs_dict)
return init_dict
[文档]def adapt_stale_fwd_patch(self, name, value):
"""
Since there are some monkey patches for forward of PretrainedModel, such as
model compression, we make these patches compatible with the latest forward
method.
"""
if name == "forward":
# NOTE(guosheng): In dygraph to static, `layer.forward` would be patched
# by an instance of `StaticFunction`. And use string compare to avoid to
# import fluid.
if type(value).__name__.endswith("StaticFunction"):
return value
if hasattr(inspect, "getfullargspec"):
(
patch_spec_args,
patch_spec_varargs,
patch_spec_varkw,
patch_spec_defaults,
_,
_,
_,
) = inspect.getfullargspec(value)
(spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = inspect.getfullargspec(self.forward)
else:
(patch_spec_args, patch_spec_varargs, patch_spec_varkw, patch_spec_defaults) = inspect.getargspec(value)
(spec_args, spec_varargs, spec_varkw, spec_defaults) = inspect.getargspec(self.forward)
new_args = [
arg
for arg in ("output_hidden_states", "output_attentions", "return_dict")
if arg not in patch_spec_args and arg in spec_args
]
if new_args:
if self.__module__.startswith("paddlenlp"):
warnings.warn(
f"The `forward` method of {self.__class__ if isinstance(self, Layer) else self} is patched and the patch "
"might be based on an old oversion which missing some "
f"arguments compared with the latest, such as {new_args}. "
"We automatically add compatibility on the patch for "
"these arguemnts, and maybe the patch should be updated."
)
else:
warnings.warn(
f"The `forward` method of {self.__class__ if isinstance(self, Layer) else self} "
"is patched and the patch might be conflict with patches made "
f"by paddlenlp which seems have more arguments such as {new_args}. "
"We automatically add compatibility on the patch for "
"these arguemnts, and maybe the patch should be updated."
)
if isinstance(self, Layer) and inspect.isfunction(value):
@functools.wraps(value)
def wrap_fwd(*args, **kwargs):
for arg in new_args:
kwargs.pop(arg, None)
return value(self, *args, **kwargs)
else:
@functools.wraps(value)
def wrap_fwd(*args, **kwargs):
for arg in new_args:
kwargs.pop(arg, None)
return value(*args, **kwargs)
return wrap_fwd
return value
[文档]class InitTrackerMeta(type(Layer)):
"""
This metaclass wraps the `__init__` method of a class to add `init_config`
attribute for instances of that class, and `init_config` use a dict to track
the initial configuration. If the class has `_pre_init` or `_post_init`
method, it would be hooked before or after `__init__` and called as
`_pre_init(self, init_fn, init_args)` or `_post_init(self, init_fn, init_args)`.
Since InitTrackerMeta would be used as metaclass for pretrained model classes,
which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)`
rather than `type` as base class for it to avoid inheritance metaclass
conflicts.
"""
def __init__(cls, name, bases, attrs):
init_func = cls.__init__
# If attrs has `__init__`, wrap it using accessable `_pre_init, _post_init`.
# Otherwise, no need to wrap again since the super cls has been wraped.
# TODO: remove reduplicated tracker if using super cls `__init__`
pre_init_func = getattr(cls, "_pre_init", None) if "__init__" in attrs else None
post_init_func = getattr(cls, "_post_init", None) if "__init__" in attrs else None
cls.__init__ = InitTrackerMeta.init_and_track_conf(init_func, pre_init_func, post_init_func)
super(InitTrackerMeta, cls).__init__(name, bases, attrs)
[文档] @staticmethod
def init_and_track_conf(init_func, pre_init_func=None, post_init_func=None):
"""
wraps `init_func` which is `__init__` method of a class to add `init_config`
attribute for instances of that class.
Args:
init_func (callable): It should be the `__init__` method of a class.
warning: `self` always is the class type of down-stream model, eg: BertForTokenClassification
pre_init_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `pre_init_func(self, init_func, *init_args, **init_args)`.
Default None.
post_init_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `post_init_func(self, init_func, *init_args, **init_args)`.
Default None.
Returns:
function: the wrapped function
"""
@functools.wraps(init_func)
def __impl__(self, *args, **kwargs):
# registed helper by `pre_init_func`
if pre_init_func:
pre_init_func(self, init_func, *args, **kwargs)
# keep full configuration
init_func(self, *args, **kwargs)
# registed helper by `post_init_func`
if post_init_func:
post_init_func(self, init_func, *args, **kwargs)
self.init_config = kwargs
if args:
kwargs["init_args"] = args
kwargs["init_class"] = self.__class__.__name__
return __impl__
def __setattr__(self, name, value):
value = adapt_stale_fwd_patch(self, name, value)
return super(InitTrackerMeta, self).__setattr__(name, value)
[文档]def param_in_func(func, param_field: str) -> bool:
"""check if the param_field is in `func` method, eg: if the `bert` param is in `__init__` method
Args:
cls (type): the class of PretrainedModel
param_field (str): the name of field
Returns:
bool: the result of existence
"""
if hasattr(inspect, "getfullargspec"):
result = inspect.getfullargspec(func)
else:
result = inspect.getargspec(func)
return param_field in result[0]
[文档]def resolve_cache_dir(pretrained_model_name_or_path: str, from_hf_hub: bool, cache_dir: Optional[str] = None) -> str:
"""resolve cache dir for PretrainedModel and PretrainedConfig
Args:
pretrained_model_name_or_path (str): the name or path of pretrained model
from_hf_hub (bool): if load from huggingface hub
cache_dir (str): cache_dir for models
"""
if os.path.isdir(pretrained_model_name_or_path):
return pretrained_model_name_or_path
# hf hub library takes care of appending the model name so we don't append the model name
if from_hf_hub:
if cache_dir is not None:
return cache_dir
else:
return HF_CACHE_HOME
else:
if cache_dir is not None:
# since model_clas.from_pretrained calls config_clas.from_pretrained, the model_name may get appended twice
if cache_dir.endswith(pretrained_model_name_or_path):
return cache_dir
else:
return os.path.join(cache_dir, pretrained_model_name_or_path)
return os.path.join(MODEL_HOME, pretrained_model_name_or_path)
[文档]def find_transformer_model_type(model_class: Type) -> str:
"""get the model type from module name,
eg:
BertModel -> bert,
RobertaForTokenClassification -> roberta
Args:
model_class (Type): the class of model
Returns:
str: the type string
"""
from paddlenlp.transformers import PretrainedModel
default_model_type = ""
if not issubclass(model_class, PretrainedModel):
return default_model_type
module_name: str = model_class.__module__
if not module_name.startswith("paddlenlp.transformers."):
return default_model_type
tokens = module_name.split(".")
if len(tokens) < 3:
return default_model_type
return tokens[2]
[文档]def find_transformer_model_class_by_name(model_name: str) -> Optional[Type[PretrainedModel]]:
"""find transformer model_class by name
Args:
model_name (str): the string of class name
Returns:
Optional[Type[PretrainedModel]]: optional pretrained-model class
"""
transformer_module = import_module("paddlenlp.transformers")
for obj_name in dir(transformer_module):
if obj_name.startswith("_"):
continue
obj = getattr(transformer_module, obj_name, None)
if obj is None:
continue
name = getattr(obj, "__name__", None)
if name is None:
continue
if name == model_name:
return obj
logger.debug(f"can not find model_class<{model_name}>")
return None
def is_paddle_support_lazy_init():
return hasattr(paddle, "LazyGuard")
[文档]class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `fastcore` library.
"""
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)