gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
taichi/lang/matrix.py ADDED
@@ -0,0 +1,1877 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+ import numbers
5
+ from collections.abc import Iterable
6
+ from itertools import product
7
+
8
+ import numpy as np
9
+
10
+ from taichi._lib import core as ti_python_core
11
+ from taichi._lib.utils import ti_python_core as _ti_python_core
12
+ from taichi.lang import expr, impl, runtime_ops
13
+ from taichi.lang import ops as ops_mod
14
+ from taichi.lang._ndarray import Ndarray, NdarrayHostAccess
15
+ from taichi.lang.common_ops import TaichiOperations
16
+ from taichi.lang.exception import (
17
+ TaichiRuntimeError,
18
+ TaichiRuntimeTypeError,
19
+ TaichiSyntaxError,
20
+ TaichiTypeError,
21
+ )
22
+ from taichi.lang.field import Field, ScalarField, SNodeHostAccess
23
+ from taichi.lang.util import (
24
+ cook_dtype,
25
+ get_traceback,
26
+ in_python_scope,
27
+ python_scope,
28
+ taichi_scope,
29
+ to_numpy_type,
30
+ to_paddle_type,
31
+ to_pytorch_type,
32
+ warning,
33
+ )
34
+ from taichi.types import primitive_types
35
+ from taichi.types.compound_types import CompoundType
36
+ from taichi.types.enums import Layout
37
+ from taichi.types.utils import is_signed
38
+
39
+ _type_factory = _ti_python_core.get_type_factory_instance()
40
+
41
+
42
+ def _generate_swizzle_patterns(key_group: str, required_length=4):
43
+ """Generate vector swizzle patterns from a given set of characters.
44
+
45
+ Example:
46
+
47
+ For `key_group=xyzw` and `required_length=4`, this function will return a
48
+ list consists of all possible strings (no repeats) in characters
49
+ `x`, `y`, `z`, `w` and of length<=4:
50
+ [`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...]
51
+ The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340.
52
+ """
53
+ result = []
54
+ for k in range(1, required_length + 1):
55
+ result.extend(product(key_group, repeat=k))
56
+ result = ["".join(pat) for pat in result]
57
+ return result
58
+
59
+
60
+ def _gen_swizzles(cls):
61
+ # https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
62
+ KEYGROUP_SET = ["xyzw", "rgba", "stpq"]
63
+ cls._swizzle_to_keygroup = {}
64
+ cls._keygroup_to_checker = {}
65
+
66
+ def make_valid_attribs_checker(key_group):
67
+ def check(instance, pattern):
68
+ valid_attribs = set(key_group[: instance.n])
69
+ pattern_set = set(pattern)
70
+ diff = pattern_set - valid_attribs
71
+ if len(diff):
72
+ valid_attribs = tuple(sorted(valid_attribs))
73
+ pattern = tuple(pattern)
74
+ raise TaichiSyntaxError(f"vec{instance.n} only has " f"attributes={valid_attribs}, got={pattern}")
75
+
76
+ return check
77
+
78
+ for key_group in KEYGROUP_SET:
79
+ cls._keygroup_to_checker[key_group] = make_valid_attribs_checker(key_group)
80
+ for index, attr in enumerate(key_group):
81
+
82
+ def gen_property(attr, attr_idx, key_group):
83
+ checker = cls._keygroup_to_checker[key_group]
84
+
85
+ def prop_getter(instance):
86
+ checker(instance, attr)
87
+ return instance[attr_idx]
88
+
89
+ @python_scope
90
+ def prop_setter(instance, value):
91
+ checker(instance, attr)
92
+ instance[attr_idx] = value
93
+
94
+ return property(prop_getter, prop_setter)
95
+
96
+ prop = gen_property(attr, index, key_group)
97
+ setattr(cls, attr, prop)
98
+ cls._swizzle_to_keygroup[attr] = key_group
99
+
100
+ for key_group in KEYGROUP_SET:
101
+ sw_patterns = _generate_swizzle_patterns(key_group, required_length=4)
102
+ # len=1 accessors are handled specially above
103
+ sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
104
+ for prop_key in sw_patterns:
105
+ # Create a function for value capturing
106
+ def gen_property(pattern, key_group):
107
+ checker = cls._keygroup_to_checker[key_group]
108
+
109
+ def prop_getter(instance):
110
+ checker(instance, pattern)
111
+ res = []
112
+ for ch in pattern:
113
+ res.append(instance[key_group.index(ch)])
114
+ return Vector(res)
115
+
116
+ @python_scope
117
+ def prop_setter(instance, value):
118
+ if len(pattern) != len(value):
119
+ raise TaichiRuntimeError(f"value len does not match the swizzle pattern={pattern}")
120
+ checker(instance, pattern)
121
+ for ch, val in zip(pattern, value):
122
+ instance[key_group.index(ch)] = val
123
+
124
+ prop = property(prop_getter, prop_setter)
125
+ return prop
126
+
127
+ prop = gen_property(prop_key, key_group)
128
+ setattr(cls, prop_key, prop)
129
+ cls._swizzle_to_keygroup[prop_key] = key_group
130
+ return cls
131
+
132
+
133
+ def _infer_entry_dt(entry):
134
+ if isinstance(entry, (int, np.integer)):
135
+ return impl.get_runtime().default_ip
136
+ if isinstance(entry, (float, np.floating)):
137
+ return impl.get_runtime().default_fp
138
+ if isinstance(entry, expr.Expr):
139
+ dt = entry.ptr.get_rvalue_type()
140
+ if dt == ti_python_core.DataType_unknown:
141
+ raise TaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
142
+ return dt
143
+ raise TaichiTypeError("Element type of the matrix is invalid.")
144
+
145
+
146
+ def _infer_array_dt(arr):
147
+ assert len(arr) > 0
148
+ return functools.reduce(ti_python_core.promoted_type, map(_infer_entry_dt, arr))
149
+
150
+
151
+ def make_matrix_with_shape(arr, shape, dt):
152
+ return expr.Expr(
153
+ impl.get_runtime()
154
+ .compiling_callable.ast_builder()
155
+ .make_matrix_expr(
156
+ shape,
157
+ dt,
158
+ [expr.Expr(elt).ptr for elt in arr],
159
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
160
+ )
161
+ )
162
+
163
+
164
+ def make_matrix(arr, dt=None):
165
+ if len(arr) == 0:
166
+ # the only usage of an empty vector is to serve as field indices
167
+ shape = [0]
168
+ dt = primitive_types.i32
169
+ else:
170
+ if isinstance(arr[0], Iterable): # matrix
171
+ shape = [len(arr), len(arr[0])]
172
+ arr = [elt for row in arr for elt in row]
173
+ else: # vector
174
+ shape = [len(arr)]
175
+ if dt is None:
176
+ dt = _infer_array_dt(arr)
177
+ else:
178
+ dt = cook_dtype(dt)
179
+ return expr.Expr(
180
+ impl.get_runtime()
181
+ .compiling_callable.ast_builder()
182
+ .make_matrix_expr(
183
+ shape,
184
+ dt,
185
+ [expr.Expr(elt).ptr for elt in arr],
186
+ ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
187
+ )
188
+ )
189
+
190
+
191
+ def _read_host_access(x):
192
+ if isinstance(x, SNodeHostAccess):
193
+ return x.accessor.getter(*x.key)
194
+ assert isinstance(x, NdarrayHostAccess)
195
+ return x.getter()
196
+
197
+
198
+ def _write_host_access(x, value):
199
+ if isinstance(x, SNodeHostAccess):
200
+ x.accessor.setter(value, *x.key)
201
+ else:
202
+ assert isinstance(x, NdarrayHostAccess)
203
+ x.setter(value)
204
+
205
+
206
+ @_gen_swizzles
207
+ class Matrix(TaichiOperations):
208
+ """The matrix class.
209
+
210
+ A matrix is a 2-D rectangular array with scalar entries, it's row-majored, and is
211
+ aligned continuously. We recommend only use matrix with no more than 32 elements for
212
+ efficiency considerations.
213
+
214
+ Note: in taichi a matrix is strictly two-dimensional and only stores scalars.
215
+
216
+ Args:
217
+ arr (Union[list, tuple, np.ndarray]): the initial values of a matrix.
218
+ dt (:mod:`~taichi.types.primitive_types`): the element data type.
219
+ ndim (int optional): the number of dimensions of the matrix; forced reshape if given.
220
+
221
+ Example::
222
+
223
+ use a 2d list to initialize a matrix
224
+
225
+ >>> @ti.kernel
226
+ >>> def test():
227
+ >>> n = 5
228
+ >>> M = ti.Matrix([[0] * n for _ in range(n)], ti.i32)
229
+ >>> print(M) # a 5x5 matrix with integer elements
230
+
231
+ get the number of rows and columns via the `n`, `m` property:
232
+
233
+ >>> M = ti.Matrix([[0, 1], [2, 3], [4, 5]], ti.i32)
234
+ >>> M.n # number of rows
235
+ 3
236
+ >>> M.m # number of cols
237
+ >>> 2
238
+
239
+ you can even initialize a matrix with an empty list:
240
+
241
+ >>> M = ti.Matrix([[], []], ti.i32)
242
+ >>> M.n
243
+ 2
244
+ >>> M.m
245
+ 0
246
+ """
247
+
248
+ _is_taichi_class = True
249
+ _is_matrix_class = True
250
+ __array_priority__ = 1000
251
+
252
+ def __init__(self, arr, dt=None):
253
+ if not isinstance(arr, (list, tuple, np.ndarray)):
254
+ raise TaichiTypeError("An Matrix/Vector can only be initialized with an array-like object")
255
+ if len(arr) == 0:
256
+ self.ndim = 0
257
+ self.n, self.m = 0, 0
258
+ self.entries = np.array([])
259
+ self.is_host_access = False
260
+ elif isinstance(arr[0], Matrix):
261
+ raise Exception("cols/rows required when using list of vectors")
262
+ elif isinstance(arr[0], Iterable): # matrix
263
+ self.ndim = 2
264
+ self.n, self.m = len(arr), len(arr[0])
265
+ if isinstance(arr[0][0], (SNodeHostAccess, NdarrayHostAccess)):
266
+ self.entries = arr
267
+ self.is_host_access = True
268
+ else:
269
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
270
+ self.is_host_access = False
271
+ else: # vector
272
+ self.ndim = 1
273
+ self.n, self.m = len(arr), 1
274
+ if isinstance(arr[0], (SNodeHostAccess, NdarrayHostAccess)):
275
+ self.entries = arr
276
+ self.is_host_access = True
277
+ else:
278
+ self.entries = np.array(arr, None if dt is None else to_numpy_type(dt))
279
+ self.is_host_access = False
280
+
281
+ if self.n * self.m > 32:
282
+ warning(
283
+ f"Taichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested."
284
+ " Matrices/vectors will be automatically unrolled at compile-time for performance."
285
+ " So the compilation time could be extremely long if the matrix size is too big."
286
+ " You may use a field to store a large matrix like this, e.g.:\n"
287
+ f" x = ti.field(ti.f32, ({self.n}, {self.m})).\n"
288
+ " See https://docs.taichi-lang.org/docs/field#matrix-size"
289
+ " for more details.",
290
+ UserWarning,
291
+ stacklevel=2,
292
+ )
293
+
294
+ def get_shape(self):
295
+ if self.ndim == 1:
296
+ return (self.n,)
297
+ if self.ndim == 2:
298
+ return (self.n, self.m)
299
+ return None
300
+
301
+ def __matmul__(self, other):
302
+ """Matrix-matrix or matrix-vector multiply.
303
+
304
+ Args:
305
+ other (Union[Matrix, Vector]): a matrix or a vector.
306
+
307
+ Returns:
308
+ The matrix-matrix product or matrix-vector product.
309
+
310
+ """
311
+ from taichi.lang import matrix_ops # pylint: disable=C0415
312
+
313
+ return matrix_ops.matmul(self, other)
314
+
315
+ # host access & python scope operation
316
+ def __len__(self):
317
+ """Get the length of each row of a matrix"""
318
+ # TODO: When this is a vector, should return its dimension?
319
+ return self.n
320
+
321
+ def __iter__(self):
322
+ if self.ndim == 1:
323
+ return (self[i] for i in range(self.n))
324
+ return ([self[i, j] for j in range(self.m)] for i in range(self.n))
325
+
326
+ def __getitem__(self, indices):
327
+ """Access to the element at the given indices in a matrix.
328
+
329
+ Args:
330
+ indices (Sequence[Expr]): the indices of the element.
331
+
332
+ Returns:
333
+ The value of the element at a specific position of a matrix.
334
+
335
+ """
336
+ entry = self._get_entry(indices)
337
+ if self.is_host_access:
338
+ return _read_host_access(entry)
339
+ return entry
340
+
341
+ @python_scope
342
+ def __setitem__(self, indices, item):
343
+ """Set the element value at the given indices in a matrix.
344
+
345
+ Args:
346
+ indices (Sequence[Expr]): the indices of a element.
347
+
348
+ """
349
+ if self.is_host_access:
350
+ entry = self._get_entry(indices)
351
+ _write_host_access(entry, item)
352
+ else:
353
+ if not isinstance(indices, (list, tuple)):
354
+ indices = [indices]
355
+ assert len(indices) in [1, 2]
356
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
357
+ if self.ndim == 1:
358
+ self.entries[indices[0]] = item
359
+ else:
360
+ self.entries[indices[0]][indices[1]] = item
361
+
362
+ def _get_entry(self, indices):
363
+ if not isinstance(indices, (list, tuple)):
364
+ indices = [indices]
365
+ assert len(indices) in [1, 2]
366
+ assert len(indices) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
367
+ if self.ndim == 1:
368
+ return self.entries[indices[0]]
369
+ return self.entries[indices[0]][indices[1]]
370
+
371
+ def _get_slice(self, a, b):
372
+ if isinstance(a, slice):
373
+ a = range(a.start or 0, a.stop or self.n, a.step or 1)
374
+ if isinstance(b, slice):
375
+ b = range(b.start or 0, b.stop or self.m, b.step or 1)
376
+ if isinstance(a, range) and isinstance(b, range):
377
+ return Matrix([[self._get_entry(i, j) for j in b] for i in a])
378
+ if isinstance(a, range): # b is not range
379
+ return Vector([self._get_entry(i, b) for i in a])
380
+ # a is not range while b is range
381
+ return Vector([self._get_entry(a, j) for j in b])
382
+
383
+ @python_scope
384
+ def _set_entries(self, value):
385
+ if isinstance(value, Matrix):
386
+ value = value.to_list()
387
+ if self.is_host_access:
388
+ if self.ndim == 1:
389
+ for i in range(self.n):
390
+ _write_host_access(self.entries[i], value[i])
391
+ else:
392
+ for i in range(self.n):
393
+ for j in range(self.m):
394
+ _write_host_access(self.entries[i][j], value[i][j])
395
+ else:
396
+ if self.ndim == 1:
397
+ for i in range(self.n):
398
+ self.entries[i] = value[i]
399
+ else:
400
+ for i in range(self.n):
401
+ for j in range(self.m):
402
+ self.entries[i][j] = value[i][j]
403
+
404
+ @property
405
+ def _members(self):
406
+ return self.entries
407
+
408
+ def to_list(self):
409
+ """Return this matrix as a 1D `list`.
410
+
411
+ This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods,
412
+ the difference is that this function always returns a new list.
413
+ """
414
+ if self.is_host_access:
415
+ if self.ndim == 1:
416
+ return [_read_host_access(self.entries[i]) for i in range(self.n)]
417
+ assert self.ndim == 2
418
+ return [[_read_host_access(self.entries[i][j]) for j in range(self.m)] for i in range(self.n)]
419
+ return self.entries.tolist()
420
+
421
+ @taichi_scope
422
+ def cast(self, dtype):
423
+ """Cast the matrix elements to a specified data type.
424
+
425
+ Args:
426
+ dtype (:mod:`~taichi.types.primitive_types`): data type of the
427
+ returned matrix.
428
+
429
+ Returns:
430
+ :class:`taichi.Matrix`: A new matrix with the specified data dtype.
431
+
432
+ Example::
433
+
434
+ >>> A = ti.Matrix([0, 1, 2], ti.i32)
435
+ >>> B = A.cast(ti.f32)
436
+ >>> B
437
+ [0.0, 1.0, 2.0]
438
+ """
439
+ if self.ndim == 1:
440
+ return Vector([ops_mod.cast(self[i], dtype) for i in range(self.n)])
441
+ return Matrix([[ops_mod.cast(self[i, j], dtype) for j in range(self.m)] for i in range(self.n)])
442
+
443
+ def trace(self):
444
+ """The sum of a matrix diagonal elements.
445
+
446
+ To call this method the matrix must be square-like.
447
+
448
+ Returns:
449
+ The sum of a matrix diagonal elements.
450
+
451
+ Example::
452
+
453
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
454
+ >>> m.trace()
455
+ 5
456
+ """
457
+ # pylint: disable-msg=C0415
458
+ from taichi.lang import matrix_ops
459
+
460
+ return matrix_ops.trace(self)
461
+
462
+ def inverse(self):
463
+ """Returns the inverse of this matrix.
464
+
465
+ Note:
466
+ The matrix dimension should be less than or equal to 4.
467
+
468
+ Returns:
469
+ :class:`~taichi.Matrix`: The inverse of a matrix.
470
+
471
+ Raises:
472
+ Exception: Inversions of matrices with sizes >= 5 are not supported.
473
+ """
474
+ from taichi.lang import matrix_ops # pylint: disable=C0415
475
+
476
+ return matrix_ops.inverse(self)
477
+
478
+ def normalized(self, eps=0):
479
+ """Normalize a vector, i.e. matrices with the second dimension being
480
+ equal to one.
481
+
482
+ The normalization of a vector `v` is a vector of length 1
483
+ and has the same direction with `v`. It's equal to `v/|v|`.
484
+
485
+ Args:
486
+ eps (float): a safe-guard value for sqrt, usually 0.
487
+
488
+ Example::
489
+
490
+ >>> a = ti.Vector([3, 4], ti.f32)
491
+ >>> a.normalized()
492
+ [0.6, 0.8]
493
+ """
494
+ # pylint: disable-msg=C0415
495
+ from taichi.lang import matrix_ops
496
+
497
+ return matrix_ops.normalized(self, eps)
498
+
499
+ def transpose(self):
500
+ """Returns the transpose of a matrix.
501
+
502
+ Returns:
503
+ :class:`~taichi.Matrix`: The transpose of this matrix.
504
+
505
+ Example::
506
+
507
+ >>> A = ti.Matrix([[0, 1], [2, 3]])
508
+ >>> A.transpose()
509
+ [[0, 2], [1, 3]]
510
+ """
511
+ # pylint: disable=C0415
512
+ from taichi.lang import matrix_ops
513
+
514
+ return matrix_ops.transpose(self)
515
+
516
+ @taichi_scope
517
+ def determinant(a):
518
+ """Returns the determinant of this matrix.
519
+
520
+ Note:
521
+ The matrix dimension should be less than or equal to 4.
522
+
523
+ Returns:
524
+ dtype: The determinant of this matrix.
525
+
526
+ Raises:
527
+ Exception: Determinants of matrices with sizes >= 5 are not supported.
528
+ """
529
+ # pylint: disable=C0415
530
+ from taichi.lang import matrix_ops
531
+
532
+ return matrix_ops.determinant(a)
533
+
534
+ @staticmethod
535
+ def diag(dim, val):
536
+ """Returns a diagonal square matrix with the diagonals filled
537
+ with `val`.
538
+
539
+ Args:
540
+ dim (int): the dimension of the wanted square matrix.
541
+ val (TypeVar): value for the diagonal elements.
542
+
543
+ Returns:
544
+ :class:`~taichi.Matrix`: The wanted diagonal matrix.
545
+
546
+ Example::
547
+
548
+ >>> m = ti.Matrix.diag(3, 1)
549
+ [[1, 0, 0],
550
+ [0, 1, 0],
551
+ [0, 0, 1]]
552
+ """
553
+ # pylint: disable=C0415
554
+ from taichi.lang import matrix_ops
555
+
556
+ return matrix_ops.diag(dim, val)
557
+
558
+ def sum(self):
559
+ """Return the sum of all elements.
560
+
561
+ Example::
562
+
563
+ >>> m = ti.Matrix([[1, 2], [3, 4]])
564
+ >>> m.sum()
565
+ 10
566
+ """
567
+ # pylint: disable=C0415
568
+ from taichi.lang import matrix_ops
569
+
570
+ return matrix_ops.sum(self)
571
+
572
+ def norm(self, eps=0):
573
+ """Returns the square root of the sum of the absolute squares
574
+ of its elements.
575
+
576
+ Args:
577
+ eps (Number): a safe-guard value for sqrt, usually 0.
578
+
579
+ Example::
580
+
581
+ >>> a = ti.Vector([3, 4])
582
+ >>> a.norm()
583
+ 5
584
+
585
+ Returns:
586
+ The square root of the sum of the absolute squares of its elements.
587
+ """
588
+ # pylint: disable=C0415
589
+ from taichi.lang import matrix_ops
590
+
591
+ return matrix_ops.norm(self, eps=eps)
592
+
593
+ def norm_inv(self, eps=0):
594
+ """The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`.
595
+
596
+ Args:
597
+ eps (float): a safe-guard value for sqrt, usually 0.
598
+
599
+ Returns:
600
+ The inverse of the matrix/vector `norm`.
601
+ """
602
+ # pylint: disable=C0415
603
+ from taichi.lang import matrix_ops
604
+
605
+ return matrix_ops.norm_inv(self, eps=eps)
606
+
607
+ def norm_sqr(self):
608
+ """Returns the sum of the absolute squares of its elements."""
609
+ # pylint: disable=C0415
610
+ from taichi.lang import matrix_ops
611
+
612
+ return matrix_ops.norm_sqr(self)
613
+
614
+ def max(self):
615
+ """Returns the maximum element value."""
616
+ # pylint: disable=C0415
617
+ from taichi.lang import matrix_ops
618
+
619
+ return matrix_ops.max(self)
620
+
621
+ def min(self):
622
+ """Returns the minimum element value."""
623
+ # pylint: disable=C0415
624
+ from taichi.lang import matrix_ops
625
+
626
+ return matrix_ops.min(self)
627
+
628
+ def any(self):
629
+ """Test whether any element not equal zero.
630
+
631
+ Returns:
632
+ bool: `True` if any element is not equal zero, `False` otherwise.
633
+
634
+ Example::
635
+
636
+ >>> v = ti.Vector([0, 0, 1])
637
+ >>> v.any()
638
+ True
639
+ """
640
+ # pylint: disable=C0415
641
+ from taichi.lang import matrix_ops
642
+
643
+ return matrix_ops.any(self)
644
+
645
+ def all(self):
646
+ """Test whether all element not equal zero.
647
+
648
+ Returns:
649
+ bool: `True` if all elements are not equal zero, `False` otherwise.
650
+
651
+ Example::
652
+
653
+ >>> v = ti.Vector([0, 0, 1])
654
+ >>> v.all()
655
+ False
656
+ """
657
+ # pylint: disable=C0415
658
+ from taichi.lang import matrix_ops
659
+
660
+ return matrix_ops.all(self)
661
+
662
+ def fill(self, val):
663
+ """Fills the matrix with a specified value.
664
+
665
+ Args:
666
+ val (Union[int, float]): Value to fill.
667
+
668
+ Example::
669
+
670
+ >>> A = ti.Matrix([0, 1, 2, 3])
671
+ >>> A.fill(-1)
672
+ >>> A
673
+ [-1, -1, -1, -1]
674
+ """
675
+ # pylint: disable=C0415
676
+ from taichi.lang import matrix_ops
677
+
678
+ return matrix_ops.fill(self, val)
679
+
680
+ def to_numpy(self):
681
+ """Converts this matrix to a numpy array.
682
+
683
+ Returns:
684
+ numpy.ndarray: The result numpy array.
685
+
686
+ Example::
687
+
688
+ >>> A = ti.Matrix([[0], [1], [2], [3]])
689
+ >>> A.to_numpy()
690
+ >>> A
691
+ array([[0], [1], [2], [3]])
692
+ """
693
+ if self.is_host_access:
694
+ return np.array(self.to_list())
695
+ return self.entries
696
+
697
+ @taichi_scope
698
+ def __ti_repr__(self):
699
+ yield "["
700
+ for i in range(self.n):
701
+ if i:
702
+ yield ", "
703
+ if self.m != 1:
704
+ yield "["
705
+ for j in range(self.m):
706
+ if j:
707
+ yield ", "
708
+ yield self(i, j)
709
+ if self.m != 1:
710
+ yield "]"
711
+ yield "]"
712
+
713
+ def __str__(self):
714
+ """Python scope matrix print support."""
715
+ if impl.inside_kernel():
716
+ """
717
+ It seems that when pybind11 got an type mismatch, it will try
718
+ to invoke `repr` to show the object... e.g.:
719
+
720
+ TypeError: make_const_expr_f32(): incompatible function arguments. The following argument types are supported:
721
+ 1. (arg0: float) -> taichi_python.Expr
722
+
723
+ Invoked with: <Taichi 2x1 Matrix>
724
+
725
+ So we have to make it happy with a dummy string...
726
+ """
727
+ return f"<{self.n}x{self.m} ti.Matrix>"
728
+ return str(self.to_numpy())
729
+
730
+ def __repr__(self):
731
+ return str(self.to_numpy())
732
+
733
+ @staticmethod
734
+ @taichi_scope
735
+ def zero(dt, n, m=None):
736
+ """Constructs a Matrix filled with zeros.
737
+
738
+ Args:
739
+ dt (DataType): The desired data type.
740
+ n (int): The first dimension (row) of the matrix.
741
+ m (int, optional): The second dimension (column) of the matrix.
742
+
743
+ Returns:
744
+ :class:`~taichi.lang.matrix.Matrix`: A :class:`~taichi.lang.matrix.Matrix` instance filled with zeros.
745
+
746
+ """
747
+ from taichi.lang import matrix_ops # pylint: disable=C0415
748
+
749
+ if m is None:
750
+ return matrix_ops._filled_vector(n, dt, 0)
751
+ return matrix_ops._filled_matrix(n, m, dt, 0)
752
+
753
+ @staticmethod
754
+ @taichi_scope
755
+ def one(dt, n, m=None):
756
+ """Constructs a Matrix filled with ones.
757
+
758
+ Args:
759
+ dt (DataType): The desired data type.
760
+ n (int): The first dimension (row) of the matrix.
761
+ m (int, optional): The second dimension (column) of the matrix.
762
+
763
+ Returns:
764
+ :class:`~taichi.lang.matrix.Matrix`: A :class:`~taichi.lang.matrix.Matrix` instance filled with ones.
765
+
766
+ """
767
+ from taichi.lang import matrix_ops # pylint: disable=C0415
768
+
769
+ if m is None:
770
+ return matrix_ops._filled_vector(n, dt, 1)
771
+ return matrix_ops._filled_matrix(n, m, dt, 1)
772
+
773
+ @staticmethod
774
+ @taichi_scope
775
+ def unit(n, i, dt=None):
776
+ """Constructs a n-D vector with the `i`-th entry being equal to one and
777
+ the remaining entries are all zeros.
778
+
779
+ Args:
780
+ n (int): The length of the vector.
781
+ i (int): The index of the entry that will be filled with one.
782
+ dt (:mod:`~taichi.types.primitive_types`, optional): The desired data type.
783
+
784
+ Returns:
785
+ :class:`~taichi.Matrix`: The returned vector.
786
+
787
+ Example::
788
+
789
+ >>> A = ti.Matrix.unit(3, 1)
790
+ >>> A
791
+ [0, 1, 0]
792
+ """
793
+ from taichi.lang import matrix_ops # pylint: disable=C0415
794
+
795
+ if dt is None:
796
+ dt = int
797
+ assert 0 <= i < n
798
+ return matrix_ops._unit_vector(n, i, dt)
799
+
800
+ @staticmethod
801
+ @taichi_scope
802
+ def identity(dt, n):
803
+ """Constructs an identity Matrix with shape (n, n).
804
+
805
+ Args:
806
+ dt (DataType): The desired data type.
807
+ n (int): The number of rows/columns.
808
+
809
+ Returns:
810
+ :class:`~taichi.Matrix`: An `n x n` identity matrix.
811
+ """
812
+ from taichi.lang import matrix_ops # pylint: disable=C0415
813
+
814
+ return matrix_ops._identity_matrix(n, dt)
815
+
816
+ @classmethod
817
+ @python_scope
818
+ def field(
819
+ cls,
820
+ n,
821
+ m,
822
+ dtype,
823
+ shape=None,
824
+ order=None,
825
+ name="",
826
+ offset=None,
827
+ needs_grad=False,
828
+ needs_dual=False,
829
+ layout=Layout.AOS,
830
+ ndim=None,
831
+ ):
832
+ """Construct a data container to hold all elements of the Matrix.
833
+
834
+ Args:
835
+ n (int): The desired number of rows of the Matrix.
836
+ m (int): The desired number of columns of the Matrix.
837
+ dtype (DataType, optional): The desired data type of the Matrix.
838
+ shape (Union[int, tuple of int], optional): The desired shape of the Matrix.
839
+ order (str, optional): order of the shape laid out in memory.
840
+ name (string, optional): The custom name of the field.
841
+ offset (Union[int, tuple of int], optional): The coordinate offset
842
+ of all elements in a field.
843
+ needs_grad (bool, optional): Whether the Matrix need grad field (reverse mode autodiff).
844
+ needs_dual (bool, optional): Whether the Matrix need dual field (forward mode autodiff).
845
+ layout (Layout, optional): The field layout, either Array Of
846
+ Structure (AOS) or Structure Of Array (SOA).
847
+
848
+ Returns:
849
+ :class:`~taichi.Matrix`: A matrix.
850
+ """
851
+ entries = []
852
+ element_dim = ndim if ndim is not None else 2
853
+ if isinstance(dtype, (list, tuple, np.ndarray)):
854
+ # set different dtype for each element in Matrix
855
+ # see #2135
856
+ if m == 1:
857
+ assert (
858
+ len(np.shape(dtype)) == 1 and len(dtype) == n
859
+ ), f"Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}"
860
+ for i in range(n):
861
+ entries.append(
862
+ impl.create_field_member(
863
+ dtype[i],
864
+ name=name,
865
+ needs_grad=needs_grad,
866
+ needs_dual=needs_dual,
867
+ )
868
+ )
869
+ else:
870
+ assert (
871
+ len(np.shape(dtype)) == 2 and len(dtype) == n and len(dtype[0]) == m
872
+ ), f"Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}"
873
+ for i in range(n):
874
+ for j in range(m):
875
+ entries.append(
876
+ impl.create_field_member(
877
+ dtype[i][j],
878
+ name=name,
879
+ needs_grad=needs_grad,
880
+ needs_dual=needs_dual,
881
+ )
882
+ )
883
+ else:
884
+ for _ in range(n * m):
885
+ entries.append(impl.create_field_member(dtype, name=name, needs_grad=needs_grad, needs_dual=needs_dual))
886
+ entries, entries_grad, entries_dual = zip(*entries)
887
+
888
+ entries = MatrixField(entries, n, m, element_dim)
889
+ if all(entries_grad):
890
+ entries_grad = MatrixField(entries_grad, n, m, element_dim)
891
+ entries._set_grad(entries_grad)
892
+ if all(entries_dual):
893
+ entries_dual = MatrixField(entries_dual, n, m, element_dim)
894
+ entries._set_dual(entries_dual)
895
+
896
+ impl.get_runtime().matrix_fields.append(entries)
897
+
898
+ if shape is None:
899
+ if offset is not None:
900
+ raise TaichiSyntaxError("shape cannot be None when offset is set")
901
+ if order is not None:
902
+ raise TaichiSyntaxError("shape cannot be None when order is set")
903
+ else:
904
+ if isinstance(shape, numbers.Number):
905
+ shape = (shape,)
906
+ if isinstance(offset, numbers.Number):
907
+ offset = (offset,)
908
+ dim = len(shape)
909
+ if offset is not None and dim != len(offset):
910
+ raise TaichiSyntaxError(
911
+ f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
912
+ )
913
+ axis_seq = []
914
+ shape_seq = []
915
+ if order is not None:
916
+ if dim != len(order):
917
+ raise TaichiSyntaxError(
918
+ f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
919
+ )
920
+ if dim != len(set(order)):
921
+ raise TaichiSyntaxError("The axes in order must be different")
922
+ for ch in order:
923
+ axis = ord(ch) - ord("i")
924
+ if axis < 0 or axis >= dim:
925
+ raise TaichiSyntaxError(f"Invalid axis {ch}")
926
+ axis_seq.append(axis)
927
+ shape_seq.append(shape[axis])
928
+ else:
929
+ axis_seq = list(range(dim))
930
+ shape_seq = list(shape)
931
+ same_level = order is None
932
+ if layout == Layout.SOA:
933
+ for e in entries._get_field_members():
934
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
935
+ if needs_grad:
936
+ for e in entries_grad._get_field_members():
937
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
938
+ if needs_dual:
939
+ for e in entries_dual._get_field_members():
940
+ impl._create_snode(axis_seq, shape_seq, same_level).place(ScalarField(e), offset=offset)
941
+ else:
942
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries, offset=offset)
943
+ if needs_grad:
944
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_grad, offset=offset)
945
+ if needs_dual:
946
+ impl._create_snode(axis_seq, shape_seq, same_level).place(entries_dual, offset=offset)
947
+ return entries
948
+
949
+ @classmethod
950
+ @python_scope
951
+ def ndarray(cls, n, m, dtype, shape):
952
+ """Defines a Taichi ndarray with matrix elements.
953
+ This function must be called in Python scope, and after `ti.init` is called.
954
+
955
+ Args:
956
+ n (int): Number of rows of the matrix.
957
+ m (int): Number of columns of the matrix.
958
+ dtype (DataType): Data type of each value.
959
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
960
+
961
+ Example::
962
+
963
+ The code below shows how a Taichi ndarray with matrix elements \
964
+ can be declared and defined::
965
+
966
+ >>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
967
+ """
968
+ if isinstance(shape, numbers.Number):
969
+ shape = (shape,)
970
+ return MatrixNdarray(n, m, dtype, shape)
971
+
972
+ @staticmethod
973
+ def rows(rows):
974
+ """Constructs a matrix by concatenating a list of
975
+ vectors/lists row by row. Must be called in Taichi scope.
976
+
977
+ Args:
978
+ rows (List): A list of Vector (1-D Matrix) or a list of list.
979
+
980
+ Returns:
981
+ :class:`~taichi.Matrix`: A matrix.
982
+
983
+ Example::
984
+
985
+ >>> @ti.kernel
986
+ >>> def test():
987
+ >>> v1 = ti.Vector([1, 2, 3])
988
+ >>> v2 = ti.Vector([4, 5, 6])
989
+ >>> m = ti.Matrix.rows([v1, v2])
990
+ >>> print(m)
991
+ >>>
992
+ >>> test()
993
+ [[1, 2, 3], [4, 5, 6]]
994
+ """
995
+ from taichi.lang import matrix_ops # pylint: disable=C0415
996
+
997
+ return matrix_ops.rows(rows)
998
+
999
+ @staticmethod
1000
+ def cols(cols):
1001
+ """Constructs a Matrix instance by concatenating Vectors/lists column by column.
1002
+
1003
+ Args:
1004
+ cols (List): A list of Vector (1-D Matrix) or a list of list.
1005
+
1006
+ Returns:
1007
+ :class:`~taichi.Matrix`: A matrix.
1008
+
1009
+ Example::
1010
+
1011
+ >>> @ti.kernel
1012
+ >>> def test():
1013
+ >>> v1 = ti.Vector([1, 2, 3])
1014
+ >>> v2 = ti.Vector([4, 5, 6])
1015
+ >>> m = ti.Matrix.cols([v1, v2])
1016
+ >>> print(m)
1017
+ >>>
1018
+ >>> test()
1019
+ [[1, 4], [2, 5], [3, 6]]
1020
+ """
1021
+ from taichi.lang import matrix_ops # pylint: disable=C0415
1022
+
1023
+ return matrix_ops.cols(cols)
1024
+
1025
+ def __hash__(self):
1026
+ # TODO: refactor KernelTemplateMapper
1027
+ # If not, we get `unhashable type: Matrix` when
1028
+ # using matrices as template arguments.
1029
+ return id(self)
1030
+
1031
+ def dot(self, other):
1032
+ """Performs the dot product of two vectors.
1033
+
1034
+ To call this method, both multiplicatives must be vectors.
1035
+
1036
+ Args:
1037
+ other (:class:`~taichi.Matrix`): The input Vector.
1038
+
1039
+ Returns:
1040
+ DataType: The dot product result (scalar) of the two Vectors.
1041
+
1042
+ Example::
1043
+
1044
+ >>> v1 = ti.Vector([1, 2, 3])
1045
+ >>> v2 = ti.Vector([3, 4, 5])
1046
+ >>> v1.dot(v2)
1047
+ 26
1048
+ """
1049
+ from taichi.lang import matrix_ops # pylint: disable=C0415
1050
+
1051
+ return matrix_ops.dot(self, other)
1052
+
1053
+ def cross(self, other):
1054
+ """Performs the cross product with the input vector (1-D Matrix).
1055
+
1056
+ Both two vectors must have the same dimension <= 3.
1057
+
1058
+ For two 2d vectors (x1, y1) and (x2, y2), the return value is the
1059
+ scalar `x1*y2 - x2*y1`.
1060
+
1061
+ For two 3d vectors `v` and `w`, the return value is the 3d vector
1062
+ `v x w`.
1063
+
1064
+ Args:
1065
+ other (:class:`~taichi.Matrix`): The input Vector.
1066
+
1067
+ Returns:
1068
+ :class:`~taichi.Matrix`: The cross product of the two Vectors.
1069
+ """
1070
+ from taichi.lang import matrix_ops # pylint: disable=C0415
1071
+
1072
+ return matrix_ops.cross(self, other)
1073
+
1074
+ def outer_product(self, other):
1075
+ """Performs the outer product with the input Vector (1-D Matrix).
1076
+
1077
+ The outer_product of two vectors `v = (x1, x2, ..., xn)`,
1078
+ `w = (y1, y2, ..., yn)` is a `n` times `n` square matrix, and its `(i, j)`
1079
+ entry is equal to `xi*yj`.
1080
+
1081
+ Args:
1082
+ other (:class:`~taichi.Matrix`): The input Vector.
1083
+
1084
+ Returns:
1085
+ :class:`~taichi.Matrix`: The outer product of the two Vectors.
1086
+ """
1087
+ from taichi.lang import matrix_ops # pylint: disable=C0415
1088
+
1089
+ return matrix_ops.outer_product(self, other)
1090
+
1091
+
1092
+ class Vector(Matrix):
1093
+ def __init__(self, arr, dt=None, **kwargs):
1094
+ """Constructs a vector from given array.
1095
+
1096
+ A vector is an instance of a 2-D matrix with the second dimension being equal to 1.
1097
+
1098
+ Args:
1099
+ arr (Union[list, tuple, np.ndarray]): The initial values of the Vector.
1100
+ dt (:mod:`~taichi.types.primitive_types`): data type of the vector.
1101
+
1102
+ Returns:
1103
+ :class:`~taichi.Matrix`: A vector instance.
1104
+ Example::
1105
+ >>> u = ti.Vector([1, 2])
1106
+ >>> print(u.m, u.n) # verify a vector is a matrix of shape (n, 1)
1107
+ 2 1
1108
+ >>> v = ti.Vector([3, 4])
1109
+ >>> u + v
1110
+ [4 6]
1111
+ """
1112
+ super().__init__(arr, dt=dt, **kwargs)
1113
+
1114
+ def get_shape(self):
1115
+ return (self.n,)
1116
+
1117
+ @classmethod
1118
+ def field(cls, n, dtype, *args, **kwargs):
1119
+ """ti.Vector.field"""
1120
+ ndim = kwargs.get("ndim", 1)
1121
+ assert ndim == 1
1122
+ kwargs["ndim"] = 1
1123
+ return super().field(n, 1, dtype, *args, **kwargs)
1124
+
1125
+ @classmethod
1126
+ @python_scope
1127
+ def ndarray(cls, n, dtype, shape):
1128
+ """Defines a Taichi ndarray with vector elements.
1129
+
1130
+ Args:
1131
+ n (int): Size of the vector.
1132
+ dtype (DataType): Data type of each value.
1133
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1134
+
1135
+ Example:
1136
+ The code below shows how a Taichi ndarray with vector elements can be declared and defined::
1137
+
1138
+ >>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
1139
+ """
1140
+ if isinstance(shape, numbers.Number):
1141
+ shape = (shape,)
1142
+ return VectorNdarray(n, dtype, shape)
1143
+
1144
+
1145
+ class MatrixField(Field):
1146
+ """Taichi matrix field with SNode implementation.
1147
+
1148
+ Args:
1149
+ vars (List[Expr]): Field members.
1150
+ n (Int): Number of rows.
1151
+ m (Int): Number of columns.
1152
+ ndim (Int): Number of dimensions; forced reshape if given.
1153
+ """
1154
+
1155
+ def __init__(self, _vars, n, m, ndim=2):
1156
+ assert len(_vars) == n * m
1157
+ assert ndim in (0, 1, 2)
1158
+ super().__init__(_vars)
1159
+ self.n = n
1160
+ self.m = m
1161
+ self.ndim = ndim
1162
+ self.ptr = ti_python_core.expr_matrix_field([var.ptr for var in self.vars], [n, m][:ndim])
1163
+
1164
+ def get_scalar_field(self, *indices):
1165
+ """Creates a ScalarField using a specific field member.
1166
+
1167
+ Args:
1168
+ indices (Tuple[Int]): Specified indices of the field member.
1169
+
1170
+ Returns:
1171
+ ScalarField: The result ScalarField.
1172
+ """
1173
+ assert len(indices) in [1, 2]
1174
+ i = indices[0]
1175
+ j = 0 if len(indices) == 1 else indices[1]
1176
+ return ScalarField(self.vars[i * self.m + j])
1177
+
1178
+ def _get_dynamic_index_stride(self):
1179
+ if self.ptr.get_dynamic_indexable():
1180
+ return self.ptr.get_dynamic_index_stride()
1181
+ return None
1182
+
1183
+ def _calc_dynamic_index_stride(self):
1184
+ # Algorithm: https://github.com/taichi-dev/taichi/issues/3810
1185
+ paths = [ScalarField(var).snode._path_from_root() for var in self.vars]
1186
+ num_members = len(paths)
1187
+ if num_members == 1:
1188
+ self.ptr.set_dynamic_index_stride(0)
1189
+ return
1190
+ length = len(paths[0])
1191
+ if any(len(path) != length or ti_python_core.is_quant(path[length - 1]._dtype) for path in paths):
1192
+ return
1193
+ for i in range(length):
1194
+ if any(path[i] != paths[0][i] for path in paths):
1195
+ depth_below_lca = i
1196
+ break
1197
+ for i in range(depth_below_lca, length - 1):
1198
+ if any(
1199
+ path[i].ptr.type != ti_python_core.SNodeType.dense
1200
+ or path[i]._cell_size_bytes != paths[0][i]._cell_size_bytes
1201
+ or path[i + 1]._offset_bytes_in_parent_cell != paths[0][i + 1]._offset_bytes_in_parent_cell
1202
+ for path in paths
1203
+ ):
1204
+ return
1205
+ stride = (
1206
+ paths[1][depth_below_lca]._offset_bytes_in_parent_cell
1207
+ - paths[0][depth_below_lca]._offset_bytes_in_parent_cell
1208
+ )
1209
+ for i in range(2, num_members):
1210
+ if (
1211
+ stride
1212
+ != paths[i][depth_below_lca]._offset_bytes_in_parent_cell
1213
+ - paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell
1214
+ ):
1215
+ return
1216
+ self.ptr.set_dynamic_index_stride(stride)
1217
+
1218
+ def fill(self, val):
1219
+ """Fills this matrix field with specified values.
1220
+
1221
+ Args:
1222
+ val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill,
1223
+ should have consistent dimension consistent with `self`.
1224
+ """
1225
+ if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) and not val.is_tensor()):
1226
+ if self.ndim == 2:
1227
+ val = tuple(tuple(val for _ in range(self.m)) for _ in range(self.n))
1228
+ else:
1229
+ assert self.ndim == 1
1230
+ val = tuple(val for _ in range(self.n))
1231
+ elif isinstance(val, expr.Expr) and val.is_tensor():
1232
+ assert val.n == self.n
1233
+ if self.ndim != 1:
1234
+ assert val.m == self.m
1235
+ else:
1236
+ if isinstance(val, Matrix):
1237
+ val = val.to_list()
1238
+ assert isinstance(val, (list, tuple))
1239
+ val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
1240
+ assert len(val) == self.n
1241
+ if self.ndim != 1:
1242
+ assert len(val[0]) == self.m
1243
+ if in_python_scope():
1244
+ from taichi._kernels import field_fill_python_scope # pylint: disable=C0415
1245
+
1246
+ field_fill_python_scope(self, val)
1247
+ else:
1248
+ from taichi._funcs import field_fill_taichi_scope # pylint: disable=C0415
1249
+
1250
+ field_fill_taichi_scope(self, val)
1251
+
1252
+ @python_scope
1253
+ def to_numpy(self, keep_dims=False, dtype=None):
1254
+ """Converts the field instance to a NumPy array.
1255
+
1256
+ Args:
1257
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1258
+ When keep_dims=True, on an n-D matrix field, the numpy array always has n+2 dims, even for 1x1, 1xn, nx1 matrix fields.
1259
+ When keep_dims=False, the resulting numpy array should skip the matrix dims with size 1.
1260
+ For example, a 4x1 or 1x4 matrix field with 5x6x7 elements results in an array of shape 5x6x7x4.
1261
+ dtype (DataType, optional): The desired data type of returned numpy array.
1262
+
1263
+ Returns:
1264
+ numpy.ndarray: The result NumPy array.
1265
+ """
1266
+ if dtype is None:
1267
+ dtype = to_numpy_type(self.dtype)
1268
+ as_vector = self.m == 1 and not keep_dims
1269
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1270
+ arr = np.zeros(self.shape + shape_ext, dtype=dtype)
1271
+ from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1272
+
1273
+ matrix_to_ext_arr(self, arr, as_vector)
1274
+ runtime_ops.sync()
1275
+ return arr
1276
+
1277
+ def to_torch(self, device=None, keep_dims=False):
1278
+ """Converts the field instance to a PyTorch tensor.
1279
+
1280
+ Args:
1281
+ device (torch.device, optional): The desired device of returned tensor.
1282
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1283
+ See :meth:`~taichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1284
+
1285
+ Returns:
1286
+ torch.tensor: The result torch tensor.
1287
+ """
1288
+ import torch # pylint: disable=C0415
1289
+
1290
+ as_vector = self.m == 1 and not keep_dims
1291
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1292
+ # pylint: disable=E1101
1293
+ arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device)
1294
+ from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1295
+
1296
+ matrix_to_ext_arr(self, arr, as_vector)
1297
+ runtime_ops.sync()
1298
+ return arr
1299
+
1300
+ def to_paddle(self, place=None, keep_dims=False):
1301
+ """Converts the field instance to a Paddle tensor.
1302
+
1303
+ Args:
1304
+ place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
1305
+ keep_dims (bool, optional): Whether to keep the dimension after conversion.
1306
+ See :meth:`~taichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1307
+
1308
+ Returns:
1309
+ paddle.Tensor: The result paddle tensor.
1310
+ """
1311
+ import paddle # pylint: disable=C0415
1312
+
1313
+ as_vector = self.m == 1 and not keep_dims and self.ndim == 1
1314
+ shape_ext = (self.n,) if as_vector else (self.n, self.m)
1315
+ # pylint: disable=E1101
1316
+ # paddle.empty() doesn't support argument `place``
1317
+ arr = paddle.to_tensor(
1318
+ paddle.empty(self.shape + shape_ext, to_paddle_type(self.dtype)),
1319
+ place=place,
1320
+ )
1321
+ from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1322
+
1323
+ matrix_to_ext_arr(self, arr, as_vector)
1324
+ runtime_ops.sync()
1325
+ return arr
1326
+
1327
+ @python_scope
1328
+ def _from_external_arr(self, arr):
1329
+ if len(arr.shape) == len(self.shape) + 1:
1330
+ as_vector = True
1331
+ assert self.m == 1, "This is not a vector field"
1332
+ else:
1333
+ as_vector = False
1334
+ assert len(arr.shape) == len(self.shape) + 2
1335
+ dim_ext = 1 if as_vector else 2
1336
+ assert len(arr.shape) == len(self.shape) + dim_ext
1337
+ from taichi._kernels import ext_arr_to_matrix # pylint: disable=C0415
1338
+
1339
+ ext_arr_to_matrix(arr, self, as_vector)
1340
+ runtime_ops.sync()
1341
+
1342
+ @python_scope
1343
+ def from_numpy(self, arr):
1344
+ """Copies an `numpy.ndarray` into this field.
1345
+
1346
+ Example::
1347
+
1348
+ >>> m = ti.Matrix.field(2, 2, ti.f32, shape=(3, 3))
1349
+ >>> arr = numpy.ones((3, 3, 2, 2))
1350
+ >>> m.from_numpy(arr)
1351
+ """
1352
+
1353
+ if not arr.flags.c_contiguous:
1354
+ arr = np.ascontiguousarray(arr)
1355
+ self._from_external_arr(arr)
1356
+
1357
+ @python_scope
1358
+ def __setitem__(self, key, value):
1359
+ self._initialize_host_accessors()
1360
+ self[key]._set_entries(value)
1361
+
1362
+ @python_scope
1363
+ def __getitem__(self, key):
1364
+ self._initialize_host_accessors()
1365
+ key = self._pad_key(key)
1366
+ _host_access = self._host_access(key)
1367
+ if self.ndim == 1:
1368
+ return Vector([_host_access[i] for i in range(self.n)])
1369
+ return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] for i in range(self.n)])
1370
+
1371
+ def __repr__(self):
1372
+ # make interactive shell happy, prevent materialization
1373
+ return f"<{self.n}x{self.m} ti.Matrix.field>"
1374
+
1375
+
1376
+ class MatrixType(CompoundType):
1377
+ def __init__(self, n, m, ndim, dtype):
1378
+ self.n = n
1379
+ self.m = m
1380
+ self.ndim = ndim
1381
+ # FIXME(haidong): dtypes should not be left empty for ndarray.
1382
+ # Remove the None dtype when we are ready to break legacy code.
1383
+ if dtype is not None:
1384
+ self.dtype = cook_dtype(dtype)
1385
+ shape = (n, m) if ndim == 2 else (n,)
1386
+ self.tensor_type = _type_factory.get_tensor_type(shape, self.dtype)
1387
+ else:
1388
+ self.dtype = None
1389
+ self.tensor_type = None
1390
+
1391
+ def __call__(self, *args):
1392
+ """Return a matrix matching the shape and dtype.
1393
+
1394
+ This function will try to convert the input to a `n x m` matrix, with n, m being
1395
+ the number of rows/cols of this matrix type.
1396
+
1397
+ Example::
1398
+
1399
+ >>> mat4x3 = MatrixType(4, 3, float)
1400
+ >>> mat2x6 = MatrixType(2, 6, float)
1401
+
1402
+ Create from n x m scalars, of a 1d list of n x m scalars:
1403
+
1404
+ >>> m = mat4x3([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
1405
+ >>> m = mat4x3(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
1406
+
1407
+ Create from n vectors/lists, with each one of dimension m:
1408
+
1409
+ >>> m = mat4x3([1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12])
1410
+
1411
+ Create from a single scalar
1412
+
1413
+ >>> m = mat4x3(1)
1414
+
1415
+ Create from another 2d list/matrix, as long as they have the same number of entries
1416
+
1417
+ >>> m = mat4x3([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
1418
+ >>> m = mat4x3(m)
1419
+ >>> k = mat2x6(m)
1420
+
1421
+ """
1422
+ if len(args) == 0:
1423
+ raise TaichiSyntaxError("Custom type instances need to be created with an initial value.")
1424
+ if len(args) == 1:
1425
+ # Init from a real Matrix
1426
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1427
+ arg = args[0]
1428
+ shape = arg.ptr.get_rvalue_type().shape()
1429
+ assert self.ndim == len(shape)
1430
+ assert self.n == shape[0]
1431
+ if self.ndim > 1:
1432
+ assert self.m == shape[1]
1433
+ return expr.Expr(arg.ptr)
1434
+
1435
+ # initialize by a single scalar, e.g. matnxm(1)
1436
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1437
+ entries = [args[0] for _ in range(self.m) for _ in range(self.n)]
1438
+ return self._instantiate(entries)
1439
+ args = args[0]
1440
+ # collect all input entries to a 1d list and then reshape
1441
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1442
+ entries = []
1443
+ for x in args:
1444
+ if isinstance(x, (list, tuple)):
1445
+ entries += x
1446
+ elif isinstance(x, np.ndarray):
1447
+ entries += list(x.ravel())
1448
+ elif isinstance(x, Matrix):
1449
+ entries += x.to_list()
1450
+ else:
1451
+ entries.append(x)
1452
+
1453
+ return self._instantiate(entries)
1454
+
1455
+ def from_taichi_object(self, func_ret, ret_index=()):
1456
+ return self(
1457
+ [
1458
+ expr.Expr(
1459
+ ti_python_core.make_get_element_expr(
1460
+ func_ret.ptr,
1461
+ ret_index + (i,),
1462
+ _ti_python_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1463
+ )
1464
+ )
1465
+ for i in range(self.m * self.n)
1466
+ ]
1467
+ )
1468
+
1469
+ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
1470
+ if self.dtype in primitive_types.integer_types:
1471
+ if is_signed(cook_dtype(self.dtype)):
1472
+ get_ret_func = launch_ctx.get_struct_ret_int
1473
+ else:
1474
+ get_ret_func = launch_ctx.get_struct_ret_uint
1475
+ elif self.dtype in primitive_types.real_types:
1476
+ get_ret_func = launch_ctx.get_struct_ret_float
1477
+ else:
1478
+ raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1479
+ return self([get_ret_func(ret_index + (i,)) for i in range(self.m * self.n)])
1480
+
1481
+ def set_kernel_struct_args(self, mat, launch_ctx, ret_index=()):
1482
+ if self.dtype in primitive_types.integer_types:
1483
+ if is_signed(cook_dtype(self.dtype)):
1484
+ set_arg_func = launch_ctx.set_struct_arg_int
1485
+ else:
1486
+ set_arg_func = launch_ctx.set_struct_arg_uint
1487
+ elif self.dtype in primitive_types.real_types:
1488
+ set_arg_func = launch_ctx.set_struct_arg_float
1489
+ else:
1490
+ raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1491
+ if self.ndim == 1:
1492
+ for i in range(self.n):
1493
+ set_arg_func(ret_index + (i,), mat[i])
1494
+ else:
1495
+ for i in range(self.n):
1496
+ for j in range(self.m):
1497
+ set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1498
+
1499
+ def set_argpack_struct_args(self, mat, argpack, ret_index=()):
1500
+ if self.dtype in primitive_types.integer_types:
1501
+ if is_signed(cook_dtype(self.dtype)):
1502
+ set_arg_func = argpack.set_arg_int
1503
+ else:
1504
+ set_arg_func = argpack.set_arg_uint
1505
+ elif self.dtype in primitive_types.real_types:
1506
+ set_arg_func = argpack.set_arg_float
1507
+ else:
1508
+ raise TaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1509
+ if self.ndim == 1:
1510
+ for i in range(self.n):
1511
+ set_arg_func(ret_index + (i,), mat[i])
1512
+ else:
1513
+ for i in range(self.n):
1514
+ for j in range(self.m):
1515
+ set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1516
+
1517
+ def _instantiate_in_python_scope(self, entries):
1518
+ entries = [[entries[k * self.m + i] for i in range(self.m)] for k in range(self.n)]
1519
+ return Matrix(
1520
+ [
1521
+ [
1522
+ int(entries[i][j]) if self.dtype in primitive_types.integer_types else float(entries[i][j])
1523
+ for j in range(self.m)
1524
+ ]
1525
+ for i in range(self.n)
1526
+ ],
1527
+ dt=self.dtype,
1528
+ )
1529
+
1530
+ def _instantiate(self, entries):
1531
+ if in_python_scope():
1532
+ return self._instantiate_in_python_scope(entries)
1533
+
1534
+ return make_matrix_with_shape(entries, [self.n, self.m], self.dtype)
1535
+
1536
+ def field(self, **kwargs):
1537
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1538
+ kwargs.update({"ndim": self.ndim})
1539
+ return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)
1540
+
1541
+ def ndarray(self, **kwargs):
1542
+ assert kwargs.get("ndim", self.ndim) == self.ndim
1543
+ kwargs.update({"ndim": self.ndim})
1544
+ return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)
1545
+
1546
+ def get_shape(self):
1547
+ if self.ndim == 1:
1548
+ return (self.n,)
1549
+ return (self.n, self.m)
1550
+
1551
+ def to_string(self):
1552
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1553
+ return f"MatrixType[{self.n},{self.m}, {dtype_str}]"
1554
+
1555
+ def check_matched(self, other):
1556
+ if self.ndim != len(other.shape()):
1557
+ return False
1558
+ if self.dtype is not None and self.dtype != other.element_type():
1559
+ return False
1560
+ shape = self.get_shape()
1561
+ for i in range(self.ndim):
1562
+ if shape[i] is not None and shape[i] != other.shape()[i]:
1563
+ return False
1564
+ return True
1565
+
1566
+
1567
+ class VectorType(MatrixType):
1568
+ def __init__(self, n, dtype):
1569
+ super().__init__(n, 1, 1, dtype)
1570
+
1571
+ def __call__(self, *args):
1572
+ """Return a vector matching the shape and dtype.
1573
+
1574
+ This function will try to convert the input to a `n`-component vector.
1575
+
1576
+ Example::
1577
+
1578
+ >>> vec3 = VectorType(3, float)
1579
+
1580
+ Create from n scalars:
1581
+
1582
+ >>> v = vec3(1, 2, 3)
1583
+
1584
+ Create from a list/tuple of n scalars:
1585
+
1586
+ >>> v = vec3([1, 2, 3])
1587
+
1588
+ Create from a single scalar
1589
+
1590
+ >>> v = vec3(1)
1591
+
1592
+ """
1593
+ if len(args) == 0:
1594
+ raise TaichiSyntaxError("Custom type instances need to be created with an initial value.")
1595
+ if len(args) == 1:
1596
+ # Init from a real Matrix
1597
+ if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
1598
+ arg = args[0]
1599
+ shape = arg.ptr.get_rvalue_type().shape()
1600
+ assert len(shape) == 1
1601
+ assert self.n == shape[0]
1602
+ return expr.Expr(arg.ptr)
1603
+
1604
+ # initialize by a single scalar, e.g. matnxm(1)
1605
+ if isinstance(args[0], (numbers.Number, expr.Expr)):
1606
+ entries = [args[0] for _ in range(self.n)]
1607
+ return self._instantiate(entries)
1608
+ args = args[0]
1609
+ # collect all input entries to a 1d list and then reshape
1610
+ # this is mostly for glsl style like vec4(v.xyz, 1.)
1611
+ entries = []
1612
+ for x in args:
1613
+ if isinstance(x, (list, tuple)):
1614
+ entries += x
1615
+ elif isinstance(x, np.ndarray):
1616
+ entries += list(x.ravel())
1617
+ elif isinstance(x, Matrix):
1618
+ entries += x.to_list()
1619
+ else:
1620
+ entries.append(x)
1621
+
1622
+ # type cast
1623
+ return self._instantiate(entries)
1624
+
1625
+ def _instantiate_in_python_scope(self, entries):
1626
+ return Vector(
1627
+ [
1628
+ int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i])
1629
+ for i in range(self.n)
1630
+ ],
1631
+ dt=self.dtype,
1632
+ )
1633
+
1634
+ def _instantiate(self, entries):
1635
+ if in_python_scope():
1636
+ return self._instantiate_in_python_scope(entries)
1637
+
1638
+ return make_matrix_with_shape(entries, [self.n], self.dtype)
1639
+
1640
+ def field(self, **kwargs):
1641
+ return Vector.field(self.n, dtype=self.dtype, **kwargs)
1642
+
1643
+ def ndarray(self, **kwargs):
1644
+ return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)
1645
+
1646
+ def to_string(self):
1647
+ dtype_str = self.dtype.to_string() if self.dtype is not None else ""
1648
+ return f"VectorType[{self.n}, {dtype_str}]"
1649
+
1650
+
1651
+ class MatrixNdarray(Ndarray):
1652
+ """Taichi ndarray with matrix elements.
1653
+
1654
+ Args:
1655
+ n (int): Number of rows of the matrix.
1656
+ m (int): Number of columns of the matrix.
1657
+ dtype (DataType): Data type of each value.
1658
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
1659
+
1660
+ Example::
1661
+
1662
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1663
+ """
1664
+
1665
+ def __init__(self, n, m, dtype, shape):
1666
+ self.n = n
1667
+ self.m = m
1668
+ super().__init__()
1669
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1670
+ self.dtype = cook_dtype(dtype)
1671
+
1672
+ self.layout = Layout.AOS
1673
+ self.shape = tuple(shape)
1674
+ self.element_type = _type_factory.get_tensor_type((self.n, self.m), self.dtype)
1675
+ # TODO: we should pass in element_type, shape, layout instead.
1676
+ self.arr = impl.get_runtime().prog.create_ndarray(
1677
+ cook_dtype(self.element_type),
1678
+ shape,
1679
+ Layout.AOS,
1680
+ zero_fill=True,
1681
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1682
+ )
1683
+
1684
+ @property
1685
+ def element_shape(self):
1686
+ """Returns the shape of each element (a 2D matrix) in this ndarray.
1687
+
1688
+ Example::
1689
+
1690
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3))
1691
+ >>> arr.element_shape
1692
+ (2, 2)
1693
+ """
1694
+ return tuple(self.arr.element_shape())
1695
+
1696
+ @python_scope
1697
+ def __setitem__(self, key, value):
1698
+ if not isinstance(value, (list, tuple)):
1699
+ value = list(value)
1700
+ if not isinstance(value[0], (list, tuple)):
1701
+ value = [[i] for i in value]
1702
+ for i in range(self.n):
1703
+ for j in range(self.m):
1704
+ self[key][i, j] = value[i][j]
1705
+
1706
+ @python_scope
1707
+ def __getitem__(self, key):
1708
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1709
+ return Matrix([[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)] for i in range(self.n)])
1710
+
1711
+ @python_scope
1712
+ def to_numpy(self):
1713
+ """Converts this ndarray to a `numpy.ndarray`.
1714
+
1715
+ Example::
1716
+
1717
+ >>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1))
1718
+ >>> arr.to_numpy()
1719
+ [[[[0. 0.]
1720
+ [0. 0.]]]
1721
+
1722
+ [[[0. 0.]
1723
+ [0. 0.]]]]
1724
+ """
1725
+ return self._ndarray_matrix_to_numpy(as_vector=0)
1726
+
1727
+ @python_scope
1728
+ def from_numpy(self, arr):
1729
+ """Copies the data of a `numpy.ndarray` into this array.
1730
+
1731
+ Example::
1732
+
1733
+ >>> m = ti.MatrixNdarray(2, 2, ti.f32, shape=(2, 1), layout=0)
1734
+ >>> arr = np.ones((2, 1, 2, 2))
1735
+ >>> m.from_numpy(arr)
1736
+ """
1737
+ self._ndarray_matrix_from_numpy(arr, as_vector=0)
1738
+
1739
+ @python_scope
1740
+ def __deepcopy__(self, memo=None):
1741
+ ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape)
1742
+ ret_arr.copy_from(self)
1743
+ return ret_arr
1744
+
1745
+ @python_scope
1746
+ def _fill_by_kernel(self, val):
1747
+ from taichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1748
+
1749
+ shape = self.element_type.shape()
1750
+ n = shape[0]
1751
+ m = 1
1752
+ if len(shape) > 1:
1753
+ m = shape[1]
1754
+
1755
+ prim_dtype = self.element_type.element_type()
1756
+ matrix_type = MatrixType(n, m, len(shape), prim_dtype)
1757
+ if isinstance(val, Matrix):
1758
+ value = val
1759
+ else:
1760
+ value = matrix_type(val)
1761
+ fill_ndarray_matrix(self, value)
1762
+
1763
+ @python_scope
1764
+ def __repr__(self):
1765
+ return f"<{self.n}x{self.m} {Layout.AOS} ti.Matrix.ndarray>"
1766
+
1767
+
1768
+ class VectorNdarray(Ndarray):
1769
+ """Taichi ndarray with vector elements.
1770
+
1771
+ Args:
1772
+ n (int): Size of the vector.
1773
+ dtype (DataType): Data type of each value.
1774
+ shape (Tuple[int]): Shape of the ndarray.
1775
+
1776
+ Example::
1777
+
1778
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1779
+ """
1780
+
1781
+ def __init__(self, n, dtype, shape):
1782
+ self.n = n
1783
+ super().__init__()
1784
+ # TODO(zhanlue): remove self.dtype and migrate its usages to element_type
1785
+ self.dtype = cook_dtype(dtype)
1786
+
1787
+ self.layout = Layout.AOS
1788
+ self.shape = tuple(shape)
1789
+ self.element_type = _type_factory.get_tensor_type((n,), self.dtype)
1790
+ self.arr = impl.get_runtime().prog.create_ndarray(
1791
+ cook_dtype(self.element_type),
1792
+ shape,
1793
+ Layout.AOS,
1794
+ zero_fill=True,
1795
+ dbg_info=ti_python_core.DebugInfo(get_traceback()),
1796
+ )
1797
+
1798
+ @property
1799
+ def element_shape(self):
1800
+ """Gets the dimension of the vector of this ndarray.
1801
+
1802
+ Example::
1803
+
1804
+ >>> a = ti.VectorNdarray(3, ti.f32, (3, 3))
1805
+ >>> a.element_shape
1806
+ (3,)
1807
+ """
1808
+ return tuple(self.arr.element_shape())
1809
+
1810
+ @python_scope
1811
+ def __setitem__(self, key, value):
1812
+ if not isinstance(value, (list, tuple)):
1813
+ value = list(value)
1814
+ for i in range(self.n):
1815
+ self[key][i] = value[i]
1816
+
1817
+ @python_scope
1818
+ def __getitem__(self, key):
1819
+ key = () if key is None else (key,) if isinstance(key, numbers.Number) else tuple(key)
1820
+ return Vector([NdarrayHostAccess(self, key, (i,)) for i in range(self.n)])
1821
+
1822
+ @python_scope
1823
+ def to_numpy(self):
1824
+ """Converts this vector ndarray to a `numpy.ndarray`.
1825
+
1826
+ Example::
1827
+
1828
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2))
1829
+ >>> a.to_numpy()
1830
+ array([[[0., 0., 0.],
1831
+ [0., 0., 0.]],
1832
+
1833
+ [[0., 0., 0.],
1834
+ [0., 0., 0.]]], dtype=float32)
1835
+ """
1836
+ return self._ndarray_matrix_to_numpy(as_vector=1)
1837
+
1838
+ @python_scope
1839
+ def from_numpy(self, arr):
1840
+ """Copies the data from a `numpy.ndarray` into this ndarray.
1841
+
1842
+ The shape and data type of `arr` must match this ndarray.
1843
+
1844
+ Example::
1845
+
1846
+ >>> import numpy as np
1847
+ >>> a = ti.VectorNdarray(3, ti.f32, (2, 2), 0)
1848
+ >>> b = np.ones((2, 2, 3), dtype=np.float32)
1849
+ >>> a.from_numpy(b)
1850
+ """
1851
+ self._ndarray_matrix_from_numpy(arr, as_vector=1)
1852
+
1853
+ @python_scope
1854
+ def __deepcopy__(self, memo=None):
1855
+ ret_arr = VectorNdarray(self.n, self.dtype, self.shape)
1856
+ ret_arr.copy_from(self)
1857
+ return ret_arr
1858
+
1859
+ @python_scope
1860
+ def _fill_by_kernel(self, val):
1861
+ from taichi._kernels import fill_ndarray_matrix # pylint: disable=C0415
1862
+
1863
+ shape = self.element_type.shape()
1864
+ prim_dtype = self.element_type.element_type()
1865
+ vector_type = VectorType(shape[0], prim_dtype)
1866
+ if isinstance(val, Vector):
1867
+ value = val
1868
+ else:
1869
+ value = vector_type(val)
1870
+ fill_ndarray_matrix(self, value)
1871
+
1872
+ @python_scope
1873
+ def __repr__(self):
1874
+ return f"<{self.n} {Layout.AOS} ti.Vector.ndarray>"
1875
+
1876
+
1877
+ __all__ = ["Matrix", "Vector", "MatrixField", "MatrixNdarray", "VectorNdarray"]