truefoundry 0.5.3rc4__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (57) hide show
  1. truefoundry/__init__.py +10 -1
  2. truefoundry/autodeploy/cli.py +2 -2
  3. truefoundry/cli/__main__.py +0 -4
  4. truefoundry/cli/util.py +12 -3
  5. truefoundry/common/auth_service_client.py +7 -4
  6. truefoundry/common/constants.py +3 -1
  7. truefoundry/common/credential_provider.py +7 -8
  8. truefoundry/common/exceptions.py +11 -7
  9. truefoundry/common/request_utils.py +96 -14
  10. truefoundry/common/servicefoundry_client.py +31 -29
  11. truefoundry/common/session.py +93 -0
  12. truefoundry/common/storage_provider_utils.py +331 -0
  13. truefoundry/common/utils.py +9 -9
  14. truefoundry/common/warnings.py +21 -0
  15. truefoundry/deploy/builder/builders/tfy_python_buildpack/dockerfile_template.py +8 -20
  16. truefoundry/deploy/cli/commands/deploy_command.py +4 -4
  17. truefoundry/deploy/lib/clients/servicefoundry_client.py +14 -38
  18. truefoundry/deploy/lib/dao/application.py +2 -2
  19. truefoundry/deploy/lib/dao/workspace.py +1 -1
  20. truefoundry/deploy/lib/session.py +8 -1
  21. truefoundry/deploy/v2/lib/deploy.py +2 -2
  22. truefoundry/deploy/v2/lib/deploy_workflow.py +1 -1
  23. truefoundry/deploy/v2/lib/patched_models.py +70 -4
  24. truefoundry/deploy/v2/lib/source.py +2 -1
  25. truefoundry/ml/artifact/truefoundry_artifact_repo.py +33 -297
  26. truefoundry/ml/autogen/client/__init__.py +2 -2
  27. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +18 -16
  28. truefoundry/ml/autogen/client/models/__init__.py +2 -2
  29. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +23 -5
  30. truefoundry/ml/autogen/client/models/{get_artifact_tags_response_dto.py → get_artifact_version_aliases_response_dto.py} +12 -10
  31. truefoundry/ml/autogen/client/models/model_version_manifest.py +16 -5
  32. truefoundry/ml/autogen/client_README.md +2 -2
  33. truefoundry/ml/autogen/entities/artifacts.py +4 -9
  34. truefoundry/ml/clients/servicefoundry_client.py +36 -15
  35. truefoundry/ml/exceptions.py +2 -1
  36. truefoundry/ml/log_types/artifacts/artifact.py +16 -15
  37. truefoundry/ml/log_types/artifacts/model.py +20 -19
  38. truefoundry/ml/log_types/artifacts/utils.py +2 -2
  39. truefoundry/ml/mlfoundry_api.py +6 -38
  40. truefoundry/ml/mlfoundry_run.py +6 -15
  41. truefoundry/ml/model_framework.py +2 -1
  42. truefoundry/ml/session.py +69 -97
  43. truefoundry/workflow/remote_filesystem/tfy_signed_url_client.py +42 -9
  44. truefoundry/workflow/remote_filesystem/tfy_signed_url_fs.py +126 -7
  45. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.4.dist-info}/METADATA +2 -2
  46. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.4.dist-info}/RECORD +48 -54
  47. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.4.dist-info}/WHEEL +1 -1
  48. truefoundry/cli/commands/pat.py +0 -24
  49. truefoundry/deploy/lib/auth/servicefoundry_session.py +0 -61
  50. truefoundry/gateway/__init__.py +0 -1
  51. truefoundry/gateway/cli/cli.py +0 -51
  52. truefoundry/gateway/lib/client.py +0 -51
  53. truefoundry/gateway/lib/entities.py +0 -33
  54. truefoundry/gateway/lib/models.py +0 -67
  55. truefoundry/ml/clients/entities.py +0 -8
  56. truefoundry/ml/clients/utils.py +0 -122
  57. {truefoundry-0.5.3rc4.dist-info → truefoundry-0.5.4.dist-info}/entry_points.txt +0 -0
truefoundry/__init__.py CHANGED
@@ -1,3 +1,12 @@
1
+ from truefoundry.common.warnings import (
2
+ suppress_truefoundry_deprecation_warnings,
3
+ surface_truefoundry_deprecation_warnings,
4
+ )
1
5
  from truefoundry.deploy.core import login, logout
2
6
 
3
- __all__ = ["login", "logout"]
7
+ surface_truefoundry_deprecation_warnings()
8
+ __all__ = [
9
+ "login",
10
+ "logout",
11
+ "suppress_truefoundry_deprecation_warnings",
12
+ ]
@@ -32,8 +32,8 @@ from truefoundry.autodeploy.tools.commit import CommitConfirmation
32
32
  from truefoundry.autodeploy.tools.docker_run import DockerRun, DockerRunLog
33
33
  from truefoundry.autodeploy.utils.client import get_git_binary
34
34
  from truefoundry.cli.const import COMMAND_CLS
35
+ from truefoundry.common.session import Session
35
36
  from truefoundry.deploy import Build, DockerFileBuild, Job, LocalSource, Port, Service
36
- from truefoundry.deploy.lib.auth.servicefoundry_session import ServiceFoundrySession
37
37
  from truefoundry.deploy.lib.clients.servicefoundry_client import (
38
38
  ServiceFoundryServiceClient,
39
39
  )
@@ -45,7 +45,7 @@ def _get_openai_client() -> OpenAI:
45
45
  api_key=AUTODEPLOY_OPENAI_API_KEY, base_url=AUTODEPLOY_OPENAI_BASE_URL
46
46
  )
47
47
  try:
48
- session = ServiceFoundrySession()
48
+ session = Session.new()
49
49
  resp = requests.get(
50
50
  f"{AUTODEPLOY_TFY_BASE_URL}/api/svc/v1/llm-gateway/access-details",
51
51
  headers={
@@ -4,7 +4,6 @@ import sys
4
4
  import rich_click as click
5
5
 
6
6
  from truefoundry import logger
7
- from truefoundry.cli.commands.pat import get_generate_pat_command
8
7
  from truefoundry.cli.config import CliConfig
9
8
  from truefoundry.cli.const import GROUP_CLS
10
9
  from truefoundry.cli.util import setup_rich_click
@@ -22,7 +21,6 @@ from truefoundry.deploy.cli.commands import (
22
21
  get_terminate_command,
23
22
  get_trigger_command,
24
23
  )
25
- from truefoundry.gateway.cli.cli import get_gateway_cli
26
24
  from truefoundry.ml.cli.cli import get_ml_cli
27
25
  from truefoundry.version import __version__
28
26
 
@@ -90,8 +88,6 @@ def create_truefoundry_cli() -> click.MultiCommand:
90
88
  cli.add_command(get_trigger_command())
91
89
  cli.add_command(get_terminate_command())
92
90
  cli.add_command(get_ml_cli())
93
- cli.add_command(get_gateway_cli())
94
- cli.add_command(get_generate_pat_command())
95
91
 
96
92
  if not (sys.platform.startswith("win32") or sys.platform.startswith("cygwin")):
97
93
  cli.add_command(get_patch_command())
truefoundry/cli/util.py CHANGED
@@ -5,13 +5,13 @@ from typing import Dict
5
5
 
6
6
  import rich_click as click
7
7
  from packaging.version import parse as parse_version
8
- from requests.exceptions import ConnectionError
8
+ from requests.exceptions import ConnectionError, Timeout
9
9
  from rich.padding import Padding
10
10
  from rich.panel import Panel
11
11
  from rich.table import Table
12
12
 
13
13
  from truefoundry.cli.console import console
14
- from truefoundry.common.exceptions import BadRequestException
14
+ from truefoundry.common.exceptions import HttpRequestException
15
15
  from truefoundry.common.utils import is_debug_env_set
16
16
 
17
17
 
@@ -25,12 +25,21 @@ def setup_rich_click():
25
25
  def handle_exception(exception):
26
26
  if is_debug_env_set():
27
27
  console.print_exception(show_locals=True)
28
- if isinstance(exception, BadRequestException):
28
+ if isinstance(exception, HttpRequestException):
29
29
  print_dict_as_table_panel(
30
30
  {"Status Code": str(exception.status_code), "Error": exception.message},
31
31
  title="Command Failed",
32
32
  border_style="red",
33
33
  )
34
+ elif isinstance(exception, Timeout):
35
+ loc = ""
36
+ if exception.request:
37
+ loc = f" at {exception.request.url}"
38
+ print_dict_as_table_panel(
39
+ {"Error": f"Request to TrueFoundry{loc} timed out."},
40
+ title="Command Failed",
41
+ border_style="red",
42
+ )
34
43
  elif isinstance(exception, ConnectionError):
35
44
  loc = ""
36
45
  if exception.request:
@@ -7,7 +7,10 @@ import requests
7
7
  from truefoundry.common.constants import VERSION_PREFIX
8
8
  from truefoundry.common.entities import DeviceCode, Token
9
9
  from truefoundry.common.exceptions import BadRequestException
10
- from truefoundry.common.request_utils import request_handling, requests_retry_session
10
+ from truefoundry.common.request_utils import (
11
+ request_handling,
12
+ requests_retry_session,
13
+ )
11
14
  from truefoundry.common.utils import poll_for_function, relogin_error_message
12
15
  from truefoundry.logger import logger
13
16
 
@@ -16,14 +19,13 @@ class AuthServiceClient(ABC):
16
19
  def __init__(self, tenant_name: str):
17
20
  self._tenant_name = tenant_name
18
21
 
19
- # TODO (chiragjn): Rename base_url to tfy_host
20
22
  @classmethod
21
- def from_base_url(cls, base_url: str) -> "AuthServiceClient":
23
+ def from_tfy_host(cls, tfy_host: str) -> "AuthServiceClient":
22
24
  from truefoundry.common.servicefoundry_client import (
23
25
  ServiceFoundryServiceClient,
24
26
  )
25
27
 
26
- client = ServiceFoundryServiceClient(base_url=base_url)
28
+ client = ServiceFoundryServiceClient(tfy_host=tfy_host)
27
29
  if client.python_sdk_config.use_sfy_server_auth_apis:
28
30
  return ServiceFoundryServerAuthServiceClient(
29
31
  tenant_name=client.tenant_info.tenant_name, url=client._api_server_url
@@ -159,6 +161,7 @@ class AuthServerServiceClient(AuthServiceClient):
159
161
  session = requests_retry_session(retries=2)
160
162
  response = session.post(url, json=data)
161
163
  response_data = request_handling(response)
164
+ assert isinstance(response_data, dict)
162
165
  # TODO: temporary cleanup of incorrect attributes
163
166
  return DeviceCode.parse_obj(
164
167
  {
@@ -38,6 +38,9 @@ class TrueFoundrySdkEnv(BaseSettings):
38
38
  )
39
39
  TFY_INTERNAL_SIGNED_URL_SERVER_MAX_TIMEOUT: int = 5 # default: 5 seconds
40
40
  TFY_INTERNAL_SIGNED_URL_SERVER_DEFAULT_TTL: int = 3600 # default: 1 hour
41
+ TFY_INTERNAL_MULTIPART_UPLOAD_FINALIZE_SIGNED_URL_TIMEOUT: int = (
42
+ 24 * 60 * 60
43
+ ) # default: 24 hour
41
44
  TFY_INTERNAL_SIGNED_URL_REQUEST_TIMEOUT: int = 3600 # default: 1 hour
42
45
  TFY_INTERNAL_SIGNED_URL_CLIENT_LOG_LEVEL: str = "WARNING"
43
46
 
@@ -63,4 +66,3 @@ API_SERVER_RELATIVE_PATH = "api/svc"
63
66
  MLFOUNDRY_SERVER_RELATIVE_PATH = "api/ml"
64
67
  VERSION_PREFIX = "v1"
65
68
  SERVICEFOUNDRY_CLIENT_MAX_RETRIES = 2
66
- GATEWAY_SERVER_RELATIVE_PATH = "api/llm"
@@ -17,10 +17,9 @@ class CredentialProvider(ABC):
17
17
  @abstractmethod
18
18
  def token(self) -> Token: ...
19
19
 
20
- # TODO (chiragjn): Rename base_url to tfy_host
21
20
  @property
22
21
  @abstractmethod
23
- def base_url(self) -> str: ...
22
+ def tfy_host(self) -> str: ...
24
23
 
25
24
  @staticmethod
26
25
  @abstractmethod
@@ -38,7 +37,7 @@ class EnvCredentialProvider(CredentialProvider):
38
37
  f"Value of {TFY_API_KEY_ENV_KEY} env var should be non-empty string"
39
38
  )
40
39
  self._host = resolve_tfy_host()
41
- self._auth_service = AuthServiceClient.from_base_url(base_url=self._host)
40
+ self._auth_service = AuthServiceClient.from_tfy_host(tfy_host=self._host)
42
41
  self._token: Token = Token(access_token=api_key, refresh_token=None) # type: ignore[call-arg]
43
42
 
44
43
  @staticmethod
@@ -51,12 +50,12 @@ class EnvCredentialProvider(CredentialProvider):
51
50
  if self._token.is_going_to_be_expired():
52
51
  logger.info("Refreshing access token")
53
52
  self._token = self._auth_service.refresh_token(
54
- self._token, self.base_url
53
+ self._token, self.tfy_host
55
54
  )
56
55
  return self._token
57
56
 
58
57
  @property
59
- def base_url(self) -> str:
58
+ def tfy_host(self) -> str:
60
59
  return self._host
61
60
 
62
61
 
@@ -70,7 +69,7 @@ class FileCredentialProvider(CredentialProvider):
70
69
  self._token = self._last_cred_file_content.to_token()
71
70
  self._host = self._last_cred_file_content.host
72
71
 
73
- self._auth_service = AuthServiceClient.from_base_url(base_url=self._host)
72
+ self._auth_service = AuthServiceClient.from_tfy_host(tfy_host=self._host)
74
73
 
75
74
  @staticmethod
76
75
  def can_provide() -> bool:
@@ -91,7 +90,7 @@ class FileCredentialProvider(CredentialProvider):
91
90
 
92
91
  if new_cred_file_content == self._last_cred_file_content:
93
92
  self._token = self._auth_service.refresh_token(
94
- self._token, self.base_url
93
+ self._token, self.tfy_host
95
94
  )
96
95
  self._last_cred_file_content = CredentialsFileContent(
97
96
  host=self._host,
@@ -115,5 +114,5 @@ class FileCredentialProvider(CredentialProvider):
115
114
  )
116
115
 
117
116
  @property
118
- def base_url(self) -> str:
117
+ def tfy_host(self) -> str:
119
118
  return self._host
@@ -1,12 +1,16 @@
1
1
  from typing import Optional
2
2
 
3
3
 
4
- class BadRequestException(Exception):
5
- def __init__(self, status_code: int, message: Optional[str] = None):
6
- super().__init__()
7
- self.status_code = status_code
8
- self.message = message
4
+ # TODO (chiragjn): We need to establish uniform exception handling across codebase
5
+ class HttpRequestException(Exception):
6
+ def __init__(self, message: str, status_code: Optional[int] = None):
7
+ self.message = str(message)
9
8
  self.status_code = status_code
9
+ super().__init__(message)
10
+
11
+ def __str__(self) -> str:
12
+ return self.message or ""
13
+
10
14
 
11
- def __str__(self):
12
- return self.message
15
+ class BadRequestException(HttpRequestException):
16
+ pass
@@ -1,30 +1,40 @@
1
- from typing import Optional
1
+ import json
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, Optional, Type, Union
2
4
 
3
5
  import requests
4
6
  from requests import Response
5
7
  from requests.adapters import HTTPAdapter
6
8
  from urllib3 import Retry
7
9
 
8
- from truefoundry.common.exceptions import BadRequestException
10
+ from truefoundry.common.exceptions import BadRequestException, HttpRequestException
9
11
 
10
12
 
11
- def request_handling(response: Response):
12
- try:
13
- status_code = response.status_code
14
- except Exception as e:
15
- raise Exception("Unknown error occurred. Couldn't get status code.") from e
13
+ # TODO (chiragjn): Rename this to json_response_handling
14
+ def request_handling(response: Response) -> Optional[Any]:
15
+ status_code = response.status_code
16
16
  if 200 <= status_code <= 299:
17
17
  if response.content == b"":
18
+ # TODO (chiragjn): Do we really need this empty check?
18
19
  return None
19
- return response.json()
20
- if 400 <= status_code <= 499:
21
20
  try:
22
- message = str(response.json()["message"])
23
- except Exception:
24
- message = response.text
25
- raise BadRequestException(status_code=response.status_code, message=message)
21
+ return response.json()
22
+ except json.JSONDecodeError as je:
23
+ raise ValueError(
24
+ f"Failed to parse response as json. Response: {response.text}",
25
+ ) from je
26
+
27
+ try:
28
+ message = str(response.json()["message"])
29
+ except Exception:
30
+ message = response.text
31
+
32
+ if 400 <= status_code <= 499:
33
+ raise BadRequestException(message=message, status_code=response.status_code)
26
34
  if 500 <= status_code <= 599:
27
- raise Exception(response.content)
35
+ raise HttpRequestException(message=message, status_code=response.status_code)
36
+
37
+ raise HttpRequestException("Unknown error occurred", status_code=status_code)
28
38
 
29
39
 
30
40
  def urllib3_retry(
@@ -82,3 +92,75 @@ def requests_retry_session(
82
92
  session.mount("http://", adapter)
83
93
  session.mount("https://", adapter)
84
94
  return session
95
+
96
+
97
+ def http_request(
98
+ *,
99
+ method: str,
100
+ url: str,
101
+ token: Optional[str] = None,
102
+ timeout=None,
103
+ headers: Optional[Dict[str, str]] = None,
104
+ session: Optional[requests.Session] = None,
105
+ exception_class: Type[HttpRequestException] = HttpRequestException,
106
+ **kwargs,
107
+ ) -> requests.Response:
108
+ session = session or requests.Session()
109
+ headers = headers or {}
110
+ if token is not None:
111
+ headers["Authorization"] = f"Bearer {token}"
112
+ try:
113
+ response = session.request(
114
+ method=method, url=url, headers=headers, timeout=timeout, **kwargs
115
+ )
116
+ return response
117
+ except requests.exceptions.ConnectionError as ce:
118
+ raise exception_class("Failed to connect to TrueFoundry") from ce
119
+ except requests.exceptions.Timeout as te:
120
+ raise exception_class(f"Request to {url} timed out") from te
121
+ except Exception as e:
122
+ raise exception_class(f"Request to {url} failed with error {str(e)}") from e
123
+
124
+
125
+ @contextmanager
126
+ def cloud_storage_http_request(
127
+ *,
128
+ method: str,
129
+ url: str,
130
+ session: Optional[requests.Session] = None,
131
+ timeout: Optional[Union[int, float]] = None,
132
+ exception_class: Type[HttpRequestException] = HttpRequestException,
133
+ **kwargs,
134
+ ):
135
+ """
136
+ Performs an HTTP PUT/GET request using Python's `requests` module with automatic retry.
137
+ """
138
+ # Note: This does not support auth and is only meant to be used for pre-signed URLs.
139
+ session = session or requests_retry_session(retries=5, backoff_factor=0.5)
140
+ headers = kwargs.get("headers", {}) or {}
141
+ if "blob.core.windows.net" in url:
142
+ headers.update({"x-ms-blob-type": "BlockBlob"})
143
+ if method.lower() not in ("put", "get"):
144
+ raise ValueError(f"Illegal http method: {method}")
145
+ yield http_request(
146
+ method=method,
147
+ url=url,
148
+ session=session,
149
+ timeout=timeout,
150
+ headers=headers,
151
+ exception_class=exception_class,
152
+ **kwargs,
153
+ )
154
+
155
+
156
+ # TODO: Try and eliminate this
157
+ def augmented_raise_for_status(
158
+ response, exception_class: Type[HttpRequestException] = HttpRequestException
159
+ ):
160
+ try:
161
+ response.raise_for_status()
162
+ except requests.exceptions.HTTPError as he:
163
+ raise exception_class(
164
+ message=f"Request failed with status code {he.response.status_code}. Response: {he.response.text}",
165
+ status_code=he.response.status_code,
166
+ ) from he
@@ -13,7 +13,10 @@ from truefoundry.common.entities import (
13
13
  PythonSDKConfig,
14
14
  TenantInfo,
15
15
  )
16
- from truefoundry.common.request_utils import request_handling, requests_retry_session
16
+ from truefoundry.common.request_utils import (
17
+ request_handling,
18
+ requests_retry_session,
19
+ )
17
20
  from truefoundry.common.utils import (
18
21
  get_tfy_servers_config,
19
22
  timed_lru_cache,
@@ -25,19 +28,8 @@ from truefoundry.version import __version__
25
28
  def check_min_cli_version(fn):
26
29
  @functools.wraps(fn)
27
30
  def inner(*args, **kwargs):
28
- if __version__ != "0.0.0":
29
- client: "ServiceFoundryServiceClient" = args[0]
30
- # "0.0.0" indicates dev version
31
- # noinspection PyProtectedMember
32
- min_cli_version_required = client._min_cli_version_required
33
- if version.parse(__version__) < version.parse(min_cli_version_required):
34
- raise Exception(
35
- "You are using an outdated version of `truefoundry`.\n"
36
- f"Run `pip install truefoundry>={min_cli_version_required}` to install the supported version.",
37
- )
38
- else:
39
- logger.debug("Ignoring minimum cli version check")
40
-
31
+ client: "ServiceFoundryServiceClient" = args[0]
32
+ client.check_min_cli_version()
41
33
  return fn(*args, **kwargs)
42
34
 
43
35
  return inner
@@ -51,34 +43,33 @@ def session_with_retries(
51
43
 
52
44
  @timed_lru_cache(seconds=30 * 60)
53
45
  def _cached_get_tenant_info(api_server_url: str, tenant_host: str) -> TenantInfo:
54
- res = session_with_retries().get(
46
+ response = session_with_retries().get(
55
47
  url=f"{api_server_url}/{VERSION_PREFIX}/tenant-id",
56
48
  params={"hostName": tenant_host},
57
49
  )
58
- res = request_handling(res)
59
- return TenantInfo.parse_obj(res)
50
+ response_data = request_handling(response)
51
+ return TenantInfo.parse_obj(response_data)
60
52
 
61
53
 
62
54
  @timed_lru_cache(seconds=30 * 60)
63
55
  def _cached_get_python_sdk_config(api_server_url: str) -> PythonSDKConfig:
64
- res = session_with_retries().get(
56
+ response = session_with_retries().get(
65
57
  url=f"{api_server_url}/{VERSION_PREFIX}/min-cli-version"
66
58
  )
67
- res = request_handling(res)
68
- return PythonSDKConfig.parse_obj(res)
59
+ response_data = request_handling(response)
60
+ return PythonSDKConfig.parse_obj(response_data)
69
61
 
70
62
 
71
63
  class ServiceFoundryServiceClient:
72
- # TODO (chiragjn): Rename base_url to tfy_host
73
- def __init__(self, base_url: str):
74
- self._base_url = base_url.strip("/")
75
- tfy_servers_config = get_tfy_servers_config(self._base_url)
64
+ def __init__(self, tfy_host: str):
65
+ self._tfy_host = tfy_host.strip("/")
66
+ tfy_servers_config = get_tfy_servers_config(self._tfy_host)
76
67
  self._tenant_host = tfy_servers_config.tenant_host
77
68
  self._api_server_url = tfy_servers_config.servicefoundry_server_url
78
69
 
79
70
  @property
80
- def base_url(self) -> str:
81
- return self._base_url
71
+ def tfy_host(self) -> str:
72
+ return self._tfy_host
82
73
 
83
74
  @property
84
75
  def tenant_info(self) -> TenantInfo:
@@ -95,6 +86,17 @@ class ServiceFoundryServiceClient:
95
86
 
96
87
  @functools.cached_property
97
88
  def _min_cli_version_required(self) -> str:
98
- return _cached_get_python_sdk_config(
99
- self._api_server_url
100
- ).truefoundry_cli_min_version
89
+ return self.python_sdk_config.truefoundry_cli_min_version
90
+
91
+ def check_min_cli_version(self) -> None:
92
+ if __version__ != "0.0.0":
93
+ # "0.0.0" indicates dev version
94
+ # noinspection PyProtectedMember
95
+ min_cli_version_required = self._min_cli_version_required
96
+ if version.parse(__version__) < version.parse(min_cli_version_required):
97
+ raise Exception(
98
+ "You are using an outdated version of `truefoundry`.\n"
99
+ f"Run `pip install truefoundry>={min_cli_version_required}` to install the supported version.",
100
+ )
101
+ else:
102
+ logger.debug("Ignoring minimum cli version check")
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import threading
4
+ from typing import Optional
5
+
6
+ from truefoundry.common.credential_provider import (
7
+ CredentialProvider,
8
+ EnvCredentialProvider,
9
+ FileCredentialProvider,
10
+ )
11
+ from truefoundry.common.entities import UserInfo
12
+ from truefoundry.common.utils import relogin_error_message
13
+ from truefoundry.logger import logger
14
+
15
+ SESSION_LOCK = threading.RLock()
16
+ ACTIVE_SESSION: Optional["Session"] = None
17
+
18
+
19
+ class Session:
20
+ def __init__(self) -> None:
21
+ self._closed = False
22
+ self._cred_provider: Optional[CredentialProvider] = self._get_cred_provider()
23
+ self._user_info: Optional[UserInfo] = self._cred_provider.token.to_user_info()
24
+
25
+ @staticmethod
26
+ def _get_cred_provider() -> CredentialProvider:
27
+ final_cred_provider = None
28
+ for cred_provider in [EnvCredentialProvider, FileCredentialProvider]:
29
+ if cred_provider.can_provide():
30
+ final_cred_provider = cred_provider()
31
+ break
32
+ if final_cred_provider is None:
33
+ raise Exception(
34
+ relogin_error_message(
35
+ "No active session found. Perhaps you are not logged in?",
36
+ )
37
+ )
38
+ return final_cred_provider
39
+
40
+ @classmethod
41
+ def new(cls) -> "Session":
42
+ global ACTIVE_SESSION
43
+ with SESSION_LOCK:
44
+ new_session = cls()
45
+ if ACTIVE_SESSION and ACTIVE_SESSION == new_session:
46
+ return ACTIVE_SESSION
47
+
48
+ if ACTIVE_SESSION:
49
+ ACTIVE_SESSION.close()
50
+
51
+ ACTIVE_SESSION = new_session
52
+ logger.info(
53
+ "Logged in to %r as %r (%s)",
54
+ new_session.tfy_host,
55
+ new_session.user_info.user_id,
56
+ new_session.user_info.email or new_session.user_info.user_type.value,
57
+ )
58
+
59
+ return ACTIVE_SESSION
60
+
61
+ def close(self):
62
+ self._closed = True
63
+ self._user_info = None
64
+ self._cred_provider = None
65
+
66
+ def _assert_not_closed(self):
67
+ if self._closed:
68
+ raise Exception("This session has been deactivated.")
69
+
70
+ @property
71
+ def access_token(self) -> str:
72
+ assert self._cred_provider is not None
73
+ return self._cred_provider.token.access_token
74
+
75
+ @property
76
+ def tfy_host(self) -> str:
77
+ assert self._cred_provider is not None
78
+ return self._cred_provider.tfy_host
79
+
80
+ @property
81
+ def user_info(self) -> UserInfo:
82
+ self._assert_not_closed()
83
+ assert self._user_info is not None
84
+ return self._user_info
85
+
86
+ def __eq__(self, other: object) -> bool:
87
+ if not isinstance(other, Session):
88
+ return False
89
+ return (
90
+ type(self._cred_provider) == type(other._cred_provider) # noqa: E721
91
+ and self.user_info == other.user_info
92
+ and self.tfy_host == other.tfy_host
93
+ )