gstaichi 1.0.1__cp310-cp310-macosx_15_0_arm64.whl → 2.1.0__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. gstaichi/CHANGELOG.md +1 -3
  2. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  3. gstaichi/_lib/core/gstaichi_python.pyi +11 -41
  4. gstaichi/_lib/utils.py +1 -7
  5. gstaichi/_test_tools/__init__.py +18 -0
  6. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  7. gstaichi/_test_tools/textwrap2.py +6 -0
  8. gstaichi/_version.py +1 -1
  9. gstaichi/examples/lcg_python.py +26 -0
  10. gstaichi/examples/lcg_taichi.py +34 -0
  11. gstaichi/examples/minimal.py +1 -1
  12. gstaichi/lang/__init__.py +1 -1
  13. gstaichi/lang/_dataclass_util.py +31 -0
  14. gstaichi/lang/_fast_caching/__init__.py +3 -0
  15. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  16. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  17. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  18. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  19. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  20. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  21. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  22. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  23. gstaichi/lang/_template_mapper.py +16 -20
  24. gstaichi/lang/_wrap_inspect.py +27 -1
  25. gstaichi/lang/ast/ast_transformer.py +7 -2
  26. gstaichi/lang/ast/ast_transformer_utils.py +18 -13
  27. gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
  28. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
  29. gstaichi/lang/field.py +0 -38
  30. gstaichi/lang/impl.py +25 -24
  31. gstaichi/lang/kernel_arguments.py +28 -30
  32. gstaichi/lang/kernel_impl.py +154 -200
  33. gstaichi/lang/matrix.py +0 -46
  34. gstaichi/lang/struct.py +0 -45
  35. gstaichi/lang/util.py +11 -80
  36. gstaichi/types/annotations.py +10 -5
  37. gstaichi/types/compound_types.py +1 -20
  38. gstaichi/types/ndarray_type.py +33 -11
  39. gstaichi/types/utils.py +0 -2
  40. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/METADATA +4 -3
  41. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/RECORD +107 -94
  42. gstaichi/lang/argpack.py +0 -411
  43. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/GLFW/glfw3.h +0 -0
  44. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/GLFW/glfw3native.h +0 -0
  45. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  46. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.h +0 -0
  47. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  48. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/linker.hpp +0 -0
  49. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  50. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
  51. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv.h +0 -0
  52. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
  53. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
  54. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
  55. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
  56. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
  57. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
  58. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
  59. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
  60. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
  61. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
  62. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
  63. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
  64. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
  65. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
  66. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
  67. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
  68. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
  69. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
  70. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
  71. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
  72. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
  73. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
  74. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
  75. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
  76. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
  77. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
  78. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
  79. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
  80. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
  81. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
  82. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
  83. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  84. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
  85. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
  86. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
  87. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
  88. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
  89. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  90. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
  91. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
  92. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
  93. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
  94. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
  95. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
  96. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
  97. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
  98. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
  99. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
  100. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
  101. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
  102. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
  103. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
  104. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
  105. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
  106. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/WHEEL +0 -0
  107. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/licenses/LICENSE +0 -0
  108. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/top_level.txt +0 -0
@@ -29,11 +29,11 @@ from gstaichi._lib.core.gstaichi_python import (
29
29
  KernelCxx,
30
30
  KernelLaunchContext,
31
31
  )
32
- from gstaichi.lang import impl, ops, runtime_ops
33
- from gstaichi.lang._template_mapper import GsTaichiCallableTemplateMapper
34
- from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
32
+ from gstaichi.lang import _kernel_impl_dataclass, impl, ops, runtime_ops
33
+ from gstaichi.lang._fast_caching import src_hasher
34
+ from gstaichi.lang._template_mapper import TemplateMapper
35
+ from gstaichi.lang._wrap_inspect import FunctionSourceInfo, get_source_info_and_src
35
36
  from gstaichi.lang.any_array import AnyArray
