langtrace-python-sdk 2.3.2__py3-none-any.whl → 2.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. examples/anthropic_example/completion.py +1 -1
  2. examples/crewai_example/instagram_post/__init__.py +0 -0
  3. examples/crewai_example/instagram_post/agents.py +96 -0
  4. examples/crewai_example/instagram_post/main.py +80 -0
  5. examples/crewai_example/instagram_post/tasks.py +146 -0
  6. examples/crewai_example/instagram_post/tools/__init__.py +0 -0
  7. examples/crewai_example/instagram_post/tools/browser_tools.py +40 -0
  8. examples/openai_example/__init__.py +1 -0
  9. langtrace_python_sdk/instrumentation/anthropic/instrumentation.py +10 -9
  10. langtrace_python_sdk/instrumentation/anthropic/patch.py +33 -29
  11. langtrace_python_sdk/instrumentation/anthropic/types.py +105 -0
  12. langtrace_python_sdk/instrumentation/cohere/patch.py +1 -4
  13. langtrace_python_sdk/instrumentation/crewai/instrumentation.py +15 -0
  14. langtrace_python_sdk/instrumentation/crewai/patch.py +47 -25
  15. langtrace_python_sdk/instrumentation/gemini/patch.py +2 -5
  16. langtrace_python_sdk/instrumentation/groq/patch.py +7 -19
  17. langtrace_python_sdk/instrumentation/openai/instrumentation.py +14 -19
  18. langtrace_python_sdk/instrumentation/openai/patch.py +93 -101
  19. langtrace_python_sdk/instrumentation/openai/types.py +170 -0
  20. langtrace_python_sdk/instrumentation/vertexai/patch.py +2 -5
  21. langtrace_python_sdk/instrumentation/weaviate/patch.py +3 -13
  22. langtrace_python_sdk/langtrace.py +20 -21
  23. langtrace_python_sdk/utils/llm.py +12 -7
  24. langtrace_python_sdk/utils/silently_fail.py +19 -3
  25. langtrace_python_sdk/version.py +1 -1
  26. {langtrace_python_sdk-2.3.2.dist-info → langtrace_python_sdk-2.3.4.dist-info}/METADATA +1 -1
  27. {langtrace_python_sdk-2.3.2.dist-info → langtrace_python_sdk-2.3.4.dist-info}/RECORD +30 -22
  28. {langtrace_python_sdk-2.3.2.dist-info → langtrace_python_sdk-2.3.4.dist-info}/WHEEL +0 -0
  29. {langtrace_python_sdk-2.3.2.dist-info → langtrace_python_sdk-2.3.4.dist-info}/entry_points.txt +0 -0
  30. {langtrace_python_sdk-2.3.2.dist-info → langtrace_python_sdk-2.3.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,21 +1,5 @@
1
- """
2
- Copyright (c) 2024 Scale3 Labs
3
-
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
7
-
8
- http://www.apache.org/licenses/LICENSE-2.0
9
-
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- """
16
-
17
1
  import json
18
-
2
+ from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
19
3
  from langtrace.trace_attributes import (
20
4
  LLMSpanAttributes,
21
5
  SpanAttributes,
@@ -23,7 +7,7 @@ from langtrace.trace_attributes import (
23
7
  from langtrace_python_sdk.utils import set_span_attribute
24
8
  from langtrace_python_sdk.utils.silently_fail import silently_fail
25
9
  from opentelemetry import trace
26
- from opentelemetry.trace import SpanKind
10
+ from opentelemetry.trace import SpanKind, Tracer, Span
27
11
  from opentelemetry.trace.status import Status, StatusCode
28
12
  from opentelemetry.trace.propagation import set_span_in_context
29
13
  from langtrace_python_sdk.constants.instrumentation.common import (
@@ -46,20 +30,31 @@ from langtrace_python_sdk.utils.llm import (
46
30
  )
47
31
  from langtrace_python_sdk.types import NOT_GIVEN
48
32
 
33
+ from langtrace_python_sdk.instrumentation.openai.types import (
34
+ ImagesGenerateKwargs,
35
+ ChatCompletionsCreateKwargs,
36
+ EmbeddingsCreateKwargs,
37
+ ImagesEditKwargs,
38
+ ResultType,
39
+ ContentItem,
40
+ )
41
+
49
42
 
50
- def images_generate(original_method, version, tracer):
43
+ def images_generate(version: str, tracer: Tracer) -> Callable:
51
44
  """
52
45
  Wrap the `generate` method of the `Images` class to trace it.
53
46
  """
54
47
 
55
- def traced_method(wrapped, instance, args, kwargs):
48
+ def traced_method(
49
+ wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesGenerateKwargs
50
+ ) -> Any:
56
51
  service_provider = SERVICE_PROVIDERS["OPENAI"]
57
52
  span_attributes = {
58
53
  **get_langtrace_attributes(version, service_provider, vendor_type="llm"),
59
54
  **get_llm_request_attributes(kwargs, operation_name="images_generate"),
60
55
  **get_llm_url(instance),
61
56
  SpanAttributes.LLM_PATH: APIS["IMAGES_GENERATION"]["ENDPOINT"],
62
- **get_extra_attributes(),
57
+ **get_extra_attributes(), # type: ignore
63
58
  }
64
59
 
65
60
  attributes = LLMSpanAttributes(**span_attributes)
@@ -74,21 +69,17 @@ def images_generate(original_method, version, tracer):
74
69
  # Attempt to call the original method
75
70
  result = wrapped(*args, **kwargs)
76
71
  if not is_streaming(kwargs):
77
- data = (
72
+ data: Optional[ContentItem] = (
78
73
  result.data[0]
79
74
  if hasattr(result, "data") and len(result.data) > 0
80
- else {}
75
+ else None
81
76
  )
82
77
  response = [
83
78
  {
84
79
  "role": "assistant",
85
80
  "content": {
86
- "url": data.url if hasattr(data, "url") else "",
87
- "revised_prompt": (
88
- data.revised_prompt
89
- if hasattr(data, "revised_prompt")
90
- else ""
91
- ),
81
+ "url": getattr(data, "url", ""),
82
+ "revised_prompt": getattr(data, "revised_prompt", ""),
92
83
  },
93
84
  }
94
85
  ]
@@ -109,12 +100,14 @@ def images_generate(original_method, version, tracer):
109
100
  return traced_method
110
101
 
111
102
 
112
- def async_images_generate(original_method, version, tracer):
103
+ def async_images_generate(version: str, tracer: Tracer) -> Callable:
113
104
  """
114
105
  Wrap the `generate` method of the `Images` class to trace it.
115
106
  """
116
107
 
117
- async def traced_method(wrapped, instance, args, kwargs):
108
+ async def traced_method(
109
+ wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesGenerateKwargs
110
+ ) -> Awaitable[Any]:
118
111
  service_provider = SERVICE_PROVIDERS["OPENAI"]
119
112
 
120
113
  span_attributes = {
@@ -122,7 +115,7 @@ def async_images_generate(original_method, version, tracer):
122
115
  **get_llm_request_attributes(kwargs, operation_name="images_generate"),
123
116
  **get_llm_url(instance),
124
117
  SpanAttributes.LLM_PATH: APIS["IMAGES_GENERATION"]["ENDPOINT"],
125
- **get_extra_attributes(),
118
+ **get_extra_attributes(), # type: ignore
126
119
  }
127
120
 
128
121
  attributes = LLMSpanAttributes(**span_attributes)
@@ -137,21 +130,17 @@ def async_images_generate(original_method, version, tracer):
137
130
  # Attempt to call the original method
138
131
  result = await wrapped(*args, **kwargs)
139
132
  if not is_streaming(kwargs):
140
- data = (
133
+ data: Optional[ContentItem] = (
141
134
  result.data[0]
142
135
  if hasattr(result, "data") and len(result.data) > 0
143
- else {}
136
+ else None
144
137
  )
145
138
  response = [
146
139
  {
147
140
  "role": "assistant",
148
141
  "content": {
149
- "url": data.url if hasattr(data, "url") else "",
150
- "revised_prompt": (
151
- data.revised_prompt
152
- if hasattr(data, "revised_prompt")
153
- else ""
154
- ),
142
+ "url": getattr(data, "url", ""),
143
+ "revised_prompt": getattr(data, "revised_prompt", ""),
155
144
  },
156
145
  }
157
146
  ]
@@ -172,12 +161,14 @@ def async_images_generate(original_method, version, tracer):
172
161
  return traced_method
173
162
 
174
163
 
175
- def images_edit(original_method, version, tracer):
164
+ def images_edit(version: str, tracer: Tracer) -> Callable:
176
165
  """
177
166
  Wrap the `edit` method of the `Images` class to trace it.
178
167
  """
179
168
 
180
- def traced_method(wrapped, instance, args, kwargs):
169
+ def traced_method(
170
+ wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesEditKwargs
171
+ ) -> Any:
181
172
  service_provider = SERVICE_PROVIDERS["OPENAI"]
182
173
 
183
174
  span_attributes = {
@@ -187,7 +178,7 @@ def images_edit(original_method, version, tracer):
187
178
  SpanAttributes.LLM_PATH: APIS["IMAGES_EDIT"]["ENDPOINT"],
188
179
  SpanAttributes.LLM_RESPONSE_FORMAT: kwargs.get("response_format"),
189
180
  SpanAttributes.LLM_IMAGE_SIZE: kwargs.get("size"),
190
- **get_extra_attributes(),
181
+ **get_extra_attributes(), # type: ignore
191
182
  }
192
183
 
193
184
  attributes = LLMSpanAttributes(**span_attributes)
@@ -233,10 +224,15 @@ def images_edit(original_method, version, tracer):
233
224
  return traced_method
234
225
 
235
226
 
236
- def chat_completions_create(original_method, version, tracer):
227
+ def chat_completions_create(version: str, tracer: Tracer) -> Callable:
237
228
  """Wrap the `create` method of the `ChatCompletion` class to trace it."""
238
229
 
239
- def traced_method(wrapped, instance, args, kwargs):
230
+ def traced_method(
231
+ wrapped: Callable,
232
+ instance: Any,
233
+ args: List[Any],
234
+ kwargs: ChatCompletionsCreateKwargs,
235
+ ) -> Any:
240
236
  service_provider = SERVICE_PROVIDERS["OPENAI"]
241
237
  if "perplexity" in get_base_url(instance):
242
238
  service_provider = SERVICE_PROVIDERS["PPLX"]
@@ -251,21 +247,13 @@ def chat_completions_create(original_method, version, tracer):
251
247
  tool_calls = []
252
248
  for tool_call in tools:
253
249
  tool_call_dict = {
254
- "id": tool_call.id if hasattr(tool_call, "id") else "",
255
- "type": tool_call.type if hasattr(tool_call, "type") else "",
250
+ "id": getattr(tool_call, "id", ""),
251
+ "type": getattr(tool_call, "type", ""),
256
252
  }
257
253
  if hasattr(tool_call, "function"):
258
254
  tool_call_dict["function"] = {
259
- "name": (
260
- tool_call.function.name
261
- if hasattr(tool_call.function, "name")
262
- else ""
263
- ),
264
- "arguments": (
265
- tool_call.function.arguments
266
- if hasattr(tool_call.function, "arguments")
267
- else ""
268
- ),
255
+ "name": getattr(tool_call.function, "name", ""),
256
+ "arguments": getattr(tool_call.function, "arguments", ""),
269
257
  }
270
258
  tool_calls.append(tool_call_dict)
271
259
  llm_prompts.append(tool_calls)
@@ -277,7 +265,7 @@ def chat_completions_create(original_method, version, tracer):
277
265
  **get_llm_request_attributes(kwargs, prompts=llm_prompts),
278
266
  **get_llm_url(instance),
279
267
  SpanAttributes.LLM_PATH: APIS["CHAT_COMPLETION"]["ENDPOINT"],
280
- **get_extra_attributes(),
268
+ **get_extra_attributes(), # type: ignore
281
269
  }
282
270
 
283
271
  attributes = LLMSpanAttributes(**span_attributes)
@@ -297,12 +285,9 @@ def chat_completions_create(original_method, version, tracer):
297
285
  prompt_tokens += calculate_prompt_tokens(
298
286
  json.dumps(str(message)), kwargs.get("model")
299
287
  )
300
-
301
- if (
302
- kwargs.get("functions") is not None
303
- and kwargs.get("functions") != NOT_GIVEN
304
- ):
305
- for function in kwargs.get("functions"):
288
+ functions = kwargs.get("functions")
289
+ if functions is not None and functions != NOT_GIVEN:
290
+ for function in functions:
306
291
  prompt_tokens += calculate_prompt_tokens(
307
292
  json.dumps(function), kwargs.get("model")
308
293
  )
@@ -315,7 +300,7 @@ def chat_completions_create(original_method, version, tracer):
315
300
  tool_calls=kwargs.get("tools") is not None,
316
301
  )
317
302
  else:
318
- _set_response_attributes(span, kwargs, result)
303
+ _set_response_attributes(span, result)
319
304
  span.set_status(StatusCode.OK)
320
305
  span.end()
321
306
  return result
@@ -329,10 +314,15 @@ def chat_completions_create(original_method, version, tracer):
329
314
  return traced_method
330
315
 
331
316
 
332
- def async_chat_completions_create(original_method, version, tracer):
317
+ def async_chat_completions_create(version: str, tracer: Tracer) -> Callable:
333
318
  """Wrap the `create` method of the `ChatCompletion` class to trace it."""
334
319
 
335
- async def traced_method(wrapped, instance, args, kwargs):
320
+ async def traced_method(
321
+ wrapped: Callable,
322
+ instance: Any,
323
+ args: List[Any],
324
+ kwargs: ChatCompletionsCreateKwargs,
325
+ ) -> Awaitable[Any]:
336
326
  service_provider = SERVICE_PROVIDERS["OPENAI"]
337
327
  if "perplexity" in get_base_url(instance):
338
328
  service_provider = SERVICE_PROVIDERS["PPLX"]
@@ -345,21 +335,13 @@ def async_chat_completions_create(original_method, version, tracer):
345
335
  tool_calls = []
346
336
  for tool_call in tools:
347
337
  tool_call_dict = {
348
- "id": tool_call.id if hasattr(tool_call, "id") else "",
349
- "type": tool_call.type if hasattr(tool_call, "type") else "",
338
+ "id": getattr(tool_call, "id", ""),
339
+ "type": getattr(tool_call, "type", ""),
350
340
  }
351
341
  if hasattr(tool_call, "function"):
352
342
  tool_call_dict["function"] = {
353
- "name": (
354
- tool_call.function.name
355
- if hasattr(tool_call.function, "name")
356
- else ""
357
- ),
358
- "arguments": (
359
- tool_call.function.arguments
360
- if hasattr(tool_call.function, "arguments")
361
- else ""
362
- ),
343
+ "name": getattr(tool_call.function, "name", ""),
344
+ "arguments": getattr(tool_call.function, "arguments", ""),
363
345
  }
364
346
  tool_calls.append(json.dumps(tool_call_dict))
365
347
  llm_prompts.append(tool_calls)
@@ -371,7 +353,7 @@ def async_chat_completions_create(original_method, version, tracer):
371
353
  **get_llm_request_attributes(kwargs, prompts=llm_prompts),
372
354
  **get_llm_url(instance),
373
355
  SpanAttributes.LLM_PATH: APIS["CHAT_COMPLETION"]["ENDPOINT"],
374
- **get_extra_attributes(),
356
+ **get_extra_attributes(), # type: ignore
375
357
  }
376
358
 
377
359
  attributes = LLMSpanAttributes(**span_attributes)
@@ -392,11 +374,9 @@ def async_chat_completions_create(original_method, version, tracer):
392
374
  json.dumps((str(message))), kwargs.get("model")
393
375
  )
394
376
 
395
- if (
396
- kwargs.get("functions") is not None
397
- and kwargs.get("functions") != NOT_GIVEN
398
- ):
399
- for function in kwargs.get("functions"):
377
+ functions = kwargs.get("functions")
378
+ if functions is not None and functions != NOT_GIVEN:
379
+ for function in functions:
400
380
  prompt_tokens += calculate_prompt_tokens(
401
381
  json.dumps(function), kwargs.get("model")
402
382
  )
@@ -407,9 +387,9 @@ def async_chat_completions_create(original_method, version, tracer):
407
387
  prompt_tokens,
408
388
  function_call=kwargs.get("functions") is not None,
409
389
  tool_calls=kwargs.get("tools") is not None,
410
- )
390
+ ) # type: ignore
411
391
  else:
412
- _set_response_attributes(span, kwargs, result)
392
+ _set_response_attributes(span, result)
413
393
  span.set_status(StatusCode.OK)
414
394
  span.end()
415
395
  return result
@@ -423,12 +403,17 @@ def async_chat_completions_create(original_method, version, tracer):
423
403
  return traced_method
424
404
 
425
405
 
426
- def embeddings_create(original_method, version, tracer):
406
+ def embeddings_create(version: str, tracer: Tracer) -> Callable:
427
407
  """
428
408
  Wrap the `create` method of the `Embeddings` class to trace it.
429
409
  """
430
410
 
431
- def traced_method(wrapped, instance, args, kwargs):
411
+ def traced_method(
412
+ wrapped: Callable,
413
+ instance: Any,
414
+ args: List[Any],
415
+ kwargs: EmbeddingsCreateKwargs,
416
+ ) -> Any:
432
417
  service_provider = SERVICE_PROVIDERS["OPENAI"]
433
418
 
434
419
  span_attributes = {
@@ -437,7 +422,7 @@ def embeddings_create(original_method, version, tracer):
437
422
  **get_llm_url(instance),
438
423
  SpanAttributes.LLM_PATH: APIS["EMBEDDINGS_CREATE"]["ENDPOINT"],
439
424
  SpanAttributes.LLM_REQUEST_DIMENSIONS: kwargs.get("dimensions"),
440
- **get_extra_attributes(),
425
+ **get_extra_attributes(), # type: ignore
441
426
  }
