// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"

namespace onnxruntime {{
namespace contrib {{
namespace cuda {{
namespace sparse_attention_v2 {{

// This file is generated by compile_sparse_attention_v2.py
// {kernel_docstring}
// cubin_size = {bin_size}
// shared_mem_bytes = {shared}
// threads_per_cta = {num_warps} * 32
// kernel_name = {triton_kernel_name}

const unsigned char {kernel_name}_cubin[] = {{ {bin_data} }};

CUmodule {kernel_name}_mod = NULL;
CUfunction {kernel_name}_func = NULL;

void unload_{kernel_name}(void) {{
    const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
    CU_CHECK(driver->cuModuleUnload({kernel_name}_mod), driver);
}}

void load_{kernel_name}(void) {{
    void *bin = (void *)&{kernel_name}_cubin;
    const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
    CU_CHECK(driver->cuModuleLoadData(&{kernel_name}_mod, bin), driver);
    CU_CHECK(driver->cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"), driver);
    constexpr int shared = {shared};
    if constexpr (shared > 49152) {{
      SetKernelSharedMemory(driver, {kernel_name}_func);
    }}
}}

Status {kernel_name}(SparseAttentionParams& params) {{
    return params.LaunchKernel({kernel_name}_func, {num_warps} * 32, {shared});
}}

}}  // namespace sparse_attention_v2
}}  // namespace cuda
}}  // namespace contrib
}}  // namespace onnxruntime
