Skip to content

Commit e40775f

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Extend raw_id_tracker to support hash_zch_runtime_meta (#3598)
Summary: See https://fb.workplace.com/groups/1404957374198553/permalink/1610214197006202/ This diff extends `raw_id_tracker` to store `hash_zch_runtime_meta` which will be alongside with `hash_zch_identities` when presented. Note that it is possible that a mpzch table only has `hash_zch_identities` without `hash_zch_runtime_meta` but is not true vice versa. Reviewed By: chouxi Differential Revision: D88600497
1 parent 7b3effd commit e40775f

File tree

3 files changed

+60
-14
lines changed

3 files changed

+60
-14
lines changed

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def append(
9393
ids: torch.Tensor,
9494
states: Optional[torch.Tensor] = None,
9595
raw_ids: Optional[torch.Tensor] = None,
96+
runtime_meta: Optional[torch.Tensor] = None,
9697
) -> None:
9798
"""
9899
Append a batch of ids and states to the store for a specific table.
@@ -165,6 +166,7 @@ def append(
165166
ids: torch.Tensor,
166167
states: Optional[torch.Tensor] = None,
167168
raw_ids: Optional[torch.Tensor] = None,
169+
runtime_meta: Optional[torch.Tensor] = None,
168170
) -> None:
169171
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
170172
table_fqn_lookup.append(
@@ -284,10 +286,13 @@ def append(
284286
ids: torch.Tensor,
285287
states: Optional[torch.Tensor] = None,
286288
raw_ids: Optional[torch.Tensor] = None,
289+
runtime_meta: Optional[torch.Tensor] = None,
287290
) -> None:
288291
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
289292
table_fqn_lookup.append(
290-
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
293+
RawIndexedLookup(
294+
batch_idx=batch_idx, ids=ids, raw_ids=raw_ids, runtime_meta=runtime_meta
295+
)
291296
)
292297
self.per_fqn_lookups[fqn] = table_fqn_lookup
293298

torchrec/distributed/model_tracker/trackers/raw_id_tracker.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,40 +185,74 @@ def record_lookup(
185185
states: torch.Tensor,
186186
emb_module: Optional[nn.Module] = None,
187187
raw_ids: Optional[torch.Tensor] = None,
188+
runtime_meta: Optional[torch.Tensor] = None,
188189
) -> None:
189190
per_table_ids: Dict[str, List[torch.Tensor]] = {}
190191
per_table_raw_ids: Dict[str, List[torch.Tensor]] = {}
192+
per_table_runtime_meta: Dict[str, List[torch.Tensor]] = {}
191193

192-
# Skip storing invalid input or raw ids
193-
if (
194-
raw_ids is None
195-
or (kjt.values().numel() == 0)
196-
or not (raw_ids.numel() % kjt.values().numel() == 0)
194+
# Skip storing invalid input or raw ids, note that runtime_meta will only exist if raw_ids exists so we can return early
195+
if raw_ids is None:
196+
logger.debug("Skipping record_lookup: raw_ids is None")
197+
return
198+
199+
if kjt.values().numel() == 0:
200+
logger.debug("Skipping record_lookup: kjt.values() is empty")
201+
return
202+
203+
if not (raw_ids.numel() % kjt.values().numel() == 0):
204+
logger.warning(
205+
f"Skipping record_lookup. Raw_ids has invalid shape {raw_ids.shape}, expected multiple of {kjt.values().numel()}"
206+
)
207+
return
208+
209+
# Skip storing if runtime_meta is provided but has invalid shape
210+
if runtime_meta is not None and not (
211+
runtime_meta.numel() % kjt.values().numel() == 0
197212
):
213+
logger.warning(
214+
f"Skipping record_lookup. Runtime_meta has invalid shape {runtime_meta.shape}, expected multiple of {kjt.values().numel()}"
215+
)
198216
return
199217

200-
embeddings_2d = raw_ids.view(kjt.values().numel(), -1)
218+
raw_ids_2d = raw_ids.view(kjt.values().numel(), -1)
219+
runtime_meta_2d = None
220+
# It is possible that runtime_meta is None while raw_ids is not None so we will proceed
221+
if runtime_meta is not None:
222+
runtime_meta_2d = runtime_meta.view(kjt.values().numel(), -1)
201223

202224
offset: int = 0
203225
for key in kjt.keys():
204226
table_fqn = self.table_to_fqn[key]
205227
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
206-
emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
228+
raw_ids_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
229+
runtime_meta_list: List[torch.Tensor] = per_table_runtime_meta.get(
230+
table_fqn, []
231+
)
207232

208233
ids = kjt[key].values()
209234
ids_list.append(ids)
210-
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
235+
raw_ids_list.append(raw_ids_2d[offset : offset + ids.numel()])
236+
if runtime_meta_2d is not None:
237+
runtime_meta_list.append(runtime_meta_2d[offset : offset + ids.numel()])
211238
offset += ids.numel()
212239

213240
per_table_ids[table_fqn] = ids_list
214-
per_table_raw_ids[table_fqn] = emb_list
241+
per_table_raw_ids[table_fqn] = raw_ids_list
242+
if runtime_meta_2d is not None:
243+
per_table_runtime_meta[table_fqn] = runtime_meta_list
215244

216245
for table_fqn, ids_list in per_table_ids.items():
217246
self.store.append(
218247
batch_idx=self.curr_batch_idx,
219248
fqn=table_fqn,
220249
ids=torch.cat(ids_list),
221250
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
251+
runtime_meta=(
252+
torch.cat(per_table_runtime_meta[table_fqn])
253+
if table_fqn in per_table_runtime_meta
254+
else None
255+
),
222256
)
223257

224258
def _clean_fqn_fn(self, fqn: str) -> str:
@@ -277,8 +311,8 @@ def get_indexed_lookups(
277311
self,
278312
tables: List[str],
279313
consumer: Optional[str] = None,
280-
) -> Dict[str, List[torch.Tensor]]:
281-
raw_id_per_table: Dict[str, List[torch.Tensor]] = {}
314+
) -> Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
315+
result: Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]] = {}
282316
consumer = consumer or self.DEFAULT_CONSUMER
283317
assert (
284318
consumer in self.per_consumer_batch_idx
@@ -293,17 +327,23 @@ def get_indexed_lookups(
293327

294328
for table in tables:
295329
raw_ids_list = []
330+
runtime_meta_list = []
296331
fqn = self.table_to_fqn[table]
297332
if fqn in indexed_lookups:
298333
for indexed_lookup in indexed_lookups[fqn]:
299334
if indexed_lookup.raw_ids is not None:
300335
raw_ids_list.append(indexed_lookup.raw_ids)
301-
raw_id_per_table[table] = raw_ids_list
336+
if indexed_lookup.runtime_meta is not None:
337+
runtime_meta_list.append(indexed_lookup.runtime_meta)
338+
if (
339+
raw_ids_list
340+
): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
341+
result[table] = (raw_ids_list, runtime_meta_list)
302342

303343
if self._delete_on_read:
304344
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
305345

306-
return raw_id_per_table
346+
return result
307347

308348
def delete(self, up_to_idx: Optional[int]) -> None:
309349
self.store.delete(up_to_idx)

torchrec/distributed/model_tracker/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class RawIndexedLookup:
3535
batch_idx: int
3636
ids: torch.Tensor
3737
raw_ids: Optional[torch.Tensor] = None
38+
runtime_meta: Optional[torch.Tensor] = None
3839

3940

4041
@dataclass

0 commit comments

Comments
 (0)