unique_toolkit 0.8.4__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/app/schemas.py +3 -0
- unique_toolkit/language_model/schemas.py +18 -0
- unique_toolkit/reference_manager/reference_manager.py +72 -0
- unique_toolkit/tools/agent_chunks_handler.py +62 -0
- unique_toolkit/tools/config.py +108 -0
- unique_toolkit/tools/{tool_factory.py → factory.py} +15 -5
- unique_toolkit/tools/schemas.py +138 -0
- unique_toolkit/tools/test/test_tool_progress_reporter.py +204 -0
- unique_toolkit/tools/tool.py +168 -0
- unique_toolkit/tools/tool_manager.py +242 -0
- unique_toolkit/tools/tool_progress_reporter.py +4 -11
- unique_toolkit/tools/utils/execution/execution.py +282 -0
- unique_toolkit/tools/utils/source_handling/schema.py +22 -0
- unique_toolkit/tools/utils/source_handling/source_formatting.py +207 -0
- unique_toolkit/tools/utils/source_handling/tests/test_source_formatting.py +215 -0
- {unique_toolkit-0.8.4.dist-info → unique_toolkit-0.8.5.dist-info}/METADATA +4 -1
- {unique_toolkit-0.8.4.dist-info → unique_toolkit-0.8.5.dist-info}/RECORD +19 -10
- unique_toolkit/tools/tool_definitions.py +0 -145
- unique_toolkit/tools/tool_definitionsV2.py +0 -137
- {unique_toolkit-0.8.4.dist-info → unique_toolkit-0.8.5.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.8.4.dist-info → unique_toolkit-0.8.5.dist-info}/WHEEL +0 -0
    
        unique_toolkit/app/schemas.py
    CHANGED
    
    | @@ -51,6 +51,7 @@ class BaseEvent(BaseModel): | |
| 51 51 | 
             
            # MCP schemas
         | 
| 52 52 | 
             
            ###
         | 
| 53 53 |  | 
| 54 | 
            +
             | 
| 54 55 | 
             
            class McpTool(BaseModel):
         | 
| 55 56 | 
             
                model_config = model_config
         | 
| 56 57 |  | 
| @@ -79,6 +80,7 @@ class McpTool(BaseModel): | |
| 79 80 | 
             
                    description="Whether the tool is connected to the MCP server. This is a Unique specific field.",
         | 
| 80 81 | 
             
                )
         | 
| 81 82 |  | 
| 83 | 
            +
             | 
| 82 84 | 
             
            class McpServer(BaseModel):
         | 
| 83 85 | 
             
                model_config = model_config
         | 
| 84 86 |  | 
| @@ -94,6 +96,7 @@ class McpServer(BaseModel): | |
| 94 96 | 
             
                )
         | 
| 95 97 | 
             
                tools: list[McpTool] = []
         | 
| 96 98 |  | 
| 99 | 
            +
             | 
| 97 100 | 
             
            ###
         | 
| 98 101 | 
             
            # ChatEvent schemas
         | 
| 99 102 | 
             
            ###
         | 
| @@ -86,6 +86,24 @@ class LanguageModelFunction(BaseModel): | |
| 86 86 | 
             
                    return seralization
         | 
| 87 87 |  | 
| 88 88 |  | 
| 89 | 
            +
                def __eq__(self, other:Self) -> bool:
         | 
