gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,341 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import gstaichi.lang.ops as ops_mod
|
4
|
+
from gstaichi.lang.impl import static
|
5
|
+
from gstaichi.lang.kernel_impl import func, pyfunc
|
6
|
+
from gstaichi.lang.matrix import Matrix, Vector
|
7
|
+
from gstaichi.lang.matrix_ops_utils import (
|
8
|
+
arg_at,
|
9
|
+
arg_foreach_check,
|
10
|
+
assert_list,
|
11
|
+
assert_tensor,
|
12
|
+
assert_vector,
|
13
|
+
check_matmul,
|
14
|
+
check_transpose,
|
15
|
+
dim_lt,
|
16
|
+
is_int_const,
|
17
|
+
preconditions,
|
18
|
+
same_shapes,
|
19
|
+
square_matrix,
|
20
|
+
)
|
21
|
+
from gstaichi.types.annotations import template
|
22
|
+
|
23
|
+
|
24
|
+
@preconditions(arg_at(0, assert_tensor))
|
25
|
+
@pyfunc
|
26
|
+
def _reduce(mat, fun: template()):
|
27
|
+
shape = static(mat.get_shape())
|
28
|
+
if static(len(shape) == 1):
|
29
|
+
result = mat[0]
|
30
|
+
for i in static(range(1, shape[0])):
|
31
|
+
result = fun(result, mat[i])
|
32
|
+
return result
|
33
|
+
result = mat[0, 0]
|
34
|
+
for i in static(range(shape[0])):
|
35
|
+
for j in static(range(shape[1])):
|
36
|
+
if static(i != 0 or j != 0):
|
37
|
+
result = fun(result, mat[i, j])
|
38
|
+
return result
|
39
|
+
|
40
|
+
|
41
|
+
@pyfunc
|
42
|
+
def _filled_vector(n: template(), dtype: template(), val: template()):
|
43
|
+
return Vector([val for _ in static(range(n))], dtype)
|
44
|
+
|
45
|
+
|
46
|
+
@pyfunc
|
47
|
+
def _filled_matrix(n: template(), m: template(), dtype: template(), val: template()):
|
48
|
+
return Matrix([[val for _ in static(range(m))] for _ in static(range(n))], dtype)
|
49
|
+
|
50
|
+
|
51
|
+
@pyfunc
|
52
|
+
def _unit_vector(n: template(), i: template(), dtype: template()):
|
53
|
+
return Vector([i == j for j in static(range(n))], dtype)
|
54
|
+
|
55
|
+
|
56
|
+
@pyfunc
|
57
|
+
def _identity_matrix(n: template(), dtype: template()):
|
58
|
+
return Matrix([[i == j for j in static(range(n))] for i in static(range(n))], dtype)
|
59
|
+
|
60
|
+
|
61
|
+
@preconditions(
|
62
|
+
arg_at(0, lambda xs: same_shapes(*xs)),
|
63
|
+
arg_foreach_check(
|
64
|
+
0,
|
65
|
+
fns=[assert_vector(), assert_list],
|
66
|
+
logic="or",
|
67
|
+
msg="Cols/rows must be a list of lists, or a list of vectors",
|
68
|
+
),
|
69
|
+
)
|
70
|
+
@pyfunc
|
71
|
+
def rows(rows): # pylint: disable=W0621
|
72
|
+
return Matrix([[x for x in row] for row in rows])
|
73
|
+
|
74
|
+
|
75
|
+
@pyfunc
|
76
|
+
def cols(cols): # pylint: disable=W0621
|
77
|
+
return rows(cols).transpose()
|
78
|
+
|
79
|
+
|
80
|
+
@pyfunc
|
81
|
+
def E(mat: template(), x: template(), y: template(), n: template()):
|
82
|
+
return mat[x % n, y % n]
|
83
|
+
|
84
|
+
|
85
|
+
@preconditions(square_matrix, dim_lt(0, 5))
|
86
|
+
@pyfunc
|
87
|
+
def determinant(mat):
|
88
|
+
shape = static(mat.get_shape())
|
89
|
+
if static(shape[0] == 1):
|
90
|
+
return mat[0, 0]
|
91
|
+
if static(shape[0] == 2):
|
92
|
+
return mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]
|
93
|
+
if static(shape[0] == 3):
|
94
|
+
return (
|
95
|
+
mat[0, 0] * (mat[1, 1] * mat[2, 2] - mat[2, 1] * mat[1, 2])
|
96
|
+
- mat[1, 0] * (mat[0, 1] * mat[2, 2] - mat[2, 1] * mat[0, 2])
|
97
|
+
+ mat[2, 0] * (mat[0, 1] * mat[1, 2] - mat[1, 1] * mat[0, 2])
|
98
|
+
)
|
99
|
+
if static(shape[0] == 4):
|
100
|
+
det = mat[0, 0] * 0 # keep type
|
101
|
+
for i in static(range(4)):
|
102
|
+
det = det + (-1) ** i * (
|
103
|
+
mat[i, 0]
|
104
|
+
* (
|
105
|
+
E(mat, i + 1, 1, 4)
|
106
|
+
* (E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) - E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4))
|
107
|
+
- E(mat, i + 2, 1, 4)
|
108
|
+
* (E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) - E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4))
|
109
|
+
+ E(mat, i + 3, 1, 4)
|
110
|
+
* (E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) - E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))
|
111
|
+
)
|
112
|
+
)
|
113
|
+
return det
|
114
|
+
# unreachable
|
115
|
+
return None
|
116
|
+
|
117
|
+
|
118
|
+
@preconditions(square_matrix, dim_lt(0, 5))
|
119
|
+
@pyfunc
|
120
|
+
def inverse(mat):
|
121
|
+
shape = static(mat.get_shape())
|
122
|
+
if static(shape[0] == 1):
|
123
|
+
return Matrix([[1.0 / mat[0, 0]]])
|
124
|
+
inv_determinant = 1.0 / determinant(mat)
|
125
|
+
if static(shape[0] == 2):
|
126
|
+
return inv_determinant * Matrix([[mat[1, 1], -mat[0, 1]], [-mat[1, 0], mat[0, 0]]])
|
127
|
+
if static(shape[0] == 3):
|
128
|
+
return inv_determinant * Matrix(
|
129
|
+
[
|
130
|
+
[
|
131
|
+
E(mat, i + 1, j + 1, 3) * E(mat, i + 2, j + 2, 3)
|
132
|
+
- E(mat, i + 2, j + 1, 3) * E(mat, i + 1, j + 2, 3)
|
133
|
+
for i in static(range(3))
|
134
|
+
]
|
135
|
+
for j in static(range(3))
|
136
|
+
]
|
137
|
+
)
|
138
|
+
if static(shape[0] == 4):
|
139
|
+
return inv_determinant * Matrix(
|
140
|
+
[
|
141
|
+
[
|
142
|
+
(-1) ** (i + j)
|
143
|
+
* (
|
144
|
+
(
|
145
|
+
E(mat, i + 1, j + 1, 4)
|
146
|
+
* (
|
147
|
+
E(mat, i + 2, j + 2, 4) * E(mat, i + 3, j + 3, 4)
|
148
|
+
- E(mat, i + 3, j + 2, 4) * E(mat, i + 2, j + 3, 4)
|
149
|
+
)
|
150
|
+
- E(mat, i + 2, j + 1, 4)
|
151
|
+
* (
|
152
|
+
E(mat, i + 1, j + 2, 4) * E(mat, i + 3, j + 3, 4)
|
153
|
+
- E(mat, i + 3, j + 2, 4) * E(mat, i + 1, j + 3, 4)
|
154
|
+
)
|
155
|
+
+ E(mat, i + 3, j + 1, 4)
|
156
|
+
* (
|
157
|
+
E(mat, i + 1, j + 2, 4) * E(mat, i + 2, j + 3, 4)
|
158
|
+
- E(mat, i + 2, j + 2, 4) * E(mat, i + 1, j + 3, 4)
|
159
|
+
)
|
160
|
+
)
|
161
|
+
)
|
162
|
+
for i in static(range(4))
|
163
|
+
]
|
164
|
+
for j in static(range(4))
|
165
|
+
]
|
166
|
+
)
|
167
|
+
# unreachable
|
168
|
+
return None
|
169
|
+
|
170
|
+
|
171
|
+
@preconditions(check_transpose)
|
172
|
+
@pyfunc
|
173
|
+
def transpose(mat):
|
174
|
+
shape = static(mat.get_shape())
|
175
|
+
return Matrix([[mat[i, j] for i in static(range(shape[0]))] for j in static(range(shape[1]))])
|
176
|
+
|
177
|
+
|
178
|
+
@preconditions(arg_at(0, is_int_const))
|
179
|
+
@pyfunc
|
180
|
+
def diag(dim: template(), val: template()):
|
181
|
+
return Matrix([[val if i == j else 0 for j in static(range(dim))] for i in static(range(dim))])
|
182
|
+
|
183
|
+
|
184
|
+
@preconditions(assert_tensor)
|
185
|
+
@pyfunc
|
186
|
+
def sum(mat): # pylint: disable=W0622
|
187
|
+
return _reduce(mat, ops_mod.add)
|
188
|
+
|
189
|
+
|
190
|
+
@preconditions(assert_tensor)
|
191
|
+
@pyfunc
|
192
|
+
def norm_sqr(mat):
|
193
|
+
return sum(mat * mat)
|
194
|
+
|
195
|
+
|
196
|
+
@preconditions(arg_at(0, assert_tensor))
|
197
|
+
@pyfunc
|
198
|
+
def norm(mat, eps=0.0):
|
199
|
+
return ops_mod.sqrt(norm_sqr(mat) + eps)
|
200
|
+
|
201
|
+
|
202
|
+
@preconditions(arg_at(0, assert_tensor))
|
203
|
+
@pyfunc
|
204
|
+
def norm_inv(mat, eps=0.0):
|
205
|
+
return ops_mod.rsqrt(norm_sqr(mat) + eps)
|
206
|
+
|
207
|
+
|
208
|
+
@preconditions(arg_at(0, assert_vector()))
|
209
|
+
@pyfunc
|
210
|
+
def normalized(vec, eps=0.0):
|
211
|
+
invlen = 1 / (norm(vec) + eps)
|
212
|
+
return invlen * vec
|
213
|
+
|
214
|
+
|
215
|
+
@preconditions(assert_tensor)
|
216
|
+
@pyfunc
|
217
|
+
def any(mat): # pylint: disable=W0622
|
218
|
+
return _reduce(mat != 0, ops_mod.logical_or) and True
|
219
|
+
|
220
|
+
|
221
|
+
@preconditions(assert_tensor)
|
222
|
+
@pyfunc
|
223
|
+
def all(mat): # pylint: disable=W0622
|
224
|
+
return _reduce(mat != 0, ops_mod.logical_and) and True
|
225
|
+
|
226
|
+
|
227
|
+
@preconditions(assert_tensor)
|
228
|
+
@pyfunc
|
229
|
+
def max(mat): # pylint: disable=W0622
|
230
|
+
return _reduce(mat, ops_mod.max_impl)
|
231
|
+
|
232
|
+
|
233
|
+
@preconditions(assert_tensor)
|
234
|
+
@pyfunc
|
235
|
+
def min(mat): # pylint: disable=W0622
|
236
|
+
return _reduce(mat, ops_mod.min_impl)
|
237
|
+
|
238
|
+
|
239
|
+
@preconditions(square_matrix)
|
240
|
+
@pyfunc
|
241
|
+
def trace(mat):
|
242
|
+
shape = static(mat.get_shape())
|
243
|
+
result = mat[0, 0]
|
244
|
+
# TODO: get rid of static when
|
245
|
+
# CHI IR Tensor repr is ready stable
|
246
|
+
for i in static(range(1, shape[0])):
|
247
|
+
result = result + mat[i, i]
|
248
|
+
return result
|
249
|
+
|
250
|
+
|
251
|
+
@preconditions(arg_at(0, assert_tensor))
|
252
|
+
@pyfunc
|
253
|
+
def fill(mat: template(), val):
|
254
|
+
shape = static(mat.get_shape())
|
255
|
+
if static(len(shape) == 1):
|
256
|
+
for i in static(range(shape[0])):
|
257
|
+
mat[i] = val
|
258
|
+
else:
|
259
|
+
for i in static(range(shape[0])):
|
260
|
+
for j in static(range(shape[1])):
|
261
|
+
mat[i, j] = val
|
262
|
+
|
263
|
+
|
264
|
+
@preconditions(check_matmul)
|
265
|
+
@pyfunc
|
266
|
+
def _matmul_helper(mat_x, mat_y):
|
267
|
+
shape_x = static(mat_x.get_shape())
|
268
|
+
shape_y = static(mat_y.get_shape())
|
269
|
+
if static(len(shape_x) == 1 and len(shape_y) == 1):
|
270
|
+
return dot(mat_x, mat_y)
|
271
|
+
if static(len(shape_y) == 1):
|
272
|
+
zero_elem = mat_x[0, 0] * mat_y[0] * 0 # for correct return type
|
273
|
+
vec_z = _filled_vector(shape_x[0], None, zero_elem)
|
274
|
+
for i in static(range(shape_x[0])):
|
275
|
+
for j in static(range(shape_x[1])):
|
276
|
+
vec_z[i] = vec_z[i] + mat_x[i, j] * mat_y[j]
|
277
|
+
return vec_z
|
278
|
+
zero_elem = mat_x[0, 0] * mat_y[0, 0] * 0 # for correct return type
|
279
|
+
mat_z = _filled_matrix(shape_x[0], shape_y[1], None, zero_elem)
|
280
|
+
for i in static(range(shape_x[0])):
|
281
|
+
for j in static(range(shape_y[1])):
|
282
|
+
for k in static(range(shape_x[1])):
|
283
|
+
mat_z[i, j] = mat_z[i, j] + mat_x[i, k] * mat_y[k, j]
|
284
|
+
return mat_z
|
285
|
+
|
286
|
+
|
287
|
+
@pyfunc
|
288
|
+
def matmul(mat_x, mat_y):
|
289
|
+
shape_x = static(mat_x.get_shape())
|
290
|
+
shape_y = static(mat_y.get_shape())
|
291
|
+
if static(len(shape_x) == 1 and len(shape_y) == 2):
|
292
|
+
return _matmul_helper(transpose(mat_y), mat_x)
|
293
|
+
return _matmul_helper(mat_x, mat_y)
|
294
|
+
|
295
|
+
|
296
|
+
@preconditions(
|
297
|
+
arg_at(0, assert_vector("lhs for dot is not a vector")),
|
298
|
+
arg_at(1, assert_vector("rhs for dot is not a vector")),
|
299
|
+
)
|
300
|
+
@pyfunc
|
301
|
+
def dot(vec_x, vec_y):
|
302
|
+
return sum(vec_x * vec_y)
|
303
|
+
|
304
|
+
|
305
|
+
@preconditions(
|
306
|
+
arg_at(0, assert_vector("lhs for cross is not a vector")),
|
307
|
+
arg_at(1, assert_vector("rhs for cross is not a vector")),
|
308
|
+
same_shapes,
|
309
|
+
arg_at(0, dim_lt(0, 4)),
|
310
|
+
)
|
311
|
+
@pyfunc
|
312
|
+
def cross(vec_x, vec_y):
|
313
|
+
shape = static(vec_x.get_shape())
|
314
|
+
if static(shape[0] == 2):
|
315
|
+
return vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0]
|
316
|
+
if static(shape[0] == 3):
|
317
|
+
return Vector(
|
318
|
+
[
|
319
|
+
vec_x[1] * vec_y[2] - vec_x[2] * vec_y[1],
|
320
|
+
vec_x[2] * vec_y[0] - vec_x[0] * vec_y[2],
|
321
|
+
vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0],
|
322
|
+
]
|
323
|
+
)
|
324
|
+
return None
|
325
|
+
|
326
|
+
|
327
|
+
@preconditions(
|
328
|
+
arg_at(0, assert_vector("lhs for outer_product is not a vector")),
|
329
|
+
arg_at(1, assert_vector("rhs for outer_product is not a vector")),
|
330
|
+
)
|
331
|
+
@pyfunc
|
332
|
+
def outer_product(vec_x, vec_y):
|
333
|
+
shape_x = static(vec_x.get_shape())
|
334
|
+
shape_y = static(vec_y.get_shape())
|
335
|
+
return Matrix([[vec_x[i] * vec_y[j] for j in static(range(shape_y[0]))] for i in static(range(shape_x[0]))])
|
336
|
+
|
337
|
+
|
338
|
+
@preconditions(assert_tensor)
|
339
|
+
@func
|
340
|
+
def cast(mat, dtype: template()):
|
341
|
+
return ops_mod.cast(mat, dtype)
|
@@ -0,0 +1,190 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import functools
|
4
|
+
|
5
|
+
from gstaichi.lang.exception import GsTaichiCompilationError
|
6
|
+
from gstaichi.lang.expr import Expr
|
7
|
+
from gstaichi.lang.matrix import Matrix
|
8
|
+
|
9
|
+
|
10
|
+
def do_check(checker_fns, *args, **kwargs):
|
11
|
+
for f in checker_fns:
|
12
|
+
ok, msg = f(*args, **kwargs)
|
13
|
+
if not ok:
|
14
|
+
return False, msg
|
15
|
+
return True, None
|
16
|
+
|
17
|
+
|
18
|
+
def preconditions(*checker_funcs):
|
19
|
+
def decorator(func):
|
20
|
+
@functools.wraps(func)
|
21
|
+
def wrapper(*args, **kwargs):
|
22
|
+
ok, msg = do_check(checker_funcs, *args, **kwargs)
|
23
|
+
if not ok:
|
24
|
+
raise GsTaichiCompilationError(msg)
|
25
|
+
return func(*args, **kwargs)
|
26
|
+
|
27
|
+
return wrapper
|
28
|
+
|
29
|
+
return decorator
|
30
|
+
|
31
|
+
|
32
|
+
def arg_at(indices, *fns):
|
33
|
+
def check(*args, **kwargs):
|
34
|
+
nonlocal indices
|
35
|
+
if isinstance(indices, int):
|
36
|
+
indices = [indices]
|
37
|
+
for i in indices:
|
38
|
+
if i in kwargs:
|
39
|
+
arg = kwargs[i]
|
40
|
+
else:
|
41
|
+
arg = args[i]
|
42
|
+
ok, msg = do_check(fns, arg)
|
43
|
+
if not ok:
|
44
|
+
return False, msg
|
45
|
+
return True, None
|
46
|
+
|
47
|
+
return check
|
48
|
+
|
49
|
+
|
50
|
+
def assert_tensor(m, msg="not tensor type: {}"):
|
51
|
+
if isinstance(m, Matrix):
|
52
|
+
return True, None
|
53
|
+
if isinstance(m, Expr) and m.is_tensor():
|
54
|
+
return True, None
|
55
|
+
return False, msg.format(type(m))
|
56
|
+
|
57
|
+
|
58
|
+
def assert_vector(msg="expected a vector, got {}"):
|
59
|
+
def check(v):
|
60
|
+
if (isinstance(v, Expr) or isinstance(v, Matrix)) and len(v.get_shape()) == 1:
|
61
|
+
return True, None
|
62
|
+
return False, msg.format(type(v))
|
63
|
+
|
64
|
+
return check
|
65
|
+
|
66
|
+
|
67
|
+
def assert_list(x, msg="not a list: {}"):
|
68
|
+
if isinstance(x, list):
|
69
|
+
return True, None
|
70
|
+
return False, msg.format(type(x))
|
71
|
+
|
72
|
+
|
73
|
+
def arg_foreach_check(*arg_indices, fns=[], logic="or", msg=None):
|
74
|
+
def check(*args, **kwargs):
|
75
|
+
for i in arg_indices:
|
76
|
+
if i in kwargs:
|
77
|
+
arg = kwargs[i]
|
78
|
+
else:
|
79
|
+
arg = args[i]
|
80
|
+
if logic == "or":
|
81
|
+
for a in arg:
|
82
|
+
passed = False
|
83
|
+
for fn in fns:
|
84
|
+
ok, _ = do_check([fn], a)
|
85
|
+
if ok:
|
86
|
+
passed = True
|
87
|
+
break
|
88
|
+
if not passed:
|
89
|
+
return False, msg
|
90
|
+
elif logic == "and":
|
91
|
+
for a in arg:
|
92
|
+
ok, _ = do_check(fns, a)
|
93
|
+
if not ok:
|
94
|
+
return False, msg
|
95
|
+
else:
|
96
|
+
raise ValueError(f"Unknown logic: {logic}")
|
97
|
+
return True, None
|
98
|
+
|
99
|
+
return check
|
100
|
+
|
101
|
+
|
102
|
+
def get_list_shape(x):
|
103
|
+
outer_shape = [len(x)]
|
104
|
+
inner_shape = None
|
105
|
+
for element in x:
|
106
|
+
if isinstance(element, list):
|
107
|
+
cur_shape = get_list_shape(element)
|
108
|
+
else:
|
109
|
+
cur_shape = []
|
110
|
+
|
111
|
+
if inner_shape:
|
112
|
+
assert curr_shape == inner_shape
|
113
|
+
else:
|
114
|
+
inner_shape = cur_shape
|
115
|
+
|
116
|
+
return outer_shape + inner_shape
|
117
|
+
|
118
|
+
|
119
|
+
def same_shapes(*xs):
|
120
|
+
shapes = []
|
121
|
+
for x in xs:
|
122
|
+
if isinstance(x, Matrix):
|
123
|
+
shapes.append(x.get_shape())
|
124
|
+
elif isinstance(x, list):
|
125
|
+
shapes.append(tuple(get_list_shape(x)))
|
126
|
+
elif isinstance(x, Expr):
|
127
|
+
shapes.append(tuple(x.ptr.get_rvalue_type().shape()))
|
128
|
+
else:
|
129
|
+
return False, f"same_shapes() received an unexpected argument of type: {x}"
|
130
|
+
|
131
|
+
if len(set(shapes)) != 1:
|
132
|
+
return False, f"required shapes to be the same, got shapes {shapes}"
|
133
|
+
return True, None
|
134
|
+
|
135
|
+
|
136
|
+
def square_matrix(x):
|
137
|
+
assert_tensor(x)
|
138
|
+
shape = x.get_shape()
|
139
|
+
if len(shape) != 2 or shape[0] != shape[1]:
|
140
|
+
return False, f"expected a square matrix, got shape {shape}"
|
141
|
+
return True, None
|
142
|
+
|
143
|
+
|
144
|
+
def dim_lt(dim, limit):
|
145
|
+
def check(x):
|
146
|
+
assert_tensor(x)
|
147
|
+
shape = x.get_shape()
|
148
|
+
return shape[dim] < limit, (f"only dimension < {limit} is supported, got shape {shape}")
|
149
|
+
|
150
|
+
return check
|
151
|
+
|
152
|
+
|
153
|
+
def is_int_const(x):
|
154
|
+
if isinstance(x, int):
|
155
|
+
return True, None
|
156
|
+
if isinstance(x, Expr) and x.val_int() is not None:
|
157
|
+
return True, None
|
158
|
+
return False, f"not an integer: {x} of type {type(x).__name__}"
|
159
|
+
|
160
|
+
|
161
|
+
def check_matmul(x, y):
|
162
|
+
assert_tensor(x, f"left hand side is not a matrix: {type(x)}")
|
163
|
+
assert_tensor(y, f"right hand side is not a matrix: {type(y)}")
|
164
|
+
x_shape = x.get_shape()
|
165
|
+
y_shape = y.get_shape()
|
166
|
+
if len(x_shape) == 1:
|
167
|
+
if len(y_shape) == 1:
|
168
|
+
return True, None
|
169
|
+
if x_shape[0] != y_shape[0]:
|
170
|
+
return (
|
171
|
+
False,
|
172
|
+
f"dimension mismatch between {x_shape} and {y_shape} for left multiplication",
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
if x_shape[1] != y_shape[0]:
|
176
|
+
return (
|
177
|
+
False,
|
178
|
+
f"dimension mismatch between {x_shape} and {y_shape} for matrix multiplication",
|
179
|
+
)
|
180
|
+
return True, None
|
181
|
+
|
182
|
+
|
183
|
+
def check_transpose(x):
|
184
|
+
ok, msg = assert_tensor(x)
|
185
|
+
if ok and len(x.get_shape()) == 1:
|
186
|
+
return (
|
187
|
+
False,
|
188
|
+
"`transpose()` cannot apply to a vector. If you want something like `a @ b.transpose()`, write `a.outer_product(b)` instead.",
|
189
|
+
)
|
190
|
+
return ok, msg
|