streamtrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,57 @@
1
+ """
2
+ Streamtrace SDK — decorator-based ML pipeline contracts for clinical AI.
3
+
4
+ Usage:
5
+ import streamtrace as st
6
+
7
+ @st.app(title="Cardiac Segmentation", version="1.0.0")
8
+ class MyPipeline:
9
+
10
+ @st.input(title="CT Scan", widget=st.FileInput(...), returns="scan")
11
+ def load_scan(self, path): ...
12
+
13
+ @st.preprocess(title="Normalize", returns="normalized")
14
+ def normalize(self, scan): ...
15
+
16
+ @st.infer(title="Segment", returns="mask", weights_path="model.pt")
17
+ def segment(self, normalized): ...
18
+
19
+ @st.postprocess(title="Clean Mask", returns="clean_mask")
20
+ def clean(self, mask): ...
21
+
22
+ @st.output(title="Result", widget=st.FileOutput(...))
23
+ def save(self, clean_mask): ...
24
+ """
25
+
26
+ from .decorators.app_decorator import app, get_app_metadata, is_app
27
+ from .decorators.input_decorator import input, get_input_metadata, is_input
28
+ from .decorators.preprocess_decorator import preprocess, get_preprocess_metadata, is_preprocess
29
+ from .decorators.infer_decorator import infer, get_infer_metadata, is_infer
30
+ from .decorators.postprocess_decorator import postprocess, get_postprocess_metadata, is_postprocess
31
+ from .decorators.output_decorator import output, get_output_metadata, is_output
32
+
33
+ from .widgets.input_widgets import FileInput, TextInput, InputDataType
34
+ from .widgets.output_widgets import FileOutput, ImageOutput, PlotOutput, MetricOutput, OutputDataType
35
+
36
+ from .dag import build_dag, DAG, Node
37
+
38
+ __version__ = "0.1.0"
39
+
40
+ __all__ = [
41
+ # App
42
+ "app", "get_app_metadata", "is_app",
43
+ # Node decorators
44
+ "input", "get_input_metadata", "is_input",
45
+ "preprocess", "get_preprocess_metadata", "is_preprocess",
46
+ "infer", "get_infer_metadata", "is_infer",
47
+ "postprocess", "get_postprocess_metadata", "is_postprocess",
48
+ "output", "get_output_metadata", "is_output",
49
+ # Input widgets
50
+ "FileInput", "TextInput", "InputDataType",
51
+ # Output widgets
52
+ "FileOutput", "ImageOutput", "PlotOutput", "MetricOutput", "OutputDataType",
53
+ # DAG
54
+ "build_dag", "DAG", "Node",
55
+ # Version
56
+ "__version__",
57
+ ]
streamtrace/dag.py ADDED
@@ -0,0 +1,386 @@
1
+ """
2
+ streamtrace/dag.py
3
+
4
+ Builds a dependency graph from @st.app decorated methods,
5
+ validates the wiring, and provides topological execution order.
6
+
7
+ The core contract: each decorated method's **parameter names** (excluding `self`)
8
+ are matched against other methods' **return aliases** (the `returns=` kwarg on the
9
+ decorator) or, if no alias is set, the method name itself.
10
+
11
+ @st.input(returns="scan")
12
+ def load_scan(self, path): ...
13
+
14
+ @st.preprocess(returns="normalized")
15
+ def normalize(self, scan): ... # ← "scan" resolves to load_scan
16
+
17
+ @st.infer(returns="mask")
18
+ def segment(self, normalized): ... # ← "normalized" resolves to normalize
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import inspect
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Node: one decorated method in the pipeline
30
+ # ---------------------------------------------------------------------------
31
+
32
+ PHASES = ("input", "preprocess", "infer", "postprocess", "output")
33
+
34
+ # Map from decorator metadata attribute → phase name
35
+ _META_ATTR_TO_PHASE: dict[str, str] = {
36
+ "__input_metadata__": "input",
37
+ "__preprocess_metadata__": "preprocess",
38
+ "__infer_metadata__": "infer",
39
+ "__postprocess_metadata__": "postprocess",
40
+ "__output_metadata__": "output",
41
+ }
42
+
43
+
44
+ @dataclass
45
+ class Node:
46
+ """A single node in the execution DAG."""
47
+
48
+ name: str # method name on the class
49
+ phase: str # one of PHASES
50
+ method: Callable # the unbound method
51
+ params: list[str] # parameter names excluding 'self'
52
+ returns_alias: str | None # the `returns=` value from the decorator, or None
53
+ metadata: Any = None # the full decorator metadata object
54
+
55
+ @property
56
+ def output_key(self) -> str:
57
+ """The name downstream nodes use to reference this node's return value."""
58
+ return self.returns_alias or self.name
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # DAG: the full dependency graph
63
+ # ---------------------------------------------------------------------------
64
+
65
+
66
+ @dataclass
67
+ class DAG:
68
+ """
69
+ Directed acyclic graph of pipeline nodes.
70
+
71
+ Nodes are discovered from class decorator metadata.
72
+ Edges are inferred from parameter name → output key matching.
73
+ """
74
+
75
+ nodes: dict[str, Node] = field(default_factory=dict)
76
+
77
+ # Computed lazily and cached
78
+ _output_key_to_node: dict[str, str] = field(
79
+ default_factory=dict, repr=False
80
+ )
81
+ _edges: dict[str, list[str]] | None = field(default=None, repr=False)
82
+
83
+ # ------------------------------------------------------------------
84
+ # Construction
85
+ # ------------------------------------------------------------------
86
+
87
+ def add_node(self, node: Node) -> None:
88
+ if node.name in self.nodes:
89
+ raise ValueError(f"Duplicate node name: '{node.name}'")
90
+
91
+ output_key = node.output_key
92
+ if output_key in self._output_key_to_node:
93
+ existing = self._output_key_to_node[output_key]
94
+ raise ValueError(
95
+ f"Duplicate output key '{output_key}': "
96
+ f"both '{existing}' and '{node.name}' produce it"
97
+ )
98
+
99
+ self.nodes[node.name] = node
100
+ self._output_key_to_node[output_key] = node.name
101
+ self._edges = None # invalidate cache
102
+
103
+ # ------------------------------------------------------------------
104
+ # Edge resolution
105
+ # ------------------------------------------------------------------
106
+
107
+ @property
108
+ def edges(self) -> dict[str, list[str]]:
109
+ """
110
+ Map of node_name → list of node_names it depends on.
111
+ Edges are resolved by matching parameter names to output keys.
112
+
113
+ Input nodes are excluded from edge resolution because their
114
+ parameters come from the user, not from upstream nodes.
115
+ """
116
+ if self._edges is not None:
117
+ return self._edges
118
+
119
+ self._edges = {}
120
+ for name, node in self.nodes.items():
121
+ if node.phase == "input":
122
+ # Input nodes receive user-provided values, not upstream outputs
123
+ self._edges[name] = []
124
+ else:
125
+ deps = []
126
+ for param in node.params:
127
+ if param in self._output_key_to_node:
128
+ deps.append(self._output_key_to_node[param])
129
+ # If param doesn't match any output key, validation catches it
130
+ self._edges[name] = deps
131
+
132
+ return self._edges
133
+
134
+ # ------------------------------------------------------------------
135
+ # Validation
136
+ # ------------------------------------------------------------------
137
+
138
+ def validate(self) -> list[str]:
139
+ """
140
+ Return a list of error strings. Empty list = valid pipeline.
141
+
142
+ Checks:
143
+ 1. At least one input node exists
144
+ 2. At least one output node exists
145
+ 3. Every non-input parameter resolves to an upstream output key
146
+ 4. No dependency cycles
147
+ 5. Phase ordering is respected (no preprocess depending on infer, etc.)
148
+ """
149
+ errors: list[str] = []
150
+
151
+ # --- Check: required phases ---
152
+ phases_present = {n.phase for n in self.nodes.values()}
153
+ if "input" not in phases_present:
154
+ errors.append("Pipeline has no @input nodes")
155
+ if "output" not in phases_present:
156
+ errors.append("Pipeline has no @output nodes")
157
+
158
+ # --- Check: all parameters resolve ---
159
+ for name, node in self.nodes.items():
160
+ if node.phase == "input":
161
+ continue
162
+ for param in node.params:
163
+ if param not in self._output_key_to_node:
164
+ errors.append(
165
+ f"Node '{name}' expects parameter '{param}' "
166
+ f"but no node produces output key '{param}'. "
167
+ f"Available keys: {sorted(self._output_key_to_node.keys())}"
168
+ )
169
+
170
+ # --- Check: phase ordering ---
171
+ phase_order = {phase: i for i, phase in enumerate(PHASES)}
172
+ for name, dep_names in self.edges.items():
173
+ node_phase = phase_order[self.nodes[name].phase]
174
+ for dep_name in dep_names:
175
+ dep_phase = phase_order[self.nodes[dep_name].phase]
176
+ if dep_phase > node_phase:
177
+ errors.append(
178
+ f"Node '{name}' ({self.nodes[name].phase}) depends on "
179
+ f"'{dep_name}' ({self.nodes[dep_name].phase}), "
180
+ f"which is a later phase"
181
+ )
182
+
183
+ # --- Check: cycle detection ---
184
+ cycle = self._detect_cycle()
185
+ if cycle:
186
+ errors.append(f"Dependency cycle detected: {' → '.join(cycle)}")
187
+
188
+ return errors
189
+
190
+ def _detect_cycle(self) -> list[str] | None:
191
+ """DFS-based cycle detection. Returns the cycle path or None."""
192
+ WHITE, GRAY, BLACK = 0, 1, 2
193
+ color = {name: WHITE for name in self.nodes}
194
+ parent: dict[str, str | None] = {}
195
+
196
+ def dfs(u: str) -> list[str] | None:
197
+ color[u] = GRAY
198
+ for v in self.edges.get(u, []):
199
+ if color[v] == GRAY:
200
+ # Reconstruct cycle
201
+ cycle = [v, u]
202
+ curr = u
203
+ while curr != v:
204
+ curr = parent.get(curr)
205
+ if curr is None:
206
+ break
207
+ cycle.append(curr)
208
+ cycle.reverse()
209
+ return cycle
210
+ if color[v] == WHITE:
211
+ parent[v] = u
212
+ result = dfs(v)
213
+ if result:
214
+ return result
215
+ color[u] = BLACK
216
+ return None
217
+
218
+ for name in self.nodes:
219
+ if color[name] == WHITE:
220
+ result = dfs(name)
221
+ if result:
222
+ return result
223
+ return None
224
+
225
+ # ------------------------------------------------------------------
226
+ # Topological sort
227
+ # ------------------------------------------------------------------
228
+
229
+ def topological_sort(self) -> list[Node]:
230
+ """
231
+ Return nodes in execution order via Kahn's algorithm.
232
+
233
+ Ties within the same depth are broken by phase order first,
234
+ then by definition order. This ensures inputs run before
235
+ preprocessing even if there's no explicit dependency.
236
+ """
237
+ phase_priority = {phase: i for i, phase in enumerate(PHASES)}
238
+
239
+ # Build in-degree map
240
+ in_degree: dict[str, int] = {name: 0 for name in self.nodes}
241
+ for name, deps in self.edges.items():
242
+ in_degree[name] = len(deps)
243
+
244
+ # Seed with zero-degree nodes, sorted by phase then name
245
+ queue: list[str] = sorted(
246
+ [n for n, d in in_degree.items() if d == 0],
247
+ key=lambda n: (phase_priority[self.nodes[n].phase], n),
248
+ )
249
+
250
+ order: list[Node] = []
251
+ while queue:
252
+ current = queue.pop(0)
253
+ order.append(self.nodes[current])
254
+
255
+ # Find nodes that depend on current and decrement
256
+ for name, deps in self.edges.items():
257
+ if current in deps:
258
+ in_degree[name] -= 1
259
+ if in_degree[name] == 0:
260
+ queue.append(name)
261
+ # Re-sort to maintain phase priority
262
+ queue.sort(
263
+ key=lambda n: (
264
+ phase_priority[self.nodes[n].phase],
265
+ n,
266
+ )
267
+ )
268
+
269
+ if len(order) != len(self.nodes):
270
+ # Should not happen if validate() passed, but safety net
271
+ missing = set(self.nodes.keys()) - {n.name for n in order}
272
+ raise RuntimeError(
273
+ f"Topological sort incomplete — stuck nodes: {missing}. "
274
+ f"This usually indicates a cycle."
275
+ )
276
+
277
+ return order
278
+
279
+ # ------------------------------------------------------------------
280
+ # Introspection helpers
281
+ # ------------------------------------------------------------------
282
+
283
+ def get_input_nodes(self) -> list[Node]:
284
+ return [n for n in self.nodes.values() if n.phase == "input"]
285
+
286
+ def get_output_nodes(self) -> list[Node]:
287
+ return [n for n in self.nodes.values() if n.phase == "output"]
288
+
289
+ def get_intermediate_outputs(self) -> list[Node]:
290
+ return [
291
+ n
292
+ for n in self.nodes.values()
293
+ if n.phase == "output"
294
+ and getattr(n.metadata, "intermediate", False)
295
+ ]
296
+
297
+ def get_dependencies(self, node_name: str) -> list[Node]:
298
+ """Get the direct upstream dependencies of a node."""
299
+ return [self.nodes[dep] for dep in self.edges.get(node_name, [])]
300
+
301
+ def to_schema(self) -> dict:
302
+ """
303
+ Emit a JSON-serializable dict describing the full pipeline.
304
+ Used by `streamtrace push` to register with the backend,
305
+ and by the frontend to render the pipeline visualization.
306
+ """
307
+ nodes_schema = []
308
+ for node in self.topological_sort():
309
+ entry = {
310
+ "name": node.name,
311
+ "phase": node.phase,
312
+ "output_key": node.output_key,
313
+ "depends_on": [
314
+ self.nodes[d].output_key
315
+ for d in self.edges.get(node.name, [])
316
+ ],
317
+ }
318
+ # Include widget schema for input nodes
319
+ if node.phase == "input" and node.metadata:
320
+ if hasattr(node.metadata, "to_schema"):
321
+ entry["widget"] = node.metadata.to_schema()
322
+
323
+ # Include output metadata
324
+ if node.phase == "output" and node.metadata:
325
+ entry["intermediate"] = getattr(node.metadata, "intermediate", False)
326
+ widget = getattr(node.metadata, "widget", None)
327
+ if widget is not None:
328
+ entry["widget"] = widget.to_schema()
329
+
330
+ nodes_schema.append(entry)
331
+
332
+ return {"nodes": nodes_schema}
333
+
334
+
335
+ # ---------------------------------------------------------------------------
336
+ # build_dag: extract a DAG from an @st.app decorated class
337
+ # ---------------------------------------------------------------------------
338
+
339
+
340
+ def build_dag(app_cls: type) -> DAG:
341
+ """
342
+ Inspect an @st.app class and build the execution DAG from
343
+ its decorated methods.
344
+
345
+ Each method decorated with @st.input, @st.preprocess, @st.infer,
346
+ or @st.output becomes a node. Parameter names on non-input nodes
347
+ are matched against the `returns` alias (or method name) of other
348
+ nodes to form edges.
349
+ """
350
+ dag = DAG()
351
+
352
+ # Use vars(app_cls) to get methods defined on the class itself,
353
+ # not inherited ones. Preserves definition order in Python 3.7+.
354
+ for attr_name, attr_value in vars(app_cls).items():
355
+ if not callable(attr_value):
356
+ continue
357
+
358
+ for meta_attr, phase in _META_ATTR_TO_PHASE.items():
359
+ metadata = getattr(attr_value, meta_attr, None)
360
+ if metadata is None:
361
+ continue
362
+
363
+ # Extract parameter names (excluding 'self')
364
+ sig = inspect.signature(attr_value)
365
+ params = [
366
+ p
367
+ for p in sig.parameters
368
+ if p != "self"
369
+ ]
370
+
371
+ # Extract returns alias from metadata
372
+ returns_alias = getattr(metadata, "returns", None)
373
+
374
+ dag.add_node(
375
+ Node(
376
+ name=attr_name,
377
+ phase=phase,
378
+ method=attr_value,
379
+ params=params,
380
+ returns_alias=returns_alias,
381
+ metadata=metadata,
382
+ )
383
+ )
384
+ break # a method can only have one phase
385
+
386
+ return dag
File without changes
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+
7
+ # ===================================================================
8
+ # @st.app — class decorator
9
+ # ===================================================================
10
+
11
+ @dataclass(frozen=True)
12
+ class AppMetadata:
13
+ """Metadata for the top-level pipeline class."""
14
+ title: str = "Streamtrace App"
15
+ version: str = "0.1.0"
16
+ description: str = ""
17
+ # Docker base image hint — used by `push` to select the right
18
+ # container. The agent picks this based on what the model needs.
19
+ docker_base: str | None = None
20
+ # Python dependencies beyond what's in the base image
21
+ requirements: list[str] = field(default_factory=list)
22
+
23
+ def to_schema(self) -> dict[str, Any]:
24
+ schema = {
25
+ "title": self.title,
26
+ "version": self.version,
27
+ "description": self.description,
28
+ }
29
+ if self.docker_base:
30
+ schema["docker_base"] = self.docker_base
31
+ if self.requirements:
32
+ schema["requirements"] = self.requirements
33
+ return schema
34
+
35
+
36
+ def app(
37
+ target=None,
38
+ *,
39
+ title: str = "Streamtrace App",
40
+ version: str = "0.1.0",
41
+ description: str = "",
42
+ docker_base: str | None = None,
43
+ requirements: list[str] | None = None,
44
+ ):
45
+ """
46
+ Class decorator that marks a class as a Streamtrace app.
47
+
48
+ Usage:
49
+ @st.app
50
+ class MyPipeline: ...
51
+
52
+ @st.app(title="Cardiac Seg", version="1.0.0",
53
+ docker_base="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
54
+ requirements=["nibabel", "scipy"])
55
+ class CardiacSegmentation: ...
56
+ """
57
+ def decorator(cls):
58
+ cls.__app_metadata__ = AppMetadata(
59
+ title=title,
60
+ version=version,
61
+ description=description,
62
+ docker_base=docker_base,
63
+ requirements=requirements or [],
64
+ )
65
+ return cls
66
+
67
+ if target is not None:
68
+ return decorator(target)
69
+ return decorator
70
+
71
+
72
+ def get_app_metadata(cls) -> AppMetadata | None:
73
+ return getattr(cls, "__app_metadata__", None)
74
+
75
+
76
+ def is_app(cls) -> bool:
77
+ return hasattr(cls, "__app_metadata__")
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+
7
+ # ===================================================================
8
+ # @st.infer — inference/model step decorator
9
+ # ===================================================================
10
+
11
+ @dataclass(frozen=True)
12
+ class InferMetadata:
13
+ """Metadata for an inference node."""
14
+ title: str = "Inference"
15
+ returns: str | None = None
16
+ # Hint for the runtime about GPU requirements
17
+ device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc.
18
+ # Optional path to model weights — used by `push` to
19
+ # ensure weights are available in the container
20
+ weights_path: str | None = None
21
+
22
+ def to_schema(self) -> dict[str, Any]:
23
+ schema: dict[str, Any] = {
24
+ "title": self.title,
25
+ "device": self.device,
26
+ }
27
+ if self.returns:
28
+ schema["output_key"] = self.returns
29
+ if self.weights_path:
30
+ schema["weights_path"] = self.weights_path
31
+ return schema
32
+
33
+
34
+ def infer(
35
+ target=None,
36
+ *,
37
+ title: str | None = None,
38
+ returns: str | None = None,
39
+ device: str = "auto",
40
+ weights_path: str | None = None,
41
+ ):
42
+ """
43
+ Decorator that marks a method as an inference step.
44
+
45
+ Inference nodes are where the model runs. They receive
46
+ preprocessed data and return predictions. The `device` hint
47
+ tells the runtime where to run — "auto" picks GPU if available.
48
+
49
+ The `weights_path` is informational — `push` uses it to verify
50
+ weights are bundled or accessible in the container.
51
+
52
+ Usage:
53
+ @st.infer(returns="mask",
54
+ device="cuda",
55
+ weights_path="checkpoints/best.pth")
56
+ def segment(self, resampled, threshold):
57
+ import torch
58
+ model = torch.load("checkpoints/best.pth", map_location="cpu")
59
+ model.eval()
60
+ with torch.no_grad():
61
+ tensor = torch.FloatTensor(resampled).unsqueeze(0).unsqueeze(0)
62
+ pred = model(tensor).squeeze().numpy()
63
+ return (pred > threshold).astype("uint8")
64
+
65
+ @st.infer(returns="embedding", device="cpu")
66
+ def encode(self, text):
67
+ return self.encoder.encode(text)
68
+ """
69
+ def decorator(fn):
70
+ fn.__infer_metadata__ = InferMetadata(
71
+ title=title or fn.__name__.replace("_", " ").title(),
72
+ returns=returns,
73
+ device=device,
74
+ weights_path=weights_path,
75
+ )
76
+ return fn
77
+
78
+ if target is not None:
79
+ return decorator(target)
80
+ return decorator
81
+
82
+
83
+ def get_infer_metadata(obj) -> InferMetadata | None:
84
+ return getattr(obj, "__infer_metadata__", None)
85
+
86
+
87
+ def is_infer(obj) -> bool:
88
+ return hasattr(obj, "__infer_metadata__")