trainml 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +16 -0
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +89 -67
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +269 -69
  41. trainml/models.py +51 -55
  42. trainml/trainml.py +159 -114
  43. trainml/utils/__init__.py +1 -0
  44. trainml/utils/auth.py +641 -0
  45. trainml/utils/transfer.py +647 -0
  46. trainml/volumes.py +48 -53
  47. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/METADATA +3 -3
  48. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/RECORD +52 -46
  49. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/LICENSE +0 -0
  50. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/WHEEL +0 -0
  51. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/entry_points.txt +0 -0
  52. {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/top_level.txt +0 -0
examples/local_storage.py CHANGED
@@ -19,7 +19,6 @@ async def create_dataset():
19
19
  attach_task = asyncio.create_task(dataset.attach())
20
20
  connect_task = asyncio.create_task(dataset.connect())
21
21
  await asyncio.gather(attach_task, connect_task)
22
- await dataset.disconnect()
23
22
  return dataset
24
23
 
25
24
 
@@ -55,7 +54,6 @@ async def run_job(dataset):
55
54
  await asyncio.gather(attach_task, connect_task)
56
55
 
57
56
  # Cleanup job
58
- await job.disconnect()
59
57
  await job.remove()
60
58
 
61
59
 
@@ -23,6 +23,7 @@ class GetCheckpointTests:
23
23
  checkpoint = await checkpoint.wait_for("archived", 60)
24
24
 
25
25
  async def test_get_checkpoints(self, trainml, checkpoint):
26
+ _ = checkpoint
26
27
  checkpoints = await trainml.checkpoints.list()
27
28
  assert len(checkpoints) > 0
28
29
 
@@ -55,7 +56,7 @@ class GetCheckpointTests:
55
56
 
56
57
  @mark.create
57
58
  @mark.asyncio
58
- async def test_checkpoint_wasabi(trainml, capsys):
59
+ async def test_checkpoint_wasabi(trainml):
59
60
  checkpoint = await trainml.checkpoints.create(
60
61
  name="CLI Automated Wasabi",
61
62
  source_type="wasabi",
@@ -72,6 +73,7 @@ async def test_checkpoint_wasabi(trainml, capsys):
72
73
 
73
74
  @mark.create
74
75
  @mark.asyncio
76
+ @mark.local
75
77
  async def test_checkpoint_local(trainml, capsys):
76
78
  checkpoint = await trainml.checkpoints.create(
77
79
  name="CLI Automated Local",
@@ -81,7 +83,6 @@ async def test_checkpoint_local(trainml, capsys):
81
83
  attach_task = asyncio.create_task(checkpoint.attach())
82
84
  connect_task = asyncio.create_task(checkpoint.connect())
83
85
  await asyncio.gather(attach_task, connect_task)
84
- await checkpoint.disconnect()
85
86
  await checkpoint.refresh()
86
87
  status = checkpoint.status
87
88
  size = checkpoint.size
@@ -92,5 +93,5 @@ async def test_checkpoint_local(trainml, capsys):
92
93
  sys.stdout.write(captured.out)
93
94
  sys.stderr.write(captured.err)
94
95
  assert "Starting data upload from local" in captured.out
95
- assert "official/LICENSE 11456 bytes" in captured.out
96
+ assert "official/LICENSE" in captured.out
96
97
  assert "Upload complete" in captured.out
@@ -50,7 +50,9 @@ class GetDatasetTests:
50
50
  async def test_dataset_repr(self, dataset):
51
51
  string = repr(dataset)
52
52
  regex = (
53
- r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '" + dataset.id + r"'.*}\)$"
53
+ r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '"
54
+ + dataset.id
55
+ + r"'.*}\)$"
54
56
  )
55
57
  assert isinstance(string, str)
56
58
  assert re.match(regex, string)
@@ -79,6 +81,7 @@ class GetDatasetTests:
79
81
 
80
82
  @mark.create
81
83
  @mark.asyncio
84
+ @mark.local
82
85
  async def test_dataset_local(trainml, capsys):
83
86
  dataset = await trainml.datasets.create(
84
87
  name="CLI Automated Local",
@@ -88,7 +91,6 @@ async def test_dataset_local(trainml, capsys):
88
91
  attach_task = asyncio.create_task(dataset.attach())
89
92
  connect_task = asyncio.create_task(dataset.connect())
90
93
  await asyncio.gather(attach_task, connect_task)
91
- await dataset.disconnect()
92
94
  await dataset.refresh()
93
95
  status = dataset.status
94
96
  size = dataset.size
@@ -99,5 +101,5 @@ async def test_dataset_local(trainml, capsys):
99
101
  sys.stdout.write(captured.out)
100
102
  sys.stderr.write(captured.err)
101
103
  assert "Starting data upload from local" in captured.out
102
- assert "data_batch_1.bin 30733788 bytes" in captured.out
104
+ assert "data_batch_1.bin" in captured.out
103
105
  assert "Upload complete" in captured.out
@@ -46,7 +46,10 @@ class JobLifeCycleTests:
46
46
  job = await job.wait_for("running")
47
47
  assert job.status == "running"
48
48
  assert job.url
49
- assert extract_domain_suffix(urlparse(job.url).hostname) == "proximl.cloud"
49
+ assert (
50
+ extract_domain_suffix(urlparse(job.url).hostname)
51
+ == "proximl.cloud"
52
+ )
50
53
 
51
54
  async def test_stop_job(self, job):
52
55
  assert job.status == "running"
@@ -204,7 +207,8 @@ class JobAPIResourceValidationTests:
204
207
  disk_size=10,
205
208
  )
206
209
  assert (
207
- "Invalid Request - CPU Count must be a multiple of 4" in error.value.message
210
+ "Invalid Request - CPU Count must be a multiple of 4"
211
+ in error.value.message
208
212
  )
209
213
 
210
214
  async def test_invalid_gpu_count_for_cpu(self, trainml):
@@ -417,6 +421,7 @@ class JobAPIWorkerValidationTests:
417
421
  @mark.asyncio
418
422
  @mark.xdist_group("job_io")
419
423
  class JobIOTests:
424
+ @mark.local
420
425
  async def test_job_local_output(self, trainml, capsys):
421
426
  temp_dir = tempfile.TemporaryDirectory()
422
427
  job = await trainml.jobs.create(
@@ -426,7 +431,7 @@ class JobIOTests:
426
431
  disk_size=10,
427
432
  workers=["python $ML_MODEL_PATH/tensorflow/main.py"],
428
433
  environment=dict(
429
- type="DEEPLEARNING_PY312",
434
+ type="DEEPLEARNING_PY313",
430
435
  env=[
431
436
  dict(
432
437
  key="CHECKPOINT_FILE",
@@ -452,13 +457,13 @@ class JobIOTests:
452
457
  ],
453
458
  ),
454
459
  )
455
- await job.wait_for("waiting for data/model download")
460
+ # Wait for job to reach running status since only output_type is local
461
+ await job.wait_for("running")
456
462
  attach_task = asyncio.create_task(job.attach())
457
463
  connect_task = asyncio.create_task(job.connect())
458
464
  await asyncio.gather(attach_task, connect_task)
459
465
  await job.refresh()
460
466
  assert job.status == "finished"
461
- await job.disconnect()
462
467
  await job.remove()
463
468
  upload_contents = os.listdir(temp_dir.name)
464
469
  temp_dir.cleanup()
@@ -470,9 +475,8 @@ class JobIOTests:
470
475
  captured = capsys.readouterr()
471
476
  sys.stdout.write(captured.out)
472
477
  sys.stderr.write(captured.err)
473
- assert "Epoch 1/2" in captured.out
474
- assert "Epoch 2/2" in captured.out
475
- assert "adding: model.ckpt-0001" in captured.out
478
+ assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
479
+ assert "model.ckpt-0001" in captured.out
476
480
  assert "Send complete" in captured.out
477
481
 
478
482
  async def test_job_model_input_and_output(self, trainml, capsys):
@@ -513,8 +517,7 @@ class JobIOTests:
513
517
  captured = capsys.readouterr()
514
518
  sys.stdout.write(captured.out)
515
519
  sys.stderr.write(captured.err)
516
- assert "Epoch 1/2" in captured.out
517
- assert "Epoch 2/2" in captured.out
520
+ assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
518
521
 
519
522
  new_model = await trainml.models.get(workers[0].get("output_uuid"))
520
523
  assert new_model.id
@@ -560,9 +563,12 @@ class JobTypeTests:
560
563
  await job.wait_for("running")
561
564
  await job.refresh()
562
565
  assert job.url
563
- assert extract_domain_suffix(urlparse(job.url).hostname) == "proximl.cloud"
566
+ assert (
567
+ extract_domain_suffix(urlparse(job.url).hostname)
568
+ == "proximl.cloud"
569
+ )
564
570
  tries = 0
565
- await asyncio.sleep(180) ## downloading weights can be slow
571
+ await asyncio.sleep(180) ## downloading weights can be slow
566
572
  async with aiohttp.ClientSession() as session:
567
573
  retry = True
568
574
  while retry:
@@ -640,9 +646,11 @@ class JobTypeTests:
640
646
  captured = capsys.readouterr()
641
647
  sys.stdout.write(captured.out)
642
648
  sys.stderr.write(captured.err)
643
- assert "Epoch 1/2" in captured.out
644
- assert "Epoch 2/2" in captured.out
645
- assert "Uploading s3://trainml-example/output/resnet_cifar10" in captured.out
649
+ assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
650
+ assert (
651
+ "Uploading s3://trainml-example/output/resnet_cifar10"
652
+ in captured.out
653
+ )
646
654
  assert (
647
655
  "upload: ./model.ckpt-0002.weights.h5 to s3://trainml-example/output/resnet_cifar10/model.ckpt-0002.weights.h5"
648
656
  in captured.out
@@ -680,9 +688,12 @@ class JobFeatureTests:
680
688
  captured = capsys.readouterr()
681
689
  sys.stdout.write(captured.out)
682
690
  sys.stderr.write(captured.err)
683
- assert "Train Epoch: 1 [0/60000 (0%)]" in captured.out
684
- assert "Train Epoch: 1 [59520/60000 (99%)]" in captured.out
691
+ assert (
692
+ "Train Epoch: 1 [0/60000 (0%)]" in captured.out
693
+ or "Train Epoch: 1 [59520/60000 (99%)]" in captured.out
694
+ )
685
695
 
696
+ @mark.local
686
697
  async def test_inference_job(self, trainml, capsys):
687
698
  temp_dir = tempfile.TemporaryDirectory()
688
699
  job = await trainml.jobs.create(
@@ -706,11 +717,11 @@ class JobFeatureTests:
706
717
  )
707
718
  assert job.id
708
719
  await job.wait_for("running")
709
- await job.connect()
710
- await job.attach()
720
+ attach_task = asyncio.create_task(job.attach())
721
+ connect_task = asyncio.create_task(job.connect())
722
+ await asyncio.gather(attach_task, connect_task)
711
723
  await job.refresh()
712
724
  assert job.status == "finished"
713
- await job.disconnect()
714
725
  await job.remove()
715
726
  await job.wait_for("archived")
716
727
  captured = capsys.readouterr()
@@ -719,15 +730,10 @@ class JobFeatureTests:
719
730
  upload_contents = os.listdir(temp_dir.name)
720
731
  temp_dir.cleanup()
721
732
  assert len(upload_contents) >= 3
722
- assert any(
723
- "model.ckpt-0002" in content
724
- for content in upload_contents
725
- )
733
+ assert any("model.ckpt-0002" in content for content in upload_contents)
726
734
 
727
735
  captured = capsys.readouterr()
728
736
  sys.stdout.write(captured.out)
729
737
  sys.stderr.write(captured.err)
730
- assert "Epoch 1/2" in captured.out
731
- assert "Epoch 2/2" in captured.out
732
- assert "Number of regular files transferred: 4" in captured.out
738
+ assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
733
739
  assert "Send complete" in captured.out
@@ -44,7 +44,11 @@ class GetModelTests:
44
44
 
45
45
  async def test_model_repr(self, model):
46
46
  string = repr(model)
47
- regex = r"^Model\( trainml , \*\*{.*'model_uuid': '" + model.id + r"'.*}\)$"
47
+ regex = (
48
+ r"^Model\( trainml , \*\*{.*'model_uuid': '"
49
+ + model.id
50
+ + r"'.*}\)$"
51
+ )
48
52
  assert isinstance(string, str)
49
53
  assert re.match(regex, string)
50
54
 
@@ -68,6 +72,7 @@ async def test_model_wasabi(trainml, capsys):
68
72
 
69
73
  @mark.create
70
74
  @mark.asyncio
75
+ @mark.local
71
76
  async def test_model_local(trainml, capsys):
72
77
  model = await trainml.models.create(
73
78
  name="CLI Automated Local",
@@ -77,7 +82,6 @@ async def test_model_local(trainml, capsys):
77
82
  attach_task = asyncio.create_task(model.attach())
78
83
  connect_task = asyncio.create_task(model.connect())
79
84
  await asyncio.gather(attach_task, connect_task)
80
- await model.disconnect()
81
85
  await model.refresh()
82
86
  status = model.status
83
87
  size = model.size
@@ -88,5 +92,5 @@ async def test_model_local(trainml, capsys):
88
92
  sys.stdout.write(captured.out)
89
93
  sys.stderr.write(captured.err)
90
94
  assert "Starting data upload from local" in captured.out
91
- assert "official/LICENSE 11456 bytes" in captured.out
95
+ assert "official/LICENSE" in captured.out
92
96
  assert "Upload complete" in captured.out
@@ -74,6 +74,7 @@ async def test_volume_wasabi(trainml, capsys):
74
74
 
75
75
  @mark.create
76
76
  @mark.asyncio
77
+ @mark.local
77
78
  async def test_volume_local(trainml, capsys):
78
79
  volume = await trainml.volumes.create(
79
80
  name="CLI Automated Local",
@@ -84,7 +85,6 @@ async def test_volume_local(trainml, capsys):
84
85
  attach_task = asyncio.create_task(volume.attach())
85
86
  connect_task = asyncio.create_task(volume.connect())
86
87
  await asyncio.gather(attach_task, connect_task)
87
- await volume.disconnect()
88
88
  await volume.refresh()
89
89
  status = volume.status
90
90
  billed_size = volume.billed_size
@@ -97,5 +97,5 @@ async def test_volume_local(trainml, capsys):
97
97
  sys.stdout.write(captured.out)
98
98
  sys.stderr.write(captured.err)
99
99
  assert "Starting data upload from local" in captured.out
100
- assert "official/LICENSE 11456 bytes" in captured.out
100
+ assert "official/LICENSE" in captured.out
101
101
  assert "Upload complete" in captured.out
@@ -1,15 +1,24 @@
1
1
  import re
2
2
  import json
3
3
  import click
4
- from unittest.mock import AsyncMock, patch
4
+ from unittest.mock import AsyncMock, patch, Mock
5
5
  from pytest import mark, fixture, raises
6
6
 
7
7
  pytestmark = [mark.cli, mark.unit, mark.checkpoints]
8
8
 
9
9
  from trainml.cli import checkpoint as specimen
10
+ from trainml.cli.checkpoint import pretty_size
10
11
  from trainml.checkpoints import Checkpoint
11
12
 
12
13
 
14
+ def test_pretty_size_zero():
15
+ """Test pretty_size with zero/None (line 7)."""
16
+ result = pretty_size(None)
17
+ assert result == "0.00 B"
18
+ result = pretty_size(0)
19
+ assert result == "0.00 B"
20
+
21
+
13
22
  def test_list(runner, mock_my_checkpoints):
14
23
  with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
15
24
  mock_trainml.checkpoints = AsyncMock()
@@ -20,3 +29,305 @@ def test_list(runner, mock_my_checkpoints):
20
29
  print(result)
21
30
  assert result.exit_code == 0
22
31
  mock_trainml.checkpoints.list.assert_called_once()
32
+
33
+
34
+ def test_attach_success(runner, mock_my_checkpoints):
35
+ """Test attach command success (lines 32-38)."""
36
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
37
+
38
+ async def list_async():
39
+ return mock_my_checkpoints
40
+
41
+ mock_trainml.checkpoints = AsyncMock()
42
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
43
+
44
+ # Use the first checkpoint from the list
45
+ checkpoint = mock_my_checkpoints[0]
46
+
47
+ async def attach_async():
48
+ return None
49
+
50
+ checkpoint.attach = Mock(return_value=attach_async())
51
+
52
+ with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
53
+ result = runner.invoke(specimen, ["attach", "1"])
54
+ assert result.exit_code == 0
55
+ checkpoint.attach.assert_called_once()
56
+
57
+
58
+ def test_attach_not_found(runner, mock_my_checkpoints):
59
+ """Test attach command when checkpoint not found (line 36)."""
60
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
61
+
62
+ async def list_async():
63
+ return mock_my_checkpoints
64
+
65
+ mock_trainml.checkpoints = AsyncMock()
66
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
67
+
68
+ with patch("trainml.cli.search_by_id_name", return_value=None):
69
+ result = runner.invoke(specimen, ["attach", "nonexistent"])
70
+ assert result.exit_code != 0
71
+ assert "Cannot find specified checkpoint" in result.output
72
+
73
+
74
+ def test_connect_with_attach(runner, mock_my_checkpoints):
75
+ """Test connect command with attach (lines 56-65, attach=True)."""
76
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
77
+
78
+ async def list_async():
79
+ return mock_my_checkpoints
80
+
81
+ mock_trainml.checkpoints = AsyncMock()
82
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
83
+
84
+ checkpoint = mock_my_checkpoints[0]
85
+
86
+ async def connect_async():
87
+ return None
88
+
89
+ async def attach_async():
90
+ return None
91
+
92
+ checkpoint.connect = Mock(return_value=connect_async())
93
+ checkpoint.attach = Mock(return_value=attach_async())
94
+
95
+ with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
96
+ result = runner.invoke(specimen, ["connect", "1"])
97
+ assert result.exit_code == 0
98
+ checkpoint.connect.assert_called_once()
99
+ checkpoint.attach.assert_called_once()
100
+
101
+
102
+ def test_connect_no_attach(runner, mock_my_checkpoints):
103
+ """Test connect command without attach (lines 56-65, attach=False)."""
104
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
105
+
106
+ async def list_async():
107
+ return mock_my_checkpoints
108
+
109
+ mock_trainml.checkpoints = AsyncMock()
110
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
111
+
112
+ checkpoint = mock_my_checkpoints[0]
113
+
114
+ async def connect_async():
115
+ return None
116
+
117
+ checkpoint.connect = Mock(return_value=connect_async())
118
+
119
+ with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
120
+ result = runner.invoke(specimen, ["connect", "--no-attach", "1"])
121
+ assert result.exit_code == 0
122
+ checkpoint.connect.assert_called_once()
123
+
124
+
125
+ def test_connect_not_found(runner, mock_my_checkpoints):
126
+ """Test connect command when checkpoint not found (line 60)."""
127
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
128
+
129
+ async def list_async():
130
+ return mock_my_checkpoints
131
+
132
+ mock_trainml.checkpoints = AsyncMock()
133
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
134
+
135
+ with patch("trainml.cli.search_by_id_name", return_value=None):
136
+ result = runner.invoke(specimen, ["connect", "nonexistent"])
137
+ assert result.exit_code != 0
138
+ assert "Cannot find specified checkpoint" in result.output
139
+
140
+
141
+ def test_create_with_connect_and_attach(runner, tmp_path, mock_my_checkpoints):
142
+ """Test create command with connect and attach (lines 103-115)."""
143
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
144
+ checkpoint = mock_my_checkpoints[0]
145
+
146
+ async def connect_async():
147
+ return None
148
+
149
+ async def attach_async():
150
+ return None
151
+
152
+ checkpoint.connect = Mock(return_value=connect_async())
153
+ checkpoint.attach = Mock(return_value=attach_async())
154
+
155
+ async def create_async(**kwargs):
156
+ return checkpoint
157
+
158
+ mock_trainml.checkpoints = AsyncMock()
159
+ mock_trainml.checkpoints.create = Mock(
160
+ side_effect=lambda **kwargs: create_async(**kwargs)
161
+ )
162
+
163
+ test_dir = tmp_path / "test_checkpoint"
164
+ test_dir.mkdir()
165
+ result = runner.invoke(
166
+ specimen, ["create", "test-checkpoint", str(test_dir)]
167
+ )
168
+ assert result.exit_code == 0
169
+ mock_trainml.checkpoints.create.assert_called_once()
170
+ checkpoint.connect.assert_called_once()
171
+ checkpoint.attach.assert_called_once()
172
+
173
+
174
+ def test_create_with_connect_no_attach(runner, tmp_path, mock_my_checkpoints):
175
+ """Test create command with connect but no attach (lines 103-115)."""
176
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
177
+ checkpoint = mock_my_checkpoints[0]
178
+
179
+ async def connect_async():
180
+ return None
181
+
182
+ checkpoint.connect = Mock(return_value=connect_async())
183
+
184
+ async def create_async(**kwargs):
185
+ return checkpoint
186
+
187
+ mock_trainml.checkpoints = AsyncMock()
188
+ mock_trainml.checkpoints.create = Mock(
189
+ side_effect=lambda **kwargs: create_async(**kwargs)
190
+ )
191
+
192
+ test_dir = tmp_path / "test_checkpoint"
193
+ test_dir.mkdir()
194
+ result = runner.invoke(
195
+ specimen,
196
+ ["create", "--no-attach", "test-checkpoint", str(test_dir)],
197
+ )
198
+ assert result.exit_code == 0
199
+ checkpoint.connect.assert_called_once()
200
+
201
+
202
+ def test_create_no_connect(runner, tmp_path):
203
+ """Test create command without connect (lines 103-115, line 115)."""
204
+ mock_checkpoint = Mock(spec=Checkpoint)
205
+
206
+ mock_trainml_runner = Mock()
207
+ mock_trainml_runner.client = Mock()
208
+ mock_trainml_runner.client.checkpoints = Mock()
209
+ mock_trainml_runner.client.checkpoints.create = AsyncMock(
210
+ return_value=mock_checkpoint
211
+ )
212
+ mock_trainml_runner.run = Mock(
213
+ side_effect=lambda x: x if not hasattr(x, "__call__") else x()
214
+ )
215
+
216
+ with patch("trainml.cli.TrainMLRunner", return_value=mock_trainml_runner):
217
+ test_dir = tmp_path / "test_checkpoint"
218
+ test_dir.mkdir()
219
+ result = runner.invoke(
220
+ specimen,
221
+ ["create", "--no-connect", "test-checkpoint", str(test_dir)],
222
+ )
223
+ assert result.exit_code != 0
224
+ assert "No logs to show" in result.output
225
+
226
+
227
+ def test_list_public(runner, mock_my_checkpoints):
228
+ """Test list_public command (lines 152-171)."""
229
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
230
+ mock_trainml.checkpoints = AsyncMock()
231
+ mock_trainml.checkpoints.list_public = AsyncMock(
232
+ return_value=mock_my_checkpoints
233
+ )
234
+
235
+ result = runner.invoke(specimen, ["list-public"])
236
+ assert result.exit_code == 0
237
+ mock_trainml.checkpoints.list_public.assert_called_once()
238
+
239
+
240
+ def test_remove_success(runner, mock_my_checkpoints):
241
+ """Test remove command success (lines 192-201)."""
242
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
243
+
244
+ async def list_async():
245
+ return mock_my_checkpoints
246
+
247
+ mock_trainml.checkpoints = AsyncMock()
248
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
249
+
250
+ checkpoint = mock_my_checkpoints[0]
251
+
252
+ async def remove_async():
253
+ return None
254
+
255
+ checkpoint.remove = Mock(return_value=remove_async())
256
+
257
+ with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
258
+ result = runner.invoke(specimen, ["remove", "1"])
259
+ assert result.exit_code == 0
260
+ checkpoint.remove.assert_called_once_with(force=False)
261
+
262
+
263
+ def test_remove_not_found(runner, mock_my_checkpoints):
264
+ """Test remove command when checkpoint not found (lines 192-201)."""
265
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
266
+
267
+ async def list_async():
268
+ return mock_my_checkpoints
269
+
270
+ mock_trainml.checkpoints = AsyncMock()
271
+ mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
272
+
273
+ with patch("trainml.cli.search_by_id_name", return_value=None):
274
+ result = runner.invoke(specimen, ["remove", "nonexistent"])
275
+ assert result.exit_code != 0
276
+ assert "Cannot find specified checkpoint" in result.output
277
+
278
+
279
+ def test_rename_success(runner, mock_my_checkpoints):
280
+ """Test rename command success (lines 214-223)."""
281
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
282
+ checkpoint = mock_my_checkpoints[0]
283
+
284
+ async def rename_async():
285
+ return None
286
+
287
+ checkpoint.rename = Mock(return_value=rename_async())
288
+
289
+ async def get_async(checkpoint_id):
290
+ return checkpoint
291
+
292
+ mock_trainml.checkpoints = AsyncMock()
293
+ mock_trainml.checkpoints.get = Mock(
294
+ side_effect=lambda checkpoint_id: get_async(checkpoint_id)
295
+ )
296
+
297
+ result = runner.invoke(specimen, ["rename", "1", "new-name"])
298
+ assert result.exit_code == 0
299
+ checkpoint.rename.assert_called_once_with(name="new-name")
300
+
301
+
302
+ def test_rename_not_found_none(runner):
303
+ """Test rename command when checkpoint is None (lines 214-223, line 219)."""
304
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
305
+
306
+ async def get_async(checkpoint_id):
307
+ return None
308
+
309
+ mock_trainml.checkpoints = AsyncMock()
310
+ mock_trainml.checkpoints.get = Mock(
311
+ side_effect=lambda checkpoint_id: get_async(checkpoint_id)
312
+ )
313
+
314
+ result = runner.invoke(specimen, ["rename", "nonexistent", "new-name"])
315
+ assert result.exit_code != 0
316
+ assert "Cannot find specified checkpoint" in result.output
317
+
318
+
319
+ def test_rename_not_found_exception(runner):
320
+ """Test rename command when exception occurs (lines 214-223, line 221)."""
321
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
322
+
323
+ async def get_async(checkpoint_id):
324
+ raise Exception("Not found")
325
+
326
+ mock_trainml.checkpoints = AsyncMock()
327
+ mock_trainml.checkpoints.get = Mock(
328
+ side_effect=lambda checkpoint_id: get_async(checkpoint_id)
329
+ )
330
+
331
+ result = runner.invoke(specimen, ["rename", "nonexistent", "new-name"])
332
+ assert result.exit_code != 0
333
+ assert "Cannot find specified checkpoint" in result.output