truefoundry 0.4.0rc1__py3-none-any.whl → 0.4.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (83) hide show
  1. truefoundry/__init__.py +2 -0
  2. truefoundry/autodeploy/agents/developer.py +1 -1
  3. truefoundry/autodeploy/agents/project_identifier.py +2 -2
  4. truefoundry/autodeploy/agents/tester.py +1 -1
  5. truefoundry/autodeploy/cli.py +1 -1
  6. truefoundry/autodeploy/tools/list_files.py +1 -1
  7. truefoundry/{deploy/lib/auth → common}/auth_service_client.py +50 -40
  8. truefoundry/common/constants.py +12 -0
  9. truefoundry/{deploy/lib/auth → common}/credential_file_manager.py +7 -7
  10. truefoundry/{deploy/lib/auth → common}/credential_provider.py +9 -12
  11. truefoundry/{ml/services → common}/entities.py +57 -41
  12. truefoundry/common/exceptions.py +12 -0
  13. truefoundry/common/request_utils.py +36 -8
  14. truefoundry/common/servicefoundry_client.py +91 -0
  15. truefoundry/common/utils.py +56 -0
  16. truefoundry/deploy/auto_gen/models.py +4 -6
  17. truefoundry/deploy/cli/cli.py +2 -0
  18. truefoundry/deploy/cli/commands/apply_command.py +1 -1
  19. truefoundry/deploy/cli/commands/build_command.py +1 -1
  20. truefoundry/deploy/cli/commands/deploy_command.py +1 -1
  21. truefoundry/deploy/cli/commands/login_command.py +2 -2
  22. truefoundry/deploy/cli/commands/patch_application_command.py +1 -1
  23. truefoundry/deploy/cli/commands/patch_command.py +1 -1
  24. truefoundry/deploy/cli/commands/terminate_comand.py +1 -1
  25. truefoundry/deploy/cli/util.py +1 -1
  26. truefoundry/deploy/function_service/remote/remote.py +1 -1
  27. truefoundry/deploy/lib/auth/servicefoundry_session.py +2 -2
  28. truefoundry/deploy/lib/clients/servicefoundry_client.py +120 -150
  29. truefoundry/deploy/lib/const.py +1 -35
  30. truefoundry/deploy/lib/exceptions.py +0 -11
  31. truefoundry/deploy/lib/model/entity.py +1 -112
  32. truefoundry/deploy/lib/session.py +13 -26
  33. truefoundry/deploy/lib/util.py +0 -37
  34. truefoundry/deploy/v2/lib/deploy.py +3 -3
  35. truefoundry/deploy/v2/lib/deployable_patched_models.py +1 -1
  36. truefoundry/ml/__init__.py +0 -9
  37. truefoundry/ml/artifact/truefoundry_artifact_repo.py +63 -22
  38. truefoundry/ml/autogen/client/__init__.py +0 -3
  39. truefoundry/ml/autogen/client/api/experiments_api.py +0 -165
  40. truefoundry/ml/autogen/client/models/__init__.py +0 -3
  41. truefoundry/ml/autogen/client/models/artifact_dto.py +6 -6
  42. truefoundry/ml/autogen/client/models/artifact_version_dto.py +8 -8
  43. truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +2 -3
  44. truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +2 -3
  45. truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +2 -2
  46. truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +2 -3
  47. truefoundry/ml/autogen/client/models/create_run_request_dto.py +5 -5
  48. truefoundry/ml/autogen/client/models/create_run_response_dto.py +2 -3
  49. truefoundry/ml/autogen/client/models/dataset_dto.py +10 -10
  50. truefoundry/ml/autogen/client/models/experiment_dto.py +18 -18
  51. truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +2 -3
  52. truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +4 -5
  53. truefoundry/ml/autogen/client/models/model_dto.py +6 -6
  54. truefoundry/ml/autogen/client/models/model_version_dto.py +15 -8
  55. truefoundry/ml/autogen/client/models/run_info_dto.py +10 -10
  56. truefoundry/ml/autogen/client/models/update_run_response_dto.py +2 -3
  57. truefoundry/ml/autogen/client_README.md +0 -2
  58. truefoundry/ml/clients/entities.py +8 -0
  59. truefoundry/ml/{services/servicefoundry_service.py → clients/servicefoundry_client.py} +20 -10
  60. truefoundry/ml/{services → clients}/utils.py +2 -2
  61. truefoundry/ml/entities.py +62 -0
  62. truefoundry/ml/env_vars.py +1 -5
  63. truefoundry/ml/internal_namespace.py +8 -8
  64. truefoundry/ml/log_types/artifacts/artifact.py +7 -3
  65. truefoundry/ml/log_types/artifacts/dataset.py +5 -4
  66. truefoundry/ml/log_types/artifacts/model.py +8 -25
  67. truefoundry/ml/log_types/image/image.py +7 -8
  68. truefoundry/ml/log_types/image/image_normalizer.py +7 -6
  69. truefoundry/ml/mlfoundry_api.py +10 -55
  70. truefoundry/ml/mlfoundry_run.py +6 -41
  71. truefoundry/ml/run_utils.py +1 -10
  72. truefoundry/ml/session.py +14 -117
  73. truefoundry/pydantic_v1.py +1 -1
  74. truefoundry/workflow/__init__.py +16 -1
  75. {truefoundry-0.4.0rc1.dist-info → truefoundry-0.4.0rc3.dist-info}/METADATA +1 -1
  76. {truefoundry-0.4.0rc1.dist-info → truefoundry-0.4.0rc3.dist-info}/RECORD +79 -77
  77. truefoundry/deploy/lib/clients/utils.py +0 -41
  78. truefoundry/ml/autogen/client/models/backfill_default_storage_integration_id_request_dto.py +0 -67
  79. truefoundry/ml/login.py +0 -241
  80. truefoundry/ml/services/auth_service.py +0 -109
  81. /truefoundry/ml/{services → clients}/__init__.py +0 -0
  82. {truefoundry-0.4.0rc1.dist-info → truefoundry-0.4.0rc3.dist-info}/WHEEL +0 -0
  83. {truefoundry-0.4.0rc1.dist-info → truefoundry-0.4.0rc3.dist-info}/entry_points.txt +0 -0
@@ -15,23 +15,22 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
15
15
  CreateExperimentRequestDto,
16
16
  CreateRunRequestDto,
17
17
  DatasetDto,
18
+ ExperimentsApi,
18
19
  ListArtifactsRequestDto,
19
20
  ListArtifactVersionsRequestDto,
20
21
  ListDatasetsRequestDto,
21
22
  ListModelVersionsRequestDto,
23
+ MlfoundryArtifactsApi,
22
24
  ModelDto,
25
+ RunsApi,
23
26
  RunTagDto,
24
27
  SearchRunsRequestDto,
25
28
  )
26
- from truefoundry.ml.autogen.client.api import ( # type: ignore[attr-defined]
27
- ExperimentsApi,
28
- MlfoundryArtifactsApi,
29
- RunsApi,
30
- )
31
29
  from truefoundry.ml.autogen.client.exceptions import (
32
30
  ApiException,
33
31
  NotFoundException,
34
32
  )
33
+ from truefoundry.ml.clients.servicefoundry_client import ServiceFoundryServiceClient
35
34
  from truefoundry.ml.enums import ModelFramework, ViewType
36
35
  from truefoundry.ml.exceptions import MlFoundryException
37
36
  from truefoundry.ml.internal_namespace import NAMESPACE
@@ -39,10 +38,8 @@ from truefoundry.ml.log_types.artifacts.artifact import ArtifactPath, ArtifactVe
39
38
  from truefoundry.ml.log_types.artifacts.dataset import DataDirectory
40
39
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
41
40
  from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
42
- from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
43
41
  from truefoundry.ml.logger import logger
44
42
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
45
- from truefoundry.ml.services.servicefoundry_service import ServicefoundryService
46
43
  from truefoundry.ml.session import (
47
44
  Session,
48
45
  _get_api_client,
@@ -126,7 +123,6 @@ class MlFoundry:
126
123
  )
