trainml 0.5.9__py3-none-any.whl → 0.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/projects/conftest.py +3 -1
- tests/integration/projects/test_projects_credentials_integration.py +45 -0
- tests/integration/projects/test_projects_data_connectors_integration.py +44 -0
- tests/integration/projects/test_projects_datastores_integration.py +42 -0
- tests/integration/projects/test_projects_secrets_integration.py +1 -1
- tests/integration/projects/test_projects_services_integration.py +44 -0
- tests/integration/test_checkpoints_integration.py +1 -2
- tests/integration/test_models_integration.py +0 -1
- tests/unit/cli/projects/__init__.py +0 -0
- tests/unit/cli/projects/test_cli_project_credential_unit.py +26 -0
- tests/unit/cli/projects/test_cli_project_data_connector_unit.py +28 -0
- tests/unit/cli/projects/test_cli_project_datastore_unit.py +26 -0
- tests/unit/cli/projects/test_cli_project_key_unit.py +26 -0
- tests/unit/cli/projects/test_cli_project_secret_unit.py +26 -0
- tests/unit/cli/projects/test_cli_project_service_unit.py +26 -0
- tests/unit/cli/projects/test_cli_project_unit.py +19 -0
- tests/unit/cloudbender/test_datastores_unit.py +1 -5
- tests/unit/conftest.py +79 -6
- tests/unit/projects/test_project_credentials_unit.py +100 -0
- tests/unit/projects/test_projects_unit.py +1 -1
- tests/unit/test_checkpoints_unit.py +15 -23
- tests/unit/test_datasets_unit.py +15 -20
- tests/unit/test_models_unit.py +13 -16
- tests/unit/test_volumes_unit.py +3 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +14 -3
- trainml/cli/cloudbender/datastore.py +2 -7
- trainml/cli/job/create.py +16 -16
- trainml/cli/project/__init__.py +4 -73
- trainml/cli/project/credential.py +128 -0
- trainml/cli/project/data_connector.py +61 -0
- trainml/cli/project/datastore.py +61 -0
- trainml/cli/project/secret.py +12 -3
- trainml/cli/project/service.py +61 -0
- trainml/cloudbender/data_connectors.py +8 -0
- trainml/cloudbender/datastores.py +9 -19
- trainml/cloudbender/nodes.py +44 -1
- trainml/cloudbender/providers.py +53 -0
- trainml/cloudbender/regions.py +48 -0
- trainml/datasets.py +14 -3
- trainml/exceptions.py +51 -0
- trainml/jobs.py +2 -13
- trainml/models.py +14 -3
- trainml/projects/credentials.py +71 -0
- trainml/projects/projects.py +7 -4
- trainml/projects/secrets.py +1 -1
- trainml/volumes.py +15 -3
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/METADATA +1 -1
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/RECORD +53 -46
- tests/integration/test_projects_integration.py +0 -44
- tests/unit/cli/cloudbender/test_cli_reservation_unit.py +0 -34
- tests/unit/cli/test_cli_project_unit.py +0 -42
- tests/unit/cloudbender/test_reservations_unit.py +0 -173
- tests/unit/test_auth.py +0 -30
- tests/unit/test_projects_unit.py +0 -320
- tests/unit/test_trainml.py +0 -54
- trainml/cli/cloudbender/reservation.py +0 -159
- trainml/cli/project.py +0 -149
- trainml/cloudbender/reservations.py +0 -126
- trainml/projects.py +0 -228
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/LICENSE +0 -0
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/WHEEL +0 -0
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.9.dist-info → trainml-0.5.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import base64
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from trainml.cli import pass_config
|
|
7
|
+
from trainml.cli.project import project
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@project.group()
|
|
11
|
+
@pass_config
|
|
12
|
+
def credential(config):
|
|
13
|
+
"""trainML project credential commands."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@credential.command()
|
|
18
|
+
@pass_config
|
|
19
|
+
def list(config):
|
|
20
|
+
"""List credentials."""
|
|
21
|
+
data = [
|
|
22
|
+
["TYPE", "KEY ID", "UPDATED AT"],
|
|
23
|
+
[
|
|
24
|
+
"-" * 80,
|
|
25
|
+
"-" * 80,
|
|
26
|
+
"-" * 80,
|
|
27
|
+
],
|
|
28
|
+
]
|
|
29
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
30
|
+
credentials = config.trainml.run(project.credentials.list())
|
|
31
|
+
|
|
32
|
+
for credential in credentials:
|
|
33
|
+
data.append(
|
|
34
|
+
[
|
|
35
|
+
credential.type,
|
|
36
|
+
credential.key_id,
|
|
37
|
+
credential.updated_at.isoformat(timespec="seconds"),
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
for row in data:
|
|
42
|
+
click.echo(
|
|
43
|
+
"{: >13.11} {: >37.35} {: >28.26}" "".format(*row),
|
|
44
|
+
file=config.stdout,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@credential.command()
|
|
49
|
+
@click.argument(
|
|
50
|
+
"type",
|
|
51
|
+
type=click.Choice(
|
|
52
|
+
[
|
|
53
|
+
"aws",
|
|
54
|
+
"azure",
|
|
55
|
+
"docker",
|
|
56
|
+
"gcp",
|
|
57
|
+
"huggingface",
|
|
58
|
+
"kaggle",
|
|
59
|
+
"ngc",
|
|
60
|
+
"wasabi",
|
|
61
|
+
],
|
|
62
|
+
case_sensitive=False,
|
|
63
|
+
),
|
|
64
|
+
)
|
|
65
|
+
@pass_config
|
|
66
|
+
def put(config, type):
|
|
67
|
+
"""
|
|
68
|
+
Set a credential.
|
|
69
|
+
|
|
70
|
+
A credential is uploaded.
|
|
71
|
+
"""
|
|
72
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
73
|
+
|
|
74
|
+
tenant = None
|
|
75
|
+
|
|
76
|
+
if type in ["aws", "wasabi"]:
|
|
77
|
+
credential_id = click.prompt(
|
|
78
|
+
"Enter the credential ID", type=str, hide_input=False
|
|
79
|
+
)
|
|
80
|
+
secret = click.prompt("Enter the secret credential", type=str, hide_input=True)
|
|
81
|
+
elif type == "azure":
|
|
82
|
+
credential_id = click.prompt(
|
|
83
|
+
"Enter the Application (client) ID", type=str, hide_input=False
|
|
84
|
+
)
|
|
85
|
+
tenant = click.prompt(
|
|
86
|
+
"Enter the Directory (tenant) ley", type=str, hide_input=False
|
|
87
|
+
)
|
|
88
|
+
secret = click.prompt("Enter the client secret", type=str, hide_input=True)
|
|
89
|
+
elif type in ["docker", "huggingface"]:
|
|
90
|
+
credential_id = click.prompt("Enter the username", type=str, hide_input=False)
|
|
91
|
+
secret = click.prompt("Enter the access token", type=str, hide_input=True)
|
|
92
|
+
elif type in ["gcp", "kaggle"]:
|
|
93
|
+
file_name = click.prompt(
|
|
94
|
+
"Enter the path of the credentials file",
|
|
95
|
+
type=click.Path(
|
|
96
|
+
exists=True, file_okay=True, dir_okay=False, resolve_path=True
|
|
97
|
+
),
|
|
98
|
+
hide_input=False,
|
|
99
|
+
)
|
|
100
|
+
credential_id = os.path.basename(file_name)
|
|
101
|
+
with open(file_name) as f:
|
|
102
|
+
secret = json.load(f)
|
|
103
|
+
secret = json.dumps(secret)
|
|
104
|
+
elif type == "ngc":
|
|
105
|
+
credential_id = "$oauthtoken"
|
|
106
|
+
secret = click.prompt("Enter the access token", type=str, hide_input=True)
|
|
107
|
+
else:
|
|
108
|
+
raise click.UsageError("Unsupported credential type")
|
|
109
|
+
|
|
110
|
+
return config.trainml.run(
|
|
111
|
+
project.credentials.put(
|
|
112
|
+
type=type, credential_id=credential_id, secret=secret, tenant=tenant
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@credential.command()
|
|
118
|
+
@click.argument("name", type=click.STRING)
|
|
119
|
+
@pass_config
|
|
120
|
+
def remove(config, name):
|
|
121
|
+
"""
|
|
122
|
+
Remove a credential.
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
"""
|
|
126
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
127
|
+
|
|
128
|
+
return config.trainml.run(project.credential.remove(name))
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import base64
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from trainml.cli import pass_config
|
|
7
|
+
from trainml.cli.project import project
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@project.group()
|
|
11
|
+
@pass_config
|
|
12
|
+
def data_connector(config):
|
|
13
|
+
"""trainML project data_connector commands."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@data_connector.command()
|
|
18
|
+
@pass_config
|
|
19
|
+
def list(config):
|
|
20
|
+
"""List project data_connectors."""
|
|
21
|
+
data = [
|
|
22
|
+
["ID", "NAME", "TYPE", "REGION_UUID"],
|
|
23
|
+
[
|
|
24
|
+
"-" * 80,
|
|
25
|
+
"-" * 80,
|
|
26
|
+
"-" * 80,
|
|
27
|
+
"-" * 80,
|
|
28
|
+
],
|
|
29
|
+
]
|
|
30
|
+
project = config.trainml.run(
|
|
31
|
+
config.trainml.client.projects.get(config.trainml.client.project)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
data_connectors = config.trainml.run(project.data_connectors.list())
|
|
35
|
+
|
|
36
|
+
for data_connector in data_connectors:
|
|
37
|
+
data.append(
|
|
38
|
+
[
|
|
39
|
+
data_connector.id,
|
|
40
|
+
data_connector.name,
|
|
41
|
+
data_connector.type,
|
|
42
|
+
data_connector.region_uuid,
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
for row in data:
|
|
47
|
+
click.echo(
|
|
48
|
+
"{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
|
|
49
|
+
file=config.stdout,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@data_connector.command()
|
|
54
|
+
@pass_config
|
|
55
|
+
def refresh(config):
|
|
56
|
+
"""
|
|
57
|
+
Refresh project data_connector list.
|
|
58
|
+
"""
|
|
59
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
60
|
+
|
|
61
|
+
return config.trainml.run(project.data_connectors.refresh())
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import base64
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from trainml.cli import pass_config
|
|
7
|
+
from trainml.cli.project import project
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@project.group()
|
|
11
|
+
@pass_config
|
|
12
|
+
def datastore(config):
|
|
13
|
+
"""trainML project datastore commands."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@datastore.command()
|
|
18
|
+
@pass_config
|
|
19
|
+
def list(config):
|
|
20
|
+
"""List project datastores."""
|
|
21
|
+
data = [
|
|
22
|
+
["ID", "NAME", "TYPE", "REGION_UUID"],
|
|
23
|
+
[
|
|
24
|
+
"-" * 80,
|
|
25
|
+
"-" * 80,
|
|
26
|
+
"-" * 80,
|
|
27
|
+
"-" * 80,
|
|
28
|
+
],
|
|
29
|
+
]
|
|
30
|
+
project = config.trainml.run(
|
|
31
|
+
config.trainml.client.projects.get(config.trainml.client.project)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
datastores = config.trainml.run(project.datastores.list())
|
|
35
|
+
|
|
36
|
+
for datastore in datastores:
|
|
37
|
+
data.append(
|
|
38
|
+
[
|
|
39
|
+
datastore.id,
|
|
40
|
+
datastore.name,
|
|
41
|
+
datastore.type,
|
|
42
|
+
datastore.region_uuid,
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
for row in data:
|
|
47
|
+
click.echo(
|
|
48
|
+
"{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
|
|
49
|
+
file=config.stdout,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@datastore.command()
|
|
54
|
+
@pass_config
|
|
55
|
+
def refresh(config):
|
|
56
|
+
"""
|
|
57
|
+
Refresh project datastore list.
|
|
58
|
+
"""
|
|
59
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
60
|
+
|
|
61
|
+
return config.trainml.run(project.datastores.refresh())
|
trainml/cli/project/secret.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import click
|
|
2
|
+
import os
|
|
2
3
|
from trainml.cli import pass_config
|
|
3
4
|
from trainml.cli.project import project
|
|
4
5
|
|
|
@@ -42,17 +43,25 @@ def list(config):
|
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
@secret.command()
|
|
46
|
+
@click.option(
|
|
47
|
+
"--file",
|
|
48
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
|
|
49
|
+
help="Load the secret value from the file at the provided path",
|
|
50
|
+
)
|
|
45
51
|
@click.argument("name", type=click.STRING)
|
|
46
52
|
@pass_config
|
|
47
|
-
def put(config, name):
|
|
53
|
+
def put(config, file, name):
|
|
48
54
|
"""
|
|
49
55
|
Set a secret value.
|
|
50
56
|
|
|
51
57
|
Secret is created with the specified NAME.
|
|
52
58
|
"""
|
|
53
59
|
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
54
|
-
|
|
55
|
-
|
|
60
|
+
if file:
|
|
61
|
+
with open(os.path.expanduser(file)) as f:
|
|
62
|
+
value = f.read()
|
|
63
|
+
else:
|
|
64
|
+
value = click.prompt("Enter the secret value", type=str, hide_input=True)
|
|
56
65
|
|
|
57
66
|
return config.trainml.run(project.secrets.put(name=name, value=value))
|
|
58
67
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import base64
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from trainml.cli import pass_config
|
|
7
|
+
from trainml.cli.project import project
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@project.group()
|
|
11
|
+
@pass_config
|
|
12
|
+
def service(config):
|
|
13
|
+
"""trainML project service commands."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@service.command()
|
|
18
|
+
@pass_config
|
|
19
|
+
def list(config):
|
|
20
|
+
"""List project services."""
|
|
21
|
+
data = [
|
|
22
|
+
["ID", "NAME", "TYPE", "REGION_UUID"],
|
|
23
|
+
[
|
|
24
|
+
"-" * 80,
|
|
25
|
+
"-" * 80,
|
|
26
|
+
"-" * 80,
|
|
27
|
+
"-" * 80,
|
|
28
|
+
],
|
|
29
|
+
]
|
|
30
|
+
project = config.trainml.run(
|
|
31
|
+
config.trainml.client.projects.get(config.trainml.client.project)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
services = config.trainml.run(project.services.list())
|
|
35
|
+
|
|
36
|
+
for service in services:
|
|
37
|
+
data.append(
|
|
38
|
+
[
|
|
39
|
+
service.id,
|
|
40
|
+
service.name,
|
|
41
|
+
service.hostname,
|
|
42
|
+
service.region_uuid,
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
for row in data:
|
|
47
|
+
click.echo(
|
|
48
|
+
"{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
|
|
49
|
+
file=config.stdout,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@service.command()
|
|
54
|
+
@pass_config
|
|
55
|
+
def refresh(config):
|
|
56
|
+
"""
|
|
57
|
+
Refresh project service list.
|
|
58
|
+
"""
|
|
59
|
+
project = config.trainml.run(config.trainml.client.projects.get_current())
|
|
60
|
+
|
|
61
|
+
return config.trainml.run(project.services.refresh())
|
|
@@ -1,5 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from trainml.exceptions import (
|
|
7
|
+
ApiError,
|
|
8
|
+
SpecificationError,
|
|
9
|
+
TrainMLException,
|
|
10
|
+
)
|
|
3
11
|
|
|
4
12
|
|
|
5
13
|
class Datastores(object):
|
|
@@ -20,9 +28,7 @@ class Datastores(object):
|
|
|
20
28
|
"GET",
|
|
21
29
|
kwargs,
|
|
22
30
|
)
|
|
23
|
-
datastores = [
|
|
24
|
-
Datastore(self.trainml, **datastore) for datastore in resp
|
|
25
|
-
]
|
|
31
|
+
datastores = [Datastore(self.trainml, **datastore) for datastore in resp]
|
|
26
32
|
return datastores
|
|
27
33
|
|
|
28
34
|
async def create(
|
|
@@ -31,18 +37,12 @@ class Datastores(object):
|
|
|
31
37
|
region_uuid,
|
|
32
38
|
name,
|
|
33
39
|
type,
|
|
34
|
-
uri,
|
|
35
|
-
root,
|
|
36
|
-
options=None,
|
|
37
40
|
**kwargs,
|
|
38
41
|
):
|
|
39
42
|
logging.info(f"Creating Datastore {name}")
|
|
40
43
|
data = dict(
|
|
41
44
|
name=name,
|
|
42
45
|
type=type,
|
|
43
|
-
uri=uri,
|
|
44
|
-
root=root,
|
|
45
|
-
options=options,
|
|
46
46
|
**kwargs,
|
|
47
47
|
)
|
|
48
48
|
payload = {k: v for k, v in data.items() if v is not None}
|
|
@@ -73,8 +73,6 @@ class Datastore:
|
|
|
73
73
|
self._region_uuid = self._datastore.get("region_uuid")
|
|
74
74
|
self._type = self._datastore.get("type")
|
|
75
75
|
self._name = self._datastore.get("name")
|
|
76
|
-
self._uri = self._datastore.get("uri")
|
|
77
|
-
self._root = self._datastore.get("root")
|
|
78
76
|
|
|
79
77
|
@property
|
|
80
78
|
def id(self) -> str:
|
|
@@ -96,14 +94,6 @@ class Datastore:
|
|
|
96
94
|
def name(self) -> str:
|
|
97
95
|
return self._name
|
|
98
96
|
|
|
99
|
-
@property
|
|
100
|
-
def uri(self) -> str:
|
|
101
|
-
return self._uri
|
|
102
|
-
|
|
103
|
-
@property
|
|
104
|
-
def root(self) -> str:
|
|
105
|
-
return self._root
|
|
106
|
-
|
|
107
97
|
def __str__(self):
|
|
108
98
|
return json.dumps({k: v for k, v in self._datastore.items()})
|
|
109
99
|
|
trainml/cloudbender/nodes.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from trainml.exceptions import ApiError, SpecificationError, TrainMLException, NodeError
|
|
3
7
|
|
|
4
8
|
|
|
5
9
|
class Nodes(object):
|
|
@@ -29,7 +33,7 @@ class Nodes(object):
|
|
|
29
33
|
region_uuid,
|
|
30
34
|
friendly_name,
|
|
31
35
|
hostname,
|
|
32
|
-
minion_id,
|
|
36
|
+
minion_id=None,
|
|
33
37
|
type="permanent",
|
|
34
38
|
service="compute",
|
|
35
39
|
**kwargs,
|
|
@@ -153,3 +157,42 @@ class Node:
|
|
|
153
157
|
None,
|
|
154
158
|
dict(command=command),
|
|
155
159
|
)
|
|
160
|
+
|
|
161
|
+
async def wait_for(self, status, timeout=300):
|
|
162
|
+
if self.status == status:
|
|
163
|
+
return
|
|
164
|
+
valid_statuses = ["active", "maintenance", "offline", "stopped", "archived"]
|
|
165
|
+
if not status in valid_statuses:
|
|
166
|
+
raise SpecificationError(
|
|
167
|
+
"status",
|
|
168
|
+
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
169
|
+
)
|
|
170
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
171
|
+
if timeout > MAX_TIMEOUT:
|
|
172
|
+
raise SpecificationError(
|
|
173
|
+
"timeout",
|
|
174
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
POLL_INTERVAL_MIN = 5
|
|
178
|
+
POLL_INTERVAL_MAX = 60
|
|
179
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
180
|
+
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
181
|
+
count = 0
|
|
182
|
+
while count < retry_count:
|
|
183
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
184
|
+
try:
|
|
185
|
+
await self.refresh()
|
|
186
|
+
except ApiError as e:
|
|
187
|
+
if status == "archived" and e.status == 404:
|
|
188
|
+
return
|
|
189
|
+
raise e
|
|
190
|
+
if self.status in ["errored", "failed"]:
|
|
191
|
+
raise NodeError(self.status, self)
|
|
192
|
+
if self.status == status:
|
|
193
|
+
return self
|
|
194
|
+
else:
|
|
195
|
+
count += 1
|
|
196
|
+
logging.debug(f"self: {self}, retry count {count}")
|
|
197
|
+
|
|
198
|
+
raise TrainMLException(f"Timeout waiting for {status}")
|
trainml/cloudbender/providers.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
|
|
7
|
+
from trainml.exceptions import (
|
|
8
|
+
ApiError,
|
|
9
|
+
SpecificationError,
|
|
10
|
+
TrainMLException,
|
|
11
|
+
ProviderError,
|
|
12
|
+
)
|
|
13
|
+
|
|
5
14
|
|
|
6
15
|
class Providers(object):
|
|
7
16
|
def __init__(self, trainml):
|
|
@@ -36,6 +45,7 @@ class Provider:
|
|
|
36
45
|
self._provider = kwargs
|
|
37
46
|
self._id = self._provider.get("provider_uuid")
|
|
38
47
|
self._type = self._provider.get("type")
|
|
48
|
+
self._status = self._provider.get("status")
|
|
39
49
|
self._credits = self._provider.get("credits")
|
|
40
50
|
|
|
41
51
|
@property
|
|
@@ -46,6 +56,10 @@ class Provider:
|
|
|
46
56
|
def type(self) -> str:
|
|
47
57
|
return self._type
|
|
48
58
|
|
|
59
|
+
@property
|
|
60
|
+
def status(self) -> str:
|
|
61
|
+
return self._status
|
|
62
|
+
|
|
49
63
|
@property
|
|
50
64
|
def credits(self) -> float:
|
|
51
65
|
return self._credits
|
|
@@ -69,3 +83,42 @@ class Provider:
|
|
|
69
83
|
)
|
|
70
84
|
self.__init__(self.trainml, **resp)
|
|
71
85
|
return self
|
|
86
|
+
|
|
87
|
+
async def wait_for(self, status, timeout=300):
|
|
88
|
+
if self.status == status:
|
|
89
|
+
return
|
|
90
|
+
valid_statuses = ["ready", "archived"]
|
|
91
|
+
if not status in valid_statuses:
|
|
92
|
+
raise SpecificationError(
|
|
93
|
+
"status",
|
|
94
|
+
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
95
|
+
)
|
|
96
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
97
|
+
if timeout > MAX_TIMEOUT:
|
|
98
|
+
raise SpecificationError(
|
|
99
|
+
"timeout",
|
|
100
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
POLL_INTERVAL_MIN = 5
|
|
104
|
+
POLL_INTERVAL_MAX = 60
|
|
105
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
106
|
+
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
107
|
+
count = 0
|
|
108
|
+
while count < retry_count:
|
|
109
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
110
|
+
try:
|
|
111
|
+
await self.refresh()
|
|
112
|
+
except ApiError as e:
|
|
113
|
+
if status == "archived" and e.status == 404:
|
|
114
|
+
return
|
|
115
|
+
raise e
|
|
116
|
+
if self.status in ["errored", "failed"]:
|
|
117
|
+
raise ProviderError(self.status, self)
|
|
118
|
+
if self.status == status:
|
|
119
|
+
return self
|
|
120
|
+
else:
|
|
121
|
+
count += 1
|
|
122
|
+
logging.debug(f"self: {self}, retry count {count}")
|
|
123
|
+
|
|
124
|
+
raise TrainMLException(f"Timeout waiting for {status}")
|
trainml/cloudbender/regions.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from trainml.exceptions import (
|
|
7
|
+
ApiError,
|
|
8
|
+
SpecificationError,
|
|
9
|
+
TrainMLException,
|
|
10
|
+
RegionError,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
|
|
5
14
|
class Regions(object):
|
|
@@ -111,3 +120,42 @@ class Region:
|
|
|
111
120
|
None,
|
|
112
121
|
dict(project_uuid=project_uuid, checkpoint_uuid=checkpoint_uuid),
|
|
113
122
|
)
|
|
123
|
+
|
|
124
|
+
async def wait_for(self, status, timeout=300):
|
|
125
|
+
if self.status == status:
|
|
126
|
+
return
|
|
127
|
+
valid_statuses = ["healthy", "offline", "archived"]
|
|
128
|
+
if not status in valid_statuses:
|
|
129
|
+
raise SpecificationError(
|
|
130
|
+
"status",
|
|
131
|
+
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
132
|
+
)
|
|
133
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
134
|
+
if timeout > MAX_TIMEOUT:
|
|
135
|
+
raise SpecificationError(
|
|
136
|
+
"timeout",
|
|
137
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
POLL_INTERVAL_MIN = 5
|
|
141
|
+
POLL_INTERVAL_MAX = 60
|
|
142
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
143
|
+
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
144
|
+
count = 0
|
|
145
|
+
while count < retry_count:
|
|
146
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
147
|
+
try:
|
|
148
|
+
await self.refresh()
|
|
149
|
+
except ApiError as e:
|
|
150
|
+
if status == "archived" and e.status == 404:
|
|
151
|
+
return
|
|
152
|
+
raise e
|
|
153
|
+
if self.status in ["errored", "failed"]:
|
|
154
|
+
raise RegionError(self.status, self)
|
|
155
|
+
if self.status == status:
|
|
156
|
+
return self
|
|
157
|
+
else:
|
|
158
|
+
count += 1
|
|
159
|
+
logging.debug(f"self: {self}, retry count {count}")
|
|
160
|
+
|
|
161
|
+
raise TrainMLException(f"Timeout waiting for {status}")
|
trainml/datasets.py
CHANGED
|
@@ -31,13 +31,24 @@ class Datasets(object):
|
|
|
31
31
|
datasets = [Dataset(self.trainml, **dataset) for dataset in resp]
|
|
32
32
|
return datasets
|
|
33
33
|
|
|
34
|
-
async def create(
|
|
34
|
+
async def create(
|
|
35
|
+
self,
|
|
36
|
+
name,
|
|
37
|
+
source_type,
|
|
38
|
+
source_uri,
|
|
39
|
+
type="evefs",
|
|
40
|
+
project_uuid=None,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
if not project_uuid:
|
|
44
|
+
project_uuid = self.trainml.active_project
|
|
35
45
|
data = dict(
|
|
36
46
|
name=name,
|
|
37
47
|
source_type=source_type,
|
|
38
48
|
source_uri=source_uri,
|
|
39
|
-
|
|
40
|
-
|
|
49
|
+
project_uuid=project_uuid,
|
|
50
|
+
type=type,
|
|
51
|
+
**kwargs,
|
|
41
52
|
)
|
|
42
53
|
payload = {k: v for k, v in data.items() if v is not None}
|
|
43
54
|
logging.info(f"Creating Dataset {name}")
|