gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 1.0.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +40 -0
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
- {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
- {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools.lib +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
- gstaichi-1.0.1.dist-info/RECORD +135 -0
- gstaichi-1.0.1.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
- gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/CHANGELOG.md +0 -20
- taichi/__init__.py +0 -44
- taichi/__main__.py +0 -5
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_main.py +0 -552
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,613 +0,0 @@
|
|
1
|
-
# type: ignore
|
2
|
-
|
3
|
-
"""
|
4
|
-
Structured representation of all JSON data structures following the
|
5
|
-
GfxRuntime140.
|
6
|
-
"""
|
7
|
-
|
8
|
-
from abc import ABC
|
9
|
-
from enum import Enum
|
10
|
-
from typing import Any, Dict, List, Optional
|
11
|
-
|
12
|
-
from taichi.aot.conventions.gfxruntime140 import dr
|
13
|
-
from taichi.types.enums import DeviceCapability, Format
|
14
|
-
|
15
|
-
|
16
|
-
class DataType(Enum):
|
17
|
-
f16 = 0
|
18
|
-
f32 = 1
|
19
|
-
f64 = 2
|
20
|
-
i8 = 3
|
21
|
-
i16 = 4
|
22
|
-
i32 = 5
|
23
|
-
i64 = 6
|
24
|
-
u8 = 8
|
25
|
-
u16 = 9
|
26
|
-
u32 = 10
|
27
|
-
u64 = 11
|
28
|
-
|
29
|
-
|
30
|
-
def get_data_type_size(dtype: DataType) -> int:
|
31
|
-
if dtype in [DataType.f16, DataType.i16, DataType.u16]:
|
32
|
-
return 2
|
33
|
-
if dtype in [DataType.f32, DataType.i32, DataType.u32]:
|
34
|
-
return 4
|
35
|
-
if dtype in [DataType.f64, DataType.i64, DataType.u64]:
|
36
|
-
return 8
|
37
|
-
assert False
|
38
|
-
|
39
|
-
|
40
|
-
class Argument(ABC):
|
41
|
-
def __init__(self, name: Optional[str]):
|
42
|
-
self.name = name
|
43
|
-
pass
|
44
|
-
|
45
|
-
|
46
|
-
class ArgumentScalar(Argument):
|
47
|
-
def __init__(self, name: Optional[str], dtype: DataType):
|
48
|
-
super().__init__(name)
|
49
|
-
self.dtype: DataType = dtype
|
50
|
-
|
51
|
-
|
52
|
-
class ParameterType(Enum):
|
53
|
-
Scalar = 0
|
54
|
-
Ndarray = 1
|
55
|
-
Texture = 2
|
56
|
-
RwTexture = 3
|
57
|
-
Unknown = 4
|
58
|
-
|
59
|
-
|
60
|
-
class NdArrayAccess(Enum):
|
61
|
-
NoAccess = 0
|
62
|
-
Read = 1
|
63
|
-
Write = 2
|
64
|
-
ReadWrite = 3
|
65
|
-
|
66
|
-
|
67
|
-
class ArgumentNdArray(Argument):
|
68
|
-
def __init__(
|
69
|
-
self,
|
70
|
-
name: Optional[str],
|
71
|
-
dtype: DataType,
|
72
|
-
element_shape: List[int],
|
73
|
-
ndim: int,
|
74
|
-
access: NdArrayAccess,
|
75
|
-
):
|
76
|
-
super().__init__(name)
|
77
|
-
self.dtype: DataType = dtype
|
78
|
-
self.element_shape: List[int] = element_shape
|
79
|
-
self.ndim: int = ndim
|
80
|
-
self.access: NdArrayAccess = access
|
81
|
-
|
82
|
-
|
83
|
-
class ArgumentTexture(Argument):
|
84
|
-
def __init__(self, name: Optional[str], ndim: int):
|
85
|
-
super().__init__(name)
|
86
|
-
self.ndim: int = ndim
|
87
|
-
|
88
|
-
|
89
|
-
class ArgumentRwTexture(Argument):
|
90
|
-
def __init__(self, name: Optional[str], fmt: Format, ndim: int):
|
91
|
-
super().__init__(name)
|
92
|
-
self.fmt: Format = fmt
|
93
|
-
self.ndim: int = ndim
|
94
|
-
|
95
|
-
|
96
|
-
class ReturnValue:
|
97
|
-
def __init__(self, dtype: DataType):
|
98
|
-
self.dtype: DataType = dtype
|
99
|
-
|
100
|
-
|
101
|
-
class Context:
|
102
|
-
def __init__(self, args: List[Argument], ret: Optional[ReturnValue]):
|
103
|
-
self.args: List[Argument] = args
|
104
|
-
self.ret: Optional[ReturnValue] = ret
|
105
|
-
|
106
|
-
|
107
|
-
class BufferBindingType(Enum):
|
108
|
-
Root = 0
|
109
|
-
GlobalTmps = 1
|
110
|
-
Args = 2
|
111
|
-
Rets = 3
|
112
|
-
ListGen = 4
|
113
|
-
ExtArr = 5
|
114
|
-
|
115
|
-
|
116
|
-
class BufferBinding:
|
117
|
-
def __init__(self, binding: int, iarg: int, buffer_bind_ty: BufferBindingType):
|
118
|
-
self.binding: int = binding
|
119
|
-
self.iarg: int = iarg
|
120
|
-
self.buffer_bind_ty: BufferBindingType = buffer_bind_ty
|
121
|
-
|
122
|
-
|
123
|
-
class TextureBindingType(Enum):
|
124
|
-
Texture = 0
|
125
|
-
RwTexture = 1
|
126
|
-
|
127
|
-
|
128
|
-
class TextureBinding:
|
129
|
-
def __init__(self, binding: int, iarg: int, texture_bind_ty: TextureBindingType):
|
130
|
-
self.binding: int = binding
|
131
|
-
self.iarg: int = iarg
|
132
|
-
self.texture_bind_ty: TextureBindingType = texture_bind_ty
|
133
|
-
|
134
|
-
|
135
|
-
class TaskType(Enum):
|
136
|
-
Serial = 0
|
137
|
-
RangeFor = 1
|
138
|
-
StructFor = 2
|
139
|
-
MeshFor = 3
|
140
|
-
ListGen = 4
|
141
|
-
Gc = 5
|
142
|
-
GcRc = 6
|
143
|
-
|
144
|
-
|
145
|
-
class LaunchGrid:
|
146
|
-
def __init__(self, block_size: int, grid_size: int):
|
147
|
-
self.block_size: int = block_size
|
148
|
-
self.grid_size: int = grid_size
|
149
|
-
|
150
|
-
|
151
|
-
class Task:
|
152
|
-
def __init__(
|
153
|
-
self,
|
154
|
-
name: str,
|
155
|
-
task_ty: TaskType,
|
156
|
-
buffer_binds: List[BufferBinding],
|
157
|
-
texture_binds: List[TextureBinding],
|
158
|
-
launch_grid: LaunchGrid,
|
159
|
-
):
|
160
|
-
self.name: str = name
|
161
|
-
self.task_ty: TaskType = task_ty
|
162
|
-
self.buffer_binds: List[BufferBinding] = buffer_binds
|
163
|
-
self.texture_binds: List[TextureBinding] = texture_binds
|
164
|
-
self.launch_grid: LaunchGrid = launch_grid
|
165
|
-
|
166
|
-
|
167
|
-
class Field:
|
168
|
-
def __init__(
|
169
|
-
self,
|
170
|
-
name: str,
|
171
|
-
dtype: DataType,
|
172
|
-
element_shape: List[int],
|
173
|
-
shape: List[int],
|
174
|
-
offset: int,
|
175
|
-
):
|
176
|
-
self.name: str = name
|
177
|
-
self.dtype: DataType = dtype
|
178
|
-
self.element_shape: List[int] = element_shape
|
179
|
-
self.shape: List[int] = shape
|
180
|
-
self.offset: int = offset
|
181
|
-
|
182
|
-
|
183
|
-
class Kernel:
|
184
|
-
def __init__(self, name: str, context: Context, tasks: List[Task]):
|
185
|
-
self.name = name
|
186
|
-
self.context: Context = context
|
187
|
-
self.tasks: List[Task] = tasks
|
188
|
-
|
189
|
-
|
190
|
-
class Metadata:
|
191
|
-
def __init__(
|
192
|
-
self,
|
193
|
-
fields: List[Field],
|
194
|
-
kernels: List[Kernel],
|
195
|
-
required_caps: List[DeviceCapability],
|
196
|
-
root_buffer_size: int,
|
197
|
-
):
|
198
|
-
self.fields: Dict[str, Field] = {x.name: x for x in fields}
|
199
|
-
self.kernels: Dict[str, Kernel] = {x.name: x for x in kernels}
|
200
|
-
self.required_caps: List[DeviceCapability] = required_caps
|
201
|
-
self.root_buffer_size: int = root_buffer_size
|
202
|
-
|
203
|
-
|
204
|
-
def from_dr_field(d: dr.FieldAttributes) -> Field:
|
205
|
-
return Field(
|
206
|
-
d.field_name,
|
207
|
-
DataType(d.dtype),
|
208
|
-
d.element_shape,
|
209
|
-
d.shape,
|
210
|
-
d.mem_offset_in_parent,
|
211
|
-
)
|
212
|
-
|
213
|
-
|
214
|
-
def from_dr_kernel(d: dr.KernelAttributes) -> Kernel:
|
215
|
-
assert d.is_jit_evaluator is False
|
216
|
-
|
217
|
-
name = d.name
|
218
|
-
|
219
|
-
class OpaqueArgumentType(Enum):
|
220
|
-
NdArray = 0
|
221
|
-
Texture = 1
|
222
|
-
RwTexture = 2
|
223
|
-
|
224
|
-
tasks = []
|
225
|
-
iarg2arg_ty: Dict[int, OpaqueArgumentType] = {}
|
226
|
-
for task in d.tasks_attribs:
|
227
|
-
# Collect buffer bindings.
|
228
|
-
buffer_binds = []
|
229
|
-
for buffer_bind in task.buffer_binds:
|
230
|
-
binding = buffer_bind.binding
|
231
|
-
iarg = buffer_bind.buffer.root_id
|
232
|
-
buffer_ty = BufferBindingType(buffer_bind.buffer.type)
|
233
|
-
buffer_binds += [BufferBinding(binding, iarg, buffer_ty)]
|
234
|
-
if buffer_ty == BufferBindingType.ExtArr:
|
235
|
-
iarg2arg_ty[buffer_bind.buffer.root_id] = OpaqueArgumentType.NdArray
|
236
|
-
elif buffer_ty == BufferBindingType.Root:
|
237
|
-
pass
|
238
|
-
elif buffer_ty == BufferBindingType.Args:
|
239
|
-
pass
|
240
|
-
elif buffer_ty == BufferBindingType.ListGen:
|
241
|
-
pass
|
242
|
-
elif buffer_ty == BufferBindingType.Rets:
|
243
|
-
pass
|
244
|
-
elif buffer_ty == BufferBindingType.GlobalTmps:
|
245
|
-
pass
|
246
|
-
else:
|
247
|
-
assert False
|
248
|
-
|
249
|
-
# Collect texture bindings.
|
250
|
-
texture_binds = []
|
251
|
-
for texture_bind in task.texture_binds:
|
252
|
-
binding = texture_bind.binding
|
253
|
-
iarg = texture_bind.arg_id
|
254
|
-
if texture_bind.is_storage:
|
255
|
-
texture_binds += [TextureBinding(binding, iarg, TextureBindingType.RwTexture)]
|
256
|
-
iarg2arg_ty[iarg] = OpaqueArgumentType.RwTexture
|
257
|
-
else:
|
258
|
-
texture_binds += [TextureBinding(binding, iarg, TextureBindingType.Texture)]
|
259
|
-
iarg2arg_ty[iarg] = OpaqueArgumentType.Texture
|
260
|
-
|
261
|
-
launch_grid = LaunchGrid(task.advisory_num_threads_per_group, task.advisory_total_num_threads)
|
262
|
-
|
263
|
-
tasks += [
|
264
|
-
Task(
|
265
|
-
task.name,
|
266
|
-
TaskType(task.task_type),
|
267
|
-
buffer_binds,
|
268
|
-
texture_binds,
|
269
|
-
launch_grid,
|
270
|
-
)
|
271
|
-
]
|
272
|
-
|
273
|
-
args = []
|
274
|
-
for i, arg in enumerate(d.ctx_attribs.arg_attribs_vec_):
|
275
|
-
assert i == arg.index
|
276
|
-
ptype = ParameterType(arg.ptype)
|
277
|
-
if ptype is not None:
|
278
|
-
if ptype == ParameterType.Scalar:
|
279
|
-
args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
|
280
|
-
elif ptype == ParameterType.Ndarray:
|
281
|
-
args += [
|
282
|
-
ArgumentNdArray(
|
283
|
-
arg.name,
|
284
|
-
DataType(arg.dtype),
|
285
|
-
arg.element_shape,
|
286
|
-
arg.field_dim,
|
287
|
-
NdArrayAccess(d.ctx_attribs.arr_access[i]),
|
288
|
-
)
|
289
|
-
]
|
290
|
-
elif ptype == ParameterType.Texture:
|
291
|
-
args += [ArgumentTexture(arg.name, arg.field_dim)]
|
292
|
-
elif ptype == ParameterType.RwTexture:
|
293
|
-
args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
|
294
|
-
else:
|
295
|
-
assert False
|
296
|
-
else:
|
297
|
-
# TODO: Keeping this for BC but feel free to break it if necessary
|
298
|
-
if arg.is_array:
|
299
|
-
# Opaque binding types.
|
300
|
-
binding_ty = iarg2arg_ty[arg.index]
|
301
|
-
if binding_ty == OpaqueArgumentType.NdArray:
|
302
|
-
args += [
|
303
|
-
ArgumentNdArray(
|
304
|
-
arg.name,
|
305
|
-
DataType(arg.dtype),
|
306
|
-
arg.element_shape,
|
307
|
-
arg.field_dim,
|
308
|
-
NdArrayAccess(d.ctx_attribs.arr_access[i]),
|
309
|
-
)
|
310
|
-
]
|
311
|
-
elif binding_ty == OpaqueArgumentType.Texture:
|
312
|
-
args += [ArgumentTexture(arg.name, arg.field_dim)]
|
313
|
-
elif binding_ty == OpaqueArgumentType.RwTexture:
|
314
|
-
args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
|
315
|
-
else:
|
316
|
-
assert False
|
317
|
-
else:
|
318
|
-
args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
|
319
|
-
|
320
|
-
assert len(d.ctx_attribs.ret_attribs_vec_) <= 1
|
321
|
-
if len(d.ctx_attribs.ret_attribs_vec_) != 0:
|
322
|
-
dtype = d.ctx_attribs.ret_attribs_vec_[0].dtype
|
323
|
-
rv = ReturnValue(DataType(dtype))
|
324
|
-
else:
|
325
|
-
rv = None
|
326
|
-
|
327
|
-
context = Context(args, rv)
|
328
|
-
|
329
|
-
return Kernel(name, context, tasks)
|
330
|
-
|
331
|
-
|
332
|
-
def from_dr_metadata(d: dr.Metadata) -> Metadata:
|
333
|
-
fields = [from_dr_field(x) for x in d.fields]
|
334
|
-
kernels = [from_dr_kernel(x) for x in d.kernels]
|
335
|
-
required_caps = []
|
336
|
-
for cap in d.required_caps:
|
337
|
-
if cap.value == 1:
|
338
|
-
required_caps += [cap.key]
|
339
|
-
else:
|
340
|
-
required_caps += [f"{cap.key}={cap.value}"]
|
341
|
-
root_buffer_size = d.root_buffer_size
|
342
|
-
|
343
|
-
return Metadata(fields, kernels, required_caps, root_buffer_size)
|
344
|
-
|
345
|
-
|
346
|
-
def to_dr_field(f: Field) -> Dict[str, Any]:
|
347
|
-
raise NotImplementedError()
|
348
|
-
|
349
|
-
|
350
|
-
def to_dr_kernel(s: Kernel) -> Dict[str, Any]:
|
351
|
-
tasks = []
|
352
|
-
for task in s.tasks:
|
353
|
-
buffer_binds = []
|
354
|
-
for buffer_bind in task.buffer_binds:
|
355
|
-
j = {
|
356
|
-
"binding": buffer_bind.binding,
|
357
|
-
"buffer": {
|
358
|
-
"root_id": buffer_bind.iarg,
|
359
|
-
"type": buffer_bind.buffer_bind_ty.value,
|
360
|
-
},
|
361
|
-
}
|
362
|
-
buffer_binds += [j]
|
363
|
-
|
364
|
-
texture_binds = []
|
365
|
-
for texture_bind in task.texture_binds:
|
366
|
-
j = {
|
367
|
-
"arg_id": texture_bind.iarg,
|
368
|
-
"binding": texture_bind.binding,
|
369
|
-
"is_storage": texture_bind.texture_bind_ty == TextureBindingType.RwTexture,
|
370
|
-
}
|
371
|
-
texture_binds += [j]
|
372
|
-
|
373
|
-
if task.task_ty == TaskType.RangeFor:
|
374
|
-
range_for_attribs = {
|
375
|
-
"begin": 0,
|
376
|
-
"const_begin": True,
|
377
|
-
"const_end": True,
|
378
|
-
"end": task.launch_grid.grid_size,
|
379
|
-
}
|
380
|
-
else:
|
381
|
-
range_for_attribs = None
|
382
|
-
|
383
|
-
j = {
|
384
|
-
"advisory_num_threads_per_group": task.launch_grid.block_size,
|
385
|
-
"advisory_total_num_threads": task.launch_grid.grid_size,
|
386
|
-
"buffer_binds": buffer_binds,
|
387
|
-
"name": task.name,
|
388
|
-
"range_for_attribs": range_for_attribs,
|
389
|
-
"task_type": task.task_ty.value,
|
390
|
-
"texture_binds": texture_binds,
|
391
|
-
}
|
392
|
-
tasks += [j]
|
393
|
-
|
394
|
-
args = []
|
395
|
-
arg_bytes = 0
|
396
|
-
arr_access = []
|
397
|
-
arg_offset = 0
|
398
|
-
for i, arg in enumerate(s.context.args):
|
399
|
-
if isinstance(arg, ArgumentNdArray):
|
400
|
-
j = {
|
401
|
-
"dtype": arg.dtype.value,
|
402
|
-
"element_shape": arg.element_shape,
|
403
|
-
"field_dim": arg.ndim,
|
404
|
-
"format": Format.unknown,
|
405
|
-
"index": i,
|
406
|
-
"is_array": True,
|
407
|
-
"offset_in_mem": arg_offset,
|
408
|
-
"stride": 4,
|
409
|
-
}
|
410
|
-
args += [j]
|
411
|
-
arr_access += [arg.access.value]
|
412
|
-
elif isinstance(arg, ArgumentTexture):
|
413
|
-
j = {
|
414
|
-
"dtype": 1,
|
415
|
-
"element_shape": [],
|
416
|
-
"field_dim": arg.ndim,
|
417
|
-
"format": Format.unknown,
|
418
|
-
"index": i,
|
419
|
-
"is_array": True,
|
420
|
-
"offset_in_mem": arg_offset,
|
421
|
-
"stride": 4,
|
422
|
-
}
|
423
|
-
args += [j]
|
424
|
-
arr_access += [0]
|
425
|
-
elif isinstance(arg, ArgumentRwTexture):
|
426
|
-
j = {
|
427
|
-
"dtype": 1,
|
428
|
-
"element_shape": [],
|
429
|
-
"field_dim": arg.ndim,
|
430
|
-
"format": arg.fmt,
|
431
|
-
"index": i,
|
432
|
-
"is_array": True,
|
433
|
-
"offset_in_mem": arg_offset,
|
434
|
-
"stride": 4,
|
435
|
-
}
|
436
|
-
args += [j]
|
437
|
-
arr_access += [0]
|
438
|
-
elif isinstance(arg, ArgumentScalar):
|
439
|
-
j = {
|
440
|
-
"dtype": arg.dtype.value,
|
441
|
-
"element_shape": [],
|
442
|
-
"field_dim": 0,
|
443
|
-
"format": Format.unknown,
|
444
|
-
"index": i,
|
445
|
-
"is_array": False,
|
446
|
-
"offset_in_mem": arg_offset,
|
447
|
-
"stride": get_data_type_size(arg.dtype),
|
448
|
-
}
|
449
|
-
args += [j]
|
450
|
-
arr_access += [0]
|
451
|
-
else:
|
452
|
-
assert False
|
453
|
-
arg_offset += j["stride"]
|
454
|
-
arg_bytes = max(arg_bytes, j["offset_in_mem"] + j["stride"])
|
455
|
-
|
456
|
-
rets = []
|
457
|
-
ret_bytes = 0
|
458
|
-
if s.context.ret is not None:
|
459
|
-
for i, ret in enumerate([s.context.ret]):
|
460
|
-
j = {
|
461
|
-
"dtype": ret.dtype.value,
|
462
|
-
"element_shape": [],
|
463
|
-
"field_dim": 0,
|
464
|
-
"format": Format.unknown,
|
465
|
-
"index": i,
|
466
|
-
"is_array": False,
|
467
|
-
"offset_in_mem": 0,
|
468
|
-
"stride": get_data_type_size(ret.dtype),
|
469
|
-
}
|
470
|
-
rets += [j]
|
471
|
-
ret_bytes = max(ret_bytes, j["offset_in_mem"] + j["stride"])
|
472
|
-
|
473
|
-
ctx_attribs = {
|
474
|
-
"arg_attribs_vec_": args,
|
475
|
-
"args_bytes_": arg_bytes,
|
476
|
-
"arr_access": arr_access,
|
477
|
-
"extra_args_bytes_": 1536,
|
478
|
-
"ret_attribs_vec_": rets,
|
479
|
-
"rets_bytes_": ret_bytes,
|
480
|
-
}
|
481
|
-
|
482
|
-
j = {
|
483
|
-
"is_jit_evaluator": False,
|
484
|
-
"ctx_attribs": ctx_attribs,
|
485
|
-
"name": s.name,
|
486
|
-
"tasks_attribs": tasks,
|
487
|
-
}
|
488
|
-
return j
|
489
|
-
|
490
|
-
|
491
|
-
def to_dr_metadata(s: Metadata) -> dr.Metadata:
|
492
|
-
fields = [to_dr_field(x) for x in s.fields.values()]
|
493
|
-
kernels = [to_dr_kernel(x) for x in s.kernels.values()]
|
494
|
-
required_caps = []
|
495
|
-
for cap in s.required_caps:
|
496
|
-
cap = str(cap)
|
497
|
-
if "=" in cap:
|
498
|
-
k, v = cap.split("=", maxsplit=1)
|
499
|
-
j = {
|
500
|
-
"key": k,
|
501
|
-
"value": int(v),
|
502
|
-
}
|
503
|
-
required_caps += [j]
|
504
|
-
else:
|
505
|
-
j = {
|
506
|
-
"key": cap,
|
507
|
-
"value": 1,
|
508
|
-
}
|
509
|
-
required_caps += [j]
|
510
|
-
root_buffer_size = s.root_buffer_size
|
511
|
-
j = {
|
512
|
-
"fields": fields,
|
513
|
-
"kernels": kernels,
|
514
|
-
"required_caps": required_caps,
|
515
|
-
"root_buffer_size": root_buffer_size,
|
516
|
-
}
|
517
|
-
return dr.Metadata(j)
|
518
|
-
|
519
|
-
|
520
|
-
class NamedArgument:
|
521
|
-
def __init__(self, name: str, arg: Argument):
|
522
|
-
self.name = name
|
523
|
-
self.arg = arg
|
524
|
-
|
525
|
-
|
526
|
-
class Dispatch:
|
527
|
-
def __init__(self, kernel: Kernel, args: List[NamedArgument]):
|
528
|
-
self.kernel = kernel
|
529
|
-
self.args = args
|
530
|
-
|
531
|
-
|
532
|
-
class Graph:
|
533
|
-
def __init__(self, name: str, dispatches: List[Dispatch]):
|
534
|
-
self.name = name
|
535
|
-
self.dispatches = dispatches
|
536
|
-
args = {y.name: y.arg for x in dispatches for y in x.args}
|
537
|
-
self.args: List[NamedArgument] = [NamedArgument(k, v) for k, v in args.items()]
|
538
|
-
|
539
|
-
|
540
|
-
def from_dr_graph(meta: Metadata, j: dr.Graph) -> Graph:
|
541
|
-
dispatches = []
|
542
|
-
for dispatch in j.value.dispatches:
|
543
|
-
kernel = meta.kernels[dispatch.kernel_name]
|
544
|
-
args = []
|
545
|
-
for i, symbolic_arg in enumerate(dispatch.symbolic_args):
|
546
|
-
arg = kernel.context.args[i]
|
547
|
-
args += [NamedArgument(symbolic_arg.name, arg)]
|
548
|
-
dispatches += [Dispatch(kernel, args)]
|
549
|
-
return Graph(j.key, dispatches)
|
550
|
-
|
551
|
-
|
552
|
-
def to_dr_graph(s: Graph) -> dr.Graph:
|
553
|
-
dispatches = []
|
554
|
-
for dispatch in s.dispatches:
|
555
|
-
kernel = dispatch.kernel
|
556
|
-
symbolic_args = []
|
557
|
-
for arg in dispatch.args:
|
558
|
-
if isinstance(arg.arg, ArgumentScalar):
|
559
|
-
j = {
|
560
|
-
"dtype_id": arg.arg.dtype.value,
|
561
|
-
"element_shape": [],
|
562
|
-
"field_dim": 0,
|
563
|
-
"name": arg.name,
|
564
|
-
"num_channels": 0,
|
565
|
-
"tag": 0,
|
566
|
-
}
|
567
|
-
symbolic_args += [j]
|
568
|
-
elif isinstance(arg.arg, ArgumentNdArray):
|
569
|
-
j = {
|
570
|
-
"dtype_id": arg.arg.dtype.value,
|
571
|
-
"element_shape": arg.arg.element_shape,
|
572
|
-
"field_dim": arg.arg.ndim,
|
573
|
-
"name": arg.name,
|
574
|
-
"num_channels": 0,
|
575
|
-
"tag": 2,
|
576
|
-
}
|
577
|
-
symbolic_args += [j]
|
578
|
-
elif isinstance(arg.arg, ArgumentTexture):
|
579
|
-
j = {
|
580
|
-
"dtype_id": DataType.f32.value,
|
581
|
-
"element_shape": [],
|
582
|
-
"field_dim": 0,
|
583
|
-
"name": arg.name,
|
584
|
-
"num_channels": 0,
|
585
|
-
"tag": 3,
|
586
|
-
}
|
587
|
-
symbolic_args += [j]
|
588
|
-
elif isinstance(arg.arg, ArgumentRwTexture):
|
589
|
-
j = {
|
590
|
-
"dtype_id": DataType.f32.value,
|
591
|
-
"element_shape": [],
|
592
|
-
"field_dim": 0,
|
593
|
-
"name": arg.name,
|
594
|
-
"num_channels": 0,
|
595
|
-
"tag": 4,
|
596
|
-
}
|
597
|
-
symbolic_args += [j]
|
598
|
-
else:
|
599
|
-
assert False
|
600
|
-
|
601
|
-
j = {
|
602
|
-
"kernel_name": kernel.name,
|
603
|
-
"symbolic_args": symbolic_args,
|
604
|
-
}
|
605
|
-
dispatches += [j]
|
606
|
-
|
607
|
-
j = {
|
608
|
-
"key": s.name,
|
609
|
-
"value": {
|
610
|
-
"dispatches": dispatches,
|
611
|
-
},
|
612
|
-
}
|
613
|
-
return dr.Graph(j)
|