gstaichi 1.0.1__cp312-cp312-macosx_15_0_arm64.whl → 2.1.0__cp312-cp312-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. gstaichi/CHANGELOG.md +1 -3
  2. gstaichi/_lib/core/gstaichi_python.cpython-312-darwin.so +0 -0
  3. gstaichi/_lib/core/gstaichi_python.pyi +11 -41
  4. gstaichi/_lib/utils.py +1 -7
  5. gstaichi/_test_tools/__init__.py +18 -0
  6. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  7. gstaichi/_test_tools/textwrap2.py +6 -0
  8. gstaichi/_version.py +1 -1
  9. gstaichi/examples/lcg_python.py +26 -0
  10. gstaichi/examples/lcg_taichi.py +34 -0
  11. gstaichi/examples/minimal.py +1 -1
  12. gstaichi/lang/__init__.py +1 -1
  13. gstaichi/lang/_dataclass_util.py +31 -0
  14. gstaichi/lang/_fast_caching/__init__.py +3 -0
  15. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  16. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  17. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  18. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  19. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  20. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  21. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  22. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  23. gstaichi/lang/_template_mapper.py +16 -20
  24. gstaichi/lang/_wrap_inspect.py +27 -1
  25. gstaichi/lang/ast/ast_transformer.py +7 -2
  26. gstaichi/lang/ast/ast_transformer_utils.py +18 -13
  27. gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
  28. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
  29. gstaichi/lang/field.py +0 -38
  30. gstaichi/lang/impl.py +25 -24
  31. gstaichi/lang/kernel_arguments.py +28 -30
  32. gstaichi/lang/kernel_impl.py +154 -200
  33. gstaichi/lang/matrix.py +0 -46
  34. gstaichi/lang/struct.py +0 -45
  35. gstaichi/lang/util.py +11 -80
  36. gstaichi/types/annotations.py +10 -5
  37. gstaichi/types/compound_types.py +1 -20
  38. gstaichi/types/ndarray_type.py +33 -11
  39. gstaichi/types/utils.py +0 -2
  40. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/METADATA +4 -3
  41. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/RECORD +107 -94
  42. gstaichi/lang/argpack.py +0 -411
  43. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/GLFW/glfw3.h +0 -0
  44. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/GLFW/glfw3native.h +0 -0
  45. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  46. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.h +0 -0
  47. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  48. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/linker.hpp +0 -0
  49. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  50. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
  51. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv.h +0 -0
  52. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
  53. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
  54. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
  55. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
  56. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
  57. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
  58. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
  59. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
  60. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
  61. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
  62. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
  63. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
  64. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
  65. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
  66. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
  67. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
  68. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
  69. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
  70. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
  71. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
  72. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
  73. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
  74. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
  75. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
  76. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
  77. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
  78. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
  79. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
  80. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
  81. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
  82. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
  83. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  84. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
  85. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
  86. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
  87. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
  88. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
  89. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  90. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
  91. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
  92. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
  93. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
  94. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
  95. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
  96. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
  97. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
  98. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
  99. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
  100. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
  101. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
  102. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
  103. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
  104. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
  105. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
  106. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/WHEEL +0 -0
  107. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/licenses/LICENSE +0 -0
  108. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import re
9
9
  import warnings
10
10
  from ast import unparse
11
11
  from collections import ChainMap
12
+ from typing import Any
12
13
 
13
14
  import numpy as np
14
15
 
@@ -18,6 +19,7 @@ from gstaichi.lang import (
18
19
  matrix,
19
20
  )
20
21
  from gstaichi.lang import ops as ti_ops
22
+ from gstaichi.lang._dataclass_util import create_flat_name
21
23
  from gstaichi.lang.ast.ast_transformer_utils import (
22
24
  ASTTransformerContext,
23
25
  get_decorator,
@@ -34,7 +36,7 @@ from gstaichi.types import primitive_types
34
36
 
35
37
  class CallTransformer:
36
38
  @staticmethod
37
- def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
39
+ def _build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
38
40
  from gstaichi.lang import matrix_ops # pylint: disable=C0415
39
41
 
40
42
  func = node.func.ptr
@@ -64,7 +66,7 @@ class CallTransformer:
64
66
  return False
65
67
 
66
68
  @staticmethod
67
- def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
69
+ def _build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
68
70
  func = node.func.ptr
69
71
  if id(func) in primitive_types.type_ids:
70
72
  if len(args) != 1 or keywords:
@@ -82,7 +84,7 @@ class CallTransformer:
82
84
  return False
83
85
 
84
86
  @staticmethod
85
- def is_external_func(ctx: ASTTransformerContext, func) -> bool:
87
+ def _is_external_func(ctx: ASTTransformerContext, func) -> bool:
86
88
  if ctx.is_in_static_scope(): # allow external function in static scope
87
89
  return False
88
90
  if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
@@ -92,9 +94,9 @@ class CallTransformer:
92
94
  return True
93
95
 
94
96
  @staticmethod
95
- def warn_if_is_external_func(ctx: ASTTransformerContext, node):
97
+ def _warn_if_is_external_func(ctx: ASTTransformerContext, node):
96
98
  func = node.func.ptr
97
- if not CallTransformer.is_external_func(ctx, func):
99
+ if not CallTransformer._is_external_func(ctx, func):
98
100
  return
99
101
  name = unparse(node.func).strip()
100
102
  warnings.warn_explicit(
@@ -120,7 +122,7 @@ class CallTransformer:
120
122
  # raw_args: [1.0, 2.0]
121
123
  # raw_keywords: {'k': <ti.Expr>}
122
124
  # return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
123
- def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
125
+ def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
124
126
  raw_brackets = re.findall(r"{(.*?)}", raw_string)
125
127
  brackets = []
126
128
  unnamed = 0
@@ -164,14 +166,18 @@ class CallTransformer:
164
166
  return args
165
167
 
166
168
  @staticmethod
167
- def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
169
+ def _expand_Call_dataclass_args(args: tuple[ast.stmt]) -> tuple[ast.stmt]:
170
+ """
171
+ We require that each node has a .ptr attribute added to it, that contains
172
+ the associated Python object
173
+ """
168
174
  args_new = []
169
175
  for arg in args:
170
176
  val = arg.ptr
171
177
  if dataclasses.is_dataclass(val):
172
178
  dataclass_type = val
173
179
  for field in dataclasses.fields(dataclass_type):
174
- child_name = f"__ti_{arg.id}_{field.name}"
180
+ child_name = create_flat_name(arg.id, field.name)
175
181
  load_ctx = ast.Load()
176
182
  arg_node = ast.Name(
177
183
  id=child_name,
@@ -181,13 +187,62 @@ class CallTransformer:
181
187
  col_offset=arg.col_offset,
182
188
  end_col_offset=arg.end_col_offset,
183
189
  )
184
- args_new.append(arg_node)
190
+ if dataclasses.is_dataclass(field.type):
191
+ arg_node.ptr = field.type
192
+ args_new.extend(CallTransformer._expand_Call_dataclass_args((arg_node,)))
193
+ else:
194
+ args_new.append(arg_node)
185
195
  else:
186
196
  args_new.append(arg)
187
197
  return tuple(args_new)
188
198
 
189
199
  @staticmethod
190
- def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts):
200
+ def _expand_Call_dataclass_kwargs(kwargs: list[ast.keyword]) -> list[ast.keyword]:
201
+ """
202
+ We require that each node has a .ptr attribute added to it, that contains
203
+ the associated Python object
204
+ """
205
+ kwargs_new = []
206
+ for i, kwarg in enumerate(kwargs):
207
+ val = kwarg.ptr[kwarg.arg]
208
+ if dataclasses.is_dataclass(val):
209
+ dataclass_type = val
210
+ for field in dataclasses.fields(dataclass_type):
211
+ src_name = create_flat_name(kwarg.value.id, field.name)
212
+ child_name = create_flat_name(kwarg.arg, field.name)
213
+ load_ctx = ast.Load()
214
+ src_node = ast.Name(
215
+ id=src_name,
216
+ ctx=load_ctx,
217
+ lineno=kwarg.lineno,
218
+ end_lineno=kwarg.end_lineno,
219
+ col_offset=kwarg.col_offset,
220
+ end_col_offset=kwarg.end_col_offset,
221
+ )
222
+ kwarg_node = ast.keyword(
223
+ arg=child_name,
224
+ value=src_node,
225
+ ctx=load_ctx,
226
+ lineno=kwarg.lineno,
227
+ end_lineno=kwarg.end_lineno,
228
+ col_offset=kwarg.col_offset,
229
+ end_col_offset=kwarg.end_col_offset,
230
+ )
231
+ if dataclasses.is_dataclass(field.type):
232
+ kwarg_node.ptr = {child_name: field.type}
233
+ kwargs_new.extend(CallTransformer._expand_Call_dataclass_kwargs([kwarg_node]))
234
+ else:
235
+ kwargs_new.append(kwarg_node)
236
+ else:
237
+ kwargs_new.append(kwarg)
238
+ return kwargs_new
239
+
240
+ @staticmethod
241
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts) -> Any | None:
242
+ """
243
+ example ast:
244
+ Call(func=Name(id='f2', ctx=Load()), args=[Name(id='my_struct_ab', ctx=Load())], keywords=[])
245
+ """
191
246
  if get_decorator(ctx, node) in ["static", "static_assert"]:
