sampler#
- class SamplerHelper(dataset, iterable=None)[源代码]#
基类:
objectThe class is to help construct iterable sampler used for
paddle.io.DataLoader. It wraps a dataset and uses its__getitem__()method. Every subclass ofSamplerHelperhas to provide an__iter__()method, providing a way to iterate over indices of dataset elements, and a__len__()method that returns the length of the returned iterators.The class also can be used as batch iterator instead of indices iterator when
iteratoryield samples rather than indices by initializingiteratorwith a iterable dataset.备注
The
__len__()method isn't strictly required bypaddle.io.DataLoader, but is expected in any calculation involving the length of apaddle.io.DataLoader.- 参数:
dataset (Dataset) -- Input dataset for
SamplerHelper.iterable (Iterable, optional) -- Iterator of dataset. Default: None.
- property length#
Returns the length.
- shuffle(buffer_size=-1, seed=None)[源代码]#
Shuffles the dataset according to the given buffer size and random seed.
- 参数:
buffer_size (int, optional) -- Buffer size for shuffle. If
buffer_size < 0or more than the length of the dataset,buffer_sizeis the length of the dataset. Default: -1.seed (int, optional) -- Seed for the random. Default: None.
- 返回:
A new shuffled
SamplerHelperobject.- 返回类型:
示例
from paddlenlp.data import SamplerHelper from paddle.io import Dataset class MyDataset(Dataset): def __init__(self): super(MyDataset, self).__init__() self.data = [ [[1, 2, 3, 4], [1]], [[5, 6, 7], [0]], [[8, 9], [1]], ] def __getitem__(self, index): data = self.data[index][0] label = self.data[index][1] return data, label def __len__(self): return len(self.data) dataset = MyDataset() sampler = SamplerHelper(dataset) print(list(sampler)) # indices of dataset elements # [0, 1, 2] sampler = sampler.shuffle(seed=2) print(list(sampler)) # indices of dataset elements # [2, 1, 0]
- sort(cmp=None, key=None, reverse=False, buffer_size=-1)[源代码]#
Sorts the dataset according to given callable
cmp()orkey().- 参数:
cmp (callable, optional) -- The function of comparison. Default: None.
key (callable, optional) -- The function of key. Default: None.
reverse (bool, optional) -- Whether to reverse when sorting the data samples. If True, it means in descending order, and False means in ascending order. Default: False.
buffer_size (int, optional) -- Buffer size for sort. If
buffer_size < 0orbuffer_sizeis more than the length of the data,buffer_sizewill be set to the length of the data. Default: -1.
- 返回:
A new sorted
SamplerHelperobject.- 返回类型:
示例
from paddlenlp.data import SamplerHelper from paddle.io import Dataset class MyDataset(Dataset): def __init__(self): super(MyDataset, self).__init__() self.data = [ [[1, 2, 3, 4], [1]], [[5, 6, 7], [0]], [[8, 9], [1]], ] def __getitem__(self, index): data = self.data[index][0] label = self.data[index][1] return data, label def __len__(self): return len(self.data) dataset = MyDataset() sampler = SamplerHelper(dataset) print(list(sampler)) # indices of dataset elements # [0, 1, 2] # Sorted in ascending order by the length of the first field # of the sample key = (lambda x, data_source: len(data_source[x][0])) sampler = sampler.sort(key=key) print(list(sampler)) # indices of dataset elements # [2, 1, 0]
- batch(batch_size, drop_last=False, batch_size_fn=None, key=None)[源代码]#
Batches the dataset according to given
batch_size.- 参数:
batch_size (int) -- The batch size.
drop_last (bool, optional) -- Whether to drop the last mini batch. Default: False.
batch_size_fn (callable, optional) -- It accepts four arguments: index of data source, the length of minibatch, the size of minibatch so far and data source, and it returns the size of mini batch so far. Actually, the returned value can be anything and would used as argument
size_so_farinkey. If None, it would return the length of mini match. Default: None.key (callable, optional) -- The function of key. It accepts the size of minibatch so far and the length of minibatch, and returns what to be compared with
batch_size. If None, only the size of mini batch so far would be compared withbatch_size. Default: None.
- 返回:
A new batched
SamplerHelperobject.- 返回类型:
示例
from paddlenlp.data import SamplerHelper from paddle.io import Dataset class MyDataset(Dataset): def __init__(self): super(MyDataset, self).__init__() self.data = [ [[1, 2, 3, 4], [1]], [[5, 6, 7], [0]], [[8, 9], [1]], ] def __getitem__(self, index): data = self.data[index][0] label = self.data[index][1] return data, label def __len__(self): return len(self.data) dataset = MyDataset() sampler = SamplerHelper(dataset) print(list(sampler)) # indices of dataset elements # [0, 1, 2] sampler = sampler.batch(batch_size=2) print(list(sampler)) # indices of dataset elements # [[0, 1], [2]]
- shard(num_replicas=None, rank=None)[源代码]#
Slices the dataset for multi GPU training.
- 参数:
num_replicas (int, optional) -- The number of training process, and is also the number of GPU cards used in training. If None, it will be set by
paddle.distributed.get_world_size()method. Default: None.rank (int, optional) -- The id of current training process. Equal to the value of the environment variable PADDLE_TRAINER_ID. If None, it will be initialized by
paddle.distributed.get_rank()method. Default: None.
- 返回:
A new sliced
SamplerHelperobject.- 返回类型:
示例
from paddlenlp.data import SamplerHelper from paddle.io import Dataset class MyDataset(Dataset): def __init__(self): super(MyDataset, self).__init__() self.data = [ [[1, 2, 3, 4], [1]], [[5, 6, 7], [0]], [[8, 9], [1]], ] def __getitem__(self, index): data = self.data[index][0] label = self.data[index][1] return data, label def __len__(self): return len(self.data) dataset = MyDataset() sampler = SamplerHelper(dataset) print(list(sampler)) # indices of dataset elements # [0, 1, 2] sampler = sampler.shard(num_replicas=2) print(list(sampler)) # indices of dataset elements # [0, 2]