utils
- fn_args_to_dict(func, *args, **kwargs)[source]
Inspect function
func
and its arguments for running, and extract a dict mapping between argument names and keys.
- adapt_stale_fwd_patch(self, name, value)[source]
Since there are some monkey patches for forward of PretrainedModel, such as model compression, we make these patches compatible with the latest forward method.
- class InitTrackerMeta(name, bases, attrs)[source]
Bases:
type
This metaclass wraps the
__init__
method of a class to addinit_config
attribute for instances of that class, andinit_config
use a dict to track the initial configuration. If the class has_pre_init
or_post_init
method, it would be hooked before or after__init__
and called as_pre_init(self, init_fn, init_args)
or_post_init(self, init_fn, init_args)
. Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer andtype(Layer)
is nottype
, thus usetype(Layer)
rather thantype
as base class for it to avoid inheritance metaclass conflicts.- static init_and_track_conf(init_func, pre_init_func=None, post_init_func=None)[source]
wraps
init_func
which is__init__
method of a class to addinit_config
attribute for instances of that class. :param init_func: It should be the__init__
method of a class.warning:
self
always is the class type of down-stream model, eg: BertForTokenClassification- Parameters:
pre_init_func (callable, optional) – If provided, it would be hooked after
init_func
and called aspre_init_func(self, init_func, *init_args, **init_args)
. Default None.post_init_func (callable, optional) – If provided, it would be hooked after
init_func
and called aspost_init_func(self, init_func, *init_args, **init_args)
. Default None.
- Returns:
the wrapped function
- Return type:
function
- param_in_func(func, param_field: str) bool [source]
check if the param_field is in
func
method, eg: if thebert
param is in__init__
method- Parameters:
cls (type) – the class of PretrainedModel
param_field (str) – the name of field
- Returns:
the result of existence
- Return type:
bool
- resolve_cache_dir(pretrained_model_name_or_path: str, from_hf_hub: bool, cache_dir: str | None = None) str [source]
resolve cache dir for PretrainedModel and PretrainedConfig
- Parameters:
pretrained_model_name_or_path (str) – the name or path of pretrained model
from_hf_hub (bool) – if load from huggingface hub
cache_dir (str) – cache_dir for models
- find_transformer_model_type(model_class: Type) str [source]
- get the model type from module name,
- eg:
BertModel -> bert, RobertaForTokenClassification -> roberta
- Parameters:
model_class (Type) – the class of model
- Returns:
the type string
- Return type:
str
- find_transformer_model_class_by_name(model_name: str) Type[PretrainedModel] | None [source]
find transformer model_class by name
- Parameters:
model_name (str) – the string of class name
- Returns:
optional pretrained-model class
- Return type:
Optional[Type[PretrainedModel]]
- class ContextManagers(context_managers: List[AbstractContextManager])[source]
Bases:
object
Wrapper for
contextlib.ExitStack
which enters a collection of context managers. Adaptation ofContextManagers
in thefastcore
library.