openlit 1.33.11__py3-none-any.whl → 1.33.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openlit/__helpers.py CHANGED
@@ -7,8 +7,8 @@ import json
7
7
  import logging
8
8
  from urllib.parse import urlparse
9
9
  from typing import Any, Dict, List, Tuple
10
+ import math
10
11
  import requests
11
- import tiktoken
12
12
  from opentelemetry.sdk.resources import SERVICE_NAME, TELEMETRY_SDK_NAME, DEPLOYMENT_ENVIRONMENT
13
13
  from opentelemetry.trace import Status, StatusCode
14
14
  from opentelemetry._events import Event
@@ -21,12 +21,13 @@ def response_as_dict(response):
21
21
  """
22
22
  Return parsed response as a dict
23
23
  """
24
+
24
25
  # pylint: disable=no-else-return
25
26
  if isinstance(response, dict):
26
27
  return response
27
- if hasattr(response, "model_dump"):
28
+ if hasattr(response, 'model_dump'):
28
29
  return response.model_dump()
29
- elif hasattr(response, "parse"):
30
+ elif hasattr(response, 'parse'):
30
31
  return response_as_dict(response.parse())
31
32
  else:
32
33
  return response
@@ -34,8 +35,8 @@ def response_as_dict(response):
34
35
  def get_env_variable(name, arg_value, error_message):
35
36
  """
36
37
  Retrieve an environment variable if the argument is not provided
37
- and raise an error if both are not set.
38
38
  """
39
+
39
40
  if arg_value is not None:
40
41
  return arg_value
41
42
  value = os.getenv(name)
@@ -44,58 +45,21 @@ def get_env_variable(name, arg_value, error_message):
44
45
  raise RuntimeError(error_message)
45
46
  return value
46
47
 
47
- def openai_tokens(text, model):
48
- """
49
- Calculate the number of tokens a given text would take up for a specified model.
50
-
51
- Args:
52
- text (str): The input text to be encoded.
53
- model (str): The model identifier used for encoding.
54
-
55
- Returns:
56
- int: The number of tokens the text is encoded into.
57
- """
58
- try:
59
- encoding = tiktoken.encoding_for_model(model)
60
- except:
61
- encoding = tiktoken.get_encoding("cl100k_base")
62
-
63
- num_tokens = len(encoding.encode(text))
64
- return num_tokens
65
-
66
48
  def general_tokens(text):
67
49
  """
68
50
  Calculate the number of tokens a given text would take up.
69
-
70
- Args:
71
- text (str): The input text to be encoded.
72
- model (str): The model identifier used for encoding.
73
-
74
- Returns:
75
- int: The number of tokens the text is encoded into.
76
51
  """
77
52
 
78
- encoding = tiktoken.get_encoding("gpt2")
79
-
80
- num_tokens = len(encoding.encode(text))
81
- return num_tokens
53
+ return math.ceil(len(text) / 2)
82
54
 
83
55
  def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
84
56
  """
85
57
  Retrieve the cost of processing for a given model based on prompt and tokens.
86
-
87
- Args:
88
- model (str): The model identifier.
89
- pricing_info (dict): A dictionary containing pricing information for various models.
90
- prompt_tokens (int): Number of tokens in the prompt.
91
- completion_tokens (int): Number of tokens in the completion if applicable.
92
-
93
- Returns:
94
- float: The calculated cost for the operation.
95
58
  """
59
+
96
60
  try:
97
- cost = ((prompt_tokens / 1000) * pricing_info["chat"][model]["promptPrice"]) + \
98
- ((completion_tokens / 1000) * pricing_info["chat"][model]["completionPrice"])
61
+ cost = ((prompt_tokens / 1000) * pricing_info['chat'][model]['promptPrice']) + \
62
+ ((completion_tokens / 1000) * pricing_info['chat'][model]['completionPrice'])
99
63
  except:
100
64
  cost = 0
101
65
  return cost
@@ -103,17 +67,10 @@ def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
103
67
  def get_embed_model_cost(model, pricing_info, prompt_tokens):
