flowyml 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. flowyml/core/execution_status.py +1 -0
  2. flowyml/core/executor.py +175 -3
  3. flowyml/core/observability.py +7 -7
  4. flowyml/core/resources.py +12 -12
  5. flowyml/core/retry_policy.py +2 -2
  6. flowyml/core/scheduler.py +9 -9
  7. flowyml/core/scheduler_config.py +2 -3
  8. flowyml/core/submission_result.py +4 -4
  9. flowyml/stacks/bridge.py +9 -9
  10. flowyml/stacks/plugins.py +2 -2
  11. flowyml/stacks/registry.py +21 -0
  12. flowyml/storage/materializers/base.py +33 -0
  13. flowyml/storage/metadata.py +3 -1042
  14. flowyml/storage/remote.py +590 -0
  15. flowyml/storage/sql.py +951 -0
  16. flowyml/ui/backend/dependencies.py +28 -0
  17. flowyml/ui/backend/main.py +4 -79
  18. flowyml/ui/backend/routers/assets.py +170 -9
  19. flowyml/ui/backend/routers/client.py +6 -6
  20. flowyml/ui/backend/routers/execution.py +2 -2
  21. flowyml/ui/backend/routers/experiments.py +53 -6
  22. flowyml/ui/backend/routers/metrics.py +23 -68
  23. flowyml/ui/backend/routers/pipelines.py +19 -10
  24. flowyml/ui/backend/routers/runs.py +287 -9
  25. flowyml/ui/backend/routers/schedules.py +5 -21
  26. flowyml/ui/backend/routers/stats.py +14 -0
  27. flowyml/ui/backend/routers/traces.py +37 -53
  28. flowyml/ui/backend/routers/websocket.py +121 -0
  29. flowyml/ui/frontend/dist/assets/index-CBUXOWze.css +1 -0
  30. flowyml/ui/frontend/dist/assets/index-DF8dJaFL.js +629 -0
  31. flowyml/ui/frontend/dist/index.html +2 -2
  32. flowyml/ui/frontend/package-lock.json +289 -0
  33. flowyml/ui/frontend/package.json +1 -0
  34. flowyml/ui/frontend/src/app/compare/page.jsx +213 -0
  35. flowyml/ui/frontend/src/app/experiments/compare/page.jsx +289 -0
  36. flowyml/ui/frontend/src/app/experiments/page.jsx +61 -1
  37. flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +418 -203
  38. flowyml/ui/frontend/src/app/runs/page.jsx +64 -3
  39. flowyml/ui/frontend/src/app/settings/page.jsx +1 -1
  40. flowyml/ui/frontend/src/app/tokens/page.jsx +8 -6
  41. flowyml/ui/frontend/src/components/ArtifactViewer.jsx +159 -0
  42. flowyml/ui/frontend/src/components/NavigationTree.jsx +26 -9
  43. flowyml/ui/frontend/src/components/PipelineGraph.jsx +26 -24
  44. flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +42 -14
  45. flowyml/ui/frontend/src/router/index.jsx +4 -0
  46. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/METADATA +3 -1
  47. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/RECORD +50 -42
  48. flowyml/ui/frontend/dist/assets/index-DcYwrn2j.css +0 -1
  49. flowyml/ui/frontend/dist/assets/index-Dlz_ygOL.js +0 -592
  50. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/WHEEL +0 -0
  51. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/entry_points.txt +0 -0
  52. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/licenses/LICENSE +0 -0
