oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +12 -0
- ads/aqua/base.py +324 -0
- ads/aqua/cli.py +19 -0
- ads/aqua/config/deployment_config_defaults.json +9 -0
- ads/aqua/config/resource_limit_names.json +7 -0
- ads/aqua/constants.py +45 -0
- ads/aqua/data.py +40 -0
- ads/aqua/decorator.py +101 -0
- ads/aqua/deployment.py +643 -0
- ads/aqua/dummy_data/icon.txt +1 -0
- ads/aqua/dummy_data/oci_model_deployments.json +56 -0
- ads/aqua/dummy_data/oci_models.json +1 -0
- ads/aqua/dummy_data/readme.md +26 -0
- ads/aqua/evaluation.py +1751 -0
- ads/aqua/exception.py +82 -0
- ads/aqua/extension/__init__.py +40 -0
- ads/aqua/extension/base_handler.py +138 -0
- ads/aqua/extension/common_handler.py +21 -0
- ads/aqua/extension/deployment_handler.py +202 -0
- ads/aqua/extension/evaluation_handler.py +135 -0
- ads/aqua/extension/finetune_handler.py +66 -0
- ads/aqua/extension/model_handler.py +59 -0
- ads/aqua/extension/ui_handler.py +201 -0
- ads/aqua/extension/utils.py +23 -0
- ads/aqua/finetune.py +579 -0
- ads/aqua/job.py +29 -0
- ads/aqua/model.py +819 -0
- ads/aqua/training/__init__.py +4 -0
- ads/aqua/training/exceptions.py +459 -0
- ads/aqua/ui.py +453 -0
- ads/aqua/utils.py +715 -0
- ads/cli.py +37 -6
- ads/common/decorator/__init__.py +7 -3
- ads/common/decorator/require_nonempty_arg.py +65 -0
- ads/common/object_storage_details.py +166 -7
- ads/common/oci_client.py +18 -1
- ads/common/oci_logging.py +2 -2
- ads/common/oci_mixin.py +4 -5
- ads/common/serializer.py +34 -5
- ads/common/utils.py +75 -10
- ads/config.py +40 -1
- ads/jobs/ads_job.py +43 -25
- ads/jobs/builders/infrastructure/base.py +4 -2
- ads/jobs/builders/infrastructure/dsc_job.py +49 -39
- ads/jobs/builders/runtimes/base.py +71 -1
- ads/jobs/builders/runtimes/container_runtime.py +4 -4
- ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
- ads/jobs/templates/driver_pytorch.py +27 -10
- ads/model/artifact_downloader.py +84 -14
- ads/model/artifact_uploader.py +25 -23
- ads/model/datascience_model.py +388 -38
- ads/model/deployment/model_deployment.py +10 -2
- ads/model/generic_model.py +8 -0
- ads/model/model_file_description_schema.json +68 -0
- ads/model/model_metadata.py +1 -1
- ads/model/service/oci_datascience_model.py +34 -5
- ads/opctl/operator/lowcode/anomaly/README.md +2 -1
- ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
- ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
- ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
- ads/opctl/operator/lowcode/forecast/README.md +3 -2
- ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
- ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
- ads/telemetry/base.py +62 -0
- ads/telemetry/client.py +105 -0
- ads/telemetry/telemetry.py +6 -3
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1761,7 +1761,11 @@ class ModelDeployment(Builder):
|
|
1761
1761
|
}
|
1762
1762
|
|
1763
1763
|
logs = {}
|
1764
|
-
if
|
1764
|
+
if (
|
1765
|
+
self.infrastructure.access_log and
|
1766
|
+
self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
|
1767
|
+
and self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_ID, None)
|
1768
|
+
):
|
1765
1769
|
logs[self.infrastructure.CONST_ACCESS] = {
|
1766
1770
|
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.access_log.get(
|
1767
1771
|
"logGroupId", None
|
@@ -1770,7 +1774,11 @@ class ModelDeployment(Builder):
|
|
1770
1774
|
"logId", None
|
1771
1775
|
),
|
1772
1776
|
}
|
1773
|
-
if
|
1777
|
+
if (
|
1778
|
+
self.infrastructure.predict_log and
|
1779
|
+
self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
|
1780
|
+
and self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_ID, None)
|
1781
|
+
):
|
1774
1782
|
logs[self.infrastructure.CONST_PREDICT] = {
|
1775
1783
|
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.predict_log.get(
|
1776
1784
|
"logGroupId", None
|
ads/model/generic_model.py
CHANGED
@@ -2054,6 +2054,7 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2054
2054
|
remove_existing_artifact: Optional[bool] = True,
|
2055
2055
|
reload: Optional[bool] = True,
|
2056
2056
|
version_label: Optional[str] = None,
|
2057
|
+
model_by_reference: Optional[bool] = False,
|
2057
2058
|
**kwargs,
|
2058
2059
|
) -> str:
|
2059
2060
|
"""Saves model artifacts to the model catalog.
|
@@ -2091,6 +2092,8 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2091
2092
|
The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
|
2092
2093
|
reload: (bool, optional)
|
2093
2094
|
Whether to reload to check if `load_model()` works in `score.py`. Default to `True`.
|
2095
|
+
model_by_reference: (bool, optional)
|
2096
|
+
Whether model artifact is made available to Model Store by reference.
|
2094
2097
|
kwargs:
|
2095
2098
|
project_id: (str, optional).
|
2096
2099
|
Project OCID. If not specified, the value will be taken either
|
@@ -2220,6 +2223,7 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2220
2223
|
overwrite_existing_artifact=overwrite_existing_artifact,
|
2221
2224
|
remove_existing_artifact=remove_existing_artifact,
|
2222
2225
|
parallel_process_count=parallel_process_count,
|
2226
|
+
model_by_reference=model_by_reference,
|
2223
2227
|
**kwargs,
|
2224
2228
|
)
|
2225
2229
|
|
@@ -2620,6 +2624,7 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2620
2624
|
remove_existing_artifact: Optional[bool] = True,
|
2621
2625
|
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
|
2622
2626
|
version_label: Optional[str] = None,
|
2627
|
+
model_by_reference: Optional[bool] = False,
|
2623
2628
|
**kwargs: Dict,
|
2624
2629
|
) -> "ModelDeployment":
|
2625
2630
|
"""Shortcut for prepare, save and deploy steps.
|
@@ -2724,6 +2729,8 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2724
2729
|
The Model version set OCID, or name, or `ModelVersionSet` instance.
|
2725
2730
|
version_label: (str, optional). Defaults to None.
|
2726
2731
|
The model version lebel.
|
2732
|
+
model_by_reference: (bool, optional)
|
2733
|
+
Whether model artifact is made available to Model Store by reference.
|
2727
2734
|
kwargs:
|
2728
2735
|
impute_values: (dict, optional).
|
2729
2736
|
The dictionary where the key is the column index(or names is accepted
|
@@ -2827,6 +2834,7 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
|
|
2827
2834
|
model_version_set=model_version_set,
|
2828
2835
|
version_label=version_label,
|
2829
2836
|
region=kwargs.pop("region", None),
|
2837
|
+
model_by_reference=model_by_reference,
|
2830
2838
|
)
|
2831
2839
|
# Set default deployment_display_name if not specified - randomly generated easy to remember name generated
|
2832
2840
|
if not deployment_display_name:
|
@@ -0,0 +1,68 @@
|
|
1
|
+
{
|
2
|
+
"$schema": "http://json-schema.org/draft-07/schema#",
|
3
|
+
"properties": {
|
4
|
+
"models": {
|
5
|
+
"items": {
|
6
|
+
"properties": {
|
7
|
+
"bucketName": {
|
8
|
+
"type": "string"
|
9
|
+
},
|
10
|
+
"namespace": {
|
11
|
+
"type": "string"
|
12
|
+
},
|
13
|
+
"objects": {
|
14
|
+
"items": {
|
15
|
+
"properties": {
|
16
|
+
"name": {
|
17
|
+
"type": "string"
|
18
|
+
},
|
19
|
+
"sizeInBytes": {
|
20
|
+
"minimum": 0,
|
21
|
+
"type": "integer"
|
22
|
+
},
|
23
|
+
"version": {
|
24
|
+
"type": "string"
|
25
|
+
}
|
26
|
+
},
|
27
|
+
"required": [
|
28
|
+
"name",
|
29
|
+
"version",
|
30
|
+
"sizeInBytes"
|
31
|
+
],
|
32
|
+
"type": "object"
|
33
|
+
},
|
34
|
+
"minItems": 1,
|
35
|
+
"type": "array"
|
36
|
+
},
|
37
|
+
"prefix": {
|
38
|
+
"type": "string"
|
39
|
+
}
|
40
|
+
},
|
41
|
+
"required": [
|
42
|
+
"namespace",
|
43
|
+
"bucketName",
|
44
|
+
"prefix",
|
45
|
+
"objects"
|
46
|
+
],
|
47
|
+
"type": "object"
|
48
|
+
},
|
49
|
+
"minItems": 1,
|
50
|
+
"type": "array"
|
51
|
+
},
|
52
|
+
"type": {
|
53
|
+
"enum": [
|
54
|
+
"modelOSSReferenceDescription"
|
55
|
+
],
|
56
|
+
"type": "string"
|
57
|
+
},
|
58
|
+
"version": {
|
59
|
+
"type": "string"
|
60
|
+
}
|
61
|
+
},
|
62
|
+
"required": [
|
63
|
+
"version",
|
64
|
+
"type",
|
65
|
+
"models"
|
66
|
+
],
|
67
|
+
"type": "object"
|
68
|
+
}
|
ads/model/model_metadata.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8 -*--
|
3
3
|
|
4
|
-
# Copyright (c) 2021,
|
4
|
+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import json
|
@@ -38,6 +38,8 @@ MODEL_NEEDS_TO_BE_SAVED = (
|
|
38
38
|
"Model needs to be saved to the Model Catalog before it can be accessed."
|
39
39
|
)
|
40
40
|
|
41
|
+
MODEL_BY_REFERENCE_DESC = "modelDescription"
|
42
|
+
|
41
43
|
|
42
44
|
class ModelProvenanceNotFoundError(Exception): # pragma: no cover
|
43
45
|
pass
|
@@ -304,18 +306,25 @@ class OCIDataScienceModel(
|
|
304
306
|
@check_for_model_id(
|
305
307
|
msg="Model needs to be saved to the Model Catalog before the artifact can be created."
|
306
308
|
)
|
307
|
-
def create_model_artifact(
|
309
|
+
def create_model_artifact(
|
310
|
+
self,
|
311
|
+
bytes_content: BytesIO,
|
312
|
+
extension: str = None,
|
313
|
+
) -> None:
|
308
314
|
"""Creates model artifact for specified model.
|
309
315
|
|
310
316
|
Parameters
|
311
317
|
----------
|
312
318
|
bytes_content: BytesIO
|
313
319
|
Model artifacts to upload.
|
320
|
+
extension: str
|
321
|
+
File extension, defaults to zip
|
314
322
|
"""
|
323
|
+
ext = ".json" if extension and extension.lower() == ".json" else ".zip"
|
315
324
|
self.client.create_model_artifact(
|
316
325
|
self.id,
|
317
326
|
bytes_content,
|
318
|
-
content_disposition=f'attachment; filename="{self.id}
|
327
|
+
content_disposition=f'attachment; filename="{self.id}{ext}"',
|
319
328
|
)
|
320
329
|
|
321
330
|
@check_for_model_id(
|
@@ -423,10 +432,14 @@ class OCIDataScienceModel(
|
|
423
432
|
OCIDataScienceModel
|
424
433
|
The `OCIDataScienceModel` instance (self).
|
425
434
|
"""
|
435
|
+
|
436
|
+
model_details = self.to_oci_model(UpdateModelDetails)
|
437
|
+
|
438
|
+
# Clean up the model version set, otherwise it throws an error that model is already
|
439
|
+
# associated with the model version set.
|
440
|
+
model_details.model_version_set_id = None
|
426
441
|
return self.update_from_oci_model(
|
427
|
-
self.client.update_model(
|
428
|
-
self.id, self.to_oci_model(UpdateModelDetails)
|
429
|
-
).data
|
442
|
+
self.client.update_model(self.id, model_details).data
|
430
443
|
)
|
431
444
|
|
432
445
|
@check_for_model_id(
|
@@ -539,3 +552,19 @@ class OCIDataScienceModel(
|
|
539
552
|
if not ocid:
|
540
553
|
raise ValueError("Model OCID not provided.")
|
541
554
|
return super().from_ocid(ocid)
|
555
|
+
|
556
|
+
def is_model_by_reference(self):
|
557
|
+
"""Checks if model is created by reference
|
558
|
+
Returns
|
559
|
+
-------
|
560
|
+
bool flag denoting whether model was created by reference.
|
561
|
+
|
562
|
+
"""
|
563
|
+
if self.custom_metadata_list:
|
564
|
+
for metadata in self.custom_metadata_list:
|
565
|
+
if (
|
566
|
+
metadata.key == MODEL_BY_REFERENCE_DESC
|
567
|
+
and metadata.value.lower() == "true"
|
568
|
+
):
|
569
|
+
return True
|
570
|
+
return False
|
@@ -37,7 +37,8 @@ To run anomaly detection locally, create and activate a new conda environment (`
|
|
37
37
|
```yaml
|
38
38
|
- datapane
|
39
39
|
- cerberus
|
40
|
-
- oracle-automlx==23.
|
40
|
+
- oracle-automlx==23.4.1
|
41
|
+
- oracle-automlx[classic]==23.4.1
|
41
42
|
- "git+https://github.com/oracle/accelerated-data-science.git@feature/anomaly#egg=oracle-ads"
|
42
43
|
```
|
43
44
|
|
@@ -36,11 +36,17 @@ def operate(operator_config: AnomalyOperatorConfig) -> None:
|
|
36
36
|
operator_config.spec.model = "autots"
|
37
37
|
operator_config.spec.model_kwargs = dict()
|
38
38
|
datasets = AnomalyDatasets(operator_config.spec)
|
39
|
-
|
40
|
-
|
41
|
-
|
39
|
+
try:
|
40
|
+
AnomalyOperatorModelFactory.get_model(
|
41
|
+
operator_config, datasets
|
42
|
+
).generate_report()
|
43
|
+
except Exception as e2:
|
44
|
+
logger.debug(
|
45
|
+
f"Failed to backup forecast with error {e2.args}. Raising original error."
|
46
|
+
)
|
47
|
+
raise e
|
42
48
|
else:
|
43
|
-
raise
|
49
|
+
raise e
|
44
50
|
|
45
51
|
|
46
52
|
def verify(spec: Dict, **kwargs: Dict) -> bool:
|
@@ -17,20 +17,26 @@ class AutoMLXOperatorModel(AnomalyOperatorBaseModel):
|
|
17
17
|
"""Class representing AutoMLX operator model."""
|
18
18
|
|
19
19
|
@runtime_dependency(
|
20
|
-
module="
|
20
|
+
module="automlx",
|
21
21
|
err_msg=(
|
22
|
-
"Please run `pip3 install oracle-automlx==23.
|
23
|
-
"install
|
22
|
+
"Please run `pip3 install oracle-automlx==23.4.1` and "
|
23
|
+
"`pip3 install oracle-automlx[classic]==23.4.1` "
|
24
|
+
"to install the required dependencies for automlx."
|
24
25
|
),
|
25
26
|
)
|
26
27
|
def _build_model(self) -> pd.DataFrame:
|
28
|
+
from automlx import init
|
29
|
+
try:
|
30
|
+
init(engine="ray", engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}})
|
31
|
+
except Exception as e:
|
32
|
+
logger.info("Ray already initialized")
|
27
33
|
date_column = self.spec.datetime_column.name
|
28
34
|
anomaly_output = AnomalyOutput(date_column=date_column)
|
35
|
+
time_budget = self.spec.model_kwargs.pop("time_budget", -1)
|
29
36
|
|
30
|
-
time_budget = self.spec.model_kwargs.pop("time_budget", None)
|
31
37
|
# Iterate over the full_data_dict items
|
32
38
|
for target, df in self.datasets.full_data_dict.items():
|
33
|
-
est =
|
39
|
+
est = automlx.Pipeline(task="anomaly_detection", **self.spec.model_kwargs)
|
34
40
|
est.fit(
|
35
41
|
X=df,
|
36
42
|
X_valid=self.X_valid_dict[target]
|
@@ -39,10 +45,10 @@ class AutoMLXOperatorModel(AnomalyOperatorBaseModel):
|
|
39
45
|
y_valid=self.y_valid_dict[target]
|
40
46
|
if self.y_valid_dict is not None
|
41
47
|
else None,
|
42
|
-
time_budget=time_budget,
|
43
48
|
contamination=self.spec.contamination
|
44
49
|
if self.y_valid_dict is not None
|
45
50
|
else None,
|
51
|
+
time_budget=time_budget,
|
46
52
|
)
|
47
53
|
y_pred = est.predict(df)
|
48
54
|
scores = est.predict_proba(df)
|
@@ -38,8 +38,9 @@ To run forecasting locally, create and activate a new conda environment (`ads-fo
|
|
38
38
|
- datapane
|
39
39
|
- cerberus
|
40
40
|
- sktime
|
41
|
-
- optuna==
|
42
|
-
- oracle-automlx==23.
|
41
|
+
- optuna==3.1.0
|
42
|
+
- oracle-automlx==23.4.1
|
43
|
+
- oracle-automlx[forecasting]==23.4.1
|
43
44
|
- oracle-ads>=2.9.0
|
44
45
|
```
|
45
46
|
|
@@ -45,7 +45,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
45
45
|
model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
|
46
46
|
)
|
47
47
|
model_kwargs_cleaned.pop("task", None)
|
48
|
-
time_budget = model_kwargs_cleaned.pop("time_budget",
|
48
|
+
time_budget = model_kwargs_cleaned.pop("time_budget", -1)
|
49
49
|
model_kwargs_cleaned[
|
50
50
|
"preprocessing"
|
51
51
|
] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
|
@@ -55,9 +55,11 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
55
55
|
return data.set_index(self.spec.datetime_column.name)
|
56
56
|
|
57
57
|
@runtime_dependency(
|
58
|
-
module="
|
58
|
+
module="automlx",
|
59
59
|
err_msg=(
|
60
|
-
"Please run `pip3 install oracle-automlx==23.
|
60
|
+
"Please run `pip3 install oracle-automlx==23.4.1` and "
|
61
|
+
"`pip3 install oracle-automlx[forecasting]==23.4.1` "
|
62
|
+
"to install the required dependencies for automlx."
|
61
63
|
),
|
62
64
|
)
|
63
65
|
@runtime_dependency(
|
@@ -67,15 +69,13 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
67
69
|
),
|
68
70
|
)
|
69
71
|
def _build_model(self) -> pd.DataFrame:
|
70
|
-
from
|
72
|
+
from automlx import init
|
71
73
|
from sktime.forecasting.model_selection import temporal_train_test_split
|
74
|
+
try:
|
75
|
+
init(engine="ray", engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}})
|
76
|
+
except Exception as e:
|
77
|
+
logger.info("Ray already initialized")
|
72
78
|
|
73
|
-
init(
|
74
|
-
engine="local",
|
75
|
-
engine_opts={"n_jobs": -1, "model_n_jobs": -1},
|
76
|
-
check_deprecation_warnings=False,
|
77
|
-
logger=50,
|
78
|
-
)
|
79
79
|
|
80
80
|
full_data_dict = self.datasets.get_data_by_series()
|
81
81
|
|
@@ -95,7 +95,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
95
95
|
|
96
96
|
for i, (s_id, df) in enumerate(full_data_dict.items()):
|
97
97
|
try:
|
98
|
-
logger.debug(f"Running
|
98
|
+
logger.debug(f"Running automlx on series {s_id}")
|
99
99
|
model_kwargs = model_kwargs_cleaned.copy()
|
100
100
|
target = self.original_target_column
|
101
101
|
self.forecast_output.init_series_output(
|
@@ -110,7 +110,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
110
110
|
if self.loaded_models is not None:
|
111
111
|
model = self.loaded_models[s_id]
|
112
112
|
else:
|
113
|
-
model =
|
113
|
+
model = automlx.Pipeline(
|
114
114
|
task="forecasting",
|
115
115
|
**model_kwargs,
|
116
116
|
)
|
@@ -149,18 +149,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
|
|
149
149
|
|
150
150
|
self.model_parameters[s_id] = {
|
151
151
|
"framework": SupportedModels.AutoMLX,
|
152
|
-
"score_metric": model.score_metric,
|
153
|
-
"random_state": model.random_state,
|
154
|
-
"model_list": model.model_list,
|
155
|
-
"n_algos_tuned": model.n_algos_tuned,
|
156
|
-
"adaptive_sampling": model.adaptive_sampling,
|
157
|
-
"min_features": model.min_features,
|
158
|
-
"optimization": model.optimization,
|
159
|
-
"preprocessing": model.preprocessing,
|
160
|
-
"search_space": model.search_space,
|
161
152
|
"time_series_period": model.time_series_period,
|
162
|
-
"min_class_instances": model.min_class_instances,
|
163
|
-
"max_tuning_trials": model.max_tuning_trials,
|
164
153
|
"selected_model": model.selected_model_,
|
165
154
|
"selected_model_params": model.selected_model_params_,
|
166
155
|
}
|
ads/telemetry/base.py
ADDED
@@ -0,0 +1,62 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
import logging
|
7
|
+
|
8
|
+
from ads import set_auth
|
9
|
+
from ads.common import oci_client as oc
|
10
|
+
from ads.common.auth import default_signer
|
11
|
+
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
|
12
|
+
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
class TelemetryBase:
|
16
|
+
"""Base class for Telemetry Client."""
|
17
|
+
|
18
|
+
def __init__(self, bucket: str, namespace: str = None) -> None:
|
19
|
+
"""Initializes the telemetry client.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
bucket : str
|
24
|
+
OCI object storage bucket name storing the telemetry objects.
|
25
|
+
namespace : str, optional
|
26
|
+
Namespace of the OCI object storage bucket, by default None.
|
27
|
+
"""
|
28
|
+
if OCI_RESOURCE_PRINCIPAL_VERSION:
|
29
|
+
set_auth("resource_principal")
|
30
|
+
self._auth = default_signer()
|
31
|
+
self.os_client = oc.OCIClientFactory(**self._auth).object_storage
|
32
|
+
self.bucket = bucket
|
33
|
+
self._namespace = namespace
|
34
|
+
self._service_endpoint = None
|
35
|
+
logger.debug(f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}")
|
36
|
+
|
37
|
+
|
38
|
+
@property
|
39
|
+
def namespace(self) -> str:
|
40
|
+
"""Gets the namespace of the object storage from the tenancy.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
str
|
45
|
+
The namespace of the tenancy.
|
46
|
+
"""
|
47
|
+
if not self._namespace:
|
48
|
+
self._namespace = self.os_client.get_namespace().data
|
49
|
+
return self._namespace
|
50
|
+
|
51
|
+
@property
|
52
|
+
def service_endpoint(self):
|
53
|
+
"""Gets the tenancy-specific endpoint.
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
str
|
58
|
+
Tenancy-specific endpoint.
|
59
|
+
"""
|
60
|
+
if not self._service_endpoint:
|
61
|
+
self._service_endpoint = self.os_client.base_client.endpoint
|
62
|
+
return self._service_endpoint
|
ads/telemetry/client.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
|
+
|
6
|
+
|
7
|
+
import logging
|
8
|
+
import threading
|
9
|
+
import urllib.parse
|
10
|
+
import requests
|
11
|
+
from requests import Response
|
12
|
+
from .base import TelemetryBase
|
13
|
+
from ads.config import DEBUG_TELEMETRY
|
14
|
+
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class TelemetryClient(TelemetryBase):
|
20
|
+
"""Represents a telemetry python client providing functions to record an event.
|
21
|
+
|
22
|
+
Methods
|
23
|
+
-------
|
24
|
+
record_event(category: str = None, action: str = None, path: str = None, **kwargs) -> None
|
25
|
+
Send a head request to generate an event record.
|
26
|
+
record_event_async(category: str = None, action: str = None, path: str = None, **kwargs)
|
27
|
+
Starts thread to send a head request to generate an event record.
|
28
|
+
|
29
|
+
Examples
|
30
|
+
--------
|
31
|
+
>>> import os
|
32
|
+
>>> import traceback
|
33
|
+
>>> from ads.telemetry.client import TelemetryClient
|
34
|
+
>>> AQUA_BUCKET = os.environ.get("AQUA_BUCKET", "service-managed-models")
|
35
|
+
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "ociodscdev")
|
36
|
+
>>> telemetry = TelemetryClient(bucket=AQUA_BUCKET, namespace=AQUA_BUCKET_NS)
|
37
|
+
>>> telemetry.record_event_async(category="aqua/service/model", action="create") # records create action
|
38
|
+
>>> telemetry.record_event_async(category="aqua/service/model/create", action="shape", detail="VM.GPU.A10.1")
|
39
|
+
"""
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def _encode_user_agent(**kwargs):
|
43
|
+
message = urllib.parse.urlencode(kwargs)
|
44
|
+
return message
|
45
|
+
|
46
|
+
def record_event(
|
47
|
+
self, category: str = None, action: str = None, detail: str = None, **kwargs
|
48
|
+
) -> Response:
|
49
|
+
"""Send a head request to generate an event record.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
category: (str)
|
54
|
+
Category of the event, which is also the path to the directory containing the object representing the event.
|
55
|
+
action: (str)
|
56
|
+
Filename of the object representing the event.
|
57
|
+
detail: (str)
|
58
|
+
Can be used to pass additional values, if required. When set, detail is converted to an action,
|
59
|
+
category and action are grouped together for telemetry parsing in the backend.
|
60
|
+
**kwargs:
|
61
|
+
Can be used to pass additional attributes like value that will be passed in the headers of the request.
|
62
|
+
|
63
|
+
Returns
|
64
|
+
-------
|
65
|
+
Response
|
66
|
+
"""
|
67
|
+
try:
|
68
|
+
if not category or not action:
|
69
|
+
raise ValueError("Please specify the category and the action.")
|
70
|
+
if detail:
|
71
|
+
category, action = f"{category}/{action}", detail
|
72
|
+
endpoint = f"{self.service_endpoint}/n/{self.namespace}/b/{self.bucket}/o/telemetry/{category}/{action}"
|
73
|
+
headers = {"User-Agent": self._encode_user_agent(**kwargs)}
|
74
|
+
logger.debug(f"Sending telemetry to endpoint: {endpoint}")
|
75
|
+
signer = self._auth["signer"]
|
76
|
+
response = requests.head(endpoint, auth=signer, headers=headers)
|
77
|
+
logger.debug(f"Telemetry status code: {response.status_code}")
|
78
|
+
return response
|
79
|
+
except Exception as e:
|
80
|
+
if DEBUG_TELEMETRY:
|
81
|
+
logger.error(f"There is an error recording telemetry: {e}")
|
82
|
+
|
83
|
+
def record_event_async(
|
84
|
+
self, category: str = None, action: str = None, detail: str = None, **kwargs
|
85
|
+
):
|
86
|
+
"""Send a head request to generate an event record.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
category (str)
|
91
|
+
Category of the event, which is also the path to the directory containing the object representing the event.
|
92
|
+
action (str)
|
93
|
+
Filename of the object representing the event.
|
94
|
+
|
95
|
+
Returns
|
96
|
+
-------
|
97
|
+
Thread
|
98
|
+
A started thread to send a head request to generate an event record.
|
99
|
+
"""
|
100
|
+
thread = threading.Thread(
|
101
|
+
target=self.record_event, args=(category, action, detail), kwargs=kwargs
|
102
|
+
)
|
103
|
+
thread.daemon = True
|
104
|
+
thread.start()
|
105
|
+
return thread
|
ads/telemetry/telemetry.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8 -*--
|
3
3
|
|
4
|
-
# Copyright (c) 2022,
|
4
|
+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
|
7
7
|
import os
|
@@ -102,7 +102,10 @@ def telemetry(
|
|
102
102
|
entry_point
|
103
103
|
)
|
104
104
|
try:
|
105
|
-
|
105
|
+
# todo: inject telemetry arg and later update all functions that use the @telemetry
|
106
|
+
# decorator to accept **kwargs. Comment the below line as some aqua apis don't support kwargs.
|
107
|
+
# return func(*args, **{**kwargs, **{TELEMETRY_ARGUMENT_NAME: telemetry}})
|
108
|
+
return func(*args, **kwargs)
|
106
109
|
except:
|
107
110
|
raise
|
108
111
|
finally:
|
@@ -178,7 +181,7 @@ class Telemetry:
|
|
178
181
|
self: Telemetry
|
179
182
|
An instance of the Telemetry.
|
180
183
|
"""
|
181
|
-
os.environ[self.environ_variable] = self._original_value
|
184
|
+
os.environ[self.environ_variable] = self._original_value or ""
|
182
185
|
return self
|
183
186
|
|
184
187
|
def clean(self) -> "Telemetry":
|