MTT S4000 张量核心 GEMM 优化的推导、编译器调优与异步边界探索
1. GEMM背景
矩阵乘法是深度学习中最常用且最耗时的算法之一。在神经网络的训练和推理中,矩阵乘法作为基础的算子会被高频次反复调用。因此,在GPU上实现高效的矩阵乘法对于深度学习任务的性能至关重要。
一般而言,在GPU上实现矩阵乘法主要有两种途径。一是使用通用计算单元完成矩阵乘法运算,二是使用专用的张量计算核心完成矩阵乘法运算。具体而言,在英伟达最早推出张量计算核心之前,矩阵乘法通常是在通用计算单元上完成的,此种实现方法不依赖芯片上额外的张量核心,但是计算效率较为一般。而在张量核心问世之后,我们会更倾向将大规模矩阵乘法运算放在更高效的张量计算核心上完成。因此在本文中,我们将使用MUSA SDK中的WMMA库调用张量计算核心完成8192×8192×16384超大规模的矩阵乘法运算。
2. 摩尔线程s4000显卡硬件架构解读
截至目前(2026/3/16),摩尔线程显卡目前一共有4款架构,分别是MP_10、MP_21、MP_22、MP_31,本次使用的s4000属于其中的MP_22架构。
如图1所示,S4000一共拥有8个MUSA处理器簇(MUSA Processor Cluster,MPC),每个MPC包含4个MUSA 处理器执行引擎(MUSA Processor eXecution engine, MPX),每个MPX包含2个MUSA处理器(MUSA Processor, MP),即S4000一共拥有8×4×2=64个MP。值得注意的是,S4000拥有1MB大小由2个MPX即4个MP所共享的L2缓存和全局共享24MB大小的L3缓存。

图1. S4000内核架构图
目前,摩尔线程所有GPU的 warp size 均为128,即一个线程束中存在128个线程,是传统英伟达显卡的4倍,具有更宽的执行宽度。
S4000的 local memory 拥有32个 bank,一次内存事务产生的读写的数据量最大是32 float即64 half。因此在读写local memory时,我们需要保证一个warp中不同的thread的地址要命中在连续且不同的32个bank中,从而获得最大的读写速度。
此外,S4000每一个MP单元拥有131072个32 bit 寄存器,73728 byte 本地内存,允许最多被分配6144个 thread 即6144/128=48个warp。每一个block最多允许使用1024个thread,而每一个thread最多拥有256个寄存器。
注意,只有一个block为512个thread的时候,每个thread才能拥有256个寄存器,如果有1024个thread,那么每个thread只能拥有128个寄存器。
最后,通过查阅MUSA SDK的 wmma 库,我们得知,MP_22 架构的张量核心在 FP16 数值精度上主要支持16×8×8、16×16×8、16×16×16和32×32×16这四种尺寸的MMA。
至此,我们获得了我们优化HGEMM算子所需的全部硬件规格,留意上述硬件指标对于后续的优化很重要。
3. S4000上HGEMM算子具体优化实践
先简单捋一下整体代码的思路,看看每个block都需要做什么。
- 先声明寄存器(Frag A、B、C)和本地内存(Local memory A、B、C),并将Frag C置零。
- 读取全局内存(Global memory A、B),存入本地内存(Local memory A、B)。
- 读取本地内存(Local memory A、B),存入寄存器(Frag A、B)。
- 调用 wmma 库,完成 mma 累加计算,结果落回 Frag C(使用单精度做累加保证计算精度,因此这里的Frag C的数值类型是单精度)。
- 循环重复Step2、3、4。
- 将 Frag C 落回本地内存 Local memory C。
- 将 Local memory C 的数值类型转换为半精度,落回全局内存。
完成上述基础的代码编写,我们可以获得大约10TFLOPS的基线计算性能。GEMM一般被认为是 compute bound 的任务。也就是说,经过良好设计后的程序是可以完全吃满计算核心而不受显存带宽的影响。而为了实现此目标,我们往往会从以下8个角度开展HGEMM算子的优化。
3.1 Block分块和warp分块
这是最基础也是最重要的性能优化,分块的好坏对性能的影响立竿见影,并且直接关系到最终算子的性能上限。
在此项优化中,我们需要确认每个 block 和 warp 需要完成多大规模的矩阵乘法运算,以及每个 MP 单元分配多少个 block。具体而言,每个 block 完成的矩阵乘法运算的规模越 大,在本地内存中共享的数据就会越多,从而能够提高计算访存比;每个 warp 需要完成的矩阵乘法运算规模越小,block 内 warp 数越多,则能够提高运算并行度;每个 MP 单元能够分配更多的 block,则有更多的满足发射的 warp 可供调度,掩盖读写延迟。
因此,这过程可谓是戴着镣铐跳舞。既要需要寻求一个尺寸上的良好平衡,又需要检查程序是否满足各项约束条件。例如寄存器、本地内存的用量不能超过MP硬件支持的上限,单个 Block 不能使用超过1024个 thread 等等。
最直接的方法是列个大表格,把所有可能的情况全部列举出来做分析。一个简化示意版如表1所示。该表固定了 MMA 子块大小为16×16×16,MMA per warp为8即每个 warp 做8次 MMA 计算,通过修改每个 MP 上贮存的 block 数量分析寄存器和本地内存的消耗数量。
表1. 每个warp做8次16×16×16 MMA的情况下可行性分析
| MMA_M | MMA_N | MMA_K | MMA per warp | Block per MP | Thread per block | Warp per block | Reg per block | Local mem per block | 可行性? |
|---|---|---|---|---|---|---|---|---|---|
| 16 | 16 | 16 | 8 | 1 | 1024 | 8 | 22528 | 8192 | TRUE |
| 16 | 16 | 16 | 8 | 2 | 1024 | 8 | 22528 | 8192 | TRUE |
| 16 | 16 | 16 | 8 | 3 | 1024 | 8 | 22528 | 8192 | TRUE |
| 16 | 16 | 16 | 8 | 4 | 1024 | 8 | 22528 | 8192 | TRUE |
| 16 | 16 | 16 | 8 | 5 | 1024 | 8 | 22528 | 8192 | FALSE |
| 16 | 16 | 16 | 8 | 6 | 1024 | 8 | 22528 | 8192 | FALSE |
| 16 | 16 | 16 | 8 | 12 | 512 | 4 | 14336 | 6144 | FALSE |
| 16 | 16 | 16 | 8 | 24 | 256 | 2 | 10240 | 4096 | FALSE |
| 16 | 16 | 16 | 8 | 48 | 128 | 1 | 8192 | 3072 | FALSE |
首先,每个block上的thread数量不能超过1024,每个MP上的thread数量不能超过6144,根据这两个约束条件我们可以得到每个block在每种情况下应该分配多少个thread。
例如在每个MP分配6个block的情况下,我们还可以为每个block分配1024个thread,但是在为每个MP分配12个block的情况下,我们只能为每个block分配512个thread,否则会超出每个MP最多6144(=12×512)个thread的约束上限。
在确定输入参数后,便可以根据程序的初步设计完成寄存器和本地内存需求数计算。HGEMM中主要需要用到Frag A、B、C三块寄存器。其中无论数值精度是FP16还是FP32,每个数字都需要占用一个完整寄存器。以表1第4行标蓝的情况举例,如图2所示,一个warp内做8次16×16×16大小的MMA计算,那么每一个warp就需要处理(16×4)×(16×2)大小的矩阵。对于Frag C来说就需要64×32=2048大小的寄存器来承接结果。
Frag A和B的情况则有些许不同。从理想情况上来说,我们希望数据只从本地内存中取一次,并且使得存放在寄存器中的数据最少。而最符合理想情况的做法是:
- 首先取B1、A1,计算C1_1;
- 接着取B2,计算C1_2;
- 再释放A1,取A2,计算B2_1和B2_2;
- 然后是释放A2,取A3,计算A3_1和A3_2
以此类推。因此对于Frag A来说需要16×16/2=128大小的寄存器,对于Frag B来说需要16×16×2/2=256大小的寄存器。
最后经过计算得知每个block中的8个warp一共需要(2048+256+512)×8=22528个寄存器用于Frag A、B和C存储矩阵数值。此外,程序运行过程中还需要一些寄存器存储地址、中间变量等信息,寄存器的额外使用量会受到编译器影响,这部分寄存器数量我们无法做精确计算,但可以做简单粗略地估计。每个MP单元拥有131072个寄存器,131072/22528≈5.8,所以在寄存器约束下考虑,放下5个block是没问题的。

图2. Warp内8次MMA计算示意图
本地内存占用计算也是类似的。HGEMM中一共用到了3块本地内存,分别是local memory A、B、C。
我们首先注意到一点,local memory A、B和local memory C可以分时复用,即本地内存同一时刻只会存在local memory A、B或者local memory C,不会同时占用本地内存。所以本地内存的用量将会是max( A+B,C)。在编写程序的时候我们可以使用联合体union告知编译器来保证这点,写作:
__shared__ union {
struct {half a[?]; half b[?];} ab;
float c[?];
} local_memory;
接着,我们分析local memory C所需的内存空间。由于目前还不支持在寄存器中直接把Frag C数值精度转换为 half ,因此local memory C的数值精度需要保留在 float。但是如果完整保存整个计算结果所需的空间过大。如图3所示,一个block 完成的计算尺寸为128×128=16384,也就是说local memory C需要存储16384个float,这样的大小显然是不可接受的。实际上,我们并不需要一次性把所有结果全部写到local memory中,而是只暂存一部分(图3中的橙色部分),写回全局内存之后再从寄存器中存下一部分,这样在不损失性能的情况下只需要原来1/8的存储空间即16384/8=2048个float。
Local memory A、B的情况也是类似的,在不考虑双、多缓冲的情况下,A和B每次各自需要存储16×128=2048个 half。而表格中的计算结果是在考虑了一次性取4次的情况。

