streamtrace 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- streamtrace/__init__.py +57 -0
- streamtrace/dag.py +386 -0
- streamtrace/decorators/__init__.py +0 -0
- streamtrace/decorators/app_decorator.py +77 -0
- streamtrace/decorators/infer_decorator.py +88 -0
- streamtrace/decorators/input_decorator.py +103 -0
- streamtrace/decorators/output_decorator.py +106 -0
- streamtrace/decorators/postprocess_decorator.py +64 -0
- streamtrace/decorators/preprocess_decorator.py +69 -0
- streamtrace/serializers.py +583 -0
- streamtrace/widgets/__init__.py +0 -0
- streamtrace/widgets/input_widgets.py +114 -0
- streamtrace/widgets/output_widgets.py +139 -0
- streamtrace-0.1.0.dist-info/METADATA +218 -0
- streamtrace-0.1.0.dist-info/RECORD +17 -0
- streamtrace-0.1.0.dist-info/WHEEL +5 -0
- streamtrace-0.1.0.dist-info/top_level.txt +1 -0
streamtrace/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streamtrace SDK — decorator-based ML pipeline contracts for clinical AI.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
import streamtrace as st
|
|
6
|
+
|
|
7
|
+
@st.app(title="Cardiac Segmentation", version="1.0.0")
|
|
8
|
+
class MyPipeline:
|
|
9
|
+
|
|
10
|
+
@st.input(title="CT Scan", widget=st.FileInput(...), returns="scan")
|
|
11
|
+
def load_scan(self, path): ...
|
|
12
|
+
|
|
13
|
+
@st.preprocess(title="Normalize", returns="normalized")
|
|
14
|
+
def normalize(self, scan): ...
|
|
15
|
+
|
|
16
|
+
@st.infer(title="Segment", returns="mask", weights_path="model.pt")
|
|
17
|
+
def segment(self, normalized): ...
|
|
18
|
+
|
|
19
|
+
@st.postprocess(title="Clean Mask", returns="clean_mask")
|
|
20
|
+
def clean(self, mask): ...
|
|
21
|
+
|
|
22
|
+
@st.output(title="Result", widget=st.FileOutput(...))
|
|
23
|
+
def save(self, clean_mask): ...
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .decorators.app_decorator import app, get_app_metadata, is_app
|
|
27
|
+
from .decorators.input_decorator import input, get_input_metadata, is_input
|
|
28
|
+
from .decorators.preprocess_decorator import preprocess, get_preprocess_metadata, is_preprocess
|
|
29
|
+
from .decorators.infer_decorator import infer, get_infer_metadata, is_infer
|
|
30
|
+
from .decorators.postprocess_decorator import postprocess, get_postprocess_metadata, is_postprocess
|
|
31
|
+
from .decorators.output_decorator import output, get_output_metadata, is_output
|
|
32
|
+
|
|
33
|
+
from .widgets.input_widgets import FileInput, TextInput, InputDataType
|
|
34
|
+
from .widgets.output_widgets import FileOutput, ImageOutput, PlotOutput, MetricOutput, OutputDataType
|
|
35
|
+
|
|
36
|
+
from .dag import build_dag, DAG, Node
|
|
37
|
+
|
|
38
|
+
__version__ = "0.1.0"
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
# App
|
|
42
|
+
"app", "get_app_metadata", "is_app",
|
|
43
|
+
# Node decorators
|
|
44
|
+
"input", "get_input_metadata", "is_input",
|
|
45
|
+
"preprocess", "get_preprocess_metadata", "is_preprocess",
|
|
46
|
+
"infer", "get_infer_metadata", "is_infer",
|
|
47
|
+
"postprocess", "get_postprocess_metadata", "is_postprocess",
|
|
48
|
+
"output", "get_output_metadata", "is_output",
|
|
49
|
+
# Input widgets
|
|
50
|
+
"FileInput", "TextInput", "InputDataType",
|
|
51
|
+
# Output widgets
|
|
52
|
+
"FileOutput", "ImageOutput", "PlotOutput", "MetricOutput", "OutputDataType",
|
|
53
|
+
# DAG
|
|
54
|
+
"build_dag", "DAG", "Node",
|
|
55
|
+
# Version
|
|
56
|
+
"__version__",
|
|
57
|
+
]
|
streamtrace/dag.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
"""
|
|
2
|
+
streamtrace/dag.py
|
|
3
|
+
|
|
4
|
+
Builds a dependency graph from @st.app decorated methods,
|
|
5
|
+
validates the wiring, and provides topological execution order.
|
|
6
|
+
|
|
7
|
+
The core contract: each decorated method's **parameter names** (excluding `self`)
|
|
8
|
+
are matched against other methods' **return aliases** (the `returns=` kwarg on the
|
|
9
|
+
decorator) or, if no alias is set, the method name itself.
|
|
10
|
+
|
|
11
|
+
@st.input(returns="scan")
|
|
12
|
+
def load_scan(self, path): ...
|
|
13
|
+
|
|
14
|
+
@st.preprocess(returns="normalized")
|
|
15
|
+
def normalize(self, scan): ... # ← "scan" resolves to load_scan
|
|
16
|
+
|
|
17
|
+
@st.infer(returns="mask")
|
|
18
|
+
def segment(self, normalized): ... # ← "normalized" resolves to normalize
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import inspect
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from typing import Any, Callable
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# ---------------------------------------------------------------------------
|
|
29
|
+
# Node: one decorated method in the pipeline
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
PHASES = ("input", "preprocess", "infer", "postprocess", "output")
|
|
33
|
+
|
|
34
|
+
# Map from decorator metadata attribute → phase name
|
|
35
|
+
_META_ATTR_TO_PHASE: dict[str, str] = {
|
|
36
|
+
"__input_metadata__": "input",
|
|
37
|
+
"__preprocess_metadata__": "preprocess",
|
|
38
|
+
"__infer_metadata__": "infer",
|
|
39
|
+
"__postprocess_metadata__": "postprocess",
|
|
40
|
+
"__output_metadata__": "output",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class Node:
|
|
46
|
+
"""A single node in the execution DAG."""
|
|
47
|
+
|
|
48
|
+
name: str # method name on the class
|
|
49
|
+
phase: str # one of PHASES
|
|
50
|
+
method: Callable # the unbound method
|
|
51
|
+
params: list[str] # parameter names excluding 'self'
|
|
52
|
+
returns_alias: str | None # the `returns=` value from the decorator, or None
|
|
53
|
+
metadata: Any = None # the full decorator metadata object
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def output_key(self) -> str:
|
|
57
|
+
"""The name downstream nodes use to reference this node's return value."""
|
|
58
|
+
return self.returns_alias or self.name
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
# DAG: the full dependency graph
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class DAG:
|
|
68
|
+
"""
|
|
69
|
+
Directed acyclic graph of pipeline nodes.
|
|
70
|
+
|
|
71
|
+
Nodes are discovered from class decorator metadata.
|
|
72
|
+
Edges are inferred from parameter name → output key matching.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
nodes: dict[str, Node] = field(default_factory=dict)
|
|
76
|
+
|
|
77
|
+
# Computed lazily and cached
|
|
78
|
+
_output_key_to_node: dict[str, str] = field(
|
|
79
|
+
default_factory=dict, repr=False
|
|
80
|
+
)
|
|
81
|
+
_edges: dict[str, list[str]] | None = field(default=None, repr=False)
|
|
82
|
+
|
|
83
|
+
# ------------------------------------------------------------------
|
|
84
|
+
# Construction
|
|
85
|
+
# ------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
def add_node(self, node: Node) -> None:
|
|
88
|
+
if node.name in self.nodes:
|
|
89
|
+
raise ValueError(f"Duplicate node name: '{node.name}'")
|
|
90
|
+
|
|
91
|
+
output_key = node.output_key
|
|
92
|
+
if output_key in self._output_key_to_node:
|
|
93
|
+
existing = self._output_key_to_node[output_key]
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Duplicate output key '{output_key}': "
|
|
96
|
+
f"both '{existing}' and '{node.name}' produce it"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.nodes[node.name] = node
|
|
100
|
+
self._output_key_to_node[output_key] = node.name
|
|
101
|
+
self._edges = None # invalidate cache
|
|
102
|
+
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
# Edge resolution
|
|
105
|
+
# ------------------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def edges(self) -> dict[str, list[str]]:
|
|
109
|
+
"""
|
|
110
|
+
Map of node_name → list of node_names it depends on.
|
|
111
|
+
Edges are resolved by matching parameter names to output keys.
|
|
112
|
+
|
|
113
|
+
Input nodes are excluded from edge resolution because their
|
|
114
|
+
parameters come from the user, not from upstream nodes.
|
|
115
|
+
"""
|
|
116
|
+
if self._edges is not None:
|
|
117
|
+
return self._edges
|
|
118
|
+
|
|
119
|
+
self._edges = {}
|
|
120
|
+
for name, node in self.nodes.items():
|
|
121
|
+
if node.phase == "input":
|
|
122
|
+
# Input nodes receive user-provided values, not upstream outputs
|
|
123
|
+
self._edges[name] = []
|
|
124
|
+
else:
|
|
125
|
+
deps = []
|
|
126
|
+
for param in node.params:
|
|
127
|
+
if param in self._output_key_to_node:
|
|
128
|
+
deps.append(self._output_key_to_node[param])
|
|
129
|
+
# If param doesn't match any output key, validation catches it
|
|
130
|
+
self._edges[name] = deps
|
|
131
|
+
|
|
132
|
+
return self._edges
|
|
133
|
+
|
|
134
|
+
# ------------------------------------------------------------------
|
|
135
|
+
# Validation
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
|
|
138
|
+
def validate(self) -> list[str]:
|
|
139
|
+
"""
|
|
140
|
+
Return a list of error strings. Empty list = valid pipeline.
|
|
141
|
+
|
|
142
|
+
Checks:
|
|
143
|
+
1. At least one input node exists
|
|
144
|
+
2. At least one output node exists
|
|
145
|
+
3. Every non-input parameter resolves to an upstream output key
|
|
146
|
+
4. No dependency cycles
|
|
147
|
+
5. Phase ordering is respected (no preprocess depending on infer, etc.)
|
|
148
|
+
"""
|
|
149
|
+
errors: list[str] = []
|
|
150
|
+
|
|
151
|
+
# --- Check: required phases ---
|
|
152
|
+
phases_present = {n.phase for n in self.nodes.values()}
|
|
153
|
+
if "input" not in phases_present:
|
|
154
|
+
errors.append("Pipeline has no @input nodes")
|
|
155
|
+
if "output" not in phases_present:
|
|
156
|
+
errors.append("Pipeline has no @output nodes")
|
|
157
|
+
|
|
158
|
+
# --- Check: all parameters resolve ---
|
|
159
|
+
for name, node in self.nodes.items():
|
|
160
|
+
if node.phase == "input":
|
|
161
|
+
continue
|
|
162
|
+
for param in node.params:
|
|
163
|
+
if param not in self._output_key_to_node:
|
|
164
|
+
errors.append(
|
|
165
|
+
f"Node '{name}' expects parameter '{param}' "
|
|
166
|
+
f"but no node produces output key '{param}'. "
|
|
167
|
+
f"Available keys: {sorted(self._output_key_to_node.keys())}"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# --- Check: phase ordering ---
|
|
171
|
+
phase_order = {phase: i for i, phase in enumerate(PHASES)}
|
|
172
|
+
for name, dep_names in self.edges.items():
|
|
173
|
+
node_phase = phase_order[self.nodes[name].phase]
|
|
174
|
+
for dep_name in dep_names:
|
|
175
|
+
dep_phase = phase_order[self.nodes[dep_name].phase]
|
|
176
|
+
if dep_phase > node_phase:
|
|
177
|
+
errors.append(
|
|
178
|
+
f"Node '{name}' ({self.nodes[name].phase}) depends on "
|
|
179
|
+
f"'{dep_name}' ({self.nodes[dep_name].phase}), "
|
|
180
|
+
f"which is a later phase"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# --- Check: cycle detection ---
|
|
184
|
+
cycle = self._detect_cycle()
|
|
185
|
+
if cycle:
|
|
186
|
+
errors.append(f"Dependency cycle detected: {' → '.join(cycle)}")
|
|
187
|
+
|
|
188
|
+
return errors
|
|
189
|
+
|
|
190
|
+
def _detect_cycle(self) -> list[str] | None:
|
|
191
|
+
"""DFS-based cycle detection. Returns the cycle path or None."""
|
|
192
|
+
WHITE, GRAY, BLACK = 0, 1, 2
|
|
193
|
+
color = {name: WHITE for name in self.nodes}
|
|
194
|
+
parent: dict[str, str | None] = {}
|
|
195
|
+
|
|
196
|
+
def dfs(u: str) -> list[str] | None:
|
|
197
|
+
color[u] = GRAY
|
|
198
|
+
for v in self.edges.get(u, []):
|
|
199
|
+
if color[v] == GRAY:
|
|
200
|
+
# Reconstruct cycle
|
|
201
|
+
cycle = [v, u]
|
|
202
|
+
curr = u
|
|
203
|
+
while curr != v:
|
|
204
|
+
curr = parent.get(curr)
|
|
205
|
+
if curr is None:
|
|
206
|
+
break
|
|
207
|
+
cycle.append(curr)
|
|
208
|
+
cycle.reverse()
|
|
209
|
+
return cycle
|
|
210
|
+
if color[v] == WHITE:
|
|
211
|
+
parent[v] = u
|
|
212
|
+
result = dfs(v)
|
|
213
|
+
if result:
|
|
214
|
+
return result
|
|
215
|
+
color[u] = BLACK
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
for name in self.nodes:
|
|
219
|
+
if color[name] == WHITE:
|
|
220
|
+
result = dfs(name)
|
|
221
|
+
if result:
|
|
222
|
+
return result
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
# ------------------------------------------------------------------
|
|
226
|
+
# Topological sort
|
|
227
|
+
# ------------------------------------------------------------------
|
|
228
|
+
|
|
229
|
+
def topological_sort(self) -> list[Node]:
|
|
230
|
+
"""
|
|
231
|
+
Return nodes in execution order via Kahn's algorithm.
|
|
232
|
+
|
|
233
|
+
Ties within the same depth are broken by phase order first,
|
|
234
|
+
then by definition order. This ensures inputs run before
|
|
235
|
+
preprocessing even if there's no explicit dependency.
|
|
236
|
+
"""
|
|
237
|
+
phase_priority = {phase: i for i, phase in enumerate(PHASES)}
|
|
238
|
+
|
|
239
|
+
# Build in-degree map
|
|
240
|
+
in_degree: dict[str, int] = {name: 0 for name in self.nodes}
|
|
241
|
+
for name, deps in self.edges.items():
|
|
242
|
+
in_degree[name] = len(deps)
|
|
243
|
+
|
|
244
|
+
# Seed with zero-degree nodes, sorted by phase then name
|
|
245
|
+
queue: list[str] = sorted(
|
|
246
|
+
[n for n, d in in_degree.items() if d == 0],
|
|
247
|
+
key=lambda n: (phase_priority[self.nodes[n].phase], n),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
order: list[Node] = []
|
|
251
|
+
while queue:
|
|
252
|
+
current = queue.pop(0)
|
|
253
|
+
order.append(self.nodes[current])
|
|
254
|
+
|
|
255
|
+
# Find nodes that depend on current and decrement
|
|
256
|
+
for name, deps in self.edges.items():
|
|
257
|
+
if current in deps:
|
|
258
|
+
in_degree[name] -= 1
|
|
259
|
+
if in_degree[name] == 0:
|
|
260
|
+
queue.append(name)
|
|
261
|
+
# Re-sort to maintain phase priority
|
|
262
|
+
queue.sort(
|
|
263
|
+
key=lambda n: (
|
|
264
|
+
phase_priority[self.nodes[n].phase],
|
|
265
|
+
n,
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if len(order) != len(self.nodes):
|
|
270
|
+
# Should not happen if validate() passed, but safety net
|
|
271
|
+
missing = set(self.nodes.keys()) - {n.name for n in order}
|
|
272
|
+
raise RuntimeError(
|
|
273
|
+
f"Topological sort incomplete — stuck nodes: {missing}. "
|
|
274
|
+
f"This usually indicates a cycle."
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
return order
|
|
278
|
+
|
|
279
|
+
# ------------------------------------------------------------------
|
|
280
|
+
# Introspection helpers
|
|
281
|
+
# ------------------------------------------------------------------
|
|
282
|
+
|
|
283
|
+
def get_input_nodes(self) -> list[Node]:
|
|
284
|
+
return [n for n in self.nodes.values() if n.phase == "input"]
|
|
285
|
+
|
|
286
|
+
def get_output_nodes(self) -> list[Node]:
|
|
287
|
+
return [n for n in self.nodes.values() if n.phase == "output"]
|
|
288
|
+
|
|
289
|
+
def get_intermediate_outputs(self) -> list[Node]:
|
|
290
|
+
return [
|
|
291
|
+
n
|
|
292
|
+
for n in self.nodes.values()
|
|
293
|
+
if n.phase == "output"
|
|
294
|
+
and getattr(n.metadata, "intermediate", False)
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
def get_dependencies(self, node_name: str) -> list[Node]:
|
|
298
|
+
"""Get the direct upstream dependencies of a node."""
|
|
299
|
+
return [self.nodes[dep] for dep in self.edges.get(node_name, [])]
|
|
300
|
+
|
|
301
|
+
def to_schema(self) -> dict:
|
|
302
|
+
"""
|
|
303
|
+
Emit a JSON-serializable dict describing the full pipeline.
|
|
304
|
+
Used by `streamtrace push` to register with the backend,
|
|
305
|
+
and by the frontend to render the pipeline visualization.
|
|
306
|
+
"""
|
|
307
|
+
nodes_schema = []
|
|
308
|
+
for node in self.topological_sort():
|
|
309
|
+
entry = {
|
|
310
|
+
"name": node.name,
|
|
311
|
+
"phase": node.phase,
|
|
312
|
+
"output_key": node.output_key,
|
|
313
|
+
"depends_on": [
|
|
314
|
+
self.nodes[d].output_key
|
|
315
|
+
for d in self.edges.get(node.name, [])
|
|
316
|
+
],
|
|
317
|
+
}
|
|
318
|
+
# Include widget schema for input nodes
|
|
319
|
+
if node.phase == "input" and node.metadata:
|
|
320
|
+
if hasattr(node.metadata, "to_schema"):
|
|
321
|
+
entry["widget"] = node.metadata.to_schema()
|
|
322
|
+
|
|
323
|
+
# Include output metadata
|
|
324
|
+
if node.phase == "output" and node.metadata:
|
|
325
|
+
entry["intermediate"] = getattr(node.metadata, "intermediate", False)
|
|
326
|
+
widget = getattr(node.metadata, "widget", None)
|
|
327
|
+
if widget is not None:
|
|
328
|
+
entry["widget"] = widget.to_schema()
|
|
329
|
+
|
|
330
|
+
nodes_schema.append(entry)
|
|
331
|
+
|
|
332
|
+
return {"nodes": nodes_schema}
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# ---------------------------------------------------------------------------
|
|
336
|
+
# build_dag: extract a DAG from an @st.app decorated class
|
|
337
|
+
# ---------------------------------------------------------------------------
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def build_dag(app_cls: type) -> DAG:
|
|
341
|
+
"""
|
|
342
|
+
Inspect an @st.app class and build the execution DAG from
|
|
343
|
+
its decorated methods.
|
|
344
|
+
|
|
345
|
+
Each method decorated with @st.input, @st.preprocess, @st.infer,
|
|
346
|
+
or @st.output becomes a node. Parameter names on non-input nodes
|
|
347
|
+
are matched against the `returns` alias (or method name) of other
|
|
348
|
+
nodes to form edges.
|
|
349
|
+
"""
|
|
350
|
+
dag = DAG()
|
|
351
|
+
|
|
352
|
+
# Use vars(app_cls) to get methods defined on the class itself,
|
|
353
|
+
# not inherited ones. Preserves definition order in Python 3.7+.
|
|
354
|
+
for attr_name, attr_value in vars(app_cls).items():
|
|
355
|
+
if not callable(attr_value):
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
for meta_attr, phase in _META_ATTR_TO_PHASE.items():
|
|
359
|
+
metadata = getattr(attr_value, meta_attr, None)
|
|
360
|
+
if metadata is None:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
# Extract parameter names (excluding 'self')
|
|
364
|
+
sig = inspect.signature(attr_value)
|
|
365
|
+
params = [
|
|
366
|
+
p
|
|
367
|
+
for p in sig.parameters
|
|
368
|
+
if p != "self"
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
# Extract returns alias from metadata
|
|
372
|
+
returns_alias = getattr(metadata, "returns", None)
|
|
373
|
+
|
|
374
|
+
dag.add_node(
|
|
375
|
+
Node(
|
|
376
|
+
name=attr_name,
|
|
377
|
+
phase=phase,
|
|
378
|
+
method=attr_value,
|
|
379
|
+
params=params,
|
|
380
|
+
returns_alias=returns_alias,
|
|
381
|
+
metadata=metadata,
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
break # a method can only have one phase
|
|
385
|
+
|
|
386
|
+
return dag
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# ===================================================================
|
|
8
|
+
# @st.app — class decorator
|
|
9
|
+
# ===================================================================
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class AppMetadata:
|
|
13
|
+
"""Metadata for the top-level pipeline class."""
|
|
14
|
+
title: str = "Streamtrace App"
|
|
15
|
+
version: str = "0.1.0"
|
|
16
|
+
description: str = ""
|
|
17
|
+
# Docker base image hint — used by `push` to select the right
|
|
18
|
+
# container. The agent picks this based on what the model needs.
|
|
19
|
+
docker_base: str | None = None
|
|
20
|
+
# Python dependencies beyond what's in the base image
|
|
21
|
+
requirements: list[str] = field(default_factory=list)
|
|
22
|
+
|
|
23
|
+
def to_schema(self) -> dict[str, Any]:
|
|
24
|
+
schema = {
|
|
25
|
+
"title": self.title,
|
|
26
|
+
"version": self.version,
|
|
27
|
+
"description": self.description,
|
|
28
|
+
}
|
|
29
|
+
if self.docker_base:
|
|
30
|
+
schema["docker_base"] = self.docker_base
|
|
31
|
+
if self.requirements:
|
|
32
|
+
schema["requirements"] = self.requirements
|
|
33
|
+
return schema
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def app(
|
|
37
|
+
target=None,
|
|
38
|
+
*,
|
|
39
|
+
title: str = "Streamtrace App",
|
|
40
|
+
version: str = "0.1.0",
|
|
41
|
+
description: str = "",
|
|
42
|
+
docker_base: str | None = None,
|
|
43
|
+
requirements: list[str] | None = None,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Class decorator that marks a class as a Streamtrace app.
|
|
47
|
+
|
|
48
|
+
Usage:
|
|
49
|
+
@st.app
|
|
50
|
+
class MyPipeline: ...
|
|
51
|
+
|
|
52
|
+
@st.app(title="Cardiac Seg", version="1.0.0",
|
|
53
|
+
docker_base="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
|
|
54
|
+
requirements=["nibabel", "scipy"])
|
|
55
|
+
class CardiacSegmentation: ...
|
|
56
|
+
"""
|
|
57
|
+
def decorator(cls):
|
|
58
|
+
cls.__app_metadata__ = AppMetadata(
|
|
59
|
+
title=title,
|
|
60
|
+
version=version,
|
|
61
|
+
description=description,
|
|
62
|
+
docker_base=docker_base,
|
|
63
|
+
requirements=requirements or [],
|
|
64
|
+
)
|
|
65
|
+
return cls
|
|
66
|
+
|
|
67
|
+
if target is not None:
|
|
68
|
+
return decorator(target)
|
|
69
|
+
return decorator
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_app_metadata(cls) -> AppMetadata | None:
|
|
73
|
+
return getattr(cls, "__app_metadata__", None)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def is_app(cls) -> bool:
|
|
77
|
+
return hasattr(cls, "__app_metadata__")
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# ===================================================================
|
|
8
|
+
# @st.infer — inference/model step decorator
|
|
9
|
+
# ===================================================================
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class InferMetadata:
|
|
13
|
+
"""Metadata for an inference node."""
|
|
14
|
+
title: str = "Inference"
|
|
15
|
+
returns: str | None = None
|
|
16
|
+
# Hint for the runtime about GPU requirements
|
|
17
|
+
device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc.
|
|
18
|
+
# Optional path to model weights — used by `push` to
|
|
19
|
+
# ensure weights are available in the container
|
|
20
|
+
weights_path: str | None = None
|
|
21
|
+
|
|
22
|
+
def to_schema(self) -> dict[str, Any]:
|
|
23
|
+
schema: dict[str, Any] = {
|
|
24
|
+
"title": self.title,
|
|
25
|
+
"device": self.device,
|
|
26
|
+
}
|
|
27
|
+
if self.returns:
|
|
28
|
+
schema["output_key"] = self.returns
|
|
29
|
+
if self.weights_path:
|
|
30
|
+
schema["weights_path"] = self.weights_path
|
|
31
|
+
return schema
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def infer(
|
|
35
|
+
target=None,
|
|
36
|
+
*,
|
|
37
|
+
title: str | None = None,
|
|
38
|
+
returns: str | None = None,
|
|
39
|
+
device: str = "auto",
|
|
40
|
+
weights_path: str | None = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Decorator that marks a method as an inference step.
|
|
44
|
+
|
|
45
|
+
Inference nodes are where the model runs. They receive
|
|
46
|
+
preprocessed data and return predictions. The `device` hint
|
|
47
|
+
tells the runtime where to run — "auto" picks GPU if available.
|
|
48
|
+
|
|
49
|
+
The `weights_path` is informational — `push` uses it to verify
|
|
50
|
+
weights are bundled or accessible in the container.
|
|
51
|
+
|
|
52
|
+
Usage:
|
|
53
|
+
@st.infer(returns="mask",
|
|
54
|
+
device="cuda",
|
|
55
|
+
weights_path="checkpoints/best.pth")
|
|
56
|
+
def segment(self, resampled, threshold):
|
|
57
|
+
import torch
|
|
58
|
+
model = torch.load("checkpoints/best.pth", map_location="cpu")
|
|
59
|
+
model.eval()
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
tensor = torch.FloatTensor(resampled).unsqueeze(0).unsqueeze(0)
|
|
62
|
+
pred = model(tensor).squeeze().numpy()
|
|
63
|
+
return (pred > threshold).astype("uint8")
|
|
64
|
+
|
|
65
|
+
@st.infer(returns="embedding", device="cpu")
|
|
66
|
+
def encode(self, text):
|
|
67
|
+
return self.encoder.encode(text)
|
|
68
|
+
"""
|
|
69
|
+
def decorator(fn):
|
|
70
|
+
fn.__infer_metadata__ = InferMetadata(
|
|
71
|
+
title=title or fn.__name__.replace("_", " ").title(),
|
|
72
|
+
returns=returns,
|
|
73
|
+
device=device,
|
|
74
|
+
weights_path=weights_path,
|
|
75
|
+
)
|
|
76
|
+
return fn
|
|
77
|
+
|
|
78
|
+
if target is not None:
|
|
79
|
+
return decorator(target)
|
|
80
|
+
return decorator
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_infer_metadata(obj) -> InferMetadata | None:
|
|
84
|
+
return getattr(obj, "__infer_metadata__", None)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def is_infer(obj) -> bool:
|
|
88
|
+
return hasattr(obj, "__infer_metadata__")
|