arize-phoenix 1.9.1rc3__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -0,0 +1,353 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import signal
6
+ import traceback
7
+ from typing import Any, Callable, Coroutine, List, Optional, Protocol, Sequence, Tuple, Union
8
+
9
+ from tqdm.auto import tqdm
10
+
11
+ from phoenix.exceptions import PhoenixException
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class Unset:
17
+ pass
18
+
19
+
20
+ _unset = Unset()
21
+
22
+
23
+ class Executor(Protocol):
24
+ def run(self, inputs: Sequence[Any]) -> List[Any]:
25
+ ...
26
+
27
+
28
+ class AsyncExecutor(Executor):
29
+ """
30
+ A class that provides asynchronous execution of tasks using a producer-consumer pattern.
31
+
32
+ An async interface is provided by the `execute` method, which returns a coroutine, and a sync
33
+ interface is provided by the `run` method.
34
+
35
+ Args:
36
+ generation_fn (Callable[[Any], Coroutine[Any, Any, Any]]): A coroutine function that
37
+ generates tasks to be executed.
38
+
39
+ concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
40
+
41
+ tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
42
+ to None.
43
+
44
+ max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
45
+ 10.
46
+
47
+ exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
48
+ Defaults to True.
49
+
50
+ fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
51
+ that encounter errors. Defaults to _unset.
52
+
53
+ termination_signal (signal.Signals, optional): The signal handled to terminate the executor.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ generation_fn: Callable[[Any], Coroutine[Any, Any, Any]],
59
+ concurrency: int = 3,
60
+ tqdm_bar_format: Optional[str] = None,
61
+ max_retries: int = 10,
62
+ exit_on_error: bool = True,
63
+ fallback_return_value: Union[Unset, Any] = _unset,
64
+ termination_signal: signal.Signals = signal.SIGINT,
65
+ ):
66
+ self.generate = generation_fn
67
+ self.fallback_return_value = fallback_return_value
68
+ self.concurrency = concurrency
69
+ self.tqdm_bar_format = tqdm_bar_format
70
+ self.max_retries = max_retries
71
+ self.exit_on_error = exit_on_error
72
+ self.base_priority = 0
73
+ self.termination_signal = termination_signal
74
+
75
+ async def producer(
76
+ self,
77
+ inputs: Sequence[Any],
78
+ queue: asyncio.PriorityQueue[Tuple[int, Any]],
79
+ max_fill: int,
80
+ done_producing: asyncio.Event,
81
+ termination_signal: asyncio.Event,
82
+ ) -> None:
83
+ try:
84
+ for index, input in enumerate(inputs):
85
+ if termination_signal.is_set():
86
+ break
87
+ while queue.qsize() >= max_fill:
88
+ # keep room in the queue for requeues
89
+ await asyncio.sleep(1)
90
+ await queue.put((self.base_priority, (index, input)))
91
+ finally:
92
+ done_producing.set()
93
+
94
+ async def consumer(
95
+ self,
96
+ output: List[Any],
97
+ queue: asyncio.PriorityQueue[Tuple[int, Any]],
98
+ done_producing: asyncio.Event,
99
+ termination_event: asyncio.Event,
100
+ progress_bar: tqdm[Any],
101
+ ) -> None:
102
+ termination_event_watcher = None
103
+ while True:
104
+ marked_done = False
105
+ try:
106
+ priority, item = await asyncio.wait_for(queue.get(), timeout=1)
107
+ except asyncio.TimeoutError:
108
+ if done_producing.is_set() and queue.empty():
109
+ break
110
+ continue
111
+ if termination_event.is_set():
112
+ # discard any remaining items in the queue
113
+ queue.task_done()
114
+ marked_done = True
115
+ continue
116
+
117
+ index, payload = item
118
+ try:
119
+ generate_task = asyncio.create_task(self.generate(payload))
120
+ termination_event_watcher = asyncio.create_task(termination_event.wait())
121
+ done, pending = await asyncio.wait(
122
+ [generate_task, termination_event_watcher],
123
+ timeout=120,
124
+ return_when=asyncio.FIRST_COMPLETED,
125
+ )
126
+ if generate_task in done:
127
+ output[index] = generate_task.result()
128
+ progress_bar.update()
129
+ elif termination_event.is_set():
130
+ # discard the pending task and remaining items in the queue
131
+ if not generate_task.done():
132
+ generate_task.cancel()
133
+ try:
134
+ # allow any cleanup to finish for the cancelled task
135
+ await generate_task
136
+ except asyncio.CancelledError:
137
+ # Handle the cancellation exception
138
+ pass
139
+ queue.task_done()
140
+ marked_done = True
141
+ continue
142
+ else:
143
+ tqdm.write("Worker timeout, requeuing")
144
+ # task timeouts are requeued at base priority
145
+ await queue.put((self.base_priority, item))
146
+ except Exception as exc:
147
+ is_phoenix_exception = isinstance(exc, PhoenixException)
148
+ if (retry_count := abs(priority)) <= self.max_retries and not is_phoenix_exception:
149
+ tqdm.write(
150
+ f"Exception in worker on attempt {retry_count + 1}: raised {repr(exc)}"
151
+ )
152
+ tqdm.write("Requeuing...")
153
+ await queue.put((priority - 1, item))
154
+ else:
155
+ tqdm.write(f"Exception in worker: {traceback.format_exc()}")
156
+ if self.exit_on_error:
157
+ termination_event.set()
158
+ else:
159
+ progress_bar.update()
160
+ finally:
161
+ if not marked_done:
162
+ queue.task_done()
163
+ if termination_event_watcher and not termination_event_watcher.done():
164
+ termination_event_watcher.cancel()
165
+
166
+ async def execute(self, inputs: Sequence[Any]) -> List[Any]:
167
+ termination_event = asyncio.Event()
168
+
169
+ def termination_handler(signum: int, frame: Any) -> None:
170
+ termination_event.set()
171
+ tqdm.write("Process was interrupted. The return value will be incomplete...")
172
+
173
+ signal.signal(self.termination_signal, termination_handler)
174
+ outputs = [self.fallback_return_value] * len(inputs)
175
+ progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)
176
+
177
+ max_queue_size = 5 * self.concurrency # limit the queue to bound memory usage
178
+ max_fill = max_queue_size - (2 * self.concurrency) # ensure there is always room to requeue
179
+ queue: asyncio.PriorityQueue[Tuple[int, Any]] = asyncio.PriorityQueue(
180
+ maxsize=max_queue_size
181
+ )
182
+ done_producing = asyncio.Event()
183
+
184
+ producer = asyncio.create_task(
185
+ self.producer(inputs, queue, max_fill, done_producing, termination_event)
186
+ )
187
+ consumers = [
188
+ asyncio.create_task(
189
+ self.consumer(outputs, queue, done_producing, termination_event, progress_bar)
190
+ )
191
+ for _ in range(self.concurrency)
192
+ ]
193
+
194
+ await asyncio.gather(producer, *consumers)
195
+ join_task = asyncio.create_task(queue.join())
196
+ termination_event_watcher = asyncio.create_task(termination_event.wait())
197
+ done, pending = await asyncio.wait(
198
+ [join_task, termination_event_watcher], return_when=asyncio.FIRST_COMPLETED
199
+ )
200
+ if termination_event_watcher in done:
201
+ # Cancel all tasks
202
+ if not join_task.done():
203
+ join_task.cancel()
204
+ if not producer.done():
205
+ producer.cancel()
206
+ for task in consumers:
207
+ if not task.done():
208
+ task.cancel()
209
+
210
+ if not termination_event_watcher.done():
211
+ termination_event_watcher.cancel()
212
+
213
+ # reset the SIGTERM handler
214
+ signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler
215
+ return outputs
216
+
217
+ def run(self, inputs: Sequence[Any]) -> List[Any]:
218
+ return asyncio.run(self.execute(inputs))
219
+
220
+
221
+ class SyncExecutor(Executor):
222
+ """
223
+ Synchronous executor for generating outputs from inputs using a given generation function.
224
+
225
+ Args:
226
+ generation_fn (Callable[[Any], Any]): The generation function that takes an input and
227
+ returns an output.
228
+
229
+ tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
230
+ to None.
231
+
232
+ max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
233
+ 10.
234
+
235
+ exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
236
+ Defaults to True.
237
+
238
+ fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
239
+ that encounter errors. Defaults to _unset.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ generation_fn: Callable[[Any], Any],
245
+ tqdm_bar_format: Optional[str] = None,
246
+ max_retries: int = 10,
247
+ exit_on_error: bool = True,
248
+ fallback_return_value: Union[Unset, Any] = _unset,
249
+ termination_signal: signal.Signals = signal.SIGINT,
250
+ ):
251
+ self.generate = generation_fn
252
+ self.fallback_return_value = fallback_return_value
253
+ self.tqdm_bar_format = tqdm_bar_format
254
+ self.max_retries = max_retries
255
+ self.exit_on_error = exit_on_error
256
+ self.termination_signal = termination_signal
257
+
258
+ self._TERMINATE = False
259
+
260
+ def _signal_handler(self, signum: int, frame: Any) -> None:
261
+ tqdm.write("Process was interrupted. The return value will be incomplete...")
262
+ self._TERMINATE = True
263
+
264
+ def run(self, inputs: Sequence[Any]) -> List[Any]:
265
+ signal.signal(self.termination_signal, self._signal_handler)
266
+ outputs = [self.fallback_return_value] * len(inputs)
267
+ progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)
268
+
269
+ for index, input in enumerate(inputs):
270
+ try:
271
+ for attempt in range(self.max_retries + 1):
272
+ if self._TERMINATE:
273
+ return outputs
274
+ try:
275
+ result = self.generate(input)
276
+ outputs[index] = result
277
+ progress_bar.update()
278
+ except Exception as exc:
279
+ is_phoenix_exception = isinstance(exc, PhoenixException)
280
+ if attempt >= self.max_retries or is_phoenix_exception:
281
+ raise exc
282
+ else:
283
+ tqdm.write(f"Exception in worker on attempt {attempt + 1}: {exc}")
284
+ tqdm.write("Retrying...")
285
+ except Exception as exc:
286
+ tqdm.write(f"Exception in worker: {exc}")
287
+ if self.exit_on_error:
288
+ return outputs
289
+ else:
290
+ progress_bar.update()
291
+ signal.signal(self.termination_signal, signal.SIG_DFL) # reset the SIGTERM handler
292
+ return outputs
293
+
294
+
295
+ def get_executor_on_sync_context(
296
+ sync_fn: Callable[[Any], Any],
297
+ async_fn: Callable[[Any], Coroutine[Any, Any, Any]],
298
+ run_sync: bool = False,
299
+ concurrency: int = 3,
300
+ tqdm_bar_format: Optional[str] = None,
301
+ exit_on_error: bool = True,
302
+ fallback_return_value: Union[Unset, Any] = _unset,
303
+ ) -> Executor:
304
+ if run_sync:
305
+ return SyncExecutor(
306
+ sync_fn,
307
+ tqdm_bar_format=tqdm_bar_format,
308
+ exit_on_error=exit_on_error,
309
+ fallback_return_value=fallback_return_value,
310
+ )
311
+
312
+ if _running_event_loop_exists():
313
+ if getattr(asyncio, "_nest_patched", False):
314
+ return AsyncExecutor(
315
+ async_fn,
316
+ concurrency=concurrency,
317
+ tqdm_bar_format=tqdm_bar_format,
318
+ exit_on_error=exit_on_error,
319
+ fallback_return_value=fallback_return_value,
320
+ )
321
+ else:
322
+ logger.warning(
323
+ "🐌!! If running llm_classify inside a notebook, patching the event loop with "
324
+ "nest_asyncio will allow asynchronous eval submission, and is significantly "
325
+ "faster. To patch the event loop, run `nest_asyncio.apply()`."
326
+ )
327
+ return SyncExecutor(
328
+ sync_fn,
329
+ tqdm_bar_format=tqdm_bar_format,
330
+ exit_on_error=exit_on_error,
331
+ fallback_return_value=fallback_return_value,
332
+ )
333
+ else:
334
+ return AsyncExecutor(
335
+ async_fn,
336
+ concurrency=concurrency,
337
+ tqdm_bar_format=tqdm_bar_format,
338
+ exit_on_error=exit_on_error,
339
+ fallback_return_value=fallback_return_value,
340
+ )
341
+
342
+
343
+ def _running_event_loop_exists() -> bool:
344
+ """Checks for a running event loop.
345
+
346
+ Returns:
347
+ bool: True if a running event loop exists, False otherwise.
348
+ """
349
+ try:
350
+ asyncio.get_running_loop()
351
+ return True
352
+ except RuntimeError:
353
+ return False
@@ -1,9 +1,11 @@
1
1
  import logging
