Skip to content

Commit 312f4dd

Browse files
Shuangping Liumeta-codesync[bot]
authored andcommitted
Clean up KJT validator killswitch (#3615)
Summary: Pull Request resolved: #3615 The killswitch `pytorch/torchrec:enable_kjt_validation` has been switched ON for a couple of months and it is working normally, so it should be safe to clean it up. Reviewed By: TroyGarden Differential Revision: D89088119 fbshipit-source-id: 9f339867f192d76224155bceb14619e98f35ff0e
1 parent cc8df00 commit 312f4dd

File tree

2 files changed

+4
-72
lines changed

2 files changed

+4
-72
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,11 +1539,8 @@ def input_dist(
15391539
ctx.inverse_indices = features.inverse_indices_or_none()
15401540

15411541
if self._has_uninitialized_input_dist:
1542-
if torch._utils_internal.justknobs_check(
1543-
"pytorch/torchrec:enable_kjt_validation"
1544-
):
1545-
logger.info("Validating input features...")
1546-
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)
1542+
logger.info("Validating input features...")
1543+
validate_keyed_jagged_tensor(features, self._embedding_bag_configs)
15471544

15481545
self._create_input_dist(features.keys())
15491546
self._has_uninitialized_input_dist = False

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -240,102 +240,37 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
240240

241241
self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))
242242

243-
@patch("torch._utils_internal.justknobs_check")
244-
def test_sharding_ebc_input_validation_enabled(self, mock_jk: Mock) -> None:
243+
def test_sharding_ebc_input_validation_enabled(self) -> None:
245244
model = self._create_sharded_model()
246245
kjt = KeyedJaggedTensor(
247246
keys=["my_feature", "my_feature"],
248247
values=torch.tensor([1, 2, 3, 4, 5]),
249248
lengths=torch.tensor([1, 2, 0, 2]),
250249
offsets=torch.tensor([0, 1, 3, 3, 5]),
251250
)
252-
mock_jk.return_value = True
253251

254252
with self.assertRaisesRegex(ValueError, "keys must be unique"):
255253
model(kjt)
256254

257-
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
258-
# This ignores any other calls to justknobs_check() with other inputs
259-
# and protects the test from breaking when new JK checks are added.
260-
validation_calls = [
261-
call
262-
for call in mock_jk.call_args_list
263-
if len(call[0]) > 0
264-
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
265-
]
266-
self.assertEqual(
267-
1,
268-
len(validation_calls),
269-
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
270-
)
271-
272-
@patch("torch._utils_internal.justknobs_check")
273-
def test_sharding_ebc_validate_input_only_once(self, mock_jk: Mock) -> None:
255+
def test_sharding_ebc_validate_input_only_once(self) -> None:
274256
model = self._create_sharded_model()
275257
kjt = KeyedJaggedTensor(
276258
keys=["my_feature"],
277259
values=torch.tensor([1, 2, 3, 4, 5]),
278260
lengths=torch.tensor([1, 2, 0, 2]),
279261
offsets=torch.tensor([0, 1, 3, 3, 5]),
280262
).to(self.device)
281-
mock_jk.return_value = True
282263

283264
with self.assertLogs(embeddingbag_logger, level="INFO") as logs:
284265
model(kjt)
285266
model(kjt)
286267
model(kjt)
287268

288-
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
289-
# This ignores any other calls to justknobs_check() with other inputs
290-
# and protects the test from breaking when new JK checks are added.
291-
validation_calls = [
292-
call
293-
for call in mock_jk.call_args_list
294-
if len(call[0]) > 0
295-
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
296-
]
297-
self.assertEqual(
298-
1,
299-
len(validation_calls),
300-
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
301-
)
302269
matched_logs = list(
303270
filter(lambda s: "Validating input features..." in s, logs.output)
304271
)
305272
self.assertEqual(1, len(matched_logs))
306273

307-
@patch("torch._utils_internal.justknobs_check")
308-
def test_sharding_ebc_input_validation_disabled(self, mock_jk: Mock) -> None:
309-
model = self._create_sharded_model()
310-
kjt = KeyedJaggedTensor(
311-
keys=["my_feature", "my_feature"],
312-
values=torch.tensor([1, 2, 3, 4, 5]),
313-
lengths=torch.tensor([1, 2, 0, 2]),
314-
offsets=torch.tensor([0, 1, 3, 3, 5]),
315-
).to(self.device)
316-
mock_jk.return_value = False
317-
318-
# Without KJT validation, input_dist will not raise exceptions
319-
try:
320-
model(kjt)
321-
except ValueError:
322-
self.fail("Input validation should not be enabled.")
323-
324-
# Count only calls with the input "pytorch/torchrec:enable_kjt_validation"
325-
# This ignores any other calls to justknobs_check() with other inputs
326-
# and protects the test from breaking when new JK checks are added.
327-
validation_calls = [
328-
call
329-
for call in mock_jk.call_args_list
330-
if len(call[0]) > 0
331-
and call[0][0] == "pytorch/torchrec:enable_kjt_validation"
332-
]
333-
self.assertEqual(
334-
1,
335-
len(validation_calls),
336-
"There should be exactly one call to JK with pytorch/torchrec:enable_kjt_validation",
337-
)
338-
339274
def _create_sharded_model(
340275
self, embedding_dim: int = 128, num_embeddings: int = 256
341276
) -> DistributedModelParallel:

0 commit comments

Comments
 (0)