trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +17 -1
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/projects/projects.py +2 -2
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import asyncio
|
|
4
|
+
import aiohttp
|
|
5
|
+
import aiofiles
|
|
6
|
+
import hashlib
|
|
7
|
+
import logging
|
|
8
|
+
from aiohttp.client_exceptions import (
|
|
9
|
+
ClientResponseError,
|
|
10
|
+
ClientConnectorError,
|
|
11
|
+
ServerTimeoutError,
|
|
12
|
+
ServerDisconnectedError,
|
|
13
|
+
ClientOSError,
|
|
14
|
+
ClientPayloadError,
|
|
15
|
+
InvalidURL,
|
|
16
|
+
)
|
|
17
|
+
from trainml.exceptions import ConnectionError, TrainMLException
|
|
18
|
+
|
|
19
|
+
MAX_RETRIES = 5
|
|
20
|
+
RETRY_BACKOFF = 2 # Exponential backoff base (2^attempt)
|
|
21
|
+
PARALLEL_UPLOADS = 10 # Max concurrent uploads
|
|
22
|
+
CHUNK_SIZE = 5 * 1024 * 1024 # 5MB
|
|
23
|
+
RETRY_STATUSES = {
|
|
24
|
+
502,
|
|
25
|
+
503,
|
|
26
|
+
504,
|
|
27
|
+
} # Server errors to retry during upload/download
|
|
28
|
+
# Additional retries for DNS/connection errors (ClientConnectorError)
|
|
29
|
+
DNS_MAX_RETRIES = 7 # More retries for DNS resolution issues
|
|
30
|
+
DNS_INITIAL_DELAY = 1 # Initial delay in seconds before first DNS retry
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalize_endpoint(endpoint):
|
|
34
|
+
"""
|
|
35
|
+
Normalize endpoint URL to ensure it has a protocol.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
endpoint: Endpoint URL (with or without protocol)
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Normalized endpoint URL with https:// protocol
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If endpoint is empty
|
|
45
|
+
"""
|
|
46
|
+
if not endpoint:
|
|
47
|
+
raise ValueError("Endpoint URL cannot be empty")
|
|
48
|
+
|
|
49
|
+
# Remove trailing slashes
|
|
50
|
+
endpoint = endpoint.rstrip("/")
|
|
51
|
+
|
|
52
|
+
# Add https:// if no protocol is specified
|
|
53
|
+
if not endpoint.startswith(("http://", "https://")):
|
|
54
|
+
endpoint = f"https://{endpoint}"
|
|
55
|
+
|
|
56
|
+
return endpoint
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def ping_endpoint(
|
|
60
|
+
endpoint, auth_token, max_retries=MAX_RETRIES, retry_backoff=RETRY_BACKOFF
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Ping the endpoint to ensure it's ready before upload/download operations.
|
|
64
|
+
|
|
65
|
+
Retries on all errors (404, 500, DNS errors, etc.) with exponential backoff
|
|
66
|
+
until a 200 response is received. This handles startup timing issues.
|
|
67
|
+
|
|
68
|
+
Creates a fresh TCPConnector for each attempt to force fresh DNS resolution
|
|
69
|
+
and avoid stale DNS cache issues.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
endpoint: Server endpoint URL
|
|
73
|
+
auth_token: Authentication token
|
|
74
|
+
max_retries: Maximum number of retry attempts
|
|
75
|
+
retry_backoff: Exponential backoff base
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ConnectionError: If ping never returns 200 after max retries
|
|
79
|
+
ClientConnectorError: If DNS/connection errors persist after max retries
|
|
80
|
+
"""
|
|
81
|
+
endpoint = normalize_endpoint(endpoint)
|
|
82
|
+
attempt = 1
|
|
83
|
+
effective_max_retries = max_retries
|
|
84
|
+
|
|
85
|
+
while attempt <= effective_max_retries:
|
|
86
|
+
# Create a fresh connector for each attempt to force DNS re-resolution
|
|
87
|
+
# This helps avoid stale DNS cache issues
|
|
88
|
+
connector = None
|
|
89
|
+
try:
|
|
90
|
+
connector = aiohttp.TCPConnector(limit=1, limit_per_host=1)
|
|
91
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
92
|
+
async with session.get(
|
|
93
|
+
f"{endpoint}/ping",
|
|
94
|
+
headers={"Authorization": f"Bearer {auth_token}"},
|
|
95
|
+
timeout=30,
|
|
96
|
+
) as response:
|
|
97
|
+
if response.status == 200:
|
|
98
|
+
logging.debug(
|
|
99
|
+
f"Endpoint {endpoint} is ready (ping successful)"
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
# For any non-200 status, retry
|
|
103
|
+
text = await response.text()
|
|
104
|
+
raise ClientResponseError(
|
|
105
|
+
request_info=response.request_info,
|
|
106
|
+
history=response.history,
|
|
107
|
+
status=response.status,
|
|
108
|
+
message=text,
|
|
109
|
+
)
|
|
110
|
+
except ClientResponseError as e:
|
|
111
|
+
# Retry on any HTTP error status
|
|
112
|
+
if attempt < effective_max_retries:
|
|
113
|
+
logging.debug(
|
|
114
|
+
f"Ping attempt {attempt}/{effective_max_retries} failed with status {e.status}: {str(e)}"
|
|
115
|
+
)
|
|
116
|
+
await asyncio.sleep(retry_backoff**attempt)
|
|
117
|
+
attempt += 1
|
|
118
|
+
continue
|
|
119
|
+
raise ConnectionError(
|
|
120
|
+
f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts. "
|
|
121
|
+
f"Last error: HTTP {e.status} - {str(e)}"
|
|
122
|
+
)
|
|
123
|
+
except ClientConnectorError as e:
|
|
124
|
+
# DNS resolution errors need more retries and initial delay
|
|
125
|
+
if effective_max_retries == max_retries:
|
|
126
|
+
effective_max_retries = max(max_retries, DNS_MAX_RETRIES)
|
|
127
|
+
|
|
128
|
+
if attempt < effective_max_retries:
|
|
129
|
+
# Use initial delay for first retry, then exponential backoff
|
|
130
|
+
if attempt == 1:
|
|
131
|
+
delay = DNS_INITIAL_DELAY
|
|
132
|
+
else:
|
|
133
|
+
delay = retry_backoff ** (attempt - 1)
|
|
134
|
+
logging.debug(
|
|
135
|
+
f"Ping attempt {attempt}/{effective_max_retries} failed due to DNS/connection error: {str(e)}"
|
|
136
|
+
)
|
|
137
|
+
await asyncio.sleep(delay)
|
|
138
|
+
attempt += 1
|
|
139
|
+
continue
|
|
140
|
+
raise ConnectionError(
|
|
141
|
+
f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts due to DNS/connection error: {str(e)}"
|
|
142
|
+
)
|
|
143
|
+
except (
|
|
144
|
+
ServerDisconnectedError,
|
|
145
|
+
ClientOSError,
|
|
146
|
+
ServerTimeoutError,
|
|
147
|
+
ClientPayloadError,
|
|
148
|
+
asyncio.TimeoutError,
|
|
149
|
+
) as e:
|
|
150
|
+
if attempt < effective_max_retries:
|
|
151
|
+
logging.debug(
|
|
152
|
+
f"Ping attempt {attempt}/{effective_max_retries} failed: {str(e)}"
|
|
153
|
+
)
|
|
154
|
+
await asyncio.sleep(retry_backoff**attempt)
|
|
155
|
+
attempt += 1
|
|
156
|
+
continue
|
|
157
|
+
raise ConnectionError(
|
|
158
|
+
f"Endpoint {endpoint} ping failed after {effective_max_retries} attempts: {str(e)}"
|
|
159
|
+
)
|
|
160
|
+
finally:
|
|
161
|
+
# Ensure connector is closed to free resources and clear DNS cache
|
|
162
|
+
# This forces fresh DNS resolution on the next attempt
|
|
163
|
+
if connector is not None:
|
|
164
|
+
try:
|
|
165
|
+
await connector.close()
|
|
166
|
+
except Exception:
|
|
167
|
+
# Ignore errors during cleanup
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def retry_request(
|
|
172
|
+
func, *args, max_retries=MAX_RETRIES, retry_backoff=RETRY_BACKOFF, **kwargs
|
|
173
|
+
):
|
|
174
|
+
"""
|
|
175
|
+
Shared retry logic for network requests.
|
|
176
|
+
|
|
177
|
+
For DNS/connection errors (ClientConnectorError), uses more retries and
|
|
178
|
+
an initial delay to handle transient DNS resolution issues.
|
|
179
|
+
"""
|
|
180
|
+
attempt = 1
|
|
181
|
+
effective_max_retries = max_retries
|
|
182
|
+
|
|
183
|
+
while attempt <= effective_max_retries:
|
|
184
|
+
try:
|
|
185
|
+
return await func(*args, **kwargs)
|
|
186
|
+
except ClientResponseError as e:
|
|
187
|
+
if e.status in RETRY_STATUSES and attempt < max_retries:
|
|
188
|
+
logging.debug(
|
|
189
|
+
f"Retry {attempt}/{max_retries} due to {e.status}: {str(e)}"
|
|
190
|
+
)
|
|
191
|
+
await asyncio.sleep(retry_backoff**attempt)
|
|
192
|
+
attempt += 1
|
|
193
|
+
continue
|
|
194
|
+
raise
|
|
195
|
+
except ClientConnectorError as e:
|
|
196
|
+
# DNS resolution errors need more retries and initial delay
|
|
197
|
+
# Update effective_max_retries if this is the first DNS error
|
|
198
|
+
if effective_max_retries == max_retries:
|
|
199
|
+
effective_max_retries = max(max_retries, DNS_MAX_RETRIES)
|
|
200
|
+
|
|
201
|
+
if attempt < effective_max_retries:
|
|
202
|
+
# Use initial delay for first retry, then exponential backoff
|
|
203
|
+
if attempt == 1:
|
|
204
|
+
delay = DNS_INITIAL_DELAY
|
|
205
|
+
else:
|
|
206
|
+
delay = retry_backoff ** (attempt - 1)
|
|
207
|
+
logging.debug(
|
|
208
|
+
f"Retry {attempt}/{effective_max_retries} due to DNS/connection error: {str(e)}"
|
|
209
|
+
)
|
|
210
|
+
await asyncio.sleep(delay)
|
|
211
|
+
attempt += 1
|
|
212
|
+
continue
|
|
213
|
+
raise
|
|
214
|
+
except (
|
|
215
|
+
ServerDisconnectedError,
|
|
216
|
+
ClientOSError,
|
|
217
|
+
ServerTimeoutError,
|
|
218
|
+
ClientPayloadError,
|
|
219
|
+
asyncio.TimeoutError,
|
|
220
|
+
) as e:
|
|
221
|
+
if attempt < max_retries:
|
|
222
|
+
logging.debug(f"Retry {attempt}/{max_retries} due to {str(e)}")
|
|
223
|
+
await asyncio.sleep(retry_backoff**attempt)
|
|
224
|
+
attempt += 1
|
|
225
|
+
continue
|
|
226
|
+
raise
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
async def upload_chunk(
|
|
230
|
+
session,
|
|
231
|
+
endpoint,
|
|
232
|
+
auth_token,
|
|
233
|
+
total_size,
|
|
234
|
+
data,
|
|
235
|
+
offset,
|
|
236
|
+
):
|
|
237
|
+
"""Uploads a single chunk with retry logic."""
|
|
238
|
+
start = offset
|
|
239
|
+
end = offset + len(data) - 1
|
|
240
|
+
headers = {
|
|
241
|
+
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
|
242
|
+
"Authorization": f"Bearer {auth_token}",
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
async def _upload():
|
|
246
|
+
async with session.put(
|
|
247
|
+
f"{endpoint}/upload",
|
|
248
|
+
headers=headers,
|
|
249
|
+
data=data,
|
|
250
|
+
timeout=30,
|
|
251
|
+
) as response:
|
|
252
|
+
if response.status == 200:
|
|
253
|
+
await response.release()
|
|
254
|
+
return response
|
|
255
|
+
elif response.status in RETRY_STATUSES:
|
|
256
|
+
text = await response.text()
|
|
257
|
+
raise ClientResponseError(
|
|
258
|
+
request_info=response.request_info,
|
|
259
|
+
history=response.history,
|
|
260
|
+
status=response.status,
|
|
261
|
+
message=text,
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
text = await response.text()
|
|
265
|
+
raise ConnectionError(
|
|
266
|
+
f"Chunk {start}-{end} failed with status {response.status}: {text}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
await retry_request(_upload)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
async def upload(endpoint, auth_token, path):
|
|
273
|
+
"""
|
|
274
|
+
Upload a local file or directory as a TAR stream to the server.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
endpoint: Server endpoint URL
|
|
278
|
+
auth_token: Authentication token
|
|
279
|
+
path: Local file or directory path to upload
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
ValueError: If path doesn't exist or is invalid
|
|
283
|
+
ConnectionError: If upload fails or endpoint ping fails
|
|
284
|
+
TrainMLException: For other errors
|
|
285
|
+
"""
|
|
286
|
+
# Normalize endpoint URL to ensure it has a protocol
|
|
287
|
+
endpoint = normalize_endpoint(endpoint)
|
|
288
|
+
|
|
289
|
+
# Ping endpoint to ensure it's ready before starting upload
|
|
290
|
+
await ping_endpoint(endpoint, auth_token)
|
|
291
|
+
|
|
292
|
+
# Expand user home directory (~) in path
|
|
293
|
+
path = os.path.expanduser(path)
|
|
294
|
+
|
|
295
|
+
if not os.path.exists(path):
|
|
296
|
+
raise ValueError(f"Path not found: {path}")
|
|
297
|
+
|
|
298
|
+
# Determine if it's a file or directory and build tar command accordingly
|
|
299
|
+
abs_path = os.path.abspath(path)
|
|
300
|
+
|
|
301
|
+
if os.path.isfile(path):
|
|
302
|
+
# For a single file, create a tar with just that file
|
|
303
|
+
file_name = os.path.basename(abs_path)
|
|
304
|
+
parent_dir = os.path.dirname(abs_path)
|
|
305
|
+
# Use tar -c to create archive with single file, stream to stdout
|
|
306
|
+
# -C changes to parent directory so the file appears at root of tar
|
|
307
|
+
command = ["tar", "-c", "-C", parent_dir, file_name]
|
|
308
|
+
desc = f"Uploading file {file_name}"
|
|
309
|
+
elif os.path.isdir(path):
|
|
310
|
+
# For a directory, archive its contents at the root of the tar file
|
|
311
|
+
# -C changes to the directory itself, and . archives all contents
|
|
312
|
+
command = ["tar", "-c", "-C", abs_path, "."]
|
|
313
|
+
dir_name = os.path.basename(abs_path)
|
|
314
|
+
desc = f"Uploading directory {dir_name}"
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError(f"Path is neither a file nor directory: {path}")
|
|
317
|
+
|
|
318
|
+
process = await asyncio.create_subprocess_exec(
|
|
319
|
+
*command,
|
|
320
|
+
stdout=asyncio.subprocess.PIPE,
|
|
321
|
+
stderr=asyncio.subprocess.PIPE,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# We need to know the total size for progress, but tar doesn't tell us
|
|
325
|
+
# So we'll estimate or track bytes uploaded
|
|
326
|
+
sha512 = hashlib.sha512()
|
|
327
|
+
offset = 0
|
|
328
|
+
semaphore = asyncio.Semaphore(PARALLEL_UPLOADS)
|
|
329
|
+
|
|
330
|
+
async with aiohttp.ClientSession() as session:
|
|
331
|
+
async with semaphore:
|
|
332
|
+
while True:
|
|
333
|
+
chunk = await process.stdout.read(CHUNK_SIZE)
|
|
334
|
+
if not chunk:
|
|
335
|
+
break # End of stream
|
|
336
|
+
|
|
337
|
+
sha512.update(chunk)
|
|
338
|
+
chunk_offset = offset
|
|
339
|
+
offset += len(chunk)
|
|
340
|
+
|
|
341
|
+
# For total_size, we'll use a large number since we don't know the actual size
|
|
342
|
+
# The server will handle the Content-Range correctly
|
|
343
|
+
await upload_chunk(
|
|
344
|
+
session,
|
|
345
|
+
endpoint,
|
|
346
|
+
auth_token,
|
|
347
|
+
offset, # Use current offset as total (will be updated)
|
|
348
|
+
chunk,
|
|
349
|
+
chunk_offset,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Wait for process to finish
|
|
353
|
+
await process.wait()
|
|
354
|
+
if process.returncode != 0:
|
|
355
|
+
stderr = await process.stderr.read()
|
|
356
|
+
raise TrainMLException(
|
|
357
|
+
f"tar command failed: {stderr.decode() if stderr else 'Unknown error'}"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Finalize upload
|
|
361
|
+
file_hash = sha512.hexdigest()
|
|
362
|
+
|
|
363
|
+
async def _finalize():
|
|
364
|
+
async with session.post(
|
|
365
|
+
f"{endpoint}/finalize",
|
|
366
|
+
headers={"Authorization": f"Bearer {auth_token}"},
|
|
367
|
+
json={"hash": file_hash},
|
|
368
|
+
) as response:
|
|
369
|
+
if response.status != 200:
|
|
370
|
+
text = await response.text()
|
|
371
|
+
raise ConnectionError(f"Finalize failed: {text}")
|
|
372
|
+
return await response.json()
|
|
373
|
+
|
|
374
|
+
data = await retry_request(_finalize)
|
|
375
|
+
logging.debug(f"Upload finalized: {data}")
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
async def download(endpoint, auth_token, target_directory, file_name=None):
|
|
379
|
+
"""
|
|
380
|
+
Download a directory archive from the server and extract it.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
endpoint: Server endpoint URL
|
|
384
|
+
auth_token: Authentication token
|
|
385
|
+
target_directory: Directory to extract files to (or save zip file)
|
|
386
|
+
file_name: Optional filename override for zip archive (if ARCHIVE=true).
|
|
387
|
+
If not provided, filename is extracted from Content-Disposition header.
|
|
388
|
+
|
|
389
|
+
Raises:
|
|
390
|
+
ConnectionError: If download fails or endpoint ping fails
|
|
391
|
+
TrainMLException: For other errors
|
|
392
|
+
"""
|
|
393
|
+
# Normalize endpoint URL to ensure it has a protocol
|
|
394
|
+
endpoint = normalize_endpoint(endpoint)
|
|
395
|
+
|
|
396
|
+
# Ping endpoint to ensure it's ready before starting download
|
|
397
|
+
await ping_endpoint(endpoint, auth_token)
|
|
398
|
+
|
|
399
|
+
# Expand user home directory (~) in target_directory
|
|
400
|
+
target_directory = os.path.expanduser(target_directory)
|
|
401
|
+
|
|
402
|
+
if not os.path.isdir(target_directory):
|
|
403
|
+
os.makedirs(target_directory, exist_ok=True)
|
|
404
|
+
|
|
405
|
+
# First, check server info to see if ARCHIVE is set
|
|
406
|
+
# If /info endpoint is not available, default to False (TAR stream mode)
|
|
407
|
+
async with aiohttp.ClientSession() as session:
|
|
408
|
+
use_archive = False
|
|
409
|
+
try:
|
|
410
|
+
|
|
411
|
+
async def _get_info():
|
|
412
|
+
async with session.get(
|
|
413
|
+
f"{endpoint}/info",
|
|
414
|
+
headers={"Authorization": f"Bearer {auth_token}"},
|
|
415
|
+
timeout=30,
|
|
416
|
+
) as response:
|
|
417
|
+
if response.status != 200:
|
|
418
|
+
try:
|
|
419
|
+
error_text = await response.text()
|
|
420
|
+
except Exception:
|
|
421
|
+
error_text = f"Unable to read response body (status: {response.status})"
|
|
422
|
+
raise ConnectionError(
|
|
423
|
+
f"Failed to get server info (status {response.status}): {error_text}"
|
|
424
|
+
)
|
|
425
|
+
return await response.json()
|
|
426
|
+
|
|
427
|
+
info = await retry_request(_get_info)
|
|
428
|
+
use_archive = info.get("archive", False)
|
|
429
|
+
except InvalidURL as e:
|
|
430
|
+
raise ConnectionError(
|
|
431
|
+
f"Invalid endpoint URL: {endpoint}. "
|
|
432
|
+
f"Please ensure the URL includes a protocol (http:// or https://). "
|
|
433
|
+
f"Error: {str(e)}"
|
|
434
|
+
)
|
|
435
|
+
except (ConnectionError, ClientResponseError) as e:
|
|
436
|
+
# If /info endpoint is not available (404) or other error,
|
|
437
|
+
# default to TAR stream mode and continue
|
|
438
|
+
if isinstance(e, ConnectionError) and "404" in str(e):
|
|
439
|
+
logging.debug(
|
|
440
|
+
"Warning: /info endpoint not available, defaulting to TAR stream mode"
|
|
441
|
+
)
|
|
442
|
+
elif isinstance(e, ClientResponseError) and e.status == 404:
|
|
443
|
+
logging.debug(
|
|
444
|
+
"Warning: /info endpoint not available, defaulting to TAR stream mode"
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
# For other errors, re-raise
|
|
448
|
+
raise
|
|
449
|
+
|
|
450
|
+
# Download the archive
|
|
451
|
+
# Note: Do NOT use "async with session.get() as response" - exiting the
|
|
452
|
+
# context manager would release/close the connection before we read the
|
|
453
|
+
# body. We need the response (and connection) to stay open for streaming.
|
|
454
|
+
async def _download():
|
|
455
|
+
response = await session.get(
|
|
456
|
+
f"{endpoint}/download",
|
|
457
|
+
headers={"Authorization": f"Bearer {auth_token}"},
|
|
458
|
+
timeout=None, # No timeout for large downloads
|
|
459
|
+
)
|
|
460
|
+
if response.status != 200:
|
|
461
|
+
text = await response.text()
|
|
462
|
+
response.close()
|
|
463
|
+
# Raise ClientResponseError for non-200 status
|
|
464
|
+
# Note: 404 and other errors should be rare now since ping_endpoint ensures readiness
|
|
465
|
+
raise ClientResponseError(
|
|
466
|
+
request_info=response.request_info,
|
|
467
|
+
history=response.history,
|
|
468
|
+
status=response.status,
|
|
469
|
+
message=(
|
|
470
|
+
text
|
|
471
|
+
if text
|
|
472
|
+
else f"Download endpoint returned status {response.status}"
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
return response
|
|
476
|
+
|
|
477
|
+
response = await retry_request(_download)
|
|
478
|
+
|
|
479
|
+
# Check Content-Type header as fallback to determine if it's a zip file
|
|
480
|
+
content_type = response.headers.get("Content-Type", "").lower()
|
|
481
|
+
content_length = response.headers.get("Content-Length")
|
|
482
|
+
if "zip" in content_type and not use_archive:
|
|
483
|
+
logging.debug(
|
|
484
|
+
"Warning: Server returned zip content but /info indicated TAR mode. Using zip mode."
|
|
485
|
+
)
|
|
486
|
+
use_archive = True
|
|
487
|
+
|
|
488
|
+
# Debug: Log response info
|
|
489
|
+
if content_length:
|
|
490
|
+
logging.debug(f"Response Content-Length: {content_length} bytes")
|
|
491
|
+
logging.debug(f"Response Content-Type: {content_type}")
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
if use_archive:
|
|
495
|
+
# Save as ZIP file
|
|
496
|
+
# Extract filename from Content-Disposition header if not provided
|
|
497
|
+
if file_name is None:
|
|
498
|
+
content_disposition = response.headers.get(
|
|
499
|
+
"Content-Disposition", ""
|
|
500
|
+
)
|
|
501
|
+
# Parse filename from Content-Disposition: attachment; filename="filename.zip"
|
|
502
|
+
if "filename=" in content_disposition:
|
|
503
|
+
# Extract filename from quotes
|
|
504
|
+
match = re.search(
|
|
505
|
+
r'filename="?([^"]+)"?', content_disposition
|
|
506
|
+
)
|
|
507
|
+
if match:
|
|
508
|
+
file_name = match.group(1)
|
|
509
|
+
else:
|
|
510
|
+
# Fallback: try without quotes
|
|
511
|
+
match = re.search(
|
|
512
|
+
r"filename=([^;]+)", content_disposition
|
|
513
|
+
)
|
|
514
|
+
if match:
|
|
515
|
+
file_name = match.group(1).strip()
|
|
516
|
+
|
|
517
|
+
# Fallback if no filename in header
|
|
518
|
+
if file_name is None:
|
|
519
|
+
file_name = "archive.zip"
|
|
520
|
+
|
|
521
|
+
# Ensure .zip extension
|
|
522
|
+
if not file_name.endswith(".zip"):
|
|
523
|
+
file_name = file_name + ".zip"
|
|
524
|
+
|
|
525
|
+
output_path = os.path.join(target_directory, file_name)
|
|
526
|
+
|
|
527
|
+
total_bytes = 0
|
|
528
|
+
async with aiofiles.open(output_path, "wb") as f:
|
|
529
|
+
# Stream the response content in chunks
|
|
530
|
+
async for chunk in response.content.iter_chunked(
|
|
531
|
+
CHUNK_SIZE
|
|
532
|
+
):
|
|
533
|
+
await f.write(chunk)
|
|
534
|
+
total_bytes += len(chunk)
|
|
535
|
+
|
|
536
|
+
if total_bytes == 0:
|
|
537
|
+
raise ConnectionError(
|
|
538
|
+
"Downloaded file is empty (0 bytes). "
|
|
539
|
+
"The server may not have any files to download, or there was an error streaming the response."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
logging.info(
|
|
543
|
+
f"Archive saved to: {output_path} ({total_bytes} bytes)"
|
|
544
|
+
)
|
|
545
|
+
else:
|
|
546
|
+
# Extract TAR stream directly
|
|
547
|
+
# Create tar extraction process
|
|
548
|
+
command = ["tar", "-x", "-C", target_directory]
|
|
549
|
+
|
|
550
|
+
extract_process = await asyncio.create_subprocess_exec(
|
|
551
|
+
*command,
|
|
552
|
+
stdin=asyncio.subprocess.PIPE,
|
|
553
|
+
stderr=asyncio.subprocess.PIPE,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Stream response to tar process
|
|
557
|
+
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
|
558
|
+
extract_process.stdin.write(chunk)
|
|
559
|
+
await extract_process.stdin.drain()
|
|
560
|
+
|
|
561
|
+
extract_process.stdin.close()
|
|
562
|
+
await extract_process.wait()
|
|
563
|
+
|
|
564
|
+
if extract_process.returncode != 0:
|
|
565
|
+
stderr = await extract_process.stderr.read()
|
|
566
|
+
raise TrainMLException(
|
|
567
|
+
f"tar extraction failed: {stderr.decode() if stderr else 'Unknown error'}"
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
logging.info(f"Files extracted to: {target_directory}")
|
|
571
|
+
finally:
|
|
572
|
+
response.close()
|
|
573
|
+
|
|
574
|
+
# Finalize download
|
|
575
|
+
async def _finalize():
|
|
576
|
+
async with session.post(
|
|
577
|
+
f"{endpoint}/finalize",
|
|
578
|
+
headers={"Authorization": f"Bearer {auth_token}"},
|
|
579
|
+
json={},
|
|
580
|
+
) as response:
|
|
581
|
+
if response.status != 200:
|
|
582
|
+
text = await response.text()
|
|
583
|
+
raise ConnectionError(f"Finalize failed: {text}")
|
|
584
|
+
return await response.json()
|
|
585
|
+
|
|
586
|
+
data = await retry_request(_finalize)
|
|
587
|
+
logging.debug(f"Download finalized: {data}")
|