|
@@ -259,6 +259,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
|
|
|
total_loss = 0.
|
|
total_loss = 0.
|
|
|
if num_batches is None:
|
|
if num_batches is None:
|
|
|
num_batches = len(data_loader)
|
|
num_batches = len(data_loader)
|
|
|
|
|
+ else:
|
|
|
|
|
+ num_batches = min(num_batches, len(data_loader))
|
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
|
if i < num_batches:
|
|
if i < num_batches:
|
|
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|