-
Notifications
You must be signed in to change notification settings - Fork 29
Debias #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Debias #23
Conversation
recstudio/model/mf/dice.py
Outdated
| def _get_query_encoder(self, train_data): | ||
| int = torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
| pop = torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
| class DICEQueryEncoder(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The int and pop in query encoder could be defined as one Embedding, with dimension 2*self.embed_dim? @pepsi2222
| class DICEQueryEncoder(torch.nn.Module): | |
| torch.nn.Embedding(train_data.num_users, 2*self.embed_dim, padding_idx=0) |
recstudio/model/mf/dice.py
Outdated
| self.pop = pop | ||
| def forward(self, batch): | ||
| return torch.cat((self.int(batch), self.pop(batch)), dim=-1) | ||
| return DICEItemEncoder(int, pop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly as comment in query encoder above.
XuHwang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job! But some changes should be token according to the comments.
recstudio/model/mf/dice.py
Outdated
| return output | ||
|
|
||
| def _get_sampler(self, train_data): | ||
| class PopularSamplerWithMargin(Sampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be discussed. I'm confusing about what the pool means and why negative items are sampled only in pop or unpop items when their size are smaller than pool. @pepsi2222
recstudio/model/mf/dice.py
Outdated
| output['mask'] = mask | ||
| output['score'] = {'pos_int_score': pos_int_score, 'pos_pop_score': pos_pop_score, 'pos_click_score': pos_click_score, | ||
| 'neg_int_score': neg_int_score, 'neg_pop_score': neg_pop_score, 'neg_click_score': neg_click_score} | ||
| output['query'] = {'query_int': query.chunk(2, -1)[0], 'query_pop': query.chunk(2, -1)[1]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chunk() operations are duplicated here, query_int and query_pop are defined in line 117.
| output['query'] = {'query_int': query.chunk(2, -1)[0], 'query_pop': query.chunk(2, -1)[1]} | |
| output['query'] = {'query_int': query_int, 'query_pop': query_pop} |
recstudio/model/mf/dice.py
Outdated
| from recstudio.model.mf.bpr import BPR | ||
| from recstudio.model import basemodel, loss_func | ||
| import time | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
u'd better attach the title and url of the paper corresponding to the model with a comment here. @pepsi2222
recstudio/data/advance_dataset.py
Outdated
|
|
||
| def build(self, split_ratio, shuffle=True, split_mode='user_entry', **kwargs): | ||
| def build(self, split_ratio, shuffle=True, split_mode='user_entry', excluding_hist=False, **kwargs): | ||
| self.excluding_hist = excluding_hist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think excluding_hist is not a good name here. How about return_hist? @pepsi2222 @Xiuchen519
recstudio/model/mf/dice.py
Outdated
|
|
||
| if num_pop_items < self.pool: | ||
| for cnt in range(num_neg): | ||
| idx = torch.randint(num_unpop_items, (1,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using torch.randint(num_unpop_items, (num_neg,)) instead of for?
| idx = torch.randint(num_unpop_items, (1,)) | |
| idx = torch.randint(num_unpop_items, (num_neg,)) |
recstudio/model/mf/expomf.py
Outdated
| # data to device | ||
| batch = self._to_device(batch, self.device) | ||
| # update latent user/item factors | ||
| a = self._expectation(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the method is not required to be overrided. U can add some conditional statement in training_step method to achieve the EM alg as below:
if batch_idx % 2 == 0:
do expectation
else:
do maximization
recstudio/model/mf/pda.py
Outdated
| excluding_hist=self.config.get('excluding_hist', False), | ||
| method=self.config.get('sampling_method', 'none'), return_query=True) | ||
| pos_score = self.score_func(query, pos_item_vec) | ||
| pos_score = pos_item_vec.split([pos_item_vec.shape[-1]-1, 1], dim=-1)[1] ** self.config['gamma'] * pos_score # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder the line is duplicated, becasue the operation has been done in scorer in line 53. @pepsi2222
If so, the method don't need to be overrided.
recstudio/model/mf/pmf.py
Outdated
| def _get_query_encoder(self, train_data): | ||
| return torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
|
|
||
| def _init_parameter(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can define init_method: normal and init_range:0.1 in pmf.yaml without overriding the method.
[feat&fix] add ExpoMF, PDA, DICE