gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
1
+ # type: ignore
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from gstaichi.lang import ops
6
+ from gstaichi.lang.util import in_python_scope
7
+ from gstaichi.types import primitive_types
8
+
9
+
10
+ class GsTaichiOperations:
11
+ """The base class of gstaichi operations of expressions. Subclasses: :class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`"""
12
+
13
+ if TYPE_CHECKING:
14
+ # Make pylint happy
15
+ def __getattr__(self, item):
16
+ pass
17
+
18
+ def __neg__(self):
19
+ return ops.neg(self)
20
+
21
+ def __abs__(self):
22
+ return ops.abs(self)
23
+
24
+ def __add__(self, other):
25
+ return ops.add(self, other)
26
+
27
+ def __radd__(self, other):
28
+ return ops.add(other, self)
29
+
30
+ def __sub__(self, other):
31
+ return ops.sub(self, other)
32
+
33
+ def __rsub__(self, other):
34
+ return ops.sub(other, self)
35
+
36
+ def __mul__(self, other):
37
+ return ops.mul(self, other)
38
+
39
+ def __rmul__(self, other):
40
+ return ops.mul(other, self)
41
+
42
+ def __truediv__(self, other):
43
+ return ops.truediv(self, other)
44
+
45
+ def __rtruediv__(self, other):
46
+ return ops.truediv(other, self)
47
+
48
+ def __floordiv__(self, other):
49
+ return ops.floordiv(self, other)
50
+
51
+ def __rfloordiv__(self, other):
52
+ return ops.floordiv(other, self)
53
+
54
+ def __mod__(self, other):
55
+ return ops.mod(self, other)
56
+
57
+ def __rmod__(self, other):
58
+ return ops.mod(other, self)
59
+
60
+ def __pow__(self, other, modulo=None):
61
+ return ops.pow(self, other)
62
+
63
+ def __rpow__(self, other, modulo=None):
64
+ return ops.pow(other, self)
65
+
66
+ def __le__(self, other):
67
+ return ops.cmp_le(self, other)
68
+
69
+ def __lt__(self, other):
70
+ return ops.cmp_lt(self, other)
71
+
72
+ def __ge__(self, other):
73
+ return ops.cmp_ge(self, other)
74
+
75
+ def __gt__(self, other):
76
+ return ops.cmp_gt(self, other)
77
+
78
+ def __eq__(self, other):
79
+ return ops.cmp_eq(self, other)
80
+
81
+ def __ne__(self, other):
82
+ return ops.cmp_ne(self, other)
83
+
84
+ def __and__(self, other):
85
+ return ops.bit_and(self, other)
86
+
87
+ def __rand__(self, other):
88
+ return ops.bit_and(other, self)
89
+
90
+ def __or__(self, other):
91
+ return ops.bit_or(self, other)
92
+
93
+ def __ror__(self, other):
94
+ return ops.bit_or(other, self)
95
+
96
+ def __xor__(self, other):
97
+ return ops.bit_xor(self, other)
98
+
99
+ def __rxor__(self, other):
100
+ return ops.bit_xor(other, self)
101
+
102
+ def __lshift__(self, other):
103
+ return ops.bit_shl(self, other)
104
+
105
+ def __rlshift__(self, other):
106
+ return ops.bit_shl(other, self)
107
+
108
+ def __rshift__(self, other):
109
+ return ops.bit_sar(self, other)
110
+
111
+ def __rrshift__(self, other):
112
+ return ops.bit_sar(other, self)
113
+
114
+ def __invert__(self): # ~a => a.__invert__()
115
+ return ops.bit_not(self)
116
+
117
+ def __not__(self): # not a => a.__not__()
118
+ return ops.logical_not(self)
119
+
120
+ def _atomic_add(self, other):
121
+ """Return the new expression of computing atomic add between self and a given operand.
122
+
123
+ Args:
124
+ other (Any): Given operand.
125
+
126
+ Returns:
127
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic add."""
128
+ return ops.atomic_add(self, other)
129
+
130
+ def _atomic_mul(self, other):
131
+ """Return the new expression of computing atomic mul between self and a given operand.
132
+
133
+ Args:
134
+ other (Any): Given operand.
135
+
136
+ Returns:
137
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic mul."""
138
+ return ops.atomic_mul(self, other)
139
+
140
+ def _atomic_sub(self, other):
141
+ """Return the new expression of computing atomic sub between self and a given operand.
142
+
143
+ Args:
144
+ other (Any): Given operand.
145
+
146
+ Returns:
147
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic sub."""
148
+ return ops.atomic_sub(self, other)
149
+
150
+ def _atomic_and(self, other):
151
+ """Return the new expression of computing atomic and between self and a given operand.
152
+
153
+ Args:
154
+ other (Any): Given operand.
155
+
156
+ Returns:
157
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic and."""
158
+ return ops.atomic_and(self, other)
159
+
160
+ def _atomic_xor(self, other):
161
+ """Return the new expression of computing atomic xor between self and a given operand.
162
+
163
+ Args:
164
+ other (Any): Given operand.
165
+
166
+ Returns:
167
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic xor."""
168
+ return ops.atomic_xor(self, other)
169
+
170
+ def _atomic_or(self, other):
171
+ """Return the new expression of computing atomic or between self and a given operand.
172
+
173
+ Args:
174
+ other (Any): Given operand.
175
+
176
+ Returns:
177
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic or."""
178
+ return ops.atomic_or(self, other)
179
+
180
+ # In-place operators in python scope returns NotImplemented to fall back to normal operators
181
+ def __iadd__(self, other):
182
+ if in_python_scope():
183
+ return NotImplemented
184
+ self._atomic_add(other)
185
+ return self
186
+
187
+ def __imul__(self, other):
188
+ if in_python_scope():
189
+ return NotImplemented
190
+ self._atomic_mul(other)
191
+ return self
192
+
193
+ def __isub__(self, other):
194
+ if in_python_scope():
195
+ return NotImplemented
196
+ self._atomic_sub(other)
197
+ return self
198
+
199
+ def __iand__(self, other):
200
+ if in_python_scope():
201
+ return NotImplemented
202
+ self._atomic_and(other)
203
+ return self
204
+
205
+ def __ixor__(self, other):
206
+ if in_python_scope():
207
+ return NotImplemented
208
+ self._atomic_xor(other)
209
+ return self
210
+
211
+ def __ior__(self, other):
212
+ if in_python_scope():
213
+ return NotImplemented
214
+ self._atomic_or(other)
215
+ return self
216
+
217
+ # we don't support atomic_mul/truediv/floordiv/mod yet:
218
+ def __imul__(self, other):
219
+ if in_python_scope():
220
+ return NotImplemented
221
+ self._assign(ops.mul(self, other))
222
+ return self
223
+
224
+ def __itruediv__(self, other):
225
+ if in_python_scope():
226
+ return NotImplemented
227
+ self._assign(ops.truediv(self, other))
228
+ return self
229
+
230
+ def __ifloordiv__(self, other):
231
+ if in_python_scope():
232
+ return NotImplemented
233
+ self._assign(ops.floordiv(self, other))
234
+ return self
235
+
236
+ def __imod__(self, other):
237
+ if in_python_scope():
238
+ return NotImplemented
239
+ self._assign(ops.mod(self, other))
240
+ return self
241
+
242
+ def __ilshift__(self, other):
243
+ if in_python_scope():
244
+ return NotImplemented
245
+ self._assign(ops.bit_shl(self, other))
246
+ return self
247
+
248
+ def __irshift__(self, other):
249
+ if in_python_scope():
250
+ return NotImplemented
251
+ self._assign(ops.bit_sar(self, other))
252
+ return self
253
+
254
+ def __ipow__(self, other):
255
+ if in_python_scope():
256
+ return NotImplemented
257
+ self._assign(ops.pow(self, other))
258
+ return self
259
+
260
+ def _assign(self, other):
261
+ """Assign the expression of the given operand to self.
262
+
263
+ Args:
264
+ other (Any): Given operand.
265
+
266
+ Returns:
267
+ :class:`~gstaichi.lang.expr.Expr`: The expression after assigning."""
268
+ return ops.assign(self, other)
269
+
270
+ def _augassign(self, x, op):
271
+ """Generate the computing expression between self and the given operand of given operator and assigned to self.
272
+
273
+ Args:
274
+ x (Any): Given operand.
275
+ op (str): The name of operator."""
276
+ if op == "Add":
277
+ self += x
278
+ elif op == "Sub":
279
+ self -= x
280
+ elif op == "Mult":
281
+ self *= x
282
+ elif op == "Div":
283
+ self /= x
284
+ elif op == "FloorDiv":
285
+ self //= x
286
+ elif op == "Mod":
287
+ self %= x
288
+ elif op == "BitAnd":
289
+ self &= x
290
+ elif op == "BitOr":
291
+ self |= x
292
+ elif op == "BitXor":
293
+ self ^= x
294
+ elif op == "RShift":
295
+ self >>= x
296
+ elif op == "LShift":
297
+ self <<= x
298
+ elif op == "Pow":
299
+ self **= x
300
+ else:
301
+ assert False, op
302
+
303
+ def __ti_int__(self):
304
+ return ops.cast(self, int)
305
+
306
+ def __ti_bool__(self):
307
+ return ops.cast(self, primitive_types.u1)
308
+
309
+ def __ti_float__(self):
310
+ return ops.cast(self, float)
@@ -0,0 +1,80 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core
4
+
5
+
6
+ class GsTaichiCompilationError(Exception):
7
+ """Base class for all compilation exceptions."""
8
+
9
+ pass
10
+
11
+
12
+ class GsTaichiSyntaxError(GsTaichiCompilationError, SyntaxError):
13
+ """Thrown when a syntax error is found during compilation."""
14
+
15
+ pass
16
+
17
+
18
+ class GsTaichiNameError(GsTaichiCompilationError, NameError):
19
+ """Thrown when an undefine name is found during compilation."""
20
+
21
+ pass
22
+
23
+
24
+ class GsTaichiIndexError(GsTaichiCompilationError, IndexError):
25
+ """Thrown when an index error is found during compilation."""
26
+
27
+ pass
28
+
29
+
30
+ class GsTaichiTypeError(GsTaichiCompilationError, TypeError):
31
+ """Thrown when a type mismatch is found during compilation."""
32
+
33
+ pass
34
+
35
+
36
+ class GsTaichiRuntimeError(RuntimeError):
37
+ """Thrown when the compiled program cannot be executed due to unspecified reasons."""
38
+
39
+ pass
40
+
41
+
42
+ class GsTaichiAssertionError(GsTaichiRuntimeError, AssertionError):
43
+ """Thrown when assertion fails at runtime."""
44
+
45
+ pass
46
+
47
+
48
+ class GsTaichiRuntimeTypeError(GsTaichiRuntimeError, TypeError):
49
+ @staticmethod
50
+ def get(pos, needed, provided):
51
+ return GsTaichiRuntimeTypeError(
52
+ f"Argument {pos} (type={provided}) cannot be converted into required type {needed}"
53
+ )
54
+
55
+ @staticmethod
56
+ def get_ret(needed, provided):
57
+ return GsTaichiRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")
58
+
59
+
60
+ def handle_exception_from_cpp(exc):
61
+ if isinstance(exc, core.GsTaichiTypeError):
62
+ return GsTaichiTypeError(str(exc))
63
+ if isinstance(exc, core.GsTaichiSyntaxError):
64
+ return GsTaichiSyntaxError(str(exc))
65
+ if isinstance(exc, core.GsTaichiIndexError):
66
+ return GsTaichiIndexError(str(exc))
67
+ if isinstance(exc, core.GsTaichiAssertionError):
68
+ return GsTaichiAssertionError(str(exc))
69
+ return exc
70
+
71
+
72
+ __all__ = [
73
+ "GsTaichiSyntaxError",
74
+ "GsTaichiTypeError",
75
+ "GsTaichiCompilationError",
76
+ "GsTaichiNameError",
77
+ "GsTaichiRuntimeError",
78
+ "GsTaichiRuntimeTypeError",
79
+ "GsTaichiAssertionError",
80
+ ]
gstaichi/lang/expr.py ADDED
@@ -0,0 +1,180 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi.lang import impl
7
+ from gstaichi.lang.common_ops import GsTaichiOperations
8
+ from gstaichi.lang.exception import GsTaichiCompilationError, GsTaichiTypeError
9
+ from gstaichi.lang.matrix import make_matrix
10
+ from gstaichi.lang.util import is_gstaichi_class, is_matrix_class, to_numpy_type
11
+ from gstaichi.types import primitive_types
12
+ from gstaichi.types.primitive_types import integer_types, real_types
13
+
14
+ if TYPE_CHECKING:
15
+ from gstaichi.lang.ast.ast_transformer_utils import ASTBuilder
16
+
17
+
18
+ # Scalar, basic data type
19
+ class Expr(GsTaichiOperations):
20
+ """A Python-side Expr wrapper, whose member variable `ptr` is an instance of C++ Expr class. A C++ Expr object contains member variable `expr` which holds an instance of C++ Expression class."""
21
+
22
+ def __init__(self, *args, dbg_info=None, dtype=None):
23
+ self.dbg_info = dbg_info
24
+ self.ptr_type_checked = False
25
+ self.declaration_tb: str = ""
26
+ if len(args) == 1:
27
+ if isinstance(args[0], _ti_core.ExprCxx):
28
+ self.ptr = args[0]
29
+ elif isinstance(args[0], Expr):
30
+ self.ptr = args[0].ptr
31
+ self.ptr_type_checked = args[0].ptr_type_checked
32
+ self.dbg_info = args[0].dbg_info
33
+ elif is_matrix_class(args[0]):
34
+ self.ptr = make_matrix(args[0].to_list()).ptr
35
+ elif isinstance(args[0], (list, tuple)):
36
+ self.ptr = make_matrix(args[0]).ptr
37
+ else:
38
+ # assume to be constant
39
+ arg = args[0]
40
+ if isinstance(arg, np.ndarray):
41
+ if arg.shape:
42
+ raise GsTaichiTypeError(
43
+ "Only 0-dimensional numpy array can be used to initialize a scalar expression"
44
+ )
45
+ arg = arg.dtype.type(arg)
46
+ self.ptr = make_constant_expr(arg, dtype).ptr
47
+ else:
48
+ assert False
49
+ if self.dbg_info:
50
+ self.ptr.set_dbg_info(self.dbg_info)
51
+ if not self.ptr_type_checked:
52
+ self.ptr.type_check(impl.get_runtime().prog.config())
53
+ self.ptr_type_checked = True
54
+
55
+ def is_tensor(self):
56
+ return self.ptr.is_tensor()
57
+
58
+ def is_struct(self):
59
+ return self.ptr.is_struct()
60
+
61
+ def element_type(self):
62
+ return self.ptr.get_rvalue_type().element_type()
63
+
64
+ def get_shape(self):
65
+ if not self.is_tensor():
66
+ raise GsTaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
67
+ shape = self.ptr.get_shape()
68
+ assert shape is not None
69
+ return tuple(shape)
70
+
71
+ @property
72
+ def n(self):
73
+ shape = self.get_shape()
74
+ if len(shape) < 1:
75
+ raise GsTaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
76
+ return shape[0]
77
+
78
+ @property
79
+ def m(self):
80
+ shape = self.get_shape()
81
+ if len(shape) < 2:
82
+ raise GsTaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
83
+ return shape[1]
84
+
85
+ def __hash__(self):
86
+ return self.ptr.get_raw_address()
87
+
88
+ def __str__(self):
89
+ return "<ti.Expr>"
90
+
91
+ def __repr__(self):
92
+ return "<ti.Expr>"
93
+
94
+
95
+ def _check_in_range(npty, val):
96
+ iif = np.iinfo(npty)
97
+ return iif.min <= val <= iif.max
98
+
99
+
100
+ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int:
101
+ # npty: np.int32 or np.int64
102
+ iif = np.iinfo(npty)
103
+ if iif.min <= val <= iif.max:
104
+ return val
105
+ cap = 1 << iif.bits
106
+ assert 0 <= val < cap
107
+ new_val = val - cap
108
+ return new_val
109
+
110
+
111
+ def make_constant_expr(val, dtype):
112
+ if isinstance(val, (bool, np.bool_)):
113
+ constant_dtype = primitive_types.u1
114
+ return Expr(_ti_core.make_const_expr_bool(constant_dtype, val))
115
+
116
+ if isinstance(val, (float, np.floating)):
117
+ constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
118
+ if constant_dtype not in real_types:
119
+ raise GsTaichiTypeError(
120
+ "Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`."
121
+ )
122
+ return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
123
+
124
+ if isinstance(val, (int, np.integer)):
125
+ constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
126
+ if constant_dtype not in integer_types:
127
+ raise GsTaichiTypeError(
128
+ "Integer literals must be annotated with a integer type. For type casting, use `ti.cast`."
129
+ )
130
+ if _check_in_range(to_numpy_type(constant_dtype), val):
131
+ return Expr(_ti_core.make_const_expr_int(constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
132
+ if dtype is None:
133
+ raise GsTaichiTypeError(
134
+ f"Integer literal {val} exceeded the range of default_ip: {impl.get_runtime().default_ip}, please specify the dtype via e.g. `ti.u64({val})` or set a different `default_ip` in `ti.init()`"
135
+ )
136
+ else:
137
+ raise GsTaichiTypeError(f"Integer literal {val} exceeded the range of specified dtype: {dtype}")
138
+
139
+ raise GsTaichiTypeError(f"Invalid constant scalar data type: {type(val)}")
140
+
141
+
142
+ def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
143
+ exprs = []
144
+ prog = impl.get_runtime().prog
145
+ for _ in range(size):
146
+ if ast_builder is None:
147
+ exprs.append(prog.make_id_expr(""))
148
+ else:
149
+ exprs.append(ast_builder.make_id_expr(""))
150
+ return exprs
151
+
152
+
153
+ def make_expr_group(*exprs):
154
+ from gstaichi.lang.matrix import Matrix # pylint: disable=C0415
155
+
156
+ if len(exprs) == 1:
157
+ if isinstance(exprs[0], (list, tuple)):
158
+ exprs = exprs[0]
159
+ elif isinstance(exprs[0], Matrix):
160
+ mat = exprs[0]
161
+ assert mat.m == 1
162
+ exprs = mat.entries
163
+ expr_group = _ti_core.ExprGroup()
164
+ for i in exprs:
165
+ flattened = _get_flattened_ptrs(i)
166
+ for item in flattened:
167
+ expr_group.push_back(item)
168
+ return expr_group
169
+
170
+
171
+ def _get_flattened_ptrs(val):
172
+ if is_gstaichi_class(val):
173
+ ptrs = []
174
+ for item in val._members:
175
+ ptrs.extend(_get_flattened_ptrs(item))
176
+ return ptrs
177
+ return [Expr(val).ptr]
178
+
179
+
180
+ __all__ = []