news 2026/7/3 5:18:42

Triton Puzzles(Demo1-4)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Triton Puzzles(Demo1-4)

Triton Puzzles

之前做tilelang puzzles的时候,发现readme里提到是仿照triton puzzles的,但当时感觉triton没有学的必要,就没做

最近发现triton的设计思想和tilelang差异很大,感觉可以开拓一下视野,就找到这个https://github.com/SiriusNEO/Triton-Puzzles-Lite项目看看,这是改进过的轻量版,不是原版triton puzzles,题目内容没变,只是减少了依赖,原本的可视化和jupyter notebook都去掉了,就在.py文件运行,并且附上了作者写的答案,可以对比学习。

环境需要,别的可能也行,但是作者建议这个,这个是肯定可以跑通的,高版本可能报错。

pipinstalltorch==2.5.0# Check triton version: triton==3.1.0

安装时,如果能访问外网,实测最快的是直接用pytorch库,其他源都可能会把torch下载限速(因为下载的人太多了,还大,可能把带宽占满了),然后注意这里的cu124,根据显卡的驱动版本来安装,可以先检查cuda驱动版本。然后选一个低于驱动版本的torch whl,这种库都是可以向下兼容,但不能向上兼容。

python3-mpipinstalltorch==2.5.0 --index-url https://download.pytorch.org/whl/cu124 --no-cache-dir--isolated

推荐使用<2.0.0的numpy,结果正确性验证时会用numpy,check脚本用了低版本接口,版本高了会出错。

python3-mpipinstallnumpy==1.26.4--isolated

运行时设置环境变量,1表示用cpu模式py解释器运行,0则是gpu模式。gpu模式由于显卡版本不同可能出现各种bug,推荐先cpu模式跑通,这也是原始triton puzzles的推荐运行方式。当前仓库的答案,gpu模式下case 11会运行出错。

TRITON_INTERPRET=1python3 puzzles.py-a

最后的参数部分

-a#运行全部puzzles-px#运行第x个puzzle-i#运行四个demo-h#显示帮助文档

clone下来可以先跑以下指令,验证所有答案cpu模式下是不是都能跑通,能的话说明基础环境配置没问题。

TRITON_INTERPRET=1python3 puzzles_ans.py-a

Triton简介

Triton 是由 OpenAI 开源的一种专为深度学习加速设计的编程语言和编译器。

如果你写过 CUDA,你可能会觉得它太底层、开发周期太长;如果你只用 PyTorch,你可能会发现很多自定义的算子(比如各种新型的 Attention 或量化算子)无法获得极致的性能。Triton 的诞生,正是为了在“开发效率”与“极致性能”之间取得完美的平衡。

1. Triton 解决的核心痛点

在传统的 GPU 算子开发中,通常面临两极分化:

高端玩家(写 CUDA C++): 可以手动控制线程块、共享内存(Shared Memory)和寄存器,性能毁天灭地,但开发极其痛苦,且代码很难跨硬件(比如从 NVIDIA 转到 AMD)复用。

普通玩家(写 PyTorch/TensorFlow): 拼凑现有的 API(如 torch.relu + torch.matmul),开发极快,但会在显存中产生大量中间变量,造成频繁的显存读写(Memory Bound),浪费算力。

Triton 的核心思想是:让没有 CUDA 经验的深度学习研究员,也能用类似 Python 的语法,写出性能媲美甚至超越专家级 CUDA 的硬件加速算子。

2. Triton 的核心设计理念:基于块(Block-based)的编程

这也是 Triton 与 CUDA 最本质的区别:

CUDA 是“基于线程(Thread-based)”的: 你需要精确计算每个 Thread 的 ID,去算它该读哪一个具体的显存地址,还要手动处理线程之间的同步(__syncthreads())和数据共享。

Triton 是“基于块(Block-based)”的: 它把张量块(Block)作为一等公民(First-class citizen)。你不需要操心单个线程,而是直接对一个分块进行加载(tl.load)、计算(tl.dot)和存储(tl.store)。

并且,triton除了是基于数据块的编程,还是声明式编程,而不是CUDA的过程式编程,也就是你只用写要对这个数据块做什么,而不需要写怎么做,编译器会把做什么转化成怎么做的机器码。

3. 编译器在幕后做了什么?

既然写起来像 Python 一样简单,那极致的性能是怎么来的?这全靠 Triton 编译器。它会把你的 Python 风格代码编译成高效的机器码(通过 LLVM IR 到 PTX/AMDGCN),自动帮你做好以下最头疼的硬件优化:

自动内存合并(Memory Coalescing): 自动优化全局显存(Global Memory)的访问模式,确保带宽跑满。

