ml-dash 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1298 @@
1
+ """Upload command implementation for ML-Dash CLI."""
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import List, Dict, Any, Optional
7
+ from dataclasses import dataclass, field
8
+ import threading
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ from rich.console import Console
12
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
13
+ from rich.table import Table
14
+ from rich.panel import Panel
15
+
16
+ from ..storage import LocalStorage
17
+ from ..client import RemoteClient
18
+ from ..config import Config
19
+
20
+ # Initialize rich console
21
+ console = Console()
22
+
23
+
24
+ @dataclass
25
+ class ExperimentInfo:
26
+ """Information about an experiment to upload."""
27
+ project: str
28
+ experiment: str
29
+ path: Path
30
+ folder: Optional[str] = None
31
+ has_logs: bool = False
32
+ has_params: bool = False
33
+ metric_names: List[str] = field(default_factory=list)
34
+ file_count: int = 0
35
+ estimated_size: int = 0 # in bytes
36
+
37
+
38
+ @dataclass
39
+ class ValidationResult:
40
+ """Result of experiment validation."""
41
+ is_valid: bool = True
42
+ warnings: List[str] = field(default_factory=list)
43
+ errors: List[str] = field(default_factory=list)
44
+ valid_data: Dict[str, Any] = field(default_factory=dict)
45
+
46
+
47
+ @dataclass
48
+ class UploadResult:
49
+ """Result of uploading an experiment."""
50
+ experiment: str
51
+ success: bool = False
52
+ uploaded: Dict[str, int] = field(default_factory=dict) # {"logs": 100, "metrics": 3}
53
+ failed: Dict[str, List[str]] = field(default_factory=dict) # {"files": ["error msg"]}
54
+ errors: List[str] = field(default_factory=list)
55
+ bytes_uploaded: int = 0 # Total bytes uploaded
56
+
57
+
58
+ @dataclass
59
+ class UploadState:
60
+ """Tracks upload state for resume functionality."""
61
+ local_path: str
62
+ remote_url: str
63
+ completed_experiments: List[str] = field(default_factory=list) # ["project/experiment"]
64
+ failed_experiments: List[str] = field(default_factory=list)
65
+ in_progress_experiment: Optional[str] = None
66
+ timestamp: Optional[str] = None
67
+
68
+ def to_dict(self) -> Dict[str, Any]:
69
+ """Convert to dictionary for JSON serialization."""
70
+ return {
71
+ "local_path": self.local_path,
72
+ "remote_url": self.remote_url,
73
+ "completed_experiments": self.completed_experiments,
74
+ "failed_experiments": self.failed_experiments,
75
+ "in_progress_experiment": self.in_progress_experiment,
76
+ "timestamp": self.timestamp,
77
+ }
78
+
79
+ @classmethod
80
+ def from_dict(cls, data: Dict[str, Any]) -> "UploadState":
81
+ """Create from dictionary."""
82
+ return cls(
83
+ local_path=data["local_path"],
84
+ remote_url=data["remote_url"],
85
+ completed_experiments=data.get("completed_experiments", []),
86
+ failed_experiments=data.get("failed_experiments", []),
87
+ in_progress_experiment=data.get("in_progress_experiment"),
88
+ timestamp=data.get("timestamp"),
89
+ )
90
+
91
+ def save(self, path: Path):
92
+ """Save state to file."""
93
+ import datetime
94
+ self.timestamp = datetime.datetime.now().isoformat()
95
+ with open(path, "w") as f:
96
+ json.dump(self.to_dict(), f, indent=2)
97
+
98
+ @classmethod
99
+ def load(cls, path: Path) -> Optional["UploadState"]:
100
+ """Load state from file."""
101
+ if not path.exists():
102
+ return None
103
+ try:
104
+ with open(path, "r") as f:
105
+ data = json.load(f)
106
+ return cls.from_dict(data)
107
+ except (json.JSONDecodeError, IOError, KeyError):
108
+ return None
109
+
110
+
111
+ def generate_api_key_from_username(user_name: str) -> str:
112
+ """
113
+ Generate a deterministic API key (JWT) from username.
114
+
115
+ This is a temporary solution until proper user authentication is implemented.
116
+ Generates a unique user ID from the username and creates a JWT token.
117
+
118
+ Args:
119
+ user_name: Username to generate API key from
120
+
121
+ Returns:
122
+ JWT token string
123
+ """
124
+ import hashlib
125
+ import time
126
+ import jwt
127
+
128
+ # Generate deterministic user ID from username (first 10 digits of SHA256 hash)
129
+ user_id = str(int(hashlib.sha256(user_name.encode()).hexdigest()[:16], 16))[:10]
130
+
131
+ # JWT payload
132
+ payload = {
133
+ "userId": user_id,
134
+ "userName": user_name,
135
+ "iat": int(time.time()),
136
+ "exp": int(time.time()) + (30 * 24 * 60 * 60) # 30 days expiration
137
+ }
138
+
139
+ # Secret key for signing (should match server's JWT_SECRET)
140
+ secret = "your-secret-key-change-this-in-production"
141
+
142
+ # Generate JWT
143
+ token = jwt.encode(payload, secret, algorithm="HS256")
144
+
145
+ return token
146
+
147
+
148
+ def add_parser(subparsers) -> argparse.ArgumentParser:
149
+ """Add upload command parser."""
150
+ parser = subparsers.add_parser(
151
+ "upload",
152
+ help="Upload local experiments to remote server",
153
+ description="Upload locally-stored ML-Dash experiment data to a remote server.",
154
+ )
155
+
156
+ # Positional argument
157
+ parser.add_argument(
158
+ "path",
159
+ nargs="?",
160
+ default="./.ml-dash",
161
+ help="Local storage directory to upload from (default: ./.ml-dash)",
162
+ )
163
+
164
+ # Remote configuration
165
+ parser.add_argument(
166
+ "--remote",
167
+ type=str,
168
+ help="Remote server URL (required unless set in config)",
169
+ )
170
+ parser.add_argument(
171
+ "--api-key",
172
+ type=str,
173
+ help="JWT token for authentication (required unless --username or config is set)",
174
+ )
175
+ parser.add_argument(
176
+ "--username",
177
+ type=str,
178
+ help="Username for authentication (generates API key automatically)",
179
+ )
180
+
181
+ # Scope control
182
+ parser.add_argument(
183
+ "--project",
184
+ type=str,
185
+ help="Upload only experiments from this project",
186
+ )
187
+ parser.add_argument(
188
+ "--experiment",
189
+ type=str,
190
+ help="Upload only this specific experiment (requires --project)",
191
+ )
192
+
193
+ # Data filtering
194
+ parser.add_argument(
195
+ "--skip-logs",
196
+ action="store_true",
197
+ help="Don't upload logs",
198
+ )
199
+ parser.add_argument(
200
+ "--skip-metrics",
201
+ action="store_true",
202
+ help="Don't upload metrics",
203
+ )
204
+ parser.add_argument(
205
+ "--skip-files",
206
+ action="store_true",
207
+ help="Don't upload files",
208
+ )
209
+ parser.add_argument(
210
+ "--skip-params",
211
+ action="store_true",
212
+ help="Don't upload parameters",
213
+ )
214
+
215
+ # Behavior control
216
+ parser.add_argument(
217
+ "--dry-run",
218
+ action="store_true",
219
+ help="Show what would be uploaded without uploading",
220
+ )
221
+ parser.add_argument(
222
+ "--strict",
223
+ action="store_true",
224
+ help="Fail on any validation error (default: skip invalid data)",
225
+ )
226
+ parser.add_argument(
227
+ "-v", "--verbose",
228
+ action="store_true",
229
+ help="Show detailed progress",
230
+ )
231
+ parser.add_argument(
232
+ "--batch-size",
233
+ type=int,
234
+ default=100,
235
+ help="Batch size for logs/metrics (default: 100)",
236
+ )
237
+ parser.add_argument(
238
+ "--resume",
239
+ action="store_true",
240
+ help="Resume previous interrupted upload",
241
+ )
242
+ parser.add_argument(
243
+ "--state-file",
244
+ type=str,
245
+ default=".ml-dash-upload-state.json",
246
+ help="Path to state file for resume (default: .ml-dash-upload-state.json)",
247
+ )
248
+
249
+ return parser
250
+
251
+
252
+ def discover_experiments(
253
+ local_path: Path,
254
+ project_filter: Optional[str] = None,
255
+ experiment_filter: Optional[str] = None,
256
+ ) -> List[ExperimentInfo]:
257
+ """
258
+ Discover experiments in local storage directory.
259
+
260
+ Supports both flat (local_path/project/experiment) and folder-based
261
+ (local_path/folder/project/experiment) hierarchies.
262
+
263
+ Args:
264
+ local_path: Root path of local storage
265
+ project_filter: Only discover experiments in this project
266
+ experiment_filter: Only discover this experiment (requires project_filter)
267
+
268
+ Returns:
269
+ List of ExperimentInfo objects
270
+ """
271
+ local_path = Path(local_path)
272
+
273
+ if not local_path.exists():
274
+ return []
275
+
276
+ experiments = []
277
+
278
+ # Find all experiment.json files recursively
279
+ for exp_json in local_path.rglob("*/experiment.json"):
280
+ exp_dir = exp_json.parent
281
+
282
+ # Extract project and experiment names from path
283
+ # Path structure: local_path / [folder] / project / experiment
284
+ try:
285
+ relative_path = exp_dir.relative_to(local_path)
286
+ parts = relative_path.parts
287
+
288
+ if len(parts) < 2:
289
+ continue # Need at least project/experiment
290
+
291
+ # Last two parts are project/experiment
292
+ exp_name = parts[-1]
293
+ project_name = parts[-2]
294
+
295
+ # Apply filters
296
+ if project_filter and project_name != project_filter:
297
+ continue
298
+ if experiment_filter and exp_name != experiment_filter:
299
+ continue
300
+
301
+ # Read folder from experiment.json
302
+ folder = None
303
+ try:
304
+ with open(exp_json, 'r') as f:
305
+ metadata = json.load(f)
306
+ folder = metadata.get('folder')
307
+ except:
308
+ pass
309
+
310
+ # Create experiment info
311
+ exp_info = ExperimentInfo(
312
+ project=project_name,
313
+ experiment=exp_name,
314
+ path=exp_dir,
315
+ folder=folder,
316
+ )
317
+ except (ValueError, IndexError):
318
+ continue
319
+
320
+ # Check for parameters
321
+ params_file = exp_dir / "parameters.json"
322
+ exp_info.has_params = params_file.exists()
323
+
324
+ # Check for logs
325
+ logs_file = exp_dir / "logs" / "logs.jsonl"
326
+ exp_info.has_logs = logs_file.exists()
327
+
328
+ # Check for metrics
329
+ metrics_dir = exp_dir / "metrics"
330
+ if metrics_dir.exists():
331
+ for metric_dir in metrics_dir.iterdir():
332
+ if metric_dir.is_dir():
333
+ data_file = metric_dir / "data.jsonl"
334
+ if data_file.exists():
335
+ exp_info.metric_names.append(metric_dir.name)
336
+
337
+ # Check for files
338
+ files_dir = exp_dir / "files"
339
+ if files_dir.exists():
340
+ try:
341
+ # Count files recursively
342
+ exp_info.file_count = sum(1 for _ in files_dir.rglob("*") if _.is_file())
343
+
344
+ # Estimate size
345
+ exp_info.estimated_size = sum(
346
+ f.stat().st_size for f in files_dir.rglob("*") if f.is_file()
347
+ )
348
+ except (OSError, PermissionError):
349
+ pass
350
+
351
+ experiments.append(exp_info)
352
+
353
+ return experiments
354
+
355
+
356
+ class ExperimentValidator:
357
+ """Validates local experiment data before upload."""
358
+
359
+ def __init__(self, strict: bool = False):
360
+ """
361
+ Initialize validator.
362
+
363
+ Args:
364
+ strict: If True, fail on any validation error
365
+ """
366
+ self.strict = strict
367
+
368
+ def validate_experiment(self, exp_info: ExperimentInfo) -> ValidationResult:
369
+ """
370
+ Validate experiment directory structure and data.
371
+
372
+ Args:
373
+ exp_info: Experiment information
374
+
375
+ Returns:
376
+ ValidationResult with validation status and messages
377
+ """
378
+ result = ValidationResult()
379
+ result.valid_data = {}
380
+
381
+ # 1. Validate experiment metadata (required)
382
+ if not self._validate_experiment_metadata(exp_info, result):
383
+ result.is_valid = False
384
+ return result
385
+
386
+ # 2. Validate parameters (optional)
387
+ self._validate_parameters(exp_info, result)
388
+
389
+ # 3. Validate logs (optional)
390
+ self._validate_logs(exp_info, result)
391
+
392
+ # 4. Validate metrics (optional)
393
+ self._validate_metrics(exp_info, result)
394
+
395
+ # 5. Validate files (optional)
396
+ self._validate_files(exp_info, result)
397
+
398
+ # In strict mode, any warning becomes an error
399
+ if self.strict and result.warnings:
400
+ result.errors.extend(result.warnings)
401
+ result.warnings = []
402
+ result.is_valid = False
403
+
404
+ return result
405
+
406
+ def _validate_experiment_metadata(self, exp_info: ExperimentInfo, result: ValidationResult) -> bool:
407
+ """Validate experiment.json exists and is valid."""
408
+ exp_json = exp_info.path / "experiment.json"
409
+
410
+ if not exp_json.exists():
411
+ result.errors.append("Missing experiment.json")
412
+ return False
413
+
414
+ try:
415
+ with open(exp_json, "r") as f:
416
+ metadata = json.load(f)
417
+
418
+ # Check required fields
419
+ if "name" not in metadata or "project" not in metadata:
420
+ result.errors.append("experiment.json missing required fields (name, project)")
421
+ return False
422
+
423
+ result.valid_data["metadata"] = metadata
424
+ return True
425
+
426
+ except json.JSONDecodeError as e:
427
+ result.errors.append(f"Invalid JSON in experiment.json: {e}")
428
+ return False
429
+ except IOError as e:
430
+ result.errors.append(f"Cannot read experiment.json: {e}")
431
+ return False
432
+
433
+ def _validate_parameters(self, exp_info: ExperimentInfo, result: ValidationResult):
434
+ """Validate parameters.json format."""
435
+ if not exp_info.has_params:
436
+ return
437
+
438
+ params_file = exp_info.path / "parameters.json"
439
+ try:
440
+ with open(params_file, "r") as f:
441
+ params = json.load(f)
442
+
443
+ # Check if it's a dict
444
+ if not isinstance(params, dict):
445
+ result.warnings.append("parameters.json is not a dict (will skip)")
446
+ return
447
+
448
+ # Check for valid data key if using versioned format
449
+ if "data" in params:
450
+ if not isinstance(params["data"], dict):
451
+ result.warnings.append("parameters.json data is not a dict (will skip)")
452
+ return
453
+ result.valid_data["parameters"] = params["data"]
454
+ else:
455
+ result.valid_data["parameters"] = params
456
+
457
+ except json.JSONDecodeError as e:
458
+ result.warnings.append(f"Invalid JSON in parameters.json: {e} (will skip)")
459
+ except IOError as e:
460
+ result.warnings.append(f"Cannot read parameters.json: {e} (will skip)")
461
+
462
+ def _validate_logs(self, exp_info: ExperimentInfo, result: ValidationResult):
463
+ """Validate logs.jsonl format."""
464
+ if not exp_info.has_logs:
465
+ return
466
+
467
+ logs_file = exp_info.path / "logs" / "logs.jsonl"
468
+ invalid_lines = []
469
+
470
+ try:
471
+ with open(logs_file, "r") as f:
472
+ for line_num, line in enumerate(f, start=1):
473
+ try:
474
+ log_entry = json.loads(line)
475
+ # Check required fields
476
+ if "message" not in log_entry:
477
+ invalid_lines.append(line_num)
478
+ except json.JSONDecodeError:
479
+ invalid_lines.append(line_num)
480
+
481
+ if invalid_lines:
482
+ count = len(invalid_lines)
483
+ preview = invalid_lines[:5]
484
+ result.warnings.append(
485
+ f"logs.jsonl has {count} invalid lines (e.g., {preview}...) - will skip these"
486
+ )
487
+
488
+ except IOError as e:
489
+ result.warnings.append(f"Cannot read logs.jsonl: {e} (will skip logs)")
490
+
491
+ def _validate_metrics(self, exp_info: ExperimentInfo, result: ValidationResult):
492
+ """Validate metrics data."""
493
+ if not exp_info.metric_names:
494
+ return
495
+
496
+ for metric_name in exp_info.metric_names:
497
+ metric_dir = exp_info.path / "metrics" / metric_name
498
+ data_file = metric_dir / "data.jsonl"
499
+
500
+ invalid_lines = []
501
+ try:
502
+ with open(data_file, "r") as f:
503
+ for line_num, line in enumerate(f, start=1):
504
+ try:
505
+ data_point = json.loads(line)
506
+ # Check for data field
507
+ if "data" not in data_point:
508
+ invalid_lines.append(line_num)
509
+ except json.JSONDecodeError:
510
+ invalid_lines.append(line_num)
511
+
512
+ if invalid_lines:
513
+ count = len(invalid_lines)
514
+ preview = invalid_lines[:5]
515
+ result.warnings.append(
516
+ f"metric '{metric_name}' has {count} invalid lines (e.g., {preview}...) - will skip these"
517
+ )
518
+
519
+ except IOError as e:
520
+ result.warnings.append(f"Cannot read metric '{metric_name}': {e} (will skip)")
521
+
522
+ def _validate_files(self, exp_info: ExperimentInfo, result: ValidationResult):
523
+ """Validate files existence."""
524
+ files_dir = exp_info.path / "files"
525
+ if not files_dir.exists():
526
+ return
527
+
528
+ metadata_file = files_dir / ".files_metadata.json"
529
+ if not metadata_file.exists():
530
+ return
531
+
532
+ try:
533
+ with open(metadata_file, "r") as f:
534
+ files_metadata = json.load(f)
535
+
536
+ missing_files = []
537
+ for file_id, file_info in files_metadata.items():
538
+ if isinstance(file_info, dict) and file_info.get("deletedAt") is None:
539
+ # Check if file exists
540
+ file_path = files_dir / file_info.get("prefix", "") / file_id / file_info.get("filename", "")
541
+ if not file_path.exists():
542
+ missing_files.append(file_info.get("filename", file_id))
543
+
544
+ if missing_files:
545
+ count = len(missing_files)
546
+ preview = missing_files[:3]
547
+ result.warnings.append(
548
+ f"{count} files referenced in metadata but missing on disk (e.g., {preview}...) - will skip these"
549
+ )
550
+
551
+ except (json.JSONDecodeError, IOError):
552
+ pass # If we can't read metadata, just skip file validation
553
+
554
+
555
+ class ExperimentUploader:
556
+ """Handles uploading a single experiment."""
557
+
558
+ def __init__(
559
+ self,
560
+ local_storage: LocalStorage,
561
+ remote_client: RemoteClient,
562
+ batch_size: int = 100,
563
+ skip_logs: bool = False,
564
+ skip_metrics: bool = False,
565
+ skip_files: bool = False,
566
+ skip_params: bool = False,
567
+ verbose: bool = False,
568
+ progress: Optional[Progress] = None,
569
+ max_concurrent_metrics: int = 5,
570
+ ):
571
+ """
572
+ Initialize uploader.
573
+
574
+ Args:
575
+ local_storage: Local storage instance
576
+ remote_client: Remote client instance
577
+ batch_size: Batch size for logs/metrics
578
+ skip_logs: Skip uploading logs
579
+ skip_metrics: Skip uploading metrics
580
+ skip_files: Skip uploading files
581
+ skip_params: Skip uploading parameters
582
+ verbose: Show verbose output
583
+ progress: Optional rich Progress instance for tracking
584
+ max_concurrent_metrics: Maximum concurrent metric uploads (default: 5)
585
+ """
586
+ self.local = local_storage
587
+ self.remote = remote_client
588
+ self.batch_size = batch_size
589
+ self.skip_logs = skip_logs
590
+ self.skip_metrics = skip_metrics
591
+ self.skip_files = skip_files
592
+ self.skip_params = skip_params
593
+ self.verbose = verbose
594
+ self.progress = progress
595
+ self.max_concurrent_metrics = max_concurrent_metrics
596
+ # Thread-safe lock for shared state updates
597
+ self._lock = threading.Lock()
598
+ # Thread-local storage for remote clients (for thread-safe HTTP requests)
599
+ self._thread_local = threading.local()
600
+
601
+ def _get_remote_client(self) -> RemoteClient:
602
+ """Get thread-local remote client for safe concurrent access."""
603
+ if not hasattr(self._thread_local, 'client'):
604
+ # Create a new client for this thread
605
+ self._thread_local.client = RemoteClient(
606
+ base_url=self.remote.base_url,
607
+ api_key=self.remote.api_key
608
+ )
609
+ return self._thread_local.client
610
+
611
+ def upload_experiment(
612
+ self, exp_info: ExperimentInfo, validation_result: ValidationResult, task_id=None
613
+ ) -> UploadResult:
614
+ """
615
+ Upload a single experiment with all its data.
616
+
617
+ Args:
618
+ exp_info: Experiment information
619
+ validation_result: Validation results
620
+ task_id: Optional progress task ID
621
+
622
+ Returns:
623
+ UploadResult with upload status
624
+ """
625
+ result = UploadResult(experiment=f"{exp_info.project}/{exp_info.experiment}")
626
+
627
+ # Calculate total steps for progress tracking
628
+ total_steps = 1 # metadata
629
+ if not self.skip_params and "parameters" in validation_result.valid_data:
630
+ total_steps += 1
631
+ if not self.skip_logs and exp_info.has_logs:
632
+ total_steps += 1
633
+ if not self.skip_metrics and exp_info.metric_names:
634
+ total_steps += len(exp_info.metric_names)
635
+ if not self.skip_files and exp_info.file_count > 0:
636
+ total_steps += exp_info.file_count
637
+
638
+ current_step = 0
639
+
640
+ def update_progress(description: str):
641
+ nonlocal current_step
642
+ current_step += 1
643
+ if self.progress and task_id is not None:
644
+ self.progress.update(task_id, completed=current_step, total=total_steps, description=description)
645
+
646
+ try:
647
+ # 1. Create/update experiment metadata
648
+ update_progress("Creating experiment...")
649
+ if self.verbose:
650
+ console.print(f" [dim]Creating experiment...[/dim]")
651
+
652
+ exp_data = validation_result.valid_data
653
+
654
+ # Store folder path in metadata (not as folderId which expects Snowflake ID)
655
+ custom_metadata = exp_data.get("metadata") or {}
656
+ if exp_data.get("folder"):
657
+ custom_metadata["folder"] = exp_data["folder"]
658
+
659
+ response = self.remote.create_or_update_experiment(
660
+ project=exp_info.project,
661
+ name=exp_info.experiment,
662
+ description=exp_data.get("description"),
663
+ tags=exp_data.get("tags"),
664
+ bindrs=exp_data.get("bindrs"),
665
+ folder=None, # Don't send folder path as folderId (expects Snowflake ID)
666
+ write_protected=exp_data.get("write_protected", False),
667
+ metadata=custom_metadata if custom_metadata else None,
668
+ )
669
+
670
+ # Extract experiment ID from nested response
671
+ experiment_id = response.get("experiment", {}).get("id") or response.get("id")
672
+ if self.verbose:
673
+ console.print(f" [green]✓[/green] Created experiment (id: {experiment_id})")
674
+
675
+ # 2. Upload parameters
676
+ if not self.skip_params and "parameters" in validation_result.valid_data:
677
+ update_progress("Uploading parameters...")
678
+ if self.verbose:
679
+ console.print(f" [dim]Uploading parameters...[/dim]")
680
+
681
+ params = validation_result.valid_data["parameters"]
682
+ self.remote.set_parameters(experiment_id, params)
683
+ result.uploaded["params"] = len(params)
684
+ # Track bytes (approximate JSON size)
685
+ result.bytes_uploaded += len(json.dumps(params).encode('utf-8'))
686
+
687
+ if self.verbose:
688
+ console.print(f" [green]✓[/green] Uploaded {len(params)} parameters")
689
+
690
+ # 3. Upload logs
691
+ if not self.skip_logs and exp_info.has_logs:
692
+ count = self._upload_logs(experiment_id, exp_info, result, task_id, update_progress)
693
+ result.uploaded["logs"] = count
694
+
695
+ # 4. Upload metrics
696
+ if not self.skip_metrics and exp_info.metric_names:
697
+ count = self._upload_metrics(experiment_id, exp_info, result, task_id, update_progress)
698
+ result.uploaded["metrics"] = count
699
+
700
+ # 5. Upload files
701
+ if not self.skip_files and exp_info.file_count > 0:
702
+ count = self._upload_files(experiment_id, exp_info, result, task_id, update_progress)
703
+ result.uploaded["files"] = count
704
+
705
+ result.success = True
706
+
707
+ except Exception as e:
708
+ result.success = False
709
+ result.errors.append(str(e))
710
+ if self.verbose:
711
+ console.print(f" [red]✗ Error: {e}[/red]")
712
+
713
+ return result
714
+
715
+ def _upload_logs(self, experiment_id: str, exp_info: ExperimentInfo, result: UploadResult,
716
+ task_id=None, update_progress=None) -> int:
717
+ """Upload logs in batches."""
718
+ if update_progress:
719
+ update_progress("Uploading logs...")
720
+ if self.verbose:
721
+ console.print(f" [dim]Uploading logs...[/dim]")
722
+
723
+ logs_file = exp_info.path / "logs" / "logs.jsonl"
724
+ logs_batch = []
725
+ total_uploaded = 0
726
+ skipped = 0
727
+
728
+ try:
729
+ with open(logs_file, "r") as f:
730
+ for line in f:
731
+ try:
732
+ log_entry = json.loads(line)
733
+
734
+ # Validate required fields
735
+ if "message" not in log_entry:
736
+ skipped += 1
737
+ continue
738
+
739
+ # Prepare log entry for API
740
+ api_log = {
741
+ "timestamp": log_entry.get("timestamp"),
742
+ "level": log_entry.get("level", "info"),
743
+ "message": log_entry["message"],
744
+ }
745
+ if "metadata" in log_entry:
746
+ api_log["metadata"] = log_entry["metadata"]
747
+
748
+ logs_batch.append(api_log)
749
+ # Track bytes
750
+ result.bytes_uploaded += len(line.encode('utf-8'))
751
+
752
+ # Upload batch
753
+ if len(logs_batch) >= self.batch_size:
754
+ self.remote.create_log_entries(experiment_id, logs_batch)
755
+ total_uploaded += len(logs_batch)
756
+ logs_batch = []
757
+
758
+ except json.JSONDecodeError:
759
+ skipped += 1
760
+ continue
761
+
762
+ # Upload remaining logs
763
+ if logs_batch:
764
+ self.remote.create_log_entries(experiment_id, logs_batch)
765
+ total_uploaded += len(logs_batch)
766
+
767
+ if self.verbose:
768
+ msg = f" [green]✓[/green] Uploaded {total_uploaded} log entries"
769
+ if skipped > 0:
770
+ msg += f" (skipped {skipped} invalid)"
771
+ console.print(msg)
772
+
773
+ except IOError as e:
774
+ result.failed.setdefault("logs", []).append(str(e))
775
+
776
+ return total_uploaded
777
+
778
+ def _upload_single_metric(
779
+ self,
780
+ experiment_id: str,
781
+ metric_name: str,
782
+ metric_dir: Path,
783
+ result: UploadResult
784
+ ) -> Dict[str, Any]:
785
+ """
786
+ Upload a single metric (thread-safe helper).
787
+
788
+ Returns:
789
+ Dict with 'success', 'uploaded', 'skipped', 'bytes', and 'error' keys
790
+ """
791
+ data_file = metric_dir / "data.jsonl"
792
+ data_batch = []
793
+ total_uploaded = 0
794
+ skipped = 0
795
+ bytes_uploaded = 0
796
+
797
+ # Get thread-local client for safe concurrent HTTP requests
798
+ remote_client = self._get_remote_client()
799
+
800
+ try:
801
+ with open(data_file, "r") as f:
802
+ for line in f:
803
+ try:
804
+ data_point = json.loads(line)
805
+
806
+ # Validate required fields
807
+ if "data" not in data_point:
808
+ skipped += 1
809
+ continue
810
+
811
+ data_batch.append(data_point["data"])
812
+ bytes_uploaded += len(line.encode('utf-8'))
813
+
814
+ # Upload batch using thread-local client
815
+ if len(data_batch) >= self.batch_size:
816
+ remote_client.append_batch_to_metric(
817
+ experiment_id, metric_name, data_batch
818
+ )
819
+ total_uploaded += len(data_batch)
820
+ data_batch = []
821
+
822
+ except json.JSONDecodeError:
823
+ skipped += 1
824
+ continue
825
+
826
+ # Upload remaining data points using thread-local client
827
+ if data_batch:
828
+ remote_client.append_batch_to_metric(experiment_id, metric_name, data_batch)
829
+ total_uploaded += len(data_batch)
830
+
831
+ return {
832
+ 'success': True,
833
+ 'uploaded': total_uploaded,
834
+ 'skipped': skipped,
835
+ 'bytes': bytes_uploaded,
836
+ 'error': None
837
+ }
838
+
839
+ except Exception as e:
840
+ return {
841
+ 'success': False,
842
+ 'uploaded': 0,
843
+ 'skipped': 0,
844
+ 'bytes': 0,
845
+ 'error': str(e)
846
+ }
847
+
848
+ def _upload_metrics(self, experiment_id: str, exp_info: ExperimentInfo, result: UploadResult,
849
+ task_id=None, update_progress=None) -> int:
850
+ """Upload metrics in parallel with concurrency limit."""
851
+ if not exp_info.metric_names:
852
+ return 0
853
+
854
+ total_metrics = 0
855
+
856
+ # Use ThreadPoolExecutor for parallel uploads
857
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_metrics) as executor:
858
+ # Submit all metric upload tasks
859
+ future_to_metric = {}
860
+ for metric_name in exp_info.metric_names:
861
+ metric_dir = exp_info.path / "metrics" / metric_name
862
+ future = executor.submit(
863
+ self._upload_single_metric,
864
+ experiment_id,
865
+ metric_name,
866
+ metric_dir,
867
+ result
868
+ )
869
+ future_to_metric[future] = metric_name
870
+
871
+ # Process completed uploads as they finish
872
+ for future in as_completed(future_to_metric):
873
+ metric_name = future_to_metric[future]
874
+
875
+ # Update progress
876
+ if update_progress:
877
+ update_progress(f"Uploading metric '{metric_name}'...")
878
+
879
+ try:
880
+ upload_result = future.result()
881
+
882
+ # Thread-safe update of shared state
883
+ with self._lock:
884
+ result.bytes_uploaded += upload_result['bytes']
885
+
886
+ if upload_result['success']:
887
+ total_metrics += 1
888
+
889
+ # Thread-safe console output
890
+ if self.verbose:
891
+ msg = f" [green]✓[/green] Uploaded {upload_result['uploaded']} data points for '{metric_name}'"
892
+ if upload_result['skipped'] > 0:
893
+ msg += f" (skipped {upload_result['skipped']} invalid)"
894
+ with self._lock:
895
+ console.print(msg)
896
+ else:
897
+ # Record failure
898
+ error_msg = f"{metric_name}: {upload_result['error']}"
899
+ with self._lock:
900
+ result.failed.setdefault("metrics", []).append(error_msg)
901
+ if self.verbose:
902
+ console.print(f" [red]✗[/red] Failed to upload '{metric_name}': {upload_result['error']}")
903
+
904
+ except Exception as e:
905
+ # Handle unexpected errors
906
+ error_msg = f"{metric_name}: {str(e)}"
907
+ with self._lock:
908
+ result.failed.setdefault("metrics", []).append(error_msg)
909
+ if self.verbose:
910
+ console.print(f" [red]✗[/red] Failed to upload '{metric_name}': {e}")
911
+
912
+ return total_metrics
913
+
914
+ def _upload_files(self, experiment_id: str, exp_info: ExperimentInfo, result: UploadResult,
915
+ task_id=None, update_progress=None) -> int:
916
+ """Upload files one by one."""
917
+ files_dir = exp_info.path / "files"
918
+ total_uploaded = 0
919
+
920
+ # Use LocalStorage to list files
921
+ try:
922
+ files_list = self.local.list_files(exp_info.project, exp_info.experiment)
923
+
924
+ for file_info in files_list:
925
+ # Skip deleted files
926
+ if file_info.get("deletedAt") is not None:
927
+ continue
928
+
929
+ try:
930
+ if update_progress:
931
+ update_progress(f"Uploading {file_info['filename']}...")
932
+
933
+ # Get file path directly from storage without copying
934
+ file_id = file_info["id"]
935
+ experiment_dir = self.local._get_experiment_dir(exp_info.project, exp_info.experiment)
936
+ files_dir = experiment_dir / "files"
937
+
938
+ # Construct file path
939
+ file_prefix = file_info["path"].lstrip("/") if file_info["path"] else ""
940
+ if file_prefix:
941
+ file_path = files_dir / file_prefix / file_id / file_info["filename"]
942
+ else:
943
+ file_path = files_dir / file_id / file_info["filename"]
944
+
945
+ # Upload to remote with correct parameters
946
+ self.remote.upload_file(
947
+ experiment_id=experiment_id,
948
+ file_path=str(file_path),
949
+ prefix=file_info.get("path", ""),
950
+ filename=file_info["filename"],
951
+ description=file_info.get("description"),
952
+ tags=file_info.get("tags", []),
953
+ metadata=file_info.get("metadata"),
954
+ checksum=file_info["checksum"],
955
+ content_type=file_info["contentType"],
956
+ size_bytes=file_info["sizeBytes"],
957
+ )
958
+
959
+ total_uploaded += 1
960
+ # Track bytes
961
+ result.bytes_uploaded += file_info.get("sizeBytes", 0)
962
+
963
+ if self.verbose:
964
+ size_mb = file_info.get("sizeBytes", 0) / (1024 * 1024)
965
+ console.print(f" [green]✓[/green] {file_info['filename']} ({size_mb:.1f}MB)")
966
+
967
+ except Exception as e:
968
+ result.failed.setdefault("files", []).append(f"{file_info['filename']}: {e}")
969
+
970
+ except Exception as e:
971
+ result.failed.setdefault("files", []).append(str(e))
972
+
973
+ if self.verbose and not result.failed.get("files"):
974
+ console.print(f" [green]✓[/green] Uploaded {total_uploaded} files")
975
+
976
+ return total_uploaded
977
+
978
+
979
+ def cmd_upload(args: argparse.Namespace) -> int:
980
+ """
981
+ Execute upload command.
982
+
983
+ Args:
984
+ args: Parsed command-line arguments
985
+
986
+ Returns:
987
+ Exit code (0 for success, 1 for error)
988
+ """
989
+ # Load config
990
+ config = Config()
991
+
992
+ # Get remote URL (command line > config)
993
+ remote_url = args.remote or config.remote_url
994
+ if not remote_url:
995
+ console.print("[red]Error:[/red] --remote URL is required (or set in config)")
996
+ return 1
997
+
998
+ # Get API key (command line > config > generate from username)
999
+ api_key = args.api_key or config.api_key
1000
+
1001
+ # If no API key, try to generate from username
1002
+ if not api_key:
1003
+ if args.username:
1004
+ console.print(f"[dim]Generating API key from username: {args.username}[/dim]")
1005
+ api_key = generate_api_key_from_username(args.username)
1006
+ else:
1007
+ console.print("[red]Error:[/red] --api-key or --username is required (or set in config)")
1008
+ return 1
1009
+
1010
+ # Validate experiment filter requires project
1011
+ if args.experiment and not args.project:
1012
+ console.print("[red]Error:[/red] --experiment requires --project")
1013
+ return 1
1014
+
1015
+ # Discover experiments
1016
+ local_path = Path(args.path)
1017
+ if not local_path.exists():
1018
+ console.print(f"[red]Error:[/red] Local storage path does not exist: {local_path}")
1019
+ return 1
1020
+
1021
+ # Handle state file for resume functionality
1022
+ state_file = Path(args.state_file)
1023
+ upload_state = None
1024
+
1025
+ if args.resume:
1026
+ upload_state = UploadState.load(state_file)
1027
+ if upload_state:
1028
+ # Validate state matches current upload
1029
+ if upload_state.local_path != str(local_path.absolute()):
1030
+ console.print("[yellow]Warning:[/yellow] State file local path doesn't match. Starting fresh upload.")
1031
+ upload_state = None
1032
+ elif upload_state.remote_url != remote_url:
1033
+ console.print("[yellow]Warning:[/yellow] State file remote URL doesn't match. Starting fresh upload.")
1034
+ upload_state = None
1035
+ else:
1036
+ console.print(f"[green]Resuming previous upload from {upload_state.timestamp}[/green]")
1037
+ console.print(f" Already completed: {len(upload_state.completed_experiments)} experiments")
1038
+ console.print(f" Failed: {len(upload_state.failed_experiments)} experiments")
1039
+ else:
1040
+ console.print("[yellow]No previous upload state found. Starting fresh upload.[/yellow]")
1041
+
1042
+ # Create new state if not resuming
1043
+ if not upload_state:
1044
+ upload_state = UploadState(
1045
+ local_path=str(local_path.absolute()),
1046
+ remote_url=remote_url,
1047
+ )
1048
+
1049
+ console.print(f"[bold]Scanning local storage:[/bold] {local_path.absolute()}")
1050
+ experiments = discover_experiments(
1051
+ local_path,
1052
+ project_filter=args.project,
1053
+ experiment_filter=args.experiment,
1054
+ )
1055
+
1056
+ if not experiments:
1057
+ if args.project and args.experiment:
1058
+ console.print(f"[yellow]No experiment found:[/yellow] {args.project}/{args.experiment}")
1059
+ elif args.project:
1060
+ console.print(f"[yellow]No experiments found in project:[/yellow] {args.project}")
1061
+ else:
1062
+ console.print("[yellow]No experiments found in local storage[/yellow]")
1063
+ return 1
1064
+
1065
+ # Filter out already completed experiments when resuming
1066
+ if args.resume and upload_state.completed_experiments:
1067
+ original_count = len(experiments)
1068
+ experiments = [
1069
+ exp for exp in experiments
1070
+ if f"{exp.project}/{exp.experiment}" not in upload_state.completed_experiments
1071
+ ]
1072
+ skipped_count = original_count - len(experiments)
1073
+ if skipped_count > 0:
1074
+ console.print(f"[dim]Skipping {skipped_count} already completed experiment(s)[/dim]")
1075
+
1076
+ console.print(f"[green]Found {len(experiments)} experiment(s) to upload[/green]")
1077
+
1078
+ # Display discovered experiments
1079
+ if args.verbose or args.dry_run:
1080
+ console.print("\n[bold]Discovered experiments:[/bold]")
1081
+ for exp in experiments:
1082
+ parts = []
1083
+ if exp.has_logs:
1084
+ parts.append("logs")
1085
+ if exp.has_params:
1086
+ parts.append("params")
1087
+ if exp.metric_names:
1088
+ parts.append(f"{len(exp.metric_names)} metrics")
1089
+ if exp.file_count:
1090
+ size_mb = exp.estimated_size / (1024 * 1024)
1091
+ parts.append(f"{exp.file_count} files ({size_mb:.1f}MB)")
1092
+
1093
+ details = ", ".join(parts) if parts else "metadata only"
1094
+ console.print(f" [cyan]•[/cyan] {exp.project}/{exp.experiment} [dim]({details})[/dim]")
1095
+
1096
+ # Dry-run mode: stop here
1097
+ if args.dry_run:
1098
+ console.print("\n[yellow bold]DRY RUN[/yellow bold] - No data will be uploaded")
1099
+ console.print("Run without --dry-run to proceed with upload.")
1100
+ return 0
1101
+
1102
+ # Validate experiments
1103
+ console.print("\n[bold]Validating experiments...[/bold]")
1104
+ validator = ExperimentValidator(strict=args.strict)
1105
+ validation_results = {}
1106
+ valid_experiments = []
1107
+ invalid_experiments = []
1108
+
1109
+ for exp in experiments:
1110
+ validation = validator.validate_experiment(exp)
1111
+ validation_results[f"{exp.project}/{exp.experiment}"] = validation
1112
+
1113
+ if validation.is_valid:
1114
+ valid_experiments.append(exp)
1115
+ else:
1116
+ invalid_experiments.append(exp)
1117
+
1118
+ # Show warnings and errors
1119
+ if args.verbose or validation.errors:
1120
+ exp_key = f"{exp.project}/{exp.experiment}"
1121
+ if validation.errors:
1122
+ console.print(f" [red]✗[/red] {exp_key}:")
1123
+ for error in validation.errors:
1124
+ console.print(f" [red]{error}[/red]")
1125
+ elif validation.warnings:
1126
+ console.print(f" [yellow]⚠[/yellow] {exp_key}:")
1127
+ for warning in validation.warnings:
1128
+ console.print(f" [yellow]{warning}[/yellow]")
1129
+
1130
+ if invalid_experiments:
1131
+ console.print(f"\n[yellow]{len(invalid_experiments)} experiment(s) failed validation and will be skipped[/yellow]")
1132
+ if args.strict:
1133
+ console.print("[red]Error: Validation failed in --strict mode[/red]")
1134
+ return 1
1135
+
1136
+ if not valid_experiments:
1137
+ console.print("[red]Error: No valid experiments to upload[/red]")
1138
+ return 1
1139
+
1140
+ console.print(f"[green]{len(valid_experiments)} experiment(s) ready to upload[/green]")
1141
+
1142
+ # Initialize remote client and local storage
1143
+ remote_client = RemoteClient(base_url=remote_url, api_key=api_key)
1144
+ local_storage = LocalStorage(root_path=local_path)
1145
+
1146
+ # Upload experiments with progress tracking
1147
+ console.print(f"\n[bold]Uploading to:[/bold] {remote_url}")
1148
+ results = []
1149
+
1150
+ # Track upload timing
1151
+ import time
1152
+ start_time = time.time()
1153
+
1154
+ # Create progress bar for overall upload
1155
+ with Progress(
1156
+ SpinnerColumn(),
1157
+ TextColumn("[progress.description]{task.description}"),
1158
+ BarColumn(),
1159
+ TaskProgressColumn(),
1160
+ console=console,
1161
+ transient=not args.verbose, # Keep progress visible in verbose mode
1162
+ ) as progress:
1163
+ # Create uploader with progress tracking
1164
+ uploader = ExperimentUploader(
1165
+ local_storage=local_storage,
1166
+ remote_client=remote_client,
1167
+ batch_size=args.batch_size,
1168
+ skip_logs=args.skip_logs,
1169
+ skip_metrics=args.skip_metrics,
1170
+ skip_files=args.skip_files,
1171
+ skip_params=args.skip_params,
1172
+ verbose=args.verbose,
1173
+ progress=progress,
1174
+ )
1175
+
1176
+ for i, exp in enumerate(valid_experiments, start=1):
1177
+ exp_key = f"{exp.project}/{exp.experiment}"
1178
+
1179
+ # Create task for this experiment
1180
+ task_id = progress.add_task(
1181
+ f"[{i}/{len(valid_experiments)}] {exp_key}",
1182
+ total=100, # Will be updated with actual steps
1183
+ )
1184
+
1185
+ # Update state - mark as in progress
1186
+ upload_state.in_progress_experiment = exp_key
1187
+ if not args.dry_run:
1188
+ upload_state.save(state_file)
1189
+
1190
+ validation = validation_results[exp_key]
1191
+ result = uploader.upload_experiment(exp, validation, task_id=task_id)
1192
+ results.append(result)
1193
+
1194
+ # Update state - mark as completed or failed
1195
+ upload_state.in_progress_experiment = None
1196
+ if result.success:
1197
+ upload_state.completed_experiments.append(exp_key)
1198
+ else:
1199
+ upload_state.failed_experiments.append(exp_key)
1200
+
1201
+ if not args.dry_run:
1202
+ upload_state.save(state_file)
1203
+
1204
+ # Update task to completed
1205
+ progress.update(task_id, completed=100, total=100)
1206
+
1207
+ if not args.verbose:
1208
+ # Show brief status
1209
+ if result.success:
1210
+ parts = []
1211
+ if result.uploaded.get("params"):
1212
+ parts.append(f"{result.uploaded['params']} params")
1213
+ if result.uploaded.get("logs"):
1214
+ parts.append(f"{result.uploaded['logs']} logs")
1215
+ if result.uploaded.get("metrics"):
1216
+ parts.append(f"{result.uploaded['metrics']} metrics")
1217
+ if result.uploaded.get("files"):
1218
+ parts.append(f"{result.uploaded['files']} files")
1219
+ status = ", ".join(parts) if parts else "metadata only"
1220
+ console.print(f" [green]✓[/green] Uploaded ({status})")
1221
+ else:
1222
+ console.print(f" [red]✗[/red] Failed")
1223
+ if result.errors:
1224
+ for error in result.errors[:3]: # Show first 3 errors
1225
+ console.print(f" [red]{error}[/red]")
1226
+
1227
+ # Calculate timing
1228
+ end_time = time.time()
1229
+ elapsed_time = end_time - start_time
1230
+ total_bytes = sum(r.bytes_uploaded for r in results)
1231
+
1232
+ # Print summary with rich Table
1233
+ console.print()
1234
+
1235
+ successful = [r for r in results if r.success]
1236
+ failed = [r for r in results if not r.success]
1237
+
1238
+ # Create summary table
1239
+ summary_table = Table(title="Upload Summary", show_header=True, header_style="bold")
1240
+ summary_table.add_column("Status", style="cyan")
1241
+ summary_table.add_column("Count", justify="right")
1242
+
1243
+ summary_table.add_row("Successful", f"[green]{len(successful)}/{len(results)}[/green]")
1244
+ if failed:
1245
+ summary_table.add_row("Failed", f"[red]{len(failed)}/{len(results)}[/red]")
1246
+
1247
+ # Add timing information
1248
+ summary_table.add_row("Total Time", f"{elapsed_time:.2f}s")
1249
+
1250
+ # Calculate and display upload speed
1251
+ if total_bytes > 0 and elapsed_time > 0:
1252
+ # Convert to appropriate unit
1253
+ if total_bytes < 1024 * 1024: # Less than 1 MB
1254
+ speed_kb = (total_bytes / 1024) / elapsed_time
1255
+ summary_table.add_row("Avg Speed", f"{speed_kb:.2f} KB/s")
1256
+ else: # 1 MB or more
1257
+ speed_mb = (total_bytes / (1024 * 1024)) / elapsed_time
1258
+ summary_table.add_row("Avg Speed", f"{speed_mb:.2f} MB/s")
1259
+
1260
+ console.print(summary_table)
1261
+
1262
+ # Show failed experiments
1263
+ if failed:
1264
+ console.print("\n[bold red]Failed Experiments:[/bold red]")
1265
+ for result in failed:
1266
+ console.print(f" [red]✗[/red] {result.experiment}")
1267
+ for error in result.errors:
1268
+ console.print(f" [dim]{error}[/dim]")
1269
+
1270
+ # Data statistics
1271
+ total_logs = sum(r.uploaded.get("logs", 0) for r in results)
1272
+ total_metrics = sum(r.uploaded.get("metrics", 0) for r in results)
1273
+ total_files = sum(r.uploaded.get("files", 0) for r in results)
1274
+
1275
+ if total_logs or total_metrics or total_files:
1276
+ data_table = Table(title="Data Uploaded", show_header=True, header_style="bold")
1277
+ data_table.add_column("Type", style="cyan")
1278
+ data_table.add_column("Count", justify="right", style="green")
1279
+
1280
+ if total_logs:
1281
+ data_table.add_row("Logs", f"{total_logs} entries")
1282
+ if total_metrics:
1283
+ data_table.add_row("Metrics", f"{total_metrics} metrics")
1284
+ if total_files:
1285
+ data_table.add_row("Files", f"{total_files} files")
1286
+
1287
+ console.print()
1288
+ console.print(data_table)
1289
+
1290
+ # Clean up state file if all uploads succeeded
1291
+ if not args.dry_run and len(failed) == 0 and state_file.exists():
1292
+ state_file.unlink()
1293
+ console.print("\n[dim]Upload complete. State file removed.[/dim]")
1294
+ elif not args.dry_run and failed:
1295
+ console.print(f"\n[yellow]State saved to {state_file}. Use --resume to retry failed uploads.[/yellow]")
1296
+
1297
+ # Return exit code
1298
+ return 0 if len(failed) == 0 else 1