跳到主要内容

MUSA Tensor Core GEMM 实践手记:把“能跑”变成“能交付”

刚接手一个 GEMM 热点模块时,最常见的状态是:
“知道 Tensor Core 很重要,但不知道该从哪一层开始落地。”

这篇文章不追求“最全 API 手册”,而是按真实工程节奏,把路径拆成 6 步:

  1. 先写一个慢但可解释的 baseline
  2. 用 WMMA 建立 Tensor Core 编程直觉
  3. 用 muBLAS 快速获得稳定实现
  4. 在 muBLASLt 上做算法与 workspace 控制
  5. 在 muDNN 场景里对接算子接口
  6. 用 MUTLASS 做更可控的模板化定制

每一步都保留 example code。文中不含敏感环境信息,也不放性能数据。

开始前:先把“跑不起来”的问题排干净

建议先确认工具链,而不是先写代码。

which mcc
mcc --version

# 以下路径按你的安装位置调整
ls /usr/local/musa/include/musa_runtime.h
ls /usr/local/musa/include/mublas.h
ls /usr/local/musa/lib64/libmusart.so
ls /usr/local/musa/lib64/libmublas.so

如果这些都正常,再开始正文会省很多时间。


一、先讲最容易出错的点:布局

大多数 GEMM 的“结果不对”,根因都不是算子,而是布局解释。

  • row-major: A[i, k] -> i * ld + k
  • col-major: A[i, k] -> k * ld + i

BLAS 默认按 col-major。业务里常见 row-major,这时常用映射:

C=ABCT=BTATC = AB \Leftrightarrow C^T = B^T A^T

这个映射的价值很实际:不写显式 transpose kernel,也能复用库接口


二、Step 1:先有 baseline(哪怕它很慢)

baseline 的意义不是快,而是“可解释、可对拍、可回归”。

2.1 一个最小 baseline kernel

与文末附录中的基准程序 bench_gemm_2048.mu 一致:布局用 MemLayout,维度在附录里为编译期常量 M_DIM/N_DIM/K_DIM;下面给出通用M,N,K 参数的写法,便于阅读。

enum class MemLayout { RowMajor, ColMajor };

template <MemLayout L>
__global__ void gemm_fp16_scalar_kernel(const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int M, int N, int K) {
int i = (int)(blockIdx.y * blockDim.y + threadIdx.y);
int j = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= M || j >= N) return;

float acc = 0.0f;
#pragma unroll 4
for (int k = 0; k < K; ++k) {
int a_idx = (L == MemLayout::RowMajor) ? (i * K + k) : (k * M + i);
int b_idx = (L == MemLayout::RowMajor) ? (k * N + j) : (j * K + k);
acc += __half2float(A[a_idx]) * __half2float(B[b_idx]);
}
int c_idx = (L == MemLayout::RowMajor) ? (i * N + j) : (j * M + i);
C[c_idx] = acc;
}

2.2 launch 示例

dim3 block(16, 16);
dim3 grid((unsigned)((N + block.x - 1) / block.x), (unsigned)((M + block.y - 1) / block.y), 1u);
gemm_fp16_scalar_kernel<MemLayout::RowMajor><<<grid, block>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());

三、Step 2:WMMA 入门,建立 Tensor Core 直觉

WMMA 是“你亲手写 Tensor Core kernel”的入口,适合理解 tile 和 fragment。

3.1 row-major WMMA 示例

MUSA 下请在包含 <mma.h> 后使用 using namespace mtmusa;(与附录中 WMMA 部分一致)。下面内核结构与附录基准相同:用 blockIdx 表示 tile,在 store 前做越界保护M、NWMMA 整数倍时可与附录中固定 2048 版本完全对齐)。

#include <mma.h>
using namespace mtmusa;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

__global__ void wmma_gemm_kernel_rm(const half* __restrict__ A_rm,
const half* __restrict__ B_rm,
float* __restrict__ C_rm,
int M, int N, int K) {
int tile_m = (int)blockIdx.y;
int tile_n = (int)blockIdx.x;
int block_row = tile_m * WMMA_M;
int block_col = tile_n * WMMA_N;

wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

for (int i = 0; i < (int)(sizeof(c_frag.x) / sizeof(c_frag.x[0])); ++i) c_frag.x[i] = 0.0f;

for (int k0 = 0; k0 < K; k0 += WMMA_K) {
const half* tile_a = A_rm + block_row * K + k0;
const half* tile_b = B_rm + k0 * N + block_col;
wmma::load_matrix_sync(a_frag, tile_a, K);
wmma::load_matrix_sync(b_frag, tile_b, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

if (block_row + WMMA_M <= M && block_col + WMMA_N <= N) {
wmma::store_matrix_sync(C_rm + block_row * N + block_col, c_frag, N,
wmma::mem_row_major);
}
}

col-major 的 WMMA 内核、以及与 row-major 共用的 launch 分支,见本文附录wmma_gemm_kernel_cmlaunch_wmma

3.2 launch 示例

dim3 grid_wmma((unsigned)((N + WMMA_N - 1) / WMMA_N),
(unsigned)((M + WMMA_M - 1) / WMMA_M), 1u);
dim3 block_wmma(32u, 1u, 1u); // one warp
wmma_gemm_kernel_rm<<<grid_wmma, block_wmma>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());

这一步跑通后,你对 Tensor Core 编程模型会有非常具体的认知。


四、Step 3:muBLAS,先把生产可用性拿到手

真实项目里,纯 GEMM 大多先落在 muBLAS。

4.1 col-major 直接调用

#define CHECK_MUBLAS(call) do { \
mublasStatus_t st = (call); \
if (st != MUBLAS_STATUS_SUCCESS) { \
std::fprintf(stderr, "muBLAS error %s:%d status=%d\n", __FILE__, __LINE__, (int)st); \
std::exit(EXIT_FAILURE); \
} \
} while (0)

void run_mublas_col_major(mublasHandle_t handle, const half* dA, const half* dB, float* dC,
int M, int N, int K) {
const float alpha = 1.0f;
const float beta = 0.0f;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
M, N, K,
&alpha,
dA, MUSA_R_16F, M,
dB, MUSA_R_16F, K,
&beta,
dC, MUSA_R_32F, M,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

4.2 row-major 无拷贝映射调用

void run_mublas_row_major_map(mublasHandle_t handle, const half* dA_rm, const half* dB_rm,
float* dC_rm, int M, int N, int K) {
// C^T(N, M) = B^T(N, K) * A^T(K, M)
const float alpha = 1.0f;
const float beta = 0.0f;
int m = N, n = M, k = K;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
m, n, k,
&alpha,
dB_rm, MUSA_R_16F, N,
dA_rm, MUSA_R_16F, K,
&beta,
dC_rm, MUSA_R_32F, N,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

五、Step 4:muBLASLt,把“可调性”加进来

当你需要更细控制(算法选择、workspace 预算、布局描述)时,就该上 muBLASLt。下面是与附录基准中 col-major 分支一致的要点:MUBLASLT_ORDER_COL 显式设到 A/B/C/D 的 layout 上;row-major 时矩阵形状与 muBLAS 无拷贝映射一致,并在 mublasLtMatmul 时交换 A/B 指针(实现见附录 mublaslt_build_ctx / launch_mublaslt)。

#include <mublasLt.h>

void run_mublaslt_col_major(const half* dA, const half* dB, float* dC,
int M, int N, int K, void* workspace_ptr, size_t workspace_bytes) {
mublasLtHandle_t lt = nullptr;
CHECK_MUBLAS(mublasLtCreate(&lt));

mublasLtMatmulDesc_t op_desc = nullptr;
CHECK_MUBLAS(mublasLtMatmulDescCreate(&op_desc, MUBLAS_COMPUTE_32F, MUSA_R_32F));
mublasOperation_t trans_n = MUBLAS_OP_N;
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSA, &trans_n, sizeof(trans_n)));
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSB, &trans_n, sizeof(trans_n)));

mublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&a_desc, MUSA_R_16F, M, K, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&b_desc, MUSA_R_16F, K, N, K));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&c_desc, MUSA_R_32F, M, N, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&d_desc, MUSA_R_32F, M, N, M));

mublasLtOrder_t order = MUBLASLT_ORDER_COL;
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
a_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
b_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
c_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
d_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));

mublasLtMatmulPreference_t pref = nullptr;
CHECK_MUBLAS(mublasLtMatmulPreferenceCreate(&pref));
CHECK_MUBLAS(mublasLtMatmulPreferenceSetAttribute(
pref, MUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_bytes, sizeof(workspace_bytes)));

mublasLtMatmulHeuristicResult_t hres{};
int algo_count = 0;
CHECK_MUBLAS(mublasLtMatmulAlgoGetHeuristic(
lt, op_desc, a_desc, b_desc, c_desc, d_desc, pref, 1, &hres, &algo_count));

const float alpha = 1.0f, beta = 0.0f;
CHECK_MUBLAS(mublasLtMatmul(
lt, op_desc, &alpha,
dA, a_desc, dB, b_desc, &beta,
dC, c_desc, dC, d_desc,
(algo_count > 0 ? &hres.algo : nullptr),
workspace_ptr, workspace_bytes, nullptr));

mublasLtMatmulPreferenceDestroy(pref);
mublasLtMatrixLayoutDestroy(a_desc);
mublasLtMatrixLayoutDestroy(b_desc);
mublasLtMatrixLayoutDestroy(c_desc);
mublasLtMatrixLayoutDestroy(d_desc);
mublasLtMatmulDescDestroy(op_desc);
mublasLtDestroy(lt);
}

六、Step 5:muDNN MatMul,接入算子式调用

如果你的上层是算子图或算子融合流程,muDNN 往往更自然。row-major 时 stride 取 (ld, 1) 即可。当前部分驱动对 tensor 的 stride 校验较严(附录 mainmake_tensor / MemLayout::ColMajor 分支前有英文注释说明),col-major 物理存储时可在不搬数据的前提下用 (C^\top = B^\top A^\top) 映射到 MatMul 的 row-major 描述,与附录 bench_one_layoutMemLayout::ColMajor 分支一致。

#include <mudnn.h>

void run_mudnn_row_major(const half* dA, const half* dB, float* dC, int M, int N, int K) {
musa::dnn::Handle handle;
musa::dnn::MatMul op;
op.SetComputeMode(musa::dnn::MatMul::ComputeMode::TENSOR);
op.SetTranspose(false, false);
op.SetAlpha(1.0);
op.SetBeta(0.0);

musa::dnn::Tensor tA, tB, tC;
tA.SetAddr(dA); tA.SetType(musa::dnn::Tensor::Type::HALF);
tB.SetAddr(dB); tB.SetType(musa::dnn::Tensor::Type::HALF);
tC.SetAddr(dC); tC.SetType(musa::dnn::Tensor::Type::FLOAT);

int64_t dimA[2] = {M, K}, strideA[2] = {K, 1};
int64_t dimB[2] = {K, N}, strideB[2] = {N, 1};
int64_t dimC[2] = {M, N}, strideC[2] = {N, 1};
tA.SetNdInfo(2, dimA, strideA);
tB.SetNdInfo(2, dimB, strideB);
tC.SetNdInfo(2, dimC, strideC);

size_t ws_bytes = 0;
op.GetWorkspaceSize(handle, ws_bytes, tC, tA, tB);
void* ws = nullptr;
if (ws_bytes > 0) CHECK_MUSA(musaMalloc(&ws, ws_bytes));

auto st = op.Run(handle, tC, tA, tB);
if (st != musa::dnn::Status::SUCCESS) {
std::fprintf(stderr, "muDNN MatMul run failed\n");
std::exit(EXIT_FAILURE);
}
if (ws) CHECK_MUSA(musaFree(ws));
}

七、Step 6:MUTLASS,做模板化可控定制

当你需要“比库更可控,但不想全手写 WMMA 调度”时,MUTLASS 是很实用的中间层。下面类型骨架与附录基准一致:Tile 为 256×128×K,含 Epilogue CollectiveBuilder,alignment 按 16 字节对齐;设备侧完整 Arguments / initialize / run 见附录 launch_mutlassbench_one_layout 中的 MUTLASS 段落。

#include "mute/tensor.hpp"
#include "mutlass/mutlass.h"
#include "mutlass/numeric_types.h"
#include "mutlass/gemm/device/gemm_universal_adapter.h"
#include "mutlass/gemm/collective/collective_builder.hpp"
#include "mutlass/epilogue/collective/collective_builder.hpp"

using MutlassLayoutA = mutlass::layout::RowMajor;
using MutlassLayoutB = mutlass::layout::RowMajor;
using MutlassLayoutC = mutlass::layout::RowMajor;
using MutlassLayoutD = mutlass::layout::RowMajor;
using MutlassTypeA = mutlass::half_t;
using MutlassTypeB = mutlass::half_t;
using MutlassTypeC = float;
using MutlassTypeD = float;
using MutlassTypeAccumulator = float;
using MutlassTypeCompute = float;
using MutlassArchTag = mutlass::arch::Mp22;
using MutlassOpClass = mutlass::arch::OpClassTensorOp;
using MutlassTileShape = mute::Shape<mute::_256, mute::_128, mute::_32>;
using MutlassClusterShape = mute::Shape<mute::_1, mute::_1, mute::_1>;
static constexpr int MutlassAlignmentA = 16 / sizeof(MutlassTypeA);
static constexpr int MutlassAlignmentB = 16 / sizeof(MutlassTypeB);
static constexpr int MutlassAlignmentC = 16 / sizeof(MutlassTypeC);
static constexpr int MutlassAlignmentD = 16 / sizeof(MutlassTypeD);

using MutlassCollectiveMainloop = typename mutlass::gemm::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTypeA, MutlassLayoutA, MutlassAlignmentA,
MutlassTypeB, MutlassLayoutB, MutlassAlignmentB,
MutlassTypeAccumulator,
MutlassTileShape,
MutlassClusterShape,
mutlass::gemm::collective::StageCountAuto,
mutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using MutlassCollectiveEpilogue = typename mutlass::epilogue::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTileShape,
MutlassClusterShape,
mutlass::epilogue::collective::EpilogueTileAuto,
MutlassTypeAccumulator, MutlassTypeCompute,
MutlassTypeC, MutlassLayoutC, MutlassAlignmentC,
MutlassTypeD, MutlassLayoutD, MutlassAlignmentD,
mutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;

