promptlayer 1.0.16__py3-none-any.whl → 1.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptlayer/utils.py CHANGED
@@ -3,16 +3,39 @@ import contextvars
3
3
  import datetime
4
4
  import functools
5
5
  import json
6
+ import logging
6
7
  import os
7
- import sys
8
8
  import types
9
+ from contextlib import asynccontextmanager
9
10
  from copy import deepcopy
10
11
  from enum import Enum
11
- from typing import Callable, Generator, List, Union
12
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
13
+ from urllib.parse import quote
14
+ from uuid import uuid4
12
15
 
16
+ import httpx
13
17
  import requests
18
+ import urllib3
19
+ import urllib3.util
20
+ from ably import AblyRealtime
21
+ from ably.types.message import Message
22
+ from centrifuge import (
23
+ Client,
24
+ PublicationContext,
25
+ SubscriptionEventHandler,
26
+ SubscriptionState,
27
+ )
14
28
  from opentelemetry import context, trace
29
+ from tenacity import (
30
+ before_sleep_log,
31
+ retry,
32
+ retry_if_exception,
33
+ stop_after_attempt,
34
+ wait_exponential,
35
+ )
15
36
 
37
+ from promptlayer import exceptions as _exceptions
38
+ from promptlayer.types import RequestLog
16
39
  from promptlayer.types.prompt_template import (
17
40
  GetPromptTemplate,
18
41
  GetPromptTemplateResponse,
@@ -21,12 +44,341 @@ from promptlayer.types.prompt_template import (
21
44
  PublishPromptTemplateResponse,
22
45
  )
23
46
 
24
- URL_API_PROMPTLAYER = os.environ.setdefault(
25
- "URL_API_PROMPTLAYER", "https://api.promptlayer.com"
47
+ # Configuration
48
+ RERAISE_ORIGINAL_EXCEPTION = os.getenv("PROMPTLAYER_RE_RAISE_ORIGINAL_EXCEPTION", "False").lower() == "true"
49
+ RAISE_FOR_STATUS = os.getenv("PROMPTLAYER_RAISE_FOR_STATUS", "False").lower() == "true"
50
+ DEFAULT_HTTP_TIMEOUT = 5
51
+
52
+ WORKFLOW_RUN_URL_TEMPLATE = "{base_url}/workflows/{workflow_id}/run"
53
+ WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE = "workflows:{workflow_id}:run:{channel_name_suffix}"
54
+ SET_WORKFLOW_COMPLETE_MESSAGE = "SET_WORKFLOW_COMPLETE"
55
+ WS_TOKEN_REQUEST_LIBRARY_URL = (
56
+ f"{os.getenv('PROMPTLAYER_BASE_URL', 'https://api.promptlayer.com')}/ws-token-request-library"
26
57
  )
27
58
 
28
59
 
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ class FinalOutputCode(Enum):
64
+ OK = "OK"
65
+ EXCEEDS_SIZE_LIMIT = "EXCEEDS_SIZE_LIMIT"
66
+
67
+
68
+ def should_retry_error(exception):
69
+ """Check if an exception should trigger a retry.
70
+
71
+ Only retries on server errors (5xx) and rate limits (429).
72
+ """
73
+ if hasattr(exception, "response"):
74
+ response = exception.response
75
+ if hasattr(response, "status_code"):
76
+ status_code = response.status_code
77
+ if status_code >= 500 or status_code == 429:
78
+ return True
79
+
80
+ if isinstance(exception, (_exceptions.PromptLayerInternalServerError, _exceptions.PromptLayerRateLimitError)):
81
+ return True
82
+
83
+ return False
84
+
85
+
86
+ def retry_on_api_error(func):
87
+ return retry(
88
+ retry=retry_if_exception(should_retry_error),
89
+ stop=stop_after_attempt(4), # 4 total attempts (1 initial + 3 retries)
90
+ wait=wait_exponential(multiplier=2, max=15), # 2s, 4s, 8s
91
+ before_sleep=before_sleep_log(logger, logging.WARNING),
92
+ reraise=True,
93
+ )(func)
94
+
95
+
96
+ def _get_http_timeout():
97
+ try:
98
+ return float(os.getenv("PROMPTLAYER_HTTP_TIMEOUT", DEFAULT_HTTP_TIMEOUT))
99
+ except (ValueError, TypeError):
100
+ return DEFAULT_HTTP_TIMEOUT
101
+
102
+
103
+ def _make_httpx_client():
104
+ return httpx.AsyncClient(timeout=_get_http_timeout())
105
+
106
+
107
+ def _make_simple_httpx_client():
108
+ return httpx.Client(timeout=_get_http_timeout())
109
+
110
+
111
+ def _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name):
112
+ # This is backward compatibility code
113
+ if (workflow_id_or_name := workflow_name if workflow_id_or_name is None else workflow_id_or_name) is None:
114
+ raise ValueError('Either "workflow_id_or_name" or "workflow_name" must be provided')
115
+
116
+ return workflow_id_or_name
117
+
118
+
119
+ async def _get_final_output(
120
+ base_url: str, execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]
121
+ ) -> Dict[str, Any]:
122
+ async with httpx.AsyncClient() as client:
123
+ response = await client.get(
124
+ f"{base_url}/workflow-version-execution-results",
125
+ headers=headers,
126
+ params={"workflow_version_execution_id": execution_id, "return_all_outputs": return_all_outputs},
127
+ )
128
+ if response.status_code != 200:
129
+ raise_on_bad_response(response, "PromptLayer had the following error while getting workflow results")
130
+ return response.json()
131
+
132
+
133
+ # TODO(dmu) MEDIUM: Consider putting all these functions into a class, so we do not have to pass
134
+ # `authorization_headers` into each function
135
+ async def _resolve_workflow_id(base_url: str, workflow_id_or_name: Union[int, str], headers):
136
+ if isinstance(workflow_id_or_name, int):
137
+ return workflow_id_or_name
138
+
139
+ # TODO(dmu) LOW: Should we warn user here to avoid using workflow names in favor of workflow id?
140
+ async with _make_httpx_client() as client:
141
+ # TODO(dmu) MEDIUM: Generalize the way we make async calls to PromptLayer API and reuse it everywhere
142
+ response = await client.get(f"{base_url}/workflows/{workflow_id_or_name}", headers=headers)
143
+ if response.status_code != 200:
144
+ raise_on_bad_response(response, "PromptLayer had the following error while resolving workflow")
145
+
146
+ return response.json()["workflow"]["id"]
147
+
148
+
149
+ async def _get_ably_token(base_url: str, channel_name, authentication_headers):
150
+ try:
151
+ async with _make_httpx_client() as client:
152
+ response = await client.post(
153
+ f"{base_url}/ws-token-request-library",
154
+ headers=authentication_headers,
155
+ params={"capability": channel_name},
156
+ )
157
+ if response.status_code != 201:
158
+ raise_on_bad_response(
159
+ response,
160
+ "PromptLayer had the following error while getting WebSocket token",
161
+ )
162
+ return response.json()
163
+ except Exception as ex:
164
+ error_message = f"Failed to get WebSocket token: {ex}"
165
+ logger.exception(error_message)
166
+ if RERAISE_ORIGINAL_EXCEPTION:
167
+ raise
168
+ else:
169
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
170
+
171
+
172
+ def _make_message_listener(base_url: str, results_future, execution_id_future, return_all_outputs, headers):
173
+ # We need this function to be mocked by unittests
174
+ async def message_listener(message: Message):
175
+ if results_future.cancelled() or message.name != SET_WORKFLOW_COMPLETE_MESSAGE:
176
+ return # TODO(dmu) LOW: Do we really need this check?
177
+
178
+ execution_id = await asyncio.wait_for(execution_id_future, _get_http_timeout() * 1.1)
179
+ message_data = json.loads(message.data)
180
+ if message_data["workflow_version_execution_id"] != execution_id:
181
+ return
182
+
183
+ if (result_code := message_data.get("result_code")) in (FinalOutputCode.OK.value, None):
184
+ results = message_data["final_output"]
185
+ elif result_code == FinalOutputCode.EXCEEDS_SIZE_LIMIT.value:
186
+ results = await _get_final_output(base_url, execution_id, return_all_outputs, headers=headers)
187
+ else:
188
+ raise NotImplementedError(f"Unsupported final output code: {result_code}")
189
+
190
+ results_future.set_result(results)
191
+
192
+ return message_listener
193
+
194
+
195
+ async def _subscribe_to_workflow_completion_channel(
196
+ base_url: str, channel, execution_id_future, return_all_outputs, headers
197
+ ):
198
+ results_future = asyncio.Future()
199
+ message_listener = _make_message_listener(
200
+ base_url, results_future, execution_id_future, return_all_outputs, headers
201
+ )
202
+ await channel.subscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
203
+ return results_future, message_listener
204
+
205
+
206
+ async def _post_workflow_id_run(
207
+ *,
208
+ base_url: str,
209
+ authentication_headers,
210
+ workflow_id,
211
+ input_variables: Dict[str, Any],
212
+ metadata: Dict[str, Any],
213
+ workflow_label_name: str,
214
+ workflow_version_number: int,
215
+ return_all_outputs: bool,
216
+ channel_name_suffix: str,
217
+ _url_template: str = WORKFLOW_RUN_URL_TEMPLATE,
218
+ ):
219
+ url = _url_template.format(base_url=base_url, workflow_id=workflow_id)
220
+ payload = {
221
+ "input_variables": input_variables,
222
+ "metadata": metadata,
223
+ "workflow_label_name": workflow_label_name,
224
+ "workflow_version_number": workflow_version_number,
225
+ "return_all_outputs": return_all_outputs,
226
+ "channel_name_suffix": channel_name_suffix,
227
+ }
228
+ try:
229
+ async with _make_httpx_client() as client:
230
+ response = await client.post(url, json=payload, headers=authentication_headers)
231
+ if response.status_code != 201:
232
+ raise_on_bad_response(response, "PromptLayer had the following error while running your workflow")
233
+
234
+ result = response.json()
235
+ if warning := result.get("warning"):
236
+ logger.warning(f"{warning}")
237
+ except Exception as ex:
238
+ error_message = f"Failed to run workflow: {str(ex)}"
239
+ logger.exception(error_message)
240
+ if RERAISE_ORIGINAL_EXCEPTION:
241
+ raise
242
+ else:
243
+ raise _exceptions.PromptLayerAPIError(error_message, response=None, body=None) from ex
244
+
245
+ return result.get("workflow_version_execution_id")
246
+
247
+
248
+ async def _wait_for_workflow_completion(channel, results_future, message_listener, timeout):
249
+ # We need this function for mocking in unittests
250
+ try:
251
+ return await asyncio.wait_for(results_future, timeout)
252
+ except asyncio.TimeoutError:
253
+ raise _exceptions.PromptLayerAPITimeoutError(
254
+ "Workflow execution did not complete properly", response=None, body=None
255
+ )
256
+ finally:
257
+ channel.unsubscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
258
+
259
+
260
+ def _make_channel_name_suffix():
261
+ # We need this function for mocking in unittests
262
+ return uuid4().hex
263
+
264
+
265
+ MessageCallback = Callable[[Message], Coroutine[None, None, None]]
266
+
267
+
268
+ class SubscriptionEventLoggerHandler(SubscriptionEventHandler):
269
+ def __init__(self, callback: MessageCallback):
270
+ self.callback = callback
271
+
272
+ async def on_publication(self, ctx: PublicationContext):
273
+ message_name = ctx.pub.data.get("message_name", "unknown")
274
+ data = ctx.pub.data.get("data", "")
275
+ message = Message(name=message_name, data=data)
276
+ await self.callback(message)
277
+
278
+
279
+ @asynccontextmanager
280
+ async def centrifugo_client(address: str, token: str):
281
+ client = Client(address, token=token)
282
+ try:
283
+ await client.connect()
284
+ yield client
285
+ finally:
286
+ await client.disconnect()
287
+
288
+
289
+ @asynccontextmanager
290
+ async def centrifugo_subscription(client: Client, topic: str, message_listener: MessageCallback):
291
+ subscription = client.new_subscription(
292
+ topic,
293
+ events=SubscriptionEventLoggerHandler(message_listener),
294
+ )
295
+ try:
296
+ await subscription.subscribe()
297
+ yield
298
+ finally:
299
+ if subscription.state == SubscriptionState.SUBSCRIBED:
300
+ await subscription.unsubscribe()
301
+
302
+
303
+ @retry_on_api_error
304
+ async def arun_workflow_request(
305
+ *,
306
+ api_key: str,
307
+ base_url: str,
308
+ throw_on_error: bool,
309
+ workflow_id_or_name: Optional[Union[int, str]] = None,
310
+ input_variables: Dict[str, Any],
311
+ metadata: Optional[Dict[str, Any]] = None,
312
+ workflow_label_name: Optional[str] = None,
313
+ workflow_version_number: Optional[int] = None,
314
+ return_all_outputs: Optional[bool] = False,
315
+ timeout: Optional[int] = 3600,
316
+ # `workflow_name` deprecated, kept for backward compatibility only.
317
+ workflow_name: Optional[str] = None,
318
+ ):
319
+ headers = {"X-API-KEY": api_key}
320
+ workflow_id = await _resolve_workflow_id(
321
+ base_url, _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
322
+ )
323
+ channel_name_suffix = _make_channel_name_suffix()
324
+ channel_name = WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE.format(
325
+ workflow_id=workflow_id, channel_name_suffix=channel_name_suffix
326
+ )
327
+ ably_token = await _get_ably_token(base_url, channel_name, headers)
328
+ token = ably_token["token_details"]["token"]
329
+
330
+ execution_id_future = asyncio.Future[int]()
331
+
332
+ if ably_token.get("messaging_backend") == "centrifugo":
333
+ address = urllib3.util.parse_url(base_url)._replace(scheme="wss", path="/connection/websocket").url
334
+ async with centrifugo_client(address, token) as client:
335
+ results_future = asyncio.Future[dict[str, Any]]()
336
+ async with centrifugo_subscription(
337
+ client,
338
+ channel_name,
339
+ _make_message_listener(base_url, results_future, execution_id_future, return_all_outputs, headers),
340
+ ):
341
+ execution_id = await _post_workflow_id_run(
342
+ base_url=base_url,
343
+ authentication_headers=headers,
344
+ workflow_id=workflow_id,
345
+ input_variables=input_variables,
346
+ metadata=metadata,
347
+ workflow_label_name=workflow_label_name,
348
+ workflow_version_number=workflow_version_number,
349
+ return_all_outputs=return_all_outputs,
350
+ channel_name_suffix=channel_name_suffix,
351
+ )
352
+ execution_id_future.set_result(execution_id)
353
+ await asyncio.wait_for(results_future, timeout)
354
+ return results_future.result()
355
+
356
+ async with AblyRealtime(token=token) as ably_client:
357
+ # It is crucial to subscribe before running a workflow, otherwise we may miss a completion message
358
+ channel = ably_client.channels.get(channel_name)
359
+ results_future, message_listener = await _subscribe_to_workflow_completion_channel(
360
+ base_url, channel, execution_id_future, return_all_outputs, headers
361
+ )
362
+
363
+ execution_id = await _post_workflow_id_run(
364
+ base_url=base_url,
365
+ authentication_headers=headers,
366
+ workflow_id=workflow_id,
367
+ input_variables=input_variables,
368
+ metadata=metadata,
369
+ workflow_label_name=workflow_label_name,
370
+ workflow_version_number=workflow_version_number,
371
+ return_all_outputs=return_all_outputs,
372
+ channel_name_suffix=channel_name_suffix,
373
+ )
374
+ execution_id_future.set_result(execution_id)
375
+
376
+ return await _wait_for_workflow_completion(channel, results_future, message_listener, timeout)
377
+
378
+
29
379
  def promptlayer_api_handler(
380
+ api_key: str,
381
+ base_url: str,
30
382
  function_name,
31
383
  provider_type,
32
384
  args,
@@ -35,20 +387,13 @@ def promptlayer_api_handler(
35
387
  response,
36
388
  request_start_time,
37
389
  request_end_time,
38
- api_key,
39
390
  return_pl_id=False,
40
391
  llm_request_span_id=None,
41
392
  ):
42
393
  if (
43
394
  isinstance(response, types.GeneratorType)
44
395
  or isinstance(response, types.AsyncGeneratorType)
45
- or type(response).__name__
46
- in [
47
- "Stream",
48
- "AsyncStream",
49
- "AsyncMessageStreamManager",
50
- "MessageStreamManager",
51
- ]
396
+ or type(response).__name__ in ["Stream", "AsyncStream", "AsyncMessageStreamManager", "MessageStreamManager"]
52
397
  ):
53
398
  return GeneratorProxy(
54
399
  generator=response,
@@ -64,9 +409,11 @@ def promptlayer_api_handler(
64
409
  "llm_request_span_id": llm_request_span_id,
65
410
  },
66
411
  api_key=api_key,
412
+ base_url=base_url,
67
413
  )
68
414
  else:
69
415
  request_id = promptlayer_api_request(
416
+ base_url=base_url,
70
417
  function_name=function_name,
71
418
  provider_type=provider_type,
72
419
  args=args,
@@ -85,6 +432,8 @@ def promptlayer_api_handler(
85
432
 
86
433
 
87
434
  async def promptlayer_api_handler_async(
435
+ api_key: str,
436
+ base_url: str,
88
437
  function_name,
89
438
  provider_type,
90
439
  args,
@@ -93,13 +442,14 @@ async def promptlayer_api_handler_async(
93
442
  response,
94
443
  request_start_time,
95
444
  request_end_time,
96
- api_key,
97
445
  return_pl_id=False,
98
446
  llm_request_span_id=None,
99
447
  ):
100
448
  return await run_in_thread_async(
101
449
  None,
102
450
  promptlayer_api_handler,
451
+ api_key,
452
+ base_url,
103
453
  function_name,
104
454
  provider_type,
105
455
  args,
@@ -108,7 +458,6 @@ async def promptlayer_api_handler_async(
108
458
  response,
109
459
  request_start_time,
110
460
  request_end_time,
111
- api_key,
112
461
  return_pl_id=return_pl_id,
113
462
  llm_request_span_id=llm_request_span_id,
114
463
  )
@@ -122,15 +471,13 @@ def convert_native_object_to_dict(native_object):
122
471
  if isinstance(native_object, Enum):
123
472
  return native_object.value
124
473
  if hasattr(native_object, "__dict__"):
125
- return {
126
- k: convert_native_object_to_dict(v)
127
- for k, v in native_object.__dict__.items()
128
- }
474
+ return {k: convert_native_object_to_dict(v) for k, v in native_object.__dict__.items()}
129
475
  return native_object
130
476
 
131
477
 
132
478
  def promptlayer_api_request(
133
479
  *,
480
+ base_url: str,
134
481
  function_name,
135
482
  provider_type,
136
483
  args,
@@ -147,13 +494,11 @@ def promptlayer_api_request(
147
494
  if isinstance(response, dict) and hasattr(response, "to_dict_recursive"):
148
495
  response = response.to_dict_recursive()
149
496
  request_response = None
150
- if hasattr(
151
- response, "dict"
152
- ): # added this for anthropic 3.0 changes, they return a completion object
497
+ if hasattr(response, "dict"): # added this for anthropic 3.0 changes, they return a completion object
153
498
  response = response.dict()
154
499
  try:
155
500
  request_response = requests.post(
156
- f"{URL_API_PROMPTLAYER}/track-request",
501
+ f"{base_url}/track-request",
157
502
  json={
158
503
  "function_name": function_name,
159
504
  "provider_type": provider_type,
@@ -170,40 +515,64 @@ def promptlayer_api_request(
170
515
  )
171
516
  if not hasattr(request_response, "status_code"):
172
517
  warn_on_bad_response(
173
- request_response,
174
- "WARNING: While logging your request PromptLayer had the following issue",
518
+ request_response, "WARNING: While logging your request PromptLayer had the following issue"
175
519
  )
176
520
  elif request_response.status_code != 200:
177
521
  warn_on_bad_response(
178
- request_response,
179
- "WARNING: While logging your request PromptLayer had the following error",
522
+ request_response, "WARNING: While logging your request PromptLayer had the following error"
180
523
  )
181
524
  except Exception as e:
182
- print(
183
- f"WARNING: While logging your request PromptLayer had the following error: {e}",
184
- file=sys.stderr,
185
- )
525
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
186
526
  if request_response is not None and return_pl_id:
187
527
  return request_response.json().get("request_id")
188
528
 
189
529
 
190
- def track_request(**body):
530
+ @retry_on_api_error
531
+ def track_request(base_url: str, throw_on_error: bool, **body):
191
532
  try:
192
533
  response = requests.post(
193
- f"{URL_API_PROMPTLAYER}/track-request",
534
+ f"{base_url}/track-request",
194
535
  json=body,
195
536
  )
196
537
  if response.status_code != 200:
197
- warn_on_bad_response(
198
- response,
199
- f"PromptLayer had the following error while tracking your request: {response.text}",
200
- )
538
+ if throw_on_error:
539
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
540
+ else:
541
+ warn_on_bad_response(
542
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
543
+ )
201
544
  return response.json()
202
545
  except requests.exceptions.RequestException as e:
203
- print(
204
- f"WARNING: While logging your request PromptLayer had the following error: {e}",
205
- file=sys.stderr,
206
- )
546
+ if throw_on_error:
547
+ raise _exceptions.PromptLayerAPIConnectionError(
548
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
549
+ ) from e
550
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
551
+ return {}
552
+
553
+
554
+ @retry_on_api_error
555
+ async def atrack_request(base_url: str, throw_on_error: bool, **body: Any) -> Dict[str, Any]:
556
+ try:
557
+ async with _make_httpx_client() as client:
558
+ response = await client.post(
559
+ f"{base_url}/track-request",
560
+ json=body,
561
+ )
562
+ if response.status_code != 200:
563
+ if throw_on_error:
564
+ raise_on_bad_response(response, "PromptLayer had the following error while tracking your request")
565
+ else:
566
+ warn_on_bad_response(
567
+ response, f"PromptLayer had the following error while tracking your request: {response.text}"
568
+ )
569
+ return response.json()
570
+ except httpx.RequestError as e:
571
+ if throw_on_error:
572
+ raise _exceptions.PromptLayerAPIConnectionError(
573
+ f"PromptLayer had the following error while tracking your request: {e}", response=None, body=None
574
+ ) from e
575
+ logger.warning(f"While logging your request PromptLayer had the following error: {e}")
207
576
  return {}
208
577
 
209
578
 
@@ -235,8 +604,9 @@ def promptlayer_api_request_async(
235
604
  )
236
605
 
237
606
 
607
+ @retry_on_api_error
238
608
  def promptlayer_get_prompt(
239
- prompt_name, api_key, version: int = None, label: str = None
609
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, version: int = None, label: str = None
240
610
  ):
241
611
  """
242
612
  Get a prompt from the PromptLayer library
@@ -245,29 +615,40 @@ def promptlayer_get_prompt(
245
615
  """
246
616
  try:
247
617
  request_response = requests.get(
248
- f"{URL_API_PROMPTLAYER}/library-get-prompt-template",
618
+ f"{base_url}/library-get-prompt-template",
249
619
  headers={"X-API-KEY": api_key},
250
620
  params={"prompt_name": prompt_name, "version": version, "label": label},
251
621
  )
252
622
  except Exception as e:
253
- raise Exception(
254
- f"PromptLayer had the following error while getting your prompt: {e}"
255
- )
623
+ if throw_on_error:
624
+ raise _exceptions.PromptLayerAPIError(
625
+ f"PromptLayer had the following error while getting your prompt: {e}", response=None, body=None
626
+ ) from e
627
+ logger.warning(f"PromptLayer had the following error while getting your prompt: {e}")
628
+ return None
256
629
  if request_response.status_code != 200:
257
- raise_on_bad_response(
258
- request_response,
259
- "PromptLayer had the following error while getting your prompt",
260
- )
630
+ if throw_on_error:
631
+ raise_on_bad_response(
632
+ request_response,
633
+ "PromptLayer had the following error while getting your prompt",
634
+ )
635
+ else:
636
+ warn_on_bad_response(
637
+ request_response,
638
+ "WARNING: PromptLayer had the following error while getting your prompt",
639
+ )
640
+ return None
261
641
 
262
642
  return request_response.json()
263
643
 
264
644
 
645
+ @retry_on_api_error
265
646
  def promptlayer_publish_prompt(
266
- prompt_name, prompt_template, commit_message, tags, api_key, metadata=None
647
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name, prompt_template, commit_message, tags, metadata=None
267
648
  ):
268
649
  try:
269
650
  request_response = requests.post(
270
- f"{URL_API_PROMPTLAYER}/library-publish-prompt-template",
651
+ f"{base_url}/library-publish-prompt-template",
271
652
  json={
272
653
  "prompt_name": prompt_name,
273
654
  "prompt_template": prompt_template,
@@ -278,23 +659,34 @@ def promptlayer_publish_prompt(
278
659
  },
279
660
  )
280
661
  except Exception as e:
281
- raise Exception(
282
- f"PromptLayer had the following error while publishing your prompt: {e}"
283
- )
662
+ if throw_on_error:
663
+ raise _exceptions.PromptLayerAPIError(
664
+ f"PromptLayer had the following error while publishing your prompt: {e}", response=None, body=None
665
+ ) from e
666
+ logger.warning(f"PromptLayer had the following error while publishing your prompt: {e}")
667
+ return False
284
668
  if request_response.status_code != 200:
285
- raise_on_bad_response(
286
- request_response,
287
- "PromptLayer had the following error while publishing your prompt",
288
- )
669
+ if throw_on_error:
670
+ raise_on_bad_response(
671
+ request_response,
672
+ "PromptLayer had the following error while publishing your prompt",
673
+ )
674
+ else:
675
+ warn_on_bad_response(
676
+ request_response,
677
+ "WARNING: PromptLayer had the following error while publishing your prompt",
678
+ )
679
+ return False
289
680
  return True
290
681
 
291
682
 
683
+ @retry_on_api_error
292
684
  def promptlayer_track_prompt(
293
- request_id, prompt_name, input_variables, api_key, version, label
685
+ api_key: str, base_url: str, throw_on_error: bool, request_id, prompt_name, input_variables, version, label
294
686
  ):
295
687
  try:
296
688
  request_response = requests.post(
297
- f"{URL_API_PROMPTLAYER}/library-track-prompt",
689
+ f"{base_url}/library-track-prompt",
298
690
  json={
299
691
  "request_id": request_id,
300
692
  "prompt_name": prompt_name,
@@ -305,24 +697,76 @@ def promptlayer_track_prompt(
305
697
  },
306
698
  )
307
699
  if request_response.status_code != 200:
308
- warn_on_bad_response(
309
- request_response,
310
- "WARNING: While tracking your prompt PromptLayer had the following error",
311
- )
312
- return False
700
+ if throw_on_error:
701
+ raise_on_bad_response(
702
+ request_response,
703
+ "While tracking your prompt PromptLayer had the following error",
704
+ )
705
+ else:
706
+ warn_on_bad_response(
707
+ request_response,
708
+ "WARNING: While tracking your prompt PromptLayer had the following error",
709
+ )
710
+ return False
313
711
  except Exception as e:
314
- print(
315
- f"WARNING: While tracking your prompt PromptLayer had the following error: {e}",
316
- file=sys.stderr,
317
- )
712
+ if throw_on_error:
713
+ raise _exceptions.PromptLayerAPIError(
714
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
715
+ ) from e
716
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
318
717
  return False
319
718
  return True
320
719
 
321
720
 
322
- def promptlayer_track_metadata(request_id, metadata, api_key):
721
+ @retry_on_api_error
722
+ async def apromptlayer_track_prompt(
723
+ api_key: str,
724
+ base_url: str,
725
+ request_id: str,
726
+ prompt_name: str,
727
+ input_variables: Dict[str, Any],
728
+ version: Optional[int] = None,
729
+ label: Optional[str] = None,
730
+ throw_on_error: bool = True,
731
+ ) -> bool:
732
+ url = f"{base_url}/library-track-prompt"
733
+ payload = {
734
+ "request_id": request_id,
735
+ "prompt_name": prompt_name,
736
+ "prompt_input_variables": input_variables,
737
+ "api_key": api_key,
738
+ "version": version,
739
+ "label": label,
740
+ }
741
+ try:
742
+ async with _make_httpx_client() as client:
743
+ response = await client.post(url, json=payload)
744
+
745
+ if response.status_code != 200:
746
+ if throw_on_error:
747
+ raise_on_bad_response(response, "While tracking your prompt, PromptLayer had the following error")
748
+ else:
749
+ warn_on_bad_response(
750
+ response,
751
+ "WARNING: While tracking your prompt, PromptLayer had the following error",
752
+ )
753
+ return False
754
+ except httpx.RequestError as e:
755
+ if throw_on_error:
756
+ raise _exceptions.PromptLayerAPIConnectionError(
757
+ f"While tracking your prompt PromptLayer had the following error: {e}", response=None, body=None
758
+ ) from e
759
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
760
+ return False
761
+
762
+ return True
763
+
764
+
765
+ @retry_on_api_error
766
+ def promptlayer_track_metadata(api_key: str, base_url: str, throw_on_error: bool, request_id, metadata):
323
767
  try:
324
768
  request_response = requests.post(
325
- f"{URL_API_PROMPTLAYER}/library-track-metadata",
769
+ f"{base_url}/library-track-metadata",
326
770
  json={
327
771
  "request_id": request_id,
328
772
  "metadata": metadata,
@@ -330,50 +774,204 @@ def promptlayer_track_metadata(request_id, metadata, api_key):
330
774
  },
331
775
  )
332
776
  if request_response.status_code != 200:
333
- warn_on_bad_response(
334
- request_response,
335
- "WARNING: While tracking your metadata PromptLayer had the following error",
336
- )
337
- return False
777
+ if throw_on_error:
778
+ raise_on_bad_response(
779
+ request_response,
780
+ "While tracking your metadata PromptLayer had the following error",
781
+ )
782
+ else:
783
+ warn_on_bad_response(
784
+ request_response,
785
+ "WARNING: While tracking your metadata PromptLayer had the following error",
786
+ )
787
+ return False
338
788
  except Exception as e:
339
- print(
340
- f"WARNING: While tracking your metadata PromptLayer had the following error: {e}",
341
- file=sys.stderr,
342
- )
789
+ if throw_on_error:
790
+ raise _exceptions.PromptLayerAPIError(
791
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
792
+ ) from e
793
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
794
+ return False
795
+ return True
796
+
797
+
798
+ @retry_on_api_error
799
+ async def apromptlayer_track_metadata(
800
+ api_key: str, base_url: str, throw_on_error: bool, request_id: str, metadata: Dict[str, Any]
801
+ ) -> bool:
802
+ url = f"{base_url}/library-track-metadata"
803
+ payload = {
804
+ "request_id": request_id,
805
+ "metadata": metadata,
806
+ "api_key": api_key,
807
+ }
808
+ try:
809
+ async with _make_httpx_client() as client:
810
+ response = await client.post(url, json=payload)
811
+
812
+ if response.status_code != 200:
813
+ if throw_on_error:
814
+ raise_on_bad_response(
815
+ response,
816
+ "While tracking your metadata, PromptLayer had the following error",
817
+ )
818
+ else:
819
+ warn_on_bad_response(
820
+ response,
821
+ "WARNING: While tracking your metadata, PromptLayer had the following error",
822
+ )
823
+ return False
824
+ except httpx.RequestError as e:
825
+ if throw_on_error:
826
+ raise _exceptions.PromptLayerAPIConnectionError(
827
+ f"While tracking your metadata PromptLayer had the following error: {e}", response=None, body=None
828
+ ) from e
829
+ logger.warning(f"While tracking your metadata PromptLayer had the following error: {e}")
343
830
  return False
831
+
344
832
  return True
345
833
 
346
834
 
347
- def promptlayer_track_score(request_id, score, score_name, api_key):
835
+ @retry_on_api_error
836
+ def promptlayer_track_score(api_key: str, base_url: str, throw_on_error: bool, request_id, score, score_name):
348
837
  try:
349
838
  data = {"request_id": request_id, "score": score, "api_key": api_key}
350
839
  if score_name is not None:
351
840
  data["name"] = score_name
352
841
  request_response = requests.post(
353
- f"{URL_API_PROMPTLAYER}/library-track-score",
842
+ f"{base_url}/library-track-score",
354
843
  json=data,
355
844
  )
356
845
  if request_response.status_code != 200:
357
- warn_on_bad_response(
358
- request_response,
359
- "WARNING: While tracking your score PromptLayer had the following error",
360
- )
361
- return False
846
+ if throw_on_error:
847
+ raise_on_bad_response(
848
+ request_response,
849
+ "While tracking your score PromptLayer had the following error",
850
+ )
851
+ else:
852
+ warn_on_bad_response(
853
+ request_response,
854
+ "WARNING: While tracking your score PromptLayer had the following error",
855
+ )
856
+ return False
362
857
  except Exception as e:
363
- print(
364
- f"WARNING: While tracking your score PromptLayer had the following error: {e}",
365
- file=sys.stderr,
366
- )
858
+ if throw_on_error:
859
+ raise _exceptions.PromptLayerAPIError(
860
+ f"While tracking your score PromptLayer had the following error: {e}", response=None, body=None
861
+ ) from e
862
+ logger.warning(f"While tracking your score PromptLayer had the following error: {e}")
367
863
  return False
368
864
  return True
369
865
 
370
866
 
867
+ @retry_on_api_error
868
+ async def apromptlayer_track_score(
869
+ api_key: str,
870
+ base_url: str,
871
+ throw_on_error: bool,
872
+ request_id: str,
873
+ score: float,
874
+ score_name: Optional[str],
875
+ ) -> bool:
876
+ url = f"{base_url}/library-track-score"
877
+ data = {
878
+ "request_id": request_id,
879
+ "score": score,
880
+ "api_key": api_key,
881
+ }
882
+ if score_name is not None:
883
+ data["name"] = score_name
884
+ try:
885
+ async with _make_httpx_client() as client:
886
+ response = await client.post(url, json=data)
887
+
888
+ if response.status_code != 200:
889
+ if throw_on_error:
890
+ raise_on_bad_response(
891
+ response,
892
+ "While tracking your score, PromptLayer had the following error",
893
+ )
894
+ else:
895
+ warn_on_bad_response(
896
+ response,
897
+ "WARNING: While tracking your score, PromptLayer had the following error",
898
+ )
899
+ return False
900
+ except httpx.RequestError as e:
901
+ if throw_on_error:
902
+ raise _exceptions.PromptLayerAPIConnectionError(
903
+ f"PromptLayer had the following error while tracking your score: {str(e)}", response=None, body=None
904
+ ) from e
905
+ logger.warning(f"While tracking your score PromptLayer had the following error: {str(e)}")
906
+ return False
907
+
908
+ return True
909
+
910
+
911
+ def build_anthropic_content_blocks(events):
912
+ content_blocks = []
913
+ current_block = None
914
+ current_signature = ""
915
+ current_thinking = ""
916
+ current_text = ""
917
+ current_tool_input_json = ""
918
+ usage = None
919
+ stop_reason = None
920
+
921
+ for event in events:
922
+ if event.type == "content_block_start":
923
+ current_block = deepcopy(event.content_block)
924
+ if current_block.type == "thinking":
925
+ current_signature = ""
926
+ current_thinking = ""
927
+ elif current_block.type == "text":
928
+ current_text = ""
929
+ elif current_block.type == "tool_use":
930
+ current_tool_input_json = ""
931
+ elif event.type == "content_block_delta" and current_block is not None:
932
+ if current_block.type == "thinking":
933
+ if hasattr(event.delta, "signature"):
934
+ current_signature = event.delta.signature
935
+ if hasattr(event.delta, "thinking"):
936
+ current_thinking += event.delta.thinking
937
+ elif current_block.type == "text":
938
+ if hasattr(event.delta, "text"):
939
+ current_text += event.delta.text
940
+ elif current_block.type == "tool_use":
941
+ if hasattr(event.delta, "partial_json"):
942
+ current_tool_input_json += event.delta.partial_json
943
+ elif event.type == "content_block_stop" and current_block is not None:
944
+ if current_block.type == "thinking":
945
+ current_block.signature = current_signature
946
+ current_block.thinking = current_thinking
947
+ elif current_block.type == "text":
948
+ current_block.text = current_text
949
+ elif current_block.type == "tool_use":
950
+ try:
951
+ current_block.input = json.loads(current_tool_input_json)
952
+ except json.JSONDecodeError:
953
+ current_block.input = {}
954
+ content_blocks.append(current_block)
955
+ current_block = None
956
+ current_signature = ""
957
+ current_thinking = ""
958
+ current_text = ""
959
+ current_tool_input_json = ""
960
+ elif event.type == "message_delta":
961
+ if hasattr(event, "usage"):
962
+ usage = event.usage
963
+ if hasattr(event.delta, "stop_reason"):
964
+ stop_reason = event.delta.stop_reason
965
+ return content_blocks, usage, stop_reason
966
+
967
+
371
968
  class GeneratorProxy:
372
- def __init__(self, generator, api_request_arguments, api_key):
969
+ def __init__(self, generator, api_request_arguments, api_key, base_url):
373
970
  self.generator = generator
374
971
  self.results = []
375
972
  self.api_request_arugments = api_request_arguments
376
973
  self.api_key = api_key
974
+ self.base_url = base_url
377
975
 
378
976
  def __iter__(self):
379
977
  return self
@@ -388,6 +986,7 @@ class GeneratorProxy:
388
986
  await self.generator._AsyncMessageStreamManager__api_request,
389
987
  api_request_arguments,
390
988
  self.api_key,
989
+ self.base_url,
391
990
  )
392
991
 
393
992
  def __enter__(self):
@@ -398,6 +997,7 @@ class GeneratorProxy:
398
997
  stream,
399
998
  api_request_arguments,
400
999
  self.api_key,
1000
+ self.base_url,
401
1001
  )
402
1002
 
403
1003
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -416,9 +1016,7 @@ class GeneratorProxy:
416
1016
 
417
1017
  def __getattr__(self, name):
418
1018
  if name == "text_stream": # anthropic async stream
419
- return GeneratorProxy(
420
- self.generator.text_stream, self.api_request_arugments, self.api_key
421
- )
1019
+ return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key, self.base_url)
422
1020
  return getattr(self.generator, name)
423
1021
 
424
1022
  def _abstracted_next(self, result):
@@ -435,12 +1033,12 @@ class GeneratorProxy:
435
1033
  end_anthropic = True
436
1034
 
437
1035
  end_openai = provider_type == "openai" and (
438
- result.choices[0].finish_reason == "stop"
439
- or result.choices[0].finish_reason == "length"
1036
+ result.choices[0].finish_reason == "stop" or result.choices[0].finish_reason == "length"
440
1037
  )
441
1038
 
442
1039
  if end_anthropic or end_openai:
443
1040
  request_id = promptlayer_api_request(
1041
+ base_url=self.base_url,
444
1042
  function_name=self.api_request_arugments["function_name"],
445
1043
  provider_type=self.api_request_arugments["provider_type"],
446
1044
  args=self.api_request_arugments["args"],
@@ -451,9 +1049,7 @@ class GeneratorProxy:
451
1049
  request_end_time=self.api_request_arugments["request_end_time"],
452
1050
  api_key=self.api_key,
453
1051
  return_pl_id=self.api_request_arugments["return_pl_id"],
454
- llm_request_span_id=self.api_request_arugments.get(
455
- "llm_request_span_id"
456
- ),
1052
+ llm_request_span_id=self.api_request_arugments.get("llm_request_span_id"),
457
1053
  )
458
1054
 
459
1055
  if self.api_request_arugments["return_pl_id"]:
@@ -470,31 +1066,35 @@ class GeneratorProxy:
470
1066
  response = ""
471
1067
  for result in self.results:
472
1068
  if hasattr(result, "completion"):
473
- response = f"{response}{result.completion}"
1069
+ response += result.completion
474
1070
  elif hasattr(result, "message") and isinstance(result.message, str):
475
- response = f"{response}{result.message}"
1071
+ response += result.message
476
1072
  elif (
477
1073
  hasattr(result, "content_block")
478
1074
  and hasattr(result.content_block, "text")
479
- and "type" in result
480
- and result.type != "message_stop"
1075
+ and getattr(result, "type", None) != "message_stop"
481
1076
  ):
482
- response = f"{response}{result.content_block.text}"
483
- elif hasattr(result, "delta") and hasattr(result.delta, "text"):
484
- response = f"{response}{result.delta.text}"
485
- if (
486
- hasattr(self.results[-1], "type")
487
- and self.results[-1].type == "message_stop"
488
- ): # this is a message stream and not the correct event
1077
+ response += result.content_block.text
1078
+ elif hasattr(result, "delta"):
1079
+ if hasattr(result.delta, "thinking"):
1080
+ response += result.delta.thinking
1081
+ elif hasattr(result.delta, "text"):
1082
+ response += result.delta.text
1083
+
1084
+ # 2) If this is a “stream” (ended by message_stop), reconstruct both ThinkingBlock & TextBlock
1085
+ last_event = self.results[-1]
1086
+ if getattr(last_event, "type", None) == "message_stop":
489
1087
  final_result = deepcopy(self.results[0].message)
490
- final_result.usage = None
491
- content_block = deepcopy(self.results[1].content_block)
492
- content_block.text = response
493
- final_result.content = [content_block]
494
- else:
495
- final_result = deepcopy(self.results[-1])
496
- final_result.completion = response
1088
+
1089
+ content_blocks, usage, stop_reason = build_anthropic_content_blocks(self.results)
1090
+ final_result.content = content_blocks
1091
+ if usage:
1092
+ final_result.usage.output_tokens = usage.output_tokens
1093
+ if stop_reason:
1094
+ final_result.stop_reason = stop_reason
497
1095
  return final_result
1096
+ else:
1097
+ return deepcopy(self.results[-1])
498
1098
  if hasattr(self.results[0].choices[0], "text"): # this is regular completion
499
1099
  response = ""
500
1100
  for result in self.results:
@@ -502,23 +1102,15 @@ class GeneratorProxy:
502
1102
  final_result = deepcopy(self.results[-1])
503
1103
  final_result.choices[0].text = response
504
1104
  return final_result
505
- elif hasattr(
506
- self.results[0].choices[0], "delta"
507
- ): # this is completion with delta
1105
+ elif hasattr(self.results[0].choices[0], "delta"): # this is completion with delta
508
1106
  response = {"role": "", "content": ""}
509
1107
  for result in self.results:
510
- if (
511
- hasattr(result.choices[0].delta, "role")
512
- and result.choices[0].delta.role is not None
513
- ):
1108
+ if hasattr(result.choices[0].delta, "role") and result.choices[0].delta.role is not None:
514
1109
  response["role"] = result.choices[0].delta.role
515
- if (
516
- hasattr(result.choices[0].delta, "content")
517
- and result.choices[0].delta.content is not None
518
- ):
519
- response["content"] = response[
520
- "content"
521
- ] = f"{response['content']}{result.choices[0].delta.content}"
1110
+ if hasattr(result.choices[0].delta, "content") and result.choices[0].delta.content is not None:
1111
+ response["content"] = response["content"] = (
1112
+ f"{response['content']}{result.choices[0].delta.content}"
1113
+ )
522
1114
  final_result = deepcopy(self.results[-1])
523
1115
  final_result.choices[0] = response
524
1116
  return final_result
@@ -537,37 +1129,71 @@ async def run_in_thread_async(executor, func, *args, **kwargs):
537
1129
  def warn_on_bad_response(request_response, main_message):
538
1130
  if hasattr(request_response, "json"):
539
1131
  try:
540
- print(
541
- f"{main_message}: {request_response.json().get('message')}",
542
- file=sys.stderr,
543
- )
1132
+ logger.warning(f"{main_message}: {request_response.json().get('message')}")
544
1133
  except json.JSONDecodeError:
545
- print(
546
- f"{main_message}: {request_response}",
547
- file=sys.stderr,
548
- )
1134
+ logger.warning(f"{main_message}: {request_response}")
549
1135
  else:
550
- print(f"{main_message}: {request_response}", file=sys.stderr)
1136
+ logger.warning(f"{main_message}: {request_response}")
551
1137
 
552
1138
 
553
1139
  def raise_on_bad_response(request_response, main_message):
1140
+ """Raise an appropriate exception based on the HTTP status code."""
1141
+ status_code = getattr(request_response, "status_code", None)
1142
+
1143
+ body = None
1144
+ error_detail = None
554
1145
  if hasattr(request_response, "json"):
555
1146
  try:
556
- raise Exception(f"{main_message}: {request_response.json().get('message')}")
557
- except json.JSONDecodeError:
558
- raise Exception(f"{main_message}: {request_response}")
1147
+ body = request_response.json()
1148
+ error_detail = body.get("message") or body.get("error") or body.get("detail")
1149
+ except (json.JSONDecodeError, AttributeError):
1150
+ body = getattr(request_response, "text", str(request_response))
1151
+ error_detail = body
559
1152
  else:
560
- raise Exception(f"{main_message}: {request_response}")
1153
+ body = str(request_response)
1154
+ error_detail = body
1155
+
1156
+ if error_detail:
1157
+ err_msg = f"{main_message}: {error_detail}"
1158
+ else:
1159
+ err_msg = main_message
1160
+
1161
+ if status_code == 400:
1162
+ raise _exceptions.PromptLayerBadRequestError(err_msg, response=request_response, body=body)
1163
+
1164
+ if status_code == 401:
1165
+ raise _exceptions.PromptLayerAuthenticationError(err_msg, response=request_response, body=body)
1166
+
1167
+ if status_code == 403:
1168
+ raise _exceptions.PromptLayerPermissionDeniedError(err_msg, response=request_response, body=body)
1169
+
1170
+ if status_code == 404:
1171
+ raise _exceptions.PromptLayerNotFoundError(err_msg, response=request_response, body=body)
1172
+
1173
+ if status_code == 409:
1174
+ raise _exceptions.PromptLayerConflictError(err_msg, response=request_response, body=body)
1175
+
1176
+ if status_code == 422:
1177
+ raise _exceptions.PromptLayerUnprocessableEntityError(err_msg, response=request_response, body=body)
1178
+
1179
+ if status_code == 429:
1180
+ raise _exceptions.PromptLayerRateLimitError(err_msg, response=request_response, body=body)
1181
+
1182
+ if status_code and status_code >= 500:
1183
+ raise _exceptions.PromptLayerInternalServerError(err_msg, response=request_response, body=body)
1184
+
1185
+ raise _exceptions.PromptLayerAPIStatusError(err_msg, response=request_response, body=body)
561
1186
 
562
1187
 
563
1188
  async def async_wrapper(
1189
+ api_key: str,
1190
+ base_url: str,
564
1191
  coroutine_obj,
565
1192
  return_pl_id,
566
1193
  request_start_time,
567
1194
  function_name,
568
1195
  provider_type,
569
1196
  tags,
570
- api_key: str = None,
571
1197
  llm_request_span_id: str = None,
572
1198
  tracer=None,
573
1199
  *args,
@@ -580,6 +1206,8 @@ async def async_wrapper(
580
1206
  response = await coroutine_obj
581
1207
  request_end_time = datetime.datetime.now().timestamp()
582
1208
  result = await promptlayer_api_handler_async(
1209
+ api_key,
1210
+ base_url,
583
1211
  function_name,
584
1212
  provider_type,
585
1213
  args,
@@ -588,7 +1216,6 @@ async def async_wrapper(
588
1216
  response,
589
1217
  request_start_time,
590
1218
  request_end_time,
591
- api_key,
592
1219
  return_pl_id=return_pl_id,
593
1220
  llm_request_span_id=llm_request_span_id,
594
1221
  )
@@ -603,32 +1230,75 @@ async def async_wrapper(
603
1230
  context.detach(token)
604
1231
 
605
1232
 
606
- def promptlayer_create_group(api_key: str = None):
1233
+ @retry_on_api_error
1234
+ def promptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
607
1235
  try:
608
1236
  request_response = requests.post(
609
- f"{URL_API_PROMPTLAYER}/create-group",
1237
+ f"{base_url}/create-group",
610
1238
  json={
611
1239
  "api_key": api_key,
612
1240
  },
613
1241
  )
614
1242
  if request_response.status_code != 200:
615
- warn_on_bad_response(
616
- request_response,
617
- "WARNING: While creating your group PromptLayer had the following error",
618
- )
619
- return False
1243
+ if throw_on_error:
1244
+ raise_on_bad_response(
1245
+ request_response,
1246
+ "While creating your group PromptLayer had the following error",
1247
+ )
1248
+ else:
1249
+ warn_on_bad_response(
1250
+ request_response,
1251
+ "WARNING: While creating your group PromptLayer had the following error",
1252
+ )
1253
+ return False
620
1254
  except requests.exceptions.RequestException as e:
621
- # I'm aiming for a more specific exception catch here
622
- raise Exception(
623
- f"PromptLayer had the following error while creating your group: {e}"
624
- )
1255
+ if throw_on_error:
1256
+ raise _exceptions.PromptLayerAPIConnectionError(
1257
+ f"PromptLayer had the following error while creating your group: {e}", response=None, body=None
1258
+ ) from e
1259
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1260
+ return False
625
1261
  return request_response.json()["id"]
626
1262
 
627
1263
 
628
- def promptlayer_track_group(request_id, group_id, api_key: str = None):
1264
+ @retry_on_api_error
1265
+ async def apromptlayer_create_group(api_key: str, base_url: str, throw_on_error: bool):
1266
+ try:
1267
+ async with _make_httpx_client() as client:
1268
+ response = await client.post(
1269
+ f"{base_url}/create-group",
1270
+ json={
1271
+ "api_key": api_key,
1272
+ },
1273
+ )
1274
+
1275
+ if response.status_code != 200:
1276
+ if throw_on_error:
1277
+ raise_on_bad_response(
1278
+ response,
1279
+ "While creating your group, PromptLayer had the following error",
1280
+ )
1281
+ else:
1282
+ warn_on_bad_response(
1283
+ response,
1284
+ "WARNING: While creating your group, PromptLayer had the following error",
1285
+ )
1286
+ return False
1287
+ return response.json()["id"]
1288
+ except httpx.RequestError as e:
1289
+ if throw_on_error:
1290
+ raise _exceptions.PromptLayerAPIConnectionError(
1291
+ f"PromptLayer had the following error while creating your group: {str(e)}", response=None, body=None
1292
+ ) from e
1293
+ logger.warning(f"While creating your group PromptLayer had the following error: {e}")
1294
+ return False
1295
+
1296
+
1297
+ @retry_on_api_error
1298
+ def promptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
629
1299
  try:
630
1300
  request_response = requests.post(
631
- f"{URL_API_PROMPTLAYER}/track-group",
1301
+ f"{base_url}/track-group",
632
1302
  json={
633
1303
  "api_key": api_key,
634
1304
  "request_id": request_id,
@@ -636,49 +1306,170 @@ def promptlayer_track_group(request_id, group_id, api_key: str = None):
636
1306
  },
637
1307
  )
638
1308
  if request_response.status_code != 200:
639
- warn_on_bad_response(
640
- request_response,
641
- "WARNING: While tracking your group PromptLayer had the following error",
642
- )
643
- return False
1309
+ if throw_on_error:
1310
+ raise_on_bad_response(
1311
+ request_response,
1312
+ "While tracking your group PromptLayer had the following error",
1313
+ )
1314
+ else:
1315
+ warn_on_bad_response(
1316
+ request_response,
1317
+ "WARNING: While tracking your group PromptLayer had the following error",
1318
+ )
1319
+ return False
644
1320
  except requests.exceptions.RequestException as e:
645
- # I'm aiming for a more specific exception catch here
646
- raise Exception(
647
- f"PromptLayer had the following error while tracking your group: {e}"
648
- )
1321
+ if throw_on_error:
1322
+ raise _exceptions.PromptLayerAPIConnectionError(
1323
+ f"PromptLayer had the following error while tracking your group: {e}", response=None, body=None
1324
+ ) from e
1325
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
1326
+ return False
649
1327
  return True
650
1328
 
651
1329
 
1330
+ @retry_on_api_error
1331
+ async def apromptlayer_track_group(api_key: str, base_url: str, throw_on_error: bool, request_id, group_id):
1332
+ try:
1333
+ payload = {
1334
+ "api_key": api_key,
1335
+ "request_id": request_id,
1336
+ "group_id": group_id,
1337
+ }
1338
+ async with _make_httpx_client() as client:
1339
+ response = await client.post(
1340
+ f"{base_url}/track-group",
1341
+ headers={"X-API-KEY": api_key},
1342
+ json=payload,
1343
+ )
1344
+
1345
+ if response.status_code != 200:
1346
+ if throw_on_error:
1347
+ raise_on_bad_response(
1348
+ response,
1349
+ "While tracking your group, PromptLayer had the following error",
1350
+ )
1351
+ else:
1352
+ warn_on_bad_response(
1353
+ response,
1354
+ "WARNING: While tracking your group, PromptLayer had the following error",
1355
+ )
1356
+ return False
1357
+ except httpx.RequestError as e:
1358
+ if throw_on_error:
1359
+ raise _exceptions.PromptLayerAPIConnectionError(
1360
+ f"PromptLayer had the following error while tracking your group: {str(e)}", response=None, body=None
1361
+ ) from e
1362
+ logger.warning(f"While tracking your group PromptLayer had the following error: {e}")
1363
+ return False
1364
+
1365
+ return True
1366
+
1367
+
1368
+ @retry_on_api_error
652
1369
  def get_prompt_template(
653
- prompt_name: str, params: Union[GetPromptTemplate, None] = None, api_key: str = None
1370
+ api_key: str, base_url: str, throw_on_error: bool, prompt_name: str, params: Union[GetPromptTemplate, None] = None
654
1371
  ) -> GetPromptTemplateResponse:
655
1372
  try:
656
1373
  json_body = {"api_key": api_key}
657
1374
  if params:
658
1375
  json_body = {**json_body, **params}
659
1376
  response = requests.post(
660
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1377
+ f"{base_url}/prompt-templates/{prompt_name}",
661
1378
  headers={"X-API-KEY": api_key},
662
1379
  json=json_body,
663
1380
  )
664
1381
  if response.status_code != 200:
665
- raise Exception(
666
- f"PromptLayer had the following error while getting your prompt template: {response.text}"
667
- )
1382
+ if throw_on_error:
1383
+ raise_on_bad_response(
1384
+ response, "PromptLayer had the following error while getting your prompt template"
1385
+ )
1386
+ else:
1387
+ warn_on_bad_response(
1388
+ response, "WARNING: PromptLayer had the following error while getting your prompt template"
1389
+ )
1390
+ return None
1391
+
668
1392
  return response.json()
1393
+ except requests.exceptions.ConnectionError as e:
1394
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1395
+ if throw_on_error:
1396
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1397
+ logger.warning(err_msg)
1398
+ return None
1399
+ except requests.exceptions.Timeout as e:
1400
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1401
+ if throw_on_error:
1402
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1403
+ logger.warning(err_msg)
1404
+ return None
669
1405
  except requests.exceptions.RequestException as e:
670
- raise Exception(
671
- f"PromptLayer had the following error while getting your prompt template: {e}"
672
- )
673
-
674
-
1406
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {e}"
1407
+ if throw_on_error:
1408
+ raise _exceptions.PromptLayerError(err_msg, response=None, body=None) from e
1409
+ logger.warning(err_msg)
1410
+ return None
1411
+
1412
+
1413
+ @retry_on_api_error
1414
+ async def aget_prompt_template(
1415
+ api_key: str,
1416
+ base_url: str,
1417
+ throw_on_error: bool,
1418
+ prompt_name: str,
1419
+ params: Union[GetPromptTemplate, None] = None,
1420
+ ) -> GetPromptTemplateResponse:
1421
+ try:
1422
+ json_body = {"api_key": api_key}
1423
+ if params:
1424
+ json_body.update(params)
1425
+ async with _make_httpx_client() as client:
1426
+ response = await client.post(
1427
+ f"{base_url}/prompt-templates/{quote(prompt_name, safe='')}",
1428
+ headers={"X-API-KEY": api_key},
1429
+ json=json_body,
1430
+ )
1431
+ if response.status_code != 200:
1432
+ if throw_on_error:
1433
+ raise_on_bad_response(
1434
+ response,
1435
+ "PromptLayer had the following error while getting your prompt template",
1436
+ )
1437
+ else:
1438
+ warn_on_bad_response(
1439
+ response, "WARNING: While getting your prompt template PromptLayer had the following error"
1440
+ )
1441
+ return None
1442
+ return response.json()
1443
+ except (httpx.ConnectError, httpx.NetworkError) as e:
1444
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1445
+ if throw_on_error:
1446
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1447
+ logger.warning(err_msg)
1448
+ return None
1449
+ except httpx.TimeoutException as e:
1450
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1451
+ if throw_on_error:
1452
+ raise _exceptions.PromptLayerAPITimeoutError(err_msg, response=None, body=None) from e
1453
+ logger.warning(err_msg)
1454
+ return None
1455
+ except httpx.RequestError as e:
1456
+ err_msg = f"PromptLayer had the following error while getting your prompt template: {str(e)}"
1457
+ if throw_on_error:
1458
+ raise _exceptions.PromptLayerAPIConnectionError(err_msg, response=None, body=None) from e
1459
+ logger.warning(err_msg)
1460
+ return None
1461
+
1462
+
1463
+ @retry_on_api_error
675
1464
  def publish_prompt_template(
1465
+ api_key: str,
1466
+ base_url: str,
1467
+ throw_on_error: bool,
676
1468
  body: PublishPromptTemplate,
677
- api_key: str = None,
678
1469
  ) -> PublishPromptTemplateResponse:
679
1470
  try:
680
1471
  response = requests.post(
681
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1472
+ f"{base_url}/rest/prompt-templates",
682
1473
  headers={"X-API-KEY": api_key},
683
1474
  json={
684
1475
  "prompt_template": {**body},
@@ -687,167 +1478,142 @@ def publish_prompt_template(
687
1478
  },
688
1479
  )
689
1480
  if response.status_code == 400:
690
- raise Exception(
691
- f"PromptLayer had the following error while publishing your prompt template: {response.text}"
692
- )
1481
+ if throw_on_error:
1482
+ raise_on_bad_response(
1483
+ response, "PromptLayer had the following error while publishing your prompt template"
1484
+ )
1485
+ else:
1486
+ warn_on_bad_response(
1487
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1488
+ )
1489
+ return None
693
1490
  return response.json()
694
1491
  except requests.exceptions.RequestException as e:
695
- raise Exception(
696
- f"PromptLayer had the following error while publishing your prompt template: {e}"
697
- )
698
-
1492
+ if throw_on_error:
1493
+ raise _exceptions.PromptLayerAPIConnectionError(
1494
+ f"PromptLayer had the following error while publishing your prompt template: {e}",
1495
+ response=None,
1496
+ body=None,
1497
+ ) from e
1498
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1499
+ return None
1500
+
1501
+
1502
+ @retry_on_api_error
1503
+ async def apublish_prompt_template(
1504
+ api_key: str,
1505
+ base_url: str,
1506
+ throw_on_error: bool,
1507
+ body: PublishPromptTemplate,
1508
+ ) -> PublishPromptTemplateResponse:
1509
+ try:
1510
+ async with _make_httpx_client() as client:
1511
+ response = await client.post(
1512
+ f"{base_url}/rest/prompt-templates",
1513
+ headers={"X-API-KEY": api_key},
1514
+ json={
1515
+ "prompt_template": {**body},
1516
+ "prompt_version": {**body},
1517
+ "release_labels": body.get("release_labels"),
1518
+ },
1519
+ )
699
1520
 
1521
+ if response.status_code == 400 or response.status_code != 201:
1522
+ if throw_on_error:
1523
+ raise_on_bad_response(
1524
+ response,
1525
+ "PromptLayer had the following error while publishing your prompt template",
1526
+ )
1527
+ else:
1528
+ warn_on_bad_response(
1529
+ response, "WARNING: PromptLayer had the following error while publishing your prompt template"
1530
+ )
1531
+ return None
1532
+ return response.json()
1533
+ except httpx.RequestError as e:
1534
+ if throw_on_error:
1535
+ raise _exceptions.PromptLayerAPIConnectionError(
1536
+ f"PromptLayer had the following error while publishing your prompt template: {str(e)}",
1537
+ response=None,
1538
+ body=None,
1539
+ ) from e
1540
+ logger.warning(f"PromptLayer had the following error while publishing your prompt template: {e}")
1541
+ return None
1542
+
1543
+
1544
+ @retry_on_api_error
700
1545
  def get_all_prompt_templates(
701
- page: int = 1, per_page: int = 30, api_key: str = None
1546
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
702
1547
  ) -> List[ListPromptTemplateResponse]:
703
1548
  try:
1549
+ params = {"page": page, "per_page": per_page}
1550
+ if label:
1551
+ params["label"] = label
704
1552
  response = requests.get(
705
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1553
+ f"{base_url}/prompt-templates",
706
1554
  headers={"X-API-KEY": api_key},
707
- params={"page": page, "per_page": per_page},
1555
+ params=params,
708
1556
  )
709
1557
  if response.status_code != 200:
710
- raise Exception(
711
- f"PromptLayer had the following error while getting all your prompt templates: {response.text}"
712
- )
1558
+ if throw_on_error:
1559
+ raise_on_bad_response(
1560
+ response, "PromptLayer had the following error while getting all your prompt templates"
1561
+ )
1562
+ else:
1563
+ warn_on_bad_response(
1564
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1565
+ )
1566
+ return []
713
1567
  items = response.json().get("items", [])
714
1568
  return items
715
1569
  except requests.exceptions.RequestException as e:
716
- raise Exception(
717
- f"PromptLayer had the following error while getting all your prompt templates: {e}"
718
- )
719
-
720
-
721
- def openai_stream_chat(results: list):
722
- from openai.types.chat import (
723
- ChatCompletion,
724
- ChatCompletionChunk,
725
- ChatCompletionMessage,
726
- )
727
- from openai.types.chat.chat_completion import Choice
728
-
729
- chat_completion_chunks: List[ChatCompletionChunk] = results
730
- response: ChatCompletion = ChatCompletion(
731
- id="",
732
- object="chat.completion",
733
- choices=[
734
- Choice(
735
- finish_reason="stop",
736
- index=0,
737
- message=ChatCompletionMessage(role="assistant"),
1570
+ if throw_on_error:
1571
+ raise _exceptions.PromptLayerAPIConnectionError(
1572
+ f"PromptLayer had the following error while getting all your prompt templates: {e}",
1573
+ response=None,
1574
+ body=None,
1575
+ ) from e
1576
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1577
+ return []
1578
+
1579
+
1580
+ @retry_on_api_error
1581
+ async def aget_all_prompt_templates(
1582
+ api_key: str, base_url: str, throw_on_error: bool, page: int = 1, per_page: int = 30, label: str = None
1583
+ ) -> List[ListPromptTemplateResponse]:
1584
+ try:
1585
+ params = {"page": page, "per_page": per_page}
1586
+ if label:
1587
+ params["label"] = label
1588
+ async with _make_httpx_client() as client:
1589
+ response = await client.get(
1590
+ f"{base_url}/prompt-templates",
1591
+ headers={"X-API-KEY": api_key},
1592
+ params=params,
738
1593
  )
739
- ],
740
- created=0,
741
- model="",
742
- )
743
- last_result = chat_completion_chunks[-1]
744
- response.id = last_result.id
745
- response.created = last_result.created
746
- response.model = last_result.model
747
- response.system_fingerprint = last_result.system_fingerprint
748
- response.usage = last_result.usage
749
- content = ""
750
- for result in chat_completion_chunks:
751
- if len(result.choices) > 0 and result.choices[0].delta.content:
752
- content = f"{content}{result.choices[0].delta.content}"
753
- response.choices[0].message.content = content
754
- return response
755
-
756
-
757
- def openai_stream_completion(results: list):
758
- from openai.types.completion import Completion, CompletionChoice
759
-
760
- completions: List[Completion] = results
761
- last_chunk = completions[-1]
762
- response = Completion(
763
- id=last_chunk.id,
764
- created=last_chunk.created,
765
- model=last_chunk.model,
766
- object="text_completion",
767
- choices=[CompletionChoice(finish_reason="stop", index=0, text="")],
768
- )
769
- text = ""
770
- for completion in completions:
771
- usage = completion.usage
772
- system_fingerprint = completion.system_fingerprint
773
- if len(completion.choices) > 0 and completion.choices[0].text:
774
- text = f"{text}{completion.choices[0].text}"
775
- if usage:
776
- response.usage = usage
777
- if system_fingerprint:
778
- response.system_fingerprint = system_fingerprint
779
- response.choices[0].text = text
780
- return response
781
-
782
-
783
- def anthropic_stream_message(results: list):
784
- from anthropic.types import Message, MessageStreamEvent, TextBlock, Usage
785
-
786
- message_stream_events: List[MessageStreamEvent] = results
787
- response: Message = Message(
788
- id="",
789
- model="",
790
- content=[],
791
- role="assistant",
792
- type="message",
793
- stop_reason="stop_sequence",
794
- stop_sequence=None,
795
- usage=Usage(input_tokens=0, output_tokens=0),
796
- )
797
- content = ""
798
- for result in message_stream_events:
799
- if result.type == "message_start":
800
- response = result.message
801
- elif result.type == "content_block_delta":
802
- if result.delta.type == "text_delta":
803
- content = f"{content}{result.delta.text}"
804
- elif result.type == "message_delta":
805
- if hasattr(result, "usage"):
806
- response.usage.output_tokens = result.usage.output_tokens
807
- if hasattr(result.delta, "stop_reason"):
808
- response.stop_reason = result.delta.stop_reason
809
- response.content.append(TextBlock(type="text", text=content))
810
- return response
811
-
812
-
813
- def anthropic_stream_completion(results: list):
814
- from anthropic.types import Completion
815
-
816
- completions: List[Completion] = results
817
- last_chunk = completions[-1]
818
- response = Completion(
819
- id=last_chunk.id,
820
- completion="",
821
- model=last_chunk.model,
822
- stop_reason="stop",
823
- type="completion",
824
- )
825
-
826
- text = ""
827
- for completion in completions:
828
- text = f"{text}{completion.completion}"
829
- response.completion = text
830
- return response
831
-
832
1594
 
833
- def stream_response(
834
- generator: Generator, after_stream: Callable, map_results: Callable
835
- ):
836
- data = {
837
- "request_id": None,
838
- "raw_response": None,
839
- "prompt_blueprint": None,
840
- }
841
- results = []
842
- for result in generator:
843
- results.append(result)
844
- data["raw_response"] = result
845
- yield data
846
- request_response = map_results(results)
847
- response = after_stream(request_response=request_response.model_dump())
848
- data["request_id"] = response.get("request_id")
849
- data["prompt_blueprint"] = response.get("prompt_blueprint")
850
- yield data
1595
+ if response.status_code != 200:
1596
+ if throw_on_error:
1597
+ raise_on_bad_response(
1598
+ response,
1599
+ "PromptLayer had the following error while getting all your prompt templates",
1600
+ )
1601
+ else:
1602
+ warn_on_bad_response(
1603
+ response, "WARNING: PromptLayer had the following error while getting all your prompt templates"
1604
+ )
1605
+ return []
1606
+ items = response.json().get("items", [])
1607
+ return items
1608
+ except httpx.RequestError as e:
1609
+ if throw_on_error:
1610
+ raise _exceptions.PromptLayerAPIConnectionError(
1611
+ f"PromptLayer had the following error while getting all your prompt templates: {str(e)}",
1612
+ response=None,
1613
+ body=None,
1614
+ ) from e
1615
+ logger.warning(f"PromptLayer had the following error while getting all your prompt templates: {e}")
1616
+ return []
851
1617
 
852
1618
 
853
1619
  def openai_chat_request(client, **kwargs):
@@ -864,14 +1630,75 @@ MAP_TYPE_TO_OPENAI_FUNCTION = {
864
1630
  }
865
1631
 
866
1632
 
867
- def openai_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1633
+ def openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
868
1634
  from openai import OpenAI
869
1635
 
870
- client = OpenAI(base_url=kwargs.pop("base_url", None))
871
- request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[
872
- prompt_blueprint["prompt_template"]["type"]
873
- ]
874
- return request_to_make(client, **kwargs)
1636
+ client = OpenAI(**client_kwargs)
1637
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1638
+
1639
+ if api_type is None:
1640
+ api_type = "chat-completions"
1641
+
1642
+ if api_type == "chat-completions":
1643
+ request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1644
+ return request_to_make(client, **function_kwargs)
1645
+ else:
1646
+ return client.responses.create(**function_kwargs)
1647
+
1648
+
1649
+ async def aopenai_chat_request(client, **kwargs):
1650
+ return await client.chat.completions.create(**kwargs)
1651
+
1652
+
1653
+ async def aopenai_completions_request(client, **kwargs):
1654
+ return await client.completions.create(**kwargs)
1655
+
1656
+
1657
+ AMAP_TYPE_TO_OPENAI_FUNCTION = {
1658
+ "chat": aopenai_chat_request,
1659
+ "completion": aopenai_completions_request,
1660
+ }
1661
+
1662
+
1663
+ async def aopenai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1664
+ from openai import AsyncOpenAI
1665
+
1666
+ client = AsyncOpenAI(**client_kwargs)
1667
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1668
+
1669
+ if api_type == "chat-completions":
1670
+ request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1671
+ return await request_to_make(client, **function_kwargs)
1672
+ else:
1673
+ return await client.responses.create(**function_kwargs)
1674
+
1675
+
1676
+ def azure_openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1677
+ from openai import AzureOpenAI
1678
+
1679
+ client = AzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1680
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1681
+
1682
+ if api_type == "chat-completions":
1683
+ request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1684
+ return request_to_make(client, **function_kwargs)
1685
+ else:
1686
+ return client.responses.create(**function_kwargs)
1687
+
1688
+
1689
+ async def aazure_openai_request(
1690
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
1691
+ ):
1692
+ from openai import AsyncAzureOpenAI
1693
+
1694
+ client = AsyncAzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1695
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1696
+
1697
+ if api_type == "chat-completions":
1698
+ request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1699
+ return await request_to_make(client, **function_kwargs)
1700
+ else:
1701
+ return await client.responses.create(**function_kwargs)
875
1702
 
876
1703
 
877
1704
  def anthropic_chat_request(client, **kwargs):
@@ -888,14 +1715,34 @@ MAP_TYPE_TO_ANTHROPIC_FUNCTION = {
888
1715
  }
889
1716
 
890
1717
 
891
- def anthropic_request(prompt_blueprint: GetPromptTemplateResponse, **kwargs):
1718
+ def anthropic_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
892
1719
  from anthropic import Anthropic
893
1720
 
894
- client = Anthropic(base_url=kwargs.pop("base_url", None))
895
- request_to_make = MAP_TYPE_TO_ANTHROPIC_FUNCTION[
896
- prompt_blueprint["prompt_template"]["type"]
897
- ]
898
- return request_to_make(client, **kwargs)
1721
+ client = Anthropic(**client_kwargs)
1722
+ request_to_make = MAP_TYPE_TO_ANTHROPIC_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1723
+ return request_to_make(client, **function_kwargs)
1724
+
1725
+
1726
+ async def aanthropic_chat_request(client, **kwargs):
1727
+ return await client.messages.create(**kwargs)
1728
+
1729
+
1730
+ async def aanthropic_completions_request(client, **kwargs):
1731
+ return await client.completions.create(**kwargs)
1732
+
1733
+
1734
+ AMAP_TYPE_TO_ANTHROPIC_FUNCTION = {
1735
+ "chat": aanthropic_chat_request,
1736
+ "completion": aanthropic_completions_request,
1737
+ }
1738
+
1739
+
1740
+ async def aanthropic_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1741
+ from anthropic import AsyncAnthropic
1742
+
1743
+ client = AsyncAnthropic(**client_kwargs)
1744
+ request_to_make = AMAP_TYPE_TO_ANTHROPIC_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1745
+ return await request_to_make(client, **function_kwargs)
899
1746
 
900
1747
 
901
1748
  # do not remove! This is used in the langchain integration.
@@ -903,7 +1750,336 @@ def get_api_key():
903
1750
  # raise an error if the api key is not set
904
1751
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
905
1752
  if not api_key:
906
- raise Exception(
907
- "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>' "
1753
+ raise _exceptions.PromptLayerAuthenticationError(
1754
+ "Please set your PROMPTLAYER_API_KEY environment variable or set API KEY in code using 'promptlayer.api_key = <your_api_key>'",
1755
+ response=None,
1756
+ body=None,
908
1757
  )
909
1758
  return api_key
1759
+
1760
+
1761
+ @retry_on_api_error
1762
+ def util_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1763
+ try:
1764
+ response = requests.post(
1765
+ f"{base_url}/log-request",
1766
+ headers={"X-API-KEY": api_key},
1767
+ json=kwargs,
1768
+ )
1769
+ if response.status_code != 201:
1770
+ if throw_on_error:
1771
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1772
+ else:
1773
+ warn_on_bad_response(
1774
+ response,
1775
+ "WARNING: While logging your request PromptLayer had the following error",
1776
+ )
1777
+ return None
1778
+ return response.json()
1779
+ except Exception as e:
1780
+ if throw_on_error:
1781
+ raise _exceptions.PromptLayerAPIError(
1782
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1783
+ ) from e
1784
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1785
+ return None
1786
+
1787
+
1788
+ @retry_on_api_error
1789
+ async def autil_log_request(api_key: str, base_url: str, throw_on_error: bool, **kwargs) -> Union[RequestLog, None]:
1790
+ try:
1791
+ async with _make_httpx_client() as client:
1792
+ response = await client.post(
1793
+ f"{base_url}/log-request",
1794
+ headers={"X-API-KEY": api_key},
1795
+ json=kwargs,
1796
+ )
1797
+ if response.status_code != 201:
1798
+ if throw_on_error:
1799
+ raise_on_bad_response(response, "PromptLayer had the following error while logging your request")
1800
+ else:
1801
+ warn_on_bad_response(
1802
+ response,
1803
+ "WARNING: While logging your request PromptLayer had the following error",
1804
+ )
1805
+ return None
1806
+ return response.json()
1807
+ except Exception as e:
1808
+ if throw_on_error:
1809
+ raise _exceptions.PromptLayerAPIError(
1810
+ f"While logging your request PromptLayer had the following error: {e}", response=None, body=None
1811
+ ) from e
1812
+ logger.warning(f"While tracking your prompt PromptLayer had the following error: {e}")
1813
+ return None
1814
+
1815
+
1816
+ def mistral_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1817
+ from mistralai import Mistral
1818
+
1819
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"), client=_make_simple_httpx_client())
1820
+ if "stream" in function_kwargs and function_kwargs["stream"]:
1821
+ function_kwargs.pop("stream")
1822
+ return client.chat.stream(**function_kwargs)
1823
+ if "stream" in function_kwargs:
1824
+ function_kwargs.pop("stream")
1825
+ return client.chat.complete(**function_kwargs)
1826
+
1827
+
1828
+ async def amistral_request(
1829
+ prompt_blueprint: GetPromptTemplateResponse,
1830
+ _: dict,
1831
+ function_kwargs: dict,
1832
+ ):
1833
+ from mistralai import Mistral
1834
+
1835
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"), async_client=_make_httpx_client())
1836
+ if "stream" in function_kwargs and function_kwargs["stream"]:
1837
+ return await client.chat.stream_async(**function_kwargs)
1838
+ return await client.chat.complete_async(**function_kwargs)
1839
+
1840
+
1841
+ class _GoogleStreamWrapper:
1842
+ """Wrapper to keep Google client alive during streaming."""
1843
+
1844
+ def __init__(self, stream_generator, client):
1845
+ self._stream = stream_generator
1846
+ self._client = client # Keep client alive
1847
+
1848
+ def __iter__(self):
1849
+ return self._stream.__iter__()
1850
+
1851
+ def __next__(self):
1852
+ return next(self._stream)
1853
+
1854
+ def __aiter__(self):
1855
+ return self._stream.__aiter__()
1856
+
1857
+ async def __anext__(self):
1858
+ return await self._stream.__anext__()
1859
+
1860
+
1861
+ def google_chat_request(client, **kwargs):
1862
+ from google.genai.chats import Content
1863
+
1864
+ stream = kwargs.pop("stream", False)
1865
+ model = kwargs.get("model", "gemini-2.0-flash")
1866
+ history = [Content(**item) for item in kwargs.get("history", [])]
1867
+ generation_config = kwargs.get("generation_config", {})
1868
+ chat = client.chats.create(model=model, history=history, config=generation_config)
1869
+ last_message = history[-1].parts if history else ""
1870
+ if stream:
1871
+ stream_gen = chat.send_message_stream(message=last_message)
1872
+ return _GoogleStreamWrapper(stream_gen, client)
1873
+ return chat.send_message(message=last_message)
1874
+
1875
+
1876
+ def google_completions_request(client, **kwargs):
1877
+ config = kwargs.pop("generation_config", {})
1878
+ model = kwargs.get("model", "gemini-2.0-flash")
1879
+ contents = kwargs.get("contents", [])
1880
+ stream = kwargs.pop("stream", False)
1881
+ if stream:
1882
+ stream_gen = client.models.generate_content_stream(model=model, contents=contents, config=config)
1883
+ return _GoogleStreamWrapper(stream_gen, client)
1884
+ return client.models.generate_content(model=model, contents=contents, config=config)
1885
+
1886
+
1887
+ MAP_TYPE_TO_GOOGLE_FUNCTION = {
1888
+ "chat": google_chat_request,
1889
+ "completion": google_completions_request,
1890
+ }
1891
+
1892
+
1893
+ def google_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1894
+ from google import genai
1895
+
1896
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
1897
+ client = genai.Client(
1898
+ vertexai=True,
1899
+ project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
1900
+ location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
1901
+ )
1902
+ else:
1903
+ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"))
1904
+ request_to_make = MAP_TYPE_TO_GOOGLE_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1905
+ return request_to_make(client, **function_kwargs)
1906
+
1907
+
1908
+ async def agoogle_chat_request(client, **kwargs):
1909
+ from google.genai.chats import Content
1910
+
1911
+ stream = kwargs.pop("stream", False)
1912
+ model = kwargs.get("model", "gemini-2.0-flash")
1913
+ history = [Content(**item) for item in kwargs.get("history", [])]
1914
+ generation_config = kwargs.get("generation_config", {})
1915
+ chat = client.aio.chats.create(model=model, history=history, config=generation_config)
1916
+ last_message = history[-1].parts[0] if history else ""
1917
+ if stream:
1918
+ stream_gen = await chat.send_message_stream(message=last_message)
1919
+ return _GoogleStreamWrapper(stream_gen, client)
1920
+ return await chat.send_message(message=last_message)
1921
+
1922
+
1923
+ async def agoogle_completions_request(client, **kwargs):
1924
+ config = kwargs.pop("generation_config", {})
1925
+ model = kwargs.get("model", "gemini-2.0-flash")
1926
+ contents = kwargs.get("contents", [])
1927
+ stream = kwargs.pop("stream", False)
1928
+ if stream:
1929
+ stream_gen = await client.aio.models.generate_content_stream(model=model, contents=contents, config=config)
1930
+ return _GoogleStreamWrapper(stream_gen, client)
1931
+ return await client.aio.models.generate_content(model=model, contents=contents, config=config)
1932
+
1933
+
1934
+ AMAP_TYPE_TO_GOOGLE_FUNCTION = {
1935
+ "chat": agoogle_chat_request,
1936
+ "completion": agoogle_completions_request,
1937
+ }
1938
+
1939
+
1940
+ async def agoogle_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1941
+ from google import genai
1942
+
1943
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
1944
+ client = genai.Client(
1945
+ vertexai=True,
1946
+ project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
1947
+ location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
1948
+ )
1949
+ else:
1950
+ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"))
1951
+ request_to_make = AMAP_TYPE_TO_GOOGLE_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
1952
+ return await request_to_make(client, **function_kwargs)
1953
+
1954
+
1955
+ def vertexai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1956
+ if "gemini" in prompt_blueprint["metadata"]["model"]["name"]:
1957
+ return google_request(
1958
+ prompt_blueprint=prompt_blueprint,
1959
+ client_kwargs=client_kwargs,
1960
+ function_kwargs=function_kwargs,
1961
+ )
1962
+
1963
+ if "claude" in prompt_blueprint["metadata"]["model"]["name"]:
1964
+ from anthropic import AnthropicVertex
1965
+
1966
+ client = AnthropicVertex(**client_kwargs)
1967
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
1968
+ return anthropic_chat_request(client=client, **function_kwargs)
1969
+ raise NotImplementedError(
1970
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Vertex AI"
1971
+ )
1972
+
1973
+ raise NotImplementedError(
1974
+ f"Vertex AI request for model {prompt_blueprint['metadata']['model']['name']} is not implemented yet."
1975
+ )
1976
+
1977
+
1978
+ async def avertexai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
1979
+ if "gemini" in prompt_blueprint["metadata"]["model"]["name"]:
1980
+ return await agoogle_request(
1981
+ prompt_blueprint=prompt_blueprint,
1982
+ client_kwargs=client_kwargs,
1983
+ function_kwargs=function_kwargs,
1984
+ )
1985
+
1986
+ if "claude" in prompt_blueprint["metadata"]["model"]["name"]:
1987
+ from anthropic import AsyncAnthropicVertex
1988
+
1989
+ client = AsyncAnthropicVertex(**client_kwargs)
1990
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
1991
+ return await aanthropic_chat_request(client=client, **function_kwargs)
1992
+ raise NotImplementedError(
1993
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Vertex AI"
1994
+ )
1995
+
1996
+ raise NotImplementedError(
1997
+ f"Vertex AI request for model {prompt_blueprint['metadata']['model']['name']} is not implemented yet."
1998
+ )
1999
+
2000
+
2001
+ def amazon_bedrock_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
2002
+ import boto3
2003
+
2004
+ bedrock_client = boto3.client(
2005
+ "bedrock-runtime",
2006
+ aws_access_key_id=function_kwargs.pop("aws_access_key", None),
2007
+ aws_secret_access_key=function_kwargs.pop("aws_secret_key", None),
2008
+ region_name=function_kwargs.pop("aws_region", "us-east-1"),
2009
+ )
2010
+
2011
+ stream = function_kwargs.pop("stream", False)
2012
+
2013
+ if stream:
2014
+ return bedrock_client.converse_stream(**function_kwargs)
2015
+ else:
2016
+ return bedrock_client.converse(**function_kwargs)
2017
+
2018
+
2019
+ async def aamazon_bedrock_request(
2020
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
2021
+ ):
2022
+ import aioboto3
2023
+
2024
+ aws_access_key = function_kwargs.pop("aws_access_key", None)
2025
+ aws_secret_key = function_kwargs.pop("aws_secret_key", None)
2026
+ aws_region = function_kwargs.pop("aws_region", "us-east-1")
2027
+
2028
+ session_kwargs = {}
2029
+ if aws_access_key:
2030
+ session_kwargs["aws_access_key_id"] = aws_access_key
2031
+ if aws_secret_key:
2032
+ session_kwargs["aws_secret_access_key"] = aws_secret_key
2033
+ if aws_region:
2034
+ session_kwargs["region_name"] = aws_region
2035
+
2036
+ stream = function_kwargs.pop("stream", False)
2037
+ session = aioboto3.Session()
2038
+
2039
+ async with session.client("bedrock-runtime", **session_kwargs) as client:
2040
+ if stream:
2041
+ return await client.converse_stream(**function_kwargs)
2042
+ else:
2043
+ return await client.converse(**function_kwargs)
2044
+
2045
+
2046
+ def anthropic_bedrock_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict):
2047
+ from anthropic import AnthropicBedrock
2048
+
2049
+ client = AnthropicBedrock(
2050
+ aws_access_key=function_kwargs.pop("aws_access_key", None),
2051
+ aws_secret_key=function_kwargs.pop("aws_secret_key", None),
2052
+ aws_region=function_kwargs.pop("aws_region", None),
2053
+ aws_session_token=function_kwargs.pop("aws_session_token", None),
2054
+ base_url=function_kwargs.pop("base_url", None),
2055
+ **client_kwargs,
2056
+ )
2057
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
2058
+ return anthropic_chat_request(client=client, **function_kwargs)
2059
+ elif prompt_blueprint["prompt_template"]["type"] == "completion":
2060
+ return anthropic_completions_request(client=client, **function_kwargs)
2061
+ raise NotImplementedError(
2062
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Bedrock"
2063
+ )
2064
+
2065
+
2066
+ async def aanthropic_bedrock_request(
2067
+ prompt_blueprint: GetPromptTemplateResponse, client_kwargs: dict, function_kwargs: dict
2068
+ ):
2069
+ from anthropic import AsyncAnthropicBedrock
2070
+
2071
+ client = AsyncAnthropicBedrock(
2072
+ aws_access_key=function_kwargs.pop("aws_access_key", None),
2073
+ aws_secret_key=function_kwargs.pop("aws_secret_key", None),
2074
+ aws_region=function_kwargs.pop("aws_region", None),
2075
+ aws_session_token=function_kwargs.pop("aws_session_token", None),
2076
+ base_url=function_kwargs.pop("base_url", None),
2077
+ **client_kwargs,
2078
+ )
2079
+ if prompt_blueprint["prompt_template"]["type"] == "chat":
2080
+ return await aanthropic_chat_request(client=client, **function_kwargs)
2081
+ elif prompt_blueprint["prompt_template"]["type"] == "completion":
2082
+ return await aanthropic_completions_request(client=client, **function_kwargs)
2083
+ raise NotImplementedError(
2084
+ f"Unsupported prompt template type {prompt_blueprint['prompt_template']['type']}' for Anthropic Bedrock"
2085
+ )