diff --git a/.gitignore b/.gitignore index fda3e64..c615a79 100644 --- a/.gitignore +++ b/.gitignore @@ -215,4 +215,8 @@ __marimo__/ .vscode/settings.json # Ignore logs -logs/ \ No newline at end of file +logs/ +*/logs/ +*ckpt +.nfs** +examples/cifar-10-batches-py/* diff --git a/examples/batch_accumulation.ipynb b/examples/batch_accumulation.ipynb new file mode 100644 index 0000000..b7881d8 --- /dev/null +++ b/examples/batch_accumulation.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "630912d7", + "metadata": {}, + "source": [ + "# Batch Accumulation with Perspic\n", + "\n", + "This notebook demonstrates how to use **gradient accumulation** with the `perspic` analyzer.\n", + "Gradient accumulation lets you simulate a large effective batch size while only fitting a small micro-batch in GPU memory.\n", + "\n", + "We train a small Vision Transformer on CIFAR-10 and compare training **with** and **without** batch accumulation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e2d7ffb7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 7\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from pytorch_lightning.callbacks import LearningRateMonitor\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import CIFAR10\n", + "from torchvision.models import VisionTransformer\n", + "\n", + "from perspic.analyzer import analyzer\n", + "from examples.models import ClassificationModule\n", + "\n", + "pl.seed_everything(7)\n", + "\n", + "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", + "MICRO_BATCH_SIZE = 64\n", + "EFFECTIVE_BATCH_SIZE = 256 # 4x accumulation\n", + "NUM_WORKERS = int(os.cpu_count() / 2)" + ] + }, + { + "cell_type": "markdown", + "id": "d0818a3e", + "metadata": {}, + "source": [ + "## Data Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7ef5f919", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 170M/170M [00:07<00:00, 23.0MB/s] \n" + ] + } + ], + "source": [ + "stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n", + "train_transform = torchvision.transforms.Compose([\n", + " torchvision.transforms.RandomCrop(32, padding=4),\n", + " torchvision.transforms.RandomHorizontalFlip(),\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(*stats),\n", + "])\n", + "test_transform = torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(*stats),\n", + "])\n", + "\n", + "train_dataset_full = CIFAR10(PATH_DATASETS, train=True, download=True, transform=train_transform)\n", + "val_dataset_full = CIFAR10(PATH_DATASETS, train=True, download=True, transform=test_transform)\n", + "test_set = CIFAR10(PATH_DATASETS, train=False, download=True, transform=test_transform)\n", + "\n", + "generator = torch.Generator().manual_seed(42)\n", + "train_set, _ = random_split(train_dataset_full, [45000, 5000], generator=generator)\n", + "_, val_set = random_split(val_dataset_full, [45000, 5000], generator=generator)\n", + "\n", + "# Use micro-batch size for the DataLoader — accumulation handles the rest\n", + "train_dataloader = DataLoader(train_set, batch_size=MICRO_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)\n", + "val_dataloader = DataLoader(val_set, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, drop_last=True)\n", + "test_dataloader = DataLoader(test_set, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, drop_last=True)" + ] + }, + { + "cell_type": "markdown", + "id": "321b9a72", + "metadata": {}, + "source": [ + "## Model Definition\n", + "\n", + "A small Vision Transformer suitable for CIFAR-10." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "06369a8f", + "metadata": {}, + "outputs": [], + "source": [ + "model_vit = VisionTransformer(\n", + " image_size=32,\n", + " patch_size=8,\n", + " num_layers=2,\n", + " num_heads=4,\n", + " hidden_dim=128,\n", + " mlp_dim=256,\n", + " num_classes=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "563f9db6", + "metadata": {}, + "source": [ + "## Training with Batch Accumulation\n", + "\n", + "The key parameters are `micro_batch_size` and `effective_batch_size`. The analyzer will:\n", + "- Zero gradients only at the start of each accumulation cycle\n", + "- Scale the loss by `1 / accumulation_steps` during backward\n", + "- Step the optimizer only after `accumulation_steps = effective_batch_size // micro_batch_size` micro-batches\n", + "- Accumulate analysis metrics across micro-batches and log once per effective step" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "998cb095", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accumulation steps: 4\n", + "Micro-batch size: 64\n", + "Effective batch: 256\n" + ] + } + ], + "source": [ + "vit_accum = analyzer(\n", + " lightning_module=ClassificationModule,\n", + " sample_wise_engine=\"opacus\",\n", + " micro_batch_size=MICRO_BATCH_SIZE,\n", + " effective_batch_size=EFFECTIVE_BATCH_SIZE,\n", + " model=model_vit,\n", + " lr=0.005,\n", + ")\n", + "\n", + "print(f\"Accumulation steps: {vit_accum.accumulation_steps}\")\n", + "print(f\"Micro-batch size: {MICRO_BATCH_SIZE}\")\n", + "print(f\"Effective batch: {EFFECTIVE_BATCH_SIZE}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "930b6ba5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "----------------------------------------------------\n", + "0 | model | VisionTransformer | 293 K | train\n", + "----------------------------------------------------\n", + "293 K Trainable params\n", + "0 Non-trainable params\n", + "293 K Total params\n", + "1.174 Total estimated model params size (MB)\n", + "32 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c212575e0e8c49fe97c41ce521e73e22", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00:0: UserWarning: Full backward hook is firing when gradients are computed with respect to module outputs since no inputs require gradients. See https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook for more details.\n", + "/tikhome/jscheunemann/usr/miniconda3/envs/mast/lib/python3.13/site-packages/torch/autograd/graph.py:829: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_scaled_dot_product_efficient_attention_backward. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:81.)\n", + " return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "887a7dcd3bdd43b9a9ba1c78500c4c7a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | | 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n", + "\n", + "fig, axes = plt.subplots(2, 2, figsize=(12, 8))\n", + "fig.suptitle(f\"ViT with Batch Accumulation ({MICRO_BATCH_SIZE} → {EFFECTIVE_BATCH_SIZE})\", fontsize=14)\n", + "\n", + "# Loss\n", + "ax = axes[0, 0]\n", + "subset = metrics[[\"train_loss\", \"step\"]].dropna()\n", + "ax.plot(subset[\"step\"], subset[\"train_loss\"], alpha=0.3, label=\"Raw\")\n", + "ax.plot(subset[\"step\"], subset[\"train_loss\"].ewm(span=50).mean(), label=\"EMA\")\n", + "ax.set_title(\"Train Loss\")\n", + "ax.set_xlabel(\"Step\")\n", + "ax.legend()\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# chi_net\n", + "ax = axes[0, 1]\n", + "step_col = \"analysis_step\" if \"analysis_step\" in metrics.columns else \"step\"\n", + "subset = metrics[[\"chi_net\", step_col]].dropna()\n", + "if len(subset) > 0:\n", + " ax.plot(subset[step_col], subset[\"chi_net\"], alpha=0.3, label=\"Raw\")\n", + " ax.plot(subset[step_col], subset[\"chi_net\"].ewm(span=50).mean(), label=\"EMA\")\n", + "ax.set_title(\"χ_net\")\n", + "ax.set_xlabel(\"Step\")\n", + "ax.set_yscale(\"log\")\n", + "ax.legend()\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# chi_loss\n", + "ax = axes[1, 0]\n", + "subset = metrics[[\"chi_loss\", step_col]].dropna()\n", + "if len(subset) > 0:\n", + " ax.plot(subset[step_col], subset[\"chi_loss\"], alpha=0.3, label=\"Raw\")\n", + " ax.plot(subset[step_col], subset[\"chi_loss\"].ewm(span=50).mean(), label=\"EMA\")\n", + "ax.set_title(\"χ_loss\")\n", + "ax.set_xlabel(\"Step\")\n", + "ax.legend()\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "# Coupling\n", + "ax = axes[1, 1]\n", + "subset = metrics[[\"chi_coup\", step_col]].dropna()\n", + "if len(subset) > 0:\n", + " ax.plot(subset[step_col], subset[\"chi_coup\"], alpha=0.3, label=\"Raw\")\n", + " ax.plot(subset[step_col], subset[\"chi_coup\"].ewm(span=50).mean(), label=\"EMA\")\n", + "ax.set_title(\"Coupling Coefficient (χ_coup)\")\n", + "ax.set_xlabel(\"Step\")\n", + "ax.set_yscale(\"log\")\n", + "ax.legend()\n", + "ax.grid(True, alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0c915b2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chi_netchi_losschi_coupstep
0NaNNaNNaN0
1NaNNaNNaN0
2NaNNaNNaN1
3NaNNaNNaN1
4NaNNaNNaN2
...............
7031242.86471613.599990.0000413513
7032NaNNaNNaN3514
7033242.86471613.599990.0000413514
7034NaNNaNNaN3514
7035NaNNaNNaN3515
\n", + "

