diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 401b99e2f27..35887135eee 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -16,6 +16,7 @@ */ #include "celeborn/client/ShuffleClient.h" +#include #include "celeborn/utils/CelebornUtils.h" @@ -57,7 +58,16 @@ ShuffleClientImpl::ShuffleClientImpl( : appUniqueId_(appUniqueId), conf_(conf), clientFactory_(clientEndpoint.clientFactory()), - pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) { + pushDataRetryPool_(clientEndpoint.pushDataRetryPool()), + shuffleCompressionEnabled_( + conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE), + compressorFactory_( + shuffleCompressionEnabled_ + ? std::function()>( + [conf]() { + return compress::Compressor::createCompressor(*conf); + }) + : std::function()>()) { CELEBORN_CHECK_NOT_NULL(clientFactory_); CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_); } @@ -154,23 +164,74 @@ int ShuffleClientImpl::pushData( auto pushState = getPushState(mapKey); const int nextBatchId = pushState->nextBatchId(); - // TODO: compression in writing is not supported. + // Validate input size fits in 32-bit int since it is required by compressor + // API and wire protocol + CELEBORN_CHECK( + length <= static_cast(std::numeric_limits::max()), + fmt::format( + "Data length {} exceeds maximum supported size {}", + length, + std::numeric_limits::max())); + + // Compression support: compress data if compression is enabled + const uint8_t* dataToWrite = data + offset; + int lengthToWrite = static_cast(length); + std::unique_ptr compressedBuffer; + + if (shuffleCompressionEnabled_ && compressorFactory_) { + // Create a new compressor instance for thread-safety + auto compressor = compressorFactory_(); + // Allocate buffer for compressed data + const size_t compressedCapacity = + compressor->getDstCapacity(static_cast(length)); + compressedBuffer = std::make_unique(compressedCapacity); + + // Compress the data + const size_t compressedSize = compressor->compress( + dataToWrite, 0, static_cast(length), compressedBuffer.get(), 0); + + CELEBORN_CHECK( + compressedSize <= static_cast(std::numeric_limits::max()), + fmt::format( + "Compressed size {} exceeds maximum supported size {}", + compressedSize, + std::numeric_limits::max())); + + lengthToWrite = static_cast(compressedSize); + dataToWrite = compressedBuffer.get(); + } - auto writeBuffer = - memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length); + // Validate final buffer size fits in size_t and int + CELEBORN_CHECK( + static_cast(lengthToWrite) <= + std::numeric_limits::max() - kBatchHeaderSize, + fmt::format( + "Buffer size {} + header {} would overflow", + lengthToWrite, + kBatchHeaderSize)); + + auto writeBuffer = memory::ByteBuffer::createWriteOnly( + kBatchHeaderSize + static_cast(lengthToWrite)); // TODO: the java side uses Platform to write the data. We simply assume // littleEndian here. writeBuffer->writeLE(mapId); writeBuffer->writeLE(attemptId); writeBuffer->writeLE(nextBatchId); - writeBuffer->writeLE(length); - writeBuffer->writeFromBuffer(data, offset, length); + writeBuffer->writeLE(lengthToWrite); + writeBuffer->writeFromBuffer( + dataToWrite, 0, static_cast(lengthToWrite)); auto hostAndPushPort = partitionLocation->hostAndPushPort(); // Check limit. limitMaxInFlight(mapKey, *pushState, hostAndPushPort); // Add inFlight requests. - const int batchBytesSize = length + kBatchHeaderSize; + CELEBORN_CHECK( + lengthToWrite <= std::numeric_limits::max() - kBatchHeaderSize, + fmt::format( + "Batch bytes size {} + header {} would overflow int", + lengthToWrite, + kBatchHeaderSize)); + const int batchBytesSize = lengthToWrite + kBatchHeaderSize; pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort); // Build pushData request. const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId); diff --git a/cpp/celeborn/client/ShuffleClient.h b/cpp/celeborn/client/ShuffleClient.h index 3e8cb9d3787..e899ba0cdaf 100644 --- a/cpp/celeborn/client/ShuffleClient.h +++ b/cpp/celeborn/client/ShuffleClient.h @@ -17,6 +17,8 @@ #pragma once +#include +#include "celeborn/client/compress/Compressor.h" #include "celeborn/client/reader/CelebornInputStream.h" #include "celeborn/client/writer/PushDataCallback.h" #include "celeborn/client/writer/PushState.h" @@ -249,6 +251,7 @@ class ShuffleClientImpl static constexpr size_t kBatchHeaderSize = 4 * 4; const std::string appUniqueId_; + const bool shuffleCompressionEnabled_; std::shared_ptr conf_; std::shared_ptr lifecycleManagerRef_; std::shared_ptr clientFactory_; @@ -266,6 +269,9 @@ class ShuffleClientImpl mapperEndSets_; utils::ConcurrentHashSet stageEndShuffleSet_; + // Factory for creating compressor instances on demand to avoid sharing a + // single non-thread-safe compressor across concurrent operations. + std::function()> compressorFactory_; // TODO: pushExcludedWorker is not supported yet }; } // namespace client diff --git a/cpp/celeborn/client/tests/CMakeLists.txt b/cpp/celeborn/client/tests/CMakeLists.txt index e19703f314f..63c37c6e186 100644 --- a/cpp/celeborn/client/tests/CMakeLists.txt +++ b/cpp/celeborn/client/tests/CMakeLists.txt @@ -22,7 +22,8 @@ add_executable( Lz4DecompressorTest.cpp ZstdDecompressorTest.cpp Lz4CompressorTest.cpp - ZstdCompressorTest.cpp) + ZstdCompressorTest.cpp + CompressorFactoryTest.cpp) add_test(NAME celeborn_client_test COMMAND celeborn_client_test) diff --git a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp new file mode 100644 index 00000000000..cbfaa27e1f1 --- /dev/null +++ b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "celeborn/client/compress/Compressor.h" +#include "celeborn/client/compress/Decompressor.h" +#include "celeborn/conf/CelebornConf.h" + +using namespace celeborn; +using namespace celeborn::client; +using namespace celeborn::conf; +using namespace celeborn::protocol; + +TEST(CompressorFactoryTest, CreateLz4CompressorFromConf) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string testData = "Test data for compression"; + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 8); + EXPECT_EQ(compressedData[0], 'L'); + EXPECT_EQ(compressedData[1], 'Z'); + EXPECT_EQ(compressedData[2], '4'); + EXPECT_EQ(compressedData[3], 'B'); + EXPECT_EQ(compressedData[4], 'l'); + EXPECT_EQ(compressedData[5], 'o'); + EXPECT_EQ(compressedData[6], 'c'); + EXPECT_EQ(compressedData[7], 'k'); +} + +TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "3"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string testData = "Test data for compression"; + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 9); + EXPECT_EQ(compressedData[0], 'Z'); + EXPECT_EQ(compressedData[1], 'S'); + EXPECT_EQ(compressedData[2], 'T'); + EXPECT_EQ(compressedData[3], 'D'); + EXPECT_EQ(compressedData[4], 'B'); + EXPECT_EQ(compressedData[5], 'l'); + EXPECT_EQ(compressedData[6], 'o'); + EXPECT_EQ(compressedData[7], 'c'); + EXPECT_EQ(compressedData[8], 'k'); +} + +TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) { + CelebornConf conf; + // Verify default is NONE + EXPECT_EQ(conf.shuffleCompressionCodec(), CompressionCodec::NONE); +} + +TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) { + // Test that configuration correctly reads ZSTD compression levels + const std::string testData = "Test data for compression"; + + for (int level = -5; level <= 10; level++) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, + std::to_string(level)); + + // Verify the compression level is set correctly + EXPECT_EQ(conf.shuffleCompressionZstdCompressLevel(), level); + + // Verify the compressor is created correctly and produces ZSTD output + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 9); + EXPECT_EQ(compressedData[0], 'Z'); + EXPECT_EQ(compressedData[1], 'S'); + EXPECT_EQ(compressedData[2], 'T'); + EXPECT_EQ(compressedData[3], 'D'); + } +} + +TEST(CompressorFactoryTest, CompressWithOffsetLz4) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string prefix = "SKIP_THIS_PREFIX"; + const std::string testData = + "Celeborn compression offset test with structured data: " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4"; + std::string fullData = prefix + testData; + + const auto maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + // Compress with offset (simulating pushData usage pattern) + const size_t compressedSize = compressor->compress( + reinterpret_cast(fullData.data()), + prefix.size(), + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 0); + ASSERT_LE(compressedSize, maxLength); + + auto decompressor = + compress::Decompressor::createDecompressor(CompressionCodec::LZ4); + ASSERT_NE(decompressor, nullptr); + + const int originalLen = decompressor->getOriginalLen(compressedData.data()); + EXPECT_EQ(originalLen, testData.size()); + + std::vector decompressedData(originalLen); + const int decompressedSize = decompressor->decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_EQ(decompressedSize, originalLen); + + const std::string decompressedStr( + reinterpret_cast(decompressedData.data()), decompressedSize); + EXPECT_EQ(decompressedStr, testData); + EXPECT_NE(decompressedStr, fullData); +} + +TEST(CompressorFactoryTest, CompressWithOffsetZstd) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "3"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string prefix = "SKIP_THIS_PREFIX"; + const std::string testData = + "Celeborn compression offset test with structured data: " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4"; + std::string fullData = prefix + testData; + + const auto maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + // Compress with offset (simulating pushData usage pattern) + const size_t compressedSize = compressor->compress( + reinterpret_cast(fullData.data()), + prefix.size(), + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 0); + ASSERT_LE(compressedSize, maxLength); + + auto decompressor = + compress::Decompressor::createDecompressor(CompressionCodec::ZSTD); + ASSERT_NE(decompressor, nullptr); + + const int originalLen = decompressor->getOriginalLen(compressedData.data()); + EXPECT_EQ(originalLen, testData.size()); + + std::vector decompressedData(originalLen); + const int decompressedSize = decompressor->decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_EQ(decompressedSize, originalLen); + + const std::string decompressedStr( + reinterpret_cast(decompressedData.data()), decompressedSize); + EXPECT_EQ(decompressedStr, testData); + EXPECT_NE(decompressedStr, fullData); +} diff --git a/cpp/celeborn/conf/tests/CelebornConfTest.cpp b/cpp/celeborn/conf/tests/CelebornConfTest.cpp index 41efdbc000c..79619d888ac 100644 --- a/cpp/celeborn/conf/tests/CelebornConfTest.cpp +++ b/cpp/celeborn/conf/tests/CelebornConfTest.cpp @@ -19,8 +19,10 @@ #include #include "celeborn/conf/CelebornConf.h" +#include "celeborn/protocol/CompressionCodec.h" using namespace celeborn::conf; +using namespace celeborn::protocol; using CelebornUserError = celeborn::utils::CelebornUserError; using SECOND = std::chrono::seconds; @@ -47,6 +49,8 @@ void testDefaultValues(CelebornConf* conf) { EXPECT_EQ(conf->networkIoNumConnectionsPerPeer(), 1); EXPECT_EQ(conf->networkIoClientThreads(), 0); EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 3); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE); + EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 1); } TEST(CelebornConfTest, defaultValues) { @@ -73,6 +77,15 @@ TEST(CelebornConfTest, setValues) { EXPECT_EQ(conf->networkIoClientThreads(), 10); conf->registerProperty(CelebornConf::kClientFetchMaxReqsInFlight, "10"); EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 10); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::LZ4); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::ZSTD); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "NONE"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE); + conf->registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "5"); + EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 5); EXPECT_THROW( conf->registerProperty("non-exist-key", "non-exist-value"),