gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  2. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  3. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  4. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  5. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  6. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  7. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  8. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  9. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  10. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  11. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  12. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  13. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  14. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  15. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  16. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  17. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  18. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  19. gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  20. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  21. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  25. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  26. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
  39. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  40. gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
  41. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  42. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  43. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  44. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  45. taichi/CHANGELOG.md +15 -0
  46. taichi/__init__.py +44 -0
  47. taichi/__main__.py +5 -0
  48. taichi/_funcs.py +706 -0
  49. taichi/_kernels.py +420 -0
  50. taichi/_lib/__init__.py +3 -0
  51. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  52. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  53. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  54. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  55. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  56. taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
  57. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  58. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  59. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  60. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  61. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  62. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  63. taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
  64. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  65. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  66. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  67. taichi/_lib/core/__init__.py +0 -0
  68. taichi/_lib/core/py.typed +0 -0
  69. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
  70. taichi/_lib/core/taichi_python.pyi +3077 -0
  71. taichi/_lib/runtime/runtime_cuda.bc +0 -0
  72. taichi/_lib/runtime/runtime_x64.bc +0 -0
  73. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  74. taichi/_lib/utils.py +249 -0
  75. taichi/_logging.py +131 -0
  76. taichi/_main.py +552 -0
  77. taichi/_snode/__init__.py +5 -0
  78. taichi/_snode/fields_builder.py +189 -0
  79. taichi/_snode/snode_tree.py +34 -0
  80. taichi/_ti_module/__init__.py +3 -0
  81. taichi/_ti_module/cppgen.py +309 -0
  82. taichi/_ti_module/module.py +145 -0
  83. taichi/_version.py +1 -0
  84. taichi/_version_check.py +100 -0
  85. taichi/ad/__init__.py +3 -0
  86. taichi/ad/_ad.py +530 -0
  87. taichi/algorithms/__init__.py +3 -0
  88. taichi/algorithms/_algorithms.py +117 -0
  89. taichi/aot/__init__.py +12 -0
  90. taichi/aot/_export.py +28 -0
  91. taichi/aot/conventions/__init__.py +3 -0
  92. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  93. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  94. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  95. taichi/aot/module.py +253 -0
  96. taichi/aot/utils.py +151 -0
  97. taichi/assets/.git +1 -0
  98. taichi/assets/Go-Regular.ttf +0 -0
  99. taichi/assets/static/imgs/ti_gallery.png +0 -0
  100. taichi/examples/minimal.py +28 -0
  101. taichi/experimental.py +16 -0
  102. taichi/graph/__init__.py +3 -0
  103. taichi/graph/_graph.py +292 -0
  104. taichi/lang/__init__.py +50 -0
  105. taichi/lang/_ndarray.py +348 -0
  106. taichi/lang/_ndrange.py +152 -0
  107. taichi/lang/_texture.py +172 -0
  108. taichi/lang/_wrap_inspect.py +189 -0
  109. taichi/lang/any_array.py +99 -0
  110. taichi/lang/argpack.py +411 -0
  111. taichi/lang/ast/__init__.py +5 -0
  112. taichi/lang/ast/ast_transformer.py +1806 -0
  113. taichi/lang/ast/ast_transformer_utils.py +328 -0
  114. taichi/lang/ast/checkers.py +106 -0
  115. taichi/lang/ast/symbol_resolver.py +57 -0
  116. taichi/lang/ast/transform.py +9 -0
  117. taichi/lang/common_ops.py +310 -0
  118. taichi/lang/exception.py +80 -0
  119. taichi/lang/expr.py +180 -0
  120. taichi/lang/field.py +464 -0
  121. taichi/lang/impl.py +1246 -0
  122. taichi/lang/kernel_arguments.py +157 -0
  123. taichi/lang/kernel_impl.py +1415 -0
  124. taichi/lang/matrix.py +1877 -0
  125. taichi/lang/matrix_ops.py +341 -0
  126. taichi/lang/matrix_ops_utils.py +190 -0
  127. taichi/lang/mesh.py +687 -0
  128. taichi/lang/misc.py +807 -0
  129. taichi/lang/ops.py +1489 -0
  130. taichi/lang/runtime_ops.py +13 -0
  131. taichi/lang/shell.py +35 -0
  132. taichi/lang/simt/__init__.py +5 -0
  133. taichi/lang/simt/block.py +94 -0
  134. taichi/lang/simt/grid.py +7 -0
  135. taichi/lang/simt/subgroup.py +191 -0
  136. taichi/lang/simt/warp.py +96 -0
  137. taichi/lang/snode.py +487 -0
  138. taichi/lang/source_builder.py +150 -0
  139. taichi/lang/struct.py +855 -0
  140. taichi/lang/util.py +381 -0
  141. taichi/linalg/__init__.py +8 -0
  142. taichi/linalg/matrixfree_cg.py +310 -0
  143. taichi/linalg/sparse_cg.py +59 -0
  144. taichi/linalg/sparse_matrix.py +303 -0
  145. taichi/linalg/sparse_solver.py +123 -0
  146. taichi/math/__init__.py +11 -0
  147. taichi/math/_complex.py +204 -0
  148. taichi/math/mathimpl.py +886 -0
  149. taichi/profiler/__init__.py +6 -0
  150. taichi/profiler/kernel_metrics.py +260 -0
  151. taichi/profiler/kernel_profiler.py +592 -0
  152. taichi/profiler/memory_profiler.py +15 -0
  153. taichi/profiler/scoped_profiler.py +36 -0
  154. taichi/shaders/Circles_vk.frag +29 -0
  155. taichi/shaders/Circles_vk.vert +45 -0
  156. taichi/shaders/Circles_vk_frag.spv +0 -0
  157. taichi/shaders/Circles_vk_vert.spv +0 -0
  158. taichi/shaders/Lines_vk.frag +9 -0
  159. taichi/shaders/Lines_vk.vert +11 -0
  160. taichi/shaders/Lines_vk_frag.spv +0 -0
  161. taichi/shaders/Lines_vk_vert.spv +0 -0
  162. taichi/shaders/Mesh_vk.frag +71 -0
  163. taichi/shaders/Mesh_vk.vert +68 -0
  164. taichi/shaders/Mesh_vk_frag.spv +0 -0
  165. taichi/shaders/Mesh_vk_vert.spv +0 -0
  166. taichi/shaders/Particles_vk.frag +95 -0
  167. taichi/shaders/Particles_vk.vert +73 -0
  168. taichi/shaders/Particles_vk_frag.spv +0 -0
  169. taichi/shaders/Particles_vk_vert.spv +0 -0
  170. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  171. taichi/shaders/SceneLines_vk.frag +9 -0
  172. taichi/shaders/SceneLines_vk.vert +12 -0
  173. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  174. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  175. taichi/shaders/SetImage_vk.frag +21 -0
  176. taichi/shaders/SetImage_vk.vert +15 -0
  177. taichi/shaders/SetImage_vk_frag.spv +0 -0
  178. taichi/shaders/SetImage_vk_vert.spv +0 -0
  179. taichi/shaders/Triangles_vk.frag +16 -0
  180. taichi/shaders/Triangles_vk.vert +29 -0
  181. taichi/shaders/Triangles_vk_frag.spv +0 -0
  182. taichi/shaders/Triangles_vk_vert.spv +0 -0
  183. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  184. taichi/sparse/__init__.py +3 -0
  185. taichi/sparse/_sparse_grid.py +77 -0
  186. taichi/tools/__init__.py +12 -0
  187. taichi/tools/diagnose.py +124 -0
  188. taichi/tools/np2ply.py +364 -0
  189. taichi/tools/vtk.py +38 -0
  190. taichi/types/__init__.py +19 -0
  191. taichi/types/annotations.py +47 -0
  192. taichi/types/compound_types.py +90 -0
  193. taichi/types/enums.py +49 -0
  194. taichi/types/ndarray_type.py +147 -0
  195. taichi/types/primitive_types.py +203 -0
  196. taichi/types/quant.py +88 -0
  197. taichi/types/texture_type.py +85 -0
  198. taichi/types/utils.py +13 -0
taichi/graph/_graph.py ADDED
@@ -0,0 +1,292 @@
1
+ # type: ignore
2
+
3
+ import warnings
4
+ from typing import Any, Dict, List
5
+
6
+ from taichi._lib import core as _ti_core
7
+ from taichi.aot.utils import produce_injected_args
8
+ from taichi.lang import impl, kernel_impl
9
+ from taichi.lang._ndarray import Ndarray
10
+ from taichi.lang._texture import Texture
11
+ from taichi.lang.exception import TaichiRuntimeError
12
+ from taichi.lang.matrix import Matrix, MatrixType
13
+ from taichi.types import enums
14
+ from taichi.types.texture_type import FORMAT2TY_CH, TY_CH2FORMAT
15
+
16
+ ArgKind = _ti_core.ArgKind
17
+
18
+
19
+ def gen_cpp_kernel(kernel_fn, args):
20
+ kernel = kernel_fn._primal
21
+ assert isinstance(kernel, kernel_impl.Kernel)
22
+ injected_args = produce_injected_args(kernel, symbolic_args=args)
23
+ key = kernel.ensure_compiled(*injected_args)
24
+ return kernel.compiled_kernels[key]
25
+
26
+
27
+ def flatten_args(args):
28
+ unzipped_args = []
29
+ # Tuple for matrix args
30
+ # FIXME remove this when native Matrix type is ready
31
+ for arg in args:
32
+ if isinstance(arg, list):
33
+ for sublist in arg:
34
+ unzipped_args.extend(sublist)
35
+ else:
36
+ unzipped_args.append(arg)
37
+ return unzipped_args
38
+
39
+
40
+ class Sequential:
41
+ def __init__(self, seq):
42
+ self.seq_ = seq
43
+
44
+ def dispatch(self, kernel_fn, *args):
45
+ kernel_cpp = gen_cpp_kernel(kernel_fn, args)
46
+ unzipped_args = flatten_args(args)
47
+ self.seq_.dispatch(kernel_cpp, unzipped_args)
48
+
49
+
50
+ class GraphBuilder:
51
+ def __init__(self):
52
+ self._graph_builder = _ti_core.GraphBuilder()
53
+
54
+ def dispatch(self, kernel_fn, *args):
55
+ kernel_cpp = gen_cpp_kernel(kernel_fn, args)
56
+ unzipped_args = flatten_args(args)
57
+ self._graph_builder.dispatch(kernel_cpp, unzipped_args)
58
+
59
+ def create_sequential(self):
60
+ return Sequential(self._graph_builder.create_sequential())
61
+
62
+ def append(self, node):
63
+ # TODO: support appending dispatch node as well.
64
+ assert isinstance(node, Sequential)
65
+ self._graph_builder.seq().append(node.seq_)
66
+
67
+ def compile(self):
68
+ return Graph(self._graph_builder.compile())
69
+
70
+
71
+ class Graph:
72
+ def __init__(self, compiled_graph) -> None:
73
+ self._compiled_graph = compiled_graph
74
+
75
+ def run(self, args):
76
+ # Support native python numerical types (int, float), Ndarray.
77
+ # Taichi Matrix types are flattened into (int, float) arrays.
78
+ # TODO diminish the flatten behavior when Matrix becomes a Taichi native type.
79
+ flattened = {}
80
+ for k, v in args.items():
81
+ if isinstance(v, Ndarray):
82
+ flattened[k] = v.arr
83
+ elif isinstance(v, Texture):
84
+ flattened[k] = v.tex
85
+ elif isinstance(v, Matrix):
86
+ flattened[k] = v.entries
87
+ elif isinstance(v, (int, float)):
88
+ flattened[k] = v
89
+ else:
90
+ raise TaichiRuntimeError(
91
+ f"Only python int, float, ti.Matrix and ti.Ndarray are supported as runtime arguments but got {type(v)}"
92
+ )
93
+ self._compiled_graph.jit_run(impl.get_runtime().prog.config(), flattened)
94
+
95
+
96
+ def _deprecate_arg_args(kwargs: Dict[str, Any]):
97
+ if "field_dim" in kwargs:
98
+ warnings.warn(
99
+ "The field_dim argument for ndarray will be deprecated in v1.6.0, use ndim instead.",
100
+ DeprecationWarning,
101
+ )
102
+ if "ndim" in kwargs:
103
+ raise TaichiRuntimeError(
104
+ "field_dim is deprecated, please do not specify field_dim and ndim at the same time."
105
+ )
106
+ kwargs["ndim"] = kwargs["field_dim"]
107
+ del kwargs["field_dim"]
108
+ tag = kwargs["tag"]
109
+
110
+ if tag == ArgKind.SCALAR:
111
+ if "element_shape" in kwargs:
112
+ raise TaichiRuntimeError(
113
+ "The element_shape argument for scalar is deprecated in v1.6.0, and is removed in v1.7.0. "
114
+ "Please remove them."
115
+ )
116
+
117
+ if tag == ArgKind.NDARRAY:
118
+ if "element_shape" not in kwargs:
119
+ if "dtype" in kwargs:
120
+ dtype = kwargs["dtype"]
121
+ if isinstance(dtype, MatrixType):
122
+ kwargs["dtype"] = dtype.dtype
123
+ kwargs["element_shape"] = dtype.get_shape()
124
+ else:
125
+ kwargs["element_shape"] = ()
126
+ else:
127
+ raise TaichiRuntimeError(
128
+ "The element_shape argument for ndarray is deprecated in v1.6.0, and it is removed in v1.7.0. "
129
+ "Please use vector or matrix data type instead."
130
+ )
131
+
132
+ if tag == ArgKind.RWTEXTURE or tag == ArgKind.TEXTURE:
133
+ if "dtype" in kwargs:
134
+ warnings.warn(
135
+ "The dtype argument for texture will be deprecated in v1.6.0, use format instead.",
136
+ DeprecationWarning,
137
+ )
138
+ del kwargs["dtype"]
139
+
140
+ if "shape" in kwargs:
141
+ raise TaichiRuntimeError(
142
+ "The shape argument for texture is deprecated in v1.6.0, and it is removed in v1.7.0. "
143
+ "Please use ndim instead. (Note that you no longer need the exact texture size.)"
144
+ )
145
+
146
+ if "channel_format" in kwargs or "num_channels" in kwargs:
147
+ if "fmt" in kwargs:
148
+ raise TaichiRuntimeError(
149
+ "channel_format and num_channels are deprecated, please do not specify channel_format/num_channels and fmt at the same time."
150
+ )
151
+ if tag == ArgKind.RWTEXTURE:
152
+ fmt = TY_CH2FORMAT[(kwargs["channel_format"], kwargs["num_channels"])]
153
+ kwargs["fmt"] = fmt
154
+ raise TaichiRuntimeError(
155
+ "The channel_format and num_channels arguments for texture are deprecated in v1.6.0, "
156
+ "and they are removed in v1.7.0. Please use fmt instead."
157
+ )
158
+ else:
159
+ raise TaichiRuntimeError(
160
+ "The channel_format and num_channels arguments are no longer required for non-RW textures "
161
+ "since v1.6.0, and they are removed in v1.7.0. Please remove them."
162
+ )
163
+
164
+
165
+ def _check_args(kwargs: Dict[str, Any], allowed_kwargs: List[str]):
166
+ for k, v in kwargs.items():
167
+ if k not in allowed_kwargs:
168
+ raise TaichiRuntimeError(
169
+ f"Invalid argument: {k}, you can only create a graph argument with: {allowed_kwargs}"
170
+ )
171
+ if k == "tag":
172
+ if not isinstance(v, ArgKind):
173
+ raise TaichiRuntimeError(f"tag must be a ArgKind variant, but found {type(v)}.")
174
+ if k == "name":
175
+ if not isinstance(v, str):
176
+ raise TaichiRuntimeError(f"name must be a string, but found {type(v)}.")
177
+
178
+
179
+ def _make_arg_scalar(kwargs: Dict[str, Any]):
180
+ allowed_kwargs = [
181
+ "tag",
182
+ "name",
183
+ "dtype",
184
+ ]
185
+ _check_args(kwargs, allowed_kwargs)
186
+ name = kwargs["name"]
187
+ dtype = kwargs["dtype"]
188
+ if isinstance(dtype, MatrixType):
189
+ raise TaichiRuntimeError(f"Tag ArgKind.SCALAR must specify a scalar type, but found {type(dtype)}.")
190
+ return _ti_core.Arg(ArgKind.SCALAR, name, dtype, 0, [])
191
+
192
+
193
+ def _make_arg_ndarray(kwargs: Dict[str, Any]):
194
+ allowed_kwargs = [
195
+ "tag",
196
+ "name",
197
+ "dtype",
198
+ "ndim",
199
+ "element_shape",
200
+ ]
201
+ _check_args(kwargs, allowed_kwargs)
202
+ name = kwargs["name"]
203
+ ndim = kwargs["ndim"]
204
+ dtype = kwargs["dtype"]
205
+ element_shape = kwargs["element_shape"]
206
+ if isinstance(dtype, MatrixType):
207
+ raise TaichiRuntimeError(f"Tag ArgKind.NDARRAY must specify a scalar type, but found {dtype}.")
208
+ return _ti_core.Arg(ArgKind.NDARRAY, name, dtype, ndim, element_shape)
209
+
210
+
211
+ def _make_arg_matrix(kwargs: Dict[str, Any]):
212
+ allowed_kwargs = [
213
+ "tag",
214
+ "name",
215
+ "dtype",
216
+ ]
217
+ _check_args(kwargs, allowed_kwargs)
218
+ name = kwargs["name"]
219
+ dtype = kwargs["dtype"]
220
+ if not isinstance(dtype, MatrixType):
221
+ raise TaichiRuntimeError(f"Tag ArgKind.MATRIX must specify matrix type, but got {dtype}.")
222
+ return _ti_core.Arg(ArgKind.MATRIX, f"{name}", dtype.dtype, 0, [dtype.n, dtype.m])
223
+
224
+
225
+ def _make_arg_texture(kwargs: Dict[str, Any]):
226
+ allowed_kwargs = [
227
+ "tag",
228
+ "name",
229
+ "ndim",
230
+ ]
231
+ _check_args(kwargs, allowed_kwargs)
232
+ name = kwargs["name"]
233
+ ndim = kwargs["ndim"]
234
+ return _ti_core.Arg(ArgKind.TEXTURE, name, impl.f32, 4, [2] * ndim)
235
+
236
+
237
+ def _make_arg_rwtexture(kwargs: Dict[str, Any]):
238
+ allowed_kwargs = [
239
+ "tag",
240
+ "name",
241
+ "ndim",
242
+ "fmt",
243
+ ]
244
+ _check_args(kwargs, allowed_kwargs)
245
+ name = kwargs["name"]
246
+ ndim = kwargs["ndim"]
247
+ fmt = kwargs["fmt"]
248
+ if fmt == enums.Format.unknown:
249
+ raise TaichiRuntimeError(f"Tag ArgKind.RWTEXTURE must specify a valid color format, but found {fmt}.")
250
+ channel_format, num_channels = FORMAT2TY_CH[fmt]
251
+ return _ti_core.Arg(ArgKind.RWTEXTURE, name, channel_format, num_channels, [2] * ndim)
252
+
253
+
254
+ def _make_arg(kwargs: Dict[str, Any]):
255
+ assert "tag" in kwargs
256
+ _deprecate_arg_args(kwargs)
257
+ proc = {
258
+ ArgKind.SCALAR: _make_arg_scalar,
259
+ ArgKind.NDARRAY: _make_arg_ndarray,
260
+ ArgKind.MATRIX: _make_arg_matrix,
261
+ ArgKind.TEXTURE: _make_arg_texture,
262
+ ArgKind.RWTEXTURE: _make_arg_rwtexture,
263
+ }
264
+ tag = kwargs["tag"]
265
+ return proc[tag](kwargs)
266
+
267
+
268
+ def _kwarg_rewriter(args, kwargs):
269
+ for i, arg in enumerate(args):
270
+ rewrite_map = {
271
+ 0: "tag",
272
+ 1: "name",
273
+ 2: "dtype",
274
+ 3: "ndim",
275
+ 4: "field_dim",
276
+ 5: "element_shape",
277
+ 6: "channel_format",
278
+ 7: "shape",
279
+ 8: "num_channels",
280
+ }
281
+ if i in rewrite_map:
282
+ kwargs[rewrite_map[i]] = arg
283
+ else:
284
+ raise TaichiRuntimeError(f"Unexpected {i}th positional argument")
285
+
286
+
287
+ def Arg(*args, **kwargs):
288
+ _kwarg_rewriter(args, kwargs)
289
+ return _make_arg(kwargs)
290
+
291
+
292
+ __all__ = ["GraphBuilder", "Graph", "Arg", "ArgKind"]
@@ -0,0 +1,50 @@
1
+ # type: ignore
2
+
3
+ from taichi.lang import impl, simt
4
+ from taichi.lang._ndarray import *
5
+ from taichi.lang._ndrange import ndrange
6
+ from taichi.lang._texture import Texture
7
+ from taichi.lang.argpack import *
8
+ from taichi.lang.exception import *
9
+ from taichi.lang.field import *
10
+ from taichi.lang.impl import *
11
+ from taichi.lang.kernel_impl import *
12
+ from taichi.lang.matrix import *
13
+ from taichi.lang.mesh import *
14
+ from taichi.lang.misc import * # pylint: disable=W0622
15
+ from taichi.lang.ops import * # pylint: disable=W0622
16
+ from taichi.lang.runtime_ops import *
17
+ from taichi.lang.snode import *
18
+ from taichi.lang.source_builder import *
19
+ from taichi.lang.struct import *
20
+ from taichi.types.enums import DeviceCapability, Format, Layout
21
+
22
+ __all__ = [
23
+ s
24
+ for s in dir()
25
+ if not s.startswith("_")
26
+ and s
27
+ not in [
28
+ "any_array",
29
+ "ast",
30
+ "common_ops",
31
+ "enums",
32
+ "exception",
33
+ "expr",
34
+ "impl",
35
+ "inspect",
36
+ "kernel_arguments",
37
+ "kernel_impl",
38
+ "matrix",
39
+ "mesh",
40
+ "misc",
41
+ "ops",
42
+ "platform",
43
+ "runtime_ops",
44
+ "shell",
45
+ "snode",
46
+ "source_builder",
47
+ "struct",
48
+ "util",
49
+ ]
50
+ ]
@@ -0,0 +1,348 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ from taichi._lib import core as _ti_core
6
+ from taichi.lang import impl
7
+ from taichi.lang.exception import TaichiIndexError
8
+ from taichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
9
+ from taichi.types import primitive_types
10
+ from taichi.types.enums import Layout
11
+ from taichi.types.ndarray_type import NdarrayTypeMetadata
12
+ from taichi.types.utils import is_real, is_signed
13
+
14
+
15
+ class Ndarray:
16
+ """Taichi ndarray class.
17
+
18
+ Args:
19
+ dtype (DataType): Data type of each value.
20
+ shape (Tuple[int]): Shape of the Ndarray.
21
+ """
22
+
23
+ def __init__(self):
24
+ self.host_accessor = None
25
+ self.shape = None
26
+ self.element_type = None
27
+ self.dtype = None
28
+ self.arr = None
29
+ self.layout = Layout.AOS
30
+ self.grad = None
31
+
32
+ def get_type(self):
33
+ return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
34
+
35
+ @property
36
+ def element_shape(self):
37
+ """Gets ndarray element shape.
38
+
39
+ Returns:
40
+ Tuple[Int]: Ndarray element shape.
41
+ """
42
+ raise NotImplementedError()
43
+
44
+ @python_scope
45
+ def __setitem__(self, key, value):
46
+ """Sets ndarray element in Python scope.
47
+
48
+ Args:
49
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
50
+ value (element type): Value to set.
51
+ """
52
+ raise NotImplementedError()
53
+
54
+ @python_scope
55
+ def __getitem__(self, key):
56
+ """Gets ndarray element in Python scope.
57
+
58
+ Args:
59
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
60
+
61
+ Returns:
62
+ element type: Value retrieved.
63
+ """
64
+ raise NotImplementedError()
65
+
66
+ @python_scope
67
+ def fill(self, val):
68
+ """Fills ndarray with a specific scalar value.
69
+
70
+ Args:
71
+ val (Union[int, float]): Value to fill.
72
+ """
73
+ if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
74
+ self._fill_by_kernel(val)
75
+ elif _ti_core.is_tensor(self.element_type):
76
+ self._fill_by_kernel(val)
77
+ elif self.dtype == primitive_types.f32:
78
+ impl.get_runtime().prog.fill_float(self.arr, val)
79
+ elif self.dtype == primitive_types.i32:
80
+ impl.get_runtime().prog.fill_int(self.arr, val)
81
+ elif self.dtype == primitive_types.u32:
82
+ impl.get_runtime().prog.fill_uint(self.arr, val)
83
+ else:
84
+ self._fill_by_kernel(val)
85
+
86
+ @python_scope
87
+ def _ndarray_to_numpy(self):
88
+ """Converts ndarray to a numpy array.
89
+
90
+ Returns:
91
+ numpy.ndarray: The result numpy array.
92
+ """
93
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
94
+ from taichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
95
+
96
+ ndarray_to_ext_arr(self, arr)
97
+ impl.get_runtime().sync()
98
+ return arr
99
+
100
+ @python_scope
101
+ def _ndarray_matrix_to_numpy(self, as_vector):
102
+ """Converts matrix ndarray to a numpy array.
103
+
104
+ Returns:
105
+ numpy.ndarray: The result numpy array.
106
+ """
107
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
108
+ from taichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
109
+
110
+ layout_is_aos = 1
111
+ ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
112
+ impl.get_runtime().sync()
113
+ return arr
114
+
115
+ @python_scope
116
+ def _ndarray_from_numpy(self, arr):
117
+ """Loads all values from a numpy array.
118
+
119
+ Args:
120
+ arr (numpy.ndarray): The source numpy array.
121
+ """
122
+ if not isinstance(arr, np.ndarray):
123
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
124
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
125
+ raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
126
+ if not arr.flags.c_contiguous:
127
+ arr = np.ascontiguousarray(arr)
128
+
129
+ from taichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
130
+
131
+ ext_arr_to_ndarray(arr, self)
132
+ impl.get_runtime().sync()
133
+
134
+ @python_scope
135
+ def _ndarray_matrix_from_numpy(self, arr, as_vector):
136
+ """Loads all values from a numpy array.
137
+
138
+ Args:
139
+ arr (numpy.ndarray): The source numpy array.
140
+ """
141
+ if not isinstance(arr, np.ndarray):
142
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
143
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
144
+ raise ValueError(
145
+ f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
146
+ )
147
+ if not arr.flags.c_contiguous:
148
+ arr = np.ascontiguousarray(arr)
149
+
150
+ from taichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
151
+
152
+ layout_is_aos = 1
153
+ ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
154
+ impl.get_runtime().sync()
155
+
156
+ @python_scope
157
+ def _get_element_size(self):
158
+ """Returns the size of one element in bytes.
159
+
160
+ Returns:
161
+ Size in bytes.
162
+ """
163
+ return self.arr.element_size()
164
+
165
+ @python_scope
166
+ def _get_nelement(self):
167
+ """Returns the total number of elements.
168
+
169
+ Returns:
170
+ Total number of elements.
171
+ """
172
+ return self.arr.nelement()
173
+
174
+ @python_scope
175
+ def copy_from(self, other):
176
+ """Copies all elements from another ndarray.
177
+
178
+ The shape of the other ndarray needs to be the same as `self`.
179
+
180
+ Args:
181
+ other (Ndarray): The source ndarray.
182
+ """
183
+ assert isinstance(other, Ndarray)
184
+ assert tuple(self.arr.shape) == tuple(other.arr.shape)
185
+ from taichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
186
+
187
+ ndarray_to_ndarray(self, other)
188
+ impl.get_runtime().sync()
189
+
190
+ def _set_grad(self, grad):
191
+ """Sets the gradient ndarray.
192
+
193
+ Args:
194
+ grad (Ndarray): The gradient ndarray.
195
+ """
196
+ self.grad = grad
197
+
198
+ def __deepcopy__(self, memo=None):
199
+ """Copies all elements to a new ndarray.
200
+
201
+ Returns:
202
+ Ndarray: The result ndarray.
203
+ """
204
+ raise NotImplementedError()
205
+
206
+ def _fill_by_kernel(self, val):
207
+ """Fills ndarray with a specific scalar value using a ti.kernel.
208
+
209
+ Args:
210
+ val (Union[int, float]): Value to fill.
211
+ """
212
+ raise NotImplementedError()
213
+
214
+ @python_scope
215
+ def _pad_key(self, key):
216
+ if key is None:
217
+ key = ()
218
+ if not isinstance(key, (tuple, list)):
219
+ key = (key,)
220
+ if len(key) != len(self.arr.total_shape()):
221
+ raise TaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
222
+ return key
223
+
224
+ @python_scope
225
+ def _initialize_host_accessor(self):
226
+ if self.host_accessor:
227
+ return
228
+ impl.get_runtime().materialize()
229
+ self.host_accessor = NdarrayHostAccessor(self.arr)
230
+
231
+
232
+ class ScalarNdarray(Ndarray):
233
+ """Taichi ndarray with scalar elements.
234
+
235
+ Args:
236
+ dtype (DataType): Data type of each value.
237
+ shape (Tuple[int]): Shape of the ndarray.
238
+ """
239
+
240
+ def __init__(self, dtype, arr_shape):
241
+ super().__init__()
242
+ self.dtype = cook_dtype(dtype)
243
+ self.arr = impl.get_runtime().prog.create_ndarray(
244
+ self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
245
+ )
246
+ self.shape = tuple(self.arr.shape)
247
+ self.element_type = dtype
248
+
249
+ def __del__(self):
250
+ if (
251
+ impl is not None
252
+ and impl.get_runtime is not None
253
+ and impl.get_runtime() is not None
254
+ and impl.get_runtime().prog is not None
255
+ ):
256
+ impl.get_runtime().prog.delete_ndarray(self.arr)
257
+
258
+ @property
259
+ def element_shape(self):
260
+ return ()
261
+
262
+ @python_scope
263
+ def __setitem__(self, key, value):
264
+ self._initialize_host_accessor()
265
+ self.host_accessor.setter(value, *self._pad_key(key))
266
+
267
+ @python_scope
268
+ def __getitem__(self, key):
269
+ self._initialize_host_accessor()
270
+ return self.host_accessor.getter(*self._pad_key(key))
271
+
272
+ @python_scope
273
+ def to_numpy(self):
274
+ return self._ndarray_to_numpy()
275
+
276
+ @python_scope
277
+ def from_numpy(self, arr):
278
+ self._ndarray_from_numpy(arr)
279
+
280
+ def __deepcopy__(self, memo=None):
281
+ ret_arr = ScalarNdarray(self.dtype, self.shape)
282
+ ret_arr.copy_from(self)
283
+ return ret_arr
284
+
285
+ def _fill_by_kernel(self, val):
286
+ from taichi._kernels import fill_ndarray # pylint: disable=C0415
287
+
288
+ fill_ndarray(self, val)
289
+
290
+ def __repr__(self):
291
+ return "<ti.ndarray>"
292
+
293
+
294
+ class NdarrayHostAccessor:
295
+ def __init__(self, ndarray):
296
+ dtype = ndarray.element_data_type()
297
+ if is_real(dtype):
298
+
299
+ def getter(*key):
300
+ return ndarray.read_float(key)
301
+
302
+ def setter(value, *key):
303
+ ndarray.write_float(key, value)
304
+
305
+ else:
306
+ if is_signed(dtype):
307
+
308
+ def getter(*key):
309
+ return ndarray.read_int(key)
310
+
311
+ else:
312
+
313
+ def getter(*key):
314
+ return ndarray.read_uint(key)
315
+
316
+ def setter(value, *key):
317
+ ndarray.write_int(key, value)
318
+
319
+ self.getter = getter
320
+ self.setter = setter
321
+
322
+
323
+ class NdarrayHostAccess:
324
+ """Class for accessing VectorNdarray/MatrixNdarray in Python scope.
325
+ Args:
326
+ arr (Union[VectorNdarray, MatrixNdarray]): See above.
327
+ indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
328
+ indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
329
+ """
330
+
331
+ def __init__(self, arr, indices_first, indices_second):
332
+ self.ndarr = arr
333
+ self.arr = arr.arr
334
+ self.indices = indices_first + indices_second
335
+
336
+ def getter():
337
+ self.ndarr._initialize_host_accessor()
338
+ return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
339
+
340
+ def setter(value):
341
+ self.ndarr._initialize_host_accessor()
342
+ self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
343
+
344
+ self.getter = getter
345
+ self.setter = setter
346
+
347
+
348
+ __all__ = ["Ndarray", "ScalarNdarray"]