Use adaptive per-method batch sizing, drop startup benchmark#390
Use adaptive per-method batch sizing, drop startup benchmark#390kali113 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR replaces the previous “auto batch size benchmarking” flow with adaptive batching that automatically reduces batch size on CUDA OOM, and updates configuration accordingly.
Changes:
- Add an internal batching helper that halves batch size on CUDA out-of-memory errors
- Refactor
*_batchedmethods to use adaptive batching instead ofbatchify - Remove the CLI-time batch size autotuning logic and simplify settings (
batch_sizeno longer supports0=auto)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| src/heretic/model.py | Introduces _batched() with OOM backoff and rewires batched inference methods to use it |
| src/heretic/main.py | Removes runtime batch-size autotuning benchmark loop |
| src/heretic/config.py | Updates batch_size semantics/docs and removes max_batch_size setting |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try: | ||
| results.append(fn(prompts[i : i + n])) | ||
| i += n | ||
| except torch.cuda.OutOfMemoryError: | ||
| if n == 1: | ||
| raise | ||
| self._batch_sizes[key] = n // 2 | ||
| empty_cache() |
| batch_size: int = Field( | ||
| default=0, # auto | ||
| description="Number of input sequences to process in parallel (0 = auto).", | ||
| ) | ||
|
|
||
| max_batch_size: int = Field( | ||
| default=128, | ||
| description="Maximum batch size to try when automatically determining the optimal batch size.", | ||
| # When storing a settings object, the batch size is already fixed, | ||
| # either determined by the automatic mechanism or by explicit user choice. | ||
| exclude=True, | ||
| description="Number of input sequences to process in parallel. " | ||
| "If an out-of-memory error occurs, the batch size is halved automatically.", | ||
| ) |
There was a problem hiding this comment.
No problem. We're in the 2.0 development cycle, where breaking changes are expected.
| self._batch_sizes: dict[str, int] = { | ||
| key: max(1, settings.batch_size) | ||
| for key in ("responses", "residuals", "logprobs") | ||
| } |
| batch_size: int = Field( | ||
| default=0, # auto | ||
| description="Number of input sequences to process in parallel (0 = auto).", | ||
| ) | ||
|
|
||
| max_batch_size: int = Field( | ||
| default=128, | ||
| description="Maximum batch size to try when automatically determining the optimal batch size.", | ||
| # When storing a settings object, the batch size is already fixed, | ||
| # either determined by the automatic mechanism or by explicit user choice. | ||
| exclude=True, | ||
| description="Number of input sequences to process in parallel. " | ||
| "If an out-of-memory error occurs, the batch size is halved automatically.", | ||
| ) |
There was a problem hiding this comment.
Code Review
This pull request replaces the static, upfront automatic batch size determination with a dynamic batching mechanism that automatically halves the batch size upon encountering an out-of-memory error. It updates the default batch_size to 128 and removes the max_batch_size setting. The review feedback points out a style guide violation regarding updating config.default.toml to match the new configuration settings. Additionally, suggestions are provided to improve backward compatibility for users with batch_size = 0 and to optimize the dynamic batching helper by handling non-CUDA OOM errors and avoiding traceback-related memory leaks.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| batch_size: int = Field( | ||
| default=0, # auto | ||
| description="Number of input sequences to process in parallel (0 = auto).", | ||
| ) | ||
|
|
||
| max_batch_size: int = Field( | ||
| default=128, | ||
| description="Maximum batch size to try when automatically determining the optimal batch size.", | ||
| # When storing a settings object, the batch size is already fixed, | ||
| # either determined by the automatic mechanism or by explicit user choice. | ||
| exclude=True, | ||
| description="Number of input sequences to process in parallel. " | ||
| "If an out-of-memory error occurs, the batch size is halved automatically.", | ||
| ) |
There was a problem hiding this comment.
Since the default value of batch_size has been changed to 128 and max_batch_size has been removed, please ensure that config.default.toml is updated accordingly to keep them in sync. If config.default.toml is left with batch_size = 0, any user running with the default configuration will have settings.batch_size set to 0, which will resolve to an initial batch size of 1 in Model.__init__ and severely degrade performance.
References
- When new settings are added in
config.py, they should also be added toconfig.default.toml, set to their default value and with their description as a comment. The order of settings inconfig.default.tomlshould match that inconfig.py. (link)
| self._batch_sizes: dict[str, int] = { | ||
| key: max(1, settings.batch_size) | ||
| for key in ("responses", "residuals", "logprobs") | ||
| } |
There was a problem hiding this comment.
To ensure backward compatibility with existing user configurations that have batch_size = 0 (the previous default for 'auto'), it is safer to map 0 (or any non-positive value) to the new default of 128 instead of silently falling back to 1 via max(1, settings.batch_size). A batch size of 1 will cause extremely slow inference without any warning to the user.
| self._batch_sizes: dict[str, int] = { | |
| key: max(1, settings.batch_size) | |
| for key in ("responses", "residuals", "logprobs") | |
| } | |
| self._batch_sizes: dict[str, int] = { | |
| key: settings.batch_size if settings.batch_size > 0 else 128 | |
| for key in ("responses", "residuals", "logprobs") | |
| } |
| def _batched( | ||
| self, prompts: list[Prompt], key: str, fn: Callable[[list[Prompt]], Any] | ||
| ) -> list[Any]: | ||
| """Run fn on prompt batches, halving the batch size on OOM.""" | ||
| results = [] | ||
| i = 0 | ||
| while i < len(prompts): | ||
| n = self._batch_sizes[key] | ||
| try: | ||
| results.append(fn(prompts[i : i + n])) | ||
| i += n | ||
| except torch.cuda.OutOfMemoryError: | ||
| if n == 1: | ||
| raise | ||
| self._batch_sizes[key] = n // 2 | ||
| empty_cache() | ||
| return results |
There was a problem hiding this comment.
There are two key improvements we can make to the _batched helper method:
- Traceback Memory Leak / GC Delay: In Python, when an exception is caught, the active exception and its traceback are kept alive during the execution of the
exceptblock. This traceback holds references to the stack frame, keeping any local tensors allocated during the failed forward pass alive. Callingempty_cache()inside theexceptblock will therefore not be able to reclaim that memory. Movingempty_cache()outside/after theexceptblock ensures the exception is cleared and the memory is successfully freed before retrying. - Non-CUDA Support: The codebase supports non-CUDA accelerators (like MPS and XPU) as seen in
utils.py. Catching onlytorch.cuda.OutOfMemoryErrormeans OOMs on other platforms (which typically raiseRuntimeErrorwith an "out of memory" message) won't be caught, causing crashes instead of adapting the batch size. We can catch both and check the error message for robustness.
def _batched(
self, prompts: list[Prompt], key: str, fn: Callable[[list[Prompt]], Any]
) -> list[Any]:
"""Run fn on prompt batches, halving the batch size on OOM."""
results = []
i = 0
while i < len(prompts):
n = self._batch_sizes[key]
oom = False
try:
results.append(fn(prompts[i : i + n]))
i += n
except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
if isinstance(error, RuntimeError) and "out of memory" not in str(error):
raise
if n == 1:
raise
self._batch_sizes[key] = n // 2
oom = True
if oom:
empty_cache()
return results32bd2ba to
09bbdff
Compare
There was a problem hiding this comment.
Code Review
This pull request replaces the startup-time automatic batch size determination with a dynamic, adaptive batching mechanism that halves the batch size upon encountering an out-of-memory error during execution. Feedback on these changes highlights three key areas for improvement: addressing backward compatibility for existing configurations where batch_size is set to 0 (which would currently fallback to an inefficient batch size of 1), extending the OOM exception handling to support non-CUDA platforms (such as MPS and XPU) by catching relevant RuntimeErrors, and updating config.default.toml to align with the configuration changes in config.py as required by the repository style guide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self._batch_sizes: dict[str, int] = { | ||
| key: max(1, settings.batch_size) | ||
| for key in ("responses", "residuals", "logprobs") | ||
| } |
There was a problem hiding this comment.
In previous versions of Heretic, batch_size = 0 was the default value representing "auto" batch sizing. If an existing user upgrades and runs Heretic with their old config.toml containing batch_size = 0, max(1, settings.batch_size) will evaluate to 1. This will silently force a batch size of 1 for all operations, causing severe performance degradation without any warning.\n\nTo preserve backward compatibility and ensure a smooth upgrade path, please default to 128 if settings.batch_size is 0 or negative.
self._batch_sizes: dict[str, int] = {\n key: settings.batch_size if settings.batch_size > 0 else 128\n for key in ('responses', 'residuals', 'logprobs')\n }| except torch.cuda.OutOfMemoryError: | ||
| if n == 1: | ||
| raise | ||
| self._batch_sizes[key] = n // 2 | ||
| empty_cache() |
There was a problem hiding this comment.
Catching only torch.cuda.OutOfMemoryError is CUDA-specific. Since Heretic explicitly supports other accelerators like MPS (Apple Silicon) and XPU (Intel GPUs), which raise a generic RuntimeError on out-of-memory conditions, catching only the CUDA-specific exception will cause the program to crash on those platforms instead of dynamically halving the batch size.\n\nTo ensure the adaptive batch sizing is resilient across all supported platforms, please catch both torch.cuda.OutOfMemoryError and RuntimeError (verifying that "out of memory" is in the error message).
except (torch.cuda.OutOfMemoryError, RuntimeError) as error:\n if isinstance(error, RuntimeError) and 'out of memory' not in str(error).lower():\n raise\n if n == 1:\n raise\n self._batch_sizes[key] = n // 2\n empty_cache()| batch_size: int = Field( | ||
| default=0, # auto | ||
| description="Number of input sequences to process in parallel (0 = auto).", | ||
| ) | ||
|
|
||
| max_batch_size: int = Field( | ||
| default=128, | ||
| description="Maximum batch size to try when automatically determining the optimal batch size.", | ||
| # When storing a settings object, the batch size is already fixed, | ||
| # either determined by the automatic mechanism or by explicit user choice. | ||
| exclude=True, | ||
| description="Number of input sequences to process in parallel. " | ||
| "If an out-of-memory error occurs, the batch size is halved automatically.", | ||
| ) |
There was a problem hiding this comment.
The default value of batch_size has been changed to 128 and max_batch_size has been removed. To keep the configuration files in sync and adhere to the repository style guide, please update config.default.toml to reflect these changes (updating the default value of batch_size and removing max_batch_size).
References
- When settings are added or modified in
config.py, they should also be updated inconfig.default.tomlto match their default values and descriptions. (link)
09bbdff to
c9d0b0b
Compare
c9d0b0b to
043473e
Compare
| ) | ||
|
|
||
| max_batch_size: int = Field( | ||
| default=128, |
|
Please note that this also needs to work on backends other than CUDA, of which we support several. Detecting OOM reliably might be a challenge. Also, with some backends (notably those that use system RAM), an OOM might result in the entire Python process being collected by the OOM reaper on Linux. Not sure what the best way to deal with that is. |
Dropped the old startup benchmark because looping through batch sizes from 1 to 128 was just wasting time on warmups and forcing an inefficient global size on everything.
Now, responses, residuals, and logprobs track their own batch sizes independently, defaulting to 128. If a method hits a CUDA OOM, it catches it, halves its batch size, and retries. That new size is saved for its next runs.
This fixes three main things. First, there is no startup lag so inference starts immediately. Second, it improves VRAM efficiency because responses take way more VRAM than residuals, so now each method actually finds its own limits. Finally, it adds resilience since the system dynamically adapts if memory pressure changes mid-run.
Fixes #248.