distributed

class ParallelEmbedding(num_embeddings, embedding_dim, rank, world_size, weight_attr=None, name=None)[源代码]

Parallel Embedding

forward(x)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments

class ColumnParallelLiner(size, num_partitions=1, gather_out=True, param_attr=None, bias_attr=None, name=None)[源代码]

Parallel Linear, axis=1

forward(x)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments

class RowParallelLiner(size, num_partitions=1, input_is_parallel=False, param_attr=None, bias_attr=None, name=None)[源代码]

Parallel Linear, axis=0

forward(x)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments