oracle-ads 2.12.10rc0__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ads/aqua/__init__.py +2 -1
  2. ads/aqua/app.py +46 -19
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +19 -14
  6. ads/aqua/common/errors.py +3 -4
  7. ads/aqua/common/utils.py +2 -2
  8. ads/aqua/constants.py +1 -0
  9. ads/aqua/evaluation/constants.py +7 -7
  10. ads/aqua/evaluation/errors.py +3 -4
  11. ads/aqua/evaluation/evaluation.py +20 -12
  12. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  13. ads/aqua/extension/base_handler.py +12 -9
  14. ads/aqua/extension/model_handler.py +29 -1
  15. ads/aqua/extension/models/ws_models.py +5 -6
  16. ads/aqua/finetuning/constants.py +3 -3
  17. ads/aqua/finetuning/entities.py +3 -0
  18. ads/aqua/finetuning/finetuning.py +32 -1
  19. ads/aqua/model/constants.py +7 -7
  20. ads/aqua/model/entities.py +2 -1
  21. ads/aqua/model/enums.py +4 -5
  22. ads/aqua/model/model.py +158 -76
  23. ads/aqua/modeldeployment/deployment.py +22 -10
  24. ads/aqua/modeldeployment/entities.py +3 -1
  25. ads/cli.py +16 -8
  26. ads/common/auth.py +33 -20
  27. ads/common/extended_enum.py +52 -44
  28. ads/llm/__init__.py +11 -8
  29. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  30. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  31. ads/model/artifact_downloader.py +3 -4
  32. ads/model/datascience_model.py +84 -64
  33. ads/model/generic_model.py +3 -3
  34. ads/model/model_metadata.py +17 -11
  35. ads/model/service/oci_datascience_model.py +12 -14
  36. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  37. ads/opctl/cli.py +4 -5
  38. ads/opctl/cmds.py +28 -32
  39. ads/opctl/config/merger.py +8 -11
  40. ads/opctl/config/resolver.py +25 -30
  41. ads/opctl/operator/cli.py +9 -9
  42. ads/opctl/operator/common/backend_factory.py +56 -60
  43. ads/opctl/operator/common/const.py +5 -5
  44. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  45. ads/opctl/operator/lowcode/common/transformations.py +38 -3
  46. ads/opctl/operator/lowcode/common/utils.py +11 -1
  47. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  48. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  49. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  50. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  51. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  52. ads/opctl/operator/lowcode/forecast/schema.yaml +63 -0
  53. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  54. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  55. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  56. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  57. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  58. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  59. ads/opctl/operator/runtime/runtime.py +4 -6
  60. ads/pipeline/ads_pipeline_run.py +13 -25
  61. ads/pipeline/visualizer/graph_renderer.py +3 -4
  62. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/METADATA +4 -2
  63. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/RECORD +66 -59
  64. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/LICENSE.txt +0 -0
  65. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/WHEEL +0 -0
  66. {oracle_ads-2.12.10rc0.dist-info → oracle_ads-2.13.0.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  """
@@ -12,11 +11,6 @@ The factory validates the backend type and runtime type before creating the back
12
11
  from typing import Dict, List, Tuple, Union
13
12
 
14
13
  import yaml
15
- from ads.opctl.operator.common.utils import print_traceback
16
-
17
- from ads.opctl.backend.marketplace.local_marketplace import (
18
- LocalMarketplaceOperatorBackend,
19
- )
20
14
 
21
15
  from ads.opctl import logger
22
16
  from ads.opctl.backend.ads_dataflow import DataFlowOperatorBackend
@@ -25,6 +19,9 @@ from ads.opctl.backend.base import Backend
25
19
  from ads.opctl.backend.local import (
26
20
  LocalOperatorBackend,
27
21
  )
22
+ from ads.opctl.backend.marketplace.local_marketplace import (
23
+ LocalMarketplaceOperatorBackend,
24
+ )
28
25
  from ads.opctl.config.base import ConfigProcessor
29
26
  from ads.opctl.config.merger import ConfigMerger
30
27
  from ads.opctl.constants import (
@@ -34,9 +31,10 @@ from ads.opctl.constants import (
34
31
  RESOURCE_TYPE,
35
32
  RUNTIME_TYPE,
36
33
  )
37
- from ads.opctl.operator.common.const import PACK_TYPE, OPERATOR_BACKEND_SECTION_NAME
34
+ from ads.opctl.operator.common.const import OPERATOR_BACKEND_SECTION_NAME, PACK_TYPE
38
35
  from ads.opctl.operator.common.dictionary_merger import DictionaryMerger
39
36
  from ads.opctl.operator.common.operator_loader import OperatorInfo, OperatorLoader
37
+ from ads.opctl.operator.common.utils import print_traceback
40
38
 
41
39
 
42
40
  class BackendFactory:
@@ -46,57 +44,57 @@ class BackendFactory:
46
44
  """
47
45
 
48
46
  BACKENDS = (
49
- BACKEND_NAME.JOB.value,
50
- BACKEND_NAME.DATAFLOW.value,
51
- BACKEND_NAME.MARKETPLACE.value,
47
+ BACKEND_NAME.JOB,
48
+ BACKEND_NAME.DATAFLOW,
49
+ BACKEND_NAME.MARKETPLACE,
52
50
  )
53
51
 
54
52
  LOCAL_BACKENDS = (
55
- BACKEND_NAME.OPERATOR_LOCAL.value,
56
- BACKEND_NAME.LOCAL.value,
53
+ BACKEND_NAME.OPERATOR_LOCAL,
54
+ BACKEND_NAME.LOCAL,
57
55
  )
58
56
 
59
57
  BACKEND_RUNTIME_MAP = {
60
- BACKEND_NAME.JOB.value.lower(): {
61
- RUNTIME_TYPE.PYTHON.value.lower(): (
62
- BACKEND_NAME.JOB.value.lower(),
63
- RUNTIME_TYPE.PYTHON.value.lower(),
58
+ BACKEND_NAME.JOB.lower(): {
59
+ RUNTIME_TYPE.PYTHON.lower(): (
60
+ BACKEND_NAME.JOB.lower(),
61
+ RUNTIME_TYPE.PYTHON.lower(),
64
62
  ),
65
- RUNTIME_TYPE.CONTAINER.value.lower(): (
66
- BACKEND_NAME.JOB.value.lower(),
67
- RUNTIME_TYPE.CONTAINER.value.lower(),
63
+ RUNTIME_TYPE.CONTAINER.lower(): (
64
+ BACKEND_NAME.JOB.lower(),
65
+ RUNTIME_TYPE.CONTAINER.lower(),
68
66
  ),
69
67
  },
70
- BACKEND_NAME.DATAFLOW.value.lower(): {
71
- RUNTIME_TYPE.DATAFLOW.value.lower(): (
72
- BACKEND_NAME.DATAFLOW.value.lower(),
73
- RUNTIME_TYPE.DATAFLOW.value.lower(),
68
+ BACKEND_NAME.DATAFLOW.lower(): {
69
+ RUNTIME_TYPE.DATAFLOW.lower(): (
70
+ BACKEND_NAME.DATAFLOW.lower(),
71
+ RUNTIME_TYPE.DATAFLOW.lower(),
74
72
  )
75
73
  },
76
- BACKEND_NAME.OPERATOR_LOCAL.value.lower(): {
77
- RUNTIME_TYPE.PYTHON.value.lower(): (
78
- BACKEND_NAME.OPERATOR_LOCAL.value.lower(),
79
- RUNTIME_TYPE.PYTHON.value.lower(),
74
+ BACKEND_NAME.OPERATOR_LOCAL.lower(): {
75
+ RUNTIME_TYPE.PYTHON.lower(): (
76
+ BACKEND_NAME.OPERATOR_LOCAL.lower(),
77
+ RUNTIME_TYPE.PYTHON.lower(),
80
78
  ),
81
- RUNTIME_TYPE.CONTAINER.value.lower(): (
82
- BACKEND_NAME.OPERATOR_LOCAL.value.lower(),
83
- RUNTIME_TYPE.CONTAINER.value.lower(),
79
+ RUNTIME_TYPE.CONTAINER.lower(): (
80
+ BACKEND_NAME.OPERATOR_LOCAL.lower(),
81
+ RUNTIME_TYPE.CONTAINER.lower(),
84
82
  ),
85
83
  },
86
- BACKEND_NAME.MARKETPLACE.value.lower(): {
87
- RUNTIME_TYPE.PYTHON.value.lower(): (
88
- BACKEND_NAME.MARKETPLACE.value.lower(),
89
- RUNTIME_TYPE.PYTHON.value.lower(),
84
+ BACKEND_NAME.MARKETPLACE.lower(): {
85
+ RUNTIME_TYPE.PYTHON.lower(): (
86
+ BACKEND_NAME.MARKETPLACE.lower(),
87
+ RUNTIME_TYPE.PYTHON.lower(),
90
88
  )
91
89
  },
92
90
  }
93
91
 
94
92
  BACKEND_MAP = {
95
- BACKEND_NAME.JOB.value.lower(): MLJobOperatorBackend,
96
- BACKEND_NAME.DATAFLOW.value.lower(): DataFlowOperatorBackend,
97
- BACKEND_NAME.OPERATOR_LOCAL.value.lower(): LocalOperatorBackend,
98
- BACKEND_NAME.LOCAL.value.lower(): LocalOperatorBackend,
99
- BACKEND_NAME.MARKETPLACE.value.lower(): LocalMarketplaceOperatorBackend,
93
+ BACKEND_NAME.JOB.lower(): MLJobOperatorBackend,
94
+ BACKEND_NAME.DATAFLOW.lower(): DataFlowOperatorBackend,
95
+ BACKEND_NAME.OPERATOR_LOCAL.lower(): LocalOperatorBackend,
96
+ BACKEND_NAME.LOCAL.lower(): LocalOperatorBackend,
97
+ BACKEND_NAME.MARKETPLACE.lower(): LocalMarketplaceOperatorBackend,
100
98
  }
101
99
 
102
100
  @classmethod
@@ -135,15 +133,15 @@ class BackendFactory:
135
133
  # validation
136
134
  if not operator_type:
137
135
  raise RuntimeError(
138
- f"The `type` attribute must be specified in the operator's config."
136
+ "The `type` attribute must be specified in the operator's config."
139
137
  )
140
138
 
141
139
  if not backend and not config.config.get(OPERATOR_BACKEND_SECTION_NAME):
142
140
  logger.info(
143
- f"Backend config is not provided, the {BACKEND_NAME.LOCAL.value} "
141
+ f"Backend config is not provided, the {BACKEND_NAME.LOCAL} "
144
142
  "will be used by default. "
145
143
  )
146
- backend = BACKEND_NAME.LOCAL.value
144
+ backend = BACKEND_NAME.LOCAL
147
145
  elif not backend:
148
146
  backend = config.config.get(OPERATOR_BACKEND_SECTION_NAME)
149
147
 
@@ -164,8 +162,8 @@ class BackendFactory:
164
162
  backend = {"kind": backend_kind}
165
163
 
166
164
  backend_kind = (
167
- BACKEND_NAME.OPERATOR_LOCAL.value
168
- if backend.get("kind").lower() == BACKEND_NAME.LOCAL.value
165
+ BACKEND_NAME.OPERATOR_LOCAL
166
+ if backend.get("kind").lower() == BACKEND_NAME.LOCAL
169
167
  else backend.get("kind").lower()
170
168
  )
171
169
  backend["kind"] = backend_kind
@@ -174,11 +172,11 @@ class BackendFactory:
174
172
  # This is necessary, because Jobs and DataFlow have similar kind,
175
173
  # The only difference would be in the infrastructure kind.
176
174
  # This is a temporary solution, the logic needs to be placed in the ConfigMerger instead.
177
- if backend_kind == BACKEND_NAME.JOB.value:
175
+ if backend_kind == BACKEND_NAME.JOB:
178
176
  if (backend.get("spec", {}) or {}).get("infrastructure", {}).get(
179
177
  "type", ""
180
- ).lower() == BACKEND_NAME.DATAFLOW.value:
181
- backend_kind = BACKEND_NAME.DATAFLOW.value
178
+ ).lower() == BACKEND_NAME.DATAFLOW:
179
+ backend_kind = BACKEND_NAME.DATAFLOW
182
180
 
183
181
  runtime_type = runtime_type or (
184
182
  backend.get("type")
@@ -247,17 +245,17 @@ class BackendFactory:
247
245
  If the backend type is not supported.
248
246
  """
249
247
  supported_backends = supported_backends or (cls.BACKENDS + cls.LOCAL_BACKENDS)
250
- backend = (backend or BACKEND_NAME.OPERATOR_LOCAL.value).lower()
248
+ backend = (backend or BACKEND_NAME.OPERATOR_LOCAL).lower()
251
249
  backend_kind, runtime_type = backend, None
252
250
 
253
- if backend.lower() != BACKEND_NAME.OPERATOR_LOCAL.value and "." in backend:
251
+ if backend.lower() != BACKEND_NAME.OPERATOR_LOCAL and "." in backend:
254
252
  backend_kind, runtime_type = backend.split(".")
255
253
  else:
256
254
  backend_kind = backend
257
255
 
258
256
  backend_kind = (
259
- BACKEND_NAME.OPERATOR_LOCAL.value
260
- if backend_kind == BACKEND_NAME.LOCAL.value
257
+ BACKEND_NAME.OPERATOR_LOCAL
258
+ if backend_kind == BACKEND_NAME.LOCAL
261
259
  else backend_kind
262
260
  )
263
261
 
@@ -357,7 +355,7 @@ class BackendFactory:
357
355
 
358
356
  # generate supported backend specifications templates YAML
359
357
  RUNTIME_TYPE_MAP = {
360
- RESOURCE_TYPE.JOB.value: [
358
+ RESOURCE_TYPE.JOB: [
361
359
  {
362
360
  RUNTIME_TYPE.PYTHON: {
363
361
  "conda_slug": operator_info.conda
@@ -373,7 +371,7 @@ class BackendFactory:
373
371
  }
374
372
  },
375
373
  ],
376
- RESOURCE_TYPE.DATAFLOW.value: [
374
+ RESOURCE_TYPE.DATAFLOW: [
377
375
  {
378
376
  RUNTIME_TYPE.DATAFLOW: {
379
377
  "conda_slug": operator_info.conda_prefix,
@@ -381,7 +379,7 @@ class BackendFactory:
381
379
  }
382
380
  }
383
381
  ],
384
- BACKEND_NAME.OPERATOR_LOCAL.value: [
382
+ BACKEND_NAME.OPERATOR_LOCAL: [
385
383
  {
386
384
  RUNTIME_TYPE.CONTAINER: {
387
385
  "kind": "operator",
@@ -397,7 +395,7 @@ class BackendFactory:
397
395
  }
398
396
  },
399
397
  ],
400
- BACKEND_NAME.MARKETPLACE.value: [
398
+ BACKEND_NAME.MARKETPLACE: [
401
399
  {
402
400
  RUNTIME_TYPE.PYTHON: {
403
401
  "kind": "marketplace",
@@ -445,11 +443,9 @@ class BackendFactory:
445
443
  )
446
444
 
447
445
  # generate YAML specification template
448
- result[
449
- (resource_type.lower(), runtime_type.value.lower())
450
- ] = yaml.load(
446
+ result[(resource_type.lower(), runtime_type.lower())] = yaml.load(
451
447
  _BackendFactory(p.config).backend.init(
452
- runtime_type=runtime_type.value,
448
+ runtime_type=runtime_type,
453
449
  **{**kwargs, **runtime_kwargs},
454
450
  ),
455
451
  Loader=yaml.FullLoader,
@@ -1,10 +1,9 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
8
7
 
9
8
  # Env variable representing the operator input arguments.
10
9
  # This variable is used when operator run on the OCI resources.
@@ -17,11 +16,12 @@ OPERATOR_BASE_DOCKER_GPU_FILE = "Dockerfile.gpu"
17
16
 
18
17
  OPERATOR_BACKEND_SECTION_NAME = "backend"
19
18
 
20
- class PACK_TYPE(str, metaclass=ExtendedEnumMeta):
19
+
20
+ class PACK_TYPE(ExtendedEnum):
21
21
  SERVICE = "service"
22
22
  CUSTOM = "published"
23
23
 
24
24
 
25
- class ARCH_TYPE(str, metaclass=ExtendedEnumMeta):
25
+ class ARCH_TYPE(ExtendedEnum):
26
26
  CPU = "cpu"
27
27
  GPU = "gpu"
@@ -1,14 +1,13 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ads.common.extended_enum import ExtendedEnumMeta
6
+ from ads.common.extended_enum import ExtendedEnum
8
7
  from ads.opctl.operator.lowcode.common.const import DataColumns
9
8
 
10
9
 
11
- class SupportedModels(str, metaclass=ExtendedEnumMeta):
10
+ class SupportedModels(ExtendedEnum):
12
11
  """Supported anomaly models."""
13
12
 
14
13
  AutoTS = "autots"
@@ -38,7 +37,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
38
37
  BOCPD = "bocpd"
39
38
 
40
39
 
41
- class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
40
+ class NonTimeADSupportedModels(ExtendedEnum):
42
41
  """Supported non time-based anomaly detection models."""
43
42
 
44
43
  OneClassSVM = "oneclasssvm"
@@ -48,7 +47,7 @@ class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
48
47
  # DBScan = "dbscan"
49
48
 
50
49
 
51
- class TODSSubModels(str, metaclass=ExtendedEnumMeta):
50
+ class TODSSubModels(ExtendedEnum):
52
51
  """Supported TODS sub models."""
53
52
 
54
53
  OCSVM = "ocsvm"
@@ -78,7 +77,7 @@ TODS_MODEL_MAP = {
78
77
  }
79
78
 
80
79
 
81
- class MerlionADModels(str, metaclass=ExtendedEnumMeta):
80
+ class MerlionADModels(ExtendedEnum):
82
81
  """Supported Merlion AD sub models."""
83
82
 
84
83
  # point anomaly
@@ -126,7 +125,7 @@ MERLIONAD_MODEL_MAP = {
126
125
  }
127
126
 
128
127
 
129
- class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
128
+ class SupportedMetrics(ExtendedEnum):
130
129
  UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
131
130
  UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
132
131
  UNSUPERVISED_N1_EXPERTS = "unsupervised_n-1_experts"
@@ -158,7 +157,7 @@ class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
158
157
  ELAPSED_TIME = "Elapsed Time"
159
158
 
160
159
 
161
- class OutputColumns(str, metaclass=ExtendedEnumMeta):
160
+ class OutputColumns(ExtendedEnum):
162
161
  ANOMALY_COL = "anomaly"
163
162
  SCORE_COL = "score"
164
163
  Series = DataColumns.Series
@@ -15,6 +15,7 @@ from ads.opctl.operator.lowcode.common.errors import (
15
15
  InvalidParameterError,
16
16
  )
17
17
  from ads.opctl.operator.lowcode.common.utils import merge_category_columns
18
+ from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
18
19
 
19
20
 
20
21
  class Transformations(ABC):
@@ -34,6 +35,7 @@ class Transformations(ABC):
34
35
  self.dataset_info = dataset_info
35
36
  self.target_category_columns = dataset_info.target_category_columns
36
37
  self.target_column_name = dataset_info.target_column
38
+ self.raw_column_names = None
37
39
  self.dt_column_name = (
38
40
  dataset_info.datetime_column.name if dataset_info.datetime_column else None
39
41
  )
@@ -60,7 +62,8 @@ class Transformations(ABC):
60
62
 
61
63
  """
62
64
  clean_df = self._remove_trailing_whitespace(data)
63
- # clean_df = self._normalize_column_names(clean_df)
65
+ if isinstance(self.dataset_info, ForecastOperatorSpec):
66
+ clean_df = self._clean_column_names(clean_df)
64
67
  if self.name == "historical_data":
65
68
  self._check_historical_dataset(clean_df)
66
69
  clean_df = self._set_series_id_column(clean_df)
@@ -98,8 +101,36 @@ class Transformations(ABC):
98
101
  def _remove_trailing_whitespace(self, df):
99
102
  return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
100
103
 
101
- # def _normalize_column_names(self, df):
102
- # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104
+ def _clean_column_names(self, df):
105
+ """
106
+ Remove all whitespaces from column names in a DataFrame and store the original names.
107
+
108
+ Parameters:
109
+ df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110
+
111
+ Returns:
112
+ pd.DataFrame: The DataFrame with cleaned column names.
113
+ """
114
+
115
+ self.raw_column_names = {
116
+ col: col.replace(" ", "") for col in df.columns if " " in col
117
+ }
118
+ df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
119
+
120
+ if self.target_column_name:
121
+ self.target_column_name = self.raw_column_names.get(
122
+ self.target_column_name, self.target_column_name
123
+ )
124
+ self.dt_column_name = self.raw_column_names.get(
125
+ self.dt_column_name, self.dt_column_name
126
+ )
127
+
128
+ if self.target_category_columns:
129
+ self.target_category_columns = [
130
+ self.raw_column_names.get(col, col)
131
+ for col in self.target_category_columns
132
+ ]
133
+ return df
103
134
 
104
135
  def _set_series_id_column(self, df):
105
136
  self._target_category_columns_map = {}
@@ -233,6 +264,10 @@ class Transformations(ABC):
233
264
  expected_names = [self.target_column_name, self.dt_column_name] + (
234
265
  self.target_category_columns if self.target_category_columns else []
235
266
  )
267
+
268
+ if self.raw_column_names:
269
+ expected_names.extend(list(self.raw_column_names.values()))
270
+
236
271
  if set(df.columns) != set(expected_names):
237
272
  raise DataMismatchError(
238
273
  f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"
@@ -12,6 +12,7 @@ from typing import List, Union
12
12
 
13
13
  import fsspec
14
14
  import oracledb
15
+ import json
15
16
  import pandas as pd
16
17
 
17
18
  from ads.common.object_storage_details import ObjectStorageDetails
@@ -125,7 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
125
126
  return data
126
127
 
127
128
 
128
- def write_data(data, filename, format, storage_options, index=False, **kwargs):
129
+ def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
129
130
  disable_print()
130
131
  if not format:
131
132
  _, format = os.path.splitext(filename)
@@ -141,6 +142,15 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
141
142
  )
142
143
 
143
144
 
145
+ def write_simple_json(data, path):
146
+ if ObjectStorageDetails.is_oci_path(path):
147
+ storage_options = default_signer()
148
+ else:
149
+ storage_options = {}
150
+ with fsspec.open(path, mode="w", **storage_options) as f:
151
+ json.dump(data, f, indent=4)
152
+
153
+
144
154
  def merge_category_columns(data, target_category_columns):
145
155
  result = data.apply(
146
156
  lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
@@ -1,71 +1,66 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import ast
8
7
  import base64
9
- from typing import Optional, List, Dict
8
+ from typing import TYPE_CHECKING, Dict, List, Optional
10
9
 
11
10
  import oci
12
11
  import requests
13
- from typing import TYPE_CHECKING
14
12
 
15
13
  try:
16
14
  from kubernetes.client import (
17
- V1ServiceStatus,
18
- V1Service,
19
- V1LoadBalancerStatus,
20
15
  V1LoadBalancerIngress,
16
+ V1LoadBalancerStatus,
17
+ V1Service,
18
+ V1ServiceStatus,
21
19
  )
22
20
  except ImportError:
23
21
  if TYPE_CHECKING:
24
22
  from kubernetes.client import (
25
- V1ServiceStatus,
26
- V1Service,
27
- V1LoadBalancerStatus,
28
23
  V1LoadBalancerIngress,
24
+ V1LoadBalancerStatus,
25
+ V1Service,
26
+ V1ServiceStatus,
29
27
  )
30
28
 
31
- from oci.resource_manager.models import StackSummary, AssociatedResourceSummary
32
-
33
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
34
- APIGatewayConfig,
35
- )
29
+ import click
30
+ from oci.resource_manager.models import AssociatedResourceSummary, StackSummary
36
31
 
37
- from ads.common.oci_client import OCIClientFactory
38
- from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
39
- LISTING_ID,
40
- APIGW_STACK_NAME,
41
- STACK_URL,
42
- NLB_RULES_ADDRESS,
43
- NODES_RULES_ADDRESS,
44
- )
45
32
  from ads import logger
46
- import click
33
+ from ads.common import auth as authutil
34
+ from ads.common.oci_client import OCIClientFactory
47
35
  from ads.opctl import logger
48
-
49
36
  from ads.opctl.backend.marketplace.marketplace_utils import (
50
37
  Color,
51
38
  print_heading,
52
39
  print_ticker,
53
40
  )
54
- from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
- MySqlConfig,
41
+ from ads.opctl.operator.lowcode.feature_store_marketplace.const import (
42
+ APIGW_STACK_NAME,
43
+ LISTING_ID,
44
+ NLB_RULES_ADDRESS,
45
+ NODES_RULES_ADDRESS,
46
+ STACK_URL,
47
+ )
48
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
49
+ APIGatewayConfig,
56
50
  )
57
-
58
51
  from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
59
52
  DBConfig,
60
53
  )
61
- from ads.common import auth as authutil
54
+ from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
55
+ MySqlConfig,
56
+ )
62
57
 
63
58
 
64
59
  def get_db_details() -> DBConfig:
65
60
  jdbc_url = "jdbc:mysql://{}/{}?createDatabaseIfNotExist=true"
66
61
  mysql_db_config = MySqlConfig()
67
62
  print_heading(
68
- f"MySQL database configuration",
63
+ "MySQL database configuration",
69
64
  colors=[Color.BOLD, Color.BLUE],
70
65
  prefix_newline_count=2,
71
66
  )
@@ -76,12 +71,12 @@ def get_db_details() -> DBConfig:
76
71
  "Is password provided as plain-text or via a Vault secret?\n"
77
72
  "(https://docs.oracle.com/en-us/iaas/Content/KeyManagement/Concepts/keyoverview.htm)",
78
73
  type=click.Choice(MySqlConfig.MySQLAuthType.values()),
79
- default=MySqlConfig.MySQLAuthType.BASIC.value,
74
+ default=MySqlConfig.MySQLAuthType.BASIC,
80
75
  )
81
76
  )
82
77
  if mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.BASIC:
83
78
  basic_auth_config = MySqlConfig.BasicConfig()
84
- basic_auth_config.password = click.prompt(f"Password", hide_input=True)
79
+ basic_auth_config.password = click.prompt("Password", hide_input=True)
85
80
  mysql_db_config.basic_config = basic_auth_config
86
81
 
87
82
  elif mysql_db_config.auth_type == MySqlConfig.MySQLAuthType.VAULT:
@@ -176,12 +171,12 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
176
171
  ).data
177
172
 
178
173
  if len(stacks) >= 1:
179
- print(f"Auto-detected feature store stack(s) in tenancy:")
174
+ print("Auto-detected feature store stack(s) in tenancy:")
180
175
  for stack in stacks:
181
176
  _print_stack_detail(stack)
182
177
  choices = {"1": "new", "2": "existing"}
183
178
  stack_provision_method = click.prompt(
184
- f"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
179
+ "Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
185
180
  type=click.Choice(list(choices.keys())),
186
181
  show_choices=False,
187
182
  )
@@ -240,20 +235,20 @@ def get_api_gw_details(compartment_id: str) -> APIGatewayConfig:
240
235
 
241
236
 
242
237
  def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig):
243
- status: "V1ServiceStatus" = service.status
244
- lb_status: "V1LoadBalancerStatus" = status.load_balancer
245
- lb_ingress: "V1LoadBalancerIngress" = lb_status.ingress[0]
238
+ status: V1ServiceStatus = service.status
239
+ lb_status: V1LoadBalancerStatus = status.load_balancer
240
+ lb_ingress: V1LoadBalancerIngress = lb_status.ingress[0]
246
241
  resource_client = OCIClientFactory(**authutil.default_signer()).create_client(
247
242
  oci.resource_search.ResourceSearchClient
248
243
  )
249
244
  search_details = oci.resource_search.models.FreeTextSearchDetails()
250
245
  search_details.matching_context_type = "NONE"
251
246
  search_details.text = lb_ingress.ip
252
- resources: List[
253
- oci.resource_search.models.ResourceSummary
254
- ] = resource_client.search_resources(
255
- search_details, tenant_id=apigw_config.root_compartment_id
256
- ).data.items
247
+ resources: List[oci.resource_search.models.ResourceSummary] = (
248
+ resource_client.search_resources(
249
+ search_details, tenant_id=apigw_config.root_compartment_id
250
+ ).data.items
251
+ )
257
252
  private_ips = list(filter(lambda obj: obj.resource_type == "PrivateIp", resources))
258
253
  if len(private_ips) != 1:
259
254
  return click.prompt(
@@ -264,12 +259,12 @@ def get_nlb_id_from_service(service: "V1Service", apigw_config: APIGatewayConfig
264
259
  nlb_client = OCIClientFactory(**authutil.default_signer()).create_client(
265
260
  oci.network_load_balancer.NetworkLoadBalancerClient
266
261
  )
267
- nlbs: List[
268
- oci.network_load_balancer.models.NetworkLoadBalancerSummary
269
- ] = nlb_client.list_network_load_balancers(
270
- compartment_id=nlb_private_ip.compartment_id,
271
- display_name=nlb_private_ip.display_name,
272
- ).data.items
262
+ nlbs: List[oci.network_load_balancer.models.NetworkLoadBalancerSummary] = (
263
+ nlb_client.list_network_load_balancers(
264
+ compartment_id=nlb_private_ip.compartment_id,
265
+ display_name=nlb_private_ip.display_name,
266
+ ).data.items
267
+ )
273
268
  if len(nlbs) != 1:
274
269
  return click.prompt(
275
270
  f"Please enter OCID of load balancer associated with ip: {lb_ingress.ip}"
@@ -17,6 +17,7 @@ from ads.opctl.operator.common.utils import _parse_input_args
17
17
 
18
18
  from .operator_config import ForecastOperatorConfig
19
19
  from .model.forecast_datasets import ForecastDatasets
20
+ from .whatifserve import ModelDeploymentManager
20
21
 
21
22
 
22
23
  def operate(operator_config: ForecastOperatorConfig) -> None:
@@ -27,6 +28,15 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
27
28
  ForecastOperatorModelFactory.get_model(
28
29
  operator_config, datasets
29
30
  ).generate_report()
31
+ # saving to model catalog
32
+ spec = operator_config.spec
33
+ if spec.what_if_analysis and datasets.additional_data:
34
+ mdm = ModelDeploymentManager(spec, datasets.additional_data)
35
+ mdm.save_to_catalog()
36
+ if spec.what_if_analysis.model_deployment:
37
+ mdm.create_deployment()
38
+ mdm.save_deployment_info()
39
+
30
40
 
31
41
  def verify(spec: Dict, **kwargs: Dict) -> bool:
32
42
  """Verifies the forecasting operator config."""