unique_toolkit 0.8.4__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,7 @@ class BaseEvent(BaseModel):
51
51
  # MCP schemas
52
52
  ###
53
53
 
54
+
54
55
  class McpTool(BaseModel):
55
56
  model_config = model_config
56
57
 
@@ -79,6 +80,7 @@ class McpTool(BaseModel):
79
80
  description="Whether the tool is connected to the MCP server. This is a Unique specific field.",
80
81
  )
81
82
 
83
+
82
84
  class McpServer(BaseModel):
83
85
  model_config = model_config
84
86
 
@@ -94,6 +96,7 @@ class McpServer(BaseModel):
94
96
  )
95
97
  tools: list[McpTool] = []
96
98
 
99
+
97
100
  ###
98
101
  # ChatEvent schemas
99
102
  ###
@@ -86,6 +86,24 @@ class LanguageModelFunction(BaseModel):
86
86
  return seralization
87
87
 
88
88
 
89
+ def __eq__(self, other:Self) -> bool:
90
+ """
91
+ Compare two tool calls based on name and arguments.
92
+ """
93
+ if not isinstance(other, LanguageModelFunction):
94
+ return False
95
+
96
+ if self.id != other.id:
97
+ return False
98
+
99
+ if self.name != other.name:
100
+ return False
101
+
102
+ if self.arguments != other.arguments:
103
+ return False
104
+
105
+ return True
106
+
89
107
  # This is tailored to the unique backend
90
108
  class LanguageModelStreamResponse(BaseModel):
91
109
  model_config = model_config
@@ -0,0 +1,72 @@
1
+ from unique_toolkit.content.schemas import ContentChunk, ContentReference
2
+ from unique_toolkit.tools.schemas import ToolCallResponse
3
+
4
+
5
+ class tool_chunks:
6
+ def __init__(self, name: str, chunks: list) -> None:
7
+ self.name = name
8
+ self.chunks = chunks
9
+
10
+
11
+ class ReferenceManager:
12
+ def __init__(self):
13
+ self._tool_chunks: dict[str, tool_chunks] = {}
14
+ self._chunks: list[ContentChunk] = []
15
+ self._references: list[list[ContentReference]] = []
16
+
17
+ def extract_referenceable_chunks(
18
+ self, tool_responses: list[ToolCallResponse]
19
+ ) -> None:
20
+ for tool_response in tool_responses:
21
+ if not tool_response.content_chunks:
22
+ continue
23
+ self._chunks.extend(tool_response.content_chunks or [])
24
+ self._tool_chunks[tool_response.id] = tool_chunks(
25
+ tool_response.name, tool_response.content_chunks
26
+ )
27
+
28
+ def get_chunks(self) -> list[ContentChunk]:
29
+ return self._chunks
30
+
31
+ def get_tool_chunks(self) -> dict:
32
+ return self._tool_chunks
33
+
34
+ def replace(self, chunks: list[ContentChunk]):
35
+ self._chunks = chunks
36
+
37
+ def add_references(
38
+ self,
39
+ references: list[ContentReference],
40
+ ):
41
+ self._references.append(references)
42
+
43
+ def get_references(
44
+ self,
45
+ ) -> list[list[ContentReference]]:
46
+ return self._references
47
+
48
+ def get_latest_references(
49
+ self,
50
+ ) -> list[ContentReference]:
51
+ if not self._references:
52
+ return []
53
+ return self._references[-1]
54
+
55
+ def get_latest_referenced_chunks(self) -> list[ContentChunk]:
56
+ if not self._references:
57
+ return []
58
+ return self._get_referenced_chunks_from_references(self._references[-1])
59
+
60
+ def _get_referenced_chunks_from_references(
61
+ self,
62
+ references: list[ContentReference],
63
+ ) -> list[ContentChunk]:
64
+ """
65
+ Get _referenced_chunks by matching sourceId from _references with merged id and chunk_id from _chunks.
66
+ """
67
+ referenced_chunks: list[ContentChunk] = []
68
+ for ref in references:
69
+ for chunk in self._chunks:
70
+ if ref.source_id == f"{chunk.id}-{chunk.chunk_id}":
71
+ referenced_chunks.append(chunk)
72
+ return referenced_chunks
@@ -0,0 +1,62 @@
1
+ from unique_toolkit.content.schemas import ContentChunk, ContentReference
2
+
3
+
4
+ class AgentChunksHandler:
5
+ def __init__(self):
6
+ self._tool_chunks = {}
7
+ self._chunks: list[ContentChunk] = []
8
+ self._references: list[list[ContentReference]] = []
9
+
10
+ @property
11
+ def chunks(self) -> list[ContentChunk]:
12
+ return self._chunks
13
+
14
+ @property
15
+ def tool_chunks(self) -> dict:
16
+ return self._tool_chunks
17
+
18
+ def extend(self, chunks: list[ContentChunk]):
19
+ self._chunks.extend(chunks)
20
+
21
+ def replace(self, chunks: list[ContentChunk]):
22
+ self._chunks = chunks
23
+
24
+ def add_references(
25
+ self,
26
+ references: list[ContentReference],
27
+ ):
28
+ self._references.append(references)
29
+
30
+ @property
31
+ def all_references(
32
+ self,
33
+ ) -> list[list[ContentReference]]:
34
+ return self._references
35
+
36
+ @property
37
+ def latest_references(
38
+ self,
39
+ ) -> list[ContentReference]:
40
+ if not self._references:
41
+ return []
42
+ return self._references[-1]
43
+
44
+ @property
45
+ def latest_referenced_chunks(self) -> list[ContentChunk]:
46
+ if not self._references:
47
+ return []
48
+ return self._get_referenced_chunks_from_references(self._references[-1])
49
+
50
+ def _get_referenced_chunks_from_references(
51
+ self,
52
+ references: list[ContentReference],
53
+ ) -> list[ContentChunk]:
54
+ """
55
+ Get _referenced_chunks by matching sourceId from _references with merged id and chunk_id from _chunks.
56
+ """
57
+ referenced_chunks: list[ContentChunk] = []
58
+ for ref in references:
59
+ for chunk in self._chunks:
60
+ if ref.source_id == str(chunk.id) + "_" + str(chunk.chunk_id):
61
+ referenced_chunks.append(chunk)
62
+ return referenced_chunks
@@ -0,0 +1,108 @@
1
+ from enum import StrEnum
2
+ import humps
3
+ from typing import Any
4
+ from pydantic.fields import ComputedFieldInfo, FieldInfo
5
+ from pydantic.alias_generators import to_camel
6
+ from pydantic import (
7
+ BaseModel,
8
+ ConfigDict,
9
+ Field,
10
+ ValidationInfo,
11
+ model_validator,
12
+ )
13
+
14
+ from typing import TYPE_CHECKING
15
+
16
+ if TYPE_CHECKING:
17
+ from unique_toolkit.tools.schemas import BaseToolConfig
18
+
19
+
20
+ def field_title_generator(
21
+ title: str,
22
+ info: FieldInfo | ComputedFieldInfo,
23
+ ) -> str:
24
+ return humps.decamelize(title).replace("_", " ").title()
25
+
26
+
27
+ def model_title_generator(model: type) -> str:
28
+ return humps.decamelize(model.__name__).replace("_", " ").title()
29
+
30
+
31
+ def get_configuration_dict(**kwargs) -> ConfigDict:
32
+ return ConfigDict(
33
+ alias_generator=to_camel,
34
+ field_title_generator=field_title_generator,
35
+ model_title_generator=model_title_generator,
36
+ populate_by_name=True,
37
+ protected_namespaces=(),
38
+ **kwargs,
39
+ )
40
+
41
+
42
+ class ToolIcon(StrEnum):
43
+ ANALYTICS = "IconAnalytics"
44
+ BOOK = "IconBook"
45
+ FOLDERDATA = "IconFolderData"
46
+ INTEGRATION = "IconIntegration"
47
+ TEXT_COMPARE = "IconTextCompare"
48
+ WORLD = "IconWorld"
49
+ QUICK_REPLY = "IconQuickReply"
50
+ CHAT_PLUS = "IconChatPlus"
51
+
52
+
53
+ class ToolSelectionPolicy(StrEnum):
54
+ """Determine the usage policy of tools."""
55
+
56
+ FORCED_BY_DEFAULT = "ForcedByDefault"
57
+ ON_BY_DEFAULT = "OnByDefault"
58
+ BY_USER = "ByUser"
59
+
60
+
61
+ class ToolBuildConfig(BaseModel):
62
+ model_config = get_configuration_dict()
63
+ """Main tool configuration"""
64
+
65
+ name: str
66
+ configuration: "BaseToolConfig"
67
+ display_name: str = ""
68
+ icon: ToolIcon = ToolIcon.BOOK
69
+ selection_policy: ToolSelectionPolicy = Field(
70
+ default=ToolSelectionPolicy.BY_USER,
71
+ )
72
+ is_exclusive: bool = Field(
73
+ default=False,
74
+ description="This tool must be chosen by the user and no other tools are used for this iteration.",
75
+ )
76
+
77
+ is_enabled: bool = Field(default=True)
78
+
79
+ @model_validator(mode="before")
80
+ def initialize_config_based_on_tool_name(
81
+ cls,
82
+ value: Any,
83
+ info: ValidationInfo,
84
+ ) -> Any:
85
+ """Check the given values for."""
86
+ if not isinstance(value, dict):
87
+ return value
88
+
89
+ configuration = value.get("configuration", {})
90
+ if isinstance(configuration, dict):
91
+ # Local import to avoid circular import at module import time
92
+ from unique_toolkit.tools.factory import ToolFactory
93
+
94
+ config = ToolFactory.build_tool_config(
95
+ value["name"],
96
+ **configuration,
97
+ )
98
+ else:
99
+ # Check that the type of config matches the tool name
100
+ from unique_toolkit.tools.factory import ToolFactory
101
+
102
+ assert isinstance(
103
+ configuration,
104
+ ToolFactory.tool_config_map[value["name"]], # type: ignore
105
+ )
106
+ config = configuration
107
+ value["configuration"] = config
108
+ return value
@@ -1,7 +1,11 @@
1
1
  from typing import Callable
