openlit 1.33.11__py3-none-any.whl → 1.33.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openlit/__helpers.py CHANGED
@@ -21,12 +21,13 @@ def response_as_dict(response):
21
21
  """
22
22
  Return parsed response as a dict
23
23
  """
24
+
24
25
  # pylint: disable=no-else-return
25
26
  if isinstance(response, dict):
26
27
  return response
27
- if hasattr(response, "model_dump"):
28
+ if hasattr(response, 'model_dump'):
28
29
  return response.model_dump()
29
- elif hasattr(response, "parse"):
30
+ elif hasattr(response, 'parse'):
30
31
  return response_as_dict(response.parse())
31
32
  else:
32
33
  return response
@@ -34,8 +35,8 @@ def response_as_dict(response):
34
35
  def get_env_variable(name, arg_value, error_message):
35
36
  """
36
37
  Retrieve an environment variable if the argument is not provided
37
- and raise an error if both are not set.
38
38
  """
39
+
39
40
  if arg_value is not None:
40
41
  return arg_value
41
42
  value = os.getenv(name)
@@ -47,14 +48,8 @@ def get_env_variable(name, arg_value, error_message):
47
48
  def openai_tokens(text, model):
48
49
  """
49
50
  Calculate the number of tokens a given text would take up for a specified model.
50
-
51
- Args:
52
- text (str): The input text to be encoded.
53
- model (str): The model identifier used for encoding.
54
-
55
- Returns:
56
- int: The number of tokens the text is encoded into.
57
51
  """
52
+
58
53
  try:
59
54
  encoding = tiktoken.encoding_for_model(model)
60
55
  except:
@@ -66,16 +61,9 @@ def openai_tokens(text, model):
66
61
  def general_tokens(text):
67
62
  """
68
63
  Calculate the number of tokens a given text would take up.
69
-
70
- Args:
71
- text (str): The input text to be encoded.
72
- model (str): The model identifier used for encoding.
73
-
74
- Returns:
75
- int: The number of tokens the text is encoded into.
76
64
  """
77
65
 
78
- encoding = tiktoken.get_encoding("gpt2")
66
+ encoding = tiktoken.get_encoding('gpt2')
79
67
 
80
68
  num_tokens = len(encoding.encode(text))
81
69
  return num_tokens
@@ -83,19 +71,11 @@ def general_tokens(text):
83
71
  def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
84
72
  """
85
73
  Retrieve the cost of processing for a given model based on prompt and tokens.
86
-
87
- Args:
88
- model (str): The model identifier.
89
- pricing_info (dict): A dictionary containing pricing information for various models.
90
- prompt_tokens (int): Number of tokens in the prompt.
91
- completion_tokens (int): Number of tokens in the completion if applicable.
92
-
93
- Returns:
94
- float: The calculated cost for the operation.
95
74
  """
75
+
96
76
  try:
97
- cost = ((prompt_tokens / 1000) * pricing_info["chat"][model]["promptPrice"]) + \
98
- ((completion_tokens / 1000) * pricing_info["chat"][model]["completionPrice"])
77
+ cost = ((prompt_tokens / 1000) * pricing_info['chat'][model]['promptPrice']) + \
78
+ ((completion_tokens / 1000) * pricing_info['chat'][model]['completionPrice'])
99
79
  except:
100
80
  cost = 0
101
81
  return cost
@@ -103,17 +83,10 @@ def get_chat_model_cost(model, pricing_info, prompt_tokens, completion_tokens):
103
83
  def get_embed_model_cost(model, pricing_info, prompt_tokens):
104
84
  """
105
85
  Retrieve the cost of processing for a given model based on prompt tokens.
106
-
107
- Args:
108
- model (str): The model identifier.
109
- pricing_info (dict): A dictionary containing pricing information for various models.
110
- prompt_tokens (int): Number of tokens in the prompt.
111
-
112
- Returns:
113
- float: The calculated cost for the operation.
114
86
  """
87
+
115
88
  try:
116
- cost = (prompt_tokens / 1000) * pricing_info["embeddings"][model]
89
+ cost = (prompt_tokens / 1000) * pricing_info['embeddings'][model]
117
90
  except:
