payi 0.1.0a40__py3-none-any.whl → 0.1.0a42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of payi might be problematic. Click here for more details.

payi/lib/instrument.py CHANGED
@@ -1,21 +1,28 @@
1
1
  import json
2
2
  import uuid
3
- import asyncio
4
3
  import inspect
5
4
  import logging
6
5
  import traceback
6
+ from enum import Enum
7
7
  from typing import Any, Set, Union, Callable, Optional
8
8
 
9
9
  from wrapt import ObjectProxy # type: ignore
10
10
 
11
11
  from payi import Payi, AsyncPayi
12
12
  from payi.types import IngestUnitsParams
13
+ from payi.types.ingest_response import IngestResponse
13
14
  from payi.types.ingest_units_params import Units
15
+ from payi.types.pay_i_common_models_api_router_header_info_param import PayICommonModelsAPIRouterHeaderInfoParam
14
16
 
15
17
  from .Stopwatch import Stopwatch
16
18
  from .Instruments import Instruments
17
19
 
18
20
 
21
+ class IsStreaming(Enum):
22
+ false = 0
23
+ true = 1
24
+ kwargs = 2
25
+
19
26
  class PayiInstrumentor:
20
27
  estimated_prompt_tokens: str = "estimated_prompt_tokens"
21
28
 
@@ -44,12 +51,15 @@ class PayiInstrumentor:
44
51
  def _instrument_all(self) -> None:
45
52
  self._instrument_openai()
46
53
  self._instrument_anthropic()
54
+ self._instrument_aws_bedrock()
47
55
 
48
56
  def _instrument_specific(self, instruments: Set[Instruments]) -> None:
49
57
  if Instruments.OPENAI in instruments:
50
58
  self._instrument_openai()
51
59
  if Instruments.ANTHROPIC in instruments:
52
60
  self._instrument_anthropic()
61
+ if Instruments.AWS_BEDROCK in instruments:
62
+ self._instrument_aws_bedrock()
53
63
 
54
64
  def _instrument_openai(self) -> None:
55
65
  from .OpenAIInstrumentor import OpenAiInstrumentor
@@ -69,79 +79,101 @@ class PayiInstrumentor:
69
79
  except Exception as e:
70
80
  logging.error(f"Error instrumenting Anthropic: {e}")
71
81
 
72
- def _ingest_units(self, ingest_units: IngestUnitsParams) -> None:
73
- # return early if there are no units to ingest and on a successul ingest request
82
+ def _instrument_aws_bedrock(self) -> None:
83
+ from .BedrockInstrumentor import BedrockInstrumentor
84
+
85
+ try:
86
+ BedrockInstrumentor.instrument(self)
87
+
88
+ except Exception as e:
89
+ logging.error(f"Error instrumenting AWS bedrock: {e}")
90
+
91
+ def _process_ingest_units(self, ingest_units: IngestUnitsParams, log_data: 'dict[str, str]') -> bool:
74
92
  if int(ingest_units.get("http_status_code") or 0) < 400:
75
93
  units = ingest_units.get("units", {})
76
94
  if not units or all(unit.get("input", 0) == 0 and unit.get("output", 0) == 0 for unit in units.values()):
77
95
  logging.error(
78
96
  'No units to ingest. For OpenAI streaming calls, make sure you pass stream_options={"include_usage": True}'
79
97
  )
80
- return
98
+ return False
99
+
100
+ if self._log_prompt_and_response and self._prompt_and_response_logger:
101
+ response_json = ingest_units.pop("provider_response_json", None)
102
+ request_json = ingest_units.pop("provider_request_json", None)
103
+ stack_trace = ingest_units.get("properties", {}).pop("system.stack_trace", None) # type: ignore
104
+
105
+ if response_json is not None:
106
+ # response_json is a list of strings, convert a single json string
107
+ log_data["provider_response_json"] = json.dumps(response_json)
108
+ if request_json is not None:
109
+ log_data["provider_request_json"] = request_json
110
+ if stack_trace is not None:
111
+ log_data["stack_trace"] = stack_trace
112
+
113
+ return True
114
+
115
+ def _process_ingest_units_response(self, ingest_response: IngestResponse) -> None:
116
+ if ingest_response.xproxy_result.limits:
117
+ for limit_id, state in ingest_response.xproxy_result.limits.items():
118
+ removeBlockedId: bool = False
119
+
120
+ if state.state == "blocked":
121
+ self._blocked_limits.add(limit_id)
122
+ elif state.state == "exceeded":
123
+ self._exceeded_limits.add(limit_id)
124
+ removeBlockedId = True
125
+ elif state.state == "ok":
126
+ removeBlockedId = True
127
+
128
+ # opportunistically remove blocked limits
129
+ if removeBlockedId:
130
+ self._blocked_limits.discard(limit_id)
131
+
132
+ async def _aingest_units(self, ingest_units: IngestUnitsParams) -> None:
133
+ # return early if there are no units to ingest and on a successul ingest request
134
+ log_data: 'dict[str,str]' = {}
135
+ if not self._process_ingest_units(ingest_units, log_data):
136
+ return
81
137
 
