arize-phoenix 1.6.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -6,11 +6,9 @@ A set of **highly experimental** helper functions to
6
6
  - ingest evaluation results into Phoenix via HttpExporter
7
7
  """
8
8
  import math
9
+ from time import sleep
9
10
  from typing import (
10
11
  Any,
11
- Iterable,
12
- List,
13
- Mapping,
14
12
  Optional,
15
13
  Sequence,
16
14
  Tuple,
@@ -20,56 +18,19 @@ from typing import (
20
18
 
21
19
  import pandas as pd
22
20
  from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
21
+ from tqdm import tqdm
23
22
 
24
23
  import phoenix.trace.v1 as pb
25
- from phoenix.core.traces import TRACE_ID
26
- from phoenix.session.session import Session
24
+ from phoenix.trace.dsl.helpers import get_qa_with_reference, get_retrieved_documents
27
25
  from phoenix.trace.exporter import HttpExporter
28
- from phoenix.trace.schemas import ATTRIBUTE_PREFIX
29
- from phoenix.trace.semantic_conventions import (
30
- DOCUMENT_CONTENT,
31
- DOCUMENT_SCORE,
32
- INPUT_VALUE,
33
- RETRIEVAL_DOCUMENTS,
34
- )
35
26
 
27
+ __all__ = [
28
+ "get_retrieved_documents",
29
+ "get_qa_with_reference",
30
+ "add_evaluations",
31
+ ]
36
32
 
37
- def get_retrieved_documents(session: Session) -> pd.DataFrame:
38
- data: List[Mapping[str, Any]] = []
39
- if (df := session.get_spans_dataframe("span_kind == 'RETRIEVER'")) is not None:
40
- for span_id, query, documents, trace_id in df.loc[
41
- :,
42
- [
43
- ATTRIBUTE_PREFIX + INPUT_VALUE,
44
- ATTRIBUTE_PREFIX + RETRIEVAL_DOCUMENTS,
45
- TRACE_ID,
46
- ],
47
- ].itertuples():
48
- if not isinstance(documents, Iterable):
49
- continue
50
- for position, document in enumerate(documents):
51
- if not hasattr(document, "get"):
52
- continue
53
- data.append(
54
- {
55
- "context.trace_id": trace_id,
56
- "context.span_id": span_id,
57
- "input": query,
58
- "document_position": position,
59
- "reference": document.get(DOCUMENT_CONTENT),
60
- "document_score": document.get(DOCUMENT_SCORE),
61
- }
62
- )
63
- index = ["context.span_id", "document_position"]
64
- columns = [
65
- "context.span_id",
66
- "document_position",
67
- "input",
68
- "reference",
69
- "document_score",
70
- "context.trace_id",
71
- ]
72
- return pd.DataFrame(data=data, columns=columns).set_index(index)
33
+ from phoenix.trace.span_evaluations import Evaluations
73
34
 
74
35
 
75
36
  def add_evaluations(
@@ -155,3 +116,22 @@ def _extract_result(row: "pd.Series[Any]") -> Optional[pb.Evaluation.Result]:
155
116
  label=StringValue(value=label) if label else None,
156
117
  explanation=StringValue(value=explanation) if explanation else None,
157
118
  )
119
+
120
+
121
+ def log_evaluations(
122
+ *evals: Evaluations,
123
+ endpoint: Optional[str] = None,
124
+ host: Optional[str] = None,
125
+ port: Optional[int] = None,
126
+ ) -> None:
127
+ if not (n := sum(map(len, evals))):
128
+ return
129
+ exporter = HttpExporter(endpoint=endpoint, host=host, port=port)
130
+ for eval in filter(bool, evals):
131
+ add_evaluations(exporter, eval.dataframe, eval.eval_name)
132
+ with tqdm(total=n, desc="Sending Evaluations") as pbar:
133
+ while n:
134
+ sleep(0.1)
135
+ n_left = exporter._queue.qsize()
136
+ n, diff = n_left, n - n_left
137
+ pbar.update(diff)
@@ -30,6 +30,7 @@ from phoenix.server.app import create_app
30
30
  from phoenix.server.thread_server import ThreadServer
31
31
  from phoenix.services import AppService
32
32
  from phoenix.trace.dsl import SpanFilter
33
+ from phoenix.trace.dsl.query import SpanQuery
33
34
  from phoenix.trace.span_json_encoder import span_to_json
34
35
  from phoenix.trace.trace_dataset import TraceDataset
35
36
 
@@ -167,6 +168,25 @@ class Session(ABC):
167
168
  """Returns the url for the phoenix app"""
168
169
  return _get_url(self.host, self.port, self.notebook_env)
169
170
 
171
+ def query_spans(
172
+ self,
173
+ *queries: SpanQuery,
174
+ start_time: Optional[datetime] = None,
175
+ stop_time: Optional[datetime] = None,
176
+ ) -> Optional[Union[pd.DataFrame, List[pd.DataFrame]]]:
177
+ if len(queries) == 0 or (traces := self.traces) is None:
178
+ return None
179
+ spans = tuple(
180
+ traces.get_spans(
181
+ start_time=start_time,
182
+ stop_time=stop_time,
183
+ )
184
+ )
185
+ dataframes = [query(spans) for query in queries]
186
+ if len(dataframes) == 1:
187
+ return dataframes[0]
188
+ return dataframes
189
+
170
190
  def get_spans_dataframe(
171
191
  self,
172
192
  filter_condition: Optional[str] = None,
phoenix/trace/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .span_evaluations import SpanEvaluations
1
+ from .span_evaluations import DocumentEvaluations, SpanEvaluations, TraceEvaluations
2
2
  from .trace_dataset import TraceDataset
3
3
 
4
- __all__ = ["TraceDataset", "SpanEvaluations"]
4
+ __all__ = ["TraceDataset", "SpanEvaluations", "DocumentEvaluations", "TraceEvaluations"]
@@ -1,5 +1,7 @@
1
1
  from phoenix.trace.dsl.filter import SpanFilter
2
+ from phoenix.trace.dsl.query import SpanQuery
2
3
 
3
4
  __all__ = [
4
5
  "SpanFilter",
6
+ "SpanQuery",
5
7
  ]
@@ -0,0 +1,61 @@
1
+ from typing import List, Optional, Protocol, Union, cast
2
+
3
+ import pandas as pd
4
+
5
+ from phoenix.trace.dsl import SpanQuery
6
+ from phoenix.trace.semantic_conventions import (
7
+ DOCUMENT_CONTENT,
8
+ DOCUMENT_SCORE,
9
+ INPUT_VALUE,
10
+ OUTPUT_VALUE,
11
+ RETRIEVAL_DOCUMENTS,
12
+ )
13
+
14
+ INPUT = {"input": INPUT_VALUE}
15
+ OUTPUT = {"output": OUTPUT_VALUE}
16
+ IO = {**INPUT, **OUTPUT}
17
+
18
+ IS_ROOT = "parent_id is None"
19
+ IS_LLM = "span_kind == 'LLM'"
20
+ IS_RETRIEVER = "span_kind == 'RETRIEVER'"
21
+
22
+
23
+ class Session(Protocol):
24
+ def query_spans(self, *query: SpanQuery) -> Optional[Union[pd.DataFrame, List[pd.DataFrame]]]:
25
+ ...
26
+
27
+
28
+ def get_retrieved_documents(session: Session) -> pd.DataFrame:
29
+ return cast(
30
+ pd.DataFrame,
31
+ session.query_spans(
32
+ SpanQuery()
33
+ .where(IS_RETRIEVER)
34
+ .select("trace_id", **INPUT)
35
+ .explode(
36
+ RETRIEVAL_DOCUMENTS,
37
+ reference=DOCUMENT_CONTENT,
38
+ document_score=DOCUMENT_SCORE,
39
+ )
40
+ ),
41
+ )
42
+
43
+
44
+ def get_qa_with_reference(session: Session) -> pd.DataFrame:
45
+ return pd.concat(
46
+ cast(
47
+ List[pd.DataFrame],
48
+ session.query_spans(
49
+ SpanQuery().select(**IO).where(IS_ROOT),
50
+ SpanQuery()
51
+ .where(IS_RETRIEVER)
52
+ .select(span_id="parent_id")
53
+ .concat(
54
+ RETRIEVAL_DOCUMENTS,
55
+ reference=DOCUMENT_CONTENT,
56
+ ),
57
+ ),
58
+ ),
59
+ axis=1,
60
+ join="inner",
61
+ )
@@ -0,0 +1,261 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass, field, fields, replace
3
+ from functools import cached_property, partial
4
+ from types import MappingProxyType
5
+ from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Sequence, Tuple
6
+
7
+ import pandas as pd
8
+
9
+ from phoenix.trace.dsl import SpanFilter
10
+ from phoenix.trace.schemas import ATTRIBUTE_PREFIX, CONTEXT_PREFIX, Span
11
+ from phoenix.trace.semantic_conventions import RETRIEVAL_DOCUMENTS
12
+
13
+ _SPAN_ID = "context.span_id"
14
+ _PRESCRIBED_POSITION_PREFIXES = {
15
+ RETRIEVAL_DOCUMENTS: "document_",
16
+ ATTRIBUTE_PREFIX + RETRIEVAL_DOCUMENTS: "document_",
17
+ }
18
+ _ALIASES = {
19
+ "span_id": "context.span_id",
20
+ "trace_id": "context.trace_id",
21
+ }
22
+
23
+ # Because UUIDs is not convertible to Parquet,
24
+ # they need to be converted to string.
25
+ _CONVERT_TO_STRING = (
26
+ "context.span_id",
27
+ "context.trace_id",
28
+ "parent_id",
29
+ )
30
+
31
+
32
+ def _unalias(key: str) -> str:
33
+ return _ALIASES.get(key, key)
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class Projection:
38
+ key: str = ""
39
+ value: Callable[[Span], Any] = field(init=False, repr=False)
40
+ span_fields: ClassVar[Tuple[str, ...]] = tuple(f.name for f in fields(Span))
41
+
42
+ def __bool__(self) -> bool:
43
+ return bool(self.key)
44
+
45
+ def __post_init__(self) -> None:
46
+ key = _unalias(self.key)
47
+ object.__setattr__(self, "key", key)
48
+ if key.startswith(CONTEXT_PREFIX):
49
+ key = key[len(CONTEXT_PREFIX) :]
50
+ value = partial(self._from_context, key=key)
51
+ elif key.startswith(ATTRIBUTE_PREFIX):
52
+ key = self.key[len(ATTRIBUTE_PREFIX) :]
53
+ value = partial(self._from_attributes, key=key)
54
+ elif key in self.span_fields:
55
+ value = partial(self._from_span, key=key)
56
+ else:
57
+ value = partial(self._from_attributes, key=key)
58
+ if self.key in _CONVERT_TO_STRING:
59
+ object.__setattr__(
60
+ self,
61
+ "value",
62
+ lambda span: None if (v := value(span)) is None else str(v),
63
+ )
64
+ else:
65
+ object.__setattr__(self, "value", value)
66
+
67
+ def __call__(self, span: Span) -> Any:
68
+ return self.value(span)
69
+
70
+ @staticmethod
71
+ def _from_attributes(span: Span, key: str) -> Any:
72
+ return span.attributes.get(key)
73
+
74
+ @staticmethod
75
+ def _from_context(span: Span, key: str) -> Any:
76
+ return getattr(span.context, key, None)
77
+
78
+ @staticmethod
79
+ def _from_span(span: Span, key: str) -> Any:
80
+ return getattr(span, key, None)
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class Explosion(Projection):
85
+ kwargs: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({}))
86
+ primary_index_key: str = "context.span_id"
87
+
88
+ position_prefix: str = field(init=False, repr=False)
89
+ primary_index: Projection = field(init=False, repr=False)
90
+
91
+ def __post_init__(self) -> None:
92
+ super().__post_init__()
93
+ position_prefix = _PRESCRIBED_POSITION_PREFIXES.get(self.key, "")
94
+ object.__setattr__(self, "position_prefix", position_prefix)
95
+ object.__setattr__(self, "primary_index", Projection(self.primary_index_key))
96
+
97
+ @cached_property
98
+ def index_keys(self) -> Tuple[str, str]:
99
+ return (self.primary_index.key, f"{self.position_prefix}position")
100
+
101
+ def with_primary_index_key(self, primary_index_key: str) -> "Explosion":
102
+ return replace(self, primary_index_key=primary_index_key)
103
+
104
+ def __call__(self, span: Span) -> Iterator[Dict[str, Any]]:
105
+ if not isinstance(seq := self.value(span), Sequence):
106
+ return
107
+ has_mapping = False
108
+ for item in seq:
109
+ if isinstance(item, Mapping):
110
+ has_mapping = True
111
+ break
112
+ if not has_mapping:
113
+ for i, item in enumerate(seq):
114
+ if item is not None:
115
+ yield {
116
+ self.key: item,
117
+ self.primary_index.key: self.primary_index(span),
118
+ f"{self.position_prefix}position": i,
119
+ }
120
+ return
121
+ for i, item in enumerate(seq):
122
+ if not isinstance(item, Mapping):
123
+ continue
124
+ record = (
125
+ {name: item.get(key) for name, key in self.kwargs.items()}
126
+ if self.kwargs
127
+ else dict(item)
128
+ )
129
+ for v in record.values():
130
+ if v is not None:
131
+ break
132
+ else:
133
+ record = {}
134
+ if not record:
135
+ continue
136
+ record[self.primary_index.key] = self.primary_index(span)
137
+ record[f"{self.position_prefix}position"] = i
138
+ yield record
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class Concatenation(Projection):
143
+ kwargs: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({}))
144
+ separator: str = "\n\n"
145
+
146
+ def with_separator(self, separator: str = "\n\n") -> "Concatenation":
147
+ return replace(self, separator=separator)
148
+
149
+ def __call__(self, span: Span) -> Iterator[Tuple[str, str]]:
150
+ if not isinstance(seq := self.value(span), Sequence):
151
+ return
152
+ if not self.kwargs:
153
+ yield self.key, self.separator.join(map(str, seq))
154
+ record = defaultdict(list)
155
+ for item in seq:
156
+ if not isinstance(item, Mapping):
157
+ continue
158
+ for k, v in self.kwargs.items():
159
+ if value := item.get(v):
160
+ record[k].append(value)
161
+ for name, values in record.items():
162
+ yield name, self.separator.join(map(str, values))
163
+
164
+
165
+ @dataclass(frozen=True)
166
+ class SpanQuery:
167
+ _select: Mapping[str, Projection] = field(default_factory=lambda: MappingProxyType({}))
168
+ _concat: Concatenation = field(default_factory=Concatenation)
169
+ _explode: Explosion = field(default_factory=Explosion)
170
+ _filter: SpanFilter = field(default_factory=SpanFilter)
171
+ _rename: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({}))
172
+ _index: Projection = field(default_factory=lambda: Projection("context.span_id"))
173
+
174
+ def __bool__(self) -> bool:
175
+ return bool(self._select) or bool(self._filter) or bool(self._explode) or bool(self._concat)
176
+
177
+ def select(self, *args: str, **kwargs: str) -> "SpanQuery":
178
+ _select = {
179
+ _unalias(name): Projection(key) for name, key in (*zip(args, args), *kwargs.items())
180
+ }
181
+ return replace(self, _select=MappingProxyType(_select))
182
+
183
+ def where(self, condition: str) -> "SpanQuery":
184
+ _filter = SpanFilter(condition)
185
+ return replace(self, _filter=_filter)
186
+
187
+ def explode(self, key: str, **kwargs: str) -> "SpanQuery":
188
+ _explode = Explosion(key=key, kwargs=kwargs, primary_index_key=self._index.key)
189
+ return replace(self, _explode=_explode)
190
+
191
+ def concat(self, key: str, **kwargs: str) -> "SpanQuery":
192
+ _concat = Concatenation(key=key, kwargs=kwargs)
193
+ return replace(self, _concat=_concat)
194
+
195
+ def rename(self, **kwargs: str) -> "SpanQuery":
196
+ _rename = MappingProxyType(kwargs)
197
+ return replace(self, _rename=_rename)
198
+
199
+ def with_index(self, key: str = "context.span_id") -> "SpanQuery":
200
+ _index = Projection(key=key)
201
+ return replace(self, _index=_index)
202
+
203
+ def with_concat_separator(self, separator: str = "\n\n") -> "SpanQuery":
204
+ _concat = self._concat.with_separator(separator)
205
+ return replace(self, _concat=_concat)
206
+
207
+ def with_explode_primary_index_key(self, primary_index_key: str) -> "SpanQuery":
208
+ _explode = self._explode.with_primary_index_key(primary_index_key)
209
+ return replace(self, _explode=_explode)
210
+
211
+ def __call__(self, spans: Iterable[Span]) -> pd.DataFrame:
212
+ if self._filter:
213
+ spans = filter(self._filter, spans)
214
+ if self._explode:
215
+ spans = filter(
216
+ lambda span: (isinstance(seq := self._explode.value(span), Sequence) and len(seq)),
217
+ spans,
218
+ )
219
+ if self._concat:
220
+ spans = filter(
221
+ lambda span: (isinstance(seq := self._concat.value(span), Sequence) and len(seq)),
222
+ spans,
223
+ )
224
+ _selected: List[Dict[str, Any]] = []
225
+ _exploded: List[Dict[str, Any]] = []
226
+ for span in spans:
227
+ if self._select:
228
+ record = {name: proj(span) for name, proj in self._select.items()}
229
+ for v in record.values():
230
+ if v is not None:
231
+ break
232
+ else:
233
+ record = {}
234
+ if self._concat:
235
+ record.update(self._concat(span))
236
+ if record:
237
+ if self._index.key not in record:
238
+ record[self._index.key] = self._index(span)
239
+ _selected.append(record)
240
+ elif self._concat:
241
+ record = {self._index.key: self._index(span)}
242
+ record.update(self._concat(span))
243
+ if record:
244
+ _selected.append(record)
245
+ if self._explode:
246
+ _exploded.extend(self._explode(span))
247
+ if _selected:
248
+ select_df = pd.DataFrame(_selected)
249
+ else:
250
+ select_df = pd.DataFrame(columns=[self._index.key])
251
+ select_df = select_df.set_index(self._index.key)
252
+ if self._explode:
253
+ if _exploded:
254
+ explode_df = pd.DataFrame(_exploded)
255
+ else:
256
+ explode_df = pd.DataFrame(columns=self._explode.index_keys)
257
+ explode_df = explode_df.set_index(list(self._explode.index_keys))
258
+ if not self._select:
259
+ return explode_df.rename(self._rename, axis=1, errors="ignore")
260
+ select_df = select_df.join(explode_df, how="outer")
261
+ return select_df.rename(self._rename, axis=1, errors="ignore")
@@ -27,7 +27,7 @@ from typing import (
27
27
  Union,
28
28
  cast,
29
29
  )
30
- from uuid import uuid4
30
+ from uuid import UUID, uuid4
31
31
 
32
32
  import llama_index
33
33
  from llama_index.callbacks.base_handler import BaseCallbackHandler
@@ -40,8 +40,12 @@ from llama_index.callbacks.schema import (
40
40
  from llama_index.llms.base import ChatMessage, ChatResponse
41
41
  from llama_index.response.schema import Response, StreamingResponse
42
42
  from llama_index.tools import ToolMetadata
43
+ from typing_extensions import TypeGuard
43
44
 
44
45
  from phoenix.trace.exporter import HttpExporter
46
+ from phoenix.trace.llama_index.streaming import (
47
+ instrument_streaming_response as _instrument_streaming_response,
48
+ )
45
49
  from phoenix.trace.schemas import (
46
50
  Span,
47
51
  SpanEvent,
@@ -49,6 +53,7 @@ from phoenix.trace.schemas import (
49
53
  SpanID,
50
54
  SpanKind,
51
55
  SpanStatusCode,
56
+ TraceID,
52
57
  )
53
58
  from phoenix.trace.semantic_conventions import (
54
59
  DOCUMENT_CONTENT,
@@ -109,6 +114,10 @@ class CBEventData:
109
114
  start_event: Optional[CBEvent] = field(default=None)
110
115
  end_event: Optional[CBEvent] = field(default=None)
111
116
  attributes: Dict[str, Any] = field(default_factory=dict)
117
+ span_id: Optional[CBEventID] = field(default=None)
118
+ parent_id: Optional[CBEventID] = field(default=None)
119
+ trace_id: Optional[TraceID] = field(default=None)
120
+ streaming_event: bool = field(default=False)
112
121
 
113
122
  def set_if_unset(self, key: str, value: Any) -> None:
114
123
  if not getattr(self, key):
@@ -270,9 +279,16 @@ class OpenInferenceTraceCallbackHandler(BaseCallbackHandler):
270
279
  **kwargs: Any,
271
280
  ) -> CBEventID:
272
281
  event_id = event_id or str(uuid4())
282
+ if parent_data := self._event_id_to_event_data.get(parent_id):
283
+ trace_id = parent_data.trace_id
284
+ else:
285
+ trace_id = uuid4()
273
286
  event_data = self._event_id_to_event_data[event_id]
274
287
  event_data.name = event_type.value
275
288
  event_data.event_type = event_type
289
+ event_data.parent_id = None if parent_id == "root" else parent_id
290
+ event_data.span_id = event_id
291
+ event_data.trace_id = trace_id
276
292
  event_data.start_event = CBEvent(
277
293
  event_type=event_type,
278
294
  payload=payload,
@@ -308,6 +324,10 @@ class OpenInferenceTraceCallbackHandler(BaseCallbackHandler):
308
324
  event_data.attributes.update(
309
325
  payload_to_semantic_attributes(event_type, payload, is_event_end=True),
310
326
  )
327
+ response = payload.get(EventPayload.RESPONSE)
328
+ if _is_streaming_response(response):
329
+ event_data.streaming_event = True
330
+ response = _instrument_streaming_response(response, self._tracer, event_data)
311
331
 
312
332
  @graceful_fallback(_null_fallback)
313
333
  def start_trace(self, trace_id: Optional[str] = None) -> None:
@@ -353,7 +373,6 @@ def _add_spans_to_tracer(
353
373
  tracer (Tracer): The tracer that stores spans.
354
374
  """
