ragaai-catalyst 2.2.2__py3-none-any.whl → 2.2.3b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,17 +2,23 @@
2
2
  trace_uploader.py - A dedicated process for handling trace uploads
3
3
  """
4
4
 
5
- import argparse
6
- import atexit
7
- import concurrent.futures
5
+ import os
6
+ import sys
8
7
  import json
8
+ import time
9
+ import signal
9
10
  import logging
10
- import os
11
+ import argparse
11
12
  import tempfile
12
- import time
13
+ from pathlib import Path
14
+ import multiprocessing
15
+ import queue
13
16
  from datetime import datetime
17
+ import atexit
18
+ import glob
14
19
  from logging.handlers import RotatingFileHandler
15
- from typing import Any, Dict
20
+ import concurrent.futures
21
+ from typing import Dict, Any, Optional
16
22
 
17
23
  # Set up logging
18
24
  log_dir = os.path.join(tempfile.gettempdir(), "ragaai_logs")
@@ -37,16 +43,11 @@ logging.basicConfig(
37
43
  logger = logging.getLogger("trace_uploader")
38
44
 
39
45
  try:
40
- from ragaai_catalyst import RagaAICatalyst
41
- from ragaai_catalyst.tracers.agentic_tracing.upload.upload_agentic_traces import (
42
- UploadAgenticTraces,
43
- )
46
+ from ragaai_catalyst.tracers.agentic_tracing.upload.upload_agentic_traces import UploadAgenticTraces
44
47
  from ragaai_catalyst.tracers.agentic_tracing.upload.upload_code import upload_code
45
-
46
48
  # from ragaai_catalyst.tracers.agentic_tracing.upload.upload_trace_metric import upload_trace_metric
47
- from ragaai_catalyst.tracers.agentic_tracing.utils.create_dataset_schema import (
48
- create_dataset_schema_with_trace,
49
- )
49
+ from ragaai_catalyst.tracers.agentic_tracing.utils.create_dataset_schema import create_dataset_schema_with_trace
50
+ from ragaai_catalyst import RagaAICatalyst
50
51
  IMPORTS_AVAILABLE = True
51
52
  except ImportError:
52
53
  logger.warning("RagaAI Catalyst imports not available - running in test mode")
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
 
3
- #from litellm import model_cost
3
+ # from litellm import model_cost
4
4
  import json
5
5
  import logging
6
6
  import os
@@ -12,24 +12,26 @@ from ..data.data_structure import LLMCall
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
+
15
16
  def get_model_cost():
16
- """Load model costs from a JSON file.
17
+ """Load model costs from a JSON file.
17
18
  Note: This file should be updated periodically or whenever a new package is created to ensure accurate cost calculations.
18
19
  To Do: Implement to do this automatically.
19
20
  """
20
- file="model_prices_and_context_window_backup.json"
21
- d={}
21
+ file = "model_prices_and_context_window_backup.json"
22
+ d = {}
22
23
  with resources.open_text("ragaai_catalyst.tracers.utils", file) as f:
23
- d= json.load(f)
24
- return d
24
+ d = json.load(f)
25
+ return d
26
+
25
27
 
26
28
  model_cost = get_model_cost()
27
29
 
30
+
28
31
  def extract_model_name(args, kwargs, result):
29
32
  """Extract model name from kwargs or result"""
30
33
  # First try direct model parameter
31
34
  model = kwargs.get("model", "")
32
-
33
35
  if not model:
34
36
  # Try to get from instance
35
37
  instance = kwargs.get("self", None)
@@ -45,24 +47,23 @@ def extract_model_name(args, kwargs, result):
45
47
  if not model:
46
48
  manager = kwargs.get("run_manager", None)
47
49
  if manager:
48
- if hasattr(manager, 'metadata'):
50
+ if hasattr(manager, "metadata"):
49
51
  metadata = manager.metadata
50
- model_name = metadata.get('ls_model_name', None)
52
+ model_name = metadata.get("ls_model_name", None)
51
53
  if model_name:
52
- model = model_name
53
-
54
+ model = model_name
55
+
54
56
  if not model:
55
- if 'to_dict' in dir(result):
57
+ if "to_dict" in dir(result):
56
58
  result = result.to_dict()
57
- if 'model_version' in result:
58
- model = result['model_version']
59
+ if "model_version" in result:
60
+ model = result["model_version"]
59
61
  try:
60
62
  if not model:
61
63
  model = result.raw.model
