promptlayer 1.0.72__tar.gz → 1.0.73__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of promptlayer might be problematic. Click here for more details.

Files changed (28) hide show
  1. {promptlayer-1.0.72 → promptlayer-1.0.73}/PKG-INFO +2 -1
  2. promptlayer-1.0.73/promptlayer/__init__.py +39 -0
  3. promptlayer-1.0.73/promptlayer/exceptions.py +119 -0
  4. promptlayer-1.0.73/promptlayer/groups/__init__.py +24 -0
  5. promptlayer-1.0.73/promptlayer/groups/groups.py +9 -0
  6. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/promptlayer.py +59 -18
  7. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/promptlayer_base.py +2 -1
  8. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/promptlayer_mixins.py +4 -2
  9. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/span_exporter.py +16 -7
  10. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/templates.py +9 -7
  11. promptlayer-1.0.73/promptlayer/track/__init__.py +71 -0
  12. promptlayer-1.0.73/promptlayer/track/track.py +107 -0
  13. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/utils.py +554 -246
  14. {promptlayer-1.0.72 → promptlayer-1.0.73}/pyproject.toml +2 -1
  15. promptlayer-1.0.72/promptlayer/__init__.py +0 -4
  16. promptlayer-1.0.72/promptlayer/groups/__init__.py +0 -22
  17. promptlayer-1.0.72/promptlayer/groups/groups.py +0 -9
  18. promptlayer-1.0.72/promptlayer/track/__init__.py +0 -53
  19. promptlayer-1.0.72/promptlayer/track/track.py +0 -88
  20. {promptlayer-1.0.72 → promptlayer-1.0.73}/LICENSE +0 -0
  21. {promptlayer-1.0.72 → promptlayer-1.0.73}/README.md +0 -0
  22. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/streaming/__init__.py +0 -0
  23. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/streaming/blueprint_builder.py +0 -0
  24. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/streaming/response_handlers.py +0 -0
  25. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/streaming/stream_processor.py +0 -0
  26. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/types/__init__.py +0 -0
  27. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/types/prompt_template.py +0 -0
  28. {promptlayer-1.0.72 → promptlayer-1.0.73}/promptlayer/types/request_log.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: promptlayer
3
- Version: 1.0.72
3
+ Version: 1.0.73
4
4
  Summary: PromptLayer is a platform for prompt engineering and tracks your LLM requests.
5
5
  License: Apache-2.0
6
6
  License-File: LICENSE
@@ -23,6 +23,7 @@ Requires-Dist: nest-asyncio (>=1.6.0,<2.0.0)
23
23
  Requires-Dist: opentelemetry-api (>=1.26.0,<2.0.0)
24
24
  Requires-Dist: opentelemetry-sdk (>=1.26.0,<2.0.0)
25
25
  Requires-Dist: requests (>=2.31.0,<3.0.0)
26
+ Requires-Dist: tenacity (>=9.1.2,<10.0.0)
26
27
  Description-Content-Type: text/markdown
27
28
 
28
29
  <div align="center">
