payi 0.1.0a110__py3-none-any.whl → 0.1.0a137__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. payi/__init__.py +3 -1
  2. payi/_base_client.py +12 -12
  3. payi/_client.py +8 -8
  4. payi/_compat.py +48 -48
  5. payi/_models.py +87 -59
  6. payi/_qs.py +7 -7
  7. payi/_streaming.py +4 -6
  8. payi/_types.py +53 -12
  9. payi/_utils/__init__.py +9 -2
  10. payi/_utils/_compat.py +45 -0
  11. payi/_utils/_datetime_parse.py +136 -0
  12. payi/_utils/_sync.py +3 -31
  13. payi/_utils/_transform.py +13 -3
  14. payi/_utils/_typing.py +6 -1
  15. payi/_utils/_utils.py +5 -6
  16. payi/_version.py +1 -1
  17. payi/lib/AnthropicInstrumentor.py +83 -57
  18. payi/lib/BedrockInstrumentor.py +292 -57
  19. payi/lib/GoogleGenAiInstrumentor.py +18 -31
  20. payi/lib/OpenAIInstrumentor.py +56 -72
  21. payi/lib/ProviderRequest.py +216 -0
  22. payi/lib/StreamWrappers.py +379 -0
  23. payi/lib/VertexInstrumentor.py +18 -37
  24. payi/lib/VertexRequest.py +16 -2
  25. payi/lib/data/cohere_embed_english_v3.json +30706 -0
  26. payi/lib/helpers.py +53 -1
  27. payi/lib/instrument.py +404 -668
  28. payi/resources/categories/__init__.py +0 -14
  29. payi/resources/categories/categories.py +25 -53
  30. payi/resources/categories/resources.py +27 -23
  31. payi/resources/ingest.py +126 -132
  32. payi/resources/limits/__init__.py +14 -14
  33. payi/resources/limits/limits.py +58 -58
  34. payi/resources/limits/properties.py +171 -0
  35. payi/resources/requests/request_id/properties.py +8 -8
  36. payi/resources/requests/request_id/result.py +3 -3
  37. payi/resources/requests/response_id/properties.py +8 -8
  38. payi/resources/requests/response_id/result.py +3 -3
  39. payi/resources/use_cases/definitions/definitions.py +27 -27
  40. payi/resources/use_cases/definitions/kpis.py +23 -23
  41. payi/resources/use_cases/definitions/limit_config.py +14 -14
  42. payi/resources/use_cases/definitions/version.py +3 -3
  43. payi/resources/use_cases/kpis.py +15 -15
  44. payi/resources/use_cases/properties.py +6 -6
  45. payi/resources/use_cases/use_cases.py +7 -7
  46. payi/types/__init__.py +2 -0
  47. payi/types/bulk_ingest_response.py +3 -20
  48. payi/types/categories/__init__.py +0 -1
  49. payi/types/categories/resource_list_params.py +5 -1
  50. payi/types/category_list_resources_params.py +5 -1
  51. payi/types/category_resource_response.py +31 -1
  52. payi/types/ingest_event_param.py +7 -6
  53. payi/types/ingest_units_params.py +5 -4
  54. payi/types/limit_create_params.py +3 -3
  55. payi/types/limit_list_response.py +1 -3
  56. payi/types/limit_response.py +1 -3
  57. payi/types/limits/__init__.py +2 -9
  58. payi/types/limits/{tag_remove_params.py → property_update_params.py} +4 -5
  59. payi/types/limits/{tag_delete_response.py → property_update_response.py} +3 -3
  60. payi/types/requests/request_id/property_update_params.py +2 -2
  61. payi/types/requests/response_id/property_update_params.py +2 -2
  62. payi/types/shared/__init__.py +2 -0
  63. payi/types/shared/api_error.py +18 -0
  64. payi/types/shared/pay_i_common_models_budget_management_create_limit_base.py +3 -3
  65. payi/types/shared/properties_request.py +11 -0
  66. payi/types/shared/xproxy_result.py +2 -0
  67. payi/types/shared_params/pay_i_common_models_budget_management_create_limit_base.py +3 -3
  68. payi/types/use_cases/definitions/limit_config_create_params.py +3 -3
  69. payi/types/use_cases/property_update_params.py +2 -2
  70. {payi-0.1.0a110.dist-info → payi-0.1.0a137.dist-info}/METADATA +6 -6
  71. {payi-0.1.0a110.dist-info → payi-0.1.0a137.dist-info}/RECORD +73 -75
  72. payi/resources/categories/fixed_cost_resources.py +0 -196
  73. payi/resources/limits/tags.py +0 -507
  74. payi/types/categories/fixed_cost_resource_create_params.py +0 -21
  75. payi/types/limits/limit_tags.py +0 -16
  76. payi/types/limits/tag_create_params.py +0 -13
  77. payi/types/limits/tag_create_response.py +0 -10
  78. payi/types/limits/tag_list_response.py +0 -10
  79. payi/types/limits/tag_remove_response.py +0 -10
  80. payi/types/limits/tag_update_params.py +0 -13
  81. payi/types/limits/tag_update_response.py +0 -10
  82. {payi-0.1.0a110.dist-info → payi-0.1.0a137.dist-info}/WHEEL +0 -0
  83. {payi-0.1.0a110.dist-info → payi-0.1.0a137.dist-info}/licenses/LICENSE +0 -0
payi/lib/instrument.py CHANGED
@@ -1,150 +1,57 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
4
+ import copy
2
5
  import json
3
6
  import time
4
7
  import uuid
8
+ import atexit
5
9
  import asyncio
6
10
  import inspect
7
11
  import logging
12
+ import threading
8
13
  import traceback
9
- from abc import abstractmethod
10
14
  from enum import Enum
11
- from typing import Any, Set, Union, Callable, Optional, Sequence, TypedDict
15
+ from typing import Any, Set, Union, Optional, Sequence, TypedDict, cast
12
16
  from datetime import datetime, timezone
13
- from dataclasses import dataclass
14
17
 
15
18
  import nest_asyncio # type: ignore
16
- from wrapt import ObjectProxy # type: ignore
19
+ from wrapt import wrap_function_wrapper # type: ignore
17
20
 
18
21
  from payi import Payi, AsyncPayi, APIStatusError, APIConnectionError, __version__ as _payi_version
19
22
  from payi.types import IngestUnitsParams
20
- from payi.lib.helpers import PayiHeaderNames
23
+ from payi.lib.helpers import PayiHeaderNames, _compact_json
21
24
  from payi.types.shared import XproxyResult
22
25
  from payi.types.ingest_response import IngestResponse
23
- from payi.types.ingest_units_params import Units, ProviderResponseFunctionCall
24
26
  from payi.types.shared.xproxy_error import XproxyError
25
27
  from payi.types.pay_i_common_models_api_router_header_info_param import PayICommonModelsAPIRouterHeaderInfoParam
26
28
 
27
29
  from .helpers import PayiCategories
28
30
  from .Stopwatch import Stopwatch
31
+ from .StreamWrappers import _GeneratorWrapper, _StreamManagerWrapper, _StreamIteratorWrapper
32
+ from .ProviderRequest import PriceAs, _StreamingType, _ProviderRequest
29
33
 
30
34
  global _g_logger
31
35
  _g_logger: logging.Logger = logging.getLogger("payi.instrument")
32
36
 
33
- @dataclass
34
- class _ChunkResult:
35
- send_chunk_to_caller: bool
36
- ingest: bool = False
37
-
38
- class _ProviderRequest:
39
- def __init__(
40
- self,
41
- instrumentor: '_PayiInstrumentor',
42
- category: str,
43
- streaming_type: '_StreamingType',
44
- module_name: str,
45
- module_version: str,
46
- is_aws_client: Optional[bool] = None,
47
- is_google_vertex_or_genai_client: Optional[bool] = None,
48
- ) -> None:
49
- self._instrumentor: '_PayiInstrumentor' = instrumentor
50
- self._module_name: str = module_name
51
- self._module_version: str = module_version
52
- self._estimated_prompt_tokens: Optional[int] = None
53
- self._category: str = category
54
- self._ingest: IngestUnitsParams = { "category": category, "units": {} } # type: ignore
55
- self._streaming_type: '_StreamingType' = streaming_type
56
- self._is_aws_client: Optional[bool] = is_aws_client
57
- self._is_google_vertex_or_genai_client: Optional[bool] = is_google_vertex_or_genai_client
58
- self._function_call_builder: Optional[dict[int, ProviderResponseFunctionCall]] = None
59
- self._building_function_response: bool = False
60
- self._function_calls: Optional[list[ProviderResponseFunctionCall]] = None
61
-
62
- def process_chunk(self, _chunk: Any) -> _ChunkResult:
63
- return _ChunkResult(send_chunk_to_caller=True)
64
-
65
- def process_synchronous_response(self, response: Any, log_prompt_and_response: bool, kwargs: Any) -> Optional[object]: # noqa: ARG002
66
- return None
67
-
68
- @abstractmethod
69
- def process_request(self, instance: Any, extra_headers: 'dict[str, str]', args: Sequence[Any], kwargs: Any) -> bool:
70
- ...
71
-
72
- def process_request_prompt(self, prompt: 'dict[str, Any]', args: Sequence[Any], kwargs: 'dict[str, Any]') -> None:
73
- ...
74
-
75
- def process_initial_stream_response(self, response: Any) -> None:
76
- pass
77
-
78
- def remove_inline_data(self, prompt: 'dict[str, Any]') -> bool:# noqa: ARG002
79
- return False
80
-
81
- @property
82
- def is_aws_client(self) -> bool:
83
- return self._is_aws_client if self._is_aws_client is not None else False
84
-
85
- @property
86
- def is_google_vertex_or_genai_client(self) -> bool:
87
- return self._is_google_vertex_or_genai_client if self._is_google_vertex_or_genai_client is not None else False
88
-
89
- def process_exception(self, exception: Exception, kwargs: Any, ) -> bool: # noqa: ARG002
90
- self.exception_to_semantic_failure(exception)
91
- return True
92
-
93
- @property
94
- def supports_extra_headers(self) -> bool:
95
- return not self.is_aws_client and not self.is_google_vertex_or_genai_client
96
-
97
- @property
98
- def streaming_type(self) -> '_StreamingType':
99
- return self._streaming_type
100
-
101
- def exception_to_semantic_failure(self, e: Exception) -> None:
102
- exception_str = f"{type(e).__name__}"
103
-
104
- fields: list[str] = []
105
-
106
- for attr in dir(e):
107
- if not attr.startswith("__"):
108
- try:
109
- value = getattr(e, attr)
110
- if value and not inspect.ismethod(value) and not inspect.isfunction(value) and not callable(value):
111
- fields.append(f"{attr}={value}")
112
- except Exception as _ex:
113
- pass
114
-
115
- existing_properties = self._ingest.get("properties", None)
116
- if not existing_properties:
117
- existing_properties = {}
118
-
119
- existing_properties['system.failure'] = exception_str
120
- if fields:
121
- failure_description = ",".join(fields)
122
- existing_properties["system.failure.description"] = failure_description[:128]
123
-
124
- self._ingest["properties"] = existing_properties
37
+ class PayiInstrumentModelMapping(TypedDict, total=False):
38
+ model: str
39
+ price_as_category: Optional[str]
40
+ price_as_resource: Optional[str]
41
+ # "global", "datazone", "region", "region.<region_name>"
42
+ resource_scope: Optional[str]
125
43
 
