gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/ops.py
ADDED
@@ -0,0 +1,1494 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import builtins
|
4
|
+
import functools
|
5
|
+
import operator as _bt_ops_mod # bt for builtin
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi.lang import expr, impl
|
12
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
13
|
+
from gstaichi.lang.field import Field
|
14
|
+
from gstaichi.lang.util import (
|
15
|
+
cook_dtype,
|
16
|
+
gstaichi_scope,
|
17
|
+
is_gstaichi_class,
|
18
|
+
is_matrix_class,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
def stack_info():
|
23
|
+
return impl.get_runtime().get_current_src_info()
|
24
|
+
|
25
|
+
|
26
|
+
def is_gstaichi_expr(a):
|
27
|
+
return isinstance(a, expr.Expr)
|
28
|
+
|
29
|
+
|
30
|
+
def wrap_if_not_expr(a):
|
31
|
+
return (
|
32
|
+
expr.Expr(a, dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()))
|
33
|
+
if not is_gstaichi_expr(a)
|
34
|
+
else a
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def _read_matrix_or_scalar(x):
|
39
|
+
if is_matrix_class(x):
|
40
|
+
return x.to_numpy()
|
41
|
+
return x
|
42
|
+
|
43
|
+
|
44
|
+
def writeback_binary(foo):
|
45
|
+
@functools.wraps(foo)
|
46
|
+
def wrapped(a, b):
|
47
|
+
if isinstance(a, Field) or isinstance(b, Field):
|
48
|
+
return NotImplemented
|
49
|
+
if not (is_gstaichi_expr(a) and a.ptr.is_lvalue()):
|
50
|
+
raise GsTaichiSyntaxError(f"cannot use a non-writable target as the first operand of '{foo.__name__}'")
|
51
|
+
return foo(a, wrap_if_not_expr(b))
|
52
|
+
|
53
|
+
return wrapped
|
54
|
+
|
55
|
+
|
56
|
+
def cast(obj, dtype):
|
57
|
+
"""Copy and cast a scalar or a matrix to a specified data type.
|
58
|
+
Must be called in GsTaichi scope.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
obj (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
62
|
+
Input scalar or matrix.
|
63
|
+
|
64
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): A primitive type defined in :mod:`~gstaichi.types.primitive_types`.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
A copy of `obj`, casted to the specified data type `dtype`.
|
68
|
+
|
69
|
+
Example::
|
70
|
+
|
71
|
+
>>> @ti.kernel
|
72
|
+
>>> def test():
|
73
|
+
>>> x = ti.Matrix([0, 1, 2], ti.i32)
|
74
|
+
>>> y = ti.cast(x, ti.f32)
|
75
|
+
>>> print(y)
|
76
|
+
>>>
|
77
|
+
>>> test()
|
78
|
+
[0.0, 1.0, 2.0]
|
79
|
+
"""
|
80
|
+
dtype = cook_dtype(dtype)
|
81
|
+
if is_gstaichi_class(obj):
|
82
|
+
# TODO: unify with element_wise_unary
|
83
|
+
return obj.cast(dtype)
|
84
|
+
return expr.Expr(_ti_core.value_cast(expr.Expr(obj).ptr, dtype))
|
85
|
+
|
86
|
+
|
87
|
+
def bit_cast(obj, dtype):
|
88
|
+
"""Copy and cast a scalar to a specified data type with its underlying
|
89
|
+
bits preserved. Must be called in gstaichi scope.
|
90
|
+
|
91
|
+
This function is equivalent to `reinterpret_cast` in C++.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
obj (:mod:`~gstaichi.types.primitive_types`): Input scalar.
|
95
|
+
|
96
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): Target data type, must have \
|
97
|
+
the same precision bits as the input (hence `f32` -> `f64` is not allowed).
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
A copy of `obj`, casted to the specified data type `dtype`.
|
101
|
+
|
102
|
+
Example::
|
103
|
+
|
104
|
+
>>> @ti.kernel
|
105
|
+
>>> def test():
|
106
|
+
>>> x = 3.14
|
107
|
+
>>> y = ti.bit_cast(x, ti.i32)
|
108
|
+
>>> print(y) # 1078523331
|
109
|
+
>>>
|
110
|
+
>>> z = ti.bit_cast(y, ti.f32)
|
111
|
+
>>> print(z) # 3.14
|
112
|
+
"""
|
113
|
+
dtype = cook_dtype(dtype)
|
114
|
+
if is_gstaichi_class(obj):
|
115
|
+
raise ValueError("Cannot apply bit_cast on GsTaichi classes")
|
116
|
+
else:
|
117
|
+
return expr.Expr(_ti_core.bits_cast(expr.Expr(obj).ptr, dtype))
|
118
|
+
|
119
|
+
|
120
|
+
def _unary_operation(gstaichi_op, python_op, a):
|
121
|
+
if isinstance(a, Field):
|
122
|
+
return NotImplemented
|
123
|
+
if is_gstaichi_expr(a):
|
124
|
+
return expr.Expr(gstaichi_op(a.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
125
|
+
from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
|
126
|
+
|
127
|
+
if isinstance(a, Matrix):
|
128
|
+
return Matrix(python_op(a.to_numpy()))
|
129
|
+
return python_op(a)
|
130
|
+
|
131
|
+
|
132
|
+
def _binary_operation(gstaichi_op, python_op, a, b):
|
133
|
+
if isinstance(a, Field) or isinstance(b, Field):
|
134
|
+
return NotImplemented
|
135
|
+
if is_gstaichi_expr(a) or is_gstaichi_expr(b):
|
136
|
+
a, b = wrap_if_not_expr(a), wrap_if_not_expr(b)
|
137
|
+
return expr.Expr(gstaichi_op(a.ptr, b.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
138
|
+
from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
|
139
|
+
|
140
|
+
if isinstance(a, Matrix) or isinstance(b, Matrix):
|
141
|
+
return Matrix(python_op(_read_matrix_or_scalar(a), _read_matrix_or_scalar(b)))
|
142
|
+
return python_op(a, b)
|
143
|
+
|
144
|
+
|
145
|
+
def _ternary_operation(gstaichi_op, python_op, a, b, c):
|
146
|
+
if isinstance(a, Field) or isinstance(b, Field) or isinstance(c, Field):
|
147
|
+
return NotImplemented
|
148
|
+
if is_gstaichi_expr(a) or is_gstaichi_expr(b) or is_gstaichi_expr(c):
|
149
|
+
a, b, c = wrap_if_not_expr(a), wrap_if_not_expr(b), wrap_if_not_expr(c)
|
150
|
+
return expr.Expr(gstaichi_op(a.ptr, b.ptr, c.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
151
|
+
from gstaichi.lang.matrix import Matrix # pylint: disable-msg=C0415
|
152
|
+
|
153
|
+
if isinstance(a, Matrix) or isinstance(b, Matrix) or isinstance(c, Matrix):
|
154
|
+
return Matrix(
|
155
|
+
python_op(
|
156
|
+
_read_matrix_or_scalar(a),
|
157
|
+
_read_matrix_or_scalar(b),
|
158
|
+
_read_matrix_or_scalar(c),
|
159
|
+
)
|
160
|
+
)
|
161
|
+
return python_op(a, b, c)
|
162
|
+
|
163
|
+
|
164
|
+
def neg(x):
|
165
|
+
"""Numerical negative, element-wise.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
169
|
+
Input scalar or matrix.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
Matrix or scalar `y`, so that `y = -x`. `y` has the same type as `x`.
|
173
|
+
|
174
|
+
Example::
|
175
|
+
>>> x = ti.Matrix([1, -1])
|
176
|
+
>>> y = ti.neg(a)
|
177
|
+
>>> y
|
178
|
+
[-1, 1]
|
179
|
+
"""
|
180
|
+
return _unary_operation(_ti_core.expr_neg, _bt_ops_mod.neg, x)
|
181
|
+
|
182
|
+
|
183
|
+
def sin(x):
|
184
|
+
"""Trigonometric sine, element-wise.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
188
|
+
Angle, in radians.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
The sine of each element of `x`.
|
192
|
+
|
193
|
+
Example::
|
194
|
+
|
195
|
+
>>> from math import pi
|
196
|
+
>>> x = ti.Matrix([-pi/2., 0, pi/2.])
|
197
|
+
>>> ti.sin(x)
|
198
|
+
[-1., 0., 1.]
|
199
|
+
"""
|
200
|
+
return _unary_operation(_ti_core.expr_sin, np.sin, x)
|
201
|
+
|
202
|
+
|
203
|
+
def cos(x):
|
204
|
+
"""Trigonometric cosine, element-wise.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
x (Union[:mod:`~gstaichi.type.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
208
|
+
Angle, in radians.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
The cosine of each element of `x`.
|
212
|
+
|
213
|
+
Example::
|
214
|
+
|
215
|
+
>>> from math import pi
|
216
|
+
>>> x = ti.Matrix([-pi, 0, pi/2.])
|
217
|
+
>>> ti.cos(x)
|
218
|
+
[-1., 1., 0.]
|
219
|
+
"""
|
220
|
+
return _unary_operation(_ti_core.expr_cos, np.cos, x)
|
221
|
+
|
222
|
+
|
223
|
+
def asin(x):
|
224
|
+
"""Trigonometric inverse sine, element-wise.
|
225
|
+
|
226
|
+
The inverse of `sin` so that, if `y = sin(x)`, then `x = asin(y)`.
|
227
|
+
|
228
|
+
For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \
|
229
|
+
it's called in gstaichi scope, or raises exception if it's called in python scope.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
233
|
+
A scalar or a matrix with elements in [-1, 1].
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
The inverse sine of each element in `x`, in radians and in the closed \
|
237
|
+
interval `[-pi/2, pi/2]`.
|
238
|
+
|
239
|
+
Example::
|
240
|
+
|
241
|
+
>>> from math import pi
|
242
|
+
>>> ti.asin(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi
|
243
|
+
[-90., 0., 90.]
|
244
|
+
"""
|
245
|
+
return _unary_operation(_ti_core.expr_asin, np.arcsin, x)
|
246
|
+
|
247
|
+
|
248
|
+
def acos(x):
|
249
|
+
"""Trigonometric inverse cosine, element-wise.
|
250
|
+
|
251
|
+
The inverse of `cos` so that, if `y = cos(x)`, then `x = acos(y)`.
|
252
|
+
|
253
|
+
For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \
|
254
|
+
it's called in gstaichi scope, or raises exception if it's called in python scope.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
258
|
+
A scalar or a matrix with elements in [-1, 1].
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
The inverse cosine of each element in `x`, in radians and in the closed \
|
262
|
+
interval `[0, pi]`. This is a scalar if `x` is a scalar.
|
263
|
+
|
264
|
+
Example::
|
265
|
+
|
266
|
+
>>> from math import pi
|
267
|
+
>>> ti.acos(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi
|
268
|
+
[180., 90., 0.]
|
269
|
+
"""
|
270
|
+
return _unary_operation(_ti_core.expr_acos, np.arccos, x)
|
271
|
+
|
272
|
+
|
273
|
+
def sqrt(x):
|
274
|
+
"""Return the non-negative square-root of a scalar or a matrix,
|
275
|
+
element wise. If `x < 0` an exception is raised.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
279
|
+
The scalar or matrix whose square-roots are required.
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
The square-root `y` so that `y >= 0` and `y^2 = x`. `y` has the same type as `x`.
|
283
|
+
|
284
|
+
Example::
|
285
|
+
|
286
|
+
>>> x = ti.Matrix([1., 4., 9.])
|
287
|
+
>>> y = ti.sqrt(x)
|
288
|
+
>>> y
|
289
|
+
[1.0, 2.0, 3.0]
|
290
|
+
"""
|
291
|
+
return _unary_operation(_ti_core.expr_sqrt, np.sqrt, x)
|
292
|
+
|
293
|
+
|
294
|
+
def rsqrt(x):
|
295
|
+
"""The reciprocal of the square root function.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
299
|
+
A scalar or a matrix.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
The reciprocal of `sqrt(x)`.
|
303
|
+
"""
|
304
|
+
|
305
|
+
def _rsqrt(x):
|
306
|
+
return 1 / np.sqrt(x)
|
307
|
+
|
308
|
+
return _unary_operation(_ti_core.expr_rsqrt, _rsqrt, x)
|
309
|
+
|
310
|
+
|
311
|
+
def _round(x):
|
312
|
+
return _unary_operation(_ti_core.expr_round, np.round, x)
|
313
|
+
|
314
|
+
|
315
|
+
def round(x, dtype=None): # pylint: disable=redefined-builtin
|
316
|
+
"""Round to the nearest integer, element-wise.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
320
|
+
A scalar or a matrix.
|
321
|
+
|
322
|
+
dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
|
323
|
+
set to `None` the retuned value will have the same type with `x`.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
The nearest integer of `x`, with return value type `dtype`.
|
327
|
+
|
328
|
+
Example::
|
329
|
+
|
330
|
+
>>> @ti.kernel
|
331
|
+
>>> def test():
|
332
|
+
>>> x = ti.Vector([-1.5, 1.2, 2.7])
|
333
|
+
>>> print(ti.round(x))
|
334
|
+
[-2., 1., 3.]
|
335
|
+
"""
|
336
|
+
result = _round(x)
|
337
|
+
if dtype is not None:
|
338
|
+
result = cast(result, dtype)
|
339
|
+
return result
|
340
|
+
|
341
|
+
|
342
|
+
def _floor(x):
|
343
|
+
return _unary_operation(_ti_core.expr_floor, np.floor, x)
|
344
|
+
|
345
|
+
|
346
|
+
def floor(x, dtype=None):
|
347
|
+
"""Return the floor of the input, element-wise.
|
348
|
+
The floor of the scalar `x` is the largest integer `k`, such that `k <= x`.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
352
|
+
Input scalar or matrix.
|
353
|
+
|
354
|
+
dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
|
355
|
+
set to `None` the retuned value will have the same type with `x`.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
The floor of each element in `x`, with return value type `dtype`.
|
359
|
+
|
360
|
+
Example::
|
361
|
+
>>> @ti.kernel
|
362
|
+
>>> def test():
|
363
|
+
>>> x = ti.Matrix([-1.1, 2.2, 3.])
|
364
|
+
>>> y = ti.floor(x, ti.f64)
|
365
|
+
>>> print(y) # [-2.000000000000, 2.000000000000, 3.000000000000]
|
366
|
+
"""
|
367
|
+
result = _floor(x)
|
368
|
+
if dtype is not None:
|
369
|
+
result = cast(result, dtype)
|
370
|
+
return result
|
371
|
+
|
372
|
+
|
373
|
+
def _ceil(x):
|
374
|
+
return _unary_operation(_ti_core.expr_ceil, np.ceil, x)
|
375
|
+
|
376
|
+
|
377
|
+
def ceil(x, dtype=None):
|
378
|
+
"""Return the ceiling of the input, element-wise.
|
379
|
+
|
380
|
+
The ceil of the scalar `x` is the smallest integer `k`, such that `k >= x`.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
384
|
+
Input scalar or matrix.
|
385
|
+
|
386
|
+
dtype: (:mod:`~gstaichi.types.primitive_types`): the returned type, default to `None`. If \
|
387
|
+
set to `None` the retuned value will have the same type with `x`.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
The ceiling of each element in `x`, with return value type `dtype`.
|
391
|
+
|
392
|
+
Example::
|
393
|
+
|
394
|
+
>>> @ti.kernel
|
395
|
+
>>> def test():
|
396
|
+
>>> x = ti.Matrix([3.14, -1.5])
|
397
|
+
>>> y = ti.ceil(x)
|
398
|
+
>>> print(y) # [4.0, -1.0]
|
399
|
+
"""
|
400
|
+
result = _ceil(x)
|
401
|
+
if dtype is not None:
|
402
|
+
result = cast(result, dtype)
|
403
|
+
return result
|
404
|
+
|
405
|
+
|
406
|
+
def frexp(x):
|
407
|
+
return _unary_operation(_ti_core.expr_frexp, np.frexp, x)
|
408
|
+
|
409
|
+
|
410
|
+
def tan(x):
|
411
|
+
"""Trigonometric tangent function, element-wise.
|
412
|
+
|
413
|
+
Equivalent to `ti.sin(x)/ti.cos(x)` element-wise.
|
414
|
+
|
415
|
+
Args:
|
416
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
417
|
+
Input scalar or matrix.
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
The tangent values of `x`.
|
421
|
+
|
422
|
+
Example::
|
423
|
+
|
424
|
+
>>> from math import pi
|
425
|
+
>>> @ti.kernel
|
426
|
+
>>> def test():
|
427
|
+
>>> x = ti.Matrix([-pi, pi/2, pi])
|
428
|
+
>>> y = ti.tan(x)
|
429
|
+
>>> print(y)
|
430
|
+
>>>
|
431
|
+
>>> test()
|
432
|
+
[-0.0, -22877334.0, 0.0]
|
433
|
+
"""
|
434
|
+
return _unary_operation(_ti_core.expr_tan, np.tan, x)
|
435
|
+
|
436
|
+
|
437
|
+
def tanh(x):
|
438
|
+
"""Compute the hyperbolic tangent of `x`, element-wise.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
442
|
+
Input scalar or matrix.
|
443
|
+
|
444
|
+
Returns:
|
445
|
+
The corresponding hyperbolic tangent values.
|
446
|
+
|
447
|
+
Example::
|
448
|
+
|
449
|
+
>>> @ti.kernel
|
450
|
+
>>> def test():
|
451
|
+
>>> x = ti.Matrix([-1.0, 0.0, 1.0])
|
452
|
+
>>> y = ti.tanh(x)
|
453
|
+
>>> print(y)
|
454
|
+
>>>
|
455
|
+
>>> test()
|
456
|
+
[-0.761594, 0.000000, 0.761594]
|
457
|
+
"""
|
458
|
+
return _unary_operation(_ti_core.expr_tanh, np.tanh, x)
|
459
|
+
|
460
|
+
|
461
|
+
def exp(x):
|
462
|
+
"""Compute the exponential of all elements in `x`, element-wise.
|
463
|
+
|
464
|
+
Args:
|
465
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
466
|
+
Input scalar or matrix.
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
Element-wise exponential of `x`.
|
470
|
+
|
471
|
+
Example::
|
472
|
+
|
473
|
+
>>> @ti.kernel
|
474
|
+
>>> def test():
|
475
|
+
>>> x = ti.Matrix([-1.0, 0.0, 1.0])
|
476
|
+
>>> y = ti.exp(x)
|
477
|
+
>>> print(y)
|
478
|
+
>>>
|
479
|
+
>>> test()
|
480
|
+
[0.367879, 1.000000, 2.718282]
|
481
|
+
"""
|
482
|
+
return _unary_operation(_ti_core.expr_exp, np.exp, x)
|
483
|
+
|
484
|
+
|
485
|
+
def log(x):
|
486
|
+
"""Compute the natural logarithm, element-wise.
|
487
|
+
|
488
|
+
The natural logarithm `log` is the inverse of the exponential function,
|
489
|
+
so that `log(exp(x)) = x`. The natural logarithm is logarithm in base `e`.
|
490
|
+
|
491
|
+
Args:
|
492
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
493
|
+
Input scalar or matrix.
|
494
|
+
|
495
|
+
Returns:
|
496
|
+
The natural logarithm of `x`, element-wise.
|
497
|
+
|
498
|
+
Example::
|
499
|
+
|
500
|
+
>>> @ti.kernel
|
501
|
+
>>> def test():
|
502
|
+
>>> x = ti.Vector([-1.0, 0.0, 1.0])
|
503
|
+
>>> y = ti.log(x)
|
504
|
+
>>> print(y)
|
505
|
+
>>>
|
506
|
+
>>> test()
|
507
|
+
[-nan, -inf, 0.000000]
|
508
|
+
"""
|
509
|
+
return _unary_operation(_ti_core.expr_log, np.log, x)
|
510
|
+
|
511
|
+
|
512
|
+
def abs(x): # pylint: disable=W0622
|
513
|
+
"""Compute the absolute value :math:`|x|` of `x`, element-wise.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
517
|
+
Input scalar or matrix.
|
518
|
+
|
519
|
+
Returns:
|
520
|
+
The absolute value of each element in `x`.
|
521
|
+
|
522
|
+
Example::
|
523
|
+
|
524
|
+
>>> @ti.kernel
|
525
|
+
>>> def test():
|
526
|
+
>>> x = ti.Vector([-1.0, 0.0, 1.0])
|
527
|
+
>>> y = ti.abs(x)
|
528
|
+
>>> print(y)
|
529
|
+
>>>
|
530
|
+
>>> test()
|
531
|
+
[1.0, 0.0, 1.0]
|
532
|
+
"""
|
533
|
+
return _unary_operation(_ti_core.expr_abs, builtins.abs, x)
|
534
|
+
|
535
|
+
|
536
|
+
def bit_not(a):
|
537
|
+
"""The bit not function.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
541
|
+
|
542
|
+
Returns:
|
543
|
+
Bitwise not of `a`.
|
544
|
+
"""
|
545
|
+
return _unary_operation(_ti_core.expr_bit_not, _bt_ops_mod.invert, a)
|
546
|
+
|
547
|
+
|
548
|
+
def popcnt(a):
|
549
|
+
def _popcnt(x):
|
550
|
+
return bin(x).count("1")
|
551
|
+
|
552
|
+
return _unary_operation(_ti_core.expr_popcnt, _popcnt, a)
|
553
|
+
|
554
|
+
|
555
|
+
def logical_not(a):
|
556
|
+
"""The logical not function.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
`1` iff `a=0`, otherwise `0`.
|
563
|
+
"""
|
564
|
+
return _unary_operation(_ti_core.expr_logic_not, np.logical_not, a)
|
565
|
+
|
566
|
+
|
567
|
+
def random(dtype=float) -> Union[float, int]:
|
568
|
+
"""Return a single random float/integer according to the specified data type.
|
569
|
+
Must be called in gstaichi scope.
|
570
|
+
|
571
|
+
If the required `dtype` is float type, this function returns a random number
|
572
|
+
sampled from the uniform distribution in the half-open interval [0, 1).
|
573
|
+
|
574
|
+
For integer types this function returns a random integer in the
|
575
|
+
half-open interval [0, 2^32) if a 32-bit integer is required,
|
576
|
+
or a random integer in the half-open interval [0, 2^64) if a
|
577
|
+
64-bit integer is required.
|
578
|
+
|
579
|
+
Args:
|
580
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): Type of the required random value.
|
581
|
+
|
582
|
+
Returns:
|
583
|
+
A random value with type `dtype`.
|
584
|
+
|
585
|
+
Example::
|
586
|
+
|
587
|
+
>>> @ti.kernel
|
588
|
+
>>> def test():
|
589
|
+
>>> x = ti.random(float)
|
590
|
+
>>> print(x) # 0.090257
|
591
|
+
>>>
|
592
|
+
>>> y = ti.random(ti.f64)
|
593
|
+
>>> print(y) # 0.716101627301
|
594
|
+
>>>
|
595
|
+
>>> i = ti.random(ti.i32)
|
596
|
+
>>> print(i) # -963722261
|
597
|
+
>>>
|
598
|
+
>>> j = ti.random(ti.i64)
|
599
|
+
>>> print(j) # 73412986184350777
|
600
|
+
"""
|
601
|
+
dtype = cook_dtype(dtype)
|
602
|
+
x = expr.Expr(_ti_core.make_rand_expr(dtype, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())))
|
603
|
+
return impl.expr_init(x)
|
604
|
+
|
605
|
+
|
606
|
+
# NEXT: add matpow(self, power)
|
607
|
+
|
608
|
+
|
609
|
+
def add(a, b):
|
610
|
+
"""The add function.
|
611
|
+
|
612
|
+
Args:
|
613
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
614
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
sum of `a` and `b`.
|
618
|
+
"""
|
619
|
+
return _binary_operation(_ti_core.expr_add, _bt_ops_mod.add, a, b)
|
620
|
+
|
621
|
+
|
622
|
+
def sub(a, b):
|
623
|
+
"""The sub function.
|
624
|
+
|
625
|
+
Args:
|
626
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
627
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
628
|
+
|
629
|
+
Returns:
|
630
|
+
`a` subtract `b`.
|
631
|
+
"""
|
632
|
+
return _binary_operation(_ti_core.expr_sub, _bt_ops_mod.sub, a, b)
|
633
|
+
|
634
|
+
|
635
|
+
def mul(a, b):
|
636
|
+
"""The multiply function.
|
637
|
+
|
638
|
+
Args:
|
639
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
640
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
641
|
+
|
642
|
+
Returns:
|
643
|
+
`a` multiplied by `b`.
|
644
|
+
"""
|
645
|
+
return _binary_operation(_ti_core.expr_mul, _bt_ops_mod.mul, a, b)
|
646
|
+
|
647
|
+
|
648
|
+
def mod(x1, x2):
|
649
|
+
"""Returns the element-wise remainder of division.
|
650
|
+
|
651
|
+
This is equivalent to the Python modulus operator `x1 % x2` and
|
652
|
+
has the same sign as the divisor x2.
|
653
|
+
|
654
|
+
Args:
|
655
|
+
x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
656
|
+
Dividend scalar or matrix.
|
657
|
+
|
658
|
+
x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
659
|
+
Divisor scalar or matrix. When both `x1` and `x2` are matrices they must have the same shape.
|
660
|
+
|
661
|
+
Returns:
|
662
|
+
The element-wise remainder of the quotient `floordiv(x1, x2)`. This is a scalar \
|
663
|
+
if both `x1` and `x2` are scalars.
|
664
|
+
|
665
|
+
Example::
|
666
|
+
|
667
|
+
>>> @ti.kernel
|
668
|
+
>>> def test():
|
669
|
+
>>> x = ti.Matrix([3.0, 4.0, 5.0])
|
670
|
+
>>> y = 3
|
671
|
+
>>> z = ti.mod(y, x)
|
672
|
+
>>> print(z)
|
673
|
+
>>>
|
674
|
+
>>> test()
|
675
|
+
[1.0, 0.0, 4.0]
|
676
|
+
"""
|
677
|
+
|
678
|
+
def expr_python_mod(a, b):
|
679
|
+
# a % b = a - (a // b) * b
|
680
|
+
quotient = expr.Expr(_ti_core.expr_floordiv(a, b))
|
681
|
+
multiply = expr.Expr(_ti_core.expr_mul(b, quotient.ptr))
|
682
|
+
return _ti_core.expr_sub(a, multiply.ptr)
|
683
|
+
|
684
|
+
return _binary_operation(expr_python_mod, _bt_ops_mod.mod, x1, x2)
|
685
|
+
|
686
|
+
|
687
|
+
def pow(base, exponent): # pylint: disable=W0622
|
688
|
+
"""First array elements raised to second array elements :math:`{base}^{exponent}`, element-wise.
|
689
|
+
|
690
|
+
The result type of two scalar operands is determined as follows:
|
691
|
+
- If the exponent is an integral value, then the result type takes the type of the base.
|
692
|
+
- Otherwise, the result type follows
|
693
|
+
[Implicit type casting in binary operations](https://docs.taichi-lang.org/docs/type#implicit-type-casting-in-binary-operations).
|
694
|
+
|
695
|
+
With the above rules, an integral value raised to a negative integral value cannot have a
|
696
|
+
feasible type. Therefore, an exception will be raised if debug mode or optimization passes
|
697
|
+
are on; otherwise 1 will be returned.
|
698
|
+
|
699
|
+
In the following situations, the result is undefined:
|
700
|
+
- A negative value raised to a non-integral value.
|
701
|
+
- A zero value raised to a non-positive value.
|
702
|
+
|
703
|
+
Args:
|
704
|
+
base (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
705
|
+
The bases.
|
706
|
+
exponent (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
707
|
+
The exponents.
|
708
|
+
|
709
|
+
Returns:
|
710
|
+
`base` raised to `exponent`. This is a scalar if both `base` and `exponent` are scalars.
|
711
|
+
|
712
|
+
Example::
|
713
|
+
|
714
|
+
>>> @ti.kernel
|
715
|
+
>>> def test():
|
716
|
+
>>> x = ti.Matrix([-2.0, 2.0])
|
717
|
+
>>> y = -3
|
718
|
+
>>> z = ti.pow(x, y)
|
719
|
+
>>> print(z)
|
720
|
+
>>>
|
721
|
+
>>> test()
|
722
|
+
[-0.125000, 0.125000]
|
723
|
+
"""
|
724
|
+
return _binary_operation(_ti_core.expr_pow, _bt_ops_mod.pow, base, exponent)
|
725
|
+
|
726
|
+
|
727
|
+
def floordiv(a, b):
|
728
|
+
"""The floor division function.
|
729
|
+
|
730
|
+
Args:
|
731
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
732
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero.
|
733
|
+
|
734
|
+
Returns:
|
735
|
+
The floor function of `a` divided by `b`.
|
736
|
+
"""
|
737
|
+
return _binary_operation(_ti_core.expr_floordiv, _bt_ops_mod.floordiv, a, b)
|
738
|
+
|
739
|
+
|
740
|
+
def truediv(a, b):
|
741
|
+
"""True division function.
|
742
|
+
|
743
|
+
Args:
|
744
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
745
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero.
|
746
|
+
|
747
|
+
Returns:
|
748
|
+
The true value of `a` divided by `b`.
|
749
|
+
"""
|
750
|
+
return _binary_operation(_ti_core.expr_truediv, _bt_ops_mod.truediv, a, b)
|
751
|
+
|
752
|
+
|
753
|
+
def max_impl(a, b):
|
754
|
+
"""The maxnimum function.
|
755
|
+
|
756
|
+
Args:
|
757
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
758
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
759
|
+
|
760
|
+
Returns:
|
761
|
+
The maxnimum of `a` and `b`.
|
762
|
+
"""
|
763
|
+
return _binary_operation(_ti_core.expr_max, np.maximum, a, b)
|
764
|
+
|
765
|
+
|
766
|
+
def min_impl(a, b):
|
767
|
+
"""The minimum function.
|
768
|
+
|
769
|
+
Args:
|
770
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
771
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): A number or a matrix.
|
772
|
+
|
773
|
+
Returns:
|
774
|
+
The minimum of `a` and `b`.
|
775
|
+
"""
|
776
|
+
return _binary_operation(_ti_core.expr_min, np.minimum, a, b)
|
777
|
+
|
778
|
+
|
779
|
+
def atan2(x1, x2):
|
780
|
+
"""Element-wise arc tangent of `x1/x2`.
|
781
|
+
|
782
|
+
Args:
|
783
|
+
x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
784
|
+
y-coordinates.
|
785
|
+
x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
786
|
+
x-coordinates.
|
787
|
+
|
788
|
+
Returns:
|
789
|
+
Angles in radians, in the range `[-pi, pi]`.
|
790
|
+
This is a scalar if both `x1` and `x2` are scalars.
|
791
|
+
|
792
|
+
Example::
|
793
|
+
|
794
|
+
>>> from math import pi
|
795
|
+
>>> @ti.kernel
|
796
|
+
>>> def test():
|
797
|
+
>>> x = ti.Matrix([-1.0, 1.0, -1.0, 1.0])
|
798
|
+
>>> y = ti.Matrix([-1.0, -1.0, 1.0, 1.0])
|
799
|
+
>>> z = ti.atan2(y, x) * 180 / pi
|
800
|
+
>>> print(z)
|
801
|
+
>>>
|
802
|
+
>>> test()
|
803
|
+
[-135.0, -45.0, 135.0, 45.0]
|
804
|
+
"""
|
805
|
+
return _binary_operation(_ti_core.expr_atan2, np.arctan2, x1, x2)
|
806
|
+
|
807
|
+
|
808
|
+
def raw_div(x1, x2):
|
809
|
+
"""Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`.
|
810
|
+
|
811
|
+
Args:
|
812
|
+
x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): Dividend.
|
813
|
+
x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): Divisor.
|
814
|
+
|
815
|
+
Returns:
|
816
|
+
Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`.
|
817
|
+
|
818
|
+
Example::
|
819
|
+
|
820
|
+
>>> @ti.kernel
|
821
|
+
>>> def main():
|
822
|
+
>>> x = 5
|
823
|
+
>>> y = 3
|
824
|
+
>>> print(raw_div(x, y)) # 1
|
825
|
+
>>> z = 4.0
|
826
|
+
>>> print(raw_div(x, z)) # 1.25
|
827
|
+
"""
|
828
|
+
|
829
|
+
def c_div(a, b):
|
830
|
+
if isinstance(a, int) and isinstance(b, int):
|
831
|
+
return a // b
|
832
|
+
return a / b
|
833
|
+
|
834
|
+
return _binary_operation(_ti_core.expr_div, c_div, x1, x2)
|
835
|
+
|
836
|
+
|
837
|
+
def raw_mod(x1, x2):
|
838
|
+
"""Return the remainder of `x1/x2`, element-wise.
|
839
|
+
This is the C-style `mod` function.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
843
|
+
The dividend.
|
844
|
+
x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
845
|
+
The divisor.
|
846
|
+
|
847
|
+
Returns:
|
848
|
+
The remainder of `x1` divided by `x2`.
|
849
|
+
|
850
|
+
Example::
|
851
|
+
|
852
|
+
>>> @ti.kernel
|
853
|
+
>>> def main():
|
854
|
+
>>> print(ti.mod(-4, 3)) # 2
|
855
|
+
>>> print(ti.raw_mod(-4, 3)) # -1
|
856
|
+
"""
|
857
|
+
|
858
|
+
def c_mod(x, y):
|
859
|
+
return x - y * int(float(x) / y)
|
860
|
+
|
861
|
+
return _binary_operation(_ti_core.expr_mod, c_mod, x1, x2)
|
862
|
+
|
863
|
+
|
864
|
+
def cmp_lt(a, b):
|
865
|
+
"""Compare two values (less than)
|
866
|
+
|
867
|
+
Args:
|
868
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
869
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
870
|
+
|
871
|
+
Returns:
|
872
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is strictly smaller than RHS, False otherwise
|
873
|
+
|
874
|
+
"""
|
875
|
+
return _binary_operation(_ti_core.expr_cmp_lt, _bt_ops_mod.lt, a, b)
|
876
|
+
|
877
|
+
|
878
|
+
def cmp_le(a, b):
|
879
|
+
"""Compare two values (less than or equal to)
|
880
|
+
|
881
|
+
Args:
|
882
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
883
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
884
|
+
|
885
|
+
Returns:
|
886
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is smaller than or equal to RHS, False otherwise
|
887
|
+
|
888
|
+
"""
|
889
|
+
return _binary_operation(_ti_core.expr_cmp_le, _bt_ops_mod.le, a, b)
|
890
|
+
|
891
|
+
|
892
|
+
def cmp_gt(a, b):
|
893
|
+
"""Compare two values (greater than)
|
894
|
+
|
895
|
+
Args:
|
896
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
897
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
898
|
+
|
899
|
+
Returns:
|
900
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is strictly larger than RHS, False otherwise
|
901
|
+
|
902
|
+
"""
|
903
|
+
return _binary_operation(_ti_core.expr_cmp_gt, _bt_ops_mod.gt, a, b)
|
904
|
+
|
905
|
+
|
906
|
+
def cmp_ge(a, b):
|
907
|
+
"""Compare two values (greater than or equal to)
|
908
|
+
|
909
|
+
Args:
|
910
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
911
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
912
|
+
|
913
|
+
Returns:
|
914
|
+
bool: True if LHS is greater than or equal to RHS, False otherwise
|
915
|
+
|
916
|
+
"""
|
917
|
+
return _binary_operation(_ti_core.expr_cmp_ge, _bt_ops_mod.ge, a, b)
|
918
|
+
|
919
|
+
|
920
|
+
def cmp_eq(a, b):
|
921
|
+
"""Compare two values (equal to)
|
922
|
+
|
923
|
+
Args:
|
924
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
925
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
926
|
+
|
927
|
+
Returns:
|
928
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is equal to RHS, False otherwise.
|
929
|
+
|
930
|
+
"""
|
931
|
+
return _binary_operation(_ti_core.expr_cmp_eq, _bt_ops_mod.eq, a, b)
|
932
|
+
|
933
|
+
|
934
|
+
def cmp_ne(a, b):
|
935
|
+
"""Compare two values (not equal to)
|
936
|
+
|
937
|
+
Args:
|
938
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
939
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
940
|
+
|
941
|
+
Returns:
|
942
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: True if LHS is not equal to RHS, False otherwise
|
943
|
+
|
944
|
+
"""
|
945
|
+
return _binary_operation(_ti_core.expr_cmp_ne, _bt_ops_mod.ne, a, b)
|
946
|
+
|
947
|
+
|
948
|
+
def bit_or(a, b):
|
949
|
+
"""Computes bitwise-or
|
950
|
+
|
951
|
+
Args:
|
952
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
953
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
954
|
+
|
955
|
+
Returns:
|
956
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-or with RHS
|
957
|
+
|
958
|
+
"""
|
959
|
+
return _binary_operation(_ti_core.expr_bit_or, _bt_ops_mod.or_, a, b)
|
960
|
+
|
961
|
+
|
962
|
+
def bit_and(a, b):
|
963
|
+
"""Compute bitwise-and
|
964
|
+
|
965
|
+
Args:
|
966
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
967
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
968
|
+
|
969
|
+
Returns:
|
970
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-and with RHS
|
971
|
+
|
972
|
+
"""
|
973
|
+
return _binary_operation(_ti_core.expr_bit_and, _bt_ops_mod.and_, a, b)
|
974
|
+
|
975
|
+
|
976
|
+
def bit_xor(a, b):
|
977
|
+
"""Compute bitwise-xor
|
978
|
+
|
979
|
+
Args:
|
980
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
981
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
982
|
+
|
983
|
+
Returns:
|
984
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS bitwise-xor with RHS
|
985
|
+
|
986
|
+
"""
|
987
|
+
return _binary_operation(_ti_core.expr_bit_xor, _bt_ops_mod.xor, a, b)
|
988
|
+
|
989
|
+
|
990
|
+
def bit_shl(a, b):
|
991
|
+
"""Compute bitwise shift left
|
992
|
+
|
993
|
+
Args:
|
994
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
995
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
996
|
+
|
997
|
+
Returns:
|
998
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, int]: LHS << RHS
|
999
|
+
|
1000
|
+
"""
|
1001
|
+
return _binary_operation(_ti_core.expr_bit_shl, _bt_ops_mod.lshift, a, b)
|
1002
|
+
|
1003
|
+
|
1004
|
+
def bit_sar(a, b):
|
1005
|
+
"""Compute bitwise shift right
|
1006
|
+
|
1007
|
+
Args:
|
1008
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
1009
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
1010
|
+
|
1011
|
+
Returns:
|
1012
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, int]: LHS >> RHS
|
1013
|
+
|
1014
|
+
"""
|
1015
|
+
return _binary_operation(_ti_core.expr_bit_sar, _bt_ops_mod.rshift, a, b)
|
1016
|
+
|
1017
|
+
|
1018
|
+
@gstaichi_scope
|
1019
|
+
def bit_shr(x1, x2):
|
1020
|
+
"""Elements in `x1` shifted to the right by number of bits in `x2`.
|
1021
|
+
Both `x1`, `x2` must have integer type.
|
1022
|
+
|
1023
|
+
Args:
|
1024
|
+
x1 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1025
|
+
Input data.
|
1026
|
+
x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1027
|
+
Number of bits to remove at the right of `x1`.
|
1028
|
+
|
1029
|
+
Returns:
|
1030
|
+
Return `x1` with bits shifted `x2` times to the right.
|
1031
|
+
This is a scalar if both `x1` and `x2` are scalars.
|
1032
|
+
|
1033
|
+
Example::
|
1034
|
+
>>> @ti.kernel
|
1035
|
+
>>> def main():
|
1036
|
+
>>> x = ti.Matrix([7, 8])
|
1037
|
+
>>> y = ti.Matrix([1, 2])
|
1038
|
+
>>> print(ti.bit_shr(x, y))
|
1039
|
+
>>>
|
1040
|
+
>>> main()
|
1041
|
+
[3, 2]
|
1042
|
+
"""
|
1043
|
+
return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, x1, x2)
|
1044
|
+
|
1045
|
+
|
1046
|
+
def logical_and(a, b):
|
1047
|
+
"""Compute logical_and
|
1048
|
+
|
1049
|
+
Args:
|
1050
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
1051
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
1052
|
+
|
1053
|
+
Returns:
|
1054
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS logical-and RHS (with short-circuit semantics)
|
1055
|
+
|
1056
|
+
"""
|
1057
|
+
return _binary_operation(_ti_core.expr_logical_and, lambda a, b: a and b, a, b)
|
1058
|
+
|
1059
|
+
|
1060
|
+
def logical_or(a, b):
|
1061
|
+
"""Compute logical_or
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
a (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value LHS
|
1065
|
+
b (Union[:class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`]): value RHS
|
1066
|
+
|
1067
|
+
Returns:
|
1068
|
+
Union[:class:`~gstaichi.lang.expr.Expr`, bool]: LHS logical-or RHS (with short-circuit semantics)
|
1069
|
+
|
1070
|
+
"""
|
1071
|
+
return _binary_operation(_ti_core.expr_logical_or, lambda a, b: a or b, a, b)
|
1072
|
+
|
1073
|
+
|
1074
|
+
def select(cond, x1, x2):
|
1075
|
+
"""Return an array drawn from elements in `x1` or `x2`,
|
1076
|
+
depending on the conditions in `cond`.
|
1077
|
+
|
1078
|
+
Args:
|
1079
|
+
cond (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1080
|
+
The array of conditions.
|
1081
|
+
x1, x2 (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1082
|
+
The arrays where the output elements are taken from.
|
1083
|
+
|
1084
|
+
Returns:
|
1085
|
+
The output at position `k` is the k-th element of `x1` if the k-th element
|
1086
|
+
in `cond` is `True`, otherwise it's the k-th element of `x2`.
|
1087
|
+
|
1088
|
+
Example::
|
1089
|
+
|
1090
|
+
>>> @ti.kernel
|
1091
|
+
>>> def main():
|
1092
|
+
>>> cond = ti.Matrix([0, 1, 0, 1])
|
1093
|
+
>>> x = ti.Matrix([1, 2, 3, 4])
|
1094
|
+
>>> y = ti.Matrix([-1, -2, -3, -4])
|
1095
|
+
>>> print(ti.select(cond, x, y))
|
1096
|
+
>>>
|
1097
|
+
>>> main()
|
1098
|
+
[-1, 2, -3, 4]
|
1099
|
+
"""
|
1100
|
+
# TODO: systematically resolve `-1 = True` problem by introducing u1:
|
1101
|
+
cond = logical_not(logical_not(cond))
|
1102
|
+
|
1103
|
+
def py_select(cond, x1, x2):
|
1104
|
+
return x1 * cond + x2 * (1 - cond)
|
1105
|
+
|
1106
|
+
return _ternary_operation(_ti_core.expr_select, py_select, cond, x1, x2)
|
1107
|
+
|
1108
|
+
|
1109
|
+
def ifte(cond, x1, x2):
|
1110
|
+
"""Evaluate and return `x1` if `cond` is true; otherwise evaluate and return `x2`. This operator guarantees
|
1111
|
+
short-circuit semantics: exactly one of `x1` or `x2` will be evaluated.
|
1112
|
+
|
1113
|
+
Args:
|
1114
|
+
cond (:mod:`~gstaichi.types.primitive_types`): \
|
1115
|
+
The condition.
|
1116
|
+
x1, x2 (:mod:`~gstaichi.types.primitive_types`): \
|
1117
|
+
The outputs.
|
1118
|
+
|
1119
|
+
Returns:
|
1120
|
+
`x1` if `cond` is true and `x2` otherwise.
|
1121
|
+
"""
|
1122
|
+
# TODO: systematically resolve `-1 = True` problem by introducing u1:
|
1123
|
+
cond = logical_not(logical_not(cond))
|
1124
|
+
|
1125
|
+
def py_ifte(cond, x1, x2):
|
1126
|
+
return x1 if cond else x2
|
1127
|
+
|
1128
|
+
return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2)
|
1129
|
+
|
1130
|
+
|
1131
|
+
def clz(a):
|
1132
|
+
"""Count the number of leading zeros for a 32bit integer"""
|
1133
|
+
|
1134
|
+
def _clz(x):
|
1135
|
+
for i in range(32):
|
1136
|
+
if 2**i > x:
|
1137
|
+
return 32 - i
|
1138
|
+
return 0
|
1139
|
+
|
1140
|
+
return _unary_operation(_ti_core.expr_clz, _clz, a)
|
1141
|
+
|
1142
|
+
|
1143
|
+
@writeback_binary
|
1144
|
+
def atomic_add(x, y):
|
1145
|
+
"""Atomically compute `x + y`, store the result in `x`,
|
1146
|
+
and return the old value of `x`.
|
1147
|
+
|
1148
|
+
`x` must be a writable target, constant expressions or scalars
|
1149
|
+
are not allowed.
|
1150
|
+
|
1151
|
+
Args:
|
1152
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1153
|
+
The input.
|
1154
|
+
|
1155
|
+
Returns:
|
1156
|
+
The old value of `x`.
|
1157
|
+
|
1158
|
+
Example::
|
1159
|
+
|
1160
|
+
>>> @ti.kernel
|
1161
|
+
>>> def test():
|
1162
|
+
>>> x = ti.Vector([0, 0, 0])
|
1163
|
+
>>> y = ti.Vector([1, 2, 3])
|
1164
|
+
>>> z = ti.atomic_add(x, y)
|
1165
|
+
>>> print(x) # [1, 2, 3] the new value of x
|
1166
|
+
>>> print(z) # [0, 0, 0], the old value of x
|
1167
|
+
>>>
|
1168
|
+
>>> ti.atomic_add(1, x) # will raise GsTaichiSyntaxError
|
1169
|
+
"""
|
1170
|
+
return impl.expr_init(expr.Expr(_ti_core.expr_atomic_add(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
|
1171
|
+
|
1172
|
+
|
1173
|
+
@writeback_binary
|
1174
|
+
def atomic_mul(x, y):
|
1175
|
+
"""Atomically compute `x * y`, store the result in `x`,
|
1176
|
+
and return the old value of `x`.
|
1177
|
+
|
1178
|
+
`x` must be a writable target, constant expressions or scalars
|
1179
|
+
are not allowed.
|
1180
|
+
|
1181
|
+
Args:
|
1182
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1183
|
+
The input.
|
1184
|
+
|
1185
|
+
Returns:
|
1186
|
+
The old value of `x`.
|
1187
|
+
|
1188
|
+
Example::
|
1189
|
+
|
1190
|
+
>>> @ti.kernel
|
1191
|
+
>>> def test():
|
1192
|
+
>>> x = ti.Vector([1, 2, 3])
|
1193
|
+
>>> y = ti.Vector([4, 5, 6])
|
1194
|
+
>>> z = ti.atomic_mul(x, y)
|
1195
|
+
>>> print(x) # [1, 2, 3] the new value of x
|
1196
|
+
>>> print(z) # [4, 10, 18], the old value of x
|
1197
|
+
>>>
|
1198
|
+
>>> ti.atomic_mul(1, x) # will raise GsTaichiSyntaxError
|
1199
|
+
"""
|
1200
|
+
return impl.expr_init(expr.Expr(_ti_core.expr_atomic_mul(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
|
1201
|
+
|
1202
|
+
|
1203
|
+
@writeback_binary
|
1204
|
+
def atomic_sub(x, y):
|
1205
|
+
"""Atomically subtract `x` by `y`, store the result in `x`,
|
1206
|
+
and return the old value of `x`.
|
1207
|
+
|
1208
|
+
`x` must be a writable target, constant expressions or scalars
|
1209
|
+
are not allowed.
|
1210
|
+
|
1211
|
+
Args:
|
1212
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1213
|
+
The input.
|
1214
|
+
|
1215
|
+
Returns:
|
1216
|
+
The old value of `x`.
|
1217
|
+
|
1218
|
+
Example::
|
1219
|
+
|
1220
|
+
>>> @ti.kernel
|
1221
|
+
>>> def test():
|
1222
|
+
>>> x = ti.Vector([0, 0, 0])
|
1223
|
+
>>> y = ti.Vector([1, 2, 3])
|
1224
|
+
>>> z = ti.atomic_sub(x, y)
|
1225
|
+
>>> print(x) # [-1, -2, -3] the new value of x
|
1226
|
+
>>> print(z) # [0, 0, 0], the old value of x
|
1227
|
+
>>>
|
1228
|
+
>>> ti.atomic_sub(1, x) # will raise GsTaichiSyntaxError
|
1229
|
+
"""
|
1230
|
+
return impl.expr_init(expr.Expr(_ti_core.expr_atomic_sub(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
|
1231
|
+
|
1232
|
+
|
1233
|
+
@writeback_binary
|
1234
|
+
def atomic_min(x, y):
|
1235
|
+
"""Atomically compute the minimum of `x` and `y`, element-wise.
|
1236
|
+
Store the result in `x`, and return the old value of `x`.
|
1237
|
+
|
1238
|
+
`x` must be a writable target, constant expressions or scalars
|
1239
|
+
are not allowed.
|
1240
|
+
|
1241
|
+
Args:
|
1242
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1243
|
+
The input.
|
1244
|
+
|
1245
|
+
Returns:
|
1246
|
+
The old value of `x`.
|
1247
|
+
|
1248
|
+
Example::
|
1249
|
+
|
1250
|
+
>>> @ti.kernel
|
1251
|
+
>>> def test():
|
1252
|
+
>>> x = 2
|
1253
|
+
>>> y = 1
|
1254
|
+
>>> z = ti.atomic_min(x, y)
|
1255
|
+
>>> print(x) # 1 the new value of x
|
1256
|
+
>>> print(z) # 2, the old value of x
|
1257
|
+
>>>
|
1258
|
+
>>> ti.atomic_min(1, x) # will raise GsTaichiSyntaxError
|
1259
|
+
"""
|
1260
|
+
return impl.expr_init(expr.Expr(_ti_core.expr_atomic_min(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
|
1261
|
+
|
1262
|
+
|
1263
|
+
@writeback_binary
|
1264
|
+
def atomic_max(x, y):
|
1265
|
+
"""Atomically compute the maximum of `x` and `y`, element-wise.
|
1266
|
+
Store the result in `x`, and return the old value of `x`.
|
1267
|
+
|
1268
|
+
`x` must be a writable target, constant expressions or scalars
|
1269
|
+
are not allowed.
|
1270
|
+
|
1271
|
+
Args:
|
1272
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1273
|
+
The input.
|
1274
|
+
|
1275
|
+
Returns:
|
1276
|
+
The old value of `x`.
|
1277
|
+
|
1278
|
+
Example::
|
1279
|
+
|
1280
|
+
>>> @ti.kernel
|
1281
|
+
>>> def test():
|
1282
|
+
>>> x = 1
|
1283
|
+
>>> y = 2
|
1284
|
+
>>> z = ti.atomic_max(x, y)
|
1285
|
+
>>> print(x) # 2 the new value of x
|
1286
|
+
>>> print(z) # 1, the old value of x
|
1287
|
+
>>>
|
1288
|
+
>>> ti.atomic_max(1, x) # will raise GsTaichiSyntaxError
|
1289
|
+
"""
|
1290
|
+
return impl.expr_init(expr.Expr(_ti_core.expr_atomic_max(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info())))
|
1291
|
+
|
1292
|
+
|
1293
|
+
@writeback_binary
|
1294
|
+
def atomic_and(x, y):
|
1295
|
+
"""Atomically compute the bit-wise AND of `x` and `y`, element-wise.
|
1296
|
+
Store the result in `x`, and return the old value of `x`.
|
1297
|
+
|
1298
|
+
`x` must be a writable target, constant expressions or scalars
|
1299
|
+
are not allowed.
|
1300
|
+
|
1301
|
+
Args:
|
1302
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1303
|
+
The input. When both are matrices they must have the same shape.
|
1304
|
+
|
1305
|
+
Returns:
|
1306
|
+
The old value of `x`.
|
1307
|
+
|
1308
|
+
Example::
|
1309
|
+
|
1310
|
+
>>> @ti.kernel
|
1311
|
+
>>> def test():
|
1312
|
+
>>> x = ti.Vector([-1, 0, 1])
|
1313
|
+
>>> y = ti.Vector([1, 2, 3])
|
1314
|
+
>>> z = ti.atomic_and(x, y)
|
1315
|
+
>>> print(x) # [1, 0, 1] the new value of x
|
1316
|
+
>>> print(z) # [-1, 0, 1], the old value of x
|
1317
|
+
>>>
|
1318
|
+
>>> ti.atomic_and(1, x) # will raise GsTaichiSyntaxError
|
1319
|
+
"""
|
1320
|
+
return impl.expr_init(
|
1321
|
+
expr.Expr(_ti_core.expr_atomic_bit_and(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
|
1325
|
+
@writeback_binary
|
1326
|
+
def atomic_or(x, y):
|
1327
|
+
"""Atomically compute the bit-wise OR of `x` and `y`, element-wise.
|
1328
|
+
Store the result in `x`, and return the old value of `x`.
|
1329
|
+
|
1330
|
+
`x` must be a writable target, constant expressions or scalars
|
1331
|
+
are not allowed.
|
1332
|
+
|
1333
|
+
Args:
|
1334
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1335
|
+
The input. When both are matrices they must have the same shape.
|
1336
|
+
|
1337
|
+
Returns:
|
1338
|
+
The old value of `x`.
|
1339
|
+
|
1340
|
+
Example::
|
1341
|
+
|
1342
|
+
>>> @ti.kernel
|
1343
|
+
>>> def test():
|
1344
|
+
>>> x = ti.Vector([-1, 0, 1])
|
1345
|
+
>>> y = ti.Vector([1, 2, 3])
|
1346
|
+
>>> z = ti.atomic_or(x, y)
|
1347
|
+
>>> print(x) # [-1, 2, 3] the new value of x
|
1348
|
+
>>> print(z) # [-1, 0, 1], the old value of x
|
1349
|
+
>>>
|
1350
|
+
>>> ti.atomic_or(1, x) # will raise GsTaichiSyntaxError
|
1351
|
+
"""
|
1352
|
+
return impl.expr_init(
|
1353
|
+
expr.Expr(_ti_core.expr_atomic_bit_or(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
1354
|
+
)
|
1355
|
+
|
1356
|
+
|
1357
|
+
@writeback_binary
|
1358
|
+
def atomic_xor(x, y):
|
1359
|
+
"""Atomically compute the bit-wise XOR of `x` and `y`, element-wise.
|
1360
|
+
Store the result in `x`, and return the old value of `x`.
|
1361
|
+
|
1362
|
+
`x` must be a writable target, constant expressions or scalars
|
1363
|
+
are not allowed.
|
1364
|
+
|
1365
|
+
Args:
|
1366
|
+
x, y (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1367
|
+
The input. When both are matrices they must have the same shape.
|
1368
|
+
|
1369
|
+
Returns:
|
1370
|
+
The old value of `x`.
|
1371
|
+
|
1372
|
+
Example::
|
1373
|
+
|
1374
|
+
>>> @ti.kernel
|
1375
|
+
>>> def test():
|
1376
|
+
>>> x = ti.Vector([-1, 0, 1])
|
1377
|
+
>>> y = ti.Vector([1, 2, 3])
|
1378
|
+
>>> z = ti.atomic_xor(x, y)
|
1379
|
+
>>> print(x) # [-2, 2, 2] the new value of x
|
1380
|
+
>>> print(z) # [-1, 0, 1], the old value of x
|
1381
|
+
>>>
|
1382
|
+
>>> ti.atomic_xor(1, x) # will raise GsTaichiSyntaxError
|
1383
|
+
"""
|
1384
|
+
return impl.expr_init(
|
1385
|
+
expr.Expr(_ti_core.expr_atomic_bit_xor(x.ptr, y.ptr), dbg_info=_ti_core.DebugInfo(stack_info()))
|
1386
|
+
)
|
1387
|
+
|
1388
|
+
|
1389
|
+
@writeback_binary
|
1390
|
+
def assign(a, b):
|
1391
|
+
impl.get_runtime().compiling_callable.ast_builder().expr_assign(a.ptr, b.ptr, _ti_core.DebugInfo(stack_info()))
|
1392
|
+
return a
|
1393
|
+
|
1394
|
+
|
1395
|
+
def max(*args): # pylint: disable=W0622
|
1396
|
+
"""Compute the maximum of the arguments, element-wise.
|
1397
|
+
|
1398
|
+
This function takes no effect on a single argument, even it's array-like.
|
1399
|
+
When there are both scalar and matrix arguments in `args`, the matrices
|
1400
|
+
must have the same shape, and scalars will be broadcasted to the same shape as the matrix.
|
1401
|
+
|
1402
|
+
Args:
|
1403
|
+
args: (List[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1404
|
+
The input.
|
1405
|
+
|
1406
|
+
Returns:
|
1407
|
+
Maximum of the inputs.
|
1408
|
+
|
1409
|
+
Example::
|
1410
|
+
|
1411
|
+
>>> @ti.kernel
|
1412
|
+
>>> def foo():
|
1413
|
+
>>> x = ti.Vector([0, 1, 2])
|
1414
|
+
>>> y = ti.Vector([3, 4, 5])
|
1415
|
+
>>> z = ti.max(x, y, 4)
|
1416
|
+
>>> print(z) # [4, 4, 5]
|
1417
|
+
"""
|
1418
|
+
num_args = len(args)
|
1419
|
+
assert num_args >= 1
|
1420
|
+
if num_args == 1:
|
1421
|
+
return args[0]
|
1422
|
+
if num_args == 2:
|
1423
|
+
return max_impl(args[0], args[1])
|
1424
|
+
return max_impl(args[0], max(*args[1:]))
|
1425
|
+
|
1426
|
+
|
1427
|
+
def min(*args): # pylint: disable=W0622
|
1428
|
+
"""Compute the minimum of the arguments, element-wise.
|
1429
|
+
|
1430
|
+
This function takes no effect on a single argument, even it's array-like.
|
1431
|
+
When there are both scalar and matrix arguments in `args`, the matrices
|
1432
|
+
must have the same shape, and scalars will be broadcasted to the same shape as the matrix.
|
1433
|
+
|
1434
|
+
Args:
|
1435
|
+
args: (List[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): \
|
1436
|
+
The input.
|
1437
|
+
|
1438
|
+
Returns:
|
1439
|
+
Minimum of the inputs.
|
1440
|
+
|
1441
|
+
Example::
|
1442
|
+
|
1443
|
+
>>> @ti.kernel
|
1444
|
+
>>> def foo():
|
1445
|
+
>>> x = ti.Vector([0, 1, 2])
|
1446
|
+
>>> y = ti.Vector([3, 4, 5])
|
1447
|
+
>>> z = ti.min(x, y, 1)
|
1448
|
+
>>> print(z) # [0, 1, 1]
|
1449
|
+
"""
|
1450
|
+
num_args = len(args)
|
1451
|
+
assert num_args >= 1
|
1452
|
+
if num_args == 1:
|
1453
|
+
return args[0]
|
1454
|
+
if num_args == 2:
|
1455
|
+
return min_impl(args[0], args[1])
|
1456
|
+
return min_impl(args[0], min(*args[1:]))
|
1457
|
+
|
1458
|
+
|
1459
|
+
__all__ = [
|
1460
|
+
"acos",
|
1461
|
+
"asin",
|
1462
|
+
"atan2",
|
1463
|
+
"atomic_and",
|
1464
|
+
"atomic_or",
|
1465
|
+
"atomic_xor",
|
1466
|
+
"atomic_max",
|
1467
|
+
"atomic_sub",
|
1468
|
+
"atomic_min",
|
1469
|
+
"atomic_add",
|
1470
|
+
"atomic_mul",
|
1471
|
+
"bit_cast",
|
1472
|
+
"bit_shr",
|
1473
|
+
"cast",
|
1474
|
+
"ceil",
|
1475
|
+
"cos",
|
1476
|
+
"exp",
|
1477
|
+
"floor",
|
1478
|
+
"frexp",
|
1479
|
+
"log",
|
1480
|
+
"random",
|
1481
|
+
"raw_mod",
|
1482
|
+
"raw_div",
|
1483
|
+
"round",
|
1484
|
+
"rsqrt",
|
1485
|
+
"sin",
|
1486
|
+
"sqrt",
|
1487
|
+
"tan",
|
1488
|
+
"tanh",
|
1489
|
+
"max",
|
1490
|
+
"min",
|
1491
|
+
"select",
|
1492
|
+
"abs",
|
1493
|
+
"pow",
|
1494
|
+
]
|