arize-phoenix 1.6.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -6,6 +6,7 @@ from inspect import BoundArguments, signature
6
6
  from types import TracebackType
7
7
  from typing import (
8
8
  Any,
9
+ AsyncIterator,
9
10
  Callable,
10
11
  ContextManager,
11
12
  Coroutine,
@@ -22,7 +23,7 @@ from typing import (
22
23
  )
23
24
 
24
25
  import openai
25
- from openai import Stream
26
+ from openai import AsyncStream, Stream
26
27
  from openai.types.chat import ChatCompletion, ChatCompletionChunk
27
28
  from typing_extensions import ParamSpec
28
29
  from wrapt import ObjectProxy
@@ -128,13 +129,14 @@ class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
128
129
  """
129
130
  self.tracer = tracer
130
131
  self.start_time: Optional[datetime] = None
131
- self.end_time: Optional[datetime] = None
132
132
  self.status_code = SpanStatusCode.UNSET
133
133
  self.status_message = ""
134
134
  self.events: List[SpanEvent] = []
135
135
  self.attributes: SpanAttributes = dict()
136
136
  parameters = _parameters(bound_arguments)
137
137
  self.num_choices = parameters.get("n", 1)
138
+ self.chat_completion_chunks: List[ChatCompletionChunk] = []
139
+ self.stream_complete = False
138
140
  self._process_parameters(parameters)
139
141
  self._span_created = False
140
142
 
@@ -150,14 +152,13 @@ class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
150
152
  ) -> None:
151
153
  if exc_value is None:
152
154
  return
153
- self.end_time = datetime.now(tz=timezone.utc)
154
155
  self.status_code = SpanStatusCode.ERROR
155
156
  status_message = str(exc_value)
156
157
  self.status_message = status_message
157
158
  self.events.append(
158
159
  SpanException(
159
160
  message=status_message,
160
- timestamp=self.end_time,
161
+ timestamp=datetime.now(tz=timezone.utc),
161
162
  exception_type=type(exc_value).__name__,
162
163
  exception_stacktrace=get_stacktrace(exc_value),
163
164
  )
@@ -171,17 +172,17 @@ class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
171
172
  Args:
172
173
  response (ChatCompletion): The chat completion object.
173
174
  """
174
- self.end_time = datetime.now(tz=timezone.utc)
175
- self.status_code = SpanStatusCode.OK
176
175
  if isinstance(response, ChatCompletion):
177
176
  self._process_chat_completion(response)
178
177
  elif isinstance(response, Stream):
179
- self.end_time = None # set end time to None to indicate that the stream is still open
180
178
  return StreamWrapper(stream=response, context=self)
179
+ elif isinstance(response, AsyncStream):
180
+ return AsyncStreamWrapper(stream=response, context=self)
181
181
  elif hasattr(response, "parse") and callable(
182
182
  response.parse
183
183
  ): # handle raw response by converting them to chat completions
184
184
  self._process_chat_completion(response.parse())
185
+ self.status_code = SpanStatusCode.OK
185
186
  self.create_span()
186
187
  return response
187
188
 
@@ -195,7 +196,7 @@ class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
195
196
  name="OpenAI Chat Completion",
196
197
  span_kind=SpanKind.LLM,
197
198
  start_time=cast(datetime, self.start_time),
198
- end_time=self.end_time,
199
+ end_time=datetime.now(tz=timezone.utc),
199
200
  status_code=self.status_code,
200
201
  status_message=self.status_message,
201
202
  attributes=self.attributes,
@@ -237,9 +238,8 @@ class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
237
238
 
238
239
  class StreamWrapper(ObjectProxy): # type: ignore
239
240
  """
240
- A wrapper for streams of chat completion chunks that records each span
241
- stream event and updates the span upon completion of the stream or upon an
242
- exception.
241
+ A wrapper for streams of chat completion chunks that records events and
242
+ creates the span upon completion of the stream or upon an exception.
243
243
  """
244
244
 
245
245
  def __init__(self, stream: Stream[ChatCompletionChunk], context: ChatCompletionContext) -> None:
@@ -252,8 +252,7 @@ class StreamWrapper(ObjectProxy): # type: ignore
252
252
  fields and attributes.
253
253
  """
254
254
  super().__init__(stream)
255
- self._self_context = context
256
- self._self_chunks: List[ChatCompletionChunk] = []
255
+ self._self_context = ChatCompletionStreamEventContext(context)
257
256
 
258
257
  def __next__(self) -> ChatCompletionChunk:
259
258
  """
@@ -263,48 +262,10 @@ class StreamWrapper(ObjectProxy): # type: ignore
263
262
  Returns:
264
263
  ChatCompletionChunk: The forwarded chat completion chunk.
265
264
  """
266
- finished_streaming = False
267
- try:
265
+ with self._self_context as context:
268
266
  chat_completion_chunk = next(self.__wrapped__)
269
- if not self._self_chunks:
270
- self._self_context.events.append(
271
- SpanEvent(
272
- name="First Token Stream Event",
273
- timestamp=datetime.now(tz=timezone.utc),
274
- attributes={},
275
- )
276
- )
277
- self._self_chunks.append(chat_completion_chunk)
267
+ context.process_chat_completion_chunk(chat_completion_chunk)
278
268
  return cast(ChatCompletionChunk, chat_completion_chunk)
279
- except StopIteration:
280
- finished_streaming = True
281
- raise
282
- except Exception as error:
283
- finished_streaming = True
284
- status_message = str(error)
285
- self._self_context.status_code = SpanStatusCode.ERROR
286
- self._self_context.status_message = status_message
287
- self._self_context.events.append(
288
- SpanException(
289
- message=status_message,
290
- timestamp=datetime.now(tz=timezone.utc),
291
- exception_type=type(error).__name__,
292
- exception_stacktrace=get_stacktrace(error),
293
- )
294
- )
295
- raise
296
- finally:
297
- if finished_streaming:
298
- self._self_context.end_time = datetime.now(tz=timezone.utc)
299
- self._self_context.attributes = {
300
- **self._self_context.attributes,
301
- LLM_OUTPUT_MESSAGES: _accumulate_messages(
302
- chunks=self._self_chunks, num_choices=self._self_context.num_choices
303
- ), # type: ignore
304
- OUTPUT_VALUE: json.dumps([chunk.dict() for chunk in self._self_chunks]),
305
- OUTPUT_MIME_TYPE: MimeType.JSON, # type: ignore
306
- }
307
- self._self_context.create_span()
308
269
 
309
270
  def __iter__(self) -> Iterator[ChatCompletionChunk]:
310
271
  """
@@ -314,6 +275,120 @@ class StreamWrapper(ObjectProxy): # type: ignore
314
275
  return self
315
276
 
316
277
 
278
+ class AsyncStreamWrapper(ObjectProxy): # type: ignore
279
+ """
280
+ A wrapper for asynchronous streams of chat completion chunks that records
281
+ events and creates the span upon completion of the stream or upon an
282
+ exception.
283
+ """
284
+
285
+ def __init__(
286
+ self, stream: AsyncStream[ChatCompletionChunk], context: ChatCompletionContext
287
+ ) -> None:
288
+ """Initializes the stream wrapper.
289
+
290
+ Args:
291
+ stream (AsyncStream[ChatCompletionChunk]): The stream to wrap.
292
+
293
+ context (ChatCompletionContext): The context used to store span
294
+ fields and attributes.
295
+ """
296
+ super().__init__(stream)
297
+ self._self_context = ChatCompletionStreamEventContext(context)
298
+
299
+ async def __anext__(self) -> ChatCompletionChunk:
300
+ """
301
+ A wrapped __anext__ method that records span stream events and updates
302
+ the span upon completion of the stream or upon exception.
303
+
304
+ Returns:
305
+ ChatCompletionChunk: The forwarded chat completion chunk.
306
+ """
307
+ with self._self_context as context:
308
+ chat_completion_chunk = await self.__wrapped__.__anext__()
309
+ context.process_chat_completion_chunk(chat_completion_chunk)
310
+ return cast(ChatCompletionChunk, chat_completion_chunk)
311
+
312
+ def __aiter__(self) -> AsyncIterator[ChatCompletionChunk]:
313
+ """
314
+ An __aiter__ method that bypasses the wrapped class' __aiter__ method so
315
+ that __aiter__ is automatically instrumented using __anext__.
316
+ """
317
+ return self
318
+
319
+
320
+ class ChatCompletionStreamEventContext(ContextManager["ChatCompletionStreamEventContext"]):
321
+ """
322
+ A context manager that surrounds stream events in a stream of chat
323
+ completions and processes each chat completion chunk.
324
+ """
325
+
326
+ def __init__(self, chat_completion_context: ChatCompletionContext) -> None:
327
+ """Initializes the context manager.
328
+
329
+ Args:
330
+ chat_completion_context (ChatCompletionContext): The chat completion
331
+ context storing span fields and attributes.
332
+ """
333
+ self._context = chat_completion_context
334
+
335
+ def __exit__(
336
+ self,
337
+ exc_type: Optional[Type[BaseException]],
338
+ exc_value: Optional[BaseException],
339
+ traceback: Optional[TracebackType],
340
+ ) -> None:
341
+ if isinstance(exc_value, StopIteration) or isinstance(exc_value, StopAsyncIteration):
342
+ self._context.stream_complete = True
343
+ self._context.status_code = SpanStatusCode.OK
344
+ elif exc_value is not None:
345
+ self._context.stream_complete = True
346
+ self._context.status_code = SpanStatusCode.ERROR
347
+ status_message = str(exc_value)
348
+ self._context.status_message = status_message
349
+ self._context.events.append(
350
+ SpanException(
351
+ message=status_message,
352
+ timestamp=datetime.now(tz=timezone.utc),
353
+ exception_type=type(exc_value).__name__,
354
+ exception_stacktrace=get_stacktrace(exc_value),
355
+ )
356
+ )
357
+ if not self._context.stream_complete:
358
+ return
359
+ self._context.attributes = {
360
+ **self._context.attributes,
361
+ LLM_OUTPUT_MESSAGES: _accumulate_messages(
362
+ chunks=self._context.chat_completion_chunks,
363
+ num_choices=self._context.num_choices,
364
+ ), # type: ignore
365
+ OUTPUT_VALUE: json.dumps(
366
+ [chunk.dict() for chunk in self._context.chat_completion_chunks]
367
+ ),
368
+ OUTPUT_MIME_TYPE: MimeType.JSON, # type: ignore
369
+ }
370
+ self._context.create_span()
371
+
372
+ def process_chat_completion_chunk(self, chat_completion_chunk: ChatCompletionChunk) -> None:
373
+ """
374
+ Processes a chat completion chunk and adds relevant information to the
375
+ context.
376
+
377
+ Args:
378
+ chat_completion_chunk (ChatCompletionChunk): The chat completion
379
+ chunk to be processed.
380
+ """
381
+ if not self._context.chat_completion_chunks:
382
+ self._context.events.append(
383
+ SpanEvent(
384
+ name="First Token Stream Event",
385
+ timestamp=datetime.now(tz=timezone.utc),
386
+ attributes={},
387
+ )
388
+ )
389
+ self._context.chat_completion_chunks.append(chat_completion_chunk)
390
+
391
+
317
392
  def _wrapped_openai_sync_client_request_function(
318
393
  request_fn: Callable[ParameterSpec, GenericType], tracer: Tracer
319
394
  ) -> Callable[ParameterSpec, GenericType]:
@@ -363,22 +438,11 @@ def _wrapped_openai_async_client_request_function(
363
438
 
364
439
  async def wrapped(*args: Any, **kwargs: Any) -> Any:
365
440
  bound_arguments = call_signature.bind(*args, **kwargs)
366
- if (
367
- _is_streaming_request(bound_arguments)
368
- or _request_type(bound_arguments) is not RequestType.CHAT_COMPLETION
369
- ):
441
+ if _request_type(bound_arguments) is not RequestType.CHAT_COMPLETION:
370
442
  return await request_fn(*args, **kwargs)
371
443
  with ChatCompletionContext(bound_arguments, tracer) as context:
372
444
  response = await request_fn(*args, **kwargs)
373
- context.process_response(
374
- cast(
375
- ChatCompletion,
376
- response.parse()
377
- if hasattr(response, "parse") and callable(response.parse)
378
- else response,
379
- )
380
- )
381
- return response
445
+ return context.process_response(response)
382
446
 
383
447
  return wrapped
384
448
 
@@ -521,19 +585,6 @@ _CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[[ChatCompletion], Any]]
521
585
  }