36
- from gstaichi.lang.argpack import ArgPack, ArgPackType
37
37
  from gstaichi.lang.ast import (
38
38
  ASTTransformerContext,
39
39
  KernelSimplicityASTChecker,
@@ -49,11 +49,11 @@ from gstaichi.lang.exception import (
49
49
  handle_exception_from_cpp,
50
50
  )
51
51
  from gstaichi.lang.expr import Expr
52
- from gstaichi.lang.kernel_arguments import KernelArgument
52
+ from gstaichi.lang.kernel_arguments import ArgMetadata
53
53
  from gstaichi.lang.matrix import MatrixType
54
54
  from gstaichi.lang.shell import _shell_pop_print
55
55
  from gstaichi.lang.struct import StructType
56
- from gstaichi.lang.util import cook_dtype, has_paddle, has_pytorch
56
+ from gstaichi.lang.util import cook_dtype, has_pytorch
57
57
  from gstaichi.types import (
58
58
  ndarray_type,
59
59
  primitive_types,
@@ -152,6 +152,7 @@ class GsTaichiCallable:
152
152
  self._adjoint: Kernel | None = None
153
153
  self.grad: Kernel | None = None
154
154
  self._is_staticmethod: bool = False
155
+ self.is_pure = False
155
156
  functools.update_wrapper(self, fn)
156
157
 
157
158
  def __call__(self, *args, **kwargs):
@@ -243,17 +244,45 @@ def pyfunc(fn: Callable) -> GsTaichiCallable:
243
244
  return gstaichi_callable
244
245
 
245
246
 
247
+ def _populate_global_vars_for_templates(
248
+ template_slot_locations: list[int],
249
+ argument_metas: list[ArgMetadata],
250
+ global_vars: dict[str, Any],
251
+ fn: Callable,
252
+ py_args: tuple[Any, ...],
253
+ ):
254
+ """
255
+ Inject template parameters into globals
256
+
257
+ Globals are being abused to store the python objects associated
258
+ with templates. We continue this approach, and in addition this function
259
+ handles injecting expanded python variables from dataclasses.
260
+ """
261
+ for i in template_slot_locations:
262
+ template_var_name = argument_metas[i].name
263
+ global_vars[template_var_name] = py_args[i]
264
+ parameters = inspect.signature(fn).parameters
265
+ for i, (parameter_name, parameter) in enumerate(parameters.items()):
266
+ if dataclasses.is_dataclass(parameter.annotation):
267
+ _kernel_impl_dataclass.populate_global_vars_from_dataclass(
268
+ parameter_name,
269
+ parameter.annotation,
270
+ py_args[i],
271
+ global_vars=global_vars,
272
+ )
273
+
274
+
246
275
  def _get_tree_and_ctx(
247
276
  self: "Func | Kernel",
248
277
  args: tuple[Any, ...],
249
278
  excluded_parameters=(),
250
279
  is_kernel: bool = True,
251
280
  arg_features=None,
252
- ast_builder: ASTBuilder | None = None,
281
+ ast_builder: "ASTBuilder | None" = None,
253
282
  is_real_function: bool = False,
283
+ current_kernel: "Kernel | None" = None,
254
284
  ) -> tuple[ast.Module, ASTTransformerContext]:
255
- file = getsourcefile(self.func)
256
- src, start_lineno = getsourcelines(self.func)
285
+ function_source_info, src = get_source_info_and_src(self.func)
257
286
  src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
258
287
  tree = ast.parse(textwrap.dedent("\n".join(src)))
259
288
 
@@ -263,17 +292,20 @@ def _get_tree_and_ctx(
263
292
  global_vars = _get_global_vars(self.func)
264
293
 
265
294
  if is_kernel or is_real_function:
266
- # inject template parameters into globals
267
- for i in self.template_slot_locations:
268
- template_var_name = self.arguments[i].name
269
- global_vars[template_var_name] = args[i]
270
- parameters = inspect.signature(self.func).parameters
271
- for arg_i, (param_name, param) in enumerate(parameters.items()):
272
- if dataclasses.is_dataclass(param.annotation):
273
- for member_field in dataclasses.fields(param.annotation):
274
- child_value = getattr(args[arg_i], member_field.name)
275
- flat_name = f"__ti_{param_name}_{member_field.name}"
276
- global_vars[flat_name] = child_value
295
+ _populate_global_vars_for_templates(
296
+ template_slot_locations=self.template_slot_locations,
297
+ argument_metas=self.arg_metas,
298
+ global_vars=global_vars,
299
+ fn=self.func,
300
+ py_args=args,
301
+ )
302
+
303
+ if current_kernel is not None: # Kernel
304
+ current_kernel.kernel_function_info = function_source_info
305
+ if current_kernel is None:
306
+ current_kernel = impl.get_runtime()._current_kernel
307
+ assert current_kernel is not None
308
+ current_kernel.visited_functions.add(function_source_info)
277
309
 
278
310
  return tree, ASTTransformerContext(
279
311
  excluded_parameters=excluded_parameters,
@@ -283,38 +315,24 @@ def _get_tree_and_ctx(
283
315
  global_vars=global_vars,
284
316
  argument_data=args,
285
317
  src=src,
286
- start_lineno=start_lineno,
287
- file=file,
318
+ start_lineno=function_source_info.start_lineno,
319
+ end_lineno=function_source_info.end_lineno,
320
+ file=function_source_info.filepath,
288
321
  ast_builder=ast_builder,
289
322
  is_real_function=is_real_function,
290
323
  )
291
324
 
292
325
 
293
- def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgument]:
294
- new_arguments = []
295
- for argument in arguments:
296
- if dataclasses.is_dataclass(argument.annotation):
297
- for field in dataclasses.fields(argument.annotation):
298
- new_argument = KernelArgument(
299
- _annotation=field.type,
300
- _name=f"__ti_{argument.name}_{field.name}",
301
- )
302
- new_arguments.append(new_argument)
303
- else:
304
- new_arguments.append(argument)
305
- return new_arguments
306
-
307
-
308
326
  def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
309
327
  if is_func:
310
- self.arguments = expand_func_arguments(self.arguments)
311
- fused_args = [argument.default for argument in self.arguments]
312
- ret: list[Any] = [argument.default for argument in self.arguments]
328
+ self.arg_metas = _kernel_impl_dataclass.expand_func_arguments(self.arg_metas)
329
+
330
+ fused_args: list[Any] = [arg_meta.default for arg_meta in self.arg_metas]
313
331
  len_args = len(args)
314
332
 
315
333
  if len_args > len(fused_args):
316
- arg_str = ", ".join([str(arg) for arg in args])
317
- expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
334
+ arg_str = ", ".join(map(str, args))
335
+ expected_str = ", ".join(f"{arg.name} : {arg.annotation}" for arg in self.arg_metas)
318
336
  msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
319
337
  raise GsTaichiSyntaxError(msg)
320
338
 
@@ -322,69 +340,27 @@ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], k
322
340
  fused_args[i] = arg
323
341
 
324
342
  for key, value in kwargs.items():
325
- found = False
326
- for i, arg in enumerate(self.arguments):
343
+ for i, arg in enumerate(self.arg_metas):
327
344
  if key == arg.name:
328
345
  if i < len_args:
329
346
  raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
330
347
  fused_args[i] = value
331
- found = True
332
348
  break
333
- if not found:
349
+ else:
334
350
  raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
335
351
 
336
352
  for i, arg in enumerate(fused_args):
337
353
  if arg is inspect.Parameter.empty:
338
- if self.arguments[i].annotation is inspect._empty:
339
- raise GsTaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
354
+ if self.arg_metas[i].annotation is inspect._empty:
355
+ raise GsTaichiSyntaxError(f"Parameter `{self.arg_metas[i].name}` missing.")
340
356
  else:
341
357
  raise GsTaichiSyntaxError(
342
- f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
358
+ f"Parameter `{self.arg_metas[i].name} : {self.arg_metas[i].annotation}` missing."
343
359
  )
344
360
 
345
361
  return tuple(fused_args)
346
362
 
347
363
 
348
- def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
349
- class AttributeToNameTransformer(ast.NodeTransformer):
350
- def visit_Attribute(self, node: ast.Attribute):
351
- if isinstance(node.value, ast.Attribute):
352
- return node
353
- if not isinstance(node.value, ast.Name):
354
- return node
355
- base_id = node.value.id
356
- attr_name = node.attr
357
- new_id = f"__ti_{base_id}_{attr_name}"
358
- if new_id not in struct_locals:
359
- return node
360
- return ast.copy_location(ast.Name(id=new_id, ctx=node.ctx), node)
361
-
362
- transformer = AttributeToNameTransformer()
363
- new_tree = transformer.visit(tree)
364
- ast.fix_missing_locations(new_tree)
365
- return new_tree
366
-
367
-
368
- def extract_struct_locals_from_context(ctx: ASTTransformerContext):
369
- """
370
- - Uses ctx.func.func to get the function signature.
371
- - Searches this for any dataclasses:
372
- - If it finds any dataclasses, then converts them into expanded names.
373
- - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
374
- {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
375
- """
376
- assert ctx.func is not None
377
- sig = inspect.signature(ctx.func.func)
378
- parameters = sig.parameters
379
- struct_locals = set()
380
- for param_name, parameter in parameters.items():
381
- if dataclasses.is_dataclass(parameter.annotation):
382
- for field in dataclasses.fields(parameter.annotation):
383
- child_name = f"__ti_{param_name}_{field.name}"
384
- struct_locals.add(child_name)
385
- return struct_locals
386
-
387
-
388
364
  class Func:
389
365
  function_counter = 0
390
366
 
@@ -396,19 +372,19 @@ class Func:
396
372
  self.classfunc = _classfunc
397
373
  self.pyfunc = _pyfunc
398
374
  self.is_real_function = is_real_function
399
- self.arguments: list[KernelArgument] = []
400
- self.orig_arguments: list[KernelArgument] = []
375
+ self.arg_metas: list[ArgMetadata] = []
376
+ self.orig_arguments: list[ArgMetadata] = []
401
377
  self.return_type: tuple[Type, ...] | None = None
402
378
  self.extract_arguments()
403
379
  self.template_slot_locations: list[int] = []
404
- for i, arg in enumerate(self.arguments):
380
+ for i, arg in enumerate(self.arg_metas):
405
381
  if arg.annotation == template or isinstance(arg.annotation, template):
406
382
  self.template_slot_locations.append(i)
407
- self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
383
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
408
384
  self.gstaichi_functions = {} # The |Function| class in C++
409
385
  self.has_print = False
410
386
 
411
- def __call__(self, *args, **kwargs) -> Any:
387
+ def __call__(self: "Func", *args, **kwargs) -> Any:
412
388
  args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
413
389
 
414
390
  if not impl.inside_kernel():
@@ -433,8 +409,9 @@ class Func:
433
409
  is_real_function=self.is_real_function,
434
410
  )
435
411
 
436
- struct_locals = extract_struct_locals_from_context(ctx)
437
- tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
412
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
413
+
414
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
438
415
  ret = transform_tree(tree, ctx)
439
416
  if not self.is_real_function:
440
417
  if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
@@ -446,7 +423,7 @@ class Func:
446
423
  assert self.is_real_function
447
424
  non_template_args = []
448
425
  dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
449
- for i, kernel_arg in enumerate(self.arguments):
426
+ for i, kernel_arg in enumerate(self.arg_metas):
450
427
  anno = kernel_arg.annotation
451
428
  if not isinstance(anno, template):
452
429
  if id(anno) in primitive_types.type_ids:
@@ -497,10 +474,10 @@ class Func:
497
474
 
498
475
  def func_body():
499
476
  old_callable = impl.get_runtime().compiling_callable
500
- impl.get_runtime().compiling_callable = fn
477
+ impl.get_runtime()._compiling_callable = fn
501
478
  ctx.ast_builder = fn.ast_builder()
502
479
  transform_tree(tree, ctx)
503
- impl.get_runtime().compiling_callable = old_callable
480
+ impl.get_runtime()._compiling_callable = old_callable
504
481
 
505
482
  self.gstaichi_functions[key.instance_id] = fn
506
483
  self.compiled[key.instance_id] = func_body
@@ -569,8 +546,8 @@ class Func:
569
546
  raise GsTaichiSyntaxError(
570
547
  f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
571
548
  )
572
- self.arguments.append(KernelArgument(annotation, param.name, param.default))
573
- self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
549
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
550
+ self.orig_arguments.append(ArgMetadata(annotation, param.name, param.default))
574
551
 
575
552
 
576
553
  def _get_global_vars(_func: Callable) -> dict[str, Any]:
@@ -587,6 +564,14 @@ def _get_global_vars(_func: Callable) -> dict[str, Any]:
587
564
  return global_vars
588
565
 
589
566
 
567
+ @dataclasses.dataclass
568
+ class SrcLlCacheObservations:
569
+ cache_key_generated: bool = False
570
+ cache_validated: bool = False
571
+ cache_loaded: bool = False
572
+ cache_stored: bool = False
573
+
574
+
590
575
  class Kernel:
591
576
  counter = 0
592
577
 
@@ -601,21 +586,29 @@ class Kernel:
601
586
  AutodiffMode.REVERSE,
602
587
  )
