Quellcode durchsuchen

make batch loss calculatution more efficient

rasbt vor 1 Jahr
Ursprung
Commit
88b2dd780a

+ 2 - 0
appendix-D/01_main-chapter-code/previous_chapters.py

@@ -259,6 +259,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

Datei-Diff unterdrückt, da er zu groß ist
+ 38 - 34
ch05/01_main-chapter-code/ch05.ipynb


+ 2 - 0
ch05/01_main-chapter-code/train.py

@@ -34,6 +34,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

+ 2 - 0
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py

@@ -253,6 +253,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

+ 2 - 0
ch05/05_bonus_hparam_tuning/hparam_search.py

@@ -27,6 +27,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
     total_loss = 0.
     if num_batches is None:
         num_batches = len(data_loader)
+    else:
+        num_batches = min(num_batches, len(data_loader))
     for i, (input_batch, target_batch) in enumerate(data_loader):
         if i < num_batches:
             loss = calc_loss_batch(input_batch, target_batch, model, device)

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.