From 605757e2057570bf34fa3afa540daeae8b2596d0 Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Tue, 19 May 2026 14:18:04 +0530 Subject: [PATCH 1/3] Fix Netty direct memory stuck after high shuffle workload on worker --- .../deploy/worker/storage/Flusher.scala | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala index f5bd4b0d779..10d3cd8442d 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala @@ -73,29 +73,38 @@ abstract private[worker] class Flusher( copyBytes = new Array[Byte](maxTaskSize.toInt) } while (!stopFlag.get()) { - val task = workingQueues(index).take() - val key = s"Flusher-$this-${Random.nextInt()}" - workerSource.sample(getFlushTimeMetric(), key) { - if (!task.notifier.hasException) { - try { - val flushBeginTime = System.nanoTime() - lastBeginFlushTime.set(index, flushBeginTime) - task.flush(copyBytes) - if (flushTimeMetric != null) { - val delta = System.nanoTime() - flushBeginTime - flushTimeMetric.update(delta) + val task = workingQueues(index).poll(1000, TimeUnit.MILLISECONDS) + if (task != null) { + val key = s"Flusher-$this-${Random.nextInt()}" + workerSource.sample(getFlushTimeMetric(), key) { + if (!task.notifier.hasException) { + try { + val flushBeginTime = System.nanoTime() + lastBeginFlushTime.set(index, flushBeginTime) + task.flush(copyBytes) + if (flushTimeMetric != null) { + val delta = System.nanoTime() - flushBeginTime + flushTimeMetric.update(delta) + } + } catch { + case t: Throwable => + val e = ExceptionUtils.wrapThrowableToIOException(t) + task.notifier.setException(e) + processIOException(e, DiskStatus.READ_OR_WRITE_FAILURE) + logWarning(s"Flusher-$this-thread-$index encounter exception.", t) } - } catch { - case t: Throwable => - val e = ExceptionUtils.wrapThrowableToIOException(t) - task.notifier.setException(e) - processIOException(e, DiskStatus.READ_OR_WRITE_FAILURE) - logWarning(s"Flusher-$this-thread-$index encounter exception.", t) + lastBeginFlushTime.set(index, -1) } - lastBeginFlushTime.set(index, -1) + Utils.tryLogNonFatalError(returnBuffer(task.buffer, task.keepBuffer)) + task.notifier.numPendingFlushes.decrementAndGet() + } + } else { + allocator match { + case alloc: PooledByteBufAllocator => + // Free buffer pool memory to main direct memory when flush thread is idle. + alloc.trimCurrentThreadCache + case _ => } - Utils.tryLogNonFatalError(returnBuffer(task.buffer, task.keepBuffer)) - task.notifier.numPendingFlushes.decrementAndGet() } } } From bffcc45271b2cb851f1f683fe6d0a81224ad9d9c Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Wed, 27 May 2026 16:28:46 +0530 Subject: [PATCH 2/3] Copilot suggestions Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../celeborn/service/deploy/worker/storage/Flusher.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala index 10d3cd8442d..bfd48e6e5e2 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/Flusher.scala @@ -95,9 +95,9 @@ abstract private[worker] class Flusher( } lastBeginFlushTime.set(index, -1) } - Utils.tryLogNonFatalError(returnBuffer(task.buffer, task.keepBuffer)) - task.notifier.numPendingFlushes.decrementAndGet() } + Utils.tryLogNonFatalError(returnBuffer(task.buffer, task.keepBuffer)) + task.notifier.numPendingFlushes.decrementAndGet() } else { allocator match { case alloc: PooledByteBufAllocator => From e84bf67b22a8802c5686104eb27473d519f3de9a Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Fri, 12 Jun 2026 17:05:07 +0530 Subject: [PATCH 3/3] Add tests --- .../deploy/worker/storage/FlusherSuite.scala | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlusherSuite.scala diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlusherSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlusherSuite.scala new file mode 100644 index 00000000000..ea88b9e075c --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlusherSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage + +import java.io.IOException +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable.ArrayBuffer + +import io.netty.buffer.{ByteBufAllocator, CompositeByteBuf, PooledByteBufAllocator, UnpooledByteBufAllocator} +import org.mockito.Mockito.{timeout, verify} +import org.mockito.MockitoSugar.{mock, spy} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.{interval, timeout => patienceTimeout} +import org.scalatest.time.SpanSugar._ + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.meta.DiskStatus +import org.apache.celeborn.common.metrics.source.AbstractSource +import org.apache.celeborn.service.deploy.worker.WorkerSource +import org.apache.celeborn.service.deploy.worker.memory.MemoryManager + +class FlusherSuite extends CelebornFunSuite { + + // The flusher worker threads block in poll() for up to 1000ms (see Flusher) + // before falling through to the idle/trim branch, so verifications must allow + // at least that long plus scheduling slack. + private val VERIFY_TIMEOUT_MS = 5000 + + private val flushers = new ArrayBuffer[TestFlusher]() + + override protected def beforeEach(): Unit = { + super.beforeEach() + MemoryManager.reset() + } + + override protected def afterEach(): Unit = { + flushers.foreach(_.shutdown()) + flushers.clear() + MemoryManager.reset() + super.afterEach() + } + + /** + * Minimal concrete [[Flusher]] used to exercise the worker loop in isolation. + * The base class starts its worker threads in `init()` on construction. + */ + private class TestFlusher( + workerSource: AbstractSource, + allocator: ByteBufAllocator, + threadCount: Int) + extends Flusher( + workerSource, + threadCount, + allocator, + 16, + null, + "test-mount", + false, + 1024L) { + + override def processIOException(e: IOException, deviceErrorType: DiskStatus): Unit = {} + + override def getFlushTimeMetric(): String = WorkerSource.FLUSH_LOCAL_DATA_TIME + + def shutdown(): Unit = { + stopFlag.set(true) + // Interrupt the threads blocked in poll() so they exit promptly and the + // suite's ThreadAudit does not report a leak. + workers.foreach(worker => if (worker != null) worker.shutdownNow()) + } + } + + private def newFlusher( + allocator: ByteBufAllocator, + threadCount: Int = 1, + workerSource: AbstractSource = mock[AbstractSource]): TestFlusher = { + val flusher = new TestFlusher(workerSource, allocator, threadCount) + flushers += flusher + flusher + } + + test("trimCurrentThreadCache is invoked when the working queue stays empty") { + val allocator = spy(new PooledByteBufAllocator(true)) + newFlusher(allocator) + + // With no tasks ever submitted, the idle branch must trim the pooled + // allocator's thread-local cache, returning that memory to the shared pool. + verify(allocator, timeout(VERIFY_TIMEOUT_MS).atLeastOnce()).trimCurrentThreadCache + } + + test("each idle flusher worker thread trims its own thread cache") { + val threadCount = 3 + val allocator = spy(new PooledByteBufAllocator(true)) + newFlusher(allocator, threadCount) + + // trimCurrentThreadCache trims only the calling thread's cache, so every + // idle worker thread must call it - expect at least one call per thread. + verify(allocator, timeout(VERIFY_TIMEOUT_MS).atLeast(threadCount)).trimCurrentThreadCache + } + + test("non-pooled allocator skips trim and the worker loop keeps processing tasks") { + MemoryManager.initialize(new CelebornConf()) + val workerSource = new WorkerSource(new CelebornConf()) + // A non-pooled allocator has no thread cache: the idle branch must take the + // `case _ =>` no-op path. UnpooledByteBufAllocator has no trimCurrentThreadCache + // to call, so the only way this test passes is if that path is a clean no-op. + val flusher = newFlusher(UnpooledByteBufAllocator.DEFAULT, workerSource = workerSource) + + // Let the worker spin through several idle poll cycles before submitting work, + // so we know the idle branch ran without disrupting the loop. + Thread.sleep(2500) + assertTaskIsFlushed(flusher, workerSource) + } + + test("submitted task is still flushed after switching take() to poll()") { + MemoryManager.initialize(new CelebornConf()) + val workerSource = new WorkerSource(new CelebornConf()) + val allocator = spy(new PooledByteBufAllocator(true)) + val flusher = newFlusher(allocator, workerSource = workerSource) + + assertTaskIsFlushed(flusher, workerSource) + } + + /** Submits one task and asserts it is flushed and its pending count drained. */ + private def assertTaskIsFlushed(flusher: TestFlusher, source: AbstractSource): Unit = { + val bytes = "flush-after-poll".getBytes("UTF-8") + val buffer: CompositeByteBuf = UnpooledByteBufAllocator.DEFAULT.compositeBuffer() + buffer.writeBytes(bytes) + + val notifier = new FlushNotifier() + notifier.numPendingFlushes.incrementAndGet() + val flushed = new AtomicBoolean(false) + val task = new FlushTask(buffer, notifier, false, source) { + override def flush(copyBytes: Array[Byte]): Unit = flushed.set(true) + } + + assert(flusher.addTask(task, 1000, 0)) + + eventually(patienceTimeout(VERIFY_TIMEOUT_MS.millis), interval(50.millis)) { + assert(flushed.get(), "flush() should have been invoked via the poll() path") + assert( + notifier.numPendingFlushes.get() == 0, + "pending flush count should be decremented after the task is processed") + } + } +}