522
586
 
523
587
 
524
- def _is_streaming_request(bound_arguments: BoundArguments) -> bool:
525
- """
526
- Determines whether the request is a streaming request.
527
-
528
- Args:
529
- bound_arguments (BoundArguments): The bound arguments to the request function.
530
-
531
- Returns:
532
- bool: True if the request is a streaming request, False otherwise.
533
- """
534
- return cast(bool, bound_arguments.arguments["stream"])
535
-
536
-
537
588
  def _parameters(bound_arguments: BoundArguments) -> Parameters:
538
589
  """
539
590
  The parameters for the LLM call, e.g., temperature.
@@ -1,12 +1,163 @@
1
- import pandas as pd
1
+ from abc import ABC
2
+ from dataclasses import dataclass, field
3
+ from itertools import product
4
+ from types import MappingProxyType
5
+ from typing import Any, Callable, List, Mapping, Optional, Sequence, Set, Tuple
2
6
 
3
- EVALUATIONS_INDEX_NAME = "context.span_id"
4
- RESULTS_COLUMN_NAMES = ["score", "label", "explanation"]
7
+ import pandas as pd
8
+ from pandas.api.types import is_integer_dtype, is_numeric_dtype, is_string_dtype
5
9
 
6
10
  EVAL_NAME_COLUMN_PREFIX = "eval."
7
11
 
8
12
 
9
- class SpanEvaluations:
13
+ class NeedsNamedIndex(ABC):
14
+ index_names: Mapping[Tuple[str, ...], Callable[[Any], bool]]
15
+ all_valid_index_name_sorted_combos: Set[Tuple[str, ...]]
16
+
17
+ @classmethod
18
+ def preferred_names(cls) -> List[str]:
19
+ return [choices[0] for choices in cls.index_names.keys()]
20
+
21
+ @classmethod
22
+ def aliases(cls) -> Mapping[str, str]:
23
+ return {alias: choices[0] for choices in cls.index_names.keys() for alias in choices[1:]}
24
+
25
+ @classmethod
26
+ def unalias(cls, name: str) -> str:
27
+ return cls.aliases().get(name, name)
28
+
29
+ @classmethod
30
+ def is_valid_index_names(cls, names: Sequence[str]) -> bool:
31
+ return (
32
+ len(names) == len(cls.index_names)
33
+ and tuple(sorted(names)) in cls.all_valid_index_name_sorted_combos
34
+ )
35
+
36
+ @classmethod
37
+ def find_valid_index_names(cls, dtypes: "pd.Series[Any]") -> Optional[List[str]]:
38
+ valid_names = []
39
+ for names, check_type in cls.index_names.items():
40
+ for name in names:
41
+ if name in dtypes.index and check_type(dtypes[name]):
42
+ valid_names.append(name)
43
+ break
44
+ else:
45
+ return None
46
+ return valid_names
47
+
48
+
49
+ class NeedsResultColumns(ABC):
50
+ result_column_names: Mapping[str, Callable[[Any], bool]] = MappingProxyType(
51
+ {
52
+ "score": is_numeric_dtype,
53
+ "label": is_string_dtype,
54
+ "explanation": is_string_dtype,
55
+ }
56
+ )
57
+
58
+ @classmethod
59
+ def is_valid_result_columns(cls, dtypes: "pd.Series[Any]") -> bool:
60
+ names = cls.result_column_names.keys()
61
+ intersection = dtypes.index.intersection(names) # type: ignore
62
+ if not len(intersection):
63
+ return False
64
+ for name in intersection:
65
+ check_type = cls.result_column_names[name]
66
+ if not check_type(dtypes[name]):
67
+ return False
68
+ return True
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class Evaluations(NeedsNamedIndex, NeedsResultColumns, ABC):
73
+ eval_name: str # The name for the evaluation, e.g. 'toxicity'
74
+ dataframe: pd.DataFrame = field(repr=False)
75
+
76
+ def __len__(self) -> int:
77
+ return len(self.dataframe)
78
+
79
+ def __repr__(self) -> str:
80
+ return (
81
+ f"{self.__class__.__name__}(eval_name={self.eval_name!r}, "
82
+ f"dataframe=<rows: {len(self.dataframe)!r}>)"
83
+ )
84
+
85
+ def __dir__(self) -> List[str]:
86
+ return ["get_dataframe"]
87
+
88
+ def get_dataframe(self, prefix_columns_with_name: bool = True) -> pd.DataFrame:
89
+ """
90
+ Returns a copy of the dataframe with the evaluation annotations
91
+
92
+ Parameters
93
+ __________
94
+ prefix_columns_with_name: bool
95
+ if True, the columns will be prefixed with the eval_name, e.g. 'eval.toxicity.value'
96
+ """
97
+ if prefix_columns_with_name:
98
+ prefix = f"{EVAL_NAME_COLUMN_PREFIX}{self.eval_name}."
99
+ return self.dataframe.add_prefix(prefix)
100
+ return self.dataframe.copy(deep=False)
101
+
102
+ def __bool__(self) -> bool:
103
+ return not self.dataframe.empty
104
+
105
+ def __post_init__(self) -> None:
106
+ dataframe = (
107
+ pd.DataFrame() if self.dataframe.empty else self._clean_dataframe(self.dataframe)
108
+ )
109
+ object.__setattr__(self, "dataframe", dataframe)
110
+
111
+ def _clean_dataframe(self, dataframe: pd.DataFrame) -> pd.DataFrame:
112
+ # Ensure column names are strings.
113
+ column_names = dataframe.columns.astype(str)
114
+ dataframe = dataframe.set_axis(column_names, axis=1)
115
+
116
+ # If the dataframe contains the index columns, set the index to those columns
117
+ if not self.is_valid_index_names(dataframe.index.names) and (
118
+ index_names := self.find_valid_index_names(dataframe.dtypes)
119
+ ):
120
+ dataframe = dataframe.set_index(index_names)
121
+
122
+ # Validate that the dataframe is indexed correctly.
123
+ if not self.is_valid_index_names(dataframe.index.names):
124
+ raise ValueError(
125
+ f"The dataframe index must be {self.preferred_names()} but was "
126
+ f"'{dataframe.index.name or dataframe.index.names}'"
127
+ )
128
+
129
+ # Validate that the dataframe contains result columns of appropriate types.
130
+ if not self.is_valid_result_columns(dataframe.dtypes):
131
+ raise ValueError(
132
+ f"The dataframe must contain one of these columns with appropriate "
133
+ f"value types: {self.result_column_names.keys()} "
134
+ )
135
+
136
+ # Un-alias to the preferred names.
137
+ preferred_names = [self.unalias(name) for name in dataframe.index.names]
138
+ dataframe = dataframe.rename_axis(preferred_names, axis=0)
139
+
140
+ # Drop the unnecessary columns.
141
+ result_column_names = dataframe.columns.intersection(self.result_column_names.keys()) # type: ignore
142
+ return dataframe.loc[:, result_column_names] # type: ignore
143
+
144
+ def __init_subclass__(
145
+ cls,
146
+ index_names: Mapping[Tuple[str, ...], Callable[[Any], bool]],
147
+ **kwargs: Any,
148
+ ) -> None:
149
+ super().__init_subclass__(**kwargs)
150
+ cls.index_names = index_names
151
+ cls.all_valid_index_name_sorted_combos = set(
152
+ tuple(sorted(prod)) for prod in product(*cls.index_names.keys())
153
+ )
154
+
155
+
156
+ @dataclass(frozen=True)
157
+ class SpanEvaluations(
158
+ Evaluations,
159
+ index_names=MappingProxyType({("context.span_id", "span_id"): is_string_dtype}),
160
+ ):
10
161
  """
