Skip to content

Use adaptive per-method batch sizing, drop startup benchmark#390

Open
kali113 wants to merge 1 commit into
p-e-w:masterfrom
kali113:fix/adaptive-batch-sizing
Open

Use adaptive per-method batch sizing, drop startup benchmark#390
kali113 wants to merge 1 commit into
p-e-w:masterfrom
kali113:fix/adaptive-batch-sizing

Conversation

@kali113

@kali113 kali113 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Dropped the old startup benchmark because looping through batch sizes from 1 to 128 was just wasting time on warmups and forcing an inefficient global size on everything.

Now, responses, residuals, and logprobs track their own batch sizes independently, defaulting to 128. If a method hits a CUDA OOM, it catches it, halves its batch size, and retries. That new size is saved for its next runs.

This fixes three main things. First, there is no startup lag so inference starts immediately. Second, it improves VRAM efficiency because responses take way more VRAM than residuals, so now each method actually finds its own limits. Finally, it adds resilience since the system dynamically adapts if memory pressure changes mid-run.

Fixes #248.

Copilot AI review requested due to automatic review settings June 17, 2026 14:40

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR replaces the previous “auto batch size benchmarking” flow with adaptive batching that automatically reduces batch size on CUDA OOM, and updates configuration accordingly.

Changes:

  • Add an internal batching helper that halves batch size on CUDA out-of-memory errors
  • Refactor *_batched methods to use adaptive batching instead of batchify
  • Remove the CLI-time batch size autotuning logic and simplify settings (batch_size no longer supports 0=auto)

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
src/heretic/model.py Introduces _batched() with OOM backoff and rewires batched inference methods to use it
src/heretic/main.py Removes runtime batch-size autotuning benchmark loop
src/heretic/config.py Updates batch_size semantics/docs and removes max_batch_size setting

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/heretic/model.py
Comment on lines +689 to +696
try:
results.append(fn(prompts[i : i + n]))
i += n
except torch.cuda.OutOfMemoryError:
if n == 1:
raise
self._batch_sizes[key] = n // 2
empty_cache()
Comment thread src/heretic/config.py
Comment on lines 184 to 188
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
)

max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
description="Number of input sequences to process in parallel. "
"If an out-of-memory error occurs, the batch size is halved automatically.",
)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. We're in the 2.0 development cycle, where breaking changes are expected.

Comment thread src/heretic/model.py
Comment on lines +102 to +105
self._batch_sizes: dict[str, int] = {
key: max(1, settings.batch_size)
for key in ("responses", "residuals", "logprobs")
}
Comment thread src/heretic/config.py
Comment on lines 184 to 188
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
)

max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
description="Number of input sequences to process in parallel. "
"If an out-of-memory error occurs, the batch size is halved automatically.",
)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the static, upfront automatic batch size determination with a dynamic batching mechanism that automatically halves the batch size upon encountering an out-of-memory error. It updates the default batch_size to 128 and removes the max_batch_size setting. The review feedback points out a style guide violation regarding updating config.default.toml to match the new configuration settings. Additionally, suggestions are provided to improve backward compatibility for users with batch_size = 0 and to optimize the dynamic batching helper by handling non-CUDA OOM errors and avoiding traceback-related memory leaks.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/heretic/config.py
Comment on lines 184 to 188
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
)

max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
description="Number of input sequences to process in parallel. "
"If an out-of-memory error occurs, the batch size is halved automatically.",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since the default value of batch_size has been changed to 128 and max_batch_size has been removed, please ensure that config.default.toml is updated accordingly to keep them in sync. If config.default.toml is left with batch_size = 0, any user running with the default configuration will have settings.batch_size set to 0, which will resolve to an initial batch size of 1 in Model.__init__ and severely degrade performance.

References
  1. When new settings are added in config.py, they should also be added to config.default.toml, set to their default value and with their description as a comment. The order of settings in config.default.toml should match that in config.py. (link)