82
138
  try:
83
139
  if isinstance(self._payi, AsyncPayi):
84
- loop = asyncio.new_event_loop()
85
- asyncio.set_event_loop(loop)
86
- try:
87
- ingest_result = loop.run_until_complete(self._payi.ingest.units(**ingest_units))
88
- finally:
89
- loop.close()
90
- elif isinstance(self._payi, Payi):
91
- ingest_result = self._payi.ingest.units(**ingest_units)
140
+ ingest_response= await self._payi.ingest.units(**ingest_units)
141
+
142
+ self._process_ingest_units_response(ingest_response)
143
+
144
+ if self._log_prompt_and_response and self._prompt_and_response_logger:
145
+ request_id = ingest_response.xproxy_result.request_id
146
+ self._prompt_and_response_logger(request_id, log_data) # type: ignore
92
147
  else:
93
148
  logging.error("No payi instance to ingest units")
94
149
  return
150
+ except Exception as e:
151
+ logging.error(f"Error Pay-i ingesting result: {e}")
95
152
 
96
- if ingest_result.xproxy_result.limits:
97
- for limit_id, state in ingest_result.xproxy_result.limits.items():
98
- removeBlockedId: bool = False
99
-
100
- if state.state == "blocked":
101
- self._blocked_limits.add(limit_id)
102
- elif state.state == "exceeded":
103
- self._exceeded_limits.add(limit_id)
104
- removeBlockedId = True
105
- elif state.state == "ok":
106
- removeBlockedId = True
107
-
108
- # opportunistically remove blocked limits
109
- if removeBlockedId:
110
- self._blocked_limits.discard(limit_id)
111
-
112
- if self._log_prompt_and_response and self._prompt_and_response_logger:
113
- request_id = ingest_result.xproxy_result.request_id
114
-
115
- log_data = {}
116
- response_json = ingest_units.pop("provider_response_json", None)
117
- request_json = ingest_units.pop("provider_request_json", None)
118
- stack_trace = ingest_units.get("properties", {}).pop("system.stack_trace", None) # type: ignore
153
+ def _ingest_units(self, ingest_units: IngestUnitsParams) -> None:
154
+ # return early if there are no units to ingest and on a successul ingest request
155
+ log_data: 'dict[str,str]' = {}
156
+ if not self._process_ingest_units(ingest_units, log_data):
157
+ return
119
158
 
120
- if response_json is not None:
121
- # response_json is a list of strings, convert a single json string
122
- log_data["provider_response_json"] = json.dumps(response_json)
123
- if request_json is not None:
124
- log_data["provider_request_json"] = request_json
125
- if stack_trace is not None:
126
- log_data["stack_trace"] = stack_trace
159
+ try:
160
+ if isinstance(self._payi, Payi):
161
+ ingest_response = self._payi.ingest.units(**ingest_units)
127
162
 
128
- self._prompt_and_response_logger(request_id, log_data) # type: ignore
163
+ self._process_ingest_units_response(ingest_response)
129
164
 
165
+ if self._log_prompt_and_response and self._prompt_and_response_logger:
166
+ request_id = ingest_response.xproxy_result.request_id
167
+ self._prompt_and_response_logger(request_id, log_data) # type: ignore
168
+ else:
169
+ logging.error("No payi instance to ingest units")
170
+ return
130
171
  except Exception as e:
131
172
  logging.error(f"Error Pay-i ingesting result: {e}")
132
173
 
133
- def _call_func(
134
- self,
135
- func: Any,
136
- proxy: bool,
137
- limit_ids: Optional["list[str]"],
138
- request_tags: Optional["list[str]"],
139
- experience_name: Optional[str],
140
- experience_id: Optional[str],
141
- user_id: Optional[str],
142
- *args: Any,
143
- **kwargs: Any,
144
- ) -> Any:
174
+ def _setup_call_func(
175
+ self
176
+ ) -> 'tuple[dict[str, Any], Optional[str], Optional[str]]':
145
177
  if len(self._context_stack) > 0:
