promptlayer 1.0.71__py3-none-any.whl → 1.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of promptlayer might be problematic. Click here for more details.

promptlayer/utils.py CHANGED
@@ -5,19 +5,35 @@ import functools
5
5
  import json
6
6
  import logging
7
7
  import os
8
- import sys
9
8
  import types
9
+ from contextlib import asynccontextmanager
10
10
  from copy import deepcopy
11
11
  from enum import Enum
12
- from typing import Any, Dict, List, Optional, Union
12
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
13
13
  from uuid import uuid4
14
14
 
15
15
  import httpx
16
16
  import requests
17
+ import urllib3
18
+ import urllib3.util
17
19
  from ably import AblyRealtime
18
20
  from ably.types.message import Message
21
+ from centrifuge import (
22
+ Client,
23
+ PublicationContext,
24
+ SubscriptionEventHandler,
25
+ SubscriptionState,
26
+ )
19
27
  from opentelemetry import context, trace
28
+ from tenacity import (
29
+ before_sleep_log,
30
+ retry,
31
+ retry_if_exception,
32
+ stop_after_attempt,
33
+ wait_exponential,
34
+ )
20
35
 
36
+ from promptlayer import exceptions as _exceptions
21
37
  from promptlayer.types import RequestLog
22
38
  from promptlayer.types.prompt_template import (
23
39
  GetPromptTemplate,
@@ -28,8 +44,6 @@ from promptlayer.types.prompt_template import (
28
44
  )
29
45
 
30
46
  # Configuration
31
- # TODO(dmu) MEDIUM: Use `PROMPTLAYER_` prefix instead of `_PROMPTLAYER` suffix
32
- URL_API_PROMPTLAYER = os.environ.setdefault("URL_API_PROMPTLAYER", "https://api.promptlayer.com")
33
47
  RERAISE_ORIGINAL_EXCEPTION = os.getenv("PROMPTLAYER_RE_RAISE_ORIGINAL_EXCEPTION", "False").lower() == "true"
34
48
  RAISE_FOR_STATUS = os.getenv("PROMPTLAYER_RAISE_FOR_STATUS", "False").lower() == "true"
35
49
  DEFAULT_HTTP_TIMEOUT = 5
@@ -37,7 +51,9 @@ DEFAULT_HTTP_TIMEOUT = 5
37
51
  WORKFLOW_RUN_URL_TEMPLATE = "{base_url}/workflows/{workflow_id}/run"
38
52
  WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE = "workflows:{workflow_id}:run:{channel_name_suffix}"
39
53
  SET_WORKFLOW_COMPLETE_MESSAGE = "SET_WORKFLOW_COMPLETE"
40
- WS_TOKEN_REQUEST_LIBRARY_URL = URL_API_PROMPTLAYER + "/ws-token-request-library"
54
+ WS_TOKEN_REQUEST_LIBRARY_URL = (
55
+ f"{os.getenv('PROMPTLAYER_BASE_URL', 'https://api.promptlayer.com')}/ws-token-request-library"
56
+ )
41
57
 
42
58
 
43
59
  logger = logging.getLogger(__name__)
@@ -48,6 +64,34 @@ class FinalOutputCode(Enum):
48
64
  EXCEEDS_SIZE_LIMIT = "EXCEEDS_SIZE_LIMIT"
49
65
 
50
66
 
67
+ def should_retry_error(exception):
68
+ """Check if an exception should trigger a retry.
69
+
70
+ Only retries on server errors (5xx) and rate limits (429).
71
+ """
72
+ if hasattr(exception, "response"):
73
+ response = exception.response
74
+ if hasattr(response, "status_code"):
75
+ status_code = response.status_code
76
+ if status_code >= 500 or status_code == 429:
77
+ return True
78
+
79
+ if isinstance(exception, (_exceptions.PromptLayerInternalServerError, _exceptions.PromptLayerRateLimitError)):
80
+ return True
81
+
82
+ return False
83
+
84
+
85
+ def retry_on_api_error(func):
86
+ return retry(
87
+ retry=retry_if_exception(should_retry_error),
88
+ stop=stop_after_attempt(4), # 4 total attempts (1 initial + 3 retries)
89
+ wait=wait_exponential(multiplier=2, max=15), # 2s, 4s, 8s
90
+ before_sleep=before_sleep_log(logger, logging.WARNING),
91
+ reraise=True,
92
+ )(func)
93
+
94
+
51
95
  def _get_http_timeout():
52
96
  try:
53
97
  return float(os.getenv("PROMPTLAYER_HTTP_TIMEOUT", DEFAULT_HTTP_TIMEOUT))
@@ -71,62 +115,60 @@ def _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name):
71
115
  return workflow_id_or_name
72
116
 
73
117
 
74
- async def _get_final_output(execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]) -> Dict[str, Any]:
118
+ async def _get_final_output(
119
+ base_url: str, execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]
120
+ ) -> Dict[str, Any]:
75
121
  async with httpx.AsyncClient() as client:
76
122
  response = await client.get(
77
- f"{URL_API_PROMPTLAYER}/workflow-version-execution-results",
123
+ f"{base_url}/workflow-version-execution-results",
78
124
  headers=headers,
79
125
  params={"workflow_version_execution_id": execution_id, "return_all_outputs": return_all_outputs},
80
126
  )
81
- response.raise_for_status()
127
+ if response.status_code != 200:
128
+ raise_on_bad_response(response, "PromptLayer had the following error while getting workflow results")
82
129
  return response.json()
83
130
 
84
131
 
85
132
  # TODO(dmu) MEDIUM: Consider putting all these functions into a class, so we do not have to pass
86
133
  # `authorization_headers` into each function
87
- async def _resolve_workflow_id(workflow_id_or_name: Union[int, str], headers):
134
+ async def _resolve_workflow_id(base_url: str, workflow_id_or_name: Union[int, str], headers):
88
135
  if isinstance(workflow_id_or_name, int):
89
136
  return workflow_id_or_name
90
137
 
91
138
  # TODO(dmu) LOW: Should we warn user here to avoid using workflow names in favor of workflow id?
92
139
  async with _make_httpx_client() as client:
93
140
  # TODO(dmu) MEDIUM: Generalize the way we make async calls to PromptLayer API and reuse it everywhere
94
- response = await client.get(f"{URL_API_PROMPTLAYER}/workflows/{workflow_id_or_name}", headers=headers)
95
- if RAISE_FOR_STATUS:
96
- response.raise_for_status()
97
- elif response.status_code != 200:
98
- raise_on_bad_response(response, "PromptLayer had the following error while running your workflow")
141
+ response = await client.get(f"{base_url}/workflows/{workflow_id_or_name}", headers=headers)
142
+ if response.status_code != 200:
143
+ raise_on_bad_response(response, "PromptLayer had the following error while resolving workflow")
99
144
 
100
145
  return response.json()["workflow"]["id"]
101
146
 
102
147
 
103
- async def _get_ably_token(channel_name, authentication_headers):
148
+ async def _get_ably_token(base_url: str, channel_name, authentication_headers):
104
149
  try:
105
150
  async with _make_httpx_client() as client:
106
151
  response = await client.post(
107
- f"{URL_API_PROMPTLAYER}/ws-token-request-library",
152
+ f"{base_url}/ws-token-request-library",
108
153
  headers=authentication_headers,
109
154
  params={"capability": channel_name},
110
155
  )
111
- if RAISE_FOR_STATUS:
112
- response.raise_for_status()
113
- elif response.status_code != 201:
156
+ if response.status_code != 201:
114
157
  raise_on_bad_response(
115
158
  response,
116
159
  "PromptLayer had the following error while getting WebSocket token",
117
160
  )
118
- return response.json()["token_details"]["token"]
161
+ return response.json()
119
162
  except Exception as ex:
120
163
  error_message = f"Failed to get WebSocket token: {ex}"
121
- print(error_message) # TODO(dmu) MEDIUM: Remove prints in favor of logging
122
164
  logger.exception(error_message)
123
165
  if RERAISE_ORIGINAL_EXCEPTION:
124
166
  raise
125
167
  else:
126
- raise Exception(error_message)
168
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
127
169
 
128
170
 
129
- def _make_message_listener(results_future, execution_id_future, return_all_outputs, headers):
171
+ def _make_message_listener(base_url: str, results_future, execution_id_future, return_all_outputs, headers):
130
172
  # We need this function to be mocked by unittests
131
173
  async def message_listener(message: Message):
132
174
  if results_future.cancelled() or message.name != SET_WORKFLOW_COMPLETE_MESSAGE:
@@ -140,7 +182,7 @@ def _make_message_listener(results_future, execution_id_future, return_all_outpu
140
182
  if (result_code := message_data.get("result_code")) in (FinalOutputCode.OK.value, None):
141
183
  results = message_data["final_output"]
142
184
  elif result_code == FinalOutputCode.EXCEEDS_SIZE_LIMIT.value:
143
- results = await _get_final_output(execution_id, return_all_outputs, headers=headers)
185
+ results = await _get_final_output(base_url, execution_id, return_all_outputs, headers=headers)
144
186
  else:
145
187
  raise NotImplementedError(f"Unsupported final output code: {result_code}")
146
188
 
@@ -149,15 +191,20 @@ def _make_message_listener(results_future, execution_id_future, return_all_outpu
149
191
  return message_listener
150
192
 
151
193
 
152
- async def _subscribe_to_workflow_completion_channel(channel, execution_id_future, return_all_outputs, headers):
194
+ async def _subscribe_to_workflow_completion_channel(
195
+ base_url: str, channel, execution_id_future, return_all_outputs, headers
196
+ ):
153
197
  results_future = asyncio.Future()
154
- message_listener = _make_message_listener(results_future, execution_id_future, return_all_outputs, headers)
198
+ message_listener = _make_message_listener(
199
+ base_url, results_future, execution_id_future, return_all_outputs, headers
200
+ )
155
201
  await channel.subscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
156
202
  return results_future, message_listener
157
203
 
158
204
 
159
205
  async def _post_workflow_id_run(
160
206
  *,
207
+ base_url: str,
161
208
  authentication_headers,
162
209
  workflow_id,
163
210
  input_variables: Dict[str, Any],
@@ -168,7 +215,7 @@ async def _post_workflow_id_run(
168
215
  channel_name_suffix: str,
169
216
  _url_template: str = WORKFLOW_RUN_URL_TEMPLATE,
170
217
  ):
171
- url = _url_template.format(base_url=URL_API_PROMPTLAYER, workflow_id=workflow_id)
218
+ url = _url_template.format(base_url=base_url, workflow_id=workflow_id)
172
219
  payload = {
173
220
  "input_variables": input_variables,
174
221
  "metadata": metadata,
@@ -180,22 +227,19 @@ async def _post_workflow_id_run(
180
227
  try:
181
228
  async with _make_httpx_client() as client:
182
229
  response = await client.post(url, json=payload, headers=authentication_headers)
183
- if RAISE_FOR_STATUS:
184
- response.raise_for_status()
185
- elif response.status_code != 201:
230
+ if response.status_code != 201:
186
231
  raise_on_bad_response(response, "PromptLayer had the following error while running your workflow")
187
232
 
188
233
  result = response.json()
189
234
  if warning := result.get("warning"):
190
- print(f"WARNING: {warning}")
235
+ logger.warning(f"{warning}")
191
236
  except Exception as ex:
192
237
  error_message = f"Failed to run workflow: {str(ex)}"
193
- print(error_message) # TODO(dmu) MEDIUM: Remove prints in favor of logging
194
238
  logger.exception(error_message)
195
239
  if RERAISE_ORIGINAL_EXCEPTION:
196
240
  raise
197
241
  else:
198
- raise Exception(error_message)
242
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
199
243
 
200
244
  return result.get("workflow_version_execution_id")
201
245
 
@@ -205,7 +249,9 @@ async def _wait_for_workflow_completion(channel, results_future, message_listene
205
249
  try:
206
250
  return await asyncio.wait_for(results_future, timeout)
207
251
  except asyncio.TimeoutError:
208
- raise Exception("Workflow execution did not complete properly")
252
+ raise _exceptions.PromptLayerAPITimeoutError(
253
+ "Workflow execution did not complete properly", response=None, body=None
254
+ )
209
255
  finally:
210
256
  channel.unsubscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
211
257
 
@@ -215,14 +261,55 @@ def _make_channel_name_suffix():
215
261
  return uuid4().hex
216
262
 
217
263
 
264
+ MessageCallback = Callable[[Message], Coroutine[None, None, None]]
265
+
266
+
267
+ class SubscriptionEventLoggerHandler(SubscriptionEventHandler):
268
+ def __init__(self, callback: MessageCallback):
269
+ self.callback = callback
270
+
271
+ async def on_publication(self, ctx: PublicationContext):
272
+ message_name = ctx.pub.data.get("message_name", "unknown")
273
+ data = ctx.pub.data.get("data", "")
274
+ message = Message(name=message_name, data=data)
275
+ await self.callback(message)
276
+
277
+
278
+ @asynccontextmanager
279
+ async def centrifugo_client(address: str, token: str):
280
+ client = Client(address, token=token)
281
+ try:
282
+ await client.connect()
283
+ yield client
284
+ finally:
285
+ await client.disconnect()
286
+
287
+
288
+ @asynccontextmanager
289
+ async def centrifugo_subscription(client: Client, topic: str, message_listener: MessageCallback):
290
+ subscription = client.new_subscription(
291
+ topic,
292
+ events=SubscriptionEventLoggerHandler(message_listener),
293
+ )
294
+ try:
295
+ await subscription.subscribe()
296
+ yield
297
+ finally:
298
+ if subscription.state == SubscriptionState.SUBSCRIBED:
299
+ await subscription.unsubscribe()
300
+
301
+
302
+ @retry_on_api_error
218
303
  async def arun_workflow_request(
219
304
  *,
305
+ api_key: str,
306
+ base_url: str,
307
+ throw_on_error: bool,
220
308
  workflow_id_or_name: Optional[Union[int, str]] = None,
221
309
  input_variables: Dict[str, Any],
222
310
  metadata: Optional[Dict[str, Any]] = None,
223
311
  workflow_label_name: Optional[str] = None,
224
312
  workflow_version_number: Optional[int] = None,
225
- api_key: str,
226
313
  return_all_outputs: Optional[bool] = False,
227
314
  timeout: Optional[int] = 3600,
228
315
  # `workflow_name` deprecated, kept for backward compatibility only.
@@ -230,22 +317,50 @@ async def arun_workflow_request(
230
317
  ):
231
318
  headers = {"X-API-KEY": api_key}
232
319
  workflow_id = await _resolve_workflow_id(
233
- _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
320
+ base_url, _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
234
321
  )
235
322
  channel_name_suffix = _make_channel_name_suffix()
236
323
  channel_name = WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE.format(
237
324
  workflow_id=workflow_id, channel_name_suffix=channel_name_suffix
238
325
  )
239
- ably_token = await _get_ably_token(channel_name, headers)
240
- async with AblyRealtime(token=ably_token) as ably_client:
326
+ ably_token = await _get_ably_token(base_url, channel_name, headers)
327
+ token = ably_token["token_details"]["token"]
328
+
329
+ execution_id_future = asyncio.Future[int]()
330
+
331
+ if ably_token.get("messaging_backend") == "centrifugo":
332
+ address = urllib3.util.parse_url(base_url)._replace(scheme="wss", path="/connection/websocket").url
333
+ async with centrifugo_client(address, token) as client:
334
+ results_future = asyncio.Future[dict[str, Any]]()
335
+ async with centrifugo_subscription(
336
+ client,
337
+ channel_name,
338
+ _make_message_listener(base_url, results_future, execution_id_future, return_all_outputs, headers),
339
+ ):
340
+ execution_id = await _post_workflow_id_run(
341
+ base_url=base_url,
342
+ authentication_headers=headers,
343
+ workflow_id=workflow_id,
344
+ input_variables=input_variables,
345
+ metadata=metadata,
346
+ workflow_label_name=workflow_label_name,
347
+ workflow_version_number=workflow_version_number,
348
+ return_all_outputs=return_all_outputs,
349
+ channel_name_suffix=channel_name_suffix,
350
+ )
351
+ execution_id_future.set_result(execution_id)
352
+ await asyncio.wait_for(results_future, timeout)
353
+ return results_future.result()
354
+
355
+ async with AblyRealtime(token=token) as ably_client:
241
356
  # It is crucial to subscribe before running a workflow, otherwise we may miss a completion message
242
357
  channel = ably_client.channels.get(channel_name)
243
- execution_id_future = asyncio.Future()
244
358
  results_future, message_listener = await _subscribe_to_workflow_completion_channel(
245
- channel, execution_id_future, return_all_outputs, headers
359
+ base_url, channel, execution_id_future, return_all_outputs, headers
246
360
  )
247
361
 
248
362
  execution_id = await _post_workflow_id_run(
363
+ base_url=base_url,
249
364
  authentication_headers=headers,
250
365
  workflow_id=workflow_id,
251
366
  input_variables=input_variables,
@@ -261,6 +376,8 @@ async def arun_workflow_request(
261
376
 
262
377
 
263
378
  def promptlayer_api_handler(
379
+ api_key: str,
380
+ base_url: str,
264
381
  function_name,
265
382
  provider_type,
266
383
  args,
@@ -269,7 +386,6 @@ def promptlayer_api_handler(
269
386
  response,
270
387
  request_start_time,
271
388
  request_end_time,
272
- api_key,
273
389
  return_pl_id=False,
274
390
  llm_request_span_id=None,
275
391
  ):
@@ -292,9 +408,11 @@ def promptlayer_api_handler(
292
408
  "llm_request_span_id": llm_request_span_id,
293
409
  },
294
410
  api_key=api_key,
411
+ base_url=base_url,
295
412
  )
296
413
  else:
297
414
  request_id = promptlayer_api_request(
415
+ base_url=base_url,
298
416
  function_name=function_name,
299
417
  provider_type=provider_type,
300
418
  args=args,
@@ -313,6 +431,8 @@ def promptlayer_api_handler(
313
431
 
314
432
 
315
433
  async def promptlayer_api_handler_async(
434
+ api_key: str,
435
+ base_url: str,
316
436
  function_name,
317
437
  provider_type,
318
438
  args,
@@ -321,13 +441,14 @@ async def promptlayer_api_handler_async(
321
441
  response,
322
442
  request_start_time,
323
443
  request_end_time,
324
- api_key,
325
444
  return_pl_id=False,
326
445
  llm_request_span_id=None,
327
446
  ):
328
447
  return await run_in_thread_async(
329
448
  None,
330
449
  promptlayer_api_handler,
450
+ api_key,
451
+ base_url,
331
452
  function_name,
332
453
  provider_type,
333
454
  args,
@@ -336,7 +457,6 @@ async def promptlayer_api_handler_async(
336
457
  response,
337
458
  request_start_time,
338
459
  request_end_time,
339
- api_key,
340
460
  return_pl_id=return_pl_id,
341
461
  llm_request_span_id=llm_request_span_id,
342
462
  )
@@ -356,6 +476,7 @@ def convert_native_object_to_dict(native_object):
356
476
 
357
477
  def promptlayer_api_request(
358
478
  *,
479
+ base_url: str,
359
480
  function_name,
360
481
  provider_type,
361
482
  args,
@@ -376,7 +497,7 @@ def promptlayer_api_request(
376
497
  response = response.dict()
377
498
  try:
378
499
  request_response = requests.post(
379
- f"{URL_API_PROMPTLAYER}/track-request",
500
+ f"{base_url}/track-request",
380
501
  json={
381
502
  "function_name": function_name,
382
503
  "provider_type": provider_type,
@@ -400,43 +521,57 @@ def promptlayer_api_request(
400
521
  request_response, "WARNING: While logging your request PromptLayer had the following error"
401
522
  )
402
523
  except Exception as e:
403
- print(f"WARNING: While logging your request PromptLayer had the following error: {e}", file=sys.stderr)
524
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
404
525
  if request_response is not None and return_pl_id:
405
526
  return request_response.json().get("request_id")
406
527
 
407
528
 
408
- def track_request(**body):
529
+ @retry_on_api_error
530
+ def track_request(base_url: str, throw_on_error: bool, **body):
409
531
  try:
410
532
  response = requests.post(
411
- f"{URL_API_PROMPTLAYER}/track-request",
533
+ f"{base_url}/track-request",
412
534
  json=body,
413
535
  )
414
536
  if response.status_code != 200:
415
- warn_on_bad_response(
416
- response, f"PromptLayer had the following error while tracking your request: {response.text}"
417
- )
537
+ if throw_on_error:
538
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
539
+ else:
540
+ warn_on_bad_response(
541
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
542
+ )
418
543
  return response.json()
419
544
  except requests.exceptions.RequestException as e:
420
- print(f"WARNING: While logging your request PromptLayer had the following error: {e}", file=sys.stderr)
545
+ if throw_on_error:
546
+ raise _exceptions.PromptLayerAPIConnectionError(
547
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
548
+ ) from e
549
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
421
550
  return {}
422
551
 
423
552
 
424
- async def atrack_request(**body: Any) -> Dict[str, Any]:
553
+ @retry_on_api_error
554
+ async def atrack_request(base_url: str, throw_on_error: bool, **body: Any) -> Dict[str, Any]:
425
555
  try:
426
556
  async with _make_httpx_client() as client:
427
557
  response = await client.post(
428
- f"{URL_API_PROMPTLAYER}/track-request",
558
+ f"{base_url}/track-request",
429
559
  json=body,
430
560
  )
431
- if RAISE_FOR_STATUS:
432
- response.raise_for_status()
433
- elif response.status_code != 200:
434
- warn_on_bad_response(
435
- response, f"PromptLayer had the following error while tracking your request: {response.text}"
436
- )
561
+ if response.status_code != 200:
562
+ if throw_on_error:
563
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
564
+ else:
565
+ warn_on_bad_response(
566
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
567
+ )
437
568
  return response.json()
438
569
  except httpx.RequestError as e:
439
- print(f"WARNING: While logging your request PromptLayer had the following error: {e}", file=sys.stderr)
570
+ if throw_on_error:
571
+ raise _exceptions.PromptLayerAPIConnectionError(
572
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
573
+ ) from e
574
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
440
575
  return {}
441
576
 
442
577
 
@@ -468,7 +603,10 @@ def promptlayer_api_request_async(
468
603
  )
469
604
 
470
605
 
471
- def promptlayer_get_prompt(prompt_name, api_key, version: int = None, label: str = None):
606
+ @retry_on_api_error
607
+ def promptlayer_get_prompt(
608
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, version: int = None, label: str = None
609
+ ):
472
610
  """
473
611
  Get a prompt from the PromptLayer library
474
612
  version: version of the prompt to get, None for latest
@@ -476,25 +614,40 @@ def promptlayer_get_prompt(prompt_name, api_key, version: int = None, label: str
476
614
  """
477
615
  try:
478
616
  request_response = requests.get(
479
- f"{URL_API_PROMPTLAYER}/library-get-prompt-template",
617
+ f"{base_url}/library-get-prompt-template",
480
618
  headers={"X-API-KEY": api_key},
481
619
  params={"prompt_name": prompt_name, "version": version, "label": label},
482
620
  )
483
621
  except Exception as e:
484
- raise Exception(f"PromptLayer had the following error while getting your prompt: {e}")
622
+ if throw_on_error:
623
+ raise _exceptions.PromptLayerAPIError(
624
+ f"PromptLayer had the following error while getting your prompt: {e}", response=None, body=None
625
+ ) from e
626
+ logger.warning(f"PromptLayer had the following error while getting your prompt: {e}")
627
+ return None
485
628
  if request_response.status_code != 200:
486
- raise_on_bad_response(
487
- request_response,
488
- "PromptLayer had the following error while getting your prompt",
489
- )
629
+ if throw_on_error:
630
+ raise_on_bad_response(
631
+ request_response,
632
+ "PromptLayer had the following error while getting your prompt",
633
+ )
634
+ else:
635
+ warn_on_bad_response(
636
+ request_response,
637
+ "WARNING: PromptLayer had the following error while getting your prompt",
638
+ )
639
+ return None
490
640
 
491
641
  return request_response.json()
492
642
 
493
643
 
494
- def promptlayer_publish_prompt(prompt_name, prompt_template, commit_message, tags, api_key, metadata=None):
644
+ @retry_on_api_error
645
+ def promptlayer_publish_prompt(
646
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, prompt_template, commit_message, tags, metadata=None
647
+ ):
495
648
  try:
496
649
  request_response = requests.post(
497
- f"{URL_API_PROMPTLAYER}/library-publish-prompt-template",
650
+ f"{base_url}/library-publish-prompt-template",
498
651
  json={
499
652
  "prompt_name": prompt_name,
500
653
  "prompt_template": prompt_template,
@@ -505,19 +658,34 @@ def promptlayer_publish_prompt(prompt_name, prompt_template, commit_message, tag
505
658
  },
506
659
  )
507
660
  except Exception as e:
508
- raise Exception(f"PromptLayer had the following error while publishing your prompt: {e}")
661
+ if throw_on_error:
662
+ raise _exceptions.PromptLayerAPIError(
663
+ f"PromptLayer had the following error while publishing your prompt: {e}", response=None, body=None
664
+ ) from e
665
+ logger.warning(f"PromptLayer had the following error while publishing your prompt: {e}")
666
+ return False
509
667
  if request_response.status_code != 200:
510
- raise_on_bad_response(
511
- request_response,
512
- "PromptLayer had the following error while publishing your prompt",
513
- )
668
+ if throw_on_error:
669
+ raise_on_bad_response(
670
+ request_response,
671
+ "PromptLayer had the following error while publishing your prompt",
672
+ )
673
+ else:
674
+ warn_on_bad_response(
675
+ request_response,
676
+ "WARNING: PromptLayer had the following error while publishing your prompt",
677
+ )
678
+ return False
514
679
  return True
515
680
 
516
681
 
517
- def promptlayer_track_prompt(request_id, prompt_name, input_variables, api_key, version, label):
682
+ @retry_on_api_error
683
+ def promptlayer_track_prompt(
684
+ api_key: str, base_url: str, throw_on_error: bool, request_id, prompt_name, input_variables, version, label
685
+ ):
518
686
  try:
519
687
  request_response = requests.post(
520
- f"{URL_API_PROMPTLAYER}/library-track-prompt",
688
+ f"{base_url}/library-track-prompt",
521
689
  json={
522
690
  "request_id": request_id,
523
691
  "prompt_name": prompt_name,
@@ -528,29 +696,39 @@ def promptlayer_track_prompt(request_id, prompt_name, input_variables, api_key,
528
696
  },
529
697
  )
530
698
  if request_response.status_code != 200:
531
- warn_on_bad_response(
532
- request_response,
533
- "WARNING: While tracking your prompt PromptLayer had the following error",
534
- )
535
- return False
699
+ if throw_on_error:
700
+ raise_on_bad_response(
701
+ request_response,
702
+ "While tracking your prompt PromptLayer had the following error",
703
+ )
704
+ else:
705
+ warn_on_bad_response(
706
+ request_response,
707
+ "WARNING: While tracking your prompt PromptLayer had the following error",
708
+ )
709
+ return False
536
710
  except Exception as e:
537
- print(
538
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
539
- file=sys.stderr,
540
- )
711
+ if throw_on_error:
712
+ raise _exceptions.PromptLayerAPIError(
713
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
714
+ ) from e
715
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
541
716
  return False
542
717
  return True
543
718
 
544
719
 
720
+ @retry_on_api_error
545
721
  async def apromptlayer_track_prompt(
722
+ api_key: str,
723
+ base_url: str,
546
724
  request_id: str,
547
725
  prompt_name: str,
548
726
  input_variables: Dict[str, Any],
549
- api_key: Optional[str] = None,
550
727
  version: Optional[int] = None,
551
728
  label: Optional[str] = None,
729
+ throw_on_error: bool = True,
552
730
  ) -> bool:
553
- url = f"{URL_API_PROMPTLAYER}/library-track-prompt"
731
+ url = f"{base_url}/library-track-prompt"
554
732
  payload = {
555
733
  "request_id": request_id,
556
734
  "prompt_name": prompt_name,
@@ -563,28 +741,31 @@ async def apromptlayer_track_prompt(
563
741
  async with _make_httpx_client() as client:
564
742
  response = await client.post(url, json=payload)
565
743
 
566
- if RAISE_FOR_STATUS:
567
- response.raise_for_status()
568
- elif response.status_code != 200:
569
- warn_on_bad_response(
570
- response,
571
- "WARNING: While tracking your prompt, PromptLayer had the following error",
572
- )
573
- return False
744
+ if response.status_code != 200:
745
+ if throw_on_error:
746
+ raise_on_bad_response(response, "While tracking your prompt, PromptLayer had the following error")
747
+ else:
748
+ warn_on_bad_response(
749
+ response,
750
+ "WARNING: While tracking your prompt, PromptLayer had the following error",
751
+ )
752
+ return False
574
753
  except httpx.RequestError as e:
575
- print(
576
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
577
- file=sys.stderr,
578
- )
754
+ if throw_on_error:
755
+ raise _exceptions.PromptLayerAPIConnectionError(
756
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
757
+ ) from e
758
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
579
759
  return False
580
760
 
581
761
  return True
582
762
 
583
763
 
584
- def promptlayer_track_metadata(request_id, metadata, api_key):
764
+ @retry_on_api_error
765
+ def promptlayer_track_metadata(api_key: str, base_url: str, throw_on_error: bool, request_id, metadata):
585
766
  try:
586
767
  request_response = requests.post(
587
- f"{URL_API_PROMPTLAYER}/library-track-metadata",
768
+ f"{base_url}/library-track-metadata",
588
769
  json={
589
770
  "request_id": request_id,
590
771
  "metadata": metadata,
@@ -592,22 +773,32 @@ def promptlayer_track_metadata(request_id, metadata, api_key):
592
773
  },
593
774
  )
594
775
  if request_response.status_code != 200:
595
- warn_on_bad_response(
596
- request_response,
597
- "WARNING: While tracking your metadata PromptLayer had the following error",
598
- )
599
- return False
776
+ if throw_on_error:
777
+ raise_on_bad_response(
778
+ request_response,
779
+ "While tracking your metadata PromptLayer had the following error",
780
+ )
781
+ else:
782
+ warn_on_bad_response(
783
+ request_response,
784
+ "WARNING: While tracking your metadata PromptLayer had the following error",
785
+ )
786
+ return False
600
787
  except Exception as e:
601
- print(
602
- f"WARNING: While tracking your metadata PromptLayer had the following error: {e}",
603
- file=sys.stderr,
604
- )
788
+ if throw_on_error:
789
+ raise _exceptions.PromptLayerAPIError(
790
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
791
+ ) from e
792
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
605
793
  return False
606
794
  return True
607
795
 
608
796
 
609
- async def apromptlayer_track_metadata(request_id: str, metadata: Dict[str, Any], api_key: Optional[str] = None) -> bool:
610
- url = f"{URL_API_PROMPTLAYER}/library-track-metadata"
797
+ @retry_on_api_error
798
+ async def apromptlayer_track_metadata(
799
+ api_key: str, base_url: str, throw_on_error: bool, request_id: str, metadata: Dict[str, Any]
800
+ ) -> bool:
801
+ url = f"{base_url}/library-track-metadata"
611
802
  payload = {
612
803
  "request_id": request_id,
613
804
  "metadata": metadata,
@@ -617,55 +808,71 @@ async def apromptlayer_track_metadata(request_id: str, metadata: Dict[str, Any],
617
808
  async with _make_httpx_client() as client:
618
809
  response = await client.post(url, json=payload)
619
810
 
620
- if RAISE_FOR_STATUS:
621
- response.raise_for_status()
622
- elif response.status_code != 200:
623
- warn_on_bad_response(
624
- response,
625
- "WARNING: While tracking your metadata, PromptLayer had the following error",
626
- )
627
- return False
811
+ if response.status_code != 200:
812
+ if throw_on_error:
813
+ raise_on_bad_response(
814
+ response,
815
+ "While tracking your metadata, PromptLayer had the following error",
816
+ )
817
+ else:
818
+ warn_on_bad_response(
819
+ response,
820
+ "WARNING: While tracking your metadata, PromptLayer had the following error",
821
+ )
822
+ return False
628
823
  except httpx.RequestError as e:
629
- print(
630
- f"WARNING: While tracking your metadata PromptLayer had the following error: {e}",
631
- file=sys.stderr,
632
- )
824
+ if throw_on_error:
825
+ raise _exceptions.PromptLayerAPIConnectionError(
826
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
827
+ ) from e
828
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
633
829
  return False
634
830
 
635
831
  return True
636
832
 
637
833
 
638
- def promptlayer_track_score(request_id, score, score_name, api_key):
834
+ @retry_on_api_error
835
+ def promptlayer_track_score(api_key: str, base_url: str, throw_on_error: bool, request_id, score, score_name):
639
836
  try:
640
837
  data = {"request_id": request_id, "score": score, "api_key": api_key}
641
838
  if score_name is not None:
642
839
  data["name"] = score_name
643
840
  request_response = requests.post(
644
- f"{URL_API_PROMPTLAYER}/library-track-score",
841
+ f"{base_url}/library-track-score",
645
842
  json=data,
646
843
  )
647
844
  if request_response.status_code != 200:
648
- warn_on_bad_response(
649
- request_response,
650
- "WARNING: While tracking your score PromptLayer had the following error",
651
- )
652
- return False
845
+ if throw_on_error:
846
+ raise_on_bad_response(
847
+ request_response,
848
+ "While tracking your score PromptLayer had the following error",
849
+ )
850
+ else:
851
+ warn_on_bad_response(
852
+ request_response,
853
+ "WARNING: While tracking your score PromptLayer had the following error",
854
+ )
855
+ return False
653
856
  except Exception as e:
654
- print(
655
- f"WARNING: While tracking your score PromptLayer had the following error: {e}",
656
- file=sys.stderr,
657
- )
857
+ if throw_on_error:
858
+ raise _exceptions.PromptLayerAPIError(
859
+ f"While tracking your score PromptLayer had the following error: {e}", response=None, body=None
860
+ ) from e
861
+ logger.warning(f"While tracking your score PromptLayer had the following error: {e}")
658
862
  return False
659
863
  return True
660
864
 
661
865
 
866
+ @retry_on_api_error
662
867
  async def apromptlayer_track_score(
868
+ api_key: str,
869
+ base_url: str,
870
+ throw_on_error: bool,
663
871
  request_id: str,
664
872
  score: float,
665
873
  score_name: Optional[str],
666
- api_key: Optional[str] = None,
667
874
  ) -> bool:
668
- url = f"{URL_API_PROMPTLAYER}/library-track-score"
875
+ url = f"{base_url}/library-track-score"
669
876
  data = {
670
877
  "request_id": request_id,
671
878
  "score": score,
@@ -677,19 +884,24 @@ async def apromptlayer_track_score(
677
884
  async with _make_httpx_client() as client:
678
885
  response = await client.post(url, json=data)
679
886
 
680
- if RAISE_FOR_STATUS:
681
- response.raise_for_status()
682
- elif response.status_code != 200:
683
- warn_on_bad_response(
684
- response,
685
- "WARNING: While tracking your score, PromptLayer had the following error",
686
- )
687
- return False
887
+ if response.status_code != 200:
888
+ if throw_on_error:
889
+ raise_on_bad_response(
890
+ response,
891
+ "While tracking your score, PromptLayer had the following error",
892
+ )
893
+ else:
894
+ warn_on_bad_response(
895
+ response,
896
+ "WARNING: While tracking your score, PromptLayer had the following error",
897
+ )
898
+ return False
688
899
  except httpx.RequestError as e:
689
- print(
690
- f"WARNING: While tracking your score PromptLayer had the following error: {str(e)}",
691
- file=sys.stderr,
692
- )
900
+ if throw_on_error:
901
+ raise _exceptions.PromptLayerAPIConnectionError(
902
+ f"PromptLayer had the following error while tracking your score: {str(e)}", response=None, body=None
903
+ ) from e
904
+ logger.warning(f"While tracking your score PromptLayer had the following error: {str(e)}")
693
905
  return False
694
906
 
695
907
  return True
@@ -753,11 +965,12 @@ def build_anthropic_content_blocks(events):
753
965
 
754
966
 
755
967
  class GeneratorProxy:
756
- def __init__(self, generator, api_request_arguments, api_key):
968
+ def __init__(self, generator, api_request_arguments, api_key, base_url):
757
969
  self.generator = generator
758
970
  self.results = []
759
971
  self.api_request_arugments = api_request_arguments
760
972
  self.api_key = api_key
973
+ self.base_url = base_url
761
974
 
762
975
  def __iter__(self):
763
976
  return self
@@ -772,6 +985,7 @@ class GeneratorProxy:
772
985
  await self.generator._AsyncMessageStreamManager__api_request,
773
986
  api_request_arguments,
774
987
  self.api_key,
988
+ self.base_url,
775
989
  )
776
990
 
777
991
  def __enter__(self):
@@ -782,6 +996,7 @@ class GeneratorProxy:
782
996
  stream,
783
997
  api_request_arguments,
784
998
  self.api_key,
999
+ self.base_url,
785
1000
  )
786
1001
 
787
1002
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -800,7 +1015,7 @@ class GeneratorProxy:
800
1015
 
801
1016
  def __getattr__(self, name):
802
1017
  if name == "text_stream": # anthropic async stream
803
- return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key)
1018
+ return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key, self.base_url)
804
1019
  return getattr(self.generator, name)
805
1020
 
806
1021
  def _abstracted_next(self, result):
@@ -822,6 +1037,7 @@ class GeneratorProxy:
822
1037
 
823
1038
  if end_anthropic or end_openai:
824
1039
  request_id = promptlayer_api_request(
1040
+ base_url=self.base_url,
825
1041
  function_name=self.api_request_arugments["function_name"],
826
1042
  provider_type=self.api_request_arugments["provider_type"],
827
1043
  args=self.api_request_arugments["args"],
@@ -912,39 +1128,71 @@ async def run_in_thread_async(executor, func, *args, **kwargs):
912
1128
  def warn_on_bad_response(request_response, main_message):
913
1129
  if hasattr(request_response, "json"):
914
1130
  try:
915
- print(
916
- f"{main_message}: {request_response.json().get('message')}",
917
- file=sys.stderr,
918
- )
1131
+ logger.warning(f"{main_message}: {request_response.json().get('message')}")
919
1132
  except json.JSONDecodeError:
920
- print(
921
- f"{main_message}: {request_response}",
922
- file=sys.stderr,
923
- )
1133
+ logger.warning(f"{main_message}: {request_response}")
924
1134
  else:
925
- print(f"{main_message}: {request_response}", file=sys.stderr)
1135
+ logger.warning(f"{main_message}: {request_response}")
926
1136
 
927
1137
 
928
1138
  def raise_on_bad_response(request_response, main_message):
1139
+ """Raise an appropriate exception based on the HTTP status code."""
1140
+ status_code = getattr(request_response, "status_code", None)
1141
+
1142
+ body = None
1143
+ error_detail = None
929
1144
  if hasattr(request_response, "json"):
930
1145
  try:
931
- raise Exception(
932
- f"{main_message}: {request_response.json().get('message') or request_response.json().get('error')}"
933
- )
934
- except json.JSONDecodeError:
935
- raise Exception(f"{main_message}: {request_response}")
1146
+ body = request_response.json()
1147
+ error_detail = body.get("message") or body.get("error") or body.get("detail")
1148
+ except (json.JSONDecodeError, AttributeError):
1149
+ body = getattr(request_response, "text", str(request_response))
1150
+ error_detail = body
1151
+ else:
1152
+ body = str(request_response)
1153
+ error_detail = body
1154
+
1155
+ if error_detail:
1156
+ err_msg = f"{main_message}: {error_detail}"
936
1157
  else:
937
- raise Exception(f"{main_message}: {request_response}")
1158
+ err_msg = main_message
1159
+
1160
+ if status_code == 400:
1161
+ raise _exceptions.PromptLayerBadRequestError(err_msg, response=request_response, body=body)
1162
+
1163
+ if status_code == 401:
1164
+ raise _exceptions.PromptLayerAuthenticationError(err_msg, response=request_response, body=body)
1165
+
1166
+ if status_code == 403:
1167
+ raise _exceptions.PromptLayerPermissionDeniedError(err_msg, response=request_response, body=body)
1168
+
1169
+ if status_code == 404:
1170
+ raise _exceptions.PromptLayerNotFoundError(err_msg, response=request_response, body=body)
1171
+
1172
+ if status_code == 409:
1173
+ raise _exceptions.PromptLayerConflictError(err_msg, response=request_response, body=body)
1174
+
1175
+ if status_code == 422:
1176
+ raise _exceptions.PromptLayerUnprocessableEntityError(err_msg, response=request_response, body=body)
1177
+
1178
+ if status_code == 429:
1179
+ raise _exceptions.PromptLayerRateLimitError(err_msg, response=request_response, body=body)
1180
+
1181
+ if status_code and status_code >= 500:
1182
+ raise _exceptions.PromptLayerInternalServerError(err_msg, response=request_response, body=body)
1183
+
1184
+ raise _exceptions.PromptLayerAPIStatusError(err_msg, response=request_response, body=body)
938
1185
 
939
1186
 
940
1187
  async def async_wrapper(
1188
+ api_key: str,
1189
+ base_url: str,
941
1190
  coroutine_obj,
942
1191
  return_pl_id,
943
1192
  request_start_time,
944
1193
  function_name,
945
1194
  provider_type,
946
1195
  tags,
947
- api_key: str = None,
948
1196
  llm_request_span_id: str = None,
949
1197
  tracer=None,
950
1198
  *args,
@@ -957,6 +1205,8 @@ async def async_wrapper(
957
1205
  response = await coroutine_obj
958
1206
  request_end_time = datetime.datetime.now().timestamp()
959
1207
  result = await promptlayer_api_handler_async(
1208
+ api_key,
1209
+ base_url,
960
1210
  function_name,
961
1211
  provider_type,
962
1212
  args,
@@ -965,7 +1215,6 @@ async def async_wrapper(
965
1215
  response,
966
1216
  request_start_time,
967
1217
  request_end_time,
968
- api_key,
969
1218
  return_pl_id=return_pl_id,
970
1219
  llm_request_span_id=llm_request_span_id,
971
1220
  )
@@ -980,53 +1229,75 @@ async def async_wrapper(
980
1229
  context.detach(token)
981
1230
 
982
1231
 
983
- def promptlayer_create_group(api_key: str = None):
1232
+ @retry_on_api_error
1233
+ def promptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
984
1234
  try:
985
1235
  request_response = requests.post(
986
- f"{URL_API_PROMPTLAYER}/create-group",
1236
+ f"{base_url}/create-group",
987
1237
  json={
988
1238
  "api_key": api_key,
989
1239
  },
990
1240
  )
991
1241
  if request_response.status_code != 200:
992
- warn_on_bad_response(
993
- request_response,
994
- "WARNING: While creating your group PromptLayer had the following error",
995
- )
996
- return False
1242
+ if throw_on_error:
1243
+ raise_on_bad_response(
1244
+ request_response,
1245
+ "While creating your group PromptLayer had the following error",
1246
+ )
1247
+ else:
1248
+ warn_on_bad_response(
1249
+ request_response,
1250
+ "WARNING: While creating your group PromptLayer had the following error",
1251
+ )
1252
+ return False
997
1253
  except requests.exceptions.RequestException as e:
998
- # I'm aiming for a more specific exception catch here
999
- raise Exception(f"PromptLayer had the following error while creating your group: {e}")
1254
+ if throw_on_error:
1255
+ raise _exceptions.PromptLayerAPIConnectionError(
1256
+ f"PromptLayer had the following error while creating your group: {e}", response=None, body=None
1257
+ ) from e
1258
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1259
+ return False
1000
1260
  return request_response.json()["id"]
1001
1261
 
1002
1262
 
1003
- async def apromptlayer_create_group(api_key: Optional[str] = None) -> str:
1263
+ @retry_on_api_error
1264
+ async def apromptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
1004
1265
  try:
1005
1266
  async with _make_httpx_client() as client:
1006
1267
  response = await client.post(
1007
- f"{URL_API_PROMPTLAYER}/create-group",
1268
+ f"{base_url}/create-group",
1008
1269
  json={
1009
1270
  "api_key": api_key,
1010
1271
  },
1011
1272
  )
1012
1273
 
1013
- if RAISE_FOR_STATUS:
1014
- response.raise_for_status()
1015
- elif response.status_code != 200:
1016
- warn_on_bad_response(
1017
- response,
1018
- "WARNING: While creating your group, PromptLayer had the following error",
1019
- )
1020
- return False
1274
+ if response.status_code != 200:
1275
+ if throw_on_error:
1276
+ raise_on_bad_response(
1277
+ response,
1278
+ "While creating your group, PromptLayer had the following error",
1279
+ )
1280
+ else:
1281
+ warn_on_bad_response(
1282
+ response,
1283
+ "WARNING: While creating your group, PromptLayer had the following error",
1284
+ )
1285
+ return False
1021
1286
  return response.json()["id"]
1022
1287
  except httpx.RequestError as e:
1023
- raise Exception(f"PromptLayer had the following error while creating your group: {str(e)}") from e
1288
+ if throw_on_error:
1289
+ raise _exceptions.PromptLayerAPIConnectionError(
1290
+ f"PromptLayer had the following error while creating your group: {str(e)}", response=None, body=None
1291
+ ) from e
1292
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1293
+ return False
1024
1294
 
1025
1295
 
1026
- def promptlayer_track_group(request_id, group_id, api_key: str = None):
1296
+ @retry_on_api_error
1297
+ def promptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
1027
1298
  try:
1028
1299
  request_response = requests.post(
1029
- f"{URL_API_PROMPTLAYER}/track-group",
1300
+ f"{base_url}/track-group",
1030
1301
  json={
1031
1302
  "api_key": api_key,
1032
1303
  "request_id": request_id,
@@ -1034,18 +1305,29 @@ def promptlayer_track_group(request_id, group_id, api_key: str = None):
1034
1305
  },
1035
1306
  )
1036
1307
  if request_response.status_code != 200:
1037
- warn_on_bad_response(
1038
- request_response,
1039
- "WARNING: While tracking your group PromptLayer had the following error",
1040
- )
1041
- return False
1308
+ if throw_on_error:
1309
+ raise_on_bad_response(
1310
+ request_response,
1311
+ "While tracking your group PromptLayer had the following error",
1312
+ )
1313
+ else:
1314
+ warn_on_bad_response(
1315
+ request_response,
1316
+ "WARNING: While tracking your group PromptLayer had the following error",
1317
+ )
1318
+ return False
1042
1319
  except requests.exceptions.RequestException as e:
1043
- # I'm aiming for a more specific exception catch here
1044
- raise Exception(f"PromptLayer had the following error while tracking your group: {e}")
1320
+ if throw_on_error:
1321
+ raise _exceptions.PromptLayerAPIConnectionError(
1322
+ f"PromptLayer had the following error while tracking your group: {e}", response=None, body=None
1323
+ ) from e
1324
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
1325
+ return False
1045
1326
  return True
1046
1327
 
1047
1328
 
1048
- async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1329
+ @retry_on_api_error
1330
+ async def apromptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
1049
1331
  try:
1050
1332
  payload = {
1051
1333
  "api_key": api_key,
@@ -1054,59 +1336,86 @@ async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1054
1336
  }
1055
1337
  async with _make_httpx_client() as client:
1056
1338
  response = await client.post(
1057
- f"{URL_API_PROMPTLAYER}/track-group",
1339
+ f"{base_url}/track-group",
1058
1340
  headers={"X-API-KEY": api_key},
1059
1341
  json=payload,
1060
1342
  )
1061
1343
 
1062
- if RAISE_FOR_STATUS:
1063
- response.raise_for_status()
1064
- elif response.status_code != 200:
1065
- warn_on_bad_response(
1066
- response,
1067
- "WARNING: While tracking your group, PromptLayer had the following error",
1068
- )
1069
- return False
1344
+ if response.status_code != 200:
1345
+ if throw_on_error:
1346
+ raise_on_bad_response(
1347
+ response,
1348
+ "While tracking your group, PromptLayer had the following error",
1349
+ )
1350
+ else:
1351
+ warn_on_bad_response(
1352
+ response,
1353
+ "WARNING: While tracking your group, PromptLayer had the following error",
1354
+ )
1355
+ return False
1070
1356
  except httpx.RequestError as e:
1071
- print(
1072
- f"WARNING: While tracking your group PromptLayer had the following error: {e}",
1073
- file=sys.stderr,
1074
- )
1357
+ if throw_on_error:
1358
+ raise _exceptions.PromptLayerAPIConnectionError(
1359
+ f"PromptLayer had the following error while tracking your group: {str(e)}", response=None, body=None
1360
+ ) from e
1361
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
1075
1362
  return False
1076
1363
 
1077
1364
  return True
1078
1365
 
1079
1366
 
1367
+ @retry_on_api_error
1080
1368
  def get_prompt_template(
1081
- prompt_name: str, params: Union[GetPromptTemplate, None] = None, api_key: str = None
1369
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name: str, params: Union[GetPromptTemplate, None] = None
1082
1370
  ) -> GetPromptTemplateResponse:
1083
1371
  try:
1084
1372
  json_body = {"api_key": api_key}
1085
1373
  if params:
1086
1374
  json_body = {**json_body, **params}
1087
1375
  response = requests.post(
1088
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1376
+ f"{base_url}/prompt-templates/{prompt_name}",
1089
1377
  headers={"X-API-KEY": api_key},
1090
1378
  json=json_body,
1091
1379
  )
1092
1380
  if response.status_code != 200:
1093
- raise Exception(f"PromptLayer had the following error while getting your prompt template: {response.text}")
1381
+ if throw_on_error:
1382
+ raise_on_bad_response(
1383
+ response, "PromptLayer had the following error while getting your prompt template"
1384
+ )
1385
+ else:
1386
+ warn_on_bad_response(
1387
+ response, "WARNING: PromptLayer had the following error while getting your prompt template"
1388
+ )
1389
+ return None
1094
1390
 
1095
- warning = response.json().get("warning", None)
1096
- if warning is not None:
1097
- warn_on_bad_response(
1098
- warning,
1099
- "WARNING: While getting your prompt template",
1100
- )
1101
1391
  return response.json()
1392
+ except requests.exceptions.ConnectionError as e:
1393
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1394
+ if throw_on_error:
1395
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1396
+ logger.warning(err_msg)
1397
+ return None
1398
+ except requests.exceptions.Timeout as e:
1399
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1400
+ if throw_on_error:
1401
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1402
+ logger.warning(err_msg)
1403
+ return None
1102
1404
  except requests.exceptions.RequestException as e:
1103
- raise Exception(f"PromptLayer had the following error while getting your prompt template: {e}")
1405
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1406
+ if throw_on_error:
1407
+ raise _exceptions.PromptLayerError(err_msg, response=None, body=None) from e
1408
+ logger.warning(err_msg)
1409
+ return None
1104
1410
 
1105
1411
 
1412
+ @retry_on_api_error
1106
1413
  async def aget_prompt_template(
1414
+ api_key: str,
1415
+ base_url: str,
1416
+ throw_on_error: bool,
1107
1417
  prompt_name: str,
1108
1418
  params: Union[GetPromptTemplate, None] = None,
1109
- api_key: str = None,
1110
1419
  ) -> GetPromptTemplateResponse:
1111
1420
  try:
1112
1421
  json_body = {"api_key": api_key}
@@ -1114,36 +1423,53 @@ async def aget_prompt_template(
1114
1423
  json_body.update(params)
1115
1424
  async with _make_httpx_client() as client:
1116
1425
  response = await client.post(
1117
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1426
+ f"{base_url}/prompt-templates/{prompt_name}",
1118
1427
  headers={"X-API-KEY": api_key},
1119
1428
  json=json_body,
1120
1429
  )
1121
1430
 
1122
- if RAISE_FOR_STATUS:
1123
- response.raise_for_status()
1124
- elif response.status_code != 200:
1125
- raise_on_bad_response(
1126
- response,
1127
- "PromptLayer had the following error while getting your prompt template",
1128
- )
1129
- warning = response.json().get("warning", None)
1130
- if warning:
1131
- warn_on_bad_response(
1132
- warning,
1133
- "WARNING: While getting your prompt template",
1134
- )
1431
+ if response.status_code != 200:
1432
+ if throw_on_error:
1433
+ raise_on_bad_response(
1434
+ response,
1435
+ "PromptLayer had the following error while getting your prompt template",
1436
+ )
1437
+ else:
1438
+ warn_on_bad_response(
1439
+ response, "WARNING: While getting your prompt template PromptLayer had the following error"
1440
+ )
1441
+ return None
1135
1442
  return response.json()
1443
+ except (httpx.ConnectError, httpx.NetworkError) as e:
1444
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1445
+ if throw_on_error:
1446
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1447
+ logger.warning(err_msg)
1448
+ return None
1449
+ except httpx.TimeoutException as e:
1450
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1451
+ if throw_on_error:
1452
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1453
+ logger.warning(err_msg)
1454
+ return None
1136
1455
  except httpx.RequestError as e:
1137
- raise Exception(f"PromptLayer had the following error while getting your prompt template: {str(e)}") from e
1456
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1457
+ if throw_on_error:
1458
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1459
+ logger.warning(err_msg)
1460
+ return None
1138
1461
 
1139
1462
 
1463
+ @retry_on_api_error
1140
1464
  def publish_prompt_template(
1465
+ api_key: str,
1466
+ base_url: str,
1467
+ throw_on_error: bool,
1141
1468
  body: PublishPromptTemplate,
1142
- api_key: str = None,
1143
1469
  ) -> PublishPromptTemplateResponse:
1144
1470
  try:
1145
1471
  response = requests.post(
1146
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1472
+ f"{base_url}/rest/prompt-templates",
1147
1473
  headers={"X-API-KEY": api_key},
1148
1474
  json={
1149
1475
  "prompt_template": {**body},
@@ -1152,22 +1478,38 @@ def publish_prompt_template(
1152
1478
  },
1153
1479
  )
1154
1480
  if response.status_code == 400:
1155
- raise Exception(
1156
- f"PromptLayer had the following error while publishing your prompt template: {response.text}"
1157
- )
1481
+ if throw_on_error:
1482
+ raise_on_bad_response(
1483
+ response, "PromptLayer had the following error while publishing your prompt template"
1484
+ )
1485
+ else:
1486
+ warn_on_bad_response(
1487
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1488
+ )
1489
+ return None
1158
1490
  return response.json()
1159
1491
  except requests.exceptions.RequestException as e:
1160
- raise Exception(f"PromptLayer had the following error while publishing your prompt template: {e}")
1492
+ if throw_on_error:
1493
+ raise _exceptions.PromptLayerAPIConnectionError(
1494
+ f"PromptLayer had the following error while publishing your prompt template: {e}",
1495
+ response=None,
1496
+ body=None,
1497
+ ) from e
1498
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1499
+ return None
1161
1500
 
1162
1501
 
1502
+ @retry_on_api_error
1163
1503
  async def apublish_prompt_template(
1504
+ api_key: str,
1505
+ base_url: str,
1506
+ throw_on_error: bool,
1164
1507
  body: PublishPromptTemplate,
1165
- api_key: str = None,
1166
1508
  ) -> PublishPromptTemplateResponse:
1167
1509
  try:
1168
1510
  async with _make_httpx_client() as client:
1169
1511
  response = await client.post(
1170
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1512
+ f"{base_url}/rest/prompt-templates",
1171
1513
  headers={"X-API-KEY": api_key},
1172
1514
  json={
1173
1515
  "prompt_template": {**body},
@@ -1176,46 +1518,68 @@ async def apublish_prompt_template(
1176
1518
  },
1177
1519
  )
1178
1520
 
1179
- if RAISE_FOR_STATUS:
1180
- response.raise_for_status()
1181
- elif response.status_code == 400:
1182
- raise Exception(
1183
- f"PromptLayer had the following error while publishing your prompt template: {response.text}"
1184
- )
1185
- if response.status_code != 201:
1186
- raise_on_bad_response(
1187
- response,
1188
- "PromptLayer had the following error while publishing your prompt template",
1189
- )
1521
+ if response.status_code == 400 or response.status_code != 201:
1522
+ if throw_on_error:
1523
+ raise_on_bad_response(
1524
+ response,
1525
+ "PromptLayer had the following error while publishing your prompt template",
1526
+ )
1527
+ else:
1528
+ warn_on_bad_response(
1529
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1530
+ )
1531
+ return None
1190
1532
  return response.json()
1191
1533
  except httpx.RequestError as e:
1192
- raise Exception(f"PromptLayer had the following error while publishing your prompt template: {str(e)}") from e
1534
+ if throw_on_error:
1535
+ raise _exceptions.PromptLayerAPIConnectionError(
1536
+ f"PromptLayer had the following error while publishing your prompt template: {str(e)}",
1537
+ response=None,
1538
+ body=None,
1539
+ ) from e
1540
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1541
+ return None
1193
1542
 
1194
1543
 
1544
+ @retry_on_api_error
1195
1545
  def get_all_prompt_templates(
1196
- page: int = 1, per_page: int = 30, api_key: str = None, label: str = None
1546
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
1197
1547
  ) -> List[ListPromptTemplateResponse]:
1198
1548
  try:
1199
1549
  params = {"page": page, "per_page": per_page}
1200
1550
  if label:
1201
1551
  params["label"] = label
1202
1552
  response = requests.get(
1203
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1553
+ f"{base_url}/prompt-templates",
1204
1554
  headers={"X-API-KEY": api_key},
1205
1555
  params=params,
1206
1556
  )
1207
1557
  if response.status_code != 200:
1208
- raise Exception(
1209
- f"PromptLayer had the following error while getting all your prompt templates: {response.text}"
1210
- )
1558
+ if throw_on_error:
1559
+ raise_on_bad_response(
1560
+ response, "PromptLayer had the following error while getting all your prompt templates"
1561
+ )
1562
+ else:
1563
+ warn_on_bad_response(
1564
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1565
+ )
1566
+ return []
1211
1567
  items = response.json().get("items", [])
1212
1568
  return items
1213
1569
  except requests.exceptions.RequestException as e:
1214
- raise Exception(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1570
+ if throw_on_error:
1571
+ raise _exceptions.PromptLayerAPIConnectionError(
1572
+ f"PromptLayer had the following error while getting all your prompt templates: {e}",
1573
+ response=None,
1574
+ body=None,
1575
+ ) from e
1576
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1577
+ return []
1215
1578
 
1216
1579
 
1580
+ @retry_on_api_error
1217
1581
  async def aget_all_prompt_templates(
1218
- page: int = 1, per_page: int = 30, api_key: str = None, label: str = None
1582
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
1219
1583
  ) -> List[ListPromptTemplateResponse]:
1220
1584
  try:
1221
1585
  params = {"page": page, "per_page": per_page}
@@ -1223,22 +1587,33 @@ async def aget_all_prompt_templates(
1223
1587
  params["label"] = label
1224
1588
  async with _make_httpx_client() as client:
1225
1589
  response = await client.get(
1226
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1590
+ f"{base_url}/prompt-templates",
1227
1591
  headers={"X-API-KEY": api_key},
1228
1592
  params=params,
1229
1593
  )
1230
1594
 
1231
- if RAISE_FOR_STATUS:
1232
- response.raise_for_status()
1233
- elif response.status_code != 200:
1234
- raise_on_bad_response(
1235
- response,
1236
- "PromptLayer had the following error while getting all your prompt templates",
1237
- )
1595
+ if response.status_code != 200:
1596
+ if throw_on_error:
1597
+ raise_on_bad_response(
1598
+ response,
1599
+ "PromptLayer had the following error while getting all your prompt templates",
1600
+ )
1601
+ else:
1602
+ warn_on_bad_response(
1603
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1604
+ )
1605
+ return []
1238
1606
  items = response.json().get("items", [])
1239
1607
  return items
1240
1608
  except httpx.RequestError as e:
1241
- raise Exception(f"PromptLayer had the following error while getting all your prompt templates: {str(e)}") from e
1609
+ if throw_on_error:
1610
+ raise _exceptions.PromptLayerAPIConnectionError(
1611
+ f"PromptLayer had the following error while getting all your prompt templates: {str(e)}",
1612
+ response=None,
1613
+ body=None,
1614
+ ) from e
1615
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1616
+ return []
1242
1617
 
1243
1618
 
1244
1619
  def openai_chat_request(client, **kwargs):
@@ -1259,7 +1634,7 @@ def openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: d
1259
1634
  from openai import OpenAI
1260
1635
 
1261
1636
  client = OpenAI(**client_kwargs)
1262
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1637
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1263
1638
 
1264
1639
  if api_type == "chat-completions":
1265
1640
  request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1286,7 +1661,7 @@ async def aopenai_request(prompt_blueprint: GetPromptTemplateResponse, client_kw
1286
1661
  from openai import AsyncOpenAI
1287
1662
 
1288
1663
  client = AsyncOpenAI(**client_kwargs)
1289
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1664
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1290
1665
 
1291
1666
  if api_type == "chat-completions":
1292
1667
  request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1299,7 +1674,7 @@ def azure_openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwa
1299
1674
  from openai import AzureOpenAI
1300
1675
 
1301
1676
  client = AzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1302
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1677
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1303
1678
 
1304
1679
  if api_type == "chat-completions":
1305
1680
  request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1314,7 +1689,7 @@ async def aazure_openai_request(
1314
1689
  from openai import AsyncAzureOpenAI
1315
1690
 
1316
1691
  client = AsyncAzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1317
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1692
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1318
1693
 
1319
1694
  if api_type == "chat-completions":
1320
1695
  request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1372,54 +1747,66 @@ def get_api_key():
1372
1747
  # raise an error if the api key is not set
1373
1748
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
1374
1749
  if not api_key:
1375
- raise Exception(
1376
- "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>' "
1750
+ raise _exceptions.PromptLayerAuthenticationError(
1751
+ "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>'",
1752
+ response=None,
1753
+ body=None,
1377
1754
  )
1378
1755
  return api_key
1379
1756
 
1380
1757
 
1381
- def util_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1758
+ @retry_on_api_error
1759
+ def util_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1382
1760
  try:
1383
1761
  response = requests.post(
1384
- f"{URL_API_PROMPTLAYER}/log-request",
1762
+ f"{base_url}/log-request",
1385
1763
  headers={"X-API-KEY": api_key},
1386
1764
  json=kwargs,
1387
1765
  )
1388
1766
  if response.status_code != 201:
1389
- warn_on_bad_response(
1390
- response,
1391
- "WARNING: While logging your request PromptLayer had the following error",
1392
- )
1393
- return None
1767
+ if throw_on_error:
1768
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1769
+ else:
1770
+ warn_on_bad_response(
1771
+ response,
1772
+ "WARNING: While logging your request PromptLayer had the following error",
1773
+ )
1774
+ return None
1394
1775
  return response.json()
1395
1776
  except Exception as e:
1396
- print(
1397
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
1398
- file=sys.stderr,
1399
- )
1777
+ if throw_on_error:
1778
+ raise _exceptions.PromptLayerAPIError(
1779
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1780
+ ) from e
1781
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1400
1782
  return None
1401
1783
 
1402
1784
 
1403
- async def autil_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1785
+ @retry_on_api_error
1786
+ async def autil_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1404
1787
  try:
1405
1788
  async with _make_httpx_client() as client:
1406
1789
  response = await client.post(
1407
- f"{URL_API_PROMPTLAYER}/log-request",
1790
+ f"{base_url}/log-request",
1408
1791
  headers={"X-API-KEY": api_key},
1409
1792
  json=kwargs,
1410
1793
  )
1411
1794
  if response.status_code != 201:
1412
- warn_on_bad_response(
1413
- response,
1414
- "WARNING: While logging your request PromptLayer had the following error",
1415
- )
1416
- return None
1795
+ if throw_on_error:
1796
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1797
+ else:
1798
+ warn_on_bad_response(
1799
+ response,
1800
+ "WARNING: While logging your request PromptLayer had the following error",
1801
+ )
1802
+ return None
1417
1803
  return response.json()
1418
1804
  except Exception as e:
1419
- print(
1420
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
1421
- file=sys.stderr,
1422
- )
1805
+ if throw_on_error:
1806
+ raise _exceptions.PromptLayerAPIError(
1807
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1808
+ ) from e
1809
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1423
1810
  return None
1424
1811
 
1425
1812
 
@@ -1448,6 +1835,26 @@ async def amistral_request(
1448
1835
  return await client.chat.complete_async(**function_kwargs)
1449
1836
 
1450
1837
 
1838
+ class _GoogleStreamWrapper:
1839
+ """Wrapper to keep Google client alive during streaming."""
1840
+
1841
+ def __init__(self, stream_generator, client):
1842
+ self._stream = stream_generator
1843
+ self._client = client # Keep client alive
1844
+
1845
+ def __iter__(self):
1846
+ return self._stream.__iter__()
1847
+
1848
+ def __next__(self):
1849
+ return next(self._stream)
1850
+
1851
+ def __aiter__(self):
1852
+ return self._stream.__aiter__()
1853
+
1854
+ async def __anext__(self):
1855
+ return await self._stream.__anext__()
1856
+
1857
+
1451
1858
  def google_chat_request(client, **kwargs):
1452
1859
  from google.genai.chats import Content
1453
1860
 
@@ -1458,7 +1865,8 @@ def google_chat_request(client, **kwargs):
1458
1865
  chat = client.chats.create(model=model, history=history, config=generation_config)
1459
1866
  last_message = history[-1].parts if history else ""
1460
1867
  if stream:
1461
- return chat.send_message_stream(message=last_message)
1868
+ stream_gen = chat.send_message_stream(message=last_message)
1869
+ return _GoogleStreamWrapper(stream_gen, client)
1462
1870
  return chat.send_message(message=last_message)
1463
1871
 
1464
1872
 
@@ -1468,7 +1876,8 @@ def google_completions_request(client, **kwargs):
1468
1876
  contents = kwargs.get("contents", [])
1469
1877
  stream = kwargs.pop("stream", False)
1470
1878
  if stream:
1471
- return client.models.generate_content_stream(model=model, contents=contents, config=config)
1879
+ stream_gen = client.models.generate_content_stream(model=model, contents=contents, config=config)
1880
+ return _GoogleStreamWrapper(stream_gen, client)
1472
1881
  return client.models.generate_content(model=model, contents=contents, config=config)
1473
1882
 
1474
1883
 
@@ -1503,7 +1912,8 @@ async def agoogle_chat_request(client, **kwargs):
1503
1912
  chat = client.aio.chats.create(model=model, history=history, config=generation_config)
1504
1913
  last_message = history[-1].parts[0] if history else ""
1505
1914
  if stream:
1506
- return await chat.send_message_stream(message=last_message)
1915
+ stream_gen = await chat.send_message_stream(message=last_message)
1916
+ return _GoogleStreamWrapper(stream_gen, client)
1507
1917
  return await chat.send_message(message=last_message)
1508
1918
 
1509
1919
 
@@ -1513,8 +1923,9 @@ async def agoogle_completions_request(client, **kwargs):
1513
1923
  contents = kwargs.get("contents", [])
1514
1924
  stream = kwargs.pop("stream", False)
1515
1925
  if stream:
1516
- return await client.aio.models.generate_content_stream(model=model, contents=contents, config=config)
1517
- return await client.aio.models.generate_content(model=model, contents=contents, config=config)
1926
+ stream_gen = await client.aio.models.generate_content_stream(model=model, contents=contents, config=config)
1927
+ return _GoogleStreamWrapper(stream_gen, client)
1928
+ return await client.aio.models.generate_content(model=model, contents=contents, config=config)
1518
1929
 
1519
1930
 
1520
1931
  AMAP_TYPE_TO_GOOGLE_FUNCTION = {