126
- if "http_status_code" not in self._ingest:
127
- # use a non existent http status code so when presented to the user, the origin is clear
128
- self._ingest["http_status_code"] = 299
44
+ class PayiInstrumentAwsBedrockConfig(TypedDict, total=False):
45
+ guardrail_trace: Optional[bool]
46
+ add_streaming_xproxy_result: Optional[bool]
47
+ model_mappings: Optional[Sequence[PayiInstrumentModelMapping]]
129
48
 
130
- def add_streaming_function_call(self, index: int, name: Optional[str], arguments: Optional[str]) -> None:
131
- if not self._function_call_builder:
132
- self._function_call_builder = {}
49
+ class PayiInstrumentAzureOpenAiConfig(TypedDict, total=False):
50
+ # map deployment name known model
51
+ model_mappings: Sequence[PayiInstrumentModelMapping]
133
52
 
134
- if not index in self._function_call_builder:
135
- self._function_call_builder[index] = ProviderResponseFunctionCall(name=name or "", arguments=arguments or "")
136
- else:
137
- function = self._function_call_builder[index]
138
- if name:
139
- function["name"] = function["name"] + name
140
- if arguments:
141
- function["arguments"] = (function.get("arguments", "") or "") + arguments
142
-
143
- def add_synchronous_function_call(self, name: str, arguments: Optional[str]) -> None:
144
- if not self._function_calls:
145
- self._function_calls = []
146
- self._ingest["provider_response_function_calls"] = self._function_calls
147
- self._function_calls.append(ProviderResponseFunctionCall(name=name, arguments=arguments))
53
+ class PayiInstrumentOfflineInstrumentationConfig(TypedDict, total=False):
54
+ file_name: str
148
55
 
149
56
  class PayiInstrumentConfig(TypedDict, total=False):
150
57
  proxy: bool
@@ -155,40 +62,55 @@ class PayiInstrumentConfig(TypedDict, total=False):
155
62
  use_case_name: Optional[str]
156
63
  use_case_id: Optional[str]
157
64
  use_case_version: Optional[int]
158
- use_case_properties: Optional["dict[str, str]"]
65
+ use_case_properties: Optional["dict[str, Optional[str]]"]
159
66
  user_id: Optional[str]
160
67
  account_name: Optional[str]
161
68
  request_tags: Optional["list[str]"]
162
- request_properties: Optional["dict[str, str]"]
69
+ request_properties: Optional["dict[str, Optional[str]]"]
70
+ aws_config: Optional[PayiInstrumentAwsBedrockConfig]
71
+ azure_openai_config: Optional[PayiInstrumentAzureOpenAiConfig]
72
+ offline_instrumentation: Optional[PayiInstrumentOfflineInstrumentationConfig]
163
73
 
164
74
  class PayiContext(TypedDict, total=False):
165
75
  use_case_name: Optional[str]
166
76
  use_case_id: Optional[str]
167
77
  use_case_version: Optional[int]
168
78
  use_case_step: Optional[str]
169
- use_case_properties: Optional["dict[str, str]"]
79
+ use_case_properties: Optional["dict[str, Optional[str]]"]
170
80
  limit_ids: Optional['list[str]']
171
81
  user_id: Optional[str]
172
82
  account_name: Optional[str]
173
83
  request_tags: Optional["list[str]"]
174
- request_properties: Optional["dict[str, str]"]
84
+ request_properties: Optional["dict[str, Optional[str]]"]
175
85
  price_as_category: Optional[str]
176
86
  price_as_resource: Optional[str]
177
87
  resource_scope: Optional[str]
178
88
  last_result: Optional[Union[XproxyResult, XproxyError]]
179
89
 
90
+ class PayiInstanceDefaultContext(TypedDict, total=False):
91
+ use_case_name: Optional[str]
92
+ use_case_id: Optional[str]
93
+ use_case_version: Optional[int]
94
+ use_case_properties: Optional["dict[str, str]"]
95
+ limit_ids: Optional['list[str]']
96
+ user_id: Optional[str]
97
+ account_name: Optional[str]
98
+ request_properties: Optional["dict[str, str]"]
99
+ price_as_category: Optional[str]
100
+ price_as_resource: Optional[str]
101
+ resource_scope: Optional[str]
102
+
180
103
  class _Context(TypedDict, total=False):
181
104
  proxy: Optional[bool]
182
105
  use_case_name: Optional[str]
183
106
  use_case_id: Optional[str]
184
107
  use_case_version: Optional[int]
185
108
  use_case_step: Optional[str]
186
- use_case_properties: Optional["dict[str, str]"]
109
+ use_case_properties: Optional["dict[str, Optional[str]]"]
187
110
  limit_ids: Optional['list[str]']
188
111
  user_id: Optional[str]
189
112
  account_name: Optional[str]
190
- request_tags: Optional["list[str]"]
191
- request_properties: Optional["dict[str, str]"]
113
+ request_properties: Optional["dict[str, Optional[str]]"]
192
114
  price_as_category: Optional[str]
193
115
  price_as_resource: Optional[str]
194
116
  resource_scope: Optional[str]
@@ -198,10 +120,14 @@ class _IsStreaming(Enum):
198
120
  true = 1
199
121
  kwargs = 2
200
122
 
201
- class _StreamingType(Enum):
202
- generator = 0
203
- iterator = 1
204
- stream_manager = 2
123
+ class _ThreadLocalContextStorage(threading.local):
124
+ """
125
+ Thread-local storage for context stacks. Each thread gets its own context stack.
126
+
127
+ Note: We don't use __init__ because threading.local's __init__ semantics are tricky.
128
+ Instead, we lazily initialize the context_stack attribute in the property accessor.
129
+ """
130
+ context_stack: "list[_Context]"
205
131
 
206
132
  class _InternalTrackContext:
207
133
  def __init__(
@@ -233,9 +159,6 @@ class _PayiInstrumentor:
233
159
  instruments: Union[Set[str], None] = None,
234
160
  log_prompt_and_response: bool = True,
235
161
  logger: Optional[logging.Logger] = None,
236
- prompt_and_response_logger: Optional[
237
- Callable[[str, "dict[str, str]"], None]
238
- ] = None, # (request id, dict of data to store) -> None
239
162
  global_config: PayiInstrumentConfig = {},
240
163
  caller_filename: str = ""
241
164
  ):
@@ -252,14 +175,16 @@ class _PayiInstrumentor:
252
175
  if self._apayi:
253
176
  _g_logger.debug(f"Pay-i instrumentor initialized with AsyncPayi instance: {self._apayi}")
254
177
 
255
- self._context_stack: list[_Context] = [] # Stack of context dictionaries
178
+ # Thread-local storage for context stacks - each thread gets its own stack
179
+ self._thread_local_storage = _ThreadLocalContextStorage()
180
+
256
181
  self._log_prompt_and_response: bool = log_prompt_and_response
257
- self._prompt_and_response_logger: Optional[Callable[[str, dict[str, str]], None]] = prompt_and_response_logger
258
182
 
259
183
  self._blocked_limits: set[str] = set()
260
184
  self._exceeded_limits: set[str] = set()
261
185
 
262
- self._api_connection_error_last_log_time: float = time.time()
186
+ # by not setting to time.time() the first connection error is always logged
187
+ self._api_connection_error_last_log_time: float = 0
263
188
  self._api_connection_error_count: int = 0
264
189
  self._api_connection_error_window: int = global_config.get("connection_error_logging_window", 60)
265
190
  if self._api_connection_error_window < 0:
@@ -272,21 +197,43 @@ class _PayiInstrumentor:
272
197
 
273
198
  self._last_result: Optional[Union[XproxyResult, XproxyError]] = None
274
199
 
200
+ self._offline_instrumentation = global_config.pop("offline_instrumentation", None)
201
+ self._offline_ingest_packets: list[IngestUnitsParams] = []
202
+ self._offline_instrumentation_file_name: Optional[str] = None
203
+
204
+ if self._offline_instrumentation is not None:
205
+ timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
206
+ self._offline_instrumentation_file_name = self._offline_instrumentation.get("file_name", f"payi_instrumentation_{timestamp}.json")
207
+
208
+ # Register exit handler to write packets when process exits
209
+ atexit.register(lambda: self._write_offline_ingest_packets())
210
+
275
211
  global_instrumentation = global_config.pop("global_instrumentation", True)
276
212
 
213
+ # configure first, then instrument
214
+ aws_config = global_config.get("aws_config", None)
215
+ if aws_config:
216
+ from .BedrockInstrumentor import BedrockInstrumentor
217
+ BedrockInstrumentor.configure(aws_config=aws_config)
218
+
219
+ azure_openai_config = global_config.get("azure_openai_config", None)
220
+ if azure_openai_config:
221
+ from .OpenAIInstrumentor import OpenAiInstrumentor
222
+ OpenAiInstrumentor.configure(azure_openai_config=azure_openai_config)
223
+
277
224
  if instruments is None or "*" in instruments:
278
225
  self._instrument_all()
279
226
  else:
280
- self._instrument_specific(instruments)
227
+ self._instrument_specific(instruments=instruments)
228
+
229
+ self._instrument_futures()
281
230
 
282
231
  if global_instrumentation:
283
232
  if "proxy" not in global_config:
284
233
  global_config["proxy"] = self._proxy_default
285
234
 
286
235
  # Use default clients if not provided for global ingest instrumentation
287
- if not self._payi and not self._apayi:
288
- self._payi = Payi()
289
- self._apayi = AsyncPayi()
236
+ self._ensure_payi_clients()
290
237
 
291
238
  if "use_case_name" not in global_config and caller_filename:
292
239
  description = f"Default use case for {caller_filename}.py"
@@ -295,23 +242,53 @@ class _PayiInstrumentor:
295
242
  self._payi.use_cases.definitions.create(name=caller_filename, description=description)
296
243
  elif self._apayi:
297
244
  self._call_async_use_case_definition_create(use_case_name=caller_filename, use_case_description=description)
245
+ else:
246
+ # in the case of _local_instrumentation is not None
247
+ pass
298
248
  global_config["use_case_name"] = caller_filename
299
249
  except Exception as e:
300
250
  self._logger.error(f"Error creating default use case definition based on file name {caller_filename}: {e}")
301
251
 
302
252
  self.__enter__()
303
253
 
304
- # _init_current_context will update the currrent context stack location
254
+ # _init_current_context will update the current context stack location
305
255
  context: _Context = {}
256
+
306
257
  # Copy allowed keys from global_config into context
307
258
  # Dynamically use keys from _Context TypedDict
308
259
  context_keys = list(_Context.__annotations__.keys()) if hasattr(_Context, '__annotations__') else []
309
260
  for key in context_keys:
310
261
  if key in global_config:
311
- context[key] = global_config[key] # type: ignore
262
+ context[key] = global_config[key] # type: ignore[literal-required]
263
+
264
+ self._init_current_context(**context)
265
+
266
+ def _ensure_payi_clients(self) -> None:
267
+ if self._offline_instrumentation is not None:
268
+ return
269
+
270
+ if not self._payi and not self._apayi:
271
+ self._payi = Payi()
272
+ self._apayi = AsyncPayi()
273
+
274
+ def _instrument_futures(self) -> None:
275
+ """Install hooks for all common concurrent execution patterns."""
276
+ def _thread_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any:
277
+ return self._thread_submit_wrapper(wrapped, instance, args, kwargs)
312
278
 
313
- self._init_current_context(**context)
279
+ async def _task_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any:
280
+ return await self._create_task_wrapper(wrapped, instance, args, kwargs)
314
281
 
282
+ try:
283
+ wrap_function_wrapper("concurrent.futures", "ThreadPoolExecutor.submit", _thread_wrapper)
284
+ except Exception as e:
285
+ self._logger.debug(f"Error wrapping ThreadPoolExecutor.submit: {e}")
286
+
287
+ try:
288
+ wrap_function_wrapper("asyncio", "create_task", _task_wrapper)
289
+ except Exception as e:
290
+ self._logger.debug(f"Error wrapping asyncio.create_task: {e}")
291
+
315
292
  def _instrument_all(self) -> None:
316
293
  self._instrument_openai()
317
294
  self._instrument_anthropic()
@@ -375,6 +352,93 @@ class _PayiInstrumentor:
375
352
  except Exception as e:
376
353
  self._logger.error(f"Error instrumenting Google GenAi: {e}")
377
354
 
355
+ def _thread_submit_wrapper(
356
+ self,
357
+ wrapped: Any,
358
+ _instance: Any,
359
+ args: Any,
360
+ kwargs: Any,
361
+ ) -> Any:
362
+ if len(args) > 0:
363
+ fn = args[0]
364
+ fn_args = args[1:]
365
+ captured_context = copy.deepcopy(self._context_safe)
366
+
367
+ def context_wrapper(*inner_args: Any, **inner_kwargs: Any) -> Any:
368
+ with self:
369
+ # self._context_stack[-1].update(captured_context)
370
+ self._init_current_context(**captured_context)
371
+ return fn(*inner_args, **inner_kwargs)
372
+
373
+ return wrapped(context_wrapper, *fn_args, **kwargs)
374
+ return wrapped(*args, **kwargs)
375
+
376
+ async def _create_task_wrapper(
377
+ self,
378
+ wrapped: Any,
379
+ _instance: Any,
380
+ args: Any,
381
+ kwargs: Any,
382
+ ) -> Any:
383
+ if len(args) > 0:
384
+ coro = args[0]
385
+ captured_context = copy.deepcopy(self._context_safe)
386
+
387
+ async def context_wrapper() -> Any:
388
+ with self:
389
+ # self._context_stack[-1].update(captured_context)
390
+ self._init_current_context(**captured_context)
391
+ return await coro
392
+
393
+ return wrapped(context_wrapper(), *args[1:], **kwargs)
394
+ return wrapped(*args, **kwargs)
395
+
396
+ @staticmethod
397
+ def _model_mapping_to_context_dict(model_mappings: Sequence[PayiInstrumentModelMapping]) -> 'dict[str, _Context]':
398
+ context: dict[str, _Context] = {}
399
+ for mapping in model_mappings:
400
+ model = mapping.get("model", "")
401
+ if not model:
402
+ continue
403
+
404
+ price_as_category = mapping.get("price_as_category", None)
405
+ price_as_resource = mapping.get("price_as_resource", None)
406
+ resource_scope = mapping.get("resource_scope", None)
407
+
408
+ if not price_as_category and not price_as_resource:
409
+ continue
410
+
411
+ context[model] = _Context(
412
+ price_as_category=price_as_category,
413
+ price_as_resource=price_as_resource,
414
+ resource_scope=resource_scope,
415
+ )
416
+ return context
417
+
418
+ def _write_offline_ingest_packets(self) -> None:
419
+ if not self._offline_instrumentation_file_name or not self._offline_ingest_packets:
420
+ return
421
+
422
+ try:
423
+ # Convert datetime objects to ISO strings for JSON serialization
424
+ serializable_packets: list[IngestUnitsParams] = []
425
+ for packet in self._offline_ingest_packets:
426
+ serializable_packet = packet.copy()
427
+
428
+ # Convert datetime fields to ISO format strings
429
+ if 'event_timestamp' in serializable_packet and isinstance(serializable_packet['event_timestamp'], datetime):
430
+ serializable_packet['event_timestamp'] = serializable_packet['event_timestamp'].isoformat()
431
+
432
+ serializable_packets.append(serializable_packet)
433
+
434
+ with open(self._offline_instrumentation_file_name, 'w', encoding='utf-8') as f:
435
+ json.dump(serializable_packets, f)
436
+
437
+ self._logger.debug(f"Written {len(self._offline_ingest_packets)} ingest packets to {self._offline_instrumentation_file_name}")
438
+
439
+ except Exception as e:
440
+ self._logger.error(f"Error writing offline ingest packets to {self._offline_instrumentation_file_name}: {e}")
441
+
378
442
  @staticmethod
