unique_toolkit 1.28.8__py3-none-any.whl → 1.33.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/__init__.py +12 -6
- unique_toolkit/_common/docx_generator/service.py +8 -32
- unique_toolkit/_common/utils/jinja/helpers.py +10 -0
- unique_toolkit/_common/utils/jinja/render.py +18 -0
- unique_toolkit/_common/utils/jinja/schema.py +65 -0
- unique_toolkit/_common/utils/jinja/utils.py +80 -0
- unique_toolkit/agentic/message_log_manager/service.py +9 -0
- unique_toolkit/agentic/tools/a2a/postprocessing/_display_utils.py +58 -3
- unique_toolkit/agentic/tools/a2a/postprocessing/_ref_utils.py +11 -0
- unique_toolkit/agentic/tools/a2a/postprocessing/config.py +33 -0
- unique_toolkit/agentic/tools/a2a/postprocessing/display.py +99 -15
- unique_toolkit/agentic/tools/a2a/postprocessing/test/test_display.py +421 -0
- unique_toolkit/agentic/tools/a2a/postprocessing/test/test_display_utils.py +768 -0
- unique_toolkit/agentic/tools/a2a/tool/config.py +77 -1
- unique_toolkit/agentic/tools/a2a/tool/service.py +67 -3
- unique_toolkit/agentic/tools/config.py +5 -45
- unique_toolkit/agentic/tools/openai_builtin/base.py +4 -0
- unique_toolkit/agentic/tools/openai_builtin/code_interpreter/service.py +4 -0
- unique_toolkit/agentic/tools/tool_manager.py +16 -19
- unique_toolkit/app/__init__.py +3 -0
- unique_toolkit/app/fast_api_factory.py +131 -0
- unique_toolkit/app/webhook.py +77 -0
- unique_toolkit/chat/functions.py +1 -1
- unique_toolkit/content/functions.py +4 -4
- unique_toolkit/content/service.py +1 -1
- unique_toolkit/data_extraction/README.md +96 -0
- unique_toolkit/data_extraction/__init__.py +11 -0
- unique_toolkit/data_extraction/augmented/__init__.py +5 -0
- unique_toolkit/data_extraction/augmented/service.py +93 -0
- unique_toolkit/data_extraction/base.py +25 -0
- unique_toolkit/data_extraction/basic/__init__.py +11 -0
- unique_toolkit/data_extraction/basic/config.py +18 -0
- unique_toolkit/data_extraction/basic/prompt.py +13 -0
- unique_toolkit/data_extraction/basic/service.py +55 -0
- unique_toolkit/embedding/service.py +1 -1
- unique_toolkit/framework_utilities/langchain/__init__.py +10 -0
- unique_toolkit/framework_utilities/openai/client.py +2 -1
- unique_toolkit/language_model/infos.py +22 -1
- unique_toolkit/services/knowledge_base.py +4 -6
- {unique_toolkit-1.28.8.dist-info → unique_toolkit-1.33.3.dist-info}/METADATA +51 -2
- {unique_toolkit-1.28.8.dist-info → unique_toolkit-1.33.3.dist-info}/RECORD +43 -27
- unique_toolkit/agentic/tools/test/test_tool_manager.py +0 -1686
- {unique_toolkit-1.28.8.dist-info → unique_toolkit-1.33.3.dist-info}/LICENSE +0 -0
- {unique_toolkit-1.28.8.dist-info → unique_toolkit-1.33.3.dist-info}/WHEEL +0 -0
unique_toolkit/__init__.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
# Re-export commonly used classes for easier imports
|
|
2
2
|
from unique_toolkit.chat import ChatService
|
|
3
3
|
from unique_toolkit.content import ContentService
|
|
4
|
+
from unique_toolkit.data_extraction import (
|
|
5
|
+
StructuredOutputDataExtractor,
|
|
6
|
+
StructuredOutputDataExtractorConfig,
|
|
7
|
+
)
|
|
4
8
|
from unique_toolkit.embedding import EmbeddingService
|
|
5
9
|
from unique_toolkit.framework_utilities.openai.client import (
|
|
6
10
|
get_async_openai_client,
|
|
@@ -26,17 +30,19 @@ except ImportError:
|
|
|
26
30
|
# You can add other classes you frequently use here as well
|
|
27
31
|
|
|
28
32
|
__all__ = [
|
|
29
|
-
"LanguageModelService",
|
|
30
|
-
"LanguageModelMessages",
|
|
31
|
-
"LanguageModelName",
|
|
32
|
-
"LanguageModelToolDescription",
|
|
33
33
|
"ChatService",
|
|
34
34
|
"ContentService",
|
|
35
35
|
"EmbeddingService",
|
|
36
|
-
"ShortTermMemoryService",
|
|
37
|
-
"KnowledgeBaseService",
|
|
38
36
|
"get_openai_client",
|
|
39
37
|
"get_async_openai_client",
|
|
38
|
+
"KnowledgeBaseService",
|
|
39
|
+
"LanguageModelMessages",
|
|
40
|
+
"LanguageModelName",
|
|
41
|
+
"LanguageModelService",
|
|
42
|
+
"LanguageModelToolDescription",
|
|
43
|
+
"ShortTermMemoryService",
|
|
44
|
+
"StructuredOutputDataExtractor",
|
|
45
|
+
"StructuredOutputDataExtractorConfig",
|
|
40
46
|
]
|
|
41
47
|
|
|
42
48
|
# Add langchain-specific exports if available
|
|
@@ -14,7 +14,6 @@ from unique_toolkit._common.docx_generator.schemas import (
|
|
|
14
14
|
RunField,
|
|
15
15
|
RunsField,
|
|
16
16
|
)
|
|
17
|
-
from unique_toolkit.services import KnowledgeBaseService
|
|
18
17
|
|
|
19
18
|
generator_dir_path = Path(__file__).resolve().parent
|
|
20
19
|
|
|
@@ -23,13 +22,9 @@ _LOGGER = logging.getLogger(__name__)
|
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class DocxGeneratorService:
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
knowledge_base_service: KnowledgeBaseService,
|
|
29
|
-
config: DocxGeneratorConfig,
|
|
30
|
-
):
|
|
31
|
-
self._knowledge_base_service = knowledge_base_service
|
|
25
|
+
def __init__(self, config: DocxGeneratorConfig, *, template: bytes | None = None):
|
|
32
26
|
self._config = config
|
|
27
|
+
self._template = template
|
|
33
28
|
|
|
34
29
|
@staticmethod
|
|
35
30
|
def parse_markdown_to_list_content_fields(
|
|
@@ -190,10 +185,7 @@ class DocxGeneratorService:
|
|
|
190
185
|
subdoc_content (list[HeadingField | ParagraphField | RunsField]): The content to be added to the docx file.
|
|
191
186
|
fields (dict): Other fields to be added to the docx file. Defaults to None.
|
|
192
187
|
"""
|
|
193
|
-
|
|
194
|
-
docx_template_object = self._get_template(self._config.template_content_id)
|
|
195
|
-
|
|
196
|
-
doc = DocxTemplate(io.BytesIO(docx_template_object))
|
|
188
|
+
doc = DocxTemplate(io.BytesIO(self.template))
|
|
197
189
|
|
|
198
190
|
try:
|
|
199
191
|
content = {}
|
|
@@ -218,27 +210,7 @@ class DocxGeneratorService:
|
|
|
218
210
|
_LOGGER.error(f"Error generating docx: {e}")
|
|
219
211
|
return None
|
|
220
212
|
|
|
221
|
-
def
|
|
222
|
-
try:
|
|
223
|
-
if template_content_id:
|
|
224
|
-
_LOGGER.info(
|
|
225
|
-
f"Downloading template from content ID: {template_content_id}"
|
|
226
|
-
)
|
|
227
|
-
file_content = self._knowledge_base_service.download_content_to_bytes(
|
|
228
|
-
content_id=template_content_id
|
|
229
|
-
)
|
|
230
|
-
else:
|
|
231
|
-
_LOGGER.info("No template content ID provided. Using default template.")
|
|
232
|
-
file_content = self._get_default_template()
|
|
233
|
-
except Exception as e:
|
|
234
|
-
_LOGGER.warning(
|
|
235
|
-
f"An error occurred while downloading the template {e}. Make sure the template content ID is valid. Falling back to default template."
|
|
236
|
-
)
|
|
237
|
-
file_content = self._get_default_template()
|
|
238
|
-
|
|
239
|
-
return file_content
|
|
240
|
-
|
|
241
|
-
def _get_default_template(self):
|
|
213
|
+
def _get_default_template(self) -> bytes:
|
|
242
214
|
generator_dir_path = Path(__file__).resolve().parent
|
|
243
215
|
path = generator_dir_path / "template" / "Doc Template.docx"
|
|
244
216
|
|
|
@@ -247,3 +219,7 @@ class DocxGeneratorService:
|
|
|
247
219
|
_LOGGER.info("Template downloaded from default template")
|
|
248
220
|
|
|
249
221
|
return file_content
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def template(self) -> bytes:
|
|
225
|
+
return self._template or self._get_default_template()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jinja2 import Template
|
|
4
|
+
|
|
5
|
+
from unique_toolkit._common.utils.jinja.schema import Jinja2PromptParams
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def render_template(
|
|
9
|
+
template: str, params: Jinja2PromptParams | dict[str, Any] | None = None, **kwargs
|
|
10
|
+
) -> str:
|
|
11
|
+
params = params or {}
|
|
12
|
+
|
|
13
|
+
if isinstance(params, Jinja2PromptParams):
|
|
14
|
+
params = params.model_dump(exclude_none=True, mode="json")
|
|
15
|
+
|
|
16
|
+
params.update(kwargs)
|
|
17
|
+
|
|
18
|
+
return Template(template, lstrip_blocks=True).render(**params)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from datetime import date, datetime
|
|
2
|
+
from typing import Annotated, Any
|
|
3
|
+
|
|
4
|
+
from jinja2 import Template
|
|
5
|
+
from pydantic import (
|
|
6
|
+
BaseModel,
|
|
7
|
+
ConfigDict,
|
|
8
|
+
Field,
|
|
9
|
+
SerializerFunctionWrapHandler,
|
|
10
|
+
WrapSerializer,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from unique_toolkit.agentic.tools.tool import Tool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Jinja2PromptParams(BaseModel):
|
|
17
|
+
model_config = ConfigDict(str_strip_whitespace=True)
|
|
18
|
+
|
|
19
|
+
def render_template(self, template: str) -> str:
|
|
20
|
+
params = self.model_dump(exclude_none=True, mode="json")
|
|
21
|
+
|
|
22
|
+
return Template(template, lstrip_blocks=True).render(**params)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ToolPromptParams(Jinja2PromptParams):
|
|
26
|
+
name: str
|
|
27
|
+
tool_description_for_system_prompt: str = ""
|
|
28
|
+
tool_format_information_for_system_prompt: str = ""
|
|
29
|
+
tool_format_reminder_for_user_prompt: str = ""
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_tool(cls, tool: Tool) -> "ToolPromptParams":
|
|
33
|
+
return cls(
|
|
34
|
+
name=tool.name,
|
|
35
|
+
tool_description_for_system_prompt=tool.tool_description_for_system_prompt(),
|
|
36
|
+
tool_format_information_for_system_prompt=tool.tool_format_information_for_system_prompt(),
|
|
37
|
+
tool_format_reminder_for_user_prompt=tool.tool_format_reminder_for_user_prompt(),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def serialize_iso8601_date(v: Any, handler: SerializerFunctionWrapHandler) -> str:
|
|
42
|
+
if isinstance(v, date):
|
|
43
|
+
return v.isoformat()
|
|
44
|
+
return handler(v)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
ISO8601Date = Annotated[
|
|
48
|
+
date,
|
|
49
|
+
WrapSerializer(serialize_iso8601_date, return_type=str),
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AgentSystemPromptParams(Jinja2PromptParams):
|
|
54
|
+
info_cutoff_at: ISO8601Date | None
|
|
55
|
+
current_date: ISO8601Date = Field(default_factory=lambda: datetime.now().date())
|
|
56
|
+
tools: list[ToolPromptParams]
|
|
57
|
+
used_tools: list[ToolPromptParams]
|
|
58
|
+
add_citation_appendix: bool = True
|
|
59
|
+
max_tools_per_iteration: int
|
|
60
|
+
max_loop_iterations: int
|
|
61
|
+
current_iteration: int
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AgentUserPromptParams(Jinja2PromptParams):
|
|
65
|
+
user_prompt: str
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from jinja2 import Environment
|
|
2
|
+
from jinja2.nodes import Const, Getattr, Getitem, Name
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TemplateValidationResult(BaseModel):
|
|
7
|
+
is_valid: bool
|
|
8
|
+
missing_placeholders: list[str]
|
|
9
|
+
optional_placeholders: list[str]
|
|
10
|
+
unexpected_placeholders: list[str]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_nested_variables(node):
|
|
14
|
+
"""Recursively extract all variable references from a Jinja2 AST node."""
|
|
15
|
+
variables = set()
|
|
16
|
+
|
|
17
|
+
if isinstance(node, Name):
|
|
18
|
+
variables.add(node.name)
|
|
19
|
+
elif isinstance(node, (Getattr, Getitem)):
|
|
20
|
+
# For nested attributes like example.category
|
|
21
|
+
if isinstance(node.node, Name):
|
|
22
|
+
if isinstance(node, Getattr):
|
|
23
|
+
variables.add(f"{node.node.name}.{node.attr}")
|
|
24
|
+
else: # Getitem
|
|
25
|
+
if isinstance(node.arg, Const):
|
|
26
|
+
variables.add(f"{node.node.name}.{node.arg.value}")
|
|
27
|
+
else:
|
|
28
|
+
# For dynamic indices, just use the base variable
|
|
29
|
+
variables.add(node.node.name)
|
|
30
|
+
# Recursively process nested nodes
|
|
31
|
+
variables.update(_get_nested_variables(node.node))
|
|
32
|
+
|
|
33
|
+
# Process child nodes
|
|
34
|
+
for child in node.iter_child_nodes():
|
|
35
|
+
variables.update(_get_nested_variables(child))
|
|
36
|
+
|
|
37
|
+
return variables
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def validate_template_placeholders(
|
|
41
|
+
template_content: str,
|
|
42
|
+
required_placeholders: set[str],
|
|
43
|
+
optional_placeholders: set[str],
|
|
44
|
+
) -> TemplateValidationResult:
|
|
45
|
+
"""
|
|
46
|
+
Validates that all required placeholders in the template are present.
|
|
47
|
+
Handles both top-level and nested variables (e.g. example.category).
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
template_content (str): The content of the Jinja template
|
|
51
|
+
required_placeholders (set[str]): Set of required placeholder names
|
|
52
|
+
optional_placeholders (set[str]): Set of optional placeholder names
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
TemplateValidationResult: A result object containing validation information
|
|
56
|
+
"""
|
|
57
|
+
# Create a Jinja environment
|
|
58
|
+
env = Environment()
|
|
59
|
+
|
|
60
|
+
# Parse the template and get all variables including nested ones
|
|
61
|
+
ast = env.parse(template_content)
|
|
62
|
+
template_vars = _get_nested_variables(ast)
|
|
63
|
+
|
|
64
|
+
# Check for missing required placeholders
|
|
65
|
+
missing_placeholders = required_placeholders - template_vars
|
|
66
|
+
|
|
67
|
+
# Check for optional placeholders present
|
|
68
|
+
present_optional = optional_placeholders & template_vars
|
|
69
|
+
|
|
70
|
+
# Check for any unexpected placeholders
|
|
71
|
+
unexpected_placeholders = template_vars - (
|
|
72
|
+
required_placeholders | optional_placeholders
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return TemplateValidationResult(
|
|
76
|
+
is_valid=len(missing_placeholders) == 0,
|
|
77
|
+
missing_placeholders=sorted(list(missing_placeholders)),
|
|
78
|
+
optional_placeholders=sorted(list(present_optional)),
|
|
79
|
+
unexpected_placeholders=sorted(list(unexpected_placeholders)),
|
|
80
|
+
)
|
|
@@ -5,6 +5,7 @@ Target of the method is to extend the step tracking on all levels of the tool.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from collections import defaultdict
|
|
8
|
+
from logging import getLogger
|
|
8
9
|
|
|
9
10
|
from unique_toolkit.chat.schemas import (
|
|
10
11
|
MessageLogDetails,
|
|
@@ -14,6 +15,8 @@ from unique_toolkit.chat.schemas import (
|
|
|
14
15
|
from unique_toolkit.chat.service import ChatService
|
|
15
16
|
from unique_toolkit.content.schemas import ContentReference
|
|
16
17
|
|
|
18
|
+
_LOGGER = getLogger(__name__)
|
|
19
|
+
|
|
17
20
|
# Per-request counters for message log ordering - keyed by message_id
|
|
18
21
|
# This is a mandatory global variable since we have in the system a bug which makes it impossible to use it as a proper class variable.
|
|
19
22
|
_request_counters = defaultdict(int)
|
|
@@ -71,6 +74,12 @@ class MessageStepLogger:
|
|
|
71
74
|
"""
|
|
72
75
|
|
|
73
76
|
# Creating a new message log entry with the found hits.
|
|
77
|
+
if not self._chat_service._assistant_message_id:
|
|
78
|
+
_LOGGER.warning(
|
|
79
|
+
"Assistant message id is not set. Skipping message log entry creation."
|
|
80
|
+
)
|
|
81
|
+
return
|
|
82
|
+
|
|
74
83
|
_ = self._chat_service.create_message_log(
|
|
75
84
|
message_id=self._chat_service._assistant_message_id,
|
|
76
85
|
text=text,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from typing import Literal
|
|
2
|
+
from typing import Literal, NamedTuple
|
|
3
3
|
|
|
4
|
+
from unique_toolkit._common.utils.jinja.render import render_template
|
|
4
5
|
from unique_toolkit.agentic.tools.a2a.postprocessing.config import (
|
|
5
6
|
SubAgentDisplayConfig,
|
|
6
7
|
SubAgentResponseDisplayMode,
|
|
@@ -84,6 +85,12 @@ def _prepare_title_template(
|
|
|
84
85
|
return display_title_template.replace("{}", "{%s}" % display_name_placeholder)
|
|
85
86
|
|
|
86
87
|
|
|
88
|
+
def _clean_linebreaks(text: str) -> str:
|
|
89
|
+
text = text.strip()
|
|
90
|
+
text = re.sub(r"^(<br>)*|(<br>)*$", "", text)
|
|
91
|
+
return text
|
|
92
|
+
|
|
93
|
+
|
|
87
94
|
def _get_display_template(
|
|
88
95
|
mode: SubAgentResponseDisplayMode,
|
|
89
96
|
add_quote_border: bool,
|
|
@@ -126,7 +133,7 @@ def _get_display_template(
|
|
|
126
133
|
if add_block_border:
|
|
127
134
|
template = _wrap_with_block_border(template)
|
|
128
135
|
|
|
129
|
-
return template
|
|
136
|
+
return _clean_linebreaks(template)
|
|
130
137
|
|
|
131
138
|
|
|
132
139
|
def _get_display_removal_re(
|
|
@@ -150,10 +157,51 @@ def _get_display_removal_re(
|
|
|
150
157
|
return re.compile(pattern, flags=re.DOTALL)
|
|
151
158
|
|
|
152
159
|
|
|
160
|
+
class SubAgentAnswerPart(NamedTuple):
|
|
161
|
+
matching_text: str # Matching text as found in the answer
|
|
162
|
+
formatted_text: str # Formatted text to be displayed
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def get_sub_agent_answer_parts(
|
|
166
|
+
answer: str,
|
|
167
|
+
display_config: SubAgentDisplayConfig,
|
|
168
|
+
) -> list[SubAgentAnswerPart]:
|
|
169
|
+
if display_config.mode == SubAgentResponseDisplayMode.HIDDEN:
|
|
170
|
+
return []
|
|
171
|
+
|
|
172
|
+
if len(display_config.answer_substrings_config) == 0:
|
|
173
|
+
return [SubAgentAnswerPart(matching_text=answer, formatted_text=answer)]
|
|
174
|
+
|
|
175
|
+
substrings = []
|
|
176
|
+
for config in display_config.answer_substrings_config:
|
|
177
|
+
for match in config.regexp.finditer(answer):
|
|
178
|
+
text = match.group(0)
|
|
179
|
+
substrings.append(
|
|
180
|
+
SubAgentAnswerPart(
|
|
181
|
+
matching_text=text,
|
|
182
|
+
formatted_text=config.display_template.format(text),
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return substrings
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_sub_agent_answer_from_parts(
|
|
190
|
+
answer_parts: list[SubAgentAnswerPart],
|
|
191
|
+
config: SubAgentDisplayConfig,
|
|
192
|
+
) -> str:
|
|
193
|
+
return render_template(
|
|
194
|
+
config.answer_substrings_jinja_template,
|
|
195
|
+
{
|
|
196
|
+
"substrings": [answer.formatted_text for answer in answer_parts],
|
|
197
|
+
},
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
153
201
|
def get_sub_agent_answer_display(
|
|
154
202
|
display_name: str,
|
|
155
203
|
display_config: SubAgentDisplayConfig,
|
|
156
|
-
answer: str,
|
|
204
|
+
answer: str | list[SubAgentAnswerPart],
|
|
157
205
|
assistant_id: str,
|
|
158
206
|
) -> str:
|
|
159
207
|
template = _get_display_template(
|
|
@@ -162,6 +210,13 @@ def get_sub_agent_answer_display(
|
|
|
162
210
|
add_block_border=display_config.add_block_border,
|
|
163
211
|
display_title_template=display_config.display_title_template,
|
|
164
212
|
)
|
|
213
|
+
|
|
214
|
+
if isinstance(answer, list):
|
|
215
|
+
answer = get_sub_agent_answer_from_parts(
|
|
216
|
+
answer_parts=answer,
|
|
217
|
+
config=display_config,
|
|
218
|
+
)
|
|
219
|
+
|
|
165
220
|
return template.format(
|
|
166
221
|
display_name=display_name, answer=answer, assistant_id=assistant_id
|
|
167
222
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from typing import Callable, Iterable, Mapping, Sequence
|
|
2
3
|
|
|
3
4
|
from unique_toolkit._common.referencing import get_reference_pattern
|
|
@@ -50,6 +51,16 @@ def add_content_refs(
|
|
|
50
51
|
return message_refs
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
def remove_unused_refs(
|
|
55
|
+
references: Sequence[ContentReference],
|
|
56
|
+
text: str,
|
|
57
|
+
ref_pattern_f: Callable[[int], str] = get_reference_pattern,
|
|
58
|
+
) -> list[ContentReference]:
|
|
59
|
+
return [
|
|
60
|
+
ref for ref in references if re.search(ref_pattern_f(ref.sequence_number), text)
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
53
64
|
def add_content_refs_and_replace_in_text(
|
|
54
65
|
message_text: str,
|
|
55
66
|
message_refs: Sequence[ContentReference],
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from enum import StrEnum
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -13,6 +14,25 @@ class SubAgentResponseDisplayMode(StrEnum):
|
|
|
13
14
|
PLAIN = "plain"
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
class SubAgentAnswerSubstringConfig(BaseModel):
|
|
18
|
+
model_config = get_configuration_dict()
|
|
19
|
+
|
|
20
|
+
regexp: re.Pattern[str] = Field(
|
|
21
|
+
description="The regular expression to use to extract the substring. The first capture group will always be used.",
|
|
22
|
+
)
|
|
23
|
+
display_template: str = Field(
|
|
24
|
+
default="{}",
|
|
25
|
+
description="The template to use to display the substring. It should contain exactly one empty placeholder '{}' for the substring.",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
_ANSWER_SUBSTRINGS_JINJA_TEMPLATE = """
|
|
30
|
+
{% for substring in substrings %}
|
|
31
|
+
{{ substring }}
|
|
32
|
+
{% endfor %}
|
|
33
|
+
""".strip()
|
|
34
|
+
|
|
35
|
+
|
|
16
36
|
class SubAgentDisplayConfig(BaseModel):
|
|
17
37
|
model_config = get_configuration_dict()
|
|
18
38
|
|
|
@@ -43,3 +63,16 @@ class SubAgentDisplayConfig(BaseModel):
|
|
|
43
63
|
default="before",
|
|
44
64
|
description="The position of the sub agent response in the main agent response.",
|
|
45
65
|
)
|
|
66
|
+
force_include_references: bool = Field(
|
|
67
|
+
default=False,
|
|
68
|
+
description="If set, the sub agent references will be added to the main agent response references even in not mentioned in the main agent response text.",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
answer_substrings_config: list[SubAgentAnswerSubstringConfig] = Field(
|
|
72
|
+
default=[],
|
|
73
|
+
description="If set, only parts of the answer matching the provided regular expressions will be displayed.",
|
|
74
|
+
)
|
|
75
|
+
answer_substrings_jinja_template: str = Field(
|
|
76
|
+
default=_ANSWER_SUBSTRINGS_JINJA_TEMPLATE,
|
|
77
|
+
description="The template to use in order to format the different answer substrings, if any.",
|
|
78
|
+
)
|
|
@@ -7,13 +7,18 @@ import unique_sdk
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from unique_toolkit._common.pydantic_helpers import get_configuration_dict
|
|
10
|
+
from unique_toolkit._common.utils.jinja.render import render_template
|
|
10
11
|
from unique_toolkit.agentic.postprocessor.postprocessor_manager import Postprocessor
|
|
11
12
|
from unique_toolkit.agentic.tools.a2a.postprocessing._display_utils import (
|
|
13
|
+
SubAgentAnswerPart,
|
|
12
14
|
get_sub_agent_answer_display,
|
|
15
|
+
get_sub_agent_answer_from_parts,
|
|
16
|
+
get_sub_agent_answer_parts,
|
|
13
17
|
remove_sub_agent_answer_from_text,
|
|
14
18
|
)
|
|
15
19
|
from unique_toolkit.agentic.tools.a2a.postprocessing._ref_utils import (
|
|
16
20
|
add_content_refs_and_replace_in_text,
|
|
21
|
+
remove_unused_refs,
|
|
17
22
|
)
|
|
18
23
|
from unique_toolkit.agentic.tools.a2a.postprocessing.config import (
|
|
19
24
|
SubAgentDisplayConfig,
|
|
@@ -37,11 +42,26 @@ class SubAgentDisplaySpec(NamedTuple):
|
|
|
37
42
|
display_config: SubAgentDisplayConfig
|
|
38
43
|
|
|
39
44
|
|
|
45
|
+
_ANSWERS_JINJA_TEMPLATE = """
|
|
46
|
+
{% for answer in answers %}
|
|
47
|
+
{{ answer }}
|
|
48
|
+
{% endfor %}
|
|
49
|
+
""".strip()
|
|
50
|
+
|
|
51
|
+
|
|
40
52
|
class SubAgentResponsesPostprocessorConfig(BaseModel):
|
|
41
53
|
model_config = get_configuration_dict()
|
|
42
54
|
|
|
43
55
|
sleep_time_before_update: float = Field(
|
|
44
|
-
default=
|
|
56
|
+
default=0, description="Time to sleep before updating the main agent message."
|
|
57
|
+
)
|
|
58
|
+
answers_jinja_template: str = Field(
|
|
59
|
+
default=_ANSWERS_JINJA_TEMPLATE,
|
|
60
|
+
description="The template to use to display the sub agent answers.",
|
|
61
|
+
)
|
|
62
|
+
filter_duplicate_answers: bool = Field(
|
|
63
|
+
default=True,
|
|
64
|
+
description="If set, duplicate answers will be filtered out.",
|
|
45
65
|
)
|
|
46
66
|
|
|
47
67
|
|
|
@@ -60,6 +80,8 @@ class SubAgentResponsesDisplayPostprocessor(Postprocessor):
|
|
|
60
80
|
display_spec.assistant_id: display_spec
|
|
61
81
|
for display_spec in display_specs
|
|
62
82
|
if display_spec.display_config.mode != SubAgentResponseDisplayMode.HIDDEN
|
|
83
|
+
# We should keep track of these messages even if they are hidden
|
|
84
|
+
or display_spec.display_config.force_include_references
|
|
63
85
|
}
|
|
64
86
|
|
|
65
87
|
@override
|
|
@@ -93,19 +115,23 @@ class SubAgentResponsesDisplayPostprocessor(Postprocessor):
|
|
|
93
115
|
|
|
94
116
|
answers_displayed_before = []
|
|
95
117
|
answers_displayed_after = []
|
|
118
|
+
all_displayed_answers = set()
|
|
96
119
|
|
|
97
120
|
for assistant_id, responses in displayed_sub_agent_responses.items():
|
|
121
|
+
tool_info = self._display_specs[assistant_id]
|
|
122
|
+
tool_name = tool_info.display_name
|
|
123
|
+
|
|
98
124
|
for response in responses:
|
|
99
125
|
message = response.message
|
|
100
|
-
tool_info = self._display_specs[assistant_id]
|
|
101
|
-
|
|
102
|
-
_add_response_references_to_message_in_place(
|
|
103
|
-
loop_response=loop_response, response=message
|
|
104
|
-
)
|
|
105
126
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
127
|
+
if tool_info.display_config.mode == SubAgentResponseDisplayMode.HIDDEN:
|
|
128
|
+
# Add references and continue
|
|
129
|
+
_add_response_references_to_message_in_place(
|
|
130
|
+
loop_response=loop_response,
|
|
131
|
+
response=message,
|
|
132
|
+
remove_unused_references=False,
|
|
133
|
+
)
|
|
134
|
+
continue
|
|
109
135
|
|
|
110
136
|
if message["text"] is None:
|
|
111
137
|
logger.warning(
|
|
@@ -113,11 +139,44 @@ class SubAgentResponsesDisplayPostprocessor(Postprocessor):
|
|
|
113
139
|
assistant_id,
|
|
114
140
|
response.sequence_number,
|
|
115
141
|
)
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
answer_parts = get_sub_agent_answer_parts(
|
|
145
|
+
answer=message["text"],
|
|
146
|
+
display_config=tool_info.display_config,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if self._config.filter_duplicate_answers:
|
|
150
|
+
answer_parts, all_displayed_answers = (
|
|
151
|
+
_filter_and_update_duplicate_answers(
|
|
152
|
+
answers=answer_parts,
|
|
153
|
+
existing_answers=all_displayed_answers,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
answer = get_sub_agent_answer_from_parts(
|
|
158
|
+
answer_parts=answer_parts,
|
|
159
|
+
config=tool_info.display_config,
|
|
160
|
+
)
|
|
161
|
+
message["text"] = answer
|
|
162
|
+
|
|
163
|
+
_add_response_references_to_message_in_place(
|
|
164
|
+
loop_response=loop_response,
|
|
165
|
+
response=message,
|
|
166
|
+
remove_unused_references=not tool_info.display_config.force_include_references,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if len(answer_parts) == 0:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
display_name = tool_name
|
|
173
|
+
if len(responses) > 1:
|
|
174
|
+
display_name = tool_name + f" {response.sequence_number}"
|
|
116
175
|
|
|
117
176
|
answer = get_sub_agent_answer_display(
|
|
118
177
|
display_name=display_name,
|
|
119
178
|
display_config=tool_info.display_config,
|
|
120
|
-
answer=
|
|
179
|
+
answer=answer,
|
|
121
180
|
assistant_id=assistant_id,
|
|
122
181
|
)
|
|
123
182
|
|
|
@@ -130,6 +189,7 @@ class SubAgentResponsesDisplayPostprocessor(Postprocessor):
|
|
|
130
189
|
text=loop_response.message.text,
|
|
131
190
|
answers_before=answers_displayed_before,
|
|
132
191
|
answers_after=answers_displayed_after,
|
|
192
|
+
template=self._config.answers_jinja_template,
|
|
133
193
|
)
|
|
134
194
|
|
|
135
195
|
return True
|
|
@@ -146,7 +206,9 @@ class SubAgentResponsesDisplayPostprocessor(Postprocessor):
|
|
|
146
206
|
|
|
147
207
|
|
|
148
208
|
def _add_response_references_to_message_in_place(
|
|
149
|
-
loop_response: LanguageModelStreamResponse,
|
|
209
|
+
loop_response: LanguageModelStreamResponse,
|
|
210
|
+
response: unique_sdk.Space.Message,
|
|
211
|
+
remove_unused_references: bool = True,
|
|
150
212
|
) -> None:
|
|
151
213
|
references = response["references"]
|
|
152
214
|
text = response["text"]
|
|
@@ -156,6 +218,12 @@ def _add_response_references_to_message_in_place(
|
|
|
156
218
|
|
|
157
219
|
content_refs = [ContentReference.from_sdk_reference(ref) for ref in references]
|
|
158
220
|
|
|
221
|
+
if remove_unused_references:
|
|
222
|
+
content_refs = remove_unused_refs(
|
|
223
|
+
references=content_refs,
|
|
224
|
+
text=text,
|
|
225
|
+
)
|
|
226
|
+
|
|
159
227
|
text, refs = add_content_refs_and_replace_in_text(
|
|
160
228
|
message_text=text,
|
|
161
229
|
message_refs=loop_response.message.references,
|
|
@@ -170,11 +238,27 @@ def _get_final_answer_display(
|
|
|
170
238
|
text: str,
|
|
171
239
|
answers_before: list[str],
|
|
172
240
|
answers_after: list[str],
|
|
173
|
-
|
|
241
|
+
template: str = _ANSWERS_JINJA_TEMPLATE,
|
|
174
242
|
) -> str:
|
|
175
243
|
if len(answers_before) > 0:
|
|
176
|
-
text =
|
|
244
|
+
text = render_template(template, {"answers": answers_before}) + text
|
|
177
245
|
|
|
178
246
|
if len(answers_after) > 0:
|
|
179
|
-
text = text +
|
|
180
|
-
|
|
247
|
+
text = text + render_template(template, {"answers": answers_after})
|
|
248
|
+
|
|
249
|
+
return text.strip()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _filter_and_update_duplicate_answers(
|
|
253
|
+
answers: list[SubAgentAnswerPart],
|
|
254
|
+
existing_answers: set[str],
|
|
255
|
+
) -> tuple[list[SubAgentAnswerPart], set[str]]:
|
|
256
|
+
new_answers = []
|
|
257
|
+
|
|
258
|
+
for answer in answers:
|
|
259
|
+
if answer.matching_text in existing_answers:
|
|
260
|
+
continue
|
|
261
|
+
existing_answers.add(answer.matching_text)
|
|
262
|
+
new_answers.append(answer)
|
|
263
|
+
|
|
264
|
+
return new_answers, existing_answers
|