gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 1.0.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +40 -0
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
- {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
- {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools.lib +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
- gstaichi-1.0.1.dist-info/RECORD +135 -0
- gstaichi-1.0.1.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
- gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/CHANGELOG.md +0 -20
- taichi/__init__.py +0 -44
- taichi/__main__.py +0 -5
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_main.py +0 -552
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -7,34 +7,34 @@ from itertools import product
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
10
|
+
from gstaichi._lib import core as ti_python_core
|
11
|
+
from gstaichi._lib.utils import ti_python_core as _ti_python_core
|
12
|
+
from gstaichi.lang import expr, impl, runtime_ops
|
13
|
+
from gstaichi.lang import ops as ops_mod
|
14
|
+
from gstaichi.lang._ndarray import Ndarray, NdarrayHostAccess
|
15
|
+
from gstaichi.lang.common_ops import GsTaichiOperations
|
16
|
+
from gstaichi.lang.exception import (
|
17
|
+
GsTaichiRuntimeError,
|
18
|
+
GsTaichiRuntimeTypeError,
|
19
|
+
GsTaichiSyntaxError,
|
20
|
+
GsTaichiTypeError,
|
21
21
|
)
|
22
|
-
from
|
23
|
-
from
|
22
|
+
from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
|
23
|
+
from gstaichi.lang.util import (
|
24
24
|
cook_dtype,
|
25
25
|
get_traceback,
|
26
|
+
gstaichi_scope,
|
26
27
|
in_python_scope,
|
27
28
|
python_scope,
|
28
|
-
taichi_scope,
|
29
29
|
to_numpy_type,
|
30
30
|
to_paddle_type,
|
31
31
|
to_pytorch_type,
|
32
32
|
warning,
|
33
33
|
)
|
34
|
-
from
|
35
|
-
from
|
36
|
-
from
|
37
|
-
from
|
34
|
+
from gstaichi.types import primitive_types
|
35
|
+
from gstaichi.types.compound_types import CompoundType
|
36
|
+
from gstaichi.types.enums import Layout
|
37
|
+
from gstaichi.types.utils import is_signed
|
38
38
|
|
39
39
|
_type_factory = _ti_python_core.get_type_factory_instance()
|
40
40
|
|
@@ -71,7 +71,7 @@ def _gen_swizzles(cls):
|
|
71
71
|
if len(diff):
|
72
72
|
valid_attribs = tuple(sorted(valid_attribs))
|
73
73
|
pattern = tuple(pattern)
|
74
|
-
raise
|
74
|
+
raise GsTaichiSyntaxError(f"vec{instance.n} only has " f"attributes={valid_attribs}, got={pattern}")
|
75
75
|
|
76
76
|
return check
|
77
77
|
|
@@ -116,7 +116,7 @@ def _gen_swizzles(cls):
|
|
116
116
|
@python_scope
|
117
117
|
def prop_setter(instance, value):
|
118
118
|
if len(pattern) != len(value):
|
119
|
-
raise
|
119
|
+
raise GsTaichiRuntimeError(f"value len does not match the swizzle pattern={pattern}")
|
120
120
|
checker(instance, pattern)
|
121
121
|
for ch, val in zip(pattern, value):
|
122
122
|
instance[key_group.index(ch)] = val
|
@@ -138,9 +138,9 @@ def _infer_entry_dt(entry):
|
|
138
138
|
if isinstance(entry, expr.Expr):
|
139
139
|
dt = entry.ptr.get_rvalue_type()
|
140
140
|
if dt == ti_python_core.DataType_unknown:
|
141
|
-
raise
|
141
|
+
raise GsTaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
|
142
142
|
return dt
|
143
|
-
raise
|
143
|
+
raise GsTaichiTypeError("Element type of the matrix is invalid.")
|
144
144
|
|
145
145
|
|
146
146
|
def _infer_array_dt(arr):
|
@@ -204,18 +204,18 @@ def _write_host_access(x, value):
|
|
204
204
|
|
205
205
|
|
206
206
|
@_gen_swizzles
|
207
|
-
class Matrix(
|
207
|
+
class Matrix(GsTaichiOperations):
|
208
208
|
"""The matrix class.
|
209
209
|
|
210
210
|
A matrix is a 2-D rectangular array with scalar entries, it's row-majored, and is
|
211
211
|
aligned continuously. We recommend only use matrix with no more than 32 elements for
|
212
212
|
efficiency considerations.
|
213
213
|
|
214
|
-
Note: in
|
214
|
+
Note: in gstaichi a matrix is strictly two-dimensional and only stores scalars.
|
215
215
|
|
216
216
|
Args:
|
217
217
|
arr (Union[list, tuple, np.ndarray]): the initial values of a matrix.
|
218
|
-
dt (:mod:`~
|
218
|
+
dt (:mod:`~gstaichi.types.primitive_types`): the element data type.
|
219
219
|
ndim (int optional): the number of dimensions of the matrix; forced reshape if given.
|
220
220
|
|
221
221
|
Example::
|
@@ -245,13 +245,13 @@ class Matrix(TaichiOperations):
|
|
245
245
|
0
|
246
246
|
"""
|
247
247
|
|
248
|
-
|
248
|
+
_is_gstaichi_class = True
|
249
249
|
_is_matrix_class = True
|
250
250
|
__array_priority__ = 1000
|
251
251
|
|
252
252
|
def __init__(self, arr, dt=None):
|
253
253
|
if not isinstance(arr, (list, tuple, np.ndarray)):
|
254
|
-
raise
|
254
|
+
raise GsTaichiTypeError("An Matrix/Vector can only be initialized with an array-like object")
|
255
255
|
if len(arr) == 0:
|
256
256
|
self.ndim = 0
|
257
257
|
self.n, self.m = 0, 0
|
@@ -280,7 +280,7 @@ class Matrix(TaichiOperations):
|
|
280
280
|
|
281
281
|
if self.n * self.m > 32:
|
282
282
|
warning(
|
283
|
-
f"
|
283
|
+
f"GsTaichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested."
|
284
284
|
" Matrices/vectors will be automatically unrolled at compile-time for performance."
|
285
285
|
" So the compilation time could be extremely long if the matrix size is too big."
|
286
286
|
" You may use a field to store a large matrix like this, e.g.:\n"
|
@@ -308,7 +308,7 @@ class Matrix(TaichiOperations):
|
|
308
308
|
The matrix-matrix product or matrix-vector product.
|
309
309
|
|
310
310
|
"""
|
311
|
-
from
|
311
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
312
312
|
|
313
313
|
return matrix_ops.matmul(self, other)
|
314
314
|
|
@@ -418,16 +418,16 @@ class Matrix(TaichiOperations):
|
|
418
418
|
return [[_read_host_access(self.entries[i][j]) for j in range(self.m)] for i in range(self.n)]
|
419
419
|
return self.entries.tolist()
|
420
420
|
|
421
|
-
@
|
421
|
+
@gstaichi_scope
|
422
422
|
def cast(self, dtype):
|
423
423
|
"""Cast the matrix elements to a specified data type.
|
424
424
|
|
425
425
|
Args:
|
426
|
-
dtype (:mod:`~
|
426
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): data type of the
|
427
427
|
returned matrix.
|
428
428
|
|
429
429
|
Returns:
|
430
|
-
:class:`
|
430
|
+
:class:`gstaichi.Matrix`: A new matrix with the specified data dtype.
|
431
431
|
|
432
432
|
Example::
|
433
433
|
|
@@ -455,7 +455,7 @@ class Matrix(TaichiOperations):
|
|
455
455
|
5
|
456
456
|
"""
|
457
457
|
# pylint: disable-msg=C0415
|
458
|
-
from
|
458
|
+
from gstaichi.lang import matrix_ops
|
459
459
|
|
460
460
|
return matrix_ops.trace(self)
|
461
461
|
|
@@ -466,12 +466,12 @@ class Matrix(TaichiOperations):
|
|
466
466
|
The matrix dimension should be less than or equal to 4.
|
467
467
|
|
468
468
|
Returns:
|
469
|
-
:class:`~
|
469
|
+
:class:`~gstaichi.Matrix`: The inverse of a matrix.
|
470
470
|
|
471
471
|
Raises:
|
472
472
|
Exception: Inversions of matrices with sizes >= 5 are not supported.
|
473
473
|
"""
|
474
|
-
from
|
474
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
475
475
|
|
476
476
|
return matrix_ops.inverse(self)
|
477
477
|
|
@@ -492,7 +492,7 @@ class Matrix(TaichiOperations):
|
|
492
492
|
[0.6, 0.8]
|
493
493
|
"""
|
494
494
|
# pylint: disable-msg=C0415
|
495
|
-
from
|
495
|
+
from gstaichi.lang import matrix_ops
|
496
496
|
|
497
497
|
return matrix_ops.normalized(self, eps)
|
498
498
|
|
@@ -500,7 +500,7 @@ class Matrix(TaichiOperations):
|
|
500
500
|
"""Returns the transpose of a matrix.
|
501
501
|
|
502
502
|
Returns:
|
503
|
-
:class:`~
|
503
|
+
:class:`~gstaichi.Matrix`: The transpose of this matrix.
|
504
504
|
|
505
505
|
Example::
|
506
506
|
|
@@ -509,11 +509,11 @@ class Matrix(TaichiOperations):
|
|
509
509
|
[[0, 2], [1, 3]]
|
510
510
|
"""
|
511
511
|
# pylint: disable=C0415
|
512
|
-
from
|
512
|
+
from gstaichi.lang import matrix_ops
|
513
513
|
|
514
514
|
return matrix_ops.transpose(self)
|
515
515
|
|
516
|
-
@
|
516
|
+
@gstaichi_scope
|
517
517
|
def determinant(a):
|
518
518
|
"""Returns the determinant of this matrix.
|
519
519
|
|
@@ -527,7 +527,7 @@ class Matrix(TaichiOperations):
|
|
527
527
|
Exception: Determinants of matrices with sizes >= 5 are not supported.
|
528
528
|
"""
|
529
529
|
# pylint: disable=C0415
|
530
|
-
from
|
530
|
+
from gstaichi.lang import matrix_ops
|
531
531
|
|
532
532
|
return matrix_ops.determinant(a)
|
533
533
|
|
@@ -541,7 +541,7 @@ class Matrix(TaichiOperations):
|
|
541
541
|
val (TypeVar): value for the diagonal elements.
|
542
542
|
|
543
543
|
Returns:
|
544
|
-
:class:`~
|
544
|
+
:class:`~gstaichi.Matrix`: The wanted diagonal matrix.
|
545
545
|
|
546
546
|
Example::
|
547
547
|
|
@@ -551,7 +551,7 @@ class Matrix(TaichiOperations):
|
|
551
551
|
[0, 0, 1]]
|
552
552
|
"""
|
553
553
|
# pylint: disable=C0415
|
554
|
-
from
|
554
|
+
from gstaichi.lang import matrix_ops
|
555
555
|
|
556
556
|
return matrix_ops.diag(dim, val)
|
557
557
|
|
@@ -565,7 +565,7 @@ class Matrix(TaichiOperations):
|
|
565
565
|
10
|
566
566
|
"""
|
567
567
|
# pylint: disable=C0415
|
568
|
-
from
|
568
|
+
from gstaichi.lang import matrix_ops
|
569
569
|
|
570
570
|
return matrix_ops.sum(self)
|
571
571
|
|
@@ -586,12 +586,12 @@ class Matrix(TaichiOperations):
|
|
586
586
|
The square root of the sum of the absolute squares of its elements.
|
587
587
|
"""
|
588
588
|
# pylint: disable=C0415
|
589
|
-
from
|
589
|
+
from gstaichi.lang import matrix_ops
|
590
590
|
|
591
591
|
return matrix_ops.norm(self, eps=eps)
|
592
592
|
|
593
593
|
def norm_inv(self, eps=0):
|
594
|
-
"""The inverse of the matrix :func:`~
|
594
|
+
"""The inverse of the matrix :func:`~gstaichi.lang.matrix.Matrix.norm`.
|
595
595
|
|
596
596
|
Args:
|
597
597
|
eps (float): a safe-guard value for sqrt, usually 0.
|
@@ -600,28 +600,28 @@ class Matrix(TaichiOperations):
|
|
600
600
|
The inverse of the matrix/vector `norm`.
|
601
601
|
"""
|
602
602
|
# pylint: disable=C0415
|
603
|
-
from
|
603
|
+
from gstaichi.lang import matrix_ops
|
604
604
|
|
605
605
|
return matrix_ops.norm_inv(self, eps=eps)
|
606
606
|
|
607
607
|
def norm_sqr(self):
|
608
608
|
"""Returns the sum of the absolute squares of its elements."""
|
609
609
|
# pylint: disable=C0415
|
610
|
-
from
|
610
|
+
from gstaichi.lang import matrix_ops
|
611
611
|
|
612
612
|
return matrix_ops.norm_sqr(self)
|
613
613
|
|
614
614
|
def max(self):
|
615
615
|
"""Returns the maximum element value."""
|
616
616
|
# pylint: disable=C0415
|
617
|
-
from
|
617
|
+
from gstaichi.lang import matrix_ops
|
618
618
|
|
619
619
|
return matrix_ops.max(self)
|
620
620
|
|
621
621
|
def min(self):
|
622
622
|
"""Returns the minimum element value."""
|
623
623
|
# pylint: disable=C0415
|
624
|
-
from
|
624
|
+
from gstaichi.lang import matrix_ops
|
625
625
|
|
626
626
|
return matrix_ops.min(self)
|
627
627
|
|
@@ -638,7 +638,7 @@ class Matrix(TaichiOperations):
|
|
638
638
|
True
|
639
639
|
"""
|
640
640
|
# pylint: disable=C0415
|
641
|
-
from
|
641
|
+
from gstaichi.lang import matrix_ops
|
642
642
|
|
643
643
|
return matrix_ops.any(self)
|
644
644
|
|
@@ -655,7 +655,7 @@ class Matrix(TaichiOperations):
|
|
655
655
|
False
|
656
656
|
"""
|
657
657
|
# pylint: disable=C0415
|
658
|
-
from
|
658
|
+
from gstaichi.lang import matrix_ops
|
659
659
|
|
660
660
|
return matrix_ops.all(self)
|
661
661
|
|
@@ -673,7 +673,7 @@ class Matrix(TaichiOperations):
|
|
673
673
|
[-1, -1, -1, -1]
|
674
674
|
"""
|
675
675
|
# pylint: disable=C0415
|
676
|
-
from
|
676
|
+
from gstaichi.lang import matrix_ops
|
677
677
|
|
678
678
|
return matrix_ops.fill(self, val)
|
679
679
|
|
@@ -694,7 +694,7 @@ class Matrix(TaichiOperations):
|
|
694
694
|
return np.array(self.to_list())
|
695
695
|
return self.entries
|
696
696
|
|
697
|
-
@
|
697
|
+
@gstaichi_scope
|
698
698
|
def __ti_repr__(self):
|
699
699
|
yield "["
|
700
700
|
for i in range(self.n):
|
@@ -718,9 +718,9 @@ class Matrix(TaichiOperations):
|
|
718
718
|
to invoke `repr` to show the object... e.g.:
|
719
719
|
|
720
720
|
TypeError: make_const_expr_f32(): incompatible function arguments. The following argument types are supported:
|
721
|
-
1. (arg0: float) ->
|
721
|
+
1. (arg0: float) -> gstaichi_python.Expr
|
722
722
|
|
723
|
-
Invoked with: <
|
723
|
+
Invoked with: <GsTaichi 2x1 Matrix>
|
724
724
|
|
725
725
|
So we have to make it happy with a dummy string...
|
726
726
|
"""
|
@@ -731,7 +731,7 @@ class Matrix(TaichiOperations):
|
|
731
731
|
return str(self.to_numpy())
|
732
732
|
|
733
733
|
@staticmethod
|
734
|
-
@
|
734
|
+
@gstaichi_scope
|
735
735
|
def zero(dt, n, m=None):
|
736
736
|
"""Constructs a Matrix filled with zeros.
|
737
737
|
|
@@ -741,17 +741,17 @@ class Matrix(TaichiOperations):
|
|
741
741
|
m (int, optional): The second dimension (column) of the matrix.
|
742
742
|
|
743
743
|
Returns:
|
744
|
-
:class:`~
|
744
|
+
:class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with zeros.
|
745
745
|
|
746
746
|
"""
|
747
|
-
from
|
747
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
748
748
|
|
749
749
|
if m is None:
|
750
750
|
return matrix_ops._filled_vector(n, dt, 0)
|
751
751
|
return matrix_ops._filled_matrix(n, m, dt, 0)
|
752
752
|
|
753
753
|
@staticmethod
|
754
|
-
@
|
754
|
+
@gstaichi_scope
|
755
755
|
def one(dt, n, m=None):
|
756
756
|
"""Constructs a Matrix filled with ones.
|
757
757
|
|
@@ -761,17 +761,17 @@ class Matrix(TaichiOperations):
|
|
761
761
|
m (int, optional): The second dimension (column) of the matrix.
|
762
762
|
|
763
763
|
Returns:
|
764
|
-
:class:`~
|
764
|
+
:class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with ones.
|
765
765
|
|
766
766
|
"""
|
767
|
-
from
|
767
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
768
768
|
|
769
769
|
if m is None:
|
770
770
|
return matrix_ops._filled_vector(n, dt, 1)
|
771
771
|
return matrix_ops._filled_matrix(n, m, dt, 1)
|
772
772
|
|
773
773
|
@staticmethod
|
774
|
-
@
|
774
|
+
@gstaichi_scope
|
775
775
|
def unit(n, i, dt=None):
|
776
776
|
"""Constructs a n-D vector with the `i`-th entry being equal to one and
|
777
777
|
the remaining entries are all zeros.
|
@@ -779,10 +779,10 @@ class Matrix(TaichiOperations):
|
|
779
779
|
Args:
|
780
780
|
n (int): The length of the vector.
|
781
781
|
i (int): The index of the entry that will be filled with one.
|
782
|
-
dt (:mod:`~
|
782
|
+
dt (:mod:`~gstaichi.types.primitive_types`, optional): The desired data type.
|
783
783
|
|
784
784
|
Returns:
|
785
|
-
:class:`~
|
785
|
+
:class:`~gstaichi.Matrix`: The returned vector.
|
786
786
|
|
787
787
|
Example::
|
788
788
|
|
@@ -790,7 +790,7 @@ class Matrix(TaichiOperations):
|
|
790
790
|
>>> A
|
791
791
|
[0, 1, 0]
|
792
792
|
"""
|
793
|
-
from
|
793
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
794
794
|
|
795
795
|
if dt is None:
|
796
796
|
dt = int
|
@@ -798,7 +798,7 @@ class Matrix(TaichiOperations):
|
|
798
798
|
return matrix_ops._unit_vector(n, i, dt)
|
799
799
|
|
800
800
|
@staticmethod
|
801
|
-
@
|
801
|
+
@gstaichi_scope
|
802
802
|
def identity(dt, n):
|
803
803
|
"""Constructs an identity Matrix with shape (n, n).
|
804
804
|
|
@@ -807,9 +807,9 @@ class Matrix(TaichiOperations):
|
|
807
807
|
n (int): The number of rows/columns.
|
808
808
|
|
809
809
|
Returns:
|
810
|
-
:class:`~
|
810
|
+
:class:`~gstaichi.Matrix`: An `n x n` identity matrix.
|
811
811
|
"""
|
812
|
-
from
|
812
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
813
813
|
|
814
814
|
return matrix_ops._identity_matrix(n, dt)
|
815
815
|
|
@@ -846,7 +846,7 @@ class Matrix(TaichiOperations):
|
|
846
846
|
Structure (AOS) or Structure Of Array (SOA).
|
847
847
|
|
848
848
|
Returns:
|
849
|
-
:class:`~
|
849
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
850
850
|
"""
|
851
851
|
entries = []
|
852
852
|
element_dim = ndim if ndim is not None else 2
|
@@ -897,9 +897,9 @@ class Matrix(TaichiOperations):
|
|
897
897
|
|
898
898
|
if shape is None:
|
899
899
|
if offset is not None:
|
900
|
-
raise
|
900
|
+
raise GsTaichiSyntaxError("shape cannot be None when offset is set")
|
901
901
|
if order is not None:
|
902
|
-
raise
|
902
|
+
raise GsTaichiSyntaxError("shape cannot be None when order is set")
|
903
903
|
else:
|
904
904
|
if isinstance(shape, numbers.Number):
|
905
905
|
shape = (shape,)
|
@@ -907,22 +907,22 @@ class Matrix(TaichiOperations):
|
|
907
907
|
offset = (offset,)
|
908
908
|
dim = len(shape)
|
909
909
|
if offset is not None and dim != len(offset):
|
910
|
-
raise
|
910
|
+
raise GsTaichiSyntaxError(
|
911
911
|
f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
|
912
912
|
)
|
913
913
|
axis_seq = []
|
914
914
|
shape_seq = []
|
915
915
|
if order is not None:
|
916
916
|
if dim != len(order):
|
917
|
-
raise
|
917
|
+
raise GsTaichiSyntaxError(
|
918
918
|
f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
|
919
919
|
)
|
920
920
|
if dim != len(set(order)):
|
921
|
-
raise
|
921
|
+
raise GsTaichiSyntaxError("The axes in order must be different")
|
922
922
|
for ch in order:
|
923
923
|
axis = ord(ch) - ord("i")
|
924
924
|
if axis < 0 or axis >= dim:
|
925
|
-
raise
|
925
|
+
raise GsTaichiSyntaxError(f"Invalid axis {ch}")
|
926
926
|
axis_seq.append(axis)
|
927
927
|
shape_seq.append(shape[axis])
|
928
928
|
else:
|
@@ -949,7 +949,7 @@ class Matrix(TaichiOperations):
|
|
949
949
|
@classmethod
|
950
950
|
@python_scope
|
951
951
|
def ndarray(cls, n, m, dtype, shape):
|
952
|
-
"""Defines a
|
952
|
+
"""Defines a GsTaichi ndarray with matrix elements.
|
953
953
|
This function must be called in Python scope, and after `ti.init` is called.
|
954
954
|
|
955
955
|
Args:
|
@@ -960,7 +960,7 @@ class Matrix(TaichiOperations):
|
|
960
960
|
|
961
961
|
Example::
|
962
962
|
|
963
|
-
The code below shows how a
|
963
|
+
The code below shows how a GsTaichi ndarray with matrix elements \
|
964
964
|
can be declared and defined::
|
965
965
|
|
966
966
|
>>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
|
@@ -972,13 +972,13 @@ class Matrix(TaichiOperations):
|
|
972
972
|
@staticmethod
|
973
973
|
def rows(rows):
|
974
974
|
"""Constructs a matrix by concatenating a list of
|
975
|
-
vectors/lists row by row. Must be called in
|
975
|
+
vectors/lists row by row. Must be called in GsTaichi scope.
|
976
976
|
|
977
977
|
Args:
|
978
978
|
rows (List): A list of Vector (1-D Matrix) or a list of list.
|
979
979
|
|
980
980
|
Returns:
|
981
|
-
:class:`~
|
981
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
982
982
|
|
983
983
|
Example::
|
984
984
|
|
@@ -992,7 +992,7 @@ class Matrix(TaichiOperations):
|
|
992
992
|
>>> test()
|
993
993
|
[[1, 2, 3], [4, 5, 6]]
|
994
994
|
"""
|
995
|
-
from
|
995
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
996
996
|
|
997
997
|
return matrix_ops.rows(rows)
|
998
998
|
|
@@ -1004,7 +1004,7 @@ class Matrix(TaichiOperations):
|
|
1004
1004
|
cols (List): A list of Vector (1-D Matrix) or a list of list.
|
1005
1005
|
|
1006
1006
|
Returns:
|
1007
|
-
:class:`~
|
1007
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
1008
1008
|
|
1009
1009
|
Example::
|
1010
1010
|
|
@@ -1018,7 +1018,7 @@ class Matrix(TaichiOperations):
|
|
1018
1018
|
>>> test()
|
1019
1019
|
[[1, 4], [2, 5], [3, 6]]
|
1020
1020
|
"""
|
1021
|
-
from
|
1021
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1022
1022
|
|
1023
1023
|
return matrix_ops.cols(cols)
|
1024
1024
|
|
@@ -1034,7 +1034,7 @@ class Matrix(TaichiOperations):
|
|
1034
1034
|
To call this method, both multiplicatives must be vectors.
|
1035
1035
|
|
1036
1036
|
Args:
|
1037
|
-
other (:class:`~
|
1037
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1038
1038
|
|
1039
1039
|
Returns:
|
1040
1040
|
DataType: The dot product result (scalar) of the two Vectors.
|
@@ -1046,7 +1046,7 @@ class Matrix(TaichiOperations):
|
|
1046
1046
|
>>> v1.dot(v2)
|
1047
1047
|
26
|
1048
1048
|
"""
|
1049
|
-
from
|
1049
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1050
1050
|
|
1051
1051
|
return matrix_ops.dot(self, other)
|
1052
1052
|
|
@@ -1062,12 +1062,12 @@ class Matrix(TaichiOperations):
|
|
1062
1062
|
`v x w`.
|
1063
1063
|
|
1064
1064
|
Args:
|
1065
|
-
other (:class:`~
|
1065
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1066
1066
|
|
1067
1067
|
Returns:
|
1068
|
-
:class:`~
|
1068
|
+
:class:`~gstaichi.Matrix`: The cross product of the two Vectors.
|
1069
1069
|
"""
|
1070
|
-
from
|
1070
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1071
1071
|
|
1072
1072
|
return matrix_ops.cross(self, other)
|
1073
1073
|
|
@@ -1079,12 +1079,12 @@ class Matrix(TaichiOperations):
|
|
1079
1079
|
entry is equal to `xi*yj`.
|
1080
1080
|
|
1081
1081
|
Args:
|
1082
|
-
other (:class:`~
|
1082
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1083
1083
|
|
1084
1084
|
Returns:
|
1085
|
-
:class:`~
|
1085
|
+
:class:`~gstaichi.Matrix`: The outer product of the two Vectors.
|
1086
1086
|
"""
|
1087
|
-
from
|
1087
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1088
1088
|
|
1089
1089
|
return matrix_ops.outer_product(self, other)
|
1090
1090
|
|
@@ -1097,10 +1097,10 @@ class Vector(Matrix):
|
|
1097
1097
|
|
1098
1098
|
Args:
|
1099
1099
|
arr (Union[list, tuple, np.ndarray]): The initial values of the Vector.
|
1100
|
-
dt (:mod:`~
|
1100
|
+
dt (:mod:`~gstaichi.types.primitive_types`): data type of the vector.
|
1101
1101
|
|
1102
1102
|
Returns:
|
1103
|
-
:class:`~
|
1103
|
+
:class:`~gstaichi.Matrix`: A vector instance.
|
1104
1104
|
Example::
|
1105
1105
|
>>> u = ti.Vector([1, 2])
|
1106
1106
|
>>> print(u.m, u.n) # verify a vector is a matrix of shape (n, 1)
|
@@ -1125,7 +1125,7 @@ class Vector(Matrix):
|
|
1125
1125
|
@classmethod
|
1126
1126
|
@python_scope
|
1127
1127
|
def ndarray(cls, n, dtype, shape):
|
1128
|
-
"""Defines a
|
1128
|
+
"""Defines a GsTaichi ndarray with vector elements.
|
1129
1129
|
|
1130
1130
|
Args:
|
1131
1131
|
n (int): Size of the vector.
|
@@ -1133,7 +1133,7 @@ class Vector(Matrix):
|
|
1133
1133
|
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
1134
1134
|
|
1135
1135
|
Example:
|
1136
|
-
The code below shows how a
|
1136
|
+
The code below shows how a GsTaichi ndarray with vector elements can be declared and defined::
|
1137
1137
|
|
1138
1138
|
>>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
|
1139
1139
|
"""
|
@@ -1143,7 +1143,7 @@ class Vector(Matrix):
|
|
1143
1143
|
|
1144
1144
|
|
1145
1145
|
class MatrixField(Field):
|
1146
|
-
"""
|
1146
|
+
"""GsTaichi matrix field with SNode implementation.
|
1147
1147
|
|
1148
1148
|
Args:
|
1149
1149
|
vars (List[Expr]): Field members.
|
@@ -1181,7 +1181,7 @@ class MatrixField(Field):
|
|
1181
1181
|
return None
|
1182
1182
|
|
1183
1183
|
def _calc_dynamic_index_stride(self):
|
1184
|
-
# Algorithm: https://github.com/taichi-dev/
|
1184
|
+
# Algorithm: https://github.com/taichi-dev/gstaichi/issues/3810
|
1185
1185
|
paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
|
1186
1186
|
num_members = len(paths)
|
1187
1187
|
if num_members == 1:
|
@@ -1241,13 +1241,17 @@ class MatrixField(Field):
|
|
1241
1241
|
if self.ndim != 1:
|
1242
1242
|
assert len(val[0]) == self.m
|
1243
1243
|
if in_python_scope():
|
1244
|
-
from
|
1244
|
+
from gstaichi._kernels import ( # pylint: disable=C0415
|
1245
|
+
field_fill_python_scope, # pylint: disable=C0415
|
1246
|
+
)
|
1245
1247
|
|
1246
1248
|
field_fill_python_scope(self, val)
|
1247
1249
|
else:
|
1248
|
-
from
|
1250
|
+
from gstaichi._funcs import ( # pylint: disable=C0415
|
1251
|
+
field_fill_gstaichi_scope, # pylint: disable=C0415
|
1252
|
+
)
|
1249
1253
|
|
1250
|
-
|
1254
|
+
field_fill_gstaichi_scope(self, val)
|
1251
1255
|
|
1252
1256
|
@python_scope
|
1253
1257
|
def to_numpy(self, keep_dims=False, dtype=None):
|
@@ -1268,7 +1272,7 @@ class MatrixField(Field):
|
|
1268
1272
|
as_vector = self.m == 1 and not keep_dims
|
1269
1273
|
shape_ext = (self.n,) if as_vector else (self.n, self.m)
|
1270
1274
|
arr = np.zeros(self.shape + shape_ext, dtype=dtype)
|
1271
|
-
from
|
1275
|
+
from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
|
1272
1276
|
|
1273
1277
|
matrix_to_ext_arr(self, arr, as_vector)
|
1274
1278
|
runtime_ops.sync()
|
@@ -1280,7 +1284,7 @@ class MatrixField(Field):
|
|
1280
1284
|
Args:
|
1281
1285
|
device (torch.device, optional): The desired device of returned tensor.
|
1282
1286
|
keep_dims (bool, optional): Whether to keep the dimension after conversion.
|
1283
|
-
See :meth:`~
|
1287
|
+
See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
|
1284
1288
|
|
1285
1289
|
Returns:
|
1286
1290
|
torch.tensor: The result torch tensor.
|
@@ -1291,7 +1295,7 @@ class MatrixField(Field):
|
|
1291
1295
|
shape_ext = (self.n,) if as_vector else (self.n, self.m)
|
1292
1296
|
# pylint: disable=E1101
|
1293
1297
|
arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device)
|
1294
|
-
from
|
1298
|
+
from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
|
1295
1299
|
|
1296
1300
|
matrix_to_ext_arr(self, arr, as_vector)
|
1297
1301
|
runtime_ops.sync()
|
@@ -1303,7 +1307,7 @@ class MatrixField(Field):
|
|
1303
1307
|
Args:
|
1304
1308
|
place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
|
1305
1309
|
keep_dims (bool, optional): Whether to keep the dimension after conversion.
|
1306
|
-
See :meth:`~
|
1310
|
+
See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
|
1307
1311
|
|
1308
1312
|
Returns:
|
1309
1313
|
paddle.Tensor: The result paddle tensor.
|
@@ -1318,7 +1322,7 @@ class MatrixField(Field):
|
|
1318
1322
|
paddle.empty(self.shape + shape_ext, to_paddle_type(self.dtype)),
|
1319
1323
|
place=place,
|
1320
1324
|
)
|
1321
|
-
from
|
1325
|
+
from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
|
1322
1326
|
|
1323
1327
|
matrix_to_ext_arr(self, arr, as_vector)
|
1324
1328
|
runtime_ops.sync()
|
@@ -1334,7 +1338,7 @@ class MatrixField(Field):
|
|
1334
1338
|
assert len(arr.shape) == len(self.shape) + 2
|
1335
1339
|
dim_ext = 1 if as_vector else 2
|
1336
1340
|
assert len(arr.shape) == len(self.shape) + dim_ext
|
1337
|
-
from
|
1341
|
+
from gstaichi._kernels import ext_arr_to_matrix # pylint: disable=C0415
|
1338
1342
|
|
1339
1343
|
ext_arr_to_matrix(arr, self, as_vector)
|
1340
1344
|
runtime_ops.sync()
|
@@ -1420,7 +1424,7 @@ class MatrixType(CompoundType):
|
|
1420
1424
|
|
1421
1425
|
"""
|
1422
1426
|
if len(args) == 0:
|
1423
|
-
raise
|
1427
|
+
raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
|
1424
1428
|
if len(args) == 1:
|
1425
1429
|
# Init from a real Matrix
|
1426
1430
|
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
|
@@ -1452,7 +1456,7 @@ class MatrixType(CompoundType):
|
|
1452
1456
|
|
1453
1457
|
return self._instantiate(entries)
|
1454
1458
|
|
1455
|
-
def
|
1459
|
+
def from_gstaichi_object(self, func_ret, ret_index=()):
|
1456
1460
|
return self(
|
1457
1461
|
[
|
1458
1462
|
expr.Expr(
|
@@ -1475,7 +1479,7 @@ class MatrixType(CompoundType):
|
|
1475
1479
|
elif self.dtype in primitive_types.real_types:
|
1476
1480
|
get_ret_func = launch_ctx.get_struct_ret_float
|
1477
1481
|
else:
|
1478
|
-
raise
|
1482
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
|
1479
1483
|
return self([get_ret_func(ret_index + (i,)) for i in range(self.m * self.n)])
|
1480
1484
|
|
1481
1485
|
def set_kernel_struct_args(self, mat, launch_ctx, ret_index=()):
|
@@ -1487,7 +1491,7 @@ class MatrixType(CompoundType):
|
|
1487
1491
|
elif self.dtype in primitive_types.real_types:
|
1488
1492
|
set_arg_func = launch_ctx.set_struct_arg_float
|
1489
1493
|
else:
|
1490
|
-
raise
|
1494
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
|
1491
1495
|
if self.ndim == 1:
|
1492
1496
|
for i in range(self.n):
|
1493
1497
|
set_arg_func(ret_index + (i,), mat[i])
|
@@ -1505,7 +1509,7 @@ class MatrixType(CompoundType):
|
|
1505
1509
|
elif self.dtype in primitive_types.real_types:
|
1506
1510
|
set_arg_func = argpack.set_arg_float
|
1507
1511
|
else:
|
1508
|
-
raise
|
1512
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
|
1509
1513
|
if self.ndim == 1:
|
1510
1514
|
for i in range(self.n):
|
1511
1515
|
set_arg_func(ret_index + (i,), mat[i])
|
@@ -1591,7 +1595,7 @@ class VectorType(MatrixType):
|
|
1591
1595
|
|
1592
1596
|
"""
|
1593
1597
|
if len(args) == 0:
|
1594
|
-
raise
|
1598
|
+
raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
|
1595
1599
|
if len(args) == 1:
|
1596
1600
|
# Init from a real Matrix
|
1597
1601
|
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
|
@@ -1649,7 +1653,7 @@ class VectorType(MatrixType):
|
|
1649
1653
|
|
1650
1654
|
|
1651
1655
|
class MatrixNdarray(Ndarray):
|
1652
|
-
"""
|
1656
|
+
"""GsTaichi ndarray with matrix elements.
|
1653
1657
|
|
1654
1658
|
Args:
|
1655
1659
|
n (int): Number of rows of the matrix.
|
@@ -1744,7 +1748,7 @@ class MatrixNdarray(Ndarray):
|
|
1744
1748
|
|
1745
1749
|
@python_scope
|
1746
1750
|
def _fill_by_kernel(self, val):
|
1747
|
-
from
|
1751
|
+
from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
|
1748
1752
|
|
1749
1753
|
shape = self.element_type.shape()
|
1750
1754
|
n = shape[0]
|
@@ -1766,7 +1770,7 @@ class MatrixNdarray(Ndarray):
|
|
1766
1770
|
|
1767
1771
|
|
1768
1772
|
class VectorNdarray(Ndarray):
|
1769
|
-
"""
|
1773
|
+
"""GsTaichi ndarray with vector elements.
|
1770
1774
|
|
1771
1775
|
Args:
|
1772
1776
|
n (int): Size of the vector.
|
@@ -1858,7 +1862,7 @@ class VectorNdarray(Ndarray):
|
|
1858
1862
|
|
1859
1863
|
@python_scope
|
1860
1864
|
def _fill_by_kernel(self, val):
|
1861
|
-
from
|
1865
|
+
from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
|
1862
1866
|
|
1863
1867
|
shape = self.element_type.shape()
|
1864
1868
|
prim_dtype = self.element_type.element_type()
|