Computer Vision News - April 2023
35 MONAI Generative Models Training n_epochs = 100 val_interval = 10 epoch_recon_loss_list = [] epoch_quant_loss_list = [] val_recon_epoch_loss_list = [] intermediary_images = [] n_example_images = 4 total_start = time.time() for epoch in range(n_epochs): model.train() epoch_loss = 0 progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: images = batch["image"].to(device) optimizer.zero_grad(set_to_none=True) # model outputs reconstruction and the quantization error reconstruction, quantization_loss = model(images=images) recons_loss = l1_loss(reconstruction.float(), images.float()) loss = recons_loss + quantization_loss loss.backward() optimizer.step() epoch_loss += recons_loss.item() progress_bar.set_postfix( {"recons_loss": epoch_loss / (step + 1), "quantization_loss": quantization_loss.item() / (step + 1)} ) epoch_recon_loss_list.append(epoch_loss / (step + 1)) epoch_quant_loss_list.append(quantization_loss.item() / (step + 1)) if (epoch + 1) % val_interval == 0: model.eval() val_loss = 0 with torch.no_grad(): for val_step, batch in enumerate(val_loader, start=1): images = batch["image"].to(device) reconstruction, quantization_loss = model(images=images) # get the first sample from the first validation batch for # visualizing how the training evolves if val_step == 1: intermediary_images.append(reconstruction[:n_example_ images, 0])
Made with FlippingBook
RkJQdWJsaXNoZXIy NTc3NzU=