fix: SVD convergence crash via CPU LAPACK fallback#393
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fallback mechanism for low-rank SVD computation. If torch.svd_lowrank fails to converge due to an ill-conditioned matrix, the code now catches torch.linalg.LinAlgError and falls back to exact SVD on the CPU. The reviewer suggested using the existing LA alias for torch.linalg to maintain consistency and adding .detach() before moving the tensor to the CPU to prevent potential autograd issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| except torch.linalg.LinAlgError: | ||
| # SVD failed to converge (usually because Optuna tested an extreme | ||
| # hyperparameter combination, resulting in an ill-conditioned matrix). | ||
| # Fall back to CPU LAPACK exact SVD. | ||
| U, S, Vh = torch.linalg.svd(W.cpu(), full_matrices=False) | ||
| U = U[:, :r].to(W.device) | ||
| S = S[:r].to(W.device) | ||
| Vh = Vh[:r, :].to(W.device) |
There was a problem hiding this comment.
For better consistency and defensive programming, consider making the following improvements:
- Consistency: The file already imports
torch.linalg as LA(on line 11) and uses it elsewhere (e.g.,LA.vector_normon line 562). UsingLA.LinAlgErrorandLA.svdkeeps the codebase consistent. - Defensive Programming: If
Whappens to haverequires_grad=True(for example, if the base model is not fully frozen in a custom setup), callingtorch.linalg.svdon CPU with a gradient-tracking tensor will raise aRuntimeErrorbecause SVD backward is not supported on CPU in PyTorch. Calling.detach()before moving to CPU avoids this issue entirely and prevents unnecessary autograd overhead.
| except torch.linalg.LinAlgError: | |
| # SVD failed to converge (usually because Optuna tested an extreme | |
| # hyperparameter combination, resulting in an ill-conditioned matrix). | |
| # Fall back to CPU LAPACK exact SVD. | |
| U, S, Vh = torch.linalg.svd(W.cpu(), full_matrices=False) | |
| U = U[:, :r].to(W.device) | |
| S = S[:r].to(W.device) | |
| Vh = Vh[:r, :].to(W.device) | |
| except LA.LinAlgError: | |
| # SVD failed to converge (usually because Optuna tested an extreme | |
| # hyperparameter combination, resulting in an ill-conditioned matrix). | |
| # Fall back to CPU LAPACK exact SVD. | |
| U, S, Vh = LA.svd(W.detach().cpu(), full_matrices=False) | |
| U = U[:, :r].to(W.device) | |
| S = S[:r].to(W.device) | |
| Vh = Vh[:r, :].to(W.device) |
c91e445 to
dd66d77
Compare
Are you sure? There is absolutely nothing "extreme" about the parameter values from #345, and I fail to see how any combination of values from our ranges could ever produce an ill-conditioned matrix. |
dd66d77 to
1e191ec
Compare
You're right — the parameters aren't extreme at all, and my description was wrong. The actual root cause is the cuSOLVER driver's iterative SVD algorithm failing to converge on certain matrices, which the traceback itself confirms: "During SVD computation with the selected cusolver driver, batches 0 failed to converge." This is independent of the hyperparameter values. The CPU LAPACK implementation handles the same matrix without issue. I've corrected the code comment accordingly. |
What are those "certain matrices", and how can this be reproduced on a smaller model? |
|
The failure isn't in |
|
If the problem is with CUDA rather than with |
Updated to use svd_lowrank on CPU in the fallback instead of switching to LA.svd |
1e191ec to
445134c
Compare
Fixes #345.
torch.svd_lowrank internally calls torch.linalg.svd on the GPU, where the cuSOLVER driver's iterative algorithm can fail to converge on certain matrices, crashing the entire study with LinAlgError. This wraps the call in a try...except block and falls back to exact CPU LAPACK SVD via LA.svd, which handles the same decomposition without issue.