truefoundry 0.5.3rc4__py3-none-any.whl → 0.5.3rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/__init__.py +10 -1
- truefoundry/autodeploy/cli.py +2 -2
- truefoundry/cli/__main__.py +0 -4
- truefoundry/cli/util.py +12 -3
- truefoundry/common/auth_service_client.py +7 -4
- truefoundry/common/constants.py +3 -1
- truefoundry/common/credential_provider.py +7 -8
- truefoundry/common/exceptions.py +11 -7
- truefoundry/common/request_utils.py +96 -14
- truefoundry/common/servicefoundry_client.py +31 -29
- truefoundry/common/session.py +93 -0
- truefoundry/common/storage_provider_utils.py +331 -0
- truefoundry/common/utils.py +9 -9
- truefoundry/common/warnings.py +21 -0
- truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +8 -20
- truefoundry/deploy/cli/commands/deploy_command.py +4 -4
- truefoundry/deploy/lib/clients/servicefoundry_client.py +13 -14
- truefoundry/deploy/lib/dao/application.py +2 -2
- truefoundry/deploy/lib/dao/workspace.py +1 -1
- truefoundry/deploy/lib/session.py +1 -1
- truefoundry/deploy/v2/lib/deploy.py +2 -2
- truefoundry/deploy/v2/lib/deploy_workflow.py +1 -1
- truefoundry/deploy/v2/lib/patched_models.py +70 -4
- truefoundry/deploy/v2/lib/source.py +2 -1
- truefoundry/gateway/cli/cli.py +1 -22
- truefoundry/gateway/lib/entities.py +3 -8
- truefoundry/gateway/lib/models.py +0 -38
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +33 -297
- truefoundry/ml/clients/servicefoundry_client.py +36 -15
- truefoundry/ml/exceptions.py +2 -1
- truefoundry/ml/log_types/artifacts/artifact.py +3 -2
- truefoundry/ml/log_types/artifacts/model.py +6 -5
- truefoundry/ml/log_types/artifacts/utils.py +2 -2
- truefoundry/ml/mlfoundry_api.py +6 -38
- truefoundry/ml/mlfoundry_run.py +6 -15
- truefoundry/ml/model_framework.py +2 -1
- truefoundry/ml/session.py +69 -97
- truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py +42 -9
- truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py +126 -7
- {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/METADATA +1 -1
- {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/RECORD +43 -44
- truefoundry/deploy/lib/auth/servicefoundry_session.py +0 -61
- truefoundry/gateway/lib/client.py +0 -51
- truefoundry/ml/clients/entities.py +0 -8
- truefoundry/ml/clients/utils.py +0 -122
- {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/WHEEL +0 -0
- {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.3rc5.dist-info}/entry_points.txt +0 -0
|
@@ -1,36 +1,57 @@
|
|
|
1
|
-
|
|
1
|
+
import functools
|
|
2
2
|
|
|
3
3
|
from truefoundry.common.constants import (
|
|
4
4
|
SERVICEFOUNDRY_CLIENT_MAX_RETRIES,
|
|
5
5
|
VERSION_PREFIX,
|
|
6
6
|
)
|
|
7
|
+
from truefoundry.common.exceptions import HttpRequestException
|
|
8
|
+
from truefoundry.common.request_utils import (
|
|
9
|
+
http_request,
|
|
10
|
+
request_handling,
|
|
11
|
+
requests_retry_session,
|
|
12
|
+
)
|
|
7
13
|
from truefoundry.common.servicefoundry_client import (
|
|
8
14
|
ServiceFoundryServiceClient as BaseServiceFoundryServiceClient,
|
|
9
15
|
)
|
|
10
|
-
from truefoundry.ml.clients.entities import (
|
|
11
|
-
HostCreds,
|
|
12
|
-
)
|
|
13
|
-
from truefoundry.ml.clients.utils import http_request_safe
|
|
14
16
|
from truefoundry.ml.exceptions import MlFoundryException
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class ServiceFoundryServiceClient(BaseServiceFoundryServiceClient):
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
def __init__(self, tfy_host: str, token: str):
|
|
21
|
+
super().__init__(tfy_host=tfy_host)
|
|
22
|
+
self._token = token
|
|
23
|
+
|
|
24
|
+
@functools.cached_property
|
|
25
|
+
def _min_cli_version_required(self) -> str:
|
|
26
|
+
# TODO (chiragjn): read the mlfoundry min cli version from the config?
|
|
27
|
+
return self.python_sdk_config.truefoundry_cli_min_version
|
|
22
28
|
|
|
23
29
|
def get_integration_from_id(self, integration_id: str):
|
|
24
30
|
integration_id = integration_id or ""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
endpoint=f"{VERSION_PREFIX}/provider-accounts/provider-integrations",
|
|
28
|
-
params={"id": integration_id, "type": "blob-storage"},
|
|
31
|
+
session = requests_retry_session(retries=SERVICEFOUNDRY_CLIENT_MAX_RETRIES)
|
|
32
|
+
response = http_request(
|
|
29
33
|
method="get",
|
|
34
|
+
url=f"{self._api_server_url}/{VERSION_PREFIX}/provider-accounts/provider-integrations",
|
|
35
|
+
token=self._token,
|
|
30
36
|
timeout=3,
|
|
31
|
-
|
|
37
|
+
params={"id": integration_id, "type": "blob-storage"},
|
|
38
|
+
session=session,
|
|
32
39
|
)
|
|
33
|
-
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
data = request_handling(response)
|
|
43
|
+
assert isinstance(data, dict)
|
|
44
|
+
except HttpRequestException as he:
|
|
45
|
+
raise MlFoundryException(
|
|
46
|
+
f"Failed to get storage integration from id: {integration_id}. Error: {he.message}",
|
|
47
|
+
status_code=he.status_code,
|
|
48
|
+
) from None
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise MlFoundryException(
|
|
51
|
+
f"Failed to get storage integration from id: {integration_id}. Error: {str(e)}"
|
|
52
|
+
) from None
|
|
53
|
+
|
|
54
|
+
# TODO (chiragjn): Parse this using Pydantic
|
|
34
55
|
if (
|
|
35
56
|
data.get("providerIntegrations")
|
|
36
57
|
and len(data["providerIntegrations"]) > 0
|
truefoundry/ml/exceptions.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
|
|
4
|
+
# TODO (chiragjn): We need to establish uniform exception handling across codebase
|
|
4
5
|
class MlFoundryException(Exception):
|
|
5
|
-
def __init__(self, message, status_code: Optional[int] = None):
|
|
6
|
+
def __init__(self, message: str, status_code: Optional[int] = None):
|
|
6
7
|
self.message = str(message)
|
|
7
8
|
self.status_code = status_code
|
|
8
9
|
super().__init__(message)
|
|
@@ -9,6 +9,7 @@ import warnings
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
|
11
11
|
|
|
12
|
+
from truefoundry.common.warnings import TrueFoundryDeprecationWarning
|
|
12
13
|
from truefoundry.ml.artifact.truefoundry_artifact_repo import (
|
|
13
14
|
ArtifactIdentifier,
|
|
14
15
|
MlFoundryArtifactsRepository,
|
|
@@ -217,7 +218,7 @@ class ArtifactVersion:
|
|
|
217
218
|
if not self._artifact_version.manifest:
|
|
218
219
|
warnings.warn(
|
|
219
220
|
message="This model version was created using an older serialization format. tags do not exist, returning empty list",
|
|
220
|
-
category=
|
|
221
|
+
category=TrueFoundryDeprecationWarning,
|
|
221
222
|
stacklevel=2,
|
|
222
223
|
)
|
|
223
224
|
return self._tags
|
|
@@ -230,7 +231,7 @@ class ArtifactVersion:
|
|
|
230
231
|
if not self._artifact_version.manifest:
|
|
231
232
|
warnings.warn(
|
|
232
233
|
message="This model version was created using an older serialization format. Tags will not be updated",
|
|
233
|
-
category=
|
|
234
|
+
category=TrueFoundryDeprecationWarning,
|
|
234
235
|
stacklevel=2,
|
|
235
236
|
)
|
|
236
237
|
return
|
|
@@ -10,6 +10,7 @@ import warnings
|
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
12
12
|
|
|
13
|
+
from truefoundry.common.warnings import TrueFoundryDeprecationWarning
|
|
13
14
|
from truefoundry.ml.artifact.truefoundry_artifact_repo import (
|
|
14
15
|
ArtifactIdentifier,
|
|
15
16
|
MlFoundryArtifactsRepository,
|
|
@@ -245,7 +246,7 @@ class ModelVersion:
|
|
|
245
246
|
if not self._model_version.manifest:
|
|
246
247
|
warnings.warn(
|
|
247
248
|
message="This model version was created using an older serialization format. tags do not exist, returning empty list",
|
|
248
|
-
category=
|
|
249
|
+
category=TrueFoundryDeprecationWarning,
|
|
249
250
|
stacklevel=2,
|
|
250
251
|
)
|
|
251
252
|
return self._tags
|
|
@@ -258,7 +259,7 @@ class ModelVersion:
|
|
|
258
259
|
if not self._model_version.manifest:
|
|
259
260
|
warnings.warn(
|
|
260
261
|
message="This model version was created using an older serialization format. Tags will not be updated",
|
|
261
|
-
category=
|
|
262
|
+
category=TrueFoundryDeprecationWarning,
|
|
262
263
|
stacklevel=2,
|
|
263
264
|
)
|
|
264
265
|
return
|
|
@@ -270,7 +271,7 @@ class ModelVersion:
|
|
|
270
271
|
if not self._model_version.manifest:
|
|
271
272
|
warnings.warn(
|
|
272
273
|
message="This model version was created using an older serialization format. environment does not exist, returning None",
|
|
273
|
-
category=
|
|
274
|
+
category=TrueFoundryDeprecationWarning,
|
|
274
275
|
stacklevel=2,
|
|
275
276
|
)
|
|
276
277
|
return self._environment
|
|
@@ -281,7 +282,7 @@ class ModelVersion:
|
|
|
281
282
|
if not self._model_version.manifest:
|
|
282
283
|
warnings.warn(
|
|
283
284
|
message="This model version was created using an older serialization format. Environment will not be updated",
|
|
284
|
-
category=
|
|
285
|
+
category=TrueFoundryDeprecationWarning,
|
|
285
286
|
stacklevel=2,
|
|
286
287
|
)
|
|
287
288
|
return
|
|
@@ -300,7 +301,7 @@ class ModelVersion:
|
|
|
300
301
|
if not self._model_version.manifest:
|
|
301
302
|
warnings.warn(
|
|
302
303
|
message="This model version was created using an older serialization format. Framework will not be updated",
|
|
303
|
-
category=
|
|
304
|
+
category=TrueFoundryDeprecationWarning,
|
|
304
305
|
stacklevel=2,
|
|
305
306
|
)
|
|
306
307
|
return
|
|
@@ -3,7 +3,7 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import posixpath
|
|
5
5
|
from pathlib import Path, PureWindowsPath
|
|
6
|
-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
7
7
|
|
|
8
8
|
from truefoundry.ml.exceptions import MlFoundryException
|
|
9
9
|
from truefoundry.ml.log_types.artifacts.constants import DESCRIPTION_MAX_LENGTH
|
|
@@ -63,7 +63,7 @@ def get_single_file_path_if_only_one_in_directory(path: str) -> Optional[str]:
|
|
|
63
63
|
|
|
64
64
|
# If it's a directory, check if it contains a single file
|
|
65
65
|
if is_destination_path_dirlike(path):
|
|
66
|
-
all_files = []
|
|
66
|
+
all_files: List[str] = []
|
|
67
67
|
for root, _, files in os.walk(path):
|
|
68
68
|
# Collect all files found in any subdirectory
|
|
69
69
|
all_files.extend(os.path.join(root, f) for f in files)
|
truefoundry/ml/mlfoundry_api.py
CHANGED
|
@@ -17,7 +17,7 @@ from typing import (
|
|
|
17
17
|
|
|
18
18
|
import coolname
|
|
19
19
|
|
|
20
|
-
from truefoundry.common.utils import ContextualDirectoryManager
|
|
20
|
+
from truefoundry.common.utils import ContextualDirectoryManager
|
|
21
21
|
from truefoundry.ml import constants
|
|
22
22
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
23
23
|
ArtifactDto,
|
|
@@ -63,10 +63,8 @@ from truefoundry.ml.log_types.artifacts.model import (
|
|
|
63
63
|
from truefoundry.ml.logger import logger
|
|
64
64
|
from truefoundry.ml.mlfoundry_run import MlFoundryRun
|
|
65
65
|
from truefoundry.ml.session import (
|
|
66
|
-
|
|
66
|
+
MLFoundrySession,
|
|
67
67
|
_get_api_client,
|
|
68
|
-
get_active_session,
|
|
69
|
-
init_session,
|
|
70
68
|
)
|
|
71
69
|
from truefoundry.ml.validation_utils import (
|
|
72
70
|
_validate_ml_repo_description,
|
|
@@ -112,13 +110,12 @@ class MlFoundry:
|
|
|
112
110
|
"""MlFoundry."""
|
|
113
111
|
|
|
114
112
|
# TODO (chiragjn): Don't allow session as None here!
|
|
115
|
-
def __init__(self, session:
|
|
113
|
+
def __init__(self, session: MLFoundrySession):
|
|
116
114
|
"""__init__
|
|
117
115
|
|
|
118
116
|
Args:
|
|
119
117
|
session (Optional[Session], optional): Session instance to get auth credentials from
|
|
120
118
|
"""
|
|
121
|
-
self._tracking_uri: str = session.tracking_uri
|
|
122
119
|
self._api_client = _get_api_client(session=session)
|
|
123
120
|
self._experiments_api = ExperimentsApi(api_client=self._api_client)
|
|
124
121
|
self._runs_api = RunsApi(api_client=self._api_client)
|
|
@@ -236,16 +233,9 @@ class MlFoundry:
|
|
|
236
233
|
raise MlFoundryException(err_msg) from e
|
|
237
234
|
return
|
|
238
235
|
|
|
239
|
-
session = get_active_session()
|
|
240
|
-
if session is None:
|
|
241
|
-
raise MlFoundryException(
|
|
242
|
-
relogin_error_message(
|
|
243
|
-
"No active session found. Perhaps you are not logged in?",
|
|
244
|
-
)
|
|
245
|
-
)
|
|
246
236
|
servicefoundry_client = ServiceFoundryServiceClient(
|
|
247
|
-
|
|
248
|
-
token=
|
|
237
|
+
tfy_host=self._api_client.tfy_host,
|
|
238
|
+
token=self._api_client.access_token,
|
|
249
239
|
)
|
|
250
240
|
|
|
251
241
|
assert existing_ml_repo.storage_integration_id is not None
|
|
@@ -604,26 +594,6 @@ class MlFoundry:
|
|
|
604
594
|
if not runs or page_token is None:
|
|
605
595
|
done = True
|
|
606
596
|
|
|
607
|
-
def get_tracking_uri(self) -> str:
|
|
608
|
-
"""
|
|
609
|
-
Get the current tracking URI.
|
|
610
|
-
|
|
611
|
-
Returns:
|
|
612
|
-
The tracking URI.
|
|
613
|
-
|
|
614
|
-
Examples:
|
|
615
|
-
|
|
616
|
-
```python
|
|
617
|
-
import tempfile
|
|
618
|
-
from truefoundry.ml import get_client
|
|
619
|
-
|
|
620
|
-
client = get_client()
|
|
621
|
-
tracking_uri = client.get_tracking_uri()
|
|
622
|
-
print("Current tracking uri: {}".format(tracking_uri))
|
|
623
|
-
```
|
|
624
|
-
"""
|
|
625
|
-
return self._tracking_uri
|
|
626
|
-
|
|
627
597
|
def _initialize_model_server(
|
|
628
598
|
self,
|
|
629
599
|
name: str,
|
|
@@ -1239,7 +1209,6 @@ class MlFoundry:
|
|
|
1239
1209
|
raise MlFoundryException(
|
|
1240
1210
|
"artifact_paths cannot be empty, atleast one artifact_path must be passed"
|
|
1241
1211
|
)
|
|
1242
|
-
|
|
1243
1212
|
ml_repo_id = self._get_ml_repo_id(ml_repo=ml_repo)
|
|
1244
1213
|
artifact_version = _log_artifact_version(
|
|
1245
1214
|
run=None,
|
|
@@ -1377,7 +1346,6 @@ class MlFoundry:
|
|
|
1377
1346
|
|
|
1378
1347
|
"""
|
|
1379
1348
|
ml_repo_id = self._get_ml_repo_id(ml_repo=ml_repo)
|
|
1380
|
-
|
|
1381
1349
|
model_version = _log_model_version(
|
|
1382
1350
|
run=None,
|
|
1383
1351
|
mlfoundry_artifacts_api=self._mlfoundry_artifacts_api,
|
|
@@ -1648,5 +1616,5 @@ def get_client() -> MlFoundry:
|
|
|
1648
1616
|
client = get_client()
|
|
1649
1617
|
```
|
|
1650
1618
|
"""
|
|
1651
|
-
session =
|
|
1619
|
+
session = MLFoundrySession.new()
|
|
1652
1620
|
return MlFoundry(session=session)
|
truefoundry/ml/mlfoundry_run.py
CHANGED
|
@@ -17,7 +17,6 @@ from typing import (
|
|
|
17
17
|
from urllib.parse import urljoin, urlsplit
|
|
18
18
|
|
|
19
19
|
from truefoundry import version
|
|
20
|
-
from truefoundry.common.utils import relogin_error_message
|
|
21
20
|
from truefoundry.ml import constants
|
|
22
21
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
23
22
|
ArtifactType,
|
|
@@ -55,7 +54,7 @@ from truefoundry.ml.log_types.artifacts.model import (
|
|
|
55
54
|
)
|
|
56
55
|
from truefoundry.ml.logger import logger
|
|
57
56
|
from truefoundry.ml.run_utils import ParamsType, flatten_dict, process_params
|
|
58
|
-
from truefoundry.ml.session import ACTIVE_RUNS, _get_api_client
|
|
57
|
+
from truefoundry.ml.session import ACTIVE_RUNS, _get_api_client
|
|
59
58
|
from truefoundry.ml.validation_utils import (
|
|
60
59
|
MAX_ENTITY_KEY_LENGTH,
|
|
61
60
|
MAX_METRICS_PER_BATCH,
|
|
@@ -72,7 +71,7 @@ if TYPE_CHECKING:
|
|
|
72
71
|
|
|
73
72
|
def _ensure_not_deleted(method):
|
|
74
73
|
@functools.wraps(method)
|
|
75
|
-
def _check_deleted_or_not(self, *args, **kwargs):
|
|
74
|
+
def _check_deleted_or_not(self: "MlFoundryRun", *args, **kwargs):
|
|
76
75
|
if self._deleted:
|
|
77
76
|
raise MlFoundryException("Run was deleted, cannot access a deleted Run")
|
|
78
77
|
else:
|
|
@@ -230,18 +229,10 @@ class MlFoundryRun:
|
|
|
230
229
|
@_ensure_not_deleted
|
|
231
230
|
def dashboard_link(self) -> str:
|
|
232
231
|
"""Get Mlfoundry dashboard link for a `run`"""
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
raise MlFoundryException(
|
|
236
|
-
relogin_error_message(
|
|
237
|
-
"No active session found. Perhaps you are not logged in?",
|
|
238
|
-
)
|
|
239
|
-
)
|
|
240
|
-
base_url = "{uri.scheme}://{uri.netloc}/".format(
|
|
241
|
-
uri=urlsplit(session.tracking_uri)
|
|
232
|
+
tfy_host = "{uri.scheme}://{uri.netloc}/".format(
|
|
233
|
+
uri=urlsplit(self._api_client.tfy_host)
|
|
242
234
|
)
|
|
243
|
-
|
|
244
|
-
return urljoin(base_url, f"mlfoundry/{self._experiment_id}/run/{self.run_id}/")
|
|
235
|
+
return urljoin(tfy_host, f"mlfoundry/{self._experiment_id}/run/{self.run_id}/")
|
|
245
236
|
|
|
246
237
|
@_ensure_not_deleted
|
|
247
238
|
def end(self, status: RunStatus = RunStatus.FINISHED):
|
|
@@ -581,7 +572,7 @@ class MlFoundryRun:
|
|
|
581
572
|
)
|
|
582
573
|
|
|
583
574
|
return _log_artifact_version(
|
|
584
|
-
self,
|
|
575
|
+
run=self,
|
|
585
576
|
name=name,
|
|
586
577
|
artifact_paths=artifact_paths,
|
|
587
578
|
description=description,
|
|
@@ -20,6 +20,7 @@ from truefoundry.common.utils import (
|
|
|
20
20
|
get_python_version_major_minor,
|
|
21
21
|
list_pip_packages_installed,
|
|
22
22
|
)
|
|
23
|
+
from truefoundry.common.warnings import TrueFoundryDeprecationWarning
|
|
23
24
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
24
25
|
ModelVersionEnvironment,
|
|
25
26
|
SklearnSerializationFormat,
|
|
@@ -260,7 +261,7 @@ class _ModelFramework(BaseModel):
|
|
|
260
261
|
if isinstance(framework, (str, ModelFramework)):
|
|
261
262
|
warnings.warn(
|
|
262
263
|
"Passing a string or ModelFramework Enum is deprecated. Please use a ModelFrameworkType object.",
|
|
263
|
-
|
|
264
|
+
category=TrueFoundryDeprecationWarning,
|
|
264
265
|
stacklevel=2,
|
|
265
266
|
)
|
|
266
267
|
|
truefoundry/ml/session.py
CHANGED
|
@@ -3,26 +3,22 @@ import threading
|
|
|
3
3
|
import weakref
|
|
4
4
|
from typing import TYPE_CHECKING, Dict, Optional
|
|
5
5
|
|
|
6
|
-
from truefoundry.common.credential_provider import (
|
|
7
|
-
CredentialProvider,
|
|
8
|
-
EnvCredentialProvider,
|
|
9
|
-
FileCredentialProvider,
|
|
10
|
-
)
|
|
11
|
-
from truefoundry.common.entities import Token, UserInfo
|
|
12
6
|
from truefoundry.common.request_utils import urllib3_retry
|
|
7
|
+
from truefoundry.common.session import Session
|
|
13
8
|
from truefoundry.common.utils import get_tfy_servers_config, relogin_error_message
|
|
14
9
|
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
15
10
|
ApiClient,
|
|
16
11
|
Configuration,
|
|
17
12
|
)
|
|
18
|
-
from truefoundry.ml.clients.entities import HostCreds
|
|
19
13
|
from truefoundry.ml.exceptions import MlFoundryException
|
|
20
14
|
from truefoundry.ml.logger import logger
|
|
15
|
+
from truefoundry.version import __version__
|
|
21
16
|
|
|
22
17
|
if TYPE_CHECKING:
|
|
23
18
|
from truefoundry.ml.mlfoundry_run import MlFoundryRun
|
|
24
19
|
|
|
25
20
|
SESSION_LOCK = threading.RLock()
|
|
21
|
+
ACTIVE_SESSION: Optional["MLFoundrySession"] = None
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
class ActiveRuns:
|
|
@@ -51,19 +47,7 @@ ACTIVE_RUNS = ActiveRuns()
|
|
|
51
47
|
atexit.register(ACTIVE_RUNS.close_active_runs)
|
|
52
48
|
|
|
53
49
|
|
|
54
|
-
class Session:
|
|
55
|
-
def __init__(self, cred_provider: CredentialProvider):
|
|
56
|
-
# Note: Whenever a new session is initialized all the active runs are ended
|
|
57
|
-
self._closed = False
|
|
58
|
-
self._cred_provider: Optional[CredentialProvider] = cred_provider
|
|
59
|
-
self._user_info: Optional[UserInfo] = self._cred_provider.token.to_user_info()
|
|
60
|
-
|
|
61
|
-
def close(self):
|
|
62
|
-
logger.debug("Closing existing session")
|
|
63
|
-
self._closed = True
|
|
64
|
-
self._user_info = None
|
|
65
|
-
self._cred_provider = None
|
|
66
|
-
|
|
50
|
+
class MLFoundrySession(Session):
|
|
67
51
|
def _assert_not_closed(self):
|
|
68
52
|
if self._closed:
|
|
69
53
|
raise MlFoundryException(
|
|
@@ -72,100 +56,88 @@ class Session:
|
|
|
72
56
|
"`truefoundry.ml.get_client()` function call) can be used"
|
|
73
57
|
)
|
|
74
58
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
59
|
+
def close(self):
|
|
60
|
+
global ACTIVE_RUNS
|
|
61
|
+
logger.debug("Closing existing session")
|
|
62
|
+
ACTIVE_RUNS.close_active_runs()
|
|
63
|
+
super().close()
|
|
78
64
|
|
|
79
|
-
@
|
|
80
|
-
def
|
|
81
|
-
|
|
82
|
-
|
|
65
|
+
@classmethod
|
|
66
|
+
def new(cls) -> "MLFoundrySession":
|
|
67
|
+
global ACTIVE_SESSION
|
|
68
|
+
with SESSION_LOCK:
|
|
69
|
+
new_session = cls()
|
|
70
|
+
if ACTIVE_SESSION and ACTIVE_SESSION == new_session:
|
|
71
|
+
return ACTIVE_SESSION
|
|
72
|
+
|
|
73
|
+
if ACTIVE_SESSION:
|
|
74
|
+
ACTIVE_SESSION.close()
|
|
75
|
+
|
|
76
|
+
ACTIVE_SESSION = new_session
|
|
77
|
+
logger.info(
|
|
78
|
+
"Logged in to %r as %r (%s)",
|
|
79
|
+
new_session.tfy_host,
|
|
80
|
+
new_session.user_info.user_id,
|
|
81
|
+
new_session.user_info.email or new_session.user_info.user_type.value,
|
|
82
|
+
)
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
@property
|
|
86
|
-
def tracking_uri(self) -> str:
|
|
87
|
-
return self._cred_provider.base_url
|
|
88
|
-
|
|
89
|
-
def __eq__(self, other: object) -> bool:
|
|
90
|
-
if not isinstance(other, Session):
|
|
91
|
-
return False
|
|
92
|
-
return (
|
|
93
|
-
type(self._cred_provider) == type(other._cred_provider) # noqa: E721
|
|
94
|
-
and self.user_info == other.user_info
|
|
95
|
-
and self.tracking_uri == other.tracking_uri
|
|
96
|
-
)
|
|
84
|
+
return ACTIVE_SESSION
|
|
97
85
|
|
|
98
|
-
def get_host_creds(self) -> HostCreds:
|
|
99
|
-
tracking_uri = get_tfy_servers_config(self.tracking_uri).mlfoundry_server_url
|
|
100
|
-
return HostCreds(
|
|
101
|
-
host=tracking_uri, token=self._cred_provider.token.access_token
|
|
102
|
-
)
|
|
103
86
|
|
|
87
|
+
class MLFoundryServerApiClient(ApiClient):
|
|
88
|
+
def __init__(self, session: Optional[MLFoundrySession] = None, *args, **kwargs):
|
|
89
|
+
self.session = session
|
|
90
|
+
super().__init__(*args, **kwargs)
|
|
104
91
|
|
|
105
|
-
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_session(cls, session: MLFoundrySession) -> "MLFoundryServerApiClient":
|
|
94
|
+
mlfoundry_server_url = get_tfy_servers_config(
|
|
95
|
+
session.tfy_host
|
|
96
|
+
).mlfoundry_server_url
|
|
97
|
+
configuration = Configuration(
|
|
98
|
+
host=mlfoundry_server_url.rstrip("/"),
|
|
99
|
+
access_token=session.access_token,
|
|
100
|
+
)
|
|
101
|
+
configuration.retries = urllib3_retry(retries=2)
|
|
102
|
+
api_client = cls(session=session, configuration=configuration)
|
|
103
|
+
api_client.user_agent = f"truefoundry-cli/{__version__}"
|
|
104
|
+
return api_client
|
|
105
|
+
|
|
106
|
+
def _ensure_session(self):
|
|
107
|
+
if self.session is None:
|
|
108
|
+
raise MlFoundryException(
|
|
109
|
+
relogin_error_message(
|
|
110
|
+
"No active session found. Perhaps you are not logged in?",
|
|
111
|
+
)
|
|
112
|
+
)
|
|
106
113
|
|
|
114
|
+
@property
|
|
115
|
+
def tfy_host(self) -> str:
|
|
116
|
+
self._ensure_session()
|
|
117
|
+
assert self.session is not None
|
|
118
|
+
return self.session.tfy_host
|
|
107
119
|
|
|
108
|
-
|
|
109
|
-
|
|
120
|
+
@property
|
|
121
|
+
def access_token(self) -> str:
|
|
122
|
+
self._ensure_session()
|
|
123
|
+
assert self.session is not None
|
|
124
|
+
return self.session.access_token
|
|
110
125
|
|
|
111
126
|
|
|
112
127
|
def _get_api_client(
|
|
113
|
-
session: Optional[
|
|
128
|
+
session: Optional[MLFoundrySession] = None,
|
|
114
129
|
allow_anonymous: bool = False,
|
|
115
|
-
) ->
|
|
116
|
-
|
|
130
|
+
) -> MLFoundryServerApiClient:
|
|
131
|
+
global ACTIVE_SESSION
|
|
117
132
|
|
|
118
|
-
session = session or
|
|
133
|
+
session = session or ACTIVE_SESSION
|
|
119
134
|
if session is None:
|
|
120
135
|
if allow_anonymous:
|
|
121
|
-
return
|
|
136
|
+
return MLFoundryServerApiClient(session=None)
|
|
122
137
|
else:
|
|
123
138
|
raise MlFoundryException(
|
|
124
139
|
relogin_error_message(
|
|
125
140
|
"No active session found. Perhaps you are not logged in?",
|
|
126
141
|
)
|
|
127
142
|
)
|
|
128
|
-
|
|
129
|
-
creds = session.get_host_creds()
|
|
130
|
-
configuration = Configuration(
|
|
131
|
-
host=creds.host.rstrip("/"),
|
|
132
|
-
access_token=creds.token,
|
|
133
|
-
)
|
|
134
|
-
configuration.retries = urllib3_retry(retries=2)
|
|
135
|
-
api_client = ApiClient(configuration=configuration)
|
|
136
|
-
api_client.user_agent = f"truefoundry-cli/{__version__}"
|
|
137
|
-
return api_client
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def init_session() -> Session:
|
|
141
|
-
with SESSION_LOCK:
|
|
142
|
-
final_cred_provider = None
|
|
143
|
-
for cred_provider in [EnvCredentialProvider, FileCredentialProvider]:
|
|
144
|
-
if cred_provider.can_provide():
|
|
145
|
-
final_cred_provider = cred_provider()
|
|
146
|
-
break
|
|
147
|
-
if final_cred_provider is None:
|
|
148
|
-
raise MlFoundryException(
|
|
149
|
-
relogin_error_message(
|
|
150
|
-
"No active session found. Perhaps you are not logged in?",
|
|
151
|
-
)
|
|
152
|
-
)
|
|
153
|
-
new_session = Session(cred_provider=final_cred_provider)
|
|
154
|
-
|
|
155
|
-
global ACTIVE_SESSION
|
|
156
|
-
if ACTIVE_SESSION and ACTIVE_SESSION == new_session:
|
|
157
|
-
return ACTIVE_SESSION
|
|
158
|
-
|
|
159
|
-
ACTIVE_RUNS.close_active_runs()
|
|
160
|
-
|
|
161
|
-
if ACTIVE_SESSION:
|
|
162
|
-
ACTIVE_SESSION.close()
|
|
163
|
-
ACTIVE_SESSION = new_session
|
|
164
|
-
|
|
165
|
-
logger.info(
|
|
166
|
-
"Logged in to %r as %r (%s)",
|
|
167
|
-
ACTIVE_SESSION.tracking_uri,
|
|
168
|
-
ACTIVE_SESSION.user_info.user_id,
|
|
169
|
-
ACTIVE_SESSION.user_info.email or ACTIVE_SESSION.user_info.user_type.value,
|
|
170
|
-
)
|
|
171
|
-
return ACTIVE_SESSION
|
|
143
|
+
return MLFoundryServerApiClient.from_session(session)
|