2
- from typing import Any, Callable, Dict, Optional, Union
2
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
3
3
 
4
4
  import pandas as pd
5
- from tqdm.auto import tqdm
6
5
 
6
+ from phoenix.experimental.evals.functions.executor import (
7
+ get_executor_on_sync_context,
8
+ )
7
9
  from phoenix.experimental.evals.models import BaseEvalModel, set_verbosity
8
10
  from phoenix.experimental.evals.templates import (
9
11
  PromptTemplate,
@@ -15,7 +17,7 @@ from phoenix.experimental.evals.utils import get_tqdm_progress_bar_formatter
15
17
  logger = logging.getLogger(__name__)
16
18
 
17
19
 
18
- def _no_op_parser(response: str) -> Dict[str, str]:
20
+ def _no_op_parser(response: str, response_index: int) -> Dict[str, str]:
19
21
  return {"output": response}
20
22
 
21
23
 
@@ -25,7 +27,11 @@ def llm_generate(
25
27
  model: BaseEvalModel,
26
28
  system_instruction: Optional[str] = None,
27
29
  verbose: bool = False,
28
- output_parser: Optional[Callable[[str], Dict[str, Any]]] = None,
30
+ output_parser: Optional[Callable[[str, int], Dict[str, Any]]] = None,
31
+ include_prompt: bool = False,
32
+ include_response: bool = False,
33
+ run_sync: bool = False,
34
+ concurrency: int = 20,
29
35
  ) -> pd.DataFrame:
30
36
  """
31
37
  Generates a text using a template using an LLM. This function is useful
@@ -49,10 +55,23 @@ def llm_generate(
49
55
  verbose (bool, optional): If True, prints detailed information to stdout such as model
50
56
  invocation parameters and retry info. Default False.
51
57
 
52
- output_parser (Callable[[str], Dict[str, Any]], optional): An optional function
53
- that takes each generated response and parses it to a dictionary. The keys of the dictionary
54
- should correspond to the column names of the output dataframe. If None, the output dataframe
55
- will have a single column named "output". Default None.
58
+ output_parser (Callable[[str, int], Dict[str, Any]], optional): An optional function
59
+ that takes each generated response and response index and parses it to a dictionary. The
60
+ keys of the dictionary should correspond to the column names of the output dataframe. If
61
+ None, the output dataframe will have a single column named "output". Default None.
62
+
63
+ include_prompt (bool, default=False): If True, includes a column named `prompt` in the
64
+ output dataframe containing the prompt used for each generation.
65
+
66
+ include_response (bool, default=False): If True, includes a column named `response` in the
67
+ output dataframe containing the raw response from the LLM prior to applying the output
68
+ parser.
69
+
70
+ run_sync (bool, default=False): If True, forces synchronous request submission. Otherwise
71
+ evaluations will be run asynchronously if possible.
72
+
73
+ concurrency (int, default=20): The number of concurrent evals if async submission is
74
+ possible.
56
75
 
57
76
  Returns:
58
77
  generations_dataframe (pandas.DataFrame): A dataframe where each row
@@ -61,28 +80,53 @@ def llm_generate(
61
80
  """
62
81
  tqdm_bar_format = get_tqdm_progress_bar_formatter("llm_generate")
63
82
  output_parser = output_parser or _no_op_parser
64
- with set_verbosity(model, verbose) as verbose_model:
65
- template = normalize_prompt_template(template)
66
- logger.info(f"Template: \n{template.prompt()}\n")
67
- logger.info(f"Template variables: {template.variables}")
68
- prompts = map_template(dataframe, template)
69
-
70
- # For each prompt, generate and parse the response
71
- output = []
72
- # Wrap the loop in a try / catch so that we can still return a dataframe
73
- # even if the process is interrupted
74
- try:
75
- for prompt in tqdm(prompts, bar_format=tqdm_bar_format):
76
- logger.info(f"Prompt: {prompt}")
77
- response = verbose_model(prompt, instruction=system_instruction)
78
- parsed_response = output_parser(response)
79
- output.append(parsed_response)
80
-
81
- except (Exception, KeyboardInterrupt) as e:
82
- logger.error(e)
83
- print(
84
- "Process was interrupted. The return value will be incomplete",
85
- e,
83
+ template = normalize_prompt_template(template)
84
+ logger.info(f"Template: \n{template.prompt()}\n")
85
+ logger.info(f"Template variables: {template.variables}")
86
+ prompts = map_template(dataframe, template)
87
+
88
+ async def _run_llm_generation_async(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]:
89
+ index, prompt = enumerated_prompt
90
+ with set_verbosity(model, verbose) as verbose_model:
91
+ response = await verbose_model._async_generate(
92
+ prompt,
93
+ instruction=system_instruction,
94
+ )
95
+ parsed_response = output_parser(response, index)
96
+ if include_prompt:
97
+ parsed_response["prompt"] = prompt
98
+ if include_response:
99
+ parsed_response["response"] = response
100
+ return parsed_response
101
+
102
+ def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]:
103
+ index, prompt = enumerated_prompt
104
+ with set_verbosity(model, verbose) as verbose_model:
105
+ response = verbose_model._generate(
106
+ prompt,
107
+ instruction=system_instruction,
86
108
  )
87
- # Return the data as a dataframe
88
- return pd.DataFrame(output)
109
+ parsed_response = output_parser(response, index)
110
+ if include_prompt:
111
+ parsed_response["prompt"] = prompt
112
+ if include_response:
113
+ parsed_response["response"] = response
114
+ return parsed_response
115
+
116
+ fallback_return_value = {
117
+ "output": "generation-failed",
118
+ **({"prompt": ""} if include_prompt else {}),
119
+ **({"response": ""} if include_response else {}),
120
+ }
121
+
122
+ executor = get_executor_on_sync_context(
123
+ _run_llm_generation_sync,
124
+ _run_llm_generation_async,
125
+ run_sync=run_sync,
126
+ concurrency=concurrency,
127
+ tqdm_bar_format=tqdm_bar_format,
128
+ exit_on_error=True,
129
+ fallback_return_value=fallback_return_value,
130
+ )
131
+ output = executor.run(list(enumerate(prompts.tolist())))
132
+ return pd.DataFrame(output)
@@ -6,6 +6,7 @@ from typing import Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar
6
6
 
7
7
  from typing_extensions import ParamSpec
8
8
 
9
+ from phoenix.exceptions import PhoenixException
9
10
  from phoenix.utilities.logging import printif
10
11
 
11
12
  ParameterSpec = ParamSpec("ParameterSpec")
@@ -13,7 +14,7 @@ GenericType = TypeVar("GenericType")
13
14
  AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
14
15
 
15
16
 
16
- class UnavailableTokensError(Exception):
17
+ class UnavailableTokensError(PhoenixException):
17
18
  pass
18
19
 
19
20
 
@@ -133,7 +134,7 @@ class AdaptiveTokenBucket:
133
134
  continue
134
135
 
135
136
 
136
- class RateLimitError(BaseException):
137
+ class RateLimitError(PhoenixException):
137
138
  ...
138
139
 
139
140
 
@@ -162,9 +163,9 @@ class RateLimiter:
162
163
  rate_increase_factor=rate_increase_factor,
163
164
  cooldown_seconds=cooldown_seconds,
164
165
  )