图3. Block中8个warp分布示意图
除了寄存器和本地内存用量约束估计,我们还能计算出寄存器/本地内存浪费数、计算访存比等能辅助我们选择的参数。
值得注意的是,上述计算中存在一些人为的设定,例如双缓冲或是多缓冲、padding之类的不同策略以及对缓存大小的各种考量都会影响最终的计算结果,还需做更多的实验才能够较好地确定参数组合。类似的,我们还可以绘制出其他情况比如每个warp做4次32×32×16 MMA的表格,如表2所示,此处不再过多赘述。
表2. 每个warp做4次32×32×16 MMA的情况下可行性分析
| MMA_M | MMA_N | MMA_K | MMA per warp | Block per MP | Thread per block | Warp per block | Reg per block | Local mem per block | 可行性? |
|---|---|---|---|---|---|---|---|---|---|
| 32 | 32 | 16 | 4 | 1 | 1024 | 8 | 38912 | 16384 | TRUE |
| 32 | 32 | 16 | 4 | 2 | 1024 | 8 | 38912 | 16384 | TRUE |
| 32 | 32 | 16 | 4 | 3 | 1024 | 8 | 38912 | 16384 | FALSE |
| 32 | 32 | 16 | 4 | 4 | 1024 | 8 | 38912 | 16384 | FALSE |
| 32 | 32 | 16 | 4 | 5 | 1024 | 8 | 38912 | 16384 | FALSE |
| 32 | 32 | 16 | 4 | 6 | 1024 | 8 | 38912 | 16384 | FALSE |
| 32 | 32 | 16 | 4 | 12 | 512 | 4 | 22528 | 8192 | FALSE |
| 32 | 32 | 16 | 4 | 24 | 256 | 2 | 14336 | 4096 | FALSE |
| 32 | 32 | 16 | 4 | 48 | 128 | 1 | 10240 | 2048 | FALSE |
经过作者实际测试,本文最后选择的是表3所示的参数组合,即每个MP贮存1个block,每个block有8个warp,每个warp做8次32×32×16 MMA。但是作者认为该参数组合并不是最优秀的选择,只是受限于作者的水平选出的一个局部最优,希望读者可以根据自己判断多多实践尝试。
表3. 最终选择的参数组合
| MMA_M | MMA_N | MMA_K | MMA per warp | Block per MP | Thread per block | Warp per block | Reg per block | Local mem per block | 可行性? |
|---|---|---|---|---|---|---|---|---|---|
| 32 | 32 | 16 | 8 | 1 | 1024 | 8 | 77824 | 16384 | TRUE |
3.2 向量化读写(40TFLOPS)
向量化读写的核心思想是用连续的、大块的批量读写代替间断的、琐碎的读写。这是一个很自然的想法,因为从硬件的角度来说,连续的读写更有利于全局内存的预读取,大块的读写则有利于缓存行的命中。
具体来说,如图4所示,从全局内存中一次读取256×16=4096个half,而一个block中存在1024个thread,因此每个thread需要搬运4096/1024=4个half。我们可以把程序写作:
#define LDST64BITS(value) (reinterpret_cast<float2*>(&(value))[0])
#define LDST64BITS_CONST(value) (reinterpret_cast<const float2*>(&(value))[0])
LDST64BITS(local_memory.ab.a[local_a_addr]) = (LDST64BITS_CONST(A[gmem_a_addr]));
这样便能够一次性读取4个half,1024个thread一次性就能将所有4096个half数据抓取回来。经过测试,实际上不定义宏,直接使用memcpy也是可行的。
此处,还需要注意的是矩阵在全局内存中的主序问题。本文使用的全部是列主序。图4中橙色的渐变部分展示了列主序矩阵在内存中的连续排布,即颜色连续代表内存空间中排布连续。可以注意到列主序的排布是利好A矩阵的读写的,因为内存空间中连续的数据块大。反观B矩阵,每次读取连续的部分都只有16个half,无法利用GPU缓存行的大小,每次读取都有部分数 据是多余连带上的,这个问题在下文中也会提及。所以GEMM中,我们会更期望B矩阵会是行主序矩阵,会得到更好的性能表现,但如果企图先做主序调整再进行计算其实是更得不偿失的,也有悖算子融合的思想。

图4. 单次读取Local memory A、B示意图
在完成前面介绍的这两步之后,我们可以得到40TFLOPS左右的性能,相比基线结果10TFLOPS相比已经获得了很大提升。
3.3 Bank conflict消解(42TFLOPS)
如果说向量化读取是为了提高全局内存读写效率,那么Bank conflict消解就是为了提高本地内存读写效率。S4000的本地内存分为32个可以并行访问的bank,每个bank可以提供32 bit 数据的读写,意味着如果能总是同时命中32个 bank,那么一次性最多能读写32×32=1024 bit 即64个half,就能提高本地内存的读写效率。
通常来说,Bank conflict消解主要有两种方式:
- 一种是padding,通过在每一行矩阵末端填充内存实现偏移使得单次读写命中不同的 bank;
- 一种是XOR,在读写本地内存之前就将数据错开放置,使得后续操作能命中全部的 bank。
但是截至目前,MUSA SDK只提供WMMA库,在从本地内存读取到寄存器的时候没法细粒度地控制读取规则。此外,在本地内存充足的情况下,也没有必要使用XOR用复杂的地址计算换取空间,所以在这里我们使用第一种padding的方法来消除bank conflict。
首先分析在什么地方涉及到了本地内存的读写。
- 全局内存A、B到本地内存A、B有一次写操作。但是这个操作过程对于本地内存的写入操作总是打满了所有bank,吃满了全部的本地内存带宽,所以不存在bank conflict的问题。
- 本地内存A、B到寄存器A、B有一次读操作。这里会存在问题,一会做具体分析。
- 寄存器C落回本地内存C有一次写操作。首先,寄存器C的数值精度是float即每个数字正好对应1个bank,对于16×16的子块来说,由于单次连续写入大小只有16个float,打不满全部的bank,会存在bank conflict问题。但是对于使用32×32的子块来说,32个float总是能打满全部的bank,因此不存在bank conflict问题。
- 最后本地内存C写回全局内存C有一次读操作。和1的情况一样,这个操作也是打满了所有的bank,不存在bank conflict的问题。
综上所述,我们只有在情况2中会遇到 bank conflict。如图5所示,在列主序的默认情况下,大小为32×16的Fragment A在读取Local memory A时总是会在Bank0-15(绿色)中进行请求,因此bank16-31(橙色)始终处于空闲状态。
而对于Fragment B读取Local memory B的过程来说,由于数据在Local memory B中总是连续排布的,因此始终能够用满全部的bank,没有bank conflict,不需要额外处理。

