pytorch - 💡(How to fix) Fix DISABLED test_lazy_template_fusion_multiple_candidates_use_async_compile_True (__main__.TestPrologueFusion) [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180413Fetched 2026-04-16 06:34:47
View on GitHub
Comments
1
Participants
1
Timeline
42
Reactions
0
Participants
Assignees
Timeline (top)
mentioned ×18subscribed ×18labeled ×4assigned ×1

Error Message

Traceback (most recent call last): File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 59, in testPartExecutor yield File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 591, in run self._callTestMethod(testMethod) File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod method() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3444, in wrapper method(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 589, in instantiated_test test(self, **param_kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2167, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/var/lib/jenkins/workspace/test/inductor/test_max_autotune.py", line 4602, in test_lazy_template_fusion_multiple_candidates ).run(code[0]) RuntimeError: Expected to find "to_copy_add_div_mm_mul_relu_sub_tanh_1" but did not find it Searched string: }

static CUfunction triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = nullptr; static const char* triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source = R"TRITON( async_compile.triton('triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', ''' import triton import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

num_stages=1, num_warps=2, triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', 'backend_hash': 'B85604529715172C5D29CA34D0DA6CC837951A340BC53A8AB299DA3BA883F0D1', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': True, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'OUT_DTYPE': 'tl.float32', 'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_K': 16, 'GROUP_M': 8, 'ALLOW_TF32': True}},

) @triton.jit def triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1): EVEN_K : tl.constexpr = True USE_FAST_ACCUM : tl.constexpr = False ACC_TYPE : tl.constexpr = tl.float32 OUT_DTYPE : tl.constexpr = tl.float32 BLOCK_M : tl.constexpr = 32 BLOCK_N : tl.constexpr = 32 BLOCK_K : tl.constexpr = 16 GROUP_M : tl.constexpr = 8 ALLOW_TF32 : tl.constexpr = True INDEX_DTYPE : tl.constexpr = tl.int32

M = 64
N = 64
K = 128
if M * N == 0:
    # early exit due to zero-size input(s)
    return
stride_am = 128
stride_ak = 1
stride_bk = 64
stride_bn = 1

# based on triton.ops.matmul
pid = tl.program_id(0).to(INDEX_DTYPE)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1):
    offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
    offs_a_m = rm % M
if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1):
    offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
    offs_b_n = rn % N
offs_k = tl.arange(0, BLOCK_K).to(INDEX_DTYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

    a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
    b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)

    idx_m = offs_a_m[:, None]
    idx_n = a_k_idx_vals
    xindex = idx_n + 128*idx_m
    tmp9 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_K])), None, eviction_policy='evict_last')
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tl.full([1], 1.0, tl.float32)
    tmp12 = tmp10 + tmp11
    a = tmp12.broadcast_to(xindex.shape)

    idx_m = b_k_idx_vals
    idx_n = offs_b_n[None, :]
    xindex = idx_n + 64*idx_m
    tmp13 = tl.load(in_ptr2 + (tl.broadcast_to(xindex, [BLOCK_K, BLOCK_N])), None, eviction_policy='evict_last')
    tmp14 = tmp13.to(tl.float32)
    tmp15 = tl.full([1], 2.0, tl.float32)
    tmp16 = tmp14 * tmp15
    b = tmp16.broadcast_to(xindex.shape)


    acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)


# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)

# inductor generates a suffix
xindex = idx_n + 64*idx_m
tmp0 = tl.load(in_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), mask, eviction_policy='evict_last')
tmp1 = acc * tmp0
tmp2 = tl.full([1], 0.5, tl.float32)
tmp3 = tmp1 * tmp2
tmp4 = tl.full([1], 1.0, tl.float32)
tmp5 = tmp3 + tmp4
tmp6 = tl.full([1], 0, tl.int32)
tmp7 = triton_helpers.maximum(tmp6, tmp5)
tmp8 = libdevice.tanh(tmp7)
tl.store(out_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), tmp8, mask)

