unique_toolkit 0.7.25__py3-none-any.whl → 0.7.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/app/schemas.py +52 -8
- unique_toolkit/chat/functions.py +79 -4
- unique_toolkit/chat/service.py +60 -13
- unique_toolkit/content/functions.py +7 -6
- unique_toolkit/content/schemas.py +2 -2
- unique_toolkit/content/service.py +1 -1
- unique_toolkit/language_model/functions.py +142 -9
- unique_toolkit/language_model/reference.py +244 -0
- unique_toolkit/language_model/service.py +57 -1
- unique_toolkit/protocols/support.py +36 -2
- unique_toolkit/smart_rules/__init__.py +0 -0
- unique_toolkit/smart_rules/compile.py +301 -0
- {unique_toolkit-0.7.25.dist-info → unique_toolkit-0.7.27.dist-info}/METADATA +11 -1
- {unique_toolkit-0.7.25.dist-info → unique_toolkit-0.7.27.dist-info}/RECORD +16 -13
- {unique_toolkit-0.7.25.dist-info → unique_toolkit-0.7.27.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.7.25.dist-info → unique_toolkit-0.7.27.dist-info}/WHEEL +0 -0
@@ -0,0 +1,244 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from unique_toolkit.chat.schemas import ChatMessage, Reference
|
6
|
+
|
7
|
+
|
8
|
+
class NodeReference(Reference):
|
9
|
+
original_index: list[int] = []
|
10
|
+
message_id: str | None = None
|
11
|
+
|
12
|
+
|
13
|
+
class PotentialReference(BaseModel):
|
14
|
+
id: str
|
15
|
+
chunk_id: str | None = None
|
16
|
+
title: str | None = None
|
17
|
+
key: str
|
18
|
+
url: str | None = None
|
19
|
+
internally_stored_at: str | None = None
|
20
|
+
|
21
|
+
|
22
|
+
def add_references_to_message(
|
23
|
+
message: ChatMessage,
|
24
|
+
search_context: list[PotentialReference],
|
25
|
+
model: str | None = None,
|
26
|
+
) -> tuple[ChatMessage, bool]:
|
27
|
+
"""Add references to a message and return the updated message with change status.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
Tuple[ChatMessage, bool]: (updated_message, references_changed)
|
31
|
+
"""
|
32
|
+
if not message.content:
|
33
|
+
return message, False
|
34
|
+
|
35
|
+
if message.id is None:
|
36
|
+
raise ValueError("Message ID is required")
|
37
|
+
|
38
|
+
message.content = _preprocess_message(message.content)
|
39
|
+
text, ref_found = _add_references(
|
40
|
+
message.content, search_context, message.id, model
|
41
|
+
)
|
42
|
+
message.content = _postprocess_message(text)
|
43
|
+
|
44
|
+
message.references = [Reference(**ref.model_dump()) for ref in ref_found]
|
45
|
+
references_changed = len(ref_found) > 0
|
46
|
+
return message, references_changed
|
47
|
+
|
48
|
+
|
49
|
+
def _add_references(
|
50
|
+
text: str,
|
51
|
+
search_context: list[PotentialReference],
|
52
|
+
message_id: str,
|
53
|
+
model: str | None = None,
|
54
|
+
) -> tuple[str, list[NodeReference]]:
|
55
|
+
"""Add references to text and return the processed text with reference status.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Tuple[str, bool]: (processed_text, ref_found)
|
59
|
+
"""
|
60
|
+
references = _find_references(
|
61
|
+
text=text,
|
62
|
+
search_context=search_context,
|
63
|
+
message_id=message_id,
|
64
|
+
)
|
65
|
+
|
66
|
+
# Only reference a source once, even if it is mentioned multiple times in the text.
|
67
|
+
with_footnotes = _add_footnotes_to_text(text=text, references=references)
|
68
|
+
|
69
|
+
# Gemini 2.5 flash model has tendency to add multiple references for the same fact
|
70
|
+
# This is a workaround to limit the number of references to 5
|
71
|
+
if model and model.startswith("litellm:gemini-2-5-flash"):
|
72
|
+
reduced_text = _limit_consecutive_source_references(with_footnotes)
|
73
|
+
|
74
|
+
# Get the references that remain after reduction
|
75
|
+
remaining_numbers = set()
|
76
|
+
sup_matches = re.findall(r"<sup>(\d+)</sup>", reduced_text)
|
77
|
+
remaining_numbers = {int(match) for match in sup_matches}
|
78
|
+
|
79
|
+
references = [
|
80
|
+
ref for ref in references if ref.sequence_number in remaining_numbers
|
81
|
+
]
|
82
|
+
text = _remove_hallucinated_references(reduced_text)
|
83
|
+
else:
|
84
|
+
text = _remove_hallucinated_references(with_footnotes)
|
85
|
+
|
86
|
+
return text, references
|
87
|
+
|
88
|
+
|
89
|
+
def _preprocess_message(text: str) -> str:
|
90
|
+
"""Preprocess message text to normalize reference formats."""
|
91
|
+
# Remove user & assistant references: XML format '[<user>]', '[\<user>]', etc.
|
92
|
+
patterns = [
|
93
|
+
(r"\[(\\)?(<)?user(>)?\]", ""),
|
94
|
+
(r"\[(\\)?(<)?assistant(>)?\]", ""),
|
95
|
+
(r"source[\s]?\[(\\)?(<)?conversation(>)?\]", "the previous conversation"),
|
96
|
+
(r"\[(\\)?(<)?previous[_,\s]conversation(>)?\]", ""),
|
97
|
+
(r"\[(\\)?(<)?past[_,\s]conversation(>)?\]", ""),
|
98
|
+
(r"\[(\\)?(<)?previous[_,\s]?answer(>)?\]", ""),
|
99
|
+
(r"\[(\\)?(<)?previous[_,\s]question(>)?\]", ""),
|
100
|
+
(r"\[(\\)?(<)?conversation(>)?\]", ""),
|
101
|
+
(r"\[(\\)?(<)?none(>)?\]", ""),
|
102
|
+
]
|
103
|
+
|
104
|
+
for pattern, replacement in patterns:
|
105
|
+
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
|
106
|
+
|
107
|
+
# Replace XML format '[<source XX>]', '[<sourceXX>]' and '[\<sourceXX>]' with [XX]
|
108
|
+
text = re.sub(r"\[(\\)?<source[\s]?(\d+)>\]", r"[\2]", text)
|
109
|
+
|
110
|
+
# Replace format '[source XX]' and '[sourceXX]' with [XX]
|
111
|
+
text = re.sub(r"\[source[\s]?(\d+)\]", r"[\1]", text)
|
112
|
+
|
113
|
+
# Make all references non-bold
|
114
|
+
text = re.sub(r"\[\*\*(\d+)\*\*\]", r"[\1]", text)
|
115
|
+
|
116
|
+
# Replace 'SOURCEXX' and 'SOURCE XX' with [XX]
|
117
|
+
text = re.sub(r"source[\s]?(\d+)", r"[\1]", text, flags=re.IGNORECASE)
|
118
|
+
|
119
|
+
# Replace 'SOURCE n°X' with [XX]
|
120
|
+
text = re.sub(r"source[\s]?n°(\d+)", r"[\1]", text, flags=re.IGNORECASE)
|
121
|
+
|
122
|
+
# Replace '[<[XX]>]' and '[\<[XX]>]' with [XX]
|
123
|
+
text = re.sub(r"\[(\\)?\[?<\[(\d+)\]?\]>\]", r"[\2]", text)
|
124
|
+
|
125
|
+
# Replace '[[A], [B], ...]' or '[[A], B, C, ...]' with [A][B][C]...
|
126
|
+
def replace_combined_brackets(match):
|
127
|
+
numbers = re.findall(r"\d+", match.group(0))
|
128
|
+
return "".join(f"[{n}]" for n in numbers)
|
129
|
+
|
130
|
+
text = re.sub(
|
131
|
+
r"\[\[(\d+)\](?:,\s*(?:\[)?\d+(?:\])?)*\]", replace_combined_brackets, text
|
132
|
+
)
|
133
|
+
|
134
|
+
return text
|
135
|
+
|
136
|
+
|
137
|
+
def _limit_consecutive_source_references(text: str) -> str:
|
138
|
+
"""Limit consecutive source references to maximum 5 unique sources."""
|
139
|
+
|
140
|
+
def replace_consecutive(match):
|
141
|
+
# Extract all numbers from the match and get unique values
|
142
|
+
numbers = list(set(re.findall(r"\d+", match.group(0))))
|
143
|
+
# Take only the first five unique numbers
|
144
|
+
return "".join(f"<sup>{n}</sup>" for n in numbers[:5])
|
145
|
+
|
146
|
+
# Find sequences of 5+ consecutive sources
|
147
|
+
pattern = r"(?:<sup>\d+</sup>){5,}"
|
148
|
+
return re.sub(pattern, replace_consecutive, text)
|
149
|
+
|
150
|
+
|
151
|
+
def _postprocess_message(text: str) -> str:
|
152
|
+
"""Format superscript references to remove duplicates."""
|
153
|
+
|
154
|
+
def replace_sup_sequence(match):
|
155
|
+
# Extract unique numbers from the entire match
|
156
|
+
sup_numbers = set(re.findall(r"\d+", match.group(0)))
|
157
|
+
return "".join(f"<sup>{n}</sup>" for n in sup_numbers)
|
158
|
+
|
159
|
+
# Find sequences of 2+ superscripts including internal spaces
|
160
|
+
pattern = r"(<sup>\d+</sup>[ ]*)+<sup>\d+</sup>"
|
161
|
+
return re.sub(pattern, replace_sup_sequence, text)
|
162
|
+
|
163
|
+
|
164
|
+
def _get_max_sub_count_in_text(text: str) -> int:
|
165
|
+
"""Get the maximum superscript number in the text."""
|
166
|
+
matches = re.findall(r"<sup>(\d+)</sup>", text)
|
167
|
+
return max((int(match) for match in matches), default=0)
|
168
|
+
|
169
|
+
|
170
|
+
def _find_references(
|
171
|
+
text: str,
|
172
|
+
search_context: list[PotentialReference],
|
173
|
+
message_id: str,
|
174
|
+
) -> list[NodeReference]:
|
175
|
+
"""Find references in text based on search context."""
|
176
|
+
references: list[NodeReference] = []
|
177
|
+
sequence_number = 1 + _get_max_sub_count_in_text(text)
|
178
|
+
|
179
|
+
# Find all numbers in brackets to ensure we get references in order of occurrence
|
180
|
+
numbers_in_brackets = _extract_numbers_in_brackets(text)
|
181
|
+
|
182
|
+
for number in numbers_in_brackets:
|
183
|
+
# Convert 1-based reference to 0-based index
|
184
|
+
index = number - 1
|
185
|
+
if index < 0 or index >= len(search_context):
|
186
|
+
continue
|
187
|
+
|
188
|
+
search = search_context[index]
|
189
|
+
if not search:
|
190
|
+
continue
|
191
|
+
|
192
|
+
# Don't put the reference twice
|
193
|
+
reference_name = search.title or search.key
|
194
|
+
found_reference = next(
|
195
|
+
(r for r in references if r.name == reference_name), None
|
196
|
+
)
|
197
|
+
|
198
|
+
if found_reference:
|
199
|
+
found_reference.original_index.append(number)
|
200
|
+
continue
|
201
|
+
|
202
|
+
url = (
|
203
|
+
search.url
|
204
|
+
if search.url and not search.internally_stored_at
|
205
|
+
else f"unique://content/{search.id}"
|
206
|
+
)
|
207
|
+
|
208
|
+
references.append(
|
209
|
+
NodeReference(
|
210
|
+
name=reference_name,
|
211
|
+
url=url,
|
212
|
+
sequence_number=sequence_number,
|
213
|
+
original_index=[number],
|
214
|
+
source_id=f"{search.id}_{search.chunk_id}"
|
215
|
+
if search.chunk_id
|
216
|
+
else search.id,
|
217
|
+
source="node-ingestion-chunks",
|
218
|
+
message_id=message_id,
|
219
|
+
)
|
220
|
+
)
|
221
|
+
sequence_number += 1
|
222
|
+
|
223
|
+
return references
|
224
|
+
|
225
|
+
|
226
|
+
def _extract_numbers_in_brackets(text: str) -> list[int]:
|
227
|
+
"""Extract numbers from [X] format in text."""
|
228
|
+
matches = re.findall(r"\[(\d+)\]", text)
|
229
|
+
return [int(match) for match in matches]
|
230
|
+
|
231
|
+
|
232
|
+
def _add_footnotes_to_text(text: str, references: list[NodeReference]) -> str:
|
233
|
+
"""Replace bracket references with superscript footnotes."""
|
234
|
+
for reference in references:
|
235
|
+
for original_index in reference.original_index:
|
236
|
+
text = text.replace(
|
237
|
+
f"[{original_index}]", f"<sup>{reference.sequence_number}</sup>"
|
238
|
+
)
|
239
|
+
return text
|
240
|
+
|
241
|
+
|
242
|
+
def _remove_hallucinated_references(text: str) -> str:
|
243
|
+
"""Remove any remaining bracket references that weren't converted."""
|
244
|
+
return re.sub(r"\[\d+\]", "", text).strip()
|
@@ -1,11 +1,12 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Optional, Type
|
2
|
+
from typing import Any, Optional, Type
|
3
3
|
|
4
4
|
from pydantic import BaseModel
|
5
5
|
from typing_extensions import deprecated
|
6
6
|
|
7
7
|
from unique_toolkit._common.validate_required_values import validate_required_values
|
8
8
|
from unique_toolkit.app.schemas import BaseEvent, ChatEvent, Event
|
9
|
+
from unique_toolkit.content.schemas import ContentChunk
|
9
10
|
from unique_toolkit.language_model.constants import (
|
10
11
|
DEFAULT_COMPLETE_TEMPERATURE,
|
11
12
|
DEFAULT_COMPLETE_TIMEOUT,
|
@@ -14,11 +15,14 @@ from unique_toolkit.language_model.constants import (
|
|
14
15
|
from unique_toolkit.language_model.functions import (
|
15
16
|
complete,
|
16
17
|
complete_async,
|
18
|
+
complete_with_references,
|
19
|
+
complete_with_references_async,
|
17
20
|
)
|
18
21
|
from unique_toolkit.language_model.infos import LanguageModelName
|
19
22
|
from unique_toolkit.language_model.schemas import (
|
20
23
|
LanguageModelMessages,
|
21
24
|
LanguageModelResponse,
|
25
|
+
LanguageModelStreamResponse,
|
22
26
|
LanguageModelTool,
|
23
27
|
LanguageModelToolDescription,
|
24
28
|
)
|
@@ -260,3 +264,55 @@ class LanguageModelService:
|
|
260
264
|
structured_output_model=structured_output_model,
|
261
265
|
structured_output_enforce_schema=structured_output_enforce_schema,
|
262
266
|
)
|
267
|
+
|
268
|
+
def complete_with_references(
|
269
|
+
self,
|
270
|
+
messages: LanguageModelMessages,
|
271
|
+
model_name: LanguageModelName | str,
|
272
|
+
content_chunks: list[ContentChunk] | None = None,
|
273
|
+
debug_info: dict = {},
|
274
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
275
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
276
|
+
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
277
|
+
start_text: str | None = None,
|
278
|
+
other_options: dict[str, Any] | None = None,
|
279
|
+
) -> LanguageModelStreamResponse:
|
280
|
+
[company_id] = validate_required_values([self._company_id])
|
281
|
+
|
282
|
+
return complete_with_references(
|
283
|
+
company_id=company_id,
|
284
|
+
messages=messages,
|
285
|
+
model_name=model_name,
|
286
|
+
content_chunks=content_chunks,
|
287
|
+
temperature=temperature,
|
288
|
+
timeout=timeout,
|
289
|
+
other_options=other_options,
|
290
|
+
tools=tools,
|
291
|
+
start_text=start_text,
|
292
|
+
)
|
293
|
+
|
294
|
+
async def complete_with_references_async(
|
295
|
+
self,
|
296
|
+
messages: LanguageModelMessages,
|
297
|
+
model_name: LanguageModelName | str,
|
298
|
+
content_chunks: list[ContentChunk] | None = None,
|
299
|
+
debug_info: dict = {},
|
300
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
301
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
302
|
+
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
303
|
+
start_text: str | None = None,
|
304
|
+
other_options: dict[str, Any] | None = None,
|
305
|
+
) -> LanguageModelStreamResponse:
|
306
|
+
[company_id] = validate_required_values([self._company_id])
|
307
|
+
|
308
|
+
return await complete_with_references_async(
|
309
|
+
company_id=company_id,
|
310
|
+
messages=messages,
|
311
|
+
model_name=model_name,
|
312
|
+
content_chunks=content_chunks,
|
313
|
+
temperature=temperature,
|
314
|
+
timeout=timeout,
|
315
|
+
other_options=other_options,
|
316
|
+
tools=tools,
|
317
|
+
start_text=start_text,
|
318
|
+
)
|
@@ -1,9 +1,11 @@
|
|
1
|
-
from typing import Protocol
|
1
|
+
from typing import Any, Awaitable, Protocol
|
2
2
|
|
3
|
+
from unique_toolkit.content import ContentChunk
|
3
4
|
from unique_toolkit.language_model import (
|
4
5
|
LanguageModelMessages,
|
5
6
|
LanguageModelName,
|
6
7
|
LanguageModelResponse,
|
8
|
+
LanguageModelStreamResponse,
|
7
9
|
LanguageModelTool,
|
8
10
|
LanguageModelToolDescription,
|
9
11
|
)
|
@@ -25,5 +27,37 @@ class SupportsComplete(Protocol):
|
|
25
27
|
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
26
28
|
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
27
29
|
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
28
|
-
**kwargs,
|
29
30
|
) -> LanguageModelResponse: ...
|
31
|
+
|
32
|
+
async def complete_async(
|
33
|
+
self,
|
34
|
+
messages: LanguageModelMessages,
|
35
|
+
model_name: LanguageModelName | str,
|
36
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
37
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
38
|
+
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
39
|
+
) -> Awaitable[LanguageModelResponse]: ...
|
40
|
+
|
41
|
+
|
42
|
+
class SupportCompleteWithReferences(Protocol):
|
43
|
+
def complete_with_references(
|
44
|
+
self,
|
45
|
+
messages: LanguageModelMessages,
|
46
|
+
model_name: LanguageModelName | str,
|
47
|
+
content_chunks: list[ContentChunk] | None = None,
|
48
|
+
debug_info: dict[str, Any] = {},
|
49
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
50
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
51
|
+
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
52
|
+
) -> LanguageModelStreamResponse: ...
|
53
|
+
|
54
|
+
def complete_with_references_async(
|
55
|
+
self,
|
56
|
+
messages: LanguageModelMessages,
|
57
|
+
model_name: LanguageModelName | str,
|
58
|
+
content_chunks: list[ContentChunk] | None = None,
|
59
|
+
debug_info: dict[str, Any] = {},
|
60
|
+
temperature: float = DEFAULT_COMPLETE_TEMPERATURE,
|
61
|
+
timeout: int = DEFAULT_COMPLETE_TIMEOUT,
|
62
|
+
tools: list[LanguageModelTool | LanguageModelToolDescription] | None = None,
|
63
|
+
) -> Awaitable[LanguageModelStreamResponse]: ...
|
File without changes
|
@@ -0,0 +1,301 @@
|
|
1
|
+
import re
|
2
|
+
from datetime import datetime, timedelta, timezone
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Any, Dict, List, Self, Union
|
5
|
+
|
6
|
+
from pydantic import AliasChoices, BaseModel, Field
|
7
|
+
from pydantic.config import ConfigDict
|
8
|
+
|
9
|
+
|
10
|
+
class Operator(str, Enum):
|
11
|
+
EQUALS = "equals"
|
12
|
+
NOT_EQUALS = "notEquals"
|
13
|
+
GREATER_THAN = "greaterThan"
|
14
|
+
GREATER_THAN_OR_EQUAL = "greaterThanOrEqual"
|
15
|
+
LESS_THAN = "lessThan"
|
16
|
+
LESS_THAN_OR_EQUAL = "lessThanOrEqual"
|
17
|
+
IN = "in"
|
18
|
+
NOT_IN = "notIn"
|
19
|
+
CONTAINS = "contains"
|
20
|
+
NOT_CONTAINS = "notContains"
|
21
|
+
IS_NULL = "isNull"
|
22
|
+
IS_NOT_NULL = "isNotNull"
|
23
|
+
IS_EMPTY = "isEmpty"
|
24
|
+
IS_NOT_EMPTY = "isNotEmpty"
|
25
|
+
NESTED = "nested"
|
26
|
+
|
27
|
+
|
28
|
+
class BaseStatement(BaseModel):
|
29
|
+
model_config: ConfigDict = {"serialize_by_alias": True}
|
30
|
+
|
31
|
+
def with_variables(
|
32
|
+
self,
|
33
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
34
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
35
|
+
) -> Self:
|
36
|
+
return self._fill_in_variables(user_metadata, tool_parameters)
|
37
|
+
|
38
|
+
def is_compiled(self) -> bool:
|
39
|
+
# Serialize the object to json string
|
40
|
+
json_str = self.model_dump_json()
|
41
|
+
# Check if the json string has <T> or <T+> or <T-> or <toolParameters or <userMetadata
|
42
|
+
return (
|
43
|
+
"<T>" in json_str
|
44
|
+
or "<T+" in json_str
|
45
|
+
or "<T-" in json_str
|
46
|
+
or "<toolParameters" in json_str
|
47
|
+
or "<userMetadata" in json_str
|
48
|
+
)
|
49
|
+
|
50
|
+
def _fill_in_variables(
|
51
|
+
self,
|
52
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
53
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
54
|
+
) -> Self:
|
55
|
+
return self.model_copy()
|
56
|
+
|
57
|
+
|
58
|
+
class Statement(BaseStatement):
|
59
|
+
operator: Operator
|
60
|
+
value: Union[str, int, bool, list[str], "AndStatement", "OrStatement"]
|
61
|
+
path: List[str] = Field(default_factory=list)
|
62
|
+
|
63
|
+
def _fill_in_variables(
|
64
|
+
self,
|
65
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
66
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
67
|
+
) -> Self:
|
68
|
+
new_stmt = self.model_copy()
|
69
|
+
new_stmt.value = eval_operator(self, user_metadata, tool_parameters)
|
70
|
+
return new_stmt
|
71
|
+
|
72
|
+
|
73
|
+
class AndStatement(BaseStatement):
|
74
|
+
and_list: List[Union["Statement", "AndStatement", "OrStatement"]] = Field(
|
75
|
+
alias="and", validation_alias=AliasChoices("and", "and_list")
|
76
|
+
)
|
77
|
+
|
78
|
+
def _fill_in_variables(
|
79
|
+
self,
|
80
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
81
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
82
|
+
) -> Self:
|
83
|
+
new_stmt = self.model_copy()
|
84
|
+
new_stmt.and_list = [
|
85
|
+
sub_query._fill_in_variables(user_metadata, tool_parameters)
|
86
|
+
for sub_query in self.and_list
|
87
|
+
]
|
88
|
+
return new_stmt
|
89
|
+
|
90
|
+
|
91
|
+
class OrStatement(BaseStatement):
|
92
|
+
or_list: List[Union["Statement", "AndStatement", "OrStatement"]] = Field(
|
93
|
+
alias="or", validation_alias=AliasChoices("or", "or_list")
|
94
|
+
)
|
95
|
+
|
96
|
+
def _fill_in_variables(
|
97
|
+
self,
|
98
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
99
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
100
|
+
) -> Self:
|
101
|
+
new_stmt = self.model_copy()
|
102
|
+
new_stmt.or_list = [
|
103
|
+
sub_query._fill_in_variables(user_metadata, tool_parameters)
|
104
|
+
for sub_query in self.or_list
|
105
|
+
]
|
106
|
+
return new_stmt
|
107
|
+
|
108
|
+
|
109
|
+
# Update the forward references
|
110
|
+
Statement.model_rebuild()
|
111
|
+
AndStatement.model_rebuild()
|
112
|
+
OrStatement.model_rebuild()
|
113
|
+
|
114
|
+
|
115
|
+
UniqueQL = Union[Statement, AndStatement, OrStatement]
|
116
|
+
|
117
|
+
|
118
|
+
def is_array_of_strings(value: Any) -> bool:
|
119
|
+
return isinstance(value, list) and all(isinstance(item, str) for item in value)
|
120
|
+
|
121
|
+
|
122
|
+
def eval_operator(
|
123
|
+
query: Statement,
|
124
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
125
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
126
|
+
) -> Any:
|
127
|
+
if query.operator in [
|
128
|
+
Operator.EQUALS,
|
129
|
+
Operator.NOT_EQUALS,
|
130
|
+
Operator.GREATER_THAN,
|
131
|
+
Operator.GREATER_THAN_OR_EQUAL,
|
132
|
+
Operator.LESS_THAN,
|
133
|
+
Operator.LESS_THAN_OR_EQUAL,
|
134
|
+
Operator.CONTAINS,
|
135
|
+
Operator.NOT_CONTAINS,
|
136
|
+
]:
|
137
|
+
return binary_operator(query.value, user_metadata, tool_parameters)
|
138
|
+
elif query.operator in [Operator.IS_NULL, Operator.IS_NOT_NULL]:
|
139
|
+
return null_operator(query.value, user_metadata, tool_parameters)
|
140
|
+
elif query.operator in [Operator.IS_EMPTY, Operator.IS_NOT_EMPTY]:
|
141
|
+
return empty_operator(query.operator, user_metadata, tool_parameters)
|
142
|
+
elif query.operator == Operator.NESTED:
|
143
|
+
return eval_nested_operator(query.value, user_metadata, tool_parameters)
|
144
|
+
elif query.operator in [Operator.IN, Operator.NOT_IN]:
|
145
|
+
return array_operator(query.value, user_metadata, tool_parameters)
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Operator {query.operator} not supported")
|
148
|
+
|
149
|
+
|
150
|
+
def eval_nested_operator(
|
151
|
+
value: Any,
|
152
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
153
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
154
|
+
) -> Union[AndStatement, OrStatement]:
|
155
|
+
if not isinstance(value, (AndStatement, OrStatement)):
|
156
|
+
raise ValueError("Nested operator must be an AndStatement or OrStatement")
|
157
|
+
return value._fill_in_variables(user_metadata, tool_parameters)
|
158
|
+
|
159
|
+
|
160
|
+
def binary_operator(
|
161
|
+
value: Any,
|
162
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
163
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
164
|
+
) -> Any:
|
165
|
+
return replace_variables(value, user_metadata, tool_parameters)
|
166
|
+
|
167
|
+
|
168
|
+
def array_operator(
|
169
|
+
value: Any,
|
170
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
171
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
172
|
+
) -> Any:
|
173
|
+
if is_array_of_strings(value):
|
174
|
+
return [
|
175
|
+
replace_variables(item, user_metadata, tool_parameters) for item in value
|
176
|
+
]
|
177
|
+
return value
|
178
|
+
|
179
|
+
|
180
|
+
def null_operator(
|
181
|
+
value: Any,
|
182
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
183
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
184
|
+
) -> Any:
|
185
|
+
return value # do nothing for now. No variables to replace
|
186
|
+
|
187
|
+
|
188
|
+
def empty_operator(
|
189
|
+
operator: Operator,
|
190
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
191
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
192
|
+
) -> Any:
|
193
|
+
"""Handle IS_EMPTY and IS_NOT_EMPTY operators."""
|
194
|
+
if operator == Operator.IS_EMPTY:
|
195
|
+
return ""
|
196
|
+
elif operator == Operator.IS_NOT_EMPTY:
|
197
|
+
return "not_empty"
|
198
|
+
return None
|
199
|
+
|
200
|
+
|
201
|
+
def calculate_current_date() -> str:
|
202
|
+
"""Calculate current date in UTC with seconds precision."""
|
203
|
+
return datetime.now(timezone.utc).isoformat(timespec="seconds")
|
204
|
+
|
205
|
+
|
206
|
+
def calculate_earlier_date(input_str: str) -> str:
|
207
|
+
match = re.search(r"<T-(\d+)>", input_str)
|
208
|
+
if not match:
|
209
|
+
return calculate_current_date() # Return current date if no match
|
210
|
+
days = int(match.group(1))
|
211
|
+
return (datetime.now(timezone.utc) - timedelta(days=days)).isoformat(
|
212
|
+
timespec="seconds"
|
213
|
+
)
|
214
|
+
|
215
|
+
|
216
|
+
def calculate_later_date(input_str: str) -> str:
|
217
|
+
match = re.search(r"<T\+(\d+)>", input_str) # Note: escaped + in regex
|
218
|
+
if not match:
|
219
|
+
return calculate_current_date() # Return current date if no match
|
220
|
+
days = int(match.group(1))
|
221
|
+
return (datetime.now(timezone.utc) + timedelta(days=days)).isoformat(
|
222
|
+
timespec="seconds"
|
223
|
+
)
|
224
|
+
|
225
|
+
|
226
|
+
def replace_variables(
|
227
|
+
value: Any,
|
228
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
229
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
230
|
+
) -> Any:
|
231
|
+
if isinstance(value, str):
|
232
|
+
if "||" in value:
|
233
|
+
return get_fallback_values(value, user_metadata, tool_parameters)
|
234
|
+
elif value == "<T>":
|
235
|
+
return calculate_current_date()
|
236
|
+
elif "<T-" in value:
|
237
|
+
return calculate_earlier_date(value)
|
238
|
+
elif "<T+" in value:
|
239
|
+
return calculate_later_date(value)
|
240
|
+
|
241
|
+
value = replace_tool_parameters_patterns(value, tool_parameters)
|
242
|
+
value = replace_user_metadata_patterns(value, user_metadata)
|
243
|
+
|
244
|
+
if value == "":
|
245
|
+
return value
|
246
|
+
try:
|
247
|
+
return int(value)
|
248
|
+
except ValueError:
|
249
|
+
if value.lower() in ["true", "false"]:
|
250
|
+
return value.lower() == "true"
|
251
|
+
return value
|
252
|
+
return value
|
253
|
+
|
254
|
+
|
255
|
+
def replace_tool_parameters_patterns(
|
256
|
+
value: str, tool_parameters: Dict[str, Union[str, int, bool]]
|
257
|
+
) -> str:
|
258
|
+
def replace_match(match):
|
259
|
+
param_name = match.group(1)
|
260
|
+
return str(tool_parameters.get(param_name, ""))
|
261
|
+
|
262
|
+
return re.sub(r"<toolParameters\.(\w+)>", replace_match, value)
|
263
|
+
|
264
|
+
|
265
|
+
def replace_user_metadata_patterns(
|
266
|
+
value: str, user_metadata: Dict[str, Union[str, int, bool]]
|
267
|
+
) -> str:
|
268
|
+
def replace_match(match):
|
269
|
+
param_name = match.group(1)
|
270
|
+
return str(user_metadata.get(param_name, ""))
|
271
|
+
|
272
|
+
return re.sub(r"<userMetadata\.(\w+)>", replace_match, value)
|
273
|
+
|
274
|
+
|
275
|
+
def get_fallback_values(
|
276
|
+
value: str,
|
277
|
+
user_metadata: Dict[str, Union[str, int, bool]],
|
278
|
+
tool_parameters: Dict[str, Union[str, int, bool]],
|
279
|
+
) -> Any:
|
280
|
+
values = value.split("||")
|
281
|
+
for val in values:
|
282
|
+
data = replace_variables(val, user_metadata, tool_parameters)
|
283
|
+
if data != "":
|
284
|
+
return data
|
285
|
+
return values
|
286
|
+
|
287
|
+
|
288
|
+
# Example usage:
|
289
|
+
def parse_uniqueql(json_data: Dict[str, Any]) -> UniqueQL:
|
290
|
+
if "operator" in json_data:
|
291
|
+
return Statement.model_validate(json_data)
|
292
|
+
elif "or" in json_data:
|
293
|
+
return OrStatement.model_validate(
|
294
|
+
{"or": [parse_uniqueql(item) for item in json_data["or"]]}
|
295
|
+
)
|
296
|
+
elif "and" in json_data:
|
297
|
+
return AndStatement.model_validate(
|
298
|
+
{"and": [parse_uniqueql(item) for item in json_data["and"]]}
|
299
|
+
)
|
300
|
+
else:
|
301
|
+
raise ValueError("Invalid UniqueQL format")
|