gstaichi 0.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  11. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  12. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  13. gstaichi/_lib/utils.py +243 -0
  14. gstaichi/_logging.py +131 -0
  15. gstaichi/_snode/__init__.py +5 -0
  16. gstaichi/_snode/fields_builder.py +187 -0
  17. gstaichi/_snode/snode_tree.py +34 -0
  18. gstaichi/_test_tools/__init__.py +18 -0
  19. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  20. gstaichi/_test_tools/load_kernel_string.py +30 -0
  21. gstaichi/_test_tools/textwrap2.py +6 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +366 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +7 -0
  52. gstaichi/lang/ast/ast_transformer.py +1351 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1259 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1386 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +784 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +10 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +21 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  113. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  114. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  115. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  116. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  117. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  118. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  119. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  120. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  121. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  122. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  123. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  124. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  125. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  126. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  127. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  128. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  129. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  130. gstaichi-0.0.0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  131. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  132. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  133. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  134. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  135. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  136. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  137. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  138. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  139. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  140. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  141. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  142. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  143. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  144. gstaichi-0.0.0.data/data/lib/SPIRV-Tools.lib +0 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  149. gstaichi-0.0.0.data/data/lib/glfw3.lib +0 -0
  150. gstaichi-0.0.0.dist-info/METADATA +97 -0
  151. gstaichi-0.0.0.dist-info/RECORD +154 -0
  152. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  153. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  154. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,83 @@