379
443
  def _create_logged_ingest_units(
380
444
  ingest_units: IngestUnitsParams,
@@ -391,10 +455,10 @@ class _PayiInstrumentor:
391
455
 
392
456
  return log_ingest_units
393
457
 
394
- def _process_ingest_units(
395
- self,
396
- request: _ProviderRequest, log_data: 'dict[str, str]',
397
- extra_headers: 'dict[str, str]') -> None:
458
+ def _after_invoke_update_request(
459
+ self,
460
+ request: _ProviderRequest,
461
+ extra_headers: 'dict[str, str]') -> None:
398
462
  ingest_units = request._ingest
399
463
 
400
464
  if request._module_version:
@@ -404,9 +468,14 @@ class _PayiInstrumentor:
404
468
  # convert the function call builder to a list of function calls
405
469
  ingest_units["provider_response_function_calls"] = list(request._function_call_builder.values())
406
470
 
471
+ if "provider_response_id" not in ingest_units or not ingest_units["provider_response_id"]:
472
+ ingest_units["provider_response_id"] = f"payi_{uuid.uuid4()}"
473
+
407
474
  if 'resource' not in ingest_units or ingest_units['resource'] == '':
408
475
  ingest_units['resource'] = "system.unknown_model"
409
476
 
477
+ request.merge_internal_request_properties()
478
+
410
479
  request_json = ingest_units.get('provider_request_json', "")
411
480
  if request_json and self._instrument_inline_data is False:
412
481
  try:
@@ -414,7 +483,7 @@ class _PayiInstrumentor:
414
483
  if request.remove_inline_data(prompt_dict):
415
484
  self._logger.debug(f"Removed inline data from provider_request_json")
416
485
  # store the modified dict back as JSON string
417
- ingest_units['provider_request_json'] = json.dumps(prompt_dict)
486
+ ingest_units['provider_request_json'] = _compact_json(prompt_dict)
418
487
 
419
488
  except Exception as e:
420
489
  self._logger.error(f"Error serializing provider_request_json: {e}")
@@ -424,19 +493,6 @@ class _PayiInstrumentor:
424
493
  if not units or all(unit.get("input", 0) == 0 and unit.get("output", 0) == 0 for unit in units.values()):
425
494
  self._logger.info('ingesting with no token counts')
426
495
 
427
- if self._log_prompt_and_response and self._prompt_and_response_logger:
428
- response_json = ingest_units.pop("provider_response_json", None)
429
- request_json = ingest_units.pop("provider_request_json", None)
430
- stack_trace = ingest_units.get("properties", {}).pop("system.stack_trace", None) # type: ignore
431
-
432
- if response_json is not None:
433
- # response_json is a list of strings, convert a single json string
434
- log_data["provider_response_json"] = json.dumps(response_json)
435
- if request_json is not None:
436
- log_data["provider_request_json"] = request_json
437
- if stack_trace is not None:
438
- log_data["stack_trace"] = stack_trace
439
-
440
496
  def _process_ingest_units_response(self, ingest_response: IngestResponse) -> None:
441
497
  if ingest_response.xproxy_result.limits:
442
498
  for limit_id, state in ingest_response.xproxy_result.limits.items():
@@ -479,11 +535,8 @@ class _PayiInstrumentor:
479
535
 
480
536
  self._logger.debug(f"_aingest_units")
481
537
 
482
- # return early if there are no units to ingest and on a successul ingest request
483
- log_data: 'dict[str,str]' = {}
484
538
  extra_headers: 'dict[str, str]' = {}
485
-
486
- self._process_ingest_units(request, log_data=log_data, extra_headers=extra_headers)
539
+ self._after_invoke_update_request(request, extra_headers=extra_headers)
487
540
 
488
541
  try:
489
542
  if self._logger.isEnabledFor(logging.DEBUG):
@@ -493,6 +546,18 @@ class _PayiInstrumentor:
493
546
  ingest_response = await self._apayi.ingest.units(**ingest_units, extra_headers=extra_headers)
494
547
  elif self._payi:
495
548
  ingest_response = self._payi.ingest.units(**ingest_units, extra_headers=extra_headers)
549
+ elif self._offline_instrumentation is not None:
550
+ self._offline_ingest_packets.append(ingest_units.copy())
551
+
552
+ # simulate a successful ingest for local instrumentation
553
+ now=datetime.now(timezone.utc)
554
+ ingest_response = IngestResponse(
555
+ event_timestamp=now,
556
+ ingest_timestamp=now,
557
+ request_id="local_instrumentation",
558
+ xproxy_result=XproxyResult(request_id="local_instrumentation"))
559
+ pass
560
+
496
561
  else:
497
562
  self._logger.error("No payi instance to ingest units")
498
563
  return XproxyError(code="configuration_error", message="No Payi or AsyncPayi instance configured for ingesting units")
@@ -502,10 +567,6 @@ class _PayiInstrumentor:
502
567
  if ingest_response:
503
568
  self._process_ingest_units_response(ingest_response)
504
569
 
505
- if ingest_response and self._log_prompt_and_response and self._prompt_and_response_logger:
506
- request_id = ingest_response.xproxy_result.request_id
507
- self._prompt_and_response_logger(request_id, log_data) # type: ignore
508
-
509
570
  return ingest_response.xproxy_result
510
571
 
511
572
  except APIConnectionError as api_ex:
@@ -564,7 +625,7 @@ class _PayiInstrumentor:
564
625
  # Try to get the response body as JSON
565
626
  body = e.body
566
627
  if body is None:
567
- self._logger.error("APIStatusError response has no body attribute")
628
+ self._logger.warning(f"Pay-i ingest exception {e}, status {e.status_code} has no body")
568
629
  return XproxyError(code="unknown_error", message=str(e))
569
630
 
570
631
  # If body is bytes, decode to string
@@ -578,8 +639,9 @@ class _PayiInstrumentor:
578
639
  if not body_dict:
579
640
  try:
580
641
  body_dict = json.loads(body) # type: ignore
581
- except Exception as json_ex:
582
- self._logger.error(f"Failed to parse response body as JSON: {json_ex}")
642
+ except Exception:
643
+ body_type = type(body).__name__ # type: ignore
644
+ self._logger.warning(f"Pay-i ingest exception {e}, status {e.status_code} cannot parse response JSON body for body type {body_type}")
583
645
  return XproxyError(code="invalid_json", message=str(e))
584
646
 
585
647
  xproxy_error = body_dict.get("xproxy_error", {})
@@ -588,7 +650,7 @@ class _PayiInstrumentor:
588
650
  return XproxyError(code=code, message=message)
589
651
 
590
652
  except Exception as ex:
591
- self._logger.error(f"Exception in _process_api_status_error: {ex}")
653
+ self._logger.warning(f"Pay-i ingest exception {e}, status {e.status_code} processing handled exception {ex}")
592
654
  return XproxyError(code="exception", message=str(ex))
593
655
 
594
656
  def _ingest_units_worker(self, request: _ProviderRequest) -> Optional[Union[XproxyResult, XproxyError]]:
@@ -597,10 +659,8 @@ class _PayiInstrumentor:
597
659
 
598
660
  self._logger.debug(f"_ingest_units")
599
661
 
600
- # return early if there are no units to ingest and on a successul ingest request
601
- log_data: 'dict[str,str]' = {}
602
662
  extra_headers: 'dict[str, str]' = {}
603
- self._process_ingest_units(request, log_data=log_data, extra_headers=extra_headers)
663
+ self._after_invoke_update_request(request, extra_headers=extra_headers)
604
664
 
605
665
  try:
606
666
  if self._payi:
@@ -612,16 +672,18 @@ class _PayiInstrumentor:
612
672
 
613
673
  self._process_ingest_units_response(ingest_response)
614
674
 
615
- if self._log_prompt_and_response and self._prompt_and_response_logger:
616
- request_id = ingest_response.xproxy_result.request_id
617
- self._prompt_and_response_logger(request_id, log_data) # type: ignore
618
-
619
675
  return ingest_response.xproxy_result
620
676
  elif self._apayi:
621
677
  # task runs async. aingest_units will invoke the callback and post process
622
678
  sync_response = self._call_aingest_sync(request)
623
679
  self._logger.debug(f"_ingest_units: apayi success ({sync_response})")
624
680
  return sync_response
681
+ elif self._offline_instrumentation is not None:
682
+ self._offline_ingest_packets.append(ingest_units.copy())
683
+
684
+ # simulate a successful ingest for local instrumentation
685
+ return XproxyResult(request_id="local_instrumentation")
686
+
625
687
  else:
626
688
  self._logger.error("No payi instance to ingest units")
627
689
  return XproxyError(code="configuration_error", message="No Payi or AsyncPayi instance configured for ingesting units")
@@ -639,6 +701,21 @@ class _PayiInstrumentor:
639
701
  def _ingest_units(self, request: _ProviderRequest) -> Optional[Union[XproxyResult, XproxyError]]:
640
702
  return self.set_xproxy_result(self._ingest_units_worker(request))
641
703
 
704
+ @property
705
+ def _context_stack(self) -> "list[_Context]":
706
+ """
707
+ Get the thread-local context stack. On first access per thread,
708
+ initializes with the current state of the main thread's context stack.
709
+ """
710
+ # Lazy-initialize the context_stack for this thread if it doesn't exist
711
+ if not hasattr(self._thread_local_storage, 'context_stack'):
712
+ self._thread_local_storage.context_stack = []
713
+
714
+ stack = self._thread_local_storage.context_stack
715
+
716
+
717
+ return stack
718
+
642
719
  def _setup_call_func(
643
720
  self
644
721
  ) -> _Context:
@@ -649,6 +726,31 @@ class _PayiInstrumentor:
649
726
 
650
727
  return {}
651
728
 
729
+ @staticmethod
730
+ def _valid_str_or_none(value: Optional[str], default: Optional[str] = None) -> Optional[str]:
731
+ if value is None:
732
+ return default
733
+ elif len(value) == 0:
734
+ # an empty string explicitly blocks the default value
735
+ return None
736
+ else:
737
+ return value
738
+
739
+ @staticmethod
740
+ def _valid_properties_or_none(value: Optional["dict[str, Optional[str]]"], default: Optional["dict[str, Optional[str]]"] = None) -> Optional["dict[str, Optional[str]]"]:
741
+ if value is None:
742
+ return default.copy() if default else None
743
+ elif len(value) == 0:
744
+ # an empty dictionary explicitly blocks the default value
745
+ return None
746
+ elif default:
747
+ # merge dictionaries, child overrides parent keys
748
+ merged = default.copy()
749
+ merged.update(value)
750
+ return merged
751
+ else:
752
+ return value.copy()
753
+
652
754
  def _init_current_context(
653
755
  self,
654
756
  proxy: Optional[bool] = None,
@@ -659,16 +761,15 @@ class _PayiInstrumentor:
659
761
  use_case_step: Optional[str]= None,
660
762
  user_id: Optional[str]= None,
661
763
  account_name: Optional[str]= None,
662
- request_tags: Optional["list[str]"] = None,
663
- request_properties: Optional["dict[str, str]"] = None,
664
- use_case_properties: Optional["dict[str, str]"] = None,
764
+ request_properties: Optional["dict[str, Optional[str]]"] = None,
765
+ use_case_properties: Optional["dict[str, Optional[str]]"] = None,
665
766
  price_as_category: Optional[str] = None,
666
767
  price_as_resource: Optional[str] = None,
667
768
  resource_scope: Optional[str] = None,
668
769
  ) -> None:
669
770
 
670
771
  # there will always be a current context
671
- context: _Context = self.get_context() # type: ignore
772
+ context: _Context = self._context # type: ignore
672
773
  parent_context: _Context = self._context_stack[-2] if len(self._context_stack) > 1 else {}
673
774
 
674
775
  parent_proxy = parent_context.get("proxy", self._proxy_default)
@@ -709,26 +810,12 @@ class _PayiInstrumentor:
709
810
  assign_use_case_values = True
710
811
 
711
812
  if assign_use_case_values:
712
- context["use_case_id"] = use_case_id if use_case_id else parent_use_case_id
713
- context["use_case_version"] = use_case_version if use_case_version else parent_use_case_version
714
- context["use_case_step"] = use_case_step if use_case_step else parent_use_case_step
813
+ context["use_case_version"] = use_case_version if use_case_version is not None else parent_use_case_version
814
+ context["use_case_id"] = self._valid_str_or_none(use_case_id, parent_use_case_id)
815
+ context["use_case_step"] = self._valid_str_or_none(use_case_step, parent_use_case_step)
715
816
 
716
817
  parent_use_case_properties = parent_context.get("use_case_properties", None)
717
- if use_case_properties is not None:
718
- if not use_case_properties:
719
- # an empty dictionary explicitly blocks inheriting from the parent state
720
- context["use_case_properties"] = None
721
- else:
722
- if parent_use_case_properties:
723
- # merge dictionaries, child overrides parent keys
724
- merged = parent_use_case_properties.copy()
725
- merged.update(use_case_properties)
726
- context["use_case_properties"] = merged
727
- else:
728
- context["use_case_properties"] = use_case_properties.copy()
729
- elif parent_use_case_properties:
730
- # use the parent use_case_properties if it exists
731
- context["use_case_properties"] = parent_use_case_properties.copy()
818
+ context["use_case_properties"] = self._valid_properties_or_none(use_case_properties, parent_use_case_properties)
732
819
 
733
820
  parent_limit_ids = parent_context.get("limit_ids", None)
734
821
  if limit_ids is None:
@@ -742,56 +829,13 @@ class _PayiInstrumentor:
742
829
  context["limit_ids"] = list(set(limit_ids) | set(parent_limit_ids)) if parent_limit_ids else limit_ids.copy()
743
830
 
744
831
  parent_user_id = parent_context.get("user_id", None)
745
- if user_id is None:
746
- # use the parent user_id if it exists
747
- context["user_id"] = parent_user_id
748
- elif len(user_id) == 0:
749
- # caller passing an empty string explicitly blocks inheriting from the parent state
750
- context["user_id"] = None
751
- else:
752
- context["user_id"] = user_id
832
+ context["user_id"] = self._valid_str_or_none(user_id, parent_user_id)
753
833
 
754
834
  parent_account_name = parent_context.get("account_name", None)
755
- if account_name is None:
756
- # use the parent account_name if it exists
757
- context["account_name"] = parent_account_name
758
- elif len(account_name) == 0:
759
- # caller passing an empty string explicitly blocks inheriting from the parent state
760
- context["account_name"] = None
761
- else:
762
- context["account_name"] = account_name
763
-
764
- parent_request_tags = parent_context.get("request_tags", None)
765
- if request_tags is not None:
766
- if len(request_tags) == 0:
767
- # caller passing an empty list explicitly blocks inheriting from the parent state
768
- context["request_tags"] = None
769
- else:
770
- if parent_request_tags:
771
- # union of new and parent lists if the parent context contains request tags
772
- context["request_tags"] = list(set(request_tags) | set(parent_request_tags))
773
- else:
774
- context["request_tags"] = request_tags.copy()
775
- elif parent_request_tags:
776
- # use the parent request_tags if it exists
777
- context["request_tags"] = parent_request_tags.copy()
835
+ context["account_name"] = self._valid_str_or_none(account_name, parent_account_name)
778
836
 
779
837
  parent_request_properties = parent_context.get("request_properties", None)
780
- if request_properties is not None:
781
- if not request_properties:
782
- # an empty dictionary explicitly blocks inheriting from the parent state
783
- context["request_properties"] = None
784
- else:
785
- if parent_request_properties:
786
- # merge dictionaries, child overrides parent keys
787
- merged = parent_request_properties.copy()
788
- merged.update(request_properties)
789
- context["request_properties"] = merged
790
- else:
791
- context["request_properties"] = request_properties.copy()
792
- elif parent_request_properties:
793
- # use the parent request_properties if it exists
794
- context["request_properties"] = parent_request_properties.copy()
838
+ context["request_properties"] = self._valid_properties_or_none(request_properties, parent_request_properties)
795
839
 
796
840
  if price_as_category:
797
841
  context["price_as_category"] = price_as_category
@@ -810,9 +854,8 @@ class _PayiInstrumentor:
810
854
  use_case_version: Optional[int],
811
855
  user_id: Optional[str],
812
856
  account_name: Optional[str],
813
- request_tags: Optional["list[str]"] = None,
814
- request_properties: Optional["dict[str, str]"] = None,
815
- use_case_properties: Optional["dict[str, str]"] = None,
857
+ request_properties: Optional["dict[str, Optional[str]]"] = None,
858
+ use_case_properties: Optional["dict[str, Optional[str]]"] = None,
816
859
  *args: Any,
817
860
  **kwargs: Any,
818
861
  ) -> Any:
