oracle-ads 2.12.9__py3-none-any.whl → 2.12.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/__init__.py +4 -3
- ads/aqua/app.py +28 -16
- ads/aqua/client/__init__.py +3 -0
- ads/aqua/client/client.py +799 -0
- ads/aqua/common/enums.py +3 -0
- ads/aqua/common/utils.py +62 -2
- ads/aqua/data.py +2 -19
- ads/aqua/evaluation/evaluation.py +20 -12
- ads/aqua/extension/aqua_ws_msg_handler.py +14 -7
- ads/aqua/extension/base_handler.py +12 -9
- ads/aqua/extension/finetune_handler.py +8 -14
- ads/aqua/extension/model_handler.py +24 -2
- ads/aqua/finetuning/constants.py +5 -2
- ads/aqua/finetuning/entities.py +67 -17
- ads/aqua/finetuning/finetuning.py +69 -54
- ads/aqua/model/entities.py +3 -1
- ads/aqua/model/model.py +196 -98
- ads/aqua/modeldeployment/deployment.py +22 -10
- ads/cli.py +16 -8
- ads/common/auth.py +9 -9
- ads/llm/autogen/__init__.py +2 -0
- ads/llm/autogen/constants.py +15 -0
- ads/llm/autogen/reports/__init__.py +2 -0
- ads/llm/autogen/reports/base.py +67 -0
- ads/llm/autogen/reports/data.py +103 -0
- ads/llm/autogen/reports/session.py +526 -0
- ads/llm/autogen/reports/templates/chat_box.html +13 -0
- ads/llm/autogen/reports/templates/chat_box_lt.html +5 -0
- ads/llm/autogen/reports/templates/chat_box_rt.html +6 -0
- ads/llm/autogen/reports/utils.py +56 -0
- ads/llm/autogen/v02/__init__.py +4 -0
- ads/llm/autogen/{client_v02.py → v02/client.py} +23 -10
- ads/llm/autogen/v02/log_handlers/__init__.py +2 -0
- ads/llm/autogen/v02/log_handlers/oci_file_handler.py +83 -0
- ads/llm/autogen/v02/loggers/__init__.py +6 -0
- ads/llm/autogen/v02/loggers/metric_logger.py +320 -0
- ads/llm/autogen/v02/loggers/session_logger.py +580 -0
- ads/llm/autogen/v02/loggers/utils.py +86 -0
- ads/llm/autogen/v02/runtime_logging.py +163 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +12 -11
- ads/model/__init__.py +11 -13
- ads/model/artifact.py +47 -8
- ads/model/extractor/embedding_onnx_extractor.py +80 -0
- ads/model/framework/embedding_onnx_model.py +438 -0
- ads/model/generic_model.py +26 -24
- ads/model/model_metadata.py +8 -7
- ads/opctl/config/merger.py +13 -14
- ads/opctl/operator/common/operator_config.py +4 -4
- ads/opctl/operator/lowcode/common/transformations.py +50 -8
- ads/opctl/operator/lowcode/common/utils.py +22 -6
- ads/opctl/operator/lowcode/forecast/__main__.py +10 -0
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/model/arima.py +19 -13
- ads/opctl/operator/lowcode/forecast/model/automlx.py +129 -36
- ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +61 -14
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +10 -3
- ads/opctl/operator/lowcode/forecast/model/prophet.py +25 -18
- ads/opctl/operator/lowcode/forecast/operator_config.py +31 -0
- ads/opctl/operator/lowcode/forecast/schema.yaml +76 -0
- ads/opctl/operator/lowcode/forecast/utils.py +4 -3
- ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py +7 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +233 -0
- ads/opctl/operator/lowcode/forecast/whatifserve/score.py +238 -0
- ads/telemetry/base.py +18 -11
- ads/telemetry/client.py +33 -13
- ads/templates/schemas/openapi.json +1740 -0
- ads/templates/score_embedding_onnx.jinja2 +202 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/METADATA +9 -8
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/RECORD +74 -48
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/WHEEL +0 -0
- {oracle_ads-2.12.9.dist-info → oracle_ads-2.12.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,580 @@
|
|
1
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
2
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
3
|
+
import importlib
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import tempfile
|
7
|
+
import threading
|
8
|
+
import traceback
|
9
|
+
import uuid
|
10
|
+
from dataclasses import dataclass, field
|
11
|
+
from datetime import datetime, timedelta, timezone
|
12
|
+
from typing import Any, Dict, List, Optional, Union
|
13
|
+
from urllib.parse import urlparse
|
14
|
+
|
15
|
+
import autogen
|
16
|
+
import fsspec
|
17
|
+
import oci
|
18
|
+
from autogen import Agent, ConversableAgent, GroupChatManager, OpenAIWrapper
|
19
|
+
from autogen.logger.file_logger import (
|
20
|
+
ChatCompletion,
|
21
|
+
F,
|
22
|
+
FileLogger,
|
23
|
+
get_current_ts,
|
24
|
+
safe_serialize,
|
25
|
+
)
|
26
|
+
from oci.object_storage import ObjectStorageClient
|
27
|
+
from oci.object_storage.models import (
|
28
|
+
CreatePreauthenticatedRequestDetails,
|
29
|
+
PreauthenticatedRequest,
|
30
|
+
)
|
31
|
+
|
32
|
+
import ads
|
33
|
+
from ads.common.auth import default_signer
|
34
|
+
from ads.llm.autogen.constants import Events
|
35
|
+
from ads.llm.autogen.reports.data import (
|
36
|
+
AgentData,
|
37
|
+
LLMCompletionData,
|
38
|
+
LogRecord,
|
39
|
+
ToolCallData,
|
40
|
+
)
|
41
|
+
from ads.llm.autogen.reports.session import SessionReport
|
42
|
+
from ads.llm.autogen.v02 import runtime_logging
|
43
|
+
from ads.llm.autogen.v02.log_handlers.oci_file_handler import OCIFileHandler
|
44
|
+
from ads.llm.autogen.v02.loggers.utils import (
|
45
|
+
serialize,
|
46
|
+
serialize_response,
|
47
|
+
)
|
48
|
+
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
|
52
|
+
CONST_REPLY_FUNC_NAME = "reply_func_name"
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class LoggingSession:
|
57
|
+
"""Represents a logging session for a specific thread."""
|
58
|
+
|
59
|
+
session_id: str
|
60
|
+
log_dir: str
|
61
|
+
log_file: str
|
62
|
+
thread_id: int
|
63
|
+
pid: int
|
64
|
+
logger: logging.Logger
|
65
|
+
auth: dict = field(default_factory=dict)
|
66
|
+
report_file: Optional[str] = None
|
67
|
+
par_uri: Optional[str] = None
|
68
|
+
|
69
|
+
@property
|
70
|
+
def report(self) -> str:
|
71
|
+
"""HTML report path of the logging session.
|
72
|
+
If the a pre-authenticated link is generated for the report,
|
73
|
+
the pre-authenticated link will be returned.
|
74
|
+
|
75
|
+
If the report is saved to OCI object storage, the URI will be return.
|
76
|
+
If the report is saved locally, the local path will be return.
|
77
|
+
If there is no report generated, `None` will be returned.
|
78
|
+
"""
|
79
|
+
if self.par_uri:
|
80
|
+
return self.par_uri
|
81
|
+
elif self.report_file:
|
82
|
+
return self.report_file
|
83
|
+
return None
|
84
|
+
|
85
|
+
def __repr__(self) -> str:
|
86
|
+
"""Shows the link to report if it is available, otherwise shows the log file link."""
|
87
|
+
if self.report:
|
88
|
+
return self.report
|
89
|
+
return self.log_file
|
90
|
+
|
91
|
+
def create_par_uri(self, oci_file: str, **kwargs) -> str:
|
92
|
+
"""Creates a pre-authenticated request URI for a file on OCI object storage.
|
93
|
+
|
94
|
+
Parameters
|
95
|
+
----------
|
96
|
+
oci_file : str
|
97
|
+
OCI file URI in the format of oci://bucket@namespace/prefix/to/file
|
98
|
+
auth : dict, optional
|
99
|
+
Dictionary containing the OCI authentication config and signer.
|
100
|
+
Defaults to `ads.common.auth.default_signer()`.
|
101
|
+
|
102
|
+
Returns
|
103
|
+
-------
|
104
|
+
str
|
105
|
+
The pre-authenticated URI
|
106
|
+
"""
|
107
|
+
auth = self.auth or default_signer()
|
108
|
+
client = ObjectStorageClient(**auth)
|
109
|
+
parsed = urlparse(oci_file)
|
110
|
+
bucket = parsed.username
|
111
|
+
namespace = parsed.hostname
|
112
|
+
time_expires = kwargs.pop(
|
113
|
+
"time_expires", datetime.now(timezone.utc) + timedelta(weeks=1)
|
114
|
+
)
|
115
|
+
access_type = kwargs.pop("access_type", "ObjectRead")
|
116
|
+
response: PreauthenticatedRequest = client.create_preauthenticated_request(
|
117
|
+
bucket_name=bucket,
|
118
|
+
namespace_name=namespace,
|
119
|
+
create_preauthenticated_request_details=CreatePreauthenticatedRequestDetails(
|
120
|
+
name=os.path.basename(oci_file),
|
121
|
+
object_name=str(parsed.path).lstrip("/"),
|
122
|
+
access_type=access_type,
|
123
|
+
time_expires=time_expires,
|
124
|
+
**kwargs,
|
125
|
+
),
|
126
|
+
).data
|
127
|
+
return response.full_path
|
128
|
+
|
129
|
+
def create_report(
|
130
|
+
self, report_file: str, return_par_uri: bool = False, **kwargs
|
131
|
+
) -> str:
|
132
|
+
"""Creates a report in HTML format.
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
report_file : str
|
137
|
+
The file path to save the report.
|
138
|
+
return_par_uri : bool, optional
|
139
|
+
If the report is saved in object storage,
|
140
|
+
whether to create a pre-authenticated link for the report, by default False.
|
141
|
+
This will be ignored if the report is not saved in object storage.
|
142
|
+
|
143
|
+
Returns
|
144
|
+
-------
|
145
|
+
str
|
146
|
+
The full path or pre-authenticated link of the report.
|
147
|
+
"""
|
148
|
+
auth = self.auth or default_signer()
|
149
|
+
report = SessionReport(log_file=self.log_file, auth=auth)
|
150
|
+
if report_file.startswith("oci://"):
|
151
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
152
|
+
# Save the report to local temp dir
|
153
|
+
temp_report = os.path.join(temp_dir, os.path.basename(report_file))
|
154
|
+
report.build(temp_report)
|
155
|
+
# Upload to OCI object storage
|
156
|
+
fs = fsspec.filesystem("oci", **auth)
|
157
|
+
fs.put(temp_report, report_file)
|
158
|
+
if return_par_uri:
|
159
|
+
par_uri = self.create_par_uri(oci_file=report_file, **kwargs)
|
160
|
+
self.report_file = report_file
|
161
|
+
self.par_uri = par_uri
|
162
|
+
return par_uri
|
163
|
+
else:
|
164
|
+
report_file = os.path.abspath(os.path.expanduser(report_file))
|
165
|
+
os.makedirs(os.path.dirname(report_file), exist_ok=True)
|
166
|
+
report.build(report_file)
|
167
|
+
self.report_file = report_file
|
168
|
+
return report_file
|
169
|
+
|
170
|
+
|
171
|
+
class SessionLogger(FileLogger):
|
172
|
+
"""Logger for saving log file to OCI object storage."""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
log_dir: str,
|
177
|
+
report_dir: Optional[str] = None,
|
178
|
+
session_id: Optional[str] = None,
|
179
|
+
auth: Optional[dict] = None,
|
180
|
+
log_for_all_threads: str = False,
|
181
|
+
report_par_uri: bool = False,
|
182
|
+
par_kwargs: Optional[dict] = None,
|
183
|
+
):
|
184
|
+
"""Initialize a file logger for new session.
|
185
|
+
|
186
|
+
Parameters
|
187
|
+
----------
|
188
|
+
log_dir : str
|
189
|
+
Directory for saving the log file.
|
190
|
+
session_id : str, optional
|
191
|
+
Session ID, by default None.
|
192
|
+
If the session ID is None, a new UUID4 will be generated.
|
193
|
+
The session ID will be used as the log filename.
|
194
|
+
auth: dict, optional
|
195
|
+
Dictionary containing the OCI authentication config and signer.
|
196
|
+
If auth is None, `ads.common.auth.default_signer()` will be used.
|
197
|
+
log_for_all_threads:
|
198
|
+
Indicate if the logger should handle logging for all threads.
|
199
|
+
Defaults to False, the logger will only log for the current thread.
|
200
|
+
"""
|
201
|
+
self.report_dir = report_dir
|
202
|
+
self.report_par_uri = report_par_uri
|
203
|
+
self.par_kwargs = par_kwargs
|
204
|
+
self.log_for_all_threads = log_for_all_threads
|
205
|
+
|
206
|
+
self.session = self.new_session(
|
207
|
+
log_dir=log_dir, session_id=session_id, auth=auth
|
208
|
+
)
|
209
|
+
# Log only if started is True
|
210
|
+
self.started = False
|
211
|
+
|
212
|
+
# Keep track of last check_termination_and_human_reply for calculating tool call duration
|
213
|
+
# This will be a dictionary mapping the IDs of the agents to their last timestamp
|
214
|
+
# of check_termination_and_human_reply
|
215
|
+
self.last_agent_checks = {}
|
216
|
+
|
217
|
+
@property
|
218
|
+
def logger(self) -> Optional[logging.Logger]:
|
219
|
+
"""Logger for the thread.
|
220
|
+
|
221
|
+
This property is used to determine whether the log should be saved.
|
222
|
+
No log will be saved if the logger is None.
|
223
|
+
"""
|
224
|
+
if not self.started:
|
225
|
+
return None
|
226
|
+
thread_id = threading.get_ident()
|
227
|
+
if not self.log_for_all_threads and thread_id != self.session.thread_id:
|
228
|
+
return None
|
229
|
+
return self.session.logger
|
230
|
+
|
231
|
+
@property
|
232
|
+
def session_id(self) -> Optional[str]:
|
233
|
+
"""Session ID for the current session."""
|
234
|
+
return self.session.session_id
|
235
|
+
|
236
|
+
@property
|
237
|
+
def log_file(self) -> Optional[str]:
|
238
|
+
"""Log file path for the current session."""
|
239
|
+
return self.session.log_file
|
240
|
+
|
241
|
+
@property
|
242
|
+
def report(self) -> Optional[str]:
|
243
|
+
"""Report path/link for the session."""
|
244
|
+
return self.session.report
|
245
|
+
|
246
|
+
@property
|
247
|
+
def name(self) -> str:
|
248
|
+
"""Name of the logger."""
|
249
|
+
return self.session_id or "oci_file_logger"
|
250
|
+
|
251
|
+
def new_session(
|
252
|
+
self,
|
253
|
+
log_dir: str,
|
254
|
+
session_id: Optional[str] = None,
|
255
|
+
auth: Optional[dict] = None,
|
256
|
+
) -> LoggingSession:
|
257
|
+
"""Creates a new logging session.
|
258
|
+
|
259
|
+
Parameters
|
260
|
+
----------
|
261
|
+
log_dir : str
|
262
|
+
Directory for saving the log file.
|
263
|
+
session_id : str, optional
|
264
|
+
Session ID, by default None.
|
265
|
+
If the session ID is None, a new UUID4 will be generated.
|
266
|
+
The session ID will be used as the log filename.
|
267
|
+
auth: dict, optional
|
268
|
+
Dictionary containing the OCI authentication config and signer.
|
269
|
+
If auth is None, `ads.common.auth.default_signer()` will be used.
|
270
|
+
|
271
|
+
|
272
|
+
Returns
|
273
|
+
-------
|
274
|
+
LoggingSession
|
275
|
+
The new logging session
|
276
|
+
"""
|
277
|
+
thread_id = threading.get_ident()
|
278
|
+
|
279
|
+
session_id = str(session_id or uuid.uuid4())
|
280
|
+
log_file = os.path.join(log_dir, f"{session_id}.log")
|
281
|
+
|
282
|
+
# Prepare the logger
|
283
|
+
session_logger = logging.getLogger(session_id)
|
284
|
+
session_logger.setLevel(logging.INFO)
|
285
|
+
file_handler = OCIFileHandler(log_file, session_id=session_id, auth=auth)
|
286
|
+
session_logger.addHandler(file_handler)
|
287
|
+
|
288
|
+
# Create logging session
|
289
|
+
session = LoggingSession(
|
290
|
+
session_id=session_id,
|
291
|
+
log_dir=log_dir,
|
292
|
+
log_file=log_file,
|
293
|
+
thread_id=thread_id,
|
294
|
+
pid=os.getpid(),
|
295
|
+
logger=session_logger,
|
296
|
+
auth=auth,
|
297
|
+
)
|
298
|
+
|
299
|
+
logger.info("Start logging session %s to file %s", session_id, log_file)
|
300
|
+
return session
|
301
|
+
|
302
|
+
def generate_report(
|
303
|
+
self,
|
304
|
+
report_dir: Optional[str] = None,
|
305
|
+
report_par_uri: Optional[bool] = None,
|
306
|
+
**kwargs,
|
307
|
+
) -> str:
|
308
|
+
"""Generates a report for the session.
|
309
|
+
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
report_dir : str, optional
|
313
|
+
Directory for saving the report, by default None
|
314
|
+
report_par_uri : bool, optional
|
315
|
+
Whether to create a pre-authenticated link for the report, by default None.
|
316
|
+
If the `report_par_uri` is not set, the value of `self.report_par_uri` will be used.
|
317
|
+
|
318
|
+
Returns
|
319
|
+
-------
|
320
|
+
str
|
321
|
+
The link to the report.
|
322
|
+
If the `report_dir` is local, the local file path will be returned.
|
323
|
+
If a pre-authenticated link is created, the link will be returned.
|
324
|
+
"""
|
325
|
+
report_dir = report_dir or self.report_dir
|
326
|
+
report_par_uri = (
|
327
|
+
report_par_uri if report_par_uri is not None else self.report_par_uri
|
328
|
+
)
|
329
|
+
kwargs = kwargs or self.par_kwargs or {}
|
330
|
+
|
331
|
+
report_file = os.path.join(self.report_dir, f"{self.session_id}.html")
|
332
|
+
report_link = self.session.create_report(
|
333
|
+
report_file=report_file, return_par_uri=self.report_par_uri, **kwargs
|
334
|
+
)
|
335
|
+
print(f"ADS AutoGen Session Report: {report_link}")
|
336
|
+
return report_link
|
337
|
+
|
338
|
+
def new_record(self, event_name: str, source: Any = None) -> LogRecord:
|
339
|
+
"""Initialize a new log record.
|
340
|
+
|
341
|
+
The record is not logged until `self.log()` is called.
|
342
|
+
"""
|
343
|
+
record = LogRecord(
|
344
|
+
session_id=self.session_id,
|
345
|
+
thread_id=threading.get_ident(),
|
346
|
+
timestamp=get_current_ts(),
|
347
|
+
event_name=event_name,
|
348
|
+
)
|
349
|
+
if source:
|
350
|
+
record.source_id = id(source)
|
351
|
+
record.source_name = str(source.name) if hasattr(source, "name") else source
|
352
|
+
return record
|
353
|
+
|
354
|
+
def log(self, record: LogRecord) -> None:
|
355
|
+
"""Logs a record.
|
356
|
+
|
357
|
+
Parameters
|
358
|
+
----------
|
359
|
+
data : dict
|
360
|
+
Data to be logged.
|
361
|
+
"""
|
362
|
+
# Do nothing if there is no logger for the thread.
|
363
|
+
if not self.logger:
|
364
|
+
return
|
365
|
+
|
366
|
+
try:
|
367
|
+
self.logger.info(record.to_string())
|
368
|
+
except Exception:
|
369
|
+
self.logger.info("Failed to log %s", record.event_name)
|
370
|
+
|
371
|
+
def start(self) -> str:
|
372
|
+
"""Start the logging session and return the session_id."""
|
373
|
+
envs = {
|
374
|
+
"oracle-ads": ads.__version__,
|
375
|
+
"oci": oci.__version__,
|
376
|
+
"autogen": autogen.__version__,
|
377
|
+
}
|
378
|
+
libraries = [
|
379
|
+
"langchain",
|
380
|
+
"langchain-core",
|
381
|
+
"langchain-community",
|
382
|
+
"langchain-openai",
|
383
|
+
"openai",
|
384
|
+
]
|
385
|
+
for library in libraries:
|
386
|
+
try:
|
387
|
+
imported_library = importlib.import_module(library)
|
388
|
+
version = imported_library.__version__
|
389
|
+
envs[library] = version
|
390
|
+
except Exception:
|
391
|
+
pass
|
392
|
+
self.started = True
|
393
|
+
self.log_event(source=self, name=Events.SESSION_START, environment=envs)
|
394
|
+
return self.session_id
|
395
|
+
|
396
|
+
def stop(self) -> None:
|
397
|
+
"""Stops the logging session."""
|
398
|
+
self.log_event(source=self, name=Events.SESSION_STOP)
|
399
|
+
super().stop()
|
400
|
+
self.started = False
|
401
|
+
if self.report_dir:
|
402
|
+
try:
|
403
|
+
self.generate_report()
|
404
|
+
except Exception as e:
|
405
|
+
logger.error(
|
406
|
+
"Failed to create session report for AutoGen session %s\n%s",
|
407
|
+
self.session_id,
|
408
|
+
str(e),
|
409
|
+
)
|
410
|
+
logger.debug(traceback.format_exc())
|
411
|
+
|
412
|
+
def log_chat_completion(
|
413
|
+
self,
|
414
|
+
invocation_id: uuid.UUID,
|
415
|
+
client_id: int,
|
416
|
+
wrapper_id: int,
|
417
|
+
source: Union[str, Agent],
|
418
|
+
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
419
|
+
response: Union[str, ChatCompletion],
|
420
|
+
is_cached: int,
|
421
|
+
cost: float,
|
422
|
+
start_time: str,
|
423
|
+
) -> None:
|
424
|
+
"""
|
425
|
+
Logs a chat completion.
|
426
|
+
"""
|
427
|
+
if not self.logger:
|
428
|
+
return
|
429
|
+
|
430
|
+
record = self.new_record(event_name=Events.LLM_CALL, source=source)
|
431
|
+
record.data = LLMCompletionData(
|
432
|
+
invocation_id=str(invocation_id),
|
433
|
+
request=serialize(request),
|
434
|
+
response=serialize_response(response),
|
435
|
+
start_time=start_time,
|
436
|
+
end_time=get_current_ts(),
|
437
|
+
cost=cost,
|
438
|
+
is_cached=is_cached,
|
439
|
+
)
|
440
|
+
record.kwargs = {
|
441
|
+
"client_id": client_id,
|
442
|
+
"wrapper_id": wrapper_id,
|
443
|
+
}
|
444
|
+
|
445
|
+
self.log(record)
|
446
|
+
|
447
|
+
def log_function_use(
|
448
|
+
self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any
|
449
|
+
) -> None:
|
450
|
+
"""
|
451
|
+
Logs a registered function(can be a tool) use from an agent or a string source.
|
452
|
+
"""
|
453
|
+
if not self.logger:
|
454
|
+
return
|
455
|
+
|
456
|
+
source_id = id(source)
|
457
|
+
if source_id in self.last_agent_checks:
|
458
|
+
start_time = self.last_agent_checks[source_id]
|
459
|
+
else:
|
460
|
+
start_time = get_current_ts()
|
461
|
+
|
462
|
+
record = self.new_record(Events.TOOL_CALL, source=source)
|
463
|
+
record.data = ToolCallData(
|
464
|
+
tool_name=function.__name__,
|
465
|
+
start_time=start_time,
|
466
|
+
end_time=record.timestamp,
|
467
|
+
agent_name=str(source.name) if hasattr(source, "name") else source,
|
468
|
+
agent_module=source.__module__,
|
469
|
+
agent_class=source.__class__.__name__,
|
470
|
+
input_args=safe_serialize(args),
|
471
|
+
returns=safe_serialize(returns),
|
472
|
+
)
|
473
|
+
|
474
|
+
self.log(record)
|
475
|
+
|
476
|
+
def log_new_agent(
|
477
|
+
self, agent: ConversableAgent, init_args: Dict[str, Any] = {}
|
478
|
+
) -> None:
|
479
|
+
"""
|
480
|
+
Logs a new agent instance.
|
481
|
+
"""
|
482
|
+
if not self.logger:
|
483
|
+
return
|
484
|
+
|
485
|
+
record = self.new_record(event_name=Events.NEW_AGENT, source=agent)
|
486
|
+
record.data = AgentData(
|
487
|
+
agent_name=(
|
488
|
+
agent.name
|
489
|
+
if hasattr(agent, "name") and agent.name is not None
|
490
|
+
else str(agent)
|
491
|
+
),
|
492
|
+
agent_module=agent.__module__,
|
493
|
+
agent_class=agent.__class__.__name__,
|
494
|
+
is_manager=isinstance(agent, GroupChatManager),
|
495
|
+
)
|
496
|
+
record.kwargs = {
|
497
|
+
"wrapper_id": serialize(
|
498
|
+
agent.client.wrapper_id
|
499
|
+
if hasattr(agent, "client") and agent.client is not None
|
500
|
+
else ""
|
501
|
+
),
|
502
|
+
"args": serialize(init_args),
|
503
|
+
}
|
504
|
+
self.log(record)
|
505
|
+
|
506
|
+
def log_event(
|
507
|
+
self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]
|
508
|
+
) -> None:
|
509
|
+
"""
|
510
|
+
Logs an event.
|
511
|
+
"""
|
512
|
+
record = self.new_record(event_name=name)
|
513
|
+
record.source_id = id(source)
|
514
|
+
record.source_name = str(source.name) if hasattr(source, "name") else source
|
515
|
+
record.kwargs = kwargs
|
516
|
+
if isinstance(source, Agent):
|
517
|
+
if (
|
518
|
+
CONST_REPLY_FUNC_NAME in kwargs
|
519
|
+
and kwargs[CONST_REPLY_FUNC_NAME] == "check_termination_and_human_reply"
|
520
|
+
):
|
521
|
+
self.last_agent_checks[record.source_id] = record.timestamp
|
522
|
+
record.data = AgentData(
|
523
|
+
agent_name=record.source_name,
|
524
|
+
agent_module=source.__module__,
|
525
|
+
agent_class=source.__class__.__name__,
|
526
|
+
is_manager=isinstance(source, GroupChatManager),
|
527
|
+
)
|
528
|
+
self.log(record)
|
529
|
+
|
530
|
+
def log_new_wrapper(self, *args, **kwargs) -> None:
|
531
|
+
# Do not log new wrapper.
|
532
|
+
# This is not used at the moment.
|
533
|
+
return
|
534
|
+
|
535
|
+
def log_new_client(
|
536
|
+
self,
|
537
|
+
client,
|
538
|
+
wrapper: OpenAIWrapper,
|
539
|
+
init_args: Dict[str, Any],
|
540
|
+
) -> None:
|
541
|
+
if not self.logger:
|
542
|
+
return
|
543
|
+
|
544
|
+
record = self.new_record(event_name=Events.NEW_CLIENT)
|
545
|
+
# init_args may contain credentials like api_key
|
546
|
+
record.kwargs = {
|
547
|
+
"client_id": id(client),
|
548
|
+
"wrapper_id": id(wrapper),
|
549
|
+
"class": client.__class__.__name__,
|
550
|
+
"args": serialize(init_args),
|
551
|
+
}
|
552
|
+
|
553
|
+
self.log(record)
|
554
|
+
|
555
|
+
def __repr__(self) -> str:
|
556
|
+
return self.session.__repr__()
|
557
|
+
|
558
|
+
def __enter__(self) -> "SessionLogger":
|
559
|
+
"""Starts the session logger
|
560
|
+
|
561
|
+
Returns
|
562
|
+
-------
|
563
|
+
SessionLogger
|
564
|
+
The session logger
|
565
|
+
"""
|
566
|
+
runtime_logging.start(self)
|
567
|
+
return self
|
568
|
+
|
569
|
+
def __exit__(self, exc_type, exc_value, tb):
|
570
|
+
"""Stops the session logger."""
|
571
|
+
if exc_type:
|
572
|
+
record = self.new_record(event_name=Events.EXCEPTION)
|
573
|
+
record.kwargs = {
|
574
|
+
"exc_type": exc_type.__name__,
|
575
|
+
"exc_value": str(exc_value),
|
576
|
+
"traceback": "".join(traceback.format_tb(tb)),
|
577
|
+
"locals": serialize(tb.tb_frame.f_locals),
|
578
|
+
}
|
579
|
+
self.log(record)
|
580
|
+
runtime_logging.stop(self)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
import inspect
|
5
|
+
import json
|
6
|
+
from types import SimpleNamespace
|
7
|
+
from typing import Any, Dict, List, Tuple, Union
|
8
|
+
|
9
|
+
|
10
|
+
def is_json_serializable(obj: Any) -> bool:
|
11
|
+
"""Checks if an object is JSON serializable.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
obj : Any
|
16
|
+
Any object.
|
17
|
+
|
18
|
+
Returns
|
19
|
+
-------
|
20
|
+
bool
|
21
|
+
True if the object is JSON serializable, otherwise False.
|
22
|
+
"""
|
23
|
+
try:
|
24
|
+
json.dumps(obj)
|
25
|
+
except Exception:
|
26
|
+
return False
|
27
|
+
return True
|
28
|
+
|
29
|
+
|
30
|
+
def serialize_response(response) -> dict:
|
31
|
+
"""Serializes the LLM response to dictionary."""
|
32
|
+
if isinstance(response, SimpleNamespace) or is_json_serializable(response):
|
33
|
+
# Convert simpleNamespace to dict
|
34
|
+
return json.loads(json.dumps(response, default=vars))
|
35
|
+
elif hasattr(response, "dict") and callable(response.dict):
|
36
|
+
return json.loads(json.dumps(response.dict(), default=str))
|
37
|
+
elif hasattr(response, "model") and hasattr(response, "choices"):
|
38
|
+
return {
|
39
|
+
"model": response.model,
|
40
|
+
"choices": [
|
41
|
+
{"message": {"content": choice.message.content}}
|
42
|
+
for choice in response.choices
|
43
|
+
],
|
44
|
+
"response": str(response),
|
45
|
+
}
|
46
|
+
return {
|
47
|
+
"model": "",
|
48
|
+
"choices": [{"message": {"content": response}}],
|
49
|
+
"response": str(response),
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
def serialize(
|
54
|
+
obj: Union[int, float, str, bool, Dict[Any, Any], List[Any], Tuple[Any, ...], Any],
|
55
|
+
exclude: Tuple[str, ...] = ("api_key", "__class__"),
|
56
|
+
no_recursive: Tuple[Any, ...] = (),
|
57
|
+
) -> Any:
|
58
|
+
"""Serializes an object for logging purpose."""
|
59
|
+
try:
|
60
|
+
if isinstance(obj, (int, float, str, bool)):
|
61
|
+
return obj
|
62
|
+
elif callable(obj):
|
63
|
+
return inspect.getsource(obj).strip()
|
64
|
+
elif isinstance(obj, dict):
|
65
|
+
return {
|
66
|
+
str(k): (
|
67
|
+
serialize(str(v))
|
68
|
+
if isinstance(v, no_recursive)
|
69
|
+
else serialize(v, exclude, no_recursive)
|
70
|
+
)
|
71
|
+
for k, v in obj.items()
|
72
|
+
if k not in exclude
|
73
|
+
}
|
74
|
+
elif isinstance(obj, (list, tuple)):
|
75
|
+
return [
|
76
|
+
(
|
77
|
+
serialize(str(v))
|
78
|
+
if isinstance(v, no_recursive)
|
79
|
+
else serialize(v, exclude, no_recursive)
|
80
|
+
)
|
81
|
+
for v in obj
|
82
|
+
]
|
83
|
+
else:
|
84
|
+
return str(obj)
|
85
|
+
except Exception:
|
86
|
+
return str(obj)
|