truefoundry 0.3.3__py3-none-any.whl → 0.4.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/cli/__main__.py +3 -17
- truefoundry/common/__init__.py +0 -0
- truefoundry/common/request_utils.py +56 -0
- truefoundry/deploy/cli/cli.py +1 -1
- truefoundry/deploy/cli/util.py +3 -1
- truefoundry/deploy/lib/auth/credential_provider.py +2 -12
- truefoundry/deploy/lib/clients/servicefoundry_client.py +0 -9
- truefoundry/deploy/lib/exceptions.py +1 -6
- truefoundry/deploy/lib/session.py +1 -16
- truefoundry/langchain/truefoundry_chat.py +1 -1
- truefoundry/langchain/truefoundry_embeddings.py +1 -1
- truefoundry/langchain/truefoundry_llm.py +1 -1
- truefoundry/langchain/utils.py +0 -41
- truefoundry/ml/__init__.py +46 -6
- truefoundry/ml/artifact/__init__.py +0 -0
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +1120 -0
- truefoundry/ml/autogen/__init__.py +0 -0
- truefoundry/ml/autogen/client/__init__.py +373 -0
- truefoundry/ml/autogen/client/api/__init__.py +16 -0
- truefoundry/ml/autogen/client/api/auth_api.py +184 -0
- truefoundry/ml/autogen/client/api/deprecated_api.py +605 -0
- truefoundry/ml/autogen/client/api/experiments_api.py +2109 -0
- truefoundry/ml/autogen/client/api/health_api.py +299 -0
- truefoundry/ml/autogen/client/api/metrics_api.py +371 -0
- truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +7213 -0
- truefoundry/ml/autogen/client/api/python_deployment_config_api.py +201 -0
- truefoundry/ml/autogen/client/api/run_artifacts_api.py +231 -0
- truefoundry/ml/autogen/client/api/runs_api.py +2919 -0
- truefoundry/ml/autogen/client/api_client.py +822 -0
- truefoundry/ml/autogen/client/api_response.py +30 -0
- truefoundry/ml/autogen/client/configuration.py +489 -0
- truefoundry/ml/autogen/client/exceptions.py +161 -0
- truefoundry/ml/autogen/client/models/__init__.py +344 -0
- truefoundry/ml/autogen/client/models/add_custom_metrics_to_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/agent.py +125 -0
- truefoundry/ml/autogen/client/models/agent_app.py +118 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool.py +143 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +144 -0
- truefoundry/ml/autogen/client/models/agent_with_fqn.py +127 -0
- truefoundry/ml/autogen/client/models/artifact_dto.py +115 -0
- truefoundry/ml/autogen/client/models/artifact_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/artifact_type.py +39 -0
- truefoundry/ml/autogen/client/models/artifact_version_dto.py +141 -0
- truefoundry/ml/autogen/client/models/artifact_version_response_dto.py +77 -0
- truefoundry/ml/autogen/client/models/artifact_version_status.py +35 -0
- truefoundry/ml/autogen/client/models/assistant_message.py +89 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/backfill_default_storage_integration_id_request_dto.py +67 -0
- truefoundry/ml/autogen/client/models/blob_storage_reference.py +93 -0
- truefoundry/ml/autogen/client/models/body_get_search_runs_get.py +72 -0
- truefoundry/ml/autogen/client/models/chat_prompt.py +156 -0
- truefoundry/ml/autogen/client/models/chat_prompt_messages_inner.py +171 -0
- truefoundry/ml/autogen/client/models/columns_dto.py +73 -0
- truefoundry/ml/autogen/client/models/content.py +153 -0
- truefoundry/ml/autogen/client/models/content1.py +153 -0
- truefoundry/ml/autogen/client/models/content2.py +174 -0
- truefoundry/ml/autogen/client/models/content2_any_of_inner.py +150 -0
- truefoundry/ml/autogen/client/models/create_artifact_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +66 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +66 -0
- truefoundry/ml/autogen/client/models/create_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/create_experiment_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/create_experiment_response_dto.py +67 -0
- truefoundry/ml/autogen/client/models/create_model_version_request_dto.py +95 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +72 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +68 -0
- truefoundry/ml/autogen/client/models/create_run_request_dto.py +97 -0
- truefoundry/ml/autogen/client/models/create_run_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/dataset_dto.py +112 -0
- truefoundry/ml/autogen/client/models/dataset_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/delete_artifact_versions_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/delete_model_version_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_run_request.py +65 -0
- truefoundry/ml/autogen/client/models/delete_tag_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/experiment_dto.py +127 -0
- truefoundry/ml/autogen/client/models/experiment_id_request_dto.py +67 -0
- truefoundry/ml/autogen/client/models/experiment_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/experiment_tag_dto.py +69 -0
- truefoundry/ml/autogen/client/models/feature_dto.py +68 -0
- truefoundry/ml/autogen/client/models/feature_value_type.py +35 -0
- truefoundry/ml/autogen/client/models/file_info_dto.py +76 -0
- truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +101 -0
- truefoundry/ml/autogen/client/models/get_experiment_response_dto.py +88 -0
- truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/get_metric_history_response.py +79 -0
- truefoundry/ml/autogen/client/models/get_signed_url_for_dataset_write_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_response_dto.py +83 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_write_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +74 -0
- truefoundry/ml/autogen/client/models/http_validation_error.py +82 -0
- truefoundry/ml/autogen/client/models/image_content_part.py +87 -0
- truefoundry/ml/autogen/client/models/image_url.py +75 -0
- truefoundry/ml/autogen/client/models/internal_metadata.py +180 -0
- truefoundry/ml/autogen/client/models/latest_run_log_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_request_dto.py +107 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_artifacts_request_dto.py +96 -0
- truefoundry/ml/autogen/client/models/list_artifacts_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_colums_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/list_datasets_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_datasets_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_experiments_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_version_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_versions_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_latest_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_metric_history_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/list_metric_history_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_model_version_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_model_versions_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/list_models_request_dto.py +89 -0
- truefoundry/ml/autogen/client/models/list_models_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_artifacts_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/log_batch_request_dto.py +106 -0
- truefoundry/ml/autogen/client/models/log_metric_request_dto.py +80 -0
- truefoundry/ml/autogen/client/models/log_param_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/method.py +37 -0
- truefoundry/ml/autogen/client/models/metric_collection_dto.py +82 -0
- truefoundry/ml/autogen/client/models/metric_dto.py +76 -0
- truefoundry/ml/autogen/client/models/mime_type.py +37 -0
- truefoundry/ml/autogen/client/models/model_configuration.py +103 -0
- truefoundry/ml/autogen/client/models/model_dto.py +122 -0
- truefoundry/ml/autogen/client/models/model_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/model_schema_dto.py +85 -0
- truefoundry/ml/autogen/client/models/model_version_dto.py +163 -0
- truefoundry/ml/autogen/client/models/model_version_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_dto.py +107 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_storage_provider.py +34 -0
- truefoundry/ml/autogen/client/models/notify_artifact_version_failure_dto.py +65 -0
- truefoundry/ml/autogen/client/models/openapi_spec.py +152 -0
- truefoundry/ml/autogen/client/models/param_dto.py +66 -0
- truefoundry/ml/autogen/client/models/parameters.py +84 -0
- truefoundry/ml/autogen/client/models/prediction_type.py +34 -0
- truefoundry/ml/autogen/client/models/resolve_agent_app_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/restore_run_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/run_data_dto.py +104 -0
- truefoundry/ml/autogen/client/models/run_dto.py +84 -0
- truefoundry/ml/autogen/client/models/run_info_dto.py +105 -0
- truefoundry/ml/autogen/client/models/run_log_dto.py +90 -0
- truefoundry/ml/autogen/client/models/run_log_input_dto.py +80 -0
- truefoundry/ml/autogen/client/models/run_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/run_tag_dto.py +66 -0
- truefoundry/ml/autogen/client/models/search_runs_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/search_runs_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/set_experiment_tag_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/set_tag_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/signed_url_dto.py +69 -0
- truefoundry/ml/autogen/client/models/stop.py +152 -0
- truefoundry/ml/autogen/client/models/store_run_logs_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/system_message.py +89 -0
- truefoundry/ml/autogen/client/models/text.py +153 -0
- truefoundry/ml/autogen/client/models/text_content_part.py +84 -0
- truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_experiment_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/update_run_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/update_run_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/url.py +153 -0
- truefoundry/ml/autogen/client/models/user_message.py +89 -0
- truefoundry/ml/autogen/client/models/validation_error.py +87 -0
- truefoundry/ml/autogen/client/models/validation_error_loc_inner.py +154 -0
- truefoundry/ml/autogen/client/rest.py +426 -0
- truefoundry/ml/autogen/client_README.md +322 -0
- truefoundry/ml/cli/__init__.py +0 -0
- truefoundry/ml/cli/cli.py +18 -0
- truefoundry/ml/cli/commands/__init__.py +3 -0
- truefoundry/ml/cli/commands/download.py +87 -0
- truefoundry/ml/constants.py +84 -0
- truefoundry/ml/enums.py +70 -0
- truefoundry/ml/env_vars.py +13 -0
- truefoundry/ml/exceptions.py +8 -0
- truefoundry/ml/git_info.py +60 -0
- truefoundry/ml/internal_namespace.py +52 -0
- truefoundry/ml/log_types/__init__.py +4 -0
- truefoundry/ml/log_types/artifacts/artifact.py +427 -0
- truefoundry/ml/log_types/artifacts/constants.py +33 -0
- truefoundry/ml/log_types/artifacts/dataset.py +383 -0
- truefoundry/ml/log_types/artifacts/general_artifact.py +110 -0
- truefoundry/ml/log_types/artifacts/model.py +628 -0
- truefoundry/ml/log_types/artifacts/model_extras.py +48 -0
- truefoundry/ml/log_types/artifacts/utils.py +161 -0
- truefoundry/ml/log_types/image/__init__.py +3 -0
- truefoundry/ml/log_types/image/constants.py +8 -0
- truefoundry/ml/log_types/image/image.py +358 -0
- truefoundry/ml/log_types/image/image_normalizer.py +101 -0
- truefoundry/ml/log_types/image/types.py +68 -0
- truefoundry/ml/log_types/plot.py +281 -0
- truefoundry/ml/log_types/pydantic_base.py +10 -0
- truefoundry/ml/log_types/utils.py +12 -0
- truefoundry/ml/logger.py +17 -0
- truefoundry/ml/login.py +241 -0
- truefoundry/ml/mlfoundry_api.py +1620 -0
- truefoundry/ml/mlfoundry_run.py +1238 -0
- truefoundry/ml/run_utils.py +102 -0
- truefoundry/ml/services/__init__.py +0 -0
- truefoundry/ml/services/auth_service.py +109 -0
- truefoundry/ml/services/entities.py +108 -0
- truefoundry/ml/services/servicefoundry_service.py +35 -0
- truefoundry/ml/services/utils.py +122 -0
- truefoundry/ml/session.py +271 -0
- truefoundry/ml/validation_utils.py +346 -0
- truefoundry/pydantic_v1.py +5 -1
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev1.dist-info}/METADATA +18 -11
- truefoundry-0.4.0.dev1.dist-info/RECORD +342 -0
- truefoundry-0.3.3.dist-info/RECORD +0 -136
- /truefoundry/{python_deploy_codegen.py → deploy/python_deploy_codegen.py} +0 -0
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev1.dist-info}/WHEEL +0 -0
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
|
6
|
+
from urllib.parse import urljoin, urlsplit
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from truefoundry.ml import env_vars
|
|
11
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_module(
|
|
15
|
+
module_name: str, error_message: Optional[str] = None, required: bool = False
|
|
16
|
+
):
|
|
17
|
+
try:
|
|
18
|
+
return importlib.import_module(module_name)
|
|
19
|
+
except Exception as ex:
|
|
20
|
+
msg = error_message or f"Error importing module {module_name}"
|
|
21
|
+
if required:
|
|
22
|
+
raise MlFoundryException(msg) from ex
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def resolve_tracking_uri(tracking_uri: Optional[str]):
|
|
26
|
+
if not tracking_uri and not os.getenv(env_vars.TRACKING_HOST_GLOBAL):
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Either `host` should be provided by --host <value>, or `{env_vars.TRACKING_HOST_GLOBAL}` env must be set"
|
|
29
|
+
)
|
|
30
|
+
return tracking_uri or os.getenv(env_vars.TRACKING_HOST_GLOBAL)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def append_path_to_rest_tracking_uri(tracking_uri: str):
|
|
34
|
+
if urlsplit(tracking_uri).netloc.startswith("localhost"):
|
|
35
|
+
return tracking_uri
|
|
36
|
+
return urljoin(tracking_uri, "/api/ml")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def append_servicefoundry_path_to_tracking_ui(tracking_uri: str):
|
|
40
|
+
if urlsplit(tracking_uri).netloc.startswith("localhost"):
|
|
41
|
+
return os.getenv("SERVICEFOUNDRY_SERVER_URL")
|
|
42
|
+
return urljoin(tracking_uri, "/api/svc")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
46
|
+
"""Special json encoder for numpy types"""
|
|
47
|
+
|
|
48
|
+
def default(self, obj):
|
|
49
|
+
if isinstance(obj, np.integer):
|
|
50
|
+
return int(obj)
|
|
51
|
+
elif isinstance(obj, np.floating):
|
|
52
|
+
return float(obj)
|
|
53
|
+
elif isinstance(obj, np.ndarray):
|
|
54
|
+
return obj.tolist()
|
|
55
|
+
return json.JSONEncoder.default(self, obj)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
ParamsType = Union[Mapping[str, Any], argparse.Namespace]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def process_params(params: ParamsType) -> Mapping[str, Any]:
|
|
62
|
+
if isinstance(params, Mapping):
|
|
63
|
+
return params
|
|
64
|
+
if isinstance(params, argparse.Namespace):
|
|
65
|
+
return vars(params)
|
|
66
|
+
# TODO: add absl support if required
|
|
67
|
+
# move to a different file then
|
|
68
|
+
raise MlFoundryException(
|
|
69
|
+
"params should be either argparse.Namespace or a Mapping (dict) type"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def flatten_dict(
|
|
74
|
+
input_dict: Mapping[Any, Any], parent_key: str = "", sep: str = "."
|
|
75
|
+
) -> Dict[str, Any]:
|
|
76
|
+
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a.b': 'c'}``.
|
|
77
|
+
All the keys will be converted to str.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
input_dict: Dictionary containing the keys
|
|
81
|
+
parent_key: Prefix to add to the keys. Defaults to ``''``.
|
|
82
|
+
sep: Delimiter to express the hierarchy. Defaults to ``'.'``.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Flattened dict.
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> flatten_dict({'a': {'b': 'c'}})
|
|
89
|
+
{'a.b': 'c'}
|
|
90
|
+
>>> flatten_dict({'a': {'b': 123}})
|
|
91
|
+
{'a.b': 123}
|
|
92
|
+
>>> flatten_dict({'a': {'b': 'c'}}, parent_key="param")
|
|
93
|
+
{'param.a.b': 'c'}
|
|
94
|
+
"""
|
|
95
|
+
new_dict_items: List[Tuple[str, Any]] = []
|
|
96
|
+
for k, v in input_dict.items():
|
|
97
|
+
new_key = parent_key + sep + str(k) if parent_key else k
|
|
98
|
+
if isinstance(v, Mapping):
|
|
99
|
+
new_dict_items.extend(flatten_dict(v, new_key, sep=sep).items())
|
|
100
|
+
else:
|
|
101
|
+
new_dict_items.append((new_key, v))
|
|
102
|
+
return dict(new_dict_items)
|
|
File without changes
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from urllib.parse import urlparse
|
|
3
|
+
|
|
4
|
+
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
5
|
+
ApiClient,
|
|
6
|
+
AuthApi,
|
|
7
|
+
Configuration,
|
|
8
|
+
)
|
|
9
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
10
|
+
from truefoundry.ml.logger import logger
|
|
11
|
+
from truefoundry.ml.run_utils import append_path_to_rest_tracking_uri
|
|
12
|
+
from truefoundry.ml.services.entities import (
|
|
13
|
+
AuthServerInfo,
|
|
14
|
+
DeviceCode,
|
|
15
|
+
HostCreds,
|
|
16
|
+
Token,
|
|
17
|
+
)
|
|
18
|
+
from truefoundry.ml.services.utils import http_request, http_request_safe
|
|
19
|
+
|
|
20
|
+
# TODO: This will eventually go away, this is duplicate of AuthServiceClient
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthService:
|
|
24
|
+
def __init__(self, url: str, tenant_name: str):
|
|
25
|
+
self._host_creds = HostCreds(host=url.rstrip("/"), token=None)
|
|
26
|
+
self._tenant_name = tenant_name
|
|
27
|
+
|
|
28
|
+
def refresh_token(self, token: Token) -> Token:
|
|
29
|
+
if not token.refresh_token:
|
|
30
|
+
# TODO: Add a way to propagate error messages without traceback to the output interface side
|
|
31
|
+
raise MlFoundryException(
|
|
32
|
+
"Unable to resume login session. Please log in again using `tfy login [--host HOST] --relogin`"
|
|
33
|
+
)
|
|
34
|
+
try:
|
|
35
|
+
response = http_request_safe(
|
|
36
|
+
method="post",
|
|
37
|
+
host_creds=self._host_creds,
|
|
38
|
+
endpoint="api/v1/oauth/token/refresh",
|
|
39
|
+
json={
|
|
40
|
+
"tenantName": token.tenant_name,
|
|
41
|
+
"refreshToken": token.refresh_token,
|
|
42
|
+
},
|
|
43
|
+
timeout=3,
|
|
44
|
+
)
|
|
45
|
+
except MlFoundryException as e:
|
|
46
|
+
if e.status_code and (400 <= e.status_code < 500):
|
|
47
|
+
raise MlFoundryException(
|
|
48
|
+
"Unable to resume login session. "
|
|
49
|
+
"Please log in again using `tfy login [--host HOST] --relogin`"
|
|
50
|
+
) from None
|
|
51
|
+
raise
|
|
52
|
+
return Token.parse_obj(response)
|
|
53
|
+
|
|
54
|
+
def get_device_code(self) -> DeviceCode:
|
|
55
|
+
response = http_request_safe(
|
|
56
|
+
method="post",
|
|
57
|
+
host_creds=self._host_creds,
|
|
58
|
+
endpoint="api/v1/oauth/device",
|
|
59
|
+
json={"tenantName": self._tenant_name},
|
|
60
|
+
timeout=3,
|
|
61
|
+
)
|
|
62
|
+
return DeviceCode.parse_obj(response)
|
|
63
|
+
|
|
64
|
+
def get_token_from_device_code(
|
|
65
|
+
self, device_code: str, timeout: float = 60
|
|
66
|
+
) -> Token:
|
|
67
|
+
start_time = time.monotonic()
|
|
68
|
+
while (time.monotonic() - start_time) <= timeout:
|
|
69
|
+
response = http_request(
|
|
70
|
+
method="post",
|
|
71
|
+
host_creds=self._host_creds,
|
|
72
|
+
endpoint="api/v1/oauth/device/token",
|
|
73
|
+
json={"tenantName": self._tenant_name, "deviceCode": device_code},
|
|
74
|
+
timeout=3,
|
|
75
|
+
)
|
|
76
|
+
if response.status_code == 202:
|
|
77
|
+
logger.debug("User has not authorized yet. Checking again.")
|
|
78
|
+
time.sleep(1.0)
|
|
79
|
+
continue
|
|
80
|
+
if response.status_code == 201:
|
|
81
|
+
response = response.json()
|
|
82
|
+
return Token.parse_obj(response)
|
|
83
|
+
raise MlFoundryException(
|
|
84
|
+
"Failed to get token using device code.\n"
|
|
85
|
+
f"Status Code: {response.status_code},\nResponse: {response.text}"
|
|
86
|
+
)
|
|
87
|
+
raise MlFoundryException(f"Did not get authorized within {timeout} seconds.")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_auth_service(tracking_uri: str) -> AuthService:
|
|
91
|
+
tracking_uri = append_path_to_rest_tracking_uri(tracking_uri)
|
|
92
|
+
parsed_tracking_uri = urlparse(tracking_uri)
|
|
93
|
+
host = parsed_tracking_uri.netloc
|
|
94
|
+
# Anonymous api
|
|
95
|
+
api_client = ApiClient(
|
|
96
|
+
configuration=Configuration(
|
|
97
|
+
host=tracking_uri.rstrip("/"),
|
|
98
|
+
access_token=None,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
auth_api = AuthApi(api_client=api_client)
|
|
102
|
+
auth_server_info = auth_api.get_tenant_id_get(
|
|
103
|
+
host_name=host,
|
|
104
|
+
_request_timeout=3,
|
|
105
|
+
)
|
|
106
|
+
tenant_info = AuthServerInfo.parse_obj(auth_server_info.dict())
|
|
107
|
+
return AuthService(
|
|
108
|
+
url=tenant_info.auth_server_url, tenant_name=tenant_info.tenant_name
|
|
109
|
+
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import time
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
|
|
6
|
+
import jwt
|
|
7
|
+
from typing_extensions import TypedDict
|
|
8
|
+
|
|
9
|
+
from truefoundry.pydantic_v1 import BaseModel, Field, NonEmptyStr, validator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HostCreds(BaseModel):
|
|
13
|
+
host: str
|
|
14
|
+
token: Optional[str] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AuthServerInfo(BaseModel):
|
|
18
|
+
tenant_name: NonEmptyStr
|
|
19
|
+
auth_server_url: str
|
|
20
|
+
|
|
21
|
+
class Config:
|
|
22
|
+
allow_mutation = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ArtifactCredential(BaseModel):
|
|
26
|
+
run_id: str
|
|
27
|
+
path: str
|
|
28
|
+
signed_uri: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class UserType(enum.Enum):
|
|
32
|
+
user = "user"
|
|
33
|
+
serviceaccount = "serviceaccount"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UserInfo(BaseModel):
|
|
37
|
+
user_id: NonEmptyStr
|
|
38
|
+
user_type: UserType = UserType.user
|
|
39
|
+
email: Optional[str] = None
|
|
40
|
+
tenant_name: NonEmptyStr = Field(alias="tenantName")
|
|
41
|
+
|
|
42
|
+
class Config:
|
|
43
|
+
allow_population_by_field_name = True
|
|
44
|
+
allow_mutation = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _DecodedToken(TypedDict):
|
|
48
|
+
tenantName: str
|
|
49
|
+
exp: int
|
|
50
|
+
username: str
|
|
51
|
+
email: str
|
|
52
|
+
userType: UserType
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Token(BaseModel):
|
|
56
|
+
access_token: NonEmptyStr = Field(alias="accessToken", repr=False)
|
|
57
|
+
refresh_token: Optional[NonEmptyStr] = Field(alias="refreshToken", repr=False)
|
|
58
|
+
decoded_value: Optional[_DecodedToken] = Field(exclude=True, repr=False)
|
|
59
|
+
|
|
60
|
+
class Config:
|
|
61
|
+
allow_population_by_field_name = True
|
|
62
|
+
allow_mutation = False
|
|
63
|
+
|
|
64
|
+
@validator("decoded_value", always=True, pre=True)
|
|
65
|
+
def _decode_jwt(cls, v, values, **kwargs):
|
|
66
|
+
access_token = values["access_token"]
|
|
67
|
+
return jwt.decode(
|
|
68
|
+
access_token,
|
|
69
|
+
options={
|
|
70
|
+
"verify_signature": False,
|
|
71
|
+
"verify_aud": False,
|
|
72
|
+
"verify_exp": False,
|
|
73
|
+
},
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def tenant_name(self) -> str:
|
|
78
|
+
assert self.decoded_value is not None
|
|
79
|
+
return self.decoded_value["tenantName"]
|
|
80
|
+
|
|
81
|
+
def is_going_to_be_expired(self, buffer_in_seconds: int = 120) -> bool:
|
|
82
|
+
assert self.decoded_value is not None
|
|
83
|
+
exp = int(self.decoded_value["exp"])
|
|
84
|
+
return (exp - time.time()) < buffer_in_seconds
|
|
85
|
+
|
|
86
|
+
def to_user_info(self) -> UserInfo:
|
|
87
|
+
assert self.decoded_value is not None
|
|
88
|
+
return UserInfo(
|
|
89
|
+
user_id=self.decoded_value["username"],
|
|
90
|
+
email=(
|
|
91
|
+
self.decoded_value["email"] if "email" in self.decoded_value else None
|
|
92
|
+
),
|
|
93
|
+
user_type=UserType(self.decoded_value.get("userType", UserType.user.value)),
|
|
94
|
+
tenant_name=self.tenant_name, # type: ignore[call-arg]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DeviceCode(BaseModel):
|
|
99
|
+
user_code: str = Field(alias="userCode")
|
|
100
|
+
device_code: str = Field(alias="deviceCode")
|
|
101
|
+
|
|
102
|
+
class Config:
|
|
103
|
+
allow_population_by_field_name = True
|
|
104
|
+
allow_mutation = False
|
|
105
|
+
|
|
106
|
+
def get_user_clickable_url(self, tracking_uri: str) -> str:
|
|
107
|
+
parsed_tracking_uri = urlparse(tracking_uri)
|
|
108
|
+
return f"{parsed_tracking_uri.scheme}://{parsed_tracking_uri.netloc}/authorize/device?userCode={self.user_code}"
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from truefoundry.common.request_utils import requests_retry_session
|
|
4
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
5
|
+
from truefoundry.ml.run_utils import append_servicefoundry_path_to_tracking_ui
|
|
6
|
+
from truefoundry.ml.services.entities import HostCreds
|
|
7
|
+
from truefoundry.ml.services.utils import http_request_safe
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ServicefoundryService:
|
|
11
|
+
def __init__(self, tracking_uri: str, token: Optional[str] = None):
|
|
12
|
+
self.host_creds = HostCreds(
|
|
13
|
+
host=append_servicefoundry_path_to_tracking_ui(tracking_uri), token=token
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def get_integration_from_id(self, integration_id: str):
|
|
17
|
+
integration_id = integration_id or ""
|
|
18
|
+
data = http_request_safe(
|
|
19
|
+
method="get",
|
|
20
|
+
host_creds=self.host_creds,
|
|
21
|
+
endpoint="v1/provider-accounts/provider-integrations",
|
|
22
|
+
session=requests_retry_session(retries=1),
|
|
23
|
+
params={"id": integration_id, "type": "blob-storage"},
|
|
24
|
+
timeout=3,
|
|
25
|
+
)
|
|
26
|
+
if (
|
|
27
|
+
data.get("providerIntegrations")
|
|
28
|
+
and len(data["providerIntegrations"]) > 0
|
|
29
|
+
and data["providerIntegrations"][0]
|
|
30
|
+
):
|
|
31
|
+
return data["providerIntegrations"][0]
|
|
32
|
+
else:
|
|
33
|
+
raise MlFoundryException(
|
|
34
|
+
f"Invalid storage integration id: {integration_id}"
|
|
35
|
+
)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
from urllib.parse import urljoin
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
8
|
+
from truefoundry.common.request_utils import requests_retry_session
|
|
9
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
10
|
+
from truefoundry.ml.services.entities import HostCreds
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO: This will be moved later to truefoundry.common.request_utils
|
|
14
|
+
def _make_url(host: str, endpoint: str) -> str:
|
|
15
|
+
if endpoint.startswith("/"):
|
|
16
|
+
raise ValueError("`endpoint` must not start with a leading slash (/)")
|
|
17
|
+
return urljoin(host, endpoint)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _http_request(
|
|
21
|
+
*, method: str, url: str, token: Optional[str] = None, session=requests, **kwargs
|
|
22
|
+
) -> requests.Response:
|
|
23
|
+
headers = kwargs.pop("headers", {}) or {}
|
|
24
|
+
if token is not None:
|
|
25
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
26
|
+
return session.request(method=method, url=url, **kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def http_request(
|
|
30
|
+
*, method: str, host_creds: HostCreds, endpoint: str, session=requests, **kwargs
|
|
31
|
+
) -> requests.Response:
|
|
32
|
+
url = _make_url(host=host_creds.host, endpoint=endpoint)
|
|
33
|
+
return _http_request(
|
|
34
|
+
method=method, url=url, token=host_creds.token, session=session, **kwargs
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def http_request_safe(
|
|
39
|
+
*, method: str, host_creds: HostCreds, endpoint: str, session=requests, **kwargs
|
|
40
|
+
) -> Any:
|
|
41
|
+
url = _make_url(host=host_creds.host, endpoint=endpoint)
|
|
42
|
+
try:
|
|
43
|
+
response = _http_request(
|
|
44
|
+
method=method, url=url, token=host_creds.token, session=session, **kwargs
|
|
45
|
+
)
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
try:
|
|
48
|
+
return response.json()
|
|
49
|
+
except json.JSONDecodeError as je:
|
|
50
|
+
raise MlFoundryException(
|
|
51
|
+
f"Failed to parse response as json. Response: {response.text}"
|
|
52
|
+
) from je
|
|
53
|
+
except requests.exceptions.ConnectionError as ce:
|
|
54
|
+
raise MlFoundryException("Failed to connect to TrueFoundry") from ce
|
|
55
|
+
except requests.exceptions.Timeout as te:
|
|
56
|
+
raise MlFoundryException(f"Request to {url} timed out") from te
|
|
57
|
+
except requests.exceptions.HTTPError as he:
|
|
58
|
+
raise MlFoundryException(
|
|
59
|
+
f"Request to {url} with status code {he.response.status_code}. Response: {he.response.text}",
|
|
60
|
+
status_code=he.response.status_code,
|
|
61
|
+
) from he
|
|
62
|
+
except Exception as e:
|
|
63
|
+
raise MlFoundryException(
|
|
64
|
+
f"Request to {url} failed with an unknown error"
|
|
65
|
+
) from e
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@contextmanager
|
|
69
|
+
def cloud_storage_http_request(
|
|
70
|
+
*,
|
|
71
|
+
method,
|
|
72
|
+
url,
|
|
73
|
+
session=None,
|
|
74
|
+
timeout=None,
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Performs an HTTP PUT/GET request using Python's `requests` module with automatic retry.
|
|
79
|
+
"""
|
|
80
|
+
session = session or requests_retry_session(retries=5, backoff_factor=2)
|
|
81
|
+
kwargs["headers"] = kwargs.get("headers", {}) or {}
|
|
82
|
+
if "blob.core.windows.net" in url:
|
|
83
|
+
kwargs["headers"].update({"x-ms-blob-type": "BlockBlob"})
|
|
84
|
+
if method.lower() not in ("put", "get"):
|
|
85
|
+
raise ValueError(f"Illegal http method: {method}")
|
|
86
|
+
try:
|
|
87
|
+
yield _http_request(
|
|
88
|
+
method=method, url=url, session=session, timeout=timeout, **kwargs
|
|
89
|
+
)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise MlFoundryException(
|
|
92
|
+
f"API request failed with exception {str(e)}"
|
|
93
|
+
) from None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def augmented_raise_for_status(response):
|
|
97
|
+
try:
|
|
98
|
+
response.raise_for_status()
|
|
99
|
+
except requests.exceptions.HTTPError as he:
|
|
100
|
+
raise MlFoundryException(
|
|
101
|
+
f"Request with status code {he.response.status_code}. Response: {he.response.text}",
|
|
102
|
+
status_code=he.response.status_code,
|
|
103
|
+
) from he
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def download_file_using_http_uri(http_uri, download_path, chunk_size=100_000_000):
|
|
107
|
+
"""
|
|
108
|
+
Downloads a file specified using the `http_uri` to a local `download_path`. This function
|
|
109
|
+
uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded.
|
|
110
|
+
|
|
111
|
+
Note : This function is meant to download files using presigned urls from various cloud
|
|
112
|
+
providers.
|
|
113
|
+
"""
|
|
114
|
+
with cloud_storage_http_request(
|
|
115
|
+
method="get", url=http_uri, stream=True
|
|
116
|
+
) as response:
|
|
117
|
+
augmented_raise_for_status(response)
|
|
118
|
+
with open(download_path, "wb") as output_file:
|
|
119
|
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
120
|
+
if not chunk:
|
|
121
|
+
break
|
|
122
|
+
output_file.write(chunk)
|