b10-transfer 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- b10_transfer-0.0.1/PKG-INFO +219 -0
- b10_transfer-0.0.1/README.md +189 -0
- b10_transfer-0.0.1/pyproject.toml +35 -0
- b10_transfer-0.0.1/src/b10_transfer/__init__.py +51 -0
- b10_transfer-0.0.1/src/b10_transfer/archive.py +175 -0
- b10_transfer-0.0.1/src/b10_transfer/async_torch_cache.py +62 -0
- b10_transfer-0.0.1/src/b10_transfer/async_transfers.py +275 -0
- b10_transfer-0.0.1/src/b10_transfer/cleanup.py +179 -0
- b10_transfer-0.0.1/src/b10_transfer/constants.py +149 -0
- b10_transfer-0.0.1/src/b10_transfer/core.py +160 -0
- b10_transfer-0.0.1/src/b10_transfer/environment.py +134 -0
- b10_transfer-0.0.1/src/b10_transfer/info.py +172 -0
- b10_transfer-0.0.1/src/b10_transfer/space_monitor.py +299 -0
- b10_transfer-0.0.1/src/b10_transfer/torch_cache.py +376 -0
- b10_transfer-0.0.1/src/b10_transfer/utils.py +355 -0
@@ -0,0 +1,219 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: b10-transfer
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: Distributed PyTorch compilation cache for Baseten - Environment-aware, lock-free compilation cache management
|
5
|
+
License: MIT
|
6
|
+
Keywords: pytorch,torch.compile,cache,machine-learning,inference
|
7
|
+
Author: Shounak Ray
|
8
|
+
Author-email: shounak.noreply@baseten.co
|
9
|
+
Maintainer: Fred Liu
|
10
|
+
Maintainer-email: fred.liu.noreply@baseten.co
|
11
|
+
Requires-Python: >=3.9,<4.0
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
13
|
+
Classifier: Intended Audience :: Developers
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
20
|
+
Classifier: Programming Language :: Python :: 3.13
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
+
Requires-Dist: torch (>=2.0.0)
|
24
|
+
Requires-Dist: triton (>=2.0.0)
|
25
|
+
Project-URL: Documentation, https://docs.baseten.co/development/model/b10-transfer
|
26
|
+
Project-URL: Homepage, https://docs.baseten.co/development/model/b10-transfer
|
27
|
+
Project-URL: Repository, https://pypi.org/project/b10-transfer/
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
|
30
|
+
https://www.notion.so/ml-infra/mega-base-cache-24291d247273805b8e20fe26677b7b0f
|
31
|
+
|
32
|
+
# B10 Transfer
|
33
|
+
|
34
|
+
PyTorch compilation cache for Baseten deployments.
|
35
|
+
|
36
|
+
## Usage
|
37
|
+
|
38
|
+
### Synchronous Operations (Blocking)
|
39
|
+
|
40
|
+
```python
|
41
|
+
import b10_transfer
|
42
|
+
|
43
|
+
# Inside model.load() function
|
44
|
+
def load():
|
45
|
+
# Load cache before torch.compile()
|
46
|
+
status = b10_transfer.load_compile_cache()
|
47
|
+
|
48
|
+
# ...
|
49
|
+
|
50
|
+
# Your model compilation
|
51
|
+
model = torch.compile(model)
|
52
|
+
# Warm up the model with dummy prompts, and arguments that would be typically used in your requests (e.g resolutions)
|
53
|
+
dummy_input = "What is the capital of France?"
|
54
|
+
model(dummy_input)
|
55
|
+
|
56
|
+
# ...
|
57
|
+
|
58
|
+
# Save cache after compilation
|
59
|
+
if status != b10_transfer.LoadStatus.SUCCESS:
|
60
|
+
b10_transfer.save_compile_cache()
|
61
|
+
```
|
62
|
+
|
63
|
+
### Asynchronous Operations (Non-blocking)
|
64
|
+
|
65
|
+
```python
|
66
|
+
import b10_transfer
|
67
|
+
|
68
|
+
def load_with_async_cache():
|
69
|
+
# Start async cache load (returns immediately with operation ID)
|
70
|
+
operation_id = b10_transfer.load_compile_cache_async()
|
71
|
+
|
72
|
+
# Check status periodically
|
73
|
+
while not b10_transfer.is_transfer_complete(operation_id):
|
74
|
+
status = b10_transfer.get_transfer_status(operation_id)
|
75
|
+
print(f"Cache load status: {status.status}")
|
76
|
+
time.sleep(1)
|
77
|
+
|
78
|
+
# Get final status
|
79
|
+
final_status = b10_transfer.get_transfer_status(operation_id)
|
80
|
+
if final_status.status == b10_transfer.AsyncTransferStatus.SUCCESS:
|
81
|
+
print("Cache loaded successfully!")
|
82
|
+
|
83
|
+
# Your model compilation...
|
84
|
+
model = torch.compile(model)
|
85
|
+
|
86
|
+
# Async save
|
87
|
+
save_op_id = b10_transfer.save_compile_cache_async()
|
88
|
+
|
89
|
+
# You can continue with other work while save happens in background
|
90
|
+
# Or wait for completion if needed
|
91
|
+
b10_transfer.wait_for_completion(save_op_id, timeout=300) # 5 minute timeout
|
92
|
+
|
93
|
+
# With progress callback
|
94
|
+
def on_progress(operation_id: str):
|
95
|
+
status = b10_transfer.get_transfer_status(operation_id)
|
96
|
+
print(f"Transfer {operation_id}: {status.status}")
|
97
|
+
|
98
|
+
operation_id = b10_transfer.load_compile_cache_async(progress_callback=on_progress)
|
99
|
+
```
|
100
|
+
|
101
|
+
### Generic Async Operations
|
102
|
+
|
103
|
+
You can also use the generic async system for custom transfer operations:
|
104
|
+
|
105
|
+
```python
|
106
|
+
import b10_transfer
|
107
|
+
from pathlib import Path
|
108
|
+
|
109
|
+
def my_custom_callback(source: Path, dest: Path):
|
110
|
+
# Your custom transfer logic here
|
111
|
+
# This could be any file operation, compression, etc.
|
112
|
+
shutil.copy2(source, dest)
|
113
|
+
|
114
|
+
# Start a generic async transfer
|
115
|
+
operation_id = b10_transfer.start_transfer_async(
|
116
|
+
source=Path("/source/file.txt"),
|
117
|
+
dest=Path("/dest/file.txt"),
|
118
|
+
callback=my_custom_callback,
|
119
|
+
operation_name="custom_file_copy",
|
120
|
+
monitor_local=True,
|
121
|
+
monitor_b10fs=False
|
122
|
+
)
|
123
|
+
|
124
|
+
# Use the same progress tracking as torch cache operations
|
125
|
+
b10_transfer.wait_for_completion(operation_id)
|
126
|
+
```
|
127
|
+
|
128
|
+
## Configuration
|
129
|
+
|
130
|
+
Configure via environment variables:
|
131
|
+
|
132
|
+
```bash
|
133
|
+
# Cache directories
|
134
|
+
export TORCH_CACHE_DIR="/tmp/torchinductor_root" # Default
|
135
|
+
export B10FS_CACHE_DIR="/cache/model/compile_cache" # Default
|
136
|
+
export LOCAL_WORK_DIR="/app" # Default
|
137
|
+
|
138
|
+
# Cache limits
|
139
|
+
export MAX_CACHE_SIZE_MB="1024" # 1GB default
|
140
|
+
```
|
141
|
+
|
142
|
+
## How It Works
|
143
|
+
|
144
|
+
### Environment-Specific Caching
|
145
|
+
|
146
|
+
The library automatically creates unique cache keys based on your environment:
|
147
|
+
|
148
|
+
```
|
149
|
+
torch-2.1.0_cuda-12.1_cc-8.6_triton-2.1.0 → cache_a1b2c3d4e5f6.latest.tar.gz
|
150
|
+
torch-2.0.1_cuda-11.8_cc-7.5_triton-2.0.1 → cache_x9y8z7w6v5u4.latest.tar.gz
|
151
|
+
torch-2.1.0_cpu_triton-none → cache_m1n2o3p4q5r6.latest.tar.gz
|
152
|
+
```
|
153
|
+
|
154
|
+
**Components used:**
|
155
|
+
- **PyTorch version** (e.g., `torch-2.1.0`)
|
156
|
+
- **CUDA version** (e.g., `cuda-12.1` or `cpu`)
|
157
|
+
- **GPU compute capability** (e.g., `cc-8.6` for A100)
|
158
|
+
- **Triton version** (e.g., `triton-2.1.0` or `triton-none`)
|
159
|
+
|
160
|
+
### Cache Workflow
|
161
|
+
|
162
|
+
1. **Load Phase** (startup): Generate environment key, check for matching cache in B10FS, extract to local directory
|
163
|
+
2. **Save Phase** (after compilation): Create archive, atomic copy to B10FS with environment-specific filename
|
164
|
+
|
165
|
+
### Lock-Free Race Prevention
|
166
|
+
|
167
|
+
Uses journal pattern with atomic filesystem operations for parallel-safe cache saves.
|
168
|
+
|
169
|
+
## API Reference
|
170
|
+
|
171
|
+
### Synchronous Functions
|
172
|
+
|
173
|
+
- `load_compile_cache() -> LoadStatus`: Load cache from B10FS for current environment
|
174
|
+
- `save_compile_cache() -> SaveStatus`: Save cache to B10FS with environment-specific filename
|
175
|
+
- `clear_local_cache() -> bool`: Clear local cache directory
|
176
|
+
- `get_cache_info() -> Dict[str, Any]`: Get cache status information for current environment
|
177
|
+
- `list_available_caches() -> Dict[str, Any]`: List all cache files with environment details
|
178
|
+
|
179
|
+
### Generic Asynchronous Functions
|
180
|
+
|
181
|
+
- `start_transfer_async(source, dest, callback, operation_name, **kwargs) -> str`: Start any async transfer operation
|
182
|
+
- `get_transfer_status(operation_id: str) -> TransferProgress`: Get current status of async operation
|
183
|
+
- `is_transfer_complete(operation_id: str) -> bool`: Check if async operation has completed
|
184
|
+
- `wait_for_completion(operation_id: str, timeout=None) -> bool`: Wait for async operation to complete
|
185
|
+
- `cancel_transfer(operation_id: str) -> bool`: Attempt to cancel running operation
|
186
|
+
- `list_active_transfers() -> Dict[str, TransferProgress]`: Get all active transfer operations
|
187
|
+
|
188
|
+
### Torch Cache Async Functions
|
189
|
+
|
190
|
+
- `load_compile_cache_async(progress_callback=None) -> str`: Start async cache load, returns operation ID
|
191
|
+
- `save_compile_cache_async(progress_callback=None) -> str`: Start async cache save, returns operation ID
|
192
|
+
|
193
|
+
### Status Enums
|
194
|
+
|
195
|
+
- `LoadStatus`: SUCCESS, ERROR, DOES_NOT_EXIST, SKIPPED
|
196
|
+
- `SaveStatus`: SUCCESS, ERROR, SKIPPED
|
197
|
+
- `AsyncTransferStatus`: NOT_STARTED, IN_PROGRESS, SUCCESS, ERROR, INTERRUPTED, CANCELLED
|
198
|
+
|
199
|
+
### Data Classes
|
200
|
+
|
201
|
+
- `TransferProgress`: Contains operation_id, status, started_at, completed_at, error_message
|
202
|
+
|
203
|
+
### Exceptions
|
204
|
+
|
205
|
+
- `CacheError`: Base exception for cache operations
|
206
|
+
- `CacheValidationError`: Path validation or compatibility check failed
|
207
|
+
- `CacheOperationInterrupted`: Operation interrupted due to insufficient disk space
|
208
|
+
|
209
|
+
## Performance Impact
|
210
|
+
|
211
|
+
### Debugging
|
212
|
+
|
213
|
+
Enable debug logging:
|
214
|
+
|
215
|
+
```python
|
216
|
+
import logging
|
217
|
+
logging.getLogger('b10_tcache').setLevel(logging.DEBUG)
|
218
|
+
```
|
219
|
+
|
@@ -0,0 +1,189 @@
|
|
1
|
+
https://www.notion.so/ml-infra/mega-base-cache-24291d247273805b8e20fe26677b7b0f
|
2
|
+
|
3
|
+
# B10 Transfer
|
4
|
+
|
5
|
+
PyTorch compilation cache for Baseten deployments.
|
6
|
+
|
7
|
+
## Usage
|
8
|
+
|
9
|
+
### Synchronous Operations (Blocking)
|
10
|
+
|
11
|
+
```python
|
12
|
+
import b10_transfer
|
13
|
+
|
14
|
+
# Inside model.load() function
|
15
|
+
def load():
|
16
|
+
# Load cache before torch.compile()
|
17
|
+
status = b10_transfer.load_compile_cache()
|
18
|
+
|
19
|
+
# ...
|
20
|
+
|
21
|
+
# Your model compilation
|
22
|
+
model = torch.compile(model)
|
23
|
+
# Warm up the model with dummy prompts, and arguments that would be typically used in your requests (e.g resolutions)
|
24
|
+
dummy_input = "What is the capital of France?"
|
25
|
+
model(dummy_input)
|
26
|
+
|
27
|
+
# ...
|
28
|
+
|
29
|
+
# Save cache after compilation
|
30
|
+
if status != b10_transfer.LoadStatus.SUCCESS:
|
31
|
+
b10_transfer.save_compile_cache()
|
32
|
+
```
|
33
|
+
|
34
|
+
### Asynchronous Operations (Non-blocking)
|
35
|
+
|
36
|
+
```python
|
37
|
+
import b10_transfer
|
38
|
+
|
39
|
+
def load_with_async_cache():
|
40
|
+
# Start async cache load (returns immediately with operation ID)
|
41
|
+
operation_id = b10_transfer.load_compile_cache_async()
|
42
|
+
|
43
|
+
# Check status periodically
|
44
|
+
while not b10_transfer.is_transfer_complete(operation_id):
|
45
|
+
status = b10_transfer.get_transfer_status(operation_id)
|
46
|
+
print(f"Cache load status: {status.status}")
|
47
|
+
time.sleep(1)
|
48
|
+
|
49
|
+
# Get final status
|
50
|
+
final_status = b10_transfer.get_transfer_status(operation_id)
|
51
|
+
if final_status.status == b10_transfer.AsyncTransferStatus.SUCCESS:
|
52
|
+
print("Cache loaded successfully!")
|
53
|
+
|
54
|
+
# Your model compilation...
|
55
|
+
model = torch.compile(model)
|
56
|
+
|
57
|
+
# Async save
|
58
|
+
save_op_id = b10_transfer.save_compile_cache_async()
|
59
|
+
|
60
|
+
# You can continue with other work while save happens in background
|
61
|
+
# Or wait for completion if needed
|
62
|
+
b10_transfer.wait_for_completion(save_op_id, timeout=300) # 5 minute timeout
|
63
|
+
|
64
|
+
# With progress callback
|
65
|
+
def on_progress(operation_id: str):
|
66
|
+
status = b10_transfer.get_transfer_status(operation_id)
|
67
|
+
print(f"Transfer {operation_id}: {status.status}")
|
68
|
+
|
69
|
+
operation_id = b10_transfer.load_compile_cache_async(progress_callback=on_progress)
|
70
|
+
```
|
71
|
+
|
72
|
+
### Generic Async Operations
|
73
|
+
|
74
|
+
You can also use the generic async system for custom transfer operations:
|
75
|
+
|
76
|
+
```python
|
77
|
+
import b10_transfer
|
78
|
+
from pathlib import Path
|
79
|
+
|
80
|
+
def my_custom_callback(source: Path, dest: Path):
|
81
|
+
# Your custom transfer logic here
|
82
|
+
# This could be any file operation, compression, etc.
|
83
|
+
shutil.copy2(source, dest)
|
84
|
+
|
85
|
+
# Start a generic async transfer
|
86
|
+
operation_id = b10_transfer.start_transfer_async(
|
87
|
+
source=Path("/source/file.txt"),
|
88
|
+
dest=Path("/dest/file.txt"),
|
89
|
+
callback=my_custom_callback,
|
90
|
+
operation_name="custom_file_copy",
|
91
|
+
monitor_local=True,
|
92
|
+
monitor_b10fs=False
|
93
|
+
)
|
94
|
+
|
95
|
+
# Use the same progress tracking as torch cache operations
|
96
|
+
b10_transfer.wait_for_completion(operation_id)
|
97
|
+
```
|
98
|
+
|
99
|
+
## Configuration
|
100
|
+
|
101
|
+
Configure via environment variables:
|
102
|
+
|
103
|
+
```bash
|
104
|
+
# Cache directories
|
105
|
+
export TORCH_CACHE_DIR="/tmp/torchinductor_root" # Default
|
106
|
+
export B10FS_CACHE_DIR="/cache/model/compile_cache" # Default
|
107
|
+
export LOCAL_WORK_DIR="/app" # Default
|
108
|
+
|
109
|
+
# Cache limits
|
110
|
+
export MAX_CACHE_SIZE_MB="1024" # 1GB default
|
111
|
+
```
|
112
|
+
|
113
|
+
## How It Works
|
114
|
+
|
115
|
+
### Environment-Specific Caching
|
116
|
+
|
117
|
+
The library automatically creates unique cache keys based on your environment:
|
118
|
+
|
119
|
+
```
|
120
|
+
torch-2.1.0_cuda-12.1_cc-8.6_triton-2.1.0 → cache_a1b2c3d4e5f6.latest.tar.gz
|
121
|
+
torch-2.0.1_cuda-11.8_cc-7.5_triton-2.0.1 → cache_x9y8z7w6v5u4.latest.tar.gz
|
122
|
+
torch-2.1.0_cpu_triton-none → cache_m1n2o3p4q5r6.latest.tar.gz
|
123
|
+
```
|
124
|
+
|
125
|
+
**Components used:**
|
126
|
+
- **PyTorch version** (e.g., `torch-2.1.0`)
|
127
|
+
- **CUDA version** (e.g., `cuda-12.1` or `cpu`)
|
128
|
+
- **GPU compute capability** (e.g., `cc-8.6` for A100)
|
129
|
+
- **Triton version** (e.g., `triton-2.1.0` or `triton-none`)
|
130
|
+
|
131
|
+
### Cache Workflow
|
132
|
+
|
133
|
+
1. **Load Phase** (startup): Generate environment key, check for matching cache in B10FS, extract to local directory
|
134
|
+
2. **Save Phase** (after compilation): Create archive, atomic copy to B10FS with environment-specific filename
|
135
|
+
|
136
|
+
### Lock-Free Race Prevention
|
137
|
+
|
138
|
+
Uses journal pattern with atomic filesystem operations for parallel-safe cache saves.
|
139
|
+
|
140
|
+
## API Reference
|
141
|
+
|
142
|
+
### Synchronous Functions
|
143
|
+
|
144
|
+
- `load_compile_cache() -> LoadStatus`: Load cache from B10FS for current environment
|
145
|
+
- `save_compile_cache() -> SaveStatus`: Save cache to B10FS with environment-specific filename
|
146
|
+
- `clear_local_cache() -> bool`: Clear local cache directory
|
147
|
+
- `get_cache_info() -> Dict[str, Any]`: Get cache status information for current environment
|
148
|
+
- `list_available_caches() -> Dict[str, Any]`: List all cache files with environment details
|
149
|
+
|
150
|
+
### Generic Asynchronous Functions
|
151
|
+
|
152
|
+
- `start_transfer_async(source, dest, callback, operation_name, **kwargs) -> str`: Start any async transfer operation
|
153
|
+
- `get_transfer_status(operation_id: str) -> TransferProgress`: Get current status of async operation
|
154
|
+
- `is_transfer_complete(operation_id: str) -> bool`: Check if async operation has completed
|
155
|
+
- `wait_for_completion(operation_id: str, timeout=None) -> bool`: Wait for async operation to complete
|
156
|
+
- `cancel_transfer(operation_id: str) -> bool`: Attempt to cancel running operation
|
157
|
+
- `list_active_transfers() -> Dict[str, TransferProgress]`: Get all active transfer operations
|
158
|
+
|
159
|
+
### Torch Cache Async Functions
|
160
|
+
|
161
|
+
- `load_compile_cache_async(progress_callback=None) -> str`: Start async cache load, returns operation ID
|
162
|
+
- `save_compile_cache_async(progress_callback=None) -> str`: Start async cache save, returns operation ID
|
163
|
+
|
164
|
+
### Status Enums
|
165
|
+
|
166
|
+
- `LoadStatus`: SUCCESS, ERROR, DOES_NOT_EXIST, SKIPPED
|
167
|
+
- `SaveStatus`: SUCCESS, ERROR, SKIPPED
|
168
|
+
- `AsyncTransferStatus`: NOT_STARTED, IN_PROGRESS, SUCCESS, ERROR, INTERRUPTED, CANCELLED
|
169
|
+
|
170
|
+
### Data Classes
|
171
|
+
|
172
|
+
- `TransferProgress`: Contains operation_id, status, started_at, completed_at, error_message
|
173
|
+
|
174
|
+
### Exceptions
|
175
|
+
|
176
|
+
- `CacheError`: Base exception for cache operations
|
177
|
+
- `CacheValidationError`: Path validation or compatibility check failed
|
178
|
+
- `CacheOperationInterrupted`: Operation interrupted due to insufficient disk space
|
179
|
+
|
180
|
+
## Performance Impact
|
181
|
+
|
182
|
+
### Debugging
|
183
|
+
|
184
|
+
Enable debug logging:
|
185
|
+
|
186
|
+
```python
|
187
|
+
import logging
|
188
|
+
logging.getLogger('b10_tcache').setLevel(logging.DEBUG)
|
189
|
+
```
|
@@ -0,0 +1,35 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["poetry-core"]
|
3
|
+
build-backend = "poetry.core.masonry.api"
|
4
|
+
|
5
|
+
[tool.poetry]
|
6
|
+
name = "b10-transfer"
|
7
|
+
version = "0.0.1"
|
8
|
+
description = "Distributed PyTorch compilation cache for Baseten - Environment-aware, lock-free compilation cache management"
|
9
|
+
authors = ["Shounak Ray <shounak.noreply@baseten.co>", "Fred Liu <fred.liu.noreply@baseten.co>"]
|
10
|
+
maintainers = ["Fred Liu <fred.liu.noreply@baseten.co>", "Shounak Ray <shounak.noreply@baseten.co>"]
|
11
|
+
readme = "README.md"
|
12
|
+
homepage = "https://docs.baseten.co/development/model/b10-transfer"
|
13
|
+
documentation = "https://docs.baseten.co/development/model/b10-transfer"
|
14
|
+
repository = "https://pypi.org/project/b10-transfer/"
|
15
|
+
license = "MIT"
|
16
|
+
keywords = ["pytorch", "torch.compile", "cache", "machine-learning", "inference"]
|
17
|
+
classifiers = [
|
18
|
+
"Development Status :: 4 - Beta",
|
19
|
+
"Intended Audience :: Developers",
|
20
|
+
"License :: OSI Approved :: MIT License",
|
21
|
+
"Programming Language :: Python :: 3",
|
22
|
+
"Programming Language :: Python :: 3.9",
|
23
|
+
"Programming Language :: Python :: 3.10",
|
24
|
+
"Programming Language :: Python :: 3.11",
|
25
|
+
"Programming Language :: Python :: 3.12",
|
26
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
27
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
28
|
+
]
|
29
|
+
packages = [{include = "b10_transfer", from = "src"}]
|
30
|
+
|
31
|
+
[tool.poetry.dependencies]
|
32
|
+
python = "^3.9"
|
33
|
+
torch = ">=2.0.0"
|
34
|
+
triton = ">=2.0.0"
|
35
|
+
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""B10 Transfer - Lock-free PyTorch compilation cache for Baseten."""
|
2
|
+
|
3
|
+
from .core import transfer
|
4
|
+
from .torch_cache import load_compile_cache, save_compile_cache, clear_local_cache
|
5
|
+
from .async_transfers import (
|
6
|
+
start_transfer_async,
|
7
|
+
get_transfer_status,
|
8
|
+
is_transfer_complete,
|
9
|
+
wait_for_completion,
|
10
|
+
cancel_transfer,
|
11
|
+
list_active_transfers,
|
12
|
+
TransferProgress,
|
13
|
+
)
|
14
|
+
from .async_torch_cache import (
|
15
|
+
load_compile_cache_async,
|
16
|
+
save_compile_cache_async,
|
17
|
+
)
|
18
|
+
from .utils import CacheError, CacheValidationError
|
19
|
+
from .space_monitor import CacheOperationInterrupted
|
20
|
+
from .info import get_cache_info, list_available_caches
|
21
|
+
from .constants import SaveStatus, LoadStatus, TransferStatus, AsyncTransferStatus
|
22
|
+
|
23
|
+
# Version
|
24
|
+
__version__ = "0.0.1"
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
"CacheError",
|
28
|
+
"CacheValidationError",
|
29
|
+
"CacheOperationInterrupted",
|
30
|
+
"SaveStatus",
|
31
|
+
"LoadStatus",
|
32
|
+
"TransferStatus",
|
33
|
+
"AsyncTransferStatus",
|
34
|
+
"transfer",
|
35
|
+
"load_compile_cache",
|
36
|
+
"save_compile_cache",
|
37
|
+
"clear_local_cache",
|
38
|
+
"get_cache_info",
|
39
|
+
"list_available_caches",
|
40
|
+
# Generic async operations
|
41
|
+
"start_transfer_async",
|
42
|
+
"get_transfer_status",
|
43
|
+
"is_transfer_complete",
|
44
|
+
"wait_for_completion",
|
45
|
+
"cancel_transfer",
|
46
|
+
"list_active_transfers",
|
47
|
+
"TransferProgress",
|
48
|
+
# Torch-specific async operations
|
49
|
+
"load_compile_cache_async",
|
50
|
+
"save_compile_cache_async",
|
51
|
+
]
|
@@ -0,0 +1,175 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import subprocess
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from .utils import timed_fn, safe_unlink, CacheValidationError, validate_path_security
|
7
|
+
from .constants import MAX_CACHE_SIZE_MB
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class ArchiveError(Exception):
|
13
|
+
"""Archive operation failed."""
|
14
|
+
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
def get_file_size_mb(file_path: Path) -> float:
|
19
|
+
"""Get the size of a file in megabytes.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
file_path: Path to the file to measure.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
float: File size in megabytes, or 0.0 if file doesn't exist or
|
26
|
+
can't be accessed.
|
27
|
+
|
28
|
+
Raises:
|
29
|
+
No exceptions are raised; OSError is caught and returns 0.0.
|
30
|
+
"""
|
31
|
+
try:
|
32
|
+
return file_path.stat().st_size / (1024 * 1024)
|
33
|
+
except OSError:
|
34
|
+
return 0.0
|
35
|
+
|
36
|
+
|
37
|
+
def _compress_directory_to_tar(source_dir: Path, target_file: Path) -> None:
|
38
|
+
"""Compress directory contents to a gzipped tar archive using system tar.
|
39
|
+
|
40
|
+
This function recursively compresses all files in the source directory
|
41
|
+
into a gzipped tar archive using the system tar command for better performance.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
source_dir: Path to the directory to compress.
|
45
|
+
target_file: Path where the compressed archive will be created.
|
46
|
+
|
47
|
+
Raises:
|
48
|
+
subprocess.CalledProcessError: If tar command fails.
|
49
|
+
OSError: If source directory can't be read or target file can't be written.
|
50
|
+
"""
|
51
|
+
# Use system tar command for better performance
|
52
|
+
# -czf: create, gzip, file
|
53
|
+
# -C: change to directory before archiving
|
54
|
+
cmd = ["tar", "-czf", str(target_file), "-C", str(source_dir), "."]
|
55
|
+
|
56
|
+
try:
|
57
|
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
58
|
+
except subprocess.CalledProcessError as e:
|
59
|
+
raise OSError(f"tar compression failed: {e.stderr}") from e
|
60
|
+
|
61
|
+
|
62
|
+
@timed_fn(logger=logger, name="Creating archive")
|
63
|
+
def create_archive(
|
64
|
+
source_dir: Path, target_file: Path, max_size_mb: int = MAX_CACHE_SIZE_MB
|
65
|
+
) -> None:
|
66
|
+
"""Create a compressed archive with path validation and size limits.
|
67
|
+
|
68
|
+
This function safely creates a gzipped tar archive from a source directory
|
69
|
+
with security validation and size constraints. It validates paths to prevent
|
70
|
+
directory traversal attacks and enforces maximum archive size limits.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
source_dir: Path to the directory to archive. Must exist and be within
|
74
|
+
allowed directories (/tmp/ or its parent).
|
75
|
+
target_file: Path where the archive will be created. Must be within
|
76
|
+
allowed directories (/app or /cache).
|
77
|
+
max_size_mb: Maximum allowed archive size in megabytes. Defaults to MAX_CACHE_SIZE_MB.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
CacheValidationError: If paths are outside allowed directories.
|
81
|
+
ArchiveError: If source directory doesn't exist, archive creation fails,
|
82
|
+
or archive exceeds size limit.
|
83
|
+
"""
|
84
|
+
# Validate paths
|
85
|
+
validate_path_security(
|
86
|
+
str(source_dir),
|
87
|
+
["/tmp/", str(source_dir.parent)],
|
88
|
+
f"Source directory {source_dir}",
|
89
|
+
CacheValidationError,
|
90
|
+
)
|
91
|
+
validate_path_security(
|
92
|
+
str(target_file),
|
93
|
+
["/app", "/cache"],
|
94
|
+
f"Target file {target_file}",
|
95
|
+
CacheValidationError,
|
96
|
+
)
|
97
|
+
|
98
|
+
if not source_dir.exists():
|
99
|
+
raise ArchiveError(f"Source directory missing: {source_dir}")
|
100
|
+
|
101
|
+
target_file.parent.mkdir(parents=True, exist_ok=True)
|
102
|
+
|
103
|
+
try:
|
104
|
+
_compress_directory_to_tar(source_dir, target_file)
|
105
|
+
size_mb = get_file_size_mb(target_file)
|
106
|
+
|
107
|
+
if size_mb > max_size_mb:
|
108
|
+
safe_unlink(
|
109
|
+
target_file, f"Failed to delete oversized archive {target_file}"
|
110
|
+
)
|
111
|
+
raise ArchiveError(f"Archive too large: {size_mb:.1f}MB > {max_size_mb}MB")
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
safe_unlink(target_file, f"Failed to cleanup failed archive {target_file}")
|
115
|
+
raise ArchiveError(f"Archive creation failed: {e}") from e
|
116
|
+
|
117
|
+
|
118
|
+
@timed_fn(logger=logger, name="Extracting archive")
|
119
|
+
def extract_archive(archive_file: Path, target_dir: Path) -> None:
|
120
|
+
"""Extract a compressed archive with security validation.
|
121
|
+
|
122
|
+
This function safely extracts a gzipped tar archive to a target directory
|
123
|
+
with security checks to prevent directory traversal attacks. It validates
|
124
|
+
both the archive and target paths, and inspects archive contents for
|
125
|
+
malicious paths before extraction.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
archive_file: Path to the archive file to extract. Must exist and be
|
129
|
+
within allowed directories (/app or /cache).
|
130
|
+
target_dir: Path to the directory where files will be extracted. Must
|
131
|
+
be within allowed directories (/tmp/ or its parent).
|
132
|
+
|
133
|
+
Raises:
|
134
|
+
CacheValidationError: If paths are outside allowed directories or if
|
135
|
+
archive contains unsafe paths (absolute paths or
|
136
|
+
paths with '..' components).
|
137
|
+
ArchiveError: If archive file doesn't exist or extraction fails.
|
138
|
+
"""
|
139
|
+
# Validate paths
|
140
|
+
validate_path_security(
|
141
|
+
str(archive_file),
|
142
|
+
["/app", "/cache"],
|
143
|
+
f"Archive file {archive_file}",
|
144
|
+
CacheValidationError,
|
145
|
+
)
|
146
|
+
validate_path_security(
|
147
|
+
str(target_dir),
|
148
|
+
["/tmp/", str(target_dir.parent)],
|
149
|
+
f"Target directory {target_dir}",
|
150
|
+
CacheValidationError,
|
151
|
+
)
|
152
|
+
|
153
|
+
if not archive_file.exists():
|
154
|
+
raise ArchiveError(f"Archive missing: {archive_file}")
|
155
|
+
|
156
|
+
try:
|
157
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
158
|
+
|
159
|
+
# First, perform security check by listing archive contents
|
160
|
+
list_cmd = ["tar", "-tzf", str(archive_file)]
|
161
|
+
result = subprocess.run(list_cmd, check=True, capture_output=True, text=True)
|
162
|
+
|
163
|
+
# Security check on all paths in the archive
|
164
|
+
for path in result.stdout.strip().split("\n"):
|
165
|
+
if path and (os.path.isabs(path) or ".." in path):
|
166
|
+
raise CacheValidationError(f"Unsafe path in archive: {path}")
|
167
|
+
|
168
|
+
# Extract using system tar command for better performance
|
169
|
+
extract_cmd = ["tar", "-xzf", str(archive_file), "-C", str(target_dir)]
|
170
|
+
subprocess.run(extract_cmd, check=True, capture_output=True, text=True)
|
171
|
+
|
172
|
+
except subprocess.CalledProcessError as e:
|
173
|
+
raise ArchiveError(f"tar extraction failed: {e.stderr}") from e
|
174
|
+
except Exception as e:
|
175
|
+
raise ArchiveError(f"Extraction failed: {e}") from e
|