Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
get(WORKER_FLUSH_REUSE_COPY_BUFFER_ENABLED)
def workerDfsReplicationFactor: Int =
get(WORKER_DFS_REPLICATION_FACTOR)
def workerReserveSlotsIoThreadPoolSize: Int = {
val configured = get(RESERVE_SLOTS_IO_THREAD_POOL_SIZE);
// reserve slots creates files, locally or on DFS, parallelism can be high
// compared to the number of CPUs
if (configured == 0) Runtime.getRuntime.availableProcessors() * 8
else configured
}

def clusterName: String = get(CLUSTER_NAME)

Expand Down Expand Up @@ -6801,6 +6808,17 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(2)

val RESERVE_SLOTS_IO_THREAD_POOL_SIZE: ConfigEntry[Int] =
buildConf("celeborn.worker.reserve.slots.io.threads")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dot for namespaces, camelCase for words

Suggested change
buildConf("celeborn.worker.reserve.slots.io.threads")
buildConf("celeborn.worker.reserveSlots.io.threads")

.categories("worker")
.version("0.7.0")
.doc("The number of threads used to create PartitionDataWriter in parallel in handleReserveSlots.")
.intConf
.checkValue(
v => v >= 0,
"The number of threads must be positive or zero. Setting to zero lets the worker compute the optimal size automatically")
.createWithDefault(1)

val CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.shuffleDataLostOnUnknownWorker.enabled")
.categories("client")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,18 @@ class CelebornConfSuite extends CelebornFunSuite {
}
}

test("test workerReserveSlotsIoThreadPoolSize") {
// Test default value
val conf1 = new CelebornConf()
assert(conf1.workerReserveSlotsIoThreadPoolSize == 1)

// Test configured value
val conf2 = new CelebornConf().set(RESERVE_SLOTS_IO_THREAD_POOL_SIZE.key, "10")
assert(conf2.workerReserveSlotsIoThreadPoolSize == 10)

// Test configured value with 0
val conf3 = new CelebornConf().set(RESERVE_SLOTS_IO_THREAD_POOL_SIZE.key, "0")
assert(conf3.workerReserveSlotsIoThreadPoolSize == Runtime.getRuntime.availableProcessors() * 8)
}

}
1 change: 1 addition & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ license: |
| celeborn.worker.replicate.port | 0 | false | Server port for Worker to receive replicate data request from other Workers. | 0.2.0 | |
| celeborn.worker.replicate.randomConnection.enabled | true | false | Whether worker will create random connection to peer when replicate data. When false, worker tend to reuse the same cached TransportClient to a specific replicate worker; when true, worker tend to use different cached TransportClient. Netty will use the same thread to serve the same connection, so with more connections replicate server can leverage more netty threads | 0.2.1 | |
| celeborn.worker.replicate.threads | 64 | false | Thread number of worker to replicate shuffle data. | 0.2.0 | |
| celeborn.worker.reserve.slots.io.threads | 1 | false | The number of threads used to create PartitionDataWriter in parallel in handleReserveSlots. | 0.7.0 | |
| celeborn.worker.reuse.hdfs.outputStream.enabled | false | false | Whether to enable reuse output stream on hdfs. | 0.7.0 | |
| celeborn.worker.rpc.port | 0 | false | Server port for Worker to receive RPC request. | 0.2.0 | |
| celeborn.worker.shuffle.partitionSplit.enabled | true | false | enable the partition split on worker side | 0.3.0 | celeborn.worker.partition.split.enabled |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.celeborn.service.deploy.worker

import java.io.IOException
import java.util.{ArrayList => jArrayList, HashMap => jHashMap, List => jList, Set => jSet}
import java.util.concurrent._
import java.util.concurrent.{CopyOnWriteArrayList, _}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray, AtomicReference}
import java.util.function.BiFunction
import java.util.function.{BiConsumer, BiFunction, Consumer, Supplier}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.Try

