trainml 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_checkpoints_integration.py +7 -5
- tests/integration/test_datasets_integration.py +4 -5
- tests/integration/test_jobs_integration.py +40 -2
- tests/integration/test_models_integration.py +8 -10
- tests/integration/test_projects_integration.py +2 -6
- tests/integration/test_volumes_integration.py +100 -0
- tests/unit/cli/cloudbender/test_cli_reservation_unit.py +10 -14
- tests/unit/cli/test_cli_project_unit.py +5 -9
- tests/unit/cli/test_cli_volume_unit.py +20 -0
- tests/unit/cloudbender/test_services_unit.py +161 -0
- tests/unit/conftest.py +94 -21
- tests/unit/test_projects_unit.py +34 -48
- tests/unit/test_volumes_unit.py +447 -0
- trainml/__init__.py +1 -1
- trainml/cli/__init__.py +3 -6
- trainml/cli/cloudbender/__init__.py +1 -1
- trainml/cli/cloudbender/service.py +129 -0
- trainml/cli/project.py +10 -15
- trainml/cli/volume.py +235 -0
- trainml/cloudbender/cloudbender.py +2 -2
- trainml/cloudbender/services.py +115 -0
- trainml/exceptions.py +21 -12
- trainml/jobs.py +36 -39
- trainml/projects.py +19 -30
- trainml/trainml.py +7 -15
- trainml/volumes.py +255 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/METADATA +1 -1
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/RECORD +32 -29
- tests/integration/test_providers_integration.py +0 -46
- tests/unit/test_providers_unit.py +0 -125
- trainml/cli/job.py +0 -173
- trainml/cli/provider.py +0 -75
- trainml/providers.py +0 -63
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/LICENSE +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/WHEEL +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/top_level.txt +0 -0
trainml/cli/volume.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from trainml.cli import cli, pass_config, search_by_id_name
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def pretty_size(num):
|
|
6
|
+
if not num:
|
|
7
|
+
num = 0.0
|
|
8
|
+
s = (" B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB")
|
|
9
|
+
n = 0
|
|
10
|
+
while num > 1023:
|
|
11
|
+
num = num / 1024
|
|
12
|
+
n += 1
|
|
13
|
+
return f"{num:.2f} {s[n]}"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@cli.group()
|
|
17
|
+
@pass_config
|
|
18
|
+
def volume(config):
|
|
19
|
+
"""trainML volume commands."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@volume.command()
|
|
24
|
+
@click.argument("volume", type=click.STRING)
|
|
25
|
+
@pass_config
|
|
26
|
+
def attach(config, volume):
|
|
27
|
+
"""
|
|
28
|
+
Attach to volume and show creation logs.
|
|
29
|
+
|
|
30
|
+
VOLUME may be specified by name or ID, but ID is preferred.
|
|
31
|
+
"""
|
|
32
|
+
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
33
|
+
|
|
34
|
+
found = search_by_id_name(volume, volumes)
|
|
35
|
+
if None is found:
|
|
36
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
config.trainml.run(found.attach())
|
|
40
|
+
return config.trainml.run(found.disconnect())
|
|
41
|
+
except:
|
|
42
|
+
try:
|
|
43
|
+
config.trainml.run(found.disconnect())
|
|
44
|
+
except:
|
|
45
|
+
pass
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@volume.command()
|
|
50
|
+
@click.option(
|
|
51
|
+
"--attach/--no-attach",
|
|
52
|
+
default=True,
|
|
53
|
+
show_default=True,
|
|
54
|
+
help="Auto attach to volume and show creation logs.",
|
|
55
|
+
)
|
|
56
|
+
@click.argument("volume", type=click.STRING)
|
|
57
|
+
@pass_config
|
|
58
|
+
def connect(config, volume, attach):
|
|
59
|
+
"""
|
|
60
|
+
Connect local source to volume and begin upload.
|
|
61
|
+
|
|
62
|
+
VOLUME may be specified by name or ID, but ID is preferred.
|
|
63
|
+
"""
|
|
64
|
+
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
65
|
+
|
|
66
|
+
found = search_by_id_name(volume, volumes)
|
|
67
|
+
if None is found:
|
|
68
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
if attach:
|
|
72
|
+
config.trainml.run(found.connect(), found.attach())
|
|
73
|
+
return config.trainml.run(found.disconnect())
|
|
74
|
+
else:
|
|
75
|
+
return config.trainml.run(found.connect())
|
|
76
|
+
except:
|
|
77
|
+
try:
|
|
78
|
+
config.trainml.run(found.disconnect())
|
|
79
|
+
except:
|
|
80
|
+
pass
|
|
81
|
+
raise
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@volume.command()
|
|
85
|
+
@click.option(
|
|
86
|
+
"--attach/--no-attach",
|
|
87
|
+
default=True,
|
|
88
|
+
show_default=True,
|
|
89
|
+
help="Auto attach to volume and show creation logs.",
|
|
90
|
+
)
|
|
91
|
+
@click.option(
|
|
92
|
+
"--connect/--no-connect",
|
|
93
|
+
default=True,
|
|
94
|
+
show_default=True,
|
|
95
|
+
help="Auto connect source and start volume creation.",
|
|
96
|
+
)
|
|
97
|
+
@click.option(
|
|
98
|
+
"--source",
|
|
99
|
+
"-s",
|
|
100
|
+
type=click.Choice(["local"], case_sensitive=False),
|
|
101
|
+
default="local",
|
|
102
|
+
show_default=True,
|
|
103
|
+
help="Dataset source type.",
|
|
104
|
+
)
|
|
105
|
+
@click.argument("name", type=click.STRING)
|
|
106
|
+
@click.argument("capacity", type=click.INT)
|
|
107
|
+
@click.argument(
|
|
108
|
+
"path", type=click.Path(exists=True, file_okay=False, resolve_path=True)
|
|
109
|
+
)
|
|
110
|
+
@pass_config
|
|
111
|
+
def create(config, attach, connect, source, name, capacity, path):
|
|
112
|
+
"""
|
|
113
|
+
Create a volume.
|
|
114
|
+
|
|
115
|
+
A volume with maximum size CAPACITY is created with the specified NAME using a local source at the PATH
|
|
116
|
+
specified. PATH should be a local directory containing the source data for
|
|
117
|
+
a local source or a URI for all other source types.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
if source == "local":
|
|
121
|
+
volume = config.trainml.run(
|
|
122
|
+
config.trainml.client.volumes.create(
|
|
123
|
+
name=name, source_type="local", source_uri=path, capacity=capacity
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
if connect and attach:
|
|
129
|
+
config.trainml.run(volume.attach(), volume.connect())
|
|
130
|
+
return config.trainml.run(volume.disconnect())
|
|
131
|
+
elif connect:
|
|
132
|
+
return config.trainml.run(volume.connect())
|
|
133
|
+
else:
|
|
134
|
+
raise click.UsageError(
|
|
135
|
+
"Abort!\n"
|
|
136
|
+
"No logs to show for local sourced volume without connect."
|
|
137
|
+
)
|
|
138
|
+
except:
|
|
139
|
+
try:
|
|
140
|
+
config.trainml.run(volume.disconnect())
|
|
141
|
+
except:
|
|
142
|
+
pass
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@volume.command()
|
|
147
|
+
@click.argument("volume", type=click.STRING)
|
|
148
|
+
@pass_config
|
|
149
|
+
def disconnect(config, volume):
|
|
150
|
+
"""
|
|
151
|
+
Disconnect and clean-up volume upload.
|
|
152
|
+
|
|
153
|
+
VOLUME may be specified by name or ID, but ID is preferred.
|
|
154
|
+
"""
|
|
155
|
+
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
156
|
+
|
|
157
|
+
found = search_by_id_name(volume, volumes)
|
|
158
|
+
if None is found:
|
|
159
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
160
|
+
|
|
161
|
+
return config.trainml.run(found.disconnect())
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@volume.command()
|
|
165
|
+
@pass_config
|
|
166
|
+
def list(config):
|
|
167
|
+
"""List volumes."""
|
|
168
|
+
data = [
|
|
169
|
+
["ID", "STATUS", "NAME", "CAPACITY"],
|
|
170
|
+
["-" * 80, "-" * 80, "-" * 80, "-" * 80],
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
174
|
+
|
|
175
|
+
for volume in volumes:
|
|
176
|
+
data.append(
|
|
177
|
+
[
|
|
178
|
+
volume.id,
|
|
179
|
+
volume.status,
|
|
180
|
+
volume.name,
|
|
181
|
+
volume.capacity,
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
for row in data:
|
|
185
|
+
click.echo(
|
|
186
|
+
"{: >38.36} {: >13.11} {: >40.38} {: >14.12}" "".format(*row),
|
|
187
|
+
file=config.stdout,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@volume.command()
|
|
192
|
+
@click.option(
|
|
193
|
+
"--force/--no-force",
|
|
194
|
+
default=False,
|
|
195
|
+
show_default=True,
|
|
196
|
+
help="Force removal.",
|
|
197
|
+
)
|
|
198
|
+
@click.argument("volume", type=click.STRING)
|
|
199
|
+
@pass_config
|
|
200
|
+
def remove(config, volume, force):
|
|
201
|
+
"""
|
|
202
|
+
Remove a volume.
|
|
203
|
+
|
|
204
|
+
VOLUME may be specified by name or ID, but ID is preferred.
|
|
205
|
+
"""
|
|
206
|
+
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
207
|
+
|
|
208
|
+
found = search_by_id_name(volume, volumes)
|
|
209
|
+
if None is found:
|
|
210
|
+
if force:
|
|
211
|
+
config.trainml.run(found.client.volumes.remove(volume))
|
|
212
|
+
else:
|
|
213
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
214
|
+
|
|
215
|
+
return config.trainml.run(found.remove(force=force))
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@volume.command()
|
|
219
|
+
@click.argument("volume", type=click.STRING)
|
|
220
|
+
@click.argument("name", type=click.STRING)
|
|
221
|
+
@pass_config
|
|
222
|
+
def rename(config, volume, name):
|
|
223
|
+
"""
|
|
224
|
+
Renames a volume.
|
|
225
|
+
|
|
226
|
+
VOLUME may be specified by name or ID, but ID is preferred.
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
volume = config.trainml.run(config.trainml.client.volumes.get(volume))
|
|
230
|
+
if volume is None:
|
|
231
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
232
|
+
except:
|
|
233
|
+
raise click.UsageError("Cannot find specified volume.")
|
|
234
|
+
|
|
235
|
+
return config.trainml.run(volume.rename(name=name))
|
|
@@ -3,7 +3,7 @@ from .regions import Regions
|
|
|
3
3
|
from .nodes import Nodes
|
|
4
4
|
from .devices import Devices
|
|
5
5
|
from .datastores import Datastores
|
|
6
|
-
from .
|
|
6
|
+
from .services import Services
|
|
7
7
|
from .device_configs import DeviceConfigs
|
|
8
8
|
|
|
9
9
|
|
|
@@ -15,5 +15,5 @@ class Cloudbender(object):
|
|
|
15
15
|
self.nodes = Nodes(trainml)
|
|
16
16
|
self.devices = Devices(trainml)
|
|
17
17
|
self.datastores = Datastores(trainml)
|
|
18
|
-
self.
|
|
18
|
+
self.services = Services(trainml)
|
|
19
19
|
self.device_configs = DeviceConfigs(trainml)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Services(object):
|
|
6
|
+
def __init__(self, trainml):
|
|
7
|
+
self.trainml = trainml
|
|
8
|
+
|
|
9
|
+
async def get(self, provider_uuid, region_uuid, id, **kwargs):
|
|
10
|
+
resp = await self.trainml._query(
|
|
11
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service/{id}",
|
|
12
|
+
"GET",
|
|
13
|
+
kwargs,
|
|
14
|
+
)
|
|
15
|
+
return Service(self.trainml, **resp)
|
|
16
|
+
|
|
17
|
+
async def list(self, provider_uuid, region_uuid, **kwargs):
|
|
18
|
+
resp = await self.trainml._query(
|
|
19
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service",
|
|
20
|
+
"GET",
|
|
21
|
+
kwargs,
|
|
22
|
+
)
|
|
23
|
+
services = [Service(self.trainml, **service) for service in resp]
|
|
24
|
+
return services
|
|
25
|
+
|
|
26
|
+
async def create(
|
|
27
|
+
self,
|
|
28
|
+
provider_uuid,
|
|
29
|
+
region_uuid,
|
|
30
|
+
name,
|
|
31
|
+
public,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
logging.info(f"Creating Service {name}")
|
|
35
|
+
data = dict(
|
|
36
|
+
name=name,
|
|
37
|
+
public=public,
|
|
38
|
+
**kwargs,
|
|
39
|
+
)
|
|
40
|
+
payload = {k: v for k, v in data.items() if v is not None}
|
|
41
|
+
resp = await self.trainml._query(
|
|
42
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service",
|
|
43
|
+
"POST",
|
|
44
|
+
None,
|
|
45
|
+
payload,
|
|
46
|
+
)
|
|
47
|
+
service = Service(self.trainml, **resp)
|
|
48
|
+
logging.info(f"Created Service {name} with id {service.id}")
|
|
49
|
+
return service
|
|
50
|
+
|
|
51
|
+
async def remove(self, provider_uuid, region_uuid, id, **kwargs):
|
|
52
|
+
await self.trainml._query(
|
|
53
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service/{id}",
|
|
54
|
+
"DELETE",
|
|
55
|
+
kwargs,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Service:
|
|
60
|
+
def __init__(self, trainml, **kwargs):
|
|
61
|
+
self.trainml = trainml
|
|
62
|
+
self._service = kwargs
|
|
63
|
+
self._id = self._service.get("service_id")
|
|
64
|
+
self._provider_uuid = self._service.get("provider_uuid")
|
|
65
|
+
self._region_uuid = self._service.get("region_uuid")
|
|
66
|
+
self._public = self._service.get("public")
|
|
67
|
+
self._name = self._service.get("name")
|
|
68
|
+
self._hostname = self._service.get("hostname")
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def id(self) -> str:
|
|
72
|
+
return self._id
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def provider_uuid(self) -> str:
|
|
76
|
+
return self._provider_uuid
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def region_uuid(self) -> str:
|
|
80
|
+
return self._region_uuid
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def public(self) -> bool:
|
|
84
|
+
return self._public
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def name(self) -> str:
|
|
88
|
+
return self._name
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def hostname(self) -> str:
|
|
92
|
+
return self._hostname
|
|
93
|
+
|
|
94
|
+
def __str__(self):
|
|
95
|
+
return json.dumps({k: v for k, v in self._service.items()})
|
|
96
|
+
|
|
97
|
+
def __repr__(self):
|
|
98
|
+
return f"Service( trainml , **{self._service.__repr__()})"
|
|
99
|
+
|
|
100
|
+
def __bool__(self):
|
|
101
|
+
return bool(self._id)
|
|
102
|
+
|
|
103
|
+
async def remove(self):
|
|
104
|
+
await self.trainml._query(
|
|
105
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/service/{self._id}",
|
|
106
|
+
"DELETE",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def refresh(self):
|
|
110
|
+
resp = await self.trainml._query(
|
|
111
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/service/{self._id}",
|
|
112
|
+
"GET",
|
|
113
|
+
)
|
|
114
|
+
self.__init__(self.trainml, **resp)
|
|
115
|
+
return self
|
trainml/exceptions.py
CHANGED
|
@@ -97,14 +97,27 @@ class CheckpointError(TrainMLException):
|
|
|
97
97
|
return self._status
|
|
98
98
|
|
|
99
99
|
def __repr__(self):
|
|
100
|
-
return "CheckpointError({self.status}, {self.message})".format(
|
|
101
|
-
self=self
|
|
102
|
-
)
|
|
100
|
+
return "CheckpointError({self.status}, {self.message})".format(self=self)
|
|
103
101
|
|
|
104
102
|
def __str__(self):
|
|
105
|
-
return "CheckpointError({self.status}, {self.message})".format(
|
|
106
|
-
|
|
107
|
-
|
|
103
|
+
return "CheckpointError({self.status}, {self.message})".format(self=self)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class VolumeError(TrainMLException):
|
|
107
|
+
def __init__(self, status, data, *args):
|
|
108
|
+
super().__init__(data, *args)
|
|
109
|
+
self._status = status
|
|
110
|
+
self._message = data
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def status(self) -> str:
|
|
114
|
+
return self._status
|
|
115
|
+
|
|
116
|
+
def __repr__(self):
|
|
117
|
+
return "VolumeError({self.status}, {self.message})".format(self=self)
|
|
118
|
+
|
|
119
|
+
def __str__(self):
|
|
120
|
+
return "VolumeError({self.status}, {self.message})".format(self=self)
|
|
108
121
|
|
|
109
122
|
|
|
110
123
|
class ConnectionError(TrainMLException):
|
|
@@ -130,11 +143,7 @@ class SpecificationError(TrainMLException):
|
|
|
130
143
|
return self._attribute
|
|
131
144
|
|
|
132
145
|
def __repr__(self):
|
|
133
|
-
return "SpecificationError({self.attribute}, {self.message})".format(
|
|
134
|
-
self=self
|
|
135
|
-
)
|
|
146
|
+
return "SpecificationError({self.attribute}, {self.message})".format(self=self)
|
|
136
147
|
|
|
137
148
|
def __str__(self):
|
|
138
|
-
return "SpecificationError({self.attribute}, {self.message})".format(
|
|
139
|
-
self=self
|
|
140
|
-
)
|
|
149
|
+
return "SpecificationError({self.attribute}, {self.message})".format(self=self)
|
trainml/jobs.py
CHANGED
|
@@ -77,8 +77,7 @@ class Jobs(object):
|
|
|
77
77
|
model=model,
|
|
78
78
|
endpoint=endpoint,
|
|
79
79
|
source_job_uuid=kwargs.get("source_job_uuid"),
|
|
80
|
-
project_uuid=kwargs.get("project_uuid")
|
|
81
|
-
or self.trainml.active_project,
|
|
80
|
+
project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
|
|
82
81
|
)
|
|
83
82
|
payload = {
|
|
84
83
|
k: v
|
|
@@ -103,9 +102,7 @@ class Jobs(object):
|
|
|
103
102
|
return job
|
|
104
103
|
|
|
105
104
|
async def remove(self, id, **kwargs):
|
|
106
|
-
await self.trainml._query(
|
|
107
|
-
f"/job/{id}", "DELETE", dict(**kwargs, force=True)
|
|
108
|
-
)
|
|
105
|
+
await self.trainml._query(f"/job/{id}", "DELETE", dict(**kwargs, force=True))
|
|
109
106
|
|
|
110
107
|
|
|
111
108
|
class Job:
|
|
@@ -308,18 +305,26 @@ class Job:
|
|
|
308
305
|
entity_type="job",
|
|
309
306
|
project_uuid=self._job.get("project_uuid"),
|
|
310
307
|
cidr=self.dict.get("vpn").get("cidr"),
|
|
311
|
-
ssh_port=
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
308
|
+
ssh_port=(
|
|
309
|
+
self._job.get("vpn").get("client").get("ssh_port")
|
|
310
|
+
if self._job.get("vpn").get("client")
|
|
311
|
+
else None
|
|
312
|
+
),
|
|
313
|
+
model_path=(
|
|
314
|
+
self._job.get("model").get("source_uri")
|
|
315
|
+
if self._job.get("model").get("source_type") == "local"
|
|
316
|
+
else None
|
|
317
|
+
),
|
|
318
|
+
input_path=(
|
|
319
|
+
self._job.get("data").get("input_uri")
|
|
320
|
+
if self._job.get("data").get("input_type") == "local"
|
|
321
|
+
else None
|
|
322
|
+
),
|
|
323
|
+
output_path=(
|
|
324
|
+
self._job.get("data").get("output_uri")
|
|
325
|
+
if self._job.get("data").get("output_type") == "local"
|
|
326
|
+
else None
|
|
327
|
+
),
|
|
323
328
|
)
|
|
324
329
|
return details
|
|
325
330
|
|
|
@@ -396,8 +401,7 @@ class Job:
|
|
|
396
401
|
|
|
397
402
|
def _get_msg_handler(self, msg_handler):
|
|
398
403
|
worker_numbers = {
|
|
399
|
-
w.get("job_worker_uuid"): ind + 1
|
|
400
|
-
for ind, w in enumerate(self._workers)
|
|
404
|
+
w.get("job_worker_uuid"): ind + 1 for ind, w in enumerate(self._workers)
|
|
401
405
|
}
|
|
402
406
|
worker_numbers["data_worker"] = 0
|
|
403
407
|
|
|
@@ -407,9 +411,7 @@ class Job:
|
|
|
407
411
|
if msg_handler:
|
|
408
412
|
msg_handler(data)
|
|
409
413
|
else:
|
|
410
|
-
timestamp = datetime.fromtimestamp(
|
|
411
|
-
int(data.get("time")) / 1000
|
|
412
|
-
)
|
|
414
|
+
timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
|
|
413
415
|
if len(self._workers) > 1:
|
|
414
416
|
print(
|
|
415
417
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: Worker {data.get('worker_number')} - {data.get('msg').rstrip()}"
|
|
@@ -422,10 +424,7 @@ class Job:
|
|
|
422
424
|
return handler
|
|
423
425
|
|
|
424
426
|
async def attach(self, msg_handler=None):
|
|
425
|
-
if
|
|
426
|
-
self.type == "notebook"
|
|
427
|
-
and self.status != "waiting for data/model download"
|
|
428
|
-
):
|
|
427
|
+
if self.type == "notebook" and self.status != "waiting for data/model download":
|
|
429
428
|
raise SpecificationError(
|
|
430
429
|
"type",
|
|
431
430
|
"Notebooks cannot be attached to after model download is complete. Use open() instead.",
|
|
@@ -442,9 +441,7 @@ class Job:
|
|
|
442
441
|
async def copy(self, name, **kwargs):
|
|
443
442
|
logging.debug(f"copy request - name: {name} ; kwargs: {kwargs}")
|
|
444
443
|
if self.type != "notebook":
|
|
445
|
-
raise SpecificationError(
|
|
446
|
-
"job", "Only notebook job types can be copied"
|
|
447
|
-
)
|
|
444
|
+
raise SpecificationError("job", "Only notebook job types can be copied")
|
|
448
445
|
|
|
449
446
|
job = await self.trainml.jobs.create(
|
|
450
447
|
name,
|
|
@@ -504,9 +501,7 @@ class Job:
|
|
|
504
501
|
|
|
505
502
|
POLL_INTERVAL_MIN = 5
|
|
506
503
|
POLL_INTERVAL_MAX = 60
|
|
507
|
-
POLL_INTERVAL = max(
|
|
508
|
-
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
509
|
-
)
|
|
504
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
510
505
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
511
506
|
count = 0
|
|
512
507
|
while count < retry_count:
|
|
@@ -519,23 +514,25 @@ class Job:
|
|
|
519
514
|
raise e
|
|
520
515
|
if (
|
|
521
516
|
self.status == status
|
|
522
|
-
or (
|
|
523
|
-
self.type == "training"
|
|
524
|
-
and status == "finished"
|
|
525
|
-
and self.status == "stopped"
|
|
526
|
-
)
|
|
527
517
|
or (
|
|
528
518
|
status
|
|
529
519
|
in [
|
|
530
520
|
"waiting for GPUs",
|
|
531
521
|
"waiting for resources",
|
|
532
522
|
] ## this status could be very short and the polling could miss it
|
|
533
|
-
and self.status
|
|
523
|
+
and self.status
|
|
524
|
+
not in ["new", "waiting for GPUs", "waiting for resources"]
|
|
534
525
|
)
|
|
535
526
|
or (
|
|
536
527
|
status
|
|
537
528
|
== "waiting for data/model download" ## this status could be very short and the polling could miss it
|
|
538
|
-
and self.status
|
|
529
|
+
and self.status
|
|
530
|
+
not in [
|
|
531
|
+
"new",
|
|
532
|
+
"waiting for GPUs",
|
|
533
|
+
"waiting for resources",
|
|
534
|
+
"waiting for data/model download",
|
|
535
|
+
]
|
|
539
536
|
)
|
|
540
537
|
):
|
|
541
538
|
return self
|
trainml/projects.py
CHANGED
|
@@ -72,17 +72,17 @@ class ProjectDatastore:
|
|
|
72
72
|
return bool(self._id)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
class
|
|
75
|
+
class ProjectService:
|
|
76
76
|
def __init__(self, trainml, **kwargs):
|
|
77
77
|
self.trainml = trainml
|
|
78
|
-
self.
|
|
79
|
-
self._id = self.
|
|
80
|
-
self._project_uuid = self.
|
|
81
|
-
self._name = self.
|
|
82
|
-
self._type = self.
|
|
83
|
-
self._hostname = self.
|
|
84
|
-
self._resource = self.
|
|
85
|
-
self._region_uuid = self.
|
|
78
|
+
self._service = kwargs
|
|
79
|
+
self._id = self._service.get("id")
|
|
80
|
+
self._project_uuid = self._service.get("project_uuid")
|
|
81
|
+
self._name = self._service.get("name")
|
|
82
|
+
self._type = self._service.get("type")
|
|
83
|
+
self._hostname = self._service.get("hostname")
|
|
84
|
+
self._resource = self._service.get("resource")
|
|
85
|
+
self._region_uuid = self._service.get("region_uuid")
|
|
86
86
|
|
|
87
87
|
@property
|
|
88
88
|
def id(self) -> str:
|
|
@@ -113,12 +113,10 @@ class ProjectReservation:
|
|
|
113
113
|
return self._region_uuid
|
|
114
114
|
|
|
115
115
|
def __str__(self):
|
|
116
|
-
return json.dumps({k: v for k, v in self.
|
|
116
|
+
return json.dumps({k: v for k, v in self._service.items()})
|
|
117
117
|
|
|
118
118
|
def __repr__(self):
|
|
119
|
-
return (
|
|
120
|
-
f"ProjectReservation( trainml , **{self._reservation.__repr__()})"
|
|
121
|
-
)
|
|
119
|
+
return f"ProjectService( trainml , **{self._service.__repr__()})"
|
|
122
120
|
|
|
123
121
|
def __bool__(self):
|
|
124
122
|
return bool(self._id)
|
|
@@ -162,26 +160,17 @@ class Project:
|
|
|
162
160
|
await self.trainml._query(f"/project/{self._id}", "DELETE")
|
|
163
161
|
|
|
164
162
|
async def list_datastores(self):
|
|
165
|
-
resp = await self.trainml._query(
|
|
166
|
-
|
|
167
|
-
)
|
|
168
|
-
datastores = [
|
|
169
|
-
ProjectDatastore(self.trainml, **datastore) for datastore in resp
|
|
170
|
-
]
|
|
163
|
+
resp = await self.trainml._query(f"/project/{self._id}/datastores", "GET")
|
|
164
|
+
datastores = [ProjectDatastore(self.trainml, **datastore) for datastore in resp]
|
|
171
165
|
return datastores
|
|
172
166
|
|
|
173
|
-
async def
|
|
174
|
-
resp = await self.trainml._query(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
reservations = [
|
|
178
|
-
ProjectReservation(self.trainml, **reservation)
|
|
179
|
-
for reservation in resp
|
|
180
|
-
]
|
|
181
|
-
return reservations
|
|
167
|
+
async def list_services(self):
|
|
168
|
+
resp = await self.trainml._query(f"/project/{self._id}/services", "GET")
|
|
169
|
+
services = [ProjectService(self.trainml, **service) for service in resp]
|
|
170
|
+
return services
|
|
182
171
|
|
|
183
172
|
async def refresh_datastores(self):
|
|
184
173
|
await self.trainml._query(f"/project/{self._id}/datastores", "PATCH")
|
|
185
174
|
|
|
186
|
-
async def
|
|
187
|
-
await self.trainml._query(f"/project/{self._id}/
|
|
175
|
+
async def refresh_services(self):
|
|
176
|
+
await self.trainml._query(f"/project/{self._id}/services", "PATCH")
|