oracle-ads 2.12.9__py3-none-any.whl → 2.12.10rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -4
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +19 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +64 -17
- ads/aqua/finetuning/finetuning.py +38 -54
- ads/aqua/model/entities.py +2 -1
- ads/aqua/model/model.py +61 -23
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +12 -5
- ads/opctl/operator/lowcode/common/utils.py +11 -5
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/schema.yaml +13 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/METADATA +7 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/RECORD +60 -39
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10rc0.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,5 @@
|
|
1
|
-
#
|
2
|
-
#
|
3
|
-
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
|
1
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
2
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
3
|
|
5
4
|
"""This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
|
6
5
|
https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
|
@@ -72,14 +71,14 @@ import copy
|
|
72
71
|
import importlib
|
73
72
|
import json
|
74
73
|
import logging
|
75
|
-
from
|
74
|
+
from dataclasses import asdict, dataclass
|
76
75
|
from types import SimpleNamespace
|
76
|
+
from typing import Any, Dict, List, Union
|
77
77
|
|
78
78
|
from autogen import ModelClient
|
79
79
|
from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
|
80
80
|
from langchain_core.messages import AIMessage
|
81
81
|
|
82
|
-
|
83
82
|
logger = logging.getLogger(__name__)
|
84
83
|
|
85
84
|
# custom_clients is a dictionary mapping the name of the class to the actual class
|
@@ -177,6 +176,13 @@ class Message(AIMessage):
|
|
177
176
|
return self.tool_calls
|
178
177
|
|
179
178
|
|
179
|
+
@dataclass
|
180
|
+
class Usage:
|
181
|
+
prompt_tokens: int = 0
|
182
|
+
completion_tokens: int = 0
|
183
|
+
total_tokens: int = 0
|
184
|
+
|
185
|
+
|
180
186
|
class LangChainModelClient(ModelClient):
|
181
187
|
"""Represents a model client wrapping a LangChain chat model."""
|
182
188
|
|
@@ -202,8 +208,8 @@ class LangChainModelClient(ModelClient):
|
|
202
208
|
# Import the LangChain class
|
203
209
|
if "langchain_cls" not in config:
|
204
210
|
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
|
205
|
-
|
206
|
-
module_name, cls_name = str(
|
211
|
+
self.langchain_cls = config.pop("langchain_cls")
|
212
|
+
module_name, cls_name = str(self.langchain_cls).rsplit(".", 1)
|
207
213
|
langchain_module = importlib.import_module(module_name)
|
208
214
|
langchain_cls = getattr(langchain_module, cls_name)
|
209
215
|
|
@@ -232,7 +238,14 @@ class LangChainModelClient(ModelClient):
|
|
232
238
|
streaming = params.get("stream", False)
|
233
239
|
# TODO: num_of_responses
|
234
240
|
num_of_responses = params.get("n", 1)
|
235
|
-
|
241
|
+
|
242
|
+
messages = copy.deepcopy(params.get("messages", []))
|
243
|
+
|
244
|
+
# OCI Gen AI does not allow empty message.
|
245
|
+
if str(self.langchain_cls).endswith("oci_generative_ai.ChatOCIGenAI"):
|
246
|
+
for message in messages:
|
247
|
+
if len(message.get("content", "")) == 0:
|
248
|
+
message["content"] = " "
|
236
249
|
|
237
250
|
invoke_params = copy.deepcopy(self.invoke_params)
|
238
251
|
|
@@ -241,7 +254,6 @@ class LangChainModelClient(ModelClient):
|
|
241
254
|
model = self.model.bind_tools(
|
242
255
|
[_convert_to_langchain_tool(tool) for tool in tools]
|
243
256
|
)
|
244
|
-
# invoke_params["tools"] = tools
|
245
257
|
invoke_params.update(self.function_call_params)
|
246
258
|
else:
|
247
259
|
model = self.model
|
@@ -249,6 +261,7 @@ class LangChainModelClient(ModelClient):
|
|
249
261
|
response = SimpleNamespace()
|
250
262
|
response.choices = []
|
251
263
|
response.model = self.model_name
|
264
|
+
response.usage = Usage()
|
252
265
|
|
253
266
|
if streaming and messages:
|
254
267
|
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
@@ -279,4 +292,4 @@ class LangChainModelClient(ModelClient):
|
|
279
292
|
@staticmethod
|
280
293
|
def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
|
281
294
|
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
282
|
-
return
|
295
|
+
return asdict(response.usage)
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
2
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
3
|
+
import io
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import threading
|
8
|
+
|
9
|
+
import fsspec
|
10
|
+
|
11
|
+
from ads.common.auth import default_signer
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class OCIFileHandler(logging.FileHandler):
|
17
|
+
"""Log handler for saving log file to OCI object storage."""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
filename: str,
|
22
|
+
session_id: str,
|
23
|
+
mode: str = "a",
|
24
|
+
encoding: str | None = None,
|
25
|
+
delay: bool = False,
|
26
|
+
errors: str | None = None,
|
27
|
+
auth: dict | None = None,
|
28
|
+
) -> None:
|
29
|
+
self.session_id = session_id
|
30
|
+
self.auth = auth
|
31
|
+
|
32
|
+
if filename.startswith("oci://"):
|
33
|
+
self.baseFilename = filename
|
34
|
+
else:
|
35
|
+
self.baseFilename = os.path.abspath(os.path.expanduser(filename))
|
36
|
+
os.makedirs(os.path.dirname(self.baseFilename), exist_ok=True)
|
37
|
+
|
38
|
+
# The following code are from the `FileHandler.__init__()`
|
39
|
+
self.mode = mode
|
40
|
+
self.encoding = encoding
|
41
|
+
if "b" not in mode:
|
42
|
+
self.encoding = io.text_encoding(encoding)
|
43
|
+
self.errors = errors
|
44
|
+
self.delay = delay
|
45
|
+
|
46
|
+
if delay:
|
47
|
+
# We don't open the stream, but we still need to call the
|
48
|
+
# Handler constructor to set level, formatter, lock etc.
|
49
|
+
logging.Handler.__init__(self)
|
50
|
+
self.stream = None
|
51
|
+
else:
|
52
|
+
logging.StreamHandler.__init__(self, self._open())
|
53
|
+
|
54
|
+
def _open(self):
|
55
|
+
"""
|
56
|
+
Open the current base file with the (original) mode and encoding.
|
57
|
+
Return the resulting stream.
|
58
|
+
"""
|
59
|
+
auth = self.auth or default_signer()
|
60
|
+
return fsspec.open(
|
61
|
+
self.baseFilename,
|
62
|
+
self.mode,
|
63
|
+
encoding=self.encoding,
|
64
|
+
errors=self.errors,
|
65
|
+
**auth,
|
66
|
+
).open()
|
67
|
+
|
68
|
+
def format(self, record: logging.LogRecord):
|
69
|
+
"""Formats the log record as JSON payload and add session_id."""
|
70
|
+
msg = record.getMessage()
|
71
|
+
try:
|
72
|
+
data = json.loads(msg)
|
73
|
+
except Exception as e:
|
74
|
+
data = {"message": msg}
|
75
|
+
|
76
|
+
if "session_id" not in data:
|
77
|
+
data["session_id"] = self.session_id
|
78
|
+
if "thread_id" not in data:
|
79
|
+
data["thread_id"] = threading.get_ident()
|
80
|
+
|
81
|
+
record.msg = json.dumps(data)
|
82
|
+
return super().format(record)
|
83
|
+
|
@@ -0,0 +1,6 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
from ads.llm.autogen.v02.loggers.metric_logger import MetricLogger
|
6
|
+
from ads.llm.autogen.v02.loggers.session_logger import SessionLogger
|
@@ -0,0 +1,320 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
import logging
|
5
|
+
from datetime import datetime
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
7
|
+
from uuid import UUID
|
8
|
+
|
9
|
+
import oci
|
10
|
+
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
11
|
+
from autogen.logger.base_logger import BaseLogger, LLMConfig
|
12
|
+
from autogen.logger.logger_utils import get_current_ts
|
13
|
+
from oci.monitoring import MonitoringClient
|
14
|
+
from pydantic import BaseModel, Field
|
15
|
+
|
16
|
+
import ads
|
17
|
+
import ads.config
|
18
|
+
from ads.llm.autogen.v02.loggers.utils import serialize_response
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class MetricName:
|
24
|
+
"""Constants for metric name."""
|
25
|
+
|
26
|
+
TOOL_CALL = "tool_call"
|
27
|
+
CHAT_COMPLETION = "chat_completion_count"
|
28
|
+
COST = "chat_completion_cost"
|
29
|
+
SESSION_START = "session_start"
|
30
|
+
SESSION_STOP = "session_stop"
|
31
|
+
|
32
|
+
|
33
|
+
class MetricDimension:
|
34
|
+
"""Constants for metric dimension."""
|
35
|
+
|
36
|
+
AGENT_NAME = "agent_name"
|
37
|
+
APP_NAME = "app_name"
|
38
|
+
MODEL = "model"
|
39
|
+
SESSION_ID = "session_id"
|
40
|
+
TOOL_NAME = "tool_name"
|
41
|
+
|
42
|
+
|
43
|
+
class Metric(BaseModel):
|
44
|
+
"""Represents the metric to be logged."""
|
45
|
+
|
46
|
+
name: str
|
47
|
+
value: float
|
48
|
+
timestamp: str
|
49
|
+
dimensions: dict = Field(default_factory=dict)
|
50
|
+
|
51
|
+
|
52
|
+
class MetricLogger(BaseLogger):
|
53
|
+
"""AutoGen logger for agent metrics."""
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
namespace: str,
|
58
|
+
app_name: Optional[str] = None,
|
59
|
+
compartment_id: Optional[str] = None,
|
60
|
+
session_id: Optional[str] = None,
|
61
|
+
region: Optional[str] = None,
|
62
|
+
resource_group: Optional[str] = None,
|
63
|
+
log_agent_name: bool = False,
|
64
|
+
log_tool_name: bool = False,
|
65
|
+
log_model_name: bool = False,
|
66
|
+
):
|
67
|
+
"""Initialize the metric logger.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
namespace : str
|
72
|
+
Namespace for posting the metric
|
73
|
+
app_name : str
|
74
|
+
Application name, which will be a metric dimension if specified.
|
75
|
+
compartment_id : str, optional
|
76
|
+
Compartment OCID for posting the metric.
|
77
|
+
If compartment_id is not specified,
|
78
|
+
ADS will try to fetch the compartment OCID from environment variable.
|
79
|
+
session_id : str, optional
|
80
|
+
Session ID to be saved as a metric dimension, by default None.
|
81
|
+
region : str, optional
|
82
|
+
OCI region for posting the metric, by default None.
|
83
|
+
If region is not specified, the region from the authentication signer will be used.
|
84
|
+
resource_group : str, optional
|
85
|
+
Resource group for the metric, by default None
|
86
|
+
log_agent_name : bool, optional
|
87
|
+
Whether to log agent name as a metric dimension, by default True.
|
88
|
+
log_tool_name : bool, optional
|
89
|
+
Whether to log tool name as a metric dimension, by default True.
|
90
|
+
log_model_name : bool, optional
|
91
|
+
Whether to log model name as a metric dimension, by default True.
|
92
|
+
|
93
|
+
"""
|
94
|
+
self.app_name = app_name
|
95
|
+
self.session_id = session_id
|
96
|
+
self.compartment_id = compartment_id or ads.config.COMPARTMENT_OCID
|
97
|
+
if not self.compartment_id:
|
98
|
+
raise ValueError(
|
99
|
+
"Unable to determine compartment OCID for metric logger."
|
100
|
+
"Please specify the compartment_id."
|
101
|
+
)
|
102
|
+
self.namespace = namespace
|
103
|
+
self.resource_group = resource_group
|
104
|
+
self.log_agent_name = log_agent_name
|
105
|
+
self.log_tool_name = log_tool_name
|
106
|
+
self.log_model_name = log_model_name
|
107
|
+
# Indicate if the logger has started.
|
108
|
+
self.started = False
|
109
|
+
|
110
|
+
auth = ads.auth.default_signer()
|
111
|
+
|
112
|
+
# Use the config/signer to determine the region if it not specified.
|
113
|
+
signer = auth.get("signer")
|
114
|
+
config = auth.get("config", {})
|
115
|
+
if not region:
|
116
|
+
if hasattr(signer, "region") and signer.region:
|
117
|
+
region = signer.region
|
118
|
+
elif config.get("region"):
|
119
|
+
region = config.get("region")
|
120
|
+
else:
|
121
|
+
raise ValueError(
|
122
|
+
"Unable to determine the region for OCI monitoring service. "
|
123
|
+
"Please specify the region using the `region` argument."
|
124
|
+
)
|
125
|
+
|
126
|
+
self.monitoring_client = MonitoringClient(
|
127
|
+
config=config,
|
128
|
+
signer=signer,
|
129
|
+
# Metrics should be submitted with the "telemetry-ingestion" endpoint instead.
|
130
|
+
# See note here: https://docs.oracle.com/iaas/api/#/en/monitoring/20180401/MetricData/PostMetricData
|
131
|
+
service_endpoint=f"https://telemetry-ingestion.{region}.oraclecloud.com",
|
132
|
+
)
|
133
|
+
|
134
|
+
def _post_metric(self, metric: Metric):
|
135
|
+
"""Posts metric to OCI monitoring."""
|
136
|
+
# Add app_name and session_id to dimensions
|
137
|
+
dimensions = metric.dimensions
|
138
|
+
if self.app_name:
|
139
|
+
dimensions[MetricDimension.APP_NAME] = self.app_name
|
140
|
+
if self.session_id:
|
141
|
+
dimensions[MetricDimension.SESSION_ID] = self.session_id
|
142
|
+
|
143
|
+
logger.debug("Posting metrics:\n%s", str(metric))
|
144
|
+
self.monitoring_client.post_metric_data(
|
145
|
+
post_metric_data_details=oci.monitoring.models.PostMetricDataDetails(
|
146
|
+
metric_data=[
|
147
|
+
oci.monitoring.models.MetricDataDetails(
|
148
|
+
namespace=self.namespace,
|
149
|
+
compartment_id=self.compartment_id,
|
150
|
+
name=metric.name,
|
151
|
+
dimensions=dimensions,
|
152
|
+
datapoints=[
|
153
|
+
oci.monitoring.models.Datapoint(
|
154
|
+
timestamp=datetime.strptime(
|
155
|
+
metric.timestamp.replace(" ", "T") + "Z",
|
156
|
+
"%Y-%m-%dT%H:%M:%S.%fZ",
|
157
|
+
),
|
158
|
+
value=metric.value,
|
159
|
+
count=1,
|
160
|
+
)
|
161
|
+
],
|
162
|
+
resource_group=self.resource_group,
|
163
|
+
)
|
164
|
+
],
|
165
|
+
batch_atomicity="ATOMIC",
|
166
|
+
),
|
167
|
+
)
|
168
|
+
|
169
|
+
def start(self):
|
170
|
+
"""Starts the logger."""
|
171
|
+
if self.session_id:
|
172
|
+
logger.info(f"Starting metric logging for session_id: {self.session_id}")
|
173
|
+
else:
|
174
|
+
logger.info("Starting metric logging.")
|
175
|
+
self.started = True
|
176
|
+
try:
|
177
|
+
metric = Metric(
|
178
|
+
name=MetricName.SESSION_START,
|
179
|
+
value=1,
|
180
|
+
timestamp=get_current_ts(),
|
181
|
+
)
|
182
|
+
self._post_metric(metric=metric)
|
183
|
+
except Exception as e:
|
184
|
+
logger.error(f"MetricLogger Failed to log session start: {str(e)}")
|
185
|
+
return self.session_id
|
186
|
+
|
187
|
+
def log_new_agent(
|
188
|
+
self, agent: ConversableAgent, init_args: Dict[str, Any] = {}
|
189
|
+
) -> None:
|
190
|
+
"""Metric logger does not log new agent."""
|
191
|
+
pass
|
192
|
+
|
193
|
+
def log_function_use(
|
194
|
+
self,
|
195
|
+
source: Union[str, Agent],
|
196
|
+
function: Any,
|
197
|
+
args: Dict[str, Any],
|
198
|
+
returns: Any,
|
199
|
+
) -> None:
|
200
|
+
"""
|
201
|
+
Log a registered function(can be a tool) use from an agent or a string source.
|
202
|
+
"""
|
203
|
+
if not self.started:
|
204
|
+
return
|
205
|
+
agent_name = str(source.name) if hasattr(source, "name") else source
|
206
|
+
dimensions = {}
|
207
|
+
if self.log_tool_name:
|
208
|
+
dimensions[MetricDimension.TOOL_NAME] = function.__name__
|
209
|
+
if self.log_agent_name:
|
210
|
+
dimensions[MetricDimension.AGENT_NAME] = agent_name
|
211
|
+
try:
|
212
|
+
self._post_metric(
|
213
|
+
Metric(
|
214
|
+
name=MetricName.TOOL_CALL,
|
215
|
+
value=1,
|
216
|
+
timestamp=get_current_ts(),
|
217
|
+
dimensions=dimensions,
|
218
|
+
)
|
219
|
+
)
|
220
|
+
except Exception as e:
|
221
|
+
logger.error(f"MetricLogger Failed to log tool call: {str(e)}")
|
222
|
+
|
223
|
+
def log_chat_completion(
|
224
|
+
self,
|
225
|
+
invocation_id: UUID,
|
226
|
+
client_id: int,
|
227
|
+
wrapper_id: int,
|
228
|
+
source: Union[str, Agent],
|
229
|
+
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
230
|
+
response: Union[str, Any],
|
231
|
+
is_cached: int,
|
232
|
+
cost: float,
|
233
|
+
start_time: str,
|
234
|
+
) -> None:
|
235
|
+
"""
|
236
|
+
Log a chat completion.
|
237
|
+
"""
|
238
|
+
if not self.started:
|
239
|
+
return
|
240
|
+
|
241
|
+
try:
|
242
|
+
response: dict = serialize_response(response)
|
243
|
+
if "usage" not in response or not isinstance(response["usage"], dict):
|
244
|
+
return
|
245
|
+
# Post usage metric
|
246
|
+
agent_name = str(source.name) if hasattr(source, "name") else source
|
247
|
+
model = response.get("model", "N/A")
|
248
|
+
dimensions = {}
|
249
|
+
if self.log_model_name:
|
250
|
+
dimensions[MetricDimension.MODEL] = model
|
251
|
+
if self.log_agent_name:
|
252
|
+
dimensions[MetricDimension.AGENT_NAME] = agent_name
|
253
|
+
|
254
|
+
# Chat completion count
|
255
|
+
self._post_metric(
|
256
|
+
Metric(
|
257
|
+
name=MetricName.CHAT_COMPLETION,
|
258
|
+
value=1,
|
259
|
+
timestamp=get_current_ts(),
|
260
|
+
dimensions=dimensions,
|
261
|
+
)
|
262
|
+
)
|
263
|
+
# Cost
|
264
|
+
if cost:
|
265
|
+
self._post_metric(
|
266
|
+
Metric(
|
267
|
+
name=MetricName.COST,
|
268
|
+
value=cost,
|
269
|
+
timestamp=get_current_ts(),
|
270
|
+
dimensions=dimensions,
|
271
|
+
)
|
272
|
+
)
|
273
|
+
# Usage
|
274
|
+
for key, val in response["usage"].items():
|
275
|
+
self._post_metric(
|
276
|
+
Metric(
|
277
|
+
name=key,
|
278
|
+
value=val,
|
279
|
+
timestamp=get_current_ts(),
|
280
|
+
dimensions=dimensions,
|
281
|
+
)
|
282
|
+
)
|
283
|
+
|
284
|
+
except Exception as e:
|
285
|
+
logger.error(f"MetricLogger Failed to log chat completion: {str(e)}")
|
286
|
+
|
287
|
+
def log_new_wrapper(
|
288
|
+
self,
|
289
|
+
wrapper: OpenAIWrapper,
|
290
|
+
init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {},
|
291
|
+
) -> None:
|
292
|
+
"""Metric logger does not log new wrapper."""
|
293
|
+
pass
|
294
|
+
|
295
|
+
def log_new_client(self, client, wrapper, init_args):
|
296
|
+
"""Metric logger does not log new client."""
|
297
|
+
pass
|
298
|
+
|
299
|
+
def log_event(self, source, name, **kwargs):
|
300
|
+
"""Metric logger does not log events."""
|
301
|
+
pass
|
302
|
+
|
303
|
+
def get_connection(self):
|
304
|
+
pass
|
305
|
+
|
306
|
+
def stop(self):
|
307
|
+
"""Stops the metric logger."""
|
308
|
+
if not self.started:
|
309
|
+
return
|
310
|
+
self.started = False
|
311
|
+
try:
|
312
|
+
metric = Metric(
|
313
|
+
name=MetricName.SESSION_STOP,
|
314
|
+
value=1,
|
315
|
+
timestamp=get_current_ts(),
|
316
|
+
)
|
317
|
+
self._post_metric(metric=metric)
|
318
|
+
except Exception as e:
|
319
|
+
logger.error(f"MetricLogger Failed to log session stop: {str(e)}")
|
320
|
+
logger.info("Metric logger stopped.")
|