自动管理共享内存(Shared Memory Allocation): 你不需要像写 CUDA 那样手动声明shared数组,编译器会自己决定什么时候把数据缓存在片上高速存储里。

指令流水线与排程(Instruction Scheduling): 自动隐藏访存延迟,让计算单元(Tensor Cores)和访存单元能够高效并发。

注意这里和tilelang的设计思想不同,并不会先映射到CUDA代码,再编译。而是自定义了TTIR(Triton IR),生成TTIR后,下一步就会映射到PTX、SaaS代码了,不会经过CUDA,也就是triton可以被视为一个独立的语言,有自己的编译路径,而不是CUDA语法糖。

4. 谁在用 Triton?

如今 Triton 已经成为大模型时代基础设施的绝对主力:

PyTorch 2.0+ 的核心: PyTorch 2.0 引入的重磅编译功能 TorchInductor,其后端默认就是将 PyTorch 代码自动生成为 Triton 内核,这也是其实现图编译加速的秘密武器。

FlashAttention: 著名的闪电注意力机制,其后续的很多高效变体和工程实现(如 FlashAttention-3)都大量采用了 Triton 进行快速迭代。

大模型推理加速: 比如 vLLM、DeepSpeed 以及各类轻量级量化插件,里面普遍包含大量用 Triton 编写的定制化算子(如上面我们聊到的量化 GEMM)。

如果你想深入 AI 芯片底层硬件加速,或者想为自己的大模型设计专属的奇门遁甲算子,Triton 是目前投产比(ROI)最高、最值得学习的技术。

Demos

demo 1

数据搬运是GPU编程中最核心的概念,第一个示例主要熟悉tl.load搬运数据

tl.load(ptr, mask)参数是两个张量,ptr是一个指针数组,表示数据搬运源地址,数组内每个指针对应一个要搬运的元素。mask是一个掩码数组,数据类型是bool,用0/1表示ptr数组中传入的每个指针,是否搬运。

需要额外引入mask的原因是triton里的所有张量(数据块)的大小都是二的幂次,如果我们想灵活搬运一个大小不对齐的张量时,比如大小5,可以传入一个刚好大于这个张量大小的指针数组,长度对齐2的幂次,然后用mask来约束搬运范围,比如mask就是[1,1,1,1,1,0,0,0],表示前五个位置利用指针地址搬运,后三个位置不进行操作。

需要注意的是,

  • 这里传入的x_ptr,已经不是torch tensor了,而是底层数据的首地址,类似c的数组首地址指针,这也是命名上带一个ptr的原因,因此我们传入指针ptr数组和mask,需要人为避免越界,如果x_ptr对应的tensor只有八个元素,那么就不能访问大于8的位置,否则会运行错误或者读到垃圾值。编译器不会阻止你,编译时的思路是类C的,允许你直接用指针寻址。
  • 如果指针数组大小超过tensor了,但是mask限制了读取范围,不会出问题,因为mask为0的位置,不会真的去读内存,而是直接返回一个值表示不操作,可以在tl.load(ptr, mask,0)操作时传入第三个参数,表示mask为0的位置填充什么值,如果不传入第三个参数,默认填充0

定义讲完了,来看这个算子的具体事项。range = tl.arange(0, 8)类似torch.arrange,生成一个公差为1的等差数列,左闭右开。

x = tl.load(x_ptr + range, range < 5, 0),这一行有很多看点。

  • x_ptr + range,这里的x_ptr本身是一个指针,也就是一个标量,但是range是刚才生成的数据块,两者相加,这里triton规定,遵循torch/numpy的广播规则,把标量广播到和张量一样的shape,再执行相加。也就是此时形成了一个
[x_ptr,x_ptr+1,x_ptr+2,...,x_ptr+7]

的指针数组,接下来会去这个数组内的位置搬运数据。

  • range < 5类似,5是一个标量,会广播到和range一样大,然后<操作会返回一个bool数组,用这个方式就构造了一个[1 1 1 1 1 0 0 0]的mask
  • x = tl.load(x_ptr + range, range < 5, 0)最后load返回的是一个triton数据块,需要把它复制给一个变量保存下来。

demo1[(1, 1, 1)](torch.ones(4, 3))最后是triton内核的启动方法,triton设计时DSL还没这么多,很多设计师对齐CUDA,比如这里(1, 1, 1)就是CUDA启动时传入的launch参数dim3,表示grid shape,或者说三个维度的block个数。

传递给函数的直接参数则在后面圆括号内,这里传入一个二维张量,(torch.ones(4, 3))。可能会好奇,这里传入的是二维张量,但kernel内看起来是把他当成一维数组用的?这也是类C设计带来的,CUDA编程时,多维数组不管几维,都是当成一维数组使用,用的时候再多次寻址实现多维数组的效果,triton继承了这一点,这个张量4*3=12个元素,在triton kernel内会看成一个长度12的连续内存。