192
247
  with ctx.static_scope_guard():
193
248
  build_stmt(ctx, node.func)
@@ -198,7 +253,9 @@ class CallTransformer:
198
253
  # creates variable for the dataclass itself (as well as other variables,
199
254
  # not related to dataclasses). Necessary for calling further child functions
200
255
  build_stmts(ctx, node.args)
201
- node.args = CallTransformer.expand_node_args_dataclasses(node.args)
256
+ build_stmts(ctx, node.keywords)
257
+ node.args = CallTransformer._expand_Call_dataclass_args(node.args)
258
+ node.keywords = CallTransformer._expand_Call_dataclass_kwargs(node.keywords)
202
259
  # create variables for the now-expanded dataclass members
203
260
  build_stmts(ctx, node.args)
204
261
  build_stmts(ctx, node.keywords)
@@ -223,7 +280,7 @@ class CallTransformer:
223
280
 
224
281
  if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
225
282
  raw_string = node.func.value.ptr
226
- args = CallTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
283
+ args = CallTransformer._canonicalize_formatted_string(raw_string, *args, **keywords)
227
284
  node.ptr = impl.ti_format(*args)
228
285
  return node.ptr
229
286
 
@@ -231,17 +288,17 @@ class CallTransformer:
231
288
  node.ptr = matrix.make_matrix(*args, **keywords)
232
289
  return node.ptr
233
290
 
234
- if CallTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
291
+ if CallTransformer._build_call_if_is_builtin(ctx, node, args, keywords):
235
292
  return node.ptr
236
293
 
237
- if CallTransformer.build_call_if_is_type(ctx, node, args, keywords):
294
+ if CallTransformer._build_call_if_is_type(ctx, node, args, keywords):
238
295
  return node.ptr
239
296
 
240
297
  if hasattr(node.func, "caller"):
241
298
  node.ptr = func(node.func.caller, *args, **keywords)
242
299
  return node.ptr
243
300
 
244
- CallTransformer.warn_if_is_external_func(ctx, node)
301
+ CallTransformer._warn_if_is_external_func(ctx, node)
245
302
  try:
246
303
  node.ptr = func(*args, **keywords)
247
304
  except TypeError as e:
@@ -249,7 +306,7 @@ class CallTransformer:
249
306
  error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
250
307
  func_name = getattr(func, "__name__", func.__class__.__name__)
251
308
  msg = f"TypeError when calling `{func_name}`: {error_msg}."
252
- if CallTransformer.is_external_func(ctx, node.func.ptr):
309
+ if CallTransformer._is_external_func(ctx, node.func.ptr):
253
310
  args_has_expr = any([isinstance(arg, Expr) for arg in args])
254
311
  if args_has_expr and (module == math or module == np):
255
312
  exec_str = f"from gstaichi import {func.__name__}"
@@ -4,6 +4,10 @@ import ast
4
4
  import dataclasses
5
5
  from typing import Any, Callable
6
6
 
7
+ from gstaichi._lib.core.gstaichi_python import (
8
+ BoundaryMode,
9
+ DataTypeCxx,
10
+ )
7
11
  from gstaichi.lang import (
8
12
  _ndarray,
9
13
  any_array,
@@ -13,7 +17,7 @@ from gstaichi.lang import (
13
17
  matrix,
14
18
  )
15
19
  from gstaichi.lang import ops as ti_ops
16
- from gstaichi.lang.argpack import ArgPackType
20
+ from gstaichi.lang._dataclass_util import create_flat_name
17
21
  from gstaichi.lang.ast.ast_transformer_utils import (
18
22
  ASTTransformerContext,
19
23
  )
@@ -29,153 +33,127 @@ from gstaichi.types import annotations, ndarray_type, primitive_types, texture_t
29
33
  class FunctionDefTransformer:
30
34
  @staticmethod
31
35
  def _decl_and_create_variable(
32
- ctx: ASTTransformerContext, annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
36
+ ctx: ASTTransformerContext,
37
+ annotation: Any,
38
+ name: str,
39
+ this_arg_features: tuple[tuple[Any, ...], ...] | None,
40
+ prefix_name: str,
33
41
  ) -> tuple[bool, Any]:
34
42
  full_name = prefix_name + "_" + name
35
43
  if not isinstance(annotation, primitive_types.RefType):
36
44
  ctx.kernel_args.append(name)
37
- if isinstance(annotation, ArgPackType):
38
- kernel_arguments.push_argpack_arg(name)
39
- d = {}
40
- items_to_put_in_dict = []
41
- for j, (_name, anno) in enumerate(annotation.members.items()):
42
- result, obj = FunctionDefTransformer._decl_and_create_variable(
43
- ctx, anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
44
- )
45
- if not result:
46
- d[_name] = None
47
- items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
48
- else:
49
- d[_name] = obj
50
- argpack = kernel_arguments.decl_argpack_arg(annotation, d)
51
- for item in items_to_put_in_dict:
52
- invoke_later_dict[item[0]] = argpack, item[1], *item[2]
53
- return True, argpack
54
45
  if annotation == annotations.template or isinstance(annotation, annotations.template):
46
+ assert ctx.global_vars is not None
55
47
  return True, ctx.global_vars[name]
56
48
  if isinstance(annotation, annotations.sparse_matrix_builder):
57
49
  return False, (
58
50
  kernel_arguments.decl_sparse_matrix,
59
51
  (
60
- to_gstaichi_type(arg_features),
52
+ to_gstaichi_type(this_arg_features),
61
53
  full_name,
62
54
  ),
63
55
  )
64
56
  if isinstance(annotation, ndarray_type.NdarrayType):
57
+ assert this_arg_features is not None
58
+ raw_element_type: DataTypeCxx
59
+ ndim: int
60
+ needs_grad: bool
61
+ boundary: BoundaryMode
62
+ raw_element_type, ndim, needs_grad, boundary = this_arg_features
65
63
  return False, (
66
64
  kernel_arguments.decl_ndarray_arg,
67
65
  (
68
- to_gstaichi_type(arg_features[0]),
69
- arg_features[1],
66
+ to_gstaichi_type(raw_element_type),
67
+ ndim,
70
68
  full_name,
71
- arg_features[2],
72
- arg_features[3],
69
+ needs_grad,
70
+ boundary,
73
71
  ),
74
72
  )
75
73
  if isinstance(annotation, texture_type.TextureType):
76
- return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
74
+ assert this_arg_features is not None
75
+ return False, (kernel_arguments.decl_texture_arg, (this_arg_features[0], full_name))
77
76
  if isinstance(annotation, texture_type.RWTextureType):
77
+ assert this_arg_features is not None
78
78
  return False, (
79
79
  kernel_arguments.decl_rw_texture_arg,
80
- (arg_features[0], arg_features[1], arg_features[2], full_name),
80
+ (this_arg_features[0], this_arg_features[1], this_arg_features[2], full_name),
81
81
  )
82
82
  if isinstance(annotation, MatrixType):
83
- return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
83
+ return True, kernel_arguments.decl_matrix_arg(annotation, name)
84
84
  if isinstance(annotation, StructType):
85
- return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
86
- return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
85
+ return True, kernel_arguments.decl_struct_arg(annotation, name)
86
+ return True, kernel_arguments.decl_scalar_arg(annotation, name)
87
87
 
88
88
  @staticmethod
89
89
  def _transform_kernel_arg(
90
90
  ctx: ASTTransformerContext,
91
- invoke_later_dict: dict[str, tuple[Any, str, Callable, list[Any]]],
92
- create_variable_later: dict[str, Any],
93
91
  argument_name: str,
94
92
  argument_type: Any,
95
93
  this_arg_features: tuple[Any, ...],
96
94
  ) -> None:
97
- if isinstance(argument_type, ArgPackType):
98
- kernel_arguments.push_argpack_arg(argument_name)
99
- d = {}
100
- items_to_put_in_dict: list[tuple[str, str, Any]] = []
101
- for j, (name, anno) in enumerate(argument_type.members.items()):
102
- result, obj = FunctionDefTransformer._decl_and_create_variable(
103
- ctx, anno, name, this_arg_features[j], invoke_later_dict, "__argpack_" + name, 1
104
- )
105
- if not result:
106
- d[name] = None
107
- items_to_put_in_dict.append(("__argpack_" + name, name, obj))
108
- else:
109
- d[name] = obj
110
- argpack = kernel_arguments.decl_argpack_arg(argument_type, d)
111
- for item in items_to_put_in_dict:
112
- invoke_later_dict[item[0]] = argpack, item[1], *item[2]
113
- create_variable_later[argument_name] = argpack
114
- elif dataclasses.is_dataclass(argument_type):
115
- arg_features = this_arg_features
95
+ if dataclasses.is_dataclass(argument_type):
116
96
  ctx.create_variable(argument_name, argument_type)
117
97
  for field_idx, field in enumerate(dataclasses.fields(argument_type)):
118
- flat_name = f"__ti_{argument_name}_{field.name}"
119
- result, obj = FunctionDefTransformer._decl_and_create_variable(
120
- ctx,
121
- field.type,
122
- flat_name,
123
- arg_features[field_idx],
124
- invoke_later_dict,
125
- "",
126
- 0,
127
- )
128
- if result:
129
- ctx.create_variable(flat_name, obj)
98
+ flat_name = create_flat_name(argument_name, field.name)
99
+ # if a field is a dataclass, then feed back into process_kernel_arg recursively
100
+ if dataclasses.is_dataclass(field.type):
101
+ FunctionDefTransformer._transform_kernel_arg(
102
+ ctx,
103
+ flat_name,
104
+ field.type,
105
+ this_arg_features[field_idx],
106
+ )
130
107
  else:
131
- decl_type_func, type_args = obj
132
- obj = decl_type_func(*type_args)
133
- ctx.create_variable(flat_name, obj)
108
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
109
+ ctx,
110
+ field.type,
111
+ flat_name,
112
+ this_arg_features[field_idx],
113
+ "",
114
+ )
115
+ if result:
116
+ ctx.create_variable(flat_name, obj)
117
+ else:
118
+ decl_type_func, type_args = obj
119
+ obj = decl_type_func(*type_args)
120
+ ctx.create_variable(flat_name, obj)
134
121
  else:
135
122
  result, obj = FunctionDefTransformer._decl_and_create_variable(
136
123
  ctx,
137
124
  argument_type,
138
125
  argument_name,
139
126
  this_arg_features if ctx.arg_features is not None else None,
140
- invoke_later_dict,
141
127
  "",
142
- 0,
143
128
  )
144
- if result:
145
- ctx.create_variable(argument_name, obj)
146
- else:
129
+ if not result:
147
130
  decl_type_func, type_args = obj
148
131
  obj = decl_type_func(*type_args)
149
- ctx.create_variable(argument_name, obj)
132
+ ctx.create_variable(argument_name, obj)
150
133
 
151
134
  @staticmethod
152
135
  def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
136
+ assert ctx.func is not None
137
+ assert ctx.arg_features is not None
153
138
  if node.returns is not None:
154
139
  if not isinstance(node.returns, ast.Constant):
140
+ assert ctx.func.return_type is not None
155
141
  for return_type in ctx.func.return_type:
156
142
  kernel_arguments.decl_ret(return_type)
157
- impl.get_runtime().compiling_callable.finalize_rets()
143
+ compiling_callable = impl.get_runtime().compiling_callable
144
+ assert compiling_callable is not None
145
+ compiling_callable.finalize_rets()
158
146
 
159
- invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
160
- create_variable_later = dict()
161
- for i, arg in enumerate(args.args):
162
- argument = ctx.func.arguments[i]
147
+ for i in range(len(args.args)):
148
+ arg_meta = ctx.func.arg_metas[i]
163
149
  FunctionDefTransformer._transform_kernel_arg(
164
150
  ctx,
165
- invoke_later_dict,
166
- create_variable_later,
167
- argument.name,
168
- argument.annotation,
151
+ arg_meta.name,
152
+ arg_meta.annotation,
169
153
  ctx.arg_features[i] if ctx.arg_features is not None else (),
170
154
  )
