Safemotion Lib
Loading...
Searching...
No Matches
Functions
fastreid.evaluation.query_expansion Namespace Reference

Functions

 aqe (torch.tensor query_feat, torch.tensor gallery_feat, int qe_times=1, int qe_k=10, float alpha=3.0)
 

Detailed Description

@author:  xingyu liao
@contact: sherlockliao01@gmail.com

Function Documentation

◆ aqe()

fastreid.evaluation.query_expansion.aqe ( torch.tensor query_feat,
torch.tensor gallery_feat,
int qe_times = 1,
int qe_k = 10,
float alpha = 3.0 )
Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
Args :
    query_feat (torch.tensor):
    gallery_feat (torch.tensor):
    qe_times (int): number of query expansion times.
    qe_k (int): number of the neighbors to be combined.
    alpha (float):

Definition at line 15 of file query_expansion.py.

16 qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0):
17 """
18 Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
19 c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
20 Args :
21 query_feat (torch.tensor):
22 gallery_feat (torch.tensor):
23 qe_times (int): number of query expansion times.
24 qe_k (int): number of the neighbors to be combined.
25 alpha (float):
26 """
27 num_query = query_feat.shape[0]
28 all_feat = torch.cat((query_feat, gallery_feat), dim=0)
29 norm_feat = F.normalize(all_feat, p=2, dim=1)
30
31 all_feat = all_feat.numpy()
32 for i in range(qe_times):
33 all_feat_list = []
34 sims = torch.mm(norm_feat, norm_feat.t())
35 sims = sims.data.cpu().numpy()
36 for sim in sims:
37 init_rank = np.argpartition(-sim, range(1, qe_k + 1))
38 weights = sim[init_rank[:qe_k]].reshape((-1, 1))
39 weights = np.power(weights, alpha)
40 all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0))
41 all_feat = np.stack(all_feat_list, axis=0)
42 norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1)
43
44 query_feat = torch.from_numpy(all_feat[:num_query])
45 gallery_feat = torch.from_numpy(all_feat[num_query:])
46 return query_feat, gallery_feat