442
427
 
443
428
  encoding_format = kwargs.get("encoding_format")
@@ -480,12 +465,17 @@ def embeddings_create(original_method, version, tracer):
480
465
  return traced_method
481
466
 
482
467
 
483
- def async_embeddings_create(original_method, version, tracer):
468
+ def async_embeddings_create(version: str, tracer: Tracer) -> Callable:
484
469
  """
485
470
  Wrap the `create` method of the `Embeddings` class to trace it.
486
471
  """
487
472
 
488
- async def traced_method(wrapped, instance, args, kwargs):
473
+ async def traced_method(
474
+ wrapped: Callable,
475
+ instance: Any,
476
+ args: List[Any],
477
+ kwargs: EmbeddingsCreateKwargs,
478
+ ) -> Awaitable[Any]:
489
479
 
490
480
  service_provider = SERVICE_PROVIDERS["OPENAI"]
491
481
 
@@ -494,7 +484,7 @@ def async_embeddings_create(original_method, version, tracer):
494
484
  **get_llm_request_attributes(kwargs, operation_name="embed"),
495
485
  SpanAttributes.LLM_PATH: APIS["EMBEDDINGS_CREATE"]["ENDPOINT"],
496
486
  SpanAttributes.LLM_REQUEST_DIMENSIONS: kwargs.get("dimensions"),
