化学系エンジニアがAIを学ぶ

PyTorchでディープラーニング、強化学習を学び、主に化学工学の問題に取り組みます

メモ: PyTorch TensorDataset、DataLoader について

はじめに

PyTorchのtorch.utils.data.TensorDatasettorch.utils.data.DataLoaderの使い方についてのメモを記す。

torch.utils.data.TensorDataset

同じ要素数の2つのtensorを渡し、その組を得る。

import numpy
import torch
import torch.utils.data

x = numpy.array([[1], [2], [3], [4], [5], [6]])
y = numpy.array([[10], [20], [30], [40], [50], [60]])
x = torch.tensor(x)
y = torch.tensor(y)
dataset = torch.utils.data.TensorDataset(x, y)

中身をそのまま表示させると、TensorDatasetというオブジェクトであることが示される。

>>> dataset
<torch.utils.data.dataset.TensorDataset object at 0x7f4e91fdca20>

[ ]で要素を取り出せる。Tensorの組がタプルになっている。

>>> dataset[0]
(tensor([1.]), tensor([10.]))

>>> for x, y in dataset:
...     print(x, y)
... 
tensor([1.]) tensor([10.])
tensor([2.]) tensor([20.])
tensor([3.]) tensor([30.])
tensor([4.]) tensor([40.])
tensor([5.]) tensor([50.])
tensor([6.]) tensor([60.])

torch.utils.data.DataLoader

torch.utils.data.DataLoaderにTensorDatasetを渡すと得られ、ミニバッチを返すiterableなオブジェクトとなる。 batch_sizeでミニバッチのデータの数を指定できる(defaultは1)。shuffleをTrueとするとTensorDatasetの中身の 組の順序がシャッフルされる(defaultはFalse)。

batch_size  = 3 # ミニバッチのデータの数
data_loader = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=True)

forで中身を取り出すと次のようにミニバッチサイズ数ごとにデータが取り出される(この例では3つ)。順序が シャッフルされていることが確認できる。

>>> for x, y in data_loader:
...     print(x, y)
... 
tensor([[5.], [4.], [1.]]) tensor([[50.], [40.], [10.]])
tensor([[3.], [6.], [2.]]) tensor([[30.], [60.], [20.]])

ミニバッチ数を4、shuffle=Falseとすると次の通りとなる。データは全部で6組なので、2回めに取り出す際は残りの2組だけとなっている。 データの順序はシャッフルされずにもとのままとなっている。

>>> data_loader = torch.utils.data.DataLoader(dataset,
... batch_size=4, shuffle=False)
>>> for x, y in data_loader:
...     print(x, y)
... 
tensor([[1.], [2.], [3.], [4.]]) tensor([[10.], [20.], [30.], [40.]])
tensor([[5.], [6.]]) tensor([[50.], [60.]])

DataLoaderを用いた例

ニューラルネットワークによる関数近似 - 化学系エンジニアがAIを学ぶ

参考