1
+ from typing import Any, Iterable, Sequence
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from gstaichi import _logging
6
+
7
+ from .._wrap_inspect import FunctionSourceInfo
8
+ from . import args_hasher, config_hasher, function_hasher
9
+ from .fast_caching_types import HashedFunctionSourceInfo
10
+ from .hash_utils import hash_iterable_strings
11
+ from .python_side_cache import PythonSideCache
12
+
13
+
14
+ def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
15
+ """
16
+ cache key takes into account:
17
+ - arg types
18
+ - cache value arg values
19
+ - kernel function (but not sub functions)
20
+ - compilation config (which includes arch, and debug)
21
+ """
22
+ args_hash = args_hasher.hash_args(args)
23
+ if args_hash is None:
24
+ # the bit in caps at start should not be modified without modifying corresponding text
25
+ # freetext bit can be freely modified
26
+ _logging.warn(
27
+ f"[FASTCACHE][INVALID_FUNC] The pure function {kernel_source_info.function_name} could not be "
28
+ "fast cached, because one or more parameter types were invalid"
29
+ )
30
+ return None
31
+ kernel_hash = function_hasher.hash_kernel(kernel_source_info)
32
+ config_hash = config_hasher.hash_compile_config()
33
+ cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
34
+ return cache_key
35
+
36
+
37
+ class CacheValue(BaseModel):
38
+ hashed_function_source_infos: list[HashedFunctionSourceInfo]
39
+
40
+
41
+ def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
42
+ """
43
+ Note that unlike other caches, this cache is not going to store the actual value we want.
44
+ This cache is only used for verification that our cache key is valid. Big picture:
45
+ - we have a cache key, based on args and top level kernel function
46
+ - we want to use this to look up LLVM IR, in C++ side cache
47
+ - however, before doing that, we first want to validate that the source code didn't change
48
+ - i.e. is our cache key still valid?
49
+ - the python side cache contains information we will use to verify that our cache key is valid
50
+ - ie the list of function source infos
51
+ """
52
+ if not cache_key:
53
+ return
54
+ cache = PythonSideCache()
55
+ hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
56
+ cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
57
+ cache.store(cache_key, cache_value_obj.json())
58
+
59
+
60
+ def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
61
+ cache = PythonSideCache()
62
+ maybe_cache_value_json = cache.try_load(cache_key)
63
+ if maybe_cache_value_json is None:
64
+ return None
65
+ cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
66
+ return cache_value_obj.hashed_function_source_infos
67
+
68
+
69
+ def validate_cache_key(cache_key: str) -> bool:
70
+ """
71
+ loads function source infos from cache, if available
72
+ checks the hashes against the current source code
73
+ """
74
+ maybe_hashed_function_source_infos = _try_load(cache_key)
75
+ if not maybe_hashed_function_source_infos:
76
+ return False
77
+ return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
78
+
79
+
80
+ def dump_stats() -> None:
81
+ print("dump stats")
82
+ args_hasher.dump_stats()
83
+ function_hasher.dump_stats()
@@ -0,0 +1,212 @@
1
+ import ast
2
+ import dataclasses
3
+ import inspect
4
+ from typing import Any
5
+
6
+ from gstaichi.lang import util
7
+ from gstaichi.lang._dataclass_util import create_flat_name
8
+ from gstaichi.lang.ast import (
9
+ ASTTransformerContext,
10
+ )
11
+ from gstaichi.lang.kernel_arguments import ArgMetadata
12
+
13
+
14
+ def _populate_struct_locals_from_params_dict(basename: str, struct_locals, struct_type) -> None:
15
+ """
16
+ We are populating struct locals from a type included in function parameters, or one of their subtypes
17
+
18
+ struct_locals will be a list of all possible unpacked variable names we can form from the struct.
19
+ basename is used to take into account the parent struct's name. For example, lets say we have:
20
+
21
+ @dataclasses.dataclass
22
+ class StructAB:
23
+ a:
24
+ b:
25
+ struct_cd: StructCD
26
+
27
+ @dataclasses.dataclass
28
+ class StructCD:
29
+ c:
30
+ d:
31
+ struct_ef: StructEF
32
+
33
+ @dataclasses.dataclass
34
+ class StructEF:
35
+ e:
36
+ f:
37
+
38
+ ... and the function parameters look like: `def foo(struct_ab: StructAB)`
39
+
40
+ then all possible variables we could form from this are:
41
+ - struct_ab.a
42
+ - struct_ab.b
43
+ - struct_ab.struct_cd.c
44
+ - struct_ab.struct_cd.d
45
+ - struct_ab.struct_cd.strucdt_ef.e
46
+ - struct_ab.struct_cd.strucdt_ef.f
47
+
48
+ And the members of struct_locals should be:
49
+ - __ti_struct_ab__ti_a
50
+ - __ti_struct_ab__ti_b
51
+ - __ti_struct_ab__ti_struct_cd__ti_c
52
+ - __ti_struct_ab__ti_struct_cd__ti_d
53
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_e
54
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
55
+ """
56
+ for field in dataclasses.fields(struct_type):
57
+ child_name = create_flat_name(basename, field.name)
58
+ if dataclasses.is_dataclass(field.type):
59
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
60
+ else:
61
+ struct_locals.add(child_name)
62
+
63
+
64
+ def extract_struct_locals_from_context(ctx: ASTTransformerContext) -> set[str]:
65
+ """
66
+ Provides meta information for later tarnsformation of nodes in AST
67
+
68
+ - Uses ctx.func.func to get the function signature.
69
+ - Searches this for any dataclasses:
70
+ - If it finds any dataclasses, then converts them into expanded names.
71
+ - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
72
+ {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
73
+ """
74
+ struct_locals = set()
75
+ assert ctx.func is not None
76
+ sig = inspect.signature(ctx.func.func)
77
+ parameters = sig.parameters
78
+ for param_name, parameter in parameters.items():
79
+ if dataclasses.is_dataclass(parameter.annotation):
80
+ for field in dataclasses.fields(parameter.annotation):
81
+ child_name = create_flat_name(param_name, field.name)
82
+ # child_name = f"__ti_{param_name}__ti_{field.name}"
83
+ if dataclasses.is_dataclass(field.type):
84
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
85
+ continue
86
+ struct_locals.add(child_name)
87
+ return struct_locals
88
+
89
+
90
+ def expand_func_arguments(arguments: list[ArgMetadata]) -> list[ArgMetadata]:
91
+ """
92
+ Used to expand arguments for @ti.func
93
+ """
94
+ expanded_arguments = []
95
+ for i, argument in enumerate(arguments):
96
+ if dataclasses.is_dataclass(argument.annotation):
97
+ for field in dataclasses.fields(argument.annotation):
98
+ child_name = create_flat_name(argument.name, field.name)
99
+ if dataclasses.is_dataclass(field.type):
100
+ new_arg = ArgMetadata(
101
+ annotation=field.type,
102
+ name=child_name,
103
+ default=argument.default,
104
+ )
105
+ child_args = expand_func_arguments([new_arg])
106
+ expanded_arguments += child_args
107
+ else:
108
+ new_argument = ArgMetadata(
109
+ annotation=field.type,
110
+ name=child_name,
111
+ )
112
+ expanded_arguments.append(new_argument)
113
+ else:
114
+ expanded_arguments.append(argument)
115
+ return expanded_arguments
116
+
117
+
118
+ class FlattenAttributeNameTransformer(ast.NodeTransformer):
119
+ def __init__(self, struct_locals: set[str]) -> None:
120
+ self.struct_locals = struct_locals
121
+
122
+ def visit_Attribute(self, node):
123
+ flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node)
124
+ if not flat_name or flat_name not in self.struct_locals:
125
+ return self.generic_visit(node)
126
+ return ast.copy_location(ast.Name(id=flat_name, ctx=node.ctx), node)
127
+
128
+ @staticmethod
129
+ def _flatten_attribute_name(node: ast.Attribute) -> str | None:
130
+ """
131
+ see unpack_ast_struct_expressions docstring for more explanation
132
+ """
133
+ if isinstance(node.value, ast.Name):
134
+ return create_flat_name(node.value.id, node.attr)
135
+ if isinstance(node.value, ast.Attribute):
136
+ child_flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node.value)
137
+ if not child_flat_name:
138
+ return None
139
+ return create_flat_name(child_flat_name, node.attr)
140
+ return None
141
+
142
+
143
+ def unpack_ast_struct_expressions(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
144
+ """
145
+ Transform nodes in AST, to flatten access to struct members
146
+
147
+ Examples of things we will transform/flatten:
148
+
149
+ # my_struct_ab.a
150
+ # Attribute(value=Name())
151
+ Attribute(
152
+ value=Name(id='my_struct_ab', ctx=Load()),
153
+ attr='a',
154
+ ctx=Load())
155
+ =>
156
+ # __ti_my_struct_ab__ti_a
157
+ Name(id='__ti_my_struct_ab__ti_a', ctx=Load()
158
+
159
+ # my_struct_ab.struct_cd.d
160
+ # Attribute(value=Attribute(value=Name()))
161
+ Attribute(
162
+ value=Attribute(
163
+ value=Name(id='my_struct_ab', ctx=Load()),
164
+ attr='struct_cd',
165
+ ctx=Load()),
166
+ attr='d',
167
+ ctx=Load())
168
+ visit_attribute
169
+ =>
170
+ # __ti_my_struct_ab__ti_struct_cd__ti_d
171
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_d', ctx=Load()
172
+
173
+ # my_struct_ab.struct_cd.struct_ef.f
174
+ # Attribute(value=Attribute(value=Name()))
175
+ Attribute(
176
+ value=Attribute(
177
+ value=Attribute(
178
+ value=Name(id='my_struct_ab', ctx=Load()),
179
+ attr='struct_cd',
180
+ ctx=Load()),
181
+ attr='struct_ef',
182
+ ctx=Load()),
183
+ attr='f',
184
+ ctx=Load())
185
+ =>
186
+ # __ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
187
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f', ctx=Load()
188
+ """
189
+ transformer = FlattenAttributeNameTransformer(struct_locals=struct_locals)
190
+ new_tree = transformer.visit(tree)
191
+ ast.fix_missing_locations(new_tree)
192
+ return new_tree
193
+
194
+
195
+ def populate_global_vars_from_dataclass(
196
+ param_name: str,
197
+ param_type: Any,
198
+ py_arg: Any,
199
+ global_vars: dict[str, Any],
200
+ ):
201
+ for field in dataclasses.fields(param_type):
202
+ child_value = getattr(py_arg, field.name)
203
+ flat_name = create_flat_name(param_name, field.name)
204
+ if dataclasses.is_dataclass(field.type):
205
+ populate_global_vars_from_dataclass(
206
+ param_name=flat_name,
207
+ param_type=field.type,
208
+ py_arg=child_value,
209
+ global_vars=global_vars,
210
+ )
211
+ elif util.is_ti_template(field.type):
212
+ global_vars[flat_name] = child_value
@@ -0,0 +1,366 @@
1
+ # type: ignore
2
+
3
+ from typing import TYPE_CHECKING, Union
4
+
5
+ import numpy as np
6
+
7
+ from gstaichi._lib import core as _ti_core
8
+ from gstaichi.lang import impl
9
+ from gstaichi.lang.exception import GsTaichiIndexError
10
+ from gstaichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
11
+ from gstaichi.types import primitive_types
12
+ from gstaichi.types.enums import Layout
13
+ from gstaichi.types.ndarray_type import NdarrayTypeMetadata
14
+ from gstaichi.types.utils import is_real, is_signed
15
+
16
+ if TYPE_CHECKING:
17
+ from gstaichi.lang.matrix import MatrixNdarray, VectorNdarray
18
+
19
+ TensorNdarray = Union["ScalarNdarray", VectorNdarray, MatrixNdarray]
20
+
21
+
22
+ class Ndarray:
23
+ """GsTaichi ndarray class.
24
+
25
+ Args:
26
+ dtype (DataType): Data type of each value.
27
+ shape (Tuple[int]): Shape of the Ndarray.
28
+ """
29
+
30
+ def __init__(self):
31
+ self.host_accessor = None
32
+ self.shape = None
33
+ self.element_type = None
34
+ self.dtype = None
35
+ self.arr = None
36
+ self.layout = Layout.AOS
37
+ self.grad: "TensorNdarray | None" = None
38
+ # we register with runtime, in order to enable reset to work later
39
+ impl.get_runtime().ndarrays.add(self)
40
+
41
+ def _reset(self):
42
+ """
43
+ Called by runtime, when we call ti.reset()
44
+ """
45
+ self.arr = None
46
+ self.grad = None
47
+ self.host_accessor = None
48
+ self.shape = None
49
+ self.element_type = None
50
+ self.dtype = None
51
+ self.layout = None
52
+
53
+ def get_type(self):
54
+ return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
55
+
56
+ @property
57
+ def element_shape(self):
58
+ """Gets ndarray element shape.
59
+
60
+ Returns:
61
+ Tuple[Int]: Ndarray element shape.
62
+ """
63
+ raise NotImplementedError()
64
+
65
+ @python_scope
66
+ def __setitem__(self, key, value):
67
+ """Sets ndarray element in Python scope.
68
+
69
+ Args:
70
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
71
+ value (element type): Value to set.
72
+ """
73
+ raise NotImplementedError()
74
+
75
+ @python_scope
76
+ def __getitem__(self, key):
77
+ """Gets ndarray element in Python scope.
78
+
79
+ Args:
80
+ key (Union[List[int], int, None]): Coordinates of the ndarray element.
81
+
82
+ Returns:
83
+ element type: Value retrieved.
84
+ """
85
+ raise NotImplementedError()
86
+
87
+ @python_scope
88
+ def fill(self, val):
89
+ """Fills ndarray with a specific scalar value.
90
+
91
+ Args:
92
+ val (Union[int, float]): Value to fill.
93
+ """
94
+ if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
95
+ self._fill_by_kernel(val)
96
+ elif _ti_core.is_tensor(self.element_type):
97
+ self._fill_by_kernel(val)
98
+ elif self.dtype == primitive_types.f32:
99
+ impl.get_runtime().prog.fill_float(self.arr, val)
100
+ elif self.dtype == primitive_types.i32:
101
+ impl.get_runtime().prog.fill_int(self.arr, val)
102
+ elif self.dtype == primitive_types.u32:
103
+ impl.get_runtime().prog.fill_uint(self.arr, val)
104
+ else:
105
+ self._fill_by_kernel(val)
106
+
107
+ @python_scope
108
+ def _ndarray_to_numpy(self):
109
+ """Converts ndarray to a numpy array.
110
+
111
+ Returns:
112
+ numpy.ndarray: The result numpy array.
113
+ """
114
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
115
+ from gstaichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
116
+
117
+ ndarray_to_ext_arr(self, arr)
118
+ impl.get_runtime().sync()
119
+ return arr
120
+
121
+ @python_scope
122
+ def _ndarray_matrix_to_numpy(self, as_vector):
123
+ """Converts matrix ndarray to a numpy array.
124
+
125
+ Returns:
126
+ numpy.ndarray: The result numpy array.
127
+ """
128
+ arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
129
+ from gstaichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
130
+
131
+ layout_is_aos = 1
132
+ ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
133
+ impl.get_runtime().sync()
134
+ return arr
135
+
136
+ @python_scope
137
+ def _ndarray_from_numpy(self, arr):
138
+ """Loads all values from a numpy array.
139
+
140
+ Args:
141
+ arr (numpy.ndarray): The source numpy array.
142
+ """
143
+ if not isinstance(arr, np.ndarray):
144
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
145
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
146
+ raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
147
+ if not arr.flags.c_contiguous:
148
+ arr = np.ascontiguousarray(arr)
149
+
150
+ from gstaichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
151
+
152
+ ext_arr_to_ndarray(arr, self)
153
+ impl.get_runtime().sync()
154
+
155
+ @python_scope
156
+ def _ndarray_matrix_from_numpy(self, arr, as_vector):
157
+ """Loads all values from a numpy array.
158
+
159
+ Args:
160
+ arr (numpy.ndarray): The source numpy array.
161
+ """
162
+ if not isinstance(arr, np.ndarray):
163
+ raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
164
+ if tuple(self.arr.total_shape()) != tuple(arr.shape):
165
+ raise ValueError(
166
+ f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
167
+ )
168
+ if not arr.flags.c_contiguous:
169
+ arr = np.ascontiguousarray(arr)
170
+
171
+ from gstaichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
172
+
173
+ layout_is_aos = 1
174
+ ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
175
+ impl.get_runtime().sync()
176
+
177
+ @python_scope
178
+ def _get_element_size(self):
179
+ """Returns the size of one element in bytes.
180
+
181
+ Returns:
182
+ Size in bytes.
183
+ """
184
+ return self.arr.element_size()
185
+
186
+ @python_scope
187
+ def _get_nelement(self):
188
+ """Returns the total number of elements.
189
+
190
+ Returns:
191
+ Total number of elements.
192
+ """
193
+ return self.arr.nelement()
194
+
195
+ @python_scope
196
+ def copy_from(self, other):
197
+ """Copies all elements from another ndarray.
198
+
199
+ The shape of the other ndarray needs to be the same as `self`.
200
+
201
+ Args:
202
+ other (Ndarray): The source ndarray.
203
+ """
204
+ assert isinstance(other, Ndarray)
205
+ assert tuple(self.arr.shape) == tuple(other.arr.shape)
206
+ from gstaichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
207
+
208
+ ndarray_to_ndarray(self, other)
209
+ impl.get_runtime().sync()
210
+
211
+ def _set_grad(self, grad: "TensorNdarray"):
212
+ """Sets the gradient ndarray.
213
+
214
+ Args:
215
+ grad (Ndarray): The gradient ndarray.
216
+ """
217
+ self.grad = grad
218
+
219
+ def __deepcopy__(self, memo=None):
220
+ """Copies all elements to a new ndarray.
221
+
222
+ Returns:
223
+ Ndarray: The result ndarray.
224
+ """
225
+ raise NotImplementedError()
226
+
227
+ def _fill_by_kernel(self, val):
228
+ """Fills ndarray with a specific scalar value using a ti.kernel.
229
+
230
+ Args:
231
+ val (Union[int, float]): Value to fill.
232
+ """
233
+ raise NotImplementedError()
234
+
235
+ @python_scope
236
+ def _pad_key(self, key):
237
+ if key is None:
238
+ key = ()
239
+ if not isinstance(key, (tuple, list)):
240
+ key = (key,)
241
+ if len(key) != len(self.arr.total_shape()):
242
+ raise GsTaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
243
+ return key
244
+
245
+ @python_scope
246
+ def _initialize_host_accessor(self):
247
+ if self.host_accessor:
248
+ return
249
+ impl.get_runtime().materialize()
250
+ self.host_accessor = NdarrayHostAccessor(self.arr)
251
+
252
+
253
+ class ScalarNdarray(Ndarray):
254
+ """GsTaichi ndarray with scalar elements.
255
+
256
+ Args:
257
+ dtype (DataType): Data type of each value.
258
+ shape (Tuple[int]): Shape of the ndarray.
259
+ """
260
+
261
+ def __init__(self, dtype, arr_shape):
262
+ super().__init__()
263
+ self.dtype = cook_dtype(dtype)
264
+ self.arr = impl.get_runtime().prog.create_ndarray(
265
+ self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
266
+ )
267
+ self.shape = tuple(self.arr.shape)
268
+ self.element_type = dtype
269
+
270
+ def __del__(self):
271
+ if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None:
272
+ prog = impl.get_runtime()._prog
273
+ if prog is not None:
274
+ prog.delete_ndarray(self.arr)
275
+
276
+ @property
277
+ def element_shape(self):
278
+ return ()
279
+
280
+ @python_scope
281
+ def __setitem__(self, key, value):
282
+ self._initialize_host_accessor()
283
+ self.host_accessor.setter(value, *self._pad_key(key))
284
+
285
+ @python_scope
286
+ def __getitem__(self, key):
287
+ self._initialize_host_accessor()
288
+ return self.host_accessor.getter(*self._pad_key(key))
289
+
290
+ @python_scope
291
+ def to_numpy(self):
292
+ return self._ndarray_to_numpy()
293
+
294
+ @python_scope
295
+ def from_numpy(self, arr):
296
+ self._ndarray_from_numpy(arr)
297
+
298
+ def __deepcopy__(self, memo=None):
299
+ ret_arr = ScalarNdarray(self.dtype, self.shape)
300
+ ret_arr.copy_from(self)
301
+ return ret_arr
302
+
303
+ def _fill_by_kernel(self, val):
304
+ from gstaichi._kernels import fill_ndarray # pylint: disable=C0415
305
+
306
+ fill_ndarray(self, val)
307
+
308
+ def __repr__(self):
309
+ return "<ti.ndarray>"
310
+
311
+
312
+ class NdarrayHostAccessor:
313
+ def __init__(self, ndarray):
314
+ dtype = ndarray.element_data_type()
315
+ if is_real(dtype):
316
+
317
+ def getter(*key):
318
+ return ndarray.read_float(key)
319
+
320
+ def setter(value, *key):
321
+ ndarray.write_float(key, value)
322
+
323
+ else:
324
+ if is_signed(dtype):
325
+
326
+ def getter(*key):
327
+ return ndarray.read_int(key)
328
+
329
+ else:
330
+
331
+ def getter(*key):
332
+ return ndarray.read_uint(key)
333
+
334
+ def setter(value, *key):
335
+ ndarray.write_int(key, value)
336
+
337
+ self.getter = getter
338
+ self.setter = setter
339
+
340
+
341
+ class NdarrayHostAccess:
342
+ """Class for accessing VectorNdarray/MatrixNdarray in Python scope.
343
+ Args:
344
+ arr (Union[VectorNdarray, MatrixNdarray]): See above.
345
+ indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
346
+ indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
347
+ """
348
+
349
+ def __init__(self, arr, indices_first, indices_second):
350
+ self.ndarr = arr
351
+ self.arr = arr.arr
352
+ self.indices = indices_first + indices_second
353
+
354
+ def getter():
355
+ self.ndarr._initialize_host_accessor()
356
+ return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
357
+
358
+ def setter(value):
359
+ self.ndarr._initialize_host_accessor()
360
+ self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
361
+
362
+ self.getter = getter
363
+ self.setter = setter
364
+
365
+
366
+ __all__ = ["Ndarray", "ScalarNdarray"]