utils¶
-
fn_args_to_dict
(func, *args, **kwargs)[源代码]¶ Inspect function
func
and its arguments for running, and extract a dict mapping between argument names and keys.
-
adapt_stale_fwd_patch
(self, name, value)[源代码]¶ Since there are some monkey patches for forward of PretrainedModel, such as model compression, we make these patches compatible with the latest forward method.
-
class
InitTrackerMeta
(name, bases, attrs)[源代码]¶ 基类:
type
This metaclass wraps the
__init__
method of a class to addinit_config
attribute for instances of that class, andinit_config
use a dict to track the initial configuration. If the class has_pre_init
or_post_init
method, it would be hooked before or after__init__
and called as_pre_init(self, init_fn, init_args)
or_post_init(self, init_fn, init_args)
. Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer andtype(Layer)
is nottype
, thus usetype(Layer)
rather thantype
as base class for it to avoid inheritance metaclass conflicts.-
static
init_and_track_conf
(init_func, pre_init_func=None, post_init_func=None)[源代码]¶ wraps
init_func
which is__init__
method of a class to addinit_config
attribute for instances of that class. :param init_func: It should be the__init__
method of a class.warning:
self
always is the class type of down-stream model, eg: BertForTokenClassification- 参数
pre_init_func (callable, optional) -- If provided, it would be hooked after
init_func
and called aspre_init_func(self, init_func, *init_args, **init_args)
. Default None.post_init_func (callable, optional) -- If provided, it would be hooked after
init_func
and called aspost_init_func(self, init_func, *init_args, **init_args)
. Default None.
- 返回
the wrapped function
- 返回类型
function
-
static
-
param_in_func
(func, param_field: str) → bool[源代码]¶ check if the param_field is in
func
method, eg: if thebert
param is in__init__
method- 参数
cls (type) -- the class of PretrainedModel
param_field (str) -- the name of field
- 返回
the result of existence
- 返回类型
bool
-
resolve_cache_dir
(pretrained_model_name_or_path: str, from_hf_hub: bool, cache_dir: Optional[str] = None) → str[源代码]¶ resolve cache dir for PretrainedModel and PretrainedConfig
- 参数
pretrained_model_name_or_path (str) -- the name or path of pretrained model
from_hf_hub (bool) -- if load from huggingface hub
cache_dir (str) -- cache_dir for models
-
find_transformer_model_type
(model_class: Type) → str[源代码]¶ - get the model type from module name,
- eg:
BertModel -> bert, RobertaForTokenClassification -> roberta
- 参数
model_class (Type) -- the class of model
- 返回
the type string
- 返回类型
str
-
find_transformer_model_class_by_name
(model_name: str) → Optional[Type[PretrainedModel]][源代码]¶ find transformer model_class by name
- 参数
model_name (str) -- the string of class name
- 返回
optional pretrained-model class
- 返回类型
Optional[Type[PretrainedModel]]