sampler

class SamplerHelper(dataset, iterable=None)[源代码]

基类:object

The class is to help construct iterable sampler used for paddle.io.DataLoader. It wraps a dataset and uses its __getitem__() method. Every subclass of SamplerHelper has to provide an __iter__() method, providing a way to iterate over indices of dataset elements, and a __len__() method that returns the length of the returned iterators.

The class also can be used as batch iterator instead of indices iterator when iterator yield samples rather than indices by initializing iterator with a iterable dataset.

注解

The __len__() method isn't strictly required by paddle.io.DataLoader, but is expected in any calculation involving the length of a paddle.io.DataLoader.

参数
  • dataset (Dataset) -- Input dataset for SamplerHelper.

  • iterable (Iterable, optional) -- Iterator of dataset. Default: None.

property length

Returns the length.

shuffle(buffer_size=- 1, seed=None)[源代码]

Shuffles the dataset according to the given buffer size and random seed.

参数
  • buffer_size (int, optional) -- Buffer size for shuffle. If buffer_size < 0 or more than the length of the dataset, buffer_size is the length of the dataset. Default: -1.

  • seed (int, optional) -- Seed for the random. Default: None.

返回

A new shuffled SamplerHelper object.

返回类型

SamplerHelper

示例

from paddlenlp.data import SamplerHelper
from paddle.io import Dataset

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = [
            [[1, 2, 3, 4], [1]],
            [[5, 6, 7], [0]],
            [[8, 9], [1]],
        ]

    def __getitem__(self, index):
        data = self.data[index][0]
        label = self.data[index][1]
        return data, label

    def __len__(self):
        return len(self.data)

dataset = MyDataset()
sampler = SamplerHelper(dataset)
print(list(sampler))    # indices of dataset elements
# [0, 1, 2]

sampler = sampler.shuffle(seed=2)
print(list(sampler))    # indices of dataset elements
# [2, 1, 0]
sort(cmp=None, key=None, reverse=False, buffer_size=- 1)[源代码]

Sorts the dataset according to given callable cmp() or key().

参数
  • cmp (callable, optional) -- The function of comparison. Default: None.

  • key (callable, optional) -- The function of key. Default: None.

  • reverse (bool, optional) -- Whether to reverse when sorting the data samples. If True, it means in descending order, and False means in ascending order. Default: False.

  • buffer_size (int, optional) -- Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data, buffer_size will be set to the length of the data. Default: -1.

返回

A new sorted SamplerHelper object.

返回类型

SamplerHelper

示例

from paddlenlp.data import SamplerHelper
from paddle.io import Dataset

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = [
            [[1, 2, 3, 4], [1]],
            [[5, 6, 7], [0]],
            [[8, 9], [1]],
        ]

    def __getitem__(self, index):
        data = self.data[index][0]
        label = self.data[index][1]
        return data, label

    def __len__(self):
        return len(self.data)

dataset = MyDataset()
sampler = SamplerHelper(dataset)
print(list(sampler))    # indices of dataset elements
# [0, 1, 2]

# Sorted in ascending order by the length of the first field
# of the sample
key = (lambda x, data_source: len(data_source[x][0]))
sampler = sampler.sort(key=key)
print(list(sampler))    # indices of dataset elements
# [2, 1, 0]
batch(batch_size, drop_last=False, batch_size_fn=None, key=None)[源代码]

Batches the dataset according to given batch_size.

参数
  • batch_size (int) -- The batch size.

  • drop_last (bool, optional) -- Whether to drop the last mini batch. Default: False.

  • batch_size_fn (callable, optional) -- It accepts four arguments: index of data source, the length of minibatch, the size of minibatch so far and data source, and it returns the size of mini batch so far. Actually, the returned value can be anything and would used as argument size_so_far in key. If None, it would return the length of mini match. Default: None.

  • key (callable, optional) -- The function of key. It accepts the size of minibatch so far and the length of minibatch, and returns what to be compared with batch_size. If None, only the size of mini batch so far would be compared with batch_size. Default: None.

返回

A new batched SamplerHelper object.

返回类型

SamplerHelper

示例

from paddlenlp.data import SamplerHelper
from paddle.io import Dataset

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = [
            [[1, 2, 3, 4], [1]],
            [[5, 6, 7], [0]],
            [[8, 9], [1]],
        ]

    def __getitem__(self, index):
        data = self.data[index][0]
        label = self.data[index][1]
        return data, label

    def __len__(self):
        return len(self.data)

dataset = MyDataset()
sampler = SamplerHelper(dataset)
print(list(sampler))    # indices of dataset elements
# [0, 1, 2]

sampler = sampler.batch(batch_size=2)
print(list(sampler))    # indices of dataset elements
# [[0, 1], [2]]
shard(num_replicas=None, rank=None)[源代码]

Slices the dataset for multi GPU training.

参数
  • num_replicas (int, optional) -- The number of training process, and is also the number of GPU cards used in training. If None, it will be set by paddle.distributed.get_world_size() method. Default: None.

  • rank (int, optional) -- The id of current training process. Equal to the value of the environment variable PADDLE_TRAINER_ID. If None, it will be intialized by paddle.distributed.get_rank() method. Default: None.

返回

A new sliced SamplerHelper object.

返回类型

SamplerHelper

示例

from paddlenlp.data import SamplerHelper
from paddle.io import Dataset

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = [
            [[1, 2, 3, 4], [1]],
            [[5, 6, 7], [0]],
            [[8, 9], [1]],
        ]

    def __getitem__(self, index):
        data = self.data[index][0]
        label = self.data[index][1]
        return data, label

    def __len__(self):
        return len(self.data)

dataset = MyDataset()
sampler = SamplerHelper(dataset)
print(list(sampler))    # indices of dataset elements
# [0, 1, 2]

sampler = sampler.shard(num_replicas=2)
print(list(sampler))    # indices of dataset elements
# [0, 2]