ml-dash 0.6.6__py3-none-any.whl → 0.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ml_dash/buffer.py ADDED
@@ -0,0 +1,735 @@
1
+ """
2
+ Background buffering system for ML-Dash time-series resources.
3
+
4
+ Provides non-blocking writes for logs, metrics, and files by batching
5
+ operations in a background thread.
6
+ """
7
+
8
+ import os
9
+ import threading
10
+ import time
11
+ import warnings
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ from datetime import datetime
14
+ from queue import Empty, Queue
15
+ from typing import Any, Dict, List, Optional
16
+
17
+
18
+ def _serialize_value(value: Any) -> Any:
19
+ """
20
+ Convert value to JSON-serializable format.
21
+
22
+ Handles numpy arrays, nested dicts/lists, etc.
23
+
24
+ Args:
25
+ value: Value to serialize
26
+
27
+ Returns:
28
+ JSON-serializable value
29
+ """
30
+ # Check for numpy array
31
+ if hasattr(value, '__array__') or (hasattr(value, 'tolist') and hasattr(value, 'dtype')):
32
+ # It's a numpy array
33
+ try:
34
+ return value.tolist()
35
+ except AttributeError:
36
+ pass
37
+
38
+ # Check for numpy scalar types
39
+ if hasattr(value, 'item'):
40
+ try:
41
+ return value.item()
42
+ except (AttributeError, ValueError):
43
+ pass
44
+
45
+ # Recursively handle dicts
46
+ if isinstance(value, dict):
47
+ return {k: _serialize_value(v) for k, v in value.items()}
48
+
49
+ # Recursively handle lists
50
+ if isinstance(value, (list, tuple)):
51
+ return [_serialize_value(v) for v in value]
52
+
53
+ # Return as-is for other types (int, float, str, bool, None)
54
+ return value
55
+
56
+
57
+ class BufferConfig:
58
+ """Configuration for buffering behavior."""
59
+
60
+ def __init__(
61
+ self,
62
+ flush_interval: float = 5.0,
63
+ log_batch_size: int = 100,
64
+ metric_batch_size: int = 100,
65
+ track_batch_size: int = 100,
66
+ file_upload_workers: int = 4,
67
+ buffer_enabled: bool = True,
68
+ ):
69
+ """
70
+ Initialize buffer configuration.
71
+
72
+ Args:
73
+ flush_interval: Time-based flush interval in seconds (default: 5.0)
74
+ log_batch_size: Max logs per batch (default: 100)
75
+ metric_batch_size: Max metric points per batch (default: 100)
76
+ track_batch_size: Max track entries per batch (default: 100)
77
+ file_upload_workers: Number of parallel file upload threads (default: 4)
78
+ buffer_enabled: Enable/disable buffering (default: True)
79
+ """
80
+ self.flush_interval = flush_interval
81
+ self.log_batch_size = log_batch_size
82
+ self.metric_batch_size = metric_batch_size
83
+ self.track_batch_size = track_batch_size
84
+ self.file_upload_workers = file_upload_workers
85
+ self.buffer_enabled = buffer_enabled
86
+
87
+ @classmethod
88
+ def from_env(cls) -> "BufferConfig":
89
+ """Create configuration from environment variables."""
90
+ return cls(
91
+ flush_interval=float(os.environ.get("ML_DASH_FLUSH_INTERVAL", "5.0")),
92
+ log_batch_size=int(os.environ.get("ML_DASH_LOG_BATCH_SIZE", "100")),
93
+ metric_batch_size=int(os.environ.get("ML_DASH_METRIC_BATCH_SIZE", "100")),
94
+ track_batch_size=int(os.environ.get("ML_DASH_TRACK_BATCH_SIZE", "100")),
95
+ file_upload_workers=int(
96
+ os.environ.get("ML_DASH_FILE_UPLOAD_WORKERS", "4")
97
+ ),
98
+ buffer_enabled=os.environ.get("ML_DASH_BUFFER_ENABLED", "true").lower()
99
+ in ("true", "1", "yes"),
100
+ )
101
+
102
+
103
+ class BackgroundBufferManager:
104
+ """Unified buffer manager with background flushing thread."""
105
+
106
+ def __init__(self, experiment: "Experiment", config: BufferConfig):
107
+ """
108
+ Initialize background buffer manager.
109
+
110
+ Args:
111
+ experiment: Parent experiment instance
112
+ config: Buffer configuration
113
+ """
114
+ self._experiment = experiment
115
+ self._config = config
116
+
117
+ # Resource-specific queues
118
+ self._log_queue: Queue = Queue()
119
+ self._metric_queues: Dict[Optional[str], Queue] = {} # Per-metric queues
120
+ self._track_buffers: Dict[str, Dict[float, Dict[str, Any]]] = {} # Per-topic: {timestamp: merged_data}
121
+ self._file_queue: Queue = Queue()
122
+
123
+ # Track last flush times per resource type
124
+ self._last_log_flush = time.time()
125
+ self._last_metric_flush: Dict[Optional[str], float] = {}
126
+ self._last_track_flush: Dict[str, float] = {} # Per-topic flush times
127
+
128
+ # Background thread control
129
+ self._thread: Optional[threading.Thread] = None
130
+ self._stop_event = threading.Event()
131
+ self._flush_event = threading.Event() # Manual flush trigger
132
+
133
+ def start(self) -> None:
134
+ """Start background flushing thread."""
135
+ if self._thread is not None:
136
+ return # Already started
137
+
138
+ self._stop_event.clear()
139
+ self._flush_event.clear()
140
+ self._thread = threading.Thread(target=self._flush_loop, daemon=True)
141
+ self._thread.start()
142
+
143
+ def stop(self) -> None:
144
+ """
145
+ Stop thread and flush remaining items.
146
+
147
+ Waits indefinitely for all buffered data to be flushed to ensure data integrity.
148
+ This is important for large file uploads which may take significant time.
149
+ """
150
+ if self._thread is None:
151
+ return # Not started
152
+
153
+ # Check what needs to be flushed and inform user
154
+ log_count = self._log_queue.qsize()
155
+ metric_count = sum(q.qsize() for q in self._metric_queues.values())
156
+ track_count = sum(len(entries) for entries in self._track_buffers.values())
157
+ file_count = self._file_queue.qsize()
158
+
159
+ if log_count > 0 or metric_count > 0 or track_count > 0 or file_count > 0:
160
+ print("\n[ML-Dash] Flushing buffered data...", flush=True)
161
+
162
+ items = []
163
+ if log_count > 0:
164
+ items.append(f"{log_count} log(s)")
165
+ if metric_count > 0:
166
+ items.append(f"{metric_count} metric point(s)")
167
+ if track_count > 0:
168
+ items.append(f"{track_count} track entry(ies)")
169
+ if file_count > 0:
170
+ items.append(f"{file_count} file(s)")
171
+
172
+ if items:
173
+ print(f"[ML-Dash] - {', '.join(items)}", flush=True)
174
+
175
+ # Signal stop and trigger flush
176
+ self._stop_event.set()
177
+ self._flush_event.set()
178
+
179
+ # Wait for thread to finish (no timeout - ensure all data is flushed)
180
+ self._thread.join()
181
+
182
+ if log_count > 0 or metric_count > 0 or track_count > 0 or file_count > 0:
183
+ print("[ML-Dash] ✓ All data flushed successfully", flush=True)
184
+
185
+ self._thread = None
186
+
187
+ def buffer_log(
188
+ self,
189
+ message: str,
190
+ level: str,
191
+ metadata: Optional[Dict[str, Any]],
192
+ timestamp: Optional[datetime],
193
+ ) -> None:
194
+ """
195
+ Add log to buffer (non-blocking).
196
+
197
+ Args:
198
+ message: Log message
199
+ level: Log level
200
+ metadata: Optional metadata
201
+ timestamp: Optional timestamp
202
+ """
203
+ log_entry = {
204
+ "timestamp": (timestamp or datetime.utcnow()).isoformat() + "Z",
205
+ "level": level,
206
+ "message": message,
207
+ }
208
+
209
+ if metadata:
210
+ log_entry["metadata"] = metadata
211
+
212
+ self._log_queue.put(log_entry)
213
+
214
+ def buffer_metric(
215
+ self,
216
+ metric_name: Optional[str],
217
+ data: Dict[str, Any],
218
+ description: Optional[str],
219
+ tags: Optional[List[str]],
220
+ metadata: Optional[Dict[str, Any]],
221
+ ) -> None:
222
+ """
223
+ Add metric datapoint to buffer (non-blocking).
224
+
225
+ Args:
226
+ metric_name: Metric name (can be None for unnamed metrics)
227
+ data: Data point
228
+ description: Optional description
229
+ tags: Optional tags
230
+ metadata: Optional metadata
231
+ """
232
+ # Get or create queue for this metric
233
+ if metric_name not in self._metric_queues:
234
+ self._metric_queues[metric_name] = Queue()
235
+ self._last_metric_flush[metric_name] = time.time()
236
+
237
+ metric_entry = {
238
+ "data": data,
239
+ "description": description,
240
+ "tags": tags,
241
+ "metadata": metadata,
242
+ }
243
+
244
+ self._metric_queues[metric_name].put(metric_entry)
245
+
246
+ def buffer_track(
247
+ self,
248
+ topic: str,
249
+ timestamp: float,
250
+ data: Dict[str, Any],
251
+ ) -> None:
252
+ """
253
+ Add track entry to buffer (non-blocking) with timestamp-based merging.
254
+
255
+ Entries with the same timestamp are automatically merged.
256
+
257
+ Args:
258
+ topic: Track topic (e.g., "robot/position")
259
+ timestamp: Entry timestamp
260
+ data: Data fields
261
+ """
262
+ # Get or create buffer for this topic
263
+ if topic not in self._track_buffers:
264
+ self._track_buffers[topic] = {}
265
+ self._last_track_flush[topic] = time.time()
266
+
267
+ # Serialize data to handle numpy arrays and other non-JSON types
268
+ serialized_data = _serialize_value(data)
269
+
270
+ # Merge with existing entry at same timestamp
271
+ if timestamp in self._track_buffers[topic]:
272
+ self._track_buffers[topic][timestamp].update(serialized_data)
273
+ else:
274
+ self._track_buffers[topic][timestamp] = serialized_data
275
+
276
+ def buffer_file(
277
+ self,
278
+ file_path: str,
279
+ prefix: str,
280
+ filename: str,
281
+ description: Optional[str],
282
+ tags: Optional[List[str]],
283
+ metadata: Optional[Dict[str, Any]],
284
+ checksum: str,
285
+ content_type: str,
286
+ size_bytes: int,
287
+ ) -> None:
288
+ """
289
+ Add file upload to queue (non-blocking).
290
+
291
+ Args:
292
+ file_path: Local file path
293
+ prefix: Logical path prefix
294
+ filename: Original filename
295
+ description: Optional description
296
+ tags: Optional tags
297
+ metadata: Optional metadata
298
+ checksum: SHA256 checksum
299
+ content_type: MIME type
300
+ size_bytes: File size in bytes
301
+ """
302
+ file_entry = {
303
+ "file_path": file_path,
304
+ "prefix": prefix,
305
+ "filename": filename,
306
+ "description": description,
307
+ "tags": tags,
308
+ "metadata": metadata,
309
+ "checksum": checksum,
310
+ "content_type": content_type,
311
+ "size_bytes": size_bytes,
312
+ }
313
+
314
+ self._file_queue.put(file_entry)
315
+
316
+ def flush_all(self) -> None:
317
+ """
318
+ Manually flush all buffered data immediately.
319
+
320
+ This forces an immediate flush of all queued logs, metrics, tracks, and files
321
+ without waiting for time or size triggers.
322
+ """
323
+ # Check what needs to be flushed
324
+ log_count = self._log_queue.qsize()
325
+ metric_count = sum(q.qsize() for q in self._metric_queues.values())
326
+ track_count = sum(len(entries) for entries in self._track_buffers.values())
327
+ file_count = self._file_queue.qsize()
328
+
329
+ if log_count > 0 or metric_count > 0 or track_count > 0 or file_count > 0:
330
+ items = []
331
+ if log_count > 0:
332
+ items.append(f"{log_count} log(s)")
333
+ if metric_count > 0:
334
+ items.append(f"{metric_count} metric point(s)")
335
+ if track_count > 0:
336
+ items.append(f"{track_count} track entry(ies)")
337
+ if file_count > 0:
338
+ items.append(f"{file_count} file(s)")
339
+
340
+ if items:
341
+ print(f"[ML-Dash] Flushing {', '.join(items)}...", flush=True)
342
+
343
+ # Flush logs immediately
344
+ self._flush_logs()
345
+
346
+ # Flush all metrics immediately
347
+ for metric_name in list(self._metric_queues.keys()):
348
+ self._flush_metric(metric_name)
349
+
350
+ # Flush all tracks immediately
351
+ self.flush_tracks()
352
+
353
+ # Flush files immediately
354
+ self._flush_files()
355
+
356
+ if log_count > 0 or metric_count > 0 or track_count > 0 or file_count > 0:
357
+ print("[ML-Dash] ✓ Flush complete", flush=True)
358
+
359
+ def flush_tracks(self) -> None:
360
+ """
361
+ Flush all track topics immediately.
362
+
363
+ This is called by TracksManager.flush() for global track flush.
364
+ """
365
+ for topic in list(self._track_buffers.keys()):
366
+ self.flush_track(topic)
367
+
368
+ def flush_track(self, topic: str) -> None:
369
+ """
370
+ Flush specific track topic immediately.
371
+
372
+ Args:
373
+ topic: Track topic to flush
374
+ """
375
+ if topic not in self._track_buffers or not self._track_buffers[topic]:
376
+ return
377
+
378
+ self._flush_track(topic)
379
+
380
+ def _flush_loop(self) -> None:
381
+ """Background thread main loop."""
382
+ while not self._stop_event.is_set():
383
+ # Wait for flush event or timeout (100ms polling interval for faster response)
384
+ triggered = self._flush_event.wait(timeout=0.1)
385
+
386
+ # Check time-based triggers and flush if needed
387
+ current_time = time.time()
388
+
389
+ # Flush logs if time elapsed or queue size exceeded or manual trigger
390
+ if not self._log_queue.empty() and (
391
+ triggered
392
+ or current_time - self._last_log_flush >= self._config.flush_interval
393
+ or self._log_queue.qsize() >= self._config.log_batch_size
394
+ ):
395
+ self._flush_logs()
396
+
397
+ # Flush metrics (check each metric queue)
398
+ for metric_name, queue in list(self._metric_queues.items()):
399
+ if not queue.empty() and (
400
+ triggered
401
+ or current_time - self._last_metric_flush.get(metric_name, 0)
402
+ >= self._config.flush_interval
403
+ or queue.qsize() >= self._config.metric_batch_size
404
+ ):
405
+ self._flush_metric(metric_name)
406
+
407
+ # Flush tracks (check each topic)
408
+ for topic, entries in list(self._track_buffers.items()):
409
+ if entries and (
410
+ triggered
411
+ or current_time - self._last_track_flush.get(topic, 0)
412
+ >= self._config.flush_interval
413
+ or len(entries) >= self._config.track_batch_size
414
+ ):
415
+ self._flush_track(topic)
416
+
417
+ # Flush files (always process file queue)
418
+ if not self._file_queue.empty():
419
+ self._flush_files()
420
+
421
+ # Clear the flush event after processing
422
+ if triggered:
423
+ self._flush_event.clear()
424
+
425
+ # Final flush on shutdown
426
+ self._flush_logs()
427
+ for metric_name in list(self._metric_queues.keys()):
428
+ self._flush_metric(metric_name)
429
+ for topic in list(self._track_buffers.keys()):
430
+ self._flush_track(topic)
431
+ self._flush_files()
432
+
433
+ def _flush_logs(self) -> None:
434
+ """Batch flush logs using client.create_log_entries()."""
435
+ if self._log_queue.empty():
436
+ return
437
+
438
+ # Collect batch
439
+ batch = []
440
+ try:
441
+ while len(batch) < self._config.log_batch_size:
442
+ log_entry = self._log_queue.get_nowait()
443
+ batch.append(log_entry)
444
+ except Empty:
445
+ pass # Queue exhausted
446
+
447
+ if not batch:
448
+ return
449
+
450
+ # Write to backends
451
+ if self._experiment.run._client:
452
+ try:
453
+ self._experiment.run._client.create_log_entries(
454
+ experiment_id=self._experiment._experiment_id,
455
+ logs=batch,
456
+ )
457
+ except Exception as e:
458
+ warnings.warn(
459
+ f"Failed to flush {len(batch)} logs to remote server: {e}. "
460
+ f"Training will continue.",
461
+ RuntimeWarning,
462
+ stacklevel=3,
463
+ )
464
+
465
+ if self._experiment.run._storage:
466
+ # Local storage writes one at a time (no batch API)
467
+ for log_entry in batch:
468
+ try:
469
+ self._experiment.run._storage.write_log(
470
+ owner=self._experiment.run.owner,
471
+ project=self._experiment.run.project,
472
+ prefix=self._experiment.run._folder_path,
473
+ message=log_entry["message"],
474
+ level=log_entry["level"],
475
+ metadata=log_entry.get("metadata"),
476
+ timestamp=log_entry["timestamp"],
477
+ )
478
+ except Exception as e:
479
+ warnings.warn(
480
+ f"Failed to write log to local storage: {e}",
481
+ RuntimeWarning,
482
+ stacklevel=3,
483
+ )
484
+
485
+ self._last_log_flush = time.time()
486
+
487
+ def _flush_metric(self, metric_name: Optional[str]) -> None:
488
+ """
489
+ Batch flush metrics using client.append_batch_to_metric().
490
+
491
+ Args:
492
+ metric_name: Metric name (can be None for unnamed metrics)
493
+ """
494
+ queue = self._metric_queues.get(metric_name)
495
+ if queue is None or queue.empty():
496
+ return
497
+
498
+ # Collect batch
499
+ batch = []
500
+ description = None
501
+ tags = None
502
+ metadata = None
503
+
504
+ try:
505
+ while len(batch) < self._config.metric_batch_size:
506
+ metric_entry = queue.get_nowait()
507
+ batch.append(metric_entry["data"])
508
+
509
+ # Use first non-None description/tags/metadata
510
+ if description is None and metric_entry["description"]:
511
+ description = metric_entry["description"]
512
+ if tags is None and metric_entry["tags"]:
513
+ tags = metric_entry["tags"]
514
+ if metadata is None and metric_entry["metadata"]:
515
+ metadata = metric_entry["metadata"]
516
+ except Empty:
517
+ pass # Queue exhausted
518
+
519
+ if not batch:
520
+ return
521
+
522
+ # Write to backends
523
+ if self._experiment.run._client:
524
+ try:
525
+ self._experiment.run._client.append_batch_to_metric(
526
+ experiment_id=self._experiment._experiment_id,
527
+ metric_name=metric_name,
528
+ data_points=batch,
529
+ description=description,
530
+ tags=tags,
531
+ metadata=metadata,
532
+ )
533
+ except Exception as e:
534
+ metric_display = f"'{metric_name}'" if metric_name else "unnamed metric"
535
+ warnings.warn(
536
+ f"Failed to flush {len(batch)} points to {metric_display} on remote server: {e}. "
537
+ f"Training will continue.",
538
+ RuntimeWarning,
539
+ stacklevel=3,
540
+ )
541
+
542
+ if self._experiment.run._storage:
543
+ try:
544
+ self._experiment.run._storage.append_batch_to_metric(
545
+ owner=self._experiment.run.owner,
546
+ project=self._experiment.run.project,
547
+ prefix=self._experiment.run._folder_path,
548
+ metric_name=metric_name,
549
+ data_points=batch,
550
+ description=description,
551
+ tags=tags,
552
+ metadata=metadata,
553
+ )
554
+ except Exception as e:
555
+ metric_display = f"'{metric_name}'" if metric_name else "unnamed metric"
556
+ warnings.warn(
557
+ f"Failed to flush {len(batch)} points to {metric_display} in local storage: {e}",
558
+ RuntimeWarning,
559
+ stacklevel=3,
560
+ )
561
+
562
+ self._last_metric_flush[metric_name] = time.time()
563
+
564
+ def _flush_track(self, topic: str) -> None:
565
+ """
566
+ Batch flush track entries using client.append_batch_to_track().
567
+
568
+ Args:
569
+ topic: Track topic
570
+ """
571
+ entries_dict = self._track_buffers.get(topic)
572
+ if not entries_dict:
573
+ return
574
+
575
+ # Convert timestamp-indexed dict to batch entries
576
+ batch = []
577
+ for timestamp, data in sorted(entries_dict.items()):
578
+ entry = {"timestamp": timestamp}
579
+ entry.update(data)
580
+ batch.append(entry)
581
+
582
+ if not batch:
583
+ return
584
+
585
+ # Clear buffer for this topic
586
+ self._track_buffers[topic] = {}
587
+
588
+ # Write to remote backend
589
+ if self._experiment.run._client:
590
+ try:
591
+ self._experiment.run._client.append_batch_to_track(
592
+ experiment_id=self._experiment._experiment_id,
593
+ topic=topic,
594
+ entries=batch,
595
+ )
596
+ except Exception as e:
597
+ warnings.warn(
598
+ f"Failed to flush {len(batch)} entries to track '{topic}' on remote server: {e}. "
599
+ f"Training will continue.",
600
+ RuntimeWarning,
601
+ stacklevel=3,
602
+ )
603
+
604
+ # Write to local storage
605
+ if self._experiment.run._storage:
606
+ try:
607
+ self._experiment.run._storage.append_batch_to_track(
608
+ owner=self._experiment.run.owner,
609
+ project=self._experiment.run.project,
610
+ prefix=self._experiment.run._folder_path,
611
+ topic=topic,
612
+ entries=batch,
613
+ )
614
+ except Exception as e:
615
+ warnings.warn(
616
+ f"Failed to flush {len(batch)} entries to track '{topic}' in local storage: {e}",
617
+ RuntimeWarning,
618
+ stacklevel=3,
619
+ )
620
+
621
+ self._last_track_flush[topic] = time.time()
622
+
623
+ def _flush_files(self) -> None:
624
+ """Upload files using ThreadPoolExecutor."""
625
+ if self._file_queue.empty():
626
+ return
627
+
628
+ # Collect all pending files
629
+ files_to_upload = []
630
+ try:
631
+ while not self._file_queue.empty():
632
+ file_entry = self._file_queue.get_nowait()
633
+ files_to_upload.append(file_entry)
634
+ except Empty:
635
+ pass # Queue exhausted
636
+
637
+ if not files_to_upload:
638
+ return
639
+
640
+ # Show progress for file uploads
641
+ total_files = len(files_to_upload)
642
+ if total_files > 0:
643
+ print(f"[ML-Dash] Uploading {total_files} file(s)...", flush=True)
644
+
645
+ # Upload in parallel using ThreadPoolExecutor
646
+ completed = 0
647
+ with ThreadPoolExecutor(max_workers=self._config.file_upload_workers) as executor:
648
+ # Submit all uploads
649
+ future_to_file = {
650
+ executor.submit(self._upload_single_file, file_entry): file_entry
651
+ for file_entry in files_to_upload
652
+ }
653
+
654
+ # Wait for completion and show progress
655
+ for future in as_completed(future_to_file):
656
+ file_entry = future_to_file[future]
657
+ try:
658
+ future.result()
659
+ completed += 1
660
+ if total_files > 1:
661
+ print(f"[ML-Dash] [{completed}/{total_files}] Uploaded {file_entry['filename']}", flush=True)
662
+ except Exception as e:
663
+ completed += 1
664
+ warnings.warn(
665
+ f"Failed to upload file {file_entry['filename']}: {e}",
666
+ RuntimeWarning,
667
+ stacklevel=3,
668
+ )
669
+
670
+ def _upload_single_file(self, file_entry: Dict[str, Any]) -> None:
671
+ """
672
+ Upload a single file.
673
+
674
+ Args:
675
+ file_entry: File metadata dict
676
+ """
677
+ import os
678
+ import tempfile
679
+
680
+ file_path = file_entry["file_path"]
681
+ temp_dir = None
682
+
683
+ # Check if file is in a temp directory (created by save methods)
684
+ # If so, we'll need to clean it up after upload
685
+ temp_root = tempfile.gettempdir()
686
+ is_temp_file = file_path.startswith(temp_root)
687
+ if is_temp_file:
688
+ temp_dir = os.path.dirname(file_path)
689
+
690
+ try:
691
+ if self._experiment.run._client:
692
+ try:
693
+ self._experiment.run._client.upload_file(
694
+ experiment_id=self._experiment._experiment_id,
695
+ file_path=file_entry["file_path"],
696
+ prefix=file_entry["prefix"],
697
+ filename=file_entry["filename"],
698
+ description=file_entry["description"],
699
+ tags=file_entry["tags"],
700
+ metadata=file_entry["metadata"],
701
+ checksum=file_entry["checksum"],
702
+ content_type=file_entry["content_type"],
703
+ size_bytes=file_entry["size_bytes"],
704
+ )
705
+ except Exception as e:
706
+ raise # Re-raise to be caught by executor
707
+
708
+ if self._experiment.run._storage:
709
+ try:
710
+ self._experiment.run._storage.write_file(
711
+ owner=self._experiment.run.owner,
712
+ project=self._experiment.run.project,
713
+ prefix=self._experiment.run._folder_path,
714
+ file_path=file_entry["file_path"],
715
+ path=file_entry["prefix"],
716
+ filename=file_entry["filename"],
717
+ description=file_entry["description"],
718
+ tags=file_entry["tags"],
719
+ metadata=file_entry["metadata"],
720
+ checksum=file_entry["checksum"],
721
+ content_type=file_entry["content_type"],
722
+ size_bytes=file_entry["size_bytes"],
723
+ )
724
+ except Exception as e:
725
+ raise # Re-raise to be caught by executor
726
+ finally:
727
+ # Clean up temp file and directory if this was a temp file
728
+ if is_temp_file and temp_dir:
729
+ try:
730
+ if os.path.exists(file_path):
731
+ os.unlink(file_path)
732
+ if os.path.exists(temp_dir) and not os.listdir(temp_dir):
733
+ os.rmdir(temp_dir)
734
+ except Exception:
735
+ pass # Ignore cleanup errors