How to work with data iterators
When the data provider for training or validation is an iterator (infinite or finite with known or unknown size), here are some basic examples of how to setup trainer or evaluator.
Infinite iterator for training
Let’s use an infinite data iterator as training dataflow
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}"
)
trainer = Engine(train_step)
# We need to specify epoch_length to define the epoch
trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3)
1/3 : 1 - 63.862
1/3 : 2 - 64.042
1/3 : 3 - 63.936
1/3 : 4 - 64.141
1/3 : 5 - 64.767
2/3 : 6 - 63.791
2/3 : 7 - 64.565
2/3 : 8 - 63.602
2/3 : 9 - 63.995
2/3 : 10 - 63.943
3/3 : 11 - 63.831
3/3 : 12 - 64.276
3/3 : 13 - 64.148
3/3 : 14 - 63.920
3/3 : 15 - 64.226
State:
iteration: 15
epoch: 3
epoch_length: 5
max_epochs: 3
output: <class 'NoneType'>
batch: <class 'torch.Tensor'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
If we do not specify epoch_length, we can stop the training explicitly by calling
terminate()
. In this case, there will be only a single epoch defined.
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(once=15))
def stop_training():
trainer.terminate()
trainer.run(infinite_iterator(4))
1/1 : 1 - 63.862
1/1 : 2 - 64.042
1/1 : 3 - 63.936
1/1 : 4 - 64.141
1/1 : 5 - 64.767
1/1 : 6 - 63.791
1/1 : 7 - 64.565
1/1 : 8 - 63.602
1/1 : 9 - 63.995
1/1 : 10 - 63.943
1/1 : 11 - 63.831
1/1 : 12 - 64.276
1/1 : 13 - 64.148
1/1 : 14 - 63.920
1/1 : 15 - 64.226
State:
iteration: 15
epoch: 1
epoch_length: <class 'NoneType'>
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'torch.Tensor'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
Same code can be used for validating models.
Finite iterator with unknown length
Let’s use a finite data iterator but with unknown length (for user). In
case of training, we would like to perform several passes over the
dataflow and thus we need to restart the data iterator when it is
exhausted. In the code, we do not specify epoch_length
which will be automatically
determined.
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.DATALOADER_STOP_ITERATION)
def restart_iter():
trainer.state.dataloader = finite_unk_size_data_iter()
data_iter = finite_unk_size_data_iter()
trainer.run(data_iter, max_epochs=5)
1/5 : 1 - 0.000
1/5 : 2 - 1.000
1/5 : 3 - 2.000
1/5 : 4 - 3.000
1/5 : 5 - 4.000
1/5 : 6 - 5.000
1/5 : 7 - 6.000
1/5 : 8 - 7.000
1/5 : 9 - 8.000
1/5 : 10 - 9.000
1/5 : 11 - 10.000
2/5 : 12 - 0.000
2/5 : 13 - 1.000
2/5 : 14 - 2.000
2/5 : 15 - 3.000
2/5 : 16 - 4.000
2/5 : 17 - 5.000
2/5 : 18 - 6.000
2/5 : 19 - 7.000
2/5 : 20 - 8.000
2/5 : 21 - 9.000
2/5 : 22 - 10.000
3/5 : 23 - 0.000
3/5 : 24 - 1.000
3/5 : 25 - 2.000
3/5 : 26 - 3.000
3/5 : 27 - 4.000
3/5 : 28 - 5.000
3/5 : 29 - 6.000
3/5 : 30 - 7.000
3/5 : 31 - 8.000
3/5 : 32 - 9.000
3/5 : 33 - 10.000
4/5 : 34 - 0.000
4/5 : 35 - 1.000
4/5 : 36 - 2.000
4/5 : 37 - 3.000
4/5 : 38 - 4.000
4/5 : 39 - 5.000
4/5 : 40 - 6.000
4/5 : 41 - 7.000
4/5 : 42 - 8.000
4/5 : 43 - 9.000
4/5 : 44 - 10.000
5/5 : 45 - 0.000
5/5 : 46 - 1.000
5/5 : 47 - 2.000
5/5 : 48 - 3.000
5/5 : 49 - 4.000
5/5 : 50 - 5.000
5/5 : 51 - 6.000
5/5 : 52 - 7.000
5/5 : 53 - 8.000
5/5 : 54 - 9.000
5/5 : 55 - 10.000
State:
iteration: 55
epoch: 5
epoch_length: 11
max_epochs: 5
output: <class 'NoneType'>
batch: 10
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
In case of validation, the code is simply
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
evaluator = Engine(val_step)
data_iter = finite_unk_size_data_iter()
evaluator.run(data_iter)
1/1 : 1 - 0.000
1/1 : 2 - 1.000
1/1 : 3 - 2.000
1/1 : 4 - 3.000
1/1 : 5 - 4.000
1/1 : 6 - 5.000
1/1 : 7 - 6.000
1/1 : 8 - 7.000
1/1 : 9 - 8.000
1/1 : 10 - 9.000
1/1 : 11 - 10.000
State:
iteration: 11
epoch: 1
epoch_length: 11
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
Finite iterator with known length
Let’s use a finite data iterator with known size for training or validation. If we need to restart the data iterator, we can do this either as in case of unknown size by attaching the restart handler on @trainer.on(Events.DATALOADER_STOP_ITERATION)
, but here we will do this explicitly on iteration:
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(every=size))
def restart_iter():
trainer.state.dataloader = finite_size_data_iter(size)
data_iter = finite_size_data_iter(size)
trainer.run(data_iter, max_epochs=5)
1/5 : 1 - 0.000
1/5 : 2 - 1.000
1/5 : 3 - 2.000
1/5 : 4 - 3.000
1/5 : 5 - 4.000
1/5 : 6 - 5.000
1/5 : 7 - 6.000
1/5 : 8 - 7.000
1/5 : 9 - 8.000
1/5 : 10 - 9.000
1/5 : 11 - 10.000
2/5 : 12 - 0.000
2/5 : 13 - 1.000
2/5 : 14 - 2.000
2/5 : 15 - 3.000
2/5 : 16 - 4.000
2/5 : 17 - 5.000
2/5 : 18 - 6.000
2/5 : 19 - 7.000
2/5 : 20 - 8.000
2/5 : 21 - 9.000
2/5 : 22 - 10.000
3/5 : 23 - 0.000
3/5 : 24 - 1.000
3/5 : 25 - 2.000
3/5 : 26 - 3.000
3/5 : 27 - 4.000
3/5 : 28 - 5.000
3/5 : 29 - 6.000
3/5 : 30 - 7.000
3/5 : 31 - 8.000
3/5 : 32 - 9.000
3/5 : 33 - 10.000
4/5 : 34 - 0.000
4/5 : 35 - 1.000
4/5 : 36 - 2.000
4/5 : 37 - 3.000
4/5 : 38 - 4.000
4/5 : 39 - 5.000
4/5 : 40 - 6.000
4/5 : 41 - 7.000
4/5 : 42 - 8.000
4/5 : 43 - 9.000
4/5 : 44 - 10.000
5/5 : 45 - 0.000
5/5 : 46 - 1.000
5/5 : 47 - 2.000
5/5 : 48 - 3.000
5/5 : 49 - 4.000
5/5 : 50 - 5.000
5/5 : 51 - 6.000
5/5 : 52 - 7.000
5/5 : 53 - 8.000
5/5 : 54 - 9.000
5/5 : 55 - 10.000
State:
iteration: 55
epoch: 5
epoch_length: 11
max_epochs: 5
output: <class 'NoneType'>
batch: 10
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>
In case of validation, the code is simply
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}"
)
evaluator = Engine(val_step)
data_iter = finite_size_data_iter(size)
evaluator.run(data_iter)
1/1 : 1 - 0.000
1/1 : 2 - 1.000
1/1 : 3 - 2.000
1/1 : 4 - 3.000
1/1 : 5 - 4.000
1/1 : 6 - 5.000
1/1 : 7 - 6.000
1/1 : 8 - 7.000
1/1 : 9 - 8.000
1/1 : 10 - 9.000
1/1 : 11 - 10.000
State:
iteration: 11
epoch: 1
epoch_length: 11
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: <class 'NoneType'>
times: <class 'dict'>