gstaichi 1.0.1__cp312-cp312-win_amd64.whl → 2.0.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. gstaichi/CHANGELOG.md +3 -3
  2. gstaichi/_lib/core/gstaichi_python.cp312-win_amd64.pyd +0 -0
  3. gstaichi/_lib/core/gstaichi_python.pyi +13 -41
  4. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  5. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  6. gstaichi/_test_tools/__init__.py +18 -0
  7. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  8. gstaichi/_test_tools/textwrap2.py +6 -0
  9. gstaichi/_version.py +1 -1
  10. gstaichi/examples/minimal.py +1 -1
  11. gstaichi/lang/__init__.py +1 -1
  12. gstaichi/lang/_dataclass_util.py +31 -0
  13. gstaichi/lang/_fast_caching/__init__.py +3 -0
  14. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  15. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  16. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  17. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  18. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  19. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  20. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  21. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  22. gstaichi/lang/_template_mapper.py +16 -20
  23. gstaichi/lang/_wrap_inspect.py +27 -1
  24. gstaichi/lang/ast/ast_transformer.py +7 -2
  25. gstaichi/lang/ast/ast_transformer_utils.py +18 -13
  26. gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
  27. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
  28. gstaichi/lang/field.py +0 -38
  29. gstaichi/lang/impl.py +25 -24
  30. gstaichi/lang/kernel_arguments.py +28 -30
  31. gstaichi/lang/kernel_impl.py +154 -200
  32. gstaichi/lang/matrix.py +0 -46
  33. gstaichi/lang/struct.py +0 -45
  34. gstaichi/lang/util.py +11 -80
  35. gstaichi/types/annotations.py +10 -5
  36. gstaichi/types/compound_types.py +1 -20
  37. gstaichi/types/ndarray_type.py +31 -11
  38. gstaichi/types/utils.py +0 -2
  39. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  40. gstaichi-2.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  41. gstaichi-2.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  42. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  43. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  44. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  45. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  46. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  47. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  48. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/SPIRV-Tools.lib +0 -0
  49. gstaichi-2.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  50. gstaichi-2.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  51. gstaichi-2.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  52. gstaichi-2.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  53. gstaichi-2.0.0.data/data/lib/glfw3.lib +0 -0
  54. {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/METADATA +2 -1
  55. {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/RECORD +81 -63
  56. gstaichi/lang/argpack.py +0 -411
  57. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  58. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  59. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  60. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  61. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  62. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  63. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  64. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  65. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  66. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  67. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  68. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  69. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  70. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  71. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  72. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  73. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  74. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  75. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  76. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.h +0 -0
  77. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  78. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/linker.hpp +0 -0
  79. {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  80. {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/WHEEL +0 -0
  81. {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/licenses/LICENSE +0 -0
  82. {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,75 @@
1
+ from typing import Any, Iterable, Sequence
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from .._wrap_inspect import FunctionSourceInfo
6
+ from . import args_hasher, config_hasher, function_hasher
7
+ from .fast_caching_types import HashedFunctionSourceInfo
8
+ from .hash_utils import hash_iterable_strings
9
+ from .python_side_cache import PythonSideCache
10
+
11
+
12
+ def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
13
+ """
14
+ cache key takes into account:
15
+ - arg types
16
+ - cache value arg values
17
+ - kernel function (but not sub functions)
18
+ - compilation config (which includes arch, and debug)
19
+ """
20
+ args_hash = args_hasher.hash_args(args)
21
+ if args_hash is None:
22
+ return None
23
+ kernel_hash = function_hasher.hash_kernel(kernel_source_info)
24
+ config_hash = config_hasher.hash_compile_config()
25
+ cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
26
+ return cache_key
27
+
28
+
29
+ class CacheValue(BaseModel):
30
+ hashed_function_source_infos: list[HashedFunctionSourceInfo]
31
+
32
+
33
+ def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
34
+ """
35
+ Note that unlike other caches, this cache is not going to store the actual value we want.
36
+ This cache is only used for verification that our cache key is valid. Big picture:
37
+ - we have a cache key, based on args and top level kernel function
38
+ - we want to use this to look up LLVM IR, in C++ side cache
39
+ - however, before doing that, we first want to validate that the source code didn't change
40
+ - i.e. is our cache key still valid?
41
+ - the python side cache contains information we will use to verify that our cache key is valid
42
+ - ie the list of function source infos
43
+ """
44
+ if not cache_key:
45
+ return
46
+ cache = PythonSideCache()
47
+ hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
48
+ cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
49
+ cache.store(cache_key, cache_value_obj.json())
50
+
51
+
52
+ def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
53
+ cache = PythonSideCache()
54
+ maybe_cache_value_json = cache.try_load(cache_key)
55
+ if maybe_cache_value_json is None:
56
+ return None
57
+ cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
58
+ return cache_value_obj.hashed_function_source_infos
59
+
60
+
61
+ def validate_cache_key(cache_key: str) -> bool:
62
+ """
63
+ loads function source infos from cache, if available
64
+ checks the hashes against the current source code
65
+ """
66
+ maybe_hashed_function_source_infos = _try_load(cache_key)
67
+ if not maybe_hashed_function_source_infos:
68
+ return False
69
+ return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
70
+
71
+
72
+ def dump_stats() -> None:
73
+ print("dump stats")
74
+ args_hasher.dump_stats()
75
+ function_hasher.dump_stats()
@@ -0,0 +1,212 @@
1
+ import ast
2
+ import dataclasses
3
+ import inspect
4
+ from typing import Any
5
+
6
+ from gstaichi.lang import util
7
+ from gstaichi.lang._dataclass_util import create_flat_name
8
+ from gstaichi.lang.ast import (
9
+ ASTTransformerContext,
10
+ )
11
+ from gstaichi.lang.kernel_arguments import ArgMetadata
12
+
13
+
14
+ def _populate_struct_locals_from_params_dict(basename: str, struct_locals, struct_type) -> None:
15
+ """
16
+ We are populating struct locals from a type included in function parameters, or one of their subtypes
17
+
18
+ struct_locals will be a list of all possible unpacked variable names we can form from the struct.
19
+ basename is used to take into account the parent struct's name. For example, lets say we have:
20
+
21
+ @dataclasses.dataclass
22
+ class StructAB:
23
+ a:
24
+ b:
25
+ struct_cd: StructCD
26
+
27
+ @dataclasses.dataclass
28
+ class StructCD:
29
+ c:
30
+ d:
31
+ struct_ef: StructEF
32
+
33
+ @dataclasses.dataclass
34
+ class StructEF:
35
+ e:
36
+ f:
37
+
38
+ ... and the function parameters look like: `def foo(struct_ab: StructAB)`
39
+
40
+ then all possible variables we could form from this are:
41
+ - struct_ab.a
42
+ - struct_ab.b
43
+ - struct_ab.struct_cd.c
44
+ - struct_ab.struct_cd.d
45
+ - struct_ab.struct_cd.strucdt_ef.e
46
+ - struct_ab.struct_cd.strucdt_ef.f
47
+
48
+ And the members of struct_locals should be:
49
+ - __ti_struct_ab__ti_a
50
+ - __ti_struct_ab__ti_b
51
+ - __ti_struct_ab__ti_struct_cd__ti_c
52
+ - __ti_struct_ab__ti_struct_cd__ti_d
53
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_e
54
+ - __ti_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
55
+ """
56
+ for field in dataclasses.fields(struct_type):
57
+ child_name = create_flat_name(basename, field.name)
58
+ if dataclasses.is_dataclass(field.type):
59
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
60
+ else:
61
+ struct_locals.add(child_name)
62
+
63
+
64
+ def extract_struct_locals_from_context(ctx: ASTTransformerContext) -> set[str]:
65
+ """
66
+ Provides meta information for later tarnsformation of nodes in AST
67
+
68
+ - Uses ctx.func.func to get the function signature.
69
+ - Searches this for any dataclasses:
70
+ - If it finds any dataclasses, then converts them into expanded names.
71
+ - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
72
+ {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
73
+ """
74
+ struct_locals = set()
75
+ assert ctx.func is not None
76
+ sig = inspect.signature(ctx.func.func)
77
+ parameters = sig.parameters
78
+ for param_name, parameter in parameters.items():
79
+ if dataclasses.is_dataclass(parameter.annotation):
80
+ for field in dataclasses.fields(parameter.annotation):
81
+ child_name = create_flat_name(param_name, field.name)
82
+ # child_name = f"__ti_{param_name}__ti_{field.name}"
83
+ if dataclasses.is_dataclass(field.type):
84
+ _populate_struct_locals_from_params_dict(child_name, struct_locals, field.type)
85
+ continue
86
+ struct_locals.add(child_name)
87
+ return struct_locals
88
+
89
+
90
+ def expand_func_arguments(arguments: list[ArgMetadata]) -> list[ArgMetadata]:
91
+ """
92
+ Used to expand arguments for @ti.func
93
+ """
94
+ expanded_arguments = []
95
+ for i, argument in enumerate(arguments):
96
+ if dataclasses.is_dataclass(argument.annotation):
97
+ for field in dataclasses.fields(argument.annotation):
98
+ child_name = create_flat_name(argument.name, field.name)
99
+ if dataclasses.is_dataclass(field.type):
100
+ new_arg = ArgMetadata(
101
+ annotation=field.type,
102
+ name=child_name,
103
+ default=argument.default,
104
+ )
105
+ child_args = expand_func_arguments([new_arg])
106
+ expanded_arguments += child_args
107
+ else:
108
+ new_argument = ArgMetadata(
109
+ annotation=field.type,
110
+ name=child_name,
111
+ )
112
+ expanded_arguments.append(new_argument)
113
+ else:
114
+ expanded_arguments.append(argument)
115
+ return expanded_arguments
116
+
117
+
118
+ class FlattenAttributeNameTransformer(ast.NodeTransformer):
119
+ def __init__(self, struct_locals: set[str]) -> None:
120
+ self.struct_locals = struct_locals
121
+
122
+ def visit_Attribute(self, node):
123
+ flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node)
124
+ if not flat_name or flat_name not in self.struct_locals:
125
+ return self.generic_visit(node)
126
+ return ast.copy_location(ast.Name(id=flat_name, ctx=node.ctx), node)
127
+
128
+ @staticmethod
129
+ def _flatten_attribute_name(node: ast.Attribute) -> str | None:
130
+ """
131
+ see unpack_ast_struct_expressions docstring for more explanation
132
+ """
133
+ if isinstance(node.value, ast.Name):
134
+ return create_flat_name(node.value.id, node.attr)
135
+ if isinstance(node.value, ast.Attribute):
136
+ child_flat_name = FlattenAttributeNameTransformer._flatten_attribute_name(node.value)
137
+ if not child_flat_name:
138
+ return None
139
+ return create_flat_name(child_flat_name, node.attr)
140
+ return None
141
+
142
+
143
+ def unpack_ast_struct_expressions(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
144
+ """
145
+ Transform nodes in AST, to flatten access to struct members
146
+
147
+ Examples of things we will transform/flatten:
148
+
149
+ # my_struct_ab.a
150
+ # Attribute(value=Name())
151
+ Attribute(
152
+ value=Name(id='my_struct_ab', ctx=Load()),
153
+ attr='a',
154
+ ctx=Load())
155
+ =>
156
+ # __ti_my_struct_ab__ti_a
157
+ Name(id='__ti_my_struct_ab__ti_a', ctx=Load()
158
+
159
+ # my_struct_ab.struct_cd.d
160
+ # Attribute(value=Attribute(value=Name()))
161
+ Attribute(
162
+ value=Attribute(
163
+ value=Name(id='my_struct_ab', ctx=Load()),
164
+ attr='struct_cd',
165
+ ctx=Load()),
166
+ attr='d',
167
+ ctx=Load())
168
+ visit_attribute
169
+ =>
170
+ # __ti_my_struct_ab__ti_struct_cd__ti_d
171
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_d', ctx=Load()
172
+
173
+ # my_struct_ab.struct_cd.struct_ef.f
174
+ # Attribute(value=Attribute(value=Name()))
175
+ Attribute(
176
+ value=Attribute(
177
+ value=Attribute(
178
+ value=Name(id='my_struct_ab', ctx=Load()),
179
+ attr='struct_cd',
180
+ ctx=Load()),
181
+ attr='struct_ef',
182
+ ctx=Load()),
183
+ attr='f',
184
+ ctx=Load())
185
+ =>
186
+ # __ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f
187
+ Name(id='__ti_my_struct_ab__ti_struct_cd__ti_struct_ef__ti_f', ctx=Load()
188
+ """
189
+ transformer = FlattenAttributeNameTransformer(struct_locals=struct_locals)
190
+ new_tree = transformer.visit(tree)
191
+ ast.fix_missing_locations(new_tree)
192
+ return new_tree
193
+
194
+
195
+ def populate_global_vars_from_dataclass(
196
+ param_name: str,
197
+ param_type: Any,
198
+ py_arg: Any,
199
+ global_vars: dict[str, Any],
200
+ ):
201
+ for field in dataclasses.fields(param_type):
202
+ child_value = getattr(py_arg, field.name)
203
+ flat_name = create_flat_name(param_name, field.name)
204
+ if dataclasses.is_dataclass(field.type):
205
+ populate_global_vars_from_dataclass(
206
+ param_name=flat_name,
207
+ param_type=field.type,
208
+ py_arg=child_value,
209
+ global_vars=global_vars,
210
+ )
211
+ elif util.is_ti_template(field.type):
212
+ global_vars[flat_name] = child_value
@@ -1,6 +1,6 @@
1
1
  import dataclasses
2
2
  import weakref
3
- from typing import Any, Union
3
+ from typing import Any, Callable, Union
4
4
 
5
5
  import gstaichi.lang
6
6
  import gstaichi.lang._ndarray
@@ -8,24 +8,27 @@ import gstaichi.lang._texture
8
8
  import gstaichi.lang.expr
9
9
  import gstaichi.lang.snode
10
10
  from gstaichi._lib import core as _ti_core
11
+ from gstaichi.lang import _dataclass_util
11
12
  from gstaichi.lang.any_array import AnyArray
12
- from gstaichi.lang.argpack import ArgPack, ArgPackType
13
13
  from gstaichi.lang.exception import (
14
14
  GsTaichiRuntimeTypeError,
15
15
  )
16
- from gstaichi.lang.kernel_arguments import KernelArgument
16
+ from gstaichi.lang.kernel_arguments import ArgMetadata
17
17
  from gstaichi.lang.matrix import MatrixType
18
- from gstaichi.lang.util import to_gstaichi_type
18
+ from gstaichi.lang.util import is_ti_template, to_gstaichi_type
19
19
  from gstaichi.types import (
20
20
  ndarray_type,
21
21
  sparse_matrix_builder,
22
22
  template,
23
23
  texture_type,
24
24
  )
25
+ from gstaichi.types.enums import AutodiffMode
26
+
27
+ CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
28
+
25
29
 
26
30
  AnnotationType = Union[
27
31
  template,
28
- ArgPackType,
29
32
  "texture_type.TextureType",
30
33
  "texture_type.RWTextureType",
31
34
  ndarray_type.NdarrayType,
@@ -34,7 +37,7 @@ AnnotationType = Union[
34
37
  ]
35
38
 
36
39
 
37
- class GsTaichiCallableTemplateMapper:
40
+ class TemplateMapper:
38
41
  """
39
42
  This should probably be renamed to sometihng like FeatureMapper, or
40
43
  FeatureExtractor, since:
@@ -46,15 +49,15 @@ class GsTaichiCallableTemplateMapper:
46
49
  - these are returned as a heterogeneous tuple, whose contents depends on the type
47
50
  """
48
51
 
49
- def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
50
- self.arguments: list[KernelArgument] = arguments
52
+ def __init__(self, arguments: list[ArgMetadata], template_slot_locations: list[int]) -> None:
53
+ self.arguments: list[ArgMetadata] = arguments
51
54
  self.num_args: int = len(arguments)
52
55
  self.template_slot_locations: list[int] = template_slot_locations
53
56
  self.mapping: dict[tuple[Any, ...], int] = {}
54
57
 
55
58
  @staticmethod
56
- def extract_arg(arg, annotation: AnnotationType, arg_name: str) -> Any:
57
- if annotation == template or isinstance(annotation, template):
59
+ def extract_arg(arg: Any, annotation: AnnotationType, arg_name: str) -> Any:
60
+ if is_ti_template(annotation):
58
61
  if isinstance(arg, gstaichi.lang.snode.SNode):
59
62
  return arg.ptr
60
63
  if isinstance(arg, gstaichi.lang.expr.Expr):
@@ -62,7 +65,7 @@ class GsTaichiCallableTemplateMapper:
62
65
  if isinstance(arg, _ti_core.ExprCxx):
63
66
  return arg.get_underlying_ptr_address()
64
67
  if isinstance(arg, tuple):
65
- return tuple(GsTaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
68
+ return tuple(TemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
66
69
  if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
67
70
  raise GsTaichiRuntimeTypeError(
68
71
  "Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
@@ -81,19 +84,12 @@ class GsTaichiCallableTemplateMapper:
81
84
 
82
85
  # [Primitive arguments] Return the value
83
86
  return arg
84
- if isinstance(annotation, ArgPackType):
85
- if not isinstance(arg, ArgPack):
86
- raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
87
- return tuple(
88
- GsTaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
89
- for index, (name, dtype) in enumerate(annotation.members.items())
90
- )
91
87
  if dataclasses.is_dataclass(annotation):
92
88
  _res_l = []
93
89
  for field in dataclasses.fields(annotation):
94
90
  field_value = getattr(arg, field.name)
95
- arg_name = f"__ti_{arg_name}_{field.name}"
96
- field_extracted = GsTaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
91
+ child_name = _dataclass_util.create_flat_name(arg_name, field.name)
92
+ field_extracted = TemplateMapper.extract_arg(field_value, field.type, child_name)
97
93
  _res_l.append(field_extracted)
98
94
  return tuple(_res_l)
99
95
  if isinstance(annotation, texture_type.TextureType):
@@ -19,8 +19,10 @@ import atexit
19
19
  import inspect
20
20
  import os
21
21
  import tempfile
22
+ from typing import Callable
22
23
 
23
24
  import dill
25
+ from pydantic import BaseModel
24
26
 
25
27
  _builtin_getfile = inspect.getfile
26
28
  _builtin_findsource = inspect.findsource
@@ -186,4 +188,28 @@ def getsourcefile(obj):
186
188
  return ret
187
189
 
188
190
 
189
- __all__ = ["getsourcelines", "getsourcefile"]
191
+ class FunctionSourceInfo(BaseModel):
192
+ function_name: str
193
+ filepath: str
194
+ start_lineno: int
195
+ end_lineno: int
196
+
197
+ class Config:
198
+ frozen = True
199
+
200
+
201
+ def get_source_info_and_src(func: Callable) -> tuple[FunctionSourceInfo, list[str]]:
202
+ file = getsourcefile(func)
203
+ name = func.__name__
204
+ src, start_lineno = getsourcelines(func)
205
+ end_lineno = start_lineno + len(src) - 1
206
+ func_info = FunctionSourceInfo(
207
+ function_name=name,
208
+ filepath=file,
209
+ start_lineno=start_lineno,
210
+ end_lineno=end_lineno,
211
+ )
212
+ return (func_info, src)
213
+
214
+
215
+ __all__ = ["getsourcelines", "getsourcefile", "get_source_info_and_src"]
@@ -2,10 +2,11 @@
2
2
 
3
3
  import ast
4
4
  import collections.abc
5
+ import dataclasses
5
6
  import itertools
6
7
  import warnings
7
8
  from ast import unparse
8
- from typing import Any, Iterable, Type
9
+ from typing import Any, Sequence, Type
9
10
 
10
11
  import numpy as np
11
12
 
@@ -40,7 +41,7 @@ from gstaichi.types import primitive_types
40
41
  from gstaichi.types.utils import is_integral
41
42
 
42
43
 
43
- def reshape_list(flat_list: list[Any], target_shape: Iterable[int]) -> list[Any]:
44
+ def reshape_list(flat_list: list[Any], target_shape: Sequence[int]) -> list[Any]:
44
45
  if len(target_shape) < 2:
45
46
  return flat_list
46
47
 
@@ -645,6 +646,8 @@ class ASTTransformer(Builder):
645
646
 
646
647
  node.ptr = getattr(tensor_ops, node.attr)
647
648
  setattr(node, "caller", node.value.ptr)
649
+ elif dataclasses.is_dataclass(node.value.ptr):
650
+ node.ptr = next(field.type for field in dataclasses.fields(node.value.ptr))
648
651
  else:
649
652
  node.ptr = getattr(node.value.ptr, node.attr)
650
653
  return node.ptr
@@ -1309,6 +1312,8 @@ build_stmt = ASTTransformer()
1309
1312
 
1310
1313
 
1311
1314
  def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
1315
+ # TODO: Should we just make this part of ASTTransformer? Then, easier to pass around (just
1316
+ # pass the ASTTransformer object around)
1312
1317
  with ctx.variable_scope_guard():
1313
1318
  for stmt in stmts:
1314
1319
  if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
@@ -27,7 +27,8 @@ if TYPE_CHECKING:
27
27
 
28
28
  class Builder:
29
29
  def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
30
- method = getattr(self, "build_" + node.__class__.__name__, None)
30
+ method_name = "build_" + node.__class__.__name__
31
+ method = getattr(self, method_name, None)
31
32
  try:
32
33
  if method is None:
33
34
  error_msg = f'Unsupported node "{node.__class__.__name__}"'
@@ -155,17 +156,18 @@ class ReturnStatus(Enum):
155
156
  class ASTTransformerContext:
156
157
  def __init__(
157
158
  self,
158
- excluded_parameters=(),
159
- is_kernel: bool = True,
160
- func: "Func | Kernel | None" = None,
161
- arg_features=None,
162
- global_vars: dict[str, Any] | None = None,
163
- argument_data=None,
164
- file: str | None = None,
165
- src: list[str] | None = None,
166
- start_lineno: int | None = None,
167
- ast_builder: ASTBuilder | None = None,
168
- is_real_function: bool = False,
159
+ excluded_parameters,
160
+ end_lineno: int,
161
+ is_kernel: bool,
162
+ func: "Func | Kernel",
163
+ arg_features: list[tuple[Any, ...]] | None,
164
+ global_vars: dict[str, Any],
165
+ argument_data,
166
+ file: str,
167
+ src: list[str],
168
+ start_lineno: int,
169
+ ast_builder: ASTBuilder | None,
170
+ is_real_function: bool,
169
171
  ):
170
172
  self.func = func
171
173
  self.local_scopes: list[dict[str, Any]] = []
@@ -176,7 +178,7 @@ class ASTTransformerContext:
176
178
  self.returns = None
177
179
  self.global_vars = global_vars
178
180
  self.argument_data = argument_data
179
- self.return_data = None
181
+ self.return_data: tuple[Any, ...] | Any | None = None
180
182
  self.file = file
181
183
  self.src = src
182
184
  self.indent = 0
@@ -186,6 +188,8 @@ class ASTTransformerContext:
186
188
  else:
187
189
  break
188
190
  self.lineno_offset = start_lineno - 1
191
+ self.start_lineno = start_lineno
192
+ self.end_lineno = end_lineno
189
193
  self.raised = False
190
194
  self.non_static_control_flow_status = NonStaticControlFlowStatus()
191
195
  self.static_scope_status = StaticScopeStatus()
@@ -194,6 +198,7 @@ class ASTTransformerContext:
194
198
  self.visited_funcdef = False
195
199
  self.is_real_function = is_real_function
196
200
  self.kernel_args: list = []
201
+ self.only_parse_function_def: bool = False
197
202
 
198
203
  # e.g.: FunctionDef, Module, Global
199
204
  def variable_scope_guard(self):