batch_sampler#

class DistributedBatchSampler(dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False, consumed_samples=0)[源代码]#

Sampler that restricts data loading to a subset of the dataset.

In such case, each process can pass a DistributedBatchSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.

备注

Dataset is assumed to be of constant size.

参数:
  • dataset (paddle.io.Dataset) -- this could be a paddle.io.Dataset implement or other python object which implemented __len__ for BatchSampler to get sample number of data source.

  • batch_size (int) -- sample indice number in a mini-batch indices.

  • num_replicas (int, optional) -- porcess number in distributed training. If num_replicas is None, num_replicas will be retrieved from paddle.distributed.ParallenEnv. Default None.

  • rank (int, optional) -- the rank of the current process among num_replicas processes. If rank is None, rank is retrieved from paddle.distributed.ParallenEnv. Default None.

  • shuffle (bool) -- whther to shuffle indices order before genrating batch indices. Default False.

  • drop_last (bool) -- whether drop the last incomplete batch dataset size is not divisible by the batch size. Default False

示例

import numpy as np

from paddle.io import Dataset, DistributedBatchSampler

# init with dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)

for data in sampler:
    # do something
    break
set_epoch(epoch=0, consumed_samples=0)[源代码]#

Sets the epoch number. When shuffle=True, this number is used as seeds of random numbers. By default, users may not set this, all replicas (workers) use a different random ordering for each epoch. If set same number at each epoch, this sampler will yield the same ordering at all epoches.

参数:

epoch (int) -- Epoch number.

示例

from paddle.io import Dataset, DistributedBatchSampler

# init with dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)

for epoch in range(10):
    sampler.set_epoch(epoch)