两个简洁高效的cuda reduce写法

前言

cuda reduce 操作是非常常见的一个操作。在自己写kernel的时候也经常会有需要。

其中以 ReduceMax 和 ReduceSum 最为常见,其他操作原理也是完全一致的。

reduce kernel是很多cuda新手经常会遇到的课题,基本都是拿这个课题进行一步一步优化的。
那么实际工业界如何写一个简洁高效的kernel呢?这里总结2个写法。

不过这里需要注意的是,标题说的是两个,其实原理都是一样的,只是写法不一样
这里用的都是warp原语写法。

以warp原语有很自然的好处:

  1. 这个kernel基本上就是一个warp级kernel,再在外层包一层使其成为一个block级kernel。是一种很规范的kernel写法。
  2. warp级kernel能更方便的规避bank conflict。

1. __shfl_down_sync

warp kernel

第一种是pytorch的写法

template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
  for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, offset, warpSize);
  }
  return val;
}

__shfl_down_sync原语,其定义为

T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);

这个函数是将,被 mask 指定的线程A,再会往后偏移 delta,这样得出来的线程B,
线程A会拿到线程B的var,其他如果偏移过头的线程会返回0。
例如线程是30,偏移10,明显超出了warp size。

以下是大概的图片解释
a

第一次reduce,会对长度warpSize的数据进行reduce,会将 warpSize / 2 右边的数值规约到左边。

第二次则会对长度 warpSize / 2的数据进行reduce,会将 warpSize / 4 右边的数值规约到左边。

以此类推,最后结果会被规约到 warp[0] 的位置。

block kernel

下面按照计划编写block级的kernel。

block级的kernel其实只是执行多次warp kernel。
因为执行一次warp kernel之后,整个数列,以wrap为单位,有效数据会变得很散。

以reduceSum 为例子。数据格式大概会变成这个样子:
warp_sum_0,0,0,0,0,0....warp_sum_1,0,0,0,0.....warp_sum2,0,0,0....

因此需要将散列的数据集中起来,以便继续使用warp kernel。

template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
  const int laneid = threadIdx.x % warpSize;
  const int warpid = threadIdx.x / warpSize;
  val = WarpReduceSum(val);
  __syncthreads();
  if (laneid == 0) {
    shared[warpid] = val;
  }
  __syncthreads();
  val = (threadIdx.x < blockDim.x / warpSize) ? shared[laneid] : T(0);
  if (warpid == 0) {
    val = WarpReduceSum(val);
  }
  return val;
}

当然上面的写法也是有缺陷的,这里只是提供一个简单的思路罢了。

  1. 会有很多空转的线程
  2. 如果数据数据量很大,就得多执行几次 warp kernel,因为一次warp kernel只能执行 32,两次则是 32 * 32,以此类推。在这个时候其实已经不是很适合用本文提到的写法了。可以直接考虑使用nvidia的cub reduce

2. __shfl_xor_sync

第二种是使用 __shfl_xor_sync

template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
  for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
    val += __shfl_xor_sync(0xffffffff, val, mask));
  }
  return val;
}

__shfl_xor_sync原语,其定义为

T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);

这个函数是将,被 mask 指定的线程A,将他的lane id 与 laneMask进行异或,这样得出来的线程B,线程A会获得线程B的值。
注: 但实际上,线程A必定与线程B交换值,因为线程B的lane id对同样的lane mask进行异或,结果也必定是线程A。

两者的区别我以下图进行展示。如果只是reduce sum的话没有区别,但是在其他地方也有一定的用处。

b