utils#
- convert_ndarray_dtype(np_array: ndarray, target_dtype: str) ndarray [源代码]#
convert ndarray
- 参数:
np_array (np.ndarray) -- numpy ndarray instance
target_dtype (str) -- the target dtype
- 返回:
converted numpy ndarray instance
- 返回类型:
np.ndarray
- get_scale_by_dtype(dtype: str | None = None, return_positive: bool = True) float [源代码]#
get scale value by dtype
- 参数:
dtype (str) -- the string dtype value
- 返回:
the scale value
- 返回类型:
float
- fn_args_to_dict(func, *args, **kwargs)[源代码]#
Inspect function
func
and its arguments for running, and extract a dict mapping between argument names and keys.
- adapt_stale_fwd_patch(self, name, value)[源代码]#
Since there are some monkey patches for forward of PretrainedModel, such as model compression, we make these patches compatible with the latest forward method.
- class InitTrackerMeta(name, bases, attrs)[源代码]#
基类:
type
This metaclass wraps the
__init__
method of a class to addinit_config
attribute for instances of that class, andinit_config
use a dict to track the initial configuration. If the class has_pre_init
or_post_init
method, it would be hooked before or after__init__
and called as_pre_init(self, init_fn, init_args)
or_post_init(self, init_fn, init_args)
. Since InitTrackerMeta would be used as metaclass for pretrained model classes, which always are Layer andtype(Layer)
is nottype
, thus usetype(Layer)
rather thantype
as base class for it to avoid inheritance metaclass conflicts.- static init_and_track_conf(init_func, pre_init_func=None, post_init_func=None)[源代码]#
wraps
init_func
which is__init__
method of a class to addinit_config
attribute for instances of that class. :param init_func: It should be the__init__
method of a class.warning:
self
always is the class type of down-stream model, eg: BertForTokenClassification- 参数:
pre_init_func (callable, optional) -- If provided, it would be hooked after
init_func
and called aspre_init_func(self, init_func, *init_args, **init_args)
. Default None.post_init_func (callable, optional) -- If provided, it would be hooked after
init_func
and called aspost_init_func(self, init_func, *init_args, **init_args)
. Default None.
- 返回:
the wrapped function
- 返回类型:
function
- param_in_func(func, param_field: str) bool [源代码]#
check if the param_field is in
func
method, eg: if thebert
param is in__init__
method- 参数:
cls (type) -- the class of PretrainedModel
param_field (str) -- the name of field
- 返回:
the result of existence
- 返回类型:
bool
- resolve_cache_dir(from_hf_hub: bool, from_aistudio: bool, cache_dir: str | None = None) str [源代码]#
resolve cache dir for PretrainedModel and PretrainedConfig
- 参数:
from_hf_hub (bool) -- if load from huggingface hub
cache_dir (str) -- cache_dir for models
- find_transformer_model_type(model_class: Type) str [源代码]#
- get the model type from module name,
- eg:
BertModel -> bert, RobertaForTokenClassification -> roberta
- 参数:
model_class (Type) -- the class of model
- 返回:
the type string
- 返回类型:
str
- find_transformer_model_class_by_name(model_name: str) Type[PretrainedModel] | None [源代码]#
find transformer model_class by name
- 参数:
model_name (str) -- the string of class name
- 返回:
optional pretrained-model class
- 返回类型:
Optional[Type[PretrainedModel]]
- convert_file_size_to_int(size: int | str)[源代码]#
Converts a size expressed as a string with digits an unit (like
"5MB"
) to an integer (in bytes). :param size: The size to convert. Will be directly returned if anint
. :type size:int
orstr
Example:
`py >>> convert_file_size_to_int("1MiB") 1048576 `
- cached_file(path_or_repo_id: str | PathLike, filename: str, cache_dir: str | PathLike | None = None, subfolder: str = '', from_aistudio: bool = False, _raise_exceptions_for_missing_entries: bool = True, _raise_exceptions_for_connection_errors: bool = True, pretrained_model_name_or_path=None) str [源代码]#
Tries to locate a file in a local folder and repo, downloads and cache it if necessary. :param path_or_repo_id: This can be either:
a string, the model id of a model repo on huggingface.co.
a path to a directory potentially containing the file.
- 参数:
filename (
str
) -- The name of the file to locate inpath_or_repo
.cache_dir (
str
oros.PathLike
, optional) -- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.subfolder (
str
, optional, defaults to""
) -- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
- 返回:
Returns the resolved file (to the cache folder if downloaded from a repo).
- 返回类型:
Optional[str]
Examples:
`python # Download a model weight from the Hub and cache it. model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin") `
- get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, cache_dir=None, subfolder='', from_aistudio=False, from_hf_hub=False)[源代码]#
For a given model: - download and cache all the shards of a sharded checkpoint if
pretrained_model_name_or_path
is a model ID on theHub
returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [
PretrainedModel.from_pretrained
].index_filename
is the full path to the index (downloaded and cached ifpretrained_model_name_or_path
is a model ID on the Hub).
- class ContextManagers(context_managers: List[AbstractContextManager])[源代码]#
基类:
object
Wrapper for
contextlib.ExitStack
which enters a collection of context managers. Adaptation ofContextManagers
in thefastcore
library.
- dtype_byte_size(dtype)[源代码]#
Returns the size (in bytes) occupied by one parameter of type
dtype
.Example:
`py >>> dtype_byte_size(paddle.float32) 4 `
- class CaptureStd(out=True, err=True, replay=True)[源代码]#
基类:
object
Context manager to capture:
stdout: replay it, clean it up and make it available via
obj.out
stderr: replay it and make it available via
obj.err
- 参数:
out (
bool
, optional, defaults toTrue
) -- Whether to capture stdout or not.err (
bool
, optional, defaults toTrue
) -- Whether to capture stderr or not.replay (
bool
, optional, defaults toTrue
) -- Whether to replay or not. By default each captured stream gets replayed back on context's exit, so that one can see what the test was doing. If this is a not wanted behavior and the captured data shouldn't be replayed, passreplay=False
to disable this feature.
Examples:
```python # to capture stdout only with auto-replay with CaptureStdout() as cs:
print("Secret message")
assert "message" in cs.out
# to capture stderr only with auto-replay import sys
- with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
assert "Warning" in cs.err
# to capture both streams with auto-replay with CaptureStd() as cs:
print("Secret message") print("Warning: ", file=sys.stderr)
assert "message" in cs.out assert "Warning" in cs.err
# to capture just one of the streams, and not the other, with auto-replay with CaptureStd(err=False) as cs:
print("Secret message")
assert "message" in cs.out # but best use the stream-specific subclasses
# to capture without auto-replay with CaptureStd(replay=False) as cs:
print("Secret message")