promptlayer 1.0.35__py3-none-any.whl → 1.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptlayer/utils.py CHANGED
@@ -3,29 +3,38 @@ import contextvars
3
3
  import datetime
4
4
  import functools
5
5
  import json
6
+ import logging
6
7
  import os
7
- import sys
8
8
  import types
9
+ from contextlib import asynccontextmanager
9
10
  from copy import deepcopy
10
11
  from enum import Enum
11
- from typing import (
12
- Any,
13
- AsyncGenerator,
14
- AsyncIterable,
15
- Callable,
16
- Dict,
17
- Generator,
18
- List,
19
- Optional,
20
- Union,
21
- )
12
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
13
+ from urllib.parse import quote
14
+ from uuid import uuid4
22
15
 
23
16
  import httpx
24
17
  import requests
18
+ import urllib3
19
+ import urllib3.util
25
20
  from ably import AblyRealtime
26
21
  from ably.types.message import Message
22
+ from centrifuge import (
23
+ Client,
24
+ PublicationContext,
25
+ SubscriptionEventHandler,
26
+ SubscriptionState,
27
+ )
27
28
  from opentelemetry import context, trace
29
+ from tenacity import (
30
+ before_sleep_log,
31
+ retry,
32
+ retry_if_exception,
33
+ stop_after_attempt,
34
+ wait_exponential,
35
+ )
28
36
 
37
+ from promptlayer import exceptions as _exceptions
29
38
  from promptlayer.types import RequestLog
30
39
  from promptlayer.types.prompt_template import (
31
40
  GetPromptTemplate,
@@ -35,112 +44,341 @@ from promptlayer.types.prompt_template import (
35
44
  PublishPromptTemplateResponse,
36
45
  )
37
46
 
38
- URL_API_PROMPTLAYER = os.environ.setdefault(
39
- "URL_API_PROMPTLAYER", "https://api.promptlayer.com"
47
+ # Configuration
48
+ RERAISE_ORIGINAL_EXCEPTION = os.getenv("PROMPTLAYER_RE_RAISE_ORIGINAL_EXCEPTION", "False").lower() == "true"
49
+ RAISE_FOR_STATUS = os.getenv("PROMPTLAYER_RAISE_FOR_STATUS", "False").lower() == "true"
50
+ DEFAULT_HTTP_TIMEOUT = 5
51
+
52
+ WORKFLOW_RUN_URL_TEMPLATE = "{base_url}/workflows/{workflow_id}/run"
53
+ WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE = "workflows:{workflow_id}:run:{channel_name_suffix}"
54
+ SET_WORKFLOW_COMPLETE_MESSAGE = "SET_WORKFLOW_COMPLETE"
55
+ WS_TOKEN_REQUEST_LIBRARY_URL = (
56
+ f"{os.getenv('PROMPTLAYER_BASE_URL', 'https://api.promptlayer.com')}/ws-token-request-library"
40
57
  )
41
58
 
42
59
 
43
- async def arun_workflow_request(
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ class FinalOutputCode(Enum):
64
+ OK = "OK"
65
+ EXCEEDS_SIZE_LIMIT = "EXCEEDS_SIZE_LIMIT"
66
+
67
+
68
+ def should_retry_error(exception):
69
+ """Check if an exception should trigger a retry.
70
+
71
+ Only retries on server errors (5xx) and rate limits (429).
72
+ """
73
+ if hasattr(exception, "response"):
74
+ response = exception.response
75
+ if hasattr(response, "status_code"):
76
+ status_code = response.status_code
77
+ if status_code >= 500 or status_code == 429:
78
+ return True
79
+
80
+ if isinstance(exception, (_exceptions.PromptLayerInternalServerError, _exceptions.PromptLayerRateLimitError)):
81
+ return True
82
+
83
+ return False
84
+
85
+
86
+ def retry_on_api_error(func):
87
+ return retry(
88
+ retry=retry_if_exception(should_retry_error),
89
+ stop=stop_after_attempt(4), # 4 total attempts (1 initial + 3 retries)
90
+ wait=wait_exponential(multiplier=2, max=15), # 2s, 4s, 8s
91
+ before_sleep=before_sleep_log(logger, logging.WARNING),
92
+ reraise=True,
93
+ )(func)
94
+
95
+
96
+ def _get_http_timeout():
97
+ try:
98
+ return float(os.getenv("PROMPTLAYER_HTTP_TIMEOUT", DEFAULT_HTTP_TIMEOUT))
99
+ except (ValueError, TypeError):
100
+ return DEFAULT_HTTP_TIMEOUT
101
+
102
+
103
+ def _make_httpx_client():
104
+ return httpx.AsyncClient(timeout=_get_http_timeout())
105
+
106
+
107
+ def _make_simple_httpx_client():
108
+ return httpx.Client(timeout=_get_http_timeout())
109
+
110
+
111
+ def _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name):
112
+ # This is backward compatibility code
113
+ if (workflow_id_or_name := workflow_name if workflow_id_or_name is None else workflow_id_or_name) is None:
114
+ raise ValueError('Either "workflow_id_or_name" or "workflow_name" must be provided')
115
+
116
+ return workflow_id_or_name
117
+
118
+
119
+ async def _get_final_output(
120
+ base_url: str, execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]
121
+ ) -> Dict[str, Any]:
122
+ async with httpx.AsyncClient() as client:
123
+ response = await client.get(
124
+ f"{base_url}/workflow-version-execution-results",
125
+ headers=headers,
126
+ params={"workflow_version_execution_id": execution_id, "return_all_outputs": return_all_outputs},
127
+ )
128
+ if response.status_code != 200:
129
+ raise_on_bad_response(response, "PromptLayer had the following error while getting workflow results")
130
+ return response.json()
131
+
132
+
133
+ # TODO(dmu) MEDIUM: Consider putting all these functions into a class, so we do not have to pass
134
+ # `authorization_headers` into each function
135
+ async def _resolve_workflow_id(base_url: str, workflow_id_or_name: Union[int, str], headers):
136
+ if isinstance(workflow_id_or_name, int):
137
+ return workflow_id_or_name
138
+
139
+ # TODO(dmu) LOW: Should we warn user here to avoid using workflow names in favor of workflow id?
140
+ async with _make_httpx_client() as client:
141
+ # TODO(dmu) MEDIUM: Generalize the way we make async calls to PromptLayer API and reuse it everywhere
142
+ response = await client.get(f"{base_url}/workflows/{workflow_id_or_name}", headers=headers)
143
+ if response.status_code != 200:
144
+ raise_on_bad_response(response, "PromptLayer had the following error while resolving workflow")
145
+
146
+ return response.json()["workflow"]["id"]
147
+
148
+
149
+ async def _get_ably_token(base_url: str, channel_name, authentication_headers):
150
+ try:
151
+ async with _make_httpx_client() as client:
152
+ response = await client.post(
153
+ f"{base_url}/ws-token-request-library",
154
+ headers=authentication_headers,
155
+ params={"capability": channel_name},
156
+ )
157
+ if response.status_code != 201:
158
+ raise_on_bad_response(
159
+ response,
160
+ "PromptLayer had the following error while getting WebSocket token",
161
+ )
162
+ return response.json()
163
+ except Exception as ex:
164
+ error_message = f"Failed to get WebSocket token: {ex}"
165
+ logger.exception(error_message)
166
+ if RERAISE_ORIGINAL_EXCEPTION:
167
+ raise
168
+ else:
169
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
170
+
171
+
172
+ def _make_message_listener(base_url: str, results_future, execution_id_future, return_all_outputs, headers):
173
+ # We need this function to be mocked by unittests
174
+ async def message_listener(message: Message):
175
+ if results_future.cancelled() or message.name != SET_WORKFLOW_COMPLETE_MESSAGE:
176
+ return # TODO(dmu) LOW: Do we really need this check?
177
+
178
+ execution_id = await asyncio.wait_for(execution_id_future, _get_http_timeout() * 1.1)
179
+ message_data = json.loads(message.data)
180
+ if message_data["workflow_version_execution_id"] != execution_id:
181
+ return
182
+
183
+ if (result_code := message_data.get("result_code")) in (FinalOutputCode.OK.value, None):
184
+ results = message_data["final_output"]
185
+ elif result_code == FinalOutputCode.EXCEEDS_SIZE_LIMIT.value:
186
+ results = await _get_final_output(base_url, execution_id, return_all_outputs, headers=headers)
187
+ else:
188
+ raise NotImplementedError(f"Unsupported final output code: {result_code}")
189
+
190
+ results_future.set_result(results)
191
+
192
+ return message_listener
193
+
194
+
195
+ async def _subscribe_to_workflow_completion_channel(
196
+ base_url: str, channel, execution_id_future, return_all_outputs, headers
197
+ ):
198
+ results_future = asyncio.Future()
199
+ message_listener = _make_message_listener(
200
+ base_url, results_future, execution_id_future, return_all_outputs, headers
201
+ )
202
+ await channel.subscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
203
+ return results_future, message_listener
204
+
205
+
206
+ async def _post_workflow_id_run(
44
207
  *,
45
- workflow_name: str,
208
+ base_url: str,
209
+ authentication_headers,
210
+ workflow_id,
46
211
  input_variables: Dict[str, Any],
47
- metadata: Optional[Dict[str, Any]] = None,
48
- workflow_label_name: Optional[str] = None,
49
- workflow_version_number: Optional[int] = None,
50
- api_key: str,
51
- return_all_outputs: Optional[bool] = False,
52
- timeout: Optional[int] = 120,
53
- ) -> Dict[str, Any]:
212
+ metadata: Dict[str, Any],
213
+ workflow_label_name: str,
214
+ workflow_version_number: int,
215
+ return_all_outputs: bool,
216
+ channel_name_suffix: str,
217
+ _url_template: str = WORKFLOW_RUN_URL_TEMPLATE,
218
+ ):
219
+ url = _url_template.format(base_url=base_url, workflow_id=workflow_id)
54
220
  payload = {
55
221
  "input_variables": input_variables,
56
222
  "metadata": metadata,
57
223
  "workflow_label_name": workflow_label_name,
58
224
  "workflow_version_number": workflow_version_number,
59
225
  "return_all_outputs": return_all_outputs,
226
+ "channel_name_suffix": channel_name_suffix,
60
227
  }
61
-
62
- url = f"{URL_API_PROMPTLAYER}/workflows/{workflow_name}/run"
63
- headers = {"X-API-KEY": api_key}
64
-
65
228
  try:
66
- async with httpx.AsyncClient() as client:
67
- response = await client.post(url, json=payload, headers=headers)
229
+ async with _make_httpx_client() as client:
230
+ response = await client.post(url, json=payload, headers=authentication_headers)
68
231
  if response.status_code != 201:
69
- raise_on_bad_response(
70
- response,
71
- "PromptLayer had the following error while running your workflow",
72
- )
232
+ raise_on_bad_response(response, "PromptLayer had the following error while running your workflow")
73
233
 
74
234
  result = response.json()
75
- warning = result.get("warning")
76
- if warning:
77
- print(f"WARNING: {warning}")
78
-
79
- except Exception as e:
80
- error_message = f"Failed to run workflow: {str(e)}"
81
- print(error_message)
82
- raise Exception(error_message)
235
+ if warning := result.get("warning"):
236
+ logger.warning(f"{warning}")
237
+ except Exception as ex:
238
+ error_message = f"Failed to run workflow: {str(ex)}"
239
+ logger.exception(error_message)
240
+ if RERAISE_ORIGINAL_EXCEPTION:
241
+ raise
242
+ else:
243
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
83
244
 
84
- execution_id = result.get("workflow_version_execution_id")
85
- if not execution_id:
86
- raise Exception("No execution ID returned from workflow run")
245
+ return result.get("workflow_version_execution_id")
87
246
 
88
- channel_name = f"workflow_updates:{execution_id}"
89
247
 
90
- # Get WebSocket token
248
+ async def _wait_for_workflow_completion(channel, results_future, message_listener, timeout):
249
+ # We need this function for mocking in unittests
91
250
  try:
92
- async with httpx.AsyncClient() as client:
93
- ws_response = await client.post(
94
- f"{URL_API_PROMPTLAYER}/ws-token-request-library",
95
- headers=headers,
96
- params={"capability": channel_name},
97
- )
98
- if ws_response.status_code != 201:
99
- raise_on_bad_response(
100
- ws_response,
101
- "PromptLayer had the following error while getting WebSocket token",
102
- )
103
- token_details = ws_response.json()["token_details"]
104
- except Exception as e:
105
- error_message = f"Failed to get WebSocket token: {e}"
106
- print(error_message)
107
- raise Exception(error_message)
251
+ return await asyncio.wait_for(results_future, timeout)
252
+ except asyncio.TimeoutError:
253
+ raise _exceptions.PromptLayerAPITimeoutError(
254
+ "Workflow execution did not complete properly", response=None, body=None
255
+ )
256
+ finally:
257
+ channel.unsubscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
108
258
 
109
- # Initialize Ably client
110
- ably_client = AblyRealtime(token=token_details["token"])
111
259
 
112
- # Subscribe to the channel named after the execution ID
113
- channel = ably_client.channels.get(channel_name)
260
+ def _make_channel_name_suffix():
261
+ # We need this function for mocking in unittests
262
+ return uuid4().hex
114
263
 
115
- final_output = {}
116
- message_received_event = asyncio.Event()
117
264
 
118
- async def message_listener(message: Message):
119
- if message.name == "set_workflow_node_output":
120
- data = json.loads(message.data)
121
- if data.get("status") == "workflow_complete":
122
- final_output.update(data.get("final_output", {}))
123
- message_received_event.set()
265
+ MessageCallback = Callable[[Message], Coroutine[None, None, None]]
266
+
124
267
 
