utils

fn_args_to_dict(func, *args, **kwargs)[source]

Inspect function func and its arguments for running, and extract a dict mapping between argument names and keys.

adapt_stale_fwd_patch(self, name, value)[source]

Since there are some monkey patches for forward of PretrainedModel, such as model compression, we make these patches compatible with the latest forward method.

class InitTrackerMeta(name, bases, attrs)[source]

Bases: type

This metaclass wraps the __init__ method of a class to add init_config attribute for instances of that class, and init_config use a dict to track the initial configuration. If the class has _pre_init or _post_init method, it would be hooked before or after __init__ and called as _pre_init(self, init_fn, init_args) or _post_init(self, init_fn, init_args). Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer and type(Layer) is not type, thus use type(Layer) rather than type as base class for it to avoid inheritance metaclass conflicts.

static init_and_track_conf(init_func, pre_init_func=None, post_init_func=None)[source]

wraps init_func which is __init__ method of a class to add init_config attribute for instances of that class. :param init_func: It should be the __init__ method of a class.

warning: self always is the class type of down-stream model, eg: BertForTokenClassification

Parameters:
  • pre_init_func (callable, optional) – If provided, it would be hooked after init_func and called as pre_init_func(self, init_func, *init_args, **init_args). Default None.

  • post_init_func (callable, optional) – If provided, it would be hooked after init_func and called as post_init_func(self, init_func, *init_args, **init_args). Default None.

Returns:

the wrapped function

Return type:

function

param_in_func(func, param_field: str) bool[source]

check if the param_field is in func method, eg: if the bert param is in __init__ method

Parameters:
  • cls (type) – the class of PretrainedModel

  • param_field (str) – the name of field

Returns:

the result of existence

Return type:

bool

resolve_cache_dir(pretrained_model_name_or_path: str, from_hf_hub: bool, cache_dir: str | None = None) str[source]

resolve cache dir for PretrainedModel and PretrainedConfig

Parameters:
  • pretrained_model_name_or_path (str) – the name or path of pretrained model

  • from_hf_hub (bool) – if load from huggingface hub

  • cache_dir (str) – cache_dir for models

find_transformer_model_type(model_class: Type) str[source]
get the model type from module name,
eg:

BertModel -> bert, RobertaForTokenClassification -> roberta

Parameters:

model_class (Type) – the class of model

Returns:

the type string

Return type:

str

find_transformer_model_class_by_name(model_name: str) Type[PretrainedModel] | None[source]

find transformer model_class by name

Parameters:

model_name (str) – the string of class name

Returns:

optional pretrained-model class

Return type:

Optional[Type[PretrainedModel]]

class ContextManagers(context_managers: List[AbstractContextManager])[source]

Bases: object

Wrapper for contextlib.ExitStack which enters a collection of context managers. Adaptation of ContextManagers in the fastcore library.