gstaichi 1.0.1__cp310-cp310-win_amd64.whl → 2.1.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +1 -3
- gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +13 -41
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/utils.py +1 -7
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -1
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +1 -1
- gstaichi/lang/__init__.py +1 -1
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_template_mapper.py +16 -20
- gstaichi/lang/_wrap_inspect.py +27 -1
- gstaichi/lang/ast/ast_transformer.py +7 -2
- gstaichi/lang/ast/ast_transformer_utils.py +18 -13
- gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
- gstaichi/lang/field.py +0 -38
- gstaichi/lang/impl.py +25 -24
- gstaichi/lang/kernel_arguments.py +28 -30
- gstaichi/lang/kernel_impl.py +154 -200
- gstaichi/lang/matrix.py +0 -46
- gstaichi/lang/struct.py +0 -45
- gstaichi/lang/util.py +11 -80
- gstaichi/types/annotations.py +10 -5
- gstaichi/types/compound_types.py +1 -20
- gstaichi/types/ndarray_type.py +33 -11
- gstaichi/types/utils.py +0 -2
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-2.1.0.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.0.data/data/include/GLFW/glfw3native.h +594 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.0.data/data/lib/glfw3.lib +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/METADATA +4 -3
- {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/RECORD +84 -64
- gstaichi/lang/argpack.py +0 -411
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/WHEEL +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import re
|
|
9
9
|
import warnings
|
10
10
|
from ast import unparse
|
11
11
|
from collections import ChainMap
|
12
|
+
from typing import Any
|
12
13
|
|
13
14
|
import numpy as np
|
14
15
|
|
@@ -18,6 +19,7 @@ from gstaichi.lang import (
|
|
18
19
|
matrix,
|
19
20
|
)
|
20
21
|
from gstaichi.lang import ops as ti_ops
|
22
|
+
from gstaichi.lang._dataclass_util import create_flat_name
|
21
23
|
from gstaichi.lang.ast.ast_transformer_utils import (
|
22
24
|
ASTTransformerContext,
|
23
25
|
get_decorator,
|
@@ -34,7 +36,7 @@ from gstaichi.types import primitive_types
|
|
34
36
|
|
35
37
|
class CallTransformer:
|
36
38
|
@staticmethod
|
37
|
-
def
|
39
|
+
def _build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
|
38
40
|
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
39
41
|
|
40
42
|
func = node.func.ptr
|
@@ -64,7 +66,7 @@ class CallTransformer:
|
|
64
66
|
return False
|
65
67
|
|
66
68
|
@staticmethod
|
67
|
-
def
|
69
|
+
def _build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
|
68
70
|
func = node.func.ptr
|
69
71
|
if id(func) in primitive_types.type_ids:
|
70
72
|
if len(args) != 1 or keywords:
|
@@ -82,7 +84,7 @@ class CallTransformer:
|
|
82
84
|
return False
|
83
85
|
|
84
86
|
@staticmethod
|
85
|
-
def
|
87
|
+
def _is_external_func(ctx: ASTTransformerContext, func) -> bool:
|
86
88
|
if ctx.is_in_static_scope(): # allow external function in static scope
|
87
89
|
return False
|
88
90
|
if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
|
@@ -92,9 +94,9 @@ class CallTransformer:
|
|
92
94
|
return True
|
93
95
|
|
94
96
|
@staticmethod
|
95
|
-
def
|
97
|
+
def _warn_if_is_external_func(ctx: ASTTransformerContext, node):
|
96
98
|
func = node.func.ptr
|
97
|
-
if not CallTransformer.
|
99
|
+
if not CallTransformer._is_external_func(ctx, func):
|
98
100
|
return
|
99
101
|
name = unparse(node.func).strip()
|
100
102
|
warnings.warn_explicit(
|
@@ -120,7 +122,7 @@ class CallTransformer:
|
|
120
122
|
# raw_args: [1.0, 2.0]
|
121
123
|
# raw_keywords: {'k': <ti.Expr>}
|
122
124
|
# return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
|
123
|
-
def
|
125
|
+
def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
|
124
126
|
raw_brackets = re.findall(r"{(.*?)}", raw_string)
|
125
127
|
brackets = []
|
126
128
|
unnamed = 0
|
@@ -164,14 +166,18 @@ class CallTransformer:
|
|
164
166
|
return args
|
165
167
|
|
166
168
|
@staticmethod
|
167
|
-
def
|
169
|
+
def _expand_Call_dataclass_args(args: tuple[ast.stmt]) -> tuple[ast.stmt]:
|
170
|
+
"""
|
171
|
+
We require that each node has a .ptr attribute added to it, that contains
|
172
|
+
the associated Python object
|
173
|
+
"""
|
168
174
|
args_new = []
|
169
175
|
for arg in args:
|
170
176
|
val = arg.ptr
|
171
177
|
if dataclasses.is_dataclass(val):
|
172
178
|
dataclass_type = val
|
173
179
|
for field in dataclasses.fields(dataclass_type):
|
174
|
-
child_name =
|
180
|
+
child_name = create_flat_name(arg.id, field.name)
|
175
181
|
load_ctx = ast.Load()
|
176
182
|
arg_node = ast.Name(
|
177
183
|
id=child_name,
|
@@ -181,13 +187,62 @@ class CallTransformer:
|
|
181
187
|
col_offset=arg.col_offset,
|
182
188
|
end_col_offset=arg.end_col_offset,
|
183
189
|
)
|
184
|
-
|
190
|
+
if dataclasses.is_dataclass(field.type):
|
191
|
+
arg_node.ptr = field.type
|
192
|
+
args_new.extend(CallTransformer._expand_Call_dataclass_args((arg_node,)))
|
193
|
+
else:
|
194
|
+
args_new.append(arg_node)
|
185
195
|
else:
|
186
196
|
args_new.append(arg)
|
187
197
|
return tuple(args_new)
|
188
198
|
|
189
199
|
@staticmethod
|
190
|
-
def
|
200
|
+
def _expand_Call_dataclass_kwargs(kwargs: list[ast.keyword]) -> list[ast.keyword]:
|
201
|
+
"""
|
202
|
+
We require that each node has a .ptr attribute added to it, that contains
|
203
|
+
the associated Python object
|
204
|
+
"""
|
205
|
+
kwargs_new = []
|
206
|
+
for i, kwarg in enumerate(kwargs):
|
207
|
+
val = kwarg.ptr[kwarg.arg]
|
208
|
+
if dataclasses.is_dataclass(val):
|
209
|
+
dataclass_type = val
|
210
|
+
for field in dataclasses.fields(dataclass_type):
|
211
|
+
src_name = create_flat_name(kwarg.value.id, field.name)
|
212
|
+
child_name = create_flat_name(kwarg.arg, field.name)
|
213
|
+
load_ctx = ast.Load()
|
214
|
+
src_node = ast.Name(
|
215
|
+
id=src_name,
|
216
|
+
ctx=load_ctx,
|
217
|
+
lineno=kwarg.lineno,
|
218
|
+
end_lineno=kwarg.end_lineno,
|
219
|
+
col_offset=kwarg.col_offset,
|
220
|
+
end_col_offset=kwarg.end_col_offset,
|
221
|
+
)
|
222
|
+
kwarg_node = ast.keyword(
|
223
|
+
arg=child_name,
|
224
|
+
value=src_node,
|
225
|
+
ctx=load_ctx,
|
226
|
+
lineno=kwarg.lineno,
|
227
|
+
end_lineno=kwarg.end_lineno,
|
228
|
+
col_offset=kwarg.col_offset,
|
229
|
+
end_col_offset=kwarg.end_col_offset,
|
230
|
+
)
|
231
|
+
if dataclasses.is_dataclass(field.type):
|
232
|
+
kwarg_node.ptr = {child_name: field.type}
|
233
|
+
kwargs_new.extend(CallTransformer._expand_Call_dataclass_kwargs([kwarg_node]))
|
234
|
+
else:
|
235
|
+
kwargs_new.append(kwarg_node)
|
236
|
+
else:
|
237
|
+
kwargs_new.append(kwarg)
|
238
|
+
return kwargs_new
|
239
|
+
|
240
|
+
@staticmethod
|
241
|
+
def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts) -> Any | None:
|
242
|
+
"""
|
243
|
+
example ast:
|
244
|
+
Call(func=Name(id='f2', ctx=Load()), args=[Name(id='my_struct_ab', ctx=Load())], keywords=[])
|
245
|
+
"""
|
191
246
|
if get_decorator(ctx, node) in ["static", "static_assert"]:
|
192
247
|
with ctx.static_scope_guard():
|
193
248
|
build_stmt(ctx, node.func)
|
@@ -198,7 +253,9 @@ class CallTransformer:
|
|
198
253
|
# creates variable for the dataclass itself (as well as other variables,
|
199
254
|
# not related to dataclasses). Necessary for calling further child functions
|
200
255
|
build_stmts(ctx, node.args)
|
201
|
-
|
256
|
+
build_stmts(ctx, node.keywords)
|
257
|
+
node.args = CallTransformer._expand_Call_dataclass_args(node.args)
|
258
|
+
node.keywords = CallTransformer._expand_Call_dataclass_kwargs(node.keywords)
|
202
259
|
# create variables for the now-expanded dataclass members
|
203
260
|
build_stmts(ctx, node.args)
|
204
261
|
build_stmts(ctx, node.keywords)
|
@@ -223,7 +280,7 @@ class CallTransformer:
|
|
223
280
|
|
224
281
|
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
|
225
282
|
raw_string = node.func.value.ptr
|
226
|
-
args = CallTransformer.
|
283
|
+
args = CallTransformer._canonicalize_formatted_string(raw_string, *args, **keywords)
|
227
284
|
node.ptr = impl.ti_format(*args)
|
228
285
|
return node.ptr
|
229
286
|
|
@@ -231,17 +288,17 @@ class CallTransformer:
|
|
231
288
|
node.ptr = matrix.make_matrix(*args, **keywords)
|
232
289
|
return node.ptr
|
233
290
|
|
234
|
-
if CallTransformer.
|
291
|
+
if CallTransformer._build_call_if_is_builtin(ctx, node, args, keywords):
|
235
292
|
return node.ptr
|
236
293
|
|
237
|
-
if CallTransformer.
|
294
|
+
if CallTransformer._build_call_if_is_type(ctx, node, args, keywords):
|
238
295
|
return node.ptr
|
239
296
|
|
240
297
|
if hasattr(node.func, "caller"):
|
241
298
|
node.ptr = func(node.func.caller, *args, **keywords)
|
242
299
|
return node.ptr
|
243
300
|
|
244
|
-
CallTransformer.
|
301
|
+
CallTransformer._warn_if_is_external_func(ctx, node)
|
245
302
|
try:
|
246
303
|
node.ptr = func(*args, **keywords)
|
247
304
|
except TypeError as e:
|
@@ -249,7 +306,7 @@ class CallTransformer:
|
|
249
306
|
error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
|
250
307
|
func_name = getattr(func, "__name__", func.__class__.__name__)
|
251
308
|
msg = f"TypeError when calling `{func_name}`: {error_msg}."
|
252
|
-
if CallTransformer.
|
309
|
+
if CallTransformer._is_external_func(ctx, node.func.ptr):
|
253
310
|
args_has_expr = any([isinstance(arg, Expr) for arg in args])
|
254
311
|
if args_has_expr and (module == math or module == np):
|
255
312
|
exec_str = f"from gstaichi import {func.__name__}"
|
@@ -4,6 +4,10 @@ import ast
|
|
4
4
|
import dataclasses
|
5
5
|
from typing import Any, Callable
|
6
6
|
|
7
|
+
from gstaichi._lib.core.gstaichi_python import (
|
8
|
+
BoundaryMode,
|
9
|
+
DataTypeCxx,
|
10
|
+
)
|
7
11
|
from gstaichi.lang import (
|
8
12
|
_ndarray,
|
9
13
|
any_array,
|
@@ -13,7 +17,7 @@ from gstaichi.lang import (
|
|
13
17
|
matrix,
|
14
18
|
)
|
15
19
|
from gstaichi.lang import ops as ti_ops
|
16
|
-
from gstaichi.lang.
|
20
|
+
from gstaichi.lang._dataclass_util import create_flat_name
|
17
21
|
from gstaichi.lang.ast.ast_transformer_utils import (
|
18
22
|
ASTTransformerContext,
|
19
23
|
)
|
@@ -29,153 +33,127 @@ from gstaichi.types import annotations, ndarray_type, primitive_types, texture_t
|
|
29
33
|
class FunctionDefTransformer:
|
30
34
|
@staticmethod
|
31
35
|
def _decl_and_create_variable(
|
32
|
-
ctx: ASTTransformerContext,
|
36
|
+
ctx: ASTTransformerContext,
|
37
|
+
annotation: Any,
|
38
|
+
name: str,
|
39
|
+
this_arg_features: tuple[tuple[Any, ...], ...] | None,
|
40
|
+
prefix_name: str,
|
33
41
|
) -> tuple[bool, Any]:
|
34
42
|
full_name = prefix_name + "_" + name
|
35
43
|
if not isinstance(annotation, primitive_types.RefType):
|
36
44
|
ctx.kernel_args.append(name)
|
37
|
-
if isinstance(annotation, ArgPackType):
|
38
|
-
kernel_arguments.push_argpack_arg(name)
|
39
|
-
d = {}
|
40
|
-
items_to_put_in_dict = []
|
41
|
-
for j, (_name, anno) in enumerate(annotation.members.items()):
|
42
|
-
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
43
|
-
ctx, anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
|
44
|
-
)
|
45
|
-
if not result:
|
46
|
-
d[_name] = None
|
47
|
-
items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
|
48
|
-
else:
|
49
|
-
d[_name] = obj
|
50
|
-
argpack = kernel_arguments.decl_argpack_arg(annotation, d)
|
51
|
-
for item in items_to_put_in_dict:
|
52
|
-
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
53
|
-
return True, argpack
|
54
45
|
if annotation == annotations.template or isinstance(annotation, annotations.template):
|
46
|
+
assert ctx.global_vars is not None
|
55
47
|
return True, ctx.global_vars[name]
|
56
48
|
if isinstance(annotation, annotations.sparse_matrix_builder):
|
57
49
|
return False, (
|
58
50
|
kernel_arguments.decl_sparse_matrix,
|
59
51
|
(
|
60
|
-
to_gstaichi_type(
|
52
|
+
to_gstaichi_type(this_arg_features),
|
61
53
|
full_name,
|
62
54
|
),
|
63
55
|
)
|
64
56
|
if isinstance(annotation, ndarray_type.NdarrayType):
|
57
|
+
assert this_arg_features is not None
|
58
|
+
raw_element_type: DataTypeCxx
|
59
|
+
ndim: int
|
60
|
+
needs_grad: bool
|
61
|
+
boundary: BoundaryMode
|
62
|
+
raw_element_type, ndim, needs_grad, boundary = this_arg_features
|
65
63
|
return False, (
|
66
64
|
kernel_arguments.decl_ndarray_arg,
|
67
65
|
(
|
68
|
-
to_gstaichi_type(
|
69
|
-
|
66
|
+
to_gstaichi_type(raw_element_type),
|
67
|
+
ndim,
|
70
68
|
full_name,
|
71
|
-
|
72
|
-
|
69
|
+
needs_grad,
|
70
|
+
boundary,
|
73
71
|
),
|
74
72
|
)
|
75
73
|
if isinstance(annotation, texture_type.TextureType):
|
76
|
-
|
74
|
+
assert this_arg_features is not None
|
75
|
+
return False, (kernel_arguments.decl_texture_arg, (this_arg_features[0], full_name))
|
77
76
|
if isinstance(annotation, texture_type.RWTextureType):
|
77
|
+
assert this_arg_features is not None
|
78
78
|
return False, (
|
79
79
|
kernel_arguments.decl_rw_texture_arg,
|
80
|
-
(
|
80
|
+
(this_arg_features[0], this_arg_features[1], this_arg_features[2], full_name),
|
81
81
|
)
|
82
82
|
if isinstance(annotation, MatrixType):
|
83
|
-
return True, kernel_arguments.decl_matrix_arg(annotation, name
|
83
|
+
return True, kernel_arguments.decl_matrix_arg(annotation, name)
|
84
84
|
if isinstance(annotation, StructType):
|
85
|
-
return True, kernel_arguments.decl_struct_arg(annotation, name
|
86
|
-
return True, kernel_arguments.decl_scalar_arg(annotation, name
|
85
|
+
return True, kernel_arguments.decl_struct_arg(annotation, name)
|
86
|
+
return True, kernel_arguments.decl_scalar_arg(annotation, name)
|
87
87
|
|
88
88
|
@staticmethod
|
89
89
|
def _transform_kernel_arg(
|
90
90
|
ctx: ASTTransformerContext,
|
91
|
-
invoke_later_dict: dict[str, tuple[Any, str, Callable, list[Any]]],
|
92
|
-
create_variable_later: dict[str, Any],
|
93
91
|
argument_name: str,
|
94
92
|
argument_type: Any,
|
95
93
|
this_arg_features: tuple[Any, ...],
|
96
94
|
) -> None:
|
97
|
-
if
|
98
|
-
kernel_arguments.push_argpack_arg(argument_name)
|
99
|
-
d = {}
|
100
|
-
items_to_put_in_dict: list[tuple[str, str, Any]] = []
|
101
|
-
for j, (name, anno) in enumerate(argument_type.members.items()):
|
102
|
-
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
103
|
-
ctx, anno, name, this_arg_features[j], invoke_later_dict, "__argpack_" + name, 1
|
104
|
-
)
|
105
|
-
if not result:
|
106
|
-
d[name] = None
|
107
|
-
items_to_put_in_dict.append(("__argpack_" + name, name, obj))
|
108
|
-
else:
|
109
|
-
d[name] = obj
|
110
|
-
argpack = kernel_arguments.decl_argpack_arg(argument_type, d)
|
111
|
-
for item in items_to_put_in_dict:
|
112
|
-
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
113
|
-
create_variable_later[argument_name] = argpack
|
114
|
-
elif dataclasses.is_dataclass(argument_type):
|
115
|
-
arg_features = this_arg_features
|
95
|
+
if dataclasses.is_dataclass(argument_type):
|
116
96
|
ctx.create_variable(argument_name, argument_type)
|
117
97
|
for field_idx, field in enumerate(dataclasses.fields(argument_type)):
|
118
|
-
flat_name =
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
)
|
128
|
-
if result:
|
129
|
-
ctx.create_variable(flat_name, obj)
|
98
|
+
flat_name = create_flat_name(argument_name, field.name)
|
99
|
+
# if a field is a dataclass, then feed back into process_kernel_arg recursively
|
100
|
+
if dataclasses.is_dataclass(field.type):
|
101
|
+
FunctionDefTransformer._transform_kernel_arg(
|
102
|
+
ctx,
|
103
|
+
flat_name,
|
104
|
+
field.type,
|
105
|
+
this_arg_features[field_idx],
|
106
|
+
)
|
130
107
|
else:
|
131
|
-
|
132
|
-
|
133
|
-
|
108
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
109
|
+
ctx,
|
110
|
+
field.type,
|
111
|
+
flat_name,
|
112
|
+
this_arg_features[field_idx],
|
113
|
+
"",
|
114
|
+
)
|
115
|
+
if result:
|
116
|
+
ctx.create_variable(flat_name, obj)
|
117
|
+
else:
|
118
|
+
decl_type_func, type_args = obj
|
119
|
+
obj = decl_type_func(*type_args)
|
120
|
+
ctx.create_variable(flat_name, obj)
|
134
121
|
else:
|
135
122
|
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
136
123
|
ctx,
|
137
124
|
argument_type,
|
138
125
|
argument_name,
|
139
126
|
this_arg_features if ctx.arg_features is not None else None,
|
140
|
-
invoke_later_dict,
|
141
127
|
"",
|
142
|
-
0,
|
143
128
|
)
|
144
|
-
if result:
|
145
|
-
ctx.create_variable(argument_name, obj)
|
146
|
-
else:
|
129
|
+
if not result:
|
147
130
|
decl_type_func, type_args = obj
|
148
131
|
obj = decl_type_func(*type_args)
|
149
|
-
|
132
|
+
ctx.create_variable(argument_name, obj)
|
150
133
|
|
151
134
|
@staticmethod
|
152
135
|
def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
136
|
+
assert ctx.func is not None
|
137
|
+
assert ctx.arg_features is not None
|
153
138
|
if node.returns is not None:
|
154
139
|
if not isinstance(node.returns, ast.Constant):
|
140
|
+
assert ctx.func.return_type is not None
|
155
141
|
for return_type in ctx.func.return_type:
|
156
142
|
kernel_arguments.decl_ret(return_type)
|
157
|
-
impl.get_runtime().compiling_callable
|
143
|
+
compiling_callable = impl.get_runtime().compiling_callable
|
144
|
+
assert compiling_callable is not None
|
145
|
+
compiling_callable.finalize_rets()
|
158
146
|
|
159
|
-
|
160
|
-
|
161
|
-
for i, arg in enumerate(args.args):
|
162
|
-
argument = ctx.func.arguments[i]
|
147
|
+
for i in range(len(args.args)):
|
148
|
+
arg_meta = ctx.func.arg_metas[i]
|
163
149
|
FunctionDefTransformer._transform_kernel_arg(
|
164
150
|
ctx,
|
165
|
-
|
166
|
-
|
167
|
-
argument.name,
|
168
|
-
argument.annotation,
|
151
|
+
arg_meta.name,
|
152
|
+
arg_meta.annotation,
|
169
153
|
ctx.arg_features[i] if ctx.arg_features is not None else (),
|
170
154
|
)
|
171
155
|
|
172
|
-
|
173
|
-
argpack, name, func, params = v
|
174
|
-
argpack[name] = func(*params)
|
175
|
-
for k, v in create_variable_later.items():
|
176
|
-
ctx.create_variable(k, v)
|
177
|
-
|
178
|
-
impl.get_runtime().compiling_callable.finalize_params()
|
156
|
+
compiling_callable.finalize_params()
|
179
157
|
# remove original args
|
180
158
|
node.args.args = []
|
181
159
|
|
@@ -186,15 +164,16 @@ class FunctionDefTransformer:
|
|
186
164
|
argument_type: Any,
|
187
165
|
data: Any,
|
188
166
|
) -> None:
|
167
|
+
# Template arguments are passed by reference.
|
189
168
|
if isinstance(argument_type, annotations.template):
|
190
169
|
ctx.create_variable(argument_name, data)
|
191
170
|
return None
|
192
171
|
|
193
172
|
if dataclasses.is_dataclass(argument_type):
|
194
|
-
|
195
|
-
|
173
|
+
for field in dataclasses.fields(argument_type):
|
174
|
+
flat_name = create_flat_name(argument_name, field.name)
|
196
175
|
data_child = getattr(data, field.name)
|
197
|
-
if
|
176
|
+
if isinstance(
|
198
177
|
data_child,
|
199
178
|
(
|
200
179
|
_ndarray.ScalarNdarray,
|
@@ -203,33 +182,33 @@ class FunctionDefTransformer:
|
|
203
182
|
any_array.AnyArray,
|
204
183
|
),
|
205
184
|
):
|
185
|
+
field.type.check_matched(data_child.get_type(), field.name)
|
186
|
+
ctx.create_variable(flat_name, data_child)
|
187
|
+
elif dataclasses.is_dataclass(data_child):
|
188
|
+
FunctionDefTransformer._transform_func_arg(
|
189
|
+
ctx,
|
190
|
+
flat_name,
|
191
|
+
field.type,
|
192
|
+
getattr(data, field.name),
|
193
|
+
)
|
194
|
+
else:
|
206
195
|
raise GsTaichiSyntaxError(
|
207
|
-
f"Argument {
|
196
|
+
f"Argument {field.name} of type {argument_type} {field.type} is not recognized."
|
208
197
|
)
|
209
|
-
field.type.check_matched(data_child.get_type(), field.name)
|
210
|
-
var_name = f"__ti_{argument_name}_{field.name}"
|
211
|
-
ctx.create_variable(var_name, data_child)
|
212
198
|
return None
|
213
199
|
|
214
200
|
# Ndarray arguments are passed by reference.
|
215
201
|
if isinstance(argument_type, (ndarray_type.NdarrayType)):
|
216
202
|
if not isinstance(
|
217
|
-
data,
|
218
|
-
(
|
219
|
-
_ndarray.ScalarNdarray,
|
220
|
-
matrix.VectorNdarray,
|
221
|
-
matrix.MatrixNdarray,
|
222
|
-
any_array.AnyArray,
|
223
|
-
),
|
203
|
+
data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray)
|
224
204
|
):
|
225
|
-
raise GsTaichiSyntaxError(f"Argument {
|
205
|
+
raise GsTaichiSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.")
|
226
206
|
argument_type.check_matched(data.get_type(), argument_name)
|
227
207
|
ctx.create_variable(argument_name, data)
|
228
208
|
return None
|
229
209
|
|
230
210
|
# Matrix arguments are passed by value.
|
231
211
|
if isinstance(argument_type, (MatrixType)):
|
232
|
-
var_name = argument_name
|
233
212
|
# "data" is expected to be an Expr here,
|
234
213
|
# so we simply call "impl.expr_init_func(data)" to perform:
|
235
214
|
#
|
@@ -239,32 +218,31 @@ class FunctionDefTransformer:
|
|
239
218
|
# We created local variable "t" - a copy of the passed-in argument "data"
|
240
219
|
if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
|
241
220
|
raise GsTaichiSyntaxError(
|
242
|
-
f"Argument {
|
221
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
|
243
222
|
)
|
244
223
|
|
245
224
|
element_shape = data.ptr.get_rvalue_type().shape()
|
246
225
|
if len(element_shape) != argument_type.ndim:
|
247
226
|
raise GsTaichiSyntaxError(
|
248
|
-
f"Argument {
|
227
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
|
249
228
|
)
|
250
229
|
|
251
230
|
assert argument_type.ndim > 0
|
252
231
|
if element_shape[0] != argument_type.n:
|
253
232
|
raise GsTaichiSyntaxError(
|
254
|
-
f"Argument {
|
233
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
|
255
234
|
)
|
256
235
|
|
257
236
|
if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
|
258
237
|
raise GsTaichiSyntaxError(
|
259
|
-
f"Argument {
|
238
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
|
260
239
|
)
|
261
240
|
|
262
|
-
ctx.create_variable(
|
241
|
+
ctx.create_variable(argument_name, impl.expr_init_func(data))
|
263
242
|
return None
|
264
243
|
|
265
244
|
if id(argument_type) in primitive_types.type_ids:
|
266
|
-
|
267
|
-
ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
|
245
|
+
ctx.create_variable(argument_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
|
268
246
|
return None
|
269
247
|
# Create a copy for non-template arguments,
|
270
248
|
# so that they are passed by value.
|
@@ -274,15 +252,16 @@ class FunctionDefTransformer:
|
|
274
252
|
|
275
253
|
@staticmethod
|
276
254
|
def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
255
|
+
# pylint: disable=import-outside-toplevel
|
256
|
+
from gstaichi.lang.kernel_impl import Func
|
257
|
+
|
258
|
+
assert isinstance(ctx.func, Func)
|
259
|
+
assert ctx.argument_data is not None
|
277
260
|
for data_i, data in enumerate(ctx.argument_data):
|
278
|
-
argument = ctx.func.
|
279
|
-
FunctionDefTransformer._transform_func_arg(
|
280
|
-
ctx,
|
281
|
-
argument.name,
|
282
|
-
argument.annotation,
|
283
|
-
data,
|
284
|
-
)
|
261
|
+
argument = ctx.func.arg_metas[data_i]
|
262
|
+
FunctionDefTransformer._transform_func_arg(ctx, argument.name, argument.annotation, data)
|
285
263
|
|
264
|
+
# deal with dataclasses
|
286
265
|
for v in ctx.func.orig_arguments:
|
287
266
|
if dataclasses.is_dataclass(v.annotation):
|
288
267
|
ctx.create_variable(v.name, v.annotation)
|
@@ -308,7 +287,12 @@ class FunctionDefTransformer:
|
|
308
287
|
if ctx.is_kernel: # ti.kernel
|
309
288
|
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
310
289
|
|
311
|
-
|
290
|
+
if ctx.only_parse_function_def:
|
291
|
+
return None
|
292
|
+
|
293
|
+
if not ctx.is_kernel: # ti.func
|
294
|
+
assert ctx.argument_data is not None
|
295
|
+
assert ctx.func is not None
|
312
296
|
if ctx.is_real_function:
|
313
297
|
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
314
298
|
else:
|
gstaichi/lang/field.py
CHANGED
@@ -9,7 +9,6 @@ from gstaichi.lang.util import (
|
|
9
9
|
in_python_scope,
|
10
10
|
python_scope,
|
11
11
|
to_numpy_type,
|
12
|
-
to_paddle_type,
|
13
12
|
to_pytorch_type,
|
14
13
|
)
|
15
14
|
|
@@ -152,18 +151,6 @@ class Field:
|
|
152
151
|
"""
|
153
152
|
raise NotImplementedError()
|
154
153
|
|
155
|
-
@python_scope
|
156
|
-
def to_paddle(self, place=None):
|
157
|
-
"""Converts `self` to a paddle tensor.
|
158
|
-
|
159
|
-
Args:
|
160
|
-
place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
|
161
|
-
|
162
|
-
Returns:
|
163
|
-
paddle.Tensor: The result paddle tensor.
|
164
|
-
"""
|
165
|
-
raise NotImplementedError()
|
166
|
-
|
167
154
|
@python_scope
|
168
155
|
def from_numpy(self, arr):
|
169
156
|
"""Loads all elements from a numpy array.
|
@@ -190,17 +177,6 @@ class Field:
|
|
190
177
|
"""
|
191
178
|
self._from_external_arr(arr.contiguous())
|
192
179
|
|
193
|
-
@python_scope
|
194
|
-
def from_paddle(self, arr):
|
195
|
-
"""Loads all elements from a paddle tensor.
|
196
|
-
|
197
|
-
The shape of the paddle tensor needs to be the same as `self`.
|
198
|
-
|
199
|
-
Args:
|
200
|
-
arr (paddle.Tensor): The source paddle tensor.
|
201
|
-
"""
|
202
|
-
self.from_numpy(arr)
|
203
|
-
|
204
180
|
@python_scope
|
205
181
|
def copy_from(self, other):
|
206
182
|
"""Copies all elements from another field.
|
@@ -325,20 +301,6 @@ class ScalarField(Field):
|
|
325
301
|
gstaichi.lang.runtime_ops.sync()
|
326
302
|
return arr
|
327
303
|
|
328
|
-
@python_scope
|
329
|
-
def to_paddle(self, place=None):
|
330
|
-
"""Converts this field to a `paddle.Tensor`."""
|
331
|
-
import paddle # pylint: disable=C0415
|
332
|
-
|
333
|
-
# pylint: disable=E1101
|
334
|
-
# paddle.empty() doesn't support argument `place``
|
335
|
-
arr = paddle.to_tensor(paddle.zeros(self.shape, to_paddle_type(self.dtype)), place=place)
|
336
|
-
from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
|
337
|
-
|
338
|
-
tensor_to_ext_arr(self, arr)
|
339
|
-
gstaichi.lang.runtime_ops.sync()
|
340
|
-
return arr
|
341
|
-
|
342
304
|
@python_scope
|
343
305
|
def _from_external_arr(self, arr):
|
344
306
|
if len(self.shape) != len(arr.shape):
|