146
178
  # copy current context into the upcoming context
147
179
  context = self._context_stack[-1].copy()
@@ -152,36 +184,100 @@ class PayiInstrumentor:
152
184
  context = {}
153
185
  previous_experience_name = None
154
186
  previous_experience_id = None
187
+ return (context, previous_experience_name, previous_experience_id)
155
188
 
156
- with self:
157
- context["proxy"] = proxy
158
-
159
- # Handle experience name and ID logic
160
- if not experience_name:
161
- # If no experience_name specified, use previous values
162
- context["experience_name"] = previous_experience_name
163
- context["experience_id"] = previous_experience_id
189
+ def _init_context(
190
+ self,
191
+ context: "dict[str, Any]",
192
+ previous_experience_name: Optional[str],
193
+ previous_experience_id: Optional[str],
194
+ proxy: bool,
195
+ limit_ids: Optional["list[str]"],
196
+ request_tags: Optional["list[str]"],
197
+ experience_name: Optional[str],
198
+ experience_id: Optional[str],
199
+ user_id: Optional[str],
200
+ ) -> None:
201
+ context["proxy"] = proxy
202
+
203
+ # Handle experience name and ID logic
204
+ if not experience_name:
205
+ # If no experience_name specified, use previous values
206
+ context["experience_name"] = previous_experience_name
207
+ context["experience_id"] = previous_experience_id
208
+ else:
209
+ # If experience_name is specified
210
+ if experience_name == previous_experience_name:
211
+ # Same experience name, use previous ID unless new one specified
212
+ context["experience_name"] = experience_name
213
+ context["experience_id"] = experience_id if experience_id else previous_experience_id
164
214
  else:
165
- # If experience_name is specified
166
- if experience_name == previous_experience_name:
167
- # Same experience name, use previous ID unless new one specified
168
- context["experience_name"] = experience_name
169
- context["experience_id"] = experience_id if experience_id else previous_experience_id
170
- else:
171
- # Different experience name, use specified ID or generate one
172
- context["experience_name"] = experience_name
173
- context["experience_id"] = experience_id if experience_id else str(uuid.uuid4())
215
+ # Different experience name, use specified ID or generate one
216
+ context["experience_name"] = experience_name
217
+ context["experience_id"] = experience_id if experience_id else str(uuid.uuid4())
218
+
219
+ # set any values explicitly passed by the caller, otherwise use what is already in the context
220
+ if limit_ids:
221
+ context["limit_ids"] = limit_ids
222
+ if request_tags:
223
+ context["request_tags"] = request_tags
224
+ if user_id:
225
+ context["user_id"] = user_id
226
+
227
+ self.set_context(context)
228
+
229
+ async def _acall_func(
230
+ self,
231
+ func: Any,
232
+ proxy: bool,
233
+ limit_ids: Optional["list[str]"],
234
+ request_tags: Optional["list[str]"],
235
+ experience_name: Optional[str],
236
+ experience_id: Optional[str],
237
+ user_id: Optional[str],
238
+ *args: Any,
239
+ **kwargs: Any,
240
+ ) -> Any:
241
+ context, previous_experience_name, previous_experience_id = self._setup_call_func()
174
242
 
175
- # set any values explicitly passed by the caller, otherwise use what is already in the context
176
- if limit_ids:
177
- context["limit_ids"] = limit_ids
178
- if request_tags:
179
- context["request_tags"] = request_tags
180
- if user_id:
181
- context["user_id"] = user_id
243
+ with self:
244
+ self._init_context(
245
+ context,
246
+ previous_experience_name,
247
+ previous_experience_id,
248
+ proxy,
249
+ limit_ids,
250
+ request_tags,
251
+ experience_name,
252
+ experience_id,
253
+ user_id)
254
+ return await func(*args, **kwargs)
182
255
 
183
- self.set_context(context)
256
+ def _call_func(
257
+ self,
258
+ func: Any,
259
+ proxy: bool,
260
+ limit_ids: Optional["list[str]"],
261
+ request_tags: Optional["list[str]"],
262
+ experience_name: Optional[str],
263
+ experience_id: Optional[str],
264
+ user_id: Optional[str],
265
+ *args: Any,
266
+ **kwargs: Any,
267
+ ) -> Any:
268
+ context, previous_experience_name, previous_experience_id = self._setup_call_func()
184
269
 
