gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
@@ -0,0 +1,613 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""
|
4
|
+
Structured representation of all JSON data structures following the
|
5
|
+
GfxRuntime140.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from abc import ABC
|
9
|
+
from enum import Enum
|
10
|
+
from typing import Any, Dict, List, Optional
|
11
|
+
|
12
|
+
from taichi.aot.conventions.gfxruntime140 import dr
|
13
|
+
from taichi.types.enums import DeviceCapability, Format
|
14
|
+
|
15
|
+
|
16
|
+
class DataType(Enum):
|
17
|
+
f16 = 0
|
18
|
+
f32 = 1
|
19
|
+
f64 = 2
|
20
|
+
i8 = 3
|
21
|
+
i16 = 4
|
22
|
+
i32 = 5
|
23
|
+
i64 = 6
|
24
|
+
u8 = 8
|
25
|
+
u16 = 9
|
26
|
+
u32 = 10
|
27
|
+
u64 = 11
|
28
|
+
|
29
|
+
|
30
|
+
def get_data_type_size(dtype: DataType) -> int:
|
31
|
+
if dtype in [DataType.f16, DataType.i16, DataType.u16]:
|
32
|
+
return 2
|
33
|
+
if dtype in [DataType.f32, DataType.i32, DataType.u32]:
|
34
|
+
return 4
|
35
|
+
if dtype in [DataType.f64, DataType.i64, DataType.u64]:
|
36
|
+
return 8
|
37
|
+
assert False
|
38
|
+
|
39
|
+
|
40
|
+
class Argument(ABC):
|
41
|
+
def __init__(self, name: Optional[str]):
|
42
|
+
self.name = name
|
43
|
+
pass
|
44
|
+
|
45
|
+
|
46
|
+
class ArgumentScalar(Argument):
|
47
|
+
def __init__(self, name: Optional[str], dtype: DataType):
|
48
|
+
super().__init__(name)
|
49
|
+
self.dtype: DataType = dtype
|
50
|
+
|
51
|
+
|
52
|
+
class ParameterType(Enum):
|
53
|
+
Scalar = 0
|
54
|
+
Ndarray = 1
|
55
|
+
Texture = 2
|
56
|
+
RwTexture = 3
|
57
|
+
Unknown = 4
|
58
|
+
|
59
|
+
|
60
|
+
class NdArrayAccess(Enum):
|
61
|
+
NoAccess = 0
|
62
|
+
Read = 1
|
63
|
+
Write = 2
|
64
|
+
ReadWrite = 3
|
65
|
+
|
66
|
+
|
67
|
+
class ArgumentNdArray(Argument):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
name: Optional[str],
|
71
|
+
dtype: DataType,
|
72
|
+
element_shape: List[int],
|
73
|
+
ndim: int,
|
74
|
+
access: NdArrayAccess,
|
75
|
+
):
|
76
|
+
super().__init__(name)
|
77
|
+
self.dtype: DataType = dtype
|
78
|
+
self.element_shape: List[int] = element_shape
|
79
|
+
self.ndim: int = ndim
|
80
|
+
self.access: NdArrayAccess = access
|
81
|
+
|
82
|
+
|
83
|
+
class ArgumentTexture(Argument):
|
84
|
+
def __init__(self, name: Optional[str], ndim: int):
|
85
|
+
super().__init__(name)
|
86
|
+
self.ndim: int = ndim
|
87
|
+
|
88
|
+
|
89
|
+
class ArgumentRwTexture(Argument):
|
90
|
+
def __init__(self, name: Optional[str], fmt: Format, ndim: int):
|
91
|
+
super().__init__(name)
|
92
|
+
self.fmt: Format = fmt
|
93
|
+
self.ndim: int = ndim
|
94
|
+
|
95
|
+
|
96
|
+
class ReturnValue:
|
97
|
+
def __init__(self, dtype: DataType):
|
98
|
+
self.dtype: DataType = dtype
|
99
|
+
|
100
|
+
|
101
|
+
class Context:
|
102
|
+
def __init__(self, args: List[Argument], ret: Optional[ReturnValue]):
|
103
|
+
self.args: List[Argument] = args
|
104
|
+
self.ret: Optional[ReturnValue] = ret
|
105
|
+
|
106
|
+
|
107
|
+
class BufferBindingType(Enum):
|
108
|
+
Root = 0
|
109
|
+
GlobalTmps = 1
|
110
|
+
Args = 2
|
111
|
+
Rets = 3
|
112
|
+
ListGen = 4
|
113
|
+
ExtArr = 5
|
114
|
+
|
115
|
+
|
116
|
+
class BufferBinding:
|
117
|
+
def __init__(self, binding: int, iarg: int, buffer_bind_ty: BufferBindingType):
|
118
|
+
self.binding: int = binding
|
119
|
+
self.iarg: int = iarg
|
120
|
+
self.buffer_bind_ty: BufferBindingType = buffer_bind_ty
|
121
|
+
|
122
|
+
|
123
|
+
class TextureBindingType(Enum):
|
124
|
+
Texture = 0
|
125
|
+
RwTexture = 1
|
126
|
+
|
127
|
+
|
128
|
+
class TextureBinding:
|
129
|
+
def __init__(self, binding: int, iarg: int, texture_bind_ty: TextureBindingType):
|
130
|
+
self.binding: int = binding
|
131
|
+
self.iarg: int = iarg
|
132
|
+
self.texture_bind_ty: TextureBindingType = texture_bind_ty
|
133
|
+
|
134
|
+
|
135
|
+
class TaskType(Enum):
|
136
|
+
Serial = 0
|
137
|
+
RangeFor = 1
|
138
|
+
StructFor = 2
|
139
|
+
MeshFor = 3
|
140
|
+
ListGen = 4
|
141
|
+
Gc = 5
|
142
|
+
GcRc = 6
|
143
|
+
|
144
|
+
|
145
|
+
class LaunchGrid:
|
146
|
+
def __init__(self, block_size: int, grid_size: int):
|
147
|
+
self.block_size: int = block_size
|
148
|
+
self.grid_size: int = grid_size
|
149
|
+
|
150
|
+
|
151
|
+
class Task:
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
name: str,
|
155
|
+
task_ty: TaskType,
|
156
|
+
buffer_binds: List[BufferBinding],
|
157
|
+
texture_binds: List[TextureBinding],
|
158
|
+
launch_grid: LaunchGrid,
|
159
|
+
):
|
160
|
+
self.name: str = name
|
161
|
+
self.task_ty: TaskType = task_ty
|
162
|
+
self.buffer_binds: List[BufferBinding] = buffer_binds
|
163
|
+
self.texture_binds: List[TextureBinding] = texture_binds
|
164
|
+
self.launch_grid: LaunchGrid = launch_grid
|
165
|
+
|
166
|
+
|
167
|
+
class Field:
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
name: str,
|
171
|
+
dtype: DataType,
|
172
|
+
element_shape: List[int],
|
173
|
+
shape: List[int],
|
174
|
+
offset: int,
|
175
|
+
):
|
176
|
+
self.name: str = name
|
177
|
+
self.dtype: DataType = dtype
|
178
|
+
self.element_shape: List[int] = element_shape
|
179
|
+
self.shape: List[int] = shape
|
180
|
+
self.offset: int = offset
|
181
|
+
|
182
|
+
|
183
|
+
class Kernel:
|
184
|
+
def __init__(self, name: str, context: Context, tasks: List[Task]):
|
185
|
+
self.name = name
|
186
|
+
self.context: Context = context
|
187
|
+
self.tasks: List[Task] = tasks
|
188
|
+
|
189
|
+
|
190
|
+
class Metadata:
|
191
|
+
def __init__(
|
192
|
+
self,
|
193
|
+
fields: List[Field],
|
194
|
+
kernels: List[Kernel],
|
195
|
+
required_caps: List[DeviceCapability],
|
196
|
+
root_buffer_size: int,
|
197
|
+
):
|
198
|
+
self.fields: Dict[str, Field] = {x.name: x for x in fields}
|
199
|
+
self.kernels: Dict[str, Kernel] = {x.name: x for x in kernels}
|
200
|
+
self.required_caps: List[DeviceCapability] = required_caps
|
201
|
+
self.root_buffer_size: int = root_buffer_size
|
202
|
+
|
203
|
+
|
204
|
+
def from_dr_field(d: dr.FieldAttributes) -> Field:
|
205
|
+
return Field(
|
206
|
+
d.field_name,
|
207
|
+
DataType(d.dtype),
|
208
|
+
d.element_shape,
|
209
|
+
d.shape,
|
210
|
+
d.mem_offset_in_parent,
|
211
|
+
)
|
212
|
+
|
213
|
+
|
214
|
+
def from_dr_kernel(d: dr.KernelAttributes) -> Kernel:
|
215
|
+
assert d.is_jit_evaluator is False
|
216
|
+
|
217
|
+
name = d.name
|
218
|
+
|
219
|
+
class OpaqueArgumentType(Enum):
|
220
|
+
NdArray = 0
|
221
|
+
Texture = 1
|
222
|
+
RwTexture = 2
|
223
|
+
|
224
|
+
tasks = []
|
225
|
+
iarg2arg_ty: Dict[int, OpaqueArgumentType] = {}
|
226
|
+
for task in d.tasks_attribs:
|
227
|
+
# Collect buffer bindings.
|
228
|
+
buffer_binds = []
|
229
|
+
for buffer_bind in task.buffer_binds:
|
230
|
+
binding = buffer_bind.binding
|
231
|
+
iarg = buffer_bind.buffer.root_id
|
232
|
+
buffer_ty = BufferBindingType(buffer_bind.buffer.type)
|
233
|
+
buffer_binds += [BufferBinding(binding, iarg, buffer_ty)]
|
234
|
+
if buffer_ty == BufferBindingType.ExtArr:
|
235
|
+
iarg2arg_ty[buffer_bind.buffer.root_id] = OpaqueArgumentType.NdArray
|
236
|
+
elif buffer_ty == BufferBindingType.Root:
|
237
|
+
pass
|
238
|
+
elif buffer_ty == BufferBindingType.Args:
|
239
|
+
pass
|
240
|
+
elif buffer_ty == BufferBindingType.ListGen:
|
241
|
+
pass
|
242
|
+
elif buffer_ty == BufferBindingType.Rets:
|
243
|
+
pass
|
244
|
+
elif buffer_ty == BufferBindingType.GlobalTmps:
|
245
|
+
pass
|
246
|
+
else:
|
247
|
+
assert False
|
248
|
+
|
249
|
+
# Collect texture bindings.
|
250
|
+
texture_binds = []
|
251
|
+
for texture_bind in task.texture_binds:
|
252
|
+
binding = texture_bind.binding
|
253
|
+
iarg = texture_bind.arg_id
|
254
|
+
if texture_bind.is_storage:
|
255
|
+
texture_binds += [TextureBinding(binding, iarg, TextureBindingType.RwTexture)]
|
256
|
+
iarg2arg_ty[iarg] = OpaqueArgumentType.RwTexture
|
257
|
+
else:
|
258
|
+
texture_binds += [TextureBinding(binding, iarg, TextureBindingType.Texture)]
|
259
|
+
iarg2arg_ty[iarg] = OpaqueArgumentType.Texture
|
260
|
+
|
261
|
+
launch_grid = LaunchGrid(task.advisory_num_threads_per_group, task.advisory_total_num_threads)
|
262
|
+
|
263
|
+
tasks += [
|
264
|
+
Task(
|
265
|
+
task.name,
|
266
|
+
TaskType(task.task_type),
|
267
|
+
buffer_binds,
|
268
|
+
texture_binds,
|
269
|
+
launch_grid,
|
270
|
+
)
|
271
|
+
]
|
272
|
+
|
273
|
+
args = []
|
274
|
+
for i, arg in enumerate(d.ctx_attribs.arg_attribs_vec_):
|
275
|
+
assert i == arg.index
|
276
|
+
ptype = ParameterType(arg.ptype)
|
277
|
+
if ptype is not None:
|
278
|
+
if ptype == ParameterType.Scalar:
|
279
|
+
args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
|
280
|
+
elif ptype == ParameterType.Ndarray:
|
281
|
+
args += [
|
282
|
+
ArgumentNdArray(
|
283
|
+
arg.name,
|
284
|
+
DataType(arg.dtype),
|
285
|
+
arg.element_shape,
|
286
|
+
arg.field_dim,
|
287
|
+
NdArrayAccess(d.ctx_attribs.arr_access[i]),
|
288
|
+
)
|
289
|
+
]
|
290
|
+
elif ptype == ParameterType.Texture:
|
291
|
+
args += [ArgumentTexture(arg.name, arg.field_dim)]
|
292
|
+
elif ptype == ParameterType.RwTexture:
|
293
|
+
args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
|
294
|
+
else:
|
295
|
+
assert False
|
296
|
+
else:
|
297
|
+
# TODO: Keeping this for BC but feel free to break it if necessary
|
298
|
+
if arg.is_array:
|
299
|
+
# Opaque binding types.
|
300
|
+
binding_ty = iarg2arg_ty[arg.index]
|
301
|
+
if binding_ty == OpaqueArgumentType.NdArray:
|
302
|
+
args += [
|
303
|
+
ArgumentNdArray(
|
304
|
+
arg.name,
|
305
|
+
DataType(arg.dtype),
|
306
|
+
arg.element_shape,
|
307
|
+
arg.field_dim,
|
308
|
+
NdArrayAccess(d.ctx_attribs.arr_access[i]),
|
309
|
+
)
|
310
|
+
]
|
311
|
+
elif binding_ty == OpaqueArgumentType.Texture:
|
312
|
+
args += [ArgumentTexture(arg.name, arg.field_dim)]
|
313
|
+
elif binding_ty == OpaqueArgumentType.RwTexture:
|
314
|
+
args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
|
315
|
+
else:
|
316
|
+
assert False
|
317
|
+
else:
|
318
|
+
args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
|
319
|
+
|
320
|
+
assert len(d.ctx_attribs.ret_attribs_vec_) <= 1
|
321
|
+
if len(d.ctx_attribs.ret_attribs_vec_) != 0:
|
322
|
+
dtype = d.ctx_attribs.ret_attribs_vec_[0].dtype
|
323
|
+
rv = ReturnValue(DataType(dtype))
|
324
|
+
else:
|
325
|
+
rv = None
|
326
|
+
|
327
|
+
context = Context(args, rv)
|
328
|
+
|
329
|
+
return Kernel(name, context, tasks)
|
330
|
+
|
331
|
+
|
332
|
+
def from_dr_metadata(d: dr.Metadata) -> Metadata:
|
333
|
+
fields = [from_dr_field(x) for x in d.fields]
|
334
|
+
kernels = [from_dr_kernel(x) for x in d.kernels]
|
335
|
+
required_caps = []
|
336
|
+
for cap in d.required_caps:
|
337
|
+
if cap.value == 1:
|
338
|
+
required_caps += [cap.key]
|
339
|
+
else:
|
340
|
+
required_caps += [f"{cap.key}={cap.value}"]
|
341
|
+
root_buffer_size = d.root_buffer_size
|
342
|
+
|
343
|
+
return Metadata(fields, kernels, required_caps, root_buffer_size)
|
344
|
+
|
345
|
+
|
346
|
+
def to_dr_field(f: Field) -> Dict[str, Any]:
|
347
|
+
raise NotImplementedError()
|
348
|
+
|
349
|
+
|
350
|
+
def to_dr_kernel(s: Kernel) -> Dict[str, Any]:
|
351
|
+
tasks = []
|
352
|
+
for task in s.tasks:
|
353
|
+
buffer_binds = []
|
354
|
+
for buffer_bind in task.buffer_binds:
|
355
|
+
j = {
|
356
|
+
"binding": buffer_bind.binding,
|
357
|
+
"buffer": {
|
358
|
+
"root_id": buffer_bind.iarg,
|
359
|
+
"type": buffer_bind.buffer_bind_ty.value,
|
360
|
+
},
|
361
|
+
}
|
362
|
+
buffer_binds += [j]
|
363
|
+
|
364
|
+
texture_binds = []
|
365
|
+
for texture_bind in task.texture_binds:
|
366
|
+
j = {
|
367
|
+
"arg_id": texture_bind.iarg,
|
368
|
+
"binding": texture_bind.binding,
|
369
|
+
"is_storage": texture_bind.texture_bind_ty == TextureBindingType.RwTexture,
|
370
|
+
}
|
371
|
+
texture_binds += [j]
|
372
|
+
|
373
|
+
if task.task_ty == TaskType.RangeFor:
|
374
|
+
range_for_attribs = {
|
375
|
+
"begin": 0,
|
376
|
+
"const_begin": True,
|
377
|
+
"const_end": True,
|
378
|
+
"end": task.launch_grid.grid_size,
|
379
|
+
}
|
380
|
+
else:
|
381
|
+
range_for_attribs = None
|
382
|
+
|
383
|
+
j = {
|
384
|
+
"advisory_num_threads_per_group": task.launch_grid.block_size,
|
385
|
+
"advisory_total_num_threads": task.launch_grid.grid_size,
|
386
|
+
"buffer_binds": buffer_binds,
|
387
|
+
"name": task.name,
|
388
|
+
"range_for_attribs": range_for_attribs,
|
389
|
+
"task_type": task.task_ty.value,
|
390
|
+
"texture_binds": texture_binds,
|
391
|
+
}
|
392
|
+
tasks += [j]
|
393
|
+
|
394
|
+
args = []
|
395
|
+
arg_bytes = 0
|
396
|
+
arr_access = []
|
397
|
+
arg_offset = 0
|
398
|
+
for i, arg in enumerate(s.context.args):
|
399
|
+
if isinstance(arg, ArgumentNdArray):
|
400
|
+
j = {
|
401
|
+
"dtype": arg.dtype.value,
|
402
|
+
"element_shape": arg.element_shape,
|
403
|
+
"field_dim": arg.ndim,
|
404
|
+
"format": Format.unknown,
|
405
|
+
"index": i,
|
406
|
+
"is_array": True,
|
407
|
+
"offset_in_mem": arg_offset,
|
408
|
+
"stride": 4,
|
409
|
+
}
|
410
|
+
args += [j]
|
411
|
+
arr_access += [arg.access.value]
|
412
|
+
elif isinstance(arg, ArgumentTexture):
|
413
|
+
j = {
|
414
|
+
"dtype": 1,
|
415
|
+
"element_shape": [],
|
416
|
+
"field_dim": arg.ndim,
|
417
|
+
"format": Format.unknown,
|
418
|
+
"index": i,
|
419
|
+
"is_array": True,
|
420
|
+
"offset_in_mem": arg_offset,
|
421
|
+
"stride": 4,
|
422
|
+
}
|
423
|
+
args += [j]
|
424
|
+
arr_access += [0]
|
425
|
+
elif isinstance(arg, ArgumentRwTexture):
|
426
|
+
j = {
|
427
|
+
"dtype": 1,
|
428
|
+
"element_shape": [],
|
429
|
+
"field_dim": arg.ndim,
|
430
|
+
"format": arg.fmt,
|
431
|
+
"index": i,
|
432
|
+
"is_array": True,
|
433
|
+
"offset_in_mem": arg_offset,
|
434
|
+
"stride": 4,
|
435
|
+
}
|
436
|
+
args += [j]
|
437
|
+
arr_access += [0]
|
438
|
+
elif isinstance(arg, ArgumentScalar):
|
439
|
+
j = {
|
440
|
+
"dtype": arg.dtype.value,
|
441
|
+
"element_shape": [],
|
442
|
+
"field_dim": 0,
|
443
|
+
"format": Format.unknown,
|
444
|
+
"index": i,
|
445
|
+
"is_array": False,
|
446
|
+
"offset_in_mem": arg_offset,
|
447
|
+
"stride": get_data_type_size(arg.dtype),
|
448
|
+
}
|
449
|
+
args += [j]
|
450
|
+
arr_access += [0]
|
451
|
+
else:
|
452
|
+
assert False
|
453
|
+
arg_offset += j["stride"]
|
454
|
+
arg_bytes = max(arg_bytes, j["offset_in_mem"] + j["stride"])
|
455
|
+
|
456
|
+
rets = []
|
457
|
+
ret_bytes = 0
|
458
|
+
if s.context.ret is not None:
|
459
|
+
for i, ret in enumerate([s.context.ret]):
|
460
|
+
j = {
|
461
|
+
"dtype": ret.dtype.value,
|
462
|
+
"element_shape": [],
|
463
|
+
"field_dim": 0,
|
464
|
+
"format": Format.unknown,
|
465
|
+
"index": i,
|
466
|
+
"is_array": False,
|
467
|
+
"offset_in_mem": 0,
|
468
|
+
"stride": get_data_type_size(ret.dtype),
|
469
|
+
}
|
470
|
+
rets += [j]
|
471
|
+
ret_bytes = max(ret_bytes, j["offset_in_mem"] + j["stride"])
|
472
|
+
|
473
|
+
ctx_attribs = {
|
474
|
+
"arg_attribs_vec_": args,
|
475
|
+
"args_bytes_": arg_bytes,
|
476
|
+
"arr_access": arr_access,
|
477
|
+
"extra_args_bytes_": 1536,
|
478
|
+
"ret_attribs_vec_": rets,
|
479
|
+
"rets_bytes_": ret_bytes,
|
480
|
+
}
|
481
|
+
|
482
|
+
j = {
|
483
|
+
"is_jit_evaluator": False,
|
484
|
+
"ctx_attribs": ctx_attribs,
|
485
|
+
"name": s.name,
|
486
|
+
"tasks_attribs": tasks,
|
487
|
+
}
|
488
|
+
return j
|
489
|
+
|
490
|
+
|
491
|
+
def to_dr_metadata(s: Metadata) -> dr.Metadata:
|
492
|
+
fields = [to_dr_field(x) for x in s.fields.values()]
|
493
|
+
kernels = [to_dr_kernel(x) for x in s.kernels.values()]
|
494
|
+
required_caps = []
|
495
|
+
for cap in s.required_caps:
|
496
|
+
cap = str(cap)
|
497
|
+
if "=" in cap:
|
498
|
+
k, v = cap.split("=", maxsplit=1)
|
499
|
+
j = {
|
500
|
+
"key": k,
|
501
|
+
"value": int(v),
|
502
|
+
}
|
503
|
+
required_caps += [j]
|
504
|
+
else:
|
505
|
+
j = {
|
506
|
+
"key": cap,
|
507
|
+
"value": 1,
|
508
|
+
}
|
509
|
+
required_caps += [j]
|
510
|
+
root_buffer_size = s.root_buffer_size
|
511
|
+
j = {
|
512
|
+
"fields": fields,
|
513
|
+
"kernels": kernels,
|
514
|
+
"required_caps": required_caps,
|
515
|
+
"root_buffer_size": root_buffer_size,
|
516
|
+
}
|
517
|
+
return dr.Metadata(j)
|
518
|
+
|
519
|
+
|
520
|
+
class NamedArgument:
|
521
|
+
def __init__(self, name: str, arg: Argument):
|
522
|
+
self.name = name
|
523
|
+
self.arg = arg
|
524
|
+
|
525
|
+
|
526
|
+
class Dispatch:
|
527
|
+
def __init__(self, kernel: Kernel, args: List[NamedArgument]):
|
528
|
+
self.kernel = kernel
|
529
|
+
self.args = args
|
530
|
+
|
531
|
+
|
532
|
+
class Graph:
|
533
|
+
def __init__(self, name: str, dispatches: List[Dispatch]):
|
534
|
+
self.name = name
|
535
|
+
self.dispatches = dispatches
|
536
|
+
args = {y.name: y.arg for x in dispatches for y in x.args}
|
537
|
+
self.args: List[NamedArgument] = [NamedArgument(k, v) for k, v in args.items()]
|
538
|
+
|
539
|
+
|
540
|
+
def from_dr_graph(meta: Metadata, j: dr.Graph) -> Graph:
|
541
|
+
dispatches = []
|
542
|
+
for dispatch in j.value.dispatches:
|
543
|
+
kernel = meta.kernels[dispatch.kernel_name]
|
544
|
+
args = []
|
545
|
+
for i, symbolic_arg in enumerate(dispatch.symbolic_args):
|
546
|
+
arg = kernel.context.args[i]
|
547
|
+
args += [NamedArgument(symbolic_arg.name, arg)]
|
548
|
+
dispatches += [Dispatch(kernel, args)]
|
549
|
+
return Graph(j.key, dispatches)
|
550
|
+
|
551
|
+
|
552
|
+
def to_dr_graph(s: Graph) -> dr.Graph:
|
553
|
+
dispatches = []
|
554
|
+
for dispatch in s.dispatches:
|
555
|
+
kernel = dispatch.kernel
|
556
|
+
symbolic_args = []
|
557
|
+
for arg in dispatch.args:
|
558
|
+
if isinstance(arg.arg, ArgumentScalar):
|
559
|
+
j = {
|
560
|
+
"dtype_id": arg.arg.dtype.value,
|
561
|
+
"element_shape": [],
|
562
|
+
"field_dim": 0,
|
563
|
+
"name": arg.name,
|
564
|
+
"num_channels": 0,
|
565
|
+
"tag": 0,
|
566
|
+
}
|
567
|
+
symbolic_args += [j]
|
568
|
+
elif isinstance(arg.arg, ArgumentNdArray):
|
569
|
+
j = {
|
570
|
+
"dtype_id": arg.arg.dtype.value,
|
571
|
+
"element_shape": arg.arg.element_shape,
|
572
|
+
"field_dim": arg.arg.ndim,
|
573
|
+
"name": arg.name,
|
574
|
+
"num_channels": 0,
|
575
|
+
"tag": 2,
|
576
|
+
}
|
577
|
+
symbolic_args += [j]
|
578
|
+
elif isinstance(arg.arg, ArgumentTexture):
|
579
|
+
j = {
|
580
|
+
"dtype_id": DataType.f32.value,
|
581
|
+
"element_shape": [],
|
582
|
+
"field_dim": 0,
|
583
|
+
"name": arg.name,
|
584
|
+
"num_channels": 0,
|
585
|
+
"tag": 3,
|
586
|
+
}
|
587
|
+
symbolic_args += [j]
|
588
|
+
elif isinstance(arg.arg, ArgumentRwTexture):
|
589
|
+
j = {
|
590
|
+
"dtype_id": DataType.f32.value,
|
591
|
+
"element_shape": [],
|
592
|
+
"field_dim": 0,
|
593
|
+
"name": arg.name,
|
594
|
+
"num_channels": 0,
|
595
|
+
"tag": 4,
|
596
|
+
}
|
597
|
+
symbolic_args += [j]
|
598
|
+
else:
|
599
|
+
assert False
|
600
|
+
|
601
|
+
j = {
|
602
|
+
"kernel_name": kernel.name,
|
603
|
+
"symbolic_args": symbolic_args,
|
604
|
+
}
|
605
|
+
dispatches += [j]
|
606
|
+
|
607
|
+
j = {
|
608
|
+
"key": s.name,
|
609
|
+
"value": {
|
610
|
+
"dispatches": dispatches,
|
611
|
+
},
|
612
|
+
}
|
613
|
+
return dr.Graph(j)
|