using MutlassGemmKernel = mutlass::gemm::kernel::GemmUniversal<
mute::Shape<int, int, int, int>, MutlassCollectiveMainloop, MutlassCollectiveEpilogue>;
using MutlassGemm = mutlass::gemm::device::GemmUniversalAdapter<MutlassGemmKernel>;

用于筛 tile 配置的 profiler 调用示例(附录中 MutlassTileShape 的注释说明其来源;需在已安装 MUTLASS 工具链的环境中执行):

mutlass_profiler \
--operation=Gemm --providers=mutlass \
--verification-enabled=false \
--m=2048 --n=2048 --k=2048 \
--op_class=tensorop --accum=f32 \
--A=f16:row --B=f16:row --C=f32:row --D=f32:row

八、把多个后端组织到同一 benchmark 框架

这一节的重点不是“谁快”,而是“如何可持续比较”。片段中的 CHECK_MUSA 与第九节「最小可运行 main.cpp」中的宏定义相同(附录开头亦有一份)。

static float time_kernel_ms(void (*launch)(void*), void* ctx, int warmup, int iters) {
musaEvent_t start, stop;
CHECK_MUSA(musaEventCreate(&start));
CHECK_MUSA(musaEventCreate(&stop));
for (int i = 0; i < warmup; ++i) launch(ctx);
CHECK_MUSA(musaDeviceSynchronize());
CHECK_MUSA(musaEventRecord(start));
for (int i = 0; i < iters; ++i) launch(ctx);
CHECK_MUSA(musaEventRecord(stop));
CHECK_MUSA(musaEventSynchronize(stop));
float ms = 0.0f;
CHECK_MUSA(musaEventElapsedTime(&ms, start, stop));
CHECK_MUSA(musaEventDestroy(start));
CHECK_MUSA(musaEventDestroy(stop));
return ms / iters;
}

spot-check 示例(与附录中 check_close 相同思想:按布局取线性下标,采样若干点比全量对拍更轻;MemLayout 与第二节一致,若已定义可删 enum):

enum class MemLayout { RowMajor, ColMajor };

static void check_close(const float* hC0, const float* hC1, MemLayout layout,
int M, int N, const char* name0, const char* name1) {
const int points[][2] = {{0, 0}, {1, 2}, {31, 7}, {127, 511}};
float max_err = 0.0f;
for (int p = 0; p < (int)(sizeof(points) / sizeof(points[0])); ++p) {
int i = points[p][0], j = points[p][1];
if (i >= M || j >= N) continue;
int idx = (layout == MemLayout::RowMajor) ? (i * N + j) : (j * M + i);
float e = std::fabs(hC0[idx] - hC1[idx]);
if (e > max_err) max_err = e;
}
std::printf("Spot-check %s vs %s: max_abs_err=%e\n", name0, name1, max_err);
}

九、一份最小可运行 main.cpp(先跑通再扩展)

如果你只想先验证 muBLAS 链路,下面这份最小程序可直接编译运行。多后端统一计时、双布局、muBLASLt / muDNN / MUTLASS / WMMA / baseline 的完整实现见文末附录(将附录保存为 bench_gemm_2048.mu 后按附录小节的编译说明链接即可)。

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <random>
#include <vector>

#include <musa_runtime.h>
#include <musa_fp16.h>
#include <mublas.h>

#define CHECK_MUSA(call) do { \
musaError_t err = (call); \
if (err != musaSuccess) { \
std::fprintf(stderr, "MUSA error %s:%d -> %s\n", __FILE__, __LINE__, musaGetErrorString(err)); \
std::exit(EXIT_FAILURE); \
} \
} while (0)

#define CHECK_MUBLAS(call) do { \
mublasStatus_t st = (call); \
if (st != MUBLAS_STATUS_SUCCESS) { \
std::fprintf(stderr, "muBLAS error %s:%d status=%d\n", __FILE__, __LINE__, (int)st); \
std::exit(EXIT_FAILURE); \
} \
} while (0)

static inline __half f32_to_half(float x) { return __float2half(x); }
static inline float half_to_f32(__half x) { return __half2float(x); }

int main() {
constexpr int M = 64, N = 64, K = 64; // col-major
std::vector<__half> hA(M * K), hB(K * N);
std::vector<float> hC(M * N, 0.0f), hRef(M * N, 0.0f);

std::mt19937 rng(123);
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
for (auto& x : hA) x = f32_to_half(dist(rng));
for (auto& x : hB) x = f32_to_half(dist(rng));

__half* dA = nullptr;
__half* dB = nullptr;
float* dC = nullptr;
CHECK_MUSA(musaMalloc(&dA, sizeof(__half) * hA.size()));
CHECK_MUSA(musaMalloc(&dB, sizeof(__half) * hB.size()));
CHECK_MUSA(musaMalloc(&dC, sizeof(float) * hC.size()));

CHECK_MUSA(musaMemcpy(dA, hA.data(), sizeof(__half) * hA.size(), musaMemcpyHostToDevice));
CHECK_MUSA(musaMemcpy(dB, hB.data(), sizeof(__half) * hB.size(), musaMemcpyHostToDevice));
CHECK_MUSA(musaMemset(dC, 0, sizeof(float) * hC.size()));

mublasHandle_t handle = nullptr;
CHECK_MUBLAS(mublasCreate(&handle));

const float alpha = 1.0f, beta = 0.0f;
CHECK_MUBLAS(mublasGemmEx(
handle, MUBLAS_OP_N, MUBLAS_OP_N,
M, N, K, &alpha,
dA, MUSA_R_16F, M,
dB, MUSA_R_16F, K,
&beta, dC, MUSA_R_32F, M,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));

CHECK_MUSA(musaMemcpy(hC.data(), dC, sizeof(float) * hC.size(), musaMemcpyDeviceToHost));

for (int col = 0; col < N; ++col) {
for (int row = 0; row < M; ++row) {
float acc = 0.0f;
for (int kk = 0; kk < K; ++kk) {
acc += half_to_f32(hA[row + kk * M]) * half_to_f32(hB[kk + col * K]);
}
hRef[row + col * M] = acc;
}
}

float max_abs_err = 0.0f;
for (size_t i = 0; i < hC.size(); ++i) {
max_abs_err = std::max(max_abs_err, std::fabs(hC[i] - hRef[i]));
}
std::cout << "max_abs_err = " << max_abs_err << "\n";
std::cout << ((max_abs_err < 1e-2f) ? "PASS" : "WARN") << "\n";

CHECK_MUBLAS(mublasDestroy(handle));
CHECK_MUSA(musaFree(dA));
CHECK_MUSA(musaFree(dB));
CHECK_MUSA(musaFree(dC));
return 0;
}

编译命令(本机可用范式):

mcc -x musa -O2 main.cpp -o gemm_mublas_demo -lmublas -lmusart
./gemm_mublas_demo

十、几个高频报错,直接给修复命令

1) musa_runtime.h not found

mcc -x musa -O2 main.cpp -o demo \
-I/usr/local/musa/include -L/usr/local/musa/lib64 -lmublas -lmusart

2) unknown type name '__half'

#include <musa_fp16.h>

3) undefined reference to musaGetErrorString

# 链接阶段补 musart
mcc -x musa -O2 main.cpp -o demo -lmublas -lmusart

结语

