```# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

__all__ = ["einsum"]

[docs]def einsum(equation, *operands):
r"""
Executes the sum of product of provided operands based on the Einstein summation convention.
Einsum can be used to complete a variety of operations, such as sum, transpose,
batch matrix multiplication.

Args:
equation (`str`):
Uses uncased letters to specify the dimension of the operands and result. The input
equation is on the left hand before `->` while the output equation is on the right side.
Einsum can infer the result shape so that the `->` and the result label letters can be omitted.
Operands in the input equation are splitted by commas (','), e.g. 'abc,cde' describes two 3D
operands. The dimensions labeled with same letter should be same or be 1. Ellipsis ('...') can
be used to specify the broadcast dimensions.

operands (`Tensor`):
The operands to compute the Einstein sum of. The number of operands should be the same as the
the operands described in input equation.

Returns:
`Tensor`: The result of Einstein sum product.

Example:
.. code-block::

import numpy as np

np.random.seed(102)

# sum
# Tensor(shape=[], dtype=float64, place=CUDAPlace(0), stop_gradient=True, 2.30369050)

# dot
# Tensor(shape=[], dtype=float64, place=CUDAPlace(0), stop_gradient=True, 1.43773247)

# outer
# Tensor(shape=[4, 5], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#         [[0.34590188, 0.48353496, 0.09996135, 0.18656330, 0.21392910],
#         [0.39122025, 0.54688535, 0.11305780, 0.21100591, 0.24195704],
#         [0.17320613, 0.24212422, 0.05005442, 0.09341929, 0.10712238],
#         [0.42290818, 0.59118179, 0.12221522, 0.22809690, 0.26155500]])

# transpose
#  Tensor(shape=[2, 3, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#        [[[0.49174730, 0.33344683],
#          [0.89440989, 0.26162022],
#          [0.36116209, 0.12241719]],

#         [[0.49019824, 0.51895050],
#          [0.18241053, 0.13092809],
#          [0.81059146, 0.55165734]]])

# batch matrix multiplication
# Tensor(shape=[2, 3, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#     [[[0.13654339, 0.39331432, 0.65059661],
#      [0.07171420, 0.57518653, 0.77629221],
#      [0.21250688, 0.37793541, 0.73643411]],

#     [[0.56925339, 0.65859030, 0.57509818],
#      [0.30368265, 0.25778348, 0.21630400],
#      [0.39587265, 0.58031243, 0.51824755]]])

# Ellipsis transpose
# Tensor(shape=[2, 2, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#     [[[0.49174730, 0.89440989, 0.36116209],
#         [0.49019824, 0.18241053, 0.81059146]],

#         [[0.33344683, 0.26162022, 0.12241719],
#         [0.51895050, 0.13092809, 0.55165734]]])

# Ellipsis batch matrix multiplication
# Tensor(shape=[2, 3, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[[0.13654339, 0.39331432, 0.65059661],
#     [0.07171420, 0.57518653, 0.77629221],
#     [0.21250688, 0.37793541, 0.73643411]],

#     [[0.56925339, 0.65859030, 0.57509818],
#     [0.30368265, 0.25778348, 0.21630400],
#     [0.39587265, 0.58031243, 0.51824755]]])
"""

def _mul_sum(left, right, sum_dims):
assert left.rank() == right.rank(), "number of rank should be equal."
if len(sum_dims) == 0:
return left * right
sum_dims_set = set(sum_dims)
batch_dims = []
left_out_dims = []
right_out_dims = []
batch_size = summed_size = left_size = right_size = 1
dim = len(left.shape)
for i in range(dim):
is_left_summed_dim = left.shape[i] > 1  # not broadcast dim
is_right_summed_dim = right.shape[i] > 1
if i in sum_dims_set:
if is_left_summed_dim and is_right_summed_dim:
assert left.shape[i] == right.shape[i], "Non-broadcast dim should be equal."
summed_size *= left.shape[i]
elif is_left_summed_dim:
left = left.sum(axis=i, keepdim=True)
elif is_right_summed_dim:
right = right.sum(axis=i, keepdim=True)
elif is_left_summed_dim and is_right_summed_dim:
assert left.shape[i] == right.shape[i], "Non-broadcast dim should be equal."
batch_dims.append(i)
batch_size *= left.shape[i]
elif is_left_summed_dim:
left_out_dims.append(i)
left_size *= left.shape[i]
else:
right_out_dims.append(i)
right_size *= right.shape[i]
out_shape = [left.shape[i] for i in batch_dims + left_out_dims]
out_shape.extend([1] * len(sum_dims))
out_shape.extend([right.shape[i] for i in right_out_dims])

left_perm = list(batch_dims)
left_perm.extend(left_out_dims)
left_perm.extend(sum_dims)
left_perm.extend(right_out_dims)

right_perm = list(batch_dims)
right_perm.extend(sum_dims)
right_perm.extend(right_out_dims)
right_perm.extend(left_out_dims)

output_perm = [-1] * (len(batch_dims) + len(left_out_dims) + len(sum_dims) + len(right_out_dims))
for i, j in enumerate(batch_dims + left_out_dims + sum_dims + right_out_dims):
output_perm[j] = i

return result

if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
operands = operands[0]
# Equation is case insensitive
num_letters = 26
letters_to_idx = [-1] * num_letters
equation = equation.lower().replace(" ", "")
# 1. Parse the equation
eqns = equation.split("->")
num_eqns_size = len(eqns)
assert num_eqns_size <= 2, "The '->' should exist at most only once"

input_eqn = eqns[0]
output_eqn = None if num_eqns_size <= 1 else eqns[1]
operand_eqns = input_eqn.split(",")
assert len(operand_eqns) == len(
operands
), "Number of operands in equation and the tensors provided should be equal."

# Parse input equation
num_total_idxes = 0
input_operand_idxes = []
letter_frequence = [0] * num_letters
idxes_last_operand = []
num_ell_idxes = -1
first_ell_idx = 0
for i, term in enumerate(operand_eqns):
ell_char_count = 0
operand_rank = int(operands[i].rank().cpu().numpy())
curr_num_ell_idxes = operand_rank - len(term) + 3
dims_in_terms = 0
curr_operand_idxes = []
for ch in term:
if ch == ".":
ell_char_count += 1
assert ell_char_count <= 3, "The '.' should only exist in one ellipsis '...' in term {}".format(term)
if ell_char_count == 3:
if num_ell_idxes == -1:
num_ell_idxes = curr_num_ell_idxes
first_ell_idx = num_total_idxes
num_total_idxes += num_ell_idxes
else:
assert (
curr_num_ell_idxes == num_ell_idxes
), "Ellipsis in all terms should represent same dimensions ({}).".format(num_ell_idxes)

for j in range(num_ell_idxes):
curr_operand_idxes.append(j + first_ell_idx)
idxes_last_operand.append(i)
dims_in_terms += num_ell_idxes
else:
assert (ell_char_count == 0) or (
ell_char_count == 3
), "'.' must only occur in ellipsis, operand {}".format(term)
assert ord("a") <= ord(ch) and ord(ch) <= ord("z"), "only accept alphabet (a-zA-Z)"
letter_num = ord(ch) - ord("a")
if letters_to_idx[letter_num] == -1:
letters_to_idx[letter_num] = num_total_idxes
num_total_idxes += 1
idxes_last_operand.append(i)
else:
idxes_last_operand[letters_to_idx[letter_num]] = i
letter_frequence[letter_num] += 1
curr_operand_idxes.append(letters_to_idx[letter_num])
dims_in_terms += 1

assert dims_in_terms == operand_rank, "Dimension dismatch for operand {}: equation {}, tensor {}".format(
i, dims_in_terms, operand_rank
)
input_operand_idxes.append(curr_operand_idxes)
# Parse output equation
idxes_to_output_dims = [-1] * num_total_idxes
num_output_dims = 0
if num_eqns_size == 2:
ell_char_count = 0
for ch in output_eqn:
if ch == ".":
ell_char_count += 1
assert ell_char_count <= 3, "The '.' should only exist in one ellipsis '...' in term {}".format(
output_eqn
)
if ell_char_count == 3:
assert num_ell_idxes > -1, "Input equation '{}' don't have ellipsis.".format(input_eqn)
for j in range(num_ell_idxes):
idxes_to_output_dims[first_ell_idx + j] = num_output_dims
num_output_dims += 1

else:
assert (ell_char_count == 0) or (
ell_char_count == 3
), "'.' must only occur in ellipsis, operand {}".format(output_eqn)
assert ord("a") <= ord(ch) and ord(ch) <= ord("z"), "only accept alphabet (a-zA-Z)"
letter_num = ord(ch) - ord("a")
assert letters_to_idx[letter_num] != -1, "character {} doesn't exist in input".format(ch)
assert (
idxes_to_output_dims[letters_to_idx[letter_num]] == -1
), "character {} occurs twice in output".format(ch)

idxes_to_output_dims[letters_to_idx[letter_num]] = num_output_dims
num_output_dims += 1
else:  # num_eqns_size == 1
# Infer the output dims
if num_ell_idxes >= 0:
for j in range(num_ell_idxes):
idxes_to_output_dims[first_ell_idx + j] = num_output_dims
num_output_dims += 1
for j in range(num_letters):
if letter_frequence[j] == 1:
idxes_to_output_dims[letters_to_idx[j]] = num_output_dims
num_output_dims += 1

# Mark sum index
sum_dim = num_output_dims
for i in range(num_total_idxes):
if idxes_to_output_dims[i] == -1:
idxes_to_output_dims[i] = sum_dim
sum_dim += 1

preprocessed_operands = []
size_dims = [-1] * num_total_idxes
for i, preprocessed_operand in enumerate(operands):
idx_to_dims = [-1] * num_total_idxes
curr_operand_idxes = input_operand_idxes[i]
dim = 0
for j, idx in enumerate(curr_operand_idxes):
output_dim = idxes_to_output_dims[idx]
if idx_to_dims[output_dim] == -1:
idx_to_dims[output_dim] = dim
if size_dims[idx] == -1:
size_dims[idx] = preprocessed_operand.shape[dim]
else:
assert (
size_dims[idx] == preprocessed_operand.shape[dim]
), "Dimension size does not match previous size. "
dim += 1
else:
# Diagonal repeated index
# TODO(zhoushunjie): Need to develop a paddle.diagonal api
raise NotImplementedError("Can't support diagonal.")
perm = []
for input_dim in idx_to_dims:
if input_dim > -1:
perm.append(input_dim)
# Transpose the tensor by perm

for dim, input_dim in enumerate(idx_to_dims):
if input_dim == -1:

preprocessed_operands.append(preprocessed_operand)

# 2. Execute the mul_sum
sum_dims = []
result = preprocessed_operands[0]
for i in range(num_total_idxes):
if idxes_last_operand[i] == 0 and idxes_to_output_dims[i] >= num_output_dims:
result = result.sum(axis=idxes_to_output_dims[i], keepdim=True)
for i in range(1, len(preprocessed_operands)):
for j in range(num_total_idxes):
if idxes_last_operand[j] == i and idxes_to_output_dims[j] >= num_output_dims:
sum_dims.append(idxes_to_output_dims[j])
result = _mul_sum(result, preprocessed_operands[i], sum_dims)

squeeze_dims = [i for i in range(len(result.shape) - 1, num_output_dims - 1, -1)]
if len(squeeze_dims) != 0: