gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
taichi/lang/impl.py ADDED
@@ -0,0 +1,1246 @@
1
+ # type: ignore
2
+
3
+ import numbers
4
+ from types import FunctionType, MethodType
5
+ from typing import Any, Iterable, Sequence
6
+
7
+ import numpy as np
8
+
9
+ from taichi._lib import core as _ti_core
10
+ from taichi._lib.core.taichi_python import (
11
+ DataType,
12
+ Function,
13
+ Program,
14
+ )
15
+ from taichi._snode.fields_builder import FieldsBuilder
16
+ from taichi.lang._ndarray import ScalarNdarray
17
+ from taichi.lang._ndrange import GroupedNDRange, _Ndrange
18
+ from taichi.lang._texture import RWTextureAccessor
19
+ from taichi.lang.any_array import AnyArray
20
+ from taichi.lang.exception import (
21
+ TaichiCompilationError,
22
+ TaichiRuntimeError,
23
+ TaichiSyntaxError,
24
+ TaichiTypeError,
25
+ )
26
+ from taichi.lang.expr import Expr, make_expr_group
27
+ from taichi.lang.field import Field, ScalarField
28
+ from taichi.lang.kernel_arguments import SparseMatrixProxy
29
+ from taichi.lang.kernel_impl import Kernel
30
+ from taichi.lang.matrix import (
31
+ Matrix,
32
+ MatrixField,
33
+ MatrixNdarray,
34
+ MatrixType,
35
+ Vector,
36
+ VectorNdarray,
37
+ make_matrix,
38
+ )
39
+ from taichi.lang.mesh import (
40
+ ConvType,
41
+ MeshElementFieldProxy,
42
+ MeshInstance,
43
+ MeshRelationAccessProxy,
44
+ MeshReorderedMatrixFieldProxy,
45
+ MeshReorderedScalarFieldProxy,
46
+ element_type_name,
47
+ )
48
+ from taichi.lang.simt.block import SharedArray
49
+ from taichi.lang.snode import SNode
50
+ from taichi.lang.struct import Struct, StructField, _IntermediateStruct
51
+ from taichi.lang.util import (
52
+ cook_dtype,
53
+ get_traceback,
54
+ is_taichi_class,
55
+ python_scope,
56
+ taichi_scope,
57
+ warning,
58
+ )
59
+ from taichi.types.enums import SNodeGradType
60
+ from taichi.types.primitive_types import (
61
+ all_types,
62
+ f16,
63
+ f32,
64
+ f64,
65
+ i32,
66
+ i64,
67
+ u8,
68
+ u32,
69
+ u64,
70
+ )
71
+
72
+
73
+ @taichi_scope
74
+ def expr_init_shared_array(shape, element_type):
75
+ compiling_callable = get_runtime().compiling_callable
76
+ assert compiling_callable is not None
77
+ return compiling_callable.ast_builder().expr_alloca_shared_array(
78
+ shape, element_type, _ti_core.DebugInfo(get_runtime().get_current_src_info())
79
+ )
80
+
81
+
82
+ @taichi_scope
83
+ def expr_init(rhs):
84
+ compiling_callable = get_runtime().compiling_callable
85
+ assert compiling_callable is not None
86
+ if rhs is None:
87
+ return Expr(
88
+ compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
89
+ )
90
+ if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
91
+ return Matrix(*rhs.to_list(), ndim=rhs.ndim)
92
+ if isinstance(rhs, Matrix):
93
+ return make_matrix(rhs.to_list())
94
+ if isinstance(rhs, SharedArray):
95
+ return rhs
96
+ if isinstance(rhs, Struct):
97
+ return Struct(rhs.to_dict(include_methods=True, include_ndim=True))
98
+ if isinstance(rhs, list):
99
+ return [expr_init(e) for e in rhs]
100
+ if isinstance(rhs, tuple):
101
+ return tuple(expr_init(e) for e in rhs)
102
+ if isinstance(rhs, dict):
103
+ return dict((key, expr_init(val)) for key, val in rhs.items())
104
+ if isinstance(rhs, _ti_core.DataType):
105
+ return rhs
106
+ if isinstance(rhs, _ti_core.Arch):
107
+ return rhs
108
+ if isinstance(rhs, _Ndrange):
109
+ return rhs
110
+ if isinstance(rhs, MeshElementFieldProxy):
111
+ return rhs
112
+ if isinstance(rhs, MeshRelationAccessProxy):
113
+ return rhs
114
+ if hasattr(rhs, "_data_oriented"):
115
+ return rhs
116
+ return Expr(
117
+ compiling_callable.ast_builder().expr_var(
118
+ Expr(rhs).ptr, _ti_core.DebugInfo(get_runtime().get_current_src_info())
119
+ )
120
+ )
121
+
122
+
123
+ @taichi_scope
124
+ def expr_init_func(rhs): # temporary solution to allow passing in fields as arguments
125
+ if isinstance(rhs, Field):
126
+ return rhs
127
+ return expr_init(rhs)
128
+
129
+
130
+ def begin_frontend_struct_for(ast_builder, group, loop_range):
131
+ if not isinstance(loop_range, (AnyArray, Field, SNode, RWTextureAccessor, _Root)):
132
+ raise TypeError(
133
+ f"Cannot loop over the object {type(loop_range)} in Taichi scope. Only Taichi fields (via template) or dense arrays (via types.ndarray) are supported."
134
+ )
135
+ if group.size() != len(loop_range.shape):
136
+ raise IndexError(
137
+ "Number of struct-for indices does not match loop variable dimensionality "
138
+ f"({group.size()} != {len(loop_range.shape)}). Maybe you wanted to "
139
+ 'use "for I in ti.grouped(x)" to group all indices into a single vector I?'
140
+ )
141
+ dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
142
+ if isinstance(loop_range, (AnyArray, RWTextureAccessor)):
143
+ ast_builder.begin_frontend_struct_for_on_external_tensor(group, loop_range._loop_range(), dbg_info)
144
+ else:
145
+ ast_builder.begin_frontend_struct_for_on_snode(group, loop_range._loop_range(), dbg_info)
146
+
147
+
148
+ def begin_frontend_if(ast_builder, cond, stmt_dbg_info):
149
+ assert ast_builder is not None
150
+ if is_taichi_class(cond):
151
+ raise ValueError(
152
+ "The truth value of vectors/matrices is ambiguous.\n"
153
+ "Consider using `any` or `all` when comparing vectors/matrices:\n"
154
+ " if all(x == y):\n"
155
+ "or\n"
156
+ " if any(x != y):\n"
157
+ )
158
+ ast_builder.begin_frontend_if(Expr(cond).ptr, stmt_dbg_info)
159
+
160
+
161
+ @taichi_scope
162
+ def _calc_slice(index, default_stop):
163
+ start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1
164
+
165
+ def check_validity(x):
166
+ # TODO(mzmzm): support variable in slice
167
+ if isinstance(x, Expr):
168
+ raise TaichiCompilationError(
169
+ "Taichi does not support variables in slice now, please use constant instead of it."
170
+ )
171
+
172
+ check_validity(start), check_validity(stop), check_validity(step)
173
+ return [_ for _ in range(start, stop, step)]
174
+
175
+
176
+ def validate_subscript_index(value, index):
177
+ if isinstance(value, Field):
178
+ # field supports negative indices
179
+ return
180
+
181
+ if isinstance(index, Expr):
182
+ return
183
+
184
+ if isinstance(index, Iterable):
185
+ for ind in index:
186
+ validate_subscript_index(value, ind)
187
+
188
+ if isinstance(index, slice):
189
+ validate_subscript_index(value, index.start)
190
+ validate_subscript_index(value, index.stop)
191
+
192
+ if isinstance(index, int) and index < 0:
193
+ raise TaichiSyntaxError("Negative indices are not supported in Taichi kernels.")
194
+
195
+
196
+ @taichi_scope
197
+ def subscript(ast_builder, value, *_indices, skip_reordered=False):
198
+ dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
199
+ compiling_callable = get_runtime().compiling_callable
200
+ assert compiling_callable is not None
201
+ ast_builder = compiling_callable.ast_builder()
202
+ # Directly evaluate in Python for non-Taichi types
203
+ if not isinstance(
204
+ value,
205
+ (
206
+ Expr,
207
+ Field,
208
+ AnyArray,
209
+ SparseMatrixProxy,
210
+ MeshElementFieldProxy,
211
+ MeshRelationAccessProxy,
212
+ SharedArray,
213
+ ),
214
+ ):
215
+ if len(_indices) == 1:
216
+ _indices = _indices[0]
217
+ return value.__getitem__(_indices)
218
+
219
+ has_slice = False
220
+
221
+ flattened_indices = []
222
+ for _index in _indices:
223
+ if isinstance(_index, Matrix):
224
+ ind = _index.to_list()
225
+ elif isinstance(_index, slice):
226
+ ind = [_index]
227
+ has_slice = True
228
+ else:
229
+ ind = [_index]
230
+ flattened_indices += ind
231
+ indices = tuple(flattened_indices)
232
+ validate_subscript_index(value, indices)
233
+
234
+ if len(indices) == 1 and indices[0] is None:
235
+ indices = ()
236
+
237
+ indices_expr_group = None
238
+ if has_slice:
239
+ if not (isinstance(value, Expr) and value.is_tensor()):
240
+ raise TaichiSyntaxError(f"The type {type(value)} do not support index of slice type")
241
+ else:
242
+ indices_expr_group = make_expr_group(*indices)
243
+
244
+ if isinstance(value, SharedArray):
245
+ return value.subscript(*indices)
246
+ if isinstance(value, MeshElementFieldProxy):
247
+ return value.subscript(*indices)
248
+ if isinstance(value, MeshRelationAccessProxy):
249
+ return value.subscript(*indices)
250
+ if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered:
251
+ assert len(indices) > 0
252
+ reordered_index = tuple(
253
+ [
254
+ Expr(
255
+ ast_builder.mesh_index_conversion(
256
+ value.mesh_ptr, value.element_type, Expr(indices[0]).ptr, ConvType.g2r, dbg_info
257
+ )
258
+ )
259
+ ]
260
+ )
261
+ return subscript(ast_builder, value, *reordered_index, skip_reordered=True)
262
+ if isinstance(value, SparseMatrixProxy):
263
+ return value.subscript(*indices)
264
+ if isinstance(value, Field):
265
+ _var = value._get_field_members()[0].ptr
266
+ snode = _var.snode()
267
+ if snode is None:
268
+ if _var.is_primal():
269
+ raise RuntimeError(f"{_var.get_expr_name()} has not been placed.")
270
+ else:
271
+ raise RuntimeError(
272
+ f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
273
+ )
274
+
275
+ assert indices_expr_group is not None
276
+ if isinstance(value, MatrixField):
277
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
278
+ if isinstance(value, StructField):
279
+ entries = {k: subscript(ast_builder, v, *indices) for k, v in value._items}
280
+ entries["__struct_methods"] = value.struct_methods
281
+ return _IntermediateStruct(entries)
282
+ return Expr(ast_builder.expr_subscript(_var, indices_expr_group, dbg_info))
283
+ if isinstance(value, AnyArray):
284
+ assert indices_expr_group is not None
285
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
286
+ assert isinstance(value, Expr)
287
+ # Index into TensorType
288
+ # value: IndexExpression with ret_type = TensorType
289
+ assert value.is_tensor()
290
+
291
+ if has_slice:
292
+ shape = value.get_shape()
293
+ dim = len(shape)
294
+ assert dim == len(indices)
295
+ indices = [
296
+ _calc_slice(index, shape[i]) if isinstance(index, slice) else index for i, index in enumerate(indices)
297
+ ]
298
+ if dim == 1:
299
+ assert isinstance(indices[0], list)
300
+ multiple_indices = [make_expr_group(i) for i in indices[0]]
301
+ return_shape = (len(indices[0]),)
302
+ else:
303
+ assert dim == 2
304
+ if isinstance(indices[0], list) and isinstance(indices[1], list):
305
+ multiple_indices = [make_expr_group(i, j) for i in indices[0] for j in indices[1]]
306
+ return_shape = (len(indices[0]), len(indices[1]))
307
+ elif isinstance(indices[0], list): # indices[1] is not list
308
+ multiple_indices = [make_expr_group(i, indices[1]) for i in indices[0]]
309
+ return_shape = (len(indices[0]),)
310
+ else: # indices[0] is not list while indices[1] is list
311
+ multiple_indices = [make_expr_group(indices[0], j) for j in indices[1]]
312
+ return_shape = (len(indices[1]),)
313
+ return Expr(
314
+ _ti_core.subscript_with_multiple_indices(
315
+ value.ptr,
316
+ multiple_indices,
317
+ return_shape,
318
+ dbg_info,
319
+ )
320
+ )
321
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
322
+
323
+
324
+ class SrcInfoGuard:
325
+ def __init__(self, info_stack, info):
326
+ self.info_stack = info_stack
327
+ self.info = info
328
+
329
+ def __enter__(self):
330
+ self.info_stack.append(self.info)
331
+
332
+ def __exit__(self, exc_type, exc_val, exc_tb):
333
+ self.info_stack.pop()
334
+
335
+
336
+ class PyTaichi:
337
+ def __init__(self, kernels=None):
338
+ self.materialized = False
339
+ self._prog: Program | None = None
340
+ self.src_info_stack = []
341
+ self.inside_kernel: bool = False
342
+ self.compiling_callable: Kernel | Function | None = None # pointer to instance of lang::Kernel/Function
343
+ self._current_kernel: Kernel | None = None
344
+ self.global_vars = []
345
+ self.grad_vars = []
346
+ self.dual_vars = []
347
+ self.matrix_fields = []
348
+ self.default_fp = f32
349
+ self.default_ip = i32
350
+ self.default_up = u32
351
+ self.print_full_traceback: bool = False
352
+ self.target_tape = None
353
+ self.fwd_mode_manager = None
354
+ self.grad_replaced = False
355
+ self.kernels = kernels or []
356
+ self._signal_handler_registry = None
357
+ self.unfinalized_fields_builder = {}
358
+
359
+ @property
360
+ def prog(self) -> Program:
361
+ if self._prog is None:
362
+ raise TaichiRuntimeError("_prog attribute not initialized. Maybe you forgot to call `ti.init()` first?")
363
+ return self._prog
364
+
365
+ @property
366
+ def current_kernel(self) -> Kernel:
367
+ if self._current_kernel is None:
368
+ raise TaichiRuntimeError(
369
+ "_pr_current_kernelog attribute not initialized. Maybe you forgot to call `ti.init()` first?"
370
+ )
371
+ return self._current_kernel
372
+
373
+ def initialize_fields_builder(self, builder):
374
+ self.unfinalized_fields_builder[builder] = get_traceback(2)
375
+
376
+ def clear_compiled_functions(self):
377
+ for k in self.kernels:
378
+ k.compiled_kernels.clear()
379
+
380
+ def finalize_fields_builder(self, builder):
381
+ self.unfinalized_fields_builder.pop(builder)
382
+
383
+ def validate_fields_builder(self):
384
+ for builder, tb in self.unfinalized_fields_builder.items():
385
+ if builder == _root_fb:
386
+ continue
387
+
388
+ raise TaichiRuntimeError(
389
+ f"Field builder {builder} is not finalized. " f"Please call finalize() on it. Traceback:\n{tb}"
390
+ )
391
+
392
+ def get_num_compiled_functions(self):
393
+ count = 0
394
+ for k in self.kernels:
395
+ count += len(k.compiled_kernels)
396
+ return count
397
+
398
+ def src_info_guard(self, info):
399
+ return SrcInfoGuard(self.src_info_stack, info)
400
+
401
+ def get_current_src_info(self):
402
+ return self.src_info_stack[-1]
403
+
404
+ def set_default_fp(self, fp):
405
+ assert fp in [f16, f32, f64]
406
+ self.default_fp = fp
407
+ default_cfg().default_fp = self.default_fp
408
+
409
+ def set_default_ip(self, ip):
410
+ assert ip in [i32, i64]
411
+ self.default_ip = ip
412
+ self.default_up = u32 if ip == i32 else u64
413
+ default_cfg().default_ip = self.default_ip
414
+ default_cfg().default_up = self.default_up
415
+
416
+ def create_program(self):
417
+ if self._prog is None:
418
+ self._prog = _ti_core.Program()
419
+
420
+ @staticmethod
421
+ def materialize_root_fb(is_first_call):
422
+ if root.finalized:
423
+ return
424
+ if not is_first_call and root.empty:
425
+ # We have to forcefully finalize when `is_first_call` is True (even
426
+ # if the root itself is empty), so that there is a valid struct
427
+ # llvm::Module, if no field has been declared before the first kernel
428
+ # invocation. Example case:
429
+ # https://github.com/taichi-dev/taichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266
430
+ return
431
+
432
+ if get_runtime().prog.config().debug:
433
+ if not root.finalized:
434
+ root._allocate_adjoint_checkbit()
435
+
436
+ root.finalize(raise_warning=not is_first_call)
437
+ global _root_fb
438
+ _root_fb = FieldsBuilder()
439
+
440
+ @staticmethod
441
+ def _finalize_root_fb_for_aot():
442
+ if _root_fb.finalized:
443
+ raise RuntimeError("AOT: can only finalize the root FieldsBuilder once")
444
+ assert isinstance(_root_fb, FieldsBuilder)
445
+ _root_fb._finalize_for_aot()
446
+
447
+ @staticmethod
448
+ def _get_tb(_var):
449
+ return getattr(_var, "declaration_tb", str(_var.ptr))
450
+
451
+ def _check_field_not_placed(self):
452
+ not_placed = []
453
+ for _var in self.global_vars:
454
+ if _var.ptr.snode() is None:
455
+ not_placed.append(self._get_tb(_var))
456
+
457
+ if len(not_placed):
458
+ bar = "=" * 44 + "\n"
459
+ raise RuntimeError(
460
+ f"These field(s) are not placed:\n{bar}"
461
+ + f"{bar}".join(not_placed)
462
+ + f"{bar}Please consider specifying a shape for them. E.g.,"
463
+ + "\n\n x = ti.field(float, shape=(2, 3))"
464
+ )
465
+
466
+ def _check_gradient_field_not_placed(self, gradient_type):
467
+ not_placed = set()
468
+ gradient_vars = []
469
+ if gradient_type == "grad":
470
+ gradient_vars = self.grad_vars
471
+ elif gradient_type == "dual":
472
+ gradient_vars = self.dual_vars
473
+ for _var in gradient_vars:
474
+ if _var.ptr.snode() is None:
475
+ not_placed.add(self._get_tb(_var))
476
+
477
+ if len(not_placed):
478
+ bar = "=" * 44 + "\n"
479
+ raise RuntimeError(
480
+ f"These field(s) requrie `needs_{gradient_type}=True`, however their {gradient_type} field(s) are not placed:\n{bar}"
481
+ + f"{bar}".join(not_placed)
482
+ + f"{bar}Please consider place the {gradient_type} field(s). E.g.,"
483
+ + "\n\n ti.root.dense(ti.i, 1).place(x.{gradient_type})"
484
+ + "\n\n Or specify a shape for the field(s). E.g.,"
485
+ + "\n\n x = ti.field(float, shape=(2, 3), needs_{gradient_type}=True)"
486
+ )
487
+
488
+ def _check_matrix_field_member_shape(self):
489
+ for _field in self.matrix_fields:
490
+ shapes = [_field.get_scalar_field(i, j).shape for i in range(_field.n) for j in range(_field.m)]
491
+ if any(shape != shapes[0] for shape in shapes):
492
+ raise RuntimeError(
493
+ "Members of the following field have different shapes "
494
+ + f"{shapes}:\n{self._get_tb(_field._get_field_members()[0])}"
495
+ )
496
+
497
+ def _calc_matrix_field_dynamic_index_stride(self):
498
+ for _field in self.matrix_fields:
499
+ _field._calc_dynamic_index_stride()
500
+
501
+ def materialize(self):
502
+ self.materialize_root_fb(not self.materialized)
503
+ self.materialized = True
504
+
505
+ self.validate_fields_builder()
506
+
507
+ self._check_field_not_placed()
508
+ self._check_gradient_field_not_placed("grad")
509
+ self._check_gradient_field_not_placed("dual")
510
+ self._check_matrix_field_member_shape()
511
+ self._calc_matrix_field_dynamic_index_stride()
512
+ self.global_vars = []
513
+ self.grad_vars = []
514
+ self.dual_vars = []
515
+ self.matrix_fields = []
516
+
517
+ def _register_signal_handlers(self):
518
+ if self._signal_handler_registry is None:
519
+ self._signal_handler_registry = _ti_core.HackedSignalRegister()
520
+
521
+ def clear(self):
522
+ if self._prog:
523
+ self._prog.finalize()
524
+ self._prog = None
525
+ self._signal_handler_registry = None
526
+ self.materialized = False
527
+
528
+ def sync(self):
529
+ self.materialize()
530
+ assert self._prog is not None
531
+ self._prog.synchronize()
532
+
533
+
534
+ pytaichi = PyTaichi()
535
+
536
+
537
+ def get_runtime() -> PyTaichi:
538
+ return pytaichi
539
+
540
+
541
+ def reset():
542
+ global pytaichi
543
+ old_kernels = pytaichi.kernels
544
+ pytaichi.clear()
545
+ pytaichi = PyTaichi(old_kernels)
546
+ for k in old_kernels:
547
+ k.reset()
548
+ _ti_core.reset_default_compile_config()
549
+
550
+
551
+ @taichi_scope
552
+ def static_print(*args, __p=print, **kwargs):
553
+ """The print function in Taichi scope.
554
+
555
+ This function is called at compile time and has no runtime overhead.
556
+ """
557
+ __p(*args, **kwargs)
558
+
559
+
560
+ # we don't add @taichi_scope decorator for @ti.pyfunc to work
561
+ def static_assert(cond, msg=None):
562
+ """Throw AssertionError when `cond` is False.
563
+
564
+ This function is called at compile time and has no runtime overhead.
565
+ The bool value in `cond` must can be determined at compile time.
566
+
567
+ Args:
568
+ cond (bool): an expression with a bool value.
569
+ msg (str): assertion message.
570
+
571
+ Example::
572
+
573
+ >>> year = 2001
574
+ >>> @ti.kernel
575
+ >>> def test():
576
+ >>> ti.static_assert(year % 4 == 0, "the year must be a lunar year")
577
+ AssertionError: the year must be a lunar year
578
+ """
579
+ if isinstance(cond, Expr):
580
+ raise TaichiTypeError("Static assert with non-static condition")
581
+ if msg is not None:
582
+ assert cond, msg
583
+ else:
584
+ assert cond
585
+
586
+
587
+ def inside_kernel():
588
+ return pytaichi.inside_kernel
589
+
590
+
591
+ def index_nd(dim):
592
+ return axes(*range(dim))
593
+
594
+
595
+ class _UninitializedRootFieldsBuilder:
596
+ def __getattr__(self, item):
597
+ if item == "__qualname__":
598
+ # For sphinx docstring extraction.
599
+ return "_UninitializedRootFieldsBuilder"
600
+ raise TaichiRuntimeError("Please call init() first")
601
+
602
+
603
+ # `root` initialization must be delayed until after the program is
604
+ # created. Unfortunately, `root` exists in both taichi.lang.impl module and
605
+ # the top-level taichi module at this point; so if `root` itself is written, we
606
+ # would have to make sure that `root` in all the modules get updated to the same
607
+ # instance. This is an error-prone process.
608
+ #
609
+ # To avoid this situation, we create `root` once during the import time, and
610
+ # never write to it. The core part, `_root_fb`, is the one whose initialization
611
+ # gets delayed. `_root_fb` will only exist in the taichi.lang.impl module, so
612
+ # writing to it is would result in less for maintenance cost.
613
+ #
614
+ # `_root_fb` will be overridden inside :func:`taichi.lang.init`.
615
+ _root_fb = _UninitializedRootFieldsBuilder()
616
+
617
+
618
+ def deactivate_all_snodes():
619
+ """Recursively deactivate all SNodes."""
620
+ for root_fb in FieldsBuilder._finalized_roots():
621
+ root_fb.deactivate_all()
622
+
623
+
624
+ class _Root:
625
+ """Wrapper around the default root FieldsBuilder instance."""
626
+
627
+ @staticmethod
628
+ def parent(n=1):
629
+ """Same as :func:`taichi.SNode.parent`"""
630
+ assert isinstance(_root_fb, FieldsBuilder)
631
+ return _root_fb.root.parent(n)
632
+
633
+ @staticmethod
634
+ def _loop_range():
635
+ """Same as :func:`taichi.SNode.loop_range`"""
636
+ assert isinstance(_root_fb, FieldsBuilder)
637
+ return _root_fb.root._loop_range()
638
+
639
+ @staticmethod
640
+ def _get_children():
641
+ """Same as :func:`taichi.SNode.get_children`"""
642
+ assert isinstance(_root_fb, FieldsBuilder)
643
+ return _root_fb.root._get_children()
644
+
645
+ # TODO: Record all of the SNodeTrees that finalized under 'ti.root'
646
+ @staticmethod
647
+ def deactivate_all():
648
+ warning("""'ti.root.deactivate_all()' would deactivate all finalized snodes.""")
649
+ deactivate_all_snodes()
650
+
651
+ @property
652
+ def shape(self):
653
+ """Same as :func:`taichi.SNode.shape`"""
654
+ assert isinstance(_root_fb, FieldsBuilder)
655
+ return _root_fb.root.shape
656
+
657
+ @property
658
+ def _id(self):
659
+ assert isinstance(_root_fb, FieldsBuilder)
660
+ return _root_fb.root._id
661
+
662
+ def __getattr__(self, item):
663
+ return getattr(_root_fb, item)
664
+
665
+ def __repr__(self):
666
+ return "ti.root"
667
+
668
+
669
+ root = _Root()
670
+ """Root of the declared Taichi :func:`~taichi.lang.impl.field`s.
671
+
672
+ See also https://docs.taichi-lang.org/docs/layout
673
+
674
+ Example::
675
+
676
+ >>> x = ti.field(ti.f32)
677
+ >>> ti.root.pointer(ti.ij, 4).dense(ti.ij, 8).place(x)
678
+ """
679
+
680
+
681
+ def _create_snode(axis_seq: Sequence[int], shape_seq: Sequence[numbers.Number], same_level: bool):
682
+ dim = len(axis_seq)
683
+ assert dim == len(shape_seq)
684
+ snode = root
685
+ if same_level:
686
+ snode = snode.dense(axes(*axis_seq), shape_seq)
687
+ else:
688
+ for i in range(dim):
689
+ snode = snode.dense(axes(axis_seq[i]), (shape_seq[i],))
690
+ return snode
691
+
692
+
693
+ @python_scope
694
+ def create_field_member(dtype, name, needs_grad, needs_dual):
695
+ dtype = cook_dtype(dtype)
696
+
697
+ # primal
698
+ prog = get_runtime().prog
699
+
700
+ x = Expr(prog.make_id_expr(""))
701
+ x.declaration_tb = get_traceback(stacklevel=4)
702
+ x.ptr = _ti_core.expr_field(x.ptr, dtype)
703
+ x.ptr.set_name(name)
704
+ x.ptr.set_grad_type(SNodeGradType.PRIMAL)
705
+ pytaichi.global_vars.append(x)
706
+
707
+ x_grad = None
708
+ x_dual = None
709
+ # The x_grad_checkbit is used for global data access rule checker
710
+ x_grad_checkbit = None
711
+ if _ti_core.is_real(dtype):
712
+ # adjoint
713
+ x_grad = Expr(prog.make_id_expr(""))
714
+ x_grad.declaration_tb = get_traceback(stacklevel=4)
715
+ x_grad.ptr = _ti_core.expr_field(x_grad.ptr, dtype)
716
+ x_grad.ptr.set_name(name + ".grad")
717
+ x_grad.ptr.set_grad_type(SNodeGradType.ADJOINT)
718
+ x.ptr.set_adjoint(x_grad.ptr)
719
+ if needs_grad:
720
+ pytaichi.grad_vars.append(x_grad)
721
+
722
+ if prog.config().debug:
723
+ # adjoint checkbit
724
+ x_grad_checkbit = Expr(prog.make_id_expr(""))
725
+ dtype = u8
726
+ if prog.config().arch in (_ti_core.opengl, _ti_core.vulkan, _ti_core.gles):
727
+ dtype = i32
728
+ x_grad_checkbit.ptr = _ti_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
729
+ x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
730
+ x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
731
+ x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)
732
+
733
+ # dual
734
+ x_dual = Expr(prog.make_id_expr(""))
735
+ x_dual.ptr = _ti_core.expr_field(x_dual.ptr, dtype)
736
+ x_dual.ptr.set_name(name + ".dual")
737
+ x_dual.ptr.set_grad_type(SNodeGradType.DUAL)
738
+ x.ptr.set_dual(x_dual.ptr)
739
+ if needs_dual:
740
+ pytaichi.dual_vars.append(x_dual)
741
+ elif needs_grad or needs_dual:
742
+ raise TaichiRuntimeError(f"{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.")
743
+
744
+ return x, x_grad, x_dual
745
+
746
+
747
+ @python_scope
748
+ def _field(
749
+ dtype,
750
+ shape=None,
751
+ order=None,
752
+ name="",
753
+ offset=None,
754
+ needs_grad=False,
755
+ needs_dual=False,
756
+ ):
757
+ x, x_grad, x_dual = create_field_member(dtype, name, needs_grad, needs_dual)
758
+ x = ScalarField(x)
759
+ if x_grad:
760
+ x_grad = ScalarField(x_grad)
761
+ x._set_grad(x_grad)
762
+ if x_dual:
763
+ x_dual = ScalarField(x_dual)
764
+ x._set_dual(x_dual)
765
+
766
+ if shape is None:
767
+ if offset is not None:
768
+ raise TaichiSyntaxError("shape cannot be None when offset is set")
769
+ if order is not None:
770
+ raise TaichiSyntaxError("shape cannot be None when order is set")
771
+ else:
772
+ if isinstance(shape, numbers.Number):
773
+ shape = (shape,)
774
+ if isinstance(offset, numbers.Number):
775
+ offset = (offset,)
776
+ dim = len(shape)
777
+ if offset is not None and dim != len(offset):
778
+ raise TaichiSyntaxError(f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})")
779
+ axis_seq = []
780
+ shape_seq = []
781
+ if order is not None:
782
+ if dim != len(order):
783
+ raise TaichiSyntaxError(
784
+ f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
785
+ )
786
+ if dim != len(set(order)):
787
+ raise TaichiSyntaxError("The axes in order must be different")
788
+ for ch in order:
789
+ axis = ord(ch) - ord("i")
790
+ if axis < 0 or axis >= dim:
791
+ raise TaichiSyntaxError(f"Invalid axis {ch}")
792
+ axis_seq.append(axis)
793
+ shape_seq.append(shape[axis])
794
+ else:
795
+ axis_seq = list(range(dim))
796
+ shape_seq = list(shape)
797
+ same_level = order is None
798
+ _create_snode(axis_seq, shape_seq, same_level).place(x, offset=offset)
799
+ if needs_grad:
800
+ _create_snode(axis_seq, shape_seq, same_level).place(x_grad, offset=offset)
801
+ if needs_dual:
802
+ _create_snode(axis_seq, shape_seq, same_level).place(x_dual, offset=offset)
803
+ return x
804
+
805
+
806
+ @python_scope
807
+ def field(dtype, *args, **kwargs):
808
+ """Defines a Taichi field.
809
+
810
+ A Taichi field can be viewed as an abstract N-dimensional array, hiding away
811
+ the complexity of how its underlying :class:`~taichi.lang.snode.SNode` are
812
+ actually defined. The data in a Taichi field can be directly accessed by
813
+ a Taichi :func:`~taichi.lang.kernel_impl.kernel`.
814
+
815
+ See also https://docs.taichi-lang.org/docs/field
816
+
817
+ Args:
818
+ dtype (DataType): data type of the field. Note it can be vector or matrix types as well.
819
+ shape (Union[int, tuple[int]], optional): shape of the field.
820
+ order (str, optional): order of the shape laid out in memory.
821
+ name (str, optional): name of the field.
822
+ offset (Union[int, tuple[int]], optional): offset of the field domain.
823
+ needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
824
+ and thus needs an adjoint field to store the gradients.
825
+ needs_dual (bool, optional): whether this field participates in autodiff (forward mode)
826
+ and thus needs an dual field to store the gradients.
827
+
828
+ Example::
829
+
830
+ The code below shows how a Taichi field can be declared and defined::
831
+
832
+ >>> x1 = ti.field(ti.f32, shape=(16, 8))
833
+ >>> # Equivalently
834
+ >>> x2 = ti.field(ti.f32)
835
+ >>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2)
836
+ >>>
837
+ >>> x3 = ti.field(ti.f32, shape=(16, 8), order='ji')
838
+ >>> # Equivalently
839
+ >>> x4 = ti.field(ti.f32)
840
+ >>> ti.root.dense(ti.j, shape=8).dense(ti.i, shape=16).place(x4)
841
+ >>>
842
+ >>> x5 = ti.field(ti.math.vec3, shape=(16, 8))
843
+
844
+ """
845
+ if isinstance(dtype, MatrixType):
846
+ if dtype.ndim == 1:
847
+ return Vector.field(dtype.n, dtype.dtype, *args, **kwargs)
848
+ return Matrix.field(dtype.n, dtype.m, dtype.dtype, *args, **kwargs)
849
+ return _field(dtype, *args, **kwargs)
850
+
851
+
852
+ @python_scope
853
+ def ndarray(dtype, shape, needs_grad=False):
854
+ """Defines a Taichi ndarray with scalar elements.
855
+
856
+ Args:
857
+ dtype (Union[DataType, MatrixType]): Data type of each element. This can be either a scalar type like ti.f32 or a compound type like ti.types.vector(3, ti.i32).
858
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
859
+
860
+ Example:
861
+ The code below shows how a Taichi ndarray with scalar elements can be declared and defined::
862
+
863
+ >>> x = ti.ndarray(ti.f32, shape=(16, 8)) # ndarray of shape (16, 8), each element is ti.f32 scalar.
864
+ >>> vec3 = ti.types.vector(3, ti.i32)
865
+ >>> y = ti.ndarray(vec3, shape=(10, 2)) # ndarray of shape (10, 2), each element is a vector of 3 ti.i32 scalars.
866
+ >>> matrix_ty = ti.types.matrix(3, 4, float)
867
+ >>> z = ti.ndarray(matrix_ty, shape=(4, 5)) # ndarray of shape (4, 5), each element is a matrix of (3, 4) ti.float scalars.
868
+ """
869
+ # primal
870
+ if isinstance(shape, numbers.Number):
871
+ shape = (shape,)
872
+ if not all((isinstance(x, int) or isinstance(x, np.integer)) and x > 0 and x <= 2**31 - 1 for x in shape):
873
+ raise TaichiRuntimeError(f"{shape} is not a valid shape for ndarray")
874
+ if dtype in all_types:
875
+ dt = cook_dtype(dtype)
876
+ x = ScalarNdarray(dt, shape)
877
+ elif isinstance(dtype, MatrixType):
878
+ if dtype.ndim == 1:
879
+ x = VectorNdarray(dtype.n, dtype.dtype, shape)
880
+ else:
881
+ x = MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape)
882
+ dt = dtype.dtype
883
+ else:
884
+ raise TaichiRuntimeError(f"{dtype} is not supported as ndarray element type")
885
+ if needs_grad:
886
+ assert isinstance(dt, DataType)
887
+ if not _ti_core.is_real(dt):
888
+ raise TaichiRuntimeError(f"{dt} is not supported for ndarray with `needs_grad=True` or `needs_dual=True`.")
889
+ x_grad = ndarray(dtype, shape, needs_grad=False)
890
+ x._set_grad(x_grad)
891
+ return x
892
+
893
+
894
+ @taichi_scope
895
+ def ti_format_list_to_content_entries(raw):
896
+ # return a pair of [content, format]
897
+ def entry2content(_var):
898
+ if isinstance(_var, str):
899
+ return [_var, None]
900
+ if isinstance(_var, list):
901
+ assert len(_var) == 2 and (isinstance(_var[1], str) or _var[1] is None)
902
+ _var[0] = Expr(_var[0]).ptr
903
+ return _var
904
+ return [Expr(_var).ptr, None]
905
+
906
+ def list_ti_repr(_var):
907
+ yield "[" # distinguishing tuple & list will increase maintenance cost
908
+ for i, v in enumerate(_var):
909
+ if i:
910
+ yield ", "
911
+ yield v
912
+ yield "]"
913
+
914
+ def vars2entries(_vars):
915
+ for _var in _vars:
916
+ # If the first element is '__ti_fmt_value__', this list is an Expr and its format.
917
+ if isinstance(_var, list) and len(_var) == 3 and isinstance(_var[0], str) and _var[0] == "__ti_fmt_value__":
918
+ # yield [Expr, format] as a whole and don't pass it to vars2entries() again
919
+ yield _var[1:]
920
+ continue
921
+ elif hasattr(_var, "__ti_repr__"):
922
+ res = _var.__ti_repr__()
923
+ elif isinstance(_var, (list, tuple)):
924
+ # If the first element is '__ti_format__', this list is the result of ti_format.
925
+ if len(_var) > 0 and isinstance(_var[0], str) and _var[0] == "__ti_format__":
926
+ res = _var[1:]
927
+ else:
928
+ res = list_ti_repr(_var)
929
+ else:
930
+ yield _var
931
+ continue
932
+
933
+ for v in vars2entries(res):
934
+ yield v
935
+
936
+ def fused_string(entries):
937
+ accumated = ""
938
+ for entry in entries:
939
+ if isinstance(entry, str):
940
+ accumated += entry
941
+ else:
942
+ if accumated:
943
+ yield accumated
944
+ accumated = ""
945
+ yield entry
946
+ if accumated:
947
+ yield accumated
948
+
949
+ def extract_formats(entries):
950
+ contents, formats = zip(*entries)
951
+ return list(contents), list(formats)
952
+
953
+ entries = vars2entries(raw)
954
+ entries = fused_string(entries)
955
+ entries = [entry2content(entry) for entry in entries]
956
+ return extract_formats(entries)
957
+
958
+
959
+ @taichi_scope
960
+ def ti_print(*_vars, sep=" ", end="\n"):
961
+ def add_separators(_vars):
962
+ for i, _var in enumerate(_vars):
963
+ if i:
964
+ yield sep
965
+ yield _var
966
+ yield end
967
+
968
+ _vars = add_separators(_vars)
969
+ contents, formats = ti_format_list_to_content_entries(_vars)
970
+ compiling_callable = get_runtime().compiling_callable
971
+ assert compiling_callable is not None
972
+ compiling_callable.ast_builder().create_print(
973
+ contents, formats, _ti_core.DebugInfo(get_runtime().get_current_src_info())
974
+ )
975
+
976
+
977
+ @taichi_scope
978
+ def ti_format(*args):
979
+ content = args[0]
980
+ mixed = args[1:]
981
+ new_mixed = []
982
+ args = []
983
+ for x in mixed:
984
+ # x is a (formatted) Expr
985
+ if isinstance(x, Expr) or (isinstance(x, list) and len(x) == 3 and x[0] == "__ti_fmt_value__"):
986
+ new_mixed.append("{}")
987
+ args.append(x)
988
+ else:
989
+ new_mixed.append(x)
990
+ content = content.format(*new_mixed)
991
+ res = content.split("{}")
992
+ assert len(res) == len(args) + 1, "Number of args is different from number of positions provided in string"
993
+
994
+ for i, arg in enumerate(args):
995
+ res.insert(i * 2 + 1, arg)
996
+ res.insert(0, "__ti_format__")
997
+ return res
998
+
999
+
1000
+ @taichi_scope
1001
+ def ti_assert(cond, msg, extra_args, dbg_info):
1002
+ # Mostly a wrapper to help us convert from Expr (defined in Python) to
1003
+ # _ti_core.Expr (defined in C++)
1004
+ compiling_callable = get_runtime().compiling_callable
1005
+ assert compiling_callable is not None
1006
+ compiling_callable.ast_builder().create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
1007
+
1008
+
1009
+ @taichi_scope
1010
+ def ti_int(_var):
1011
+ if hasattr(_var, "__ti_int__"):
1012
+ return _var.__ti_int__()
1013
+ return int(_var)
1014
+
1015
+
1016
+ @taichi_scope
1017
+ def ti_bool(_var):
1018
+ if hasattr(_var, "__ti_bool__"):
1019
+ return _var.__ti_bool__()
1020
+ return bool(_var)
1021
+
1022
+
1023
+ @taichi_scope
1024
+ def ti_float(_var):
1025
+ if hasattr(_var, "__ti_float__"):
1026
+ return _var.__ti_float__()
1027
+ return float(_var)
1028
+
1029
+
1030
+ @taichi_scope
1031
+ def zero(x):
1032
+ # TODO: get dtype from Expr and Matrix:
1033
+ """Returns an array of zeros with the same shape and type as the input. It's also a scalar
1034
+ if the input is a scalar.
1035
+
1036
+ Args:
1037
+ x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): The input.
1038
+
1039
+ Returns:
1040
+ A new copy of the input but filled with zeros.
1041
+
1042
+ Example::
1043
+
1044
+ >>> x = ti.Vector([1, 1])
1045
+ >>> @ti.kernel
1046
+ >>> def test():
1047
+ >>> y = ti.zero(x)
1048
+ >>> print(y)
1049
+ [0, 0]
1050
+ """
1051
+ return x * 0
1052
+
1053
+
1054
+ @taichi_scope
1055
+ def one(x):
1056
+ """Returns an array of ones with the same shape and type as the input. It's also a scalar
1057
+ if the input is a scalar.
1058
+
1059
+ Args:
1060
+ x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): The input.
1061
+
1062
+ Returns:
1063
+ A new copy of the input but filled with ones.
1064
+
1065
+ Example::
1066
+
1067
+ >>> x = ti.Vector([0, 0])
1068
+ >>> @ti.kernel
1069
+ >>> def test():
1070
+ >>> y = ti.one(x)
1071
+ >>> print(y)
1072
+ [1, 1]
1073
+ """
1074
+ return zero(x) + 1
1075
+
1076
+
1077
+ def axes(*x: int):
1078
+ """Defines a list of axes to be used by a field.
1079
+
1080
+ Args:
1081
+ *x: A list of axes to be activated
1082
+
1083
+ Note that Taichi has already provided a set of commonly used axes. For example,
1084
+ `ti.ij` is just `axes(0, 1)` under the hood.
1085
+ """
1086
+ return [_ti_core.Axis(i) for i in x]
1087
+
1088
+
1089
+ Axis = _ti_core.Axis
1090
+
1091
+
1092
+ def static(x, *xs) -> Any:
1093
+ """Evaluates a Taichi-scope expression at compile time.
1094
+
1095
+ `static()` is what enables the so-called metaprogramming in Taichi. It is
1096
+ in many ways similar to ``constexpr`` in C++.
1097
+
1098
+ See also https://docs.taichi-lang.org/docs/meta.
1099
+
1100
+ Args:
1101
+ x (Any): an expression to be evaluated
1102
+ *xs (Any): for Python-ish swapping assignment
1103
+
1104
+ Example:
1105
+ The most common usage of `static()` is for compile-time evaluation::
1106
+
1107
+ >>> cond = False
1108
+ >>>
1109
+ >>> @ti.kernel
1110
+ >>> def run():
1111
+ >>> if ti.static(cond):
1112
+ >>> do_a()
1113
+ >>> else:
1114
+ >>> do_b()
1115
+
1116
+ Depending on the value of ``cond``, ``run()`` will be directly compiled
1117
+ into either ``do_a()`` or ``do_b()``. Thus there won't be a runtime
1118
+ condition check.
1119
+
1120
+ Another common usage is for compile-time loop unrolling::
1121
+
1122
+ >>> @ti.kernel
1123
+ >>> def run():
1124
+ >>> for i in ti.static(range(3)):
1125
+ >>> print(i)
1126
+ >>>
1127
+ >>> # The above will be unrolled to:
1128
+ >>> @ti.kernel
1129
+ >>> def run():
1130
+ >>> print(0)
1131
+ >>> print(1)
1132
+ >>> print(2)
1133
+ """
1134
+ if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x)
1135
+ return [static(x)] + [static(x) for x in xs]
1136
+
1137
+ if (
1138
+ isinstance(
1139
+ x,
1140
+ (
1141
+ bool,
1142
+ int,
1143
+ float,
1144
+ range,
1145
+ list,
1146
+ tuple,
1147
+ enumerate,
1148
+ GroupedNDRange,
1149
+ _Ndrange,
1150
+ zip,
1151
+ filter,
1152
+ map,
1153
+ ),
1154
+ )
1155
+ or x is None
1156
+ ):
1157
+ return x
1158
+ if isinstance(x, (np.bool_, np.integer, np.floating)):
1159
+ return x
1160
+
1161
+ if isinstance(x, AnyArray):
1162
+ return x
1163
+ if isinstance(x, Field):
1164
+ return x
1165
+ if isinstance(x, (FunctionType, MethodType)):
1166
+ return x
1167
+ raise ValueError(f"Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}")
1168
+
1169
+
1170
+ @taichi_scope
1171
+ def grouped(x):
1172
+ """Groups the indices in the iterator returned by `ndrange()` into a 1-D vector.
1173
+
1174
+ This is often used when you want to iterate over all indices returned by `ndrange()`
1175
+ in one `for` loop and a single index.
1176
+
1177
+ Args:
1178
+ x (:func:`~taichi.ndrange`): an iterator object returned by `ti.ndrange`.
1179
+
1180
+ Example::
1181
+ >>> # without ti.grouped
1182
+ >>> for I in ti.ndrange(2, 3):
1183
+ >>> print(I)
1184
+ prints 0, 1, 2, 3, 4, 5
1185
+
1186
+ >>> # with ti.grouped
1187
+ >>> for I in ti.grouped(ti.ndrange(2, 3)):
1188
+ >>> print(I)
1189
+ prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
1190
+ """
1191
+ if isinstance(x, _Ndrange):
1192
+ return x.grouped()
1193
+ return x
1194
+
1195
+
1196
+ def stop_grad(x):
1197
+ """Stops computing gradients during back propagation.
1198
+
1199
+ Args:
1200
+ x (:class:`~taichi.Field`): A field.
1201
+ """
1202
+ compiling_callable = get_runtime().compiling_callable
1203
+ assert compiling_callable is not None
1204
+ compiling_callable.ast_builder().stop_grad(x.snode.ptr)
1205
+
1206
+
1207
+ def current_cfg():
1208
+ return get_runtime().prog.config()
1209
+
1210
+
1211
+ def default_cfg():
1212
+ return _ti_core.default_compile_config()
1213
+
1214
+
1215
+ def call_internal(name, *args, with_runtime_context=True):
1216
+ return expr_init(_ti_core.insert_internal_func_call(getattr(_ti_core.InternalOp, name), make_expr_group(args)))
1217
+
1218
+
1219
+ def get_cuda_compute_capability():
1220
+ return _ti_core.query_int64("cuda_compute_capability")
1221
+
1222
+
1223
+ @taichi_scope
1224
+ def mesh_relation_access(mesh, from_index, to_element_type):
1225
+ # to support ti.mesh_local and access mesh attribute as field
1226
+ if isinstance(from_index, MeshInstance):
1227
+ return getattr(from_index, element_type_name(to_element_type))
1228
+ if isinstance(mesh, MeshInstance):
1229
+ return MeshRelationAccessProxy(mesh, from_index, to_element_type)
1230
+ raise RuntimeError("Relation access should be with a mesh instance!")
1231
+
1232
+
1233
+ __all__ = [
1234
+ "axes",
1235
+ "deactivate_all_snodes",
1236
+ "field",
1237
+ "grouped",
1238
+ "ndarray",
1239
+ "one",
1240
+ "root",
1241
+ "static",
1242
+ "static_assert",
1243
+ "static_print",
1244
+ "stop_grad",
1245
+ "zero",
1246
+ ]