promptlayer 1.0.16__py3-none-any.whl → 1.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,66 +1,64 @@
1
1
  import asyncio
2
- import datetime
2
+ import json
3
+ import logging
3
4
  import os
4
- from copy import deepcopy
5
- from functools import wraps
6
- from typing import Any, Dict, List, Literal, Union
5
+ from typing import Any, Dict, List, Literal, Optional, Union
7
6
 
8
- from opentelemetry import trace
9
- from opentelemetry.sdk.resources import Resource
10
- from opentelemetry.sdk.trace import TracerProvider
11
- from opentelemetry.sdk.trace.export import BatchSpanProcessor
12
- from opentelemetry.semconv.resource import ResourceAttributes
7
+ import nest_asyncio
13
8
 
14
- from promptlayer.groups import GroupManager
9
+ from promptlayer import exceptions as _exceptions
10
+ from promptlayer.groups import AsyncGroupManager, GroupManager
15
11
  from promptlayer.promptlayer_base import PromptLayerBase
16
- from promptlayer.span_exporter import PromptLayerSpanExporter
17
- from promptlayer.templates import TemplateManager
18
- from promptlayer.track import TrackManager
12
+ from promptlayer.promptlayer_mixins import PromptLayerMixin
13
+ from promptlayer.streaming import astream_response, stream_response
14
+ from promptlayer.templates import AsyncTemplateManager, TemplateManager
15
+ from promptlayer.track import AsyncTrackManager, TrackManager
16
+ from promptlayer.types.prompt_template import PromptTemplate
19
17
  from promptlayer.utils import (
20
- anthropic_request,
21
- anthropic_stream_completion,
22
- anthropic_stream_message,
23
- openai_request,
24
- openai_stream_chat,
25
- openai_stream_completion,
26
- stream_response,
18
+ RERAISE_ORIGINAL_EXCEPTION,
19
+ _get_workflow_workflow_id_or_name,
20
+ arun_workflow_request,
21
+ atrack_request,
22
+ autil_log_request,
27
23
  track_request,
24
+ util_log_request,
28
25
  )
29
26
 
30
- MAP_PROVIDER_TO_FUNCTION_NAME = {
31
- "openai": {
32
- "chat": {
33
- "function_name": "openai.chat.completions.create",
34
- "stream_function": openai_stream_chat,
35
- },
36
- "completion": {
37
- "function_name": "openai.completions.create",
38
- "stream_function": openai_stream_completion,
39
- },
40
- },
41
- "anthropic": {
42
- "chat": {
43
- "function_name": "anthropic.messages.create",
44
- "stream_function": anthropic_stream_message,
45
- },
46
- "completion": {
47
- "function_name": "anthropic.completions.create",
48
- "stream_function": anthropic_stream_completion,
49
- },
50
- },
51
- }
52
-
53
- MAP_PROVIDER_TO_FUNCTION = {
54
- "openai": openai_request,
55
- "anthropic": anthropic_request,
56
- }
57
-
58
-
59
- class PromptLayer:
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def get_base_url(base_url: Union[str, None]):
31
+ return base_url or os.environ.get("PROMPTLAYER_BASE_URL", "https://api.promptlayer.com")
32
+
33
+
34
+ def is_workflow_results_dict(obj: Any) -> bool:
35
+ if not isinstance(obj, dict):
36
+ return False
37
+
38
+ required_keys = {
39
+ "status",
40
+ "value",
41
+ "error_message",
42
+ "raw_error_message",
43
+ "is_output_node",
44
+ }
45
+
46
+ for val in obj.values():
47
+ if not isinstance(val, dict):
48
+ return False
49
+ if not required_keys.issubset(val.keys()):
50
+ return False
51
+
52
+ return True
53
+
54
+
55
+ class PromptLayer(PromptLayerMixin):
60
56
  def __init__(
61
57
  self,
62
- api_key: str = None,
58
+ api_key: Union[str, None] = None,
63
59
  enable_tracing: bool = False,
60
+ base_url: Union[str, None] = None,
61
+ throw_on_error: bool = True,
64
62
  ):
65
63
  if api_key is None:
66
64
  api_key = os.environ.get("PROMPTLAYER_API_KEY")
@@ -71,11 +69,15 @@ class PromptLayer:
71
69
  "Please set the PROMPTLAYER_API_KEY environment variable or pass the api_key parameter."
72
70
  )
73
71
 
72
+ self.base_url = get_base_url(base_url)
74
73
  self.api_key = api_key