125
- # Subscribe to the channel
126
- await channel.subscribe("set_workflow_node_output", message_listener)
268
+ class SubscriptionEventLoggerHandler(SubscriptionEventHandler):
269
+ def __init__(self, callback: MessageCallback):
270
+ self.callback = callback
127
271
 
128
- # Wait for the message or timeout
272
+ async def on_publication(self, ctx: PublicationContext):
273
+ message_name = ctx.pub.data.get("message_name", "unknown")
274
+ data = ctx.pub.data.get("data", "")
275
+ message = Message(name=message_name, data=data)
276
+ await self.callback(message)
277
+
278
+
279
+ @asynccontextmanager
280
+ async def centrifugo_client(address: str, token: str):
281
+ client = Client(address, token=token)
129
282
  try:
130
- await asyncio.wait_for(message_received_event.wait(), timeout)
131
- except asyncio.TimeoutError:
132
- channel.unsubscribe("set_workflow_node_output", message_listener)
133
- await ably_client.close()
134
- raise Exception("Workflow execution did not complete properly")
283
+ await client.connect()
284
+ yield client
285
+ finally:
286
+ await client.disconnect()
135
287
 
136
- # Unsubscribe from the channel and close the client
137
- channel.unsubscribe("set_workflow_node_output", message_listener)
138
- await ably_client.close()
139
288
 
140
- return final_output
289
+ @asynccontextmanager
290
+ async def centrifugo_subscription(client: Client, topic: str, message_listener: MessageCallback):
291
+ subscription = client.new_subscription(
292
+ topic,
293
+ events=SubscriptionEventLoggerHandler(message_listener),
294
+ )
295
+ try:
296
+ await subscription.subscribe()
297
+ yield
298
+ finally:
299
+ if subscription.state == SubscriptionState.SUBSCRIBED:
300
+ await subscription.unsubscribe()
301
+
302
+
303
+ @retry_on_api_error
304
+ async def arun_workflow_request(
305
+ *,
306
+ api_key: str,
307
+ base_url: str,
308
+ throw_on_error: bool,
309
+ workflow_id_or_name: Optional[Union[int, str]] = None,
310
+ input_variables: Dict[str, Any],
311
+ metadata: Optional[Dict[str, Any]] = None,
312
+ workflow_label_name: Optional[str] = None,
313
+ workflow_version_number: Optional[int] = None,
314
+ return_all_outputs: Optional[bool] = False,
315
+ timeout: Optional[int] = 3600,
316
+ # `workflow_name` deprecated, kept for backward compatibility only.
317
+ workflow_name: Optional[str] = None,
318
+ ):
319
+ headers = {"X-API-KEY": api_key}
320
+ workflow_id = await _resolve_workflow_id(
321
+ base_url, _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
322
+ )
323
+ channel_name_suffix = _make_channel_name_suffix()
324
+ channel_name = WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE.format(
325
+ workflow_id=workflow_id, channel_name_suffix=channel_name_suffix
326
+ )
327
+ ably_token = await _get_ably_token(base_url, channel_name, headers)
328
+ token = ably_token["token_details"]["token"]
329
+
330
+ execution_id_future = asyncio.Future[int]()
331
+
332
+ if ably_token.get("messaging_backend") == "centrifugo":
333
+ address = urllib3.util.parse_url(base_url)._replace(scheme="wss", path="/connection/websocket").url
334
+ async with centrifugo_client(address, token) as client:
335
+ results_future = asyncio.Future[dict[str, Any]]()
336
+ async with centrifugo_subscription(
337
+ client,
338
+ channel_name,
339
+ _make_message_listener(base_url, results_future, execution_id_future, return_all_outputs, headers),
340
+ ):
341
+ execution_id = await _post_workflow_id_run(
342
+ base_url=base_url,
343
+ authentication_headers=headers,
344
+ workflow_id=workflow_id,
345
+ input_variables=input_variables,
346
+ metadata=metadata,
347
+ workflow_label_name=workflow_label_name,
348
+ workflow_version_number=workflow_version_number,
349
+ return_all_outputs=return_all_outputs,
350
+ channel_name_suffix=channel_name_suffix,
351
+ )
352
+ execution_id_future.set_result(execution_id)
353
+ await asyncio.wait_for(results_future, timeout)
354
+ return results_future.result()
355
+
356
+ async with AblyRealtime(token=token) as ably_client:
357
+ # It is crucial to subscribe before running a workflow, otherwise we may miss a completion message
358
+ channel = ably_client.channels.get(channel_name)
359
+ results_future, message_listener = await _subscribe_to_workflow_completion_channel(
360
+ base_url, channel, execution_id_future, return_all_outputs, headers
361
+ )
362
+
363
+ execution_id = await _post_workflow_id_run(
364
+ base_url=base_url,
365
+ authentication_headers=headers,
366
+ workflow_id=workflow_id,
367
+ input_variables=input_variables,
368
+ metadata=metadata,
369
+ workflow_label_name=workflow_label_name,
370
+ workflow_version_number=workflow_version_number,
371
+ return_all_outputs=return_all_outputs,
372
+ channel_name_suffix=channel_name_suffix,
373
+ )
374
+ execution_id_future.set_result(execution_id)
375
+
376
+ return await _wait_for_workflow_completion(channel, results_future, message_listener, timeout)
141
377
 
142
378
 
