CUTE Learn-2





CUTE Layout 系统基础分析


CUTE Layout 系统基础分析

概述

CUTE 的核心是一个强类型的布局系统,用于表达多维张量在内存中的索引和映射关系。Layout 类是整个系统的基石,它通过 Shape 和 Stride 两个组件定义了从逻辑坐标到线性内存地址的映射关系。

Layout 类的定义

template <class Shape, class Stride = LayoutLeft::Apply<Shape> >
struct Layout
    : private cute::tuple<Shape, Stride>   // EBO for static layouts

Layout 类使用私有继承自 cute::tuple<Shape, Stride>,这样做的好处是:

  • 利用 EBO (Empty Base Optimization) 优化静态布局的存储
  • 将 Shape 和 Stride 作为一个统一的数据结构管理

核心类型别名

CUTE 定义了一系列类型别名,本质上都是 cute::tuple 的特化:

template <class... Shapes>
using Shape = cute::tuple<Shapes...>;

template <class... Strides>
using Stride = cute::tuple<Strides...>;

template <class... Coords>
using Coord = cute::tuple<Coords...>;

template <class... Layouts>
using Tile = cute::tuple<Layouts...>;

这些别名提供了强类型的语义区分,使代码更易读且类型安全。

构造函数和工厂函数

构造函数

CUTE_HOST_DEVICE constexpr
Layout(Shape  const& shape  = {}, Stride const& stride = {})
    : cute::tuple<Shape, Stride>(shape, stride)
{}

工厂函数

template <class... Ts>
CUTE_HOST_DEVICE constexpr
Shape<Ts...> make_shape(Ts const&... t) {
  return {t...};
}

template <class... Ts>
CUTE_HOST_DEVICE constexpr
Stride<Ts...> make_stride(Ts const&... t) {
  return {t...};
}

template <class... Ts>
CUTE_HOST_DEVICE constexpr
Coord<Ts...> make_coord(Ts const&... t) {
  return {t...};
}

工厂函数提供了更方便的构造方式,避免了显式模板参数的指定。

基本访问器

Rank(维度数)

static constexpr int rank = rank_v<Shape>;

Layout 的维度由 Shape 的维度决定,这是一个编译时常量。

Shape 和 Stride 访问

template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto) shape() const {
  return get<0,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}

template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto) stride() const {
  return get<1,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}

这些访问器支持多级索引,可以获取嵌套 tuple 中的特定元素。

核心映射操作:operator()

Layout 的核心功能是 operator() 函数,它根据输入坐标的类型执行不同操作:

template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator()(Coord const& coord) const {
  if constexpr (has_underscore<Coord>::value) {
    return slice(coord, *this);        // 切片操作
  } else {
    return crd2idx(coord, shape(), stride());  // 索引计算
  }
}

两种操作模式

1. 索引计算模式

当坐标中没有下划线(_)时,执行标准的坐标到索引的映射:

逻辑坐标 → 线性索引

2. 切片模式

当坐标中包含下划线(_)时,执行切片操作:

带下划线的坐标 → 新的 Layout(子布局)

这种设计允许统一的接口同时支持索引访问和张量切片操作。

坐标到索引的映射 (crd2idx)

crd2idx 函数实现了多维坐标到一维索引的核心算法:

index = coord[0] * stride[0] + coord[1] * stride[1] + ... + coord[n] * stride[n]

这是标准的线性索引计算公式,支持任意维度的张量布局。

Layout 的组合操作

compose() 函数

template <class OtherLayout>
CUTE_HOST_DEVICE constexpr
auto compose(OtherLayout const& other) const {
  return composition(*this, other);
}

Layout 的组合允许将两个布局函数复合,创建更复杂的映射关系:

composed_layout = layout_a ∘ layout_b

这在实现复杂的张量变换时非常有用。

设计原则总结

  1. 强类型设计:通过类型别名区分不同语义的 tuple
  2. 编译时优化:大量使用 constexpr 和模板元编程
  3. 零成本抽象:利用 EBO 优化存储,静态布局几乎零开销
  4. 统一接口:operator() 同时支持索引和切片操作
  5. 可组合性:Layout 可以通过 compose 等操作进行组合和变换

示例用法

// 创建一个 2x3 的行主序布局
auto shape = make_shape(2, 3);
auto stride = make_stride(3, 1);
auto layout = Layout{shape, stride};

// 索引访问:layout(1, 2) = 1*3 + 2*1 = 5
auto idx = layout(1, 2);

// 切片操作:layout(1, _) 获取第二行
auto row_layout = layout(1, _);

这个基础设计为 CUTE 的整个张量系统奠定了坚实的基础。



已发布

分类

来自

标签:

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注