oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. ads/aqua/__init__.py +4 -4
  2. ads/aqua/common/enums.py +3 -0
  3. ads/aqua/common/utils.py +62 -2
  4. ads/aqua/data.py +2 -19
  5. ads/aqua/extension/finetune_handler.py +8 -14
  6. ads/aqua/extension/model_handler.py +19 -2
  7. ads/aqua/finetuning/constants.py +5 -2
  8. ads/aqua/finetuning/entities.py +64 -17
  9. ads/aqua/finetuning/finetuning.py +38 -54
  10. ads/aqua/model/entities.py +2 -1
  11. ads/aqua/model/model.py +61 -23
  12. ads/common/auth.py +9 -9
  13. ads/llm/autogen/__init__.py +2 -0
  14. ads/llm/autogen/constants.py +15 -0
  15. ads/llm/autogen/reports/__init__.py +2 -0
  16. ads/llm/autogen/reports/base.py +67 -0
  17. ads/llm/autogen/reports/data.py +103 -0
  18. ads/llm/autogen/reports/session.py +526 -0
  19. ads/llm/autogen/reports/templates/chat_box.html +13 -0
  20. ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
  21. ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
  22. ads/llm/autogen/reports/utils.py +56 -0
  23. ads/llm/autogen/v02/__init__.py +4 -0
  24. ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
  25. ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
  26. ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
  27. ads/llm/autogen/v02/loggers/__init__.py +6 -0
  28. ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
  29. ads/llm/autogen/v02/loggers/session_logger.py +580 -0
  30. ads/llm/autogen/v02/loggers/utils.py +86 -0
  31. ads/llm/autogen/v02/runtime_logging.py +163 -0
  32. ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
  33. ads/model/__init__.py +11 -13
  34. ads/model/artifact.py +47 -8
  35. ads/model/extractor/embedding_onnx_extractor.py +80 -0
  36. ads/model/framework/embedding_onnx_model.py +438 -0
  37. ads/model/generic_model.py +26 -24
  38. ads/model/model_metadata.py +8 -7
  39. ads/opctl/config/merger.py +13 -14
  40. ads/opctl/operator/common/operator_config.py +4 -4
  41. ads/opctl/operator/lowcode/common/transformations.py +12 -5
  42. ads/opctl/operator/lowcode/common/utils.py +11 -5
  43. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  44. ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
  45. ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
  46. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  47. ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
  48. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
  49. ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
  50. ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
  51. ads/opctl/operator/lowcode/forecast/utils.py +4 -3
  52. ads/telemetry/base.py +18 -11
  53. ads/telemetry/client.py +33 -13
  54. ads/templates/schemas/openapi.json +1740 -0
  55. ads/templates/score_embedding_onnx.jinja2 +202 -0
  56. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
  57. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
  58. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
  59. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
  60. {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,5 @@
1
- # coding: utf-8
2
- # Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
3
- # This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
1
+ # Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
2
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
3
 
5
4
  """This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
6
5
  https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
@@ -72,14 +71,14 @@ import copy
72
71
  import importlib
73
72
  import json
74
73
  import logging
75
- from typing import Any, Dict, List, Union
74
+ from dataclasses import asdict, dataclass
76
75
  from types import SimpleNamespace
76
+ from typing import Any, Dict, List, Union
77
77
 
78
78
  from autogen import ModelClient
79
79
  from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
80
80
  from langchain_core.messages import AIMessage
81
81
 
82
-
83
82
  logger = logging.getLogger(__name__)
84
83
 
85
84
  # custom_clients is a dictionary mapping the name of the class to the actual class
@@ -177,6 +176,13 @@ class Message(AIMessage):
177
176
  return self.tool_calls
178
177
 
179
178
 
179
+ @dataclass
180
+ class Usage:
181
+ prompt_tokens: int = 0
182
+ completion_tokens: int = 0
183
+ total_tokens: int = 0
184
+
185
+
180
186
  class LangChainModelClient(ModelClient):
181
187
  """Represents a model client wrapping a LangChain chat model."""
182
188
 
@@ -202,8 +208,8 @@ class LangChainModelClient(ModelClient):
202
208
  # Import the LangChain class
203
209
  if "langchain_cls" not in config:
204
210
  raise ValueError("Missing langchain_cls in LangChain Model Client config.")
205
- module_cls = config.pop("langchain_cls")
206
- module_name, cls_name = str(module_cls).rsplit(".", 1)
211
+ self.langchain_cls = config.pop("langchain_cls")
212
+ module_name, cls_name = str(self.langchain_cls).rsplit(".", 1)
207
213
  langchain_module = importlib.import_module(module_name)
208
214
  langchain_cls = getattr(langchain_module, cls_name)
209
215
 
@@ -232,7 +238,14 @@ class LangChainModelClient(ModelClient):
232
238
  streaming = params.get("stream", False)
233
239
  # TODO: num_of_responses
234
240
  num_of_responses = params.get("n", 1)
235
- messages = params.pop("messages", [])
241
+
242
+ messages = copy.deepcopy(params.get("messages", []))
243
+
244
+ # OCI Gen AI does not allow empty message.
245
+ if str(self.langchain_cls).endswith("oci_generative_ai.ChatOCIGenAI"):
246
+ for message in messages:
247
+ if len(message.get("content", "")) == 0:
248
+ message["content"] = " "
236
249
 
237
250
  invoke_params = copy.deepcopy(self.invoke_params)
238
251
 
@@ -241,7 +254,6 @@ class LangChainModelClient(ModelClient):
241
254
  model = self.model.bind_tools(
242
255
  [_convert_to_langchain_tool(tool) for tool in tools]
243
256
  )
244
- # invoke_params["tools"] = tools
245
257
  invoke_params.update(self.function_call_params)
246
258
  else:
247
259
  model = self.model
@@ -249,6 +261,7 @@ class LangChainModelClient(ModelClient):
249
261
  response = SimpleNamespace()
250
262
  response.choices = []
251
263
  response.model = self.model_name
264
+ response.usage = Usage()
252
265
 
253
266
  if streaming and messages:
254
267
  # If streaming is enabled and has messages, then iterate over the chunks of the response.
@@ -279,4 +292,4 @@ class LangChainModelClient(ModelClient):
279
292
  @staticmethod
280
293
  def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
281
294
  """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
282
- return {}
295
+ return asdict(response.usage)
@@ -0,0 +1,2 @@
1
+ # Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
2
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -0,0 +1,83 @@
1
+ # Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
2
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3
+ import io
4
+ import json
5
+ import logging
6
+ import os
7
+ import threading
8
+
9
+ import fsspec
10
+
11
+ from ads.common.auth import default_signer
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class OCIFileHandler(logging.FileHandler):
17
+ """Log handler for saving log file to OCI object storage."""
18
+
19
+ def __init__(
20
+ self,
21
+ filename: str,
22
+ session_id: str,
23
+ mode: str = "a",
24
+ encoding: str | None = None,
25
+ delay: bool = False,
26
+ errors: str | None = None,
27
+ auth: dict | None = None,
28
+ ) -> None:
29
+ self.session_id = session_id
30
+ self.auth = auth
31
+
32
+ if filename.startswith("oci://"):
33
+ self.baseFilename = filename
34
+ else:
35
+ self.baseFilename = os.path.abspath(os.path.expanduser(filename))
36
+ os.makedirs(os.path.dirname(self.baseFilename), exist_ok=True)
37
+
38
+ # The following code are from the `FileHandler.__init__()`
39
+ self.mode = mode
40
+ self.encoding = encoding
41
+ if "b" not in mode:
42
+ self.encoding = io.text_encoding(encoding)
43
+ self.errors = errors
44
+ self.delay = delay
45
+
46
+ if delay:
47
+ # We don't open the stream, but we still need to call the
48
+ # Handler constructor to set level, formatter, lock etc.
49
+ logging.Handler.__init__(self)
50
+ self.stream = None
51
+ else:
52
+ logging.StreamHandler.__init__(self, self._open())
53
+
54
+ def _open(self):
55
+ """
56
+ Open the current base file with the (original) mode and encoding.
57
+ Return the resulting stream.
58
+ """
59
+ auth = self.auth or default_signer()
60
+ return fsspec.open(
61
+ self.baseFilename,
62
+ self.mode,
63
+ encoding=self.encoding,
64
+ errors=self.errors,
65
+ **auth,
66
+ ).open()
67
+
68
+ def format(self, record: logging.LogRecord):
69
+ """Formats the log record as JSON payload and add session_id."""
70
+ msg = record.getMessage()
71
+ try:
72
+ data = json.loads(msg)
73
+ except Exception as e:
74
+ data = {"message": msg}
75
+
76
+ if "session_id" not in data:
77
+ data["session_id"] = self.session_id
78
+ if "thread_id" not in data:
79
+ data["thread_id"] = threading.get_ident()
80
+
81
+ record.msg = json.dumps(data)
82
+ return super().format(record)
83
+
@@ -0,0 +1,6 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+
5
+ from ads.llm.autogen.v02.loggers.metric_logger import MetricLogger
6
+ from ads.llm.autogen.v02.loggers.session_logger import SessionLogger
@@ -0,0 +1,320 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
+ import logging
5
+ from datetime import datetime
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from uuid import UUID
8
+
9
+ import oci
10
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
11
+ from autogen.logger.base_logger import BaseLogger, LLMConfig
12
+ from autogen.logger.logger_utils import get_current_ts
13
+ from oci.monitoring import MonitoringClient
14
+ from pydantic import BaseModel, Field
15
+
16
+ import ads
17
+ import ads.config
18
+ from ads.llm.autogen.v02.loggers.utils import serialize_response
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MetricName:
24
+ """Constants for metric name."""
25
+
26
+ TOOL_CALL = "tool_call"
27
+ CHAT_COMPLETION = "chat_completion_count"
28
+ COST = "chat_completion_cost"
29
+ SESSION_START = "session_start"
30
+ SESSION_STOP = "session_stop"
31
+
32
+
33
+ class MetricDimension:
34
+ """Constants for metric dimension."""
35
+
36
+ AGENT_NAME = "agent_name"
37
+ APP_NAME = "app_name"
38
+ MODEL = "model"
39
+ SESSION_ID = "session_id"
40
+ TOOL_NAME = "tool_name"
41
+
42
+
43
+ class Metric(BaseModel):
44
+ """Represents the metric to be logged."""
45
+
46
+ name: str
47
+ value: float
48
+ timestamp: str
49
+ dimensions: dict = Field(default_factory=dict)
50
+
51
+
52
+ class MetricLogger(BaseLogger):
53
+ """AutoGen logger for agent metrics."""
54
+
55
+ def __init__(
56
+ self,
57
+ namespace: str,
58
+ app_name: Optional[str] = None,
59
+ compartment_id: Optional[str] = None,
60
+ session_id: Optional[str] = None,
61
+ region: Optional[str] = None,
62
+ resource_group: Optional[str] = None,
63
+ log_agent_name: bool = False,
64
+ log_tool_name: bool = False,
65
+ log_model_name: bool = False,
66
+ ):
67
+ """Initialize the metric logger.
68
+
69
+ Parameters
70
+ ----------
71
+ namespace : str
72
+ Namespace for posting the metric
73
+ app_name : str
74
+ Application name, which will be a metric dimension if specified.
75
+ compartment_id : str, optional
76
+ Compartment OCID for posting the metric.
77
+ If compartment_id is not specified,
78
+ ADS will try to fetch the compartment OCID from environment variable.
79
+ session_id : str, optional
80
+ Session ID to be saved as a metric dimension, by default None.
81
+ region : str, optional
82
+ OCI region for posting the metric, by default None.
83
+ If region is not specified, the region from the authentication signer will be used.
84
+ resource_group : str, optional
85
+ Resource group for the metric, by default None
86
+ log_agent_name : bool, optional
87
+ Whether to log agent name as a metric dimension, by default True.
88
+ log_tool_name : bool, optional
89
+ Whether to log tool name as a metric dimension, by default True.
90
+ log_model_name : bool, optional
91
+ Whether to log model name as a metric dimension, by default True.
92
+
93
+ """
94
+ self.app_name = app_name
95
+ self.session_id = session_id
96
+ self.compartment_id = compartment_id or ads.config.COMPARTMENT_OCID
97
+ if not self.compartment_id:
98
+ raise ValueError(
99
+ "Unable to determine compartment OCID for metric logger."
100
+ "Please specify the compartment_id."
101
+ )
102
+ self.namespace = namespace
103
+ self.resource_group = resource_group
104
+ self.log_agent_name = log_agent_name
105
+ self.log_tool_name = log_tool_name
106
+ self.log_model_name = log_model_name
107
+ # Indicate if the logger has started.
108
+ self.started = False
109
+
110
+ auth = ads.auth.default_signer()
111
+
112
+ # Use the config/signer to determine the region if it not specified.
113
+ signer = auth.get("signer")
114
+ config = auth.get("config", {})
115
+ if not region:
116
+ if hasattr(signer, "region") and signer.region:
117
+ region = signer.region
118
+ elif config.get("region"):
119
+ region = config.get("region")
120
+ else:
121
+ raise ValueError(
122
+ "Unable to determine the region for OCI monitoring service. "
123
+ "Please specify the region using the `region` argument."
124
+ )
125
+
126
+ self.monitoring_client = MonitoringClient(
127
+ config=config,
128
+ signer=signer,
129
+ # Metrics should be submitted with the "telemetry-ingestion" endpoint instead.
130
+ # See note here: https://docs.oracle.com/iaas/api/#/en/monitoring/20180401/MetricData/PostMetricData
131
+ service_endpoint=f"https://telemetry-ingestion.{region}.oraclecloud.com",
132
+ )
133
+
134
+ def _post_metric(self, metric: Metric):
135
+ """Posts metric to OCI monitoring."""
136
+ # Add app_name and session_id to dimensions
137
+ dimensions = metric.dimensions
138
+ if self.app_name:
139
+ dimensions[MetricDimension.APP_NAME] = self.app_name
140
+ if self.session_id:
141
+ dimensions[MetricDimension.SESSION_ID] = self.session_id
142
+
143
+ logger.debug("Posting metrics:\n%s", str(metric))
144
+ self.monitoring_client.post_metric_data(
145
+ post_metric_data_details=oci.monitoring.models.PostMetricDataDetails(
146
+ metric_data=[
147
+ oci.monitoring.models.MetricDataDetails(
148
+ namespace=self.namespace,
149
+ compartment_id=self.compartment_id,
150
+ name=metric.name,
151
+ dimensions=dimensions,
152
+ datapoints=[
153
+ oci.monitoring.models.Datapoint(
154
+ timestamp=datetime.strptime(
155
+ metric.timestamp.replace(" ", "T") + "Z",
156
+ "%Y-%m-%dT%H:%M:%S.%fZ",
157
+ ),
158
+ value=metric.value,
159
+ count=1,
160
+ )
161
+ ],
162
+ resource_group=self.resource_group,
163
+ )
164
+ ],
165
+ batch_atomicity="ATOMIC",
166
+ ),
167
+ )
168
+
169
+ def start(self):
170
+ """Starts the logger."""
171
+ if self.session_id:
172
+ logger.info(f"Starting metric logging for session_id: {self.session_id}")
173
+ else:
174
+ logger.info("Starting metric logging.")
175
+ self.started = True
176
+ try:
177
+ metric = Metric(
178
+ name=MetricName.SESSION_START,
179
+ value=1,
180
+ timestamp=get_current_ts(),
181
+ )
182
+ self._post_metric(metric=metric)
183
+ except Exception as e:
184
+ logger.error(f"MetricLogger Failed to log session start: {str(e)}")
185
+ return self.session_id
186
+
187
+ def log_new_agent(
188
+ self, agent: ConversableAgent, init_args: Dict[str, Any] = {}
189
+ ) -> None:
190
+ """Metric logger does not log new agent."""
191
+ pass
192
+
193
+ def log_function_use(
194
+ self,
195
+ source: Union[str, Agent],
196
+ function: Any,
197
+ args: Dict[str, Any],
198
+ returns: Any,
199
+ ) -> None:
200
+ """
201
+ Log a registered function(can be a tool) use from an agent or a string source.
202
+ """
203
+ if not self.started:
204
+ return
205
+ agent_name = str(source.name) if hasattr(source, "name") else source
206
+ dimensions = {}
207
+ if self.log_tool_name:
208
+ dimensions[MetricDimension.TOOL_NAME] = function.__name__
209
+ if self.log_agent_name:
210
+ dimensions[MetricDimension.AGENT_NAME] = agent_name
211
+ try:
212
+ self._post_metric(
213
+ Metric(
214
+ name=MetricName.TOOL_CALL,
215
+ value=1,
216
+ timestamp=get_current_ts(),
217
+ dimensions=dimensions,
218
+ )
219
+ )
220
+ except Exception as e:
221
+ logger.error(f"MetricLogger Failed to log tool call: {str(e)}")
222
+
223
+ def log_chat_completion(
224
+ self,
225
+ invocation_id: UUID,
226
+ client_id: int,
227
+ wrapper_id: int,
228
+ source: Union[str, Agent],
229
+ request: Dict[str, Union[float, str, List[Dict[str, str]]]],
230
+ response: Union[str, Any],
231
+ is_cached: int,
232
+ cost: float,
233
+ start_time: str,
234
+ ) -> None:
235
+ """
236
+ Log a chat completion.
237
+ """
238
+ if not self.started:
239
+ return
240
+
241
+ try:
242
+ response: dict = serialize_response(response)
243
+ if "usage" not in response or not isinstance(response["usage"], dict):
244
+ return
245
+ # Post usage metric
246
+ agent_name = str(source.name) if hasattr(source, "name") else source
247
+ model = response.get("model", "N/A")
248
+ dimensions = {}
249
+ if self.log_model_name:
250
+ dimensions[MetricDimension.MODEL] = model
251
+ if self.log_agent_name:
252
+ dimensions[MetricDimension.AGENT_NAME] = agent_name
253
+
254
+ # Chat completion count
255
+ self._post_metric(
256
+ Metric(
257
+ name=MetricName.CHAT_COMPLETION,
258
+ value=1,
259
+ timestamp=get_current_ts(),
260
+ dimensions=dimensions,
261
+ )
262
+ )
263
+ # Cost
264
+ if cost:
265
+ self._post_metric(
266
+ Metric(
267
+ name=MetricName.COST,
268
+ value=cost,
269
+ timestamp=get_current_ts(),
270
+ dimensions=dimensions,
271
+ )
272
+ )
273
+ # Usage
274
+ for key, val in response["usage"].items():
275
+ self._post_metric(
276
+ Metric(
277
+ name=key,
278
+ value=val,
279
+ timestamp=get_current_ts(),
280
+ dimensions=dimensions,
281
+ )
282
+ )
283
+
284
+ except Exception as e:
285
+ logger.error(f"MetricLogger Failed to log chat completion: {str(e)}")
286
+
287
+ def log_new_wrapper(
288
+ self,
289
+ wrapper: OpenAIWrapper,
290
+ init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {},
291
+ ) -> None:
292
+ """Metric logger does not log new wrapper."""
293
+ pass
294
+
295
+ def log_new_client(self, client, wrapper, init_args):
296
+ """Metric logger does not log new client."""
297
+ pass
298
+
299
+ def log_event(self, source, name, **kwargs):
300
+ """Metric logger does not log events."""
301
+ pass
302
+
303
+ def get_connection(self):
304
+ pass
305
+
306
+ def stop(self):
307
+ """Stops the metric logger."""
308
+ if not self.started:
309
+ return
310
+ self.started = False
311
+ try:
312
+ metric = Metric(
313
+ name=MetricName.SESSION_STOP,
314
+ value=1,
315
+ timestamp=get_current_ts(),
316
+ )
317
+ self._post_metric(metric=metric)
318
+ except Exception as e:
319
+ logger.error(f"MetricLogger Failed to log session stop: {str(e)}")
320
+ logger.info("Metric logger stopped.")