Skip to content

你好,请问可以解释一下这一部分的代码吗? #5

@KrystalCWT

Description

@KrystalCWT

你好,请问可以解释一下这一部分的代码吗?没看懂你triplet loss是怎么计算的。
temp_x = [torch.stack(input[i], dim=0) for i in range(len(input))]
temp_y = [torch.stack(target[i], dim=0) for i in range(len(target))]
new_x = torch.stack(temp_x, dim=0)
new_y = torch.stack(temp_y, dim=0)

    new_x = [new_x[:, i] for i in range(3)]
    new_y = [new_y[:, i] for i in range(3)]
    sample_input = torch.cat(new_x, 0)
    sample_target = torch.cat(new_y, 0)
    # print (sample_target)
    # print (sample_target[:batch_size])
    # print (sample_target[batch_size:(batch_size * 2)])
    # print (sample_target[-batch_size:])
    target = sample_target.cuda(async=True)
    input_var = torch.autograd.Variable(sample_input.cuda())
    target_var = torch.autograd.Variable(target.cuda())
    # compute output
    output = model(input_var)
    anchor = output[:temp_batch_size]
    positive = output[temp_batch_size:(temp_batch_size * 2)]
    negative = output[-temp_batch_size:]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions