gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,341 @@
1
+ # type: ignore
2
+
3
+ import gstaichi.lang.ops as ops_mod
4
+ from gstaichi.lang.impl import static
5
+ from gstaichi.lang.kernel_impl import func, pyfunc
6
+ from gstaichi.lang.matrix import Matrix, Vector
7
+ from gstaichi.lang.matrix_ops_utils import (
8
+ arg_at,
9
+ arg_foreach_check,
10
+ assert_list,
11
+ assert_tensor,
12
+ assert_vector,
13
+ check_matmul,
14
+ check_transpose,
15
+ dim_lt,
16
+ is_int_const,
17
+ preconditions,
18
+ same_shapes,
19
+ square_matrix,
20
+ )
21
+ from gstaichi.types.annotations import template
22
+
23
+
24
+ @preconditions(arg_at(0, assert_tensor))
25
+ @pyfunc
26
+ def _reduce(mat, fun: template()):
27
+ shape = static(mat.get_shape())
28
+ if static(len(shape) == 1):
29
+ result = mat[0]
30
+ for i in static(range(1, shape[0])):
31
+ result = fun(result, mat[i])
32
+ return result
33
+ result = mat[0, 0]
34
+ for i in static(range(shape[0])):
35
+ for j in static(range(shape[1])):
36
+ if static(i != 0 or j != 0):
37
+ result = fun(result, mat[i, j])
38
+ return result
39
+
40
+
41
+ @pyfunc
42
+ def _filled_vector(n: template(), dtype: template(), val: template()):
43
+ return Vector([val for _ in static(range(n))], dtype)
44
+
45
+
46
+ @pyfunc
47
+ def _filled_matrix(n: template(), m: template(), dtype: template(), val: template()):
48
+ return Matrix([[val for _ in static(range(m))] for _ in static(range(n))], dtype)
49
+
50
+
51
+ @pyfunc
52
+ def _unit_vector(n: template(), i: template(), dtype: template()):
53
+ return Vector([i == j for j in static(range(n))], dtype)
54
+
55
+
56
+ @pyfunc
57
+ def _identity_matrix(n: template(), dtype: template()):
58
+ return Matrix([[i == j for j in static(range(n))] for i in static(range(n))], dtype)
59
+
60
+
61
+ @preconditions(
62
+ arg_at(0, lambda xs: same_shapes(*xs)),
63
+ arg_foreach_check(
64
+ 0,
65
+ fns=[assert_vector(), assert_list],
66
+ logic="or",
67
+ msg="Cols/rows must be a list of lists, or a list of vectors",
68
+ ),
69
+ )
70
+ @pyfunc
71
+ def rows(rows): # pylint: disable=W0621
72
+ return Matrix([[x for x in row] for row in rows])
73
+
74
+
75
+ @pyfunc
76
+ def cols(cols): # pylint: disable=W0621
77
+ return rows(cols).transpose()
78
+
79
+
80
+ @pyfunc
81
+ def E(mat: template(), x: template(), y: template(), n: template()):
82
+ return mat[x % n, y % n]
83
+
84
+
85
+ @preconditions(square_matrix, dim_lt(0, 5))
86
+ @pyfunc
87
+ def determinant(mat):
88
+ shape = static(mat.get_shape())
89
+ if static(shape[0] == 1):
90
+ return mat[0, 0]
91
+ if static(shape[0] == 2):
92
+ return mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]
93
+ if static(shape[0] == 3):
94
+ return (
95
+ mat[0, 0] * (mat[1, 1] * mat[2, 2] - mat[2, 1] * mat[1, 2])
96
+ - mat[1, 0] * (mat[0, 1] * mat[2, 2] - mat[2, 1] * mat[0, 2])
97
+ + mat[2, 0] * (mat[0, 1] * mat[1, 2] - mat[1, 1] * mat[0, 2])
98
+ )
99
+ if static(shape[0] == 4):
100
+ det = mat[0, 0] * 0 # keep type
101
+ for i in static(range(4)):
102
+ det = det + (-1) ** i * (
103
+ mat[i, 0]
104
+ * (
105
+ E(mat, i + 1, 1, 4)
106
+ * (E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) - E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4))
107
+ - E(mat, i + 2, 1, 4)
108
+ * (E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) - E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4))
109
+ + E(mat, i + 3, 1, 4)
110
+ * (E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) - E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))
111
+ )
112
+ )
113
+ return det
114
+ # unreachable
115
+ return None
116
+
117
+
118
+ @preconditions(square_matrix, dim_lt(0, 5))
119
+ @pyfunc
120
+ def inverse(mat):
121
+ shape = static(mat.get_shape())
122
+ if static(shape[0] == 1):
123
+ return Matrix([[1.0 / mat[0, 0]]])
124
+ inv_determinant = 1.0 / determinant(mat)
125
+ if static(shape[0] == 2):
126
+ return inv_determinant * Matrix([[mat[1, 1], -mat[0, 1]], [-mat[1, 0], mat[0, 0]]])
127
+ if static(shape[0] == 3):
128
+ return inv_determinant * Matrix(
129
+ [
130
+ [
131
+ E(mat, i + 1, j + 1, 3) * E(mat, i + 2, j + 2, 3)
132
+ - E(mat, i + 2, j + 1, 3) * E(mat, i + 1, j + 2, 3)
133
+ for i in static(range(3))
134
+ ]
135
+ for j in static(range(3))
136
+ ]
137
+ )
138
+ if static(shape[0] == 4):
139
+ return inv_determinant * Matrix(
140
+ [
141
+ [
142
+ (-1) ** (i + j)
143
+ * (
144
+ (
145
+ E(mat, i + 1, j + 1, 4)
146
+ * (
147
+ E(mat, i + 2, j + 2, 4) * E(mat, i + 3, j + 3, 4)
148
+ - E(mat, i + 3, j + 2, 4) * E(mat, i + 2, j + 3, 4)
149
+ )
150
+ - E(mat, i + 2, j + 1, 4)
151
+ * (
152
+ E(mat, i + 1, j + 2, 4) * E(mat, i + 3, j + 3, 4)
153
+ - E(mat, i + 3, j + 2, 4) * E(mat, i + 1, j + 3, 4)
154
+ )
155
+ + E(mat, i + 3, j + 1, 4)
156
+ * (
157
+ E(mat, i + 1, j + 2, 4) * E(mat, i + 2, j + 3, 4)
158
+ - E(mat, i + 2, j + 2, 4) * E(mat, i + 1, j + 3, 4)
159
+ )
160
+ )
161
+ )
162
+ for i in static(range(4))
163
+ ]
164
+ for j in static(range(4))
165
+ ]
166
+ )
167
+ # unreachable
168
+ return None
169
+
170
+
171
+ @preconditions(check_transpose)
172
+ @pyfunc
173
+ def transpose(mat):
174
+ shape = static(mat.get_shape())
175
+ return Matrix([[mat[i, j] for i in static(range(shape[0]))] for j in static(range(shape[1]))])
176
+
177
+
178
+ @preconditions(arg_at(0, is_int_const))
179
+ @pyfunc
180
+ def diag(dim: template(), val: template()):
181
+ return Matrix([[val if i == j else 0 for j in static(range(dim))] for i in static(range(dim))])
182
+
183
+
184
+ @preconditions(assert_tensor)
185
+ @pyfunc
186
+ def sum(mat): # pylint: disable=W0622
187
+ return _reduce(mat, ops_mod.add)
188
+
189
+
190
+ @preconditions(assert_tensor)
191
+ @pyfunc
192
+ def norm_sqr(mat):
193
+ return sum(mat * mat)
194
+
195
+
196
+ @preconditions(arg_at(0, assert_tensor))
197
+ @pyfunc
198
+ def norm(mat, eps=0.0):
199
+ return ops_mod.sqrt(norm_sqr(mat) + eps)
200
+
201
+
202
+ @preconditions(arg_at(0, assert_tensor))
203
+ @pyfunc
204
+ def norm_inv(mat, eps=0.0):
205
+ return ops_mod.rsqrt(norm_sqr(mat) + eps)
206
+
207
+
208
+ @preconditions(arg_at(0, assert_vector()))
209
+ @pyfunc
210
+ def normalized(vec, eps=0.0):
211
+ invlen = 1 / (norm(vec) + eps)
212
+ return invlen * vec
213
+
214
+
215
+ @preconditions(assert_tensor)
216
+ @pyfunc
217
+ def any(mat): # pylint: disable=W0622
218
+ return _reduce(mat != 0, ops_mod.logical_or) and True
219
+
220
+
221
+ @preconditions(assert_tensor)
222
+ @pyfunc
223
+ def all(mat): # pylint: disable=W0622
224
+ return _reduce(mat != 0, ops_mod.logical_and) and True
225
+
226
+
227
+ @preconditions(assert_tensor)
228
+ @pyfunc
229
+ def max(mat): # pylint: disable=W0622
230
+ return _reduce(mat, ops_mod.max_impl)
231
+
232
+
233
+ @preconditions(assert_tensor)
234
+ @pyfunc
235
+ def min(mat): # pylint: disable=W0622
236
+ return _reduce(mat, ops_mod.min_impl)
237
+
238
+
239
+ @preconditions(square_matrix)
240
+ @pyfunc
241
+ def trace(mat):
242
+ shape = static(mat.get_shape())
243
+ result = mat[0, 0]
244
+ # TODO: get rid of static when
245
+ # CHI IR Tensor repr is ready stable
246
+ for i in static(range(1, shape[0])):
247
+ result = result + mat[i, i]
248
+ return result
249
+
250
+
251
+ @preconditions(arg_at(0, assert_tensor))
252
+ @pyfunc
253
+ def fill(mat: template(), val):
254
+ shape = static(mat.get_shape())
255
+ if static(len(shape) == 1):
256
+ for i in static(range(shape[0])):
257
+ mat[i] = val
258
+ else:
259
+ for i in static(range(shape[0])):
260
+ for j in static(range(shape[1])):
261
+ mat[i, j] = val
262
+
263
+
264
+ @preconditions(check_matmul)
265
+ @pyfunc
266
+ def _matmul_helper(mat_x, mat_y):
267
+ shape_x = static(mat_x.get_shape())
268
+ shape_y = static(mat_y.get_shape())
269
+ if static(len(shape_x) == 1 and len(shape_y) == 1):
270
+ return dot(mat_x, mat_y)
271
+ if static(len(shape_y) == 1):
272
+ zero_elem = mat_x[0, 0] * mat_y[0] * 0 # for correct return type
273
+ vec_z = _filled_vector(shape_x[0], None, zero_elem)
274
+ for i in static(range(shape_x[0])):
275
+ for j in static(range(shape_x[1])):
276
+ vec_z[i] = vec_z[i] + mat_x[i, j] * mat_y[j]
277
+ return vec_z
278
+ zero_elem = mat_x[0, 0] * mat_y[0, 0] * 0 # for correct return type
279
+ mat_z = _filled_matrix(shape_x[0], shape_y[1], None, zero_elem)
280
+ for i in static(range(shape_x[0])):
281
+ for j in static(range(shape_y[1])):
282
+ for k in static(range(shape_x[1])):
283
+ mat_z[i, j] = mat_z[i, j] + mat_x[i, k] * mat_y[k, j]
284
+ return mat_z
285
+
286
+
287
+ @pyfunc
288
+ def matmul(mat_x, mat_y):
289
+ shape_x = static(mat_x.get_shape())
290
+ shape_y = static(mat_y.get_shape())
291
+ if static(len(shape_x) == 1 and len(shape_y) == 2):
292
+ return _matmul_helper(transpose(mat_y), mat_x)
293
+ return _matmul_helper(mat_x, mat_y)
294
+
295
+
296
+ @preconditions(
297
+ arg_at(0, assert_vector("lhs for dot is not a vector")),
298
+ arg_at(1, assert_vector("rhs for dot is not a vector")),
299
+ )
300
+ @pyfunc
301
+ def dot(vec_x, vec_y):
302
+ return sum(vec_x * vec_y)
303
+
304
+
305
+ @preconditions(
306
+ arg_at(0, assert_vector("lhs for cross is not a vector")),
307
+ arg_at(1, assert_vector("rhs for cross is not a vector")),
308
+ same_shapes,
309
+ arg_at(0, dim_lt(0, 4)),
310
+ )
311
+ @pyfunc
312
+ def cross(vec_x, vec_y):
313
+ shape = static(vec_x.get_shape())
314
+ if static(shape[0] == 2):
315
+ return vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0]
316
+ if static(shape[0] == 3):
317
+ return Vector(
318
+ [
319
+ vec_x[1] * vec_y[2] - vec_x[2] * vec_y[1],
320
+ vec_x[2] * vec_y[0] - vec_x[0] * vec_y[2],
321
+ vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0],
322
+ ]
323
+ )
324
+ return None
325
+
326
+
327
+ @preconditions(
328
+ arg_at(0, assert_vector("lhs for outer_product is not a vector")),
329
+ arg_at(1, assert_vector("rhs for outer_product is not a vector")),
330
+ )
331
+ @pyfunc
332
+ def outer_product(vec_x, vec_y):
333
+ shape_x = static(vec_x.get_shape())
334
+ shape_y = static(vec_y.get_shape())
335
+ return Matrix([[vec_x[i] * vec_y[j] for j in static(range(shape_y[0]))] for i in static(range(shape_x[0]))])
336
+
337
+
338
+ @preconditions(assert_tensor)
339
+ @func
340
+ def cast(mat, dtype: template()):
341
+ return ops_mod.cast(mat, dtype)
@@ -0,0 +1,190 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+
5
+ from gstaichi.lang.exception import GsTaichiCompilationError
6
+ from gstaichi.lang.expr import Expr
7
+ from gstaichi.lang.matrix import Matrix
8
+
9
+
10
+ def do_check(checker_fns, *args, **kwargs):
11
+ for f in checker_fns:
12
+ ok, msg = f(*args, **kwargs)
13
+ if not ok:
14
+ return False, msg
15
+ return True, None
16
+
17
+
18
+ def preconditions(*checker_funcs):
19
+ def decorator(func):
20
+ @functools.wraps(func)
21
+ def wrapper(*args, **kwargs):
22
+ ok, msg = do_check(checker_funcs, *args, **kwargs)
23
+ if not ok:
24
+ raise GsTaichiCompilationError(msg)
25
+ return func(*args, **kwargs)
26
+
27
+ return wrapper
28
+
29
+ return decorator
30
+
31
+
32
+ def arg_at(indices, *fns):
33
+ def check(*args, **kwargs):
34
+ nonlocal indices
35
+ if isinstance(indices, int):
36
+ indices = [indices]
37
+ for i in indices:
38
+ if i in kwargs:
39
+ arg = kwargs[i]
40
+ else:
41
+ arg = args[i]
42
+ ok, msg = do_check(fns, arg)
43
+ if not ok:
44
+ return False, msg
45
+ return True, None
46
+
47
+ return check
48
+
49
+
50
+ def assert_tensor(m, msg="not tensor type: {}"):
51
+ if isinstance(m, Matrix):
52
+ return True, None
53
+ if isinstance(m, Expr) and m.is_tensor():
54
+ return True, None
55
+ return False, msg.format(type(m))
56
+
57
+
58
+ def assert_vector(msg="expected a vector, got {}"):
59
+ def check(v):
60
+ if (isinstance(v, Expr) or isinstance(v, Matrix)) and len(v.get_shape()) == 1:
61
+ return True, None
62
+ return False, msg.format(type(v))
63
+
64
+ return check
65
+
66
+
67
+ def assert_list(x, msg="not a list: {}"):
68
+ if isinstance(x, list):
69
+ return True, None
70
+ return False, msg.format(type(x))
71
+
72
+
73
+ def arg_foreach_check(*arg_indices, fns=[], logic="or", msg=None):
74
+ def check(*args, **kwargs):
75
+ for i in arg_indices:
76
+ if i in kwargs:
77
+ arg = kwargs[i]
78
+ else:
79
+ arg = args[i]
80
+ if logic == "or":
81
+ for a in arg:
82
+ passed = False
83
+ for fn in fns:
84
+ ok, _ = do_check([fn], a)
85
+ if ok:
86
+ passed = True
87
+ break
88
+ if not passed:
89
+ return False, msg
90
+ elif logic == "and":
91
+ for a in arg:
92
+ ok, _ = do_check(fns, a)
93
+ if not ok:
94
+ return False, msg
95
+ else:
96
+ raise ValueError(f"Unknown logic: {logic}")
97
+ return True, None
98
+
99
+ return check
100
+
101
+
102
+ def get_list_shape(x):
103
+ outer_shape = [len(x)]
104
+ inner_shape = None
105
+ for element in x:
106
+ if isinstance(element, list):
107
+ cur_shape = get_list_shape(element)
108
+ else:
109
+ cur_shape = []
110
+
111
+ if inner_shape:
112
+ assert curr_shape == inner_shape
113
+ else:
114
+ inner_shape = cur_shape
115
+
116
+ return outer_shape + inner_shape
117
+
118
+
119
+ def same_shapes(*xs):
120
+ shapes = []
121
+ for x in xs:
122
+ if isinstance(x, Matrix):
123
+ shapes.append(x.get_shape())
124
+ elif isinstance(x, list):
125
+ shapes.append(tuple(get_list_shape(x)))
126
+ elif isinstance(x, Expr):
127
+ shapes.append(tuple(x.ptr.get_rvalue_type().shape()))
128
+ else:
129
+ return False, f"same_shapes() received an unexpected argument of type: {x}"
130
+
131
+ if len(set(shapes)) != 1:
132
+ return False, f"required shapes to be the same, got shapes {shapes}"
133
+ return True, None
134
+
135
+
136
+ def square_matrix(x):
137
+ assert_tensor(x)
138
+ shape = x.get_shape()
139
+ if len(shape) != 2 or shape[0] != shape[1]:
140
+ return False, f"expected a square matrix, got shape {shape}"
141
+ return True, None
142
+
143
+
144
+ def dim_lt(dim, limit):
145
+ def check(x):
146
+ assert_tensor(x)
147
+ shape = x.get_shape()
148
+ return shape[dim] < limit, (f"only dimension < {limit} is supported, got shape {shape}")
149
+
150
+ return check
151
+
152
+
153
+ def is_int_const(x):
154
+ if isinstance(x, int):
155
+ return True, None
156
+ if isinstance(x, Expr) and x.val_int() is not None:
157
+ return True, None
158
+ return False, f"not an integer: {x} of type {type(x).__name__}"
159
+
160
+
161
+ def check_matmul(x, y):
162
+ assert_tensor(x, f"left hand side is not a matrix: {type(x)}")
163
+ assert_tensor(y, f"right hand side is not a matrix: {type(y)}")
164
+ x_shape = x.get_shape()
165
+ y_shape = y.get_shape()
166
+ if len(x_shape) == 1:
167
+ if len(y_shape) == 1:
168
+ return True, None
169
+ if x_shape[0] != y_shape[0]:
170
+ return (
171
+ False,
172
+ f"dimension mismatch between {x_shape} and {y_shape} for left multiplication",
173
+ )
174
+ else:
175
+ if x_shape[1] != y_shape[0]:
176
+ return (
177
+ False,
178
+ f"dimension mismatch between {x_shape} and {y_shape} for matrix multiplication",
179
+ )
180
+ return True, None
181
+
182
+
183
+ def check_transpose(x):
184
+ ok, msg = assert_tensor(x)
185
+ if ok and len(x.get_shape()) == 1:
186
+ return (
187
+ False,
188
+ "`transpose()` cannot apply to a vector. If you want something like `a @ b.transpose()`, write `a.outer_product(b)` instead.",
189
+ )
190
+ return ok, msg