Skip to content

Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422

Open
Perseus14 wants to merge 1 commit into
mainfrom
ltx2-improvements
Open

Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422
Perseus14 wants to merge 1 commit into
mainfrom
ltx2-improvements

Conversation

@Perseus14

Copy link
Copy Markdown
Collaborator

Description

This PR is a comprehensive refactor and optimization sweep. It brings massive improvements to XLA compilation times, memory usage (HBM), and architectural hygiene by stripping out redundant compute, unifying duplicated logic, and optimizing JAX tracing.

🧹 Architectural Hygiene & Code Quality

  • Unified Transformer Block Application: Solved the "quadruple code duplication" hazard in LTX2VideoTransformer3DModel.__call__. The 4 separate block execution paths (scan vs. loop, perturbation vs. no-perturbation) have been consolidated using a new TransformerContext container and a single apply_block helper function.
  • Removed Dead Code in prepare_video_coords: Deleted the wasteful 5D latent_coords block in attention_ltx2.py that was computing an unused, wrongly-shaped tensor only to immediately overwrite it.
  • Simplified apply_split_rotary_emb: Cleaned up the convoluted reshape/broadcast logic for split RoPE. Removed the redundant expand_dims and squeeze operations, executing the rotation directly (first_x * cos - second_x * sin) to avoid allocating unnecessary intermediate 5D tensors.
  • Removed Redundant Guards: Dropped the unnecessary hasattr(self, "rope_type") check in LTX2Attention.
  • Missing PRNG Fallback: Added a missing max_logging.log warning when defaulting to a zero-seed jax.random.key(0) for noise generation.

⚡ XLA & JAX Compilation Optimizations

  • Dynamic Guidance Scales: Removed continuous hyperparameter floats (guidance_scale, stg_scale, audio_guidance_scale, etc.) from static_argnames in run_diffusion_loop(). Tweaking these generation scales will no longer trigger expensive 10-30 minute JAX recompilations!
  • Dynamic JAX Control Flow: Replaced the static Python if guidance_rescale > 0: check inside the compiled diffusion loop with jax.lax.cond. This enables the CFG rescaling logic to be fully dynamic, complementing the removal of the static scales and fixing formulation inconsistencies.
  • Standardized Scan Loop: Replaced nnx.scan with standard jax.lax.scan for the primary denoising timestep loop to ensure predictable compilation.
  • Fixed RuntimeProgramInputMismatch for scan_layers=False: Resolved an issue where XLA would fail during warmup compilation due to unrolled layer layout mismatches. Added explicit @jax.jit wrappers with jax.lax.with_sharding_constraint to enforce layout transpositions before crossing into run_diffusion_loop.

🧠 Memory (HBM) Optimizations

  • Direct Dtype Typecasting: Streamlined text encoder state extraction by evaluating target_dtype upfront and mapping it directly, avoiding a redundant double-casting pipeline that was passing through bfloat16.

@Perseus14 Perseus14 requested a review from entrpn as a code owner June 19, 2026 21:51
@github-actions

Copy link
Copy Markdown

@Perseus14 Perseus14 self-assigned this Jun 19, 2026
@Perseus14 Perseus14 requested a review from prishajain1 June 19, 2026 21:51
@Perseus14 Perseus14 changed the title LTX2.3 improvements and bug fixes Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations Jun 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant