arize-phoenix 2.11.1__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,683 +1,26 @@
1
- import json
2
- from collections import defaultdict
3
- from dataclasses import dataclass, field
4
- from datetime import datetime, timezone
5
- from enum import Enum
6
- from inspect import BoundArguments, signature
7
- from types import TracebackType
8
- from typing import (
9
- Any,
10
- AsyncIterator,
11
- Callable,
12
- ContextManager,
13
- Coroutine,
14
- DefaultDict,
15
- Dict,
16
- Iterator,
17
- List,
18
- Mapping,
19
- Optional,
20
- Tuple,
21
- Type,
22
- TypeVar,
23
- cast,
24
- )
1
+ import logging
2
+ from importlib.metadata import PackageNotFoundError
3
+ from importlib.util import find_spec
4
+ from typing import Any
25
5
 
26
- import openai
27
- from openai import AsyncStream, Stream
28
- from openai.types.chat import ChatCompletion, ChatCompletionChunk
29
- from typing_extensions import ParamSpec
30
- from wrapt import ObjectProxy
6
+ from openinference.instrumentation.openai import OpenAIInstrumentor as Instrumentor
7
+ from opentelemetry.sdk import trace as trace_sdk
8
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
31
9
 
32
- from phoenix.trace.schemas import (
33
- MimeType,
34
- SpanAttributes,
35
- SpanEvent,
36
- SpanException,
37
- SpanKind,
38
- SpanStatusCode,
39
- )
40
- from phoenix.trace.semantic_conventions import (
41
- INPUT_MIME_TYPE,
42
- INPUT_VALUE,
43
- LLM_FUNCTION_CALL,
44
- LLM_INPUT_MESSAGES,
45
- LLM_INVOCATION_PARAMETERS,
46
- LLM_OUTPUT_MESSAGES,
47
- LLM_TOKEN_COUNT_COMPLETION,
48
- LLM_TOKEN_COUNT_PROMPT,
49
- LLM_TOKEN_COUNT_TOTAL,
50
- MESSAGE_CONTENT,
51
- MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
52
- MESSAGE_FUNCTION_CALL_NAME,
53
- MESSAGE_NAME,
54
- MESSAGE_ROLE,
55
- MESSAGE_TOOL_CALLS,
56
- OUTPUT_MIME_TYPE,
57
- OUTPUT_VALUE,
58
- TOOL_CALL_FUNCTION_ARGUMENTS_JSON,
59
- TOOL_CALL_FUNCTION_NAME,
60
- )
61
- from phoenix.trace.utils import get_stacktrace
10
+ from phoenix.trace.exporter import _OpenInferenceExporter
11
+ from phoenix.trace.tracer import _show_deprecation_warnings
62
12
 
63
- from ..tracer import Tracer
13
+ logger = logging.getLogger(__name__)
64
14
 
65
- ParameterSpec = ParamSpec("ParameterSpec")
66
- GenericType = TypeVar("GenericType")
67
- AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
68
- Parameters = Mapping[str, Any]
69
- OpenInferenceMessage = Dict[str, str]
70
15
 
71
- INSTRUMENTED_ATTRIBUTE_NAME = "is_instrumented_with_openinference_tracer"
72
-
73
-
74
- class RequestType(Enum):
75
- CHAT_COMPLETION = "chat_completion"
76
- COMPLETION = "completion"
77
- EMBEDDING = "embedding"
78
-
79
-
80
- class OpenAIInstrumentor:
81
- def __init__(self, tracer: Optional[Tracer] = None) -> None:
82
- """Instruments your OpenAI client to automatically create spans for each API call.
83
-
84
- Args:
85
- tracer (Optional[Tracer], optional): A tracer to record and handle spans. If not
86
- provided, the default tracer will be used.
87
- """
88
- self._tracer = tracer or Tracer()
16
+ class OpenAIInstrumentor(Instrumentor):
17
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
18
+ _show_deprecation_warnings(self, *args, **kwargs)
19
+ if find_spec("openai") is None:
20
+ raise PackageNotFoundError("Missing `openai`. Install with `pip install openai`.")
21
+ super().__init__()
89
22
 
90
23
  def instrument(self) -> None:
91
- """
92
- Instruments your OpenAI client.
93
- """
94
- if not hasattr(openai.OpenAI, INSTRUMENTED_ATTRIBUTE_NAME):
95
- openai.OpenAI.request = _wrapped_openai_sync_client_request_function( # type: ignore
96
- openai.OpenAI.request, self._tracer
97
- )
98
- setattr(
99
- openai.OpenAI,
100
- INSTRUMENTED_ATTRIBUTE_NAME,
101
- True,
102
- )
103
- if not hasattr(openai.AsyncOpenAI, INSTRUMENTED_ATTRIBUTE_NAME):
104
- openai.AsyncOpenAI.request = _wrapped_openai_async_client_request_function( # type: ignore
105
- openai.AsyncOpenAI.request, self._tracer
106
- )
107
- setattr(
108
- openai.AsyncOpenAI,
109
- INSTRUMENTED_ATTRIBUTE_NAME,
110
- True,
111
- )
112
-
113
-
114
- class ChatCompletionContext(ContextManager["ChatCompletionContext"]):
115
- """
116
- A context manager for creating spans for chat completion requests. The
117
- context manager extracts attributes from the input parameters and response
118
- from the API and records any exceptions that are raised.
119
- """
120
-
121
- def __init__(self, bound_arguments: BoundArguments, tracer: Tracer) -> None:
122
- """
123
- Initializes the context manager.
124
-
125
- Args:
126
- bound_arguments (BoundArguments): The arguments to the request
127
- function from which parameter attributes are extracted.
128
-
129
- tracer (Tracer): The tracer to use to create spans.
130
- """
131
- self.tracer = tracer
132
- self.start_time: Optional[datetime] = None
133
- self.status_code = SpanStatusCode.UNSET
134
- self.status_message = ""
135
- self.events: List[SpanEvent] = []
136
- self.attributes: SpanAttributes = dict()
137
- parameters = _parameters(bound_arguments)
138
- self.num_choices = parameters.get("n", 1)
139
- self.chat_completion_chunks: List[ChatCompletionChunk] = []
140
- self.stream_complete = False
141
- self._process_parameters(parameters)
142
- self._span_created = False
143
-
144
- def __enter__(self) -> "ChatCompletionContext":
145
- self.start_time = datetime.now(tz=timezone.utc)
146
- return self
147
-
148
- def __exit__(
149
- self,
150
- exc_type: Optional[Type[BaseException]],
151
- exc_value: Optional[BaseException],
152
- traceback: Optional[TracebackType],
153
- ) -> None:
154
- if exc_value is None:
155
- return
156
- self.status_code = SpanStatusCode.ERROR
157
- status_message = str(exc_value)
158
- self.status_message = status_message
159
- self.events.append(
160
- SpanException(
161
- message=status_message,
162
- timestamp=datetime.now(tz=timezone.utc),
163
- exception_type=type(exc_value).__name__,
164
- exception_stacktrace=get_stacktrace(exc_value),
165
- )
166
- )
167
- self.create_span()
168
-
169
- def process_response(self, response: Any) -> Any:
170
- """
171
- Processes the response from the OpenAI chat completions API call to extract attributes.
172
-
173
- Args:
174
- response (ChatCompletion): The chat completion object.
175
- """
176
- if isinstance(response, ChatCompletion):
177
- self._process_chat_completion(response)
178
- elif isinstance(response, Stream):
179
- return StreamWrapper(stream=response, context=self)
180
- elif isinstance(response, AsyncStream):
181
- return AsyncStreamWrapper(stream=response, context=self)
182
- elif hasattr(response, "parse") and callable(
183
- response.parse
184
- ): # handle raw response by converting them to chat completions
185
- self._process_chat_completion(response.parse())
186
- self.status_code = SpanStatusCode.OK
187
- self.create_span()
188
- return response
189
-
190
- def create_span(self) -> None:
191
- """
192
- Creates a span from the context if one has not already been created.
193
- """
194
- if self._span_created:
195
- return
196
- self.tracer.create_span(
197
- name="OpenAI Chat Completion",
198
- span_kind=SpanKind.LLM,
199
- start_time=cast(datetime, self.start_time),
200
- end_time=datetime.now(tz=timezone.utc),
201
- status_code=self.status_code,
202
- status_message=self.status_message,
203
- attributes=self.attributes,
204
- events=self.events,
205
- )
206
- self._span_created = True
207
-
208
- def _process_chat_completion(self, chat_completion: ChatCompletion) -> None:
209
- """
210
- Processes a chat completion response to extract and add fields and
211
- attributes to the context.
212
-
213
- Args:
214
- chat_completion (ChatCompletion): Response object from the chat
215
- completions API.
216
- """
217
- for (
218
- attribute_name,
219
- get_chat_completion_attribute_fn,
220
- ) in _CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS.items():
221
- if (attribute_value := get_chat_completion_attribute_fn(chat_completion)) is not None:
222
- self.attributes[attribute_name] = attribute_value
223
-
224
- def _process_parameters(self, parameters: Parameters) -> None:
225
- """
226
- Processes the input parameters to the chat completions API to extract
227
- and add fields and attributes to the context.
228
-
229
- Args:
230
- parameters (Parameters): Input parameters.
231
- """
232
- for (
233
- attribute_name,
234
- get_parameter_attribute_fn,
235
- ) in _PARAMETER_ATTRIBUTE_FUNCTIONS.items():
236
- if (attribute_value := get_parameter_attribute_fn(parameters)) is not None:
237
- self.attributes[attribute_name] = attribute_value
238
-
239
-
240
- class StreamWrapper(ObjectProxy): # type: ignore
241
- """
242
- A wrapper for streams of chat completion chunks that records events and
243
- creates the span upon completion of the stream or upon an exception.
244
- """
245
-
246
- def __init__(self, stream: Stream[ChatCompletionChunk], context: ChatCompletionContext) -> None:
247
- """Initializes the stream wrapper.
248
-
249
- Args:
250
- stream (Stream[ChatCompletionChunk]): The stream to wrap.
251
-
252
- context (ChatCompletionContext): The context used to store span
253
- fields and attributes.
254
- """
255
- super().__init__(stream)
256
- self._self_context = ChatCompletionStreamEventContext(context)
257
-
258
- def __next__(self) -> ChatCompletionChunk:
259
- """
260
- A wrapped __next__ method that records span stream events and updates
261
- the span upon completion of the stream or upon exception.
262
-
263
- Returns:
264
- ChatCompletionChunk: The forwarded chat completion chunk.
265
- """
266
- with self._self_context as context:
267
- chat_completion_chunk = next(self.__wrapped__)
268
- context.process_chat_completion_chunk(chat_completion_chunk)
269
- return cast(ChatCompletionChunk, chat_completion_chunk)
270
-
271
- def __iter__(self) -> Iterator[ChatCompletionChunk]:
272
- """
273
- A __iter__ method that bypasses the wrapped class' __iter__ method so
274
- that __iter__ is automatically instrumented using __next__.
275
- """
276
- return self
277
-
278
-
279
- class AsyncStreamWrapper(ObjectProxy): # type: ignore
280
- """
281
- A wrapper for asynchronous streams of chat completion chunks that records
282
- events and creates the span upon completion of the stream or upon an
283
- exception.
284
- """
285
-
286
- def __init__(
287
- self, stream: AsyncStream[ChatCompletionChunk], context: ChatCompletionContext
288
- ) -> None:
289
- """Initializes the stream wrapper.
290
-
291
- Args:
292
- stream (AsyncStream[ChatCompletionChunk]): The stream to wrap.
293
-
294
- context (ChatCompletionContext): The context used to store span
295
- fields and attributes.
296
- """
297
- super().__init__(stream)
298
- self._self_context = ChatCompletionStreamEventContext(context)
299
-
300
- async def __anext__(self) -> ChatCompletionChunk:
301
- """
302
- A wrapped __anext__ method that records span stream events and updates
303
- the span upon completion of the stream or upon exception.
304
-
305
- Returns:
306
- ChatCompletionChunk: The forwarded chat completion chunk.
307
- """
308
- with self._self_context as context:
309
- chat_completion_chunk = await self.__wrapped__.__anext__()
310
- context.process_chat_completion_chunk(chat_completion_chunk)
311
- return cast(ChatCompletionChunk, chat_completion_chunk)
312
-
313
- def __aiter__(self) -> AsyncIterator[ChatCompletionChunk]:
314
- """
315
- An __aiter__ method that bypasses the wrapped class' __aiter__ method so
316
- that __aiter__ is automatically instrumented using __anext__.
317
- """
318
- return self
319
-
320
-
321
- class ChatCompletionStreamEventContext(ContextManager["ChatCompletionStreamEventContext"]):
322
- """
323
- A context manager that surrounds stream events in a stream of chat
324
- completions and processes each chat completion chunk.
325
- """
326
-
327
- def __init__(self, chat_completion_context: ChatCompletionContext) -> None:
328
- """Initializes the context manager.
329
-
330
- Args:
331
- chat_completion_context (ChatCompletionContext): The chat completion
332
- context storing span fields and attributes.
333
- """
334
- self._context = chat_completion_context
335
-
336
- def __exit__(
337
- self,
338
- exc_type: Optional[Type[BaseException]],
339
- exc_value: Optional[BaseException],
340
- traceback: Optional[TracebackType],
341
- ) -> None:
342
- if isinstance(exc_value, StopIteration) or isinstance(exc_value, StopAsyncIteration):
343
- self._context.stream_complete = True
344
- self._context.status_code = SpanStatusCode.OK
345
- elif exc_value is not None:
346
- self._context.stream_complete = True
347
- self._context.status_code = SpanStatusCode.ERROR
348
- status_message = str(exc_value)
349
- self._context.status_message = status_message
350
- self._context.events.append(
351
- SpanException(
352
- message=status_message,
353
- timestamp=datetime.now(tz=timezone.utc),
354
- exception_type=type(exc_value).__name__,
355
- exception_stacktrace=get_stacktrace(exc_value),
356
- )
357
- )
358
- if not self._context.stream_complete:
359
- return
360
- self._context.attributes = {
361
- **self._context.attributes,
362
- LLM_OUTPUT_MESSAGES: _accumulate_messages(
363
- chunks=self._context.chat_completion_chunks,
364
- num_choices=self._context.num_choices,
365
- ), # type: ignore
366
- OUTPUT_VALUE: json.dumps(
367
- [chunk.dict() for chunk in self._context.chat_completion_chunks]
368
- ),
369
- OUTPUT_MIME_TYPE: MimeType.JSON, # type: ignore
370
- }
371
- self._context.create_span()
372
-
373
- def process_chat_completion_chunk(self, chat_completion_chunk: ChatCompletionChunk) -> None:
374
- """
375
- Processes a chat completion chunk and adds relevant information to the
376
- context.
377
-
378
- Args:
379
- chat_completion_chunk (ChatCompletionChunk): The chat completion
380
- chunk to be processed.
381
- """
382
- if not self._context.chat_completion_chunks:
383
- self._context.events.append(
384
- SpanEvent(
385
- name="First Token Stream Event",
386
- timestamp=datetime.now(tz=timezone.utc),
387
- attributes={},
388
- )
389
- )
390
- self._context.chat_completion_chunks.append(chat_completion_chunk)
391
-
392
-
393
- def _wrapped_openai_sync_client_request_function(
394
- request_fn: Callable[ParameterSpec, GenericType], tracer: Tracer
395
- ) -> Callable[ParameterSpec, GenericType]:
396
- """
397
- Wraps the synchronous OpenAI client's request method to create spans for
398
- each API call.
399
-
400
- Args:
401
- request_fn (Callable[ParameterSpec, GenericType]): The request method on
402
- the OpenAI client.
403
-
404
- tracer (Tracer): The tracer to use to create spans.
405
-
406
- Returns:
407
- Callable[ParameterSpec, GenericType]: The wrapped request method.
408
- """
409
- call_signature = signature(request_fn)
410
-
411
- def wrapped(*args: Any, **kwargs: Any) -> Any:
412
- bound_arguments = call_signature.bind(*args, **kwargs)
413
- if _request_type(bound_arguments) is not RequestType.CHAT_COMPLETION:
414
- return request_fn(*args, **kwargs)
415
- with ChatCompletionContext(bound_arguments, tracer) as context:
416
- response = request_fn(*args, **kwargs)
417
- return context.process_response(response)
418
-
419
- return wrapped
420
-
421
-
422
- def _wrapped_openai_async_client_request_function(
423
- request_fn: AsyncCallable[ParameterSpec, GenericType], tracer: Tracer
424
- ) -> AsyncCallable[ParameterSpec, GenericType]:
425
- """
426
- Wraps the asynchronous AsyncOpenAI client's request method to create spans
427
- for each API call.
428
-
429
- Args:
430
- request_fn (AsyncCallable[ParameterSpec, GenericType]): The request
431
- method on the AsyncOpenAI client.
432
-
433
- tracer (Tracer): The tracer to use to create spans.
434
-
435
- Returns:
436
- AsyncCallable[ParameterSpec, GenericType]: The wrapped request method.
437
- """
438
- call_signature = signature(request_fn)
439
-
440
- async def wrapped(*args: Any, **kwargs: Any) -> Any:
441
- bound_arguments = call_signature.bind(*args, **kwargs)
442
- if _request_type(bound_arguments) is not RequestType.CHAT_COMPLETION:
443
- return await request_fn(*args, **kwargs)
444
- with ChatCompletionContext(bound_arguments, tracer) as context:
445
- response = await request_fn(*args, **kwargs)
446
- return context.process_response(response)
447
-
448
- return wrapped
449
-
450
-
451
- def _input_value(parameters: Parameters) -> str:
452
- return json.dumps(parameters)
453
-
454
-
455
- def _input_mime_type(_: Any) -> MimeType:
456
- return MimeType.JSON
457
-
458
-
459
- def _llm_input_messages(parameters: Parameters) -> Optional[List[OpenInferenceMessage]]:
460
- if not (messages := parameters.get("messages")):
461
- return None
462
- return [_to_openinference_message(message, expects_name=True) for message in messages]
463
-
464
-
465
- def _llm_invocation_parameters(
466
- parameters: Parameters,
467
- ) -> str:
468
- return json.dumps(parameters)
469
-
470
-
471
- def _output_value(chat_completion: ChatCompletion) -> str:
472
- return chat_completion.json()
473
-
474
-
475
- def _output_mime_type(_: Any) -> MimeType:
476
- return MimeType.JSON
477
-
478
-
479
- def _llm_output_messages(chat_completion: ChatCompletion) -> List[OpenInferenceMessage]:
480
- return [
481
- _to_openinference_message(choice.message.dict(), expects_name=False)
482
- for choice in chat_completion.choices
483
- ]
484
-
485
-
486
- def _llm_token_count_prompt(chat_completion: ChatCompletion) -> Optional[int]:
487
- if completion_usage := chat_completion.usage:
488
- return completion_usage.prompt_tokens
489
- return None
490
-
491
-
492
- def _llm_token_count_completion(chat_completion: ChatCompletion) -> Optional[int]:
493
- if completion_usage := chat_completion.usage:
494
- return completion_usage.completion_tokens
495
- return None
496
-
497
-
498
- def _llm_token_count_total(chat_completion: ChatCompletion) -> Optional[int]:
499
- if completion_usage := chat_completion.usage:
500
- return completion_usage.total_tokens
501
- return None
502
-
503
-
504
- def _llm_function_call(
505
- chat_completion: ChatCompletion,
506
- ) -> Optional[str]:
507
- choices = chat_completion.choices
508
- choice = choices[0]
509
- if choice.finish_reason == "function_call" and (function_call := choice.message.function_call):
510
- return function_call.json()
511
- return None
512
-
513
-
514
- def _request_type(bound_arguments: BoundArguments) -> Optional[RequestType]:
515
- options = bound_arguments.arguments["options"]
516
- url = options.url
517
- """Get OpenAI request type from URL, or returns None if the request type cannot be recognized"""
518
- if "chat/completions" in url:
519
- return RequestType.CHAT_COMPLETION
520
- if "completions" in url:
521
- return RequestType.COMPLETION
522
- if "embeddings" in url:
523
- return RequestType.EMBEDDING
524
- return None
525
-
526
-
527
- def _to_openinference_message(
528
- message: Mapping[str, Any], *, expects_name: bool
529
- ) -> OpenInferenceMessage:
530
- """Converts an OpenAI input or output message to an OpenInference message.
531
-
532
- Args:
533
- message (Dict[str, Any]): The OpenAI message to be parsed.
534
-
535
- expects_name (bool): Whether to parse the "name" key in the OpenAI message. This key is
536
- sometimes included in "function"-role input messages to specify the function name, but is
537
- not included in output messages.
538
-
539
- Returns:
540
- OpenInferenceMessage: A message in OpenInference format.
541
- """
542
- openinference_message = {}
543
- if role := message.get("role"):
544
- openinference_message[MESSAGE_ROLE] = role
545
- if content := message.get("content"):
546
- openinference_message[MESSAGE_CONTENT] = content
547
- if function_call_data := message.get("function_call"):
548
- if function_name := function_call_data.get("name"):
549
- openinference_message[MESSAGE_FUNCTION_CALL_NAME] = function_name
550
- if function_arguments := function_call_data.get("arguments"):
551
- openinference_message[MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] = function_arguments
552
- if tool_calls_data := message.get("tool_calls"):
553
- message_tool_calls = []
554
- for tool_call_data in tool_calls_data:
555
- if message_tool_call := dict(_get_tool_call(tool_call_data)):
556
- message_tool_calls.append(message_tool_call)
557
- if message_tool_calls:
558
- openinference_message[MESSAGE_TOOL_CALLS] = message_tool_calls
559
- if expects_name and (name := message.get("name")):
560
- openinference_message[MESSAGE_NAME] = name
561
- return openinference_message
562
-
563
-
564
- def _get_tool_call(tool_call: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]:
565
- if function := tool_call.get("function"):
566
- if name := function.get("name"):
567
- yield TOOL_CALL_FUNCTION_NAME, name
568
- if arguments := function.get("arguments"):
569
- yield TOOL_CALL_FUNCTION_ARGUMENTS_JSON, arguments
570
-
571
-
572
- _PARAMETER_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[[Parameters], Any]] = {
573
- INPUT_VALUE: _input_value,
574
- INPUT_MIME_TYPE: _input_mime_type,
575
- LLM_INPUT_MESSAGES: _llm_input_messages,
576
- LLM_INVOCATION_PARAMETERS: _llm_invocation_parameters,
577
- }
578
- _CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[[ChatCompletion], Any]] = {
579
- OUTPUT_VALUE: _output_value,
580
- OUTPUT_MIME_TYPE: _output_mime_type,
581
- LLM_OUTPUT_MESSAGES: _llm_output_messages,
582
- LLM_TOKEN_COUNT_PROMPT: _llm_token_count_prompt,
583
- LLM_TOKEN_COUNT_COMPLETION: _llm_token_count_completion,
584
- LLM_TOKEN_COUNT_TOTAL: _llm_token_count_total,
585
- LLM_FUNCTION_CALL: _llm_function_call,
586
- }
587
-
588
-
589
- def _parameters(bound_arguments: BoundArguments) -> Parameters:
590
- """
591
- The parameters for the LLM call, e.g., temperature.
592
-
593
- Args:
594
- bound_arguments (BoundArguments): The bound arguments to the request function.
595
-
596
- Returns:
597
- Parameters: The parameters to the request function.
598
- """
599
- return cast(Parameters, bound_arguments.arguments["options"].json_data)
600
-
601
-
602
- @dataclass
603
- class StreamingFunctionCallData:
604
- """
605
- Stores function call data from a streaming chat completion.
606
- """
607
-
608
- name: Optional[str] = None
609
- argument_tokens: List[str] = field(default_factory=list)
610
-
611
-
612
- def _accumulate_messages(
613
- chunks: List[ChatCompletionChunk], num_choices: int
614
- ) -> List[OpenInferenceMessage]:
615
- """
616
- Converts a list of chat completion chunks to a list of OpenInference messages.
617
-
618
- Args:
619
- chunks (List[ChatCompletionChunk]): The input chunks to be converted.
620
-
621
- num_choices (int): The number of choices in the chat completion (i.e.,
622
- the parameter `n` in the input parameters).
623
-
624
- Returns:
625
- List[OpenInferenceMessage]: The list of OpenInference messages.
626
- """
627
- if not chunks:
628
- return []
629
- content_token_lists: DefaultDict[int, List[str]] = defaultdict(list)
630
- function_calls: DefaultDict[int, StreamingFunctionCallData] = defaultdict(
631
- StreamingFunctionCallData
632
- )
633
- tool_calls: DefaultDict[int, DefaultDict[int, StreamingFunctionCallData]] = defaultdict(
634
- lambda: defaultdict(StreamingFunctionCallData)
635
- )
636
- roles: Dict[int, str] = {}
637
- for chunk in chunks:
638
- for choice in chunk.choices:
639
- choice_index = choice.index
640
- if content_token := choice.delta.content:
641
- content_token_lists[choice_index].append(content_token)
642
- if function_call := choice.delta.function_call:
643
- if function_name := function_call.name:
644
- function_calls[choice_index].name = function_name
645
- if (function_argument_token := function_call.arguments) is not None:
646
- function_calls[choice_index].argument_tokens.append(function_argument_token)
647
- if role := choice.delta.role:
648
- roles[choice_index] = role
649
- if choice.delta.tool_calls:
650
- for tool_call in choice.delta.tool_calls:
651
- tool_index = tool_call.index
652
- if not tool_call.function:
653
- continue
654
- if (name := tool_call.function.name) is not None:
655
- tool_calls[choice_index][tool_index].name = name
656
- if (arguments := tool_call.function.arguments) is not None:
657
- tool_calls[choice_index][tool_index].argument_tokens.append(arguments)
658
-
659
- messages: List[OpenInferenceMessage] = []
660
- for choice_index in range(num_choices):
661
- message: Dict[str, Any] = {}
662
- if (role_ := roles.get(choice_index)) is not None:
663
- message[MESSAGE_ROLE] = role_
664
- if content_tokens := content_token_lists[choice_index]:
665
- message[MESSAGE_CONTENT] = "".join(content_tokens)
666
- if function_call_ := function_calls.get(choice_index):
667
- if (name := function_call_.name) is not None:
668
- message[MESSAGE_FUNCTION_CALL_NAME] = name
669
- if argument_tokens := function_call_.argument_tokens:
670
- message[MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] = "".join(argument_tokens)
671
- if tool_calls_ := tool_calls.get(choice_index):
672
- num_tool_calls = max(tool_index for tool_index in tool_calls_.keys()) + 1
673
- message[MESSAGE_TOOL_CALLS] = [{} for _ in range(num_tool_calls)]
674
- for tool_index, tool_call_ in tool_calls_.items():
675
- if (name := tool_call_.name) is not None:
676
- message[MESSAGE_TOOL_CALLS][tool_index][TOOL_CALL_FUNCTION_NAME] = name
677
- if argument_tokens := tool_call_.argument_tokens:
678
- message[MESSAGE_TOOL_CALLS][tool_index][
679
- TOOL_CALL_FUNCTION_ARGUMENTS_JSON
680
- ] = "".join(argument_tokens)
681
- messages.append(message)
682
-
683
- return messages
24
+ tracer_provider = trace_sdk.TracerProvider()
25
+ tracer_provider.add_span_processor(SimpleSpanProcessor(_OpenInferenceExporter()))
26
+ super().instrument(skip_dep_check=True, tracer_provider=tracer_provider)