270
+ with self:
271
+ self._init_context(
272
+ context,
273
+ previous_experience_name,
274
+ previous_experience_id,
275
+ proxy,
276
+ limit_ids,
277
+ request_tags,
278
+ experience_name,
279
+ experience_id,
280
+ user_id)
185
281
  return func(*args, **kwargs)
186
282
 
187
283
  def __enter__(self) -> Any:
@@ -203,22 +299,68 @@ class PayiInstrumentor:
203
299
  # Return the current top of the stack
204
300
  return self._context_stack[-1] if self._context_stack else None
205
301
 
206
- def chat_wrapper(
302
+
303
+ def _prepare_ingest(
304
+ self,
305
+ ingest: IngestUnitsParams,
306
+ ingest_extra_headers: "dict[str, str]", # do not coflict potential kwargs["extra_headers"]
307
+ **kwargs: Any,
308
+ ) -> None:
309
+ limit_ids = ingest_extra_headers.pop("xProxy-Limit-IDs", None)
310
+ request_tags = ingest_extra_headers.pop("xProxy-Request-Tags", None)
311
+ experience_name = ingest_extra_headers.pop("xProxy-Experience-Name", None)
312
+ experience_id = ingest_extra_headers.pop("xProxy-Experience-ID", None)
313
+ user_id = ingest_extra_headers.pop("xProxy-User-ID", None)
314
+
315
+ if limit_ids:
316
+ ingest["limit_ids"] = limit_ids.split(",")
317
+ if request_tags:
318
+ ingest["request_tags"] = request_tags.split(",")
319
+ if experience_name:
320
+ ingest["experience_name"] = experience_name
321
+ if experience_id:
322
+ ingest["experience_id"] = experience_id
323
+ if user_id:
324
+ ingest["user_id"] = user_id
325
+
326
+ if len(ingest_extra_headers) > 0:
327
+ ingest["provider_request_headers"] = [PayICommonModelsAPIRouterHeaderInfoParam(name=k, value=v) for k, v in ingest_extra_headers.items()]
328
+
329
+ provider_prompt = {}
330
+ for k, v in kwargs.items():
331
+ if k == "messages":
332
+ provider_prompt[k] = [m.model_dump() if hasattr(m, "model_dump") else m for m in v]
333
+ elif k in ["extra_headers", "extra_query"]:
334
+ pass
335
+ else:
336
+ provider_prompt[k] = v
337
+
338
+ if self._log_prompt_and_response:
339
+ ingest["provider_request_json"] = json.dumps(provider_prompt)
340
+
341
+ async def achat_wrapper(
207
342
  self,
208
343
  category: str,
209
- process_chunk: Callable[[Any, IngestUnitsParams], None],
210
- process_request: Optional[Callable[[IngestUnitsParams, Any], None]],
211
- process_synchronous_response: Optional[Callable[[Any, IngestUnitsParams, bool], None]],
344
+ process_chunk: Optional[Callable[[Any, IngestUnitsParams], None]],
345
+ process_request: Optional[Callable[[IngestUnitsParams, Any, Any], None]],
346
+ process_synchronous_response: Any,
347
+ is_streaming: IsStreaming,
212
348
  wrapped: Any,
213
349
  instance: Any,
214
350
  args: Any,
215
- kwargs: 'dict[str, Any]',
351
+ kwargs: Any,
216
352
  ) -> Any:
217
353
  context = self.get_context()
218
354
 
355
+ is_bedrock:bool = category == "system.aws.bedrock"
356
+
219
357
  if not context:
220
- # should not happen
221
- return wrapped(*args, **kwargs)
358
+ if is_bedrock:
359
+ # boto3 doesn't allow extra_headers
360
+ kwargs.pop("extra_headers", None)
361
+
362
+ # wrapped function invoked outside of decorator scope
363
+ return await wrapped(*args, **kwargs)
222
364
 
223
365
  # after _udpate_headers, all metadata to add to ingest is in extra_headers, keyed by the xproxy-xxx header name
224
366
  extra_headers = kwargs.get("extra_headers", {})
@@ -228,13 +370,16 @@ class PayiInstrumentor:
228
370
  if "extra_headers" not in kwargs:
229
371
  kwargs["extra_headers"] = extra_headers
230
372
 
231
- return wrapped(*args, **kwargs)
373
+ return await wrapped(*args, **kwargs)
232
374
 
233
- ingest: IngestUnitsParams = {"category": category, "resource": kwargs.get("model"), "units": {}} # type: ignore
375
+ ingest: IngestUnitsParams = {"category": category, "units": {}} # type: ignore
376
+ if is_bedrock:
377
+ # boto3 doesn't allow extra_headers
378
+ kwargs.pop("extra_headers", None)
379
+ ingest["resource"] = kwargs.get("modelId", "")
380
+ else:
381
+ ingest["resource"] = kwargs.get("model", "")
234
382
 
235
- # blocked_limit = next((limit for limit in (context.get('limit_ids') or []) if limit in self._blocked_limits), None)
236
- # if blocked_limit:
237
- # raise Exception(f"Limit {blocked_limit} is blocked")
238
383
  current_frame = inspect.currentframe()
239
384
  # f_back excludes the current frame, strip() cleans up whitespace and newlines
240
385
  stack = [frame.strip() for frame in traceback.format_stack(current_frame.f_back)] # type: ignore
@@ -242,46 +387,135 @@ class PayiInstrumentor:
242
387
  ingest['properties'] = { 'system.stack_trace': json.dumps(stack) }
243
388
 
244
389
  if process_request:
245
- process_request(ingest, kwargs)
390
+ process_request(ingest, (), instance)
246
391
 
247
392
  sw = Stopwatch()
248
- stream = kwargs.get("stream", False)
393
+ stream: bool = False
394
+
395
+ if is_streaming == IsStreaming.kwargs:
396
+ stream = kwargs.get("stream", False)
397
+ elif is_streaming == IsStreaming.true:
398
+ stream = True
399
+ else:
400
+ stream = False
249
401
 
250
402
  try:
251
- limit_ids = extra_headers.pop("xProxy-Limit-IDs", None)
252
- request_tags = extra_headers.pop("xProxy-Request-Tags", None)
253
- experience_name = extra_headers.pop("xProxy-Experience-Name", None)
254
- experience_id = extra_headers.pop("xProxy-Experience-ID", None)
255
- user_id = extra_headers.pop("xProxy-User-ID", None)
256
-
257
- if limit_ids:
258
- ingest["limit_ids"] = limit_ids.split(",")
259
- if request_tags:
260
- ingest["request_tags"] = request_tags.split(",")
261
- if experience_name:
262
- ingest["experience_name"] = experience_name
263
- if experience_id:
264
- ingest["experience_id"] = experience_id
265
- if user_id:
266
- ingest["user_id"] = user_id
267
-
268
- if len(extra_headers) > 0:
269
- ingest["provider_request_headers"] = {k: [v] for k, v in extra_headers.items()} # type: ignore
270
-
271
- provider_prompt = {}
272
- for k, v in kwargs.items():
273
- if k == "messages":
274
- provider_prompt[k] = [m.model_dump() if hasattr(m, "model_dump") else m for m in v]
275
- elif k in ["extra_headers", "extra_query"]:
276
- pass
403
+ self._prepare_ingest(ingest, extra_headers, **kwargs)
404
+ sw.start()
405
+ response = await wrapped(*args, **kwargs)
406
+
407
+ except Exception as e: # pylint: disable=broad-except
408
+ sw.stop()
409
+ duration = sw.elapsed_ms_int()
410
+
411
+ # TODO ingest error
412
+
413
+ raise e
414
+
415
+ if stream:
416
+ stream_result = ChatStreamWrapper(
417
+ response=response,
418
+ instance=instance,
419
+ instrumentor=self,
420
+ log_prompt_and_response=self._log_prompt_and_response,
421
+ ingest=ingest,
422
+ stopwatch=sw,
423
+ process_chunk=process_chunk,
424
+ is_bedrock=is_bedrock,
425
+ )
426
+
427
+ if is_bedrock:
428
+ if "body" in response:
429
+ response["body"] = stream_result
277
430
  else:
278
- provider_prompt[k] = v
431
+ response["stream"] = stream_result
432
+ return response
433
+
434
+ return stream_result
435
+
436
+ sw.stop()
437
+ duration = sw.elapsed_ms_int()
438
+ ingest["end_to_end_latency_ms"] = duration
439
+ ingest["http_status_code"] = 200
440
+
441
+ if process_synchronous_response:
442
+ return_result: Any = process_synchronous_response(
443
+ response=response,
444
+ ingest=ingest,
445
+ log_prompt_and_response=self._log_prompt_and_response,
446
+ instrumentor=self)
447
+ if return_result:
448
+ return return_result
449
+
450
+ await self._aingest_units(ingest)
451
+
452
+ return response
453
+
454
+ def chat_wrapper(
455
+ self,
456
+ category: str,
457
+ process_chunk: Optional[Callable[[Any, IngestUnitsParams], None]],
458
+ process_request: Optional[Callable[[IngestUnitsParams, Any, Any], None]],
459
+ process_synchronous_response: Any,
460
+ is_streaming: IsStreaming,
461
+ wrapped: Any,
462
+ instance: Any,
463
+ args: Any,
464
+ kwargs: Any,
465
+ ) -> Any:
466
+ context = self.get_context()
467
+
468
+ is_bedrock:bool = category == "system.aws.bedrock"
469
+
470
+ if not context:
471
+ if is_bedrock:
472
+ # boto3 doesn't allow extra_headers
473
+ kwargs.pop("extra_headers", None)
474
+
475
+ # wrapped function invoked outside of decorator scope
476
+ return wrapped(*args, **kwargs)
477
+
478
+ # after _udpate_headers, all metadata to add to ingest is in extra_headers, keyed by the xproxy-xxx header name
479
+ extra_headers = kwargs.get("extra_headers", {})
480
+ self._update_headers(context, extra_headers)
481
+
482
+ if context.get("proxy", True):
483
+ if "extra_headers" not in kwargs:
484
+ kwargs["extra_headers"] = extra_headers
485
+
486
+ return wrapped(*args, **kwargs)
487
+
488
+ ingest: IngestUnitsParams = {"category": category, "units": {}} # type: ignore
489
+ if is_bedrock:
490
+ # boto3 doesn't allow extra_headers
491
+ kwargs.pop("extra_headers", None)
492
+ ingest["resource"] = kwargs.get("modelId", "")
493
+ else:
494
+ ingest["resource"] = kwargs.get("model", "")
495
+
496
+ current_frame = inspect.currentframe()
497
+ # f_back excludes the current frame, strip() cleans up whitespace and newlines
498
+ stack = [frame.strip() for frame in traceback.format_stack(current_frame.f_back)] # type: ignore
279
499
 
280
- if self._log_prompt_and_response:
281
- ingest["provider_request_json"] = json.dumps(provider_prompt)
500
+ ingest['properties'] = { 'system.stack_trace': json.dumps(stack) }
501
+
502
+ if process_request:
503
+ process_request(ingest, (), kwargs)
504
+
505
+ sw = Stopwatch()
506
+ stream: bool = False
507
+
508
+ if is_streaming == IsStreaming.kwargs:
509
+ stream = kwargs.get("stream", False)
510
+ elif is_streaming == IsStreaming.true:
511
+ stream = True
512
+ else:
513
+ stream = False
282
514
 
515
+ try:
516
+ self._prepare_ingest(ingest, extra_headers, **kwargs)
283
517
  sw.start()
284
- response = wrapped(*args, **kwargs.copy())
518
+ response = wrapped(*args, **kwargs)
285
519
 
286
520
  except Exception as e: # pylint: disable=broad-except
287
521
  sw.stop()
@@ -292,7 +526,7 @@ class PayiInstrumentor:
292
526
  raise e
293
527
 
294
528
  if stream:
295
- return ChatStreamWrapper(
529
+ stream_result = ChatStreamWrapper(
296
530
  response=response,
297
531
  instance=instance,
298
532
  instrumentor=self,
@@ -300,15 +534,31 @@ class PayiInstrumentor:
300
534
  ingest=ingest,
301
535
  stopwatch=sw,
302
536
  process_chunk=process_chunk,
537
+ is_bedrock=is_bedrock,
303
538
  )
304
539
 
540
+ if is_bedrock:
541
+ if "body" in response:
542
+ response["body"] = stream_result
543
+ else:
544
+ response["stream"] = stream_result
545
+ return response
546
+
547
+ return stream_result
548
+
305
549
  sw.stop()
306
550
  duration = sw.elapsed_ms_int()
307
551
  ingest["end_to_end_latency_ms"] = duration
308
552
  ingest["http_status_code"] = 200
309
553
 
310
554
  if process_synchronous_response:
311
- process_synchronous_response(response, ingest, self._log_prompt_and_response)
555
+ return_result: Any = process_synchronous_response(
556
+ response=response,
557
+ ingest=ingest,
558
+ log_prompt_and_response=self._log_prompt_and_response,
559
+ instrumentor=self)
560
+ if return_result:
561
+ return return_result
312
562
 
313
563
  self._ingest_units(ingest)
314
564
 
@@ -379,14 +629,29 @@ class PayiInstrumentor:
379
629
  o,
380
630
  wrapped,
381
631
  instance,
382
- args,
383
- kwargs,
632
+ *args,
633
+ **kwargs,
384
634
  )
385
635
 
386
636
  return wrapper
387
637
 
388
638
  return _payi_wrapper
389
639
 
640
+ @staticmethod
641
+ def payi_awrapper(func: Any) -> Any:
642
+ def _payi_awrapper(o: Any) -> Any:
643
+ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any:
644
+ return await func(
645
+ o,
646
+ wrapped,
647
+ instance,
648
+ *args,
649
+ **kwargs,
650
+ )
651
+
652
+ return wrapper
653
+
654
+ return _payi_awrapper
390
655
 
391
656
  class ChatStreamWrapper(ObjectProxy): # type: ignore
392
657
  def __init__(
@@ -398,7 +663,19 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
398
663
  stopwatch: Stopwatch,
399
664
  process_chunk: Optional[Callable[[Any, IngestUnitsParams], None]] = None,
400
665
  log_prompt_and_response: bool = True,
666
+ is_bedrock: bool = False,
401
667
  ) -> None:
668
+
669
+ bedrock_from_stream: bool = False
670
+ if is_bedrock:
671
+ stream = response.get("stream", None)
672
+ if stream:
673
+ response = stream
674
+ bedrock_from_stream = True
675
+ else:
676
+ response = response.get("body")
677
+ bedrock_from_stream = False
678
+
402
679
  super().__init__(response) # type: ignore
403
680
 
404
681
  self._response = response
@@ -413,6 +690,8 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
413
690
  self._process_chunk: Optional[Callable[[Any, IngestUnitsParams], None]] = process_chunk
414
691
 
415
692
  self._first_token: bool = True
693
+ self._is_bedrock: bool = is_bedrock
694
+ self._bedrock_from_stream: bool = bedrock_from_stream
416
695
 
417
696
  def __enter__(self) -> Any:
418
697
  return self
@@ -426,9 +705,26 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
426
705
  async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
427
706
  await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) # type: ignore
