```# Copyright 2020-present the HuggingFace Inc. team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# This file is modified from
#  https://github.com/huggingface/transformers/blob/main/src/transformers

import collections
import copy
import os
from typing import Any, Optional

import numpy as np

__all__ = [
"distributed_concat",
"nested_concat",
"nested_detach",
"nested_numpify",
"nested_truncate",
]

def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = []
dist.all_gather(output_tensors, tensor)
output_tensors = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in output_tensors]

# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")

[docs]
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:

# raise ValueError("Error")
# Let's figure out the new shape
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tuple(
tensor1.shape[2:]
)

# Now let's fill the result tensor

result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
return result

"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), axis=0)

# Let's figure out the new shape
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]

# Now let's fill the result tensor
result[: array1.shape[0], : array1.shape[1]] = array1
result[array1.shape[0] :, : array2.shape[1]] = array2
return result

[docs]
"""
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
elif isinstance(tensors, np.ndarray):
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

[docs]
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()

[docs]
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
t = tensors.cpu()
return t.cpu().numpy()

[docs]
def nested_truncate(tensors, limit):
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_truncate(t, limit) for t in tensors)
return tensors[:limit]

def distributed_isfile(filename):
"""Check all machine nodes. return False if no machine have such file."""
if trainers_num <= 1:
return os.path.isfile(filename)
else:
if local_rank == 0 and os.path.isfile(filename):
file_count += 1

return file_count >= 1

def distributed_file(filename):
if trainers_num <= 1:
return filename
else:
if local_rank == 0 and os.path.isfile(filename):

tensor_list = []

file_object_list = [None]

file_object = file_object_list[0]

if local_rank == 0 and not os.path.isfile(filename):
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))

with open(filename, "wb") as f:
f.write(file_object)

return filename

TensorHolder = collections.namedtuple("TensorHolder", ["shape", "dtype", "name"])

def nested_reduce_tensor(tensor):
if isinstance(tensor, dict):
# copy tensor since it will be inplace modified dict
tensor = copy.copy(tensor)
for key in list(tensor.keys()):
tensor[key] = nested_reduce_tensor(tensor[key])
if isinstance(tensor, (tuple, list)):
return type(tensor)(nested_reduce_tensor(t) for t in tensor)

return TensorHolder(tensor.shape, tensor.dtype, tensor.name)

return tensor

def nested_empty_tensor(tensor):
if isinstance(tensor, dict):
for key in list(tensor.keys()):
tensor[key] = nested_empty_tensor(tensor[key])
if isinstance(tensor, list):
return type(tensor)(nested_empty_tensor(t) for t in tensor)

# TensorHolder is tuple
if isinstance(tensor, TensorHolder):
t.name = tensor.name
return t

return tensor

if isinstance(tensor, dict):
for key in list(tensor.keys()):
if isinstance(tensor, list):
return type(tensor)(nested_broadcast_tensor(t, src=src, group=group) for t in tensor)

return tensor

return state_dict

logger.info("Start broadcast optimizer in data parallel group.")
try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0

if process_rank == src_rank:
if state_dict is None:
logger.warning(
)
fake_state_dict = [nested_reduce_tensor(state_dict)]
else:
if state_dict is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()}  are forbidden to have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}"
)
fake_state_dict = [None]