MUSA Tensor Core GEMM 实践手记:把“能跑”变成“能交付”
刚接手一个 GEMM 热点模块时,最常见的状态是:
“知道 Tensor Core 很重要,但不知道该从哪一层开始落地。”
这篇文章不追求“最全 API 手册”,而是按真实工程节奏,把路径拆成 6 步:
- 先写一个慢但可解释的 baseline
- 用 WMMA 建立 Tensor Core 编程直觉
- 用 muBLAS 快速获得稳定实现
- 在 muBLASLt 上做算法与 workspace 控制
- 在 muDNN 场景里对接算子接口
- 用 MUTLASS 做更可控的模板化定制
每一步都保留 example code。文中不含敏感环境信息,也不放性能数据。
开始前:先把“跑不起来”的问题排干净
建议先确认工具链,而不是先写代码。
which mcc
mcc --version
# 以下路径按你的安装位置调整
ls /usr/local/musa/include/musa_runtime.h
ls /usr/local/musa/include/mublas.h
ls /usr/local/musa/lib64/libmusart.so
ls /usr/local/musa/lib64/libmublas.so
如果这些都正常,再开始正文会省很多时间。
一、先讲最容易出错的点:布局
大多数 GEMM 的“结果不对”,根因都不是算子,而是布局解释。
row-major:A[i, k] -> i * ld + kcol-major:A[i, k] -> k * ld + i
BLAS 默认按 col-major。业务里常见 row-major,这时常用映射:
这个映射的价值很实际:不写显式 transpose kernel,也能复用库接口。
二、Step 1:先有 baseline(哪怕它很慢)
baseline 的意义不是快,而是“可解释、可对拍、可回归”。
2.1 一个最小 baseline kernel
与文末附录中的基准程序 bench_gemm_2048.mu 一致:布局用 MemLayout,维度在附录里为编译期常量 M_DIM/N_DIM/K_DIM;下面给出通用带 M,N,K 参数的写法,便于阅读。
enum class MemLayout { RowMajor, ColMajor };
template <MemLayout L>
__global__ void gemm_fp16_scalar_kernel(const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int M, int N, int K) {
int i = (int)(blockIdx.y * blockDim.y + threadIdx.y);
int j = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= M || j >= N) return;
float acc = 0.0f;
#pragma unroll 4
for (int k = 0; k < K; ++k) {
int a_idx = (L == MemLayout::RowMajor) ? (i * K + k) : (k * M + i);
int b_idx = (L == MemLayout::RowMajor) ? (k * N + j) : (j * K + k);
acc += __half2float(A[a_idx]) * __half2float(B[b_idx]);
}
int c_idx = (L == MemLayout::RowMajor) ? (i * N + j) : (j * M + i);
C[c_idx] = acc;
}
2.2 launch 示例
dim3 block(16, 16);
dim3 grid((unsigned)((N + block.x - 1) / block.x), (unsigned)((M + block.y - 1) / block.y), 1u);
gemm_fp16_scalar_kernel<MemLayout::RowMajor><<<grid, block>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());
三、Step 2:WMMA 入门,建立 Tensor Core 直觉
WMMA 是“你亲手写 Tensor Core kernel”的入口,适合理解 tile 和 fragment。
3.1 row-major WMMA 示例
MUSA 下请在包含 <mma.h> 后使用 using namespace mtmusa;(与附录中 WMMA 部分一致)。下面内核结构与附录基准相同:用 blockIdx 表示 tile,在 store 前做越界保护(M、N 为 WMMA 整数倍时可与附录中固定 2048 版本完全对齐)。
#include <mma.h>
using namespace mtmusa;
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
__global__ void wmma_gemm_kernel_rm(const half* __restrict__ A_rm,
const half* __restrict__ B_rm,
float* __restrict__ C_rm,
int M, int N, int K) {
int tile_m = (int)blockIdx.y;
int tile_n = (int)blockIdx.x;
int block_row = tile_m * WMMA_M;
int block_col = tile_n * WMMA_N;
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
for (int i = 0; i < (int)(sizeof(c_frag.x) / sizeof(c_frag.x[0])); ++i) c_frag.x[i] = 0.0f;
for (int k0 = 0; k0 < K; k0 += WMMA_K) {
const half* tile_a = A_rm + block_row * K + k0;
const half* tile_b = B_rm + k0 * N + block_col;
wmma::load_matrix_sync(a_frag, tile_a, K);
wmma::load_matrix_sync(b_frag, tile_b, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
if (block_row + WMMA_M <= M && block_col + WMMA_N <= N) {
wmma::store_matrix_sync(C_rm + block_row * N + block_col, c_frag, N,
wmma::mem_row_major);
}
}
col-major 的 WMMA 内核、以及与 row-major 共用的 launch 分支,见本文附录中 wmma_gemm_kernel_cm 与 launch_wmma。
3.2 launch 示例
dim3 grid_wmma((unsigned)((N + WMMA_N - 1) / WMMA_N),
(unsigned)((M + WMMA_M - 1) / WMMA_M), 1u);
dim3 block_wmma(32u, 1u, 1u); // one warp
wmma_gemm_kernel_rm<<<grid_wmma, block_wmma>>>(dA, dB, dC, M, N, K);
CHECK_MUSA(musaGetLastError());
这一步跑通后,你对 Tensor Core 编程模型会有非常具体的认知。
四、Step 3:muBLAS,先把生产可用性拿到手
真实项目里,纯 GEMM 大多先落在 muBLAS。
4.1 col-major 直接调用
#define CHECK_MUBLAS(call) do { \
mublasStatus_t st = (call); \
if (st != MUBLAS_STATUS_SUCCESS) { \
std::fprintf(stderr, "muBLAS error %s:%d status=%d\n", __FILE__, __LINE__, (int)st); \
std::exit(EXIT_FAILURE); \
} \
} while (0)
void run_mublas_col_major(mublasHandle_t handle, const half* dA, const half* dB, float* dC,
int M, int N, int K) {
const float alpha = 1.0f;
const float beta = 0.0f;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
M, N, K,
&alpha,
dA, MUSA_R_16F, M,
dB, MUSA_R_16F, K,
&beta,
dC, MUSA_R_32F, M,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
4.2 row-major 无拷贝映射调用
void run_mublas_row_major_map(mublasHandle_t handle, const half* dA_rm, const half* dB_rm,
float* dC_rm, int M, int N, int K) {
// C^T(N, M) = B^T(N, K) * A^T(K, M)
const float alpha = 1.0f;
const float beta = 0.0f;
int m = N, n = M, k = K;
CHECK_MUBLAS(mublasGemmEx(
handle,
MUBLAS_OP_N, MUBLAS_OP_N,
m, n, k,
&alpha,
dB_rm, MUSA_R_16F, N,
dA_rm, MUSA_R_16F, K,
&beta,
dC_rm, MUSA_R_32F, N,
MUBLAS_COMPUTE_32F,
MUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
五、Step 4:muBLASLt,把“可调性”加进来
当你需要更细控制(算法选择、workspace 预算、布局描述)时,就该上 muBLASLt。下面是与附录基准中 col-major 分支一致的要点:MUBLASLT_ORDER_COL 显式设到 A/B/C/D 的 layout 上;row-major 时矩阵形状与 muBLAS 无拷贝映射一致,并在 mublasLtMatmul 时交换 A/B 指针(实现见附录 mublaslt_build_ctx / launch_mublaslt)。
#include <mublasLt.h>
void run_mublaslt_col_major(const half* dA, const half* dB, float* dC,
int M, int N, int K, void* workspace_ptr, size_t workspace_bytes) {
mublasLtHandle_t lt = nullptr;
CHECK_MUBLAS(mublasLtCreate(<));
mublasLtMatmulDesc_t op_desc = nullptr;
CHECK_MUBLAS(mublasLtMatmulDescCreate(&op_desc, MUBLAS_COMPUTE_32F, MUSA_R_32F));
mublasOperation_t trans_n = MUBLAS_OP_N;
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSA, &trans_n, sizeof(trans_n)));
CHECK_MUBLAS(mublasLtMatmulDescSetAttribute(op_desc, MUBLASLT_MATMUL_DESC_TRANSB, &trans_n, sizeof(trans_n)));
mublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&a_desc, MUSA_R_16F, M, K, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&b_desc, MUSA_R_16F, K, N, K));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&c_desc, MUSA_R_32F, M, N, M));
CHECK_MUBLAS(mublasLtMatrixLayoutCreate(&d_desc, MUSA_R_32F, M, N, M));
mublasLtOrder_t order = MUBLASLT_ORDER_COL;
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
a_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
b_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
c_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
CHECK_MUBLAS(mublasLtMatrixLayoutSetAttribute(
d_desc, MUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));
mublasLtMatmulPreference_t pref = nullptr;
CHECK_MUBLAS(mublasLtMatmulPreferenceCreate(&pref));
CHECK_MUBLAS(mublasLtMatmulPreferenceSetAttribute(
pref, MUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_bytes, sizeof(workspace_bytes)));
mublasLtMatmulHeuristicResult_t hres{};
int algo_count = 0;
CHECK_MUBLAS(mublasLtMatmulAlgoGetHeuristic(
lt, op_desc, a_desc, b_desc, c_desc, d_desc, pref, 1, &hres, &algo_count));
const float alpha = 1.0f, beta = 0.0f;
CHECK_MUBLAS(mublasLtMatmul(
lt, op_desc, &alpha,
dA, a_desc, dB, b_desc, &beta,
dC, c_desc, dC, d_desc,
(algo_count > 0 ? &hres.algo : nullptr),
workspace_ptr, workspace_bytes, nullptr));
mublasLtMatmulPreferenceDestroy(pref);
mublasLtMatrixLayoutDestroy(a_desc);
mublasLtMatrixLayoutDestroy(b_desc);
mublasLtMatrixLayoutDestroy(c_desc);
mublasLtMatrixLayoutDestroy(d_desc);
mublasLtMatmulDescDestroy(op_desc);
mublasLtDestroy(lt);
}
六、Step 5:muDNN MatMul,接入算子式调用
如果你的上层是算子图或算子融合流程,muDNN 往往更自然。row-major 时 stride 取 (ld, 1) 即可。当前部分驱动对 tensor 的 stride 校验较严(附录 main 中 make_tensor / MemLayout::ColMajor 分支前有英文注释说明),col-major 物理存储时可在不搬数据的前提下用 (C^\top = B^\top A^\top) 映射到 MatMul 的 row-major 描述,与附录 bench_one_layout 里 MemLayout::ColMajor 分支一致。
#include <mudnn.h>
void run_mudnn_row_major(const half* dA, const half* dB, float* dC, int M, int N, int K) {
musa::dnn::Handle handle;
musa::dnn::MatMul op;
op.SetComputeMode(musa::dnn::MatMul::ComputeMode::TENSOR);
op.SetTranspose(false, false);
op.SetAlpha(1.0);
op.SetBeta(0.0);
musa::dnn::Tensor tA, tB, tC;
tA.SetAddr(dA); tA.SetType(musa::dnn::Tensor::Type::HALF);
tB.SetAddr(dB); tB.SetType(musa::dnn::Tensor::Type::HALF);
tC.SetAddr(dC); tC.SetType(musa::dnn::Tensor::Type::FLOAT);
int64_t dimA[2] = {M, K}, strideA[2] = {K, 1};
int64_t dimB[2] = {K, N}, strideB[2] = {N, 1};
int64_t dimC[2] = {M, N}, strideC[2] = {N, 1};
tA.SetNdInfo(2, dimA, strideA);
tB.SetNdInfo(2, dimB, strideB);
tC.SetNdInfo(2, dimC, strideC);
size_t ws_bytes = 0;
op.GetWorkspaceSize(handle, ws_bytes, tC, tA, tB);
void* ws = nullptr;
if (ws_bytes > 0) CHECK_MUSA(musaMalloc(&ws, ws_bytes));
auto st = op.Run(handle, tC, tA, tB);
if (st != musa::dnn::Status::SUCCESS) {
std::fprintf(stderr, "muDNN MatMul run failed\n");
std::exit(EXIT_FAILURE);
}
if (ws) CHECK_MUSA(musaFree(ws));
}
七、Step 6:MUTLASS,做模板化可控定制
当你需要“比库更可控,但不想全手写 WMMA 调度”时,MUTLASS 是很实用的中间层。下面类型骨架与附录基准一致:Tile 为 256×128×K,含 Epilogue CollectiveBuilder,alignment 按 16 字节对齐;设备侧完整 Arguments / initialize / run 见附录 launch_mutlass 与 bench_one_layout 中的 MUTLASS 段落。
#include "mute/tensor.hpp"
#include "mutlass/mutlass.h"
#include "mutlass/numeric_types.h"
#include "mutlass/gemm/device/gemm_universal_adapter.h"
#include "mutlass/gemm/collective/collective_builder.hpp"
#include "mutlass/epilogue/collective/collective_builder.hpp"
using MutlassLayoutA = mutlass::layout::RowMajor;
using MutlassLayoutB = mutlass::layout::RowMajor;
using MutlassLayoutC = mutlass::layout::RowMajor;
using MutlassLayoutD = mutlass::layout::RowMajor;
using MutlassTypeA = mutlass::half_t;
using MutlassTypeB = mutlass::half_t;
using MutlassTypeC = float;
using MutlassTypeD = float;
using MutlassTypeAccumulator = float;
using MutlassTypeCompute = float;
using MutlassArchTag = mutlass::arch::Mp22;
using MutlassOpClass = mutlass::arch::OpClassTensorOp;
using MutlassTileShape = mute::Shape<mute::_256, mute::_128, mute::_32>;
using MutlassClusterShape = mute::Shape<mute::_1, mute::_1, mute::_1>;
static constexpr int MutlassAlignmentA = 16 / sizeof(MutlassTypeA);
static constexpr int MutlassAlignmentB = 16 / sizeof(MutlassTypeB);
static constexpr int MutlassAlignmentC = 16 / sizeof(MutlassTypeC);
static constexpr int MutlassAlignmentD = 16 / sizeof(MutlassTypeD);
using MutlassCollectiveMainloop = typename mutlass::gemm::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTypeA, MutlassLayoutA, MutlassAlignmentA,
MutlassTypeB, MutlassLayoutB, MutlassAlignmentB,
MutlassTypeAccumulator,
MutlassTileShape,
MutlassClusterShape,
mutlass::gemm::collective::StageCountAuto,
mutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
using MutlassCollectiveEpilogue = typename mutlass::epilogue::collective::CollectiveBuilder<
MutlassArchTag, MutlassOpClass,
MutlassTileShape,
MutlassClusterShape,
mutlass::epilogue::collective::EpilogueTileAuto,
MutlassTypeAccumulator, MutlassTypeCompute,
MutlassTypeC, MutlassLayoutC, MutlassAlignmentC,
MutlassTypeD, MutlassLayoutD, MutlassAlignmentD,
mutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;
using MutlassGemmKernel = mutlass::gemm::kernel::GemmUniversal<
mute::Shape<int, int, int, int>, MutlassCollectiveMainloop, MutlassCollectiveEpilogue>;
using MutlassGemm = mutlass::gemm::device::GemmUniversalAdapter<MutlassGemmKernel>;
用于筛 tile 配置的 profiler 调用示例(附录中 MutlassTileShape 的注释说明其来源;需在已安装 MUTLASS 工具链的环境中执行):
mutlass_profiler \
--operation=Gemm --providers=mutlass \
--verification-enabled=false \
--m=2048 --n=2048 --k=2048 \
--op_class=tensorop --accum=f32 \
--A=f16:row --B=f16:row --C=f32:row --D=f32:row
八、把多个后端组织到同一 benchmark 框架
这一节的重点不是“谁快”,而是“如何可持续比较”。片段中的 CHECK_MUSA 与第九节「最小可运行 main.cpp」中的宏定义相同(附录开头亦有一份)。
static float time_kernel_ms(void (*launch)(void*), void* ctx, int warmup, int iters) {
musaEvent_t start, stop;
CHECK_MUSA(musaEventCreate(&start));
CHECK_MUSA(musaEventCreate(&stop));
for (int i = 0; i < warmup; ++i) launch(ctx);
CHECK_MUSA(musaDeviceSynchronize());
CHECK_MUSA(musaEventRecord(start));
for (int i = 0; i < iters; ++i) launch(ctx);
CHECK_MUSA(musaEventRecord(stop));
CHECK_MUSA(musaEventSynchronize(stop));
float ms = 0.0f;
CHECK_MUSA(musaEventElapsedTime(&ms, start, stop));
CHECK_MUSA(musaEventDestroy(start));
CHECK_MUSA(musaEventDestroy(stop));
return ms / iters;
}
spot-check 示例(与附录中 check_close 相同思想:按布局取线性下标,采样若干点比全量对拍更轻;MemLayout 与第二节一致,若已定义可删 enum):
enum class MemLayout { RowMajor, ColMajor };
static void check_close(const float* hC0, const float* hC1, MemLayout layout,
int M, int N, const char* name0, const char* name1) {
const int points[][2] = {{0, 0}, {1, 2}, {31, 7}, {127, 511}};
float max_err = 0.0f;
for (int p = 0; p < (int)(sizeof(points) / sizeof(points[0])); ++p) {
int i = points[p][0], j = points[p][1];
if (i >= M || j >= N) continue;
int idx = (layout == MemLayout::RowMajor) ? (i * N + j) : (j * M + i);
float e = std::fabs(hC0[idx] - hC1[idx]);
if (e > max_err) max_err = e;
}
std::printf("Spot-check %s vs %s: max_abs_err=%e\n", name0, name1, max_err);
}