在华为算子开发中,形状(Shape)是张量(Tensor)的一个核心属性,用于描述张量在各个维度上的大小。 它以元组或列表的形式表示,例如 (3, 4) 或 (2, 3, 4),其中每个数字对应一个维度的元素个数。
形状的维度数量由元组中元素的个数决定,例如形状 (4, 20, 20, 3) 表示一个四维张量。 在物理意义上,形状定义了数据的布局,例如在图像处理中,形状 (4, 20, 20, 3) 可以解释为包含4张20x20像素的彩色图片(每个像素由3个颜色通道组成)。
在算子开发中,正确理解形状对于实现算子逻辑至关重要,因为它影响数据访问模式和内存布局。 例如,在华为CANN框架中,形状与数据排布格式(如NCHW或NC1HWC0)结合使用,以优化计算性能。
形状(Shape)
张量的形状,以(D0, D1, … ,Dn-1)的形式表示,D0到Dn是任意的正整数。
如形状(3,4)表示第一维有3个元素,第二维有4个元素,(3,4)表示一个3行4列的矩阵数组。
形状的第一个元素对应张量最外层中括号中的元素个数,形状的第二个元素对应张量中从左边开始数第二个中括号中的元素个数,依此类推。例如:
物理含义我们应该怎么理解呢?假设我们有这样一个shape=(4, 20, 20, 3)。
假设有一些照片,每个像素点都由红/绿/蓝3色组成,即shape里面3的含义,照片的宽和高都是20,也就是20*20=400个像素,总共有4张的照片,这就是shape=(4, 20, 20, 3)的物理含义。
如果体现在编程上,可以简单把shape理解为操作Tensor的各层循环,比如我们要对shape=(4, 20, 20, 3)的A tensor进行操作,循环语句如下:
produceA{for(i,0,4){for(j,0,20){for(p,0,20){for(q,0,3){A[((((((i*20)+j)*20)+p)*3)+q)]=a_tensor[((((((i*20)+j)*20)+p)*3)+q)]}}}}}Shape推导(Shape Inference)
Shape推导是图模式下的核心环节。开发者可以通过两种方式实现:
Follow模式
若输出Shape与某输入Shape完全一致,可使用Follow接口快速表达:
this->Output("y1").ParamType(REQUIRED).Follow("x1",FollowType::SHAPE);自定义InferShape函数
对于输出Shape与输入Shape存在复杂关系的算子,如Reshape,需编写自定义InferShape函数:
ge::graphStatusInferShapeForReshape(InferShapeContext*context){constgert::Shape*x_shape=context->GetInputShape(0);constgert::Tensor*shape_tensor=context->GetInputTensor(1);gert::Shape*output_shape=context->GetOutputShape(0);if(!x_shape||!shape_tensor||!output_shape)returnge::GRAPH_FAILED;auto reshape_size=static_cast<int32_t>(shape_tensor->GetShapeSize());if(shape_tensor->GetDataType()==ge::DT_INT32){int32_t*reshape_data=shape_tensor->GetData<int32_t>();returnReshapeInferShapeImpl<int32_t>(reshape_data,*x_shape,*output_shape,reshape_size);}else{int64_t*reshape_data=shape_tensor->GetData<int64_t>();returnReshapeInferShapeImpl<int64_t>(reshape_data,*x_shape,*output_shape,reshape_size);}}数据依赖算子
部分算子在Shape推导时,需要依赖输入的真实值,如Reshape依赖shape输入。此类输入需通过ValueDepend(REQUIRED)声明:
this->Input("shape").ParamType(REQUIRED).ValueDepend(REQUIRED);动态Shape与ShapeRange推导
有些算子(如Unique)的输出Shape在编译阶段无法确定,必须在执行时才能得出。这时需要ShapeRange推导,用于预估最大输出内存:
ge::graphStatusUniqueInferShapeRangeFunc(gert::InferShapeRangeContext*context){auto x_shape_range=context->GetInputShapeRange(0U);auto y_shape_range=context->GetOutputShapeRange(0U);y_shape_range->GetMax()->SetDim(0,x_shape_range->GetMax()->GetDim(0));y_shape_range->GetMin()->SetDim(0,x_shape_range->GetMin()->GetDim(0));returnge::GRAPH_SUCCESS;}通过ShapeRange推导,框架可以安全地为动态输出分配内存,保证算子执行的正确性。
昇腾训练营报名链接:
https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro
训练营简介:2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