gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/matrix.py
ADDED
@@ -0,0 +1,1835 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import numbers
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from itertools import product
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from gstaichi._lib import core as ti_python_core
|
11
|
+
from gstaichi._lib.utils import ti_python_core as _ti_python_core
|
12
|
+
from gstaichi.lang import expr, impl, runtime_ops
|
13
|
+
from gstaichi.lang import ops as ops_mod
|
14
|
+
from gstaichi.lang._ndarray import Ndarray, NdarrayHostAccess
|
15
|
+
from gstaichi.lang.common_ops import GsTaichiOperations
|
16
|
+
from gstaichi.lang.exception import (
|
17
|
+
GsTaichiRuntimeError,
|
18
|
+
GsTaichiRuntimeTypeError,
|
19
|
+
GsTaichiSyntaxError,
|
20
|
+
GsTaichiTypeError,
|
21
|
+
)
|
22
|
+
from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
|
23
|
+
from gstaichi.lang.util import (
|
24
|
+
cook_dtype,
|
25
|
+
get_traceback,
|
26
|
+
gstaichi_scope,
|
27
|
+
in_python_scope,
|
28
|
+
python_scope,
|
29
|
+
to_numpy_type,
|
30
|
+
to_pytorch_type,
|
31
|
+
warning,
|
32
|
+
)
|
33
|
+
from gstaichi.types import primitive_types
|
34
|
+
from gstaichi.types.compound_types import CompoundType
|
35
|
+
from gstaichi.types.enums import Layout
|
36
|
+
from gstaichi.types.utils import is_signed
|
37
|
+
|
38
|
+
_type_factory = _ti_python_core.get_type_factory_instance()
|
39
|
+
|
40
|
+
|
41
|
+
def _generate_swizzle_patterns(key_group: str, required_length=4):
|
42
|
+
"""Generate vector swizzle patterns from a given set of characters.
|
43
|
+
|
44
|
+
Example:
|
45
|
+
|
46
|
+
For `key_group=xyzw` and `required_length=4`, this function will return a
|
47
|
+
list consists of all possible strings (no repeats) in characters
|
48
|
+
`x`, `y`, `z`, `w` and of length<=4:
|
49
|
+
[`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...]
|
50
|
+
The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340.
|
51
|
+
"""
|
52
|
+
result = []
|
53
|
+
for k in range(1, required_length + 1):
|
54
|
+
result.extend(product(key_group, repeat=k))
|
55
|
+
result = ["".join(pat) for pat in result]
|
56
|
+
return result
|
57
|
+
|
58
|
+
|
59
|
+
def _gen_swizzles(cls):
|
60
|
+
# https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
|
61
|
+
KEYGROUP_SET = ["xyzw", "rgba", "stpq"]
|
62
|
+
cls._swizzle_to_keygroup = {}
|
63
|
+
cls._keygroup_to_checker = {}
|
64
|
+
|
65
|
+
def make_valid_attribs_checker(key_group):
|
66
|
+
def check(instance, pattern):
|
67
|
+
valid_attribs = set(key_group[: instance.n])
|
68
|
+
pattern_set = set(pattern)
|
69
|
+
diff = pattern_set - valid_attribs
|
70
|
+
if len(diff):
|
71
|
+
valid_attribs = tuple(sorted(valid_attribs))
|
72
|
+
pattern = tuple(pattern)
|
73
|
+
raise GsTaichiSyntaxError(f"vec{instance.n} only has " f"attributes={valid_attribs}, got={pattern}")
|
74
|
+
|
75
|
+
return check
|
76
|
+
|
77
|
+
for key_group in KEYGROUP_SET:
|
78
|
+
cls._keygroup_to_checker[key_group] = make_valid_attribs_checker(key_group)
|
79
|
+
for index, attr in enumerate(key_group):
|
80
|
+
|
81
|
+
def gen_property(attr, attr_idx, key_group):
|
82
|
+
checker = cls._keygroup_to_checker[key_group]
|
83
|
+
|
84
|
+
def prop_getter(instance):
|
85
|
+
checker(instance, attr)
|
86
|
+
return instance[attr_idx]
|
87
|
+
|
88
|
+
@python_scope
|
89
|
+
def prop_setter(instance, value):
|
90
|
+
checker(instance, attr)
|
91
|
+
instance[attr_idx] = value
|
92
|
+
|
93
|
+
return property(prop_getter, prop_setter)
|
94
|
+
|
95
|
+
prop = gen_property(attr, index, key_group)
|
96
|
+
setattr(cls, attr, prop)
|
97
|
+
cls._swizzle_to_keygroup[attr] = key_group
|
98
|
+
|
99
|
+
for key_group in KEYGROUP_SET:
|
100
|
+
sw_patterns = _generate_swizzle_patterns(key_group, required_length=4)
|
101
|
+
# len=1 accessors are handled specially above
|
102
|
+
sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
|
103
|
+
for prop_key in sw_patterns:
|
104
|
+
# Create a function for value capturing
|
105
|
+
def gen_property(pattern, key_group):
|
106
|
+
checker = cls._keygroup_to_checker[key_group]
|
107
|
+
|
108
|
+
def prop_getter(instance):
|
109
|
+
checker(instance, pattern)
|
110
|
+
res = []
|
111
|
+
for ch in pattern:
|
112
|
+
res.append(instance[key_group.index(ch)])
|
113
|
+
return Vector(res)
|
114
|
+
|
115
|
+
@python_scope
|
116
|
+
def prop_setter(instance, value):
|
117
|
+
if len(pattern) != len(value):
|
118
|
+
raise GsTaichiRuntimeError(f"value len does not match the swizzle pattern={pattern}")
|
119
|
+
checker(instance, pattern)
|
120
|
+
for ch, val in zip(pattern, value):
|
121
|
+
instance[key_group.index(ch)] = val
|
122
|
+
|
123
|
+
prop = property(prop_getter, prop_setter)
|
124
|
+
return prop
|
125
|
+
|
126
|
+
prop = gen_property(prop_key, key_group)
|
127
|
+
setattr(cls, prop_key, prop)
|
128
|
+
cls._swizzle_to_keygroup[prop_key] = key_group
|
129
|
+
return cls
|
130
|
+
|
131
|
+
|
132
|
+
def _infer_entry_dt(entry):
|
133
|
+
if isinstance(entry, (int, np.integer)):
|
134
|
+
return impl.get_runtime().default_ip
|
135
|
+
if isinstance(entry, (float, np.floating)):
|
136
|
+
return impl.get_runtime().default_fp
|
137
|
+
if isinstance(entry, expr.Expr):
|
138
|
+
dt = entry.ptr.get_rvalue_type()
|
139
|
+
if dt == ti_python_core.DataType_unknown:
|
140
|
+
raise GsTaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
|
141
|
+
return dt
|
142
|
+
raise GsTaichiTypeError("Element type of the matrix is invalid.")
|
143
|
+
|
144
|
+
|
145
|
+
def _infer_array_dt(arr):
|
146
|
+
assert len(arr) > 0
|
147
|
+
return functools.reduce(ti_python_core.promoted_type, map(_infer_entry_dt, arr))
|
148
|
+
|
149
|
+
|
150
|
+
def make_matrix_with_shape(arr, shape, dt):
|
151
|
+
return expr.Expr(
|
152
|
+
impl.get_runtime()
|
153
|
+
.compiling_callable.ast_builder()
|
154
|
+
.make_matrix_expr(
|
155
|
+
shape,
|
156
|
+
dt,
|
157
|
+
[expr.Expr(elt).ptr for elt in arr],
|
158
|
+
ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
159
|
+
)
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def make_matrix(arr, dt=None):
|
164
|
+
if len(arr) == 0:
|
165
|
+
# the only usage of an empty vector is to serve as field indices
|
166
|
+
shape = [0]
|
167
|
+
dt = primitive_types.i32
|
168
|
+
else:
|
169
|
+
if isinstance(arr[0], Iterable): # matrix
|
170
|
+
shape = [len(arr), len(arr[0])]
|
171
|
+
arr = [elt for row in arr for elt in row]
|
172
|
+
else: # vector
|
173
|
+
shape = [len(arr)]
|
174
|
+
if dt is None:
|
175
|
+
dt = _infer_array_dt(arr)
|
176
|
+
else:
|
177
|
+
dt = cook_dtype(dt)
|
178
|
+
return expr.Expr(
|
179
|
+
impl.get_runtime()
|
180
|
+
.compiling_callable.ast_builder()
|
181
|
+
.make_matrix_expr(
|
182
|
+
shape,
|
183
|
+
dt,
|
184
|
+
[expr.Expr(elt).ptr for elt in arr],
|
185
|
+
ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
186
|
+
)
|
187
|
+
)
|
188
|
+
|
189
|
+
|
190
|
+
def _read_host_access(x):
|
191
|
+
if isinstance(x, SNodeHostAccess):
|
192
|
+
return x.accessor.getter(*x.key)
|
193
|
+
assert isinstance(x, NdarrayHostAccess)
|
194
|
+
return x.getter()
|
195
|
+
|
196
|
+
|
197
|
+
def _write_host_access(x, value):
|
198
|
+
if isinstance(x, SNodeHostAccess):
|
199
|
+
x.accessor.setter(value, *x.key)
|
200
|
+
else:
|
201
|
+
assert isinstance(x, NdarrayHostAccess)
|
202
|
+
x.setter(value)
|
203
|
+
|
204
|
+
|
205
|
+
@_gen_swizzles
|
206
|
+
class Matrix(GsTaichiOperations):
|
207
|
+
"""The matrix class.
|
208
|
+
|
209
|
+
A matrix is a 2-D rectangular array with scalar entries, it's row-majored, and is
|
210
|
+
aligned continuously. We recommend only use matrix with no more than 32 elements for
|
211
|
+
efficiency considerations.
|
212
|
+
|
213
|
+
Note: in gstaichi a matrix is strictly two-dimensional and only stores scalars.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
arr (Union[list, tuple, np.ndarray]): the initial values of a matrix.
|
217
|
+
dt (:mod:`~gstaichi.types.primitive_types`): the element data type.
|
218
|
+
ndim (int optional): the number of dimensions of the matrix; forced reshape if given.
|
219
|
+
|
220
|
+
Example::
|
221
|
+
|
222
|
+
use a 2d list to initialize a matrix
|
223
|
+
|
224
|
+
>>> @ti.kernel
|
225
|
+
>>> def test():
|
226
|
+
>>> n = 5
|
227
|
+
>>> M = ti.Matrix([[0] * n for _ in range(n)], ti.i32)
|
228
|
+
>>> print(M) # a 5x5 matrix with integer elements
|
229
|
+
|
230
|
+
get the number of rows and columns via the `n`, `m` property:
|
231
|
+
|
232
|
+
>>> M = ti.Matrix([[0, 1], [2, 3], [4, 5]], ti.i32)
|
233
|
+
>>> M.n # number of rows
|
234
|
+
3
|
235
|
+
>>> M.m # number of cols
|
236
|
+
>>> 2
|
237
|
+
|
238
|
+
you can even initialize a matrix with an empty list:
|
239
|
+
|
240
|
+
>>> M = ti.Matrix([[], []], ti.i32)
|
241
|
+
>>> M.n
|
242
|
+
2
|
243
|
+
>>> M.m
|
244
|
+
0
|
245
|
+
"""
|
246
|
+
|
247
|
+
_is_gstaichi_class = True
|
248
|
+
_is_matrix_class = True
|
249
|
+
__array_priority__ = 1000
|
250
|
+
|
251
|
+
def __init__(self, arr, dt=None):
|
252
|
+
if not isinstance(arr, (list, tuple, np.ndarray)):
|
253
|
+
raise GsTaichiTypeError("An Matrix/Vector can only be initialized with an array-like object")
|
254
|
+
if len(arr) == 0:
|
255
|
+
self.ndim = 0
|
256
|
+
self.n, self.m = 0, 0
|
257
|
+
self.entries = np.array([])
|
258
|
+
self.is_host_access = False
|
259
|
+
elif isinstance(arr[0], Matrix):
|
260
|
+
raise Exception("cols/rows required when using list of vectors")
|
261
|
+
elif isinstance(arr[0], Iterable): # matrix
|
262
|
+
self.ndim = 2
|
263
|
+
self.n, self.m = len(arr), len(arr[0])
|
264
|
+
if isinstance(arr[0][0], (SNodeHostAccess, NdarrayHostAccess)):
|
265
|
+
self.entries = arr
|
266
|
+
self.is_host_access = True
|
267
|
+
else:
|
268
|
+
self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
|
269
|
+
self.is_host_access = False
|
270
|
+
else: # vector
|
271
|
+
self.ndim = 1
|
272
|
+
self.n, self.m = len(arr), 1
|
273
|
+
if isinstance(arr[0], (SNodeHostAccess, NdarrayHostAccess)):
|
274
|
+
self.entries = arr
|
275
|
+
self.is_host_access = True
|
276
|
+
else:
|
277
|
+
self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
|
278
|
+
self.is_host_access = False
|
279
|
+
|
280
|
+
if self.n * self.m > 32:
|
281
|
+
warning(
|
282
|
+
f"GsTaichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested."
|
283
|
+
" Matrices/vectors will be automatically unrolled at compile-time for performance."
|
284
|
+
" So the compilation time could be extremely long if the matrix size is too big."
|
285
|
+
" You may use a field to store a large matrix like this, e.g.:\n"
|
286
|
+
f" x = ti.field(ti.f32, ({self.n}, {self.m})).\n"
|
287
|
+
" See https://docs.taichi-lang.org/docs/field#matrix-size"
|
288
|
+
" for more details.",
|
289
|
+
UserWarning,
|
290
|
+
stacklevel=2,
|
291
|
+
)
|
292
|
+
|
293
|
+
def get_shape(self):
|
294
|
+
if self.ndim == 1:
|
295
|
+
return (self.n,)
|
296
|
+
if self.ndim == 2:
|
297
|
+
return (self.n, self.m)
|
298
|
+
return None
|
299
|
+
|
300
|
+
def __matmul__(self, other):
|
301
|
+
"""Matrix-matrix or matrix-vector multiply.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
other (Union[Matrix, Vector]): a matrix or a vector.
|
305
|
+
|
306
|
+
Returns:
|
307
|
+
The matrix-matrix product or matrix-vector product.
|
308
|
+
|
309
|
+
"""
|
310
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
311
|
+
|
312
|
+
return matrix_ops.matmul(self, other)
|
313
|
+
|
314
|
+
# host access & python scope operation
|
315
|
+
def __len__(self):
|
316
|
+
"""Get the length of each row of a matrix"""
|
317
|
+
# TODO: When this is a vector, should return its dimension?
|
318
|
+
return self.n
|
319
|
+
|
320
|
+
def __iter__(self):
|
321
|
+
if self.ndim == 1:
|
322
|
+
return (self[i] for i in range(self.n))
|
323
|
+
return ([self[i, j] for j in range(self.m)] for i in range(self.n))
|
324
|
+
|
325
|
+
def __getitem__(self, indices):
|
326
|
+
"""Access to the element at the given indices in a matrix.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
indices (Sequence[Expr]): the indices of the element.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
The value of the element at a specific position of a matrix.
|
333
|
+
|
334
|
+
"""
|
335
|
+
entry = self._get_entry(indices)
|
336
|
+
if self.is_host_access:
|
337
|
+
return _read_host_access(entry)
|
338
|
+
return entry
|
339
|
+
|
340
|
+
@python_scope
|
341
|
+
def __setitem__(self, indices, item):
|
342
|
+
"""Set the element value at the given indices in a matrix.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
indices (Sequence[Expr]): the indices of a element.
|
346
|
+
|
347
|
+
"""
|
348
|
+
if self.is_host_access:
|
349
|
+
entry = self._get_entry(indices)
|
350
|
+
_write_host_access(entry, item)
|
351
|
+
else:
|
352
|
+
if not isinstance(indices, (list, tuple)):
|
353
|
+
indices = [indices]
|
354
|
+
assert len(indices) in [1, 2]
|
355
|
+
assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
|
356
|
+
if self.ndim == 1:
|
357
|
+
self.entries[indices[0]] = item
|
358
|
+
else:
|
359
|
+
self.entries[indices[0]][indices[1]] = item
|
360
|
+
|
361
|
+
def _get_entry(self, indices):
|
362
|
+
if not isinstance(indices, (list, tuple)):
|
363
|
+
indices = [indices]
|
364
|
+
assert len(indices) in [1, 2]
|
365
|
+
assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
|
366
|
+
if self.ndim == 1:
|
367
|
+
return self.entries[indices[0]]
|
368
|
+
return self.entries[indices[0]][indices[1]]
|
369
|
+
|
370
|
+
def _get_slice(self, a, b):
|
371
|
+
if isinstance(a, slice):
|
372
|
+
a = range(a.start or 0, a.stop or self.n, a.step or 1)
|
373
|
+
if isinstance(b, slice):
|
374
|
+
b = range(b.start or 0, b.stop or self.m, b.step or 1)
|
375
|
+
if isinstance(a, range) and isinstance(b, range):
|
376
|
+
return Matrix([[self._get_entry(i, j) for j in b] for i in a])
|
377
|
+
if isinstance(a, range): # b is not range
|
378
|
+
return Vector([self._get_entry(i, b) for i in a])
|
379
|
+
# a is not range while b is range
|
380
|
+
return Vector([self._get_entry(a, j) for j in b])
|
381
|
+
|
382
|
+
@python_scope
|
383
|
+
def _set_entries(self, value):
|
384
|
+
if isinstance(value, Matrix):
|
385
|
+
value = value.to_list()
|
386
|
+
if self.is_host_access:
|
387
|
+
if self.ndim == 1:
|
388
|
+
for i in range(self.n):
|
389
|
+
_write_host_access(self.entries[i], value[i])
|
390
|
+
else:
|
391
|
+
for i in range(self.n):
|
392
|
+
for j in range(self.m):
|
393
|
+
_write_host_access(self.entries[i][j], value[i][j])
|
394
|
+
else:
|
395
|
+
if self.ndim == 1:
|
396
|
+
for i in range(self.n):
|
397
|
+
self.entries[i] = value[i]
|
398
|
+
else:
|
399
|
+
for i in range(self.n):
|
400
|
+
for j in range(self.m):
|
401
|
+
self.entries[i][j] = value[i][j]
|
402
|
+
|
403
|
+
@property
|
404
|
+
def _members(self):
|
405
|
+
return self.entries
|
406
|
+
|
407
|
+
def to_list(self):
|
408
|
+
"""Return this matrix as a 1D `list`.
|
409
|
+
|
410
|
+
This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods,
|
411
|
+
the difference is that this function always returns a new list.
|
412
|
+
"""
|
413
|
+
if self.is_host_access:
|
414
|
+
if self.ndim == 1:
|
415
|
+
return [_read_host_access(self.entries[i]) for i in range(self.n)]
|
416
|
+
assert self.ndim == 2
|
417
|
+
return [[_read_host_access(self.entries[i][j]) for j in range(self.m)] for i in range(self.n)]
|
418
|
+
return self.entries.tolist()
|
419
|
+
|
420
|
+
@gstaichi_scope
|
421
|
+
def cast(self, dtype):
|
422
|
+
"""Cast the matrix elements to a specified data type.
|
423
|
+
|
424
|
+
Args:
|
425
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): data type of the
|
426
|
+
returned matrix.
|
427
|
+
|
428
|
+
Returns:
|
429
|
+
:class:`gstaichi.Matrix`: A new matrix with the specified data dtype.
|
430
|
+
|
431
|
+
Example::
|
432
|
+
|
433
|
+
>>> A = ti.Matrix([0, 1, 2], ti.i32)
|
434
|
+
>>> B = A.cast(ti.f32)
|
435
|
+
>>> B
|
436
|
+
[0.0, 1.0, 2.0]
|
437
|
+
"""
|
438
|
+
if self.ndim == 1:
|
439
|
+
return Vector([ops_mod.cast(self[i], dtype) for i in range(self.n)])
|
440
|
+
return Matrix([[ops_mod.cast(self[i, j], dtype) for j in range(self.m)] for i in range(self.n)])
|
441
|
+
|
442
|
+
def trace(self):
|
443
|
+
"""The sum of a matrix diagonal elements.
|
444
|
+
|
445
|
+
To call this method the matrix must be square-like.
|
446
|
+
|
447
|
+
Returns:
|
448
|
+
The sum of a matrix diagonal elements.
|
449
|
+
|
450
|
+
Example::
|
451
|
+
|
452
|
+
>>> m = ti.Matrix([[1, 2], [3, 4]])
|
453
|
+
>>> m.trace()
|
454
|
+
5
|
455
|
+
"""
|
456
|
+
# pylint: disable-msg=C0415
|
457
|
+
from gstaichi.lang import matrix_ops
|
458
|
+
|
459
|
+
return matrix_ops.trace(self)
|
460
|
+
|
461
|
+
def inverse(self):
|
462
|
+
"""Returns the inverse of this matrix.
|
463
|
+
|
464
|
+
Note:
|
465
|
+
The matrix dimension should be less than or equal to 4.
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
:class:`~gstaichi.Matrix`: The inverse of a matrix.
|
469
|
+
|
470
|
+
Raises:
|
471
|
+
Exception: Inversions of matrices with sizes >= 5 are not supported.
|
472
|
+
"""
|
473
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
474
|
+
|
475
|
+
return matrix_ops.inverse(self)
|
476
|
+
|
477
|
+
def normalized(self, eps=0):
|
478
|
+
"""Normalize a vector, i.e. matrices with the second dimension being
|
479
|
+
equal to one.
|
480
|
+
|
481
|
+
The normalization of a vector `v` is a vector of length 1
|
482
|
+
and has the same direction with `v`. It's equal to `v/|v|`.
|
483
|
+
|
484
|
+
Args:
|
485
|
+
eps (float): a safe-guard value for sqrt, usually 0.
|
486
|
+
|
487
|
+
Example::
|
488
|
+
|
489
|
+
>>> a = ti.Vector([3, 4], ti.f32)
|
490
|
+
>>> a.normalized()
|
491
|
+
[0.6, 0.8]
|
492
|
+
"""
|
493
|
+
# pylint: disable-msg=C0415
|
494
|
+
from gstaichi.lang import matrix_ops
|
495
|
+
|
496
|
+
return matrix_ops.normalized(self, eps)
|
497
|
+
|
498
|
+
def transpose(self):
|
499
|
+
"""Returns the transpose of a matrix.
|
500
|
+
|
501
|
+
Returns:
|
502
|
+
:class:`~gstaichi.Matrix`: The transpose of this matrix.
|
503
|
+
|
504
|
+
Example::
|
505
|
+
|
506
|
+
>>> A = ti.Matrix([[0, 1], [2, 3]])
|
507
|
+
>>> A.transpose()
|
508
|
+
[[0, 2], [1, 3]]
|
509
|
+
"""
|
510
|
+
# pylint: disable=C0415
|
511
|
+
from gstaichi.lang import matrix_ops
|
512
|
+
|
513
|
+
return matrix_ops.transpose(self)
|
514
|
+
|
515
|
+
@gstaichi_scope
|
516
|
+
def determinant(a):
|
517
|
+
"""Returns the determinant of this matrix.
|
518
|
+
|
519
|
+
Note:
|
520
|
+
The matrix dimension should be less than or equal to 4.
|
521
|
+
|
522
|
+
Returns:
|
523
|
+
dtype: The determinant of this matrix.
|
524
|
+
|
525
|
+
Raises:
|
526
|
+
Exception: Determinants of matrices with sizes >= 5 are not supported.
|
527
|
+
"""
|
528
|
+
# pylint: disable=C0415
|
529
|
+
from gstaichi.lang import matrix_ops
|
530
|
+
|
531
|
+
return matrix_ops.determinant(a)
|
532
|
+
|
533
|
+
@staticmethod
|
534
|
+
def diag(dim, val):
|
535
|
+
"""Returns a diagonal square matrix with the diagonals filled
|
536
|
+
with `val`.
|
537
|
+
|
538
|
+
Args:
|
539
|
+
dim (int): the dimension of the wanted square matrix.
|
540
|
+
val (TypeVar): value for the diagonal elements.
|
541
|
+
|
542
|
+
Returns:
|
543
|
+
:class:`~gstaichi.Matrix`: The wanted diagonal matrix.
|
544
|
+
|
545
|
+
Example::
|
546
|
+
|
547
|
+
>>> m = ti.Matrix.diag(3, 1)
|
548
|
+
[[1, 0, 0],
|
549
|
+
[0, 1, 0],
|
550
|
+
[0, 0, 1]]
|
551
|
+
"""
|
552
|
+
# pylint: disable=C0415
|
553
|
+
from gstaichi.lang import matrix_ops
|
554
|
+
|
555
|
+
return matrix_ops.diag(dim, val)
|
556
|
+
|
557
|
+
def sum(self):
|
558
|
+
"""Return the sum of all elements.
|
559
|
+
|
560
|
+
Example::
|
561
|
+
|
562
|
+
>>> m = ti.Matrix([[1, 2], [3, 4]])
|
563
|
+
>>> m.sum()
|
564
|
+
10
|
565
|
+
"""
|
566
|
+
# pylint: disable=C0415
|
567
|
+
from gstaichi.lang import matrix_ops
|
568
|
+
|
569
|
+
return matrix_ops.sum(self)
|
570
|
+
|
571
|
+
def norm(self, eps=0):
|
572
|
+
"""Returns the square root of the sum of the absolute squares
|
573
|
+
of its elements.
|
574
|
+
|
575
|
+
Args:
|
576
|
+
eps (Number): a safe-guard value for sqrt, usually 0.
|
577
|
+
|
578
|
+
Example::
|
579
|
+
|
580
|
+
>>> a = ti.Vector([3, 4])
|
581
|
+
>>> a.norm()
|
582
|
+
5
|
583
|
+
|
584
|
+
Returns:
|
585
|
+
The square root of the sum of the absolute squares of its elements.
|
586
|
+
"""
|
587
|
+
# pylint: disable=C0415
|
588
|
+
from gstaichi.lang import matrix_ops
|
589
|
+
|
590
|
+
return matrix_ops.norm(self, eps=eps)
|
591
|
+
|
592
|
+
def norm_inv(self, eps=0):
|
593
|
+
"""The inverse of the matrix :func:`~gstaichi.lang.matrix.Matrix.norm`.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
eps (float): a safe-guard value for sqrt, usually 0.
|
597
|
+
|
598
|
+
Returns:
|
599
|
+
The inverse of the matrix/vector `norm`.
|
600
|
+
"""
|
601
|
+
# pylint: disable=C0415
|
602
|
+
from gstaichi.lang import matrix_ops
|
603
|
+
|
604
|
+
return matrix_ops.norm_inv(self, eps=eps)
|
605
|
+
|
606
|
+
def norm_sqr(self):
|
607
|
+
"""Returns the sum of the absolute squares of its elements."""
|
608
|
+
# pylint: disable=C0415
|
609
|
+
from gstaichi.lang import matrix_ops
|
610
|
+
|
611
|
+
return matrix_ops.norm_sqr(self)
|
612
|
+
|
613
|
+
def max(self):
|
614
|
+
"""Returns the maximum element value."""
|
615
|
+
# pylint: disable=C0415
|
616
|
+
from gstaichi.lang import matrix_ops
|
617
|
+
|
618
|
+
return matrix_ops.max(self)
|
619
|
+
|
620
|
+
def min(self):
|
621
|
+
"""Returns the minimum element value."""
|
622
|
+
# pylint: disable=C0415
|
623
|
+
from gstaichi.lang import matrix_ops
|
624
|
+
|
625
|
+
return matrix_ops.min(self)
|
626
|
+
|
627
|
+
def any(self):
|
628
|
+
"""Test whether any element not equal zero.
|
629
|
+
|
630
|
+
Returns:
|
631
|
+
bool: `True` if any element is not equal zero, `False` otherwise.
|
632
|
+
|
633
|
+
Example::
|
634
|
+
|
635
|
+
>>> v = ti.Vector([0, 0, 1])
|
636
|
+
>>> v.any()
|
637
|
+
True
|
638
|
+
"""
|
639
|
+
# pylint: disable=C0415
|
640
|
+
from gstaichi.lang import matrix_ops
|
641
|
+
|
642
|
+
return matrix_ops.any(self)
|
643
|
+
|
644
|
+
def all(self):
|
645
|
+
"""Test whether all element not equal zero.
|
646
|
+
|
647
|
+
Returns:
|
648
|
+
bool: `True` if all elements are not equal zero, `False` otherwise.
|
649
|
+
|
650
|
+
Example::
|
651
|
+
|
652
|
+
>>> v = ti.Vector([0, 0, 1])
|
653
|
+
>>> v.all()
|
654
|
+
False
|
655
|
+
"""
|
656
|
+
# pylint: disable=C0415
|
657
|
+
from gstaichi.lang import matrix_ops
|
658
|
+
|
659
|
+
return matrix_ops.all(self)
|
660
|
+
|
661
|
+
def fill(self, val):
|
662
|
+
"""Fills the matrix with a specified value.
|
663
|
+
|
664
|
+
Args:
|
665
|
+
val (Union[int, float]): Value to fill.
|
666
|
+
|
667
|
+
Example::
|
668
|
+
|
669
|
+
>>> A = ti.Matrix([0, 1, 2, 3])
|
670
|
+
>>> A.fill(-1)
|
671
|
+
>>> A
|
672
|
+
[-1, -1, -1, -1]
|
673
|
+
"""
|
674
|
+
# pylint: disable=C0415
|
675
|
+
from gstaichi.lang import matrix_ops
|
676
|
+
|
677
|
+
return matrix_ops.fill(self, val)
|
678
|
+
|
679
|
+
def to_numpy(self):
|
680
|
+
"""Converts this matrix to a numpy array.
|
681
|
+
|
682
|
+
Returns:
|
683
|
+
numpy.ndarray: The result numpy array.
|
684
|
+
|
685
|
+
Example::
|
686
|
+
|
687
|
+
>>> A = ti.Matrix([[0], [1], [2], [3]])
|
688
|
+
>>> A.to_numpy()
|
689
|
+
>>> A
|
690
|
+
array([[0], [1], [2], [3]])
|
691
|
+
"""
|
692
|
+
if self.is_host_access:
|
693
|
+
return np.array(self.to_list())
|
694
|
+
return self.entries
|
695
|
+
|
696
|
+
@gstaichi_scope
|
697
|
+
def __ti_repr__(self):
|
698
|
+
yield "["
|
699
|
+
for i in range(self.n):
|
700
|
+
if i:
|
701
|
+
yield ", "
|
702
|
+
if self.m != 1:
|
703
|
+
yield "["
|
704
|
+
for j in range(self.m):
|
705
|
+
if j:
|
706
|
+
yield ", "
|
707
|
+
yield self(i, j)
|
708
|
+
if self.m != 1:
|
709
|
+
yield "]"
|
710
|
+
yield "]"
|
711
|
+
|
712
|
+
def __str__(self):
|
713
|
+
"""Python scope matrix print support."""
|
714
|
+
if impl.inside_kernel():
|
715
|
+
"""
|
716
|
+
It seems that when pybind11 got an type mismatch, it will try
|
717
|
+
to invoke `repr` to show the object... e.g.:
|
718
|
+
|
719
|
+
TypeError: make_const_expr_f32(): incompatible function arguments. The following argument types are supported:
|
720
|
+
1. (arg0: float) -> gstaichi_python.Expr
|
721
|
+
|
722
|
+
Invoked with: <GsTaichi 2x1 Matrix>
|
723
|
+
|
724
|
+
So we have to make it happy with a dummy string...
|
725
|
+
"""
|
726
|
+
return f"<{self.n}x{self.m} ti.Matrix>"
|
727
|
+
return str(self.to_numpy())
|
728
|
+
|
729
|
+
def __repr__(self):
|
730
|
+
return str(self.to_numpy())
|
731
|
+
|
732
|
+
@staticmethod
|
733
|
+
@gstaichi_scope
|
734
|
+
def zero(dt, n, m=None):
|
735
|
+
"""Constructs a Matrix filled with zeros.
|
736
|
+
|
737
|
+
Args:
|
738
|
+
dt (DataType): The desired data type.
|
739
|
+
n (int): The first dimension (row) of the matrix.
|
740
|
+
m (int, optional): The second dimension (column) of the matrix.
|
741
|
+
|
742
|
+
Returns:
|
743
|
+
:class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with zeros.
|
744
|
+
|
745
|
+
"""
|
746
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
747
|
+
|
748
|
+
if m is None:
|
749
|
+
return matrix_ops._filled_vector(n, dt, 0)
|
750
|
+
return matrix_ops._filled_matrix(n, m, dt, 0)
|
751
|
+
|
752
|
+
@staticmethod
|
753
|
+
@gstaichi_scope
|
754
|
+
def one(dt, n, m=None):
|
755
|
+
"""Constructs a Matrix filled with ones.
|
756
|
+
|
757
|
+
Args:
|
758
|
+
dt (DataType): The desired data type.
|
759
|
+
n (int): The first dimension (row) of the matrix.
|
760
|
+
m (int, optional): The second dimension (column) of the matrix.
|
761
|
+
|
762
|
+
Returns:
|
763
|
+
:class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with ones.
|
764
|
+
|
765
|
+
"""
|
766
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
767
|
+
|
768
|
+
if m is None:
|
769
|
+
return matrix_ops._filled_vector(n, dt, 1)
|
770
|
+
return matrix_ops._filled_matrix(n, m, dt, 1)
|
771
|
+
|
772
|
+
@staticmethod
|
773
|
+
@gstaichi_scope
|
774
|
+
def unit(n, i, dt=None):
|
775
|
+
"""Constructs a n-D vector with the `i`-th entry being equal to one and
|
776
|
+
the remaining entries are all zeros.
|
777
|
+
|
778
|
+
Args:
|
779
|
+
n (int): The length of the vector.
|
780
|
+
i (int): The index of the entry that will be filled with one.
|
781
|
+
dt (:mod:`~gstaichi.types.primitive_types`, optional): The desired data type.
|
782
|
+
|
783
|
+
Returns:
|
784
|
+
:class:`~gstaichi.Matrix`: The returned vector.
|
785
|
+
|
786
|
+
Example::
|
787
|
+
|
788
|
+
>>> A = ti.Matrix.unit(3, 1)
|
789
|
+
>>> A
|
790
|
+
[0, 1, 0]
|
791
|
+
"""
|
792
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
793
|
+
|
794
|
+
if dt is None:
|
795
|
+
dt = int
|
796
|
+
assert 0 <= i < n
|
797
|
+
return matrix_ops._unit_vector(n, i, dt)
|
798
|
+
|
799
|
+
@staticmethod
|
800
|
+
@gstaichi_scope
|
801
|
+
def identity(dt, n):
|
802
|
+
"""Constructs an identity Matrix with shape (n, n).
|
803
|
+
|
804
|
+
Args:
|
805
|
+
dt (DataType): The desired data type.
|
806
|
+
n (int): The number of rows/columns.
|
807
|
+
|
808
|
+
Returns:
|
809
|
+
:class:`~gstaichi.Matrix`: An `n x n` identity matrix.
|
810
|
+
"""
|
811
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
812
|
+
|
813
|
+
return matrix_ops._identity_matrix(n, dt)
|
814
|
+
|
815
|
+
@classmethod
|
816
|
+
@python_scope
|
817
|
+
def field(
|
818
|
+
cls,
|
819
|
+
n,
|
820
|
+
m,
|
821
|
+
dtype,
|
822
|
+
shape=None,
|
823
|
+
order=None,
|
824
|
+
name="",
|
825
|
+
offset=None,
|
826
|
+
needs_grad=False,
|
827
|
+
needs_dual=False,
|
828
|
+
layout=Layout.AOS,
|
829
|
+
ndim=None,
|
830
|
+
):
|
831
|
+
"""Construct a data container to hold all elements of the Matrix.
|
832
|
+
|
833
|
+
Args:
|
834
|
+
n (int): The desired number of rows of the Matrix.
|
835
|
+
m (int): The desired number of columns of the Matrix.
|
836
|
+
dtype (DataType, optional): The desired data type of the Matrix.
|
837
|
+
shape (Union[int, tuple of int], optional): The desired shape of the Matrix.
|
838
|
+
order (str, optional): order of the shape laid out in memory.
|
839
|
+
name (string, optional): The custom name of the field.
|
840
|
+
offset (Union[int, tuple of int], optional): The coordinate offset
|
841
|
+
of all elements in a field.
|
842
|
+
needs_grad (bool, optional): Whether the Matrix need grad field (reverse mode autodiff).
|
843
|
+
needs_dual (bool, optional): Whether the Matrix need dual field (forward mode autodiff).
|
844
|
+
layout (Layout, optional): The field layout, either Array Of
|
845
|
+
Structure (AOS) or Structure Of Array (SOA).
|
846
|
+
|
847
|
+
Returns:
|
848
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
849
|
+
"""
|
850
|
+
entries = []
|
851
|
+
element_dim = ndim if ndim is not None else 2
|
852
|
+
if isinstance(dtype, (list, tuple, np.ndarray)):
|
853
|
+
# set different dtype for each element in Matrix
|
854
|
+
# see #2135
|
855
|
+
if m == 1:
|
856
|
+
assert (
|
857
|
+
len(np.shape(dtype)) == 1 and len(dtype) == n
|
858
|
+
), f"Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}"
|
859
|
+
for i in range(n):
|
860
|
+
entries.append(
|
861
|
+
impl.create_field_member(
|
862
|
+
dtype[i],
|
863
|
+
name=name,
|
864
|
+
needs_grad=needs_grad,
|
865
|
+
needs_dual=needs_dual,
|
866
|
+
)
|
867
|
+
)
|
868
|
+
else:
|
869
|
+
assert (
|
870
|
+
len(np.shape(dtype)) == 2 and len(dtype) == n and len(dtype[0]) == m
|
871
|
+
), f"Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}"
|
872
|
+
for i in range(n):
|
873
|
+
for j in range(m):
|
874
|
+
entries.append(
|
875
|
+
impl.create_field_member(
|
876
|
+
dtype[i][j],
|
877
|
+
name=name,
|
878
|
+
needs_grad=needs_grad,
|
879
|
+
needs_dual=needs_dual,
|
880
|
+
)
|
881
|
+
)
|
882
|
+
else:
|
883
|
+
for _ in range(n * m):
|
884
|
+
entries.append(impl.create_field_member(dtype, name=name, needs_grad=needs_grad, needs_dual=needs_dual))
|
885
|
+
entries, entries_grad, entries_dual = zip(*entries)
|
886
|
+
|
887
|
+
entries = MatrixField(entries, n, m, element_dim)
|
888
|
+
if all(entries_grad):
|
889
|
+
entries_grad = MatrixField(entries_grad, n, m, element_dim)
|
890
|
+
entries._set_grad(entries_grad)
|
891
|
+
if all(entries_dual):
|
892
|
+
entries_dual = MatrixField(entries_dual, n, m, element_dim)
|
893
|
+
entries._set_dual(entries_dual)
|
894
|
+
|
895
|
+
impl.get_runtime().matrix_fields.append(entries)
|
896
|
+
|
897
|
+
if shape is None:
|
898
|
+
if offset is not None:
|
899
|
+
raise GsTaichiSyntaxError("shape cannot be None when offset is set")
|
900
|
+
if order is not None:
|
901
|
+
raise GsTaichiSyntaxError("shape cannot be None when order is set")
|
902
|
+
else:
|
903
|
+
if isinstance(shape, numbers.Number):
|
904
|
+
shape = (shape,)
|
905
|
+
if isinstance(offset, numbers.Number):
|
906
|
+
offset = (offset,)
|
907
|
+
dim = len(shape)
|
908
|
+
if offset is not None and dim != len(offset):
|
909
|
+
raise GsTaichiSyntaxError(
|
910
|
+
f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
|
911
|
+
)
|
912
|
+
axis_seq = []
|
913
|
+
shape_seq = []
|
914
|
+
if order is not None:
|
915
|
+
if dim != len(order):
|
916
|
+
raise GsTaichiSyntaxError(
|
917
|
+
f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
|
918
|
+
)
|
919
|
+
if dim != len(set(order)):
|
920
|
+
raise GsTaichiSyntaxError("The axes in order must be different")
|
921
|
+
for ch in order:
|
922
|
+
axis = ord(ch) - ord("i")
|
923
|
+
if axis < 0 or axis >= dim:
|
924
|
+
raise GsTaichiSyntaxError(f"Invalid axis {ch}")
|
925
|
+
axis_seq.append(axis)
|
926
|
+
shape_seq.append(shape[axis])
|
927
|
+
else:
|
928
|
+
axis_seq = list(range(dim))
|
929
|
+
shape_seq = list(shape)
|
930
|
+
same_level = order is None
|
931
|
+
if layout == Layout.SOA:
|
932
|
+
for e in entries._get_field_members():
|
933
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
|
934
|
+
if needs_grad:
|
935
|
+
for e in entries_grad._get_field_members():
|
936
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
|
937
|
+
if needs_dual:
|
938
|
+
for e in entries_dual._get_field_members():
|
939
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
|
940
|
+
else:
|
941
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(entries, offset=offset)
|
942
|
+
if needs_grad:
|
943
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(entries_grad, offset=offset)
|
944
|
+
if needs_dual:
|
945
|
+
impl._create_snode(axis_seq, shape_seq, same_level).place(entries_dual, offset=offset)
|
946
|
+
return entries
|
947
|
+
|
948
|
+
@classmethod
|
949
|
+
@python_scope
|
950
|
+
def ndarray(cls, n, m, dtype, shape):
|
951
|
+
"""Defines a GsTaichi ndarray with matrix elements.
|
952
|
+
This function must be called in Python scope, and after `ti.init` is called.
|
953
|
+
|
954
|
+
Args:
|
955
|
+
n (int): Number of rows of the matrix.
|
956
|
+
m (int): Number of columns of the matrix.
|
957
|
+
dtype (DataType): Data type of each value.
|
958
|
+
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
959
|
+
|
960
|
+
Example::
|
961
|
+
|
962
|
+
The code below shows how a GsTaichi ndarray with matrix elements \
|
963
|
+
can be declared and defined::
|
964
|
+
|
965
|
+
>>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
|
966
|
+
"""
|
967
|
+
if isinstance(shape, numbers.Number):
|
968
|
+
shape = (shape,)
|
969
|
+
return MatrixNdarray(n, m, dtype, shape)
|
970
|
+
|
971
|
+
@staticmethod
|
972
|
+
def rows(rows):
|
973
|
+
"""Constructs a matrix by concatenating a list of
|
974
|
+
vectors/lists row by row. Must be called in GsTaichi scope.
|
975
|
+
|
976
|
+
Args:
|
977
|
+
rows (List): A list of Vector (1-D Matrix) or a list of list.
|
978
|
+
|
979
|
+
Returns:
|
980
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
981
|
+
|
982
|
+
Example::
|
983
|
+
|
984
|
+
>>> @ti.kernel
|
985
|
+
>>> def test():
|
986
|
+
>>> v1 = ti.Vector([1, 2, 3])
|
987
|
+
>>> v2 = ti.Vector([4, 5, 6])
|
988
|
+
>>> m = ti.Matrix.rows([v1, v2])
|
989
|
+
>>> print(m)
|
990
|
+
>>>
|
991
|
+
>>> test()
|
992
|
+
[[1, 2, 3], [4, 5, 6]]
|
993
|
+
"""
|
994
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
995
|
+
|
996
|
+
return matrix_ops.rows(rows)
|
997
|
+
|
998
|
+
@staticmethod
|
999
|
+
def cols(cols):
|
1000
|
+
"""Constructs a Matrix instance by concatenating Vectors/lists column by column.
|
1001
|
+
|
1002
|
+
Args:
|
1003
|
+
cols (List): A list of Vector (1-D Matrix) or a list of list.
|
1004
|
+
|
1005
|
+
Returns:
|
1006
|
+
:class:`~gstaichi.Matrix`: A matrix.
|
1007
|
+
|
1008
|
+
Example::
|
1009
|
+
|
1010
|
+
>>> @ti.kernel
|
1011
|
+
>>> def test():
|
1012
|
+
>>> v1 = ti.Vector([1, 2, 3])
|
1013
|
+
>>> v2 = ti.Vector([4, 5, 6])
|
1014
|
+
>>> m = ti.Matrix.cols([v1, v2])
|
1015
|
+
>>> print(m)
|
1016
|
+
>>>
|
1017
|
+
>>> test()
|
1018
|
+
[[1, 4], [2, 5], [3, 6]]
|
1019
|
+
"""
|
1020
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1021
|
+
|
1022
|
+
return matrix_ops.cols(cols)
|
1023
|
+
|
1024
|
+
def __hash__(self):
|
1025
|
+
# TODO: refactor KernelTemplateMapper
|
1026
|
+
# If not, we get `unhashable type: Matrix` when
|
1027
|
+
# using matrices as template arguments.
|
1028
|
+
return id(self)
|
1029
|
+
|
1030
|
+
def dot(self, other):
|
1031
|
+
"""Performs the dot product of two vectors.
|
1032
|
+
|
1033
|
+
To call this method, both multiplicatives must be vectors.
|
1034
|
+
|
1035
|
+
Args:
|
1036
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1037
|
+
|
1038
|
+
Returns:
|
1039
|
+
DataType: The dot product result (scalar) of the two Vectors.
|
1040
|
+
|
1041
|
+
Example::
|
1042
|
+
|
1043
|
+
>>> v1 = ti.Vector([1, 2, 3])
|
1044
|
+
>>> v2 = ti.Vector([3, 4, 5])
|
1045
|
+
>>> v1.dot(v2)
|
1046
|
+
26
|
1047
|
+
"""
|
1048
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1049
|
+
|
1050
|
+
return matrix_ops.dot(self, other)
|
1051
|
+
|
1052
|
+
def cross(self, other):
|
1053
|
+
"""Performs the cross product with the input vector (1-D Matrix).
|
1054
|
+
|
1055
|
+
Both two vectors must have the same dimension <= 3.
|
1056
|
+
|
1057
|
+
For two 2d vectors (x1, y1) and (x2, y2), the return value is the
|
1058
|
+
scalar `x1*y2 - x2*y1`.
|
1059
|
+
|
1060
|
+
For two 3d vectors `v` and `w`, the return value is the 3d vector
|
1061
|
+
`v x w`.
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1065
|
+
|
1066
|
+
Returns:
|
1067
|
+
:class:`~gstaichi.Matrix`: The cross product of the two Vectors.
|
1068
|
+
"""
|
1069
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1070
|
+
|
1071
|
+
return matrix_ops.cross(self, other)
|
1072
|
+
|
1073
|
+
def outer_product(self, other):
|
1074
|
+
"""Performs the outer product with the input Vector (1-D Matrix).
|
1075
|
+
|
1076
|
+
The outer_product of two vectors `v = (x1, x2, ..., xn)`,
|
1077
|
+
`w = (y1, y2, ..., yn)` is a `n` times `n` square matrix, and its `(i, j)`
|
1078
|
+
entry is equal to `xi*yj`.
|
1079
|
+
|
1080
|
+
Args:
|
1081
|
+
other (:class:`~gstaichi.Matrix`): The input Vector.
|
1082
|
+
|
1083
|
+
Returns:
|
1084
|
+
:class:`~gstaichi.Matrix`: The outer product of the two Vectors.
|
1085
|
+
"""
|
1086
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
1087
|
+
|
1088
|
+
return matrix_ops.outer_product(self, other)
|
1089
|
+
|
1090
|
+
|
1091
|
+
class Vector(Matrix):
|
1092
|
+
def __init__(self, arr, dt=None, **kwargs):
|
1093
|
+
"""Constructs a vector from given array.
|
1094
|
+
|
1095
|
+
A vector is an instance of a 2-D matrix with the second dimension being equal to 1.
|
1096
|
+
|
1097
|
+
Args:
|
1098
|
+
arr (Union[list, tuple, np.ndarray]): The initial values of the Vector.
|
1099
|
+
dt (:mod:`~gstaichi.types.primitive_types`): data type of the vector.
|
1100
|
+
|
1101
|
+
Returns:
|
1102
|
+
:class:`~gstaichi.Matrix`: A vector instance.
|
1103
|
+
Example::
|
1104
|
+
>>> u = ti.Vector([1, 2])
|
1105
|
+
>>> print(u.m, u.n) # verify a vector is a matrix of shape (n, 1)
|
1106
|
+
2 1
|
1107
|
+
>>> v = ti.Vector([3, 4])
|
1108
|
+
>>> u + v
|
1109
|
+
[4 6]
|
1110
|
+
"""
|
1111
|
+
super().__init__(arr, dt=dt, **kwargs)
|
1112
|
+
|
1113
|
+
def get_shape(self):
|
1114
|
+
return (self.n,)
|
1115
|
+
|
1116
|
+
@classmethod
|
1117
|
+
def field(cls, n, dtype, *args, **kwargs):
|
1118
|
+
"""ti.Vector.field"""
|
1119
|
+
ndim = kwargs.get("ndim", 1)
|
1120
|
+
assert ndim == 1
|
1121
|
+
kwargs["ndim"] = 1
|
1122
|
+
return super().field(n, 1, dtype, *args, **kwargs)
|
1123
|
+
|
1124
|
+
@classmethod
|
1125
|
+
@python_scope
|
1126
|
+
def ndarray(cls, n, dtype, shape):
|
1127
|
+
"""Defines a GsTaichi ndarray with vector elements.
|
1128
|
+
|
1129
|
+
Args:
|
1130
|
+
n (int): Size of the vector.
|
1131
|
+
dtype (DataType): Data type of each value.
|
1132
|
+
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
1133
|
+
|
1134
|
+
Example:
|
1135
|
+
The code below shows how a GsTaichi ndarray with vector elements can be declared and defined::
|
1136
|
+
|
1137
|
+
>>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
|
1138
|
+
"""
|
1139
|
+
if isinstance(shape, numbers.Number):
|
1140
|
+
shape = (shape,)
|
1141
|
+
return VectorNdarray(n, dtype, shape)
|
1142
|
+
|
1143
|
+
|
1144
|
+
class MatrixField(Field):
|
1145
|
+
"""GsTaichi matrix field with SNode implementation.
|
1146
|
+
|
1147
|
+
Args:
|
1148
|
+
vars (List[Expr]): Field members.
|
1149
|
+
n (Int): Number of rows.
|
1150
|
+
m (Int): Number of columns.
|
1151
|
+
ndim (Int): Number of dimensions; forced reshape if given.
|
1152
|
+
"""
|
1153
|
+
|
1154
|
+
def __init__(self, _vars, n, m, ndim=2):
|
1155
|
+
assert len(_vars) == n * m
|
1156
|
+
assert ndim in (0, 1, 2)
|
1157
|
+
super().__init__(_vars)
|
1158
|
+
self.n = n
|
1159
|
+
self.m = m
|
1160
|
+
self.ndim = ndim
|
1161
|
+
self.ptr = ti_python_core.expr_matrix_field([var.ptr for var in self.vars], [n, m][:ndim])
|
1162
|
+
|
1163
|
+
def get_scalar_field(self, *indices):
|
1164
|
+
"""Creates a ScalarField using a specific field member.
|
1165
|
+
|
1166
|
+
Args:
|
1167
|
+
indices (Tuple[Int]): Specified indices of the field member.
|
1168
|
+
|
1169
|
+
Returns:
|
1170
|
+
ScalarField: The result ScalarField.
|
1171
|
+
"""
|
1172
|
+
assert len(indices) in [1, 2]
|
1173
|
+
i = indices[0]
|
1174
|
+
j = 0 if len(indices) == 1 else indices[1]
|
1175
|
+
return ScalarField(self.vars[i * self.m + j])
|
1176
|
+
|
1177
|
+
def _get_dynamic_index_stride(self):
|
1178
|
+
if self.ptr.get_dynamic_indexable():
|
1179
|
+
return self.ptr.get_dynamic_index_stride()
|
1180
|
+
return None
|
1181
|
+
|
1182
|
+
def _calc_dynamic_index_stride(self):
|
1183
|
+
# Algorithm: https://github.com/taichi-dev/gstaichi/issues/3810
|
1184
|
+
paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
|
1185
|
+
num_members = len(paths)
|
1186
|
+
if num_members == 1:
|
1187
|
+
self.ptr.set_dynamic_index_stride(0)
|
1188
|
+
return
|
1189
|
+
length = len(paths[0])
|
1190
|
+
if any(len(path) != length or ti_python_core.is_quant(path[length - 1]._dtype) for path in paths):
|
1191
|
+
return
|
1192
|
+
for i in range(length):
|
1193
|
+
if any(path[i] != paths[0][i] for path in paths):
|
1194
|
+
depth_below_lca = i
|
1195
|
+
break
|
1196
|
+
for i in range(depth_below_lca, length - 1):
|
1197
|
+
if any(
|
1198
|
+
path[i].ptr.type != ti_python_core.SNodeType.dense
|
1199
|
+
or path[i]._cell_size_bytes != paths[0][i]._cell_size_bytes
|
1200
|
+
or path[i + 1]._offset_bytes_in_parent_cell != paths[0][i + 1]._offset_bytes_in_parent_cell
|
1201
|
+
for path in paths
|
1202
|
+
):
|
1203
|
+
return
|
1204
|
+
stride = (
|
1205
|
+
paths[1][depth_below_lca]._offset_bytes_in_parent_cell
|
1206
|
+
- paths[0][depth_below_lca]._offset_bytes_in_parent_cell
|
1207
|
+
)
|
1208
|
+
for i in range(2, num_members):
|
1209
|
+
if (
|
1210
|
+
stride
|
1211
|
+
!= paths[i][depth_below_lca]._offset_bytes_in_parent_cell
|
1212
|
+
- paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell
|
1213
|
+
):
|
1214
|
+
return
|
1215
|
+
self.ptr.set_dynamic_index_stride(stride)
|
1216
|
+
|
1217
|
+
def fill(self, val):
|
1218
|
+
"""Fills this matrix field with specified values.
|
1219
|
+
|
1220
|
+
Args:
|
1221
|
+
val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill,
|
1222
|
+
should have consistent dimension consistent with `self`.
|
1223
|
+
"""
|
1224
|
+
if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) and not val.is_tensor()):
|
1225
|
+
if self.ndim == 2:
|
1226
|
+
val = tuple(tuple(val for _ in range(self.m)) for _ in range(self.n))
|
1227
|
+
else:
|
1228
|
+
assert self.ndim == 1
|
1229
|
+
val = tuple(val for _ in range(self.n))
|
1230
|
+
elif isinstance(val, expr.Expr) and val.is_tensor():
|
1231
|
+
assert val.n == self.n
|
1232
|
+
if self.ndim != 1:
|
1233
|
+
assert val.m == self.m
|
1234
|
+
else:
|
1235
|
+
if isinstance(val, Matrix):
|
1236
|
+
val = val.to_list()
|
1237
|
+
assert isinstance(val, (list, tuple))
|
1238
|
+
val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
|
1239
|
+
assert len(val) == self.n
|
1240
|
+
if self.ndim != 1:
|
1241
|
+
assert len(val[0]) == self.m
|
1242
|
+
if in_python_scope():
|
1243
|
+
from gstaichi._kernels import ( # pylint: disable=C0415
|
1244
|
+
field_fill_python_scope, # pylint: disable=C0415
|
1245
|
+
)
|
1246
|
+
|
1247
|
+
field_fill_python_scope(self, val)
|
1248
|
+
else:
|
1249
|
+
from gstaichi._funcs import ( # pylint: disable=C0415
|
1250
|
+
field_fill_gstaichi_scope, # pylint: disable=C0415
|
1251
|
+
)
|
1252
|
+
|
1253
|
+
field_fill_gstaichi_scope(self, val)
|
1254
|
+
|
1255
|
+
@python_scope
|
1256
|
+
def to_numpy(self, keep_dims=False, dtype=None):
|
1257
|
+
"""Converts the field instance to a NumPy array.
|
1258
|
+
|
1259
|
+
Args:
|
1260
|
+
keep_dims (bool, optional): Whether to keep the dimension after conversion.
|
1261
|
+
When keep_dims=True, on an n-D matrix field, the numpy array always has n+2 dims, even for 1x1, 1xn, nx1 matrix fields.
|
1262
|
+
When keep_dims=False, the resulting numpy array should skip the matrix dims with size 1.
|
1263
|
+
For example, a 4x1 or 1x4 matrix field with 5x6x7 elements results in an array of shape 5x6x7x4.
|
1264
|
+
dtype (DataType, optional): The desired data type of returned numpy array.
|
1265
|
+
|
1266
|
+
Returns:
|
1267
|
+
numpy.ndarray: The result NumPy array.
|
1268
|
+
"""
|
1269
|
+
if dtype is None:
|
1270
|
+
dtype = to_numpy_type(self.dtype)
|
1271
|
+
as_vector = self.m == 1 and not keep_dims
|
1272
|
+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
|
1273
|
+
arr = np.zeros(self.shape + shape_ext, dtype=dtype)
|
1274
|
+
from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
|
1275
|
+
|
1276
|
+
matrix_to_ext_arr(self, arr, as_vector)
|
1277
|
+
runtime_ops.sync()
|
1278
|
+
return arr
|
1279
|
+
|
1280
|
+
def to_torch(self, device=None, keep_dims=False):
|
1281
|
+
"""Converts the field instance to a PyTorch tensor.
|
1282
|
+
|
1283
|
+
Args:
|
1284
|
+
device (torch.device, optional): The desired device of returned tensor.
|
1285
|
+
keep_dims (bool, optional): Whether to keep the dimension after conversion.
|
1286
|
+
See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
|
1287
|
+
|
1288
|
+
Returns:
|
1289
|
+
torch.tensor: The result torch tensor.
|
1290
|
+
"""
|
1291
|
+
import torch # pylint: disable=C0415
|
1292
|
+
|
1293
|
+
as_vector = self.m == 1 and not keep_dims
|
1294
|
+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
|
1295
|
+
# pylint: disable=E1101
|
1296
|
+
arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device)
|
1297
|
+
from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
|
1298
|
+
|
1299
|
+
matrix_to_ext_arr(self, arr, as_vector)
|
1300
|
+
runtime_ops.sync()
|
1301
|
+
return arr
|
1302
|
+
|
1303
|
+
@python_scope
|
1304
|
+
def _from_external_arr(self, arr):
|
1305
|
+
if len(arr.shape) == len(self.shape) + 1:
|
1306
|
+
as_vector = True
|
1307
|
+
assert self.m == 1, "This is not a vector field"
|
1308
|
+
else:
|
1309
|
+
as_vector = False
|
1310
|
+
assert len(arr.shape) == len(self.shape) + 2
|
1311
|
+
dim_ext = 1 if as_vector else 2
|
1312
|
+
assert len(arr.shape) == len(self.shape) + dim_ext
|
1313
|
+
from gstaichi._kernels import ext_arr_to_matrix # pylint: disable=C0415
|
1314
|
+
|
1315
|
+
ext_arr_to_matrix(arr, self, as_vector)
|
1316
|
+
runtime_ops.sync()
|
1317
|
+
|
1318
|
+
@python_scope
|
1319
|
+
def from_numpy(self, arr):
|
1320
|
+
"""Copies an `numpy.ndarray` into this field.
|
1321
|
+
|
1322
|
+
Example::
|
1323
|
+
|
1324
|
+
>>> m = ti.Matrix.field(2, 2, ti.f32, shape=(3, 3))
|
1325
|
+
>>> arr = numpy.ones((3, 3, 2, 2))
|
1326
|
+
>>> m.from_numpy(arr)
|
1327
|
+
"""
|
1328
|
+
|
1329
|
+
if not arr.flags.c_contiguous:
|
1330
|
+
arr = np.ascontiguousarray(arr)
|
1331
|
+
self._from_external_arr(arr)
|
1332
|
+
|
1333
|
+
@python_scope
|
1334
|
+
def __setitem__(self, key, value):
|
1335
|
+
self._initialize_host_accessors()
|
1336
|
+
self[key]._set_entries(value)
|
1337
|
+
|
1338
|
+
@python_scope
|
1339
|
+
def __getitem__(self, key):
|
1340
|
+
self._initialize_host_accessors()
|
1341
|
+
key = self._pad_key(key)
|
1342
|
+
_host_access = self._host_access(key)
|
1343
|
+
if self.ndim == 1:
|
1344
|
+
return Vector([_host_access[i] for i in range(self.n)])
|
1345
|
+
return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] for i in range(self.n)])
|
1346
|
+
|
1347
|
+
def __repr__(self):
|
1348
|
+
# make interactive shell happy, prevent materialization
|
1349
|
+
return f"<{self.n}x{self.m} ti.Matrix.field>"
|
1350
|
+
|
1351
|
+
|
1352
|
+
class MatrixType(CompoundType):
|
1353
|
+
def __init__(self, n, m, ndim, dtype):
|
1354
|
+
self.n = n
|
1355
|
+
self.m = m
|
1356
|
+
self.ndim = ndim
|
1357
|
+
# FIXME(haidong): dtypes should not be left empty for ndarray.
|
1358
|
+
# Remove the None dtype when we are ready to break legacy code.
|
1359
|
+
if dtype is not None:
|
1360
|
+
self.dtype = cook_dtype(dtype)
|
1361
|
+
shape = (n, m) if ndim == 2 else (n,)
|
1362
|
+
self.tensor_type = _type_factory.get_tensor_type(shape, self.dtype)
|
1363
|
+
else:
|
1364
|
+
self.dtype = None
|
1365
|
+
self.tensor_type = None
|
1366
|
+
|
1367
|
+
def __call__(self, *args):
|
1368
|
+
"""Return a matrix matching the shape and dtype.
|
1369
|
+
|
1370
|
+
This function will try to convert the input to a `n x m` matrix, with n, m being
|
1371
|
+
the number of rows/cols of this matrix type.
|
1372
|
+
|
1373
|
+
Example::
|
1374
|
+
|
1375
|
+
>>> mat4x3 = MatrixType(4, 3, float)
|
1376
|
+
>>> mat2x6 = MatrixType(2, 6, float)
|
1377
|
+
|
1378
|
+
Create from n x m scalars, of a 1d list of n x m scalars:
|
1379
|
+
|
1380
|
+
>>> m = mat4x3([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
|
1381
|
+
>>> m = mat4x3(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
|
1382
|
+
|
1383
|
+
Create from n vectors/lists, with each one of dimension m:
|
1384
|
+
|
1385
|
+
>>> m = mat4x3([1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12])
|
1386
|
+
|
1387
|
+
Create from a single scalar
|
1388
|
+
|
1389
|
+
>>> m = mat4x3(1)
|
1390
|
+
|
1391
|
+
Create from another 2d list/matrix, as long as they have the same number of entries
|
1392
|
+
|
1393
|
+
>>> m = mat4x3([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
|
1394
|
+
>>> m = mat4x3(m)
|
1395
|
+
>>> k = mat2x6(m)
|
1396
|
+
|
1397
|
+
"""
|
1398
|
+
if len(args) == 0:
|
1399
|
+
raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
|
1400
|
+
if len(args) == 1:
|
1401
|
+
# Init from a real Matrix
|
1402
|
+
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
|
1403
|
+
arg = args[0]
|
1404
|
+
shape = arg.ptr.get_rvalue_type().shape()
|
1405
|
+
assert self.ndim == len(shape)
|
1406
|
+
assert self.n == shape[0]
|
1407
|
+
if self.ndim > 1:
|
1408
|
+
assert self.m == shape[1]
|
1409
|
+
return expr.Expr(arg.ptr)
|
1410
|
+
|
1411
|
+
# initialize by a single scalar, e.g. matnxm(1)
|
1412
|
+
if isinstance(args[0], (numbers.Number, expr.Expr)):
|
1413
|
+
entries = [args[0] for _ in range(self.m) for _ in range(self.n)]
|
1414
|
+
return self._instantiate(entries)
|
1415
|
+
args = args[0]
|
1416
|
+
# collect all input entries to a 1d list and then reshape
|
1417
|
+
# this is mostly for glsl style like vec4(v.xyz, 1.)
|
1418
|
+
entries = []
|
1419
|
+
for x in args:
|
1420
|
+
if isinstance(x, (list, tuple)):
|
1421
|
+
entries += x
|
1422
|
+
elif isinstance(x, np.ndarray):
|
1423
|
+
entries += list(x.ravel())
|
1424
|
+
elif isinstance(x, Matrix):
|
1425
|
+
entries += x.to_list()
|
1426
|
+
else:
|
1427
|
+
entries.append(x)
|
1428
|
+
|
1429
|
+
return self._instantiate(entries)
|
1430
|
+
|
1431
|
+
def from_gstaichi_object(self, func_ret, ret_index=()):
|
1432
|
+
return self(
|
1433
|
+
[
|
1434
|
+
expr.Expr(
|
1435
|
+
ti_python_core.make_get_element_expr(
|
1436
|
+
func_ret.ptr,
|
1437
|
+
ret_index + (i,),
|
1438
|
+
_ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
1439
|
+
)
|
1440
|
+
)
|
1441
|
+
for i in range(self.m * self.n)
|
1442
|
+
]
|
1443
|
+
)
|
1444
|
+
|
1445
|
+
def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
|
1446
|
+
if self.dtype in primitive_types.integer_types:
|
1447
|
+
if is_signed(cook_dtype(self.dtype)):
|
1448
|
+
get_ret_func = launch_ctx.get_struct_ret_int
|
1449
|
+
else:
|
1450
|
+
get_ret_func = launch_ctx.get_struct_ret_uint
|
1451
|
+
elif self.dtype in primitive_types.real_types:
|
1452
|
+
get_ret_func = launch_ctx.get_struct_ret_float
|
1453
|
+
else:
|
1454
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
|
1455
|
+
return self([get_ret_func(ret_index + (i,)) for i in range(self.m * self.n)])
|
1456
|
+
|
1457
|
+
def set_kernel_struct_args(self, mat, launch_ctx, ret_index=()):
|
1458
|
+
if self.dtype in primitive_types.integer_types:
|
1459
|
+
if is_signed(cook_dtype(self.dtype)):
|
1460
|
+
set_arg_func = launch_ctx.set_struct_arg_int
|
1461
|
+
else:
|
1462
|
+
set_arg_func = launch_ctx.set_struct_arg_uint
|
1463
|
+
elif self.dtype in primitive_types.real_types:
|
1464
|
+
set_arg_func = launch_ctx.set_struct_arg_float
|
1465
|
+
else:
|
1466
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
|
1467
|
+
if self.ndim == 1:
|
1468
|
+
for i in range(self.n):
|
1469
|
+
set_arg_func(ret_index + (i,), mat[i])
|
1470
|
+
else:
|
1471
|
+
for i in range(self.n):
|
1472
|
+
for j in range(self.m):
|
1473
|
+
set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
|
1474
|
+
|
1475
|
+
def _instantiate_in_python_scope(self, entries):
|
1476
|
+
entries = [[entries[k * self.m + i] for i in range(self.m)] for k in range(self.n)]
|
1477
|
+
return Matrix(
|
1478
|
+
[
|
1479
|
+
[
|
1480
|
+
int(entries[i][j]) if self.dtype in primitive_types.integer_types else float(entries[i][j])
|
1481
|
+
for j in range(self.m)
|
1482
|
+
]
|
1483
|
+
for i in range(self.n)
|
1484
|
+
],
|
1485
|
+
dt=self.dtype,
|
1486
|
+
)
|
1487
|
+
|
1488
|
+
def _instantiate(self, entries):
|
1489
|
+
if in_python_scope():
|
1490
|
+
return self._instantiate_in_python_scope(entries)
|
1491
|
+
|
1492
|
+
return make_matrix_with_shape(entries, [self.n, self.m], self.dtype)
|
1493
|
+
|
1494
|
+
def field(self, **kwargs):
|
1495
|
+
assert kwargs.get("ndim", self.ndim) == self.ndim
|
1496
|
+
kwargs.update({"ndim": self.ndim})
|
1497
|
+
return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)
|
1498
|
+
|
1499
|
+
def ndarray(self, **kwargs):
|
1500
|
+
assert kwargs.get("ndim", self.ndim) == self.ndim
|
1501
|
+
kwargs.update({"ndim": self.ndim})
|
1502
|
+
return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)
|
1503
|
+
|
1504
|
+
def get_shape(self):
|
1505
|
+
if self.ndim == 1:
|
1506
|
+
return (self.n,)
|
1507
|
+
return (self.n, self.m)
|
1508
|
+
|
1509
|
+
def to_string(self):
|
1510
|
+
dtype_str = self.dtype.to_string() if self.dtype is not None else ""
|
1511
|
+
return f"MatrixType[{self.n},{self.m}, {dtype_str}]"
|
1512
|
+
|
1513
|
+
def check_matched(self, other):
|
1514
|
+
if self.ndim != len(other.shape()):
|
1515
|
+
return False
|
1516
|
+
if self.dtype is not None and self.dtype != other.element_type():
|
1517
|
+
return False
|
1518
|
+
shape = self.get_shape()
|
1519
|
+
for i in range(self.ndim):
|
1520
|
+
if shape[i] is not None and shape[i] != other.shape()[i]:
|
1521
|
+
return False
|
1522
|
+
return True
|
1523
|
+
|
1524
|
+
|
1525
|
+
class VectorType(MatrixType):
|
1526
|
+
def __init__(self, n, dtype):
|
1527
|
+
super().__init__(n, 1, 1, dtype)
|
1528
|
+
|
1529
|
+
def __call__(self, *args):
|
1530
|
+
"""Return a vector matching the shape and dtype.
|
1531
|
+
|
1532
|
+
This function will try to convert the input to a `n`-component vector.
|
1533
|
+
|
1534
|
+
Example::
|
1535
|
+
|
1536
|
+
>>> vec3 = VectorType(3, float)
|
1537
|
+
|
1538
|
+
Create from n scalars:
|
1539
|
+
|
1540
|
+
>>> v = vec3(1, 2, 3)
|
1541
|
+
|
1542
|
+
Create from a list/tuple of n scalars:
|
1543
|
+
|
1544
|
+
>>> v = vec3([1, 2, 3])
|
1545
|
+
|
1546
|
+
Create from a single scalar
|
1547
|
+
|
1548
|
+
>>> v = vec3(1)
|
1549
|
+
|
1550
|
+
"""
|
1551
|
+
if len(args) == 0:
|
1552
|
+
raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
|
1553
|
+
if len(args) == 1:
|
1554
|
+
# Init from a real Matrix
|
1555
|
+
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
|
1556
|
+
arg = args[0]
|
1557
|
+
shape = arg.ptr.get_rvalue_type().shape()
|
1558
|
+
assert len(shape) == 1
|
1559
|
+
assert self.n == shape[0]
|
1560
|
+
return expr.Expr(arg.ptr)
|
1561
|
+
|
1562
|
+
# initialize by a single scalar, e.g. matnxm(1)
|
1563
|
+
if isinstance(args[0], (numbers.Number, expr.Expr)):
|
1564
|
+
entries = [args[0] for _ in range(self.n)]
|
1565
|
+
return self._instantiate(entries)
|
1566
|
+
args = args[0]
|
1567
|
+
# collect all input entries to a 1d list and then reshape
|
1568
|
+
# this is mostly for glsl style like vec4(v.xyz, 1.)
|
1569
|
+
entries = []
|
1570
|
+
for x in args:
|
1571
|
+
if isinstance(x, (list, tuple)):
|
1572
|
+
entries += x
|
1573
|
+
elif isinstance(x, np.ndarray):
|
1574
|
+
entries += list(x.ravel())
|
1575
|
+
elif isinstance(x, Matrix):
|
1576
|
+
entries += x.to_list()
|
1577
|
+
else:
|
1578
|
+
entries.append(x)
|
1579
|
+
|
1580
|
+
# type cast
|
1581
|
+
return self._instantiate(entries)
|
1582
|
+
|
1583
|
+
def _instantiate_in_python_scope(self, entries):
|
1584
|
+
return Vector(
|
1585
|
+
[
|
1586
|
+
int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i])
|
1587
|
+
for i in range(self.n)
|
1588
|
+
],
|
1589
|
+
dt=self.dtype,
|
1590
|
+
)
|
1591
|
+
|
1592
|
+
def _instantiate(self, entries):
|
1593
|
+
if in_python_scope():
|
1594
|
+
return self._instantiate_in_python_scope(entries)
|
1595
|
+
|
1596
|
+
return make_matrix_with_shape(entries, [self.n], self.dtype)
|
1597
|
+
|
1598
|
+
def field(self, **kwargs):
|
1599
|
+
return Vector.field(self.n, dtype=self.dtype, **kwargs)
|
1600
|
+
|
1601
|
+
def ndarray(self, **kwargs):
|
1602
|
+
return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)
|
1603
|
+
|
1604
|
+
def to_string(self):
|
1605
|
+
dtype_str = self.dtype.to_string() if self.dtype is not None else ""
|
1606
|
+
return f"VectorType[{self.n}, {dtype_str}]"
|
1607
|
+
|
1608
|
+
|
1609
|
+
class MatrixNdarray(Ndarray):
|
1610
|
+
"""GsTaichi ndarray with matrix elements.
|
1611
|
+
|
1612
|
+
Args:
|
1613
|
+
n (int): Number of rows of the matrix.
|
1614
|
+
m (int): Number of columns of the matrix.
|
1615
|
+
dtype (DataType): Data type of each value.
|
1616
|
+
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
1617
|
+
|
1618
|
+
Example::
|
1619
|
+
|
1620
|
+
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
|
1621
|
+
"""
|
1622
|
+
|
1623
|
+
def __init__(self, n, m, dtype, shape):
|
1624
|
+
self.n = n
|
1625
|
+
self.m = m
|
1626
|
+
super().__init__()
|
1627
|
+
# TODO(zhanlue): remove self.dtype and migrate its usages to element_type
|
1628
|
+
self.dtype = cook_dtype(dtype)
|
1629
|
+
|
1630
|
+
self.layout = Layout.AOS
|
1631
|
+
self.shape = tuple(shape)
|
1632
|
+
self.element_type = _type_factory.get_tensor_type((self.n, self.m), self.dtype)
|
1633
|
+
# TODO: we should pass in element_type, shape, layout instead.
|
1634
|
+
self.arr = impl.get_runtime().prog.create_ndarray(
|
1635
|
+
cook_dtype(self.element_type),
|
1636
|
+
shape,
|
1637
|
+
Layout.AOS,
|
1638
|
+
zero_fill=True,
|
1639
|
+
dbg_info=ti_python_core.DebugInfo(get_traceback()),
|
1640
|
+
)
|
1641
|
+
|
1642
|
+
@property
|
1643
|
+
def element_shape(self):
|
1644
|
+
"""Returns the shape of each element (a 2D matrix) in this ndarray.
|
1645
|
+
|
1646
|
+
Example::
|
1647
|
+
|
1648
|
+
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
|
1649
|
+
>>> arr.element_shape
|
1650
|
+
(2, 2)
|
1651
|
+
"""
|
1652
|
+
return tuple(self.arr.element_shape())
|
1653
|
+
|
1654
|
+
@python_scope
|
1655
|
+
def __setitem__(self, key, value):
|
1656
|
+
if not isinstance(value, (list, tuple)):
|
1657
|
+
value = list(value)
|
1658
|
+
if not isinstance(value[0], (list, tuple)):
|
1659
|
+
value = [[i] for i in value]
|
1660
|
+
for i in range(self.n):
|
1661
|
+
for j in range(self.m):
|
1662
|
+
self[key][i, j] = value[i][j]
|
1663
|
+
|
1664
|
+
@python_scope
|
1665
|
+
def __getitem__(self, key):
|
1666
|
+
key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
|
1667
|
+
return Matrix([[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)] for i in range(self.n)])
|
1668
|
+
|
1669
|
+
@python_scope
|
1670
|
+
def to_numpy(self):
|
1671
|
+
"""Converts this ndarray to a `numpy.ndarray`.
|
1672
|
+
|
1673
|
+
Example::
|
1674
|
+
|
1675
|
+
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1))
|
1676
|
+
>>> arr.to_numpy()
|
1677
|
+
[[[[0. 0.]
|
1678
|
+
[0. 0.]]]
|
1679
|
+
|
1680
|
+
[[[0. 0.]
|
1681
|
+
[0. 0.]]]]
|
1682
|
+
"""
|
1683
|
+
return self._ndarray_matrix_to_numpy(as_vector=0)
|
1684
|
+
|
1685
|
+
@python_scope
|
1686
|
+
def from_numpy(self, arr):
|
1687
|
+
"""Copies the data of a `numpy.ndarray` into this array.
|
1688
|
+
|
1689
|
+
Example::
|
1690
|
+
|
1691
|
+
>>> m = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1), layout=0)
|
1692
|
+
>>> arr = np.ones((2, 1, 2, 2))
|
1693
|
+
>>> m.from_numpy(arr)
|
1694
|
+
"""
|
1695
|
+
self._ndarray_matrix_from_numpy(arr, as_vector=0)
|
1696
|
+
|
1697
|
+
@python_scope
|
1698
|
+
def __deepcopy__(self, memo=None):
|
1699
|
+
ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape)
|
1700
|
+
ret_arr.copy_from(self)
|
1701
|
+
return ret_arr
|
1702
|
+
|
1703
|
+
@python_scope
|
1704
|
+
def _fill_by_kernel(self, val):
|
1705
|
+
from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
|
1706
|
+
|
1707
|
+
shape = self.element_type.shape()
|
1708
|
+
n = shape[0]
|
1709
|
+
m = 1
|
1710
|
+
if len(shape) > 1:
|
1711
|
+
m = shape[1]
|
1712
|
+
|
1713
|
+
prim_dtype = self.element_type.element_type()
|
1714
|
+
matrix_type = MatrixType(n, m, len(shape), prim_dtype)
|
1715
|
+
if isinstance(val, Matrix):
|
1716
|
+
value = val
|
1717
|
+
else:
|
1718
|
+
value = matrix_type(val)
|
1719
|
+
fill_ndarray_matrix(self, value)
|
1720
|
+
|
1721
|
+
@python_scope
|
1722
|
+
def __repr__(self):
|
1723
|
+
return f"<{self.n}x{self.m} {Layout.AOS} ti.Matrix.ndarray>"
|
1724
|
+
|
1725
|
+
|
1726
|
+
class VectorNdarray(Ndarray):
|
1727
|
+
"""GsTaichi ndarray with vector elements.
|
1728
|
+
|
1729
|
+
Args:
|
1730
|
+
n (int): Size of the vector.
|
1731
|
+
dtype (DataType): Data type of each value.
|
1732
|
+
shape (Tuple[int]): Shape of the ndarray.
|
1733
|
+
|
1734
|
+
Example::
|
1735
|
+
|
1736
|
+
>>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
|
1737
|
+
"""
|
1738
|
+
|
1739
|
+
def __init__(self, n, dtype, shape):
|
1740
|
+
self.n = n
|
1741
|
+
super().__init__()
|
1742
|
+
# TODO(zhanlue): remove self.dtype and migrate its usages to element_type
|
1743
|
+
self.dtype = cook_dtype(dtype)
|
1744
|
+
|
1745
|
+
self.layout = Layout.AOS
|
1746
|
+
self.shape = tuple(shape)
|
1747
|
+
self.element_type = _type_factory.get_tensor_type((n,), self.dtype)
|
1748
|
+
self.arr = impl.get_runtime().prog.create_ndarray(
|
1749
|
+
cook_dtype(self.element_type),
|
1750
|
+
shape,
|
1751
|
+
Layout.AOS,
|
1752
|
+
zero_fill=True,
|
1753
|
+
dbg_info=ti_python_core.DebugInfo(get_traceback()),
|
1754
|
+
)
|
1755
|
+
|
1756
|
+
@property
|
1757
|
+
def element_shape(self):
|
1758
|
+
"""Gets the dimension of the vector of this ndarray.
|
1759
|
+
|
1760
|
+
Example::
|
1761
|
+
|
1762
|
+
>>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
|
1763
|
+
>>> a.element_shape
|
1764
|
+
(3,)
|
1765
|
+
"""
|
1766
|
+
return tuple(self.arr.element_shape())
|
1767
|
+
|
1768
|
+
@python_scope
|
1769
|
+
def __setitem__(self, key, value):
|
1770
|
+
if not isinstance(value, (list, tuple)):
|
1771
|
+
value = list(value)
|
1772
|
+
for i in range(self.n):
|
1773
|
+
self[key][i] = value[i]
|
1774
|
+
|
1775
|
+
@python_scope
|
1776
|
+
def __getitem__(self, key):
|
1777
|
+
key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
|
1778
|
+
return Vector([NdarrayHostAccess(self, key, (i,)) for i in range(self.n)])
|
1779
|
+
|
1780
|
+
@python_scope
|
1781
|
+
def to_numpy(self):
|
1782
|
+
"""Converts this vector ndarray to a `numpy.ndarray`.
|
1783
|
+
|
1784
|
+
Example::
|
1785
|
+
|
1786
|
+
>>> a = ti.VectorNdarray(3, ti.f32, (2, 2))
|
1787
|
+
>>> a.to_numpy()
|
1788
|
+
array([[[0., 0., 0.],
|
1789
|
+
[0., 0., 0.]],
|
1790
|
+
|
1791
|
+
[[0., 0., 0.],
|
1792
|
+
[0., 0., 0.]]], dtype=float32)
|
1793
|
+
"""
|
1794
|
+
return self._ndarray_matrix_to_numpy(as_vector=1)
|
1795
|
+
|
1796
|
+
@python_scope
|
1797
|
+
def from_numpy(self, arr):
|
1798
|
+
"""Copies the data from a `numpy.ndarray` into this ndarray.
|
1799
|
+
|
1800
|
+
The shape and data type of `arr` must match this ndarray.
|
1801
|
+
|
1802
|
+
Example::
|
1803
|
+
|
1804
|
+
>>> import numpy as np
|
1805
|
+
>>> a = ti.VectorNdarray(3, ti.f32, (2, 2), 0)
|
1806
|
+
>>> b = np.ones((2, 2, 3), dtype=np.float32)
|
1807
|
+
>>> a.from_numpy(b)
|
1808
|
+
"""
|
1809
|
+
self._ndarray_matrix_from_numpy(arr, as_vector=1)
|
1810
|
+
|
1811
|
+
@python_scope
|
1812
|
+
def __deepcopy__(self, memo=None):
|
1813
|
+
ret_arr = VectorNdarray(self.n, self.dtype, self.shape)
|
1814
|
+
ret_arr.copy_from(self)
|
1815
|
+
return ret_arr
|
1816
|
+
|
1817
|
+
@python_scope
|
1818
|
+
def _fill_by_kernel(self, val):
|
1819
|
+
from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
|
1820
|
+
|
1821
|
+
shape = self.element_type.shape()
|
1822
|
+
prim_dtype = self.element_type.element_type()
|
1823
|
+
vector_type = VectorType(shape[0], prim_dtype)
|
1824
|
+
if isinstance(val, Vector):
|
1825
|
+
value = val
|
1826
|
+
else:
|
1827
|
+
value = vector_type(val)
|
1828
|
+
fill_ndarray_matrix(self, value)
|
1829
|
+
|
1830
|
+
@python_scope
|
1831
|
+
def __repr__(self):
|
1832
|
+
return f"<{self.n} {Layout.AOS} ti.Vector.ndarray>"
|
1833
|
+
|
1834
|
+
|
1835
|
+
__all__ = ["Matrix", "Vector", "MatrixField", "MatrixNdarray", "VectorNdarray"]
|