603
588
  self.autodiff_mode = autodiff_mode
604
- self.grad: Kernel | None = None
605
- self.arguments: list[KernelArgument] = []
589
+ self.grad: "Kernel | None" = None
590
+ self.arg_metas: list[ArgMetadata] = []
606
591
  self.return_type = None
607
592
  self.classkernel = _classkernel
608
593
  self.extract_arguments()
609
594
  self.template_slot_locations = []
610
- for i, arg in enumerate(self.arguments):
595
+ for i, arg in enumerate(self.arg_metas):
611
596
  if arg.annotation == template or isinstance(arg.annotation, template):
612
597
  self.template_slot_locations.append(i)
613
- self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
598
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
614
599
  impl.get_runtime().kernels.append(self)
615
600
  self.reset()
616
601
  self.kernel_cpp = None
617
- self.compiled_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
602
+ # A materialized kernel is a KernelCxx object which may or may not have
603
+ # been compiled. It generally has been converted at least as far as AST
604
+ # and front-end IR, but not necessarily any further.
605
+ self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
618
606
  self.has_print = False
607
+ self.gstaichi_callable: GsTaichiCallable | None = None
608
+ self.visited_functions: set[FunctionSourceInfo] = set()
609
+ self.kernel_function_info: FunctionSourceInfo | None = None
610
+
611
+ self.src_ll_cache_observations: SrcLlCacheObservations = SrcLlCacheObservations()
619
612
 