497
- **get_extra_attributes(),
487
+ **get_extra_attributes(), # type: ignore
498
488
  }
499
489
 
500
490
  attributes = LLMSpanAttributes(**span_attributes)
@@ -537,7 +527,7 @@ def async_embeddings_create(original_method, version, tracer):
537
527
  return traced_method
538
528
 
539
529
 
540
- def extract_content(choice):
530
+ def extract_content(choice: Any) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]:
541
531
  # Check if choice.message exists and has a content attribute
542
532
  if (
543
533
  hasattr(choice, "message")
@@ -582,13 +572,15 @@ def extract_content(choice):
582
572
 
583
573
 
584
574
  @silently_fail
585
- def _set_input_attributes(span, kwargs, attributes):
575
+ def _set_input_attributes(
576
+ span: Span, kwargs: ChatCompletionsCreateKwargs, attributes: LLMSpanAttributes
577
+ ) -> None:
586
578
  tools = []
587
579
  for field, value in attributes.model_dump(by_alias=True).items():
588
580
  set_span_attribute(span, field, value)
589
-
590
- if kwargs.get("functions") is not None and kwargs.get("functions") != NOT_GIVEN:
591
- for function in kwargs.get("functions"):
581
+ functions = kwargs.get("functions")
582
+ if functions is not None and functions != NOT_GIVEN:
583
+ for function in functions:
592
584
  tools.append(json.dumps({"type": "function", "function": function}))
593
585
 
594
586
  if kwargs.get("tools") is not None and kwargs.get("tools") != NOT_GIVEN:
@@ -599,7 +591,7 @@ def _set_input_attributes(span, kwargs, attributes):
599
591
 
600
592
 
601
593
  @silently_fail
602
- def _set_response_attributes(span, kwargs, result):
594
+ def _set_response_attributes(span: Span, result: ResultType) -> None:
603
595
  set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, result.model)
