gstaichi 1.0.1__cp311-cp311-macosx_15_0_arm64.whl → 2.0.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +3 -3
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +11 -41
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -1
- gstaichi/examples/minimal.py +1 -1
- gstaichi/lang/__init__.py +1 -1
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_template_mapper.py +16 -20
- gstaichi/lang/_wrap_inspect.py +27 -1
- gstaichi/lang/ast/ast_transformer.py +7 -2
- gstaichi/lang/ast/ast_transformer_utils.py +18 -13
- gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
- gstaichi/lang/field.py +0 -38
- gstaichi/lang/impl.py +25 -24
- gstaichi/lang/kernel_arguments.py +28 -30
- gstaichi/lang/kernel_impl.py +154 -200
- gstaichi/lang/matrix.py +0 -46
- gstaichi/lang/struct.py +0 -45
- gstaichi/lang/util.py +11 -80
- gstaichi/types/annotations.py +10 -5
- gstaichi/types/compound_types.py +1 -20
- gstaichi/types/ndarray_type.py +31 -11
- gstaichi/types/utils.py +0 -2
- {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/METADATA +2 -1
- {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/RECORD +104 -93
- gstaichi/lang/argpack.py +0 -411
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/WHEEL +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.0.0.dist-info}/top_level.txt +0 -0
gstaichi/lang/kernel_impl.py
CHANGED
@@ -29,11 +29,11 @@ from gstaichi._lib.core.gstaichi_python import (
|
|
29
29
|
KernelCxx,
|
30
30
|
KernelLaunchContext,
|
31
31
|
)
|
32
|
-
from gstaichi.lang import impl, ops, runtime_ops
|
33
|
-
from gstaichi.lang.
|
34
|
-
from gstaichi.lang.
|
32
|
+
from gstaichi.lang import _kernel_impl_dataclass, impl, ops, runtime_ops
|
33
|
+
from gstaichi.lang._fast_caching import src_hasher
|
34
|
+
from gstaichi.lang._template_mapper import TemplateMapper
|
35
|
+
from gstaichi.lang._wrap_inspect import FunctionSourceInfo, get_source_info_and_src
|
35
36
|
from gstaichi.lang.any_array import AnyArray
|
36
|
-
from gstaichi.lang.argpack import ArgPack, ArgPackType
|
37
37
|
from gstaichi.lang.ast import (
|
38
38
|
ASTTransformerContext,
|
39
39
|
KernelSimplicityASTChecker,
|
@@ -49,11 +49,11 @@ from gstaichi.lang.exception import (
|
|
49
49
|
handle_exception_from_cpp,
|
50
50
|
)
|
51
51
|
from gstaichi.lang.expr import Expr
|
52
|
-
from gstaichi.lang.kernel_arguments import
|
52
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
53
53
|
from gstaichi.lang.matrix import MatrixType
|
54
54
|
from gstaichi.lang.shell import _shell_pop_print
|
55
55
|
from gstaichi.lang.struct import StructType
|
56
|
-
from gstaichi.lang.util import cook_dtype,
|
56
|
+
from gstaichi.lang.util import cook_dtype, has_pytorch
|
57
57
|
from gstaichi.types import (
|
58
58
|
ndarray_type,
|
59
59
|
primitive_types,
|
@@ -152,6 +152,7 @@ class GsTaichiCallable:
|
|
152
152
|
self._adjoint: Kernel | None = None
|
153
153
|
self.grad: Kernel | None = None
|
154
154
|
self._is_staticmethod: bool = False
|
155
|
+
self.is_pure = False
|
155
156
|
functools.update_wrapper(self, fn)
|
156
157
|
|
157
158
|
def __call__(self, *args, **kwargs):
|
@@ -243,17 +244,45 @@ def pyfunc(fn: Callable) -> GsTaichiCallable:
|
|
243
244
|
return gstaichi_callable
|
244
245
|
|
245
246
|
|
247
|
+
def _populate_global_vars_for_templates(
|
248
|
+
template_slot_locations: list[int],
|
249
|
+
argument_metas: list[ArgMetadata],
|
250
|
+
global_vars: dict[str, Any],
|
251
|
+
fn: Callable,
|
252
|
+
py_args: tuple[Any, ...],
|
253
|
+
):
|
254
|
+
"""
|
255
|
+
Inject template parameters into globals
|
256
|
+
|
257
|
+
Globals are being abused to store the python objects associated
|
258
|
+
with templates. We continue this approach, and in addition this function
|
259
|
+
handles injecting expanded python variables from dataclasses.
|
260
|
+
"""
|
261
|
+
for i in template_slot_locations:
|
262
|
+
template_var_name = argument_metas[i].name
|
263
|
+
global_vars[template_var_name] = py_args[i]
|
264
|
+
parameters = inspect.signature(fn).parameters
|
265
|
+
for i, (parameter_name, parameter) in enumerate(parameters.items()):
|
266
|
+
if dataclasses.is_dataclass(parameter.annotation):
|
267
|
+
_kernel_impl_dataclass.populate_global_vars_from_dataclass(
|
268
|
+
parameter_name,
|
269
|
+
parameter.annotation,
|
270
|
+
py_args[i],
|
271
|
+
global_vars=global_vars,
|
272
|
+
)
|
273
|
+
|
274
|
+
|
246
275
|
def _get_tree_and_ctx(
|
247
276
|
self: "Func | Kernel",
|
248
277
|
args: tuple[Any, ...],
|
249
278
|
excluded_parameters=(),
|
250
279
|
is_kernel: bool = True,
|
251
280
|
arg_features=None,
|
252
|
-
ast_builder: ASTBuilder | None = None,
|
281
|
+
ast_builder: "ASTBuilder | None" = None,
|
253
282
|
is_real_function: bool = False,
|
283
|
+
current_kernel: "Kernel | None" = None,
|
254
284
|
) -> tuple[ast.Module, ASTTransformerContext]:
|
255
|
-
|
256
|
-
src, start_lineno = getsourcelines(self.func)
|
285
|
+
function_source_info, src = get_source_info_and_src(self.func)
|
257
286
|
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
|
258
287
|
tree = ast.parse(textwrap.dedent("\n".join(src)))
|
259
288
|
|
@@ -263,17 +292,20 @@ def _get_tree_and_ctx(
|
|
263
292
|
global_vars = _get_global_vars(self.func)
|
264
293
|
|
265
294
|
if is_kernel or is_real_function:
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
global_vars
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
295
|
+
_populate_global_vars_for_templates(
|
296
|
+
template_slot_locations=self.template_slot_locations,
|
297
|
+
argument_metas=self.arg_metas,
|
298
|
+
global_vars=global_vars,
|
299
|
+
fn=self.func,
|
300
|
+
py_args=args,
|
301
|
+
)
|
302
|
+
|
303
|
+
if current_kernel is not None: # Kernel
|
304
|
+
current_kernel.kernel_function_info = function_source_info
|
305
|
+
if current_kernel is None:
|
306
|
+
current_kernel = impl.get_runtime()._current_kernel
|
307
|
+
assert current_kernel is not None
|
308
|
+
current_kernel.visited_functions.add(function_source_info)
|
277
309
|
|
278
310
|
return tree, ASTTransformerContext(
|
279
311
|
excluded_parameters=excluded_parameters,
|
@@ -283,38 +315,24 @@ def _get_tree_and_ctx(
|
|
283
315
|
global_vars=global_vars,
|
284
316
|
argument_data=args,
|
285
317
|
src=src,
|
286
|
-
start_lineno=start_lineno,
|
287
|
-
|
318
|
+
start_lineno=function_source_info.start_lineno,
|
319
|
+
end_lineno=function_source_info.end_lineno,
|
320
|
+
file=function_source_info.filepath,
|
288
321
|
ast_builder=ast_builder,
|
289
322
|
is_real_function=is_real_function,
|
290
323
|
)
|
291
324
|
|
292
325
|
|
293
|
-
def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgument]:
|
294
|
-
new_arguments = []
|
295
|
-
for argument in arguments:
|
296
|
-
if dataclasses.is_dataclass(argument.annotation):
|
297
|
-
for field in dataclasses.fields(argument.annotation):
|
298
|
-
new_argument = KernelArgument(
|
299
|
-
_annotation=field.type,
|
300
|
-
_name=f"__ti_{argument.name}_{field.name}",
|
301
|
-
)
|
302
|
-
new_arguments.append(new_argument)
|
303
|
-
else:
|
304
|
-
new_arguments.append(argument)
|
305
|
-
return new_arguments
|
306
|
-
|
307
|
-
|
308
326
|
def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
|
309
327
|
if is_func:
|
310
|
-
self.
|
311
|
-
|
312
|
-
|
328
|
+
self.arg_metas = _kernel_impl_dataclass.expand_func_arguments(self.arg_metas)
|
329
|
+
|
330
|
+
fused_args: list[Any] = [arg_meta.default for arg_meta in self.arg_metas]
|
313
331
|
len_args = len(args)
|
314
332
|
|
315
333
|
if len_args > len(fused_args):
|
316
|
-
arg_str = ", ".join(
|
317
|
-
expected_str = ", ".join(
|
334
|
+
arg_str = ", ".join(map(str, args))
|
335
|
+
expected_str = ", ".join(f"{arg.name} : {arg.annotation}" for arg in self.arg_metas)
|
318
336
|
msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
|
319
337
|
raise GsTaichiSyntaxError(msg)
|
320
338
|
|
@@ -322,69 +340,27 @@ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], k
|
|
322
340
|
fused_args[i] = arg
|
323
341
|
|
324
342
|
for key, value in kwargs.items():
|
325
|
-
|
326
|
-
for i, arg in enumerate(self.arguments):
|
343
|
+
for i, arg in enumerate(self.arg_metas):
|
327
344
|
if key == arg.name:
|
328
345
|
if i < len_args:
|
329
346
|
raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
|
330
347
|
fused_args[i] = value
|
331
|
-
found = True
|
332
348
|
break
|
333
|
-
|
349
|
+
else:
|
334
350
|
raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
|
335
351
|
|
336
352
|
for i, arg in enumerate(fused_args):
|
337
353
|
if arg is inspect.Parameter.empty:
|
338
|
-
if self.
|
339
|
-
raise GsTaichiSyntaxError(f"Parameter `{self.
|
354
|
+
if self.arg_metas[i].annotation is inspect._empty:
|
355
|
+
raise GsTaichiSyntaxError(f"Parameter `{self.arg_metas[i].name}` missing.")
|
340
356
|
else:
|
341
357
|
raise GsTaichiSyntaxError(
|
342
|
-
f"Parameter `{self.
|
358
|
+
f"Parameter `{self.arg_metas[i].name} : {self.arg_metas[i].annotation}` missing."
|
343
359
|
)
|
344
360
|
|
345
361
|
return tuple(fused_args)
|
346
362
|
|
347
363
|
|
348
|
-
def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
|
349
|
-
class AttributeToNameTransformer(ast.NodeTransformer):
|
350
|
-
def visit_Attribute(self, node: ast.Attribute):
|
351
|
-
if isinstance(node.value, ast.Attribute):
|
352
|
-
return node
|
353
|
-
if not isinstance(node.value, ast.Name):
|
354
|
-
return node
|
355
|
-
base_id = node.value.id
|
356
|
-
attr_name = node.attr
|
357
|
-
new_id = f"__ti_{base_id}_{attr_name}"
|
358
|
-
if new_id not in struct_locals:
|
359
|
-
return node
|
360
|
-
return ast.copy_location(ast.Name(id=new_id, ctx=node.ctx), node)
|
361
|
-
|
362
|
-
transformer = AttributeToNameTransformer()
|
363
|
-
new_tree = transformer.visit(tree)
|
364
|
-
ast.fix_missing_locations(new_tree)
|
365
|
-
return new_tree
|
366
|
-
|
367
|
-
|
368
|
-
def extract_struct_locals_from_context(ctx: ASTTransformerContext):
|
369
|
-
"""
|
370
|
-
- Uses ctx.func.func to get the function signature.
|
371
|
-
- Searches this for any dataclasses:
|
372
|
-
- If it finds any dataclasses, then converts them into expanded names.
|
373
|
-
- E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
|
374
|
-
{"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
|
375
|
-
"""
|
376
|
-
assert ctx.func is not None
|
377
|
-
sig = inspect.signature(ctx.func.func)
|
378
|
-
parameters = sig.parameters
|
379
|
-
struct_locals = set()
|
380
|
-
for param_name, parameter in parameters.items():
|
381
|
-
if dataclasses.is_dataclass(parameter.annotation):
|
382
|
-
for field in dataclasses.fields(parameter.annotation):
|
383
|
-
child_name = f"__ti_{param_name}_{field.name}"
|
384
|
-
struct_locals.add(child_name)
|
385
|
-
return struct_locals
|
386
|
-
|
387
|
-
|
388
364
|
class Func:
|
389
365
|
function_counter = 0
|
390
366
|
|
@@ -396,19 +372,19 @@ class Func:
|
|
396
372
|
self.classfunc = _classfunc
|
397
373
|
self.pyfunc = _pyfunc
|
398
374
|
self.is_real_function = is_real_function
|
399
|
-
self.
|
400
|
-
self.orig_arguments: list[
|
375
|
+
self.arg_metas: list[ArgMetadata] = []
|
376
|
+
self.orig_arguments: list[ArgMetadata] = []
|
401
377
|
self.return_type: tuple[Type, ...] | None = None
|
402
378
|
self.extract_arguments()
|
403
379
|
self.template_slot_locations: list[int] = []
|
404
|
-
for i, arg in enumerate(self.
|
380
|
+
for i, arg in enumerate(self.arg_metas):
|
405
381
|
if arg.annotation == template or isinstance(arg.annotation, template):
|
406
382
|
self.template_slot_locations.append(i)
|
407
|
-
self.mapper =
|
383
|
+
self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
|
408
384
|
self.gstaichi_functions = {} # The |Function| class in C++
|
409
385
|
self.has_print = False
|
410
386
|
|
411
|
-
def __call__(self, *args, **kwargs) -> Any:
|
387
|
+
def __call__(self: "Func", *args, **kwargs) -> Any:
|
412
388
|
args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
|
413
389
|
|
414
390
|
if not impl.inside_kernel():
|
@@ -433,8 +409,9 @@ class Func:
|
|
433
409
|
is_real_function=self.is_real_function,
|
434
410
|
)
|
435
411
|
|
436
|
-
struct_locals = extract_struct_locals_from_context(ctx)
|
437
|
-
|
412
|
+
struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
|
413
|
+
|
414
|
+
tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
|
438
415
|
ret = transform_tree(tree, ctx)
|
439
416
|
if not self.is_real_function:
|
440
417
|
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
@@ -446,7 +423,7 @@ class Func:
|
|
446
423
|
assert self.is_real_function
|
447
424
|
non_template_args = []
|
448
425
|
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
449
|
-
for i, kernel_arg in enumerate(self.
|
426
|
+
for i, kernel_arg in enumerate(self.arg_metas):
|
450
427
|
anno = kernel_arg.annotation
|
451
428
|
if not isinstance(anno, template):
|
452
429
|
if id(anno) in primitive_types.type_ids:
|
@@ -497,10 +474,10 @@ class Func:
|
|
497
474
|
|
498
475
|
def func_body():
|
499
476
|
old_callable = impl.get_runtime().compiling_callable
|
500
|
-
impl.get_runtime().
|
477
|
+
impl.get_runtime()._compiling_callable = fn
|
501
478
|
ctx.ast_builder = fn.ast_builder()
|
502
479
|
transform_tree(tree, ctx)
|
503
|
-
impl.get_runtime().
|
480
|
+
impl.get_runtime()._compiling_callable = old_callable
|
504
481
|
|
505
482
|
self.gstaichi_functions[key.instance_id] = fn
|
506
483
|
self.compiled[key.instance_id] = func_body
|
@@ -569,8 +546,8 @@ class Func:
|
|
569
546
|
raise GsTaichiSyntaxError(
|
570
547
|
f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
|
571
548
|
)
|
572
|
-
self.
|
573
|
-
self.orig_arguments.append(
|
549
|
+
self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
|
550
|
+
self.orig_arguments.append(ArgMetadata(annotation, param.name, param.default))
|
574
551
|
|
575
552
|
|
576
553
|
def _get_global_vars(_func: Callable) -> dict[str, Any]:
|
@@ -587,6 +564,14 @@ def _get_global_vars(_func: Callable) -> dict[str, Any]:
|
|
587
564
|
return global_vars
|
588
565
|
|
589
566
|
|
567
|
+
@dataclasses.dataclass
|
568
|
+
class SrcLlCacheObservations:
|
569
|
+
cache_key_generated: bool = False
|
570
|
+
cache_validated: bool = False
|
571
|
+
cache_loaded: bool = False
|
572
|
+
cache_stored: bool = False
|
573
|
+
|
574
|
+
|
590
575
|
class Kernel:
|
591
576
|
counter = 0
|
592
577
|
|
@@ -601,21 +586,29 @@ class Kernel:
|
|
601
586
|
AutodiffMode.REVERSE,
|
602
587
|
)
|
603
588
|
self.autodiff_mode = autodiff_mode
|
604
|
-
self.grad: Kernel | None = None
|
605
|
-
self.
|
589
|
+
self.grad: "Kernel | None" = None
|
590
|
+
self.arg_metas: list[ArgMetadata] = []
|
606
591
|
self.return_type = None
|
607
592
|
self.classkernel = _classkernel
|
608
593
|
self.extract_arguments()
|
609
594
|
self.template_slot_locations = []
|
610
|
-
for i, arg in enumerate(self.
|
595
|
+
for i, arg in enumerate(self.arg_metas):
|
611
596
|
if arg.annotation == template or isinstance(arg.annotation, template):
|
612
597
|
self.template_slot_locations.append(i)
|
613
|
-
self.mapper =
|
598
|
+
self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
|
614
599
|
impl.get_runtime().kernels.append(self)
|
615
600
|
self.reset()
|
616
601
|
self.kernel_cpp = None
|
617
|
-
|
602
|
+
# A materialized kernel is a KernelCxx object which may or may not have
|
603
|
+
# been compiled. It generally has been converted at least as far as AST
|
604
|
+
# and front-end IR, but not necessarily any further.
|
605
|
+
self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
|
618
606
|
self.has_print = False
|
607
|
+
self.gstaichi_callable: GsTaichiCallable | None = None
|
608
|
+
self.visited_functions: set[FunctionSourceInfo] = set()
|
609
|
+
self.kernel_function_info: FunctionSourceInfo | None = None
|
610
|
+
|
611
|
+
self.src_ll_cache_observations: SrcLlCacheObservations = SrcLlCacheObservations()
|
619
612
|
|
620
613
|
def ast_builder(self) -> ASTBuilder:
|
621
614
|
assert self.kernel_cpp is not None
|
@@ -623,7 +616,7 @@ class Kernel:
|
|
623
616
|
|
624
617
|
def reset(self) -> None:
|
625
618
|
self.runtime = impl.get_runtime()
|
626
|
-
self.
|
619
|
+
self.materialized_kernels = {}
|
627
620
|
|
628
621
|
def extract_arguments(self) -> None:
|
629
622
|
sig = inspect.signature(self.func)
|
@@ -639,7 +632,7 @@ class Kernel:
|
|
639
632
|
for return_type in self.return_type:
|
640
633
|
if return_type is Ellipsis:
|
641
634
|
raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
642
|
-
params = sig.parameters
|
635
|
+
params = dict(sig.parameters)
|
643
636
|
arg_names = params.keys()
|
644
637
|
for i, arg_name in enumerate(arg_names):
|
645
638
|
param = params[arg_name]
|
@@ -682,34 +675,50 @@ class Kernel:
|
|
682
675
|
pass
|
683
676
|
elif isinstance(annotation, StructType):
|
684
677
|
pass
|
685
|
-
elif isinstance(annotation, ArgPackType):
|
686
|
-
pass
|
687
678
|
elif annotation == template:
|
688
679
|
pass
|
689
680
|
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
690
681
|
pass
|
691
682
|
else:
|
692
|
-
raise GsTaichiSyntaxError(
|
693
|
-
|
694
|
-
)
|
695
|
-
self.arguments.append(KernelArgument(annotation, param.name, param.default))
|
683
|
+
raise GsTaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
|
684
|
+
self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
|
696
685
|
|
697
|
-
def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features):
|
686
|
+
def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features=None):
|
698
687
|
if key is None:
|
699
688
|
key = (self.func, 0, self.autodiff_mode)
|
700
689
|
self.runtime.materialize()
|
690
|
+
self.compiled_kernel_data = None
|
691
|
+
self.fast_checksum = None
|
701
692
|
|
702
|
-
if key in self.
|
693
|
+
if key in self.materialized_kernels:
|
703
694
|
return
|
704
695
|
|
696
|
+
if self.gstaichi_callable and self.gstaichi_callable.is_pure:
|
697
|
+
kernel_source_info, _src = get_source_info_and_src(self.func)
|
698
|
+
self.fast_checksum = src_hasher.create_cache_key(kernel_source_info, args)
|
699
|
+
if self.fast_checksum:
|
700
|
+
self.src_ll_cache_observations.cache_key_generated = True
|
701
|
+
if self.fast_checksum and src_hasher.validate_cache_key(self.fast_checksum):
|
702
|
+
self.src_ll_cache_observations.cache_validated = True
|
703
|
+
prog = impl.get_runtime().prog
|
704
|
+
self.compiled_kernel_data = prog.load_fast_cache(
|
705
|
+
self.fast_checksum,
|
706
|
+
self.func.__name__,
|
707
|
+
prog.config(),
|
708
|
+
prog.get_device_caps(),
|
709
|
+
)
|
710
|
+
if self.compiled_kernel_data:
|
711
|
+
self.src_ll_cache_observations.cache_loaded = True
|
712
|
+
|
705
713
|
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
|
706
|
-
_logging.trace(f"
|
714
|
+
_logging.trace(f"Materializing kernel {kernel_name} in {self.autodiff_mode}...")
|
707
715
|
|
708
716
|
tree, ctx = _get_tree_and_ctx(
|
709
717
|
self,
|
710
718
|
args=args,
|
711
719
|
excluded_parameters=self.template_slot_locations,
|
712
720
|
arg_features=arg_features,
|
721
|
+
current_kernel=self,
|
713
722
|
)
|
714
723
|
|
715
724
|
if self.autodiff_mode != AutodiffMode.NONE:
|
@@ -717,7 +726,7 @@ class Kernel:
|
|
717
726
|
|
718
727
|
# Do not change the name of 'gstaichi_ast_generator'
|
719
728
|
# The warning system needs this identifier to remove unnecessary messages
|
720
|
-
def gstaichi_ast_generator(kernel_cxx:
|
729
|
+
def gstaichi_ast_generator(kernel_cxx: KernelCxx):
|
721
730
|
nonlocal tree
|
722
731
|
if self.runtime.inside_kernel:
|
723
732
|
raise GsTaichiSyntaxError(
|
@@ -729,8 +738,8 @@ class Kernel:
|
|
729
738
|
self.kernel_cpp = kernel_cxx
|
730
739
|
self.runtime.inside_kernel = True
|
731
740
|
self.runtime._current_kernel = self
|
732
|
-
assert self.runtime.
|
733
|
-
self.runtime.
|
741
|
+
assert self.runtime._compiling_callable is None
|
742
|
+
self.runtime._compiling_callable = kernel_cxx
|
734
743
|
try:
|
735
744
|
ctx.ast_builder = kernel_cxx.ast_builder()
|
736
745
|
|
@@ -767,8 +776,9 @@ class Kernel:
|
|
767
776
|
output_file.write_text(
|
768
777
|
json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
|
769
778
|
)
|
770
|
-
struct_locals = extract_struct_locals_from_context(ctx)
|
771
|
-
tree =
|
779
|
+
struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
|
780
|
+
tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
|
781
|
+
ctx.only_parse_function_def = self.compiled_kernel_data is not None
|
772
782
|
transform_tree(tree, ctx)
|
773
783
|
if not ctx.is_real_function:
|
774
784
|
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
@@ -776,14 +786,14 @@ class Kernel:
|
|
776
786
|
finally:
|
777
787
|
self.runtime.inside_kernel = False
|
778
788
|
self.runtime._current_kernel = None
|
779
|
-
self.runtime.
|
789
|
+
self.runtime._compiling_callable = None
|
780
790
|
|
781
791
|
gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
|
782
|
-
assert key not in self.
|
783
|
-
self.
|
792
|
+
assert key not in self.materialized_kernels
|
793
|
+
self.materialized_kernels[key] = gstaichi_kernel
|
784
794
|
|
785
795
|
def launch_kernel(self, t_kernel: KernelCxx, *args) -> Any:
|
786
|
-
assert len(args) == len(self.
|
796
|
+
assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
|
787
797
|
|
788
798
|
tmps = []
|
789
799
|
callbacks = []
|
@@ -897,43 +907,8 @@ class Kernel:
|
|
897
907
|
)
|
898
908
|
else:
|
899
909
|
raise GsTaichiRuntimeTypeError(
|
900
|
-
f"Argument {
|
910
|
+
f"Argument of type {type(v)} cannot be converted into required type {needed}"
|
901
911
|
)
|
902
|
-
elif has_paddle():
|
903
|
-
# Do we want to continue to support paddle? :thinking_face:
|
904
|
-
# #maybeprunable
|
905
|
-
import paddle # pylint: disable=C0415 # type: ignore
|
906
|
-
|
907
|
-
if isinstance(v, paddle.Tensor):
|
908
|
-
# For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
|
909
|
-
def get_call_back(u, v):
|
910
|
-
def call_back():
|
911
|
-
u.copy_(v, False)
|
912
|
-
|
913
|
-
return call_back
|
914
|
-
|
915
|
-
tmp = v.value().get_tensor()
|
916
|
-
gstaichi_arch = self.runtime.prog.config().arch
|
917
|
-
if v.place.is_gpu_place():
|
918
|
-
if gstaichi_arch != _ti_core.Arch.cuda:
|
919
|
-
# Paddle cuda tensor on GsTaichi non-cuda arch
|
920
|
-
host_v = v.cpu()
|
921
|
-
tmp = host_v.value().get_tensor()
|
922
|
-
callbacks.append(get_call_back(v, host_v))
|
923
|
-
elif v.place.is_cpu_place():
|
924
|
-
if gstaichi_arch == _ti_core.Arch.cuda:
|
925
|
-
# Paddle cpu tensor on GsTaichi cuda arch
|
926
|
-
gpu_v = v.cuda()
|
927
|
-
tmp = gpu_v.value().get_tensor()
|
928
|
-
callbacks.append(get_call_back(v, gpu_v))
|
929
|
-
else:
|
930
|
-
# Paddle do support many other backends like XPU, NPU, MLU, IPU
|
931
|
-
raise GsTaichiRuntimeTypeError(f"GsTaichi do not support backend {v.place} that Paddle support")
|
932
|
-
launch_ctx.set_arg_external_array_with_shape(
|
933
|
-
indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
|
934
|
-
)
|
935
|
-
else:
|
936
|
-
raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
937
912
|
else:
|
938
913
|
raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
939
914
|
|
@@ -979,43 +954,28 @@ class Kernel:
|
|
979
954
|
e.g. templates don't set kernel args, so returns 0
|
980
955
|
a single ndarray is 1 kernel arg, so returns 1
|
981
956
|
a struct of 3 ndarrays would set 3 kernel args, so return 3
|
957
|
+
note: len(indices) > 1 only happens with argpack (which we are removing support for)
|
982
958
|
"""
|
983
|
-
in_argpack = len(indices) > 1
|
984
959
|
nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
|
985
960
|
if actual_argument_slot >= max_arg_num:
|
986
961
|
exceed_max_arg_num = True
|
987
962
|
return 0
|
988
963
|
actual_argument_slot += 1
|
989
|
-
if isinstance(needed_arg_type, ArgPackType):
|
990
|
-
if not isinstance(v, ArgPack):
|
991
|
-
raise GsTaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
|
992
|
-
idx_new = 0
|
993
|
-
for j, (name, anno) in enumerate(needed_arg_type.members.items()):
|
994
|
-
idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
|
995
|
-
launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack) # type: ignore
|
996
|
-
return 1
|
997
964
|
# Note: do not use sth like "needed == f32". That would be slow.
|
998
965
|
if id(needed_arg_type) in primitive_types.real_type_ids:
|
999
966
|
if not isinstance(v, (float, int, np.floating, np.integer)):
|
1000
967
|
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1001
|
-
if in_argpack:
|
1002
|
-
return 1
|
1003
968
|
launch_ctx.set_arg_float(indices, float(v))
|
1004
969
|
return 1
|
1005
970
|
if id(needed_arg_type) in primitive_types.integer_type_ids:
|
1006
971
|
if not isinstance(v, (int, np.integer)):
|
1007
972
|
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1008
|
-
if in_argpack:
|
1009
|
-
return 1
|
1010
973
|
if is_signed(cook_dtype(needed_arg_type)):
|
1011
974
|
launch_ctx.set_arg_int(indices, int(v))
|
1012
975
|
else:
|
1013
976
|
launch_ctx.set_arg_uint(indices, int(v))
|
1014
977
|
return 1
|
1015
978
|
if isinstance(needed_arg_type, sparse_matrix_builder):
|
1016
|
-
if in_argpack:
|
1017
|
-
set_later_list.append((set_arg_sparse_matrix_builder, (v,)))
|
1018
|
-
return 0
|
1019
979
|
set_arg_sparse_matrix_builder(indices, v)
|
1020
980
|
return 1
|
1021
981
|
if dataclasses.is_dataclass(needed_arg_type):
|
@@ -1027,39 +987,23 @@ class Kernel:
|
|
1027
987
|
idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
|
1028
988
|
return idx
|
1029
989
|
if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
|
1030
|
-
if in_argpack:
|
1031
|
-
set_later_list.append((set_arg_ndarray, (v,)))
|
1032
|
-
return 0
|
1033
990
|
set_arg_ndarray(indices, v)
|
1034
991
|
return 1
|
1035
992
|
if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
|
1036
|
-
if in_argpack:
|
1037
|
-
set_later_list.append((set_arg_texture, (v,)))
|
1038
|
-
return 0
|
1039
993
|
set_arg_texture(indices, v)
|
1040
994
|
return 1
|
1041
995
|
if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
|
1042
996
|
v, gstaichi.lang._texture.Texture
|
1043
997
|
):
|
1044
|
-
if in_argpack:
|
1045
|
-
set_later_list.append((set_arg_rw_texture, (v,)))
|
1046
|
-
return 0
|
1047
998
|
set_arg_rw_texture(indices, v)
|
1048
999
|
return 1
|
1049
1000
|
if isinstance(needed_arg_type, ndarray_type.NdarrayType):
|
1050
|
-
if in_argpack:
|
1051
|
-
set_later_list.append((set_arg_ext_array, (v, needed_arg_type)))
|
1052
|
-
return 0
|
1053
1001
|
set_arg_ext_array(indices, v, needed_arg_type)
|
1054
1002
|
return 1
|
1055
1003
|
if isinstance(needed_arg_type, MatrixType):
|
1056
|
-
if in_argpack:
|
1057
|
-
return 1
|
1058
1004
|
set_arg_matrix(indices, v, needed_arg_type)
|
1059
1005
|
return 1
|
1060
1006
|
if isinstance(needed_arg_type, StructType):
|
1061
|
-
if in_argpack:
|
1062
|
-
return 1
|
1063
1007
|
# Unclear how to make the following pass typing checks
|
1064
1008
|
# StructType implements __instancecheck__, which should be a classmethod, but
|
1065
1009
|
# is currently an instance method
|
@@ -1077,7 +1021,7 @@ class Kernel:
|
|
1077
1021
|
template_num = 0
|
1078
1022
|
i_out = 0
|
1079
1023
|
for i_in, val in enumerate(args):
|
1080
|
-
needed_ = self.
|
1024
|
+
needed_ = self.arg_metas[i_in].annotation
|
1081
1025
|
if needed_ == template or isinstance(needed_, template):
|
1082
1026
|
template_num += 1
|
1083
1027
|
i_out += 1
|
@@ -1094,10 +1038,19 @@ class Kernel:
|
|
1094
1038
|
|
1095
1039
|
try:
|
1096
1040
|
prog = impl.get_runtime().prog
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1041
|
+
if not self.compiled_kernel_data:
|
1042
|
+
self.compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
|
1043
|
+
if self.fast_checksum:
|
1044
|
+
src_hasher.store(self.fast_checksum, self.visited_functions)
|
1045
|
+
prog.store_fast_cache(
|
1046
|
+
self.fast_checksum,
|
1047
|
+
self.kernel_cpp,
|
1048
|
+
prog.config(),
|
1049
|
+
prog.get_device_caps(),
|
1050
|
+
self.compiled_kernel_data,
|
1051
|
+
)
|
1052
|
+
self.src_ll_cache_observations.cache_stored = True
|
1053
|
+
prog.launch_kernel(self.compiled_kernel_data, launch_ctx)
|
1101
1054
|
except Exception as e:
|
1102
1055
|
e = handle_exception_from_cpp(e)
|
1103
1056
|
if impl.get_runtime().print_full_traceback:
|
@@ -1170,7 +1123,7 @@ class Kernel:
|
|
1170
1123
|
_logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
|
1171
1124
|
impl.current_cfg().opt_level = 1
|
1172
1125
|
key = self.ensure_compiled(*args)
|
1173
|
-
kernel_cpp = self.
|
1126
|
+
kernel_cpp = self.materialized_kernels[key]
|
1174
1127
|
return self.launch_kernel(kernel_cpp, *args)
|
1175
1128
|
|
1176
1129
|
|
@@ -1256,6 +1209,7 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
|
|
1256
1209
|
wrapped._is_classkernel = is_classkernel
|
1257
1210
|
wrapped._primal = primal
|
1258
1211
|
wrapped._adjoint = adjoint
|
1212
|
+
primal.gstaichi_callable = wrapped
|
1259
1213
|
return wrapped
|
1260
1214
|
|
1261
1215
|
|