@@ -185,40 +185,74 @@ def record_lookup(
185185 states : torch .Tensor ,
186186 emb_module : Optional [nn .Module ] = None ,
187187 raw_ids : Optional [torch .Tensor ] = None ,
188+ runtime_meta : Optional [torch .Tensor ] = None ,
188189 ) -> None :
189190 per_table_ids : Dict [str , List [torch .Tensor ]] = {}
190191 per_table_raw_ids : Dict [str , List [torch .Tensor ]] = {}
192+ per_table_runtime_meta : Dict [str , List [torch .Tensor ]] = {}
191193
192- # Skip storing invalid input or raw ids
193- if (
194- raw_ids is None
195- or (kjt .values ().numel () == 0 )
196- or not (raw_ids .numel () % kjt .values ().numel () == 0 )
194+ # Skip storing invalid input or raw ids, note that runtime_meta will only exist if raw_ids exists so we can return early
195+ if raw_ids is None :
196+ logger .debug ("Skipping record_lookup: raw_ids is None" )
197+ return
198+
199+ if kjt .values ().numel () == 0 :
200+ logger .debug ("Skipping record_lookup: kjt.values() is empty" )
201+ return
202+
203+ if not (raw_ids .numel () % kjt .values ().numel () == 0 ):
204+ logger .warning (
205+ f"Skipping record_lookup. Raw_ids has invalid shape { raw_ids .shape } , expected multiple of { kjt .values ().numel ()} "
206+ )
207+ return
208+
209+ # Skip storing if runtime_meta is provided but has invalid shape
210+ if runtime_meta is not None and not (
211+ runtime_meta .numel () % kjt .values ().numel () == 0
197212 ):
213+ logger .warning (
214+ f"Skipping record_lookup. Runtime_meta has invalid shape { runtime_meta .shape } , expected multiple of { kjt .values ().numel ()} "
215+ )
198216 return
199217
200- embeddings_2d = raw_ids .view (kjt .values ().numel (), - 1 )
218+ raw_ids_2d = raw_ids .view (kjt .values ().numel (), - 1 )
219+ runtime_meta_2d = None
220+ # It is possible that runtime_meta is None while raw_ids is not None so we will proceed
221+ if runtime_meta is not None :
222+ runtime_meta_2d = runtime_meta .view (kjt .values ().numel (), - 1 )
201223
202224 offset : int = 0
203225 for key in kjt .keys ():
204226 table_fqn = self .table_to_fqn [key ]
205227 ids_list : List [torch .Tensor ] = per_table_ids .get (table_fqn , [])
206- emb_list : List [torch .Tensor ] = per_table_raw_ids .get (table_fqn , [])
228+ raw_ids_list : List [torch .Tensor ] = per_table_raw_ids .get (table_fqn , [])
229+ runtime_meta_list : List [torch .Tensor ] = per_table_runtime_meta .get (
230+ table_fqn , []
231+ )
207232
208233 ids = kjt [key ].values ()
209234 ids_list .append (ids )
210- emb_list .append (embeddings_2d [offset : offset + ids .numel ()])
235+ raw_ids_list .append (raw_ids_2d [offset : offset + ids .numel ()])
236+ if runtime_meta_2d is not None :
237+ runtime_meta_list .append (runtime_meta_2d [offset : offset + ids .numel ()])
211238 offset += ids .numel ()
212239
213240 per_table_ids [table_fqn ] = ids_list
214- per_table_raw_ids [table_fqn ] = emb_list
241+ per_table_raw_ids [table_fqn ] = raw_ids_list
242+ if runtime_meta_2d is not None :
243+ per_table_runtime_meta [table_fqn ] = runtime_meta_list
215244
216245 for table_fqn , ids_list in per_table_ids .items ():
217246 self .store .append (
218247 batch_idx = self .curr_batch_idx ,
219248 fqn = table_fqn ,
220249 ids = torch .cat (ids_list ),
221250 raw_ids = torch .cat (per_table_raw_ids [table_fqn ]),
251+ runtime_meta = (
252+ torch .cat (per_table_runtime_meta [table_fqn ])
253+ if table_fqn in per_table_runtime_meta
254+ else None
255+ ),
222256 )
223257
224258 def _clean_fqn_fn (self , fqn : str ) -> str :
@@ -277,8 +311,8 @@ def get_indexed_lookups(
277311 self ,
278312 tables : List [str ],
279313 consumer : Optional [str ] = None ,
280- ) -> Dict [str , List [torch .Tensor ]]:
281- raw_id_per_table : Dict [str , List [torch .Tensor ]] = {}
314+ ) -> Dict [str , Tuple [ List [torch .Tensor ], List [ torch . Tensor ] ]]:
315+ result : Dict [str , Tuple [ List [torch .Tensor ], List [ torch . Tensor ] ]] = {}
282316 consumer = consumer or self .DEFAULT_CONSUMER
283317 assert (
284318 consumer in self .per_consumer_batch_idx
@@ -293,17 +327,23 @@ def get_indexed_lookups(
293327
294328 for table in tables :
295329 raw_ids_list = []
330+ runtime_meta_list = []
296331 fqn = self .table_to_fqn [table ]
297332 if fqn in indexed_lookups :
298333 for indexed_lookup in indexed_lookups [fqn ]:
299334 if indexed_lookup .raw_ids is not None :
300335 raw_ids_list .append (indexed_lookup .raw_ids )
301- raw_id_per_table [table ] = raw_ids_list
336+ if indexed_lookup .runtime_meta is not None :
337+ runtime_meta_list .append (indexed_lookup .runtime_meta )
338+ if (
339+ raw_ids_list
340+ ): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
341+ result [table ] = (raw_ids_list , runtime_meta_list )
302342
303343 if self ._delete_on_read :
304344 self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
305345
306- return raw_id_per_table
346+ return result
307347
308348 def delete (self , up_to_idx : Optional [int ]) -> None :
309349 self .store .delete (up_to_idx )
0 commit comments