vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.3.dist-info/METADATA +325 -0
  43. vespaembed-0.0.3.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
vespaembed/web/app.py ADDED
@@ -0,0 +1,605 @@
1
+ import json
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import sys
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
11
+ from fastapi.responses import HTMLResponse
12
+ from fastapi.staticfiles import StaticFiles
13
+ from fastapi.templating import Jinja2Templates
14
+ from pydantic import BaseModel, Field, field_validator
15
+ from starlette.requests import Request
16
+
17
+ from vespaembed.db import create_run, delete_run, get_active_run, get_all_runs, get_run, update_run_status
18
+ from vespaembed.enums import RunStatus
19
+
20
+
21
+ # Helper: Check if process is alive
22
+ def is_process_alive(pid: int | None) -> bool:
23
+ """Check if a process with the given PID is still running."""
24
+ if pid is None:
25
+ return False
26
+ try:
27
+ os.kill(pid, 0) # Signal 0 checks if process exists without killing it
28
+ return True
29
+ except (ProcessLookupError, OSError):
30
+ return False
31
+
32
+
33
+ # Helper: Check if final model exists for a run
34
+ def has_final_model(output_dir: str | None) -> bool:
35
+ """Check if training completed successfully by looking for final model."""
36
+ if not output_dir:
37
+ return False
38
+ final_path = Path(output_dir) / "final"
39
+ return final_path.exists() and final_path.is_dir()
40
+
41
+
42
+ # Startup: Sync run statuses with actual process states
43
+ def sync_run_statuses():
44
+ """Check all running/pending runs and update status if process is dead.
45
+
46
+ This handles cases where:
47
+ - Server was restarted while training was in progress
48
+ - Training process crashed unexpectedly
49
+ - Training completed but status wasn't updated
50
+ """
51
+ runs = get_all_runs()
52
+ for run in runs:
53
+ if run["status"] in [RunStatus.RUNNING.value, RunStatus.PENDING.value]:
54
+ pid = run.get("pid")
55
+ run_id = run["id"]
56
+
57
+ # Check if process is still alive
58
+ if is_process_alive(pid):
59
+ continue # Still running, leave it
60
+
61
+ # Process is dead - determine final status
62
+ if has_final_model(run.get("output_dir")):
63
+ update_run_status(run_id, RunStatus.COMPLETED)
64
+ print(f"[sync] Run {run_id}: Marked as completed (final model exists)")
65
+ else:
66
+ update_run_status(run_id, RunStatus.ERROR, error_message="Process terminated unexpectedly")
67
+ print(f"[sync] Run {run_id}: Marked as error (no final model)")
68
+
69
+
70
+ # Paths
71
+ PACKAGE_DIR = Path(__file__).parent.parent
72
+ STATIC_DIR = PACKAGE_DIR / "static"
73
+ TEMPLATES_DIR = PACKAGE_DIR / "templates"
74
+ BASE_DIR = Path.home() / ".vespaembed"
75
+ UPLOAD_DIR = BASE_DIR / "uploads"
76
+ UPDATE_DIR = BASE_DIR / "updates"
77
+ PROJECTS_DIR = BASE_DIR / "projects"
78
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
79
+ UPDATE_DIR.mkdir(parents=True, exist_ok=True)
80
+ PROJECTS_DIR.mkdir(parents=True, exist_ok=True)
81
+
82
+ # FastAPI app
83
+ app = FastAPI(title="VespaEmbed", version="0.0.1")
84
+
85
+ # Mount static files
86
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
87
+
88
+ # Templates
89
+ templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
90
+
91
+ # Sync run statuses on startup
92
+ sync_run_statuses()
93
+
94
+
95
+ # Pydantic models
96
+ class TrainRequest(BaseModel):
97
+ # Project name (required) - used to create output directory
98
+ project_name: str = Field(..., description="Project name (alphanumeric and hyphens only)")
99
+
100
+ # Data source - either file upload OR HuggingFace dataset
101
+ train_filename: Optional[str] = Field(None, description="Path to uploaded training file")
102
+ eval_filename: Optional[str] = Field(None, description="Path to uploaded evaluation file (optional)")
103
+
104
+ # HuggingFace dataset (alternative to file upload)
105
+ hf_dataset: Optional[str] = Field(
106
+ None, description="HuggingFace dataset name (e.g., 'sentence-transformers/all-nli')"
107
+ )
108
+ hf_subset: Optional[str] = Field(None, description="Dataset subset/config name")
109
+ hf_train_split: str = Field("train", description="Training split name")
110
+ hf_eval_split: Optional[str] = Field(None, description="Evaluation split name (optional)")
111
+
112
+ # Required
113
+ task: str
114
+ base_model: str
115
+
116
+ # Basic hyperparameters
117
+ epochs: int = 3
118
+ batch_size: int = 32
119
+ learning_rate: float = 2e-5
120
+
121
+ # Advanced hyperparameters
122
+ warmup_ratio: float = 0.1
123
+ weight_decay: float = 0.01
124
+ fp16: bool = False
125
+ bf16: bool = False
126
+ eval_steps: int = 500
127
+ save_steps: int = 500
128
+ logging_steps: int = 100
129
+ gradient_accumulation_steps: int = 1
130
+
131
+ # Optimizer and scheduler
132
+ optimizer: str = Field("adamw_torch", description="Optimizer type")
133
+ scheduler: str = Field("linear", description="Learning rate scheduler")
134
+
135
+ # LoRA/PEFT parameters
136
+ lora_enabled: bool = False
137
+ lora_r: int = 64
138
+ lora_alpha: int = 128
139
+ lora_dropout: float = 0.1
140
+ lora_target_modules: str = Field("query, key, value, dense", description="Comma-separated list of target modules")
141
+
142
+ # Model configuration
143
+ max_seq_length: Optional[int] = None # Auto-detect from model if not specified
144
+ gradient_checkpointing: bool = False # Saves VRAM, uses Unsloth optimization when Unsloth is enabled
145
+
146
+ # Unsloth parameters
147
+ unsloth_enabled: bool = False
148
+ unsloth_save_method: str = "merged_16bit" # "lora", "merged_16bit", "merged_4bit"
149
+
150
+ # Hub push
151
+ push_to_hub: bool = False
152
+ hf_username: Optional[str] = None
153
+
154
+ # Task-specific parameters
155
+ matryoshka_dims: Optional[str] = Field(
156
+ None, description="Matryoshka dimensions as comma-separated string (e.g., '768,512,256,128')"
157
+ )
158
+ loss_variant: Optional[str] = Field(
159
+ None, description="Loss function variant (task-specific, uses default if not specified)"
160
+ )
161
+
162
+ @field_validator("project_name")
163
+ @classmethod
164
+ def validate_project_name(cls, v):
165
+ if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-]*$", v):
166
+ raise ValueError(
167
+ "Project name must start with alphanumeric and contain only alphanumeric characters and hyphens"
168
+ )
169
+ return v
170
+
171
+
172
+ class StopRequest(BaseModel):
173
+ run_id: int
174
+
175
+
176
+ class UploadResponse(BaseModel):
177
+ filename: str
178
+ filepath: str
179
+ columns: list[str]
180
+ preview: list[dict]
181
+ row_count: int
182
+
183
+
184
+ # Routes
185
+ @app.get("/", response_class=HTMLResponse)
186
+ async def index(request: Request):
187
+ """Serve the main page."""
188
+ return templates.TemplateResponse("index.html", {"request": request})
189
+
190
+
191
+ @app.post("/upload", response_model=UploadResponse)
192
+ async def upload_file(
193
+ file: UploadFile = File(...),
194
+ file_type: str = Form("train"), # "train" or "eval"
195
+ ):
196
+ """Upload a training or evaluation data file.
197
+
198
+ Args:
199
+ file: The file to upload (CSV or JSONL)
200
+ file_type: Either "train" or "eval" to indicate file purpose
201
+ """
202
+ if file_type not in ("train", "eval"):
203
+ raise HTTPException(status_code=400, detail="file_type must be 'train' or 'eval'")
204
+
205
+ # Save file with prefix to distinguish train/eval
206
+ original_filename = file.filename
207
+ filename = f"{file_type}_{original_filename}"
208
+ filepath = UPLOAD_DIR / filename
209
+
210
+ with open(filepath, "wb") as f:
211
+ content = await file.read()
212
+ f.write(content)
213
+
214
+ # Get preview and row count
215
+ preview = []
216
+ columns = []
217
+ row_count = 0
218
+
219
+ try:
220
+ if original_filename.endswith(".csv"):
221
+ import pandas as pd
222
+
223
+ # Get row count
224
+ df_full = pd.read_csv(filepath)
225
+ row_count = len(df_full)
226
+
227
+ # Get preview
228
+ df = df_full.head(5)
229
+ columns = df.columns.tolist()
230
+ preview = df.to_dict("records")
231
+
232
+ elif original_filename.endswith(".jsonl"):
233
+ with open(filepath) as f:
234
+ for i, line in enumerate(f):
235
+ line = line.strip()
236
+ if not line:
237
+ continue
238
+ record = json.loads(line)
239
+ row_count += 1
240
+ if len(preview) < 5:
241
+ preview.append(record)
242
+ if not columns:
243
+ columns = list(record.keys())
244
+ else:
245
+ raise HTTPException(
246
+ status_code=400,
247
+ detail="Unsupported file format. Please upload CSV or JSONL files.",
248
+ )
249
+ except json.JSONDecodeError as e:
250
+ raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {e}")
251
+ except Exception as e:
252
+ raise HTTPException(status_code=400, detail=f"Failed to parse file: {e}")
253
+
254
+ return UploadResponse(
255
+ filename=original_filename,
256
+ filepath=str(filepath),
257
+ columns=columns,
258
+ preview=preview,
259
+ row_count=row_count,
260
+ )
261
+
262
+
263
+ @app.post("/train")
264
+ async def train(config: TrainRequest):
265
+ """Start a training run."""
266
+ # Check for active run
267
+ active = get_active_run()
268
+ if active:
269
+ raise HTTPException(
270
+ status_code=400,
271
+ detail="A training run is already in progress. Stop it first.",
272
+ )
273
+
274
+ # Validate data source - must have either file or HF dataset
275
+ has_file = config.train_filename is not None
276
+ has_hf = config.hf_dataset is not None
277
+
278
+ if not has_file and not has_hf:
279
+ raise HTTPException(
280
+ status_code=400,
281
+ detail="Must provide either train_filename or hf_dataset",
282
+ )
283
+
284
+ if has_file and has_hf:
285
+ raise HTTPException(
286
+ status_code=400,
287
+ detail="Cannot specify both train_filename and hf_dataset. Choose one.",
288
+ )
289
+
290
+ # Validate file exists if using file upload
291
+ if has_file and not Path(config.train_filename).exists():
292
+ raise HTTPException(status_code=400, detail="Training data file not found")
293
+
294
+ if config.eval_filename and not Path(config.eval_filename).exists():
295
+ raise HTTPException(status_code=400, detail="Evaluation data file not found")
296
+
297
+ # Validate hub config
298
+ if config.push_to_hub and not config.hf_username:
299
+ raise HTTPException(
300
+ status_code=400,
301
+ detail="hf_username is required when push_to_hub is enabled",
302
+ )
303
+
304
+ # Create output directory in ~/.vespaembed/projects/
305
+ output_dir = PROJECTS_DIR / config.project_name
306
+ if output_dir.exists():
307
+ # Append timestamp to make unique
308
+ timestamp = int(time.time())
309
+ output_dir = PROJECTS_DIR / f"{config.project_name}-{timestamp}"
310
+ output_dir.mkdir(parents=True, exist_ok=True)
311
+
312
+ # Prepare config with resolved output_dir
313
+ config_dict = config.model_dump()
314
+ config_dict["output_dir"] = str(output_dir)
315
+
316
+ # Create run record
317
+ run_id = create_run(
318
+ config=config_dict,
319
+ project_name=config.project_name,
320
+ output_dir=str(output_dir),
321
+ )
322
+
323
+ # Clear any existing update file for this run
324
+ update_file = UPDATE_DIR / f"run_{run_id}.jsonl"
325
+ if update_file.exists():
326
+ update_file.unlink()
327
+
328
+ # Start worker process
329
+ cmd = [
330
+ sys.executable,
331
+ "-m",
332
+ "vespaembed.worker",
333
+ "--run-id",
334
+ str(run_id),
335
+ "--config",
336
+ json.dumps(config_dict),
337
+ ]
338
+
339
+ # Don't capture stdout/stderr - let them flow to terminal for visibility
340
+ process = subprocess.Popen(
341
+ cmd,
342
+ stdout=None, # Inherit from parent - logs visible in terminal
343
+ stderr=None, # Inherit from parent - errors visible in terminal
344
+ start_new_session=True,
345
+ )
346
+
347
+ update_run_status(run_id, RunStatus.RUNNING, pid=process.pid)
348
+
349
+ return {"message": "Training started", "run_id": run_id, "output_dir": str(output_dir)}
350
+
351
+
352
+ @app.post("/stop")
353
+ async def stop(request: StopRequest):
354
+ """Stop a training run."""
355
+ run = get_run(request.run_id)
356
+ if not run:
357
+ raise HTTPException(status_code=404, detail="Run not found")
358
+
359
+ if run["status"] != RunStatus.RUNNING.value:
360
+ raise HTTPException(status_code=400, detail="Run is not active")
361
+
362
+ # Send SIGTERM to process
363
+ if run.get("pid"):
364
+ try:
365
+ os.kill(run["pid"], 15) # SIGTERM
366
+ except ProcessLookupError:
367
+ pass
368
+
369
+ update_run_status(request.run_id, RunStatus.STOPPED)
370
+
371
+ return {"message": "Training stopped"}
372
+
373
+
374
+ @app.get("/runs")
375
+ async def list_runs():
376
+ """List all training runs."""
377
+ return get_all_runs()
378
+
379
+
380
+ @app.get("/runs/{run_id}")
381
+ async def get_run_details(run_id: int):
382
+ """Get details of a specific run."""
383
+ run = get_run(run_id)
384
+ if not run:
385
+ raise HTTPException(status_code=404, detail="Run not found")
386
+ return run
387
+
388
+
389
+ @app.delete("/runs/{run_id}")
390
+ async def delete_run_endpoint(run_id: int):
391
+ """Delete a training run."""
392
+ run = get_run(run_id)
393
+ if not run:
394
+ raise HTTPException(status_code=404, detail="Run not found")
395
+
396
+ # Stop if running
397
+ if run["status"] == RunStatus.RUNNING.value and run.get("pid"):
398
+ try:
399
+ os.kill(run["pid"], 15)
400
+ except ProcessLookupError:
401
+ pass
402
+
403
+ delete_run(run_id, delete_files=True)
404
+ return {"message": "Run deleted"}
405
+
406
+
407
+ @app.get("/active_run_id")
408
+ async def get_active_run_id():
409
+ """Get the currently active run ID."""
410
+ run = get_active_run()
411
+ return {"run_id": run["id"] if run else None}
412
+
413
+
414
+ @app.get("/runs/{run_id}/updates")
415
+ async def get_run_updates(run_id: int, since_line: int = 0):
416
+ """Poll for updates from a training run.
417
+
418
+ Args:
419
+ run_id: The run ID to get updates for
420
+ since_line: Return updates after this line number (0-indexed).
421
+ Client should track this and pass it on subsequent requests.
422
+
423
+ Returns:
424
+ updates: List of update objects (progress, log, status, etc.)
425
+ next_line: The line number to use for the next poll request
426
+ has_more: Whether there might be more updates (run still active)
427
+ """
428
+ run = get_run(run_id)
429
+ if not run:
430
+ raise HTTPException(status_code=404, detail="Run not found")
431
+
432
+ update_file = UPDATE_DIR / f"run_{run_id}.jsonl"
433
+ updates = []
434
+ next_line = since_line
435
+
436
+ if update_file.exists():
437
+ with open(update_file) as f:
438
+ for i, line in enumerate(f):
439
+ if i < since_line:
440
+ continue
441
+ line = line.strip()
442
+ if line:
443
+ try:
444
+ updates.append(json.loads(line))
445
+ except json.JSONDecodeError:
446
+ pass
447
+ next_line = i + 1
448
+
449
+ # Check if run is still active
450
+ is_active = run["status"] == RunStatus.RUNNING.value
451
+
452
+ return {
453
+ "updates": updates,
454
+ "next_line": next_line,
455
+ "has_more": is_active,
456
+ "run_status": run["status"],
457
+ }
458
+
459
+
460
+ @app.get("/api/tasks")
461
+ async def get_all_tasks():
462
+ """Get information about all available training tasks."""
463
+ # Import tasks to ensure they're registered
464
+ import vespaembed.tasks # noqa: F401
465
+ from vespaembed.core.registry import Registry
466
+
467
+ return Registry.get_task_info()
468
+
469
+
470
+ @app.get("/api/tasks/{task_name}")
471
+ async def get_task(task_name: str):
472
+ """Get information about a specific training task."""
473
+ # Import tasks to ensure they're registered
474
+ import vespaembed.tasks # noqa: F401
475
+ from vespaembed.core.registry import Registry
476
+
477
+ try:
478
+ return Registry.get_task_info(task_name)
479
+ except ValueError as e:
480
+ raise HTTPException(status_code=404, detail=str(e))
481
+
482
+
483
+ @app.get("/runs/{run_id}/metrics")
484
+ async def get_run_metrics(run_id: int):
485
+ """Get training metrics from TensorBoard event files.
486
+
487
+ Returns metrics like loss, learning_rate, etc. parsed from tfevents files.
488
+ """
489
+ import math
490
+
491
+ def sanitize_value(value):
492
+ """Convert NaN/Inf to None for JSON serialization."""
493
+ if math.isnan(value) or math.isinf(value):
494
+ return None
495
+ return value
496
+
497
+ run = get_run(run_id)
498
+ if not run:
499
+ raise HTTPException(status_code=404, detail="Run not found")
500
+
501
+ log_dir = Path(run["output_dir"]) / "logs"
502
+ if not log_dir.exists():
503
+ return {"metrics": {}, "steps": []}
504
+
505
+ try:
506
+ from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
507
+
508
+ # Find all event files
509
+ metrics = {}
510
+ ea = EventAccumulator(str(log_dir))
511
+ ea.Reload()
512
+
513
+ # Get all scalar tags
514
+ scalar_tags = ea.Tags().get("scalars", [])
515
+
516
+ for tag in scalar_tags:
517
+ events = ea.Scalars(tag)
518
+ # Clean up tag name (remove "train/" prefix if present)
519
+ clean_tag = tag.replace("train/", "").replace("eval/", "eval_")
520
+ # Sanitize values to handle NaN/Inf
521
+ metrics[clean_tag] = [{"step": e.step, "value": sanitize_value(e.value)} for e in events]
522
+
523
+ return {"metrics": metrics}
524
+
525
+ except Exception as e:
526
+ # If tensorboard parsing fails, return empty metrics
527
+ return {"metrics": {}, "error": str(e)}
528
+
529
+
530
+ @app.get("/runs/{run_id}/artifacts")
531
+ async def get_run_artifacts(run_id: int):
532
+ """Get list of downloadable artifacts for a training run."""
533
+ run = get_run(run_id)
534
+ if not run:
535
+ raise HTTPException(status_code=404, detail="Run not found")
536
+
537
+ output_dir = Path(run["output_dir"])
538
+ if not output_dir.exists():
539
+ return {"artifacts": []}
540
+
541
+ artifacts = []
542
+
543
+ # Check for final model
544
+ final_path = output_dir / "final"
545
+ if final_path.exists() and final_path.is_dir():
546
+ # Get total size of final directory
547
+ total_size = sum(f.stat().st_size for f in final_path.rglob("*") if f.is_file())
548
+ artifacts.append(
549
+ {
550
+ "name": "final",
551
+ "label": "Final Model",
552
+ "category": "model",
553
+ "path": str(final_path),
554
+ "size": total_size,
555
+ "is_directory": True,
556
+ }
557
+ )
558
+
559
+ # Check for config file
560
+ config_files = list(output_dir.glob("*.json"))
561
+ for cf in config_files:
562
+ if cf.name in ["config.json", "training_config.json"]:
563
+ artifacts.append(
564
+ {
565
+ "name": cf.name,
566
+ "label": "Training Config",
567
+ "category": "config",
568
+ "path": str(cf),
569
+ "size": cf.stat().st_size,
570
+ "is_directory": False,
571
+ }
572
+ )
573
+
574
+ # Check for checkpoints
575
+ checkpoints = sorted(output_dir.glob("checkpoint-*"))
576
+ for ckpt in checkpoints[:3]: # Limit to 3 most recent
577
+ if ckpt.is_dir():
578
+ total_size = sum(f.stat().st_size for f in ckpt.rglob("*") if f.is_file())
579
+ artifacts.append(
580
+ {
581
+ "name": ckpt.name,
582
+ "label": f"Checkpoint ({ckpt.name})",
583
+ "category": "checkpoint",
584
+ "path": str(ckpt),
585
+ "size": total_size,
586
+ "is_directory": True,
587
+ }
588
+ )
589
+
590
+ # Check for logs
591
+ logs_path = output_dir / "logs"
592
+ if logs_path.exists():
593
+ total_size = sum(f.stat().st_size for f in logs_path.rglob("*") if f.is_file())
594
+ artifacts.append(
595
+ {
596
+ "name": "logs",
597
+ "label": "TensorBoard Logs",
598
+ "category": "logs",
599
+ "path": str(logs_path),
600
+ "size": total_size,
601
+ "is_directory": True,
602
+ }
603
+ )
604
+
605
+ return {"artifacts": artifacts, "output_dir": str(output_dir)}