arize-phoenix 1.9.1rc2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from collections import defaultdict
3
+ from dataclasses import dataclass, field
3
4
  from datetime import datetime, timezone
4
5
  from enum import Enum
5
6
  from inspect import BoundArguments, signature
@@ -598,6 +599,16 @@ def _parameters(bound_arguments: BoundArguments) -> Parameters:
598
599
  return cast(Parameters, bound_arguments.arguments["options"].json_data)
599
600
 
600
601
 
602
+ @dataclass
603
+ class StreamingFunctionCallData:
604
+ """
605
+ Stores function call data from a streaming chat completion.
606
+ """
607
+
608
+ name: Optional[str] = None
609
+ argument_tokens: List[str] = field(default_factory=list)
610
+
611
+
601
612
  def _accumulate_messages(
602
613
  chunks: List[ChatCompletionChunk], num_choices: int
603
614
  ) -> List[OpenInferenceMessage]:
@@ -616,8 +627,12 @@ def _accumulate_messages(
616
627
  if not chunks:
617
628
  return []
618
629
  content_token_lists: DefaultDict[int, List[str]] = defaultdict(list)
619
- function_argument_token_lists: DefaultDict[int, List[str]] = defaultdict(list)
620
- function_names: Dict[int, str] = {}
630
+ function_calls: DefaultDict[int, StreamingFunctionCallData] = defaultdict(
631
+ StreamingFunctionCallData
632
+ )
633
+ tool_calls: DefaultDict[int, DefaultDict[int, StreamingFunctionCallData]] = defaultdict(
634
+ lambda: defaultdict(StreamingFunctionCallData)
635
+ )
621
636
  roles: Dict[int, str] = {}
622
637
  for chunk in chunks:
623
638
  for choice in chunk.choices:
@@ -626,21 +641,43 @@ def _accumulate_messages(
626
641
  content_token_lists[choice_index].append(content_token)
627
642
  if function_call := choice.delta.function_call:
628
643
  if function_name := function_call.name:
629
- function_names[choice_index] = function_name
644
+ function_calls[choice_index].name = function_name
630
645
  if (function_argument_token := function_call.arguments) is not None:
631
- function_argument_token_lists[choice_index].append(function_argument_token)
646
+ function_calls[choice_index].argument_tokens.append(function_argument_token)
632
647
  if role := choice.delta.role:
633
648
  roles[choice_index] = role
634
- messages: List[OpenInferenceMessage] = [{} for _ in range(num_choices)]
649
+ if choice.delta.tool_calls:
650
+ for tool_call in choice.delta.tool_calls:
651
+ tool_index = tool_call.index
652
+ if not tool_call.function:
653
+ continue
654
+ if (name := tool_call.function.name) is not None:
655
+ tool_calls[choice_index][tool_index].name = name
656
+ if (arguments := tool_call.function.arguments) is not None:
657
+ tool_calls[choice_index][tool_index].argument_tokens.append(arguments)
658
+
659
+ messages: List[OpenInferenceMessage] = []
635
660
  for choice_index in range(num_choices):
661
+ message: Dict[str, Any] = {}
636
662
  if (role_ := roles.get(choice_index)) is not None:
637
- messages[choice_index][MESSAGE_ROLE] = role_
638
- if content_token_list := content_token_lists[choice_index]:
639
- messages[choice_index][MESSAGE_CONTENT] = "".join(content_token_list)
640
- if (function_name := function_names.get(choice_index)) is not None:
641
- messages[choice_index][MESSAGE_FUNCTION_CALL_NAME] = function_name
642
- if function_argument_token_list := function_argument_token_lists[choice_index]:
643
- messages[choice_index][MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] = "".join(
644
- function_argument_token_list
645
- )
663
+ message[MESSAGE_ROLE] = role_
664
+ if content_tokens := content_token_lists[choice_index]:
665
+ message[MESSAGE_CONTENT] = "".join(content_tokens)
666
+ if function_call_ := function_calls.get(choice_index):
667
+ if (name := function_call_.name) is not None:
668
+ message[MESSAGE_FUNCTION_CALL_NAME] = name
669
+ if argument_tokens := function_call_.argument_tokens:
670
+ message[MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] = "".join(argument_tokens)
671
+ if tool_calls_ := tool_calls.get(choice_index):
672
+ num_tool_calls = max(tool_index for tool_index in tool_calls_.keys()) + 1
673
+ message[MESSAGE_TOOL_CALLS] = [{} for _ in range(num_tool_calls)]
674
+ for tool_index, tool_call_ in tool_calls_.items():
675
+ if (name := tool_call_.name) is not None:
676
+ message[MESSAGE_TOOL_CALLS][tool_index][TOOL_CALL_FUNCTION_NAME] = name
677
+ if argument_tokens := tool_call_.argument_tokens:
678
+ message[MESSAGE_TOOL_CALLS][tool_index][
679
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON
680
+ ] = "".join(argument_tokens)
681
+ messages.append(message)
682
+
646
683
  return messages