GridSampler3DGrad
【免费下载链接】ops-cv本项目是CANN提供的图像处理、目标检测相关的算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-cv
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | √ |
功能说明
算子功能:GridSampler中3D场景的反向传播,完成张量input与张量grid的梯度计算。
计算公式:
计算流程:
- 根据grid存储的(x, y, z)值,计算出映射到input上的坐标,这些坐标和align_corners、padding_mode有关。
- 坐标根据输入的interpolation_mode,选择使用bilinear、nearest不同插值模式计算输出值。
- 根据grad存储的梯度值乘上对应点的权重值,计算出最终dx、dgrid的结果。
其中:
grad、input、grid、dx、dgrid的尺寸如下:
$$ grad: (N, C, D_{out}, H_{out}, W_{out})\ input: (N, C, D_{in}, H_{in}, W_{in})\ grid: (N, D_{out}, H_{out}, W_{out}, 3)\ dx: (N, C, D_{in}, H_{in}, W_{in})\ dgrid: (N, D_{out}, H_{out}, W_{out}, 3) $$
其中grad、input、grid、dx、dgrid中的N是一致的,grad、input和dx中的C是一致的,input和dx中的$D_{in}$、$H_{in}$、$W_{in}$是一致的,grad、grid和dgrid中的$D_{out}$、$H_{out}$、$W_{out}$是一致的,grid最后一维大小为3,表示input像素位置信息为(x, y, z),会将x、y、z的取值范围归一化到[-1, 1]之间。
对于超出范围的坐标,会根据padding_mode进行不同处理:
- padding_mode="zeros",表示对越界位置用0填充。
- padding_mode="border",表示对越界位置用边界值填充。
- padding_mode="reflection",表示对越界位置用边界值的对称值填充。
对input采样时,会根据interpolation_mode进行不同处理:
- interpolation_mode="bilinear",表示取input中(x, y, z)周围八个坐标的加权平均值。
- interpolation_mode="nearest",表示取input中距离(x, y, z)最近的坐标值。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| grad | 输入 | 表示反向传播过程中上一层的输出梯度,对应公式描述中的`grad`。数据类型与`x`的数据类型一致。当数据类型DOUBLE时,数据格式不支持NDHWC。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| x | 输入 | 表示反向传播的输入张量,对应公式描述中的`input`。shape仅支持五维,且需满足`x`和`grad`的N轴和C轴的值保持一致,x的D,H,W值不可为0。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| grid | 输入 | 表示采用像素位置的张量,对应公式描述中的`grid`。shape仅支持五维,且需满足`grid`和`grad`的N轴、D轴、H轴、W轴的值保持一致,最后一维的值等于3。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NDHWC |
| interpolation_mode | 可选属性 |
| STRING | - |
| padding_mode | 可选属性 |
| STRING | - |
| align_corners | 可选属性 |
| BOOL | - |
| dx | 输出 | 表示反向传播的输出梯度,对应公式描述中的`dx`。数据类型、数据格式和shape与`x`的数据类型、数据格式和shape保持一致。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| dgrid | 输出 | 表示`grid`梯度,对应公式描述中的`dgrid`。数据类型、数据格式和shape与`grid`的数据类型、数据格式和shape保持一致。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NDHWC |
Atlas 训练系列产品 :输入参数和输出参数的数据类型不支持DOUBLE、BFLOAT16。
约束说明
无
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_grid_sampler3_d_backward | 通过aclnnGridSampler3DBackward接口方式调用GridSampler3DGrad算子。 |
| 图模式 | - | 通过算子IR构图方式调用GridSampler3DGrad算子。 |
【免费下载链接】ops-cv本项目是CANN提供的图像处理、目标检测相关的算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-cv
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考