oracle-ads 2.12.8__py3-none-any.whl → 2.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. ads/aqua/__init__.py +4 -3
  2. ads/aqua/app.py +40 -18
  3. ads/aqua/client/__init__.py +3 -0
  4. ads/aqua/client/client.py +799 -0
  5. ads/aqua/common/enums.py +3 -0
  6. ads/aqua/common/utils.py +62 -2
  7. ads/aqua/data.py +2 -19
  8. ads/aqua/evaluation/entities.py +6 -0
  9. ads/aqua/evaluation/evaluation.py +45 -15
  10. ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
  11. ads/aqua/extension/base_handler.py +12 -9
  12. ads/aqua/extension/deployment_handler.py +8 -4
  13. ads/aqua/extension/finetune_handler.py +8 -14
  14. ads/aqua/extension/model_handler.py +30 -6
  15. ads/aqua/extension/ui_handler.py +13 -1
  16. ads/aqua/finetuning/constants.py +5 -2
  17. ads/aqua/finetuning/entities.py +73 -17
  18. ads/aqua/finetuning/finetuning.py +110 -82
  19. ads/aqua/model/entities.py +5 -1
  20. ads/aqua/model/model.py +230 -104
  21. ads/aqua/modeldeployment/deployment.py +35 -11
  22. ads/aqua/modeldeployment/entities.py +7 -4
  23. ads/aqua/ui.py +24 -2
  24. ads/cli.py +16 -8
  25. ads/common/auth.py +9 -9
  26. ads/llm/autogen/__init__.py +2 -0
  27. ads/llm/autogen/constants.py +15 -0
  28. ads/llm/autogen/reports/__init__.py +2 -0
  29. ads/llm/autogen/reports/base.py +67 -0
  30. ads/llm/autogen/reports/data.py +103 -0
  31. ads/llm/autogen/reports/session.py +526 -0
  32. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  33. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  34. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  35. ads/llm/autogen/reports/utils.py +56 -0
  36. ads/llm/autogen/v02/__init__.py +4 -0
  37. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  38. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  39. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  40. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  41. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  42. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  43. ads/llm/autogen/v02/loggers/utils.py +86 -0
  44. ads/llm/autogen/v02/runtime_logging.py +163 -0
  45. ads/llm/guardrails/base.py +6 -5
  46. ads/llm/langchain/plugins/chat_models/oci_data_science.py +46 -20
  47. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +38 -11
  48. ads/model/__init__.py +11 -13
  49. ads/model/artifact.py +47 -8
  50. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  51. ads/model/framework/embedding_onnx_model.py +438 -0
  52. ads/model/generic_model.py +26 -24
  53. ads/model/model_metadata.py +8 -7
  54. ads/opctl/config/merger.py +13 -14
  55. ads/opctl/operator/common/operator_config.py +4 -4
  56. ads/opctl/operator/lowcode/common/transformations.py +50 -8
  57. ads/opctl/operator/lowcode/common/utils.py +22 -6
  58. ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
  59. ads/opctl/operator/lowcode/forecast/const.py +3 -0
  60. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  61. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  62. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  63. ads/opctl/operator/lowcode/forecast/model/base_model.py +58 -17
  64. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
  65. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  66. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  67. ads/opctl/operator/lowcode/forecast/model_evaluator.py +3 -2
  68. ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
  69. ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
  70. ads/opctl/operator/lowcode/forecast/utils.py +8 -6
  71. ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
  72. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
  73. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
  74. ads/telemetry/base.py +18 -11
  75. ads/telemetry/client.py +33 -13
  76. ads/templates/schemas/openapi.json +1740 -0
  77. ads/templates/score_embedding_onnx.jinja2 +202 -0
  78. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +11 -10
  79. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +82 -56
  80. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
  81. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
  82. {oracle_ads-2.12.8.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,163 @@
1
+ # Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
2
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3
+ import logging
4
+ import traceback
5
+ from sqlite3 import Connection
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import autogen.runtime_logging
9
+ from autogen.logger.base_logger import BaseLogger
10
+ from autogen.logger.logger_factory import LoggerFactory
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LoggerManager(BaseLogger):
16
+ """Manages multiple AutoGen loggers."""
17
+
18
+ def __init__(self) -> None:
19
+ self.loggers: List[BaseLogger] = []
20
+ super().__init__()
21
+
22
+ def add_logger(self, logger: BaseLogger) -> None:
23
+ """Adds a new AutoGen logger."""
24
+ self.loggers.append(logger)
25
+
26
+ def _call_loggers(self, method: str, *args, **kwargs) -> None:
27
+ """Calls the specific method on each AutoGen logger in self.loggers."""
28
+ for autogen_logger in self.loggers:
29
+ try:
30
+ getattr(autogen_logger, method)(*args, **kwargs)
31
+ except Exception as e:
32
+ # Catch the logging exception so that the program will not be interrupted.
33
+ logger.error(
34
+ "Failed to %s with %s: %s",
35
+ method,
36
+ autogen_logger.__class__.__name__,
37
+ str(e),
38
+ )
39
+ logger.debug(traceback.format_exc())
40
+
41
+ def start(self) -> str:
42
+ """Starts all loggers."""
43
+ return self._call_loggers("start")
44
+
45
+ def stop(self) -> None:
46
+ self._call_loggers("stop")
47
+ # Remove the loggers once they are stopped.
48
+ self.loggers = []
49
+
50
+ def get_connection(self) -> None | Connection:
51
+ return self._call_loggers("get_connection")
52
+
53
+ def log_chat_completion(self, *args, **kwargs) -> None:
54
+ return self._call_loggers("log_chat_completion", *args, **kwargs)
55
+
56
+ def log_new_agent(self, *args, **kwargs) -> None:
57
+ return self._call_loggers("log_new_agent", *args, **kwargs)
58
+
59
+ def log_event(self, *args, **kwargs) -> None:
60
+ return self._call_loggers("log_event", *args, **kwargs)
61
+
62
+ def log_new_wrapper(self, *args, **kwargs) -> None:
63
+ return self._call_loggers("log_new_wrapper", *args, **kwargs)
64
+
65
+ def log_new_client(self, *args, **kwargs) -> None:
66
+ return self._call_loggers("log_new_client", *args, **kwargs)
67
+
68
+ def log_function_use(self, *args, **kwargs) -> None:
69
+ return self._call_loggers("log_function_use", *args, **kwargs)
70
+
71
+ def __repr__(self) -> str:
72
+ return "\n\n".join(
73
+ [
74
+ f"{str(logger.__class__)}:\n{logger.__repr__()}"
75
+ for logger in self.loggers
76
+ ]
77
+ )
78
+
79
+
80
+ def start(
81
+ autogen_logger: Optional[BaseLogger] = None,
82
+ logger_type: str = None,
83
+ config: Optional[Dict[str, Any]] = None,
84
+ ) -> str:
85
+ """Starts logging with AutoGen logger.
86
+ Specify your custom autogen_logger, or the logger_type and config to use a built-in logger.
87
+
88
+ Parameters
89
+ ----------
90
+ autogen_logger : BaseLogger, optional
91
+ An AutoGen logger, which should be a subclass of autogen.logger.base_logger.BaseLogger.
92
+ logger_type : str, optional
93
+ Logger type, which can be a built-in AutoGen logger type ("file", or "sqlite"), by default None.
94
+ config : dict, optional
95
+ Configurations for the built-in AutoGen logger, by default None
96
+
97
+ Returns
98
+ -------
99
+ str
100
+ A unique session ID returned from starting the logger.
101
+
102
+ """
103
+ if autogen_logger and logger_type:
104
+ raise ValueError(
105
+ "Please specify only autogen_logger(%s) or logger_type(%s).",
106
+ autogen_logger,
107
+ logger_type,
108
+ )
109
+
110
+ # Check if a logger is already configured
111
+ existing_logger = autogen.runtime_logging.autogen_logger
112
+ if not existing_logger:
113
+ # No logger is configured
114
+ logger_manager = LoggerManager()
115
+ elif isinstance(existing_logger, LoggerManager):
116
+ # Logger is already configured with ADS
117
+ logger_manager = existing_logger
118
+ else:
119
+ # Logger is configured but it is not via ADS
120
+ logger.warning("AutoGen is already configured with %s", str(existing_logger))
121
+ logger_manager = LoggerManager()
122
+ logger_manager.add_logger(existing_logger)
123
+
124
+ # Add AutoGen logger
125
+ if not autogen_logger:
126
+ autogen_logger = LoggerFactory.get_logger(
127
+ logger_type=logger_type, config=config
128
+ )
129
+ logger_manager.add_logger(autogen_logger)
130
+
131
+ try:
132
+ session_id = autogen_logger.start()
133
+ autogen.runtime_logging.is_logging = True
134
+ autogen.runtime_logging.autogen_logger = logger_manager
135
+ except Exception as e:
136
+ logger.error(f"Failed to start logging: {e}")
137
+ return session_id
138
+
139
+
140
+ def stop(*loggers) -> BaseLogger:
141
+ """Stops AutoGen logger.
142
+ If loggers are managed by LoggerManager,
143
+ you may specify one or more loggers to be stopped.
144
+ If no logger is specified, all loggers will be stopped.
145
+ Stopped loggers will be removed from the LoggerManager.
146
+ """
147
+ autogen_logger = autogen.runtime_logging.autogen_logger
148
+ if isinstance(autogen_logger, LoggerManager) and loggers:
149
+ for logger in loggers:
150
+ logger.stop()
151
+ if logger in autogen_logger.loggers:
152
+ autogen_logger.loggers.remove(logger)
153
+ else:
154
+ autogen.runtime_logging.stop()
155
+ return autogen_logger
156
+
157
+
158
+ def get_loggers() -> List[BaseLogger]:
159
+ """Gets a list of existing AutoGen loggers."""
160
+ autogen_logger = autogen.runtime_logging.autogen_logger
161
+ if isinstance(autogen_logger, LoggerManager):
162
+ return autogen_logger.loggers
163
+ return [autogen_logger]
@@ -1,17 +1,16 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
8
7
  import datetime
9
8
  import functools
10
- import operator
11
9
  import importlib.util
10
+ import operator
12
11
  import sys
12
+ from typing import Any, List, Optional, Union
13
13
 
14
- from typing import Any, List, Dict, Tuple
15
14
  from langchain.schema.prompt import PromptValue
16
15
  from langchain.tools.base import BaseTool, ToolException
17
16
  from pydantic import BaseModel, model_validator
@@ -207,7 +206,9 @@ class Guardrail(BaseTool):
207
206
  return input.to_string()
208
207
  return str(input)
209
208
 
210
- def _to_args_and_kwargs(self, tool_input: Any) -> Tuple[Tuple, Dict]:
209
+ def _to_args_and_kwargs(
210
+ self, tool_input: Union[str, dict], tool_call_id: Optional[str]
211
+ ) -> tuple[tuple, dict]:
211
212
  if isinstance(tool_input, dict):
212
213
  return (), tool_input
213
214
  else:
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
  """Chat model for OCI data science model deployment endpoint."""
7
6
 
@@ -50,6 +49,7 @@ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint i
50
49
  )
