proximl 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. examples/local_storage.py +0 -2
  2. proximl/__init__.py +1 -1
  3. proximl/checkpoints.py +56 -57
  4. proximl/cli/__init__.py +6 -3
  5. proximl/cli/checkpoint.py +18 -57
  6. proximl/cli/dataset.py +17 -57
  7. proximl/cli/job/__init__.py +89 -67
  8. proximl/cli/job/create.py +51 -24
  9. proximl/cli/model.py +14 -56
  10. proximl/cli/volume.py +18 -57
  11. proximl/datasets.py +50 -55
  12. proximl/jobs.py +269 -69
  13. proximl/models.py +51 -55
  14. proximl/proximl.py +159 -114
  15. proximl/utils/__init__.py +1 -0
  16. proximl/{auth.py → utils/auth.py} +4 -3
  17. proximl/utils/transfer.py +647 -0
  18. proximl/volumes.py +48 -53
  19. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/METADATA +3 -3
  20. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/RECORD +52 -50
  21. tests/integration/test_checkpoints_integration.py +4 -3
  22. tests/integration/test_datasets_integration.py +5 -3
  23. tests/integration/test_jobs_integration.py +33 -27
  24. tests/integration/test_models_integration.py +7 -3
  25. tests/integration/test_volumes_integration.py +2 -2
  26. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  27. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  28. tests/unit/cloudbender/test_providers_unit.py +96 -0
  29. tests/unit/cloudbender/test_regions_unit.py +106 -0
  30. tests/unit/cloudbender/test_services_unit.py +141 -0
  31. tests/unit/conftest.py +23 -10
  32. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  33. tests/unit/projects/test_project_datastores_unit.py +37 -0
  34. tests/unit/projects/test_project_members_unit.py +46 -0
  35. tests/unit/projects/test_project_services_unit.py +65 -0
  36. tests/unit/projects/test_projects_unit.py +16 -0
  37. tests/unit/test_auth_unit.py +17 -2
  38. tests/unit/test_checkpoints_unit.py +256 -71
  39. tests/unit/test_datasets_unit.py +218 -68
  40. tests/unit/test_exceptions.py +133 -0
  41. tests/unit/test_gpu_types_unit.py +11 -1
  42. tests/unit/test_jobs_unit.py +1014 -95
  43. tests/unit/test_main_unit.py +20 -0
  44. tests/unit/test_models_unit.py +218 -70
  45. tests/unit/test_proximl_unit.py +627 -3
  46. tests/unit/test_volumes_unit.py +211 -70
  47. tests/unit/utils/__init__.py +1 -0
  48. tests/unit/utils/test_transfer_unit.py +4260 -0
  49. proximl/cli/connection.py +0 -61
  50. proximl/connections.py +0 -621
  51. tests/unit/test_connections_unit.py +0 -182
  52. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/LICENSE +0 -0
  53. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/WHEEL +0 -0
  54. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/entry_points.txt +0 -0
  55. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,647 @@
1
+ import os
2
+ import re
3
+ import math
4
+ import asyncio
5
+ import aiohttp
6
+ import aiofiles
7
+ import hashlib
8
+ import logging
9
+ from aiohttp.client_exceptions import (
10
+ ClientResponseError,
11
+ ClientConnectorError,
12
+ ServerTimeoutError,
13
+ ServerDisconnectedError,
14
+ ClientOSError,
15
+ ClientPayloadError,
16
+ InvalidURL,
17
+ )
18
+ from proximl.exceptions import ConnectionError, ProxiMLException
19
+
20
+ MAX_RETRIES = 5
21
+ RETRY_BACKOFF = 2 # Exponential backoff base (2^attempt)
22
+ PARALLEL_UPLOADS = 10 # Max concurrent uploads
23
+ CHUNK_SIZE = 5 * 1024 * 1024 # 5MB
24
+ RETRY_STATUSES = {
25
+ 502,
26
+ 503,
27
+ 504,
28
+ } # Server errors to retry during upload/download
29
+ # Additional retries for DNS/connection errors (ClientConnectorError)
30
+ DNS_MAX_RETRIES = 7 # More retries for DNS resolution issues
31
+ DNS_INITIAL_DELAY = 1 # Initial delay in seconds before first DNS retry
32
+ # Ping warmup timeout: calculate retries so last retry is this many seconds after first try
33
+ PING_WARMUP_TIMEOUT = 8 * 60 # 8 minutes in seconds
34
+
35
+
36
+ def normalize_endpoint(endpoint):
37
+ """
38
+ Normalize endpoint URL to ensure it has a protocol.
39
+
40
+ Args:
41
+ endpoint: Endpoint URL (with or without protocol)
42
+
43
+ Returns:
44
+ Normalized endpoint URL with https:// protocol
45
+
46
+ Raises:
47
+ ValueError: If endpoint is empty
48
+ """
49
+ if not endpoint:
50
+ raise ValueError("Endpoint URL cannot be empty")
51
+
52
+ # Remove trailing slashes
53
+ endpoint = endpoint.rstrip("/")
54
+
55
+ # Add https:// if no protocol is specified
56
+ if not endpoint.startswith(("http://", "https://")):
57
+ endpoint = f"https://{endpoint}"
58
+
59
+ return endpoint
60
+
61
+
62
+ def calculate_ping_retries(timeout_seconds, backoff_base):
63
+ """
64
+ Calculate the number of retries needed for ping warmup to reach timeout.
65
+
66
+ With exponential backoff, if we have n retries, the total wait time is:
67
+ backoff_base^1 + backoff_base^2 + ... + backoff_base^(n-1) = (backoff_base^n - backoff_base) / (backoff_base - 1)
68
+
69
+ For backoff_base = 2, this simplifies to: 2^n - 2
70
+
71
+ Args:
72
+ timeout_seconds: Total timeout in seconds
73
+ backoff_base: Exponential backoff base
74
+
75
+ Returns:
76
+ Number of retries needed to reach or exceed the timeout
77
+ """
78
+ if backoff_base == 2:
79
+ # Simplified calculation for base 2
80
+ # 2^n - 2 >= timeout_seconds
81
+ # 2^n >= timeout_seconds + 2
82
+ # n >= log2(timeout_seconds + 2)
83
+ n = math.ceil(math.log2(timeout_seconds + 2))
84
+ else:
85
+ # General case: (backoff_base^n - backoff_base) / (backoff_base - 1) >= timeout_seconds
86
+ # backoff_base^n >= timeout_seconds * (backoff_base - 1) + backoff_base
87
+ # n >= log_base(timeout_seconds * (backoff_base - 1) + backoff_base)
88
+ target = timeout_seconds * (backoff_base - 1) + backoff_base
89
+ n = math.ceil(math.log(target, backoff_base))
90
+
91
+ return max(1, int(n))
92
+
93
+
94
+ async def ping_endpoint(
95
+ endpoint, auth_token, max_retries=MAX_RETRIES, retry_backoff=RETRY_BACKOFF
96
+ ):
97
+ """
98
+ Ping the endpoint to ensure it's ready before upload/download operations.
99
+
100
+ Retries on all errors (404, 500, DNS errors, etc.) with exponential backoff
101
+ until a 200 response is received. This handles startup timing issues.
102
+
103
+ Creates a fresh TCPConnector for each attempt to force fresh DNS resolution
104
+ and avoid stale DNS cache issues.
105
+
106
+ For ping warmup, calculates retries dynamically to ensure the last retry
107
+ occurs PING_WARMUP_TIMEOUT seconds after the first try.
108
+
109
+ Args:
110
+ endpoint: Server endpoint URL
111
+ auth_token: Authentication token
112
+ max_retries: Maximum number of retry attempts (ignored for ping warmup)
113
+ retry_backoff: Exponential backoff base
114
+
115
+ Raises:
116
+ ConnectionError: If ping never returns 200 after max retries
117
+ ClientConnectorError: If DNS/connection errors persist after max retries
118
+ """
119
+ endpoint = normalize_endpoint(endpoint)
120
+ attempt = 1
121
+ # Calculate retries for ping warmup to reach PING_WARMUP_TIMEOUT
122
+ # Allow max_retries to override when explicitly provided (for testing)
123
+ if max_retries == MAX_RETRIES:
124
+ # Use default, calculate retries for ping warmup
125
+ ping_max_retries = calculate_ping_retries(
126
+ PING_WARMUP_TIMEOUT, retry_backoff
127
+ )
128
+ effective_max_retries = ping_max_retries
129
+ else:
130
+ # Use explicitly provided max_retries (for testing)
131
+ effective_max_retries = max_retries
132
+ ping_max_retries = max_retries # For DNS error handling below
133
+
134
+ while attempt <= effective_max_retries:
135
+ # Create a fresh connector for each attempt to force DNS re-resolution
136
+ # This helps avoid stale DNS cache issues
137
+ connector = None
138
+ try:
139
+ connector = aiohttp.TCPConnector(limit=1, limit_per_host=1)
140
+ async with aiohttp.ClientSession(connector=connector) as session:
141
+ async with session.get(
142
+ f"{endpoint}/ping",
143
+ headers={"Authorization": f"Bearer {auth_token}"},
144
+ timeout=30,
145
+ ) as response:
146
+ if response.status == 200:
147
+ logging.debug(
148
+ f"Endpoint {endpoint} is ready (ping successful)"
149
+ )
150
+ return
151
+ # For any non-200 status, retry
152
+ text = await response.text()
153
+ raise ClientResponseError(
154
+ request_info=response.request_info,
155
+ history=response.history,
156
+ status=response.status,
157
+ message=text,
158
+ )
159
+ except ClientResponseError as e:
160
+ # Retry on any HTTP error status
161
+ if attempt < effective_max_retries:
162
+ logging.debug(
163
+ f"Ping attempt {attempt}/{effective_max_retries} failed with status {e.status}: {str(e)}"
164
+ )
165
+ await asyncio.sleep(retry_backoff**attempt)
166
+ attempt += 1
167
+ continue
168
+ raise ConnectionError(
169
+ f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts. "
170
+ f"Last error: HTTP {e.status} - {str(e)}"
171
+ )
172
+ except ClientConnectorError as e:
173
+ # DNS resolution errors need more retries and initial delay
174
+ # Use the higher of DNS_MAX_RETRIES or calculated ping retries
175
+ if effective_max_retries == ping_max_retries:
176
+ effective_max_retries = max(ping_max_retries, DNS_MAX_RETRIES)
177
+
178
+ if attempt < effective_max_retries:
179
+ # Use initial delay for first retry, then exponential backoff
180
+ if attempt == 1:
181
+ delay = DNS_INITIAL_DELAY
182
+ else:
183
+ delay = retry_backoff ** (attempt - 1)
184
+ logging.debug(
185
+ f"Ping attempt {attempt}/{effective_max_retries} failed due to DNS/connection error: {str(e)}"
186
+ )
187
+ await asyncio.sleep(delay)
188
+ attempt += 1
189
+ continue
190
+ raise ConnectionError(
191
+ f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts due to DNS/connection error: {str(e)}"
192
+ )
193
+ except (
194
+ ServerDisconnectedError,
195
+ ClientOSError,
196
+ ServerTimeoutError,
197
+ ClientPayloadError,
198
+ asyncio.TimeoutError,
199
+ ) as e:
200
+ if attempt < effective_max_retries:
201
+ logging.debug(
202
+ f"Ping attempt {attempt}/{effective_max_retries} failed: {str(e)}"
203
+ )
204
+ await asyncio.sleep(retry_backoff**attempt)
205
+ attempt += 1
206
+ continue
207
+ raise ConnectionError(
208
+ f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts: {str(e)}"
209
+ )
210
+ finally:
211
+ # Ensure connector is closed to free resources and clear DNS cache
212
+ # This forces fresh DNS resolution on the next attempt
213
+ if connector is not None:
214
+ try:
215
+ await connector.close()
216
+ except Exception:
217
+ # Ignore errors during cleanup
218
+ pass
219
+
220
+
221
+ async def retry_request(
222
+ func, *args, max_retries=MAX_RETRIES, retry_backoff=RETRY_BACKOFF, **kwargs
223
+ ):
224
+ """
225
+ Shared retry logic for network requests.
226
+
227
+ For DNS/connection errors (ClientConnectorError), uses more retries and
228
+ an initial delay to handle transient DNS resolution issues.
229
+ """
230
+ attempt = 1
231
+ effective_max_retries = max_retries
232
+
233
+ while attempt <= effective_max_retries:
234
+ try:
235
+ return await func(*args, **kwargs)
236
+ except ClientResponseError as e:
237
+ if e.status in RETRY_STATUSES and attempt < max_retries:
238
+ logging.debug(
239
+ f"Retry {attempt}/{max_retries} due to {e.status}: {str(e)}"
240
+ )
241
+ await asyncio.sleep(retry_backoff**attempt)
242
+ attempt += 1
243
+ continue
244
+ raise
245
+ except ClientConnectorError as e:
246
+ # DNS resolution errors need more retries and initial delay
247
+ # Update effective_max_retries if this is the first DNS error
248
+ if effective_max_retries == max_retries:
249
+ effective_max_retries = max(max_retries, DNS_MAX_RETRIES)
250
+
251
+ if attempt < effective_max_retries:
252
+ # Use initial delay for first retry, then exponential backoff
253
+ if attempt == 1:
254
+ delay = DNS_INITIAL_DELAY
255
+ else:
256
+ delay = retry_backoff ** (attempt - 1)
257
+ logging.debug(
258
+ f"Retry {attempt}/{effective_max_retries} due to DNS/connection error: {str(e)}"
259
+ )
260
+ await asyncio.sleep(delay)
261
+ attempt += 1
262
+ continue
263
+ raise
264
+ except (
265
+ ServerDisconnectedError,
266
+ ClientOSError,
267
+ ServerTimeoutError,
268
+ ClientPayloadError,
269
+ asyncio.TimeoutError,
270
+ ) as e:
271
+ if attempt < max_retries:
272
+ logging.debug(f"Retry {attempt}/{max_retries} due to {str(e)}")
273
+ await asyncio.sleep(retry_backoff**attempt)
274
+ attempt += 1
275
+ continue
276
+ raise
277
+
278
+
279
+ async def upload_chunk(
280
+ session,
281
+ endpoint,
282
+ auth_token,
283
+ total_size,
284
+ data,
285
+ offset,
286
+ ):
287
+ """Uploads a single chunk with retry logic."""
288
+ start = offset
289
+ end = offset + len(data) - 1
290
+ headers = {
291
+ "Content-Range": f"bytes {start}-{end}/{total_size}",
292
+ "Authorization": f"Bearer {auth_token}",
293
+ }
294
+
295
+ async def _upload():
296
+ async with session.put(
297
+ f"{endpoint}/upload",
298
+ headers=headers,
299
+ data=data,
300
+ timeout=30,
301
+ ) as response:
302
+ if response.status == 200:
303
+ await response.release()
304
+ return response
305
+ elif response.status in RETRY_STATUSES:
306
+ text = await response.text()
307
+ raise ClientResponseError(
308
+ request_info=response.request_info,
309
+ history=response.history,
310
+ status=response.status,
311
+ message=text,
312
+ )
313
+ else:
314
+ text = await response.text()
315
+ raise ConnectionError(
316
+ f"Chunk {start}-{end} failed with status {response.status}: {text}"
317
+ )
318
+
319
+ await retry_request(_upload)
320
+
321
+
322
+ async def upload(endpoint, auth_token, path):
323
+ """
324
+ Upload a local file or directory as a TAR stream to the server.
325
+
326
+ Args:
327
+ endpoint: Server endpoint URL
328
+ auth_token: Authentication token
329
+ path: Local file or directory path to upload
330
+
331
+ Raises:
332
+ ValueError: If path doesn't exist or is invalid
333
+ ConnectionError: If upload fails or endpoint ping fails
334
+ ProxiMLException: For other errors
335
+ """
336
+ # Normalize endpoint URL to ensure it has a protocol
337
+ endpoint = normalize_endpoint(endpoint)
338
+
339
+ # Ping endpoint to ensure it's ready before starting upload
340
+ await ping_endpoint(endpoint, auth_token)
341
+
342
+ # Expand user home directory (~) in path
343
+ path = os.path.expanduser(path)
344
+
345
+ if not os.path.exists(path):
346
+ raise ValueError(f"Path not found: {path}")
347
+
348
+ # Determine if it's a file or directory and build tar command accordingly
349
+ abs_path = os.path.abspath(path)
350
+
351
+ if os.path.isfile(path):
352
+ # For a single file, create a tar with just that file
353
+ file_name = os.path.basename(abs_path)
354
+ parent_dir = os.path.dirname(abs_path)
355
+ # Use tar -c to create archive with single file, stream to stdout
356
+ # -C changes to parent directory so the file appears at root of tar
357
+ command = ["tar", "-c", "-C", parent_dir, file_name]
358
+ desc = f"Uploading file {file_name}"
359
+ elif os.path.isdir(path):
360
+ # For a directory, archive its contents at the root of the tar file
361
+ # -C changes to the directory itself, and . archives all contents
362
+ command = ["tar", "-c", "-C", abs_path, "."]
363
+ dir_name = os.path.basename(abs_path)
364
+ desc = f"Uploading directory {dir_name}"
365
+ else:
366
+ raise ValueError(f"Path is neither a file nor directory: {path}")
367
+
368
+ process = await asyncio.create_subprocess_exec(
369
+ *command,
370
+ stdout=asyncio.subprocess.PIPE,
371
+ stderr=asyncio.subprocess.PIPE,
372
+ )
373
+
374
+ # We need to know the total size for progress, but tar doesn't tell us
375
+ # So we'll estimate or track bytes uploaded
376
+ sha512 = hashlib.sha512()
377
+ offset = 0
378
+ semaphore = asyncio.Semaphore(PARALLEL_UPLOADS)
379
+
380
+ async with aiohttp.ClientSession() as session:
381
+ async with semaphore:
382
+ while True:
383
+ chunk = await process.stdout.read(CHUNK_SIZE)
384
+ if not chunk:
385
+ break # End of stream
386
+
387
+ sha512.update(chunk)
388
+ chunk_offset = offset
389
+ offset += len(chunk)
390
+
391
+ # For total_size, we'll use a large number since we don't know the actual size
392
+ # The server will handle the Content-Range correctly
393
+ await upload_chunk(
394
+ session,
395
+ endpoint,
396
+ auth_token,
397
+ offset, # Use current offset as total (will be updated)
398
+ chunk,
399
+ chunk_offset,
400
+ )
401
+
402
+ # Wait for process to finish
403
+ await process.wait()
404
+ if process.returncode != 0:
405
+ stderr = await process.stderr.read()
406
+ raise ProxiMLException(
407
+ f"tar command failed: {stderr.decode() if stderr else 'Unknown error'}"
408
+ )
409
+
410
+ # Finalize upload
411
+ file_hash = sha512.hexdigest()
412
+
413
+ async def _finalize():
414
+ async with session.post(
415
+ f"{endpoint}/finalize",
416
+ headers={"Authorization": f"Bearer {auth_token}"},
417
+ json={"hash": file_hash},
418
+ ) as response:
419
+ if response.status != 200:
420
+ text = await response.text()
421
+ raise ConnectionError(f"Finalize failed: {text}")
422
+ return await response.json()
423
+
424
+ data = await retry_request(_finalize)
425
+ logging.debug(f"Upload finalized: {data}")
426
+
427
+
428
+ async def download(endpoint, auth_token, target_directory, file_name=None):
429
+ """
430
+ Download a directory archive from the server and extract it.
431
+
432
+ Args:
433
+ endpoint: Server endpoint URL
434
+ auth_token: Authentication token
435
+ target_directory: Directory to extract files to (or save zip file)
436
+ file_name: Optional filename override for zip archive (if ARCHIVE=true).
437
+ If not provided, filename is extracted from Content-Disposition header.
438
+
439
+ Raises:
440
+ ConnectionError: If download fails or endpoint ping fails
441
+ ProxiMLException: For other errors
442
+ """
443
+ # Normalize endpoint URL to ensure it has a protocol
444
+ endpoint = normalize_endpoint(endpoint)
445
+
446
+ # Ping endpoint to ensure it's ready before starting download
447
+ await ping_endpoint(endpoint, auth_token)
448
+
449
+ # Expand user home directory (~) in target_directory
450
+ target_directory = os.path.expanduser(target_directory)
451
+
452
+ if not os.path.isdir(target_directory):
453
+ os.makedirs(target_directory, exist_ok=True)
454
+
455
+ # First, check server info to see if ARCHIVE is set
456
+ # If /info endpoint is not available, default to False (TAR stream mode)
457
+ async with aiohttp.ClientSession() as session:
458
+ use_archive = False
459
+ try:
460
+
461
+ async def _get_info():
462
+ async with session.get(
463
+ f"{endpoint}/info",
464
+ headers={"Authorization": f"Bearer {auth_token}"},
465
+ timeout=30,
466
+ ) as response:
467
+ if response.status != 200:
468
+ try:
469
+ error_text = await response.text()
470
+ except Exception:
471
+ error_text = f"Unable to read response body (status: {response.status})"
472
+ raise ClientResponseError(
473
+ request_info=response.request_info,
474
+ history=response.history,
475
+ status=response.status,
476
+ message=error_text,
477
+ )
478
+ return await response.json()
479
+
480
+ info = await retry_request(_get_info)
481
+ use_archive = info.get("archive", False)
482
+ except InvalidURL as e:
483
+ raise ConnectionError(
484
+ f"Invalid endpoint URL: {endpoint}. "
485
+ f"Please ensure the URL includes a protocol (http:// or https://). "
486
+ f"Error: {str(e)}"
487
+ )
488
+ except (ConnectionError, ClientResponseError) as e:
489
+ # If /info endpoint is not available (404) or other error,
490
+ # default to TAR stream mode and continue
491
+ if isinstance(e, ConnectionError) and "404" in str(e):
492
+ logging.debug(
493
+ "Warning: /info endpoint not available, defaulting to TAR stream mode"
494
+ )
495
+ elif isinstance(e, ClientResponseError) and e.status == 404:
496
+ logging.debug(
497
+ "Warning: /info endpoint not available, defaulting to TAR stream mode"
498
+ )
499
+ else:
500
+ # For other errors, convert ClientResponseError to ConnectionError
501
+ # to maintain backward compatibility
502
+ if isinstance(e, ClientResponseError):
503
+ error_msg = getattr(e, "message", str(e))
504
+ raise ConnectionError(
505
+ f"Failed to get server info (status {e.status}): {error_msg}"
506
+ )
507
+ # For ConnectionError, re-raise as-is
508
+ raise
509
+
510
+ # Download the archive
511
+ # Note: Do NOT use "async with session.get() as response" - exiting the
512
+ # context manager would release/close the connection before we read the
513
+ # body. We need the response (and connection) to stay open for streaming.
514
+ async def _download():
515
+ response = await session.get(
516
+ f"{endpoint}/download",
517
+ headers={"Authorization": f"Bearer {auth_token}"},
518
+ timeout=None, # No timeout for large downloads
519
+ )
520
+ if response.status != 200:
521
+ text = await response.text()
522
+ response.close()
523
+ # Raise ClientResponseError for non-200 status
524
+ # Note: 404 and other errors should be rare now since ping_endpoint ensures readiness
525
+ raise ClientResponseError(
526
+ request_info=response.request_info,
527
+ history=response.history,
528
+ status=response.status,
529
+ message=(
530
+ text
531
+ if text
532
+ else f"Download endpoint returned status {response.status}"
533
+ ),
534
+ )
535
+ return response
536
+
537
+ response = await retry_request(_download)
538
+
539
+ # Check Content-Type header as fallback to determine if it's a zip file
540
+ content_type = response.headers.get("Content-Type", "").lower()
541
+ content_length = response.headers.get("Content-Length")
542
+ if "zip" in content_type and not use_archive:
543
+ logging.debug(
544
+ "Warning: Server returned zip content but /info indicated TAR mode. Using zip mode."
545
+ )
546
+ use_archive = True
547
+
548
+ # Debug: Log response info
549
+ if content_length:
550
+ logging.debug(f"Response Content-Length: {content_length} bytes")
551
+ logging.debug(f"Response Content-Type: {content_type}")
552
+
553
+ try:
554
+ if use_archive:
555
+ # Save as ZIP file
556
+ # Extract filename from Content-Disposition header if not provided
557
+ if file_name is None:
558
+ content_disposition = response.headers.get(
559
+ "Content-Disposition", ""
560
+ )
561
+ # Parse filename from Content-Disposition: attachment; filename="filename.zip"
562
+ if "filename=" in content_disposition:
563
+ # Extract filename from quotes
564
+ match = re.search(
565
+ r'filename="?([^"]+)"?', content_disposition
566
+ )
567
+ if match:
568
+ file_name = match.group(1)
569
+ else:
570
+ # Fallback: try without quotes
571
+ match = re.search(
572
+ r"filename=([^;]+)", content_disposition
573
+ )
574
+ if match:
575
+ file_name = match.group(1).strip()
576
+
577
+ # Fallback if no filename in header
578
+ if file_name is None:
579
+ file_name = "archive.zip"
580
+
581
+ # Ensure .zip extension
582
+ if not file_name.endswith(".zip"):
583
+ file_name = file_name + ".zip"
584
+
585
+ output_path = os.path.join(target_directory, file_name)
586
+
587
+ total_bytes = 0
588
+ async with aiofiles.open(output_path, "wb") as f:
589
+ # Stream the response content in chunks
590
+ async for chunk in response.content.iter_chunked(
591
+ CHUNK_SIZE
592
+ ):
593
+ await f.write(chunk)
594
+ total_bytes += len(chunk)
595
+
596
+ if total_bytes == 0:
597
+ raise ConnectionError(
598
+ "Downloaded file is empty (0 bytes). "
599
+ "The server may not have any files to download, or there was an error streaming the response."
600
+ )
601
+
602
+ logging.info(
603
+ f"Archive saved to: {output_path} ({total_bytes} bytes)"
604
+ )
605
+ else:
606
+ # Extract TAR stream directly
607
+ # Create tar extraction process
608
+ command = ["tar", "-x", "-C", target_directory]
609
+
610
+ extract_process = await asyncio.create_subprocess_exec(
611
+ *command,
612
+ stdin=asyncio.subprocess.PIPE,
613
+ stderr=asyncio.subprocess.PIPE,
614
+ )
615
+
616
+ # Stream response to tar process
617
+ async for chunk in response.content.iter_chunked(CHUNK_SIZE):
618
+ extract_process.stdin.write(chunk)
619
+ await extract_process.stdin.drain()
620
+
621
+ extract_process.stdin.close()
622
+ await extract_process.wait()
623
+
624
+ if extract_process.returncode != 0:
625
+ stderr = await extract_process.stderr.read()
626
+ raise ProxiMLException(
627
+ f"tar extraction failed: {stderr.decode() if stderr else 'Unknown error'}"
628
+ )
629
+
630
+ logging.info(f"Files extracted to: {target_directory}")
631
+ finally:
632
+ response.close()
633
+
634
+ # Finalize download
635
+ async def _finalize():
636
+ async with session.post(
637
+ f"{endpoint}/finalize",
638
+ headers={"Authorization": f"Bearer {auth_token}"},
639
+ json={},
640
+ ) as response:
641
+ if response.status != 200:
642
+ text = await response.text()
643
+ raise ConnectionError(f"Finalize failed: {text}")
644
+ return await response.json()
645
+
646
+ data = await retry_request(_finalize)
647
+ logging.debug(f"Download finalized: {data}")