b10-transfer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ """PyTorch compilation cache management using the generic transfer system.
2
+
3
+ This module provides torch-specific cache operations (save/load) that use the
4
+ generic transfer infrastructure from core.py. It handles the torch-specific
5
+ logic like compression, extraction, and file naming while delegating the
6
+ robust transfer operations to the core transfer function.
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ import tempfile
12
+ import shutil
13
+ from pathlib import Path
14
+
15
+ from .core import transfer
16
+ from .environment import get_cache_filename
17
+ from .archive import create_archive, extract_archive
18
+ from .utils import (
19
+ timed_fn,
20
+ critical_section_b10fs_file_lock,
21
+ safe_execute,
22
+ temp_file_cleanup,
23
+ safe_unlink,
24
+ )
25
+ from .space_monitor import worker_process
26
+ from .constants import (
27
+ TORCH_CACHE_DIR,
28
+ B10FS_CACHE_DIR,
29
+ LOCAL_WORK_DIR,
30
+ MAX_CACHE_SIZE_MB,
31
+ CACHE_FILE_EXTENSION,
32
+ CACHE_LATEST_SUFFIX,
33
+ CACHE_INCOMPLETE_SUFFIX,
34
+ LoadStatus,
35
+ SaveStatus,
36
+ TransferStatus,
37
+ )
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def torch_cache_save_callback(
43
+ source_dir: Path, dest_file: Path, max_size_mb: int
44
+ ) -> None:
45
+ """Callback function for saving torch cache: compress then copy to b10fs.
46
+
47
+ This function handles the torch-specific save logic:
48
+ 1. Compress the torch cache directory to a temporary archive
49
+ 2. Copy the archive to b10fs using atomic operations (temp file + rename)
50
+
51
+ Args:
52
+ source_dir: Path to the torch cache directory to compress
53
+ dest_file: Path to the final cache file in b10fs
54
+ max_size_mb: Maximum allowed archive size in megabytes
55
+ """
56
+ work_dir = Path(LOCAL_WORK_DIR)
57
+
58
+ # Create temporary archive in local work directory
59
+ with tempfile.NamedTemporaryFile(
60
+ suffix=CACHE_FILE_EXTENSION, dir=work_dir, delete=False
61
+ ) as f:
62
+ temp_archive = Path(f.name)
63
+
64
+ logger.debug(f"Created temporary archive: {temp_archive}")
65
+
66
+ try:
67
+ with temp_file_cleanup(temp_archive):
68
+ # Step 1: Compress torch cache to temporary archive
69
+ logger.info(f"Compressing torch cache: {source_dir} -> {temp_archive}")
70
+ create_archive(source_dir, temp_archive, max_size_mb)
71
+
72
+ # Step 2: Atomic copy to b10fs (temp file + rename)
73
+ b10fs_dir = dest_file.parent
74
+ b10fs_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+ # Use incomplete suffix for atomic operation
77
+ cache_filename = get_cache_filename()
78
+ temp_dest = (
79
+ b10fs_dir
80
+ / f"{cache_filename}{CACHE_INCOMPLETE_SUFFIX}{CACHE_FILE_EXTENSION}"
81
+ )
82
+
83
+ logger.info(f"Copying to b10fs: {temp_archive} -> {temp_dest}")
84
+
85
+ @critical_section_b10fs_file_lock("copy_in")
86
+ def _atomic_copy_to_b10fs():
87
+ shutil.copy2(temp_archive, temp_dest)
88
+ # Atomic rename to final destination
89
+ logger.info(f"Atomic rename: {temp_dest} -> {dest_file}")
90
+ temp_dest.rename(dest_file)
91
+
92
+ _atomic_copy_to_b10fs()
93
+
94
+ except Exception as e:
95
+ # Cleanup any partial b10fs files
96
+ temp_dest_pattern = dest_file.parent / f"*{CACHE_INCOMPLETE_SUFFIX}*"
97
+ for temp_file in dest_file.parent.glob(f"*{CACHE_INCOMPLETE_SUFFIX}*"):
98
+ safe_unlink(temp_file, f"Failed to cleanup incomplete file {temp_file}")
99
+ raise
100
+
101
+
102
+ def torch_cache_load_callback(source_file: Path, dest_dir: Path) -> None:
103
+ """Callback function for loading torch cache: copy from b10fs then extract.
104
+
105
+ This function handles the torch-specific load logic:
106
+ 1. Copy the cache file from b10fs to a temporary local file
107
+ 2. Extract the archive to the torch cache directory
108
+
109
+ Args:
110
+ source_file: Path to the cache file in b10fs
111
+ dest_dir: Path to the torch cache directory where files will be extracted
112
+ """
113
+ work_dir = Path(LOCAL_WORK_DIR)
114
+
115
+ # Create temporary file for local copy
116
+ with tempfile.NamedTemporaryFile(
117
+ suffix=CACHE_FILE_EXTENSION, dir=work_dir, delete=False
118
+ ) as f:
119
+ temp_archive = Path(f.name)
120
+
121
+ logger.debug(f"Created temporary file for cache copy: {temp_archive}")
122
+
123
+ try:
124
+ with temp_file_cleanup(temp_archive):
125
+ # Step 1: Copy from b10fs to local temp file
126
+ @critical_section_b10fs_file_lock("copy_out")
127
+ def _copy_from_b10fs():
128
+ logger.info(f"Copying from b10fs: {source_file} -> {temp_archive}")
129
+ shutil.copy2(source_file, temp_archive)
130
+
131
+ _copy_from_b10fs()
132
+
133
+ # Step 2: Extract archive to torch cache directory
134
+ logger.info(f"Extracting archive: {temp_archive} -> {dest_dir}")
135
+ extract_archive(temp_archive, dest_dir)
136
+
137
+ except Exception as e:
138
+ # Cleanup partial torch directory on failure
139
+ if dest_dir.exists():
140
+ try:
141
+ shutil.rmtree(dest_dir)
142
+ logger.debug(f"Cleaned up partial torch directory: {dest_dir}")
143
+ except Exception as cleanup_error:
144
+ logger.error(
145
+ f"Failed to cleanup torch directory {dest_dir}: {cleanup_error}"
146
+ )
147
+ raise
148
+
149
+
150
+ @timed_fn(logger=logger, name="Loading compile cache")
151
+ @safe_execute("Load failed", LoadStatus.ERROR)
152
+ def load_compile_cache() -> LoadStatus:
153
+ """Load PyTorch compilation cache from b10fs to local torch cache directory.
154
+
155
+ This function loads cached PyTorch compilation artifacts from the b10fs shared
156
+ filesystem to the local torch cache directory using the generic transfer system.
157
+ It validates cache availability, checks for existing cache, and extracts the
158
+ archive if needed.
159
+
160
+ Returns:
161
+ LoadStatus:
162
+ LoadStatus.SUCCESS if cache was successfully loaded
163
+ LoadStatus.SKIPPED if already exists
164
+ LoadStatus.ERROR if b10fs is unavailable, local disk space is insufficient, or loading failed.
165
+ LoadStatus.DOES_NOT_EXIST if no cache file was found.
166
+
167
+ Raises:
168
+ CacheValidationError: If b10fs is not enabled (caught and returns LoadStatus.ERROR).
169
+ CacheOperationInterrupted: If operations interrupted due to insufficient
170
+ local disk space (caught and returns LoadStatus.ERROR).
171
+ Exception: Any other errors during loading (caught and returns LoadStatus.ERROR).
172
+ """
173
+ b10fs_dir = Path(B10FS_CACHE_DIR)
174
+ torch_dir = Path(TORCH_CACHE_DIR)
175
+
176
+ cache_filename = get_cache_filename()
177
+ cache_file = (
178
+ b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
179
+ )
180
+ logger.debug(f"Looking for cache file: {cache_file}")
181
+
182
+ if not cache_file.exists():
183
+ logger.info("No cache file found in b10fs")
184
+ return LoadStatus.DOES_NOT_EXIST
185
+
186
+ # Skip if already loaded
187
+ if torch_dir.exists() and any(torch_dir.iterdir()):
188
+ logger.info("Torch cache already loaded, skipping extraction")
189
+ return LoadStatus.SKIPPED
190
+
191
+ # Use generic transfer system with torch-specific callback
192
+ result = transfer(
193
+ source=cache_file,
194
+ dest=torch_dir,
195
+ callback=torch_cache_load_callback,
196
+ monitor_local=True,
197
+ monitor_b10fs=False, # No need to monitor b10fs for read operations
198
+ )
199
+
200
+ # Convert TransferStatus to LoadStatus
201
+ if result == TransferStatus.SUCCESS:
202
+ logger.info("Cache load complete")
203
+ return LoadStatus.SUCCESS
204
+ else:
205
+ logger.error(f"Cache load failed with status: {result}")
206
+ return LoadStatus.ERROR
207
+
208
+
209
+ @timed_fn(logger=logger, name="Saving compile cache")
210
+ @safe_execute("Save failed", SaveStatus.ERROR)
211
+ def save_compile_cache() -> SaveStatus:
212
+ """Save local PyTorch compilation cache to b10fs using atomic journal pattern.
213
+
214
+ This function creates an archive of the local torch cache directory and
215
+ atomically saves it to b10fs using the generic transfer system. It validates
216
+ cache availability, checks if cache already exists (early exit), and performs
217
+ compression and copy operations with proper space monitoring.
218
+
219
+ Returns:
220
+ SaveStatus:
221
+ SaveStatus.SUCCESS if cache was successfully saved
222
+ SaveStatus.ERROR if b10fs is unavailable, insufficient disk space caused interruption,
223
+ no cache exists to save, or saving failed.
224
+ SaveStatus.SKIPPED if no cache exists to save or cache already exists in b10fs
225
+
226
+ Raises:
227
+ CacheValidationError: If b10fs is not enabled (caught and returns SaveStatus.ERROR).
228
+ CacheOperationInterrupted: If operations interrupted due to insufficient
229
+ disk space (caught and returns SaveStatus.ERROR).
230
+ ArchiveError: If archive creation fails (caught and returns SaveStatus.ERROR).
231
+ Exception: Any other errors during saving (caught and returns SaveStatus.ERROR).
232
+ """
233
+ b10fs_dir = Path(B10FS_CACHE_DIR)
234
+ torch_dir = Path(TORCH_CACHE_DIR)
235
+
236
+ # Check if anything to save
237
+ if not torch_dir.exists() or not any(torch_dir.iterdir()):
238
+ logger.info("No torch cache to save")
239
+ return SaveStatus.SKIPPED
240
+
241
+ cache_filename = get_cache_filename()
242
+ final_file = (
243
+ b10fs_dir / f"{cache_filename}{CACHE_LATEST_SUFFIX}{CACHE_FILE_EXTENSION}"
244
+ )
245
+
246
+ # Check for existing cache first (early exit)
247
+ if final_file.exists():
248
+ logger.info("Cache already exists in b10fs, skipping save")
249
+ return SaveStatus.SKIPPED
250
+
251
+ # Use generic transfer system with torch-specific callback
252
+ result = transfer(
253
+ source=torch_dir,
254
+ dest=final_file,
255
+ callback=torch_cache_save_callback,
256
+ max_size_mb=MAX_CACHE_SIZE_MB,
257
+ monitor_local=True,
258
+ monitor_b10fs=True,
259
+ )
260
+
261
+ # Convert TransferStatus to SaveStatus
262
+ if result == TransferStatus.SUCCESS:
263
+ logger.info("Cache save complete")
264
+ return SaveStatus.SUCCESS
265
+ elif result == TransferStatus.INTERRUPTED:
266
+ logger.warning("Cache save interrupted due to insufficient disk space")
267
+ return SaveStatus.ERROR
268
+ else:
269
+ logger.error(f"Cache save failed with status: {result}")
270
+ return SaveStatus.ERROR
271
+
272
+
273
+ @safe_execute("Clear failed", False)
274
+ def clear_local_cache() -> bool:
275
+ """Clear the local PyTorch compilation cache directory.
276
+
277
+ This function removes the entire local torch cache directory and all its
278
+ contents. This is useful for cleaning up disk space or forcing recompilation.
279
+
280
+ Returns:
281
+ bool: True if cache was successfully cleared or didn't exist, False if
282
+ clearing failed due to permissions or other filesystem errors.
283
+
284
+ Raises:
285
+ Exception: Any errors during directory removal (caught and returns False).
286
+ """
287
+ torch_dir = Path(TORCH_CACHE_DIR)
288
+ if not torch_dir.exists():
289
+ return True
290
+ shutil.rmtree(torch_dir)
291
+ return True
292
+
293
+
294
+ # Worker functions for backward compatibility with existing monitored process system
295
+ # These are used if someone wants to use the old worker-based approach
296
+
297
+
298
+ @worker_process("Compression was cancelled before starting")
299
+ def _cache_compression_worker(
300
+ torch_dir_str: str, local_temp_str: str, max_size_mb: int
301
+ ) -> None:
302
+ """Worker process that handles cache compression.
303
+
304
+ This function runs in a separate process to compress the torch cache directory
305
+ into an archive. It can be terminated externally if disk space becomes insufficient.
306
+
307
+ Args:
308
+ torch_dir_str: String path to the torch cache directory to compress.
309
+ local_temp_str: String path where the compressed archive will be created.
310
+ max_size_mb: Maximum allowed archive size in megabytes.
311
+ """
312
+ torch_dir = Path(torch_dir_str)
313
+ local_temp = Path(local_temp_str)
314
+
315
+ create_archive(torch_dir, local_temp, max_size_mb)
316
+
317
+
318
+ @worker_process("Copy was cancelled before starting")
319
+ def _cache_copy_worker(source_path_str: str, dest_path_str: str) -> None:
320
+ """Worker process that handles file copy to b10fs.
321
+
322
+ This function runs in a separate process to copy the compressed cache file
323
+ to the b10fs filesystem. It can be terminated externally if disk space becomes insufficient.
324
+
325
+ Args:
326
+ source_path_str: String path to the source file to copy.
327
+ dest_path_str: String path where the file will be copied.
328
+ """
329
+ source_path = Path(source_path_str)
330
+ dest_path = Path(dest_path_str)
331
+
332
+ shutil.copy2(source_path, dest_path)
333
+
334
+
335
+ @worker_process("Copy from b10fs was cancelled before starting")
336
+ def _cache_copy_from_b10fs_worker(source_path_str: str, dest_path_str: str) -> None:
337
+ """Worker process that handles file copy from b10fs to local machine.
338
+
339
+ This function runs in a separate process to copy the cache file from b10fs
340
+ to the local filesystem. It can be terminated externally if local disk space becomes insufficient.
341
+
342
+ Args:
343
+ source_path_str: String path to the source file in b10fs to copy.
344
+ dest_path_str: String path where the file will be copied locally.
345
+ """
346
+ source_path = Path(source_path_str)
347
+ dest_path = Path(dest_path_str)
348
+
349
+ shutil.copy2(source_path, dest_path)
350
+
351
+
352
+ @worker_process("Extraction was cancelled before starting")
353
+ def _cache_extract_worker(archive_path_str: str, dest_dir_str: str) -> None:
354
+ """Worker process that handles archive extraction.
355
+
356
+ This function runs in a separate process to extract the cache archive to
357
+ the torch cache directory. It can be terminated externally if local disk space becomes insufficient.
358
+
359
+ Args:
360
+ archive_path_str: String path to the archive file to extract.
361
+ dest_dir_str: String path to the directory where archive will be extracted.
362
+ """
363
+ archive_path = Path(archive_path_str)
364
+ dest_dir = Path(dest_dir_str)
365
+
366
+ extract_archive(archive_path, dest_dir)
367
+
368
+
369
+ def _cleanup_torch_dir(torch_dir: Path) -> None:
370
+ """Helper function to safely cleanup torch directory during interrupted extraction."""
371
+ try:
372
+ if torch_dir.exists():
373
+ shutil.rmtree(torch_dir)
374
+ logger.debug(f"Cleaned up torch directory: {torch_dir}")
375
+ except Exception as e:
376
+ logger.error(f"Failed to cleanup torch directory {torch_dir}: {e}")