```# Copyright 2020-present the HuggingFace Inc. team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# This file is modified from
#  https://github.com/huggingface/transformers/blob/main/src/transformers

from typing import Any, Optional

import numpy as np

__all__ = [
"distributed_concat",
"nested_concat",
"nested_detach",
"nested_numpify",
"nested_truncate",
]

def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = []
# output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
# output_tensors = [
#     t if len(t.shape) > 0 else t[None] for t in output_tensors
# ]
dist.all_gather(output_tensors, tensor)

# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")

"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
if len(tensor1.shape) == 1 or tensor1.shape == tensor2.shape:

# raise ValueError("Error")
# Let's figure out the new shape
new_shape = (tensor1.shape + tensor2.shape, max(tensor1.shape, tensor2.shape)) + tuple(
tensor1.shape[2:]
)

# Now let's fill the result tensor

result[: tensor1.shape, : tensor1.shape] = tensor1
result[tensor1.shape :, : tensor2.shape] = tensor2
return result

"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
if len(array1.shape) == 1 or array1.shape == array2.shape:
return np.concatenate((array1, array2), axis=0)

# Let's figure out the new shape
new_shape = (array1.shape + array2.shape, max(array1.shape, array2.shape)) + array1.shape[2:]

# Now let's fill the result tensor
result[: array1.shape, : array1.shape] = array1
result[array1.shape :, : array2.shape] = array2
return result

"""
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
elif isinstance(tensors, np.ndarray):
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

[docs]def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()

[docs]def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
t = tensors.cpu()