620
613
  def ast_builder(self) -> ASTBuilder:
621
614
  assert self.kernel_cpp is not None
@@ -623,7 +616,7 @@ class Kernel:
623
616
 
624
617
  def reset(self) -> None:
625
618
  self.runtime = impl.get_runtime()
626
- self.compiled_kernels = {}
619
+ self.materialized_kernels = {}
627
620
 
628
621
  def extract_arguments(self) -> None:
629
622
  sig = inspect.signature(self.func)
@@ -639,7 +632,7 @@ class Kernel:
639
632
  for return_type in self.return_type:
640
633
  if return_type is Ellipsis:
641
634
  raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
642
- params = sig.parameters
635
+ params = dict(sig.parameters)
643
636
  arg_names = params.keys()
644
637
  for i, arg_name in enumerate(arg_names):
645
638
  param = params[arg_name]
@@ -682,34 +675,50 @@ class Kernel:
682
675
  pass
683
676
  elif isinstance(annotation, StructType):
684
677
  pass
685
- elif isinstance(annotation, ArgPackType):
686
- pass
687
678
  elif annotation == template:
688
679
  pass
689
680
  elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
690
681
  pass
691
682
  else:
692
- raise GsTaichiSyntaxError(
693
- f"Invalid type annotation (argument {i}) of GsTaichi kernel: {annotation}"
694
- )
695
- self.arguments.append(KernelArgument(annotation, param.name, param.default))
683
+ raise GsTaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
684
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
696
685
 
