diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java index 4872dc171b..987497cef5 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java @@ -259,6 +259,15 @@ public Map> getAvailablePartitionServersForWrit // 0, 1, 2 int idx = (int) (taskAttemptId % (serverSize - 1)) + 1; candidate = servers.get(idx); + } else { + // fallback to random server if no available servers in load-balanced mode + servers = + replicaServerEntry.getValue().stream() + .filter(x -> !excludedServerToReplacements.containsKey(x.getId())) + .collect(Collectors.toList()); + serverSize = servers.size(); + int idx = (int) (taskAttemptId % (serverSize - 1)) + 1; + candidate = servers.get(idx); } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index df13e0f390..963000e8c0 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -124,27 +124,26 @@ public CompletableFuture send(AddBlockEvent event) { String taskId = event.getTaskId(); List blocks = event.getShuffleDataInfoList(); List validBlocks = filterOutStaleAssignmentBlocks(taskId, blocks); - if (CollectionUtils.isEmpty(validBlocks)) { - return 0L; - } SendShuffleDataResult result = null; try { - result = - shuffleWriteClient.sendShuffleData( - rssAppId, - event.getStageAttemptNumber(), - validBlocks, - () -> !isValidTask(taskId)); - // completionCallback should be executed before updating taskToSuccessBlockIds - // structure to avoid side effect - Set succeedBlockIds = getSucceedBlockIds(result); - for (ShuffleBlockInfo block : validBlocks) { - block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId())); + if (CollectionUtils.isNotEmpty(validBlocks)) { + result = + shuffleWriteClient.sendShuffleData( + rssAppId, + event.getStageAttemptNumber(), + validBlocks, + () -> !isValidTask(taskId)); + // completionCallback should be executed before updating taskToSuccessBlockIds + // structure to avoid side effect + Set succeedBlockIds = getSucceedBlockIds(result); + for (ShuffleBlockInfo block : validBlocks) { + block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId())); + } + putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); + putFailedBlockSendTracker( + taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker()); } - putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); - putFailedBlockSendTracker( - taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker()); } finally { WriteBufferManager bufferManager = event.getBufferManager(); if (bufferManager != null && result != null) { @@ -159,6 +158,9 @@ public CompletableFuture send(AddBlockEvent event) { runnable.run(); } } + if (CollectionUtils.isEmpty(validBlocks)) { + return 0L; + } Set succeedBlockIds = getSucceedBlockIds(result); return validBlocks.stream() .filter(x -> succeedBlockIds.contains(x.getBlockId())) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java index 63fac0c12a..2faac350ba 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java @@ -91,10 +91,11 @@ private boolean hasBeenLoadBalanced(int partitionId) { public boolean tryNextServerForSplitPartition( int partitionId, List exclusiveServers) { if (hasBeenLoadBalanced(partitionId)) { - Set servers = - this.exclusiveServersForPartition.computeIfAbsent( - partitionId, k -> new ConcurrentSkipListSet<>()); - servers.addAll(exclusiveServers); + // update the exclusive servers + this.exclusiveServersForPartition + .computeIfAbsent(partitionId, k -> new ConcurrentSkipListSet<>()) + .addAll(exclusiveServers); + // update the assignment due to the upper exclusive servers change update(this.handle); return true; } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java index c380e3ee78..57bd256414 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/ReassignExecutor.java @@ -33,6 +33,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; @@ -306,19 +307,22 @@ private void reassignOnPartitionNeedSplit(FailedBlockSendTracker failedTracker) } } + String readableMessage = readableResult(fastSwitchList); if (reassignList.isEmpty()) { - LOG.info( - "[partition-split] All fast switch to another servers successfully for taskId[{}]. list: {}", - taskId, - readableResult(fastSwitchList)); - return; - } else { - if (!fastSwitchList.isEmpty()) { + if (StringUtils.isNotEmpty(readableMessage)) { LOG.info( - "[partition-split] Partial fast switch to another servers for taskId[{}]. list: {}", + "[partition-split] All partitions fast-switched successfully for taskId[{}]. list: {}", taskId, - readableResult(fastSwitchList)); + readableMessage); } + return; + } + + if (StringUtils.isNotEmpty(readableMessage)) { + LOG.info( + "[partition-split] Partial partitions fast-switched for taskId[{}]. list: {}", + taskId, + readableMessage); } @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance") @@ -385,6 +389,7 @@ private void reassignAndResendBlocks(Set blocks) { List resendCandidates = Lists.newArrayList(); Map> partitionedFailedBlocks = blocks.stream() + .filter(x -> x.getStatusCode() != null) .collect(Collectors.groupingBy(d -> d.getShuffleBlockInfo().getPartitionId())); Map> failurePartitionToServers = new HashMap<>(); @@ -429,8 +434,12 @@ private void reassignAndResendBlocks(Set blocks) { readableResult(constructUpdateList(failurePartitionToServers))); } + int staleCnt = 0; for (TrackingBlockStatus blockStatus : blocks) { ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo(); + if (blockStatus.getStatusCode() == null) { + staleCnt += 1; + } // todo: getting the replacement should support multi replica. List servers = taskAttemptAssignment.retrieve(block.getPartitionId()); // Gets the first replica for this partition for now. @@ -459,8 +468,10 @@ private void reassignAndResendBlocks(Set blocks) { } resendBlocksFunction.accept(resendCandidates); LOG.info( - "[partition-reassign] All {} blocks have been resent to queue successfully in {} ms.", + "[partition-reassign] {} blocks (failed/stale: {}/{}) have been resent to queue successfully in {} ms.", blocks.size(), + blocks.size() - staleCnt, + staleCnt, System.currentTimeMillis() - start); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java index cf75152e75..63dc247d37 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfoTest.java @@ -276,4 +276,46 @@ public void testUpdateAssignmentOnPartitionSplit() { // All the servers were selected as writer are available as reader assertEquals(6, assignment.get(1).size()); } + + @Test + public void testLoadBalanceFallbackToNonExcludedServers() { + // prepare servers + ShuffleServerInfo a = createFakeServerInfo("a"); + ShuffleServerInfo b = createFakeServerInfo("b"); + + Map> partitionToServers = new HashMap<>(); + partitionToServers.put(1, Arrays.asList(a, b)); + + // create handle with LOAD_BALANCE mode + MutableShuffleHandleInfo handleInfo = + new MutableShuffleHandleInfo( + 1, + partitionToServers, + new RemoteStorageInfo(""), + org.apache.uniffle.common.PartitionSplitMode.LOAD_BALANCE); + + int partitionId = 1; + + // mark partition as split by excluding server "a" + Set replacements = Sets.newHashSet(createFakeServerInfo("c")); + handleInfo.updateAssignmentOnPartitionSplit(partitionId, "a", replacements); + + // also make sure excludedServerToReplacements contains "b" + // so that first filtering (exclude problem nodes) removes all servers + handleInfo.updateAssignment(partitionId, "b", Sets.newHashSet(createFakeServerInfo("d"))); + + // now call writer assignment + Map> available = + handleInfo.getAvailablePartitionServersForWriter(null); + + // fallback branch should be triggered and still return a valid candidate + // ensure we have exactly one candidate for replica 0 + assertTrue(available.containsKey(partitionId)); + assertEquals(2, available.get(partitionId).size()); + + // candidate must be one of the original servers or appended replacements, rather than always + // the last one + ShuffleServerInfo candidate = available.get(partitionId).get(0); + assertEquals("c", candidate.getId()); + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java index eb357d9da9..720bad4e39 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -25,6 +25,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import com.google.common.collect.Maps; @@ -130,6 +131,70 @@ public void testFilterOutStaleAssignmentBlocks() { assertEquals(3, failedBlockIds.stream().findFirst().get()); } + /** + * Test that when all blocks in a batch are stale (filtered out by fast-switch), the + * processedCallbackChain is still executed. Before the fix, if all blocks were stale, the early + * return skipped the finally block, causing the callback (which notifies checkBlockSendResult via + * finishEventQueue) to never run. This led to checkBlockSendResult blocking indefinitely on + * poll(), unable to call reassign() to resend the stale blocks, ultimately timing out. + */ + @Test + public void testProcessedCallbackChainExecutedWhenAllBlocksAreStale() + throws ExecutionException, InterruptedException { + FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient(); + + Map> taskToSuccessBlockIds = Maps.newConcurrentMap(); + Map taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap(); + Set failedTaskIds = new HashSet<>(); + + RssConf rssConf = new RssConf(); + rssConf.set(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED, true); + rssConf.set(RssSparkConfig.RSS_PARTITION_REASSIGN_STALE_ASSIGNMENT_FAST_SWITCH_ENABLED, true); + DataPusher dataPusher = + new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockSendTracker, + failedTaskIds, + 1, + 2, + rssConf); + dataPusher.setRssAppId("testCallbackWhenAllStale"); + + String taskId = "taskId1"; + List server1 = + Collections.singletonList(new ShuffleServerInfo("0", "localhost", 1234)); + // Create a stale block: isStaleAssignment() returns true because the + // partitionAssignmentRetrieveFunc returns an empty list (different from the block's servers). + ShuffleBlockInfo staleBlock = + new ShuffleBlockInfo( + 1, 1, 10, 1, 1, new byte[1], server1, 1, 100, 1, integer -> Collections.emptyList()); + + // Track whether processedCallbackChain is invoked + AtomicBoolean callbackExecuted = new AtomicBoolean(false); + AddBlockEvent event = new AddBlockEvent(taskId, Arrays.asList(staleBlock)); + event.addCallback(() -> callbackExecuted.set(true)); + + CompletableFuture future = dataPusher.send(event); + long result = future.get(); + + // The block is stale, so no data is actually sent (0 bytes freed) + assertEquals(0L, result); + + // The stale block should be tracked in the FailedBlockSendTracker + Set failedBlockIds = taskToFailedBlockSendTracker.get(taskId).getFailedBlockIds(); + assertEquals(1, failedBlockIds.size()); + assertEquals(10, failedBlockIds.stream().findFirst().get()); + + // The processedCallbackChain MUST be executed even when all blocks are stale. + // Before the fix, this assertion would fail because the early return (return 0L) + // was placed before the try-finally that executes the callback chain. + assertTrue( + callbackExecuted.get(), + "processedCallbackChain must be executed even when all blocks are stale, " + + "otherwise checkBlockSendResult will block on finishEventQueue.poll() indefinitely"); + } + @Test public void testSendData() throws ExecutionException, InterruptedException { FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient(); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 60f3ef46f3..03d92e8a30 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -81,6 +81,7 @@ import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.shuffle.BlockStats; import org.apache.uniffle.shuffle.ReassignExecutor; import org.apache.uniffle.shuffle.ShuffleWriteTaskStats; @@ -518,8 +519,7 @@ protected List> postBlockEvent( } event.addCallback( () -> { - boolean ret = finishEventQueue.add(new Object()); - if (!ret) { + if (!finishEventQueue.add(new Object())) { LOG.error("Add event " + event + " to finishEventQueue fail"); } }); @@ -572,17 +572,26 @@ protected void checkBlockSendResult(Set blockIds) { } Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); if (currentAckValue != 0 || blockIds.size() != successBlockIds.size()) { - int failedBlockCount = blockIds.size() - successBlockIds.size(); - String errorMsg = - "Timeout: Task[" - + taskId - + "] failed because " - + failedBlockCount - + " blocks can't be sent to shuffle server in " - + sendCheckTimeout - + " ms."; - LOG.error(errorMsg); - throw new RssWaitFailedException(errorMsg); + int missing = blockIds.size() - successBlockIds.size(); + int failed = + Optional.ofNullable(shuffleManager.getFailedBlockIds(taskId)).map(Set::size).orElse(0); + String message = + String.format( + "TaskId[%s] failed because %d blocks (failed: %d}) can't be sent to shuffle server in %d ms", + taskId, missing, failed, sendCheckTimeout); + + // detailed error message + Set missingBlockIds = new HashSet<>(blockIds); + missingBlockIds.removeAll(successBlockIds); + BlockIdLayout layout = BlockIdLayout.from(rssConf); + LOG.error( + "{}, includes partitions: {}", + message, + missingBlockIds.stream() + .map(x -> layout.getPartitionId(x)) + .collect(Collectors.toSet())); + + throw new RssWaitFailedException(message); } } finally { if (interrupted) { diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 2e16f454c5..a054b9b2ac 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -611,7 +611,7 @@ public void checkBlockSendResultTest() { assertThrows( RuntimeException.class, () -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L))); - assertTrue(e2.getMessage().startsWith("Timeout:")); + assertTrue(e2.getMessage().contains("failed because")); successBlocks.clear(); // case 3: partial blocks are sent failed, Runtime exception will be thrown