gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from gstaichi.lang import ops
|
6
|
+
from gstaichi.lang.util import in_python_scope
|
7
|
+
from gstaichi.types import primitive_types
|
8
|
+
|
9
|
+
|
10
|
+
class GsTaichiOperations:
|
11
|
+
"""The base class of gstaichi operations of expressions. Subclasses: :class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`"""
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
# Make pylint happy
|
15
|
+
def __getattr__(self, item):
|
16
|
+
pass
|
17
|
+
|
18
|
+
def __neg__(self):
|
19
|
+
return ops.neg(self)
|
20
|
+
|
21
|
+
def __abs__(self):
|
22
|
+
return ops.abs(self)
|
23
|
+
|
24
|
+
def __add__(self, other):
|
25
|
+
return ops.add(self, other)
|
26
|
+
|
27
|
+
def __radd__(self, other):
|
28
|
+
return ops.add(other, self)
|
29
|
+
|
30
|
+
def __sub__(self, other):
|
31
|
+
return ops.sub(self, other)
|
32
|
+
|
33
|
+
def __rsub__(self, other):
|
34
|
+
return ops.sub(other, self)
|
35
|
+
|
36
|
+
def __mul__(self, other):
|
37
|
+
return ops.mul(self, other)
|
38
|
+
|
39
|
+
def __rmul__(self, other):
|
40
|
+
return ops.mul(other, self)
|
41
|
+
|
42
|
+
def __truediv__(self, other):
|
43
|
+
return ops.truediv(self, other)
|
44
|
+
|
45
|
+
def __rtruediv__(self, other):
|
46
|
+
return ops.truediv(other, self)
|
47
|
+
|
48
|
+
def __floordiv__(self, other):
|
49
|
+
return ops.floordiv(self, other)
|
50
|
+
|
51
|
+
def __rfloordiv__(self, other):
|
52
|
+
return ops.floordiv(other, self)
|
53
|
+
|
54
|
+
def __mod__(self, other):
|
55
|
+
return ops.mod(self, other)
|
56
|
+
|
57
|
+
def __rmod__(self, other):
|
58
|
+
return ops.mod(other, self)
|
59
|
+
|
60
|
+
def __pow__(self, other, modulo=None):
|
61
|
+
return ops.pow(self, other)
|
62
|
+
|
63
|
+
def __rpow__(self, other, modulo=None):
|
64
|
+
return ops.pow(other, self)
|
65
|
+
|
66
|
+
def __le__(self, other):
|
67
|
+
return ops.cmp_le(self, other)
|
68
|
+
|
69
|
+
def __lt__(self, other):
|
70
|
+
return ops.cmp_lt(self, other)
|
71
|
+
|
72
|
+
def __ge__(self, other):
|
73
|
+
return ops.cmp_ge(self, other)
|
74
|
+
|
75
|
+
def __gt__(self, other):
|
76
|
+
return ops.cmp_gt(self, other)
|
77
|
+
|
78
|
+
def __eq__(self, other):
|
79
|
+
return ops.cmp_eq(self, other)
|
80
|
+
|
81
|
+
def __ne__(self, other):
|
82
|
+
return ops.cmp_ne(self, other)
|
83
|
+
|
84
|
+
def __and__(self, other):
|
85
|
+
return ops.bit_and(self, other)
|
86
|
+
|
87
|
+
def __rand__(self, other):
|
88
|
+
return ops.bit_and(other, self)
|
89
|
+
|
90
|
+
def __or__(self, other):
|
91
|
+
return ops.bit_or(self, other)
|
92
|
+
|
93
|
+
def __ror__(self, other):
|
94
|
+
return ops.bit_or(other, self)
|
95
|
+
|
96
|
+
def __xor__(self, other):
|
97
|
+
return ops.bit_xor(self, other)
|
98
|
+
|
99
|
+
def __rxor__(self, other):
|
100
|
+
return ops.bit_xor(other, self)
|
101
|
+
|
102
|
+
def __lshift__(self, other):
|
103
|
+
return ops.bit_shl(self, other)
|
104
|
+
|
105
|
+
def __rlshift__(self, other):
|
106
|
+
return ops.bit_shl(other, self)
|
107
|
+
|
108
|
+
def __rshift__(self, other):
|
109
|
+
return ops.bit_sar(self, other)
|
110
|
+
|
111
|
+
def __rrshift__(self, other):
|
112
|
+
return ops.bit_sar(other, self)
|
113
|
+
|
114
|
+
def __invert__(self): # ~a => a.__invert__()
|
115
|
+
return ops.bit_not(self)
|
116
|
+
|
117
|
+
def __not__(self): # not a => a.__not__()
|
118
|
+
return ops.logical_not(self)
|
119
|
+
|
120
|
+
def _atomic_add(self, other):
|
121
|
+
"""Return the new expression of computing atomic add between self and a given operand.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
other (Any): Given operand.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic add."""
|
128
|
+
return ops.atomic_add(self, other)
|
129
|
+
|
130
|
+
def _atomic_mul(self, other):
|
131
|
+
"""Return the new expression of computing atomic mul between self and a given operand.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
other (Any): Given operand.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic mul."""
|
138
|
+
return ops.atomic_mul(self, other)
|
139
|
+
|
140
|
+
def _atomic_sub(self, other):
|
141
|
+
"""Return the new expression of computing atomic sub between self and a given operand.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
other (Any): Given operand.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic sub."""
|
148
|
+
return ops.atomic_sub(self, other)
|
149
|
+
|
150
|
+
def _atomic_and(self, other):
|
151
|
+
"""Return the new expression of computing atomic and between self and a given operand.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
other (Any): Given operand.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic and."""
|
158
|
+
return ops.atomic_and(self, other)
|
159
|
+
|
160
|
+
def _atomic_xor(self, other):
|
161
|
+
"""Return the new expression of computing atomic xor between self and a given operand.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
other (Any): Given operand.
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic xor."""
|
168
|
+
return ops.atomic_xor(self, other)
|
169
|
+
|
170
|
+
def _atomic_or(self, other):
|
171
|
+
"""Return the new expression of computing atomic or between self and a given operand.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
other (Any): Given operand.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic or."""
|
178
|
+
return ops.atomic_or(self, other)
|
179
|
+
|
180
|
+
# In-place operators in python scope returns NotImplemented to fall back to normal operators
|
181
|
+
def __iadd__(self, other):
|
182
|
+
if in_python_scope():
|
183
|
+
return NotImplemented
|
184
|
+
self._atomic_add(other)
|
185
|
+
return self
|
186
|
+
|
187
|
+
def __imul__(self, other):
|
188
|
+
if in_python_scope():
|
189
|
+
return NotImplemented
|
190
|
+
self._atomic_mul(other)
|
191
|
+
return self
|
192
|
+
|
193
|
+
def __isub__(self, other):
|
194
|
+
if in_python_scope():
|
195
|
+
return NotImplemented
|
196
|
+
self._atomic_sub(other)
|
197
|
+
return self
|
198
|
+
|
199
|
+
def __iand__(self, other):
|
200
|
+
if in_python_scope():
|
201
|
+
return NotImplemented
|
202
|
+
self._atomic_and(other)
|
203
|
+
return self
|
204
|
+
|
205
|
+
def __ixor__(self, other):
|
206
|
+
if in_python_scope():
|
207
|
+
return NotImplemented
|
208
|
+
self._atomic_xor(other)
|
209
|
+
return self
|
210
|
+
|
211
|
+
def __ior__(self, other):
|
212
|
+
if in_python_scope():
|
213
|
+
return NotImplemented
|
214
|
+
self._atomic_or(other)
|
215
|
+
return self
|
216
|
+
|
217
|
+
# we don't support atomic_mul/truediv/floordiv/mod yet:
|
218
|
+
def __imul__(self, other):
|
219
|
+
if in_python_scope():
|
220
|
+
return NotImplemented
|
221
|
+
self._assign(ops.mul(self, other))
|
222
|
+
return self
|
223
|
+
|
224
|
+
def __itruediv__(self, other):
|
225
|
+
if in_python_scope():
|
226
|
+
return NotImplemented
|
227
|
+
self._assign(ops.truediv(self, other))
|
228
|
+
return self
|
229
|
+
|
230
|
+
def __ifloordiv__(self, other):
|
231
|
+
if in_python_scope():
|
232
|
+
return NotImplemented
|
233
|
+
self._assign(ops.floordiv(self, other))
|
234
|
+
return self
|
235
|
+
|
236
|
+
def __imod__(self, other):
|
237
|
+
if in_python_scope():
|
238
|
+
return NotImplemented
|
239
|
+
self._assign(ops.mod(self, other))
|
240
|
+
return self
|
241
|
+
|
242
|
+
def __ilshift__(self, other):
|
243
|
+
if in_python_scope():
|
244
|
+
return NotImplemented
|
245
|
+
self._assign(ops.bit_shl(self, other))
|
246
|
+
return self
|
247
|
+
|
248
|
+
def __irshift__(self, other):
|
249
|
+
if in_python_scope():
|
250
|
+
return NotImplemented
|
251
|
+
self._assign(ops.bit_sar(self, other))
|
252
|
+
return self
|
253
|
+
|
254
|
+
def __ipow__(self, other):
|
255
|
+
if in_python_scope():
|
256
|
+
return NotImplemented
|
257
|
+
self._assign(ops.pow(self, other))
|
258
|
+
return self
|
259
|
+
|
260
|
+
def _assign(self, other):
|
261
|
+
"""Assign the expression of the given operand to self.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
other (Any): Given operand.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
:class:`~gstaichi.lang.expr.Expr`: The expression after assigning."""
|
268
|
+
return ops.assign(self, other)
|
269
|
+
|
270
|
+
def _augassign(self, x, op):
|
271
|
+
"""Generate the computing expression between self and the given operand of given operator and assigned to self.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
x (Any): Given operand.
|
275
|
+
op (str): The name of operator."""
|
276
|
+
if op == "Add":
|
277
|
+
self += x
|
278
|
+
elif op == "Sub":
|
279
|
+
self -= x
|
280
|
+
elif op == "Mult":
|
281
|
+
self *= x
|
282
|
+
elif op == "Div":
|
283
|
+
self /= x
|
284
|
+
elif op == "FloorDiv":
|
285
|
+
self //= x
|
286
|
+
elif op == "Mod":
|
287
|
+
self %= x
|
288
|
+
elif op == "BitAnd":
|
289
|
+
self &= x
|
290
|
+
elif op == "BitOr":
|
291
|
+
self |= x
|
292
|
+
elif op == "BitXor":
|
293
|
+
self ^= x
|
294
|
+
elif op == "RShift":
|
295
|
+
self >>= x
|
296
|
+
elif op == "LShift":
|
297
|
+
self <<= x
|
298
|
+
elif op == "Pow":
|
299
|
+
self **= x
|
300
|
+
else:
|
301
|
+
assert False, op
|
302
|
+
|
303
|
+
def __ti_int__(self):
|
304
|
+
return ops.cast(self, int)
|
305
|
+
|
306
|
+
def __ti_bool__(self):
|
307
|
+
return ops.cast(self, primitive_types.u1)
|
308
|
+
|
309
|
+
def __ti_float__(self):
|
310
|
+
return ops.cast(self, float)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._lib import core
|
4
|
+
|
5
|
+
|
6
|
+
class GsTaichiCompilationError(Exception):
|
7
|
+
"""Base class for all compilation exceptions."""
|
8
|
+
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
class GsTaichiSyntaxError(GsTaichiCompilationError, SyntaxError):
|
13
|
+
"""Thrown when a syntax error is found during compilation."""
|
14
|
+
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
class GsTaichiNameError(GsTaichiCompilationError, NameError):
|
19
|
+
"""Thrown when an undefine name is found during compilation."""
|
20
|
+
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
24
|
+
class GsTaichiIndexError(GsTaichiCompilationError, IndexError):
|
25
|
+
"""Thrown when an index error is found during compilation."""
|
26
|
+
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class GsTaichiTypeError(GsTaichiCompilationError, TypeError):
|
31
|
+
"""Thrown when a type mismatch is found during compilation."""
|
32
|
+
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class GsTaichiRuntimeError(RuntimeError):
|
37
|
+
"""Thrown when the compiled program cannot be executed due to unspecified reasons."""
|
38
|
+
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class GsTaichiAssertionError(GsTaichiRuntimeError, AssertionError):
|
43
|
+
"""Thrown when assertion fails at runtime."""
|
44
|
+
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class GsTaichiRuntimeTypeError(GsTaichiRuntimeError, TypeError):
|
49
|
+
@staticmethod
|
50
|
+
def get(pos, needed, provided):
|
51
|
+
return GsTaichiRuntimeTypeError(
|
52
|
+
f"Argument {pos} (type={provided}) cannot be converted into required type {needed}"
|
53
|
+
)
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def get_ret(needed, provided):
|
57
|
+
return GsTaichiRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")
|
58
|
+
|
59
|
+
|
60
|
+
def handle_exception_from_cpp(exc):
|
61
|
+
if isinstance(exc, core.GsTaichiTypeError):
|
62
|
+
return GsTaichiTypeError(str(exc))
|
63
|
+
if isinstance(exc, core.GsTaichiSyntaxError):
|
64
|
+
return GsTaichiSyntaxError(str(exc))
|
65
|
+
if isinstance(exc, core.GsTaichiIndexError):
|
66
|
+
return GsTaichiIndexError(str(exc))
|
67
|
+
if isinstance(exc, core.GsTaichiAssertionError):
|
68
|
+
return GsTaichiAssertionError(str(exc))
|
69
|
+
return exc
|
70
|
+
|
71
|
+
|
72
|
+
__all__ = [
|
73
|
+
"GsTaichiSyntaxError",
|
74
|
+
"GsTaichiTypeError",
|
75
|
+
"GsTaichiCompilationError",
|
76
|
+
"GsTaichiNameError",
|
77
|
+
"GsTaichiRuntimeError",
|
78
|
+
"GsTaichiRuntimeTypeError",
|
79
|
+
"GsTaichiAssertionError",
|
80
|
+
]
|
gstaichi/lang/expr.py
ADDED
@@ -0,0 +1,180 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi.lang import impl
|
7
|
+
from gstaichi.lang.common_ops import GsTaichiOperations
|
8
|
+
from gstaichi.lang.exception import GsTaichiCompilationError, GsTaichiTypeError
|
9
|
+
from gstaichi.lang.matrix import make_matrix
|
10
|
+
from gstaichi.lang.util import is_gstaichi_class, is_matrix_class, to_numpy_type
|
11
|
+
from gstaichi.types import primitive_types
|
12
|
+
from gstaichi.types.primitive_types import integer_types, real_types
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from gstaichi.lang.ast.ast_transformer_utils import ASTBuilder
|
16
|
+
|
17
|
+
|
18
|
+
# Scalar, basic data type
|
19
|
+
class Expr(GsTaichiOperations):
|
20
|
+
"""A Python-side Expr wrapper, whose member variable `ptr` is an instance of C++ Expr class. A C++ Expr object contains member variable `expr` which holds an instance of C++ Expression class."""
|
21
|
+
|
22
|
+
def __init__(self, *args, dbg_info=None, dtype=None):
|
23
|
+
self.dbg_info = dbg_info
|
24
|
+
self.ptr_type_checked = False
|
25
|
+
self.declaration_tb: str = ""
|
26
|
+
if len(args) == 1:
|
27
|
+
if isinstance(args[0], _ti_core.ExprCxx):
|
28
|
+
self.ptr = args[0]
|
29
|
+
elif isinstance(args[0], Expr):
|
30
|
+
self.ptr = args[0].ptr
|
31
|
+
self.ptr_type_checked = args[0].ptr_type_checked
|
32
|
+
self.dbg_info = args[0].dbg_info
|
33
|
+
elif is_matrix_class(args[0]):
|
34
|
+
self.ptr = make_matrix(args[0].to_list()).ptr
|
35
|
+
elif isinstance(args[0], (list, tuple)):
|
36
|
+
self.ptr = make_matrix(args[0]).ptr
|
37
|
+
else:
|
38
|
+
# assume to be constant
|
39
|
+
arg = args[0]
|
40
|
+
if isinstance(arg, np.ndarray):
|
41
|
+
if arg.shape:
|
42
|
+
raise GsTaichiTypeError(
|
43
|
+
"Only 0-dimensional numpy array can be used to initialize a scalar expression"
|
44
|
+
)
|
45
|
+
arg = arg.dtype.type(arg)
|
46
|
+
self.ptr = make_constant_expr(arg, dtype).ptr
|
47
|
+
else:
|
48
|
+
assert False
|
49
|
+
if self.dbg_info:
|
50
|
+
self.ptr.set_dbg_info(self.dbg_info)
|
51
|
+
if not self.ptr_type_checked:
|
52
|
+
self.ptr.type_check(impl.get_runtime().prog.config())
|
53
|
+
self.ptr_type_checked = True
|
54
|
+
|
55
|
+
def is_tensor(self):
|
56
|
+
return self.ptr.is_tensor()
|
57
|
+
|
58
|
+
def is_struct(self):
|
59
|
+
return self.ptr.is_struct()
|
60
|
+
|
61
|
+
def element_type(self):
|
62
|
+
return self.ptr.get_rvalue_type().element_type()
|
63
|
+
|
64
|
+
def get_shape(self):
|
65
|
+
if not self.is_tensor():
|
66
|
+
raise GsTaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
|
67
|
+
shape = self.ptr.get_shape()
|
68
|
+
assert shape is not None
|
69
|
+
return tuple(shape)
|
70
|
+
|
71
|
+
@property
|
72
|
+
def n(self):
|
73
|
+
shape = self.get_shape()
|
74
|
+
if len(shape) < 1:
|
75
|
+
raise GsTaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
|
76
|
+
return shape[0]
|
77
|
+
|
78
|
+
@property
|
79
|
+
def m(self):
|
80
|
+
shape = self.get_shape()
|
81
|
+
if len(shape) < 2:
|
82
|
+
raise GsTaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
|
83
|
+
return shape[1]
|
84
|
+
|
85
|
+
def __hash__(self):
|
86
|
+
return self.ptr.get_raw_address()
|
87
|
+
|
88
|
+
def __str__(self):
|
89
|
+
return "<ti.Expr>"
|
90
|
+
|
91
|
+
def __repr__(self):
|
92
|
+
return "<ti.Expr>"
|
93
|
+
|
94
|
+
|
95
|
+
def _check_in_range(npty, val):
|
96
|
+
iif = np.iinfo(npty)
|
97
|
+
return iif.min <= val <= iif.max
|
98
|
+
|
99
|
+
|
100
|
+
def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int:
|
101
|
+
# npty: np.int32 or np.int64
|
102
|
+
iif = np.iinfo(npty)
|
103
|
+
if iif.min <= val <= iif.max:
|
104
|
+
return val
|
105
|
+
cap = 1 << iif.bits
|
106
|
+
assert 0 <= val < cap
|
107
|
+
new_val = val - cap
|
108
|
+
return new_val
|
109
|
+
|
110
|
+
|
111
|
+
def make_constant_expr(val, dtype):
|
112
|
+
if isinstance(val, (bool, np.bool_)):
|
113
|
+
constant_dtype = primitive_types.u1
|
114
|
+
return Expr(_ti_core.make_const_expr_bool(constant_dtype, val))
|
115
|
+
|
116
|
+
if isinstance(val, (float, np.floating)):
|
117
|
+
constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
|
118
|
+
if constant_dtype not in real_types:
|
119
|
+
raise GsTaichiTypeError(
|
120
|
+
"Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`."
|
121
|
+
)
|
122
|
+
return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
|
123
|
+
|
124
|
+
if isinstance(val, (int, np.integer)):
|
125
|
+
constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
|
126
|
+
if constant_dtype not in integer_types:
|
127
|
+
raise GsTaichiTypeError(
|
128
|
+
"Integer literals must be annotated with a integer type. For type casting, use `ti.cast`."
|
129
|
+
)
|
130
|
+
if _check_in_range(to_numpy_type(constant_dtype), val):
|
131
|
+
return Expr(_ti_core.make_const_expr_int(constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
|
132
|
+
if dtype is None:
|
133
|
+
raise GsTaichiTypeError(
|
134
|
+
f"Integer literal {val} exceeded the range of default_ip: {impl.get_runtime().default_ip}, please specify the dtype via e.g. `ti.u64({val})` or set a different `default_ip` in `ti.init()`"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
raise GsTaichiTypeError(f"Integer literal {val} exceeded the range of specified dtype: {dtype}")
|
138
|
+
|
139
|
+
raise GsTaichiTypeError(f"Invalid constant scalar data type: {type(val)}")
|
140
|
+
|
141
|
+
|
142
|
+
def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
|
143
|
+
exprs = []
|
144
|
+
prog = impl.get_runtime().prog
|
145
|
+
for _ in range(size):
|
146
|
+
if ast_builder is None:
|
147
|
+
exprs.append(prog.make_id_expr(""))
|
148
|
+
else:
|
149
|
+
exprs.append(ast_builder.make_id_expr(""))
|
150
|
+
return exprs
|
151
|
+
|
152
|
+
|
153
|
+
def make_expr_group(*exprs):
|
154
|
+
from gstaichi.lang.matrix import Matrix # pylint: disable=C0415
|
155
|
+
|
156
|
+
if len(exprs) == 1:
|
157
|
+
if isinstance(exprs[0], (list, tuple)):
|
158
|
+
exprs = exprs[0]
|
159
|
+
elif isinstance(exprs[0], Matrix):
|
160
|
+
mat = exprs[0]
|
161
|
+
assert mat.m == 1
|
162
|
+
exprs = mat.entries
|
163
|
+
expr_group = _ti_core.ExprGroup()
|
164
|
+
for i in exprs:
|
165
|
+
flattened = _get_flattened_ptrs(i)
|
166
|
+
for item in flattened:
|
167
|
+
expr_group.push_back(item)
|
168
|
+
return expr_group
|
169
|
+
|
170
|
+
|
171
|
+
def _get_flattened_ptrs(val):
|
172
|
+
if is_gstaichi_class(val):
|
173
|
+
ptrs = []
|
174
|
+
for item in val._members:
|
175
|
+
ptrs.extend(_get_flattened_ptrs(item))
|
176
|
+
return ptrs
|
177
|
+
return [Expr(val).ptr]
|
178
|
+
|
179
|
+
|
180
|
+
__all__ = []
|