@@ -0,0 +1,39 @@
1
+ from .exceptions import (
2
+ PromptLayerAPIConnectionError,
3
+ PromptLayerAPIError,
4
+ PromptLayerAPIStatusError,
5
+ PromptLayerAPITimeoutError,
6
+ PromptLayerAuthenticationError,
7
+ PromptLayerBadRequestError,
8
+ PromptLayerConflictError,
9
+ PromptLayerError,
10
+ PromptLayerInternalServerError,
11
+ PromptLayerNotFoundError,
12
+ PromptLayerPermissionDeniedError,
13
+ PromptLayerRateLimitError,
14
+ PromptLayerUnprocessableEntityError,
15
+ PromptLayerValidationError,
16
+ )
17
+ from .promptlayer import AsyncPromptLayer, PromptLayer
18
+
19
+ __version__ = "1.0.73"
20
+ __all__ = [
21
+ "PromptLayer",
22
+ "AsyncPromptLayer",
23
+ "__version__",
24
+ # Exceptions
25
+ "PromptLayerError",
26
+ "PromptLayerAPIError",
27
+ "PromptLayerBadRequestError",
28
+ "PromptLayerAuthenticationError",
29
+ "PromptLayerPermissionDeniedError",
30
+ "PromptLayerNotFoundError",
31
+ "PromptLayerConflictError",
32
+ "PromptLayerUnprocessableEntityError",
33
+ "PromptLayerRateLimitError",
34
+ "PromptLayerInternalServerError",
35
+ "PromptLayerAPIStatusError",
36
+ "PromptLayerAPIConnectionError",
37
+ "PromptLayerAPITimeoutError",
38
+ "PromptLayerValidationError",
39
+ ]
@@ -0,0 +1,119 @@
1
+ class PromptLayerError(Exception):
2
+ """Base exception for all PromptLayer SDK errors."""
3
+
4
+ def __init__(self, message: str, response=None, body=None):
5
+ super().__init__(message)
6
+ self.message = message
7
+ self.response = response
8
+ self.body = body
9
+
10
+ def __str__(self):
11
+ return self.message
12
+
13
+
14
+ class PromptLayerAPIError(PromptLayerError):
15
+ """Base exception for API-related errors."""
16
+
17
+ pass
18
+
19
+
20
+ class PromptLayerBadRequestError(PromptLayerAPIError):
21
+ """Exception raised for 400 Bad Request errors.
22
+
23
+ Indicates that the request was malformed or contained invalid parameters.
24
+ """
25
+
26
+ pass
27
+
28
+
29
+ class PromptLayerAuthenticationError(PromptLayerAPIError):
30
+ """Exception raised for 401 Unauthorized errors.
31
+
32
+ Indicates that the API key is missing, invalid, or expired.
33
+ """
34
+
35
+ pass
36
+
37
+
38
+ class PromptLayerPermissionDeniedError(PromptLayerAPIError):
39
+ """Exception raised for 403 Forbidden errors.
40
+
41
+ Indicates that the API key doesn't have permission to perform the requested operation.
42
+ """
43
+
44
+ pass
45
+
46
+
47
+ class PromptLayerNotFoundError(PromptLayerAPIError):
48
+ """Exception raised for 404 Not Found errors.
49
+
50
+ Indicates that the requested resource (e.g., prompt template) was not found.
51
+ """
52
+
53
+ pass
54
+
55
+
56
+ class PromptLayerConflictError(PromptLayerAPIError):
57
+ """Exception raised for 409 Conflict errors.
58
+
59
+ Indicates that the request conflicts with the current state of the resource.
60
+ """
61
+
62
+ pass
63
+
64
+
65
+ class PromptLayerUnprocessableEntityError(PromptLayerAPIError):
66
+ """Exception raised for 422 Unprocessable Entity errors.
67
+
68
+ Indicates that the request was well-formed but contains semantic errors.
69
+ """
70
+
71
+ pass
72
+
73
+
74
+ class PromptLayerRateLimitError(PromptLayerAPIError):
75
+ """Exception raised for 429 Too Many Requests errors.
76
+
77
+ Indicates that the API rate limit has been exceeded.
78
+ """
79
+
80
+ pass
81
+
82
+
83
+ class PromptLayerInternalServerError(PromptLayerAPIError):
84
+ """Exception raised for 500+ Internal Server errors.
85
+
86
+ Indicates that the PromptLayer API encountered an internal error.
87
+ """
88
+
89
+ pass
90
+
91
+
92
+ class PromptLayerAPIStatusError(PromptLayerAPIError):
93
+ """Exception raised for other API errors not covered by specific exception classes."""
94
+
95
+ pass
96
+
97
+
98
+ class PromptLayerAPIConnectionError(PromptLayerError):
99
+ """Exception raised when unable to connect to the API.
100
+
101
+ This can be due to network issues, timeouts, or connection errors.
102
+ """
103
+
104
+ pass
105
+
106
+
107
+ class PromptLayerAPITimeoutError(PromptLayerError):
108
+ """Exception raised when an API request times out."""
109
+
110
+ pass
111
+
112
+
113
+ class PromptLayerValidationError(PromptLayerError):
114
+ """Exception raised when input validation fails.
115
+
116
+ This can be due to invalid types, out of range values, or malformed data.
117
+ """
118
+
119
+ pass
@@ -0,0 +1,24 @@
1
+ from promptlayer.groups.groups import acreate, create
2
+
3
+
4
+ class GroupManager:
5
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
6
+ self.api_key = api_key
7
+ self.base_url = base_url
8
+ self.throw_on_error = throw_on_error
9
+
10
+ def create(self):
11
+ return create(self.api_key, self.base_url, self.throw_on_error)
12
+
13
+
14
+ class AsyncGroupManager:
15
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
16
+ self.api_key = api_key
17
+ self.base_url = base_url
18
+ self.throw_on_error = throw_on_error
19
+
20
+ async def create(self):
21
+ return await acreate(self.api_key, self.base_url, self.throw_on_error)
22
+
23
+
24
+ __all__ = ["GroupManager", "AsyncGroupManager"]
@@ -0,0 +1,9 @@
1
+ from promptlayer.utils import apromptlayer_create_group, promptlayer_create_group
2
+
3
+
4
+ def create(api_key: str, base_url: str, throw_on_error: bool):
5
+ return promptlayer_create_group(api_key, base_url, throw_on_error)
6
+
7
+
8
+ async def acreate(api_key: str, base_url: str, throw_on_error: bool):
9
+ return await apromptlayer_create_group(api_key, base_url, throw_on_error)
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
6
6
 