104
68
  """
105
69
  Retrieve the cost of processing for a given model based on prompt tokens.
106
-
107
- Args:
108
- model (str): The model identifier.
109
- pricing_info (dict): A dictionary containing pricing information for various models.
110
- prompt_tokens (int): Number of tokens in the prompt.
111
-
112
- Returns:
113
- float: The calculated cost for the operation.
114
70
  """
71
+
115
72
  try:
116
- cost = (prompt_tokens / 1000) * pricing_info["embeddings"][model]
73
+ cost = (prompt_tokens / 1000) * pricing_info['embeddings'][model]
117
74
  except:
118
75
  cost = 0
119
76
  return cost
@@ -121,18 +78,10 @@ def get_embed_model_cost(model, pricing_info, prompt_tokens):
121
78
  def get_image_model_cost(model, pricing_info, size, quality):
122
79
  """
123
80
  Retrieve the cost of processing for a given model based on image size and quailty.
124
-
125
- Args:
126
- model (str): The model identifier.
127
- pricing_info (dict): A dictionary containing pricing information for various models.
128
- size (str): Size of the Image.
129
- quality (int): Quality of the Image.
130
-
131
- Returns:
132
- float: The calculated cost for the operation.
133
81
  """
82
+
134
83
  try:
135
- cost = pricing_info["images"][model][quality][size]
84
+ cost = pricing_info['images'][model][quality][size]
136
85
  except:
137
86
  cost = 0
138
87
  return cost
@@ -140,20 +89,13 @@ def get_image_model_cost(model, pricing_info, size, quality):
140
89
  def get_audio_model_cost(model, pricing_info, prompt, duration=None):
141
90
  """
142
91
  Retrieve the cost of processing for a given model based on prompt.
143
-
144
- Args:
145
- model (str): The model identifier.
146
- pricing_info (dict): A dictionary containing pricing information for various models.
147
- prompt (str): Prompt to the LLM Model
148
-
149
- Returns:
150
- float: The calculated cost for the operation.
151
92
  """
93
+
152
94
  try:
153
95
  if prompt:
154
- cost = (len(prompt) / 1000) * pricing_info["audio"][model]
96
+ cost = (len(prompt) / 1000) * pricing_info['audio'][model]
155
97
  else:
156
- cost = duration * pricing_info["audio"][model]
98
+ cost = duration * pricing_info['audio'][model]
157
99
  except:
158
100
  cost = 0
159
101
  return cost
@@ -161,15 +103,10 @@ def get_audio_model_cost(model, pricing_info, prompt, duration=None):
161
103
  def fetch_pricing_info(pricing_json=None):
162
104
  """
163
105
  Fetches pricing information from a specified URL or File Path.
164
-
165
- Args:
166
- pricing_json(str): path or url to the pricing json file
167
-
168
- Returns:
169
- dict: The pricing json
170
106
  """
107
+
171
108
  if pricing_json:
172
- is_url = urlparse(pricing_json).scheme != ""
109
+ is_url = urlparse(pricing_json).scheme != ''
173
110
  if is_url:
174
111
  pricing_url = pricing_json
175
112
  else:
@@ -177,39 +114,36 @@ def fetch_pricing_info(pricing_json=None):
177
114
  with open(pricing_json, mode='r', encoding='utf-8') as f:
178
115
  return json.load(f)
179
116
  except FileNotFoundError:
180
- logger.error("Pricing information file not found: %s", pricing_json)
117
+ logger.error('Pricing information file not found: %s', pricing_json)
181
118
  except json.JSONDecodeError:
182
- logger.error("Error decoding JSON from file: %s", pricing_json)
119
+ logger.error('Error decoding JSON from file: %s', pricing_json)
183
120
  except Exception as file_err:
184
- logger.error("Unexpected error occurred while reading file: %s", file_err)
121
+ logger.error('Unexpected error occurred while reading file: %s', file_err)
185
122
  return {}
186
123
  else:
187
- pricing_url = "https://raw.githubusercontent.com/openlit/openlit/main/assets/pricing.json"
124
+ pricing_url = 'https://raw.githubusercontent.com/openlit/openlit/main/assets/pricing.json'
188
125
  try:
189
126
  # Set a timeout of 10 seconds for both the connection and the read
190
127
  response = requests.get(pricing_url, timeout=20)
