ragaai-catalyst 2.1.4.1b0__py3-none-any.whl → 2.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ragaai_catalyst/__init__.py +23 -2
  2. ragaai_catalyst/dataset.py +462 -1
  3. ragaai_catalyst/evaluation.py +76 -7
  4. ragaai_catalyst/ragaai_catalyst.py +52 -10
  5. ragaai_catalyst/redteaming/__init__.py +7 -0
  6. ragaai_catalyst/redteaming/config/detectors.toml +13 -0
  7. ragaai_catalyst/redteaming/data_generator/scenario_generator.py +95 -0
  8. ragaai_catalyst/redteaming/data_generator/test_case_generator.py +120 -0
  9. ragaai_catalyst/redteaming/evaluator.py +125 -0
  10. ragaai_catalyst/redteaming/llm_generator.py +136 -0
  11. ragaai_catalyst/redteaming/llm_generator_old.py +83 -0
  12. ragaai_catalyst/redteaming/red_teaming.py +331 -0
  13. ragaai_catalyst/redteaming/requirements.txt +4 -0
  14. ragaai_catalyst/redteaming/tests/grok.ipynb +97 -0
  15. ragaai_catalyst/redteaming/tests/stereotype.ipynb +2258 -0
  16. ragaai_catalyst/redteaming/upload_result.py +38 -0
  17. ragaai_catalyst/redteaming/utils/issue_description.py +114 -0
  18. ragaai_catalyst/redteaming/utils/rt.png +0 -0
  19. ragaai_catalyst/redteaming_old.py +171 -0
  20. ragaai_catalyst/synthetic_data_generation.py +400 -22
  21. ragaai_catalyst/tracers/__init__.py +17 -1
  22. ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +4 -2
  23. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +212 -148
  24. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +657 -247
  25. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +50 -19
  26. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +588 -177
  27. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +99 -100
  28. ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +3 -3
  29. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +230 -29
  30. ragaai_catalyst/tracers/agentic_tracing/upload/trace_uploader.py +358 -0
  31. ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +75 -20
  32. ragaai_catalyst/tracers/agentic_tracing/upload/upload_code.py +55 -11
  33. ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +74 -0
  34. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +47 -16
  35. ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +4 -2
  36. ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +26 -3
  37. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +182 -17
  38. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1233 -497
  39. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +81 -10
  40. ragaai_catalyst/tracers/agentic_tracing/utils/supported_llm_provider.toml +34 -0
  41. ragaai_catalyst/tracers/agentic_tracing/utils/system_monitor.py +215 -0
  42. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
  43. ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
  44. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +73 -47
  45. ragaai_catalyst/tracers/distributed.py +300 -0
  46. ragaai_catalyst/tracers/exporters/__init__.py +3 -1
  47. ragaai_catalyst/tracers/exporters/dynamic_trace_exporter.py +160 -0
  48. ragaai_catalyst/tracers/exporters/ragaai_trace_exporter.py +129 -0
  49. ragaai_catalyst/tracers/langchain_callback.py +809 -0
  50. ragaai_catalyst/tracers/llamaindex_instrumentation.py +424 -0
  51. ragaai_catalyst/tracers/tracer.py +301 -55
  52. ragaai_catalyst/tracers/upload_traces.py +24 -7
  53. ragaai_catalyst/tracers/utils/convert_langchain_callbacks_output.py +61 -0
  54. ragaai_catalyst/tracers/utils/convert_llama_instru_callback.py +69 -0
  55. ragaai_catalyst/tracers/utils/extraction_logic_llama_index.py +74 -0
  56. ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +82 -0
  57. ragaai_catalyst/tracers/utils/model_prices_and_context_window_backup.json +9365 -0
  58. ragaai_catalyst/tracers/utils/trace_json_converter.py +269 -0
  59. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/METADATA +367 -45
  60. ragaai_catalyst-2.1.5.dist-info/RECORD +97 -0
  61. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/WHEEL +1 -1
  62. ragaai_catalyst-2.1.4.1b0.dist-info/RECORD +0 -67
  63. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/LICENSE +0 -0
  64. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,74 @@
1
+ import logging
2
+ import os
3
+ import requests
4
+
5
+ from ragaai_catalyst import RagaAICatalyst
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logging_level = (
9
+ logger.setLevel(logging.DEBUG)
10
+ if os.getenv("DEBUG")
11
+ else logger.setLevel(logging.INFO)
12
+ )
13
+
14
+
15
+ def calculate_metric(project_id, metric_name, model, provider, **kwargs):
16
+ user_id = "1"
17
+ org_domain = "raga"
18
+
19
+ headers = {
20
+ "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
21
+ "X-Project-Id": str(project_id),
22
+ "Content-Type": "application/json"
23
+ }
24
+
25
+ payload = {
26
+ "data": [
27
+ {
28
+ "metric_name": metric_name,
29
+ "metric_config": {
30
+ "threshold": {
31
+ "isEditable": True,
32
+ "lte": 0.3
33
+ },
34
+ "model": model,
35
+ "orgDomain": org_domain,
36
+ "provider": provider,
37
+ "user_id": user_id,
38
+ "job_id": 1,
39
+ "metric_name": metric_name,
40
+ "request_id": 1
41
+ },
42
+ "variable_mapping": kwargs,
43
+ "trace_object": {
44
+ "Data": {
45
+ "DocId": "doc-1",
46
+ "Prompt": kwargs.get("prompt"),
47
+ "Response": kwargs.get("response"),
48
+ "Context": kwargs.get("context"),
49
+ "ExpectedResponse": kwargs.get("expected_response"),
50
+ "ExpectedContext": kwargs.get("expected_context"),
51
+ "Chat": kwargs.get("chat"),
52
+ "Instructions": kwargs.get("instructions"),
53
+ "SystemPrompt": kwargs.get("system_prompt"),
54
+ "Text": kwargs.get("text")
55
+ },
56
+ "claims": {},
57
+ "last_computed_metrics": {
58
+ metric_name: {
59
+ }
60
+ }
61
+ }
62
+ }
63
+ ]
64
+ }
65
+
66
+ try:
67
+ BASE_URL = RagaAICatalyst.BASE_URL
68
+ response = requests.post(f"{BASE_URL}/v1/llm/calculate-metric", headers=headers, json=payload, timeout=30)
69
+ logger.debug(f"Metric calculation response status {response.status_code}")
70
+ response.raise_for_status()
71
+ return response.json()
72
+ except requests.exceptions.RequestException as e:
73
+ logger.debug(f"Error in calculate-metric api: {e}, payload: {payload}")
74
+ raise Exception(f"Error in calculate-metric: {e}")
@@ -1,27 +1,40 @@
1
+ import logging
2
+
1
3
  import requests
2
4
  import os
3
5
  import json
6
+ import time
4
7
  from ....ragaai_catalyst import RagaAICatalyst
5
8
  from ..utils.get_user_trace_metrics import get_user_trace_metrics
6
9
 
7
- def upload_trace_metric(json_file_path, dataset_name, project_name):
10
+ logger = logging.getLogger(__name__)
11
+ logging_level = (
12
+ logger.setLevel(logging.DEBUG)
13
+ if os.getenv("DEBUG")
14
+ else logger.setLevel(logging.INFO)
15
+ )
16
+
17
+
18
+ def upload_trace_metric(json_file_path, dataset_name, project_name, base_url=None):
8
19
  try:
9
20
  with open(json_file_path, "r") as f:
10
21
  traces = json.load(f)
22
+
11
23
  metrics = get_trace_metrics_from_trace(traces)
12
24
  metrics = _change_metrics_format_for_payload(metrics)
13
25
 
14
26
  user_trace_metrics = get_user_trace_metrics(project_name, dataset_name)
15
27
  if user_trace_metrics:
16
28
  user_trace_metrics_list = [metric["displayName"] for metric in user_trace_metrics]
17
-
29
+
18
30
  if user_trace_metrics:
19
31
  for metric in metrics:
20
32
  if metric["displayName"] in user_trace_metrics_list:
21
- metricConfig = next((user_metric["metricConfig"] for user_metric in user_trace_metrics if user_metric["displayName"] == metric["displayName"]), None)
33
+ metricConfig = next((user_metric["metricConfig"] for user_metric in user_trace_metrics if
34
+ user_metric["displayName"] == metric["displayName"]), None)
22
35
  if not metricConfig or metricConfig.get("Metric Source", {}).get("value") != "user":
