pytorch - ✅(Solved) Fix bag boundry is not calculated correctly in EmbeddingBag when setting include_last_offset = True [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178614Fetched 2026-04-08 01:40:31
View on GitHub
Comments
1
Participants
2
Timeline
33
Reactions
0
Participants
Timeline (top)
mentioned ×11subscribed ×11labeled ×8commented ×1

Fix Action

Fixed

PR fix notes

PR #178546: Fix: Repsect include_last_offset in cuda EmbeddingBag

Description (problem / solution / changelog)

Both EmbeddingBag_updateOutputKernel_max and EmbeddingBag_updateOutputKernel_sum_mean were computing the end index for the last bag by unconditionally using numIndices. When include_last_offset=true the offsets tensor has an extra sentinel element (equal to numIndices), so the kernels should read offsets[bag+1] for the last bag just as they do for all others.

Add include_last_offset as a parameter to both kernels and update the end-index computation accordingly, passing the flag through from _embedding_bag_cuda at the call sites.

fix: https://github.com/pytorch/pytorch/issues/178614

Changed files

  • aten/src/ATen/native/cuda/EmbeddingBag.cu (modified, +6/-6)
  • test/nn/test_embedding.py (modified, +20/-0)

Code Example

import torch
weight = torch.tensor(
    [[1.,  2.],
     [3.,  4.],
     [5.,  6.],
     [7.,  8.],
     [99., 99.]],  # row 4 — should NOT appear in any bag
    dtype=torch.float32,
)
# Two bags: bag 0 = rows 0,1 | bag 1 = rows 2,3
# Sentinel=4 means last bag ends at position 4 (row 4 is excluded)
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
offsets = torch.tensor([0, 2, 4],       dtype=torch.long)
eb_cpu = torch.nn.EmbeddingBag(5, 2, mode="sum", include_last_offset=True)
eb_cpu.weight = torch.nn.Parameter(weight.clone())
out_cpu = eb_cpu(indices, offsets)
eb_cuda = torch.nn.EmbeddingBag(5, 2, mode="sum", include_last_offset=True).cuda()
eb_cuda.weight = torch.nn.Parameter(weight.clone().cuda())
out_cuda = eb_cuda(indices.cuda(), offsets.cuda())
print(f"CPU  bag1: {out_cpu[1].tolist()}")   # [12.0, 14.0]print(f"CUDA bag1: {out_cuda[1].tolist()}")  # [111.0, 113.0]  ✗ includes row 4

---

CPU  bag1: [12.0, 14.0]
CUDA bag1: [12.0, 14.0]

---

CPU  bag1: [12.0, 14.0]
CUDA bag1: [111.0, 113.0]
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When include_last_offset=True, the offsets tensor has one extra sentinel element at the end that defines where the last bag ends. If the sentinel is less than len(indices), the remaining indices after the sentinel should be excluded from all bags.

On CPU this works correctly. On CUDA, the forward kernel did not receive include_last_offset as a parameter, so it always computed the last bag's end as len(indices) and incorrectly pulled in the trailing indices.

import torch
weight = torch.tensor(
    [[1.,  2.],
     [3.,  4.],
     [5.,  6.],
     [7.,  8.],
     [99., 99.]],  # row 4 — should NOT appear in any bag
    dtype=torch.float32,
)
# Two bags: bag 0 = rows 0,1 | bag 1 = rows 2,3
# Sentinel=4 means last bag ends at position 4 (row 4 is excluded)
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
offsets = torch.tensor([0, 2, 4],       dtype=torch.long)
eb_cpu = torch.nn.EmbeddingBag(5, 2, mode="sum", include_last_offset=True)
eb_cpu.weight = torch.nn.Parameter(weight.clone())
out_cpu = eb_cpu(indices, offsets)
eb_cuda = torch.nn.EmbeddingBag(5, 2, mode="sum", include_last_offset=True).cuda()
eb_cuda.weight = torch.nn.Parameter(weight.clone().cuda())
out_cuda = eb_cuda(indices.cuda(), offsets.cuda())
print(f"CPU  bag1: {out_cpu[1].tolist()}")   # [12.0, 14.0]  ✓
print(f"CUDA bag1: {out_cuda[1].tolist()}")  # [111.0, 113.0]  ✗ includes row 4

Expected output

CPU  bag1: [12.0, 14.0]
CUDA bag1: [12.0, 14.0]

Actual output

CPU  bag1: [12.0, 14.0]
CUDA bag1: [111.0, 113.0]

Versions

version: 2.10.0

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia

extent analysis

Fix Plan

To fix the issue, we need to modify the EmbeddingBag module to correctly handle the include_last_offset parameter on CUDA devices.

Here are the steps:

  • Modify the forward method of EmbeddingBag to accept include_last_offset as a parameter.
  • Update the CUDA kernel to correctly compute the last bag's end when include_last_offset is True.

Example Code

import torch
import torch.nn as nn

class CustomEmbeddingBag(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, mode, include_last_offset):
        super(CustomEmbeddingBag, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.mode = mode
        self.include_last_offset = include_last_offset

    def forward(self, indices, offsets):
        if self.include_last_offset:
            # Correctly compute the last bag's end
            last_bag_end = offsets[-1]
            indices = indices[:last_bag_end]
        embeddings = self.embedding(indices)
        if self.mode == "sum":
            output = torch.zeros((len(offsets) - 1, embeddings.shape[1]), device=embeddings.device)
            start =            for i in range(len(offsets) - 1):
                end = offsets[i + 1]
                output[i] = torch.sum(embeddings[start:end], dim=0)
                start = end
        return output

# Usage
weight = torch.tensor(
    [[1.,  2.],
     [3.,  4.],
     [5.,  6.],
     [7.,  8.],
     [99., 99.]],  # row 4 — should NOT appear in any bag
    dtype=torch.float32,
)
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
offsets = torch.tensor([0, 2, 4],       dtype=torch.long)

eb_cpu = CustomEmbeddingBag(5, 2, "sum", True)
eb_cpu.embedding.weight = torch.nn.Parameter(weight.clone())
out_cpu = eb_cpu(indices, offsets)

eb_cuda = CustomEmbeddingBag(5, 2, "sum", True).cuda()
eb_cuda.embedding.weight = torch.nn.Parameter(weight.clone().cuda())
out_cuda = eb_cuda(indices.cuda(), offsets.cuda())

print(f"CPU  bag1: {out_cpu[1].tolist()}")   # [

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING