vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.3.dist-info/METADATA +325 -0
- vespaembed-0.0.3.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
vespaembed/web/app.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import subprocess
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
11
|
+
from fastapi.responses import HTMLResponse
|
|
12
|
+
from fastapi.staticfiles import StaticFiles
|
|
13
|
+
from fastapi.templating import Jinja2Templates
|
|
14
|
+
from pydantic import BaseModel, Field, field_validator
|
|
15
|
+
from starlette.requests import Request
|
|
16
|
+
|
|
17
|
+
from vespaembed.db import create_run, delete_run, get_active_run, get_all_runs, get_run, update_run_status
|
|
18
|
+
from vespaembed.enums import RunStatus
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Helper: Check if process is alive
|
|
22
|
+
def is_process_alive(pid: int | None) -> bool:
|
|
23
|
+
"""Check if a process with the given PID is still running."""
|
|
24
|
+
if pid is None:
|
|
25
|
+
return False
|
|
26
|
+
try:
|
|
27
|
+
os.kill(pid, 0) # Signal 0 checks if process exists without killing it
|
|
28
|
+
return True
|
|
29
|
+
except (ProcessLookupError, OSError):
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Helper: Check if final model exists for a run
|
|
34
|
+
def has_final_model(output_dir: str | None) -> bool:
|
|
35
|
+
"""Check if training completed successfully by looking for final model."""
|
|
36
|
+
if not output_dir:
|
|
37
|
+
return False
|
|
38
|
+
final_path = Path(output_dir) / "final"
|
|
39
|
+
return final_path.exists() and final_path.is_dir()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Startup: Sync run statuses with actual process states
|
|
43
|
+
def sync_run_statuses():
|
|
44
|
+
"""Check all running/pending runs and update status if process is dead.
|
|
45
|
+
|
|
46
|
+
This handles cases where:
|
|
47
|
+
- Server was restarted while training was in progress
|
|
48
|
+
- Training process crashed unexpectedly
|
|
49
|
+
- Training completed but status wasn't updated
|
|
50
|
+
"""
|
|
51
|
+
runs = get_all_runs()
|
|
52
|
+
for run in runs:
|
|
53
|
+
if run["status"] in [RunStatus.RUNNING.value, RunStatus.PENDING.value]:
|
|
54
|
+
pid = run.get("pid")
|
|
55
|
+
run_id = run["id"]
|
|
56
|
+
|
|
57
|
+
# Check if process is still alive
|
|
58
|
+
if is_process_alive(pid):
|
|
59
|
+
continue # Still running, leave it
|
|
60
|
+
|
|
61
|
+
# Process is dead - determine final status
|
|
62
|
+
if has_final_model(run.get("output_dir")):
|
|
63
|
+
update_run_status(run_id, RunStatus.COMPLETED)
|
|
64
|
+
print(f"[sync] Run {run_id}: Marked as completed (final model exists)")
|
|
65
|
+
else:
|
|
66
|
+
update_run_status(run_id, RunStatus.ERROR, error_message="Process terminated unexpectedly")
|
|
67
|
+
print(f"[sync] Run {run_id}: Marked as error (no final model)")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Paths
|
|
71
|
+
PACKAGE_DIR = Path(__file__).parent.parent
|
|
72
|
+
STATIC_DIR = PACKAGE_DIR / "static"
|
|
73
|
+
TEMPLATES_DIR = PACKAGE_DIR / "templates"
|
|
74
|
+
BASE_DIR = Path.home() / ".vespaembed"
|
|
75
|
+
UPLOAD_DIR = BASE_DIR / "uploads"
|
|
76
|
+
UPDATE_DIR = BASE_DIR / "updates"
|
|
77
|
+
PROJECTS_DIR = BASE_DIR / "projects"
|
|
78
|
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
UPDATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
PROJECTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
# FastAPI app
|
|
83
|
+
app = FastAPI(title="VespaEmbed", version="0.0.1")
|
|
84
|
+
|
|
85
|
+
# Mount static files
|
|
86
|
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
87
|
+
|
|
88
|
+
# Templates
|
|
89
|
+
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
|
90
|
+
|
|
91
|
+
# Sync run statuses on startup
|
|
92
|
+
sync_run_statuses()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Pydantic models
|
|
96
|
+
class TrainRequest(BaseModel):
|
|
97
|
+
# Project name (required) - used to create output directory
|
|
98
|
+
project_name: str = Field(..., description="Project name (alphanumeric and hyphens only)")
|
|
99
|
+
|
|
100
|
+
# Data source - either file upload OR HuggingFace dataset
|
|
101
|
+
train_filename: Optional[str] = Field(None, description="Path to uploaded training file")
|
|
102
|
+
eval_filename: Optional[str] = Field(None, description="Path to uploaded evaluation file (optional)")
|
|
103
|
+
|
|
104
|
+
# HuggingFace dataset (alternative to file upload)
|
|
105
|
+
hf_dataset: Optional[str] = Field(
|
|
106
|
+
None, description="HuggingFace dataset name (e.g., 'sentence-transformers/all-nli')"
|
|
107
|
+
)
|
|
108
|
+
hf_subset: Optional[str] = Field(None, description="Dataset subset/config name")
|
|
109
|
+
hf_train_split: str = Field("train", description="Training split name")
|
|
110
|
+
hf_eval_split: Optional[str] = Field(None, description="Evaluation split name (optional)")
|
|
111
|
+
|
|
112
|
+
# Required
|
|
113
|
+
task: str
|
|
114
|
+
base_model: str
|
|
115
|
+
|
|
116
|
+
# Basic hyperparameters
|
|
117
|
+
epochs: int = 3
|
|
118
|
+
batch_size: int = 32
|
|
119
|
+
learning_rate: float = 2e-5
|
|
120
|
+
|
|
121
|
+
# Advanced hyperparameters
|
|
122
|
+
warmup_ratio: float = 0.1
|
|
123
|
+
weight_decay: float = 0.01
|
|
124
|
+
fp16: bool = False
|
|
125
|
+
bf16: bool = False
|
|
126
|
+
eval_steps: int = 500
|
|
127
|
+
save_steps: int = 500
|
|
128
|
+
logging_steps: int = 100
|
|
129
|
+
gradient_accumulation_steps: int = 1
|
|
130
|
+
|
|
131
|
+
# Optimizer and scheduler
|
|
132
|
+
optimizer: str = Field("adamw_torch", description="Optimizer type")
|
|
133
|
+
scheduler: str = Field("linear", description="Learning rate scheduler")
|
|
134
|
+
|
|
135
|
+
# LoRA/PEFT parameters
|
|
136
|
+
lora_enabled: bool = False
|
|
137
|
+
lora_r: int = 64
|
|
138
|
+
lora_alpha: int = 128
|
|
139
|
+
lora_dropout: float = 0.1
|
|
140
|
+
lora_target_modules: str = Field("query, key, value, dense", description="Comma-separated list of target modules")
|
|
141
|
+
|
|
142
|
+
# Model configuration
|
|
143
|
+
max_seq_length: Optional[int] = None # Auto-detect from model if not specified
|
|
144
|
+
gradient_checkpointing: bool = False # Saves VRAM, uses Unsloth optimization when Unsloth is enabled
|
|
145
|
+
|
|
146
|
+
# Unsloth parameters
|
|
147
|
+
unsloth_enabled: bool = False
|
|
148
|
+
unsloth_save_method: str = "merged_16bit" # "lora", "merged_16bit", "merged_4bit"
|
|
149
|
+
|
|
150
|
+
# Hub push
|
|
151
|
+
push_to_hub: bool = False
|
|
152
|
+
hf_username: Optional[str] = None
|
|
153
|
+
|
|
154
|
+
# Task-specific parameters
|
|
155
|
+
matryoshka_dims: Optional[str] = Field(
|
|
156
|
+
None, description="Matryoshka dimensions as comma-separated string (e.g., '768,512,256,128')"
|
|
157
|
+
)
|
|
158
|
+
loss_variant: Optional[str] = Field(
|
|
159
|
+
None, description="Loss function variant (task-specific, uses default if not specified)"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
@field_validator("project_name")
|
|
163
|
+
@classmethod
|
|
164
|
+
def validate_project_name(cls, v):
|
|
165
|
+
if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-]*$", v):
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Project name must start with alphanumeric and contain only alphanumeric characters and hyphens"
|
|
168
|
+
)
|
|
169
|
+
return v
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class StopRequest(BaseModel):
|
|
173
|
+
run_id: int
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class UploadResponse(BaseModel):
|
|
177
|
+
filename: str
|
|
178
|
+
filepath: str
|
|
179
|
+
columns: list[str]
|
|
180
|
+
preview: list[dict]
|
|
181
|
+
row_count: int
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Routes
|
|
185
|
+
@app.get("/", response_class=HTMLResponse)
|
|
186
|
+
async def index(request: Request):
|
|
187
|
+
"""Serve the main page."""
|
|
188
|
+
return templates.TemplateResponse("index.html", {"request": request})
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@app.post("/upload", response_model=UploadResponse)
|
|
192
|
+
async def upload_file(
|
|
193
|
+
file: UploadFile = File(...),
|
|
194
|
+
file_type: str = Form("train"), # "train" or "eval"
|
|
195
|
+
):
|
|
196
|
+
"""Upload a training or evaluation data file.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
file: The file to upload (CSV or JSONL)
|
|
200
|
+
file_type: Either "train" or "eval" to indicate file purpose
|
|
201
|
+
"""
|
|
202
|
+
if file_type not in ("train", "eval"):
|
|
203
|
+
raise HTTPException(status_code=400, detail="file_type must be 'train' or 'eval'")
|
|
204
|
+
|
|
205
|
+
# Save file with prefix to distinguish train/eval
|
|
206
|
+
original_filename = file.filename
|
|
207
|
+
filename = f"{file_type}_{original_filename}"
|
|
208
|
+
filepath = UPLOAD_DIR / filename
|
|
209
|
+
|
|
210
|
+
with open(filepath, "wb") as f:
|
|
211
|
+
content = await file.read()
|
|
212
|
+
f.write(content)
|
|
213
|
+
|
|
214
|
+
# Get preview and row count
|
|
215
|
+
preview = []
|
|
216
|
+
columns = []
|
|
217
|
+
row_count = 0
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
if original_filename.endswith(".csv"):
|
|
221
|
+
import pandas as pd
|
|
222
|
+
|
|
223
|
+
# Get row count
|
|
224
|
+
df_full = pd.read_csv(filepath)
|
|
225
|
+
row_count = len(df_full)
|
|
226
|
+
|
|
227
|
+
# Get preview
|
|
228
|
+
df = df_full.head(5)
|
|
229
|
+
columns = df.columns.tolist()
|
|
230
|
+
preview = df.to_dict("records")
|
|
231
|
+
|
|
232
|
+
elif original_filename.endswith(".jsonl"):
|
|
233
|
+
with open(filepath) as f:
|
|
234
|
+
for i, line in enumerate(f):
|
|
235
|
+
line = line.strip()
|
|
236
|
+
if not line:
|
|
237
|
+
continue
|
|
238
|
+
record = json.loads(line)
|
|
239
|
+
row_count += 1
|
|
240
|
+
if len(preview) < 5:
|
|
241
|
+
preview.append(record)
|
|
242
|
+
if not columns:
|
|
243
|
+
columns = list(record.keys())
|
|
244
|
+
else:
|
|
245
|
+
raise HTTPException(
|
|
246
|
+
status_code=400,
|
|
247
|
+
detail="Unsupported file format. Please upload CSV or JSONL files.",
|
|
248
|
+
)
|
|
249
|
+
except json.JSONDecodeError as e:
|
|
250
|
+
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {e}")
|
|
251
|
+
except Exception as e:
|
|
252
|
+
raise HTTPException(status_code=400, detail=f"Failed to parse file: {e}")
|
|
253
|
+
|
|
254
|
+
return UploadResponse(
|
|
255
|
+
filename=original_filename,
|
|
256
|
+
filepath=str(filepath),
|
|
257
|
+
columns=columns,
|
|
258
|
+
preview=preview,
|
|
259
|
+
row_count=row_count,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@app.post("/train")
|
|
264
|
+
async def train(config: TrainRequest):
|
|
265
|
+
"""Start a training run."""
|
|
266
|
+
# Check for active run
|
|
267
|
+
active = get_active_run()
|
|
268
|
+
if active:
|
|
269
|
+
raise HTTPException(
|
|
270
|
+
status_code=400,
|
|
271
|
+
detail="A training run is already in progress. Stop it first.",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Validate data source - must have either file or HF dataset
|
|
275
|
+
has_file = config.train_filename is not None
|
|
276
|
+
has_hf = config.hf_dataset is not None
|
|
277
|
+
|
|
278
|
+
if not has_file and not has_hf:
|
|
279
|
+
raise HTTPException(
|
|
280
|
+
status_code=400,
|
|
281
|
+
detail="Must provide either train_filename or hf_dataset",
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if has_file and has_hf:
|
|
285
|
+
raise HTTPException(
|
|
286
|
+
status_code=400,
|
|
287
|
+
detail="Cannot specify both train_filename and hf_dataset. Choose one.",
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Validate file exists if using file upload
|
|
291
|
+
if has_file and not Path(config.train_filename).exists():
|
|
292
|
+
raise HTTPException(status_code=400, detail="Training data file not found")
|
|
293
|
+
|
|
294
|
+
if config.eval_filename and not Path(config.eval_filename).exists():
|
|
295
|
+
raise HTTPException(status_code=400, detail="Evaluation data file not found")
|
|
296
|
+
|
|
297
|
+
# Validate hub config
|
|
298
|
+
if config.push_to_hub and not config.hf_username:
|
|
299
|
+
raise HTTPException(
|
|
300
|
+
status_code=400,
|
|
301
|
+
detail="hf_username is required when push_to_hub is enabled",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Create output directory in ~/.vespaembed/projects/
|
|
305
|
+
output_dir = PROJECTS_DIR / config.project_name
|
|
306
|
+
if output_dir.exists():
|
|
307
|
+
# Append timestamp to make unique
|
|
308
|
+
timestamp = int(time.time())
|
|
309
|
+
output_dir = PROJECTS_DIR / f"{config.project_name}-{timestamp}"
|
|
310
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
311
|
+
|
|
312
|
+
# Prepare config with resolved output_dir
|
|
313
|
+
config_dict = config.model_dump()
|
|
314
|
+
config_dict["output_dir"] = str(output_dir)
|
|
315
|
+
|
|
316
|
+
# Create run record
|
|
317
|
+
run_id = create_run(
|
|
318
|
+
config=config_dict,
|
|
319
|
+
project_name=config.project_name,
|
|
320
|
+
output_dir=str(output_dir),
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Clear any existing update file for this run
|
|
324
|
+
update_file = UPDATE_DIR / f"run_{run_id}.jsonl"
|
|
325
|
+
if update_file.exists():
|
|
326
|
+
update_file.unlink()
|
|
327
|
+
|
|
328
|
+
# Start worker process
|
|
329
|
+
cmd = [
|
|
330
|
+
sys.executable,
|
|
331
|
+
"-m",
|
|
332
|
+
"vespaembed.worker",
|
|
333
|
+
"--run-id",
|
|
334
|
+
str(run_id),
|
|
335
|
+
"--config",
|
|
336
|
+
json.dumps(config_dict),
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
# Don't capture stdout/stderr - let them flow to terminal for visibility
|
|
340
|
+
process = subprocess.Popen(
|
|
341
|
+
cmd,
|
|
342
|
+
stdout=None, # Inherit from parent - logs visible in terminal
|
|
343
|
+
stderr=None, # Inherit from parent - errors visible in terminal
|
|
344
|
+
start_new_session=True,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
update_run_status(run_id, RunStatus.RUNNING, pid=process.pid)
|
|
348
|
+
|
|
349
|
+
return {"message": "Training started", "run_id": run_id, "output_dir": str(output_dir)}
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@app.post("/stop")
|
|
353
|
+
async def stop(request: StopRequest):
|
|
354
|
+
"""Stop a training run."""
|
|
355
|
+
run = get_run(request.run_id)
|
|
356
|
+
if not run:
|
|
357
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
358
|
+
|
|
359
|
+
if run["status"] != RunStatus.RUNNING.value:
|
|
360
|
+
raise HTTPException(status_code=400, detail="Run is not active")
|
|
361
|
+
|
|
362
|
+
# Send SIGTERM to process
|
|
363
|
+
if run.get("pid"):
|
|
364
|
+
try:
|
|
365
|
+
os.kill(run["pid"], 15) # SIGTERM
|
|
366
|
+
except ProcessLookupError:
|
|
367
|
+
pass
|
|
368
|
+
|
|
369
|
+
update_run_status(request.run_id, RunStatus.STOPPED)
|
|
370
|
+
|
|
371
|
+
return {"message": "Training stopped"}
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@app.get("/runs")
|
|
375
|
+
async def list_runs():
|
|
376
|
+
"""List all training runs."""
|
|
377
|
+
return get_all_runs()
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@app.get("/runs/{run_id}")
|
|
381
|
+
async def get_run_details(run_id: int):
|
|
382
|
+
"""Get details of a specific run."""
|
|
383
|
+
run = get_run(run_id)
|
|
384
|
+
if not run:
|
|
385
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
386
|
+
return run
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
@app.delete("/runs/{run_id}")
|
|
390
|
+
async def delete_run_endpoint(run_id: int):
|
|
391
|
+
"""Delete a training run."""
|
|
392
|
+
run = get_run(run_id)
|
|
393
|
+
if not run:
|
|
394
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
395
|
+
|
|
396
|
+
# Stop if running
|
|
397
|
+
if run["status"] == RunStatus.RUNNING.value and run.get("pid"):
|
|
398
|
+
try:
|
|
399
|
+
os.kill(run["pid"], 15)
|
|
400
|
+
except ProcessLookupError:
|
|
401
|
+
pass
|
|
402
|
+
|
|
403
|
+
delete_run(run_id, delete_files=True)
|
|
404
|
+
return {"message": "Run deleted"}
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@app.get("/active_run_id")
|
|
408
|
+
async def get_active_run_id():
|
|
409
|
+
"""Get the currently active run ID."""
|
|
410
|
+
run = get_active_run()
|
|
411
|
+
return {"run_id": run["id"] if run else None}
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@app.get("/runs/{run_id}/updates")
|
|
415
|
+
async def get_run_updates(run_id: int, since_line: int = 0):
|
|
416
|
+
"""Poll for updates from a training run.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
run_id: The run ID to get updates for
|
|
420
|
+
since_line: Return updates after this line number (0-indexed).
|
|
421
|
+
Client should track this and pass it on subsequent requests.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
updates: List of update objects (progress, log, status, etc.)
|
|
425
|
+
next_line: The line number to use for the next poll request
|
|
426
|
+
has_more: Whether there might be more updates (run still active)
|
|
427
|
+
"""
|
|
428
|
+
run = get_run(run_id)
|
|
429
|
+
if not run:
|
|
430
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
431
|
+
|
|
432
|
+
update_file = UPDATE_DIR / f"run_{run_id}.jsonl"
|
|
433
|
+
updates = []
|
|
434
|
+
next_line = since_line
|
|
435
|
+
|
|
436
|
+
if update_file.exists():
|
|
437
|
+
with open(update_file) as f:
|
|
438
|
+
for i, line in enumerate(f):
|
|
439
|
+
if i < since_line:
|
|
440
|
+
continue
|
|
441
|
+
line = line.strip()
|
|
442
|
+
if line:
|
|
443
|
+
try:
|
|
444
|
+
updates.append(json.loads(line))
|
|
445
|
+
except json.JSONDecodeError:
|
|
446
|
+
pass
|
|
447
|
+
next_line = i + 1
|
|
448
|
+
|
|
449
|
+
# Check if run is still active
|
|
450
|
+
is_active = run["status"] == RunStatus.RUNNING.value
|
|
451
|
+
|
|
452
|
+
return {
|
|
453
|
+
"updates": updates,
|
|
454
|
+
"next_line": next_line,
|
|
455
|
+
"has_more": is_active,
|
|
456
|
+
"run_status": run["status"],
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
@app.get("/api/tasks")
|
|
461
|
+
async def get_all_tasks():
|
|
462
|
+
"""Get information about all available training tasks."""
|
|
463
|
+
# Import tasks to ensure they're registered
|
|
464
|
+
import vespaembed.tasks # noqa: F401
|
|
465
|
+
from vespaembed.core.registry import Registry
|
|
466
|
+
|
|
467
|
+
return Registry.get_task_info()
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@app.get("/api/tasks/{task_name}")
|
|
471
|
+
async def get_task(task_name: str):
|
|
472
|
+
"""Get information about a specific training task."""
|
|
473
|
+
# Import tasks to ensure they're registered
|
|
474
|
+
import vespaembed.tasks # noqa: F401
|
|
475
|
+
from vespaembed.core.registry import Registry
|
|
476
|
+
|
|
477
|
+
try:
|
|
478
|
+
return Registry.get_task_info(task_name)
|
|
479
|
+
except ValueError as e:
|
|
480
|
+
raise HTTPException(status_code=404, detail=str(e))
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@app.get("/runs/{run_id}/metrics")
|
|
484
|
+
async def get_run_metrics(run_id: int):
|
|
485
|
+
"""Get training metrics from TensorBoard event files.
|
|
486
|
+
|
|
487
|
+
Returns metrics like loss, learning_rate, etc. parsed from tfevents files.
|
|
488
|
+
"""
|
|
489
|
+
import math
|
|
490
|
+
|
|
491
|
+
def sanitize_value(value):
|
|
492
|
+
"""Convert NaN/Inf to None for JSON serialization."""
|
|
493
|
+
if math.isnan(value) or math.isinf(value):
|
|
494
|
+
return None
|
|
495
|
+
return value
|
|
496
|
+
|
|
497
|
+
run = get_run(run_id)
|
|
498
|
+
if not run:
|
|
499
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
500
|
+
|
|
501
|
+
log_dir = Path(run["output_dir"]) / "logs"
|
|
502
|
+
if not log_dir.exists():
|
|
503
|
+
return {"metrics": {}, "steps": []}
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
|
507
|
+
|
|
508
|
+
# Find all event files
|
|
509
|
+
metrics = {}
|
|
510
|
+
ea = EventAccumulator(str(log_dir))
|
|
511
|
+
ea.Reload()
|
|
512
|
+
|
|
513
|
+
# Get all scalar tags
|
|
514
|
+
scalar_tags = ea.Tags().get("scalars", [])
|
|
515
|
+
|
|
516
|
+
for tag in scalar_tags:
|
|
517
|
+
events = ea.Scalars(tag)
|
|
518
|
+
# Clean up tag name (remove "train/" prefix if present)
|
|
519
|
+
clean_tag = tag.replace("train/", "").replace("eval/", "eval_")
|
|
520
|
+
# Sanitize values to handle NaN/Inf
|
|
521
|
+
metrics[clean_tag] = [{"step": e.step, "value": sanitize_value(e.value)} for e in events]
|
|
522
|
+
|
|
523
|
+
return {"metrics": metrics}
|
|
524
|
+
|
|
525
|
+
except Exception as e:
|
|
526
|
+
# If tensorboard parsing fails, return empty metrics
|
|
527
|
+
return {"metrics": {}, "error": str(e)}
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
@app.get("/runs/{run_id}/artifacts")
|
|
531
|
+
async def get_run_artifacts(run_id: int):
|
|
532
|
+
"""Get list of downloadable artifacts for a training run."""
|
|
533
|
+
run = get_run(run_id)
|
|
534
|
+
if not run:
|
|
535
|
+
raise HTTPException(status_code=404, detail="Run not found")
|
|
536
|
+
|
|
537
|
+
output_dir = Path(run["output_dir"])
|
|
538
|
+
if not output_dir.exists():
|
|
539
|
+
return {"artifacts": []}
|
|
540
|
+
|
|
541
|
+
artifacts = []
|
|
542
|
+
|
|
543
|
+
# Check for final model
|
|
544
|
+
final_path = output_dir / "final"
|
|
545
|
+
if final_path.exists() and final_path.is_dir():
|
|
546
|
+
# Get total size of final directory
|
|
547
|
+
total_size = sum(f.stat().st_size for f in final_path.rglob("*") if f.is_file())
|
|
548
|
+
artifacts.append(
|
|
549
|
+
{
|
|
550
|
+
"name": "final",
|
|
551
|
+
"label": "Final Model",
|
|
552
|
+
"category": "model",
|
|
553
|
+
"path": str(final_path),
|
|
554
|
+
"size": total_size,
|
|
555
|
+
"is_directory": True,
|
|
556
|
+
}
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
# Check for config file
|
|
560
|
+
config_files = list(output_dir.glob("*.json"))
|
|
561
|
+
for cf in config_files:
|
|
562
|
+
if cf.name in ["config.json", "training_config.json"]:
|
|
563
|
+
artifacts.append(
|
|
564
|
+
{
|
|
565
|
+
"name": cf.name,
|
|
566
|
+
"label": "Training Config",
|
|
567
|
+
"category": "config",
|
|
568
|
+
"path": str(cf),
|
|
569
|
+
"size": cf.stat().st_size,
|
|
570
|
+
"is_directory": False,
|
|
571
|
+
}
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Check for checkpoints
|
|
575
|
+
checkpoints = sorted(output_dir.glob("checkpoint-*"))
|
|
576
|
+
for ckpt in checkpoints[:3]: # Limit to 3 most recent
|
|
577
|
+
if ckpt.is_dir():
|
|
578
|
+
total_size = sum(f.stat().st_size for f in ckpt.rglob("*") if f.is_file())
|
|
579
|
+
artifacts.append(
|
|
580
|
+
{
|
|
581
|
+
"name": ckpt.name,
|
|
582
|
+
"label": f"Checkpoint ({ckpt.name})",
|
|
583
|
+
"category": "checkpoint",
|
|
584
|
+
"path": str(ckpt),
|
|
585
|
+
"size": total_size,
|
|
586
|
+
"is_directory": True,
|
|
587
|
+
}
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Check for logs
|
|
591
|
+
logs_path = output_dir / "logs"
|
|
592
|
+
if logs_path.exists():
|
|
593
|
+
total_size = sum(f.stat().st_size for f in logs_path.rglob("*") if f.is_file())
|
|
594
|
+
artifacts.append(
|
|
595
|
+
{
|
|
596
|
+
"name": "logs",
|
|
597
|
+
"label": "TensorBoard Logs",
|
|
598
|
+
"category": "logs",
|
|
599
|
+
"path": str(logs_path),
|
|
600
|
+
"size": total_size,
|
|
601
|
+
"is_directory": True,
|
|
602
|
+
}
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
return {"artifacts": artifacts, "output_dir": str(output_dir)}
|