procfunc 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- procfunc/__init__.py +87 -0
- procfunc/color.py +57 -0
- procfunc/compute_graph/__init__.py +28 -0
- procfunc/compute_graph/compute_graph.py +115 -0
- procfunc/compute_graph/node.py +200 -0
- procfunc/compute_graph/operators_info.py +92 -0
- procfunc/compute_graph/proxy.py +173 -0
- procfunc/compute_graph/util.py +282 -0
- procfunc/context.py +115 -0
- procfunc/control.py +174 -0
- procfunc/nodes/__init__.py +66 -0
- procfunc/nodes/bindings_util.py +196 -0
- procfunc/nodes/bpy_node_info.py +280 -0
- procfunc/nodes/compositor.py +2242 -0
- procfunc/nodes/execute/construct_nodes.py +571 -0
- procfunc/nodes/execute/construct_special_cases.py +246 -0
- procfunc/nodes/execute/execute.py +548 -0
- procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
- procfunc/nodes/execute/util.py +247 -0
- procfunc/nodes/func.py +1417 -0
- procfunc/nodes/geo.py +4240 -0
- procfunc/nodes/manifest.json +8769 -0
- procfunc/nodes/math.py +644 -0
- procfunc/nodes/node_function.py +160 -0
- procfunc/nodes/shader.py +2359 -0
- procfunc/nodes/types.py +347 -0
- procfunc/ops/__init__.py +35 -0
- procfunc/ops/_util.py +275 -0
- procfunc/ops/addons.py +59 -0
- procfunc/ops/attr.py +426 -0
- procfunc/ops/collection.py +90 -0
- procfunc/ops/curve.py +18 -0
- procfunc/ops/file.py +126 -0
- procfunc/ops/manifest.json +39149 -0
- procfunc/ops/mesh.py +1510 -0
- procfunc/ops/modifier.py +603 -0
- procfunc/ops/object.py +258 -0
- procfunc/ops/primitives/__init__.py +31 -0
- procfunc/ops/primitives/camera.py +45 -0
- procfunc/ops/primitives/curve.py +71 -0
- procfunc/ops/primitives/light.py +114 -0
- procfunc/ops/primitives/mesh.py +358 -0
- procfunc/ops/uv.py +271 -0
- procfunc/random.py +247 -0
- procfunc/tracer/__init__.py +43 -0
- procfunc/tracer/decorator.py +121 -0
- procfunc/tracer/patch.py +494 -0
- procfunc/tracer/proxy.py +127 -0
- procfunc/tracer/trace.py +222 -0
- procfunc/transforms/__init__.py +49 -0
- procfunc/transforms/cleanup.py +214 -0
- procfunc/transforms/convert.py +20 -0
- procfunc/transforms/distribution.py +191 -0
- procfunc/transforms/extract_materials.py +116 -0
- procfunc/transforms/infer_distribution.py +326 -0
- procfunc/transforms/parameters.py +15 -0
- procfunc/transforms/util.py +35 -0
- procfunc/transpiler/__init__.py +24 -0
- procfunc/transpiler/bpy_to_computegraph.py +1348 -0
- procfunc/transpiler/codegen.py +919 -0
- procfunc/transpiler/identifiers.py +595 -0
- procfunc/transpiler/main.py +299 -0
- procfunc/types.py +380 -0
- procfunc/util/__init__.py +0 -0
- procfunc/util/bpy_info.py +145 -0
- procfunc/util/camera.py +0 -0
- procfunc/util/keyframe.py +70 -0
- procfunc/util/log.py +96 -0
- procfunc/util/manifest.py +121 -0
- procfunc/util/pytree.py +343 -0
- procfunc/util/teardown.py +37 -0
- procfunc-0.30.0.dist-info/METADATA +120 -0
- procfunc-0.30.0.dist-info/RECORD +76 -0
- procfunc-0.30.0.dist-info/WHEEL +5 -0
- procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
- procfunc-0.30.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""General-purpose Proxy wrapper for Node with all dunders."""
|
|
2
|
+
|
|
3
|
+
import operator
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
from .node import FunctionCallNode, GetAttributeNode, MethodCallNode, Node
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Proxy(Generic[T]):
|
|
14
|
+
"""General-purpose wrapper for Node that provides all dunder methods."""
|
|
15
|
+
|
|
16
|
+
node: Node
|
|
17
|
+
|
|
18
|
+
def __repr__(self):
|
|
19
|
+
return f"Proxy({self.node!r})"
|
|
20
|
+
|
|
21
|
+
def __getattr__(self, attr: str) -> "AttributeProxy":
|
|
22
|
+
node = GetAttributeNode(source=self.node, attribute_name=attr)
|
|
23
|
+
return AttributeProxy(node)
|
|
24
|
+
|
|
25
|
+
def __call__(self, *args, **kwargs) -> "Proxy":
|
|
26
|
+
raise NotImplementedError("Proxy.__call__ is not implemented")
|
|
27
|
+
|
|
28
|
+
def __len__(self) -> int:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"Tracing does not allow __len__ since real values are not evaluated"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def __iter__(self):
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"Proxy does not support __iter__. Use explicit indexing instead."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def __getitem__(self, idx) -> "Proxy":
|
|
39
|
+
idx_node = idx.node if isinstance(idx, Proxy) else idx
|
|
40
|
+
getitem_node = FunctionCallNode(
|
|
41
|
+
func=operator.getitem,
|
|
42
|
+
args=(self.node, idx_node),
|
|
43
|
+
kwargs={},
|
|
44
|
+
)
|
|
45
|
+
return Proxy(getitem_node)
|
|
46
|
+
|
|
47
|
+
def __bool__(self):
|
|
48
|
+
raise ValueError(
|
|
49
|
+
"Base Proxy does not allow __bool__ during tracing since real values are unknown"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
NODE_DUNDER_METHODS = {
|
|
54
|
+
"__add__": operator.add,
|
|
55
|
+
"__sub__": operator.sub,
|
|
56
|
+
"__mul__": operator.mul,
|
|
57
|
+
"__truediv__": operator.truediv,
|
|
58
|
+
"__floordiv__": operator.floordiv,
|
|
59
|
+
"__mod__": operator.mod,
|
|
60
|
+
"__pow__": operator.pow,
|
|
61
|
+
"__lshift__": operator.lshift,
|
|
62
|
+
"__rshift__": operator.rshift,
|
|
63
|
+
"__and__": operator.and_,
|
|
64
|
+
"__xor__": operator.xor,
|
|
65
|
+
"__or__": operator.or_,
|
|
66
|
+
"__neg__": operator.neg,
|
|
67
|
+
"__pos__": operator.pos,
|
|
68
|
+
"__abs__": operator.abs,
|
|
69
|
+
"__invert__": operator.invert,
|
|
70
|
+
"__eq__": operator.eq,
|
|
71
|
+
"__ne__": operator.ne,
|
|
72
|
+
"__lt__": operator.lt,
|
|
73
|
+
"__le__": operator.le,
|
|
74
|
+
"__gt__": operator.gt,
|
|
75
|
+
"__ge__": operator.ge,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
NODE_REFLECTABLE_METHODS = [
|
|
79
|
+
"add",
|
|
80
|
+
"sub",
|
|
81
|
+
"mul",
|
|
82
|
+
"floordiv",
|
|
83
|
+
"truediv",
|
|
84
|
+
"div",
|
|
85
|
+
"mod",
|
|
86
|
+
"pow",
|
|
87
|
+
"lshift",
|
|
88
|
+
"rshift",
|
|
89
|
+
"and_",
|
|
90
|
+
"or_",
|
|
91
|
+
"xor",
|
|
92
|
+
"getitem",
|
|
93
|
+
"matmul",
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _add_proxy_operator(cls, name, operator_func):
|
|
98
|
+
def proxy_method(self, *args, **kwargs):
|
|
99
|
+
# Convert any Proxy args to their underlying nodes
|
|
100
|
+
node_args = tuple(arg.node if isinstance(arg, Proxy) else arg for arg in args)
|
|
101
|
+
node_kwargs = {
|
|
102
|
+
k: v.node if isinstance(v, Proxy) else v for k, v in kwargs.items()
|
|
103
|
+
}
|
|
104
|
+
node = FunctionCallNode(
|
|
105
|
+
func=operator_func,
|
|
106
|
+
args=(self.node, *node_args),
|
|
107
|
+
kwargs=node_kwargs,
|
|
108
|
+
)
|
|
109
|
+
return Proxy(node)
|
|
110
|
+
|
|
111
|
+
setattr(cls, name, proxy_method)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _add_proxy_reflection(cls, name: str):
|
|
115
|
+
# __rmul__(self, rhs) means rhs * self — use the same operator but swap arg order
|
|
116
|
+
fwd_dunder = f"__{name.rstrip('_')}__"
|
|
117
|
+
operator_func = NODE_DUNDER_METHODS.get(fwd_dunder)
|
|
118
|
+
if operator_func is None:
|
|
119
|
+
return # no matching forward op, skip
|
|
120
|
+
|
|
121
|
+
def proxy_method(self, rhs):
|
|
122
|
+
rhs_node = rhs.node if isinstance(rhs, Proxy) else rhs
|
|
123
|
+
node = FunctionCallNode(
|
|
124
|
+
func=operator_func,
|
|
125
|
+
args=(rhs_node, self.node),
|
|
126
|
+
kwargs={},
|
|
127
|
+
)
|
|
128
|
+
return Proxy(node)
|
|
129
|
+
|
|
130
|
+
setattr(cls, f"__r{name}__", proxy_method)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# Add all dunder methods to Proxy
|
|
134
|
+
for name, operator_func in NODE_DUNDER_METHODS.items():
|
|
135
|
+
_add_proxy_operator(Proxy, name, operator_func)
|
|
136
|
+
|
|
137
|
+
# Add reflected methods
|
|
138
|
+
for name in NODE_REFLECTABLE_METHODS:
|
|
139
|
+
_add_proxy_reflection(Proxy, name)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class AttributeProxy(Proxy):
|
|
144
|
+
"""Special proxy for attribute access that supports peekthrough optimization"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, node: Node):
|
|
147
|
+
super().__init__(node)
|
|
148
|
+
assert isinstance(node, GetAttributeNode), node
|
|
149
|
+
|
|
150
|
+
def __call__(self, *args, **kwargs) -> Proxy:
|
|
151
|
+
"""
|
|
152
|
+
Someone did func = proxy.xyz, then func(), or equivelantly thats just proxy.xyz().
|
|
153
|
+
We can convert that to just a single node which is a method call on the obj node.
|
|
154
|
+
|
|
155
|
+
torch.fx.symbolic_trace calls this a _peekthrough optimization_
|
|
156
|
+
|
|
157
|
+
TODO: this means that `self` is very often dropped from the graph. need to account for this if we check for drops.
|
|
158
|
+
"""
|
|
159
|
+
assert isinstance(self.node, GetAttributeNode), self.node
|
|
160
|
+
|
|
161
|
+
# Convert any Proxy args to their underlying nodes
|
|
162
|
+
node_args = tuple(arg.node if isinstance(arg, Proxy) else arg for arg in args)
|
|
163
|
+
node_kwargs = {
|
|
164
|
+
k: v.node if isinstance(v, Proxy) else v for k, v in kwargs.items()
|
|
165
|
+
}
|
|
166
|
+
# self.node.args[0] is the source node that we're calling the method on
|
|
167
|
+
call_node = MethodCallNode(
|
|
168
|
+
callee=self.node.args[0],
|
|
169
|
+
method_name=self.node.attribute_name,
|
|
170
|
+
args=node_args,
|
|
171
|
+
kwargs=node_kwargs,
|
|
172
|
+
)
|
|
173
|
+
return Proxy(call_node)
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from collections import defaultdict, deque
|
|
4
|
+
from typing import Any, Callable, Generator, Literal, TypeVar
|
|
5
|
+
|
|
6
|
+
from procfunc.util import pytree
|
|
7
|
+
|
|
8
|
+
from .compute_graph import ComputeGraph
|
|
9
|
+
from .node import Node, SubgraphCallNode
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LiteralConstant:
|
|
18
|
+
def __init__(self, value: Any):
|
|
19
|
+
self.value = value
|
|
20
|
+
|
|
21
|
+
def __repr__(self) -> Any:
|
|
22
|
+
return self.value
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def traverse_breadth_first(
|
|
26
|
+
graph: ComputeGraph,
|
|
27
|
+
yield_parent: bool = False,
|
|
28
|
+
yield_name: bool = False,
|
|
29
|
+
yield_consts: bool = False,
|
|
30
|
+
) -> Generator[Any, None, None]:
|
|
31
|
+
"""
|
|
32
|
+
Traverse all nodes in the compute graph.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
graph: The compute graph to traverse
|
|
36
|
+
yield_parent: If True, yield (parent, child), with output nodes having parent=None
|
|
37
|
+
yield_name: If True, yield (name, child) or (name, parent, child) if yield_parent is also True
|
|
38
|
+
yield_consts: If True, yield child arguments of nodes even if they are not Nodes
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
visited = set()
|
|
42
|
+
frontier = deque((None, name, node) for name, node in graph.outputs.items())
|
|
43
|
+
# logger.debug(f"{traverse_breadth_first.__name__} {graph.name} {len(frontier)=}")
|
|
44
|
+
|
|
45
|
+
def res(parent, name, child):
|
|
46
|
+
res = (child,)
|
|
47
|
+
if yield_parent:
|
|
48
|
+
res = (parent,) + res
|
|
49
|
+
if yield_name:
|
|
50
|
+
res = (name,) + res
|
|
51
|
+
|
|
52
|
+
return res[0] if len(res) == 1 else tuple(res)
|
|
53
|
+
|
|
54
|
+
while len(frontier) > 0:
|
|
55
|
+
parent, name, node = frontier.popleft()
|
|
56
|
+
|
|
57
|
+
if yield_consts and not isinstance(node, Node):
|
|
58
|
+
yield res(parent, name, node)
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
if node is None:
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
if id(node) in visited:
|
|
65
|
+
continue
|
|
66
|
+
visited.add(id(node))
|
|
67
|
+
|
|
68
|
+
yield res(parent, name, node)
|
|
69
|
+
|
|
70
|
+
children = list(pytree.PyTree(node.args).items()) + list(
|
|
71
|
+
pytree.PyTree(node.kwargs).items()
|
|
72
|
+
)
|
|
73
|
+
for key, arg in children:
|
|
74
|
+
if not yield_consts and not isinstance(arg, Node):
|
|
75
|
+
continue
|
|
76
|
+
if id(arg) in visited:
|
|
77
|
+
continue
|
|
78
|
+
frontier.append((node, key, arg))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _traverse_depth_first_node(
|
|
82
|
+
node: Node,
|
|
83
|
+
visited: set[int],
|
|
84
|
+
parent: Node | None,
|
|
85
|
+
name: str,
|
|
86
|
+
order: Literal["preorder", "postorder"],
|
|
87
|
+
yield_parent: bool,
|
|
88
|
+
yield_name: bool,
|
|
89
|
+
yield_consts: bool,
|
|
90
|
+
) -> Generator[Any, None, None]:
|
|
91
|
+
def res(parent, name, child):
|
|
92
|
+
res = (child,)
|
|
93
|
+
if yield_parent:
|
|
94
|
+
res = (parent,) + res
|
|
95
|
+
if yield_name:
|
|
96
|
+
res = (name,) + res
|
|
97
|
+
return res[0] if len(res) == 1 else tuple(res)
|
|
98
|
+
|
|
99
|
+
assert isinstance(node, Node), node
|
|
100
|
+
if id(node) in visited:
|
|
101
|
+
return
|
|
102
|
+
visited.add(id(node))
|
|
103
|
+
|
|
104
|
+
if order == "preorder":
|
|
105
|
+
yield res(parent, name, node)
|
|
106
|
+
|
|
107
|
+
children = list(pytree.PyTree(node.args).items()) + list(
|
|
108
|
+
pytree.PyTree(node.kwargs).items()
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
for key, arg in children:
|
|
112
|
+
if not isinstance(arg, Node):
|
|
113
|
+
if yield_consts:
|
|
114
|
+
yield res(node, key, arg)
|
|
115
|
+
continue
|
|
116
|
+
yield from _traverse_depth_first_node(
|
|
117
|
+
node=arg,
|
|
118
|
+
visited=visited,
|
|
119
|
+
parent=node,
|
|
120
|
+
name=key,
|
|
121
|
+
order=order,
|
|
122
|
+
yield_parent=yield_parent,
|
|
123
|
+
yield_name=yield_name,
|
|
124
|
+
yield_consts=yield_consts,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if order == "postorder":
|
|
128
|
+
yield res(parent, name, node)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def traverse_depth_first_node(
|
|
132
|
+
node: Node,
|
|
133
|
+
yield_consts: bool = False,
|
|
134
|
+
order: Literal["preorder", "postorder"] = "postorder",
|
|
135
|
+
) -> Generator[Any, None, None]:
|
|
136
|
+
return _traverse_depth_first_node(
|
|
137
|
+
node=node,
|
|
138
|
+
visited=set(),
|
|
139
|
+
parent=None,
|
|
140
|
+
name="",
|
|
141
|
+
order=order,
|
|
142
|
+
yield_parent=False,
|
|
143
|
+
yield_name=False,
|
|
144
|
+
yield_consts=yield_consts,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def traverse_depth_first(
|
|
149
|
+
graph: ComputeGraph,
|
|
150
|
+
yield_parent: bool = False,
|
|
151
|
+
yield_name: bool = False,
|
|
152
|
+
yield_consts: bool = False,
|
|
153
|
+
order: Literal["preorder", "postorder"] = "postorder",
|
|
154
|
+
) -> Generator[Any, None, None]:
|
|
155
|
+
visited = set()
|
|
156
|
+
for name, node in graph.outputs.items():
|
|
157
|
+
if node is None:
|
|
158
|
+
continue
|
|
159
|
+
yield from _traverse_depth_first_node(
|
|
160
|
+
node, visited, None, name, order, yield_parent, yield_name, yield_consts
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def traverse_nested_graphs(
|
|
165
|
+
graph: ComputeGraph,
|
|
166
|
+
yield_call_nodes: bool = False,
|
|
167
|
+
) -> Generator[tuple[Node | None, ComputeGraph], None, None]:
|
|
168
|
+
visited = set()
|
|
169
|
+
frontier = deque([(None, graph)])
|
|
170
|
+
|
|
171
|
+
while len(frontier) > 0:
|
|
172
|
+
node, graph = frontier.popleft()
|
|
173
|
+
|
|
174
|
+
if id(graph) in visited:
|
|
175
|
+
continue
|
|
176
|
+
visited.add(id(graph))
|
|
177
|
+
|
|
178
|
+
if yield_call_nodes:
|
|
179
|
+
yield node, graph
|
|
180
|
+
else:
|
|
181
|
+
yield graph
|
|
182
|
+
|
|
183
|
+
frontier.extend(
|
|
184
|
+
(node, node.subgraph)
|
|
185
|
+
for node in traverse_depth_first(graph)
|
|
186
|
+
if isinstance(node, SubgraphCallNode)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def usages_per_node(
|
|
191
|
+
graph: ComputeGraph,
|
|
192
|
+
) -> dict[int, list[Node]]:
|
|
193
|
+
usages = defaultdict(list)
|
|
194
|
+
for node in traverse_depth_first(graph):
|
|
195
|
+
argtree = pytree.PyTree((node.args, node.kwargs))
|
|
196
|
+
for arg in argtree.values():
|
|
197
|
+
if isinstance(arg, Node):
|
|
198
|
+
usages[id(arg)].append(node)
|
|
199
|
+
return dict(usages)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def graph_nodes_equal(graph1: ComputeGraph, graph2: ComputeGraph) -> bool:
|
|
203
|
+
nodes1 = list(traverse_depth_first(graph1))
|
|
204
|
+
nodes2 = list(traverse_depth_first(graph2))
|
|
205
|
+
if len(nodes1) != len(nodes2):
|
|
206
|
+
return False
|
|
207
|
+
for node1, node2 in zip(nodes1, nodes2):
|
|
208
|
+
if type(node1) is not type(node2):
|
|
209
|
+
return False
|
|
210
|
+
if isinstance(node1, SubgraphCallNode):
|
|
211
|
+
if not graph_nodes_equal(node1.subgraph, node2.subgraph):
|
|
212
|
+
return False
|
|
213
|
+
elif node1.args != node2.args or node1.kwargs != node2.kwargs:
|
|
214
|
+
return False
|
|
215
|
+
return True
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def transform_nodetree(
|
|
219
|
+
root: Node,
|
|
220
|
+
transform_fn: Callable[[Node], Any],
|
|
221
|
+
memo: dict[int, Node] = {},
|
|
222
|
+
):
|
|
223
|
+
raise NotImplementedError("Not implemented")
|
|
224
|
+
|
|
225
|
+
new_root = transform_fn(root)
|
|
226
|
+
|
|
227
|
+
for parent, parent_key, node in traverse_breadth_first(root, parent_child=True):
|
|
228
|
+
if parent is None:
|
|
229
|
+
continue
|
|
230
|
+
elif parent is root:
|
|
231
|
+
parent = new_root
|
|
232
|
+
|
|
233
|
+
new_node = transform_fn(node)
|
|
234
|
+
if new_node is None:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Transform function {transform_fn.__name__} returned None for node {node.name}"
|
|
237
|
+
)
|
|
238
|
+
if isinstance(parent_key, int):
|
|
239
|
+
args_list = list(parent.args)
|
|
240
|
+
args_list[parent_key] = new_node
|
|
241
|
+
parent.args = tuple(args_list)
|
|
242
|
+
else:
|
|
243
|
+
parent.kwargs[parent_key] = new_node
|
|
244
|
+
|
|
245
|
+
return new_root
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def transform_compute_graph(
|
|
249
|
+
compute_graph: ComputeGraph,
|
|
250
|
+
transform_fn: Callable[[Node], Any],
|
|
251
|
+
graph_name: str | None = None,
|
|
252
|
+
):
|
|
253
|
+
raise NotImplementedError("Not implemented")
|
|
254
|
+
|
|
255
|
+
id_map: dict[int, Node] = {}
|
|
256
|
+
|
|
257
|
+
def wrapper(node: Node) -> Node:
|
|
258
|
+
res = transform_fn(node)
|
|
259
|
+
id_map[id(node)] = res
|
|
260
|
+
return res
|
|
261
|
+
|
|
262
|
+
memo = {}
|
|
263
|
+
|
|
264
|
+
new_output_values = [
|
|
265
|
+
transform_nodetree(v, wrapper, memo) for v in compute_graph.outputs.values()
|
|
266
|
+
]
|
|
267
|
+
new_output = pytree.PyTree.from_values(new_output_values, compute_graph.output.spec)
|
|
268
|
+
|
|
269
|
+
new_metadata = copy.copy(compute_graph.metadata)
|
|
270
|
+
if "operations" not in new_metadata:
|
|
271
|
+
new_metadata["operations"] = []
|
|
272
|
+
op = (transform_compute_graph, {"transform_fn": transform_fn, "id_map": id_map})
|
|
273
|
+
new_metadata["operations"].append(op)
|
|
274
|
+
|
|
275
|
+
new_inputs = compute_graph.inputs.map(lambda v: id_map[id(v)])
|
|
276
|
+
|
|
277
|
+
return ComputeGraph(
|
|
278
|
+
inputs=new_inputs,
|
|
279
|
+
outputs=new_output,
|
|
280
|
+
name=compute_graph.name + "_transformed",
|
|
281
|
+
metadata=new_metadata,
|
|
282
|
+
)
|
procfunc/context.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ProcfuncContext:
|
|
11
|
+
"""Global context for Procfunc configuration and system information."""
|
|
12
|
+
|
|
13
|
+
num_cpu_cores: int
|
|
14
|
+
current_trace_level: int | None # compared against to TraceLevel int values
|
|
15
|
+
|
|
16
|
+
warn_mode_avoid_normal_bump: Literal["ignore", "warn", "throw"]
|
|
17
|
+
"""
|
|
18
|
+
Set to 'throw' for infinigen best practices. Using a normal map / bump map / BSDF vector input _can_ be useful,
|
|
19
|
+
but its always better to do it via the Displacement output of the shader (so that it at least can also be done as displacement)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
warn_mode_avoid_implicit_vector: Literal["ignore", "warn", "throw"]
|
|
23
|
+
"""
|
|
24
|
+
Set to 'throw' for default infinigen usage - we force the user to specify exactly what vector to sample
|
|
25
|
+
This prevents materials from being incorrect when on a moving object, or ignoring uv coordinates
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
warn_mode_avoid_io_nodes: Literal["ignore", "warn", "throw"]
|
|
29
|
+
"""
|
|
30
|
+
Set to 'throw' for infinigen, prevents floating Value Vector Color nodes, or strange output nodes like AOV in shader
|
|
31
|
+
Intent is to force nodegroups to be more functional - all inputs and outputs go through nodegroup interface
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
warn_mode_empty_geonodes: Literal["ignore", "warn", "throw"]
|
|
35
|
+
"""
|
|
36
|
+
Controls behavior when a geometry node graph produces no mesh geometry (e.g. unconnected inputs).
|
|
37
|
+
'ignore' silently returns an empty mesh, 'warn' logs a warning, 'throw' raises an error.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
"""Initialize computed fields after dataclass creation."""
|
|
42
|
+
if self.num_cpu_cores <= 0:
|
|
43
|
+
self.num_cpu_cores = multiprocessing.cpu_count()
|
|
44
|
+
|
|
45
|
+
def set_strict(self):
|
|
46
|
+
"""Set all warning modes to 'throw'"""
|
|
47
|
+
self.warn_mode_avoid_normal_bump = "throw"
|
|
48
|
+
self.warn_mode_avoid_implicit_vector = "throw"
|
|
49
|
+
self.warn_mode_avoid_io_nodes = "throw"
|
|
50
|
+
self.warn_mode_empty_geonodes = "throw"
|
|
51
|
+
|
|
52
|
+
def set_warn(self):
|
|
53
|
+
"""Set all warning modes to 'warn'"""
|
|
54
|
+
self.warn_mode_avoid_normal_bump = "warn"
|
|
55
|
+
self.warn_mode_avoid_implicit_vector = "warn"
|
|
56
|
+
self.warn_mode_avoid_io_nodes = "warn"
|
|
57
|
+
self.warn_mode_empty_geonodes = "warn"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Global context instance
|
|
61
|
+
|
|
62
|
+
warn_modes = ["ignore", "warn", "throw"]
|
|
63
|
+
|
|
64
|
+
_warn_mode_avoid_normal_bump = os.environ.get(
|
|
65
|
+
"PROCFUNC_WARN_MODE_AVOID_NORMAL_BUMP", "ignore"
|
|
66
|
+
)
|
|
67
|
+
assert _warn_mode_avoid_normal_bump in warn_modes
|
|
68
|
+
|
|
69
|
+
warn_mode_avoid_implicit_vector = os.environ.get(
|
|
70
|
+
"PROCFUNC_WARN_MODE_AVOID_IMPLICIT_VECTOR", "ignore"
|
|
71
|
+
)
|
|
72
|
+
assert warn_mode_avoid_implicit_vector in warn_modes
|
|
73
|
+
|
|
74
|
+
warn_mode_avoid_io_nodes = os.environ.get("PROCFUNC_WARN_MODE_AVOID_IO_NODES", "ignore")
|
|
75
|
+
assert warn_mode_avoid_io_nodes in warn_modes
|
|
76
|
+
|
|
77
|
+
_warn_mode_empty_geonodes = os.environ.get("PROCFUNC_WARN_MODE_EMPTY_GEONODES", "warn")
|
|
78
|
+
assert _warn_mode_empty_geonodes in warn_modes
|
|
79
|
+
|
|
80
|
+
globals = ProcfuncContext(
|
|
81
|
+
num_cpu_cores=int(os.environ.get("PROCFUNC_NUM_CPU_CORES", 0)),
|
|
82
|
+
warn_mode_avoid_normal_bump=_warn_mode_avoid_normal_bump, # type: ignore[invalid-assignment]
|
|
83
|
+
warn_mode_avoid_implicit_vector=warn_mode_avoid_implicit_vector, # type: ignore[invalid-assignment]
|
|
84
|
+
warn_mode_avoid_io_nodes=warn_mode_avoid_io_nodes, # type: ignore[invalid-assignment]
|
|
85
|
+
warn_mode_empty_geonodes=_warn_mode_empty_geonodes, # type: ignore[invalid-assignment]
|
|
86
|
+
current_trace_level=None,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@contextmanager
|
|
91
|
+
def override_globals(
|
|
92
|
+
new_context: ProcfuncContext | None = None,
|
|
93
|
+
**overrides: Any,
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Override the context for a block of code
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
new_context: If provided, will override the entire context with this new context
|
|
100
|
+
overrides: If provided, will override specific keys with these values
|
|
101
|
+
"""
|
|
102
|
+
orig = copy.deepcopy(globals)
|
|
103
|
+
|
|
104
|
+
if new_context is not None:
|
|
105
|
+
for key, value in asdict(new_context).items():
|
|
106
|
+
setattr(globals, key, value)
|
|
107
|
+
|
|
108
|
+
if overrides is not None:
|
|
109
|
+
for key, value in overrides.items():
|
|
110
|
+
setattr(globals, key, value)
|
|
111
|
+
|
|
112
|
+
yield
|
|
113
|
+
|
|
114
|
+
for key, value in asdict(orig).items():
|
|
115
|
+
setattr(globals, key, value)
|