23
- raise ValueError(f"Metrics {metric['displayName']} already exist in dataset {dataset_name} of project {project_name}.")
24
-
36
+ raise ValueError(
37
+ f"Metrics {metric['displayName']} already exist in dataset {dataset_name} of project {project_name}.")
25
38
  headers = {
26
39
  "Content-Type": "application/json",
27
40
  "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
@@ -31,11 +44,17 @@ def upload_trace_metric(json_file_path, dataset_name, project_name):
31
44
  "datasetName": dataset_name,
32
45
  "metrics": metrics
33
46
  })
34
- response = requests.request("POST",
35
- f"{RagaAICatalyst.BASE_URL}/v1/llm/trace/metrics",
36
- headers=headers,
47
+ url_base = base_url if base_url is not None else RagaAICatalyst.BASE_URL
48
+ start_time = time.time()
49
+ endpoint = f"{url_base}/v1/llm/trace/metrics"
50
+ response = requests.request("POST",
51
+ endpoint,
52
+ headers=headers,
37
53
  data=payload,
38
54
  timeout=10)
55
+ elapsed_ms = (time.time() - start_time) * 1000
56
+ logger.debug(
57
+ f"API Call: [POST] {endpoint} | Status: {response.status_code} | Time: {elapsed_ms:.2f}ms")
39
58
  if response.status_code != 200:
40
59
  raise ValueError(f"Error inserting agentic trace metrics")
41
60
  except requests.exceptions.RequestException as e:
@@ -59,25 +78,37 @@ def _get_children_metrics_of_agent(children_traces):
59
78
 
60
79
  def get_trace_metrics_from_trace(traces):
61
80
  metrics = []
81
+
82
+ # get trace level metrics
83
+ if "metrics" in traces.keys():
84
+ if len(traces["metrics"]) > 0:
85
+ metrics.extend(traces["metrics"])
86
+
87
+ # get span level metrics
62
88
  for span in traces["data"][0]["spans"]:
63
89
  if span["type"] == "agent":
90
+ # Add children metrics of agent
64
91
  children_metric = _get_children_metrics_of_agent(span["data"]["children"])
65
92
  if children_metric:
66
93
  metrics.extend(children_metric)
67
- else:
68
- metric = span.get("metrics", [])
69
- if metric:
70
- metrics.extend(metric)
94
+
95
+ metric = span.get("metrics", [])
96
+ if metric:
97
+ metrics.extend(metric)
71
98
  return metrics
72
99
 
100
+
73
101
  def _change_metrics_format_for_payload(metrics):
74
102
  formatted_metrics = []
75
103
  for metric in metrics:
76
- if any(m["name"] == metric["name"] for m in formatted_metrics):
104
+ if any(m["name"] == metric.get("displayName") or m['name'] == metric.get("name") for m in formatted_metrics):
77
105
  continue
106
+ metric_display_name = metric["name"]
107
+ if metric.get("displayName"):
108
+ metric_display_name = metric['displayName']
78
109
  formatted_metrics.append({
79
- "name": metric["name"],
80
- "displayName": metric["name"],
110
+ "name": metric_display_name,
111
+ "displayName": metric_display_name,
81
112
  "config": {"source": "user"},
82
113
  })
83
- return formatted_metrics
114
+ return formatted_metrics
@@ -4,7 +4,7 @@ import re
4
4
  import requests
5
5
  from ragaai_catalyst.tracers.agentic_tracing.tracers.base import RagaAICatalyst
6
6
 
7
- def create_dataset_schema_with_trace(project_name, dataset_name):
7
+ def create_dataset_schema_with_trace(project_name, dataset_name, base_url=None):
8
8
  def make_request():
9
9
  headers = {
10
10
  "Content-Type": "application/json",
@@ -15,8 +15,10 @@ def create_dataset_schema_with_trace(project_name, dataset_name):
15
15
  "datasetName": dataset_name,
16
16
  "traceFolderUrl": None,
17
17
  })