697
- def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features):
686
+ def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features=None):
698
687
  if key is None:
699
688
  key = (self.func, 0, self.autodiff_mode)
700
689
  self.runtime.materialize()
690
+ self.compiled_kernel_data = None
691
+ self.fast_checksum = None
701
692
 
702
- if key in self.compiled_kernels:
693
+ if key in self.materialized_kernels:
703
694
  return
704
695
 
696
+ if self.gstaichi_callable and self.gstaichi_callable.is_pure:
697
+ kernel_source_info, _src = get_source_info_and_src(self.func)
698
+ self.fast_checksum = src_hasher.create_cache_key(kernel_source_info, args)
699
+ if self.fast_checksum:
700
+ self.src_ll_cache_observations.cache_key_generated = True
701
+ if self.fast_checksum and src_hasher.validate_cache_key(self.fast_checksum):
702
+ self.src_ll_cache_observations.cache_validated = True
703
+ prog = impl.get_runtime().prog
704
+ self.compiled_kernel_data = prog.load_fast_cache(
705
+ self.fast_checksum,
706
+ self.func.__name__,
707
+ prog.config(),
708
+ prog.get_device_caps(),
709
+ )
710
+ if self.compiled_kernel_data:
711
+ self.src_ll_cache_observations.cache_loaded = True
712
+
705
713
  kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
