pytorch - 💡(How to fix) Fix [Inductor] Filter the extreme bad triton config by device information for triton heuristics

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

The current workflow for the triton heuristics is to generate a series of triton configs and then run autotune to a best config. Usually these config generated are based on well known practices, and at the same time try to cover as many cases as possible because autotune is a onetime effort and it doesn't seem to hurt much to evaluate more configs. With max_autotune, it also tend to add or keep more configs to bench.

Code Example

import torch
import time

start_time = time.perf_counter()

class TestModule(torch.nn.Module):
    # Input tensors that are generated randomly
    def __init__(self, sel_device, training):
        torch.manual_seed(777)
        super(TestModule, self).__init__()
        self.inp = torch.empty_strided(size=[16, 4096, 1024], stride=[4194304, 1024, 1], dtype=torch.bfloat16, device='cpu').uniform_(-1, 1).to(device=sel_device).requires_grad_(training)

    def forward(self):
        r = torch.sum(self.inp, [0], False) # traced line: 1010
        return r

device = "xpu"
m = TestModule(device, False)
m = torch.compile(m)     

out = m()
out = out.to("cpu")

end_time = time.perf_counter()
print("Total time:", str(end_time - start_time))

---

1. {'XBLOCK': 1}
2. {'XBLOCK': 8}  
3. {'XBLOCK': 32}  
4. {'XBLOCK': 128}

---

1. {'XBLOCK': 1} -> grid (4194304, 1, 1)
2. {'XBLOCK': 8}  -> grid (524288, 1, 1)
3. {'XBLOCK': 32}  -> grid (131072, 1, 1)
4. {'XBLOCK': 128} -> grid (32768, 1, 1)
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

The current workflow for the triton heuristics is to generate a series of triton configs and then run autotune to a best config. Usually these config generated are based on well known practices, and at the same time try to cover as many cases as possible because autotune is a onetime effort and it doesn't seem to hurt much to evaluate more configs. With max_autotune, it also tend to add or keep more configs to bench.

This workflow is simple and straight forward. But we saw the config included lacks the consideration of the device capabilities such as the number CTAs especially when the GPU in the market becomes more diverse: low end GPU may have a few CTAs per GPU, while high end GPU have hundreds. For example, for the following case:

import torch
import time

start_time = time.perf_counter()

class TestModule(torch.nn.Module):
    # Input tensors that are generated randomly
    def __init__(self, sel_device, training):
        torch.manual_seed(777)
        super(TestModule, self).__init__()
        self.inp = torch.empty_strided(size=[16, 4096, 1024], stride=[4194304, 1024, 1], dtype=torch.bfloat16, device='cpu').uniform_(-1, 1).to(device=sel_device).requires_grad_(training)

    def forward(self):
        r = torch.sum(self.inp, [0], False) # traced line: 1010
        return r

device = "xpu"
m = TestModule(device, False)
m = torch.compile(m)     

out = m()
out = out.to("cpu")

end_time = time.perf_counter()
print("Total time:", str(end_time - start_time))

Inductor persistent_reduction will generate and auto tune 4 XBLOCK configurations with the reduction size = 16:

1. {'XBLOCK': 1}
2. {'XBLOCK': 8}  
3. {'XBLOCK': 32}  
4. {'XBLOCK': 128}

These will bench 4 grid configurations:

1. {'XBLOCK': 1} -> grid (4194304, 1, 1)
2. {'XBLOCK': 8}  -> grid (524288, 1, 1)
3. {'XBLOCK': 32}  -> grid (131072, 1, 1)
4. {'XBLOCK': 128} -> grid (32768, 1, 1)

Extreme cases like 4194304 thread blocks /work groups with reduction on 16 elements is obviously bad and will mostly perform the worst. We will spend much of the time benchmarking some very bad config without a chance to win. What's more, when we running this config in weak GPU or in a simulator, this kind of extreme case will take much of the time and make the benchmark time unacceptable.

One method to deal with this is to avoid including such a config at the beginning by applying the knowledge of the problem size and the device capabilities. But this will make the config logic too complex and dynamic considering we already have a lot of cases to handle for pointwise, reduction, persistent reduction and so on. The existing way doing limited check and adding configs is still ideal for simplicity.

The suggest approach is to simply add the capability to perform a scrub process which filter out the extreme triton config by device information after all configs are created. This makes the logic of filtering centralized and not intrusive for the config generation logic. And this filtering config by device information can be controlled by a flag so that user can turn on or off.

@ezyang @EikanWang

Alternatives

No response

Additional context

No response

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor] Filter the extreme bad triton config by device information for triton heuristics