lmnr 0.5.1a0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. lmnr/__init__.py +0 -8
  2. lmnr/openllmetry_sdk/__init__.py +5 -33
  3. lmnr/openllmetry_sdk/decorators/base.py +24 -17
  4. lmnr/openllmetry_sdk/instruments.py +1 -0
  5. lmnr/openllmetry_sdk/opentelemetry/instrumentation/google_genai/__init__.py +454 -0
  6. lmnr/openllmetry_sdk/opentelemetry/instrumentation/google_genai/config.py +9 -0
  7. lmnr/openllmetry_sdk/opentelemetry/instrumentation/google_genai/utils.py +216 -0
  8. lmnr/openllmetry_sdk/tracing/__init__.py +1 -0
  9. lmnr/openllmetry_sdk/tracing/context_manager.py +13 -0
  10. lmnr/openllmetry_sdk/tracing/tracing.py +230 -252
  11. lmnr/sdk/browser/playwright_otel.py +42 -58
  12. lmnr/sdk/browser/pw_utils.py +8 -40
  13. lmnr/sdk/client/asynchronous/async_client.py +0 -34
  14. lmnr/sdk/client/asynchronous/resources/__init__.py +0 -4
  15. lmnr/sdk/client/asynchronous/resources/agent.py +96 -6
  16. lmnr/sdk/client/synchronous/resources/__init__.py +1 -3
  17. lmnr/sdk/client/synchronous/resources/agent.py +94 -8
  18. lmnr/sdk/client/synchronous/sync_client.py +0 -36
  19. lmnr/sdk/decorators.py +16 -2
  20. lmnr/sdk/laminar.py +3 -3
  21. lmnr/sdk/types.py +84 -170
  22. lmnr/sdk/utils.py +8 -1
  23. lmnr/version.py +1 -1
  24. {lmnr-0.5.1a0.dist-info → lmnr-0.5.2.dist-info}/METADATA +57 -57
  25. lmnr-0.5.2.dist-info/RECORD +54 -0
  26. lmnr/sdk/client/asynchronous/resources/pipeline.py +0 -89
  27. lmnr/sdk/client/asynchronous/resources/semantic_search.py +0 -60
  28. lmnr/sdk/client/synchronous/resources/pipeline.py +0 -89
  29. lmnr/sdk/client/synchronous/resources/semantic_search.py +0 -60
  30. lmnr-0.5.1a0.dist-info/RECORD +0 -54
  31. {lmnr-0.5.1a0.dist-info → lmnr-0.5.2.dist-info}/LICENSE +0 -0
  32. {lmnr-0.5.1a0.dist-info → lmnr-0.5.2.dist-info}/WHEEL +0 -0
  33. {lmnr-0.5.1a0.dist-info → lmnr-0.5.2.dist-info}/entry_points.txt +0 -0