706
- _logging.trace(f"Compiling kernel {kernel_name} in {self.autodiff_mode}...")
714
+ _logging.trace(f"Materializing kernel {kernel_name} in {self.autodiff_mode}...")
707
715
 
708
716
  tree, ctx = _get_tree_and_ctx(
709
717
  self,
710
718
  args=args,
711
719
  excluded_parameters=self.template_slot_locations,
712
720
  arg_features=arg_features,
721
+ current_kernel=self,
713
722
  )
714
723
 
715
724
  if self.autodiff_mode != AutodiffMode.NONE:
@@ -717,7 +726,7 @@ class Kernel:
717
726
 
718
727
  # Do not change the name of 'gstaichi_ast_generator'
719
728
  # The warning system needs this identifier to remove unnecessary messages
720
- def gstaichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
729
+ def gstaichi_ast_generator(kernel_cxx: KernelCxx):
721
730
  nonlocal tree
722
731
  if self.runtime.inside_kernel:
723
732
  raise GsTaichiSyntaxError(
@@ -729,8 +738,8 @@ class Kernel:
729
738
  self.kernel_cpp = kernel_cxx
730
739
  self.runtime.inside_kernel = True
731
740
  self.runtime._current_kernel = self
732
- assert self.runtime.compiling_callable is None
733
- self.runtime.compiling_callable = kernel_cxx
741
+ assert self.runtime._compiling_callable is None
742
+ self.runtime._compiling_callable = kernel_cxx
734
743
  try:
735
744
  ctx.ast_builder = kernel_cxx.ast_builder()
736
745
 
@@ -767,8 +776,9 @@ class Kernel:
767
776
  output_file.write_text(
768
777
  json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
769
778
  )
770
- struct_locals = extract_struct_locals_from_context(ctx)
771
- tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
779
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
780
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
781
+ ctx.only_parse_function_def = self.compiled_kernel_data is not None
772
782
  transform_tree(tree, ctx)
773
783
  if not ctx.is_real_function:
774
784
  if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
@@ -776,14 +786,14 @@ class Kernel:
776
786
  finally:
777
787
  self.runtime.inside_kernel = False
778
788
  self.runtime._current_kernel = None
779
- self.runtime.compiling_callable = None
789
+ self.runtime._compiling_callable = None
780
790
 
781
791
  gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
782
- assert key not in self.compiled_kernels
783
- self.compiled_kernels[key] = gstaichi_kernel
792
+ assert key not in self.materialized_kernels
793
+ self.materialized_kernels[key] = gstaichi_kernel
784
794
 
785
795
  def launch_kernel(self, t_kernel: KernelCxx, *args) -> Any:
786
- assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
796
+ assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
787
797
 
788
798
  tmps = []
789
799
  callbacks = []
@@ -897,43 +907,8 @@ class Kernel:
897
907
  )