@@ -825,7 +868,6 @@ class _PayiInstrumentor:
825
868
  use_case_version=use_case_version,
826
869
  user_id=user_id,
827
870
  account_name=account_name,
828
- request_tags=request_tags,
829
871
  request_properties=request_properties,
830
872
  use_case_properties=use_case_properties
831
873
  )
@@ -841,9 +883,8 @@ class _PayiInstrumentor:
841
883
  use_case_version: Optional[int],
842
884
  user_id: Optional[str],
843
885
  account_name: Optional[str],
844
- request_tags: Optional["list[str]"] = None,
845
- request_properties: Optional["dict[str, str]"] = None,
846
- use_case_properties: Optional["dict[str, str]"] = None,
886
+ request_properties: Optional["dict[str, Optional[str]]"] = None,
887
+ use_case_properties: Optional["dict[str, Optional[str]]"] = None,
847
888
  *args: Any,
848
889
  **kwargs: Any,
849
890
  ) -> Any:
@@ -856,40 +897,51 @@ class _PayiInstrumentor:
856
897
  use_case_version=use_case_version,
857
898
  user_id=user_id,
858
899
  account_name=account_name,
859
- request_tags=request_tags,
860
900
  request_properties=request_properties,
861
901
  use_case_properties=use_case_properties)
862
902
  return func(*args, **kwargs)
863
903
 
864
904
  def __enter__(self) -> Any:
865
- # Push a new context dictionary onto the stack
905
+ # Push a new context dictionary onto the thread-local stack
866
906
  self._context_stack.append({})
867
907
  return self
868
908
 
869
909
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
870
- # Pop the current context off the stack
910
+ # Pop the current context off the thread-local stack
871
911
  if self._context_stack:
872
912
  self._context_stack.pop()
873
913
 
874
- def get_context(self) -> Optional[_Context]:
914
+ @property
915
+ def _context(self) -> Optional[_Context]:
875
916
  # Return the current top of the stack
876
917
  return self._context_stack[-1] if self._context_stack else None
877
918
 
878
- def get_context_safe(self) -> _Context:
919
+ @property
920
+ def _context_safe(self) -> _Context:
879
921
  # Return the current top of the stack
880
- return self.get_context() or {}
922
+ return self._context or {}
923
+
924
+ def _extract_price_as(self, extra_headers: "dict[str, str]") -> PriceAs:
925
+ context = self._context_safe
926
+
927
+ return PriceAs(
928
+ category=extra_headers.pop(PayiHeaderNames.price_as_category, None) or context.get("price_as_category", None),
929
+ resource=extra_headers.pop(PayiHeaderNames.price_as_resource, None) or context.get("price_as_resource", None),
930
+ resource_scope=extra_headers.pop(PayiHeaderNames.resource_scope, None) or context.get("resource_scope", None),
931
+ )
881
932
 
882
- def _prepare_ingest(
933
+ def _before_invoke_update_request(
883
934
  self,
884
935
  request: _ProviderRequest,
885
- context: _Context,
886
936
  ingest_extra_headers: "dict[str, str]", # do not conflict with potential kwargs["extra_headers"]
887
937
  args: Sequence[Any],
888
938
  kwargs: 'dict[str, Any]',
889
939
  ) -> None:
890
940
 
941
+ # pop and ignore the request tags header since it is no longer processed
942
+ ingest_extra_headers.pop(PayiHeaderNames.request_tags, None)
943
+
891
944
  limit_ids = ingest_extra_headers.pop(PayiHeaderNames.limit_ids, None)
892
- request_tags = ingest_extra_headers.pop(PayiHeaderNames.request_tags, None)
893
945
 
894
946
  use_case_name = ingest_extra_headers.pop(PayiHeaderNames.use_case_name, None)
895
947
  use_case_id = ingest_extra_headers.pop(PayiHeaderNames.use_case_id, None)
@@ -899,10 +951,11 @@ class _PayiInstrumentor:
899
951
  user_id = ingest_extra_headers.pop(PayiHeaderNames.user_id, None)
900
952
  account_name = ingest_extra_headers.pop(PayiHeaderNames.account_name, None)
901
953
 
954
+ request_properties = ingest_extra_headers.pop(PayiHeaderNames.request_properties, "")
955
+ use_case_properties = ingest_extra_headers.pop(PayiHeaderNames.use_case_properties, "")
956
+
902
957
  if limit_ids:
903
958
  request._ingest["limit_ids"] = limit_ids.split(",")
904
- if request_tags:
905
- request._ingest["request_tags"] = request_tags.split(",")
906
959
  if use_case_name:
907
960
  request._ingest["use_case_name"] = use_case_name
908
961
  if use_case_id:
@@ -915,14 +968,10 @@ class _PayiInstrumentor:
915
968
  request._ingest["user_id"] = user_id
916
969
  if account_name:
917
970
  request._ingest["account_name"] = account_name
918
-
919
- request_properties = context.get("request_properties", None)
920
971
  if request_properties:
921
- request._ingest["properties"] = request_properties
922
-
923
- use_case_properties = context.get("use_case_properties", None)
972
+ request._ingest["properties"] = json.loads(request_properties)
924
973
  if use_case_properties:
925
- request._ingest["use_case_properties"] = use_case_properties
974
+ request._ingest["use_case_properties"] = json.loads(use_case_properties)
926
975
 
927
976
  if len(ingest_extra_headers) > 0:
928
977
  request._ingest["provider_request_headers"] = [PayICommonModelsAPIRouterHeaderInfoParam(name=k, value=v) for k, v in ingest_extra_headers.items()]
@@ -946,10 +995,10 @@ class _PayiInstrumentor:
946
995
  request.process_request_prompt(provider_prompt, args, kwargs)
947
996
 
948
997
  if self._log_prompt_and_response:
949
- request._ingest["provider_request_json"] = json.dumps(provider_prompt)
998
+ request._ingest["provider_request_json"] = _compact_json(provider_prompt)
950
999
 
951
1000
  request._ingest["event_timestamp"] = datetime.now(timezone.utc)
952
-
1001
+
953
1002
  async def async_invoke_wrapper(
954
1003
  self,
955
1004
  request: _ProviderRequest,
@@ -961,7 +1010,7 @@ class _PayiInstrumentor:
961
1010
  ) -> Any:
962
1011
  self._logger.debug(f"async_invoke_wrapper: instance {instance}, category {request._category}")
963
1012
 
964
- context = self.get_context()
1013
+ context = self._context
965
1014
 
966
1015
  # Bedrock client does not have an async method
967
1016
 
@@ -973,23 +1022,29 @@ class _PayiInstrumentor:
973
1022
 
974
1023
  # after _udpate_headers, all metadata to add to ingest is in extra_headers, keyed by the xproxy-xxx header name
975
1024
  extra_headers: Optional[dict[str, str]] = kwargs.get("extra_headers")
976
- if extra_headers is None:
977
- extra_headers = {}
1025
+ extra_headers = (extra_headers or {}).copy()
978
1026
  self._update_extra_headers(context, extra_headers)
979
1027
 
980
1028
  if context.get("proxy", self._proxy_default):
981
- if "extra_headers" not in kwargs and extra_headers:
1029
+ if not request.supports_extra_headers:
1030
+ kwargs.pop("extra_headers", None)
1031
+ elif extra_headers:
1032
+ # Pass the copy to the wrapped function. Assumes anthropic and openai clients
982
1033
  kwargs["extra_headers"] = extra_headers
983
1034
 
984
1035
  self._logger.debug(f"async_invoke_wrapper: sending proxy request")
985
1036
 
986
1037
  return await wrapped(*args, **kwargs)
1038
+
1039
+ request._price_as = self._extract_price_as(extra_headers)
1040
+ if not request.supports_extra_headers and "extra_headers" in kwargs:
1041
+ kwargs.pop("extra_headers", None)
987
1042
 
988
1043
  current_frame = inspect.currentframe()
989
1044
  # f_back excludes the current frame, strip() cleans up whitespace and newlines
990
1045
  stack = [frame.strip() for frame in traceback.format_stack(current_frame.f_back)] # type: ignore
991
1046
 
992
- request._ingest['properties'] = { 'system.stack_trace': json.dumps(stack) }
1047
+ request._ingest['properties'] = { 'system.stack_trace': _compact_json(stack) }
993
1048
 
994
1049
  if request.process_request(instance, extra_headers, args, kwargs) is False:
995
1050
  self._logger.debug(f"async_invoke_wrapper: calling wrapped instance")
@@ -1006,9 +1061,13 @@ class _PayiInstrumentor:
1006
1061
  stream = False
1007
1062
 
1008
1063
  try:
1009
- self._prepare_ingest(request, context, extra_headers, args, kwargs)
1064
+ self._before_invoke_update_request(request, extra_headers, args, kwargs)
1010
1065
  self._logger.debug(f"async_invoke_wrapper: calling wrapped instance (stream={stream})")
1011
1066
 
1067
+ if "extra_headers" in kwargs:
1068
+ # replace the original extra_headers with the updated copy which has all of the Pay-i headers removed
1069
+ kwargs["extra_headers"] = extra_headers
1070
+
1012
1071
  sw.start()
1013
1072
  response = await wrapped(*args, **kwargs)
1014
1073
 
@@ -1055,6 +1114,8 @@ class _PayiInstrumentor:
1055
1114
  request._ingest["end_to_end_latency_ms"] = duration
1056
1115
  request._ingest["http_status_code"] = 200
1057
1116
 
1117
+ request.add_instrumented_response_headers(response)
1118
+
1058
1119
  return_result: Any = request.process_synchronous_response(
1059
1120
  response=response,
1060
1121
  log_prompt_and_response=self._log_prompt_and_response,
@@ -1064,7 +1125,8 @@ class _PayiInstrumentor:
1064
1125
  self._logger.debug(f"async_invoke_wrapper: process sync response return")
1065
1126
  return return_result
1066
1127
 
1067
- await self._aingest_units(request)
1128
+ xproxy_result = await self._aingest_units(request)
1129
+ request.assign_xproxy_result(response, xproxy_result)
1068
1130
 
1069
1131
  self._logger.debug(f"async_invoke_wrapper: finished")
1070
1132
  return response
@@ -1080,7 +1142,7 @@ class _PayiInstrumentor:
1080
1142
  ) -> Any:
1081
1143
  self._logger.debug(f"invoke_wrapper: instance {instance}, category {request._category}")
1082
1144
 
1083
- context = self.get_context()
1145
+ context = self._context
1084
1146
 
1085
1147
  if not context:
1086
1148
  if not request.supports_extra_headers:
@@ -1093,26 +1155,29 @@ class _PayiInstrumentor:
1093
1155
 
1094
1156
  # after _udpate_headers, all metadata to add to ingest is in extra_headers, keyed by the xproxy-xxx header name
1095
1157
  extra_headers: Optional[dict[str, str]] = kwargs.get("extra_headers")
1096
- if extra_headers is None:
1097
- extra_headers = {}
1158
+ extra_headers = (extra_headers or {}).copy()
1098
1159
  self._update_extra_headers(context, extra_headers)
1099
1160
 
1100
1161
  if context.get("proxy", self._proxy_default):
1101
1162
  if not request.supports_extra_headers:
1102
1163
  kwargs.pop("extra_headers", None)
1103
- elif "extra_headers" not in kwargs and extra_headers:
1104
- # assumes anthropic and openai clients
1164
+ elif extra_headers:
1165
+ # Pass the copy to the wrapped function. Assumes anthropic and openai clients
1105
1166
  kwargs["extra_headers"] = extra_headers
1106
1167
 
1107
1168
  self._logger.debug(f"invoke_wrapper: sending proxy request")
1108
1169
 
1109
1170
  return wrapped(*args, **kwargs)
1171
+
1172
+ request._price_as = self._extract_price_as(extra_headers)
1173
+ if not request.supports_extra_headers and "extra_headers" in kwargs:
1174
+ kwargs.pop("extra_headers", None)
1110
1175
 
1111
1176
  current_frame = inspect.currentframe()
1112
1177
  # f_back excludes the current frame, strip() cleans up whitespace and newlines
1113
1178
  stack = [frame.strip() for frame in traceback.format_stack(current_frame.f_back)] # type: ignore
1114
1179
 
1115
- request._ingest['properties'] = { 'system.stack_trace': json.dumps(stack) }
1180
+ request._ingest['properties'] = { 'system.stack_trace': _compact_json(stack) }
1116
1181
 
1117
1182
  if request.process_request(instance, extra_headers, args, kwargs) is False:
1118
1183
  self._logger.debug(f"invoke_wrapper: calling wrapped instance")
@@ -1129,9 +1194,13 @@ class _PayiInstrumentor:
1129
1194
  stream = False
1130
1195
 
1131
1196
  try:
1132
- self._prepare_ingest(request, context, extra_headers, args, kwargs)
1197
+ self._before_invoke_update_request(request, extra_headers, args, kwargs)
1133
1198
  self._logger.debug(f"invoke_wrapper: calling wrapped instance (stream={stream})")
1134
1199
 
1200
+ if "extra_headers" in kwargs:
1201
+ # replace the original extra_headers with the updated copy which has all of the Pay-i headers removed
1202
+ kwargs["extra_headers"] = extra_headers
1203
+
1135
1204
  sw.start()
1136
1205
  response = wrapped(*args, **kwargs)
1137
1206
 
@@ -1188,15 +1257,19 @@ class _PayiInstrumentor:
1188
1257
  request._ingest["end_to_end_latency_ms"] = duration
1189
1258
  request._ingest["http_status_code"] = 200
1190
1259
 
1260
+ request.add_instrumented_response_headers(response)
1261
+
1191
1262
  return_result: Any = request.process_synchronous_response(
1192
1263
  response=response,
1193
1264
  log_prompt_and_response=self._log_prompt_and_response,
1194
1265
  kwargs=kwargs)
1266
+
1195
1267
  if return_result:
1196
1268
  self._logger.debug(f"invoke_wrapper: process sync response return")
1197
1269
  return return_result
1198
1270
 
1199
- self._ingest_units(request)
1271
+ xproxy_result = self._ingest_units(request)
1272
+ request.assign_xproxy_result(response, xproxy_result)
1200
1273
 
1201
1274
  self._logger.debug(f"invoke_wrapper: finished")
1202
1275
  return response
@@ -1205,7 +1278,7 @@ class _PayiInstrumentor:
1205
1278
  self
1206
1279
  ) -> 'dict[str, str]':
1207
1280
  extra_headers: dict[str, str] = {}
1208
- context = self.get_context()
1281
+ context = self._context
1209
1282
  if context:
1210
1283
  self._update_extra_headers(context, extra_headers)
1211
1284
 
@@ -1229,19 +1302,43 @@ class _PayiInstrumentor:
1229
1302
 
1230
1303
  context_user_id: Optional[str] = context.get("user_id")
1231
1304
  context_account_name: Optional[str] = context.get("account_name")
1232
- context_request_tags: Optional[list[str]] = context.get("request_tags")
1233
1305
 
1234
1306
  context_price_as_category: Optional[str] = context.get("price_as_category")
1235
1307
  context_price_as_resource: Optional[str] = context.get("price_as_resource")
1236
1308
  context_resource_scope: Optional[str] = context.get("resource_scope")
1237
1309
 
1238
- # headers_limit_ids = extra_headers.get(PayiHeaderNames.limit_ids, None)
1239
-
1310
+ context_request_properties: Optional[dict[str, Optional[str]]] = context.get("request_properties")
1311
+ context_use_case_properties: Optional[dict[str, Optional[str]]] = context.get("use_case_properties")
1312
+
1313
+ if PayiHeaderNames.request_properties in extra_headers:
1314
+ headers_request_properties = extra_headers.get(PayiHeaderNames.request_properties, None)
1315
+
1316
+ if not headers_request_properties:
1317
+ # headers_request_properties is empty, remove it from extra_headers
1318
+ extra_headers.pop(PayiHeaderNames.request_properties, None)
1319
+ else:
1320
+ # leave the value in extra_headers
1321
+ ...
1322
+ elif context_request_properties:
1323
+ extra_headers[PayiHeaderNames.request_properties] = _compact_json(context_request_properties)
1324
+
1325
+ if PayiHeaderNames.use_case_properties in extra_headers:
1326
+ headers_use_case_properties = extra_headers.get(PayiHeaderNames.use_case_properties, None)
1327
+
1328
+ if not headers_use_case_properties:
1329
+ # headers_use_case_properties is empty, remove it from extra_headers
1330
+ extra_headers.pop(PayiHeaderNames.use_case_properties, None)
1331
+ else:
1332
+ # leave the value in extra_headers
1333
+ ...
1334
+ elif context_use_case_properties:
1335
+ extra_headers[PayiHeaderNames.use_case_properties] = _compact_json(context_use_case_properties)
1336
+
1240
1337
  # If the caller specifies limit_ids in extra_headers, it takes precedence over the decorator
1241
1338
  if PayiHeaderNames.limit_ids in extra_headers:
1242
1339
  headers_limit_ids = extra_headers.get(PayiHeaderNames.limit_ids)
1243
1340
 