11
162
  SpanEvaluations is a set of evaluation annotations for a set of spans.
12
163
  SpanEvaluations encompasses the evaluation annotations for a single evaluation task
@@ -18,7 +169,7 @@ class SpanEvaluations:
18
169
  Parameters
19
170
  __________
20
171
  eval_name: str
21
- the name of the evaluation, e.x. 'toxicity'
172
+ the name of the evaluation, e.g. 'toxicity'
22
173
  dataframe: pandas.DataFrame
23
174
  the pandas dataframe containing the evaluation annotations Each row
24
175
  represents the evaluations on a span.
@@ -35,38 +186,52 @@ class SpanEvaluations:
35
186
  | span_3 | 1 | toxic | discrimination |
36
187
  """
37
188
 
38
- dataframe: pd.DataFrame
39
189
 
40
- eval_name: str # The name for the evaluation, e.x. 'toxicity'
190
+ @dataclass(frozen=True)
191
+ class DocumentEvaluations(
192
+ Evaluations,
193
+ index_names=MappingProxyType(
194
+ {
195
+ ("context.span_id", "span_id"): is_string_dtype,
196
+ ("document_position", "position"): is_integer_dtype,
197
+ }
198
+ ),
199
+ ):
200
+ """
201
+ DocumentEvaluations is a set of evaluation annotations for a set of documents.
202
+ DocumentEvaluations encompasses the evaluation annotations for a single evaluation task
203
+ such as relevance.
41
204
 
