b10-transfer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- b10_transfer/__init__.py +51 -0
- b10_transfer/archive.py +175 -0
- b10_transfer/async_torch_cache.py +62 -0
- b10_transfer/async_transfers.py +275 -0
- b10_transfer/cleanup.py +179 -0
- b10_transfer/constants.py +149 -0
- b10_transfer/core.py +160 -0
- b10_transfer/environment.py +134 -0
- b10_transfer/info.py +172 -0
- b10_transfer/space_monitor.py +299 -0
- b10_transfer/torch_cache.py +376 -0
- b10_transfer/utils.py +355 -0
- b10_transfer-0.0.1.dist-info/METADATA +219 -0
- b10_transfer-0.0.1.dist-info/RECORD +15 -0
- b10_transfer-0.0.1.dist-info/WHEEL +4 -0
b10_transfer/__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
"""B10 Transfer - Lock-free PyTorch compilation cache for Baseten."""
|
2
|
+
|
3
|
+
from .core import transfer
|
4
|
+
from .torch_cache import load_compile_cache, save_compile_cache, clear_local_cache
|
5
|
+
from .async_transfers import (
|
6
|
+
start_transfer_async,
|
7
|
+
get_transfer_status,
|
8
|
+
is_transfer_complete,
|
9
|
+
wait_for_completion,
|
10
|
+
cancel_transfer,
|
11
|
+
list_active_transfers,
|
12
|
+
TransferProgress,
|
13
|
+
)
|
14
|
+
from .async_torch_cache import (
|
15
|
+
load_compile_cache_async,
|
16
|
+
save_compile_cache_async,
|
17
|
+
)
|
18
|
+
from .utils import CacheError, CacheValidationError
|
19
|
+
from .space_monitor import CacheOperationInterrupted
|
20
|
+
from .info import get_cache_info, list_available_caches
|
21
|
+
from .constants import SaveStatus, LoadStatus, TransferStatus, AsyncTransferStatus
|
22
|
+
|
23
|
+
# Version
|
24
|
+
__version__ = "0.0.1"
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
"CacheError",
|
28
|
+
"CacheValidationError",
|
29
|
+
"CacheOperationInterrupted",
|
30
|
+
"SaveStatus",
|
31
|
+
"LoadStatus",
|
32
|
+
"TransferStatus",
|
33
|
+
"AsyncTransferStatus",
|
34
|
+
"transfer",
|
35
|
+
"load_compile_cache",
|
36
|
+
"save_compile_cache",
|
37
|
+
"clear_local_cache",
|
38
|
+
"get_cache_info",
|
39
|
+
"list_available_caches",
|
40
|
+
# Generic async operations
|
41
|
+
"start_transfer_async",
|
42
|
+
"get_transfer_status",
|
43
|
+
"is_transfer_complete",
|
44
|
+
"wait_for_completion",
|
45
|
+
"cancel_transfer",
|
46
|
+
"list_active_transfers",
|
47
|
+
"TransferProgress",
|
48
|
+
# Torch-specific async operations
|
49
|
+
"load_compile_cache_async",
|
50
|
+
"save_compile_cache_async",
|
51
|
+
]
|
b10_transfer/archive.py
ADDED
@@ -0,0 +1,175 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import subprocess
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from .utils import timed_fn, safe_unlink, CacheValidationError, validate_path_security
|
7
|
+
from .constants import MAX_CACHE_SIZE_MB
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class ArchiveError(Exception):
|
13
|
+
"""Archive operation failed."""
|
14
|
+
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
def get_file_size_mb(file_path: Path) -> float:
|
19
|
+
"""Get the size of a file in megabytes.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
file_path: Path to the file to measure.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
float: File size in megabytes, or 0.0 if file doesn't exist or
|
26
|
+
can't be accessed.
|
27
|
+
|
28
|
+
Raises:
|
29
|
+
No exceptions are raised; OSError is caught and returns 0.0.
|
30
|
+
"""
|
31
|
+
try:
|
32
|
+
return file_path.stat().st_size / (1024 * 1024)
|
33
|
+
except OSError:
|
34
|
+
return 0.0
|
35
|
+
|
36
|
+
|
37
|
+
def _compress_directory_to_tar(source_dir: Path, target_file: Path) -> None:
|
38
|
+
"""Compress directory contents to a gzipped tar archive using system tar.
|
39
|
+
|
40
|
+
This function recursively compresses all files in the source directory
|
41
|
+
into a gzipped tar archive using the system tar command for better performance.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
source_dir: Path to the directory to compress.
|
45
|
+
target_file: Path where the compressed archive will be created.
|
46
|
+
|
47
|
+
Raises:
|
48
|
+
subprocess.CalledProcessError: If tar command fails.
|
49
|
+
OSError: If source directory can't be read or target file can't be written.
|
50
|
+
"""
|
51
|
+
# Use system tar command for better performance
|
52
|
+
# -czf: create, gzip, file
|
53
|
+
# -C: change to directory before archiving
|
54
|
+
cmd = ["tar", "-czf", str(target_file), "-C", str(source_dir), "."]
|
55
|
+
|
56
|
+
try:
|
57
|
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
58
|
+
except subprocess.CalledProcessError as e:
|
59
|
+
raise OSError(f"tar compression failed: {e.stderr}") from e
|
60
|
+
|
61
|
+
|
62
|
+
@timed_fn(logger=logger, name="Creating archive")
|
63
|
+
def create_archive(
|
64
|
+
source_dir: Path, target_file: Path, max_size_mb: int = MAX_CACHE_SIZE_MB
|
65
|
+
) -> None:
|
66
|
+
"""Create a compressed archive with path validation and size limits.
|
67
|
+
|
68
|
+
This function safely creates a gzipped tar archive from a source directory
|
69
|
+
with security validation and size constraints. It validates paths to prevent
|
70
|
+
directory traversal attacks and enforces maximum archive size limits.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
source_dir: Path to the directory to archive. Must exist and be within
|
74
|
+
allowed directories (/tmp/ or its parent).
|
75
|
+
target_file: Path where the archive will be created. Must be within
|
76
|
+
allowed directories (/app or /cache).
|
77
|
+
max_size_mb: Maximum allowed archive size in megabytes. Defaults to MAX_CACHE_SIZE_MB.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
CacheValidationError: If paths are outside allowed directories.
|
81
|
+
ArchiveError: If source directory doesn't exist, archive creation fails,
|
82
|
+
or archive exceeds size limit.
|
83
|
+
"""
|
84
|
+
# Validate paths
|
85
|
+
validate_path_security(
|
86
|
+
str(source_dir),
|
87
|
+
["/tmp/", str(source_dir.parent)],
|
88
|
+
f"Source directory {source_dir}",
|
89
|
+
CacheValidationError,
|
90
|
+
)
|
91
|
+
validate_path_security(
|
92
|
+
str(target_file),
|
93
|
+
["/app", "/cache"],
|
94
|
+
f"Target file {target_file}",
|
95
|
+
CacheValidationError,
|
96
|
+
)
|
97
|
+
|
98
|
+
if not source_dir.exists():
|
99
|
+
raise ArchiveError(f"Source directory missing: {source_dir}")
|
100
|
+
|
101
|
+
target_file.parent.mkdir(parents=True, exist_ok=True)
|
102
|
+
|
103
|
+
try:
|
104
|
+
_compress_directory_to_tar(source_dir, target_file)
|
105
|
+
size_mb = get_file_size_mb(target_file)
|
106
|
+
|
107
|
+
if size_mb > max_size_mb:
|
108
|
+
safe_unlink(
|
109
|
+
target_file, f"Failed to delete oversized archive {target_file}"
|
110
|
+
)
|
111
|
+
raise ArchiveError(f"Archive too large: {size_mb:.1f}MB > {max_size_mb}MB")
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
safe_unlink(target_file, f"Failed to cleanup failed archive {target_file}")
|
115
|
+
raise ArchiveError(f"Archive creation failed: {e}") from e
|
116
|
+
|
117
|
+
|
118
|
+
@timed_fn(logger=logger, name="Extracting archive")
|
119
|
+
def extract_archive(archive_file: Path, target_dir: Path) -> None:
|
120
|
+
"""Extract a compressed archive with security validation.
|
121
|
+
|
122
|
+
This function safely extracts a gzipped tar archive to a target directory
|
123
|
+
with security checks to prevent directory traversal attacks. It validates
|
124
|
+
both the archive and target paths, and inspects archive contents for
|
125
|
+
malicious paths before extraction.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
archive_file: Path to the archive file to extract. Must exist and be
|
129
|
+
within allowed directories (/app or /cache).
|
130
|
+
target_dir: Path to the directory where files will be extracted. Must
|
131
|
+
be within allowed directories (/tmp/ or its parent).
|
132
|
+
|
133
|
+
Raises:
|
134
|
+
CacheValidationError: If paths are outside allowed directories or if
|
135
|
+
archive contains unsafe paths (absolute paths or
|
136
|
+
paths with '..' components).
|
137
|
+
ArchiveError: If archive file doesn't exist or extraction fails.
|
138
|
+
"""
|
139
|
+
# Validate paths
|
140
|
+
validate_path_security(
|
141
|
+
str(archive_file),
|
142
|
+
["/app", "/cache"],
|
143
|
+
f"Archive file {archive_file}",
|
144
|
+
CacheValidationError,
|
145
|
+
)
|
146
|
+
validate_path_security(
|
147
|
+
str(target_dir),
|
148
|
+
["/tmp/", str(target_dir.parent)],
|
149
|
+
f"Target directory {target_dir}",
|
150
|
+
CacheValidationError,
|
151
|
+
)
|
152
|
+
|
153
|
+
if not archive_file.exists():
|
154
|
+
raise ArchiveError(f"Archive missing: {archive_file}")
|
155
|
+
|
156
|
+
try:
|
157
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
158
|
+
|
159
|
+
# First, perform security check by listing archive contents
|
160
|
+
list_cmd = ["tar", "-tzf", str(archive_file)]
|
161
|
+
result = subprocess.run(list_cmd, check=True, capture_output=True, text=True)
|
162
|
+
|
163
|
+
# Security check on all paths in the archive
|
164
|
+
for path in result.stdout.strip().split("\n"):
|
165
|
+
if path and (os.path.isabs(path) or ".." in path):
|
166
|
+
raise CacheValidationError(f"Unsafe path in archive: {path}")
|
167
|
+
|
168
|
+
# Extract using system tar command for better performance
|
169
|
+
extract_cmd = ["tar", "-xzf", str(archive_file), "-C", str(target_dir)]
|
170
|
+
subprocess.run(extract_cmd, check=True, capture_output=True, text=True)
|
171
|
+
|
172
|
+
except subprocess.CalledProcessError as e:
|
173
|
+
raise ArchiveError(f"tar extraction failed: {e.stderr}") from e
|
174
|
+
except Exception as e:
|
175
|
+
raise ArchiveError(f"Extraction failed: {e}") from e
|
@@ -0,0 +1,62 @@
|
|
1
|
+
"""Torch-specific async cache operations using the generic async transfer system."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Optional, Callable
|
5
|
+
|
6
|
+
from .async_transfers import start_transfer_async
|
7
|
+
from .torch_cache import torch_cache_load_callback, torch_cache_save_callback
|
8
|
+
from .environment import get_cache_filename
|
9
|
+
from .constants import (
|
10
|
+
TORCH_CACHE_DIR,
|
11
|
+
B10FS_CACHE_DIR,
|
12
|
+
MAX_CACHE_SIZE_MB,
|
13
|
+
CACHE_FILE_EXTENSION,
|
14
|
+
CACHE_LATEST_SUFFIX,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
def load_compile_cache_async(
|
19
|
+
progress_callback: Optional[Callable[[str], None]] = None,
|
20
|
+
) -> str:
|
21
|
+
"""Start async PyTorch compilation cache load operation."""
|
22
|
+
b10fs_dir = Path(B10FS_CACHE_DIR)
|
23
|
+
torch_dir = Path(TORCH_CACHE_DIR)
|
24
|
+
|
25
|
+
cache_filename = get_cache_filename()
|
26
|
+
cache_file = (
|
27
|
+
b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
|
28
|
+
)
|
29
|
+
|
30
|
+
return start_transfer_async(
|
31
|
+
source=cache_file,
|
32
|
+
dest=torch_dir,
|
33
|
+
callback=torch_cache_load_callback,
|
34
|
+
operation_name="torch_cache_load",
|
35
|
+
progress_callback=progress_callback,
|
36
|
+
monitor_local=True,
|
37
|
+
monitor_b10fs=False, # No need to monitor b10fs for read operations
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def save_compile_cache_async(
|
42
|
+
progress_callback: Optional[Callable[[str], None]] = None,
|
43
|
+
) -> str:
|
44
|
+
"""Start async PyTorch compilation cache save operation."""
|
45
|
+
b10fs_dir = Path(B10FS_CACHE_DIR)
|
46
|
+
torch_dir = Path(TORCH_CACHE_DIR)
|
47
|
+
|
48
|
+
cache_filename = get_cache_filename()
|
49
|
+
final_file = (
|
50
|
+
b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
|
51
|
+
)
|
52
|
+
|
53
|
+
return start_transfer_async(
|
54
|
+
source=torch_dir,
|
55
|
+
dest=final_file,
|
56
|
+
callback=torch_cache_save_callback,
|
57
|
+
operation_name="torch_cache_save",
|
58
|
+
progress_callback=progress_callback,
|
59
|
+
monitor_local=True,
|
60
|
+
monitor_b10fs=True,
|
61
|
+
max_size_mb=MAX_CACHE_SIZE_MB,
|
62
|
+
)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
"""Generic async transfer operations with progress tracking."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import threading
|
6
|
+
import time
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Optional, Dict, Any, Callable
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from datetime import datetime, timedelta
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
12
|
+
|
13
|
+
from .core import transfer
|
14
|
+
from .constants import AsyncTransferStatus, TransferStatus
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class TransferProgress:
|
21
|
+
operation_id: str
|
22
|
+
operation_name: str
|
23
|
+
status: AsyncTransferStatus
|
24
|
+
started_at: Optional[datetime] = None
|
25
|
+
completed_at: Optional[datetime] = None
|
26
|
+
error_message: Optional[str] = None
|
27
|
+
progress_callback: Optional[Callable[[str], None]] = None
|
28
|
+
|
29
|
+
def update_status(
|
30
|
+
self, status: AsyncTransferStatus, error_message: Optional[str] = None
|
31
|
+
):
|
32
|
+
self.status = status
|
33
|
+
if error_message:
|
34
|
+
self.error_message = error_message
|
35
|
+
if status == AsyncTransferStatus.IN_PROGRESS and self.started_at is None:
|
36
|
+
self.started_at = datetime.now()
|
37
|
+
elif status in [
|
38
|
+
AsyncTransferStatus.SUCCESS,
|
39
|
+
AsyncTransferStatus.ERROR,
|
40
|
+
AsyncTransferStatus.INTERRUPTED,
|
41
|
+
AsyncTransferStatus.CANCELLED,
|
42
|
+
]:
|
43
|
+
self.completed_at = datetime.now()
|
44
|
+
|
45
|
+
# Notify callback if provided
|
46
|
+
if self.progress_callback:
|
47
|
+
try:
|
48
|
+
self.progress_callback(self.operation_id)
|
49
|
+
except Exception as e:
|
50
|
+
logger.warning(f"Progress callback failed for {self.operation_id}: {e}")
|
51
|
+
|
52
|
+
|
53
|
+
class AsyncTransferManager:
|
54
|
+
def __init__(self):
|
55
|
+
self._transfers: Dict[str, TransferProgress] = {}
|
56
|
+
self._executor = ThreadPoolExecutor(
|
57
|
+
max_workers=2, thread_name_prefix="b10-async-transfer"
|
58
|
+
)
|
59
|
+
self._lock = threading.Lock()
|
60
|
+
self._operation_counter = 0
|
61
|
+
|
62
|
+
def _generate_operation_id(self, operation_name: str) -> str:
|
63
|
+
with self._lock:
|
64
|
+
self._operation_counter += 1
|
65
|
+
return f"{operation_name}_{self._operation_counter}_{int(datetime.now().timestamp())}"
|
66
|
+
|
67
|
+
def start_transfer_async(
|
68
|
+
self,
|
69
|
+
source: Path,
|
70
|
+
dest: Path,
|
71
|
+
callback: Callable,
|
72
|
+
operation_name: str,
|
73
|
+
progress_callback: Optional[Callable[[str], None]] = None,
|
74
|
+
monitor_local: bool = True,
|
75
|
+
monitor_b10fs: bool = True,
|
76
|
+
**callback_kwargs,
|
77
|
+
) -> str:
|
78
|
+
operation_id = self._generate_operation_id(operation_name)
|
79
|
+
|
80
|
+
progress = TransferProgress(
|
81
|
+
operation_id=operation_id,
|
82
|
+
operation_name=operation_name,
|
83
|
+
status=AsyncTransferStatus.NOT_STARTED,
|
84
|
+
progress_callback=progress_callback,
|
85
|
+
)
|
86
|
+
|
87
|
+
with self._lock:
|
88
|
+
self._transfers[operation_id] = progress
|
89
|
+
|
90
|
+
# Submit the transfer operation to thread pool
|
91
|
+
future = self._executor.submit(
|
92
|
+
self._execute_transfer,
|
93
|
+
operation_id,
|
94
|
+
source,
|
95
|
+
dest,
|
96
|
+
callback,
|
97
|
+
monitor_local,
|
98
|
+
monitor_b10fs,
|
99
|
+
callback_kwargs,
|
100
|
+
)
|
101
|
+
|
102
|
+
logger.info(f"Started async transfer operation: {operation_id}")
|
103
|
+
return operation_id
|
104
|
+
|
105
|
+
def _execute_transfer(
|
106
|
+
self,
|
107
|
+
operation_id: str,
|
108
|
+
source: Path,
|
109
|
+
dest: Path,
|
110
|
+
callback: Callable,
|
111
|
+
monitor_local: bool,
|
112
|
+
monitor_b10fs: bool,
|
113
|
+
callback_kwargs: Dict[str, Any],
|
114
|
+
) -> None:
|
115
|
+
progress = self._transfers.get(operation_id)
|
116
|
+
if not progress:
|
117
|
+
logger.error(f"Progress tracking lost for operation {operation_id}")
|
118
|
+
return
|
119
|
+
|
120
|
+
try:
|
121
|
+
progress.update_status(AsyncTransferStatus.IN_PROGRESS)
|
122
|
+
logger.info(f"Starting transfer for operation {operation_id}")
|
123
|
+
|
124
|
+
result = transfer(
|
125
|
+
source=source,
|
126
|
+
dest=dest,
|
127
|
+
callback=callback,
|
128
|
+
monitor_local=monitor_local,
|
129
|
+
monitor_b10fs=monitor_b10fs,
|
130
|
+
**callback_kwargs,
|
131
|
+
)
|
132
|
+
|
133
|
+
# Convert TransferStatus to AsyncTransferStatus
|
134
|
+
if result == TransferStatus.SUCCESS:
|
135
|
+
progress.update_status(AsyncTransferStatus.SUCCESS)
|
136
|
+
logger.info(f"Transfer completed successfully: {operation_id}")
|
137
|
+
elif result == TransferStatus.INTERRUPTED:
|
138
|
+
progress.update_status(
|
139
|
+
AsyncTransferStatus.INTERRUPTED,
|
140
|
+
"Transfer interrupted due to insufficient disk space",
|
141
|
+
)
|
142
|
+
logger.warning(f"Transfer interrupted: {operation_id}")
|
143
|
+
else:
|
144
|
+
progress.update_status(
|
145
|
+
AsyncTransferStatus.ERROR, "Transfer operation failed"
|
146
|
+
)
|
147
|
+
logger.error(f"Transfer failed: {operation_id}")
|
148
|
+
|
149
|
+
except Exception as e:
|
150
|
+
progress.update_status(AsyncTransferStatus.ERROR, str(e))
|
151
|
+
logger.error(
|
152
|
+
f"Transfer operation {operation_id} failed with exception: {e}"
|
153
|
+
)
|
154
|
+
|
155
|
+
def get_transfer_status(self, operation_id: str) -> Optional[TransferProgress]:
|
156
|
+
with self._lock:
|
157
|
+
return self._transfers.get(operation_id)
|
158
|
+
|
159
|
+
def is_transfer_complete(self, operation_id: str) -> bool:
|
160
|
+
progress = self.get_transfer_status(operation_id)
|
161
|
+
if not progress:
|
162
|
+
return False
|
163
|
+
|
164
|
+
return progress.status in [
|
165
|
+
AsyncTransferStatus.SUCCESS,
|
166
|
+
AsyncTransferStatus.ERROR,
|
167
|
+
AsyncTransferStatus.INTERRUPTED,
|
168
|
+
AsyncTransferStatus.CANCELLED,
|
169
|
+
]
|
170
|
+
|
171
|
+
def wait_for_completion(
|
172
|
+
self, operation_id: str, timeout: Optional[float] = None
|
173
|
+
) -> bool:
|
174
|
+
start_time = datetime.now()
|
175
|
+
|
176
|
+
while not self.is_transfer_complete(operation_id):
|
177
|
+
if timeout and (datetime.now() - start_time).total_seconds() > timeout:
|
178
|
+
return False
|
179
|
+
time.sleep(0.1) # Small delay to avoid busy waiting
|
180
|
+
|
181
|
+
return True
|
182
|
+
|
183
|
+
def cancel_transfer(self, operation_id: str) -> bool:
|
184
|
+
progress = self.get_transfer_status(operation_id)
|
185
|
+
if not progress:
|
186
|
+
return False
|
187
|
+
|
188
|
+
if progress.status == AsyncTransferStatus.IN_PROGRESS:
|
189
|
+
progress.update_status(AsyncTransferStatus.CANCELLED)
|
190
|
+
logger.info(f"Marked transfer operation as cancelled: {operation_id}")
|
191
|
+
return True
|
192
|
+
|
193
|
+
return False
|
194
|
+
|
195
|
+
def cleanup_completed_transfers(self, max_age_hours: int = 24) -> int:
|
196
|
+
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
197
|
+
cleaned_count = 0
|
198
|
+
|
199
|
+
with self._lock:
|
200
|
+
to_remove = []
|
201
|
+
for operation_id, progress in self._transfers.items():
|
202
|
+
if (
|
203
|
+
progress.completed_at
|
204
|
+
and progress.completed_at < cutoff_time
|
205
|
+
and self.is_transfer_complete(operation_id)
|
206
|
+
):
|
207
|
+
to_remove.append(operation_id)
|
208
|
+
|
209
|
+
for operation_id in to_remove:
|
210
|
+
del self._transfers[operation_id]
|
211
|
+
cleaned_count += 1
|
212
|
+
|
213
|
+
if cleaned_count > 0:
|
214
|
+
logger.info(f"Cleaned up {cleaned_count} completed transfer records")
|
215
|
+
|
216
|
+
return cleaned_count
|
217
|
+
|
218
|
+
def list_active_transfers(self) -> Dict[str, TransferProgress]:
|
219
|
+
with self._lock:
|
220
|
+
return {
|
221
|
+
op_id: progress
|
222
|
+
for op_id, progress in self._transfers.items()
|
223
|
+
if not self.is_transfer_complete(op_id)
|
224
|
+
}
|
225
|
+
|
226
|
+
def shutdown(self) -> None:
|
227
|
+
logger.info("Shutting down async transfer manager...")
|
228
|
+
self._executor.shutdown(wait=True)
|
229
|
+
|
230
|
+
|
231
|
+
# Global instance for easy access
|
232
|
+
_transfer_manager = AsyncTransferManager()
|
233
|
+
|
234
|
+
|
235
|
+
# Generic Public API functions
|
236
|
+
def start_transfer_async(
|
237
|
+
source: Path,
|
238
|
+
dest: Path,
|
239
|
+
callback: Callable,
|
240
|
+
operation_name: str,
|
241
|
+
progress_callback: Optional[Callable[[str], None]] = None,
|
242
|
+
monitor_local: bool = True,
|
243
|
+
monitor_b10fs: bool = True,
|
244
|
+
**callback_kwargs,
|
245
|
+
) -> str:
|
246
|
+
return _transfer_manager.start_transfer_async(
|
247
|
+
source=source,
|
248
|
+
dest=dest,
|
249
|
+
callback=callback,
|
250
|
+
operation_name=operation_name,
|
251
|
+
progress_callback=progress_callback,
|
252
|
+
monitor_local=monitor_local,
|
253
|
+
monitor_b10fs=monitor_b10fs,
|
254
|
+
**callback_kwargs,
|
255
|
+
)
|
256
|
+
|
257
|
+
|
258
|
+
def get_transfer_status(operation_id: str) -> Optional[TransferProgress]:
|
259
|
+
return _transfer_manager.get_transfer_status(operation_id)
|
260
|
+
|
261
|
+
|
262
|
+
def is_transfer_complete(operation_id: str) -> bool:
|
263
|
+
return _transfer_manager.is_transfer_complete(operation_id)
|
264
|
+
|
265
|
+
|
266
|
+
def wait_for_completion(operation_id: str, timeout: Optional[float] = None) -> bool:
|
267
|
+
return _transfer_manager.wait_for_completion(operation_id, timeout)
|
268
|
+
|
269
|
+
|
270
|
+
def cancel_transfer(operation_id: str) -> bool:
|
271
|
+
return _transfer_manager.cancel_transfer(operation_id)
|
272
|
+
|
273
|
+
|
274
|
+
def list_active_transfers() -> Dict[str, TransferProgress]:
|
275
|
+
return _transfer_manager.list_active_transfers()
|