torchada:一行代码让 PyTorch CUDA 项目运行在摩尔线程 GPU 上
在开源社区中,SGLang、vLLM、ComfyUI、LightLLM 等热门 AI 项目都是基于 PyTorch 开发的,它们通过 torch.cuda.* API 来管理 GPU 设备、分配显存、调度计算。当开发者希望将这些项目迁移到摩尔线程 GPU 上运行时,传统做法是将代码中每一处 cuda 引用改为 musa —— 这对于动辄数十万行代码的大型项目来说,工作量巨大且极易出错。
torchada 为此而生。只需在代码入口添加一行 import torchada,它就能在运行时自动将所有 CUDA API 调用转换为 MUSA 等效调用, 让现有代码零改动运行在摩尔线程 GPU 上。
本文将详细介绍 torchada 的设计理念、核心特性、使用方式、性能表现,以及已成功使用 torchada 适配摩尔线程 GPU 的开源项目。
什么是 torchada?

torchada 是一个适配器包,让 torch_musa(摩尔线程 GPU 的 PyTorch 支持)兼容标准的 PyTorch CUDA API。它的核心理念是:零代码改动,即可将 CUDA 项目迁移到摩尔线程 GPU 上运行。
为什么需要 torchada?
以 SGLang 为例,其代码库中有大量 torch.cuda.* 的调用,从设备管理到显存分配,从混合精度训练到分布式通信,几乎无处不在。如果要逐一将这些调用改为 torch.musa.*,不仅需要深入理解项目的每个模块,还要确保修改后的代码在功能上完全等价 —— 类似的情况在 vLLM、ComfyUI、LightLLM 等项目中同样存在。
torchada 从根本上解决了这个问题:它在 Python 运行时层面自动拦截所有 torch.cuda.* API 调用,并将其透明地转换为对应的 torch.musa.* 调用。开发者无需修改任何业务代码,只需在入口处添加一行 import torchada 即可。
核心特性
- 零代码改动:只需
import torchada,现有的 CUDA 代码即可在摩尔线程 GPU 上运行。 - 全面的 API 覆盖:支持设备操作、显存管理、同步、混合精度训练、CUDA Graphs、分布式训练、torch.compile、C++ 扩展等。
- 极低的运行时开销:通过激进的缓存策略,频繁调用的操作开销低于 200 纳秒。
- 自动符号映射:构建 C++ 扩展时,自动将 CUDA 符号(如
cudaMalloc、cudaStream_t)转换为 MUSA 等效符号(如musaMalloc、musaStream_t),包含 380+ 条映射规则。 - ctypes 库加载支持:自动转换 ctypes 加载的动态库中的 CUDA 函数名为 MUSA 函数名。
支持的功能一览
| 功能 | 示例 |
|---|---|
| 设备操作 | tensor.cuda(), model.cuda(), torch.device("cuda") |
| 显存管理 | torch.cuda.memory_allocated(), empty_cache() |
| 同步 | torch.cuda.synchronize(), Stream, Event |
| 混合精度 | torch.cuda.amp.autocast(), GradScaler() |
| CUDA Graphs | torch.cuda.CUDAGraph, torch.cuda.graph() |
| CUDA 运行时 | torch.cuda.cudart() → 使用 MUSA 运行时 |
| 性能分析 | ProfilerActivity.CUDA → 使用 PrivateUse1 |
| 自定义算子 | Library.impl(..., "CUDA") → 使用 PrivateUse1 |
| 分布式训练 | dist.init_process_group(backend='nccl') → 使用 MCCL |
| torch.compile | torch.compile(model) 支持所有后端 |
| C++ 扩展 | CUDAExtension, BuildExtension, load() |
| ctypes 库加载 | ctypes.CDLL 使用 CUDA 函数名 → 自动转换为 MUSA |
安装与快速开始
前置条件
- torch_musa:必须安装 torch_musa(提供 PyTorch 的 MUSA 支持)
- 摩尔线程 GPU:已安装正确驱动的摩尔线程 GPU
安装 torchada
pip install torchada
# 或从源码安装
git clone https://github.com/MooreThreads/torchada.git
cd torchada
pip install -e .
快速开始
只需在代码顶部添加一行 import torchada,你现有的 CUDA 代码即可在摩尔线程 GPU 上运行:
import torchada # ← 在文件顶部添加这一行
import torch
# 你现有的 CUDA 代码无需改动:
x = torch.randn(10, 10).cuda()
print(torch.cuda.device_count())
torch.cuda.synchronize()
就这么简单!所有 torch.cuda.* API 会自动重定向到 torch.musa.*。
使用示例
下面通过几个典型场景来展示 torchada 的使用方式。在所有场景中,你会发现代码和在 NVIDIA GPU 上编写时完全一样 —— 唯一的区别是在入口处多了一行 import torchada。
训练场景:混合精度 + 分布式
在实际的模型训练中,混合精度和分布式是两个最常用的特性。torchada 对二者都提供了完整的支持:
import torchada
import torch
import torch.distributed as dist
# 分布式初始化:'nccl' 会自动映射到摩尔线程的 'mccl'
dist.init_process_group(backend='nccl')
model = MyModel().cuda()
scaler = torch.cuda.amp.GradScaler()
# 混合精度训练:autocast 和 GradScaler 正常工作
with torch.cuda.amp.autocast():
output = model(data.cuda())
loss = criterion(output, target.cuda())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
推理优化:CUDA Graphs + torch.compile
对于追求极致推理性能的场景,CUDA Graphs 和 torch.compile 同样可以无缝使用:
import torchada
import torch
# 使用 torch.compile 加速
compiled_model = torch.compile(model.cuda(), backend='inductor')
# 使用 CUDA Graphs 减少内核启动开销
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g):
y = compiled_model(x)
扩展开发:C++ 扩展与自定义算子
很多高性能项目会使用 C++ 扩展来实现自定义 CUDA 内核。torchada 在构建阶段自动将 CUDA 符号(如 cudaMalloc、cudaStream_t、at::cuda)转换为 MUSA 等效符号,覆盖 380+ 条映射规则:
import torchada # 必须在 torch.utils.cpp_extension 之前导入
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# 标准 CUDAExtension 可直接使用 — torchada 自动处理 CUDA→MUSA 转换
ext = CUDAExtension("my_ext", sources=["kernel.cu"])
通过 torch.library 注册的自定义算子也能正常工作:
import torchada
import torch
my_lib = torch.library.Library("my_lib", "DEF")
my_lib.define("my_op(Tensor x) -> Tensor")
my_lib.impl("my_op", my_func, "CUDA") # 在 MUSA 上也能工作
调试与分析:Profiler + ctypes
在性能调优阶段,ProfilerActivity.CUDA 和 ctypes 库加载同样被 torchada 覆盖:
import torchada
import torch
# 性能分析:ProfilerActivity.CUDA 在 MUSA 上也能工作
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
) as prof:
model(x)
import torchada
import ctypes
# ctypes 加载动态库时,CUDA 函数名自动转换为 MUSA 函数名
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaMalloc # 自动转换为 musaMalloc
平台检测
当你需要在代码中区分当前运行在 NVIDIA GPU 还是摩尔线程 GPU 上时,torchada 提供了便捷的 API:
import torchada
from torchada import detect_platform, Platform
platform = detect_platform()
if platform == Platform.MUSA:
print("在摩尔线程 GPU 上运行")
elif platform == Platform.CUDA:
print("在 NVIDIA GPU 上运行")
性能
torchada 使用激进的缓存策略来最小化运行时开销。所有频繁调用的操作都在 200 纳秒内完成:
| 操作 | 开销 |
|---|---|
torch.cuda.device_count() | ~140ns |
torch.cuda.Stream(属性访问) | ~130ns |
torch.cuda.Event(属性访问) | ~130ns |
_translate_device('cuda') | ~140ns |
torch.backends.cuda.is_built() | ~155ns |
作为对比,典型的 GPU 内核启动耗时 5,000-20,000ns。torchada 的补丁开销对于实际应用来说可以忽略不计。
使用注意事项
设备类型字符串比较
由于 torchada 将 cuda 设备映射为 musa,直接进行设备类型字符串比较时需要注意:
device = torch.device("cuda:0") # 在 MUSA 上实际变成 musa:0
device.type == "cuda" # 返回 False
推荐使用 torchada.is_gpu_device() 来进行设备类型判断:
import torchada
if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能正确工作
...
# 或者: device.type in ("cuda", "musa")
关于 torch.cuda.is_available()
torch.cuda.is_available() 是 torchada 有意不进行重定向的 API —— 在 MUSA 环境下它会返回 False。这是一个经过深思熟虑的设计 决策:保留此 API 的原始语义,使得开发者可以准确判断当前是否处于 NVIDIA CUDA 环境,从而实现正确的平台检测逻辑。如果需要通用的 GPU 可用性检查,推荐使用以下模式:
def has_gpu():
import torch
return torch.cuda.is_available() or (hasattr(torch, 'musa') and torch.musa.is_available())
更多细节请参见 torchada 迁移指南。
将 torchada 集成到你的项目
将 torchada 集成到现有项目只需以下几个步骤:
步骤 1:添加依赖
# pyproject.toml 或 requirements.txt
torchada>=0.1.35
步骤 2:条件导入
# 在应用入口处
def is_musa():
import torch
return hasattr(torch.version, "musa") and torch.version.musa is not None
if is_musa():
import torchada # noqa: F401
# 其余代码正常使用 torch.cuda.*
步骤 3:扩展功能标志(如适用)
# 在 GPU 能力检查中包含 MUSA
if is_nvidia() or is_musa():
ENABLE_FLASH_ATTENTION = True
步骤 4:修复设备类型检查(如适用)
# 不要用: device.type == "cuda"
# 改用: device.type in ("cuda", "musa")
# 或者: torchada.is_gpu_device(device)
已成功适配的开源项目
以下是已经使用 torchada 成功适配摩尔线程 GPU 的开源项目,它们涵盖了模型服务和图像/视频生成等多个领域:
| 项目 | 类别 | 状态 |
|---|---|---|
| Xinference | 模型服务 | ✅ 已合并 |
| LightLLM | 模型服务 | ✅ 已合并 |
| LightX2V | 图像/视频生成 | ✅ 已合并 |
| 赤兔 (Chitu) | 模型服务 | ✅ 已合并 |
| SGLang | 模型服务 | 进行中(Issue #16565) |
| ComfyUI | 图像/视频生成 | 进行中(PR #11618) |
这些项目的成功适 配证明了 torchada 在实际生产环境中的可靠性和通用性。无论是大语言模型推理服务(如 Xinference、LightLLM、赤兔),还是图像/视频生成工具(如 LightX2V、ComfyUI),torchada 都能帮助这些项目以最小的代码改动在摩尔线程 GPU 上运行。
实战案例:ComfyUI 适配摩尔线程 GPU
ComfyUI 是目前最流行的 Stable Diffusion 工作流工具之一,拥有超过 10 万 GitHub Star。下面以 PR #11618 为例,展示如何使用 torchada 将一个大型 CUDA 项目适配到摩尔线程 GPU 上 —— 整个改动非常直观。
第一步:添加 torchada 依赖
在 requirements.txt 中添加一行:
torchada>=0.1.35
第二步:导入 torchada 并检测 MUSA 平台
在核心模块 model_management.py 中,添加 MUSA 平台检测:
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
def is_musa():
global musa_available
return musa_available
这就是 torchada 的核心用法:导入后,所有 torch.cuda.* API 自动可用。is_musa() 函数则用于需要针对 MUSA 平台做特殊处理的场景。
第三步:扩展 GPU 能力标志
ComfyUI 中有许多针对 NVIDIA GPU 的功能开关,只需在条件判断中加入 is_musa() 即可让摩尔线程 GPU 享受同样的优化路径:
# PyTorch 原生注意力机制
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
ENABLE_PYTORCH_ATTENTION = True
# 异步权重卸载
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2
# 固定内存
if is_nvidia() or is_amd() or is_musa():
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * ...
# FP16/BF16 支持
if is_musa():
return True # 摩尔线程 GPU 支持 FP16 和 BF16
第四步:设置设备可见性环境变量
在入口文件 main.py 中,添加 MUSA_VISIBLE_DEVICES 环境变量的设置:
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['MUSA_VISIBLE_DEVICES'] = str(args.cuda_device) # ← 新增
效果
完成以上改动后,ComfyUI 即可在摩尔线程 GPU 上运行。从 PR 中的测试日志可以看到,ComfyUI 成功识别了 81838 MB 显存的摩尔线程 GPU,并完成了 Flux 图像生成和 Wan2.1 文生视频任务:
Total VRAM 81838 MB, total RAM 2063756 MB
Device: musa
Prompt executed in 55.36 seconds
整个适配过程只修改了 3 个文件,核心改动不到 50 行代码 —— 这正是 torchada 的价值所在。