gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/ops.py ADDED
@@ -0,0 +1,1494 @@
1
+ # type: ignore
2
+
3
+ import builtins
4
+ import functools
5
+ import operator as _bt_ops_mod # bt for builtin
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+
10
+ from gstaichi._lib import core as _ti_core
11
+ from gstaichi.lang import expr, impl
12
+ from gstaichi.lang.exception import GsTaichiSyntaxError
13
+ from gstaichi.lang.field import Field
14
+ from gstaichi.lang.util import (
15
+ cook_dtype,
16
+ gstaichi_scope,
17
+ is_gstaichi_class,
18
+ is_matrix_class,
19
+ )
20
+
21
+
22
+ def stack_info():
23
+ return impl.get_runtime().get_current_src_info()
24
+
25
+
26
+ def is_gstaichi_expr(a):
27
+ return isinstance(a, expr.Expr)
28
+
29
+
30
+ def wrap_if_not_expr(a):
31
+ return (
32
+ expr.Expr(a, dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()))
33
+ if not is_gstaichi_expr(a)
34
+ else a
35
+ )
36
+
37
+
38
+ def _read_matrix_or_scalar(x):
39
+ if is_matrix_class(x):
40
+ return x.to_numpy()
41
+ return x
42
+
43
+
44
+ def writeback_binary(foo):
45
+ @functools.wraps(foo)
46
+ def wrapped(a, b):
47
+ if isinstance(a, Field) or isinstance(b, Field):
48
+ return NotImplemented
49
+ if not (is_gstaichi_expr(a) and a.ptr.is_lvalue()):
50
+ raise GsTaichiSyntaxError(f"cannot use a non-writable target as the first operand of '{foo.__name__}'")
51
+ return foo(a, wrap_if_not_expr(b))
52
+
53
+ return wrapped
54
+
55
+
56
+ def cast(obj, dtype):
57
+ """Copy and cast a scalar or a matrix to a specified data type.
58
+ Must be called in GsTaichi scope.
59
+
60
+ Args:
61
+ obj (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
62
+ Input scalar or matrix.
63
+
64
+ dtype (:mod:`~gstaichi.types.primitive_types`): A primitive type defined in :mod:`~gstaichi.types.primitive_types`.
65
+
66
+ Returns:
67
+ A copy of `obj`, casted to the specified data type `dtype`.
68
+
69
+ Example::
70
+
71
+ >>> @ti.kernel
72
+ >>> def test():
73
+ >>> x = ti.Matrix([0, 1, 2], ti.i32)
74
+ >>> y = ti.cast(x, ti.f32)
75
+ >>> print(y)
76
+ >>>
77
+ >>> test()
78
+ [0.0, 1.0, 2.0]
79
+ """
80
+ dtype = cook_dtype(dtype)
81
+ if is_gstaichi_class(obj):
82
+ # TODO: unify with element_wise_unary
83
+ return obj.cast(dtype)
84
+ return expr.Expr(_ti_core.value_cast(expr.Expr(obj).ptr, dtype))
85
+
86
+
87
+ def bit_cast(obj, dtype):
88
+ """Copy and cast a scalar to a specified data type with its underlying
89
+ bits preserved. Must be called in gstaichi scope.
90
+
91
+ This function is equivalent to `reinterpret_cast` in C++.
92
+
93
+ Args:
94
+ obj (:mod:`~gstaichi.types.primitive_types`): Input scalar.
95
+
96
+ dtype (:mod:`~gstaichi.types.primitive_types`): Target data type, must have \
97
+ the same precision bits as the input (hence `f32` -> `f64` is not allowed).
98
+
99
+ Returns:
100
+ A copy of `obj`, casted to the specified data type `dtype`.
101
+
102
+ Example::
103
+
104
+ >>> @ti.kernel
105
+ >>> def test():
106
+ >>> x = 3.14
107
+ >>> y = ti.bit_cast(x, ti.i32)
108
+ >>> print(y) # 1078523331
109
+ >>>
110
+ >>> z = ti.bit_cast(y, ti.f32)
111
+ >>> print(z) # 3.14
112
+ """
113
+ dtype = cook_dtype(dtype)
114
+ if is_gstaichi_class(obj):
115
+ raise ValueError("Cannot apply bit_cast on GsTaichi classes")
116
+ else:
117
+ return expr.Expr(_ti_core.bits_cast(expr.Expr(obj).ptr, dtype))
118
+
119
+
120
+ def _unary_operation(gstaichi_op, python_op, a):
121
+ if isinstance(a, Field):
122
+ return NotImplemented
123
+ if is_gstaichi_expr(a):
124
+ return expr.Expr(gstaichi_op(a.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
125
+ from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
126
+
127
+ if isinstance(a, Matrix):
128
+ return Matrix(python_op(a.to_numpy()))
129
+ return python_op(a)
130
+
131
+
132
+ def _binary_operation(gstaichi_op, python_op, a, b):
133
+ if isinstance(a, Field) or isinstance(b, Field):
134
+ return NotImplemented
135
+ if is_gstaichi_expr(a) or is_gstaichi_expr(b):
136
+ a, b = wrap_if_not_expr(a), wrap_if_not_expr(b)
137
+ return expr.Expr(gstaichi_op(a.ptr, b.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
138
+ from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
139
+
140
+ if isinstance(a, Matrix) or isinstance(b, Matrix):
141
+ return Matrix(python_op(_read_matrix_or_scalar(a), _read_matrix_or_scalar(b)))
142
+ return python_op(a, b)
143
+
144
+
145
+ def _ternary_operation(gstaichi_op, python_op, a, b, c):
146
+ if isinstance(a, Field) or isinstance(b, Field) or isinstance(c, Field):
147
+ return NotImplemented
148
+ if is_gstaichi_expr(a) or is_gstaichi_expr(b) or is_gstaichi_expr(c):
149
+ a, b, c = wrap_if_not_expr(a), wrap_if_not_expr(b), wrap_if_not_expr(c)
150
+ return expr.Expr(gstaichi_op(a.ptr, b.ptr, c.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
151
+ from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
152
+
153
+ if isinstance(a, Matrix) or isinstance(b, Matrix) or isinstance(c, Matrix):
154
+ return Matrix(
155
+ python_op(
156
+ _read_matrix_or_scalar(a),
157
+ _read_matrix_or_scalar(b),
158
+ _read_matrix_or_scalar(c),
159
+ )
160
+ )
161
+ return python_op(a, b, c)
162
+
163
+
164
+ def neg(x):
165
+ """Numerical negative, element-wise.
166
+
167
+ Args:
168
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
169
+ Input scalar or matrix.
170
+
171
+ Returns:
172
+ Matrix or scalar `y`, so that `y = -x`. `y` has the same type as `x`.
173
+
174
+ Example::
175
+ >>> x = ti.Matrix([1, -1])
176
+ >>> y = ti.neg(a)
177
+ >>> y
178
+ [-1, 1]
179
+ """
180
+ return _unary_operation(_ti_core.expr_neg, _bt_ops_mod.neg, x)
181
+
182
+
183
+ def sin(x):
184
+ """Trigonometric sine, element-wise.
185
+
186
+ Args:
187
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
188
+ Angle, in radians.
189
+
190
+ Returns:
191
+ The sine of each element of `x`.
192
+
193
+ Example::
194
+
195
+ >>> from math import pi
196
+ >>> x = ti.Matrix([-pi/2., 0, pi/2.])
197
+ >>> ti.sin(x)
198
+ [-1., 0., 1.]
199
+ """
200
+ return _unary_operation(_ti_core.expr_sin, np.sin, x)
201
+
202
+
203
+ def cos(x):
204
+ """Trigonometric cosine, element-wise.
205
+
206
+ Args:
207
+ x (Union[:mod:`~gstaichi.type.primitive_types`, :class:`~gstaichi.Matrix`]): \
208
+ Angle, in radians.
209
+
210
+ Returns:
211
+ The cosine of each element of `x`.
212
+
213
+ Example::
214
+
215
+ >>> from math import pi
216
+ >>> x = ti.Matrix([-pi, 0, pi/2.])
217
+ >>> ti.cos(x)
218
+ [-1., 1., 0.]
219
+ """
220
+ return _unary_operation(_ti_core.expr_cos, np.cos, x)
221
+
222
+
223
+ def asin(x):
224
+ """Trigonometric inverse sine, element-wise.
225
+
226
+ The inverse of `sin` so that, if `y = sin(x)`, then `x = asin(y)`.
227
+
228
+ For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \
229
+ it's called in gstaichi scope, or raises exception if it's called in python scope.
230
+
231
+ Args:
232
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
233
+ A scalar or a matrix with elements in [-1, 1].
234
+
235
+ Returns:
236
+ The inverse sine of each element in `x`, in radians and in the closed \
237
+ interval `[-pi/2, pi/2]`.
238
+
239
+ Example::
240
+
241
+ >>> from math import pi
242
+ >>> ti.asin(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi
243
+ [-90., 0., 90.]
244
+ """
245
+ return _unary_operation(_ti_core.expr_asin, np.arcsin, x)
246
+
247
+
248
+ def acos(x):
249
+ """Trigonometric inverse cosine, element-wise.
250
+
251
+ The inverse of `cos` so that, if `y = cos(x)`, then `x = acos(y)`.
252
+
253
+ For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \
254
+ it's called in gstaichi scope, or raises exception if it's called in python scope.
255
+
256
+ Args:
257
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
258
+ A scalar or a matrix with elements in [-1, 1].
259
+
260
+ Returns:
261
+ The inverse cosine of each element in `x`, in radians and in the closed \
262
+ interval `[0, pi]`. This is a scalar if `x` is a scalar.
263
+
264
+ Example::
265
+
266
+ >>> from math import pi
267
+ >>> ti.acos(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi
268
+ [180., 90., 0.]
269
+ """
270
+ return _unary_operation(_ti_core.expr_acos, np.arccos, x)
271
+
272
+
273
+ def sqrt(x):
274
+ """Return the non-negative square-root of a scalar or a matrix,
275
+ element wise. If `x < 0` an exception is raised.
276
+
277
+ Args:
278
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
279
+ The scalar or matrix whose square-roots are required.
280
+
281
+ Returns:
282
+ The square-root `y` so that `y >= 0` and `y^2 = x`. `y` has the same type as `x`.
283
+
284
+ Example::
285
+
286
+ >>> x = ti.Matrix([1., 4., 9.])
287
+ >>> y = ti.sqrt(x)
288
+ >>> y
289
+ [1.0, 2.0, 3.0]
290
+ """
291
+ return _unary_operation(_ti_core.expr_sqrt, np.sqrt, x)
292
+
293
+
294
+ def rsqrt(x):
295
+ """The reciprocal of the square root function.
296
+
297
+ Args:
298
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
299
+ A scalar or a matrix.
300
+
301
+ Returns:
302
+ The reciprocal of `sqrt(x)`.
303
+ """
304
+
305
+ def _rsqrt(x):
306
+ return 1 / np.sqrt(x)
307
+
308
+ return _unary_operation(_ti_core.expr_rsqrt, _rsqrt, x)
309
+
310
+
311
+ def _round(x):
312
+ return _unary_operation(_ti_core.expr_round, np.round, x)
313
+
314
+
315
+ def round(x, dtype=None): # pylint: disable=redefined-builtin
316
+ """Round to the nearest integer, element-wise.
317
+
318
+ Args:
319
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
320
+ A scalar or a matrix.
321
+
322
+ dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
323
+ set to `None` the retuned value will have the same type with `x`.
324
+
325
+ Returns:
326
+ The nearest integer of `x`, with return value type `dtype`.
327
+
328
+ Example::
329
+
330
+ >>> @ti.kernel
331
+ >>> def test():
332
+ >>> x = ti.Vector([-1.5, 1.2, 2.7])
333
+ >>> print(ti.round(x))
334
+ [-2., 1., 3.]
335
+ """
336
+ result = _round(x)
337
+ if dtype is not None:
338
+ result = cast(result, dtype)
339
+ return result
340
+
341
+
342
+ def _floor(x):
343
+ return _unary_operation(_ti_core.expr_floor, np.floor, x)
344
+
345
+
346
+ def floor(x, dtype=None):
347
+ """Return the floor of the input, element-wise.
348
+ The floor of the scalar `x` is the largest integer `k`, such that `k <= x`.
349
+
350
+ Args:
351
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
352
+ Input scalar or matrix.
353
+
354
+ dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
355
+ set to `None` the retuned value will have the same type with `x`.
356
+
357
+ Returns:
358
+ The floor of each element in `x`, with return value type `dtype`.
359
+
360
+ Example::
361
+ >>> @ti.kernel
362
+ >>> def test():
363
+ >>> x = ti.Matrix([-1.1, 2.2, 3.])
364
+ >>> y = ti.floor(x, ti.f64)
365
+ >>> print(y) # [-2.000000000000, 2.000000000000, 3.000000000000]
366
+ """
367
+ result = _floor(x)
368
+ if dtype is not None:
369
+ result = cast(result, dtype)
370
+ return result
371
+
372
+
373
+ def _ceil(x):
374
+ return _unary_operation(_ti_core.expr_ceil, np.ceil, x)
375
+
376
+
377
+ def ceil(x, dtype=None):
378
+ """Return the ceiling of the input, element-wise.
379
+
380
+ The ceil of the scalar `x` is the smallest integer `k`, such that `k >= x`.
381
+
382
+ Args:
383
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
384
+ Input scalar or matrix.
385
+
386
+ dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
387
+ set to `None` the retuned value will have the same type with `x`.
388
+
389
+ Returns:
390
+ The ceiling of each element in `x`, with return value type `dtype`.
391
+
392
+ Example::
393
+
394
+ >>> @ti.kernel
395
+ >>> def test():
396
+ >>> x = ti.Matrix([3.14, -1.5])
397
+ >>> y = ti.ceil(x)
398
+ >>> print(y) # [4.0, -1.0]
399
+ """
400
+ result = _ceil(x)
401
+ if dtype is not None:
402
+ result = cast(result, dtype)
403
+ return result
404
+
405
+
406
+ def frexp(x):
407
+ return _unary_operation(_ti_core.expr_frexp, np.frexp, x)
408
+
409
+
410
+ def tan(x):
411
+ """Trigonometric tangent function, element-wise.
412
+
413
+ Equivalent to `ti.sin(x)/ti.cos(x)` element-wise.
414
+
415
+ Args:
416
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
417
+ Input scalar or matrix.
418
+
419
+ Returns:
420
+ The tangent values of `x`.
421
+
422
+ Example::
423
+
424
+ >>> from math import pi
425
+ >>> @ti.kernel
426
+ >>> def test():
427
+ >>> x = ti.Matrix([-pi, pi/2, pi])
428
+ >>> y = ti.tan(x)
429
+ >>> print(y)
430
+ >>>
431
+ >>> test()
432
+ [-0.0, -22877334.0, 0.0]
433
+ """
434
+ return _unary_operation(_ti_core.expr_tan, np.tan, x)
435
+
436
+
437
+ def tanh(x):
438
+ """Compute the hyperbolic tangent of `x`, element-wise.
439
+
440
+ Args:
441
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
442
+ Input scalar or matrix.
443
+
444
+ Returns:
445
+ The corresponding hyperbolic tangent values.
446
+
447
+ Example::
448
+
449
+ >>> @ti.kernel
450
+ >>> def test():
451
+ >>> x = ti.Matrix([-1.0, 0.0, 1.0])
452
+ >>> y = ti.tanh(x)
453
+ >>> print(y)
454
+ >>>
455
+ >>> test()
456
+ [-0.761594, 0.000000, 0.761594]
457
+ """
458
+ return _unary_operation(_ti_core.expr_tanh, np.tanh, x)
459
+
460
+
461
+ def exp(x):
462
+ """Compute the exponential of all elements in `x`, element-wise.
463
+
464
+ Args:
465
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
466
+ Input scalar or matrix.
467
+
468
+ Returns:
469
+ Element-wise exponential of `x`.
470
+
471
+ Example::
472
+
473
+ >>> @ti.kernel
474
+ >>> def test():
475
+ >>> x = ti.Matrix([-1.0, 0.0, 1.0])
476
+ >>> y = ti.exp(x)
477
+ >>> print(y)
478
+ >>>
479
+ >>> test()
480
+ [0.367879, 1.000000, 2.718282]
481
+ """
482
+ return _unary_operation(_ti_core.expr_exp, np.exp, x)
483
+
484
+
485
+ def log(x):
486
+ """Compute the natural logarithm, element-wise.
487
+
488
+ The natural logarithm `log` is the inverse of the exponential function,
489
+ so that `log(exp(x)) = x`. The natural logarithm is logarithm in base `e`.
490
+
491
+ Args:
492
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
493
+ Input scalar or matrix.
494
+
495
+ Returns:
496
+ The natural logarithm of `x`, element-wise.
497
+
498
+ Example::
499
+
500
+ >>> @ti.kernel
501
+ >>> def test():
502
+ >>> x = ti.Vector([-1.0, 0.0, 1.0])
503
+ >>> y = ti.log(x)
504
+ >>> print(y)
505
+ >>>
506
+ >>> test()
507
+ [-nan, -inf, 0.000000]
508
+ """
509
+ return _unary_operation(_ti_core.expr_log, np.log, x)
510
+
511
+
512
+ def abs(x): # pylint: disable=W0622
513
+ """Compute the absolute value :math:`|x|` of `x`, element-wise.
514
+
515
+ Args:
516
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
517
+ Input scalar or matrix.
518
+
519
+ Returns:
520
+ The absolute value of each element in `x`.
521
+
522
+ Example::
523
+
524
+ >>> @ti.kernel
525
+ >>> def test():
526
+ >>> x = ti.Vector([-1.0, 0.0, 1.0])
527
+ >>> y = ti.abs(x)
528
+ >>> print(y)
529
+ >>>
530
+ >>> test()
531
+ [1.0, 0.0, 1.0]
532
+ """
533
+ return _unary_operation(_ti_core.expr_abs, builtins.abs, x)
534
+
535
+
536
+ def bit_not(a):
537
+ """The bit not function.
538
+
539
+ Args:
540
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
541
+
542
+ Returns:
543
+ Bitwise not of `a`.
544
+ """
545
+ return _unary_operation(_ti_core.expr_bit_not, _bt_ops_mod.invert, a)
546
+
547
+
548
+ def popcnt(a):
549
+ def _popcnt(x):
550
+ return bin(x).count("1")
551
+
552
+ return _unary_operation(_ti_core.expr_popcnt, _popcnt, a)
553
+
554
+
555
+ def logical_not(a):
556
+ """The logical not function.
557
+
558
+ Args:
559
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
560
+
561
+ Returns:
562
+ `1` iff `a=0`, otherwise `0`.
563
+ """
564
+ return _unary_operation(_ti_core.expr_logic_not, np.logical_not, a)
565
+
566
+
567
+ def random(dtype=float) -> Union[float, int]:
568
+ """Return a single random float/integer according to the specified data type.
569
+ Must be called in gstaichi scope.
570
+
571
+ If the required `dtype` is float type, this function returns a random number
572
+ sampled from the uniform distribution in the half-open interval [0, 1).
573
+
574
+ For integer types this function returns a random integer in the
575
+ half-open interval [0, 2^32) if a 32-bit integer is required,
576
+ or a random integer in the half-open interval [0, 2^64) if a
577
+ 64-bit integer is required.
578
+
579
+ Args:
580
+ dtype (:mod:`~gstaichi.types.primitive_types`): Type of the required random value.
581
+
582
+ Returns:
583
+ A random value with type `dtype`.
584
+
585
+ Example::
586
+
587
+ >>> @ti.kernel
588
+ >>> def test():
589
+ >>> x = ti.random(float)
590
+ >>> print(x) # 0.090257
591
+ >>>
592
+ >>> y = ti.random(ti.f64)
593
+ >>> print(y) # 0.716101627301
594
+ >>>
595
+ >>> i = ti.random(ti.i32)
596
+ >>> print(i) # -963722261
597
+ >>>
598
+ >>> j = ti.random(ti.i64)
599
+ >>> print(j) # 73412986184350777
600
+ """
601
+ dtype = cook_dtype(dtype)
602
+ x = expr.Expr(_ti_core.make_rand_expr(dtype, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())))
603
+ return impl.expr_init(x)
604
+
605
+
606
+ # NEXT: add matpow(self, power)
607
+
608
+
609
+ def add(a, b):
610
+ """The add function.
611
+
612
+ Args:
613
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
614
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
615
+
616
+ Returns:
617
+ sum of `a` and `b`.
618
+ """
619
+ return _binary_operation(_ti_core.expr_add, _bt_ops_mod.add, a, b)
620
+
621
+
622
+ def sub(a, b):
623
+ """The sub function.
624
+
625
+ Args:
626
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
627
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
628
+
629
+ Returns:
630
+ `a` subtract `b`.
631
+ """
632
+ return _binary_operation(_ti_core.expr_sub, _bt_ops_mod.sub, a, b)
633
+
634
+
635
+ def mul(a, b):
636
+ """The multiply function.
637
+
638
+ Args:
639
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
640
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
641
+
642
+ Returns:
643
+ `a` multiplied by `b`.
644
+ """
645
+ return _binary_operation(_ti_core.expr_mul, _bt_ops_mod.mul, a, b)
646
+
647
+
648
+ def mod(x1, x2):
649
+ """Returns the element-wise remainder of division.
650
+
651
+ This is equivalent to the Python modulus operator `x1 % x2` and
652
+ has the same sign as the divisor x2.
653
+
654
+ Args:
655
+ x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
656
+ Dividend scalar or matrix.
657
+
658
+ x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
659
+ Divisor scalar or matrix. When both `x1` and `x2` are matrices they must have the same shape.
660
+
661
+ Returns:
662
+ The element-wise remainder of the quotient `floordiv(x1, x2)`. This is a scalar \
663
+ if both `x1` and `x2` are scalars.
664
+
665
+ Example::
666
+
667
+ >>> @ti.kernel
668
+ >>> def test():
669
+ >>> x = ti.Matrix([3.0, 4.0, 5.0])
670
+ >>> y = 3
671
+ >>> z = ti.mod(y, x)
672
+ >>> print(z)
673
+ >>>
674
+ >>> test()
675
+ [1.0, 0.0, 4.0]
676
+ """
677
+
678
+ def expr_python_mod(a, b):
679
+ # a % b = a - (a // b) * b
680
+ quotient = expr.Expr(_ti_core.expr_floordiv(a, b))
681
+ multiply = expr.Expr(_ti_core.expr_mul(b, quotient.ptr))
682
+ return _ti_core.expr_sub(a, multiply.ptr)
683
+
684
+ return _binary_operation(expr_python_mod, _bt_ops_mod.mod, x1, x2)
685
+
686
+
687
+ def pow(base, exponent): # pylint: disable=W0622
688
+ """First array elements raised to second array elements :math:`{base}^{exponent}`, element-wise.
689
+
690
+ The result type of two scalar operands is determined as follows:
691
+ - If the exponent is an integral value, then the result type takes the type of the base.
692
+ - Otherwise, the result type follows
693
+ [Implicit type casting in binary operations](https://docs.taichi-lang.org/docs/type#implicit-type-casting-in-binary-operations).
694
+
695
+ With the above rules, an integral value raised to a negative integral value cannot have a
696
+ feasible type. Therefore, an exception will be raised if debug mode or optimization passes
697
+ are on; otherwise 1 will be returned.
698
+
699
+ In the following situations, the result is undefined:
700
+ - A negative value raised to a non-integral value.
701
+ - A zero value raised to a non-positive value.
702
+
703
+ Args:
704
+ base (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
705
+ The bases.
706
+ exponent (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
707
+ The exponents.
708
+
709
+ Returns:
710
+ `base` raised to `exponent`. This is a scalar if both `base` and `exponent` are scalars.
711
+
712
+ Example::
713
+
714
+ >>> @ti.kernel
715
+ >>> def test():
716
+ >>> x = ti.Matrix([-2.0, 2.0])
717
+ >>> y = -3
718
+ >>> z = ti.pow(x, y)
719
+ >>> print(z)
720
+ >>>
721
+ >>> test()
722
+ [-0.125000, 0.125000]
723
+ """
724
+ return _binary_operation(_ti_core.expr_pow, _bt_ops_mod.pow, base, exponent)
725
+
726
+
727
+ def floordiv(a, b):
728
+ """The floor division function.
729
+
730
+ Args:
731
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
732
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero.
733
+
734
+ Returns:
735
+ The floor function of `a` divided by `b`.
736
+ """
737
+ return _binary_operation(_ti_core.expr_floordiv, _bt_ops_mod.floordiv, a, b)
738
+
739
+
740
+ def truediv(a, b):
741
+ """True division function.
742
+
743
+ Args:
744
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
745
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero.
746
+
747
+ Returns:
748
+ The true value of `a` divided by `b`.
749
+ """
750
+ return _binary_operation(_ti_core.expr_truediv, _bt_ops_mod.truediv, a, b)
751
+
752
+
753
+ def max_impl(a, b):
754
+ """The maxnimum function.
755
+
756
+ Args:
757
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
758
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
759
+
760
+ Returns:
761
+ The maxnimum of `a` and `b`.
762
+ """
763
+ return _binary_operation(_ti_core.expr_max, np.maximum, a, b)
764
+
765
+
766
+ def min_impl(a, b):
767
+ """The minimum function.
768
+
769
+ Args:
770
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
771
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
772
+
773
+ Returns:
774
+ The minimum of `a` and `b`.
775
+ """
776
+ return _binary_operation(_ti_core.expr_min, np.minimum, a, b)
777
+
778
+
779
+ def atan2(x1, x2):
780
+ """Element-wise arc tangent of `x1/x2`.
781
+
782
+ Args:
783
+ x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
784
+ y-coordinates.
785
+ x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
786
+ x-coordinates.
787
+
788
+ Returns:
789
+ Angles in radians, in the range `[-pi, pi]`.
790
+ This is a scalar if both `x1` and `x2` are scalars.
791
+
792
+ Example::
793
+
794
+ >>> from math import pi
795
+ >>> @ti.kernel
796
+ >>> def test():
797
+ >>> x = ti.Matrix([-1.0, 1.0, -1.0, 1.0])
798
+ >>> y = ti.Matrix([-1.0, -1.0, 1.0, 1.0])
799
+ >>> z = ti.atan2(y, x) * 180 / pi
800
+ >>> print(z)
801
+ >>>
802
+ >>> test()
803
+ [-135.0, -45.0, 135.0, 45.0]
804
+ """
805
+ return _binary_operation(_ti_core.expr_atan2, np.arctan2, x1, x2)
806
+
807
+
808
+ def raw_div(x1, x2):
809
+ """Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`.
810
+
811
+ Args:
812
+ x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): Dividend.
813
+ x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): Divisor.
814
+
815
+ Returns:
816
+ Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`.
817
+
818
+ Example::
819
+
820
+ >>> @ti.kernel
821
+ >>> def main():
822
+ >>> x = 5
823
+ >>> y = 3
824
+ >>> print(raw_div(x, y)) # 1
825
+ >>> z = 4.0
826
+ >>> print(raw_div(x, z)) # 1.25
827
+ """
828
+
829
+ def c_div(a, b):
830
+ if isinstance(a, int) and isinstance(b, int):
831
+ return a // b
832
+ return a / b
833
+
834
+ return _binary_operation(_ti_core.expr_div, c_div, x1, x2)
835
+
836
+
837
+ def raw_mod(x1, x2):
838
+ """Return the remainder of `x1/x2`, element-wise.
839
+ This is the C-style `mod` function.
840
+
841
+ Args:
842
+ x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
843
+ The dividend.
844
+ x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
845
+ The divisor.
846
+
847
+ Returns:
848
+ The remainder of `x1` divided by `x2`.
849
+
850
+ Example::
851
+
852
+ >>> @ti.kernel
853
+ >>> def main():
854
+ >>> print(ti.mod(-4, 3)) # 2
855
+ >>> print(ti.raw_mod(-4, 3)) # -1
856
+ """
857
+
858
+ def c_mod(x, y):
859
+ return x - y * int(float(x) / y)
860
+
861
+ return _binary_operation(_ti_core.expr_mod, c_mod, x1, x2)
862
+
863
+
864
+ def cmp_lt(a, b):
865
+ """Compare two values (less than)
866
+
867
+ Args:
868
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
869
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
870
+
871
+ Returns:
872
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is strictly smaller than RHS, False otherwise
873
+
874
+ """
875
+ return _binary_operation(_ti_core.expr_cmp_lt, _bt_ops_mod.lt, a, b)
876
+
877
+
878
+ def cmp_le(a, b):
879
+ """Compare two values (less than or equal to)
880
+
881
+ Args:
882
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
883
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
884
+
885
+ Returns:
886
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is smaller than or equal to RHS, False otherwise
887
+
888
+ """
889
+ return _binary_operation(_ti_core.expr_cmp_le, _bt_ops_mod.le, a, b)
890
+
891
+
892
+ def cmp_gt(a, b):
893
+ """Compare two values (greater than)
894
+
895
+ Args:
896
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
897
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
898
+
899
+ Returns:
900
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is strictly larger than RHS, False otherwise
901
+
902
+ """
903
+ return _binary_operation(_ti_core.expr_cmp_gt, _bt_ops_mod.gt, a, b)
904
+
905
+
906
+ def cmp_ge(a, b):
907
+ """Compare two values (greater than or equal to)
908
+
909
+ Args:
910
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
911
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
912
+
913
+ Returns:
914
+ bool: True if LHS is greater than or equal to RHS, False otherwise
915
+
916
+ """
917
+ return _binary_operation(_ti_core.expr_cmp_ge, _bt_ops_mod.ge, a, b)
918
+
919
+
920
+ def cmp_eq(a, b):
921
+ """Compare two values (equal to)
922
+
923
+ Args:
924
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
925
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
926
+
927
+ Returns:
928
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is equal to RHS, False otherwise.
929
+
930
+ """
931
+ return _binary_operation(_ti_core.expr_cmp_eq, _bt_ops_mod.eq, a, b)
932
+
933
+
934
+ def cmp_ne(a, b):
935
+ """Compare two values (not equal to)
936
+
937
+ Args:
938
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
939
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
940
+
941
+ Returns:
942
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is not equal to RHS, False otherwise
943
+
944
+ """
945
+ return _binary_operation(_ti_core.expr_cmp_ne, _bt_ops_mod.ne, a, b)
946
+
947
+
948
+ def bit_or(a, b):
949
+ """Computes bitwise-or
950
+
951
+ Args:
952
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
953
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
954
+
955
+ Returns:
956
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-or with RHS
957
+
958
+ """
959
+ return _binary_operation(_ti_core.expr_bit_or, _bt_ops_mod.or_, a, b)
960
+
961
+
962
+ def bit_and(a, b):
963
+ """Compute bitwise-and
964
+
965
+ Args:
966
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
967
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
968
+
969
+ Returns:
970
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-and with RHS
971
+
972
+ """
973
+ return _binary_operation(_ti_core.expr_bit_and, _bt_ops_mod.and_, a, b)
974
+
975
+
976
+ def bit_xor(a, b):
977
+ """Compute bitwise-xor
978
+
979
+ Args:
980
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
981
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
982
+
983
+ Returns:
984
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-xor with RHS
985
+
986
+ """
987
+ return _binary_operation(_ti_core.expr_bit_xor, _bt_ops_mod.xor, a, b)
988
+
989
+
990
+ def bit_shl(a, b):
991
+ """Compute bitwise shift left
992
+
993
+ Args:
994
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
995
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
996
+
997
+ Returns:
998
+ Union[:class:`~gstaichi.lang.expr.Expr`, int]: LHS << RHS
999
+
1000
+ """
1001
+ return _binary_operation(_ti_core.expr_bit_shl, _bt_ops_mod.lshift, a, b)
1002
+
1003
+
1004
+ def bit_sar(a, b):
1005
+ """Compute bitwise shift right
1006
+
1007
+ Args:
1008
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
1009
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
1010
+
1011
+ Returns:
1012
+ Union[:class:`~gstaichi.lang.expr.Expr`, int]: LHS >> RHS
1013
+
1014
+ """
1015
+ return _binary_operation(_ti_core.expr_bit_sar, _bt_ops_mod.rshift, a, b)
1016
+
1017
+
1018
+ @gstaichi_scope
1019
+ def bit_shr(x1, x2):
1020
+ """Elements in `x1` shifted to the right by number of bits in `x2`.
1021
+ Both `x1`, `x2` must have integer type.
1022
+
1023
+ Args:
1024
+ x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1025
+ Input data.
1026
+ x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1027
+ Number of bits to remove at the right of `x1`.
1028
+
1029
+ Returns:
1030
+ Return `x1` with bits shifted `x2` times to the right.
1031
+ This is a scalar if both `x1` and `x2` are scalars.
1032
+
1033
+ Example::
1034
+ >>> @ti.kernel
1035
+ >>> def main():
1036
+ >>> x = ti.Matrix([7, 8])
1037
+ >>> y = ti.Matrix([1, 2])
1038
+ >>> print(ti.bit_shr(x, y))
1039
+ >>>
1040
+ >>> main()
1041
+ [3, 2]
1042
+ """
1043
+ return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, x1, x2)
1044
+
1045
+
1046
+ def logical_and(a, b):
1047
+ """Compute logical_and
1048
+
1049
+ Args:
1050
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
1051
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
1052
+
1053
+ Returns:
1054
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS logical-and RHS (with short-circuit semantics)
1055
+
1056
+ """
1057
+ return _binary_operation(_ti_core.expr_logical_and, lambda a, b: a and b, a, b)
1058
+
1059
+
1060
+ def logical_or(a, b):
1061
+ """Compute logical_or
1062
+
1063
+ Args:
1064
+ a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
1065
+ b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
1066
+
1067
+ Returns:
1068
+ Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS logical-or RHS (with short-circuit semantics)
1069
+
1070
+ """
1071
+ return _binary_operation(_ti_core.expr_logical_or, lambda a, b: a or b, a, b)
1072
+
1073
+
1074
+ def select(cond, x1, x2):
1075
+ """Return an array drawn from elements in `x1` or `x2`,
1076
+ depending on the conditions in `cond`.
1077
+
1078
+ Args:
1079
+ cond (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1080
+ The array of conditions.
1081
+ x1, x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1082
+ The arrays where the output elements are taken from.
1083
+
1084
+ Returns:
1085
+ The output at position `k` is the k-th element of `x1` if the k-th element
1086
+ in `cond` is `True`, otherwise it's the k-th element of `x2`.
1087
+
1088
+ Example::
1089
+
1090
+ >>> @ti.kernel
1091
+ >>> def main():
1092
+ >>> cond = ti.Matrix([0, 1, 0, 1])
1093
+ >>> x = ti.Matrix([1, 2, 3, 4])
1094
+ >>> y = ti.Matrix([-1, -2, -3, -4])
1095
+ >>> print(ti.select(cond, x, y))
1096
+ >>>
1097
+ >>> main()
1098
+ [-1, 2, -3, 4]
1099
+ """
1100
+ # TODO: systematically resolve `-1 = True` problem by introducing u1:
1101
+ cond = logical_not(logical_not(cond))
1102
+
1103
+ def py_select(cond, x1, x2):
1104
+ return x1 * cond + x2 * (1 - cond)
1105
+
1106
+ return _ternary_operation(_ti_core.expr_select, py_select, cond, x1, x2)
1107
+
1108
+
1109
+ def ifte(cond, x1, x2):
1110
+ """Evaluate and return `x1` if `cond` is true; otherwise evaluate and return `x2`. This operator guarantees
1111
+ short-circuit semantics: exactly one of `x1` or `x2` will be evaluated.
1112
+
1113
+ Args:
1114
+ cond (:mod:`~gstaichi.types.primitive_types`): \
1115
+ The condition.
1116
+ x1, x2 (:mod:`~gstaichi.types.primitive_types`): \
1117
+ The outputs.
1118
+
1119
+ Returns:
1120
+ `x1` if `cond` is true and `x2` otherwise.
1121
+ """
1122
+ # TODO: systematically resolve `-1 = True` problem by introducing u1:
1123
+ cond = logical_not(logical_not(cond))
1124
+
1125
+ def py_ifte(cond, x1, x2):
1126
+ return x1 if cond else x2
1127
+
1128
+ return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2)
1129
+
1130
+
1131
+ def clz(a):
1132
+ """Count the number of leading zeros for a 32bit integer"""
1133
+
1134
+ def _clz(x):
1135
+ for i in range(32):
1136
+ if 2**i > x:
1137
+ return 32 - i
1138
+ return 0
1139
+
1140
+ return _unary_operation(_ti_core.expr_clz, _clz, a)
1141
+
1142
+
1143
+ @writeback_binary
1144
+ def atomic_add(x, y):
1145
+ """Atomically compute `x + y`, store the result in `x`,
1146
+ and return the old value of `x`.
1147
+
1148
+ `x` must be a writable target, constant expressions or scalars
1149
+ are not allowed.
1150
+
1151
+ Args:
1152
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1153
+ The input.
1154
+
1155
+ Returns:
1156
+ The old value of `x`.
1157
+
1158
+ Example::
1159
+
1160
+ >>> @ti.kernel
1161
+ >>> def test():
1162
+ >>> x = ti.Vector([0, 0, 0])
1163
+ >>> y = ti.Vector([1, 2, 3])
1164
+ >>> z = ti.atomic_add(x, y)
1165
+ >>> print(x) # [1, 2, 3] the new value of x
1166
+ >>> print(z) # [0, 0, 0], the old value of x
1167
+ >>>
1168
+ >>> ti.atomic_add(1, x) # will raise GsTaichiSyntaxError
1169
+ """
1170
+ return impl.expr_init(expr.Expr(_ti_core.expr_atomic_add(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
1171
+
1172
+
1173
+ @writeback_binary
1174
+ def atomic_mul(x, y):
1175
+ """Atomically compute `x * y`, store the result in `x`,
1176
+ and return the old value of `x`.
1177
+
1178
+ `x` must be a writable target, constant expressions or scalars
1179
+ are not allowed.
1180
+
1181
+ Args:
1182
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1183
+ The input.
1184
+
1185
+ Returns:
1186
+ The old value of `x`.
1187
+
1188
+ Example::
1189
+
1190
+ >>> @ti.kernel
1191
+ >>> def test():
1192
+ >>> x = ti.Vector([1, 2, 3])
1193
+ >>> y = ti.Vector([4, 5, 6])
1194
+ >>> z = ti.atomic_mul(x, y)
1195
+ >>> print(x) # [1, 2, 3] the new value of x
1196
+ >>> print(z) # [4, 10, 18], the old value of x
1197
+ >>>
1198
+ >>> ti.atomic_mul(1, x) # will raise GsTaichiSyntaxError
1199
+ """
1200
+ return impl.expr_init(expr.Expr(_ti_core.expr_atomic_mul(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
1201
+
1202
+
1203
+ @writeback_binary
1204
+ def atomic_sub(x, y):
1205
+ """Atomically subtract `x` by `y`, store the result in `x`,
1206
+ and return the old value of `x`.
1207
+
1208
+ `x` must be a writable target, constant expressions or scalars
1209
+ are not allowed.
1210
+
1211
+ Args:
1212
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1213
+ The input.
1214
+
1215
+ Returns:
1216
+ The old value of `x`.
1217
+
1218
+ Example::
1219
+
1220
+ >>> @ti.kernel
1221
+ >>> def test():
1222
+ >>> x = ti.Vector([0, 0, 0])
1223
+ >>> y = ti.Vector([1, 2, 3])
1224
+ >>> z = ti.atomic_sub(x, y)
1225
+ >>> print(x) # [-1, -2, -3] the new value of x
1226
+ >>> print(z) # [0, 0, 0], the old value of x
1227
+ >>>
1228
+ >>> ti.atomic_sub(1, x) # will raise GsTaichiSyntaxError
1229
+ """
1230
+ return impl.expr_init(expr.Expr(_ti_core.expr_atomic_sub(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
1231
+
1232
+
1233
+ @writeback_binary
1234
+ def atomic_min(x, y):
1235
+ """Atomically compute the minimum of `x` and `y`, element-wise.
1236
+ Store the result in `x`, and return the old value of `x`.
1237
+
1238
+ `x` must be a writable target, constant expressions or scalars
1239
+ are not allowed.
1240
+
1241
+ Args:
1242
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1243
+ The input.
1244
+
1245
+ Returns:
1246
+ The old value of `x`.
1247
+
1248
+ Example::
1249
+
1250
+ >>> @ti.kernel
1251
+ >>> def test():
1252
+ >>> x = 2
1253
+ >>> y = 1
1254
+ >>> z = ti.atomic_min(x, y)
1255
+ >>> print(x) # 1 the new value of x
1256
+ >>> print(z) # 2, the old value of x
1257
+ >>>
1258
+ >>> ti.atomic_min(1, x) # will raise GsTaichiSyntaxError
1259
+ """
1260
+ return impl.expr_init(expr.Expr(_ti_core.expr_atomic_min(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
1261
+
1262
+
1263
+ @writeback_binary
1264
+ def atomic_max(x, y):
1265
+ """Atomically compute the maximum of `x` and `y`, element-wise.
1266
+ Store the result in `x`, and return the old value of `x`.
1267
+
1268
+ `x` must be a writable target, constant expressions or scalars
1269
+ are not allowed.
1270
+
1271
+ Args:
1272
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1273
+ The input.
1274
+
1275
+ Returns:
1276
+ The old value of `x`.
1277
+
1278
+ Example::
1279
+
1280
+ >>> @ti.kernel
1281
+ >>> def test():
1282
+ >>> x = 1
1283
+ >>> y = 2
1284
+ >>> z = ti.atomic_max(x, y)
1285
+ >>> print(x) # 2 the new value of x
1286
+ >>> print(z) # 1, the old value of x
1287
+ >>>
1288
+ >>> ti.atomic_max(1, x) # will raise GsTaichiSyntaxError
1289
+ """
1290
+ return impl.expr_init(expr.Expr(_ti_core.expr_atomic_max(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
1291
+
1292
+
1293
+ @writeback_binary
1294
+ def atomic_and(x, y):
1295
+ """Atomically compute the bit-wise AND of `x` and `y`, element-wise.
1296
+ Store the result in `x`, and return the old value of `x`.
1297
+
1298
+ `x` must be a writable target, constant expressions or scalars
1299
+ are not allowed.
1300
+
1301
+ Args:
1302
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1303
+ The input. When both are matrices they must have the same shape.
1304
+
1305
+ Returns:
1306
+ The old value of `x`.
1307
+
1308
+ Example::
1309
+
1310
+ >>> @ti.kernel
1311
+ >>> def test():
1312
+ >>> x = ti.Vector([-1, 0, 1])
1313
+ >>> y = ti.Vector([1, 2, 3])
1314
+ >>> z = ti.atomic_and(x, y)
1315
+ >>> print(x) # [1, 0, 1] the new value of x
1316
+ >>> print(z) # [-1, 0, 1], the old value of x
1317
+ >>>
1318
+ >>> ti.atomic_and(1, x) # will raise GsTaichiSyntaxError
1319
+ """
1320
+ return impl.expr_init(
1321
+ expr.Expr(_ti_core.expr_atomic_bit_and(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
1322
+ )
1323
+
1324
+
1325
+ @writeback_binary
1326
+ def atomic_or(x, y):
1327
+ """Atomically compute the bit-wise OR of `x` and `y`, element-wise.
1328
+ Store the result in `x`, and return the old value of `x`.
1329
+
1330
+ `x` must be a writable target, constant expressions or scalars
1331
+ are not allowed.
1332
+
1333
+ Args:
1334
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1335
+ The input. When both are matrices they must have the same shape.
1336
+
1337
+ Returns:
1338
+ The old value of `x`.
1339
+
1340
+ Example::
1341
+
1342
+ >>> @ti.kernel
1343
+ >>> def test():
1344
+ >>> x = ti.Vector([-1, 0, 1])
1345
+ >>> y = ti.Vector([1, 2, 3])
1346
+ >>> z = ti.atomic_or(x, y)
1347
+ >>> print(x) # [-1, 2, 3] the new value of x
1348
+ >>> print(z) # [-1, 0, 1], the old value of x
1349
+ >>>
1350
+ >>> ti.atomic_or(1, x) # will raise GsTaichiSyntaxError
1351
+ """
1352
+ return impl.expr_init(
1353
+ expr.Expr(_ti_core.expr_atomic_bit_or(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
1354
+ )
1355
+
1356
+
1357
+ @writeback_binary
1358
+ def atomic_xor(x, y):
1359
+ """Atomically compute the bit-wise XOR of `x` and `y`, element-wise.
1360
+ Store the result in `x`, and return the old value of `x`.
1361
+
1362
+ `x` must be a writable target, constant expressions or scalars
1363
+ are not allowed.
1364
+
1365
+ Args:
1366
+ x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1367
+ The input. When both are matrices they must have the same shape.
1368
+
1369
+ Returns:
1370
+ The old value of `x`.
1371
+
1372
+ Example::
1373
+
1374
+ >>> @ti.kernel
1375
+ >>> def test():
1376
+ >>> x = ti.Vector([-1, 0, 1])
1377
+ >>> y = ti.Vector([1, 2, 3])
1378
+ >>> z = ti.atomic_xor(x, y)
1379
+ >>> print(x) # [-2, 2, 2] the new value of x
1380
+ >>> print(z) # [-1, 0, 1], the old value of x
1381
+ >>>
1382
+ >>> ti.atomic_xor(1, x) # will raise GsTaichiSyntaxError
1383
+ """
1384
+ return impl.expr_init(
1385
+ expr.Expr(_ti_core.expr_atomic_bit_xor(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
1386
+ )
1387
+
1388
+
1389
+ @writeback_binary
1390
+ def assign(a, b):
1391
+ impl.get_runtime().compiling_callable.ast_builder().expr_assign(a.ptr, b.ptr, _ti_core.DebugInfo(stack_info()))
1392
+ return a
1393
+
1394
+
1395
+ def max(*args): # pylint: disable=W0622
1396
+ """Compute the maximum of the arguments, element-wise.
1397
+
1398
+ This function takes no effect on a single argument, even it's array-like.
1399
+ When there are both scalar and matrix arguments in `args`, the matrices
1400
+ must have the same shape, and scalars will be broadcasted to the same shape as the matrix.
1401
+
1402
+ Args:
1403
+ args: (List[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1404
+ The input.
1405
+
1406
+ Returns:
1407
+ Maximum of the inputs.
1408
+
1409
+ Example::
1410
+
1411
+ >>> @ti.kernel
1412
+ >>> def foo():
1413
+ >>> x = ti.Vector([0, 1, 2])
1414
+ >>> y = ti.Vector([3, 4, 5])
1415
+ >>> z = ti.max(x, y, 4)
1416
+ >>> print(z) # [4, 4, 5]
1417
+ """
1418
+ num_args = len(args)
1419
+ assert num_args >= 1
1420
+ if num_args == 1:
1421
+ return args[0]
1422
+ if num_args == 2:
1423
+ return max_impl(args[0], args[1])
1424
+ return max_impl(args[0], max(*args[1:]))
1425
+
1426
+
1427
+ def min(*args): # pylint: disable=W0622
1428
+ """Compute the minimum of the arguments, element-wise.
1429
+
1430
+ This function takes no effect on a single argument, even it's array-like.
1431
+ When there are both scalar and matrix arguments in `args`, the matrices
1432
+ must have the same shape, and scalars will be broadcasted to the same shape as the matrix.
1433
+
1434
+ Args:
1435
+ args: (List[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
1436
+ The input.
1437
+
1438
+ Returns:
1439
+ Minimum of the inputs.
1440
+
1441
+ Example::
1442
+
1443
+ >>> @ti.kernel
1444
+ >>> def foo():
1445
+ >>> x = ti.Vector([0, 1, 2])
1446
+ >>> y = ti.Vector([3, 4, 5])
1447
+ >>> z = ti.min(x, y, 1)
1448
+ >>> print(z) # [0, 1, 1]
1449
+ """
1450
+ num_args = len(args)
1451
+ assert num_args >= 1
1452
+ if num_args == 1:
1453
+ return args[0]
1454
+ if num_args == 2:
1455
+ return min_impl(args[0], args[1])
1456
+ return min_impl(args[0], min(*args[1:]))
1457
+
1458
+
1459
+ __all__ = [
1460
+ "acos",
1461
+ "asin",
1462
+ "atan2",
1463
+ "atomic_and",
1464
+ "atomic_or",
1465
+ "atomic_xor",
1466
+ "atomic_max",
1467
+ "atomic_sub",
1468
+ "atomic_min",
1469
+ "atomic_add",
1470
+ "atomic_mul",
1471
+ "bit_cast",
1472
+ "bit_shr",
1473
+ "cast",
1474
+ "ceil",
1475
+ "cos",
1476
+ "exp",
1477
+ "floor",
1478
+ "frexp",
1479
+ "log",
1480
+ "random",
1481
+ "raw_mod",
1482
+ "raw_div",
1483
+ "round",
1484
+ "rsqrt",
1485
+ "sin",
1486
+ "sqrt",
1487
+ "tan",
1488
+ "tanh",
1489
+ "max",
1490
+ "min",
1491
+ "select",
1492
+ "abs",
1493
+ "pow",
1494
+ ]