unique-web-search 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_web_search/config.py +11 -2
- unique_web_search/service.py +95 -138
- unique_web_search/services/executors/__init__.py +12 -1
- unique_web_search/services/executors/base_executor.py +44 -59
- unique_web_search/services/executors/context.py +104 -0
- unique_web_search/services/executors/web_search_v1_executor.py +34 -77
- unique_web_search/services/executors/web_search_v2_executor.py +73 -83
- unique_web_search/services/message_log.py +70 -0
- unique_web_search/services/query_elicitation.py +162 -0
- unique_web_search/services/search_engine/vertexai.py +1 -1
- {unique_web_search-1.9.1.dist-info → unique_web_search-1.10.0.dist-info}/METADATA +1 -1
- {unique_web_search-1.9.1.dist-info → unique_web_search-1.10.0.dist-info}/RECORD +13 -10
- {unique_web_search-1.9.1.dist-info → unique_web_search-1.10.0.dist-info}/WHEEL +0 -0
|
@@ -1,38 +1,30 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from time import time
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Literal, overload, override
|
|
4
4
|
|
|
5
5
|
from pydantic import Field
|
|
6
6
|
from unique_toolkit import LanguageModelService
|
|
7
|
-
from unique_toolkit._common.chunk_relevancy_sorter.config import (
|
|
8
|
-
ChunkRelevancySortConfig,
|
|
9
|
-
)
|
|
10
|
-
from unique_toolkit._common.chunk_relevancy_sorter.service import ChunkRelevancySorter
|
|
11
7
|
from unique_toolkit._common.utils.structured_output.schema import StructuredOutputModel
|
|
12
8
|
from unique_toolkit._common.validators import LMI
|
|
13
|
-
from unique_toolkit.agentic.tools.tool_progress_reporter import (
|
|
14
|
-
ToolProgressReporter,
|
|
15
|
-
)
|
|
16
9
|
from unique_toolkit.content import ContentChunk
|
|
17
10
|
from unique_toolkit.language_model import LanguageModelFunction
|
|
18
11
|
from unique_toolkit.language_model.builder import MessagesBuilder
|
|
19
12
|
|
|
20
|
-
from unique_web_search.schema import
|
|
21
|
-
from unique_web_search.services.content_processing import ContentProcessor, WebPageChunk
|
|
22
|
-
from unique_web_search.services.crawlers import CrawlerTypes
|
|
13
|
+
from unique_web_search.schema import WebSearchToolParameters
|
|
23
14
|
from unique_web_search.services.executors.base_executor import (
|
|
24
15
|
BaseWebSearchExecutor,
|
|
25
|
-
MessageLogCallback,
|
|
26
|
-
WebSearchLogEntry,
|
|
27
16
|
)
|
|
28
17
|
from unique_web_search.services.executors.configs import RefineQueryMode
|
|
29
|
-
from unique_web_search.services.
|
|
18
|
+
from unique_web_search.services.executors.context import (
|
|
19
|
+
ExecutorCallbacks,
|
|
20
|
+
ExecutorConfiguration,
|
|
21
|
+
ExecutorServiceContext,
|
|
22
|
+
)
|
|
30
23
|
from unique_web_search.services.search_engine.schema import (
|
|
31
24
|
WebSearchResult,
|
|
32
25
|
)
|
|
33
26
|
from unique_web_search.utils import (
|
|
34
27
|
StepDebugInfo,
|
|
35
|
-
WebSearchDebugInfo,
|
|
36
28
|
query_params_to_human_string,
|
|
37
29
|
)
|
|
38
30
|
|
|
@@ -139,105 +131,74 @@ class WebSearchV1Executor(BaseWebSearchExecutor):
|
|
|
139
131
|
@override
|
|
140
132
|
def __init__(
|
|
141
133
|
self,
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
search_service: SearchEngineTypes,
|
|
146
|
-
crawler_service: CrawlerTypes,
|
|
147
|
-
content_processor: ContentProcessor,
|
|
148
|
-
message_log_callback: MessageLogCallback,
|
|
149
|
-
chunk_relevancy_sorter: ChunkRelevancySorter | None,
|
|
150
|
-
chunk_relevancy_sort_config: ChunkRelevancySortConfig,
|
|
151
|
-
content_reducer: Callable[[list[WebPageChunk]], list[WebPageChunk]],
|
|
134
|
+
services: ExecutorServiceContext,
|
|
135
|
+
config: ExecutorConfiguration,
|
|
136
|
+
callbacks: ExecutorCallbacks,
|
|
152
137
|
tool_call: LanguageModelFunction,
|
|
153
138
|
tool_parameters: WebSearchToolParameters,
|
|
154
139
|
refine_query_system_prompt: str,
|
|
155
|
-
debug_info: WebSearchDebugInfo,
|
|
156
|
-
tool_progress_reporter: Optional[ToolProgressReporter] = None,
|
|
157
140
|
mode: RefineQueryMode = RefineQueryMode.BASIC,
|
|
158
141
|
max_queries: int = 10,
|
|
159
142
|
):
|
|
160
143
|
super().__init__(
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
crawler_service=crawler_service,
|
|
144
|
+
services=services,
|
|
145
|
+
config=config,
|
|
146
|
+
callbacks=callbacks,
|
|
165
147
|
tool_call=tool_call,
|
|
166
148
|
tool_parameters=tool_parameters,
|
|
167
|
-
company_id=company_id,
|
|
168
|
-
content_processor=content_processor,
|
|
169
|
-
chunk_relevancy_sorter=chunk_relevancy_sorter,
|
|
170
|
-
chunk_relevancy_sort_config=chunk_relevancy_sort_config,
|
|
171
|
-
debug_info=debug_info,
|
|
172
|
-
content_reducer=content_reducer,
|
|
173
|
-
tool_progress_reporter=tool_progress_reporter,
|
|
174
|
-
message_log_callback=message_log_callback,
|
|
175
149
|
)
|
|
176
150
|
self.mode = mode
|
|
177
151
|
self.tool_parameters = tool_parameters
|
|
178
152
|
self.refine_query_system_prompt = refine_query_system_prompt
|
|
179
153
|
self.max_queries = max_queries
|
|
180
154
|
|
|
181
|
-
async def run(self) ->
|
|
155
|
+
async def run(self) -> list[ContentChunk]:
|
|
182
156
|
query = self.tool_parameters.query
|
|
183
157
|
date_restrict = self.tool_parameters.date_restrict
|
|
184
158
|
|
|
185
159
|
self.notify_name = "**Refining Query**"
|
|
186
160
|
self.notify_message = query_params_to_human_string(query, date_restrict)
|
|
187
161
|
await self.notify_callback()
|
|
188
|
-
|
|
189
|
-
|
|
162
|
+
|
|
163
|
+
await self._message_log_callback.log_progress(
|
|
164
|
+
f"_Refining Query:_ {self.notify_message}"
|
|
190
165
|
)
|
|
191
166
|
refined_queries, objective = await self._refine_query(query)
|
|
192
167
|
|
|
168
|
+
elicitated_queries = await self._ff_elicitate_queries(refined_queries)
|
|
169
|
+
|
|
193
170
|
web_search_results = []
|
|
194
171
|
# Pass query strings only - callback handles creating WebSearchLogEntry objects
|
|
172
|
+
|
|
195
173
|
queries_wo_results = [
|
|
196
174
|
query_params_to_human_string(refined_query, date_restrict)
|
|
197
|
-
for refined_query in
|
|
175
|
+
for refined_query in elicitated_queries
|
|
198
176
|
]
|
|
199
|
-
|
|
200
|
-
queries_for_log=queries_wo_results
|
|
201
|
-
)
|
|
177
|
+
await self._message_log_callback.log_queries(queries_wo_results)
|
|
202
178
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
if len(refined_queries) > 1:
|
|
179
|
+
for index, query in enumerate(elicitated_queries):
|
|
180
|
+
if len(elicitated_queries) > 1:
|
|
206
181
|
self.notify_name = (
|
|
207
|
-
f"**Searching Web {index + 1}/{len(
|
|
182
|
+
f"**Searching Web {index + 1}/{len(elicitated_queries)}**"
|
|
208
183
|
)
|
|
209
184
|
else:
|
|
210
185
|
self.notify_name = "**Searching Web**"
|
|
211
186
|
|
|
212
|
-
self.notify_message = query_params_to_human_string(
|
|
213
|
-
refined_query, date_restrict
|
|
214
|
-
)
|
|
187
|
+
self.notify_message = query_params_to_human_string(query, date_restrict)
|
|
215
188
|
await self.notify_callback()
|
|
189
|
+
await self._message_log_callback.log_progress(self.notify_message)
|
|
216
190
|
|
|
217
|
-
search_results = await self._search(
|
|
218
|
-
refined_query, date_restrict=date_restrict
|
|
219
|
-
)
|
|
220
|
-
queries_for_log.append(
|
|
221
|
-
WebSearchLogEntry(
|
|
222
|
-
type=StepType.SEARCH,
|
|
223
|
-
message=self.notify_message,
|
|
224
|
-
web_search_results=search_results,
|
|
225
|
-
)
|
|
226
|
-
)
|
|
191
|
+
search_results = await self._search(query, date_restrict=date_restrict)
|
|
227
192
|
|
|
228
|
-
|
|
193
|
+
await self._message_log_callback.log_web_search_results(search_results)
|
|
229
194
|
|
|
230
|
-
|
|
231
|
-
queries_for_log=queries_for_log
|
|
232
|
-
)
|
|
195
|
+
web_search_results.extend(search_results)
|
|
233
196
|
|
|
234
197
|
if self.search_service.requires_scraping:
|
|
235
198
|
self.notify_name = "**Crawling URLs**"
|
|
236
199
|
self.notify_message = f"{len(web_search_results)} URLs to fetch"
|
|
237
200
|
await self.notify_callback()
|
|
238
|
-
|
|
239
|
-
progress_message="_Crawling URLs_", queries_for_log=queries_for_log
|
|
240
|
-
)
|
|
201
|
+
await self._message_log_callback.log_progress("_Crawling URLs_")
|
|
241
202
|
crawl_results = await self._crawl(web_search_results)
|
|
242
203
|
for web_search_result, crawl_result in zip(
|
|
243
204
|
web_search_results, crawl_results
|
|
@@ -247,9 +208,7 @@ class WebSearchV1Executor(BaseWebSearchExecutor):
|
|
|
247
208
|
self.notify_name = "**Analyzing Web Pages**"
|
|
248
209
|
self.notify_message = objective
|
|
249
210
|
await self.notify_callback()
|
|
250
|
-
|
|
251
|
-
progress_message="_Analyzing Web Pages_", queries_for_log=queries_for_log
|
|
252
|
-
)
|
|
211
|
+
await self._message_log_callback.log_progress("_Analyzing Web Pages_")
|
|
253
212
|
|
|
254
213
|
content_results = await self._content_processing(objective, web_search_results)
|
|
255
214
|
|
|
@@ -257,15 +216,13 @@ class WebSearchV1Executor(BaseWebSearchExecutor):
|
|
|
257
216
|
self.notify_name = "**Resorting Sources**"
|
|
258
217
|
self.notify_message = objective
|
|
259
218
|
await self.notify_callback()
|
|
260
|
-
|
|
261
|
-
progress_message="_Resorting Sources_", queries_for_log=queries_for_log
|
|
262
|
-
)
|
|
219
|
+
await self._message_log_callback.log_progress("_Resorting Sources_")
|
|
263
220
|
|
|
264
221
|
relevant_sources = await self._select_relevant_sources(
|
|
265
222
|
objective, content_results
|
|
266
223
|
)
|
|
267
224
|
|
|
268
|
-
return relevant_sources
|
|
225
|
+
return relevant_sources
|
|
269
226
|
|
|
270
227
|
async def _refine_query(self, query: str) -> tuple[list[str], str]:
|
|
271
228
|
start_time = time()
|
|
@@ -1,17 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from time import time
|
|
4
|
-
from typing import Callable, Optional
|
|
5
4
|
|
|
6
|
-
from unique_toolkit import LanguageModelService
|
|
7
|
-
from unique_toolkit._common.chunk_relevancy_sorter.config import (
|
|
8
|
-
ChunkRelevancySortConfig,
|
|
9
|
-
)
|
|
10
|
-
from unique_toolkit._common.chunk_relevancy_sorter.service import ChunkRelevancySorter
|
|
11
|
-
from unique_toolkit._common.validators import LMI
|
|
12
|
-
from unique_toolkit.agentic.tools.tool_progress_reporter import (
|
|
13
|
-
ToolProgressReporter,
|
|
14
|
-
)
|
|
15
5
|
from unique_toolkit.content import ContentChunk
|
|
16
6
|
from unique_toolkit.language_model import LanguageModelFunction
|
|
17
7
|
|
|
@@ -20,18 +10,18 @@ from unique_web_search.schema import (
|
|
|
20
10
|
StepType,
|
|
21
11
|
WebSearchPlan,
|
|
22
12
|
)
|
|
23
|
-
from unique_web_search.services.content_processing import ContentProcessor, WebPageChunk
|
|
24
|
-
from unique_web_search.services.crawlers import CrawlerTypes
|
|
25
13
|
from unique_web_search.services.executors.base_executor import (
|
|
26
14
|
BaseWebSearchExecutor,
|
|
27
|
-
MessageLogCallback,
|
|
28
|
-
WebSearchLogEntry,
|
|
29
15
|
)
|
|
30
|
-
from unique_web_search.services.
|
|
16
|
+
from unique_web_search.services.executors.context import (
|
|
17
|
+
ExecutorCallbacks,
|
|
18
|
+
ExecutorConfiguration,
|
|
19
|
+
ExecutorServiceContext,
|
|
20
|
+
)
|
|
31
21
|
from unique_web_search.services.search_engine.schema import (
|
|
32
22
|
WebSearchResult,
|
|
33
23
|
)
|
|
34
|
-
from unique_web_search.utils import StepDebugInfo
|
|
24
|
+
from unique_web_search.utils import StepDebugInfo
|
|
35
25
|
|
|
36
26
|
_LOGGER = logging.getLogger(__name__)
|
|
37
27
|
|
|
@@ -46,42 +36,23 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
46
36
|
|
|
47
37
|
def __init__(
|
|
48
38
|
self,
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
crawler_service: CrawlerTypes,
|
|
39
|
+
services: ExecutorServiceContext,
|
|
40
|
+
config: ExecutorConfiguration,
|
|
41
|
+
callbacks: ExecutorCallbacks,
|
|
53
42
|
tool_call: LanguageModelFunction,
|
|
54
43
|
tool_parameters: WebSearchPlan,
|
|
55
|
-
company_id: str,
|
|
56
|
-
content_processor: ContentProcessor,
|
|
57
|
-
message_log_callback: MessageLogCallback,
|
|
58
|
-
chunk_relevancy_sorter: ChunkRelevancySorter | None,
|
|
59
|
-
chunk_relevancy_sort_config: ChunkRelevancySortConfig,
|
|
60
|
-
content_reducer: Callable[[list[WebPageChunk]], list[WebPageChunk]],
|
|
61
|
-
debug_info: WebSearchDebugInfo,
|
|
62
|
-
tool_progress_reporter: Optional[ToolProgressReporter] = None,
|
|
63
44
|
max_steps: int = 3,
|
|
64
45
|
):
|
|
65
46
|
super().__init__(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
crawler_service=crawler_service,
|
|
47
|
+
services=services,
|
|
48
|
+
config=config,
|
|
49
|
+
callbacks=callbacks,
|
|
70
50
|
tool_call=tool_call,
|
|
71
51
|
tool_parameters=tool_parameters,
|
|
72
|
-
company_id=company_id,
|
|
73
|
-
content_processor=content_processor,
|
|
74
|
-
chunk_relevancy_sorter=chunk_relevancy_sorter,
|
|
75
|
-
chunk_relevancy_sort_config=chunk_relevancy_sort_config,
|
|
76
|
-
content_reducer=content_reducer,
|
|
77
|
-
debug_info=debug_info,
|
|
78
|
-
tool_progress_reporter=tool_progress_reporter,
|
|
79
|
-
message_log_callback=message_log_callback,
|
|
80
52
|
)
|
|
81
53
|
|
|
82
54
|
self.tool_parameters = tool_parameters
|
|
83
55
|
self.max_steps = max_steps
|
|
84
|
-
self.queries_for_log: list[WebSearchLogEntry] = []
|
|
85
56
|
|
|
86
57
|
@property
|
|
87
58
|
def notify_name(self):
|
|
@@ -99,20 +70,20 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
99
70
|
def notify_message(self, value):
|
|
100
71
|
self._notify_message = value
|
|
101
72
|
|
|
102
|
-
async def run(self) ->
|
|
73
|
+
async def run(self) -> list[ContentChunk]:
|
|
103
74
|
await self._enforce_max_steps()
|
|
104
75
|
|
|
105
76
|
results: list[WebSearchResult] = []
|
|
106
77
|
self.notify_name = "**Searching Web**"
|
|
107
78
|
self.notify_message = self.tool_parameters.objective
|
|
79
|
+
|
|
108
80
|
await self.notify_callback()
|
|
109
|
-
self._message_log_callback(
|
|
110
|
-
|
|
111
|
-
)
|
|
81
|
+
await self._message_log_callback.log_progress("_Searching Web_")
|
|
82
|
+
|
|
83
|
+
elicitated_steps = await self._elicitate_steps(self.tool_parameters.steps)
|
|
112
84
|
|
|
113
85
|
tasks = [
|
|
114
|
-
asyncio.create_task(self._execute_step(step))
|
|
115
|
-
for step in self.tool_parameters.steps
|
|
86
|
+
asyncio.create_task(self._execute_step(step)) for step in elicitated_steps
|
|
116
87
|
]
|
|
117
88
|
|
|
118
89
|
results_nested = await asyncio.gather(*tasks, return_exceptions=True)
|
|
@@ -123,20 +94,10 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
123
94
|
else:
|
|
124
95
|
results.extend(result)
|
|
125
96
|
|
|
126
|
-
# for step in self.tool_parameters.steps:
|
|
127
|
-
# self.notify_name = step_type_to_name[step.step_type]
|
|
128
|
-
# self.notify_message = step.objective
|
|
129
|
-
# await self.notify_callback()
|
|
130
|
-
# step_results = await self._execute_step(step)
|
|
131
|
-
# results.extend(step_results)
|
|
132
|
-
|
|
133
97
|
self.notify_name = "**Analyzing Web Pages**"
|
|
134
98
|
self.notify_message = self.tool_parameters.expected_outcome
|
|
135
99
|
await self.notify_callback()
|
|
136
|
-
self._message_log_callback(
|
|
137
|
-
progress_message="_Analyzing Web Pages_",
|
|
138
|
-
queries_for_log=self.queries_for_log,
|
|
139
|
-
)
|
|
100
|
+
await self._message_log_callback.log_progress("_Analyzing Web Pages_")
|
|
140
101
|
|
|
141
102
|
content_results = await self._content_processing(
|
|
142
103
|
self.tool_parameters.objective, results
|
|
@@ -146,16 +107,13 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
146
107
|
self.notify_name = "**Resorting Sources**"
|
|
147
108
|
self.notify_message = self.tool_parameters.objective
|
|
148
109
|
await self.notify_callback()
|
|
149
|
-
self._message_log_callback(
|
|
150
|
-
progress_message="_Resorting Sources_",
|
|
151
|
-
queries_for_log=self.queries_for_log,
|
|
152
|
-
)
|
|
110
|
+
await self._message_log_callback.log_progress("_Resorting Sources_")
|
|
153
111
|
|
|
154
112
|
relevant_sources = await self._select_relevant_sources(
|
|
155
113
|
self.tool_parameters.objective, content_results
|
|
156
114
|
)
|
|
157
115
|
|
|
158
|
-
return relevant_sources
|
|
116
|
+
return relevant_sources
|
|
159
117
|
|
|
160
118
|
async def _execute_step(self, step: Step) -> list[WebSearchResult]:
|
|
161
119
|
if step.step_type == StepType.SEARCH:
|
|
@@ -178,14 +136,9 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
178
136
|
time_start = time()
|
|
179
137
|
_LOGGER.info(f"Company {self.company_id} Searching with {self.search_service}")
|
|
180
138
|
|
|
139
|
+
await self._message_log_callback.log_queries([step.query_or_url])
|
|
181
140
|
results = await self.search_service.search(step.query_or_url)
|
|
182
|
-
self.
|
|
183
|
-
WebSearchLogEntry(
|
|
184
|
-
type=StepType.SEARCH,
|
|
185
|
-
message=step.query_or_url,
|
|
186
|
-
web_search_results=results,
|
|
187
|
-
)
|
|
188
|
-
)
|
|
141
|
+
await self._message_log_callback.log_web_search_results(results)
|
|
189
142
|
|
|
190
143
|
delta_time = time() - time_start
|
|
191
144
|
|
|
@@ -245,20 +198,17 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
245
198
|
)
|
|
246
199
|
time_start = time()
|
|
247
200
|
_LOGGER.info(f"Company {self.company_id} Crawling with {self.crawler_service}")
|
|
201
|
+
|
|
248
202
|
results = await self.crawler_service.crawl([step.query_or_url])
|
|
249
|
-
self.
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
title=step.objective,
|
|
259
|
-
)
|
|
260
|
-
],
|
|
261
|
-
)
|
|
203
|
+
await self._message_log_callback.log_web_search_results(
|
|
204
|
+
[
|
|
205
|
+
WebSearchResult(
|
|
206
|
+
url=step.query_or_url,
|
|
207
|
+
content=results[0],
|
|
208
|
+
snippet=step.objective,
|
|
209
|
+
title=step.objective,
|
|
210
|
+
)
|
|
211
|
+
]
|
|
262
212
|
)
|
|
263
213
|
delta_time = time() - time_start
|
|
264
214
|
_LOGGER.debug(
|
|
@@ -300,3 +250,43 @@ class WebSearchV2Executor(BaseWebSearchExecutor):
|
|
|
300
250
|
},
|
|
301
251
|
)
|
|
302
252
|
)
|
|
253
|
+
|
|
254
|
+
async def _elicitate_steps(self, steps: list[Step]) -> list[Step]:
|
|
255
|
+
"""Elicit user approval for search steps while preserving read URL steps.
|
|
256
|
+
|
|
257
|
+
This method partitions the steps into search and read URL types. Search steps
|
|
258
|
+
require query elicitation for user approval/modification, while read URL steps
|
|
259
|
+
are passed through unchanged. The elicited queries replace the original search
|
|
260
|
+
steps, maintaining the step order (search steps first, then read URL steps).
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
steps: List of planned steps to elicit
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
List of approved steps with elicited search queries and original read URL steps
|
|
267
|
+
"""
|
|
268
|
+
# Partition steps by type in a single pass
|
|
269
|
+
search_steps, read_url_steps = [], []
|
|
270
|
+
for step in steps:
|
|
271
|
+
if step.step_type == StepType.SEARCH:
|
|
272
|
+
search_steps.append(step)
|
|
273
|
+
else: # StepType.READ_URL
|
|
274
|
+
read_url_steps.append(step)
|
|
275
|
+
|
|
276
|
+
# Early return if no search steps require elicitation
|
|
277
|
+
if not search_steps:
|
|
278
|
+
return read_url_steps
|
|
279
|
+
|
|
280
|
+
# Extract queries and elicit user approval/modifications
|
|
281
|
+
search_queries = [step.query_or_url for step in search_steps]
|
|
282
|
+
elicited_queries = await self._ff_elicitate_queries(search_queries)
|
|
283
|
+
|
|
284
|
+
# Reconstruct search steps with elicited queries
|
|
285
|
+
# Note: Objectives are cleared as elicitation may have changed query intent
|
|
286
|
+
elicited_search_steps = [
|
|
287
|
+
Step(step_type=StepType.SEARCH, query_or_url=query, objective="")
|
|
288
|
+
for query in elicited_queries
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
# Return elicited search steps followed by unchanged read URL steps
|
|
292
|
+
return elicited_search_steps + read_url_steps
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from unique_toolkit.agentic.message_log_manager.service import MessageStepLogger
|
|
2
|
+
from unique_toolkit.chat.schemas import (
|
|
3
|
+
MessageLog,
|
|
4
|
+
MessageLogDetails,
|
|
5
|
+
MessageLogEvent,
|
|
6
|
+
MessageLogStatus,
|
|
7
|
+
)
|
|
8
|
+
from unique_toolkit.content import ContentReference
|
|
9
|
+
|
|
10
|
+
from unique_web_search.services.search_engine.schema import WebSearchResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WebSearchMessageLogger:
|
|
14
|
+
def __init__(self, message_step_logger: MessageStepLogger, tool_display_name: str):
|
|
15
|
+
self._message_step_logger = message_step_logger
|
|
16
|
+
self._current_message_log: MessageLog | None = None
|
|
17
|
+
self._tool_display_name = tool_display_name
|
|
18
|
+
|
|
19
|
+
self._status = MessageLogStatus.RUNNING
|
|
20
|
+
self._details: MessageLogDetails = MessageLogDetails(data=[])
|
|
21
|
+
self._references: list[ContentReference] = []
|
|
22
|
+
self._progress_message = ""
|
|
23
|
+
|
|
24
|
+
async def finished(self) -> None:
|
|
25
|
+
self._status = MessageLogStatus.COMPLETED
|
|
26
|
+
await self._propagate_message_log()
|
|
27
|
+
|
|
28
|
+
async def failed(self) -> None:
|
|
29
|
+
self._status = MessageLogStatus.FAILED
|
|
30
|
+
await self._propagate_message_log()
|
|
31
|
+
|
|
32
|
+
async def log_progress(self, progress_message: str) -> None:
|
|
33
|
+
self._progress_message = progress_message
|
|
34
|
+
await self._propagate_message_log()
|
|
35
|
+
|
|
36
|
+
async def log_queries(self, queries: list[str]) -> None:
|
|
37
|
+
log_events_from_queries = [
|
|
38
|
+
MessageLogEvent(
|
|
39
|
+
type="WebSearch",
|
|
40
|
+
text=query,
|
|
41
|
+
)
|
|
42
|
+
for query in queries
|
|
43
|
+
]
|
|
44
|
+
self._details.data.extend(log_events_from_queries) # type: ignore (data has already been initialized with an empty list)
|
|
45
|
+
await self._propagate_message_log()
|
|
46
|
+
|
|
47
|
+
async def log_web_search_results(
|
|
48
|
+
self, web_search_results: list[WebSearchResult]
|
|
49
|
+
) -> None:
|
|
50
|
+
offset_sequence_number = len(self._references)
|
|
51
|
+
new_references = [
|
|
52
|
+
web_search_result.to_content_reference(sequence_number)
|
|
53
|
+
for sequence_number, web_search_result in enumerate(
|
|
54
|
+
web_search_results, start=offset_sequence_number
|
|
55
|
+
)
|
|
56
|
+
]
|
|
57
|
+
self._references.extend(new_references)
|
|
58
|
+
await self._propagate_message_log()
|
|
59
|
+
|
|
60
|
+
async def _propagate_message_log(self) -> None:
|
|
61
|
+
self._current_message_log = (
|
|
62
|
+
await self._message_step_logger.create_or_update_message_log_async(
|
|
63
|
+
active_message_log=self._current_message_log,
|
|
64
|
+
header=self._tool_display_name,
|
|
65
|
+
progress_message=self._progress_message,
|
|
66
|
+
details=self._details,
|
|
67
|
+
references=self._references,
|
|
68
|
+
status=self._status,
|
|
69
|
+
)
|
|
70
|
+
)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Query elicitation service for web search.
|
|
2
|
+
|
|
3
|
+
This module provides functionality for creating and evaluating query elicitations,
|
|
4
|
+
allowing users to review and modify proposed search queries before execution.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Self
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
12
|
+
from unique_toolkit._common.pydantic_helpers import get_configuration_dict
|
|
13
|
+
from unique_toolkit.chat.service import ChatService
|
|
14
|
+
from unique_toolkit.elicitation import (
|
|
15
|
+
ElicitationCancelledException,
|
|
16
|
+
ElicitationDeclinedException,
|
|
17
|
+
ElicitationExpiredException,
|
|
18
|
+
ElicitationFailedException,
|
|
19
|
+
ElicitationMode,
|
|
20
|
+
ElicitationStatus,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
_LOGGER = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class QueryElicitationConfig(BaseModel):
|
|
27
|
+
model_config = get_configuration_dict()
|
|
28
|
+
|
|
29
|
+
enable_elicitation: bool = Field(
|
|
30
|
+
default=True,
|
|
31
|
+
description="Whether to enable elicitation. This flag is relevant only if the associated feature flag is enabled.",
|
|
32
|
+
)
|
|
33
|
+
timeout_seconds: int = Field(
|
|
34
|
+
default=60,
|
|
35
|
+
description="Timeout in seconds for waiting for user approval",
|
|
36
|
+
ge=1,
|
|
37
|
+
le=300,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class QueryElicitationModel(BaseModel):
|
|
42
|
+
"""Model for query elicitation with support for default values.
|
|
43
|
+
|
|
44
|
+
This model uses Pydantic's RootModel to represent a list of queries.
|
|
45
|
+
The create_model_with_default_queries classmethod enables dynamic model
|
|
46
|
+
creation with default values for form pre-population.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
model_config = ConfigDict(title="Query Elicitation")
|
|
50
|
+
queries: list[str] = Field(description="The queries to search the web for")
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def create_model_with_default_queries(cls, queries: list[str]) -> type[Self]:
|
|
54
|
+
"""Create a model with default query values.
|
|
55
|
+
|
|
56
|
+
This method dynamically creates a Pydantic model where the root field
|
|
57
|
+
has default values set to the provided queries. This allows elicitation
|
|
58
|
+
forms to be pre-populated with suggested queries that users can review
|
|
59
|
+
and modify.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
queries: List of default query strings to pre-populate
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A new model class with default values set
|
|
66
|
+
"""
|
|
67
|
+
model = create_model(
|
|
68
|
+
cls.__name__,
|
|
69
|
+
queries=(
|
|
70
|
+
list[str],
|
|
71
|
+
Field(
|
|
72
|
+
description="The queries to search the web for",
|
|
73
|
+
default=queries,
|
|
74
|
+
),
|
|
75
|
+
),
|
|
76
|
+
__base__=cls,
|
|
77
|
+
)
|
|
78
|
+
return model
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class QueryElicitationService:
|
|
82
|
+
"""Service for managing query elicitation workflow.
|
|
83
|
+
|
|
84
|
+
This service encapsulates the logic for creating elicitations and waiting
|
|
85
|
+
for user approval, providing a clean callback-based interface for integration
|
|
86
|
+
with web search executors.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
chat_service: ChatService,
|
|
92
|
+
display_name: str,
|
|
93
|
+
config: QueryElicitationConfig,
|
|
94
|
+
):
|
|
95
|
+
"""Initialize the query elicitation service.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
chat_service: Service for interacting with chat/elicitation APIs
|
|
99
|
+
display_name: Display name for the tool in elicitation UI
|
|
100
|
+
timeout_seconds: Timeout in seconds for waiting for user approval
|
|
101
|
+
"""
|
|
102
|
+
self._chat_service = chat_service
|
|
103
|
+
self._display_name = display_name
|
|
104
|
+
self._config = config
|
|
105
|
+
|
|
106
|
+
async def __call__(self, queries: list[str]) -> list[str]:
|
|
107
|
+
if not self._config.enable_elicitation:
|
|
108
|
+
return queries
|
|
109
|
+
|
|
110
|
+
_LOGGER.info("Creating elicitation...")
|
|
111
|
+
|
|
112
|
+
model = QueryElicitationModel.create_model_with_default_queries(queries)
|
|
113
|
+
elicitation = await self._chat_service.elicitation.create_async(
|
|
114
|
+
mode=ElicitationMode.FORM,
|
|
115
|
+
tool_name=self._display_name,
|
|
116
|
+
message="Web Search Query Approval",
|
|
117
|
+
json_schema=model.model_json_schema(),
|
|
118
|
+
expires_in_seconds=self._config.timeout_seconds,
|
|
119
|
+
)
|
|
120
|
+
_LOGGER.info(
|
|
121
|
+
f"Elicitation created: {elicitation.id}. Waiting for user response for {self._config.timeout_seconds} seconds..."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
for _ in range(self._config.timeout_seconds):
|
|
125
|
+
await asyncio.sleep(1)
|
|
126
|
+
elicitation = await self._chat_service.elicitation.get_async(
|
|
127
|
+
elicitation_id=elicitation.id,
|
|
128
|
+
)
|
|
129
|
+
if elicitation.status == ElicitationStatus.ACCEPTED:
|
|
130
|
+
_LOGGER.info(f"Query elicitation {elicitation.id} accepted")
|
|
131
|
+
queries = QueryElicitationModel.model_validate(
|
|
132
|
+
elicitation.response_content
|
|
133
|
+
).queries
|
|
134
|
+
|
|
135
|
+
if len(queries) == 0:
|
|
136
|
+
raise ElicitationFailedException(
|
|
137
|
+
context="The user approved the web search request but removed all search queries from the form, resulting in zero queries to execute.",
|
|
138
|
+
instruction="The web search tool did not execute because no search queries were provided. "
|
|
139
|
+
"**IMPORTANT INFORMATION TO PROPAGATE TO THE USER:** The user was presented with an approval form (elicitation UI) showing the proposed search queries before execution. "
|
|
140
|
+
"They clicked 'Approve' or submitted the form, but all search query fields were either removed or left empty. "
|
|
141
|
+
"Because of this, the web search tool cannot perform any searches - it needs at least one search query to execute. "
|
|
142
|
+
"Explain this situation clearly to the user and inform them that the search was not performed. "
|
|
143
|
+
"Ask if they would like to retry the search with specific queries, or if they can describe what information they're looking for so you can help formulate appropriate search queries for the next approval.",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return queries
|
|
147
|
+
|
|
148
|
+
elif elicitation.status == ElicitationStatus.DECLINED:
|
|
149
|
+
_LOGGER.info(f"Query elicitation {elicitation.id} declined")
|
|
150
|
+
raise ElicitationDeclinedException(
|
|
151
|
+
f"Elicitation triggerd with queries {queries} was declined"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
elif elicitation.status == ElicitationStatus.CANCELLED:
|
|
155
|
+
_LOGGER.info(f"Query elicitation {elicitation.id} cancelled")
|
|
156
|
+
raise ElicitationCancelledException(
|
|
157
|
+
f"Elicitation triggerd with queries {queries} was cancelled"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
raise ElicitationExpiredException(
|
|
161
|
+
f"Query elicitation {elicitation.id} not accepted"
|
|
162
|
+
)
|