171
155
 
172
- for k, v in invoke_later_dict.items():
173
- argpack, name, func, params = v
174
- argpack[name] = func(*params)
175
- for k, v in create_variable_later.items():
176
- ctx.create_variable(k, v)
177
-
178
- impl.get_runtime().compiling_callable.finalize_params()
156
+ compiling_callable.finalize_params()
179
157
  # remove original args
180
158
  node.args.args = []
181
159
 
@@ -186,15 +164,16 @@ class FunctionDefTransformer:
186
164
  argument_type: Any,
187
165
  data: Any,
188
166
  ) -> None:
167
+ # Template arguments are passed by reference.
189
168
  if isinstance(argument_type, annotations.template):
190
169
  ctx.create_variable(argument_name, data)
191
170
  return None
192
171
 
193
172
  if dataclasses.is_dataclass(argument_type):
194
- dataclass_type = argument_type
195
- for field in dataclasses.fields(dataclass_type):
173
+ for field in dataclasses.fields(argument_type):
174
+ flat_name = create_flat_name(argument_name, field.name)
196
175
  data_child = getattr(data, field.name)
197
- if not isinstance(
176
+ if isinstance(
198
177
  data_child,
199
178
  (
200
179
  _ndarray.ScalarNdarray,
@@ -203,33 +182,33 @@ class FunctionDefTransformer:
203
182
  any_array.AnyArray,
204
183
  ),
205
184
  ):
185
+ field.type.check_matched(data_child.get_type(), field.name)
186
+ ctx.create_variable(flat_name, data_child)
187
+ elif dataclasses.is_dataclass(data_child):
188
+ FunctionDefTransformer._transform_func_arg(
189
+ ctx,
190
+ flat_name,
191
+ field.type,
192
+ getattr(data, field.name),
193
+ )
194
+ else:
206
195
  raise GsTaichiSyntaxError(
207
- f"Argument {argument_name} of type {dataclass_type} {field.type} is not recognized."
196
+ f"Argument {field.name} of type {argument_type} {field.type} is not recognized."
208
197
  )
209
- field.type.check_matched(data_child.get_type(), field.name)
210
- var_name = f"__ti_{argument_name}_{field.name}"
211
- ctx.create_variable(var_name, data_child)
212
198
  return None
213
199
 
214
200
  # Ndarray arguments are passed by reference.
215
201
  if isinstance(argument_type, (ndarray_type.NdarrayType)):
216
202
  if not isinstance(
217
- data,
218
- (
219
- _ndarray.ScalarNdarray,
220
- matrix.VectorNdarray,
221
- matrix.MatrixNdarray,
222
- any_array.AnyArray,
223
- ),
203
+ data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray)
224
204
  ):
225
- raise GsTaichiSyntaxError(f"Argument {arg.arg} of type {argument_type} is not recognized.")
205
+ raise GsTaichiSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.")
226
206
  argument_type.check_matched(data.get_type(), argument_name)
227
207
  ctx.create_variable(argument_name, data)
228
208
  return None
229
209
 
230
210
  # Matrix arguments are passed by value.
231
211
  if isinstance(argument_type, (MatrixType)):
232
- var_name = argument_name
233
212
  # "data" is expected to be an Expr here,
234
213
  # so we simply call "impl.expr_init_func(data)" to perform:
235
214
  #
@@ -239,32 +218,31 @@ class FunctionDefTransformer:
239
218
  # We created local variable "t" - a copy of the passed-in argument "data"
240
219
  if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
241
220
  raise GsTaichiSyntaxError(
242
- f"Argument {var_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
221
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
243
222
  )
244
223
 
245
224
  element_shape = data.ptr.get_rvalue_type().shape()
246
225
  if len(element_shape) != argument_type.ndim:
247
226
  raise GsTaichiSyntaxError(
248
- f"Argument {var_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
227
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
249
228
  )
250
229
 
251
230
  assert argument_type.ndim > 0
252
231
  if element_shape[0] != argument_type.n:
253
232
  raise GsTaichiSyntaxError(
254
- f"Argument {var_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
233
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
255
234
  )
256
235
 
257
236
  if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
258
237
  raise GsTaichiSyntaxError(
259
- f"Argument {var_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
238
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
260
239
  )
261
240
 
262
- ctx.create_variable(var_name, impl.expr_init_func(data))
241
+ ctx.create_variable(argument_name, impl.expr_init_func(data))
263
242
  return None
