Add ARO (Adaptively Rotated Optimization) optimizer#36
Open
JohnLangford wants to merge 18 commits intomicrosoft:mainfrom
Open
Add ARO (Adaptively Rotated Optimization) optimizer#36JohnLangford wants to merge 18 commits intomicrosoft:mainfrom
JohnLangford wants to merge 18 commits intomicrosoft:mainfrom
Conversation
Implements the ARO algorithm from https://arxiv.org/abs/2602.09006. ARO maintains a per-parameter rotation matrix R ∈ SO(m) that is updated each step via QR decomposition of a cross-alignment matrix coupling the gradient to the base optimizer's transformation. - Subclasses DistributedOrthoBase for shared infrastructure - R stored in float32 for QR stability, adding O(m²) memory per param - DDP megabatch distributes QR+matmul across ranks via all-gather - Base optimizer functions: row_norm (default) and sign - FSDP not supported (R requires full row dimension)
Instead of reimplementing the pad/stack/assign/all-gather pattern, ARO now plugs QR decomposition into the shared megabatch infrastructure via yield-from, matching how NorMuon and Dion2 plug in Newton-Schulz. Pre-compute (cross-alignment) and post-compute (rotation → update) are done locally; only the QR orthogonalization is distributed.
Implements SR-Sinkhorn normalization (alternating L2 row/column normalization), the recommended base optimizer from the ARO paper. Stateless, 5 iterations by default.
Move the full ARO computation (rotation, cross-alignment, QR, update direction) into a closure passed to megabatch_orthogonalize_async. For FSDP, the all-to-all reassembles full matrices before the closure runs; for DDP, each rank runs the closure on its assigned chunk. The closure captures R for the assigned params, so rotation state stays consistent without a separate all-gather.
Delete rotated, f_rotated, and cross tensors as soon as they're no longer needed, reducing peak memory during the QR decomposition. cusolver's cusolverDnCreate fails with CUSOLVER_STATUS_INTERNAL_ERROR when GPU memory is exhausted by the preceding all-to-all + float32 intermediates.
cusolver allocates workspace outside PyTorch's caching allocator, so freed-but-cached tensor blocks aren't visible to it. Two changes: - del M_f32 before QR (recompute it after from M_batch for phase 3) - torch.cuda.empty_cache() before QR to release cached blocks
cusolverDnCreate fails with INTERNAL_ERROR under FSDP memory pressure regardless of available memory. Switch to magma backend for the QR call, restoring the previous backend after.
Both cusolver and magma fail under FSDP memory pressure because the all-to-all reassembly of full matrices leaves no room for linalg workspace. CPU QR sidesteps this — the cross matrix is square [m, m] so the transfer and computation are cheap relative to the GPU matmuls. The paper's approach (Shifted Cholesky QR / fully distributed rotation) avoids full reassembly entirely but requires a larger architectural change.
The padding block guarded by N > 1 skipped setting per_rank, but the FSDP all-to-all path still referenced it when N=1, causing a NameError. The corrupted async task then left CUDA state dirty, making subsequent QR calls fail with cusolver INTERNAL_ERROR. Fix: also enter the padding block when comm_dim is not None (FSDP), regardless of N.
Use matmul + Cholesky + triangular solve instead of torch.linalg.qr. This matches the ARO paper's recommended implementation and dion.py's existing orthogonalize() pattern: G = A^T A + shift*I R = cholesky(G, upper=True) Q = solve_triangular(R, A, upper=True, left=False) Cholesky QR uses far less GPU workspace than Householder QR, avoiding the cusolver/magma OOM issues under FSDP memory pressure. Falls back to Householder QR if the Cholesky factorization fails.
After forward+backward, PyTorch's caching allocator holds ~164GB reserved but only ~2GB allocated. cusolver allocates outside the caching allocator and only sees ~28GB free, causing cusolverDnCreate to fail. torch.cuda.empty_cache() releases cached blocks back to CUDA before the Cholesky/QR calls. Also free M_f32 before the QR (recompute it after for phase 3) to reduce peak memory during decomposition.
The earlier empty_cache before _shifted_cholesky_qr is ineffective because G = A.mT @ A re-fills the caching allocator. The cache must be released after G is computed but before cholesky_ex calls cusolver.
empty_cache cannot release blocks with pending operations on other CUDA streams. The megabatch all-to-all and torch.compile ops may run on separate streams, preventing cache release. synchronize() ensures all pending ops complete before releasing cached blocks.
torch.compile on B200/CUDA 13.0 corrupts cusolver state after the first compiled forward/backward pass, making ALL GPU linalg calls fail (even 64x64 identity matrices). This is not memory-related — 190GB free and cusolver still broken. Move the Cholesky decomposition to CPU. The Gram matrix is small ([batch, m, m]) so the CPU overhead is minimal. The matmuls (A^T A and the triangular solve) stay on GPU via cuBLAS which is unaffected.
The shifted Cholesky QR moved cholesky to CPU but left solve_triangular on GPU, which still hits the cusolver corruption. Revert to the proven CPU QR approach that ran 569 steps successfully.
FSDP MixedPrecisionPolicy wraps the optimizer step in bf16 autocast,
causing the R_my.mT @ M_f32 matmul to fail with dtype mismatch.
Wrap aro_ortho_fn in torch.autocast("cuda", enabled=False) and
explicitly cast R_my to float32.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements the ARO algorithm (arxiv:2602.09006), which replaces Newton-Schulz orthogonalization with an adaptive rotation policy based on QR decomposition.
Algorithm per step:
Implementation:
DistributedOrthoBasefor distributed setup, step orchestration, Lion/AdamW tasksrow_norm(default),signTested:
Test plan