7036 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " chi_net chi_loss chi_coup step\n", + "0 NaN NaN NaN 0\n", + "1 NaN NaN NaN 0\n", + "2 NaN NaN NaN 1\n", + "3 NaN NaN NaN 1\n", + "4 NaN NaN NaN 2\n", + "... ... ... ... ...\n", + "7031 242.864716 13.59999 0.000041 3513\n", + "7032 NaN NaN NaN 3514\n", + "7033 242.864716 13.59999 0.000041 3514\n", + "7034 NaN NaN NaN 3514\n", + "7035 NaN NaN NaN 3515\n", + "\n", + "[7036 rows x 4 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics.keys()\n", + "metrics\n", + "metrics[[\"chi_net\", \"chi_loss\", \"chi_coup\", \"step\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65578b50", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mast", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/perspic/analyzer.py b/perspic/analyzer.py index 1c842e9..0ffafc4 100644 --- a/perspic/analyzer.py +++ b/perspic/analyzer.py @@ -22,6 +22,8 @@ def analyzer( analyze_every: Optional[int] = None, analysis_schedule: Optional[LogarithmicWindowSchedule] = None, cross_response: bool = False, + micro_batch_size: Optional[int] = None, + effective_batch_size: Optional[int] = None, **model_kwargs, ): """Factory function that wraps a LightningModule with analysis capabilities. @@ -52,11 +54,19 @@ def analyzer( analysis runs only at the scheduled steps. If both analyze_every and analysis_schedule are provided, analysis_schedule takes precedence. - cross_response: If True, enables cross-batch response analysis and assumes - the training batch is a dict with 'train' and 'measure' keys. - Defaults to False. - **model_kwargs: Additional keyword arguments passed to the - LightningModule constructor. + cross_response: If True, enables cross-batch response + analysis and assumes the training batch is a dict + with 'train' and 'measure' keys. Defaults to False. + micro_batch_size: The actual micro-batch size used by the + DataLoader. Required when effective_batch_size is + set. Can be provided alone (no accumulation). + effective_batch_size: The desired simulated batch size + achieved through gradient accumulation. Must be + divisible by micro_batch_size. When set, the optimizer + step is only performed every + (effective_batch_size // micro_batch_size) micro-batches. + **model_kwargs: Additional keyword arguments passed to + the LightningModule constructor. Returns: An initialized Analyzer instance that wraps the provided @@ -117,6 +127,8 @@ def __init__( analyze_every=analyze_every, analysis_schedule=analysis_schedule, cross_response=cross_response, + micro_batch_size=micro_batch_size, + effective_batch_size=effective_batch_size, **model_kwargs, ): super().__init__(**model_kwargs) @@ -178,6 +190,67 @@ def __init__( self.delegate_optimization = False self.automatic_optimization = False # We handle optimization manually + # Gradient accumulation setup + self.micro_batch_size = micro_batch_size + self.effective_batch_size = effective_batch_size + + if ( + effective_batch_size is not None + and micro_batch_size is None + ): + raise ValueError( + "micro_batch_size must be specified when " + "effective_batch_size is set." + ) + + if ( + micro_batch_size is not None + and effective_batch_size is not None + ): + if effective_batch_size < micro_batch_size: + raise ValueError( + f"effective_batch_size " + f"({effective_batch_size}) must be " + f">= micro_batch_size ({micro_batch_size})." + ) + if effective_batch_size % micro_batch_size != 0: + raise ValueError( + f"effective_batch_size " + f"({effective_batch_size}) must be " + f"divisible by micro_batch_size " + f"({micro_batch_size})." + ) + self.accumulation_steps = ( + effective_batch_size // micro_batch_size + ) + else: + self.accumulation_steps = 1 + + if ( + self.accumulation_steps > 1 + and self.delegate_optimization + ): + raise ValueError( + "Gradient accumulation is not supported " + "when the wrapped model uses manual " + "optimization (delegate_optimization=True)." + ) + + self._accumulation_count = 0 + self._optimizer_step_count = 0 + + # Analysis accumulation buffers + self._accum_chi_net = [] + self._accum_chi_loss = [] + self._accum_cross_chi_net = [] + self._accum_cross_chi_loss = [] + self._accum_grad_train = None + self._accum_grad_measure = None + self._accum_train_loss = 0.0 + self._accum_measure_loss = 0.0 + # Track whether analysis is active for this cycle + self._analysis_active = False + # Check if model has criterion attribute if not hasattr(self, "criterion"): raise AttributeError( @@ -223,29 +296,53 @@ def training_step(self, batch, batch_idx): # Initializing manual optimization opt = self.optimizers() - opt.zero_grad() + + # Zero gradients only at start of accumulation cycle + if self._accumulation_count == 0: + opt.zero_grad() # BEFORE logic if not self.disable_analyzer: - self._before_training_step(batch, batch_idx, batch_measure) + self._before_training_step( + batch, batch_idx, batch_measure + ) # Original training step output = super().training_step(batch, batch_idx) if not self.delegate_optimization: - # Backward pass - self.manual_backward(output) - # Optimizer step - opt.step() + # Scale loss for gradient accumulation + scaled_output = ( + output / self.accumulation_steps + ) + self.manual_backward(scaled_output) + + self._accumulation_count += 1 - # Step schedulers with interval='step' - if self._trainer is not None and self.trainer.lr_scheduler_configs: - for config in self.trainer.lr_scheduler_configs: - if config.interval == "step": - config.scheduler.step() + # Step optimizer only at end of accumulation cycle + if ( + self._accumulation_count + >= self.accumulation_steps + ): + opt.step() + self._optimizer_step_count += 1 + self._accumulation_count = 0 + + # Step schedulers with interval='step' + if ( + self._trainer is not None + and self.trainer.lr_scheduler_configs + ): + for config in ( + self.trainer.lr_scheduler_configs + ): + if config.interval == "step": + config.scheduler.step() # AFTER logic if not self.disable_analyzer: - self._after_training_step(batch, batch_idx, output) + self._after_training_step( + batch, batch_idx, output + ) return output @@ -263,6 +360,13 @@ def on_train_epoch_end(self): super().on_train_epoch_end() + @property + def effective_step(self): + """Return the effective optimizer step count.""" + if self.delegate_optimization: + return self.global_step + return self._optimizer_step_count + def _should_analyze(self, step: int) -> bool: """Determine if analysis should run at the given step.""" # If schedule provided, use it @@ -278,8 +382,9 @@ def _before_training_step(self, batch, batch_idx, cross_response_batch=None): """Hook executed before the wrapped training step. Computes analysis metrics including sample-wise gradients and - linearization probes. Only runs if the scheduler determines - this step should be analyzed. + linearization probes. When gradient accumulation is active, + metrics are accumulated across micro-batches and only logged + after the full accumulation cycle. Args: batch: Training batch containing input data and labels. @@ -289,67 +394,64 @@ def _before_training_step(self, batch, batch_idx, cross_response_batch=None): Returns: None """ - # Check if we should run analysis at this step - if not self._should_analyze(self.global_step): + if self.accumulation_steps == 1: + return self._analyze_single_step( + batch, batch_idx, cross_response_batch + ) + else: + return self._analyze_accumulated_step( + batch, batch_idx, cross_response_batch + ) + + def _analyze_single_step(self, batch, batch_idx, cross_response_batch=None): + """Run analysis for a single step (no accumulation).""" + if not self._should_analyze(self.effective_step): return None x, y = batch - - # Get cross-response batch if available + # Get cross-response batch if applicable x2, y2 = None, None if self.cross_response: x2, y2 = cross_response_batch samples_results = {} with BatchStatSnapshot(self.model, x): - # Compute samplewise metrics for the training batch + # Compute sample-wise metrics and self response samples_results["self"] = self.sample_calc.compute( - self.model, - self.criterion, - x, - y, + self.model, self.criterion, x, y, ) - # Compute samplewise metrics for the cross batch if available + # Compute sample-wise metrics and cross response if applicable if x2 is not None and y2 is not None: - probe_results_cross_preliminary = self.sample_calc.compute( - self.model, - self.criterion, - x2, - y2, + cross_preliminary = self.sample_calc.compute( + self.model, self.criterion, x2, y2, ) - samples_results["cross"] = self.sample_calc.compute_cross_metrics( - sample_wise_metrics_self=samples_results["self"], - sample_wise_metrics_cross=probe_results_cross_preliminary, + samples_results["cross"] = ( + self.sample_calc.compute_cross_metrics( + sample_wise_metrics_self=samples_results["self"], + sample_wise_metrics_cross=cross_preliminary, + ) ) - - # Linearizer compute + # Linearizer probe probe_results = self.linearizer.compute( - model=self.model, - criterion=self.criterion, - x1=x, - y1=y, - x2=x2, - y2=y2, + model=self.model, criterion=self.criterion, + x1=x, y1=y, x2=x2, y2=y2, ) # Get "self" result for coupling calculation loss_self, _, delta_loss_self = probe_results["self"] - - # Compute coupling value (using self response) chi_coup = self.coupling_calc.calculate( delta_loss=delta_loss_self, chi_loss=samples_results["self"]["batch_grad_norms_loss"], chi_net=samples_results["self"]["batch_grad_norms_network"], ) + chi_coup_cross = None if self.cross_response and "cross" in probe_results: - # Optionally, compute coupling for cross response as well - loss_cross, _, delta_loss_cross = probe_results["cross"] + _, _, delta_loss_cross = probe_results["cross"] chi_coup_cross = self.coupling_calc.calculate( delta_loss=delta_loss_cross, chi_loss=samples_results["cross"]["batch_grad_norms_loss"], chi_net=samples_results["cross"]["batch_grad_norms_network"], ) - # Log results with fixed metric names if self.log_metrics: self._log_analysis_results( @@ -359,7 +461,6 @@ def _before_training_step(self, batch, batch_idx, cross_response_batch=None): chi_coup=chi_coup, batch_size=x.shape[0], ) - # Log cross response if available if "cross" in samples_results and samples_results["cross"] is not None: self._log_analysis_results( @@ -369,11 +470,10 @@ def _before_training_step(self, batch, batch_idx, cross_response_batch=None): chi_coup=chi_coup_cross, batch_size=x2.shape[0] if x2 is not None else 0, ) - # Log window tracking info if using logarithmic schedule if self.analysis_schedule is not None: window_info = self.analysis_schedule.get_window_info( - self.global_step + self.effective_step ) if window_info is not None: self.log("window_id", window_info["window_id"]) @@ -382,6 +482,219 @@ def _before_training_step(self, batch, batch_idx, cross_response_batch=None): return None + def _analyze_accumulated_step(self, batch, batch_idx, cross_response_batch=None): + """Run analysis with gradient accumulation across micro-batches. + + On each micro-batch: accumulate sample-wise metrics and linearizer + gradients. On the last micro-batch of the cycle: finalize, log, clear. + """ + # On first micro-batch of cycle, decide whether to analyze + if self._accumulation_count == 0: + self._analysis_active = self._should_analyze( + self.effective_step + ) + if self._analysis_active: + self._clear_accumulation_buffers() + + if not self._analysis_active: + return None + + x, y = batch + x2, y2 = None, None + if self.cross_response: + x2, y2 = cross_response_batch + + with BatchStatSnapshot(self.model, x): + # Accumulate sample-wise metrics + self_metrics = self.sample_calc.compute( + self.model, self.criterion, x, y, + ) + self._accum_chi_net.append( + self_metrics["batch_grad_norms_network"] + ) + self._accum_chi_loss.append( + self_metrics["batch_grad_norms_loss"] + ) + + if x2 is not None and y2 is not None: + cross_preliminary = self.sample_calc.compute( + self.model, self.criterion, x2, y2, + ) + cross_metrics = self.sample_calc.compute_cross_metrics( + sample_wise_metrics_self=self_metrics, + sample_wise_metrics_cross=cross_preliminary, + ) + self._accum_cross_chi_net.append( + cross_metrics["batch_grad_norms_network"] + ) + self._accum_cross_chi_loss.append( + cross_metrics["batch_grad_norms_loss"] + ) + + # Accumulate linearizer gradients (train side) + self._accumulate_linearizer_grads( + x, y, is_train=True + ) + # Accumulate linearizer gradients (measure side) + # TODO: How would a measure batchsize different to the effective batch size work here? + # !!! We would need to accumulate separately and then combine at the end. + if x2 is not None and y2 is not None: + self._accumulate_linearizer_grads( + x2, y2, is_train=False + ) + + # On last micro-batch: finalize and log + is_last = ( + self._accumulation_count == self.accumulation_steps - 1 + ) + if is_last: + self._finalize_accumulated_analysis(x, x2) + + return None + + def _accumulate_linearizer_grads(self, x, y, is_train=True): + """Forward+backward on a micro-batch and add grads to accumulator.""" + self.model.zero_grad() + loss = self.criterion(self.model(x), y) + loss.backward() + + loss_val = loss.detach().item() + if is_train: + self._accum_train_loss += loss_val + if self._accum_grad_train is None: + self._accum_grad_train = [ + p.grad.clone() if p.grad is not None else None + for p in self.model.parameters() + ] + else: + for acc, p in zip( + self._accum_grad_train, self.model.parameters() + ): + if acc is not None and p.grad is not None: + acc.add_(p.grad) + else: + self._accum_measure_loss += loss_val + if self._accum_grad_measure is None: + self._accum_grad_measure = [ + p.grad.clone() if p.grad is not None else None + for p in self.model.parameters() + ] + else: + for acc, p in zip( + self._accum_grad_measure, self.model.parameters() + ): + if acc is not None and p.grad is not None: + acc.add_(p.grad) + + self.model.zero_grad() + + def _finalize_accumulated_analysis(self, x, x2): + """Combine accumulated metrics and log results.""" + K = self.accumulation_steps + B = x.shape[0] + + # Combine sample-wise metrics + chi_net_eff = sum(self._accum_chi_net) / K + chi_loss_eff = K * sum(self._accum_chi_loss) + + samples_result_self = { + "batch_grad_norms_network": chi_net_eff, + "batch_grad_norms_loss": chi_loss_eff, + } + + # Compute self linearizer result from accumulated grads + grad_norm_sq = sum( + (g ** 2).sum().item() + for g in self._accum_grad_train if g is not None + ) / (K ** 2) + + avg_train_loss = self._accum_train_loss / K + delta_loss_self = -grad_norm_sq + probe_result_self = ( + avg_train_loss, + avg_train_loss + delta_loss_self, + delta_loss_self, + ) + + chi_coup = self.coupling_calc.calculate( + delta_loss=delta_loss_self, + chi_loss=chi_loss_eff, + chi_net=chi_net_eff, + ) + + # Cross response + samples_result_cross = None + probe_result_cross = None + chi_coup_cross = None + if self._accum_grad_measure is not None: + chi_net_cross_eff = sum(self._accum_cross_chi_net) / K + chi_loss_cross_eff = K * sum(self._accum_cross_chi_loss) + samples_result_cross = { + "batch_grad_norms_network": chi_net_cross_eff, + "batch_grad_norms_loss": chi_loss_cross_eff, + } + + cross_dot = sum( + (g1 * g2).sum().item() + for g1, g2 in zip( + self._accum_grad_train, + self._accum_grad_measure, + ) + if g1 is not None and g2 is not None + ) / (K ** 2) + + avg_measure_loss = self._accum_measure_loss / K + delta_loss_cross = -cross_dot + probe_result_cross = ( + avg_measure_loss, + avg_measure_loss + delta_loss_cross, + delta_loss_cross, + ) + chi_coup_cross = self.coupling_calc.calculate( + delta_loss=delta_loss_cross, + chi_loss=chi_loss_cross_eff, + chi_net=chi_net_cross_eff, + ) + + # Log results + if self.log_metrics: + self._log_analysis_results( + prefix="", + samples_result=samples_result_self, + probe_result=probe_result_self, + chi_coup=chi_coup, + batch_size=B, + ) + if samples_result_cross is not None: + self._log_analysis_results( + prefix="cross_", + samples_result=samples_result_cross, + probe_result=probe_result_cross, + chi_coup=chi_coup_cross, + batch_size=x2.shape[0] if x2 is not None else 0, + ) + if self.analysis_schedule is not None: + window_info = self.analysis_schedule.get_window_info( + self.effective_step + ) + if window_info is not None: + self.log("window_id", window_info["window_id"]) + self.log("window_center", window_info["window_center"]) + self.log("window_width", window_info["window_width"]) + + self._clear_accumulation_buffers() + + def _clear_accumulation_buffers(self): + """Reset all accumulation buffers.""" + self._accum_chi_net.clear() + self._accum_chi_loss.clear() + self._accum_cross_chi_net.clear() + self._accum_cross_chi_loss.clear() + self._accum_grad_train = None + self._accum_grad_measure = None + self._accum_train_loss = 0.0 + self._accum_measure_loss = 0.0 + def _log_analysis_results( self, prefix: str, @@ -402,10 +715,15 @@ def _log_analysis_results( self.log(f"{prefix}chi_coup", chi_coup) self.log(f"{prefix}batch_size", batch_size) + if self.accumulation_steps > 1: + self.log( + f"{prefix}effective_batch_size", + batch_size * self.accumulation_steps, + ) # Only log analysis_step once (usually with empty prefix) if prefix == "": - self.log("analysis_step", self.global_step) + self.log("analysis_step", self.effective_step) # Log probe results (linearization) if probe_result is not None: @@ -443,5 +761,7 @@ def _after_training_step(self, batch, batch_idx, output): analyze_every=analyze_every, analysis_schedule=analysis_schedule, cross_response=cross_response, + micro_batch_size=micro_batch_size, + effective_batch_size=effective_batch_size, **model_kwargs, ) diff --git a/tests/unit/test_analyzer.py b/tests/unit/test_analyzer.py index 6046012..4dff146 100644 --- a/tests/unit/test_analyzer.py +++ b/tests/unit/test_analyzer.py @@ -717,10 +717,7 @@ def test_before_hook_skipped_when_not_scheduled( ): """Test _before_training_step skips analysis when not scheduled.""" model = analyzer(simple_lightning_module, analyze_every=10) - model._global_step = 5 # Not a multiple of 10 - - # Mock global_step property - type(model).global_step = property(lambda self: 5) + model._optimizer_step_count = 5 # Not a multiple of 10 model._before_training_step(sample_batch, 0) @@ -753,7 +750,6 @@ def test_logs_window_info_with_schedule( simple_lightning_module, analysis_schedule=schedule, log_metrics=True ) model.log = Mock() - type(model).global_step = property(lambda self: 0) model._before_training_step(sample_batch, 0) @@ -779,7 +775,6 @@ def test_no_window_info_without_schedule( model = analyzer(simple_lightning_module, log_metrics=True) model.log = Mock() - type(model).global_step = property(lambda self: 0) model._before_training_step(sample_batch, 0) @@ -787,3 +782,710 @@ def test_no_window_info_without_schedule( assert "window_id" not in logged_names assert "window_center" not in logged_names assert "window_width" not in logged_names + + +class TestGradientAccumulation: + """Test gradient accumulation functionality.""" + # The tests are categorized into sections A-J for clarity. + + # --- A. Parameter validation --- + + def test_accumulation_steps_default( + self, simple_lightning_module + ): + """No params → accumulation_steps=1.""" + model = analyzer(simple_lightning_module) + assert model.accumulation_steps == 1 + + def test_batch_size_only_no_accumulation( + self, simple_lightning_module + ): + """micro_batch_size alone → no accumulation, value stored.""" + model = analyzer( + simple_lightning_module, micro_batch_size=8 + ) + assert model.accumulation_steps == 1 + assert model.micro_batch_size == 8 + assert model.effective_batch_size is None + + def test_accumulation_steps_computed( + self, simple_lightning_module + ): + """micro=8, effective=32 → accumulation_steps=4.""" + model = analyzer( + simple_lightning_module, + micro_batch_size=8, + effective_batch_size=32, + ) + assert model.accumulation_steps == 4 + + def test_effective_without_micro_batch_raises( + self, simple_lightning_module + ): + """effective_batch_size alone → ValueError.""" + with pytest.raises( + ValueError, match="micro_batch_size must be specified" + ): + analyzer( + simple_lightning_module, + effective_batch_size=32, + ) + + def test_effective_less_than_micro_batch_raises( + self, simple_lightning_module + ): + """effective=8, micro=32 → ValueError.""" + with pytest.raises( + ValueError, match="must be >= micro_batch_size" + ): + analyzer( + simple_lightning_module, + micro_batch_size=32, + effective_batch_size=8, + ) + + def test_not_divisible_raises( + self, simple_lightning_module + ): + """effective=30, micro=8 → ValueError (not divisible).""" + with pytest.raises( + ValueError, match="must be divisible" + ): + analyzer( + simple_lightning_module, + micro_batch_size=8, + effective_batch_size=30, + ) + + def test_accumulation_with_delegate_raises( + self, manual_optimization_module + ): + """Accumulation + manual optimization module → ValueError.""" + with pytest.raises( + ValueError, + match="Gradient accumulation is not supported", + ): + with pytest.warns( + UserWarning, match="manual optimization" + ): + analyzer( + manual_optimization_module, + micro_batch_size=8, + effective_batch_size=32, + ) + + # --- B. Optimizer behavior --- + + def test_zero_grad_once_per_cycle( + self, simple_lightning_module, sample_batch + ): + """4 micro-steps, accum=4: zero_grad called exactly 1x.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=16, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + for i in range(4): + model.training_step((x, y), i) + + assert mock_opt.zero_grad.call_count == 1 + + def test_step_once_per_cycle( + self, simple_lightning_module, sample_batch + ): + """4 micro-steps, accum=4: opt.step() called exactly 1x.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=16, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + for i in range(4): + model.training_step((x, y), i) + + assert mock_opt.step.call_count == 1 + + def test_step_not_called_mid_cycle( + self, simple_lightning_module, sample_batch + ): + """3 of 4 micro-steps done: opt.step() never called.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=16, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + for i in range(3): + model.training_step((x, y), i) + + mock_opt.step.assert_not_called() + + def test_loss_scaled_for_backward( + self, simple_lightning_module, sample_batch + ): + """manual_backward receives loss / accumulation_steps.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=16, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + output = model.training_step((x, y), 0) + + backward_arg = model.manual_backward.call_args[0][0] + expected = output / 4 + assert torch.allclose(backward_arg, expected) + + def test_unscaled_loss_returned( + self, simple_lightning_module, sample_batch + ): + """training_step returns the original unscaled loss.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=16, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + output_accum = model.training_step((x, y), 0) + assert isinstance(output_accum, torch.Tensor) + + def test_two_full_cycles( + self, simple_lightning_module, sample_batch + ): + """4 steps, accum=2: zero_grad 2x, opt.step() 2x.""" + model = analyzer( + simple_lightning_module, + disable_analyzer=True, + micro_batch_size=4, + effective_batch_size=8, + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + for i in range(4): + model.training_step((x, y), i) + + assert mock_opt.zero_grad.call_count == 2 + assert mock_opt.step.call_count == 2 + + # --- C. Backwards compatibility --- + + def test_no_accumulation_backwards_compatible( + self, simple_lightning_module, sample_batch + ): + """No accum params: every call does zero_grad + step.""" + model = analyzer( + simple_lightning_module, disable_analyzer=True + ) + mock_opt = Mock(zero_grad=Mock(), step=Mock()) + model.optimizers = Mock(return_value=mock_opt) + model.manual_backward = Mock() + + x, y = sample_batch + for i in range(4): + model.training_step((x, y), i) + + assert mock_opt.zero_grad.call_count == 4 + assert mock_opt.step.call_count == 4 + + @patch.object(SamplewiseCalculatorOpacus, "compute") + @patch.object(Linearizer, "compute") + def test_single_step_path_unchanged( + self, mock_probe, mock_compute, simple_lightning_module, sample_batch + ): + """Without accumulation, _analyze_single_step produces same results.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.5), + "batch_grad_norms_loss": torch.tensor(2.5), + } + mock_probe.return_value = { + "self": (1.0, 0.0, -1.0), + "cross": None, + } + + model = analyzer(simple_lightning_module, log_metrics=True) + model.log = Mock() + x, y = sample_batch + + model._before_training_step((x, y), 0) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + assert torch.allclose(logged["chi_net"], torch.tensor(1.5)) + assert torch.allclose(logged["chi_loss"], torch.tensor(2.5)) + assert logged["loss"] == 1.0 + assert logged["grad_norm_squared"] == 1.0 + assert logged["batch_size"] == 4 + + # --- D. Effective step --- + + def test_effective_step_with_accumulation( + self, simple_lightning_module + ): + """_optimizer_step_count=2 → effective_step=2.""" + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=16, + ) + model._optimizer_step_count = 2 + assert model.effective_step == 2 + + model._optimizer_step_count = 0 + assert model.effective_step == 0 + + def test_effective_step_without_accumulation( + self, simple_lightning_module + ): + """_optimizer_step_count=42 → effective_step=42.""" + model = analyzer(simple_lightning_module) + model._optimizer_step_count = 42 + assert model.effective_step == 42 + + # --- E. Analysis scheduling with accumulation --- + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_analysis_uses_effective_step_for_scheduling( + self, mock_compute, simple_lightning_module, sample_batch + ): + """_should_analyze uses effective_step, activates on first micro-batch.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + model = analyzer( + simple_lightning_module, + analyze_every=2, + micro_batch_size=4, + effective_batch_size=16, + ) + model.log = Mock() + x, y = sample_batch + + # effective_step=0, analyze_every=2 → 0 % 2 == 0 → analyze + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + assert model._analysis_active is True + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_analysis_skipped_when_schedule_says_no( + self, mock_compute, simple_lightning_module, sample_batch + ): + """When _should_analyze returns False, no accumulation or logging happens.""" + model = analyzer( + simple_lightning_module, + analyze_every=10, + micro_batch_size=4, + effective_batch_size=8, + ) + model.log = Mock() + x, y = sample_batch + + # effective_step=1, analyze_every=10 → 1 % 10 != 0 → skip + model._optimizer_step_count = 1 + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + + assert model._analysis_active is False + mock_compute.assert_not_called() + model.log.assert_not_called() + + # Second micro-batch also skipped (flag persists) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + mock_compute.assert_not_called() + model.log.assert_not_called() + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_analysis_step_logs_effective_step( + self, mock_compute, simple_lightning_module, sample_batch + ): + """Logged analysis_step equals effective_step, not global_step.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + # _optimizer_step_count=3 → effective_step=3 + model._optimizer_step_count = 3 + + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + assert logged["analysis_step"] == 3 + + # --- F. Sample-wise metric accumulation --- + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_accumulated_chi_net_is_mean( + self, mock_compute, simple_lightning_module, sample_batch + ): + """chi_net_eff = mean of per-micro-batch chi_net values.""" + mock_compute.side_effect = [ + { + "batch_grad_norms_network": torch.tensor(2.0), + "batch_grad_norms_loss": torch.tensor(3.0), + }, + { + "batch_grad_norms_network": torch.tensor(4.0), + "batch_grad_norms_loss": torch.tensor(5.0), + }, + ] + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + + # chi_net_eff = mean([2.0, 4.0]) = 3.0 + assert torch.allclose(logged["chi_net"], torch.tensor(3.0)) + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_accumulated_chi_loss_is_k_times_sum( + self, mock_compute, simple_lightning_module, sample_batch + ): + """chi_loss_eff = K * sum of per-micro-batch chi_loss values.""" + mock_compute.side_effect = [ + { + "batch_grad_norms_network": torch.tensor(2.0), + "batch_grad_norms_loss": torch.tensor(3.0), + }, + { + "batch_grad_norms_network": torch.tensor(4.0), + "batch_grad_norms_loss": torch.tensor(5.0), + }, + ] + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + + # chi_loss_eff = K * sum([3.0, 5.0]) = 2 * 8.0 = 16.0 + assert torch.allclose(logged["chi_loss"], torch.tensor(16.0)) + + # --- G. Linearizer gradient accumulation --- + + def test_linearizer_accumulated_grad_norm( + self, simple_lightning_module, sample_batch + ): + """grad_norm_squared = ||Σ∇L_k||² / K² from accumulated grads.""" + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + # Run a full accumulation cycle (K=2) + with patch.object( + SamplewiseCalculatorOpacus, "compute", + return_value={ + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + }, + ): + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + + # grad_norm_squared should be a positive float + assert "grad_norm_squared" in logged + assert logged["grad_norm_squared"] > 0 + + # Verify manually: compute the expected value + # Do two forward+backward passes, sum grads, compute ||sum||²/K² + model.model.zero_grad() + loss0 = model.criterion(model.model(x), y) + loss0.backward() + grads_0 = [ + p.grad.clone() for p in model.model.parameters() + if p.grad is not None + ] + + model.model.zero_grad() + loss1 = model.criterion(model.model(x), y) + loss1.backward() + grads_1 = [ + p.grad.clone() for p in model.model.parameters() + if p.grad is not None + ] + model.model.zero_grad() + + expected_norm_sq = sum( + ((g0 + g1) ** 2).sum().item() + for g0, g1 in zip(grads_0, grads_1) + ) / 4 # K² = 2² = 4 + + assert abs(logged["grad_norm_squared"] - expected_norm_sq) < 1e-4 + + # --- H. Coupling with accumulated values --- + + def test_coupling_from_accumulated_values( + self, simple_lightning_module, sample_batch + ): + """coupling = grad_norm_sq / (chi_loss_eff * chi_net_eff).""" + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + with patch.object( + SamplewiseCalculatorOpacus, "compute", + return_value={ + "batch_grad_norms_network": torch.tensor(2.0), + "batch_grad_norms_loss": torch.tensor(3.0), + }, + ): + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + + # chi_net_eff = mean([2.0, 2.0]) = 2.0 + # chi_loss_eff = 2 * sum([3.0, 3.0]) = 12.0 + # coupling = grad_norm_sq / (12.0 * 2.0) + assert "chi_coup" in logged + expected_coupling = ( + logged["grad_norm_squared"] / (logged["chi_loss"] * logged["chi_net"]) + ) + assert abs(logged["chi_coup"] - expected_coupling) < 1e-5 + + # --- I. Logging behavior --- + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_logging_only_on_last_microbatch( + self, mock_compute, simple_lightning_module, sample_batch + ): + """Metrics logged once per cycle on the last micro-batch only.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + # First micro-batch — should NOT log yet + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + assert model.log.call_count == 0 + + # Second micro-batch (last) — should log + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + assert model.log.call_count > 0 + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_effective_batch_size_logged( + self, mock_compute, simple_lightning_module, sample_batch + ): + """effective_batch_size is logged when accumulation is active.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + logged = { + call[0][0]: call[0][1] + for call in model.log.call_args_list + } + + assert "effective_batch_size" in logged + # micro_batch_size=4, accumulation_steps=2 → 4*2=8 + assert logged["effective_batch_size"] == 8 + + # --- J. Buffer cleanup --- + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_buffers_cleared_after_cycle( + self, mock_compute, simple_lightning_module, sample_batch + ): + """Accumulation buffers are reset after a full cycle.""" + mock_compute.return_value = { + "batch_grad_norms_network": torch.tensor(1.0), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + # Run one full cycle + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + # Buffers should be cleared + assert len(model._accum_chi_net) == 0 + assert len(model._accum_chi_loss) == 0 + assert model._accum_grad_train is None + assert model._accum_grad_measure is None + assert model._accum_train_loss == 0.0 + + @patch.object(SamplewiseCalculatorOpacus, "compute") + def test_buffers_dont_leak_between_cycles( + self, mock_compute, simple_lightning_module, sample_batch + ): + """Second cycle doesn't contain data from the first cycle.""" + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + # Cycle 1: return 10.0, Cycle 2: return 20.0 + val = 10.0 if call_count[0] <= 2 else 20.0 + return { + "batch_grad_norms_network": torch.tensor(val), + "batch_grad_norms_loss": torch.tensor(1.0), + } + + mock_compute.side_effect = side_effect + + model = analyzer( + simple_lightning_module, + micro_batch_size=4, + effective_batch_size=8, + log_metrics=True, + ) + model.log = Mock() + x, y = sample_batch + + # Cycle 1 + model._accumulation_count = 0 + model._before_training_step((x, y), 0) + model._accumulation_count = 1 + model._before_training_step((x, y), 1) + + # Cycle 2 + model._accumulation_count = 0 + model._before_training_step((x, y), 2) + model._accumulation_count = 1 + model._before_training_step((x, y), 3) + + # Get chi_net from cycle 2 (the last logged value) + chi_net_calls = [ + call[0][1] + for call in model.log.call_args_list + if call[0][0] == "chi_net" + ] + # Cycle 1: mean([10, 10]) = 10, Cycle 2: mean([20, 20]) = 20 + assert len(chi_net_calls) == 2 + assert torch.allclose(chi_net_calls[0], torch.tensor(10.0)) + assert torch.allclose(chi_net_calls[1], torch.tensor(20.0))