7
7
  import nest_asyncio
8
8
 
9
+ from promptlayer import exceptions as _exceptions
9
10
  from promptlayer.groups import AsyncGroupManager, GroupManager
10
11
  from promptlayer.promptlayer_base import PromptLayerBase
11
12
  from promptlayer.promptlayer_mixins import PromptLayerMixin
@@ -53,7 +54,11 @@ def is_workflow_results_dict(obj: Any) -> bool:
53
54
 
54
55
  class PromptLayer(PromptLayerMixin):
55
56
  def __init__(
56
- self, api_key: Union[str, None] = None, enable_tracing: bool = False, base_url: Union[str, None] = None
57
+ self,
58
+ api_key: Union[str, None] = None,
59
+ enable_tracing: bool = False,
60
+ base_url: Union[str, None] = None,
61
+ throw_on_error: bool = True,
57
62
  ):
58
63
  if api_key is None:
59
64
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
@@ -66,10 +71,13 @@ class PromptLayer(PromptLayerMixin):
66
71
 
67
72
  self.base_url = get_base_url(base_url)
68
73
  self.api_key = api_key
69
- self.templates = TemplateManager(api_key, self.base_url)
70
- self.group = GroupManager(api_key, self.base_url)
71
- self.tracer_provider, self.tracer = self._initialize_tracer(api_key, self.base_url, enable_tracing)
72
- self.track = TrackManager(api_key, self.base_url)
74
+ self.throw_on_error = throw_on_error
75
+ self.templates = TemplateManager(api_key, self.base_url, self.throw_on_error)
76
+ self.group = GroupManager(api_key, self.base_url, self.throw_on_error)
77
+ self.tracer_provider, self.tracer = self._initialize_tracer(
78
+ api_key, self.base_url, self.throw_on_error, enable_tracing
79
+ )
80
+ self.track = TrackManager(api_key, self.base_url, self.throw_on_error)
73
81
 
74
82
  def __getattr__(
75
83
  self,
@@ -114,7 +122,7 @@ class PromptLayer(PromptLayerMixin):
114
122
  pl_run_span_id,
115
123
  **body,
116
124
  )
117
- return track_request(**track_request_kwargs)
125
+ return track_request(self.base_url, self.throw_on_error, **track_request_kwargs)
118
126
 
119
127
  return _track_request
120
128
 
@@ -144,6 +152,12 @@ class PromptLayer(PromptLayerMixin):
144
152
  model_parameter_overrides=model_parameter_overrides,