62
64
  except Exception:
63
65
  pass
64
-
65
-
66
+
66
67
  # Normalize Google model names
67
68
  if model and isinstance(model, str):
68
69
  model = model.lower()
@@ -73,10 +74,10 @@ def extract_model_name(args, kwargs, result):
73
74
  if "gemini-pro" in model:
74
75
  return "gemini-pro"
75
76
 
76
- if 'response_metadata' in dir(result):
77
- if 'model_name' in result.response_metadata:
78
- model = result.response_metadata['model_name']
79
-
77
+ if "response_metadata" in dir(result):
78
+ if "model_name" in result.response_metadata:
79
+ model = result.response_metadata["model_name"]
80
+
80
81
  return model or "default"
81
82
 
82
83
 
@@ -85,27 +86,29 @@ def extract_parameters(kwargs):
85
86
  parameters = {k: v for k, v in kwargs.items() if v is not None}
86
87
 
87
88
  # Remove contents key in parameters (Google LLM Response)
88
- if 'contents' in parameters:
89
- del parameters['contents']
89
+ if "contents" in parameters:
90
+ del parameters["contents"]
90
91
 
91
92
  # Remove messages key in parameters (OpenAI message)
92
- if 'messages' in parameters:
93
- del parameters['messages']
94
-
95
- if 'run_manager' in parameters:
96
- del parameters['run_manager']
97
-
98
- if 'generation_config' in parameters:
99
- generation_config = parameters['generation_config']
93
+ if "messages" in parameters:
94
+ del parameters["messages"]
95
+
96
+ if "run_manager" in parameters:
97
+ del parameters["run_manager"]
98
+
99
+ if "generation_config" in parameters:
100
+ generation_config = parameters["generation_config"]
100
101
  # If generation_config is already a dict, use it directly
101
102
  if isinstance(generation_config, dict):
102
103
  config_dict = generation_config
103
104
  else:
104
105
  # Convert GenerationConfig to dictionary if it has a to_dict method, otherwise try to get its __dict__
105
- config_dict = getattr(generation_config, 'to_dict', lambda: generation_config.__dict__)()
106
+ config_dict = getattr(
107
+ generation_config, "to_dict", lambda: generation_config.__dict__
108
+ )()
106
109
  parameters.update(config_dict)
107
- del parameters['generation_config']
108
-
110
+ del parameters["generation_config"]
111
+
109
112
  return parameters
110
113
 
111
114
 
@@ -123,13 +126,14 @@ def extract_token_usage(result):
123
126
  # First try parsing as JSON for OpenAI responses
124
127
  try:
125
128
  import json
129
+
126
130
  json_data = json.loads(result.text)
127
131
  if isinstance(json_data, dict) and "usage" in json_data:
128
132
  usage = json_data["usage"]
129
133
  return {
130
134
  "prompt_tokens": usage.get("prompt_tokens", 0),
131
135
  "completion_tokens": usage.get("completion_tokens", 0),
132
- "total_tokens": usage.get("total_tokens", 0)
136
+ "total_tokens": usage.get("total_tokens", 0),
133
137
  }
134
138
  except (json.JSONDecodeError, AttributeError):
135
139
  pass
@@ -142,7 +146,7 @@ def extract_token_usage(result):
142
146
  return {
143
147
  "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
144
148
  "completion_tokens": total_tokens,
145
- "total_tokens": total_tokens
149
+ "total_tokens": total_tokens,
146
150
  }
147
151
 
148
152
  # Handle Claude 3 message format
@@ -152,15 +156,15 @@ def extract_token_usage(result):
152
156
  return {
153
157
  "prompt_tokens": usage.input_tokens,
154
158
  "completion_tokens": usage.output_tokens,
155
- "total_tokens": usage.input_tokens + usage.output_tokens
159
+ "total_tokens": usage.input_tokens + usage.output_tokens,
156
160
  }
157
161
  # Handle standard OpenAI/Anthropic format
158
162
  return {
159
163
  "prompt_tokens": getattr(usage, "prompt_tokens", 0),
160
164
  "completion_tokens": getattr(usage, "completion_tokens", 0),
161
- "total_tokens": getattr(usage, "total_tokens", 0)
165
+ "total_tokens": getattr(usage, "total_tokens", 0),
162
166
  }
163
-
167
+
164
168
  # Handle Google GenerativeAI format with usage_metadata
165
169
  if hasattr(result, "usage_metadata"):
