oracle-ads 2.12.2__py3-none-any.whl → 2.12.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ads/aqua/common/enums.py +9 -0
  2. ads/aqua/common/utils.py +83 -6
  3. ads/aqua/config/config.py +0 -16
  4. ads/aqua/constants.py +2 -0
  5. ads/aqua/evaluation/entities.py +45 -50
  6. ads/aqua/evaluation/evaluation.py +26 -61
  7. ads/aqua/extension/deployment_handler.py +35 -0
  8. ads/aqua/extension/errors.py +1 -0
  9. ads/aqua/extension/evaluation_handler.py +0 -5
  10. ads/aqua/extension/finetune_handler.py +1 -2
  11. ads/aqua/extension/model_handler.py +38 -1
  12. ads/aqua/extension/ui_handler.py +22 -3
  13. ads/aqua/finetuning/entities.py +5 -4
  14. ads/aqua/finetuning/finetuning.py +13 -8
  15. ads/aqua/model/constants.py +1 -0
  16. ads/aqua/model/entities.py +2 -0
  17. ads/aqua/model/model.py +350 -140
  18. ads/aqua/modeldeployment/deployment.py +118 -62
  19. ads/aqua/modeldeployment/entities.py +10 -2
  20. ads/aqua/ui.py +29 -16
  21. ads/config.py +3 -8
  22. ads/dataset/dataset.py +2 -2
  23. ads/dataset/factory.py +1 -1
  24. ads/llm/deploy.py +6 -0
  25. ads/llm/guardrails/base.py +0 -1
  26. ads/llm/langchain/plugins/chat_models/oci_data_science.py +118 -41
  27. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +18 -14
  28. ads/llm/templates/score_chain.jinja2 +0 -1
  29. ads/model/datascience_model.py +519 -16
  30. ads/model/deployment/model_deployment.py +13 -0
  31. ads/model/deployment/model_deployment_infrastructure.py +34 -0
  32. ads/model/generic_model.py +10 -0
  33. ads/model/model_properties.py +1 -0
  34. ads/model/service/oci_datascience_model.py +28 -0
  35. ads/opctl/operator/lowcode/anomaly/const.py +66 -1
  36. ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py +161 -0
  37. ads/opctl/operator/lowcode/anomaly/model/autots.py +30 -15
  38. ads/opctl/operator/lowcode/anomaly/model/factory.py +15 -3
  39. ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py +1 -1
  40. ads/opctl/operator/lowcode/anomaly/schema.yaml +10 -0
  41. ads/opctl/operator/lowcode/anomaly/utils.py +3 -0
  42. ads/opctl/operator/lowcode/forecast/cmd.py +3 -9
  43. ads/opctl/operator/lowcode/forecast/const.py +3 -4
  44. ads/opctl/operator/lowcode/forecast/model/factory.py +13 -12
  45. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +4 -3
  46. ads/opctl/operator/lowcode/forecast/operator_config.py +17 -10
  47. ads/opctl/operator/lowcode/forecast/schema.yaml +2 -2
  48. ads/oracledb/oracle_db.py +32 -20
  49. ads/secrets/adb.py +28 -6
  50. {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/METADATA +3 -2
  51. {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/RECORD +54 -55
  52. {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/WHEEL +1 -1
  53. ads/aqua/config/deployment_config_defaults.json +0 -38
  54. ads/aqua/config/resource_limit_names.json +0 -9
  55. {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/LICENSE.txt +0 -0
  56. {oracle_ads-2.12.2.dist-info → oracle_ads-2.12.4.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,19 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ..const import SupportedModels, AUTO_SELECT
6
+ from ..const import AUTO_SELECT, SupportedModels
7
+ from ..model_evaluator import ModelEvaluator
8
8
  from ..operator_config import ForecastOperatorConfig
9
9
  from .arima import ArimaOperatorModel
10
10
  from .automlx import AutoMLXOperatorModel
11
11
  from .autots import AutoTSOperatorModel
12
12
  from .base_model import ForecastOperatorBaseModel
13
+ from .forecast_datasets import ForecastDatasets
13
14
  from .neuralprophet import NeuralProphetOperatorModel
14
15
  from .prophet import ProphetOperatorModel
15
- from .forecast_datasets import ForecastDatasets
16
- from .ml_forecast import MLForecastOperatorModel
17
- from ..model_evaluator import ModelEvaluator
16
+
18
17
 
19
18
  class UnSupportedModelError(Exception):
20
19
  def __init__(self, model_type: str):
@@ -33,9 +32,9 @@ class ForecastOperatorModelFactory:
33
32
  SupportedModels.Prophet: ProphetOperatorModel,
34
33
  SupportedModels.Arima: ArimaOperatorModel,
35
34
  SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36
- SupportedModels.LGBForecast: MLForecastOperatorModel,
35
+ # SupportedModels.LGBForecast: MLForecastOperatorModel,
37
36
  SupportedModels.AutoMLX: AutoMLXOperatorModel,
38
- SupportedModels.AutoTS: AutoTSOperatorModel
37
+ SupportedModels.AutoTS: AutoTSOperatorModel,
39
38
  }
40
39
 
41
40
  @classmethod
@@ -65,14 +64,14 @@ class ForecastOperatorModelFactory:
65
64
  model_type = operator_config.spec.model
66
65
  if model_type == AUTO_SELECT:
67
66
  model_type = cls.auto_select_model(datasets, operator_config)
68
- operator_config.spec.model_kwargs = dict()
67
+ operator_config.spec.model_kwargs = {}
69
68
  if model_type not in cls._MAP:
70
69
  raise UnSupportedModelError(model_type)
71
70
  return cls._MAP[model_type](config=operator_config, datasets=datasets)
72
71
 
73
72
  @classmethod
74
73
  def auto_select_model(
75
- cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
74
+ cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
76
75
  ) -> str:
77
76
  """
78
77
  Selects AutoMLX or Arima model based on column count.
@@ -90,8 +89,10 @@ class ForecastOperatorModelFactory:
90
89
  str
91
90
  The type of the model.
92
91
  """
93
- all_models = operator_config.spec.model_kwargs.get("model_list", cls._MAP.keys())
92
+ all_models = operator_config.spec.model_kwargs.get(
93
+ "model_list", cls._MAP.keys()
94
+ )
94
95
  num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
95
96
  sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
96
97
  model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
97
- return model_evaluator.find_best_model(datasets, operator_config)
98
+ return model_evaluator.find_best_model(datasets, operator_config)
@@ -2,7 +2,8 @@
2
2
 
3
3
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
- import numpy as np
5
+ import traceback
6
+
6
7
  import pandas as pd
7
8
 
8
9
  from ads.common.decorator import runtime_dependency
@@ -164,7 +165,7 @@ class MLForecastOperatorModel(ForecastOperatorBaseModel):
164
165
  self.errors_dict[self.spec.model] = {
165
166
  "model_name": self.spec.model,
166
167
  "error": str(e),
167
- "error_trace": traceback.format_exc()
168
+ "error_trace": traceback.format_exc(),
168
169
  }
169
170
  logger.warn(f"Encountered Error: {e}. Skipping.")
170
171
  logger.warn(traceback.format_exc())
@@ -173,7 +174,7 @@ class MLForecastOperatorModel(ForecastOperatorBaseModel):
173
174
  def _build_model(self) -> pd.DataFrame:
174
175
  data_train = self.datasets.get_all_data_long(include_horizon=False)
175
176
  data_test = self.datasets.get_all_data_long_forecast_horizon()
176
- self.models = dict()
177
+ self.models = {}
177
178
  model_kwargs = self.set_kwargs()
178
179
  self.forecast_output = ForecastOutput(
179
180
  confidence_interval_width=self.spec.confidence_interval_width,
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import os
@@ -9,13 +8,17 @@ from dataclasses import dataclass, field
9
8
  from typing import Dict, List
10
9
 
11
10
  from ads.common.serializer import DataClassSerializable
11
+ from ads.opctl.operator.common.operator_config import (
12
+ InputData,
13
+ OperatorConfig,
14
+ OutputDirectory,
15
+ )
12
16
  from ads.opctl.operator.common.utils import _load_yaml_from_uri
13
- from ads.opctl.operator.common.operator_config import OperatorConfig, OutputDirectory, InputData
14
-
15
- from .const import SupportedMetrics, SpeedAccuracyMode
16
- from .const import SupportedModels
17
17
  from ads.opctl.operator.lowcode.common.utils import find_output_dirname
18
18
 
19
+ from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
20
+
21
+
19
22
  @dataclass(repr=True)
20
23
  class TestData(InputData):
21
24
  """Class representing operator specification test data details."""
@@ -90,13 +93,17 @@ class ForecastOperatorSpec(DataClassSerializable):
90
93
 
91
94
  def __post_init__(self):
92
95
  """Adjusts the specification details."""
93
- self.output_directory = self.output_directory or OutputDirectory(url=find_output_dirname(self.output_directory))
96
+ self.output_directory = self.output_directory or OutputDirectory(
97
+ url=find_output_dirname(self.output_directory)
98
+ )
94
99
  self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
95
- self.model = self.model or SupportedModels.Auto
100
+ self.model = self.model or SupportedModels.Prophet
96
101
  self.confidence_interval_width = self.confidence_interval_width or 0.80
97
102
  self.report_filename = self.report_filename or "report.html"
98
103
  self.preprocessing = (
99
- self.preprocessing if self.preprocessing is not None else DataPreprocessor(enabled=True)
104
+ self.preprocessing
105
+ if self.preprocessing is not None
106
+ else DataPreprocessor(enabled=True)
100
107
  )
101
108
  # For Report Generation. When user doesn't specify defaults to True
102
109
  self.generate_report = (
@@ -138,7 +145,7 @@ class ForecastOperatorSpec(DataClassSerializable):
138
145
  )
139
146
  self.target_column = self.target_column or "Sales"
140
147
  self.errors_dict_filename = "errors.json"
141
- self.model_kwargs = self.model_kwargs or dict()
148
+ self.model_kwargs = self.model_kwargs or {}
142
149
 
143
150
 
144
151
  @dataclass(repr=True)
@@ -374,12 +374,12 @@ spec:
374
374
  model:
375
375
  type: string
376
376
  required: false
377
- default: auto-select
377
+ default: prophet
378
378
  allowed:
379
379
  - prophet
380
380
  - arima
381
381
  - neuralprophet
382
- - lgbforecast
382
+ # - lgbforecast
383
383
  - automlx
384
384
  - autots
385
385
  - auto-select
ads/oracledb/oracle_db.py CHANGED
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  """
@@ -17,19 +16,20 @@ If user uses DSN string copied from OCI console with OCI database setup for TLS
17
16
  Note: We need to account for cx_Oracle though oracledb can operate in thick mode. The end user may be is using one of the old conda packs or an environment where cx_Oracle is the only available driver.
18
17
  """
19
18
 
20
- from ads.common.utils import ORACLE_DEFAULT_PORT
21
-
22
19
  import logging
23
- import numpy as np
24
20
  import os
25
- import pandas as pd
26
21
  import tempfile
27
- from time import time
28
- from typing import Dict, Optional, List, Union, Iterator
29
22
  import zipfile
23
+ from time import time
24
+ from typing import Dict, Iterator, List, Optional, Union
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+
30
29
  from ads.common.decorator.runtime_dependency import (
31
30
  OptionalDependency,
32
31
  )
32
+ from ads.common.utils import ORACLE_DEFAULT_PORT
33
33
 
34
34
  logger = logging.getLogger("ads.oracle_connector")
35
35
  CX_ORACLE = "cx_Oracle"
@@ -40,17 +40,17 @@ try:
40
40
  import oracledb as oracle_driver # Both the driver share same signature for the APIs that we are using.
41
41
 
42
42
  PYTHON_DRIVER_NAME = PYTHON_ORACLEDB
43
- except:
43
+ except ModuleNotFoundError:
44
44
  logger.info("oracledb package not found. Trying to load cx_Oracle")
45
45
  try:
46
46
  import cx_Oracle as oracle_driver
47
47
 
48
48
  PYTHON_DRIVER_NAME = CX_ORACLE
49
- except ModuleNotFoundError:
49
+ except ModuleNotFoundError as err2:
50
50
  raise ModuleNotFoundError(
51
51
  f"Neither `oracledb` nor `cx_Oracle` module was not found. Please run "
52
52
  f"`pip install {OptionalDependency.DATA}`."
53
- )
53
+ ) from err2
54
54
 
55
55
 
56
56
  class OracleRDBMSConnection(oracle_driver.Connection):
@@ -75,7 +75,7 @@ class OracleRDBMSConnection(oracle_driver.Connection):
75
75
  logger.info(
76
76
  "Running oracledb driver in thick mode. For mTLS based connection, thick mode is default."
77
77
  )
78
- except:
78
+ except Exception:
79
79
  logger.info(
80
80
  "Could not use thick mode. The driver is running in thin mode. System might prompt for passphrase"
81
81
  )
@@ -154,7 +154,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
154
154
  batch_size=100000,
155
155
  encoding="utf-8",
156
156
  ):
157
-
158
157
  if if_exists not in ["fail", "replace", "append"]:
159
158
  raise ValueError(
160
159
  f"Unknown option `if_exists`={if_exists}. Valid options are 'fail', 'replace', 'append'"
@@ -173,7 +172,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
173
172
  df_orcl.columns = df_orcl.columns.str.replace(r"\W+", "_", regex=True)
174
173
  table_exist = True
175
174
  with self.cursor() as cursor:
176
-
177
175
  if if_exists != "replace":
178
176
  try:
179
177
  cursor.execute(f"SELECT 1 from {table_name} FETCH NEXT 1 ROWS ONLY")
@@ -275,7 +273,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
275
273
  yield lst[i : i + batch_size]
276
274
 
277
275
  for batch in chunks(record_data, batch_size=batch_size):
278
-
279
276
  cursor.executemany(sql, batch, batcherrors=True)
280
277
 
281
278
  for error in cursor.getbatcherrors():
@@ -304,7 +301,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
304
301
  def query(
305
302
  self, sql: str, bind_variables: Optional[Dict], chunksize=None
306
303
  ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
307
-
308
304
  start_time = time()
309
305
 
310
306
  cursor = self.cursor()
@@ -315,10 +311,8 @@ class OracleRDBMSConnection(oracle_driver.Connection):
315
311
  cursor.execute(sql, **bind_variables)
316
312
  columns = [row[0] for row in cursor.description]
317
313
  df = iter(
318
- (
319
- pd.DataFrame(data=rows, columns=columns)
320
- for rows in self._fetch_by_batch(cursor, chunksize)
321
- )
314
+ pd.DataFrame(data=rows, columns=columns)
315
+ for rows in self._fetch_by_batch(cursor, chunksize)
322
316
  )
323
317
 
324
318
  else:
@@ -332,3 +326,21 @@ class OracleRDBMSConnection(oracle_driver.Connection):
332
326
  )
333
327
 
334
328
  return df
329
+
330
+
331
+ def get_adw_connection(vault_secret_id: str) -> "oracledb.Connection":
332
+ """Creates ADW connection from the credentials stored in the vault"""
333
+ import oracledb
334
+
335
+ from ads.secrets.adb import ADBSecretKeeper
336
+
337
+ secret = vault_secret_id
338
+
339
+ logging.getLogger().debug("A secret id was used to retrieve credentials.")
340
+ creds = ADBSecretKeeper.load_secret(secret).to_dict()
341
+ user = creds.pop("user_name", None)
342
+ password = creds.pop("password", None)
343
+ if not user or not password:
344
+ raise ValueError(f"The user or password is missing in {secret}")
345
+ logging.getLogger().debug("Downloaded secrets successfully.")
346
+ return oracledb.connect(user=user, password=password, **creds)
ads/secrets/adb.py CHANGED
@@ -1,17 +1,18 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2021, 2022 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- import ads
8
- from ads.secrets import SecretKeeper, Secret
9
6
  import json
10
7
  import os
11
8
  import tempfile
12
9
  import zipfile
10
+
13
11
  from tqdm.auto import tqdm
14
12
 
13
+ import ads
14
+ from ads.secrets import Secret, SecretKeeper
15
+
15
16
  logger = ads.getLogger("ads.secrets")
16
17
 
17
18
  from dataclasses import dataclass, field
@@ -25,7 +26,7 @@ class ADBSecret(Secret):
25
26
 
26
27
  user_name: str
27
28
  password: str
28
- service_name: str
29
+ service_name: str = field(default=None)
29
30
  wallet_location: str = field(
30
31
  default=None, metadata={"serializable": False}
31
32
  ) # Not saved in vault
@@ -40,6 +41,7 @@ class ADBSecret(Secret):
40
41
  wallet_secret_ids: list = field(
41
42
  repr=False, default_factory=list
42
43
  ) # Not exposed through environment or `to_dict` function
44
+ dsn: str = field(default=None)
43
45
 
44
46
  def __post_init__(self):
45
47
  self.wallet_file_name = (
@@ -76,6 +78,22 @@ class ADBSecretKeeper(SecretKeeper):
76
78
  >>> print(adw_keeper.secret_id) # Prints the secret_id of the stored credentials
77
79
  >>> adw_keeper.export_vault_details("adw_employee_att.json", format="json") # Save the secret id and vault info to a json file
78
80
 
81
+
82
+ >>> # Saving credentials for TLS connection
83
+ >>> from ads.secrets.adw import ADBSecretKeeper
84
+ >>> vault_id = "ocid1.vault.oc1..<unique_ID>"
85
+ >>> kid = "ocid1.ke..<unique_ID>"
86
+
87
+ >>> import ads
88
+ >>> ads.set_auth("resource_principal") # If using resource principal for authentication
89
+ >>> connection_parameters={
90
+ ... "user_name":"admin",
91
+ ... "password":"<your password>",
92
+ ... "dsn":"<dsn string>"
93
+ ... }
94
+ >>> adw_keeper = ADBSecretKeeper(vault_id=vault_id, key_id=kid, **connection_parameters)
95
+ >>> adw_keeper.save("adw_employee", "My DB credentials", freeform_tags={"schema":"emp"})
96
+
79
97
  >>> # Loading credentails
80
98
  >>> import ads
81
99
  >>> ads.set_auth("resource_principal") # If using resource principal for authentication
@@ -133,6 +151,7 @@ class ADBSecretKeeper(SecretKeeper):
133
151
  wallet_dir: str = None,
134
152
  repository_path: str = None,
135
153
  repository_key: str = None,
154
+ dsn: str = None,
136
155
  **kwargs,
137
156
  ):
138
157
  """
@@ -152,6 +171,8 @@ class ADBSecretKeeper(SecretKeeper):
152
171
  Path to credentials repository. For more details refer `ads.database.connection`
153
172
  repository_key: (str, optional). Default None.
154
173
  Configuration key for loading the right configuration from repository. For more details refer `ads.database.connection`
174
+ dsn: (str, optional). Default None.
175
+ dsn string copied from the OCI console for TLS connection
155
176
  kwargs:
156
177
  vault_id: str. OCID of the vault where the secret is stored. Required for saving secret.
157
178
  key_id: str. OCID of the key used for encrypting the secret. Required for saving secret.
@@ -180,6 +201,7 @@ class ADBSecretKeeper(SecretKeeper):
180
201
  password=password,
181
202
  service_name=service_name,
182
203
  wallet_location=wallet_location,
204
+ dsn=dsn,
183
205
  )
184
206
  self.wallet_dir = wallet_dir
185
207
 
@@ -252,7 +274,7 @@ class ADBSecretKeeper(SecretKeeper):
252
274
  logger.debug(f"Setting wallet file to {self.data.wallet_location}")
253
275
  data.wallet_location = self.data.wallet_location
254
276
  elif data.wallet_secret_ids and len(data.wallet_secret_ids) > 0:
255
- logger.debug(f"Secret ids corresponding to the wallet files found.")
277
+ logger.debug("Secret ids corresponding to the wallet files found.")
256
278
  # If the secret ids for wallet files are available in secret, then we
257
279
  # can generate the wallet file.
258
280
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: oracle_ads
3
- Version: 2.12.2
3
+ Version: 2.12.4
4
4
  Summary: Oracle Accelerated Data Science SDK
5
5
  Keywords: Oracle Cloud Infrastructure,OCI,Machine Learning,ML,Artificial Intelligence,AI,Data Science,Cloud,Oracle
6
6
  Author: Oracle Data Science
@@ -40,6 +40,7 @@ Requires-Dist: oracledb ; extra == "anomaly"
40
40
  Requires-Dist: report-creator==1.0.9 ; extra == "anomaly"
41
41
  Requires-Dist: rrcf==0.4.4 ; extra == "anomaly"
42
42
  Requires-Dist: scikit-learn ; extra == "anomaly"
43
+ Requires-Dist: salesforce-merlion[all]==2.0.4 ; extra == "anomaly"
43
44
  Requires-Dist: jupyter_server ; extra == "aqua"
44
45
  Requires-Dist: hdfs[kerberos] ; extra == "bds"
45
46
  Requires-Dist: ibis-framework[impala] ; extra == "bds"