264
243
 
265
244
  if id(argument_type) in primitive_types.type_ids:
266
- var_name = argument_name
267
- ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
245
+ ctx.create_variable(argument_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
268
246
  return None
269
247
  # Create a copy for non-template arguments,
270
248
  # so that they are passed by value.
@@ -274,15 +252,16 @@ class FunctionDefTransformer:
274
252
 
275
253
  @staticmethod
276
254
  def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
255
+ # pylint: disable=import-outside-toplevel
256
+ from gstaichi.lang.kernel_impl import Func
257
+
258
+ assert isinstance(ctx.func, Func)
259
+ assert ctx.argument_data is not None
277
260
  for data_i, data in enumerate(ctx.argument_data):
278
- argument = ctx.func.arguments[data_i]
279
- FunctionDefTransformer._transform_func_arg(
280
- ctx,
281
- argument.name,
282
- argument.annotation,
283
- data,
284
- )
261
+ argument = ctx.func.arg_metas[data_i]
262
+ FunctionDefTransformer._transform_func_arg(ctx, argument.name, argument.annotation, data)
285
263
 
264
+ # deal with dataclasses
286
265
  for v in ctx.func.orig_arguments:
287
266
  if dataclasses.is_dataclass(v.annotation):
288
267
  ctx.create_variable(v.name, v.annotation)
@@ -308,7 +287,12 @@ class FunctionDefTransformer:
308
287
  if ctx.is_kernel: # ti.kernel
309
288
  FunctionDefTransformer._transform_as_kernel(ctx, node, args)
310
289
 
311
- else: # ti.func
290
+ if ctx.only_parse_function_def:
291
+ return None
292
+
293
+ if not ctx.is_kernel: # ti.func
294
+ assert ctx.argument_data is not None
295
+ assert ctx.func is not None
312
296
  if ctx.is_real_function:
313
297
  FunctionDefTransformer._transform_as_kernel(ctx, node, args)
314
298
  else:
gstaichi/lang/field.py CHANGED
@@ -9,7 +9,6 @@ from gstaichi.lang.util import (
9
9
  in_python_scope,
10
10
  python_scope,
11
11
  to_numpy_type,
12
- to_paddle_type,
13
12
  to_pytorch_type,
14
13
  )
15
14
 
@@ -152,18 +151,6 @@ class Field:
152
151
  """
153
152
  raise NotImplementedError()
154
153
 
155
- @python_scope
156
- def to_paddle(self, place=None):
157
- """Converts `self` to a paddle tensor.
158
-
159
- Args:
160
- place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
161
-
162
- Returns:
163
- paddle.Tensor: The result paddle tensor.
164
- """
165
- raise NotImplementedError()
166
-
167
154
  @python_scope
168
155
  def from_numpy(self, arr):
169
156
  """Loads all elements from a numpy array.
@@ -190,17 +177,6 @@ class Field:
190
177
  """
191
178
  self._from_external_arr(arr.contiguous())
192
179
 
193
- @python_scope
194
- def from_paddle(self, arr):
195
- """Loads all elements from a paddle tensor.
196
-
197
- The shape of the paddle tensor needs to be the same as `self`.
198
-
199
- Args:
200
- arr (paddle.Tensor): The source paddle tensor.
201
- """
202
- self.from_numpy(arr)
203
-
204
180
  @python_scope
205
181
  def copy_from(self, other):
206
182
  """Copies all elements from another field.
@@ -325,20 +301,6 @@ class ScalarField(Field):
325
301
  gstaichi.lang.runtime_ops.sync()
326
302
  return arr
327
303
 
328
- @python_scope
329
- def to_paddle(self, place=None):
330
- """Converts this field to a `paddle.Tensor`."""
331
- import paddle # pylint: disable=C0415
332
-
333
- # pylint: disable=E1101
334
- # paddle.empty() doesn't support argument `place``
335
- arr = paddle.to_tensor(paddle.zeros(self.shape, to_paddle_type(self.dtype)), place=place)
336
- from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
337
-
338
- tensor_to_ext_arr(self, arr)
339
- gstaichi.lang.runtime_ops.sync()
340
- return arr
341
-
342
304
  @python_scope
343
305
  def _from_external_arr(self, arr):
344
306
  if len(self.shape) != len(arr.shape):