604
596
  if hasattr(result, "choices") and result.choices is not None:
605
597
  responses = [
@@ -611,8 +603,8 @@ def _set_response_attributes(span, kwargs, result):
611
603
  ),
612
604
  "content": extract_content(choice),
613
605
  **(
614
- {"content_filter_results": choice["content_filter_results"]}
615
- if "content_filter_results" in choice
606
+ {"content_filter_results": choice.content_filter_results}
607
+ if hasattr(choice, "content_filter_results")
616
608
  else {}
617
609
  ),
618
610
  }
@@ -0,0 +1,170 @@
1
+ """
2
+ Copyright (c) 2024 Scale3 Labs
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ from typing import Any, Dict, List, Union, Optional, TypedDict
15
+
16
+
17
+ class ContentItem:
18
+ url: str
19
+ revised_prompt: str
20
+ base64: Optional[str]
21
+
22
+ def __init__(
23
+ self,
24
+ url: str,
25
+ revised_prompt: str,
26
+ base64: Optional[str],
27
+ ):
28
+ self.url = url
29
+ self.revised_prompt = revised_prompt
30
+ self.base64 = base64
31
+
32
+
33
+ class ToolFunction:
34
+ name: str
35
+ arguments: str
36
+
37
+ def __init__(
38
+ self,
39
+ name: str,
40
+ arguments: str,
41
+ ):
42
+ self.name = name
43
+ self.arguments = arguments
44
+
45
+
46
+ class ToolCall:
47
+ id: str
48
+ type: str
49
+ function: ToolFunction
50
+
51
+ def __init__(
52
+ self,
53
+ id: str,
54
+ type: str,
55
+ function: ToolFunction,
56
+ ):
57
+ self.id = id
58
+ self.type = type
59
+ self.function = function
60
+
61
+
62
+ class Message:
63
+ role: str
64
+ content: Union[str, List[ContentItem], Dict[str, Any]]
65
+ tool_calls: Optional[List[ToolCall]]
66
+
67
+ def __init__(
68
+ self,
69
+ role: str,
70
+ content: Union[str, List[ContentItem], Dict[str, Any]],
71
+ content_filter_results: Optional[Any],
72
+ ):
73
+ self.role = role
74
+ self.content = content
75
+ self.content_filter_results = content_filter_results
76
+
77
+
78
+ class Usage:
79
+ prompt_tokens: int
80
+ completion_tokens: int
81
+ total_tokens: int
82
+
83
+ def __init__(
84
+ self,
85
+ prompt_tokens: int,
86
+ completion_tokens: int,
87
+ total_tokens: int,
88
+ ):
89
+ self.prompt_tokens = prompt_tokens
90
+ self.completion_tokens = completion_tokens
91
+ self.total_tokens = total_tokens
92
+
93
+
94
+ class Choice:
95
+ message: Message
96
+ content_filter_results: Optional[Any]
97
+
98
+ def __init__(
99
+ self,
100
+ message: Message,
101
+ content_filter_results: Optional[Any],
102
+ ):
103
+ self.message = message
104
+ self.content_filter_results = content_filter_results
105
+
106
+
107
+ class ResultType:
108
+ model: Optional[str]
109
+ content: List[ContentItem]
110
+ system_fingerprint: Optional[str]
111
+ usage: Optional[Usage]
112
+ choices: Optional[List[Choice]]
113
+ response_format: Optional[str]
114
+ size: Optional[str]
115
+ encoding_format: Optional[str]
116
+
117
+ def __init__(
118
+ self,
119
+ model: Optional[str],
120
+ role: Optional[str],
121
+ content: List[ContentItem],
122
+ system_fingerprint: Optional[str],
123
+ usage: Optional[Usage],
124
+ functions: Optional[List[ToolCall]],
125
+ tools: Optional[List[ToolCall]],
126
+ choices: Optional[List[Choice]],
127
+ response_format: Optional[str],
128
+ size: Optional[str],
129
+ encoding_format: Optional[str],
130
+ ):
131
+ self.model = model
132
+ self.role = role
133
+ self.content = content
134
+ self.system_fingerprint = system_fingerprint
135
+ self.usage = usage
136
+ self.functions = functions
137
+ self.tools = tools
138
+ self.choices = choices
139
+ self.response_format = response_format
140
+ self.size = size
141
+ self.encoding_format = encoding_format
142
+
143
+
144
+ class ImagesGenerateKwargs(TypedDict, total=False):
145
+ operation_name: str
146
+ model: Optional[str]
147
+ messages: Optional[List[Message]]
148
+ functions: Optional[List[ToolCall]]
149
+ tools: Optional[List[ToolCall]]
150
+ response_format: Optional[str]
151
+ size: Optional[str]
152
+ encoding_format: Optional[str]
153
+
154
+
155
+ class ImagesEditKwargs(TypedDict, total=False):
156
+ response_format: Optional[str]
157
+ size: Optional[str]
158
+
159
+
160
+ class ChatCompletionsCreateKwargs(TypedDict, total=False):
161
+ model: Optional[str]
162
+ messages: List[Message]
163
+ functions: Optional[List[ToolCall]]
164
+ tools: Optional[List[ToolCall]]
165
+
166
+
167
+ class EmbeddingsCreateKwargs(TypedDict, total=False):
168
+ dimensions: Optional[str]
169
+ input: Union[str, List[str], None]
170
+ encoding_format: Optional[Union[List[str], str]]
@@ -102,12 +102,9 @@ def is_streaming_response(response):
102
102
 
103
103
 
104
104
  def get_llm_model(instance):
105
- llm_model = "unknown"
106
- if hasattr(instance, "_model_id"):
107
- llm_model = instance._model_id
108
105
  if hasattr(instance, "_model_name"):
109
- llm_model = instance._model_name.replace("publishers/google/models/", "")
110
- return llm_model
106
+ return instance._model_name.replace("models/", "")
107
+ return getattr(instance, "_model_id", "unknown")
111
108
 
112
109
 
113
110
  def serialize_prompts(args, kwargs):
@@ -96,19 +96,9 @@ def get_response_object_attributes(response_object):
96
96
  response_attributes = {
97
97
  **response_object.properties,
98
98
  "uuid": str(response_object.uuid) if hasattr(response_object, "uuid") else None,
99
- "collection": (
100
- response_object.collection
101
- if hasattr(response_object, "collection")
102
- else None
103
- ),
104
- "vector": (
105
- response_object.vector if hasattr(response_object, "vector") else None
106
- ),
107
- "references": (
108
- response_object.references
109
- if hasattr(response_object, "references")
110
- else None
111
- ),
99
+ "collection": getattr(response_object, "collection", None),
100
+ "vector": getattr(response_object, "vector", None),
101
+ "references": getattr(response_object, "references", None),
112
102
  "metadata": (
113
103
  extract_metadata(response_object.metadata)
114
104
  if hasattr(response_object, "metadata")