gstaichi 0.1.25.dev0__cp313-cp313-macosx_15_0_arm64.whl → 2.0.0__cp313-cp313-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +1 -1
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +11 -41
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -1
- gstaichi/examples/minimal.py +1 -1
- gstaichi/lang/__init__.py +1 -1
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_template_mapper.py +16 -20
- gstaichi/lang/_wrap_inspect.py +27 -1
- gstaichi/lang/ast/ast_transformer.py +7 -2
- gstaichi/lang/ast/ast_transformer_utils.py +18 -13
- gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
- gstaichi/lang/field.py +0 -38
- gstaichi/lang/impl.py +25 -24
- gstaichi/lang/kernel_arguments.py +28 -30
- gstaichi/lang/kernel_impl.py +154 -200
- gstaichi/lang/matrix.py +0 -46
- gstaichi/lang/struct.py +0 -45
- gstaichi/lang/util.py +11 -80
- gstaichi/types/annotations.py +10 -5
- gstaichi/types/compound_types.py +1 -20
- gstaichi/types/ndarray_type.py +31 -11
- gstaichi/types/utils.py +0 -2
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/METADATA +2 -1
- gstaichi-2.0.0.dist-info/RECORD +177 -0
- gstaichi/__main__.py +0 -5
- gstaichi/_main.py +0 -545
- gstaichi/lang/argpack.py +0 -411
- gstaichi-0.1.25.dev0.dist-info/RECORD +0 -168
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +0 -2
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,75 @@
|
|
1
|
+
from typing import Any, Iterable, Sequence
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from .._wrap_inspect import FunctionSourceInfo
|
6
|
+
from . import args_hasher, config_hasher, function_hasher
|
7
|
+
from .fast_caching_types import HashedFunctionSourceInfo
|
8
|
+
from .hash_utils import hash_iterable_strings
|
9
|
+
from .python_side_cache import PythonSideCache
|
10
|
+
|
11
|
+
|
12
|
+
def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
|
13
|
+
"""
|
14
|
+
cache key takes into account:
|
15
|
+
- arg types
|
16
|
+
- cache value arg values
|
17
|
+
- kernel function (but not sub functions)
|
18
|
+
- compilation config (which includes arch, and debug)
|
19
|
+
"""
|
20
|
+
args_hash = args_hasher.hash_args(args)
|
21
|
+
if args_hash is None:
|
22
|
+
return None
|
23
|
+
kernel_hash = function_hasher.hash_kernel(kernel_source_info)
|
24
|
+
config_hash = config_hasher.hash_compile_config()
|
25
|
+
cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
|
26
|
+
return cache_key
|
27
|
+
|
28
|
+
|
29
|
+
class CacheValue(BaseModel):
|
30
|
+
hashed_function_source_infos: list[HashedFunctionSourceInfo]
|
31
|
+
|
32
|
+
|
33
|
+
def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
|
34
|
+
"""
|
35
|
+
Note that unlike other caches, this cache is not going to store the actual value we want.
|
36
|
+
This cache is only used for verification that our cache key is valid. Big picture:
|
37
|
+
- we have a cache key, based on args and top level kernel function
|
38
|
+
- we want to use this to look up LLVM IR, in C++ side cache
|
39
|
+
- however, before doing that, we first want to validate that the source code didn't change
|
40
|
+
- i.e. is our cache key still valid?
|
41
|
+
- the python side cache contains information we will use to verify that our cache key is valid
|
42
|
+
- ie the list of function source infos
|
43
|
+
"""
|
44
|
+
if not cache_key:
|
45
|
+
return
|
46
|
+
cache = PythonSideCache()
|
47
|
+
hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
|
48
|
+
cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
|
49
|
+
cache.store(cache_key, cache_value_obj.json())
|
50
|
+
|
51
|
+
|
52
|
+
def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
|
53
|
+
cache = PythonSideCache()
|
54
|
+
maybe_cache_value_json = cache.try_load(cache_key)
|
55
|
+
if maybe_cache_value_json is None:
|
56
|
+
return None
|
57
|
+
cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
|
58
|
+
return cache_value_obj.hashed_function_source_infos
|
59
|
+
|
60
|
+
|
61
|
+
def validate_cache_key(cache_key: str) -> bool:
|
62
|
+
"""
|
63
|
+
loads function source infos from cache, if available
|
64
|
+
checks the hashes against the current source code
|
65
|
+
"""
|
66
|
+
maybe_hashed_function_source_infos = _try_load(cache_key)
|
67
|
+
if not maybe_hashed_function_source_infos:
|
68
|
+
return False
|
69
|
+
return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
|
70
|
+
|
71
|
+
|
72
|
+
def dump_stats() -> None:
|
73
|
+
print("dump stats")
|
74
|
+
args_hasher.dump_stats()
|
75
|
+
function_hasher.dump_stats()
|
@@ -0,0 +1,212 @@
|
|
1
|
+
import ast
|
2
|
+
import dataclasses
|
3
|
+
import inspect
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from gstaichi.lang import util
|
7
|
+
from gstaichi.lang._dataclass_util import create_flat_name
|
8
|
+
from gstaichi.lang.ast import (
|
9
|
+
ASTTransformerContext,
|
10
|
+
)
|
11
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
12
|
+
|
13
|
+
|
14
|
+
def _populate_struct_locals_from_params_dict(basename: str, struct_locals, struct_type) -> None:
|
15
|
+
"""
|
16
|
+
We are populating struct locals from a type included in function parameters, or one of their subtypes
|
17
|
+
|
18
|
+
struct_locals will be a list of all possible unpacked variable names we can form from the struct.
|
19
|
+
basename is used to take into account the parent struct's name. For example, lets say we have:
|
20
|
+
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class StructAB:
|
23
|
+
a:
|
24
|
+
b:
|
25
|
+
struct_cd: StructCD
|
26
|
+
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class StructCD:
|
29
|
+
c:
|
30
|
+
d:
|
31
|
+
struct_ef: StructEF
|
32
|
+
|
33
|
+
@dataclasses.dataclass
|
34
|
+
class StructEF:
|
35
|
+
e:
|
36
|
+
f:
|
37
|
+
|
38
|
+
... and the function parameters look like: `def foo(struct_ab: StructAB)`
|
39
|
+
|
40
|
+
then all possible variables we could form from this are:
|
41
|
+
- struct_ab.a
|
42
|
+
- struct_ab.b
|
43
|
+
- struct_ab.struct_cd.c
|
44
|
+
- struct_ab.struct_cd.d
|
45
|
+
- struct_ab.struct_cd.strucdt_ef.e
|
46
|
+
- struct_ab.struct_cd.strucdt_ef.f
|
47
|
+
|
48
|
+
And the members of struct_locals should be:
|
49
|
+
- __ti_struct_ab__ti_a
|
50
|
+
- __ti_struct_ab__ti_b
|
51
|
+
- __ti_struct_ab__ti_struct_cd__ti_c
|
52
|
+
- __ti_struct_ab__ti_struct_cd__ti_d
|
53
|
+
- __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_e
|
54
|
+
- __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
|
55
|
+
"""
|
56
|
+
for field in dataclasses.fields(struct_type):
|
57
|
+
child_name = create_flat_name(basename, field.name)
|
58
|
+
if dataclasses.is_dataclass(field.type):
|
59
|
+
_populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
|
60
|
+
else:
|
61
|
+
struct_locals.add(child_name)
|
62
|
+
|
63
|
+
|
64
|
+
def extract_struct_locals_from_context(ctx: ASTTransformerContext) -> set[str]:
|
65
|
+
"""
|
66
|
+
Provides meta information for later tarnsformation of nodes in AST
|
67
|
+
|
68
|
+
- Uses ctx.func.func to get the function signature.
|
69
|
+
- Searches this for any dataclasses:
|
70
|
+
- If it finds any dataclasses, then converts them into expanded names.
|
71
|
+
- E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
|
72
|
+
{"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
|
73
|
+
"""
|
74
|
+
struct_locals = set()
|
75
|
+
assert ctx.func is not None
|
76
|
+
sig = inspect.signature(ctx.func.func)
|
77
|
+
parameters = sig.parameters
|
78
|
+
for param_name, parameter in parameters.items():
|
79
|
+
if dataclasses.is_dataclass(parameter.annotation):
|
80
|
+
for field in dataclasses.fields(parameter.annotation):
|
81
|
+
child_name = create_flat_name(param_name, field.name)
|
82
|
+
# child_name = f"__ti_{param_name}__ti_{field.name}"
|
83
|
+
if dataclasses.is_dataclass(field.type):
|
84
|
+
_populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
|
85
|
+
continue
|
86
|
+
struct_locals.add(child_name)
|
87
|
+
return struct_locals
|
88
|
+
|
89
|
+
|
90
|
+
def expand_func_arguments(arguments: list[ArgMetadata]) -> list[ArgMetadata]:
|
91
|
+
"""
|
92
|
+
Used to expand arguments for @ti.func
|
93
|
+
"""
|
94
|
+
expanded_arguments = []
|
95
|
+
for i, argument in enumerate(arguments):
|
96
|
+
if dataclasses.is_dataclass(argument.annotation):
|
97
|
+
for field in dataclasses.fields(argument.annotation):
|
98
|
+
child_name = create_flat_name(argument.name, field.name)
|
99
|
+
if dataclasses.is_dataclass(field.type):
|
100
|
+
new_arg = ArgMetadata(
|
101
|
+
annotation=field.type,
|
102
|
+
name=child_name,
|
103
|
+
default=argument.default,
|
104
|
+
)
|
105
|
+
child_args = expand_func_arguments([new_arg])
|
106
|
+
expanded_arguments += child_args
|
107
|
+
else:
|
108
|
+
new_argument = ArgMetadata(
|
109
|
+
annotation=field.type,
|
110
|
+
name=child_name,
|
111
|
+
)
|
112
|
+
expanded_arguments.append(new_argument)
|
113
|
+
else:
|
114
|
+
expanded_arguments.append(argument)
|
115
|
+
return expanded_arguments
|
116
|
+
|
117
|
+
|
118
|
+
class FlattenAttributeNameTransformer(ast.NodeTransformer):
|
119
|
+
def __init__(self, struct_locals: set[str]) -> None:
|
120
|
+
self.struct_locals = struct_locals
|
121
|
+
|
122
|
+
def visit_Attribute(self, node):
|
123
|
+
flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node)
|
124
|
+
if not flat_name or flat_name not in self.struct_locals:
|
125
|
+
return self.generic_visit(node)
|
126
|
+
return ast.copy_location(ast.Name(id=flat_name, ctx=node.ctx), node)
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def _flatten_attribute_name(node: ast.Attribute) -> str | None:
|
130
|
+
"""
|
131
|
+
see unpack_ast_struct_expressions docstring for more explanation
|
132
|
+
"""
|
133
|
+
if isinstance(node.value, ast.Name):
|
134
|
+
return create_flat_name(node.value.id, node.attr)
|
135
|
+
if isinstance(node.value, ast.Attribute):
|
136
|
+
child_flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node.value)
|
137
|
+
if not child_flat_name:
|
138
|
+
return None
|
139
|
+
return create_flat_name(child_flat_name, node.attr)
|
140
|
+
return None
|
141
|
+
|
142
|
+
|
143
|
+
def unpack_ast_struct_expressions(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
|
144
|
+
"""
|
145
|
+
Transform nodes in AST, to flatten access to struct members
|
146
|
+
|
147
|
+
Examples of things we will transform/flatten:
|
148
|
+
|
149
|
+
# my_struct_ab.a
|
150
|
+
# Attribute(value=Name())
|
151
|
+
Attribute(
|
152
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
153
|
+
attr='a',
|
154
|
+
ctx=Load())
|
155
|
+
=>
|
156
|
+
# __ti_my_struct_ab__ti_a
|
157
|
+
Name(id='__ti_my_struct_ab__ti_a', ctx=Load()
|
158
|
+
|
159
|
+
# my_struct_ab.struct_cd.d
|
160
|
+
# Attribute(value=Attribute(value=Name()))
|
161
|
+
Attribute(
|
162
|
+
value=Attribute(
|
163
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
164
|
+
attr='struct_cd',
|
165
|
+
ctx=Load()),
|
166
|
+
attr='d',
|
167
|
+
ctx=Load())
|
168
|
+
visit_attribute
|
169
|
+
=>
|
170
|
+
# __ti_my_struct_ab__ti_struct_cd__ti_d
|
171
|
+
Name(id='__ti_my_struct_ab__ti_struct_cd__ti_d', ctx=Load()
|
172
|
+
|
173
|
+
# my_struct_ab.struct_cd.struct_ef.f
|
174
|
+
# Attribute(value=Attribute(value=Name()))
|
175
|
+
Attribute(
|
176
|
+
value=Attribute(
|
177
|
+
value=Attribute(
|
178
|
+
value=Name(id='my_struct_ab', ctx=Load()),
|
179
|
+
attr='struct_cd',
|
180
|
+
ctx=Load()),
|
181
|
+
attr='struct_ef',
|
182
|
+
ctx=Load()),
|
183
|
+
attr='f',
|
184
|
+
ctx=Load())
|
185
|
+
=>
|
186
|
+
# __ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
|
187
|
+
Name(id='__ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f', ctx=Load()
|
188
|
+
"""
|
189
|
+
transformer = FlattenAttributeNameTransformer(struct_locals=struct_locals)
|
190
|
+
new_tree = transformer.visit(tree)
|
191
|
+
ast.fix_missing_locations(new_tree)
|
192
|
+
return new_tree
|
193
|
+
|
194
|
+
|
195
|
+
def populate_global_vars_from_dataclass(
|
196
|
+
param_name: str,
|
197
|
+
param_type: Any,
|
198
|
+
py_arg: Any,
|
199
|
+
global_vars: dict[str, Any],
|
200
|
+
):
|
201
|
+
for field in dataclasses.fields(param_type):
|
202
|
+
child_value = getattr(py_arg, field.name)
|
203
|
+
flat_name = create_flat_name(param_name, field.name)
|
204
|
+
if dataclasses.is_dataclass(field.type):
|
205
|
+
populate_global_vars_from_dataclass(
|
206
|
+
param_name=flat_name,
|
207
|
+
param_type=field.type,
|
208
|
+
py_arg=child_value,
|
209
|
+
global_vars=global_vars,
|
210
|
+
)
|
211
|
+
elif util.is_ti_template(field.type):
|
212
|
+
global_vars[flat_name] = child_value
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import weakref
|
3
|
-
from typing import Any, Union
|
3
|
+
from typing import Any, Callable, Union
|
4
4
|
|
5
5
|
import gstaichi.lang
|
6
6
|
import gstaichi.lang._ndarray
|
@@ -8,24 +8,27 @@ import gstaichi.lang._texture
|
|
8
8
|
import gstaichi.lang.expr
|
9
9
|
import gstaichi.lang.snode
|
10
10
|
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi.lang import _dataclass_util
|
11
12
|
from gstaichi.lang.any_array import AnyArray
|
12
|
-
from gstaichi.lang.argpack import ArgPack, ArgPackType
|
13
13
|
from gstaichi.lang.exception import (
|
14
14
|
GsTaichiRuntimeTypeError,
|
15
15
|
)
|
16
|
-
from gstaichi.lang.kernel_arguments import
|
16
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
17
17
|
from gstaichi.lang.matrix import MatrixType
|
18
|
-
from gstaichi.lang.util import to_gstaichi_type
|
18
|
+
from gstaichi.lang.util import is_ti_template, to_gstaichi_type
|
19
19
|
from gstaichi.types import (
|
20
20
|
ndarray_type,
|
21
21
|
sparse_matrix_builder,
|
22
22
|
template,
|
23
23
|
texture_type,
|
24
24
|
)
|
25
|
+
from gstaichi.types.enums import AutodiffMode
|
26
|
+
|
27
|
+
CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
|
28
|
+
|
25
29
|
|
26
30
|
AnnotationType = Union[
|
27
31
|
template,
|
28
|
-
ArgPackType,
|
29
32
|
"texture_type.TextureType",
|
30
33
|
"texture_type.RWTextureType",
|
31
34
|
ndarray_type.NdarrayType,
|
@@ -34,7 +37,7 @@ AnnotationType = Union[
|
|
34
37
|
]
|
35
38
|
|
36
39
|
|
37
|
-
class
|
40
|
+
class TemplateMapper:
|
38
41
|
"""
|
39
42
|
This should probably be renamed to sometihng like FeatureMapper, or
|
40
43
|
FeatureExtractor, since:
|
@@ -46,15 +49,15 @@ class GsTaichiCallableTemplateMapper:
|
|
46
49
|
- these are returned as a heterogeneous tuple, whose contents depends on the type
|
47
50
|
"""
|
48
51
|
|
49
|
-
def __init__(self, arguments: list[
|
50
|
-
self.arguments: list[
|
52
|
+
def __init__(self, arguments: list[ArgMetadata], template_slot_locations: list[int]) -> None:
|
53
|
+
self.arguments: list[ArgMetadata] = arguments
|
51
54
|
self.num_args: int = len(arguments)
|
52
55
|
self.template_slot_locations: list[int] = template_slot_locations
|
53
56
|
self.mapping: dict[tuple[Any, ...], int] = {}
|
54
57
|
|
55
58
|
@staticmethod
|
56
|
-
def extract_arg(arg, annotation: AnnotationType, arg_name: str) -> Any:
|
57
|
-
if
|
59
|
+
def extract_arg(arg: Any, annotation: AnnotationType, arg_name: str) -> Any:
|
60
|
+
if is_ti_template(annotation):
|
58
61
|
if isinstance(arg, gstaichi.lang.snode.SNode):
|
59
62
|
return arg.ptr
|
60
63
|
if isinstance(arg, gstaichi.lang.expr.Expr):
|
@@ -62,7 +65,7 @@ class GsTaichiCallableTemplateMapper:
|
|
62
65
|
if isinstance(arg, _ti_core.ExprCxx):
|
63
66
|
return arg.get_underlying_ptr_address()
|
64
67
|
if isinstance(arg, tuple):
|
65
|
-
return tuple(
|
68
|
+
return tuple(TemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
|
66
69
|
if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
|
67
70
|
raise GsTaichiRuntimeTypeError(
|
68
71
|
"Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
|
@@ -81,19 +84,12 @@ class GsTaichiCallableTemplateMapper:
|
|
81
84
|
|
82
85
|
# [Primitive arguments] Return the value
|
83
86
|
return arg
|
84
|
-
if isinstance(annotation, ArgPackType):
|
85
|
-
if not isinstance(arg, ArgPack):
|
86
|
-
raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
|
87
|
-
return tuple(
|
88
|
-
GsTaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
|
89
|
-
for index, (name, dtype) in enumerate(annotation.members.items())
|
90
|
-
)
|
91
87
|
if dataclasses.is_dataclass(annotation):
|
92
88
|
_res_l = []
|
93
89
|
for field in dataclasses.fields(annotation):
|
94
90
|
field_value = getattr(arg, field.name)
|
95
|
-
|
96
|
-
field_extracted =
|
91
|
+
child_name = _dataclass_util.create_flat_name(arg_name, field.name)
|
92
|
+
field_extracted = TemplateMapper.extract_arg(field_value, field.type, child_name)
|
97
93
|
_res_l.append(field_extracted)
|
98
94
|
return tuple(_res_l)
|
99
95
|
if isinstance(annotation, texture_type.TextureType):
|
gstaichi/lang/_wrap_inspect.py
CHANGED
@@ -19,8 +19,10 @@ import atexit
|
|
19
19
|
import inspect
|
20
20
|
import os
|
21
21
|
import tempfile
|
22
|
+
from typing import Callable
|
22
23
|
|
23
24
|
import dill
|
25
|
+
from pydantic import BaseModel
|
24
26
|
|
25
27
|
_builtin_getfile = inspect.getfile
|
26
28
|
_builtin_findsource = inspect.findsource
|
@@ -186,4 +188,28 @@ def getsourcefile(obj):
|
|
186
188
|
return ret
|
187
189
|
|
188
190
|
|
189
|
-
|
191
|
+
class FunctionSourceInfo(BaseModel):
|
192
|
+
function_name: str
|
193
|
+
filepath: str
|
194
|
+
start_lineno: int
|
195
|
+
end_lineno: int
|
196
|
+
|
197
|
+
class Config:
|
198
|
+
frozen = True
|
199
|
+
|
200
|
+
|
201
|
+
def get_source_info_and_src(func: Callable) -> tuple[FunctionSourceInfo, list[str]]:
|
202
|
+
file = getsourcefile(func)
|
203
|
+
name = func.__name__
|
204
|
+
src, start_lineno = getsourcelines(func)
|
205
|
+
end_lineno = start_lineno + len(src) - 1
|
206
|
+
func_info = FunctionSourceInfo(
|
207
|
+
function_name=name,
|
208
|
+
filepath=file,
|
209
|
+
start_lineno=start_lineno,
|
210
|
+
end_lineno=end_lineno,
|
211
|
+
)
|
212
|
+
return (func_info, src)
|
213
|
+
|
214
|
+
|
215
|
+
__all__ = ["getsourcelines", "getsourcefile", "get_source_info_and_src"]
|
@@ -2,10 +2,11 @@
|
|
2
2
|
|
3
3
|
import ast
|
4
4
|
import collections.abc
|
5
|
+
import dataclasses
|
5
6
|
import itertools
|
6
7
|
import warnings
|
7
8
|
from ast import unparse
|
8
|
-
from typing import Any,
|
9
|
+
from typing import Any, Sequence, Type
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
|
@@ -40,7 +41,7 @@ from gstaichi.types import primitive_types
|
|
40
41
|
from gstaichi.types.utils import is_integral
|
41
42
|
|
42
43
|
|
43
|
-
def reshape_list(flat_list: list[Any], target_shape:
|
44
|
+
def reshape_list(flat_list: list[Any], target_shape: Sequence[int]) -> list[Any]:
|
44
45
|
if len(target_shape) < 2:
|
45
46
|
return flat_list
|
46
47
|
|
@@ -645,6 +646,8 @@ class ASTTransformer(Builder):
|
|
645
646
|
|
646
647
|
node.ptr = getattr(tensor_ops, node.attr)
|
647
648
|
setattr(node, "caller", node.value.ptr)
|
649
|
+
elif dataclasses.is_dataclass(node.value.ptr):
|
650
|
+
node.ptr = next(field.type for field in dataclasses.fields(node.value.ptr))
|
648
651
|
else:
|
649
652
|
node.ptr = getattr(node.value.ptr, node.attr)
|
650
653
|
return node.ptr
|
@@ -1309,6 +1312,8 @@ build_stmt = ASTTransformer()
|
|
1309
1312
|
|
1310
1313
|
|
1311
1314
|
def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
|
1315
|
+
# TODO: Should we just make this part of ASTTransformer? Then, easier to pass around (just
|
1316
|
+
# pass the ASTTransformer object around)
|
1312
1317
|
with ctx.variable_scope_guard():
|
1313
1318
|
for stmt in stmts:
|
1314
1319
|
if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
|
@@ -27,7 +27,8 @@ if TYPE_CHECKING:
|
|
27
27
|
|
28
28
|
class Builder:
|
29
29
|
def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
|
30
|
-
|
30
|
+
method_name = "build_" + node.__class__.__name__
|
31
|
+
method = getattr(self, method_name, None)
|
31
32
|
try:
|
32
33
|
if method is None:
|
33
34
|
error_msg = f'Unsupported node "{node.__class__.__name__}"'
|
@@ -155,17 +156,18 @@ class ReturnStatus(Enum):
|
|
155
156
|
class ASTTransformerContext:
|
156
157
|
def __init__(
|
157
158
|
self,
|
158
|
-
excluded_parameters
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
159
|
+
excluded_parameters,
|
160
|
+
end_lineno: int,
|
161
|
+
is_kernel: bool,
|
162
|
+
func: "Func | Kernel",
|
163
|
+
arg_features: list[tuple[Any, ...]] | None,
|
164
|
+
global_vars: dict[str, Any],
|
165
|
+
argument_data,
|
166
|
+
file: str,
|
167
|
+
src: list[str],
|
168
|
+
start_lineno: int,
|
169
|
+
ast_builder: ASTBuilder | None,
|
170
|
+
is_real_function: bool,
|
169
171
|
):
|
170
172
|
self.func = func
|
171
173
|
self.local_scopes: list[dict[str, Any]] = []
|
@@ -176,7 +178,7 @@ class ASTTransformerContext:
|
|
176
178
|
self.returns = None
|
177
179
|
self.global_vars = global_vars
|
178
180
|
self.argument_data = argument_data
|
179
|
-
self.return_data = None
|
181
|
+
self.return_data: tuple[Any, ...] | Any | None = None
|
180
182
|
self.file = file
|
181
183
|
self.src = src
|
182
184
|
self.indent = 0
|
@@ -186,6 +188,8 @@ class ASTTransformerContext:
|
|
186
188
|
else:
|
187
189
|
break
|
188
190
|
self.lineno_offset = start_lineno - 1
|
191
|
+
self.start_lineno = start_lineno
|
192
|
+
self.end_lineno = end_lineno
|
189
193
|
self.raised = False
|
190
194
|
self.non_static_control_flow_status = NonStaticControlFlowStatus()
|
191
195
|
self.static_scope_status = StaticScopeStatus()
|
@@ -194,6 +198,7 @@ class ASTTransformerContext:
|
|
194
198
|
self.visited_funcdef = False
|
195
199
|
self.is_real_function = is_real_function
|
196
200
|
self.kernel_args: list = []
|
201
|
+
self.only_parse_function_def: bool = False
|
197
202
|
|
198
203
|
# e.g.: FunctionDef, Module, Global
|
199
204
|
def variable_scope_guard(self):
|