把 Tensor Core 落地成“可交付工程”的关键,不是单点极限优化,而是把以下三件事固定下来:

  • 布局映射写清楚(尤其 row-major 到 col-major 的解释)
  • 正确性基线长期保留(baseline + spot-check)
  • 多后端统一计时与统一入口(便于持续演进)

按这个节奏推进,后续不管你要切到 muBLASLt、muDNN 还是 MUTLASS,成本都会明显更可控。

附录:bench_gemm_2048.mu(完整可运行源码)

以下是一份自包含的基准源文件:固定 M=N=K=2048,在 row-major / col-major 两种布局下对 muBLAS、muBLASLt、muDNN、MUTLASS、WMMA 与手写 FP16 baseline 做统一计时与 spot-check。将代码保存为 bench_gemm_2048.mu 后,用 mcc 编译;除 MUSA 运行时与 muBLAS 外,还需链接 muBLASLtmuDNN,并为 MUTLASSmute 头文件增加 -I(具体路径随 SDK 安装前缀而定,与你本机其他 MUSA 示例工程中的写法对齐即可)。下面给出结构示例(请把路径改成你机器上的实际目录):

mcc -x musa -O2 bench_gemm_2048.mu -o bench_gemm_2048 \
-I/usr/local/musa/include \
-I/path/to/mutlass/include \
-I/path/to/mute/include \
-L/usr/local/musa/lib64 \
-lmublas -lmublasLt -lmudnn -lmusart

若你手头已有带该目标的 Makefile,可直接沿用其中的 include / link 行;仅阅读本文时,以上命令足以理解依赖关系,不必访问任何外部网页或仓库链接。

