-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Open
Description
The solution you provide in "Defining the DB-VAE loss function" part might be wrong:
total_loss = torch.mean(classification_loss * face_indicator + vae_loss)
From your definition, for images of non-faces, loss function is solely the classification loss, so it should be
total_loss = torch.mean(classification_loss + vae_loss * face_indicator)
btw, why we use the training data as testing data in the first CNN model? I don't have api for your loader, but if I am right, that loader.get_batch(5000)
should already been used during training?
# set the model to eval mode
standard_classifier.eval()
# TRAINING DATA
# Evaluate on a subset of CelebA+Imagenet
(batch_x, batch_y) = loader.get_batch(5000)
batch_x = torch.from_numpy(batch_x).float().to(device)
batch_y = torch.from_numpy(batch_y).float().to(device)
with torch.inference_mode():
y_pred_logits = standard_classifier(batch_x)
Metadata
Metadata
Assignees
Labels
No labels