gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 1.0.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi/CHANGELOG.md +6 -0
  2. gstaichi/__init__.py +40 -0
  3. {taichi → gstaichi}/_funcs.py +8 -8
  4. {taichi → gstaichi}/_kernels.py +19 -19
  5. gstaichi/_lib/__init__.py +3 -0
  6. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
  7. taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
  8. {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
  9. {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
  10. {taichi → gstaichi}/_lib/utils.py +15 -15
  11. {taichi → gstaichi}/_logging.py +1 -1
  12. gstaichi/_snode/__init__.py +5 -0
  13. {taichi → gstaichi}/_snode/fields_builder.py +27 -29
  14. {taichi → gstaichi}/_snode/snode_tree.py +5 -5
  15. gstaichi/_test_tools/__init__.py +0 -0
  16. gstaichi/_test_tools/load_kernel_string.py +30 -0
  17. gstaichi/_version.py +1 -0
  18. {taichi → gstaichi}/_version_check.py +8 -5
  19. gstaichi/ad/__init__.py +3 -0
  20. {taichi → gstaichi}/ad/_ad.py +26 -26
  21. {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
  22. {taichi → gstaichi}/examples/minimal.py +1 -1
  23. {taichi → gstaichi}/experimental.py +1 -1
  24. gstaichi/lang/__init__.py +50 -0
  25. {taichi → gstaichi}/lang/_ndarray.py +30 -26
  26. {taichi → gstaichi}/lang/_ndrange.py +8 -8
  27. gstaichi/lang/_template_mapper.py +199 -0
  28. {taichi → gstaichi}/lang/_texture.py +19 -19
  29. {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
  30. {taichi → gstaichi}/lang/any_array.py +13 -13
  31. {taichi → gstaichi}/lang/argpack.py +29 -29
  32. gstaichi/lang/ast/__init__.py +5 -0
  33. {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
  34. {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
  35. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  36. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  37. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  38. {taichi → gstaichi}/lang/ast/checkers.py +5 -5
  39. gstaichi/lang/ast/transform.py +9 -0
  40. {taichi → gstaichi}/lang/common_ops.py +12 -12
  41. gstaichi/lang/exception.py +80 -0
  42. {taichi → gstaichi}/lang/expr.py +22 -22
  43. {taichi → gstaichi}/lang/field.py +29 -27
  44. {taichi → gstaichi}/lang/impl.py +116 -121
  45. {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
  46. {taichi → gstaichi}/lang/kernel_impl.py +330 -363
  47. {taichi → gstaichi}/lang/matrix.py +119 -115
  48. {taichi → gstaichi}/lang/matrix_ops.py +6 -6
  49. {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
  50. {taichi → gstaichi}/lang/mesh.py +22 -22
  51. {taichi → gstaichi}/lang/misc.py +39 -68
  52. {taichi → gstaichi}/lang/ops.py +146 -141
  53. {taichi → gstaichi}/lang/runtime_ops.py +2 -2
  54. {taichi → gstaichi}/lang/shell.py +3 -3
  55. {taichi → gstaichi}/lang/simt/__init__.py +1 -1
  56. {taichi → gstaichi}/lang/simt/block.py +7 -7
  57. {taichi → gstaichi}/lang/simt/grid.py +1 -1
  58. {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
  59. {taichi → gstaichi}/lang/simt/warp.py +1 -1
  60. {taichi → gstaichi}/lang/snode.py +46 -44
  61. {taichi → gstaichi}/lang/source_builder.py +13 -13
  62. {taichi → gstaichi}/lang/struct.py +33 -33
  63. {taichi → gstaichi}/lang/util.py +24 -24
  64. gstaichi/linalg/__init__.py +8 -0
  65. {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
  66. {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
  67. {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
  68. {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
  69. {taichi → gstaichi}/math/__init__.py +1 -1
  70. {taichi → gstaichi}/math/_complex.py +21 -20
  71. {taichi → gstaichi}/math/mathimpl.py +56 -56
  72. gstaichi/profiler/__init__.py +6 -0
  73. {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
  74. {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
  75. {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
  76. {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
  77. {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
  78. {taichi → gstaichi}/tools/__init__.py +4 -4
  79. {taichi → gstaichi}/tools/diagnose.py +10 -17
  80. gstaichi/types/__init__.py +19 -0
  81. {taichi → gstaichi}/types/annotations.py +1 -1
  82. {taichi → gstaichi}/types/compound_types.py +8 -8
  83. {taichi → gstaichi}/types/enums.py +1 -1
  84. {taichi → gstaichi}/types/ndarray_type.py +7 -7
  85. {taichi → gstaichi}/types/primitive_types.py +17 -14
  86. {taichi → gstaichi}/types/quant.py +9 -9
  87. {taichi → gstaichi}/types/texture_type.py +5 -5
  88. {taichi → gstaichi}/types/utils.py +1 -1
  89. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  90. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  91. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  92. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  93. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  94. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  95. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  96. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools.lib +0 -0
  97. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
  98. gstaichi-1.0.1.dist-info/RECORD +135 -0
  99. gstaichi-1.0.1.dist-info/top_level.txt +1 -0
  100. gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
  101. gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
  102. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
  103. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
  104. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
  105. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
  106. gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
  107. gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
  108. gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
  109. gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
  110. taichi/CHANGELOG.md +0 -20
  111. taichi/__init__.py +0 -44
  112. taichi/__main__.py +0 -5
  113. taichi/_lib/__init__.py +0 -3
  114. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  115. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
  116. taichi/_lib/c_api/include/taichi/taichi.h +0 -29
  117. taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
  118. taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
  119. taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
  120. taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
  121. taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
  122. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
  123. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  124. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  125. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  126. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
  127. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
  128. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
  129. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  130. taichi/_main.py +0 -552
  131. taichi/_snode/__init__.py +0 -5
  132. taichi/_ti_module/__init__.py +0 -3
  133. taichi/_ti_module/cppgen.py +0 -309
  134. taichi/_ti_module/module.py +0 -145
  135. taichi/_version.py +0 -1
  136. taichi/ad/__init__.py +0 -3
  137. taichi/aot/__init__.py +0 -12
  138. taichi/aot/_export.py +0 -28
  139. taichi/aot/conventions/__init__.py +0 -3
  140. taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
  141. taichi/aot/conventions/gfxruntime140/dr.py +0 -244
  142. taichi/aot/conventions/gfxruntime140/sr.py +0 -613
  143. taichi/aot/module.py +0 -253
  144. taichi/aot/utils.py +0 -151
  145. taichi/graph/__init__.py +0 -3
  146. taichi/graph/_graph.py +0 -292
  147. taichi/lang/__init__.py +0 -50
  148. taichi/lang/ast/__init__.py +0 -5
  149. taichi/lang/ast/transform.py +0 -9
  150. taichi/lang/exception.py +0 -80
  151. taichi/linalg/__init__.py +0 -8
  152. taichi/profiler/__init__.py +0 -6
  153. taichi/shaders/Circles_vk.frag +0 -29
  154. taichi/shaders/Circles_vk.vert +0 -45
  155. taichi/shaders/Circles_vk_frag.spv +0 -0
  156. taichi/shaders/Circles_vk_vert.spv +0 -0
  157. taichi/shaders/Lines_vk.frag +0 -9
  158. taichi/shaders/Lines_vk.vert +0 -11
  159. taichi/shaders/Lines_vk_frag.spv +0 -0
  160. taichi/shaders/Lines_vk_vert.spv +0 -0
  161. taichi/shaders/Mesh_vk.frag +0 -71
  162. taichi/shaders/Mesh_vk.vert +0 -68
  163. taichi/shaders/Mesh_vk_frag.spv +0 -0
  164. taichi/shaders/Mesh_vk_vert.spv +0 -0
  165. taichi/shaders/Particles_vk.frag +0 -95
  166. taichi/shaders/Particles_vk.vert +0 -73
  167. taichi/shaders/Particles_vk_frag.spv +0 -0
  168. taichi/shaders/Particles_vk_vert.spv +0 -0
  169. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  170. taichi/shaders/SceneLines_vk.frag +0 -9
  171. taichi/shaders/SceneLines_vk.vert +0 -12
  172. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  173. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  174. taichi/shaders/SetImage_vk.frag +0 -21
  175. taichi/shaders/SetImage_vk.vert +0 -15
  176. taichi/shaders/SetImage_vk_frag.spv +0 -0
  177. taichi/shaders/SetImage_vk_vert.spv +0 -0
  178. taichi/shaders/Triangles_vk.frag +0 -16
  179. taichi/shaders/Triangles_vk.vert +0 -29
  180. taichi/shaders/Triangles_vk_frag.spv +0 -0
  181. taichi/shaders/Triangles_vk_vert.spv +0 -0
  182. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  183. taichi/types/__init__.py +0 -19
  184. {taichi → gstaichi}/_lib/core/__init__.py +0 -0
  185. {taichi → gstaichi}/_lib/core/py.typed +0 -0
  186. {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
  187. {taichi → gstaichi}/algorithms/__init__.py +0 -0
  188. {taichi → gstaichi}/assets/.git +0 -0
  189. {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
  190. {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
  191. {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
  192. {taichi → gstaichi}/sparse/__init__.py +0 -0
  193. {taichi → gstaichi}/tools/np2ply.py +0 -0
  194. {taichi → gstaichi}/tools/vtk.py +0 -0
  195. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  196. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  197. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  198. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  199. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  200. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  201. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  202. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  203. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  204. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  205. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  206. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  207. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  208. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  209. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  210. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  211. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  212. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  213. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
  214. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
  215. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  216. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
  217. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  218. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
  219. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
taichi/aot/module.py DELETED
@@ -1,253 +0,0 @@
1
- # type: ignore
2
-
3
- import datetime
4
- import os
5
- import warnings
6
- from contextlib import contextmanager
7
- from glob import glob
8
- from pathlib import Path, PurePosixPath
9
- from shutil import rmtree
10
- from tempfile import mkdtemp
11
- from zipfile import ZipFile
12
-
13
- import taichi
14
- from taichi.aot.utils import produce_injected_args, produce_injected_args_from_template
15
- from taichi.lang import impl, kernel_impl
16
- from taichi.lang.field import ScalarField
17
- from taichi.lang.matrix import MatrixField
18
- from taichi.types.annotations import template
19
-
20
-
21
- class KernelTemplate:
22
- def __init__(self, kernel_fn, aot_module):
23
- self._kernel_fn = kernel_fn
24
- self._aot_module = aot_module
25
-
26
- @staticmethod
27
- def keygen(v, key_p, fields):
28
- if isinstance(v, (int, float, bool)):
29
- key_p += "=" + str(v) + ","
30
- return key_p
31
- for ky, val in fields:
32
- if val is v:
33
- key_p += "=" + ky + ","
34
- return key_p
35
- raise RuntimeError(
36
- "Arg type must be of type int/float/boolean" f" or taichi field. Type {str(type(v))}" " is not supported"
37
- )
38
-
39
- def instantiate(self, **kwargs):
40
- name = self._kernel_fn.__name__
41
- kernel = self._kernel_fn._primal
42
- assert isinstance(kernel, kernel_impl.Kernel)
43
- injected_args = []
44
- key_p = ""
45
- anno_index = 0
46
- template_args = {}
47
-
48
- for index, (key, value) in enumerate(kwargs.items()):
49
- template_args[index] = (key, value)
50
-
51
- for arg in kernel.arguments:
52
- if arg.annotation == template or isinstance(arg.annotation, template):
53
- (k, v) = template_args[anno_index]
54
- key_p += k
55
- key_p = self.keygen(v, key_p, self._aot_module._fields.items())
56
- injected_args.append(v)
57
- anno_index += 1
58
- else:
59
- injected_args.append(0)
60
- kernel.ensure_compiled(*injected_args)
61
- self._aot_module._aot_builder.add_kernel_template(name, key_p, kernel.kernel_cpp)
62
-
63
- # kernel AOT
64
- self._aot_module._kernels.append(kernel)
65
-
66
-
67
- class Module:
68
- """An AOT module to save and load Taichi kernels.
69
-
70
- This module serializes the Taichi kernels for a specific arch. The
71
- serialized module can later be loaded to run on that backend, without the
72
- Python environment.
73
-
74
- Example:
75
- Usage::
76
-
77
- m = ti.aot.Module(ti.metal)
78
- m.add_kernel(foo)
79
- m.add_kernel(bar)
80
-
81
- m.save('/path/to/module')
82
-
83
- # Now the module file '/path/to/module' contains the Metal kernels
84
- # for running ``foo`` and ``bar``.
85
- """
86
-
87
- def __init__(self, arch=None, caps=None):
88
- """Creates a new AOT module instance
89
-
90
- Args:
91
- arch: Target backend architecture. Default to the one initialized in :func:`~taichi.lang.init` if not specified.
92
- caps (List[str]): Enabled device capabilities.
93
- """
94
- if caps is None:
95
- caps = []
96
- curr_arch = impl.current_cfg().arch
97
- if arch is None:
98
- arch = curr_arch
99
- elif arch != curr_arch:
100
- # TODO: we'll support this eventually but not yet...
101
- warnings.warn(
102
- f"AOT compilation to a different arch than the current one is not yet supported, switching to {curr_arch}"
103
- )
104
- arch = curr_arch
105
-
106
- self._arch = arch
107
- self._kernels = []
108
- self._fields = {}
109
- rtm = impl.get_runtime()
110
- rtm._finalize_root_fb_for_aot()
111
- self._aot_builder = rtm.prog.make_aot_module_builder(arch, caps)
112
- self._content = []
113
-
114
- def add_field(self, name, field):
115
- """Add a taichi field to the AOT module.
116
-
117
- Args:
118
- name: name of taichi field
119
- field: taichi field
120
-
121
- Example::
122
-
123
- >>> a = ti.field(ti.f32, shape=(4,4))
124
- >>> b = ti.field("something")
125
- >>>
126
- >>> m.add_field(a)
127
- >>> m.add_field(b)
128
- >>>
129
- >>> # Must add in sequence
130
- """
131
- is_scalar = True
132
- self._fields[name] = field
133
- column_num = 1
134
- row_num = 1
135
- if isinstance(field, MatrixField):
136
- is_scalar = False
137
- row_num = field.m
138
- column_num = field.n
139
- else:
140
- assert isinstance(field, ScalarField)
141
- self._aot_builder.add_field(
142
- name,
143
- field.snode.ptr,
144
- is_scalar,
145
- field.dtype,
146
- field.snode.shape,
147
- row_num,
148
- column_num,
149
- )
150
-
151
- def add_kernel(self, kernel_fn, template_args=None, name=None):
152
- """Add a taichi kernel to the AOT module.
153
-
154
- Args:
155
- kernel_fn (Function): the function decorated by taichi `kernel`.
156
- template_args (Dict[str, Any]): a dict where key is the template
157
- parameter name, and value is the instantiating arg. Note that this
158
- works for both :class:`~taichi.types.template` and for
159
- `:class:`~taichi.types.ndarray`.
160
- name (str): Name to identify this kernel in the module. If not
161
- provided, uses the built-in ``__name__`` attribute of `kernel_fn`.
162
-
163
- """
164
- kernel_name = name or kernel_fn.__name__
165
- kernel = kernel_fn._primal
166
- assert isinstance(kernel, kernel_impl.Kernel)
167
- if template_args is not None:
168
- injected_args = produce_injected_args_from_template(kernel, template_args)
169
- else:
170
- injected_args = produce_injected_args(kernel)
171
- kernel.ensure_compiled(*injected_args)
172
- self._aot_builder.add(kernel_name, kernel.kernel_cpp)
173
-
174
- # kernel AOT
175
- self._kernels.append(kernel)
176
-
177
- self._content += ["kernel:" + kernel_name]
178
-
179
- def add_graph(self, name, graph):
180
- self._aot_builder.add_graph(name, graph._compiled_graph)
181
- self._content += ["cgraph:" + name]
182
-
183
- @contextmanager
184
- def add_kernel_template(self, kernel_fn):
185
- """Add a taichi kernel (with template parameters) to the AOT module.
186
-
187
- Args:
188
- kernel_fn (Function): the function decorated by taichi `kernel`.
189
-
190
- Example::
191
-
192
- >>> @ti.kernel
193
- >>> def bar_tmpl(a: ti.template()):
194
- >>> x = a
195
- >>> # or y = a
196
- >>> # do something with `x` or `y`
197
- >>>
198
- >>> m = ti.aot.Module(arch)
199
- >>> with m.add_kernel_template(bar_tmpl) as kt:
200
- >>> kt.instantiate(a=x)
201
- >>> kt.instantiate(a=y)
202
- >>>
203
- >>> @ti.kernel
204
- >>> def bar_tmpl_multiple_args(a: ti.template(), b: ti.template())
205
- >>> x = a
206
- >>> y = b
207
- >>> # do something with `x` and `y`
208
- >>>
209
- >>> with m.add_kernel_template(bar_tmpl) as kt:
210
- >>> kt.instantiate(a=x, b=y)
211
-
212
- TODO:
213
- * Support external array
214
- """
215
- kt = KernelTemplate(kernel_fn, self)
216
- yield kt
217
-
218
- def save(self, filepath):
219
- """
220
- Args:
221
- filepath (str): path to a folder to store aot files.
222
- """
223
- filepath = str(PurePosixPath(Path(filepath)))
224
- self._aot_builder.dump(filepath, "")
225
- with open(f"{filepath}/__content__", "w") as f:
226
- f.write("\n".join(self._content))
227
- with open(f"{filepath}/__version__", "w") as f:
228
- f.write(".".join(str(x) for x in taichi.__version__))
229
-
230
- def archive(self, filepath: str):
231
- """
232
- Args:
233
- filepath (str): path to the stored archive of aot artifacts, MUST
234
- end with `.tcm`.
235
- """
236
- assert filepath.endswith(".tcm"), "AOT module artifact archive must ends with .tcm"
237
- tcm_path = Path(filepath).absolute()
238
- assert tcm_path.parent.exists(), "Output directory doesn't exist"
239
-
240
- temp_dir = mkdtemp(prefix="tcm_")
241
- # Save first as usual.
242
- self.save(temp_dir)
243
-
244
- fixed_time = datetime.datetime(2000, 12, 1).timestamp()
245
-
246
- # Package all artifacts into a zip archive and attach contend data.
247
- with ZipFile(tcm_path, "w") as z:
248
- for path in glob(f"{temp_dir}/*", recursive=True):
249
- os.utime(path, (fixed_time, fixed_time))
250
- z.write(path, Path.relative_to(Path(path), temp_dir))
251
-
252
- # Remove cached files
253
- rmtree(temp_dir)
taichi/aot/utils.py DELETED
@@ -1,151 +0,0 @@
1
- # type: ignore
2
-
3
- from typing import Any
4
-
5
- from taichi.lang._ndarray import ScalarNdarray
6
- from taichi.lang._texture import Texture
7
- from taichi.lang.exception import TaichiCompilationError
8
- from taichi.lang.matrix import (
9
- Matrix,
10
- MatrixNdarray,
11
- MatrixType,
12
- VectorNdarray,
13
- VectorType,
14
- )
15
- from taichi.lang.util import cook_dtype
16
- from taichi.types.annotations import template
17
- from taichi.types.enums import Format
18
- from taichi.types.ndarray_type import NdarrayType
19
- from taichi.types.texture_type import RWTextureType, TextureType
20
-
21
- template_types = (NdarrayType, TextureType, template)
22
-
23
-
24
- def check_type_match(lhs, rhs):
25
- if isinstance(lhs, MatrixType) and isinstance(rhs, MatrixType):
26
- return lhs.n == rhs.n and lhs.m == rhs.m and (lhs.dtype == rhs.dtype or lhs.dtype is None or rhs.dtype is None)
27
- if isinstance(lhs, MatrixType) or isinstance(rhs, MatrixType):
28
- return False
29
-
30
- return cook_dtype(lhs) == cook_dtype(rhs)
31
-
32
-
33
- def produce_injected_args_from_template(kernel, template_args):
34
- injected_args = []
35
- num_template_args = len([arg.annotation for arg in kernel.arguments if isinstance(arg.annotation, template_types)])
36
- assert num_template_args == len(
37
- template_args
38
- ), f"Need {num_template_args} inputs to instantiate the template parameters, got {len(template_args)}"
39
- for arg in kernel.arguments:
40
- anno = arg.annotation
41
- if isinstance(anno, template_types):
42
- injected_args.append(template_args[arg.name])
43
- elif isinstance(anno, RWTextureType):
44
- texture_shape = (2,) * anno.num_dimensions
45
- fmt = anno.fmt
46
- injected_args.append(Texture(fmt, texture_shape))
47
- else:
48
- injected_args.append(0)
49
- return injected_args
50
-
51
-
52
- def produce_injected_args(kernel, symbolic_args=None):
53
- injected_args = []
54
- for i, arg in enumerate(kernel.arguments):
55
- anno = arg.annotation
56
- if isinstance(anno, NdarrayType):
57
- if symbolic_args is not None:
58
- # TODO: reconstruct dtype to be TensorType from taichi_core instead of the Python ones
59
- element_dim = len(symbolic_args[i].element_shape)
60
- if element_dim == 0 or symbolic_args[i].element_shape == (1,):
61
- dtype = symbolic_args[i].dtype()
62
- elif element_dim == 1:
63
- dtype = VectorType(symbolic_args[i].element_shape[0], symbolic_args[i].dtype())
64
- elif element_dim == 2:
65
- dtype = MatrixType(
66
- symbolic_args[i].element_shape[0],
67
- symbolic_args[i].element_shape[1],
68
- 2,
69
- symbolic_args[i].dtype(),
70
- )
71
- else:
72
- raise TaichiCompilationError("Not supported")
73
- ndim = symbolic_args[i].field_dim
74
- else:
75
- ndim = anno.ndim
76
- dtype = anno.dtype
77
-
78
- if anno.ndim is not None and ndim != anno.ndim:
79
- raise TaichiCompilationError(
80
- f"{ndim} from Arg {arg.name} doesn't match kernel's annotated ndim={anno.ndim}"
81
- )
82
-
83
- if anno.dtype is not None and not check_type_match(dtype, anno.dtype):
84
- raise TaichiCompilationError(
85
- f" Arg {arg.name}'s dtype {dtype.to_string()} doesn't match kernel's annotated dtype={anno.dtype.to_string()}"
86
- )
87
-
88
- if isinstance(dtype, VectorType):
89
- injected_args.append(VectorNdarray(dtype.n, dtype=dtype.dtype, shape=(2,) * ndim))
90
- elif isinstance(dtype, MatrixType):
91
- injected_args.append(MatrixNdarray(dtype.n, dtype.m, dtype=dtype.dtype, shape=(2,) * ndim))
92
- else:
93
- injected_args.append(ScalarNdarray(dtype, (2,) * ndim))
94
- elif isinstance(anno, RWTextureType):
95
- texture_shape = (2,) * anno.num_dimensions
96
- fmt = anno.fmt
97
- injected_args.append(Texture(fmt, texture_shape))
98
- elif isinstance(anno, TextureType):
99
- texture_shape = (2,) * anno.num_dimensions
100
- injected_args.append(Texture(Format.rgba8, texture_shape))
101
- elif isinstance(anno, MatrixType):
102
- if symbolic_args is not None:
103
- symbolic_mat_n = symbolic_args[i].element_shape[0]
104
- symbolic_mat_m = symbolic_args[i].element_shape[1]
105
-
106
- if symbolic_mat_m != anno.m or symbolic_mat_n != anno.n:
107
- raise RuntimeError(
108
- f"Matrix dimension mismatch, expected ({anno.n}, {anno.m}) "
109
- f"but dispatched shape ({symbolic_mat_n}, {symbolic_mat_m})."
110
- )
111
- injected_args.append(Matrix([0] * anno.n * anno.m, dt=anno.dtype))
112
- else:
113
- if symbolic_args is not None:
114
- dtype = symbolic_args[i].dtype()
115
- else:
116
- dtype = anno
117
-
118
- if not check_type_match(dtype, anno):
119
- raise TaichiCompilationError(
120
- f" Arg {arg.name}'s dtype {dtype.to_string()} doesn't match kernel's annotated dtype={anno.to_string()}"
121
- )
122
- # For primitive types, we can just inject a dummy value.
123
- injected_args.append(0)
124
- return injected_args
125
-
126
-
127
- def json_data_model(f):
128
- """
129
- Decorates a JSON data model. A JSON data model MUST NOT have any member
130
- functions and it MUST be constructible from a JSON object.
131
-
132
- This is merely a marker.
133
- """
134
- f._is_json_data_model = True
135
- return f
136
-
137
-
138
- def is_json_data_model(cls) -> bool:
139
- return hasattr(cls, "_is_json_data_model")
140
-
141
-
142
- def dump_json_data_model(x: object) -> Any:
143
- if isinstance(x, (int, float, str, bool, type(None))):
144
- return x
145
- if isinstance(x, (list, tuple)):
146
- return [dump_json_data_model(e) for e in x]
147
- if isinstance(x, dict):
148
- return {k: dump_json_data_model(v) for k, v in x.items()}
149
- if is_json_data_model(x):
150
- return {k: dump_json_data_model(v) for k, v in x.__dict__.items() if k != "_is_json_data_model"}
151
- return x
taichi/graph/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- # type: ignore
2
-
3
- from ._graph import *
taichi/graph/_graph.py DELETED
@@ -1,292 +0,0 @@
1
- # type: ignore
2
-
3
- import warnings
4
- from typing import Any, Dict, List
5
-
6
- from taichi._lib import core as _ti_core
7
- from taichi.aot.utils import produce_injected_args
8
- from taichi.lang import impl, kernel_impl
9
- from taichi.lang._ndarray import Ndarray
10
- from taichi.lang._texture import Texture
11
- from taichi.lang.exception import TaichiRuntimeError
12
- from taichi.lang.matrix import Matrix, MatrixType
13
- from taichi.types import enums
14
- from taichi.types.texture_type import FORMAT2TY_CH, TY_CH2FORMAT
15
-
16
- ArgKind = _ti_core.ArgKind
17
-
18
-
19
- def gen_cpp_kernel(kernel_fn, args):
20
- kernel = kernel_fn._primal
21
- assert isinstance(kernel, kernel_impl.Kernel)
22
- injected_args = produce_injected_args(kernel, symbolic_args=args)
23
- key = kernel.ensure_compiled(*injected_args)
24
- return kernel.compiled_kernels[key]
25
-
26
-
27
- def flatten_args(args):
28
- unzipped_args = []
29
- # Tuple for matrix args
30
- # FIXME remove this when native Matrix type is ready
31
- for arg in args:
32
- if isinstance(arg, list):
33
- for sublist in arg:
34
- unzipped_args.extend(sublist)
35
- else:
36
- unzipped_args.append(arg)
37
- return unzipped_args
38
-
39
-
40
- class Sequential:
41
- def __init__(self, seq):
42
- self.seq_ = seq
43
-
44
- def dispatch(self, kernel_fn, *args):
45
- kernel_cpp = gen_cpp_kernel(kernel_fn, args)
46
- unzipped_args = flatten_args(args)
47
- self.seq_.dispatch(kernel_cpp, unzipped_args)
48
-
49
-
50
- class GraphBuilder:
51
- def __init__(self):
52
- self._graph_builder = _ti_core.GraphBuilder()
53
-
54
- def dispatch(self, kernel_fn, *args):
55
- kernel_cpp = gen_cpp_kernel(kernel_fn, args)
56
- unzipped_args = flatten_args(args)
57
- self._graph_builder.dispatch(kernel_cpp, unzipped_args)
58
-
59
- def create_sequential(self):
60
- return Sequential(self._graph_builder.create_sequential())
61
-
62
- def append(self, node):
63
- # TODO: support appending dispatch node as well.
64
- assert isinstance(node, Sequential)
65
- self._graph_builder.seq().append(node.seq_)
66
-
67
- def compile(self):
68
- return Graph(self._graph_builder.compile())
69
-
70
-
71
- class Graph:
72
- def __init__(self, compiled_graph) -> None:
73
- self._compiled_graph = compiled_graph
74
-
75
- def run(self, args):
76
- # Support native python numerical types (int, float), Ndarray.
77
- # Taichi Matrix types are flattened into (int, float) arrays.
78
- # TODO diminish the flatten behavior when Matrix becomes a Taichi native type.
79
- flattened = {}
80
- for k, v in args.items():
81
- if isinstance(v, Ndarray):
82
- flattened[k] = v.arr
83
- elif isinstance(v, Texture):
84
- flattened[k] = v.tex
85
- elif isinstance(v, Matrix):
86
- flattened[k] = v.entries
87
- elif isinstance(v, (int, float)):
88
- flattened[k] = v
89
- else:
90
- raise TaichiRuntimeError(
91
- f"Only python int, float, ti.Matrix and ti.Ndarray are supported as runtime arguments but got {type(v)}"
92
- )
93
- self._compiled_graph.jit_run(impl.get_runtime().prog.config(), flattened)
94
-
95
-
96
- def _deprecate_arg_args(kwargs: Dict[str, Any]):
97
- if "field_dim" in kwargs:
98
- warnings.warn(
99
- "The field_dim argument for ndarray will be deprecated in v1.6.0, use ndim instead.",
100
- DeprecationWarning,
101
- )
102
- if "ndim" in kwargs:
103
- raise TaichiRuntimeError(
104
- "field_dim is deprecated, please do not specify field_dim and ndim at the same time."
105
- )
106
- kwargs["ndim"] = kwargs["field_dim"]
107
- del kwargs["field_dim"]
108
- tag = kwargs["tag"]
109
-
110
- if tag == ArgKind.SCALAR:
111
- if "element_shape" in kwargs:
112
- raise TaichiRuntimeError(
113
- "The element_shape argument for scalar is deprecated in v1.6.0, and is removed in v1.7.0. "
114
- "Please remove them."
115
- )
116
-
117
- if tag == ArgKind.NDARRAY:
118
- if "element_shape" not in kwargs:
119
- if "dtype" in kwargs:
120
- dtype = kwargs["dtype"]
121
- if isinstance(dtype, MatrixType):
122
- kwargs["dtype"] = dtype.dtype
123
- kwargs["element_shape"] = dtype.get_shape()
124
- else:
125
- kwargs["element_shape"] = ()
126
- else:
127
- raise TaichiRuntimeError(
128
- "The element_shape argument for ndarray is deprecated in v1.6.0, and it is removed in v1.7.0. "
129
- "Please use vector or matrix data type instead."
130
- )
131
-
132
- if tag == ArgKind.RWTEXTURE or tag == ArgKind.TEXTURE:
133
- if "dtype" in kwargs:
134
- warnings.warn(
135
- "The dtype argument for texture will be deprecated in v1.6.0, use format instead.",
136
- DeprecationWarning,
137
- )
138
- del kwargs["dtype"]
139
-
140
- if "shape" in kwargs:
141
- raise TaichiRuntimeError(
142
- "The shape argument for texture is deprecated in v1.6.0, and it is removed in v1.7.0. "
143
- "Please use ndim instead. (Note that you no longer need the exact texture size.)"
144
- )
145
-
146
- if "channel_format" in kwargs or "num_channels" in kwargs:
147
- if "fmt" in kwargs:
148
- raise TaichiRuntimeError(
149
- "channel_format and num_channels are deprecated, please do not specify channel_format/num_channels and fmt at the same time."
150
- )
151
- if tag == ArgKind.RWTEXTURE:
152
- fmt = TY_CH2FORMAT[(kwargs["channel_format"], kwargs["num_channels"])]
153
- kwargs["fmt"] = fmt
154
- raise TaichiRuntimeError(
155
- "The channel_format and num_channels arguments for texture are deprecated in v1.6.0, "
156
- "and they are removed in v1.7.0. Please use fmt instead."
157
- )
158
- else:
159
- raise TaichiRuntimeError(
160
- "The channel_format and num_channels arguments are no longer required for non-RW textures "
161
- "since v1.6.0, and they are removed in v1.7.0. Please remove them."
162
- )
163
-
164
-
165
- def _check_args(kwargs: Dict[str, Any], allowed_kwargs: List[str]):
166
- for k, v in kwargs.items():
167
- if k not in allowed_kwargs:
168
- raise TaichiRuntimeError(
169
- f"Invalid argument: {k}, you can only create a graph argument with: {allowed_kwargs}"
170
- )
171
- if k == "tag":
172
- if not isinstance(v, ArgKind):
173
- raise TaichiRuntimeError(f"tag must be a ArgKind variant, but found {type(v)}.")
174
- if k == "name":
175
- if not isinstance(v, str):
176
- raise TaichiRuntimeError(f"name must be a string, but found {type(v)}.")
177
-
178
-
179
- def _make_arg_scalar(kwargs: Dict[str, Any]):
180
- allowed_kwargs = [
181
- "tag",
182
- "name",
183
- "dtype",
184
- ]
185
- _check_args(kwargs, allowed_kwargs)
186
- name = kwargs["name"]
187
- dtype = kwargs["dtype"]
188
- if isinstance(dtype, MatrixType):
189
- raise TaichiRuntimeError(f"Tag ArgKind.SCALAR must specify a scalar type, but found {type(dtype)}.")
190
- return _ti_core.Arg(ArgKind.SCALAR, name, dtype, 0, [])
191
-
192
-
193
- def _make_arg_ndarray(kwargs: Dict[str, Any]):
194
- allowed_kwargs = [
195
- "tag",
196
- "name",
197
- "dtype",
198
- "ndim",
199
- "element_shape",
200
- ]
201
- _check_args(kwargs, allowed_kwargs)
202
- name = kwargs["name"]
203
- ndim = kwargs["ndim"]
204
- dtype = kwargs["dtype"]
205
- element_shape = kwargs["element_shape"]
206
- if isinstance(dtype, MatrixType):
207
- raise TaichiRuntimeError(f"Tag ArgKind.NDARRAY must specify a scalar type, but found {dtype}.")
208
- return _ti_core.Arg(ArgKind.NDARRAY, name, dtype, ndim, element_shape)
209
-
210
-
211
- def _make_arg_matrix(kwargs: Dict[str, Any]):
212
- allowed_kwargs = [
213
- "tag",
214
- "name",
215
- "dtype",
216
- ]
217
- _check_args(kwargs, allowed_kwargs)
218
- name = kwargs["name"]
219
- dtype = kwargs["dtype"]
220
- if not isinstance(dtype, MatrixType):
221
- raise TaichiRuntimeError(f"Tag ArgKind.MATRIX must specify matrix type, but got {dtype}.")
222
- return _ti_core.Arg(ArgKind.MATRIX, f"{name}", dtype.dtype, 0, [dtype.n, dtype.m])
223
-
224
-
225
- def _make_arg_texture(kwargs: Dict[str, Any]):
226
- allowed_kwargs = [
227
- "tag",
228
- "name",
229
- "ndim",
230
- ]
231
- _check_args(kwargs, allowed_kwargs)
232
- name = kwargs["name"]
233
- ndim = kwargs["ndim"]
234
- return _ti_core.Arg(ArgKind.TEXTURE, name, impl.f32, 4, [2] * ndim)
235
-
236
-
237
- def _make_arg_rwtexture(kwargs: Dict[str, Any]):
238
- allowed_kwargs = [
239
- "tag",
240
- "name",
241
- "ndim",
242
- "fmt",
243
- ]
244
- _check_args(kwargs, allowed_kwargs)
245
- name = kwargs["name"]
246
- ndim = kwargs["ndim"]
247
- fmt = kwargs["fmt"]
248
- if fmt == enums.Format.unknown:
249
- raise TaichiRuntimeError(f"Tag ArgKind.RWTEXTURE must specify a valid color format, but found {fmt}.")
250
- channel_format, num_channels = FORMAT2TY_CH[fmt]
251
- return _ti_core.Arg(ArgKind.RWTEXTURE, name, channel_format, num_channels, [2] * ndim)
252
-
253
-
254
- def _make_arg(kwargs: Dict[str, Any]):
255
- assert "tag" in kwargs
256
- _deprecate_arg_args(kwargs)
257
- proc = {
258
- ArgKind.SCALAR: _make_arg_scalar,
259
- ArgKind.NDARRAY: _make_arg_ndarray,
260
- ArgKind.MATRIX: _make_arg_matrix,
261
- ArgKind.TEXTURE: _make_arg_texture,
262
- ArgKind.RWTEXTURE: _make_arg_rwtexture,
263
- }
264
- tag = kwargs["tag"]
265
- return proc[tag](kwargs)
266
-
267
-
268
- def _kwarg_rewriter(args, kwargs):
269
- for i, arg in enumerate(args):
270
- rewrite_map = {
271
- 0: "tag",
272
- 1: "name",
273
- 2: "dtype",
274
- 3: "ndim",
275
- 4: "field_dim",
276
- 5: "element_shape",
277
- 6: "channel_format",
278
- 7: "shape",
279
- 8: "num_channels",
280
- }
281
- if i in rewrite_map:
282
- kwargs[rewrite_map[i]] = arg
283
- else:
284
- raise TaichiRuntimeError(f"Unexpected {i}th positional argument")
285
-
286
-
287
- def Arg(*args, **kwargs):
288
- _kwarg_rewriter(args, kwargs)
289
- return _make_arg(kwargs)
290
-
291
-
292
- __all__ = ["GraphBuilder", "Graph", "Arg", "ArgKind"]