gstaichi 2.1.1__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/_funcs.py +706 -0
  3. gstaichi/_kernels.py +420 -0
  4. gstaichi/_lib/__init__.py +3 -0
  5. gstaichi/_lib/core/__init__.py +0 -0
  6. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  7. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  8. gstaichi/_lib/core/py.typed +0 -0
  9. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  10. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  11. gstaichi/_lib/utils.py +243 -0
  12. gstaichi/_logging.py +131 -0
  13. gstaichi/_snode/__init__.py +5 -0
  14. gstaichi/_snode/fields_builder.py +187 -0
  15. gstaichi/_snode/snode_tree.py +34 -0
  16. gstaichi/_test_tools/__init__.py +18 -0
  17. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  18. gstaichi/_test_tools/load_kernel_string.py +30 -0
  19. gstaichi/_test_tools/textwrap2.py +6 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +352 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +5 -0
  51. gstaichi/lang/ast/ast_transformer.py +1323 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1245 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1341 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +780 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +8 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +19 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-2.1.1.dist-info/METADATA +106 -0
  175. gstaichi-2.1.1.dist-info/RECORD +178 -0
  176. gstaichi-2.1.1.dist-info/WHEEL +5 -0
  177. gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-2.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,212 @@
1
+ import ast
2
+ import dataclasses
3
+ import inspect
4
+ from typing import Any
5
+
6
+ from gstaichi.lang import util
7
+ from gstaichi.lang._dataclass_util import create_flat_name
8
+ from gstaichi.lang.ast import (
9
+ ASTTransformerContext,
10
+ )
11
+ from gstaichi.lang.kernel_arguments import ArgMetadata
12
+
13
+
14
+ def _populate_struct_locals_from_params_dict(basename: str, struct_locals, struct_type) -> None:
15
+ """
16
+ We are populating struct locals from a type included in function parameters, or one of their subtypes
17
+
18
+ struct_locals will be a list of all possible unpacked variable names we can form from the struct.
19
+ basename is used to take into account the parent struct's name. For example, lets say we have:
20
+
21
+ @dataclasses.dataclass
22
+ class StructAB:
23
+ a:
24
+ b:
25
+ struct_cd: StructCD
26
+
27
+ @dataclasses.dataclass
28
+ class StructCD:
29
+ c:
30
+ d:
31
+ struct_ef: StructEF
32
+
33
+ @dataclasses.dataclass
34
+ class StructEF:
35
+ e:
36
+ f:
37
+
38
+ ... and the function parameters look like: `def foo(struct_ab: StructAB)`
39
+
40
+ then all possible variables we could form from this are:
41
+ - struct_ab.a
42
+ - struct_ab.b
43
+ - struct_ab.struct_cd.c
44
+ - struct_ab.struct_cd.d
45
+ - struct_ab.struct_cd.strucdt_ef.e
46
+ - struct_ab.struct_cd.strucdt_ef.f
47
+
48
+ And the members of struct_locals should be:
49
+ - __ti_struct_ab__ti_a
50
+ - __ti_struct_ab__ti_b
51
+ - __ti_struct_ab__ti_struct_cd__ti_c
52
+ - __ti_struct_ab__ti_struct_cd__ti_d
53
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_e
54
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
55
+ """
56
+ for field in dataclasses.fields(struct_type):
57
+ child_name = create_flat_name(basename, field.name)
58
+ if dataclasses.is_dataclass(field.type):
59
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
60
+ else:
61
+ struct_locals.add(child_name)
62
+
63
+
64
+ def extract_struct_locals_from_context(ctx: ASTTransformerContext) -> set[str]:
65
+ """
66
+ Provides meta information for later tarnsformation of nodes in AST
67
+
68
+ - Uses ctx.func.func to get the function signature.
69
+ - Searches this for any dataclasses:
70
+ - If it finds any dataclasses, then converts them into expanded names.
71
+ - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
72
+ {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
73
+ """
74
+ struct_locals = set()
75
+ assert ctx.func is not None
76
+ sig = inspect.signature(ctx.func.func)
77
+ parameters = sig.parameters
78
+ for param_name, parameter in parameters.items():
79
+ if dataclasses.is_dataclass(parameter.annotation):
80
+ for field in dataclasses.fields(parameter.annotation):
81
+ child_name = create_flat_name(param_name, field.name)
82
+ # child_name = f"__ti_{param_name}__ti_{field.name}"
83
+ if dataclasses.is_dataclass(field.type):
84
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
85
+ continue
86
+ struct_locals.add(child_name)
87
+ return struct_locals
88
+
89
+
90
+ def expand_func_arguments(arguments: list[ArgMetadata]) -> list[ArgMetadata]:
91
+ """
92
+ Used to expand arguments for @ti.func
93
+ """
94
+ expanded_arguments = []
95
+ for i, argument in enumerate(arguments):
96
+ if dataclasses.is_dataclass(argument.annotation):
97
+ for field in dataclasses.fields(argument.annotation):
98
+ child_name = create_flat_name(argument.name, field.name)
99
+ if dataclasses.is_dataclass(field.type):
100
+ new_arg = ArgMetadata(
101
+ annotation=field.type,
102
+ name=child_name,
103
+ default=argument.default,
104
+ )
105
+ child_args = expand_func_arguments([new_arg])
106
+ expanded_arguments += child_args
107
+ else:
108
+ new_argument = ArgMetadata(
109
+ annotation=field.type,
110
+ name=child_name,
111
+ )
112
+ expanded_arguments.append(new_argument)
113
+ else:
114
+ expanded_arguments.append(argument)
115
+ return expanded_arguments
116
+
117
+
118
+ class FlattenAttributeNameTransformer(ast.NodeTransformer):
119
+ def __init__(self, struct_locals: set[str]) -> None:
120
+ self.struct_locals = struct_locals
121
+
122
+ def visit_Attribute(self, node):
123
+ flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node)
124
+ if not flat_name or flat_name not in self.struct_locals:
125
+ return self.generic_visit(node)
126
+ return ast.copy_location(ast.Name(id=flat_name, ctx=node.ctx), node)
127
+
128
+ @staticmethod
129
+ def _flatten_attribute_name(node: ast.Attribute) -> str | None:
130
+ """
131
+ see unpack_ast_struct_expressions docstring for more explanation
132
+ """
133
+ if isinstance(node.value, ast.Name):
134
+ return create_flat_name(node.value.id, node.attr)
135
+ if isinstance(node.value, ast.Attribute):
136
+ child_flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node.value)
137
+ if not child_flat_name:
138
+ return None
139
+ return create_flat_name(child_flat_name, node.attr)
140
+ return None
141
+
142
+
143
+ def unpack_ast_struct_expressions(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
144
+ """
145
+ Transform nodes in AST, to flatten access to struct members
146
+
147
+ Examples of things we will transform/flatten:
148
+
149
+ # my_struct_ab.a
150
+ # Attribute(value=Name())
151
+ Attribute(
152
+ value=Name(id='my_struct_ab', ctx=Load()),
153
+ attr='a',
154
+ ctx=Load())
155
+ =>
156
+ # __ti_my_struct_ab__ti_a
157
+ Name(id='__ti_my_struct_ab__ti_a', ctx=Load()
158
+
159
+ # my_struct_ab.struct_cd.d
160
+ # Attribute(value=Attribute(value=Name()))
161
+ Attribute(
162
+ value=Attribute(
163
+ value=Name(id='my_struct_ab', ctx=Load()),
164
+ attr='struct_cd',
165
+ ctx=Load()),
166
+ attr='d',
167
+ ctx=Load())
168
+ visit_attribute
169
+ =>
170
+ # __ti_my_struct_ab__ti_struct_cd__ti_d
171
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_d', ctx=Load()
172
+
173
+ # my_struct_ab.struct_cd.struct_ef.f
174
+ # Attribute(value=Attribute(value=Name()))
175
+ Attribute(
176
+ value=Attribute(
177
+ value=Attribute(
178
+ value=Name(id='my_struct_ab', ctx=Load()),
179
+ attr='struct_cd',
180
+ ctx=Load()),
181
+ attr='struct_ef',
182
+ ctx=Load()),
183
+ attr='f',
184
+ ctx=Load())
185
+ =>
186
+ # __ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
187
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f', ctx=Load()
188
+ """
189
+ transformer = FlattenAttributeNameTransformer(struct_locals=struct_locals)
190
+ new_tree = transformer.visit(tree)
191
+ ast.fix_missing_locations(new_tree)
192
+ return new_tree
193
+
194
+
195
+ def populate_global_vars_from_dataclass(
196
+ param_name: str,
197
+ param_type: Any,
198
+ py_arg: Any,
199
+ global_vars: dict[str, Any],
200
+ ):
201
+ for field in dataclasses.fields(param_type):
202
+ child_value = getattr(py_arg, field.name)
203
+ flat_name = create_flat_name(param_name, field.name)
204
+ if dataclasses.is_dataclass(field.type):
205
+ populate_global_vars_from_dataclass(
206
+ param_name=flat_name,
207
+ param_type=field.type,
208
+ py_arg=child_value,
209
+ global_vars=global_vars,
210
+ )
211
+ elif util.is_ti_template(field.type):
212
+ global_vars[flat_name] = child_value
@@ -0,0 +1,352 @@
1
+ # type: ignore
2
+
3
+ from typing import TYPE_CHECKING, Union
4
+
5
+ import numpy as np
6
+
7
+ from gstaichi._lib import core as _ti_core
8
+ from gstaichi.lang import impl
9
+ from gstaichi.lang.exception import GsTaichiIndexError
10
+ from gstaichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
11
+ from gstaichi.types import primitive_types
12
+ from gstaichi.types.enums import Layout
13
+ from gstaichi.types.ndarray_type import NdarrayTypeMetadata
14
+ from gstaichi.types.utils import is_real, is_signed
15
+
16
+ if TYPE_CHECKING:
17
+ from gstaichi.lang.matrix import MatrixNdarray, VectorNdarray
18
+
19
+ TensorNdarray = Union["ScalarNdarray", VectorNdarray, MatrixNdarray]
20
+
21
+
22
+ class Ndarray:
23
+ """GsTaichi ndarray class.
24
+
25
+ Args:
26
+ dtype (DataType): Data type of each value.
27
+ shape (Tuple[int]): Shape of the Ndarray.
28
+ """
29
+
30
+ def __init__(self):
31
+ self.host_accessor = None
32
+ self.shape = None
33
+ self.element_type = None
34
+ self.dtype = None
35
+ self.arr = None
36
+ self.layout = Layout.AOS
37
+ self.grad: "TensorNdarray | None" = None
38
+
39
+ def get_type(self):
40
+ return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
41
+
42
+ @property
43
+ def element_shape(self):
44
+ """Gets ndarray element shape.
45
+
46
+ Returns:
47
+ Tuple[Int]: Ndarray element shape.
48
+ """
49
+ raise NotImplementedError()
50
+
51
+ @python_scope
52
+ def __setitem__(self, key, value):
53
+ """Sets ndarray element in Python scope.
54
+
55
+ Args:
56
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
57
+ value (element type): Value to set.
58
+ """
59
+ raise NotImplementedError()
60
+
61
+ @python_scope
62
+ def __getitem__(self, key):
63
+ """Gets ndarray element in Python scope.
64
+
65
+ Args:
66
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
67
+
68
+ Returns:
69
+ element type: Value retrieved.
70
+ """
71
+ raise NotImplementedError()
72
+
73
+ @python_scope
74
+ def fill(self, val):
75
+ """Fills ndarray with a specific scalar value.
76
+
77
+ Args:
78
+ val (Union[int, float]): Value to fill.
79
+ """
80
+ if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
81
+ self._fill_by_kernel(val)
82
+ elif _ti_core.is_tensor(self.element_type):
83
+ self._fill_by_kernel(val)
84
+ elif self.dtype == primitive_types.f32:
85
+ impl.get_runtime().prog.fill_float(self.arr, val)
86
+ elif self.dtype == primitive_types.i32:
87
+ impl.get_runtime().prog.fill_int(self.arr, val)
88
+ elif self.dtype == primitive_types.u32:
89
+ impl.get_runtime().prog.fill_uint(self.arr, val)
90
+ else:
91
+ self._fill_by_kernel(val)
92
+
93
+ @python_scope
94
+ def _ndarray_to_numpy(self):
95
+ """Converts ndarray to a numpy array.
96
+
97
+ Returns:
98
+ numpy.ndarray: The result numpy array.
99
+ """
100
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
101
+ from gstaichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
102
+
103
+ ndarray_to_ext_arr(self, arr)
104
+ impl.get_runtime().sync()
105
+ return arr
106
+
107
+ @python_scope
108
+ def _ndarray_matrix_to_numpy(self, as_vector):
109
+ """Converts matrix ndarray to a numpy array.
110
+
111
+ Returns:
112
+ numpy.ndarray: The result numpy array.
113
+ """
114
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
115
+ from gstaichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
116
+
117
+ layout_is_aos = 1
118
+ ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
119
+ impl.get_runtime().sync()
120
+ return arr
121
+
122
+ @python_scope
123
+ def _ndarray_from_numpy(self, arr):
124
+ """Loads all values from a numpy array.
125
+
126
+ Args:
127
+ arr (numpy.ndarray): The source numpy array.
128
+ """
129
+ if not isinstance(arr, np.ndarray):
130
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
131
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
132
+ raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
133
+ if not arr.flags.c_contiguous:
134
+ arr = np.ascontiguousarray(arr)
135
+
136
+ from gstaichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
137
+
138
+ ext_arr_to_ndarray(arr, self)
139
+ impl.get_runtime().sync()
140
+
141
+ @python_scope
142
+ def _ndarray_matrix_from_numpy(self, arr, as_vector):
143
+ """Loads all values from a numpy array.
144
+
145
+ Args:
146
+ arr (numpy.ndarray): The source numpy array.
147
+ """
148
+ if not isinstance(arr, np.ndarray):
149
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
150
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
151
+ raise ValueError(
152
+ f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
153
+ )
154
+ if not arr.flags.c_contiguous:
155
+ arr = np.ascontiguousarray(arr)
156
+
157
+ from gstaichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
158
+
159
+ layout_is_aos = 1
160
+ ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
161
+ impl.get_runtime().sync()
162
+
163
+ @python_scope
164
+ def _get_element_size(self):
165
+ """Returns the size of one element in bytes.
166
+
167
+ Returns:
168
+ Size in bytes.
169
+ """
170
+ return self.arr.element_size()
171
+
172
+ @python_scope
173
+ def _get_nelement(self):
174
+ """Returns the total number of elements.
175
+
176
+ Returns:
177
+ Total number of elements.
178
+ """
179
+ return self.arr.nelement()
180
+
181
+ @python_scope
182
+ def copy_from(self, other):
183
+ """Copies all elements from another ndarray.
184
+
185
+ The shape of the other ndarray needs to be the same as `self`.
186
+
187
+ Args:
188
+ other (Ndarray): The source ndarray.
189
+ """
190
+ assert isinstance(other, Ndarray)
191
+ assert tuple(self.arr.shape) == tuple(other.arr.shape)
192
+ from gstaichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
193
+
194
+ ndarray_to_ndarray(self, other)
195
+ impl.get_runtime().sync()
196
+
197
+ def _set_grad(self, grad: "TensorNdarray"):
198
+ """Sets the gradient ndarray.
199
+
200
+ Args:
201
+ grad (Ndarray): The gradient ndarray.
202
+ """
203
+ self.grad = grad
204
+
205
+ def __deepcopy__(self, memo=None):
206
+ """Copies all elements to a new ndarray.
207
+
208
+ Returns:
209
+ Ndarray: The result ndarray.
210
+ """
211
+ raise NotImplementedError()
212
+
213
+ def _fill_by_kernel(self, val):
214
+ """Fills ndarray with a specific scalar value using a ti.kernel.
215
+
216
+ Args:
217
+ val (Union[int, float]): Value to fill.
218
+ """
219
+ raise NotImplementedError()
220
+
221
+ @python_scope
222
+ def _pad_key(self, key):
223
+ if key is None:
224
+ key = ()
225
+ if not isinstance(key, (tuple, list)):
226
+ key = (key,)
227
+ if len(key) != len(self.arr.total_shape()):
228
+ raise GsTaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
229
+ return key
230
+
231
+ @python_scope
232
+ def _initialize_host_accessor(self):
233
+ if self.host_accessor:
234
+ return
235
+ impl.get_runtime().materialize()
236
+ self.host_accessor = NdarrayHostAccessor(self.arr)
237
+
238
+
239
+ class ScalarNdarray(Ndarray):
240
+ """GsTaichi ndarray with scalar elements.
241
+
242
+ Args:
243
+ dtype (DataType): Data type of each value.
244
+ shape (Tuple[int]): Shape of the ndarray.
245
+ """
246
+
247
+ def __init__(self, dtype, arr_shape):
248
+ super().__init__()
249
+ self.dtype = cook_dtype(dtype)
250
+ self.arr = impl.get_runtime().prog.create_ndarray(
251
+ self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
252
+ )
253
+ self.shape = tuple(self.arr.shape)
254
+ self.element_type = dtype
255
+
256
+ def __del__(self):
257
+ if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None:
258
+ prog = impl.get_runtime()._prog
259
+ if prog is not None:
260
+ prog.delete_ndarray(self.arr)
261
+
262
+ @property
263
+ def element_shape(self):
264
+ return ()
265
+
266
+ @python_scope
267
+ def __setitem__(self, key, value):
268
+ self._initialize_host_accessor()
269
+ self.host_accessor.setter(value, *self._pad_key(key))
270
+
271
+ @python_scope
272
+ def __getitem__(self, key):
273
+ self._initialize_host_accessor()
274
+ return self.host_accessor.getter(*self._pad_key(key))
275
+
276
+ @python_scope
277
+ def to_numpy(self):
278
+ return self._ndarray_to_numpy()
279
+
280
+ @python_scope
281
+ def from_numpy(self, arr):
282
+ self._ndarray_from_numpy(arr)
283
+
284
+ def __deepcopy__(self, memo=None):
285
+ ret_arr = ScalarNdarray(self.dtype, self.shape)
286
+ ret_arr.copy_from(self)
287
+ return ret_arr
288
+
289
+ def _fill_by_kernel(self, val):
290
+ from gstaichi._kernels import fill_ndarray # pylint: disable=C0415
291
+
292
+ fill_ndarray(self, val)
293
+
294
+ def __repr__(self):
295
+ return "<ti.ndarray>"
296
+
297
+
298
+ class NdarrayHostAccessor:
299
+ def __init__(self, ndarray):
300
+ dtype = ndarray.element_data_type()
301
+ if is_real(dtype):
302
+
303
+ def getter(*key):
304
+ return ndarray.read_float(key)
305
+
306
+ def setter(value, *key):
307
+ ndarray.write_float(key, value)
308
+
309
+ else:
310
+ if is_signed(dtype):
311
+
312
+ def getter(*key):
313
+ return ndarray.read_int(key)
314
+
315
+ else:
316
+
317
+ def getter(*key):
318
+ return ndarray.read_uint(key)
319
+
320
+ def setter(value, *key):
321
+ ndarray.write_int(key, value)
322
+
323
+ self.getter = getter
324
+ self.setter = setter
325
+
326
+
327
+ class NdarrayHostAccess:
328
+ """Class for accessing VectorNdarray/MatrixNdarray in Python scope.
329
+ Args:
330
+ arr (Union[VectorNdarray, MatrixNdarray]): See above.
331
+ indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
332
+ indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
333
+ """
334
+
335
+ def __init__(self, arr, indices_first, indices_second):
336
+ self.ndarr = arr
337
+ self.arr = arr.arr
338
+ self.indices = indices_first + indices_second
339
+
340
+ def getter():
341
+ self.ndarr._initialize_host_accessor()
342
+ return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
343
+
344
+ def setter(value):
345
+ self.ndarr._initialize_host_accessor()
346
+ self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
347
+
348
+ self.getter = getter
349
+ self.setter = setter
350
+
351
+
352
+ __all__ = ["Ndarray", "ScalarNdarray"]