CUTE Layout 系统基础分析
概述
CUTE 的核心是一个强类型的布局系统,用于表达多维张量在内存中的索引和映射关系。Layout 类是整个系统的基石,它通过 Shape 和 Stride 两个组件定义了从逻辑坐标到线性内存地址的映射关系。
Layout 类的定义
template <class Shape, class Stride = LayoutLeft::Apply<Shape> >
struct Layout
: private cute::tuple<Shape, Stride> // EBO for static layouts
Layout 类使用私有继承自 cute::tuple<Shape, Stride>,这样做的好处是:
- 利用 EBO (Empty Base Optimization) 优化静态布局的存储
- 将 Shape 和 Stride 作为一个统一的数据结构管理
核心类型别名
CUTE 定义了一系列类型别名,本质上都是 cute::tuple 的特化:
template <class... Shapes>
using Shape = cute::tuple<Shapes...>;
template <class... Strides>
using Stride = cute::tuple<Strides...>;
template <class... Coords>
using Coord = cute::tuple<Coords...>;
template <class... Layouts>
using Tile = cute::tuple<Layouts...>;
这些别名提供了强类型的语义区分,使代码更易读且类型安全。
构造函数和工厂函数
构造函数
CUTE_HOST_DEVICE constexpr
Layout(Shape const& shape = {}, Stride const& stride = {})
: cute::tuple<Shape, Stride>(shape, stride)
{}
工厂函数
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Shape<Ts...> make_shape(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Stride<Ts...> make_stride(Ts const&... t) {
return {t...};
}
template <class... Ts>
CUTE_HOST_DEVICE constexpr
Coord<Ts...> make_coord(Ts const&... t) {
return {t...};
}
工厂函数提供了更方便的构造方式,避免了显式模板参数的指定。
基本访问器
Rank(维度数)
static constexpr int rank = rank_v<Shape>;
Layout 的维度由 Shape 的维度决定,这是一个编译时常量。
Shape 和 Stride 访问
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto) shape() const {
return get<0,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}
template <int... I>
CUTE_HOST_DEVICE constexpr
decltype(auto) stride() const {
return get<1,I...>(static_cast<cute::tuple<Shape, Stride> const&>(*this));
}
这些访问器支持多级索引,可以获取嵌套 tuple 中的特定元素。
核心映射操作:operator()
Layout 的核心功能是 operator() 函数,它根据输入坐标的类型执行不同操作:
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
return slice(coord, *this); // 切片操作
} else {
return crd2idx(coord, shape(), stride()); // 索引计算
}
}
两种操作模式
1. 索引计算模式
当坐标中没有下划线(_)时,执行标准的坐标到索引的映射:
逻辑坐标 → 线性索引
2. 切片模式
当坐标中包含下划线(_)时,执行切片操作:
带下划线的坐标 → 新的 Layout(子布局)
这种设计允许统一的接口同时支持索引访问和张量切片操作。
坐标到索引的映射 (crd2idx)
crd2idx 函数实现了多维坐标到一维索引的核心算法:
index = coord[0] * stride[0] + coord[1] * stride[1] + ... + coord[n] * stride[n]
这是标准的线性索引计算公式,支持任意维度的张量布局。
Layout 的组合操作
compose() 函数
template <class OtherLayout>
CUTE_HOST_DEVICE constexpr
auto compose(OtherLayout const& other) const {
return composition(*this, other);
}
Layout 的组合允许将两个布局函数复合,创建更复杂的映射关系:
composed_layout = layout_a ∘ layout_b
这在实现复杂的张量变换时非常有用。
设计原则总结
- 强类型设计:通过类型别名区分不同语义的 tuple
- 编译时优化:大量使用 constexpr 和模板元编程
- 零成本抽象:利用 EBO 优化存储,静态布局几乎零开销
- 统一接口:operator() 同时支持索引和切片操作
- 可组合性:Layout 可以通过 compose 等操作进行组合和变换
示例用法
// 创建一个 2x3 的行主序布局
auto shape = make_shape(2, 3);
auto stride = make_stride(3, 1);
auto layout = Layout{shape, stride};
// 索引访问:layout(1, 2) = 1*3 + 2*1 = 5
auto idx = layout(1, 2);
// 切片操作:layout(1, _) 获取第二行
auto row_layout = layout(1, _);
这个基础设计为 CUTE 的整个张量系统奠定了坚实的基础。
发表回复