promptlayer 1.0.70__py3-none-any.whl → 1.0.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of promptlayer might be problematic. Click here for more details.

promptlayer/utils.py CHANGED
@@ -7,15 +7,24 @@ import logging
7
7
  import os
8
8
  import sys
9
9
  import types
10
+ from contextlib import asynccontextmanager
10
11
  from copy import deepcopy
11
12
  from enum import Enum
12
- from typing import Any, Dict, List, Optional, Union
13
+ from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
13
14
  from uuid import uuid4
14
15
 
15
16
  import httpx
16
17
  import requests
18
+ import urllib3
19
+ import urllib3.util
17
20
  from ably import AblyRealtime
18
21
  from ably.types.message import Message
22
+ from centrifuge import (
23
+ Client,
24
+ PublicationContext,
25
+ SubscriptionEventHandler,
26
+ SubscriptionState,
27
+ )
19
28
  from opentelemetry import context, trace
20
29
 
21
30
  from promptlayer.types import RequestLog
@@ -28,8 +37,7 @@ from promptlayer.types.prompt_template import (
28
37
  )
29
38
 
30
39
  # Configuration
31
- # TODO(dmu) MEDIUM: Use `PROMPTLAYER_` prefix instead of `_PROMPTLAYER` suffix
32
- URL_API_PROMPTLAYER = os.environ.setdefault("URL_API_PROMPTLAYER", "https://api.promptlayer.com")
40
+
33
41
  RERAISE_ORIGINAL_EXCEPTION = os.getenv("PROMPTLAYER_RE_RAISE_ORIGINAL_EXCEPTION", "False").lower() == "true"
34
42
  RAISE_FOR_STATUS = os.getenv("PROMPTLAYER_RAISE_FOR_STATUS", "False").lower() == "true"
35
43
  DEFAULT_HTTP_TIMEOUT = 5
@@ -37,7 +45,9 @@ DEFAULT_HTTP_TIMEOUT = 5
37
45
  WORKFLOW_RUN_URL_TEMPLATE = "{base_url}/workflows/{workflow_id}/run"
38
46
  WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE = "workflows:{workflow_id}:run:{channel_name_suffix}"
39
47
  SET_WORKFLOW_COMPLETE_MESSAGE = "SET_WORKFLOW_COMPLETE"
40
- WS_TOKEN_REQUEST_LIBRARY_URL = URL_API_PROMPTLAYER + "/ws-token-request-library"
48
+ WS_TOKEN_REQUEST_LIBRARY_URL = (
49
+ f"{os.getenv('PROMPTLAYER_BASE_URL', 'https://api.promptlayer.com')}/ws-token-request-library"
50
+ )
41
51
 
42
52
 
43
53
  logger = logging.getLogger(__name__)
@@ -71,10 +81,12 @@ def _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name):
71
81
  return workflow_id_or_name
72
82
 
73
83
 
74
- async def _get_final_output(execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]) -> Dict[str, Any]:
84
+ async def _get_final_output(
85
+ base_url: str, execution_id: int, return_all_outputs: bool, *, headers: Dict[str, str]
86
+ ) -> Dict[str, Any]:
75
87
  async with httpx.AsyncClient() as client:
76
88
  response = await client.get(
77
- f"{URL_API_PROMPTLAYER}/workflow-version-execution-results",
89
+ f"{base_url}/workflow-version-execution-results",
78
90
  headers=headers,
79
91
  params={"workflow_version_execution_id": execution_id, "return_all_outputs": return_all_outputs},
80
92
  )
@@ -84,14 +96,14 @@ async def _get_final_output(execution_id: int, return_all_outputs: bool, *, head
84
96
 
85
97
  # TODO(dmu) MEDIUM: Consider putting all these functions into a class, so we do not have to pass
86
98
  # `authorization_headers` into each function
87
- async def _resolve_workflow_id(workflow_id_or_name: Union[int, str], headers):
99
+ async def _resolve_workflow_id(base_url: str, workflow_id_or_name: Union[int, str], headers):
88
100
  if isinstance(workflow_id_or_name, int):
89
101
  return workflow_id_or_name
90
102
 
91
103
  # TODO(dmu) LOW: Should we warn user here to avoid using workflow names in favor of workflow id?
92
104
  async with _make_httpx_client() as client:
93
105
  # TODO(dmu) MEDIUM: Generalize the way we make async calls to PromptLayer API and reuse it everywhere
94
- response = await client.get(f"{URL_API_PROMPTLAYER}/workflows/{workflow_id_or_name}", headers=headers)
106
+ response = await client.get(f"{base_url}/workflows/{workflow_id_or_name}", headers=headers)
95
107
  if RAISE_FOR_STATUS:
96
108
  response.raise_for_status()
97
109
  elif response.status_code != 200:
@@ -100,11 +112,11 @@ async def _resolve_workflow_id(workflow_id_or_name: Union[int, str], headers):
100
112
  return response.json()["workflow"]["id"]
101
113
 
102
114
 
103
- async def _get_ably_token(channel_name, authentication_headers):
115
+ async def _get_ably_token(base_url: str, channel_name, authentication_headers):
104
116
  try:
105
117
  async with _make_httpx_client() as client:
106
118
  response = await client.post(
107
- f"{URL_API_PROMPTLAYER}/ws-token-request-library",
119
+ f"{base_url}/ws-token-request-library",
108
120
  headers=authentication_headers,
109
121
  params={"capability": channel_name},
110
122
  )
@@ -115,7 +127,7 @@ async def _get_ably_token(channel_name, authentication_headers):
115
127
  response,
116
128
  "PromptLayer had the following error while getting WebSocket token",
117
129
  )