r""" ## Introduction To begin with, we will only use `tl.load` and `tl.store` in order to build simple programs. """""" ### Demo 1 Here's an example of load. It takes an `arange` over the memory. By default the indexing of torch tensors with column, rows, depths or right-to-left. It also takes in a mask as the second argument. Mask is critically important because all shapes in Triton need to be powers of two. Expected Results: [0 1 2 3 4 5 6 7] [1. 1. 1. 1. 1. 0. 0. 0.] Explanation: tl.load(ptr, mask) tl.load use mask: [0 1 2 3 4 5 6 7] < 5 = [1 1 1 1 1 0 0 0] """@triton.jitdefdemo1(x_ptr):range=tl.arange(0,8)# print works in the interpreterprint(range)x=tl.load(x_ptr+range,range<5,0)print(x)defrun_demo1():print("Demo1 Output: ")demo1[(1,1,1)](torch.ones(4,3))print_end_line()"""

demo 2

仍然是load,只是这次需要load一个复杂一点的二维区域i < 4 and j < 3

那么用一个range mask就有点难做到了,可以用两个。

首先构造两个等差数列,一个对应行,一个对应列。然后给他们升维,类似torch.unsqueeze,弄完之后两个mask的shape分别是(8,1)(1,4)

i_range=tl.arange(0,8)[:,None]j_range=tl.arange(0,4)[None,:]

range = i_range * 4 + j_range让这两个mask做加法,遵循torch/numpy广播规则,会都先变成(8,4)再执行加法。并且加之前,先把行张量乘上每一行的元素个数,这样最后得到的结果,每个位置的值都等于,把这个张量展开到一维后这个位置的编号,可以用来构造mask数组了

(i_range < 4) & (j_range < 3)构造mask时可以把两个条件取and,这里重载了&的规则,不是py里的按位与,而是表示and。这样我们就限制了只拷贝i < 4 and j < 3的区域

""" ### Demo 2: You can also use this trick to read in a 2d array. Expected Results: [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15] [16 17 18 19] [20 21 22 23] [24 25 26 27] [28 29 30 31]] [[1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.]] Explanation: tl.load use mask: i < 4 and j < 3. """@triton.jitdefdemo2(x_ptr):i_range=tl.arange(0,8)[:,None]j_range=tl.arange(0,4)[None,:]range=i_range*4+j_range# print works in the interpreterprint(range)x=tl.load(x_ptr+range,(i_range<4)&(j_range<3),0)print(x)defrun_demo2():print("Demo2 Output: ")demo2[(1,1,1)](torch.ones(4,4))print_end_line()

demo 3

这节主要是学习tl.store写入操作,和读取tl.load一起构成了完整的数据搬运。

tl.store(ptr, value, mask)参数和tl.load类似,也是传入一个指针数组,一个mask,只不过这是个无返回值的函数,所以ptr就是目的地址,源则是value。ptr类似前面load的规则,类C的指针数组,手动寻址。但value是类似py张量,可以传入一个标量进行广播,也可以传入一个前面load进来的张量,不能传入和ptr类似的指针数组,也就是源不是给传指针,寻址,而是直接给出值。

一般的范式是,读取到一个张量,做想做的操作,然后再写入,也就是读取写入之间一定有一个张量来倒手。

x=tl.load(x_ptr,mask)tl.store(y_ptr,x,mask)

来看具体实现,z = tl.store(z_ptr + range, 10, range < 5)这里用z接受了返回值,其实是一个陷阱,tl.store无返回值,所以尝试print(z)会报错。想看结果,数据已经被写入z_ptr为首地址的张量了,在kernel内只有首地址指针,没有z_ptr对应的张量对象,看不了,必须从kernel里返回后host侧才能看。

""" ### Demo 3 The `tl.store` function is quite similar. It allows you to write to a tensor. Expected Results: tensor([[10., 10., 10.], [10., 10., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) Explanation: tl.store(ptr, value, mask) here range < 5 corresponds to the 2D-mask [[1. 1. 1.] [1. 1. 0.] [0. 0. 0.] [0. 0. 0.]] """@triton.jitdefdemo3(z_ptr):range=tl.arange(0,8)z=tl.store(z_ptr+range,10,range<5)defrun_demo3():print("Demo3 Output: ")z=torch.ones(4,3)demo3[(1,1,1)](z)print(z)print_end_line()"""

demo 4

前三个都是单线程的,但作为GPU编程当然可以根据数据块编号不同,做不同的操作,这节来看如何利用tl.program_id确定所在块号,然后执行不同操作。

tl.program_id(0)这里的0,1,2分别是取出这个数据块的三个维度编号,三个维度是我们启动内核时传入的,比如这里就是demo4[(3, 1, 1)](x),表示0维度长度3,另外1,2维度长度1,也就是有3 * 1 * 1 = 3个block。

x = torch.ones(2, 4, 4)传入的张量展平后有32个元素,想要搬运前20个。均分给三个block实现,考虑到每次搬运操作的长度都是二的幂次,最少的搬运方式是,每个block搬8个元素,前两个block都全搬,最后一个block只用搬前四个,设一个mask实现这一点。

kernel内,range = tl.arange(0, 8) + pid * 8实现了每个block搬运的位置不同,也就是根据block id进行偏移。每个都搬长度为8的区间,所以生成一个长度8的等差数列,然后累加上块偏移,就是这个块负责的地址范围

range < 20为了只搬前20个,增加一个mask限制,这个限制只会让最后一个block的mask是前四个1,后四个0,对前两个block无影响。

""" ### Demo 4 You can only load in relatively small `blocks` at a time in Triton. To work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks. Expected Results: Print for each [0] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [1] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [2] [1. 1. 1. 1. 0. 0. 0. 0.] Explanation: This program launch 3 blocks in parallel. For each block (pid=0, 1, 2), it loads 8 elements. Note that similar to demo3, multi-dimensional tensors are flattened when we use pointer (i.e. continuous in memory). """@triton.jitdefdemo4(x_ptr):pid=tl.program_id(0)range=tl.arange(0,8)+pid*8x=tl.load(x_ptr+range,range<20)print("Print for each",pid,x)defrun_demo4():print("Demo4 Output: ")x=torch.ones(2,4,4)demo4[(3,1,1)](x)print_end_line()
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/7/3 5:18:04

OpenAI Residency:顶尖AI工程师的实战破壁实验室

1. 项目概述&#xff1a;这不是“AI夏令营”&#xff0c;而是一场高强度、高门槛的实战淬炼“Top AI Adventure: OpenAI Residency”——光看标题&#xff0c;很多人第一反应是“哦&#xff0c;OpenAI办的AI训练营”或者“类似谷歌AI residency那种实习项目”。但实话讲&#x…

作者头像 李华
网站建设 2026/7/3 5:15:07

软件审计风暴下,企业如何用自动化工具守住合规底线?

近年来&#xff0c;软件供应商的审计力度正在以前所未有的速度收紧。数据显示&#xff0c;过去三年中多达73%的企业遭遇过Oracle发起的Java合规性审计。而审计一旦启动&#xff0c;代价往往超出预期——超过四分之一的受访企业每年在解决非合规许可问题上的花费超过50万美元&am…

作者头像 李华
网站建设 2026/7/3 5:11:51

ArgoCD从内网的GitLab Repo部署应用

K8s Cluster 我有一个3节点的k8s集群&#xff1a; 一个maste&#xff0c;2个node ❯ kubectl cluster-info Kubernetes control plane is running at https://192.168.1.101:6443 CoreDNS is running at https://192.168.1.101:6443/api/v1/namespaces/kube-system/services/…

作者头像 李华
网站建设 2026/7/3 5:11:12

信创深水区,企业即时通讯如何走出替代陷阱

信创深水区&#xff0c;企业即时通讯如何走出“替代陷阱” 当信创从党政机关向金融、能源、电信等八大行业全面铺开&#xff0c;企业即时通讯&#xff08;IM&#xff09;的选型逻辑正在被彻底改写。过去&#xff0c;许多组织将“替换微信/钉钉”等同于完成国产化任务&#xff0…

作者头像 李华
网站建设 2026/7/3 5:10:03

领导给我一台麒麟V10:你去用 nginx 部署一个前端项目

第一步&#xff1a;安装 nginx 1. 确定系统信息 用root用户执行nkvers命令查看系统信息&#xff1a; ############## Kylin Linux Version ################# Release: Kylin Linux Advanced Server release V10 (Sword)Kernel: 4.19.90-24.4.v2101.ky10.aarch64Build: Kyli…

作者头像 李华
网站建设 2026/7/3 5:07:12

VBA 宏编辑

体且垂直居中&#xff0c;区域内容为微软雅黑不加粗10号字体且垂直居中。Sub 一键处理JKLM()Dim ws As WorksheetSet ws ActiveSheetDim lastRowB As Long, lastRowC As LongDim lastRowD As Long, lastRowE As LonglastRowB ws.Cells(ws.Rows.Count, "B").End(xlU…

作者头像 李华