118
91
  cost = 0
119
92
  return cost
@@ -121,18 +94,10 @@ def get_embed_model_cost(model, pricing_info, prompt_tokens):
121
94
  def get_image_model_cost(model, pricing_info, size, quality):
122
95
  """
123
96
  Retrieve the cost of processing for a given model based on image size and quailty.
124
-
125
- Args:
126
- model (str): The model identifier.
127
- pricing_info (dict): A dictionary containing pricing information for various models.
128
- size (str): Size of the Image.
129
- quality (int): Quality of the Image.
130
-
131
- Returns:
132
- float: The calculated cost for the operation.
133
97
  """
98
+
134
99
  try:
135
- cost = pricing_info["images"][model][quality][size]
100
+ cost = pricing_info['images'][model][quality][size]
136
101
  except:
137
102
  cost = 0
138
103
  return cost
@@ -140,20 +105,13 @@ def get_image_model_cost(model, pricing_info, size, quality):
140
105
  def get_audio_model_cost(model, pricing_info, prompt, duration=None):
141
106
  """
142
107
  Retrieve the cost of processing for a given model based on prompt.
143
-
144
- Args:
145
- model (str): The model identifier.
146
- pricing_info (dict): A dictionary containing pricing information for various models.
147
- prompt (str): Prompt to the LLM Model
148
-
149
- Returns:
150
- float: The calculated cost for the operation.
151
108
  """
109
+
152
110
  try:
153
111
  if prompt:
154
- cost = (len(prompt) / 1000) * pricing_info["audio"][model]
112
+ cost = (len(prompt) / 1000) * pricing_info['audio'][model]
155
113
  else:
156
- cost = duration * pricing_info["audio"][model]
114
+ cost = duration * pricing_info['audio'][model]
157
115
  except:
158
116
  cost = 0
159
117
  return cost
@@ -161,15 +119,10 @@ def get_audio_model_cost(model, pricing_info, prompt, duration=None):
161
119
  def fetch_pricing_info(pricing_json=None):
162
120
  """
163
121
  Fetches pricing information from a specified URL or File Path.
164
-
165
- Args:
166
- pricing_json(str): path or url to the pricing json file
167
-
168
- Returns:
169
- dict: The pricing json
170
122
  """
123
+
171
124
  if pricing_json:
172
- is_url = urlparse(pricing_json).scheme != ""
125
+ is_url = urlparse(pricing_json).scheme != ''
173
126
  if is_url:
174
127
  pricing_url = pricing_json
175
128
  else:
@@ -177,39 +130,36 @@ def fetch_pricing_info(pricing_json=None):
177
130
  with open(pricing_json, mode='r', encoding='utf-8') as f:
178
131
  return json.load(f)
179
132
  except FileNotFoundError:
180
- logger.error("Pricing information file not found: %s", pricing_json)
133
+ logger.error('Pricing information file not found: %s', pricing_json)
181
134
  except json.JSONDecodeError:
182
- logger.error("Error decoding JSON from file: %s", pricing_json)
135
+ logger.error('Error decoding JSON from file: %s', pricing_json)
183
136
  except Exception as file_err:
184
- logger.error("Unexpected error occurred while reading file: %s", file_err)
137
+ logger.error('Unexpected error occurred while reading file: %s', file_err)
185
138
  return {}
186
139
  else:
187
- pricing_url = "https://raw.githubusercontent.com/openlit/openlit/main/assets/pricing.json"
140
+ pricing_url = 'https://raw.githubusercontent.com/openlit/openlit/main/assets/pricing.json'
188
141
  try:
189
142
  # Set a timeout of 10 seconds for both the connection and the read
190
143
  response = requests.get(pricing_url, timeout=20)
191
144
  response.raise_for_status()
192
145
  return response.json()
193
146
  except requests.HTTPError as http_err:
194
- logger.error("HTTP error occured while fetching pricing info: %s", http_err)
147
+ logger.error('HTTP error occured while fetching pricing info: %s', http_err)
195
148
  except Exception as err:
