vespaembed 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.2.dist-info/METADATA +325 -0
  43. vespaembed-0.0.2.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1485 @@
1
+ // VespaEmbed Web UI
2
+
3
+ // State
4
+ let runs = [];
5
+ let activeRunId = null;
6
+ let selectedRunId = null;
7
+ let chart = null;
8
+ let lossHistory = [];
9
+ let pollInterval = null;
10
+ let pollLine = 0;
11
+ let currentDataSource = 'file'; // 'file' or 'huggingface'
12
+ let tasksData = []; // Cached task data from API
13
+ let metricsData = {}; // Cached metrics from TensorBoard
14
+ let currentMetric = 'loss'; // Currently selected metric
15
+
16
+ // DOM Elements
17
+ const newTrainingBtn = document.getElementById('new-training-btn');
18
+ const newTrainingModal = document.getElementById('new-training-modal');
19
+ const closeNewTraining = document.getElementById('close-new-training');
20
+ const trainForm = document.getElementById('train-form');
21
+ const runList = document.getElementById('run-list');
22
+ const projectSummary = document.getElementById('project-summary');
23
+ const logContent = document.getElementById('log-content');
24
+ const chartPlaceholder = document.getElementById('chart-placeholder');
25
+ const artifactsBtn = document.getElementById('artifacts-btn');
26
+ const artifactsModal = document.getElementById('artifacts-modal');
27
+ const closeArtifactsModal = document.getElementById('close-artifacts-modal');
28
+
29
+ // Initialize
30
+ document.addEventListener('DOMContentLoaded', async () => {
31
+ initChart();
32
+ await loadTasks(); // Load tasks from API first
33
+ await loadRuns(true); // Auto-select latest run on initial load
34
+ setupEventListeners();
35
+ setupFileUploads();
36
+ setupTabs();
37
+ setupHubToggle();
38
+ setupLoraToggle();
39
+ setupUnslothToggle();
40
+ setupMatryoshkaToggle();
41
+ setupTaskSelector();
42
+ setupWizard();
43
+ setupArtifactsModal();
44
+ });
45
+
46
+ // Generate random project name
47
+ function generateProjectName() {
48
+ const chars = 'abcdefghijklmnopqrstuvwxyz0123456789';
49
+ let name = '';
50
+ for (let i = 0; i < 8; i++) {
51
+ name += chars.charAt(Math.floor(Math.random() * chars.length));
52
+ }
53
+ return name;
54
+ }
55
+
56
+ // Load tasks from API
57
+ async function loadTasks() {
58
+ try {
59
+ const response = await fetch('/api/tasks');
60
+ tasksData = await response.json();
61
+ populateTaskDropdown();
62
+ } catch (error) {
63
+ console.error('Failed to load tasks:', error);
64
+ }
65
+ }
66
+
67
+ // Populate task dropdown from API data
68
+ function populateTaskDropdown() {
69
+ const taskSelect = document.getElementById('task');
70
+ taskSelect.innerHTML = tasksData.map(task =>
71
+ `<option value="${task.name}">${task.name.toUpperCase()}</option>`
72
+ ).join('');
73
+
74
+ // Set initial task description and defaults
75
+ if (tasksData.length > 0) {
76
+ updateTaskUI(tasksData[0]);
77
+ }
78
+ }
79
+
80
+ // Setup task selector change handler
81
+ function setupTaskSelector() {
82
+ const taskSelect = document.getElementById('task');
83
+ taskSelect.addEventListener('change', () => {
84
+ const task = tasksData.find(t => t.name === taskSelect.value);
85
+ if (task) {
86
+ updateTaskUI(task);
87
+ }
88
+ });
89
+ }
90
+
91
+ // Update UI based on selected task
92
+ function updateTaskUI(task) {
93
+ // Update task description
94
+ const descEl = document.getElementById('task-description');
95
+ if (descEl) {
96
+ descEl.textContent = task.description;
97
+ }
98
+
99
+ // Update required columns display
100
+ const columnsEl = document.getElementById('required-columns-list');
101
+ if (columnsEl && task.expected_columns) {
102
+ columnsEl.innerHTML = task.expected_columns.map(col =>
103
+ `<span class="required-column-tag">${col}</span>`
104
+ ).join('');
105
+ }
106
+
107
+ // Update loss variant dropdown
108
+ updateLossVariantUI(task);
109
+
110
+ // Update sample data display
111
+ updateSampleData(task);
112
+
113
+ // Update Matryoshka section visibility (not supported for TSDAE)
114
+ const matryoshkaSection = document.getElementById('matryoshka-section');
115
+ if (matryoshkaSection) {
116
+ if (task.name === 'tsdae') {
117
+ matryoshkaSection.style.display = 'none';
118
+ // Also uncheck and hide fields when switching to TSDAE
119
+ document.getElementById('matryoshka_enabled').checked = false;
120
+ document.getElementById('matryoshka-fields').style.display = 'none';
121
+ } else {
122
+ matryoshkaSection.style.display = 'block';
123
+ }
124
+ }
125
+
126
+ // Update hyperparameters to task defaults
127
+ const hyper = task.hyperparameters;
128
+ if (hyper) {
129
+ setValueIfExists('epochs', hyper.epochs);
130
+ setValueIfExists('batch_size', hyper.batch_size);
131
+ setValueIfExists('learning_rate', hyper.learning_rate);
132
+ setValueIfExists('warmup_ratio', hyper.warmup_ratio);
133
+ setValueIfExists('weight_decay', hyper.weight_decay);
134
+ setValueIfExists('eval_steps', hyper.eval_steps);
135
+ setValueIfExists('save_steps', hyper.save_steps);
136
+ setValueIfExists('logging_steps', hyper.logging_steps);
137
+ setValueIfExists('gradient_accumulation_steps', hyper.gradient_accumulation_steps);
138
+ setValueIfExists('optimizer', hyper.optimizer || 'adamw_torch');
139
+ setValueIfExists('scheduler', hyper.scheduler || 'linear');
140
+
141
+ // Set precision dropdown
142
+ if (hyper.bf16) {
143
+ setValueIfExists('precision', 'bf16');
144
+ } else if (hyper.fp16) {
145
+ setValueIfExists('precision', 'fp16');
146
+ } else {
147
+ setValueIfExists('precision', 'fp32');
148
+ }
149
+ }
150
+
151
+ // Update task-specific parameters
152
+ const paramsContainer = document.getElementById('task-specific-params');
153
+ paramsContainer.innerHTML = '';
154
+
155
+ if (task.task_specific_params && Object.keys(task.task_specific_params).length > 0) {
156
+ const fieldsHtml = Object.entries(task.task_specific_params).map(([key, config]) => {
157
+ return `
158
+ <div class="form-field">
159
+ <label>${config.label}</label>
160
+ <input type="${config.type}" id="task_param_${key}" name="task_param_${key}"
161
+ value="${config.default || ''}">
162
+ ${config.description ? `<span class="field-hint">${config.description}</span>` : ''}
163
+ </div>
164
+ `;
165
+ }).join('');
166
+
167
+ paramsContainer.innerHTML = `
168
+ <div class="form-section">
169
+ <label class="section-label">Task Settings</label>
170
+ ${fieldsHtml}
171
+ </div>
172
+ `;
173
+ }
174
+ }
175
+
176
+ // Update sample data display for selected task
177
+ function updateSampleData(task) {
178
+ const container = document.getElementById('sample-data-content');
179
+ if (!container || !task.sample_data || task.sample_data.length === 0) {
180
+ return;
181
+ }
182
+
183
+ // Show just the first sample row as a formatted example
184
+ const sample = task.sample_data[0];
185
+ const columns = task.expected_columns;
186
+
187
+ const rowsHtml = columns.map(col => {
188
+ const value = sample[col];
189
+ const displayValue = typeof value === 'string'
190
+ ? (value.length > 50 ? value.substring(0, 50) + '...' : value)
191
+ : value;
192
+ return `<span class="sample-data-row"><span class="sample-data-key">${col}:</span> <span class="sample-data-value">"${displayValue}"</span></span>`;
193
+ }).join('');
194
+
195
+ container.innerHTML = rowsHtml;
196
+ }
197
+
198
+ // Update loss variant dropdown based on selected task
199
+ function updateLossVariantUI(task) {
200
+ const fieldEl = document.getElementById('loss-variant-field');
201
+ const selectEl = document.getElementById('loss_variant');
202
+ const hintEl = document.getElementById('loss-variant-hint');
203
+
204
+ if (!fieldEl || !selectEl) return;
205
+
206
+ // Check if task has loss options
207
+ if (task.loss_options && task.loss_options.length > 0) {
208
+ // Populate dropdown with multiple options
209
+ selectEl.innerHTML = task.loss_options.map(loss => {
210
+ const isDefault = loss === task.default_loss;
211
+ const label = formatLossLabel(loss) + (isDefault ? ' (default)' : '');
212
+ return `<option value="${loss}" ${isDefault ? 'selected' : ''}>${label}</option>`;
213
+ }).join('');
214
+ selectEl.disabled = false;
215
+
216
+ // Update hint based on task type
217
+ if (hintEl) {
218
+ if (task.name === 'pairs' || task.name === 'triplets' || task.name === 'matryoshka') {
219
+ hintEl.textContent = 'MNR uses in-batch negatives. GIST uses a guide model for better negatives.';
220
+ } else if (task.name === 'similarity') {
221
+ hintEl.textContent = 'CoSENT and AnglE often outperform Cosine on STS benchmarks.';
222
+ } else {
223
+ hintEl.textContent = '';
224
+ }
225
+ }
226
+
227
+ fieldEl.style.display = 'block';
228
+ } else if (task.name === 'tsdae') {
229
+ // TSDAE has a fixed loss - show it but disabled
230
+ selectEl.innerHTML = '<option value="">Denoising Auto-Encoder Loss</option>';
231
+ selectEl.disabled = true;
232
+
233
+ if (hintEl) {
234
+ hintEl.textContent = 'TSDAE uses unsupervised denoising to learn embeddings from unlabeled text.';
235
+ }
236
+
237
+ fieldEl.style.display = 'block';
238
+ } else {
239
+ // Hide for other tasks without loss options
240
+ fieldEl.style.display = 'none';
241
+ }
242
+ }
243
+
244
+ // Format loss variant name for display
245
+ function formatLossLabel(loss) {
246
+ const labels = {
247
+ 'mnr': 'MNR (Multiple Negatives Ranking)',
248
+ 'mnr_symmetric': 'MNR Symmetric',
249
+ 'gist': 'GIST (Guided In-Sample Triplet)',
250
+ 'cached_mnr': 'Cached MNR',
251
+ 'cached_gist': 'Cached GIST',
252
+ 'cosine': 'Cosine Similarity',
253
+ 'cosent': 'CoSENT',
254
+ 'angle': 'AnglE',
255
+ };
256
+ return labels[loss] || loss.toUpperCase();
257
+ }
258
+
259
+ // Helper to set input value if element exists
260
+ function setValueIfExists(id, value) {
261
+ const el = document.getElementById(id);
262
+ if (el && value !== undefined) {
263
+ el.value = value;
264
+ }
265
+ }
266
+
267
+ // Helper to set radio button checked
268
+ function setRadioChecked(name, value) {
269
+ const radio = document.querySelector(`input[name="${name}"][value="${value}"]`);
270
+ if (radio) {
271
+ radio.checked = true;
272
+ }
273
+ }
274
+
275
+ // Get task-specific parameter values
276
+ function getTaskSpecificParams() {
277
+ const params = {};
278
+ const task = tasksData.find(t => t.name === document.getElementById('task').value);
279
+ if (task && task.task_specific_params) {
280
+ for (const key of Object.keys(task.task_specific_params)) {
281
+ const el = document.getElementById(`task_param_${key}`);
282
+ if (el && el.value) {
283
+ params[key] = el.value;
284
+ }
285
+ }
286
+ }
287
+ return params;
288
+ }
289
+
290
+ // Polling for updates
291
+ let pollingRunId = null; // The run ID we're currently polling for
292
+
293
+ function startPolling(runId) {
294
+ stopPolling();
295
+ // pollLine is set by loadHistoricalUpdates, so don't reset it here
296
+ pollingRunId = runId;
297
+
298
+ pollInterval = setInterval(async () => {
299
+ try {
300
+ const response = await fetch(`/runs/${pollingRunId}/updates?since_line=${pollLine}`);
301
+ if (!response.ok) {
302
+ stopPolling();
303
+ return;
304
+ }
305
+
306
+ const data = await response.json();
307
+
308
+ // Only update UI if we're viewing the run being polled
309
+ if (selectedRunId === pollingRunId) {
310
+ data.updates.forEach(update => {
311
+ handleUpdate(update);
312
+ });
313
+
314
+ // Refresh metrics from TensorBoard files
315
+ await loadMetrics(pollingRunId);
316
+ }
317
+
318
+ // Update poll position
319
+ pollLine = data.next_line;
320
+
321
+ // Stop polling if run is no longer active
322
+ if (!data.has_more) {
323
+ stopPolling();
324
+ loadRuns(); // Refresh run list to show final status
325
+ // Final metrics refresh if still viewing this run
326
+ if (selectedRunId === pollingRunId) {
327
+ await loadMetrics(pollingRunId);
328
+ }
329
+ }
330
+ } catch (error) {
331
+ console.error('Polling error:', error);
332
+ }
333
+ }, 2000); // Poll every 2 seconds
334
+ }
335
+
336
+ function stopPolling() {
337
+ if (pollInterval) {
338
+ clearInterval(pollInterval);
339
+ pollInterval = null;
340
+ }
341
+ pollingRunId = null;
342
+ }
343
+
344
+ function handleUpdate(data) {
345
+ switch (data.type) {
346
+ case 'log':
347
+ appendLog(data.message);
348
+ break;
349
+ case 'progress':
350
+ updateProgress(data);
351
+ break;
352
+ case 'status':
353
+ updateRunStatus(data.run_id, data.status);
354
+ break;
355
+ case 'complete':
356
+ handleTrainingComplete(data);
357
+ break;
358
+ case 'error':
359
+ handleTrainingError(data);
360
+ break;
361
+ }
362
+ }
363
+
364
+ // Chart
365
+ function initChart() {
366
+ const ctx = document.getElementById('loss-chart').getContext('2d');
367
+ chart = new Chart(ctx, {
368
+ type: 'line',
369
+ data: {
370
+ labels: [],
371
+ datasets: [{
372
+ label: 'Loss',
373
+ data: [],
374
+ borderColor: '#22c55e',
375
+ backgroundColor: 'rgba(34, 197, 94, 0.1)',
376
+ borderWidth: 2,
377
+ fill: true,
378
+ tension: 0.3,
379
+ pointRadius: 3,
380
+ pointBackgroundColor: '#22c55e',
381
+ pointBorderColor: '#22c55e',
382
+ pointHoverRadius: 6,
383
+ pointHoverBackgroundColor: '#fff',
384
+ pointHoverBorderColor: '#22c55e',
385
+ pointHoverBorderWidth: 2,
386
+ }]
387
+ },
388
+ options: {
389
+ responsive: true,
390
+ maintainAspectRatio: false,
391
+ interaction: {
392
+ mode: 'nearest',
393
+ axis: 'x',
394
+ intersect: false,
395
+ },
396
+ plugins: {
397
+ legend: { display: false },
398
+ tooltip: {
399
+ backgroundColor: 'rgba(0, 0, 0, 0.8)',
400
+ titleColor: '#fff',
401
+ bodyColor: '#fff',
402
+ borderColor: '#22c55e',
403
+ borderWidth: 1,
404
+ displayColors: false,
405
+ callbacks: {
406
+ title: (items) => `Step ${items[0].label}`,
407
+ label: (item) => `${item.dataset.label}: ${item.parsed.y.toFixed(4)}`,
408
+ }
409
+ }
410
+ },
411
+ scales: {
412
+ x: {
413
+ display: true,
414
+ grid: { color: 'rgba(255,255,255,0.05)' },
415
+ ticks: { color: '#999' }
416
+ },
417
+ y: {
418
+ display: true,
419
+ grid: { color: 'rgba(255,255,255,0.05)' },
420
+ ticks: { color: '#999' }
421
+ }
422
+ }
423
+ }
424
+ });
425
+ }
426
+
427
+ function updateChart(step, loss) {
428
+ lossHistory.push({ step, loss });
429
+ chart.data.labels.push(step);
430
+ chart.data.datasets[0].data.push(loss);
431
+ chart.update('none');
432
+ chartPlaceholder.style.display = 'none';
433
+ }
434
+
435
+ function resetChart() {
436
+ lossHistory = [];
437
+ chart.data.labels = [];
438
+ chart.data.datasets[0].data = [];
439
+ chart.update();
440
+ chartPlaceholder.style.display = 'flex';
441
+ }
442
+
443
+ // Reset form fields to task defaults
444
+ function resetFormToTaskDefaults() {
445
+ // Reset file uploads
446
+ document.getElementById('train_filename').value = '';
447
+ document.getElementById('eval_filename').value = '';
448
+ document.getElementById('train_file_info').textContent = 'CSV or JSONL';
449
+ document.getElementById('eval_file_info').textContent = 'CSV or JSONL';
450
+ document.getElementById('train-upload').classList.remove('uploaded');
451
+ document.getElementById('eval-upload').classList.remove('uploaded');
452
+
453
+ // Reset HF dataset fields
454
+ document.getElementById('hf_dataset').value = '';
455
+ document.getElementById('hf_subset').value = '';
456
+ document.getElementById('hf_train_split').value = 'train';
457
+ document.getElementById('hf_eval_split').value = '';
458
+
459
+ // Reset hub settings
460
+ document.getElementById('push_to_hub').checked = false;
461
+ document.getElementById('hf_username').value = '';
462
+ document.getElementById('hub-fields').style.display = 'none';
463
+
464
+ // Reset LoRA settings
465
+ document.getElementById('lora_enabled').checked = false;
466
+ document.getElementById('lora_r').value = '64';
467
+ document.getElementById('lora_alpha').value = '128';
468
+ document.getElementById('lora_dropout').value = '0.1';
469
+ document.getElementById('lora_target_preset').value = 'query, key, value, dense';
470
+ document.getElementById('lora_target_modules').value = 'query, key, value, dense';
471
+ document.getElementById('lora-fields').style.display = 'none';
472
+
473
+ // Reset model settings
474
+ document.getElementById('max_seq_length').value = ''; // Empty = auto-detect
475
+ document.getElementById('gradient_checkpointing').checked = false;
476
+
477
+ // Reset Unsloth settings
478
+ document.getElementById('unsloth_enabled').checked = false;
479
+ document.getElementById('unsloth_save_method').value = 'merged_16bit';
480
+ document.getElementById('unsloth-fields').style.display = 'none';
481
+
482
+ // Reset Matryoshka settings
483
+ document.getElementById('matryoshka_enabled').checked = false;
484
+ document.getElementById('matryoshka_dims').value = '768,512,256,128,64';
485
+ document.getElementById('matryoshka-fields').style.display = 'none';
486
+
487
+ // Apply defaults for currently selected task (must be last to properly show/hide UI elements)
488
+ const taskSelect = document.getElementById('task');
489
+ const task = tasksData.find(t => t.name === taskSelect.value);
490
+ if (task) {
491
+ updateTaskUI(task);
492
+ }
493
+ }
494
+
495
+ // Event Listeners
496
+ function setupEventListeners() {
497
+ // New Training Modal
498
+ newTrainingBtn.addEventListener('click', () => {
499
+ // Generate random project name when opening modal
500
+ document.getElementById('project_name').value = generateProjectName();
501
+
502
+ // Reset form and apply defaults for currently selected task
503
+ resetFormToTaskDefaults();
504
+
505
+ // Reset wizard to step 1
506
+ resetWizard();
507
+
508
+ newTrainingModal.style.display = 'flex';
509
+ });
510
+
511
+ closeNewTraining.addEventListener('click', () => {
512
+ newTrainingModal.style.display = 'none';
513
+ });
514
+
515
+ newTrainingModal.addEventListener('click', (e) => {
516
+ if (e.target === newTrainingModal) {
517
+ newTrainingModal.style.display = 'none';
518
+ }
519
+ });
520
+
521
+ // Form Submit
522
+ trainForm.addEventListener('submit', handleTrainSubmit);
523
+
524
+ // Refresh
525
+ document.getElementById('refresh-btn').addEventListener('click', loadRuns);
526
+
527
+ // Stop button
528
+ document.getElementById('stop-btn').addEventListener('click', stopTraining);
529
+
530
+ // Delete button
531
+ document.getElementById('delete-btn').addEventListener('click', deleteRun);
532
+
533
+ // Escape key closes modals
534
+ document.addEventListener('keydown', (e) => {
535
+ if (e.key === 'Escape') {
536
+ newTrainingModal.style.display = 'none';
537
+ }
538
+ });
539
+
540
+ // Metric selector
541
+ document.getElementById('metric-select').addEventListener('change', (e) => {
542
+ currentMetric = e.target.value;
543
+ updateChartWithMetric(currentMetric);
544
+ });
545
+ }
546
+
547
+ // Tabs
548
+ function setupTabs() {
549
+ const tabs = document.querySelectorAll('.tab');
550
+ tabs.forEach(tab => {
551
+ tab.addEventListener('click', () => {
552
+ // Update active tab
553
+ tabs.forEach(t => t.classList.remove('active'));
554
+ tab.classList.add('active');
555
+
556
+ // Update active content
557
+ const tabId = tab.dataset.tab;
558
+ document.querySelectorAll('.tab-content').forEach(content => {
559
+ content.classList.remove('active');
560
+ });
561
+ document.getElementById(`tab-${tabId}`).classList.add('active');
562
+
563
+ // Update data source state
564
+ currentDataSource = tabId;
565
+ });
566
+ });
567
+ }
568
+
569
+ // Hub Push Toggle
570
+ function setupHubToggle() {
571
+ const pushToHub = document.getElementById('push_to_hub');
572
+ const hubFields = document.getElementById('hub-fields');
573
+
574
+ pushToHub.addEventListener('change', () => {
575
+ hubFields.style.display = pushToHub.checked ? 'block' : 'none';
576
+ });
577
+ }
578
+
579
+ // LoRA Toggle
580
+ function setupLoraToggle() {
581
+ const loraEnabled = document.getElementById('lora_enabled');
582
+ const loraFields = document.getElementById('lora-fields');
583
+
584
+ loraEnabled.addEventListener('change', () => {
585
+ loraFields.style.display = loraEnabled.checked ? 'block' : 'none';
586
+ });
587
+
588
+ // Target modules preset select
589
+ const presetSelect = document.getElementById('lora_target_preset');
590
+ const targetInput = document.getElementById('lora_target_modules');
591
+
592
+ presetSelect.addEventListener('change', () => {
593
+ targetInput.value = presetSelect.value;
594
+ });
595
+ }
596
+
597
+ // Unsloth Toggle
598
+ function setupUnslothToggle() {
599
+ const unslothEnabled = document.getElementById('unsloth_enabled');
600
+ const unslothFields = document.getElementById('unsloth-fields');
601
+
602
+ unslothEnabled.addEventListener('change', () => {
603
+ unslothFields.style.display = unslothEnabled.checked ? 'block' : 'none';
604
+ });
605
+ }
606
+
607
+ // Matryoshka Toggle
608
+ function setupMatryoshkaToggle() {
609
+ const matryoshkaEnabled = document.getElementById('matryoshka_enabled');
610
+ const matryoshkaFields = document.getElementById('matryoshka-fields');
611
+
612
+ matryoshkaEnabled.addEventListener('change', () => {
613
+ matryoshkaFields.style.display = matryoshkaEnabled.checked ? 'block' : 'none';
614
+ });
615
+ }
616
+
617
+ // Wizard Navigation
618
+ let currentWizardStep = 1;
619
+ const totalWizardSteps = 3;
620
+
621
+ function setupWizard() {
622
+ const nextBtn = document.getElementById('wizard-next');
623
+ const backBtn = document.getElementById('wizard-back');
624
+ const startBtn = document.getElementById('start-btn');
625
+
626
+ nextBtn.addEventListener('click', () => {
627
+ if (validateCurrentStep()) {
628
+ goToStep(currentWizardStep + 1);
629
+ }
630
+ });
631
+
632
+ backBtn.addEventListener('click', () => {
633
+ goToStep(currentWizardStep - 1);
634
+ });
635
+ }
636
+
637
+ function goToStep(step) {
638
+ if (step < 1 || step > totalWizardSteps) return;
639
+
640
+ // Hide current step
641
+ document.getElementById(`wizard-step-${currentWizardStep}`).style.display = 'none';
642
+ document.querySelector(`.wizard-step[data-step="${currentWizardStep}"]`).classList.remove('active');
643
+
644
+ // Mark previous steps as completed
645
+ if (step > currentWizardStep) {
646
+ document.querySelector(`.wizard-step[data-step="${currentWizardStep}"]`).classList.add('completed');
647
+ }
648
+
649
+ // Show new step
650
+ currentWizardStep = step;
651
+ document.getElementById(`wizard-step-${currentWizardStep}`).style.display = 'block';
652
+ document.querySelector(`.wizard-step[data-step="${currentWizardStep}"]`).classList.add('active');
653
+ document.querySelector(`.wizard-step[data-step="${currentWizardStep}"]`).classList.remove('completed');
654
+
655
+ // Update button visibility
656
+ const backBtn = document.getElementById('wizard-back');
657
+ const nextBtn = document.getElementById('wizard-next');
658
+ const startBtn = document.getElementById('start-btn');
659
+
660
+ backBtn.style.display = currentWizardStep > 1 ? 'block' : 'none';
661
+ nextBtn.style.display = currentWizardStep < totalWizardSteps ? 'block' : 'none';
662
+ startBtn.style.display = currentWizardStep === totalWizardSteps ? 'block' : 'none';
663
+ }
664
+
665
+ function validateCurrentStep() {
666
+ if (currentWizardStep === 1) {
667
+ // Validate step 1: project name, task, model, data
668
+ const projectName = document.getElementById('project_name').value.trim();
669
+ if (!projectName) {
670
+ alert('Please enter a project name');
671
+ return false;
672
+ }
673
+ if (!/^[a-zA-Z0-9][a-zA-Z0-9-]*$/.test(projectName)) {
674
+ alert('Project name must start with alphanumeric and contain only alphanumeric characters and hyphens');
675
+ return false;
676
+ }
677
+
678
+ // Validate data source
679
+ if (currentDataSource === 'file') {
680
+ const trainFile = document.getElementById('train_filename').value;
681
+ if (!trainFile) {
682
+ alert('Please upload training data');
683
+ return false;
684
+ }
685
+ } else {
686
+ const hfDataset = document.getElementById('hf_dataset').value.trim();
687
+ if (!hfDataset) {
688
+ alert('Please enter a HuggingFace dataset name');
689
+ return false;
690
+ }
691
+ }
692
+ }
693
+ return true;
694
+ }
695
+
696
+ function resetWizard() {
697
+ currentWizardStep = 1;
698
+
699
+ // Reset step indicators
700
+ document.querySelectorAll('.wizard-step').forEach((step, index) => {
701
+ step.classList.remove('active', 'completed');
702
+ if (index === 0) step.classList.add('active');
703
+ });
704
+
705
+ // Show only first step
706
+ for (let i = 1; i <= totalWizardSteps; i++) {
707
+ document.getElementById(`wizard-step-${i}`).style.display = i === 1 ? 'block' : 'none';
708
+ }
709
+
710
+ // Reset buttons
711
+ document.getElementById('wizard-back').style.display = 'none';
712
+ document.getElementById('wizard-next').style.display = 'block';
713
+ document.getElementById('start-btn').style.display = 'none';
714
+ }
715
+
716
+ // File Uploads
717
+ function setupFileUploads() {
718
+ // Training data upload
719
+ setupSingleUpload({
720
+ uploadBox: document.getElementById('train-upload'),
721
+ fileInput: document.getElementById('train_file'),
722
+ fileInfo: document.getElementById('train_file_info'),
723
+ hiddenInput: document.getElementById('train_filename'),
724
+ fileType: 'train'
725
+ });
726
+
727
+ // Evaluation data upload
728
+ setupSingleUpload({
729
+ uploadBox: document.getElementById('eval-upload'),
730
+ fileInput: document.getElementById('eval_file'),
731
+ fileInfo: document.getElementById('eval_file_info'),
732
+ hiddenInput: document.getElementById('eval_filename'),
733
+ fileType: 'eval'
734
+ });
735
+ }
736
+
737
+ function setupSingleUpload({ uploadBox, fileInput, fileInfo, hiddenInput, fileType }) {
738
+ uploadBox.addEventListener('click', () => fileInput.click());
739
+
740
+ uploadBox.addEventListener('dragover', (e) => {
741
+ e.preventDefault();
742
+ uploadBox.style.borderColor = 'var(--heather-light)';
743
+ });
744
+
745
+ uploadBox.addEventListener('dragleave', () => {
746
+ uploadBox.style.borderColor = '';
747
+ });
748
+
749
+ uploadBox.addEventListener('drop', (e) => {
750
+ e.preventDefault();
751
+ uploadBox.style.borderColor = '';
752
+ if (e.dataTransfer.files.length) {
753
+ fileInput.files = e.dataTransfer.files;
754
+ handleFileUpload(e.dataTransfer.files[0], fileType, uploadBox, fileInfo, hiddenInput);
755
+ }
756
+ });
757
+
758
+ fileInput.addEventListener('change', () => {
759
+ if (fileInput.files.length) {
760
+ handleFileUpload(fileInput.files[0], fileType, uploadBox, fileInfo, hiddenInput);
761
+ }
762
+ });
763
+ }
764
+
765
+ async function handleFileUpload(file, fileType, uploadBox, fileInfo, hiddenInput) {
766
+ const formData = new FormData();
767
+ formData.append('file', file);
768
+ formData.append('file_type', fileType);
769
+
770
+ try {
771
+ const response = await fetch('/upload', {
772
+ method: 'POST',
773
+ body: formData
774
+ });
775
+
776
+ if (response.ok) {
777
+ const data = await response.json();
778
+ hiddenInput.value = data.filepath;
779
+ fileInfo.textContent = `${file.name} (${data.row_count} rows)`;
780
+ uploadBox.classList.add('uploaded');
781
+ } else {
782
+ const error = await response.json();
783
+ throw new Error(error.detail || 'Upload failed');
784
+ }
785
+ } catch (error) {
786
+ console.error('Upload error:', error);
787
+ alert(`Failed to upload file: ${error.message}`);
788
+ }
789
+ }
790
+
791
+ // Training
792
+ async function handleTrainSubmit(e) {
793
+ e.preventDefault();
794
+
795
+ const projectName = document.getElementById('project_name').value.trim();
796
+
797
+ // Validate project name
798
+ if (!projectName) {
799
+ alert('Please enter a project name');
800
+ return;
801
+ }
802
+
803
+ if (!/^[a-zA-Z0-9][a-zA-Z0-9-]*$/.test(projectName)) {
804
+ alert('Project name must start with alphanumeric and contain only alphanumeric characters and hyphens');
805
+ return;
806
+ }
807
+
808
+ // Get selected precision mode
809
+ const precision = document.getElementById('precision').value || 'fp32';
810
+
811
+ // Get loss variant (only if the field is visible/applicable)
812
+ const lossVariantField = document.getElementById('loss-variant-field');
813
+ const lossVariant = lossVariantField && lossVariantField.style.display !== 'none'
814
+ ? document.getElementById('loss_variant').value
815
+ : null;
816
+
817
+ const formData = {
818
+ project_name: projectName,
819
+ task: document.getElementById('task').value,
820
+ base_model: document.getElementById('base_model').value,
821
+ loss_variant: lossVariant,
822
+ epochs: parseInt(document.getElementById('epochs').value),
823
+ batch_size: parseInt(document.getElementById('batch_size').value),
824
+ learning_rate: parseFloat(document.getElementById('learning_rate').value),
825
+
826
+ // Advanced settings
827
+ warmup_ratio: parseFloat(document.getElementById('warmup_ratio').value),
828
+ weight_decay: parseFloat(document.getElementById('weight_decay').value),
829
+ gradient_accumulation_steps: parseInt(document.getElementById('gradient_accumulation_steps').value),
830
+ logging_steps: parseInt(document.getElementById('logging_steps').value),
831
+ eval_steps: parseInt(document.getElementById('eval_steps').value),
832
+ save_steps: parseInt(document.getElementById('save_steps').value),
833
+ fp16: precision === 'fp16',
834
+ bf16: precision === 'bf16',
835
+ optimizer: document.getElementById('optimizer').value,
836
+ scheduler: document.getElementById('scheduler').value,
837
+
838
+ // LoRA settings
839
+ lora_enabled: document.getElementById('lora_enabled').checked,
840
+ lora_r: parseInt(document.getElementById('lora_r').value),
841
+ lora_alpha: parseInt(document.getElementById('lora_alpha').value),
842
+ lora_dropout: parseFloat(document.getElementById('lora_dropout').value),
843
+ lora_target_modules: document.getElementById('lora_target_modules').value.trim(),
844
+
845
+ // Model settings
846
+ max_seq_length: document.getElementById('max_seq_length').value
847
+ ? parseInt(document.getElementById('max_seq_length').value)
848
+ : null, // null = auto-detect from model
849
+ gradient_checkpointing: document.getElementById('gradient_checkpointing').checked,
850
+
851
+ // Unsloth settings
852
+ unsloth_enabled: document.getElementById('unsloth_enabled').checked,
853
+ unsloth_save_method: document.getElementById('unsloth_save_method').value,
854
+
855
+ // Matryoshka settings (only send dims if enabled)
856
+ matryoshka_dims: document.getElementById('matryoshka_enabled').checked
857
+ ? document.getElementById('matryoshka_dims').value.trim()
858
+ : null,
859
+
860
+ // Hub settings
861
+ push_to_hub: document.getElementById('push_to_hub').checked,
862
+ hf_username: document.getElementById('hf_username').value.trim() || null,
863
+
864
+ // Task-specific parameters
865
+ ...getTaskSpecificParams(),
866
+ };
867
+
868
+ // Add data source based on selected tab
869
+ if (currentDataSource === 'file') {
870
+ formData.train_filename = document.getElementById('train_filename').value;
871
+ formData.eval_filename = document.getElementById('eval_filename').value || null;
872
+
873
+ if (!formData.train_filename) {
874
+ alert('Please upload training data');
875
+ return;
876
+ }
877
+ } else {
878
+ formData.hf_dataset = document.getElementById('hf_dataset').value.trim();
879
+ formData.hf_subset = document.getElementById('hf_subset').value.trim() || null;
880
+ formData.hf_train_split = document.getElementById('hf_train_split').value.trim() || 'train';
881
+ formData.hf_eval_split = document.getElementById('hf_eval_split').value.trim() || null;
882
+
883
+ if (!formData.hf_dataset) {
884
+ alert('Please enter a HuggingFace dataset name');
885
+ return;
886
+ }
887
+ }
888
+
889
+ // Validate hub settings
890
+ if (formData.push_to_hub && !formData.hf_username) {
891
+ alert('Please enter your HuggingFace username');
892
+ return;
893
+ }
894
+
895
+ // Show loading state
896
+ const submitBtn = trainForm.querySelector('button[type="submit"]');
897
+ const originalBtnText = submitBtn.textContent;
898
+ submitBtn.disabled = true;
899
+ submitBtn.textContent = 'Starting...';
900
+
901
+ try {
902
+ const response = await fetch('/train', {
903
+ method: 'POST',
904
+ headers: { 'Content-Type': 'application/json' },
905
+ body: JSON.stringify(formData)
906
+ });
907
+
908
+ if (response.ok) {
909
+ const data = await response.json();
910
+ newTrainingModal.style.display = 'none';
911
+ resetChart();
912
+ clearLogs();
913
+
914
+ // Show status banner
915
+ showStatusBanner('Initializing training...');
916
+
917
+ loadRuns();
918
+ // selectRun will load historical updates and start polling if active
919
+ selectRun(data.run_id);
920
+ } else {
921
+ const error = await response.json();
922
+ alert(`Training failed: ${error.detail || 'Unknown error'}`);
923
+ }
924
+ } catch (error) {
925
+ console.error('Training error:', error);
926
+ alert('Failed to start training');
927
+ } finally {
928
+ submitBtn.disabled = false;
929
+ submitBtn.textContent = originalBtnText;
930
+ }
931
+ }
932
+
933
+ async function stopTraining() {
934
+ if (!selectedRunId) return;
935
+
936
+ const stopBtn = document.getElementById('stop-btn');
937
+ stopBtn.classList.add('loading');
938
+ stopBtn.disabled = true;
939
+
940
+ try {
941
+ const response = await fetch('/stop', {
942
+ method: 'POST',
943
+ headers: { 'Content-Type': 'application/json' },
944
+ body: JSON.stringify({ run_id: selectedRunId })
945
+ });
946
+
947
+ if (response.ok) {
948
+ stopPolling();
949
+ // Hide stop button and status banner
950
+ stopBtn.style.display = 'none';
951
+ hideStatusBanner();
952
+ // Update status display
953
+ const statusEl = document.getElementById('summary-status');
954
+ statusEl.textContent = 'stopped';
955
+ statusEl.className = 'status-chip small stopped';
956
+ // Hide progress bar
957
+ document.getElementById('progress-container').style.display = 'none';
958
+ // Refresh run list to get updated status
959
+ await loadRuns();
960
+ }
961
+ } catch (error) {
962
+ console.error('Stop error:', error);
963
+ } finally {
964
+ stopBtn.classList.remove('loading');
965
+ stopBtn.disabled = false;
966
+ }
967
+ }
968
+
969
+ async function deleteRun() {
970
+ if (!selectedRunId) return;
971
+
972
+ if (!confirm('Are you sure you want to delete this run?')) return;
973
+
974
+ try {
975
+ await fetch(`/runs/${selectedRunId}`, { method: 'DELETE' });
976
+ stopPolling();
977
+ selectedRunId = null;
978
+ projectSummary.style.display = 'none';
979
+ loadRuns();
980
+ } catch (error) {
981
+ console.error('Delete error:', error);
982
+ }
983
+ }
984
+
985
+ // Runs
986
+ async function loadRuns(autoSelectLatest = false) {
987
+ try {
988
+ const response = await fetch('/runs');
989
+ runs = await response.json();
990
+ renderRunList();
991
+
992
+ // Check for active run
993
+ const activeResponse = await fetch('/active_run_id');
994
+ const activeData = await activeResponse.json();
995
+ activeRunId = activeData.run_id;
996
+
997
+ // If there's an active run and we're not polling, start polling
998
+ // Reset pollLine since we haven't loaded historical updates yet
999
+ if (activeRunId && !pollInterval) {
1000
+ pollLine = 0;
1001
+ startPolling(activeRunId);
1002
+ }
1003
+
1004
+ // Auto-select latest run if requested and no run is selected
1005
+ if (autoSelectLatest && runs.length > 0 && !selectedRunId) {
1006
+ selectRun(runs[0].id);
1007
+ }
1008
+ } catch (error) {
1009
+ console.error('Failed to load runs:', error);
1010
+ }
1011
+ }
1012
+
1013
+ function renderRunList() {
1014
+ if (runs.length === 0) {
1015
+ runList.innerHTML = '<div class="run-item placeholder">No training runs yet</div>';
1016
+ return;
1017
+ }
1018
+
1019
+ runList.innerHTML = runs.map(run => {
1020
+ const config = JSON.parse(run.config || '{}');
1021
+ const date = new Date(run.created_at).toLocaleDateString();
1022
+ const isSelected = run.id === selectedRunId;
1023
+ const projectName = config.project_name || `Run #${run.id}`;
1024
+
1025
+ return `
1026
+ <div class="run-item ${isSelected ? 'active' : ''}" data-id="${run.id}">
1027
+ <span class="run-status-icon ${run.status}"></span>
1028
+ <div class="run-info">
1029
+ <div class="run-id">${projectName}</div>
1030
+ <div class="run-date">${date}</div>
1031
+ </div>
1032
+ </div>
1033
+ `;
1034
+ }).join('');
1035
+
1036
+ // Add click handlers
1037
+ runList.querySelectorAll('.run-item:not(.placeholder)').forEach(item => {
1038
+ item.addEventListener('click', () => {
1039
+ selectRun(parseInt(item.dataset.id));
1040
+ });
1041
+ });
1042
+ }
1043
+
1044
+ async function selectRun(runId) {
1045
+ selectedRunId = runId;
1046
+ renderRunList();
1047
+
1048
+ // Reset chart and metrics when switching runs
1049
+ metricsData = {};
1050
+ resetChart();
1051
+ clearLogs();
1052
+
1053
+ // Reset progress display
1054
+ document.getElementById('current-epoch').textContent = '0';
1055
+ document.getElementById('current-step').textContent = '0';
1056
+ document.getElementById('current-loss').textContent = '--';
1057
+ document.getElementById('current-eta').textContent = '--';
1058
+ document.getElementById('progress-container').style.display = 'none';
1059
+ document.getElementById('progress-fill').style.width = '0%';
1060
+ document.getElementById('progress-pct').textContent = '0%';
1061
+ document.getElementById('progress-speed').textContent = '-- it/s';
1062
+ document.getElementById('status-banner').style.display = 'none';
1063
+
1064
+ try {
1065
+ const response = await fetch(`/runs/${runId}`);
1066
+ const run = await response.json();
1067
+ showRunSummary(run);
1068
+
1069
+ // Load historical updates (logs and progress) for this run
1070
+ await loadHistoricalUpdates(runId);
1071
+
1072
+ // Load metrics from TensorBoard files for this run
1073
+ await loadMetrics(runId);
1074
+
1075
+ // Show progress bar for all runs
1076
+ document.getElementById('progress-container').style.display = 'block';
1077
+
1078
+ // Update header display based on run status
1079
+ if (run.status === 'running') {
1080
+ if (pollingRunId !== runId) {
1081
+ startPolling(runId);
1082
+ }
1083
+ } else {
1084
+ // For finished/stopped/error runs, show final state
1085
+ document.getElementById('current-eta').textContent = run.status === 'completed' ? '0' : '--';
1086
+
1087
+ // Show final loss from metrics if available
1088
+ const lossData = metricsData['loss'] || metricsData['eval_loss'];
1089
+ if (lossData && lossData.length > 0) {
1090
+ const finalLoss = lossData[lossData.length - 1];
1091
+ if (finalLoss && finalLoss.value !== null) {
1092
+ document.getElementById('current-loss').textContent = finalLoss.value.toFixed(4);
1093
+ }
1094
+ }
1095
+
1096
+ // For completed runs, show 100% progress
1097
+ if (run.status === 'completed') {
1098
+ document.getElementById('progress-fill').style.width = '100%';
1099
+ document.getElementById('progress-pct').textContent = '100%';
1100
+ document.getElementById('progress-speed').textContent = 'Complete';
1101
+ }
1102
+ }
1103
+ } catch (error) {
1104
+ console.error('Failed to load run:', error);
1105
+ }
1106
+ }
1107
+
1108
+ async function loadHistoricalUpdates(runId) {
1109
+ try {
1110
+ // Load all updates from the beginning (since_line=0)
1111
+ const response = await fetch(`/runs/${runId}/updates?since_line=0`);
1112
+ if (!response.ok) return;
1113
+
1114
+ const data = await response.json();
1115
+
1116
+ // Process historical updates, but skip status updates since we have
1117
+ // the authoritative status from the API (status in update file may be stale)
1118
+ data.updates.forEach(update => {
1119
+ if (update.type !== 'status') {
1120
+ handleUpdate(update);
1121
+ }
1122
+ });
1123
+
1124
+ // Set pollLine for future polling
1125
+ pollLine = data.next_line;
1126
+
1127
+ } catch (error) {
1128
+ console.error('Failed to load historical updates:', error);
1129
+ }
1130
+ }
1131
+
1132
+ function showRunSummary(run) {
1133
+ const config = JSON.parse(run.config || '{}');
1134
+
1135
+ document.getElementById('summary-project-name').textContent = config.project_name || `Run #${run.id}`;
1136
+ document.getElementById('summary-status').textContent = run.status;
1137
+ document.getElementById('summary-status').className = `status-chip small ${run.status}`;
1138
+
1139
+ document.getElementById('sum-task').textContent = config.task || '--';
1140
+ document.getElementById('sum-loss').textContent = config.loss_variant || 'default';
1141
+ document.getElementById('sum-model').textContent = config.base_model?.split('/').pop() || '--';
1142
+
1143
+ // Show data source
1144
+ let dataSource = '--';
1145
+ if (config.train_filename) {
1146
+ dataSource = config.train_filename.split('/').pop();
1147
+ } else if (config.hf_dataset) {
1148
+ dataSource = config.hf_dataset;
1149
+ }
1150
+ document.getElementById('sum-data').textContent = dataSource;
1151
+
1152
+ document.getElementById('sum-epochs').textContent = config.epochs || '--';
1153
+ document.getElementById('sum-batch').textContent = config.batch_size || '--';
1154
+ document.getElementById('sum-lr').textContent = config.learning_rate || '--';
1155
+
1156
+ // Show LoRA info if enabled
1157
+ const loraRow = document.getElementById('sum-lora-row');
1158
+ if (config.lora_enabled) {
1159
+ loraRow.style.display = 'flex';
1160
+ document.getElementById('sum-lora').textContent = `r=${config.lora_r}, a=${config.lora_alpha}`;
1161
+ } else {
1162
+ loraRow.style.display = 'none';
1163
+ }
1164
+
1165
+ // Show Matryoshka info if enabled
1166
+ const matryoshkaRow = document.getElementById('sum-matryoshka-row');
1167
+ if (config.matryoshka_dims) {
1168
+ matryoshkaRow.style.display = 'flex';
1169
+ document.getElementById('sum-matryoshka').textContent = config.matryoshka_dims;
1170
+ } else {
1171
+ matryoshkaRow.style.display = 'none';
1172
+ }
1173
+
1174
+ // Show/hide stop button based on status
1175
+ const stopBtn = document.getElementById('stop-btn');
1176
+ stopBtn.style.display = run.status === 'running' ? 'block' : 'none';
1177
+
1178
+ // Enable artifacts button for non-running runs (completed, stopped, error)
1179
+ // These may have artifacts like checkpoints or partial models
1180
+ artifactsBtn.disabled = run.status === 'running' || run.status === 'pending';
1181
+
1182
+ projectSummary.style.display = 'block';
1183
+ }
1184
+
1185
+ function updateRunStatus(runId, status) {
1186
+ const run = runs.find(r => r.id === runId);
1187
+ if (run) {
1188
+ run.status = status;
1189
+ renderRunList();
1190
+ if (selectedRunId === runId) {
1191
+ showRunSummary(run);
1192
+ }
1193
+ }
1194
+ }
1195
+
1196
+ // Progress
1197
+ function updateProgress(data) {
1198
+ const progressType = data.type || 'progress';
1199
+
1200
+ if (progressType === 'train_start') {
1201
+ // Hide status banner and show progress bar when training starts
1202
+ hideStatusBanner();
1203
+ document.getElementById('progress-container').style.display = 'block';
1204
+ document.getElementById('progress-fill').style.width = '0%';
1205
+ document.getElementById('progress-pct').textContent = '0%';
1206
+ document.getElementById('progress-speed').textContent = '-- it/s';
1207
+ return;
1208
+ }
1209
+
1210
+ if (progressType === 'train_end') {
1211
+ // Training complete - fill bar to 100%
1212
+ document.getElementById('progress-fill').style.width = '100%';
1213
+ document.getElementById('progress-pct').textContent = '100%';
1214
+ document.getElementById('progress-speed').textContent = 'Complete';
1215
+ document.getElementById('current-eta').textContent = 'Done';
1216
+ return;
1217
+ }
1218
+
1219
+ // Update metrics display
1220
+ const totalEpochs = data.total_epochs || 0;
1221
+ const epochDisplay = totalEpochs ? `${data.epoch?.toFixed(1) || 0}/${totalEpochs}` : (data.epoch?.toFixed(2) || 0);
1222
+ document.getElementById('current-epoch').textContent = epochDisplay;
1223
+
1224
+ const totalSteps = data.total_steps || 0;
1225
+ const stepDisplay = totalSteps ? `${data.step || 0}/${totalSteps}` : (data.step || 0);
1226
+ document.getElementById('current-step').textContent = stepDisplay;
1227
+
1228
+ document.getElementById('current-loss').textContent = data.loss?.toFixed(4) || '--';
1229
+
1230
+ // Update ETA
1231
+ if (data.eta_seconds && data.eta_seconds > 0) {
1232
+ document.getElementById('current-eta').textContent = formatTime(data.eta_seconds);
1233
+ }
1234
+
1235
+ // Update progress bar
1236
+ if (data.progress_pct !== undefined) {
1237
+ document.getElementById('progress-container').style.display = 'block';
1238
+ document.getElementById('progress-fill').style.width = `${data.progress_pct}%`;
1239
+ document.getElementById('progress-pct').textContent = `${data.progress_pct.toFixed(1)}%`;
1240
+ }
1241
+
1242
+ // Update speed
1243
+ if (data.steps_per_sec) {
1244
+ document.getElementById('progress-speed').textContent = `${data.steps_per_sec.toFixed(2)} it/s`;
1245
+ }
1246
+
1247
+ if (data.step && data.loss) {
1248
+ updateChart(data.step, data.loss);
1249
+ }
1250
+ }
1251
+
1252
+ function formatTime(seconds) {
1253
+ if (seconds < 60) {
1254
+ return `${Math.round(seconds)}s`;
1255
+ } else if (seconds < 3600) {
1256
+ const mins = Math.floor(seconds / 60);
1257
+ const secs = Math.round(seconds % 60);
1258
+ return `${mins}m ${secs}s`;
1259
+ } else {
1260
+ const hours = Math.floor(seconds / 3600);
1261
+ const mins = Math.floor((seconds % 3600) / 60);
1262
+ return `${hours}h ${mins}m`;
1263
+ }
1264
+ }
1265
+
1266
+ // Logs
1267
+ function appendLog(message) {
1268
+ const logCount = document.getElementById('log-count');
1269
+ const count = parseInt(logCount.textContent) + 1;
1270
+ logCount.textContent = count;
1271
+
1272
+ logContent.textContent += message + '\n';
1273
+ logContent.scrollTop = logContent.scrollHeight;
1274
+ }
1275
+
1276
+ function clearLogs() {
1277
+ logContent.textContent = '';
1278
+ document.getElementById('log-count').textContent = '0';
1279
+ }
1280
+
1281
+ // Status Banner
1282
+ function showStatusBanner(message) {
1283
+ const banner = document.getElementById('status-banner');
1284
+ document.getElementById('status-message').textContent = message;
1285
+ banner.style.display = 'flex';
1286
+ }
1287
+
1288
+ function hideStatusBanner() {
1289
+ document.getElementById('status-banner').style.display = 'none';
1290
+ }
1291
+
1292
+ // Handlers
1293
+ function handleTrainingComplete(data) {
1294
+ appendLog(`Training completed! Model saved to: ${data.output_dir}`);
1295
+ hideStatusBanner();
1296
+ // Hide stop button and update status
1297
+ document.getElementById('stop-btn').style.display = 'none';
1298
+ document.getElementById('progress-container').style.display = 'none';
1299
+ const statusEl = document.getElementById('summary-status');
1300
+ statusEl.textContent = 'completed';
1301
+ statusEl.className = 'status-chip small completed';
1302
+ stopPolling();
1303
+ loadRuns();
1304
+ }
1305
+
1306
+ function handleTrainingError(data) {
1307
+ appendLog(`Error: ${data.message}`);
1308
+ hideStatusBanner();
1309
+ // Hide stop button and update status
1310
+ document.getElementById('stop-btn').style.display = 'none';
1311
+ document.getElementById('progress-container').style.display = 'none';
1312
+ const statusEl = document.getElementById('summary-status');
1313
+ statusEl.textContent = 'error';
1314
+ statusEl.className = 'status-chip small error';
1315
+ stopPolling();
1316
+ loadRuns();
1317
+ }
1318
+
1319
+ // Metrics from TensorBoard files
1320
+ async function loadMetrics(runId) {
1321
+ try {
1322
+ const response = await fetch(`/runs/${runId}/metrics`);
1323
+ if (response.ok) {
1324
+ const data = await response.json();
1325
+ metricsData = data.metrics || {};
1326
+ updateMetricSelector();
1327
+ updateChartWithMetric(currentMetric);
1328
+ }
1329
+ } catch (error) {
1330
+ console.error('Failed to load metrics:', error);
1331
+ }
1332
+ }
1333
+
1334
+ function updateMetricSelector() {
1335
+ const select = document.getElementById('metric-select');
1336
+
1337
+ // Filter to only metrics with multiple valid data points (single-point metrics are useless for charts)
1338
+ const availableMetrics = Object.keys(metricsData).filter(key => {
1339
+ const data = metricsData[key];
1340
+ if (!data || data.length < 2) return false; // Need at least 2 points for a line chart
1341
+ return data.filter(d => d.value !== null).length >= 2;
1342
+ });
1343
+
1344
+ if (availableMetrics.length === 0) {
1345
+ select.innerHTML = '<option value="loss">Loss</option>';
1346
+ return;
1347
+ }
1348
+
1349
+ // Sort metrics: exact 'loss' first, eval_loss second, then alphabetically
1350
+ availableMetrics.sort((a, b) => {
1351
+ const aLower = a.toLowerCase();
1352
+ const bLower = b.toLowerCase();
1353
+ // Exact 'loss' first (case-insensitive)
1354
+ if (aLower === 'loss') return -1;
1355
+ if (bLower === 'loss') return 1;
1356
+ // eval_loss second
1357
+ if (aLower === 'eval_loss') return -1;
1358
+ if (bLower === 'eval_loss') return 1;
1359
+ // Push flos/runtime/samples_per_second to the end (less useful metrics)
1360
+ const aIsUtility = aLower.includes('flos') || aLower.includes('runtime') || aLower.includes('per_second');
1361
+ const bIsUtility = bLower.includes('flos') || bLower.includes('runtime') || bLower.includes('per_second');
1362
+ if (aIsUtility && !bIsUtility) return 1;
1363
+ if (bIsUtility && !aIsUtility) return -1;
1364
+ // Then alphabetically
1365
+ return a.localeCompare(b);
1366
+ });
1367
+
1368
+ select.innerHTML = availableMetrics.map(metric => {
1369
+ const label = metric.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase());
1370
+ return `<option value="${metric}">${label}</option>`;
1371
+ }).join('');
1372
+
1373
+ // Find the best metric to select (prefer exact 'loss')
1374
+ let selectedMetric = currentMetric;
1375
+ if (!availableMetrics.includes(selectedMetric)) {
1376
+ // Try to find exact 'loss' (case-insensitive)
1377
+ selectedMetric = availableMetrics.find(m => m.toLowerCase() === 'loss')
1378
+ || availableMetrics.find(m => m.toLowerCase() === 'eval_loss')
1379
+ || availableMetrics[0];
1380
+ }
1381
+
1382
+ currentMetric = selectedMetric;
1383
+ select.value = currentMetric;
1384
+ }
1385
+
1386
+ function updateChartWithMetric(metric) {
1387
+ const data = metricsData[metric];
1388
+ if (!data || data.length === 0) {
1389
+ return;
1390
+ }
1391
+
1392
+ // Update chart with metric data
1393
+ chart.data.labels = data.map(d => d.step);
1394
+ chart.data.datasets[0].data = data.map(d => d.value);
1395
+ chart.data.datasets[0].label = metric.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase());
1396
+ chart.update();
1397
+ chartPlaceholder.style.display = 'none';
1398
+ }
1399
+
1400
+ // Artifacts Modal
1401
+ function setupArtifactsModal() {
1402
+ // Open artifacts modal
1403
+ artifactsBtn.addEventListener('click', async () => {
1404
+ if (!selectedRunId || artifactsBtn.disabled) return;
1405
+ await loadArtifacts(selectedRunId);
1406
+ artifactsModal.style.display = 'flex';
1407
+ });
1408
+
1409
+ // Close artifacts modal
1410
+ closeArtifactsModal.addEventListener('click', () => {
1411
+ artifactsModal.style.display = 'none';
1412
+ });
1413
+
1414
+ // Close on backdrop click
1415
+ artifactsModal.addEventListener('click', (e) => {
1416
+ if (e.target === artifactsModal) {
1417
+ artifactsModal.style.display = 'none';
1418
+ }
1419
+ });
1420
+
1421
+ // Close on escape
1422
+ document.addEventListener('keydown', (e) => {
1423
+ if (e.key === 'Escape' && artifactsModal.style.display === 'flex') {
1424
+ artifactsModal.style.display = 'none';
1425
+ }
1426
+ });
1427
+ }
1428
+
1429
+ async function loadArtifacts(runId) {
1430
+ const listEl = document.getElementById('artifacts-list');
1431
+ const pathEl = document.getElementById('artifacts-path');
1432
+
1433
+ try {
1434
+ const response = await fetch(`/runs/${runId}/artifacts`);
1435
+ const data = await response.json();
1436
+
1437
+ if (data.artifacts.length === 0) {
1438
+ listEl.innerHTML = '<div class="artifacts-empty">No artifacts available</div>';
1439
+ pathEl.textContent = '';
1440
+ return;
1441
+ }
1442
+
1443
+ listEl.innerHTML = data.artifacts.map(artifact => `
1444
+ <div class="artifact-item">
1445
+ <div class="artifact-info">
1446
+ <span class="artifact-name">${artifact.label}</span>
1447
+ <div class="artifact-meta">
1448
+ <span class="artifact-category">${artifact.category}</span>
1449
+ <span>${formatFileSize(artifact.size)}</span>
1450
+ </div>
1451
+ </div>
1452
+ <button class="artifact-download" onclick="copyArtifactPath('${artifact.path.replace(/'/g, "\\'")}')">
1453
+ Copy Path
1454
+ </button>
1455
+ </div>
1456
+ `).join('');
1457
+
1458
+ pathEl.textContent = data.output_dir;
1459
+ } catch (error) {
1460
+ console.error('Failed to load artifacts:', error);
1461
+ listEl.innerHTML = '<div class="artifacts-empty">Failed to load artifacts</div>';
1462
+ }
1463
+ }
1464
+
1465
+ function copyArtifactPath(path) {
1466
+ navigator.clipboard.writeText(path).then(() => {
1467
+ // Brief visual feedback
1468
+ const btn = event.target;
1469
+ const originalText = btn.textContent;
1470
+ btn.textContent = 'Copied!';
1471
+ setTimeout(() => {
1472
+ btn.textContent = originalText;
1473
+ }, 1000);
1474
+ }).catch(err => {
1475
+ console.error('Failed to copy:', err);
1476
+ });
1477
+ }
1478
+
1479
+ function formatFileSize(bytes) {
1480
+ if (bytes === 0) return '0 B';
1481
+ const k = 1024;
1482
+ const sizes = ['B', 'KB', 'MB', 'GB'];
1483
+ const i = Math.floor(Math.log(bytes) / Math.log(k));
1484
+ return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
1485
+ }