Skip to content

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 31, 2026

Description

Problem

Using Float8BlockQuantizer with sequence parallel fails with AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizer when local tensor dimensions aren't divisible by 128.

Solution

Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather

###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.

Verification

The code change does not alter convergence behavior
image

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Chen Cui <chcui@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 31, 2026

Greptile Overview

Greptile Summary

Fixed Float8BlockQuantizer usage with sequence parallel when local tensor dimensions aren't divisible by 128.

Key Changes:

  • Skipped the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim has a fallback path
  • Enhanced the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
  • Corrected dtype usage to use the actual tensor dtype after dequantization rather than the initial guess

Impact:
The changes enable Float8BlockQuantizer to work with sequence parallel in cases where tensor dimensions aren't quantizable (not divisible by 128), though this uses a fallback path with higher communication overhead (high-precision all-gather followed by quantization).

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - it fixes a specific crash scenario and adds proper fallback handling
  • The changes are focused and address a specific issue with sequence parallel when dimensions aren't divisible by 128. The logic correctly handles already-quantized inputs in the fallback path, and the early return for Float8BlockQuantizer in the assertion function is well-justified by the existing fallback mechanism
  • No files require special attention - both changes are straightforward fixes

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Fixed fallback path in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
transformer_engine/pytorch/utils.py Added early return in assert_dim_for_all_gather for Float8BlockQuantizer to skip assertion check since fallback path exists

Sequence Diagram

sequenceDiagram
    participant Caller
    participant gather_along_first_dim
    participant assert_dim_for_all_gather
    participant _start_all_gather_fp8_blockwise
    participant Float8BlockwiseQTensorStorage
    participant all_gather_into_tensor

    Caller->>gather_along_first_dim: inp (tensor), quantizer (Float8BlockQuantizer)
    gather_along_first_dim->>_start_all_gather_fp8_blockwise: inp, quantizer
    
    alt Input not quantizable (dims not divisible by 128) OR block_scaling_dim != 1
        _start_all_gather_fp8_blockwise->>_start_all_gather_fp8_blockwise: Check if inp is Float8BlockwiseQTensorStorage
        alt Input is already quantized
            _start_all_gather_fp8_blockwise->>Float8BlockwiseQTensorStorage: dequantize()
            Float8BlockwiseQTensorStorage-->>_start_all_gather_fp8_blockwise: high-precision tensor (float32)
        end
        _start_all_gather_fp8_blockwise->>all_gather_into_tensor: Gather in high precision
        all_gather_into_tensor-->>_start_all_gather_fp8_blockwise: gathered tensor
        _start_all_gather_fp8_blockwise->>_start_all_gather_fp8_blockwise: quantizer(out)
        _start_all_gather_fp8_blockwise-->>gather_along_first_dim: quantized result
    else Input is quantizable
        _start_all_gather_fp8_blockwise->>_start_all_gather_fp8_blockwise: Quantize if needed
        _start_all_gather_fp8_blockwise->>all_gather_into_tensor: Gather FP8 data
        all_gather_into_tensor-->>_start_all_gather_fp8_blockwise: gathered FP8 tensor
        _start_all_gather_fp8_blockwise-->>gather_along_first_dim: FP8 result
    end
    
    gather_along_first_dim-->>Caller: result
    
    Note over assert_dim_for_all_gather: Skip assertion for Float8BlockQuantizer<br/>since fallback path handles non-quantizable tensors
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant