oracle-ads 2.12.7__py3-none-any.whl → 2.12.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/app.py +12 -2
- ads/aqua/evaluation/entities.py +6 -0
- ads/aqua/evaluation/evaluation.py +25 -3
- ads/aqua/extension/deployment_handler.py +8 -4
- ads/aqua/extension/model_handler.py +9 -7
- ads/aqua/extension/ui_handler.py +13 -1
- ads/aqua/finetuning/entities.py +6 -0
- ads/aqua/finetuning/finetuning.py +47 -34
- ads/aqua/model/entities.py +2 -0
- ads/aqua/model/model.py +34 -6
- ads/aqua/modeldeployment/deployment.py +28 -10
- ads/aqua/modeldeployment/entities.py +7 -4
- ads/aqua/ui.py +24 -2
- ads/llm/guardrails/base.py +6 -5
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +34 -9
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
- ads/opctl/operator/common/utils.py +6 -4
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +2 -3
- ads/opctl/operator/lowcode/anomaly/model/factory.py +2 -2
- ads/opctl/operator/lowcode/common/transformations.py +14 -10
- ads/opctl/operator/lowcode/common/utils.py +37 -37
- ads/opctl/operator/lowcode/forecast/const.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +10 -2
- ads/opctl/operator/lowcode/forecast/model/base_model.py +10 -15
- ads/opctl/operator/lowcode/forecast/model/factory.py +3 -2
- ads/opctl/operator/lowcode/forecast/model/prophet.py +4 -1
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/opctl/operator/lowcode/pii/model/factory.py +7 -5
- ads/opctl/operator/lowcode/recommender/model/base_model.py +2 -1
- ads/opctl/operator/lowcode/recommender/model/factory.py +4 -6
- ads/opctl/operator/lowcode/recommender/model/svd.py +5 -5
- {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/METADATA +3 -3
- {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/RECORD +38 -38
- {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.7.dist-info → oracle_ads-2.12.9.dist-info}/entry_points.txt +0 -0
@@ -110,6 +110,8 @@ class AquaDeploymentApp(AquaApp):
|
|
110
110
|
private_endpoint_id: Optional[str] = None,
|
111
111
|
container_image_uri: Optional[None] = None,
|
112
112
|
cmd_var: List[str] = None,
|
113
|
+
freeform_tags: Optional[dict] = None,
|
114
|
+
defined_tags: Optional[dict] = None,
|
113
115
|
) -> "AquaDeployment":
|
114
116
|
"""
|
115
117
|
Creates a new Aqua deployment
|
@@ -163,6 +165,10 @@ class AquaDeploymentApp(AquaApp):
|
|
163
165
|
Required parameter for BYOC based deployments if this parameter was not set during model registration.
|
164
166
|
cmd_var: List[str]
|
165
167
|
The cmd of model deployment container runtime.
|
168
|
+
freeform_tags: dict
|
169
|
+
Freeform tags for the model deployment
|
170
|
+
defined_tags: dict
|
171
|
+
Defined tags for the model deployment
|
166
172
|
Returns
|
167
173
|
-------
|
168
174
|
AquaDeployment
|
@@ -172,7 +178,11 @@ class AquaDeploymentApp(AquaApp):
|
|
172
178
|
# TODO validate if the service model has no artifact and if it requires import step before deployment.
|
173
179
|
# Create a model catalog entry in the user compartment
|
174
180
|
aqua_model = AquaModelApp().create(
|
175
|
-
model_id=model_id,
|
181
|
+
model_id=model_id,
|
182
|
+
compartment_id=compartment_id,
|
183
|
+
project_id=project_id,
|
184
|
+
freeform_tags=freeform_tags,
|
185
|
+
defined_tags=defined_tags,
|
176
186
|
)
|
177
187
|
|
178
188
|
tags = {}
|
@@ -185,7 +195,7 @@ class AquaDeploymentApp(AquaApp):
|
|
185
195
|
tags[tag] = aqua_model.freeform_tags[tag]
|
186
196
|
|
187
197
|
tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
|
188
|
-
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK,
|
198
|
+
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)})
|
189
199
|
|
190
200
|
# Set up info to get deployment config
|
191
201
|
config_source_id = model_id
|
@@ -418,12 +428,14 @@ class AquaDeploymentApp(AquaApp):
|
|
418
428
|
if cmd_var:
|
419
429
|
container_runtime.with_cmd(cmd_var)
|
420
430
|
|
431
|
+
tags = {**tags, **(freeform_tags or {})}
|
421
432
|
# configure model deployment and deploy model on container runtime
|
422
433
|
deployment = (
|
423
434
|
ModelDeployment()
|
424
435
|
.with_display_name(display_name)
|
425
436
|
.with_description(description)
|
426
437
|
.with_freeform_tags(**tags)
|
438
|
+
.with_defined_tags(**(defined_tags or {}))
|
427
439
|
.with_infrastructure(infrastructure)
|
428
440
|
.with_runtime(container_runtime)
|
429
441
|
).deploy(wait_for_completion=False)
|
@@ -533,16 +545,22 @@ class AquaDeploymentApp(AquaApp):
|
|
533
545
|
return results
|
534
546
|
|
535
547
|
@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
|
536
|
-
def delete(self,model_deployment_id:str):
|
537
|
-
return self.ds_client.delete_model_deployment(
|
548
|
+
def delete(self, model_deployment_id: str):
|
549
|
+
return self.ds_client.delete_model_deployment(
|
550
|
+
model_deployment_id=model_deployment_id
|
551
|
+
).data
|
538
552
|
|
539
|
-
@telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
|
540
|
-
def deactivate(self,model_deployment_id:str):
|
541
|
-
return self.ds_client.deactivate_model_deployment(
|
553
|
+
@telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua")
|
554
|
+
def deactivate(self, model_deployment_id: str):
|
555
|
+
return self.ds_client.deactivate_model_deployment(
|
556
|
+
model_deployment_id=model_deployment_id
|
557
|
+
).data
|
542
558
|
|
543
|
-
@telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
|
544
|
-
def activate(self,model_deployment_id:str):
|
545
|
-
return self.ds_client.activate_model_deployment(
|
559
|
+
@telemetry(entry_point="plugin=deployment&action=activate", name="aqua")
|
560
|
+
def activate(self, model_deployment_id: str):
|
561
|
+
return self.ds_client.activate_model_deployment(
|
562
|
+
model_deployment_id=model_deployment_id
|
563
|
+
).data
|
546
564
|
|
547
565
|
@telemetry(entry_point="plugin=deployment&action=get", name="aqua")
|
548
566
|
def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
|
@@ -98,9 +98,12 @@ class AquaDeployment(DataClassSerializable):
|
|
98
98
|
),
|
99
99
|
)
|
100
100
|
|
101
|
-
|
102
|
-
|
103
|
-
|
101
|
+
tags = {}
|
102
|
+
tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
|
103
|
+
tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)
|
104
|
+
|
105
|
+
aqua_service_model_tag = tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None)
|
106
|
+
aqua_model_name = tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN)
|
104
107
|
private_endpoint_id = getattr(
|
105
108
|
instance_configuration, "private_endpoint_id", UNKNOWN
|
106
109
|
)
|
@@ -125,7 +128,7 @@ class AquaDeployment(DataClassSerializable):
|
|
125
128
|
ocid=oci_model_deployment.id,
|
126
129
|
region=region,
|
127
130
|
),
|
128
|
-
tags=
|
131
|
+
tags=tags,
|
129
132
|
environment_variables=environment_variables,
|
130
133
|
cmd=cmd,
|
131
134
|
)
|
ads/aqua/ui.py
CHANGED
@@ -481,12 +481,12 @@ class AquaUIApp(AquaApp):
|
|
481
481
|
|
482
482
|
@telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua")
|
483
483
|
def list_job_shapes(self, **kwargs) -> list:
|
484
|
-
"""Lists all
|
484
|
+
"""Lists all available job shapes for the specified compartment.
|
485
485
|
|
486
486
|
Parameters
|
487
487
|
----------
|
488
488
|
**kwargs
|
489
|
-
|
489
|
+
Additional arguments, such as `compartment_id`,
|
490
490
|
for `list_job_shapes <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_job_shapes>`_
|
491
491
|
|
492
492
|
Returns
|
@@ -500,6 +500,28 @@ class AquaUIApp(AquaApp):
|
|
500
500
|
).data
|
501
501
|
return sanitize_response(oci_client=self.ds_client, response=res)
|
502
502
|
|
503
|
+
@telemetry(entry_point="plugin=ui&action=list_model_deployment_shapes", name="aqua")
|
504
|
+
def list_model_deployment_shapes(self, **kwargs) -> list:
|
505
|
+
"""Lists all available shapes for model deployment in the specified compartment.
|
506
|
+
|
507
|
+
Parameters
|
508
|
+
----------
|
509
|
+
**kwargs
|
510
|
+
Additional arguments, such as `compartment_id`,
|
511
|
+
for `list_model_deployment_shapes <https://docs.oracle.com/en-us/iaas/api/#/en/data-science/20190101/ModelDeploymentShapeSummary/ListModelDeploymentShapes>`_
|
512
|
+
|
513
|
+
Returns
|
514
|
+
-------
|
515
|
+
str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`."""
|
516
|
+
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
|
517
|
+
logger.info(
|
518
|
+
f"Loading model deployment shape summary from compartment: {compartment_id}"
|
519
|
+
)
|
520
|
+
res = self.ds_client.list_model_deployment_shapes(
|
521
|
+
compartment_id=compartment_id, **kwargs
|
522
|
+
).data
|
523
|
+
return sanitize_response(oci_client=self.ds_client, response=res)
|
524
|
+
|
503
525
|
@telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua")
|
504
526
|
def list_vcn(self, **kwargs) -> list:
|
505
527
|
"""Lists the virtual cloud networks (VCNs) in the specified compartment.
|
ads/llm/guardrails/base.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
8
7
|
import datetime
|
9
8
|
import functools
|
10
|
-
import operator
|
11
9
|
import importlib.util
|
10
|
+
import operator
|
12
11
|
import sys
|
12
|
+
from typing import Any, List, Optional, Union
|
13
13
|
|
14
|
-
from typing import Any, List, Dict, Tuple
|
15
14
|
from langchain.schema.prompt import PromptValue
|
16
15
|
from langchain.tools.base import BaseTool, ToolException
|
17
16
|
from pydantic import BaseModel, model_validator
|
@@ -207,7 +206,9 @@ class Guardrail(BaseTool):
|
|
207
206
|
return input.to_string()
|
208
207
|
return str(input)
|
209
208
|
|
210
|
-
def _to_args_and_kwargs(
|
209
|
+
def _to_args_and_kwargs(
|
210
|
+
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
211
|
+
) -> tuple[tuple, dict]:
|
211
212
|
if isinstance(tool_input, dict):
|
212
213
|
return (), tool_input
|
213
214
|
else:
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
"""Chat model for OCI data science model deployment endpoint."""
|
7
6
|
|
@@ -50,6 +49,7 @@ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint i
|
|
50
49
|
)
|
51
50
|
|
52
51
|
logger = logging.getLogger(__name__)
|
52
|
+
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
|
53
53
|
|
54
54
|
|
55
55
|
def _is_pydantic_class(obj: Any) -> bool:
|
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
93
93
|
Key init args — client params:
|
94
94
|
auth: dict
|
95
95
|
ADS auth dictionary for OCI authentication.
|
96
|
+
default_headers: Optional[Dict]
|
97
|
+
The headers to be added to the Model Deployment request.
|
96
98
|
|
97
99
|
Instantiate:
|
98
100
|
.. code-block:: python
|
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
109
111
|
"temperature": 0.2,
|
110
112
|
# other model parameters ...
|
111
113
|
},
|
114
|
+
default_headers={
|
115
|
+
"route": "/v1/chat/completions",
|
116
|
+
# other request headers ...
|
117
|
+
},
|
112
118
|
)
|
113
119
|
|
114
120
|
Invocation:
|
@@ -291,6 +297,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
291
297
|
"stream": self.streaming,
|
292
298
|
}
|
293
299
|
|
300
|
+
def _headers(
|
301
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
302
|
+
) -> Dict:
|
303
|
+
"""Construct and return the headers for a request.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
307
|
+
Defaults to `False`.
|
308
|
+
body (optional): The request body to be included in the headers if
|
309
|
+
the request is asynchronous.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
313
|
+
"""
|
314
|
+
return {
|
315
|
+
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
|
316
|
+
**super()._headers(is_async=is_async, body=body),
|
317
|
+
}
|
318
|
+
|
294
319
|
def _generate(
|
295
320
|
self,
|
296
321
|
messages: List[BaseMessage],
|
@@ -704,7 +729,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
704
729
|
|
705
730
|
for choice in choices:
|
706
731
|
message = _convert_dict_to_message(choice["message"])
|
707
|
-
generation_info =
|
732
|
+
generation_info = {"finish_reason": choice.get("finish_reason")}
|
708
733
|
if "logprobs" in choice:
|
709
734
|
generation_info["logprobs"] = choice["logprobs"]
|
710
735
|
|
@@ -794,7 +819,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
794
819
|
"""Number of most likely tokens to consider at each step."""
|
795
820
|
|
796
821
|
min_p: Optional[float] = 0.0
|
797
|
-
"""Float that represents the minimum probability for a token to be considered.
|
822
|
+
"""Float that represents the minimum probability for a token to be considered.
|
798
823
|
Must be in [0,1]. 0 to disable this."""
|
799
824
|
|
800
825
|
repetition_penalty: Optional[float] = 1.0
|
@@ -818,7 +843,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
818
843
|
the EOS token is generated."""
|
819
844
|
|
820
845
|
min_tokens: Optional[int] = 0
|
821
|
-
"""Minimum number of tokens to generate per output sequence before
|
846
|
+
"""Minimum number of tokens to generate per output sequence before
|
822
847
|
EOS or stop_token_ids can be generated"""
|
823
848
|
|
824
849
|
stop_token_ids: Optional[List[int]] = None
|
@@ -836,7 +861,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
836
861
|
tool_choice: Optional[str] = None
|
837
862
|
"""Whether to use tool calling.
|
838
863
|
Defaults to None, tool calling is disabled.
|
839
|
-
Tool calling requires model support and the vLLM to be configured
|
864
|
+
Tool calling requires model support and the vLLM to be configured
|
840
865
|
with `--tool-call-parser`.
|
841
866
|
Set this to `auto` for the model to make tool calls automatically.
|
842
867
|
Set this to `required` to force the model to always call one or more tools.
|
@@ -956,9 +981,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
|
|
956
981
|
"""Total probability mass of tokens to consider at each step."""
|
957
982
|
|
958
983
|
top_logprobs: Optional[int] = None
|
959
|
-
"""An integer between 0 and 5 specifying the number of most
|
960
|
-
likely tokens to return at each token position, each with an
|
961
|
-
associated log probability. logprobs must be set to true if
|
984
|
+
"""An integer between 0 and 5 specifying the number of most
|
985
|
+
likely tokens to return at each token position, each with an
|
986
|
+
associated log probability. logprobs must be set to true if
|
962
987
|
this parameter is used."""
|
963
988
|
|
964
989
|
@property
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
@@ -24,6 +23,7 @@ from typing import (
|
|
24
23
|
|
25
24
|
import aiohttp
|
26
25
|
import requests
|
26
|
+
from langchain_community.utilities.requests import Requests
|
27
27
|
from langchain_core.callbacks import (
|
28
28
|
AsyncCallbackManagerForLLMRun,
|
29
29
|
CallbackManagerForLLMRun,
|
@@ -34,14 +34,13 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
34
34
|
from langchain_core.utils import get_from_dict_or_env
|
35
35
|
from pydantic import Field, model_validator
|
36
36
|
|
37
|
-
from langchain_community.utilities.requests import Requests
|
38
|
-
|
39
37
|
logger = logging.getLogger(__name__)
|
40
38
|
|
41
39
|
|
42
40
|
DEFAULT_TIME_OUT = 300
|
43
41
|
DEFAULT_CONTENT_TYPE_JSON = "application/json"
|
44
42
|
DEFAULT_MODEL_NAME = "odsc-llm"
|
43
|
+
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
|
45
44
|
|
46
45
|
|
47
46
|
class TokenExpiredError(Exception):
|
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
|
|
86
85
|
max_retries: int = 3
|
87
86
|
"""Maximum number of retries to make when generating."""
|
88
87
|
|
88
|
+
default_headers: Optional[Dict[str, Any]] = None
|
89
|
+
"""The headers to be added to the Model Deployment request."""
|
90
|
+
|
89
91
|
@model_validator(mode="before")
|
90
92
|
@classmethod
|
91
93
|
def validate_environment(cls, values: Dict) -> Dict:
|
@@ -101,7 +103,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
101
103
|
"Please install it with `pip install oracle_ads`."
|
102
104
|
) from ex
|
103
105
|
|
104
|
-
if not values.get("auth"
|
106
|
+
if not values.get("auth"):
|
105
107
|
values["auth"] = ads.common.auth.default_signer()
|
106
108
|
|
107
109
|
values["endpoint"] = get_from_dict_or_env(
|
@@ -125,12 +127,12 @@ class BaseOCIModelDeployment(Serializable):
|
|
125
127
|
Returns:
|
126
128
|
Dict: A dictionary containing the appropriate headers for the request.
|
127
129
|
"""
|
130
|
+
headers = self.default_headers or {}
|
128
131
|
if is_async:
|
129
132
|
signer = self.auth["signer"]
|
130
133
|
_req = requests.Request("POST", self.endpoint, json=body)
|
131
134
|
req = _req.prepare()
|
132
135
|
req = signer(req)
|
133
|
-
headers = {}
|
134
136
|
for key, value in req.headers.items():
|
135
137
|
headers[key] = value
|
136
138
|
|
@@ -140,7 +142,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
140
142
|
)
|
141
143
|
return headers
|
142
144
|
|
143
|
-
|
145
|
+
headers.update(
|
144
146
|
{
|
145
147
|
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
146
148
|
"enable-streaming": "true",
|
@@ -152,6 +154,8 @@ class BaseOCIModelDeployment(Serializable):
|
|
152
154
|
}
|
153
155
|
)
|
154
156
|
|
157
|
+
return headers
|
158
|
+
|
155
159
|
def completion_with_retry(
|
156
160
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
157
161
|
) -> Any:
|
@@ -357,7 +361,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
357
361
|
self.auth["signer"].refresh_security_token()
|
358
362
|
return True
|
359
363
|
return False
|
360
|
-
|
364
|
+
|
361
365
|
@classmethod
|
362
366
|
def is_lc_serializable(cls) -> bool:
|
363
367
|
"""Return whether this model can be serialized by LangChain."""
|
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
388
392
|
model="odsc-llm",
|
389
393
|
streaming=True,
|
390
394
|
model_kwargs={"frequency_penalty": 1.0},
|
395
|
+
headers={
|
396
|
+
"route": "/v1/completions",
|
397
|
+
# other request headers ...
|
398
|
+
}
|
391
399
|
)
|
392
400
|
llm.invoke("tell me a joke.")
|
393
401
|
|
@@ -477,6 +485,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
477
485
|
**self._default_params,
|
478
486
|
}
|
479
487
|
|
488
|
+
def _headers(
|
489
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
490
|
+
) -> Dict:
|
491
|
+
"""Construct and return the headers for a request.
|
492
|
+
|
493
|
+
Args:
|
494
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
495
|
+
Defaults to `False`.
|
496
|
+
body (optional): The request body to be included in the headers if
|
497
|
+
the request is asynchronous.
|
498
|
+
|
499
|
+
Returns:
|
500
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
501
|
+
"""
|
502
|
+
return {
|
503
|
+
"route": DEFAULT_INFERENCE_ENDPOINT,
|
504
|
+
**super()._headers(is_async=is_async, body=body),
|
505
|
+
}
|
506
|
+
|
480
507
|
def _generate(
|
481
508
|
self,
|
482
509
|
prompts: List[str],
|
@@ -712,9 +739,9 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
712
739
|
def _generate_info(self, choice: dict) -> Any:
|
713
740
|
"""Extracts generation info from the response."""
|
714
741
|
gen_info = {}
|
715
|
-
finish_reason = choice.get("finish_reason"
|
716
|
-
logprobs = choice.get("logprobs"
|
717
|
-
index = choice.get("index"
|
742
|
+
finish_reason = choice.get("finish_reason")
|
743
|
+
logprobs = choice.get("logprobs")
|
744
|
+
index = choice.get("index")
|
718
745
|
if finish_reason:
|
719
746
|
gen_info.update({"finish_reason": finish_reason})
|
720
747
|
if logprobs is not None:
|
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
3
|
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -18,7 +17,6 @@ import yaml
|
|
18
17
|
from cerberus import Validator
|
19
18
|
|
20
19
|
from ads.opctl import logger, utils
|
21
|
-
from ads.opctl.operator import __operators__
|
22
20
|
|
23
21
|
CONTAINER_NETWORK = "CONTAINER_NETWORK"
|
24
22
|
|
@@ -26,7 +24,11 @@ CONTAINER_NETWORK = "CONTAINER_NETWORK"
|
|
26
24
|
class OperatorValidator(Validator):
|
27
25
|
"""The custom validator class."""
|
28
26
|
|
29
|
-
|
27
|
+
def validate(self, obj_dict, **kwargs):
|
28
|
+
# Model should be case insensitive
|
29
|
+
if "model" in obj_dict["spec"]:
|
30
|
+
obj_dict["spec"]["model"] = str(obj_dict["spec"]["model"]).lower()
|
31
|
+
return super().validate(obj_dict, **kwargs)
|
30
32
|
|
31
33
|
|
32
34
|
def create_output_folder(name):
|
@@ -34,7 +36,7 @@ def create_output_folder(name):
|
|
34
36
|
protocol = fsspec.utils.get_protocol(output_folder)
|
35
37
|
storage_options = {}
|
36
38
|
if protocol != "file":
|
37
|
-
storage_options =
|
39
|
+
storage_options = default_signer()
|
38
40
|
|
39
41
|
fs = fsspec.filesystem(protocol, **storage_options)
|
40
42
|
name_suffix = 1
|
@@ -166,9 +166,8 @@ class AnomalyOperatorBaseModel(ABC):
|
|
166
166
|
yaml_appendix = rc.Yaml(self.config.to_dict())
|
167
167
|
summary = rc.Block(
|
168
168
|
rc.Group(
|
169
|
-
rc.Text(
|
170
|
-
|
171
|
-
),
|
169
|
+
rc.Text(f"You selected the **`{self.spec.model}`** model.\n"),
|
170
|
+
model_description,
|
172
171
|
rc.Text(
|
173
172
|
"Based on your dataset, you could have also selected "
|
174
173
|
f"any of the models: `{'`, `'.join(SupportedModels.keys() if self.spec.datetime_column else NonTimeADSupportedModels.keys())}`."
|
@@ -26,9 +26,9 @@ class UnSupportedModelError(Exception):
|
|
26
26
|
|
27
27
|
def __init__(self, operator_config: AnomalyOperatorConfig, model_type: str):
|
28
28
|
supported_models = (
|
29
|
-
SupportedModels.values
|
29
|
+
SupportedModels.values()
|
30
30
|
if operator_config.spec.datetime_column
|
31
|
-
else NonTimeADSupportedModels.values
|
31
|
+
else NonTimeADSupportedModels.values()
|
32
32
|
)
|
33
33
|
message = (
|
34
34
|
f"Model: `{model_type}` is not supported. "
|
@@ -1,18 +1,19 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
3
|
+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
6
|
+
from abc import ABC
|
7
|
+
|
8
|
+
import pandas as pd
|
9
|
+
|
7
10
|
from ads.opctl import logger
|
11
|
+
from ads.opctl.operator.lowcode.common.const import DataColumns
|
8
12
|
from ads.opctl.operator.lowcode.common.errors import (
|
9
|
-
InvalidParameterError,
|
10
13
|
DataMismatchError,
|
14
|
+
InvalidParameterError,
|
11
15
|
)
|
12
|
-
from ads.opctl.operator.lowcode.common.const import DataColumns
|
13
16
|
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
|
14
|
-
import pandas as pd
|
15
|
-
from abc import ABC
|
16
17
|
|
17
18
|
|
18
19
|
class Transformations(ABC):
|
@@ -58,6 +59,7 @@ class Transformations(ABC):
|
|
58
59
|
|
59
60
|
"""
|
60
61
|
clean_df = self._remove_trailing_whitespace(data)
|
62
|
+
# clean_df = self._normalize_column_names(clean_df)
|
61
63
|
if self.name == "historical_data":
|
62
64
|
self._check_historical_dataset(clean_df)
|
63
65
|
clean_df = self._set_series_id_column(clean_df)
|
@@ -95,8 +97,11 @@ class Transformations(ABC):
|
|
95
97
|
def _remove_trailing_whitespace(self, df):
|
96
98
|
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
|
97
99
|
|
100
|
+
# def _normalize_column_names(self, df):
|
101
|
+
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
|
102
|
+
|
98
103
|
def _set_series_id_column(self, df):
|
99
|
-
self._target_category_columns_map =
|
104
|
+
self._target_category_columns_map = {}
|
100
105
|
if not self.target_category_columns:
|
101
106
|
df[DataColumns.Series] = "Series 1"
|
102
107
|
self.has_artificial_series = True
|
@@ -125,10 +130,10 @@ class Transformations(ABC):
|
|
125
130
|
df[self.dt_column_name] = pd.to_datetime(
|
126
131
|
df[self.dt_column_name], format=self.dt_column_format
|
127
132
|
)
|
128
|
-
except:
|
133
|
+
except Exception as ee:
|
129
134
|
raise InvalidParameterError(
|
130
135
|
f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
|
131
|
-
)
|
136
|
+
) from ee
|
132
137
|
return df
|
133
138
|
|
134
139
|
def _set_multi_index(self, df):
|
@@ -242,7 +247,6 @@ class Transformations(ABC):
|
|
242
247
|
"Class": "A",
|
243
248
|
"Num": 2
|
244
249
|
},
|
245
|
-
|
246
250
|
}
|
247
251
|
"""
|
248
252
|
|