b10-transfer 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,127 @@
1
+ Metadata-Version: 2.3
2
+ Name: b10-transfer
3
+ Version: 0.1.2
4
+ Summary: Distributed PyTorch file transfer for Baseten - Environment-aware, lock-free file transfer management
5
+ License: MIT
6
+ Keywords: pytorch,file-transfer,cache,machine-learning,inference
7
+ Author: Shounak Ray
8
+ Author-email: shounak.noreply@baseten.co
9
+ Maintainer: Fred Liu
10
+ Maintainer-email: fred.liu.noreply@baseten.co
11
+ Requires-Python: >=3.9,<4.0
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.9
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Programming Language :: Python :: 3.13
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Requires-Dist: torch (>=2.0.0)
24
+ Requires-Dist: triton (>=2.0.0)
25
+ Project-URL: Documentation, https://docs.baseten.co/development/model/b10-transfer
26
+ Project-URL: Homepage, https://docs.baseten.co/development/model/b10-transfer
27
+ Project-URL: Repository, https://pypi.org/project/b10-transfer/
28
+ Description-Content-Type: text/markdown
29
+
30
+ https://www.notion.so/ml-infra/mega-base-cache-24291d247273805b8e20fe26677b7b0f
31
+
32
+ # B10 Transfer
33
+
34
+ PyTorch file transfer for Baseten deployments.
35
+
36
+ ## Usage
37
+
38
+ ```python
39
+ import b10_transfer
40
+
41
+ # Inside model.load() function
42
+ def load()
43
+ # Load cache before torch.compile()
44
+ cache_loaded = b10_transfer.load_compile_cache()
45
+
46
+ # ...
47
+
48
+ # Your model compilation
49
+ model = torch.compile(model)
50
+ # Warm up the model with dummy prompts, and arguments that would be typically used in your requests (e.g resolutions)
51
+ dummy_input = "What is the capital of France?"
52
+ model(dummy_input)
53
+
54
+ # ...
55
+
56
+ # Save cache after compilation
57
+ if not cache_loaded:
58
+ b10_transfer.save_compile_cache()
59
+ ```
60
+
61
+ ## Configuration
62
+
63
+ Configure via environment variables:
64
+
65
+ ```bash
66
+ # Cache directories
67
+ export TORCH_CACHE_DIR="/tmp/torchinductor_root" # Default
68
+ export B10FS_CACHE_DIR="/cache/model/compile_cache" # Default
69
+ export LOCAL_WORK_DIR="/app" # Default
70
+
71
+ # Cache limits
72
+ export MAX_CACHE_SIZE_MB="1024" # 1GB default
73
+ ```
74
+
75
+ ## How It Works
76
+
77
+ ### Environment-Specific Caching
78
+
79
+ The library automatically creates unique cache keys based on your environment:
80
+
81
+ ```
82
+ torch-2.1.0_cuda-12.1_cc-8.6_triton-2.1.0 → cache_a1b2c3d4e5f6.latest.tar.gz
83
+ torch-2.0.1_cuda-11.8_cc-7.5_triton-2.0.1 → cache_x9y8z7w6v5u4.latest.tar.gz
84
+ torch-2.1.0_cpu_triton-none → cache_m1n2o3p4q5r6.latest.tar.gz
85
+ ```
86
+
87
+ **Components used:**
88
+ - **PyTorch version** (e.g., `torch-2.1.0`)
89
+ - **CUDA version** (e.g., `cuda-12.1` or `cpu`)
90
+ - **GPU compute capability** (e.g., `cc-8.6` for A100)
91
+ - **Triton version** (e.g., `triton-2.1.0` or `triton-none`)
92
+
93
+ ### Cache Workflow
94
+
95
+ 1. **Load Phase** (startup): Generate environment key, check for matching cache in B10FS, extract to local directory
96
+ 2. **Save Phase** (after compilation): Create archive, atomic copy to B10FS with environment-specific filename
97
+
98
+ ### Lock-Free Race Prevention
99
+
100
+ Uses journal pattern with atomic filesystem operations for parallel-safe cache saves.
101
+
102
+ ## API Reference
103
+
104
+ ### Functions
105
+
106
+ - `load_compile_cache() -> bool`: Load cache from B10FS for current environment
107
+ - `save_compile_cache() -> bool`: Save cache to B10FS with environment-specific filename
108
+ - `clear_local_cache() -> bool`: Clear local cache directory
109
+ - `get_cache_info() -> Dict[str, Any]`: Get cache status information for current environment
110
+ - `list_available_caches() -> Dict[str, Any]`: List all cache files with environment details
111
+
112
+ ### Exceptions
113
+
114
+ - `CacheError`: Base exception for cache operations
115
+ - `CacheValidationError`: Path validation or compatibility check failed
116
+
117
+ ## Performance Impact
118
+
119
+ ### Debugging
120
+
121
+ Enable debug logging:
122
+
123
+ ```python
124
+ import logging
125
+ logging.getLogger('b10_transfer').setLevel(logging.DEBUG)
126
+ ```
127
+
@@ -0,0 +1,12 @@
1
+ b10_transfer/__init__.py,sha256=kZXd7GHMH7PFzr4aXs19MmuFzwEvtwzh0rwyK4jHgHo,641
2
+ b10_transfer/archive.py,sha256=GKb0mi0-YeM7ch4FLAoOLHXw0T6LkRerYad2N2y9TYM,6400
3
+ b10_transfer/cleanup.py,sha256=3RnqWNGMCcko5GQdq1Gr9VPpGzAF5J6x7xjIH9SNZ78,6226
4
+ b10_transfer/constants.py,sha256=EmWCh9AOamCZL3KkSU6YJO6KBkh93OmAIUeEfaZxHL0,4321
5
+ b10_transfer/core.py,sha256=tKA1gWDEpcb_Xfr6njScAGWCxIT3htlyVh9VLg67YMg,15445
6
+ b10_transfer/environment.py,sha256=aC0biEMQrtHk0ke_3epdcq1X9J5fPmPpBVt0fH7XF2Y,5625
7
+ b10_transfer/info.py,sha256=I3iOuImZ5r6DMJTDeBtVvzlSn6IuyPJbLJYUO_OF0ks,6299
8
+ b10_transfer/space_monitor.py,sha256=C_CKDH43bNsWdq60WStSZ3c_nQkWvScQmqU_SYHesew,10531
9
+ b10_transfer/utils.py,sha256=Stee0DFK-8MRRYNIocqaK64cJvfs4jPW3Mpx7zkWV6Y,11932
10
+ b10_transfer-0.1.2.dist-info/METADATA,sha256=P0Yf1VzkqFV4eco8x1yAjgUiahlmjgJGTAee0NNhh6o,4108
11
+ b10_transfer-0.1.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
12
+ b10_transfer-0.1.2.dist-info/RECORD,,
@@ -1,62 +0,0 @@
1
- """Torch-specific async cache operations using the generic async transfer system."""
2
-
3
- from pathlib import Path
4
- from typing import Optional, Callable
5
-
6
- from .async_transfers import start_transfer_async
7
- from .torch_cache import torch_cache_load_callback, torch_cache_save_callback
8
- from .environment import get_cache_filename
9
- from .constants import (
10
- TORCH_CACHE_DIR,
11
- B10FS_CACHE_DIR,
12
- MAX_CACHE_SIZE_MB,
13
- CACHE_FILE_EXTENSION,
14
- CACHE_LATEST_SUFFIX,
15
- )
16
-
17
-
18
- def load_compile_cache_async(
19
- progress_callback: Optional[Callable[[str], None]] = None,
20
- ) -> str:
21
- """Start async PyTorch compilation cache load operation."""
22
- b10fs_dir = Path(B10FS_CACHE_DIR)
23
- torch_dir = Path(TORCH_CACHE_DIR)
24
-
25
- cache_filename = get_cache_filename()
26
- cache_file = (
27
- b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
28
- )
29
-
30
- return start_transfer_async(
31
- source=cache_file,
32
- dest=torch_dir,
33
- callback=torch_cache_load_callback,
34
- operation_name="torch_cache_load",
35
- progress_callback=progress_callback,
36
- monitor_local=True,
37
- monitor_b10fs=False, # No need to monitor b10fs for read operations
38
- )
39
-
40
-
41
- def save_compile_cache_async(
42
- progress_callback: Optional[Callable[[str], None]] = None,
43
- ) -> str:
44
- """Start async PyTorch compilation cache save operation."""
45
- b10fs_dir = Path(B10FS_CACHE_DIR)
46
- torch_dir = Path(TORCH_CACHE_DIR)
47
-
48
- cache_filename = get_cache_filename()
49
- final_file = (
50
- b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
51
- )
52
-
53
- return start_transfer_async(
54
- source=torch_dir,
55
- dest=final_file,
56
- callback=torch_cache_save_callback,
57
- operation_name="torch_cache_save",
58
- progress_callback=progress_callback,
59
- monitor_local=True,
60
- monitor_b10fs=True,
61
- max_size_mb=MAX_CACHE_SIZE_MB,
62
- )
@@ -1,283 +0,0 @@
1
- """Generic async transfer operations with progress tracking."""
2
-
3
- import asyncio
4
- import logging
5
- import threading
6
- import time
7
- from pathlib import Path
8
- from typing import Optional, Dict, Any, Callable
9
- from dataclasses import dataclass
10
- from datetime import datetime, timedelta
11
- from concurrent.futures import ThreadPoolExecutor
12
-
13
- from .core import transfer
14
- from .constants import AsyncTransferStatus, TransferStatus
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- @dataclass
20
- class TransferProgress:
21
- operation_id: str
22
- operation_name: str
23
- status: AsyncTransferStatus
24
- started_at: Optional[datetime] = None
25
- completed_at: Optional[datetime] = None
26
- error_message: Optional[str] = None
27
- progress_callback: Optional[Callable[[str], None]] = None
28
-
29
- def update_status(
30
- self, status: AsyncTransferStatus, error_message: Optional[str] = None
31
- ):
32
- self.status = status
33
- if error_message:
34
- self.error_message = error_message
35
- if status == AsyncTransferStatus.IN_PROGRESS and self.started_at is None:
36
- self.started_at = datetime.now()
37
- elif status in [
38
- AsyncTransferStatus.SUCCESS,
39
- AsyncTransferStatus.ERROR,
40
- AsyncTransferStatus.INTERRUPTED,
41
- AsyncTransferStatus.CANCELLED,
42
- AsyncTransferStatus.DOES_NOT_EXIST,
43
- ]:
44
- self.completed_at = datetime.now()
45
-
46
- # Notify callback if provided
47
- if self.progress_callback:
48
- try:
49
- self.progress_callback(self.operation_id)
50
- except Exception as e:
51
- logger.warning(f"Progress callback failed for {self.operation_id}: {e}")
52
-
53
-
54
- class AsyncTransferManager:
55
- def __init__(self):
56
- self._transfers: Dict[str, TransferProgress] = {}
57
- self._executor = ThreadPoolExecutor(
58
- max_workers=2, thread_name_prefix="b10-async-transfer"
59
- )
60
- self._lock = threading.Lock()
61
- self._operation_counter = 0
62
-
63
- def _generate_operation_id(self, operation_name: str) -> str:
64
- with self._lock:
65
- self._operation_counter += 1
66
- return f"{operation_name}_{self._operation_counter}_{int(datetime.now().timestamp())}"
67
-
68
- def start_transfer_async(
69
- self,
70
- source: Path,
71
- dest: Path,
72
- callback: Callable,
73
- operation_name: str,
74
- progress_callback: Optional[Callable[[str], None]] = None,
75
- monitor_local: bool = True,
76
- monitor_b10fs: bool = True,
77
- **callback_kwargs,
78
- ) -> str:
79
- operation_id = self._generate_operation_id(operation_name)
80
-
81
- progress = TransferProgress(
82
- operation_id=operation_id,
83
- operation_name=operation_name,
84
- status=AsyncTransferStatus.NOT_STARTED,
85
- progress_callback=progress_callback,
86
- )
87
-
88
- with self._lock:
89
- self._transfers[operation_id] = progress
90
-
91
- # Submit the transfer operation to thread pool
92
- future = self._executor.submit(
93
- self._execute_transfer,
94
- operation_id,
95
- source,
96
- dest,
97
- callback,
98
- monitor_local,
99
- monitor_b10fs,
100
- callback_kwargs,
101
- )
102
-
103
- logger.info(f"Started async transfer operation: {operation_id}")
104
- return operation_id
105
-
106
- def _execute_transfer(
107
- self,
108
- operation_id: str,
109
- source: Path,
110
- dest: Path,
111
- callback: Callable,
112
- monitor_local: bool,
113
- monitor_b10fs: bool,
114
- callback_kwargs: Dict[str, Any],
115
- ) -> None:
116
- progress = self._transfers.get(operation_id)
117
- if not progress:
118
- logger.error(f"Progress tracking lost for operation {operation_id}")
119
- return
120
-
121
- try:
122
- progress.update_status(AsyncTransferStatus.IN_PROGRESS)
123
- logger.info(f"Starting transfer for operation {operation_id}")
124
-
125
- result = transfer(
126
- source=source,
127
- dest=dest,
128
- callback=callback,
129
- monitor_local=monitor_local,
130
- monitor_b10fs=monitor_b10fs,
131
- **callback_kwargs,
132
- )
133
-
134
- # Convert TransferStatus to AsyncTransferStatus
135
- if result == TransferStatus.SUCCESS:
136
- progress.update_status(AsyncTransferStatus.SUCCESS)
137
- logger.info(f"Transfer completed successfully: {operation_id}")
138
- elif result == TransferStatus.INTERRUPTED:
139
- progress.update_status(
140
- AsyncTransferStatus.INTERRUPTED,
141
- "Transfer interrupted due to insufficient disk space",
142
- )
143
- logger.warning(f"Transfer interrupted: {operation_id}")
144
- elif result == TransferStatus.DOES_NOT_EXIST:
145
- progress.update_status(
146
- AsyncTransferStatus.DOES_NOT_EXIST,
147
- "Cache file not found",
148
- )
149
- logger.info(f"Transfer failed - file not found: {operation_id}")
150
- else:
151
- progress.update_status(
152
- AsyncTransferStatus.ERROR, "Transfer operation failed"
153
- )
154
- logger.error(f"Transfer failed: {operation_id}")
155
-
156
- except Exception as e:
157
- progress.update_status(AsyncTransferStatus.ERROR, str(e))
158
- logger.error(
159
- f"Transfer operation {operation_id} failed with exception: {e}"
160
- )
161
-
162
- def get_transfer_status(self, operation_id: str) -> Optional[TransferProgress]:
163
- with self._lock:
164
- return self._transfers.get(operation_id)
165
-
166
- def is_transfer_complete(self, operation_id: str) -> bool:
167
- progress = self.get_transfer_status(operation_id)
168
- if not progress:
169
- return False
170
-
171
- return progress.status in [
172
- AsyncTransferStatus.SUCCESS,
173
- AsyncTransferStatus.ERROR,
174
- AsyncTransferStatus.INTERRUPTED,
175
- AsyncTransferStatus.CANCELLED,
176
- AsyncTransferStatus.DOES_NOT_EXIST,
177
- ]
178
-
179
- def wait_for_completion(
180
- self, operation_id: str, timeout: Optional[float] = None
181
- ) -> bool:
182
- start_time = datetime.now()
183
-
184
- while not self.is_transfer_complete(operation_id):
185
- if timeout and (datetime.now() - start_time).total_seconds() > timeout:
186
- return False
187
- time.sleep(0.1) # Small delay to avoid busy waiting
188
-
189
- return True
190
-
191
- def cancel_transfer(self, operation_id: str) -> bool:
192
- progress = self.get_transfer_status(operation_id)
193
- if not progress:
194
- return False
195
-
196
- if progress.status == AsyncTransferStatus.IN_PROGRESS:
197
- progress.update_status(AsyncTransferStatus.CANCELLED)
198
- logger.info(f"Marked transfer operation as cancelled: {operation_id}")
199
- return True
200
-
201
- return False
202
-
203
- def cleanup_completed_transfers(self, max_age_hours: int = 24) -> int:
204
- cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
205
- cleaned_count = 0
206
-
207
- with self._lock:
208
- to_remove = []
209
- for operation_id, progress in self._transfers.items():
210
- if (
211
- progress.completed_at
212
- and progress.completed_at < cutoff_time
213
- and self.is_transfer_complete(operation_id)
214
- ):
215
- to_remove.append(operation_id)
216
-
217
- for operation_id in to_remove:
218
- del self._transfers[operation_id]
219
- cleaned_count += 1
220
-
221
- if cleaned_count > 0:
222
- logger.info(f"Cleaned up {cleaned_count} completed transfer records")
223
-
224
- return cleaned_count
225
-
226
- def list_active_transfers(self) -> Dict[str, TransferProgress]:
227
- with self._lock:
228
- return {
229
- op_id: progress
230
- for op_id, progress in self._transfers.items()
231
- if not self.is_transfer_complete(op_id)
232
- }
233
-
234
- def shutdown(self) -> None:
235
- logger.info("Shutting down async transfer manager...")
236
- self._executor.shutdown(wait=True)
237
-
238
-
239
- # Global instance for easy access
240
- _transfer_manager = AsyncTransferManager()
241
-
242
-
243
- # Generic Public API functions
244
- def start_transfer_async(
245
- source: Path,
246
- dest: Path,
247
- callback: Callable,
248
- operation_name: str,
249
- progress_callback: Optional[Callable[[str], None]] = None,
250
- monitor_local: bool = True,
251
- monitor_b10fs: bool = True,
252
- **callback_kwargs,
253
- ) -> str:
254
- return _transfer_manager.start_transfer_async(
255
- source=source,
256
- dest=dest,
257
- callback=callback,
258
- operation_name=operation_name,
259
- progress_callback=progress_callback,
260
- monitor_local=monitor_local,
261
- monitor_b10fs=monitor_b10fs,
262
- **callback_kwargs,
263
- )
264
-
265
-
266
- def get_transfer_status(operation_id: str) -> Optional[TransferProgress]:
267
- return _transfer_manager.get_transfer_status(operation_id)
268
-
269
-
270
- def is_transfer_complete(operation_id: str) -> bool:
271
- return _transfer_manager.is_transfer_complete(operation_id)
272
-
273
-
274
- def wait_for_completion(operation_id: str, timeout: Optional[float] = None) -> bool:
275
- return _transfer_manager.wait_for_completion(operation_id, timeout)
276
-
277
-
278
- def cancel_transfer(operation_id: str) -> bool:
279
- return _transfer_manager.cancel_transfer(operation_id)
280
-
281
-
282
- def list_active_transfers() -> Dict[str, TransferProgress]:
283
- return _transfer_manager.list_active_transfers()