Skip to content

Add ARO (Adaptively Rotated Optimization) optimizer#36

Open
JohnLangford wants to merge 18 commits intomicrosoft:mainfrom
JohnLangford:dev/aro
Open

Add ARO (Adaptively Rotated Optimization) optimizer#36
JohnLangford wants to merge 18 commits intomicrosoft:mainfrom
JohnLangford:dev/aro

Conversation

@JohnLangford
Copy link
Copy Markdown
Contributor

Summary

Implements the ARO algorithm (arxiv:2602.09006), which replaces Newton-Schulz orthogonalization with an adaptive rotation policy based on QR decomposition.

Algorithm per step:

  1. Update momentum: M = β·M + (1-β)·G
  2. Rotate into previous frame: R_prev^T @ M
  3. Apply base optimizer f_t (row_norm or sign)
  4. Cross-alignment + QR → new rotation R
  5. Rotate with new R, apply f_t → update direction U
  6. Apply: W -= η·U

Implementation:

  • Subclasses DistributedOrthoBase for distributed setup, step orchestration, Lion/AdamW tasks
  • R stored in float32 for QR stability, adding O(m²) memory per param
  • DDP megabatch distributes QR+matmul across ranks via all-gather
  • Base optimizer functions: row_norm (default), sign
  • FSDP not supported (R requires full row dimension)

Tested:

  • Single-GPU: rotation stays orthogonal (3.6e-7 error after 3 steps)
  • DDP 2×H100: exact cross-rank agreement (0.00 diff)

Test plan

  • Single-GPU smoke test (mixed shapes, megabatch, both base_opt)
  • DDP cross-rank consistency (2×H100)
  • Training loss comparison vs Muon/AdamW on a real model

JohnLangford and others added 18 commits March 28, 2026 17:23
Implements the ARO algorithm from https://arxiv.org/abs/2602.09006.
ARO maintains a per-parameter rotation matrix R ∈ SO(m) that is
updated each step via QR decomposition of a cross-alignment matrix
coupling the gradient to the base optimizer's transformation.

- Subclasses DistributedOrthoBase for shared infrastructure
- R stored in float32 for QR stability, adding O(m²) memory per param
- DDP megabatch distributes QR+matmul across ranks via all-gather
- Base optimizer functions: row_norm (default) and sign
- FSDP not supported (R requires full row dimension)
Instead of reimplementing the pad/stack/assign/all-gather pattern,
ARO now plugs QR decomposition into the shared megabatch infrastructure
via yield-from, matching how NorMuon and Dion2 plug in Newton-Schulz.

Pre-compute (cross-alignment) and post-compute (rotation → update)
are done locally; only the QR orthogonalization is distributed.
Implements SR-Sinkhorn normalization (alternating L2 row/column
normalization), the recommended base optimizer from the ARO paper.
Stateless, 5 iterations by default.
Move the full ARO computation (rotation, cross-alignment, QR, update
direction) into a closure passed to megabatch_orthogonalize_async.
For FSDP, the all-to-all reassembles full matrices before the closure
runs; for DDP, each rank runs the closure on its assigned chunk.
The closure captures R for the assigned params, so rotation state
stays consistent without a separate all-gather.
Delete rotated, f_rotated, and cross tensors as soon as they're no
longer needed, reducing peak memory during the QR decomposition.
cusolver's cusolverDnCreate fails with CUSOLVER_STATUS_INTERNAL_ERROR
when GPU memory is exhausted by the preceding all-to-all + float32
intermediates.
cusolver allocates workspace outside PyTorch's caching allocator,
so freed-but-cached tensor blocks aren't visible to it. Two changes:
- del M_f32 before QR (recompute it after from M_batch for phase 3)
- torch.cuda.empty_cache() before QR to release cached blocks
cusolverDnCreate fails with INTERNAL_ERROR under FSDP memory
pressure regardless of available memory. Switch to magma backend
for the QR call, restoring the previous backend after.
Both cusolver and magma fail under FSDP memory pressure because the
all-to-all reassembly of full matrices leaves no room for linalg
workspace. CPU QR sidesteps this — the cross matrix is square [m, m]
so the transfer and computation are cheap relative to the GPU matmuls.

The paper's approach (Shifted Cholesky QR / fully distributed rotation)
avoids full reassembly entirely but requires a larger architectural
change.
The padding block guarded by N > 1 skipped setting per_rank, but
the FSDP all-to-all path still referenced it when N=1, causing a
NameError. The corrupted async task then left CUDA state dirty,
making subsequent QR calls fail with cusolver INTERNAL_ERROR.

Fix: also enter the padding block when comm_dim is not None (FSDP),
regardless of N.
Use matmul + Cholesky + triangular solve instead of torch.linalg.qr.
This matches the ARO paper's recommended implementation and dion.py's
existing orthogonalize() pattern:
  G = A^T A + shift*I
  R = cholesky(G, upper=True)
  Q = solve_triangular(R, A, upper=True, left=False)

Cholesky QR uses far less GPU workspace than Householder QR, avoiding
the cusolver/magma OOM issues under FSDP memory pressure. Falls back
to Householder QR if the Cholesky factorization fails.
After forward+backward, PyTorch's caching allocator holds ~164GB
reserved but only ~2GB allocated. cusolver allocates outside the
caching allocator and only sees ~28GB free, causing cusolverDnCreate
to fail. torch.cuda.empty_cache() releases cached blocks back to
CUDA before the Cholesky/QR calls.

Also free M_f32 before the QR (recompute it after for phase 3)
to reduce peak memory during decomposition.
The earlier empty_cache before _shifted_cholesky_qr is ineffective
because G = A.mT @ A re-fills the caching allocator. The cache must
be released after G is computed but before cholesky_ex calls cusolver.
empty_cache cannot release blocks with pending operations on other
CUDA streams. The megabatch all-to-all and torch.compile ops may
run on separate streams, preventing cache release. synchronize()
ensures all pending ops complete before releasing cached blocks.
torch.compile on B200/CUDA 13.0 corrupts cusolver state after the
first compiled forward/backward pass, making ALL GPU linalg calls
fail (even 64x64 identity matrices). This is not memory-related —
190GB free and cusolver still broken.

Move the Cholesky decomposition to CPU. The Gram matrix is small
([batch, m, m]) so the CPU overhead is minimal. The matmuls (A^T A
and the triangular solve) stay on GPU via cuBLAS which is unaffected.
The shifted Cholesky QR moved cholesky to CPU but left
solve_triangular on GPU, which still hits the cusolver corruption.
Revert to the proven CPU QR approach that ran 569 steps successfully.
FSDP MixedPrecisionPolicy wraps the optimizer step in bf16 autocast,
causing the R_my.mT @ M_f32 matmul to fail with dtype mismatch.
Wrap aro_ortho_fn in torch.autocast("cuda", enabled=False) and
explicitly cast R_my to float32.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant