utils¶
-
fn_args_to_dict
(func, *args, **kwargs)[source]¶ Inspect function
func
and its arguments for running, and extract a dict mapping between argument names and keys.
-
adapt_stale_fwd_patch
(self, name, value)[source]¶ Since there are some monkey patches for forward of PretrainedModel, such as model compression, we make these patches compatible with the latest forward method.
-
class
InitTrackerMeta
(name, bases, attrs)[source]¶ Bases:
type
This metaclass wraps the
__init__
method of a class to addinit_config
attribute for instances of that class, andinit_config
use a dict to track the initial configuration. If the class has_pre_init
or_post_init
method, it would be hooked before or after__init__
and called as_pre_init(self, init_fn, init_args)
or_post_init(self, init_fn, init_args)
. Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer andtype(Layer)
is nottype
, thus usetype(Layer)
rather thantype
as base class for it to avoid inheritance metaclass conflicts.-
static
init_and_track_conf
(init_func, pre_init_func=None, post_init_func=None)[source]¶ wraps
init_func
which is__init__
method of a class to addinit_config
attribute for instances of that class. :param init_func: It should be the__init__
method of a class.warning:
self
always is the class type of down-stream model, eg: BertForTokenClassification- Parameters
pre_init_func (callable, optional) – If provided, it would be hooked after
init_func
and called aspre_init_func(self, init_func, *init_args, **init_args)
. Default None.post_init_func (callable, optional) – If provided, it would be hooked after
init_func
and called aspost_init_func(self, init_func, *init_args, **init_args)
. Default None.
- Returns
the wrapped function
- Return type
function
-
static
-
param_in_func
(func, param_field: str) → bool[source]¶ check if the param_field is in
func
method, eg: if thebert
param is in__init__
method- Parameters
cls (type) – the class of PretrainedModel
param_field (str) – the name of field
- Returns
the result of existence
- Return type
bool
-
resolve_cache_dir
(pretrained_model_name_or_path: str, from_hf_hub: bool, cache_dir: Optional[str] = None) → str[source]¶ resolve cache dir for PretrainedModel and PretrainedConfig
- Parameters
pretrained_model_name_or_path (str) – the name or path of pretrained model
from_hf_hub (bool) – if load from huggingface hub
cache_dir (str) – cache_dir for models
-
find_transformer_model_type
(model_class: Type) → str[source]¶ - get the model type from module name,
- eg:
BertModel -> bert, RobertaForTokenClassification -> roberta
- Parameters
model_class (Type) – the class of model
- Returns
the type string
- Return type
str
-
find_transformer_model_class_by_name
(model_name: str) → Optional[Type[PretrainedModel]][source]¶ find transformer model_class by name
- Parameters
model_name (str) – the string of class name
- Returns
optional pretrained-model class
- Return type
Optional[Type[PretrainedModel]]