127
124
  raise MlFoundryException(err_msg) from e
128
125
 
129
- assert ml_repo_instance.experiment_id is not None
130
126
  return ml_repo_instance.experiment_id
131
127
 
132
128
  def list_ml_repos(self) -> List[str]:
@@ -216,7 +212,7 @@ class MlFoundry:
216
212
  "No active session found. Perhaps you are not logged in?\n"
217
213
  "Please log in using `tfy login [--host HOST] --relogin"
218
214
  )
219
- servicefoundry_service = ServicefoundryService(
215
+ servicefoundry_client = ServiceFoundryServiceClient(
220
216
  tracking_uri=self.get_tracking_uri(),
221
217
  token=session.token.access_token,
222
218
  )
@@ -224,7 +220,7 @@ class MlFoundry:
224
220
  assert existing_ml_repo.storage_integration_id is not None
225
221
  try:
226
222
  existing_storage_integration = (
227
- servicefoundry_service.get_integration_from_id(
223
+ servicefoundry_client.get_integration_from_id(
228
224
  existing_ml_repo.storage_integration_id
229
225
  )
230
226
  )
@@ -338,11 +334,6 @@ class MlFoundry:
338
334
  )
339
335
  )
340
336
  run = _run.run
341
-
342
- assert run is not None
343
- assert run.info.run_id is not None
344
- assert run.info.fqn is not None
345
-
346
337
  mlf_run_id = run.info.run_id
347
338
  kwargs.setdefault("auto_end", True)
348
339
  mlf_run = MlFoundryRun(experiment_id=ml_repo_id, run_id=mlf_run_id, **kwargs)
@@ -381,9 +372,6 @@ class MlFoundry:
381
372
  return self.get_run_by_fqn(run_id)
382
373
  _run = self._runs_api.get_run_get(run_id=run_id)
383
374
  run = _run.run
384
-
385
- assert run is not None
386
-
387
375
  mlfoundry_run = MlFoundryRun._from_dto(run)
388
376
  logger.info(
389
377
  f"Link to the dashboard for the run: {mlfoundry_run.dashboard_link}"
@@ -725,7 +713,6 @@ class MlFoundry:
725
713
  version=resolved_version,
726
714
  )
727
715
  model_version = _model_version.model_version
728
- assert model_version.model_id is not None
729
716
  _model = self._mlfoundry_artifacts_api.get_model_get(id=model_version.model_id)
730
717
  model = _model.model
731
718
 
@@ -868,7 +855,6 @@ class MlFoundry:
868
855
  model: Optional[ModelDto] = None,
869
856
  ) -> Iterator[ModelVersion]:
870
857
  if model and not model_id:
871
- assert model.id is not None
872
858
  model_id = model.id
873
859
  elif not model and model_id:
874
860
  _model = self._mlfoundry_artifacts_api.get_model_get(id=str(model_id))
@@ -943,7 +929,6 @@ class MlFoundry:
943
929
  )
944
930
  )
945
931
  artifact_version = _artifact_version.artifact_version
946
- assert artifact_version.artifact_id is not None
947
932
  _artifact = self._mlfoundry_artifacts_api.get_artifact_by_id_get(
948
933
  id=artifact_version.artifact_id
949
934
  )
@@ -1072,7 +1057,6 @@ class MlFoundry:
1072
1057
  artifact: Optional[ArtifactDto] = None,
1073
1058
  ) -> Iterator[ArtifactVersion]:
1074
1059
  if artifact and not artifact_id:
1075
- assert artifact.id is not None
1076
1060
  artifact_id = artifact.id
1077
1061
  elif not artifact and artifact_id:
1078
1062
  _artifact = self._mlfoundry_artifacts_api.get_artifact_by_id_get(
@@ -1220,8 +1204,7 @@ class MlFoundry:
1220
1204
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
1221
1205
  description: Optional[str] = None,
1222
1206
  metadata: Optional[Dict[str, Any]] = None,
1223
- model_schema: Optional[Union[ModelSchema, Dict[str, Any]]] = None,
1224
- custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
1207
+ progress: Optional[bool] = None,
1225
1208
  ) -> ModelVersion:
1226
1209
  """
1227
1210
  Serialize and log a versioned model under the current ml_repo. Each logged model generates a new version
@@ -1277,34 +1260,7 @@ class MlFoundry:
1277
1260
  metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
1278
1261
  For example, you can use this to store metrics, params, notes.
1279
1262
  This field can be updated at any time after logging. Defaults to `None`
1280
- model_schema (Optional[Union[Dict[str, Any], ModelSchema]], optional):
1281
- instance of `truefoundry.ml.ModelSchema`.
1282
- This schema needs to be consistent with older versions of the model under the given `name` i.e.
1283
- a feature's value type and model's prediction type cannot be changed in the schema of new version.
1284
- Features can be removed or added between versions.
1285
- ```
1286
- E.g. if there exists a v1 with
1287
- schema = {"features": {"name": "feat1": "int"}, "prediction": "categorical"}, then
1288
-
1289
- schema = {"features": {"name": "feat1": "string"}, "prediction": "categorical"} or
1290
- schema = {"features": {"name": "feat1": "int"}, "prediction": "numerical"}
1291
- are invalid because they change the types of existing features and prediction
1292
-
1293
- while
1294
- schema = {"features": {"name": "feat1": "int", "feat2": "string"}, "prediction": "categorical"} or
1295
- schema = {"features": {"feat2": "string"}, "prediction": "categorical"}
1296
- are valid
1297
-
1298
- This field can be updated at any time after logging. Defaults to None
1299
- ```
1300
- custom_metrics: (Optional[Union[List[Dict[str, Any]], CustomMetric]], optional): list of instances of
1301
- `truefoundry.ml.CustomMetric`
1302
- The custom metrics must be added according to the prediction type of schema.
1303
- custom_metrics = [{
1304
- "name": "mean_square_error",
1305
- "type": "metric",
1306
- "value_type": "float"
1307
- }]
1263
+ progress (bool): value to show progress bar, defaults to None.
1308
1264
 
1309
1265
  Returns:
1310
1266
  truefoundry.ml.ModelVersion: an instance of `ModelVersion` that can be used to download the files,
@@ -1394,9 +1350,8 @@ class MlFoundry:
1394
1350
  additional_files=additional_files,
1395
1351
  description=description,
1396
1352
  metadata=metadata,
1397
- model_schema=model_schema,
1398
- custom_metrics=custom_metrics,
1399
1353
  step=None,
1354
+ progress=progress,
1400
1355
  )
1401
1356
  logger.info(f"Logged model successfully with fqn {model_version.fqn!r}")
1402
1357
  return model_version
@@ -1415,7 +1370,7 @@ class MlFoundry:
1415
1370
  Args:
1416
1371
  ml_repo (str): Name of the ML Repo in which you want to create data_directory
1417
1372
  name (str): Name of the DataDirectory to be created.
1418
- description (str): Description of the Datset
1373
+ description (str): Description for the DataDirectory.
1419
1374
  metadata (Dict <str>: Any): Metadata about the data_directory in Dictionary form.
1420
1375
 
1421
1376
  Returns:
@@ -38,6 +38,7 @@ from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
38
38
  RunTagDto,
39
39
  UpdateRunRequestDto,
40
40
  )
41
+ from truefoundry.ml.entities import Metric
41
42
  from truefoundry.ml.enums import RunStatus
42
43
  from truefoundry.ml.exceptions import MlFoundryException
43
44
  from truefoundry.ml.internal_namespace import NAMESPACE
@@ -45,7 +46,6 @@ from truefoundry.ml.log_types import Image, Plot
45
46
  from truefoundry.ml.log_types.artifacts.artifact import ArtifactPath, ArtifactVersion
46
47
  from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
47
48
  from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
48
- from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
49
49
  from truefoundry.ml.logger import logger
50
50
  from truefoundry.ml.run_utils import ParamsType, flatten_dict, process_params
51
51
  from truefoundry.ml.session import ACTIVE_RUNS, _get_api_client, get_active_session
@@ -110,8 +110,6 @@ class MlFoundryRun:
110
110
  @classmethod
111
111
  def _from_dto(cls, run_dto: RunDto) -> "MlFoundryRun":
112
112
  """classmethod to get MLfoundry run from dto instance"""
113
- assert run_dto.info.experiment_id is not None
114
- assert run_dto.info.run_id is not None
115
113
  run = cls(run_dto.info.experiment_id, run_dto.info.run_id)
116
114
  run._run_info = run_dto.info
117
115
  run._run_data = run_dto.data
@@ -136,7 +134,6 @@ class MlFoundryRun:
136
134
  def run_name(self) -> str:
137
135
  """Get run_name for the current `run`"""
138
136
  run_info = self._get_run_info()
139
- assert run_info.name is not None
140
137
  return run_info.name
141
138
 
142
139
  @property
@@ -144,7 +141,6 @@ class MlFoundryRun:
144
141
  def fqn(self) -> str:
145
142
  """Get fqn for the current `run`"""
146
143
  run_info = self._get_run_info()
147
- assert run_info.fqn is not None
148
144
  return run_info.fqn
149
145
 
150
146
  @property
@@ -162,7 +158,6 @@ class MlFoundryRun:
162
158
  _experiment = self._experiments_api.get_experiment_get(
163
159
  experiment_id=self._experiment_id
164
160
  )
165
- assert _experiment.experiment.name is not None
166
161
  return _experiment.experiment.name
167
162
 
168
163
  @property
@@ -812,7 +807,7 @@ class MlFoundryRun:
812
807
  @_ensure_not_deleted
813
808
  def get_metrics(
814
809
  self, metric_names: Optional[Iterable[str]] = None
815
- ) -> Dict[str, List[MetricDto]]:
810
+ ) -> Dict[str, List[Metric]]:
816
811
  """Get metrics logged for the current `run` grouped by metric name.
817
812
 
818
813
  Args:
@@ -860,7 +855,7 @@ class MlFoundryRun:
860
855
  unknown_metrics = metric_names - run_metric_names
861
856
  if len(unknown_metrics) > 0:
862
857
  logger.warning(f"{unknown_metrics} metrics not present in the run")
863
- metrics_dict: Dict[str, List[MetricDto]] = {
858
+ metrics_dict: Dict[str, List[Metric]] = {
864
859
  metric_name: [] for metric_name in unknown_metrics
865
860
  }
866
861
  valid_metrics = metric_names - unknown_metrics
@@ -868,7 +863,9 @@ class MlFoundryRun:
868
863
  _metric_history = self._metrics_api.get_metric_history_get(
869
864
  run_id=self.run_id, metric_key=metric_name
870
865
  )
871
- metrics_dict[metric_name] = _metric_history.metrics
866
+ metrics_dict[metric_name] = [
867
+ Metric.from_dto(metric) for metric in _metric_history.metrics
868
+ ]
872
869
  return metrics_dict
873
870
 
874
871
  @_ensure_not_deleted
@@ -911,8 +908,6 @@ class MlFoundryRun:
911
908
  additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
912
909
  description: Optional[str] = None,
913
910
  metadata: Optional[Dict[str, Any]] = None,
914
- model_schema: Optional[Union[ModelSchema, Dict[str, Any]]] = None,
915
- custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
916
911
  step: int = 0,
917
912
  progress: Optional[bool] = None,
918
913
  ) -> ModelVersion:
@@ -965,34 +960,6 @@ class MlFoundryRun:
965
960
  metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
966
961
  For example, you can use this to store metrics, params, notes.
967
962
  This field can be updated at any time after logging. Defaults to `None`
968
- model_schema (Optional[Union[Dict[str, Any], ModelSchema]], optional):
969
- instance of `truefoundry.ml.ModelSchema`.
970
- This schema needs to be consistent with older versions of the model under the given `name` i.e.
971
- a feature's value type and model's prediction type cannot be changed in the schema of new version.
972
- Features can be removed or added between versions.
973
- ```
974
- E.g. if there exists a v1 with
975
- schema = {"features": {"name": "feat1": "int"}, "prediction": "categorical"}, then
976
-
977
- schema = {"features": {"name": "feat1": "string"}, "prediction": "categorical"} or
978
- schema = {"features": {"name": "feat1": "int"}, "prediction": "numerical"}
979
- are invalid because they change the types of existing features and prediction
980
-
981
- while
982
- schema = {"features": {"name": "feat1": "int", "feat2": "string"}, "prediction": "categorical"} or
983
- schema = {"features": {"feat2": "string"}, "prediction": "categorical"}
984
- are valid
985
-
986
- This field can be updated at any time after logging. Defaults to None
987
- ```
988
- custom_metrics: (Optional[Union[List[Dict[str, Any]], CustomMetric]], optional): list of instances of
989
- `truefoundry.ml.CustomMetric`
990
- The custom metrics must be added according to the prediction type of schema.
991
- custom_metrics = [{
992
- "name": "mean_square_error",
993
- "type": "metric",
994
- "value_type": "float"
995
- }]
996
963
  step (int): step/iteration at which the model is being logged, defaults to 0.
997
964
  progress (bool): value to show progress bar, defaults to None.
998
965
 
@@ -1084,8 +1051,6 @@ class MlFoundryRun:
1084
1051
  additional_files=additional_files,
1085
1052
  description=description,
1086
1053
  metadata=metadata,
1087
- model_schema=model_schema,
1088
- custom_metrics=custom_metrics,
1089
1054
  step=step,
1090
1055
  progress=progress,
1091
1056
  )
@@ -7,7 +7,6 @@ from urllib.parse import urljoin, urlsplit
7
7
 
8
8
  import numpy as np
9
9
 
10
- from truefoundry.ml import env_vars
11
10
  from truefoundry.ml.exceptions import MlFoundryException
12
11
 
13
12
 
@@ -22,21 +21,13 @@ def get_module(
22
21
  raise MlFoundryException(msg) from ex
23
22
 
24
23
 
25
- def resolve_tracking_uri(tracking_uri: Optional[str]):
26
- if not tracking_uri and not os.getenv(env_vars.TRACKING_HOST_GLOBAL):
27
- raise ValueError(
28
- f"Either `host` should be provided by --host <value>, or `{env_vars.TRACKING_HOST_GLOBAL}` env must be set"
29
- )
30
- return tracking_uri or os.getenv(env_vars.TRACKING_HOST_GLOBAL)
31
-
32
-
33
24
  def append_path_to_rest_tracking_uri(tracking_uri: str):
34
25
  if urlsplit(tracking_uri).netloc.startswith("localhost"):
35
26
  return tracking_uri
36
27
  return urljoin(tracking_uri, "/api/ml")
37
28
 
38
29
 
39
- def append_servicefoundry_path_to_tracking_ui(tracking_uri: str):
30
+ def append_servicefoundry_path_to_tracking_uri(tracking_uri: str):
40
31
  if urlsplit(tracking_uri).netloc.startswith("localhost"):
41
32
  return os.getenv("SERVICEFOUNDRY_SERVER_URL")
42
33
  return urljoin(tracking_uri, "/api/svc")
truefoundry/ml/session.py CHANGED
@@ -1,131 +1,29 @@
1
- import abc
2
1
  import atexit
3
- import os
4
2
  import threading
5
3
  import weakref
6
4
  from typing import TYPE_CHECKING, Dict, Optional
7
5
 
6
+ from truefoundry.common.credential_provider import (
7
+ CredentialProvider,
8
+ EnvCredentialProvider,
9
+ FileCredentialProvider,
10
+ )
11
+ from truefoundry.common.entities import Token, UserInfo
8
12
  from truefoundry.common.request_utils import urllib3_retry
9
- from truefoundry.ml import env_vars
10
13
  from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
11
14
  ApiClient,
12
15
  Configuration,
13
16
  )
17
+ from truefoundry.ml.clients.entities import HostCreds
14
18
  from truefoundry.ml.exceptions import MlFoundryException
15
19
  from truefoundry.ml.logger import logger
16
- from truefoundry.ml.login import CredentialsFileContent, CredentialsFileManager
17
20
  from truefoundry.ml.run_utils import (
18
21
  append_path_to_rest_tracking_uri,
19
- resolve_tracking_uri,
20
22
  )
21
- from truefoundry.ml.services.auth_service import get_auth_service
22
- from truefoundry.ml.services.entities import HostCreds, Token, UserInfo
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from truefoundry.ml.mlfoundry_run import MlFoundryRun
26
26
 
27
- TOKEN_REFRESH_LOCK = threading.RLock()
28
-
29
-
30
- class CredentialProvider(abc.ABC):
31
- @property
32
- @abc.abstractmethod
33
- def token(self) -> Token: ...
34
-
35
- @staticmethod
36
- @abc.abstractmethod
37
- def can_provide() -> bool: ...
38
-
39
- @property
40
- @abc.abstractmethod
41
- def tracking_uri(self) -> str: ...
42
-
43
-
44
- class EnvCredentialProvider(CredentialProvider):
45
- def __init__(self):
46
- logger.debug("Using env var credential provider")
47
- self._tracking_uri = resolve_tracking_uri(tracking_uri=None)
48
- self._auth_service = get_auth_service(tracking_uri=self._tracking_uri)
49
- api_key = os.getenv(env_vars.API_KEY_GLOBAL)
50
- if not api_key:
51
- raise MlFoundryException(
52
- f"Value of {env_vars.API_KEY_GLOBAL} env var should be non-empty string"
53
- )
54
- self._token: Token = Token(access_token=api_key, refresh_token=None) # type: ignore[call-arg]
55
-
56
- @staticmethod
57
- def can_provide() -> bool:
58
- return env_vars.API_KEY_GLOBAL in os.environ
59
-
60
- @property
61
- def token(self) -> Token:
62
- with TOKEN_REFRESH_LOCK:
63
- if self._token.is_going_to_be_expired():
64
- logger.info("Refreshing access token")
65
- self._token = self._auth_service.refresh_token(self._token)
66
- return self._token
67
-
68
- @property
69
- def tracking_uri(self) -> str:
70
- return self._tracking_uri
71
-
72
-
73
- class FileCredentialProvider(CredentialProvider):
74
- def __init__(self):
75
- logger.debug("Using file credential provider")
76
- self._cred_file = CredentialsFileManager()
77
-
78
- with self._cred_file:
79
- self._last_cred_file_content = self._cred_file.read()
80
- self._tracking_uri = self._last_cred_file_content.host
81
- self._token = self._last_cred_file_content.to_token()
82
- self._auth_service = get_auth_service(tracking_uri=self._tracking_uri)
83
-
84
- @staticmethod
85
- def can_provide() -> bool:
86
- with CredentialsFileManager() as cred_file:
87
- return cred_file.exists()
88
-
89
- @property
90
- def token(self) -> Token:
91
- with TOKEN_REFRESH_LOCK:
92
- if not self._token.is_going_to_be_expired():
93
- return self._token
94
-
95
- logger.info("Refreshing access token")
96
- with self._cred_file:
97
- new_cred_file_content = self._cred_file.read()
98
- new_token = new_cred_file_content.to_token()
99
- new_tracking_uri = new_cred_file_content.host
100
-
101
- if new_cred_file_content == self._last_cred_file_content:
102
- self._token = self._auth_service.refresh_token(self._token)
103
- self._last_cred_file_content = CredentialsFileContent(
104
- host=self._tracking_uri,
105
- access_token=self._token.access_token,
106
- refresh_token=self._token.refresh_token,
107
- )
108
- self._cred_file.write(self._last_cred_file_content)
109
- return self._token
110
-
111
- if (
112
- new_tracking_uri == self._tracking_uri
113
- and new_token.to_user_info() == self._token.to_user_info()
114
- ):
115
- self._last_cred_file_content = new_cred_file_content
116
- self._token = new_token
117
- # recursive
118
- return self.token
119
-
120
- raise MlFoundryException(
121
- "Credentials on disk changed while mlfoundry was running."
122
- )
123
-
124
- @property
125
- def tracking_uri(self) -> str:
126
- return self._tracking_uri
127
-
128
-
129
27
  SESSION_LOCK = threading.RLock()