/*
* Benchmark fixed M=N=K=2048 GEMM on MUSA:
* - muBLAS (mublasGemmEx)
* - muBLASLt (mublasLtMatmul + heuristic algo)
* - muDNN MatMul (musa::dnn::MatMul)
* - MUTLASS GEMM (FP16 input, FP32 accumulate/output)
* - WMMA GEMM (FP16 input, FP32 accumulate/output)
* - Handwritten non-TC GEMM (FP16 input, FP32 output)
*
* For fairness:
* - Measure only the compute call with musaEvent timing (avg over iters)
* - One-time setup (handle creation, tensor desc, workspace alloc) is excluded
* - Compare both logical layouts: row-major and col-major
*
* Build: see the prose above this code block (mcc + include/lib paths).
*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>

#include <musa_runtime.h>
#include <mma.h>

#include <mublas.h>
#include <mublasLt.h>
#include <mudnn.h>

#include "mute/tensor.hpp"
#include "mutlass/mutlass.h"
#include "mutlass/numeric_types.h"
#include "mutlass/gemm/device/gemm_universal_adapter.h"
#include "mutlass/gemm/collective/collective_builder.hpp"
#include "mutlass/epilogue/collective/collective_builder.hpp"
#include "mutlass/util/device_memory.h"
#include "mutlass/util/packed_stride.hpp"

using namespace mtmusa;

#define CHECK_MUSA(call) \
do { \
musaError_t err = call; \
if (err != musaSuccess) { \
fprintf(stderr, "MUSA Error at %s:%d - %s\n", __FILE__, __LINE__, \
musaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while (0)

#define CHECK_MUBLAS(call) \
do { \
mublasStatus_t st = call; \
if (st != MUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "muBLAS Error at %s:%d - status=%d\n", __FILE__, \
__LINE__, (int)st); \
exit(EXIT_FAILURE); \
} \
} while (0)

static const int M_DIM = 2048;
static const int N_DIM = 2048;
static const int K_DIM = 2048;

static inline float gflops(float ms) {
double ops = 2.0 * (double)M_DIM * (double)N_DIM * (double)K_DIM;
return (float)(ops / (ms / 1000.0) / 1e9);
}

enum class MemLayout { RowMajor, ColMajor };

static const char* layout_name(MemLayout l) {
return l == MemLayout::RowMajor ? "row-major" : "col-major";
}

static void fill_fp16(MemLayout layout, half* A, half* B) {
// Deterministic values, bounded to avoid overflow.
// A: MxK, B: KxN, stored in requested layout.
for (int i = 0; i < M_DIM; ++i) {
for (int k = 0; k < K_DIM; ++k) {
float v = ((i * 131 + k * 17) % 97) * 0.01f;
int idx = (layout == MemLayout::RowMajor) ? (i * K_DIM + k) : (k * M_DIM + i);
A[idx] = __float2half(v);
}
}
for (int k = 0; k < K_DIM; ++k) {
for (int j = 0; j < N_DIM; ++j) {
float v = ((k * 29 + j * 43) % 89) * 0.01f;
int idx = (layout == MemLayout::RowMajor) ? (k * N_DIM + j) : (j * K_DIM + k);
B[idx] = __float2half(v);
}
}
}

static void zero_fp32(float* C, size_t n) {
memset(C, 0, n * sizeof(float));
}

static float time_kernel_ms(void (*launch)(void*), void* ctx, int warmup, int iters) {
musaEvent_t start, stop;
CHECK_MUSA(musaEventCreate(&start));
CHECK_MUSA(musaEventCreate(&stop));

for (int i = 0; i < warmup; ++i) launch(ctx);
CHECK_MUSA(musaDeviceSynchronize());

CHECK_MUSA(musaEventRecord(start));
for (int i = 0; i < iters; ++i) launch(ctx);
CHECK_MUSA(musaEventRecord(stop));
CHECK_MUSA(musaEventSynchronize(stop));

float ms = 0.0f;
CHECK_MUSA(musaEventElapsedTime(&ms, start, stop));
CHECK_MUSA(musaEventDestroy(start));
CHECK_MUSA(musaEventDestroy(stop));
return ms / iters;
}

// ---------------- WMMA GEMM (global memory load/store) ----------------

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

__global__ void wmma_gemm_kernel_rm(const half* __restrict__ A_rm,
const half* __restrict__ B_rm,
float* __restrict__ C_rm) {
int tile_m = (int)blockIdx.y;
int tile_n = (int)blockIdx.x;
int block_row = tile_m * WMMA_M;
int block_col = tile_n * WMMA_N;

wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

for (int i = 0; i < (int)(sizeof(c_frag.x) / sizeof(c_frag.x[0])); ++i) c_frag.x[i] = 0.0f;

for (int k0 = 0; k0 < K_DIM; k0 += WMMA_K) {
const half* tile_a = A_rm + block_row * K_DIM + k0;
const half* tile_b = B_rm + k0 * N_DIM + block_col;
wmma::load_matrix_sync(a_frag, tile_a, K_DIM);
wmma::load_matrix_sync(b_frag, tile_b, N_DIM);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

if (block_row + WMMA_M <= M_DIM && block_col + WMMA_N <= N_DIM) {
wmma::store_matrix_sync(C_rm + block_row * N_DIM + block_col, c_frag, N_DIM,
wmma::mem_row_major);
}
}

__global__ void wmma_gemm_kernel_cm(const half* __restrict__ A_cm,
const half* __restrict__ B_cm,
float* __restrict__ C_cm) {
int tile_m = (int)blockIdx.y;
int tile_n = (int)blockIdx.x;
int block_row = tile_m * WMMA_M;
int block_col = tile_n * WMMA_N;

// In col-major memory, a matrix of shape (M,K) is stored with leading dim lda=M.
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

for (int i = 0; i < (int)(sizeof(c_frag.x) / sizeof(c_frag.x[0])); ++i) c_frag.x[i] = 0.0f;

for (int k0 = 0; k0 < K_DIM; k0 += WMMA_K) {
const half* tile_a = A_cm + block_row + k0 * M_DIM; // A(i,k) at k*M + i
const half* tile_b = B_cm + k0 + block_col * K_DIM; // B(k,j) at j*K + k
wmma::load_matrix_sync(a_frag, tile_a, M_DIM);
wmma::load_matrix_sync(b_frag, tile_b, K_DIM);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

if (block_row + WMMA_M <= M_DIM && block_col + WMMA_N <= N_DIM) {
wmma::store_matrix_sync(C_cm + block_row + block_col * M_DIM, c_frag, M_DIM,
wmma::mem_col_major);
}
}

// ---------------- Handwritten non-TC GEMM (FP16 -> FP32) ----------------

template <MemLayout L>
__global__ void gemm_fp16_scalar_kernel(const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C) {
int i = (int)(blockIdx.y * blockDim.y + threadIdx.y);
int j = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= M_DIM || j >= N_DIM) return;

float acc = 0.0f;
#pragma unroll 4
for (int k = 0; k < K_DIM; ++k) {
int a_idx = (L == MemLayout::RowMajor) ? (i * K_DIM + k) : (k * M_DIM + i);
int b_idx = (L == MemLayout::RowMajor) ? (k * N_DIM + j) : (j * K_DIM + k);
acc += __half2float(A[a_idx]) * __half2float(B[b_idx]);
}
int c_idx = (L == MemLayout::RowMajor) ? (i * N_DIM + j) : (j * M_DIM + i);
C[c_idx] = acc;
}

// ---------------- muBLAS GEMMEx wrapper ----------------

struct MublasCtx {
mublasHandle_t handle;
MemLayout layout;
const half* dA;
const half* dB;
float* dC;
};

static void launch_mublas(void* p) {
auto* ctx = (MublasCtx*)p;
const float alpha = 1.0f;
const float beta = 0.0f;

if (ctx->layout == MemLayout::ColMajor) {
// Column-major direct: C(MxN) = A(MxK) * B(KxN)
int lda = M_DIM;
int ldb = K_DIM;
int ldc = M_DIM;
CHECK_MUBLAS(mublasGemmEx(ctx->handle,
MUBLAS_OP_N, MUBLAS_OP_N,
M_DIM, N_DIM, K_DIM,
&alpha,
ctx->dA, MUSA_R_16F, lda,
ctx->dB, MUSA_R_16F, ldb,
&beta,
ctx->dC, MUSA_R_32F, ldc,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// Row-major mapping using column-major GEMM without data transform:
// If A_rm is row-major (MxK), treat it as col-major (KxM) of A^T.
// If B_rm is row-major (KxN), treat it as col-major (NxK) of B^T.
// Compute C_rm as row-major (MxN) which is col-major (NxM) of C^T.
// So compute C^T(NxM) = B^T(NxK) * A^T(KxM) in column-major.
int m = N_DIM; // rows of C^T
int n = M_DIM; // cols of C^T
int k = K_DIM;
int lda = N_DIM; // B^T is (N x K) col-major, leading dim = N
int ldb = K_DIM; // A^T is (K x M) col-major, leading dim = K
int ldc = N_DIM; // C^T is (N x M) col-major, leading dim = N
CHECK_MUBLAS(mublasGemmEx(ctx->handle,
MUBLAS_OP_N, MUBLAS_OP_N,
m, n, k,
&alpha,
ctx->dB, MUSA_R_16F, lda,
ctx->dA, MUSA_R_16F, ldb,
&beta,
ctx->dC, MUSA_R_32F, ldc,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}

// ---------------- muBLASLt Matmul wrapper ----------------

struct MublasLtCtx {
mublasLtHandle_t lt_handle;
MemLayout layout;
const half* dA;
const half* dB;
float* dC;
mublasLtMatmulDesc_t op_desc;
mublasLtMatrixLayout_t a_desc;
mublasLtMatrixLayout_t b_desc;
mublasLtMatrixLayout_t c_desc;
mublasLtMatrixLayout_t d_desc;
mublasLtMatmulAlgo_t algo;
bool has_algo;
void* workspace;
size_t workspace_bytes;
};

static void mublaslt_destroy_ctx(MublasLtCtx* ctx) {
if (ctx->a_desc) CHECK_MUBLAS(mublasLtMatrixLayoutDestroy(ctx->a_desc));
if (ctx->b_desc) CHECK_MUBLAS(mublasLtMatrixLayoutDestroy(ctx->b_desc));
if (ctx->c_desc) CHECK_MUBLAS(mublasLtMatrixLayoutDestroy(ctx->c_desc));
if (ctx->d_desc) CHECK_MUBLAS(mublasLtMatrixLayoutDestroy(ctx->d_desc));
if (ctx->op_desc) CHECK_MUBLAS(mublasLtMatmulDescDestroy(ctx->op_desc));
if (ctx->workspace) CHECK_MUSA(musaFree(ctx->workspace));
ctx->a_desc = nullptr;
ctx->b_desc = nullptr;
ctx->c_desc = nullptr;
ctx->d_desc = nullptr;
ctx->op_desc = nullptr;
ctx->workspace = nullptr;
ctx->workspace_bytes = 0;
}

static void mublaslt_build_ctx(MublasLtCtx* ctx) {
memset(&ctx->algo, 0, sizeof(ctx->algo));
ctx->has_algo = false;
ctx->workspace = nullptr;
ctx->workspace_bytes = 0;
ctx->op_desc = nullptr;
ctx->a_desc = nullptr;
ctx->b_desc = nullptr;
ctx->c_desc = nullptr;
ctx->d_desc = nullptr;

CHECK_MUBLAS(mublasLtMatmulDescCreate(&ctx->op_desc, MUBLAS_COMPUTE_32F, MUSA_R_32F));

mublasOperation_t trans_n = MUBLAS_OP_N;
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(
ctx->op_desc, MUBLASLT_MATMUL_DESC_TRANSA, &trans_n, sizeof(trans_n)));
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(
ctx->op_desc, MUBLASLT_MATMUL_DESC_TRANSB, &trans_n, sizeof(trans_n)));

if (ctx->layout == MemLayout::ColMajor) {
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->a_desc, MUSA_R_16F, M_DIM, K_DIM, M_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->b_desc, MUSA_R_16F, K_DIM, N_DIM, K_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->c_desc, MUSA_R_32F, M_DIM, N_DIM, M_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->d_desc, MUSA_R_32F, M_DIM, N_DIM, M_DIM));
mublasLtOrder_t order = MUBLASLT_ORDER_COL;
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->a_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->b_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->c_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->d_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
} else {
// Match muBLAS row-major mapping:
// C^T(NxM) = B^T(NxK) * A^T(KxM), all interpreted in column-major.
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->a_desc, MUSA_R_16F, N_DIM, K_DIM, N_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->b_desc, MUSA_R_16F, K_DIM, M_DIM, K_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->c_desc, MUSA_R_32F, N_DIM, M_DIM, N_DIM));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&ctx->d_desc, MUSA_R_32F, N_DIM, M_DIM, N_DIM));
mublasLtOrder_t order = MUBLASLT_ORDER_COL;
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->a_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->b_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->c_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
ctx->d_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
}

mublasLtMatmulPreference_t pref = nullptr;
CHECK_MUBLAS(mublasLtMatmulPreferenceCreate(&pref));
size_t max_workspace = 64ull * 1024ull * 1024ull;
CHECK_MUBLAS(mublasLtMatmulPreferenceSetAttribute(
pref, MUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace, sizeof(max_workspace)));

mublasLtMatmulHeuristicResult_t heuristic;
memset(&heuristic, 0, sizeof(heuristic));
int algo_count = 0;
CHECK_MUBLAS(mublasLtMatmulAlgoGetHeuristic(ctx->lt_handle,
ctx->op_desc,
ctx->a_desc,
ctx->b_desc,
ctx->c_desc,
ctx->d_desc,
pref,
1,
&heuristic,
&algo_count));
CHECK_MUBLAS(mublasLtMatmulPreferenceDestroy(pref));

if (algo_count > 0 && heuristic.state == MUBLAS_STATUS_SUCCESS) {
ctx->algo = heuristic.algo;
ctx->has_algo = true;
ctx->workspace_bytes = heuristic.workspaceSize;
if (ctx->workspace_bytes > 0) CHECK_MUSA(musaMalloc(&ctx->workspace, ctx->workspace_bytes));
}
}

static void launch_mublaslt(void* p) {
auto* ctx = (MublasLtCtx*)p;
const float alpha = 1.0f;
const float beta = 0.0f;
const void* A_ptr = (ctx->layout == MemLayout::RowMajor) ? (const void*)ctx->dB : (const void*)ctx->dA;
const void* B_ptr = (ctx->layout == MemLayout::RowMajor) ? (const void*)ctx->dA : (const void*)ctx->dB;
CHECK_MUBLAS(mublasLtMatmul(ctx->lt_handle,
ctx->op_desc,
&alpha,
A_ptr,
ctx->a_desc,
B_ptr,
ctx->b_desc,
&beta,
ctx->dC,
ctx->c_desc,
ctx->dC,
ctx->d_desc,
ctx->has_algo ? &ctx->algo : nullptr,
ctx->workspace,
ctx->workspace_bytes,
nullptr));
}

// ---------------- muDNN MatMul wrapper ----------------

struct MudnnCtx {
musa::dnn::Handle* handle;
musa::dnn::MatMul* op;
musa::dnn::Tensor* tA;
musa::dnn::Tensor* tB;
musa::dnn::Tensor* tC;
MemLayout layout;
};

static void launch_mudnn(void* p) {
auto* ctx = (MudnnCtx*)p;
auto st = ctx->op->Run(*ctx->handle, *ctx->tC, *ctx->tA, *ctx->tB);
if (st != musa::dnn::Status::SUCCESS) {
fprintf(stderr, "muDNN MatMul Run failed, status=%d\n", (int)st);
exit(EXIT_FAILURE);
}
}

// ---------------- MUTLASS GEMM wrapper (tutorial style) ----------------

using MutlassLayoutA = mutlass::layout::RowMajor;
using MutlassLayoutB = mutlass::layout::RowMajor;
using MutlassLayoutC = mutlass::layout::RowMajor;
using MutlassLayoutD = mutlass::layout::RowMajor;
using MutlassTypeA = mutlass::half_t;
using MutlassTypeB = mutlass::half_t;
using MutlassTypeC = float;
using MutlassTypeD = float;
using MutlassTypeAccumulator = float;
using MutlassTypeCompute = float;
using MutlassArchTag = mutlass::arch::Mp22;
using MutlassOpClass = mutlass::arch::OpClassTensorOp;
// Tuned from mutlass_profiler for M=N=K=2048 FP16->FP32 on MP22.
using MutlassTileShape = mute::Shape<mute::_256, mute::_128, mute::_32>;
using MutlassClusterShape = mute::Shape<mute::_1, mute::_1, mute::_1>;
static constexpr int MutlassAlignmentA = 16 / sizeof(MutlassTypeA);
static constexpr int MutlassAlignmentB = 16 / sizeof(MutlassTypeB);
static constexpr int MutlassAlignmentC = 16 / sizeof(MutlassTypeC);
static constexpr int MutlassAlignmentD = 16 / sizeof(MutlassTypeD);

using MutlassCollectiveMainloop = typename mutlass::gemm::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTypeA, MutlassLayoutA, MutlassAlignmentA,
MutlassTypeB, MutlassLayoutB, MutlassAlignmentB,
MutlassTypeAccumulator,
MutlassTileShape,
MutlassClusterShape,
mutlass::gemm::collective::StageCountAuto,
mutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using MutlassCollectiveEpilogue = typename mutlass::epilogue::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTileShape,
MutlassClusterShape,
mutlass::epilogue::collective::EpilogueTileAuto,
MutlassTypeAccumulator, MutlassTypeCompute,
MutlassTypeC, MutlassLayoutC, MutlassAlignmentC,
MutlassTypeD, MutlassLayoutD, MutlassAlignmentD,
mutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;

using MutlassGemmKernel = mutlass::gemm::kernel::GemmUniversal<
mute::Shape<int, int, int, int>, MutlassCollectiveMainloop, MutlassCollectiveEpilogue>;
using MutlassGemm = mutlass::gemm::device::GemmUniversalAdapter<MutlassGemmKernel>;

static inline void check_mutlass(mutlass::Status st, const char* where) {
if (st != mutlass::Status::kSuccess) {
fprintf(stderr, "MUTLASS Error at %s: %s\n", where, mutlass::mutlassGetStatusString(st));
exit(EXIT_FAILURE);
}
}

struct MutlassCtx {
MemLayout layout;
const half* dA;
const half* dB;
float* dC;
MutlassGemm* gemm;
typename MutlassGemm::Arguments* arguments;
mutlass::device_memory::allocation<uint8_t>* workspace;
};

static void launch_mutlass(void* p) {
auto* ctx = (MutlassCtx*)p;
check_mutlass(ctx->gemm->run(), "run");
CHECK_MUSA(musaGetLastError());
}

// ---------------- WMMA launchers ----------------

struct WmmaCtx {
MemLayout layout;
const half* dA;
const half* dB;
float* dC;
};

static void launch_wmma(void* p) {
auto* ctx = (WmmaCtx*)p;
dim3 block(32u, 1u, 1u); // 1 warp
dim3 grid((unsigned int)(N_DIM / WMMA_N), (unsigned int)(M_DIM / WMMA_M), 1u);
if (ctx->layout == MemLayout::RowMajor) {
wmma_gemm_kernel_rm<<<grid, block>>>(ctx->dA, ctx->dB, ctx->dC);
} else {
wmma_gemm_kernel_cm<<<grid, block>>>(ctx->dA, ctx->dB, ctx->dC);
}
CHECK_MUSA(musaGetLastError());
}

// ---------------- Handwritten launchers ----------------

struct HandCtx {
MemLayout layout;
const half* dA;
const half* dB;
float* dC;
};

static void launch_hand(void* p) {
auto* ctx = (HandCtx*)p;
dim3 block(16u, 16u, 1u);
dim3 grid((unsigned int)((N_DIM + 15) / 16), (unsigned int)((M_DIM + 15) / 16), 1u);
if (ctx->layout == MemLayout::RowMajor) {
gemm_fp16_scalar_kernel<MemLayout::RowMajor><<<grid, block>>>(ctx->dA, ctx->dB, ctx->dC);
} else {
gemm_fp16_scalar_kernel<MemLayout::ColMajor><<<grid, block>>>(ctx->dA, ctx->dB, ctx->dC);
}
CHECK_MUSA(musaGetLastError());
}

static void print_device() {
int device = 0;
CHECK_MUSA(musaGetDevice(&device));
musaDeviceProp prop;
CHECK_MUSA(musaGetDeviceProperties(&prop, device));
printf("========================================\n");
printf(" Device: %s\n", prop.name);
printf(" CC: %d.%d, SM: %d\n", prop.major, prop.minor, prop.multiProcessorCount);
printf("========================================\n");
}

static void check_close(const float* hC0, const float* hC1, MemLayout layout, const char* name0,
const char* name1) {
// Compare a sparse set of points to avoid full copy overhead in reporting.
// For deterministic inputs, this catches gross layout/transpose mistakes.
const int points[][2] = {{0,0},{1,2},{31,7},{127,511},{1023,1001},{2047,2047}};
float max_err = 0.0f;
for (int p = 0; p < (int)(sizeof(points)/sizeof(points[0])); ++p) {
int i = points[p][0], j = points[p][1];
int idx = (layout == MemLayout::RowMajor) ? (i * N_DIM + j) : (j * M_DIM + i);
float e = fabsf(hC0[idx] - hC1[idx]);
if (e > max_err) max_err = e;
}
printf(" Spot-check %s vs %s: max_abs_err=%e\n", name0, name1, max_err);
}

int main() {
print_device();
printf("Fixed GEMM: M=N=K=2048, A/B=FP16, accum/output=FP32\n");
printf("Timing: avg ms over 100 iters (after 10 warmup)\n\n");

size_t bytesA = (size_t)M_DIM * K_DIM * sizeof(half);
size_t bytesB = (size_t)K_DIM * N_DIM * sizeof(half);
size_t bytesC = (size_t)M_DIM * N_DIM * sizeof(float);

half* hA_rm = (half*)malloc(bytesA);
half* hB_rm = (half*)malloc(bytesB);
half* hA_cm = (half*)malloc(bytesA);
half* hB_cm = (half*)malloc(bytesB);
float* hC0 = (float*)malloc(bytesC);
float* hC1 = (float*)malloc(bytesC);

if (!hA_rm || !hB_rm || !hA_cm || !hB_cm || !hC0 || !hC1) {
fprintf(stderr, "Host malloc failed\n");
return 1;
}

fill_fp16(MemLayout::RowMajor, hA_rm, hB_rm);
fill_fp16(MemLayout::ColMajor, hA_cm, hB_cm);

// Device buffers per layout (keep separate to avoid confusion)
half *dA_rm, *dB_rm, *dA_cm, *dB_cm;
float *dC_rm, *dC_cm;
CHECK_MUSA(musaMalloc((void**)&dA_rm, bytesA));
CHECK_MUSA(musaMalloc((void**)&dB_rm, bytesB));
CHECK_MUSA(musaMalloc((void**)&dC_rm, bytesC));
CHECK_MUSA(musaMalloc((void**)&dA_cm, bytesA));
CHECK_MUSA(musaMalloc((void**)&dB_cm, bytesB));
CHECK_MUSA(musaMalloc((void**)&dC_cm, bytesC));

CHECK_MUSA(musaMemcpy(dA_rm, hA_rm, bytesA, musaMemcpyHostToDevice));
CHECK_MUSA(musaMemcpy(dB_rm, hB_rm, bytesB, musaMemcpyHostToDevice));
CHECK_MUSA(musaMemcpy(dA_cm, hA_cm, bytesA, musaMemcpyHostToDevice));
CHECK_MUSA(musaMemcpy(dB_cm, hB_cm, bytesB, musaMemcpyHostToDevice));

// muBLAS handle
mublasHandle_t blas;
CHECK_MUBLAS(mublasCreate(&blas));
mublasLtHandle_t blaslt;
CHECK_MUBLAS(mublasLtCreate(&blaslt));

// muDNN setup
musa::dnn::Handle dnn_handle;
musa::dnn::MatMul dnn_op;
dnn_op.SetComputeMode(musa::dnn::MatMul::ComputeMode::TENSOR);
dnn_op.SetTranspose(false, false);
dnn_op.SetAlpha(1.0);
dnn_op.SetBeta(0.0);

// Workspace (optional)
size_t ws_bytes_rm = 0, ws_bytes_cm = 0;
void* d_ws_rm = nullptr;
void* d_ws_cm = nullptr;

auto make_tensor = [](musa::dnn::Tensor& t, void* addr, musa::dnn::Tensor::Type ty,
int64_t dim0, int64_t dim1, int64_t ld, bool row_major) {
t.SetAddr(addr);
t.SetType(ty);
if (row_major) {
int64_t dim[2] = {dim0, dim1};
int64_t stride[2] = {ld, 1};
t.SetNdInfo(2, dim, stride);
} else {
int64_t dim[2] = {dim0, dim1};
int64_t stride[2] = {1, ld};
t.SetNdInfo(2, dim, stride);
}
};

auto bench_one_layout = [&](MemLayout layout) {
const half* dA = (layout == MemLayout::RowMajor) ? dA_rm : dA_cm;
const half* dB = (layout == MemLayout::RowMajor) ? dB_rm : dB_cm;
float* dC = (layout == MemLayout::RowMajor) ? dC_rm : dC_cm;

printf("== Layout: %s ==\n", layout_name(layout));

// muBLAS
zero_fp32(hC0, (size_t)M_DIM * N_DIM);
CHECK_MUSA(musaMemset(dC, 0, bytesC));
MublasCtx blas_ctx{blas, layout, dA, dB, dC};
float ms_blas = time_kernel_ms(launch_mublas, &blas_ctx, 10, 100);
CHECK_MUSA(musaMemcpy(hC0, dC, bytesC, musaMemcpyDeviceToHost));
printf(" muBLAS: %.3f ms, %.2f GFLOPS\n", ms_blas, gflops(ms_blas));

// muBLASLt
CHECK_MUSA(musaMemset(dC, 0, bytesC));
MublasLtCtx blaslt_ctx{};
blaslt_ctx.lt_handle = blaslt;
blaslt_ctx.layout = layout;
blaslt_ctx.dA = dA;
blaslt_ctx.dB = dB;
blaslt_ctx.dC = dC;
mublaslt_build_ctx(&blaslt_ctx);
float ms_blaslt = time_kernel_ms(launch_mublaslt, &blaslt_ctx, 10, 100);
CHECK_MUSA(musaMemcpy(hC1, dC, bytesC, musaMemcpyDeviceToHost));
printf(" muBLASLt: %.3f ms, %.2f GFLOPS (ws=%zu bytes)\n",
ms_blaslt, gflops(ms_blaslt), blaslt_ctx.workspace_bytes);
check_close(hC0, hC1, layout, "muBLAS", "muBLASLt");
mublaslt_destroy_ctx(&blaslt_ctx);

// muDNN (configure tensors per layout)
musa::dnn::Tensor tA, tB, tC;
if (layout == MemLayout::ColMajor) {
// muDNN MatMul validates lda as the first stride component and rejects lda=1,
// effectively requiring row-major-like (ld,1) tensors. We therefore apply the
// standard transpose mapping without any data movement:
//
// C = A * B (A,B,C stored col-major)
// <=> C^T = B^T * A^T
//
// For col-major storage, the raw memory can be viewed as row-major storage of
// the transpose. So:
// - A_cm (M×K col-major) is A^T (K×M) row-major
// - B_cm (K×N col-major) is B^T (N×K) row-major
// - C_cm (M×N col-major) is C^T (N×M) row-major
make_tensor(tA, (void*)dB, musa::dnn::Tensor::Type::HALF, N_DIM, K_DIM, K_DIM, true); // B^T
make_tensor(tB, (void*)dA, musa::dnn::Tensor::Type::HALF, K_DIM, M_DIM, M_DIM, true); // A^T
make_tensor(tC, (void*)dC, musa::dnn::Tensor::Type::FLOAT, N_DIM, M_DIM, M_DIM, true); // C^T
} else {
// Direct row-major with stride (ld, 1) where ld = cols
make_tensor(tA, (void*)dA, musa::dnn::Tensor::Type::HALF, M_DIM, K_DIM, K_DIM, true);
make_tensor(tB, (void*)dB, musa::dnn::Tensor::Type::HALF, K_DIM, N_DIM, N_DIM, true);
make_tensor(tC, (void*)dC, musa::dnn::Tensor::Type::FLOAT, M_DIM, N_DIM, N_DIM, true);
}

size_t ws_bytes = 0;
auto st_ws = dnn_op.GetWorkspaceSize(dnn_handle, ws_bytes, tC, tA, tB);
if (st_ws != musa::dnn::Status::SUCCESS) {
fprintf(stderr, "muDNN GetWorkspaceSize failed, status=%d\n", (int)st_ws);
exit(EXIT_FAILURE);
}
void* d_ws = nullptr;
if (ws_bytes > 0) CHECK_MUSA(musaMalloc(&d_ws, ws_bytes));
auto maintainer = [d_ws](size_t) -> musa::dnn::MemoryHandler {
return musa::dnn::MemoryHandler(d_ws, [](void*) {});
};

CHECK_MUSA(musaMemset(dC, 0, bytesC));
MudnnCtx dnn_ctx{&dnn_handle, &dnn_op, &tA, &tB, &tC, layout};
float ms_dnn = time_kernel_ms(launch_mudnn, &dnn_ctx, 10, 100);
CHECK_MUSA(musaMemcpy(hC1, dC, bytesC, musaMemcpyDeviceToHost));
printf(" muDNN: %.3f ms, %.2f GFLOPS (ws=%zu bytes)\n", ms_dnn, gflops(ms_dnn), (size_t)ws_bytes);
check_close(hC0, hC1, layout, "muBLAS", "muDNN");
if (d_ws) musaFree(d_ws);

// MUTLASS (FP16 input + FP32 accumulate/output)
CHECK_MUSA(musaMemset(dC, 0, bytesC));
using MutlassArgs = typename MutlassGemm::Arguments;
MutlassGemm mutlass_gemm;
int mm = (layout == MemLayout::RowMajor) ? M_DIM : N_DIM;
int nn = (layout == MemLayout::RowMajor) ? N_DIM : M_DIM;
using ProblemShapeType = typename MutlassGemm::GemmKernel::ProblemShape;
using StrideA = typename MutlassGemm::GemmKernel::StrideA;
using StrideB = typename MutlassGemm::GemmKernel::StrideB;
using StrideC = typename MutlassGemm::GemmKernel::StrideC;
ProblemShapeType pshape{mm, nn, K_DIM, 1};
StrideA sa = mutlass::make_mute_packed_stride(StrideA{}, mute::make_shape(mm, K_DIM, 1));
StrideB sb = mutlass::make_mute_packed_stride(StrideB{}, mute::make_shape(nn, K_DIM, 1));
StrideC sc = mutlass::make_mute_packed_stride(StrideC{}, mute::make_shape(mm, nn, 1));
const mutlass::half_t* a_ptr = reinterpret_cast<const mutlass::half_t*>(
(layout == MemLayout::RowMajor) ? dA : dB);
const mutlass::half_t* b_ptr = reinterpret_cast<const mutlass::half_t*>(
(layout == MemLayout::RowMajor) ? dB : dA);
MutlassArgs mutlass_args{
mutlass::gemm::GemmUniversalMode::kGemm,
pshape,
{a_ptr, sa, b_ptr, sb},
{{1.0f, 0.0f}, dC, sc, dC, sc},
mutlass::KernelHardwareInfo{}};
size_t ws_size = MutlassGemm::get_workspace_size(mutlass_args);
mutlass::device_memory::allocation<uint8_t> mutlass_workspace(ws_size);

// Fair timing: initialize once, only benchmark run().
check_mutlass(mutlass_gemm.can_implement(mutlass_args), "can_implement");
check_mutlass(mutlass_gemm.initialize(mutlass_args, mutlass_workspace.get()), "initialize");

MutlassCtx mutlass_ctx{layout, dA, dB, dC, &mutlass_gemm, &mutlass_args, &mutlass_workspace};
float ms_mutlass = time_kernel_ms(launch_mutlass, &mutlass_ctx, 10, 100);
CHECK_MUSA(musaMemcpy(hC1, dC, bytesC, musaMemcpyDeviceToHost));
printf(" MUTLASS: %.3f ms, %.2f GFLOPS (FP16->FP32)\n", ms_mutlass, gflops(ms_mutlass));
check_close(hC0, hC1, layout, "muBLAS(fp16)", "MUTLASS(fp16)");

// WMMA
CHECK_MUSA(musaMemset(dC, 0, bytesC));
WmmaCtx wmma_ctx{layout, dA, dB, dC};
float ms_wmma = time_kernel_ms(launch_wmma, &wmma_ctx, 10, 100);
CHECK_MUSA(musaMemcpy(hC1, dC, bytesC, musaMemcpyDeviceToHost));
printf(" WMMA: %.3f ms, %.2f GFLOPS\n", ms_wmma, gflops(ms_wmma));
check_close(hC0, hC1, layout, "muBLAS", "WMMA");

// Handwritten
CHECK_MUSA(musaMemset(dC, 0, bytesC));
HandCtx hand_ctx{layout, dA, dB, dC};
float ms_hand = time_kernel_ms(launch_hand, &hand_ctx, 2, 5); // this is intentionally small (very slow)
CHECK_MUSA(musaMemcpy(hC1, dC, bytesC, musaMemcpyDeviceToHost));
printf(" HAND: %.3f ms, %.2f GFLOPS (note: scalar baseline)\n", ms_hand, gflops(ms_hand));
check_close(hC0, hC1, layout, "muBLAS", "HAND");

printf("\n");
};

bench_one_layout(MemLayout::RowMajor);
bench_one_layout(MemLayout::ColMajor);

CHECK_MUBLAS(mublasDestroy(blas));
CHECK_MUBLAS(mublasLtDestroy(blaslt));

musaFree(dA_rm); musaFree(dB_rm); musaFree(dC_rm);
musaFree(dA_cm); musaFree(dB_cm); musaFree(dC_cm);

free(hA_rm); free(hB_rm); free(hA_cm); free(hB_cm);
free(hC0); free(hC1);

return 0;
}