42
- def __init__(self, eval_name: str, dataframe: pd.DataFrame):
43
- self.eval_name = eval_name
205
+ Parameters
206
+ __________
207
+ eval_name: str
208
+ the name of the evaluation, e.g. 'relevance'
209
+ dataframe: pandas.DataFrame
210
+ the pandas dataframe containing the evaluation annotations. Each row
211
+ represents the evaluations on a document.
44
212
 
45
- # If the dataframe contains the index column, set the index to that column
46
- if EVALUATIONS_INDEX_NAME in dataframe.columns:
47
- dataframe = dataframe.set_index(EVALUATIONS_INDEX_NAME)
213
+ Example
214
+ _______
48
215
 
49
- # validate that the dataframe is indexed by context.span_id
50
- if dataframe.index.name != EVALUATIONS_INDEX_NAME:
51
- raise ValueError(
52
- f"The dataframe index must be '{EVALUATIONS_INDEX_NAME}' but was "
53
- f"'{dataframe.index.name}'"
54
- )
216
+ DataFrame of document evaluations for relevance may look like:
55
217
 
56
- # Drop the unnecessary columns
57
- extra_column_names = dataframe.columns.difference(RESULTS_COLUMN_NAMES)
58
- self.dataframe = dataframe.drop(extra_column_names, axis=1)
218
+ | span_id | position | score | label | explanation |
219
+ |---------|----------|-------|------------|--------------|
220
+ | span_1 | 0 | 1 | relevant | it's apropos |
221
+ | span_1 | 1 | 1 | relevant | it's germane |
222
+ | span_2 | 0 | 0 | irrelevant | it's rubbish |
223
+ """
59
224
 
60
- def get_dataframe(self, prefix_columns_with_name: bool = True) -> pd.DataFrame:
61
- """
62
- Returns a copy of the dataframe with the evaluation annotations
225
+ def _clean_dataframe(self, dataframe: pd.DataFrame) -> pd.DataFrame:
226
+ dataframe = super()._clean_dataframe(dataframe)
227
+ if dataframe.index.names != self.preferred_names():
228
+ return dataframe.swaplevel()
229
+ return dataframe
63
230
 