2
2
 
3
- from unique_toolkit.unique_toolkit.tools.tool_definitions import BaseToolConfig, Tool
3
+ from typing import TYPE_CHECKING
4
+ from unique_toolkit.tools.schemas import BaseToolConfig
5
+ from unique_toolkit.tools.tool import Tool
4
6
 
7
+ if TYPE_CHECKING:
8
+ from unique_toolkit.tools.config import ToolBuildConfig
5
9
 
6
10
 
7
11
  class ToolFactory:
@@ -18,14 +22,20 @@ class ToolFactory:
18
22
  cls.tool_config_map[tool.name] = tool_config
19
23
 
20
24
  @classmethod
21
- def build_tool(cls, tool_name: str, *args, **kwargs) -> Tool:
25
+ def build_tool(cls, tool_name: str, *args, **kwargs) -> Tool[BaseToolConfig]:
22
26
  tool = cls.tool_map[tool_name](*args, **kwargs)
23
27
  return tool
24
28
 
25
29
  @classmethod
26
- def build_tool_config(
27
- cls, tool_name: str, **kwargs
28
- ) -> BaseToolConfig:
30
+ def build_tool_with_settings(
31
+ cls, tool_name: str, settings: "ToolBuildConfig", *args, **kwargs
32
+ ) -> Tool[BaseToolConfig]:
33
+ tool = cls.tool_map[tool_name](*args, **kwargs)
34
+ tool.settings = settings
35
+ return tool
36
+
37
+ @classmethod
38
+ def build_tool_config(cls, tool_name: str, **kwargs) -> BaseToolConfig:
29
39
  if tool_name not in cls.tool_config_map:
30
40
  raise ValueError(f"Tool {tool_name} not found")
31
41
  return cls.tool_config_map[tool_name](**kwargs)
@@ -0,0 +1,138 @@
1
+ import base64
2
+ import gzip
3
+ import re
4
+ from typing import Any, Optional
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
7
+ from unique_toolkit.content.schemas import ContentChunk
8
+
9
+ from unique_toolkit.tools.config import get_configuration_dict
10
+ from unique_toolkit.tools.utils.source_handling.schema import SourceFormatConfig
11
+
12
+
13
+ # TODO: this needs to be more general as the tools can potentially return anything maybe make a base class and then derive per "type" of tool
14
+ class ToolCallResponse(BaseModel):
15
+ id: str
16
+ name: str
17
+ debug_info: Optional[dict] = None # TODO: Make the default {}
18
+ content_chunks: Optional[list[ContentChunk]] = None # TODO: Make the default []
19
+ reasoning_result: Optional[dict] = None # TODO: Make the default {}
20
+ error_message: str = ""
21
+
22
+ @property
23
+ def successful(self) -> bool:
24
+ return self.error_message == ""
25
+
26
+
27
+ class BaseToolConfig(BaseModel):
28
+ model_config = get_configuration_dict()
29
+ # TODO: add a check for the parameters to all be consistent within the tool config
30
+ pass
31
+
32
+
33
+ class Source(BaseModel):
34
+ """Represents the sources in the tool call response that the llm will see
35
+
36
+ Args:
37
+ source_number: The number of the source
38
+ content: The content of the source
39
+ """
40
+
41
+ model_config = ConfigDict(
42
+ validate_by_alias=True, serialize_by_alias=True, validate_by_name=True
43
+ )
44
+
45
+ source_number: int | None = Field(
46
+ default=None,
47
+ serialization_alias="[source_number] - Used for citations!",
48
+ validation_alias="[source_number] - Used for citations!",
49
+ )
50
+ content: str = Field(
51
+ serialization_alias="[content] - Content of source",
52
+ validation_alias="[content] - Content of source",
53
+ )
54
+ order: int = Field(
55
+ serialization_alias="[order] - Index in the document!",
56
+ validation_alias="[order] - Index in the document!",
57
+ )
58
+ chunk_id: str | None = Field(
59
+ default=None,
60
+ serialization_alias="[chunk_id] - IGNORE",
61
+ validation_alias="[chunk_id] - IGNORE",
62
+ )
63
+ id: str = Field(
64
+ serialization_alias="[id] - IGNORE",
65
+ validation_alias="[id] - IGNORE",
66
+ )
67
+ key: str | None = Field(
68
+ default=None,
69
+ serialization_alias="[key] - IGNORE",
70
+ validation_alias="[key] - IGNORE",
71
+ )
72
+ metadata: dict[str, str] | str | None = Field(
73
+ default=None,
74
+ serialization_alias="[metadata] - Formatted metadata",
75
+ validation_alias="[metadata] - Formatted metadata",
76
+ )
77
+ url: str | None = Field(
78
+ default=None,
79
+ serialization_alias="[url] - IGNORE",
80
+ validation_alias="[url] - IGNORE",
81
+ )
82
+
83
+ @field_validator("metadata", mode="before")
84
+ def _metadata_str_to_dict(
85
+ cls, v: str | dict[str, str] | None
86
+ ) -> dict[str, str] | None:
87
+ """
88
+ Accept • dict → keep as-is
89
+ • str → parse tag-string back to dict
90
+ """
91
+ if v is None or isinstance(v, dict):
92
+ return v
93
+
94
+ # v is the rendered string. Build a dict by matching the
95
+ # patterns defined in SourceFormatConfig.sections.
96
+ cfg = SourceFormatConfig() # or inject your app-wide config
97
+ out: dict[str, str] = {}
98
+ for key, tmpl in cfg.sections.items():
99
+ pattern = cfg.template_to_pattern(tmpl)
100
+ m = re.search(pattern, v, flags=re.S)
101
+ if m:
102
+ out[key] = m.group(1).strip()
103
+
104
+ return out if out else v # type: ignore
105
+
106
+ # Compression + Base64 for url to hide it from the LLM
107
+ @field_serializer("url")
108
+ def serialize_url(self, value: str | None) -> str | None:
109
+ if value is None:
110
+ return None
111
+ # Compress then base64 encode
112
+ compressed = gzip.compress(value.encode())
113
+ return base64.b64encode(compressed).decode()
114
+
115
+ @field_validator("url", mode="before")
116
+ @classmethod
117
+ def validate_url(cls, value: Any) -> str | None:
118
+ if value is None or isinstance(value, str) and not value:
119
+ return None
120
+ if isinstance(value, str):
121
+ try:
122
+ # Try to decode base64 then decompress
123
+ decoded_bytes = base64.b64decode(value.encode())
124
+ decompressed = gzip.decompress(decoded_bytes).decode()
125
+ return decompressed
126
+ except Exception:
127
+ # If decoding/decompression fails, assume it's plain text
128
+ return value
129
+ return str(value)
130
+
131
+
132
+ class ToolPrompts(BaseModel):
133
+ name: str
134
+ display_name: str
135
+ tool_description: str
136
+ tool_format_information_for_system_prompt: str
137
+ tool_format_information_for_system_prompt: str
138
+ input_model: dict[str, Any]
@@ -0,0 +1,204 @@
1
+ from unittest.mock import AsyncMock
2
+
3
+ import pytest
4
+ from unique_toolkit.chat.service import ChatService
5
+ from unique_toolkit.content.schemas import ContentReference
6
+ from unique_toolkit.language_model.schemas import LanguageModelFunction
7
+ from unique_toolkit.tools.tool_progress_reporter import (
8
+ DUMMY_REFERENCE_PLACEHOLDER,
9
+ ProgressState,
10
+ ToolExecutionStatus,
11
+ ToolProgressReporter,
12
+ ToolWithToolProgressReporter,
13
+ track_tool_progress,
14
+ )
15
+
16
+
17
+ @pytest.fixture
18
+ def chat_service():
19
+ return AsyncMock(spec=ChatService)
20
+
21
+
22
+ @pytest.fixture
23
+ def tool_progress_reporter(chat_service):
24
+ return ToolProgressReporter(chat_service)
25
+
26
+
27
+ @pytest.fixture
28
+ def tool_call():
29
+ return LanguageModelFunction(id="test_id", name="test_tool")
30
+
31
+
32
+ class TestToolProgressReporter:
33
+ @pytest.mark.asyncio
34
+ async def test_notify_from_tool_call(self, tool_progress_reporter, tool_call):
35
+ # Arrange
36
+ name = "Test Tool"
37
+ message = "Processing..."
38
+ state = ProgressState.RUNNING
39
+ references = [
40
+ ContentReference(
41
+ sequence_number=1,
42
+ id="1",
43
+ message_id="1",
44
+ name="1",
45
+ source="1",
46
+ source_id="1",
47
+ url="1",
48
+ )
49
+ ]
50
+
51
+ # Act
52
+ await tool_progress_reporter.notify_from_tool_call(
53
+ tool_call=tool_call,
54
+ name=name,
55
+ message=message,
56
+ state=state,
57
+ references=references,
58
+ requires_new_assistant_message=True,
59
+ )
60
+
61
+ # Assert
62
+ assert tool_call.id in tool_progress_reporter.tool_statuses
63
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
64
+ assert status.name == name
65
+ assert status.message == message
66
+ assert status.state == state
67
+ assert status.references == references
68
+ assert tool_progress_reporter.requires_new_assistant_message is True
69
+
70
+ def test_replace_placeholders(self, tool_progress_reporter):
71
+ # Arrange
72
+ message = (
73
+ f"Test{DUMMY_REFERENCE_PLACEHOLDER}message{DUMMY_REFERENCE_PLACEHOLDER}"
74
+ )
75
+
76
+ # Act
77
+ result = tool_progress_reporter._replace_placeholders(message, start_number=1)
78
+
79
+ # Assert
80
+ assert result == "Test<sup>1</sup>message<sup>2</sup>"
81
+
82
+ def test_correct_reference_sequence(self, tool_progress_reporter):
83
+ # Arrange
84
+ references = [
85
+ ContentReference(
86
+ sequence_number=0,
87
+ id="1",
88
+ message_id="1",
89
+ name="1",
90
+ source="1",
91
+ source_id="1",
92
+ url="1",
93
+ ),
94
+ ContentReference(
95
+ sequence_number=0,
96
+ id="2",
97
+ message_id="2",
98
+ name="2",
99
+ source="2",
100
+ source_id="2",
101
+ url="2",
102
+ ),
103
+ ]
104
+
105
+ # Act
106
+ result = tool_progress_reporter._correct_reference_sequence(
107
+ references, start_number=1
108
+ )
109
+
110
+ # Assert
111
+ assert len(result) == 2
112
+ assert result[0].sequence_number == 1
113
+ assert result[1].sequence_number == 2
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_publish_updates_chat_service(
117
+ self, tool_progress_reporter, tool_call
118
+ ):
119
+ # Arrange
120
+ status = ToolExecutionStatus(
121
+ name="Test Tool",
122
+ message="Test message",
123
+ state=ProgressState.FINISHED,
124
+ references=[
125
+ ContentReference(
126
+ sequence_number=1,
127
+ id="1",
128
+ message_id="1",
129
+ name="1",
130
+ source="1",
131
+ source_id="1",
132
+ url="1",
133
+ )
134
+ ],
135
+ )
136
+ tool_progress_reporter.tool_statuses[tool_call.id] = status
137
+
138
+ # Act
139
+ await tool_progress_reporter.publish()
140
+
141
+ # Assert
142
+ tool_progress_reporter.chat_service.modify_assistant_message_async.assert_called_once()
143
+
144
+
145
+ class TestToolProgressDecorator:
146
+ class DummyTool(ToolWithToolProgressReporter):
147
+ def __init__(self, tool_progress_reporter):
148
+ self.tool_progress_reporter = tool_progress_reporter
149
+
150
+ @track_tool_progress(
151
+ message="Processing",
152
+ on_start_state=ProgressState.STARTED,
153
+ on_success_state=ProgressState.FINISHED,
154
+ on_success_message="Completed",
155
+ requires_new_assistant_message=True,
156
+ )
157
+ async def execute(self, tool_call, notification_tool_name):
158
+ return {
159
+ "references": [
160
+ ContentReference(
161
+ sequence_number=1,
162
+ id="1",
163
+ message_id="1",
164
+ name="1",
165
+ source="1",
166
+ source_id="1",
167
+ url="1",
168
+ )
169
+ ]
170
+ }
171
+
172
+ @pytest.mark.asyncio
173
+ async def test_decorator_success_flow(self, tool_progress_reporter, tool_call):
174
+ # Arrange
175
+ tool = self.DummyTool(tool_progress_reporter)
176
+
177
+ # Act
178
+ await tool.execute(tool_call, "Test Tool")
179
+
180
+ # Assert
181
+ assert len(tool_progress_reporter.tool_statuses) == 1
182
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
183
+ assert status.state == ProgressState.FINISHED
184
+ assert status.message == "Completed"
185
+
186
+ @pytest.mark.asyncio
187
+ async def test_decorator_error_flow(self, tool_progress_reporter, tool_call):
188
+ # Arrange
189
+ class ErrorTool(ToolWithToolProgressReporter):
190
+ def __init__(self, tool_progress_reporter):
191
+ self.tool_progress_reporter = tool_progress_reporter
192
+
193
+ @track_tool_progress(message="Processing")
194
+ async def execute(self, tool_call, notification_tool_name):
195
+ raise ValueError("Test error")
196
+
197
+ tool = ErrorTool(tool_progress_reporter)
198
+
199
+ # Act & Assert
200
+ with pytest.raises(ValueError):
201
+ await tool.execute(tool_call, "Test Tool")
202
+
203
+ status = tool_progress_reporter.tool_statuses[tool_call.id]
204
+ assert status.state == ProgressState.FAILED