''', device_str='cuda')

)TRITON"; static LazyKernelCompileResult triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result; template <typename in_ptr0_type_, typename in_ptr1_type_, typename in_ptr2_type_, typename out_ptr1_type_> static attribute((noinline)) void call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1( const in_ptr0_type_& in_ptr0, const in_ptr1_type_& in_ptr1, const in_ptr2_type_& in_ptr2, const out_ptr1_type_& out_ptr1, int64_t grid_0, int64_t grid_1, int64_t grid_2, int32_t device_idx, cudaStream_t stream, const std::optionalstd::string& cubin_dir = std::nullopt ){ if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 == nullptr) { triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result = runTritonKernelWithAutotune( module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", stream, in_ptr0, in_ptr1, in_ptr2, out_ptr1, _grid_0, _grid_1, _grid_2);

    triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = loadKernel(
        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.cubin_path,
        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.mangled_name,
        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem);

    // First invocation already ran the kernel, so return early
    return;
}
uint32_t grid_0 = _grid_0;
uint32_t grid_1 = _grid_1;
uint32_t grid_2 = _grid_2;
if (grid_0 == 0) return;
CUdeviceptr var_3 = reinterpret_cast<CUdeviceptr>(in_ptr0.data_ptr());
CUdeviceptr var_4 = reinterpret_cast<CUdeviceptr>(in_ptr1.data_ptr());
CUdeviceptr var_5 = reinterpret_cast<CUdeviceptr>(in_ptr2.data_ptr());
CUdeviceptr var_6 = reinterpret_cast<CUdeviceptr>(out_ptr1.data_ptr());
CUdeviceptr global_scratch_ptr = 0;
RAIIAtenTensorHandle global_scratch_ptr_tensor;
if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch > 0) {
    int64_t global_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch};
    int64_t global_scratch_ptr_stride[] = {1};
    AtenTensorHandle global_scratch_ptr_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
        1, global_scratch_ptr_size, global_scratch_ptr_stride, cached_torch_dtype_uint8,
        cached_torch_device_type_cuda, device_idx_, &global_scratch_ptr_handle));
    global_scratch_ptr_tensor = RAIIAtenTensorHandle(global_scratch_ptr_handle);
    global_scratch_ptr = reinterpret_cast<CUdeviceptr>(global_scratch_ptr_tensor.data_ptr());
}
CUdeviceptr profile_scratch_ptr = 0;
RAIIAtenTensorHandle profile_scratch_ptr_tensor;
if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch > 0) {
    int64_t profile_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch};
    int64_t profile_scratch_ptr_stride[] = {1};
    AtenTensorHandle profile_scratch_ptr_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
        1, profile_scratch_ptr_size, profile_scratch_ptr_stride, cached_torch_dtype_uint8,
        cached_torch_device_type_cuda, device_idx_, &profile_scratch_ptr_handle));
    profile_scratch_ptr_tensor = RAIIAtenTensorHandle(profile_scratch_ptr_handle);
    profile_scratch_ptr = reinterpret_cast<CUdeviceptr>(profile_scratch_ptr_tensor.data_ptr());
}
void* kernel_args_[] = {&var_3, &var_4, &var_5, &var_6, &global_scratch_ptr, &profile_scratch_ptr};
launchKernel(triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1, grid_0, grid_1, grid_2, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.num_warps, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem, kernel_args_, stream_);

} // Start parallel compilation of all Triton kernels static inline void start_all_triton_kernel_compiles() { loadLazyCompileFuncs(); _module_pending_kernels = PyDict_New(); AOTI_TORCH_CHECK(_module_pending_kernels, "Failed to create pending kernels dict"); startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_div_mm_sub_0", triton_tem_fused__to_copy_div_mm_sub_0_source); startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source); }

// Static initializer to start kernel compilation on module load static struct TritonKernelCompileInit { TritonKernelCompileInit() { start_all_triton_kernel_compiles(); } } __triton_kernel_compile_init;

void inductor_entry_impl( AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed AtenTensorHandle* output_handles // array for writing output AtenTensorHandle; handles // will be stolen by the caller; the array itself is // borrowed) ) { py::gil_scoped_release_simple release;

auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
auto arg0_1 = std::move(inputs[0]);
auto arg1_1 = std::move(inputs[1]);
auto arg2_1 = std::move(inputs[2]);
auto arg3_1 = std::move(inputs[3]);

AOTICudaGuard device_guard(0);
static constexpr int64_t int_array_0[] = {64L, 64L};
static constexpr int64_t int_array_1[] = {64L, 1L};
AtenTensorHandle buf5_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf5_handle));
RAIIAtenTensorHandle buf5(buf5_handle);
// Topologically Sorted Source Nodes: [to_2, c_transformed, to_3, d_transformed, mm2], Original ATen: [aten._to_copy, aten.sub, aten.div, aten.mm]
cudaStream_t stream0;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(0, (void**)&stream0));
call_triton_tem_fused__to_copy_div_mm_sub_0(arg2_1, arg3_1, buf5, 4L, 1, 1, 0, stream0);
arg2_1.reset();
arg3_1.reset();
static constexpr int64_t int_array_2[] = {64L, 64L};
static constexpr int64_t int_array_3[] = {64L, 1L};
AtenTensorHandle buf6_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_2, int_array_3, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf6_handle));
RAIIAtenTensorHandle buf6(buf6_handle);
// Topologically Sorted Source Nodes: [to, a_transformed, to_1, b_transformed, mm1, combined, mul_2, add_1, relu, normalized], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.mm, aten.relu, aten.tanh]
call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(buf5, arg0_1, arg1_1, buf6, 4L, 1, 1, 0, stream0);
arg0_1.reset();
arg1_1.reset();
buf5.reset();
output_handles[0] = buf6.release();

} // inductor_entry_impl """ )

inductor_entry = CppWrapperCodeCache.load_pybinding( argtypes=["std::vector<AtenTensorHandle>"], main_code=cpp_wrapper_src, device_type="cuda", num_outputs=1, kernel_code=None, )

def _wrap_func(f): def g(args): input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args] input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)

    args.clear()
    del input_tensors

    output_handles = f(input_handles)
    output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
    return output_tensors

return g

call = _wrap_func(inductor_entry)

def get_args(): from torch._dynamo.testing import rand_strided arg0_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16) arg1_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16) arg2_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16) arg3_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16) return [arg0_1, arg1_1, arg2_1, arg3_1]

def benchmark_compiled_module(args, times=10, repeat=10): from torch._inductor.utils import print_performance fn = lambda: call(list(args)) return print_performance(fn, times=times, repeat=repeat)

if name == "main": from torch._inductor.wrapper_benchmark import compiled_module_main args = get_args() compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat)) From CHECK: to_copy_add_div_mm_mul_relu_sub_tanh_1

To execute this test, run the following from the base repo dir: python test/inductor/test_max_autotune.py TestPrologueFusion.test_lazy_template_fusion_multiple_candidates_use_async_compile_True

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Root Cause

This test was disabled because it is failing in CI. See recent examples and the most recent trunk workflow logs.

Code Example

Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3444, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 589, in instantiated_test
    test(self, **param_kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2167, in wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/var/lib/jenkins/workspace/test/inductor/test_max_autotune.py", line 4602, in test_lazy_template_fusion_multiple_candidates
    ).run(code[0])
RuntimeError: Expected to find "to_copy_add_div_mm_mul_relu_sub_tanh_1" but did not find it
Searched string:
}

static CUfunction triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = nullptr;
static const char* triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source = R"TRITON(
async_compile.triton('triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

num_stages=1,
num_warps=2,
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', 'backend_hash': 'B85604529715172C5D29CA34D0DA6CC837951A340BC53A8AB299DA3BA883F0D1', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': True, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'OUT_DTYPE': 'tl.float32', 'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_K': 16, 'GROUP_M': 8, 'ALLOW_TF32': True}},

)
@triton.jit
def triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1):
    EVEN_K : tl.constexpr = True
    USE_FAST_ACCUM : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    OUT_DTYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 32
    BLOCK_N : tl.constexpr = 32
    BLOCK_K : tl.constexpr = 16
    GROUP_M : tl.constexpr = 8
    ALLOW_TF32 : tl.constexpr = True
    INDEX_DTYPE : tl.constexpr = tl.int32

    M = 64
    N = 64
    K = 128
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 128
    stride_ak = 1
    stride_bk = 64
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0).to(INDEX_DTYPE)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
    if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1):
        offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    else:
        offs_a_m = rm % M
    if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1):
        offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    else:
        offs_b_n = rn % N
    offs_k = tl.arange(0, BLOCK_K).to(INDEX_DTYPE)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
        b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)

        idx_m = offs_a_m[:, None]
        idx_n = a_k_idx_vals
        xindex = idx_n + 128*idx_m
        tmp9 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_K])), None, eviction_policy='evict_last')
        tmp10 = tmp9.to(tl.float32)
        tmp11 = tl.full([1], 1.0, tl.float32)
        tmp12 = tmp10 + tmp11
        a = tmp12.broadcast_to(xindex.shape)

        idx_m = b_k_idx_vals
        idx_n = offs_b_n[None, :]
        xindex = idx_n + 64*idx_m
        tmp13 = tl.load(in_ptr2 + (tl.broadcast_to(xindex, [BLOCK_K, BLOCK_N])), None, eviction_policy='evict_last')
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tl.full([1], 2.0, tl.float32)
        tmp16 = tmp14 * tmp15
        b = tmp16.broadcast_to(xindex.shape)


        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)


    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + 64*idx_m
    tmp0 = tl.load(in_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), mask, eviction_policy='evict_last')
    tmp1 = acc * tmp0
    tmp2 = tl.full([1], 0.5, tl.float32)
    tmp3 = tmp1 * tmp2
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = tl.full([1], 0, tl.int32)
    tmp7 = triton_helpers.maximum(tmp6, tmp5)
    tmp8 = libdevice.tanh(tmp7)
    tl.store(out_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), tmp8, mask)
''', device_str='cuda')

)TRITON";
static LazyKernelCompileResult triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result;
template <typename in_ptr0_type_, typename in_ptr1_type_, typename in_ptr2_type_, typename out_ptr1_type_>
static __attribute__((noinline)) void call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(
    const in_ptr0_type_& in_ptr0,
    const in_ptr1_type_& in_ptr1,
    const in_ptr2_type_& in_ptr2,
    const out_ptr1_type_& out_ptr1,
    int64_t _grid_0,
    int64_t _grid_1,
    int64_t _grid_2,
    int32_t device_idx_,
    cudaStream_t stream_,
    const std::optional<std::string>& cubin_dir_ = std::nullopt
){
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 == nullptr) {
        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result = runTritonKernelWithAutotune(
            _module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", stream_, in_ptr0, in_ptr1, in_ptr2, out_ptr1, _grid_0, _grid_1, _grid_2);

        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = loadKernel(
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.cubin_path,
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.mangled_name,
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem);

        // First invocation already ran the kernel, so return early
        return;
    }
    uint32_t grid_0 = _grid_0;
    uint32_t grid_1 = _grid_1;
    uint32_t grid_2 = _grid_2;
    if (grid_0 == 0) return;
    CUdeviceptr var_3 = reinterpret_cast<CUdeviceptr>(in_ptr0.data_ptr());
    CUdeviceptr var_4 = reinterpret_cast<CUdeviceptr>(in_ptr1.data_ptr());
    CUdeviceptr var_5 = reinterpret_cast<CUdeviceptr>(in_ptr2.data_ptr());
    CUdeviceptr var_6 = reinterpret_cast<CUdeviceptr>(out_ptr1.data_ptr());
    CUdeviceptr global_scratch_ptr = 0;
    RAIIAtenTensorHandle global_scratch_ptr_tensor;
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch > 0) {
        int64_t global_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch};
        int64_t global_scratch_ptr_stride[] = {1};
        AtenTensorHandle global_scratch_ptr_handle;
        AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
            1, global_scratch_ptr_size, global_scratch_ptr_stride, cached_torch_dtype_uint8,
            cached_torch_device_type_cuda, device_idx_, &global_scratch_ptr_handle));
        global_scratch_ptr_tensor = RAIIAtenTensorHandle(global_scratch_ptr_handle);
        global_scratch_ptr = reinterpret_cast<CUdeviceptr>(global_scratch_ptr_tensor.data_ptr());
    }
    CUdeviceptr profile_scratch_ptr = 0;
    RAIIAtenTensorHandle profile_scratch_ptr_tensor;
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch > 0) {
        int64_t profile_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch};
        int64_t profile_scratch_ptr_stride[] = {1};
        AtenTensorHandle profile_scratch_ptr_handle;
        AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
            1, profile_scratch_ptr_size, profile_scratch_ptr_stride, cached_torch_dtype_uint8,
            cached_torch_device_type_cuda, device_idx_, &profile_scratch_ptr_handle));
        profile_scratch_ptr_tensor = RAIIAtenTensorHandle(profile_scratch_ptr_handle);
        profile_scratch_ptr = reinterpret_cast<CUdeviceptr>(profile_scratch_ptr_tensor.data_ptr());
    }
    void* kernel_args_[] = {&var_3, &var_4, &var_5, &var_6, &global_scratch_ptr, &profile_scratch_ptr};
    launchKernel(triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1, grid_0, grid_1, grid_2, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.num_warps, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem, kernel_args_, stream_);
}
// Start parallel compilation of all Triton kernels
static inline void start_all_triton_kernel_compiles() {
    loadLazyCompileFuncs();
    _module_pending_kernels = PyDict_New();
    AOTI_TORCH_CHECK(_module_pending_kernels, "Failed to create pending kernels dict");
    startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_div_mm_sub_0", triton_tem_fused__to_copy_div_mm_sub_0_source);
    startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source);
}

