trailai 1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trail/__helpers.py +278 -0
- trail/__init__.py +1857 -0
- trail/core/datetime_utils.py +4 -0
- trail/core/pydantic_utilities.py +17 -0
- trail/evals/__init__.py +12 -0
- trail/evals/all.py +169 -0
- trail/evals/bias_detection.py +173 -0
- trail/evals/hallucination.py +170 -0
- trail/evals/toxicity.py +168 -0
- trail/evals/utils.py +275 -0
- trail/instrumentation/ag2/__init__.py +49 -0
- trail/instrumentation/ag2/ag2.py +163 -0
- trail/instrumentation/ai21/__init__.py +67 -0
- trail/instrumentation/ai21/ai21.py +191 -0
- trail/instrumentation/ai21/async_ai21.py +191 -0
- trail/instrumentation/ai21/utils.py +409 -0
- trail/instrumentation/anthropic/__init__.py +50 -0
- trail/instrumentation/anthropic/anthropic.py +149 -0
- trail/instrumentation/anthropic/async_anthropic.py +149 -0
- trail/instrumentation/anthropic/utils.py +251 -0
- trail/instrumentation/assemblyai/__init__.py +43 -0
- trail/instrumentation/assemblyai/assemblyai.py +150 -0
- trail/instrumentation/astra/__init__.py +178 -0
- trail/instrumentation/astra/astra.py +45 -0
- trail/instrumentation/astra/async_astra.py +45 -0
- trail/instrumentation/astra/utils.py +102 -0
- trail/instrumentation/azure_ai_inference/__init__.py +53 -0
- trail/instrumentation/azure_ai_inference/async_azure_ai_inference.py +144 -0
- trail/instrumentation/azure_ai_inference/azure_ai_inference.py +144 -0
- trail/instrumentation/azure_ai_inference/utils.py +225 -0
- trail/instrumentation/bedrock/__init__.py +42 -0
- trail/instrumentation/bedrock/bedrock.py +77 -0
- trail/instrumentation/bedrock/utils.py +252 -0
- trail/instrumentation/chroma/__init__.py +86 -0
- trail/instrumentation/chroma/chroma.py +199 -0
- trail/instrumentation/cohere/__init__.py +74 -0
- trail/instrumentation/cohere/async_cohere.py +610 -0
- trail/instrumentation/cohere/cohere.py +610 -0
- trail/instrumentation/controlflow/__init__.py +56 -0
- trail/instrumentation/controlflow/controlflow.py +113 -0
- trail/instrumentation/crawl4ai/__init__.py +52 -0
- trail/instrumentation/crawl4ai/async_crawl4ai.py +104 -0
- trail/instrumentation/crawl4ai/crawl4ai.py +104 -0
- trail/instrumentation/crewai/__init__.py +50 -0
- trail/instrumentation/crewai/crewai.py +153 -0
- trail/instrumentation/dynamiq/__init__.py +64 -0
- trail/instrumentation/dynamiq/dynamiq.py +110 -0
- trail/instrumentation/elevenlabs/__init__.py +70 -0
- trail/instrumentation/elevenlabs/async_elevenlabs.py +146 -0
- trail/instrumentation/elevenlabs/elevenlabs.py +147 -0
- trail/instrumentation/embedchain/__init__.py +55 -0
- trail/instrumentation/embedchain/embedchain.py +165 -0
- trail/instrumentation/firecrawl/__init__.py +49 -0
- trail/instrumentation/firecrawl/firecrawl.py +90 -0
- trail/instrumentation/google_ai_studio/__init__.py +56 -0
- trail/instrumentation/google_ai_studio/async_google_ai_studio.py +227 -0
- trail/instrumentation/google_ai_studio/google_ai_studio.py +227 -0
- trail/instrumentation/gpt4all/__init__.py +52 -0
- trail/instrumentation/gpt4all/gpt4all.py +497 -0
- trail/instrumentation/gpu/__init__.py +213 -0
- trail/instrumentation/groq/__init__.py +50 -0
- trail/instrumentation/groq/async_groq.py +467 -0
- trail/instrumentation/groq/groq.py +467 -0
- trail/instrumentation/haystack/__init__.py +49 -0
- trail/instrumentation/haystack/haystack.py +84 -0
- trail/instrumentation/julep/__init__.py +80 -0
- trail/instrumentation/julep/async_julep.py +111 -0
- trail/instrumentation/julep/julep.py +112 -0
- trail/instrumentation/langchain/__init__.py +118 -0
- trail/instrumentation/langchain/async_langchain.py +388 -0
- trail/instrumentation/langchain/langchain.py +362 -0
- trail/instrumentation/letta/__init__.py +77 -0
- trail/instrumentation/letta/letta.py +188 -0
- trail/instrumentation/litellm/__init__.py +67 -0
- trail/instrumentation/litellm/async_litellm.py +592 -0
- trail/instrumentation/litellm/litellm.py +592 -0
- trail/instrumentation/llamaindex/__init__.py +55 -0
- trail/instrumentation/llamaindex/llamaindex.py +86 -0
- trail/instrumentation/mem0/__init__.py +79 -0
- trail/instrumentation/mem0/mem0.py +115 -0
- trail/instrumentation/milvus/__init__.py +94 -0
- trail/instrumentation/milvus/milvus.py +179 -0
- trail/instrumentation/mistral/__init__.py +80 -0
- trail/instrumentation/mistral/async_mistral.py +611 -0
- trail/instrumentation/mistral/mistral.py +611 -0
- trail/instrumentation/multion/__init__.py +80 -0
- trail/instrumentation/multion/async_multion.py +133 -0
- trail/instrumentation/multion/multion.py +133 -0
- trail/instrumentation/ollama/__init__.py +84 -0
- trail/instrumentation/ollama/async_ollama.py +184 -0
- trail/instrumentation/ollama/ollama.py +184 -0
- trail/instrumentation/ollama/utils.py +332 -0
- trail/instrumentation/openai/__init__.py +132 -0
- trail/instrumentation/openai/async_openai.py +1411 -0
- trail/instrumentation/openai/openai.py +1411 -0
- trail/instrumentation/openai_agents/__init__.py +42 -0
- trail/instrumentation/openai_agents/openai_agents.py +65 -0
- trail/instrumentation/phidata/__init__.py +42 -0
- trail/instrumentation/phidata/phidata.py +100 -0
- trail/instrumentation/pinecone/__init__.py +66 -0
- trail/instrumentation/pinecone/pinecone.py +173 -0
- trail/instrumentation/premai/__init__.py +51 -0
- trail/instrumentation/premai/premai.py +556 -0
- trail/instrumentation/qdrant/__init__.py +295 -0
- trail/instrumentation/qdrant/async_qdrant.py +267 -0
- trail/instrumentation/qdrant/qdrant.py +274 -0
- trail/instrumentation/reka/__init__.py +54 -0
- trail/instrumentation/reka/async_reka.py +197 -0
- trail/instrumentation/reka/reka.py +197 -0
- trail/instrumentation/together/__init__.py +70 -0
- trail/instrumentation/together/async_together.py +600 -0
- trail/instrumentation/together/together.py +600 -0
- trail/instrumentation/transformers/__init__.py +37 -0
- trail/instrumentation/transformers/transformers.py +197 -0
- trail/instrumentation/vertexai/__init__.py +97 -0
- trail/instrumentation/vertexai/async_vertexai.py +459 -0
- trail/instrumentation/vertexai/vertexai.py +459 -0
- trail/instrumentation/vllm/__init__.py +43 -0
- trail/instrumentation/vllm/vllm.py +173 -0
- trail/model/dataset.py +1022 -0
- trail/otel/events.py +80 -0
- trail/otel/metrics.py +218 -0
- trail/otel/tracing.py +88 -0
- trail/semcov/__init__.py +310 -0
- trail/utils/experiment.py +106 -0
- trail/utils/utils.py +22 -0
- trailai-1.6.dist-info/METADATA +32 -0
- trailai-1.6.dist-info/RECORD +130 -0
- trailai-1.6.dist-info/WHEEL +4 -0
- trailai-1.6.dist-info/licenses/LICENSE +201 -0
trail/__helpers.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# pylint: disable=bare-except, broad-exception-caught
|
|
2
|
+
"""
|
|
3
|
+
This module has functions to calculate model costs based on tokens and to fetch pricing information.
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
from typing import Any, Dict, List, Tuple
|
|
10
|
+
import math
|
|
11
|
+
import requests
|
|
12
|
+
from opentelemetry.sdk.resources import SERVICE_NAME, TELEMETRY_SDK_NAME, DEPLOYMENT_ENVIRONMENT
|
|
13
|
+
from opentelemetry.trace import Status, StatusCode
|
|
14
|
+
from opentelemetry._events import Event
|
|
15
|
+
from trail.semcov import SemanticConvetion
|
|
16
|
+
|
|
17
|
+
# Set up logging
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
def response_as_dict(response):
|
|
21
|
+
"""
|
|
22
|
+
Return parsed response as a dict
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# pylint: disable=no-else-return
|
|
26
|
+
if isinstance(response, dict):
|
|
27
|
+
return response
|
|
28
|
+
if hasattr(response, 'model_dump'):
|
|
29
|
+
return response.model_dump()
|
|
30
|
+
elif hasattr(response, 'parse'):
|
|
31
|
+
return response_as_dict(response.parse())
|
|
32
|
+
else:
|
|
33
|
+
return response
|
|
34
|
+
|
|
35
|
+
def get_env_variable(name, arg_value, error_message):
|
|
36
|
+
"""
|
|
37
|
+
Retrieve an environment variable if the argument is not provided
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if arg_value is not None:
|
|
41
|
+
return arg_value
|
|
42
|
+
value = os.getenv(name)
|
|
43
|
+
if not value:
|
|
44
|
+
logging.error(error_message)
|
|
45
|
+
raise RuntimeError(error_message)
|
|
46
|
+
return value
|
|
47
|
+
|
|
48
|
+
def general_tokens(text):
|
|
49
|
+
"""
|
|
50
|
+
Calculate the number of tokens a given text would take up.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
return math.ceil(len(text) / 2)
|
|
54
|
+
|
|
55
|
+
def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
|
|
56
|
+
"""
|
|
57
|
+
Retrieve the cost of processing for a given model based on prompt and tokens.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
cost = ((prompt_tokens / 1000) * pricing_info['chat'][model]['promptPrice']) + \
|
|
62
|
+
((completion_tokens / 1000) * pricing_info['chat'][model]['completionPrice'])
|
|
63
|
+
except:
|
|
64
|
+
cost = 0
|
|
65
|
+
return cost
|
|
66
|
+
|
|
67
|
+
def get_embed_model_cost(model, pricing_info, prompt_tokens):
|
|
68
|
+
"""
|
|
69
|
+
Retrieve the cost of processing for a given model based on prompt tokens.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
cost = (prompt_tokens / 1000) * pricing_info['embeddings'][model]
|
|
74
|
+
except:
|
|
75
|
+
cost = 0
|
|
76
|
+
return cost
|
|
77
|
+
|
|
78
|
+
def get_image_model_cost(model, pricing_info, size, quality):
|
|
79
|
+
"""
|
|
80
|
+
Retrieve the cost of processing for a given model based on image size and quailty.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
cost = pricing_info['images'][model][quality][size]
|
|
85
|
+
except:
|
|
86
|
+
cost = 0
|
|
87
|
+
return cost
|
|
88
|
+
|
|
89
|
+
def get_audio_model_cost(model, pricing_info, prompt, duration=None):
|
|
90
|
+
"""
|
|
91
|
+
Retrieve the cost of processing for a given model based on prompt.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
if prompt:
|
|
96
|
+
cost = (len(prompt) / 1000) * pricing_info['audio'][model]
|
|
97
|
+
else:
|
|
98
|
+
cost = duration * pricing_info['audio'][model]
|
|
99
|
+
except:
|
|
100
|
+
cost = 0
|
|
101
|
+
return cost
|
|
102
|
+
|
|
103
|
+
def fetch_pricing_info(pricing_json=None):
|
|
104
|
+
"""
|
|
105
|
+
Fetches pricing information from a specified URL or File Path.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
if pricing_json:
|
|
109
|
+
is_url = urlparse(pricing_json).scheme != ''
|
|
110
|
+
if is_url:
|
|
111
|
+
pricing_url = pricing_json
|
|
112
|
+
else:
|
|
113
|
+
try:
|
|
114
|
+
with open(pricing_json, mode='r', encoding='utf-8') as f:
|
|
115
|
+
return json.load(f)
|
|
116
|
+
except FileNotFoundError:
|
|
117
|
+
logger.error('Pricing information file not found: %s', pricing_json)
|
|
118
|
+
except json.JSONDecodeError:
|
|
119
|
+
logger.error('Error decoding JSON from file: %s', pricing_json)
|
|
120
|
+
except Exception as file_err:
|
|
121
|
+
logger.error('Unexpected error occurred while reading file: %s', file_err)
|
|
122
|
+
return {}
|
|
123
|
+
else:
|
|
124
|
+
pricing_url = 'https://raw.githubusercontent.com/tmam-dev/assets/main/pricing.json'
|
|
125
|
+
try:
|
|
126
|
+
# Set a timeout of 10 seconds for both the connection and the read
|
|
127
|
+
response = requests.get(pricing_url, timeout=20)
|
|
128
|
+
response.raise_for_status()
|
|
129
|
+
return response.json()
|
|
130
|
+
except requests.HTTPError as http_err:
|
|
131
|
+
logger.error('HTTP error occured while fetching pricing info: %s', http_err)
|
|
132
|
+
except Exception as err:
|
|
133
|
+
logger.error('Unexpected error occurred while fetching pricing info: %s', err)
|
|
134
|
+
return {}
|
|
135
|
+
|
|
136
|
+
def handle_exception(span,e):
|
|
137
|
+
"""Handles Exception when LLM Function fails or trace creation fails."""
|
|
138
|
+
|
|
139
|
+
span.record_exception(e)
|
|
140
|
+
span.set_status(Status(StatusCode.ERROR))
|
|
141
|
+
|
|
142
|
+
def calculate_ttft(timestamps: List[float], start_time: float) -> float:
|
|
143
|
+
"""
|
|
144
|
+
Calculate the time to the first tokens.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
if timestamps:
|
|
148
|
+
return timestamps[0] - start_time
|
|
149
|
+
return 0.0
|
|
150
|
+
|
|
151
|
+
def calculate_tbt(timestamps: List[float]) -> float:
|
|
152
|
+
"""
|
|
153
|
+
Calculate the average time between tokens.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
if len(timestamps) > 1:
|
|
157
|
+
time_diffs = [timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps))]
|
|
158
|
+
return sum(time_diffs) / len(time_diffs)
|
|
159
|
+
return 0.0
|
|
160
|
+
|
|
161
|
+
def create_metrics_attributes(
|
|
162
|
+
service_name: str,
|
|
163
|
+
deployment_environment: str,
|
|
164
|
+
operation: str,
|
|
165
|
+
system: str,
|
|
166
|
+
request_model: str,
|
|
167
|
+
server_address: str,
|
|
168
|
+
server_port: int,
|
|
169
|
+
response_model: str,
|
|
170
|
+
) -> Dict[Any, Any]:
|
|
171
|
+
"""
|
|
172
|
+
Returns OTel metrics attributes
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
TELEMETRY_SDK_NAME: 'trail',
|
|
177
|
+
SERVICE_NAME: service_name,
|
|
178
|
+
DEPLOYMENT_ENVIRONMENT: deployment_environment,
|
|
179
|
+
SemanticConvetion.GEN_AI_OPERATION: operation,
|
|
180
|
+
SemanticConvetion.GEN_AI_SYSTEM: system,
|
|
181
|
+
SemanticConvetion.GEN_AI_REQUEST_MODEL: request_model,
|
|
182
|
+
SemanticConvetion.SERVER_ADDRESS: server_address,
|
|
183
|
+
SemanticConvetion.SERVER_PORT: server_port,
|
|
184
|
+
SemanticConvetion.GEN_AI_RESPONSE_MODEL: response_model
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def set_server_address_and_port(client_instance: Any,
|
|
188
|
+
default_server_address: str, default_server_port: int) -> Tuple[str, int]:
|
|
189
|
+
"""
|
|
190
|
+
Determines and returns the server address and port based on the provided client's `base_url`,
|
|
191
|
+
using defaults if none found or values are None.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
# Try getting base_url from multiple potential attributes
|
|
195
|
+
base_client = getattr(client_instance, '_client', None)
|
|
196
|
+
base_url = getattr(base_client, 'base_url', None)
|
|
197
|
+
|
|
198
|
+
if not base_url:
|
|
199
|
+
# Attempt to get endpoint from instance._config.endpoint if base_url is not set
|
|
200
|
+
config = getattr(client_instance, '_config', None)
|
|
201
|
+
base_url = getattr(config, 'endpoint', None)
|
|
202
|
+
|
|
203
|
+
if not base_url:
|
|
204
|
+
# Attempt to get server_url from instance.sdk_configuration.server_url
|
|
205
|
+
config = getattr(client_instance, 'sdk_configuration', None)
|
|
206
|
+
base_url = getattr(config, 'server_url', None)
|
|
207
|
+
|
|
208
|
+
if base_url:
|
|
209
|
+
if isinstance(base_url, str):
|
|
210
|
+
url = urlparse(base_url)
|
|
211
|
+
server_address = url.hostname or default_server_address
|
|
212
|
+
server_port = url.port if url.port is not None else default_server_port
|
|
213
|
+
else: # base_url might not be a str; handle as an object.
|
|
214
|
+
server_address = getattr(base_url, 'host', None) or default_server_address
|
|
215
|
+
port_attr = getattr(base_url, 'port', None)
|
|
216
|
+
server_port = port_attr if port_attr is not None else default_server_port
|
|
217
|
+
else: # no base_url or endpoint provided; use defaults.
|
|
218
|
+
server_address = default_server_address
|
|
219
|
+
server_port = default_server_port
|
|
220
|
+
|
|
221
|
+
return server_address, server_port
|
|
222
|
+
|
|
223
|
+
def otel_event(name, attributes, body):
|
|
224
|
+
"""
|
|
225
|
+
Returns an OpenTelemetry Event object
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
return Event(
|
|
229
|
+
name=name,
|
|
230
|
+
attributes=attributes,
|
|
231
|
+
body=body,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def extract_and_format_input(messages):
|
|
235
|
+
"""
|
|
236
|
+
Process a list of messages to extract content and categorize
|
|
237
|
+
them into fixed roles like 'user', 'assistant', 'system', 'tool'.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
fixed_roles = ['user', 'assistant', 'system', 'tool', 'developer']
|
|
241
|
+
formatted_messages = {role_key: {'role': '', 'content': ''} for role_key in fixed_roles}
|
|
242
|
+
|
|
243
|
+
for message in messages:
|
|
244
|
+
message = response_as_dict(message)
|
|
245
|
+
|
|
246
|
+
role = message.get('role')
|
|
247
|
+
if role not in fixed_roles:
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
content = message.get('content', '')
|
|
251
|
+
|
|
252
|
+
# Prepare content as a string, handling both list and str
|
|
253
|
+
if isinstance(content, list):
|
|
254
|
+
content_str = ", ".join(str(item) for item in content)
|
|
255
|
+
else:
|
|
256
|
+
content_str = content
|
|
257
|
+
|
|
258
|
+
# Set the role in the formatted message and concatenate content
|
|
259
|
+
if not formatted_messages[role]['role']:
|
|
260
|
+
formatted_messages[role]['role'] = role
|
|
261
|
+
|
|
262
|
+
if formatted_messages[role]['content']:
|
|
263
|
+
formatted_messages[role]['content'] += ' ' + content_str
|
|
264
|
+
else:
|
|
265
|
+
formatted_messages[role]['content'] = content_str
|
|
266
|
+
|
|
267
|
+
return formatted_messages
|
|
268
|
+
|
|
269
|
+
# To be removed one the change to log events (from span events) is complete
|
|
270
|
+
def concatenate_all_contents(formatted_messages):
|
|
271
|
+
"""
|
|
272
|
+
Concatenate all 'content' fields into a single strin
|
|
273
|
+
"""
|
|
274
|
+
return ' '.join(
|
|
275
|
+
message_data['content']
|
|
276
|
+
for message_data in formatted_messages.values()
|
|
277
|
+
if message_data['content']
|
|
278
|
+
)
|