145
153
  )
146
154
  prompt_blueprint = self.templates.get(prompt_name, get_prompt_template_params)
155
+ if not prompt_blueprint:
156
+ raise _exceptions.PromptLayerNotFoundError(
157
+ f"Prompt template '{prompt_name}' not found.",
158
+ response=None,
159
+ body=None,
160
+ )
147
161
  prompt_blueprint_model = self._validate_and_extract_model_from_prompt_blueprint(
148
162
  prompt_blueprint=prompt_blueprint, prompt_name=prompt_name
149
163
  )
@@ -218,7 +232,7 @@ class PromptLayer(PromptLayerMixin):
218
232
  metadata=metadata,
219
233
  **body,
220
234
  )
221
- return track_request(self.base_url, **track_request_kwargs)
235
+ return track_request(self.base_url, self.throw_on_error, **track_request_kwargs)
222
236
 
223
237
  def run(
224
238
  self,
@@ -285,6 +299,7 @@ class PromptLayer(PromptLayerMixin):
285
299
  arun_workflow_request(
286
300
  api_key=self.api_key,
287
301
  base_url=self.base_url,
302
+ throw_on_error=self.throw_on_error,
288
303
  workflow_id_or_name=_get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name),
289
304
  input_variables=input_variables or {},
290
305
  metadata=metadata,
@@ -297,10 +312,16 @@ class PromptLayer(PromptLayerMixin):
297
312
  if not return_all_outputs and is_workflow_results_dict(results):
298
313
  output_nodes = [node_data for node_data in results.values() if node_data.get("is_output_node")]
299
314
  if not output_nodes:
300
- raise Exception("Output nodes not found: %S", json.dumps(results, indent=4))
315
+ raise _exceptions.PromptLayerNotFoundError(
316
+ f"Output nodes not found: {json.dumps(results, indent=4)}", response=None, body=results
317
+ )
301
318
 
302
319
  if not any(node.get("status") == "SUCCESS" for node in output_nodes):
303
- raise Exception("None of the output nodes have succeeded", json.dumps(results, indent=4))
320
+ raise _exceptions.PromptLayerAPIError(
321
+ f"None of the output nodes have succeeded: {json.dumps(results, indent=4)}",
322
+ response=None,
323
+ body=results,
324
+ )
304
325
 
305
326
  return results
306
327
  except Exception as ex:
@@ -308,7 +329,9 @@ class PromptLayer(PromptLayerMixin):
308
329
  if RERAISE_ORIGINAL_EXCEPTION:
309
330
  raise
310
331
  else:
311
- raise Exception(f"Error running workflow: {str(ex)}") from ex
332
+ raise _exceptions.PromptLayerAPIError(
333
+ f"Error running workflow: {str(ex)}", response=None, body=None
334
+ ) from ex
312
335
 
313
336
  def log_request(
314
337
  self,
@@ -338,6 +361,7 @@ class PromptLayer(PromptLayerMixin):
338
361
  return util_log_request(
339
362
  self.api_key,
340
363
  self.base_url,
364
+ throw_on_error=self.throw_on_error,
341
365
  provider=provider,
342
366
  model=model,
343
367
  input=input,
@@ -362,7 +386,11 @@ class PromptLayer(PromptLayerMixin):
362
386
 
363
387
  class AsyncPromptLayer(PromptLayerMixin):
364
388
  def __init__(
365
- self, api_key: Union[str, None] = None, enable_tracing: bool = False, base_url: Union[str, None] = None
389
+ self,
390
+ api_key: Union[str, None] = None,
391
+ enable_tracing: bool = False,
392
+ base_url: Union[str, None] = None,
393
+ throw_on_error: bool = True,
366
394
  ):
367
395
  if api_key is None:
368
396
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
@@ -375,10 +403,13 @@ class AsyncPromptLayer(PromptLayerMixin):
375
403
 
376
404
  self.base_url = get_base_url(base_url)
377
405
  self.api_key = api_key
378
- self.templates = AsyncTemplateManager(api_key, self.base_url)
379
- self.group = AsyncGroupManager(api_key, self.base_url)
380
- self.tracer_provider, self.tracer = self._initialize_tracer(api_key, self.base_url, enable_tracing)
381
- self.track = AsyncTrackManager(api_key, self.base_url)
406
+ self.throw_on_error = throw_on_error
407
+ self.templates = AsyncTemplateManager(api_key, self.base_url, self.throw_on_error)
408
+ self.group = AsyncGroupManager(api_key, self.base_url, self.throw_on_error)
409
+ self.tracer_provider, self.tracer = self._initialize_tracer(
410
+ api_key, self.base_url, self.throw_on_error, enable_tracing
411
+ )
412
+ self.track = AsyncTrackManager(api_key, self.base_url, self.throw_on_error)
382
413
 
383
414
  def __getattr__(self, name: Union[Literal["openai"], Literal["anthropic"], Literal["prompts"]]):
384
415
  if name == "openai":
@@ -420,6 +451,7 @@ class AsyncPromptLayer(PromptLayerMixin):
420
451
  return await arun_workflow_request(
421
452
  api_key=self.api_key,
422
453
  base_url=self.base_url,
454
+ throw_on_error=self.throw_on_error,
423
455
  workflow_id_or_name=_get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name),
424
456
  input_variables=input_variables or {},
425
457
  metadata=metadata,
@@ -432,7 +464,9 @@ class AsyncPromptLayer(PromptLayerMixin):
432
464
  if RERAISE_ORIGINAL_EXCEPTION:
433
465
  raise
434
466
  else:
435
- raise Exception(f"Error running workflow: {str(ex)}")
467
+ raise _exceptions.PromptLayerAPIError(
468
+ f"Error running workflow: {str(ex)}", response=None, body=None
469
+ ) from ex
436
470
 
437
471
  async def run(
438
472
  self,
@@ -498,6 +532,7 @@ class AsyncPromptLayer(PromptLayerMixin):
498
532
  return await autil_log_request(
499
533
  self.api_key,
500
534
  self.base_url,
535
+ throw_on_error=self.throw_on_error,
501
536
  provider=provider,
502
537
  model=model,
503
538
  input=input,
@@ -537,7 +572,7 @@ class AsyncPromptLayer(PromptLayerMixin):
537
572
  pl_run_span_id,
538
573
  **body,
539
574
  )
540
- return await atrack_request(self.base_url, **track_request_kwargs)
575
+ return await atrack_request(self.base_url, self.throw_on_error, **track_request_kwargs)
541
576
 
542
577
  return _track_request
543
578
 
@@ -561,7 +596,7 @@ class AsyncPromptLayer(PromptLayerMixin):
561
596
  metadata=metadata,
562
597
  **body,
563
598
  )
564
- return await atrack_request(self.base_url, **track_request_kwargs)
599
+ return await atrack_request(self.base_url, self.throw_on_error, **track_request_kwargs)
565
600
 
566
601
  async def _run_internal(
567
602
  self,
@@ -589,6 +624,12 @@ class AsyncPromptLayer(PromptLayerMixin):
589
624
  model_parameter_overrides=model_parameter_overrides,
590
625
  )
591
626
  prompt_blueprint = await self.templates.get(prompt_name, get_prompt_template_params)
627
+ if not prompt_blueprint:
628
+ raise _exceptions.PromptLayerNotFoundError(
629
+ f"Prompt template '{prompt_name}' not found.",
630
+ response=None,
631
+ body=None,
632
+ )
592
633
  prompt_blueprint_model = self._validate_and_extract_model_from_prompt_blueprint(
593
634
  prompt_blueprint=prompt_blueprint, prompt_name=prompt_name
594
635
  )
@@ -2,6 +2,7 @@ import datetime
2
2
  import inspect
3
3
  import re
4
4
 
5
+ from promptlayer import exceptions as _exceptions
5
6
  from promptlayer.utils import async_wrapper, promptlayer_api_handler
6
7
 
7
8
 
@@ -61,7 +62,7 @@ class PromptLayerBase(object):
61
62
  def __call__(self, *args, **kwargs):
62
63
  tags = kwargs.pop("pl_tags", None)
63
64
  if tags is not None and not isinstance(tags, list):
64
- raise Exception("pl_tags must be a list of strings.")
65
+ raise _exceptions.PromptLayerValidationError("pl_tags must be a list of strings.", response=None, body=None)
65
66
 
66
67
  return_pl_id = kwargs.pop("return_pl_id", False)
67
68
  request_start_time = datetime.datetime.now().timestamp()
@@ -262,11 +262,13 @@ AMAP_PROVIDER_TO_FUNCTION = {
262
262
 
263
263
  class PromptLayerMixin:
264
264
  @staticmethod
265
- def _initialize_tracer(api_key: str, base_url: str, enable_tracing: bool = False):
265
+ def _initialize_tracer(api_key: str, base_url: str, throw_on_error: bool, enable_tracing: bool = False):
266
266
  if enable_tracing:
267
267
  resource = Resource(attributes={ResourceAttributes.SERVICE_NAME: "prompt-layer-library"})
268
268
  tracer_provider = TracerProvider(resource=resource)
269
- promptlayer_exporter = PromptLayerSpanExporter(api_key=api_key, base_url=base_url)
269
+ promptlayer_exporter = PromptLayerSpanExporter(
270
+ api_key=api_key, base_url=base_url, throw_on_error=throw_on_error
271
+ )
270
272
  span_processor = BatchSpanProcessor(promptlayer_exporter)
271
273
  tracer_provider.add_span_processor(span_processor)
272
274
  tracer = tracer_provider.get_tracer(__name__)
@@ -4,11 +4,25 @@ import requests
4
4
  from opentelemetry.sdk.trace import ReadableSpan
5
5
  from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
6
6
 
7
+ from promptlayer.utils import raise_on_bad_response, retry_on_api_error
8
+
7
9
 
8
10
  class PromptLayerSpanExporter(SpanExporter):
9
- def __init__(self, api_key: str, base_url: str):
11
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
10
12
  self.api_key = api_key
11
13
  self.url = f"{base_url}/spans-bulk"
14
+ self.throw_on_error = throw_on_error
15
+
16
+ @retry_on_api_error
17
+ def _post_spans(self, request_data):
18
+ response = requests.post(
19
+ self.url,
20
+ headers={"X-Api-Key": self.api_key, "Content-Type": "application/json"},
21
+ json={"spans": request_data},
22
+ )
23
+ if response.status_code not in (200, 201):
24
+ raise_on_bad_response(response, "PromptLayer had the following error while exporting spans")
25
+ return response
12
26
 
13
27
  def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
14
28
  request_data = []
@@ -47,12 +61,7 @@ class PromptLayerSpanExporter(SpanExporter):
47
61
  request_data.append(span_info)
48
62
 
49
63
  try:
50
- response = requests.post(
51
- self.url,
52
- headers={"X-Api-Key": self.api_key, "Content-Type": "application/json"},
53
- json={"spans": request_data},
54
- )
55
- response.raise_for_status()
64
+ self._post_spans(request_data)
56
65
  return SpanExportResult.SUCCESS
57
66
  except requests.RequestException:
58
67
  return SpanExportResult.FAILURE
@@ -11,27 +11,29 @@ from promptlayer.utils import (
11
11
 
12
12
 
13
13
  class TemplateManager:
14
- def __init__(self, api_key: str, base_url: str):
14
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
15
15
  self.api_key = api_key
16
16
  self.base_url = base_url
17
+ self.throw_on_error = throw_on_error
17
18
 
18
19
  def get(self, prompt_name: str, params: Union[GetPromptTemplate, None] = None):
19
- return get_prompt_template(self.api_key, self.base_url, prompt_name, params)
20
+ return get_prompt_template(self.api_key, self.base_url, self.throw_on_error, prompt_name, params)
20
21
 
21
22
  def publish(self, body: PublishPromptTemplate):
22
- return publish_prompt_template(self.api_key, self.base_url, body)
23
+ return publish_prompt_template(self.api_key, self.base_url, self.throw_on_error, body)
23
24
 
24
25
  def all(self, page: int = 1, per_page: int = 30, label: str = None):
25
- return get_all_prompt_templates(self.api_key, self.base_url, page, per_page, label)
26
+ return get_all_prompt_templates(self.api_key, self.base_url, self.throw_on_error, page, per_page, label)
26
27
 
27
28
 
28
29
  class AsyncTemplateManager:
29
- def __init__(self, api_key: str, base_url: str):
30
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
30
31
  self.api_key = api_key
31
32
  self.base_url = base_url
33
+ self.throw_on_error = throw_on_error
32
34
 
33
35
  async def get(self, prompt_name: str, params: Union[GetPromptTemplate, None] = None):
34
- return await aget_prompt_template(self.api_key, self.base_url, prompt_name, params)
36
+ return await aget_prompt_template(self.api_key, self.base_url, self.throw_on_error, prompt_name, params)
35
37
 
36
38
  async def all(self, page: int = 1, per_page: int = 30, label: str = None):
37
- return await aget_all_prompt_templates(self.api_key, self.base_url, page, per_page, label)
39
+ return await aget_all_prompt_templates(self.api_key, self.base_url, self.throw_on_error, page, per_page, label)
@@ -0,0 +1,71 @@
1
+ from promptlayer.track.track import (
2
+ agroup,
3
+ ametadata,
4
+ aprompt,
5
+ ascore,
6
+ group,
7
+ metadata as metadata_,
8
+ prompt,
9
+ score as score_,
10
+ )
11
+
12
+ # TODO(dmu) LOW: Move this code to another file
13
+
14
+
15
+ class TrackManager:
16
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
17
+ self.api_key = api_key
18
+ self.base_url = base_url
19
+ self.throw_on_error = throw_on_error
20
+
21
+ def group(self, request_id, group_id):
22
+ return group(self.api_key, self.base_url, self.throw_on_error, request_id, group_id)
23
+
24
+ def metadata(self, request_id, metadata):
25
+ return metadata_(self.api_key, self.base_url, self.throw_on_error, request_id, metadata)
26
+
27
+ def prompt(self, request_id, prompt_name, prompt_input_variables, version=None, label=None):
28
+ return prompt(
29
+ self.api_key,
30
+ self.base_url,
31
+ self.throw_on_error,
32
+ request_id,
33
+ prompt_name,
34
+ prompt_input_variables,
35
+ version,
36
+ label,
37
+ )
38
+
39
+ def score(self, request_id, score, score_name=None):
40
+ return score_(self.api_key, self.base_url, self.throw_on_error, request_id, score, score_name)
41
+
42
+
43
+ class AsyncTrackManager:
44
+ def __init__(self, api_key: str, base_url: str, throw_on_error: bool):
45
+ self.api_key = api_key
46
+ self.base_url = base_url
47
+ self.throw_on_error = throw_on_error
48
+
49
+ async def group(self, request_id, group_id):
50
+ return await agroup(self.api_key, self.base_url, self.throw_on_error, request_id, group_id)
51
+
52
+ async def metadata(self, request_id, metadata):
53
+ return await ametadata(self.api_key, self.base_url, self.throw_on_error, request_id, metadata)
54
+
55
+ async def prompt(self, request_id, prompt_name, prompt_input_variables, version=None, label=None):
56
+ return await aprompt(
57
+ self.api_key,
58
+ self.base_url,
59
+ self.throw_on_error,
60
+ request_id,
61
+ prompt_name,
62
+ prompt_input_variables,
63
+ version,
64
+ label,
65
+ )
66
+
67
+ async def score(self, request_id, score, score_name=None):
68
+ return await ascore(self.api_key, self.base_url, self.throw_on_error, request_id, score, score_name)
69
+
70
+
71
+ __all__ = ["TrackManager"]