143
379
  def promptlayer_api_handler(
380
+ api_key: str,
381
+ base_url: str,
144
382
  function_name,
145
383
  provider_type,
146
384
  args,
@@ -149,20 +387,13 @@ def promptlayer_api_handler(
149
387
  response,
150
388
  request_start_time,
151
389
  request_end_time,
152
- api_key,
153
390
  return_pl_id=False,
154
391
  llm_request_span_id=None,
155
392
  ):
156
393
  if (
157
394
  isinstance(response, types.GeneratorType)
158
395
  or isinstance(response, types.AsyncGeneratorType)
159
- or type(response).__name__
160
- in [
161
- "Stream",
162
- "AsyncStream",
163
- "AsyncMessageStreamManager",
164
- "MessageStreamManager",
165
- ]
396
+ or type(response).__name__ in ["Stream", "AsyncStream", "AsyncMessageStreamManager", "MessageStreamManager"]
166
397
  ):
167
398
  return GeneratorProxy(
168
399
  generator=response,
@@ -178,9 +409,11 @@ def promptlayer_api_handler(
178
409
  "llm_request_span_id": llm_request_span_id,
179
410
  },
180
411
  api_key=api_key,
412
+ base_url=base_url,
181
413
  )
182
414
  else:
183
415
  request_id = promptlayer_api_request(
416
+ base_url=base_url,
184
417
  function_name=function_name,
185
418
  provider_type=provider_type,
186
419
  args=args,
@@ -199,6 +432,8 @@ def promptlayer_api_handler(
199
432
 
200
433
 
201
434
  async def promptlayer_api_handler_async(
435
+ api_key: str,
436
+ base_url: str,
202
437
  function_name,
203
438
  provider_type,
204
439
  args,
@@ -207,13 +442,14 @@ async def promptlayer_api_handler_async(
207
442
  response,
208
443
  request_start_time,
209
444
  request_end_time,
210
- api_key,
211
445
  return_pl_id=False,
212
446
  llm_request_span_id=None,
213
447
  ):
214
448
  return await run_in_thread_async(
215
449
  None,
216
450
  promptlayer_api_handler,
451
+ api_key,
452
+ base_url,
217
453
  function_name,
218
454
  provider_type,
219
455
  args,
@@ -222,7 +458,6 @@ async def promptlayer_api_handler_async(
222
458
  response,
223
459
  request_start_time,
224
460
  request_end_time,
225
- api_key,
226
461
  return_pl_id=return_pl_id,
227
462
  llm_request_span_id=llm_request_span_id,
228
463
  )
@@ -236,15 +471,13 @@ def convert_native_object_to_dict(native_object):
236
471
  if isinstance(native_object, Enum):
237
472
  return native_object.value
238
473
  if hasattr(native_object, "__dict__"):
239
- return {
240
- k: convert_native_object_to_dict(v)
241
- for k, v in native_object.__dict__.items()
242
- }
474
+ return {k: convert_native_object_to_dict(v) for k, v in native_object.__dict__.items()}
243
475
  return native_object
244
476
 
245
477
 
246
478
  def promptlayer_api_request(
247
479
  *,
480
+ base_url: str,
248
481
  function_name,
249
482
  provider_type,
250
483
  args,
@@ -261,13 +494,11 @@ def promptlayer_api_request(
261
494
  if isinstance(response, dict) and hasattr(response, "to_dict_recursive"):
262
495
  response = response.to_dict_recursive()
263
496
  request_response = None
264
- if hasattr(
265
- response, "dict"
266
- ): # added this for anthropic 3.0 changes, they return a completion object
497
+ if hasattr(response, "dict"): # added this for anthropic 3.0 changes, they return a completion object
267
498
  response = response.dict()
268
499
  try:
269
500
  request_response = requests.post(
270
- f"{URL_API_PROMPTLAYER}/track-request",
501
+ f"{base_url}/track-request",
271
502
  json={
272
503
  "function_name": function_name,
273
504
  "provider_type": provider_type,
@@ -284,61 +515,64 @@ def promptlayer_api_request(
284
515
  )
285
516
  if not hasattr(request_response, "status_code"):
286
517
  warn_on_bad_response(
287
- request_response,
288
- "WARNING: While logging your request PromptLayer had the following issue",
518
+ request_response, "WARNING: While logging your request PromptLayer had the following issue"
289
519
  )
290
520
  elif request_response.status_code != 200:
291
521
  warn_on_bad_response(
292
- request_response,
293
- "WARNING: While logging your request PromptLayer had the following error",
522
+ request_response, "WARNING: While logging your request PromptLayer had the following error"
294
523
  )
295
524
  except Exception as e:
296
- print(
297
- f"WARNING: While logging your request PromptLayer had the following error: {e}",
298
- file=sys.stderr,
299
- )
525
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
300
526
  if request_response is not None and return_pl_id:
301
527
  return request_response.json().get("request_id")
302
528
 
303
529
 
304
- def track_request(**body):
530
+ @retry_on_api_error
531
+ def track_request(base_url: str, throw_on_error: bool, **body):
305
532
  try:
306
533
  response = requests.post(
307
- f"{URL_API_PROMPTLAYER}/track-request",
534
+ f"{base_url}/track-request",
308
535
  json=body,
309
536
  )
310
537
  if response.status_code != 200:
311
- warn_on_bad_response(
312
- response,
313
- f"PromptLayer had the following error while tracking your request: {response.text}",
314
- )
538
+ if throw_on_error:
539
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
540
+ else:
541
+ warn_on_bad_response(
542
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
543
+ )
315
544
  return response.json()
316
545
  except requests.exceptions.RequestException as e:
317
- print(
318
- f"WARNING: While logging your request PromptLayer had the following error: {e}",
319
- file=sys.stderr,
320
- )
546
+ if throw_on_error:
547
+ raise _exceptions.PromptLayerAPIConnectionError(
548
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
549
+ ) from e
550
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
321
551
  return {}
322
552
 
323
553
 
324
- async def atrack_request(**body: Any) -> Dict[str, Any]:
554
+ @retry_on_api_error
555
+ async def atrack_request(base_url: str, throw_on_error: bool, **body: Any) -> Dict[str, Any]:
325
556
  try:
326
- async with httpx.AsyncClient() as client:
557
+ async with _make_httpx_client() as client:
327
558
  response = await client.post(
328
- f"{URL_API_PROMPTLAYER}/track-request",
559
+ f"{base_url}/track-request",
329
560
  json=body,
330
561
  )
331
- if response.status_code != 200:
332
- warn_on_bad_response(
333
- response,
334
- f"PromptLayer had the following error while tracking your request: {response.text}",
335
- )
562
+ if response.status_code != 200:
563
+ if throw_on_error:
564
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
565
+ else:
566
+ warn_on_bad_response(
567
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
568
+ )
336
569
  return response.json()
337
570
  except httpx.RequestError as e:
338
- print(
339
- f"WARNING: While logging your request PromptLayer had the following error: {e}",
340
- file=sys.stderr,
341
- )
571
+ if throw_on_error:
572
+ raise _exceptions.PromptLayerAPIConnectionError(
573
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
574
+ ) from e
575
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
342
576
  return {}
343
577
 
344
578
 
@@ -370,8 +604,9 @@ def promptlayer_api_request_async(
370
604
  )
371
605
 
372
606
 
607
+ @retry_on_api_error
373
608
  def promptlayer_get_prompt(
374
- prompt_name, api_key, version: int = None, label: str = None
609
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, version: int = None, label: str = None
375
610
  ):
376
611
  """
377
612
  Get a prompt from the PromptLayer library
@@ -380,29 +615,40 @@ def promptlayer_get_prompt(
380
615
  """
381
616
  try:
382
617
  request_response = requests.get(
383
- f"{URL_API_PROMPTLAYER}/library-get-prompt-template",
618
+ f"{base_url}/library-get-prompt-template",
384
619
  headers={"X-API-KEY": api_key},
385
620
  params={"prompt_name": prompt_name, "version": version, "label": label},
386
621
  )
387
622
  except Exception as e:
388
- raise Exception(
389
- f"PromptLayer had the following error while getting your prompt: {e}"
390
- )
623
+ if throw_on_error:
624
+ raise _exceptions.PromptLayerAPIError(
625
+ f"PromptLayer had the following error while getting your prompt: {e}", response=None, body=None
626
+ ) from e
627
+ logger.warning(f"PromptLayer had the following error while getting your prompt: {e}")
628
+ return None
391
629
  if request_response.status_code != 200:
392
- raise_on_bad_response(
393
- request_response,
394
- "PromptLayer had the following error while getting your prompt",
395
- )
630
+ if throw_on_error:
631
+ raise_on_bad_response(
632
+ request_response,
633
+ "PromptLayer had the following error while getting your prompt",
634
+ )
635
+ else:
636
+ warn_on_bad_response(
637
+ request_response,
638
+ "WARNING: PromptLayer had the following error while getting your prompt",
639
+ )
640
+ return None
396
641
 
397
642
  return request_response.json()
398
643
 
399
644
 
645
+ @retry_on_api_error
400
646
  def promptlayer_publish_prompt(
401
- prompt_name, prompt_template, commit_message, tags, api_key, metadata=None
647
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, prompt_template, commit_message, tags, metadata=None
402
648
  ):
403
649
  try:
404
650
  request_response = requests.post(
405
- f"{URL_API_PROMPTLAYER}/library-publish-prompt-template",
651
+ f"{base_url}/library-publish-prompt-template",
406
652
  json={
407
653
  "prompt_name": prompt_name,
408
654
  "prompt_template": prompt_template,
@@ -413,23 +659,34 @@ def promptlayer_publish_prompt(
413
659
  },
414
660
  )
415
661
  except Exception as e:
416
- raise Exception(
417
- f"PromptLayer had the following error while publishing your prompt: {e}"
418
- )
662
+ if throw_on_error:
663
+ raise _exceptions.PromptLayerAPIError(
664
+ f"PromptLayer had the following error while publishing your prompt: {e}", response=None, body=None
665
+ ) from e
666
+ logger.warning(f"PromptLayer had the following error while publishing your prompt: {e}")
667
+ return False
419
668
  if request_response.status_code != 200:
420
- raise_on_bad_response(
421
- request_response,
422
- "PromptLayer had the following error while publishing your prompt",
423
- )
669
+ if throw_on_error:
670
+ raise_on_bad_response(
671
+ request_response,
672
+ "PromptLayer had the following error while publishing your prompt",
673
+ )
674
+ else:
675
+ warn_on_bad_response(
676
+ request_response,
677
+ "WARNING: PromptLayer had the following error while publishing your prompt",
678
+ )
679
+ return False
424
680
  return True
425
681
 
426
682
 
683
+ @retry_on_api_error
427
684
  def promptlayer_track_prompt(
428
- request_id, prompt_name, input_variables, api_key, version, label
685
+ api_key: str, base_url: str, throw_on_error: bool, request_id, prompt_name, input_variables, version, label
429
686
  ):
430
687
  try:
431
688
  request_response = requests.post(
432
- f"{URL_API_PROMPTLAYER}/library-track-prompt",
689
+ f"{base_url}/library-track-prompt",
433
690
  json={
434
691
  "request_id": request_id,
435
692
  "prompt_name": prompt_name,
@@ -440,29 +697,39 @@ def promptlayer_track_prompt(
440
697
  },
441
698
  )
442
699
  if request_response.status_code != 200:
443
- warn_on_bad_response(
444
- request_response,
445
- "WARNING: While tracking your prompt PromptLayer had the following error",
446
- )
447
- return False
700
+ if throw_on_error:
701
+ raise_on_bad_response(
702
+ request_response,
703
+ "While tracking your prompt PromptLayer had the following error",
704
+ )
705
+ else:
706
+ warn_on_bad_response(
707
+ request_response,
708
+ "WARNING: While tracking your prompt PromptLayer had the following error",
709
+ )
710
+ return False
448
711
  except Exception as e:
449
- print(
450
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
451
- file=sys.stderr,
452
- )
712
+ if throw_on_error:
713
+ raise _exceptions.PromptLayerAPIError(
714
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
715
+ ) from e
716
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
453
717
  return False
454
718
  return True
455
719
 
456
720
 
721
+ @retry_on_api_error
457
722
  async def apromptlayer_track_prompt(
723
+ api_key: str,
724
+ base_url: str,
458
725
  request_id: str,
459
726
  prompt_name: str,
460
727
  input_variables: Dict[str, Any],
461
- api_key: Optional[str] = None,
462
728
  version: Optional[int] = None,
463
729
  label: Optional[str] = None,
730
+ throw_on_error: bool = True,
464
731
  ) -> bool:
465
- url = f"{URL_API_PROMPTLAYER}/library-track-prompt"
732
+ url = f"{base_url}/library-track-prompt"
466
733
  payload = {
467
734
  "request_id": request_id,
468
735
  "prompt_name": prompt_name,
@@ -472,28 +739,34 @@ async def apromptlayer_track_prompt(
472
739
  "label": label,
473
740
  }
474
741
  try:
475
- async with httpx.AsyncClient() as client:
742
+ async with _make_httpx_client() as client:
476
743
  response = await client.post(url, json=payload)
744
+
477
745
  if response.status_code != 200:
478
- warn_on_bad_response(
479
- response,
480
- "WARNING: While tracking your prompt, PromptLayer had the following error",
481
- )
482
- return False
746
+ if throw_on_error:
747
+ raise_on_bad_response(response, "While tracking your prompt, PromptLayer had the following error")
748
+ else:
749
+ warn_on_bad_response(
750
+ response,
751
+ "WARNING: While tracking your prompt, PromptLayer had the following error",
752
+ )
753
+ return False
483
754
  except httpx.RequestError as e:
484
- print(
485
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
486
- file=sys.stderr,
487
- )
755
+ if throw_on_error:
756
+ raise _exceptions.PromptLayerAPIConnectionError(
757
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
758
+ ) from e
759
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
488
760
  return False
489
761
 
490
762
  return True
491
763
 
492
764
 
493
- def promptlayer_track_metadata(request_id, metadata, api_key):
765
+ @retry_on_api_error
766
+ def promptlayer_track_metadata(api_key: str, base_url: str, throw_on_error: bool, request_id, metadata):
494
767
  try:
495
768
  request_response = requests.post(
496
- f"{URL_API_PROMPTLAYER}/library-track-metadata",
769
+ f"{base_url}/library-track-metadata",
497
770
  json={
498
771
  "request_id": request_id,
499
772
  "metadata": metadata,
@@ -501,79 +774,106 @@ def promptlayer_track_metadata(request_id, metadata, api_key):
501
774
  },
502
775
  )
503
776
  if request_response.status_code != 200:
504
- warn_on_bad_response(
505
- request_response,
506
- "WARNING: While tracking your metadata PromptLayer had the following error",
507
- )
508
- return False
777
+ if throw_on_error:
778
+ raise_on_bad_response(
779
+ request_response,
780
+ "While tracking your metadata PromptLayer had the following error",
781
+ )
782
+ else:
783
+ warn_on_bad_response(
784
+ request_response,
785
+ "WARNING: While tracking your metadata PromptLayer had the following error",
786
+ )
787
+ return False
509
788
  except Exception as e:
510
- print(
511
- f"WARNING: While tracking your metadata PromptLayer had the following error: {e}",
512
- file=sys.stderr,
513
- )
789
+ if throw_on_error:
790
+ raise _exceptions.PromptLayerAPIError(
791
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
792
+ ) from e
793
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
514
794
  return False
515
795
  return True
516
796
 
517
797
 
798
+ @retry_on_api_error
518
799
  async def apromptlayer_track_metadata(
519
- request_id: str, metadata: Dict[str, Any], api_key: Optional[str] = None
800
+ api_key: str, base_url: str, throw_on_error: bool, request_id: str, metadata: Dict[str, Any]
520
801
  ) -> bool:
521
- url = f"{URL_API_PROMPTLAYER}/library-track-metadata"
802
+ url = f"{base_url}/library-track-metadata"
522
803
  payload = {
523
804
  "request_id": request_id,
524
805
  "metadata": metadata,
525
806
  "api_key": api_key,
526
807
  }
527
808
  try:
528
- async with httpx.AsyncClient() as client:
809
+ async with _make_httpx_client() as client:
529
810
  response = await client.post(url, json=payload)
811
+
530
812
  if response.status_code != 200:
531
- warn_on_bad_response(
532
- response,
533
- "WARNING: While tracking your metadata, PromptLayer had the following error",
534
- )
535
- return False
813
+ if throw_on_error:
814
+ raise_on_bad_response(
815
+ response,
816
+ "While tracking your metadata, PromptLayer had the following error",
817
+ )
818
+ else:
819
+ warn_on_bad_response(
820
+ response,
821
+ "WARNING: While tracking your metadata, PromptLayer had the following error",
822
+ )
823
+ return False
536
824
  except httpx.RequestError as e:
537
- print(
538
- f"WARNING: While tracking your metadata PromptLayer had the following error: {e}",
539
- file=sys.stderr,
540
- )
825
+ if throw_on_error:
826
+ raise _exceptions.PromptLayerAPIConnectionError(
827
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
828
+ ) from e
829
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
541
830
  return False
542
831
 
543
832
  return True
544
833
 
545
834
 
546
- def promptlayer_track_score(request_id, score, score_name, api_key):
835
+ @retry_on_api_error
836
+ def promptlayer_track_score(api_key: str, base_url: str, throw_on_error: bool, request_id, score, score_name):
547
837
  try:
548
838
  data = {"request_id": request_id, "score": score, "api_key": api_key}
549
839
  if score_name is not None:
550
840
  data["name"] = score_name
551
841
  request_response = requests.post(
552
- f"{URL_API_PROMPTLAYER}/library-track-score",
842
+ f"{base_url}/library-track-score",
553
843
  json=data,
554
844
  )
555
845
  if request_response.status_code != 200:
556
- warn_on_bad_response(
557
- request_response,
558
- "WARNING: While tracking your score PromptLayer had the following error",
559
- )
560
- return False
846
+ if throw_on_error:
847
+ raise_on_bad_response(
848
+ request_response,
849
+ "While tracking your score PromptLayer had the following error",
850
+ )
851
+ else:
852
+ warn_on_bad_response(
853
+ request_response,
854
+ "WARNING: While tracking your score PromptLayer had the following error",
855
+ )
856
+ return False
561
857
  except Exception as e:
562
- print(
563
- f"WARNING: While tracking your score PromptLayer had the following error: {e}",
564
- file=sys.stderr,
565
- )
858
+ if throw_on_error:
859
+ raise _exceptions.PromptLayerAPIError(
860
+ f"While tracking your score PromptLayer had the following error: {e}", response=None, body=None
861
+ ) from e
862
+ logger.warning(f"While tracking your score PromptLayer had the following error: {e}")
566
863
  return False
567
864
  return True
568
865
 
569
866
 
867
+ @retry_on_api_error
570
868
  async def apromptlayer_track_score(
869
+ api_key: str,
870
+ base_url: str,
871
+ throw_on_error: bool,
571
872
  request_id: str,
572
873
  score: float,
573
874
  score_name: Optional[str],
574
- api_key: Optional[str] = None,
575
875
  ) -> bool:
576
- url = f"{URL_API_PROMPTLAYER}/library-track-score"
876
+ url = f"{base_url}/library-track-score"
577
877
  data = {
578
878
  "request_id": request_id,
579
879
  "score": score,
@@ -582,30 +882,96 @@ async def apromptlayer_track_score(
582
882
  if score_name is not None:
583
883
  data["name"] = score_name
584
884
  try:
585
- async with httpx.AsyncClient() as client:
885
+ async with _make_httpx_client() as client:
586
886
  response = await client.post(url, json=data)
887
+
587
888
  if response.status_code != 200:
588
- warn_on_bad_response(
589
- response,
590
- "WARNING: While tracking your score, PromptLayer had the following error",
591
- )
592
- return False
889
+ if throw_on_error:
890
+ raise_on_bad_response(
891
+ response,
892
+ "While tracking your score, PromptLayer had the following error",
893
+ )
894
+ else:
895
+ warn_on_bad_response(
896
+ response,
897
+ "WARNING: While tracking your score, PromptLayer had the following error",
898
+ )
899
+ return False
593
900
  except httpx.RequestError as e:
594
- print(
595
- f"WARNING: While tracking your score PromptLayer had the following error: {str(e)}",
596
- file=sys.stderr,
597
- )
901
+ if throw_on_error:
902
+ raise _exceptions.PromptLayerAPIConnectionError(
903
+ f"PromptLayer had the following error while tracking your score: {str(e)}", response=None, body=None
904
+ ) from e
905
+ logger.warning(f"While tracking your score PromptLayer had the following error: {str(e)}")
598
906
  return False
599
907
 
600
908
  return True
601
909
 
602
910
 
911
+ def build_anthropic_content_blocks(events):
912
+ content_blocks = []
913
+ current_block = None
914
+ current_signature = ""
915
+ current_thinking = ""
916
+ current_text = ""
917
+ current_tool_input_json = ""
918
+ usage = None
919
+ stop_reason = None
920
+
921
+ for event in events:
922
+ if event.type == "content_block_start":
923
+ current_block = deepcopy(event.content_block)
924
+ if current_block.type == "thinking":
925
+ current_signature = ""
926
+ current_thinking = ""
927
+ elif current_block.type == "text":
928
+ current_text = ""
929
+ elif current_block.type == "tool_use":
930
+ current_tool_input_json = ""
931
+ elif event.type == "content_block_delta" and current_block is not None:
932
+ if current_block.type == "thinking":
933
+ if hasattr(event.delta, "signature"):
934
+ current_signature = event.delta.signature
935
+ if hasattr(event.delta, "thinking"):
936
+ current_thinking += event.delta.thinking
937
+ elif current_block.type == "text":
938
+ if hasattr(event.delta, "text"):
939
+ current_text += event.delta.text
940
+ elif current_block.type == "tool_use":
941
+ if hasattr(event.delta, "partial_json"):
942
+ current_tool_input_json += event.delta.partial_json
943
+ elif event.type == "content_block_stop" and current_block is not None:
944
+ if current_block.type == "thinking":
945
+ current_block.signature = current_signature
946
+ current_block.thinking = current_thinking
947
+ elif current_block.type == "text":
948
+ current_block.text = current_text
949
+ elif current_block.type == "tool_use":
950
+ try:
951
+ current_block.input = json.loads(current_tool_input_json)
952
+ except json.JSONDecodeError:
953
+ current_block.input = {}
954
+ content_blocks.append(current_block)
955
+ current_block = None
956
+ current_signature = ""
957
+ current_thinking = ""
958
+ current_text = ""
959
+ current_tool_input_json = ""
960
+ elif event.type == "message_delta":
961
+ if hasattr(event, "usage"):
962
+ usage = event.usage
963
+ if hasattr(event.delta, "stop_reason"):
964
+ stop_reason = event.delta.stop_reason
965
+ return content_blocks, usage, stop_reason
966
+
967
+
603
968
  class GeneratorProxy:
604
- def __init__(self, generator, api_request_arguments, api_key):
969
+ def __init__(self, generator, api_request_arguments, api_key, base_url):
605
970
  self.generator = generator
606
971
  self.results = []
607
972
  self.api_request_arugments = api_request_arguments
608
973
  self.api_key = api_key
974
+ self.base_url = base_url
609
975
 
610
976
  def __iter__(self):
611
977
  return self
@@ -620,6 +986,7 @@ class GeneratorProxy:
620
986
  await self.generator._AsyncMessageStreamManager__api_request,
621
987
  api_request_arguments,
622
988
  self.api_key,
989
+ self.base_url,
623
990
  )
624
991
 
625
992
  def __enter__(self):
@@ -630,6 +997,7 @@ class GeneratorProxy:
630
997
  stream,
631
998
  api_request_arguments,
632
999
  self.api_key,
1000
+ self.base_url,
633
1001
  )
634
1002
 
635
1003
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -648,9 +1016,7 @@ class GeneratorProxy:
648
1016
 
649
1017
  def __getattr__(self, name):
650
1018
  if name == "text_stream": # anthropic async stream
651
- return GeneratorProxy(
652
- self.generator.text_stream, self.api_request_arugments, self.api_key
653
- )
1019
+ return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key, self.base_url)
654
1020
  return getattr(self.generator, name)
655
1021
 
656
1022
  def _abstracted_next(self, result):
@@ -667,12 +1033,12 @@ class GeneratorProxy:
667
1033
  end_anthropic = True
668
1034
 
669
1035
  end_openai = provider_type == "openai" and (
670
- result.choices[0].finish_reason == "stop"
671
- or result.choices[0].finish_reason == "length"
1036
+ result.choices[0].finish_reason == "stop" or result.choices[0].finish_reason == "length"
672
1037
  )
673
1038
 
674
1039
  if end_anthropic or end_openai:
675
1040
  request_id = promptlayer_api_request(
1041
+ base_url=self.base_url,
676
1042
  function_name=self.api_request_arugments["function_name"],
677
1043
  provider_type=self.api_request_arugments["provider_type"],
678
1044
  args=self.api_request_arugments["args"],
@@ -683,9 +1049,7 @@ class GeneratorProxy:
683
1049
  request_end_time=self.api_request_arugments["request_end_time"],
684
1050
  api_key=self.api_key,
685
1051
  return_pl_id=self.api_request_arugments["return_pl_id"],
686
- llm_request_span_id=self.api_request_arugments.get(
687
- "llm_request_span_id"
688
- ),
1052
+ llm_request_span_id=self.api_request_arugments.get("llm_request_span_id"),
689
1053
  )
690
1054
 
691
1055
  if self.api_request_arugments["return_pl_id"]:
@@ -702,31 +1066,35 @@ class GeneratorProxy:
702
1066
  response = ""
703
1067
  for result in self.results:
704
1068
  if hasattr(result, "completion"):
705
- response = f"{response}{result.completion}"
1069
+ response += result.completion
706
1070
  elif hasattr(result, "message") and isinstance(result.message, str):
707
- response = f"{response}{result.message}"
1071
+ response += result.message
708
1072
  elif (
709
1073
  hasattr(result, "content_block")
710
1074
  and hasattr(result.content_block, "text")
711
- and "type" in result
712
- and result.type != "message_stop"
1075
+ and getattr(result, "type", None) != "message_stop"
713
1076
  ):
714
- response = f"{response}{result.content_block.text}"
715
- elif hasattr(result, "delta") and hasattr(result.delta, "text"):
716
- response = f"{response}{result.delta.text}"
717
- if (
718
- hasattr(self.results[-1], "type")
719
- and self.results[-1].type == "message_stop"
720
- ): # this is a message stream and not the correct event
1077
+ response += result.content_block.text
1078
+ elif hasattr(result, "delta"):
1079
+ if hasattr(result.delta, "thinking"):
1080
+ response += result.delta.thinking
1081
+ elif hasattr(result.delta, "text"):
1082
+ response += result.delta.text
1083
+
1084
+ # 2) If this is a “stream” (ended by message_stop), reconstruct both ThinkingBlock & TextBlock
1085
+ last_event = self.results[-1]
1086
+ if getattr(last_event, "type", None) == "message_stop":
721
1087
  final_result = deepcopy(self.results[0].message)
722
- final_result.usage = None
723
- content_block = deepcopy(self.results[1].content_block)
724
- content_block.text = response
725
- final_result.content = [content_block]
726
- else:
727
- final_result = deepcopy(self.results[-1])
728
- final_result.completion = response
1088
+
1089
+ content_blocks, usage, stop_reason = build_anthropic_content_blocks(self.results)
1090
+ final_result.content = content_blocks
1091
+ if usage:
1092
+ final_result.usage.output_tokens = usage.output_tokens
1093
+ if stop_reason:
1094
+ final_result.stop_reason = stop_reason
729
1095
  return final_result
1096
+ else:
1097
+ return deepcopy(self.results[-1])
730
1098
  if hasattr(self.results[0].choices[0], "text"): # this is regular completion
731
1099
  response = ""
732
1100
  for result in self.results:
@@ -734,23 +1102,15 @@ class GeneratorProxy:
734
1102
  final_result = deepcopy(self.results[-1])
735
1103
  final_result.choices[0].text = response
736
1104
  return final_result
737
- elif hasattr(
738
- self.results[0].choices[0], "delta"
739
- ): # this is completion with delta
1105
+ elif hasattr(self.results[0].choices[0], "delta"): # this is completion with delta
740
1106
  response = {"role": "", "content": ""}
741
1107
  for result in self.results:
742
- if (
743
- hasattr(result.choices[0].delta, "role")
744
- and result.choices[0].delta.role is not None
745
- ):
1108
+ if hasattr(result.choices[0].delta, "role") and result.choices[0].delta.role is not None:
746
1109
  response["role"] = result.choices[0].delta.role
747
- if (
748
- hasattr(result.choices[0].delta, "content")
749
- and result.choices[0].delta.content is not None
750
- ):
751
- response["content"] = response[
752
- "content"
753
- ] = f"{response['content']}{result.choices[0].delta.content}"
1110
+ if hasattr(result.choices[0].delta, "content") and result.choices[0].delta.content is not None:
1111
+ response["content"] = response["content"] = (
1112
+ f"{response['content']}{result.choices[0].delta.content}"
1113
+ )
754
1114
  final_result = deepcopy(self.results[-1])
755
1115
  final_result.choices[0] = response
756
1116
  return final_result
@@ -769,39 +1129,71 @@ async def run_in_thread_async(executor, func, *args, **kwargs):
769
1129
  def warn_on_bad_response(request_response, main_message):
770
1130
  if hasattr(request_response, "json"):
771
1131
  try:
772
- print(
773
- f"{main_message}: {request_response.json().get('message')}",
774
- file=sys.stderr,
775
- )
1132
+ logger.warning(f"{main_message}: {request_response.json().get('message')}")
776
1133
  except json.JSONDecodeError:
777
- print(
778
- f"{main_message}: {request_response}",
779
- file=sys.stderr,
780
- )
1134
+ logger.warning(f"{main_message}: {request_response}")
781
1135
  else:
782
- print(f"{main_message}: {request_response}", file=sys.stderr)
1136
+ logger.warning(f"{main_message}: {request_response}")
783
1137
 
784
1138
 
785
1139
  def raise_on_bad_response(request_response, main_message):
1140
+ """Raise an appropriate exception based on the HTTP status code."""
1141
+ status_code = getattr(request_response, "status_code", None)
1142
+
1143
+ body = None
1144
+ error_detail = None
786
1145
  if hasattr(request_response, "json"):
787
1146
  try:
788
- raise Exception(
789
- f"{main_message}: {request_response.json().get('message') or request_response.json().get('error')}"
790
- )
791
- except json.JSONDecodeError:
792
- raise Exception(f"{main_message}: {request_response}")
1147
+ body = request_response.json()
1148
+ error_detail = body.get("message") or body.get("error") or body.get("detail")
1149
+ except (json.JSONDecodeError, AttributeError):
1150
+ body = getattr(request_response, "text", str(request_response))
1151
+ error_detail = body
1152
+ else:
1153
+ body = str(request_response)
1154
+ error_detail = body
1155
+
1156
+ if error_detail:
1157
+ err_msg = f"{main_message}: {error_detail}"
793
1158
  else:
794
- raise Exception(f"{main_message}: {request_response}")
1159
+ err_msg = main_message
1160
+
1161
+ if status_code == 400:
1162
+ raise _exceptions.PromptLayerBadRequestError(err_msg, response=request_response, body=body)
1163
+
1164
+ if status_code == 401:
1165
+ raise _exceptions.PromptLayerAuthenticationError(err_msg, response=request_response, body=body)
1166
+
1167
+ if status_code == 403:
1168
+ raise _exceptions.PromptLayerPermissionDeniedError(err_msg, response=request_response, body=body)
1169
+
1170
+ if status_code == 404:
1171
+ raise _exceptions.PromptLayerNotFoundError(err_msg, response=request_response, body=body)
1172
+
1173
+ if status_code == 409:
1174
+ raise _exceptions.PromptLayerConflictError(err_msg, response=request_response, body=body)
1175
+
1176
+ if status_code == 422:
1177
+ raise _exceptions.PromptLayerUnprocessableEntityError(err_msg, response=request_response, body=body)
1178
+
1179
+ if status_code == 429:
1180
+ raise _exceptions.PromptLayerRateLimitError(err_msg, response=request_response, body=body)
1181
+
1182
+ if status_code and status_code >= 500:
1183
+ raise _exceptions.PromptLayerInternalServerError(err_msg, response=request_response, body=body)
1184
+
1185
+ raise _exceptions.PromptLayerAPIStatusError(err_msg, response=request_response, body=body)
795
1186
 
796
1187
 
797
1188
  async def async_wrapper(
1189
+ api_key: str,
1190
+ base_url: str,
798
1191
  coroutine_obj,
799
1192
  return_pl_id,
800
1193
  request_start_time,
801
1194
  function_name,
802
1195
  provider_type,
803
1196
  tags,
804
- api_key: str = None,
805
1197
  llm_request_span_id: str = None,
806
1198
  tracer=None,
807
1199
  *args,
@@ -814,6 +1206,8 @@ async def async_wrapper(
814
1206
  response = await coroutine_obj
815
1207
  request_end_time = datetime.datetime.now().timestamp()
816
1208
  result = await promptlayer_api_handler_async(
1209
+ api_key,
1210
+ base_url,
817
1211
  function_name,
818
1212
  provider_type,
819
1213
  args,
@@ -822,7 +1216,6 @@ async def async_wrapper(
822
1216
  response,
823
1217
  request_start_time,
824
1218
  request_end_time,
825
- api_key,
826
1219
  return_pl_id=return_pl_id,
827
1220
  llm_request_span_id=llm_request_span_id,
828
1221
  )
@@ -837,54 +1230,75 @@ async def async_wrapper(
837
1230
  context.detach(token)
838
1231
 
839
1232
 
840
- def promptlayer_create_group(api_key: str = None):
1233
+ @retry_on_api_error
1234
+ def promptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
841
1235
  try:
842
1236
  request_response = requests.post(
843
- f"{URL_API_PROMPTLAYER}/create-group",
1237
+ f"{base_url}/create-group",
844
1238
  json={
845
1239
  "api_key": api_key,
846
1240
  },
847
1241
  )
848
1242
  if request_response.status_code != 200:
849
- warn_on_bad_response(
850
- request_response,
851
- "WARNING: While creating your group PromptLayer had the following error",
852
- )
853
- return False
1243
+ if throw_on_error:
1244
+ raise_on_bad_response(
1245
+ request_response,
1246
+ "While creating your group PromptLayer had the following error",
1247
+ )
1248
+ else:
1249
+ warn_on_bad_response(
1250
+ request_response,
1251
+ "WARNING: While creating your group PromptLayer had the following error",
1252
+ )
1253
+ return False
854
1254
  except requests.exceptions.RequestException as e:
855
- # I'm aiming for a more specific exception catch here
856
- raise Exception(
857
- f"PromptLayer had the following error while creating your group: {e}"
858
- )
1255
+ if throw_on_error:
1256
+ raise _exceptions.PromptLayerAPIConnectionError(
1257
+ f"PromptLayer had the following error while creating your group: {e}", response=None, body=None
1258
+ ) from e
1259
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1260
+ return False
859
1261
  return request_response.json()["id"]
860
1262
 
861
1263
 
862
- async def apromptlayer_create_group(api_key: Optional[str] = None) -> str:
1264
+ @retry_on_api_error
1265
+ async def apromptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
863
1266
  try:
864
- async with httpx.AsyncClient() as client:
1267
+ async with _make_httpx_client() as client:
865
1268
  response = await client.post(
866
- f"{URL_API_PROMPTLAYER}/create-group",
1269
+ f"{base_url}/create-group",
867
1270
  json={
868
1271
  "api_key": api_key,
869
1272
  },
870
1273
  )
1274
+
871
1275
  if response.status_code != 200:
872
- warn_on_bad_response(
873
- response,
874
- "WARNING: While creating your group, PromptLayer had the following error",
875
- )
876
- return False
1276
+ if throw_on_error:
1277
+ raise_on_bad_response(
1278
+ response,
1279
+ "While creating your group, PromptLayer had the following error",
1280
+ )
1281
+ else:
1282
+ warn_on_bad_response(
1283
+ response,
1284
+ "WARNING: While creating your group, PromptLayer had the following error",
1285
+ )
1286
+ return False
877
1287
  return response.json()["id"]
878
1288
  except httpx.RequestError as e:
879
- raise Exception(
880
- f"PromptLayer had the following error while creating your group: {str(e)}"
881
- ) from e
1289
+ if throw_on_error:
1290
+ raise _exceptions.PromptLayerAPIConnectionError(
1291
+ f"PromptLayer had the following error while creating your group: {str(e)}", response=None, body=None
1292
+ ) from e
1293
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1294
+ return False
882
1295
 
883
1296
 
884
- def promptlayer_track_group(request_id, group_id, api_key: str = None):
1297
+ @retry_on_api_error
1298
+ def promptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
885
1299
  try:
886
1300
  request_response = requests.post(
887
- f"{URL_API_PROMPTLAYER}/track-group",
1301
+ f"{base_url}/track-group",
888
1302
  json={
889
1303
  "api_key": api_key,
890
1304
  "request_id": request_id,
@@ -892,118 +1306,170 @@ def promptlayer_track_group(request_id, group_id, api_key: str = None):
892
1306
  },
893
1307
  )
894
1308
  if request_response.status_code != 200:
895
- warn_on_bad_response(
896
- request_response,
897
- "WARNING: While tracking your group PromptLayer had the following error",
898
- )
899
- return False
1309
+ if throw_on_error:
1310
+ raise_on_bad_response(
1311
+ request_response,
1312
+ "While tracking your group PromptLayer had the following error",
1313
+ )
1314
+ else:
1315
+ warn_on_bad_response(
1316
+ request_response,
1317
+ "WARNING: While tracking your group PromptLayer had the following error",
1318
+ )
1319
+ return False
900
1320
  except requests.exceptions.RequestException as e:
901
- # I'm aiming for a more specific exception catch here
902
- raise Exception(
903
- f"PromptLayer had the following error while tracking your group: {e}"
904
- )
1321
+ if throw_on_error:
1322
+ raise _exceptions.PromptLayerAPIConnectionError(
1323
+ f"PromptLayer had the following error while tracking your group: {e}", response=None, body=None
1324
+ ) from e
1325
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
1326
+ return False
905
1327
  return True
906
1328
 
907
1329
 
908
- async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1330
+ @retry_on_api_error
1331
+ async def apromptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
909
1332
  try:
910
1333
  payload = {
911
1334
  "api_key": api_key,
912
1335
  "request_id": request_id,
913
1336
  "group_id": group_id,
914
1337
  }
915
- async with httpx.AsyncClient() as client:
1338
+ async with _make_httpx_client() as client:
916
1339
  response = await client.post(
917
- f"{URL_API_PROMPTLAYER}/track-group",
1340
+ f"{base_url}/track-group",
918
1341
  headers={"X-API-KEY": api_key},
919
1342
  json=payload,
920
1343
  )
1344
+
921
1345
  if response.status_code != 200:
922
- warn_on_bad_response(
923
- response,
924
- "WARNING: While tracking your group, PromptLayer had the following error",
925
- )
926
- return False
1346
+ if throw_on_error:
1347
+ raise_on_bad_response(
1348
+ response,
1349
+ "While tracking your group, PromptLayer had the following error",
1350
+ )
1351
+ else:
1352
+ warn_on_bad_response(
1353
+ response,
1354
+ "WARNING: While tracking your group, PromptLayer had the following error",
1355
+ )
1356
+ return False
927
1357
  except httpx.RequestError as e:
928
- print(
929
- f"WARNING: While tracking your group PromptLayer had the following error: {e}",
930
- file=sys.stderr,
931
- )
1358
+ if throw_on_error:
1359
+ raise _exceptions.PromptLayerAPIConnectionError(
1360
+ f"PromptLayer had the following error while tracking your group: {str(e)}", response=None, body=None
1361
+ ) from e
1362
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
932
1363
  return False
933
1364
 
934
1365
  return True
935
1366
 
936
1367
 
1368
+ @retry_on_api_error
937
1369
  def get_prompt_template(
938
- prompt_name: str, params: Union[GetPromptTemplate, None] = None, api_key: str = None
1370
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name: str, params: Union[GetPromptTemplate, None] = None
939
1371
  ) -> GetPromptTemplateResponse:
940
1372
  try:
941
1373
  json_body = {"api_key": api_key}
942
1374
  if params:
943
1375
  json_body = {**json_body, **params}
944
1376
  response = requests.post(
945
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1377
+ f"{base_url}/prompt-templates/{prompt_name}",
946
1378
  headers={"X-API-KEY": api_key},
947
1379
  json=json_body,
948
1380
  )
949
1381
  if response.status_code != 200:
950
- raise Exception(
951
- f"PromptLayer had the following error while getting your prompt template: {response.text}"
952
- )
1382
+ if throw_on_error:
1383
+ raise_on_bad_response(
1384
+ response, "PromptLayer had the following error while getting your prompt template"
1385
+ )
1386
+ else:
1387
+ warn_on_bad_response(
1388
+ response, "WARNING: PromptLayer had the following error while getting your prompt template"
1389
+ )
1390
+ return None
953
1391
 
954
- warning = response.json().get("warning", None)
955
- if warning is not None:
956
- warn_on_bad_response(
957
- warning,
958
- "WARNING: While getting your prompt template",
959
- )
960
1392
  return response.json()
1393
+ except requests.exceptions.ConnectionError as e:
1394
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1395
+ if throw_on_error:
1396
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1397
+ logger.warning(err_msg)
1398
+ return None
1399
+ except requests.exceptions.Timeout as e:
1400
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1401
+ if throw_on_error:
1402
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1403
+ logger.warning(err_msg)
1404
+ return None
961
1405
  except requests.exceptions.RequestException as e:
962
- raise Exception(
963
- f"PromptLayer had the following error while getting your prompt template: {e}"
964
- )
1406
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1407
+ if throw_on_error:
1408
+ raise _exceptions.PromptLayerError(err_msg, response=None, body=None) from e
1409
+ logger.warning(err_msg)
1410
+ return None
965
1411
 
966
1412
 
1413
+ @retry_on_api_error
967
1414
  async def aget_prompt_template(
1415
+ api_key: str,
1416
+ base_url: str,
1417
+ throw_on_error: bool,
968
1418
  prompt_name: str,
969
1419
  params: Union[GetPromptTemplate, None] = None,
970
- api_key: str = None,
971
1420
  ) -> GetPromptTemplateResponse:
972
1421
  try:
973
1422
  json_body = {"api_key": api_key}
974
1423
  if params:
975
1424
  json_body.update(params)
976
- async with httpx.AsyncClient() as client:
1425
+ async with _make_httpx_client() as client:
977
1426
  response = await client.post(
978
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1427
+ f"{base_url}/prompt-templates/{quote(prompt_name, safe='')}",
979
1428
  headers={"X-API-KEY": api_key},
980
1429
  json=json_body,
981
1430
  )
982
1431
  if response.status_code != 200:
983
- raise_on_bad_response(
984
- response,
985
- "PromptLayer had the following error while getting your prompt template",
986
- )
987
- warning = response.json().get("warning", None)
988
- if warning:
989
- warn_on_bad_response(
990
- warning,
991
- "WARNING: While getting your prompt template",
992
- )
1432
+ if throw_on_error:
1433
+ raise_on_bad_response(
1434
+ response,
1435
+ "PromptLayer had the following error while getting your prompt template",
1436
+ )
1437
+ else:
1438
+ warn_on_bad_response(
1439
+ response, "WARNING: While getting your prompt template PromptLayer had the following error"
1440
+ )
1441
+ return None
993
1442
  return response.json()
1443
+ except (httpx.ConnectError, httpx.NetworkError) as e:
1444
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1445
+ if throw_on_error:
1446
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1447
+ logger.warning(err_msg)
1448
+ return None
1449
+ except httpx.TimeoutException as e:
1450
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1451
+ if throw_on_error:
1452
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1453
+ logger.warning(err_msg)
1454
+ return None
994
1455
  except httpx.RequestError as e:
995
- raise Exception(
996
- f"PromptLayer had the following error while getting your prompt template: {str(e)}"
997
- ) from e
1456
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1457
+ if throw_on_error:
1458
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1459
+ logger.warning(err_msg)
1460
+ return None
998
1461
 
999
1462
 
1463
+ @retry_on_api_error
1000
1464
  def publish_prompt_template(
1465
+ api_key: str,
1466
+ base_url: str,
1467
+ throw_on_error: bool,
1001
1468
  body: PublishPromptTemplate,
1002
- api_key: str = None,
1003
1469
  ) -> PublishPromptTemplateResponse:
1004
1470
  try:
1005
1471
  response = requests.post(
1006
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1472
+ f"{base_url}/rest/prompt-templates",
1007
1473
  headers={"X-API-KEY": api_key},
1008
1474
  json={
1009
1475
  "prompt_template": {**body},
@@ -1012,24 +1478,38 @@ def publish_prompt_template(
1012
1478
  },
1013
1479
  )
1014
1480
  if response.status_code == 400:
1015
- raise Exception(
1016
- f"PromptLayer had the following error while publishing your prompt template: {response.text}"
1017
- )
1481
+ if throw_on_error:
1482
+ raise_on_bad_response(
1483
+ response, "PromptLayer had the following error while publishing your prompt template"
1484
+ )
1485
+ else:
1486
+ warn_on_bad_response(
1487
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1488
+ )
1489
+ return None
1018
1490
  return response.json()
1019
1491
  except requests.exceptions.RequestException as e:
1020
- raise Exception(
1021
- f"PromptLayer had the following error while publishing your prompt template: {e}"
1022
- )
1492
+ if throw_on_error:
1493
+ raise _exceptions.PromptLayerAPIConnectionError(
1494
+ f"PromptLayer had the following error while publishing your prompt template: {e}",
1495
+ response=None,
1496
+ body=None,
1497
+ ) from e
1498
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1499
+ return None
1023
1500
 
1024
1501
 
1502
+ @retry_on_api_error
1025
1503
  async def apublish_prompt_template(
1504
+ api_key: str,
1505
+ base_url: str,
1506
+ throw_on_error: bool,
1026
1507
  body: PublishPromptTemplate,
1027
- api_key: str = None,
1028
1508
  ) -> PublishPromptTemplateResponse:
1029
1509
  try:
1030
- async with httpx.AsyncClient() as client:
1510
+ async with _make_httpx_client() as client:
1031
1511
  response = await client.post(
1032
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1512
+ f"{base_url}/rest/prompt-templates",
1033
1513
  headers={"X-API-KEY": api_key},
1034
1514
  json={
1035
1515
  "prompt_template": {**body},
@@ -1037,429 +1517,103 @@ async def apublish_prompt_template(
1037
1517
  "release_labels": body.get("release_labels"),
1038
1518
  },
1039
1519
  )
1040
- if response.status_code == 400:
1041
- raise Exception(
1042
- f"PromptLayer had the following error while publishing your prompt template: {response.text}"
1043
- )
1044
- if response.status_code != 201:
1045
- raise_on_bad_response(
1046
- response,
1047
- "PromptLayer had the following error while publishing your prompt template",
1048
- )
1520
+
1521
+ if response.status_code == 400 or response.status_code != 201:
1522
+ if throw_on_error:
1523
+ raise_on_bad_response(
1524
+ response,
1525
+ "PromptLayer had the following error while publishing your prompt template",
1526
+ )
1527
+ else:
1528
+ warn_on_bad_response(
1529
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1530
+ )
1531
+ return None
1049
1532
  return response.json()
1050
1533
  except httpx.RequestError as e:
1051
- raise Exception(
1052
- f"PromptLayer had the following error while publishing your prompt template: {str(e)}"
1053
- ) from e
1534
+ if throw_on_error:
1535
+ raise _exceptions.PromptLayerAPIConnectionError(
1536
+ f"PromptLayer had the following error while publishing your prompt template: {str(e)}",
1537
+ response=None,
1538
+ body=None,
1539
+ ) from e
1540
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1541
+ return None
1054
1542
 
1055
1543
 
1544
+ @retry_on_api_error
1056
1545
  def get_all_prompt_templates(
1057
- page: int = 1, per_page: int = 30, api_key: str = None
1546
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
1058
1547
  ) -> List[ListPromptTemplateResponse]:
1059
1548
  try:
1549
+ params = {"page": page, "per_page": per_page}
1550
+ if label:
1551
+ params["label"] = label
1060
1552
  response = requests.get(
1061
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1553
+ f"{base_url}/prompt-templates",
1062
1554
  headers={"X-API-KEY": api_key},
1063
- params={"page": page, "per_page": per_page},
1555
+ params=params,
1064
1556
  )
1065
1557
  if response.status_code != 200:
1066
- raise Exception(
1067
- f"PromptLayer had the following error while getting all your prompt templates: {response.text}"
1068
- )
1558
+ if throw_on_error:
1559
+ raise_on_bad_response(
1560
+ response, "PromptLayer had the following error while getting all your prompt templates"
1561
+ )
1562
+ else:
1563
+ warn_on_bad_response(
1564
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1565
+ )
1566
+ return []
1069
1567
  items = response.json().get("items", [])
1070
1568
  return items
1071
1569
  except requests.exceptions.RequestException as e:
1072
- raise Exception(
1073
- f"PromptLayer had the following error while getting all your prompt templates: {e}"
1074
- )
1570
+ if throw_on_error:
1571
+ raise _exceptions.PromptLayerAPIConnectionError(
1572
+ f"PromptLayer had the following error while getting all your prompt templates: {e}",
1573
+ response=None,
1574
+ body=None,
1575
+ ) from e
1576
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1577
+ return []
1075
1578
 
1076
1579
 
1580
+ @retry_on_api_error
1077
1581
  async def aget_all_prompt_templates(
1078
- page: int = 1, per_page: int = 30, api_key: str = None
1582
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
1079
1583
  ) -> List[ListPromptTemplateResponse]:
1080
1584
  try:
1081
- async with httpx.AsyncClient() as client:
1585
+ params = {"page": page, "per_page": per_page}
1586
+ if label:
1587
+ params["label"] = label
1588
+ async with _make_httpx_client() as client:
1082
1589
  response = await client.get(
1083
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1590
+ f"{base_url}/prompt-templates",
1084
1591
  headers={"X-API-KEY": api_key},
1085
- params={"page": page, "per_page": per_page},
1592
+ params=params,
1086
1593
  )
1594
+
1087
1595
  if response.status_code != 200:
1088
- raise_on_bad_response(
1089
- response,
1090
- "PromptLayer had the following error while getting all your prompt templates",
1091
- )
1596
+ if throw_on_error:
1597
+ raise_on_bad_response(
1598
+ response,
1599
+ "PromptLayer had the following error while getting all your prompt templates",
1600
+ )
1601
+ else:
1602
+ warn_on_bad_response(
1603
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1604
+ )
1605
+ return []
1092
1606
  items = response.json().get("items", [])
1093
1607
  return items
1094
1608
  except httpx.RequestError as e:
1095
- raise Exception(
1096
- f"PromptLayer had the following error while getting all your prompt templates: {str(e)}"
1097
- ) from e
1098
-
1099
-
1100
- def openai_stream_chat(results: list):
1101
- from openai.types.chat import (
1102
- ChatCompletion,
1103
- ChatCompletionChunk,
1104
- ChatCompletionMessage,
1105
- ChatCompletionMessageToolCall,
1106
- )
1107
- from openai.types.chat.chat_completion import Choice
1108
- from openai.types.chat.chat_completion_message_tool_call import Function
1109
-
1110
- chat_completion_chunks: List[ChatCompletionChunk] = results
1111
- response: ChatCompletion = ChatCompletion(
1112
- id="",
1113
- object="chat.completion",
1114
- choices=[
1115
- Choice(
1116
- finish_reason="stop",
1117
- index=0,
1118
- message=ChatCompletionMessage(role="assistant"),
1119
- )
1120
- ],
1121
- created=0,
1122
- model="",
1123
- )
1124
- last_result = chat_completion_chunks[-1]
1125
- response.id = last_result.id
1126
- response.created = last_result.created
1127
- response.model = last_result.model
1128
- response.system_fingerprint = last_result.system_fingerprint
1129
- response.usage = last_result.usage
1130
- content = ""
1131
- tool_calls: Union[List[ChatCompletionMessageToolCall], None] = None
1132
- for result in chat_completion_chunks:
1133
- choices = result.choices
1134
- if len(choices) == 0:
1135
- continue
1136
- if choices[0].delta.content:
1137
- content = f"{content}{result.choices[0].delta.content}"
1138
-
1139
- delta = choices[0].delta
1140
- if delta.tool_calls:
1141
- tool_calls = tool_calls or []
1142
- last_tool_call = None
1143
- if len(tool_calls) > 0:
1144
- last_tool_call = tool_calls[-1]
1145
- tool_call = delta.tool_calls[0]
1146
- if not tool_call.function:
1147
- continue
1148
- if not last_tool_call or tool_call.id:
1149
- tool_calls.append(
1150
- ChatCompletionMessageToolCall(
1151
- id=tool_call.id or "",
1152
- function=Function(
1153
- name=tool_call.function.name or "",
1154
- arguments=tool_call.function.arguments or "",
1155
- ),
1156
- type=tool_call.type or "function",
1157
- )
1158
- )
1159
- continue
1160
- last_tool_call.function.name = (
1161
- f"{last_tool_call.function.name}{tool_call.function.name or ''}"
1162
- )
1163
- last_tool_call.function.arguments = f"{last_tool_call.function.arguments}{tool_call.function.arguments or ''}"
1164
-
1165
- response.choices[0].message.content = content
1166
- response.choices[0].message.tool_calls = tool_calls
1167
- return response
1168
-
1169
-
1170
- async def aopenai_stream_chat(generator: AsyncIterable[Any]) -> Any:
1171
- from openai.types.chat import (
1172
- ChatCompletion,
1173
- ChatCompletionChunk,
1174
- ChatCompletionMessage,
1175
- ChatCompletionMessageToolCall,
1176
- )
1177
- from openai.types.chat.chat_completion import Choice
1178
- from openai.types.chat.chat_completion_message_tool_call import Function
1179
-
1180
- chat_completion_chunks: List[ChatCompletionChunk] = []
1181
- response: ChatCompletion = ChatCompletion(
1182
- id="",
1183
- object="chat.completion",
1184
- choices=[
1185
- Choice(
1186
- finish_reason="stop",
1187
- index=0,
1188
- message=ChatCompletionMessage(role="assistant"),
1189
- )
1190
- ],
1191
- created=0,
1192
- model="",
1193
- )
1194
- content = ""
1195
- tool_calls: Union[List[ChatCompletionMessageToolCall], None] = None
1196
-
1197
- async for result in generator:
1198
- chat_completion_chunks.append(result)
1199
- choices = result.choices
1200
- if len(choices) == 0:
1201
- continue
1202
- if choices[0].delta.content:
1203
- content = f"{content}{choices[0].delta.content}"
1204
-
1205
- delta = choices[0].delta
1206
- if delta.tool_calls:
1207
- tool_calls = tool_calls or []
1208
- last_tool_call = None
1209
- if len(tool_calls) > 0:
1210
- last_tool_call = tool_calls[-1]
1211
- tool_call = delta.tool_calls[0]
1212
- if not tool_call.function:
1213
- continue
1214
- if not last_tool_call or tool_call.id:
1215
- tool_calls.append(
1216
- ChatCompletionMessageToolCall(
1217
- id=tool_call.id or "",
1218
- function=Function(
1219
- name=tool_call.function.name or "",
1220
- arguments=tool_call.function.arguments or "",
1221
- ),
1222
- type=tool_call.type or "function",
1223
- )
1224
- )
1225
- continue
1226
- last_tool_call.function.name = (
1227
- f"{last_tool_call.function.name}{tool_call.function.name or ''}"
1228
- )
1229
- last_tool_call.function.arguments = f"{last_tool_call.function.arguments}{tool_call.function.arguments or ''}"
1230
-
1231
- # After collecting all chunks, set the response attributes
1232
- if chat_completion_chunks:
1233
- last_result = chat_completion_chunks[-1]
1234
- response.id = last_result.id
1235
- response.created = last_result.created
1236
- response.model = last_result.model
1237
- response.system_fingerprint = getattr(last_result, "system_fingerprint", None)
1238
- response.usage = last_result.usage
1239
-
1240
- response.choices[0].message.content = content
1241
- response.choices[0].message.tool_calls = tool_calls
1242
- return response
1243
-
1244
-
1245
- def openai_stream_completion(results: list):
1246
- from openai.types.completion import Completion, CompletionChoice
1247
-
1248
- completions: List[Completion] = results
1249
- last_chunk = completions[-1]
1250
- response = Completion(
1251
- id=last_chunk.id,
1252
- created=last_chunk.created,
1253
- model=last_chunk.model,
1254
- object="text_completion",
1255
- choices=[CompletionChoice(finish_reason="stop", index=0, text="")],
1256
- )
1257
- text = ""
1258
- for completion in completions:
1259
- usage = completion.usage
1260
- system_fingerprint = completion.system_fingerprint
1261
- if len(completion.choices) > 0 and completion.choices[0].text:
1262
- text = f"{text}{completion.choices[0].text}"
1263
- if usage:
1264
- response.usage = usage
1265
- if system_fingerprint:
1266
- response.system_fingerprint = system_fingerprint
1267
- response.choices[0].text = text
1268
- return response
1269
-
1270
-
1271
- async def aopenai_stream_completion(generator: AsyncIterable[Any]) -> Any:
1272
- from openai.types.completion import Completion, CompletionChoice
1273
-
1274
- completions: List[Completion] = []
1275
- text = ""
1276
- response = Completion(
1277
- id="",
1278
- created=0,
1279
- model="",
1280
- object="text_completion",
1281
- choices=[CompletionChoice(finish_reason="stop", index=0, text="")],
1282
- )
1283
-
1284
- async for completion in generator:
1285
- completions.append(completion)
1286
- usage = completion.usage
1287
- system_fingerprint = getattr(completion, "system_fingerprint", None)
1288
- if len(completion.choices) > 0 and completion.choices[0].text:
1289
- text = f"{text}{completion.choices[0].text}"
1290
- if usage:
1291
- response.usage = usage
1292
- if system_fingerprint:
1293
- response.system_fingerprint = system_fingerprint
1294
-
1295
- # After collecting all completions, set the response attributes
1296
- if completions:
1297
- last_chunk = completions[-1]
1298
- response.id = last_chunk.id
1299
- response.created = last_chunk.created
1300
- response.model = last_chunk.model
1301
-
1302
- response.choices[0].text = text
1303
- return response
1304
-
1305
-
1306
- def anthropic_stream_message(results: list):
1307
- from anthropic.types import Message, MessageStreamEvent, TextBlock, Usage
1308
-
1309
- message_stream_events: List[MessageStreamEvent] = results
1310
- response: Message = Message(
1311
- id="",
1312
- model="",
1313
- content=[],
1314
- role="assistant",
1315
- type="message",
1316
- stop_reason="stop_sequence",
1317
- stop_sequence=None,
1318
- usage=Usage(input_tokens=0, output_tokens=0),
1319
- )
1320
- content = ""
1321
- for result in message_stream_events:
1322
- if result.type == "message_start":
1323
- response = result.message
1324
- elif result.type == "content_block_delta":
1325
- if result.delta.type == "text_delta":
1326
- content = f"{content}{result.delta.text}"
1327
- elif result.type == "message_delta":
1328
- if hasattr(result, "usage"):
1329
- response.usage.output_tokens = result.usage.output_tokens
1330
- if hasattr(result.delta, "stop_reason"):
1331
- response.stop_reason = result.delta.stop_reason
1332
- response.content.append(TextBlock(type="text", text=content))
1333
- return response
1334
-
1335
-
1336
- async def aanthropic_stream_message(generator: AsyncIterable[Any]) -> Any:
1337
- from anthropic.types import Message, MessageStreamEvent, TextBlock, Usage
1338
-
1339
- message_stream_events: List[MessageStreamEvent] = []
1340
- response: Message = Message(
1341
- id="",
1342
- model="",
1343
- content=[],
1344
- role="assistant",
1345
- type="message",
1346
- stop_reason="stop_sequence",
1347
- stop_sequence=None,
1348
- usage=Usage(input_tokens=0, output_tokens=0),
1349
- )
1350
- content = ""
1351
-
1352
- async for result in generator:
1353
- message_stream_events.append(result)
1354
- if result.type == "message_start":
1355
- response = result.message
1356
- elif result.type == "content_block_delta":
1357
- if result.delta.type == "text_delta":
1358
- content = f"{content}{result.delta.text}"
1359
- elif result.type == "message_delta":
1360
- if hasattr(result, "usage"):
1361
- response.usage.output_tokens = result.usage.output_tokens
1362
- if hasattr(result.delta, "stop_reason"):
1363
- response.stop_reason = result.delta.stop_reason
1364
-
1365
- response.content.append(TextBlock(type="text", text=content))
1366
- return response
1367
-
1368
-
1369
- def anthropic_stream_completion(results: list):
1370
- from anthropic.types import Completion
1371
-
1372
- completions: List[Completion] = results
1373
- last_chunk = completions[-1]
1374
- response = Completion(
1375
- id=last_chunk.id,
1376
- completion="",
1377
- model=last_chunk.model,
1378
- stop_reason="stop",
1379
- type="completion",
1380
- )
1381
-
1382
- text = ""
1383
- for completion in completions:
1384
- text = f"{text}{completion.completion}"
1385
- response.completion = text
1386
- return response
1387
-
1388
-
1389
- async def aanthropic_stream_completion(generator: AsyncIterable[Any]) -> Any:
1390
- from anthropic.types import Completion
1391
-
1392
- completions: List[Completion] = []
1393
- text = ""
1394
- response = Completion(
1395
- id="",
1396
- completion="",
1397
- model="",
1398
- stop_reason="stop",
1399
- type="completion",
1400
- )
1401
-
1402
- async for completion in generator:
1403
- completions.append(completion)
1404
- text = f"{text}{completion.completion}"
1405
-
1406
- # After collecting all completions, set the response attributes
1407
- if completions:
1408
- last_chunk = completions[-1]
1409
- response.id = last_chunk.id
1410
- response.model = last_chunk.model
1411
-
1412
- response.completion = text
1413
- return response
1414
-
1415
-
1416
- def stream_response(
1417
- generator: Generator, after_stream: Callable, map_results: Callable
1418
- ):
1419
- data = {
1420
- "request_id": None,
1421
- "raw_response": None,
1422
- "prompt_blueprint": None,
1423
- }
1424
- results = []
1425
- for result in generator:
1426
- results.append(result)
1427
- data["raw_response"] = result
1428
- yield data
1429
- request_response = map_results(results)
1430
- response = after_stream(request_response=request_response.model_dump())
1431
- data["request_id"] = response.get("request_id")
1432
- data["prompt_blueprint"] = response.get("prompt_blueprint")
1433
- yield data
1434
-
1435
-
1436
- async def astream_response(
1437
- generator: AsyncIterable[Any],
1438
- after_stream: Callable[..., Any],
1439
- map_results: Callable[[Any], Any],
1440
- ) -> AsyncGenerator[Dict[str, Any], None]:
1441
- data = {
1442
- "request_id": None,
1443
- "raw_response": None,
1444
- "prompt_blueprint": None,
1445
- }
1446
- results = []
1447
- async for result in generator:
1448
- results.append(result)
1449
- data["raw_response"] = result
1450
- yield data
1451
-
1452
- async def async_generator_from_list(lst):
1453
- for item in lst:
1454
- yield item
1455
-
1456
- request_response = await map_results(async_generator_from_list(results))
1457
- after_stream_response = await after_stream(
1458
- request_response=request_response.model_dump()
1459
- )
1460
- data["request_id"] = after_stream_response.get("request_id")
1461
- data["prompt_blueprint"] = after_stream_response.get("prompt_blueprint")
1462
- yield data
1609
+ if throw_on_error:
1610
+ raise _exceptions.PromptLayerAPIConnectionError(
1611
+ f"PromptLayer had the following error while getting all your prompt templates: {str(e)}",
1612
+ response=None,
1613
+ body=None,
1614
+ ) from e
1615
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1616
+ return []
1463
1617
 
1464
1618
 
1465
1619
  def openai_chat_request(client, **kwargs):
@@ -1476,14 +1630,20 @@ MAP_TYPE_TO_OPENAI_FUNCTION = {
1476
1630
  }
1477
1631
 
1478
1632
 
1479
- def openai_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1633
+ def openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1480
1634
  from openai import OpenAI
1481
1635
 
1482
- client = OpenAI(base_url=kwargs.pop("base_url", None))
1483
- request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[
1484
- prompt_blueprint["prompt_template"]["type"]
1485
- ]
1486
- return request_to_make(client, **kwargs)
1636
+ client = OpenAI(**client_kwargs)
1637
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1638
+
1639
+ if api_type is None:
1640
+ api_type = "chat-completions"
1641
+
1642
+ if api_type == "chat-completions":
1643
+ request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1644
+ return request_to_make(client, **function_kwargs)
1645
+ else:
1646
+ return client.responses.create(**function_kwargs)
1487
1647
 
1488
1648
 
1489
1649
  async def aopenai_chat_request(client, **kwargs):
@@ -1500,34 +1660,45 @@ AMAP_TYPE_TO_OPENAI_FUNCTION = {
1500
1660
  }
1501
1661
 
1502
1662
 
1503
- async def aopenai_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1663
+ async def aopenai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1504
1664
  from openai import AsyncOpenAI
1505
1665
 
1506
- client = AsyncOpenAI(base_url=kwargs.pop("base_url", None))
1507
- request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[
1508
- prompt_blueprint["prompt_template"]["type"]
1509
- ]
1510
- return await request_to_make(client, **kwargs)
1666
+ client = AsyncOpenAI(**client_kwargs)
1667
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1668
+
1669
+ if api_type == "chat-completions":
1670
+ request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1671
+ return await request_to_make(client, **function_kwargs)
1672
+ else:
1673
+ return await client.responses.create(**function_kwargs)
1511
1674
 
1512
1675
 
1513
- def azure_openai_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1676
+ def azure_openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1514
1677
  from openai import AzureOpenAI
1515
1678
 
1516
- client = AzureOpenAI(azure_endpoint=kwargs.pop("base_url", None))
1517
- request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[
1518
- prompt_blueprint["prompt_template"]["type"]
1519
- ]
1520
- return request_to_make(client, **kwargs)
1679
+ client = AzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1680
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1681
+
1682
+ if api_type == "chat-completions":
1683
+ request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1684
+ return request_to_make(client, **function_kwargs)
1685
+ else:
1686
+ return client.responses.create(**function_kwargs)
1521
1687
 
1522
1688
 
1523
- async def aazure_openai_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1689
+ async def aazure_openai_request(
1690
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
1691
+ ):
1524
1692
  from openai import AsyncAzureOpenAI
1525
1693
 
1526
- client = AsyncAzureOpenAI(azure_endpoint=kwargs.pop("base_url", None))
1527
- request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[
1528
- prompt_blueprint["prompt_template"]["type"]
1529
- ]
1530
- return await request_to_make(client, **kwargs)
1694
+ client = AsyncAzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1695
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1696
+
1697
+ if api_type == "chat-completions":
1698
+ request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1699
+ return await request_to_make(client, **function_kwargs)
1700
+ else:
1701
+ return await client.responses.create(**function_kwargs)
1531
1702
 
1532
1703
 
1533
1704
  def anthropic_chat_request(client, **kwargs):
@@ -1544,14 +1715,12 @@ MAP_TYPE_TO_ANTHROPIC_FUNCTION = {
1544
1715
  }
1545
1716
 
1546
1717
 
1547
- def anthropic_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1718
+ def anthropic_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1548
1719
  from anthropic import Anthropic
1549
1720
 
1550
- client = Anthropic(base_url=kwargs.pop("base_url", None))
1551
- request_to_make = MAP_TYPE_TO_ANTHROPIC_FUNCTION[
1552
- prompt_blueprint["prompt_template"]["type"]
1553
- ]
1554
- return request_to_make(client, **kwargs)
1721
+ client = Anthropic(**client_kwargs)
1722
+ request_to_make = MAP_TYPE_TO_ANTHROPIC_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1723
+ return request_to_make(client, **function_kwargs)
1555
1724
 
1556
1725
 
1557
1726
  async def aanthropic_chat_request(client, **kwargs):
@@ -1568,14 +1737,12 @@ AMAP_TYPE_TO_ANTHROPIC_FUNCTION = {
1568
1737
  }
1569
1738
 
1570
1739
 
1571
- async def aanthropic_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1740
+ async def aanthropic_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1572
1741
  from anthropic import AsyncAnthropic
1573
1742
 
1574
- client = AsyncAnthropic(base_url=kwargs.pop("base_url", None))
1575
- request_to_make = AMAP_TYPE_TO_ANTHROPIC_FUNCTION[
1576
- prompt_blueprint["prompt_template"]["type"]
1577
- ]
1578
- return await request_to_make(client, **kwargs)
1743
+ client = AsyncAnthropic(**client_kwargs)
1744
+ request_to_make = AMAP_TYPE_TO_ANTHROPIC_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1745
+ return await request_to_make(client, **function_kwargs)
1579
1746
 
1580
1747
 
1581
1748
  # do not remove! This is used in the langchain integration.
@@ -1583,214 +1750,336 @@ def get_api_key():
1583
1750
  # raise an error if the api key is not set
1584
1751
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
1585
1752
  if not api_key:
1586
- raise Exception(
1587
- "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>' "
1753
+ raise _exceptions.PromptLayerAuthenticationError(
1754
+ "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>'",
1755
+ response=None,
1756
+ body=None,
1588
1757
  )
1589
1758
  return api_key
1590
1759
 
1591
1760
 
1592
- def util_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1761
+ @retry_on_api_error
1762
+ def util_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1593
1763
  try:
1594
1764
  response = requests.post(
1595
- f"{URL_API_PROMPTLAYER}/log-request",
1765
+ f"{base_url}/log-request",
1596
1766
  headers={"X-API-KEY": api_key},
1597
1767
  json=kwargs,
1598
1768
  )
1599
1769
  if response.status_code != 201:
1600
- warn_on_bad_response(
1601
- response,
1602
- "WARNING: While logging your request PromptLayer had the following error",
1603
- )
1604
- return None
1770
+ if throw_on_error:
1771
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1772
+ else:
1773
+ warn_on_bad_response(
1774
+ response,
1775
+ "WARNING: While logging your request PromptLayer had the following error",
1776
+ )
1777
+ return None
1605
1778
  return response.json()
1606
1779
  except Exception as e:
1607
- print(
1608
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
1609
- file=sys.stderr,
1610
- )
1780
+ if throw_on_error:
1781
+ raise _exceptions.PromptLayerAPIError(
1782
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1783
+ ) from e
1784
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1611
1785
  return None
1612
1786
 
1613
1787
 
1614
- async def autil_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1788
+ @retry_on_api_error
1789
+ async def autil_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1615
1790
  try:
1616
- async with httpx.AsyncClient() as client:
1791
+ async with _make_httpx_client() as client:
1617
1792
  response = await client.post(
1618
- f"{URL_API_PROMPTLAYER}/log-request",
1793
+ f"{base_url}/log-request",
1619
1794
  headers={"X-API-KEY": api_key},
1620
1795
  json=kwargs,
1621
1796
  )
1622
1797
  if response.status_code != 201:
1623
- warn_on_bad_response(
1624
- response,
1625
- "WARNING: While logging your request PromptLayer had the following error",
1626
- )
1627
- return None
1798
+ if throw_on_error:
1799
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1800
+ else:
1801
+ warn_on_bad_response(
1802
+ response,
1803
+ "WARNING: While logging your request PromptLayer had the following error",
1804
+ )
1805
+ return None
1628
1806
  return response.json()
1629
1807
  except Exception as e:
1630
- print(
1631
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
1632
- file=sys.stderr,
1633
- )
1808
+ if throw_on_error:
1809
+ raise _exceptions.PromptLayerAPIError(
1810
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1811
+ ) from e
1812
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1634
1813
  return None
1635
1814
 
1636
1815
 
1637
- def mistral_request(
1638
- prompt_blueprint: GetPromptTemplateResponse,
1639
- **kwargs,
1640
- ):
1816
+ def mistral_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1641
1817
  from mistralai import Mistral
1642
1818
 
1643
- client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
1644
- if "stream" in kwargs and kwargs["stream"]:
1645
- kwargs.pop("stream")
1646
- return client.chat.stream(**kwargs)
1647
- if "stream" in kwargs:
1648
- kwargs.pop("stream")
1649
- return client.chat.complete(**kwargs)
1819
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"), client=_make_simple_httpx_client())
1820
+ if "stream" in function_kwargs and function_kwargs["stream"]:
1821
+ function_kwargs.pop("stream")
1822
+ return client.chat.stream(**function_kwargs)
1823
+ if "stream" in function_kwargs:
1824
+ function_kwargs.pop("stream")
1825
+ return client.chat.complete(**function_kwargs)
1650
1826
 
1651
1827
 
1652
1828
  async def amistral_request(
1653
1829
  prompt_blueprint: GetPromptTemplateResponse,
1654
- **kwargs,
1830
+ _: dict,
1831
+ function_kwargs: dict,
1655
1832
  ):
1656
1833
  from mistralai import Mistral
1657
1834
 
1658
- client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
1659
- if "stream" in kwargs and kwargs["stream"]:
1660
- return await client.chat.stream_async(**kwargs)
1661
- return await client.chat.complete_async(**kwargs)
1835
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"), async_client=_make_httpx_client())
1836
+ if "stream" in function_kwargs and function_kwargs["stream"]:
1837
+ return await client.chat.stream_async(**function_kwargs)
1838
+ return await client.chat.complete_async(**function_kwargs)
1839
+
1840
+
1841
+ class _GoogleStreamWrapper:
1842
+ """Wrapper to keep Google client alive during streaming."""
1843
+
1844
+ def __init__(self, stream_generator, client):
1845
+ self._stream = stream_generator
1846
+ self._client = client # Keep client alive
1847
+
1848
+ def __iter__(self):
1849
+ return self._stream.__iter__()
1850
+
1851
+ def __next__(self):
1852
+ return next(self._stream)
1853
+
1854
+ def __aiter__(self):
1855
+ return self._stream.__aiter__()
1856
+
1857
+ async def __anext__(self):
1858
+ return await self._stream.__anext__()
1859
+
1860
+
1861
+ def google_chat_request(client, **kwargs):
1862
+ from google.genai.chats import Content
1863
+
1864
+ stream = kwargs.pop("stream", False)
1865
+ model = kwargs.get("model", "gemini-2.0-flash")
1866
+ history = [Content(**item) for item in kwargs.get("history", [])]
1867
+ generation_config = kwargs.get("generation_config", {})
1868
+ chat = client.chats.create(model=model, history=history, config=generation_config)
1869
+ last_message = history[-1].parts if history else ""
1870
+ if stream:
1871
+ stream_gen = chat.send_message_stream(message=last_message)
1872
+ return _GoogleStreamWrapper(stream_gen, client)
1873
+ return chat.send_message(message=last_message)
1874
+
1875
+
1876
+ def google_completions_request(client, **kwargs):
1877
+ config = kwargs.pop("generation_config", {})
1878
+ model = kwargs.get("model", "gemini-2.0-flash")
1879
+ contents = kwargs.get("contents", [])
1880
+ stream = kwargs.pop("stream", False)
1881
+ if stream:
1882
+ stream_gen = client.models.generate_content_stream(model=model, contents=contents, config=config)
1883
+ return _GoogleStreamWrapper(stream_gen, client)
1884
+ return client.models.generate_content(model=model, contents=contents, config=config)
1885
+
1886
+
1887
+ MAP_TYPE_TO_GOOGLE_FUNCTION = {
1888
+ "chat": google_chat_request,
1889
+ "completion": google_completions_request,
1890
+ }
1891
+
1662
1892
 
1893
+ def google_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1894
+ from google import genai
1663
1895
 
1664
- def mistral_stream_chat(results: list):
1665
- from openai.types.chat import (
1666
- ChatCompletion,
1667
- ChatCompletionMessage,
1668
- ChatCompletionMessageToolCall,
1896
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
1897
+ client = genai.Client(
1898
+ vertexai=True,
1899
+ project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
1900
+ location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
1901
+ )
1902
+ else:
1903
+ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"))
1904
+ request_to_make = MAP_TYPE_TO_GOOGLE_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1905
+ return request_to_make(client, **function_kwargs)
1906
+
1907
+
1908
+ async def agoogle_chat_request(client, **kwargs):
1909
+ from google.genai.chats import Content
1910
+
1911
+ stream = kwargs.pop("stream", False)
1912
+ model = kwargs.get("model", "gemini-2.0-flash")
1913
+ history = [Content(**item) for item in kwargs.get("history", [])]
1914
+ generation_config = kwargs.get("generation_config", {})
1915
+ chat = client.aio.chats.create(model=model, history=history, config=generation_config)
1916
+ last_message = history[-1].parts[0] if history else ""
1917
+ if stream:
1918
+ stream_gen = await chat.send_message_stream(message=last_message)
1919
+ return _GoogleStreamWrapper(stream_gen, client)
1920
+ return await chat.send_message(message=last_message)
1921
+
1922
+
1923
+ async def agoogle_completions_request(client, **kwargs):
1924
+ config = kwargs.pop("generation_config", {})
1925
+ model = kwargs.get("model", "gemini-2.0-flash")
1926
+ contents = kwargs.get("contents", [])
1927
+ stream = kwargs.pop("stream", False)
1928
+ if stream:
1929
+ stream_gen = await client.aio.models.generate_content_stream(model=model, contents=contents, config=config)
1930
+ return _GoogleStreamWrapper(stream_gen, client)
1931
+ return await client.aio.models.generate_content(model=model, contents=contents, config=config)
1932
+
1933
+
1934
+ AMAP_TYPE_TO_GOOGLE_FUNCTION = {
1935
+ "chat": agoogle_chat_request,
1936
+ "completion": agoogle_completions_request,
1937
+ }
1938
+
1939
+
1940
+ async def agoogle_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1941
+ from google import genai
1942
+
1943
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
1944
+ client = genai.Client(
1945
+ vertexai=True,
1946
+ project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
1947
+ location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
1948
+ )
1949
+ else:
1950
+ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"))
1951
+ request_to_make = AMAP_TYPE_TO_GOOGLE_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1952
+ return await request_to_make(client, **function_kwargs)
1953
+
1954
+
1955
+ def vertexai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1956
+ if "gemini" in prompt_blueprint["metadata"]["model"]["name"]:
1957
+ return google_request(
1958
+ prompt_blueprint=prompt_blueprint,
1959
+ client_kwargs=client_kwargs,
1960
+ function_kwargs=function_kwargs,
1961
+ )
1962
+
1963
+ if "claude" in prompt_blueprint["metadata"]["model"]["name"]:
1964
+ from anthropic import AnthropicVertex
1965
+
1966
+ client = AnthropicVertex(**client_kwargs)
1967
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
1968
+ return anthropic_chat_request(client=client, **function_kwargs)
1969
+ raise NotImplementedError(
1970
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Vertex AI"
1971
+ )
1972
+
1973
+ raise NotImplementedError(
1974
+ f"Vertex AI request for model {prompt_blueprint['metadata']['model']['name']} is not implemented yet."
1669
1975
  )
1670
- from openai.types.chat.chat_completion import Choice
1671
- from openai.types.chat.chat_completion_message_tool_call import Function
1672
-
1673
- last_result = results[-1]
1674
- response = ChatCompletion(
1675
- id=last_result.data.id,
1676
- object="chat.completion",
1677
- choices=[
1678
- Choice(
1679
- finish_reason=last_result.data.choices[0].finish_reason or "stop",
1680
- index=0,
1681
- message=ChatCompletionMessage(role="assistant"),
1682
- )
1683
- ],
1684
- created=last_result.data.created,
1685
- model=last_result.data.model,
1976
+
1977
+
1978
+ async def avertexai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1979
+ if "gemini" in prompt_blueprint["metadata"]["model"]["name"]:
1980
+ return await agoogle_request(
1981
+ prompt_blueprint=prompt_blueprint,
1982
+ client_kwargs=client_kwargs,
1983
+ function_kwargs=function_kwargs,
1984
+ )
1985
+
1986
+ if "claude" in prompt_blueprint["metadata"]["model"]["name"]:
1987
+ from anthropic import AsyncAnthropicVertex
1988
+
1989
+ client = AsyncAnthropicVertex(**client_kwargs)
1990
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
1991
+ return await aanthropic_chat_request(client=client, **function_kwargs)
1992
+ raise NotImplementedError(
1993
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Vertex AI"
1994
+ )
1995
+
1996
+ raise NotImplementedError(
1997
+ f"Vertex AI request for model {prompt_blueprint['metadata']['model']['name']} is not implemented yet."
1686
1998
  )
1687
1999
 
1688
- content = ""
1689
- tool_calls = None
1690
-
1691
- for result in results:
1692
- choices = result.data.choices
1693
- if len(choices) == 0:
1694
- continue
1695
-
1696
- delta = choices[0].delta
1697
- if delta.content is not None:
1698
- content = f"{content}{delta.content}"
1699
-
1700
- if delta.tool_calls:
1701
- tool_calls = tool_calls or []
1702
- for tool_call in delta.tool_calls:
1703
- if len(tool_calls) == 0 or tool_call.id:
1704
- tool_calls.append(
1705
- ChatCompletionMessageToolCall(
1706
- id=tool_call.id or "",
1707
- function=Function(
1708
- name=tool_call.function.name,
1709
- arguments=tool_call.function.arguments,
1710
- ),
1711
- type="function",
1712
- )
1713
- )
1714
- else:
1715
- last_tool_call = tool_calls[-1]
1716
- if tool_call.function.name:
1717
- last_tool_call.function.name = (
1718
- f"{last_tool_call.function.name}{tool_call.function.name}"
1719
- )
1720
- if tool_call.function.arguments:
1721
- last_tool_call.function.arguments = f"{last_tool_call.function.arguments}{tool_call.function.arguments}"
1722
-
1723
- response.choices[0].message.content = content
1724
- response.choices[0].message.tool_calls = tool_calls
1725
- response.usage = last_result.data.usage
1726
- return response
1727
-
1728
-
1729
- async def amistral_stream_chat(generator: AsyncIterable[Any]) -> Any:
1730
- from openai.types.chat import (
1731
- ChatCompletion,
1732
- ChatCompletionMessage,
1733
- ChatCompletionMessageToolCall,
2000
+
2001
+ def amazon_bedrock_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
2002
+ import boto3
2003
+
2004
+ bedrock_client = boto3.client(
2005
+ "bedrock-runtime",
2006
+ aws_access_key_id=function_kwargs.pop("aws_access_key", None),
2007
+ aws_secret_access_key=function_kwargs.pop("aws_secret_key", None),
2008
+ region_name=function_kwargs.pop("aws_region", "us-east-1"),
1734
2009
  )
1735
- from openai.types.chat.chat_completion import Choice
1736
- from openai.types.chat.chat_completion_message_tool_call import Function
1737
-
1738
- completion_chunks = []
1739
- response = ChatCompletion(
1740
- id="",
1741
- object="chat.completion",
1742
- choices=[
1743
- Choice(
1744
- finish_reason="stop",
1745
- index=0,
1746
- message=ChatCompletionMessage(role="assistant"),
1747
- )
1748
- ],
1749
- created=0,
1750
- model="",
2010
+
2011
+ stream = function_kwargs.pop("stream", False)
2012
+
2013
+ if stream:
2014
+ return bedrock_client.converse_stream(**function_kwargs)
2015
+ else:
2016
+ return bedrock_client.converse(**function_kwargs)
2017
+
2018
+
2019
+ async def aamazon_bedrock_request(
2020
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
2021
+ ):
2022
+ import aioboto3
2023
+
2024
+ aws_access_key = function_kwargs.pop("aws_access_key", None)
2025
+ aws_secret_key = function_kwargs.pop("aws_secret_key", None)
2026
+ aws_region = function_kwargs.pop("aws_region", "us-east-1")
2027
+
2028
+ session_kwargs = {}
2029
+ if aws_access_key:
2030
+ session_kwargs["aws_access_key_id"] = aws_access_key
2031
+ if aws_secret_key:
2032
+ session_kwargs["aws_secret_access_key"] = aws_secret_key
2033
+ if aws_region:
2034
+ session_kwargs["region_name"] = aws_region
2035
+
2036
+ stream = function_kwargs.pop("stream", False)
2037
+ session = aioboto3.Session()
2038
+
2039
+ async with session.client("bedrock-runtime", **session_kwargs) as client:
2040
+ if stream:
2041
+ return await client.converse_stream(**function_kwargs)
2042
+ else:
2043
+ return await client.converse(**function_kwargs)
2044
+
2045
+
2046
+ def anthropic_bedrock_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
2047
+ from anthropic import AnthropicBedrock
2048
+
2049
+ client = AnthropicBedrock(
2050
+ aws_access_key=function_kwargs.pop("aws_access_key", None),
2051
+ aws_secret_key=function_kwargs.pop("aws_secret_key", None),
2052
+ aws_region=function_kwargs.pop("aws_region", None),
2053
+ aws_session_token=function_kwargs.pop("aws_session_token", None),
2054
+ base_url=function_kwargs.pop("base_url", None),
2055
+ **client_kwargs,
2056
+ )
2057
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
2058
+ return anthropic_chat_request(client=client, **function_kwargs)
2059
+ elif prompt_blueprint["prompt_template"]["type"] == "completion":
2060
+ return anthropic_completions_request(client=client, **function_kwargs)
2061
+ raise NotImplementedError(
2062
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Bedrock"
2063
+ )
2064
+
2065
+
2066
+ async def aanthropic_bedrock_request(
2067
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
2068
+ ):
2069
+ from anthropic import AsyncAnthropicBedrock
2070
+
2071
+ client = AsyncAnthropicBedrock(
2072
+ aws_access_key=function_kwargs.pop("aws_access_key", None),
2073
+ aws_secret_key=function_kwargs.pop("aws_secret_key", None),
2074
+ aws_region=function_kwargs.pop("aws_region", None),
2075
+ aws_session_token=function_kwargs.pop("aws_session_token", None),
2076
+ base_url=function_kwargs.pop("base_url", None),
2077
+ **client_kwargs,
2078
+ )
2079
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
2080
+ return await aanthropic_chat_request(client=client, **function_kwargs)
2081
+ elif prompt_blueprint["prompt_template"]["type"] == "completion":
2082
+ return await aanthropic_completions_request(client=client, **function_kwargs)
2083
+ raise NotImplementedError(
2084
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Bedrock"
1751
2085
  )
1752
- content = ""
1753
- tool_calls = None
1754
-
1755
- async for result in generator:
1756
- completion_chunks.append(result)
1757
- choices = result.data.choices
1758
- if len(choices) == 0:
1759
- continue
1760
- delta = choices[0].delta
1761
- if delta.content is not None:
1762
- content = f"{content}{delta.content}"
1763
-
1764
- if delta.tool_calls:
1765
- tool_calls = tool_calls or []
1766
- for tool_call in delta.tool_calls:
1767
- if len(tool_calls) == 0 or tool_call.id:
1768
- tool_calls.append(
1769
- ChatCompletionMessageToolCall(
1770
- id=tool_call.id or "",
1771
- function=Function(
1772
- name=tool_call.function.name,
1773
- arguments=tool_call.function.arguments,
1774
- ),
1775
- type="function",
1776
- )
1777
- )
1778
- else:
1779
- last_tool_call = tool_calls[-1]
1780
- if tool_call.function.name:
1781
- last_tool_call.function.name = (
1782
- f"{last_tool_call.function.name}{tool_call.function.name}"
1783
- )
1784
- if tool_call.function.arguments:
1785
- last_tool_call.function.arguments = f"{last_tool_call.function.arguments}{tool_call.function.arguments}"
1786
-
1787
- if completion_chunks:
1788
- last_result = completion_chunks[-1]
1789
- response.id = last_result.data.id
1790
- response.created = last_result.data.created
1791
- response.model = last_result.data.model
1792
- response.usage = last_result.data.usage
1793
-
1794
- response.choices[0].message.content = content
1795
- response.choices[0].message.tool_calls = tool_calls
1796
- return response