trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +17 -1
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/projects/projects.py +2 -2
  43. trainml/trainml.py +50 -16
  44. trainml/utils/__init__.py +1 -0
  45. trainml/utils/auth.py +641 -0
  46. trainml/utils/transfer.py +587 -0
  47. trainml/volumes.py +48 -53
  48. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  49. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
  50. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/jobs.py CHANGED
@@ -12,7 +12,7 @@ from trainml.exceptions import (
12
12
  SpecificationError,
13
13
  TrainMLException,
14
14
  )
15
- from trainml.connections import Connection
15
+ from trainml.utils.transfer import upload, download
16
16
 
17
17
 
18
18
  class Jobs(object):
@@ -77,7 +77,8 @@ class Jobs(object):
77
77
  model=model,
78
78
  endpoint=endpoint,
79
79
  source_job_uuid=kwargs.get("source_job_uuid"),
80
- project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
80
+ project_uuid=kwargs.get("project_uuid")
81
+ or self.trainml.active_project,
81
82
  )
82
83
  payload = {
83
84
  k: v
@@ -102,7 +103,9 @@ class Jobs(object):
102
103
  return job
103
104
 
104
105
  async def remove(self, id, **kwargs):
105
- await self.trainml._query(f"/job/{id}", "DELETE", dict(**kwargs, force=True))
106
+ await self.trainml._query(
107
+ f"/job/{id}", "DELETE", dict(**kwargs, force=True)
108
+ )
106
109
 
107
110
 
108
111
  class Job:
@@ -292,42 +295,6 @@ class Job:
292
295
  )
293
296
  return resp
294
297
 
295
- async def get_connection_utility_url(self):
296
- resp = await self.trainml._query(
297
- f"/job/{self._id}/download",
298
- "GET",
299
- dict(project_uuid=self._project_uuid),
300
- )
301
- return resp
302
-
303
- def get_connection_details(self):
304
- details = dict(
305
- entity_type="job",
306
- project_uuid=self._job.get("project_uuid"),
307
- cidr=self.dict.get("vpn").get("cidr"),
308
- ssh_port=(
309
- self._job.get("vpn").get("client").get("ssh_port")
310
- if self._job.get("vpn").get("client")
311
- else None
312
- ),
313
- model_path=(
314
- self._job.get("model").get("source_uri")
315
- if self._job.get("model").get("source_type") == "local"
316
- else None
317
- ),
318
- input_path=(
319
- self._job.get("data").get("input_uri")
320
- if self._job.get("data").get("input_type") == "local"
321
- else None
322
- ),
323
- output_path=(
324
- self._job.get("data").get("output_uri")
325
- if self._job.get("data").get("output_type") == "local"
326
- else None
327
- ),
328
- )
329
- return details
330
-
331
298
  async def open(self):
332
299
  if self.type != "notebook":
