MUSA Tensor Core GEMM 实践手记:把“能跑”变成“能交付”
刚接手一个 GEMM 热点模块时,最常见的状态是:
“知道 Tensor Core 很重要,但不知道该从哪一层开始落地。”
这篇文章不追求“最全 API 手册”,而是按真实工程节奏,把路径拆成 6 步:
- 先写一个慢但可解释的 baseline
- 用 WMMA 建立 Tensor Core 编程直觉
- 用 muBLAS 快速获得稳定实现
- 在 muBLASLt 上做算法与 workspace 控制
- 在 muDNN 场景里对接算子接口
- 用 MUTLASS 做更可控的模板化定制
每一步都保留 example code。文中不含敏感环境信息,也不放性能数据。
开始前:先把“跑不起来”的问题排干净
建议先确认工具链,而不是先写代码。
which mcc
mcc --version
# 以下路径按你的安装位置调整
ls /usr/local/musa/include/musa_runtime.h
ls /usr/local/musa/include/mublas.h
ls /usr/local/musa/lib64/libmusart.so
ls /usr/local/musa/lib64/libmublas.so
如果这些都正常,再开始正文会省很多时间。
一、先讲最容易出错的点:布局
大多数 GEMM 的“结果不对”,根因都不是算子,而是布局解释。
row-major:A[i, k] -> i * ld + kcol-major:A[i, k] -> k * ld + i
BLAS 默认按 col-major。业务里常见 row-major,这时常用映射:
这个映射的价值很实际:不写显式 transpose kernel,也能复用库接口。
二、Step 1:先有 baseline(哪怕它很慢)
baseline 的意义不是快,而是“可解释、可对拍、可回归”。
2.1 一个最小 baseline kernel
与文末附录中的基准程序 bench_gemm_2048.mu 一致:布局用 MemLayout,维度在附录里为编译期常量 M_DIM/N_DIM/K_DIM;下面给出通用带 M,N,K 参数的写法,便于阅读。
enum class MemLayout { RowMajor, ColMajor };
template <MemLayout L>
__global__ void gemm_fp16_scalar_kernel(const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int M, int N, int K) {
int i = (int)(blockIdx.y * blockDim.y + threadIdx.y);
int j = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= M || j >= N) return;
float acc = 0.0f;
#pragma unroll 4
for (int k = 0; k < K; ++k) {
int a_idx = (L == MemLayout::RowMajor) ? (i * K + k) : (k * M + i);
int b_idx = (L == MemLayout::RowMajor) ? (k * N + j) : (j * K + k);
acc += __half2float(A[a_idx]) * __half2float(B[b_idx]);
}
int c_idx = (L == MemLayout::RowMajor) ? (i * N + j) : (j * M + i);
C[c_idx] = acc;
}
2.2 launch 示例
dim3 block(16, 16);
dim3 grid((unsigned)((N + block.x - 1) / block.x), (unsigned)((M + block.y - 1) / block.y), 1u);
gemm_fp16_scalar_kernel<MemLayout::RowMajor><<<grid, block>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());
三、Step 2:WMMA 入门,建立 Tensor Core 直觉
WMMA 是“你亲手写 Tensor Core kernel”的入口,适合理解 tile 和 fragment。
3.1 row-major WMMA 示例
MUSA 下请在包含 <mma.h> 后使用 using namespace mtmusa;(与附录中 WMMA 部分一致)。下面内核结构与附录基准相同:用 blockIdx 表示 tile,在 store 前做越界保护(M、N 为 WMMA 整数倍时可与附录中固定 2048 版本完全对齐)。
#include <mma.h>
using namespace mtmusa;
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
__global__ void wmma_gemm_kernel_rm(const half* __restrict__ A_rm,
const half* __restrict__ B_rm,
float* __restrict__ C_rm,
int M, int N, int K) {
int tile_m = (int)blockIdx.y;
int tile_n = (int)blockIdx.x;
int block_row = tile_m * WMMA_M;
int block_col = tile_n * WMMA_N;
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
for (int i = 0; i < (int)(sizeof(c_frag.x) / sizeof(c_frag.x[0])); ++i) c_frag.x[i] = 0.0f;
for (int k0 = 0; k0 < K; k0 += WMMA_K) {
const half* tile_a = A_rm + block_row * K + k0;
const half* tile_b = B_rm + k0 * N + block_col;
wmma::load_matrix_sync(a_frag, tile_a, K);
wmma::load_matrix_sync(b_frag, tile_b, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
if (block_row + WMMA_M <= M && block_col + WMMA_N <= N) {
wmma::store_matrix_sync(C_rm + block_row * N + block_col, c_frag, N,
wmma::mem_row_major);
}
}
col-major 的 WMMA 内核、以及与 row-major 共用的 launch 分支,见本文附录中 wmma_gemm_kernel_cm 与 launch_wmma。
3.2 launch 示例
dim3 grid_wmma((unsigned)((N + WMMA_N - 1) / WMMA_N),
(unsigned)((M + WMMA_M - 1) / WMMA_M), 1u);
dim3 block_wmma(32u, 1u, 1u); // one warp
wmma_gemm_kernel_rm<<<grid_wmma, block_wmma>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());
这一步跑通后,你对 Tensor Core 编程模型会有非常具体的认知。
四、Step 3:muBLAS,先把生产可用性拿到手
真实项目里,纯 GEMM 大多先落在 muBLAS。
4.1 col-major 直接调用
#define CHECK_MUBLAS(call) do { \
mublasStatus_t st = (call); \
if (st != MUBLAS_STATUS_SUCCESS) { \
std::fprintf(stderr, "muBLAS error %s:%d status=%d\n", __FILE__, __LINE__, (int)st); \
std::exit(EXIT_FAILURE); \
} \
} while (0)
void run_mublas_col_major(mublasHandle_t handle, const half* dA, const half* dB, float* dC,
int M, int N, int K) {
const float alpha = 1.0f;
const float beta = 0.0f;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
M, N, K,
&alpha,
dA, MUSA_R_16F, M,
dB, MUSA_R_16F, K,
&beta,
dC, MUSA_R_32F, M,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
4.2 row-major 无拷贝映射调用
void run_mublas_row_major_map(mublasHandle_t handle, const half* dA_rm, const half* dB_rm,
float* dC_rm, int M, int N, int K) {
// C^T(N, M) = B^T(N, K) * A^T(K, M)
const float alpha = 1.0f;
const float beta = 0.0f;
int m = N, n = M, k = K;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
m, n, k,
&alpha,
dB_rm, MUSA_R_16F, N,
dA_rm, MUSA_R_16F, K,
&beta,
dC_rm, MUSA_R_32F, N,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
五、Step 4:muBLASLt,把“可调性”加进来
当你需要更细控制(算法选择、workspace 预算、布局描述)时,就该上 muBLASLt。下面是与附录基准中 col-major 分支一致的要点:MUBLASLT_ORDER_COL 显式设到 A/B/C/D 的 layout 上;row-major 时矩阵形状与 muBLAS 无拷贝映射一致,并在 mublasLtMatmul 时交换 A/B 指针(实现见附录 mublaslt_build_ctx / launch_mublaslt)。
#include <mublasLt.h>
void run_mublaslt_col_major(const half* dA, const half* dB, float* dC,
int M, int N, int K, void* workspace_ptr, size_t workspace_bytes) {
mublasLtHandle_t lt = nullptr;
CHECK_MUBLAS(mublasLtCreate(<));
mublasLtMatmulDesc_t op_desc = nullptr;
CHECK_MUBLAS(mublasLtMatmulDescCreate(&op_desc, MUBLAS_COMPUTE_32F, MUSA_R_32F));
mublasOperation_t trans_n = MUBLAS_OP_N;
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSA, &trans_n, sizeof(trans_n)));
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSB, &trans_n, sizeof(trans_n)));
mublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&a_desc, MUSA_R_16F, M, K, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&b_desc, MUSA_R_16F, K, N, K));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&c_desc, MUSA_R_32F, M, N, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&d_desc, MUSA_R_32F, M, N, M));
mublasLtOrder_t order = MUBLASLT_ORDER_COL;
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
a_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
b_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
c_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
d_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
mublasLtMatmulPreference_t pref = nullptr;
CHECK_MUBLAS(mublasLtMatmulPreferenceCreate(&pref));
CHECK_MUBLAS(mublasLtMatmulPreferenceSetAttribute(
pref, MUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_bytes, sizeof(workspace_bytes)));
mublasLtMatmulHeuristicResult_t hres{};
int algo_count = 0;
CHECK_MUBLAS(mublasLtMatmulAlgoGetHeuristic(
lt, op_desc, a_desc, b_desc, c_desc, d_desc, pref, 1, &hres, &algo_count));
const float alpha = 1.0f, beta = 0.0f;
CHECK_MUBLAS(mublasLtMatmul(
lt, op_desc, &alpha,
dA, a_desc, dB, b_desc, &beta,
dC, c_desc, dC, d_desc,
(algo_count > 0 ? &hres.algo : nullptr),
workspace_ptr, workspace_bytes, nullptr));
mublasLtMatmulPreferenceDestroy(pref);
mublasLtMatrixLayoutDestroy(a_desc);
mublasLtMatrixLayoutDestroy(b_desc);
mublasLtMatrixLayoutDestroy(c_desc);
mublasLtMatrixLayoutDestroy(d_desc);
mublasLtMatmulDescDestroy(op_desc);
mublasLtDestroy(lt);
}