166
170
  metadata = result.usage_metadata
@@ -168,37 +172,35 @@ def extract_token_usage(result):
168
172
  return {
169
173
  "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
170
174
  "completion_tokens": getattr(metadata, "candidates_token_count", 0),
171
- "total_tokens": getattr(metadata, "total_token_count", 0)
175
+ "total_tokens": getattr(metadata, "total_token_count", 0),
172
176
  }
173
177
  elif hasattr(metadata, "input_tokens"):
174
178
  return {
175
179
  "prompt_tokens": getattr(metadata, "input_tokens", 0),
176
180
  "completion_tokens": getattr(metadata, "output_tokens", 0),
177
- "total_tokens": getattr(metadata, "total_tokens", 0)
181
+ "total_tokens": getattr(metadata, "total_tokens", 0),
178
182
  }
179
183
  elif "input_tokens" in metadata:
180
184
  return {
181
185
  "prompt_tokens": metadata["input_tokens"],
182
186
  "completion_tokens": metadata["output_tokens"],
183
- "total_tokens": metadata["total_tokens"]
187
+ "total_tokens": metadata["total_tokens"],
184
188
  }
185
189
 
186
-
187
-
188
190
  # Handle ChatResponse format with raw usuage
189
191
  if hasattr(result, "raw") and hasattr(result.raw, "usage"):
190
192
  usage = result.raw.usage
191
193
  return {
192
194
  "prompt_tokens": getattr(usage, "prompt_tokens", 0),
193
195
  "completion_tokens": getattr(usage, "completion_tokens", 0),
194
- "total_tokens": getattr(usage, "total_tokens", 0)
196
+ "total_tokens": getattr(usage, "total_tokens", 0),
195
197
  }
196
-
198
+
197
199
  # Handle ChatResult format with generations
198
200
  if hasattr(result, "generations") and result.generations:
199
201
  # Get the first generation
200
202
  generation = result.generations[0]
201
-
203
+
202
204
  # Try to get usage from generation_info
203
205
  if hasattr(generation, "generation_info"):
204
206
  metadata = generation.generation_info.get("usage_metadata", {})
@@ -206,46 +208,47 @@ def extract_token_usage(result):
206
208
  return {
207
209
  "prompt_tokens": metadata.get("prompt_token_count", 0),
208
210
  "completion_tokens": metadata.get("candidates_token_count", 0),
209
- "total_tokens": metadata.get("total_token_count", 0)
211
+ "total_tokens": metadata.get("total_token_count", 0),
210
212
  }
211
-
213
+
212
214
  # Try to get usage from message's usage_metadata
213
- if hasattr(generation, "message") and hasattr(generation.message, "usage_metadata"):
215
+ if hasattr(generation, "message") and hasattr(
216
+ generation.message, "usage_metadata"
217
+ ):
214
218
  metadata = generation.message.usage_metadata
215
219
  return {
216
220
  "prompt_tokens": metadata.get("input_tokens", 0),
217
221
  "completion_tokens": metadata.get("output_tokens", 0),
218
- "total_tokens": metadata.get("total_tokens", 0)
222
+ "total_tokens": metadata.get("total_tokens", 0),
219
223
  }
220
-
221
- return {
222
- "prompt_tokens": 0,
223
- "completion_tokens": 0,
224
- "total_tokens": 0
225
- }
226
224
 
227
- def num_tokens_from_messages(model="gpt-4o-mini-2024-07-18", prompt_messages=None, response_message=None):
225
+ return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
226
+
227
+
228
+ def num_tokens_from_messages(
229
+ model="gpt-4o-mini-2024-07-18", prompt_messages=None, response_message=None
230
+ ):
228
231
  """Calculate the number of tokens used by messages.
229
-
232
+
230
233
  Args:
231
234
  messages: Optional list of messages (deprecated, use prompt_messages and response_message instead)
232
235
  model: The model name to use for token calculation
233
236
  prompt_messages: List of prompt messages
234
237
  response_message: Response message from the assistant
235
-
238
+
236
239
  Returns:
237
240
  dict: A dictionary containing:
238
241
  - prompt_tokens: Number of tokens in the prompt
239
242
  - completion_tokens: Number of tokens in the completion
240
243
  - total_tokens: Total number of tokens
241
244
  """
242
- #import pdb; pdb.set_trace()
245
+ # import pdb; pdb.set_trace()
243
246
  try:
244
247
  encoding = tiktoken.encoding_for_model(model)