118
- return response.json()["token_details"]["token"]
130
+ return response.json()
119
131
  except Exception as ex:
120
132
  error_message = f"Failed to get WebSocket token: {ex}"
121
133
  print(error_message) # TODO(dmu) MEDIUM: Remove prints in favor of logging
@@ -126,7 +138,7 @@ async def _get_ably_token(channel_name, authentication_headers):
126
138
  raise Exception(error_message)
127
139
 
128
140
 
129
- def _make_message_listener(results_future, execution_id_future, return_all_outputs, headers):
141
+ def _make_message_listener(base_url: str, results_future, execution_id_future, return_all_outputs, headers):
130
142
  # We need this function to be mocked by unittests
131
143
  async def message_listener(message: Message):
132
144
  if results_future.cancelled() or message.name != SET_WORKFLOW_COMPLETE_MESSAGE:
@@ -140,7 +152,7 @@ def _make_message_listener(results_future, execution_id_future, return_all_outpu
140
152
  if (result_code := message_data.get("result_code")) in (FinalOutputCode.OK.value, None):
141
153
  results = message_data["final_output"]
142
154
  elif result_code == FinalOutputCode.EXCEEDS_SIZE_LIMIT.value:
143
- results = await _get_final_output(execution_id, return_all_outputs, headers=headers)
155
+ results = await _get_final_output(base_url, execution_id, return_all_outputs, headers=headers)
144
156
  else:
145
157
  raise NotImplementedError(f"Unsupported final output code: {result_code}")
146
158
 
@@ -149,15 +161,20 @@ def _make_message_listener(results_future, execution_id_future, return_all_outpu
149
161
  return message_listener
150
162
 
151
163
 
152
- async def _subscribe_to_workflow_completion_channel(channel, execution_id_future, return_all_outputs, headers):
164
+ async def _subscribe_to_workflow_completion_channel(
165
+ base_url: str, channel, execution_id_future, return_all_outputs, headers
166
+ ):
153
167
  results_future = asyncio.Future()
154
- message_listener = _make_message_listener(results_future, execution_id_future, return_all_outputs, headers)
168
+ message_listener = _make_message_listener(
169
+ base_url, results_future, execution_id_future, return_all_outputs, headers
170
+ )
155
171
  await channel.subscribe(SET_WORKFLOW_COMPLETE_MESSAGE, message_listener)
156
172
  return results_future, message_listener
157
173
 
158
174
 
159
175
  async def _post_workflow_id_run(
160
176
  *,
177
+ base_url: str,
161
178
  authentication_headers,
162
179
  workflow_id,
163
180
  input_variables: Dict[str, Any],
@@ -168,7 +185,7 @@ async def _post_workflow_id_run(
168
185
  channel_name_suffix: str,
169
186
  _url_template: str = WORKFLOW_RUN_URL_TEMPLATE,
170
187
  ):
171
- url = _url_template.format(base_url=URL_API_PROMPTLAYER, workflow_id=workflow_id)
188
+ url = _url_template.format(base_url=base_url, workflow_id=workflow_id)
172
189
  payload = {
173
190
  "input_variables": input_variables,
174
191
  "metadata": metadata,
@@ -215,14 +232,53 @@ def _make_channel_name_suffix():
215
232
  return uuid4().hex
216
233
 
217
234
 
235
+ MessageCallback = Callable[[Message], Coroutine[None, None, None]]
236
+
237
+
238
+ class SubscriptionEventLoggerHandler(SubscriptionEventHandler):
239
+ def __init__(self, callback: MessageCallback):
240
+ self.callback = callback
241
+
242
+ async def on_publication(self, ctx: PublicationContext):
243
+ message_name = ctx.pub.data.get("message_name", "unknown")
244
+ data = ctx.pub.data.get("data", "")
245
+ message = Message(name=message_name, data=data)
246
+ await self.callback(message)
247
+
248
+
249
+ @asynccontextmanager
250
+ async def centrifugo_client(address: str, token: str):
251
+ client = Client(address, token=token)
252
+ try:
253
+ await client.connect()
254
+ yield client
255
+ finally:
256
+ await client.disconnect()
257
+
258
+
259
+ @asynccontextmanager
260
+ async def centrifugo_subscription(client: Client, topic: str, message_listener: MessageCallback):
261
+ subscription = client.new_subscription(
262
+ topic,
263
+ events=SubscriptionEventLoggerHandler(message_listener),
264
+ )
265
+ try:
266
+ await subscription.subscribe()
267
+ yield
268
+ finally:
269
+ if subscription.state == SubscriptionState.SUBSCRIBED:
270
+ await subscription.unsubscribe()
271
+
272
+
218
273
  async def arun_workflow_request(
219
274
  *,
275
+ api_key: str,
276
+ base_url: str,
220
277
  workflow_id_or_name: Optional[Union[int, str]] = None,
221
278
  input_variables: Dict[str, Any],
222
279
  metadata: Optional[Dict[str, Any]] = None,
223
280
  workflow_label_name: Optional[str] = None,
224
281
  workflow_version_number: Optional[int] = None,
225
- api_key: str,
226
282
  return_all_outputs: Optional[bool] = False,
227
283
  timeout: Optional[int] = 3600,
228
284
  # `workflow_name` deprecated, kept for backward compatibility only.
@@ -230,22 +286,50 @@ async def arun_workflow_request(
230
286
  ):
231
287
  headers = {"X-API-KEY": api_key}
232
288
  workflow_id = await _resolve_workflow_id(
233
- _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
289
+ base_url, _get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name), headers
234
290
  )
235
291
  channel_name_suffix = _make_channel_name_suffix()
236
292
  channel_name = WORKFLOW_RUN_CHANNEL_NAME_TEMPLATE.format(
237
293
  workflow_id=workflow_id, channel_name_suffix=channel_name_suffix
238
294
  )
239
- ably_token = await _get_ably_token(channel_name, headers)
240
- async with AblyRealtime(token=ably_token) as ably_client:
295
+ ably_token = await _get_ably_token(base_url, channel_name, headers)
296
+ token = ably_token["token_details"]["token"]
297
+
298
+ execution_id_future = asyncio.Future[int]()
299
+
300
+ if ably_token.get("messaging_backend") == "centrifugo":
301
+ address = urllib3.util.parse_url(base_url)._replace(scheme="wss", path="/connection/websocket").url
302
+ async with centrifugo_client(address, token) as client:
303
+ results_future = asyncio.Future[dict[str, Any]]()
304
+ async with centrifugo_subscription(
305
+ client,
306
+ channel_name,
307
+ _make_message_listener(base_url, results_future, execution_id_future, return_all_outputs, headers),
308
+ ):
309
+ execution_id = await _post_workflow_id_run(
310
+ base_url=base_url,
311
+ authentication_headers=headers,
312
+ workflow_id=workflow_id,
313
+ input_variables=input_variables,
314
+ metadata=metadata,
315
+ workflow_label_name=workflow_label_name,
316
+ workflow_version_number=workflow_version_number,
317
+ return_all_outputs=return_all_outputs,
318
+ channel_name_suffix=channel_name_suffix,
319
+ )
320
+ execution_id_future.set_result(execution_id)
321
+ await asyncio.wait_for(results_future, timeout)
322
+ return results_future.result()
323
+
324
+ async with AblyRealtime(token=token) as ably_client:
241
325
  # It is crucial to subscribe before running a workflow, otherwise we may miss a completion message
242
326
  channel = ably_client.channels.get(channel_name)
243
- execution_id_future = asyncio.Future()
244
327
  results_future, message_listener = await _subscribe_to_workflow_completion_channel(
245
- channel, execution_id_future, return_all_outputs, headers
328
+ base_url, channel, execution_id_future, return_all_outputs, headers
246
329
  )
247
330
 
248
331
  execution_id = await _post_workflow_id_run(
332
+ base_url=base_url,
249
333
  authentication_headers=headers,
250
334
  workflow_id=workflow_id,
251
335
  input_variables=input_variables,
@@ -261,6 +345,8 @@ async def arun_workflow_request(
261
345
 
262
346
 
263
347
  def promptlayer_api_handler(
348
+ api_key: str,
349
+ base_url: str,
264
350
  function_name,
265
351
  provider_type,
266
352
  args,
@@ -269,7 +355,6 @@ def promptlayer_api_handler(
269
355
  response,
270
356
  request_start_time,
271
357
  request_end_time,
272
- api_key,
273
358
  return_pl_id=False,
274
359
  llm_request_span_id=None,
275
360
  ):
@@ -292,9 +377,11 @@ def promptlayer_api_handler(
292
377
  "llm_request_span_id": llm_request_span_id,
293
378
  },
294
379
  api_key=api_key,
380
+ base_url=base_url,
295
381
  )
296
382
  else:
297
383
  request_id = promptlayer_api_request(
384
+ base_url=base_url,
298
385
  function_name=function_name,
299
386
  provider_type=provider_type,
300
387
  args=args,
@@ -313,6 +400,8 @@ def promptlayer_api_handler(
313
400
 
314
401
 
315
402
  async def promptlayer_api_handler_async(
403
+ api_key: str,
404
+ base_url: str,
316
405
  function_name,
317
406
  provider_type,
318
407
  args,
@@ -321,13 +410,14 @@ async def promptlayer_api_handler_async(
321
410
  response,
322
411
  request_start_time,
323
412
  request_end_time,
324
- api_key,
325
413
  return_pl_id=False,
326
414
  llm_request_span_id=None,
327
415
  ):
328
416
  return await run_in_thread_async(
329
417
  None,
330
418
  promptlayer_api_handler,
419
+ api_key,
420
+ base_url,
331
421
  function_name,
332
422
  provider_type,
333
423
  args,
@@ -336,7 +426,6 @@ async def promptlayer_api_handler_async(
336
426
  response,
337
427
  request_start_time,
338
428
  request_end_time,
339
- api_key,
340
429
  return_pl_id=return_pl_id,
341
430
  llm_request_span_id=llm_request_span_id,
342
431
  )
@@ -356,6 +445,7 @@ def convert_native_object_to_dict(native_object):
356
445
 
357
446
  def promptlayer_api_request(
358
447
  *,
448
+ base_url: str,
359
449
  function_name,
360
450
  provider_type,
361
451
  args,
@@ -376,7 +466,7 @@ def promptlayer_api_request(
376
466
  response = response.dict()
377
467
  try:
378
468
  request_response = requests.post(
379
- f"{URL_API_PROMPTLAYER}/track-request",
469
+ f"{base_url}/track-request",
380
470
  json={
381
471
  "function_name": function_name,
382
472
  "provider_type": provider_type,
@@ -405,10 +495,10 @@ def promptlayer_api_request(
405
495
  return request_response.json().get("request_id")
406
496
 
407
497
 
408
- def track_request(**body):
498
+ def track_request(base_url: str, **body):
409
499
  try:
410
500
  response = requests.post(
411
- f"{URL_API_PROMPTLAYER}/track-request",
501
+ f"{base_url}/track-request",
412
502
  json=body,
413
503
  )
414
504
  if response.status_code != 200:
@@ -421,11 +511,11 @@ def track_request(**body):
421
511
  return {}
422
512
 
423
513
 
424
- async def atrack_request(**body: Any) -> Dict[str, Any]:
514
+ async def atrack_request(base_url: str, **body: Any) -> Dict[str, Any]:
425
515
  try:
426
516
  async with _make_httpx_client() as client:
427
517
  response = await client.post(
428
- f"{URL_API_PROMPTLAYER}/track-request",
518
+ f"{base_url}/track-request",
429
519
  json=body,
430
520
  )
431
521
  if RAISE_FOR_STATUS:
@@ -468,7 +558,7 @@ def promptlayer_api_request_async(
468
558
  )
469
559
 
470
560
 
471
- def promptlayer_get_prompt(prompt_name, api_key, version: int = None, label: str = None):
561
+ def promptlayer_get_prompt(api_key: str, base_url: str, prompt_name, version: int = None, label: str = None):
472
562
  """
473
563
  Get a prompt from the PromptLayer library
474
564
  version: version of the prompt to get, None for latest
@@ -476,7 +566,7 @@ def promptlayer_get_prompt(prompt_name, api_key, version: int = None, label: str
476
566
  """
477
567
  try:
478
568
  request_response = requests.get(
479
- f"{URL_API_PROMPTLAYER}/library-get-prompt-template",
569
+ f"{base_url}/library-get-prompt-template",
480
570
  headers={"X-API-KEY": api_key},
481
571
  params={"prompt_name": prompt_name, "version": version, "label": label},
482
572
  )
@@ -491,10 +581,12 @@ def promptlayer_get_prompt(prompt_name, api_key, version: int = None, label: str
491
581
  return request_response.json()
492
582
 
493
583
 
494
- def promptlayer_publish_prompt(prompt_name, prompt_template, commit_message, tags, api_key, metadata=None):
584
+ def promptlayer_publish_prompt(
585
+ api_key: str, base_url: str, prompt_name, prompt_template, commit_message, tags, metadata=None
586
+ ):
495
587
  try:
496
588
  request_response = requests.post(
497
- f"{URL_API_PROMPTLAYER}/library-publish-prompt-template",
589
+ f"{base_url}/library-publish-prompt-template",
498
590
  json={
499
591
  "prompt_name": prompt_name,
500
592
  "prompt_template": prompt_template,
@@ -514,10 +606,10 @@ def promptlayer_publish_prompt(prompt_name, prompt_template, commit_message, tag
514
606
  return True
515
607
 
516
608
 
517
- def promptlayer_track_prompt(request_id, prompt_name, input_variables, api_key, version, label):
609
+ def promptlayer_track_prompt(api_key: str, base_url: str, request_id, prompt_name, input_variables, version, label):
518
610
  try:
519
611
  request_response = requests.post(
520
- f"{URL_API_PROMPTLAYER}/library-track-prompt",
612
+ f"{base_url}/library-track-prompt",
521
613
  json={
522
614
  "request_id": request_id,
523
615
  "prompt_name": prompt_name,
@@ -543,14 +635,15 @@ def promptlayer_track_prompt(request_id, prompt_name, input_variables, api_key,
543
635
 
544
636
 
545
637
  async def apromptlayer_track_prompt(
638
+ api_key: str,
639
+ base_url: str,
546
640
  request_id: str,
547
641
  prompt_name: str,
548
642
  input_variables: Dict[str, Any],
549
- api_key: Optional[str] = None,
550
643
  version: Optional[int] = None,
551
644
  label: Optional[str] = None,
552
645
  ) -> bool:
553
- url = f"{URL_API_PROMPTLAYER}/library-track-prompt"
646
+ url = f"{base_url}/library-track-prompt"
554
647
  payload = {
555
648
  "request_id": request_id,
556
649
  "prompt_name": prompt_name,
@@ -581,10 +674,10 @@ async def apromptlayer_track_prompt(
581
674
  return True
582
675
 
583
676
 
584
- def promptlayer_track_metadata(request_id, metadata, api_key):
677
+ def promptlayer_track_metadata(api_key: str, base_url: str, request_id, metadata):
585
678
  try:
586
679
  request_response = requests.post(
587
- f"{URL_API_PROMPTLAYER}/library-track-metadata",
680
+ f"{base_url}/library-track-metadata",
588
681
  json={
589
682
  "request_id": request_id,
590
683
  "metadata": metadata,
@@ -606,8 +699,8 @@ def promptlayer_track_metadata(request_id, metadata, api_key):
606
699
  return True
607
700
 
608
701
 
609
- async def apromptlayer_track_metadata(request_id: str, metadata: Dict[str, Any], api_key: Optional[str] = None) -> bool:
610
- url = f"{URL_API_PROMPTLAYER}/library-track-metadata"
702
+ async def apromptlayer_track_metadata(api_key: str, base_url: str, request_id: str, metadata: Dict[str, Any]) -> bool:
703
+ url = f"{base_url}/library-track-metadata"
611
704
  payload = {
612
705
  "request_id": request_id,
613
706
  "metadata": metadata,
@@ -635,13 +728,13 @@ async def apromptlayer_track_metadata(request_id: str, metadata: Dict[str, Any],
635
728
  return True
636
729
 
637
730
 
638
- def promptlayer_track_score(request_id, score, score_name, api_key):
731
+ def promptlayer_track_score(api_key: str, base_url: str, request_id, score, score_name):
639
732
  try:
640
733
  data = {"request_id": request_id, "score": score, "api_key": api_key}
641
734
  if score_name is not None:
642
735
  data["name"] = score_name
643
736
  request_response = requests.post(
644
- f"{URL_API_PROMPTLAYER}/library-track-score",
737
+ f"{base_url}/library-track-score",
645
738
  json=data,
646
739
  )
647
740
  if request_response.status_code != 200:
@@ -660,12 +753,13 @@ def promptlayer_track_score(request_id, score, score_name, api_key):
660
753
 
661
754
 
662
755
  async def apromptlayer_track_score(
756
+ api_key: str,
757
+ base_url: str,
663
758
  request_id: str,
664
759
  score: float,
665
760
  score_name: Optional[str],
666
- api_key: Optional[str] = None,
667
761
  ) -> bool:
668
- url = f"{URL_API_PROMPTLAYER}/library-track-score"
762
+ url = f"{base_url}/library-track-score"
669
763
  data = {
670
764
  "request_id": request_id,
671
765
  "score": score,
@@ -753,11 +847,12 @@ def build_anthropic_content_blocks(events):
753
847
 
754
848
 
755
849
  class GeneratorProxy:
756
- def __init__(self, generator, api_request_arguments, api_key):
850
+ def __init__(self, generator, api_request_arguments, api_key, base_url):
757
851
  self.generator = generator
758
852
  self.results = []
759
853
  self.api_request_arugments = api_request_arguments
760
854
  self.api_key = api_key
855
+ self.base_url = base_url
761
856
 
762
857
  def __iter__(self):
763
858
  return self
@@ -772,6 +867,7 @@ class GeneratorProxy:
772
867
  await self.generator._AsyncMessageStreamManager__api_request,
773
868
  api_request_arguments,
774
869
  self.api_key,
870
+ self.base_url,
775
871
  )
776
872
 
777
873
  def __enter__(self):
@@ -782,6 +878,7 @@ class GeneratorProxy:
782
878
  stream,
783
879
  api_request_arguments,
784
880
  self.api_key,
881
+ self.base_url,
785
882
  )
786
883
 
787
884
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -800,7 +897,7 @@ class GeneratorProxy:
800
897
 
801
898
  def __getattr__(self, name):
802
899
  if name == "text_stream": # anthropic async stream
803
- return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key)
900
+ return GeneratorProxy(self.generator.text_stream, self.api_request_arugments, self.api_key, self.base_url)
804
901
  return getattr(self.generator, name)
805
902
 
806
903
  def _abstracted_next(self, result):
@@ -822,6 +919,7 @@ class GeneratorProxy:
822
919
 
823
920
  if end_anthropic or end_openai:
824
921
  request_id = promptlayer_api_request(
922
+ base_url=self.base_url,
825
923
  function_name=self.api_request_arugments["function_name"],
826
924
  provider_type=self.api_request_arugments["provider_type"],
827
925
  args=self.api_request_arugments["args"],
@@ -938,13 +1036,14 @@ def raise_on_bad_response(request_response, main_message):
938
1036
 
939
1037
 
940
1038
  async def async_wrapper(
1039
+ api_key: str,
1040
+ base_url: str,
941
1041
  coroutine_obj,
942
1042
  return_pl_id,
943
1043
  request_start_time,
944
1044
  function_name,
945
1045
  provider_type,
946
1046
  tags,
947
- api_key: str = None,
948
1047
  llm_request_span_id: str = None,
949
1048
  tracer=None,
950
1049
  *args,
@@ -957,6 +1056,8 @@ async def async_wrapper(
957
1056
  response = await coroutine_obj
958
1057
  request_end_time = datetime.datetime.now().timestamp()
959
1058
  result = await promptlayer_api_handler_async(
1059
+ api_key,
1060
+ base_url,
960
1061
  function_name,
961
1062
  provider_type,
962
1063
  args,
@@ -965,7 +1066,6 @@ async def async_wrapper(
965
1066
  response,
966
1067
  request_start_time,
967
1068
  request_end_time,
968
- api_key,
969
1069
  return_pl_id=return_pl_id,
970
1070
  llm_request_span_id=llm_request_span_id,
971
1071
  )
@@ -980,10 +1080,10 @@ async def async_wrapper(
980
1080
  context.detach(token)
981
1081
 
982
1082
 
983
- def promptlayer_create_group(api_key: str = None):
1083
+ def promptlayer_create_group(api_key: str, base_url: str):
984
1084
  try:
985
1085
  request_response = requests.post(
986
- f"{URL_API_PROMPTLAYER}/create-group",
1086
+ f"{base_url}/create-group",
987
1087
  json={
988
1088
  "api_key": api_key,
989
1089
  },
@@ -1000,11 +1100,11 @@ def promptlayer_create_group(api_key: str = None):
1000
1100
  return request_response.json()["id"]
1001
1101
 
1002
1102
 
1003
- async def apromptlayer_create_group(api_key: Optional[str] = None) -> str:
1103
+ async def apromptlayer_create_group(api_key: str, base_url: str):
1004
1104
  try:
1005
1105
  async with _make_httpx_client() as client:
1006
1106
  response = await client.post(
1007
- f"{URL_API_PROMPTLAYER}/create-group",
1107
+ f"{base_url}/create-group",
1008
1108
  json={
1009
1109
  "api_key": api_key,
1010
1110
  },
@@ -1023,10 +1123,10 @@ async def apromptlayer_create_group(api_key: Optional[str] = None) -> str:
1023
1123
  raise Exception(f"PromptLayer had the following error while creating your group: {str(e)}") from e
1024
1124
 
1025
1125
 
1026
- def promptlayer_track_group(request_id, group_id, api_key: str = None):
1126
+ def promptlayer_track_group(api_key: str, base_url: str, request_id, group_id):
1027
1127
  try:
1028
1128
  request_response = requests.post(
1029
- f"{URL_API_PROMPTLAYER}/track-group",
1129
+ f"{base_url}/track-group",
1030
1130
  json={
1031
1131
  "api_key": api_key,
1032
1132
  "request_id": request_id,
@@ -1045,7 +1145,7 @@ def promptlayer_track_group(request_id, group_id, api_key: str = None):
1045
1145
  return True
1046
1146
 
1047
1147
 
1048
- async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1148
+ async def apromptlayer_track_group(api_key: str, base_url: str, request_id, group_id):
1049
1149
  try:
1050
1150
  payload = {
1051
1151
  "api_key": api_key,
@@ -1054,7 +1154,7 @@ async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1054
1154
  }
1055
1155
  async with _make_httpx_client() as client:
1056
1156
  response = await client.post(
1057
- f"{URL_API_PROMPTLAYER}/track-group",
1157
+ f"{base_url}/track-group",
1058
1158
  headers={"X-API-KEY": api_key},
1059
1159
  json=payload,
1060
1160
  )
@@ -1078,14 +1178,14 @@ async def apromptlayer_track_group(request_id, group_id, api_key: str = None):
1078
1178
 
1079
1179
 
1080
1180
  def get_prompt_template(
1081
- prompt_name: str, params: Union[GetPromptTemplate, None] = None, api_key: str = None
1181
+ api_key: str, base_url: str, prompt_name: str, params: Union[GetPromptTemplate, None] = None
1082
1182
  ) -> GetPromptTemplateResponse:
1083
1183
  try:
1084
1184
  json_body = {"api_key": api_key}
1085
1185
  if params:
1086
1186
  json_body = {**json_body, **params}
1087
1187
  response = requests.post(
1088
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1188
+ f"{base_url}/prompt-templates/{prompt_name}",
1089
1189
  headers={"X-API-KEY": api_key},
1090
1190
  json=json_body,
1091
1191
  )
@@ -1104,9 +1204,10 @@ def get_prompt_template(
1104
1204
 
1105
1205
 
1106
1206
  async def aget_prompt_template(
1207
+ api_key: str,
1208
+ base_url: str,
1107
1209
  prompt_name: str,
1108
1210
  params: Union[GetPromptTemplate, None] = None,
1109
- api_key: str = None,
1110
1211
  ) -> GetPromptTemplateResponse:
1111
1212
  try:
1112
1213
  json_body = {"api_key": api_key}
@@ -1114,7 +1215,7 @@ async def aget_prompt_template(
1114
1215
  json_body.update(params)
1115
1216
  async with _make_httpx_client() as client:
1116
1217
  response = await client.post(
1117
- f"{URL_API_PROMPTLAYER}/prompt-templates/{prompt_name}",
1218
+ f"{base_url}/prompt-templates/{prompt_name}",
1118
1219
  headers={"X-API-KEY": api_key},
1119
1220
  json=json_body,
1120
1221
  )
@@ -1138,12 +1239,13 @@ async def aget_prompt_template(
1138
1239
 
1139
1240
 
1140
1241
  def publish_prompt_template(
1242
+ api_key: str,
1243
+ base_url: str,
1141
1244
  body: PublishPromptTemplate,
1142
- api_key: str = None,
1143
1245
  ) -> PublishPromptTemplateResponse:
1144
1246
  try:
1145
1247
  response = requests.post(
1146
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1248
+ f"{base_url}/rest/prompt-templates",
1147
1249
  headers={"X-API-KEY": api_key},
1148
1250
  json={
1149
1251
  "prompt_template": {**body},
@@ -1161,13 +1263,14 @@ def publish_prompt_template(
1161
1263
 
1162
1264
 
1163
1265
  async def apublish_prompt_template(
1266
+ api_key: str,
1267
+ base_url: str,
1164
1268
  body: PublishPromptTemplate,
1165
- api_key: str = None,
1166
1269
  ) -> PublishPromptTemplateResponse:
1167
1270
  try:
1168
1271
  async with _make_httpx_client() as client:
1169
1272
  response = await client.post(
1170
- f"{URL_API_PROMPTLAYER}/rest/prompt-templates",
1273
+ f"{base_url}/rest/prompt-templates",
1171
1274
  headers={"X-API-KEY": api_key},
1172
1275
  json={
1173
1276
  "prompt_template": {**body},
@@ -1193,14 +1296,14 @@ async def apublish_prompt_template(
1193
1296
 
1194
1297
 
1195
1298
  def get_all_prompt_templates(
1196
- page: int = 1, per_page: int = 30, api_key: str = None, label: str = None
1299
+ api_key: str, base_url: str, page: int = 1, per_page: int = 30, label: str = None
1197
1300
  ) -> List[ListPromptTemplateResponse]:
1198
1301
  try:
1199
1302
  params = {"page": page, "per_page": per_page}
1200
1303
  if label:
1201
1304
  params["label"] = label
1202
1305
  response = requests.get(
1203
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1306
+ f"{base_url}/prompt-templates",
1204
1307
  headers={"X-API-KEY": api_key},
1205
1308
  params=params,
1206
1309
  )
@@ -1215,7 +1318,7 @@ def get_all_prompt_templates(
1215
1318
 
1216
1319
 
1217
1320
  async def aget_all_prompt_templates(
1218
- page: int = 1, per_page: int = 30, api_key: str = None, label: str = None
1321
+ api_key: str, base_url: str, page: int = 1, per_page: int = 30, label: str = None
1219
1322
  ) -> List[ListPromptTemplateResponse]:
1220
1323
  try:
1221
1324
  params = {"page": page, "per_page": per_page}
@@ -1223,7 +1326,7 @@ async def aget_all_prompt_templates(
1223
1326
  params["label"] = label
1224
1327
  async with _make_httpx_client() as client:
1225
1328
  response = await client.get(
1226
- f"{URL_API_PROMPTLAYER}/prompt-templates",
1329
+ f"{base_url}/prompt-templates",
1227
1330
  headers={"X-API-KEY": api_key},
1228
1331
  params=params,
1229
1332
  )
@@ -1259,7 +1362,7 @@ def openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwargs: d
1259
1362
  from openai import OpenAI
1260
1363
 
1261
1364
  client = OpenAI(**client_kwargs)
1262
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1365
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1263
1366
 
1264
1367
  if api_type == "chat-completions":
1265
1368
  request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1286,7 +1389,7 @@ async def aopenai_request(prompt_blueprint: GetPromptTemplateResponse, client_kw
1286
1389
  from openai import AsyncOpenAI
1287
1390
 
1288
1391
  client = AsyncOpenAI(**client_kwargs)
1289
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1392
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1290
1393
 
1291
1394
  if api_type == "chat-completions":
1292
1395
  request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1299,7 +1402,7 @@ def azure_openai_request(prompt_blueprint: GetPromptTemplateResponse, client_kwa
1299
1402
  from openai import AzureOpenAI
1300
1403
 
1301
1404
  client = AzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1302
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1405
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1303
1406
 
1304
1407
  if api_type == "chat-completions":
1305
1408
  request_to_make = MAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1314,7 +1417,7 @@ async def aazure_openai_request(
1314
1417
  from openai import AsyncAzureOpenAI
1315
1418
 
1316
1419
  client = AsyncAzureOpenAI(azure_endpoint=client_kwargs.pop("base_url", None))
1317
- api_type = prompt_blueprint["metadata"]["model"]["api_type"]
1420
+ api_type = prompt_blueprint["metadata"]["model"].get("api_type", "chat-completions")
1318
1421
 
1319
1422
  if api_type == "chat-completions":
1320
1423
  request_to_make = AMAP_TYPE_TO_OPENAI_FUNCTION[prompt_blueprint["prompt_template"]["type"]]
@@ -1378,10 +1481,10 @@ def get_api_key():
1378
1481
  return api_key
1379
1482
 
1380
1483
 
1381
- def util_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1484
+ def util_log_request(api_key: str, base_url: str, **kwargs) -> Union[RequestLog, None]:
1382
1485
  try:
1383
1486
  response = requests.post(
1384
- f"{URL_API_PROMPTLAYER}/log-request",
1487
+ f"{base_url}/log-request",
1385
1488
  headers={"X-API-KEY": api_key},
1386
1489
  json=kwargs,
1387
1490
  )
@@ -1400,11 +1503,11 @@ def util_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1400
1503
  return None
1401
1504
 
1402
1505
 
1403
- async def autil_log_request(api_key: str, **kwargs) -> Union[RequestLog, None]:
1506
+ async def autil_log_request(api_key: str, base_url: str, **kwargs) -> Union[RequestLog, None]:
1404
1507
  try:
1405
1508
  async with _make_httpx_client() as client:
1406
1509
  response = await client.post(
1407
- f"{URL_API_PROMPTLAYER}/log-request",
1510
+ f"{base_url}/log-request",
1408
1511
  headers={"X-API-KEY": api_key},
1409
1512
  json=kwargs,
1410
1513
  )
@@ -1456,7 +1559,7 @@ def google_chat_request(client, **kwargs):
1456
1559
  history = [Content(**item) for item in kwargs.get("history", [])]
1457
1560
  generation_config = kwargs.get("generation_config", {})
1458
1561
  chat = client.chats.create(model=model, history=history, config=generation_config)
1459
- last_message = history[-1].parts[0] if history else ""
1562
+ last_message = history[-1].parts if history else ""
1460
1563
  if stream:
1461
1564
  return chat.send_message_stream(message=last_message)
1462
1565
  return chat.send_message(message=last_message)