18
+ # Use provided base_url or fall back to default
19
+ url_base = base_url if base_url is not None else RagaAICatalyst.BASE_URL
18
20
  response = requests.request("POST",
19
- f"{RagaAICatalyst.BASE_URL}/v1/llm/dataset/logs",
21
+ f"{url_base}/v1/llm/dataset/logs",
20
22
  headers=headers,
21
23
  data=payload,
22
24
  timeout=10
@@ -8,13 +8,32 @@ class TrackName:
8
8
  def trace_decorator(self, func):
9
9
  @wraps(func)
10
10
  def wrapper(*args, **kwargs):
11
- file_name = self._get_file_name()
11
+ file_name = self._get_decorated_file_name()
12
12
  self.files.add(file_name)
13
13
 
14
14
  return func(*args, **kwargs)
15
15
  return wrapper
16
+
17
+ def trace_wrapper(self, func):
18
+ @wraps(func)
19
+ def wrapper(*args, **kwargs):
20
+ file_name = self._get_wrapped_file_name()
21
+ self.files.add(file_name)
22
+ return func(*args, **kwargs)
23
+ return wrapper
24
+
25
+ def _get_wrapped_file_name(self):
26
+ try:
27
+ from IPython import get_ipython
28
+ if 'IPKernelApp' in get_ipython().config:
29
+ return self._get_notebook_name()
30
+ except Exception:
31
+ pass
32
+
33
+ frame = inspect.stack()[4]
34
+ return frame.filename
16
35
 
17
- def _get_file_name(self):
36
+ def _get_decorated_file_name(self):
18
37
  # Check if running in a Jupyter notebook
19
38
  try:
20
39
  from IPython import get_ipython
@@ -43,4 +62,8 @@ class TrackName:
43
62
 
44
63
  def reset(self):
45
64
  """Reset the file tracker by clearing all tracked files."""
46
- self.files.clear()
65
+ self.files.clear()
66
+
67
+ def trace_main_file(self):
68
+ frame = inspect.stack()[-1]
69
+ self.files.add(frame.filename)
@@ -2,14 +2,30 @@ from ..data.data_structure import LLMCall
2
2
  from .trace_utils import (
3
3
  calculate_cost,
4
4
  convert_usage_to_dict,
5
- load_model_costs,
6
5
  )
7
6
  from importlib import resources
7
+ #from litellm import model_cost
8
8
  import json
9
9
  import os
10
10
  import asyncio
11
11
  import psutil
12
+ import tiktoken
13
+ import logging
12
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def get_model_cost():
18
+ """Load model costs from a JSON file.
19
+ Note: This file should be updated periodically or whenever a new package is created to ensure accurate cost calculations.
20
+ To Do: Implement to do this automatically.
21
+ """
22
+ file="model_prices_and_context_window_backup.json"
23
+ d={}
24
+ with resources.open_text("ragaai_catalyst.tracers.utils", file) as f:
25
+ d= json.load(f)
26
+ return d
27
+
28
+ model_cost = get_model_cost()
13
29
 
14
30
  def extract_model_name(args, kwargs, result):
15
31
  """Extract model name from kwargs or result"""
@@ -35,7 +51,18 @@ def extract_model_name(args, kwargs, result):
35
51
  metadata = manager.metadata
36
52
  model_name = metadata.get('ls_model_name', None)
37
53
  if model_name:
38
- model = model_name
54
+ model = model_name
55
+
56
+ if not model:
57
+ if 'to_dict' in dir(result):
58
+ result = result.to_dict()
59
+ if 'model_version' in result:
60
+ model = result['model_version']
61
+ try:
62
+ if not model:
63
+ model = result.raw.model
64
+ except Exception as e:
65
+ pass
39
66
 
40
67
 
41
68
  # Normalize Google model names
@@ -48,10 +75,9 @@ def extract_model_name(args, kwargs, result):
48
75
  if "gemini-pro" in model:
49
76
  return "gemini-pro"
50
77
 
51
- if 'to_dict' in dir(result):
52
- result = result.to_dict()
53
- if 'model_version' in result:
54
- model = result['model_version']
78
+ if 'response_metadata' in dir(result):
79
+ if 'model_name' in result.response_metadata:
80
+ model = result.response_metadata['model_name']
55
81
 
56
82
  return model or "default"
57
83
 
@@ -67,6 +93,9 @@ def extract_parameters(kwargs):
67
93
  # Remove messages key in parameters (OpenAI message)
68
94
  if 'messages' in parameters:
69
95
  del parameters['messages']
96
+
97
+ if 'run_manager' in parameters:
98
+ del parameters['run_manager']
70
99
 
71
100
  if 'generation_config' in parameters:
72
101
  generation_config = parameters['generation_config']
@@ -91,8 +120,8 @@ def extract_token_usage(result):
91
120
  # Run the coroutine in the current event loop
92
121
  result = loop.run_until_complete(result)
93
122
 
94
- # Handle text attribute responses (JSON string or Vertex AI)
95
- if hasattr(result, "text"):
123
+ # Handle text attribute responses (JSON string for Vertex AI)
124
+ if hasattr(result, "text") and isinstance(result.text, (str, bytes, bytearray)):
96
125
  # First try parsing as JSON for OpenAI responses
97
126
  try:
98
127
  import json
@@ -137,10 +166,34 @@ def extract_token_usage(result):
137
166
  # Handle Google GenerativeAI format with usage_metadata
138
167
  if hasattr(result, "usage_metadata"):
139
168
  metadata = result.usage_metadata
169
+ if hasattr(metadata, "prompt_token_count"):
170
+ return {
171
+ "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
172
+ "completion_tokens": getattr(metadata, "candidates_token_count", 0),
173
+ "total_tokens": getattr(metadata, "total_token_count", 0)
174
+ }
175
+ elif hasattr(metadata, "input_tokens"):
176
+ return {
177
+ "prompt_tokens": getattr(metadata, "input_tokens", 0),
178
+ "completion_tokens": getattr(metadata, "output_tokens", 0),
179
+ "total_tokens": getattr(metadata, "total_tokens", 0)
180
+ }
181
+ elif "input_tokens" in metadata:
182
+ return {
183
+ "prompt_tokens": metadata["input_tokens"],
184
+ "completion_tokens": metadata["output_tokens"],
185
+ "total_tokens": metadata["total_tokens"]
186
+ }
187
+
188
+
189
+
190
+ # Handle ChatResponse format with raw usuage
191
+ if hasattr(result, "raw") and hasattr(result.raw, "usage"):
192
+ usage = result.raw.usage
140
193
  return {
141
- "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
142
- "completion_tokens": getattr(metadata, "candidates_token_count", 0),
143
- "total_tokens": getattr(metadata, "total_token_count", 0)
194
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
195
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
196
+ "total_tokens": getattr(usage, "total_tokens", 0)
144
197
  }
145
198
 
146
199
  # Handle ChatResult format with generations
@@ -173,24 +226,129 @@ def extract_token_usage(result):
173
226
  "total_tokens": 0
174
227
  }