245
248
  except KeyError:
246
249
  logging.warning("Warning: model not found. Using o200k_base encoding.")
247
250
  encoding = tiktoken.get_encoding("o200k_base")
248
-
251
+
249
252
  if model in {
250
253
  "gpt-3.5-turbo-0125",
251
254
  "gpt-4-0314",
@@ -253,31 +256,51 @@ def num_tokens_from_messages(model="gpt-4o-mini-2024-07-18", prompt_messages=Non
253
256
  "gpt-4-0613",
254
257
  "gpt-4-32k-0613",
255
258
  "gpt-4o-2024-08-06",
256
- "gpt-4o-mini-2024-07-18"
257
- }:
259
+ "gpt-4o-mini-2024-07-18",
260
+ }:
258
261
  tokens_per_message = 3
259
262
  tokens_per_name = 1
260
263
  elif "gpt-3.5-turbo" in model:
261
- logging.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
262
- return num_tokens_from_messages(model="gpt-3.5-turbo-0125",
263
- prompt_messages=prompt_messages, response_message=response_message)
264
+ logging.warning(
265
+ "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125."
266
+ )
267
+ return num_tokens_from_messages(
268
+ model="gpt-3.5-turbo-0125",
269
+ prompt_messages=prompt_messages,
270
+ response_message=response_message,
271
+ )
264
272
  elif "gpt-4o-mini" in model:
