瀏覽代碼

make batch loss calculatution more efficient

rasbt 1 年之前
父節點
當前提交
88b2dd780a

+ 2 - 0
appendix-D/01_main-chapter-code/previous_chapters.py

@@ -259,6 +259,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

File diff suppressed because it is too large
+ 38 - 34
ch05/01_main-chapter-code/ch05.ipynb


+ 2 - 0
ch05/01_main-chapter-code/train.py

@@ -34,6 +34,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

+ 2 - 0
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py

@@ -253,6 +253,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

+ 2 - 0
ch05/05_bonus_hparam_tuning/hparam_search.py

@@ -27,6 +27,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

Some files were not shown because too many files changed in this diff