@@ -240,102 +240,37 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
240240
241241 self .assertTrue (isinstance (model .module , ShardedFusedEmbeddingBagCollection ))
242242
243- @patch ("torch._utils_internal.justknobs_check" )
244- def test_sharding_ebc_input_validation_enabled (self , mock_jk : Mock ) -> None :
243+ def test_sharding_ebc_input_validation_enabled (self ) -> None :
245244 model = self ._create_sharded_model ()
246245 kjt = KeyedJaggedTensor (
247246 keys = ["my_feature" , "my_feature" ],
248247 values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
249248 lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
250249 offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
251250 )
252- mock_jk .return_value = True
253251
254252 with self .assertRaisesRegex (ValueError , "keys must be unique" ):
255253 model (kjt )
256254
257- # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258- # This ignores any other calls to justknobs_check() with other inputs
259- # and protects the test from breaking when new JK checks are added.
260- validation_calls = [
261- call
262- for call in mock_jk .call_args_list
263- if len (call [0 ]) > 0
264- and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
265- ]
266- self .assertEqual (
267- 1 ,
268- len (validation_calls ),
269- "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
270- )
271-
272- @patch ("torch._utils_internal.justknobs_check" )
273- def test_sharding_ebc_validate_input_only_once (self , mock_jk : Mock ) -> None :
255+ def test_sharding_ebc_validate_input_only_once (self ) -> None :
274256 model = self ._create_sharded_model ()
275257 kjt = KeyedJaggedTensor (
276258 keys = ["my_feature" ],
277259 values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
278260 lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
279261 offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
280262 ).to (self .device )
281- mock_jk .return_value = True
282263
283264 with self .assertLogs (embeddingbag_logger , level = "INFO" ) as logs :
284265 model (kjt )
285266 model (kjt )
286267 model (kjt )
287268
288- # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289- # This ignores any other calls to justknobs_check() with other inputs
290- # and protects the test from breaking when new JK checks are added.
291- validation_calls = [
292- call
293- for call in mock_jk .call_args_list
294- if len (call [0 ]) > 0
295- and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
296- ]
297- self .assertEqual (
298- 1 ,
299- len (validation_calls ),
300- "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
301- )
302269 matched_logs = list (
303270 filter (lambda s : "Validating input features..." in s , logs .output )
304271 )
305272 self .assertEqual (1 , len (matched_logs ))
306273
307- @patch ("torch._utils_internal.justknobs_check" )
308- def test_sharding_ebc_input_validation_disabled (self , mock_jk : Mock ) -> None :
309- model = self ._create_sharded_model ()
310- kjt = KeyedJaggedTensor (
311- keys = ["my_feature" , "my_feature" ],
312- values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
313- lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
314- offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
315- ).to (self .device )
316- mock_jk .return_value = False
317-
318- # Without KJT validation, input_dist will not raise exceptions
319- try :
320- model (kjt )
321- except ValueError :
322- self .fail ("Input validation should not be enabled." )
323-
324- # Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325- # This ignores any other calls to justknobs_check() with other inputs
326- # and protects the test from breaking when new JK checks are added.
327- validation_calls = [
328- call
329- for call in mock_jk .call_args_list
330- if len (call [0 ]) > 0
331- and call [0 ][0 ] == "pytorch/torchrec:enable_kjt_validation"
332- ]
333- self .assertEqual (
334- 1 ,
335- len (validation_calls ),
336- "There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation" ,
337- )
338-
339274 def _create_sharded_model (
340275 self , embedding_dim : int = 128 , num_embeddings : int = 256
341276 ) -> DistributedModelParallel :
0 commit comments