trailai 1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. trail/__helpers.py +278 -0
  2. trail/__init__.py +1857 -0
  3. trail/core/datetime_utils.py +4 -0
  4. trail/core/pydantic_utilities.py +17 -0
  5. trail/evals/__init__.py +12 -0
  6. trail/evals/all.py +169 -0
  7. trail/evals/bias_detection.py +173 -0
  8. trail/evals/hallucination.py +170 -0
  9. trail/evals/toxicity.py +168 -0
  10. trail/evals/utils.py +275 -0
  11. trail/instrumentation/ag2/__init__.py +49 -0
  12. trail/instrumentation/ag2/ag2.py +163 -0
  13. trail/instrumentation/ai21/__init__.py +67 -0
  14. trail/instrumentation/ai21/ai21.py +191 -0
  15. trail/instrumentation/ai21/async_ai21.py +191 -0
  16. trail/instrumentation/ai21/utils.py +409 -0
  17. trail/instrumentation/anthropic/__init__.py +50 -0
  18. trail/instrumentation/anthropic/anthropic.py +149 -0
  19. trail/instrumentation/anthropic/async_anthropic.py +149 -0
  20. trail/instrumentation/anthropic/utils.py +251 -0
  21. trail/instrumentation/assemblyai/__init__.py +43 -0
  22. trail/instrumentation/assemblyai/assemblyai.py +150 -0
  23. trail/instrumentation/astra/__init__.py +178 -0
  24. trail/instrumentation/astra/astra.py +45 -0
  25. trail/instrumentation/astra/async_astra.py +45 -0
  26. trail/instrumentation/astra/utils.py +102 -0
  27. trail/instrumentation/azure_ai_inference/__init__.py +53 -0
  28. trail/instrumentation/azure_ai_inference/async_azure_ai_inference.py +144 -0
  29. trail/instrumentation/azure_ai_inference/azure_ai_inference.py +144 -0
  30. trail/instrumentation/azure_ai_inference/utils.py +225 -0
  31. trail/instrumentation/bedrock/__init__.py +42 -0
  32. trail/instrumentation/bedrock/bedrock.py +77 -0
  33. trail/instrumentation/bedrock/utils.py +252 -0
  34. trail/instrumentation/chroma/__init__.py +86 -0
  35. trail/instrumentation/chroma/chroma.py +199 -0
  36. trail/instrumentation/cohere/__init__.py +74 -0
  37. trail/instrumentation/cohere/async_cohere.py +610 -0
  38. trail/instrumentation/cohere/cohere.py +610 -0
  39. trail/instrumentation/controlflow/__init__.py +56 -0
  40. trail/instrumentation/controlflow/controlflow.py +113 -0
  41. trail/instrumentation/crawl4ai/__init__.py +52 -0
  42. trail/instrumentation/crawl4ai/async_crawl4ai.py +104 -0
  43. trail/instrumentation/crawl4ai/crawl4ai.py +104 -0
  44. trail/instrumentation/crewai/__init__.py +50 -0
  45. trail/instrumentation/crewai/crewai.py +153 -0
  46. trail/instrumentation/dynamiq/__init__.py +64 -0
  47. trail/instrumentation/dynamiq/dynamiq.py +110 -0
  48. trail/instrumentation/elevenlabs/__init__.py +70 -0
  49. trail/instrumentation/elevenlabs/async_elevenlabs.py +146 -0
  50. trail/instrumentation/elevenlabs/elevenlabs.py +147 -0
  51. trail/instrumentation/embedchain/__init__.py +55 -0
  52. trail/instrumentation/embedchain/embedchain.py +165 -0
  53. trail/instrumentation/firecrawl/__init__.py +49 -0
  54. trail/instrumentation/firecrawl/firecrawl.py +90 -0
  55. trail/instrumentation/google_ai_studio/__init__.py +56 -0
  56. trail/instrumentation/google_ai_studio/async_google_ai_studio.py +227 -0
  57. trail/instrumentation/google_ai_studio/google_ai_studio.py +227 -0
  58. trail/instrumentation/gpt4all/__init__.py +52 -0
  59. trail/instrumentation/gpt4all/gpt4all.py +497 -0
  60. trail/instrumentation/gpu/__init__.py +213 -0
  61. trail/instrumentation/groq/__init__.py +50 -0
  62. trail/instrumentation/groq/async_groq.py +467 -0
  63. trail/instrumentation/groq/groq.py +467 -0
  64. trail/instrumentation/haystack/__init__.py +49 -0
  65. trail/instrumentation/haystack/haystack.py +84 -0
  66. trail/instrumentation/julep/__init__.py +80 -0
  67. trail/instrumentation/julep/async_julep.py +111 -0
  68. trail/instrumentation/julep/julep.py +112 -0
  69. trail/instrumentation/langchain/__init__.py +118 -0
  70. trail/instrumentation/langchain/async_langchain.py +388 -0
  71. trail/instrumentation/langchain/langchain.py +362 -0
  72. trail/instrumentation/letta/__init__.py +77 -0
  73. trail/instrumentation/letta/letta.py +188 -0
  74. trail/instrumentation/litellm/__init__.py +67 -0
  75. trail/instrumentation/litellm/async_litellm.py +592 -0
  76. trail/instrumentation/litellm/litellm.py +592 -0
  77. trail/instrumentation/llamaindex/__init__.py +55 -0
  78. trail/instrumentation/llamaindex/llamaindex.py +86 -0
  79. trail/instrumentation/mem0/__init__.py +79 -0
  80. trail/instrumentation/mem0/mem0.py +115 -0
  81. trail/instrumentation/milvus/__init__.py +94 -0
  82. trail/instrumentation/milvus/milvus.py +179 -0
  83. trail/instrumentation/mistral/__init__.py +80 -0
  84. trail/instrumentation/mistral/async_mistral.py +611 -0
  85. trail/instrumentation/mistral/mistral.py +611 -0
  86. trail/instrumentation/multion/__init__.py +80 -0
  87. trail/instrumentation/multion/async_multion.py +133 -0
  88. trail/instrumentation/multion/multion.py +133 -0
  89. trail/instrumentation/ollama/__init__.py +84 -0
  90. trail/instrumentation/ollama/async_ollama.py +184 -0
  91. trail/instrumentation/ollama/ollama.py +184 -0
  92. trail/instrumentation/ollama/utils.py +332 -0
  93. trail/instrumentation/openai/__init__.py +132 -0
  94. trail/instrumentation/openai/async_openai.py +1411 -0
  95. trail/instrumentation/openai/openai.py +1411 -0
  96. trail/instrumentation/openai_agents/__init__.py +42 -0
  97. trail/instrumentation/openai_agents/openai_agents.py +65 -0
  98. trail/instrumentation/phidata/__init__.py +42 -0
  99. trail/instrumentation/phidata/phidata.py +100 -0
  100. trail/instrumentation/pinecone/__init__.py +66 -0
  101. trail/instrumentation/pinecone/pinecone.py +173 -0
  102. trail/instrumentation/premai/__init__.py +51 -0
  103. trail/instrumentation/premai/premai.py +556 -0
  104. trail/instrumentation/qdrant/__init__.py +295 -0
  105. trail/instrumentation/qdrant/async_qdrant.py +267 -0
  106. trail/instrumentation/qdrant/qdrant.py +274 -0
  107. trail/instrumentation/reka/__init__.py +54 -0
  108. trail/instrumentation/reka/async_reka.py +197 -0
  109. trail/instrumentation/reka/reka.py +197 -0
  110. trail/instrumentation/together/__init__.py +70 -0
  111. trail/instrumentation/together/async_together.py +600 -0
  112. trail/instrumentation/together/together.py +600 -0
  113. trail/instrumentation/transformers/__init__.py +37 -0
  114. trail/instrumentation/transformers/transformers.py +197 -0
  115. trail/instrumentation/vertexai/__init__.py +97 -0
  116. trail/instrumentation/vertexai/async_vertexai.py +459 -0
  117. trail/instrumentation/vertexai/vertexai.py +459 -0
  118. trail/instrumentation/vllm/__init__.py +43 -0
  119. trail/instrumentation/vllm/vllm.py +173 -0
  120. trail/model/dataset.py +1022 -0
  121. trail/otel/events.py +80 -0
  122. trail/otel/metrics.py +218 -0
  123. trail/otel/tracing.py +88 -0
  124. trail/semcov/__init__.py +310 -0
  125. trail/utils/experiment.py +106 -0
  126. trail/utils/utils.py +22 -0
  127. trailai-1.6.dist-info/METADATA +32 -0
  128. trailai-1.6.dist-info/RECORD +130 -0
  129. trailai-1.6.dist-info/WHEEL +4 -0
  130. trailai-1.6.dist-info/licenses/LICENSE +201 -0
