メモ: PyTorch TensorDataset、DataLoader について
はじめに
PyTorchのtorch.utils.data.TensorDataset
、torch.utils.data.DataLoader
の使い方についてのメモを記す。
torch.utils.data.TensorDataset
import numpy import torch import torch.utils.data x = numpy.array([[1], [2], [3], [4], [5], [6]]) y = numpy.array([[10], [20], [30], [40], [50], [60]]) x = torch.tensor(x) y = torch.tensor(y) dataset = torch.utils.data.TensorDataset(x, y)
中身をそのまま表示させると、TensorDatasetというオブジェクトであることが示される。
>>> dataset <torch.utils.data.dataset.TensorDataset object at 0x7f4e91fdca20>
[ ]で要素を取り出せる。Tensorの組がタプルになっている。
>>> dataset[0] (tensor([1.]), tensor([10.])) >>> for x, y in dataset: ... print(x, y) ... tensor([1.]) tensor([10.]) tensor([2.]) tensor([20.]) tensor([3.]) tensor([30.]) tensor([4.]) tensor([40.]) tensor([5.]) tensor([50.]) tensor([6.]) tensor([60.])
torch.utils.data.DataLoader
torch.utils.data.DataLoader
にTensorDatasetを渡すと得られ、ミニバッチを返すiterableなオブジェクトとなる。
batch_size
でミニバッチのデータの数を指定できる(defaultは1)。shuffle
をTrueとするとTensorDatasetの中身の
組の順序がシャッフルされる(defaultはFalse)。
batch_size = 3 # ミニバッチのデータの数 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
for
で中身を取り出すと次のようにミニバッチサイズ数ごとにデータが取り出される(この例では3つ)。順序が
シャッフルされていることが確認できる。
>>> for x, y in data_loader: ... print(x, y) ... tensor([[5.], [4.], [1.]]) tensor([[50.], [40.], [10.]]) tensor([[3.], [6.], [2.]]) tensor([[30.], [60.], [20.]])
ミニバッチ数を4、shuffle=False
とすると次の通りとなる。データは全部で6組なので、2回めに取り出す際は残りの2組だけとなっている。
データの順序はシャッフルされずにもとのままとなっている。
>>> data_loader = torch.utils.data.DataLoader(dataset, ... batch_size=4, shuffle=False) >>> for x, y in data_loader: ... print(x, y) ... tensor([[1.], [2.], [3.], [4.]]) tensor([[10.], [20.], [30.], [40.]]) tensor([[5.], [6.]]) tensor([[50.], [60.]])
DataLoaderを用いた例
ニューラルネットワークによる関数近似 - 化学系エンジニアがAIを学ぶ