Skip to content

Sorting requirement for __include_gt__ #21

Open
@Nifury

Description

@Nifury

Hello, I notice that if the ground truth y[0] is not sorted, __include_gt__ does not behave properly.
It might be worth mentioning this in the documentation.

Code to reproduce:

def test(self):
    h_s = torch.randn(1, 10, 20)
    h_t = torch.randn(1, 10, 20)
    s_mask = torch.ones(1, 10, dtype=torch.bool)
    y = torch.as_tensor([[2, 0, 1], [3, 4, 5]])
    # make sure top k doesn't include ground truth
    h_s[0, y[0]] = 100
    h_t[0, y[1]] = -100
    self.k = 1
    S_idx = self.__top_k__(h_s, h_t)
    S_rnd_idx = torch.zeros(1, 10, 1, dtype=torch.long)
    S_idx = torch.cat([S_idx, S_rnd_idx], dim=-1)
    S_idx = self.__include_gt__(S_idx, s_mask, y)
    mask = S_idx[0, y[0]] == y[1].view(-1, 1)
    print(mask.any(dim=-1))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions