gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
@@ -0,0 +1,117 @@
1
+ # type: ignore
2
+
3
+ from taichi._kernels import (
4
+ blit_from_field_to_field,
5
+ scan_add_inclusive,
6
+ sort_stage,
7
+ uniform_add,
8
+ warp_shfl_up_i32,
9
+ )
10
+ from taichi.lang.impl import current_cfg, field
11
+ from taichi.lang.kernel_impl import data_oriented
12
+ from taichi.lang.misc import cuda, vulkan
13
+ from taichi.lang.runtime_ops import sync
14
+ from taichi.lang.simt import subgroup
15
+ from taichi.types.primitive_types import i32
16
+
17
+
18
+ def parallel_sort(keys, values=None):
19
+ """Odd-even merge sort
20
+
21
+ References:
22
+ https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
23
+ https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
24
+ """
25
+ N = keys.shape[0]
26
+
27
+ num_stages = 0
28
+ p = 1
29
+ while p < N:
30
+ k = p
31
+ while k >= 1:
32
+ invocations = int((N - k - k % p) / (2 * k)) + 1
33
+ if values is None:
34
+ sort_stage(keys, 0, keys, N, p, k, invocations)
35
+ else:
36
+ sort_stage(keys, 1, values, N, p, k, invocations)
37
+ num_stages += 1
38
+ sync()
39
+ k = int(k / 2)
40
+ p = int(p * 2)
41
+
42
+
43
+ @data_oriented
44
+ class PrefixSumExecutor:
45
+ """Parallel Prefix Sum (Scan) Helper
46
+
47
+ Use this helper to perform an inclusive in-place's parallel prefix sum.
48
+
49
+ References:
50
+ https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
51
+ https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
52
+ """
53
+
54
+ def __init__(self, length):
55
+ self.sorting_length = length
56
+
57
+ BLOCK_SZ = 64
58
+ GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)
59
+
60
+ # Buffer position and length
61
+ # This is a single buffer implementation for ease of aot usage
62
+ ele_num = length
63
+ self.ele_nums = [ele_num]
64
+ start_pos = 0
65
+ self.ele_nums_pos = [start_pos]
66
+
67
+ while ele_num > 1:
68
+ ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
69
+ self.ele_nums.append(ele_num)
70
+ start_pos += BLOCK_SZ * ele_num
71
+ self.ele_nums_pos.append(start_pos)
72
+
73
+ self.large_arr = field(i32, shape=start_pos)
74
+
75
+ def run(self, input_arr):
76
+ length = self.sorting_length
77
+ ele_nums = self.ele_nums
78
+ ele_nums_pos = self.ele_nums_pos
79
+
80
+ if input_arr.dtype != i32:
81
+ raise RuntimeError("Only ti.i32 type is supported for prefix sum.")
82
+
83
+ if current_cfg().arch == cuda:
84
+ inclusive_add = warp_shfl_up_i32
85
+ elif current_cfg().arch == vulkan:
86
+ inclusive_add = subgroup.inclusive_add
87
+ else:
88
+ raise RuntimeError(f"{str(current_cfg().arch)} is not supported for prefix sum.")
89
+
90
+ blit_from_field_to_field(self.large_arr, input_arr, 0, length)
91
+
92
+ # Kogge-Stone construction
93
+ for i in range(len(ele_nums) - 1):
94
+ if i == len(ele_nums) - 2:
95
+ scan_add_inclusive(
96
+ self.large_arr,
97
+ ele_nums_pos[i],
98
+ ele_nums_pos[i + 1],
99
+ True,
100
+ inclusive_add,
101
+ )
102
+ else:
103
+ scan_add_inclusive(
104
+ self.large_arr,
105
+ ele_nums_pos[i],
106
+ ele_nums_pos[i + 1],
107
+ False,
108
+ inclusive_add,
109
+ )
110
+
111
+ for i in range(len(ele_nums) - 3, -1, -1):
112
+ uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])
113
+
114
+ blit_from_field_to_field(input_arr, self.large_arr, 0, length)
115
+
116
+
117
+ __all__ = ["parallel_sort", "PrefixSumExecutor"]
taichi/aot/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ # type: ignore
2
+
3
+ """Taichi's AOT (ahead of time) module.
4
+
5
+ Users can use Taichi as a GPU compute shader/kernel compiler by compiling their
6
+ Taichi kernels into an AOT module.
7
+ """
8
+
9
+ import taichi.aot.conventions
10
+ from taichi.aot._export import export, export_as
11
+ from taichi.aot.conventions.gfxruntime140 import GfxRuntime140
12
+ from taichi.aot.module import Module
taichi/aot/_export.py ADDED
@@ -0,0 +1,28 @@
1
+ # type: ignore
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+
6
+ class AotExportKernel:
7
+ def __init__(self, f, name: str, template_types: Dict[str, Any]) -> None:
8
+ self.kernel = f
9
+ self.name = name
10
+ self.template_types = template_types
11
+
12
+
13
+ _aot_kernels: List[AotExportKernel] = []
14
+
15
+
16
+ def export_as(name: str, *, template_types: Optional[Dict[str, Any]] = None):
17
+ def inner(f):
18
+ assert hasattr(f, "_is_wrapped_kernel"), "Only Taichi kernels can be exported"
19
+
20
+ record = AotExportKernel(f, name, template_types or {})
21
+ _aot_kernels.append(record)
22
+ return f
23
+
24
+ return inner
25
+
26
+
27
+ def export(f):
28
+ return export_as(f.__name__)(f)
@@ -0,0 +1,3 @@
1
+ # type: ignore
2
+
3
+ from taichi.aot.conventions import gfxruntime140
@@ -0,0 +1,38 @@
1
+ # type: ignore
2
+
3
+ import json
4
+ import zipfile
5
+ from pathlib import Path
6
+ from typing import Any, List
7
+
8
+ from taichi.aot.conventions.gfxruntime140 import dr, sr
9
+
10
+
11
+ class GfxRuntime140:
12
+ def __init__(self, metadata_json: Any, graphs_json: Any) -> None:
13
+ metadata = dr.from_json_metadata(metadata_json)
14
+ graphs = [dr.from_json_graph(x) for x in graphs_json]
15
+ self.metadata = sr.from_dr_metadata(metadata)
16
+ self.graphs = [sr.from_dr_graph(self.metadata, x) for x in graphs]
17
+
18
+ @staticmethod
19
+ def from_module(module_path: str) -> "GfxRuntime140":
20
+ if Path(module_path).is_file():
21
+ with zipfile.ZipFile(module_path) as z:
22
+ with z.open("metadata.json") as f:
23
+ metadata_json = json.load(f)
24
+ with z.open("graphs.json") as f:
25
+ graphs_json = json.load(f)
26
+ else:
27
+ with open(f"{module_path}/metadata.json") as f:
28
+ metadata_json = json.load(f)
29
+ with open(f"{module_path}/graphs.json") as f:
30
+ graphs_json = json.load(f)
31
+
32
+ return GfxRuntime140(metadata_json, graphs_json)
33
+
34
+ def to_metadata_json(self) -> Any:
35
+ return dr.to_json_metadata(sr.to_dr_metadata(self.metadata))
36
+
37
+ def to_graphs_json(self) -> List[Any]:
38
+ return [dr.to_json_graph(sr.to_dr_graph(x)) for x in self.graphs]
@@ -0,0 +1,244 @@
1
+ # type: ignore
2
+
3
+ """
4
+ Data representation of all JSON data structures following the GfxRuntime140
5
+ convention.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from taichi.aot.utils import dump_json_data_model, json_data_model
11
+
12
+
13
+ @json_data_model
14
+ class FieldAttributes:
15
+ def __init__(self, j: Dict[str, Any]) -> None:
16
+ dtype = j["dtype"]
17
+ dtype_name = j["dtype_name"]
18
+ element_shape = j["element_shape"]
19
+ field_name = j["field_name"]
20
+ is_scalar = j["is_scalar"]
21
+ mem_offset_in_parent = j["mem_offset_in_parent"]
22
+ shape = j["shape"]
23
+
24
+ self.dtype: int = int(dtype)
25
+ self.dtype_name: str = str(dtype_name)
26
+ self.element_shape: List[int] = [int(x) for x in element_shape]
27
+ self.field_name: str = str(field_name)
28
+ self.is_scalar: bool = bool(is_scalar)
29
+ self.mem_offset_in_parent: int = int(mem_offset_in_parent)
30
+ self.shape: List[int] = [int(x) for x in shape]
31
+
32
+
33
+ @json_data_model
34
+ class ArgumentAttributes:
35
+ def __init__(self, j: Dict[str, Any]) -> None:
36
+ index = j["key"][0]
37
+ dtype = j["value"]["dtype"]
38
+ element_shape = j["value"]["element_shape"]
39
+ field_dim = j["value"]["field_dim"]
40
+ fmt = j["value"]["format"]
41
+ is_array = j["value"]["is_array"]
42
+ offset_in_mem = j["value"]["offset_in_mem"]
43
+ stride = j["value"]["stride"]
44
+ # (penguinliong) Note that the name field is optional for kernels.
45
+ # Kernels are always launched by indexed arguments and this is for
46
+ # debugging and header generation only.
47
+ name = j["value"]["name"] if "name" in j["value"] and len(j["value"]["name"]) > 0 else None
48
+ ptype = j["value"]["ptype"] if "ptype" in j["value"] else None
49
+
50
+ self.dtype: int = int(dtype)
51
+ self.element_shape: List[int] = [int(x) for x in element_shape]
52
+ self.field_dim: int = int(field_dim)
53
+ self.format: int = int(fmt)
54
+ self.index: int = int(index)
55
+ self.is_array: bool = bool(is_array)
56
+ self.offset_in_mem: int = int(offset_in_mem)
57
+ self.stride: int = int(stride)
58
+ self.name: Optional[str] = str(name) if name is not None else None
59
+ self.ptype: Optional[int] = int(ptype) if ptype is not None else None
60
+
61
+
62
+ @json_data_model
63
+ class ContextAttributes:
64
+ def __init__(self, j: Dict[str, Any]) -> None:
65
+ arg_attribs_vec_ = j["arg_attribs_vec_"]
66
+ args_bytes_ = j["args_bytes_"]
67
+ arr_access = j["arr_access"]
68
+ ret_attribs_vec_ = j["ret_attribs_vec_"]
69
+ rets_bytes_ = j["rets_bytes_"]
70
+
71
+ self.arg_attribs_vec_: List[ArgumentAttributes] = [ArgumentAttributes(x) for x in arg_attribs_vec_]
72
+ self.arg_attribs_vec_.sort(key=lambda x: x.index)
73
+ self.args_bytes_: int = int(args_bytes_)
74
+ self.arr_access: List[int] = [int(x["value"]) for x in arr_access]
75
+ self.ret_attribs_vec_: List[ArgumentAttributes] = [ArgumentAttributes(x) for x in ret_attribs_vec_]
76
+ self.rets_bytes_: int = int(rets_bytes_)
77
+
78
+
79
+ @json_data_model
80
+ class Buffer:
81
+ def __init__(self, j: Dict[str, Any]) -> None:
82
+ root_id = j["root_id"][0]
83
+ ty = j["type"]
84
+
85
+ self.root_id: int = int(root_id)
86
+ self.type: int = int(ty)
87
+
88
+
89
+ @json_data_model
90
+ class BufferBinding:
91
+ def __init__(self, j: Dict[str, Any]) -> None:
92
+ binding = j["binding"]
93
+ buffer = j["buffer"]
94
+
95
+ self.binding: int = int(binding)
96
+ self.buffer: Buffer = Buffer(buffer)
97
+
98
+
99
+ @json_data_model
100
+ class TextureBinding:
101
+ def __init__(self, j: Dict[str, Any]) -> None:
102
+ arg_id = j["arg_id"]
103
+ binding = j["binding"]
104
+ is_storage = j["is_storage"]
105
+
106
+ self.arg_id: int = int(arg_id)
107
+ self.binding: int = int(binding)
108
+ self.is_storage: bool = bool(is_storage)
109
+
110
+
111
+ @json_data_model
112
+ class RangeForAttributes:
113
+ def __init__(self, j: Dict[str, Any]) -> None:
114
+ begin = j["begin"]
115
+ const_begin = j["const_begin"]
116
+ const_end = j["const_end"]
117
+ end = j["end"]
118
+
119
+ self.begin: int = int(begin)
120
+ self.const_begin: bool = bool(const_begin)
121
+ self.const_end: bool = bool(const_end)
122
+ self.end: int = int(end)
123
+
124
+
125
+ @json_data_model
126
+ class TaskAttributes:
127
+ def __init__(self, j: Dict[str, Any]) -> None:
128
+ advisory_num_threads_per_group = j["advisory_num_threads_per_group"]
129
+ advisory_total_num_threads = j["advisory_total_num_threads"]
130
+ buffer_binds = j["buffer_binds"]
131
+ name = j["name"]
132
+ range_for_attribs = j["range_for_attribs"] if "range_for_attribs" in j else None
133
+ task_type = j["task_type"]
134
+ texture_binds = j["texture_binds"]
135
+
136
+ self.advisory_num_threads_per_group: int = int(advisory_num_threads_per_group)
137
+ self.advisory_total_num_threads: int = int(advisory_total_num_threads)
138
+ self.buffer_binds: List[BufferBinding] = [BufferBinding(x) for x in buffer_binds]
139
+ self.name: str = str(name)
140
+ self.range_for_attribs: Optional[RangeForAttributes] = (
141
+ RangeForAttributes(range_for_attribs) if range_for_attribs is not None else None
142
+ )
143
+ self.task_type: int = int(task_type)
144
+ self.texture_binds: List[TextureBinding] = [TextureBinding(x) for x in texture_binds]
145
+
146
+
147
+ @json_data_model
148
+ class DeviceCapabilityLevel:
149
+ def __init__(self, j: Dict[str, Any]) -> None:
150
+ key = j["key"]
151
+ value = j["value"]
152
+
153
+ self.key: str = str(key)
154
+ self.value: int = int(value)
155
+
156
+
157
+ @json_data_model
158
+ class KernelAttributes:
159
+ def __init__(self, j: Dict[str, Any]) -> None:
160
+ ctx_attribs = j["ctx_attribs"]
161
+ is_jit_evaluator = j["is_jit_evaluator"]
162
+ name = j["name"]
163
+ tasks_attribs = j["tasks_attribs"]
164
+
165
+ self.ctx_attribs: ContextAttributes = ContextAttributes(ctx_attribs)
166
+ self.is_jit_evaluator: bool = bool(is_jit_evaluator)
167
+ self.name: str = str(name)
168
+ self.tasks_attribs: List[TaskAttributes] = [TaskAttributes(x) for x in tasks_attribs]
169
+
170
+
171
+ @json_data_model
172
+ class Metadata:
173
+ def __init__(self, j: Dict[str, Any]) -> None:
174
+ fields = j["fields"]
175
+ kernels = j["kernels"]
176
+ required_caps = j["required_caps"]
177
+ root_buffer_size = j["root_buffer_size"]
178
+
179
+ self.fields: List[FieldAttributes] = [FieldAttributes(x) for x in fields]
180
+ self.kernels: List[KernelAttributes] = [KernelAttributes(x) for x in kernels]
181
+ self.required_caps: List[DeviceCapabilityLevel] = [DeviceCapabilityLevel(x) for x in required_caps]
182
+ self.root_buffer_size: int = int(root_buffer_size)
183
+
184
+
185
+ def from_json_metadata(j: Dict[str, Any]) -> Metadata:
186
+ return Metadata(j)
187
+
188
+
189
+ def to_json_metadata(meta_data: Metadata) -> Dict[str, Any]:
190
+ return dump_json_data_model(meta_data)
191
+
192
+
193
+ @json_data_model
194
+ class SymbolicArgument:
195
+ def __init__(self, j: Dict[str, Any]) -> None:
196
+ dtype_id = j["dtype_id"]
197
+ element_shape = j["element_shape"]
198
+ field_dim = j["field_dim"]
199
+ name = j["name"]
200
+ num_channels = j["num_channels"]
201
+ tag = j["tag"]
202
+
203
+ self.dtype_id: int = int(dtype_id)
204
+ self.element_shape: List[int] = [int(x) for x in element_shape]
205
+ self.field_dim: int = int(field_dim)
206
+ self.name: str = str(name)
207
+ self.num_channels: int = int(num_channels)
208
+ self.tag: int = int(tag)
209
+
210
+
211
+ @json_data_model
212
+ class Dispatch:
213
+ def __init__(self, j: Dict[str, Any]) -> None:
214
+ kernel_name = j["kernel_name"]
215
+ symbolic_args = j["symbolic_args"]
216
+
217
+ self.kernel_name: str = str(kernel_name)
218
+ self.symbolic_args: List[SymbolicArgument] = [SymbolicArgument(x) for x in symbolic_args]
219
+
220
+
221
+ @json_data_model
222
+ class GraphData:
223
+ def __init__(self, j: Dict[str, Any]) -> None:
224
+ dispatches = j["dispatches"]
225
+
226
+ self.dispatches = [Dispatch(x) for x in dispatches]
227
+
228
+
229
+ @json_data_model
230
+ class Graph:
231
+ def __init__(self, j: Dict[str, Any]) -> None:
232
+ key = j["key"]
233
+ value = j["value"]
234
+
235
+ self.key = str(key)
236
+ self.value = GraphData(value)
237
+
238
+
239
+ def from_json_graph(j: Dict[str, Any]) -> Graph:
240
+ return Graph(j)
241
+
242
+
243
+ def to_json_graph(graph: Graph) -> Dict[str, Any]:
244
+ return dump_json_data_model(graph)