175
228
 
229
+ def num_tokens_from_messages(model="gpt-4o-mini-2024-07-18", prompt_messages=None, response_message=None):
230
+ """Calculate the number of tokens used by messages.
231
+
232
+ Args:
233
+ messages: Optional list of messages (deprecated, use prompt_messages and response_message instead)
234
+ model: The model name to use for token calculation
235
+ prompt_messages: List of prompt messages
236
+ response_message: Response message from the assistant
237
+
238
+ Returns:
239
+ dict: A dictionary containing:
240
+ - prompt_tokens: Number of tokens in the prompt
241
+ - completion_tokens: Number of tokens in the completion
242
+ - total_tokens: Total number of tokens
243
+ """
244
+ #import pdb; pdb.set_trace()
245
+ try:
246
+ encoding = tiktoken.encoding_for_model(model)
247
+ except KeyError:
248
+ logging.warning("Warning: model not found. Using o200k_base encoding.")
249
+ encoding = tiktoken.get_encoding("o200k_base")
250
+
251
+ if model in {
252
+ "gpt-3.5-turbo-0125",
253
+ "gpt-4-0314",
254
+ "gpt-4-32k-0314",
255
+ "gpt-4-0613",
256
+ "gpt-4-32k-0613",
257
+ "gpt-4o-2024-08-06",
258
+ "gpt-4o-mini-2024-07-18"
259
+ }:
260
+ tokens_per_message = 3
261
+ tokens_per_name = 1
262
+ elif "gpt-3.5-turbo" in model:
263
+ logging.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
264
+ return num_tokens_from_messages(model="gpt-3.5-turbo-0125",
265
+ prompt_messages=prompt_messages, response_message=response_message)
266
+ elif "gpt-4o-mini" in model:
267
+ logging.warning("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
268
+ return num_tokens_from_messages(model="gpt-4o-mini-2024-07-18",
269
+ prompt_messages=prompt_messages, response_message=response_message)
270
+ elif "gpt-4o" in model:
271
+ logging.warning("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
272
+ return num_tokens_from_messages(model="gpt-4o-2024-08-06",
273
+ prompt_messages=prompt_messages, response_message=response_message)
274
+ elif "gpt-4" in model:
275
+ logging.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
276
+ return num_tokens_from_messages(model="gpt-4-0613",
277
+ prompt_messages=prompt_messages, response_message=response_message)
278
+ else:
279
+ raise NotImplementedError(
280
+ f"""num_tokens_from_messages() is not implemented for model {model}."""
281
+ )
282
+
283
+ all_messages = []
284
+ if prompt_messages:
285
+ all_messages.extend(prompt_messages)
286
+ if response_message:
287
+ if isinstance(response_message, dict):
288
+ all_messages.append(response_message)
289
+ else:
290
+ all_messages.append({"role": "assistant", "content": response_message})
291
+
292
+ prompt_tokens = 0
293
+ completion_tokens = 0
294
+
295
+ for message in all_messages:
296
+ num_tokens = tokens_per_message
297
+ for key, value in message.items():
298
+ token_count = len(encoding.encode(str(value))) # Convert value to string for safety
299
+ num_tokens += token_count
300
+ if key == "name":
301
+ num_tokens += tokens_per_name
302
+
303
+ # Add tokens to prompt or completion based on role
304
+ if message.get("role") == "assistant":
305
+ completion_tokens += num_tokens
306
+ else:
307
+ prompt_tokens += num_tokens
308
+
309
+ # Add the assistant message prefix tokens to completion tokens if we have a response
310
+ if completion_tokens > 0:
311
+ completion_tokens += 3 # <|start|>assistant<|message|>
312
+
313
+ total_tokens = prompt_tokens + completion_tokens
314
+
315
+ return {
316
+ "prompt_tokens": prompt_tokens,
317
+ "completion_tokens": completion_tokens,
318
+ "total_tokens": total_tokens
319
+ }
176
320
 
177
321
  def extract_input_data(args, kwargs, result):
178
- """Extract input data from function call"""
322
+ """Sanitize and format input data, including handling of nested lists and dictionaries."""
323
+
324
+ def sanitize_value(value):
325
+ if isinstance(value, (int, float, bool, str)):
326
+ return value
327
+ elif isinstance(value, list):
328
+ return [sanitize_value(item) for item in value]
329
+ elif isinstance(value, dict):
330
+ return {key: sanitize_value(val) for key, val in value.items()}
331
+ else:
332
+ return str(value) # Convert non-standard types to string
333
+
179
334
  return {
180
- 'args': args,
181
- 'kwargs': kwargs
335
+ "args": [sanitize_value(arg) for arg in args],
336
+ "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
182
337
  }
183
338
 
184
339
 
185
- def calculate_llm_cost(token_usage, model_name, model_costs):
340
+ def calculate_llm_cost(token_usage, model_name, model_costs, model_custom_cost=None):
186
341
  """Calculate cost based on token usage and model"""
342
+ if model_custom_cost is None:
343
+ model_custom_cost = {}
344
+ model_costs.update(model_custom_cost)
187
345
  if not isinstance(token_usage, dict):
188
346
  token_usage = {
189
347
  "prompt_tokens": 0,
190
348
  "completion_tokens": 0,
191
349
  "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
192
350
  }
193
-
351
+
194
352
  # Get model costs, defaulting to default costs if unknown
195
353
  model_cost = model_cost = model_costs.get(model_name, {
196
354
  "input_cost_per_token": 0.0,
@@ -277,6 +435,13 @@ def extract_llm_output(result):
277
435
  })
278
436
  return OutputResponse(output)
279
437
 
438
+ # Handle AIMessage Format
439
+ if hasattr(result, "content"):
440
+ return OutputResponse([{
441
+ "content": result.content,
442
+ "role": getattr(result, "role", "assistant")
443
+ }])
444
+
280
445
  # Handle Vertex AI format
281
446
  # format1
282
447
  if hasattr(result, "text"):
@@ -424,7 +589,7 @@ def extract_llm_data(args, kwargs, result):
424
589
  token_usage = extract_token_usage(result)
425
590
 
426
591
  # Load model costs
427
- model_costs = load_model_costs()
592
+ model_costs = model_cost
428
593
 
429
594
  # Calculate cost
430
595
  cost = calculate_llm_cost(token_usage, model_name, model_costs)