procfunc 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- procfunc/__init__.py +87 -0
- procfunc/color.py +57 -0
- procfunc/compute_graph/__init__.py +28 -0
- procfunc/compute_graph/compute_graph.py +115 -0
- procfunc/compute_graph/node.py +200 -0
- procfunc/compute_graph/operators_info.py +92 -0
- procfunc/compute_graph/proxy.py +173 -0
- procfunc/compute_graph/util.py +282 -0
- procfunc/context.py +115 -0
- procfunc/control.py +174 -0
- procfunc/nodes/__init__.py +66 -0
- procfunc/nodes/bindings_util.py +196 -0
- procfunc/nodes/bpy_node_info.py +280 -0
- procfunc/nodes/compositor.py +2242 -0
- procfunc/nodes/execute/construct_nodes.py +571 -0
- procfunc/nodes/execute/construct_special_cases.py +246 -0
- procfunc/nodes/execute/execute.py +548 -0
- procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
- procfunc/nodes/execute/util.py +247 -0
- procfunc/nodes/func.py +1417 -0
- procfunc/nodes/geo.py +4240 -0
- procfunc/nodes/manifest.json +8769 -0
- procfunc/nodes/math.py +644 -0
- procfunc/nodes/node_function.py +160 -0
- procfunc/nodes/shader.py +2359 -0
- procfunc/nodes/types.py +347 -0
- procfunc/ops/__init__.py +35 -0
- procfunc/ops/_util.py +275 -0
- procfunc/ops/addons.py +59 -0
- procfunc/ops/attr.py +426 -0
- procfunc/ops/collection.py +90 -0
- procfunc/ops/curve.py +18 -0
- procfunc/ops/file.py +126 -0
- procfunc/ops/manifest.json +39149 -0
- procfunc/ops/mesh.py +1510 -0
- procfunc/ops/modifier.py +603 -0
- procfunc/ops/object.py +258 -0
- procfunc/ops/primitives/__init__.py +31 -0
- procfunc/ops/primitives/camera.py +45 -0
- procfunc/ops/primitives/curve.py +71 -0
- procfunc/ops/primitives/light.py +114 -0
- procfunc/ops/primitives/mesh.py +358 -0
- procfunc/ops/uv.py +271 -0
- procfunc/random.py +247 -0
- procfunc/tracer/__init__.py +43 -0
- procfunc/tracer/decorator.py +121 -0
- procfunc/tracer/patch.py +494 -0
- procfunc/tracer/proxy.py +127 -0
- procfunc/tracer/trace.py +222 -0
- procfunc/transforms/__init__.py +49 -0
- procfunc/transforms/cleanup.py +214 -0
- procfunc/transforms/convert.py +20 -0
- procfunc/transforms/distribution.py +191 -0
- procfunc/transforms/extract_materials.py +116 -0
- procfunc/transforms/infer_distribution.py +326 -0
- procfunc/transforms/parameters.py +15 -0
- procfunc/transforms/util.py +35 -0
- procfunc/transpiler/__init__.py +24 -0
- procfunc/transpiler/bpy_to_computegraph.py +1348 -0
- procfunc/transpiler/codegen.py +919 -0
- procfunc/transpiler/identifiers.py +595 -0
- procfunc/transpiler/main.py +299 -0
- procfunc/types.py +380 -0
- procfunc/util/__init__.py +0 -0
- procfunc/util/bpy_info.py +145 -0
- procfunc/util/camera.py +0 -0
- procfunc/util/keyframe.py +70 -0
- procfunc/util/log.py +96 -0
- procfunc/util/manifest.py +121 -0
- procfunc/util/pytree.py +343 -0
- procfunc/util/teardown.py +37 -0
- procfunc-0.30.0.dist-info/METADATA +120 -0
- procfunc-0.30.0.dist-info/RECORD +76 -0
- procfunc-0.30.0.dist-info/WHEEL +5 -0
- procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
- procfunc-0.30.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,919 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import enum
|
|
3
|
+
import inspect
|
|
4
|
+
import itertools
|
|
5
|
+
import logging
|
|
6
|
+
from collections import OrderedDict, defaultdict
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Generator, Union, get_args, get_origin
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import procfunc as pf
|
|
13
|
+
from procfunc import compute_graph as cg
|
|
14
|
+
from procfunc.compute_graph.operators_info import (
|
|
15
|
+
FUNCTIONS_TO_OPERATORS,
|
|
16
|
+
OPERATOR_TEMPLATES,
|
|
17
|
+
OperatorType,
|
|
18
|
+
)
|
|
19
|
+
from procfunc.nodes import types as nt
|
|
20
|
+
from procfunc.transpiler import identifiers
|
|
21
|
+
from procfunc.util import pytree
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
INDENT = " "
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def indent_lines(lines: list[str], indent: str = INDENT) -> list[str]:
|
|
29
|
+
return [indent + line for line in lines]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _repr_type(x: Any) -> str:
|
|
33
|
+
# TODO: make the user pass in special resolutions for types, or else we will just do verbose types
|
|
34
|
+
|
|
35
|
+
if isinstance(x, str):
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
if x.__name__ == "NoneType":
|
|
39
|
+
return "None"
|
|
40
|
+
|
|
41
|
+
origin = get_origin(x)
|
|
42
|
+
args = get_args(x)
|
|
43
|
+
|
|
44
|
+
if x.__name__ == "ProcNode":
|
|
45
|
+
if len(args) == 1:
|
|
46
|
+
return f"pf.ProcNode[{_repr_type(args[0])}]"
|
|
47
|
+
elif len(args) == 0:
|
|
48
|
+
return "pf.ProcNode"
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"Unsupported ProcNode type: {x} {args=}")
|
|
51
|
+
|
|
52
|
+
if hasattr(pf, x.__name__):
|
|
53
|
+
if len(args):
|
|
54
|
+
raise ValueError(f"procfunc type had unhandled annotations: {x} {args=}")
|
|
55
|
+
return f"pf.{x.__name__}"
|
|
56
|
+
|
|
57
|
+
if x.__module__ == "builtins":
|
|
58
|
+
return x.__name__
|
|
59
|
+
|
|
60
|
+
origin = get_origin(x)
|
|
61
|
+
args = get_args(x)
|
|
62
|
+
|
|
63
|
+
if origin is Union:
|
|
64
|
+
args_0 = get_args(args[0])
|
|
65
|
+
if get_origin(args[0]) is nt.ProcNode and args_0[0] is args[1]:
|
|
66
|
+
return f"t.SocketOrVal[{_repr_type(args_0[0])}]"
|
|
67
|
+
else:
|
|
68
|
+
return " | ".join([_repr_type(a) for a in args])
|
|
69
|
+
|
|
70
|
+
if getattr(x, "__module__", None) == "procfunc.nodes.types":
|
|
71
|
+
return f"t.{x.__name__}"
|
|
72
|
+
|
|
73
|
+
return x.__name__
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _repr_value(value: Any) -> str:
|
|
77
|
+
if hasattr(value, "__wrapped__"):
|
|
78
|
+
value = value.__wrapped__
|
|
79
|
+
|
|
80
|
+
if isinstance(value, cg.Proxy):
|
|
81
|
+
logger.warning(
|
|
82
|
+
f"Proxy object {value} should never appear as a raw value in codegen - "
|
|
83
|
+
f"its underlying node {value.node} was not resolved to a variable"
|
|
84
|
+
)
|
|
85
|
+
if isinstance(value, nt.ProcNode):
|
|
86
|
+
logger.warning(
|
|
87
|
+
f"Procnode object {value} should never be treated as a raw value in codegen"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if isinstance(value, np.random.Generator):
|
|
91
|
+
return "np.random.default_rng()"
|
|
92
|
+
elif isinstance(value, type):
|
|
93
|
+
return _repr_type(value)
|
|
94
|
+
elif isinstance(value, np.ndarray):
|
|
95
|
+
nprepr = repr(value).replace("\n", "")
|
|
96
|
+
return f"np.{nprepr}"
|
|
97
|
+
elif isinstance(value, np.dtype):
|
|
98
|
+
return f"np.dtype('{value}')"
|
|
99
|
+
elif isinstance(value, (pf.Color, pf.Vector, pf.Euler, pf.Quaternion, pf.Matrix)):
|
|
100
|
+
x = tuple(round(x, 6) for x in value)
|
|
101
|
+
return f"pf.{value.__class__.__name__}({x})"
|
|
102
|
+
elif isinstance(value, enum.Enum):
|
|
103
|
+
return f"{type(value).__name__}.{value.name}"
|
|
104
|
+
elif isinstance(value, Path):
|
|
105
|
+
return f"Path({str(value)!r})"
|
|
106
|
+
elif dataclasses.is_dataclass(value) and not isinstance(value, type):
|
|
107
|
+
args_str = ", ".join(
|
|
108
|
+
f"{f.name}={_repr_value(getattr(value, f.name))}"
|
|
109
|
+
for f in dataclasses.fields(value)
|
|
110
|
+
)
|
|
111
|
+
return f"{type(value).__name__}({args_str})"
|
|
112
|
+
elif isinstance(value, list):
|
|
113
|
+
return f"[{', '.join([_repr_value(x) for x in value])}]"
|
|
114
|
+
else:
|
|
115
|
+
return repr(value)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _repr_inp(
|
|
119
|
+
arg: Any,
|
|
120
|
+
scope_expressions: dict[int, str | list[str]],
|
|
121
|
+
extra_parens: bool = False,
|
|
122
|
+
) -> str:
|
|
123
|
+
if isinstance(arg, cg.Node):
|
|
124
|
+
if id(arg) not in scope_expressions:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Scope expressions {scope_expressions} did not contain {arg=} possibly due to bad visit ordering"
|
|
127
|
+
)
|
|
128
|
+
expr = scope_expressions[id(arg)]
|
|
129
|
+
else:
|
|
130
|
+
expr = _repr_value(arg)
|
|
131
|
+
|
|
132
|
+
if isinstance(expr, list):
|
|
133
|
+
if len(expr) > 1:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"Inlined values should not resolve to more than one line in current implementation, "
|
|
136
|
+
f"got {expr=} for {arg=}"
|
|
137
|
+
)
|
|
138
|
+
expr = expr[0]
|
|
139
|
+
assert isinstance(expr, str)
|
|
140
|
+
|
|
141
|
+
if " " in expr and extra_parens and expr[0] != "(" and expr[-1] != ")":
|
|
142
|
+
return f"({expr})"
|
|
143
|
+
else:
|
|
144
|
+
return expr
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _kwarg_matches_default(sig: inspect.Signature, key: str, value: Any) -> bool:
|
|
148
|
+
if isinstance(value, (cg.Node, cg.Proxy)):
|
|
149
|
+
return False
|
|
150
|
+
param = sig.parameters.get(key)
|
|
151
|
+
if param is None or param.default is inspect.Parameter.empty:
|
|
152
|
+
return False
|
|
153
|
+
try:
|
|
154
|
+
return bool(value == param.default)
|
|
155
|
+
except Exception:
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _repr_args(
|
|
160
|
+
func: Callable[..., Any] | None,
|
|
161
|
+
args: tuple[Any, ...],
|
|
162
|
+
kwargs: dict[str, Any],
|
|
163
|
+
scope_expressions: dict[int, str | list[str]],
|
|
164
|
+
) -> list[str]:
|
|
165
|
+
"""
|
|
166
|
+
Create string for arg and kwarg def for function inputs
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
sig = inspect.signature(func) if func is not None else None
|
|
171
|
+
except ValueError:
|
|
172
|
+
sig = None
|
|
173
|
+
|
|
174
|
+
if sig is not None:
|
|
175
|
+
kwargs = {
|
|
176
|
+
k: v for k, v in kwargs.items() if not _kwarg_matches_default(sig, k, v)
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# common specialcase: nodes with a single output which would unnecessarily be a kwarg can just be a positional arg instead
|
|
180
|
+
if len(args) == 0 and len(kwargs) == 1 and sig is not None:
|
|
181
|
+
if next(iter(kwargs)) == next(iter(sig.parameters)):
|
|
182
|
+
args = (kwargs[next(iter(kwargs))],)
|
|
183
|
+
kwargs = {}
|
|
184
|
+
|
|
185
|
+
argreprs = pytree.PyTree(args).map(lambda x: _repr_inp(x, scope_expressions))
|
|
186
|
+
argreprs = [
|
|
187
|
+
pytree.repr_tree_to_str(v, type_namer=_repr_type)
|
|
188
|
+
for v in argreprs.unflatten_one_level()
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
kwargreprs = (
|
|
192
|
+
pytree.PyTree(kwargs)
|
|
193
|
+
.map(lambda x: _repr_inp(x, scope_expressions))
|
|
194
|
+
.unflatten_one_level()
|
|
195
|
+
)
|
|
196
|
+
kwargreprs = {
|
|
197
|
+
k: pytree.repr_tree_to_str(v, type_namer=_repr_type)
|
|
198
|
+
for k, v in kwargreprs.items()
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
# use func sig to sort kwargs
|
|
202
|
+
if sig is not None:
|
|
203
|
+
kwargkeys = list(sig.parameters.keys())
|
|
204
|
+
has_var_keyword = any(
|
|
205
|
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
|
206
|
+
)
|
|
207
|
+
if not has_var_keyword:
|
|
208
|
+
assert set(kwargreprs.keys()).issubset(set(kwargkeys)), (
|
|
209
|
+
f"{kwargreprs.keys()=} {kwargkeys=}"
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
kwargkeys = kwargkeys + [k for k in kwargreprs.keys() if k not in kwargkeys]
|
|
213
|
+
else:
|
|
214
|
+
kwargkeys = list(kwargs.keys())
|
|
215
|
+
|
|
216
|
+
kwarglist = [f"{k}={kwargreprs[k]}" for k in kwargkeys if k in kwargreprs]
|
|
217
|
+
|
|
218
|
+
return argreprs + kwarglist
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _repr_function_call(
|
|
222
|
+
node: cg.FunctionCallNode | cg.MethodCallNode | cg.SubgraphCallNode,
|
|
223
|
+
scope_expressions: dict[int, str | list[str]],
|
|
224
|
+
line_limit: int = 80,
|
|
225
|
+
) -> list[str]:
|
|
226
|
+
match node:
|
|
227
|
+
case cg.FunctionCallNode():
|
|
228
|
+
func = node.func
|
|
229
|
+
func_str = scope_expressions[id(func)]
|
|
230
|
+
case cg.MethodCallNode(args=(target, *_), method_name=method_name):
|
|
231
|
+
if not isinstance(target, cg.Node):
|
|
232
|
+
raise ValueError(f"Method call {node=} has non-node target {target=}")
|
|
233
|
+
func = None
|
|
234
|
+
func_str = f"{_repr_inp(target, scope_expressions)}.{method_name}"
|
|
235
|
+
case cg.SubgraphCallNode(subgraph=subgraph):
|
|
236
|
+
func = None
|
|
237
|
+
func_str = scope_expressions.get(id(subgraph))
|
|
238
|
+
if func_str is None:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Scope expressions did not contain definition for {subgraph=}"
|
|
241
|
+
)
|
|
242
|
+
assert isinstance(func_str, str), func_str
|
|
243
|
+
case _:
|
|
244
|
+
raise TypeError(f"Unsupported {node=}")
|
|
245
|
+
|
|
246
|
+
args = node.args[1:] if isinstance(node, cg.MethodCallNode) else node.args
|
|
247
|
+
arg_reprs = _repr_args(func, args, node.kwargs, scope_expressions) # type: ignore
|
|
248
|
+
|
|
249
|
+
if len(arg_reprs) == 0:
|
|
250
|
+
return [f"{func_str}()"]
|
|
251
|
+
|
|
252
|
+
total_len = len(func_str) + sum(len(arg) for arg in arg_reprs)
|
|
253
|
+
multiline = total_len > line_limit
|
|
254
|
+
|
|
255
|
+
if len(arg_reprs) > 1 and multiline:
|
|
256
|
+
arg_reprs = [line + "," for line in arg_reprs]
|
|
257
|
+
|
|
258
|
+
if multiline:
|
|
259
|
+
return [f"{func_str}("] + indent_lines(arg_reprs) + [")"]
|
|
260
|
+
else:
|
|
261
|
+
return [f"{func_str}({', '.join(arg_reprs)})"]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _repr_operator_call(
|
|
265
|
+
node: cg.FunctionCallNode,
|
|
266
|
+
scope_expressions: dict[int, str | list[str]],
|
|
267
|
+
) -> list[str]:
|
|
268
|
+
assert isinstance(node, cg.FunctionCallNode), node
|
|
269
|
+
|
|
270
|
+
# Support both positional args and kwargs for operator templates
|
|
271
|
+
all_args = [
|
|
272
|
+
_repr_inp(v, scope_expressions, extra_parens=True) for v in node.args
|
|
273
|
+
] + [
|
|
274
|
+
_repr_inp(v, scope_expressions, extra_parens=True) for v in node.kwargs.values()
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
operator_template = scope_expressions[id(node.func)]
|
|
278
|
+
assert isinstance(operator_template, str), operator_template
|
|
279
|
+
return [operator_template.format(*all_args)]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _codegen_for_node(
|
|
283
|
+
node: cg.Node,
|
|
284
|
+
scope_expressions: dict[int, str | list[str]],
|
|
285
|
+
) -> list[str]:
|
|
286
|
+
match node:
|
|
287
|
+
case cg.FunctionCallNode(func=func):
|
|
288
|
+
funcres = scope_expressions[id(func)]
|
|
289
|
+
if isinstance(funcres, list):
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"{node} resolved to {funcres} but functions should always resolve to names, not expressions"
|
|
292
|
+
)
|
|
293
|
+
elif funcres == OperatorType.NOOP:
|
|
294
|
+
return [] # no code needed
|
|
295
|
+
elif "{}" in funcres:
|
|
296
|
+
return _repr_operator_call(node, scope_expressions)
|
|
297
|
+
else:
|
|
298
|
+
return _repr_function_call(node, scope_expressions)
|
|
299
|
+
case cg.MethodCallNode() if node.method_name == "__getitem__":
|
|
300
|
+
callee_expr = _repr_inp(node.args[0], scope_expressions)
|
|
301
|
+
idx_expr = _repr_inp(node.args[1], scope_expressions)
|
|
302
|
+
return [f"{callee_expr}[{idx_expr}]"]
|
|
303
|
+
case cg.MethodCallNode():
|
|
304
|
+
return _repr_function_call(node, scope_expressions)
|
|
305
|
+
case cg.SubgraphCallNode():
|
|
306
|
+
return _repr_function_call(node, scope_expressions)
|
|
307
|
+
case cg.GetAttributeNode(args=(source,), attribute_name=attribute_name):
|
|
308
|
+
arg_expr = scope_expressions[id(source)]
|
|
309
|
+
if isinstance(arg_expr, list) and len(arg_expr) == 1:
|
|
310
|
+
arg_expr = arg_expr[0]
|
|
311
|
+
if not isinstance(arg_expr, str):
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Attribute access {attribute_name!r} on {source!r} resolved to {arg_expr} but should be a string"
|
|
314
|
+
)
|
|
315
|
+
if " " in arg_expr:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"f{_codegen_for_node.__name__} got would attempt to create getattr expression "
|
|
318
|
+
f"{arg_expr!r}.{attribute_name} due to space in {arg_expr=} "
|
|
319
|
+
f"for {id(node)=} {node=} {id(source)=} {source=}"
|
|
320
|
+
)
|
|
321
|
+
return [f"{arg_expr}.{attribute_name}"]
|
|
322
|
+
case cg.ConstantNode(value=value):
|
|
323
|
+
return [_repr_value(value)]
|
|
324
|
+
case _:
|
|
325
|
+
raise TypeError(f"Unsupported {node=}")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _codegen_graph_inputs(
|
|
329
|
+
graph: cg.ComputeGraph,
|
|
330
|
+
node_names: dict[int, str],
|
|
331
|
+
typename: str | None,
|
|
332
|
+
func_name: str | None = None,
|
|
333
|
+
) -> list[str]:
|
|
334
|
+
args = sorted(
|
|
335
|
+
list(graph.inputs.values()),
|
|
336
|
+
key=lambda x: x.kwargs.get("default_value", None) is not None,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
func_name = func_name or graph.name
|
|
340
|
+
|
|
341
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
342
|
+
argnames = [node_names.get(id(node)) for node in args]
|
|
343
|
+
logger.debug(f"Codegen inputs for {func_name} {argnames=}")
|
|
344
|
+
|
|
345
|
+
if len(args) == 0:
|
|
346
|
+
return [f"def {func_name}():"]
|
|
347
|
+
|
|
348
|
+
args_lines = []
|
|
349
|
+
for node in args:
|
|
350
|
+
if id(node) not in node_names:
|
|
351
|
+
raise ValueError(f"Node {node} has no name in {node_names}")
|
|
352
|
+
name = node_names[id(node)]
|
|
353
|
+
|
|
354
|
+
known_value_type = node.metadata.get("known_value_type", None)
|
|
355
|
+
line = (
|
|
356
|
+
f"{name}: {_repr_type(known_value_type)}"
|
|
357
|
+
if known_value_type is not None
|
|
358
|
+
else f"{name}"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if (default := node.kwargs.get("default_value")) is not None:
|
|
362
|
+
line += f" = {_repr_value(default)}"
|
|
363
|
+
|
|
364
|
+
args_lines.append(line + ",")
|
|
365
|
+
|
|
366
|
+
end_statement = "):" if typename is None else f") -> {typename}: "
|
|
367
|
+
|
|
368
|
+
return [f"def {func_name}("] + indent_lines(args_lines) + [end_statement]
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _codegen_namedtuple_def(outputs: pytree.PyTree):
|
|
372
|
+
tupletype = outputs.toplevel_type()
|
|
373
|
+
|
|
374
|
+
type_lines = []
|
|
375
|
+
for name, node in outputs.items():
|
|
376
|
+
if node is None:
|
|
377
|
+
continue
|
|
378
|
+
vt = node.metadata.get("known_value_type", None)
|
|
379
|
+
if vt is None:
|
|
380
|
+
type_lines.append(f"{name}: Any")
|
|
381
|
+
else:
|
|
382
|
+
type_lines.append(f"{name}: {_repr_type(vt)}")
|
|
383
|
+
|
|
384
|
+
return [f"class {tupletype.__name__}(NamedTuple):"] + indent_lines(type_lines)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _codegen_for_outputs(
|
|
388
|
+
graph: cg.ComputeGraph,
|
|
389
|
+
scope_expressions: dict[int, str | list[str]],
|
|
390
|
+
) -> tuple[str | None, list[str], list[str]]:
|
|
391
|
+
if len(graph.outputs) == 0:
|
|
392
|
+
return None, [], []
|
|
393
|
+
if len(graph.outputs) == 1:
|
|
394
|
+
single_output = next(graph.outputs.values())
|
|
395
|
+
vt = single_output.metadata.get("known_value_type", None)
|
|
396
|
+
type_name = _repr_type(vt) if vt is not None else None
|
|
397
|
+
return type_name, [], [f"return {_repr_inp(single_output, scope_expressions)}"]
|
|
398
|
+
|
|
399
|
+
graph_output_type = graph.outputs.toplevel_type()
|
|
400
|
+
type_name = _repr_type(graph_output_type)
|
|
401
|
+
|
|
402
|
+
is_pf_type = hasattr(pf, graph_output_type.__name__)
|
|
403
|
+
if is_pf_type:
|
|
404
|
+
type_def = []
|
|
405
|
+
elif graph_output_type.__module__ == "builtins":
|
|
406
|
+
type_def = []
|
|
407
|
+
elif id(graph_output_type) in scope_expressions:
|
|
408
|
+
assert scope_expressions[id(graph_output_type)] == type_name
|
|
409
|
+
logger.debug(f"Skipping redefinition of {graph_output_type}")
|
|
410
|
+
type_def = []
|
|
411
|
+
elif pytree.is_type_namedtuple(graph_output_type):
|
|
412
|
+
type_def = _codegen_namedtuple_def(graph.outputs)
|
|
413
|
+
scope_expressions[id(graph_output_type)] = type_name
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError(f"Unhandled graph output type: {graph_output_type}")
|
|
416
|
+
|
|
417
|
+
reprs_tree = graph.outputs.map(lambda node: _repr_inp(node, scope_expressions))
|
|
418
|
+
return_lines = [
|
|
419
|
+
f"return {pytree.repr_tree_to_str(reprs_tree, type_namer=_repr_type)}"
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
return type_name, type_def, return_lines
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _check_graph_input_names(
|
|
426
|
+
graph: cg.ComputeGraph,
|
|
427
|
+
scope_names: dict[int, str],
|
|
428
|
+
):
|
|
429
|
+
input_names = {id(node): name for name, node in graph.inputs.items()}
|
|
430
|
+
if len(input_names.values()) != len(set(input_names.values())):
|
|
431
|
+
raise ValueError(
|
|
432
|
+
f"Input names for {graph.name} had duplicate values. {input_names.values()=}"
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
overlap = set(input_names.values()).intersection(set(scope_names.values()))
|
|
436
|
+
for k, v in input_names.items():
|
|
437
|
+
if v not in overlap:
|
|
438
|
+
continue
|
|
439
|
+
|
|
440
|
+
newname = v + "_val"
|
|
441
|
+
assert newname not in input_names.values()
|
|
442
|
+
input_names[k] = newname
|
|
443
|
+
logger.warning(
|
|
444
|
+
f"Renaming input {k=} of {graph.name=} from {v} to {newname} to avoid "
|
|
445
|
+
f"collision, since {v} is also the name of a util function"
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
for orig_name, node in graph.inputs.items():
|
|
449
|
+
identifier = input_names[id(node)]
|
|
450
|
+
if not identifiers.is_valid_snake_identifier(identifier):
|
|
451
|
+
raise ValueError(
|
|
452
|
+
f"{graph.name=} had input {orig_name=} {node=} which recieved invalid identifier {identifier=}"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return input_names
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _codegen_graph_decorator(graph: cg.ComputeGraph) -> list[str]:
|
|
459
|
+
if graph.metadata.get("is_node_function"):
|
|
460
|
+
return ["@pf.nodes.node_function"]
|
|
461
|
+
return []
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _should_fold_node(
|
|
465
|
+
node: cg.Node,
|
|
466
|
+
parent: cg.Node | None,
|
|
467
|
+
scope_expressions: dict[int, str | list[str]],
|
|
468
|
+
usages: dict[int, list[cg.Node]],
|
|
469
|
+
fold_map: dict[int, bool],
|
|
470
|
+
) -> bool:
|
|
471
|
+
if isinstance(node, cg.MethodCallNode) and node.method_name in (
|
|
472
|
+
"astype",
|
|
473
|
+
"__getitem__",
|
|
474
|
+
):
|
|
475
|
+
return True
|
|
476
|
+
|
|
477
|
+
if isinstance(node, cg.GetAttributeNode):
|
|
478
|
+
return True
|
|
479
|
+
|
|
480
|
+
if any(isinstance(u, cg.GetAttributeNode) for u in usages.get(id(node), [])):
|
|
481
|
+
return False
|
|
482
|
+
|
|
483
|
+
if len(usages.get(id(node), [])) > 1:
|
|
484
|
+
return False
|
|
485
|
+
|
|
486
|
+
if parent is None:
|
|
487
|
+
return False
|
|
488
|
+
|
|
489
|
+
if (
|
|
490
|
+
isinstance(node, cg.FunctionCallNode)
|
|
491
|
+
and "{}" in scope_expressions[id(node.func)]
|
|
492
|
+
):
|
|
493
|
+
return not fold_map.get(id(parent), False)
|
|
494
|
+
|
|
495
|
+
return False
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _expression_fold_map(
|
|
499
|
+
graph: cg.ComputeGraph,
|
|
500
|
+
scope_expressions: dict[int, str | list[str]],
|
|
501
|
+
usages: dict[int, list[cg.Node]],
|
|
502
|
+
) -> dict[int, bool]:
|
|
503
|
+
fold_map: dict[int, bool] = {}
|
|
504
|
+
|
|
505
|
+
for output in graph.outputs.values():
|
|
506
|
+
fold_map[id(output)] = (
|
|
507
|
+
output is None
|
|
508
|
+
or isinstance(output, cg.ConstantNode)
|
|
509
|
+
or isinstance(output, cg.GetAttributeNode)
|
|
510
|
+
)
|
|
511
|
+
for parent, node in cg.traverse_breadth_first(graph, yield_parent=True):
|
|
512
|
+
if id(node) in fold_map:
|
|
513
|
+
continue # dont overwrite output settings
|
|
514
|
+
should_fold = _should_fold_node(
|
|
515
|
+
node, parent, scope_expressions, usages, fold_map
|
|
516
|
+
)
|
|
517
|
+
fold_map[id(node)] = should_fold
|
|
518
|
+
|
|
519
|
+
return fold_map
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def traverse_chunks(
|
|
523
|
+
graph: cg.ComputeGraph,
|
|
524
|
+
pred: Callable[[cg.Node, list[cg.Node]], bool],
|
|
525
|
+
) -> Generator[list[cg.Node], None, None]:
|
|
526
|
+
visited = set()
|
|
527
|
+
|
|
528
|
+
def _greedy_singleuses(node: cg.Node, chunk: list[cg.Node]):
|
|
529
|
+
if id(node) in visited:
|
|
530
|
+
return
|
|
531
|
+
visited.add(id(node))
|
|
532
|
+
yield node
|
|
533
|
+
|
|
534
|
+
for arg in itertools.chain(node.args, node.kwargs.values()):
|
|
535
|
+
if not isinstance(arg, cg.Node):
|
|
536
|
+
continue
|
|
537
|
+
if id(arg) in visited:
|
|
538
|
+
continue
|
|
539
|
+
if not pred(arg, chunk):
|
|
540
|
+
continue
|
|
541
|
+
yield from _greedy_singleuses(arg, chunk)
|
|
542
|
+
|
|
543
|
+
for node in cg.traverse_breadth_first(graph):
|
|
544
|
+
if id(node) in visited:
|
|
545
|
+
continue
|
|
546
|
+
chunk = []
|
|
547
|
+
yield list(_greedy_singleuses(node, chunk))
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _code_paragraphing_predicate(
|
|
551
|
+
node: cg.Node,
|
|
552
|
+
chunk: list[cg.Node],
|
|
553
|
+
scope_expressions: dict[int, str | list[str]],
|
|
554
|
+
usages: dict[int, list[cg.Node]],
|
|
555
|
+
) -> bool:
|
|
556
|
+
if not isinstance(node, cg.FunctionCallNode):
|
|
557
|
+
return False
|
|
558
|
+
|
|
559
|
+
target_expr = scope_expressions[id(node.func)]
|
|
560
|
+
if not (".math." not in target_expr or "{}" not in target_expr):
|
|
561
|
+
return False
|
|
562
|
+
|
|
563
|
+
uses = usages[id(node)]
|
|
564
|
+
if len(uses) == 1:
|
|
565
|
+
return True
|
|
566
|
+
elif not all(any(id(u) == id(v) for v in chunk) for u in uses):
|
|
567
|
+
return False
|
|
568
|
+
|
|
569
|
+
return True
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _codegen_for_assignment(
|
|
573
|
+
assign_varname: str,
|
|
574
|
+
node_code: list[str] | str,
|
|
575
|
+
add_line_comments: bool,
|
|
576
|
+
) -> list[str]:
|
|
577
|
+
assert isinstance(assign_varname, str)
|
|
578
|
+
assert identifiers.is_valid_snake_identifier(assign_varname)
|
|
579
|
+
|
|
580
|
+
if isinstance(node_code, list):
|
|
581
|
+
node_code[0] = f"{assign_varname} = " + node_code[0]
|
|
582
|
+
else:
|
|
583
|
+
node_code = [f"{assign_varname} = {node_code}"]
|
|
584
|
+
if add_line_comments:
|
|
585
|
+
node_code[0] += f" # {node}" # noqa: F821
|
|
586
|
+
|
|
587
|
+
return node_code
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def _expressions_scope_for_graph(
|
|
591
|
+
graph: cg.ComputeGraph,
|
|
592
|
+
scope_expressions: dict[int, str | list[str]],
|
|
593
|
+
) -> tuple[dict[int, str | list[str]], dict[int, bool]]:
|
|
594
|
+
expressions: dict[int, str | list[str]] = {
|
|
595
|
+
**scope_expressions.copy(),
|
|
596
|
+
**_check_graph_input_names(graph, scope_expressions),
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
# when we want to refer to a value, what string should we insert?
|
|
600
|
+
# - for most nodes: refer to a variable name
|
|
601
|
+
# - for inlined expressions: emplace a expression string
|
|
602
|
+
usages = cg.usages_per_node(graph)
|
|
603
|
+
fold_map = _expression_fold_map(graph, expressions, usages=usages)
|
|
604
|
+
|
|
605
|
+
node_names = identifiers.nodenames_from_fixed_and_infill(
|
|
606
|
+
graph,
|
|
607
|
+
fold_map=fold_map,
|
|
608
|
+
scope_expressions=expressions,
|
|
609
|
+
)
|
|
610
|
+
if duplicates := identifiers.duplicate_names(node_names):
|
|
611
|
+
raise ValueError(f"Duplicate node names: {duplicates}")
|
|
612
|
+
|
|
613
|
+
if intersection := set(expressions.values()).intersection(set(node_names.values())):
|
|
614
|
+
raise ValueError(f"Scope and node names had overlap: {intersection=}")
|
|
615
|
+
expressions.update(node_names)
|
|
616
|
+
|
|
617
|
+
return expressions, fold_map
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def _codegen_for_graph(
|
|
621
|
+
graph: cg.ComputeGraph,
|
|
622
|
+
scope_expressions: dict[int, str],
|
|
623
|
+
as_maincall: bool = True,
|
|
624
|
+
add_version_comment: bool = True,
|
|
625
|
+
add_line_comments: bool = False,
|
|
626
|
+
func_name: str | None = None,
|
|
627
|
+
) -> list[str]:
|
|
628
|
+
code_lines: list[str] = []
|
|
629
|
+
|
|
630
|
+
if add_version_comment:
|
|
631
|
+
code_lines.append(f"# Code generated by procfunc v{pf.__version__}")
|
|
632
|
+
|
|
633
|
+
expressions, fold_map = _expressions_scope_for_graph(graph, scope_expressions)
|
|
634
|
+
_input_ids = set(id(node) for node in graph.inputs.values()) # noqa: F841
|
|
635
|
+
|
|
636
|
+
last_varname: str = ""
|
|
637
|
+
|
|
638
|
+
# Collect mutator call nodes so they emit as bare statements (no assignment)
|
|
639
|
+
mutator_call_ids = set()
|
|
640
|
+
for node in cg.traverse_depth_first(graph):
|
|
641
|
+
if isinstance(node, cg.MutatedArgumentNode):
|
|
642
|
+
mutator_call_ids.add(id(node.args[1]))
|
|
643
|
+
|
|
644
|
+
for node in cg.traverse_depth_first(graph):
|
|
645
|
+
if isinstance(node, cg.InputPlaceholderNode):
|
|
646
|
+
continue # arguments are defined in _codegen_graph_inputs
|
|
647
|
+
if isinstance(node, cg.MutatedArgumentNode):
|
|
648
|
+
# alias to the original node, since mutation is in-place
|
|
649
|
+
original_node = node.args[0]
|
|
650
|
+
expressions[id(node)] = expressions[id(original_node)]
|
|
651
|
+
continue
|
|
652
|
+
|
|
653
|
+
node_code = _codegen_for_node(node, expressions.copy())
|
|
654
|
+
|
|
655
|
+
if fold_map[id(node)]:
|
|
656
|
+
assert id(node) not in expressions, f"{node=} {expressions[id(node)]=}"
|
|
657
|
+
expressions[id(node)] = node_code
|
|
658
|
+
continue
|
|
659
|
+
|
|
660
|
+
if id(node) in mutator_call_ids:
|
|
661
|
+
code_lines.extend(node_code if isinstance(node_code, list) else [node_code])
|
|
662
|
+
continue
|
|
663
|
+
|
|
664
|
+
varname = expressions[id(node)]
|
|
665
|
+
node_code = _codegen_for_assignment(varname, node_code, add_line_comments)
|
|
666
|
+
code_lines.extend(node_code)
|
|
667
|
+
|
|
668
|
+
if last_varname.split("_")[0] != varname.split("_")[0]:
|
|
669
|
+
code_lines.append("")
|
|
670
|
+
last_varname = varname
|
|
671
|
+
|
|
672
|
+
if as_maincall:
|
|
673
|
+
assert len(graph.inputs) == 0, graph.inputs
|
|
674
|
+
return ["if __name__ == '__main__':"] + indent_lines(code_lines)
|
|
675
|
+
|
|
676
|
+
typename, typedef, return_lines = _codegen_for_outputs(graph, expressions)
|
|
677
|
+
|
|
678
|
+
return (
|
|
679
|
+
typedef
|
|
680
|
+
+ [""]
|
|
681
|
+
+ _codegen_graph_decorator(graph)
|
|
682
|
+
+ _codegen_graph_inputs(graph, expressions, typename, func_name=func_name)
|
|
683
|
+
+ indent_lines(code_lines)
|
|
684
|
+
+ indent_lines(return_lines)
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _resolve_func(func: Any) -> tuple[str | None, str]:
|
|
689
|
+
"""
|
|
690
|
+
Returns:
|
|
691
|
+
tuple[str, str]: import string, function callsite string
|
|
692
|
+
"""
|
|
693
|
+
|
|
694
|
+
if isinstance(func, np.ufunc):
|
|
695
|
+
return "import numpy as np", f"np.{func.__name__}"
|
|
696
|
+
|
|
697
|
+
module = getattr(func, "__module__", None)
|
|
698
|
+
|
|
699
|
+
if module is None:
|
|
700
|
+
raise NotImplementedError(f"Unsupported function: {func}")
|
|
701
|
+
elif module == "builtins":
|
|
702
|
+
return None, func.__name__
|
|
703
|
+
elif module.startswith("procfunc."):
|
|
704
|
+
callsite = "pf." + module[len("procfunc.") :] + "." + func.__name__
|
|
705
|
+
importstring = "import procfunc as pf"
|
|
706
|
+
return importstring, callsite
|
|
707
|
+
elif module.startswith("infinigen_v2."):
|
|
708
|
+
parent, _, mod_name = module.rpartition(".")
|
|
709
|
+
importstring = f"from {parent} import {mod_name}"
|
|
710
|
+
callsite = f"{mod_name}.{func.__name__}"
|
|
711
|
+
return importstring, callsite
|
|
712
|
+
else:
|
|
713
|
+
callsite = f"{module}.{func.__name__}"
|
|
714
|
+
importstring = f"import {module}"
|
|
715
|
+
return importstring, callsite
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def default_func_resolution_map(
|
|
719
|
+
toplevel_graph: cg.ComputeGraph,
|
|
720
|
+
skip_funcs: set | None = None,
|
|
721
|
+
) -> tuple[dict[Any, str | OperatorType], list[str]]:
|
|
722
|
+
func_resolution = {}
|
|
723
|
+
import_lines = set()
|
|
724
|
+
|
|
725
|
+
for graph in cg.traverse_nested_graphs(toplevel_graph):
|
|
726
|
+
assert isinstance(graph, cg.ComputeGraph), graph
|
|
727
|
+
for node in cg.traverse_depth_first(graph):
|
|
728
|
+
if not isinstance(node, cg.FunctionCallNode):
|
|
729
|
+
continue
|
|
730
|
+
|
|
731
|
+
if skip_funcs is not None and node.func in skip_funcs:
|
|
732
|
+
continue
|
|
733
|
+
|
|
734
|
+
if node.func in FUNCTIONS_TO_OPERATORS:
|
|
735
|
+
func_resolution[node.func] = FUNCTIONS_TO_OPERATORS[node.func]
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
importstring, callsite = _resolve_func(node.func)
|
|
739
|
+
func_resolution[node.func] = callsite
|
|
740
|
+
if importstring is not None:
|
|
741
|
+
import_lines.add(importstring)
|
|
742
|
+
|
|
743
|
+
return func_resolution, list(import_lines)
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def _topo_sort_subgraphs(graph: cg.ComputeGraph) -> list[cg.ComputeGraph]:
|
|
747
|
+
"""DFS post-order traversal: dependencies before dependents."""
|
|
748
|
+
visited = set()
|
|
749
|
+
result = []
|
|
750
|
+
|
|
751
|
+
def visit(g: cg.ComputeGraph):
|
|
752
|
+
if id(g) in visited:
|
|
753
|
+
return
|
|
754
|
+
visited.add(id(g))
|
|
755
|
+
for node in cg.traverse_depth_first(g):
|
|
756
|
+
if isinstance(node, cg.SubgraphCallNode):
|
|
757
|
+
visit(node.subgraph)
|
|
758
|
+
result.append(g)
|
|
759
|
+
|
|
760
|
+
visit(graph)
|
|
761
|
+
return result
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def graphs_to_python_functions(
|
|
765
|
+
graph: cg.ComputeGraph,
|
|
766
|
+
func_resolution: dict[Any, str],
|
|
767
|
+
toplevel_as_maincall: bool = True,
|
|
768
|
+
add_version_comment: bool = True,
|
|
769
|
+
add_line_comments: bool = False,
|
|
770
|
+
) -> OrderedDict[str, list[str]]:
|
|
771
|
+
np_linewidth = np.get_printoptions()["linewidth"]
|
|
772
|
+
np.set_printoptions(linewidth=100000)
|
|
773
|
+
|
|
774
|
+
targets = _topo_sort_subgraphs(graph)
|
|
775
|
+
|
|
776
|
+
def _clean_graph_name(name: str) -> str:
|
|
777
|
+
for suffix in identifiers.NONDESCRIPTIVE_NODE_NAME_PARTS:
|
|
778
|
+
if name.endswith("_" + suffix):
|
|
779
|
+
name = name[: -(len(suffix) + 1)]
|
|
780
|
+
return name
|
|
781
|
+
|
|
782
|
+
for subgraph in cg.traverse_nested_graphs(graph):
|
|
783
|
+
subgraph.name = _clean_graph_name(subgraph.name)
|
|
784
|
+
|
|
785
|
+
subgraph_names = {
|
|
786
|
+
id(subgraph): subgraph.name for subgraph in cg.traverse_nested_graphs(graph)
|
|
787
|
+
}
|
|
788
|
+
subgraph_names = identifiers.dedup_names_with_suffix(subgraph_names, separator="_")
|
|
789
|
+
|
|
790
|
+
scope_expressions = subgraph_names.copy()
|
|
791
|
+
for k, v in func_resolution.items():
|
|
792
|
+
if isinstance(v, OperatorType):
|
|
793
|
+
scope_expressions[id(k)] = OPERATOR_TEMPLATES[v]
|
|
794
|
+
else:
|
|
795
|
+
scope_expressions[id(k)] = v
|
|
796
|
+
|
|
797
|
+
lines_for_modules = []
|
|
798
|
+
for subgraph in targets:
|
|
799
|
+
func_name = subgraph_names[id(subgraph)]
|
|
800
|
+
result = _codegen_for_graph(
|
|
801
|
+
subgraph,
|
|
802
|
+
scope_expressions=scope_expressions.copy(),
|
|
803
|
+
as_maincall=(subgraph is graph and toplevel_as_maincall),
|
|
804
|
+
add_version_comment=add_version_comment,
|
|
805
|
+
add_line_comments=add_line_comments,
|
|
806
|
+
func_name=func_name,
|
|
807
|
+
)
|
|
808
|
+
lines_for_modules.append((subgraph_names[id(subgraph)], result))
|
|
809
|
+
|
|
810
|
+
np.set_printoptions(linewidth=np_linewidth)
|
|
811
|
+
|
|
812
|
+
return OrderedDict(lines_for_modules)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def _define_multiuse_return_types(
|
|
816
|
+
graph: cg.ComputeGraph,
|
|
817
|
+
func_resolution: dict,
|
|
818
|
+
) -> list[str]:
|
|
819
|
+
counts, graphs_by_type, seen = defaultdict(int), {}, set()
|
|
820
|
+
for subgraph in cg.traverse_nested_graphs(graph):
|
|
821
|
+
rettype = subgraph.outputs.toplevel_type()
|
|
822
|
+
|
|
823
|
+
if (
|
|
824
|
+
id(subgraph) in seen
|
|
825
|
+
or rettype is None
|
|
826
|
+
or hasattr(pf, rettype.__name__)
|
|
827
|
+
or not pytree.is_type_namedtuple(rettype)
|
|
828
|
+
):
|
|
829
|
+
continue
|
|
830
|
+
seen.add(id(subgraph))
|
|
831
|
+
|
|
832
|
+
counts[rettype] += 1
|
|
833
|
+
graphs_by_type[rettype] = subgraph
|
|
834
|
+
# logger.debug(f"Found {rettype=} for {subgraph.name} {counts[rettype]=}")
|
|
835
|
+
|
|
836
|
+
multiuse = {
|
|
837
|
+
rettype: subgraph
|
|
838
|
+
for rettype, subgraph in graphs_by_type.items()
|
|
839
|
+
if counts[rettype] > 1
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
lines = []
|
|
843
|
+
for rettype, subgraph in multiuse.items():
|
|
844
|
+
lines.extend(_codegen_namedtuple_def(subgraph.outputs))
|
|
845
|
+
lines.append("")
|
|
846
|
+
func_resolution[rettype] = rettype.__name__
|
|
847
|
+
|
|
848
|
+
return lines
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _collect_graph_value_imports(graph: cg.ComputeGraph) -> list[str]:
|
|
852
|
+
import_lines = set()
|
|
853
|
+
|
|
854
|
+
def _collect_from_value(v):
|
|
855
|
+
if isinstance(v, cg.Node):
|
|
856
|
+
return
|
|
857
|
+
if isinstance(v, enum.Enum):
|
|
858
|
+
t = type(v)
|
|
859
|
+
if t.__module__ != "builtins":
|
|
860
|
+
import_lines.add(f"from {t.__module__} import {t.__name__}")
|
|
861
|
+
elif isinstance(v, Path):
|
|
862
|
+
import_lines.add("from pathlib import Path")
|
|
863
|
+
elif dataclasses.is_dataclass(v) and not isinstance(v, type):
|
|
864
|
+
t = type(v)
|
|
865
|
+
if t.__module__ != "builtins":
|
|
866
|
+
import_lines.add(f"from {t.__module__} import {t.__name__}")
|
|
867
|
+
for f in dataclasses.fields(v):
|
|
868
|
+
_collect_from_value(getattr(v, f.name))
|
|
869
|
+
elif isinstance(v, list):
|
|
870
|
+
for item in v:
|
|
871
|
+
_collect_from_value(item)
|
|
872
|
+
|
|
873
|
+
for subgraph in cg.traverse_nested_graphs(graph):
|
|
874
|
+
for node in cg.traverse_depth_first(subgraph):
|
|
875
|
+
for arg in itertools.chain(node.args, node.kwargs.values()):
|
|
876
|
+
_collect_from_value(arg)
|
|
877
|
+
|
|
878
|
+
return list(import_lines)
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
def to_python(
|
|
882
|
+
graph: cg.ComputeGraph,
|
|
883
|
+
func_resolution: dict[Any, str | OperatorType] | None = None,
|
|
884
|
+
import_lines: list[str] | None = None,
|
|
885
|
+
toplevel_as_maincall: bool = True,
|
|
886
|
+
add_version_comment: bool = True,
|
|
887
|
+
add_line_comments: bool = False,
|
|
888
|
+
) -> str:
|
|
889
|
+
code_lines = []
|
|
890
|
+
code_lines.append("from typing import NamedTuple, Annotated")
|
|
891
|
+
code_lines.append("import numpy as np")
|
|
892
|
+
code_lines.append("import bpy")
|
|
893
|
+
# code_lines.append("import logging; logging.basicConfig(level=logging.DEBUG)")
|
|
894
|
+
code_lines.append("from procfunc.nodes import types as t")
|
|
895
|
+
code_lines.append("from procfunc.nodes.types import ProcNode, SocketOrVal")
|
|
896
|
+
|
|
897
|
+
if func_resolution is None:
|
|
898
|
+
func_resolution, import_lines = default_func_resolution_map(graph)
|
|
899
|
+
else:
|
|
900
|
+
assert import_lines is not None
|
|
901
|
+
|
|
902
|
+
all_imports = set(import_lines) | set(_collect_graph_value_imports(graph))
|
|
903
|
+
code_lines.extend(sorted(all_imports))
|
|
904
|
+
code_lines.append("")
|
|
905
|
+
|
|
906
|
+
code_lines.extend(_define_multiuse_return_types(graph, func_resolution))
|
|
907
|
+
|
|
908
|
+
lines_for_modules = graphs_to_python_functions(
|
|
909
|
+
graph,
|
|
910
|
+
func_resolution,
|
|
911
|
+
add_version_comment=add_version_comment,
|
|
912
|
+
add_line_comments=add_line_comments,
|
|
913
|
+
toplevel_as_maincall=toplevel_as_maincall,
|
|
914
|
+
)
|
|
915
|
+
for module_name, module_lines in lines_for_modules.items():
|
|
916
|
+
code_lines.extend(module_lines)
|
|
917
|
+
code_lines.append("")
|
|
918
|
+
|
|
919
|
+
return "\n".join(code_lines)
|