333
300
  raise SpecificationError(
@@ -337,6 +304,7 @@ class Job:
337
304
  webbrowser.open(self.notebook_url)
338
305
 
339
306
  async def connect(self):
307
+ # Handle notebook/endpoint special cases
340
308
  if self.type == "notebook" and self.status not in [
341
309
  "new",
342
310
  "waiting for data/model download",
@@ -352,6 +320,8 @@ class Job:
352
320
  "waiting for data/model download",
353
321
  ]:
354
322
  return self.url
323
+
324
+ # Check for invalid statuses
355
325
  if self.status in [
356
326
  "failed",
357
327
  "finished",
@@ -364,26 +334,221 @@ class Job:
364
334
  "status",
365
335
  f"You can only connect to active jobs.",
366
336
  )
367
- if self._job.get("vpn").get("status") == "n/a":
368
- logging.info("Local connection not enabled for this job.")
369
- return
370
- if self.status == "new":
371
- await self.wait_for("waiting for data/model download")
372
- connection = Connection(
373
- self.trainml, entity_type="job", id=self.id, entity=self
374
- )
375
- await connection.start()
376
- return connection.status
377
337
 
378
- async def disconnect(self):
379
- if self._job.get("vpn").get("status") == "n/a":
380
- logging.info("Local connection not enabled for this job.")
381
- return
382
- connection = Connection(
383
- self.trainml, entity_type="job", id=self.id, entity=self
384
- )
385
- await connection.stop()
386
- return connection.status
338
+ # Only allow specific statuses for connect
339
+ if self.status not in [
340
+ "waiting for data/model download",
341
+ "uploading",
342
+ "running",
343
+ ]:
344
+ if self.status == "new":
345
+ await self.wait_for("waiting for data/model download")
346
+ else:
347
+ raise SpecificationError(
348
+ "status",
349
+ f"You can only connect to jobs in 'waiting for data/model download', 'uploading', or 'running' status.",
350
+ )
351
+
352
+ # Refresh to get latest job data
353
+ await self.refresh()
354
+
355
+ # Re-check status after refresh (status may have changed if attach() is running in parallel)
356
+ if self.status not in [
357
+ "waiting for data/model download",
358
+ "uploading",
359
+ "running",
360
+ ]:
361
+ raise SpecificationError(
362
+ "status",
363
+ f"Job status changed to '{self.status}'. You can only connect to jobs in 'waiting for data/model download', 'uploading', or 'running' status.",
364
+ )
365
+
366
+ if self.status == "waiting for data/model download":
367
+ # Upload model and/or data if local
368
+ model = self._job.get("model", {})
369
+ data = self._job.get("data", {})
370
+
371
+ model_local = model.get("source_type") == "local"
372
+ data_local = data.get("input_type") == "local"
373
+
374
+ if not model_local and not data_local:
375
+ raise SpecificationError(
376
+ "status",
377
+ f"Job has no local model or data to upload. Model source_type: {model.get('source_type')}, Data input_type: {data.get('input_type')}",
378
+ )
379
+
380
+ upload_tasks = []
381
+
382
+ if model_local:
383
+ model_auth_token = model.get("auth_token")
384
+ model_hostname = model.get("hostname")
385
+ model_source_uri = model.get("source_uri")
386
+
387
+ if (
388
+ not model_auth_token
389
+ or not model_hostname
390
+ or not model_source_uri
391
+ ):
392
+ raise SpecificationError(
393
+ "status",
394
+ f"Job model missing required connection properties (auth_token, hostname, source_uri).",
395
+ )
396
+
397
+ upload_tasks.append(
398
+ upload(model_hostname, model_auth_token, model_source_uri)
399
+ )
400
+
401
+ if data_local:
402
+ data_auth_token = data.get("input_auth_token")
403
+ data_hostname = data.get("input_hostname")
404
+ data_input_uri = data.get("input_uri")
405
+
406
+ if (
407
+ not data_auth_token
408
+ or not data_hostname
409
+ or not data_input_uri
410
+ ):
411
+ raise SpecificationError(
412
+ "status",
413
+ f"Job data missing required connection properties (input_auth_token, input_hostname, input_uri).",
414
+ )
415
+
416
+ upload_tasks.append(
417
+ upload(data_hostname, data_auth_token, data_input_uri)
418
+ )
419
+
420
+ # Upload both in parallel if both are local
421
+ if upload_tasks:
422
+ await asyncio.gather(*upload_tasks)
423
+
424
+ elif self.status in ["uploading", "running"]:
425
+ # Download output if local
426
+ data = self._job.get("data", {})
427
+
428
+ if data.get("output_type") != "local":
429
+ raise SpecificationError(
430
+ "status",
431
+ f"Job output_type is not 'local', cannot download output.",
432
+ )
433
+
434
+ output_uri = data.get("output_uri")
435
+ if not output_uri:
436
+ raise SpecificationError(
437
+ "status",
438
+ f"Job data missing output_uri for local output download.",
439
+ )
440
+
441
+ # Track which workers we've already started downloading
442
+ downloading_workers = set()
443
+ download_tasks = []
444
+
445
+ # Poll until all workers are finished
446
+ while True:
447
+ # Refresh job to get latest worker statuses
448
+ await self.refresh()
449
+
450
+ # Get fresh workers list
451
+ workers = self._job.get("workers", [])
452
+ if not workers:
453
+ raise SpecificationError(
454
+ "status",
455
+ f"Job has no workers.",
456
+ )
457
+
458
+ # Check if job is finished
459
+ if self.status in ["finished", "canceled", "failed"]:
460
+ break
461
+
462
+ # Check all workers for uploading status
463
+ for worker in workers:
464
+ worker_id = worker.get("job_worker_uuid") or worker.get(
465
+ "id"
466
+ )
467
+ worker_status = worker.get("status")
468
+
469
+ # Start download for any worker that enters uploading status
470
+ if (
471
+ worker_status == "uploading"
472
+ and worker_id not in downloading_workers
473
+ ):
474
+ output_auth_token = worker.get("output_auth_token")
475
+ output_hostname = worker.get("output_hostname")
476
+
477
+ if not output_auth_token or not output_hostname:
478
+ logging.warning(
479
+ f"Worker {worker_id} in uploading status missing output_auth_token or output_hostname, skipping."
480
+ )
481
+ continue
482
+
483
+ downloading_workers.add(worker_id)
484
+ # Create and start download task (runs in parallel)
485
+ logging.info(
486
+ f"Starting download for worker {worker_id} from {output_hostname} to {output_uri}"
487
+ )
488
+ try:
489
+ download_task = asyncio.create_task(
490
+ download(
491
+ output_hostname,
492
+ output_auth_token,
493
+ output_uri,
494
+ )
495
+ )
496
+ download_tasks.append(download_task)
497
+ logging.debug(
498
+ f"Download task created for worker {worker_id}, task: {download_task}"
499
+ )
500
+ except Exception as e:
501
+ logging.error(
502
+ f"Failed to create download task for worker {worker_id}: {e}",
503
+ exc_info=True,
504
+ )
505
+ raise
506
+
507
+ # Check if any download tasks have completed or failed
508
+ if download_tasks:
509
+ completed_tasks = [
510
+ task for task in download_tasks if task.done()
511
+ ]
512
+ for task in completed_tasks:
513
+ try:
514
+ await task # This will raise if the task failed
515
+ logging.info(
516
+ f"Download task completed successfully"
517
+ )
518
+ except Exception as e:
519
+ logging.error(
520
+ f"Download task failed: {e}", exc_info=True
521
+ )
522
+ raise
523
+ # Remove completed tasks
524
+ download_tasks = [
525
+ task for task in download_tasks if not task.done()
526
+ ]
527
+
528
+ # Check if all workers are finished
529
+ all_finished = all(
530
+ worker.get("status") in ["finished", "removed"]
531
+ for worker in workers
532
+ )
533
+
534
+ if all_finished:
535
+ break
536
+
537
+ # If we have active download tasks, wait a bit for them to make progress
538
+ # but don't wait the full 30 seconds - check more frequently
539
+ if download_tasks:
540
+ await asyncio.sleep(5)
541
+ else:
542
+ # Wait 30 seconds before next poll if no downloads in progress
543
+ await asyncio.sleep(30)
544
+
545
+ # Wait for all download tasks to complete
546
+ if download_tasks:
547
+ logging.info(
548
+ f"Waiting for {len(download_tasks)} download task(s) to complete"
549
+ )
550
+ await asyncio.gather(*download_tasks)
551
+ logging.info("All downloads completed")
387
552
 
388
553
  async def remove(self, force=False):
389
554
  await self.trainml._query(
@@ -401,7 +566,8 @@ class Job:
401
566
 
402
567
  def _get_msg_handler(self, msg_handler):
403
568
  worker_numbers = {
404
- w.get("job_worker_uuid"): ind + 1 for ind, w in enumerate(self._workers)
569
+ w.get("job_worker_uuid"): ind + 1
570
+ for ind, w in enumerate(self._workers)
405
571
  }
406
572
  worker_numbers["data_worker"] = 0
407
573
 
@@ -411,7 +577,9 @@ class Job:
411
577
  if msg_handler:
412
578
  msg_handler(data)
413
579
  else:
414
- timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
580
+ timestamp = datetime.fromtimestamp(
581
+ int(data.get("time")) / 1000
582
+ )
415
583
  if len(self._workers) > 1:
416
584
  print(
417
585
  f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: Worker {data.get('worker_number')} - {data.get('msg').rstrip()}"
@@ -424,7 +592,10 @@ class Job:
424
592
  return handler
425
593
 
426
594
  async def attach(self, msg_handler=None):
427
- if self.type == "notebook" and self.status != "waiting for data/model download":
595
+ if (
596
+ self.type == "notebook"
597
+ and self.status != "waiting for data/model download"
598
+ ):
428
599
  raise SpecificationError(
429
600
  "type",
430
601
  "Notebooks cannot be attached to after model download is complete. Use open() instead.",
@@ -441,7 +612,9 @@ class Job:
441
612
  async def copy(self, name, **kwargs):
442
613
  logging.debug(f"copy request - name: {name} ; kwargs: {kwargs}")
443
614
  if self.type != "notebook":
444
- raise SpecificationError("job", "Only notebook job types can be copied")
615
+ raise SpecificationError(
616
+ "job", "Only notebook job types can be copied"
617
+ )
445
618
 
446
619
  job = await self.trainml.jobs.create(
447
620
  name,
@@ -502,7 +675,9 @@ class Job:
502
675
 
503
676
  POLL_INTERVAL_MIN = 5
504
677
  POLL_INTERVAL_MAX = 60
505
- POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
678
+ POLL_INTERVAL = max(
679
+ min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
680
+ )
506
681
  retry_count = math.ceil(timeout / POLL_INTERVAL)
507
682
  count = 0
508
683
  while count < retry_count:
@@ -533,11 +708,7 @@ class Job:
533
708
  or (
534
709
  status
535
710
  == "running" ## this status could be too short for polling could miss it
536
- and self.status
537
- in [
538
- "uploading",
539
- "finished"
540
- ]
711
+ and self.status in ["uploading", "finished"]
541
712
  )
542
713
  ):
543
714
  return self
trainml/models.py CHANGED
@@ -10,7 +10,7 @@ from .exceptions import (
10
10
  SpecificationError,
11
11
  TrainMLException,
12
12
  )
13
- from .connections import Connection
13
+ from trainml.utils.transfer import upload, download
14
14
 
15
15
 
16
16
  class Models(object):
@@ -54,7 +54,9 @@ class Models(object):
54
54
  return model
55
55
 
56
56
  async def remove(self, id, **kwargs):
57
- await self.trainml._query(f"/model/{id}", "DELETE", dict(**kwargs, force=True))
57
+ await self.trainml._query(
58
+ f"/model/{id}", "DELETE", dict(**kwargs, force=True)
59
+ )
58
60
 
59
61
 
60
62
  class Model:
@@ -65,7 +67,9 @@ class Model:
65
67
  self._status = self._model.get("status")
66
68
  self._name = self._model.get("name")
67
69
  self._size = self._model.get("size") or self._model.get("used_size")
68
- self._billed_size = self._model.get("billed_size") or self._model.get("size")
70
+ self._billed_size = self._model.get("billed_size") or self._model.get(
71
+ "size"
72
+ )
69
73
  self._project_uuid = self._model.get("project_uuid")
70
74
 
71
75
  @property
@@ -113,57 +117,45 @@ class Model:
113
117
  )
114
118
  return resp
115
119
 
116
- async def get_connection_utility_url(self):
117
- resp = await self.trainml._query(
118
- f"/model/{self._id}/download",
119
- "GET",
120
- dict(project_uuid=self._project_uuid),
121
- )
122
- return resp
123
-
124
- def get_connection_details(self):
125
- if self._model.get("vpn"):
126
- details = dict(
127
- entity_type="model",
128
- project_uuid=self._model.get("project_uuid"),
129
- cidr=self._model.get("vpn").get("cidr"),
130
- ssh_port=self._model.get("vpn").get("client").get("ssh_port"),
131
- input_path=(
132
- self._model.get("source_uri")
133
- if self.status in ["new", "downloading"]
134
- else None
135
- ),
136
- output_path=(
137
- self._model.get("output_uri")
138
- if self.status == "exporting"
139
- else None
140
- ),
141
- )
142
- else:
143
- details = dict()
144
- logging.debug(f"Connection Details: {details}")
145
- return details
146
-
147
120
  async def connect(self):
148
- if self.status in ["ready", "failed"]:
149
- raise SpecificationError(
150
- "status",
151
- f"You can only connect to downloading or exporting models.",
152
- )
153
- if self.status == "new":
154
- await self.wait_for("downloading")
155
- connection = Connection(
156
- self.trainml, entity_type="model", id=self.id, entity=self
157
- )
158
- await connection.start()
159
- return connection.status
121
+ if self.status not in ["downloading", "exporting"]:
122
+ if self.status == "new":
123
+ await self.wait_for("downloading")
124
+ else:
125
+ raise SpecificationError(
126
+ "status",
127
+ f"You can only connect to downloading or exporting models.",
128
+ )
160
129
 
161
- async def disconnect(self):
162
- connection = Connection(
163
- self.trainml, entity_type="model", id=self.id, entity=self
164
- )
165
- await connection.stop()
166
- return connection.status
130
+ # Refresh to get latest entity data
131
+ await self.refresh()
132
+
133
+ if self.status == "downloading":
134
+ # Upload task - get auth_token, hostname, and source_uri from model
135
+ auth_token = self._model.get("auth_token")
136
+ hostname = self._model.get("hostname")
137
+ source_uri = self._model.get("source_uri")
138
+
139
+ if not auth_token or not hostname or not source_uri:
140
+ raise SpecificationError(
141
+ "status",
142
+ f"Model in downloading status missing required connection properties (auth_token, hostname, source_uri).",
143
+ )
144
+
145
+ await upload(hostname, auth_token, source_uri)
146
+ elif self.status == "exporting":
147
+ # Download task - get auth_token, hostname, and output_uri from model
148
+ auth_token = self._model.get("auth_token")
149
+ hostname = self._model.get("hostname")
150
+ output_uri = self._model.get("output_uri")
151
+
152
+ if not auth_token or not hostname or not output_uri:
153
+ raise SpecificationError(
154
+ "status",
155
+ f"Model in exporting status missing required connection properties (auth_token, hostname, output_uri).",
156
+ )
157
+
158
+ await download(hostname, auth_token, output_uri)
167
159
 
168
160
  async def remove(self, force=False):
169
161
  await self.trainml._query(
@@ -202,7 +194,9 @@ class Model:
202
194
  if msg_handler:
203
195
  msg_handler(data)
204
196
  else:
205
- timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
197
+ timestamp = datetime.fromtimestamp(
198
+ int(data.get("time")) / 1000
199
+ )
206
200
  print(
207
201
  f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
208
202
  )
@@ -231,7 +225,7 @@ class Model:
231
225
  async def wait_for(self, status, timeout=300):
232
226
  if self.status == status:
233
227
  return
234
- valid_statuses = ["downloading", "ready", "archived"]
228
+ valid_statuses = ["downloading", "ready","exporting", "archived"]
235
229
  if not status in valid_statuses:
236
230
  raise SpecificationError(
237
231
  "status",
@@ -245,7 +239,9 @@ class Model:
245
239
  )
246
240
  POLL_INTERVAL_MIN = 5
247
241
  POLL_INTERVAL_MAX = 60
248
- POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
242
+ POLL_INTERVAL = max(
243
+ min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
244
+ )
249
245
  retry_count = math.ceil(timeout / POLL_INTERVAL)
250
246
  count = 0
251
247
  while count < retry_count:
@@ -27,10 +27,10 @@ class Projects(object):
27
27
  projects = [Project(self.trainml, **project) for project in resp]
28
28
  return projects
29
29
 
30
- async def create(self, name, copy_credentials=False, **kwargs):
30
+ async def create(self, name, **kwargs):
31
31
  data = dict(
32
32
  name=name,
33
- copy_credentials=copy_credentials,
33
+ **kwargs
34
34
  )
35
35
  payload = {k: v for k, v in data.items() if v is not None}
36
36
  logging.info(f"Creating Project {name}")