lmnr/__init__.py CHANGED
@@ -6,11 +6,7 @@ from .sdk.laminar import Laminar
6
6
  from .sdk.types import (
7
7
  AgentOutput,
8
8
  FinalOutputChunkContent,
9
- ChatMessage,
10
9
  HumanEvaluator,
11
- NodeInput,
12
- PipelineRunError,
13
- PipelineRunResponse,
14
10
  RunAgentResponseChunk,
15
11
  StepChunkContent,
16
12
  TracingLevel,
@@ -25,7 +21,6 @@ __all__ = [
25
21
  "AgentOutput",
26
22
  "AsyncLaminarClient",
27
23
  "Attributes",
28
- "ChatMessage",
29
24
  "EvaluationDataset",
30
25
  "FinalOutputChunkContent",
31
26
  "HumanEvaluator",
@@ -34,9 +29,6 @@ __all__ = [
34
29
  "LaminarClient",
35
30
  "LaminarDataset",
36
31
  "LaminarSpanContext",
37
- "NodeInput",
38
- "PipelineRunError",
39
- "PipelineRunResponse",
40
32
  "RunAgentResponseChunk",
41
33
  "StepChunkContent",
42
34
  "TracingLevel",
@@ -1,6 +1,5 @@
1
1
  import sys
2
2
 
3
- from contextlib import contextmanager
4
3
  from typing import Optional, Set
5
4
  from opentelemetry.sdk.trace import SpanProcessor
6
5
  from opentelemetry.sdk.trace.export import SpanExporter
@@ -19,7 +18,6 @@ from typing import Dict
19
18
 
20
19
  class TracerManager:
21
20
  __tracer_wrapper: TracerWrapper
22
- __initialized: bool = False
23
21
 
24
22
  @staticmethod
25
23
  def init(
@@ -53,6 +51,9 @@ class TracerManager:
53
51
 
54
52
  # Tracer init
55
53
  resource_attributes.update({SERVICE_NAME: app_name})
54
+ TracerWrapper.set_static_params(
55
+ resource_attributes, enable_content_tracing, api_endpoint, headers
56
+ )
56
57
  TracerManager.__tracer_wrapper = TracerWrapper(
57
58
  disable_batch=disable_batch,
58
59
  processor=processor,
@@ -63,41 +64,12 @@ class TracerManager:
63
64
  base_http_url=base_http_url,
64
65
  project_api_key=project_api_key,
65
66
  max_export_batch_size=max_export_batch_size,
66
- resource_attributes=resource_attributes,
67
- enable_content_tracing=enable_content_tracing,
68
- endpoint=api_endpoint,
69
- headers=headers,
70
67
  )
71
- TracerManager.__initialized = True
72
68
 
73
69
  @staticmethod
74
70
  def flush() -> bool:
75
71
  return TracerManager.__tracer_wrapper.flush()
76
72
 
77
73
  @staticmethod
78
- def shutdown() -> bool:
79
- try:
80
- res = TracerManager.__tracer_wrapper.shutdown()
81
- TracerManager.__tracer_wrapper = None
82
- TracerManager.__initialized = False
83
- return res
84
- except Exception:
85
- return False
86
-
87
- @staticmethod
88
- def is_initialized() -> bool:
89
- return TracerManager.__initialized
90
-
91
- @staticmethod
92
- def get_tracer_wrapper() -> TracerWrapper:
93
- return TracerManager.__tracer_wrapper
94
-
95
-
96
- @contextmanager
97
- def get_tracer(flush_on_exit: bool = False):
98
- wrapper = TracerManager.get_tracer_wrapper()
99
- try:
100
- yield wrapper.get_tracer()
101
- finally:
102
- if flush_on_exit:
103
- wrapper.flush()
74
+ def shutdown():
75
+ TracerManager.__tracer_wrapper.shutdown()
@@ -1,7 +1,6 @@
1
1
  import json
2
2
  from functools import wraps
3
3
  import logging
4
- import os
5
4
  import pydantic
6
5
  import types
7
6
  from typing import Any, Literal, Optional, Union
@@ -11,9 +10,9 @@ from opentelemetry import context as context_api
11
10
  from opentelemetry.trace import Span
12
11
 
13
12
  from lmnr.sdk.utils import get_input_from_func_args, is_method
14
- from lmnr.openllmetry_sdk import get_tracer
13
+ from lmnr.openllmetry_sdk.tracing import get_tracer
15
14
  from lmnr.openllmetry_sdk.tracing.attributes import SPAN_INPUT, SPAN_OUTPUT, SPAN_TYPE
16
- from lmnr.openllmetry_sdk import TracerManager
15
+ from lmnr.openllmetry_sdk.tracing.tracing import TracerWrapper
17
16
  from lmnr.openllmetry_sdk.utils.json_encoder import JSONEncoder
18
17
  from lmnr.openllmetry_sdk.config import MAX_MANUAL_SPAN_PAYLOAD_SIZE
19
18
 
@@ -40,13 +39,14 @@ def json_dumps(data: dict) -> str:
40
39
  def entity_method(
41
40
  name: Optional[str] = None,
42
41
  ignore_input: bool = False,
42
+ ignore_inputs: Optional[list[str]] = None,
43
43
  ignore_output: bool = False,
44
44
  span_type: Union[Literal["DEFAULT"], Literal["LLM"], Literal["TOOL"]] = "DEFAULT",
45
45
  ):
46
46
  def decorate(fn):
47
47
  @wraps(fn)
48
48
  def wrap(*args, **kwargs):
49
- if not TracerManager.is_initialized():
49
+ if not TracerWrapper.verify_initialized():
50
50
  return fn(*args, **kwargs)
51
51
 
52
52
  span_name = name or fn.__name__
@@ -58,9 +58,15 @@ def entity_method(
58
58
  ctx_token = context_api.attach(ctx)
59
59
 
60
60
  try:
61
- if _should_send_prompts() and not ignore_input:
61
+ if not ignore_input:
62
62
  inp = json_dumps(
63
- get_input_from_func_args(fn, is_method(fn), args, kwargs)
63
+ get_input_from_func_args(
64
+ fn,
65
+ is_method=is_method(fn),
66
+ func_args=args,
67
+ func_kwargs=kwargs,
68
+ ignore_inputs=ignore_inputs,
69
+ )
64
70
  )
65
71
  if len(inp) > MAX_MANUAL_SPAN_PAYLOAD_SIZE:
66
72
  span.set_attribute(
@@ -83,7 +89,7 @@ def entity_method(
83
89
  return _handle_generator(span, res)
84
90
 
85
91
  try:
86
- if _should_send_prompts() and not ignore_output:
92
+ if not ignore_output:
87
93
  output = json_dumps(res)
88
94
  if len(output) > MAX_MANUAL_SPAN_PAYLOAD_SIZE:
89
95
  span.set_attribute(
@@ -108,13 +114,14 @@ def entity_method(
108
114
  def aentity_method(
109
115
  name: Optional[str] = None,
110
116
  ignore_input: bool = False,
117
+ ignore_inputs: Optional[list[str]] = None,
111
118
  ignore_output: bool = False,
112
119
  span_type: Union[Literal["DEFAULT"], Literal["LLM"], Literal["TOOL"]] = "DEFAULT",
113
120
  ):
114
121
  def decorate(fn):
115
122
  @wraps(fn)
116
123
  async def wrap(*args, **kwargs):
117
- if not TracerManager.is_initialized():
124
+ if not TracerWrapper.verify_initialized():
118
125
  return await fn(*args, **kwargs)
119
126
 
120
127
  span_name = name or fn.__name__
@@ -126,9 +133,15 @@ def aentity_method(
126
133
  ctx_token = context_api.attach(ctx)
127
134
 
128
135
  try:
129
- if _should_send_prompts() and not ignore_input:
136
+ if not ignore_input:
130
137
  inp = json_dumps(
131
- get_input_from_func_args(fn, is_method(fn), args, kwargs)
138
+ get_input_from_func_args(
139
+ fn,
140
+ is_method=is_method(fn),
141
+ func_args=args,
142
+ func_kwargs=kwargs,
143
+ ignore_inputs=ignore_inputs,
144
+ )
132
145
  )
133
146
  if len(inp) > MAX_MANUAL_SPAN_PAYLOAD_SIZE:
134
147
  span.set_attribute(
@@ -151,7 +164,7 @@ def aentity_method(
151
164
  return await _ahandle_generator(span, ctx_token, res)
152
165
 
153
166
  try:
154
- if _should_send_prompts() and not ignore_output:
167
+ if not ignore_output:
155
168
  output = json_dumps(res)
156
169
  if len(output) > MAX_MANUAL_SPAN_PAYLOAD_SIZE:
157
170
  span.set_attribute(
@@ -192,12 +205,6 @@ async def _ahandle_generator(span, ctx_token, res):
192
205
  context_api.detach(ctx_token)
193
206
 
194
207
 
195
- def _should_send_prompts():
196
- return (
197
- os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
198
- ).lower() == "true" or context_api.get_value("override_enable_content_tracing")
199
-
200
-
201
208
  def _process_exception(span: Span, e: Exception):
202
209
  # Note that this `escaped` is sent as a StringValue("True"), not a boolean.
203
210
  span.record_exception(e, escaped=True)
@@ -11,6 +11,7 @@ class Instruments(Enum):
11
11
  CHROMA = "chroma"
12
12
  COHERE = "cohere"
13
13
  GOOGLE_GENERATIVEAI = "google_generativeai"
14
+ GOOGLE_GENAI = "google_genai"
14
15
  GROQ = "groq"
15
16
  HAYSTACK = "haystack"
16
17
  LANCEDB = "lancedb"
@@ -0,0 +1,454 @@
1
+ """OpenTelemetry Google Generative AI API instrumentation"""
2
+
3
+ from collections import defaultdict
4
+ import logging
5
+ import os
6
+ from typing import AsyncGenerator, Callable, Collection, Generator, Optional
7
+
8
+ from google.genai import types
9
+
10
+ from .config import (
11
+ Config,
12
+ )
13
+ from .utils import (
14
+ dont_throw,
15
+ role_from_content_union,
16
+ set_span_attribute,
17
+ process_content_union,
18
+ to_dict,
19
+ with_tracer_wrapper,
20
+ )
21
+ from opentelemetry.trace import Tracer
22
+ from wrapt import wrap_function_wrapper
23
+
24
+ from opentelemetry import context as context_api
25
+ from opentelemetry.trace import get_tracer, SpanKind, Span
26
+ from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
27
+
28
+ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
29
+ from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY, unwrap
30
+
31
+ from opentelemetry.semconv_ai import (
32
+ SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
33
+ SpanAttributes,
34
+ LLMRequestTypeValues,
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ _instruments = ("google-genai >= 1.0.0",)
40
+
41
+ WRAPPED_METHODS = [
42
+ {
43
+ "package": "google.genai.models",
44
+ "object": "Models",
45
+ "method": "generate_content",
46
+ "span_name": "gemini.generate_content",
47
+ "is_streaming": False,
48
+ "is_async": False,
49
+ },
50
+ {
51
+ "package": "google.genai.models",
52
+ "object": "AsyncModels",
53
+ "method": "generate_content",
54
+ "span_name": "gemini.generate_content",
55
+ "is_streaming": False,
56
+ "is_async": True,
57
+ },
58
+ {
59
+ "package": "google.genai.models",
60
+ "object": "Models",
61
+ "method": "generate_content_stream",
62
+ "span_name": "gemini.generate_content_stream",
63
+ "is_streaming": True,
64
+ "is_async": False,
65
+ },
66
+ {
67
+ "package": "google.genai.models",
68
+ "object": "AsyncModels",
69
+ "method": "generate_content_stream",
70
+ "span_name": "gemini.generate_content_stream",
71
+ "is_streaming": True,
72
+ "is_async": True,
73
+ },
74
+ ]
75
+
76
+
77
+ def should_send_prompts():
78
+ return (
79
+ os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
80
+ ).lower() == "true" or context_api.get_value("override_enable_content_tracing")
81
+
82
+
83
+ @dont_throw
84
+ def _set_request_attributes(span, args, kwargs):
85
+ config_dict = to_dict(kwargs.get("config", {}))
86
+ set_span_attribute(
87
+ span, gen_ai_attributes.GEN_AI_REQUEST_MODEL, kwargs.get("model")
88
+ )
89
+ set_span_attribute(
90
+ span,
91
+ gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE,
92
+ config_dict.get("temperature"),
93
+ )
94
+ set_span_attribute(
95
+ span, gen_ai_attributes.GEN_AI_REQUEST_TOP_P, config_dict.get("top_p")
96
+ )
97
+ set_span_attribute(
98
+ span, gen_ai_attributes.GEN_AI_REQUEST_TOP_K, config_dict.get("top_k")
99
+ )
100
+ set_span_attribute(
101
+ span,
102
+ gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT,
103
+ config_dict.get("candidate_count"),
104
+ )
105
+ set_span_attribute(
106
+ span,
107
+ gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS,
108
+ config_dict.get("max_output_tokens"),
109
+ )
110
+ set_span_attribute(
111
+ span,
112
+ gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES,
113
+ config_dict.get("stop_sequences"),
114
+ )
115
+ set_span_attribute(
116
+ span,
117
+ gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY,
118
+ config_dict.get("frequency_penalty"),
119
+ )
120
+ set_span_attribute(
121
+ span,
122
+ gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY,
123
+ config_dict.get("presence_penalty"),
124
+ )
125
+ set_span_attribute(
126
+ span, gen_ai_attributes.GEN_AI_REQUEST_SEED, config_dict.get("seed")
127
+ )
128
+
129
+ tools: list[types.FunctionDeclaration] = []
130
+ if kwargs.get("tools"):
131
+ for tool in kwargs.get("tools"):
132
+ if isinstance(tool, types.Tool):
133
+ tools += tool.function_declarations or []
134
+ elif isinstance(tool, Callable):
135
+ tools.append(types.FunctionDeclaration.from_callable(tool))
136
+ for tool_num, tool in enumerate(tools):
137
+ set_span_attribute(
138
+ span,
139
+ f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{tool_num}.name",
140
+ to_dict(tool).get("name"),
141
+ )
142
+ set_span_attribute(
143
+ span,
144
+ f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{tool_num}.description",
145
+ to_dict(tool).get("description"),
146
+ )
147
+ set_span_attribute(
148
+ span,
149
+ f"{SpanAttributes.LLM_REQUEST_FUNCTIONS}.{tool_num}.parameters",
150
+ to_dict(tool).get("parameters"),
151
+ )
152
+
153
+ if should_send_prompts():
154
+ i = 0
155
+ system_instruction: Optional[types.ContentUnion] = config_dict.get(
156
+ "system_instruction"
157
+ )
158
+ if system_instruction:
159
+ set_span_attribute(
160
+ span,
161
+ f"{gen_ai_attributes.GEN_AI_PROMPT}.{i}.content",
162
+ process_content_union(system_instruction),
163
+ )
164
+ set_span_attribute(
165
+ span, f"{gen_ai_attributes.GEN_AI_PROMPT}.{i}.role", "system"
166
+ )
167
+ i += 1
168
+ contents = kwargs.get("contents", [])
169
+ if not isinstance(contents, list):
170
+ contents = [contents]
171
+ for content in contents:
172
+ set_span_attribute(
173
+ span,
174
+ f"{gen_ai_attributes.GEN_AI_PROMPT}.{i}.content",
175
+ process_content_union(content),
176
+ )
177
+ set_span_attribute(
178
+ span,
179
+ f"{gen_ai_attributes.GEN_AI_PROMPT}.{i}.role",
180
+ role_from_content_union(content) or "user",
181
+ )
182
+ i += 1
183
+
184
+
185
+ @dont_throw
186
+ def _set_response_attributes(span, response: types.GenerateContentResponse):
187
+ candidates = response.candidates or []
188
+ set_span_attribute(
189
+ span, gen_ai_attributes.GEN_AI_RESPONSE_ID, to_dict(response).get("response_id")
190
+ )
191
+ set_span_attribute(
192
+ span,
193
+ gen_ai_attributes.GEN_AI_RESPONSE_MODEL,
194
+ to_dict(response).get("model_version"),
195
+ )
196
+
197
+ if response.usage_metadata:
198
+ usage_dict = to_dict(response.usage_metadata)
199
+ set_span_attribute(
200
+ span,
201
+ gen_ai_attributes.GEN_AI_USAGE_INPUT_TOKENS,
202
+ usage_dict.get("prompt_token_count"),
203
+ )
204
+ set_span_attribute(
205
+ span,
206
+ gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS,
207
+ usage_dict.get("candidates_token_count"),
208
+ )
209
+ set_span_attribute(
210
+ span,
211
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
212
+ usage_dict.get("total_token_count"),
213
+ )
214
+ set_span_attribute(
215
+ span,
216
+ SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS,
217
+ usage_dict.get("cached_content_token_count"),
218
+ )
219
+
220
+ if should_send_prompts():
221
+ if len(candidates) > 1:
222
+ for i, candidate in enumerate(candidates):
223
+ set_span_attribute(
224
+ span,
225
+ f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.content",
226
+ process_content_union(candidate.content),
227
+ )
228
+ set_span_attribute(
229
+ span, f"{gen_ai_attributes.GEN_AI_COMPLETION}.{i}.role", "assistant"
230
+ )
231
+ else:
232
+ set_span_attribute(
233
+ span, f"{gen_ai_attributes.GEN_AI_COMPLETION}.0.content", response.text
234
+ )
235
+ set_span_attribute(
236
+ span, f"{gen_ai_attributes.GEN_AI_COMPLETION}.0.role", "assistant"
237
+ )
238
+
239
+
240
+ @dont_throw
241
+ def _build_from_streaming_response(
242
+ span: Span, response: Generator[types.GenerateContentResponse, None, None]
243
+ ) -> Generator[types.GenerateContentResponse, None, None]:
244
+ final_parts = []
245
+ role = "model"
246
+ aggregated_usage_metadata = defaultdict(int)
247
+ model_version = None
248
+ for chunk in response:
249
+ if chunk.model_version:
250
+ model_version = chunk.model_version
251
+
252
+ if chunk.candidates:
253
+ # Currently gemini throws an error if you pass more than one candidate
254
+ # with streaming
255
+ if chunk.candidates and len(chunk.candidates) > 0:
256
+ final_parts += chunk.candidates[0].content.parts or []
257
+ role = chunk.candidates[0].content.role or role
258
+ if chunk.usage_metadata:
259
+ usage_dict = to_dict(chunk.usage_metadata)
260
+ # prompt token count is sent in every chunk
261
+ # (and is less by 1 in the last chunk, so we set it once);
262
+ # total token count in every chunk is greater by prompt token count than it should be,
263
+ # thus this awkward logic here
264
+ if aggregated_usage_metadata.get("prompt_token_count") is None:
265
+ aggregated_usage_metadata["prompt_token_count"] = (
266
+ usage_dict.get("prompt_token_count") or 0
267
+ )
268
+ aggregated_usage_metadata["total_token_count"] = (
269
+ usage_dict.get("total_token_count") or 0
270
+ )
271
+ aggregated_usage_metadata["candidates_token_count"] += (
272
+ usage_dict.get("candidates_token_count") or 0
273
+ )
274
+ aggregated_usage_metadata["total_token_count"] += (
275
+ usage_dict.get("candidates_token_count") or 0
276
+ )
277
+ yield chunk
278
+
279
+ compound_response = types.GenerateContentResponse(
280
+ candidates=[
281
+ {
282
+ "content": {
283
+ "parts": final_parts,
284
+ "role": role,
285
+ },
286
+ }
287
+ ],
288
+ usage_metadata=types.GenerateContentResponseUsageMetadataDict(
289
+ **aggregated_usage_metadata
290
+ ),
291
+ model_version=model_version,
292
+ )
293
+ if span.is_recording():
294
+ _set_response_attributes(span, compound_response)
295
+ span.end()
296
+
297
+
298
+ @dont_throw
299
+ async def _abuild_from_streaming_response(
300
+ span: Span, response: AsyncGenerator[types.GenerateContentResponse, None]
301
+ ) -> AsyncGenerator[types.GenerateContentResponse, None]:
302
+ final_parts = []
303
+ role = "model"
304
+ aggregated_usage_metadata = defaultdict(int)
305
+ model_version = None
306
+ async for chunk in response:
307
+ if chunk.candidates:
308
+ # Currently gemini throws an error if you pass more than one candidate
309
+ # with streaming
310
+ if chunk.candidates and len(chunk.candidates) > 0:
311
+ final_parts += chunk.candidates[0].content.parts or []
312
+ role = chunk.candidates[0].content.role or role
313
+ if chunk.model_version:
314
+ model_version = chunk.model_version
315
+ if chunk.usage_metadata:
316
+ usage_dict = to_dict(chunk.usage_metadata)
317
+ # prompt token count is sent in every chunk
318
+ # (and is less by 1 in the last chunk, so we set it once);
319
+ # total token count in every chunk is greater by prompt token count than it should be,
320
+ # thus this awkward logic here
321
+ if aggregated_usage_metadata.get("prompt_token_count") is None:
322
+ aggregated_usage_metadata["prompt_token_count"] = usage_dict.get(
323
+ "prompt_token_count"
324
+ )
325
+ aggregated_usage_metadata["total_token_count"] = usage_dict.get(
326
+ "total_token_count"
327
+ )
328
+ aggregated_usage_metadata["candidates_token_count"] += (
329
+ usage_dict.get("candidates_token_count") or 0
330
+ )
331
+ aggregated_usage_metadata["total_token_count"] += (
332
+ usage_dict.get("candidates_token_count") or 0
333
+ )
334
+ yield chunk
335
+
336
+ compound_response = types.GenerateContentResponse(
337
+ candidates=[
338
+ {
339
+ "content": {
340
+ "parts": final_parts,
341
+ "role": role,
342
+ },
343
+ }
344
+ ],
345
+ usage_metadata=types.GenerateContentResponseUsageMetadataDict(
346
+ **aggregated_usage_metadata
347
+ ),
348
+ model_version=model_version,
349
+ )
350
+ if span.is_recording():
351
+ _set_response_attributes(span, compound_response)
352
+ span.end()
353
+
354
+
355
+ @with_tracer_wrapper
356
+ def _wrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
357
+ if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
358
+ SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
359
+ ):
360
+ return wrapped(*args, **kwargs)
361
+
362
+ span = tracer.start_span(
363
+ to_wrap.get("span_name"),
364
+ kind=SpanKind.CLIENT,
365
+ attributes={
366
+ SpanAttributes.LLM_SYSTEM: "gemini",
367
+ SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value,
368
+ },
369
+ )
370
+
371
+ if span.is_recording():
372
+ _set_request_attributes(span, args, kwargs)
373
+
374
+ if to_wrap.get("is_streaming"):
375
+ return _build_from_streaming_response(span, wrapped(*args, **kwargs))
376
+ else:
377
+ response = wrapped(*args, **kwargs)
378
+
379
+ if span.is_recording():
380
+ _set_response_attributes(span, response)
381
+
382
+ span.end()
383
+ return response
384
+
385
+
386
+ @with_tracer_wrapper
387
+ async def _awrap(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
388
+ if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
389
+ SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
390
+ ):
391
+ return await wrapped(*args, **kwargs)
392
+
393
+ span = tracer.start_span(
394
+ to_wrap.get("span_name"),
395
+ kind=SpanKind.CLIENT,
396
+ attributes={
397
+ SpanAttributes.LLM_SYSTEM: "gemini",
398
+ SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value,
399
+ },
400
+ )
401
+
402
+ if span.is_recording():
403
+ _set_request_attributes(span, args, kwargs)
404
+
405
+ if to_wrap.get("is_streaming"):
406
+ return _abuild_from_streaming_response(span, await wrapped(*args, **kwargs))
407
+ else:
408
+ response = await wrapped(*args, **kwargs)
409
+
410
+ if span.is_recording():
411
+ _set_response_attributes(span, response)
412
+
413
+ span.end()
414
+ return response
415
+
416
+
417
+ class GoogleGenAiSdkInstrumentor(BaseInstrumentor):
418
+ """An instrumentor for Google GenAI's client library."""
419
+
420
+ def __init__(
421
+ self,
422
+ exception_logger=None,
423
+ upload_base64_image=None,
424
+ convert_image_to_openai_format=True,
425
+ ):
426
+ super().__init__()
427
+ Config.exception_logger = exception_logger
428
+ Config.upload_base64_image = upload_base64_image
429
+ Config.convert_image_to_openai_format = convert_image_to_openai_format
430
+
431
+ def instrumentation_dependencies(self) -> Collection[str]:
432
+ return _instruments
433
+
434
+ def _instrument(self, **kwargs):
435
+ tracer_provider = kwargs.get("tracer_provider")
436
+ tracer = get_tracer(__name__, "0.0.1a0", tracer_provider)
437
+
438
+ for wrapped_method in WRAPPED_METHODS:
439
+ wrap_function_wrapper(
440
+ wrapped_method.get("package"),
441
+ f"{wrapped_method.get('object')}.{wrapped_method.get('method')}",
442
+ (
443
+ _awrap(tracer, wrapped_method)
444
+ if wrapped_method.get("is_async")
445
+ else _wrap(tracer, wrapped_method)
446
+ ),
447
+ )
448
+
449
+ def _uninstrument(self, **kwargs):
450
+ for wrapped_method in WRAPPED_METHODS:
451
+ unwrap(
452
+ f"{wrapped_method.get('package')}.{wrapped_method.get('object')}",
453
+ wrapped_method.get("method"),
454
+ )
@@ -0,0 +1,9 @@
1
+ from typing import Callable, Coroutine, Optional
2
+
3
+
4
+ class Config:
5
+ exception_logger = None
6
+ upload_base64_image: Optional[
7
+ Callable[[str, str, str, str], Coroutine[None, None, str]]
8
+ ] = None
9
+ convert_image_to_openai_format: bool = True