diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 6b21abed540..9668ece0f98 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -696,6 +696,13 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se get(WORKER_FLUSH_REUSE_COPY_BUFFER_ENABLED) def workerDfsReplicationFactor: Int = get(WORKER_DFS_REPLICATION_FACTOR) + def workerReserveSlotsIoThreadPoolSize: Int = { + val configured = get(RESERVE_SLOTS_IO_THREAD_POOL_SIZE); + // reserve slots creates files, locally or on DFS, parallelism can be high + // compared to the number of CPUs + if (configured == 0) Runtime.getRuntime.availableProcessors() * 8 + else configured + } def clusterName: String = get(CLUSTER_NAME) @@ -6801,6 +6808,17 @@ object CelebornConf extends Logging { .intConf .createWithDefault(2) + val RESERVE_SLOTS_IO_THREAD_POOL_SIZE: ConfigEntry[Int] = + buildConf("celeborn.worker.reserve.slots.io.threads") + .categories("worker") + .version("0.7.0") + .doc("The number of threads used to create PartitionDataWriter in parallel in handleReserveSlots.") + .intConf + .checkValue( + v => v >= 0, + "The number of threads must be positive or zero. Setting to zero lets the worker compute the optimal size automatically") + .createWithDefault(1) + val CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.client.shuffleDataLostOnUnknownWorker.enabled") .categories("client") diff --git a/common/src/test/scala/org/apache/celeborn/common/CelebornConfSuite.scala b/common/src/test/scala/org/apache/celeborn/common/CelebornConfSuite.scala index ff2584ac652..6630c4a3fd3 100644 --- a/common/src/test/scala/org/apache/celeborn/common/CelebornConfSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/CelebornConfSuite.scala @@ -483,4 +483,18 @@ class CelebornConfSuite extends CelebornFunSuite { } } + test("test workerReserveSlotsIoThreadPoolSize") { + // Test default value + val conf1 = new CelebornConf() + assert(conf1.workerReserveSlotsIoThreadPoolSize == 1) + + // Test configured value + val conf2 = new CelebornConf().set(RESERVE_SLOTS_IO_THREAD_POOL_SIZE.key, "10") + assert(conf2.workerReserveSlotsIoThreadPoolSize == 10) + + // Test configured value with 0 + val conf3 = new CelebornConf().set(RESERVE_SLOTS_IO_THREAD_POOL_SIZE.key, "0") + assert(conf3.workerReserveSlotsIoThreadPoolSize == Runtime.getRuntime.availableProcessors() * 8) + } + } diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index b34c5bd6652..78b2f8b4461 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -177,6 +177,7 @@ license: | | celeborn.worker.replicate.port | 0 | false | Server port for Worker to receive replicate data request from other Workers. | 0.2.0 | | | celeborn.worker.replicate.randomConnection.enabled | true | false | Whether worker will create random connection to peer when replicate data. When false, worker tend to reuse the same cached TransportClient to a specific replicate worker; when true, worker tend to use different cached TransportClient. Netty will use the same thread to serve the same connection, so with more connections replicate server can leverage more netty threads | 0.2.1 | | | celeborn.worker.replicate.threads | 64 | false | Thread number of worker to replicate shuffle data. | 0.2.0 | | +| celeborn.worker.reserve.slots.io.threads | 1 | false | The number of threads used to create PartitionDataWriter in parallel in handleReserveSlots. | 0.7.0 | | | celeborn.worker.reuse.hdfs.outputStream.enabled | false | false | Whether to enable reuse output stream on hdfs. | 0.7.0 | | | celeborn.worker.rpc.port | 0 | false | Server port for Worker to receive RPC request. | 0.2.0 | | | celeborn.worker.shuffle.partitionSplit.enabled | true | false | enable the partition split on worker side | 0.3.0 | celeborn.worker.partition.split.enabled | diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 25ce6532cfb..5ccdd47a81f 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -19,12 +19,13 @@ package org.apache.celeborn.service.deploy.worker import java.io.IOException import java.util.{ArrayList => jArrayList, HashMap => jHashMap, List => jList, Set => jSet} -import java.util.concurrent._ +import java.util.concurrent.{CopyOnWriteArrayList, _} import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray, AtomicReference} -import java.util.function.BiFunction +import java.util.function.{BiConsumer, BiFunction, Consumer, Supplier} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.util.Try import io.netty.util.{HashedWheelTimer, Timeout, TimerTask} import org.roaringbitmap.RoaringBitmap @@ -62,6 +63,7 @@ private[deploy] class Controller( var partitionLocationInfo: WorkerPartitionLocationInfo = _ var timer: HashedWheelTimer = _ var commitThreadPool: ThreadPoolExecutor = _ + var reserveSlotsThreadPool: ThreadPoolExecutor = _ var commitFinishedChecker: ScheduledExecutorService = _ var asyncReplyPool: ScheduledExecutorService = _ val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate @@ -86,6 +88,10 @@ private[deploy] class Controller( asyncReplyPool = worker.asyncReplyPool shutdown = worker.shutdown + reserveSlotsThreadPool = + Executors.newFixedThreadPool(conf.workerReserveSlotsIoThreadPoolSize).asInstanceOf[ + ThreadPoolExecutor] + commitFinishedChecker = worker.commitFinishedChecker commitFinishedChecker.scheduleWithFixedDelay( new Runnable { @@ -193,111 +199,145 @@ private[deploy] class Controller( context.reply(ReserveSlotsResponse(StatusCode.NO_AVAILABLE_WORKING_DIR, msg)) return } - val primaryLocs = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until requestPrimaryLocs.size()) { - var location = partitionLocationInfo.getPrimaryLocation( - shuffleKey, - requestPrimaryLocs.get(ind).getUniqueId) - if (location == null) { - location = requestPrimaryLocs.get(ind) - val writer = storageManager.createPartitionDataWriter( - applicationId, - shuffleId, - location, - splitThreshold, - splitMode, - partitionType, - rangeReadFilter, - userIdentifier, - partitionSplitEnabled, - isSegmentGranularityVisible) - primaryLocs.add(new WorkingPartition(location, writer)) - } else { - primaryLocs.add(location) - } - } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) + + def collectResults( + tasks: ArrayBuffer[CompletableFuture[PartitionLocation]], + createdWriters: CopyOnWriteArrayList[PartitionDataWriter], + startTime: Long) = { + val primaryFuture = CompletableFuture.allOf(tasks.toSeq: _*) + .whenComplete(new BiConsumer[Void, Throwable] { + override def accept(ignore: Void, error: Throwable): Unit = { + if (error != null) { + createdWriters.forEach(new Consumer[PartitionDataWriter] { + override def accept(fileWriter: PartitionDataWriter) { + fileWriter.destroy(new IOException( + s"Destroy FileWriter $fileWriter caused by " + + s"reserving slots failed for $shuffleKey.", + error)) + } + }) + } else { + val timeToReserveLocations = System.currentTimeMillis() - startTime; + logInfo( + s"Reserved ${tasks.size} slots for $shuffleKey in $timeToReserveLocations ms (with ${conf.workerReserveSlotsIoThreadPoolSize} threads)") + } + createdWriters.clear() + } + }) + primaryFuture } - if (primaryLocs.size() < requestPrimaryLocs.size()) { - val msg = s"Not all primary partition satisfied for $shuffleKey" - logWarning(s"[handleReserveSlots] $msg, will destroy writers.") - primaryLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } - context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg)) - return + + logInfo(s"Reserving ${requestPrimaryLocs.size()} primary slots for $shuffleKey") + val startReservePrimaryLocs = System.currentTimeMillis + val primaryCfTasks = ArrayBuffer[CompletableFuture[PartitionLocation]]() + val primaryWriters: CopyOnWriteArrayList[PartitionDataWriter] = + new CopyOnWriteArrayList[PartitionDataWriter] + (0 until requestPrimaryLocs.size()).foreach { ind => + primaryCfTasks.append(CompletableFuture.supplyAsync[PartitionLocation]( + new Supplier[PartitionLocation] { + override def get(): PartitionLocation = { + var location: PartitionLocation = partitionLocationInfo.getPrimaryLocation( + shuffleKey, + requestPrimaryLocs.get(ind).getUniqueId) + if (location == null) { + location = requestPrimaryLocs.get(ind) + val writer = storageManager.createPartitionDataWriter( + applicationId, + shuffleId, + location, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible) + primaryWriters.add(writer) + new WorkingPartition(location, writer) + } else { + location + } + } + }, + reserveSlotsThreadPool)) } - val replicaLocs = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until requestReplicaLocs.size()) { - var location = - partitionLocationInfo.getReplicaLocation( - shuffleKey, - requestReplicaLocs.get(ind).getUniqueId) - if (location == null) { - location = requestReplicaLocs.get(ind) - val writer = storageManager.createPartitionDataWriter( - applicationId, - shuffleId, - location, - splitThreshold, - splitMode, - partitionType, - rangeReadFilter, - userIdentifier, - partitionSplitEnabled, - isSegmentGranularityVisible) - replicaLocs.add(new WorkingPartition(location, writer)) + logInfo(s"Reserving ${requestReplicaLocs.size()} replica slots for $shuffleKey") + val startReserveReplicLocs = System.currentTimeMillis + val replicaCfTasks = ArrayBuffer[CompletableFuture[PartitionLocation]]() + val replicaWriters: CopyOnWriteArrayList[PartitionDataWriter] = + new CopyOnWriteArrayList[PartitionDataWriter] + (0 until requestReplicaLocs.size()).foreach { ind => + replicaCfTasks.append(CompletableFuture.supplyAsync[PartitionLocation]( + new Supplier[PartitionLocation] { + override def get(): PartitionLocation = { + var location = + partitionLocationInfo.getReplicaLocation( + shuffleKey, + requestReplicaLocs.get(ind).getUniqueId) + if (location == null) { + location = requestReplicaLocs.get(ind) + val writer = storageManager.createPartitionDataWriter( + applicationId, + shuffleId, + location, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible) + replicaWriters.add(writer) + new WorkingPartition(location, writer) + } else { + location + } + } + }, + reserveSlotsThreadPool)) + } + + // collect results + val primaryFuture: CompletableFuture[Void] = + collectResults(primaryCfTasks, primaryWriters, startReservePrimaryLocs) + val replicaFuture: CompletableFuture[Void] = + collectResults(replicaCfTasks, replicaWriters, startReserveReplicLocs) + + val future = CompletableFuture.allOf(primaryFuture, replicaFuture) + future.whenComplete(new BiConsumer[Void, Throwable] { + override def accept(ignore: Void, error: Throwable): Unit = { + if (error != null) { + logError(s"An error occurred while reserving slots for $shuffleKey", error) + val msg = s"An error occurred while reserving slots for $shuffleKey: $error"; + context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg)) } else { - replicaLocs.add(location) + val primaryLocs = primaryCfTasks.map(cf => cf.join()).asJava + val replicaLocs = replicaCfTasks.map(cf => cf.join()).asJava + // reserve success, update status + partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs) + partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs) + shufflePartitionType.put(shuffleKey, partitionType) + shufflePushDataTimeout.put( + shuffleKey, + if (pushDataTimeout <= 0) defaultPushdataTimeout else pushDataTimeout) + workerInfo.allocateSlots( + shuffleKey, + Utils.getSlotsPerDisk(requestPrimaryLocs, requestReplicaLocs)) + workerSource.incCounter( + WorkerSource.SLOTS_ALLOCATED, + primaryLocs.size() + replicaLocs.size()) + + logInfo(s"Reserved ${primaryLocs.size()} primary location " + + s"${primaryLocs.asScala.map(_.getUniqueId).mkString(",")} and ${replicaLocs.size()} replica location " + + s"${replicaLocs.asScala.map(_.getUniqueId).mkString(",")} for $shuffleKey ") + if (log.isDebugEnabled()) { + logDebug(s"primary: $primaryLocs\nreplica: $replicaLocs.") + } + context.reply(ReserveSlotsResponse(StatusCode.SUCCESS)) } } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) - } - if (replicaLocs.size() < requestReplicaLocs.size()) { - val msg = s"Not all replica partition satisfied for $shuffleKey" - logWarning(s"[handleReserveSlots] $msg, destroy writers.") - primaryLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } - replicaLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } - context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg)) - return - } - - // reserve success, update status - partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs) - partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs) - shufflePartitionType.put(shuffleKey, partitionType) - shufflePushDataTimeout.put( - shuffleKey, - if (pushDataTimeout <= 0) defaultPushdataTimeout else pushDataTimeout) - workerInfo.allocateSlots( - shuffleKey, - Utils.getSlotsPerDisk(requestPrimaryLocs, requestReplicaLocs)) - workerSource.incCounter(WorkerSource.SLOTS_ALLOCATED, primaryLocs.size() + replicaLocs.size()) - - logInfo(s"Reserved ${primaryLocs.size()} primary location " + - s"${primaryLocs.asScala.map(_.getUniqueId).mkString(",")} and ${replicaLocs.size()} replica location " + - s"${replicaLocs.asScala.map(_.getUniqueId).mkString(",")} for $shuffleKey ") - if (log.isDebugEnabled()) { - logDebug(s"primary: $primaryLocs\nreplica: $replicaLocs.") - } - context.reply(ReserveSlotsResponse(StatusCode.SUCCESS)) + }) } private def commitFiles( @@ -828,4 +868,24 @@ private[deploy] class Controller( mapIdx += 1 } } + + override def onStop(): Unit = { + if (reserveSlotsThreadPool != null) { + reserveSlotsThreadPool.shutdown() + try { + if (!reserveSlotsThreadPool.awaitTermination( + conf.workerGracefulShutdownTimeoutMs, + TimeUnit.MILLISECONDS)) { + logWarning("ReserveSlotsThreadPool shutdown timeout, forcing shutdown.") + reserveSlotsThreadPool.shutdownNow() + } + } catch { + case e: InterruptedException => + logWarning("ReserveSlotsThreadPool shutdown interrupted, forcing shutdown.", e) + reserveSlotsThreadPool.shutdownNow() + Thread.currentThread().interrupt() + } + } + super.onStop() + } }