图5. bank conflict分析
为了解决Local memory A的 bank conflict 问题,我们可以采用如图6所示的padding方法,通过往Local memory A的列末端填充占位的空数组,我们可以让bank服务的位置产生32个half的偏移。这样之后,Fragment A在读取Local memory A的过程中就总是能同时命 中全部的32个bank。实际测试下,大约能够获得2TFLOPS的提升,HGEMM的性能也来到了42TFLOPS。

图6. Padding解决bank conflict
3.4 Block swizzling(45TFLOPS)
在前面两部分,我们完成了对全局内存和本地内存读写的优化工作,但还未对针对缓存做任何处理。事实上缓存的命中效率也会极大影响HGEMM算法的工作效率。
那么如何提高缓存的命中效率呢?一言以蔽之:增加时空局部性,即共享缓存的不同MP单元要处理相邻的block块,它们之间会有公用的A、B矩阵数据可以共享。我们虽然无法手动指定让某个MP去做某个block计算,但是我们可以间接地通过发布连续相邻的block计算任务来实现这一点。
如图5所示,默认情况下编译器会按照灰色的箭头来分发block任务给MP,即以grid.x方向遍历到头,然后grid.y增1,再遍历grid.x,此种循环往复的方式完成block任务发布。但是这种发布方式存在空间局部性不强的问题,尤其是在grid.x到头之后的一个猛回头,完全破坏了空间局部性。当然,也可以尝试S型走线,获得一定的空间局部性。但实际效果最好的是swizzling这样的走线方式,空间局部性最好。
一个实现了block swizzling,可供参考的代码:
主机端:
const int N_SWIZZLE = (N + (512) - 1) / (512);
dim3 grid((N/BN)/N_SWIZZLE, M/BM, N_SWIZZLE);
设备端:
const int bx = blockIdx.z * gridDim.x + blockIdx.x;
const int by = blockIdx.y + (blockIdx.z & 1) * (gridDim.y - 1 - 2 * blockIdx.y);
其中M=N=8192,BM=BN=256,N_SWIZZLE的大小是超参数可以手动调节。实际测试下,Block swizzling策略让性能提升了大约3TFLOPS,HGEMM性能来到了45TFLOPS。
顺带一提,此处可能有读者选择使用if判断来分类讨论by的自增方向。这种整个warp走同一分支并不会造成性能影响,但是如果涉及不同的thread需要执行不同的分支就会产生分支发散,MP需要完整执行所有分支来获得结果,导致效率低下。在无法避免的情况下,可以考虑使用上面的这种“谓词执行”的写法,来获得性能提升。

图7. 正常遍历(灰色)、S型遍历(绿色) 和Block swizzling(橙色)示意图