191
128
  response.raise_for_status()
192
129
  return response.json()
193
130
  except requests.HTTPError as http_err:
194
- logger.error("HTTP error occured while fetching pricing info: %s", http_err)
131
+ logger.error('HTTP error occured while fetching pricing info: %s', http_err)
195
132
  except Exception as err:
196
- logger.error("Unexpected error occurred while fetching pricing info: %s", err)
133
+ logger.error('Unexpected error occurred while fetching pricing info: %s', err)
197
134
  return {}
198
135
 
199
136
  def handle_exception(span,e):
200
137
  """Handles Exception when LLM Function fails or trace creation fails."""
201
- # Record the exception details within the span
138
+
202
139
  span.record_exception(e)
203
140
  span.set_status(Status(StatusCode.ERROR))
204
141
 
205
142
  def calculate_ttft(timestamps: List[float], start_time: float) -> float:
206
143
  """
207
144
  Calculate the time to the first tokens.
208
-
209
- :param timestamps: List of timestamps for received tokens
210
- :param start_time: The start time of the streaming process
211
- :return: Time to the first tokens
212
145
  """
146
+
213
147
  if timestamps:
214
148
  return timestamps[0] - start_time
215
149
  return 0.0
@@ -217,10 +151,8 @@ def calculate_ttft(timestamps: List[float], start_time: float) -> float:
217
151
  def calculate_tbt(timestamps: List[float]) -> float:
218
152
  """
219
153
  Calculate the average time between tokens.
220
-
221
- :param timestamps: List of timestamps for received tokens
222
- :return: Average time between tokens
223
154
  """
155
+
224
156
  if len(timestamps) > 1:
225
157
  time_diffs = [timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps))]
226
158
  return sum(time_diffs) / len(time_diffs)