// Static initializer to start kernel compilation on module load
static struct TritonKernelCompileInit {
    TritonKernelCompileInit() {
        start_all_triton_kernel_compiles();
    }
} __triton_kernel_compile_init;


void inductor_entry_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles  // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed)
) {
    py::gil_scoped_release_simple release;

    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg0_1 = std::move(inputs[0]);
    auto arg1_1 = std::move(inputs[1]);
    auto arg2_1 = std::move(inputs[2]);
    auto arg3_1 = std::move(inputs[3]);

    AOTICudaGuard device_guard(0);
    static constexpr int64_t int_array_0[] = {64L, 64L};
    static constexpr int64_t int_array_1[] = {64L, 1L};
    AtenTensorHandle buf5_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf5_handle));
    RAIIAtenTensorHandle buf5(buf5_handle);
    // Topologically Sorted Source Nodes: [to_2, c_transformed, to_3, d_transformed, mm2], Original ATen: [aten._to_copy, aten.sub, aten.div, aten.mm]
    cudaStream_t stream0;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(0, (void**)&stream0));
    call_triton_tem_fused__to_copy_div_mm_sub_0(arg2_1, arg3_1, buf5, 4L, 1, 1, 0, stream0);
    arg2_1.reset();
    arg3_1.reset();
    static constexpr int64_t int_array_2[] = {64L, 64L};
    static constexpr int64_t int_array_3[] = {64L, 1L};
    AtenTensorHandle buf6_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_2, int_array_3, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf6_handle));
    RAIIAtenTensorHandle buf6(buf6_handle);
    // Topologically Sorted Source Nodes: [to, a_transformed, to_1, b_transformed, mm1, combined, mul_2, add_1, relu, normalized], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.mm, aten.relu, aten.tanh]
    call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(buf5, arg0_1, arg1_1, buf6, 4L, 1, 1, 0, stream0);
    arg0_1.reset();
    arg1_1.reset();
    buf5.reset();
    output_handles[0] = buf6.release();
} // inductor_entry_impl
"""
)

inductor_entry = CppWrapperCodeCache.load_pybinding(
    argtypes=["std::vector<AtenTensorHandle>"],
    main_code=cpp_wrapper_src,
    device_type="cuda",
    num_outputs=1,
    kernel_code=None,
)

def _wrap_func(f):
    def g(args):
        input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]
        input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)

        args.clear()
        del input_tensors

        output_handles = f(input_handles)
        output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
        return output_tensors

    return g

call = _wrap_func(inductor_entry)


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
From CHECK: to_copy_add_div_mm_mul_relu_sub_tanh_1


To execute this test, run the following from the base repo dir:
    python test/inductor/test_max_autotune.py TestPrologueFusion.test_lazy_template_fusion_multiple_candidates_use_async_compile_True

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
RAW_BUFFERClick to expand / collapse

Platforms: inductor

This test was disabled because it is failing in CI. See recent examples and the most recent trunk workflow logs.

Over the past 6 hours, it has been determined flaky in 3 workflow(s) with 3 failures and 3 successes.

Debugging instructions (after clicking on the recent samples link): DO NOT ASSUME THINGS ARE OKAY IF THE CI IS GREEN. We now shield flaky tests from developers so CI will thus be green but it will be harder to parse the logs. To find relevant log snippets:

  1. Click on the workflow logs linked above
  2. Click on the Test step of the job so that it is expanded. Otherwise, the grepping will not work.
  3. Grep for test_lazy_template_fusion_multiple_candidates_use_async_compile_True
  4. There should be several instances run (as flaky tests are rerun in CI) from which you can study the logs.
<details><summary>Sample error message</summary>
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3444, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 589, in instantiated_test
    test(self, **param_kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2167, in wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/var/lib/jenkins/workspace/test/inductor/test_max_autotune.py", line 4602, in test_lazy_template_fusion_multiple_candidates
    ).run(code[0])
RuntimeError: Expected to find "to_copy_add_div_mm_mul_relu_sub_tanh_1" but did not find it
Searched string:
}

static CUfunction triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = nullptr;
static const char* triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source = R"TRITON(
async_compile.triton('triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

num_stages=1,
num_warps=2,
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1', 'backend_hash': 'B85604529715172C5D29CA34D0DA6CC837951A340BC53A8AB299DA3BA883F0D1', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': True, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'OUT_DTYPE': 'tl.float32', 'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_K': 16, 'GROUP_M': 8, 'ALLOW_TF32': True}},

)
@triton.jit
def triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1):
    EVEN_K : tl.constexpr = True
    USE_FAST_ACCUM : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    OUT_DTYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 32
    BLOCK_N : tl.constexpr = 32
    BLOCK_K : tl.constexpr = 16
    GROUP_M : tl.constexpr = 8
    ALLOW_TF32 : tl.constexpr = True
    INDEX_DTYPE : tl.constexpr = tl.int32

    M = 64
    N = 64
    K = 128
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 128
    stride_ak = 1
    stride_bk = 64
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0).to(INDEX_DTYPE)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
    if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1):
        offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    else:
        offs_a_m = rm % M
    if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1):
        offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    else:
        offs_b_n = rn % N
    offs_k = tl.arange(0, BLOCK_K).to(INDEX_DTYPE)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    for k_idx in range(0, tl.cdiv(K, BLOCK_K)):

        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
        b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)

        idx_m = offs_a_m[:, None]
        idx_n = a_k_idx_vals
        xindex = idx_n + 128*idx_m
        tmp9 = tl.load(in_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_K])), None, eviction_policy='evict_last')
        tmp10 = tmp9.to(tl.float32)
        tmp11 = tl.full([1], 1.0, tl.float32)
        tmp12 = tmp10 + tmp11
        a = tmp12.broadcast_to(xindex.shape)

        idx_m = b_k_idx_vals
        idx_n = offs_b_n[None, :]
        xindex = idx_n + 64*idx_m
        tmp13 = tl.load(in_ptr2 + (tl.broadcast_to(xindex, [BLOCK_K, BLOCK_N])), None, eviction_policy='evict_last')
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tl.full([1], 2.0, tl.float32)
        tmp16 = tmp14 * tmp15
        b = tmp16.broadcast_to(xindex.shape)


        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE)


    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + 64*idx_m
    tmp0 = tl.load(in_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), mask, eviction_policy='evict_last')
    tmp1 = acc * tmp0
    tmp2 = tl.full([1], 0.5, tl.float32)
    tmp3 = tmp1 * tmp2
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = tl.full([1], 0, tl.int32)
    tmp7 = triton_helpers.maximum(tmp6, tmp5)
    tmp8 = libdevice.tanh(tmp7)
    tl.store(out_ptr1 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), tmp8, mask)
''', device_str='cuda')

)TRITON";
static LazyKernelCompileResult triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result;
template <typename in_ptr0_type_, typename in_ptr1_type_, typename in_ptr2_type_, typename out_ptr1_type_>
static __attribute__((noinline)) void call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(
    const in_ptr0_type_& in_ptr0,
    const in_ptr1_type_& in_ptr1,
    const in_ptr2_type_& in_ptr2,
    const out_ptr1_type_& out_ptr1,
    int64_t _grid_0,
    int64_t _grid_1,
    int64_t _grid_2,
    int32_t device_idx_,
    cudaStream_t stream_,
    const std::optional<std::string>& cubin_dir_ = std::nullopt
){
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 == nullptr) {
        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result = runTritonKernelWithAutotune(
            _module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", stream_, in_ptr0, in_ptr1, in_ptr2, out_ptr1, _grid_0, _grid_1, _grid_2);

        triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1 = loadKernel(
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.cubin_path,
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.mangled_name,
            triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem);

        // First invocation already ran the kernel, so return early
        return;
    }
    uint32_t grid_0 = _grid_0;
    uint32_t grid_1 = _grid_1;
    uint32_t grid_2 = _grid_2;
    if (grid_0 == 0) return;
    CUdeviceptr var_3 = reinterpret_cast<CUdeviceptr>(in_ptr0.data_ptr());
    CUdeviceptr var_4 = reinterpret_cast<CUdeviceptr>(in_ptr1.data_ptr());
    CUdeviceptr var_5 = reinterpret_cast<CUdeviceptr>(in_ptr2.data_ptr());
    CUdeviceptr var_6 = reinterpret_cast<CUdeviceptr>(out_ptr1.data_ptr());
    CUdeviceptr global_scratch_ptr = 0;
    RAIIAtenTensorHandle global_scratch_ptr_tensor;
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch > 0) {
        int64_t global_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.global_scratch};
        int64_t global_scratch_ptr_stride[] = {1};
        AtenTensorHandle global_scratch_ptr_handle;
        AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
            1, global_scratch_ptr_size, global_scratch_ptr_stride, cached_torch_dtype_uint8,
            cached_torch_device_type_cuda, device_idx_, &global_scratch_ptr_handle));
        global_scratch_ptr_tensor = RAIIAtenTensorHandle(global_scratch_ptr_handle);
        global_scratch_ptr = reinterpret_cast<CUdeviceptr>(global_scratch_ptr_tensor.data_ptr());
    }
    CUdeviceptr profile_scratch_ptr = 0;
    RAIIAtenTensorHandle profile_scratch_ptr_tensor;
    if (triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch > 0) {
        int64_t profile_scratch_ptr_size[] = {triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.profile_scratch};
        int64_t profile_scratch_ptr_stride[] = {1};
        AtenTensorHandle profile_scratch_ptr_handle;
        AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
            1, profile_scratch_ptr_size, profile_scratch_ptr_stride, cached_torch_dtype_uint8,
            cached_torch_device_type_cuda, device_idx_, &profile_scratch_ptr_handle));
        profile_scratch_ptr_tensor = RAIIAtenTensorHandle(profile_scratch_ptr_handle);
        profile_scratch_ptr = reinterpret_cast<CUdeviceptr>(profile_scratch_ptr_tensor.data_ptr());
    }
    void* kernel_args_[] = {&var_3, &var_4, &var_5, &var_6, &global_scratch_ptr, &profile_scratch_ptr};
    launchKernel(triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1, grid_0, grid_1, grid_2, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.num_warps, triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_result.shared_mem, kernel_args_, stream_);
}
// Start parallel compilation of all Triton kernels
static inline void start_all_triton_kernel_compiles() {
    loadLazyCompileFuncs();
    _module_pending_kernels = PyDict_New();
    AOTI_TORCH_CHECK(_module_pending_kernels, "Failed to create pending kernels dict");
    startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_div_mm_sub_0", triton_tem_fused__to_copy_div_mm_sub_0_source);
    startKernelCompile(_module_pending_kernels, "triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1", triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1_source);
}

// Static initializer to start kernel compilation on module load
static struct TritonKernelCompileInit {
    TritonKernelCompileInit() {
        start_all_triton_kernel_compiles();
    }
} __triton_kernel_compile_init;


void inductor_entry_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles  // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed)
) {
    py::gil_scoped_release_simple release;

    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg0_1 = std::move(inputs[0]);
    auto arg1_1 = std::move(inputs[1]);
    auto arg2_1 = std::move(inputs[2]);
    auto arg3_1 = std::move(inputs[3]);

    AOTICudaGuard device_guard(0);
    static constexpr int64_t int_array_0[] = {64L, 64L};
    static constexpr int64_t int_array_1[] = {64L, 1L};
    AtenTensorHandle buf5_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf5_handle));
    RAIIAtenTensorHandle buf5(buf5_handle);
    // Topologically Sorted Source Nodes: [to_2, c_transformed, to_3, d_transformed, mm2], Original ATen: [aten._to_copy, aten.sub, aten.div, aten.mm]
    cudaStream_t stream0;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(0, (void**)&stream0));
    call_triton_tem_fused__to_copy_div_mm_sub_0(arg2_1, arg3_1, buf5, 4L, 1, 1, 0, stream0);
    arg2_1.reset();
    arg3_1.reset();
    static constexpr int64_t int_array_2[] = {64L, 64L};
    static constexpr int64_t int_array_3[] = {64L, 1L};
    AtenTensorHandle buf6_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_2, int_array_3, cached_torch_dtype_float32, cached_torch_device_type_cuda,  0, &buf6_handle));
    RAIIAtenTensorHandle buf6(buf6_handle);
    // Topologically Sorted Source Nodes: [to, a_transformed, to_1, b_transformed, mm1, combined, mul_2, add_1, relu, normalized], Original ATen: [aten._to_copy, aten.add, aten.mul, aten.mm, aten.relu, aten.tanh]
    call_triton_tem_fused__to_copy_add_mm_mul_relu_tanh_1(buf5, arg0_1, arg1_1, buf6, 4L, 1, 1, 0, stream0);
    arg0_1.reset();
    arg1_1.reset();
    buf5.reset();
    output_handles[0] = buf6.release();
} // inductor_entry_impl
"""
)