import io.netty.util.{HashedWheelTimer, Timeout, TimerTask}
import org.roaringbitmap.RoaringBitmap
Expand Down Expand Up @@ -62,6 +63,7 @@ private[deploy] class Controller(
var partitionLocationInfo: WorkerPartitionLocationInfo = _
var timer: HashedWheelTimer = _
var commitThreadPool: ThreadPoolExecutor = _
var reserveSlotsThreadPool: ThreadPoolExecutor = _
var commitFinishedChecker: ScheduledExecutorService = _
var asyncReplyPool: ScheduledExecutorService = _
val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
Expand All @@ -86,6 +88,10 @@ private[deploy] class Controller(
asyncReplyPool = worker.asyncReplyPool
shutdown = worker.shutdown

reserveSlotsThreadPool =
Executors.newFixedThreadPool(conf.workerReserveSlotsIoThreadPoolSize).asInstanceOf[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Celeborn's util to create a thread pool instead of raw juc classes, to properly set no daemon, name prefix, exception handler, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use ThreadUtils.sameThreadExecutionContext for threads = 1.

when something goes wrong, we will lost full stacktrace if we run the task in another thread.

ThreadPoolExecutor]

commitFinishedChecker = worker.commitFinishedChecker
commitFinishedChecker.scheduleWithFixedDelay(
new Runnable {
Expand Down Expand Up @@ -193,111 +199,145 @@ private[deploy] class Controller(
context.reply(ReserveSlotsResponse(StatusCode.NO_AVAILABLE_WORKING_DIR, msg))
return
}
val primaryLocs = new jArrayList[PartitionLocation]()
try {
for (ind <- 0 until requestPrimaryLocs.size()) {
var location = partitionLocationInfo.getPrimaryLocation(
shuffleKey,
requestPrimaryLocs.get(ind).getUniqueId)
if (location == null) {
location = requestPrimaryLocs.get(ind)
val writer = storageManager.createPartitionDataWriter(
applicationId,
shuffleId,
location,
splitThreshold,
splitMode,
partitionType,
rangeReadFilter,
userIdentifier,
partitionSplitEnabled,
isSegmentGranularityVisible)
primaryLocs.add(new WorkingPartition(location, writer))
} else {
primaryLocs.add(location)
}
}
} catch {
case e: Exception =>
logError(s"CreateWriter for $shuffleKey failed.", e)

def collectResults(
tasks: ArrayBuffer[CompletableFuture[PartitionLocation]],
createdWriters: CopyOnWriteArrayList[PartitionDataWriter],
startTime: Long) = {
val primaryFuture = CompletableFuture.allOf(tasks.toSeq: _*)
.whenComplete(new BiConsumer[Void, Throwable] {
override def accept(ignore: Void, error: Throwable): Unit = {
if (error != null) {
createdWriters.forEach(new Consumer[PartitionDataWriter] {
override def accept(fileWriter: PartitionDataWriter) {
fileWriter.destroy(new IOException(
s"Destroy FileWriter $fileWriter caused by " +
s"reserving slots failed for $shuffleKey.",
error))
}
})
} else {
val timeToReserveLocations = System.currentTimeMillis() - startTime;
logInfo(
s"Reserved ${tasks.size} slots for $shuffleKey in $timeToReserveLocations ms (with ${conf.workerReserveSlotsIoThreadPoolSize} threads)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calculation of ${conf.workerReserveSlotsIoThreadPoolSize} is not free, especially on the hot path, so materialize it to save the evaluation cost every time

}
createdWriters.clear()
}
})
primaryFuture
}
if (primaryLocs.size() < requestPrimaryLocs.size()) {
val msg = s"Not all primary partition satisfied for $shuffleKey"
logWarning(s"[handleReserveSlots] $msg, will destroy writers.")
primaryLocs.asScala.foreach { partitionLocation =>
val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " +
s"reserving slots failed for $shuffleKey."))
}
context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg))
return

logInfo(s"Reserving ${requestPrimaryLocs.size()} primary slots for $shuffleKey")
val startReservePrimaryLocs = System.currentTimeMillis
val primaryCfTasks = ArrayBuffer[CompletableFuture[PartitionLocation]]()
val primaryWriters: CopyOnWriteArrayList[PartitionDataWriter] =
new CopyOnWriteArrayList[PartitionDataWriter]
(0 until requestPrimaryLocs.size()).foreach { ind =>
primaryCfTasks.append(CompletableFuture.supplyAsync[PartitionLocation](
new Supplier[PartitionLocation] {
override def get(): PartitionLocation = {
var location: PartitionLocation = partitionLocationInfo.getPrimaryLocation(
shuffleKey,
requestPrimaryLocs.get(ind).getUniqueId)
if (location == null) {
location = requestPrimaryLocs.get(ind)
val writer = storageManager.createPartitionDataWriter(
applicationId,
shuffleId,
location,
splitThreshold,
splitMode,
partitionType,
rangeReadFilter,
userIdentifier,
partitionSplitEnabled,
isSegmentGranularityVisible)
primaryWriters.add(writer)
new WorkingPartition(location, writer)
} else {
location
}
}
},
reserveSlotsThreadPool))
}

val replicaLocs = new jArrayList[PartitionLocation]()
try {
for (ind <- 0 until requestReplicaLocs.size()) {
var location =
partitionLocationInfo.getReplicaLocation(
shuffleKey,
requestReplicaLocs.get(ind).getUniqueId)
if (location == null) {
location = requestReplicaLocs.get(ind)
val writer = storageManager.createPartitionDataWriter(
applicationId,
shuffleId,
location,
splitThreshold,
splitMode,
partitionType,
rangeReadFilter,
userIdentifier,
partitionSplitEnabled,
isSegmentGranularityVisible)
replicaLocs.add(new WorkingPartition(location, writer))
logInfo(s"Reserving ${requestReplicaLocs.size()} replica slots for $shuffleKey")
val startReserveReplicLocs = System.currentTimeMillis
val replicaCfTasks = ArrayBuffer[CompletableFuture[PartitionLocation]]()
val replicaWriters: CopyOnWriteArrayList[PartitionDataWriter] =
new CopyOnWriteArrayList[PartitionDataWriter]
(0 until requestReplicaLocs.size()).foreach { ind =>
replicaCfTasks.append(CompletableFuture.supplyAsync[PartitionLocation](
new Supplier[PartitionLocation] {
override def get(): PartitionLocation = {
var location =
partitionLocationInfo.getReplicaLocation(
shuffleKey,
requestReplicaLocs.get(ind).getUniqueId)
if (location == null) {
location = requestReplicaLocs.get(ind)
val writer = storageManager.createPartitionDataWriter(
applicationId,
shuffleId,
location,
splitThreshold,
splitMode,
partitionType,
rangeReadFilter,
userIdentifier,
partitionSplitEnabled,
isSegmentGranularityVisible)
replicaWriters.add(writer)
new WorkingPartition(location, writer)
} else {
location
}
}
},
reserveSlotsThreadPool))
}

// collect results
val primaryFuture: CompletableFuture[Void] =
collectResults(primaryCfTasks, primaryWriters, startReservePrimaryLocs)
val replicaFuture: CompletableFuture[Void] =
collectResults(replicaCfTasks, replicaWriters, startReserveReplicLocs)

val future = CompletableFuture.allOf(primaryFuture, replicaFuture)
future.whenComplete(new BiConsumer[Void, Throwable] {
override def accept(ignore: Void, error: Throwable): Unit = {
if (error != null) {
logError(s"An error occurred while reserving slots for $shuffleKey", error)
val msg = s"An error occurred while reserving slots for $shuffleKey: $error";
context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg))
} else {
replicaLocs.add(location)
val primaryLocs = primaryCfTasks.map(cf => cf.join()).asJava
val replicaLocs = replicaCfTasks.map(cf => cf.join()).asJava
// reserve success, update status
partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs)
partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs)
shufflePartitionType.put(shuffleKey, partitionType)
shufflePushDataTimeout.put(
shuffleKey,
if (pushDataTimeout <= 0) defaultPushdataTimeout else pushDataTimeout)
workerInfo.allocateSlots(
shuffleKey,
Utils.getSlotsPerDisk(requestPrimaryLocs, requestReplicaLocs))
workerSource.incCounter(
WorkerSource.SLOTS_ALLOCATED,
primaryLocs.size() + replicaLocs.size())

logInfo(s"Reserved ${primaryLocs.size()} primary location " +
s"${primaryLocs.asScala.map(_.getUniqueId).mkString(",")} and ${replicaLocs.size()} replica location " +
s"${replicaLocs.asScala.map(_.getUniqueId).mkString(",")} for $shuffleKey ")
if (log.isDebugEnabled()) {
logDebug(s"primary: $primaryLocs\nreplica: $replicaLocs.")
}
context.reply(ReserveSlotsResponse(StatusCode.SUCCESS))
}
}
} catch {
case e: Exception =>
logError(s"CreateWriter for $shuffleKey failed.", e)
}
if (replicaLocs.size() < requestReplicaLocs.size()) {
val msg = s"Not all replica partition satisfied for $shuffleKey"
logWarning(s"[handleReserveSlots] $msg, destroy writers.")
primaryLocs.asScala.foreach { partitionLocation =>
val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " +
s"reserving slots failed for $shuffleKey."))
}
replicaLocs.asScala.foreach { partitionLocation =>
val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " +
s"reserving slots failed for $shuffleKey."))
}
context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg))
return
}

