From d9be058a73ebd93846f03a74d43eadc8846ba7ca Mon Sep 17 00:00:00 2001 From: afterincomparableyum Date: Sat, 17 Jan 2026 16:53:59 -0600 Subject: [PATCH 1/5] [CELEBORN-2221] Support writing with compression in C++ client Integrate existing compression infrastructure (LZ4 and ZSTD) into the C++ client write path. This enables compression during pushData operations, matching the functionality available in the Java client. Changes: - Add compression support to ShuffleClientImpl: * Add shuffleCompressionEnabled_ flag and compressor_ member * Initialize compressor from CelebornConf in constructor * Compress data in pushData() when compression is enabled * Use compressed size for batchBytesSize tracking - Configuration integration: * Read compression codec from celeborn.client.shuffle.compression.codec * Read ZSTD compression level from celeborn.client.shuffle.compression.zstd.level * Default to NONE (compression disabled) - Retry/revive support: * Retry path correctly uses pre-compressed body buffer * No re-compression needed during retries - Testing: * Add CompressorFactoryTest for factory pattern and config integration * Add compression config tests to CelebornConfTest * Test offset compression support for both LZ4 and ZSTD --- cpp/celeborn/client/ShuffleClient.cpp | 32 +++- cpp/celeborn/client/ShuffleClient.h | 3 + cpp/celeborn/client/tests/CMakeLists.txt | 3 +- .../client/tests/CompressorFactoryTest.cpp | 139 ++++++++++++++++++ cpp/celeborn/conf/tests/CelebornConfTest.cpp | 13 ++ 5 files changed, 183 insertions(+), 7 deletions(-) create mode 100644 cpp/celeborn/client/tests/CompressorFactoryTest.cpp diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 401b99e2f27..0679bd3d277 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -57,7 +57,13 @@ ShuffleClientImpl::ShuffleClientImpl( : appUniqueId_(appUniqueId), conf_(conf), clientFactory_(clientEndpoint.clientFactory()), - pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) { + pushDataRetryPool_(clientEndpoint.pushDataRetryPool()), + shuffleCompressionEnabled_( + conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE), + compressor_( + shuffleCompressionEnabled_ + ? compress::Compressor::createCompressor(*conf) + : nullptr) { CELEBORN_CHECK_NOT_NULL(clientFactory_); CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_); } @@ -154,23 +160,37 @@ int ShuffleClientImpl::pushData( auto pushState = getPushState(mapKey); const int nextBatchId = pushState->nextBatchId(); - // TODO: compression in writing is not supported. + // Compression support: compress data if compression is enabled + const uint8_t* dataToWrite = data + offset; + size_t lengthToWrite = length; + std::unique_ptr compressedBuffer; + + if (shuffleCompressionEnabled_ && compressor_) { + // Allocate buffer for compressed data + const size_t compressedCapacity = compressor_->getDstCapacity(length); + compressedBuffer = std::make_unique(compressedCapacity); + + // Compress the data + lengthToWrite = + compressor_->compress(data, offset, length, compressedBuffer.get(), 0); + dataToWrite = compressedBuffer.get(); + } auto writeBuffer = - memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length); + memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + lengthToWrite); // TODO: the java side uses Platform to write the data. We simply assume // littleEndian here. writeBuffer->writeLE(mapId); writeBuffer->writeLE(attemptId); writeBuffer->writeLE(nextBatchId); - writeBuffer->writeLE(length); - writeBuffer->writeFromBuffer(data, offset, length); + writeBuffer->writeLE(lengthToWrite); + writeBuffer->writeFromBuffer(dataToWrite, 0, lengthToWrite); auto hostAndPushPort = partitionLocation->hostAndPushPort(); // Check limit. limitMaxInFlight(mapKey, *pushState, hostAndPushPort); // Add inFlight requests. - const int batchBytesSize = length + kBatchHeaderSize; + const int batchBytesSize = lengthToWrite + kBatchHeaderSize; pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort); // Build pushData request. const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId); diff --git a/cpp/celeborn/client/ShuffleClient.h b/cpp/celeborn/client/ShuffleClient.h index 3e8cb9d3787..ecc0354b8ce 100644 --- a/cpp/celeborn/client/ShuffleClient.h +++ b/cpp/celeborn/client/ShuffleClient.h @@ -17,6 +17,7 @@ #pragma once +#include "celeborn/client/compress/Compressor.h" #include "celeborn/client/reader/CelebornInputStream.h" #include "celeborn/client/writer/PushDataCallback.h" #include "celeborn/client/writer/PushState.h" @@ -249,6 +250,8 @@ class ShuffleClientImpl static constexpr size_t kBatchHeaderSize = 4 * 4; const std::string appUniqueId_; + const bool shuffleCompressionEnabled_; + std::unique_ptr compressor_; std::shared_ptr conf_; std::shared_ptr lifecycleManagerRef_; std::shared_ptr clientFactory_; diff --git a/cpp/celeborn/client/tests/CMakeLists.txt b/cpp/celeborn/client/tests/CMakeLists.txt index e19703f314f..63c37c6e186 100644 --- a/cpp/celeborn/client/tests/CMakeLists.txt +++ b/cpp/celeborn/client/tests/CMakeLists.txt @@ -22,7 +22,8 @@ add_executable( Lz4DecompressorTest.cpp ZstdDecompressorTest.cpp Lz4CompressorTest.cpp - ZstdCompressorTest.cpp) + ZstdCompressorTest.cpp + CompressorFactoryTest.cpp) add_test(NAME celeborn_client_test COMMAND celeborn_client_test) diff --git a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp new file mode 100644 index 00000000000..af1092a28f5 --- /dev/null +++ b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "celeborn/client/compress/Compressor.h" +#include "celeborn/conf/CelebornConf.h" + +using namespace celeborn; +using namespace celeborn::client; +using namespace celeborn::conf; +using namespace celeborn::protocol; + +TEST(CompressorFactoryTest, CreateLz4CompressorFromConf) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + // Verify it's an LZ4 compressor + EXPECT_GT(compressor->getDstCapacity(100), 0); +} + +TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "3"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + // Verify it's a ZSTD compressor + EXPECT_GT(compressor->getDstCapacity(100), 0); +} + +TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) { + CelebornConf conf; + // Verify default is NONE + EXPECT_EQ(conf.shuffleCompressionCodec(), CompressionCodec::NONE); +} + +TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) { + // Test that configuration correctly reads ZSTD compression levels + for (int level = -5; level <= 10; level++) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, + std::to_string(level)); + + // Verify the compression level is set correctly + EXPECT_EQ(conf.shuffleCompressionZstdCompressLevel(), level); + + // Verify the compressor is created correctly + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + EXPECT_GT(compressor->getDstCapacity(100), 0); + } +} + +TEST(CompressorFactoryTest, CompressWithOffsetLz4) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string prefix = "SKIP_THIS_PREFIX"; + const std::string testData = + "Celeborn compression offset test with structured data: " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4"; + std::string fullData = prefix + testData; + + const auto maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + // Compress with offset (simulating pushData usage pattern) + const size_t compressedSize = compressor->compress( + reinterpret_cast(fullData.data()), + prefix.size(), + testData.size(), + compressedData.data(), + 0); + + // Verify compression succeeded with offset + EXPECT_GT(compressedSize, 0); + EXPECT_LE(compressedSize, maxLength); +} + +TEST(CompressorFactoryTest, CompressWithOffsetZstd) { + CelebornConf conf; + conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + conf.registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "3"); + + auto compressor = compress::Compressor::createCompressor(conf); + ASSERT_NE(compressor, nullptr); + + const std::string prefix = "SKIP_THIS_PREFIX"; + const std::string testData = + "Celeborn compression offset test with structured data: " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4 " + "partition_0:shuffle_1:map_2:attempt_0:batch_3:data_block_4"; + std::string fullData = prefix + testData; + + const auto maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + // Compress with offset (simulating pushData usage pattern) + const size_t compressedSize = compressor->compress( + reinterpret_cast(fullData.data()), + prefix.size(), + testData.size(), + compressedData.data(), + 0); + + // Verify compression succeeded with offset + EXPECT_GT(compressedSize, 0); + EXPECT_LE(compressedSize, maxLength); +} diff --git a/cpp/celeborn/conf/tests/CelebornConfTest.cpp b/cpp/celeborn/conf/tests/CelebornConfTest.cpp index 41efdbc000c..79619d888ac 100644 --- a/cpp/celeborn/conf/tests/CelebornConfTest.cpp +++ b/cpp/celeborn/conf/tests/CelebornConfTest.cpp @@ -19,8 +19,10 @@ #include #include "celeborn/conf/CelebornConf.h" +#include "celeborn/protocol/CompressionCodec.h" using namespace celeborn::conf; +using namespace celeborn::protocol; using CelebornUserError = celeborn::utils::CelebornUserError; using SECOND = std::chrono::seconds; @@ -47,6 +49,8 @@ void testDefaultValues(CelebornConf* conf) { EXPECT_EQ(conf->networkIoNumConnectionsPerPeer(), 1); EXPECT_EQ(conf->networkIoClientThreads(), 0); EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 3); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE); + EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 1); } TEST(CelebornConfTest, defaultValues) { @@ -73,6 +77,15 @@ TEST(CelebornConfTest, setValues) { EXPECT_EQ(conf->networkIoClientThreads(), 10); conf->registerProperty(CelebornConf::kClientFetchMaxReqsInFlight, "10"); EXPECT_EQ(conf->clientFetchMaxReqsInFlight(), 10); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "LZ4"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::LZ4); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::ZSTD); + conf->registerProperty(CelebornConf::kShuffleCompressionCodec, "NONE"); + EXPECT_EQ(conf->shuffleCompressionCodec(), CompressionCodec::NONE); + conf->registerProperty( + CelebornConf::kShuffleCompressionZstdCompressLevel, "5"); + EXPECT_EQ(conf->shuffleCompressionZstdCompressLevel(), 5); EXPECT_THROW( conf->registerProperty("non-exist-key", "non-exist-value"), From 10db9e6b475dda4b1ba0a084a562e68888918649 Mon Sep 17 00:00:00 2001 From: afterincomparableyum Date: Sun, 25 Jan 2026 16:44:01 -0600 Subject: [PATCH 2/5] Change offset to 0 --- cpp/celeborn/client/ShuffleClient.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 0679bd3d277..86d6e6d14e9 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -172,7 +172,7 @@ int ShuffleClientImpl::pushData( // Compress the data lengthToWrite = - compressor_->compress(data, offset, length, compressedBuffer.get(), 0); + compressor_->compress(dataToWrite, 0, length, compressedBuffer.get(), 0); dataToWrite = compressedBuffer.get(); } From 68de6c0560cd8054a416a1874dfe23262acadaae Mon Sep 17 00:00:00 2001 From: afterincomparableyum Date: Mon, 26 Jan 2026 15:49:15 +0000 Subject: [PATCH 3/5] fix clang format issues --- cpp/celeborn/client/ShuffleClient.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 86d6e6d14e9..317e8421497 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -171,8 +171,8 @@ int ShuffleClientImpl::pushData( compressedBuffer = std::make_unique(compressedCapacity); // Compress the data - lengthToWrite = - compressor_->compress(dataToWrite, 0, length, compressedBuffer.get(), 0); + lengthToWrite = compressor_->compress( + dataToWrite, 0, length, compressedBuffer.get(), 0); dataToWrite = compressedBuffer.get(); } From 36d3a41a3fb1ce12bc7920e1c482b3c9edebf852 Mon Sep 17 00:00:00 2001 From: afterincomparableyum Date: Sat, 31 Jan 2026 22:48:48 -0600 Subject: [PATCH 4/5] address comments generated by AI Copilot --- cpp/celeborn/client/ShuffleClient.cpp | 57 +++++++-- cpp/celeborn/client/ShuffleClient.h | 4 +- .../client/tests/CompressorFactoryTest.cpp | 111 ++++++++++++++++-- 3 files changed, 148 insertions(+), 24 deletions(-) diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 317e8421497..0e9eb39ea6e 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -16,6 +16,7 @@ */ #include "celeborn/client/ShuffleClient.h" +#include #include "celeborn/utils/CelebornUtils.h" @@ -60,10 +61,13 @@ ShuffleClientImpl::ShuffleClientImpl( pushDataRetryPool_(clientEndpoint.pushDataRetryPool()), shuffleCompressionEnabled_( conf->shuffleCompressionCodec() != protocol::CompressionCodec::NONE), - compressor_( + compressorFactory_( shuffleCompressionEnabled_ - ? compress::Compressor::createCompressor(*conf) - : nullptr) { + ? std::function()>( + [conf]() { + return compress::Compressor::createCompressor(*conf); + }) + : std::function()>()) { CELEBORN_CHECK_NOT_NULL(clientFactory_); CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_); } @@ -160,31 +164,62 @@ int ShuffleClientImpl::pushData( auto pushState = getPushState(mapKey); const int nextBatchId = pushState->nextBatchId(); + // Validate input size fits in 32-bit int since it is required by compressor + // API and wire protocol + CELEBORN_CHECK( + length <= static_cast(std::numeric_limits::max()), + fmt::format( + "Data length {} exceeds maximum supported size {}", + length, + std::numeric_limits::max())); + // Compression support: compress data if compression is enabled const uint8_t* dataToWrite = data + offset; - size_t lengthToWrite = length; + int lengthToWrite = static_cast(length); std::unique_ptr compressedBuffer; - if (shuffleCompressionEnabled_ && compressor_) { + if (shuffleCompressionEnabled_ && compressorFactory_) { + // Create a new compressor instance for thread-safety + auto compressor = compressorFactory_(); // Allocate buffer for compressed data - const size_t compressedCapacity = compressor_->getDstCapacity(length); + const size_t compressedCapacity = + compressor->getDstCapacity(static_cast(length)); compressedBuffer = std::make_unique(compressedCapacity); // Compress the data - lengthToWrite = compressor_->compress( - dataToWrite, 0, length, compressedBuffer.get(), 0); + const size_t compressedSize = compressor->compress( + dataToWrite, 0, static_cast(length), compressedBuffer.get(), 0); + + CELEBORN_CHECK( + compressedSize <= static_cast(std::numeric_limits::max()), + fmt::format( + "Compressed size {} exceeds maximum supported size {}", + compressedSize, + std::numeric_limits::max())); + + lengthToWrite = static_cast(compressedSize); dataToWrite = compressedBuffer.get(); } - auto writeBuffer = - memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + lengthToWrite); + // Validate final buffer size fits in size_t and int + CELEBORN_CHECK( + static_cast(lengthToWrite) <= + std::numeric_limits::max() - kBatchHeaderSize, + fmt::format( + "Buffer size {} + header {} would overflow", + lengthToWrite, + kBatchHeaderSize)); + + auto writeBuffer = memory::ByteBuffer::createWriteOnly( + kBatchHeaderSize + static_cast(lengthToWrite)); // TODO: the java side uses Platform to write the data. We simply assume // littleEndian here. writeBuffer->writeLE(mapId); writeBuffer->writeLE(attemptId); writeBuffer->writeLE(nextBatchId); writeBuffer->writeLE(lengthToWrite); - writeBuffer->writeFromBuffer(dataToWrite, 0, lengthToWrite); + writeBuffer->writeFromBuffer( + dataToWrite, 0, static_cast(lengthToWrite)); auto hostAndPushPort = partitionLocation->hostAndPushPort(); // Check limit. diff --git a/cpp/celeborn/client/ShuffleClient.h b/cpp/celeborn/client/ShuffleClient.h index ecc0354b8ce..d1e099b2e97 100644 --- a/cpp/celeborn/client/ShuffleClient.h +++ b/cpp/celeborn/client/ShuffleClient.h @@ -251,7 +251,6 @@ class ShuffleClientImpl const std::string appUniqueId_; const bool shuffleCompressionEnabled_; - std::unique_ptr compressor_; std::shared_ptr conf_; std::shared_ptr lifecycleManagerRef_; std::shared_ptr clientFactory_; @@ -269,6 +268,9 @@ class ShuffleClientImpl mapperEndSets_; utils::ConcurrentHashSet stageEndShuffleSet_; + // Factory for creating compressor instances on demand to avoid sharing a + // single non-thread-safe compressor across concurrent operations. + std::function()> compressorFactory_; // TODO: pushExcludedWorker is not supported yet }; } // namespace client diff --git a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp index af1092a28f5..cbfaa27e1f1 100644 --- a/cpp/celeborn/client/tests/CompressorFactoryTest.cpp +++ b/cpp/celeborn/client/tests/CompressorFactoryTest.cpp @@ -18,6 +18,7 @@ #include #include "celeborn/client/compress/Compressor.h" +#include "celeborn/client/compress/Decompressor.h" #include "celeborn/conf/CelebornConf.h" using namespace celeborn; @@ -32,8 +33,26 @@ TEST(CompressorFactoryTest, CreateLz4CompressorFromConf) { auto compressor = compress::Compressor::createCompressor(conf); ASSERT_NE(compressor, nullptr); - // Verify it's an LZ4 compressor - EXPECT_GT(compressor->getDstCapacity(100), 0); + const std::string testData = "Test data for compression"; + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 8); + EXPECT_EQ(compressedData[0], 'L'); + EXPECT_EQ(compressedData[1], 'Z'); + EXPECT_EQ(compressedData[2], '4'); + EXPECT_EQ(compressedData[3], 'B'); + EXPECT_EQ(compressedData[4], 'l'); + EXPECT_EQ(compressedData[5], 'o'); + EXPECT_EQ(compressedData[6], 'c'); + EXPECT_EQ(compressedData[7], 'k'); } TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) { @@ -45,8 +64,27 @@ TEST(CompressorFactoryTest, CreateZstdCompressorFromConf) { auto compressor = compress::Compressor::createCompressor(conf); ASSERT_NE(compressor, nullptr); - // Verify it's a ZSTD compressor - EXPECT_GT(compressor->getDstCapacity(100), 0); + const std::string testData = "Test data for compression"; + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 9); + EXPECT_EQ(compressedData[0], 'Z'); + EXPECT_EQ(compressedData[1], 'S'); + EXPECT_EQ(compressedData[2], 'T'); + EXPECT_EQ(compressedData[3], 'D'); + EXPECT_EQ(compressedData[4], 'B'); + EXPECT_EQ(compressedData[5], 'l'); + EXPECT_EQ(compressedData[6], 'o'); + EXPECT_EQ(compressedData[7], 'c'); + EXPECT_EQ(compressedData[8], 'k'); } TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) { @@ -57,6 +95,8 @@ TEST(CompressorFactoryTest, CompressionCodecNoneDisablesCompression) { TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) { // Test that configuration correctly reads ZSTD compression levels + const std::string testData = "Test data for compression"; + for (int level = -5; level <= 10; level++) { CelebornConf conf; conf.registerProperty(CelebornConf::kShuffleCompressionCodec, "ZSTD"); @@ -67,10 +107,25 @@ TEST(CompressorFactoryTest, ZstdCompressionLevelFromConf) { // Verify the compression level is set correctly EXPECT_EQ(conf.shuffleCompressionZstdCompressLevel(), level); - // Verify the compressor is created correctly + // Verify the compressor is created correctly and produces ZSTD output auto compressor = compress::Compressor::createCompressor(conf); ASSERT_NE(compressor, nullptr); - EXPECT_GT(compressor->getDstCapacity(100), 0); + + const size_t maxLength = compressor->getDstCapacity(testData.size()); + std::vector compressedData(maxLength); + + const size_t compressedSize = compressor->compress( + reinterpret_cast(testData.data()), + 0, + testData.size(), + compressedData.data(), + 0); + + ASSERT_GT(compressedSize, 9); + EXPECT_EQ(compressedData[0], 'Z'); + EXPECT_EQ(compressedData[1], 'S'); + EXPECT_EQ(compressedData[2], 'T'); + EXPECT_EQ(compressedData[3], 'D'); } } @@ -100,9 +155,25 @@ TEST(CompressorFactoryTest, CompressWithOffsetLz4) { compressedData.data(), 0); - // Verify compression succeeded with offset - EXPECT_GT(compressedSize, 0); - EXPECT_LE(compressedSize, maxLength); + ASSERT_GT(compressedSize, 0); + ASSERT_LE(compressedSize, maxLength); + + auto decompressor = + compress::Decompressor::createDecompressor(CompressionCodec::LZ4); + ASSERT_NE(decompressor, nullptr); + + const int originalLen = decompressor->getOriginalLen(compressedData.data()); + EXPECT_EQ(originalLen, testData.size()); + + std::vector decompressedData(originalLen); + const int decompressedSize = decompressor->decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_EQ(decompressedSize, originalLen); + + const std::string decompressedStr( + reinterpret_cast(decompressedData.data()), decompressedSize); + EXPECT_EQ(decompressedStr, testData); + EXPECT_NE(decompressedStr, fullData); } TEST(CompressorFactoryTest, CompressWithOffsetZstd) { @@ -133,7 +204,23 @@ TEST(CompressorFactoryTest, CompressWithOffsetZstd) { compressedData.data(), 0); - // Verify compression succeeded with offset - EXPECT_GT(compressedSize, 0); - EXPECT_LE(compressedSize, maxLength); + ASSERT_GT(compressedSize, 0); + ASSERT_LE(compressedSize, maxLength); + + auto decompressor = + compress::Decompressor::createDecompressor(CompressionCodec::ZSTD); + ASSERT_NE(decompressor, nullptr); + + const int originalLen = decompressor->getOriginalLen(compressedData.data()); + EXPECT_EQ(originalLen, testData.size()); + + std::vector decompressedData(originalLen); + const int decompressedSize = decompressor->decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_EQ(decompressedSize, originalLen); + + const std::string decompressedStr( + reinterpret_cast(decompressedData.data()), decompressedSize); + EXPECT_EQ(decompressedStr, testData); + EXPECT_NE(decompressedStr, fullData); } From 3977b3d820ca07085ac8341719857e8d7d4f0be9 Mon Sep 17 00:00:00 2001 From: afterincomparableyum Date: Wed, 4 Feb 2026 23:10:26 -0800 Subject: [PATCH 5/5] address more of copilot PR suggestions --- cpp/celeborn/client/ShuffleClient.cpp | 6 ++++++ cpp/celeborn/client/ShuffleClient.h | 1 + 2 files changed, 7 insertions(+) diff --git a/cpp/celeborn/client/ShuffleClient.cpp b/cpp/celeborn/client/ShuffleClient.cpp index 0e9eb39ea6e..35887135eee 100644 --- a/cpp/celeborn/client/ShuffleClient.cpp +++ b/cpp/celeborn/client/ShuffleClient.cpp @@ -225,6 +225,12 @@ int ShuffleClientImpl::pushData( // Check limit. limitMaxInFlight(mapKey, *pushState, hostAndPushPort); // Add inFlight requests. + CELEBORN_CHECK( + lengthToWrite <= std::numeric_limits::max() - kBatchHeaderSize, + fmt::format( + "Batch bytes size {} + header {} would overflow int", + lengthToWrite, + kBatchHeaderSize)); const int batchBytesSize = lengthToWrite + kBatchHeaderSize; pushState->addBatch(nextBatchId, batchBytesSize, hostAndPushPort); // Build pushData request. diff --git a/cpp/celeborn/client/ShuffleClient.h b/cpp/celeborn/client/ShuffleClient.h index d1e099b2e97..e899ba0cdaf 100644 --- a/cpp/celeborn/client/ShuffleClient.h +++ b/cpp/celeborn/client/ShuffleClient.h @@ -17,6 +17,7 @@ #pragma once +#include #include "celeborn/client/compress/Compressor.h" #include "celeborn/client/reader/CelebornInputStream.h" #include "celeborn/client/writer/PushDataCallback.h"