| 90 | 
            +
                    """
         | 
| 91 | 
            +
                    Compare two tool calls based on name and arguments.
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                    if not isinstance(other, LanguageModelFunction):
         | 
| 94 | 
            +
                        return False
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    if self.id != other.id:
         | 
| 97 | 
            +
                        return False
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    if self.name != other.name:
         | 
| 100 | 
            +
                        return False
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    if self.arguments != other.arguments:
         | 
| 103 | 
            +
                        return False
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                    return True
         | 
| 106 | 
            +
             | 
| 89 107 | 
             
            # This is tailored to the unique backend
         | 
| 90 108 | 
             
            class LanguageModelStreamResponse(BaseModel):
         | 
| 91 109 | 
             
                model_config = model_config
         | 
| @@ -0,0 +1,72 @@ | |
| 1 | 
            +
            from unique_toolkit.content.schemas import ContentChunk, ContentReference
         | 
| 2 | 
            +
            from unique_toolkit.tools.schemas import ToolCallResponse
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class tool_chunks:
         | 
| 6 | 
            +
                def __init__(self, name: str, chunks: list) -> None:
         | 
| 7 | 
            +
                    self.name = name
         | 
| 8 | 
            +
                    self.chunks = chunks
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class ReferenceManager:
         | 
| 12 | 
            +
                def __init__(self):
         | 
| 13 | 
            +
                    self._tool_chunks: dict[str, tool_chunks] = {}
         | 
| 14 | 
            +
                    self._chunks: list[ContentChunk] = []
         | 
| 15 | 
            +
                    self._references: list[list[ContentReference]] = []
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def extract_referenceable_chunks(
         | 
| 18 | 
            +
                    self, tool_responses: list[ToolCallResponse]
         | 
| 19 | 
            +
                ) -> None:
         | 
| 20 | 
            +
                    for tool_response in tool_responses:
         | 
| 21 | 
            +
                        if not tool_response.content_chunks:
         | 
| 22 | 
            +
                            continue
         | 
| 23 | 
            +
                        self._chunks.extend(tool_response.content_chunks or [])
         | 
| 24 | 
            +
                        self._tool_chunks[tool_response.id] = tool_chunks(
         | 
| 25 | 
            +
                            tool_response.name, tool_response.content_chunks
         | 
| 26 | 
            +
                        )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def get_chunks(self) -> list[ContentChunk]:
         | 
| 29 | 
            +
                    return self._chunks
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def get_tool_chunks(self) -> dict:
         | 
| 32 | 
            +
                    return self._tool_chunks
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def replace(self, chunks: list[ContentChunk]):
         | 
| 35 | 
            +
                    self._chunks = chunks
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def add_references(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    references: list[ContentReference],
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    self._references.append(references)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def get_references(
         | 
| 44 | 
            +
                    self,
         | 
| 45 | 
            +
                ) -> list[list[ContentReference]]:
         | 
| 46 | 
            +
                    return self._references
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def get_latest_references(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                ) -> list[ContentReference]:
         | 
| 51 | 
            +
                    if not self._references:
         | 
| 52 | 
            +
                        return []
         | 
| 53 | 
            +
                    return self._references[-1]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def get_latest_referenced_chunks(self) -> list[ContentChunk]:
         | 
| 56 | 
            +
                    if not self._references:
         | 
| 57 | 
            +
                        return []
         | 
| 58 | 
            +
                    return self._get_referenced_chunks_from_references(self._references[-1])
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def _get_referenced_chunks_from_references(
         | 
| 61 | 
            +
                    self,
         | 
| 62 | 
            +
                    references: list[ContentReference],
         | 
| 63 | 
            +
                ) -> list[ContentChunk]:
         | 
| 64 | 
            +
                    """
         | 
| 65 | 
            +
                    Get _referenced_chunks by matching sourceId from _references with merged id and chunk_id from _chunks.
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    referenced_chunks: list[ContentChunk] = []
         | 
| 68 | 
            +
                    for ref in references:
         | 
| 69 | 
            +
                        for chunk in self._chunks:
         | 
| 70 | 
            +
                            if ref.source_id == f"{chunk.id}-{chunk.chunk_id}":
         | 
| 71 | 
            +
                                referenced_chunks.append(chunk)
         | 
| 72 | 
            +
                    return referenced_chunks
         | 
| @@ -0,0 +1,62 @@ | |
| 1 | 
            +
            from unique_toolkit.content.schemas import ContentChunk, ContentReference
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class AgentChunksHandler:
         | 
| 5 | 
            +
                def __init__(self):
         | 
| 6 | 
            +
                    self._tool_chunks = {}
         | 
| 7 | 
            +
                    self._chunks: list[ContentChunk] = []
         | 
| 8 | 
            +
                    self._references: list[list[ContentReference]] = []
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                @property
         | 
| 11 | 
            +
                def chunks(self) -> list[ContentChunk]:
         | 
| 12 | 
            +
                    return self._chunks
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @property
         | 
| 15 | 
            +
                def tool_chunks(self) -> dict:
         | 
| 16 | 
            +
                    return self._tool_chunks
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def extend(self, chunks: list[ContentChunk]):
         | 
| 19 | 
            +
                    self._chunks.extend(chunks)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def replace(self, chunks: list[ContentChunk]):
         | 
| 22 | 
            +
                    self._chunks = chunks
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def add_references(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    references: list[ContentReference],
         | 
| 27 | 
            +
                ):
         | 
| 28 | 
            +
                    self._references.append(references)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                @property
         | 
| 31 | 
            +
                def all_references(
         | 
| 32 | 
            +
                    self,
         | 
| 33 | 
            +
                ) -> list[list[ContentReference]]:
         | 
| 34 | 
            +
                    return self._references
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                @property
         | 
| 37 | 
            +
                def latest_references(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                ) -> list[ContentReference]:
         | 
| 40 | 
            +
                    if not self._references:
         | 
| 41 | 
            +
                        return []
         | 
| 42 | 
            +
                    return self._references[-1]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                @property
         | 
| 45 | 
            +
                def latest_referenced_chunks(self) -> list[ContentChunk]:
         | 
| 46 | 
            +
                    if not self._references:
         | 
| 47 | 
            +
                        return []
         | 
| 48 | 
            +
                    return self._get_referenced_chunks_from_references(self._references[-1])
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def _get_referenced_chunks_from_references(
         | 
| 51 | 
            +
                    self,
         | 
| 52 | 
            +
                    references: list[ContentReference],
         | 
| 53 | 
            +
                ) -> list[ContentChunk]:
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    Get _referenced_chunks by matching sourceId from _references with merged id and chunk_id from _chunks.
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
                    referenced_chunks: list[ContentChunk] = []
         | 
| 58 | 
            +
                    for ref in references:
         | 
| 59 | 
            +
                        for chunk in self._chunks:
         | 
| 60 | 
            +
                            if ref.source_id == str(chunk.id) + "_" + str(chunk.chunk_id):
         | 
| 61 | 
            +
                                referenced_chunks.append(chunk)
         | 
| 62 | 
            +
                    return referenced_chunks
         | 
| @@ -0,0 +1,108 @@ | |
| 1 | 
            +
            from enum import StrEnum
         | 
| 2 | 
            +
            import humps
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
            from pydantic.fields import ComputedFieldInfo, FieldInfo
         | 
| 5 | 
            +
            from pydantic.alias_generators import to_camel
         | 
