gstaichi 0.1.25.dev0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp312-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1881 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+ import numbers
5
+ from collections.abc import Iterable
6
+ from itertools import product
7
+
8
+ import numpy as np
9
+
10
+ from gstaichi._lib import core as ti_python_core
11
+ from gstaichi._lib.utils import ti_python_core as _ti_python_core
12
+ from gstaichi.lang import expr, impl, runtime_ops
13
+ from gstaichi.lang import ops as ops_mod
14
+ from gstaichi.lang._ndarray import Ndarray, NdarrayHostAccess
15
+ from gstaichi.lang.common_ops import GsTaichiOperations
16
+ from gstaichi.lang.exception import (
17
+ GsTaichiRuntimeError,
18
+ GsTaichiRuntimeTypeError,
19
+ GsTaichiSyntaxError,
20
+ GsTaichiTypeError,
21
+ )
22
+ from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
23
+ from gstaichi.lang.util import (
24
+ cook_dtype,
25
+ get_traceback,
26
+ gstaichi_scope,
27
+ in_python_scope,
28
+ python_scope,
29
+ to_numpy_type,
30
+ to_paddle_type,
31
+ to_pytorch_type,
32
+ warning,
33
+ )
34
+ from gstaichi.types import primitive_types
35
+ from gstaichi.types.compound_types import CompoundType
36
+ from gstaichi.types.enums import Layout
37
+ from gstaichi.types.utils import is_signed
38
+
39
+ _type_factory = _ti_python_core.get_type_factory_instance()
40
+
41
+
42
+ def _generate_swizzle_patterns(key_group: str, required_length=4):
43
+ """Generate vector swizzle patterns from a given set of characters.
44
+
45
+ Example:
46
+
47
+ For `key_group=xyzw` and `required_length=4`, this function will return a
48
+ list consists of all possible strings (no repeats) in characters
49
+ `x`, `y`, `z`, `w` and of length<=4:
50
+ [`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...]
51
+ The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340.
52
+ """
53
+ result = []
54
+ for k in range(1, required_length + 1):
55
+ result.extend(product(key_group, repeat=k))
56
+ result = ["".join(pat) for pat in result]
57
+ return result
58
+
59
+
60
+ def _gen_swizzles(cls):
61
+ # https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
62
+ KEYGROUP_SET = ["xyzw", "rgba", "stpq"]
63
+ cls._swizzle_to_keygroup = {}
64
+ cls._keygroup_to_checker = {}
65
+
66
+ def make_valid_attribs_checker(key_group):
67
+ def check(instance, pattern):
68
+ valid_attribs = set(key_group[: instance.n])
69
+ pattern_set = set(pattern)
70
+ diff = pattern_set - valid_attribs
71
+ if len(diff):
72
+ valid_attribs = tuple(sorted(valid_attribs))
73
+ pattern = tuple(pattern)
74
+ raise GsTaichiSyntaxError(f"vec{instance.n} only has " f"attributes={valid_attribs}, got={pattern}")
75
+
76
+ return check
77
+
78
+ for key_group in KEYGROUP_SET:
79
+ cls._keygroup_to_checker[key_group] = make_valid_attribs_checker(key_group)
80
+ for index, attr in enumerate(key_group):
81
+
82
+ def gen_property(attr, attr_idx, key_group):
83
+ checker = cls._keygroup_to_checker[key_group]
84
+
85
+ def prop_getter(instance):
86
+ checker(instance, attr)
87
+ return instance[attr_idx]
88
+
89
+ @python_scope
90
+ def prop_setter(instance, value):
91
+ checker(instance, attr)
92
+ instance[attr_idx] = value
93
+
94
+ return property(prop_getter, prop_setter)
95
+
96
+ prop = gen_property(attr, index, key_group)
97
+ setattr(cls, attr, prop)
98
+ cls._swizzle_to_keygroup[attr] = key_group
99
+
100
+ for key_group in KEYGROUP_SET:
101
+ sw_patterns = _generate_swizzle_patterns(key_group, required_length=4)
102
+ # len=1 accessors are handled specially above
103
+ sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
104
+ for prop_key in sw_patterns:
105
+ # Create a function for value capturing
106
+ def gen_property(pattern, key_group):
107
+ checker = cls._keygroup_to_checker[key_group]
108
+
109
+ def prop_getter(instance):
110
+ checker(instance, pattern)
111
+ res = []
112
+ for ch in pattern:
113
+ res.append(instance[key_group.index(ch)])
114
+ return Vector(res)
115
+
116
+ @python_scope
117
+ def prop_setter(instance, value):
118
+ if len(pattern) != len(value):
119
+ raise GsTaichiRuntimeError(f"value len does not match the swizzle pattern={pattern}")
120
+ checker(instance, pattern)
121
+ for ch, val in zip(pattern, value):
122
+ instance[key_group.index(ch)] = val
123
+
124
+ prop = property(prop_getter, prop_setter)
125
+ return prop
126
+
127
+ prop = gen_property(prop_key, key_group)
128
+ setattr(cls, prop_key, prop)
129
+ cls._swizzle_to_keygroup[prop_key] = key_group
130
+ return cls
131
+
132
+
133
+ def _infer_entry_dt(entry):
134
+ if isinstance(entry, (int, np.integer)):
135
+ return impl.get_runtime().default_ip
136
+ if isinstance(entry, (float, np.floating)):
137
+ return impl.get_runtime().default_fp
138
+ if isinstance(entry, expr.Expr):
139
+ dt = entry.ptr.get_rvalue_type()
140
+ if dt == ti_python_core.DataType_unknown:
141
+ raise GsTaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
142
+ return dt
143
+ raise GsTaichiTypeError("Element type of the matrix is invalid.")
144
+
145
+
146
+ def _infer_array_dt(arr):
147
+ assert len(arr) > 0
148
+ return functools.reduce(ti_python_core.promoted_type, map(_infer_entry_dt, arr))
149
+
150
+
151
+ def make_matrix_with_shape(arr, shape, dt):
152
+ return expr.Expr(
153
+ impl.get_runtime()
154
+ .compiling_callable.ast_builder()
155
+ .make_matrix_expr(
156
+ shape,
157
+ dt,
158
+ [expr.Expr(elt).ptr for elt in arr],
159
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
160
+ )
161
+ )
162
+
163
+
164
+ def make_matrix(arr, dt=None):
165
+ if len(arr) == 0:
166
+ # the only usage of an empty vector is to serve as field indices
167
+ shape = [0]
168
+ dt = primitive_types.i32
169
+ else:
170
+ if isinstance(arr[0], Iterable): # matrix
171
+ shape = [len(arr), len(arr[0])]
172
+ arr = [elt for row in arr for elt in row]
173
+ else: # vector
174
+ shape = [len(arr)]
175
+ if dt is None:
176
+ dt = _infer_array_dt(arr)
177
+ else:
178
+ dt = cook_dtype(dt)
179
+ return expr.Expr(
180
+ impl.get_runtime()
181
+ .compiling_callable.ast_builder()
182
+ .make_matrix_expr(
183
+ shape,
184
+ dt,
185
+ [expr.Expr(elt).ptr for elt in arr],
186
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
187
+ )
188
+ )
189
+
190
+
191
+ def _read_host_access(x):
192
+ if isinstance(x, SNodeHostAccess):
193
+ return x.accessor.getter(*x.key)
194
+ assert isinstance(x, NdarrayHostAccess)
195
+ return x.getter()
196
+
197
+
198
+ def _write_host_access(x, value):
199
+ if isinstance(x, SNodeHostAccess):
200
+ x.accessor.setter(value, *x.key)
201
+ else:
202
+ assert isinstance(x, NdarrayHostAccess)
203
+ x.setter(value)
204
+
205
+
206
+ @_gen_swizzles
207
+ class Matrix(GsTaichiOperations):
208
+ """The matrix class.
209
+
210
+ A matrix is a 2-D rectangular array with scalar entries, it's row-majored, and is
211
+ aligned continuously. We recommend only use matrix with no more than 32 elements for
212
+ efficiency considerations.
213
+
214
+ Note: in gstaichi a matrix is strictly two-dimensional and only stores scalars.
215
+
216
+ Args:
217
+ arr (Union[list, tuple, np.ndarray]): the initial values of a matrix.
218
+ dt (:mod:`~gstaichi.types.primitive_types`): the element data type.
219
+ ndim (int optional): the number of dimensions of the matrix; forced reshape if given.
220
+
221
+ Example::
222
+
223
+ use a 2d list to initialize a matrix
224
+
225
+ >>> @ti.kernel
226
+ >>> def test():
227
+ >>> n = 5
228
+ >>> M = ti.Matrix([[0] * n for _ in range(n)], ti.i32)
229
+ >>> print(M) # a 5x5 matrix with integer elements
230
+
231
+ get the number of rows and columns via the `n`, `m` property:
232
+
233
+ >>> M = ti.Matrix([[0, 1], [2, 3], [4, 5]], ti.i32)
234
+ >>> M.n # number of rows
235
+ 3
236
+ >>> M.m # number of cols
237
+ >>> 2
238
+
239
+ you can even initialize a matrix with an empty list:
240
+
241
+ >>> M = ti.Matrix([[], []], ti.i32)
242
+ >>> M.n
243
+ 2
244
+ >>> M.m
245
+ 0
246
+ """
247
+
248
+ _is_gstaichi_class = True
249
+ _is_matrix_class = True
250
+ __array_priority__ = 1000
251
+
252
+ def __init__(self, arr, dt=None):
253
+ if not isinstance(arr, (list, tuple, np.ndarray)):
254
+ raise GsTaichiTypeError("An Matrix/Vector can only be initialized with an array-like object")
255
+ if len(arr) == 0:
256
+ self.ndim = 0
257
+ self.n, self.m = 0, 0
258
+ self.entries = np.array([])
259
+ self.is_host_access = False
260
+ elif isinstance(arr[0], Matrix):
261
+ raise Exception("cols/rows required when using list of vectors")
262
+ elif isinstance(arr[0], Iterable): # matrix
263
+ self.ndim = 2
264
+ self.n, self.m = len(arr), len(arr[0])
265
+ if isinstance(arr[0][0], (SNodeHostAccess, NdarrayHostAccess)):
266
+ self.entries = arr
267
+ self.is_host_access = True
268
+ else:
269
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
270
+ self.is_host_access = False
271
+ else: # vector
272
+ self.ndim = 1
273
+ self.n, self.m = len(arr), 1
274
+ if isinstance(arr[0], (SNodeHostAccess, NdarrayHostAccess)):
275
+ self.entries = arr
276
+ self.is_host_access = True
277
+ else:
278
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
279
+ self.is_host_access = False
280
+
281
+ if self.n * self.m > 32:
282
+ warning(
283
+ f"GsTaichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested."
284
+ " Matrices/vectors will be automatically unrolled at compile-time for performance."
285
+ " So the compilation time could be extremely long if the matrix size is too big."
286
+ " You may use a field to store a large matrix like this, e.g.:\n"
287
+ f" x = ti.field(ti.f32, ({self.n}, {self.m})).\n"
288
+ " See https://docs.taichi-lang.org/docs/field#matrix-size"
289
+ " for more details.",
290
+ UserWarning,
291
+ stacklevel=2,
292
+ )
293
+
294
+ def get_shape(self):
295
+ if self.ndim == 1:
296
+ return (self.n,)
297
+ if self.ndim == 2:
298
+ return (self.n, self.m)
299
+ return None
300
+
301
+ def __matmul__(self, other):
302
+ """Matrix-matrix or matrix-vector multiply.
303
+
304
+ Args:
305
+ other (Union[Matrix, Vector]): a matrix or a vector.
306
+
307
+ Returns:
308
+ The matrix-matrix product or matrix-vector product.
309
+
310
+ """
311
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
312
+
313
+ return matrix_ops.matmul(self, other)
314
+
315
+ # host access & python scope operation
316
+ def __len__(self):
317
+ """Get the length of each row of a matrix"""
318
+ # TODO: When this is a vector, should return its dimension?
319
+ return self.n
320
+
321
+ def __iter__(self):
322
+ if self.ndim == 1:
323
+ return (self[i] for i in range(self.n))
324
+ return ([self[i, j] for j in range(self.m)] for i in range(self.n))
325
+
326
+ def __getitem__(self, indices):
327
+ """Access to the element at the given indices in a matrix.
328
+
329
+ Args:
330
+ indices (Sequence[Expr]): the indices of the element.
331
+
332
+ Returns:
333
+ The value of the element at a specific position of a matrix.
334
+
335
+ """
336
+ entry = self._get_entry(indices)
337
+ if self.is_host_access:
338
+ return _read_host_access(entry)
339
+ return entry
340
+
341
+ @python_scope
342
+ def __setitem__(self, indices, item):
343
+ """Set the element value at the given indices in a matrix.
344
+
345
+ Args:
346
+ indices (Sequence[Expr]): the indices of a element.
347
+
348
+ """
349
+ if self.is_host_access:
350
+ entry = self._get_entry(indices)
351
+ _write_host_access(entry, item)
352
+ else:
353
+ if not isinstance(indices, (list, tuple)):
354
+ indices = [indices]
355
+ assert len(indices) in [1, 2]
356
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
357
+ if self.ndim == 1:
358
+ self.entries[indices[0]] = item
359
+ else:
360
+ self.entries[indices[0]][indices[1]] = item
361
+
362
+ def _get_entry(self, indices):
363
+ if not isinstance(indices, (list, tuple)):
364
+ indices = [indices]
365
+ assert len(indices) in [1, 2]
366
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
367
+ if self.ndim == 1:
368
+ return self.entries[indices[0]]
369
+ return self.entries[indices[0]][indices[1]]
370
+
371
+ def _get_slice(self, a, b):
372
+ if isinstance(a, slice):
373
+ a = range(a.start or 0, a.stop or self.n, a.step or 1)
374
+ if isinstance(b, slice):
375
+ b = range(b.start or 0, b.stop or self.m, b.step or 1)
376
+ if isinstance(a, range) and isinstance(b, range):
377
+ return Matrix([[self._get_entry(i, j) for j in b] for i in a])
378
+ if isinstance(a, range): # b is not range
379
+ return Vector([self._get_entry(i, b) for i in a])
380
+ # a is not range while b is range
381
+ return Vector([self._get_entry(a, j) for j in b])
382
+
383
+ @python_scope
384
+ def _set_entries(self, value):
385
+ if isinstance(value, Matrix):
386
+ value = value.to_list()
387
+ if self.is_host_access:
388
+ if self.ndim == 1:
389
+ for i in range(self.n):
390
+ _write_host_access(self.entries[i], value[i])
391
+ else:
392
+ for i in range(self.n):
393
+ for j in range(self.m):
394
+ _write_host_access(self.entries[i][j], value[i][j])
395
+ else:
396
+ if self.ndim == 1:
397
+ for i in range(self.n):
398
+ self.entries[i] = value[i]
399
+ else:
400
+ for i in range(self.n):
401
+ for j in range(self.m):
402
+ self.entries[i][j] = value[i][j]
403
+
404
+ @property
405
+ def _members(self):
406
+ return self.entries
407
+
408
+ def to_list(self):
409
+ """Return this matrix as a 1D `list`.
410
+
411
+ This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods,
412
+ the difference is that this function always returns a new list.
413
+ """
414
+ if self.is_host_access:
415
+ if self.ndim == 1:
416
+ return [_read_host_access(self.entries[i]) for i in range(self.n)]
417
+ assert self.ndim == 2
418
+ return [[_read_host_access(self.entries[i][j]) for j in range(self.m)] for i in range(self.n)]
419
+ return self.entries.tolist()
420
+
421
+ @gstaichi_scope
422
+ def cast(self, dtype):
423
+ """Cast the matrix elements to a specified data type.
424
+
425
+ Args:
426
+ dtype (:mod:`~gstaichi.types.primitive_types`): data type of the
427
+ returned matrix.
428
+
429
+ Returns:
430
+ :class:`gstaichi.Matrix`: A new matrix with the specified data dtype.
431
+
432
+ Example::
433
+
434
+ >>> A = ti.Matrix([0, 1, 2], ti.i32)
435
+ >>> B = A.cast(ti.f32)
436
+ >>> B
437
+ [0.0, 1.0, 2.0]
438
+ """
439
+ if self.ndim == 1:
440
+ return Vector([ops_mod.cast(self[i], dtype) for i in range(self.n)])
441
+ return Matrix([[ops_mod.cast(self[i, j], dtype) for j in range(self.m)] for i in range(self.n)])
442
+
443
+ def trace(self):
444
+ """The sum of a matrix diagonal elements.
445
+
446
+ To call this method the matrix must be square-like.
447
+
448
+ Returns:
449
+ The sum of a matrix diagonal elements.
450
+
451
+ Example::
452
+
453
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
454
+ >>> m.trace()
455
+ 5
456
+ """
457
+ # pylint: disable-msg=C0415
458
+ from gstaichi.lang import matrix_ops
459
+
460
+ return matrix_ops.trace(self)
461
+
462
+ def inverse(self):
463
+ """Returns the inverse of this matrix.
464
+
465
+ Note:
466
+ The matrix dimension should be less than or equal to 4.
467
+
468
+ Returns:
469
+ :class:`~gstaichi.Matrix`: The inverse of a matrix.
470
+
471
+ Raises:
472
+ Exception: Inversions of matrices with sizes >= 5 are not supported.
473
+ """
474
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
475
+
476
+ return matrix_ops.inverse(self)
477
+
478
+ def normalized(self, eps=0):
479
+ """Normalize a vector, i.e. matrices with the second dimension being
480
+ equal to one.
481
+
482
+ The normalization of a vector `v` is a vector of length 1
483
+ and has the same direction with `v`. It's equal to `v/|v|`.
484
+
485
+ Args:
486
+ eps (float): a safe-guard value for sqrt, usually 0.
487
+
488
+ Example::
489
+
490
+ >>> a = ti.Vector([3, 4], ti.f32)
491
+ >>> a.normalized()
492
+ [0.6, 0.8]
493
+ """
494
+ # pylint: disable-msg=C0415
495
+ from gstaichi.lang import matrix_ops
496
+
497
+ return matrix_ops.normalized(self, eps)
498
+
499
+ def transpose(self):
500
+ """Returns the transpose of a matrix.
501
+
502
+ Returns:
503
+ :class:`~gstaichi.Matrix`: The transpose of this matrix.
504
+
505
+ Example::
506
+
507
+ >>> A = ti.Matrix([[0, 1], [2, 3]])
508
+ >>> A.transpose()
509
+ [[0, 2], [1, 3]]
510
+ """
511
+ # pylint: disable=C0415
512
+ from gstaichi.lang import matrix_ops
513
+
514
+ return matrix_ops.transpose(self)
515
+
516
+ @gstaichi_scope
517
+ def determinant(a):
518
+ """Returns the determinant of this matrix.
519
+
520
+ Note:
521
+ The matrix dimension should be less than or equal to 4.
522
+
523
+ Returns:
524
+ dtype: The determinant of this matrix.
525
+
526
+ Raises:
527
+ Exception: Determinants of matrices with sizes >= 5 are not supported.
528
+ """
529
+ # pylint: disable=C0415
530
+ from gstaichi.lang import matrix_ops
531
+
532
+ return matrix_ops.determinant(a)
533
+
534
+ @staticmethod
535
+ def diag(dim, val):
536
+ """Returns a diagonal square matrix with the diagonals filled
537
+ with `val`.
538
+
539
+ Args:
540
+ dim (int): the dimension of the wanted square matrix.
541
+ val (TypeVar): value for the diagonal elements.
542
+
543
+ Returns:
544
+ :class:`~gstaichi.Matrix`: The wanted diagonal matrix.
545
+
546
+ Example::
547
+
548
+ >>> m = ti.Matrix.diag(3, 1)
549
+ [[1, 0, 0],
550
+ [0, 1, 0],
551
+ [0, 0, 1]]
552
+ """
553
+ # pylint: disable=C0415
554
+ from gstaichi.lang import matrix_ops
555
+
556
+ return matrix_ops.diag(dim, val)
557
+
558
+ def sum(self):
559
+ """Return the sum of all elements.
560
+
561
+ Example::
562
+
563
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
564
+ >>> m.sum()
565
+ 10
566
+ """
567
+ # pylint: disable=C0415
568
+ from gstaichi.lang import matrix_ops
569
+
570
+ return matrix_ops.sum(self)
571
+
572
+ def norm(self, eps=0):
573
+ """Returns the square root of the sum of the absolute squares
574
+ of its elements.
575
+
576
+ Args:
577
+ eps (Number): a safe-guard value for sqrt, usually 0.
578
+
579
+ Example::
580
+
581
+ >>> a = ti.Vector([3, 4])
582
+ >>> a.norm()
583
+ 5
584
+
585
+ Returns:
586
+ The square root of the sum of the absolute squares of its elements.
587
+ """
588
+ # pylint: disable=C0415
589
+ from gstaichi.lang import matrix_ops
590
+
591
+ return matrix_ops.norm(self, eps=eps)
592
+
593
+ def norm_inv(self, eps=0):
594
+ """The inverse of the matrix :func:`~gstaichi.lang.matrix.Matrix.norm`.
595
+
596
+ Args:
597
+ eps (float): a safe-guard value for sqrt, usually 0.
598
+
599
+ Returns:
600
+ The inverse of the matrix/vector `norm`.
601
+ """
602
+ # pylint: disable=C0415
603
+ from gstaichi.lang import matrix_ops
604
+
605
+ return matrix_ops.norm_inv(self, eps=eps)
606
+
607
+ def norm_sqr(self):
608
+ """Returns the sum of the absolute squares of its elements."""
609
+ # pylint: disable=C0415
610
+ from gstaichi.lang import matrix_ops
611
+
612
+ return matrix_ops.norm_sqr(self)
613
+
614
+ def max(self):
615
+ """Returns the maximum element value."""
616
+ # pylint: disable=C0415
617
+ from gstaichi.lang import matrix_ops
618
+
619
+ return matrix_ops.max(self)
620
+
621
+ def min(self):
622
+ """Returns the minimum element value."""
623
+ # pylint: disable=C0415
624
+ from gstaichi.lang import matrix_ops
625
+
626
+ return matrix_ops.min(self)
627
+
628
+ def any(self):
629
+ """Test whether any element not equal zero.
630
+
631
+ Returns:
632
+ bool: `True` if any element is not equal zero, `False` otherwise.
633
+
634
+ Example::
635
+
636
+ >>> v = ti.Vector([0, 0, 1])
637
+ >>> v.any()
638
+ True
639
+ """
640
+ # pylint: disable=C0415
641
+ from gstaichi.lang import matrix_ops
642
+
643
+ return matrix_ops.any(self)
644
+
645
+ def all(self):
646
+ """Test whether all element not equal zero.
647
+
648
+ Returns:
649
+ bool: `True` if all elements are not equal zero, `False` otherwise.
650
+
651
+ Example::
652
+
653
+ >>> v = ti.Vector([0, 0, 1])
654
+ >>> v.all()
655
+ False
656
+ """
657
+ # pylint: disable=C0415
658
+ from gstaichi.lang import matrix_ops
659
+
660
+ return matrix_ops.all(self)
661
+
662
+ def fill(self, val):
663
+ """Fills the matrix with a specified value.
664
+
665
+ Args:
666
+ val (Union[int, float]): Value to fill.
667
+
668
+ Example::
669
+
670
+ >>> A = ti.Matrix([0, 1, 2, 3])
671
+ >>> A.fill(-1)
672
+ >>> A
673
+ [-1, -1, -1, -1]
674
+ """
675
+ # pylint: disable=C0415
676
+ from gstaichi.lang import matrix_ops
677
+
678
+ return matrix_ops.fill(self, val)
679
+
680
+ def to_numpy(self):
681
+ """Converts this matrix to a numpy array.
682
+
683
+ Returns:
684
+ numpy.ndarray: The result numpy array.
685
+
686
+ Example::
687
+
688
+ >>> A = ti.Matrix([[0], [1], [2], [3]])
689
+ >>> A.to_numpy()
690
+ >>> A
691
+ array([[0], [1], [2], [3]])
692
+ """
693
+ if self.is_host_access:
694
+ return np.array(self.to_list())
695
+ return self.entries
696
+
697
+ @gstaichi_scope
698
+ def __ti_repr__(self):
699
+ yield "["
700
+ for i in range(self.n):
701
+ if i:
702
+ yield ", "
703
+ if self.m != 1:
704
+ yield "["
705
+ for j in range(self.m):
706
+ if j:
707
+ yield ", "
708
+ yield self(i, j)
709
+ if self.m != 1:
710
+ yield "]"
711
+ yield "]"
712
+
713
+ def __str__(self):
714
+ """Python scope matrix print support."""
715
+ if impl.inside_kernel():
716
+ """
717
+ It seems that when pybind11 got an type mismatch, it will try
718
+ to invoke `repr` to show the object... e.g.:
719
+
720
+ TypeError: make_const_expr_f32(): incompatible function arguments. The following argument types are supported:
721
+ 1. (arg0: float) -> gstaichi_python.Expr
722
+
723
+ Invoked with: <GsTaichi 2x1 Matrix>
724
+
725
+ So we have to make it happy with a dummy string...
726
+ """
727
+ return f"<{self.n}x{self.m} ti.Matrix>"
728
+ return str(self.to_numpy())
729
+
730
+ def __repr__(self):
731
+ return str(self.to_numpy())
732
+
733
+ @staticmethod
734
+ @gstaichi_scope
735
+ def zero(dt, n, m=None):
736
+ """Constructs a Matrix filled with zeros.
737
+
738
+ Args:
739
+ dt (DataType): The desired data type.
740
+ n (int): The first dimension (row) of the matrix.
741
+ m (int, optional): The second dimension (column) of the matrix.
742
+
743
+ Returns:
744
+ :class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with zeros.
745
+
746
+ """
747
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
748
+
749
+ if m is None:
750
+ return matrix_ops._filled_vector(n, dt, 0)
751
+ return matrix_ops._filled_matrix(n, m, dt, 0)
752
+
753
+ @staticmethod
754
+ @gstaichi_scope
755
+ def one(dt, n, m=None):
756
+ """Constructs a Matrix filled with ones.
757
+
758
+ Args:
759
+ dt (DataType): The desired data type.
760
+ n (int): The first dimension (row) of the matrix.
761
+ m (int, optional): The second dimension (column) of the matrix.
762
+
763
+ Returns:
764
+ :class:`~gstaichi.lang.matrix.Matrix`: A :class:`~gstaichi.lang.matrix.Matrix` instance filled with ones.
765
+
766
+ """
767
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
768
+
769
+ if m is None:
770
+ return matrix_ops._filled_vector(n, dt, 1)
771
+ return matrix_ops._filled_matrix(n, m, dt, 1)
772
+
773
+ @staticmethod
774
+ @gstaichi_scope
775
+ def unit(n, i, dt=None):
776
+ """Constructs a n-D vector with the `i`-th entry being equal to one and
777
+ the remaining entries are all zeros.
778
+
779
+ Args:
780
+ n (int): The length of the vector.
781
+ i (int): The index of the entry that will be filled with one.
782
+ dt (:mod:`~gstaichi.types.primitive_types`, optional): The desired data type.
783
+
784
+ Returns:
785
+ :class:`~gstaichi.Matrix`: The returned vector.
786
+
787
+ Example::
788
+
789
+ >>> A = ti.Matrix.unit(3, 1)
790
+ >>> A
791
+ [0, 1, 0]
792
+ """
793
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
794
+
795
+ if dt is None:
796
+ dt = int
797
+ assert 0 <= i < n
798
+ return matrix_ops._unit_vector(n, i, dt)
799
+
800
+ @staticmethod
801
+ @gstaichi_scope
802
+ def identity(dt, n):
803
+ """Constructs an identity Matrix with shape (n, n).
804
+
805
+ Args:
806
+ dt (DataType): The desired data type.
807
+ n (int): The number of rows/columns.
808
+
809
+ Returns:
810
+ :class:`~gstaichi.Matrix`: An `n x n` identity matrix.
811
+ """
812
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
813
+
814
+ return matrix_ops._identity_matrix(n, dt)
815
+
816
+ @classmethod
817
+ @python_scope
818
+ def field(
819
+ cls,
820
+ n,
821
+ m,
822
+ dtype,
823
+ shape=None,
824
+ order=None,
825
+ name="",
826
+ offset=None,
827
+ needs_grad=False,
828
+ needs_dual=False,
829
+ layout=Layout.AOS,
830
+ ndim=None,
831
+ ):
832
+ """Construct a data container to hold all elements of the Matrix.
833
+
834
+ Args:
835
+ n (int): The desired number of rows of the Matrix.
836
+ m (int): The desired number of columns of the Matrix.
837
+ dtype (DataType, optional): The desired data type of the Matrix.
838
+ shape (Union[int, tuple of int], optional): The desired shape of the Matrix.
839
+ order (str, optional): order of the shape laid out in memory.
840
+ name (string, optional): The custom name of the field.
841
+ offset (Union[int, tuple of int], optional): The coordinate offset
842
+ of all elements in a field.
843
+ needs_grad (bool, optional): Whether the Matrix need grad field (reverse mode autodiff).
844
+ needs_dual (bool, optional): Whether the Matrix need dual field (forward mode autodiff).
845
+ layout (Layout, optional): The field layout, either Array Of
846
+ Structure (AOS) or Structure Of Array (SOA).
847
+
848
+ Returns:
849
+ :class:`~gstaichi.Matrix`: A matrix.
850
+ """
851
+ entries = []
852
+ element_dim = ndim if ndim is not None else 2
853
+ if isinstance(dtype, (list, tuple, np.ndarray)):
854
+ # set different dtype for each element in Matrix
855
+ # see #2135
856
+ if m == 1:
857
+ assert (
858
+ len(np.shape(dtype)) == 1 and len(dtype) == n
859
+ ), f"Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}"
860
+ for i in range(n):
861
+ entries.append(
862
+ impl.create_field_member(
863
+ dtype[i],
864
+ name=name,
865
+ needs_grad=needs_grad,
866
+ needs_dual=needs_dual,
867
+ )
868
+ )
869
+ else:
870
+ assert (
871
+ len(np.shape(dtype)) == 2 and len(dtype) == n and len(dtype[0]) == m
872
+ ), f"Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}"
873
+ for i in range(n):
874
+ for j in range(m):
875
+ entries.append(
876
+ impl.create_field_member(
877
+ dtype[i][j],
878
+ name=name,
879
+ needs_grad=needs_grad,
880
+ needs_dual=needs_dual,
881
+ )
882
+ )
883
+ else:
884
+ for _ in range(n * m):
885
+ entries.append(impl.create_field_member(dtype, name=name, needs_grad=needs_grad, needs_dual=needs_dual))
886
+ entries, entries_grad, entries_dual = zip(*entries)
887
+
888
+ entries = MatrixField(entries, n, m, element_dim)
889
+ if all(entries_grad):
890
+ entries_grad = MatrixField(entries_grad, n, m, element_dim)
891
+ entries._set_grad(entries_grad)
892
+ if all(entries_dual):
893
+ entries_dual = MatrixField(entries_dual, n, m, element_dim)
894
+ entries._set_dual(entries_dual)
895
+
896
+ impl.get_runtime().matrix_fields.append(entries)
897
+
898
+ if shape is None:
899
+ if offset is not None:
900
+ raise GsTaichiSyntaxError("shape cannot be None when offset is set")
901
+ if order is not None:
902
+ raise GsTaichiSyntaxError("shape cannot be None when order is set")
903
+ else:
904
+ if isinstance(shape, numbers.Number):
905
+ shape = (shape,)
906
+ if isinstance(offset, numbers.Number):
907
+ offset = (offset,)
908
+ dim = len(shape)
909
+ if offset is not None and dim != len(offset):
910
+ raise GsTaichiSyntaxError(
911
+ f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
912
+ )
913
+ axis_seq = []
914
+ shape_seq = []
915
+ if order is not None:
916
+ if dim != len(order):
917
+ raise GsTaichiSyntaxError(
918
+ f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
919
+ )
920
+ if dim != len(set(order)):
921
+ raise GsTaichiSyntaxError("The axes in order must be different")
922
+ for ch in order:
923
+ axis = ord(ch) - ord("i")
924
+ if axis < 0 or axis >= dim:
925
+ raise GsTaichiSyntaxError(f"Invalid axis {ch}")
926
+ axis_seq.append(axis)
927
+ shape_seq.append(shape[axis])
928
+ else:
929
+ axis_seq = list(range(dim))
930
+ shape_seq = list(shape)
931
+ same_level = order is None
932
+ if layout == Layout.SOA:
933
+ for e in entries._get_field_members():
934
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
935
+ if needs_grad:
936
+ for e in entries_grad._get_field_members():
937
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
938
+ if needs_dual:
939
+ for e in entries_dual._get_field_members():
940
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
941
+ else:
942
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries, offset=offset)
943
+ if needs_grad:
944
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_grad, offset=offset)
945
+ if needs_dual:
946
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_dual, offset=offset)
947
+ return entries
948
+
949
+ @classmethod
950
+ @python_scope
951
+ def ndarray(cls, n, m, dtype, shape):
952
+ """Defines a GsTaichi ndarray with matrix elements.
953
+ This function must be called in Python scope, and after `ti.init` is called.
954
+
955
+ Args:
956
+ n (int): Number of rows of the matrix.
957
+ m (int): Number of columns of the matrix.
958
+ dtype (DataType): Data type of each value.
959
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
960
+
961
+ Example::
962
+
963
+ The code below shows how a GsTaichi ndarray with matrix elements \
964
+ can be declared and defined::
965
+
966
+ >>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
967
+ """
968
+ if isinstance(shape, numbers.Number):
969
+ shape = (shape,)
970
+ return MatrixNdarray(n, m, dtype, shape)
971
+
972
+ @staticmethod
973
+ def rows(rows):
974
+ """Constructs a matrix by concatenating a list of
975
+ vectors/lists row by row. Must be called in GsTaichi scope.
976
+
977
+ Args:
978
+ rows (List): A list of Vector (1-D Matrix) or a list of list.
979
+
980
+ Returns:
981
+ :class:`~gstaichi.Matrix`: A matrix.
982
+
983
+ Example::
984
+
985
+ >>> @ti.kernel
986
+ >>> def test():
987
+ >>> v1 = ti.Vector([1, 2, 3])
988
+ >>> v2 = ti.Vector([4, 5, 6])
989
+ >>> m = ti.Matrix.rows([v1, v2])
990
+ >>> print(m)
991
+ >>>
992
+ >>> test()
993
+ [[1, 2, 3], [4, 5, 6]]
994
+ """
995
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
996
+
997
+ return matrix_ops.rows(rows)
998
+
999
+ @staticmethod
1000
+ def cols(cols):
1001
+ """Constructs a Matrix instance by concatenating Vectors/lists column by column.
1002
+
1003
+ Args:
1004
+ cols (List): A list of Vector (1-D Matrix) or a list of list.
1005
+
1006
+ Returns:
1007
+ :class:`~gstaichi.Matrix`: A matrix.
1008
+
1009
+ Example::
1010
+
1011
+ >>> @ti.kernel
1012
+ >>> def test():
1013
+ >>> v1 = ti.Vector([1, 2, 3])
1014
+ >>> v2 = ti.Vector([4, 5, 6])
1015
+ >>> m = ti.Matrix.cols([v1, v2])
1016
+ >>> print(m)
1017
+ >>>
1018
+ >>> test()
1019
+ [[1, 4], [2, 5], [3, 6]]
1020
+ """
1021
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1022
+
1023
+ return matrix_ops.cols(cols)
1024
+
1025
+ def __hash__(self):
1026
+ # TODO: refactor KernelTemplateMapper
1027
+ # If not, we get `unhashable type: Matrix` when
1028
+ # using matrices as template arguments.
1029
+ return id(self)
1030
+
1031
+ def dot(self, other):
1032
+ """Performs the dot product of two vectors.
1033
+
1034
+ To call this method, both multiplicatives must be vectors.
1035
+
1036
+ Args:
1037
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1038
+
1039
+ Returns:
1040
+ DataType: The dot product result (scalar) of the two Vectors.
1041
+
1042
+ Example::
1043
+
1044
+ >>> v1 = ti.Vector([1, 2, 3])
1045
+ >>> v2 = ti.Vector([3, 4, 5])
1046
+ >>> v1.dot(v2)
1047
+ 26
1048
+ """
1049
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1050
+
1051
+ return matrix_ops.dot(self, other)
1052
+
1053
+ def cross(self, other):
1054
+ """Performs the cross product with the input vector (1-D Matrix).
1055
+
1056
+ Both two vectors must have the same dimension <= 3.
1057
+
1058
+ For two 2d vectors (x1, y1) and (x2, y2), the return value is the
1059
+ scalar `x1*y2 - x2*y1`.
1060
+
1061
+ For two 3d vectors `v` and `w`, the return value is the 3d vector
1062
+ `v x w`.
1063
+
1064
+ Args:
1065
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1066
+
1067
+ Returns:
1068
+ :class:`~gstaichi.Matrix`: The cross product of the two Vectors.
1069
+ """
1070
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1071
+
1072
+ return matrix_ops.cross(self, other)
1073
+
1074
+ def outer_product(self, other):
1075
+ """Performs the outer product with the input Vector (1-D Matrix).
1076
+
1077
+ The outer_product of two vectors `v = (x1, x2, ..., xn)`,
1078
+ `w = (y1, y2, ..., yn)` is a `n` times `n` square matrix, and its `(i, j)`
1079
+ entry is equal to `xi*yj`.
1080
+
1081
+ Args:
1082
+ other (:class:`~gstaichi.Matrix`): The input Vector.
1083
+
1084
+ Returns:
1085
+ :class:`~gstaichi.Matrix`: The outer product of the two Vectors.
1086
+ """
1087
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
1088
+
1089
+ return matrix_ops.outer_product(self, other)
1090
+
1091
+
1092
+ class Vector(Matrix):
1093
+ def __init__(self, arr, dt=None, **kwargs):
1094
+ """Constructs a vector from given array.
1095
+
1096
+ A vector is an instance of a 2-D matrix with the second dimension being equal to 1.
1097
+
1098
+ Args:
1099
+ arr (Union[list, tuple, np.ndarray]): The initial values of the Vector.
1100
+ dt (:mod:`~gstaichi.types.primitive_types`): data type of the vector.
1101
+
1102
+ Returns:
1103
+ :class:`~gstaichi.Matrix`: A vector instance.
1104
+ Example::
1105
+ >>> u = ti.Vector([1, 2])
1106
+ >>> print(u.m, u.n) # verify a vector is a matrix of shape (n, 1)
1107
+ 2 1
1108
+ >>> v = ti.Vector([3, 4])
1109
+ >>> u + v
1110
+ [4 6]
1111
+ """
1112
+ super().__init__(arr, dt=dt, **kwargs)
1113
+
1114
+ def get_shape(self):
1115
+ return (self.n,)
1116
+
1117
+ @classmethod
1118
+ def field(cls, n, dtype, *args, **kwargs):
1119
+ """ti.Vector.field"""
1120
+ ndim = kwargs.get("ndim", 1)
1121
+ assert ndim == 1
1122
+ kwargs["ndim"] = 1
1123
+ return super().field(n, 1, dtype, *args, **kwargs)
1124
+
1125
+ @classmethod
1126
+ @python_scope
1127
+ def ndarray(cls, n, dtype, shape):
1128
+ """Defines a GsTaichi ndarray with vector elements.
1129
+
1130
+ Args:
1131
+ n (int): Size of the vector.
1132
+ dtype (DataType): Data type of each value.
1133
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1134
+
1135
+ Example:
1136
+ The code below shows how a GsTaichi ndarray with vector elements can be declared and defined::
1137
+
1138
+ >>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
1139
+ """
1140
+ if isinstance(shape, numbers.Number):
1141
+ shape = (shape,)
1142
+ return VectorNdarray(n, dtype, shape)
1143
+
1144
+
1145
+ class MatrixField(Field):
1146
+ """GsTaichi matrix field with SNode implementation.
1147
+
1148
+ Args:
1149
+ vars (List[Expr]): Field members.
1150
+ n (Int): Number of rows.
1151
+ m (Int): Number of columns.
1152
+ ndim (Int): Number of dimensions; forced reshape if given.
1153
+ """
1154
+
1155
+ def __init__(self, _vars, n, m, ndim=2):
1156
+ assert len(_vars) == n * m
1157
+ assert ndim in (0, 1, 2)
1158
+ super().__init__(_vars)
1159
+ self.n = n
1160
+ self.m = m
1161
+ self.ndim = ndim
1162
+ self.ptr = ti_python_core.expr_matrix_field([var.ptr for var in self.vars], [n, m][:ndim])
1163
+
1164
+ def get_scalar_field(self, *indices):
1165
+ """Creates a ScalarField using a specific field member.
1166
+
1167
+ Args:
1168
+ indices (Tuple[Int]): Specified indices of the field member.
1169
+
1170
+ Returns:
1171
+ ScalarField: The result ScalarField.
1172
+ """
1173
+ assert len(indices) in [1, 2]
1174
+ i = indices[0]
1175
+ j = 0 if len(indices) == 1 else indices[1]
1176
+ return ScalarField(self.vars[i * self.m + j])
1177
+
1178
+ def _get_dynamic_index_stride(self):
1179
+ if self.ptr.get_dynamic_indexable():
1180
+ return self.ptr.get_dynamic_index_stride()
1181
+ return None
1182
+
1183
+ def _calc_dynamic_index_stride(self):
1184
+ # Algorithm: https://github.com/taichi-dev/gstaichi/issues/3810
1185
+ paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
1186
+ num_members = len(paths)
1187
+ if num_members == 1:
1188
+ self.ptr.set_dynamic_index_stride(0)
1189
+ return
1190
+ length = len(paths[0])
1191
+ if any(len(path) != length or ti_python_core.is_quant(path[length - 1]._dtype) for path in paths):
1192
+ return
1193
+ for i in range(length):
1194
+ if any(path[i] != paths[0][i] for path in paths):
1195
+ depth_below_lca = i
1196
+ break
1197
+ for i in range(depth_below_lca, length - 1):
1198
+ if any(
1199
+ path[i].ptr.type != ti_python_core.SNodeType.dense
1200
+ or path[i]._cell_size_bytes != paths[0][i]._cell_size_bytes
1201
+ or path[i + 1]._offset_bytes_in_parent_cell != paths[0][i + 1]._offset_bytes_in_parent_cell
1202
+ for path in paths
1203
+ ):
1204
+ return
1205
+ stride = (
1206
+ paths[1][depth_below_lca]._offset_bytes_in_parent_cell
1207
+ - paths[0][depth_below_lca]._offset_bytes_in_parent_cell
1208
+ )
1209
+ for i in range(2, num_members):
1210
+ if (
1211
+ stride
1212
+ != paths[i][depth_below_lca]._offset_bytes_in_parent_cell
1213
+ - paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell
1214
+ ):
1215
+ return
1216
+ self.ptr.set_dynamic_index_stride(stride)
1217
+
1218
+ def fill(self, val):
1219
+ """Fills this matrix field with specified values.
1220
+
1221
+ Args:
1222
+ val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill,
1223
+ should have consistent dimension consistent with `self`.
1224
+ """
1225
+ if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) and not val.is_tensor()):
1226
+ if self.ndim == 2:
1227
+ val = tuple(tuple(val for _ in range(self.m)) for _ in range(self.n))
1228
+ else:
1229
+ assert self.ndim == 1
1230
+ val = tuple(val for _ in range(self.n))
1231
+ elif isinstance(val, expr.Expr) and val.is_tensor():
1232
+ assert val.n == self.n
1233
+ if self.ndim != 1:
1234
+ assert val.m == self.m
1235
+ else:
1236
+ if isinstance(val, Matrix):
1237
+ val = val.to_list()
1238
+ assert isinstance(val, (list, tuple))
1239
+ val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
1240
+ assert len(val) == self.n
1241
+ if self.ndim != 1:
1242
+ assert len(val[0]) == self.m
1243
+ if in_python_scope():
1244
+ from gstaichi._kernels import ( # pylint: disable=C0415
1245
+ field_fill_python_scope, # pylint: disable=C0415
1246
+ )
1247
+
1248
+ field_fill_python_scope(self, val)
1249
+ else:
1250
+ from gstaichi._funcs import ( # pylint: disable=C0415
1251
+ field_fill_gstaichi_scope, # pylint: disable=C0415
1252
+ )
1253
+
1254
+ field_fill_gstaichi_scope(self, val)
1255
+
1256
+ @python_scope
1257
+ def to_numpy(self, keep_dims=False, dtype=None):
1258
+ """Converts the field instance to a NumPy array.
1259
+
1260
+ Args:
1261
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1262
+ When keep_dims=True, on an n-D matrix field, the numpy array always has n+2 dims, even for 1x1, 1xn, nx1 matrix fields.
1263
+ When keep_dims=False, the resulting numpy array should skip the matrix dims with size 1.
1264
+ For example, a 4x1 or 1x4 matrix field with 5x6x7 elements results in an array of shape 5x6x7x4.
1265
+ dtype (DataType, optional): The desired data type of returned numpy array.
1266
+
1267
+ Returns:
1268
+ numpy.ndarray: The result NumPy array.
1269
+ """
1270
+ if dtype is None:
1271
+ dtype = to_numpy_type(self.dtype)
1272
+ as_vector = self.m == 1 and not keep_dims
1273
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1274
+ arr = np.zeros(self.shape + shape_ext, dtype=dtype)
1275
+ from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1276
+
1277
+ matrix_to_ext_arr(self, arr, as_vector)
1278
+ runtime_ops.sync()
1279
+ return arr
1280
+
1281
+ def to_torch(self, device=None, keep_dims=False):
1282
+ """Converts the field instance to a PyTorch tensor.
1283
+
1284
+ Args:
1285
+ device (torch.device, optional): The desired device of returned tensor.
1286
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1287
+ See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1288
+
1289
+ Returns:
1290
+ torch.tensor: The result torch tensor.
1291
+ """
1292
+ import torch # pylint: disable=C0415
1293
+
1294
+ as_vector = self.m == 1 and not keep_dims
1295
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1296
+ # pylint: disable=E1101
1297
+ arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device)
1298
+ from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1299
+
1300
+ matrix_to_ext_arr(self, arr, as_vector)
1301
+ runtime_ops.sync()
1302
+ return arr
1303
+
1304
+ def to_paddle(self, place=None, keep_dims=False):
1305
+ """Converts the field instance to a Paddle tensor.
1306
+
1307
+ Args:
1308
+ place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
1309
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1310
+ See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1311
+
1312
+ Returns:
1313
+ paddle.Tensor: The result paddle tensor.
1314
+ """
1315
+ import paddle # pylint: disable=C0415
1316
+
1317
+ as_vector = self.m == 1 and not keep_dims and self.ndim == 1
1318
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1319
+ # pylint: disable=E1101
1320
+ # paddle.empty() doesn't support argument `place``
1321
+ arr = paddle.to_tensor(
1322
+ paddle.empty(self.shape + shape_ext, to_paddle_type(self.dtype)),
1323
+ place=place,
1324
+ )
1325
+ from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1326
+
1327
+ matrix_to_ext_arr(self, arr, as_vector)
1328
+ runtime_ops.sync()
1329
+ return arr
1330
+
1331
+ @python_scope
1332
+ def _from_external_arr(self, arr):
1333
+ if len(arr.shape) == len(self.shape) + 1:
1334
+ as_vector = True
1335
+ assert self.m == 1, "This is not a vector field"
1336
+ else:
1337
+ as_vector = False
1338
+ assert len(arr.shape) == len(self.shape) + 2
1339
+ dim_ext = 1 if as_vector else 2
1340
+ assert len(arr.shape) == len(self.shape) + dim_ext
1341
+ from gstaichi._kernels import ext_arr_to_matrix # pylint: disable=C0415
1342
+
1343
+ ext_arr_to_matrix(arr, self, as_vector)
1344
+ runtime_ops.sync()
1345
+
1346
+ @python_scope
1347
+ def from_numpy(self, arr):
1348
+ """Copies an `numpy.ndarray` into this field.
1349
+
1350
+ Example::
1351
+
1352
+ >>> m = ti.Matrix.field(2, 2, ti.f32, shape=(3, 3))
1353
+ >>> arr = numpy.ones((3, 3, 2, 2))
1354
+ >>> m.from_numpy(arr)
1355
+ """
1356
+
1357
+ if not arr.flags.c_contiguous:
1358
+ arr = np.ascontiguousarray(arr)
1359
+ self._from_external_arr(arr)
1360
+
1361
+ @python_scope
1362
+ def __setitem__(self, key, value):
1363
+ self._initialize_host_accessors()
1364
+ self[key]._set_entries(value)
1365
+
1366
+ @python_scope
1367
+ def __getitem__(self, key):
1368
+ self._initialize_host_accessors()
1369
+ key = self._pad_key(key)
1370
+ _host_access = self._host_access(key)
1371
+ if self.ndim == 1:
1372
+ return Vector([_host_access[i] for i in range(self.n)])
1373
+ return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] for i in range(self.n)])
1374
+
1375
+ def __repr__(self):
1376
+ # make interactive shell happy, prevent materialization
1377
+ return f"<{self.n}x{self.m} ti.Matrix.field>"
1378
+
1379
+
1380
+ class MatrixType(CompoundType):
1381
+ def __init__(self, n, m, ndim, dtype):
1382
+ self.n = n
1383
+ self.m = m
1384
+ self.ndim = ndim
1385
+ # FIXME(haidong): dtypes should not be left empty for ndarray.
1386
+ # Remove the None dtype when we are ready to break legacy code.
1387
+ if dtype is not None:
1388
+ self.dtype = cook_dtype(dtype)
1389
+ shape = (n, m) if ndim == 2 else (n,)
1390
+ self.tensor_type = _type_factory.get_tensor_type(shape, self.dtype)
1391
+ else:
1392
+ self.dtype = None
1393
+ self.tensor_type = None
1394
+
1395
+ def __call__(self, *args):
1396
+ """Return a matrix matching the shape and dtype.
1397
+
1398
+ This function will try to convert the input to a `n x m` matrix, with n, m being
1399
+ the number of rows/cols of this matrix type.
1400
+
1401
+ Example::
1402
+
1403
+ >>> mat4x3 = MatrixType(4, 3, float)
1404
+ >>> mat2x6 = MatrixType(2, 6, float)
1405
+
1406
+ Create from n x m scalars, of a 1d list of n x m scalars:
1407
+
1408
+ >>> m = mat4x3([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
1409
+ >>> m = mat4x3(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
1410
+
1411
+ Create from n vectors/lists, with each one of dimension m:
1412
+
1413
+ >>> m = mat4x3([1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12])
1414
+
1415
+ Create from a single scalar
1416
+
1417
+ >>> m = mat4x3(1)
1418
+
1419
+ Create from another 2d list/matrix, as long as they have the same number of entries
1420
+
1421
+ >>> m = mat4x3([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
1422
+ >>> m = mat4x3(m)
1423
+ >>> k = mat2x6(m)
1424
+
1425
+ """
1426
+ if len(args) == 0:
1427
+ raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
1428
+ if len(args) == 1:
1429
+ # Init from a real Matrix
1430
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1431
+ arg = args[0]
1432
+ shape = arg.ptr.get_rvalue_type().shape()
1433
+ assert self.ndim == len(shape)
1434
+ assert self.n == shape[0]
1435
+ if self.ndim > 1:
1436
+ assert self.m == shape[1]
1437
+ return expr.Expr(arg.ptr)
1438
+
1439
+ # initialize by a single scalar, e.g. matnxm(1)
1440
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1441
+ entries = [args[0] for _ in range(self.m) for _ in range(self.n)]
1442
+ return self._instantiate(entries)
1443
+ args = args[0]
1444
+ # collect all input entries to a 1d list and then reshape
1445
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1446
+ entries = []
1447
+ for x in args:
1448
+ if isinstance(x, (list, tuple)):
1449
+ entries += x
1450
+ elif isinstance(x, np.ndarray):
1451
+ entries += list(x.ravel())
1452
+ elif isinstance(x, Matrix):
1453
+ entries += x.to_list()
1454
+ else:
1455
+ entries.append(x)
1456
+
1457
+ return self._instantiate(entries)
1458
+
1459
+ def from_gstaichi_object(self, func_ret, ret_index=()):
1460
+ return self(
1461
+ [
1462
+ expr.Expr(
1463
+ ti_python_core.make_get_element_expr(
1464
+ func_ret.ptr,
1465
+ ret_index + (i,),
1466
+ _ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1467
+ )
1468
+ )
1469
+ for i in range(self.m * self.n)
1470
+ ]
1471
+ )
1472
+
1473
+ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
1474
+ if self.dtype in primitive_types.integer_types:
1475
+ if is_signed(cook_dtype(self.dtype)):
1476
+ get_ret_func = launch_ctx.get_struct_ret_int
1477
+ else:
1478
+ get_ret_func = launch_ctx.get_struct_ret_uint
1479
+ elif self.dtype in primitive_types.real_types:
1480
+ get_ret_func = launch_ctx.get_struct_ret_float
1481
+ else:
1482
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1483
+ return self([get_ret_func(ret_index + (i,)) for i in range(self.m * self.n)])
1484
+
1485
+ def set_kernel_struct_args(self, mat, launch_ctx, ret_index=()):
1486
+ if self.dtype in primitive_types.integer_types:
1487
+ if is_signed(cook_dtype(self.dtype)):
1488
+ set_arg_func = launch_ctx.set_struct_arg_int
1489
+ else:
1490
+ set_arg_func = launch_ctx.set_struct_arg_uint
1491
+ elif self.dtype in primitive_types.real_types:
1492
+ set_arg_func = launch_ctx.set_struct_arg_float
1493
+ else:
1494
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1495
+ if self.ndim == 1:
1496
+ for i in range(self.n):
1497
+ set_arg_func(ret_index + (i,), mat[i])
1498
+ else:
1499
+ for i in range(self.n):
1500
+ for j in range(self.m):
1501
+ set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1502
+
1503
+ def set_argpack_struct_args(self, mat, argpack, ret_index=()):
1504
+ if self.dtype in primitive_types.integer_types:
1505
+ if is_signed(cook_dtype(self.dtype)):
1506
+ set_arg_func = argpack.set_arg_int
1507
+ else:
1508
+ set_arg_func = argpack.set_arg_uint
1509
+ elif self.dtype in primitive_types.real_types:
1510
+ set_arg_func = argpack.set_arg_float
1511
+ else:
1512
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1513
+ if self.ndim == 1:
1514
+ for i in range(self.n):
1515
+ set_arg_func(ret_index + (i,), mat[i])
1516
+ else:
1517
+ for i in range(self.n):
1518
+ for j in range(self.m):
1519
+ set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1520
+
1521
+ def _instantiate_in_python_scope(self, entries):
1522
+ entries = [[entries[k * self.m + i] for i in range(self.m)] for k in range(self.n)]
1523
+ return Matrix(
1524
+ [
1525
+ [
1526
+ int(entries[i][j]) if self.dtype in primitive_types.integer_types else float(entries[i][j])
1527
+ for j in range(self.m)
1528
+ ]
1529
+ for i in range(self.n)
1530
+ ],
1531
+ dt=self.dtype,
1532
+ )
1533
+
1534
+ def _instantiate(self, entries):
1535
+ if in_python_scope():
1536
+ return self._instantiate_in_python_scope(entries)
1537
+
1538
+ return make_matrix_with_shape(entries, [self.n, self.m], self.dtype)
1539
+
1540
+ def field(self, **kwargs):
1541
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1542
+ kwargs.update({"ndim": self.ndim})
1543
+ return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)
1544
+
1545
+ def ndarray(self, **kwargs):
1546
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1547
+ kwargs.update({"ndim": self.ndim})
1548
+ return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)
1549
+
1550
+ def get_shape(self):
1551
+ if self.ndim == 1:
1552
+ return (self.n,)
1553
+ return (self.n, self.m)
1554
+
1555
+ def to_string(self):
1556
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1557
+ return f"MatrixType[{self.n},{self.m}, {dtype_str}]"
1558
+
1559
+ def check_matched(self, other):
1560
+ if self.ndim != len(other.shape()):
1561
+ return False
1562
+ if self.dtype is not None and self.dtype != other.element_type():
1563
+ return False
1564
+ shape = self.get_shape()
1565
+ for i in range(self.ndim):
1566
+ if shape[i] is not None and shape[i] != other.shape()[i]:
1567
+ return False
1568
+ return True
1569
+
1570
+
1571
+ class VectorType(MatrixType):
1572
+ def __init__(self, n, dtype):
1573
+ super().__init__(n, 1, 1, dtype)
1574
+
1575
+ def __call__(self, *args):
1576
+ """Return a vector matching the shape and dtype.
1577
+
1578
+ This function will try to convert the input to a `n`-component vector.
1579
+
1580
+ Example::
1581
+
1582
+ >>> vec3 = VectorType(3, float)
1583
+
1584
+ Create from n scalars:
1585
+
1586
+ >>> v = vec3(1, 2, 3)
1587
+
1588
+ Create from a list/tuple of n scalars:
1589
+
1590
+ >>> v = vec3([1, 2, 3])
1591
+
1592
+ Create from a single scalar
1593
+
1594
+ >>> v = vec3(1)
1595
+
1596
+ """
1597
+ if len(args) == 0:
1598
+ raise GsTaichiSyntaxError("Custom type instances need to be created with an initial value.")
1599
+ if len(args) == 1:
1600
+ # Init from a real Matrix
1601
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1602
+ arg = args[0]
1603
+ shape = arg.ptr.get_rvalue_type().shape()
1604
+ assert len(shape) == 1
1605
+ assert self.n == shape[0]
1606
+ return expr.Expr(arg.ptr)
1607
+
1608
+ # initialize by a single scalar, e.g. matnxm(1)
1609
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1610
+ entries = [args[0] for _ in range(self.n)]
1611
+ return self._instantiate(entries)
1612
+ args = args[0]
1613
+ # collect all input entries to a 1d list and then reshape
1614
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1615
+ entries = []
1616
+ for x in args:
1617
+ if isinstance(x, (list, tuple)):
1618
+ entries += x
1619
+ elif isinstance(x, np.ndarray):
1620
+ entries += list(x.ravel())
1621
+ elif isinstance(x, Matrix):
1622
+ entries += x.to_list()
1623
+ else:
1624
+ entries.append(x)
1625
+
1626
+ # type cast
1627
+ return self._instantiate(entries)
1628
+
1629
+ def _instantiate_in_python_scope(self, entries):
1630
+ return Vector(
1631
+ [
1632
+ int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i])
1633
+ for i in range(self.n)
1634
+ ],
1635
+ dt=self.dtype,
1636
+ )
1637
+
1638
+ def _instantiate(self, entries):
1639
+ if in_python_scope():
1640
+ return self._instantiate_in_python_scope(entries)
1641
+
1642
+ return make_matrix_with_shape(entries, [self.n], self.dtype)
1643
+
1644
+ def field(self, **kwargs):
1645
+ return Vector.field(self.n, dtype=self.dtype, **kwargs)
1646
+
1647
+ def ndarray(self, **kwargs):
1648
+ return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)
1649
+
1650
+ def to_string(self):
1651
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1652
+ return f"VectorType[{self.n}, {dtype_str}]"
1653
+
1654
+
1655
+ class MatrixNdarray(Ndarray):
1656
+ """GsTaichi ndarray with matrix elements.
1657
+
1658
+ Args:
1659
+ n (int): Number of rows of the matrix.
1660
+ m (int): Number of columns of the matrix.
1661
+ dtype (DataType): Data type of each value.
1662
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1663
+
1664
+ Example::
1665
+
1666
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1667
+ """
1668
+
1669
+ def __init__(self, n, m, dtype, shape):
1670
+ self.n = n
1671
+ self.m = m
1672
+ super().__init__()
1673
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1674
+ self.dtype = cook_dtype(dtype)
1675
+
1676
+ self.layout = Layout.AOS
1677
+ self.shape = tuple(shape)
1678
+ self.element_type = _type_factory.get_tensor_type((self.n, self.m), self.dtype)
1679
+ # TODO: we should pass in element_type, shape, layout instead.
1680
+ self.arr = impl.get_runtime().prog.create_ndarray(
1681
+ cook_dtype(self.element_type),
1682
+ shape,
1683
+ Layout.AOS,
1684
+ zero_fill=True,
1685
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1686
+ )
1687
+
1688
+ @property
1689
+ def element_shape(self):
1690
+ """Returns the shape of each element (a 2D matrix) in this ndarray.
1691
+
1692
+ Example::
1693
+
1694
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1695
+ >>> arr.element_shape
1696
+ (2, 2)
1697
+ """
1698
+ return tuple(self.arr.element_shape())
1699
+
1700
+ @python_scope
1701
+ def __setitem__(self, key, value):
1702
+ if not isinstance(value, (list, tuple)):
1703
+ value = list(value)
1704
+ if not isinstance(value[0], (list, tuple)):
1705
+ value = [[i] for i in value]
1706
+ for i in range(self.n):
1707
+ for j in range(self.m):
1708
+ self[key][i, j] = value[i][j]
1709
+
1710
+ @python_scope
1711
+ def __getitem__(self, key):
1712
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1713
+ return Matrix([[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)] for i in range(self.n)])
1714
+
1715
+ @python_scope
1716
+ def to_numpy(self):
1717
+ """Converts this ndarray to a `numpy.ndarray`.
1718
+
1719
+ Example::
1720
+
1721
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1))
1722
+ >>> arr.to_numpy()
1723
+ [[[[0. 0.]
1724
+ [0. 0.]]]
1725
+
1726
+ [[[0. 0.]
1727
+ [0. 0.]]]]
1728
+ """
1729
+ return self._ndarray_matrix_to_numpy(as_vector=0)
1730
+
1731
+ @python_scope
1732
+ def from_numpy(self, arr):
1733
+ """Copies the data of a `numpy.ndarray` into this array.
1734
+
1735
+ Example::
1736
+
1737
+ >>> m = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1), layout=0)
1738
+ >>> arr = np.ones((2, 1, 2, 2))
1739
+ >>> m.from_numpy(arr)
1740
+ """
1741
+ self._ndarray_matrix_from_numpy(arr, as_vector=0)
1742
+
1743
+ @python_scope
1744
+ def __deepcopy__(self, memo=None):
1745
+ ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape)
1746
+ ret_arr.copy_from(self)
1747
+ return ret_arr
1748
+
1749
+ @python_scope
1750
+ def _fill_by_kernel(self, val):
1751
+ from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1752
+
1753
+ shape = self.element_type.shape()
1754
+ n = shape[0]
1755
+ m = 1
1756
+ if len(shape) > 1:
1757
+ m = shape[1]
1758
+
1759
+ prim_dtype = self.element_type.element_type()
1760
+ matrix_type = MatrixType(n, m, len(shape), prim_dtype)
1761
+ if isinstance(val, Matrix):
1762
+ value = val
1763
+ else:
1764
+ value = matrix_type(val)
1765
+ fill_ndarray_matrix(self, value)
1766
+
1767
+ @python_scope
1768
+ def __repr__(self):
1769
+ return f"<{self.n}x{self.m} {Layout.AOS} ti.Matrix.ndarray>"
1770
+
1771
+
1772
+ class VectorNdarray(Ndarray):
1773
+ """GsTaichi ndarray with vector elements.
1774
+
1775
+ Args:
1776
+ n (int): Size of the vector.
1777
+ dtype (DataType): Data type of each value.
1778
+ shape (Tuple[int]): Shape of the ndarray.
1779
+
1780
+ Example::
1781
+
1782
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1783
+ """
1784
+
1785
+ def __init__(self, n, dtype, shape):
1786
+ self.n = n
1787
+ super().__init__()
1788
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1789
+ self.dtype = cook_dtype(dtype)
1790
+
1791
+ self.layout = Layout.AOS
1792
+ self.shape = tuple(shape)
1793
+ self.element_type = _type_factory.get_tensor_type((n,), self.dtype)
1794
+ self.arr = impl.get_runtime().prog.create_ndarray(
1795
+ cook_dtype(self.element_type),
1796
+ shape,
1797
+ Layout.AOS,
1798
+ zero_fill=True,
1799
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1800
+ )
1801
+
1802
+ @property
1803
+ def element_shape(self):
1804
+ """Gets the dimension of the vector of this ndarray.
1805
+
1806
+ Example::
1807
+
1808
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1809
+ >>> a.element_shape
1810
+ (3,)
1811
+ """
1812
+ return tuple(self.arr.element_shape())
1813
+
1814
+ @python_scope
1815
+ def __setitem__(self, key, value):
1816
+ if not isinstance(value, (list, tuple)):
1817
+ value = list(value)
1818
+ for i in range(self.n):
1819
+ self[key][i] = value[i]
1820
+
1821
+ @python_scope
1822
+ def __getitem__(self, key):
1823
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1824
+ return Vector([NdarrayHostAccess(self, key, (i,)) for i in range(self.n)])
1825
+
1826
+ @python_scope
1827
+ def to_numpy(self):
1828
+ """Converts this vector ndarray to a `numpy.ndarray`.
1829
+
1830
+ Example::
1831
+
1832
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2))
1833
+ >>> a.to_numpy()
1834
+ array([[[0., 0., 0.],
1835
+ [0., 0., 0.]],
1836
+
1837
+ [[0., 0., 0.],
1838
+ [0., 0., 0.]]], dtype=float32)
1839
+ """
1840
+ return self._ndarray_matrix_to_numpy(as_vector=1)
1841
+
1842
+ @python_scope
1843
+ def from_numpy(self, arr):
1844
+ """Copies the data from a `numpy.ndarray` into this ndarray.
1845
+
1846
+ The shape and data type of `arr` must match this ndarray.
1847
+
1848
+ Example::
1849
+
1850
+ >>> import numpy as np
1851
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2), 0)
1852
+ >>> b = np.ones((2, 2, 3), dtype=np.float32)
1853
+ >>> a.from_numpy(b)
1854
+ """
1855
+ self._ndarray_matrix_from_numpy(arr, as_vector=1)
1856
+
1857
+ @python_scope
1858
+ def __deepcopy__(self, memo=None):
1859
+ ret_arr = VectorNdarray(self.n, self.dtype, self.shape)
1860
+ ret_arr.copy_from(self)
1861
+ return ret_arr
1862
+
1863
+ @python_scope
1864
+ def _fill_by_kernel(self, val):
1865
+ from gstaichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1866
+
1867
+ shape = self.element_type.shape()
1868
+ prim_dtype = self.element_type.element_type()
1869
+ vector_type = VectorType(shape[0], prim_dtype)
1870
+ if isinstance(val, Vector):
1871
+ value = val
1872
+ else:
1873
+ value = vector_type(val)
1874
+ fill_ndarray_matrix(self, value)
1875
+
1876
+ @python_scope
1877
+ def __repr__(self):
1878
+ return f"<{self.n} {Layout.AOS} ti.Vector.ndarray>"
1879
+
1880
+
1881
+ __all__ = ["Matrix", "Vector", "MatrixField", "MatrixNdarray", "VectorNdarray"]