gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1835 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+ import numbers
5
+ from collections.abc import Iterable
6
+ from itertools import product
7
+
8
+ import numpy as np
9
+
10
+ from gstaichi._lib import core as ti_python_core
11
+ from gstaichi._lib.utils import ti_python_core as _ti_python_core
12
+ from gstaichi.lang import expr, impl, runtime_ops
13
+ from gstaichi.lang import ops as ops_mod
14
+ from gstaichi.lang._ndarray import Ndarray, NdarrayHostAccess
15
+ from gstaichi.lang.common_ops import GsTaichiOperations
16
+ from gstaichi.lang.exception import (
17
+ GsTaichiRuntimeError,
18
+ GsTaichiRuntimeTypeError,
19
+ GsTaichiSyntaxError,
20
+ GsTaichiTypeError,
21
+ )
22
+ from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
23
+ from gstaichi.lang.util import (
24
+ cook_dtype,
25
+ get_traceback,
26
+ gstaichi_scope,
27
+ in_python_scope,
28
+ python_scope,
29
+ to_numpy_type,
30
+ to_pytorch_type,
31
+ warning,
32
+ )
33
+ from gstaichi.types import primitive_types
34
+ from gstaichi.types.compound_types import CompoundType
35
+ from gstaichi.types.enums import Layout
36
+ from gstaichi.types.utils import is_signed
37
+
38
+ _type_factory = _ti_python_core.get_type_factory_instance()
39
+
40
+
41
+ def _generate_swizzle_patterns(key_group: str, required_length=4):
42
+ """Generate vector swizzle patterns from a given set of characters.
43
+
44
+ Example:
45
+
46
+ For `key_group=xyzw` and `required_length=4`, this function will return a
47
+ list consists of all possible strings (no repeats) in characters
48
+ `x`, `y`, `z`, `w` and of length<=4:
49
+ [`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...]
50
+ The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340.
51
+ """
52
+ result = []
53
+ for k in range(1, required_length + 1):
54
+ result.extend(product(key_group, repeat=k))
55
+ result = ["".join(pat) for pat in result]
56
+ return result
57
+
58
+
59
+ def _gen_swizzles(cls):
60
+ # https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
61
+ KEYGROUP_SET = ["xyzw", "rgba", "stpq"]
62
+ cls._swizzle_to_keygroup = {}
63
+ cls._keygroup_to_checker = {}
64
+
65
+ def make_valid_attribs_checker(key_group):
66
+ def check(instance, pattern):
67
+ valid_attribs = set(key_group[: instance.n])
68
+ pattern_set = set(pattern)
69
+ diff = pattern_set - valid_attribs
70
+ if len(diff):
71
+ valid_attribs = tuple(sorted(valid_attribs))
72
+ pattern = tuple(pattern)
73
+ raise GsTaichiSyntaxError(f"vec{instance.n} only has " f"attributes={valid_attribs}, got={pattern}")
74
+
75
+ return check
76
+
77
+ for key_group in KEYGROUP_SET:
78
+ cls._keygroup_to_checker[key_group] = make_valid_attribs_checker(key_group)
79
+ for index, attr in enumerate(key_group):
80
+
81
+ def gen_property(attr, attr_idx, key_group):
82
+ checker = cls._keygroup_to_checker[key_group]
83
+
84
+ def prop_getter(instance):
85
+ checker(instance, attr)
86
+ return instance[attr_idx]
87
+
88
+ @python_scope
89
+ def prop_setter(instance, value):
90
+ checker(instance, attr)
91
+ instance[attr_idx] = value
92
+
93
+ return property(prop_getter, prop_setter)
94
+
95
+ prop = gen_property(attr, index, key_group)
96
+ setattr(cls, attr, prop)
97
+ cls._swizzle_to_keygroup[attr] = key_group
98
+
99
+ for key_group in KEYGROUP_SET:
100
+ sw_patterns = _generate_swizzle_patterns(key_group, required_length=4)
101
+ # len=1 accessors are handled specially above
102
+ sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
103
+ for prop_key in sw_patterns:
104
+ # Create a function for value capturing
105
+ def gen_property(pattern, key_group):
106
+ checker = cls._keygroup_to_checker[key_group]
107
+
108
+ def prop_getter(instance):
109
+ checker(instance, pattern)
110
+ res = []
111
+ for ch in pattern:
112
+ res.append(instance[key_group.index(ch)])
113
+ return Vector(res)
114
+
115
+ @python_scope
116
+ def prop_setter(instance, value):
117
+ if len(pattern) != len(value):
118
+ raise GsTaichiRuntimeError(f"value len does not match the swizzle pattern={pattern}")
119
+ checker(instance, pattern)
120
+ for ch, val in zip(pattern, value):
121
+ instance[key_group.index(ch)] = val
122
+
123
+ prop = property(prop_getter, prop_setter)
124
+ return prop
125
+
126
+ prop = gen_property(prop_key, key_group)
127
+ setattr(cls, prop_key, prop)
128
+ cls._swizzle_to_keygroup[prop_key] = key_group
129
+ return cls
130
+
131
+
132
+ def _infer_entry_dt(entry):
133
+ if isinstance(entry, (int, np.integer)):
134
+ return impl.get_runtime().default_ip
135
+ if isinstance(entry, (float, np.floating)):
136
+ return impl.get_runtime().default_fp
137
+ if isinstance(entry, expr.Expr):
138
+ dt = entry.ptr.get_rvalue_type()
139
+ if dt == ti_python_core.DataType_unknown:
140
+ raise GsTaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
141
+ return dt
142
+ raise GsTaichiTypeError("Element type of the matrix is invalid.")
143
+
144
+
145
+ def _infer_array_dt(arr):
146
+ assert len(arr) > 0
147
+ return functools.reduce(ti_python_core.promoted_type, map(_infer_entry_dt, arr))
148
+
149
+
150
+ def make_matrix_with_shape(arr, shape, dt):
151
+ return expr.Expr(
152
+ impl.get_runtime()
153
+ .compiling_callable.ast_builder()
154
+ .make_matrix_expr(
155
+ shape,
156
+ dt,
157
+ [expr.Expr(elt).ptr for elt in arr],
158
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
159
+ )
160
+ )
161
+
162
+
163
+ def make_matrix(arr, dt=None):
164
+ if len(arr) == 0:
165
+ # the only usage of an empty vector is to serve as field indices
166
+ shape = [0]
167
+ dt = primitive_types.i32
168
+ else:
169
+ if isinstance(arr[0], Iterable): # matrix
170
+ shape = [len(arr), len(arr[0])]
171
+ arr = [elt for row in arr for elt in row]
172
+ else: # vector
173
+ shape = [len(arr)]
174
+ if dt is None:
175
+ dt = _infer_array_dt(arr)
176
+ else:
177
+ dt = cook_dtype(dt)
178
+ return expr.Expr(
179
+ impl.get_runtime()
180
+ .compiling_callable.ast_builder()
181
+ .make_matrix_expr(
182
+ shape,
183
+ dt,
184
+ [expr.Expr(elt).ptr for elt in arr],
185
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
186
+ )
187
+ )
188
+
189
+
190
+ def _read_host_access(x):
191
+ if isinstance(x, SNodeHostAccess):
192
+ return x.accessor.getter(*x.key)
193
+ assert isinstance(x, NdarrayHostAccess)
194
+ return x.getter()
195
+
196
+
197
+ def _write_host_access(x, value):
198
+ if isinstance(x, SNodeHostAccess):
199
+ x.accessor.setter(value, *x.key)
200
+ else:
201
+ assert isinstance(x, NdarrayHostAccess)
202
+ x.setter(value)
203
+
204
+
205
+ @_gen_swizzles
206
+ class Matrix(GsTaichiOperations):
207
+ """The matrix class.
208
+
209
+ A matrix is a 2-D rectangular array with scalar entries, it's row-majored, and is
210
+ aligned continuously. We recommend only use matrix with no more than 32 elements for
211
+ efficiency considerations.
212
+
213
+ Note: in gstaichi a matrix is strictly two-dimensional and only stores scalars.
214
+
215
+ Args:
216
+ arr (Union[list, tuple, np.ndarray]): the initial values of a matrix.
217
+ dt (:mod:`~gstaichi.types.primitive_types`): the element data type.
218
+ ndim (int optional): the number of dimensions of the matrix; forced reshape if given.
219
+
220
+ Example::
221
+
222
+ use a 2d list to initialize a matrix
223
+
224
+ >>> @ti.kernel
225
+ >>> def test():
226
+ >>> n = 5
227
+ >>> M = ti.Matrix([[0] * n for _ in range(n)], ti.i32)
228
+ >>> print(M) # a 5x5 matrix with integer elements
229
+
230
+ get the number of rows and columns via the `n`, `m` property:
231
+
232
+ >>> M = ti.Matrix([[0, 1], [2, 3], [4, 5]], ti.i32)
233
+ >>> M.n # number of rows
234
+ 3
235
+ >>> M.m # number of cols
236
+ >>> 2
237
+
238
+ you can even initialize a matrix with an empty list:
239
+
240
+ >>> M = ti.Matrix([[], []], ti.i32)
241
+ >>> M.n
242
+ 2
243
+ >>> M.m
244
+ 0
245
+ """
246
+
247
+ _is_gstaichi_class = True
248
+ _is_matrix_class = True
249
+ __array_priority__ = 1000
250
+
251
+ def __init__(self, arr, dt=None):
252
+ if not isinstance(arr, (list, tuple, np.ndarray)):
253
+ raise GsTaichiTypeError("An Matrix/Vector can only be initialized with an array-like object")
254
+ if len(arr) == 0:
255
+ self.ndim = 0
256
+ self.n, self.m = 0, 0
257
+ self.entries = np.array([])
258
+ self.is_host_access = False
259
+ elif isinstance(arr[0], Matrix):
260
+ raise Exception("cols/rows required when using list of vectors")
261
+ elif isinstance(arr[0], Iterable): # matrix
262
+ self.ndim = 2
263
+ self.n, self.m = len(arr), len(arr[0])
264
+ if isinstance(arr[0][0], (SNodeHostAccess, NdarrayHostAccess)):
265
+ self.entries = arr
266
+ self.is_host_access = True
267
+ else:
268
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
269
+ self.is_host_access = False
270
+ else: # vector
271
+ self.ndim = 1
272
+ self.n, self.m = len(arr), 1
273
+ if isinstance(arr[0], (SNodeHostAccess, NdarrayHostAccess)):
274
+ self.entries = arr
275
+ self.is_host_access = True
276
+ else:
277
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
278
+ self.is_host_access = False
279
+
280
+ if self.n * self.m > 32:
281
+ warning(
282
+ f"GsTaichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested."
283
+ " Matrices/vectors will be automatically unrolled at compile-time for performance."
284
+ " So the compilation time could be extremely long if the matrix size is too big."
285
+ " You may use a field to store a large matrix like this, e.g.:\n"
286
+ f" x = ti.field(ti.f32, ({self.n}, {self.m})).\n"
287
+ " See https://docs.taichi-lang.org/docs/field#matrix-size"
288
+ " for more details.",
289
+ UserWarning,
290
+ stacklevel=2,
291
+ )
292
+
293
+ def get_shape(self):
294
+ if self.ndim == 1:
295
+ return (self.n,)
296
+ if self.ndim == 2:
297
+ return (self.n, self.m)
298
+ return None
299
+
300
+ def __matmul__(self, other):
301
+ """Matrix-matrix or matrix-vector multiply.
302
+
303
+ Args:
304
+ other (Union[Matrix, Vector]): a matrix or a vector.
305
+
306
+ Returns:
307
+ The matrix-matrix product or matrix-vector product.
308
+
309
+ """
310
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
311
+
312
+ return matrix_ops.matmul(self, other)
313
+
314
+ # host access & python scope operation
315
+ def __len__(self):
316
+ """Get the length of each row of a matrix"""
317
+ # TODO: When this is a vector, should return its dimension?
318
+ return self.n
319
+
320
+ def __iter__(self):
321
+ if self.ndim == 1:
322
+ return (self[i] for i in range(self.n))
323
+ return ([self[i, j] for j in range(self.m)] for i in range(self.n))
324
+
325
+ def __getitem__(self, indices):
326
+ """Access to the element at the given indices in a matrix.
327
+
328
+ Args:
329
+ indices (Sequence[Expr]): the indices of the element.
330
+
331
+ Returns:
332
+ The value of the element at a specific position of a matrix.
333
+
334
+ """
335
+ entry = self._get_entry(indices)
336
+ if self.is_host_access:
337
+ return _read_host_access(entry)
338
+ return entry
339
+
340
+ @python_scope
341
+ def __setitem__(self, indices, item):
342
+ """Set the element value at the given indices in a matrix.
343
+
344
+ Args:
345
+ indices (Sequence[Expr]): the indices of a element.
346
+
347
+ """
348
+ if self.is_host_access:
349
+ entry = self._get_entry(indices)
350
+ _write_host_access(entry, item)
351
+ else:
352
+ if not isinstance(indices, (list, tuple)):
353
+ indices = [indices]
354
+ assert len(indices) in [1, 2]
355
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
356
+ if self.ndim == 1:
357
+ self.entries[indices[0]] = item
358
+ else:
359
+ self.entries[indices[0]][indices[1]] = item
360
+
361
+ def _get_entry(self, indices):
362
+ if not isinstance(indices, (list, tuple)):
363
+ indices = [indices]
364
+ assert len(indices) in [1, 2]
365
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
366
+ if self.ndim == 1:
367
+ return self.entries[indices[0]]
368
+ return self.entries[indices[0]][indices[1]]
369
+
370
+ def _get_slice(self, a, b):
371
+ if isinstance(a, slice):
372
+ a = range(a.start or 0, a.stop or self.n, a.step or 1)
373
+ if isinstance(b, slice):
374
+ b = range(b.start or 0, b.stop or self.m, b.step or 1)
375
+ if isinstance(a, range) and isinstance(b, range):
376
+ return Matrix([[self._get_entry(i, j) for j in b] for i in a])
377
+ if isinstance(a, range): # b is not range
378
+ return Vector([self._get_entry(i, b) for i in a])
379
+ # a is not range while b is range
380
+ return Vector([self._get_entry(a, j) for j in b])
381
+
382
+ @python_scope
383
+ def _set_entries(self, value):
384
+ if isinstance(value, Matrix):
385
+ value = value.to_list()
386
+ if self.is_host_access:
387
+ if self.ndim == 1:
388
+ for i in range(self.n):
389
+ _write_host_access(self.entries[i], value[i])
390
+ else:
391
+ for i in range(self.n):
392
+ for j in range(self.m):
393
+ _write_host_access(self.entries[i][j], value[i][j])
394
+ else:
395
+ if self.ndim == 1:
396
+ for i in range(self.n):
397
+ self.entries[i] = value[i]
398
+ else:
399
+ for i in range(self.n):
400
+ for j in range(self.m):
401
+ self.entries[i][j] = value[i][j]
402
+
403
+ @property
404
+ def _members(self):
405
+ return self.entries
406
+
407
+ def to_list(self):
408
+ """Return this matrix as a 1D `list`.
409
+
410
+ This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods,
411
+ the difference is that this function always returns a new list.
412
+ """
413
+ if self.is_host_access:
414
+ if self.ndim == 1:
415
+ return [_read_host_access(self.entries[i]) for i in range(self.n)]
416
+ assert self.ndim == 2
417
+ return [[_read_host_access(self.entries[i][j]) for j in range(self.m)] for i in range(self.n)]
418
+ return self.entries.tolist()
419
+
420
+ @gstaichi_scope
421
+ def cast(self, dtype):
422
+ """Cast the matrix elements to a specified data type.
423
+
424
+ Args:
425
+ dtype (:mod:`~gstaichi.types.primitive_types`): data type of the
426
+ returned matrix.
427
+
428
+ Returns:
429
+ :class:`gstaichi.Matrix`: A new matrix with the specified data dtype.
430
+
431
+ Example::
432
+
433
+ >>> A = ti.Matrix([0, 1, 2], ti.i32)
434
+ >>> B = A.cast(ti.f32)
435
+ >>> B
436
+ [0.0, 1.0, 2.0]
437
+ """
438
+ if self.ndim == 1:
439
+ return Vector([ops_mod.cast(self[i], dtype) for i in range(self.n)])
440
+ return Matrix([[ops_mod.cast(self[i, j], dtype) for j in range(self.m)] for i in range(self.n)])
441
+
442
+ def trace(self):
443
+ """The sum of a matrix diagonal elements.
444
+
445
+ To call this method the matrix must be square-like.
446
+
447
+ Returns:
448
+ The sum of a matrix diagonal elements.
449
+
450
+ Example::
451
+
452
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
453
+ >>> m.trace()
454
+ 5
455
+ """
456
+ # pylint: disable-msg=C0415
457
+ from gstaichi.lang import matrix_ops
458
+
459
+ return matrix_ops.trace(self)
460
+
461
+ def inverse(self):
462
+ """Returns the inverse of this matrix.
463
+
464
+ Note:
465
+ The matrix dimension should be less than or equal to 4.
466
+
467
+ Returns:
468
+ :class:`~gstaichi.Matrix`: The inverse of a matrix.
469
+
470
+ Raises:
471
+ Exception: Inversions of matrices with sizes >= 5 are not supported.
472
+ """
473
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
474
+
475
+ return matrix_ops.inverse(self)
476
+
477
+ def normalized(self, eps=0):
478
+ """Normalize a vector, i.e. matrices with the second dimension being
479
+ equal to one.
480
+
481
+ The normalization of a vector `v` is a vector of length 1
482
+ and has the same direction with `v`. It's equal to `v/|v|`.
483
+
484
+ Args:
485
+ eps (float): a safe-guard value for sqrt, usually 0.
486
+
487
+ Example::
488
+
489
+ >>> a = ti.Vector([3, 4], ti.f32)
490
+ >>> a.normalized()
491
+ [0.6, 0.8]
492
+ """
493
+ # pylint: disable-msg=C0415
494
+ from gstaichi.lang import matrix_ops
495
+
496
+ return matrix_ops.normalized(self, eps)
497
+
498
+ def transpose(self):
499
+ """Returns the transpose of a matrix.
500
+
501
+ Returns:
502
+ :class:`~gstaichi.Matrix`: The transpose of this matrix.
503
+
504
+ Example::
505
+
506
+ >>> A = ti.Matrix([[0, 1], [2, 3]])
507
+ >>> A.transpose()
508
+ [[0, 2], [1, 3]]
509
+ """
510
+ # pylint: disable=C0415
511
+ from gstaichi.lang import matrix_ops
512
+
513
+ return matrix_ops.transpose(self)
514
+
515
+ @gstaichi_scope
516
+ def determinant(a):
517
+ """Returns the determinant of this matrix.
518
+
519
+ Note:
520
+ The matrix dimension should be less than or equal to 4.
521
+
522
+ Returns:
523
+ dtype: The determinant of this matrix.
524
+
525
+ Raises:
526
+ Exception: Determinants of matrices with sizes >= 5 are not supported.
527
+ """
528
+ # pylint: disable=C0415
529
+ from gstaichi.lang import matrix_ops
530
+
531
+ return matrix_ops.determinant(a)
532
+
533
+ @staticmethod
534
+ def diag(dim, val):
535
+ """Returns a diagonal square matrix with the diagonals filled
536
+ with `val`.
537
+
538
+ Args:
539
+ dim (int): the dimension of the wanted square matrix.
540
+ val (TypeVar): value for the diagonal elements.
541
+
542
+ Returns:
543
+ :class:`~gstaichi.Matrix`: The wanted diagonal matrix.
544
+
545
+ Example::
546
+
547
+ >>> m = ti.Matrix.diag(3, 1)
548
+ [[1, 0, 0],
549
+ [0, 1, 0],
550
+ [0, 0, 1]]
551
+ """
552
+ # pylint: disable=C0415
553
+ from gstaichi.lang import matrix_ops
554
+
555
+ return matrix_ops.diag(dim, val)
556
+
557
+ def sum(self):
558
+ """Return the sum of all elements.
559
+
560
+ Example::
561
+
562
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
563
+ >>> m.sum()
564
+ 10
565
+ """
566
+ # pylint: disable=C0415
567
+ from gstaichi.lang import matrix_ops
568
+
569
+ return matrix_ops.sum(self)
570
+
571
+ def norm(self, eps=0):
572
+ """Returns the square root of the sum of the absolute squares
573
+ of its elements.
574
+
575
+ Args:
576
+ eps (Number): a safe-guard value for sqrt, usually 0.
577
+
578
+ Example::
579
+
580
+ >>> a = ti.Vector([3, 4])
581
+ >>> a.norm()
582
+ 5
583
+
584
+ Returns:
585
+ The square root of the sum of the absolute squares of its elements.
586
+ """
587
+ # pylint: disable=C0415
588
+ from gstaichi.lang import matrix_ops
589
+
590
+ return matrix_ops.norm(self, eps=eps)
591
+
592
+ def norm_inv(self, eps=0):
593
+ """The inverse of the matrix :func:`~gstaichi.lang.matrix.Matrix.norm`.
594
+
595
+ Args:
596
+ eps (float): a safe-guard value for sqrt, usually 0.
597
+
598
+ Returns:
599
+ The inverse of the matrix/vector `norm`.
600
+ """
601
+ # pylint: disable=C0415
602
+ from gstaichi.lang import matrix_ops
603
+
604
+ return matrix_ops.norm_inv(self, eps=eps)
605
+
606
+ def norm_sqr(self):
607
+ """Returns the sum of the absolute squares of its elements."""
608
+ # pylint: disable=C0415
609
+ from gstaichi.lang import matrix_ops
610
+
611
+ return matrix_ops.norm_sqr(self)
612
+
613
+ def max(self):
614
+ """Returns the maximum element value."""
615
+ # pylint: disable=C0415
616
+ from gstaichi.lang import matrix_ops
617
+
618
+ return matrix_ops.max(self)
619
+
620
+ def min(self):
621
+ """Returns the minimum element value."""
622
+ # pylint: disable=C0415
623
+ from gstaichi.lang import matrix_ops
624
+
625
+ return matrix_ops.min(self)
626
+
627
+ def any(self):
628
+ """Test whether any element not equal zero.
629
+
630
+ Returns:
631
+ bool: `True` if any element is not equal zero, `False` otherwise.
632
+
633
+ Example::
634
+
635
+ >>> v = ti.Vector([0, 0, 1])
636
+ >>> v.any()
637
+ True
638
+ """
639
+ # pylint: disable=C0415
640
+ from gstaichi.lang import matrix_ops
641
+
642
+ return matrix_ops.any(self)
643
+
644
+ def all(self):
645
+ """Test whether all element not equal zero.
646
+
647
+ Returns:
648
+ bool: `True` if all elements are not equal zero, `False` otherwise.
649
+
650
+ Example::
651
+
652
+ >>> v = ti.Vector([0, 0, 1])
653
+ >>> v.all()
654
+ False
655
+ """
656
+ # pylint: disable=C0415
657
+ from gstaichi.lang import matrix_ops
658
+
659
+ return matrix_ops.all(self)
660
+
661
+ def fill(self, val):
662
+ """Fills the matrix with a specified value.
663
+
664
+ Args:
665
+ val (Union[int, float]): Value to fill.
666
+
667
+ Example::
668
+
669
+ >>> A = ti.Matrix([0, 1, 2, 3])
670
+ >>> A.fill(-1)
671
+ >>> A
672
+ [-1, -1, -1, -1]
673
+ """
674
+ # pylint: disable=C0415
675
+ from gstaichi.lang import matrix_ops
676
+
677
+ return matrix_ops.fill(self, val)
678
+
679
+ def to_numpy(self):
680
+ """Converts this matrix to a numpy array.
681
+
682
+ Returns:
683
+ numpy.ndarray: The result numpy array.
684
+
685
+ Example::
686
+
687
+ >>> A = ti.Matrix([[0], [1], [2], [3]])
688
+ >>> A.to_numpy()
689
+ >>> A
690
+ array([[0], [1], [2], [3]])
691
+ """
692
+ if self.is_host_access:
693
+ return np.array(self.to_list())
694
+ return self.entries
695
+
696
+ @gstaichi_scope
697
+ def __ti_repr__(self):
698
+ yield "["
699
+ for i in range(self.n):
700
+ if i:
701
+ yield ", "
702
+ if self.m != 1:
703
+ yield "["
704
+ for j in range(self.m):
705
+ if j:
706
+ yield ", "
707
+ yield self(i, j)
708
+ if self.m != 1:
709
+ yield "]"
710
+ yield "]"
711
+
712
+ def __str__(self):
713
+ """Python scope matrix print support."""
714
+ if impl.inside_kernel():
715
+ """
716
+ It seems that when pybind11 got an type mismatch, it will try
717
+ to invoke `repr` to show the object... e.g.:
718
+
719
+ TypeError: make_const_expr_f32(): incompatible function arguments. The following argument types are supported:
720
+ 1. (arg0: float) -> gstaichi_python.Expr
721
+
722
+ Invoked with: <GsTaichi 2x1 Matrix>
723
+
724
+ So we have to make it happy with a dummy string...
725
+ """
726
+ return f"<{self.n}x{self.m} ti.Matrix>"
727
+ return str(self.to_numpy())
728
+
729
+ def __repr__(self):
730
+ return str(self.to_numpy())
731
+
732
+ @staticmethod
733
+ @gstaichi_scope
734
+ def zero(dt, n, m=None):
735
+ """Constructs a Matrix filled with zeros.
736
+
737
+ Args:
738
+ dt (DataType): The desired data type.
739
+ n (int): The first dimension (row) of the matrix.
740
+ m (int, optional): The second dimension (column) of the matrix.
741
+
742
+ Returns:
743
+ :class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with zeros.
744
+
745
+ """
746
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
747
+
748
+ if m is None:
749
+ return matrix_ops._filled_vector(n, dt, 0)
750
+ return matrix_ops._filled_matrix(n, m, dt, 0)
751
+
752
+ @staticmethod
753
+ @gstaichi_scope
754
+ def one(dt, n, m=None):
755
+ """Constructs a Matrix filled with ones.
756
+
757
+ Args:
758
+ dt (DataType): The desired data type.
759
+ n (int): The first dimension (row) of the matrix.
760
+ m (int, optional): The second dimension (column) of the matrix.
761
+
762
+ Returns:
763
+ :class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with ones.
764
+
765
+ """
766
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
767
+
768
+ if m is None:
769
+ return matrix_ops._filled_vector(n, dt, 1)
770
+ return matrix_ops._filled_matrix(n, m, dt, 1)
771
+
772
+ @staticmethod
773
+ @gstaichi_scope
774
+ def unit(n, i, dt=None):
775
+ """Constructs a n-D vector with the `i`-th entry being equal to one and
776
+ the remaining entries are all zeros.
777
+
778
+ Args:
779
+ n (int): The length of the vector.
780
+ i (int): The index of the entry that will be filled with one.
781
+ dt (:mod:`~gstaichi.types.primitive_types`, optional): The desired data type.
782
+
783
+ Returns:
784
+ :class:`~gstaichi.Matrix`: The returned vector.
785
+
786
+ Example::
787
+
788
+ >>> A = ti.Matrix.unit(3, 1)
789
+ >>> A
790
+ [0, 1, 0]
791
+ """
792
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
793
+
794
+ if dt is None:
795
+ dt = int
796
+ assert 0 <= i < n
797
+ return matrix_ops._unit_vector(n, i, dt)
798
+
799
+ @staticmethod
800
+ @gstaichi_scope
801
+ def identity(dt, n):
802
+ """Constructs an identity Matrix with shape (n, n).
803
+
804
+ Args:
805
+ dt (DataType): The desired data type.
806
+ n (int): The number of rows/columns.
807
+
808
+ Returns:
809
+ :class:`~gstaichi.Matrix`: An `n x n` identity matrix.
810
+ """
811
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
812
+
813
+ return matrix_ops._identity_matrix(n, dt)
814
+
815
+ @classmethod
816
+ @python_scope
817
+ def field(
818
+ cls,
819
+ n,
820
+ m,
821
+ dtype,
822
+ shape=None,
823
+ order=None,
824
+ name="",
825
+ offset=None,
826
+ needs_grad=False,
827
+ needs_dual=False,
828
+ layout=Layout.AOS,
829
+ ndim=None,
830
+ ):
831
+ """Construct a data container to hold all elements of the Matrix.
832
+
833
+ Args:
834
+ n (int): The desired number of rows of the Matrix.
835
+ m (int): The desired number of columns of the Matrix.
836
+ dtype (DataType, optional): The desired data type of the Matrix.
837
+ shape (Union[int, tuple of int], optional): The desired shape of the Matrix.
838
+ order (str, optional): order of the shape laid out in memory.
839
+ name (string, optional): The custom name of the field.
840
+ offset (Union[int, tuple of int], optional): The coordinate offset
841
+ of all elements in a field.
842
+ needs_grad (bool, optional): Whether the Matrix need grad field (reverse mode autodiff).
843
+ needs_dual (bool, optional): Whether the Matrix need dual field (forward mode autodiff).
844
+ layout (Layout, optional): The field layout, either Array Of
845
+ Structure (AOS) or Structure Of Array (SOA).
846
+
847
+ Returns:
848
+ :class:`~gstaichi.Matrix`: A matrix.
849
+ """
850
+ entries = []
851
+ element_dim = ndim if ndim is not None else 2
852
+ if isinstance(dtype, (list, tuple, np.ndarray)):
853
+ # set different dtype for each element in Matrix
854
+ # see #2135
855
+ if m == 1:
856
+ assert (
857
+ len(np.shape(dtype)) == 1 and len(dtype) == n
858
+ ), f"Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}"
859
+ for i in range(n):
860
+ entries.append(
861
+ impl.create_field_member(
862
+ dtype[i],
863
+ name=name,
864
+ needs_grad=needs_grad,
865
+ needs_dual=needs_dual,
866
+ )
867
+ )
868
+ else:
869
+ assert (
870
+ len(np.shape(dtype)) == 2 and len(dtype) == n and len(dtype[0]) == m
871
+ ), f"Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}"
872
+ for i in range(n):
873
+ for j in range(m):
874
+ entries.append(
875
+ impl.create_field_member(
876
+ dtype[i][j],
877
+ name=name,
878
+ needs_grad=needs_grad,
879
+ needs_dual=needs_dual,
880
+ )
881
+ )
882
+ else:
883
+ for _ in range(n * m):
884
+ entries.append(impl.create_field_member(dtype, name=name, needs_grad=needs_grad, needs_dual=needs_dual))
885
+ entries, entries_grad, entries_dual = zip(*entries)
886
+
887
+ entries = MatrixField(entries, n, m, element_dim)
888
+ if all(entries_grad):
889
+ entries_grad = MatrixField(entries_grad, n, m, element_dim)
890
+ entries._set_grad(entries_grad)
891
+ if all(entries_dual):
892
+ entries_dual = MatrixField(entries_dual, n, m, element_dim)
893
+ entries._set_dual(entries_dual)
894
+
895
+ impl.get_runtime().matrix_fields.append(entries)
896
+
897
+ if shape is None:
898
+ if offset is not None:
899
+ raise GsTaichiSyntaxError("shape cannot be None when offset is set")
900
+ if order is not None:
901
+ raise GsTaichiSyntaxError("shape cannot be None when order is set")
902
+ else:
903
+ if isinstance(shape, numbers.Number):
904
+ shape = (shape,)
905
+ if isinstance(offset, numbers.Number):
906
+ offset = (offset,)
907
+ dim = len(shape)
908
+ if offset is not None and dim != len(offset):
909
+ raise GsTaichiSyntaxError(
910
+ f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
911
+ )
912
+ axis_seq = []
913
+ shape_seq = []
914
+ if order is not None:
915
+ if dim != len(order):
916
+ raise GsTaichiSyntaxError(
917
+ f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
918
+ )
919
+ if dim != len(set(order)):
920
+ raise GsTaichiSyntaxError("The axes in order must be different")
921
+ for ch in order:
922
+ axis = ord(ch) - ord("i")
923
+ if axis < 0 or axis >= dim:
924
+ raise GsTaichiSyntaxError(f"Invalid axis {ch}")
925
+ axis_seq.append(axis)
926
+ shape_seq.append(shape[axis])
927
+ else:
928
+ axis_seq = list(range(dim))
929
+ shape_seq = list(shape)
930
+ same_level = order is None
931
+ if layout == Layout.SOA:
932
+ for e in entries._get_field_members():
933
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
934
+ if needs_grad:
935
+ for e in entries_grad._get_field_members():
936
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
937
+ if needs_dual:
938
+ for e in entries_dual._get_field_members():
939
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
940
+ else:
941
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries, offset=offset)
942
+ if needs_grad:
943
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_grad, offset=offset)
944
+ if needs_dual:
945
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_dual, offset=offset)
946
+ return entries
947
+
948
+ @classmethod
949
+ @python_scope
950
+ def ndarray(cls, n, m, dtype, shape):
951
+ """Defines a GsTaichi ndarray with matrix elements.
952
+ This function must be called in Python scope, and after `ti.init` is called.
953
+
954
+ Args:
955
+ n (int): Number of rows of the matrix.
956
+ m (int): Number of columns of the matrix.
957
+ dtype (DataType): Data type of each value.
958
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
959
+
960
+ Example::
961
+
962
+ The code below shows how a GsTaichi ndarray with matrix elements \
963
+ can be declared and defined::
964
+
965
+ >>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
966
+ """
967
+ if isinstance(shape, numbers.Number):
968
+ shape = (shape,)
969
+ return MatrixNdarray(n, m, dtype, shape)
970
+
971
+ @staticmethod
972
+ def rows(rows):
973
+ """Constructs a matrix by concatenating a list of
974
+ vectors/lists row by row. Must be called in GsTaichi scope.
975
+
976
+ Args:
977
+ rows (List): A list of Vector (1-D Matrix) or a list of list.
978
+
979
+ Returns:
980
+ :class:`~gstaichi.Matrix`: A matrix.
981
+
982
+ Example::
983
+
984
+ >>> @ti.kernel
985
+ >>> def test():
986
+ >>> v1 = ti.Vector([1, 2, 3])
987
+ >>> v2 = ti.Vector([4, 5, 6])
988
+ >>> m = ti.Matrix.rows([v1, v2])
989
+ >>> print(m)
990
+ >>>
991
+ >>> test()
992
+ [[1, 2, 3], [4, 5, 6]]
993
+ """
994
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
995
+
996
+ return matrix_ops.rows(rows)
997
+
998
+ @staticmethod
999
+ def cols(cols):
1000
+ """Constructs a Matrix instance by concatenating Vectors/lists column by column.
1001
+
1002
+ Args:
1003
+ cols (List): A list of Vector (1-D Matrix) or a list of list.
1004
+
1005
+ Returns:
1006
+ :class:`~gstaichi.Matrix`: A matrix.
1007
+
1008
+ Example::
1009
+
1010
+ >>> @ti.kernel
1011
+ >>> def test():
1012
+ >>> v1 = ti.Vector([1, 2, 3])
1013
+ >>> v2 = ti.Vector([4, 5, 6])
1014
+ >>> m = ti.Matrix.cols([v1, v2])
1015
+ >>> print(m)
1016
+ >>>
1017
+ >>> test()
1018
+ [[1, 4], [2, 5], [3, 6]]
1019
+ """
1020
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1021
+
1022
+ return matrix_ops.cols(cols)
1023
+
1024
+ def __hash__(self):
1025
+ # TODO: refactor KernelTemplateMapper
1026
+ # If not, we get `unhashable type: Matrix` when
1027
+ # using matrices as template arguments.
1028
+ return id(self)
1029
+
1030
+ def dot(self, other):
1031
+ """Performs the dot product of two vectors.
1032
+
1033
+ To call this method, both multiplicatives must be vectors.
1034
+
1035
+ Args:
1036
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1037
+
1038
+ Returns:
1039
+ DataType: The dot product result (scalar) of the two Vectors.
1040
+
1041
+ Example::
1042
+
1043
+ >>> v1 = ti.Vector([1, 2, 3])
1044
+ >>> v2 = ti.Vector([3, 4, 5])
1045
+ >>> v1.dot(v2)
1046
+ 26
1047
+ """
1048
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1049
+
1050
+ return matrix_ops.dot(self, other)
1051
+
1052
+ def cross(self, other):
1053
+ """Performs the cross product with the input vector (1-D Matrix).
1054
+
1055
+ Both two vectors must have the same dimension <= 3.
1056
+
1057
+ For two 2d vectors (x1, y1) and (x2, y2), the return value is the
1058
+ scalar `x1*y2 - x2*y1`.
1059
+
1060
+ For two 3d vectors `v` and `w`, the return value is the 3d vector
1061
+ `v x w`.
1062
+
1063
+ Args:
1064
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1065
+
1066
+ Returns:
1067
+ :class:`~gstaichi.Matrix`: The cross product of the two Vectors.
1068
+ """
1069
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1070
+
1071
+ return matrix_ops.cross(self, other)
1072
+
1073
+ def outer_product(self, other):
1074
+ """Performs the outer product with the input Vector (1-D Matrix).
1075
+
1076
+ The outer_product of two vectors `v = (x1, x2, ..., xn)`,
1077
+ `w = (y1, y2, ..., yn)` is a `n` times `n` square matrix, and its `(i, j)`
1078
+ entry is equal to `xi*yj`.
1079
+
1080
+ Args:
1081
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1082
+
1083
+ Returns:
1084
+ :class:`~gstaichi.Matrix`: The outer product of the two Vectors.
1085
+ """
1086
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1087
+
1088
+ return matrix_ops.outer_product(self, other)
1089
+
1090
+
1091
+ class Vector(Matrix):
1092
+ def __init__(self, arr, dt=None, **kwargs):
1093
+ """Constructs a vector from given array.
1094
+
1095
+ A vector is an instance of a 2-D matrix with the second dimension being equal to 1.
1096
+
1097
+ Args:
1098
+ arr (Union[list, tuple, np.ndarray]): The initial values of the Vector.
1099
+ dt (:mod:`~gstaichi.types.primitive_types`): data type of the vector.
1100
+
1101
+ Returns:
1102
+ :class:`~gstaichi.Matrix`: A vector instance.
1103
+ Example::
1104
+ >>> u = ti.Vector([1, 2])
1105
+ >>> print(u.m, u.n) # verify a vector is a matrix of shape (n, 1)
1106
+ 2 1
1107
+ >>> v = ti.Vector([3, 4])
1108
+ >>> u + v
1109
+ [4 6]
1110
+ """
1111
+ super().__init__(arr, dt=dt, **kwargs)
1112
+
1113
+ def get_shape(self):
1114
+ return (self.n,)
1115
+
1116
+ @classmethod
1117
+ def field(cls, n, dtype, *args, **kwargs):
1118
+ """ti.Vector.field"""
1119
+ ndim = kwargs.get("ndim", 1)
1120
+ assert ndim == 1
1121
+ kwargs["ndim"] = 1
1122
+ return super().field(n, 1, dtype, *args, **kwargs)
1123
+
1124
+ @classmethod
1125
+ @python_scope
1126
+ def ndarray(cls, n, dtype, shape):
1127
+ """Defines a GsTaichi ndarray with vector elements.
1128
+
1129
+ Args:
1130
+ n (int): Size of the vector.
1131
+ dtype (DataType): Data type of each value.
1132
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1133
+
1134
+ Example:
1135
+ The code below shows how a GsTaichi ndarray with vector elements can be declared and defined::
1136
+
1137
+ >>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
1138
+ """
1139
+ if isinstance(shape, numbers.Number):
1140
+ shape = (shape,)
1141
+ return VectorNdarray(n, dtype, shape)
1142
+
1143
+
1144
+ class MatrixField(Field):
1145
+ """GsTaichi matrix field with SNode implementation.
1146
+
1147
+ Args:
1148
+ vars (List[Expr]): Field members.
1149
+ n (Int): Number of rows.
1150
+ m (Int): Number of columns.
1151
+ ndim (Int): Number of dimensions; forced reshape if given.
1152
+ """
1153
+
1154
+ def __init__(self, _vars, n, m, ndim=2):
1155
+ assert len(_vars) == n * m
1156
+ assert ndim in (0, 1, 2)
1157
+ super().__init__(_vars)
1158
+ self.n = n
1159
+ self.m = m
1160
+ self.ndim = ndim
1161
+ self.ptr = ti_python_core.expr_matrix_field([var.ptr for var in self.vars], [n, m][:ndim])
1162
+
1163
+ def get_scalar_field(self, *indices):
1164
+ """Creates a ScalarField using a specific field member.
1165
+
1166
+ Args:
1167
+ indices (Tuple[Int]): Specified indices of the field member.
1168
+
1169
+ Returns:
1170
+ ScalarField: The result ScalarField.
1171
+ """
1172
+ assert len(indices) in [1, 2]
1173
+ i = indices[0]
1174
+ j = 0 if len(indices) == 1 else indices[1]
1175
+ return ScalarField(self.vars[i * self.m + j])
1176
+
1177
+ def _get_dynamic_index_stride(self):
1178
+ if self.ptr.get_dynamic_indexable():
1179
+ return self.ptr.get_dynamic_index_stride()
1180
+ return None
1181
+
1182
+ def _calc_dynamic_index_stride(self):
1183
+ # Algorithm: https://github.com/taichi-dev/gstaichi/issues/3810
1184
+ paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
1185
+ num_members = len(paths)
1186
+ if num_members == 1:
1187
+ self.ptr.set_dynamic_index_stride(0)
1188
+ return
1189
+ length = len(paths[0])
1190
+ if any(len(path) != length or ti_python_core.is_quant(path[length - 1]._dtype) for path in paths):
1191
+ return
1192
+ for i in range(length):
1193
+ if any(path[i] != paths[0][i] for path in paths):
1194
+ depth_below_lca = i
1195
+ break
1196
+ for i in range(depth_below_lca, length - 1):
1197
+ if any(
1198
+ path[i].ptr.type != ti_python_core.SNodeType.dense
1199
+ or path[i]._cell_size_bytes != paths[0][i]._cell_size_bytes
1200
+ or path[i + 1]._offset_bytes_in_parent_cell != paths[0][i + 1]._offset_bytes_in_parent_cell
1201
+ for path in paths
1202
+ ):
1203
+ return
1204
+ stride = (
1205
+ paths[1][depth_below_lca]._offset_bytes_in_parent_cell
1206
+ - paths[0][depth_below_lca]._offset_bytes_in_parent_cell
1207
+ )
1208
+ for i in range(2, num_members):
1209
+ if (
1210
+ stride
1211
+ != paths[i][depth_below_lca]._offset_bytes_in_parent_cell
1212
+ - paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell
1213
+ ):
1214
+ return
1215
+ self.ptr.set_dynamic_index_stride(stride)
1216
+
1217
+ def fill(self, val):
1218
+ """Fills this matrix field with specified values.
1219
+
1220
+ Args:
1221
+ val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill,
1222
+ should have consistent dimension consistent with `self`.
1223
+ """
1224
+ if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) and not val.is_tensor()):
1225
+ if self.ndim == 2:
1226
+ val = tuple(tuple(val for _ in range(self.m)) for _ in range(self.n))
1227
+ else:
1228
+ assert self.ndim == 1
1229
+ val = tuple(val for _ in range(self.n))
1230
+ elif isinstance(val, expr.Expr) and val.is_tensor():
1231
+ assert val.n == self.n
1232
+ if self.ndim != 1:
1233
+ assert val.m == self.m
1234
+ else:
1235
+ if isinstance(val, Matrix):
1236
+ val = val.to_list()
1237
+ assert isinstance(val, (list, tuple))
1238
+ val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
1239
+ assert len(val) == self.n
1240
+ if self.ndim != 1:
1241
+ assert len(val[0]) == self.m
1242
+ if in_python_scope():
1243
+ from gstaichi._kernels import ( # pylint: disable=C0415
1244
+ field_fill_python_scope, # pylint: disable=C0415
1245
+ )
1246
+
1247
+ field_fill_python_scope(self, val)
1248
+ else:
1249
+ from gstaichi._funcs import ( # pylint: disable=C0415
1250
+ field_fill_gstaichi_scope, # pylint: disable=C0415
1251
+ )
1252
+
1253
+ field_fill_gstaichi_scope(self, val)
1254
+
1255
+ @python_scope
1256
+ def to_numpy(self, keep_dims=False, dtype=None):
1257
+ """Converts the field instance to a NumPy array.
1258
+
1259
+ Args:
1260
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1261
+ When keep_dims=True, on an n-D matrix field, the numpy array always has n+2 dims, even for 1x1, 1xn, nx1 matrix fields.
1262
+ When keep_dims=False, the resulting numpy array should skip the matrix dims with size 1.
1263
+ For example, a 4x1 or 1x4 matrix field with 5x6x7 elements results in an array of shape 5x6x7x4.
1264
+ dtype (DataType, optional): The desired data type of returned numpy array.
1265
+
1266
+ Returns:
1267
+ numpy.ndarray: The result NumPy array.
1268
+ """
1269
+ if dtype is None:
1270
+ dtype = to_numpy_type(self.dtype)
1271
+ as_vector = self.m == 1 and not keep_dims
1272
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1273
+ arr = np.zeros(self.shape + shape_ext, dtype=dtype)
1274
+ from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1275
+
1276
+ matrix_to_ext_arr(self, arr, as_vector)
1277
+ runtime_ops.sync()
1278
+ return arr
1279
+
1280
+ def to_torch(self, device=None, keep_dims=False):
1281
+ """Converts the field instance to a PyTorch tensor.
1282
+
1283
+ Args:
1284
+ device (torch.device, optional): The desired device of returned tensor.
1285
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1286
+ See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1287
+
1288
+ Returns:
1289
+ torch.tensor: The result torch tensor.
1290
+ """
1291
+ import torch # pylint: disable=C0415
1292
+
1293
+ as_vector = self.m == 1 and not keep_dims
1294
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1295
+ # pylint: disable=E1101
1296
+ arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device)
1297
+ from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1298
+
1299
+ matrix_to_ext_arr(self, arr, as_vector)
1300
+ runtime_ops.sync()
1301
+ return arr
1302
+
1303
+ @python_scope
1304
+ def _from_external_arr(self, arr):
1305
+ if len(arr.shape) == len(self.shape) + 1:
1306
+ as_vector = True
1307
+ assert self.m == 1, "This is not a vector field"
1308
+ else:
1309
+ as_vector = False
1310
+ assert len(arr.shape) == len(self.shape) + 2
1311
+ dim_ext = 1 if as_vector else 2
1312
+ assert len(arr.shape) == len(self.shape) + dim_ext
1313
+ from gstaichi._kernels import ext_arr_to_matrix # pylint: disable=C0415
1314
+
1315
+ ext_arr_to_matrix(arr, self, as_vector)
1316
+ runtime_ops.sync()
1317
+
1318
+ @python_scope
1319
+ def from_numpy(self, arr):
1320
+ """Copies an `numpy.ndarray` into this field.
1321
+
1322
+ Example::
1323
+
1324
+ >>> m = ti.Matrix.field(2, 2, ti.f32, shape=(3, 3))
1325
+ >>> arr = numpy.ones((3, 3, 2, 2))
1326
+ >>> m.from_numpy(arr)
1327
+ """
1328
+
1329
+ if not arr.flags.c_contiguous:
1330
+ arr = np.ascontiguousarray(arr)
1331
+ self._from_external_arr(arr)
1332
+
1333
+ @python_scope
1334
+ def __setitem__(self, key, value):
1335
+ self._initialize_host_accessors()
1336
+ self[key]._set_entries(value)
1337
+
1338
+ @python_scope
1339
+ def __getitem__(self, key):
1340
+ self._initialize_host_accessors()
1341
+ key = self._pad_key(key)
1342
+ _host_access = self._host_access(key)
1343
+ if self.ndim == 1:
1344
+ return Vector([_host_access[i] for i in range(self.n)])
1345
+ return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] for i in range(self.n)])
1346
+
1347
+ def __repr__(self):
1348
+ # make interactive shell happy, prevent materialization
1349
+ return f"<{self.n}x{self.m} ti.Matrix.field>"
1350
+
1351
+
1352
+ class MatrixType(CompoundType):
1353
+ def __init__(self, n, m, ndim, dtype):
1354
+ self.n = n
1355
+ self.m = m
1356
+ self.ndim = ndim
1357
+ # FIXME(haidong): dtypes should not be left empty for ndarray.
1358
+ # Remove the None dtype when we are ready to break legacy code.
1359
+ if dtype is not None:
1360
+ self.dtype = cook_dtype(dtype)
1361
+ shape = (n, m) if ndim == 2 else (n,)
1362
+ self.tensor_type = _type_factory.get_tensor_type(shape, self.dtype)
1363
+ else:
1364
+ self.dtype = None
1365
+ self.tensor_type = None
1366
+
1367
+ def __call__(self, *args):
1368
+ """Return a matrix matching the shape and dtype.
1369
+
1370
+ This function will try to convert the input to a `n x m` matrix, with n, m being
1371
+ the number of rows/cols of this matrix type.
1372
+
1373
+ Example::
1374
+
1375
+ >>> mat4x3 = MatrixType(4, 3, float)
1376
+ >>> mat2x6 = MatrixType(2, 6, float)
1377
+
1378
+ Create from n x m scalars, of a 1d list of n x m scalars:
1379
+
1380
+ >>> m = mat4x3([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
1381
+ >>> m = mat4x3(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
1382
+
1383
+ Create from n vectors/lists, with each one of dimension m:
1384
+
1385
+ >>> m = mat4x3([1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12])
1386
+
1387
+ Create from a single scalar
1388
+
1389
+ >>> m = mat4x3(1)
1390
+
1391
+ Create from another 2d list/matrix, as long as they have the same number of entries
1392
+
1393
+ >>> m = mat4x3([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
1394
+ >>> m = mat4x3(m)
1395
+ >>> k = mat2x6(m)
1396
+
1397
+ """
1398
+ if len(args) == 0:
1399
+ raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
1400
+ if len(args) == 1:
1401
+ # Init from a real Matrix
1402
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1403
+ arg = args[0]
1404
+ shape = arg.ptr.get_rvalue_type().shape()
1405
+ assert self.ndim == len(shape)
1406
+ assert self.n == shape[0]
1407
+ if self.ndim > 1:
1408
+ assert self.m == shape[1]
1409
+ return expr.Expr(arg.ptr)
1410
+
1411
+ # initialize by a single scalar, e.g. matnxm(1)
1412
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1413
+ entries = [args[0] for _ in range(self.m) for _ in range(self.n)]
1414
+ return self._instantiate(entries)
1415
+ args = args[0]
1416
+ # collect all input entries to a 1d list and then reshape
1417
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1418
+ entries = []
1419
+ for x in args:
1420
+ if isinstance(x, (list, tuple)):
1421
+ entries += x
1422
+ elif isinstance(x, np.ndarray):
1423
+ entries += list(x.ravel())
1424
+ elif isinstance(x, Matrix):
1425
+ entries += x.to_list()
1426
+ else:
1427
+ entries.append(x)
1428
+
1429
+ return self._instantiate(entries)
1430
+
1431
+ def from_gstaichi_object(self, func_ret, ret_index=()):
1432
+ return self(
1433
+ [
1434
+ expr.Expr(
1435
+ ti_python_core.make_get_element_expr(
1436
+ func_ret.ptr,
1437
+ ret_index + (i,),
1438
+ _ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1439
+ )
1440
+ )
1441
+ for i in range(self.m * self.n)
1442
+ ]
1443
+ )
1444
+
1445
+ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
1446
+ if self.dtype in primitive_types.integer_types:
1447
+ if is_signed(cook_dtype(self.dtype)):
1448
+ get_ret_func = launch_ctx.get_struct_ret_int
1449
+ else:
1450
+ get_ret_func = launch_ctx.get_struct_ret_uint
1451
+ elif self.dtype in primitive_types.real_types:
1452
+ get_ret_func = launch_ctx.get_struct_ret_float
1453
+ else:
1454
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1455
+ return self([get_ret_func(ret_index + (i,)) for i in range(self.m * self.n)])
1456
+
1457
+ def set_kernel_struct_args(self, mat, launch_ctx, ret_index=()):
1458
+ if self.dtype in primitive_types.integer_types:
1459
+ if is_signed(cook_dtype(self.dtype)):
1460
+ set_arg_func = launch_ctx.set_struct_arg_int
1461
+ else:
1462
+ set_arg_func = launch_ctx.set_struct_arg_uint
1463
+ elif self.dtype in primitive_types.real_types:
1464
+ set_arg_func = launch_ctx.set_struct_arg_float
1465
+ else:
1466
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1467
+ if self.ndim == 1:
1468
+ for i in range(self.n):
1469
+ set_arg_func(ret_index + (i,), mat[i])
1470
+ else:
1471
+ for i in range(self.n):
1472
+ for j in range(self.m):
1473
+ set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1474
+
1475
+ def _instantiate_in_python_scope(self, entries):
1476
+ entries = [[entries[k * self.m + i] for i in range(self.m)] for k in range(self.n)]
1477
+ return Matrix(
1478
+ [
1479
+ [
1480
+ int(entries[i][j]) if self.dtype in primitive_types.integer_types else float(entries[i][j])
1481
+ for j in range(self.m)
1482
+ ]
1483
+ for i in range(self.n)
1484
+ ],
1485
+ dt=self.dtype,
1486
+ )
1487
+
1488
+ def _instantiate(self, entries):
1489
+ if in_python_scope():
1490
+ return self._instantiate_in_python_scope(entries)
1491
+
1492
+ return make_matrix_with_shape(entries, [self.n, self.m], self.dtype)
1493
+
1494
+ def field(self, **kwargs):
1495
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1496
+ kwargs.update({"ndim": self.ndim})
1497
+ return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)
1498
+
1499
+ def ndarray(self, **kwargs):
1500
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1501
+ kwargs.update({"ndim": self.ndim})
1502
+ return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)
1503
+
1504
+ def get_shape(self):
1505
+ if self.ndim == 1:
1506
+ return (self.n,)
1507
+ return (self.n, self.m)
1508
+
1509
+ def to_string(self):
1510
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1511
+ return f"MatrixType[{self.n},{self.m}, {dtype_str}]"
1512
+
1513
+ def check_matched(self, other):
1514
+ if self.ndim != len(other.shape()):
1515
+ return False
1516
+ if self.dtype is not None and self.dtype != other.element_type():
1517
+ return False
1518
+ shape = self.get_shape()
1519
+ for i in range(self.ndim):
1520
+ if shape[i] is not None and shape[i] != other.shape()[i]:
1521
+ return False
1522
+ return True
1523
+
1524
+
1525
+ class VectorType(MatrixType):
1526
+ def __init__(self, n, dtype):
1527
+ super().__init__(n, 1, 1, dtype)
1528
+
1529
+ def __call__(self, *args):
1530
+ """Return a vector matching the shape and dtype.
1531
+
1532
+ This function will try to convert the input to a `n`-component vector.
1533
+
1534
+ Example::
1535
+
1536
+ >>> vec3 = VectorType(3, float)
1537
+
1538
+ Create from n scalars:
1539
+
1540
+ >>> v = vec3(1, 2, 3)
1541
+
1542
+ Create from a list/tuple of n scalars:
1543
+
1544
+ >>> v = vec3([1, 2, 3])
1545
+
1546
+ Create from a single scalar
1547
+
1548
+ >>> v = vec3(1)
1549
+
1550
+ """
1551
+ if len(args) == 0:
1552
+ raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
1553
+ if len(args) == 1:
1554
+ # Init from a real Matrix
1555
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1556
+ arg = args[0]
1557
+ shape = arg.ptr.get_rvalue_type().shape()
1558
+ assert len(shape) == 1
1559
+ assert self.n == shape[0]
1560
+ return expr.Expr(arg.ptr)
1561
+
1562
+ # initialize by a single scalar, e.g. matnxm(1)
1563
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1564
+ entries = [args[0] for _ in range(self.n)]
1565
+ return self._instantiate(entries)
1566
+ args = args[0]
1567
+ # collect all input entries to a 1d list and then reshape
1568
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1569
+ entries = []
1570
+ for x in args:
1571
+ if isinstance(x, (list, tuple)):
1572
+ entries += x
1573
+ elif isinstance(x, np.ndarray):
1574
+ entries += list(x.ravel())
1575
+ elif isinstance(x, Matrix):
1576
+ entries += x.to_list()
1577
+ else:
1578
+ entries.append(x)
1579
+
1580
+ # type cast
1581
+ return self._instantiate(entries)
1582
+
1583
+ def _instantiate_in_python_scope(self, entries):
1584
+ return Vector(
1585
+ [
1586
+ int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i])
1587
+ for i in range(self.n)
1588
+ ],
1589
+ dt=self.dtype,
1590
+ )
1591
+
1592
+ def _instantiate(self, entries):
1593
+ if in_python_scope():
1594
+ return self._instantiate_in_python_scope(entries)
1595
+
1596
+ return make_matrix_with_shape(entries, [self.n], self.dtype)
1597
+
1598
+ def field(self, **kwargs):
1599
+ return Vector.field(self.n, dtype=self.dtype, **kwargs)
1600
+
1601
+ def ndarray(self, **kwargs):
1602
+ return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)
1603
+
1604
+ def to_string(self):
1605
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1606
+ return f"VectorType[{self.n}, {dtype_str}]"
1607
+
1608
+
1609
+ class MatrixNdarray(Ndarray):
1610
+ """GsTaichi ndarray with matrix elements.
1611
+
1612
+ Args:
1613
+ n (int): Number of rows of the matrix.
1614
+ m (int): Number of columns of the matrix.
1615
+ dtype (DataType): Data type of each value.
1616
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1617
+
1618
+ Example::
1619
+
1620
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1621
+ """
1622
+
1623
+ def __init__(self, n, m, dtype, shape):
1624
+ self.n = n
1625
+ self.m = m
1626
+ super().__init__()
1627
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1628
+ self.dtype = cook_dtype(dtype)
1629
+
1630
+ self.layout = Layout.AOS
1631
+ self.shape = tuple(shape)
1632
+ self.element_type = _type_factory.get_tensor_type((self.n, self.m), self.dtype)
1633
+ # TODO: we should pass in element_type, shape, layout instead.
1634
+ self.arr = impl.get_runtime().prog.create_ndarray(
1635
+ cook_dtype(self.element_type),
1636
+ shape,
1637
+ Layout.AOS,
1638
+ zero_fill=True,
1639
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1640
+ )
1641
+
1642
+ @property
1643
+ def element_shape(self):
1644
+ """Returns the shape of each element (a 2D matrix) in this ndarray.
1645
+
1646
+ Example::
1647
+
1648
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1649
+ >>> arr.element_shape
1650
+ (2, 2)
1651
+ """
1652
+ return tuple(self.arr.element_shape())
1653
+
1654
+ @python_scope
1655
+ def __setitem__(self, key, value):
1656
+ if not isinstance(value, (list, tuple)):
1657
+ value = list(value)
1658
+ if not isinstance(value[0], (list, tuple)):
1659
+ value = [[i] for i in value]
1660
+ for i in range(self.n):
1661
+ for j in range(self.m):
1662
+ self[key][i, j] = value[i][j]
1663
+
1664
+ @python_scope
1665
+ def __getitem__(self, key):
1666
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1667
+ return Matrix([[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)] for i in range(self.n)])
1668
+
1669
+ @python_scope
1670
+ def to_numpy(self):
1671
+ """Converts this ndarray to a `numpy.ndarray`.
1672
+
1673
+ Example::
1674
+
1675
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1))
1676
+ >>> arr.to_numpy()
1677
+ [[[[0. 0.]
1678
+ [0. 0.]]]
1679
+
1680
+ [[[0. 0.]
1681
+ [0. 0.]]]]
1682
+ """
1683
+ return self._ndarray_matrix_to_numpy(as_vector=0)
1684
+
1685
+ @python_scope
1686
+ def from_numpy(self, arr):
1687
+ """Copies the data of a `numpy.ndarray` into this array.
1688
+
1689
+ Example::
1690
+
1691
+ >>> m = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1), layout=0)
1692
+ >>> arr = np.ones((2, 1, 2, 2))
1693
+ >>> m.from_numpy(arr)
1694
+ """
1695
+ self._ndarray_matrix_from_numpy(arr, as_vector=0)
1696
+
1697
+ @python_scope
1698
+ def __deepcopy__(self, memo=None):
1699
+ ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape)
1700
+ ret_arr.copy_from(self)
1701
+ return ret_arr
1702
+
1703
+ @python_scope
1704
+ def _fill_by_kernel(self, val):
1705
+ from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1706
+
1707
+ shape = self.element_type.shape()
1708
+ n = shape[0]
1709
+ m = 1
1710
+ if len(shape) > 1:
1711
+ m = shape[1]
1712
+
1713
+ prim_dtype = self.element_type.element_type()
1714
+ matrix_type = MatrixType(n, m, len(shape), prim_dtype)
1715
+ if isinstance(val, Matrix):
1716
+ value = val
1717
+ else:
1718
+ value = matrix_type(val)
1719
+ fill_ndarray_matrix(self, value)
1720
+
1721
+ @python_scope
1722
+ def __repr__(self):
1723
+ return f"<{self.n}x{self.m} {Layout.AOS} ti.Matrix.ndarray>"
1724
+
1725
+
1726
+ class VectorNdarray(Ndarray):
1727
+ """GsTaichi ndarray with vector elements.
1728
+
1729
+ Args:
1730
+ n (int): Size of the vector.
1731
+ dtype (DataType): Data type of each value.
1732
+ shape (Tuple[int]): Shape of the ndarray.
1733
+
1734
+ Example::
1735
+
1736
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1737
+ """
1738
+
1739
+ def __init__(self, n, dtype, shape):
1740
+ self.n = n
1741
+ super().__init__()
1742
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1743
+ self.dtype = cook_dtype(dtype)
1744
+
1745
+ self.layout = Layout.AOS
1746
+ self.shape = tuple(shape)
1747
+ self.element_type = _type_factory.get_tensor_type((n,), self.dtype)
1748
+ self.arr = impl.get_runtime().prog.create_ndarray(
1749
+ cook_dtype(self.element_type),
1750
+ shape,
1751
+ Layout.AOS,
1752
+ zero_fill=True,
1753
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1754
+ )
1755
+
1756
+ @property
1757
+ def element_shape(self):
1758
+ """Gets the dimension of the vector of this ndarray.
1759
+
1760
+ Example::
1761
+
1762
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1763
+ >>> a.element_shape
1764
+ (3,)
1765
+ """
1766
+ return tuple(self.arr.element_shape())
1767
+
1768
+ @python_scope
1769
+ def __setitem__(self, key, value):
1770
+ if not isinstance(value, (list, tuple)):
1771
+ value = list(value)
1772
+ for i in range(self.n):
1773
+ self[key][i] = value[i]
1774
+
1775
+ @python_scope
1776
+ def __getitem__(self, key):
1777
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1778
+ return Vector([NdarrayHostAccess(self, key, (i,)) for i in range(self.n)])
1779
+
1780
+ @python_scope
1781
+ def to_numpy(self):
1782
+ """Converts this vector ndarray to a `numpy.ndarray`.
1783
+
1784
+ Example::
1785
+
1786
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2))
1787
+ >>> a.to_numpy()
1788
+ array([[[0., 0., 0.],
1789
+ [0., 0., 0.]],
1790
+
1791
+ [[0., 0., 0.],
1792
+ [0., 0., 0.]]], dtype=float32)
1793
+ """
1794
+ return self._ndarray_matrix_to_numpy(as_vector=1)
1795
+
1796
+ @python_scope
1797
+ def from_numpy(self, arr):
1798
+ """Copies the data from a `numpy.ndarray` into this ndarray.
1799
+
1800
+ The shape and data type of `arr` must match this ndarray.
1801
+
1802
+ Example::
1803
+
1804
+ >>> import numpy as np
1805
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2), 0)
1806
+ >>> b = np.ones((2, 2, 3), dtype=np.float32)
1807
+ >>> a.from_numpy(b)
1808
+ """
1809
+ self._ndarray_matrix_from_numpy(arr, as_vector=1)
1810
+
1811
+ @python_scope
1812
+ def __deepcopy__(self, memo=None):
1813
+ ret_arr = VectorNdarray(self.n, self.dtype, self.shape)
1814
+ ret_arr.copy_from(self)
1815
+ return ret_arr
1816
+
1817
+ @python_scope
1818
+ def _fill_by_kernel(self, val):
1819
+ from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1820
+
1821
+ shape = self.element_type.shape()
1822
+ prim_dtype = self.element_type.element_type()
1823
+ vector_type = VectorType(shape[0], prim_dtype)
1824
+ if isinstance(val, Vector):
1825
+ value = val
1826
+ else:
1827
+ value = vector_type(val)
1828
+ fill_ndarray_matrix(self, value)
1829
+
1830
+ @python_scope
1831
+ def __repr__(self):
1832
+ return f"<{self.n} {Layout.AOS} ti.Vector.ndarray>"
1833
+
1834
+
1835
+ __all__ = ["Matrix", "Vector", "MatrixField", "MatrixNdarray", "VectorNdarray"]