51
50
 
52
51
  logger = logging.getLogger(__name__)
52
+ DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
53
53
 
54
54
 
55
55
  def _is_pydantic_class(obj: Any) -> bool:
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
93
93
  Key init args — client params:
94
94
  auth: dict
95
95
  ADS auth dictionary for OCI authentication.
96
+ default_headers: Optional[Dict]
97
+ The headers to be added to the Model Deployment request.
96
98
 
97
99
  Instantiate:
98
100
  .. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109
111
  "temperature": 0.2,
110
112
  # other model parameters ...
111
113
  },
114
+ default_headers={
115
+ "route": "/v1/chat/completions",
116
+ # other request headers ...
117
+ },
112
118
  )
113
119
 
114
120
  Invocation:
@@ -291,6 +297,25 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
291
297
  "stream": self.streaming,
292
298
  }
293
299
 
300
+ def _headers(
301
+ self, is_async: Optional[bool] = False, body: Optional[dict] = None
302
+ ) -> Dict:
303
+ """Construct and return the headers for a request.
304
+
305
+ Args:
306
+ is_async (bool, optional): Indicates if the request is asynchronous.
307
+ Defaults to `False`.
308
+ body (optional): The request body to be included in the headers if
309
+ the request is asynchronous.
310
+
311
+ Returns:
312
+ Dict: A dictionary containing the appropriate headers for the request.
313
+ """
314
+ return {
315
+ "route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
316
+ **super()._headers(is_async=is_async, body=body),
317
+ }
318
+
294
319
  def _generate(
295
320
  self,
296
321
  messages: List[BaseMessage],
@@ -704,7 +729,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
704
729
 
705
730
  for choice in choices:
706
731
  message = _convert_dict_to_message(choice["message"])
707
- generation_info = dict(finish_reason=choice.get("finish_reason"))
732
+ generation_info = {"finish_reason": choice.get("finish_reason")}
708
733
  if "logprobs" in choice:
709
734
  generation_info["logprobs"] = choice["logprobs"]
710
735
 
@@ -744,6 +769,8 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
744
769
  Science Model Deployment endpoint. See:
745
770
  https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
746
771
 
772
+ See https://docs.vllm.ai/en/latest/api/inference_params.html for the defaults of the parameters.
773
+
747
774
  Example:
748
775
 
749
776
  .. code-block:: python
@@ -761,7 +788,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
761
788
 
762
789
  """ # noqa: E501
763
790
 
764
- frequency_penalty: float = 0.0
791
+ frequency_penalty: Optional[float] = None
765
792
  """Penalizes repeated tokens according to frequency. Between 0 and 1."""
766
793
 
767
794
  logit_bias: Optional[Dict[str, float]] = None
@@ -773,7 +800,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
773
800
  n: int = 1
774
801
  """Number of output sequences to return for the given prompt."""
775
802
 
776
- presence_penalty: float = 0.0
803
+ presence_penalty: Optional[float] = None
777
804
  """Penalizes repeated tokens. Between 0 and 1."""
778
805
 
779
806
  temperature: float = 0.2
@@ -787,25 +814,25 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
787
814
  (the one with the highest log probability per token).
788
815
  """
789
816
 
790
- use_beam_search: Optional[bool] = False
817
+ use_beam_search: Optional[bool] = None
791
818
  """Whether to use beam search instead of sampling."""
792
819
 
793
820
  top_k: Optional[int] = -1
794
821
  """Number of most likely tokens to consider at each step."""
795
822
 
796
823
  min_p: Optional[float] = 0.0
797
- """Float that represents the minimum probability for a token to be considered.
824
+ """Float that represents the minimum probability for a token to be considered.
798
825
  Must be in [0,1]. 0 to disable this."""
799
826
 
800
- repetition_penalty: Optional[float] = 1.0
827
+ repetition_penalty: Optional[float] = None
801
828
  """Float that penalizes new tokens based on their frequency in the
802
829
  generated text. Values > 1 encourage the model to use new tokens."""
803
830
 
804
- length_penalty: Optional[float] = 1.0
831
+ length_penalty: Optional[float] = None
805
832
  """Float that penalizes sequences based on their length. Used only
806
833
  when `use_beam_search` is True."""
807
834
 
808
- early_stopping: Optional[bool] = False
835
+ early_stopping: Optional[bool] = None
809
836
  """Controls the stopping condition for beam search. It accepts the
810
837
  following values: `True`, where the generation stops as soon as there
811
838
  are `best_of` complete candidates; `False`, where a heuristic is applied
@@ -817,8 +844,8 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
817
844
  """Whether to ignore the EOS token and continue generating tokens after
818
845
  the EOS token is generated."""
819
846
 
820
- min_tokens: Optional[int] = 0
821
- """Minimum number of tokens to generate per output sequence before
847
+ min_tokens: Optional[int] = None
848
+ """Minimum number of tokens to generate per output sequence before
822
849
  EOS or stop_token_ids can be generated"""
823
850
 
824
851
  stop_token_ids: Optional[List[int]] = None
@@ -826,17 +853,16 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
826
853
  The returned output will contain the stop tokens unless the stop tokens
827
854
  are special tokens."""
828
855
 
829
- skip_special_tokens: Optional[bool] = True
856
+ skip_special_tokens: Optional[bool] = None
830
857
  """Whether to skip special tokens in the output. Defaults to True."""
831
858
 
832
- spaces_between_special_tokens: Optional[bool] = True
833
- """Whether to add spaces between special tokens in the output.
834
- Defaults to True."""
859
+ spaces_between_special_tokens: Optional[bool] = None
860
+ """Whether to add spaces between special tokens in the output."""
835
861
 
836
862
  tool_choice: Optional[str] = None
837
863
  """Whether to use tool calling.
838
864
  Defaults to None, tool calling is disabled.
839
- Tool calling requires model support and the vLLM to be configured
865
+ Tool calling requires model support and the vLLM to be configured
840
866
  with `--tool-call-parser`.
841
867
  Set this to `auto` for the model to make tool calls automatically.
842
868
  Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +982,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956
982
  """Total probability mass of tokens to consider at each step."""
957
983
 
958
984
  top_logprobs: Optional[int] = None
959
- """An integer between 0 and 5 specifying the number of most
960
- likely tokens to return at each token position, each with an
961
- associated log probability. logprobs must be set to true if
985
+ """An integer between 0 and 5 specifying the number of most
986
+ likely tokens to return at each token position, each with an
987
+ associated log probability. logprobs must be set to true if
962
988
  this parameter is used."""
963
989
 
964
990
  @property
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
 
@@ -24,6 +23,7 @@ from typing import (
24
23
 
25
24
  import aiohttp
26
25
  import requests
26
+ from langchain_community.utilities.requests import Requests
27
27
  from langchain_core.callbacks import (
28
28
  AsyncCallbackManagerForLLMRun,
29
29
  CallbackManagerForLLMRun,
@@ -34,14 +34,13 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
34
34
  from langchain_core.utils import get_from_dict_or_env
35
35
  from pydantic import Field, model_validator
36
36
 
37
- from langchain_community.utilities.requests import Requests
38
-
39
37
  logger = logging.getLogger(__name__)
40
38
 
41
39
 
42
40
  DEFAULT_TIME_OUT = 300
43
41
  DEFAULT_CONTENT_TYPE_JSON = "application/json"
44
42
  DEFAULT_MODEL_NAME = "odsc-llm"
43
+ DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
45
44
 
46
45
 
47
46
  class TokenExpiredError(Exception):
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
86
85
  max_retries: int = 3
87
86
  """Maximum number of retries to make when generating."""
88
87
 
88
+ default_headers: Optional[Dict[str, Any]] = None
89
+ """The headers to be added to the Model Deployment request."""
90
+
89
91
  @model_validator(mode="before")
90
92
  @classmethod
91
93
  def validate_environment(cls, values: Dict) -> Dict:
@@ -101,7 +103,7 @@ class BaseOCIModelDeployment(Serializable):
101
103
  "Please install it with `pip install oracle_ads`."
102
104
  ) from ex
103
105
 
104
- if not values.get("auth", None):
106
+ if not values.get("auth"):
105
107
  values["auth"] = ads.common.auth.default_signer()
106
108
 
107
109
  values["endpoint"] = get_from_dict_or_env(
@@ -125,12 +127,12 @@ class BaseOCIModelDeployment(Serializable):
125
127
  Returns:
126
128
  Dict: A dictionary containing the appropriate headers for the request.
127
129
  """
130
+ headers = self.default_headers or {}
128
131
  if is_async:
129
132
  signer = self.auth["signer"]
130
133
  _req = requests.Request("POST", self.endpoint, json=body)
131
134
  req = _req.prepare()
132
135
  req = signer(req)
133
- headers = {}
134
136
  for key, value in req.headers.items():
135
137
  headers[key] = value
136
138
 
@@ -140,7 +142,7 @@ class BaseOCIModelDeployment(Serializable):
140
142
  )
141
143
  return headers
142
144
 
143
- return (
145
+ headers.update(
144
146
  {
145
147
  "Content-Type": DEFAULT_CONTENT_TYPE_JSON,
146
148
  "enable-streaming": "true",
@@ -152,6 +154,8 @@ class BaseOCIModelDeployment(Serializable):
152
154
  }
153
155
  )
154
156
 
157
+ return headers
158
+
155
159
  def completion_with_retry(
156
160
  self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
157
161
  ) -> Any:
@@ -357,7 +361,7 @@ class BaseOCIModelDeployment(Serializable):
357
361
  self.auth["signer"].refresh_security_token()
358
362
  return True
359
363
  return False
360
-
364
+
361
365
  @classmethod
362
366
  def is_lc_serializable(cls) -> bool:
363
367
  """Return whether this model can be serialized by LangChain."""
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
388
392
  model="odsc-llm",
389
393
  streaming=True,
390
394
  model_kwargs={"frequency_penalty": 1.0},
395
+ headers={
396
+ "route": "/v1/completions",
397
+ # other request headers ...
398
+ }
391
399
  )
392
400
  llm.invoke("tell me a joke.")
393
401
 
@@ -477,6 +485,25 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
477
485
  **self._default_params,
478
486
  }
479
487
 
488
+ def _headers(
489
+ self, is_async: Optional[bool] = False, body: Optional[dict] = None
490
+ ) -> Dict:
491
+ """Construct and return the headers for a request.
492
+
493
+ Args:
494
+ is_async (bool, optional): Indicates if the request is asynchronous.
495
+ Defaults to `False`.
496
+ body (optional): The request body to be included in the headers if
497
+ the request is asynchronous.
498
+
499
+ Returns:
500
+ Dict: A dictionary containing the appropriate headers for the request.
501
+ """
502
+ return {
503
+ "route": DEFAULT_INFERENCE_ENDPOINT,
504
+ **super()._headers(is_async=is_async, body=body),
505
+ }
506
+
480
507
  def _generate(
481
508
  self,
482
509
  prompts: List[str],
@@ -712,9 +739,9 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
712
739
  def _generate_info(self, choice: dict) -> Any:
713
740
  """Extracts generation info from the response."""
714
741
  gen_info = {}
715
- finish_reason = choice.get("finish_reason", None)
716
- logprobs = choice.get("logprobs", None)
717
- index = choice.get("index", None)
742
+ finish_reason = choice.get("finish_reason")
743
+ logprobs = choice.get("logprobs")
744
+ index = choice.get("index")
718
745
  if finish_reason:
719
746
  gen_info.update({"finish_reason": finish_reason})
720
747
  if logprobs is not None:
ads/model/__init__.py CHANGED
@@ -1,29 +1,26 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ads.model.generic_model import GenericModel, ModelState
8
6
  from ads.model.datascience_model import DataScienceModel
9
- from ads.model.model_properties import ModelProperties
7
+ from ads.model.deployment.model_deployer import ModelDeployer
8
+ from ads.model.deployment.model_deployment import ModelDeployment
9
+ from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
10
10
  from ads.model.framework.automl_model import AutoMLModel
11
+ from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
12
+ from ads.model.framework.huggingface_model import HuggingFacePipelineModel
11
13
  from ads.model.framework.lightgbm_model import LightGBMModel
12
14
  from ads.model.framework.pytorch_model import PyTorchModel
13
15
  from ads.model.framework.sklearn_model import SklearnModel
16
+ from ads.model.framework.spark_model import SparkPipelineModel
14
17
  from ads.model.framework.tensorflow_model import TensorFlowModel
15
18
  from ads.model.framework.xgboost_model import XGBoostModel
16
- from ads.model.framework.spark_model import SparkPipelineModel
17
- from ads.model.framework.huggingface_model import HuggingFacePipelineModel
18
-
19
- from ads.model.deployment.model_deployer import ModelDeployer
20
- from ads.model.deployment.model_deployment import ModelDeployment
21
- from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
22
-
19
+ from ads.model.generic_model import GenericModel, ModelState
20
+ from ads.model.model_properties import ModelProperties
21
+ from ads.model.model_version_set import ModelVersionSet, experiment
23
22
  from ads.model.serde.common import SERDE
24
23
  from ads.model.serde.model_input import ModelInputSerializer
25
-
26
- from ads.model.model_version_set import ModelVersionSet, experiment
27
24
  from ads.model.service.oci_datascience_model_version_set import (
28
25
  ModelVersionSetNotExists,
29
26
  ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@ __all__ = [
42
39
  "XGBoostModel",
43
40
  "SparkPipelineModel",
44
41
  "HuggingFacePipelineModel",
42
+ "EmbeddingONNXModel",
45
43
  "ModelDeployer",
46
44
  "ModelDeployment",
47
45
  "ModelDeploymentProperties",
ads/model/artifact.py CHANGED
@@ -1,28 +1,28 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import fnmatch
8
7
  import importlib
9
8
  import os
10
- import sys
11
9
  import shutil
10
+ import sys
12
11
  import tempfile
13
12
  import uuid
14
- import fsspec
13
+ from datetime import datetime
15
14
  from typing import Dict, Optional, Tuple
15
+
16
+ import fsspec
17
+ from jinja2 import Environment, PackageLoader
18
+
19
+ from ads import __version__
16
20
  from ads.common import auth as authutil
17
21
  from ads.common import logger, utils
18
22
  from ads.common.object_storage_details import ObjectStorageDetails
19
23
  from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
20
24
  from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
21
25
  from ads.model.runtime.runtime_info import RuntimeInfo
22
- from jinja2 import Environment, PackageLoader
23
- import warnings
24
- from ads import __version__
25
- from datetime import datetime
26
26
 
27
27
  MODEL_ARTIFACT_VERSION = "3.0"
28
28
  REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ class ModelArtifact:
378
378
  ) as f:
379
379
  f.write(scorefn_template.render(context))
380
380
 
381
+ def prepare_schema(self, schema_name: str):
382
+ """Copies schema to artifact directory.
383
+
384
+ Parameters
385
+ ----------
386
+ schema_name: str
387
+ The schema name
388
+
389
+ Returns
390
+ -------
391
+ None
392
+
393
+ Raises
394
+ ------
395
+ FileExistsError
396
+ If `schema_name` doesn't exist.
397
+ """
398
+ uri_src = os.path.join(
399
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
400
+ "templates",
401
+ "schemas",
402
+ f"{schema_name}",
403
+ )
404
+
405
+ if not os.path.exists(uri_src):
406
+ raise FileExistsError(
407
+ f"{schema_name} does not exists. "
408
+ "Ensure the schema name is valid or specify a different one."
409
+ )
410
+
411
+ uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
412
+
413
+ utils.copy_file(
414
+ uri_src=uri_src,
415
+ uri_dst=uri_dst,
416
+ force_overwrite=True,
417
+ auth=self.auth,
418
+ )
419
+
381
420
  def reload(self):
382
421
  """Syncs the `score.py` to reload the model and predict function.
383
422