355
375
 
356
- trace_id = uuid4()
357
376
  parent_child_id_stack: List[Tuple[Optional[SpanID], CBEventID]] = [
358
377
  (None, root_event_id) for root_event_id in trace_map["root"]
359
378
  ]
@@ -383,7 +402,11 @@ def _add_spans_to_tracer(
383
402
 
384
403
  start_time = _timestamp_to_tz_aware_datetime(start_event.time)
385
404
  span_exceptions = _get_span_exceptions(event_data, start_time)
386
- end_time = _get_end_time(event_data, span_exceptions)
405
+ if event_data.streaming_event:
406
+ # Do not set the end time for streaming events so we can update the event later
407
+ end_time = None
408
+ else:
409
+ end_time = _get_end_time(event_data, span_exceptions)
387
410
  start_time = start_time or end_time or datetime.now(timezone.utc)
388
411
 
389
412
  name = event_name if (event_name := event_data.name) is not None else "unknown"
@@ -391,7 +414,7 @@ def _add_spans_to_tracer(
391
414
  span = tracer.create_span(
392
415
  name=name,
393
416
  span_kind=span_kind,
394
- trace_id=trace_id,
417
+ trace_id=event_data.trace_id,
395
418
  start_time=start_time,
396
419
  end_time=end_time,
397
420
  status_code=SpanStatusCode.ERROR if span_exceptions else SpanStatusCode.OK,
@@ -400,6 +423,7 @@ def _add_spans_to_tracer(
400
423
  attributes=attributes,
401
424
  events=sorted(span_exceptions, key=lambda event: event.timestamp) or None,
402
425
  conversation=None,
426
+ span_id=UUID(event_data.span_id),
403
427
  )
404
428
  new_parent_span_id = span.context.span_id
405
429
  for new_child_event_id in trace_map.get(event_id, []):
@@ -640,3 +664,7 @@ def _get_tool_call(tool_call: object) -> Iterator[Tuple[str, Any]]:
640
664
  if arguments := getattr(function, "arguments", None):
641
665
  assert isinstance(arguments, str), f"arguments must be str, found {type(arguments)}"
642
666
  yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, arguments
667
+
668
+
669
+ def _is_streaming_response(response: Any) -> TypeGuard[StreamingResponse]:
670
+ return isinstance(response, StreamingResponse)
@@ -0,0 +1,93 @@
1
+ from datetime import datetime, timezone
2
+ from typing import TYPE_CHECKING, Generator, List
3
+ from uuid import UUID
4
+
5
+ from llama_index.callbacks.schema import TIMESTAMP_FORMAT
6
+ from llama_index.response.schema import StreamingResponse
7
+
8
+ from phoenix.trace.schemas import SpanKind, SpanStatusCode
9
+ from phoenix.trace.semantic_conventions import OUTPUT_VALUE
10
+ from phoenix.trace.tracer import Tracer
11
+
12
+ if TYPE_CHECKING:
13
+ from phoenix.trace.llama_index.callback import CBEventData
14
+
15
+ _LOCAL_TZINFO = datetime.now().astimezone().tzinfo
16
+
17
+
18
+ class TokenGenInstrumentor:
19
+ def __init__(
20
+ self, stream: Generator[str, None, None], tracer: Tracer, event_data: "CBEventData"
21
+ ):
22
+ self._stream = stream
23
+ self._token_stream: List[str] = []
24
+ self._finished = False
25
+ self._tracer = tracer
26
+ self._event_data = event_data
27
+
28
+ def __iter__(self) -> "TokenGenInstrumentor":
29
+ return self
30
+
31
+ def __next__(self) -> str:
32
+ if self._finished:
33
+ raise StopIteration
34
+
35
+ try:
36
+ value = next(self._stream)
37
+ self._token_stream.append(value)
38
+ return value
39
+ except StopIteration:
40
+ self._finished = True
41
+ self._handle_end_of_stream()
42
+ raise
43
+
44
+ def _handle_end_of_stream(self) -> None:
45
+ # Handle the end-of-stream logic here
46
+ parent_id = self._event_data.parent_id
47
+ output = "".join(self._token_stream)
48
+ start_event = self._event_data.start_event
49
+ if start_event:
50
+ start_time = _timestamp_to_tz_aware_datetime(start_event.time)
51
+ else:
52
+ start_time = datetime.now(timezone.utc)
53
+ attributes = self._event_data.attributes
54
+ attributes.update({OUTPUT_VALUE: output})
55
+ self._tracer.create_span(
56
+ name=self._event_data.name if self._event_data.name else "",
57
+ span_kind=SpanKind.LLM,
58
+ trace_id=self._event_data.trace_id,
59
+ start_time=start_time,
60
+ end_time=datetime.now(timezone.utc),
61
+ status_code=SpanStatusCode.OK,
62
+ status_message="",
63
+ parent_id=UUID(parent_id) if parent_id else None,
64
+ attributes=self._event_data.attributes,
65
+ events=[],
66
+ conversation=None,
67
+ span_id=UUID(self._event_data.span_id),
68
+ )
69
+
70
+
71
+ def instrument_streaming_response(
72
+ response: StreamingResponse,
73
+ tracer: Tracer,
74
+ event_data: "CBEventData",
75
+ ) -> StreamingResponse:
76
+ if response.response_gen is not None:
77
+ response.response_gen = TokenGenInstrumentor(response.response_gen, tracer, event_data) # type: ignore
78
+ return response
79
+
80
+
81
+ def _timestamp_to_tz_aware_datetime(timestamp: str) -> datetime:
82
+ """Converts a timestamp string to a timezone-aware datetime."""
83
+ return _tz_naive_to_tz_aware_datetime(_timestamp_to_tz_naive_datetime(timestamp))
84
+
85
+
86
+ def _timestamp_to_tz_naive_datetime(timestamp: str) -> datetime:
87
+ """Converts a timestamp string to a timezone-naive datetime."""
88
+ return datetime.strptime(timestamp, TIMESTAMP_FORMAT)
89
+
90
+
91
+ def _tz_naive_to_tz_aware_datetime(timestamp: datetime) -> datetime:
92
+ """Converts a timezone-naive datetime to a timezone-aware datetime."""
93
+ return timestamp.replace(tzinfo=_LOCAL_TZINFO)