265
- logging.warning("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
266
- return num_tokens_from_messages(model="gpt-4o-mini-2024-07-18",
267
- prompt_messages=prompt_messages, response_message=response_message)
273
+ logging.warning(
274
+ "Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18."
275
+ )
276
+ return num_tokens_from_messages(
277
+ model="gpt-4o-mini-2024-07-18",
278
+ prompt_messages=prompt_messages,
279
+ response_message=response_message,
280
+ )
268
281
  elif "gpt-4o" in model:
269
- logging.warning("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
270
- return num_tokens_from_messages(model="gpt-4o-2024-08-06",
271
- prompt_messages=prompt_messages, response_message=response_message)
282
+ logging.warning(
283
+ "Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06."
284
+ )
285
+ return num_tokens_from_messages(
286
+ model="gpt-4o-2024-08-06",
287
+ prompt_messages=prompt_messages,
288
+ response_message=response_message,
289
+ )
272
290
  elif "gpt-4" in model:
273
- logging.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
274
- return num_tokens_from_messages(model="gpt-4-0613",
275
- prompt_messages=prompt_messages, response_message=response_message)
291
+ logging.warning(
292
+ "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
293
+ )
294
+ return num_tokens_from_messages(
295
+ model="gpt-4-0613",
296
+ prompt_messages=prompt_messages,
297
+ response_message=response_message,
298
+ )
276
299
  else:
277
300
  raise NotImplementedError(
278
301
  f"""num_tokens_from_messages() is not implemented for model {model}."""
279
302
  )
280
-
303
+
281
304
  all_messages = []
282
305
  if prompt_messages:
283
306
  all_messages.extend(prompt_messages)
@@ -286,36 +309,39 @@ def num_tokens_from_messages(model="gpt-4o-mini-2024-07-18", prompt_messages=Non
286
309
  all_messages.append(response_message)
287
310
  else:
288
311
  all_messages.append({"role": "assistant", "content": response_message})
289
-
312
+
290
313
  prompt_tokens = 0
291
314
  completion_tokens = 0
292
-
315
+
293
316
  for message in all_messages:
294
317
  num_tokens = tokens_per_message
295
318
  for key, value in message.items():
296
- token_count = len(encoding.encode(str(value))) # Convert value to string for safety
319
+ token_count = len(
320
+ encoding.encode(str(value))
321
+ ) # Convert value to string for safety
297
322
  num_tokens += token_count
298
323
  if key == "name":
299
324
  num_tokens += tokens_per_name
300
-
325
+
301
326
  # Add tokens to prompt or completion based on role
302
327
  if message.get("role") == "assistant":
303
328
  completion_tokens += num_tokens
304
329
  else:
305
330
  prompt_tokens += num_tokens
306
-
331
+
307
332
  # Add the assistant message prefix tokens to completion tokens if we have a response
308
333
  if completion_tokens > 0:
309
334
  completion_tokens += 3 # <|start|>assistant<|message|>
310
-
335
+
311
336
  total_tokens = prompt_tokens + completion_tokens
312
-
337
+
313
338
  return {
314
339
  "prompt_tokens": prompt_tokens,
315
340
  "completion_tokens": completion_tokens,
316
- "total_tokens": total_tokens
341
+ "total_tokens": total_tokens,
317
342
  }
318
343
 
344
+
319
345
  def extract_input_data(args, kwargs, result):
320
346
  """Sanitize and format input data, including handling of nested lists and dictionaries."""
321
347
 
@@ -336,6 +362,7 @@ def extract_input_data(args, kwargs, result):
336
362
 
337
363
 
338
364
  def calculate_llm_cost(token_usage, model_name, model_costs, model_custom_cost=None):
365
+ model_name = extract_model_name([], {"model": model_name}, None)
339
366
  """Calculate cost based on token usage and model"""
340
367
  if model_custom_cost is None:
341
368
  model_custom_cost = {}
@@ -344,40 +371,51 @@ def calculate_llm_cost(token_usage, model_name, model_costs, model_custom_cost=N
344
371
  token_usage = {
345
372
  "prompt_tokens": 0,
346
373
  "completion_tokens": 0,
347
- "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
374
+ "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0,
348
375
  }
349
-
376
+
350
377
  # Get model costs, defaulting to default costs if unknown
351
- model_cost = model_cost = model_costs.get(model_name, {
352
- "input_cost_per_token": 0.0,
353
- "output_cost_per_token": 0.0
354
- })
355
- if model_cost['input_cost_per_token'] == 0.0 and model_cost['output_cost_per_token'] == 0.0:
356
- provide_name = model_name.split('-')[0]
357
- if provide_name == 'azure':
358
- model_name = os.path.join('azure', '-'.join(model_name.split('-')[1:]))
359
-
360
- model_cost = model_costs.get(model_name, {
361
- "input_cost_per_token": 0.0,
362
- "output_cost_per_token": 0.0
363
- })
364
-
365
- input_cost = (token_usage.get("prompt_tokens", 0)) * model_cost.get("input_cost_per_token", 0.0)
366
- output_cost = (token_usage.get("completion_tokens", 0)) * model_cost.get("output_cost_per_token", 0.0)
378
+ model_cost = model_cost = model_costs.get(
379
+ model_name, {"input_cost_per_token": 0.0, "output_cost_per_token": 0.0}
380
+ )
381
+ if (
382
+ model_cost["input_cost_per_token"] == 0.0
383
+ and model_cost["output_cost_per_token"] == 0.0
384
+ ):
385
+ provide_name = model_name.split("-")[0]
386
+ if provide_name == "azure":
387
+ model_name = os.path.join("azure", "-".join(model_name.split("-")[1:]))
388
+
389
+ model_cost = model_costs.get(
390
+ model_name, {"input_cost_per_token": 0.0, "output_cost_per_token": 0.0}
391
+ )
392
+
393
+ input_cost = (token_usage.get("prompt_tokens", 0)) * model_cost.get(
394
+ "input_cost_per_token", 0.0
395
+ )
396
+ output_cost = (token_usage.get("completion_tokens", 0)) * model_cost.get(
397
+ "output_cost_per_token", 0.0
398
+ )
367
399
  total_cost = input_cost + output_cost
368
400
 
369
401
  return {
370
402
  "input_cost": round(input_cost, 10),
371
403
  "output_cost": round(output_cost, 10),
372
- "total_cost": round(total_cost, 10)
404
+ "total_cost": round(total_cost, 10),
373
405
  }
374
406
 
375
407
 
376
408
  def sanitize_api_keys(data):
377
409
  """Remove sensitive information from data"""
378
410
  if isinstance(data, dict):
379
- return {k: sanitize_api_keys(v) for k, v in data.items()
380
- if not any(sensitive in k.lower() for sensitive in ['key', 'token', 'secret', 'password'])}
411
+ return {
412
+ k: sanitize_api_keys(v)
413
+ for k, v in data.items()
414
+ if not any(
415
+ sensitive in k.lower()
416
+ for sensitive in ["key", "token", "secret", "password"]
417
+ )
418
+ }
381
419
  elif isinstance(data, list):
382
420
  return [sanitize_api_keys(item) for item in data]
383
421
  elif isinstance(data, tuple):
@@ -387,10 +425,10 @@ def sanitize_api_keys(data):
387
425
 
388
426
  def sanitize_input(args, kwargs):
389
427
  """Convert input arguments to text format.
390
-
428
+
391
429
  Args:
392
430
  args: Input arguments that may contain nested dictionaries
393
-
431
+
394
432
  Returns:
395
433
  str: Text representation of the input arguments
396
434
  """
@@ -403,6 +441,7 @@ def sanitize_input(args, kwargs):
403
441
 
404
442
  def extract_llm_output(result):
405
443
  """Extract output from LLM response"""
444
+
406
445
  class OutputResponse:
407
446
  def __init__(self, output_response):
408
447
  self.output_response = output_response
@@ -415,7 +454,9 @@ def extract_llm_output(result):
415
454
  else:
416
455
  # We're in an async context, but this function is called synchronously
417
456
  # Return a placeholder and let the caller handle the coroutine
418
- return OutputResponse([{'content': "Coroutine result pending", "role": "assistant"}])
457
+ return OutputResponse(
458
+ [{"content": "Coroutine result pending", "role": "assistant"}]
459
+ )
419
460
 
420
461
  # Handle Google GenerativeAI format
421
462
  if hasattr(result, "result"):
@@ -426,56 +467,52 @@ def extract_llm_output(result):
426
467
  if content and hasattr(content, "parts"):
427
468
  for part in content.parts:
428
469
  if hasattr(part, "text"):
429
- output.append({
430
- "content": part.text,
431
- "role": getattr(content, "role", "assistant"),
432
- "finish_reason": getattr(candidate, "finish_reason", None)
433
- })
470
+ output.append(
471
+ {
472
+ "content": part.text,
473
+ "role": getattr(content, "role", "assistant"),
474
+ "finish_reason": getattr(
475
+ candidate, "finish_reason", None
476
+ ),
477
+ }
478
+ )
434
479
  return OutputResponse(output)
435
-
480
+
436
481
  # Handle AIMessage Format
437
482
  if hasattr(result, "content"):
438
- return OutputResponse([{
439
- "content": result.content,
440
- "role": getattr(result, "role", "assistant")
441
- }])
442
-
483
+ return OutputResponse(
484
+ [{"content": result.content, "role": getattr(result, "role", "assistant")}]
485
+ )
486
+
443
487
  # Handle Vertex AI format
444
488
  # format1
445
489
  if hasattr(result, "text"):
446
- return OutputResponse([{
447
- "content": result.text,
448
- "role": "assistant"
449
- }])
450
-
490
+ return OutputResponse([{"content": result.text, "role": "assistant"}])
451
491
 
452
492
  # format2
453
493
  if hasattr(result, "generations"):
454
494
  output = []
455
495
  for generation in result.generations:
456
- output.append({
457
- "content": generation.text,
458
- "role": "assistant"
459
- })
496
+ output.append({"content": generation.text, "role": "assistant"})
460
497
  return OutputResponse(output)
461
-
498
+
462
499
  # Handle OpenAI format
463
500
  if hasattr(result, "choices"):
464
- return OutputResponse([{
465
- "content": choice.message.content,
466
- "role": choice.message.role
467
- } for choice in result.choices])
468
-
501
+ return OutputResponse(
502
+ [
503
+ {"content": choice.message.content, "role": choice.message.role}
504
+ for choice in result.choices
505
+ ]
506
+ )
469
507
 
470
508
  # Handle Anthropic format
471
509
  if hasattr(result, "content"):
472
- return OutputResponse([{
473
- "content": result.content[0].text,
474
- "role": "assistant"
475
- }])
476
-
510
+ return OutputResponse(
511
+ [{"content": result.content[0].text, "role": "assistant"}]
512
+ )
513
+
477
514
  # Default case
478
- return OutputResponse([{'content': result, 'role': 'assistant'}])
515
+ return OutputResponse([{"content": result, "role": "assistant"}])
479
516
 
480
517
 
481
518
  def extract_llm_data(args, kwargs, result):
@@ -608,12 +645,12 @@ def count_tokens(input_str: str) -> int:
608
645
  # Use tiktoken to count tokens
609
646
  try:
610
647
  import tiktoken
611
-
648
+
612
649
  # Use GPT-4o model's encoding (cl100k_base)
613
650
  encoding = tiktoken.get_encoding("cl100k_base")
614
-
651
+
615
652
  # Count tokens
616
653
  tokens = encoding.encode(input_str)
617
654
  return len(tokens)
618
655
  except Exception:
619
- raise Exception("Failed to count tokens")
656
+ raise Exception("Failed to count tokens")