428
707
 
429
- def __iter__(self) -> Any:
708
+ def __iter__(self) -> Any:
709
+ if self._is_bedrock:
710
+ # MUST be reside in a separate function so that the yield statement doesn't implicitly return its own iterator and overriding self
711
+ return self._iter_bedrock()
430
712
  return self
431
713
 
714
+ def _iter_bedrock(self) -> Any:
715
+ # botocore EventStream doesn't have a __next__ method so iterate over the wrapped object in place
716
+ for event in self.__wrapped__: # type: ignore
717
+ if (self._bedrock_from_stream):
718
+ self._evaluate_chunk(event)
719
+ else:
720
+ chunk = event.get('chunk') # type: ignore
721
+ if chunk:
722
+ decode = chunk.get('bytes').decode() # type: ignore
723
+ self._evaluate_chunk(decode)
724
+ yield event
725
+
726
+ self._stop_iteration()
727
+
432
728
  def __aiter__(self) -> Any:
433
729
  return self
434
730
 
@@ -448,7 +744,7 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
448
744
  chunk: Any = await self.__wrapped__.__anext__() # type: ignore
449
745
  except Exception as e:
450
746
  if isinstance(e, StopAsyncIteration):
451
- self._stop_iteration()
747
+ await self._astop_iteration()
452
748
  raise e
