pyoco 0.3.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoco/cli/main.py +182 -23
- pyoco/client.py +29 -9
- pyoco/core/context.py +81 -1
- pyoco/core/engine.py +182 -3
- pyoco/core/exceptions.py +15 -0
- pyoco/core/models.py +130 -1
- pyoco/discovery/loader.py +32 -1
- pyoco/discovery/plugins.py +148 -0
- pyoco/dsl/expressions.py +160 -0
- pyoco/dsl/nodes.py +56 -0
- pyoco/dsl/syntax.py +241 -95
- pyoco/dsl/validator.py +104 -0
- pyoco/server/api.py +59 -18
- pyoco/server/metrics.py +113 -0
- pyoco/server/models.py +2 -0
- pyoco/server/store.py +153 -16
- pyoco/server/webhook.py +108 -0
- pyoco/socketless_reset.py +7 -0
- pyoco/worker/runner.py +3 -8
- {pyoco-0.3.0.dist-info → pyoco-0.5.1.dist-info}/METADATA +16 -1
- pyoco-0.5.1.dist-info/RECORD +33 -0
- pyoco-0.3.0.dist-info/RECORD +0 -25
- {pyoco-0.3.0.dist-info → pyoco-0.5.1.dist-info}/WHEEL +0 -0
- {pyoco-0.3.0.dist-info → pyoco-0.5.1.dist-info}/top_level.txt +0 -0
pyoco/cli/main.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import json
|
|
2
3
|
import sys
|
|
3
4
|
import os
|
|
4
5
|
import signal
|
|
6
|
+
import time
|
|
7
|
+
from types import SimpleNamespace
|
|
5
8
|
from ..schemas.config import PyocoConfig
|
|
6
9
|
from ..discovery.loader import TaskLoader
|
|
7
10
|
from ..core.models import Flow
|
|
@@ -28,6 +31,8 @@ def main():
|
|
|
28
31
|
check_parser = subparsers.add_parser("check", help="Verify a workflow")
|
|
29
32
|
check_parser.add_argument("--config", required=True, help="Path to flow.yaml")
|
|
30
33
|
check_parser.add_argument("--flow", default="main", help="Flow name to check")
|
|
34
|
+
check_parser.add_argument("--dry-run", action="store_true", help="Traverse flow without executing tasks")
|
|
35
|
+
check_parser.add_argument("--json", action="store_true", help="Output report as JSON")
|
|
31
36
|
|
|
32
37
|
# List tasks command
|
|
33
38
|
list_parser = subparsers.add_parser("list-tasks", help="List available tasks")
|
|
@@ -55,6 +60,8 @@ def main():
|
|
|
55
60
|
runs_list = runs_subparsers.add_parser("list", help="List runs")
|
|
56
61
|
runs_list.add_argument("--server", default="http://localhost:8000", help="Server URL")
|
|
57
62
|
runs_list.add_argument("--status", help="Filter by status")
|
|
63
|
+
runs_list.add_argument("--flow", help="Filter by flow name")
|
|
64
|
+
runs_list.add_argument("--limit", type=int, help="Maximum number of runs to show")
|
|
58
65
|
|
|
59
66
|
runs_show = runs_subparsers.add_parser("show", help="Show run details")
|
|
60
67
|
runs_show.add_argument("run_id", help="Run ID")
|
|
@@ -63,6 +70,26 @@ def main():
|
|
|
63
70
|
runs_cancel = runs_subparsers.add_parser("cancel", help="Cancel a run")
|
|
64
71
|
runs_cancel.add_argument("run_id", help="Run ID")
|
|
65
72
|
runs_cancel.add_argument("--server", default="http://localhost:8000", help="Server URL")
|
|
73
|
+
|
|
74
|
+
runs_inspect = runs_subparsers.add_parser("inspect", help="Inspect run details")
|
|
75
|
+
runs_inspect.add_argument("run_id", help="Run ID")
|
|
76
|
+
runs_inspect.add_argument("--server", default="http://localhost:8000", help="Server URL")
|
|
77
|
+
runs_inspect.add_argument("--json", action="store_true", help="Output JSON payload")
|
|
78
|
+
|
|
79
|
+
runs_logs = runs_subparsers.add_parser("logs", help="Show run logs")
|
|
80
|
+
runs_logs.add_argument("run_id", help="Run ID")
|
|
81
|
+
runs_logs.add_argument("--server", default="http://localhost:8000", help="Server URL")
|
|
82
|
+
runs_logs.add_argument("--task", help="Filter logs by task")
|
|
83
|
+
runs_logs.add_argument("--tail", type=int, help="Show last N log entries")
|
|
84
|
+
runs_logs.add_argument("--follow", action="store_true", help="Stream logs until completion")
|
|
85
|
+
runs_logs.add_argument("--allow-failure", action="store_true", help="Don't exit non-zero when run failed")
|
|
86
|
+
|
|
87
|
+
plugins_parser = subparsers.add_parser("plugins", help="Inspect plug-in entry points")
|
|
88
|
+
plugins_sub = plugins_parser.add_subparsers(dest="plugins_command")
|
|
89
|
+
plugins_list = plugins_sub.add_parser("list", help="List discovered plug-ins")
|
|
90
|
+
plugins_list.add_argument("--json", action="store_true", help="Output JSON payload")
|
|
91
|
+
plugins_lint = plugins_sub.add_parser("lint", help="Validate plug-ins for upcoming requirements")
|
|
92
|
+
plugins_lint.add_argument("--json", action="store_true", help="Output JSON payload")
|
|
66
93
|
|
|
67
94
|
args = parser.parse_args()
|
|
68
95
|
|
|
@@ -94,6 +121,55 @@ def main():
|
|
|
94
121
|
print(f" - {name}")
|
|
95
122
|
return
|
|
96
123
|
|
|
124
|
+
if args.command == "plugins":
|
|
125
|
+
reports = _collect_plugin_reports()
|
|
126
|
+
if args.plugins_command == "list":
|
|
127
|
+
if getattr(args, "json", False):
|
|
128
|
+
print(json.dumps(reports, indent=2))
|
|
129
|
+
else:
|
|
130
|
+
if not reports:
|
|
131
|
+
print("No plug-ins registered under group 'pyoco.tasks'.")
|
|
132
|
+
else:
|
|
133
|
+
print("Discovered plug-ins:")
|
|
134
|
+
for info in reports:
|
|
135
|
+
mod = info.get("module") or info.get("value")
|
|
136
|
+
print(f" - {info.get('name')} ({mod})")
|
|
137
|
+
if info.get("error"):
|
|
138
|
+
print(f" ⚠️ error: {info['error']}")
|
|
139
|
+
continue
|
|
140
|
+
for task in info.get("tasks", []):
|
|
141
|
+
warn_msg = "; ".join(task.get("warnings", [])) or "ok"
|
|
142
|
+
print(f" • {task['name']} [{task['origin']}] ({warn_msg})")
|
|
143
|
+
for warn in info.get("warnings", []):
|
|
144
|
+
print(f" ⚠️ {warn}")
|
|
145
|
+
elif args.plugins_command == "lint":
|
|
146
|
+
issues = []
|
|
147
|
+
for info in reports:
|
|
148
|
+
prefix = info["name"]
|
|
149
|
+
if info.get("error"):
|
|
150
|
+
issues.append(f"{prefix}: {info['error']}")
|
|
151
|
+
continue
|
|
152
|
+
for warn in info.get("warnings", []):
|
|
153
|
+
issues.append(f"{prefix}: {warn}")
|
|
154
|
+
for task in info.get("tasks", []):
|
|
155
|
+
for warn in task.get("warnings", []):
|
|
156
|
+
issues.append(f"{prefix}.{task['name']}: {warn}")
|
|
157
|
+
payload = {"issues": issues, "reports": reports}
|
|
158
|
+
if getattr(args, "json", False):
|
|
159
|
+
print(json.dumps(payload, indent=2))
|
|
160
|
+
else:
|
|
161
|
+
if not issues:
|
|
162
|
+
print("✅ All plug-ins look good.")
|
|
163
|
+
else:
|
|
164
|
+
print("⚠️ Plug-in issues found:")
|
|
165
|
+
for issue in issues:
|
|
166
|
+
print(f" - {issue}")
|
|
167
|
+
if issues:
|
|
168
|
+
sys.exit(1)
|
|
169
|
+
else:
|
|
170
|
+
plugins_parser.print_help()
|
|
171
|
+
return
|
|
172
|
+
|
|
97
173
|
if args.command == "server":
|
|
98
174
|
if args.server_command == "start":
|
|
99
175
|
import uvicorn
|
|
@@ -113,7 +189,7 @@ def main():
|
|
|
113
189
|
client = Client(args.server)
|
|
114
190
|
try:
|
|
115
191
|
if args.runs_command == "list":
|
|
116
|
-
runs = client.list_runs(status=args.status)
|
|
192
|
+
runs = client.list_runs(status=args.status, flow=args.flow, limit=args.limit)
|
|
117
193
|
print(f"🐇 Active Runs ({len(runs)}):")
|
|
118
194
|
print(f"{'ID':<36} | {'Status':<12} | {'Flow':<15}")
|
|
119
195
|
print("-" * 70)
|
|
@@ -134,6 +210,31 @@ def main():
|
|
|
134
210
|
elif args.runs_command == "cancel":
|
|
135
211
|
client.cancel_run(args.run_id)
|
|
136
212
|
print(f"🛑 Cancellation requested for run {args.run_id}")
|
|
213
|
+
elif args.runs_command == "inspect":
|
|
214
|
+
run = client.get_run(args.run_id)
|
|
215
|
+
if args.json:
|
|
216
|
+
print(json.dumps(run, indent=2))
|
|
217
|
+
else:
|
|
218
|
+
print(f"🐇 Run: {run['run_id']} ({run.get('flow_name', 'n/a')})")
|
|
219
|
+
print(f"Status: {run['status']}")
|
|
220
|
+
if run.get("start_time"):
|
|
221
|
+
print(f"Started: {run['start_time']}")
|
|
222
|
+
if run.get("end_time"):
|
|
223
|
+
print(f"Ended: {run['end_time']}")
|
|
224
|
+
print("Tasks:")
|
|
225
|
+
records = run.get("task_records", {})
|
|
226
|
+
for name, info in records.items():
|
|
227
|
+
state = info.get("state", run["tasks"].get(name))
|
|
228
|
+
duration = info.get("duration_ms")
|
|
229
|
+
duration_str = f"{duration:.2f} ms" if duration else "-"
|
|
230
|
+
print(f" - {name}: {state} ({duration_str})")
|
|
231
|
+
if info.get("error"):
|
|
232
|
+
print(f" error: {info['error']}")
|
|
233
|
+
if not records:
|
|
234
|
+
for t_name, t_state in run.get("tasks", {}).items():
|
|
235
|
+
print(f" - {t_name}: {t_state}")
|
|
236
|
+
elif args.runs_command == "logs":
|
|
237
|
+
_stream_logs(client, args)
|
|
137
238
|
except Exception as e:
|
|
138
239
|
print(f"Error: {e}")
|
|
139
240
|
return
|
|
@@ -164,14 +265,16 @@ def main():
|
|
|
164
265
|
sys.exit(1)
|
|
165
266
|
return
|
|
166
267
|
# Build Flow from graph string
|
|
167
|
-
from ..dsl.syntax import TaskWrapper
|
|
268
|
+
from ..dsl.syntax import TaskWrapper, switch
|
|
168
269
|
eval_context = {name: TaskWrapper(task) for name, task in loader.tasks.items()}
|
|
270
|
+
eval_context["switch"] = switch
|
|
169
271
|
|
|
170
272
|
try:
|
|
171
273
|
# Create Flow and add all loaded tasks
|
|
172
274
|
flow = Flow(name=args.flow)
|
|
173
275
|
for t in loader.tasks.values():
|
|
174
276
|
flow.add_task(t)
|
|
277
|
+
eval_context["flow"] = flow
|
|
175
278
|
|
|
176
279
|
# Evaluate graph to set up dependencies
|
|
177
280
|
exec(flow_conf.graph, {}, eval_context)
|
|
@@ -210,36 +313,36 @@ def main():
|
|
|
210
313
|
|
|
211
314
|
# 1. Check imports (already done by loader.load(), but we can check for missing tasks in graph)
|
|
212
315
|
# 2. Build flow to check graph
|
|
213
|
-
from ..dsl.syntax import TaskWrapper
|
|
316
|
+
from ..dsl.syntax import TaskWrapper, switch
|
|
214
317
|
eval_context = {name: TaskWrapper(task) for name, task in loader.tasks.items()}
|
|
318
|
+
eval_context["switch"] = switch
|
|
215
319
|
|
|
216
320
|
try:
|
|
217
321
|
flow = Flow(name=args.flow)
|
|
218
322
|
for t in loader.tasks.values():
|
|
219
323
|
flow.add_task(t)
|
|
324
|
+
eval_context["flow"] = flow
|
|
220
325
|
|
|
221
326
|
eval(flow_conf.graph, {}, eval_context)
|
|
222
327
|
|
|
223
328
|
# 3. Reachability / Orphans
|
|
224
|
-
# Nodes with no deps and no dependents (except if single node flow)
|
|
225
329
|
if len(flow.tasks) > 1:
|
|
226
330
|
for t in flow.tasks:
|
|
227
331
|
if not t.dependencies and not t.dependents:
|
|
228
332
|
warnings.append(f"Task '{t.name}' is orphaned (no dependencies or dependents).")
|
|
229
333
|
|
|
230
334
|
# 4. Cycles
|
|
231
|
-
# Simple DFS for cycle detection
|
|
232
335
|
visited = set()
|
|
233
336
|
path = set()
|
|
337
|
+
|
|
234
338
|
def visit(node):
|
|
235
339
|
if node in path:
|
|
236
|
-
return True
|
|
340
|
+
return True
|
|
237
341
|
if node in visited:
|
|
238
342
|
return False
|
|
239
|
-
|
|
240
343
|
visited.add(node)
|
|
241
344
|
path.add(node)
|
|
242
|
-
for dep in node.dependencies:
|
|
345
|
+
for dep in node.dependencies:
|
|
243
346
|
if visit(dep):
|
|
244
347
|
return True
|
|
245
348
|
path.remove(node)
|
|
@@ -255,29 +358,85 @@ def main():
|
|
|
255
358
|
for t in flow.tasks:
|
|
256
359
|
sig = inspect.signature(t.func)
|
|
257
360
|
for name, param in sig.parameters.items():
|
|
258
|
-
if name == 'ctx':
|
|
259
|
-
|
|
260
|
-
# This is hard because inputs are resolved at runtime.
|
|
261
|
-
# But we can check if 'inputs' mapping exists for it.
|
|
361
|
+
if name == 'ctx':
|
|
362
|
+
continue
|
|
262
363
|
if name not in t.inputs and name not in flow_conf.defaults:
|
|
263
|
-
# Warning: might be missing input
|
|
264
364
|
warnings.append(f"Task '{t.name}' argument '{name}' might be missing input (not in inputs or defaults).")
|
|
265
365
|
|
|
266
366
|
except Exception as e:
|
|
267
367
|
errors.append(f"Graph evaluation failed: {e}")
|
|
268
368
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
369
|
+
if args.dry_run:
|
|
370
|
+
from ..dsl.validator import FlowValidator
|
|
371
|
+
try:
|
|
372
|
+
validator = FlowValidator(flow)
|
|
373
|
+
dr_report = validator.validate()
|
|
374
|
+
warnings.extend(dr_report.warnings)
|
|
375
|
+
errors.extend(dr_report.errors)
|
|
376
|
+
except Exception as exc:
|
|
377
|
+
print(f"❌ Dry run internal error: {exc}")
|
|
378
|
+
import traceback
|
|
379
|
+
traceback.print_exc()
|
|
380
|
+
sys.exit(3)
|
|
381
|
+
|
|
382
|
+
status = "ok"
|
|
383
|
+
if errors:
|
|
384
|
+
status = "error"
|
|
385
|
+
elif warnings:
|
|
386
|
+
status = "warning"
|
|
387
|
+
|
|
388
|
+
report = {"status": status, "warnings": warnings, "errors": errors}
|
|
389
|
+
|
|
390
|
+
if args.json:
|
|
391
|
+
print(json.dumps(report, indent=2))
|
|
273
392
|
else:
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
print(
|
|
278
|
-
|
|
279
|
-
|
|
393
|
+
print("\n--- Check Report ---")
|
|
394
|
+
print(f"Status: {status}")
|
|
395
|
+
if not errors and not warnings:
|
|
396
|
+
print("✅ All checks passed!")
|
|
397
|
+
else:
|
|
398
|
+
for w in warnings:
|
|
399
|
+
print(f"⚠️ {w}")
|
|
400
|
+
for e in errors:
|
|
401
|
+
print(f"❌ {e}")
|
|
402
|
+
|
|
403
|
+
if errors:
|
|
404
|
+
sys.exit(2 if args.dry_run else 1)
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
def _collect_plugin_reports():
|
|
408
|
+
dummy = SimpleNamespace(
|
|
409
|
+
tasks={},
|
|
410
|
+
discovery=SimpleNamespace(entry_points=[], packages=[], glob_modules=[]),
|
|
411
|
+
)
|
|
412
|
+
loader = TaskLoader(dummy)
|
|
413
|
+
loader.load()
|
|
414
|
+
return loader.plugin_reports
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _stream_logs(client, args):
|
|
418
|
+
seen_seq = -1
|
|
419
|
+
follow = args.follow
|
|
420
|
+
while True:
|
|
421
|
+
tail = args.tail if (args.tail and seen_seq == -1 and not follow) else None
|
|
422
|
+
data = client.get_run_logs(args.run_id, task=args.task, tail=tail)
|
|
423
|
+
logs = data.get("logs", [])
|
|
424
|
+
logs.sort(key=lambda entry: entry.get("seq", 0))
|
|
425
|
+
for entry in logs:
|
|
426
|
+
seq = entry.get("seq", 0)
|
|
427
|
+
if seq <= seen_seq:
|
|
428
|
+
continue
|
|
429
|
+
line = entry.get("text", "")
|
|
430
|
+
line = line.rstrip("\n")
|
|
431
|
+
print(f"[{entry.get('task', 'unknown')}][{entry.get('stream', '')}] {line}")
|
|
432
|
+
seen_seq = seq
|
|
433
|
+
status = data.get("run_status", "UNKNOWN")
|
|
434
|
+
if not follow or status in ("COMPLETED", "FAILED", "CANCELLED"):
|
|
435
|
+
if status == "FAILED" and not args.allow_failure:
|
|
280
436
|
sys.exit(1)
|
|
437
|
+
break
|
|
438
|
+
time.sleep(1)
|
|
439
|
+
|
|
281
440
|
|
|
282
441
|
if __name__ == "__main__":
|
|
283
442
|
main()
|
pyoco/client.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import httpx
|
|
2
2
|
from typing import Dict, List, Optional, Any
|
|
3
|
-
from .core.models import RunStatus, TaskState
|
|
3
|
+
from .core.models import RunStatus, TaskState, RunContext
|
|
4
4
|
|
|
5
5
|
class Client:
|
|
6
6
|
def __init__(self, server_url: str, client_id: str = "cli"):
|
|
@@ -17,10 +17,19 @@ class Client:
|
|
|
17
17
|
resp.raise_for_status()
|
|
18
18
|
return resp.json()["run_id"]
|
|
19
19
|
|
|
20
|
-
def list_runs(
|
|
20
|
+
def list_runs(
|
|
21
|
+
self,
|
|
22
|
+
status: Optional[str] = None,
|
|
23
|
+
flow: Optional[str] = None,
|
|
24
|
+
limit: Optional[int] = None,
|
|
25
|
+
) -> List[Dict]:
|
|
21
26
|
params = {}
|
|
22
27
|
if status:
|
|
23
28
|
params["status"] = status
|
|
29
|
+
if flow:
|
|
30
|
+
params["flow"] = flow
|
|
31
|
+
if limit:
|
|
32
|
+
params["limit"] = limit
|
|
24
33
|
resp = self.client.get("/runs", params=params)
|
|
25
34
|
resp.raise_for_status()
|
|
26
35
|
return resp.json()
|
|
@@ -49,21 +58,32 @@ class Client:
|
|
|
49
58
|
# print(f"Poll failed: {e}")
|
|
50
59
|
return None
|
|
51
60
|
|
|
52
|
-
def heartbeat(self,
|
|
61
|
+
def heartbeat(self, run_ctx: RunContext) -> bool:
|
|
53
62
|
"""
|
|
54
63
|
Sends heartbeat. Returns True if cancellation is requested.
|
|
55
64
|
"""
|
|
56
65
|
try:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
resp = self.client.post(f"/runs/{run_id}/heartbeat", json={
|
|
66
|
+
states_json = {k: v.value if hasattr(v, 'value') else v for k, v in run_ctx.tasks.items()}
|
|
67
|
+
status_value = run_ctx.status.value if hasattr(run_ctx.status, 'value') else run_ctx.status
|
|
68
|
+
payload = {
|
|
62
69
|
"task_states": states_json,
|
|
70
|
+
"task_records": run_ctx.serialize_task_records(),
|
|
71
|
+
"logs": run_ctx.drain_logs(),
|
|
63
72
|
"run_status": status_value
|
|
64
|
-
}
|
|
73
|
+
}
|
|
74
|
+
resp = self.client.post(f"/runs/{run_ctx.run_id}/heartbeat", json=payload)
|
|
65
75
|
resp.raise_for_status()
|
|
66
76
|
return resp.json().get("cancel_requested", False)
|
|
67
77
|
except Exception as e:
|
|
68
78
|
print(f"Heartbeat failed: {e}")
|
|
69
79
|
return False
|
|
80
|
+
|
|
81
|
+
def get_run_logs(self, run_id: str, task: Optional[str] = None, tail: Optional[int] = None) -> Dict[str, Any]:
|
|
82
|
+
params = {}
|
|
83
|
+
if task:
|
|
84
|
+
params["task"] = task
|
|
85
|
+
if tail:
|
|
86
|
+
params["tail"] = tail
|
|
87
|
+
resp = self.client.get(f"/runs/{run_id}/logs", params=params)
|
|
88
|
+
resp.raise_for_status()
|
|
89
|
+
return resp.json()
|
pyoco/core/context.py
CHANGED
|
@@ -1,8 +1,46 @@
|
|
|
1
1
|
import threading
|
|
2
|
-
from typing import Any, Dict, List, Optional
|
|
2
|
+
from typing import Any, Dict, List, Optional, Sequence
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from .models import RunContext
|
|
5
5
|
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class LoopFrame:
|
|
9
|
+
name: str
|
|
10
|
+
type: str
|
|
11
|
+
index: Optional[int] = None
|
|
12
|
+
iteration: Optional[int] = None
|
|
13
|
+
count: Optional[int] = None
|
|
14
|
+
item: Any = None
|
|
15
|
+
condition: Optional[bool] = None
|
|
16
|
+
path: Optional[str] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LoopStack:
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self._frames: List[LoopFrame] = []
|
|
22
|
+
|
|
23
|
+
def push(self, frame: LoopFrame) -> LoopFrame:
|
|
24
|
+
parent_path = self._frames[-1].path if self._frames else ""
|
|
25
|
+
segment = frame.name
|
|
26
|
+
if frame.index is not None:
|
|
27
|
+
segment = f"{segment}[{frame.index}]"
|
|
28
|
+
frame.path = f"{parent_path}.{segment}" if parent_path else segment
|
|
29
|
+
self._frames.append(frame)
|
|
30
|
+
return frame
|
|
31
|
+
|
|
32
|
+
def pop(self) -> LoopFrame:
|
|
33
|
+
if not self._frames:
|
|
34
|
+
raise RuntimeError("Loop stack underflow")
|
|
35
|
+
return self._frames.pop()
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def current(self) -> Optional[LoopFrame]:
|
|
39
|
+
return self._frames[-1] if self._frames else None
|
|
40
|
+
|
|
41
|
+
def snapshot(self) -> Sequence[LoopFrame]:
|
|
42
|
+
return tuple(self._frames)
|
|
43
|
+
|
|
6
44
|
@dataclass
|
|
7
45
|
class Context:
|
|
8
46
|
"""
|
|
@@ -14,11 +52,13 @@ class Context:
|
|
|
14
52
|
artifacts: Dict[str, Any] = field(default_factory=dict)
|
|
15
53
|
env: Dict[str, str] = field(default_factory=dict)
|
|
16
54
|
artifact_dir: Optional[str] = None
|
|
55
|
+
_vars: Dict[str, Any] = field(default_factory=dict, repr=False)
|
|
17
56
|
|
|
18
57
|
# Reference to the parent run context (v0.2.0+)
|
|
19
58
|
run_context: Optional[RunContext] = None
|
|
20
59
|
|
|
21
60
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
61
|
+
_loop_stack: LoopStack = field(default_factory=LoopStack, repr=False)
|
|
22
62
|
|
|
23
63
|
@property
|
|
24
64
|
def is_cancelled(self) -> bool:
|
|
@@ -27,6 +67,29 @@ class Context:
|
|
|
27
67
|
return self.run_context.status in [RunStatus.CANCELLING, RunStatus.CANCELLED]
|
|
28
68
|
return False
|
|
29
69
|
|
|
70
|
+
@property
|
|
71
|
+
def loop(self) -> Optional[LoopFrame]:
|
|
72
|
+
return self._loop_stack.current
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def loops(self) -> Sequence[LoopFrame]:
|
|
76
|
+
return self._loop_stack.snapshot()
|
|
77
|
+
|
|
78
|
+
def push_loop(self, frame: LoopFrame) -> LoopFrame:
|
|
79
|
+
return self._loop_stack.push(frame)
|
|
80
|
+
|
|
81
|
+
def pop_loop(self) -> LoopFrame:
|
|
82
|
+
return self._loop_stack.pop()
|
|
83
|
+
|
|
84
|
+
def set_var(self, name: str, value: Any):
|
|
85
|
+
self._vars[name] = value
|
|
86
|
+
|
|
87
|
+
def get_var(self, name: str, default=None):
|
|
88
|
+
return self._vars.get(name, default)
|
|
89
|
+
|
|
90
|
+
def clear_var(self, name: str):
|
|
91
|
+
self._vars.pop(name, None)
|
|
92
|
+
|
|
30
93
|
def __post_init__(self):
|
|
31
94
|
# Ensure artifact directory exists
|
|
32
95
|
if self.artifact_dir is None:
|
|
@@ -124,3 +187,20 @@ class Context:
|
|
|
124
187
|
|
|
125
188
|
return value
|
|
126
189
|
|
|
190
|
+
def expression_data(self) -> Dict[str, Any]:
|
|
191
|
+
data: Dict[str, Any] = {}
|
|
192
|
+
data.update(self._vars)
|
|
193
|
+
data["params"] = self.params
|
|
194
|
+
data["results"] = self.results
|
|
195
|
+
data["scratch"] = self.scratch
|
|
196
|
+
data["artifacts"] = self.artifacts
|
|
197
|
+
data["loop"] = self.loop
|
|
198
|
+
data["loops"] = list(self.loops)
|
|
199
|
+
return data
|
|
200
|
+
|
|
201
|
+
def env_data(self) -> Dict[str, str]:
|
|
202
|
+
import os
|
|
203
|
+
|
|
204
|
+
env_data = dict(os.environ)
|
|
205
|
+
env_data.update(self.env)
|
|
206
|
+
return env_data
|