898
908
  else:
899
909
  raise GsTaichiRuntimeTypeError(
900
- f"Argument {needed} cannot be converted into required type {type(v)}"
910
+ f"Argument of type {type(v)} cannot be converted into required type {needed}"
901
911
  )
902
- elif has_paddle():
903
- # Do we want to continue to support paddle? :thinking_face:
904
- # #maybeprunable
905
- import paddle # pylint: disable=C0415 # type: ignore
906
-
907
- if isinstance(v, paddle.Tensor):
908
- # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
909
- def get_call_back(u, v):
910
- def call_back():
911
- u.copy_(v, False)
912
-
913
- return call_back
914
-
915
- tmp = v.value().get_tensor()
916
- gstaichi_arch = self.runtime.prog.config().arch
917
- if v.place.is_gpu_place():
918
- if gstaichi_arch != _ti_core.Arch.cuda:
919
- # Paddle cuda tensor on GsTaichi non-cuda arch
920
- host_v = v.cpu()
921
- tmp = host_v.value().get_tensor()
922
- callbacks.append(get_call_back(v, host_v))
923
- elif v.place.is_cpu_place():
924
- if gstaichi_arch == _ti_core.Arch.cuda:
925
- # Paddle cpu tensor on GsTaichi cuda arch
926
- gpu_v = v.cuda()
927
- tmp = gpu_v.value().get_tensor()
928
- callbacks.append(get_call_back(v, gpu_v))
929
- else:
930
- # Paddle do support many other backends like XPU, NPU, MLU, IPU
931
- raise GsTaichiRuntimeTypeError(f"GsTaichi do not support backend {v.place} that Paddle support")
932
- launch_ctx.set_arg_external_array_with_shape(
933
- indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
934
- )
935
- else:
936
- raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
937
912
  else:
938
913
  raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
939
914
 
@@ -979,43 +954,28 @@ class Kernel:
979
954
  e.g. templates don't set kernel args, so returns 0
980
955
  a single ndarray is 1 kernel arg, so returns 1
981
956
  a struct of 3 ndarrays would set 3 kernel args, so return 3
