MTT S4000 GEMM 优化中的分块、尾处理与并发实战复盘
1. 规则分析
赛题有以下约束:
| 项目 | 要求 |
|---|---|
| 输入输出 | FP16 输入,FP32 累加,FP16 输出 |
| 固定接口 | extern "C" void gemm_optimized(...),签名不能改 |
| 数据布局 | A/B/C 都按列主序访问 |
| 同步要求 | 返回前必须完成计算,提交版需要 musaDeviceSynchronize() |
| 禁止行为 | 禁止 muBLAS/muDNN 等高层库,禁止修改输入矩阵,禁止投机跳过计算 |
| 评测重点 | correctness 先过,再看 8192 x 8192 x 16384 的性能 |
这组规则对后面的设计有两个直接影响:
- 不能把问题当成普通 CUDA GEMM 直接套模板。
列主序、MUSA WMMA 语义、提交接口和同步位置,都必须一一对齐。 - 性能优化必须和正确性绑定。
任何没有经过完整 correctness 约束的高分版本,最后都可能只是表面上跑得快。
2. GEMM入门
在前期准备阶段, 我主要做了三件事: 先理解 GEMM 的计算本质,再跑通提交代码的最小链路,最后去看 S4000 这张卡真正吃什么。
GEMM 本质上就是一个三重循环:
for (int n = 0; n < N; ++n)
for (int m = 0; m < M; ++m) {
float acc = 0.0f;
for (int k = 0; k < K; ++k)
acc += A[m, k] * B[k, n];
C[m, n] = acc;
}
在固定规模是 8192 x 8192 x 16384 下,单次计算量就是 2 x M x N x K,大约 2.2e12 FLOPs。
由此可以快速得出两个关键结论:
- 不能按教科书式逐元素直算。
如果每次乘加都直接从全局内存读取数据,带宽会迅速成为瓶颈。 - 数据复用是第一原则。
同一块A和B子块会参与很多输出元素的计算,所以 tile、shared memory 和寄存器缓存不是锦上添花,而是基本前提。
这道题还有一个很容易踩坑的地方,就是 列主序。很多人平时更熟悉行主序,但这里地址计算必须统一为:
A(m, k) -> d_A[m + k * M]
B(k, n) -> d_B[k + n * K]
C(m, n) -> d_C[m + n * M]

结合摩尔线程 MTT S4000 的公开规格:
- 48GB GDDR6
- 768GB/s 显存带宽
- FP16/BF16 峰值 100 TFLOPS
这几个数字直接告诉我,想把这道题做好,必须同时满足两件事:
- 主计算必须尽量走 Tensor Core 路径。
否则纯标量或普通向量 FMA 很难逼近这张卡的张量算力上限。 - 访存必须做分层复用。
既然带宽是硬上限,就不能让每个 wave 都反复从 global memory 取同一块B。
MUSA 编程和性能文档里还有几条对这题非常关键的硬件事实:
- MUSA 的存储层次是
register->shared memory->global memory。GEMM 天然适合把数据按global -> shared -> register -> mma这条路径搬运。 - QY 架构下
warpSize=128。block size 通常要取 warp size 的整数倍,128/256/512/1024都是合理候选。 - active warps 并不是越多越好,它会同时受寄存器和 shared memory 用量约束,所以 tile 变大、unroll 变深,未必一定更快。
- double buffer、向量化加载、多 stream 并发,都是有价值的工具,但前提是主核结构已经稳定。
这一阶段最大的收获,是将优化方向从经验判断转变为基于硬件约束的推导。
3. 基础框架搭建
在理清 GEMM 计算和硬件特性后,我开始搭建开发闭环,把代码怎么组织、实验怎么推进、结果怎么判断这三件事做顺。
我把基础框架拆成两层:
- 代码框架:把提交接口、路由、fallback、epilogue 这些骨架搭对。
- 实验框架:把评分、监控、profile、日志、版本快照这套闭环搭起来。
3.1 搭代码框架,再填快路径
Host 侧骨架大致如下:
extern "C" void gemm_optimized(
half const* d_A,
half const* d_B,
half* d_C,
int M, int N, int K) {
if (!fast_ok32(M, N, K)) {
gemm_fallback_kernel<<<grid, block>>>(d_A, d_B, d_C, M , N, K);
musaDeviceSynchronize();
return;
}
if (M == 8192 && N == 8192 && K == 16384) {
gemm_wmma32_wave128x8n128_k16384_k32_reuseB_kernel<<<grid, block, 0, s_epi_stream0>>>(...);
musaEventRecord(s_evt_gemm_done, s_epi_stream0);
musaStreamWaitEvent(s_epi_stream1, s_evt_gemm_done, 0);
fp32_to_fp16_8192_half2x8_oneshot_kernel<<<..., s_epi_stream0>>>(...);
fp32_to_fp16_8192_half2x8_oneshot_kernel<<<..., s_epi_stream1>>>(...);
musaDeviceSynchronize();
return;
}
// 其它尺寸按 x1/x2/x4/x8 路由
...
}
这套代码框架的作用是:
- 避免某个专用 kernel 失败导致整份代码不可用。
- 将优化集中在固定形状的专用路径,不会反复污染通用逻辑。
- 将 correctness 和 performance 解耦分析。
3.2 搭实验框架
每轮实验都固定成下面 6 步:
- 备份当前代码版本。
- 运行
grader,看 correctness 和 GFLOPS。 - 运行
mthreads-gmi,看 utilization / clock / temperature。 - 运行
msys profile,看 API 和 kernel 时间线。 - 保存报告、日志、profile。
- 给这次实验贴标签,方便后面回滚和对比。
这个流程帮我省掉了很多无效尝试。因为当版本一多,最怕的不是性能没涨,而是根本说不清为什么涨、为什么跌。
4. 环境搭建
比赛镜像里提供了评分脚本、编译环境和 profiling 工具。对参赛过程最有用的命令有:
# 相对评分
python3 grader.py <submission>.mu --verbose
# 绝对性能展示
python3 grader_absolute.py <submission>.mu
# 设备状态监控
mthreads-gmi -q -d UTILIZATION,CLOCK,TEMPERATURE -l 1
# 时间线分析
msys profile -d 0 -t musa,osrt --gpu-metrics-set=1 --duration=20 \
-o gemm.msys-rep \
python3 grader.py <submission>.mu --report gemm.json
如果要单独手工编译某个 probe 或最小复现版本,可以跟比赛环境保持一致,固定使用:
mcc --offload-arch=mp_22 -O3 -march=native -lmusart -lmublas