-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
你好,请问可以解释一下这一部分的代码吗?没看懂你triplet loss是怎么计算的。
temp_x = [torch.stack(input[i], dim=0) for i in range(len(input))]
temp_y = [torch.stack(target[i], dim=0) for i in range(len(target))]
new_x = torch.stack(temp_x, dim=0)
new_y = torch.stack(temp_y, dim=0)
new_x = [new_x[:, i] for i in range(3)]
new_y = [new_y[:, i] for i in range(3)]
sample_input = torch.cat(new_x, 0)
sample_target = torch.cat(new_y, 0)
# print (sample_target)
# print (sample_target[:batch_size])
# print (sample_target[batch_size:(batch_size * 2)])
# print (sample_target[-batch_size:])
target = sample_target.cuda(async=True)
input_var = torch.autograd.Variable(sample_input.cuda())
target_var = torch.autograd.Variable(target.cuda())
# compute output
output = model(input_var)
anchor = output[:temp_batch_size]
positive = output[temp_batch_size:(temp_batch_size * 2)]
negative = output[-temp_batch_size:]
Metadata
Metadata
Assignees
Labels
No labels