957
+ note: len(indices) > 1 only happens with argpack (which we are removing support for)
982
958
  """
983
- in_argpack = len(indices) > 1
984
959
  nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
985
960
  if actual_argument_slot >= max_arg_num:
986
961
  exceed_max_arg_num = True
987
962
  return 0
988
963
  actual_argument_slot += 1
989
- if isinstance(needed_arg_type, ArgPackType):
990
- if not isinstance(v, ArgPack):
991
- raise GsTaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
992
- idx_new = 0
993
- for j, (name, anno) in enumerate(needed_arg_type.members.items()):
994
- idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
995
- launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack) # type: ignore
996
- return 1
997
964
  # Note: do not use sth like "needed == f32". That would be slow.
998
965
  if id(needed_arg_type) in primitive_types.real_type_ids:
999
966
  if not isinstance(v, (float, int, np.floating, np.integer)):
1000
967
  raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1001
- if in_argpack:
1002
- return 1
1003
968
  launch_ctx.set_arg_float(indices, float(v))
1004
969
  return 1
1005
970
  if id(needed_arg_type) in primitive_types.integer_type_ids:
1006
971
  if not isinstance(v, (int, np.integer)):
1007
972
  raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1008
- if in_argpack:
1009
- return 1
1010
973
  if is_signed(cook_dtype(needed_arg_type)):
1011
974
  launch_ctx.set_arg_int(indices, int(v))
1012
975
  else:
1013
976
  launch_ctx.set_arg_uint(indices, int(v))
1014
977
  return 1
1015
978
  if isinstance(needed_arg_type, sparse_matrix_builder):
1016
- if in_argpack:
1017
- set_later_list.append((set_arg_sparse_matrix_builder, (v,)))
1018
- return 0
1019
979
  set_arg_sparse_matrix_builder(indices, v)
1020
980
  return 1
1021
981
  if dataclasses.is_dataclass(needed_arg_type):
@@ -1027,39 +987,23 @@ class Kernel:
1027
987
  idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
1028
988
  return idx
1029
989
  if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
1030
- if in_argpack:
1031
- set_later_list.append((set_arg_ndarray, (v,)))
1032
- return 0
1033
990
  set_arg_ndarray(indices, v)
1034
991
  return 1
1035
992
  if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
1036
- if in_argpack:
1037
- set_later_list.append((set_arg_texture, (v,)))
1038
- return 0
1039
993
  set_arg_texture(indices, v)
1040
994
  return 1
1041
995
  if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
1042
996
  v, gstaichi.lang._texture.Texture
1043
997
  ):
1044
- if in_argpack:
1045
- set_later_list.append((set_arg_rw_texture, (v,)))
1046
- return 0
1047
998
  set_arg_rw_texture(indices, v)
1048
999
  return 1
1049
1000
  if isinstance(needed_arg_type, ndarray_type.NdarrayType):
1050
- if in_argpack:
1051
- set_later_list.append((set_arg_ext_array, (v, needed_arg_type)))
1052
- return 0
1053
1001
  set_arg_ext_array(indices, v, needed_arg_type)
1054
1002
  return 1
1055
1003
  if isinstance(needed_arg_type, MatrixType):
1056
- if in_argpack:
1057
- return 1
1058
1004
  set_arg_matrix(indices, v, needed_arg_type)
1059
1005
  return 1
1060
1006
  if isinstance(needed_arg_type, StructType):
1061
- if in_argpack:
1062
- return 1
1063
1007
  # Unclear how to make the following pass typing checks
1064
1008
  # StructType implements __instancecheck__, which should be a classmethod, but
1065
1009
  # is currently an instance method
@@ -1077,7 +1021,7 @@ class Kernel:
1077
1021
  template_num = 0
1078
1022
  i_out = 0
1079
1023
  for i_in, val in enumerate(args):
1080
- needed_ = self.arguments[i_in].annotation
1024
+ needed_ = self.arg_metas[i_in].annotation
1081
1025
  if needed_ == template or isinstance(needed_, template):
1082
1026
  template_num += 1
1083
1027
  i_out += 1
@@ -1094,10 +1038,19 @@ class Kernel:
1094
1038
 
1095
1039
  try:
1096
1040
  prog = impl.get_runtime().prog
1097
- # Compile kernel (& Online Cache & Offline Cache)
1098
- compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1099
- # Launch kernel
1100
- prog.launch_kernel(compiled_kernel_data, launch_ctx)
1041
+ if not self.compiled_kernel_data:
1042
+ self.compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1043
+ if self.fast_checksum:
1044
+ src_hasher.store(self.fast_checksum, self.visited_functions)
1045
+ prog.store_fast_cache(
1046
+ self.fast_checksum,
1047
+ self.kernel_cpp,
1048
+ prog.config(),
1049
+ prog.get_device_caps(),
1050
+ self.compiled_kernel_data,
1051
+ )
1052
+ self.src_ll_cache_observations.cache_stored = True
1053
+ prog.launch_kernel(self.compiled_kernel_data, launch_ctx)
1101
1054
  except Exception as e:
1102
1055
  e = handle_exception_from_cpp(e)
1103
1056
  if impl.get_runtime().print_full_traceback:
@@ -1170,7 +1123,7 @@ class Kernel:
1170
1123
  _logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
1171
1124
  impl.current_cfg().opt_level = 1
1172
1125
  key = self.ensure_compiled(*args)
1173
- kernel_cpp = self.compiled_kernels[key]
1126
+ kernel_cpp = self.materialized_kernels[key]
1174
1127
  return self.launch_kernel(kernel_cpp, *args)
1175
1128
 
1176
1129
 
@@ -1256,6 +1209,7 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
1256
1209
  wrapped._is_classkernel = is_classkernel
1257
1210
  wrapped._primal = primal
1258
1211
  wrapped._adjoint = adjoint
1212
+ primal.gstaichi_callable = wrapped
1259
1213
  return wrapped
1260
1214
 
1261
1215