130
28
 
131
29
 
@@ -159,8 +57,8 @@ class Session:
159
57
  def __init__(self, cred_provider: CredentialProvider):
160
58
  # Note: Whenever a new session is initialized all the active runs are ended
161
59
  self._closed = False
162
- self._cred_provider: CredentialProvider = cred_provider
163
- self._user_info: UserInfo = self._cred_provider.token.to_user_info()
60
+ self._cred_provider: Optional[CredentialProvider] = cred_provider
61
+ self._user_info: Optional[UserInfo] = self._cred_provider.token.to_user_info()
164
62
 
165
63
  def close(self):
166
64
  logger.debug("Closing existing session")
@@ -187,7 +85,7 @@ class Session:
187
85
 
188
86
  @property
189
87
  def tracking_uri(self) -> str:
190
- return self._cred_provider.tracking_uri
88
+ return self._cred_provider.base_url
191
89
 
192
90
  def __eq__(self, other: object) -> bool:
193
91
  if not isinstance(other, Session):
@@ -199,9 +97,7 @@ class Session:
199
97
  )
200
98
 
201
99
  def get_host_creds(self) -> HostCreds:
202
- tracking_uri = append_path_to_rest_tracking_uri(
203
- self._cred_provider.tracking_uri
204
- )
100
+ tracking_uri = append_path_to_rest_tracking_uri(self._cred_provider.base_url)
205
101
  return HostCreds(
206
102
  host=tracking_uri, token=self._cred_provider.token.access_token
207
103
  )
@@ -235,6 +131,7 @@ def _get_api_client(
235
131
  )
236
132
  configuration.retries = urllib3_retry(retries=2)
237
133
  api_client = ApiClient(configuration=configuration)
134
+ api_client.user_agent = "truefoundry-cli"
238
135
  return api_client
239
136
 
240
137
 
@@ -247,8 +144,8 @@ def init_session() -> Session:
247
144
  break
248
145
  if final_cred_provider is None:
249
146
  raise MlFoundryException(
250
- "Please login using `mlfoundry login` command "
251
- "or `truefoundry.ml.login()` function call"
147
+ "Please login using `tfy login` command "
148
+ "or `truefoundry.login()` function call"
252
149
  )
253
150
  new_session = Session(cred_provider=final_cred_provider)
254
151
 
@@ -5,7 +5,7 @@ try:
5
5
  except ImportError:
6
6
  # pydantic <=1.10.17
7
7
  from pydantic import * # noqa: F403
8
- from pydantic import utils # noqa: F401
8
+ from pydantic import ConstrainedStr, utils # noqa: F401
9
9
 
10
10
 
11
11
  class NonEmptyStr(ConstrainedStr):
@@ -1,5 +1,5 @@
1
1
  try:
2
- from flytekit import task
2
+ from flytekit import task as _
3
3
  except ImportError:
4
4
  print("To use workflows, please run 'pip install truefoundry[workflow]'.")
5
5
 
@@ -17,3 +17,18 @@ from truefoundry.workflow.map_task import map_task
17
17
  from truefoundry.workflow.python_task import PythonFunctionTask
18
18
  from truefoundry.workflow.task import task
19
19
  from truefoundry.workflow.workflow import ExecutionConfig, workflow
20
+
21
+ __all__ = [
22
+ "task",
23
+ "ContainerTask",
24
+ "PythonFunctionTask",
25
+ "map_task",
26
+ "workflow",
27
+ "conditional",
28
+ "FlyteDirectory",
29
+ "TaskDockerFileBuild",
30
+ "TaskPythonBuild",
31
+ "ContainerTaskConfig",
32
+ "PythonTaskConfig",
33
+ "ExecutionConfig",
34
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: truefoundry
3
- Version: 0.4.0rc1
3
+ Version: 0.4.0rc3
4
4
  Summary: Truefoundry CLI
5
5
  Author: Abhishek Choudhary
6
6
  Author-email: abhishek@truefoundry.com