gstaichi 0.0.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +51 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +5 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cp313-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +122 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +83 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +366 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +7 -0
- gstaichi/lang/ast/ast_transformer.py +1351 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1259 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1386 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +784 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +10 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +21 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.0.0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.0.0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.0.0.dist-info/METADATA +97 -0
- gstaichi-0.0.0.dist-info/RECORD +154 -0
- gstaichi-0.0.0.dist-info/WHEEL +5 -0
- gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,83 @@
|
|
1
|
+
from typing import Any, Iterable, Sequence
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from gstaichi import _logging
|
6
|
+
|
7
|
+
from .._wrap_inspect import FunctionSourceInfo
|
8
|
+
from . import args_hasher, config_hasher, function_hasher
|
9
|
+
from .fast_caching_types import HashedFunctionSourceInfo
|
10
|
+
from .hash_utils import hash_iterable_strings
|
11
|
+
from .python_side_cache import PythonSideCache
|
12
|
+
|
13
|
+
|
14
|
+
def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
|
15
|
+
"""
|
16
|
+
cache key takes into account:
|
17
|
+
- arg types
|
18
|
+
- cache value arg values
|
19
|
+
- kernel function (but not sub functions)
|
20
|
+
- compilation config (which includes arch, and debug)
|
21
|
+
"""
|
22
|
+
args_hash = args_hasher.hash_args(args)
|
23
|
+
if args_hash is None:
|
24
|
+
# the bit in caps at start should not be modified without modifying corresponding text
|
25
|
+
# freetext bit can be freely modified
|
26
|
+
_logging.warn(
|
27
|
+
f"[FASTCACHE][INVALID_FUNC] The pure function {kernel_source_info.function_name} could not be "
|
28
|
+
"fast cached, because one or more parameter types were invalid"
|
29
|
+
)
|
30
|
+
return None
|
31
|
+
kernel_hash = function_hasher.hash_kernel(kernel_source_info)
|
32
|
+
config_hash = config_hasher.hash_compile_config()
|
33
|
+
cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
|
34
|
+
return cache_key
|
35
|
+
|
36
|
+
|
37
|
+
class CacheValue(BaseModel):
|
38
|
+
hashed_function_source_infos: list[HashedFunctionSourceInfo]
|
39
|
+
|
40
|
+
|
41
|
+
def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
|
42
|
+
"""
|
43
|
+
Note that unlike other caches, this cache is not going to store the actual value we want.
|
44
|
+
This cache is only used for verification that our cache key is valid. Big picture:
|
45
|
+
- we have a cache key, based on args and top level kernel function
|
46
|
+
- we want to use this to look up LLVM IR, in C++ side cache
|
47
|
+
- however, before doing that, we first want to validate that the source code didn't change
|
48
|
+
- i.e. is our cache key still valid?
|
49
|
+
- the python side cache contains information we will use to verify that our cache key is valid
|
50
|
+
- ie the list of function source infos
|
51
|
+
"""
|
52
|
+
if not cache_key:
|
53
|
+
return
|
54
|
+
cache = PythonSideCache()
|
55
|
+
hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
|
56
|
+
cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
|
57
|
+
cache.store(cache_key, cache_value_obj.json())
|
58
|
+
|
59
|
+
|
60
|
+
def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
|
61
|
+
cache = PythonSideCache()
|
62
|
+
maybe_cache_value_json = cache.try_load(cache_key)
|
63
|
+
if maybe_cache_value_json is None:
|
64
|
+
return None
|
65
|
+
cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
|
66
|
+
return cache_value_obj.hashed_function_source_infos
|
67
|
+
|
68
|
+
|
69
|
+
def validate_cache_key(cache_key: str) -> bool:
|
70
|
+
"""
|
71
|
+
loads function source infos from cache, if available
|
72
|
+
checks the hashes against the current source code
|
73
|
+
"""
|
74
|
+
maybe_hashed_function_source_infos = _try_load(cache_key)
|
75
|
+
if not maybe_hashed_function_source_infos:
|
76
|
+
return False
|
77
|
+
return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
|
78
|
+
|
79
|
+
|
80
|
+
def dump_stats() -> None:
|
81
|
+
print("dump stats")
|
82
|
+
args_hasher.dump_stats()
|
83
|
+
function_hasher.dump_stats()
|
@@ -0,0 +1,212 @@
|
|
1
|
+
import ast
|
2
|
+
import dataclasses
|
3
|
+
import inspect
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from gstaichi.lang import util
|
7
|
+
from gstaichi.lang._dataclass_util import create_flat_name
|
8
|
+
from gstaichi.lang.ast import (
|
9
|
+
ASTTransformerContext,
|
10
|
+
)
|
11
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
12
|
+
|
13
|
+
|
14
|
+
def _populate_struct_locals_from_params_dict(basename: str, struct_locals, struct_type) -> None:
|
15
|
+
"""
|
16
|
+
We are populating struct locals from a type included in function parameters, or one of their subtypes
|
17
|
+
|
18
|
+
struct_locals will be a list of all possible unpacked variable names we can form from the struct.
|
19
|
+
basename is used to take into account the parent struct's name. For example, lets say we have:
|
20
|
+
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class StructAB:
|
23
|
+
a:
|
24
|
+
b:
|
25
|
+
struct_cd: StructCD
|
26
|
+
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class StructCD:
|
29
|
+
c:
|
30
|
+
d:
|
31
|
+
struct_ef: StructEF
|
32
|
+
|
33
|
+
@dataclasses.dataclass
|
34
|
+
class StructEF:
|
35
|
+
e:
|
36
|
+
f:
|
37
|
+
|
38
|
+
... and the function parameters look like: `def foo(struct_ab: StructAB)`
|
39
|
+
|
40
|
+
then all possible variables we could form from this are:
|
41
|
+
- struct_ab.a
|
42
|
+
- struct_ab.b
|
43
|
+
- struct_ab.struct_cd.c
|
44
|
+
- struct_ab.struct_cd.d
|
45
|
+
- struct_ab.struct_cd.strucdt_ef.e
|
46
|
+
- struct_ab.struct_cd.strucdt_ef.f
|
47
|
+
|
48
|
+
And the members of struct_locals should be:
|
49
|
+
- __ti_struct_ab__ti_a
|
50
|
+
- __ti_struct_ab__ti_b
|
51
|
+
- __ti_struct_ab__ti_struct_cd__ti_c
|
52
|
+
- __ti_struct_ab__ti_struct_cd__ti_d
|
53
|
+
- __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_e
|
54
|
+
- __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
|
55
|
+
"""
|
56
|
+
for field in dataclasses.fields(struct_type):
|
57
|
+
child_name = create_flat_name(basename, field.name)
|
58
|
+
if dataclasses.is_dataclass(field.type):
|
59
|
+
_populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
|
60
|
+
else:
|
61
|
+
struct_locals.add(child_name)
|
62
|
+
|
63
|
+
|
64
|
+
def extract_struct_locals_from_context(ctx: ASTTransformerContext) -> set[str]:
|
65
|
+
"""
|
66
|
+
Provides meta information for later tarnsformation of nodes in AST
|
67
|
+
|
68
|
+
- Uses ctx.func.func to get the function signature.
|
69
|
+
- Searches this for any dataclasses:
|
70
|
+
- If it finds any dataclasses, then converts them into expanded names.
|
71
|
+
- E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
|
72
|
+
{"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
|
73
|
+
"""
|
74
|
+
struct_locals = set()
|
75
|
+
assert ctx.func is not None
|
76
|
+
sig = inspect.signature(ctx.func.func)
|
77
|
+
parameters = sig.parameters
|
78
|
+
for param_name, parameter in parameters.items():
|
79
|
+
if dataclasses.is_dataclass(parameter.annotation):
|
80
|
+
for field in dataclasses.fields(parameter.annotation):
|
81
|
+
child_name = create_flat_name(param_name, field.name)
|
82
|
+
# child_name = f"__ti_{param_name}__ti_{field.name}"
|
83
|
+
if dataclasses.is_dataclass(field.type):
|
84
|
+
_populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
|
85
|
+
continue
|
86
|
+
struct_locals.add(child_name)
|
87
|
+
return struct_locals
|
88
|
+
|
89
|
+
|
90
|
+
def expand_func_arguments(arguments: list[ArgMetadata]) -> list[ArgMetadata]:
|
91
|
+
"""
|
92
|
+
Used to expand arguments for @ti.func
|
93
|
+
"""
|
94
|
+
expanded_arguments = []
|
95
|
+
for i, argument in enumerate(arguments):
|
96
|
+
if dataclasses.is_dataclass(argument.annotation):
|
97
|
+
for field in dataclasses.fields(argument.annotation):
|
98
|
+
child_name = create_flat_name(argument.name, field.name)
|
99
|
+
if dataclasses.is_dataclass(field.type):
|
100
|
+
new_arg = ArgMetadata(
|
101
|
+
annotation=field.type,
|
102
|
+
name=child_name,
|
103
|
+
default=argument.default,
|
104
|
+
)
|
105
|
+
child_args = expand_func_arguments([new_arg])
|
106
|
+
expanded_arguments += child_args
|
107
|
+
else:
|
108
|
+
new_argument = ArgMetadata(
|
109
|
+
annotation=field.type,
|
110
|
+
name=child_name,
|
111
|
+
)
|
112
|
+
expanded_arguments.append(new_argument)
|
113
|
+
else:
|
114
|
+
expanded_arguments.append(argument)
|
115
|
+
return expanded_arguments
|
116
|
+
|
117
|
+
|
118
|
+
class FlattenAttributeNameTransformer(ast.NodeTransformer):
|
119
|
+
def __init__(self, struct_locals: set[str]) -> None:
|
120
|
+
self.struct_locals = struct_locals
|
121
|
+
|
122
|
+
def visit_Attribute(self, node):
|
123
|
+
flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node)
|
124
|
+
if not flat_name or flat_name not in self.struct_locals:
|
125
|
+
return self.generic_visit(node)
|
126
|
+
return ast.copy_location(ast.Name(id=flat_name, ctx=node.ctx), node)
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def _flatten_attribute_name(node: ast.Attribute) -> str | None:
|
130
|
+
"""
|
131
|
+
see unpack_ast_struct_expressions docstring for more explanation
|
132
|
+
"""
|
133
|
+
if isinstance(node.value, ast.Name):
|
134
|
+
return create_flat_name(node.value.id, node.attr)
|
135
|
+
if isinstance(node.value, ast.Attribute):
|
136
|
+
child_flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node.value)
|
137
|
+
if not child_flat_name:
|
138
|
+
return None
|
139
|
+
return create_flat_name(child_flat_name, node.attr)
|
140
|
+
return None
|
141
|
+
|
142
|
+
|
143
|
+
def unpack_ast_struct_expressions(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
|
144
|
+
"""
|
145
|
+
Transform nodes in AST, to flatten access to struct members
|
146
|
+
|
147
|
+
Examples of things we will transform/flatten:
|
148
|
+
|
149
|
+
# my_struct_ab.a
|
150
|
+
# Attribute(value=Name())
|
151
|
+
Attribute(
|
152
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
153
|
+
attr='a',
|
154
|
+
ctx=Load())
|
155
|
+
=>
|
156
|
+
# __ti_my_struct_ab__ti_a
|
157
|
+
Name(id='__ti_my_struct_ab__ti_a', ctx=Load()
|
158
|
+
|
159
|
+
# my_struct_ab.struct_cd.d
|
160
|
+
# Attribute(value=Attribute(value=Name()))
|
161
|
+
Attribute(
|
162
|
+
value=Attribute(
|
163
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
164
|
+
attr='struct_cd',
|
165
|
+
ctx=Load()),
|
166
|
+
attr='d',
|
167
|
+
ctx=Load())
|
168
|
+
visit_attribute
|
169
|
+
=>
|
170
|
+
# __ti_my_struct_ab__ti_struct_cd__ti_d
|
171
|
+
Name(id='__ti_my_struct_ab__ti_struct_cd__ti_d', ctx=Load()
|
172
|
+
|
173
|
+
# my_struct_ab.struct_cd.struct_ef.f
|
174
|
+
# Attribute(value=Attribute(value=Name()))
|
175
|
+
Attribute(
|
176
|
+
value=Attribute(
|
177
|
+
value=Attribute(
|
178
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
179
|
+
attr='struct_cd',
|
180
|
+
ctx=Load()),
|
181
|
+
attr='struct_ef',
|
182
|
+
ctx=Load()),
|
183
|
+
attr='f',
|
184
|
+
ctx=Load())
|
185
|
+
=>
|
186
|
+
# __ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
|
187
|
+
Name(id='__ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f', ctx=Load()
|
188
|
+
"""
|
189
|
+
transformer = FlattenAttributeNameTransformer(struct_locals=struct_locals)
|
190
|
+
new_tree = transformer.visit(tree)
|
191
|
+
ast.fix_missing_locations(new_tree)
|
192
|
+
return new_tree
|
193
|
+
|
194
|
+
|
195
|
+
def populate_global_vars_from_dataclass(
|
196
|
+
param_name: str,
|
197
|
+
param_type: Any,
|
198
|
+
py_arg: Any,
|
199
|
+
global_vars: dict[str, Any],
|
200
|
+
):
|
201
|
+
for field in dataclasses.fields(param_type):
|
202
|
+
child_value = getattr(py_arg, field.name)
|
203
|
+
flat_name = create_flat_name(param_name, field.name)
|
204
|
+
if dataclasses.is_dataclass(field.type):
|
205
|
+
populate_global_vars_from_dataclass(
|
206
|
+
param_name=flat_name,
|
207
|
+
param_type=field.type,
|
208
|
+
py_arg=child_value,
|
209
|
+
global_vars=global_vars,
|
210
|
+
)
|
211
|
+
elif util.is_ti_template(field.type):
|
212
|
+
global_vars[flat_name] = child_value
|
@@ -0,0 +1,366 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi.lang import impl
|
9
|
+
from gstaichi.lang.exception import GsTaichiIndexError
|
10
|
+
from gstaichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
|
11
|
+
from gstaichi.types import primitive_types
|
12
|
+
from gstaichi.types.enums import Layout
|
13
|
+
from gstaichi.types.ndarray_type import NdarrayTypeMetadata
|
14
|
+
from gstaichi.types.utils import is_real, is_signed
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from gstaichi.lang.matrix import MatrixNdarray, VectorNdarray
|
18
|
+
|
19
|
+
TensorNdarray = Union["ScalarNdarray", VectorNdarray, MatrixNdarray]
|
20
|
+
|
21
|
+
|
22
|
+
class Ndarray:
|
23
|
+
"""GsTaichi ndarray class.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
dtype (DataType): Data type of each value.
|
27
|
+
shape (Tuple[int]): Shape of the Ndarray.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self):
|
31
|
+
self.host_accessor = None
|
32
|
+
self.shape = None
|
33
|
+
self.element_type = None
|
34
|
+
self.dtype = None
|
35
|
+
self.arr = None
|
36
|
+
self.layout = Layout.AOS
|
37
|
+
self.grad: "TensorNdarray | None" = None
|
38
|
+
# we register with runtime, in order to enable reset to work later
|
39
|
+
impl.get_runtime().ndarrays.add(self)
|
40
|
+
|
41
|
+
def _reset(self):
|
42
|
+
"""
|
43
|
+
Called by runtime, when we call ti.reset()
|
44
|
+
"""
|
45
|
+
self.arr = None
|
46
|
+
self.grad = None
|
47
|
+
self.host_accessor = None
|
48
|
+
self.shape = None
|
49
|
+
self.element_type = None
|
50
|
+
self.dtype = None
|
51
|
+
self.layout = None
|
52
|
+
|
53
|
+
def get_type(self):
|
54
|
+
return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
|
55
|
+
|
56
|
+
@property
|
57
|
+
def element_shape(self):
|
58
|
+
"""Gets ndarray element shape.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Tuple[Int]: Ndarray element shape.
|
62
|
+
"""
|
63
|
+
raise NotImplementedError()
|
64
|
+
|
65
|
+
@python_scope
|
66
|
+
def __setitem__(self, key, value):
|
67
|
+
"""Sets ndarray element in Python scope.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
71
|
+
value (element type): Value to set.
|
72
|
+
"""
|
73
|
+
raise NotImplementedError()
|
74
|
+
|
75
|
+
@python_scope
|
76
|
+
def __getitem__(self, key):
|
77
|
+
"""Gets ndarray element in Python scope.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
element type: Value retrieved.
|
84
|
+
"""
|
85
|
+
raise NotImplementedError()
|
86
|
+
|
87
|
+
@python_scope
|
88
|
+
def fill(self, val):
|
89
|
+
"""Fills ndarray with a specific scalar value.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
val (Union[int, float]): Value to fill.
|
93
|
+
"""
|
94
|
+
if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
|
95
|
+
self._fill_by_kernel(val)
|
96
|
+
elif _ti_core.is_tensor(self.element_type):
|
97
|
+
self._fill_by_kernel(val)
|
98
|
+
elif self.dtype == primitive_types.f32:
|
99
|
+
impl.get_runtime().prog.fill_float(self.arr, val)
|
100
|
+
elif self.dtype == primitive_types.i32:
|
101
|
+
impl.get_runtime().prog.fill_int(self.arr, val)
|
102
|
+
elif self.dtype == primitive_types.u32:
|
103
|
+
impl.get_runtime().prog.fill_uint(self.arr, val)
|
104
|
+
else:
|
105
|
+
self._fill_by_kernel(val)
|
106
|
+
|
107
|
+
@python_scope
|
108
|
+
def _ndarray_to_numpy(self):
|
109
|
+
"""Converts ndarray to a numpy array.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
numpy.ndarray: The result numpy array.
|
113
|
+
"""
|
114
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
115
|
+
from gstaichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
|
116
|
+
|
117
|
+
ndarray_to_ext_arr(self, arr)
|
118
|
+
impl.get_runtime().sync()
|
119
|
+
return arr
|
120
|
+
|
121
|
+
@python_scope
|
122
|
+
def _ndarray_matrix_to_numpy(self, as_vector):
|
123
|
+
"""Converts matrix ndarray to a numpy array.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
numpy.ndarray: The result numpy array.
|
127
|
+
"""
|
128
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
129
|
+
from gstaichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
|
130
|
+
|
131
|
+
layout_is_aos = 1
|
132
|
+
ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
|
133
|
+
impl.get_runtime().sync()
|
134
|
+
return arr
|
135
|
+
|
136
|
+
@python_scope
|
137
|
+
def _ndarray_from_numpy(self, arr):
|
138
|
+
"""Loads all values from a numpy array.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
arr (numpy.ndarray): The source numpy array.
|
142
|
+
"""
|
143
|
+
if not isinstance(arr, np.ndarray):
|
144
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
145
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
146
|
+
raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
|
147
|
+
if not arr.flags.c_contiguous:
|
148
|
+
arr = np.ascontiguousarray(arr)
|
149
|
+
|
150
|
+
from gstaichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
|
151
|
+
|
152
|
+
ext_arr_to_ndarray(arr, self)
|
153
|
+
impl.get_runtime().sync()
|
154
|
+
|
155
|
+
@python_scope
|
156
|
+
def _ndarray_matrix_from_numpy(self, arr, as_vector):
|
157
|
+
"""Loads all values from a numpy array.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
arr (numpy.ndarray): The source numpy array.
|
161
|
+
"""
|
162
|
+
if not isinstance(arr, np.ndarray):
|
163
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
164
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
165
|
+
raise ValueError(
|
166
|
+
f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
|
167
|
+
)
|
168
|
+
if not arr.flags.c_contiguous:
|
169
|
+
arr = np.ascontiguousarray(arr)
|
170
|
+
|
171
|
+
from gstaichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
|
172
|
+
|
173
|
+
layout_is_aos = 1
|
174
|
+
ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
|
175
|
+
impl.get_runtime().sync()
|
176
|
+
|
177
|
+
@python_scope
|
178
|
+
def _get_element_size(self):
|
179
|
+
"""Returns the size of one element in bytes.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
Size in bytes.
|
183
|
+
"""
|
184
|
+
return self.arr.element_size()
|
185
|
+
|
186
|
+
@python_scope
|
187
|
+
def _get_nelement(self):
|
188
|
+
"""Returns the total number of elements.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
Total number of elements.
|
192
|
+
"""
|
193
|
+
return self.arr.nelement()
|
194
|
+
|
195
|
+
@python_scope
|
196
|
+
def copy_from(self, other):
|
197
|
+
"""Copies all elements from another ndarray.
|
198
|
+
|
199
|
+
The shape of the other ndarray needs to be the same as `self`.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
other (Ndarray): The source ndarray.
|
203
|
+
"""
|
204
|
+
assert isinstance(other, Ndarray)
|
205
|
+
assert tuple(self.arr.shape) == tuple(other.arr.shape)
|
206
|
+
from gstaichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
|
207
|
+
|
208
|
+
ndarray_to_ndarray(self, other)
|
209
|
+
impl.get_runtime().sync()
|
210
|
+
|
211
|
+
def _set_grad(self, grad: "TensorNdarray"):
|
212
|
+
"""Sets the gradient ndarray.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
grad (Ndarray): The gradient ndarray.
|
216
|
+
"""
|
217
|
+
self.grad = grad
|
218
|
+
|
219
|
+
def __deepcopy__(self, memo=None):
|
220
|
+
"""Copies all elements to a new ndarray.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
Ndarray: The result ndarray.
|
224
|
+
"""
|
225
|
+
raise NotImplementedError()
|
226
|
+
|
227
|
+
def _fill_by_kernel(self, val):
|
228
|
+
"""Fills ndarray with a specific scalar value using a ti.kernel.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
val (Union[int, float]): Value to fill.
|
232
|
+
"""
|
233
|
+
raise NotImplementedError()
|
234
|
+
|
235
|
+
@python_scope
|
236
|
+
def _pad_key(self, key):
|
237
|
+
if key is None:
|
238
|
+
key = ()
|
239
|
+
if not isinstance(key, (tuple, list)):
|
240
|
+
key = (key,)
|
241
|
+
if len(key) != len(self.arr.total_shape()):
|
242
|
+
raise GsTaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
|
243
|
+
return key
|
244
|
+
|
245
|
+
@python_scope
|
246
|
+
def _initialize_host_accessor(self):
|
247
|
+
if self.host_accessor:
|
248
|
+
return
|
249
|
+
impl.get_runtime().materialize()
|
250
|
+
self.host_accessor = NdarrayHostAccessor(self.arr)
|
251
|
+
|
252
|
+
|
253
|
+
class ScalarNdarray(Ndarray):
|
254
|
+
"""GsTaichi ndarray with scalar elements.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
dtype (DataType): Data type of each value.
|
258
|
+
shape (Tuple[int]): Shape of the ndarray.
|
259
|
+
"""
|
260
|
+
|
261
|
+
def __init__(self, dtype, arr_shape):
|
262
|
+
super().__init__()
|
263
|
+
self.dtype = cook_dtype(dtype)
|
264
|
+
self.arr = impl.get_runtime().prog.create_ndarray(
|
265
|
+
self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
|
266
|
+
)
|
267
|
+
self.shape = tuple(self.arr.shape)
|
268
|
+
self.element_type = dtype
|
269
|
+
|
270
|
+
def __del__(self):
|
271
|
+
if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None:
|
272
|
+
prog = impl.get_runtime()._prog
|
273
|
+
if prog is not None:
|
274
|
+
prog.delete_ndarray(self.arr)
|
275
|
+
|
276
|
+
@property
|
277
|
+
def element_shape(self):
|
278
|
+
return ()
|
279
|
+
|
280
|
+
@python_scope
|
281
|
+
def __setitem__(self, key, value):
|
282
|
+
self._initialize_host_accessor()
|
283
|
+
self.host_accessor.setter(value, *self._pad_key(key))
|
284
|
+
|
285
|
+
@python_scope
|
286
|
+
def __getitem__(self, key):
|
287
|
+
self._initialize_host_accessor()
|
288
|
+
return self.host_accessor.getter(*self._pad_key(key))
|
289
|
+
|
290
|
+
@python_scope
|
291
|
+
def to_numpy(self):
|
292
|
+
return self._ndarray_to_numpy()
|
293
|
+
|
294
|
+
@python_scope
|
295
|
+
def from_numpy(self, arr):
|
296
|
+
self._ndarray_from_numpy(arr)
|
297
|
+
|
298
|
+
def __deepcopy__(self, memo=None):
|
299
|
+
ret_arr = ScalarNdarray(self.dtype, self.shape)
|
300
|
+
ret_arr.copy_from(self)
|
301
|
+
return ret_arr
|
302
|
+
|
303
|
+
def _fill_by_kernel(self, val):
|
304
|
+
from gstaichi._kernels import fill_ndarray # pylint: disable=C0415
|
305
|
+
|
306
|
+
fill_ndarray(self, val)
|
307
|
+
|
308
|
+
def __repr__(self):
|
309
|
+
return "<ti.ndarray>"
|
310
|
+
|
311
|
+
|
312
|
+
class NdarrayHostAccessor:
|
313
|
+
def __init__(self, ndarray):
|
314
|
+
dtype = ndarray.element_data_type()
|
315
|
+
if is_real(dtype):
|
316
|
+
|
317
|
+
def getter(*key):
|
318
|
+
return ndarray.read_float(key)
|
319
|
+
|
320
|
+
def setter(value, *key):
|
321
|
+
ndarray.write_float(key, value)
|
322
|
+
|
323
|
+
else:
|
324
|
+
if is_signed(dtype):
|
325
|
+
|
326
|
+
def getter(*key):
|
327
|
+
return ndarray.read_int(key)
|
328
|
+
|
329
|
+
else:
|
330
|
+
|
331
|
+
def getter(*key):
|
332
|
+
return ndarray.read_uint(key)
|
333
|
+
|
334
|
+
def setter(value, *key):
|
335
|
+
ndarray.write_int(key, value)
|
336
|
+
|
337
|
+
self.getter = getter
|
338
|
+
self.setter = setter
|
339
|
+
|
340
|
+
|
341
|
+
class NdarrayHostAccess:
|
342
|
+
"""Class for accessing VectorNdarray/MatrixNdarray in Python scope.
|
343
|
+
Args:
|
344
|
+
arr (Union[VectorNdarray, MatrixNdarray]): See above.
|
345
|
+
indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
|
346
|
+
indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
|
347
|
+
"""
|
348
|
+
|
349
|
+
def __init__(self, arr, indices_first, indices_second):
|
350
|
+
self.ndarr = arr
|
351
|
+
self.arr = arr.arr
|
352
|
+
self.indices = indices_first + indices_second
|
353
|
+
|
354
|
+
def getter():
|
355
|
+
self.ndarr._initialize_host_accessor()
|
356
|
+
return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
|
357
|
+
|
358
|
+
def setter(value):
|
359
|
+
self.ndarr._initialize_host_accessor()
|
360
|
+
self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
|
361
|
+
|
362
|
+
self.getter = getter
|
363
|
+
self.setter = setter
|
364
|
+
|
365
|
+
|
366
|
+
__all__ = ["Ndarray", "ScalarNdarray"]
|