64
- Parameters
65
- __________
66
- prefix_columns_with_name: bool
67
- if True, the columns will be prefixed with the eval_name, e.x. 'eval.toxicity.value'
68
- """
69
- if prefix_columns_with_name:
70
- prefix = f"{EVAL_NAME_COLUMN_PREFIX}{self.eval_name}."
71
- return self.dataframe.add_prefix(prefix)
72
- return self.dataframe.copy()
231
+
232
+ @dataclass(frozen=True)
233
+ class TraceEvaluations(
234
+ Evaluations,
235
+ index_names=MappingProxyType({("context.trace_id", "trace_id"): is_string_dtype}),
236
+ ):
237
+ ...
@@ -16,7 +16,7 @@ from .semantic_conventions import (
16
16
  RERANKER_OUTPUT_DOCUMENTS,
17
17
  RETRIEVAL_DOCUMENTS,
18
18
  )
19
- from .span_evaluations import EVALUATIONS_INDEX_NAME, SpanEvaluations
19
+ from .span_evaluations import Evaluations, SpanEvaluations
20
20
  from .span_json_decoder import json_to_span
21
21
  from .span_json_encoder import span_to_json
22
22
 
@@ -93,14 +93,14 @@ class TraceDataset:
93
93
 
94
94
  name: str
95
95
  dataframe: pd.DataFrame
96
- evaluations: List[SpanEvaluations] = []
96
+ evaluations: List[Evaluations] = []
97
97
  _data_file_name: str = "data.parquet"
98
98
 
99
99
  def __init__(
100
100
  self,
101
101
  dataframe: DataFrame,
102
102
  name: Optional[str] = None,
103
- evaluations: Iterable[SpanEvaluations] = (),
103
+ evaluations: Iterable[Evaluations] = (),
104
104
  ):
105
105
  """