75
- self.templates = TemplateManager(api_key)
76
- self.group = GroupManager(api_key)
77
- self.tracer = self._initialize_tracer(api_key, enable_tracing)
78
- self.track = TrackManager(api_key)
74
+ self.throw_on_error = throw_on_error
75
+ self.templates = TemplateManager(api_key, self.base_url, self.throw_on_error)
76
+ self.group = GroupManager(api_key, self.base_url, self.throw_on_error)
77
+ self.tracer_provider, self.tracer = self._initialize_tracer(
78
+ api_key, self.base_url, self.throw_on_error, enable_tracing
79
+ )
80
+ self.track = TrackManager(api_key, self.base_url, self.throw_on_error)
79
81
 
80
82
  def __getattr__(
81
83
  self,
@@ -84,24 +86,20 @@ class PromptLayer:
84
86
  if name == "openai":
85
87
  import openai as openai_module
86
88
 
87
- openai = PromptLayerBase(
88
- openai_module,
89
- function_name="openai",
90
- api_key=self.api_key,
91
- tracer=self.tracer,
89
+ return PromptLayerBase(
90
+ self.api_key, self.base_url, openai_module, function_name="openai", tracer=self.tracer
92
91
  )
93
- return openai
94
92
  elif name == "anthropic":
95
93
  import anthropic as anthropic_module
96
94
 
97
- anthropic = PromptLayerBase(
95
+ return PromptLayerBase(
96
+ self.api_key,
97
+ self.base_url,
98
98
  anthropic_module,
99
99
  function_name="anthropic",
100
100
  provider_type="anthropic",
101
- api_key=self.api_key,
102
101
  tracer=self.tracer,
103
102
  )
104
- return anthropic
105
103
  else:
106
104
  raise AttributeError(f"module {__name__} has no attribute {name}")
107
105
 
@@ -112,105 +110,24 @@ class PromptLayer:
112
110
  tags,
113
111
  input_variables,
114
112
  group_id,
115
- pl_run_span_id: str | None = None,
113
+ pl_run_span_id: Union[str, None] = None,
114
+ request_start_time: Union[float, None] = None,
116
115
  ):
117
116
  def _track_request(**body):
118
117
  track_request_kwargs = self._prepare_track_request_kwargs(
119
- request_params, tags, input_variables, group_id, pl_run_span_id, **body
118
+ self.api_key,
119
+ request_params,
120
+ tags,
121
+ input_variables,
122
+ group_id,
123
+ pl_run_span_id,
124
+ request_start_time=request_start_time,
125
+ **body,
120
126
  )
121
- return track_request(**track_request_kwargs)
127
+ return track_request(self.base_url, self.throw_on_error, **track_request_kwargs)
122
128
 
123
129
  return _track_request
124
130
 
125
- @staticmethod
126
- def _initialize_tracer(api_key: str = None, enable_tracing: bool = False):
127
- if enable_tracing:
128
- resource = Resource(
129
- attributes={ResourceAttributes.SERVICE_NAME: "prompt-layer-library"}
130
- )
131
- tracer_provider = TracerProvider(resource=resource)
132
- promptlayer_exporter = PromptLayerSpanExporter(api_key=api_key)
133
- span_processor = BatchSpanProcessor(promptlayer_exporter)
134
- tracer_provider.add_span_processor(span_processor)
135
- trace.set_tracer_provider(tracer_provider)
136
- return trace.get_tracer(__name__)
137
- else:
138
- return None
139
-
140
- @staticmethod
141
- def _prepare_get_prompt_template_params(
142
- *, prompt_version, prompt_release_label, input_variables, metadata
143
- ):
144
- params = {}
145
-
146
- if prompt_version:
147
- params["version"] = prompt_version
148
- if prompt_release_label:
149
- params["label"] = prompt_release_label
150
- if input_variables:
151
- params["input_variables"] = input_variables
152
- if metadata:
153
- params["metadata_filters"] = metadata
154
-
155
- return params
156
-
157
- @staticmethod
158
- def _prepare_llm_request_params(
159
- *, prompt_blueprint, prompt_template, prompt_blueprint_model, stream
160
- ):
161
- provider = prompt_blueprint_model["provider"]
162
- kwargs = deepcopy(prompt_blueprint["llm_kwargs"])
163
- config = MAP_PROVIDER_TO_FUNCTION_NAME[provider][prompt_template["type"]]
164
-
165
- if provider_base_url := prompt_blueprint.get("provider_base_url"):
166
- kwargs["base_url"] = provider_base_url["url"]
167
-
168
- kwargs["stream"] = stream
169
- if stream and provider == "openai":
170
- kwargs["stream_options"] = {"include_usage": True}
171
-
172
- return {
173
- "provider": provider,
174
- "function_name": config["function_name"],
175
- "stream_function": config["stream_function"],
176
- "request_function": MAP_PROVIDER_TO_FUNCTION[provider],
177
- "kwargs": kwargs,
178
- "prompt_blueprint": prompt_blueprint,
179
- }
180
-
181
- def _prepare_track_request_kwargs(
182
- self,
183
- request_params,
184
- tags,
185
- input_variables,
186
- group_id,
187
- pl_run_span_id: str | None = None,
188
- metadata: Dict[str, str] | None = None,
189
- **body,
190
- ):
191
- return {
192
- "function_name": request_params["function_name"],
193
- "provider_type": request_params["provider"],
194
- "args": [],
195
- "kwargs": request_params["kwargs"],
196
- "tags": tags,
197
- "request_start_time": datetime.datetime.now(
198
- datetime.timezone.utc
199
- ).timestamp(),
200
- "request_end_time": datetime.datetime.now(
201
- datetime.timezone.utc
202
- ).timestamp(),
203
- "api_key": self.api_key,
204
- "metadata": metadata,
205
- "prompt_id": request_params["prompt_blueprint"]["id"],
206
- "prompt_version": request_params["prompt_blueprint"]["version"],
207
- "prompt_input_variables": input_variables,
208
- "group_id": group_id,
209
- "return_prompt_blueprint": True,
210
- "span_id": pl_run_span_id,
211
- **body,
212
- }
213
-
214
131
  def _run_internal(
215
132
  self,
216
133
  *,
@@ -218,54 +135,88 @@ class PromptLayer:
218
135
  prompt_version: Union[int, None] = None,
219
136
  prompt_release_label: Union[str, None] = None,
220
137
  input_variables: Union[Dict[str, Any], None] = None,
138
+ model_parameter_overrides: Union[Dict[str, Any], None] = None,
221
139
  tags: Union[List[str], None] = None,
222
140
  metadata: Union[Dict[str, str], None] = None,
223
141
  group_id: Union[int, None] = None,
224
142
  stream: bool = False,
225
- pl_run_span_id: str | None = None,
143
+ pl_run_span_id: Union[str, None] = None,
144
+ provider: Union[str, None] = None,
145
+ model: Union[str, None] = None,
226
146
  ) -> Dict[str, Any]:
147
+ import datetime
148
+
227
149
  get_prompt_template_params = self._prepare_get_prompt_template_params(
228
150
  prompt_version=prompt_version,
229
151
  prompt_release_label=prompt_release_label,
230
152
  input_variables=input_variables,
231
153
  metadata=metadata,
154
+ provider=provider,
155
+ model=model,
156
+ model_parameter_overrides=model_parameter_overrides,
232
157
  )
233
158
  prompt_blueprint = self.templates.get(prompt_name, get_prompt_template_params)
159
+ if not prompt_blueprint:
160
+ raise _exceptions.PromptLayerNotFoundError(
161
+ f"Prompt template '{prompt_name}' not found.",
162
+ response=None,
163
+ body=None,
164
+ )
234
165
  prompt_blueprint_model = self._validate_and_extract_model_from_prompt_blueprint(
235
166
  prompt_blueprint=prompt_blueprint, prompt_name=prompt_name
236
167
  )
237
- llm_request_params = self._prepare_llm_request_params(
168
+ llm_data = self._prepare_llm_data(
238
169
  prompt_blueprint=prompt_blueprint,
239
170
  prompt_template=prompt_blueprint["prompt_template"],
240
171
  prompt_blueprint_model=prompt_blueprint_model,
241
172
  stream=stream,
242
173
  )
243
174
 
244
- response = llm_request_params["request_function"](
245
- llm_request_params["prompt_blueprint"], **llm_request_params["kwargs"]
175
+ # Capture start time before making the LLM request
176
+ request_start_time = datetime.datetime.now(datetime.timezone.utc).timestamp()
177
+
178
+ # response is just whatever the LLM call returns
179
+ # streaming=False > Pydantic model instance
180
+ # streaming=True > generator that yields ChatCompletionChunk pieces as they arrive
181
+ response = llm_data["request_function"](
182
+ prompt_blueprint=llm_data["prompt_blueprint"],
183
+ client_kwargs=llm_data["client_kwargs"],
184
+ function_kwargs=llm_data["function_kwargs"],
246
185
  )
247
186
 
187
+ # Capture end time after the LLM request completes
188
+ request_end_time = datetime.datetime.now(datetime.timezone.utc).timestamp()
189
+
248
190
  if stream:
249
191
  return stream_response(
250
- response,
251
- self._create_track_request_callable(
252
- request_params=llm_request_params,
192
+ generator=response,
193
+ after_stream=self._create_track_request_callable(
194
+ request_params=llm_data,
253
195
  tags=tags,
254
196
  input_variables=input_variables,
255
197
  group_id=group_id,
256
198
  pl_run_span_id=pl_run_span_id,
199
+ request_start_time=request_start_time,
257
200
  ),
258
- llm_request_params["stream_function"],
201
+ map_results=llm_data["stream_function"],
202
+ metadata=llm_data["prompt_blueprint"]["metadata"],
259
203
  )
260
204
 
205
+ if isinstance(response, dict):
206
+ request_response = response
207
+ else:
208
+ request_response = response.model_dump(mode="json")
209
+
261
210
  request_log = self._track_request_log(
262
- llm_request_params,
211
+ llm_data,
263
212
  tags,
264
213
  input_variables,
265
214
  group_id,
266
215
  pl_run_span_id,
267
216
  metadata=metadata,
268
- request_response=response.model_dump(),
217
+ request_response=request_response,
218
+ request_start_time=request_start_time,
219
+ request_end_time=request_end_time,
269
220
  )
270
221
 
271
222
  return {
@@ -280,11 +231,12 @@ class PromptLayer:
280
231
  tags,
281
232
  input_variables,
282
233
  group_id,
283
- pl_run_span_id: str | None = None,
284
- metadata: Dict[str, str] | None = None,
234
+ pl_run_span_id: Union[str, None] = None,
235
+ metadata: Union[Dict[str, str], None] = None,
285
236
  **body,
286
237
  ):
287
238
  track_request_kwargs = self._prepare_track_request_kwargs(
239
+ self.api_key,
288
240
  request_params,
289
241
  tags,
290
242
  input_variables,
@@ -293,53 +245,268 @@ class PromptLayer:
293
245
  metadata=metadata,
294
246
  **body,
295
247
  )
296
- return track_request(**track_request_kwargs)
248
+ return track_request(self.base_url, self.throw_on_error, **track_request_kwargs)
297
249
 
298
- @staticmethod
299
- def _validate_and_extract_model_from_prompt_blueprint(
300
- *, prompt_blueprint, prompt_name
301
- ):
302
- if not prompt_blueprint["llm_kwargs"]:
303
- raise ValueError(
304
- f"Prompt '{prompt_name}' does not have any LLM kwargs associated with it."
250
+ def run(
251
+ self,
252
+ prompt_name: str,
253
+ prompt_version: Union[int, None] = None,
254
+ prompt_release_label: Union[str, None] = None,
255
+ input_variables: Union[Dict[str, Any], None] = None,
256
+ model_parameter_overrides: Union[Dict[str, Any], None] = None,
257
+ tags: Union[List[str], None] = None,
258
+ metadata: Union[Dict[str, str], None] = None,
259
+ group_id: Union[int, None] = None,
260
+ stream: bool = False,
261
+ provider: Union[str, None] = None,
262
+ model: Union[str, None] = None,
263
+ ) -> Dict[str, Any]:
264
+ _run_internal_kwargs = {
265
+ "prompt_name": prompt_name,
266
+ "prompt_version": prompt_version,
267
+ "prompt_release_label": prompt_release_label,
268
+ "input_variables": input_variables or {},
269
+ "model_parameter_overrides": model_parameter_overrides,
270
+ "tags": tags,
271
+ "metadata": metadata,
272
+ "group_id": group_id,
273
+ "stream": stream,
274
+ "provider": provider,
275
+ "model": model,
276
+ }
277
+
278
+ if self.tracer:
279
+ with self.tracer.start_as_current_span("PromptLayer Run") as span:
280
+ span.set_attribute("prompt_name", prompt_name)
281
+ span.set_attribute("function_input", str(_run_internal_kwargs))
282
+ pl_run_span_id = hex(span.context.span_id)[2:].zfill(16)
283
+ result = self._run_internal(**_run_internal_kwargs, pl_run_span_id=pl_run_span_id)
284
+ span.set_attribute("function_output", str(result))
285
+ return result
286
+ else:
287
+ return self._run_internal(**_run_internal_kwargs)
288
+
289
+ def run_workflow(
290
+ self,
291
+ workflow_id_or_name: Optional[Union[int, str]] = None,
292
+ input_variables: Optional[Dict[str, Any]] = None,
293
+ metadata: Optional[Dict[str, str]] = None,
294
+ workflow_label_name: Optional[str] = None,
295
+ workflow_version: Optional[int] = None,
296
+ return_all_outputs: Optional[bool] = False,
297
+ # `workflow_name` deprecated, kept for backward compatibility only.
298
+ # Allows `workflow_name` to be passed both as keyword and positional argument
299
+ # (virtually identical to `workflow_id_or_name`)
300
+ workflow_name: Optional[str] = None,
301
+ ) -> Union[Dict[str, Any], Any]:
302
+ try:
303
+ try:
304
+ loop = asyncio.get_running_loop() # Check if we're inside a running event loop
305
+ except RuntimeError:
306
+ loop = None
307
+
308
+ if loop and loop.is_running():
309
+ nest_asyncio.apply()
310
+
311
+ results = asyncio.run(
312
+ arun_workflow_request(
313
+ api_key=self.api_key,
314
+ base_url=self.base_url,
315
+ throw_on_error=self.throw_on_error,
316
+ workflow_id_or_name=_get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name),
317
+ input_variables=input_variables or {},
318
+ metadata=metadata,
319
+ workflow_label_name=workflow_label_name,
320
+ workflow_version_number=workflow_version,
321
+ return_all_outputs=return_all_outputs,
322
+ )
305
323
  )
306
324
 
307
- prompt_blueprint_metadata = prompt_blueprint.get("metadata")
325
+ if not return_all_outputs and is_workflow_results_dict(results):
326
+ output_nodes = [node_data for node_data in results.values() if node_data.get("is_output_node")]
327
+ if not output_nodes:
328
+ raise _exceptions.PromptLayerNotFoundError(
329
+ f"Output nodes not found: {json.dumps(results, indent=4)}", response=None, body=results
330
+ )
331
+
332
+ if not any(node.get("status") == "SUCCESS" for node in output_nodes):
333
+ raise _exceptions.PromptLayerAPIError(
334
+ f"None of the output nodes have succeeded: {json.dumps(results, indent=4)}",
335
+ response=None,
336
+ body=results,
337
+ )
338
+
339
+ return results
340
+ except Exception as ex:
341
+ logger.exception("Error running workflow")
342
+ if RERAISE_ORIGINAL_EXCEPTION:
343
+ raise
344
+ else:
345
+ raise _exceptions.PromptLayerAPIError(
346
+ f"Error running workflow: {str(ex)}", response=None, body=None
347
+ ) from ex
348
+
349
+ def log_request(
350
+ self,
351
+ *,
352
+ provider: str,
353
+ model: str,
354
+ input: PromptTemplate,
355
+ output: PromptTemplate,
356
+ request_start_time: float,
357
+ request_end_time: float,
358
+ # TODO(dmu) MEDIUM: Avoid using mutable defaults
359
+ # TODO(dmu) MEDIUM: Deprecate and remove this wrapper function?
360
+ parameters: Dict[str, Any] = {},
361
+ tags: List[str] = [],
362
+ metadata: Dict[str, str] = {},
363
+ prompt_name: Union[str, None] = None,
364
+ prompt_version_number: Union[int, None] = None,
365
+ prompt_input_variables: Dict[str, Any] = {},
366
+ input_tokens: int = 0,
367
+ output_tokens: int = 0,
368
+ price: float = 0.0,
369
+ function_name: str = "",
370
+ score: int = 0,
371
+ prompt_id: Union[int, None] = None,
372
+ score_name: Union[str, None] = None,
373
+ ):
374
+ return util_log_request(
375
+ self.api_key,
376
+ self.base_url,
377
+ throw_on_error=self.throw_on_error,
378
+ provider=provider,
379
+ model=model,
380
+ input=input,
381
+ output=output,
382
+ request_start_time=request_start_time,
383
+ request_end_time=request_end_time,
384
+ parameters=parameters,
385
+ tags=tags,
386
+ metadata=metadata,
387
+ prompt_name=prompt_name,
388
+ prompt_version_number=prompt_version_number,
389
+ prompt_input_variables=prompt_input_variables,
390
+ input_tokens=input_tokens,
391
+ output_tokens=output_tokens,
392
+ price=price,
393
+ function_name=function_name,
394
+ score=score,
395
+ prompt_id=prompt_id,
396
+ score_name=score_name,
397
+ )
398
+
399
+
400
+ class AsyncPromptLayer(PromptLayerMixin):
401
+ def __init__(
402
+ self,
403
+ api_key: Union[str, None] = None,
404
+ enable_tracing: bool = False,
405
+ base_url: Union[str, None] = None,
406
+ throw_on_error: bool = True,
407
+ ):
408
+ if api_key is None:
409
+ api_key = os.environ.get("PROMPTLAYER_API_KEY")
308
410
 
309
- if not prompt_blueprint_metadata:
411
+ if api_key is None:
310
412
  raise ValueError(
311
- f"Prompt '{prompt_name}' does not have any metadata associated with it."
413
+ "PromptLayer API key not provided. "
414
+ "Please set the PROMPTLAYER_API_KEY environment variable or pass the api_key parameter."
312
415
  )
313
416
 
314
- prompt_blueprint_model = prompt_blueprint_metadata.get("model")
417
+ self.base_url = get_base_url(base_url)
418
+ self.api_key = api_key
419
+ self.throw_on_error = throw_on_error
420
+ self.templates = AsyncTemplateManager(api_key, self.base_url, self.throw_on_error)
421
+ self.group = AsyncGroupManager(api_key, self.base_url, self.throw_on_error)
422
+ self.tracer_provider, self.tracer = self._initialize_tracer(
423
+ api_key, self.base_url, self.throw_on_error, enable_tracing
424
+ )
425
+ self.track = AsyncTrackManager(api_key, self.base_url, self.throw_on_error)
315
426
 
316
- if not prompt_blueprint_model:
317
- raise ValueError(
318
- f"Prompt '{prompt_name}' does not have a model parameters associated with it."
427
+ def __getattr__(self, name: Union[Literal["openai"], Literal["anthropic"], Literal["prompts"]]):
428
+ if name == "openai":
429
+ import openai as openai_module
430
+
431
+ openai = PromptLayerBase(
432
+ self.api_key, self.base_url, openai_module, function_name="openai", tracer=self.tracer
319
433
  )
434
+ return openai
435
+ elif name == "anthropic":
436
+ import anthropic as anthropic_module
320
437
 
321
- return prompt_blueprint_model
438
+ anthropic = PromptLayerBase(
439
+ self.api_key,
440
+ self.base_url,
441
+ anthropic_module,
442
+ function_name="anthropic",
443
+ provider_type="anthropic",
444
+ tracer=self.tracer,
445
+ )
446
+ return anthropic
447
+ else:
448
+ raise AttributeError(f"module {__name__} has no attribute {name}")
322
449
 
323
- def run(
450
+ async def run_workflow(
451
+ self,
452
+ workflow_id_or_name: Optional[Union[int, str]] = None,
453
+ input_variables: Optional[Dict[str, Any]] = None,
454
+ metadata: Optional[Dict[str, str]] = None,
455
+ workflow_label_name: Optional[str] = None,
456
+ workflow_version: Optional[int] = None, # This is the version number, not the version ID
457
+ return_all_outputs: Optional[bool] = False,
458
+ # `workflow_name` deprecated, kept for backward compatibility only.
459
+ # Allows `workflow_name` to be passed both as keyword and positional argument
460
+ # (virtually identical to `workflow_id_or_name`)
461
+ workflow_name: Optional[str] = None,
462
+ ) -> Union[Dict[str, Any], Any]:
463
+ try:
464
+ return await arun_workflow_request(
465
+ api_key=self.api_key,
466
+ base_url=self.base_url,
467
+ throw_on_error=self.throw_on_error,
468
+ workflow_id_or_name=_get_workflow_workflow_id_or_name(workflow_id_or_name, workflow_name),
469
+ input_variables=input_variables or {},
470
+ metadata=metadata,
471
+ workflow_label_name=workflow_label_name,
472
+ workflow_version_number=workflow_version,
473
+ return_all_outputs=return_all_outputs,
474
+ )
475
+ except Exception as ex:
476
+ logger.exception("Error running workflow")
477
+ if RERAISE_ORIGINAL_EXCEPTION:
478
+ raise
479
+ else:
480
+ raise _exceptions.PromptLayerAPIError(
481
+ f"Error running workflow: {str(ex)}", response=None, body=None
482
+ ) from ex
483
+
484
+ async def run(
324
485
  self,
325
486
  prompt_name: str,
326
487
  prompt_version: Union[int, None] = None,
327
488
  prompt_release_label: Union[str, None] = None,
328
489
  input_variables: Union[Dict[str, Any], None] = None,
490
+ model_parameter_overrides: Union[Dict[str, Any], None] = None,
329
491
  tags: Union[List[str], None] = None,
330
492
  metadata: Union[Dict[str, str], None] = None,
331
493
  group_id: Union[int, None] = None,
332
494
  stream: bool = False,
495
+ provider: Union[str, None] = None,
496
+ model: Union[str, None] = None,
333
497
  ) -> Dict[str, Any]:
334
498
  _run_internal_kwargs = {
335
499
  "prompt_name": prompt_name,
336
500
  "prompt_version": prompt_version,
337
501
  "prompt_release_label": prompt_release_label,
338
502
  "input_variables": input_variables,
503
+ "model_parameter_overrides": model_parameter_overrides,
339
504
  "tags": tags,
340
505
  "metadata": metadata,
341
506
  "group_id": group_id,
342
507
  "stream": stream,
508
+ "provider": provider,
509
+ "model": model,
343
510
  }
344
511
 
345
512
  if self.tracer:
@@ -347,52 +514,197 @@ class PromptLayer:
347
514
  span.set_attribute("prompt_name", prompt_name)
348
515
  span.set_attribute("function_input", str(_run_internal_kwargs))
349
516
  pl_run_span_id = hex(span.context.span_id)[2:].zfill(16)
350
- result = self._run_internal(
351
- **_run_internal_kwargs, pl_run_span_id=pl_run_span_id
352
- )
517
+ result = await self._run_internal(**_run_internal_kwargs, pl_run_span_id=pl_run_span_id)
353
518
  span.set_attribute("function_output", str(result))
354
519
  return result
355
520
  else:
356
- return self._run_internal(**_run_internal_kwargs)
521
+ return await self._run_internal(**_run_internal_kwargs)
522
+
523
+ async def log_request(
524
+ self,
525
+ *,
526
+ provider: str,
527
+ model: str,
528
+ input: PromptTemplate,
529
+ output: PromptTemplate,
530
+ request_start_time: float,
531
+ request_end_time: float,
532
+ parameters: Dict[str, Any] = {},
533
+ tags: List[str] = [],
534
+ metadata: Dict[str, str] = {},
535
+ prompt_name: Union[str, None] = None,
536
+ prompt_version_number: Union[int, None] = None,
537
+ prompt_input_variables: Dict[str, Any] = {},
538
+ input_tokens: int = 0,
539
+ output_tokens: int = 0,
540
+ price: float = 0.0,
541
+ function_name: str = "",
542
+ score: int = 0,
543
+ prompt_id: Union[int, None] = None,
544
+ ):
545
+ return await autil_log_request(
546
+ self.api_key,
547
+ self.base_url,
548
+ throw_on_error=self.throw_on_error,
549
+ provider=provider,
550
+ model=model,
551
+ input=input,
552
+ output=output,
553
+ request_start_time=request_start_time,
554
+ request_end_time=request_end_time,
555
+ parameters=parameters,
556
+ tags=tags,
557
+ metadata=metadata,
558
+ prompt_name=prompt_name,
559
+ prompt_version_number=prompt_version_number,
560
+ prompt_input_variables=prompt_input_variables,
561
+ input_tokens=input_tokens,
562
+ output_tokens=output_tokens,
563
+ price=price,
564
+ function_name=function_name,
565
+ score=score,
566
+ prompt_id=prompt_id,
567
+ )
357
568
 
358
- def traceable(self, attributes=None):
359
- def decorator(func):
360
- @wraps(func)
361
- def sync_wrapper(*args, **kwargs):
362
- if self.tracer:
363
- with self.tracer.start_as_current_span(func.__name__) as span:
364
- if attributes:
365
- for key, value in attributes.items():
366
- span.set_attribute(key, value)
367
-
368
- span.set_attribute(
369
- "function_input", str({"args": args, "kwargs": kwargs})
370
- )
371
- result = func(*args, **kwargs)
372
- span.set_attribute("function_output", str(result))
373
-
374
- return result
375
- else:
376
- return func(*args, **kwargs)
377
-
378
- @wraps(func)
379
- async def async_wrapper(*args, **kwargs):
380
- if self.tracer:
381
- with self.tracer.start_as_current_span(func.__name__) as span:
382
- if attributes:
383
- for key, value in attributes.items():
384
- span.set_attribute(key, value)
385
-
386
- span.set_attribute(
387
- "function_input", str({"args": args, "kwargs": kwargs})
388
- )
389
- result = await func(*args, **kwargs)
390
- span.set_attribute("function_output", str(result))
391
-
392
- return result
393
- else:
394
- return await func(*args, **kwargs)
395
-
396
- return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
397
-
398
- return decorator
569
+ async def _create_track_request_callable(
570
+ self,
571
+ *,
572
+ request_params,
573
+ tags,
574
+ input_variables,
575
+ group_id,
576
+ pl_run_span_id: Union[str, None] = None,
577
+ request_start_time: Union[float, None] = None,
578
+ ):
579
+ async def _track_request(**body):
580
+ track_request_kwargs = self._prepare_track_request_kwargs(
581
+ self.api_key,
582
+ request_params,
583
+ tags,
584
+ input_variables,
585
+ group_id,
586
+ pl_run_span_id,
587
+ request_start_time=request_start_time,
588
+ **body,
589
+ )
590
+ return await atrack_request(self.base_url, self.throw_on_error, **track_request_kwargs)
591
+
592
+ return _track_request
593
+
594
+ async def _track_request_log(
595
+ self,
596
+ request_params,
597
+ tags,
598
+ input_variables,
599
+ group_id,
600
+ pl_run_span_id: Union[str, None] = None,
601
+ metadata: Union[Dict[str, str], None] = None,
602
+ **body,
603
+ ):
604
+ track_request_kwargs = self._prepare_track_request_kwargs(
605
+ self.api_key,
606
+ request_params,
607
+ tags,
608
+ input_variables,
609
+ group_id,
610
+ pl_run_span_id,
611
+ metadata=metadata,
612
+ **body,
613
+ )
614
+ return await atrack_request(self.base_url, self.throw_on_error, **track_request_kwargs)
615
+
616
+ async def _run_internal(
617
+ self,
618
+ *,
619
+ prompt_name: str,
620
+ prompt_version: Union[int, None] = None,
621
+ prompt_release_label: Union[str, None] = None,
622
+ input_variables: Union[Dict[str, Any], None] = None,
623
+ model_parameter_overrides: Union[Dict[str, Any], None] = None,
624
+ tags: Union[List[str], None] = None,
625
+ metadata: Union[Dict[str, str], None] = None,
626
+ group_id: Union[int, None] = None,
627
+ stream: bool = False,
628
+ pl_run_span_id: Union[str, None] = None,
629
+ provider: Union[str, None] = None,
630
+ model: Union[str, None] = None,
631
+ ) -> Dict[str, Any]:
632
+ import datetime
633
+
634
+ get_prompt_template_params = self._prepare_get_prompt_template_params(
635
+ prompt_version=prompt_version,
636
+ prompt_release_label=prompt_release_label,
637
+ input_variables=input_variables,
638
+ metadata=metadata,
639
+ provider=provider,
640
+ model=model,
641
+ model_parameter_overrides=model_parameter_overrides,
642
+ )
643
+ prompt_blueprint = await self.templates.get(prompt_name, get_prompt_template_params)
644
+ if not prompt_blueprint:
645
+ raise _exceptions.PromptLayerNotFoundError(
646
+ f"Prompt template '{prompt_name}' not found.",
647
+ response=None,
648
+ body=None,
649
+ )
650
+ prompt_blueprint_model = self._validate_and_extract_model_from_prompt_blueprint(
651
+ prompt_blueprint=prompt_blueprint, prompt_name=prompt_name
652
+ )
653
+ llm_data = self._prepare_llm_data(
654
+ prompt_blueprint=prompt_blueprint,
655
+ prompt_template=prompt_blueprint["prompt_template"],
656
+ prompt_blueprint_model=prompt_blueprint_model,
657
+ stream=stream,
658
+ is_async=True,
659
+ )
660
+
661
+ # Capture start time before making the LLM request
662
+ request_start_time = datetime.datetime.now(datetime.timezone.utc).timestamp()
663
+
664
+ response = await llm_data["request_function"](
665
+ prompt_blueprint=llm_data["prompt_blueprint"],
666
+ client_kwargs=llm_data["client_kwargs"],
667
+ function_kwargs=llm_data["function_kwargs"],
668
+ )
669
+
670
+ # Capture end time after the LLM request completes
671
+ request_end_time = datetime.datetime.now(datetime.timezone.utc).timestamp()
672
+
673
+ if hasattr(response, "model_dump"):
674
+ request_response = response.model_dump(mode="json")
675
+ else:
676
+ request_response = response
677
+
678
+ if stream:
679
+ track_request_callable = await self._create_track_request_callable(
680
+ request_params=llm_data,
681
+ tags=tags,
682
+ input_variables=input_variables,
683
+ group_id=group_id,
684
+ pl_run_span_id=pl_run_span_id,
685
+ request_start_time=request_start_time,
686
+ )
687
+ return astream_response(
688
+ request_response,
689
+ track_request_callable,
690
+ llm_data["stream_function"],
691
+ llm_data["prompt_blueprint"]["metadata"],
692
+ )
693
+
694
+ request_log = await self._track_request_log(
695
+ llm_data,
696
+ tags,
697
+ input_variables,
698
+ group_id,
699
+ pl_run_span_id,
700
+ metadata=metadata,
701
+ request_response=request_response,
702
+ request_start_time=request_start_time,
703
+ request_end_time=request_end_time,
704
+ )
705
+
706
+ return {
707
+ "request_id": request_log.get("request_id", None),
708
+ "raw_response": response,
709
+ "prompt_blueprint": request_log.get("prompt_blueprint", None),
710
+ }