165
- self._rate_limit_handling = asyncio.Event()
166
- self._rate_limit_handling.set() # allow requests to start immediately
167
- self._rate_limit_handling_lock = asyncio.Lock()
166
+ self._rate_limit_handling: Optional[asyncio.Event] = None
167
+ self._rate_limit_handling_lock: Optional[asyncio.Lock] = None
168
+ self._current_loop: Optional[asyncio.AbstractEventLoop] = None
168
169
  self._verbose = verbose
169
170
 
170
171
  def limit(
@@ -192,11 +193,30 @@ class RateLimiter:
192
193
 
193
194
  return wrapper
194
195
 
196
+ def _initialize_async_primitives(self) -> None:
197
+ """
198
+ Lazily initialize async primitives to ensure they are created in the correct event loop.
199
+ """
200
+
201
+ loop = asyncio.get_running_loop()
202
+ if loop is not self._current_loop:
203
+ self._current_loop = loop
204
+ self._rate_limit_handling = asyncio.Event()
205
+ self._rate_limit_handling.set()
206
+ self._rate_limit_handling_lock = asyncio.Lock()
207
+
195
208
  def alimit(
196
209
  self, fn: AsyncCallable[ParameterSpec, GenericType]
197
210
  ) -> AsyncCallable[ParameterSpec, GenericType]:
198
211
  @wraps(fn)
199
212
  async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
213
+ self._initialize_async_primitives()
214
+ assert self._rate_limit_handling_lock is not None and isinstance(
215
+ self._rate_limit_handling_lock, asyncio.Lock
216
+ )
217
+ assert self._rate_limit_handling is not None and isinstance(
218
+ self._rate_limit_handling, asyncio.Event
219
+ )
200
220
  try:
201
221
  try:
202
222
  await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
@@ -15,7 +15,6 @@ from .default_templates import (
15
15
  TOXICITY_PROMPT_TEMPLATE,
16
16
  )
17
17
  from .template import (
18
- NOT_PARSABLE,
19
18
  ClassificationTemplate,
20
19
  PromptOptions,
21
20
  PromptTemplate,
@@ -32,7 +31,6 @@ __all__ = [
32
31
  "normalize_classification_template",
33
32
  "normalize_prompt_template",
34
33
  "map_template",
35
- "NOT_PARSABLE",
36
34
  "CODE_READABILITY_PROMPT_RAILS_MAP",
37
35
  "CODE_READABILITY_PROMPT_TEMPLATE",
38
36
  "HALLUCINATION_PROMPT_RAILS_MAP",
@@ -4,14 +4,11 @@ from typing import Callable, List, Mapping, Optional, Tuple, Union
4
4
 
5
5
  import pandas as pd
6
6
 
7
+ from phoenix.experimental.evals.utils import NOT_PARSABLE
8
+
7
9
  DEFAULT_START_DELIM = "{"
8
10
  DEFAULT_END_DELIM = "}"
9
11
 
10
- # Rather than returning None, we return this string to indicate that the LLM output could not be
11
- # parsed.
12
- # This is useful for debugging as well as to just treat the output as a non-parsable category
13
- NOT_PARSABLE = "NOT_PARSABLE"
14
-
15
12
 
16
13
  @dataclass
17
14
  class PromptOptions:
@@ -1,11 +1,24 @@
1
1
  import json
2
2
  from io import BytesIO
3
+ from typing import List, Optional, Tuple
3
4
  from urllib.error import HTTPError
4
5
  from urllib.request import urlopen
5
6
  from zipfile import ZipFile
6
7
 
7
8
  import pandas as pd
8
9
 
10
+ from phoenix.utilities.logging import printif
11
+
12
+ # Rather than returning None, we return this string to indicate that the LLM output could not be
13
+ # parsed.
14
+ # This is useful for debugging as well as to just treat the output as a non-parsable category
15
+ NOT_PARSABLE = "NOT_PARSABLE"
16
+
17
+ # argument keys in the default openai function call,
18
+ # defined here only to prevent typos
19
+ _RESPONSE = "response"
20
+ _EXPLANATION = "explanation"
21
+
9
22
 
10
23
  def download_benchmark_dataset(task: str, dataset_name: str) -> pd.DataFrame:
11
24
  """Downloads an Arize evals benchmark dataset as a pandas dataframe.
@@ -42,3 +55,56 @@ def get_tqdm_progress_bar_formatter(title: str) -> str:
42
55
  title + " |{bar}| {n_fmt}/{total_fmt} ({percentage:3.1f}%) "
43
56
  "| ⏳ {elapsed}<{remaining} | {rate_fmt}{postfix}"
44
57
  )
58
+
59
+
60
+ def snap_to_rail(raw_string: Optional[str], rails: List[str], verbose: bool = False) -> str:
61
+ """
62
+ Snaps a string to the nearest rail, or returns None if the string cannot be
63
+ snapped to a rail.
64
+
65
+ Args:
66
+ raw_string (str): An input to be snapped to a rail.
67
+
68
+ rails (List[str]): The target set of strings to snap to.
69
+
70
+ Returns:
71
+ str: A string from the rails argument or "UNPARSABLE" if the input
72
+ string could not be snapped.
73
+ """
74
+ if not raw_string:
75
+ return NOT_PARSABLE
76
+ snap_string = raw_string.lower()
77
+ rails = list(set(rail.lower() for rail in rails))
78
+ rails.sort(key=len, reverse=True)
79
+ found_rails = set()
80
+ for rail in rails:
81
+ if rail in snap_string:
82
+ found_rails.add(rail)
83
+ snap_string = snap_string.replace(rail, "")
84
+ if len(found_rails) != 1:
85
+ printif(verbose, f"- Cannot snap {repr(raw_string)} to rails")
86
+ return NOT_PARSABLE
87
+ rail = list(found_rails)[0]
88
+ printif(verbose, f"- Snapped {repr(raw_string)} to rail: {rail}")
89
+ return rail
90
+
91
+
92
+ def parse_openai_function_call(raw_output: str) -> Tuple[str, Optional[str]]:
93
+ """
94
+ Parses the output of an OpenAI function call.
95
+
96
+ Args:
97
+ raw_output (str): The raw output of an OpenAI function call.
98
+
99
+ Returns:
100
+ Tuple[str, Optional[str]]: A tuple of the unrailed label and an optional
101
+ explanation.
102
+ """
103
+ try:
104
+ function_arguments = json.loads(raw_output, strict=False)
105
+ unrailed_label = function_arguments.get(_RESPONSE, "")
106
+ explanation = function_arguments.get(_EXPLANATION)
107
+ except json.JSONDecodeError:
108
+ unrailed_label = raw_output
109
+ explanation = None
110
+ return unrailed_label, explanation