ragaai-catalyst 2.1.4.1b0__py3-none-any.whl → 2.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +23 -2
- ragaai_catalyst/dataset.py +462 -1
- ragaai_catalyst/evaluation.py +76 -7
- ragaai_catalyst/ragaai_catalyst.py +52 -10
- ragaai_catalyst/redteaming/__init__.py +7 -0
- ragaai_catalyst/redteaming/config/detectors.toml +13 -0
- ragaai_catalyst/redteaming/data_generator/scenario_generator.py +95 -0
- ragaai_catalyst/redteaming/data_generator/test_case_generator.py +120 -0
- ragaai_catalyst/redteaming/evaluator.py +125 -0
- ragaai_catalyst/redteaming/llm_generator.py +136 -0
- ragaai_catalyst/redteaming/llm_generator_old.py +83 -0
- ragaai_catalyst/redteaming/red_teaming.py +331 -0
- ragaai_catalyst/redteaming/requirements.txt +4 -0
- ragaai_catalyst/redteaming/tests/grok.ipynb +97 -0
- ragaai_catalyst/redteaming/tests/stereotype.ipynb +2258 -0
- ragaai_catalyst/redteaming/upload_result.py +38 -0
- ragaai_catalyst/redteaming/utils/issue_description.py +114 -0
- ragaai_catalyst/redteaming/utils/rt.png +0 -0
- ragaai_catalyst/redteaming_old.py +171 -0
- ragaai_catalyst/synthetic_data_generation.py +400 -22
- ragaai_catalyst/tracers/__init__.py +17 -1
- ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +4 -2
- ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +212 -148
- ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +657 -247
- ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +50 -19
- ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +588 -177
- ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +99 -100
- ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +3 -3
- ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +230 -29
- ragaai_catalyst/tracers/agentic_tracing/upload/trace_uploader.py +358 -0
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +75 -20
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_code.py +55 -11
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +74 -0
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +47 -16
- ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +4 -2
- ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +26 -3
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +182 -17
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1233 -497
- ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +81 -10
- ragaai_catalyst/tracers/agentic_tracing/utils/supported_llm_provider.toml +34 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/system_monitor.py +215 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
- ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
- ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +73 -47
- ragaai_catalyst/tracers/distributed.py +300 -0
- ragaai_catalyst/tracers/exporters/__init__.py +3 -1
- ragaai_catalyst/tracers/exporters/dynamic_trace_exporter.py +160 -0
- ragaai_catalyst/tracers/exporters/ragaai_trace_exporter.py +129 -0
- ragaai_catalyst/tracers/langchain_callback.py +809 -0
- ragaai_catalyst/tracers/llamaindex_instrumentation.py +424 -0
- ragaai_catalyst/tracers/tracer.py +301 -55
- ragaai_catalyst/tracers/upload_traces.py +24 -7
- ragaai_catalyst/tracers/utils/convert_langchain_callbacks_output.py +61 -0
- ragaai_catalyst/tracers/utils/convert_llama_instru_callback.py +69 -0
- ragaai_catalyst/tracers/utils/extraction_logic_llama_index.py +74 -0
- ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +82 -0
- ragaai_catalyst/tracers/utils/model_prices_and_context_window_backup.json +9365 -0
- ragaai_catalyst/tracers/utils/trace_json_converter.py +269 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/METADATA +367 -45
- ragaai_catalyst-2.1.5.dist-info/RECORD +97 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/WHEEL +1 -1
- ragaai_catalyst-2.1.4.1b0.dist-info/RECORD +0 -67
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/LICENSE +0 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,74 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import requests
|
4
|
+
|
5
|
+
from ragaai_catalyst import RagaAICatalyst
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
logging_level = (
|
9
|
+
logger.setLevel(logging.DEBUG)
|
10
|
+
if os.getenv("DEBUG")
|
11
|
+
else logger.setLevel(logging.INFO)
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def calculate_metric(project_id, metric_name, model, provider, **kwargs):
|
16
|
+
user_id = "1"
|
17
|
+
org_domain = "raga"
|
18
|
+
|
19
|
+
headers = {
|
20
|
+
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
21
|
+
"X-Project-Id": str(project_id),
|
22
|
+
"Content-Type": "application/json"
|
23
|
+
}
|
24
|
+
|
25
|
+
payload = {
|
26
|
+
"data": [
|
27
|
+
{
|
28
|
+
"metric_name": metric_name,
|
29
|
+
"metric_config": {
|
30
|
+
"threshold": {
|
31
|
+
"isEditable": True,
|
32
|
+
"lte": 0.3
|
33
|
+
},
|
34
|
+
"model": model,
|
35
|
+
"orgDomain": org_domain,
|
36
|
+
"provider": provider,
|
37
|
+
"user_id": user_id,
|
38
|
+
"job_id": 1,
|
39
|
+
"metric_name": metric_name,
|
40
|
+
"request_id": 1
|
41
|
+
},
|
42
|
+
"variable_mapping": kwargs,
|
43
|
+
"trace_object": {
|
44
|
+
"Data": {
|
45
|
+
"DocId": "doc-1",
|
46
|
+
"Prompt": kwargs.get("prompt"),
|
47
|
+
"Response": kwargs.get("response"),
|
48
|
+
"Context": kwargs.get("context"),
|
49
|
+
"ExpectedResponse": kwargs.get("expected_response"),
|
50
|
+
"ExpectedContext": kwargs.get("expected_context"),
|
51
|
+
"Chat": kwargs.get("chat"),
|
52
|
+
"Instructions": kwargs.get("instructions"),
|
53
|
+
"SystemPrompt": kwargs.get("system_prompt"),
|
54
|
+
"Text": kwargs.get("text")
|
55
|
+
},
|
56
|
+
"claims": {},
|
57
|
+
"last_computed_metrics": {
|
58
|
+
metric_name: {
|
59
|
+
}
|
60
|
+
}
|
61
|
+
}
|
62
|
+
}
|
63
|
+
]
|
64
|
+
}
|
65
|
+
|
66
|
+
try:
|
67
|
+
BASE_URL = RagaAICatalyst.BASE_URL
|
68
|
+
response = requests.post(f"{BASE_URL}/v1/llm/calculate-metric", headers=headers, json=payload, timeout=30)
|
69
|
+
logger.debug(f"Metric calculation response status {response.status_code}")
|
70
|
+
response.raise_for_status()
|
71
|
+
return response.json()
|
72
|
+
except requests.exceptions.RequestException as e:
|
73
|
+
logger.debug(f"Error in calculate-metric api: {e}, payload: {payload}")
|
74
|
+
raise Exception(f"Error in calculate-metric: {e}")
|
@@ -1,27 +1,40 @@
|
|
1
|
+
import logging
|
2
|
+
|
1
3
|
import requests
|
2
4
|
import os
|
3
5
|
import json
|
6
|
+
import time
|
4
7
|
from ....ragaai_catalyst import RagaAICatalyst
|
5
8
|
from ..utils.get_user_trace_metrics import get_user_trace_metrics
|
6
9
|
|
7
|
-
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
logging_level = (
|
12
|
+
logger.setLevel(logging.DEBUG)
|
13
|
+
if os.getenv("DEBUG")
|
14
|
+
else logger.setLevel(logging.INFO)
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
def upload_trace_metric(json_file_path, dataset_name, project_name, base_url=None):
|
8
19
|
try:
|
9
20
|
with open(json_file_path, "r") as f:
|
10
21
|
traces = json.load(f)
|
22
|
+
|
11
23
|
metrics = get_trace_metrics_from_trace(traces)
|
12
24
|
metrics = _change_metrics_format_for_payload(metrics)
|
13
25
|
|
14
26
|
user_trace_metrics = get_user_trace_metrics(project_name, dataset_name)
|
15
27
|
if user_trace_metrics:
|
16
28
|
user_trace_metrics_list = [metric["displayName"] for metric in user_trace_metrics]
|
17
|
-
|
29
|
+
|
18
30
|
if user_trace_metrics:
|
19
31
|
for metric in metrics:
|
20
32
|
if metric["displayName"] in user_trace_metrics_list:
|
21
|
-
metricConfig = next((user_metric["metricConfig"] for user_metric in user_trace_metrics if
|
33
|
+
metricConfig = next((user_metric["metricConfig"] for user_metric in user_trace_metrics if
|
34
|
+
user_metric["displayName"] == metric["displayName"]), None)
|
22
35
|
if not metricConfig or metricConfig.get("Metric Source", {}).get("value") != "user":
|
23
|
-
raise ValueError(
|
24
|
-
|
36
|
+
raise ValueError(
|
37
|
+
f"Metrics {metric['displayName']} already exist in dataset {dataset_name} of project {project_name}.")
|
25
38
|
headers = {
|
26
39
|
"Content-Type": "application/json",
|
27
40
|
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
@@ -31,11 +44,17 @@ def upload_trace_metric(json_file_path, dataset_name, project_name):
|
|
31
44
|
"datasetName": dataset_name,
|
32
45
|
"metrics": metrics
|
33
46
|
})
|
34
|
-
|
35
|
-
|
36
|
-
|
47
|
+
url_base = base_url if base_url is not None else RagaAICatalyst.BASE_URL
|
48
|
+
start_time = time.time()
|
49
|
+
endpoint = f"{url_base}/v1/llm/trace/metrics"
|
50
|
+
response = requests.request("POST",
|
51
|
+
endpoint,
|
52
|
+
headers=headers,
|
37
53
|
data=payload,
|
38
54
|
timeout=10)
|
55
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
56
|
+
logger.debug(
|
57
|
+
f"API Call: [POST] {endpoint} | Status: {response.status_code} | Time: {elapsed_ms:.2f}ms")
|
39
58
|
if response.status_code != 200:
|
40
59
|
raise ValueError(f"Error inserting agentic trace metrics")
|
41
60
|
except requests.exceptions.RequestException as e:
|
@@ -59,25 +78,37 @@ def _get_children_metrics_of_agent(children_traces):
|
|
59
78
|
|
60
79
|
def get_trace_metrics_from_trace(traces):
|
61
80
|
metrics = []
|
81
|
+
|
82
|
+
# get trace level metrics
|
83
|
+
if "metrics" in traces.keys():
|
84
|
+
if len(traces["metrics"]) > 0:
|
85
|
+
metrics.extend(traces["metrics"])
|
86
|
+
|
87
|
+
# get span level metrics
|
62
88
|
for span in traces["data"][0]["spans"]:
|
63
89
|
if span["type"] == "agent":
|
90
|
+
# Add children metrics of agent
|
64
91
|
children_metric = _get_children_metrics_of_agent(span["data"]["children"])
|
65
92
|
if children_metric:
|
66
93
|
metrics.extend(children_metric)
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
94
|
+
|
95
|
+
metric = span.get("metrics", [])
|
96
|
+
if metric:
|
97
|
+
metrics.extend(metric)
|
71
98
|
return metrics
|
72
99
|
|
100
|
+
|
73
101
|
def _change_metrics_format_for_payload(metrics):
|
74
102
|
formatted_metrics = []
|
75
103
|
for metric in metrics:
|
76
|
-
if any(m["name"] == metric["name"
|
104
|
+
if any(m["name"] == metric.get("displayName") or m['name'] == metric.get("name") for m in formatted_metrics):
|
77
105
|
continue
|
106
|
+
metric_display_name = metric["name"]
|
107
|
+
if metric.get("displayName"):
|
108
|
+
metric_display_name = metric['displayName']
|
78
109
|
formatted_metrics.append({
|
79
|
-
"name":
|
80
|
-
"displayName":
|
110
|
+
"name": metric_display_name,
|
111
|
+
"displayName": metric_display_name,
|
81
112
|
"config": {"source": "user"},
|
82
113
|
})
|
83
|
-
return formatted_metrics
|
114
|
+
return formatted_metrics
|
@@ -4,7 +4,7 @@ import re
|
|
4
4
|
import requests
|
5
5
|
from ragaai_catalyst.tracers.agentic_tracing.tracers.base import RagaAICatalyst
|
6
6
|
|
7
|
-
def create_dataset_schema_with_trace(project_name, dataset_name):
|
7
|
+
def create_dataset_schema_with_trace(project_name, dataset_name, base_url=None):
|
8
8
|
def make_request():
|
9
9
|
headers = {
|
10
10
|
"Content-Type": "application/json",
|
@@ -15,8 +15,10 @@ def create_dataset_schema_with_trace(project_name, dataset_name):
|
|
15
15
|
"datasetName": dataset_name,
|
16
16
|
"traceFolderUrl": None,
|
17
17
|
})
|
18
|
+
# Use provided base_url or fall back to default
|
19
|
+
url_base = base_url if base_url is not None else RagaAICatalyst.BASE_URL
|
18
20
|
response = requests.request("POST",
|
19
|
-
f"{
|
21
|
+
f"{url_base}/v1/llm/dataset/logs",
|
20
22
|
headers=headers,
|
21
23
|
data=payload,
|
22
24
|
timeout=10
|
@@ -8,13 +8,32 @@ class TrackName:
|
|
8
8
|
def trace_decorator(self, func):
|
9
9
|
@wraps(func)
|
10
10
|
def wrapper(*args, **kwargs):
|
11
|
-
file_name = self.
|
11
|
+
file_name = self._get_decorated_file_name()
|
12
12
|
self.files.add(file_name)
|
13
13
|
|
14
14
|
return func(*args, **kwargs)
|
15
15
|
return wrapper
|
16
|
+
|
17
|
+
def trace_wrapper(self, func):
|
18
|
+
@wraps(func)
|
19
|
+
def wrapper(*args, **kwargs):
|
20
|
+
file_name = self._get_wrapped_file_name()
|
21
|
+
self.files.add(file_name)
|
22
|
+
return func(*args, **kwargs)
|
23
|
+
return wrapper
|
24
|
+
|
25
|
+
def _get_wrapped_file_name(self):
|
26
|
+
try:
|
27
|
+
from IPython import get_ipython
|
28
|
+
if 'IPKernelApp' in get_ipython().config:
|
29
|
+
return self._get_notebook_name()
|
30
|
+
except Exception:
|
31
|
+
pass
|
32
|
+
|
33
|
+
frame = inspect.stack()[4]
|
34
|
+
return frame.filename
|
16
35
|
|
17
|
-
def
|
36
|
+
def _get_decorated_file_name(self):
|
18
37
|
# Check if running in a Jupyter notebook
|
19
38
|
try:
|
20
39
|
from IPython import get_ipython
|
@@ -43,4 +62,8 @@ class TrackName:
|
|
43
62
|
|
44
63
|
def reset(self):
|
45
64
|
"""Reset the file tracker by clearing all tracked files."""
|
46
|
-
self.files.clear()
|
65
|
+
self.files.clear()
|
66
|
+
|
67
|
+
def trace_main_file(self):
|
68
|
+
frame = inspect.stack()[-1]
|
69
|
+
self.files.add(frame.filename)
|
@@ -2,14 +2,30 @@ from ..data.data_structure import LLMCall
|
|
2
2
|
from .trace_utils import (
|
3
3
|
calculate_cost,
|
4
4
|
convert_usage_to_dict,
|
5
|
-
load_model_costs,
|
6
5
|
)
|
7
6
|
from importlib import resources
|
7
|
+
#from litellm import model_cost
|
8
8
|
import json
|
9
9
|
import os
|
10
10
|
import asyncio
|
11
11
|
import psutil
|
12
|
+
import tiktoken
|
13
|
+
import logging
|
12
14
|
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
def get_model_cost():
|
18
|
+
"""Load model costs from a JSON file.
|
19
|
+
Note: This file should be updated periodically or whenever a new package is created to ensure accurate cost calculations.
|
20
|
+
To Do: Implement to do this automatically.
|
21
|
+
"""
|
22
|
+
file="model_prices_and_context_window_backup.json"
|
23
|
+
d={}
|
24
|
+
with resources.open_text("ragaai_catalyst.tracers.utils", file) as f:
|
25
|
+
d= json.load(f)
|
26
|
+
return d
|
27
|
+
|
28
|
+
model_cost = get_model_cost()
|
13
29
|
|
14
30
|
def extract_model_name(args, kwargs, result):
|
15
31
|
"""Extract model name from kwargs or result"""
|
@@ -35,7 +51,18 @@ def extract_model_name(args, kwargs, result):
|
|
35
51
|
metadata = manager.metadata
|
36
52
|
model_name = metadata.get('ls_model_name', None)
|
37
53
|
if model_name:
|
38
|
-
model = model_name
|
54
|
+
model = model_name
|
55
|
+
|
56
|
+
if not model:
|
57
|
+
if 'to_dict' in dir(result):
|
58
|
+
result = result.to_dict()
|
59
|
+
if 'model_version' in result:
|
60
|
+
model = result['model_version']
|
61
|
+
try:
|
62
|
+
if not model:
|
63
|
+
model = result.raw.model
|
64
|
+
except Exception as e:
|
65
|
+
pass
|
39
66
|
|
40
67
|
|
41
68
|
# Normalize Google model names
|
@@ -48,10 +75,9 @@ def extract_model_name(args, kwargs, result):
|
|
48
75
|
if "gemini-pro" in model:
|
49
76
|
return "gemini-pro"
|
50
77
|
|
51
|
-
if '
|
52
|
-
|
53
|
-
|
54
|
-
model = result['model_version']
|
78
|
+
if 'response_metadata' in dir(result):
|
79
|
+
if 'model_name' in result.response_metadata:
|
80
|
+
model = result.response_metadata['model_name']
|
55
81
|
|
56
82
|
return model or "default"
|
57
83
|
|
@@ -67,6 +93,9 @@ def extract_parameters(kwargs):
|
|
67
93
|
# Remove messages key in parameters (OpenAI message)
|
68
94
|
if 'messages' in parameters:
|
69
95
|
del parameters['messages']
|
96
|
+
|
97
|
+
if 'run_manager' in parameters:
|
98
|
+
del parameters['run_manager']
|
70
99
|
|
71
100
|
if 'generation_config' in parameters:
|
72
101
|
generation_config = parameters['generation_config']
|
@@ -91,8 +120,8 @@ def extract_token_usage(result):
|
|
91
120
|
# Run the coroutine in the current event loop
|
92
121
|
result = loop.run_until_complete(result)
|
93
122
|
|
94
|
-
# Handle text attribute responses (JSON string
|
95
|
-
if hasattr(result, "text"):
|
123
|
+
# Handle text attribute responses (JSON string for Vertex AI)
|
124
|
+
if hasattr(result, "text") and isinstance(result.text, (str, bytes, bytearray)):
|
96
125
|
# First try parsing as JSON for OpenAI responses
|
97
126
|
try:
|
98
127
|
import json
|
@@ -137,10 +166,34 @@ def extract_token_usage(result):
|
|
137
166
|
# Handle Google GenerativeAI format with usage_metadata
|
138
167
|
if hasattr(result, "usage_metadata"):
|
139
168
|
metadata = result.usage_metadata
|
169
|
+
if hasattr(metadata, "prompt_token_count"):
|
170
|
+
return {
|
171
|
+
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
172
|
+
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
173
|
+
"total_tokens": getattr(metadata, "total_token_count", 0)
|
174
|
+
}
|
175
|
+
elif hasattr(metadata, "input_tokens"):
|
176
|
+
return {
|
177
|
+
"prompt_tokens": getattr(metadata, "input_tokens", 0),
|
178
|
+
"completion_tokens": getattr(metadata, "output_tokens", 0),
|
179
|
+
"total_tokens": getattr(metadata, "total_tokens", 0)
|
180
|
+
}
|
181
|
+
elif "input_tokens" in metadata:
|
182
|
+
return {
|
183
|
+
"prompt_tokens": metadata["input_tokens"],
|
184
|
+
"completion_tokens": metadata["output_tokens"],
|
185
|
+
"total_tokens": metadata["total_tokens"]
|
186
|
+
}
|
187
|
+
|
188
|
+
|
189
|
+
|
190
|
+
# Handle ChatResponse format with raw usuage
|
191
|
+
if hasattr(result, "raw") and hasattr(result.raw, "usage"):
|
192
|
+
usage = result.raw.usage
|
140
193
|
return {
|
141
|
-
"prompt_tokens": getattr(
|
142
|
-
"completion_tokens": getattr(
|
143
|
-
"total_tokens": getattr(
|
194
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
195
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
196
|
+
"total_tokens": getattr(usage, "total_tokens", 0)
|
144
197
|
}
|
145
198
|
|
146
199
|
# Handle ChatResult format with generations
|
@@ -173,24 +226,129 @@ def extract_token_usage(result):
|
|
173
226
|
"total_tokens": 0
|
174
227
|
}
|
175
228
|
|
229
|
+
def num_tokens_from_messages(model="gpt-4o-mini-2024-07-18", prompt_messages=None, response_message=None):
|
230
|
+
"""Calculate the number of tokens used by messages.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
messages: Optional list of messages (deprecated, use prompt_messages and response_message instead)
|
234
|
+
model: The model name to use for token calculation
|
235
|
+
prompt_messages: List of prompt messages
|
236
|
+
response_message: Response message from the assistant
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
dict: A dictionary containing:
|
240
|
+
- prompt_tokens: Number of tokens in the prompt
|
241
|
+
- completion_tokens: Number of tokens in the completion
|
242
|
+
- total_tokens: Total number of tokens
|
243
|
+
"""
|
244
|
+
#import pdb; pdb.set_trace()
|
245
|
+
try:
|
246
|
+
encoding = tiktoken.encoding_for_model(model)
|
247
|
+
except KeyError:
|
248
|
+
logging.warning("Warning: model not found. Using o200k_base encoding.")
|
249
|
+
encoding = tiktoken.get_encoding("o200k_base")
|
250
|
+
|
251
|
+
if model in {
|
252
|
+
"gpt-3.5-turbo-0125",
|
253
|
+
"gpt-4-0314",
|
254
|
+
"gpt-4-32k-0314",
|
255
|
+
"gpt-4-0613",
|
256
|
+
"gpt-4-32k-0613",
|
257
|
+
"gpt-4o-2024-08-06",
|
258
|
+
"gpt-4o-mini-2024-07-18"
|
259
|
+
}:
|
260
|
+
tokens_per_message = 3
|
261
|
+
tokens_per_name = 1
|
262
|
+
elif "gpt-3.5-turbo" in model:
|
263
|
+
logging.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
|
264
|
+
return num_tokens_from_messages(model="gpt-3.5-turbo-0125",
|
265
|
+
prompt_messages=prompt_messages, response_message=response_message)
|
266
|
+
elif "gpt-4o-mini" in model:
|
267
|
+
logging.warning("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
|
268
|
+
return num_tokens_from_messages(model="gpt-4o-mini-2024-07-18",
|
269
|
+
prompt_messages=prompt_messages, response_message=response_message)
|
270
|
+
elif "gpt-4o" in model:
|
271
|
+
logging.warning("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
|
272
|
+
return num_tokens_from_messages(model="gpt-4o-2024-08-06",
|
273
|
+
prompt_messages=prompt_messages, response_message=response_message)
|
274
|
+
elif "gpt-4" in model:
|
275
|
+
logging.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
276
|
+
return num_tokens_from_messages(model="gpt-4-0613",
|
277
|
+
prompt_messages=prompt_messages, response_message=response_message)
|
278
|
+
else:
|
279
|
+
raise NotImplementedError(
|
280
|
+
f"""num_tokens_from_messages() is not implemented for model {model}."""
|
281
|
+
)
|
282
|
+
|
283
|
+
all_messages = []
|
284
|
+
if prompt_messages:
|
285
|
+
all_messages.extend(prompt_messages)
|
286
|
+
if response_message:
|
287
|
+
if isinstance(response_message, dict):
|
288
|
+
all_messages.append(response_message)
|
289
|
+
else:
|
290
|
+
all_messages.append({"role": "assistant", "content": response_message})
|
291
|
+
|
292
|
+
prompt_tokens = 0
|
293
|
+
completion_tokens = 0
|
294
|
+
|
295
|
+
for message in all_messages:
|
296
|
+
num_tokens = tokens_per_message
|
297
|
+
for key, value in message.items():
|
298
|
+
token_count = len(encoding.encode(str(value))) # Convert value to string for safety
|
299
|
+
num_tokens += token_count
|
300
|
+
if key == "name":
|
301
|
+
num_tokens += tokens_per_name
|
302
|
+
|
303
|
+
# Add tokens to prompt or completion based on role
|
304
|
+
if message.get("role") == "assistant":
|
305
|
+
completion_tokens += num_tokens
|
306
|
+
else:
|
307
|
+
prompt_tokens += num_tokens
|
308
|
+
|
309
|
+
# Add the assistant message prefix tokens to completion tokens if we have a response
|
310
|
+
if completion_tokens > 0:
|
311
|
+
completion_tokens += 3 # <|start|>assistant<|message|>
|
312
|
+
|
313
|
+
total_tokens = prompt_tokens + completion_tokens
|
314
|
+
|
315
|
+
return {
|
316
|
+
"prompt_tokens": prompt_tokens,
|
317
|
+
"completion_tokens": completion_tokens,
|
318
|
+
"total_tokens": total_tokens
|
319
|
+
}
|
176
320
|
|
177
321
|
def extract_input_data(args, kwargs, result):
|
178
|
-
"""
|
322
|
+
"""Sanitize and format input data, including handling of nested lists and dictionaries."""
|
323
|
+
|
324
|
+
def sanitize_value(value):
|
325
|
+
if isinstance(value, (int, float, bool, str)):
|
326
|
+
return value
|
327
|
+
elif isinstance(value, list):
|
328
|
+
return [sanitize_value(item) for item in value]
|
329
|
+
elif isinstance(value, dict):
|
330
|
+
return {key: sanitize_value(val) for key, val in value.items()}
|
331
|
+
else:
|
332
|
+
return str(value) # Convert non-standard types to string
|
333
|
+
|
179
334
|
return {
|
180
|
-
|
181
|
-
|
335
|
+
"args": [sanitize_value(arg) for arg in args],
|
336
|
+
"kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
|
182
337
|
}
|
183
338
|
|
184
339
|
|
185
|
-
def calculate_llm_cost(token_usage, model_name, model_costs):
|
340
|
+
def calculate_llm_cost(token_usage, model_name, model_costs, model_custom_cost=None):
|
186
341
|
"""Calculate cost based on token usage and model"""
|
342
|
+
if model_custom_cost is None:
|
343
|
+
model_custom_cost = {}
|
344
|
+
model_costs.update(model_custom_cost)
|
187
345
|
if not isinstance(token_usage, dict):
|
188
346
|
token_usage = {
|
189
347
|
"prompt_tokens": 0,
|
190
348
|
"completion_tokens": 0,
|
191
349
|
"total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
|
192
350
|
}
|
193
|
-
|
351
|
+
|
194
352
|
# Get model costs, defaulting to default costs if unknown
|
195
353
|
model_cost = model_cost = model_costs.get(model_name, {
|
196
354
|
"input_cost_per_token": 0.0,
|
@@ -277,6 +435,13 @@ def extract_llm_output(result):
|
|
277
435
|
})
|
278
436
|
return OutputResponse(output)
|
279
437
|
|
438
|
+
# Handle AIMessage Format
|
439
|
+
if hasattr(result, "content"):
|
440
|
+
return OutputResponse([{
|
441
|
+
"content": result.content,
|
442
|
+
"role": getattr(result, "role", "assistant")
|
443
|
+
}])
|
444
|
+
|
280
445
|
# Handle Vertex AI format
|
281
446
|
# format1
|
282
447
|
if hasattr(result, "text"):
|
@@ -424,7 +589,7 @@ def extract_llm_data(args, kwargs, result):
|
|
424
589
|
token_usage = extract_token_usage(result)
|
425
590
|
|
426
591
|
# Load model costs
|
427
|
-
model_costs =
|
592
|
+
model_costs = model_cost
|
428
593
|
|
429
594
|
# Calculate cost
|
430
595
|
cost = calculate_llm_cost(token_usage, model_name, model_costs)
|