gstaichi 0.1.25.dev0__cp311-cp311-macosx_15_0_arm64.whl → 2.0.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +1 -1
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +11 -41
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -1
- gstaichi/examples/minimal.py +1 -1
- gstaichi/lang/__init__.py +1 -1
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_template_mapper.py +16 -20
- gstaichi/lang/_wrap_inspect.py +27 -1
- gstaichi/lang/ast/ast_transformer.py +7 -2
- gstaichi/lang/ast/ast_transformer_utils.py +18 -13
- gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
- gstaichi/lang/field.py +0 -38
- gstaichi/lang/impl.py +25 -24
- gstaichi/lang/kernel_arguments.py +28 -30
- gstaichi/lang/kernel_impl.py +154 -200
- gstaichi/lang/matrix.py +0 -46
- gstaichi/lang/struct.py +0 -45
- gstaichi/lang/util.py +11 -80
- gstaichi/types/annotations.py +10 -5
- gstaichi/types/compound_types.py +1 -20
- gstaichi/types/ndarray_type.py +31 -11
- gstaichi/types/utils.py +0 -2
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/METADATA +2 -1
- gstaichi-2.0.0.dist-info/RECORD +177 -0
- gstaichi/__main__.py +0 -5
- gstaichi/_main.py +0 -545
- gstaichi/lang/argpack.py +0 -411
- gstaichi-0.1.25.dev0.dist-info/RECORD +0 -168
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +0 -2
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/top_level.txt +0 -0
gstaichi/lang/impl.py
CHANGED
@@ -8,6 +8,7 @@ from gstaichi._lib import core as _ti_core
|
|
8
8
|
from gstaichi._lib.core.gstaichi_python import (
|
9
9
|
DataTypeCxx,
|
10
10
|
Function,
|
11
|
+
KernelCxx,
|
11
12
|
Program,
|
12
13
|
)
|
13
14
|
from gstaichi._snode.fields_builder import FieldsBuilder
|
@@ -70,17 +71,14 @@ from gstaichi.types.primitive_types import (
|
|
70
71
|
|
71
72
|
@gstaichi_scope
|
72
73
|
def expr_init_shared_array(shape, element_type):
|
73
|
-
|
74
|
-
|
75
|
-
return
|
76
|
-
shape, element_type, _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
77
|
-
)
|
74
|
+
ast_builder = get_runtime().compiling_callable.ast_builder()
|
75
|
+
debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
76
|
+
return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)
|
78
77
|
|
79
78
|
|
80
79
|
@gstaichi_scope
|
81
80
|
def expr_init(rhs):
|
82
81
|
compiling_callable = get_runtime().compiling_callable
|
83
|
-
assert compiling_callable is not None
|
84
82
|
if rhs is None:
|
85
83
|
return Expr(
|
86
84
|
compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
|
@@ -167,7 +165,7 @@ def _calc_slice(index, default_stop):
|
|
167
165
|
"GsTaichi does not support variables in slice now, please use constant instead of it."
|
168
166
|
)
|
169
167
|
|
170
|
-
check_validity(start), check_validity(stop), check_validity(step)
|
168
|
+
_ = check_validity(start), check_validity(stop), check_validity(step)
|
171
169
|
return [_ for _ in range(start, stop, step)]
|
172
170
|
|
173
171
|
|
@@ -194,9 +192,7 @@ def validate_subscript_index(value, index):
|
|
194
192
|
@gstaichi_scope
|
195
193
|
def subscript(ast_builder, value, *_indices, skip_reordered=False):
|
196
194
|
dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
197
|
-
|
198
|
-
assert compiling_callable is not None
|
199
|
-
ast_builder = compiling_callable.ast_builder()
|
195
|
+
ast_builder = get_runtime().compiling_callable.ast_builder()
|
200
196
|
# Directly evaluate in Python for non-GsTaichi types
|
201
197
|
if not isinstance(
|
202
198
|
value,
|
@@ -337,8 +333,8 @@ class PyGsTaichi:
|
|
337
333
|
self._prog: Program | None = None
|
338
334
|
self.src_info_stack = []
|
339
335
|
self.inside_kernel: bool = False
|
340
|
-
self.
|
341
|
-
self._current_kernel: Kernel | None = None
|
336
|
+
self._compiling_callable: KernelCxx | Kernel | Function | None = None
|
337
|
+
self._current_kernel: "Kernel | None" = None
|
342
338
|
self.global_vars = []
|
343
339
|
self.grad_vars = []
|
344
340
|
self.dual_vars = []
|
@@ -350,10 +346,18 @@ class PyGsTaichi:
|
|
350
346
|
self.target_tape = None
|
351
347
|
self.fwd_mode_manager = None
|
352
348
|
self.grad_replaced = False
|
353
|
-
self.kernels = kernels or []
|
349
|
+
self.kernels: list[Kernel] = kernels or []
|
354
350
|
self._signal_handler_registry = None
|
355
351
|
self.unfinalized_fields_builder = {}
|
356
352
|
|
353
|
+
@property
|
354
|
+
def compiling_callable(self) -> KernelCxx | Kernel | Function:
|
355
|
+
if self._compiling_callable is None:
|
356
|
+
raise GsTaichiRuntimeError(
|
357
|
+
"_compiling_callable attribute not initialized. Maybe you forgot to call `ti.init()` first?"
|
358
|
+
)
|
359
|
+
return self._compiling_callable
|
360
|
+
|
357
361
|
@property
|
358
362
|
def prog(self) -> Program:
|
359
363
|
if self._prog is None:
|
@@ -364,7 +368,7 @@ class PyGsTaichi:
|
|
364
368
|
def current_kernel(self) -> Kernel:
|
365
369
|
if self._current_kernel is None:
|
366
370
|
raise GsTaichiRuntimeError(
|
367
|
-
"
|
371
|
+
"_current_kernel attribute not initialized. Maybe you forgot to call `ti.init()` first?"
|
368
372
|
)
|
369
373
|
return self._current_kernel
|
370
374
|
|
@@ -373,7 +377,7 @@ class PyGsTaichi:
|
|
373
377
|
|
374
378
|
def clear_compiled_functions(self):
|
375
379
|
for k in self.kernels:
|
376
|
-
k.
|
380
|
+
k.materialized_kernels.clear()
|
377
381
|
|
378
382
|
def finalize_fields_builder(self, builder):
|
379
383
|
self.unfinalized_fields_builder.pop(builder)
|
@@ -390,7 +394,7 @@ class PyGsTaichi:
|
|
390
394
|
def get_num_compiled_functions(self):
|
391
395
|
count = 0
|
392
396
|
for k in self.kernels:
|
393
|
-
count += len(k.
|
397
|
+
count += len(k.materialized_kernels)
|
394
398
|
return count
|
395
399
|
|
396
400
|
def src_info_guard(self, info):
|
@@ -962,11 +966,9 @@ def ti_print(*_vars, sep=" ", end="\n"):
|
|
962
966
|
|
963
967
|
_vars = add_separators(_vars)
|
964
968
|
contents, formats = ti_format_list_to_content_entries(_vars)
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
contents, formats, _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
969
|
-
)
|
969
|
+
ast_builder = get_runtime().compiling_callable.ast_builder()
|
970
|
+
debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
971
|
+
ast_builder.create_print(contents, formats, debug_info)
|
970
972
|
|
971
973
|
|
972
974
|
@gstaichi_scope
|
@@ -996,9 +998,8 @@ def ti_format(*args):
|
|
996
998
|
def ti_assert(cond, msg, extra_args, dbg_info):
|
997
999
|
# Mostly a wrapper to help us convert from Expr (defined in Python) to
|
998
1000
|
# _ti_core.Expr (defined in C++)
|
999
|
-
|
1000
|
-
|
1001
|
-
compiling_callable.ast_builder().create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
|
1001
|
+
ast_builder = get_runtime().compiling_callable.ast_builder()
|
1002
|
+
ast_builder.create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
|
1002
1003
|
|
1003
1004
|
|
1004
1005
|
@gstaichi_scope
|
@@ -4,6 +4,10 @@ import inspect
|
|
4
4
|
|
5
5
|
import gstaichi.lang
|
6
6
|
from gstaichi._lib import core as _ti_core
|
7
|
+
from gstaichi._lib.core.gstaichi_python import (
|
8
|
+
BoundaryMode,
|
9
|
+
DataTypeCxx,
|
10
|
+
)
|
7
11
|
from gstaichi.lang import impl, ops
|
8
12
|
from gstaichi.lang._texture import RWTextureAccessor, TextureSampler
|
9
13
|
from gstaichi.lang.any_array import AnyArray
|
@@ -15,11 +19,18 @@ from gstaichi.types.compound_types import CompoundType
|
|
15
19
|
from gstaichi.types.primitive_types import RefType, u64
|
16
20
|
|
17
21
|
|
18
|
-
class
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
22
|
+
class ArgMetadata:
|
23
|
+
"""
|
24
|
+
Metadata about an argument to a function
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, annotation, name, default=inspect.Parameter.empty):
|
28
|
+
self.annotation = annotation
|
29
|
+
self.name = name
|
30
|
+
self.default = default
|
31
|
+
|
32
|
+
def __repr__(self) -> str:
|
33
|
+
return f"{self.__class__.__name__}(annotation={self.annotation}, name={self.name}, default={self.default})"
|
23
34
|
|
24
35
|
|
25
36
|
class SparseMatrixEntry:
|
@@ -48,7 +59,7 @@ class SparseMatrixProxy:
|
|
48
59
|
return SparseMatrixEntry(self.ptr, i, j, self.dtype)
|
49
60
|
|
50
61
|
|
51
|
-
def decl_scalar_arg(dtype, name
|
62
|
+
def decl_scalar_arg(dtype, name):
|
52
63
|
is_ref = False
|
53
64
|
if isinstance(dtype, RefType):
|
54
65
|
is_ref = True
|
@@ -60,9 +71,7 @@ def decl_scalar_arg(dtype, name, arg_depth):
|
|
60
71
|
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name)
|
61
72
|
|
62
73
|
argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
63
|
-
return Expr(
|
64
|
-
_ti_core.make_arg_load_expr(arg_id, dtype, is_ref, create_load=True, arg_depth=arg_depth, dbg_info=argload_di)
|
65
|
-
)
|
74
|
+
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref, create_load=True, dbg_info=argload_di))
|
66
75
|
|
67
76
|
|
68
77
|
def get_type_for_kernel_args(dtype, name):
|
@@ -86,35 +95,22 @@ def get_type_for_kernel_args(dtype, name):
|
|
86
95
|
return dtype
|
87
96
|
|
88
97
|
|
89
|
-
def decl_matrix_arg(matrixtype, name
|
98
|
+
def decl_matrix_arg(matrixtype, name):
|
90
99
|
arg_type = get_type_for_kernel_args(matrixtype, name)
|
91
100
|
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
|
92
101
|
argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
93
|
-
arg_load = Expr(
|
94
|
-
_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
|
95
|
-
)
|
102
|
+
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, dbg_info=argload_di))
|
96
103
|
return matrixtype.from_gstaichi_object(arg_load)
|
97
104
|
|
98
105
|
|
99
|
-
def decl_struct_arg(structtype, name
|
106
|
+
def decl_struct_arg(structtype, name):
|
100
107
|
arg_type = get_type_for_kernel_args(structtype, name)
|
101
108
|
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
|
102
109
|
argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
103
|
-
arg_load = Expr(
|
104
|
-
_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
|
105
|
-
)
|
110
|
+
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, dbg_info=argload_di))
|
106
111
|
return structtype.from_gstaichi_object(arg_load)
|
107
112
|
|
108
113
|
|
109
|
-
def push_argpack_arg(name):
|
110
|
-
impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name)
|
111
|
-
|
112
|
-
|
113
|
-
def decl_argpack_arg(argpacktype, member_dict):
|
114
|
-
impl.get_runtime().compiling_callable.pop_argpack_stack()
|
115
|
-
return argpacktype.from_gstaichi_object(member_dict)
|
116
|
-
|
117
|
-
|
118
114
|
def decl_sparse_matrix(dtype, name):
|
119
115
|
value_type = cook_dtype(dtype)
|
120
116
|
ptr_type = cook_dtype(u64)
|
@@ -126,16 +122,18 @@ def decl_sparse_matrix(dtype, name):
|
|
126
122
|
)
|
127
123
|
|
128
124
|
|
129
|
-
def decl_ndarray_arg(
|
125
|
+
def decl_ndarray_arg(
|
126
|
+
element_type: DataTypeCxx, ndim: int, name: str, needs_grad: bool, boundary: BoundaryMode
|
127
|
+
) -> AnyArray:
|
130
128
|
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
|
131
|
-
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad,
|
129
|
+
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))
|
132
130
|
|
133
131
|
|
134
132
|
def decl_texture_arg(num_dimensions, name):
|
135
133
|
# FIXME: texture_arg doesn't have element_shape so better separate them
|
136
134
|
arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name)
|
137
135
|
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
138
|
-
return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions,
|
136
|
+
return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, dbg_info), num_dimensions)
|
139
137
|
|
140
138
|
|
141
139
|
def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
|
@@ -143,7 +141,7 @@ def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
|
|
143
141
|
arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name)
|
144
142
|
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
145
143
|
return RWTextureAccessor(
|
146
|
-
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions,
|
144
|
+
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod, dbg_info), num_dimensions
|
147
145
|
)
|
148
146
|
|
149
147
|
|