trail/__helpers.py ADDED
@@ -0,0 +1,278 @@
1
+ # pylint: disable=bare-except, broad-exception-caught
2
+ """
3
+ This module has functions to calculate model costs based on tokens and to fetch pricing information.
4
+ """
5
+ import os
6
+ import json
7
+ import logging
8
+ from urllib.parse import urlparse
9
+ from typing import Any, Dict, List, Tuple
10
+ import math
11
+ import requests
12
+ from opentelemetry.sdk.resources import SERVICE_NAME, TELEMETRY_SDK_NAME, DEPLOYMENT_ENVIRONMENT
13
+ from opentelemetry.trace import Status, StatusCode
14
+ from opentelemetry._events import Event
15
+ from trail.semcov import SemanticConvetion
16
+
17
+ # Set up logging
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def response_as_dict(response):
21
+ """
22
+ Return parsed response as a dict
23
+ """
24
+
25
+ # pylint: disable=no-else-return
26
+ if isinstance(response, dict):
27
+ return response
28
+ if hasattr(response, 'model_dump'):
29
+ return response.model_dump()
30
+ elif hasattr(response, 'parse'):
31
+ return response_as_dict(response.parse())
32
+ else:
33
+ return response
34
+
35
+ def get_env_variable(name, arg_value, error_message):
36
+ """
37
+ Retrieve an environment variable if the argument is not provided
38
+ """
39
+
40
+ if arg_value is not None:
41
+ return arg_value
42
+ value = os.getenv(name)
43
+ if not value:
44
+ logging.error(error_message)
45
+ raise RuntimeError(error_message)
46
+ return value
47
+
48
+ def general_tokens(text):
49
+ """
50
+ Calculate the number of tokens a given text would take up.
51
+ """
52
+
53
+ return math.ceil(len(text) / 2)
54
+
55
+ def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
56
+ """
57
+ Retrieve the cost of processing for a given model based on prompt and tokens.
58
+ """
59
+
60
+ try:
61
+ cost = ((prompt_tokens / 1000) * pricing_info['chat'][model]['promptPrice']) + \
62
+ ((completion_tokens / 1000) * pricing_info['chat'][model]['completionPrice'])
63
+ except:
64
+ cost = 0
65
+ return cost
66
+
67
+ def get_embed_model_cost(model, pricing_info, prompt_tokens):
68
+ """
69
+ Retrieve the cost of processing for a given model based on prompt tokens.
70
+ """
71
+
72
+ try:
73
+ cost = (prompt_tokens / 1000) * pricing_info['embeddings'][model]
74
+ except:
75
+ cost = 0
76
+ return cost
77
+
78
+ def get_image_model_cost(model, pricing_info, size, quality):
79
+ """
80
+ Retrieve the cost of processing for a given model based on image size and quailty.
81
+ """
82
+
83
+ try:
84
+ cost = pricing_info['images'][model][quality][size]
85
+ except:
86
+ cost = 0
87
+ return cost
88
+
89
+ def get_audio_model_cost(model, pricing_info, prompt, duration=None):
90
+ """
91
+ Retrieve the cost of processing for a given model based on prompt.
92
+ """
93
+
94
+ try:
95
+ if prompt:
96
+ cost = (len(prompt) / 1000) * pricing_info['audio'][model]
97
+ else:
98
+ cost = duration * pricing_info['audio'][model]
99
+ except:
100
+ cost = 0
101
+ return cost
102
+
103
+ def fetch_pricing_info(pricing_json=None):
104
+ """
105
+ Fetches pricing information from a specified URL or File Path.
106
+ """
107
+
108
+ if pricing_json:
109
+ is_url = urlparse(pricing_json).scheme != ''
110
+ if is_url:
111
+ pricing_url = pricing_json
112
+ else:
113
+ try:
114
+ with open(pricing_json, mode='r', encoding='utf-8') as f:
115
+ return json.load(f)
116
+ except FileNotFoundError:
117
+ logger.error('Pricing information file not found: %s', pricing_json)
118
+ except json.JSONDecodeError:
119
+ logger.error('Error decoding JSON from file: %s', pricing_json)
120
+ except Exception as file_err:
121
+ logger.error('Unexpected error occurred while reading file: %s', file_err)
122
+ return {}
123
+ else:
124
+ pricing_url = 'https://raw.githubusercontent.com/tmam-dev/assets/main/pricing.json'
125
+ try:
126
+ # Set a timeout of 10 seconds for both the connection and the read
127
+ response = requests.get(pricing_url, timeout=20)
128
+ response.raise_for_status()
129
+ return response.json()
130
+ except requests.HTTPError as http_err:
131
+ logger.error('HTTP error occured while fetching pricing info: %s', http_err)
132
+ except Exception as err:
133
+ logger.error('Unexpected error occurred while fetching pricing info: %s', err)
134
+ return {}
135
+
136
+ def handle_exception(span,e):
137
+ """Handles Exception when LLM Function fails or trace creation fails."""
138
+
139
+ span.record_exception(e)
140
+ span.set_status(Status(StatusCode.ERROR))
141
+
142
+ def calculate_ttft(timestamps: List[float], start_time: float) -> float:
143
+ """
144
+ Calculate the time to the first tokens.
145
+ """
146
+
147
+ if timestamps:
148
+ return timestamps[0] - start_time
149
+ return 0.0
150
+
151
+ def calculate_tbt(timestamps: List[float]) -> float:
152
+ """
153
+ Calculate the average time between tokens.
154
+ """
155
+
156
+ if len(timestamps) > 1:
157
+ time_diffs = [timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps))]
158
+ return sum(time_diffs) / len(time_diffs)
159
+ return 0.0
160
+
161
+ def create_metrics_attributes(
162
+ service_name: str,
163
+ deployment_environment: str,
164
+ operation: str,
165
+ system: str,
166
+ request_model: str,
167
+ server_address: str,
168
+ server_port: int,
169
+ response_model: str,
170
+ ) -> Dict[Any, Any]:
171
+ """
172
+ Returns OTel metrics attributes
173
+ """
174
+
175
+ return {
176
+ TELEMETRY_SDK_NAME: 'trail',
177
+ SERVICE_NAME: service_name,
178
+ DEPLOYMENT_ENVIRONMENT: deployment_environment,
179
+ SemanticConvetion.GEN_AI_OPERATION: operation,
180
+ SemanticConvetion.GEN_AI_SYSTEM: system,
181
+ SemanticConvetion.GEN_AI_REQUEST_MODEL: request_model,
182
+ SemanticConvetion.SERVER_ADDRESS: server_address,
183
+ SemanticConvetion.SERVER_PORT: server_port,
184
+ SemanticConvetion.GEN_AI_RESPONSE_MODEL: response_model
185
+ }
186
+
187
+ def set_server_address_and_port(client_instance: Any,
188
+ default_server_address: str, default_server_port: int) -> Tuple[str, int]:
189
+ """
190
+ Determines and returns the server address and port based on the provided client's `base_url`,
191
+ using defaults if none found or values are None.
192
+ """
193
+
194
+ # Try getting base_url from multiple potential attributes
195
+ base_client = getattr(client_instance, '_client', None)
196
+ base_url = getattr(base_client, 'base_url', None)
197
+
198
+ if not base_url:
199
+ # Attempt to get endpoint from instance._config.endpoint if base_url is not set
200
+ config = getattr(client_instance, '_config', None)
201
+ base_url = getattr(config, 'endpoint', None)
202
+
203
+ if not base_url:
204
+ # Attempt to get server_url from instance.sdk_configuration.server_url
205
+ config = getattr(client_instance, 'sdk_configuration', None)
206
+ base_url = getattr(config, 'server_url', None)
207
+
208
+ if base_url:
209
+ if isinstance(base_url, str):
210
+ url = urlparse(base_url)
211
+ server_address = url.hostname or default_server_address
212
+ server_port = url.port if url.port is not None else default_server_port
213
+ else: # base_url might not be a str; handle as an object.
214
+ server_address = getattr(base_url, 'host', None) or default_server_address
215
+ port_attr = getattr(base_url, 'port', None)
216
+ server_port = port_attr if port_attr is not None else default_server_port
217
+ else: # no base_url or endpoint provided; use defaults.
218
+ server_address = default_server_address
219
+ server_port = default_server_port
220
+
221
+ return server_address, server_port
222
+
223
+ def otel_event(name, attributes, body):
224
+ """
225
+ Returns an OpenTelemetry Event object
226
+ """
227
+
228
+ return Event(
229
+ name=name,
230
+ attributes=attributes,
231
+ body=body,
232
+ )
233
+
234
+ def extract_and_format_input(messages):
235
+ """
236
+ Process a list of messages to extract content and categorize
237
+ them into fixed roles like 'user', 'assistant', 'system', 'tool'.
238
+ """
239
+
240
+ fixed_roles = ['user', 'assistant', 'system', 'tool', 'developer']
241
+ formatted_messages = {role_key: {'role': '', 'content': ''} for role_key in fixed_roles}
242
+
243
+ for message in messages:
244
+ message = response_as_dict(message)
245
+
246
+ role = message.get('role')
247
+ if role not in fixed_roles:
248
+ continue
249
+
250
+ content = message.get('content', '')
251
+
252
+ # Prepare content as a string, handling both list and str
253
+ if isinstance(content, list):
254
+ content_str = ", ".join(str(item) for item in content)
255
+ else:
256
+ content_str = content
257
+
258
+ # Set the role in the formatted message and concatenate content
259
+ if not formatted_messages[role]['role']:
260
+ formatted_messages[role]['role'] = role
261
+
262
+ if formatted_messages[role]['content']:
263
+ formatted_messages[role]['content'] += ' ' + content_str
264
+ else:
265
+ formatted_messages[role]['content'] = content_str
266
+
267
+ return formatted_messages
268
+
269
+ # To be removed one the change to log events (from span events) is complete
270
+ def concatenate_all_contents(formatted_messages):
271
+ """
272
+ Concatenate all 'content' fields into a single strin
273
+ """
274
+ return ' '.join(
275
+ message_data['content']
276
+ for message_data in formatted_messages.values()
277
+ if message_data['content']
278
+ )