106
106
  Constructs a TraceDataset from a dataframe of spans. Optionally takes in
@@ -199,7 +199,7 @@ class TraceDataset:
199
199
  coerce_timestamps="ms",
200
200
  )
201
201
 
202
- def append_evaluations(self, evaluations: SpanEvaluations) -> None:
202
+ def append_evaluations(self, evaluations: Evaluations) -> None:
203
203
  """adds an evaluation to the traces"""
204
204
  # Append the evaluations to the list of evaluations
205
205
  self.evaluations.append(evaluations)
@@ -209,7 +209,11 @@ class TraceDataset:
209
209
  Creates a flat dataframe of all the evaluations for the dataset.
210
210
  """
211
211
  return pd.concat(
212
- [evals.get_dataframe(prefix_columns_with_name=True) for evals in self.evaluations],
212
+ [
213
+ evals.get_dataframe(prefix_columns_with_name=True)
214
+ for evals in self.evaluations
215
+ if isinstance(evals, SpanEvaluations)
216
+ ],
213
217
  axis=1,
214
218
  )
215
219
 
@@ -227,5 +231,5 @@ class TraceDataset:
227
231
  return self.dataframe.copy()
228
232
  evals_df = self.get_evals_dataframe()
229
233
  # Make sure the index is set to the span_id
230
- df = self.dataframe.set_index(EVALUATIONS_INDEX_NAME, drop=False)
234
+ df = self.dataframe.set_index("context.span_id", drop=False)
231
235
  return pd.concat([df, evals_df], axis=1)
phoenix/trace/tracer.py CHANGED
@@ -72,6 +72,7 @@ class Tracer:
72
72
  attributes: Optional[SpanAttributes] = None,
73
73
  events: Optional[List[SpanEvent]] = None,
74
74
  conversation: Optional[SpanConversationAttributes] = None,
75
+ span_id: Optional[UUID] = None,
75
76
  ) -> Span:
76
77
  """
77
78
  create_span creates a new span with the given name and options.
@@ -90,7 +91,7 @@ class Tracer:
90
91
 
91
92
  span = Span(
92
93
  name=name,
93
- context=SpanContext(trace_id=trace_id, span_id=uuid4()),
94
+ context=SpanContext(trace_id=trace_id, span_id=span_id or uuid4()),
94
95
  span_kind=span_kind,
95
96
  parent_id=parent_id,
96
97
  start_time=start_time,