inductor_entry = CppWrapperCodeCache.load_pybinding(
    argtypes=["std::vector<AtenTensorHandle>"],
    main_code=cpp_wrapper_src,
    device_type="cuda",
    num_outputs=1,
    kernel_code=None,
)

def _wrap_func(f):
    def g(args):
        input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]
        input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)

        args.clear()
        del input_tensors

        output_handles = f(input_handles)
        output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
        return output_tensors

    return g

call = _wrap_func(inductor_entry)


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((64, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((128, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
From CHECK: to_copy_add_div_mm_mul_relu_sub_tanh_1


To execute this test, run the following from the base repo dir:
    python test/inductor/test_max_autotune.py TestPrologueFusion.test_lazy_template_fusion_multiple_candidates_use_async_compile_True

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
</details>

Test file path: inductor/test_max_autotune.py

For all disabled tests (by GitHub issue), see https://hud.pytorch.org/disabled.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The test test_lazy_template_fusion_multiple_candidates_use_async_compile_True is failing due to a RuntimeError caused by the expected kernel to_copy_add_div_mm_mul_relu_sub_tanh_1 not being found.

Guidance

  1. Verify the kernel name: Ensure that the kernel name to_copy_add_div_mm_mul_relu_sub_tanh_1 is correctly defined and matches the expected name in the test.
  2. Check the compilation process: Investigate the compilation process of the Triton kernel to see if there are any issues that might prevent the kernel from being generated or loaded correctly.
  3. Review the test configuration: Examine the test configuration and setup to ensure that the necessary dependencies and environment variables are properly set up.
  4. Run the test with debugging instructions: Follow the provided debugging instructions to gather more information about the failure and potentially identify the root cause.

Example

No specific code example is provided as the issue seems to be related to the compilation and loading of a specific kernel, and the code snippet provided is quite extensive.

Notes

The issue might be related to the fact that the test is using async_compile which can lead to issues with kernel compilation and loading. Additionally, the test is using a specific kernel name that might not be correctly defined or loaded.

Recommendation

Apply a workaround by re-checking the kernel name and compilation process, and consider disabling async_compile to see if the issue persists. If the issue is still present, it may be necessary to modify the test configuration or environment variables to ensure proper kernel compilation and loading.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix DISABLED test_lazy_template_fusion_multiple_candidates_use_async_compile_True (__main__.TestPrologueFusion) [1 comments, 1 participants]