gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  2. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  3. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  4. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  5. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  6. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  7. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  8. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  9. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  10. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  11. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  12. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  13. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  14. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  15. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  16. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  17. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  18. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  19. gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  20. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  21. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  25. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  26. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
  39. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  40. gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
  41. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  42. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  43. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  44. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  45. taichi/CHANGELOG.md +15 -0
  46. taichi/__init__.py +44 -0
  47. taichi/__main__.py +5 -0
  48. taichi/_funcs.py +706 -0
  49. taichi/_kernels.py +420 -0
  50. taichi/_lib/__init__.py +3 -0
  51. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  52. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  53. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  54. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  55. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  56. taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
  57. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  58. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  59. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  60. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  61. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  62. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  63. taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
  64. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  65. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  66. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  67. taichi/_lib/core/__init__.py +0 -0
  68. taichi/_lib/core/py.typed +0 -0
  69. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
  70. taichi/_lib/core/taichi_python.pyi +3077 -0
  71. taichi/_lib/runtime/runtime_cuda.bc +0 -0
  72. taichi/_lib/runtime/runtime_x64.bc +0 -0
  73. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  74. taichi/_lib/utils.py +249 -0
  75. taichi/_logging.py +131 -0
  76. taichi/_main.py +552 -0
  77. taichi/_snode/__init__.py +5 -0
  78. taichi/_snode/fields_builder.py +189 -0
  79. taichi/_snode/snode_tree.py +34 -0
  80. taichi/_ti_module/__init__.py +3 -0
  81. taichi/_ti_module/cppgen.py +309 -0
  82. taichi/_ti_module/module.py +145 -0
  83. taichi/_version.py +1 -0
  84. taichi/_version_check.py +100 -0
  85. taichi/ad/__init__.py +3 -0
  86. taichi/ad/_ad.py +530 -0
  87. taichi/algorithms/__init__.py +3 -0
  88. taichi/algorithms/_algorithms.py +117 -0
  89. taichi/aot/__init__.py +12 -0
  90. taichi/aot/_export.py +28 -0
  91. taichi/aot/conventions/__init__.py +3 -0
  92. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  93. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  94. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  95. taichi/aot/module.py +253 -0
  96. taichi/aot/utils.py +151 -0
  97. taichi/assets/.git +1 -0
  98. taichi/assets/Go-Regular.ttf +0 -0
  99. taichi/assets/static/imgs/ti_gallery.png +0 -0
  100. taichi/examples/minimal.py +28 -0
  101. taichi/experimental.py +16 -0
  102. taichi/graph/__init__.py +3 -0
  103. taichi/graph/_graph.py +292 -0
  104. taichi/lang/__init__.py +50 -0
  105. taichi/lang/_ndarray.py +348 -0
  106. taichi/lang/_ndrange.py +152 -0
  107. taichi/lang/_texture.py +172 -0
  108. taichi/lang/_wrap_inspect.py +189 -0
  109. taichi/lang/any_array.py +99 -0
  110. taichi/lang/argpack.py +411 -0
  111. taichi/lang/ast/__init__.py +5 -0
  112. taichi/lang/ast/ast_transformer.py +1806 -0
  113. taichi/lang/ast/ast_transformer_utils.py +328 -0
  114. taichi/lang/ast/checkers.py +106 -0
  115. taichi/lang/ast/symbol_resolver.py +57 -0
  116. taichi/lang/ast/transform.py +9 -0
  117. taichi/lang/common_ops.py +310 -0
  118. taichi/lang/exception.py +80 -0
  119. taichi/lang/expr.py +180 -0
  120. taichi/lang/field.py +464 -0
  121. taichi/lang/impl.py +1246 -0
  122. taichi/lang/kernel_arguments.py +157 -0
  123. taichi/lang/kernel_impl.py +1415 -0
  124. taichi/lang/matrix.py +1877 -0
  125. taichi/lang/matrix_ops.py +341 -0
  126. taichi/lang/matrix_ops_utils.py +190 -0
  127. taichi/lang/mesh.py +687 -0
  128. taichi/lang/misc.py +807 -0
  129. taichi/lang/ops.py +1489 -0
  130. taichi/lang/runtime_ops.py +13 -0
  131. taichi/lang/shell.py +35 -0
  132. taichi/lang/simt/__init__.py +5 -0
  133. taichi/lang/simt/block.py +94 -0
  134. taichi/lang/simt/grid.py +7 -0
  135. taichi/lang/simt/subgroup.py +191 -0
  136. taichi/lang/simt/warp.py +96 -0
  137. taichi/lang/snode.py +487 -0
  138. taichi/lang/source_builder.py +150 -0
  139. taichi/lang/struct.py +855 -0
  140. taichi/lang/util.py +381 -0
  141. taichi/linalg/__init__.py +8 -0
  142. taichi/linalg/matrixfree_cg.py +310 -0
  143. taichi/linalg/sparse_cg.py +59 -0
  144. taichi/linalg/sparse_matrix.py +303 -0
  145. taichi/linalg/sparse_solver.py +123 -0
  146. taichi/math/__init__.py +11 -0
  147. taichi/math/_complex.py +204 -0
  148. taichi/math/mathimpl.py +886 -0
  149. taichi/profiler/__init__.py +6 -0
  150. taichi/profiler/kernel_metrics.py +260 -0
  151. taichi/profiler/kernel_profiler.py +592 -0
  152. taichi/profiler/memory_profiler.py +15 -0
  153. taichi/profiler/scoped_profiler.py +36 -0
  154. taichi/shaders/Circles_vk.frag +29 -0
  155. taichi/shaders/Circles_vk.vert +45 -0
  156. taichi/shaders/Circles_vk_frag.spv +0 -0
  157. taichi/shaders/Circles_vk_vert.spv +0 -0
  158. taichi/shaders/Lines_vk.frag +9 -0
  159. taichi/shaders/Lines_vk.vert +11 -0
  160. taichi/shaders/Lines_vk_frag.spv +0 -0
  161. taichi/shaders/Lines_vk_vert.spv +0 -0
  162. taichi/shaders/Mesh_vk.frag +71 -0
  163. taichi/shaders/Mesh_vk.vert +68 -0
  164. taichi/shaders/Mesh_vk_frag.spv +0 -0
  165. taichi/shaders/Mesh_vk_vert.spv +0 -0
  166. taichi/shaders/Particles_vk.frag +95 -0
  167. taichi/shaders/Particles_vk.vert +73 -0
  168. taichi/shaders/Particles_vk_frag.spv +0 -0
  169. taichi/shaders/Particles_vk_vert.spv +0 -0
  170. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  171. taichi/shaders/SceneLines_vk.frag +9 -0
  172. taichi/shaders/SceneLines_vk.vert +12 -0
  173. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  174. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  175. taichi/shaders/SetImage_vk.frag +21 -0
  176. taichi/shaders/SetImage_vk.vert +15 -0
  177. taichi/shaders/SetImage_vk_frag.spv +0 -0
  178. taichi/shaders/SetImage_vk_vert.spv +0 -0
  179. taichi/shaders/Triangles_vk.frag +16 -0
  180. taichi/shaders/Triangles_vk.vert +29 -0
  181. taichi/shaders/Triangles_vk_frag.spv +0 -0
  182. taichi/shaders/Triangles_vk_vert.spv +0 -0
  183. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  184. taichi/sparse/__init__.py +3 -0
  185. taichi/sparse/_sparse_grid.py +77 -0
  186. taichi/tools/__init__.py +12 -0
  187. taichi/tools/diagnose.py +124 -0
  188. taichi/tools/np2ply.py +364 -0
  189. taichi/tools/vtk.py +38 -0
  190. taichi/types/__init__.py +19 -0
  191. taichi/types/annotations.py +47 -0
  192. taichi/types/compound_types.py +90 -0
  193. taichi/types/enums.py +49 -0
  194. taichi/types/ndarray_type.py +147 -0
  195. taichi/types/primitive_types.py +203 -0
  196. taichi/types/quant.py +88 -0
  197. taichi/types/texture_type.py +85 -0
  198. taichi/types/utils.py +13 -0
@@ -0,0 +1,613 @@
1
+ # type: ignore
2
+
3
+ """
4
+ Structured representation of all JSON data structures following the
5
+ GfxRuntime140.
6
+ """
7
+
8
+ from abc import ABC
9
+ from enum import Enum
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ from taichi.aot.conventions.gfxruntime140 import dr
13
+ from taichi.types.enums import DeviceCapability, Format
14
+
15
+
16
+ class DataType(Enum):
17
+ f16 = 0
18
+ f32 = 1
19
+ f64 = 2
20
+ i8 = 3
21
+ i16 = 4
22
+ i32 = 5
23
+ i64 = 6
24
+ u8 = 8
25
+ u16 = 9
26
+ u32 = 10
27
+ u64 = 11
28
+
29
+
30
+ def get_data_type_size(dtype: DataType) -> int:
31
+ if dtype in [DataType.f16, DataType.i16, DataType.u16]:
32
+ return 2
33
+ if dtype in [DataType.f32, DataType.i32, DataType.u32]:
34
+ return 4
35
+ if dtype in [DataType.f64, DataType.i64, DataType.u64]:
36
+ return 8
37
+ assert False
38
+
39
+
40
+ class Argument(ABC):
41
+ def __init__(self, name: Optional[str]):
42
+ self.name = name
43
+ pass
44
+
45
+
46
+ class ArgumentScalar(Argument):
47
+ def __init__(self, name: Optional[str], dtype: DataType):
48
+ super().__init__(name)
49
+ self.dtype: DataType = dtype
50
+
51
+
52
+ class ParameterType(Enum):
53
+ Scalar = 0
54
+ Ndarray = 1
55
+ Texture = 2
56
+ RwTexture = 3
57
+ Unknown = 4
58
+
59
+
60
+ class NdArrayAccess(Enum):
61
+ NoAccess = 0
62
+ Read = 1
63
+ Write = 2
64
+ ReadWrite = 3
65
+
66
+
67
+ class ArgumentNdArray(Argument):
68
+ def __init__(
69
+ self,
70
+ name: Optional[str],
71
+ dtype: DataType,
72
+ element_shape: List[int],
73
+ ndim: int,
74
+ access: NdArrayAccess,
75
+ ):
76
+ super().__init__(name)
77
+ self.dtype: DataType = dtype
78
+ self.element_shape: List[int] = element_shape
79
+ self.ndim: int = ndim
80
+ self.access: NdArrayAccess = access
81
+
82
+
83
+ class ArgumentTexture(Argument):
84
+ def __init__(self, name: Optional[str], ndim: int):
85
+ super().__init__(name)
86
+ self.ndim: int = ndim
87
+
88
+
89
+ class ArgumentRwTexture(Argument):
90
+ def __init__(self, name: Optional[str], fmt: Format, ndim: int):
91
+ super().__init__(name)
92
+ self.fmt: Format = fmt
93
+ self.ndim: int = ndim
94
+
95
+
96
+ class ReturnValue:
97
+ def __init__(self, dtype: DataType):
98
+ self.dtype: DataType = dtype
99
+
100
+
101
+ class Context:
102
+ def __init__(self, args: List[Argument], ret: Optional[ReturnValue]):
103
+ self.args: List[Argument] = args
104
+ self.ret: Optional[ReturnValue] = ret
105
+
106
+
107
+ class BufferBindingType(Enum):
108
+ Root = 0
109
+ GlobalTmps = 1
110
+ Args = 2
111
+ Rets = 3
112
+ ListGen = 4
113
+ ExtArr = 5
114
+
115
+
116
+ class BufferBinding:
117
+ def __init__(self, binding: int, iarg: int, buffer_bind_ty: BufferBindingType):
118
+ self.binding: int = binding
119
+ self.iarg: int = iarg
120
+ self.buffer_bind_ty: BufferBindingType = buffer_bind_ty
121
+
122
+
123
+ class TextureBindingType(Enum):
124
+ Texture = 0
125
+ RwTexture = 1
126
+
127
+
128
+ class TextureBinding:
129
+ def __init__(self, binding: int, iarg: int, texture_bind_ty: TextureBindingType):
130
+ self.binding: int = binding
131
+ self.iarg: int = iarg
132
+ self.texture_bind_ty: TextureBindingType = texture_bind_ty
133
+
134
+
135
+ class TaskType(Enum):
136
+ Serial = 0
137
+ RangeFor = 1
138
+ StructFor = 2
139
+ MeshFor = 3
140
+ ListGen = 4
141
+ Gc = 5
142
+ GcRc = 6
143
+
144
+
145
+ class LaunchGrid:
146
+ def __init__(self, block_size: int, grid_size: int):
147
+ self.block_size: int = block_size
148
+ self.grid_size: int = grid_size
149
+
150
+
151
+ class Task:
152
+ def __init__(
153
+ self,
154
+ name: str,
155
+ task_ty: TaskType,
156
+ buffer_binds: List[BufferBinding],
157
+ texture_binds: List[TextureBinding],
158
+ launch_grid: LaunchGrid,
159
+ ):
160
+ self.name: str = name
161
+ self.task_ty: TaskType = task_ty
162
+ self.buffer_binds: List[BufferBinding] = buffer_binds
163
+ self.texture_binds: List[TextureBinding] = texture_binds
164
+ self.launch_grid: LaunchGrid = launch_grid
165
+
166
+
167
+ class Field:
168
+ def __init__(
169
+ self,
170
+ name: str,
171
+ dtype: DataType,
172
+ element_shape: List[int],
173
+ shape: List[int],
174
+ offset: int,
175
+ ):
176
+ self.name: str = name
177
+ self.dtype: DataType = dtype
178
+ self.element_shape: List[int] = element_shape
179
+ self.shape: List[int] = shape
180
+ self.offset: int = offset
181
+
182
+
183
+ class Kernel:
184
+ def __init__(self, name: str, context: Context, tasks: List[Task]):
185
+ self.name = name
186
+ self.context: Context = context
187
+ self.tasks: List[Task] = tasks
188
+
189
+
190
+ class Metadata:
191
+ def __init__(
192
+ self,
193
+ fields: List[Field],
194
+ kernels: List[Kernel],
195
+ required_caps: List[DeviceCapability],
196
+ root_buffer_size: int,
197
+ ):
198
+ self.fields: Dict[str, Field] = {x.name: x for x in fields}
199
+ self.kernels: Dict[str, Kernel] = {x.name: x for x in kernels}
200
+ self.required_caps: List[DeviceCapability] = required_caps
201
+ self.root_buffer_size: int = root_buffer_size
202
+
203
+
204
+ def from_dr_field(d: dr.FieldAttributes) -> Field:
205
+ return Field(
206
+ d.field_name,
207
+ DataType(d.dtype),
208
+ d.element_shape,
209
+ d.shape,
210
+ d.mem_offset_in_parent,
211
+ )
212
+
213
+
214
+ def from_dr_kernel(d: dr.KernelAttributes) -> Kernel:
215
+ assert d.is_jit_evaluator is False
216
+
217
+ name = d.name
218
+
219
+ class OpaqueArgumentType(Enum):
220
+ NdArray = 0
221
+ Texture = 1
222
+ RwTexture = 2
223
+
224
+ tasks = []
225
+ iarg2arg_ty: Dict[int, OpaqueArgumentType] = {}
226
+ for task in d.tasks_attribs:
227
+ # Collect buffer bindings.
228
+ buffer_binds = []
229
+ for buffer_bind in task.buffer_binds:
230
+ binding = buffer_bind.binding
231
+ iarg = buffer_bind.buffer.root_id
232
+ buffer_ty = BufferBindingType(buffer_bind.buffer.type)
233
+ buffer_binds += [BufferBinding(binding, iarg, buffer_ty)]
234
+ if buffer_ty == BufferBindingType.ExtArr:
235
+ iarg2arg_ty[buffer_bind.buffer.root_id] = OpaqueArgumentType.NdArray
236
+ elif buffer_ty == BufferBindingType.Root:
237
+ pass
238
+ elif buffer_ty == BufferBindingType.Args:
239
+ pass
240
+ elif buffer_ty == BufferBindingType.ListGen:
241
+ pass
242
+ elif buffer_ty == BufferBindingType.Rets:
243
+ pass
244
+ elif buffer_ty == BufferBindingType.GlobalTmps:
245
+ pass
246
+ else:
247
+ assert False
248
+
249
+ # Collect texture bindings.
250
+ texture_binds = []
251
+ for texture_bind in task.texture_binds:
252
+ binding = texture_bind.binding
253
+ iarg = texture_bind.arg_id
254
+ if texture_bind.is_storage:
255
+ texture_binds += [TextureBinding(binding, iarg, TextureBindingType.RwTexture)]
256
+ iarg2arg_ty[iarg] = OpaqueArgumentType.RwTexture
257
+ else:
258
+ texture_binds += [TextureBinding(binding, iarg, TextureBindingType.Texture)]
259
+ iarg2arg_ty[iarg] = OpaqueArgumentType.Texture
260
+
261
+ launch_grid = LaunchGrid(task.advisory_num_threads_per_group, task.advisory_total_num_threads)
262
+
263
+ tasks += [
264
+ Task(
265
+ task.name,
266
+ TaskType(task.task_type),
267
+ buffer_binds,
268
+ texture_binds,
269
+ launch_grid,
270
+ )
271
+ ]
272
+
273
+ args = []
274
+ for i, arg in enumerate(d.ctx_attribs.arg_attribs_vec_):
275
+ assert i == arg.index
276
+ ptype = ParameterType(arg.ptype)
277
+ if ptype is not None:
278
+ if ptype == ParameterType.Scalar:
279
+ args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
280
+ elif ptype == ParameterType.Ndarray:
281
+ args += [
282
+ ArgumentNdArray(
283
+ arg.name,
284
+ DataType(arg.dtype),
285
+ arg.element_shape,
286
+ arg.field_dim,
287
+ NdArrayAccess(d.ctx_attribs.arr_access[i]),
288
+ )
289
+ ]
290
+ elif ptype == ParameterType.Texture:
291
+ args += [ArgumentTexture(arg.name, arg.field_dim)]
292
+ elif ptype == ParameterType.RwTexture:
293
+ args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
294
+ else:
295
+ assert False
296
+ else:
297
+ # TODO: Keeping this for BC but feel free to break it if necessary
298
+ if arg.is_array:
299
+ # Opaque binding types.
300
+ binding_ty = iarg2arg_ty[arg.index]
301
+ if binding_ty == OpaqueArgumentType.NdArray:
302
+ args += [
303
+ ArgumentNdArray(
304
+ arg.name,
305
+ DataType(arg.dtype),
306
+ arg.element_shape,
307
+ arg.field_dim,
308
+ NdArrayAccess(d.ctx_attribs.arr_access[i]),
309
+ )
310
+ ]
311
+ elif binding_ty == OpaqueArgumentType.Texture:
312
+ args += [ArgumentTexture(arg.name, arg.field_dim)]
313
+ elif binding_ty == OpaqueArgumentType.RwTexture:
314
+ args += [ArgumentRwTexture(arg.name, Format(arg.format), arg.field_dim)]
315
+ else:
316
+ assert False
317
+ else:
318
+ args += [ArgumentScalar(arg.name, DataType(arg.dtype))]
319
+
320
+ assert len(d.ctx_attribs.ret_attribs_vec_) <= 1
321
+ if len(d.ctx_attribs.ret_attribs_vec_) != 0:
322
+ dtype = d.ctx_attribs.ret_attribs_vec_[0].dtype
323
+ rv = ReturnValue(DataType(dtype))
324
+ else:
325
+ rv = None
326
+
327
+ context = Context(args, rv)
328
+
329
+ return Kernel(name, context, tasks)
330
+
331
+
332
+ def from_dr_metadata(d: dr.Metadata) -> Metadata:
333
+ fields = [from_dr_field(x) for x in d.fields]
334
+ kernels = [from_dr_kernel(x) for x in d.kernels]
335
+ required_caps = []
336
+ for cap in d.required_caps:
337
+ if cap.value == 1:
338
+ required_caps += [cap.key]
339
+ else:
340
+ required_caps += [f"{cap.key}={cap.value}"]
341
+ root_buffer_size = d.root_buffer_size
342
+
343
+ return Metadata(fields, kernels, required_caps, root_buffer_size)
344
+
345
+
346
+ def to_dr_field(f: Field) -> Dict[str, Any]:
347
+ raise NotImplementedError()
348
+
349
+
350
+ def to_dr_kernel(s: Kernel) -> Dict[str, Any]:
351
+ tasks = []
352
+ for task in s.tasks:
353
+ buffer_binds = []
354
+ for buffer_bind in task.buffer_binds:
355
+ j = {
356
+ "binding": buffer_bind.binding,
357
+ "buffer": {
358
+ "root_id": buffer_bind.iarg,
359
+ "type": buffer_bind.buffer_bind_ty.value,
360
+ },
361
+ }
362
+ buffer_binds += [j]
363
+
364
+ texture_binds = []
365
+ for texture_bind in task.texture_binds:
366
+ j = {
367
+ "arg_id": texture_bind.iarg,
368
+ "binding": texture_bind.binding,
369
+ "is_storage": texture_bind.texture_bind_ty == TextureBindingType.RwTexture,
370
+ }
371
+ texture_binds += [j]
372
+
373
+ if task.task_ty == TaskType.RangeFor:
374
+ range_for_attribs = {
375
+ "begin": 0,
376
+ "const_begin": True,
377
+ "const_end": True,
378
+ "end": task.launch_grid.grid_size,
379
+ }
380
+ else:
381
+ range_for_attribs = None
382
+
383
+ j = {
384
+ "advisory_num_threads_per_group": task.launch_grid.block_size,
385
+ "advisory_total_num_threads": task.launch_grid.grid_size,
386
+ "buffer_binds": buffer_binds,
387
+ "name": task.name,
388
+ "range_for_attribs": range_for_attribs,
389
+ "task_type": task.task_ty.value,
390
+ "texture_binds": texture_binds,
391
+ }
392
+ tasks += [j]
393
+
394
+ args = []
395
+ arg_bytes = 0
396
+ arr_access = []
397
+ arg_offset = 0
398
+ for i, arg in enumerate(s.context.args):
399
+ if isinstance(arg, ArgumentNdArray):
400
+ j = {
401
+ "dtype": arg.dtype.value,
402
+ "element_shape": arg.element_shape,
403
+ "field_dim": arg.ndim,
404
+ "format": Format.unknown,
405
+ "index": i,
406
+ "is_array": True,
407
+ "offset_in_mem": arg_offset,
408
+ "stride": 4,
409
+ }
410
+ args += [j]
411
+ arr_access += [arg.access.value]
412
+ elif isinstance(arg, ArgumentTexture):
413
+ j = {
414
+ "dtype": 1,
415
+ "element_shape": [],
416
+ "field_dim": arg.ndim,
417
+ "format": Format.unknown,
418
+ "index": i,
419
+ "is_array": True,
420
+ "offset_in_mem": arg_offset,
421
+ "stride": 4,
422
+ }
423
+ args += [j]
424
+ arr_access += [0]
425
+ elif isinstance(arg, ArgumentRwTexture):
426
+ j = {
427
+ "dtype": 1,
428
+ "element_shape": [],
429
+ "field_dim": arg.ndim,
430
+ "format": arg.fmt,
431
+ "index": i,
432
+ "is_array": True,
433
+ "offset_in_mem": arg_offset,
434
+ "stride": 4,
435
+ }
436
+ args += [j]
437
+ arr_access += [0]
438
+ elif isinstance(arg, ArgumentScalar):
439
+ j = {
440
+ "dtype": arg.dtype.value,
441
+ "element_shape": [],
442
+ "field_dim": 0,
443
+ "format": Format.unknown,
444
+ "index": i,
445
+ "is_array": False,
446
+ "offset_in_mem": arg_offset,
447
+ "stride": get_data_type_size(arg.dtype),
448
+ }
449
+ args += [j]
450
+ arr_access += [0]
451
+ else:
452
+ assert False
453
+ arg_offset += j["stride"]
454
+ arg_bytes = max(arg_bytes, j["offset_in_mem"] + j["stride"])
455
+
456
+ rets = []
457
+ ret_bytes = 0
458
+ if s.context.ret is not None:
459
+ for i, ret in enumerate([s.context.ret]):
460
+ j = {
461
+ "dtype": ret.dtype.value,
462
+ "element_shape": [],
463
+ "field_dim": 0,
464
+ "format": Format.unknown,
465
+ "index": i,
466
+ "is_array": False,
467
+ "offset_in_mem": 0,
468
+ "stride": get_data_type_size(ret.dtype),
469
+ }
470
+ rets += [j]
471
+ ret_bytes = max(ret_bytes, j["offset_in_mem"] + j["stride"])
472
+
473
+ ctx_attribs = {
474
+ "arg_attribs_vec_": args,
475
+ "args_bytes_": arg_bytes,
476
+ "arr_access": arr_access,
477
+ "extra_args_bytes_": 1536,
478
+ "ret_attribs_vec_": rets,
479
+ "rets_bytes_": ret_bytes,
480
+ }
481
+
482
+ j = {
483
+ "is_jit_evaluator": False,
484
+ "ctx_attribs": ctx_attribs,
485
+ "name": s.name,
486
+ "tasks_attribs": tasks,
487
+ }
488
+ return j
489
+
490
+
491
+ def to_dr_metadata(s: Metadata) -> dr.Metadata:
492
+ fields = [to_dr_field(x) for x in s.fields.values()]
493
+ kernels = [to_dr_kernel(x) for x in s.kernels.values()]
494
+ required_caps = []
495
+ for cap in s.required_caps:
496
+ cap = str(cap)
497
+ if "=" in cap:
498
+ k, v = cap.split("=", maxsplit=1)
499
+ j = {
500
+ "key": k,
501
+ "value": int(v),
502
+ }
503
+ required_caps += [j]
504
+ else:
505
+ j = {
506
+ "key": cap,
507
+ "value": 1,
508
+ }
509
+ required_caps += [j]
510
+ root_buffer_size = s.root_buffer_size
511
+ j = {
512
+ "fields": fields,
513
+ "kernels": kernels,
514
+ "required_caps": required_caps,
515
+ "root_buffer_size": root_buffer_size,
516
+ }
517
+ return dr.Metadata(j)
518
+
519
+
520
+ class NamedArgument:
521
+ def __init__(self, name: str, arg: Argument):
522
+ self.name = name
523
+ self.arg = arg
524
+
525
+
526
+ class Dispatch:
527
+ def __init__(self, kernel: Kernel, args: List[NamedArgument]):
528
+ self.kernel = kernel
529
+ self.args = args
530
+
531
+
532
+ class Graph:
533
+ def __init__(self, name: str, dispatches: List[Dispatch]):
534
+ self.name = name
535
+ self.dispatches = dispatches
536
+ args = {y.name: y.arg for x in dispatches for y in x.args}
537
+ self.args: List[NamedArgument] = [NamedArgument(k, v) for k, v in args.items()]
538
+
539
+
540
+ def from_dr_graph(meta: Metadata, j: dr.Graph) -> Graph:
541
+ dispatches = []
542
+ for dispatch in j.value.dispatches:
543
+ kernel = meta.kernels[dispatch.kernel_name]
544
+ args = []
545
+ for i, symbolic_arg in enumerate(dispatch.symbolic_args):
546
+ arg = kernel.context.args[i]
547
+ args += [NamedArgument(symbolic_arg.name, arg)]
548
+ dispatches += [Dispatch(kernel, args)]
549
+ return Graph(j.key, dispatches)
550
+
551
+
552
+ def to_dr_graph(s: Graph) -> dr.Graph:
553
+ dispatches = []
554
+ for dispatch in s.dispatches:
555
+ kernel = dispatch.kernel
556
+ symbolic_args = []
557
+ for arg in dispatch.args:
558
+ if isinstance(arg.arg, ArgumentScalar):
559
+ j = {
560
+ "dtype_id": arg.arg.dtype.value,
561
+ "element_shape": [],
562
+ "field_dim": 0,
563
+ "name": arg.name,
564
+ "num_channels": 0,
565
+ "tag": 0,
566
+ }
567
+ symbolic_args += [j]
568
+ elif isinstance(arg.arg, ArgumentNdArray):
569
+ j = {
570
+ "dtype_id": arg.arg.dtype.value,
571
+ "element_shape": arg.arg.element_shape,
572
+ "field_dim": arg.arg.ndim,
573
+ "name": arg.name,
574
+ "num_channels": 0,
575
+ "tag": 2,
576
+ }
577
+ symbolic_args += [j]
578
+ elif isinstance(arg.arg, ArgumentTexture):
579
+ j = {
580
+ "dtype_id": DataType.f32.value,
581
+ "element_shape": [],
582
+ "field_dim": 0,
583
+ "name": arg.name,
584
+ "num_channels": 0,
585
+ "tag": 3,
586
+ }
587
+ symbolic_args += [j]
588
+ elif isinstance(arg.arg, ArgumentRwTexture):
589
+ j = {
590
+ "dtype_id": DataType.f32.value,
591
+ "element_shape": [],
592
+ "field_dim": 0,
593
+ "name": arg.name,
594
+ "num_channels": 0,
595
+ "tag": 4,
596
+ }
597
+ symbolic_args += [j]
598
+ else:
599
+ assert False
600
+
601
+ j = {
602
+ "kernel_name": kernel.name,
603
+ "symbolic_args": symbolic_args,
604
+ }
605
+ dispatches += [j]
606
+
607
+ j = {
608
+ "key": s.name,
609
+ "value": {
610
+ "dispatches": dispatches,
611
+ },
612
+ }
613
+ return dr.Graph(j)