vllm - 💡(How to fix) Fix [RFC]: Expose per-parameter sharding as DTensor metadata [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#40429Fetched 2026-04-22 07:45:41
View on GitHub
Comments
1
Participants
1
Timeline
1
Reactions
0
Author
Participants
Timeline (top)
commented ×1

Fix Action

Fix / Workaround

Coverage, patch size, compatibility

Patch: vllm/distributed/sharding.py ~50 LOC, 1–2 lines per primitive (~8 classes). ~75 LOC added, 0 removed, fully additiveoutput_dim / input_dim / weight_loader preserved.

  • model.to_dtensors() returning live DTensors. Ties public API to torch runtime, allocates wrappers, harder to serialize cross-process.
  • Central {name: spec} registry. Needs a walker with per-type dispatch — reintroduces the exact duplication this RFC eliminates.
  • Enrich weight_loader with declarative metadata. Muddles its single responsibility; imperative code reads poorly as declaration.
  • Model-level get_sharding_spec() walking internally. Pays walker cost inside vLLM; per-param attachment is strictly simpler.

Code Example

@dataclass(frozen=True)
class ShardingSpec:
    mesh: str                            # logical name: "tp" / "ep" / "ep_tp" / ...
    placements: tuple[Placement, ...]    # one per mesh dim

def attach(param, mesh: str, placements: tuple[Placement, ...]) -> None:
    """Called by a parallel primitive in its __init__ to declare sharding."""

def get(param) -> ShardingSpec | None:
    """Read the spec. None = replicated / unsharded."""

def get_device_meshes(parallel_config) -> dict[str, DeviceMesh]:
    """Build DeviceMeshes referenced by logical name."""

---

meshes = sharding.get_device_meshes(parallel_config)
for name, p in model.named_parameters():
    spec = sharding.get(p)
    if spec is None:
        continue
    dt = DTensor.from_local(p, meshes[spec.mesh], spec.placements)

---

spec_by_name, meshes = fetch_from_vllm()
for name, global_tensor in trainer_state_dict.items():
    spec = spec_by_name[name]
    dt = distribute_tensor(global_tensor, meshes[spec.mesh], spec.placements)
    send(dt.to_local(), to_rank=mesh_rank_of(dt))      # exact shape, no waste

---

for name, p in model.named_parameters():
    spec = sharding.get(p)
    if spec is None: continue
    p.data.copy_(recv(from_rank=sender_for(name)))     # no slicing

---

def test_every_sharded_param_has_spec():
    for name, p in llm.model.named_parameters():
        if sharding.get(p) is None:
            assert p.shape == reference_global_shape[name]
RAW_BUFFERClick to expand / collapse

Motivation.

A growing class of deployments treats vLLM as an inference backend that receives weight updates from an external process — RL rollouts (GRPO/PPO), online fine-tuning, reward-model swap-in, disaggregated prefill/decode refresh. The sender must move each global tensor from its training-side layout (FSDP / HSDP / 3D parallelism / FP8) to vLLM's inference-side layout (TP + optional EP + PP + attention-DP).

To do this without duplicating data on the wire, the sender must know, for every vLLM parameter, which slice of the global tensor lives on which vLLM rank — i.e. a sharding spec. vLLM has this knowledge; the parallel primitives (ColumnParallelLinear, QKVParallelLinear, FusedMoE, ...) embed it in their constructors and weight_loader methods. But there is no queryable API. Consumers are forced into one of:

  1. Send full tensors and slice on vLLM side — wastes ~TP× bandwidth, dominates weight-sync time at scale.
  2. Hard-code vLLM's sharding rules externally — a growing isinstance(module, X) → Placement table, brittle across versions, silently broken when a new primitive lands, duplicated across every downstream (RL framework, checkpoint tool, debugger).
  3. Reverse-engineer weight_loader — not feasible without running it or reading source.

Multiple projects today carry their own private mirror of these rules. This RFC eliminates the duplication by exposing the knowledge as DTensor metadata. torch.distributed.tensor already provides the exact vocabulary we need — DeviceMesh for topology, Placement (Shard(dim), Replicate(), Partial(op)) for per-dim semantics — as stable PyTorch API that interoperates with DCP, FSDP2, distribute_tensor. Using it as the exchange format means senders compute the right local slice with no vLLM-specific logic, and new parallel primitives require no sender-side change.

Use cases served: zero-duplicate weight transfer, distributed checkpointing (save/resume at a different TP/EP size), model-parallel validation, cross-framework state bridges. Out of scope (handled elsewhere): HF↔vLLM name translation (WeightsMapper), logical decomposition of fused params (packed_modules_mapping), training-framework-specific weight conversion, KV cache sharding.

Proposed Change.

A new module vllm.distributed.sharding with four symbols:

@dataclass(frozen=True)
class ShardingSpec:
    mesh: str                            # logical name: "tp" / "ep" / "ep_tp" / ...
    placements: tuple[Placement, ...]    # one per mesh dim

def attach(param, mesh: str, placements: tuple[Placement, ...]) -> None:
    """Called by a parallel primitive in its __init__ to declare sharding."""

def get(param) -> ShardingSpec | None:
    """Read the spec. None = replicated / unsharded."""

def get_device_meshes(parallel_config) -> dict[str, DeviceMesh]:
    """Build DeviceMeshes referenced by logical name."""

Consumer pattern:

meshes = sharding.get_device_meshes(parallel_config)
for name, p in model.named_parameters():
    spec = sharding.get(p)
    if spec is None:
        continue
    dt = DTensor.from_local(p, meshes[spec.mesh], spec.placements)

Each parallel primitive adds one attach call in its __init__:

PrimitiveAttach call
ColumnParallelLinearattach(self.weight, "tp", (Shard(0),))
RowParallelLinearattach(self.weight, "tp", (Shard(1),))
QKVParallelLinear, MergedColumnParallelLinearattach(self.weight, "tp", (Shard(0),))
VocabParallelEmbedding, ParallelLMHeadattach(self.weight, "tp", (Shard(0),))
ReplicatedLinear, RMSNorm, LayerNormomit — None means replicate
FusedMoE (EP-only)attach(self.w13_weight, "ep", (Shard(0),))
FusedMoE (TP+EP)attach(self.w13_weight, "ep_tp", (Shard(0), Shard(1)))

Worked example — zero-duplicate weight send

Sender (external trainer):

spec_by_name, meshes = fetch_from_vllm()
for name, global_tensor in trainer_state_dict.items():
    spec = spec_by_name[name]
    dt = distribute_tensor(global_tensor, meshes[spec.mesh], spec.placements)
    send(dt.to_local(), to_rank=mesh_rank_of(dt))      # exact shape, no waste

Receiver (vLLM worker):

for name, p in model.named_parameters():
    spec = sharding.get(p)
    if spec is None: continue
    p.data.copy_(recv(from_rank=sender_for(name)))     # no slicing

No tensor is duplicated on the wire. No vLLM-internal module type is inspected by the sender. A new parallel primitive added to vLLM requires no sender-side change.

Design choices

  1. Spec lives on the Parameter, alongside existing output_dim/weight_loader. Zero coordination — creator declares, consumer reads named_parameters(). No walker, no isinstance, no registry.
  2. Physical sharding only. Fused-param decomposition stays in packed_modules_mapping; this API stays narrow.
  3. Mesh by string name. Keeps ShardingSpec small and serializable; get_device_meshes() is the single mesh factory.
  4. Uses torch.distributed.tensor.Placement directly. No vLLM-specific enum; plugs into DTensor / DCP / distribute_tensor out of the box.

Coverage, patch size, compatibility

Initial PR: dense + MoE primitives above. Follow-ups: quantized layers (AWQ/GPTQ/FP8), MLA / attention-DP (new "dp_attn" mesh), PP cross-rank aggregation.

Patch: vllm/distributed/sharding.py ~50 LOC, 1–2 lines per primitive (~8 classes). ~75 LOC added, 0 removed, fully additiveoutput_dim / input_dim / weight_loader preserved.

One CI test prevents regression:

def test_every_sharded_param_has_spec():
    for name, p in llm.model.named_parameters():
        if sharding.get(p) is None:
            assert p.shape == reference_global_shape[name]

Alternatives considered

  • model.to_dtensors() returning live DTensors. Ties public API to torch runtime, allocates wrappers, harder to serialize cross-process.
  • Central {name: spec} registry. Needs a walker with per-type dispatch — reintroduces the exact duplication this RFC eliminates.
  • Enrich weight_loader with declarative metadata. Muddles its single responsibility; imperative code reads poorly as declaration.
  • Model-level get_sharding_spec() walking internally. Pays walker cost inside vLLM; per-param attachment is strictly simpler.

Open questions

  • Mesh-name vocabulary: fix {tp, pp, dp, ep, dp_attn, ep_tp} or keep open strings? Recommendation: fix.
  • Attribute name: _vllm_sharding (private) vs sharding_spec (matching output_dim)? Recommendation: _vllm_sharding.
  • HTTP endpoint: include GET /v1/sharding_spec in this RFC or defer? Recommendation: defer.
  • Partial placements: currently unused in vLLM; vocabulary kept for forward compatibility.

Feedback Period.

Two weeks from publication.

CC List.

(Seeking input from contributors familiar with vllm/distributed/, vllm/model_executor/layers/, and RL weight-sync integrations.)

Any Other Things.

Reference implementation will follow as a draft PR after initial alignment. Happy to split — (1) sharding module + dense primitives, (2) FusedMoE, (3) quantized — if that helps review.

extent analysis

TL;DR

To address the issue of duplicating data on the wire when sending weight updates from an external process to vLLM, implement a new module vllm.distributed.sharding that exposes sharding specifications as DTensor metadata.

Guidance

  • Implement the proposed ShardingSpec dataclass and functions (attach, get, get_device_meshes) to enable querying of sharding specifications for vLLM parameters.
  • Update parallel primitives (e.g., ColumnParallelLinear, QKVParallelLinear) to attach their sharding specifications using the attach function.
  • Modify the sender (external trainer) to fetch sharding specifications from vLLM and use them to compute the correct local slice of the global tensor for transfer.
  • Verify the fix by checking that the sender correctly computes the local slice and that the receiver (vLLM worker) can copy the received tensor without slicing.

Example

# In vllm/distributed/sharding.py
@dataclass(frozen=True)
class ShardingSpec:
    mesh: str
    placements: tuple[Placement, ...]

def attach(param, mesh: str, placements: tuple[Placement, ...]) -> None:
    # implementation

def get(param) -> ShardingSpec | None:
    # implementation

# In parallel primitive (e.g., ColumnParallelLinear)
def __init__(self, ...):
    # ...
    attach(self.weight, "tp", (Shard(0),))

Notes

The proposed solution assumes that the torch.distributed.tensor API is available and compatible with the vLLM framework. Additionally, the implementation of the attach, get, and get_device_meshes functions is not provided and will require careful consideration of the vLLM architecture and parallel primitives.

Recommendation

Apply the proposed workaround by implementing the vllm.distributed.sharding module and updating the parallel primitives to attach their sharding specifications. This approach eliminates the need for duplicating data on the wire and provides a scalable solution for weight updates.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING