oracle-ads 2.12.7__py3-none-any.whl → 2.12.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ads/aqua/app.py +12 -2
  2. ads/aqua/evaluation/entities.py +6 -0
  3. ads/aqua/evaluation/evaluation.py +25 -3
  4. ads/aqua/extension/deployment_handler.py +8 -4
  5. ads/aqua/extension/model_handler.py +9 -7
  6. ads/aqua/extension/ui_handler.py +13 -1
  7. ads/aqua/finetuning/entities.py +6 -0
  8. ads/aqua/finetuning/finetuning.py +47 -34
  9. ads/aqua/model/entities.py +2 -0
  10. ads/aqua/model/model.py +34 -6
  11. ads/aqua/modeldeployment/deployment.py +28 -10
  12. ads/aqua/modeldeployment/entities.py +7 -4
  13. ads/aqua/ui.py +24 -2
  14. ads/llm/guardrails/base.py +6 -5
  15. ads/llm/langchain/plugins/chat_models/oci_data_science.py +34 -9
  16. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
  17. ads/opctl/operator/common/utils.py +6 -4
  18. ads/opctl/operator/lowcode/anomaly/model/base_model.py +2 -3
  19. ads/opctl/operator/lowcode/anomaly/model/factory.py +2 -2
  20. ads/opctl/operator/lowcode/common/transformations.py +14 -10
  21. ads/opctl/operator/lowcode/common/utils.py +37 -37
  22. ads/opctl/operator/lowcode/forecast/const.py +1 -0
  23. ads/opctl/operator/lowcode/forecast/model/automlx.py +10 -2
  24. ads/opctl/operator/lowcode/forecast/model/base_model.py +10 -15
  25. ads/opctl/operator/lowcode/forecast/model/factory.py +3 -2
  26. ads/opctl/operator/lowcode/forecast/model/prophet.py +4 -1
  27. ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
  28. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  29. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  30. ads/opctl/operator/lowcode/pii/model/factory.py +7 -5
  31. ads/opctl/operator/lowcode/recommender/model/base_model.py +2 -1
  32. ads/opctl/operator/lowcode/recommender/model/factory.py +4 -6
  33. ads/opctl/operator/lowcode/recommender/model/svd.py +5 -5
  34. {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/METADATA +3 -3
  35. {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/RECORD +38 -38
  36. {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/LICENSE.txt +0 -0
  37. {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/WHEEL +0 -0
  38. {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/entry_points.txt +0 -0
@@ -110,6 +110,8 @@ class AquaDeploymentApp(AquaApp):
110
110
  private_endpoint_id: Optional[str] = None,
111
111
  container_image_uri: Optional[None] = None,
112
112
  cmd_var: List[str] = None,
113
+ freeform_tags: Optional[dict] = None,
114
+ defined_tags: Optional[dict] = None,
113
115
  ) -> "AquaDeployment":
114
116
  """
115
117
  Creates a new Aqua deployment
@@ -163,6 +165,10 @@ class AquaDeploymentApp(AquaApp):
163
165
  Required parameter for BYOC based deployments if this parameter was not set during model registration.
164
166
  cmd_var: List[str]
165
167
  The cmd of model deployment container runtime.
168
+ freeform_tags: dict
169
+ Freeform tags for the model deployment
170
+ defined_tags: dict
171
+ Defined tags for the model deployment
166
172
  Returns
167
173
  -------
168
174
  AquaDeployment
@@ -172,7 +178,11 @@ class AquaDeploymentApp(AquaApp):
172
178
  # TODO validate if the service model has no artifact and if it requires import step before deployment.
173
179
  # Create a model catalog entry in the user compartment
174
180
  aqua_model = AquaModelApp().create(
175
- model_id=model_id, compartment_id=compartment_id, project_id=project_id
181
+ model_id=model_id,
182
+ compartment_id=compartment_id,
183
+ project_id=project_id,
184
+ freeform_tags=freeform_tags,
185
+ defined_tags=defined_tags,
176
186
  )
177
187
 
178
188
  tags = {}
@@ -185,7 +195,7 @@ class AquaDeploymentApp(AquaApp):
185
195
  tags[tag] = aqua_model.freeform_tags[tag]
186
196
 
187
197
  tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
188
- tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})
198
+ tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)})
189
199
 
190
200
  # Set up info to get deployment config
191
201
  config_source_id = model_id
@@ -418,12 +428,14 @@ class AquaDeploymentApp(AquaApp):
418
428
  if cmd_var:
419
429
  container_runtime.with_cmd(cmd_var)
420
430
 
431
+ tags = {**tags, **(freeform_tags or {})}
421
432
  # configure model deployment and deploy model on container runtime
422
433
  deployment = (
423
434
  ModelDeployment()
424
435
  .with_display_name(display_name)
425
436
  .with_description(description)
426
437
  .with_freeform_tags(**tags)
438
+ .with_defined_tags(**(defined_tags or {}))
427
439
  .with_infrastructure(infrastructure)
428
440
  .with_runtime(container_runtime)
429
441
  ).deploy(wait_for_completion=False)
@@ -533,16 +545,22 @@ class AquaDeploymentApp(AquaApp):
533
545
  return results
534
546
 
535
547
  @telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
536
- def delete(self,model_deployment_id:str):
537
- return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data
548
+ def delete(self, model_deployment_id: str):
549
+ return self.ds_client.delete_model_deployment(
550
+ model_deployment_id=model_deployment_id
551
+ ).data
538
552
 
539
- @telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
540
- def deactivate(self,model_deployment_id:str):
541
- return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data
553
+ @telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua")
554
+ def deactivate(self, model_deployment_id: str):
555
+ return self.ds_client.deactivate_model_deployment(
556
+ model_deployment_id=model_deployment_id
557
+ ).data
542
558
 
543
- @telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
544
- def activate(self,model_deployment_id:str):
545
- return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data
559
+ @telemetry(entry_point="plugin=deployment&action=activate", name="aqua")
560
+ def activate(self, model_deployment_id: str):
561
+ return self.ds_client.activate_model_deployment(
562
+ model_deployment_id=model_deployment_id
563
+ ).data
546
564
 
547
565
  @telemetry(entry_point="plugin=deployment&action=get", name="aqua")
548
566
  def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
@@ -98,9 +98,12 @@ class AquaDeployment(DataClassSerializable):
98
98
  ),
99
99
  )
100
100
 
101
- freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT
102
- aqua_service_model_tag = freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None)
103
- aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN)
101
+ tags = {}
102
+ tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
103
+ tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)
104
+
105
+ aqua_service_model_tag = tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None)
106
+ aqua_model_name = tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN)
104
107
  private_endpoint_id = getattr(
105
108
  instance_configuration, "private_endpoint_id", UNKNOWN
106
109
  )
@@ -125,7 +128,7 @@ class AquaDeployment(DataClassSerializable):
125
128
  ocid=oci_model_deployment.id,
126
129
  region=region,
127
130
  ),
128
- tags=freeform_tags,
131
+ tags=tags,
129
132
  environment_variables=environment_variables,
130
133
  cmd=cmd,
131
134
  )
ads/aqua/ui.py CHANGED
@@ -481,12 +481,12 @@ class AquaUIApp(AquaApp):
481
481
 
482
482
  @telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua")
483
483
  def list_job_shapes(self, **kwargs) -> list:
484
- """Lists all availiable job shapes for the specified compartment.
484
+ """Lists all available job shapes for the specified compartment.
485
485
 
486
486
  Parameters
487
487
  ----------
488
488
  **kwargs
489
- Addtional arguments, such as `compartment_id`,
489
+ Additional arguments, such as `compartment_id`,
490
490
  for `list_job_shapes <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_job_shapes>`_
491
491
 
492
492
  Returns
@@ -500,6 +500,28 @@ class AquaUIApp(AquaApp):
500
500
  ).data
501
501
  return sanitize_response(oci_client=self.ds_client, response=res)
502
502
 
503
+ @telemetry(entry_point="plugin=ui&action=list_model_deployment_shapes", name="aqua")
504
+ def list_model_deployment_shapes(self, **kwargs) -> list:
505
+ """Lists all available shapes for model deployment in the specified compartment.
506
+
507
+ Parameters
508
+ ----------
509
+ **kwargs
510
+ Additional arguments, such as `compartment_id`,
511
+ for `list_model_deployment_shapes <https://docs.oracle.com/en-us/iaas/api/#/en/data-science/20190101/ModelDeploymentShapeSummary/ListModelDeploymentShapes>`_
512
+
513
+ Returns
514
+ -------
515
+ str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`."""
516
+ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
517
+ logger.info(
518
+ f"Loading model deployment shape summary from compartment: {compartment_id}"
519
+ )
520
+ res = self.ds_client.list_model_deployment_shapes(
521
+ compartment_id=compartment_id, **kwargs
522
+ ).data
523
+ return sanitize_response(oci_client=self.ds_client, response=res)
524
+
503
525
  @telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua")
504
526
  def list_vcn(self, **kwargs) -> list:
505
527
  """Lists the virtual cloud networks (VCNs) in the specified compartment.
@@ -1,17 +1,16 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
8
7
  import datetime
9
8
  import functools
10
- import operator
11
9
  import importlib.util
10
+ import operator
12
11
  import sys
12
+ from typing import Any, List, Optional, Union
13
13
 
14
- from typing import Any, List, Dict, Tuple
15
14
  from langchain.schema.prompt import PromptValue
16
15
  from langchain.tools.base import BaseTool, ToolException
17
16
  from pydantic import BaseModel, model_validator
@@ -207,7 +206,9 @@ class Guardrail(BaseTool):
207
206
  return input.to_string()
208
207
  return str(input)
209
208
 
210
- def _to_args_and_kwargs(self, tool_input: Any) -> Tuple[Tuple, Dict]:
209
+ def _to_args_and_kwargs(
210
+ self, tool_input: Union[str, dict], tool_call_id: Optional[str]
211
+ ) -> tuple[tuple, dict]:
211
212
  if isinstance(tool_input, dict):
212
213
  return (), tool_input
213
214
  else:
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
  """Chat model for OCI data science model deployment endpoint."""
7
6
 
@@ -50,6 +49,7 @@ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint i
50
49
  )
51
50
 
52
51
  logger = logging.getLogger(__name__)
52
+ DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
53
53
 
54
54
 
55
55
  def _is_pydantic_class(obj: Any) -> bool:
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
93
93
  Key init args — client params:
94
94
  auth: dict
95
95
  ADS auth dictionary for OCI authentication.
96
+ default_headers: Optional[Dict]
97
+ The headers to be added to the Model Deployment request.
96
98
 
97
99
  Instantiate:
98
100
  .. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109
111
  "temperature": 0.2,
110
112
  # other model parameters ...
111
113
  },
114
+ default_headers={
115
+ "route": "/v1/chat/completions",
116
+ # other request headers ...
117
+ },
112
118
  )
113
119
 
114
120
  Invocation:
@@ -291,6 +297,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
291
297
  "stream": self.streaming,
292
298
  }
293
299
 
300
+ def _headers(
301
+ self, is_async: Optional[bool] = False, body: Optional[dict] = None
302
+ ) -> Dict:
303
+ """Construct and return the headers for a request.
304
+
305
+ Args:
306
+ is_async (bool, optional): Indicates if the request is asynchronous.
307
+ Defaults to `False`.
308
+ body (optional): The request body to be included in the headers if
309
+ the request is asynchronous.
310
+
311
+ Returns:
312
+ Dict: A dictionary containing the appropriate headers for the request.
313
+ """
314
+ return {
315
+ "route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
316
+ **super()._headers(is_async=is_async, body=body),
317
+ }
318
+
294
319
  def _generate(
295
320
  self,
296
321
  messages: List[BaseMessage],
@@ -704,7 +729,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
704
729
 
705
730
  for choice in choices:
706
731
  message = _convert_dict_to_message(choice["message"])
707
- generation_info = dict(finish_reason=choice.get("finish_reason"))
732
+ generation_info = {"finish_reason": choice.get("finish_reason")}
708
733
  if "logprobs" in choice:
709
734
  generation_info["logprobs"] = choice["logprobs"]
710
735
 
@@ -794,7 +819,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
794
819
  """Number of most likely tokens to consider at each step."""
795
820
 
796
821
  min_p: Optional[float] = 0.0
797
- """Float that represents the minimum probability for a token to be considered.
822
+ """Float that represents the minimum probability for a token to be considered.
798
823
  Must be in [0,1]. 0 to disable this."""
799
824
 
800
825
  repetition_penalty: Optional[float] = 1.0
@@ -818,7 +843,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
818
843
  the EOS token is generated."""
819
844
 
820
845
  min_tokens: Optional[int] = 0
821
- """Minimum number of tokens to generate per output sequence before
846
+ """Minimum number of tokens to generate per output sequence before
822
847
  EOS or stop_token_ids can be generated"""
823
848
 
824
849
  stop_token_ids: Optional[List[int]] = None
@@ -836,7 +861,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
836
861
  tool_choice: Optional[str] = None
837
862
  """Whether to use tool calling.
838
863
  Defaults to None, tool calling is disabled.
839
- Tool calling requires model support and the vLLM to be configured
864
+ Tool calling requires model support and the vLLM to be configured
840
865
  with `--tool-call-parser`.
841
866
  Set this to `auto` for the model to make tool calls automatically.
842
867
  Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +981,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956
981
  """Total probability mass of tokens to consider at each step."""
957
982
 
958
983
  top_logprobs: Optional[int] = None
959
- """An integer between 0 and 5 specifying the number of most
960
- likely tokens to return at each token position, each with an
961
- associated log probability. logprobs must be set to true if
984
+ """An integer between 0 and 5 specifying the number of most
985
+ likely tokens to return at each token position, each with an
986
+ associated log probability. logprobs must be set to true if
962
987
  this parameter is used."""
963
988
 
964
989
  @property
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
@@ -24,6 +23,7 @@ from typing import (
24
23
 
25
24
  import aiohttp
26
25
  import requests
26
+ from langchain_community.utilities.requests import Requests
27
27
  from langchain_core.callbacks import (
28
28
  AsyncCallbackManagerForLLMRun,
29
29
  CallbackManagerForLLMRun,
@@ -34,14 +34,13 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
34
34
  from langchain_core.utils import get_from_dict_or_env
35
35
  from pydantic import Field, model_validator
36
36
 
37
- from langchain_community.utilities.requests import Requests
38
-
39
37
  logger = logging.getLogger(__name__)
40
38
 
41
39
 
42
40
  DEFAULT_TIME_OUT = 300
43
41
  DEFAULT_CONTENT_TYPE_JSON = "application/json"
44
42
  DEFAULT_MODEL_NAME = "odsc-llm"
43
+ DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
45
44
 
46
45
 
47
46
  class TokenExpiredError(Exception):
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
86
85
  max_retries: int = 3
87
86
  """Maximum number of retries to make when generating."""
88
87
 
88
+ default_headers: Optional[Dict[str, Any]] = None
89
+ """The headers to be added to the Model Deployment request."""
90
+
89
91
  @model_validator(mode="before")
90
92
  @classmethod
91
93
  def validate_environment(cls, values: Dict) -> Dict:
@@ -101,7 +103,7 @@ class BaseOCIModelDeployment(Serializable):
101
103
  "Please install it with `pip install oracle_ads`."
102
104
  ) from ex
103
105
 
104
- if not values.get("auth", None):
106
+ if not values.get("auth"):
105
107
  values["auth"] = ads.common.auth.default_signer()
106
108
 
107
109
  values["endpoint"] = get_from_dict_or_env(
@@ -125,12 +127,12 @@ class BaseOCIModelDeployment(Serializable):
125
127
  Returns:
126
128
  Dict: A dictionary containing the appropriate headers for the request.
127
129
  """
130
+ headers = self.default_headers or {}
128
131
  if is_async:
129
132
  signer = self.auth["signer"]
130
133
  _req = requests.Request("POST", self.endpoint, json=body)
131
134
  req = _req.prepare()
132
135
  req = signer(req)
133
- headers = {}
134
136
  for key, value in req.headers.items():
135
137
  headers[key] = value
136
138
 
@@ -140,7 +142,7 @@ class BaseOCIModelDeployment(Serializable):
140
142
  )
141
143
  return headers
142
144
 
143
- return (
145
+ headers.update(
144
146
  {
145
147
  "Content-Type": DEFAULT_CONTENT_TYPE_JSON,
146
148
  "enable-streaming": "true",
@@ -152,6 +154,8 @@ class BaseOCIModelDeployment(Serializable):
152
154
  }
153
155
  )
154
156
 
157
+ return headers
158
+
155
159
  def completion_with_retry(
156
160
  self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
157
161
  ) -> Any:
@@ -357,7 +361,7 @@ class BaseOCIModelDeployment(Serializable):
357
361
  self.auth["signer"].refresh_security_token()
358
362
  return True
359
363
  return False
360
-
364
+
361
365
  @classmethod
362
366
  def is_lc_serializable(cls) -> bool:
363
367
  """Return whether this model can be serialized by LangChain."""
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
388
392
  model="odsc-llm",
389
393
  streaming=True,
390
394
  model_kwargs={"frequency_penalty": 1.0},
395
+ headers={
396
+ "route": "/v1/completions",
397
+ # other request headers ...
398
+ }
391
399
  )
392
400
  llm.invoke("tell me a joke.")
393
401
 
@@ -477,6 +485,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
477
485
  **self._default_params,
478
486
  }
479
487
 
488
+ def _headers(
489
+ self, is_async: Optional[bool] = False, body: Optional[dict] = None
490
+ ) -> Dict:
491
+ """Construct and return the headers for a request.
492
+
493
+ Args:
494
+ is_async (bool, optional): Indicates if the request is asynchronous.
495
+ Defaults to `False`.
496
+ body (optional): The request body to be included in the headers if
497
+ the request is asynchronous.
498
+
499
+ Returns:
500
+ Dict: A dictionary containing the appropriate headers for the request.
501
+ """
502
+ return {
503
+ "route": DEFAULT_INFERENCE_ENDPOINT,
504
+ **super()._headers(is_async=is_async, body=body),
505
+ }
506
+
480
507
  def _generate(
481
508
  self,
482
509
  prompts: List[str],
@@ -712,9 +739,9 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
712
739
  def _generate_info(self, choice: dict) -> Any:
713
740
  """Extracts generation info from the response."""
714
741
  gen_info = {}
715
- finish_reason = choice.get("finish_reason", None)
716
- logprobs = choice.get("logprobs", None)
717
- index = choice.get("index", None)
742
+ finish_reason = choice.get("finish_reason")
743
+ logprobs = choice.get("logprobs")
744
+ index = choice.get("index")
718
745
  if finish_reason:
719
746
  gen_info.update({"finish_reason": finish_reason})
720
747
  if logprobs is not None:
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
3
  # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -18,7 +17,6 @@ import yaml
18
17
  from cerberus import Validator
19
18
 
20
19
  from ads.opctl import logger, utils
21
- from ads.opctl.operator import __operators__
22
20
 
23
21
  CONTAINER_NETWORK = "CONTAINER_NETWORK"
24
22
 
@@ -26,7 +24,11 @@ CONTAINER_NETWORK = "CONTAINER_NETWORK"
26
24
  class OperatorValidator(Validator):
27
25
  """The custom validator class."""
28
26
 
29
- pass
27
+ def validate(self, obj_dict, **kwargs):
28
+ # Model should be case insensitive
29
+ if "model" in obj_dict["spec"]:
30
+ obj_dict["spec"]["model"] = str(obj_dict["spec"]["model"]).lower()
31
+ return super().validate(obj_dict, **kwargs)
30
32
 
31
33
 
32
34
  def create_output_folder(name):
@@ -34,7 +36,7 @@ def create_output_folder(name):
34
36
  protocol = fsspec.utils.get_protocol(output_folder)
35
37
  storage_options = {}
36
38
  if protocol != "file":
37
- storage_options = auth or default_signer()
39
+ storage_options = default_signer()
38
40
 
39
41
  fs = fsspec.filesystem(protocol, **storage_options)
40
42
  name_suffix = 1
@@ -166,9 +166,8 @@ class AnomalyOperatorBaseModel(ABC):
166
166
  yaml_appendix = rc.Yaml(self.config.to_dict())
167
167
  summary = rc.Block(
168
168
  rc.Group(
169
- rc.Text(
170
- f"You selected the **`{self.spec.model}`** model.\n{model_description.text}\n"
171
- ),
169
+ rc.Text(f"You selected the **`{self.spec.model}`** model.\n"),
170
+ model_description,
172
171
  rc.Text(
173
172
  "Based on your dataset, you could have also selected "
174
173
  f"any of the models: `{'`, `'.join(SupportedModels.keys() if self.spec.datetime_column else NonTimeADSupportedModels.keys())}`."
@@ -26,9 +26,9 @@ class UnSupportedModelError(Exception):
26
26
 
27
27
  def __init__(self, operator_config: AnomalyOperatorConfig, model_type: str):
28
28
  supported_models = (
29
- SupportedModels.values
29
+ SupportedModels.values()
30
30
  if operator_config.spec.datetime_column
31
- else NonTimeADSupportedModels.values
31
+ else NonTimeADSupportedModels.values()
32
32
  )
33
33
  message = (
34
34
  f"Model: `{model_type}` is not supported. "
@@ -1,18 +1,19 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
6
+ from abc import ABC
7
+
8
+ import pandas as pd
9
+
7
10
  from ads.opctl import logger
11
+ from ads.opctl.operator.lowcode.common.const import DataColumns
8
12
  from ads.opctl.operator.lowcode.common.errors import (
9
- InvalidParameterError,
10
13
  DataMismatchError,
14
+ InvalidParameterError,
11
15
  )
12
- from ads.opctl.operator.lowcode.common.const import DataColumns
13
16
  from ads.opctl.operator.lowcode.common.utils import merge_category_columns
14
- import pandas as pd
15
- from abc import ABC
16
17
 
17
18
 
18
19
  class Transformations(ABC):
@@ -58,6 +59,7 @@ class Transformations(ABC):
58
59
 
59
60
  """
60
61
  clean_df = self._remove_trailing_whitespace(data)
62
+ # clean_df = self._normalize_column_names(clean_df)
61
63
  if self.name == "historical_data":
62
64
  self._check_historical_dataset(clean_df)
63
65
  clean_df = self._set_series_id_column(clean_df)
@@ -95,8 +97,11 @@ class Transformations(ABC):
95
97
  def _remove_trailing_whitespace(self, df):
96
98
  return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
97
99
 
100
+ # def _normalize_column_names(self, df):
101
+ # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
102
+
98
103
  def _set_series_id_column(self, df):
99
- self._target_category_columns_map = dict()
104
+ self._target_category_columns_map = {}
100
105
  if not self.target_category_columns:
101
106
  df[DataColumns.Series] = "Series 1"
102
107
  self.has_artificial_series = True
@@ -125,10 +130,10 @@ class Transformations(ABC):
125
130
  df[self.dt_column_name] = pd.to_datetime(
126
131
  df[self.dt_column_name], format=self.dt_column_format
127
132
  )
128
- except:
133
+ except Exception as ee:
129
134
  raise InvalidParameterError(
130
135
  f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
131
- )
136
+ ) from ee
132
137
  return df
133
138
 
134
139
  def _set_multi_index(self, df):
@@ -242,7 +247,6 @@ class Transformations(ABC):
242
247
  "Class": "A",
243
248
  "Num": 2
244
249
  },
245
-
246
250
  }
247
251
  """
248
252