@@ -239,8 +171,9 @@ def create_metrics_attributes(
239
171
  """
240
172
  Returns OTel metrics attributes
241
173
  """
174
+
242
175
  return {
243
- TELEMETRY_SDK_NAME: "openlit",
176
+ TELEMETRY_SDK_NAME: 'openlit',
244
177
  SERVICE_NAME: service_name,
245
178
  DEPLOYMENT_ENVIRONMENT: deployment_environment,
246
179
  SemanticConvetion.GEN_AI_OPERATION: operation,
@@ -259,18 +192,18 @@ def set_server_address_and_port(client_instance: Any,
259
192
  """
260
193
 
261
194
  # Try getting base_url from multiple potential attributes
262
- base_client = getattr(client_instance, "_client", None)
263
- base_url = getattr(base_client, "base_url", None)
195
+ base_client = getattr(client_instance, '_client', None)
196
+ base_url = getattr(base_client, 'base_url', None)
264
197
 
265
198
  if not base_url:
266
199
  # Attempt to get endpoint from instance._config.endpoint if base_url is not set
267
- config = getattr(client_instance, "_config", None)
268
- base_url = getattr(config, "endpoint", None)
200
+ config = getattr(client_instance, '_config', None)
201
+ base_url = getattr(config, 'endpoint', None)
269
202
 
270
203
  if not base_url:
271
204
  # Attempt to get server_url from instance.sdk_configuration.server_url
272
- config = getattr(client_instance, "sdk_configuration", None)
273
- base_url = getattr(config, "server_url", None)
205
+ config = getattr(client_instance, 'sdk_configuration', None)
206
+ base_url = getattr(config, 'server_url', None)
274
207
 
275
208
  if base_url:
276
209
  if isinstance(base_url, str):
@@ -278,8 +211,8 @@ def set_server_address_and_port(client_instance: Any,
278
211
  server_address = url.hostname or default_server_address
279
212
  server_port = url.port if url.port is not None else default_server_port
280
213
  else: # base_url might not be a str; handle as an object.
281
- server_address = getattr(base_url, "host", None) or default_server_address
282
- port_attr = getattr(base_url, "port", None)
214
+ server_address = getattr(base_url, 'host', None) or default_server_address
215
+ port_attr = getattr(base_url, 'port', None)
283
216
  server_port = port_attr if port_attr is not None else default_server_port
284
217
  else: # no base_url or endpoint provided; use defaults.
285
218
  server_address = default_server_address
@@ -301,59 +234,74 @@ def otel_event(name, attributes, body):
301
234
  def extract_and_format_input(messages):
302
235
  """
303
236
  Process a list of messages to extract content and categorize
304
- them into fixed roles like 'user', 'assistant', 'system'.
237
+ them into fixed roles like 'user', 'assistant', 'system', 'tool'.
305
238
  """
306
239
 
307
240
  fixed_roles = ['user', 'assistant', 'system', 'tool'] # Ensure these are your fixed keys
308
241
  # Initialize the dictionary with fixed keys and empty structures
309
- formatted_messages = {role_key: {"role": "", "content": ""} for role_key in fixed_roles}
242
+ formatted_messages = {role_key: {'role': '', 'content': ''} for role_key in fixed_roles}
310
243
 
311
244
  for message in messages:
312
245
  # Normalize the message structure
313
246
  message = response_as_dict(message)
314
247
 
315
248
  # Extract role and content
316
- role = message.get("role")
249
+ role = message.get('role')
317
250
  if role not in fixed_roles:
318
251
  continue # Skip any role not in our predefined roles
319
252
 
320
- content = message.get("content", "")
253
+ content = message.get('content', '')
321
254
 
322
255
  # Prepare content as a string
323
256
  if isinstance(content, list):
324
257
  content_str = ", ".join(
325
- # pylint: disable=line-too-long
326
- f'{item.get("type", "text")}: {item.get("text", item.get("image_url", "").get("url", "") if isinstance(item.get("image_url", ""), dict) else item.get("image_url", ""))}'
258
+ f'{item.get("type", "text")}: {extract_text_from_item(item)}'
327
259
  for item in content
328
260
  )
329
261
  else:
330
262
  content_str = content
331
263
 
332
264
  # Set the role in the formatted message and concatenate content
333
- if not formatted_messages[role]["role"]:
334
- formatted_messages[role]["role"] = role
265
+ if not formatted_messages[role]['role']:
266
+ formatted_messages[role]['role'] = role
335
267
 
336
- if formatted_messages[role]["content"]:
337
- formatted_messages[role]["content"] += " " + content_str
268
+ if formatted_messages[role]['content']:
269
+ formatted_messages[role]['content'] += ' ' + content_str
338
270
  else:
339
- formatted_messages[role]["content"] = content_str
271
+ formatted_messages[role]['content'] = content_str
340
272
 
341
273
  return formatted_messages
342
274
 
275
+ def extract_text_from_item(item):
276
+ """
277
+ Extract text from inpit message
278
+ """
279
+
280
+ #pylint: disable=no-else-return
281
+ if item.get('type') == 'text':
282
+ return item.get('text', '')
283
+ elif item.get('type') == 'image':
284
+ # Handle image content specifically checking for 'url' or 'base64'
285
+ source = item.get('source', {})
286
+ if isinstance(source, dict):
287
+ if source.get('type') == 'base64':
288
+ # Return the actual base64 data if present
289
+ return source.get('data', '[Missing base64 data]')
290
+ elif source.get('type') == 'url':
291
+ return source.get('url', '[Missing URL]')
292
+ elif item.get('type') == 'image_url':
293
+ # New format: Handle the 'image_url' type
294
+ image_url = item.get('image_url', {})
295
+ if isinstance(image_url, dict):
296
+ return image_url.get('url', '[Missing image URL]')
297
+ return ''
298
+
343
299
  # To be removed one the change to log events (from span events) is complete
344
300
  def concatenate_all_contents(formatted_messages):
345
301
  """
346
- Concatenate all 'content' fields from the formatted messages
347
- dictionary into a single string.
348
-
349
- Parameters:
350
- - formatted_messages: Dictionary with roles as keys and corresponding
351
- role and content as values.
352
-
353
- Returns:
354
- - A single string with all content concatenated.
302
+ Concatenate all 'content' fields into a single strin
355
303
  """
356
- return " ".join(
304
+ return ' '.join(
357
305
  message_data['content']
358
306
  for message_data in formatted_messages.values()
359
307
  if message_data['content']
@@ -9,7 +9,7 @@ from openlit.instrumentation.ag2.ag2 import (
9
9
  conversable_agent, agent_run
10
10
  )
11
11
 
12
- _instruments = ("ag2 >= 0.3.2",)
12
+ _instruments = ('ag2 >= 0.3.2',)
13
13
 
14
14
  class AG2Instrumentor(BaseInstrumentor):
15
15
  """
@@ -20,26 +20,26 @@ class AG2Instrumentor(BaseInstrumentor):
20
20
  return _instruments
21
21
 
22
22
  def _instrument(self, **kwargs):
23
- application_name = kwargs.get("application_name", "default_application")
24
- environment = kwargs.get("environment", "default_environment")
25
- tracer = kwargs.get("tracer")
26
- event_provider = kwargs.get("event_provider")
27
- metrics = kwargs.get("metrics_dict")
28
- pricing_info = kwargs.get("pricing_info", {})
29
- capture_message_content = kwargs.get("capture_message_content", False)
30
- disable_metrics = kwargs.get("disable_metrics")
31
- version = importlib.metadata.version("ag2")
23
+ application_name = kwargs.get('application_name', 'default_application')
24
+ environment = kwargs.get('environment', 'default_environment')
25
+ tracer = kwargs.get('tracer')
26
+ event_provider = kwargs.get('event_provider')
27
+ metrics = kwargs.get('metrics_dict')
28
+ pricing_info = kwargs.get('pricing_info', {})
29
+ capture_message_content = kwargs.get('capture_message_content', False)
30
+ disable_metrics = kwargs.get('disable_metrics')
31
+ version = importlib.metadata.version('ag2')
32
32
 
33
33
  wrap_function_wrapper(
34
- "autogen.agentchat.conversable_agent",
35
- "ConversableAgent.__init__",
34
+ 'autogen.agentchat.conversable_agent',
35
+ 'ConversableAgent.__init__',
36
36
  conversable_agent(version, environment, application_name,
37
37
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
38
38
  )
39
39
 
40
40
  wrap_function_wrapper(
41
- "autogen.agentchat.conversable_agent",
42
- "ConversableAgent.run",
41
+ 'autogen.agentchat.conversable_agent',
42
+ 'ConversableAgent.run',
43
43
  agent_run(version, environment, application_name,
44
44
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
45
45
  )
@@ -28,7 +28,7 @@ def set_span_attributes(span, version, operation_name, environment,
28
28
  """
29
29
 
30
30
  # Set Span attributes (OTel Semconv)
31
- span.set_attribute(TELEMETRY_SDK_NAME, "openlit")
31
+ span.set_attribute(TELEMETRY_SDK_NAME, 'openlit')
32
32
  span.set_attribute(SemanticConvetion.GEN_AI_OPERATION, operation_name)
33
33
  span.set_attribute(SemanticConvetion.GEN_AI_SYSTEM, SemanticConvetion.GEN_AI_SYSTEM_AG2)
34
34
  span.set_attribute(SemanticConvetion.GEN_AI_AGENT_NAME, AGENT_NAME)
@@ -73,10 +73,10 @@ def emit_events(response, event_provider, capture_message_content):
73
73
  SemanticConvetion.GEN_AI_SYSTEM: SemanticConvetion.GEN_AI_SYSTEM_AG2
74
74
  },
75
75
  body={
76
- "index": response.chat_history.index(chat),
77
- "message": {
78
- **({"content": chat['content']} if capture_message_content else {}),
79
- "role": 'assistant' if chat['role'] == 'user' else 'user'
76
+ 'index': response.chat_history.index(chat),
77
+ 'message': {
78
+ **({'content': chat['content']} if capture_message_content else {}),
79
+ 'role': 'assistant' if chat['role'] == 'user' else 'user'
80
80
  }
81
81
  }
82
82
  )
@@ -92,12 +92,12 @@ def conversable_agent(version, environment, application_name,
92
92
  global AGENT_NAME, MODEL_AND_NAME_SET, REQUEST_MODEL, SYSTEM_MESSAGE
93
93
 
94
94
  if not MODEL_AND_NAME_SET:
95
- AGENT_NAME = kwargs.get("name", "NOT_FOUND")
96
- REQUEST_MODEL = kwargs.get("llm_config", {}).get('model', 'gpt-4o')
95
+ AGENT_NAME = kwargs.get('name', 'NOT_FOUND')
96
+ REQUEST_MODEL = kwargs.get('llm_config', {}).get('model', 'gpt-4o')
97
97
  SYSTEM_MESSAGE = kwargs.get('system_message', '')
98
98
  MODEL_AND_NAME_SET = True
99
99
 
100
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CREATE_AGENT} {AGENT_NAME}"
100
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CREATE_AGENT} {AGENT_NAME}'
101
101
 
102
102
  with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
103
103
  try:
@@ -117,7 +117,7 @@ def conversable_agent(version, environment, application_name,
117
117
 
118
118
  except Exception as e:
119
119
  handle_exception(span, e)
120
- logger.error("Error in trace creation: %s", e)
120
+ logger.error('Error in trace creation: %s', e)
121
121
  return response
122
122
 
123
123
  return wrapper
@@ -130,7 +130,7 @@ def agent_run(version, environment, application_name,
130
130
  def wrapper(wrapped, instance, args, kwargs):
131
131
  server_address, server_port = '127.0.0.1', 80
132
132
 
133
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_EXECUTE_AGENT_TASK} {AGENT_NAME}"
133
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_EXECUTE_AGENT_TASK} {AGENT_NAME}'
134
134
 
135
135
  with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
136
136
  try:
@@ -157,7 +157,7 @@ def agent_run(version, environment, application_name,
157
157
 
158
158
  except Exception as e:
159
159
  handle_exception(span, e)
160
- logger.error("Error in trace creation: %s", e)
160
+ logger.error('Error in trace creation: %s', e)
161
161
  return response
162
162
 
163
163
  return wrapper
@@ -13,7 +13,7 @@ from openlit.instrumentation.ai21.async_ai21 import (
13
13
  async_chat, async_chat_rag
14
14
  )
15
15
 
16
- _instruments = ("ai21 >= 3.0.0",)
16
+ _instruments = ('ai21 >= 3.0.0',)
17
17
 
18
18
  class AI21Instrumentor(BaseInstrumentor):
19
19
  """
@@ -24,40 +24,40 @@ class AI21Instrumentor(BaseInstrumentor):
24
24
  return _instruments
25
25
 
26
26
  def _instrument(self, **kwargs):
27
- application_name = kwargs.get("application_name", "default_application")
28
- environment = kwargs.get("environment", "default_environment")
29
- tracer = kwargs.get("tracer")
30
- event_provider = kwargs.get("event_provider")
31
- metrics = kwargs.get("metrics_dict")
32
- pricing_info = kwargs.get("pricing_info", {})
33
- capture_message_content = kwargs.get("capture_message_content", False)
34
- disable_metrics = kwargs.get("disable_metrics")
35
- version = importlib.metadata.version("ai21")
27
+ application_name = kwargs.get('application_name', 'default')
28
+ environment = kwargs.get('environment', 'default')
29
+ tracer = kwargs.get('tracer')
30
+ event_provider = kwargs.get('event_provider')
31
+ metrics = kwargs.get('metrics_dict')
32
+ pricing_info = kwargs.get('pricing_info', {})
33
+ capture_message_content = kwargs.get('capture_message_content', False)
34
+ disable_metrics = kwargs.get('disable_metrics')
35
+ version = importlib.metadata.version('ai21')
36
36
 
37
37
  #sync
38
38
  wrap_function_wrapper(
39
- "ai21.clients.studio.resources.chat.chat_completions",
40
- "ChatCompletions.create",
39
+ 'ai21.clients.studio.resources.chat.chat_completions',
40
+ 'ChatCompletions.create',
41
41
  chat(version, environment, application_name,
42
42
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
43
43
  )
44
44
  wrap_function_wrapper(
45
- "ai21.clients.studio.resources.studio_conversational_rag",
46
- "StudioConversationalRag.create",
45
+ 'ai21.clients.studio.resources.studio_conversational_rag',
46
+ 'StudioConversationalRag.create',
47
47
  chat_rag(version, environment, application_name,
48
48
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
49
49
  )
50
50
 
51
51
  #Async
52
52
  wrap_function_wrapper(
53
- "ai21.clients.studio.resources.chat.async_chat_completions",
54
- "AsyncChatCompletions.create",
53
+ 'ai21.clients.studio.resources.chat.async_chat_completions',
54
+ 'AsyncChatCompletions.create',
55
55
  async_chat(version, environment, application_name,
56
56
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
57
57
  )
58
58
  wrap_function_wrapper(
59
- "ai21.clients.studio.resources.studio_conversational_rag",
60
- "AsyncStudioConversationalRag.create",
59
+ 'ai21.clients.studio.resources.studio_conversational_rag',
60
+ 'AsyncStudioConversationalRag.create',
61
61
  async_chat_rag(version, environment, application_name,
62
62
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
63
63
  )
@@ -7,7 +7,6 @@ import time
7
7
  from opentelemetry.trace import SpanKind
8
8
  from openlit.__helpers import (
9
9
  handle_exception,
10
- response_as_dict,
11
10
  set_server_address_and_port,
12
11
  )
13
12
  from openlit.instrumentation.ai21.utils import (
@@ -47,9 +46,9 @@ def chat(version, environment, application_name,
47
46
  self._span = span
48
47
  self._span_name = span_name
49
48
  # Placeholder for aggregating streaming response
50
- self._llmresponse = ""
51
- self._response_id = ""
52
- self._finish_reason = ""
49
+ self._llmresponse = ''
50
+ self._response_id = ''
51
+ self._finish_reason = ''
53
52
  self._input_tokens = 0
54
53
  self._output_tokens = 0
55
54
  self._choices = []
@@ -100,7 +99,7 @@ def chat(version, environment, application_name,
100
99
  )
101
100
  except Exception as e:
102
101
  handle_exception(self._span, e)
103
- logger.error("Error in trace creation: %s", e)
102
+ logger.error('Error in trace creation: %s', e)
104
103
  raise
105
104
 
106
105
  def wrapper(wrapped, instance, args, kwargs):
@@ -109,12 +108,12 @@ def chat(version, environment, application_name,
109
108
  """
110
109
 
111
110
  # Check if streaming is enabled for the API call
112
- streaming = kwargs.get("stream", False)
111
+ streaming = kwargs.get('stream', False)
113
112
 
114
- server_address, server_port = set_server_address_and_port(instance, "api.ai21.com", 443)
115
- request_model = kwargs.get("model", "jamba-1.5-mini")
113
+ server_address, server_port = set_server_address_and_port(instance, 'api.ai21.com', 443)
114
+ request_model = kwargs.get('model', 'jamba-1.5-mini')
116
115
 
117
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}"
116
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}'
118
117
 
119
118
  # pylint: disable=no-else-return
120
119
  if streaming:
@@ -129,7 +128,7 @@ def chat(version, environment, application_name,
129
128
  start_time = time.time()
130
129
  response = wrapped(*args, **kwargs)
131
130
  response = process_chat_response(
132
- response=response_as_dict(response),
131
+ response=response,
133
132
  request_model=request_model,
134
133
  pricing_info=pricing_info,
135
134
  server_port=server_port,
@@ -161,16 +160,16 @@ def chat_rag(version, environment, application_name,
161
160
  Wraps the GenAI function call.
162
161
  """
163
162
 
164
- server_address, server_port = set_server_address_and_port(instance, "api.ai21.com", 443)
165
- request_model = kwargs.get("model", "jamba-1.5-mini")
163
+ server_address, server_port = set_server_address_and_port(instance, 'api.ai21.com', 443)
164
+ request_model = kwargs.get('model', 'jamba-1.5-mini')
166
165
 
167
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}"
166
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}'
168
167
 
169
168
  with tracer.start_as_current_span(span_name, kind= SpanKind.CLIENT) as span:
170
169
  start_time = time.time()
171
170
  response = wrapped(*args, **kwargs)
172
171
  response = process_chat_rag_response(
173
- response=response_as_dict(response),
172
+ response=response,
174
173
  request_model=request_model,
175
174
  pricing_info=pricing_info,
176
175
  server_port=server_port,