1244
- if headers_limit_ids is None or len(headers_limit_ids) == 0:
1341
+ if not headers_limit_ids:
1245
1342
  # headers_limit_ids is empty, remove it from extra_headers
1246
1343
  extra_headers.pop(PayiHeaderNames.limit_ids, None)
1247
1344
  else:
@@ -1252,7 +1349,7 @@ class _PayiInstrumentor:
1252
1349
 
1253
1350
  if PayiHeaderNames.user_id in extra_headers:
1254
1351
  headers_user_id = extra_headers.get(PayiHeaderNames.user_id, None)
1255
- if headers_user_id is None or len(headers_user_id) == 0:
1352
+ if not headers_user_id:
1256
1353
  # headers_user_id is empty, remove it from extra_headers
1257
1354
  extra_headers.pop(PayiHeaderNames.user_id, None)
1258
1355
  else:
@@ -1263,7 +1360,7 @@ class _PayiInstrumentor:
1263
1360
 
1264
1361
  if PayiHeaderNames.account_name in extra_headers:
1265
1362
  headers_account_name = extra_headers.get(PayiHeaderNames.account_name, None)
1266
- if headers_account_name is None or len(headers_account_name) == 0:
1363
+ if not headers_account_name:
1267
1364
  # headers_account_name is empty, remove it from extra_headers
1268
1365
  extra_headers.pop(PayiHeaderNames.account_name, None)
1269
1366
  else:
@@ -1274,7 +1371,7 @@ class _PayiInstrumentor:
1274
1371
 
1275
1372
  if PayiHeaderNames.use_case_name in extra_headers:
1276
1373
  headers_use_case_name = extra_headers.get(PayiHeaderNames.use_case_name, None)
1277
- if headers_use_case_name is None or len(headers_use_case_name) == 0:
1374
+ if not headers_use_case_name:
1278
1375
  # headers_use_case_name is empty, remove all use case related headers
1279
1376
  extra_headers.pop(PayiHeaderNames.use_case_name, None)
1280
1377
  extra_headers.pop(PayiHeaderNames.use_case_id, None)
@@ -1290,10 +1387,7 @@ class _PayiInstrumentor:
1290
1387
  if context_use_case_version is not None:
1291
1388
  extra_headers[PayiHeaderNames.use_case_version] = str(context_use_case_version)
1292
1389
  if context_use_case_step is not None:
1293
- extra_headers[PayiHeaderNames.use_case_step] = str(context_use_case_step)
1294
-
1295
- if PayiHeaderNames.request_tags not in extra_headers and context_request_tags:
1296
- extra_headers[PayiHeaderNames.request_tags] = ",".join(context_request_tags)
1390
+ extra_headers[PayiHeaderNames.use_case_step] = context_use_case_step
1297
1391
 
1298
1392
  if PayiHeaderNames.price_as_category not in extra_headers and context_price_as_category:
1299
1393
  extra_headers[PayiHeaderNames.price_as_category] = context_price_as_category
@@ -1304,16 +1398,6 @@ class _PayiInstrumentor:
1304
1398
  if PayiHeaderNames.resource_scope not in extra_headers and context_resource_scope:
1305
1399
  extra_headers[PayiHeaderNames.resource_scope] = context_resource_scope
1306
1400
 
1307
- @staticmethod
1308
- def update_for_vision(input: int, units: 'dict[str, Units]', estimated_prompt_tokens: Optional[int]) -> int:
1309
- if estimated_prompt_tokens:
1310
- vision = input - estimated_prompt_tokens
1311
- if (vision > 0):
1312
- units["vision"] = Units(input=vision, output=0)
1313
- input = estimated_prompt_tokens
1314
-
1315
- return input
1316
-
1317
1401
  @staticmethod
1318
1402
  def payi_wrapper(func: Any) -> Any:
1319
1403
  def _payi_wrapper(o: Any) -> Any:
@@ -1346,351 +1430,6 @@ class _PayiInstrumentor:
1346
1430
 
1347
1431
  return _payi_awrapper
1348
1432
 