196
- logger.error("Unexpected error occurred while fetching pricing info: %s", err)
149
+ logger.error('Unexpected error occurred while fetching pricing info: %s', err)
197
150
  return {}
198
151
 
199
152
  def handle_exception(span,e):
200
153
  """Handles Exception when LLM Function fails or trace creation fails."""
201
- # Record the exception details within the span
154
+
202
155
  span.record_exception(e)
203
156
  span.set_status(Status(StatusCode.ERROR))
204
157
 
205
158
  def calculate_ttft(timestamps: List[float], start_time: float) -> float:
206
159
  """
207
160
  Calculate the time to the first tokens.
208
-
209
- :param timestamps: List of timestamps for received tokens
210
- :param start_time: The start time of the streaming process
211
- :return: Time to the first tokens
212
161
  """
162
+
213
163
  if timestamps:
214
164
  return timestamps[0] - start_time
215
165
  return 0.0
@@ -217,10 +167,8 @@ def calculate_ttft(timestamps: List[float], start_time: float) -> float:
217
167
  def calculate_tbt(timestamps: List[float]) -> float:
218
168
  """
219
169
  Calculate the average time between tokens.
220
-
221
- :param timestamps: List of timestamps for received tokens
222
- :return: Average time between tokens
223
170
  """
171
+
224
172
  if len(timestamps) > 1:
225
173
  time_diffs = [timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps))]
226
174
  return sum(time_diffs) / len(time_diffs)
