Parcourir la source

fix learning rate scheduler

rasbt il y a 1 an
Parent
commit
6f0a5c320b

Fichier diff supprimé car celui-ci est trop grand
+ 1 - 1
appendix-D/01_main-chapter-code/appendix-D.ipynb


+ 2 - 2
ch05/05_bonus_hparam_tuning/hparam_search.py

@@ -65,13 +65,13 @@ def train_model(model, train_loader, val_loader, optimizer, device,
                 initial_lr=3e-05, min_lr=1e-6):
     global_step = 0
 
-    max_lr = optimizer.defaults["lr"]
+    max_lr = optimizer.param_groups[0]["lr"]
 
     # Calculate total number of iterations
     total_training_iters = len(train_loader) * n_epochs
 
     # Calculate the learning rate increment at each step during warmup
-    lr_increment = (optimizer.defaults["lr"] - initial_lr) / warmup_iters
+    lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters
 
     for epoch in range(n_epochs):
         model.train()

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff