Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6e53791
Support chunk compressed data writer
saurabhd336 May 19, 2026
53a4e81
Add read support
saurabhd336 May 23, 2026
5520acc
Add tests
saurabhd336 May 24, 2026
f917a57
lint fix
saurabhd336 May 24, 2026
cde5467
Add e2e test
saurabhd336 May 25, 2026
4f76be5
Fix lint
saurabhd336 May 25, 2026
c3ed330
Fix sbt build
saurabhd336 May 26, 2026
485dd23
Fix compression level compilation
saurabhd336 May 26, 2026
3d1daa4
Fix client.md
saurabhd336 May 26, 2026
c3d4bc8
Fix test
saurabhd336 May 26, 2026
7c2cbff
Fix test
saurabhd336 May 26, 2026
78611fd
Fix chunk compressed writer
saurabhd336 Jun 1, 2026
187906a
Avoid chunk compression during large records
saurabhd336 Jun 3, 2026
3533e52
Move to chunk compression context message
saurabhd336 Jun 3, 2026
174db3d
Fix compression
saurabhd336 Jun 4, 2026
91a49e7
Don't compress large records
saurabhd336 Jun 4, 2026
032f7aa
Fix ChunkCompressedFileChannelWriterSuiteJ to handle uncompressed lar…
saurabhd336 Jun 4, 2026
39595da
Lint fix
saurabhd336 Jun 4, 2026
6d8a640
Fix compilation
saurabhd336 Jun 4, 2026
3b1ea63
Fix test
saurabhd336 Jun 4, 2026
c9b4785
Fix lint
saurabhd336 Jun 4, 2026
18f3ac7
Address comments
saurabhd336 Jun 9, 2026
e6e45ac
Address comments
saurabhd336 Jun 9, 2026
709b05d
Address comments
saurabhd336 Jun 9, 2026
563b5b9
Lint fix
saurabhd336 Jun 9, 2026
0ddb927
Address comments
saurabhd336 Jun 10, 2026
6f03c94
Address comments
saurabhd336 Jun 10, 2026
7b503bc
address comments
saurabhd336 Jun 10, 2026
c96df79
Address
saurabhd336 Jun 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import scala.Tuple2;

import com.github.luben.zstd.ZstdException;
import com.github.luben.zstd.ZstdInputStream;
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import net.jpountz.lz4.LZ4Exception;
import org.apache.commons.lang3.tuple.Pair;
import org.roaringbitmap.RoaringBitmap;
Expand Down Expand Up @@ -193,6 +195,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private Decompressor decompressor;

private ByteBuf currentChunk;
private boolean currentChunkCompressed = false;
private boolean firstChunk = true;
private PartitionReader currentReader;
private final int fetchChunkMaxRetry;
Expand All @@ -213,6 +216,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final String localHostAddress;

private boolean shouldDecompress;
private InputStream currentStream;
Comment thread
SteNicholas marked this conversation as resolved.
private boolean shuffleIntegrityCheckEnabled;
private long fetchExcludedWorkerExpireTimeout;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
Expand Down Expand Up @@ -526,7 +530,9 @@ private ByteBuf getNextChunk() throws IOException {
if (!currentReader.hasNext()) {
return null;
}
return currentReader.next();
Pair<ByteBuf, Boolean> result = currentReader.next();
currentChunkCompressed = result.getRight();
return result.getLeft();
} catch (Exception e) {
shuffleClient.excludeFailedFetchLocation(
currentReader.getLocation().hostAndFetchPort(), e);
Expand Down Expand Up @@ -730,6 +736,7 @@ public synchronized void close() {

compressedBuf = null;
rawDataBuf = null;
closeCurrentStream();
batchesRead = null;
locations = null;
attempts = null;
Expand Down Expand Up @@ -800,6 +807,34 @@ private void init() {
rawDataBuf = new byte[bufferSize];
}

private void closeCurrentStream() {
if (currentStream != null) {
try {
currentStream.close();
} catch (IOException ignored) {
}
currentStream = null;
}
}

private void setupCurrentStream() throws IOException {
closeCurrentStream();
if (currentChunk == null) return;
InputStream base = new ByteBufInputStream(currentChunk);
currentStream = currentChunkCompressed ? new ZstdInputStream(base) : base;
}
Comment thread
SteNicholas marked this conversation as resolved.

/** Reads exactly len bytes; returns total read (< len only on EOF). */
private static int readFully(InputStream in, byte[] buf, int off, int len) throws IOException {
int total = 0;
while (total < len) {
int n = in.read(buf, off + total, len - total);
if (n == -1) break;
total += n;
}
return total;
}

private boolean fillBuffer() throws IOException {
try {
if (firstChunk && currentReader != null) {
Expand All @@ -814,10 +849,23 @@ private boolean fillBuffer() throws IOException {
return false;
}

if (currentStream == null) {
setupCurrentStream();
}

LocationPushFailedBatches failedBatch = new LocationPushFailedBatches();
boolean hasData = false;
while (currentChunk.isReadable() || moveToNextChunk()) {
currentChunk.readBytes(sizeBuf);
while (true) {
int headerRead = readFully(currentStream, sizeBuf, 0, BATCH_HEADER_SIZE);
if (headerRead == 0) {
closeCurrentStream();
if (!moveToNextChunk()) break;
setupCurrentStream();
continue;
} else if (headerRead != BATCH_HEADER_SIZE) {
throw new IOException("Invalid EOF detected");
}
Comment thread
SteNicholas marked this conversation as resolved.
Comment thread
SteNicholas marked this conversation as resolved.

int mapId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET);
int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4);
int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8);
Expand All @@ -827,14 +875,16 @@ private boolean fillBuffer() throws IOException {
if (size > compressedBuf.length) {
compressedBuf = new byte[size];
}

currentChunk.readBytes(compressedBuf, 0, size);
if (readFully(currentStream, compressedBuf, 0, size) != size) {
throw new IOException("Invalid EOF detected");
}
Comment thread
SteNicholas marked this conversation as resolved.
} else {
if (size > rawDataBuf.length) {
rawDataBuf = new byte[size];
}

currentChunk.readBytes(rawDataBuf, 0, size);
if (readFully(currentStream, rawDataBuf, 0, size) != size) {
throw new IOException("Invalid EOF detected");
}
Comment thread
SteNicholas marked this conversation as resolved.
}

// de-duplicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private void checkpoint() {
}

@Override
public ByteBuf next() throws Exception {
public Pair<ByteBuf, Boolean> next() throws Exception {
Pair<Integer, ByteBuf> chunk = null;
checkpoint();
if (!fetchThreadStarted) {
Expand Down Expand Up @@ -328,7 +328,7 @@ public ByteBuf next() throws Exception {
}
returnedChunks++;
lastReturnedChunkId = chunk.getLeft();
return chunk.getRight();
return Pair.of(chunk.getRight(), false);
}
Comment thread
SteNicholas marked this conversation as resolved.

private void checkException() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCounted;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -214,7 +215,7 @@ public boolean hasNext() {
}

@Override
public ByteBuf next() throws IOException, InterruptedException {
public Pair<ByteBuf, Boolean> next() throws Exception {
checkException();
if (chunkIndex <= endChunkIndex) {
fetchChunks();
Expand Down Expand Up @@ -254,8 +255,12 @@ public ByteBuf next() throws IOException, InterruptedException {
logger.error("PartitionReader thread interrupted while fetching data.");
throw e;
}
int chunkIdx = startChunkIndex + returnedChunks;
returnedChunks++;
return chunk;
boolean compressed =
streamHandler.getChunkCompressedCount() > chunkIdx
&& streamHandler.getChunkCompressed(chunkIdx);
return Pair.of(chunk, compressed);
}

private void checkException() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import java.util.Optional;

import io.netty.buffer.ByteBuf;
import org.apache.commons.lang3.tuple.Pair;

import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.protocol.PartitionLocation;

public interface PartitionReader {
boolean hasNext();

ByteBuf next() throws Exception;
Pair<ByteBuf, Boolean> next() throws Exception;

void close();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private void checkpoint() {
}

@Override
public ByteBuf next() throws IOException, InterruptedException {
public Pair<ByteBuf, Boolean> next() throws Exception {
checkpoint();
checkException();
if (chunkIndex <= endChunkIndex) {
Expand Down Expand Up @@ -229,7 +229,11 @@ public ByteBuf next() throws IOException, InterruptedException {
returnedChunks++;
inflightRequestCount--;
lastReturnedChunkId = chunk.getLeft();
return chunk.getRight();
int chunkIdx = chunk.getLeft();
boolean compressed =
streamHandler.getChunkCompressedCount() > chunkIdx
&& streamHandler.getChunkCompressed(chunkIdx);
return Pair.of(chunk.getRight(), compressed);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.celeborn.client.listener.WorkerStatusListener
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.CelebornConf.ACTIVE_STORAGE_TYPES
import org.apache.celeborn.common.client.{ApplicationInfoProvider, MasterClient}
import org.apache.celeborn.common.compression.ChunkCompressionContext
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ApplicationMeta, ShufflePartitionLocationInfo, WorkerInfo}
Expand Down Expand Up @@ -1324,7 +1325,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
userIdentifier,
conf.pushDataTimeoutMs,
partitionSplitEnabled = true,
isSegmentGranularityVisible = isSegmentGranularityVisible))
isSegmentGranularityVisible = isSegmentGranularityVisible,
chunkCompressionContext = new ChunkCompressionContext(
conf.isChunkCompressionEnabled,
conf.chunkCompressionLevel)))
futures.add((future, workerInfo))
}(ec)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.compression;

/**
* Carries chunk-level compression settings from the client through to the worker's {@code
* ChunkCompressedFileChannelWriter}. Using a context object instead of a bare boolean keeps the
* call chain stable as new compression knobs are added.
*/
public final class ChunkCompressionContext {

/** ZSTD default compression level (mirrors {@code Zstd.defaultCompressionLevel()}). */
public static final int DEFAULT_COMPRESSION_LEVEL = 3;

private static final ChunkCompressionContext DISABLED =
new ChunkCompressionContext(false, DEFAULT_COMPRESSION_LEVEL);

private final boolean enabled;
private final int compressionLevel;

public ChunkCompressionContext(boolean enabled, int compressionLevel) {
this.enabled = enabled;
this.compressionLevel = compressionLevel;
}

/** Returns a context with compression disabled and the default compression level. */
public static ChunkCompressionContext disabled() {
return DISABLED;
}

public boolean isEnabled() {
return enabled;
}

public int getCompressionLevel() {
return compressionLevel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.slf4j.LoggerFactory;

import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.compression.ChunkCompressionContext;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.StorageInfo;
import org.apache.celeborn.common.util.Utils;
Expand All @@ -39,16 +40,19 @@ public class DiskFileInfo extends FileInfo {
private static final Logger logger = LoggerFactory.getLogger(DiskFileInfo.class);
private final String filePath;
private final StorageInfo.Type storageType;
private final ChunkCompressionContext chunkCompressionContext;

public DiskFileInfo(
UserIdentifier userIdentifier,
boolean partitionSplitEnabled,
FileMeta fileMeta,
String filePath,
StorageInfo.Type storageType) {
StorageInfo.Type storageType,
ChunkCompressionContext chunkCompressionContext) {
super(userIdentifier, partitionSplitEnabled, fileMeta);
this.filePath = filePath;
this.storageType = storageType;
this.chunkCompressionContext = chunkCompressionContext;
}

// only called when restore from pb or in UT
Expand All @@ -58,9 +62,11 @@ public DiskFileInfo(
FileMeta fileMeta,
String filePath,
StorageInfo.Type storageType,
long bytesFlushed) {
long bytesFlushed,
ChunkCompressionContext chunkCompressionContext) {
super(userIdentifier, partitionSplitEnabled, fileMeta);
this.filePath = filePath;
this.chunkCompressionContext = chunkCompressionContext;
if (storageType != null) {
this.storageType = storageType;
} else {
Expand All @@ -76,13 +82,16 @@ public DiskFileInfo(File file, UserIdentifier userIdentifier, CelebornConf conf)
true,
new ReduceFileMeta(new ArrayList<>(Arrays.asList(0L)), conf.shuffleChunkSize()),
file.getAbsolutePath(),
StorageInfo.Type.HDD);
StorageInfo.Type.HDD,
ChunkCompressionContext.disabled());
}

// User only by the sorted
public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath) {
super(userIdentifier, true, fileMeta);
this.filePath = filePath;
this.storageType = StorageInfo.Type.HDD;
this.chunkCompressionContext = ChunkCompressionContext.disabled();
}

public File getFile() {
Expand Down Expand Up @@ -175,4 +184,16 @@ public boolean isDFS() {
public StorageInfo.Type getStorageType() {
return storageType;
}

public boolean isChunkCompressionEnabled() {
return chunkCompressionContext.isEnabled();
}

public int getChunkCompressionLevel() {
return chunkCompressionContext.getCompressionLevel();
}

public ChunkCompressionContext getChunkCompressionContext() {
return chunkCompressionContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public synchronized void updateBytesFlushed(long bytes) {
}
}

public synchronized void setBytesFlushed(long bytesFlushed) {
this.bytesFlushed = bytesFlushed;
}
Comment thread
SteNicholas marked this conversation as resolved.

public UserIdentifier getUserIdentifier() {
return userIdentifier;
}
Expand Down
Loading
Loading