1349
- class _StreamIteratorWrapper(ObjectProxy): # type: ignore
1350
- def __init__(
1351
- self,
1352
- response: Any,
1353
- instance: Any,
1354
- instrumentor: _PayiInstrumentor,
1355
- stopwatch: Stopwatch,
1356
- request: _ProviderRequest,
1357
- ) -> None:
1358
-
1359
- instrumentor._logger.debug(f"StreamIteratorWrapper: instance {instance}, category {request._category}")
1360
-
1361
- request.process_initial_stream_response(response)
1362
-
1363
- bedrock_from_stream: bool = False
1364
- if request.is_aws_client:
1365
- stream = response.get("stream", None)
1366
-
1367
- if stream:
1368
- response = stream
1369
- bedrock_from_stream = True
1370
- else:
1371
- response = response.get("body")
1372
- bedrock_from_stream = False
1373
-
1374
- super().__init__(response) # type: ignore
1375
-
1376
- self._response = response
1377
- self._instance = instance
1378
-
1379
- self._instrumentor = instrumentor
1380
- self._stopwatch: Stopwatch = stopwatch
1381
- self._responses: list[str] = []
1382
-
1383
- self._request: _ProviderRequest = request
1384
-
1385
- self._first_token: bool = True
1386
- self._bedrock_from_stream: bool = bedrock_from_stream
1387
- self._ingested: bool = False
1388
- self._iter_started: bool = False
1389
-
1390
- def __enter__(self) -> Any:
1391
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __enter__")
1392
- return self
1393
-
1394
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
1395
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __exit__")
1396
- self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) # type: ignore
1397
-
1398
- async def __aenter__(self) -> Any:
1399
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __aenter__")
1400
- return self
1401
-
1402
- async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
1403
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __aexit__")
1404
- await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) # type: ignore
1405
-
1406
- def __iter__(self) -> Any:
1407
- self._iter_started = True
1408
- if self._request.is_aws_client:
1409
- # MUST reside in a separate function so that the yield statement (e.g. the generator) doesn't implicitly return its own iterator and overriding self
1410
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: bedrock __iter__")
1411
- return self._iter_bedrock()
1412
-
1413
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __iter__")
1414
- return self
1415
-
1416
- def _iter_bedrock(self) -> Any:
1417
- # botocore EventStream doesn't have a __next__ method so iterate over the wrapped object in place
1418
- for event in self.__wrapped__: # type: ignore
1419
- result: Optional[_ChunkResult] = None
1420
-
1421
- if (self._bedrock_from_stream):
1422
- result = self._evaluate_chunk(event)
1423
- else:
1424
- chunk = event.get('chunk') # type: ignore
1425
- if chunk:
1426
- decode = chunk.get('bytes').decode() # type: ignore
1427
- result = self._evaluate_chunk(decode)
1428
-
1429
- if result and result.ingest:
1430
- self._stop_iteration()
1431
-
1432
- yield event
1433
-
1434
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: bedrock iter finished")
1435
-
1436
- self._stop_iteration()
1437
-
1438
- def __aiter__(self) -> Any:
1439
- self._iter_started = True
1440
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __aiter__")
1441
- return self
1442
-
1443
- def __next__(self) -> object:
1444
- try:
1445
- chunk: object = self.__wrapped__.__next__() # type: ignore
1446
-
1447
- if self._ingested:
1448
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __next__ already ingested, not processing chunk {chunk}")
1449
- return chunk # type: ignore
1450
-
1451
- result = self._evaluate_chunk(chunk)
1452
-
1453
- if result.ingest:
1454
- self._stop_iteration()
1455
-
1456
- if result.send_chunk_to_caller:
1457
- return chunk # type: ignore
1458
- else:
1459
- return self.__next__()
1460
- except Exception as e:
1461
- if isinstance(e, StopIteration):
1462
- self._stop_iteration()
1463
- else:
1464
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __next__ exception {e}")
1465
- raise e
1466
-
1467
- async def __anext__(self) -> object:
1468
- try:
1469
- chunk: object = await self.__wrapped__.__anext__() # type: ignore
1470
-
1471
- if self._ingested:
1472
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __next__ already ingested, not processing chunk {chunk}")
1473
- return chunk # type: ignore
1474
-
1475
- result = self._evaluate_chunk(chunk)
1476
-
1477
- if result.ingest:
1478
- await self._astop_iteration()
1479
-
1480
- if result.send_chunk_to_caller:
1481
- return chunk # type: ignore
1482
- else:
1483
- return await self.__anext__()
1484
-
1485
- except Exception as e:
1486
- if isinstance(e, StopAsyncIteration):
1487
- await self._astop_iteration()
1488
- else:
1489
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: __anext__ exception {e}")
1490
- raise e
1491
-
1492
- def _evaluate_chunk(self, chunk: Any) -> _ChunkResult:
1493
- if self._first_token:
1494
- self._request._ingest["time_to_first_token_ms"] = self._stopwatch.elapsed_ms_int()
1495
- self._first_token = False
1496
-
1497
- if self._instrumentor._log_prompt_and_response:
1498
- self._responses.append(self.chunk_to_json(chunk))
1499
-
1500
- return self._request.process_chunk(chunk)
1501
-
1502
- def _process_stop_iteration(self) -> None:
1503
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: process stop iteration")
1504
-
1505
- self._stopwatch.stop()
1506
- self._request._ingest["end_to_end_latency_ms"] = self._stopwatch.elapsed_ms_int()
1507
- self._request._ingest["http_status_code"] = 200
1508
-
1509
- if self._instrumentor._log_prompt_and_response:
1510
- self._request._ingest["provider_response_json"] = self._responses
1511
-
1512
- async def _astop_iteration(self) -> None:
1513
- if self._ingested:
1514
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: astop iteration already ingested, skipping")
1515
- return
1516
-
1517
- self._process_stop_iteration()
1518
-
1519
- await self._instrumentor._aingest_units(self._request)
1520
- self._ingested = True
1521
-
1522
- def _stop_iteration(self) -> None:
1523
- if self._ingested:
1524
- self._instrumentor._logger.debug(f"StreamIteratorWrapper: stop iteration already ingested, skipping")
1525
- return
1526
-
1527
- self._process_stop_iteration()
1528
- self._instrumentor._ingest_units(self._request)
1529
- self._ingested = True
1530
-
1531
- @staticmethod
1532
- def chunk_to_json(chunk: Any) -> str:
1533
- if hasattr(chunk, "to_json"):
1534
- return str(chunk.to_json())
1535
- elif isinstance(chunk, bytes):
1536
- return chunk.decode()
1537
- elif isinstance(chunk, str):
1538
- return chunk
1539
- else:
1540
- # assume dict
1541
- return json.dumps(chunk)
1542
-
1543
- class _StreamManagerWrapper(ObjectProxy): # type: ignore
1544
- def __init__(
1545
- self,
1546
- stream_manager: Any, # type: ignore
1547
- instance: Any,
1548
- instrumentor: _PayiInstrumentor,
1549
- stopwatch: Stopwatch,
1550
- request: _ProviderRequest,
1551
- ) -> None:
1552
- instrumentor._logger.debug(f"StreamManagerWrapper: instance {instance}, category {request._category}")
1553
-
1554
- super().__init__(stream_manager) # type: ignore
1555
-
1556
- self._stream_manager = stream_manager
1557
- self._instance = instance
1558
- self._instrumentor = instrumentor
1559
- self._stopwatch: Stopwatch = stopwatch
1560
- self._responses: list[str] = []
1561
- self._request: _ProviderRequest = request
1562
- self._first_token: bool = True
1563
-
1564
- def __enter__(self) -> _StreamIteratorWrapper:
1565
- self._instrumentor._logger.debug(f"_StreamManagerWrapper: __enter__")
1566
-
1567
- return _StreamIteratorWrapper(
1568
- response=self.__wrapped__.__enter__(), # type: ignore
1569
- instance=self._instance,
1570
- instrumentor=self._instrumentor,
1571
- stopwatch=self._stopwatch,
1572
- request=self._request,
1573
- )
1574
-
1575
- class _GeneratorWrapper: # type: ignore
1576
- def __init__(
1577
- self,
1578
- generator: Any,
1579
- instance: Any,
1580
- instrumentor: _PayiInstrumentor,
1581
- stopwatch: Stopwatch,
1582
- request: _ProviderRequest,
1583
- ) -> None:
1584
- instrumentor._logger.debug(f"GeneratorWrapper: instance {instance}, category {request._category}")
1585
-
1586
- super().__init__() # type: ignore
1587
-
1588
- self._generator = generator
1589
- self._instance = instance
1590
- self._instrumentor = instrumentor
1591
- self._stopwatch: Stopwatch = stopwatch
1592
- self._log_prompt_and_response: bool = instrumentor._log_prompt_and_response
1593
- self._responses: list[str] = []
1594
- self._request: _ProviderRequest = request
1595
- self._first_token: bool = True
1596
- self._ingested: bool = False
1597
- self._iter_started: bool = False
1598
-
1599
- def __iter__(self) -> Any:
1600
- self._iter_started = True
1601
- self._instrumentor._logger.debug(f"GeneratorWrapper: __iter__")
1602
- return self
1603
-
1604
- def __aiter__(self) -> Any:
1605
- self._instrumentor._logger.debug(f"GeneratorWrapper: __aiter__")
1606
- return self
1607
-
1608
- def _process_chunk(self, chunk: Any) -> _ChunkResult:
1609
- if self._first_token:
1610
- self._request._ingest["time_to_first_token_ms"] = self._stopwatch.elapsed_ms_int()
1611
- self._first_token = False
1612
-
1613
- if self._log_prompt_and_response:
1614
- dict = self._chunk_to_dict(chunk)
1615
- self._responses.append(json.dumps(dict))
1616
-
1617
- return self._request.process_chunk(chunk)
1618
-
1619
- def __next__(self) -> Any:
1620
- try:
1621
- chunk = next(self._generator)
1622
- result = self._process_chunk(chunk)
1623
-
1624
- if result.ingest:
1625
- self._stop_iteration()
1626
-
1627
- # ignore result.send_chunk_to_caller:
1628
- return chunk
1629
-
1630
- except Exception as e:
1631
- if isinstance(e, StopIteration):
1632
- self._stop_iteration()
1633
- else:
1634
- self._instrumentor._logger.debug(f"GeneratorWrapper: __next__ exception {e}")
1635
- raise e
1636
-
1637
- async def __anext__(self) -> Any:
1638
- try:
1639
- chunk = await anext(self._generator) # type: ignore
1640
- result = self._process_chunk(chunk)
1641
-
1642
- if result.ingest:
1643
- await self._astop_iteration()
1644
-
1645
- # ignore result.send_chunk_to_caller:
1646
- return chunk # type: ignore
1647
-
1648
- except Exception as e:
1649
- if isinstance(e, StopAsyncIteration):
1650
- await self._astop_iteration()
1651
- else:
1652
- self._instrumentor._logger.debug(f"GeneratorWrapper: __anext__ exception {e}")
1653
- raise e
1654
-
1655
- @staticmethod
1656
- def _chunk_to_dict(chunk: Any) -> 'dict[str, object]':
1657
- if hasattr(chunk, "to_dict"):
1658
- return chunk.to_dict() # type: ignore
1659
- elif hasattr(chunk, "to_json_dict"):
1660
- return chunk.to_json_dict() # type: ignore
1661
- else:
1662
- return {}
1663
-
1664
- def _stop_iteration(self) -> None:
1665
- if self._ingested:
1666
- self._instrumentor._logger.debug(f"GeneratorWrapper: stop iteration already ingested, skipping")
1667
- return
1668
-
1669
- self._process_stop_iteration()
1670
-
1671
- self._instrumentor._ingest_units(self._request)
1672
- self._ingested = True
1673
-
1674
- async def _astop_iteration(self) -> None:
1675
- if self._ingested:
1676
- self._instrumentor._logger.debug(f"GeneratorWrapper: astop iteration already ingested, skipping")
1677
- return
1678
-
1679
- self._process_stop_iteration()
1680
-
1681
- await self._instrumentor._aingest_units(self._request)
1682
- self._ingested = True
1683
-
1684
- def _process_stop_iteration(self) -> None:
1685
- self._instrumentor._logger.debug(f"GeneratorWrapper: stop iteration")
1686
-
1687
- self._stopwatch.stop()
1688
- self._request._ingest["end_to_end_latency_ms"] = self._stopwatch.elapsed_ms_int()
1689
- self._request._ingest["http_status_code"] = 200
1690
-
1691
- if self._log_prompt_and_response:
1692
- self._request._ingest["provider_response_json"] = self._responses
1693
-
1694
1433
  global _instrumentor
1695
1434
  _instrumentor: Optional[_PayiInstrumentor] = None
1696
1435
 
@@ -1699,7 +1438,6 @@ def payi_instrument(
1699
1438
  payi: Optional[Union[Payi, AsyncPayi, 'list[Union[Payi, AsyncPayi]]']] = None,
1700
1439
  instruments: Optional[Set[str]] = None,
1701
1440
  log_prompt_and_response: bool = True,
1702
- prompt_and_response_logger: Optional[Callable[[str, "dict[str, str]"], None]] = None,
1703
1441
  config: Optional[PayiInstrumentConfig] = None,
1704
1442
  logger: Optional[logging.Logger] = None,
1705
1443
  ) -> None:
@@ -1732,7 +1470,6 @@ def payi_instrument(
1732
1470
  instruments=instruments,
1733
1471
  log_prompt_and_response=log_prompt_and_response,
1734
1472
  logger=logger,
1735
- prompt_and_response_logger=prompt_and_response_logger,
1736
1473
  global_config=config if config else PayiInstrumentConfig(),
1737
1474
  caller_filename=caller_filename
1738
1475
  )
@@ -1750,9 +1487,9 @@ def track(
1750
1487
  use_case_properties: Optional["dict[str, str]"] = None,
1751
1488
  proxy: Optional[bool] = None,
1752
1489
  ) -> Any:
1490
+ _ = request_tags
1753
1491
 
1754
1492
  def _track(func: Any) -> Any:
1755
- import asyncio
1756
1493
  if asyncio.iscoroutinefunction(func):
1757
1494
  async def awrapper(*args: Any, **kwargs: Any) -> Any:
1758
1495
  if not _instrumentor:
@@ -1770,9 +1507,8 @@ def track(
1770
1507
  use_case_version,
1771
1508
  user_id,
1772
1509
  account_name,
1773
- request_tags,
1774
- request_properties,
1775
- use_case_properties,
1510
+ cast(Optional['dict[str, Optional[str]]'], request_properties),
1511
+ cast(Optional['dict[str, Optional[str]]'], use_case_properties),
1776
1512
  *args,
1777
1513
  **kwargs,
1778
1514
  )
@@ -1794,9 +1530,8 @@ def track(
1794
1530
  use_case_version,
1795
1531
  user_id,
1796
1532
  account_name,
1797
- request_tags,
1798
- request_properties,
1799
- use_case_properties,
1533
+ cast(Optional['dict[str, Optional[str]]'], request_properties),
1534
+ cast(Optional['dict[str, Optional[str]]'], use_case_properties),
1800
1535
  *args,
1801
1536
  **kwargs,
1802
1537
  )
@@ -1834,14 +1569,15 @@ def track_context(
1834
1569
 
1835
1570
  context["user_id"] = user_id
1836
1571
  context["account_name"] = account_name
1837
- context["request_tags"] = request_tags
1838
1572
 
1839
1573
  context["price_as_category"] = price_as_category
1840
1574
  context["price_as_resource"] = price_as_resource
1841
1575
  context["resource_scope"] = resource_scope
1842
1576
 
1843
- context["request_properties"] = request_properties
1844
- context["use_case_properties"] = use_case_properties
1577
+ context["request_properties"] = cast(Optional['dict[str, Optional[str]]'], request_properties)
1578
+ context["use_case_properties"] = cast(Optional['dict[str, Optional[str]]'], use_case_properties)
1579
+
1580
+ _ = request_tags
1845
1581
 
1846
1582
  return _InternalTrackContext(context)
1847
1583
 
@@ -1852,7 +1588,7 @@ def get_context() -> PayiContext:
1852
1588
  """
1853
1589
  if not _instrumentor:
1854
1590
  return PayiContext()
1855
- internal_context = _instrumentor.get_context() or {}
1591
+ internal_context = _instrumentor._context_safe
1856
1592
 
1857
1593
  context_dict = {
1858
1594
  key: value