453
749
  else:
454
750
  self._evaluate_chunk(chunk)
@@ -460,12 +756,12 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
460
756
  self._first_token = False
461
757
 
462
758
  if self._log_prompt_and_response:
463
- self._responses.append(chunk.to_json())
759
+ self._responses.append(self.chunk_to_json(chunk))
464
760
 
465
761
  if self._process_chunk:
466
762
  self._process_chunk(chunk, self._ingest)
467
763
 
468
- def _stop_iteration(self) -> None:
764
+ def _process_stop_iteration(self) -> None:
469
765
  self._stopwatch.stop()
470
766
  self._ingest["end_to_end_latency_ms"] = self._stopwatch.elapsed_ms_int()
471
767
  self._ingest["http_status_code"] = 200
@@ -473,13 +769,29 @@ class ChatStreamWrapper(ObjectProxy): # type: ignore
473
769
  if self._log_prompt_and_response:
474
770
  self._ingest["provider_response_json"] = self._responses
475
771
 
772
+ async def _astop_iteration(self) -> None:
773
+ self._process_stop_iteration()
774
+ await self._instrumentor._aingest_units(self._ingest)
775
+
776
+ def _stop_iteration(self) -> None:
777
+ self._process_stop_iteration()
476
778
  self._instrumentor._ingest_units(self._ingest)
