procfunc 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- procfunc/__init__.py +87 -0
- procfunc/color.py +57 -0
- procfunc/compute_graph/__init__.py +28 -0
- procfunc/compute_graph/compute_graph.py +115 -0
- procfunc/compute_graph/node.py +200 -0
- procfunc/compute_graph/operators_info.py +92 -0
- procfunc/compute_graph/proxy.py +173 -0
- procfunc/compute_graph/util.py +282 -0
- procfunc/context.py +115 -0
- procfunc/control.py +174 -0
- procfunc/nodes/__init__.py +66 -0
- procfunc/nodes/bindings_util.py +196 -0
- procfunc/nodes/bpy_node_info.py +280 -0
- procfunc/nodes/compositor.py +2242 -0
- procfunc/nodes/execute/construct_nodes.py +571 -0
- procfunc/nodes/execute/construct_special_cases.py +246 -0
- procfunc/nodes/execute/execute.py +548 -0
- procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
- procfunc/nodes/execute/util.py +247 -0
- procfunc/nodes/func.py +1417 -0
- procfunc/nodes/geo.py +4240 -0
- procfunc/nodes/manifest.json +8769 -0
- procfunc/nodes/math.py +644 -0
- procfunc/nodes/node_function.py +160 -0
- procfunc/nodes/shader.py +2359 -0
- procfunc/nodes/types.py +347 -0
- procfunc/ops/__init__.py +35 -0
- procfunc/ops/_util.py +275 -0
- procfunc/ops/addons.py +59 -0
- procfunc/ops/attr.py +426 -0
- procfunc/ops/collection.py +90 -0
- procfunc/ops/curve.py +18 -0
- procfunc/ops/file.py +126 -0
- procfunc/ops/manifest.json +39149 -0
- procfunc/ops/mesh.py +1510 -0
- procfunc/ops/modifier.py +603 -0
- procfunc/ops/object.py +258 -0
- procfunc/ops/primitives/__init__.py +31 -0
- procfunc/ops/primitives/camera.py +45 -0
- procfunc/ops/primitives/curve.py +71 -0
- procfunc/ops/primitives/light.py +114 -0
- procfunc/ops/primitives/mesh.py +358 -0
- procfunc/ops/uv.py +271 -0
- procfunc/random.py +247 -0
- procfunc/tracer/__init__.py +43 -0
- procfunc/tracer/decorator.py +121 -0
- procfunc/tracer/patch.py +494 -0
- procfunc/tracer/proxy.py +127 -0
- procfunc/tracer/trace.py +222 -0
- procfunc/transforms/__init__.py +49 -0
- procfunc/transforms/cleanup.py +214 -0
- procfunc/transforms/convert.py +20 -0
- procfunc/transforms/distribution.py +191 -0
- procfunc/transforms/extract_materials.py +116 -0
- procfunc/transforms/infer_distribution.py +326 -0
- procfunc/transforms/parameters.py +15 -0
- procfunc/transforms/util.py +35 -0
- procfunc/transpiler/__init__.py +24 -0
- procfunc/transpiler/bpy_to_computegraph.py +1348 -0
- procfunc/transpiler/codegen.py +919 -0
- procfunc/transpiler/identifiers.py +595 -0
- procfunc/transpiler/main.py +299 -0
- procfunc/types.py +380 -0
- procfunc/util/__init__.py +0 -0
- procfunc/util/bpy_info.py +145 -0
- procfunc/util/camera.py +0 -0
- procfunc/util/keyframe.py +70 -0
- procfunc/util/log.py +96 -0
- procfunc/util/manifest.py +121 -0
- procfunc/util/pytree.py +343 -0
- procfunc/util/teardown.py +37 -0
- procfunc-0.30.0.dist-info/METADATA +120 -0
- procfunc-0.30.0.dist-info/RECORD +76 -0
- procfunc-0.30.0.dist-info/WHEEL +5 -0
- procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
- procfunc-0.30.0.dist-info/top_level.txt +1 -0
procfunc/tracer/trace.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torch.fx / jax-like function-compute-graph tracing tool, but specially designed for procedural generation functions
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
This tool was heavily inspired by torch.fx.symbolic_trace https://docs.pytorch.org/docs/2.6/fx.html
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import builtins
|
|
9
|
+
import inspect
|
|
10
|
+
import logging
|
|
11
|
+
import math
|
|
12
|
+
import random
|
|
13
|
+
from types import ModuleType
|
|
14
|
+
from typing import Any, Callable
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import numpy.linalg
|
|
18
|
+
|
|
19
|
+
import procfunc as pf
|
|
20
|
+
from procfunc import compute_graph as cg
|
|
21
|
+
from procfunc.util import pytree
|
|
22
|
+
|
|
23
|
+
from .patch import (
|
|
24
|
+
Patcher,
|
|
25
|
+
PatchFunctionTarget,
|
|
26
|
+
TraceLevel,
|
|
27
|
+
)
|
|
28
|
+
from .proxy import RngProxy
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
_autowrap_modules: list[tuple[ModuleType, bool, TraceLevel]] = []
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def autowrap_module(
|
|
36
|
+
module: ModuleType,
|
|
37
|
+
allow_exec: bool = False,
|
|
38
|
+
trace_level: TraceLevel = TraceLevel.PRIMITIVES,
|
|
39
|
+
):
|
|
40
|
+
_autowrap_modules.append((module, allow_exec, trace_level))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
autowrap_module(math, allow_exec=True)
|
|
44
|
+
autowrap_module(np, allow_exec=True)
|
|
45
|
+
autowrap_module(numpy.linalg, allow_exec=True)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
_banned_modules: list[ModuleType] = [
|
|
49
|
+
np.random,
|
|
50
|
+
random,
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def add_banned_module(module: ModuleType):
|
|
55
|
+
_banned_modules.append(module)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
_patch_function_targets: list[PatchFunctionTarget] = []
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def add_wrap_target(target: PatchFunctionTarget):
|
|
62
|
+
_patch_function_targets.append(target)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
WRAP_BUILTINS = ["min", "max", "abs", "round", "sum"]
|
|
66
|
+
|
|
67
|
+
for _builtin_name in WRAP_BUILTINS:
|
|
68
|
+
add_wrap_target(
|
|
69
|
+
PatchFunctionTarget(
|
|
70
|
+
frame=builtins.__dict__,
|
|
71
|
+
name=_builtin_name,
|
|
72
|
+
trace_level=TraceLevel.PRIMITIVES,
|
|
73
|
+
normalize=False,
|
|
74
|
+
allow_exec=True,
|
|
75
|
+
source_name="builtins",
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
WRAP_CONSTRUCTORS = ["Vector", "Color", "Euler"]
|
|
80
|
+
|
|
81
|
+
for _name in WRAP_CONSTRUCTORS:
|
|
82
|
+
add_wrap_target(
|
|
83
|
+
PatchFunctionTarget(
|
|
84
|
+
frame=pf.__dict__,
|
|
85
|
+
name=_name,
|
|
86
|
+
trace_level=TraceLevel.PRIMITIVES,
|
|
87
|
+
normalize=False,
|
|
88
|
+
allow_exec=True,
|
|
89
|
+
source_name="mathutils",
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
_search_scopes: list[dict] = []
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def add_search_scope(module: ModuleType):
|
|
98
|
+
"""
|
|
99
|
+
Causes an intermediate module to be searched to discover any existing targets that need to be patched
|
|
100
|
+
|
|
101
|
+
e.g. for procfunc.nodes.to_mesh_object, the original to_mesh_object is already a target, but we need to
|
|
102
|
+
add_search_scope on the `nodes` module so that that module's references to to_mesh_object get wrapped.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
_search_scopes.append(module.__dict__)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _map_args(
|
|
109
|
+
func: Callable,
|
|
110
|
+
**inputs: Any,
|
|
111
|
+
) -> dict[str, cg.Proxy | Any]:
|
|
112
|
+
signature = inspect.signature(func)
|
|
113
|
+
|
|
114
|
+
res = {}
|
|
115
|
+
|
|
116
|
+
for name, param in signature.parameters.items():
|
|
117
|
+
if name in inputs:
|
|
118
|
+
val = inputs[name]
|
|
119
|
+
if isinstance(val, np.random.Generator):
|
|
120
|
+
node = cg.ConstantNode(value=val)
|
|
121
|
+
res[name] = RngProxy(node, val, dirty=False)
|
|
122
|
+
elif isinstance(val, cg.ConstantNode) and isinstance(
|
|
123
|
+
val.value, np.random.Generator
|
|
124
|
+
):
|
|
125
|
+
res[name] = RngProxy(val, val.value, dirty=False)
|
|
126
|
+
else:
|
|
127
|
+
res[name] = val
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
if param.default is not param.empty:
|
|
131
|
+
res[name] = param.default
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
node = cg.InputPlaceholderNode(
|
|
135
|
+
name=name, default_value=None, metadata={"varname": name}
|
|
136
|
+
)
|
|
137
|
+
res[name] = cg.Proxy(node)
|
|
138
|
+
|
|
139
|
+
return res
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def trace(
|
|
143
|
+
func: Callable,
|
|
144
|
+
trace_level: TraceLevel = TraceLevel.GENERATORS,
|
|
145
|
+
name: str | None = None,
|
|
146
|
+
**inputs: Any,
|
|
147
|
+
):
|
|
148
|
+
"""
|
|
149
|
+
Turn a python function into a graph datastructure.
|
|
150
|
+
|
|
151
|
+
Using this datastructure is (usually) equivelent to executing the function.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
func: The function to trace
|
|
155
|
+
trace_level: Granularity of the graph. Functions at this level become leaves;
|
|
156
|
+
finer functions are traced through. choice() peeks through all options when
|
|
157
|
+
trace_level >= RANDOM_CONTROL, or resolves to the chosen branch when finer.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
logger.debug(f"Tracing {func} {id(func)=} with {inputs=} {trace_level=}")
|
|
161
|
+
|
|
162
|
+
if name is None:
|
|
163
|
+
assert hasattr(func, "__name__")
|
|
164
|
+
assert isinstance(func.__name__, str)
|
|
165
|
+
name = func.__name__
|
|
166
|
+
|
|
167
|
+
proxy_args = _map_args(func, **inputs)
|
|
168
|
+
|
|
169
|
+
if pf.context.globals.current_trace_level is not None:
|
|
170
|
+
# TODO we can lift this restriction fairly(?) easily by having a global patcher & saving/restoring this state
|
|
171
|
+
# rather than setting to false at end of function. see fx.trace for an example
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Can't trace {name}, tracing is already in progress for another function. "
|
|
174
|
+
"Nested tracing is not yet supported - contact the developers to request"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
pf.context.globals.current_trace_level = trace_level.value
|
|
178
|
+
|
|
179
|
+
patcher = Patcher(
|
|
180
|
+
trace_level=trace_level,
|
|
181
|
+
autopatch_wrap_modules=_autowrap_modules,
|
|
182
|
+
autopatch_remove_modules=_banned_modules,
|
|
183
|
+
search_scopes=_search_scopes,
|
|
184
|
+
patch_functions=_patch_function_targets,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
func = patcher.apply_preexecute_patches(func, trace_level)
|
|
189
|
+
logger.debug(f"Executing {func.__name__} {id(func)=}")
|
|
190
|
+
func_result = func(**proxy_args)
|
|
191
|
+
finally:
|
|
192
|
+
pf.context.globals.current_trace_level = None
|
|
193
|
+
patcher.unpatch_all()
|
|
194
|
+
|
|
195
|
+
metadata = {
|
|
196
|
+
"operations": [
|
|
197
|
+
(trace, {"func": func, "trace_level": trace_level}),
|
|
198
|
+
],
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def extract_node(v):
|
|
202
|
+
if isinstance(v, cg.Proxy):
|
|
203
|
+
return v.node
|
|
204
|
+
return cg.ConstantNode(value=v)
|
|
205
|
+
|
|
206
|
+
outputs = pytree.PyTree(func_result)
|
|
207
|
+
outputs = outputs.map(extract_node)
|
|
208
|
+
|
|
209
|
+
input_nodes = {
|
|
210
|
+
k: v.node
|
|
211
|
+
for k, v in proxy_args.items()
|
|
212
|
+
if isinstance(v, cg.Proxy) and isinstance(v.node, cg.InputPlaceholderNode)
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
compgraph = cg.ComputeGraph(
|
|
216
|
+
inputs=pytree.PyTree(input_nodes),
|
|
217
|
+
outputs=outputs,
|
|
218
|
+
name=name,
|
|
219
|
+
metadata=metadata,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return compgraph
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .cleanup import (
|
|
2
|
+
coerce_shaders_to_materialresult,
|
|
3
|
+
eliminate_duplicate_result_types,
|
|
4
|
+
eliminate_duplicate_subgraphs,
|
|
5
|
+
extract_shader_vectors_as_inputs,
|
|
6
|
+
fill_graph_defaults_with_call_node,
|
|
7
|
+
remove_v1_name_from_graph,
|
|
8
|
+
replace_ids,
|
|
9
|
+
)
|
|
10
|
+
from .convert import (
|
|
11
|
+
colors_to_hsv_definition,
|
|
12
|
+
)
|
|
13
|
+
from .distribution import (
|
|
14
|
+
distribution_to_mode,
|
|
15
|
+
outlier_distribution,
|
|
16
|
+
)
|
|
17
|
+
from .extract_materials import (
|
|
18
|
+
extract_materials_from_graph,
|
|
19
|
+
extract_materials_from_graphs,
|
|
20
|
+
)
|
|
21
|
+
from .infer_distribution import (
|
|
22
|
+
infer_distribution_hypercube,
|
|
23
|
+
infer_nodegroup_distributions,
|
|
24
|
+
)
|
|
25
|
+
from .parameters import (
|
|
26
|
+
extract_parameter_distributions,
|
|
27
|
+
)
|
|
28
|
+
from .util import (
|
|
29
|
+
map_graph_list,
|
|
30
|
+
map_subgraphs,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"coerce_shaders_to_materialresult",
|
|
35
|
+
"eliminate_duplicate_result_types",
|
|
36
|
+
"eliminate_duplicate_subgraphs",
|
|
37
|
+
"extract_shader_vectors_as_inputs",
|
|
38
|
+
"fill_graph_defaults_with_call_node",
|
|
39
|
+
"remove_v1_name_from_graph",
|
|
40
|
+
"replace_ids",
|
|
41
|
+
"colors_to_hsv_definition",
|
|
42
|
+
"distribution_to_mode",
|
|
43
|
+
"outlier_distribution",
|
|
44
|
+
"infer_distribution_hypercube",
|
|
45
|
+
"infer_nodegroup_distributions",
|
|
46
|
+
"extract_parameter_distributions",
|
|
47
|
+
"map_graph_list",
|
|
48
|
+
"map_subgraphs",
|
|
49
|
+
]
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from procfunc import compute_graph as cg
|
|
6
|
+
from procfunc import types as t
|
|
7
|
+
from procfunc.nodes import types as nt
|
|
8
|
+
from procfunc.nodes.shader import coord, geometry
|
|
9
|
+
from procfunc.util import pytree
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def remove_v1_name_from_graph(
|
|
15
|
+
_call_node: cg.Node, graph: cg.ComputeGraph
|
|
16
|
+
) -> cg.ComputeGraph:
|
|
17
|
+
if graph.name.startswith("nodegroup_"):
|
|
18
|
+
graph.name = graph.name.replace("nodegroup_", "")
|
|
19
|
+
if graph.name.startswith("shader_"):
|
|
20
|
+
graph.name = graph.name.replace("shader_", "")
|
|
21
|
+
return graph
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def eliminate_duplicate_subgraphs(
|
|
25
|
+
graphs: list[cg.ComputeGraph],
|
|
26
|
+
) -> list[cg.ComputeGraph]:
|
|
27
|
+
unique: list[cg.ComputeGraph] = []
|
|
28
|
+
removed: list[cg.ComputeGraph] = []
|
|
29
|
+
# maps id(duplicate_subgraph) -> canonical subgraph to replace it with
|
|
30
|
+
replacements: dict[int, cg.ComputeGraph] = {}
|
|
31
|
+
|
|
32
|
+
for topgraph in graphs:
|
|
33
|
+
subgraphs = reversed(
|
|
34
|
+
list(cg.traverse_nested_graphs(topgraph, yield_call_nodes=True))
|
|
35
|
+
)
|
|
36
|
+
for _call_node, subgraph in subgraphs:
|
|
37
|
+
match = next((g for g in unique if cg.graph_nodes_equal(subgraph, g)), None)
|
|
38
|
+
if match is not None:
|
|
39
|
+
if len(subgraph.name) < len(match.name):
|
|
40
|
+
match.name = subgraph.name
|
|
41
|
+
replacements[id(subgraph)] = match
|
|
42
|
+
removed.append(subgraph)
|
|
43
|
+
else:
|
|
44
|
+
unique.append(subgraph)
|
|
45
|
+
|
|
46
|
+
# second pass: update ALL call nodes (in all nested subgraphs) that reference a replaced subgraph
|
|
47
|
+
for topgraph in graphs:
|
|
48
|
+
for subgraph in cg.traverse_nested_graphs(topgraph):
|
|
49
|
+
for node in cg.traverse_depth_first(subgraph):
|
|
50
|
+
if (
|
|
51
|
+
isinstance(node, cg.SubgraphCallNode)
|
|
52
|
+
and id(node.subgraph) in replacements
|
|
53
|
+
):
|
|
54
|
+
node.subgraph = replacements[id(node.subgraph)]
|
|
55
|
+
|
|
56
|
+
logger.debug(f"Eliminated duplicated subgraphs {[g.name for g in removed]}")
|
|
57
|
+
|
|
58
|
+
return graphs
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def eliminate_duplicate_result_types(
|
|
62
|
+
graphs: list[cg.ComputeGraph],
|
|
63
|
+
uses_threshold: int = 1,
|
|
64
|
+
) -> list[cg.ComputeGraph]:
|
|
65
|
+
rettype_uses: dict[type, list[cg.ComputeGraph]] = defaultdict(list)
|
|
66
|
+
|
|
67
|
+
for graph in graphs:
|
|
68
|
+
for subgraph in cg.traverse_nested_graphs(graph):
|
|
69
|
+
result_type = subgraph.outputs.toplevel_type()
|
|
70
|
+
if result_type is None or not pytree.is_type_namedtuple(result_type):
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
for rt in rettype_uses.keys():
|
|
74
|
+
if list(rt._fields) == list(result_type._fields):
|
|
75
|
+
result_type = rt
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
rettype_uses[result_type].append(subgraph)
|
|
79
|
+
|
|
80
|
+
for rettype, uses in rettype_uses.items():
|
|
81
|
+
if len(uses) <= uses_threshold:
|
|
82
|
+
continue
|
|
83
|
+
first_rettype = uses[0].outputs.toplevel_type()
|
|
84
|
+
for subgraph in uses[1:]:
|
|
85
|
+
subgraph.outputs.spec.container = first_rettype
|
|
86
|
+
|
|
87
|
+
return graphs
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def fill_graph_defaults_with_call_node(
|
|
91
|
+
call_node: cg.SubgraphCallNode,
|
|
92
|
+
graph: cg.ComputeGraph,
|
|
93
|
+
) -> cg.ComputeGraph:
|
|
94
|
+
if call_node is None:
|
|
95
|
+
return graph
|
|
96
|
+
|
|
97
|
+
if any(
|
|
98
|
+
isinstance(arg.default_value, float) and arg.default_value != 0.0
|
|
99
|
+
for arg in graph.inputs.values()
|
|
100
|
+
):
|
|
101
|
+
logger.debug(
|
|
102
|
+
f"Skipping {graph.name} because it has nondefault existing default args"
|
|
103
|
+
)
|
|
104
|
+
return graph
|
|
105
|
+
|
|
106
|
+
for name, inpnode in graph.inputs.items():
|
|
107
|
+
fillval = call_node.kwargs.get(name, None)
|
|
108
|
+
if fillval is not None and not isinstance(fillval, cg.Node):
|
|
109
|
+
inpnode.kwargs["default_value"] = fillval
|
|
110
|
+
|
|
111
|
+
return graph
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def coerce_shaders_to_materialresult(
|
|
115
|
+
_call_node: cg.Node, subgraph: cg.ComputeGraph
|
|
116
|
+
) -> cg.ComputeGraph:
|
|
117
|
+
if subgraph.outputs.toplevel_type() is t.Material:
|
|
118
|
+
return subgraph
|
|
119
|
+
outputs = subgraph.outputs.dict()
|
|
120
|
+
surface = outputs.get("surface") or outputs.get("bsdf")
|
|
121
|
+
if surface is None:
|
|
122
|
+
return subgraph
|
|
123
|
+
shader_outputs = {
|
|
124
|
+
"surface": surface,
|
|
125
|
+
"displacement": outputs.get("displacement"),
|
|
126
|
+
"volume": outputs.get("volume"),
|
|
127
|
+
}
|
|
128
|
+
if len(outputs) > len(shader_outputs):
|
|
129
|
+
logger.warning(
|
|
130
|
+
f"{coerce_shaders_to_materialresult.__name__} skipping due to extra outputs: {outputs.keys()}"
|
|
131
|
+
)
|
|
132
|
+
return subgraph
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"{coerce_shaders_to_materialresult.__name__} converted {subgraph.name} output"
|
|
135
|
+
)
|
|
136
|
+
subgraph.outputs = pytree.PyTree(t.Material(**shader_outputs))
|
|
137
|
+
|
|
138
|
+
return subgraph
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def replace_ids(
|
|
142
|
+
graph: cg.ComputeGraph,
|
|
143
|
+
ids: set[int],
|
|
144
|
+
val: Any,
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Pull out hardcoded arguments to be inputs to the graph instead
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
graph: The graph to extract constants from
|
|
151
|
+
extract_mask: A mask of which args to extract. The key is a tuple of the parent node id and the arg name.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
assert isinstance(graph, cg.ComputeGraph)
|
|
155
|
+
|
|
156
|
+
for name, parent, child in cg.traverse_depth_first(
|
|
157
|
+
graph, yield_consts=True, yield_name=True, yield_parent=True
|
|
158
|
+
):
|
|
159
|
+
if id(child) not in ids:
|
|
160
|
+
continue
|
|
161
|
+
if isinstance(name, int):
|
|
162
|
+
args = list(parent.args)
|
|
163
|
+
args[name] = val
|
|
164
|
+
parent.args = tuple(args)
|
|
165
|
+
else:
|
|
166
|
+
parent.kwargs[name] = val
|
|
167
|
+
|
|
168
|
+
return graph
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def extract_as_input(
|
|
172
|
+
graph: cg.ComputeGraph,
|
|
173
|
+
nodes: set[int],
|
|
174
|
+
name: str,
|
|
175
|
+
arg_type: type,
|
|
176
|
+
):
|
|
177
|
+
inp = cg.InputPlaceholderNode(
|
|
178
|
+
default_value=None, metadata={"known_value_type": arg_type, "varname": name}
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
inputs = graph.inputs.obj()
|
|
182
|
+
assert isinstance(inputs, dict), inputs
|
|
183
|
+
inputs[name] = inp
|
|
184
|
+
graph.inputs = pytree.PyTree(inputs)
|
|
185
|
+
|
|
186
|
+
return replace_ids(graph, nodes, inp)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def extract_shader_vectors_as_inputs(
|
|
190
|
+
graph: cg.ComputeGraph,
|
|
191
|
+
extract_funcs: list[Callable[..., Any]] | None = None,
|
|
192
|
+
):
|
|
193
|
+
"""
|
|
194
|
+
Pull out shader vectors as inputs to the graph instead
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
if extract_funcs is None:
|
|
198
|
+
extract_funcs = [coord, geometry]
|
|
199
|
+
|
|
200
|
+
def _is_vector_target(node: cg.FunctionCallNode) -> bool:
|
|
201
|
+
return isinstance(node, cg.FunctionCallNode) and node.func in extract_funcs
|
|
202
|
+
|
|
203
|
+
vector_nodes = set(
|
|
204
|
+
id(node)
|
|
205
|
+
for node in cg.traverse_depth_first(graph)
|
|
206
|
+
if _is_vector_target(node)
|
|
207
|
+
or (isinstance(node, cg.GetAttributeNode) and _is_vector_target(node.args[0]))
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if len(vector_nodes) == 0:
|
|
211
|
+
return graph
|
|
212
|
+
|
|
213
|
+
extract_as_input(graph, vector_nodes, "vector", nt.ProcNode[t.Vector])
|
|
214
|
+
return graph
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import procfunc as pf
|
|
2
|
+
from procfunc import compute_graph as cg
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def colors_to_hsv_definition(graph: cg.ComputeGraph) -> cg.ComputeGraph:
|
|
6
|
+
for node in cg.traverse_depth_first(graph):
|
|
7
|
+
for i, arg in enumerate(node.args):
|
|
8
|
+
if isinstance(arg, pf.Color):
|
|
9
|
+
hsv = tuple(round(x, 4) for x in arg.hsv)
|
|
10
|
+
node.args[i] = cg.FunctionCallNode(
|
|
11
|
+
pf.color.hsv_to_rgba, args=(), kwargs={"hsv": hsv}
|
|
12
|
+
)
|
|
13
|
+
for key, arg in node.kwargs.items():
|
|
14
|
+
if isinstance(arg, pf.Color):
|
|
15
|
+
hsv = tuple(round(x, 4) for x in arg.hsv)
|
|
16
|
+
node.kwargs[key] = cg.FunctionCallNode(
|
|
17
|
+
pf.color.hsv_to_rgba, args=(), kwargs={"hsv": hsv}
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
return graph
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
import procfunc as pf
|
|
8
|
+
from procfunc import compute_graph as cg
|
|
9
|
+
from procfunc.compute_graph import transform_compute_graph
|
|
10
|
+
from procfunc.random import random_distrib_funcs
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NumpyRandomDistrib(enum.Enum):
|
|
16
|
+
# from https://numpy.org/doc/2.2/reference/random/generator.html#distributions
|
|
17
|
+
BETA = "beta"
|
|
18
|
+
BINOMIAL = "binomial"
|
|
19
|
+
CHISQUARE = "chisquare"
|
|
20
|
+
EXPONENTIAL = "exponential"
|
|
21
|
+
F = "f"
|
|
22
|
+
GAMMA = "gamma"
|
|
23
|
+
GAUSSIAN = "gaussian"
|
|
24
|
+
GEOMETRIC = "geometric"
|
|
25
|
+
GUMBEL = "gumbel"
|
|
26
|
+
HYPERGEOMETRIC = "hypergeometric"
|
|
27
|
+
LAPLACE = "laplace"
|
|
28
|
+
LOGISTIC = "logistic"
|
|
29
|
+
LOGNORMAL = "lognormal"
|
|
30
|
+
LOGSERIES = "logseries"
|
|
31
|
+
MULTINOMIAL = "multinomial"
|
|
32
|
+
MULTIVARIATE_NORMAL = "multivariate_normal"
|
|
33
|
+
NEGATIVE_BINOMIAL = "negative_binomial"
|
|
34
|
+
NORMAL = "normal"
|
|
35
|
+
NONCENTRAL_CHISQUARE = "noncentral_chisquare"
|
|
36
|
+
NONCENTRAL_F = "noncentral_f"
|
|
37
|
+
PARETO = "pareto"
|
|
38
|
+
POISSON = "poisson"
|
|
39
|
+
POWER = "power"
|
|
40
|
+
RAYLEIGH = "rayleigh"
|
|
41
|
+
SHUFFLE = "shuffle"
|
|
42
|
+
STANDARD_CAUCHY = "standard_cauchy"
|
|
43
|
+
STANDARD_EXPONENTIAL = "standard_exponential"
|
|
44
|
+
STANDARD_GAMMA = "standard_gamma"
|
|
45
|
+
STANDARD_NORMAL = "standard_normal"
|
|
46
|
+
STANDARD_LOGISTIC = "standard_logistic"
|
|
47
|
+
STANDARD_LAPLACE = "standard_laplace"
|
|
48
|
+
STANDARD_PARETO = "standard_pareto"
|
|
49
|
+
STANDARD_T = "standard_t"
|
|
50
|
+
TRIANGULAR = "triangular"
|
|
51
|
+
UNIFORM = "uniform"
|
|
52
|
+
VONMISES = "vonmises"
|
|
53
|
+
WALD = "wald"
|
|
54
|
+
WEIBULL = "weibull"
|
|
55
|
+
ZIPF = "zipf"
|
|
56
|
+
INTEGER = "integers"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
FUNCNAME_TO_DISTRIB = {v.value: v for v in NumpyRandomDistrib.__members__.values()}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def as_distribution(
|
|
63
|
+
node: cg.Node,
|
|
64
|
+
) -> NumpyRandomDistrib | Callable[[np.random.Generator, ...], Any] | None:
|
|
65
|
+
match node:
|
|
66
|
+
case cg.FunctionCallNode(func=x) if x in random_distrib_funcs:
|
|
67
|
+
return x
|
|
68
|
+
case cg.MethodCallNode(method_name=method, args=(arg_0,)) if (
|
|
69
|
+
method in FUNCNAME_TO_DISTRIB
|
|
70
|
+
and isinstance(arg_0, cg.Node)
|
|
71
|
+
and arg_0.metadata.get("known_value_type") is np.random.Generator
|
|
72
|
+
):
|
|
73
|
+
return FUNCNAME_TO_DISTRIB[method]
|
|
74
|
+
case _:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def distribution_to_mode(
|
|
79
|
+
compute_graph: cg.ComputeGraph,
|
|
80
|
+
graph_name: str | None = None,
|
|
81
|
+
) -> cg.ComputeGraph:
|
|
82
|
+
"""
|
|
83
|
+
Transform a generator to use the mode of the distribution instead of random samples
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def map_to_mode(node: cg.Node) -> cg.Node | float:
|
|
87
|
+
match as_distribution(node):
|
|
88
|
+
case NumpyRandomDistrib.UNIFORM:
|
|
89
|
+
assert len(node.args) == 1 or len(node.kwargs) == 0, (
|
|
90
|
+
"implementation may be errored for mix of args and kwargs"
|
|
91
|
+
)
|
|
92
|
+
low = node.kwargs.get("low", node.args[1])
|
|
93
|
+
high = node.kwargs.get("high", node.args[2])
|
|
94
|
+
if isinstance(low, cg.Node) or isinstance(high, cg.Node):
|
|
95
|
+
logger.warning(
|
|
96
|
+
f"Uniform mode not implemented for {node=} with non-constant {low=} {high=}"
|
|
97
|
+
)
|
|
98
|
+
return node
|
|
99
|
+
return (low + high) / 2
|
|
100
|
+
case NumpyRandomDistrib.NORMAL:
|
|
101
|
+
assert len(node.args) == 1 or len(node.kwargs) == 0, (
|
|
102
|
+
"implementation may be errored for mix of args and kwargs"
|
|
103
|
+
)
|
|
104
|
+
mean = node.kwargs.get("mean", node.args[1])
|
|
105
|
+
_std = node.kwargs.get("std", node.args[2])
|
|
106
|
+
if isinstance(mean, cg.Node):
|
|
107
|
+
logger.warning(
|
|
108
|
+
f"Normal mode not implemented for {node=} with non-constant {mean=}"
|
|
109
|
+
)
|
|
110
|
+
return node
|
|
111
|
+
return mean
|
|
112
|
+
case _:
|
|
113
|
+
return node
|
|
114
|
+
|
|
115
|
+
return transform_compute_graph(
|
|
116
|
+
compute_graph,
|
|
117
|
+
map_to_mode,
|
|
118
|
+
graph_name=graph_name or compute_graph.name + "_mode",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def map_to_outlier(
|
|
123
|
+
node: cg.Node, pct: float = 0.05, normal_clip_std: float = 3.0
|
|
124
|
+
) -> cg.Node:
|
|
125
|
+
varname = node.metadata.get("varname", None)
|
|
126
|
+
varname = (varname + "_outlier") if varname else None
|
|
127
|
+
|
|
128
|
+
match as_distribution(node):
|
|
129
|
+
case NumpyRandomDistrib.UNIFORM:
|
|
130
|
+
assert len(node.args) == 1 or len(node.kwargs) == 0, (
|
|
131
|
+
"cant handle arg/kwarg mix"
|
|
132
|
+
)
|
|
133
|
+
low = node.kwargs.get("low", node.args[1])
|
|
134
|
+
high = node.kwargs.get("high", node.args[2])
|
|
135
|
+
if isinstance(low, cg.Node) or isinstance(high, cg.Node):
|
|
136
|
+
logger.warning(
|
|
137
|
+
f"outlier not implemented for {node=} with {min=} {max=}"
|
|
138
|
+
)
|
|
139
|
+
return node
|
|
140
|
+
rng = node.args[0]
|
|
141
|
+
assert isinstance(rng, cg.Node), f"got {node.args[0]=}"
|
|
142
|
+
return cg.FunctionCallNode(
|
|
143
|
+
func=pf.random.uniform_tails,
|
|
144
|
+
args=(rng,),
|
|
145
|
+
kwargs=dict(low=low, high=high, tail_pct=pct),
|
|
146
|
+
varname=varname,
|
|
147
|
+
)
|
|
148
|
+
case NumpyRandomDistrib.NORMAL:
|
|
149
|
+
assert len(node.args) == 1 or len(node.kwargs) == 0, (
|
|
150
|
+
"cant handle arg/kwarg mix"
|
|
151
|
+
)
|
|
152
|
+
mean = node.kwargs.get("mean", node.args[1])
|
|
153
|
+
std = node.kwargs.get("std", node.args[2])
|
|
154
|
+
if isinstance(mean, cg.Node):
|
|
155
|
+
logger.warning(f"outlier not implemented for {node=} with {mean=}")
|
|
156
|
+
return node
|
|
157
|
+
rng = node.args[0]
|
|
158
|
+
return cg.FunctionCallNode(
|
|
159
|
+
func=pf.random.uniform_tails,
|
|
160
|
+
args=(rng,),
|
|
161
|
+
kwargs=dict(
|
|
162
|
+
tail_pct=pct,
|
|
163
|
+
low=mean - normal_clip_std * std,
|
|
164
|
+
high=mean + normal_clip_std * std,
|
|
165
|
+
),
|
|
166
|
+
varname=varname,
|
|
167
|
+
)
|
|
168
|
+
case func if func in pf.random.random_distrib_funcs:
|
|
169
|
+
logger.warning(
|
|
170
|
+
f"{outlier_distribution.__name__} not implemented for {func}"
|
|
171
|
+
)
|
|
172
|
+
return node
|
|
173
|
+
case _:
|
|
174
|
+
return node
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def outlier_distribution(
|
|
178
|
+
compute_graph: cg.ComputeGraph,
|
|
179
|
+
pct: float = 0.05,
|
|
180
|
+
graph_name: str | None = None,
|
|
181
|
+
normal_clip_std: float = 3.0,
|
|
182
|
+
) -> cg.ComputeGraph:
|
|
183
|
+
"""
|
|
184
|
+
Transform a generator to generate outliers with a given probability
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
return transform_compute_graph(
|
|
188
|
+
compute_graph,
|
|
189
|
+
lambda node: map_to_outlier(node, pct, normal_clip_std),
|
|
190
|
+
graph_name=graph_name or compute_graph.name + "_outlier",
|
|
191
|
+
)
|