flowyml/storage/sql.py ADDED
@@ -0,0 +1,951 @@
1
+ """SQLAlchemy-based metadata storage backend."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+
8
+ # Python 3.11+ has UTC, but Python 3.10 doesn't
9
+ try:
10
+ from datetime import UTC
11
+ except ImportError:
12
+ UTC = UTC
13
+
14
+ from sqlalchemy import (
15
+ create_engine,
16
+ MetaData,
17
+ Table,
18
+ Column,
19
+ String,
20
+ Integer,
21
+ Float,
22
+ Text,
23
+ ForeignKey,
24
+ select,
25
+ insert,
26
+ update,
27
+ delete,
28
+ func,
29
+ text,
30
+ inspect,
31
+ )
32
+ from sqlalchemy.pool import StaticPool
33
+
34
+ from flowyml.storage.metadata import MetadataStore
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class SQLMetadataStore(MetadataStore):
40
+ """SQLAlchemy-based metadata storage supporting SQLite and PostgreSQL."""
41
+
42
+ def __init__(self, db_path: str = ".flowyml/metadata.db", db_url: str | None = None):
43
+ """Initialize SQL metadata store.
44
+
45
+ Args:
46
+ db_path: Path to SQLite database file OR database URL (backward compatible)
47
+ db_url: Explicit database URL (takes precedence if provided)
48
+ (e.g., sqlite:///path/to/db, postgresql://user:pass@host/db)
49
+ """
50
+ # Handle backward compatibility: if db_path looks like a URL, use it as db_url
51
+ if db_url:
52
+ self.db_url = db_url
53
+ # Store db_path for backward compatibility
54
+ if db_url.startswith("sqlite:///"):
55
+ self.db_path = Path(db_url[10:]) # Remove 'sqlite:///'
56
+ else:
57
+ self.db_path = None
58
+ elif db_path and ("://" in db_path or db_path.startswith("sqlite:")):
59
+ # db_path is actually a URL
60
+ self.db_url = db_path
61
+ if db_path.startswith("sqlite:///"):
62
+ self.db_path = Path(db_path[10:])
63
+ else:
64
+ self.db_path = None
65
+ else:
66
+ # db_path is a file path
67
+ self.db_path = Path(db_path)
68
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
69
+ # Ensure absolute path for SQLite URL
70
+ abs_path = self.db_path.resolve()
71
+ self.db_url = f"sqlite:///{abs_path}"
72
+
73
+ self._init_db()
74
+
75
+ def _init_db(self) -> None:
76
+ """Initialize database schema."""
77
+ # Configure engine
78
+ connect_args = {}
79
+ if self.db_url.startswith("sqlite"):
80
+ connect_args = {"check_same_thread": False}
81
+
82
+ self.engine = create_engine(
83
+ self.db_url,
84
+ connect_args=connect_args,
85
+ poolclass=StaticPool if self.db_url.startswith("sqlite") else None,
86
+ )
87
+
88
+ self.metadata = MetaData()
89
+
90
+ # Define tables
91
+ self.runs = Table(
92
+ "runs",
93
+ self.metadata,
94
+ Column("run_id", String, primary_key=True),
95
+ Column("pipeline_name", String),
96
+ Column("status", String),
97
+ Column("start_time", String),
98
+ Column("end_time", String),
99
+ Column("duration", Float),
100
+ Column("metadata", Text),
101
+ Column("project", String),
102
+ Column("created_at", String, server_default=func.current_timestamp()),
103
+ )
104
+
105
+ self.artifacts = Table(
106
+ "artifacts",
107
+ self.metadata,
108
+ Column("artifact_id", String, primary_key=True),
109
+ Column("name", String),
110
+ Column("type", String),
111
+ Column("run_id", String, ForeignKey("runs.run_id")),
112
+ Column("path", String),
113
+ Column("metadata", Text),
114
+ Column("project", String),
115
+ Column("created_at", String, server_default=func.current_timestamp()),
116
+ )
117
+
118
+ self.metrics = Table(
119
+ "metrics",
120
+ self.metadata,
121
+ Column("id", Integer, primary_key=True, autoincrement=True),
122
+ Column("run_id", String, ForeignKey("runs.run_id")),
123
+ Column("name", String),
124
+ Column("value", Float),
125
+ Column("step", Integer),
126
+ Column("timestamp", String, server_default=func.current_timestamp()),
127
+ )
128
+
129
+ self.model_metrics = Table(
130
+ "model_metrics",
131
+ self.metadata,
132
+ Column("id", Integer, primary_key=True, autoincrement=True),
133
+ Column("project", String),
134
+ Column("model_name", String),
135
+ Column("run_id", String),
136
+ Column("metric_name", String),
137
+ Column("metric_value", Float),
138
+ Column("environment", String),
139
+ Column("tags", Text),
140
+ Column("created_at", String, server_default=func.current_timestamp()),
141
+ )
142
+
143
+ self.parameters = Table(
144
+ "parameters",
145
+ self.metadata,
146
+ Column("id", Integer, primary_key=True, autoincrement=True),
147
+ Column("run_id", String, ForeignKey("runs.run_id")),
148
+ Column("name", String),
149
+ Column("value", Text),
150
+ )
151
+
152
+ self.experiments = Table(
153
+ "experiments",
154
+ self.metadata,
155
+ Column("experiment_id", String, primary_key=True),
156
+ Column("name", String),
157
+ Column("description", Text),
158
+ Column("tags", Text),
159
+ Column("project", String),
160
+ Column("created_at", String, server_default=func.current_timestamp()),
161
+ )
162
+
163
+ self.experiment_runs = Table(
164
+ "experiment_runs",
165
+ self.metadata,
166
+ Column("experiment_id", String, ForeignKey("experiments.experiment_id"), primary_key=True),
167
+ Column("run_id", String, ForeignKey("runs.run_id"), primary_key=True),
168
+ Column("metrics", Text),
169
+ Column("parameters", Text),
170
+ Column("timestamp", String, server_default=func.current_timestamp()),
171
+ )
172
+
173
+ self.traces = Table(
174
+ "traces",
175
+ self.metadata,
176
+ Column("event_id", String, primary_key=True),
177
+ Column("trace_id", String),
178
+ Column("parent_id", String),
179
+ Column("event_type", String),
180
+ Column("name", String),
181
+ Column("inputs", Text),
182
+ Column("outputs", Text),
183
+ Column("start_time", Float),
184
+ Column("end_time", Float),
185
+ Column("duration", Float),
186
+ Column("status", String),
187
+ Column("error", Text),
188
+ Column("metadata", Text),
189
+ Column("prompt_tokens", Integer),
190
+ Column("completion_tokens", Integer),
191
+ Column("total_tokens", Integer),
192
+ Column("cost", Float),
193
+ Column("model", String),
194
+ Column("project", String),
195
+ Column("created_at", String, server_default=func.current_timestamp()),
196
+ )
197
+
198
+ self.pipeline_definitions = Table(
199
+ "pipeline_definitions",
200
+ self.metadata,
201
+ Column("pipeline_name", String, primary_key=True),
202
+ Column("definition", Text, nullable=False),
203
+ Column("created_at", String, nullable=False),
204
+ Column("updated_at", String, nullable=False),
205
+ )
206
+
207
+ # Create tables
208
+ self.metadata.create_all(self.engine)
209
+
210
+ # Handle migrations (add missing columns if tables existed)
211
+ self._migrate_schema()
212
+
213
+ def _migrate_schema(self) -> None:
214
+ """Add missing columns if needed."""
215
+ inspector = inspect(self.engine)
216
+
217
+ # Check runs.project
218
+ columns = [c["name"] for c in inspector.get_columns("runs")]
219
+ if "project" not in columns:
220
+ with self.engine.connect() as conn:
221
+ conn.execute(text("ALTER TABLE runs ADD COLUMN project VARCHAR"))
222
+ conn.commit()
223
+
224
+ # Check artifacts.project
225
+ columns = [c["name"] for c in inspector.get_columns("artifacts")]
226
+ if "project" not in columns:
227
+ with self.engine.connect() as conn:
228
+ conn.execute(text("ALTER TABLE artifacts ADD COLUMN project VARCHAR"))
229
+ conn.commit()
230
+
231
+ # Check experiments.project
232
+ columns = [c["name"] for c in inspector.get_columns("experiments")]
233
+ if "project" not in columns:
234
+ with self.engine.connect() as conn:
235
+ conn.execute(text("ALTER TABLE experiments ADD COLUMN project VARCHAR"))
236
+ conn.commit()
237
+
238
+ # Check traces.project
239
+ columns = [c["name"] for c in inspector.get_columns("traces")]
240
+ if "project" not in columns:
241
+ with self.engine.connect() as conn:
242
+ conn.execute(text("ALTER TABLE traces ADD COLUMN project VARCHAR"))
243
+ conn.commit()
244
+
245
+ def save_run(self, run_id: str, metadata: dict) -> None:
246
+ """Save run metadata."""
247
+ with self.engine.connect() as conn:
248
+ # Upsert run
249
+ stmt = select(self.runs).where(self.runs.c.run_id == run_id)
250
+ existing = conn.execute(stmt).fetchone()
251
+
252
+ values = {
253
+ "run_id": run_id,
254
+ "pipeline_name": metadata.get("pipeline_name"),
255
+ "status": metadata.get("status"),
256
+ "start_time": metadata.get("start_time"),
257
+ "end_time": metadata.get("end_time"),
258
+ "duration": metadata.get("duration"),
259
+ "metadata": json.dumps(metadata),
260
+ "project": metadata.get("project"),
261
+ }
262
+
263
+ if existing:
264
+ conn.execute(
265
+ update(self.runs).where(self.runs.c.run_id == run_id).values(**values),
266
+ )
267
+ else:
268
+ conn.execute(insert(self.runs).values(**values))
269
+
270
+ # Save parameters
271
+ if "parameters" in metadata:
272
+ conn.execute(delete(self.parameters).where(self.parameters.c.run_id == run_id))
273
+ if metadata["parameters"]:
274
+ conn.execute(
275
+ insert(self.parameters),
276
+ [
277
+ {"run_id": run_id, "name": k, "value": json.dumps(v)}
278
+ for k, v in metadata["parameters"].items()
279
+ ],
280
+ )
281
+
282
+ # Save metrics
283
+ if "metrics" in metadata:
284
+ conn.execute(delete(self.metrics).where(self.metrics.c.run_id == run_id))
285
+ if metadata["metrics"]:
286
+ conn.execute(
287
+ insert(self.metrics),
288
+ [
289
+ {"run_id": run_id, "name": k, "value": float(v), "step": 0}
290
+ for k, v in metadata["metrics"].items()
291
+ ],
292
+ )
293
+
294
+ conn.commit()
295
+
296
+ def load_run(self, run_id: str) -> dict | None:
297
+ """Load run metadata."""
298
+ with self.engine.connect() as conn:
299
+ stmt = select(self.runs.c.metadata).where(self.runs.c.run_id == run_id)
300
+ row = conn.execute(stmt).fetchone()
301
+ if row:
302
+ return json.loads(row[0])
303
+ return None
304
+
305
+ def update_run_project(self, run_id: str, project_name: str) -> None:
306
+ """Update the project for a run."""
307
+ with self.engine.connect() as conn:
308
+ # Update column
309
+ conn.execute(
310
+ update(self.runs).where(self.runs.c.run_id == run_id).values(project=project_name),
311
+ )
312
+
313
+ # Update JSON blob
314
+ stmt = select(self.runs.c.metadata).where(self.runs.c.run_id == run_id)
315
+ row = conn.execute(stmt).fetchone()
316
+ if row:
317
+ metadata = json.loads(row[0])
318
+ metadata["project"] = project_name
319
+ conn.execute(
320
+ update(self.runs).where(self.runs.c.run_id == run_id).values(metadata=json.dumps(metadata)),
321
+ )
322
+
323
+ conn.commit()
324
+
325
+ def list_runs(self, limit: int | None = None) -> list[dict]:
326
+ """List all runs."""
327
+ with self.engine.connect() as conn:
328
+ stmt = select(self.runs.c.metadata).order_by(self.runs.c.created_at.desc())
329
+ if limit:
330
+ stmt = stmt.limit(limit)
331
+
332
+ rows = conn.execute(stmt).fetchall()
333
+ return [json.loads(row[0]) for row in rows]
334
+
335
+ def list_pipelines(self, project: str = None) -> list[str]:
336
+ """List all unique pipeline names."""
337
+ with self.engine.connect() as conn:
338
+ stmt = select(self.runs.c.pipeline_name).distinct().order_by(self.runs.c.pipeline_name)
339
+ if project:
340
+ stmt = stmt.where(self.runs.c.project == project)
341
+
342
+ rows = conn.execute(stmt).fetchall()
343
+ return [row[0] for row in rows if row[0]]
344
+
345
+ def save_artifact(self, artifact_id: str, metadata: dict) -> None:
346
+ """Save artifact metadata."""
347
+ with self.engine.connect() as conn:
348
+ stmt = select(self.artifacts).where(self.artifacts.c.artifact_id == artifact_id)
349
+ existing = conn.execute(stmt).fetchone()
350
+
351
+ values = {
352
+ "artifact_id": artifact_id,
353
+ "name": metadata.get("name"),
354
+ "type": metadata.get("type"),
355
+ "run_id": metadata.get("run_id"),
356
+ "path": metadata.get("path"),
357
+ "metadata": json.dumps(metadata),
358
+ "project": metadata.get("project"),
359
+ }
360
+
361
+ if existing:
362
+ conn.execute(
363
+ update(self.artifacts).where(self.artifacts.c.artifact_id == artifact_id).values(**values),
364
+ )
365
+ else:
366
+ conn.execute(insert(self.artifacts).values(**values))
367
+
368
+ conn.commit()
369
+
370
+ def load_artifact(self, artifact_id: str) -> dict | None:
371
+ """Load artifact metadata."""
372
+ with self.engine.connect() as conn:
373
+ stmt = select(self.artifacts.c.metadata).where(self.artifacts.c.artifact_id == artifact_id)
374
+ row = conn.execute(stmt).fetchone()
375
+ if row:
376
+ return json.loads(row[0])
377
+ return None
378
+
379
+ def delete_artifact(self, artifact_id: str) -> None:
380
+ """Delete artifact metadata."""
381
+ with self.engine.connect() as conn:
382
+ conn.execute(delete(self.artifacts).where(self.artifacts.c.artifact_id == artifact_id))
383
+ conn.commit()
384
+
385
+ def list_assets(self, limit: int | None = None, **filters) -> list[dict]:
386
+ """List assets with optional filters."""
387
+ with self.engine.connect() as conn:
388
+ stmt = select(self.artifacts.c.metadata)
389
+
390
+ for key, value in filters.items():
391
+ if value is not None and hasattr(self.artifacts.c, key):
392
+ stmt = stmt.where(getattr(self.artifacts.c, key) == value)
393
+
394
+ stmt = stmt.order_by(self.artifacts.c.created_at.desc())
395
+
396
+ if limit:
397
+ stmt = stmt.limit(limit)
398
+
399
+ rows = conn.execute(stmt).fetchall()
400
+ return [json.loads(row[0]) for row in rows]
401
+
402
+ def query(self, **filters) -> list[dict]:
403
+ """Query runs with filters."""
404
+ with self.engine.connect() as conn:
405
+ stmt = select(self.runs.c.metadata)
406
+
407
+ for key, value in filters.items():
408
+ if hasattr(self.runs.c, key):
409
+ stmt = stmt.where(getattr(self.runs.c, key) == value)
410
+
411
+ stmt = stmt.order_by(self.runs.c.created_at.desc())
412
+ rows = conn.execute(stmt).fetchall()
413
+ return [json.loads(row[0]) for row in rows]
414
+
415
+ def save_metric(self, run_id: str, name: str, value: float, step: int = 0) -> None:
416
+ """Save a single metric value."""
417
+ with self.engine.connect() as conn:
418
+ conn.execute(
419
+ insert(self.metrics).values(
420
+ run_id=run_id,
421
+ name=name,
422
+ value=value,
423
+ step=step,
424
+ ),
425
+ )
426
+ conn.commit()
427
+
428
+ def get_metrics(self, run_id: str, name: str | None = None) -> list[dict]:
429
+ """Get metrics for a run."""
430
+ with self.engine.connect() as conn:
431
+ stmt = select(
432
+ self.metrics.c.name,
433
+ self.metrics.c.value,
434
+ self.metrics.c.step,
435
+ self.metrics.c.timestamp,
436
+ ).where(self.metrics.c.run_id == run_id)
437
+
438
+ if name:
439
+ stmt = stmt.where(self.metrics.c.name == name)
440
+
441
+ stmt = stmt.order_by(self.metrics.c.step)
442
+ rows = conn.execute(stmt).fetchall()
443
+
444
+ return [{"name": row[0], "value": row[1], "step": row[2], "timestamp": str(row[3])} for row in rows]
445
+
446
+ def log_model_metrics(
447
+ self,
448
+ project: str,
449
+ model_name: str,
450
+ metrics: dict[str, float],
451
+ run_id: str | None = None,
452
+ environment: str | None = None,
453
+ tags: dict | None = None,
454
+ ) -> None:
455
+ """Log production model metrics."""
456
+ if not metrics:
457
+ return
458
+
459
+ with self.engine.connect() as conn:
460
+ tags_json = json.dumps(tags or {})
461
+
462
+ values_list = []
463
+ for metric_name, value in metrics.items():
464
+ try:
465
+ metric_value = float(value)
466
+ values_list.append(
467
+ {
468
+ "project": project,
469
+ "model_name": model_name,
470
+ "run_id": run_id,
471
+ "metric_name": metric_name,
472
+ "metric_value": metric_value,
473
+ "environment": environment,
474
+ "tags": tags_json,
475
+ },
476
+ )
477
+ except (TypeError, ValueError):
478
+ continue
479
+
480
+ if values_list:
481
+ conn.execute(insert(self.model_metrics), values_list)
482
+ conn.commit()
483
+
484
+ def list_model_metrics(
485
+ self,
486
+ project: str | None = None,
487
+ model_name: str | None = None,
488
+ limit: int = 100,
489
+ ) -> list[dict]:
490
+ """List logged model metrics."""
491
+ with self.engine.connect() as conn:
492
+ stmt = select(
493
+ self.model_metrics.c.project,
494
+ self.model_metrics.c.model_name,
495
+ self.model_metrics.c.run_id,
496
+ self.model_metrics.c.metric_name,
497
+ self.model_metrics.c.metric_value,
498
+ self.model_metrics.c.environment,
499
+ self.model_metrics.c.tags,
500
+ self.model_metrics.c.created_at,
501
+ )
502
+
503
+ if project:
504
+ stmt = stmt.where(self.model_metrics.c.project == project)
505
+ if model_name:
506
+ stmt = stmt.where(self.model_metrics.c.model_name == model_name)
507
+
508
+ stmt = stmt.order_by(self.model_metrics.c.created_at.desc()).limit(limit)
509
+ rows = conn.execute(stmt).fetchall()
510
+
511
+ results = []
512
+ for row in rows:
513
+ results.append(
514
+ {
515
+ "project": row[0],
516
+ "model_name": row[1],
517
+ "run_id": row[2],
518
+ "metric_name": row[3],
519
+ "metric_value": row[4],
520
+ "environment": row[5],
521
+ "tags": json.loads(row[6]) if row[6] else {},
522
+ "created_at": str(row[7]),
523
+ },
524
+ )
525
+ return results
526
+
527
+ def save_experiment(self, experiment_id: str, name: str, description: str = "", tags: dict = None) -> None:
528
+ """Save experiment metadata."""
529
+ with self.engine.connect() as conn:
530
+ stmt = select(self.experiments).where(self.experiments.c.experiment_id == experiment_id)
531
+ existing = conn.execute(stmt).fetchone()
532
+
533
+ values = {
534
+ "experiment_id": experiment_id,
535
+ "name": name,
536
+ "description": description,
537
+ "tags": json.dumps(tags or {}),
538
+ }
539
+
540
+ if existing:
541
+ conn.execute(
542
+ update(self.experiments).where(self.experiments.c.experiment_id == experiment_id).values(**values),
543
+ )
544
+ else:
545
+ conn.execute(insert(self.experiments).values(**values))
546
+
547
+ conn.commit()
548
+
549
+ def log_experiment_run(
550
+ self,
551
+ experiment_id: str,
552
+ run_id: str,
553
+ metrics: dict = None,
554
+ parameters: dict = None,
555
+ ) -> None:
556
+ """Log a run to an experiment."""
557
+ with self.engine.connect() as conn:
558
+ # Check if exists (composite primary key)
559
+ stmt = select(self.experiment_runs).where(
560
+ (self.experiment_runs.c.experiment_id == experiment_id) & (self.experiment_runs.c.run_id == run_id),
561
+ )
562
+ existing = conn.execute(stmt).fetchone()
563
+
564
+ values = {
565
+ "experiment_id": experiment_id,
566
+ "run_id": run_id,
567
+ "metrics": json.dumps(metrics or {}),
568
+ "parameters": json.dumps(parameters or {}),
569
+ }
570
+
571
+ if existing:
572
+ conn.execute(
573
+ update(self.experiment_runs)
574
+ .where(
575
+ (self.experiment_runs.c.experiment_id == experiment_id)
576
+ & (self.experiment_runs.c.run_id == run_id),
577
+ )
578
+ .values(**values),
579
+ )
580
+ else:
581
+ conn.execute(insert(self.experiment_runs).values(**values))
582
+
583
+ conn.commit()
584
+
585
+ def list_experiments(self) -> list[dict]:
586
+ """List all experiments."""
587
+ with self.engine.connect() as conn:
588
+ stmt = select(
589
+ self.experiments.c.experiment_id,
590
+ self.experiments.c.name,
591
+ self.experiments.c.description,
592
+ self.experiments.c.tags,
593
+ self.experiments.c.created_at,
594
+ self.experiments.c.project,
595
+ ).order_by(self.experiments.c.created_at.desc())
596
+
597
+ rows = conn.execute(stmt).fetchall()
598
+
599
+ experiments = []
600
+ for row in rows:
601
+ # Count runs
602
+ count_stmt = (
603
+ select(func.count())
604
+ .select_from(self.experiment_runs)
605
+ .where(
606
+ self.experiment_runs.c.experiment_id == row[0],
607
+ )
608
+ )
609
+ run_count = conn.execute(count_stmt).scalar()
610
+
611
+ experiments.append(
612
+ {
613
+ "experiment_id": row[0],
614
+ "name": row[1],
615
+ "description": row[2],
616
+ "tags": json.loads(row[3]),
617
+ "created_at": str(row[4]),
618
+ "project": row[5],
619
+ "run_count": run_count,
620
+ },
621
+ )
622
+ return experiments
623
+
624
+ def update_experiment_project(self, experiment_name: str, project_name: str) -> None:
625
+ """Update the project for an experiment."""
626
+ with self.engine.connect() as conn:
627
+ conn.execute(
628
+ update(self.experiments).where(self.experiments.c.name == experiment_name).values(project=project_name),
629
+ )
630
+ conn.commit()
631
+
632
+ def get_experiment(self, experiment_id: str) -> dict | None:
633
+ """Get experiment details."""
634
+ with self.engine.connect() as conn:
635
+ stmt = select(
636
+ self.experiments.c.experiment_id,
637
+ self.experiments.c.name,
638
+ self.experiments.c.description,
639
+ self.experiments.c.tags,
640
+ self.experiments.c.created_at,
641
+ ).where(self.experiments.c.experiment_id == experiment_id)
642
+
643
+ row = conn.execute(stmt).fetchone()
644
+ if not row:
645
+ return None
646
+
647
+ return {
648
+ "experiment_id": row[0],
649
+ "name": row[1],
650
+ "description": row[2],
651
+ "tags": json.loads(row[3]),
652
+ "created_at": str(row[4]),
653
+ }
654
+
655
+ def save_trace_event(self, event: dict) -> None:
656
+ """Save a trace event."""
657
+ with self.engine.connect() as conn:
658
+ stmt = select(self.traces).where(self.traces.c.event_id == event["event_id"])
659
+ existing = conn.execute(stmt).fetchone()
660
+
661
+ values = {
662
+ "event_id": event["event_id"],
663
+ "trace_id": event.get("trace_id"),
664
+ "parent_id": event.get("parent_id"),
665
+ "event_type": event.get("event_type"),
666
+ "name": event.get("name"),
667
+ "inputs": json.dumps(event.get("inputs")),
668
+ "outputs": json.dumps(event.get("outputs")),
669
+ "start_time": event.get("start_time"),
670
+ "end_time": event.get("end_time"),
671
+ "duration": event.get("duration"),
672
+ "status": event.get("status"),
673
+ "error": json.dumps(event.get("error")),
674
+ "metadata": json.dumps(event.get("metadata")),
675
+ "prompt_tokens": event.get("prompt_tokens"),
676
+ "completion_tokens": event.get("completion_tokens"),
677
+ "total_tokens": event.get("total_tokens"),
678
+ "cost": event.get("cost"),
679
+ "model": event.get("model"),
680
+ "project": event.get("project"),
681
+ }
682
+
683
+ if existing:
684
+ conn.execute(
685
+ update(self.traces).where(self.traces.c.event_id == event["event_id"]).values(**values),
686
+ )
687
+ else:
688
+ conn.execute(insert(self.traces).values(**values))
689
+
690
+ conn.commit()
691
+
692
+ def save_pipeline_definition(self, pipeline_name: str, definition: dict) -> None:
693
+ """Save pipeline definition."""
694
+ now = datetime.now(UTC).isoformat()
695
+ with self.engine.connect() as conn:
696
+ stmt = select(self.pipeline_definitions).where(
697
+ self.pipeline_definitions.c.pipeline_name == pipeline_name,
698
+ )
699
+ existing = conn.execute(stmt).fetchone()
700
+
701
+ values = {
702
+ "pipeline_name": pipeline_name,
703
+ "definition": json.dumps(definition),
704
+ "updated_at": now,
705
+ }
706
+
707
+ if existing:
708
+ conn.execute(
709
+ update(self.pipeline_definitions)
710
+ .where(self.pipeline_definitions.c.pipeline_name == pipeline_name)
711
+ .values(**values),
712
+ )
713
+ else:
714
+ values["created_at"] = now
715
+ conn.execute(insert(self.pipeline_definitions).values(**values))
716
+
717
+ conn.commit()
718
+
719
+ def get_trace(self, trace_id: str) -> list[dict]:
720
+ """Get all events for a trace."""
721
+ with self.engine.connect() as conn:
722
+ stmt = select(self.traces).where(self.traces.c.trace_id == trace_id).order_by(self.traces.c.start_time)
723
+ rows = conn.execute(stmt).fetchall()
724
+
725
+ events = []
726
+ for row in rows:
727
+ event = {
728
+ "event_id": row.event_id,
729
+ "trace_id": row.trace_id,
730
+ "parent_id": row.parent_id,
731
+ "event_type": row.event_type,
732
+ "name": row.name,
733
+ "inputs": json.loads(row.inputs) if row.inputs else {},
734
+ "outputs": json.loads(row.outputs) if row.outputs else {},
735
+ "start_time": row.start_time,
736
+ "end_time": row.end_time,
737
+ "duration": row.duration,
738
+ "status": row.status,
739
+ "error": json.loads(row.error) if row.error else None,
740
+ "metadata": json.loads(row.metadata) if row.metadata else {},
741
+ "prompt_tokens": row.prompt_tokens,
742
+ "completion_tokens": row.completion_tokens,
743
+ "total_tokens": row.total_tokens,
744
+ "cost": row.cost,
745
+ "model": row.model,
746
+ "project": row.project,
747
+ }
748
+ events.append(event)
749
+ return events
750
+
751
+ def list_traces(
752
+ self,
753
+ limit: int = 50,
754
+ trace_id: str | None = None,
755
+ event_type: str | None = None,
756
+ project: str | None = None,
757
+ ) -> list[dict]:
758
+ """List traces with optional filters."""
759
+ with self.engine.connect() as conn:
760
+ stmt = select(self.traces)
761
+
762
+ if trace_id:
763
+ stmt = stmt.where(self.traces.c.trace_id == trace_id)
764
+ if event_type:
765
+ stmt = stmt.where(self.traces.c.event_type == event_type)
766
+ if project:
767
+ stmt = stmt.where(self.traces.c.project == project)
768
+
769
+ stmt = stmt.order_by(self.traces.c.start_time.desc()).limit(limit)
770
+ rows = conn.execute(stmt).fetchall()
771
+
772
+ traces = []
773
+ for row in rows:
774
+ trace = {
775
+ "event_id": row.event_id,
776
+ "trace_id": row.trace_id,
777
+ "parent_id": row.parent_id,
778
+ "event_type": row.event_type,
779
+ "name": row.name,
780
+ "inputs": json.loads(row.inputs) if row.inputs else {},
781
+ "outputs": json.loads(row.outputs) if row.outputs else {},
782
+ "start_time": row.start_time,
783
+ "end_time": row.end_time,
784
+ "duration": row.duration,
785
+ "status": row.status,
786
+ "error": json.loads(row.error) if row.error else None,
787
+ "metadata": json.loads(row.metadata) if row.metadata else {},
788
+ "prompt_tokens": row.prompt_tokens,
789
+ "completion_tokens": row.completion_tokens,
790
+ "total_tokens": row.total_tokens,
791
+ "cost": row.cost,
792
+ "model": row.model,
793
+ "project": row.project,
794
+ "created_at": row.created_at,
795
+ }
796
+ traces.append(trace)
797
+ return traces
798
+
799
+ def get_pipeline_definition(self, pipeline_name: str) -> dict | None:
800
+ """Get pipeline definition."""
801
+ with self.engine.connect() as conn:
802
+ stmt = select(self.pipeline_definitions.c.definition).where(
803
+ self.pipeline_definitions.c.pipeline_name == pipeline_name,
804
+ )
805
+ row = conn.execute(stmt).fetchone()
806
+ if row:
807
+ return json.loads(row[0])
808
+ return None
809
+
810
+ def get_orchestrator_metrics(self, days: int = 30) -> dict:
811
+ """Get orchestrator-level performance metrics for the last N days."""
812
+ from datetime import datetime, timedelta
813
+
814
+ cutoff = (datetime.now() - timedelta(days=days)).isoformat()
815
+
816
+ with self.engine.connect() as conn:
817
+ # Total runs
818
+ total_runs = conn.execute(
819
+ select(func.count()).select_from(self.runs).where(self.runs.c.created_at >= cutoff),
820
+ ).scalar()
821
+
822
+ # Status distribution
823
+ status_stmt = (
824
+ select(
825
+ self.runs.c.status,
826
+ func.count(),
827
+ )
828
+ .where(self.runs.c.created_at >= cutoff)
829
+ .group_by(self.runs.c.status)
830
+ )
831
+ status_rows = conn.execute(status_stmt).fetchall()
832
+ status_counts = {row[0]: row[1] for row in status_rows if row[0]}
833
+
834
+ # Average duration
835
+ avg_duration = (
836
+ conn.execute(
837
+ select(func.avg(self.runs.c.duration)).where(
838
+ (self.runs.c.created_at >= cutoff) & (self.runs.c.duration.isnot(None)),
839
+ ),
840
+ ).scalar()
841
+ or 0
842
+ )
843
+
844
+ completed = status_counts.get("completed", 0)
845
+ success_rate = completed / total_runs if total_runs > 0 else 0
846
+
847
+ return {
848
+ "total_runs": total_runs,
849
+ "success_rate": success_rate,
850
+ "avg_duration_seconds": avg_duration,
851
+ "status_distribution": status_counts,
852
+ "period_days": days,
853
+ }
854
+
855
+ def get_cache_metrics(self, days: int = 30) -> dict:
856
+ """Get cache performance metrics for the last N days."""
857
+ from datetime import datetime, timedelta
858
+
859
+ cutoff = (datetime.now() - timedelta(days=days)).isoformat()
860
+
861
+ with self.engine.connect() as conn:
862
+ stmt = select(self.runs.c.metadata).where(self.runs.c.created_at >= cutoff)
863
+ rows = conn.execute(stmt).fetchall()
864
+
865
+ total_steps, cached_steps = 0, 0
866
+ for row in rows:
867
+ if not row[0]:
868
+ continue
869
+ try:
870
+ metadata = json.loads(row[0])
871
+ for step_data in metadata.get("steps", {}).values():
872
+ total_steps += 1
873
+ if step_data.get("cached"):
874
+ cached_steps += 1
875
+ except Exception:
876
+ continue
877
+
878
+ cache_hit_rate = cached_steps / total_steps if total_steps > 0 else 0
879
+
880
+ return {
881
+ "total_steps": total_steps,
882
+ "cached_steps": cached_steps,
883
+ "cache_hit_rate": cache_hit_rate,
884
+ "period_days": days,
885
+ }
886
+
887
+ def get_statistics(self, project: str | None = None) -> dict:
888
+ """Get global statistics."""
889
+ with self.engine.connect() as conn:
890
+ # 1. Total runs
891
+ runs_stmt = select(func.count()).select_from(self.runs)
892
+ if project:
893
+ runs_stmt = runs_stmt.where(self.runs.c.project == project)
894
+ total_runs = conn.execute(runs_stmt).scalar() or 0
895
+
896
+ # 2. Total pipelines (unique names)
897
+ pipelines_stmt = select(func.count(func.distinct(self.runs.c.pipeline_name)))
898
+ if project:
899
+ pipelines_stmt = pipelines_stmt.where(self.runs.c.project == project)
900
+ total_pipelines = conn.execute(pipelines_stmt).scalar() or 0
901
+
902
+ # 3. Total artifacts
903
+ artifacts_stmt = select(func.count()).select_from(self.artifacts)
904
+ if project:
905
+ artifacts_stmt = artifacts_stmt.where(self.artifacts.c.project == project)
906
+ total_artifacts = conn.execute(artifacts_stmt).scalar() or 0
907
+
908
+ # 4. Total experiments
909
+ experiments_stmt = select(func.count()).select_from(self.experiments)
910
+ if project:
911
+ experiments_stmt = experiments_stmt.where(self.experiments.c.project == project)
912
+ total_experiments = conn.execute(experiments_stmt).scalar() or 0
913
+
914
+ # 5. Total models
915
+ models_stmt = select(func.count(func.distinct(self.model_metrics.c.model_name)))
916
+ if project:
917
+ models_stmt = models_stmt.where(self.model_metrics.c.project == project)
918
+ total_models = conn.execute(models_stmt).scalar() or 0
919
+
920
+ # 6. Status counts (completed vs failed)
921
+ status_stmt = select(self.runs.c.status, func.count()).group_by(self.runs.c.status)
922
+ if project:
923
+ status_stmt = status_stmt.where(self.runs.c.project == project)
924
+
925
+ status_rows = conn.execute(status_stmt).fetchall()
926
+ status_map = {row[0]: row[1] for row in status_rows if row[0]}
927
+
928
+ completed_runs = status_map.get("completed", 0)
929
+ failed_runs = status_map.get("failed", 0)
930
+
931
+ # 7. Avg duration (only completed runs)
932
+ dur_stmt = select(func.avg(self.runs.c.duration)).where(self.runs.c.status == "completed")
933
+ if project:
934
+ dur_stmt = dur_stmt.where(self.runs.c.project == project)
935
+
936
+ avg_duration = conn.execute(dur_stmt).scalar() or 0.0
937
+
938
+ return {
939
+ # Frontend-friendly keys
940
+ "pipelines": total_pipelines,
941
+ "runs": total_runs,
942
+ "artifacts": total_artifacts,
943
+ "completed_runs": completed_runs,
944
+ "failed_runs": failed_runs,
945
+ "avg_duration": avg_duration,
946
+ # Backward compatibility
947
+ "total_runs": total_runs,
948
+ "total_pipelines": total_pipelines,
949
+ "total_experiments": total_experiments,
950
+ "total_models": total_models,
951
+ }