b10-transfer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,51 @@
1
+ """B10 Transfer - Lock-free PyTorch compilation cache for Baseten."""
2
+
3
+ from .core import transfer
4
+ from .torch_cache import load_compile_cache, save_compile_cache, clear_local_cache
5
+ from .async_transfers import (
6
+ start_transfer_async,
7
+ get_transfer_status,
8
+ is_transfer_complete,
9
+ wait_for_completion,
10
+ cancel_transfer,
11
+ list_active_transfers,
12
+ TransferProgress,
13
+ )
14
+ from .async_torch_cache import (
15
+ load_compile_cache_async,
16
+ save_compile_cache_async,
17
+ )
18
+ from .utils import CacheError, CacheValidationError
19
+ from .space_monitor import CacheOperationInterrupted
20
+ from .info import get_cache_info, list_available_caches
21
+ from .constants import SaveStatus, LoadStatus, TransferStatus, AsyncTransferStatus
22
+
23
+ # Version
24
+ __version__ = "0.0.1"
25
+
26
+ __all__ = [
27
+ "CacheError",
28
+ "CacheValidationError",
29
+ "CacheOperationInterrupted",
30
+ "SaveStatus",
31
+ "LoadStatus",
32
+ "TransferStatus",
33
+ "AsyncTransferStatus",
34
+ "transfer",
35
+ "load_compile_cache",
36
+ "save_compile_cache",
37
+ "clear_local_cache",
38
+ "get_cache_info",
39
+ "list_available_caches",
40
+ # Generic async operations
41
+ "start_transfer_async",
42
+ "get_transfer_status",
43
+ "is_transfer_complete",
44
+ "wait_for_completion",
45
+ "cancel_transfer",
46
+ "list_active_transfers",
47
+ "TransferProgress",
48
+ # Torch-specific async operations
49
+ "load_compile_cache_async",
50
+ "save_compile_cache_async",
51
+ ]
@@ -0,0 +1,175 @@
1
+ import os
2
+ import logging
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ from .utils import timed_fn, safe_unlink, CacheValidationError, validate_path_security
7
+ from .constants import MAX_CACHE_SIZE_MB
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ArchiveError(Exception):
13
+ """Archive operation failed."""
14
+
15
+ pass
16
+
17
+
18
+ def get_file_size_mb(file_path: Path) -> float:
19
+ """Get the size of a file in megabytes.
20
+
21
+ Args:
22
+ file_path: Path to the file to measure.
23
+
24
+ Returns:
25
+ float: File size in megabytes, or 0.0 if file doesn't exist or
26
+ can't be accessed.
27
+
28
+ Raises:
29
+ No exceptions are raised; OSError is caught and returns 0.0.
30
+ """
31
+ try:
32
+ return file_path.stat().st_size / (1024 * 1024)
33
+ except OSError:
34
+ return 0.0
35
+
36
+
37
+ def _compress_directory_to_tar(source_dir: Path, target_file: Path) -> None:
38
+ """Compress directory contents to a gzipped tar archive using system tar.
39
+
40
+ This function recursively compresses all files in the source directory
41
+ into a gzipped tar archive using the system tar command for better performance.
42
+
43
+ Args:
44
+ source_dir: Path to the directory to compress.
45
+ target_file: Path where the compressed archive will be created.
46
+
47
+ Raises:
48
+ subprocess.CalledProcessError: If tar command fails.
49
+ OSError: If source directory can't be read or target file can't be written.
50
+ """
51
+ # Use system tar command for better performance
52
+ # -czf: create, gzip, file
53
+ # -C: change to directory before archiving
54
+ cmd = ["tar", "-czf", str(target_file), "-C", str(source_dir), "."]
55
+
56
+ try:
57
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
58
+ except subprocess.CalledProcessError as e:
59
+ raise OSError(f"tar compression failed: {e.stderr}") from e
60
+
61
+
62
+ @timed_fn(logger=logger, name="Creating archive")
63
+ def create_archive(
64
+ source_dir: Path, target_file: Path, max_size_mb: int = MAX_CACHE_SIZE_MB
65
+ ) -> None:
66
+ """Create a compressed archive with path validation and size limits.
67
+
68
+ This function safely creates a gzipped tar archive from a source directory
69
+ with security validation and size constraints. It validates paths to prevent
70
+ directory traversal attacks and enforces maximum archive size limits.
71
+
72
+ Args:
73
+ source_dir: Path to the directory to archive. Must exist and be within
74
+ allowed directories (/tmp/ or its parent).
75
+ target_file: Path where the archive will be created. Must be within
76
+ allowed directories (/app or /cache).
77
+ max_size_mb: Maximum allowed archive size in megabytes. Defaults to MAX_CACHE_SIZE_MB.
78
+
79
+ Raises:
80
+ CacheValidationError: If paths are outside allowed directories.
81
+ ArchiveError: If source directory doesn't exist, archive creation fails,
82
+ or archive exceeds size limit.
83
+ """
84
+ # Validate paths
85
+ validate_path_security(
86
+ str(source_dir),
87
+ ["/tmp/", str(source_dir.parent)],
88
+ f"Source directory {source_dir}",
89
+ CacheValidationError,
90
+ )
91
+ validate_path_security(
92
+ str(target_file),
93
+ ["/app", "/cache"],
94
+ f"Target file {target_file}",
95
+ CacheValidationError,
96
+ )
97
+
98
+ if not source_dir.exists():
99
+ raise ArchiveError(f"Source directory missing: {source_dir}")
100
+
101
+ target_file.parent.mkdir(parents=True, exist_ok=True)
102
+
103
+ try:
104
+ _compress_directory_to_tar(source_dir, target_file)
105
+ size_mb = get_file_size_mb(target_file)
106
+
107
+ if size_mb > max_size_mb:
108
+ safe_unlink(
109
+ target_file, f"Failed to delete oversized archive {target_file}"
110
+ )
111
+ raise ArchiveError(f"Archive too large: {size_mb:.1f}MB > {max_size_mb}MB")
112
+
113
+ except Exception as e:
114
+ safe_unlink(target_file, f"Failed to cleanup failed archive {target_file}")
115
+ raise ArchiveError(f"Archive creation failed: {e}") from e
116
+
117
+
118
+ @timed_fn(logger=logger, name="Extracting archive")
119
+ def extract_archive(archive_file: Path, target_dir: Path) -> None:
120
+ """Extract a compressed archive with security validation.
121
+
122
+ This function safely extracts a gzipped tar archive to a target directory
123
+ with security checks to prevent directory traversal attacks. It validates
124
+ both the archive and target paths, and inspects archive contents for
125
+ malicious paths before extraction.
126
+
127
+ Args:
128
+ archive_file: Path to the archive file to extract. Must exist and be
129
+ within allowed directories (/app or /cache).
130
+ target_dir: Path to the directory where files will be extracted. Must
131
+ be within allowed directories (/tmp/ or its parent).
132
+
133
+ Raises:
134
+ CacheValidationError: If paths are outside allowed directories or if
135
+ archive contains unsafe paths (absolute paths or
136
+ paths with '..' components).
137
+ ArchiveError: If archive file doesn't exist or extraction fails.
138
+ """
139
+ # Validate paths
140
+ validate_path_security(
141
+ str(archive_file),
142
+ ["/app", "/cache"],
143
+ f"Archive file {archive_file}",
144
+ CacheValidationError,
145
+ )
146
+ validate_path_security(
147
+ str(target_dir),
148
+ ["/tmp/", str(target_dir.parent)],
149
+ f"Target directory {target_dir}",
150
+ CacheValidationError,
151
+ )
152
+
153
+ if not archive_file.exists():
154
+ raise ArchiveError(f"Archive missing: {archive_file}")
155
+
156
+ try:
157
+ target_dir.mkdir(parents=True, exist_ok=True)
158
+
159
+ # First, perform security check by listing archive contents
160
+ list_cmd = ["tar", "-tzf", str(archive_file)]
161
+ result = subprocess.run(list_cmd, check=True, capture_output=True, text=True)
162
+
163
+ # Security check on all paths in the archive
164
+ for path in result.stdout.strip().split("\n"):
165
+ if path and (os.path.isabs(path) or ".." in path):
166
+ raise CacheValidationError(f"Unsafe path in archive: {path}")
167
+
168
+ # Extract using system tar command for better performance
169
+ extract_cmd = ["tar", "-xzf", str(archive_file), "-C", str(target_dir)]
170
+ subprocess.run(extract_cmd, check=True, capture_output=True, text=True)
171
+
172
+ except subprocess.CalledProcessError as e:
173
+ raise ArchiveError(f"tar extraction failed: {e.stderr}") from e
174
+ except Exception as e:
175
+ raise ArchiveError(f"Extraction failed: {e}") from e
@@ -0,0 +1,62 @@
1
+ """Torch-specific async cache operations using the generic async transfer system."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Callable
5
+
6
+ from .async_transfers import start_transfer_async
7
+ from .torch_cache import torch_cache_load_callback, torch_cache_save_callback
8
+ from .environment import get_cache_filename
9
+ from .constants import (
10
+ TORCH_CACHE_DIR,
11
+ B10FS_CACHE_DIR,
12
+ MAX_CACHE_SIZE_MB,
13
+ CACHE_FILE_EXTENSION,
14
+ CACHE_LATEST_SUFFIX,
15
+ )
16
+
17
+
18
+ def load_compile_cache_async(
19
+ progress_callback: Optional[Callable[[str], None]] = None,
20
+ ) -> str:
21
+ """Start async PyTorch compilation cache load operation."""
22
+ b10fs_dir = Path(B10FS_CACHE_DIR)
23
+ torch_dir = Path(TORCH_CACHE_DIR)
24
+
25
+ cache_filename = get_cache_filename()
26
+ cache_file = (
27
+ b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
28
+ )
29
+
30
+ return start_transfer_async(
31
+ source=cache_file,
32
+ dest=torch_dir,
33
+ callback=torch_cache_load_callback,
34
+ operation_name="torch_cache_load",
35
+ progress_callback=progress_callback,
36
+ monitor_local=True,
37
+ monitor_b10fs=False, # No need to monitor b10fs for read operations
38
+ )
39
+
40
+
41
+ def save_compile_cache_async(
42
+ progress_callback: Optional[Callable[[str], None]] = None,
43
+ ) -> str:
44
+ """Start async PyTorch compilation cache save operation."""
45
+ b10fs_dir = Path(B10FS_CACHE_DIR)
46
+ torch_dir = Path(TORCH_CACHE_DIR)
47
+
48
+ cache_filename = get_cache_filename()
49
+ final_file = (
50
+ b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
51
+ )
52
+
53
+ return start_transfer_async(
54
+ source=torch_dir,
55
+ dest=final_file,
56
+ callback=torch_cache_save_callback,
57
+ operation_name="torch_cache_save",
58
+ progress_callback=progress_callback,
59
+ monitor_local=True,
60
+ monitor_b10fs=True,
61
+ max_size_mb=MAX_CACHE_SIZE_MB,
62
+ )
@@ -0,0 +1,275 @@
1
+ """Generic async transfer operations with progress tracking."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import threading
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Optional, Dict, Any, Callable
9
+ from dataclasses import dataclass
10
+ from datetime import datetime, timedelta
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ from .core import transfer
14
+ from .constants import AsyncTransferStatus, TransferStatus
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class TransferProgress:
21
+ operation_id: str
22
+ operation_name: str
23
+ status: AsyncTransferStatus
24
+ started_at: Optional[datetime] = None
25
+ completed_at: Optional[datetime] = None
26
+ error_message: Optional[str] = None
27
+ progress_callback: Optional[Callable[[str], None]] = None
28
+
29
+ def update_status(
30
+ self, status: AsyncTransferStatus, error_message: Optional[str] = None
31
+ ):
32
+ self.status = status
33
+ if error_message:
34
+ self.error_message = error_message
35
+ if status == AsyncTransferStatus.IN_PROGRESS and self.started_at is None:
36
+ self.started_at = datetime.now()
37
+ elif status in [
38
+ AsyncTransferStatus.SUCCESS,
39
+ AsyncTransferStatus.ERROR,
40
+ AsyncTransferStatus.INTERRUPTED,
41
+ AsyncTransferStatus.CANCELLED,
42
+ ]:
43
+ self.completed_at = datetime.now()
44
+
45
+ # Notify callback if provided
46
+ if self.progress_callback:
47
+ try:
48
+ self.progress_callback(self.operation_id)
49
+ except Exception as e:
50
+ logger.warning(f"Progress callback failed for {self.operation_id}: {e}")
51
+
52
+
53
+ class AsyncTransferManager:
54
+ def __init__(self):
55
+ self._transfers: Dict[str, TransferProgress] = {}
56
+ self._executor = ThreadPoolExecutor(
57
+ max_workers=2, thread_name_prefix="b10-async-transfer"
58
+ )
59
+ self._lock = threading.Lock()
60
+ self._operation_counter = 0
61
+
62
+ def _generate_operation_id(self, operation_name: str) -> str:
63
+ with self._lock:
64
+ self._operation_counter += 1
65
+ return f"{operation_name}_{self._operation_counter}_{int(datetime.now().timestamp())}"
66
+
67
+ def start_transfer_async(
68
+ self,
69
+ source: Path,
70
+ dest: Path,
71
+ callback: Callable,
72
+ operation_name: str,
73
+ progress_callback: Optional[Callable[[str], None]] = None,
74
+ monitor_local: bool = True,
75
+ monitor_b10fs: bool = True,
76
+ **callback_kwargs,
77
+ ) -> str:
78
+ operation_id = self._generate_operation_id(operation_name)
79
+
80
+ progress = TransferProgress(
81
+ operation_id=operation_id,
82
+ operation_name=operation_name,
83
+ status=AsyncTransferStatus.NOT_STARTED,
84
+ progress_callback=progress_callback,
85
+ )
86
+
87
+ with self._lock:
88
+ self._transfers[operation_id] = progress
89
+
90
+ # Submit the transfer operation to thread pool
91
+ future = self._executor.submit(
92
+ self._execute_transfer,
93
+ operation_id,
94
+ source,
95
+ dest,
96
+ callback,
97
+ monitor_local,
98
+ monitor_b10fs,
99
+ callback_kwargs,
100
+ )
101
+
102
+ logger.info(f"Started async transfer operation: {operation_id}")
103
+ return operation_id
104
+
105
+ def _execute_transfer(
106
+ self,
107
+ operation_id: str,
108
+ source: Path,
109
+ dest: Path,
110
+ callback: Callable,
111
+ monitor_local: bool,
112
+ monitor_b10fs: bool,
113
+ callback_kwargs: Dict[str, Any],
114
+ ) -> None:
115
+ progress = self._transfers.get(operation_id)
116
+ if not progress:
117
+ logger.error(f"Progress tracking lost for operation {operation_id}")
118
+ return
119
+
120
+ try:
121
+ progress.update_status(AsyncTransferStatus.IN_PROGRESS)
122
+ logger.info(f"Starting transfer for operation {operation_id}")
123
+
124
+ result = transfer(
125
+ source=source,
126
+ dest=dest,
127
+ callback=callback,
128
+ monitor_local=monitor_local,
129
+ monitor_b10fs=monitor_b10fs,
130
+ **callback_kwargs,
131
+ )
132
+
133
+ # Convert TransferStatus to AsyncTransferStatus
134
+ if result == TransferStatus.SUCCESS:
135
+ progress.update_status(AsyncTransferStatus.SUCCESS)
136
+ logger.info(f"Transfer completed successfully: {operation_id}")
137
+ elif result == TransferStatus.INTERRUPTED:
138
+ progress.update_status(
139
+ AsyncTransferStatus.INTERRUPTED,
140
+ "Transfer interrupted due to insufficient disk space",
141
+ )
142
+ logger.warning(f"Transfer interrupted: {operation_id}")
143
+ else:
144
+ progress.update_status(
145
+ AsyncTransferStatus.ERROR, "Transfer operation failed"
146
+ )
147
+ logger.error(f"Transfer failed: {operation_id}")
148
+
149
+ except Exception as e:
150
+ progress.update_status(AsyncTransferStatus.ERROR, str(e))
151
+ logger.error(
152
+ f"Transfer operation {operation_id} failed with exception: {e}"
153
+ )
154
+
155
+ def get_transfer_status(self, operation_id: str) -> Optional[TransferProgress]:
156
+ with self._lock:
157
+ return self._transfers.get(operation_id)
158
+
159
+ def is_transfer_complete(self, operation_id: str) -> bool:
160
+ progress = self.get_transfer_status(operation_id)
161
+ if not progress:
162
+ return False
163
+
164
+ return progress.status in [
165
+ AsyncTransferStatus.SUCCESS,
166
+ AsyncTransferStatus.ERROR,
167
+ AsyncTransferStatus.INTERRUPTED,
168
+ AsyncTransferStatus.CANCELLED,
169
+ ]
170
+
171
+ def wait_for_completion(
172
+ self, operation_id: str, timeout: Optional[float] = None
173
+ ) -> bool:
174
+ start_time = datetime.now()
175
+
176
+ while not self.is_transfer_complete(operation_id):
177
+ if timeout and (datetime.now() - start_time).total_seconds() > timeout:
178
+ return False
179
+ time.sleep(0.1) # Small delay to avoid busy waiting
180
+
181
+ return True
182
+
183
+ def cancel_transfer(self, operation_id: str) -> bool:
184
+ progress = self.get_transfer_status(operation_id)
185
+ if not progress:
186
+ return False
187
+
188
+ if progress.status == AsyncTransferStatus.IN_PROGRESS:
189
+ progress.update_status(AsyncTransferStatus.CANCELLED)
190
+ logger.info(f"Marked transfer operation as cancelled: {operation_id}")
191
+ return True
192
+
193
+ return False
194
+
195
+ def cleanup_completed_transfers(self, max_age_hours: int = 24) -> int:
196
+ cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
197
+ cleaned_count = 0
198
+
199
+ with self._lock:
200
+ to_remove = []
201
+ for operation_id, progress in self._transfers.items():
202
+ if (
203
+ progress.completed_at
204
+ and progress.completed_at < cutoff_time
205
+ and self.is_transfer_complete(operation_id)
206
+ ):
207
+ to_remove.append(operation_id)
208
+
209
+ for operation_id in to_remove:
210
+ del self._transfers[operation_id]
211
+ cleaned_count += 1
212
+
213
+ if cleaned_count > 0:
214
+ logger.info(f"Cleaned up {cleaned_count} completed transfer records")
215
+
216
+ return cleaned_count
217
+
218
+ def list_active_transfers(self) -> Dict[str, TransferProgress]:
219
+ with self._lock:
220
+ return {
221
+ op_id: progress
222
+ for op_id, progress in self._transfers.items()
223
+ if not self.is_transfer_complete(op_id)
224
+ }
225
+
226
+ def shutdown(self) -> None:
227
+ logger.info("Shutting down async transfer manager...")
228
+ self._executor.shutdown(wait=True)
229
+
230
+
231
+ # Global instance for easy access
232
+ _transfer_manager = AsyncTransferManager()
233
+
234
+
235
+ # Generic Public API functions
236
+ def start_transfer_async(
237
+ source: Path,
238
+ dest: Path,
239
+ callback: Callable,
240
+ operation_name: str,
241
+ progress_callback: Optional[Callable[[str], None]] = None,
242
+ monitor_local: bool = True,
243
+ monitor_b10fs: bool = True,
244
+ **callback_kwargs,
245
+ ) -> str:
246
+ return _transfer_manager.start_transfer_async(
247
+ source=source,
248
+ dest=dest,
249
+ callback=callback,
250
+ operation_name=operation_name,
251
+ progress_callback=progress_callback,
252
+ monitor_local=monitor_local,
253
+ monitor_b10fs=monitor_b10fs,
254
+ **callback_kwargs,
255
+ )
256
+
257
+
258
+ def get_transfer_status(operation_id: str) -> Optional[TransferProgress]:
259
+ return _transfer_manager.get_transfer_status(operation_id)
260
+
261
+
262
+ def is_transfer_complete(operation_id: str) -> bool:
263
+ return _transfer_manager.is_transfer_complete(operation_id)
264
+
265
+
266
+ def wait_for_completion(operation_id: str, timeout: Optional[float] = None) -> bool:
267
+ return _transfer_manager.wait_for_completion(operation_id, timeout)
268
+
269
+
270
+ def cancel_transfer(operation_id: str) -> bool:
271
+ return _transfer_manager.cancel_transfer(operation_id)
272
+
273
+
274
+ def list_active_transfers() -> Dict[str, TransferProgress]:
275
+ return _transfer_manager.list_active_transfers()