Comment thread src/heretic/model.py
Comment on lines +102 to +105
self._batch_sizes: dict[str, int] = {
key: max(1, settings.batch_size)
for key in ("responses", "residuals", "logprobs")
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure backward compatibility with existing user configurations that have batch_size = 0 (the previous default for 'auto'), it is safer to map 0 (or any non-positive value) to the new default of 128 instead of silently falling back to 1 via max(1, settings.batch_size). A batch size of 1 will cause extremely slow inference without any warning to the user.

Suggested change
self._batch_sizes: dict[str, int] = {
key: max(1, settings.batch_size)
for key in ("responses", "residuals", "logprobs")
}
self._batch_sizes: dict[str, int] = {
key: settings.batch_size if settings.batch_size > 0 else 128
for key in ("responses", "residuals", "logprobs")
}

Comment thread src/heretic/model.py
Comment on lines +681 to +697
def _batched(
self, prompts: list[Prompt], key: str, fn: Callable[[list[Prompt]], Any]
) -> list[Any]:
"""Run fn on prompt batches, halving the batch size on OOM."""
results = []
i = 0
while i < len(prompts):
n = self._batch_sizes[key]
try:
results.append(fn(prompts[i : i + n]))
i += n
except torch.cuda.OutOfMemoryError:
if n == 1:
raise
self._batch_sizes[key] = n // 2
empty_cache()
return results

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two key improvements we can make to the _batched helper method:

  1. Traceback Memory Leak / GC Delay: In Python, when an exception is caught, the active exception and its traceback are kept alive during the execution of the except block. This traceback holds references to the stack frame, keeping any local tensors allocated during the failed forward pass alive. Calling empty_cache() inside the except block will therefore not be able to reclaim that memory. Moving empty_cache() outside/after the except block ensures the exception is cleared and the memory is successfully freed before retrying.
  2. Non-CUDA Support: The codebase supports non-CUDA accelerators (like MPS and XPU) as seen in utils.py. Catching only torch.cuda.OutOfMemoryError means OOMs on other platforms (which typically raise RuntimeError with an "out of memory" message) won't be caught, causing crashes instead of adapting the batch size. We can catch both and check the error message for robustness.
    def _batched(
        self, prompts: list[Prompt], key: str, fn: Callable[[list[Prompt]], Any]
    ) -> list[Any]:
        """Run fn on prompt batches, halving the batch size on OOM."""
        results = []
        i = 0
        while i < len(prompts):
            n = self._batch_sizes[key]
            oom = False
            try:
                results.append(fn(prompts[i : i + n]))
                i += n
            except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
                if isinstance(error, RuntimeError) and "out of memory" not in str(error):
                    raise
                if n == 1:
                    raise
                self._batch_sizes[key] = n // 2
                oom = True
            if oom:
                empty_cache()
        return results

@kali113 kali113 force-pushed the fix/adaptive-batch-sizing branch from 32bd2ba to 09bbdff Compare June 17, 2026 14:42

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the startup-time automatic batch size determination with a dynamic, adaptive batching mechanism that halves the batch size upon encountering an out-of-memory error during execution. Feedback on these changes highlights three key areas for improvement: addressing backward compatibility for existing configurations where batch_size is set to 0 (which would currently fallback to an inefficient batch size of 1), extending the OOM exception handling to support non-CUDA platforms (such as MPS and XPU) by catching relevant RuntimeErrors, and updating config.default.toml to align with the configuration changes in config.py as required by the repository style guide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/heretic/model.py
Comment on lines +102 to +105
self._batch_sizes: dict[str, int] = {
key: max(1, settings.batch_size)
for key in ("responses", "residuals", "logprobs")
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In previous versions of Heretic, batch_size = 0 was the default value representing "auto" batch sizing. If an existing user upgrades and runs Heretic with their old config.toml containing batch_size = 0, max(1, settings.batch_size) will evaluate to 1. This will silently force a batch size of 1 for all operations, causing severe performance degradation without any warning.\n\nTo preserve backward compatibility and ensure a smooth upgrade path, please default to 128 if settings.batch_size is 0 or negative.

        self._batch_sizes: dict[str, int] = {\n            key: settings.batch_size if settings.batch_size > 0 else 128\n            for key in ('responses', 'residuals', 'logprobs')\n        }

Comment thread src/heretic/model.py Outdated
Comment on lines +692 to +696
except torch.cuda.OutOfMemoryError:
if n == 1:
raise
self._batch_sizes[key] = n // 2
empty_cache()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching only torch.cuda.OutOfMemoryError is CUDA-specific. Since Heretic explicitly supports other accelerators like MPS (Apple Silicon) and XPU (Intel GPUs), which raise a generic RuntimeError on out-of-memory conditions, catching only the CUDA-specific exception will cause the program to crash on those platforms instead of dynamically halving the batch size.\n\nTo ensure the adaptive batch sizing is resilient across all supported platforms, please catch both torch.cuda.OutOfMemoryError and RuntimeError (verifying that "out of memory" is in the error message).

            except (torch.cuda.OutOfMemoryError, RuntimeError) as error:\n                if isinstance(error, RuntimeError) and 'out of memory' not in str(error).lower():\n                    raise\n                if n == 1:\n                    raise\n                self._batch_sizes[key] = n // 2\n                empty_cache()

Comment thread src/heretic/config.py
Comment on lines 184 to 188
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
)

max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
description="Number of input sequences to process in parallel. "
"If an out-of-memory error occurs, the batch size is halved automatically.",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value of batch_size has been changed to 128 and max_batch_size has been removed. To keep the configuration files in sync and adhere to the repository style guide, please update config.default.toml to reflect these changes (updating the default value of batch_size and removing max_batch_size).

References
  1. When settings are added or modified in config.py, they should also be updated in config.default.toml to match their default values and descriptions. (link)

@kali113 kali113 force-pushed the fix/adaptive-batch-sizing branch from 09bbdff to c9d0b0b Compare June 17, 2026 14:49
@kali113 kali113 force-pushed the fix/adaptive-batch-sizing branch from c9d0b0b to 043473e Compare June 17, 2026 14:57
@p-e-w

p-e-w commented Jun 18, 2026

Copy link
Copy Markdown
Owner

Thanks, this is looking good at first glance!

Please note that we're currently in a feature freeze until #53 is merged, so it might take me a while to merge this, but I definitely want this change as per my design in #248.

Out of interest, how come Copilot is commenting here?

Comment thread src/heretic/config.py
)

max_batch_size: int = Field(
default=128,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 is a better default per #248.

@p-e-w

p-e-w commented Jun 18, 2026

Copy link
Copy Markdown
Owner

Please note that this also needs to work on backends other than CUDA, of which we support several. Detecting OOM reliably might be a challenge.

Also, with some backends (notably those that use system RAM), an OOM might result in the entire Python process being collected by the OOM reaper on Linux. Not sure what the best way to deal with that is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] A better automatic batch size mechanism

3 participants