前言
cuda reduce 操作是非常常见的一个操作。在自己写kernel的时候也经常会有需要。
其中以 ReduceMax 和 ReduceSum 最为常见,其他操作原理也是完全一致的。
reduce kernel是很多cuda新手经常会遇到的课题,基本都是拿这个课题进行一步一步优化的。
那么实际工业界如何写一个简洁高效的kernel呢?这里总结2个写法。
不过这里需要注意的是,标题说的是两个,其实原理都是一样的,只是写法不一样。
这里用的都是warp原语写法。
以warp原语有很自然的好处:
- 这个kernel基本上就是一个warp级kernel,再在外层包一层使其成为一个block级kernel。是一种很规范的kernel写法。
- warp级kernel能更方便的规避bank conflict。
1. __shfl_down_sync
warp kernel
第一种是pytorch的写法
template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset, warpSize);
}
return val;
}
__shfl_down_sync原语,其定义为
T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
这个函数是将,被 mask 指定的线程A,再会往后偏移 delta,这样得出来的线程B,
线程A会拿到线程B的var,其他如果偏移过头的线程会返回0。
例如线程是30,偏移10,明显超出了warp size。
以下是大概的图片解释
第一次reduce,会对长度warpSize的数据进行reduce,会将 warpSize / 2 右边的数值规约到左边。
第二次则会对长度 warpSize / 2的数据进行reduce,会将 warpSize / 4 右边的数值规约到左边。
以此类推,最后结果会被规约到 warp[0] 的位置。
block kernel
下面按照计划编写block级的kernel。
block级的kernel其实只是执行多次warp kernel。
因为执行一次warp kernel之后,整个数列,以wrap为单位,有效数据会变得很散。
以reduceSum 为例子。数据格式大概会变成这个样子:
warp_sum_0,0,0,0,0,0....warp_sum_1,0,0,0,0.....warp_sum2,0,0,0....
因此需要将散列的数据集中起来,以便继续使用warp kernel。
template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int laneid = threadIdx.x % warpSize;
const int warpid = threadIdx.x / warpSize;
val = WarpReduceSum(val);
__syncthreads();
if (laneid == 0) {
shared[warpid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[laneid] : T(0);
if (warpid == 0) {
val = WarpReduceSum(val);
}
return val;
}
当然上面的写法也是有缺陷的,这里只是提供一个简单的思路罢了。
- 会有很多空转的线程
- 如果数据数据量很大,就得多执行几次 warp kernel,因为一次warp kernel只能执行 32,两次则是 32 * 32,以此类推。在这个时候其实已经不是很适合用本文提到的写法了。可以直接考虑使用nvidia的cub reduce
2. __shfl_xor_sync
第二种是使用 __shfl_xor_sync
template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__shfl_xor_sync原语,其定义为
T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
这个函数是将,被 mask 指定的线程A,将他的lane id 与 laneMask进行异或,这样得出来的线程B,线程A会获得线程B的值。
注: 但实际上,线程A必定与线程B交换值,因为线程B的lane id对同样的lane mask进行异或,结果也必定是线程A。
两者的区别我以下图进行展示。如果只是reduce sum的话没有区别,但是在其他地方也有一定的用处。