@@ -239,8 +187,9 @@ def create_metrics_attributes(
239
187
  """
240
188
  Returns OTel metrics attributes
241
189
  """
190
+
242
191
  return {
243
- TELEMETRY_SDK_NAME: "openlit",
192
+ TELEMETRY_SDK_NAME: 'openlit',
244
193
  SERVICE_NAME: service_name,
245
194
  DEPLOYMENT_ENVIRONMENT: deployment_environment,
246
195
  SemanticConvetion.GEN_AI_OPERATION: operation,
@@ -259,18 +208,18 @@ def set_server_address_and_port(client_instance: Any,
259
208
  """
260
209
 
261
210
  # Try getting base_url from multiple potential attributes
262
- base_client = getattr(client_instance, "_client", None)
263
- base_url = getattr(base_client, "base_url", None)
211
+ base_client = getattr(client_instance, '_client', None)
212
+ base_url = getattr(base_client, 'base_url', None)
264
213
 
265
214
  if not base_url:
266
215
  # Attempt to get endpoint from instance._config.endpoint if base_url is not set
267
- config = getattr(client_instance, "_config", None)
268
- base_url = getattr(config, "endpoint", None)
216
+ config = getattr(client_instance, '_config', None)
217
+ base_url = getattr(config, 'endpoint', None)
269
218
 
270
219
  if not base_url:
271
220
  # Attempt to get server_url from instance.sdk_configuration.server_url
272
- config = getattr(client_instance, "sdk_configuration", None)
273
- base_url = getattr(config, "server_url", None)
221
+ config = getattr(client_instance, 'sdk_configuration', None)
222
+ base_url = getattr(config, 'server_url', None)
274
223
 
275
224
  if base_url:
276
225
  if isinstance(base_url, str):
@@ -278,8 +227,8 @@ def set_server_address_and_port(client_instance: Any,
278
227
  server_address = url.hostname or default_server_address
279
228
  server_port = url.port if url.port is not None else default_server_port
280
229
  else: # base_url might not be a str; handle as an object.
281
- server_address = getattr(base_url, "host", None) or default_server_address
282
- port_attr = getattr(base_url, "port", None)
230
+ server_address = getattr(base_url, 'host', None) or default_server_address
231
+ port_attr = getattr(base_url, 'port', None)
283
232
  server_port = port_attr if port_attr is not None else default_server_port
284
233
  else: # no base_url or endpoint provided; use defaults.
285
234
  server_address = default_server_address
@@ -301,59 +250,74 @@ def otel_event(name, attributes, body):
301
250
  def extract_and_format_input(messages):
302
251
  """
303
252
  Process a list of messages to extract content and categorize
304
- them into fixed roles like 'user', 'assistant', 'system'.
253
+ them into fixed roles like 'user', 'assistant', 'system', 'tool'.
305
254
  """
306
255
 
307
256
  fixed_roles = ['user', 'assistant', 'system', 'tool'] # Ensure these are your fixed keys
308
257
  # Initialize the dictionary with fixed keys and empty structures
309
- formatted_messages = {role_key: {"role": "", "content": ""} for role_key in fixed_roles}
258
+ formatted_messages = {role_key: {'role': '', 'content': ''} for role_key in fixed_roles}
310
259
 
311
260
  for message in messages:
312
261
  # Normalize the message structure
313
262
  message = response_as_dict(message)
314
263
 
315
264
  # Extract role and content
316
- role = message.get("role")
265
+ role = message.get('role')
317
266
  if role not in fixed_roles:
318
267
  continue # Skip any role not in our predefined roles
319
268
 
320
- content = message.get("content", "")
269
+ content = message.get('content', '')
321
270
 
322
271
  # Prepare content as a string
323
272
  if isinstance(content, list):
324
273
  content_str = ", ".join(
325
- # pylint: disable=line-too-long
326
- f'{item.get("type", "text")}: {item.get("text", item.get("image_url", "").get("url", "") if isinstance(item.get("image_url", ""), dict) else item.get("image_url", ""))}'
274
+ f'{item.get("type", "text")}: {extract_text_from_item(item)}'
327
275
  for item in content
328
276
  )
329
277
  else:
330
278
  content_str = content
331
279
 
332
280
  # Set the role in the formatted message and concatenate content
333
- if not formatted_messages[role]["role"]:
334
- formatted_messages[role]["role"] = role
281
+ if not formatted_messages[role]['role']:
282
+ formatted_messages[role]['role'] = role
335
283
 
336
- if formatted_messages[role]["content"]:
337
- formatted_messages[role]["content"] += " " + content_str
284
+ if formatted_messages[role]['content']:
285
+ formatted_messages[role]['content'] += ' ' + content_str
338
286
  else:
339
- formatted_messages[role]["content"] = content_str
287
+ formatted_messages[role]['content'] = content_str
340
288
 
341
289
  return formatted_messages
342
290
 
291
+ def extract_text_from_item(item):
292
+ """
293
+ Extract text from inpit message
294
+ """
295
+
296
+ #pylint: disable=no-else-return
297
+ if item.get('type') == 'text':
298
+ return item.get('text', '')
299
+ elif item.get('type') == 'image':
300
+ # Handle image content specifically checking for 'url' or 'base64'
301
+ source = item.get('source', {})
302
+ if isinstance(source, dict):
303
+ if source.get('type') == 'base64':
304
+ # Return the actual base64 data if present
305
+ return source.get('data', '[Missing base64 data]')
306
+ elif source.get('type') == 'url':
307
+ return source.get('url', '[Missing URL]')
308
+ elif item.get('type') == 'image_url':
309
+ # New format: Handle the 'image_url' type
310
+ image_url = item.get('image_url', {})
311
+ if isinstance(image_url, dict):
312
+ return image_url.get('url', '[Missing image URL]')
313
+ return ''
314
+
343
315
  # To be removed one the change to log events (from span events) is complete
344
316
  def concatenate_all_contents(formatted_messages):
345
317
  """
346
- Concatenate all 'content' fields from the formatted messages
347
- dictionary into a single string.
348
-
349
- Parameters:
350
- - formatted_messages: Dictionary with roles as keys and corresponding
351
- role and content as values.
352
-
353
- Returns:
354
- - A single string with all content concatenated.
318
+ Concatenate all 'content' fields into a single strin
355
319
  """
356
- return " ".join(
320
+ return ' '.join(
357
321
  message_data['content']
358
322
  for message_data in formatted_messages.values()
359
323
  if message_data['content']
@@ -9,7 +9,7 @@ from openlit.instrumentation.ag2.ag2 import (
9
9
  conversable_agent, agent_run
10
10
  )
11
11
 
12
- _instruments = ("ag2 >= 0.3.2",)
12
+ _instruments = ('ag2 >= 0.3.2',)
13
13
 
14
14
  class AG2Instrumentor(BaseInstrumentor):
15
15
  """
@@ -20,26 +20,26 @@ class AG2Instrumentor(BaseInstrumentor):
20
20
  return _instruments
21
21
 
22
22
  def _instrument(self, **kwargs):
23
- application_name = kwargs.get("application_name", "default_application")
24
- environment = kwargs.get("environment", "default_environment")
25
- tracer = kwargs.get("tracer")
26
- event_provider = kwargs.get("event_provider")
27
- metrics = kwargs.get("metrics_dict")
28
- pricing_info = kwargs.get("pricing_info", {})
29
- capture_message_content = kwargs.get("capture_message_content", False)
30
- disable_metrics = kwargs.get("disable_metrics")
31
- version = importlib.metadata.version("ag2")
23
+ application_name = kwargs.get('application_name', 'default_application')
24
+ environment = kwargs.get('environment', 'default_environment')
25
+ tracer = kwargs.get('tracer')
26
+ event_provider = kwargs.get('event_provider')
27
+ metrics = kwargs.get('metrics_dict')
28
+ pricing_info = kwargs.get('pricing_info', {})
29
+ capture_message_content = kwargs.get('capture_message_content', False)
30
+ disable_metrics = kwargs.get('disable_metrics')
31
+ version = importlib.metadata.version('ag2')
32
32
 
33
33
  wrap_function_wrapper(
34
- "autogen.agentchat.conversable_agent",
35
- "ConversableAgent.__init__",
34
+ 'autogen.agentchat.conversable_agent',
35
+ 'ConversableAgent.__init__',
36
36
  conversable_agent(version, environment, application_name,
37
37
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
38
38
  )
39
39
 
40
40
  wrap_function_wrapper(
41
- "autogen.agentchat.conversable_agent",
42
- "ConversableAgent.run",
41
+ 'autogen.agentchat.conversable_agent',
42
+ 'ConversableAgent.run',
43
43
  agent_run(version, environment, application_name,
44
44
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
45
45
  )
@@ -28,7 +28,7 @@ def set_span_attributes(span, version, operation_name, environment,
28
28
  """
29
29
 
30
30
  # Set Span attributes (OTel Semconv)
31
- span.set_attribute(TELEMETRY_SDK_NAME, "openlit")
31
+ span.set_attribute(TELEMETRY_SDK_NAME, 'openlit')
32
32
  span.set_attribute(SemanticConvetion.GEN_AI_OPERATION, operation_name)
33
33
  span.set_attribute(SemanticConvetion.GEN_AI_SYSTEM, SemanticConvetion.GEN_AI_SYSTEM_AG2)
34
34
  span.set_attribute(SemanticConvetion.GEN_AI_AGENT_NAME, AGENT_NAME)
@@ -73,10 +73,10 @@ def emit_events(response, event_provider, capture_message_content):
73
73
  SemanticConvetion.GEN_AI_SYSTEM: SemanticConvetion.GEN_AI_SYSTEM_AG2
74
74
  },
75
75
  body={
76
- "index": response.chat_history.index(chat),
77
- "message": {
78
- **({"content": chat['content']} if capture_message_content else {}),
79
- "role": 'assistant' if chat['role'] == 'user' else 'user'
76
+ 'index': response.chat_history.index(chat),
77
+ 'message': {
78
+ **({'content': chat['content']} if capture_message_content else {}),
79
+ 'role': 'assistant' if chat['role'] == 'user' else 'user'
80
80
  }
81
81
  }
82
82
  )
@@ -92,12 +92,12 @@ def conversable_agent(version, environment, application_name,
92
92
  global AGENT_NAME, MODEL_AND_NAME_SET, REQUEST_MODEL, SYSTEM_MESSAGE
93
93
 
94
94
  if not MODEL_AND_NAME_SET:
95
- AGENT_NAME = kwargs.get("name", "NOT_FOUND")
96
- REQUEST_MODEL = kwargs.get("llm_config", {}).get('model', 'gpt-4o')
95
+ AGENT_NAME = kwargs.get('name', 'NOT_FOUND')
96
+ REQUEST_MODEL = kwargs.get('llm_config', {}).get('model', 'gpt-4o')
97
97
  SYSTEM_MESSAGE = kwargs.get('system_message', '')
98
98
  MODEL_AND_NAME_SET = True
99
99
 
100
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CREATE_AGENT} {AGENT_NAME}"
100
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CREATE_AGENT} {AGENT_NAME}'
101
101
 
102
102
  with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
103
103
  try:
@@ -117,7 +117,7 @@ def conversable_agent(version, environment, application_name,
117
117
 
118
118
  except Exception as e:
119
119
  handle_exception(span, e)
120
- logger.error("Error in trace creation: %s", e)
120
+ logger.error('Error in trace creation: %s', e)
121
121
  return response
122
122
 
123
123
  return wrapper
@@ -130,7 +130,7 @@ def agent_run(version, environment, application_name,
130
130
  def wrapper(wrapped, instance, args, kwargs):
131
131
  server_address, server_port = '127.0.0.1', 80
132
132
 
133
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_EXECUTE_AGENT_TASK} {AGENT_NAME}"
133
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_EXECUTE_AGENT_TASK} {AGENT_NAME}'
134
134
 
135
135
  with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
136
136
  try:
@@ -157,7 +157,7 @@ def agent_run(version, environment, application_name,
157
157
 
158
158
  except Exception as e:
159
159
  handle_exception(span, e)
160
- logger.error("Error in trace creation: %s", e)
160
+ logger.error('Error in trace creation: %s', e)
161
161
  return response
162
162
 
163
163
  return wrapper
@@ -13,7 +13,7 @@ from openlit.instrumentation.ai21.async_ai21 import (
13
13
  async_chat, async_chat_rag
14
14
  )
15
15
 
16
- _instruments = ("ai21 >= 3.0.0",)
16
+ _instruments = ('ai21 >= 3.0.0',)
17
17
 
18
18
  class AI21Instrumentor(BaseInstrumentor):
19
19
  """
@@ -24,40 +24,40 @@ class AI21Instrumentor(BaseInstrumentor):
24
24
  return _instruments
25
25
 
26
26
  def _instrument(self, **kwargs):
27
- application_name = kwargs.get("application_name", "default_application")
28
- environment = kwargs.get("environment", "default_environment")
29
- tracer = kwargs.get("tracer")
30
- event_provider = kwargs.get("event_provider")
31
- metrics = kwargs.get("metrics_dict")
32
- pricing_info = kwargs.get("pricing_info", {})
33
- capture_message_content = kwargs.get("capture_message_content", False)
34
- disable_metrics = kwargs.get("disable_metrics")
35
- version = importlib.metadata.version("ai21")
27
+ application_name = kwargs.get('application_name', 'default')
28
+ environment = kwargs.get('environment', 'default')
29
+ tracer = kwargs.get('tracer')
30
+ event_provider = kwargs.get('event_provider')
31
+ metrics = kwargs.get('metrics_dict')
32
+ pricing_info = kwargs.get('pricing_info', {})
33
+ capture_message_content = kwargs.get('capture_message_content', False)
34
+ disable_metrics = kwargs.get('disable_metrics')
35
+ version = importlib.metadata.version('ai21')
36
36
 
37
37
  #sync
38
38
  wrap_function_wrapper(
39
- "ai21.clients.studio.resources.chat.chat_completions",
40
- "ChatCompletions.create",
39
+ 'ai21.clients.studio.resources.chat.chat_completions',
40
+ 'ChatCompletions.create',
41
41
  chat(version, environment, application_name,
42
42
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
43
43
  )
44
44
  wrap_function_wrapper(
45
- "ai21.clients.studio.resources.studio_conversational_rag",
46
- "StudioConversationalRag.create",
45
+ 'ai21.clients.studio.resources.studio_conversational_rag',
46
+ 'StudioConversationalRag.create',
47
47
  chat_rag(version, environment, application_name,
48
48
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
49
49
  )
50
50
 
51
51
  #Async
52
52
  wrap_function_wrapper(
53
- "ai21.clients.studio.resources.chat.async_chat_completions",
54
- "AsyncChatCompletions.create",
53
+ 'ai21.clients.studio.resources.chat.async_chat_completions',
54
+ 'AsyncChatCompletions.create',
55
55
  async_chat(version, environment, application_name,
56
56
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
57
57
  )
58
58
  wrap_function_wrapper(
59
- "ai21.clients.studio.resources.studio_conversational_rag",
60
- "AsyncStudioConversationalRag.create",
59
+ 'ai21.clients.studio.resources.studio_conversational_rag',
60
+ 'AsyncStudioConversationalRag.create',
61
61
  async_chat_rag(version, environment, application_name,
62
62
  tracer, event_provider, pricing_info, capture_message_content, metrics, disable_metrics),
63
63
  )
@@ -7,7 +7,6 @@ import time
7
7
  from opentelemetry.trace import SpanKind
8
8
  from openlit.__helpers import (
9
9
  handle_exception,
10
- response_as_dict,
11
10
  set_server_address_and_port,
12
11
  )
13
12
  from openlit.instrumentation.ai21.utils import (
@@ -47,9 +46,9 @@ def chat(version, environment, application_name,
47
46
  self._span = span
48
47
  self._span_name = span_name
49
48
  # Placeholder for aggregating streaming response
50
- self._llmresponse = ""
51
- self._response_id = ""
52
- self._finish_reason = ""
49
+ self._llmresponse = ''
50
+ self._response_id = ''
51
+ self._finish_reason = ''
53
52
  self._input_tokens = 0
54
53
  self._output_tokens = 0
55
54
  self._choices = []
@@ -100,7 +99,7 @@ def chat(version, environment, application_name,
100
99
  )
101
100
  except Exception as e:
102
101
  handle_exception(self._span, e)
103
- logger.error("Error in trace creation: %s", e)
102
+ logger.error('Error in trace creation: %s', e)
104
103
  raise
105
104
 
106
105
  def wrapper(wrapped, instance, args, kwargs):
@@ -109,12 +108,12 @@ def chat(version, environment, application_name,
109
108
  """
110
109
 
111
110
  # Check if streaming is enabled for the API call
112
- streaming = kwargs.get("stream", False)
111
+ streaming = kwargs.get('stream', False)
113
112
 
114
- server_address, server_port = set_server_address_and_port(instance, "api.ai21.com", 443)
115
- request_model = kwargs.get("model", "jamba-1.5-mini")
113
+ server_address, server_port = set_server_address_and_port(instance, 'api.ai21.com', 443)
114
+ request_model = kwargs.get('model', 'jamba-1.5-mini')
116
115
 
117
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}"
116
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}'
118
117
 
119
118
  # pylint: disable=no-else-return
120
119
  if streaming:
@@ -129,7 +128,7 @@ def chat(version, environment, application_name,
129
128
  start_time = time.time()
130
129
  response = wrapped(*args, **kwargs)
131
130
  response = process_chat_response(
132
- response=response_as_dict(response),
131
+ response=response,
133
132
  request_model=request_model,
134
133
  pricing_info=pricing_info,
135
134
  server_port=server_port,
@@ -161,16 +160,16 @@ def chat_rag(version, environment, application_name,
161
160
  Wraps the GenAI function call.
162
161
  """
163
162
 
164
- server_address, server_port = set_server_address_and_port(instance, "api.ai21.com", 443)
165
- request_model = kwargs.get("model", "jamba-1.5-mini")
163
+ server_address, server_port = set_server_address_and_port(instance, 'api.ai21.com', 443)
164
+ request_model = kwargs.get('model', 'jamba-1.5-mini')
166
165
 
167
- span_name = f"{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}"
166
+ span_name = f'{SemanticConvetion.GEN_AI_OPERATION_TYPE_CHAT} {request_model}'
168
167
 
169
168
  with tracer.start_as_current_span(span_name, kind= SpanKind.CLIENT) as span:
170
169
  start_time = time.time()
171
170
  response = wrapped(*args, **kwargs)
172
171
  response = process_chat_rag_response(
173
- response=response_as_dict(response),
172
+ response=response,
174
173
  request_model=request_model,
175
174
  pricing_info=pricing_info,
176
175
  server_port=server_port,