// reserve success, update status
partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs)
partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs)
shufflePartitionType.put(shuffleKey, partitionType)
shufflePushDataTimeout.put(
shuffleKey,
if (pushDataTimeout <= 0) defaultPushdataTimeout else pushDataTimeout)
workerInfo.allocateSlots(
shuffleKey,
Utils.getSlotsPerDisk(requestPrimaryLocs, requestReplicaLocs))
workerSource.incCounter(WorkerSource.SLOTS_ALLOCATED, primaryLocs.size() + replicaLocs.size())

logInfo(s"Reserved ${primaryLocs.size()} primary location " +
s"${primaryLocs.asScala.map(_.getUniqueId).mkString(",")} and ${replicaLocs.size()} replica location " +
s"${replicaLocs.asScala.map(_.getUniqueId).mkString(",")} for $shuffleKey ")
if (log.isDebugEnabled()) {
logDebug(s"primary: $primaryLocs\nreplica: $replicaLocs.")
}
context.reply(ReserveSlotsResponse(StatusCode.SUCCESS))
})
}

private def commitFiles(
Expand Down Expand Up @@ -828,4 +868,24 @@ private[deploy] class Controller(
mapIdx += 1
}
}

override def onStop(): Unit = {
if (reserveSlotsThreadPool != null) {
reserveSlotsThreadPool.shutdown()
try {
if (!reserveSlotsThreadPool.awaitTermination(
conf.workerGracefulShutdownTimeoutMs,
TimeUnit.MILLISECONDS)) {
logWarning("ReserveSlotsThreadPool shutdown timeout, forcing shutdown.")
reserveSlotsThreadPool.shutdownNow()
}
} catch {
case e: InterruptedException =>
logWarning("ReserveSlotsThreadPool shutdown interrupted, forcing shutdown.", e)
reserveSlotsThreadPool.shutdownNow()
Thread.currentThread().interrupt()
}
}
super.onStop()
}
}
Loading