aitracer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,350 @@
1
+ """Google Gemini client wrapper for automatic logging."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ import uuid
7
+ from functools import wraps
8
+ from typing import TYPE_CHECKING, Any, Iterator
9
+
10
+ if TYPE_CHECKING:
11
+ from google.generativeai import GenerativeModel
12
+ from google.generativeai.types import GenerateContentResponse
13
+ from aitracer.client import AITracer
14
+
15
+
16
+ def wrap_gemini_model(model: "GenerativeModel", tracer: "AITracer") -> "GenerativeModel":
17
+ """
18
+ Wrap a Google Gemini GenerativeModel to automatically log all API calls.
19
+
20
+ Args:
21
+ model: GenerativeModel instance.
22
+ tracer: AITracer instance.
23
+
24
+ Returns:
25
+ Wrapped GenerativeModel (same instance, modified in place).
26
+
27
+ Example:
28
+ >>> import google.generativeai as genai
29
+ >>> from aitracer import AITracer
30
+ >>> from aitracer.wrappers import wrap_gemini_model
31
+ >>>
32
+ >>> genai.configure(api_key="your-api-key")
33
+ >>> model = genai.GenerativeModel("gemini-1.5-flash")
34
+ >>> tracer = AITracer(api_key="your-aitracer-key")
35
+ >>> model = wrap_gemini_model(model, tracer)
36
+ >>>
37
+ >>> response = model.generate_content("Hello!")
38
+ """
39
+ model_name = getattr(model, "model_name", "gemini-unknown")
40
+
41
+ # Wrap generate_content
42
+ original_generate = model.generate_content
43
+
44
+ @wraps(original_generate)
45
+ def wrapped_generate_content(
46
+ contents: Any,
47
+ *args: Any,
48
+ stream: bool = False,
49
+ **kwargs: Any,
50
+ ) -> Any:
51
+ """Wrapped generate_content method."""
52
+ start_time = time.time()
53
+ span_id = str(uuid.uuid4())
54
+
55
+ try:
56
+ response = original_generate(contents, *args, stream=stream, **kwargs)
57
+
58
+ if stream:
59
+ return _wrap_stream_response(
60
+ response=response,
61
+ tracer=tracer,
62
+ model=model_name,
63
+ contents=contents,
64
+ start_time=start_time,
65
+ span_id=span_id,
66
+ )
67
+ else:
68
+ latency_ms = int((time.time() - start_time) * 1000)
69
+ _log_generation(
70
+ tracer=tracer,
71
+ model=model_name,
72
+ contents=contents,
73
+ response=response,
74
+ latency_ms=latency_ms,
75
+ span_id=span_id,
76
+ )
77
+ return response
78
+
79
+ except Exception as e:
80
+ latency_ms = int((time.time() - start_time) * 1000)
81
+ tracer.log(
82
+ model=model_name,
83
+ provider="google",
84
+ input_data={"contents": _serialize_contents(contents)},
85
+ output_data=None,
86
+ latency_ms=latency_ms,
87
+ status="error",
88
+ error_message=str(e),
89
+ span_id=span_id,
90
+ )
91
+ raise
92
+
93
+ model.generate_content = wrapped_generate_content # type: ignore
94
+
95
+ # Wrap generate_content_async if available
96
+ if hasattr(model, "generate_content_async"):
97
+ original_generate_async = model.generate_content_async
98
+
99
+ @wraps(original_generate_async)
100
+ async def wrapped_generate_content_async(
101
+ contents: Any,
102
+ *args: Any,
103
+ stream: bool = False,
104
+ **kwargs: Any,
105
+ ) -> Any:
106
+ """Wrapped async generate_content method."""
107
+ start_time = time.time()
108
+ span_id = str(uuid.uuid4())
109
+
110
+ try:
111
+ response = await original_generate_async(
112
+ contents, *args, stream=stream, **kwargs
113
+ )
114
+
115
+ if stream:
116
+ return _wrap_async_stream_response(
117
+ response=response,
118
+ tracer=tracer,
119
+ model=model_name,
120
+ contents=contents,
121
+ start_time=start_time,
122
+ span_id=span_id,
123
+ )
124
+ else:
125
+ latency_ms = int((time.time() - start_time) * 1000)
126
+ _log_generation(
127
+ tracer=tracer,
128
+ model=model_name,
129
+ contents=contents,
130
+ response=response,
131
+ latency_ms=latency_ms,
132
+ span_id=span_id,
133
+ )
134
+ return response
135
+
136
+ except Exception as e:
137
+ latency_ms = int((time.time() - start_time) * 1000)
138
+ tracer.log(
139
+ model=model_name,
140
+ provider="google",
141
+ input_data={"contents": _serialize_contents(contents)},
142
+ output_data=None,
143
+ latency_ms=latency_ms,
144
+ status="error",
145
+ error_message=str(e),
146
+ span_id=span_id,
147
+ )
148
+ raise
149
+
150
+ model.generate_content_async = wrapped_generate_content_async # type: ignore
151
+
152
+ return model
153
+
154
+
155
+ def _wrap_stream_response(
156
+ response: Iterator["GenerateContentResponse"],
157
+ tracer: "AITracer",
158
+ model: str,
159
+ contents: Any,
160
+ start_time: float,
161
+ span_id: str,
162
+ ) -> Iterator["GenerateContentResponse"]:
163
+ """Wrap streaming response to log after completion."""
164
+ content_parts: list[str] = []
165
+ input_tokens = 0
166
+ output_tokens = 0
167
+ last_response = None
168
+
169
+ try:
170
+ for chunk in response:
171
+ last_response = chunk
172
+
173
+ # Accumulate content
174
+ if chunk.text:
175
+ content_parts.append(chunk.text)
176
+
177
+ # Get usage if available
178
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
179
+ input_tokens = getattr(chunk.usage_metadata, "prompt_token_count", 0) or 0
180
+ output_tokens = getattr(chunk.usage_metadata, "candidates_token_count", 0) or 0
181
+
182
+ yield chunk
183
+
184
+ # Log after stream completes
185
+ latency_ms = int((time.time() - start_time) * 1000)
186
+ full_content = "".join(content_parts)
187
+
188
+ tracer.log(
189
+ model=model,
190
+ provider="google",
191
+ input_data={"contents": _serialize_contents(contents)},
192
+ output_data={"content": full_content},
193
+ input_tokens=input_tokens,
194
+ output_tokens=output_tokens,
195
+ latency_ms=latency_ms,
196
+ status="success",
197
+ span_id=span_id,
198
+ )
199
+
200
+ except Exception as e:
201
+ latency_ms = int((time.time() - start_time) * 1000)
202
+ tracer.log(
203
+ model=model,
204
+ provider="google",
205
+ input_data={"contents": _serialize_contents(contents)},
206
+ output_data=None,
207
+ latency_ms=latency_ms,
208
+ status="error",
209
+ error_message=str(e),
210
+ span_id=span_id,
211
+ )
212
+ raise
213
+
214
+
215
+ async def _wrap_async_stream_response(
216
+ response: Any,
217
+ tracer: "AITracer",
218
+ model: str,
219
+ contents: Any,
220
+ start_time: float,
221
+ span_id: str,
222
+ ) -> Any:
223
+ """Wrap async streaming response to log after completion."""
224
+ content_parts: list[str] = []
225
+ input_tokens = 0
226
+ output_tokens = 0
227
+
228
+ try:
229
+ async for chunk in response:
230
+ # Accumulate content
231
+ if chunk.text:
232
+ content_parts.append(chunk.text)
233
+
234
+ # Get usage if available
235
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
236
+ input_tokens = getattr(chunk.usage_metadata, "prompt_token_count", 0) or 0
237
+ output_tokens = getattr(chunk.usage_metadata, "candidates_token_count", 0) or 0
238
+
239
+ yield chunk
240
+
241
+ # Log after stream completes
242
+ latency_ms = int((time.time() - start_time) * 1000)
243
+ full_content = "".join(content_parts)
244
+
245
+ tracer.log(
246
+ model=model,
247
+ provider="google",
248
+ input_data={"contents": _serialize_contents(contents)},
249
+ output_data={"content": full_content},
250
+ input_tokens=input_tokens,
251
+ output_tokens=output_tokens,
252
+ latency_ms=latency_ms,
253
+ status="success",
254
+ span_id=span_id,
255
+ )
256
+
257
+ except Exception as e:
258
+ latency_ms = int((time.time() - start_time) * 1000)
259
+ tracer.log(
260
+ model=model,
261
+ provider="google",
262
+ input_data={"contents": _serialize_contents(contents)},
263
+ output_data=None,
264
+ latency_ms=latency_ms,
265
+ status="error",
266
+ error_message=str(e),
267
+ span_id=span_id,
268
+ )
269
+ raise
270
+
271
+
272
+ def _log_generation(
273
+ tracer: "AITracer",
274
+ model: str,
275
+ contents: Any,
276
+ response: "GenerateContentResponse",
277
+ latency_ms: int,
278
+ span_id: str,
279
+ ) -> None:
280
+ """Log a non-streaming generation."""
281
+ # Extract response data
282
+ output_content = None
283
+ try:
284
+ output_content = response.text
285
+ except Exception:
286
+ # Response may not have text (e.g., safety blocked)
287
+ if response.candidates:
288
+ parts = []
289
+ for candidate in response.candidates:
290
+ if candidate.content and candidate.content.parts:
291
+ for part in candidate.content.parts:
292
+ if hasattr(part, "text"):
293
+ parts.append(part.text)
294
+ output_content = "".join(parts) if parts else None
295
+
296
+ # Extract token usage
297
+ input_tokens = 0
298
+ output_tokens = 0
299
+ if hasattr(response, "usage_metadata") and response.usage_metadata:
300
+ input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0
301
+ output_tokens = getattr(response.usage_metadata, "candidates_token_count", 0) or 0
302
+
303
+ tracer.log(
304
+ model=model,
305
+ provider="google",
306
+ input_data={"contents": _serialize_contents(contents)},
307
+ output_data={"content": output_content},
308
+ input_tokens=input_tokens,
309
+ output_tokens=output_tokens,
310
+ latency_ms=latency_ms,
311
+ status="success",
312
+ span_id=span_id,
313
+ )
314
+
315
+
316
+ def _serialize_contents(contents: Any) -> Any:
317
+ """Serialize contents to JSON-compatible format."""
318
+ if contents is None:
319
+ return None
320
+
321
+ if isinstance(contents, str):
322
+ return contents
323
+
324
+ if isinstance(contents, list):
325
+ result = []
326
+ for item in contents:
327
+ result.append(_serialize_content_item(item))
328
+ return result
329
+
330
+ return _serialize_content_item(contents)
331
+
332
+
333
+ def _serialize_content_item(item: Any) -> Any:
334
+ """Serialize a single content item."""
335
+ if isinstance(item, str):
336
+ return item
337
+
338
+ if isinstance(item, dict):
339
+ return item
340
+
341
+ if hasattr(item, "to_dict"):
342
+ return item.to_dict()
343
+
344
+ if hasattr(item, "model_dump"):
345
+ return item.model_dump()
346
+
347
+ if hasattr(item, "__dict__"):
348
+ return {k: v for k, v in item.__dict__.items() if not k.startswith("_")}
349
+
350
+ return str(item)
@@ -0,0 +1,193 @@
1
+ """OpenAI client wrapper for automatic logging."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ import uuid
7
+ from functools import wraps
8
+ from typing import TYPE_CHECKING, Any, Callable, Iterator
9
+
10
+ if TYPE_CHECKING:
11
+ from openai import OpenAI
12
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
13
+ from aitracer.client import AITracer
14
+
15
+
16
+ def wrap_openai_client(client: "OpenAI", tracer: "AITracer") -> "OpenAI":
17
+ """
18
+ Wrap an OpenAI client to automatically log all API calls.
19
+
20
+ Args:
21
+ client: OpenAI client instance.
22
+ tracer: AITracer instance.
23
+
24
+ Returns:
25
+ Wrapped OpenAI client (same instance, modified in place).
26
+ """
27
+ # Store original method
28
+ original_create = client.chat.completions.create
29
+
30
+ @wraps(original_create)
31
+ def wrapped_create(*args: Any, **kwargs: Any) -> Any:
32
+ """Wrapped chat.completions.create method."""
33
+ start_time = time.time()
34
+ span_id = str(uuid.uuid4())
35
+
36
+ # Extract request data
37
+ model = kwargs.get("model", args[0] if args else "unknown")
38
+ messages = kwargs.get("messages", args[1] if len(args) > 1 else [])
39
+ stream = kwargs.get("stream", False)
40
+
41
+ try:
42
+ response = original_create(*args, **kwargs)
43
+
44
+ if stream:
45
+ # Handle streaming response
46
+ return _wrap_stream_response(
47
+ response=response,
48
+ tracer=tracer,
49
+ model=model,
50
+ messages=messages,
51
+ start_time=start_time,
52
+ span_id=span_id,
53
+ )
54
+ else:
55
+ # Handle non-streaming response
56
+ latency_ms = int((time.time() - start_time) * 1000)
57
+ _log_completion(
58
+ tracer=tracer,
59
+ model=model,
60
+ messages=messages,
61
+ response=response,
62
+ latency_ms=latency_ms,
63
+ span_id=span_id,
64
+ )
65
+ return response
66
+
67
+ except Exception as e:
68
+ # Log error
69
+ latency_ms = int((time.time() - start_time) * 1000)
70
+ tracer.log(
71
+ model=model,
72
+ provider="openai",
73
+ input_data={"messages": _serialize_messages(messages)},
74
+ output_data=None,
75
+ latency_ms=latency_ms,
76
+ status="error",
77
+ error_message=str(e),
78
+ span_id=span_id,
79
+ )
80
+ raise
81
+
82
+ # Replace method
83
+ client.chat.completions.create = wrapped_create # type: ignore
84
+
85
+ return client
86
+
87
+
88
+ def _wrap_stream_response(
89
+ response: Iterator["ChatCompletionChunk"],
90
+ tracer: "AITracer",
91
+ model: str,
92
+ messages: list,
93
+ start_time: float,
94
+ span_id: str,
95
+ ) -> Iterator["ChatCompletionChunk"]:
96
+ """Wrap streaming response to log after completion."""
97
+ chunks: list[Any] = []
98
+ content_parts: list[str] = []
99
+ input_tokens = 0
100
+ output_tokens = 0
101
+
102
+ try:
103
+ for chunk in response:
104
+ chunks.append(chunk)
105
+
106
+ # Accumulate content
107
+ if chunk.choices and chunk.choices[0].delta.content:
108
+ content_parts.append(chunk.choices[0].delta.content)
109
+
110
+ # Get usage if available (some models include it in final chunk)
111
+ if hasattr(chunk, "usage") and chunk.usage:
112
+ input_tokens = chunk.usage.prompt_tokens or 0
113
+ output_tokens = chunk.usage.completion_tokens or 0
114
+
115
+ yield chunk
116
+
117
+ # Log after stream completes
118
+ latency_ms = int((time.time() - start_time) * 1000)
119
+ full_content = "".join(content_parts)
120
+
121
+ tracer.log(
122
+ model=model,
123
+ provider="openai",
124
+ input_data={"messages": _serialize_messages(messages)},
125
+ output_data={"content": full_content},
126
+ input_tokens=input_tokens,
127
+ output_tokens=output_tokens,
128
+ latency_ms=latency_ms,
129
+ status="success",
130
+ span_id=span_id,
131
+ )
132
+
133
+ except Exception as e:
134
+ latency_ms = int((time.time() - start_time) * 1000)
135
+ tracer.log(
136
+ model=model,
137
+ provider="openai",
138
+ input_data={"messages": _serialize_messages(messages)},
139
+ output_data=None,
140
+ latency_ms=latency_ms,
141
+ status="error",
142
+ error_message=str(e),
143
+ span_id=span_id,
144
+ )
145
+ raise
146
+
147
+
148
+ def _log_completion(
149
+ tracer: "AITracer",
150
+ model: str,
151
+ messages: list,
152
+ response: "ChatCompletion",
153
+ latency_ms: int,
154
+ span_id: str,
155
+ ) -> None:
156
+ """Log a non-streaming completion."""
157
+ # Extract response data
158
+ output_content = None
159
+ if response.choices and response.choices[0].message:
160
+ output_content = response.choices[0].message.content
161
+
162
+ input_tokens = 0
163
+ output_tokens = 0
164
+ if response.usage:
165
+ input_tokens = response.usage.prompt_tokens
166
+ output_tokens = response.usage.completion_tokens
167
+
168
+ tracer.log(
169
+ model=model,
170
+ provider="openai",
171
+ input_data={"messages": _serialize_messages(messages)},
172
+ output_data={"content": output_content},
173
+ input_tokens=input_tokens,
174
+ output_tokens=output_tokens,
175
+ latency_ms=latency_ms,
176
+ status="success",
177
+ span_id=span_id,
178
+ )
179
+
180
+
181
+ def _serialize_messages(messages: list) -> list[dict]:
182
+ """Serialize messages to JSON-compatible format."""
183
+ result = []
184
+ for msg in messages:
185
+ if isinstance(msg, dict):
186
+ result.append(msg)
187
+ elif hasattr(msg, "model_dump"):
188
+ result.append(msg.model_dump())
189
+ elif hasattr(msg, "__dict__"):
190
+ result.append(msg.__dict__)
191
+ else:
192
+ result.append({"content": str(msg)})
193
+ return result