Open
Description
Hello, I notice that if the ground truth y[0] is not sorted, __include_gt__
does not behave properly.
It might be worth mentioning this in the documentation.
Code to reproduce:
def test(self):
h_s = torch.randn(1, 10, 20)
h_t = torch.randn(1, 10, 20)
s_mask = torch.ones(1, 10, dtype=torch.bool)
y = torch.as_tensor([[2, 0, 1], [3, 4, 5]])
# make sure top k doesn't include ground truth
h_s[0, y[0]] = 100
h_t[0, y[1]] = -100
self.k = 1
S_idx = self.__top_k__(h_s, h_t)
S_rnd_idx = torch.zeros(1, 10, 1, dtype=torch.long)
S_idx = torch.cat([S_idx, S_rnd_idx], dim=-1)
S_idx = self.__include_gt__(S_idx, s_mask, y)
mask = S_idx[0, y[0]] == y[1].view(-1, 1)
print(mask.any(dim=-1))
Metadata
Metadata
Assignees
Labels
No labels