truefoundry 0.3.3__py3-none-any.whl → 0.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truefoundry might be problematic. Click here for more details.
- truefoundry/cli/__main__.py +3 -17
- truefoundry/common/__init__.py +0 -0
- truefoundry/common/request_utils.py +56 -0
- truefoundry/deploy/cli/cli.py +1 -1
- truefoundry/deploy/lib/auth/credential_provider.py +2 -12
- truefoundry/deploy/lib/clients/servicefoundry_client.py +0 -9
- truefoundry/deploy/lib/exceptions.py +1 -6
- truefoundry/deploy/lib/session.py +1 -16
- truefoundry/langchain/truefoundry_chat.py +1 -1
- truefoundry/langchain/truefoundry_embeddings.py +1 -1
- truefoundry/langchain/truefoundry_llm.py +1 -1
- truefoundry/langchain/utils.py +0 -41
- truefoundry/ml/__init__.py +46 -6
- truefoundry/ml/artifact/__init__.py +0 -0
- truefoundry/ml/artifact/truefoundry_artifact_repo.py +1120 -0
- truefoundry/ml/autogen/__init__.py +0 -0
- truefoundry/ml/autogen/client/__init__.py +373 -0
- truefoundry/ml/autogen/client/api/__init__.py +16 -0
- truefoundry/ml/autogen/client/api/auth_api.py +184 -0
- truefoundry/ml/autogen/client/api/deprecated_api.py +605 -0
- truefoundry/ml/autogen/client/api/experiments_api.py +2109 -0
- truefoundry/ml/autogen/client/api/health_api.py +299 -0
- truefoundry/ml/autogen/client/api/metrics_api.py +371 -0
- truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +7213 -0
- truefoundry/ml/autogen/client/api/python_deployment_config_api.py +201 -0
- truefoundry/ml/autogen/client/api/run_artifacts_api.py +231 -0
- truefoundry/ml/autogen/client/api/runs_api.py +2919 -0
- truefoundry/ml/autogen/client/api_client.py +822 -0
- truefoundry/ml/autogen/client/api_response.py +30 -0
- truefoundry/ml/autogen/client/configuration.py +489 -0
- truefoundry/ml/autogen/client/exceptions.py +161 -0
- truefoundry/ml/autogen/client/models/__init__.py +344 -0
- truefoundry/ml/autogen/client/models/add_custom_metrics_to_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/add_features_to_model_version_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/agent.py +125 -0
- truefoundry/ml/autogen/client/models/agent_app.py +118 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool.py +143 -0
- truefoundry/ml/autogen/client/models/agent_open_api_tool_with_fqn.py +144 -0
- truefoundry/ml/autogen/client/models/agent_with_fqn.py +127 -0
- truefoundry/ml/autogen/client/models/artifact_dto.py +115 -0
- truefoundry/ml/autogen/client/models/artifact_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/artifact_type.py +39 -0
- truefoundry/ml/autogen/client/models/artifact_version_dto.py +141 -0
- truefoundry/ml/autogen/client/models/artifact_version_response_dto.py +77 -0
- truefoundry/ml/autogen/client/models/artifact_version_status.py +35 -0
- truefoundry/ml/autogen/client/models/assistant_message.py +89 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/authorize_user_for_model_version_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/backfill_default_storage_integration_id_request_dto.py +67 -0
- truefoundry/ml/autogen/client/models/blob_storage_reference.py +93 -0
- truefoundry/ml/autogen/client/models/body_get_search_runs_get.py +72 -0
- truefoundry/ml/autogen/client/models/chat_prompt.py +156 -0
- truefoundry/ml/autogen/client/models/chat_prompt_messages_inner.py +171 -0
- truefoundry/ml/autogen/client/models/columns_dto.py +73 -0
- truefoundry/ml/autogen/client/models/content.py +153 -0
- truefoundry/ml/autogen/client/models/content1.py +153 -0
- truefoundry/ml/autogen/client/models/content2.py +174 -0
- truefoundry/ml/autogen/client/models/content2_any_of_inner.py +150 -0
- truefoundry/ml/autogen/client/models/create_artifact_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_response_dto.py +66 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/create_artifact_version_response_dto.py +66 -0
- truefoundry/ml/autogen/client/models/create_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/create_experiment_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/create_experiment_response_dto.py +67 -0
- truefoundry/ml/autogen/client/models/create_model_version_request_dto.py +95 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_for_dataset_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/create_multi_part_upload_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_request_dto.py +72 -0
- truefoundry/ml/autogen/client/models/create_python_deployment_config_response_dto.py +68 -0
- truefoundry/ml/autogen/client/models/create_run_request_dto.py +97 -0
- truefoundry/ml/autogen/client/models/create_run_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/dataset_dto.py +112 -0
- truefoundry/ml/autogen/client/models/dataset_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/delete_artifact_versions_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/delete_model_version_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/delete_run_request.py +65 -0
- truefoundry/ml/autogen/client/models/delete_tag_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/experiment_dto.py +127 -0
- truefoundry/ml/autogen/client/models/experiment_id_request_dto.py +67 -0
- truefoundry/ml/autogen/client/models/experiment_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/experiment_tag_dto.py +69 -0
- truefoundry/ml/autogen/client/models/feature_dto.py +68 -0
- truefoundry/ml/autogen/client/models/feature_value_type.py +35 -0
- truefoundry/ml/autogen/client/models/file_info_dto.py +76 -0
- truefoundry/ml/autogen/client/models/finalize_artifact_version_request_dto.py +101 -0
- truefoundry/ml/autogen/client/models/get_experiment_response_dto.py +88 -0
- truefoundry/ml/autogen/client/models/get_latest_run_log_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/get_metric_history_response.py +79 -0
- truefoundry/ml/autogen/client/models/get_signed_url_for_dataset_write_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_artifact_version_write_response_dto.py +83 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_request_dto.py +68 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_read_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_signed_urls_for_dataset_write_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/get_tenant_id_response_dto.py +74 -0
- truefoundry/ml/autogen/client/models/http_validation_error.py +82 -0
- truefoundry/ml/autogen/client/models/image_content_part.py +87 -0
- truefoundry/ml/autogen/client/models/image_url.py +75 -0
- truefoundry/ml/autogen/client/models/internal_metadata.py +180 -0
- truefoundry/ml/autogen/client/models/latest_run_log_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_request_dto.py +107 -0
- truefoundry/ml/autogen/client/models/list_artifact_versions_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_artifacts_request_dto.py +96 -0
- truefoundry/ml/autogen/client/models/list_artifacts_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_colums_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/list_datasets_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/list_datasets_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_experiments_response_dto.py +86 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_version_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_artifact_versions_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/list_files_for_dataset_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_latest_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_metric_history_request_dto.py +69 -0
- truefoundry/ml/autogen/client/models/list_metric_history_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_model_version_response_dto.py +87 -0
- truefoundry/ml/autogen/client/models/list_model_versions_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/list_models_request_dto.py +89 -0
- truefoundry/ml/autogen/client/models/list_models_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_artifacts_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/list_run_logs_response_dto.py +82 -0
- truefoundry/ml/autogen/client/models/list_seed_experiments_response_dto.py +81 -0
- truefoundry/ml/autogen/client/models/log_batch_request_dto.py +106 -0
- truefoundry/ml/autogen/client/models/log_metric_request_dto.py +80 -0
- truefoundry/ml/autogen/client/models/log_param_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/method.py +37 -0
- truefoundry/ml/autogen/client/models/metric_collection_dto.py +82 -0
- truefoundry/ml/autogen/client/models/metric_dto.py +76 -0
- truefoundry/ml/autogen/client/models/mime_type.py +37 -0
- truefoundry/ml/autogen/client/models/model_configuration.py +103 -0
- truefoundry/ml/autogen/client/models/model_dto.py +122 -0
- truefoundry/ml/autogen/client/models/model_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/model_schema_dto.py +85 -0
- truefoundry/ml/autogen/client/models/model_version_dto.py +163 -0
- truefoundry/ml/autogen/client/models/model_version_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_dto.py +107 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_response_dto.py +79 -0
- truefoundry/ml/autogen/client/models/multi_part_upload_storage_provider.py +34 -0
- truefoundry/ml/autogen/client/models/notify_artifact_version_failure_dto.py +65 -0
- truefoundry/ml/autogen/client/models/openapi_spec.py +152 -0
- truefoundry/ml/autogen/client/models/param_dto.py +66 -0
- truefoundry/ml/autogen/client/models/parameters.py +84 -0
- truefoundry/ml/autogen/client/models/prediction_type.py +34 -0
- truefoundry/ml/autogen/client/models/resolve_agent_app_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/restore_run_request_dto.py +65 -0
- truefoundry/ml/autogen/client/models/run_data_dto.py +104 -0
- truefoundry/ml/autogen/client/models/run_dto.py +84 -0
- truefoundry/ml/autogen/client/models/run_info_dto.py +105 -0
- truefoundry/ml/autogen/client/models/run_log_dto.py +90 -0
- truefoundry/ml/autogen/client/models/run_log_input_dto.py +80 -0
- truefoundry/ml/autogen/client/models/run_response_dto.py +75 -0
- truefoundry/ml/autogen/client/models/run_tag_dto.py +66 -0
- truefoundry/ml/autogen/client/models/search_runs_request_dto.py +94 -0
- truefoundry/ml/autogen/client/models/search_runs_response_dto.py +84 -0
- truefoundry/ml/autogen/client/models/set_experiment_tag_request_dto.py +73 -0
- truefoundry/ml/autogen/client/models/set_tag_request_dto.py +76 -0
- truefoundry/ml/autogen/client/models/signed_url_dto.py +69 -0
- truefoundry/ml/autogen/client/models/stop.py +152 -0
- truefoundry/ml/autogen/client/models/store_run_logs_request_dto.py +83 -0
- truefoundry/ml/autogen/client/models/system_message.py +89 -0
- truefoundry/ml/autogen/client/models/text.py +153 -0
- truefoundry/ml/autogen/client/models/text_content_part.py +84 -0
- truefoundry/ml/autogen/client/models/update_artifact_version_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_dataset_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_experiment_request_dto.py +74 -0
- truefoundry/ml/autogen/client/models/update_model_version_request_dto.py +93 -0
- truefoundry/ml/autogen/client/models/update_run_request_dto.py +78 -0
- truefoundry/ml/autogen/client/models/update_run_response_dto.py +76 -0
- truefoundry/ml/autogen/client/models/url.py +153 -0
- truefoundry/ml/autogen/client/models/user_message.py +89 -0
- truefoundry/ml/autogen/client/models/validation_error.py +87 -0
- truefoundry/ml/autogen/client/models/validation_error_loc_inner.py +154 -0
- truefoundry/ml/autogen/client/rest.py +426 -0
- truefoundry/ml/autogen/client_README.md +322 -0
- truefoundry/ml/cli/__init__.py +0 -0
- truefoundry/ml/cli/cli.py +18 -0
- truefoundry/ml/cli/commands/__init__.py +3 -0
- truefoundry/ml/cli/commands/download.py +87 -0
- truefoundry/ml/constants.py +84 -0
- truefoundry/ml/enums.py +70 -0
- truefoundry/ml/env_vars.py +13 -0
- truefoundry/ml/exceptions.py +8 -0
- truefoundry/ml/git_info.py +60 -0
- truefoundry/ml/internal_namespace.py +52 -0
- truefoundry/ml/log_types/__init__.py +4 -0
- truefoundry/ml/log_types/artifacts/artifact.py +427 -0
- truefoundry/ml/log_types/artifacts/constants.py +33 -0
- truefoundry/ml/log_types/artifacts/dataset.py +383 -0
- truefoundry/ml/log_types/artifacts/general_artifact.py +110 -0
- truefoundry/ml/log_types/artifacts/model.py +628 -0
- truefoundry/ml/log_types/artifacts/model_extras.py +48 -0
- truefoundry/ml/log_types/artifacts/utils.py +161 -0
- truefoundry/ml/log_types/image/__init__.py +3 -0
- truefoundry/ml/log_types/image/constants.py +8 -0
- truefoundry/ml/log_types/image/image.py +358 -0
- truefoundry/ml/log_types/image/image_normalizer.py +101 -0
- truefoundry/ml/log_types/image/types.py +68 -0
- truefoundry/ml/log_types/plot.py +281 -0
- truefoundry/ml/log_types/pydantic_base.py +10 -0
- truefoundry/ml/log_types/utils.py +12 -0
- truefoundry/ml/logger.py +17 -0
- truefoundry/ml/login.py +241 -0
- truefoundry/ml/mlfoundry_api.py +1620 -0
- truefoundry/ml/mlfoundry_run.py +1238 -0
- truefoundry/ml/run_utils.py +102 -0
- truefoundry/ml/services/__init__.py +0 -0
- truefoundry/ml/services/auth_service.py +109 -0
- truefoundry/ml/services/entities.py +108 -0
- truefoundry/ml/services/servicefoundry_service.py +35 -0
- truefoundry/ml/services/utils.py +122 -0
- truefoundry/ml/session.py +271 -0
- truefoundry/ml/validation_utils.py +346 -0
- truefoundry/pydantic_v1.py +5 -1
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/METADATA +19 -12
- truefoundry-0.4.0.dev0.dist-info/RECORD +342 -0
- truefoundry-0.3.3.dist-info/RECORD +0 -136
- /truefoundry/{python_deploy_codegen.py → deploy/python_deploy_codegen.py} +0 -0
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/WHEEL +0 -0
- {truefoundry-0.3.3.dist-info → truefoundry-0.4.0.dev0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1238 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
import platform
|
|
4
|
+
import re
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterable,
|
|
12
|
+
Iterator,
|
|
13
|
+
List,
|
|
14
|
+
Optional,
|
|
15
|
+
Sequence,
|
|
16
|
+
Tuple,
|
|
17
|
+
Union,
|
|
18
|
+
)
|
|
19
|
+
from urllib.parse import urljoin, urlsplit
|
|
20
|
+
|
|
21
|
+
from truefoundry import version
|
|
22
|
+
from truefoundry.ml import constants, enums
|
|
23
|
+
from truefoundry.ml.autogen.client import ( # type: ignore[attr-defined]
|
|
24
|
+
ArtifactType,
|
|
25
|
+
DeleteRunRequest,
|
|
26
|
+
ExperimentsApi,
|
|
27
|
+
ListArtifactVersionsRequestDto,
|
|
28
|
+
ListModelVersionsRequestDto,
|
|
29
|
+
LogBatchRequestDto,
|
|
30
|
+
MetricDto,
|
|
31
|
+
MetricsApi,
|
|
32
|
+
MlfoundryArtifactsApi,
|
|
33
|
+
ParamDto,
|
|
34
|
+
RunDataDto,
|
|
35
|
+
RunDto,
|
|
36
|
+
RunInfoDto,
|
|
37
|
+
RunsApi,
|
|
38
|
+
RunTagDto,
|
|
39
|
+
UpdateRunRequestDto,
|
|
40
|
+
)
|
|
41
|
+
from truefoundry.ml.enums import RunStatus
|
|
42
|
+
from truefoundry.ml.exceptions import MlFoundryException
|
|
43
|
+
from truefoundry.ml.internal_namespace import NAMESPACE
|
|
44
|
+
from truefoundry.ml.log_types import Image, Plot
|
|
45
|
+
from truefoundry.ml.log_types.artifacts.artifact import ArtifactPath, ArtifactVersion
|
|
46
|
+
from truefoundry.ml.log_types.artifacts.general_artifact import _log_artifact_version
|
|
47
|
+
from truefoundry.ml.log_types.artifacts.model import ModelVersion, _log_model_version
|
|
48
|
+
from truefoundry.ml.log_types.artifacts.model_extras import CustomMetric, ModelSchema
|
|
49
|
+
from truefoundry.ml.logger import logger
|
|
50
|
+
from truefoundry.ml.run_utils import ParamsType, flatten_dict, process_params
|
|
51
|
+
from truefoundry.ml.session import ACTIVE_RUNS, _get_api_client, get_active_session
|
|
52
|
+
from truefoundry.ml.validation_utils import (
|
|
53
|
+
MAX_ENTITY_KEY_LENGTH,
|
|
54
|
+
_validate_batch_log_data,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if TYPE_CHECKING:
|
|
58
|
+
import matplotlib
|
|
59
|
+
import plotly
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _ensure_not_deleted(method):
|
|
63
|
+
@functools.wraps(method)
|
|
64
|
+
def _check_deleted_or_not(self, *args, **kwargs):
|
|
65
|
+
if self._deleted:
|
|
66
|
+
raise MlFoundryException("Run was deleted, cannot access a deleted Run")
|
|
67
|
+
else:
|
|
68
|
+
return method(self, *args, **kwargs)
|
|
69
|
+
|
|
70
|
+
return _check_deleted_or_not
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MlFoundryRun:
|
|
74
|
+
"""MlFoundryRun."""
|
|
75
|
+
|
|
76
|
+
VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[A-Za-z0-9_\-\. /]+$")
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
experiment_id: str,
|
|
81
|
+
run_id: str,
|
|
82
|
+
auto_end: bool = False,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""__init__.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
experiment_id (str): experiment_id
|
|
89
|
+
run_id (str): run_id
|
|
90
|
+
auto_end (bool): If to end the run at garbage collection or process end (atexit)
|
|
91
|
+
"""
|
|
92
|
+
self._experiment_id = str(experiment_id)
|
|
93
|
+
self._run_id = run_id
|
|
94
|
+
self._auto_end = auto_end
|
|
95
|
+
self._run_info: Optional[RunInfoDto] = None
|
|
96
|
+
self._run_data: Optional[RunDataDto] = None
|
|
97
|
+
self._deleted = False
|
|
98
|
+
self._terminate_called = False
|
|
99
|
+
if self._auto_end:
|
|
100
|
+
ACTIVE_RUNS.add_run(self)
|
|
101
|
+
|
|
102
|
+
self._api_client = _get_api_client()
|
|
103
|
+
self._experiments_api = ExperimentsApi(api_client=self._api_client)
|
|
104
|
+
self._runs_api = RunsApi(api_client=self._api_client)
|
|
105
|
+
self._metrics_api = MetricsApi(api_client=self._api_client)
|
|
106
|
+
self._mlfoundry_artifacts_api = MlfoundryArtifactsApi(
|
|
107
|
+
api_client=self._api_client
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def _from_dto(cls, run_dto: RunDto) -> "MlFoundryRun":
|
|
112
|
+
"""classmethod to get MLfoundry run from dto instance"""
|
|
113
|
+
assert run_dto.info.experiment_id is not None
|
|
114
|
+
assert run_dto.info.run_id is not None
|
|
115
|
+
run = cls(run_dto.info.experiment_id, run_dto.info.run_id)
|
|
116
|
+
run._run_info = run_dto.info
|
|
117
|
+
run._run_data = run_dto.data
|
|
118
|
+
return run
|
|
119
|
+
|
|
120
|
+
def _get_run_info(self) -> RunInfoDto:
|
|
121
|
+
if self._run_info is not None:
|
|
122
|
+
return self._run_info
|
|
123
|
+
|
|
124
|
+
_run = self._runs_api.get_run_get(run_id=self.run_id)
|
|
125
|
+
self._run_info = _run.run.info
|
|
126
|
+
return self._run_info
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
@_ensure_not_deleted
|
|
130
|
+
def run_id(self) -> str:
|
|
131
|
+
"""Get run_id for the current `run`"""
|
|
132
|
+
return self._run_id
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
@_ensure_not_deleted
|
|
136
|
+
def run_name(self) -> str:
|
|
137
|
+
"""Get run_name for the current `run`"""
|
|
138
|
+
run_info = self._get_run_info()
|
|
139
|
+
assert run_info.name is not None
|
|
140
|
+
return run_info.name
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
@_ensure_not_deleted
|
|
144
|
+
def fqn(self) -> str:
|
|
145
|
+
"""Get fqn for the current `run`"""
|
|
146
|
+
run_info = self._get_run_info()
|
|
147
|
+
assert run_info.fqn is not None
|
|
148
|
+
return run_info.fqn
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
@_ensure_not_deleted
|
|
152
|
+
def status(self) -> RunStatus:
|
|
153
|
+
"""Get status for the current `run`"""
|
|
154
|
+
_run = self._runs_api.get_run_get(run_id=self.run_id)
|
|
155
|
+
assert _run.run.info.status is not None
|
|
156
|
+
return RunStatus(_run.run.info.status)
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
@_ensure_not_deleted
|
|
160
|
+
def ml_repo(self) -> str:
|
|
161
|
+
"""Get ml_repo name of which the current `run` is part of"""
|
|
162
|
+
_experiment = self._experiments_api.get_experiment_get(
|
|
163
|
+
experiment_id=self._experiment_id
|
|
164
|
+
)
|
|
165
|
+
assert _experiment.experiment.name is not None
|
|
166
|
+
return _experiment.experiment.name
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
@_ensure_not_deleted
|
|
170
|
+
def auto_end(self) -> bool:
|
|
171
|
+
"""Tells whether automatic end for `run` is True or False"""
|
|
172
|
+
return self._auto_end
|
|
173
|
+
|
|
174
|
+
@_ensure_not_deleted
|
|
175
|
+
def __repr__(self) -> str:
|
|
176
|
+
return f"<{type(self).__name__} at 0x{id(self):x}: run={self.fqn!r}>"
|
|
177
|
+
|
|
178
|
+
@_ensure_not_deleted
|
|
179
|
+
def __enter__(self):
|
|
180
|
+
return self
|
|
181
|
+
|
|
182
|
+
def _terminate_run_if_running(self, termination_status: RunStatus):
|
|
183
|
+
"""_terminate_run_if_running.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
termination_status (RunStatus): termination_status
|
|
187
|
+
"""
|
|
188
|
+
if self._terminate_called:
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
# Prevent double execution for termination
|
|
192
|
+
self._terminate_called = True
|
|
193
|
+
ACTIVE_RUNS.remove_run(self)
|
|
194
|
+
|
|
195
|
+
current_status = self.status
|
|
196
|
+
try:
|
|
197
|
+
# we do not need to set any termination status unless the run was in RUNNING state
|
|
198
|
+
if current_status != RunStatus.RUNNING:
|
|
199
|
+
return
|
|
200
|
+
logger.info("Setting run status of %r to %r", self.fqn, termination_status)
|
|
201
|
+
_run_update = self._runs_api.update_run_post(
|
|
202
|
+
update_run_request_dto=UpdateRunRequestDto(
|
|
203
|
+
run_id=self.run_id,
|
|
204
|
+
status=termination_status.value,
|
|
205
|
+
end_time=int(time.time() * 1000),
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
self._run_info = _run_update.run_info
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.warning(
|
|
211
|
+
f"failed to set termination status {termination_status} due to {e}"
|
|
212
|
+
)
|
|
213
|
+
logger.info(f"Finished run: {self.fqn!r}, Dashboard: {self.dashboard_link}")
|
|
214
|
+
|
|
215
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
216
|
+
status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED
|
|
217
|
+
self._terminate_run_if_running(status)
|
|
218
|
+
|
|
219
|
+
def __del__(self):
|
|
220
|
+
if self._auto_end:
|
|
221
|
+
self._terminate_run_if_running(RunStatus.FINISHED)
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
@_ensure_not_deleted
|
|
225
|
+
def dashboard_link(self) -> str:
|
|
226
|
+
"""Get Mlfoundry dashboard link for a `run`"""
|
|
227
|
+
session = get_active_session()
|
|
228
|
+
if session is None:
|
|
229
|
+
raise MlFoundryException(
|
|
230
|
+
"No active session found. Perhaps you are not logged in?\n"
|
|
231
|
+
"Please log in using `tfy login [--host HOST] --relogin"
|
|
232
|
+
)
|
|
233
|
+
base_url = "{uri.scheme}://{uri.netloc}/".format(
|
|
234
|
+
uri=urlsplit(session.tracking_uri)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return urljoin(base_url, f"mlfoundry/{self._experiment_id}/run/{self.run_id}/")
|
|
238
|
+
|
|
239
|
+
@_ensure_not_deleted
|
|
240
|
+
def end(self, status: RunStatus = RunStatus.FINISHED):
|
|
241
|
+
"""End a `run`.
|
|
242
|
+
|
|
243
|
+
This function marks the run as `FINISHED`.
|
|
244
|
+
|
|
245
|
+
Examples:
|
|
246
|
+
|
|
247
|
+
```python
|
|
248
|
+
from truefoundry.ml import get_client
|
|
249
|
+
|
|
250
|
+
client = get_client()
|
|
251
|
+
run = client.create_run(
|
|
252
|
+
ml_repo="my-classification-project", run_name="svm-with-rbf-kernel"
|
|
253
|
+
)
|
|
254
|
+
# ...
|
|
255
|
+
# Model training code
|
|
256
|
+
# ...
|
|
257
|
+
run.end()
|
|
258
|
+
```
|
|
259
|
+
|
|
260
|
+
In case the run was created using the context manager approach,
|
|
261
|
+
We do not need to call this function.
|
|
262
|
+
|
|
263
|
+
```python
|
|
264
|
+
from truefoundry.ml import get_client
|
|
265
|
+
|
|
266
|
+
client = get_client()
|
|
267
|
+
with client.create_run(
|
|
268
|
+
ml_repo="my-classification-project", run_name="svm-with-rbf-kernel"
|
|
269
|
+
) as run:
|
|
270
|
+
# ...
|
|
271
|
+
# Model training code
|
|
272
|
+
...
|
|
273
|
+
# `run` will be automatically marked as `FINISHED` or `FAILED`.
|
|
274
|
+
```
|
|
275
|
+
"""
|
|
276
|
+
self._terminate_run_if_running(status)
|
|
277
|
+
|
|
278
|
+
@_ensure_not_deleted
|
|
279
|
+
def delete(self) -> None:
|
|
280
|
+
"""
|
|
281
|
+
This function permanently delete the run
|
|
282
|
+
|
|
283
|
+
Examples:
|
|
284
|
+
|
|
285
|
+
```python
|
|
286
|
+
from truefoundry.ml import get_client
|
|
287
|
+
|
|
288
|
+
client = get_client()
|
|
289
|
+
client.create_ml_repo('iris-learning')
|
|
290
|
+
run = client.create_run(ml_repo="iris-learning", run_name="svm-model1")
|
|
291
|
+
run.log_params({"learning_rate": 0.001})
|
|
292
|
+
run.log_metrics({"accuracy": 0.7, "loss": 0.6})
|
|
293
|
+
|
|
294
|
+
run.delete()
|
|
295
|
+
```
|
|
296
|
+
|
|
297
|
+
In case we try to call or access any other function of that run after deleting
|
|
298
|
+
then it will through MlfoundryException
|
|
299
|
+
|
|
300
|
+
```python
|
|
301
|
+
from truefoundry.ml import get_client
|
|
302
|
+
|
|
303
|
+
client = get_client()
|
|
304
|
+
client.create_ml_repo('iris-learning')
|
|
305
|
+
run = client.create_run(ml_repo="iris-learning", run_name="svm-model1")
|
|
306
|
+
run.log_params({"learning_rate": 0.001})
|
|
307
|
+
run.log_metrics({"accuracy": 0.7, "loss": 0.6})
|
|
308
|
+
|
|
309
|
+
run.delete()
|
|
310
|
+
run.log_params({"learning_rate": 0.001}) # raises MlfoundryException
|
|
311
|
+
```
|
|
312
|
+
"""
|
|
313
|
+
name = self.run_name
|
|
314
|
+
try:
|
|
315
|
+
self._runs_api.hard_delete_run_post(
|
|
316
|
+
delete_run_request=DeleteRunRequest(run_id=self.run_id)
|
|
317
|
+
)
|
|
318
|
+
logger.info(f"Run {name} was deleted successfully")
|
|
319
|
+
ACTIVE_RUNS.remove_run(self)
|
|
320
|
+
self._deleted = True
|
|
321
|
+
self._auto_end = False
|
|
322
|
+
except Exception as ex:
|
|
323
|
+
logger.warning(f"Failed to delete the run {name} because of {ex}")
|
|
324
|
+
raise
|
|
325
|
+
|
|
326
|
+
@_ensure_not_deleted
|
|
327
|
+
def list_artifact_versions(
|
|
328
|
+
self,
|
|
329
|
+
artifact_type: Optional[ArtifactType] = ArtifactType.ARTIFACT,
|
|
330
|
+
) -> Iterator[ArtifactVersion]:
|
|
331
|
+
"""
|
|
332
|
+
Get all the version of an artifact from a particular run to download contents or load them in memory
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
artifact_type: Type of the artifact you want
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
Iterator[ArtifactVersion]: An iterator that yields non deleted artifact-versions
|
|
339
|
+
of an artifact under a given run sorted reverse by the version number
|
|
340
|
+
|
|
341
|
+
Examples:
|
|
342
|
+
|
|
343
|
+
```python
|
|
344
|
+
from truefoundry.ml import get_client
|
|
345
|
+
|
|
346
|
+
client = get_client()
|
|
347
|
+
run = client.create_run(ml_repo="iris-learning", run_name="svm-model1")
|
|
348
|
+
artifact_versions = run.list_artifact_versions()
|
|
349
|
+
|
|
350
|
+
for artifact_version in artifact_versions:
|
|
351
|
+
print(artifact_version)
|
|
352
|
+
|
|
353
|
+
run.end()
|
|
354
|
+
```
|
|
355
|
+
"""
|
|
356
|
+
done, page_token, max_results = False, None, 25
|
|
357
|
+
while not done:
|
|
358
|
+
_artifact_versions = (
|
|
359
|
+
self._mlfoundry_artifacts_api.list_artifact_versions_post(
|
|
360
|
+
list_artifact_versions_request_dto=ListArtifactVersionsRequestDto(
|
|
361
|
+
run_ids=[self.run_id],
|
|
362
|
+
artifact_types=[artifact_type] if artifact_type else None,
|
|
363
|
+
max_results=max_results,
|
|
364
|
+
page_token=page_token,
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
)
|
|
368
|
+
artifact_versions = _artifact_versions.artifact_versions
|
|
369
|
+
page_token = _artifact_versions.next_page_token
|
|
370
|
+
for artifact_version in artifact_versions:
|
|
371
|
+
yield ArtifactVersion.from_fqn(artifact_version.fqn)
|
|
372
|
+
if not artifact_versions or page_token is None:
|
|
373
|
+
done = True
|
|
374
|
+
|
|
375
|
+
@_ensure_not_deleted
|
|
376
|
+
def list_model_versions(
|
|
377
|
+
self,
|
|
378
|
+
) -> Iterator[ModelVersion]:
|
|
379
|
+
"""
|
|
380
|
+
Get all the version of a models from a particular run to download contents or load them in memory
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Iterator[ModelVersion]: An iterator that yields non deleted model-versions
|
|
384
|
+
under a given run sorted reverse by the version number
|
|
385
|
+
|
|
386
|
+
Examples:
|
|
387
|
+
|
|
388
|
+
```python
|
|
389
|
+
from truefoundry.ml import get_client
|
|
390
|
+
|
|
391
|
+
client = get_client()
|
|
392
|
+
run = client.get_run(run_id="<your-run-id>")
|
|
393
|
+
model_versions = run.list_model_versions()
|
|
394
|
+
|
|
395
|
+
for model_version in model_versions:
|
|
396
|
+
print(model_version)
|
|
397
|
+
|
|
398
|
+
run.end()
|
|
399
|
+
```
|
|
400
|
+
"""
|
|
401
|
+
done, page_token, max_results = False, None, 25
|
|
402
|
+
while not done:
|
|
403
|
+
_model_versions = self._mlfoundry_artifacts_api.list_model_versions_post(
|
|
404
|
+
list_model_versions_request_dto=ListModelVersionsRequestDto(
|
|
405
|
+
run_ids=[self.run_id],
|
|
406
|
+
max_results=max_results,
|
|
407
|
+
page_token=page_token,
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
model_versions = _model_versions.model_versions
|
|
411
|
+
page_token = _model_versions.next_page_token
|
|
412
|
+
for model_version in model_versions:
|
|
413
|
+
yield ModelVersion.from_fqn(fqn=model_version.fqn)
|
|
414
|
+
if not model_versions or page_token is None:
|
|
415
|
+
done = True
|
|
416
|
+
|
|
417
|
+
def _add_git_info(self, root_path: Optional[str] = None):
|
|
418
|
+
"""_add_git_info.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
root_path (Optional[str]): root_path
|
|
422
|
+
"""
|
|
423
|
+
root_path = root_path or os.getcwd()
|
|
424
|
+
try:
|
|
425
|
+
from truefoundry.ml.git_info import GitInfo
|
|
426
|
+
|
|
427
|
+
git_info = GitInfo(root_path)
|
|
428
|
+
tags = [
|
|
429
|
+
RunTagDto(
|
|
430
|
+
key=constants.GIT_COMMIT_TAG_NAME,
|
|
431
|
+
value=git_info.current_commit_sha,
|
|
432
|
+
),
|
|
433
|
+
RunTagDto(
|
|
434
|
+
key=constants.GIT_BRANCH_TAG_NAME,
|
|
435
|
+
value=git_info.current_branch_name,
|
|
436
|
+
),
|
|
437
|
+
RunTagDto(
|
|
438
|
+
key=constants.GIT_DIRTY_TAG_NAME, value=str(git_info.is_dirty)
|
|
439
|
+
),
|
|
440
|
+
]
|
|
441
|
+
remote_url = git_info.remote_url
|
|
442
|
+
if remote_url is not None:
|
|
443
|
+
tags.append(
|
|
444
|
+
RunTagDto(key=constants.GIT_REMOTE_URL_NAME, value=remote_url)
|
|
445
|
+
)
|
|
446
|
+
_validate_batch_log_data(metrics=[], params=[], tags=tags)
|
|
447
|
+
self._runs_api.log_run_batch_post(
|
|
448
|
+
log_batch_request_dto=LogBatchRequestDto(run_id=self.run_id, tags=tags)
|
|
449
|
+
)
|
|
450
|
+
except Exception as ex:
|
|
451
|
+
# no-blocking
|
|
452
|
+
logger.warning(f"failed to log git info because {ex}")
|
|
453
|
+
|
|
454
|
+
def _add_python_truefoundry_version(self):
|
|
455
|
+
python_version = platform.python_version()
|
|
456
|
+
truefoundry_version = version.__version__
|
|
457
|
+
tags = [
|
|
458
|
+
RunTagDto(
|
|
459
|
+
key=constants.PYTHON_VERSION_TAG_NAME,
|
|
460
|
+
value=python_version,
|
|
461
|
+
),
|
|
462
|
+
]
|
|
463
|
+
|
|
464
|
+
if truefoundry_version:
|
|
465
|
+
tags.append(
|
|
466
|
+
RunTagDto(
|
|
467
|
+
key=constants.MLFOUNDRY_VERSION_TAG_NAME,
|
|
468
|
+
value=truefoundry_version,
|
|
469
|
+
)
|
|
470
|
+
)
|
|
471
|
+
tags.append(
|
|
472
|
+
RunTagDto(
|
|
473
|
+
key=constants.TRUEFOUNDRY_VERSION_TAG_NAME,
|
|
474
|
+
value=truefoundry_version,
|
|
475
|
+
)
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
logger.warning("Failed to get MLFoundry version.")
|
|
479
|
+
_validate_batch_log_data(metrics=[], params=[], tags=tags)
|
|
480
|
+
self._runs_api.log_run_batch_post(
|
|
481
|
+
log_batch_request_dto=LogBatchRequestDto(run_id=self.run_id, tags=tags)
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
@_ensure_not_deleted
|
|
485
|
+
def log_artifact(
|
|
486
|
+
self,
|
|
487
|
+
name: str,
|
|
488
|
+
artifact_paths: List[
|
|
489
|
+
Union[Tuple[str], Tuple[str, Optional[str]], ArtifactPath]
|
|
490
|
+
],
|
|
491
|
+
description: Optional[str] = None,
|
|
492
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
493
|
+
step: Optional[int] = 0,
|
|
494
|
+
progress: Optional[bool] = None,
|
|
495
|
+
) -> ArtifactVersion:
|
|
496
|
+
"""Logs an artifact for the current ML Repo.
|
|
497
|
+
|
|
498
|
+
An `artifact` is a list of local files and directories.
|
|
499
|
+
This function packs the mentioned files and directories in `artifact_paths`
|
|
500
|
+
and uploads them to remote storage linked to the experiment
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
name (str): Name of the Artifact. If an artifact with this name already exists under the current ML Repo,
|
|
504
|
+
the logged artifact will be added as a new version under that `name`.
|
|
505
|
+
If no artifact exist with the given `name`, the given artifact will be logged as version 1.
|
|
506
|
+
artifact_paths (List[truefoundry.ml.ArtifactPath], optional): A list of pairs
|
|
507
|
+
of (source path, destination path) to add files and folders
|
|
508
|
+
to the artifact version contents. The first member of the pair should be a file or directory path
|
|
509
|
+
and the second member should be the path inside the artifact contents to upload to.
|
|
510
|
+
|
|
511
|
+
```python
|
|
512
|
+
from truefoundry.ml import ArtifactPath
|
|
513
|
+
|
|
514
|
+
...
|
|
515
|
+
run.log_artifact(
|
|
516
|
+
name="xyz",
|
|
517
|
+
artifact_paths=[
|
|
518
|
+
ArtifactPath("foo.txt", "foo/bar/foo.txt"),
|
|
519
|
+
ArtifactPath("tokenizer/", "foo/tokenizer/"),
|
|
520
|
+
ArtifactPath('bar.text'),
|
|
521
|
+
('bar.txt', ),
|
|
522
|
+
('foo.txt', 'a/foo.txt')
|
|
523
|
+
]
|
|
524
|
+
)
|
|
525
|
+
```
|
|
526
|
+
|
|
527
|
+
would result in
|
|
528
|
+
|
|
529
|
+
```
|
|
530
|
+
.
|
|
531
|
+
└── foo/
|
|
532
|
+
├── bar/
|
|
533
|
+
│ └── foo.txt
|
|
534
|
+
└── tokenizer/
|
|
535
|
+
└── # contents of tokenizer/ directory will be uploaded here
|
|
536
|
+
```
|
|
537
|
+
description (Optional[str], optional): arbitrary text upto 1024 characters to store as description.
|
|
538
|
+
This field can be updated at any time after logging. Defaults to `None`
|
|
539
|
+
metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
|
|
540
|
+
For example, you can use this to store metrics, params, notes.
|
|
541
|
+
This field can be updated at any time after logging. Defaults to `None`
|
|
542
|
+
step (int): step/iteration at which the vesion is being logged, defaults to 0.
|
|
543
|
+
progress (bool): value to show progress bar, defaults to None.
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
truefoundry.ml.ArtifactVersion: an instance of `ArtifactVersion` that can be used to download the files,
|
|
547
|
+
or update attributes like description, metadata.
|
|
548
|
+
|
|
549
|
+
Examples:
|
|
550
|
+
|
|
551
|
+
```python
|
|
552
|
+
import os
|
|
553
|
+
from truefoundry.ml import get_client, ArtifactPath
|
|
554
|
+
|
|
555
|
+
with open("artifact.txt", "w") as f:
|
|
556
|
+
f.write("hello-world")
|
|
557
|
+
|
|
558
|
+
client = get_client()
|
|
559
|
+
run = client.create_run(
|
|
560
|
+
ml_repo="my-classification-project", run_name="svm-with-rbf-kernel"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
run.log_artifact(
|
|
564
|
+
name="hello-world-file",
|
|
565
|
+
artifact_paths=[ArtifactPath('artifact.txt', 'a/b/')]
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
run.end()
|
|
569
|
+
```
|
|
570
|
+
"""
|
|
571
|
+
if not artifact_paths:
|
|
572
|
+
raise MlFoundryException(
|
|
573
|
+
"artifact_paths cannot be empty, atleast one artifact_path must be passed"
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
return _log_artifact_version(
|
|
577
|
+
self,
|
|
578
|
+
name=name,
|
|
579
|
+
artifact_paths=artifact_paths,
|
|
580
|
+
description=description,
|
|
581
|
+
metadata=metadata,
|
|
582
|
+
step=step,
|
|
583
|
+
progress=progress,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
@_ensure_not_deleted
|
|
587
|
+
def log_metrics(self, metric_dict: Dict[str, Union[int, float]], step: int = 0):
|
|
588
|
+
"""Log metrics for the current `run`.
|
|
589
|
+
|
|
590
|
+
A metric is defined by a metric name (such as "training-loss") and a
|
|
591
|
+
floating point or integral value (such as `1.2`). A metric is associated
|
|
592
|
+
with a `step` which is the training iteration at which the metric was
|
|
593
|
+
calculated.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
metric_dict (Dict[str, Union[int, float]]): A metric name to metric value map.
|
|
597
|
+
metric value should be either `float` or `int`. This should be
|
|
598
|
+
a non-empty dictionary.
|
|
599
|
+
step (int, optional): Training step/iteration at which the metrics
|
|
600
|
+
present in `metric_dict` were calculated. If not passed, `0` is
|
|
601
|
+
set as the `step`.
|
|
602
|
+
|
|
603
|
+
Examples:
|
|
604
|
+
|
|
605
|
+
```python
|
|
606
|
+
from truefoundry.ml import get_client
|
|
607
|
+
|
|
608
|
+
client = get_client()
|
|
609
|
+
run = client.create_run(
|
|
610
|
+
ml_repo="my-classification-project"
|
|
611
|
+
)
|
|
612
|
+
run.log_metrics(metric_dict={"accuracy": 0.7, "loss": 0.6}, step=0)
|
|
613
|
+
run.log_metrics(metric_dict={"accuracy": 0.8, "loss": 0.4}, step=1)
|
|
614
|
+
|
|
615
|
+
run.end()
|
|
616
|
+
```
|
|
617
|
+
"""
|
|
618
|
+
timestamp = int(time.time() * 1000)
|
|
619
|
+
metrics = []
|
|
620
|
+
for key in metric_dict.keys():
|
|
621
|
+
if isinstance(metric_dict[key], str):
|
|
622
|
+
logger.warning(
|
|
623
|
+
f"Cannot log metric with string value. Discarding metric {key}={metric_dict[key]}"
|
|
624
|
+
)
|
|
625
|
+
continue
|
|
626
|
+
if not self.VALID_PARAM_AND_METRIC_NAMES.match(key):
|
|
627
|
+
logger.warning(
|
|
628
|
+
f"Invalid metric name: {key}. Names may only contain alphanumerics, "
|
|
629
|
+
f"underscores (_), dashes (-), periods (.), spaces ( ), and slashes (/). "
|
|
630
|
+
f"Discarding metric {key}={metric_dict[key]}"
|
|
631
|
+
)
|
|
632
|
+
continue
|
|
633
|
+
metrics.append(
|
|
634
|
+
MetricDto(
|
|
635
|
+
key=key, value=metric_dict[key], timestamp=timestamp, step=step
|
|
636
|
+
)
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
if len(metrics) == 0:
|
|
640
|
+
logger.warning("Cannot log empty metrics dictionary")
|
|
641
|
+
return
|
|
642
|
+
|
|
643
|
+
try:
|
|
644
|
+
_validate_batch_log_data(metrics=metrics, params=[], tags=[])
|
|
645
|
+
self._runs_api.log_run_batch_post(
|
|
646
|
+
log_batch_request_dto=LogBatchRequestDto(
|
|
647
|
+
run_id=self.run_id, metrics=metrics, params=[], tags=[]
|
|
648
|
+
)
|
|
649
|
+
)
|
|
650
|
+
except Exception as e:
|
|
651
|
+
raise MlFoundryException(str(e)) from e
|
|
652
|
+
|
|
653
|
+
logger.info("Metrics logged successfully")
|
|
654
|
+
|
|
655
|
+
@_ensure_not_deleted
|
|
656
|
+
def log_params(self, param_dict: ParamsType, flatten_params: bool = False):
|
|
657
|
+
"""Logs parameters for the run.
|
|
658
|
+
|
|
659
|
+
Parameters or Hyperparameters can be thought of as configurations for a run.
|
|
660
|
+
For example, the type of kernel used in a SVM model is a parameter.
|
|
661
|
+
A Parameter is defined by a name and a string value. Parameters are
|
|
662
|
+
also immutable, we cannot overwrite parameter value for a parameter
|
|
663
|
+
name.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
param_dict (ParamsType): A parameter name to parameter value map.
|
|
667
|
+
Parameter values are converted to `str`.
|
|
668
|
+
flatten_params (bool): Flatten hierarchical dict, e.g. `{'a': {'b': 'c'}} -> {'a.b': 'c'}`.
|
|
669
|
+
All the keys will be converted to `str`. Defaults to False
|
|
670
|
+
|
|
671
|
+
Examples:
|
|
672
|
+
|
|
673
|
+
### Logging parameters using a `dict`.
|
|
674
|
+
|
|
675
|
+
```python
|
|
676
|
+
from truefoundry.ml import get_client
|
|
677
|
+
|
|
678
|
+
client = get_client()
|
|
679
|
+
run = client.create_run(
|
|
680
|
+
ml_repo="my-classification-project"
|
|
681
|
+
)
|
|
682
|
+
run.log_params({"learning_rate": 0.01, "epochs": 10})
|
|
683
|
+
|
|
684
|
+
run.end()
|
|
685
|
+
```
|
|
686
|
+
|
|
687
|
+
### Logging parameters using `argparse` Namespace object
|
|
688
|
+
|
|
689
|
+
```python
|
|
690
|
+
import argparse
|
|
691
|
+
from truefoundry.ml import get_client
|
|
692
|
+
|
|
693
|
+
parser = argparse.ArgumentParser()
|
|
694
|
+
parser.add_argument("-batch_size", type=int, required=True)
|
|
695
|
+
args = parser.parse_args()
|
|
696
|
+
|
|
697
|
+
client = get_client()
|
|
698
|
+
run = client.create_run(
|
|
699
|
+
ml_repo="my-classification-project"
|
|
700
|
+
)
|
|
701
|
+
run.log_params(args)
|
|
702
|
+
```
|
|
703
|
+
"""
|
|
704
|
+
try:
|
|
705
|
+
param_dict = process_params(param_dict)
|
|
706
|
+
param_dict = flatten_dict(param_dict) if flatten_params else param_dict
|
|
707
|
+
|
|
708
|
+
params = []
|
|
709
|
+
for param_key in param_dict.keys():
|
|
710
|
+
if (
|
|
711
|
+
len(str(param_key)) > MAX_ENTITY_KEY_LENGTH
|
|
712
|
+
or len(str(param_dict[param_key])) > MAX_ENTITY_KEY_LENGTH
|
|
713
|
+
):
|
|
714
|
+
logger.warning(
|
|
715
|
+
f"MlFoundry can't log parameters with length greater than {MAX_ENTITY_KEY_LENGTH} characters. "
|
|
716
|
+
f"Discarding {param_key}:{param_dict[param_key]}."
|
|
717
|
+
)
|
|
718
|
+
continue
|
|
719
|
+
if not self.VALID_PARAM_AND_METRIC_NAMES.match(param_key):
|
|
720
|
+
logger.warning(
|
|
721
|
+
f"Invalid param name: {param_key}. Names may only contain alphanumerics, "
|
|
722
|
+
f"underscores (_), dashes (-), periods (.), spaces ( ), and slashes (/). "
|
|
723
|
+
f"Discarding param {param_key}={param_dict[param_key]}"
|
|
724
|
+
)
|
|
725
|
+
continue
|
|
726
|
+
params.append(ParamDto(key=param_key, value=str(param_dict[param_key])))
|
|
727
|
+
|
|
728
|
+
if len(params) == 0:
|
|
729
|
+
logger.warning("Cannot log empty params dictionary")
|
|
730
|
+
|
|
731
|
+
_validate_batch_log_data(metrics=[], params=params, tags=[])
|
|
732
|
+
self._runs_api.log_run_batch_post(
|
|
733
|
+
log_batch_request_dto=LogBatchRequestDto(
|
|
734
|
+
run_id=self.run_id, metrics=[], params=params, tags=[]
|
|
735
|
+
)
|
|
736
|
+
)
|
|
737
|
+
except Exception as e:
|
|
738
|
+
raise MlFoundryException(str(e)) from e
|
|
739
|
+
logger.info("Parameters logged successfully")
|
|
740
|
+
|
|
741
|
+
@_ensure_not_deleted
|
|
742
|
+
def set_tags(self, tags: Dict[str, str]):
|
|
743
|
+
"""Set tags for the current `run`.
|
|
744
|
+
|
|
745
|
+
Tags are "labels" for a run. A tag is represented by a tag name and value.
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
tags (Dict[str, str]): A tag name to value map.
|
|
749
|
+
Tag name cannot start with `mlf.`, `mlf.` prefix
|
|
750
|
+
is reserved for truefoundry. Tag values will be converted
|
|
751
|
+
to `str`.
|
|
752
|
+
|
|
753
|
+
Examples:
|
|
754
|
+
|
|
755
|
+
```python
|
|
756
|
+
from truefoundry.ml import get_client
|
|
757
|
+
|
|
758
|
+
client = get_client()
|
|
759
|
+
run = client.create_run(
|
|
760
|
+
ml_repo="my-classification-project"
|
|
761
|
+
)
|
|
762
|
+
run.set_tags({"nlp.framework": "Spark NLP"})
|
|
763
|
+
|
|
764
|
+
run.end()
|
|
765
|
+
```
|
|
766
|
+
"""
|
|
767
|
+
tags = tags or {}
|
|
768
|
+
try:
|
|
769
|
+
NAMESPACE.validate_namespace_not_used(names=tags.keys())
|
|
770
|
+
tags_arr = [
|
|
771
|
+
RunTagDto(key=key, value=str(value)) for key, value in tags.items()
|
|
772
|
+
]
|
|
773
|
+
self._runs_api.log_run_batch_post(
|
|
774
|
+
log_batch_request_dto=LogBatchRequestDto(
|
|
775
|
+
run_id=self.run_id, metrics=[], params=[], tags=tags_arr
|
|
776
|
+
)
|
|
777
|
+
)
|
|
778
|
+
except Exception as e:
|
|
779
|
+
raise MlFoundryException(str(e)) from e
|
|
780
|
+
logger.info("Tags set successfully")
|
|
781
|
+
|
|
782
|
+
@_ensure_not_deleted
|
|
783
|
+
def get_tags(self, no_cache=False) -> Dict[str, str]:
|
|
784
|
+
"""Returns all the tags set for the current `run`.
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
Dict[str, str]: A dictionary containing tags. The keys in the dictionary
|
|
788
|
+
are tag names and the values are corresponding tag values.
|
|
789
|
+
|
|
790
|
+
Examples:
|
|
791
|
+
|
|
792
|
+
```python
|
|
793
|
+
from truefoundry.ml import get_client
|
|
794
|
+
|
|
795
|
+
client = get_client()
|
|
796
|
+
run = client.create_run(
|
|
797
|
+
ml_repo="my-classification-project"
|
|
798
|
+
)
|
|
799
|
+
run.set_tags({"nlp.framework": "Spark NLP"})
|
|
800
|
+
print(run.get_tags())
|
|
801
|
+
|
|
802
|
+
run.end()
|
|
803
|
+
```
|
|
804
|
+
"""
|
|
805
|
+
if no_cache or not self._run_data:
|
|
806
|
+
_run = self._runs_api.get_run_get(run_id=self.run_id)
|
|
807
|
+
self._run_data = _run.run.data
|
|
808
|
+
assert self._run_data is not None
|
|
809
|
+
tags = self._run_data.tags or []
|
|
810
|
+
return {tag.key: tag.value for tag in tags}
|
|
811
|
+
|
|
812
|
+
@_ensure_not_deleted
|
|
813
|
+
def get_metrics(
|
|
814
|
+
self, metric_names: Optional[Iterable[str]] = None
|
|
815
|
+
) -> Dict[str, List[MetricDto]]:
|
|
816
|
+
"""Get metrics logged for the current `run` grouped by metric name.
|
|
817
|
+
|
|
818
|
+
Args:
|
|
819
|
+
metric_names (Optional[Iterable[str]], optional): A list of metric names
|
|
820
|
+
For which the logged metrics will be fetched. If not passed, then all
|
|
821
|
+
metrics logged under the `run` is returned.
|
|
822
|
+
|
|
823
|
+
Returns:
|
|
824
|
+
Dict[str, List[Metric]]: A dictionary containing metric name to list of metrics
|
|
825
|
+
map.
|
|
826
|
+
|
|
827
|
+
Examples:
|
|
828
|
+
|
|
829
|
+
```python
|
|
830
|
+
from truefoundry.ml import get_client
|
|
831
|
+
|
|
832
|
+
client = get_client()
|
|
833
|
+
run = client.create_run(
|
|
834
|
+
ml_repo="my-classification-project", run_name="svm-with-rbf-kernel"
|
|
835
|
+
)
|
|
836
|
+
run.log_metrics(metric_dict={"accuracy": 0.7, "loss": 0.6}, step=0)
|
|
837
|
+
run.log_metrics(metric_dict={"accuracy": 0.8, "loss": 0.4}, step=1)
|
|
838
|
+
|
|
839
|
+
metrics = run.get_metrics()
|
|
840
|
+
for metric_name, metric_history in metrics.items():
|
|
841
|
+
print(f"logged metrics for metric {metric_name}:")
|
|
842
|
+
for metric in metric_history:
|
|
843
|
+
print(f"value: {metric.value}")
|
|
844
|
+
print(f"step: {metric.step}")
|
|
845
|
+
print(f"timestamp_ms: {metric.timestamp}")
|
|
846
|
+
print("--")
|
|
847
|
+
|
|
848
|
+
run.end()
|
|
849
|
+
```
|
|
850
|
+
"""
|
|
851
|
+
_run = self._runs_api.get_run_get(run_id=self.run_id)
|
|
852
|
+
assert _run.run.data is not None
|
|
853
|
+
run_metrics = _run.run.data.metrics or []
|
|
854
|
+
run_metric_names = {metric.key for metric in run_metrics}
|
|
855
|
+
|
|
856
|
+
metric_names = (
|
|
857
|
+
set(metric_names) if metric_names is not None else run_metric_names
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
unknown_metrics = metric_names - run_metric_names
|
|
861
|
+
if len(unknown_metrics) > 0:
|
|
862
|
+
logger.warning(f"{unknown_metrics} metrics not present in the run")
|
|
863
|
+
metrics_dict: Dict[str, List[MetricDto]] = {
|
|
864
|
+
metric_name: [] for metric_name in unknown_metrics
|
|
865
|
+
}
|
|
866
|
+
valid_metrics = metric_names - unknown_metrics
|
|
867
|
+
for metric_name in valid_metrics:
|
|
868
|
+
_metric_history = self._metrics_api.get_metric_history_get(
|
|
869
|
+
run_id=self.run_id, metric_key=metric_name
|
|
870
|
+
)
|
|
871
|
+
metrics_dict[metric_name] = _metric_history.metrics
|
|
872
|
+
return metrics_dict
|
|
873
|
+
|
|
874
|
+
@_ensure_not_deleted
|
|
875
|
+
def get_params(self, no_cache=False) -> Dict[str, str]:
|
|
876
|
+
"""Get all the params logged for the current `run`.
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
Dict[str, str]: A dictionary containing the parameters. The keys in the dictionary
|
|
880
|
+
are parameter names and the values are corresponding parameter values.
|
|
881
|
+
|
|
882
|
+
Examples:
|
|
883
|
+
|
|
884
|
+
```python
|
|
885
|
+
from truefoundry.ml import get_client
|
|
886
|
+
|
|
887
|
+
client = get_client()
|
|
888
|
+
run = client.create_run(
|
|
889
|
+
ml_repo="my-classification-project"
|
|
890
|
+
)
|
|
891
|
+
run.log_params({"learning_rate": 0.01, "epochs": 10})
|
|
892
|
+
print(run.get_params())
|
|
893
|
+
|
|
894
|
+
run.end()
|
|
895
|
+
```
|
|
896
|
+
"""
|
|
897
|
+
if no_cache or not self._run_data:
|
|
898
|
+
_run = self._runs_api.get_run_get(run_id=self.run_id)
|
|
899
|
+
self._run_data = _run.run.data
|
|
900
|
+
assert self._run_data is not None
|
|
901
|
+
params = self._run_data.params or []
|
|
902
|
+
return {param.key: param.value for param in params}
|
|
903
|
+
|
|
904
|
+
@_ensure_not_deleted
|
|
905
|
+
def log_model(
|
|
906
|
+
self,
|
|
907
|
+
*,
|
|
908
|
+
name: str,
|
|
909
|
+
model_file_or_folder: str,
|
|
910
|
+
framework: Optional[Union[enums.ModelFramework, str]],
|
|
911
|
+
additional_files: Sequence[Tuple[Union[str, Path], Optional[str]]] = (),
|
|
912
|
+
description: Optional[str] = None,
|
|
913
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
914
|
+
model_schema: Optional[Union[ModelSchema, Dict[str, Any]]] = None,
|
|
915
|
+
custom_metrics: Optional[List[Union[CustomMetric, Dict[str, Any]]]] = None,
|
|
916
|
+
step: int = 0,
|
|
917
|
+
progress: Optional[bool] = None,
|
|
918
|
+
) -> ModelVersion:
|
|
919
|
+
# TODO (chiragjn): Document mapping of framework to list of valid model save kwargs
|
|
920
|
+
# TODO (chiragjn): Add more examples
|
|
921
|
+
"""
|
|
922
|
+
Serialize and log a versioned model under the current ML Repo. Each logged model generates a new version
|
|
923
|
+
associated with the given `name` and linked to the current run. Multiple versions of the model can be
|
|
924
|
+
logged as separate versions under the same `name`.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
name (str): Name of the model. If a model with this name already exists under the current ML Repo,
|
|
928
|
+
the logged model will be added as a new version under that `name`. If no models exist with the given
|
|
929
|
+
`name`, the given model will be logged as version 1.
|
|
930
|
+
model_file_or_folder (str): Path to either a single file or a folder containing model files. This folder
|
|
931
|
+
is usually created using serialization methods of libraries or frameworks e.g. `joblib.dump`,
|
|
932
|
+
`model.save_pretrained(...)`, `torch.save(...)`, `model.save(...)`
|
|
933
|
+
framework (Union[enums.ModelFramework, str]): Model Framework. Ex:- pytorch, sklearn, tensorflow etc.
|
|
934
|
+
The full list of supported frameworks can be found in `truefoundry.ml.enums.ModelFramework`.
|
|
935
|
+
Can also be `None` when `model` is `None`.
|
|
936
|
+
additional_files (Sequence[Tuple[Union[str, Path], Optional[str]]], optional): A list of pairs
|
|
937
|
+
of (source path, destination path) to add additional files and folders
|
|
938
|
+
to the model version contents. The first member of the pair should be a file or directory path
|
|
939
|
+
and the second member should be the path inside the model versions contents to upload to.
|
|
940
|
+
The model version contents are arranged like follows
|
|
941
|
+
.
|
|
942
|
+
└── model/
|
|
943
|
+
└── # model files are serialized here
|
|
944
|
+
└── # any additional files and folders can be added here.
|
|
945
|
+
|
|
946
|
+
You can also add additional files to model/ subdirectory by specifying the destination path as model/
|
|
947
|
+
|
|
948
|
+
```
|
|
949
|
+
E.g. >>> run.log_model(
|
|
950
|
+
... name="xyz", model_file_or_folder="clf.joblib", framework="sklearn",
|
|
951
|
+
... additional_files=[("foo.txt", "foo/bar/foo.txt"), ("tokenizer/", "foo/tokenizer/")]
|
|
952
|
+
... )
|
|
953
|
+
would result in
|
|
954
|
+
.
|
|
955
|
+
├── model/
|
|
956
|
+
│ └── clf.joblib # if `model_file_or_folder` is a folder, contents will be added here
|
|
957
|
+
└── foo/
|
|
958
|
+
├── bar/
|
|
959
|
+
│ └── foo.txt
|
|
960
|
+
└── tokenizer/
|
|
961
|
+
└── # contents of tokenizer/ directory will be uploaded here
|
|
962
|
+
```
|
|
963
|
+
description (Optional[str], optional): arbitrary text upto 1024 characters to store as description.
|
|
964
|
+
This field can be updated at any time after logging. Defaults to `None`
|
|
965
|
+
metadata (Optional[Dict[str, Any]], optional): arbitrary json serializable dictionary to store metadata.
|
|
966
|
+
For example, you can use this to store metrics, params, notes.
|
|
967
|
+
This field can be updated at any time after logging. Defaults to `None`
|
|
968
|
+
model_schema (Optional[Union[Dict[str, Any], ModelSchema]], optional):
|
|
969
|
+
instance of `truefoundry.ml.ModelSchema`.
|
|
970
|
+
This schema needs to be consistent with older versions of the model under the given `name` i.e.
|
|
971
|
+
a feature's value type and model's prediction type cannot be changed in the schema of new version.
|
|
972
|
+
Features can be removed or added between versions.
|
|
973
|
+
```
|
|
974
|
+
E.g. if there exists a v1 with
|
|
975
|
+
schema = {"features": {"name": "feat1": "int"}, "prediction": "categorical"}, then
|
|
976
|
+
|
|
977
|
+
schema = {"features": {"name": "feat1": "string"}, "prediction": "categorical"} or
|
|
978
|
+
schema = {"features": {"name": "feat1": "int"}, "prediction": "numerical"}
|
|
979
|
+
are invalid because they change the types of existing features and prediction
|
|
980
|
+
|
|
981
|
+
while
|
|
982
|
+
schema = {"features": {"name": "feat1": "int", "feat2": "string"}, "prediction": "categorical"} or
|
|
983
|
+
schema = {"features": {"feat2": "string"}, "prediction": "categorical"}
|
|
984
|
+
are valid
|
|
985
|
+
|
|
986
|
+
This field can be updated at any time after logging. Defaults to None
|
|
987
|
+
```
|
|
988
|
+
custom_metrics: (Optional[Union[List[Dict[str, Any]], CustomMetric]], optional): list of instances of
|
|
989
|
+
`truefoundry.ml.CustomMetric`
|
|
990
|
+
The custom metrics must be added according to the prediction type of schema.
|
|
991
|
+
custom_metrics = [{
|
|
992
|
+
"name": "mean_square_error",
|
|
993
|
+
"type": "metric",
|
|
994
|
+
"value_type": "float"
|
|
995
|
+
}]
|
|
996
|
+
step (int): step/iteration at which the model is being logged, defaults to 0.
|
|
997
|
+
progress (bool): value to show progress bar, defaults to None.
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
truefoundry.ml.ModelVersion: an instance of `ModelVersion` that can be used to download the files,
|
|
1001
|
+
load the model, or update attributes like description, metadata, schema.
|
|
1002
|
+
|
|
1003
|
+
Examples:
|
|
1004
|
+
|
|
1005
|
+
### Sklearn
|
|
1006
|
+
|
|
1007
|
+
```python
|
|
1008
|
+
from truefoundry.ml import get_client
|
|
1009
|
+
from truefoundry.ml.enums import ModelFramework
|
|
1010
|
+
|
|
1011
|
+
import joblib
|
|
1012
|
+
import numpy as np
|
|
1013
|
+
from sklearn.pipeline import make_pipeline
|
|
1014
|
+
from sklearn.preprocessing import StandardScaler
|
|
1015
|
+
from sklearn.svm import SVC
|
|
1016
|
+
|
|
1017
|
+
X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
|
|
1018
|
+
y = np.array([1, 1, 2, 2])
|
|
1019
|
+
clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))
|
|
1020
|
+
clf.fit(X, y)
|
|
1021
|
+
joblib.dump(clf, "sklearn-pipeline.joblib")
|
|
1022
|
+
|
|
1023
|
+
client = get_client()
|
|
1024
|
+
client.create_ml_repo( # This is only required once
|
|
1025
|
+
ml_repo="my-classification-project",
|
|
1026
|
+
# This controls which bucket is used.
|
|
1027
|
+
# You can get this from Integrations > Blob Storage. `None` picks the default
|
|
1028
|
+
storage_integration_fqn=None
|
|
1029
|
+
)
|
|
1030
|
+
run = client.create_run(
|
|
1031
|
+
ml_repo="my-classification-project"
|
|
1032
|
+
)
|
|
1033
|
+
model_version = run.log_model(
|
|
1034
|
+
name="my-sklearn-model",
|
|
1035
|
+
model_file_or_folder="sklearn-pipeline.joblib",
|
|
1036
|
+
framework=ModelFramework.SKLEARN,
|
|
1037
|
+
metadata={"accuracy": 0.99, "f1": 0.80},
|
|
1038
|
+
step=1, # step number, useful when using iterative algorithms like SGD
|
|
1039
|
+
)
|
|
1040
|
+
print(model_version.fqn)
|
|
1041
|
+
```
|
|
1042
|
+
|
|
1043
|
+
### Huggingface Transformers
|
|
1044
|
+
|
|
1045
|
+
```python
|
|
1046
|
+
from truefoundry.ml import get_client
|
|
1047
|
+
from truefoundry.ml.enums import ModelFramework
|
|
1048
|
+
|
|
1049
|
+
import torch
|
|
1050
|
+
from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM
|
|
1051
|
+
pln = pipeline(
|
|
1052
|
+
"text-generation",
|
|
1053
|
+
model_file_or_folder="EleutherAI/pythia-70m",
|
|
1054
|
+
tokenizer="EleutherAI/pythia-70m",
|
|
1055
|
+
torch_dtype=torch.float16
|
|
1056
|
+
)
|
|
1057
|
+
pln.model.save_pretrained("my-transformers-model")
|
|
1058
|
+
pln.tokenizer.save_pretrained("my-transformers-model")
|
|
1059
|
+
|
|
1060
|
+
client = get_client()
|
|
1061
|
+
client.create_ml_repo( # This is only required once
|
|
1062
|
+
ml_repo="my-llm-project",
|
|
1063
|
+
# This controls which bucket is used.
|
|
1064
|
+
# You can get this from Integrations > Blob Storage. `None` picks the default
|
|
1065
|
+
storage_integration_fqn=None
|
|
1066
|
+
)
|
|
1067
|
+
run = client.create_run(
|
|
1068
|
+
ml_repo="my-llm-project"
|
|
1069
|
+
)
|
|
1070
|
+
model_version = run.log_model(
|
|
1071
|
+
name="my-transformers-model",
|
|
1072
|
+
model_file_or_folder="my-transformers-model/",
|
|
1073
|
+
framework=ModelFramework.TRANSFORMERS
|
|
1074
|
+
)
|
|
1075
|
+
print(model_version.fqn)
|
|
1076
|
+
```
|
|
1077
|
+
"""
|
|
1078
|
+
|
|
1079
|
+
model_version = _log_model_version(
|
|
1080
|
+
run=self,
|
|
1081
|
+
name=name,
|
|
1082
|
+
model_file_or_folder=model_file_or_folder,
|
|
1083
|
+
framework=framework,
|
|
1084
|
+
additional_files=additional_files,
|
|
1085
|
+
description=description,
|
|
1086
|
+
metadata=metadata,
|
|
1087
|
+
model_schema=model_schema,
|
|
1088
|
+
custom_metrics=custom_metrics,
|
|
1089
|
+
step=step,
|
|
1090
|
+
progress=progress,
|
|
1091
|
+
)
|
|
1092
|
+
logger.info(f"Logged model successfully with fqn {model_version.fqn!r}")
|
|
1093
|
+
return model_version
|
|
1094
|
+
|
|
1095
|
+
@_ensure_not_deleted
|
|
1096
|
+
def log_images(self, images: Dict[str, Image], step: int = 0):
|
|
1097
|
+
"""Log images under the current `run` at the given `step`.
|
|
1098
|
+
|
|
1099
|
+
Use this function to log images for a `run`. `PIL` package is needed to log images.
|
|
1100
|
+
To install the `PIL` package, run `pip install pillow`.
|
|
1101
|
+
|
|
1102
|
+
Args:
|
|
1103
|
+
images (Dict[str, "truefoundry.ml.Image"]): A map of string image key to instance of
|
|
1104
|
+
`truefoundry.ml.Image` class. The image key should only contain alphanumeric,
|
|
1105
|
+
hyphens(-) or underscores(_). For a single key and step pair, we can log only
|
|
1106
|
+
one image.
|
|
1107
|
+
step (int, optional): Training step/iteration for which the `images` should be
|
|
1108
|
+
logged. Default is `0`.
|
|
1109
|
+
|
|
1110
|
+
Examples:
|
|
1111
|
+
|
|
1112
|
+
### Logging images from different sources
|
|
1113
|
+
|
|
1114
|
+
```python
|
|
1115
|
+
from truefoundry.ml import get_client, Image
|
|
1116
|
+
import numpy as np
|
|
1117
|
+
import PIL.Image
|
|
1118
|
+
|
|
1119
|
+
client = get_client()
|
|
1120
|
+
run = client.create_run(
|
|
1121
|
+
ml_repo="my-classification-project",
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
imarray = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
|
1125
|
+
im = PIL.Image.fromarray(imarray.astype("uint8")).convert("RGB")
|
|
1126
|
+
im.save("result_image.jpeg")
|
|
1127
|
+
|
|
1128
|
+
images_to_log = {
|
|
1129
|
+
"logged-image-array": Image(data_or_path=imarray),
|
|
1130
|
+
"logged-pil-image": Image(data_or_path=im),
|
|
1131
|
+
"logged-image-from-path": Image(data_or_path="result_image.jpeg"),
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
run.log_images(images_to_log, step=1)
|
|
1135
|
+
run.end()
|
|
1136
|
+
```
|
|
1137
|
+
"""
|
|
1138
|
+
for key, image in images.items():
|
|
1139
|
+
if not isinstance(image, Image):
|
|
1140
|
+
raise MlFoundryException(
|
|
1141
|
+
"image should be of type `truefoundry.ml.Image`"
|
|
1142
|
+
)
|
|
1143
|
+
image.save(run=self, key=key, step=step)
|
|
1144
|
+
|
|
1145
|
+
@_ensure_not_deleted
|
|
1146
|
+
def log_plots(
|
|
1147
|
+
self,
|
|
1148
|
+
plots: Dict[
|
|
1149
|
+
str,
|
|
1150
|
+
Union[
|
|
1151
|
+
"matplotlib.pyplot",
|
|
1152
|
+
"matplotlib.figure.Figure",
|
|
1153
|
+
"plotly.graph_objects.Figure",
|
|
1154
|
+
Plot,
|
|
1155
|
+
],
|
|
1156
|
+
],
|
|
1157
|
+
step: int = 0,
|
|
1158
|
+
):
|
|
1159
|
+
"""Log custom plots under the current `run` at the given `step`.
|
|
1160
|
+
|
|
1161
|
+
Use this function to log custom matplotlib, plotly plots.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
plots (Dict[str, "matplotlib.pyplot", "matplotlib.figure.Figure", "plotly.graph_objects.Figure", Plot]):
|
|
1165
|
+
A map of string plot key to the plot or figure object.
|
|
1166
|
+
The plot key should only contain alphanumeric, hyphens(-) or
|
|
1167
|
+
underscores(_). For a single key and step pair, we can log only
|
|
1168
|
+
one image.
|
|
1169
|
+
step (int, optional): Training step/iteration for which the `plots` should be
|
|
1170
|
+
logged. Default is `0`.
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
Examples:
|
|
1174
|
+
|
|
1175
|
+
### Logging a plotly figure
|
|
1176
|
+
|
|
1177
|
+
```python
|
|
1178
|
+
from truefoundry.ml import get_client
|
|
1179
|
+
import plotly.express as px
|
|
1180
|
+
|
|
1181
|
+
client = get_client()
|
|
1182
|
+
run = client.create_run(
|
|
1183
|
+
ml_repo="my-classification-project",
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
df = px.data.tips()
|
|
1187
|
+
fig = px.histogram(
|
|
1188
|
+
df,
|
|
1189
|
+
x="total_bill",
|
|
1190
|
+
y="tip",
|
|
1191
|
+
color="sex",
|
|
1192
|
+
marginal="rug",
|
|
1193
|
+
hover_data=df.columns,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
plots_to_log = {
|
|
1197
|
+
"distribution-plot": fig,
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
run.log_plots(plots_to_log, step=1)
|
|
1201
|
+
run.end()
|
|
1202
|
+
```
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
### Logging a matplotlib plt or figure
|
|
1206
|
+
|
|
1207
|
+
```python
|
|
1208
|
+
from truefoundry.ml import get_client
|
|
1209
|
+
from matplotlib import pyplot as plt
|
|
1210
|
+
import numpy as np
|
|
1211
|
+
|
|
1212
|
+
client = get_client()
|
|
1213
|
+
run = client.create_run(
|
|
1214
|
+
ml_repo="my-classification-project",
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
t = np.arange(0.0, 5.0, 0.01)
|
|
1218
|
+
s = np.cos(2 * np.pi * t)
|
|
1219
|
+
(line,) = plt.plot(t, s, lw=2)
|
|
1220
|
+
|
|
1221
|
+
plt.annotate(
|
|
1222
|
+
"local max",
|
|
1223
|
+
xy=(2, 1),
|
|
1224
|
+
xytext=(3, 1.5),
|
|
1225
|
+
arrowprops=dict(facecolor="black", shrink=0.05),
|
|
1226
|
+
)
|
|
1227
|
+
|
|
1228
|
+
plt.ylim(-2, 2)
|
|
1229
|
+
|
|
1230
|
+
plots_to_log = {"cos-plot": plt, "cos-plot-using-figure": plt.gcf()}
|
|
1231
|
+
|
|
1232
|
+
run.log_plots(plots_to_log, step=1)
|
|
1233
|
+
run.end()
|
|
1234
|
+
```
|
|
1235
|
+
"""
|
|
1236
|
+
for key, plot in plots.items():
|
|
1237
|
+
plot = Plot(plot) if not isinstance(plot, Plot) else plot
|
|
1238
|
+
plot.save(run=self, key=key, step=step)
|