| 6 | 
            +
            from pydantic import (
         | 
| 7 | 
            +
                BaseModel,
         | 
| 8 | 
            +
                ConfigDict,
         | 
| 9 | 
            +
                Field,
         | 
| 10 | 
            +
                ValidationInfo,
         | 
| 11 | 
            +
                model_validator,
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from typing import TYPE_CHECKING
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            if TYPE_CHECKING:
         | 
| 17 | 
            +
                from unique_toolkit.tools.schemas import BaseToolConfig
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def field_title_generator(
         | 
| 21 | 
            +
                title: str,
         | 
| 22 | 
            +
                info: FieldInfo | ComputedFieldInfo,
         | 
| 23 | 
            +
            ) -> str:
         | 
| 24 | 
            +
                return humps.decamelize(title).replace("_", " ").title()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def model_title_generator(model: type) -> str:
         | 
| 28 | 
            +
                return humps.decamelize(model.__name__).replace("_", " ").title()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def get_configuration_dict(**kwargs) -> ConfigDict:
         | 
| 32 | 
            +
                return ConfigDict(
         | 
| 33 | 
            +
                    alias_generator=to_camel,
         | 
| 34 | 
            +
                    field_title_generator=field_title_generator,
         | 
| 35 | 
            +
                    model_title_generator=model_title_generator,
         | 
| 36 | 
            +
                    populate_by_name=True,
         | 
| 37 | 
            +
                    protected_namespaces=(),
         | 
| 38 | 
            +
                    **kwargs,
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class ToolIcon(StrEnum):
         | 
| 43 | 
            +
                ANALYTICS = "IconAnalytics"
         | 
| 44 | 
            +
                BOOK = "IconBook"
         | 
| 45 | 
            +
                FOLDERDATA = "IconFolderData"
         | 
| 46 | 
            +
                INTEGRATION = "IconIntegration"
         | 
| 47 | 
            +
                TEXT_COMPARE = "IconTextCompare"
         | 
| 48 | 
            +
                WORLD = "IconWorld"
         | 
| 49 | 
            +
                QUICK_REPLY = "IconQuickReply"
         | 
| 50 | 
            +
                CHAT_PLUS = "IconChatPlus"
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class ToolSelectionPolicy(StrEnum):
         | 
| 54 | 
            +
                """Determine the usage policy of tools."""
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                FORCED_BY_DEFAULT = "ForcedByDefault"
         | 
| 57 | 
            +
                ON_BY_DEFAULT = "OnByDefault"
         | 
| 58 | 
            +
                BY_USER = "ByUser"
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class ToolBuildConfig(BaseModel):
         | 
| 62 | 
            +
                model_config = get_configuration_dict()
         | 
| 63 | 
            +
                """Main tool configuration"""
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                name: str
         | 
| 66 | 
            +
                configuration: "BaseToolConfig"
         | 
| 67 | 
            +
                display_name: str = ""
         | 
| 68 | 
            +
                icon: ToolIcon = ToolIcon.BOOK
         | 
| 69 | 
            +
                selection_policy: ToolSelectionPolicy = Field(
         | 
| 70 | 
            +
                    default=ToolSelectionPolicy.BY_USER,
         | 
| 71 | 
            +
                )
         | 
| 72 | 
            +
                is_exclusive: bool = Field(
         | 
| 73 | 
            +
                    default=False,
         | 
| 74 | 
            +
                    description="This tool must be chosen by the user and no other tools are used for this iteration.",
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                is_enabled: bool = Field(default=True)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                @model_validator(mode="before")
         | 
| 80 | 
            +
                def initialize_config_based_on_tool_name(
         | 
| 81 | 
            +
                    cls,
         | 
| 82 | 
            +
                    value: Any,
         | 
| 83 | 
            +
                    info: ValidationInfo,
         | 
| 84 | 
            +
                ) -> Any:
         | 
| 85 | 
            +
                    """Check the given values for."""
         | 
| 86 | 
            +
                    if not isinstance(value, dict):
         | 
| 87 | 
            +
                        return value
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    configuration = value.get("configuration", {})
         | 
| 90 | 
            +
                    if isinstance(configuration, dict):
         | 
| 91 | 
            +
                        # Local import to avoid circular import at module import time
         | 
| 92 | 
            +
                        from unique_toolkit.tools.factory import ToolFactory
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                        config = ToolFactory.build_tool_config(
         | 
| 95 | 
            +
                            value["name"],
         | 
| 96 | 
            +
                            **configuration,
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        # Check that the type of config matches the tool name
         | 
| 100 | 
            +
                        from unique_toolkit.tools.factory import ToolFactory
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        assert isinstance(
         | 
| 103 | 
            +
                            configuration,
         | 
| 104 | 
            +
                            ToolFactory.tool_config_map[value["name"]],  # type: ignore
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                        config = configuration
         | 
| 107 | 
            +
                    value["configuration"] = config
         | 
| 108 | 
            +
                    return value
         | 
| @@ -1,7 +1,11 @@ | |
| 1 1 | 
             
            from typing import Callable
         | 
| 2 2 |  | 
| 3 | 
            -
            from  | 
| 3 | 
            +
            from typing import TYPE_CHECKING
         | 
| 4 | 
            +
            from unique_toolkit.tools.schemas import BaseToolConfig
         | 
| 5 | 
            +
            from unique_toolkit.tools.tool import Tool
         | 
| 4 6 |  | 
| 7 | 
            +
            if TYPE_CHECKING:
         | 
| 8 | 
            +
                from unique_toolkit.tools.config import ToolBuildConfig
         | 
| 5 9 |  | 
| 6 10 |  | 
| 7 11 | 
             
            class ToolFactory:
         | 
| @@ -18,14 +22,20 @@ class ToolFactory: | |
| 18 22 | 
             
                    cls.tool_config_map[tool.name] = tool_config
         | 
| 19 23 |  | 
| 20 24 | 
             
                @classmethod
         | 
| 21 | 
            -
                def build_tool(cls, tool_name: str, *args, **kwargs) -> Tool:
         | 
| 25 | 
            +
                def build_tool(cls, tool_name: str, *args, **kwargs) -> Tool[BaseToolConfig]:
         | 
| 22 26 | 
             
                    tool = cls.tool_map[tool_name](*args, **kwargs)
         | 
| 23 27 | 
             
                    return tool
         | 
| 24 28 |  | 
| 25 29 | 
             
                @classmethod
         | 
| 26 | 
            -
                def  | 
| 27 | 
            -
                    cls, tool_name: str, **kwargs
         | 
| 28 | 
            -
                ) -> BaseToolConfig:
         | 
| 30 | 
            +
                def build_tool_with_settings(
         | 
| 31 | 
            +
                    cls, tool_name: str, settings: "ToolBuildConfig", *args, **kwargs
         | 
| 32 | 
            +
                ) -> Tool[BaseToolConfig]:
         | 
| 33 | 
            +
                    tool = cls.tool_map[tool_name](*args, **kwargs)
         | 
| 34 | 
            +
                    tool.settings = settings
         | 
| 35 | 
            +
                    return tool
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @classmethod
         | 
| 38 | 
            +
                def build_tool_config(cls, tool_name: str, **kwargs) -> BaseToolConfig:
         | 
| 29 39 | 
             
                    if tool_name not in cls.tool_config_map:
         | 
| 30 40 | 
             
                        raise ValueError(f"Tool {tool_name} not found")
         | 
| 31 41 | 
             
                    return cls.tool_config_map[tool_name](**kwargs)
         | 
| @@ -0,0 +1,138 @@ | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import gzip
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from typing import Any, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
         | 
| 7 | 
            +
            from unique_toolkit.content.schemas import ContentChunk
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from unique_toolkit.tools.config import get_configuration_dict
         | 
| 10 | 
            +
            from unique_toolkit.tools.utils.source_handling.schema import SourceFormatConfig
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            # TODO: this needs to be more general as the tools can potentially return anything maybe make a base class and then derive per "type" of tool
         | 
| 14 | 
            +
            class ToolCallResponse(BaseModel):
         | 
| 15 | 
            +
                id: str
         | 
| 16 | 
            +
                name: str
         | 
| 17 | 
            +
                debug_info: Optional[dict] = None  # TODO: Make the default {}
         | 
| 18 | 
            +
                content_chunks: Optional[list[ContentChunk]] = None  # TODO: Make the default []
         | 
| 19 | 
            +
                reasoning_result: Optional[dict] = None  # TODO: Make the default {}
         | 
| 20 | 
            +
                error_message: str = ""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @property
         | 
| 23 | 
            +
                def successful(self) -> bool:
         | 
| 24 | 
            +
                    return self.error_message == ""
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class BaseToolConfig(BaseModel):
         | 
| 28 | 
            +
                model_config = get_configuration_dict()
         | 
| 29 | 
            +
                # TODO: add a check for the parameters to all be consistent within the tool config
         | 
| 30 | 
            +
                pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class Source(BaseModel):
         | 
| 34 | 
            +
                """Represents the sources in the tool call response that the llm will see
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                    source_number: The number of the source
         | 
| 38 | 
            +
                    content: The content of the source
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                model_config = ConfigDict(
         | 
| 42 | 
            +
                    validate_by_alias=True, serialize_by_alias=True, validate_by_name=True
         | 
| 43 | 
            +
                )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                source_number: int | None = Field(
         | 
| 46 | 
            +
                    default=None,
         | 
| 47 | 
            +
                    serialization_alias="[source_number] - Used for citations!",
         | 
| 48 | 
            +
                    validation_alias="[source_number] - Used for citations!",
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
                content: str = Field(
         | 
| 51 | 
            +
                    serialization_alias="[content] - Content of source",
         | 
| 52 | 
            +
                    validation_alias="[content] - Content of source",
         | 
| 53 | 
            +
                )
         | 
| 54 | 
            +
                order: int = Field(
         | 
| 55 | 
            +
                    serialization_alias="[order] - Index in the document!",
         | 
| 56 | 
            +
                    validation_alias="[order] - Index in the document!",
         | 
| 57 | 
            +
                )
         | 
| 58 | 
            +
                chunk_id: str | None = Field(
         | 
| 59 | 
            +
                    default=None,
         | 
| 60 | 
            +
                    serialization_alias="[chunk_id] - IGNORE",
         | 
| 61 | 
            +
                    validation_alias="[chunk_id] - IGNORE",
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
                id: str = Field(
         | 
| 64 | 
            +
                    serialization_alias="[id] - IGNORE",
         | 
| 65 | 
            +
                    validation_alias="[id] - IGNORE",
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
                key: str | None = Field(
         | 
| 68 | 
            +
                    default=None,
         | 
| 69 | 
            +
                    serialization_alias="[key] - IGNORE",
         | 
| 70 | 
            +
                    validation_alias="[key] - IGNORE",
         | 
| 71 | 
            +
                )
         | 
| 72 | 
            +
                metadata: dict[str, str] | str | None = Field(
         | 
| 73 | 
            +
                    default=None,
         | 
| 74 | 
            +
                    serialization_alias="[metadata] - Formatted metadata",
         | 
| 75 | 
            +
                    validation_alias="[metadata] - Formatted metadata",
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
                url: str | None = Field(
         | 
| 78 | 
            +
                    default=None,
         | 
| 79 | 
            +
                    serialization_alias="[url] - IGNORE",
         | 
| 80 | 
            +
                    validation_alias="[url] - IGNORE",
         | 
| 81 | 
            +
                )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                @field_validator("metadata", mode="before")
         | 
| 84 | 
            +
                def _metadata_str_to_dict(
         | 
| 85 | 
            +
                    cls, v: str | dict[str, str] | None
         | 
| 86 | 
            +
                ) -> dict[str, str] | None:
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    Accept   • dict   → keep as-is
         | 
| 89 | 
            +
                             • str    → parse tag-string back to dict
         | 
| 90 | 
            +
                    """
         | 
| 91 | 
            +
                    if v is None or isinstance(v, dict):
         | 
| 92 | 
            +
                        return v
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # v is the rendered string.  Build a dict by matching the
         | 
| 95 | 
            +
                    # patterns defined in SourceFormatConfig.sections.
         | 
| 96 | 
            +
                    cfg = SourceFormatConfig()  # or inject your app-wide config
         | 
| 97 | 
            +
                    out: dict[str, str] = {}
         | 
| 98 | 
            +
                    for key, tmpl in cfg.sections.items():
         | 
| 99 | 
            +
                        pattern = cfg.template_to_pattern(tmpl)
         | 
| 100 | 
            +
                        m = re.search(pattern, v, flags=re.S)
         | 
| 101 | 
            +
                        if m:
         | 
| 102 | 
            +
                            out[key] = m.group(1).strip()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    return out if out else v  # type: ignore
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # Compression + Base64 for url to hide it from the LLM
         | 
| 107 | 
            +
                @field_serializer("url")
         | 
| 108 | 
            +
                def serialize_url(self, value: str | None) -> str | None:
         | 
| 109 | 
            +
                    if value is None:
         | 
| 110 | 
            +
                        return None
         | 
| 111 | 
            +
                    # Compress then base64 encode
         | 
| 112 | 
            +
                    compressed = gzip.compress(value.encode())
         | 
| 113 | 
            +
                    return base64.b64encode(compressed).decode()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @field_validator("url", mode="before")
         | 
| 116 | 
            +
                @classmethod
         | 
| 117 | 
            +
                def validate_url(cls, value: Any) -> str | None:
         | 
| 118 | 
            +
                    if value is None or isinstance(value, str) and not value:
         | 
| 119 | 
            +
                        return None
         | 
| 120 | 
            +
                    if isinstance(value, str):
         | 
| 121 | 
            +
                        try:
         | 
| 122 | 
            +
                            # Try to decode base64 then decompress
         | 
| 123 | 
            +
                            decoded_bytes = base64.b64decode(value.encode())
         | 
| 124 | 
            +
                            decompressed = gzip.decompress(decoded_bytes).decode()
         | 
| 125 | 
            +
                            return decompressed
         | 
| 126 | 
            +
                        except Exception:
         | 
| 127 | 
            +
                            # If decoding/decompression fails, assume it's plain text
         | 
| 128 | 
            +
                            return value
         | 
| 129 | 
            +
                    return str(value)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            class ToolPrompts(BaseModel):
         | 
| 133 | 
            +
                name: str
         | 
| 134 | 
            +
                display_name: str
         | 
| 135 | 
            +
                tool_description: str
         | 
| 136 | 
            +
                tool_format_information_for_system_prompt: str
         | 
| 137 | 
            +
                tool_format_information_for_system_prompt: str
         | 
| 138 | 
            +
                input_model: dict[str, Any]
         | 
| @@ -0,0 +1,204 @@ | |
| 1 | 
            +
            from unittest.mock import AsyncMock
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import pytest
         | 
| 4 | 
            +
            from unique_toolkit.chat.service import ChatService
         | 
| 5 | 
            +
            from unique_toolkit.content.schemas import ContentReference
         | 
| 6 | 
            +
            from unique_toolkit.language_model.schemas import LanguageModelFunction
         | 
| 7 | 
            +
            from unique_toolkit.tools.tool_progress_reporter import (
         | 
| 8 | 
            +
                DUMMY_REFERENCE_PLACEHOLDER,
         | 
| 9 | 
            +
                ProgressState,
         | 
| 10 | 
            +
                ToolExecutionStatus,
         | 
| 11 | 
            +
                ToolProgressReporter,
         | 
| 12 | 
            +
                ToolWithToolProgressReporter,
         | 
| 13 | 
            +
                track_tool_progress,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            @pytest.fixture
         | 
| 18 | 
            +
            def chat_service():
         | 
| 19 | 
            +
                return AsyncMock(spec=ChatService)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            @pytest.fixture
         | 
| 23 | 
            +
            def tool_progress_reporter(chat_service):
         | 
| 24 | 
            +
                return ToolProgressReporter(chat_service)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            @pytest.fixture
         | 
| 28 | 
            +
            def tool_call():
         | 
| 29 | 
            +
                return LanguageModelFunction(id="test_id", name="test_tool")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class TestToolProgressReporter:
         | 
| 33 | 
            +
                @pytest.mark.asyncio
         | 
| 34 | 
            +
                async def test_notify_from_tool_call(self, tool_progress_reporter, tool_call):
         | 
| 35 | 
            +
                    # Arrange
         | 
| 36 | 
            +
                    name = "Test Tool"
         | 
| 37 | 
            +
                    message = "Processing..."
         | 
| 38 | 
            +
                    state = ProgressState.RUNNING
         | 
| 39 | 
            +
                    references = [
         | 
| 40 | 
            +
                        ContentReference(
         | 
| 41 | 
            +
                            sequence_number=1,
         | 
| 42 | 
            +
                            id="1",
         | 
| 43 | 
            +
                            message_id="1",
         | 
| 44 | 
            +
                            name="1",
         | 
| 45 | 
            +
                            source="1",
         | 
| 46 | 
            +
                            source_id="1",
         | 
| 47 | 
            +
                            url="1",
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                    ]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Act
         | 
| 52 | 
            +
                    await tool_progress_reporter.notify_from_tool_call(
         | 
| 53 | 
            +
                        tool_call=tool_call,
         | 
| 54 | 
            +
                        name=name,
         | 
| 55 | 
            +
                        message=message,
         | 
| 56 | 
            +
                        state=state,
         | 
| 57 | 
            +
                        references=references,
         | 
| 58 | 
            +
                        requires_new_assistant_message=True,
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # Assert
         | 
| 62 | 
            +
                    assert tool_call.id in tool_progress_reporter.tool_statuses
         | 
| 63 | 
            +
                    status = tool_progress_reporter.tool_statuses[tool_call.id]
         | 
| 64 | 
            +
                    assert status.name == name
         | 
| 65 | 
            +
                    assert status.message == message
         | 
| 66 | 
            +
                    assert status.state == state
         | 
| 67 | 
            +
                    assert status.references == references
         | 
| 68 | 
            +
                    assert tool_progress_reporter.requires_new_assistant_message is True
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def test_replace_placeholders(self, tool_progress_reporter):
         | 
| 71 | 
            +
                    # Arrange
         | 
| 72 | 
            +
                    message = (
         | 
| 73 | 
            +
                        f"Test{DUMMY_REFERENCE_PLACEHOLDER}message{DUMMY_REFERENCE_PLACEHOLDER}"
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Act
         | 
| 77 | 
            +
                    result = tool_progress_reporter._replace_placeholders(message, start_number=1)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # Assert
         | 
| 80 | 
            +
                    assert result == "Test<sup>1</sup>message<sup>2</sup>"
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def test_correct_reference_sequence(self, tool_progress_reporter):
         | 
| 83 | 
            +
                    # Arrange
         | 
| 84 | 
            +
                    references = [
         | 
| 85 | 
            +
                        ContentReference(
         | 
| 86 | 
            +
                            sequence_number=0,
         | 
| 87 | 
            +
                            id="1",
         | 
| 88 | 
            +
                            message_id="1",
         | 
| 89 | 
            +
                            name="1",
         | 
| 90 | 
            +
                            source="1",
         | 
| 91 | 
            +
                            source_id="1",
         | 
| 92 | 
            +
                            url="1",
         | 
| 93 | 
            +
                        ),
         | 
| 94 | 
            +
                        ContentReference(
         | 
| 95 | 
            +
                            sequence_number=0,
         | 
| 96 | 
            +
                            id="2",
         | 
| 97 | 
            +
                            message_id="2",
         | 
| 98 | 
            +
                            name="2",
         | 
| 99 | 
            +
                            source="2",
         | 
| 100 | 
            +
                            source_id="2",
         | 
| 101 | 
            +
                            url="2",
         | 
| 102 | 
            +
                        ),
         | 
| 103 | 
            +
                    ]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # Act
         | 
| 106 | 
            +
                    result = tool_progress_reporter._correct_reference_sequence(
         | 
| 107 | 
            +
                        references, start_number=1
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # Assert
         | 
| 111 | 
            +
                    assert len(result) == 2
         | 
| 112 | 
            +
                    assert result[0].sequence_number == 1
         | 
| 113 | 
            +
                    assert result[1].sequence_number == 2
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @pytest.mark.asyncio
         | 
| 116 | 
            +
                async def test_publish_updates_chat_service(
         | 
| 117 | 
            +
                    self, tool_progress_reporter, tool_call
         | 
| 118 | 
            +
                ):
         | 
| 119 | 
            +
                    # Arrange
         | 
| 120 | 
            +
                    status = ToolExecutionStatus(
         | 
| 121 | 
            +
                        name="Test Tool",
         | 
| 122 | 
            +
                        message="Test message",
         | 
| 123 | 
            +
                        state=ProgressState.FINISHED,
         | 
| 124 | 
            +
                        references=[
         | 
| 125 | 
            +
                            ContentReference(
         | 
| 126 | 
            +
                                sequence_number=1,
         | 
| 127 | 
            +
                                id="1",
         | 
| 128 | 
            +
                                message_id="1",
         | 
| 129 | 
            +
                                name="1",
         | 
| 130 | 
            +
                                source="1",
         | 
| 131 | 
            +
                                source_id="1",
         | 
| 132 | 
            +
                                url="1",
         | 
| 133 | 
            +
                            )
         | 
| 134 | 
            +
                        ],
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
                    tool_progress_reporter.tool_statuses[tool_call.id] = status
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # Act
         | 
| 139 | 
            +
                    await tool_progress_reporter.publish()
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Assert
         | 
| 142 | 
            +
                    tool_progress_reporter.chat_service.modify_assistant_message_async.assert_called_once()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class TestToolProgressDecorator:
         | 
| 146 | 
            +
                class DummyTool(ToolWithToolProgressReporter):
         | 
| 147 | 
            +
                    def __init__(self, tool_progress_reporter):
         | 
| 148 | 
            +
                        self.tool_progress_reporter = tool_progress_reporter
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    @track_tool_progress(
         | 
| 151 | 
            +
                        message="Processing",
         | 
| 152 | 
            +
                        on_start_state=ProgressState.STARTED,
         | 
| 153 | 
            +
                        on_success_state=ProgressState.FINISHED,
         | 
| 154 | 
            +
                        on_success_message="Completed",
         | 
| 155 | 
            +
                        requires_new_assistant_message=True,
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    async def execute(self, tool_call, notification_tool_name):
         | 
| 158 | 
            +
                        return {
         | 
| 159 | 
            +
                            "references": [
         | 
| 160 | 
            +
                                ContentReference(
         | 
| 161 | 
            +
                                    sequence_number=1,
         | 
| 162 | 
            +
                                    id="1",
         | 
| 163 | 
            +
                                    message_id="1",
         | 
| 164 | 
            +
                                    name="1",
         | 
| 165 | 
            +
                                    source="1",
         | 
| 166 | 
            +
                                    source_id="1",
         | 
| 167 | 
            +
                                    url="1",
         | 
| 168 | 
            +
                                )
         | 
| 169 | 
            +
                            ]
         | 
| 170 | 
            +
                        }
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                @pytest.mark.asyncio
         | 
| 173 | 
            +
                async def test_decorator_success_flow(self, tool_progress_reporter, tool_call):
         | 
| 174 | 
            +
                    # Arrange
         | 
| 175 | 
            +
                    tool = self.DummyTool(tool_progress_reporter)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # Act
         | 
| 178 | 
            +
                    await tool.execute(tool_call, "Test Tool")
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    # Assert
         | 
| 181 | 
            +
                    assert len(tool_progress_reporter.tool_statuses) == 1
         | 
| 182 | 
            +
                    status = tool_progress_reporter.tool_statuses[tool_call.id]
         | 
| 183 | 
            +
                    assert status.state == ProgressState.FINISHED
         | 
| 184 | 
            +
                    assert status.message == "Completed"
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                @pytest.mark.asyncio
         | 
| 187 | 
            +
                async def test_decorator_error_flow(self, tool_progress_reporter, tool_call):
         | 
| 188 | 
            +
                    # Arrange
         | 
| 189 | 
            +
                    class ErrorTool(ToolWithToolProgressReporter):
         | 
| 190 | 
            +
                        def __init__(self, tool_progress_reporter):
         | 
| 191 | 
            +
                            self.tool_progress_reporter = tool_progress_reporter
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        @track_tool_progress(message="Processing")
         | 
| 194 | 
            +
                        async def execute(self, tool_call, notification_tool_name):
         | 
| 195 | 
            +
                            raise ValueError("Test error")
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    tool = ErrorTool(tool_progress_reporter)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    # Act & Assert
         | 
| 200 | 
            +
                    with pytest.raises(ValueError):
         | 
| 201 | 
            +
                        await tool.execute(tool_call, "Test Tool")
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    status = tool_progress_reporter.tool_statuses[tool_call.id]
         | 
| 204 | 
            +
                    assert status.state == ProgressState.FAILED
         |