477
779
 
780
+ @staticmethod
781
+ def chunk_to_json(chunk: Any) -> str:
782
+ if hasattr(chunk, "to_json"):
783
+ return str(chunk.to_json())
784
+ elif isinstance(chunk, bytes):
785
+ return chunk.decode()
786
+ elif isinstance(chunk, str):
787
+ return chunk
788
+ else:
789
+ # assume dict
790
+ return json.dumps(chunk)
478
791
 
479
792
  global _instrumentor
480
793
  _instrumentor: PayiInstrumentor
481
794
 
482
-
483
795
  def payi_instrument(
484
796
  payi: Optional[Union[Payi, AsyncPayi]] = None,
485
797
  instruments: Optional[Set[Instruments]] = None,
@@ -494,7 +806,6 @@ def payi_instrument(
494
806
  prompt_and_response_logger=prompt_and_response_logger,
495
807
  )
496
808
 
497
-
498
809
  def ingest(
499
810
  limit_ids: Optional["list[str]"] = None,
500
811
  request_tags: Optional["list[str]"] = None,
@@ -503,24 +814,38 @@ def ingest(
503
814
  user_id: Optional[str] = None,
504
815
  ) -> Any:
505
816
  def _ingest(func: Any) -> Any:
506
- def _ingest_wrapper(*args: Any, **kwargs: Any) -> Any:
507
- return _instrumentor._call_func(
508
- func,
509
- False, # false -> ingest
510
- limit_ids,
511
- request_tags,
512
- experience_name,
513
- experience_id,
514
- user_id,
515
- *args,
516
- **kwargs,
517
- )
518
-
519
- return _ingest_wrapper
520
-
817
+ import asyncio
818
+ if asyncio.iscoroutinefunction(func):
819
+ async def awrapper(*args: Any, **kwargs: Any) -> Any:
820
+ # Call the instrumentor's _call_func for async functions
821
+ return await _instrumentor._acall_func(
822
+ func,
823
+ False,
824
+ limit_ids,
825
+ request_tags,
826
+ experience_name,
827
+ experience_id,
828
+ user_id,
829
+ *args,
830
+ *kwargs,
831
+ )
832
+ return awrapper
833
+ else:
834
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
835
+ return _instrumentor._call_func(
836
+ func,
837
+ False,
838
+ limit_ids,
839
+ request_tags,
840
+ experience_name,
841
+ experience_id,
842
+ user_id,
843
+ *args,
844
+ **kwargs,
845
+ )
846
+ return wrapper
521
847
  return _ingest
522
848
 
523
-
524
849
  def proxy(
525
850
  limit_ids: Optional["list[str]"] = None,
526
851
  request_tags: Optional["list[str]"] = None,
@@ -529,11 +854,36 @@ def proxy(
529
854
  user_id: Optional[str] = None,
530
855
  ) -> Any:
531
856
  def _proxy(func: Any) -> Any:
532
- def _proxy_wrapper(*args: Any, **kwargs: Any) -> Any:
533
- return _instrumentor._call_func(
534
- func, True, limit_ids, request_tags, experience_name, experience_id, user_id, *args, **kwargs
535
- )
857
+ import asyncio
858
+ if asyncio.iscoroutinefunction(func):
859
+ async def _proxy_awrapper(*args: Any, **kwargs: Any) -> Any:
860
+ return await _instrumentor._call_func(
861
+ func,
862
+ True,
863
+ limit_ids,
864
+ request_tags,
865
+ experience_name,
866
+ experience_id,
867
+ user_id,
868
+ *args,
869
+ **kwargs
870
+ )
871
+
872
+ return _proxy_awrapper
873
+ else:
874
+ def _proxy_wrapper(*args: Any, **kwargs: Any) -> Any:
875
+ return _instrumentor._call_func(
876
+ func,
877
+ True,
878
+ limit_ids,
879
+ request_tags,
880
+ experience_name,
881
+ experience_id,
882
+ user_id,
883
+ *args,
884
+ **kwargs
885
+ )
536
886
 
537
- return _proxy_wrapper
887
+ return _proxy_wrapper
538
888
 
539
889
  return _proxy