oracle-ads 2.12.8__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -4
- ads/aqua/app.py +12 -2
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/evaluation/entities.py +6 -0
- ads/aqua/evaluation/evaluation.py +25 -3
- ads/aqua/extension/deployment_handler.py +8 -4
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +25 -6
- ads/aqua/extension/ui_handler.py +13 -1
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +70 -17
- ads/aqua/finetuning/finetuning.py +79 -82
- ads/aqua/model/entities.py +4 -1
- ads/aqua/model/model.py +95 -29
- ads/aqua/modeldeployment/deployment.py +13 -1
- ads/aqua/modeldeployment/entities.py +7 -4
- ads/aqua/ui.py +24 -2
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/guardrails/base.py +6 -5
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +12 -5
- ads/opctl/operator/lowcode/common/utils.py +11 -5
- ads/opctl/operator/lowcode/forecast/const.py +3 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
- ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
- ads/opctl/operator/lowcode/forecast/utils.py +8 -6
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +9 -10
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +71 -50
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
2
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
3
|
+
import logging
|
4
|
+
import traceback
|
5
|
+
from sqlite3 import Connection
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
|
+
|
8
|
+
import autogen.runtime_logging
|
9
|
+
from autogen.logger.base_logger import BaseLogger
|
10
|
+
from autogen.logger.logger_factory import LoggerFactory
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class LoggerManager(BaseLogger):
|
16
|
+
"""Manages multiple AutoGen loggers."""
|
17
|
+
|
18
|
+
def __init__(self) -> None:
|
19
|
+
self.loggers: List[BaseLogger] = []
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
def add_logger(self, logger: BaseLogger) -> None:
|
23
|
+
"""Adds a new AutoGen logger."""
|
24
|
+
self.loggers.append(logger)
|
25
|
+
|
26
|
+
def _call_loggers(self, method: str, *args, **kwargs) -> None:
|
27
|
+
"""Calls the specific method on each AutoGen logger in self.loggers."""
|
28
|
+
for autogen_logger in self.loggers:
|
29
|
+
try:
|
30
|
+
getattr(autogen_logger, method)(*args, **kwargs)
|
31
|
+
except Exception as e:
|
32
|
+
# Catch the logging exception so that the program will not be interrupted.
|
33
|
+
logger.error(
|
34
|
+
"Failed to %s with %s: %s",
|
35
|
+
method,
|
36
|
+
autogen_logger.__class__.__name__,
|
37
|
+
str(e),
|
38
|
+
)
|
39
|
+
logger.debug(traceback.format_exc())
|
40
|
+
|
41
|
+
def start(self) -> str:
|
42
|
+
"""Starts all loggers."""
|
43
|
+
return self._call_loggers("start")
|
44
|
+
|
45
|
+
def stop(self) -> None:
|
46
|
+
self._call_loggers("stop")
|
47
|
+
# Remove the loggers once they are stopped.
|
48
|
+
self.loggers = []
|
49
|
+
|
50
|
+
def get_connection(self) -> None | Connection:
|
51
|
+
return self._call_loggers("get_connection")
|
52
|
+
|
53
|
+
def log_chat_completion(self, *args, **kwargs) -> None:
|
54
|
+
return self._call_loggers("log_chat_completion", *args, **kwargs)
|
55
|
+
|
56
|
+
def log_new_agent(self, *args, **kwargs) -> None:
|
57
|
+
return self._call_loggers("log_new_agent", *args, **kwargs)
|
58
|
+
|
59
|
+
def log_event(self, *args, **kwargs) -> None:
|
60
|
+
return self._call_loggers("log_event", *args, **kwargs)
|
61
|
+
|
62
|
+
def log_new_wrapper(self, *args, **kwargs) -> None:
|
63
|
+
return self._call_loggers("log_new_wrapper", *args, **kwargs)
|
64
|
+
|
65
|
+
def log_new_client(self, *args, **kwargs) -> None:
|
66
|
+
return self._call_loggers("log_new_client", *args, **kwargs)
|
67
|
+
|
68
|
+
def log_function_use(self, *args, **kwargs) -> None:
|
69
|
+
return self._call_loggers("log_function_use", *args, **kwargs)
|
70
|
+
|
71
|
+
def __repr__(self) -> str:
|
72
|
+
return "\n\n".join(
|
73
|
+
[
|
74
|
+
f"{str(logger.__class__)}:\n{logger.__repr__()}"
|
75
|
+
for logger in self.loggers
|
76
|
+
]
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def start(
|
81
|
+
autogen_logger: Optional[BaseLogger] = None,
|
82
|
+
logger_type: str = None,
|
83
|
+
config: Optional[Dict[str, Any]] = None,
|
84
|
+
) -> str:
|
85
|
+
"""Starts logging with AutoGen logger.
|
86
|
+
Specify your custom autogen_logger, or the logger_type and config to use a built-in logger.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
autogen_logger : BaseLogger, optional
|
91
|
+
An AutoGen logger, which should be a subclass of autogen.logger.base_logger.BaseLogger.
|
92
|
+
logger_type : str, optional
|
93
|
+
Logger type, which can be a built-in AutoGen logger type ("file", or "sqlite"), by default None.
|
94
|
+
config : dict, optional
|
95
|
+
Configurations for the built-in AutoGen logger, by default None
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
str
|
100
|
+
A unique session ID returned from starting the logger.
|
101
|
+
|
102
|
+
"""
|
103
|
+
if autogen_logger and logger_type:
|
104
|
+
raise ValueError(
|
105
|
+
"Please specify only autogen_logger(%s) or logger_type(%s).",
|
106
|
+
autogen_logger,
|
107
|
+
logger_type,
|
108
|
+
)
|
109
|
+
|
110
|
+
# Check if a logger is already configured
|
111
|
+
existing_logger = autogen.runtime_logging.autogen_logger
|
112
|
+
if not existing_logger:
|
113
|
+
# No logger is configured
|
114
|
+
logger_manager = LoggerManager()
|
115
|
+
elif isinstance(existing_logger, LoggerManager):
|
116
|
+
# Logger is already configured with ADS
|
117
|
+
logger_manager = existing_logger
|
118
|
+
else:
|
119
|
+
# Logger is configured but it is not via ADS
|
120
|
+
logger.warning("AutoGen is already configured with %s", str(existing_logger))
|
121
|
+
logger_manager = LoggerManager()
|
122
|
+
logger_manager.add_logger(existing_logger)
|
123
|
+
|
124
|
+
# Add AutoGen logger
|
125
|
+
if not autogen_logger:
|
126
|
+
autogen_logger = LoggerFactory.get_logger(
|
127
|
+
logger_type=logger_type, config=config
|
128
|
+
)
|
129
|
+
logger_manager.add_logger(autogen_logger)
|
130
|
+
|
131
|
+
try:
|
132
|
+
session_id = autogen_logger.start()
|
133
|
+
autogen.runtime_logging.is_logging = True
|
134
|
+
autogen.runtime_logging.autogen_logger = logger_manager
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Failed to start logging: {e}")
|
137
|
+
return session_id
|
138
|
+
|
139
|
+
|
140
|
+
def stop(*loggers) -> BaseLogger:
|
141
|
+
"""Stops AutoGen logger.
|
142
|
+
If loggers are managed by LoggerManager,
|
143
|
+
you may specify one or more loggers to be stopped.
|
144
|
+
If no logger is specified, all loggers will be stopped.
|
145
|
+
Stopped loggers will be removed from the LoggerManager.
|
146
|
+
"""
|
147
|
+
autogen_logger = autogen.runtime_logging.autogen_logger
|
148
|
+
if isinstance(autogen_logger, LoggerManager) and loggers:
|
149
|
+
for logger in loggers:
|
150
|
+
logger.stop()
|
151
|
+
if logger in autogen_logger.loggers:
|
152
|
+
autogen_logger.loggers.remove(logger)
|
153
|
+
else:
|
154
|
+
autogen.runtime_logging.stop()
|
155
|
+
return autogen_logger
|
156
|
+
|
157
|
+
|
158
|
+
def get_loggers() -> List[BaseLogger]:
|
159
|
+
"""Gets a list of existing AutoGen loggers."""
|
160
|
+
autogen_logger = autogen.runtime_logging.autogen_logger
|
161
|
+
if isinstance(autogen_logger, LoggerManager):
|
162
|
+
return autogen_logger.loggers
|
163
|
+
return [autogen_logger]
|
ads/llm/guardrails/base.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
8
7
|
import datetime
|
9
8
|
import functools
|
10
|
-
import operator
|
11
9
|
import importlib.util
|
10
|
+
import operator
|
12
11
|
import sys
|
12
|
+
from typing import Any, List, Optional, Union
|
13
13
|
|
14
|
-
from typing import Any, List, Dict, Tuple
|
15
14
|
from langchain.schema.prompt import PromptValue
|
16
15
|
from langchain.tools.base import BaseTool, ToolException
|
17
16
|
from pydantic import BaseModel, model_validator
|
@@ -207,7 +206,9 @@ class Guardrail(BaseTool):
|
|
207
206
|
return input.to_string()
|
208
207
|
return str(input)
|
209
208
|
|
210
|
-
def _to_args_and_kwargs(
|
209
|
+
def _to_args_and_kwargs(
|
210
|
+
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
|
211
|
+
) -> tuple[tuple, dict]:
|
211
212
|
if isinstance(tool_input, dict):
|
212
213
|
return (), tool_input
|
213
214
|
else:
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
"""Chat model for OCI data science model deployment endpoint."""
|
7
6
|
|
@@ -50,6 +49,7 @@ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint i
|
|
50
49
|
)
|
51
50
|
|
52
51
|
logger = logging.getLogger(__name__)
|
52
|
+
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
|
53
53
|
|
54
54
|
|
55
55
|
def _is_pydantic_class(obj: Any) -> bool:
|
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
93
93
|
Key init args — client params:
|
94
94
|
auth: dict
|
95
95
|
ADS auth dictionary for OCI authentication.
|
96
|
+
default_headers: Optional[Dict]
|
97
|
+
The headers to be added to the Model Deployment request.
|
96
98
|
|
97
99
|
Instantiate:
|
98
100
|
.. code-block:: python
|
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
109
111
|
"temperature": 0.2,
|
110
112
|
# other model parameters ...
|
111
113
|
},
|
114
|
+
default_headers={
|
115
|
+
"route": "/v1/chat/completions",
|
116
|
+
# other request headers ...
|
117
|
+
},
|
112
118
|
)
|
113
119
|
|
114
120
|
Invocation:
|
@@ -291,6 +297,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
291
297
|
"stream": self.streaming,
|
292
298
|
}
|
293
299
|
|
300
|
+
def _headers(
|
301
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
302
|
+
) -> Dict:
|
303
|
+
"""Construct and return the headers for a request.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
307
|
+
Defaults to `False`.
|
308
|
+
body (optional): The request body to be included in the headers if
|
309
|
+
the request is asynchronous.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
313
|
+
"""
|
314
|
+
return {
|
315
|
+
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
|
316
|
+
**super()._headers(is_async=is_async, body=body),
|
317
|
+
}
|
318
|
+
|
294
319
|
def _generate(
|
295
320
|
self,
|
296
321
|
messages: List[BaseMessage],
|
@@ -704,7 +729,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|
704
729
|
|
705
730
|
for choice in choices:
|
706
731
|
message = _convert_dict_to_message(choice["message"])
|
707
|
-
generation_info =
|
732
|
+
generation_info = {"finish_reason": choice.get("finish_reason")}
|
708
733
|
if "logprobs" in choice:
|
709
734
|
generation_info["logprobs"] = choice["logprobs"]
|
710
735
|
|
@@ -744,6 +769,8 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
744
769
|
Science Model Deployment endpoint. See:
|
745
770
|
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
746
771
|
|
772
|
+
See https://docs.vllm.ai/en/latest/api/inference_params.html for the defaults of the parameters.
|
773
|
+
|
747
774
|
Example:
|
748
775
|
|
749
776
|
.. code-block:: python
|
@@ -761,7 +788,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
761
788
|
|
762
789
|
""" # noqa: E501
|
763
790
|
|
764
|
-
frequency_penalty: float =
|
791
|
+
frequency_penalty: Optional[float] = None
|
765
792
|
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
766
793
|
|
767
794
|
logit_bias: Optional[Dict[str, float]] = None
|
@@ -773,7 +800,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
773
800
|
n: int = 1
|
774
801
|
"""Number of output sequences to return for the given prompt."""
|
775
802
|
|
776
|
-
presence_penalty: float =
|
803
|
+
presence_penalty: Optional[float] = None
|
777
804
|
"""Penalizes repeated tokens. Between 0 and 1."""
|
778
805
|
|
779
806
|
temperature: float = 0.2
|
@@ -787,25 +814,25 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
787
814
|
(the one with the highest log probability per token).
|
788
815
|
"""
|
789
816
|
|
790
|
-
use_beam_search: Optional[bool] =
|
817
|
+
use_beam_search: Optional[bool] = None
|
791
818
|
"""Whether to use beam search instead of sampling."""
|
792
819
|
|
793
820
|
top_k: Optional[int] = -1
|
794
821
|
"""Number of most likely tokens to consider at each step."""
|
795
822
|
|
796
823
|
min_p: Optional[float] = 0.0
|
797
|
-
"""Float that represents the minimum probability for a token to be considered.
|
824
|
+
"""Float that represents the minimum probability for a token to be considered.
|
798
825
|
Must be in [0,1]. 0 to disable this."""
|
799
826
|
|
800
|
-
repetition_penalty: Optional[float] =
|
827
|
+
repetition_penalty: Optional[float] = None
|
801
828
|
"""Float that penalizes new tokens based on their frequency in the
|
802
829
|
generated text. Values > 1 encourage the model to use new tokens."""
|
803
830
|
|
804
|
-
length_penalty: Optional[float] =
|
831
|
+
length_penalty: Optional[float] = None
|
805
832
|
"""Float that penalizes sequences based on their length. Used only
|
806
833
|
when `use_beam_search` is True."""
|
807
834
|
|
808
|
-
early_stopping: Optional[bool] =
|
835
|
+
early_stopping: Optional[bool] = None
|
809
836
|
"""Controls the stopping condition for beam search. It accepts the
|
810
837
|
following values: `True`, where the generation stops as soon as there
|
811
838
|
are `best_of` complete candidates; `False`, where a heuristic is applied
|
@@ -817,8 +844,8 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
817
844
|
"""Whether to ignore the EOS token and continue generating tokens after
|
818
845
|
the EOS token is generated."""
|
819
846
|
|
820
|
-
min_tokens: Optional[int] =
|
821
|
-
"""Minimum number of tokens to generate per output sequence before
|
847
|
+
min_tokens: Optional[int] = None
|
848
|
+
"""Minimum number of tokens to generate per output sequence before
|
822
849
|
EOS or stop_token_ids can be generated"""
|
823
850
|
|
824
851
|
stop_token_ids: Optional[List[int]] = None
|
@@ -826,17 +853,16 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
|
|
826
853
|
The returned output will contain the stop tokens unless the stop tokens
|
827
854
|
are special tokens."""
|
828
855
|
|
829
|
-
skip_special_tokens: Optional[bool] =
|
856
|
+
skip_special_tokens: Optional[bool] = None
|
830
857
|
"""Whether to skip special tokens in the output. Defaults to True."""
|
831
858
|
|
832
|
-
spaces_between_special_tokens: Optional[bool] =
|
833
|
-
"""Whether to add spaces between special tokens in the output.
|
834
|
-
Defaults to True."""
|
859
|
+
spaces_between_special_tokens: Optional[bool] = None
|
860
|
+
"""Whether to add spaces between special tokens in the output."""
|
835
861
|
|
836
862
|
tool_choice: Optional[str] = None
|
837
863
|
"""Whether to use tool calling.
|
838
864
|
Defaults to None, tool calling is disabled.
|
839
|
-
Tool calling requires model support and the vLLM to be configured
|
865
|
+
Tool calling requires model support and the vLLM to be configured
|
840
866
|
with `--tool-call-parser`.
|
841
867
|
Set this to `auto` for the model to make tool calls automatically.
|
842
868
|
Set this to `required` to force the model to always call one or more tools.
|
@@ -956,9 +982,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
|
|
956
982
|
"""Total probability mass of tokens to consider at each step."""
|
957
983
|
|
958
984
|
top_logprobs: Optional[int] = None
|
959
|
-
"""An integer between 0 and 5 specifying the number of most
|
960
|
-
likely tokens to return at each token position, each with an
|
961
|
-
associated log probability. logprobs must be set to true if
|
985
|
+
"""An integer between 0 and 5 specifying the number of most
|
986
|
+
likely tokens to return at each token position, each with an
|
987
|
+
associated log probability. logprobs must be set to true if
|
962
988
|
this parameter is used."""
|
963
989
|
|
964
990
|
@property
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c)
|
3
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
|
@@ -24,6 +23,7 @@ from typing import (
|
|
24
23
|
|
25
24
|
import aiohttp
|
26
25
|
import requests
|
26
|
+
from langchain_community.utilities.requests import Requests
|
27
27
|
from langchain_core.callbacks import (
|
28
28
|
AsyncCallbackManagerForLLMRun,
|
29
29
|
CallbackManagerForLLMRun,
|
@@ -34,14 +34,13 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
34
34
|
from langchain_core.utils import get_from_dict_or_env
|
35
35
|
from pydantic import Field, model_validator
|
36
36
|
|
37
|
-
from langchain_community.utilities.requests import Requests
|
38
|
-
|
39
37
|
logger = logging.getLogger(__name__)
|
40
38
|
|
41
39
|
|
42
40
|
DEFAULT_TIME_OUT = 300
|
43
41
|
DEFAULT_CONTENT_TYPE_JSON = "application/json"
|
44
42
|
DEFAULT_MODEL_NAME = "odsc-llm"
|
43
|
+
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
|
45
44
|
|
46
45
|
|
47
46
|
class TokenExpiredError(Exception):
|
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
|
|
86
85
|
max_retries: int = 3
|
87
86
|
"""Maximum number of retries to make when generating."""
|
88
87
|
|
88
|
+
default_headers: Optional[Dict[str, Any]] = None
|
89
|
+
"""The headers to be added to the Model Deployment request."""
|
90
|
+
|
89
91
|
@model_validator(mode="before")
|
90
92
|
@classmethod
|
91
93
|
def validate_environment(cls, values: Dict) -> Dict:
|
@@ -101,7 +103,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
101
103
|
"Please install it with `pip install oracle_ads`."
|
102
104
|
) from ex
|
103
105
|
|
104
|
-
if not values.get("auth"
|
106
|
+
if not values.get("auth"):
|
105
107
|
values["auth"] = ads.common.auth.default_signer()
|
106
108
|
|
107
109
|
values["endpoint"] = get_from_dict_or_env(
|
@@ -125,12 +127,12 @@ class BaseOCIModelDeployment(Serializable):
|
|
125
127
|
Returns:
|
126
128
|
Dict: A dictionary containing the appropriate headers for the request.
|
127
129
|
"""
|
130
|
+
headers = self.default_headers or {}
|
128
131
|
if is_async:
|
129
132
|
signer = self.auth["signer"]
|
130
133
|
_req = requests.Request("POST", self.endpoint, json=body)
|
131
134
|
req = _req.prepare()
|
132
135
|
req = signer(req)
|
133
|
-
headers = {}
|
134
136
|
for key, value in req.headers.items():
|
135
137
|
headers[key] = value
|
136
138
|
|
@@ -140,7 +142,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
140
142
|
)
|
141
143
|
return headers
|
142
144
|
|
143
|
-
|
145
|
+
headers.update(
|
144
146
|
{
|
145
147
|
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
|
146
148
|
"enable-streaming": "true",
|
@@ -152,6 +154,8 @@ class BaseOCIModelDeployment(Serializable):
|
|
152
154
|
}
|
153
155
|
)
|
154
156
|
|
157
|
+
return headers
|
158
|
+
|
155
159
|
def completion_with_retry(
|
156
160
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
157
161
|
) -> Any:
|
@@ -357,7 +361,7 @@ class BaseOCIModelDeployment(Serializable):
|
|
357
361
|
self.auth["signer"].refresh_security_token()
|
358
362
|
return True
|
359
363
|
return False
|
360
|
-
|
364
|
+
|
361
365
|
@classmethod
|
362
366
|
def is_lc_serializable(cls) -> bool:
|
363
367
|
"""Return whether this model can be serialized by LangChain."""
|
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
388
392
|
model="odsc-llm",
|
389
393
|
streaming=True,
|
390
394
|
model_kwargs={"frequency_penalty": 1.0},
|
395
|
+
headers={
|
396
|
+
"route": "/v1/completions",
|
397
|
+
# other request headers ...
|
398
|
+
}
|
391
399
|
)
|
392
400
|
llm.invoke("tell me a joke.")
|
393
401
|
|
@@ -477,6 +485,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
477
485
|
**self._default_params,
|
478
486
|
}
|
479
487
|
|
488
|
+
def _headers(
|
489
|
+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
|
490
|
+
) -> Dict:
|
491
|
+
"""Construct and return the headers for a request.
|
492
|
+
|
493
|
+
Args:
|
494
|
+
is_async (bool, optional): Indicates if the request is asynchronous.
|
495
|
+
Defaults to `False`.
|
496
|
+
body (optional): The request body to be included in the headers if
|
497
|
+
the request is asynchronous.
|
498
|
+
|
499
|
+
Returns:
|
500
|
+
Dict: A dictionary containing the appropriate headers for the request.
|
501
|
+
"""
|
502
|
+
return {
|
503
|
+
"route": DEFAULT_INFERENCE_ENDPOINT,
|
504
|
+
**super()._headers(is_async=is_async, body=body),
|
505
|
+
}
|
506
|
+
|
480
507
|
def _generate(
|
481
508
|
self,
|
482
509
|
prompts: List[str],
|
@@ -712,9 +739,9 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
|
|
712
739
|
def _generate_info(self, choice: dict) -> Any:
|
713
740
|
"""Extracts generation info from the response."""
|
714
741
|
gen_info = {}
|
715
|
-
finish_reason = choice.get("finish_reason"
|
716
|
-
logprobs = choice.get("logprobs"
|
717
|
-
index = choice.get("index"
|
742
|
+
finish_reason = choice.get("finish_reason")
|
743
|
+
logprobs = choice.get("logprobs")
|
744
|
+
index = choice.get("index")
|
718
745
|
if finish_reason:
|
719
746
|
gen_info.update({"finish_reason": finish_reason})
|
720
747
|
if logprobs is not None:
|
ads/model/__init__.py
CHANGED
@@ -1,29 +1,26 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2021,
|
3
|
+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
|
-
from ads.model.generic_model import GenericModel, ModelState
|
8
6
|
from ads.model.datascience_model import DataScienceModel
|
9
|
-
from ads.model.
|
7
|
+
from ads.model.deployment.model_deployer import ModelDeployer
|
8
|
+
from ads.model.deployment.model_deployment import ModelDeployment
|
9
|
+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
|
10
10
|
from ads.model.framework.automl_model import AutoMLModel
|
11
|
+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
|
12
|
+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
|
11
13
|
from ads.model.framework.lightgbm_model import LightGBMModel
|
12
14
|
from ads.model.framework.pytorch_model import PyTorchModel
|
13
15
|
from ads.model.framework.sklearn_model import SklearnModel
|
16
|
+
from ads.model.framework.spark_model import SparkPipelineModel
|
14
17
|
from ads.model.framework.tensorflow_model import TensorFlowModel
|
15
18
|
from ads.model.framework.xgboost_model import XGBoostModel
|
16
|
-
from ads.model.
|
17
|
-
from ads.model.
|
18
|
-
|
19
|
-
from ads.model.deployment.model_deployer import ModelDeployer
|
20
|
-
from ads.model.deployment.model_deployment import ModelDeployment
|
21
|
-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
|
22
|
-
|
19
|
+
from ads.model.generic_model import GenericModel, ModelState
|
20
|
+
from ads.model.model_properties import ModelProperties
|
21
|
+
from ads.model.model_version_set import ModelVersionSet, experiment
|
23
22
|
from ads.model.serde.common import SERDE
|
24
23
|
from ads.model.serde.model_input import ModelInputSerializer
|
25
|
-
|
26
|
-
from ads.model.model_version_set import ModelVersionSet, experiment
|
27
24
|
from ads.model.service.oci_datascience_model_version_set import (
|
28
25
|
ModelVersionSetNotExists,
|
29
26
|
ModelVersionSetNotSaved,
|
@@ -42,6 +39,7 @@ __all__ = [
|
|
42
39
|
"XGBoostModel",
|
43
40
|
"SparkPipelineModel",
|
44
41
|
"HuggingFacePipelineModel",
|
42
|
+
"EmbeddingONNXModel",
|
45
43
|
"ModelDeployer",
|
46
44
|
"ModelDeployment",
|
47
45
|
"ModelDeploymentProperties",
|
ads/model/artifact.py
CHANGED
@@ -1,28 +1,28 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
|
-
# Copyright (c) 2022,
|
3
|
+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
5
|
|
7
6
|
import fnmatch
|
8
7
|
import importlib
|
9
8
|
import os
|
10
|
-
import sys
|
11
9
|
import shutil
|
10
|
+
import sys
|
12
11
|
import tempfile
|
13
12
|
import uuid
|
14
|
-
import
|
13
|
+
from datetime import datetime
|
15
14
|
from typing import Dict, Optional, Tuple
|
15
|
+
|
16
|
+
import fsspec
|
17
|
+
from jinja2 import Environment, PackageLoader
|
18
|
+
|
19
|
+
from ads import __version__
|
16
20
|
from ads.common import auth as authutil
|
17
21
|
from ads.common import logger, utils
|
18
22
|
from ads.common.object_storage_details import ObjectStorageDetails
|
19
23
|
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
|
20
24
|
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
|
21
25
|
from ads.model.runtime.runtime_info import RuntimeInfo
|
22
|
-
from jinja2 import Environment, PackageLoader
|
23
|
-
import warnings
|
24
|
-
from ads import __version__
|
25
|
-
from datetime import datetime
|
26
26
|
|
27
27
|
MODEL_ARTIFACT_VERSION = "3.0"
|
28
28
|
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
|
@@ -378,6 +378,45 @@ class ModelArtifact:
|
|
378
378
|
) as f:
|
379
379
|
f.write(scorefn_template.render(context))
|
380
380
|
|
381
|
+
def prepare_schema(self, schema_name: str):
|
382
|
+
"""Copies schema to artifact directory.
|
383
|
+
|
384
|
+
Parameters
|
385
|
+
----------
|
386
|
+
schema_name: str
|
387
|
+
The schema name
|
388
|
+
|
389
|
+
Returns
|
390
|
+
-------
|
391
|
+
None
|
392
|
+
|
393
|
+
Raises
|
394
|
+
------
|
395
|
+
FileExistsError
|
396
|
+
If `schema_name` doesn't exist.
|
397
|
+
"""
|
398
|
+
uri_src = os.path.join(
|
399
|
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
|
400
|
+
"templates",
|
401
|
+
"schemas",
|
402
|
+
f"{schema_name}",
|
403
|
+
)
|
404
|
+
|
405
|
+
if not os.path.exists(uri_src):
|
406
|
+
raise FileExistsError(
|
407
|
+
f"{schema_name} does not exists. "
|
408
|
+
"Ensure the schema name is valid or specify a different one."
|
409
|
+
)
|
410
|
+
|
411
|
+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
|
412
|
+
|
413
|
+
utils.copy_file(
|
414
|
+
uri_src=uri_src,
|
415
|
+
uri_dst=uri_dst,
|
416
|
+
force_overwrite=True,
|
417
|
+
auth=self.auth,
|
418
|
+
)
|
419
|
+
|
381
420
|
def reload(self):
|
382
421
|
"""Syncs the `score.py` to reload the model and predict function.
|
383
422
|
|