decider 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- decider/__init__.py +7 -0
- decider/_ext.py +106 -0
- decider/cli/__init__.py +15 -0
- decider/cli/_graph.py +310 -0
- decider/cli/_visualise_app.py +236 -0
- decider/cli/serve.py +92 -0
- decider/cli/template.py +86 -0
- decider/cli/visualise.py +92 -0
- decider/config/__init__.py +22 -0
- decider/config/_ext.py +17 -0
- decider/config/_ext.pyi +16 -0
- decider/config/base.py +113 -0
- decider/config/core.py +156 -0
- decider/config/file.py +189 -0
- decider/config/versioned.py +57 -0
- decider/exceptions.py +113 -0
- decider/executor.py +146 -0
- decider/graphutil.py +40 -0
- decider/initialization.py +52 -0
- decider/magics/__init__.py +162 -0
- decider/modules/__init__.py +22 -0
- decider/modules/_ext.py +15 -0
- decider/modules/_ext.pyi +16 -0
- decider/modules/core.py +92 -0
- decider/modules/credit/__init__.py +24 -0
- decider/modules/credit/decision_table/__init__.py +27 -0
- decider/modules/credit/decision_table/config.py +290 -0
- decider/modules/credit/decision_table/impl.py +66 -0
- decider/modules/credit/decision_table/module.py +71 -0
- decider/modules/credit/scorecard/__init__.py +42 -0
- decider/modules/credit/scorecard/impl.py +312 -0
- decider/modules/credit/scorecard/module.py +359 -0
- decider/modules/expression.py +245 -0
- decider/modules/functional.py +202 -0
- decider/modules/primitives/__init__.py +5 -0
- decider/modules/primitives/join.py +90 -0
- decider/modules/primitives/sequential.py +105 -0
- decider/modules/primitives/union.py +22 -0
- decider/modules/rules/__init__.py +162 -0
- decider/modules/rules/common/__init__.py +0 -0
- decider/modules/rules/common/feature.py +128 -0
- decider/modules/rules/common/nodes/__init__.py +82 -0
- decider/modules/rules/common/nodes/cases.py +208 -0
- decider/modules/rules/common/nodes/composite.py +40 -0
- decider/modules/rules/common/nodes/conditions.py +301 -0
- decider/modules/rules/common/nodes/operators.py +323 -0
- decider/modules/rules/common/nodes/unary.py +25 -0
- decider/modules/rules/common/nodetypes.py +128 -0
- decider/modules/rules/common/parameters.py +143 -0
- decider/modules/rules/common/shared.py +184 -0
- decider/modules/rules/flat_rules/__init__.py +0 -0
- decider/modules/rules/flat_rules/impl.py +542 -0
- decider/modules/rules/flat_rules/module.py +287 -0
- decider/modules/rules/flat_rules/nodes.py +582 -0
- decider/modules/rules/modules.py +61 -0
- decider/modules/rules/tree/__init__.py +31 -0
- decider/modules/rules/tree/tree.py +117 -0
- decider/modules/rules/tree/v1/__init__.py +0 -0
- decider/modules/rules/tree/v1/edges.py +37 -0
- decider/modules/rules/tree/v1/nodes.py +122 -0
- decider/modules/rules/tree/v1/schema.py +383 -0
- decider/modules/rules/tree/v1/tree.py +586 -0
- decider/modules/rules/tree/v1/variables.py +71 -0
- decider/modules/rules/tree/v2/__init__.py +0 -0
- decider/modules/rules/tree/v2/nodes.py +230 -0
- decider/modules/rules/tree/v2/tree.py +113 -0
- decider/modules/rules/tree/v3/__init__.py +94 -0
- decider/modules/rules/tree/v3/nodes_ui.py +307 -0
- decider/modules/rules/tree/v3/tree.py +177 -0
- decider/modules/util.py +88 -0
- decider/serializable/__init__.py +0 -0
- decider/serializable/dataframe.py +61 -0
- decider/serializable/dtypes.py +148 -0
- decider/serializable/function.py +24 -0
- decider/serializable/schema.py +368 -0
- decider/serving/__init__.py +0 -0
- decider/serving/format.py +47 -0
- decider/serving/handler.py +113 -0
- decider/serving/media_types.py +11 -0
- decider/serving/parse.py +49 -0
- decider/serving/servers/__init__.py +0 -0
- decider/serving/servers/core.py +19 -0
- decider/serving/servers/sanic.py +56 -0
- decider/serving/servers/starlette.py +68 -0
- decider/settings.py +93 -0
- decider/templates/__init__.py +3 -0
- decider/templates/scaffold.py +154 -0
- decider/templates/static/extension_module.py +13 -0
- decider/templates/static/extension_package/module.py +13 -0
- decider/templates/static/extension_package/pyproject.toml +19 -0
- decider/templates/static/project/generate.py +65 -0
- decider/templates/static/project/speedtest.py +66 -0
- decider/types.py +6 -0
- decider-0.0.1.dist-info/METADATA +268 -0
- decider-0.0.1.dist-info/RECORD +98 -0
- decider-0.0.1.dist-info/WHEEL +4 -0
- decider-0.0.1.dist-info/entry_points.txt +2 -0
- decider-0.0.1.dist-info/licenses/LICENSE +21 -0
decider/__init__.py
ADDED
decider/_ext.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import typing as t
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from pydantic import create_model, RootModel, BaseModel, Field, model_validator
|
|
5
|
+
from warnings import warn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TypeDiscriminatedBaseModule(BaseModel, ABC):
|
|
9
|
+
type: str
|
|
10
|
+
|
|
11
|
+
_CLASS_TYPE_IDENTIFIER: t.ClassVar[str]
|
|
12
|
+
|
|
13
|
+
def __init_subclass__(cls, **kwargs):
|
|
14
|
+
"""
|
|
15
|
+
We are basically using the below to ensure:
|
|
16
|
+
1. the class implements a type: Literal['value'] so we can use that as a discriminator for the union of all implementations of this class
|
|
17
|
+
2. We dont want there to be type: Literal['value'] = 'value' on the class because we making use of pydantic.model_dump(exclude_defaults=True) to exclude the type field when saving out modules, and if there is a default value then it will not be included in the dumdecider dict which breaks loading it back in.
|
|
18
|
+
3. We want to store what the value of Literal is so we can automatically initialise it when we construct the model Model() rather than needing Model(type='value') every time
|
|
19
|
+
"""
|
|
20
|
+
super().__init_subclass__(**kwargs)
|
|
21
|
+
|
|
22
|
+
# Skip abstract classes. as this will be used as a base for multiple implementations.
|
|
23
|
+
if inspect.isabstract(cls):
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
# Ensure `type` declared
|
|
27
|
+
if "type" not in cls.__annotations__:
|
|
28
|
+
raise TypeError(f"{cls.__name__} must define a 'type' annotation")
|
|
29
|
+
|
|
30
|
+
annotation = cls.__annotations__["type"]
|
|
31
|
+
|
|
32
|
+
if t.get_origin(annotation) is not t.Literal:
|
|
33
|
+
raise TypeError(
|
|
34
|
+
f"{cls.__name__}.type must be typing.Literal[...]"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
literal_values = t.get_args(annotation)
|
|
38
|
+
|
|
39
|
+
if len(literal_values) != 1:
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"{cls.__name__}.type must be a single-value Literal"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if "type" in cls.__dict__:
|
|
45
|
+
raise TypeError(
|
|
46
|
+
f"{cls.__name__}.type must not define a default value"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
cls._CLASS_TYPE_IDENTIFIER = literal_values[0]
|
|
50
|
+
|
|
51
|
+
@model_validator(mode="before")
|
|
52
|
+
@classmethod
|
|
53
|
+
def auto_set_type(cls, values):
|
|
54
|
+
if isinstance(values, dict) and not inspect.isabstract(cls):
|
|
55
|
+
values.setdefault("type", cls._CLASS_TYPE_IDENTIFIER)
|
|
56
|
+
return values
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
_TExtenableRootType = t.TypeVar("_TExtenableRootType")
|
|
60
|
+
|
|
61
|
+
class TExtendableModel(RootModel[_TExtenableRootType], t.Generic[_TExtenableRootType]):
|
|
62
|
+
root: _TExtenableRootType
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def create_extendable_model(
|
|
66
|
+
base_class: t.Type,
|
|
67
|
+
discriminator_field: str = "type",
|
|
68
|
+
model_name: str = "ExtendableModel"
|
|
69
|
+
) -> t.Tuple[t.Type[TExtendableModel], t.Callable[[t.Type], None]]:
|
|
70
|
+
"""
|
|
71
|
+
Creates an extendable model pattern that allows external packages to register
|
|
72
|
+
new types without creating hard dependencies.
|
|
73
|
+
|
|
74
|
+
Returns a tuple of (Model class, register function).
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
ExtendableModel = create_model(
|
|
78
|
+
model_name,
|
|
79
|
+
__base__=RootModel,
|
|
80
|
+
root=("RootType", ...)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
_registered: t.Dict[str, t.Type] = {}
|
|
84
|
+
|
|
85
|
+
def _rebuild():
|
|
86
|
+
classes = list(_registered.values())
|
|
87
|
+
if not classes:
|
|
88
|
+
return
|
|
89
|
+
union = classes[0] if len(classes) == 1 else t.Union[tuple(classes)]
|
|
90
|
+
annotated = t.Annotated[union, Field(discriminator=discriminator_field)]
|
|
91
|
+
ExtendableModel.__annotations__["root"] = annotated
|
|
92
|
+
ExtendableModel.model_fields["root"].annotation = annotated
|
|
93
|
+
was_rebuilt = ExtendableModel.model_rebuild(
|
|
94
|
+
force=True,
|
|
95
|
+
_types_namespace={"RootType": annotated},
|
|
96
|
+
)
|
|
97
|
+
if was_rebuilt is not True:
|
|
98
|
+
warn(f"model_rebuild did not return True for {ExtendableModel.__name__}")
|
|
99
|
+
|
|
100
|
+
def register_provider(provider_class: t.Type):
|
|
101
|
+
assert issubclass(provider_class, base_class), f"Provider must be a subclass of {base_class.__name__}"
|
|
102
|
+
type_id = getattr(provider_class, "_CLASS_TYPE_IDENTIFIER", None) or provider_class.__name__
|
|
103
|
+
_registered[type_id] = provider_class
|
|
104
|
+
_rebuild()
|
|
105
|
+
|
|
106
|
+
return ExtendableModel, register_provider
|
decider/cli/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from .template import template
|
|
4
|
+
from .serve import serve
|
|
5
|
+
from .visualise import visualise
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@click.group()
|
|
9
|
+
def cli():
|
|
10
|
+
"""Decider — build, serve and inspect decision pipelines."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
cli.add_command(template)
|
|
14
|
+
cli.add_command(serve)
|
|
15
|
+
cli.add_command(visualise)
|
decider/cli/_graph.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Walk a BaseModule tree and produce graph structures for the visualiser.
|
|
3
|
+
|
|
4
|
+
Two graph types:
|
|
5
|
+
build_graph(module) — module-level structural graph (pipeline view)
|
|
6
|
+
build_expression_graph(module) — expression node DAG inside one ExpressionModule
|
|
7
|
+
|
|
8
|
+
Both return ModuleGraph. The module_ref on each GraphNode holds the live
|
|
9
|
+
module object so the app can drill into it.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import typing as t
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class GraphNode:
|
|
18
|
+
id: str
|
|
19
|
+
label: str
|
|
20
|
+
kind: str # "expression" | "sequential" | "join" | "union" | "col" | "config" | "unknown"
|
|
21
|
+
type_id: str
|
|
22
|
+
parent: t.Optional[str] = None
|
|
23
|
+
fields: t.Dict[str, t.Any] = field(default_factory=dict)
|
|
24
|
+
module_ref: t.Any = None # live BaseModule if drillable
|
|
25
|
+
drillable: bool = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class GraphEdge:
|
|
30
|
+
source: str
|
|
31
|
+
target: str
|
|
32
|
+
label: str = ""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ModuleGraph:
|
|
37
|
+
nodes: t.List[GraphNode] = field(default_factory=list)
|
|
38
|
+
edges: t.List[GraphEdge] = field(default_factory=list)
|
|
39
|
+
|
|
40
|
+
def to_graphviz(self) -> "graphviz.Digraph":
|
|
41
|
+
import graphviz
|
|
42
|
+
dot = graphviz.Digraph(graph_attr={"rankdir": "TB", "splines": "ortho"})
|
|
43
|
+
|
|
44
|
+
_KIND_COLOURS = {
|
|
45
|
+
"expression": "#4C9BE8",
|
|
46
|
+
"sequential": "#E8884C",
|
|
47
|
+
"join": "#4CE8A0",
|
|
48
|
+
"union": "#9B4CE8",
|
|
49
|
+
"col": "#888888",
|
|
50
|
+
"config": "#C8A850",
|
|
51
|
+
"unknown": "#AAAAAA",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
for n in self.nodes:
|
|
55
|
+
colour = _KIND_COLOURS.get(n.kind, "#AAAAAA")
|
|
56
|
+
tooltip = "\n".join(f"{k}: {v}" for k, v in n.fields.items()) or n.type_id
|
|
57
|
+
shape = "ellipse" if n.kind in ("col", "config") else "box"
|
|
58
|
+
border = "bold" if n.drillable else ""
|
|
59
|
+
dot.node(
|
|
60
|
+
n.id,
|
|
61
|
+
label=n.label,
|
|
62
|
+
shape=shape,
|
|
63
|
+
style=f"filled,rounded,{border}".strip(","),
|
|
64
|
+
fillcolor=colour,
|
|
65
|
+
fontcolor="white",
|
|
66
|
+
tooltip=tooltip,
|
|
67
|
+
)
|
|
68
|
+
for e in self.edges:
|
|
69
|
+
dot.edge(e.source, e.target, label=e.label)
|
|
70
|
+
return dot
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
74
|
+
|
|
75
|
+
def _kind(module) -> str:
|
|
76
|
+
type_id = getattr(module, "type", "")
|
|
77
|
+
if type_id == "sequential":
|
|
78
|
+
return "sequential"
|
|
79
|
+
if type_id == "join":
|
|
80
|
+
return "join"
|
|
81
|
+
if type_id == "union":
|
|
82
|
+
return "union"
|
|
83
|
+
if hasattr(module, "expand_nodes"):
|
|
84
|
+
return "expression"
|
|
85
|
+
return "unknown"
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _config_fields(module) -> t.Dict[str, t.Any]:
|
|
89
|
+
_SKIP = {"type", "name", "steps", "modules", "left", "right", "on", "how"}
|
|
90
|
+
try:
|
|
91
|
+
raw = module.model_dump(exclude_defaults=False)
|
|
92
|
+
except Exception:
|
|
93
|
+
return {}
|
|
94
|
+
return {k: v for k, v in raw.items() if k not in _SKIP and not k.startswith("_")}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ── module-level graph ────────────────────────────────────────────────────────
|
|
98
|
+
|
|
99
|
+
def _walk(
|
|
100
|
+
module,
|
|
101
|
+
graph: ModuleGraph,
|
|
102
|
+
parent_id: t.Optional[str] = None,
|
|
103
|
+
counter: t.Optional[t.List[int]] = None,
|
|
104
|
+
) -> str:
|
|
105
|
+
if counter is None:
|
|
106
|
+
counter = [0]
|
|
107
|
+
|
|
108
|
+
counter[0] += 1
|
|
109
|
+
node_id = f"node_{counter[0]}"
|
|
110
|
+
type_id = getattr(module, "type", type(module).__name__)
|
|
111
|
+
name = getattr(module, "name", type_id)
|
|
112
|
+
kind = _kind(module)
|
|
113
|
+
drillable = kind in ("expression", "sequential", "join", "union")
|
|
114
|
+
|
|
115
|
+
graph.nodes.append(GraphNode(
|
|
116
|
+
id=node_id,
|
|
117
|
+
label=name,
|
|
118
|
+
kind=kind,
|
|
119
|
+
type_id=type_id,
|
|
120
|
+
parent=parent_id,
|
|
121
|
+
fields=_config_fields(module),
|
|
122
|
+
module_ref=module,
|
|
123
|
+
drillable=drillable,
|
|
124
|
+
))
|
|
125
|
+
|
|
126
|
+
if parent_id is not None:
|
|
127
|
+
graph.edges.append(GraphEdge(source=parent_id, target=node_id))
|
|
128
|
+
|
|
129
|
+
if kind == "sequential":
|
|
130
|
+
prev = node_id
|
|
131
|
+
for step in module.steps:
|
|
132
|
+
child_id = _walk(step, graph, parent_id=node_id, counter=counter)
|
|
133
|
+
if graph.edges and graph.edges[-1].source == node_id:
|
|
134
|
+
graph.edges[-1] = GraphEdge(source=prev, target=child_id, label="then")
|
|
135
|
+
prev = child_id
|
|
136
|
+
|
|
137
|
+
elif kind == "join":
|
|
138
|
+
for side, ref in (("left", module.left), ("right", module.right)):
|
|
139
|
+
if hasattr(ref, "type"):
|
|
140
|
+
child_id = _walk(ref, graph, parent_id=node_id, counter=counter)
|
|
141
|
+
if graph.edges:
|
|
142
|
+
graph.edges[-1].label = side
|
|
143
|
+
else:
|
|
144
|
+
fid = f"frame_{ref}_{counter[0]}"
|
|
145
|
+
counter[0] += 1
|
|
146
|
+
graph.nodes.append(GraphNode(
|
|
147
|
+
id=fid, label=f'"{ref}"', kind="col", type_id="frame", parent=node_id,
|
|
148
|
+
))
|
|
149
|
+
graph.edges.append(GraphEdge(source=node_id, target=fid, label=side))
|
|
150
|
+
|
|
151
|
+
elif kind == "union":
|
|
152
|
+
for child_mod in module.modules:
|
|
153
|
+
_walk(child_mod, graph, parent_id=node_id, counter=counter)
|
|
154
|
+
|
|
155
|
+
return node_id
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def build_graph(module) -> ModuleGraph:
|
|
159
|
+
"""Module-level structural graph for any BaseModule tree."""
|
|
160
|
+
g = ModuleGraph()
|
|
161
|
+
_walk(module, g)
|
|
162
|
+
return g
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ── expression node DAG ───────────────────────────────────────────────────────
|
|
166
|
+
|
|
167
|
+
def build_expression_graph(module) -> ModuleGraph:
|
|
168
|
+
"""
|
|
169
|
+
Return a computation DAG for an ExpressionModule showing individual
|
|
170
|
+
function nodes, their column inputs, and config injections.
|
|
171
|
+
"""
|
|
172
|
+
from decider.modules.expression import ExternalInputNode, StaticValueNode, Node as ExprNode
|
|
173
|
+
|
|
174
|
+
g = ModuleGraph()
|
|
175
|
+
nodes = module.expand_nodes()
|
|
176
|
+
|
|
177
|
+
# add a function node for every expression node
|
|
178
|
+
for name, expr_node in nodes.items():
|
|
179
|
+
g.nodes.append(GraphNode(
|
|
180
|
+
id=f"fn_{name}",
|
|
181
|
+
label=name,
|
|
182
|
+
kind="expression",
|
|
183
|
+
type_id="function",
|
|
184
|
+
drillable=False,
|
|
185
|
+
))
|
|
186
|
+
|
|
187
|
+
# add edges: inputs → function nodes
|
|
188
|
+
for name, expr_node in nodes.items():
|
|
189
|
+
for param, ref in expr_node.input_map.items():
|
|
190
|
+
if isinstance(ref, ExprNode):
|
|
191
|
+
g.edges.append(GraphEdge(source=f"fn_{ref.name}", target=f"fn_{name}", label=param))
|
|
192
|
+
elif isinstance(ref, ExternalInputNode):
|
|
193
|
+
col_id = f"col_{ref.input_name}"
|
|
194
|
+
if not any(n.id == col_id for n in g.nodes):
|
|
195
|
+
g.nodes.append(GraphNode(
|
|
196
|
+
id=col_id,
|
|
197
|
+
label=ref.input_name,
|
|
198
|
+
kind="col",
|
|
199
|
+
type_id="column",
|
|
200
|
+
drillable=False,
|
|
201
|
+
))
|
|
202
|
+
g.edges.append(GraphEdge(source=col_id, target=f"fn_{name}", label=param))
|
|
203
|
+
elif isinstance(ref, StaticValueNode):
|
|
204
|
+
val = ref.value
|
|
205
|
+
cfg_id = f"cfg_{name}_{param}"
|
|
206
|
+
# show the config type name, not the full repr
|
|
207
|
+
cfg_label = type(val).__name__ if hasattr(val, "__class__") else str(val)
|
|
208
|
+
if not any(n.id == cfg_id for n in g.nodes):
|
|
209
|
+
g.nodes.append(GraphNode(
|
|
210
|
+
id=cfg_id,
|
|
211
|
+
label=cfg_label,
|
|
212
|
+
kind="config",
|
|
213
|
+
type_id="config",
|
|
214
|
+
drillable=False,
|
|
215
|
+
fields=val.model_dump() if hasattr(val, "model_dump") else {},
|
|
216
|
+
))
|
|
217
|
+
g.edges.append(GraphEdge(source=cfg_id, target=f"fn_{name}", label=param))
|
|
218
|
+
|
|
219
|
+
return g
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# ── intermediate value extraction ─────────────────────────────────────────────
|
|
223
|
+
|
|
224
|
+
def run_with_intermediates(
|
|
225
|
+
module,
|
|
226
|
+
inputs: t.Dict[str, "pl.DataFrame"],
|
|
227
|
+
) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
|
|
228
|
+
"""
|
|
229
|
+
Execute module and return a list of (label, DataFrame) pairs, one per
|
|
230
|
+
logical step, in execution order.
|
|
231
|
+
|
|
232
|
+
- ExpressionModule → one entry per compiled expression column, accumulated
|
|
233
|
+
- SequentialModule → one entry per step
|
|
234
|
+
- Others → single entry with final output
|
|
235
|
+
"""
|
|
236
|
+
import polars as pl
|
|
237
|
+
|
|
238
|
+
kind = _kind(module)
|
|
239
|
+
|
|
240
|
+
if kind == "expression":
|
|
241
|
+
return _run_expression_intermediates(module, inputs)
|
|
242
|
+
elif kind == "sequential":
|
|
243
|
+
return _run_sequential_intermediates(module, inputs)
|
|
244
|
+
elif kind == "join":
|
|
245
|
+
return _run_join_intermediates(module, inputs)
|
|
246
|
+
else:
|
|
247
|
+
out = module(inputs)
|
|
248
|
+
if isinstance(out, pl.LazyFrame):
|
|
249
|
+
out = out.collect()
|
|
250
|
+
return [(getattr(module, "name", "output"), out)]
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _run_expression_intermediates(
|
|
254
|
+
module,
|
|
255
|
+
inputs: t.Dict[str, "pl.DataFrame"],
|
|
256
|
+
) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
|
|
257
|
+
import polars as pl
|
|
258
|
+
|
|
259
|
+
module.compile_expressions()
|
|
260
|
+
ce = module._compiled_expressions
|
|
261
|
+
frame = inputs.get(ce.input_frame)
|
|
262
|
+
if frame is None:
|
|
263
|
+
frame = next(iter(inputs.values()))
|
|
264
|
+
if isinstance(frame, pl.DataFrame):
|
|
265
|
+
frame = frame.lazy()
|
|
266
|
+
|
|
267
|
+
results = []
|
|
268
|
+
accumulated = frame
|
|
269
|
+
for col_name, expr in ce.expressions.items():
|
|
270
|
+
accumulated = accumulated.with_columns(expr.alias(col_name))
|
|
271
|
+
snapshot = accumulated.collect()
|
|
272
|
+
results.append((col_name, snapshot))
|
|
273
|
+
|
|
274
|
+
return results
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _run_sequential_intermediates(
|
|
278
|
+
module,
|
|
279
|
+
inputs: t.Dict[str, "pl.DataFrame"],
|
|
280
|
+
) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
|
|
281
|
+
import polars as pl
|
|
282
|
+
|
|
283
|
+
frames = {
|
|
284
|
+
k: v.lazy() if isinstance(v, pl.DataFrame) else v
|
|
285
|
+
for k, v in inputs.items()
|
|
286
|
+
}
|
|
287
|
+
current = frames.get("input") if "input" in frames else next(iter(frames.values()))
|
|
288
|
+
|
|
289
|
+
results = []
|
|
290
|
+
for step in module.steps:
|
|
291
|
+
frames["input"] = current
|
|
292
|
+
out = step(frames)
|
|
293
|
+
if isinstance(out, pl.LazyFrame):
|
|
294
|
+
out = out.collect()
|
|
295
|
+
current = out.lazy()
|
|
296
|
+
results.append((getattr(step, "name", step.type), out))
|
|
297
|
+
|
|
298
|
+
return results
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _run_join_intermediates(
|
|
302
|
+
module,
|
|
303
|
+
inputs: t.Dict[str, "pl.DataFrame"],
|
|
304
|
+
) -> t.List[t.Tuple[str, "pl.DataFrame"]]:
|
|
305
|
+
import polars as pl
|
|
306
|
+
|
|
307
|
+
out = module(inputs)
|
|
308
|
+
if isinstance(out, pl.LazyFrame):
|
|
309
|
+
out = out.collect()
|
|
310
|
+
return [(getattr(module, "name", "join"), out)]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streamlit app — launched by `decider visualise`.
|
|
3
|
+
|
|
4
|
+
Env vars:
|
|
5
|
+
DECIDER_VISUALISE_PROJECT_DIR
|
|
6
|
+
DECIDER_VISUALISE_EXT_DIR (optional)
|
|
7
|
+
DECIDER_VISUALISE_CONFIG_DIR (optional)
|
|
8
|
+
DECIDER_VISUALISE_ROOT_MODULE (optional, default "main")
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import polars as pl
|
|
17
|
+
import streamlit as st
|
|
18
|
+
|
|
19
|
+
# ── bootstrap ─────────────────────────────────────────────────────────────────
|
|
20
|
+
|
|
21
|
+
_project_dir = Path(os.environ.get("DECIDER_VISUALISE_PROJECT_DIR", ".")).resolve()
|
|
22
|
+
_repo_root = _project_dir.parent.parent
|
|
23
|
+
if str(_repo_root) not in sys.path:
|
|
24
|
+
sys.path.insert(0, str(_repo_root))
|
|
25
|
+
|
|
26
|
+
_ext_dir = os.environ.get("DECIDER_VISUALISE_EXT_DIR",
|
|
27
|
+
str(_project_dir / "decider_extensions"))
|
|
28
|
+
_configs_dir = os.environ.get("DECIDER_VISUALISE_CONFIG_DIR",
|
|
29
|
+
str(_project_dir / "configs"))
|
|
30
|
+
_root_module = os.environ.get("DECIDER_VISUALISE_ROOT_MODULE", "main")
|
|
31
|
+
|
|
32
|
+
from decider.initialization import initialize_decider
|
|
33
|
+
from decider.config.file import JsonFileConfigManager
|
|
34
|
+
from decider.modules import GraphModule
|
|
35
|
+
from decider.cli._graph import (
|
|
36
|
+
build_graph,
|
|
37
|
+
build_expression_graph,
|
|
38
|
+
run_with_intermediates,
|
|
39
|
+
_kind,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# ── page setup ────────────────────────────────────────────────────────────────
|
|
43
|
+
|
|
44
|
+
st.set_page_config(
|
|
45
|
+
page_title="Decider Visualise",
|
|
46
|
+
layout="wide",
|
|
47
|
+
initial_sidebar_state="expanded",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# ── session state init ────────────────────────────────────────────────────────
|
|
51
|
+
|
|
52
|
+
if "breadcrumb" not in st.session_state:
|
|
53
|
+
# Each entry: {"label": str, "module": BaseModule}
|
|
54
|
+
st.session_state.breadcrumb = []
|
|
55
|
+
|
|
56
|
+
if "run_inputs" not in st.session_state:
|
|
57
|
+
st.session_state.run_inputs = None # Dict[str, pl.DataFrame] when set
|
|
58
|
+
|
|
59
|
+
# ── load root module (cached) ─────────────────────────────────────────────────
|
|
60
|
+
|
|
61
|
+
@st.cache_resource
|
|
62
|
+
def _load_root(root_key: str):
|
|
63
|
+
initialize_decider(extension_path=_ext_dir)
|
|
64
|
+
import asyncio
|
|
65
|
+
mgr = JsonFileConfigManager(basepath=_configs_dir)
|
|
66
|
+
versioned = asyncio.run(mgr.get_latest())
|
|
67
|
+
module = GraphModule.model_validate(versioned.config[root_key]).root
|
|
68
|
+
return module, versioned
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ── sidebar ───────────────────────────────────────────────────────────────────
|
|
72
|
+
|
|
73
|
+
with st.sidebar:
|
|
74
|
+
st.header("Project")
|
|
75
|
+
st.caption(str(_project_dir))
|
|
76
|
+
|
|
77
|
+
root_key = st.text_input("Root module key", value=_root_module)
|
|
78
|
+
if st.button("↺ Reload config"):
|
|
79
|
+
st.cache_resource.clear()
|
|
80
|
+
st.session_state.breadcrumb = []
|
|
81
|
+
st.session_state.run_inputs = None
|
|
82
|
+
st.rerun()
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
root_module, versioned = _load_root(root_key)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
st.error(f"Could not load module: {e}")
|
|
88
|
+
st.stop()
|
|
89
|
+
|
|
90
|
+
with st.sidebar:
|
|
91
|
+
st.divider()
|
|
92
|
+
st.caption(f"version {versioned.version}")
|
|
93
|
+
st.caption(f"type {root_module.type}")
|
|
94
|
+
|
|
95
|
+
# ── input data entry ──────────────────────────────────────────────────────
|
|
96
|
+
st.divider()
|
|
97
|
+
st.subheader("Run data")
|
|
98
|
+
st.caption("Paste JSON (column-oriented) to push data through the pipeline.")
|
|
99
|
+
|
|
100
|
+
default_cols = root_module.get_input_frame_keys()
|
|
101
|
+
json_placeholder = json.dumps(
|
|
102
|
+
{k: ["value1", "value2"] for k in
|
|
103
|
+
(root_module._compute_input_frame_keys() if hasattr(root_module, '_compute_input_frame_keys') else ["input"])},
|
|
104
|
+
indent=2,
|
|
105
|
+
)
|
|
106
|
+
raw_json = st.text_area("Input JSON", value="", height=180,
|
|
107
|
+
placeholder=json_placeholder)
|
|
108
|
+
if st.button("▶ Run"):
|
|
109
|
+
try:
|
|
110
|
+
parsed = json.loads(raw_json)
|
|
111
|
+
# support both {col: [...]} (single frame) and {"frame": {col: [...]}}
|
|
112
|
+
if parsed and isinstance(next(iter(parsed.values())), dict):
|
|
113
|
+
st.session_state.run_inputs = {
|
|
114
|
+
k: pl.DataFrame(v) for k, v in parsed.items()
|
|
115
|
+
}
|
|
116
|
+
else:
|
|
117
|
+
st.session_state.run_inputs = {"input": pl.DataFrame(parsed)}
|
|
118
|
+
except Exception as e:
|
|
119
|
+
st.error(f"Invalid JSON: {e}")
|
|
120
|
+
|
|
121
|
+
if st.session_state.run_inputs is not None:
|
|
122
|
+
if st.button("✕ Clear run"):
|
|
123
|
+
st.session_state.run_inputs = None
|
|
124
|
+
st.rerun()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# ── breadcrumb navigation ─────────────────────────────────────────────────────
|
|
128
|
+
|
|
129
|
+
# current module is root unless the user has drilled in
|
|
130
|
+
crumb_stack = st.session_state.breadcrumb
|
|
131
|
+
current_module = crumb_stack[-1]["module"] if crumb_stack else root_module
|
|
132
|
+
|
|
133
|
+
# render breadcrumb bar
|
|
134
|
+
crumb_parts = [{"label": root_key, "module": root_module}] + crumb_stack
|
|
135
|
+
cols = st.columns([1] * len(crumb_parts) + [8])
|
|
136
|
+
for i, crumb in enumerate(crumb_parts):
|
|
137
|
+
with cols[i]:
|
|
138
|
+
is_last = i == len(crumb_parts) - 1
|
|
139
|
+
if is_last:
|
|
140
|
+
st.markdown(f"**{crumb['label']}**")
|
|
141
|
+
else:
|
|
142
|
+
if st.button(crumb["label"], key=f"crumb_{i}"):
|
|
143
|
+
st.session_state.breadcrumb = crumb_stack[: i] # pop back to i
|
|
144
|
+
st.rerun()
|
|
145
|
+
|
|
146
|
+
if crumb_stack:
|
|
147
|
+
st.caption(f"type: {current_module.type} · name: {current_module.name}")
|
|
148
|
+
|
|
149
|
+
st.divider()
|
|
150
|
+
|
|
151
|
+
# ── main content: tabs ────────────────────────────────────────────────────────
|
|
152
|
+
|
|
153
|
+
tab_graph, tab_run, tab_config = st.tabs(["Graph", "Run output", "Config"])
|
|
154
|
+
|
|
155
|
+
# ── TAB: Graph ────────────────────────────────────────────────────────────────
|
|
156
|
+
|
|
157
|
+
with tab_graph:
|
|
158
|
+
kind = _kind(current_module)
|
|
159
|
+
|
|
160
|
+
if kind == "expression":
|
|
161
|
+
# show the intra-module expression DAG
|
|
162
|
+
st.caption("Expression node DAG — functions, column inputs and config injections")
|
|
163
|
+
eg = build_expression_graph(current_module)
|
|
164
|
+
dot = eg.to_graphviz()
|
|
165
|
+
st.graphviz_chart(dot.source, use_container_width=True)
|
|
166
|
+
|
|
167
|
+
# node table
|
|
168
|
+
rows = [{"node": n.label, "kind": n.kind,
|
|
169
|
+
**{f"cfg:{k}": v for k, v in n.fields.items()}}
|
|
170
|
+
for n in eg.nodes]
|
|
171
|
+
if rows:
|
|
172
|
+
st.dataframe(pl.DataFrame(rows, infer_schema_length=len(rows)),
|
|
173
|
+
use_container_width=True)
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
# show the module-level structural graph
|
|
177
|
+
g = build_graph(current_module)
|
|
178
|
+
col_g, col_d = st.columns([2, 1])
|
|
179
|
+
|
|
180
|
+
with col_g:
|
|
181
|
+
dot = g.to_graphviz()
|
|
182
|
+
st.graphviz_chart(dot.source, use_container_width=True)
|
|
183
|
+
|
|
184
|
+
with col_d:
|
|
185
|
+
st.subheader("Modules")
|
|
186
|
+
for n in g.nodes:
|
|
187
|
+
if not n.drillable:
|
|
188
|
+
continue
|
|
189
|
+
c1, c2 = st.columns([4, 1])
|
|
190
|
+
with c1:
|
|
191
|
+
tag = f"`{n.type_id}`"
|
|
192
|
+
cfg = " · " + " ".join(f"{k}={v}" for k, v in n.fields.items()) if n.fields else ""
|
|
193
|
+
st.markdown(f"**{n.label}** {tag}{cfg}")
|
|
194
|
+
with c2:
|
|
195
|
+
if st.button("→", key=f"drill_{n.id}",
|
|
196
|
+
help=f"Drill into {n.label}"):
|
|
197
|
+
st.session_state.breadcrumb = crumb_stack + [
|
|
198
|
+
{"label": n.label, "module": n.module_ref}
|
|
199
|
+
]
|
|
200
|
+
st.rerun()
|
|
201
|
+
|
|
202
|
+
# ── TAB: Run output ───────────────────────────────────────────────────────────
|
|
203
|
+
|
|
204
|
+
with tab_run:
|
|
205
|
+
if st.session_state.run_inputs is None:
|
|
206
|
+
st.info("Paste input data in the sidebar and click **▶ Run** to see intermediate outputs.")
|
|
207
|
+
else:
|
|
208
|
+
inputs = st.session_state.run_inputs
|
|
209
|
+
|
|
210
|
+
st.subheader("Input")
|
|
211
|
+
for frame_key, df in inputs.items():
|
|
212
|
+
st.caption(f"frame: `{frame_key}`")
|
|
213
|
+
st.dataframe(df, use_container_width=True)
|
|
214
|
+
|
|
215
|
+
st.subheader("Intermediates")
|
|
216
|
+
try:
|
|
217
|
+
intermediates = run_with_intermediates(current_module, inputs)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
st.error(f"Execution error: {e}")
|
|
220
|
+
intermediates = []
|
|
221
|
+
|
|
222
|
+
for label, df in intermediates:
|
|
223
|
+
with st.expander(f"after **{label}**", expanded=True):
|
|
224
|
+
# highlight newly-added columns vs the input
|
|
225
|
+
input_cols = set(next(iter(inputs.values())).columns)
|
|
226
|
+
new_cols = [c for c in df.columns if c not in input_cols]
|
|
227
|
+
st.caption(f"new columns: {', '.join(new_cols) if new_cols else '(none)'}")
|
|
228
|
+
st.dataframe(df, use_container_width=True)
|
|
229
|
+
|
|
230
|
+
# ── TAB: Config ───────────────────────────────────────────────────────────────
|
|
231
|
+
|
|
232
|
+
with tab_config:
|
|
233
|
+
try:
|
|
234
|
+
st.json(current_module.model_dump())
|
|
235
|
+
except Exception:
|
|
236
|
+
st.json(versioned.config)
|