orchestrator-core 4.6.1__py3-none-any.whl → 4.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +1 -1
- orchestrator/api/api_v1/endpoints/processes.py +4 -1
- orchestrator/api/api_v1/endpoints/search.py +44 -34
- orchestrator/{search/retrieval/utils.py → cli/search/display.py} +4 -29
- orchestrator/cli/search/search_explore.py +22 -24
- orchestrator/cli/search/speedtest.py +11 -9
- orchestrator/db/models.py +6 -6
- orchestrator/graphql/resolvers/helpers.py +15 -0
- orchestrator/graphql/resolvers/process.py +5 -3
- orchestrator/graphql/resolvers/product.py +3 -2
- orchestrator/graphql/resolvers/product_block.py +3 -2
- orchestrator/graphql/resolvers/resource_type.py +3 -2
- orchestrator/graphql/resolvers/scheduled_tasks.py +3 -1
- orchestrator/graphql/resolvers/settings.py +2 -0
- orchestrator/graphql/resolvers/subscription.py +5 -3
- orchestrator/graphql/resolvers/version.py +2 -0
- orchestrator/graphql/resolvers/workflow.py +3 -2
- orchestrator/graphql/schemas/process.py +3 -3
- orchestrator/log_config.py +2 -0
- orchestrator/schemas/search.py +1 -1
- orchestrator/schemas/search_requests.py +59 -0
- orchestrator/search/agent/handlers.py +129 -0
- orchestrator/search/agent/prompts.py +54 -33
- orchestrator/search/agent/state.py +9 -24
- orchestrator/search/agent/tools.py +223 -144
- orchestrator/search/agent/validation.py +80 -0
- orchestrator/search/{schemas → aggregations}/__init__.py +20 -0
- orchestrator/search/aggregations/base.py +201 -0
- orchestrator/search/core/types.py +3 -2
- orchestrator/search/filters/__init__.py +4 -0
- orchestrator/search/filters/definitions.py +22 -1
- orchestrator/search/filters/numeric_filter.py +3 -3
- orchestrator/search/llm_migration.py +2 -1
- orchestrator/search/query/__init__.py +90 -0
- orchestrator/search/query/builder.py +285 -0
- orchestrator/search/query/engine.py +162 -0
- orchestrator/search/{retrieval → query}/exceptions.py +38 -7
- orchestrator/search/query/mixins.py +95 -0
- orchestrator/search/query/queries.py +129 -0
- orchestrator/search/query/results.py +252 -0
- orchestrator/search/{retrieval/query_state.py → query/state.py} +31 -11
- orchestrator/search/{retrieval → query}/validation.py +58 -1
- orchestrator/search/retrieval/__init__.py +0 -5
- orchestrator/search/retrieval/pagination.py +7 -8
- orchestrator/search/retrieval/retrievers/base.py +9 -9
- orchestrator/workflows/translations/en-GB.json +1 -0
- {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/METADATA +16 -15
- {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/RECORD +51 -45
- orchestrator/search/retrieval/builder.py +0 -127
- orchestrator/search/retrieval/engine.py +0 -197
- orchestrator/search/schemas/parameters.py +0 -133
- orchestrator/search/schemas/results.py +0 -80
- /orchestrator/search/{export.py → query/export.py} +0 -0
- {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,10 +12,10 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
|
|
14
14
|
import json
|
|
15
|
-
from typing import Any
|
|
15
|
+
from typing import Any, cast
|
|
16
16
|
|
|
17
17
|
import structlog
|
|
18
|
-
from ag_ui.core import EventType,
|
|
18
|
+
from ag_ui.core import EventType, StateSnapshotEvent
|
|
19
19
|
from pydantic_ai import RunContext
|
|
20
20
|
from pydantic_ai.ag_ui import StateDeps
|
|
21
21
|
from pydantic_ai.exceptions import ModelRetry
|
|
@@ -26,17 +26,28 @@ from orchestrator.api.api_v1.endpoints.search import (
|
|
|
26
26
|
get_definitions,
|
|
27
27
|
list_paths,
|
|
28
28
|
)
|
|
29
|
-
from orchestrator.db import
|
|
30
|
-
from orchestrator.search.agent.
|
|
31
|
-
|
|
29
|
+
from orchestrator.db import db
|
|
30
|
+
from orchestrator.search.agent.handlers import (
|
|
31
|
+
execute_aggregation_with_persistence,
|
|
32
|
+
execute_search_with_persistence,
|
|
33
|
+
)
|
|
34
|
+
from orchestrator.search.agent.state import SearchState
|
|
35
|
+
from orchestrator.search.agent.validation import require_action
|
|
36
|
+
from orchestrator.search.aggregations import Aggregation, FieldAggregation, TemporalGrouping
|
|
32
37
|
from orchestrator.search.core.types import ActionType, EntityType, FilterOp
|
|
33
|
-
from orchestrator.search.export import fetch_export_data
|
|
34
38
|
from orchestrator.search.filters import FilterTree
|
|
35
|
-
from orchestrator.search.
|
|
36
|
-
from orchestrator.search.
|
|
37
|
-
from orchestrator.search.
|
|
38
|
-
from orchestrator.search.
|
|
39
|
-
from orchestrator.search.
|
|
39
|
+
from orchestrator.search.query import engine
|
|
40
|
+
from orchestrator.search.query.exceptions import PathNotFoundError, QueryValidationError
|
|
41
|
+
from orchestrator.search.query.export import fetch_export_data
|
|
42
|
+
from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query, SelectQuery
|
|
43
|
+
from orchestrator.search.query.results import AggregationResponse, AggregationResult, ExportData, VisualizationType
|
|
44
|
+
from orchestrator.search.query.state import QueryState
|
|
45
|
+
from orchestrator.search.query.validation import (
|
|
46
|
+
validate_aggregation_field,
|
|
47
|
+
validate_filter_path,
|
|
48
|
+
validate_filter_tree,
|
|
49
|
+
validate_temporal_grouping_field,
|
|
50
|
+
)
|
|
40
51
|
from orchestrator.settings import app_settings
|
|
41
52
|
|
|
42
53
|
logger = structlog.get_logger(__name__)
|
|
@@ -53,27 +64,11 @@ def last_user_message(ctx: RunContext[StateDeps[SearchState]]) -> str | None:
|
|
|
53
64
|
return None
|
|
54
65
|
|
|
55
66
|
|
|
56
|
-
def _set_parameters(
|
|
57
|
-
ctx: RunContext[StateDeps[SearchState]],
|
|
58
|
-
entity_type: EntityType,
|
|
59
|
-
action: str | ActionType,
|
|
60
|
-
query: str,
|
|
61
|
-
filters: Any | None,
|
|
62
|
-
) -> None:
|
|
63
|
-
"""Internal helper to set parameters."""
|
|
64
|
-
ctx.deps.state.parameters = {
|
|
65
|
-
"action": action,
|
|
66
|
-
"entity_type": entity_type,
|
|
67
|
-
"filters": filters,
|
|
68
|
-
"query": query,
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
|
|
72
67
|
@search_toolset.tool
|
|
73
68
|
async def start_new_search(
|
|
74
69
|
ctx: RunContext[StateDeps[SearchState]],
|
|
75
70
|
entity_type: EntityType,
|
|
76
|
-
action:
|
|
71
|
+
action: ActionType = ActionType.SELECT,
|
|
77
72
|
) -> StateSnapshotEvent:
|
|
78
73
|
"""Starts a completely new search, clearing all previous state.
|
|
79
74
|
|
|
@@ -90,13 +85,26 @@ async def start_new_search(
|
|
|
90
85
|
)
|
|
91
86
|
|
|
92
87
|
# Clear all state
|
|
93
|
-
ctx.deps.state.
|
|
94
|
-
ctx.deps.state.
|
|
95
|
-
|
|
96
|
-
#
|
|
97
|
-
|
|
88
|
+
ctx.deps.state.results_count = None
|
|
89
|
+
ctx.deps.state.action = action
|
|
90
|
+
|
|
91
|
+
# Create the appropriate query object based on action
|
|
92
|
+
if action == ActionType.SELECT:
|
|
93
|
+
ctx.deps.state.query = SelectQuery(
|
|
94
|
+
entity_type=entity_type,
|
|
95
|
+
query_text=final_query,
|
|
96
|
+
)
|
|
97
|
+
elif action == ActionType.COUNT:
|
|
98
|
+
ctx.deps.state.query = CountQuery(
|
|
99
|
+
entity_type=entity_type,
|
|
100
|
+
)
|
|
101
|
+
else: # ActionType.AGGREGATE
|
|
102
|
+
ctx.deps.state.query = AggregateQuery(
|
|
103
|
+
entity_type=entity_type,
|
|
104
|
+
aggregations=[], # Will be set by set_aggregations tool
|
|
105
|
+
)
|
|
98
106
|
|
|
99
|
-
logger.debug("New search started",
|
|
107
|
+
logger.debug("New search started", action=action.value, query_type=type(ctx.deps.state.query).__name__)
|
|
100
108
|
|
|
101
109
|
return StateSnapshotEvent(
|
|
102
110
|
type=EventType.STATE_SNAPSHOT,
|
|
@@ -108,18 +116,15 @@ async def start_new_search(
|
|
|
108
116
|
async def set_filter_tree(
|
|
109
117
|
ctx: RunContext[StateDeps[SearchState]],
|
|
110
118
|
filters: FilterTree | None,
|
|
111
|
-
) ->
|
|
119
|
+
) -> StateSnapshotEvent:
|
|
112
120
|
"""Replace current filters atomically with a full FilterTree, or clear with None.
|
|
113
121
|
|
|
114
|
-
|
|
115
|
-
- Root/group operators must be 'AND' or 'OR' (uppercase).
|
|
116
|
-
- Provide either PathFilters or nested groups under `children`.
|
|
117
|
-
- See the FilterTree schema examples for the exact shape.
|
|
122
|
+
See FilterTree model for structure, operators, and examples.
|
|
118
123
|
"""
|
|
119
|
-
if ctx.deps.state.
|
|
120
|
-
raise ModelRetry("Search
|
|
124
|
+
if ctx.deps.state.query is None:
|
|
125
|
+
raise ModelRetry("Search query is not initialized. Call start_new_search first.")
|
|
121
126
|
|
|
122
|
-
entity_type =
|
|
127
|
+
entity_type = ctx.deps.state.query.entity_type
|
|
123
128
|
|
|
124
129
|
logger.debug(
|
|
125
130
|
"Setting filter tree",
|
|
@@ -133,108 +138,112 @@ async def set_filter_tree(
|
|
|
133
138
|
except PathNotFoundError as e:
|
|
134
139
|
logger.debug(f"{PathNotFoundError.__name__}: {str(e)}")
|
|
135
140
|
raise ModelRetry(f"{str(e)} Use discover_filter_paths tool to find valid paths.")
|
|
136
|
-
except
|
|
141
|
+
except QueryValidationError as e:
|
|
137
142
|
# ModelRetry will trigger an agent retry, containing the specific validation error.
|
|
138
|
-
logger.debug(f"
|
|
143
|
+
logger.debug(f"Query validation failed: {str(e)}")
|
|
139
144
|
raise ModelRetry(str(e))
|
|
140
145
|
except Exception as e:
|
|
141
146
|
logger.error("Unexpected Filter validation exception", error=str(e))
|
|
142
147
|
raise ModelRetry(f"Filter validation failed: {str(e)}. Please check your filter structure and try again.")
|
|
143
148
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
path="/parameters/filters",
|
|
152
|
-
value=filter_data,
|
|
153
|
-
existed=filters_existed,
|
|
154
|
-
)
|
|
155
|
-
],
|
|
149
|
+
ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"filters": filters})
|
|
150
|
+
|
|
151
|
+
# Use snapshot to workaround state persistence issue
|
|
152
|
+
# TODO: Fix root cause; state tree may be empty on frontend when parameters are being set
|
|
153
|
+
return StateSnapshotEvent(
|
|
154
|
+
type=EventType.STATE_SNAPSHOT,
|
|
155
|
+
snapshot=ctx.deps.state.model_dump(),
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
@search_toolset.tool
|
|
160
|
+
@require_action(ActionType.SELECT)
|
|
160
161
|
async def run_search(
|
|
161
162
|
ctx: RunContext[StateDeps[SearchState]],
|
|
162
163
|
limit: int = 10,
|
|
163
|
-
) ->
|
|
164
|
-
"""Execute
|
|
165
|
-
if not ctx.deps.state.parameters:
|
|
166
|
-
raise ValueError("No search parameters set")
|
|
164
|
+
) -> AggregationResponse:
|
|
165
|
+
"""Execute a search to find and rank entities.
|
|
167
166
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
167
|
+
Use this tool for SELECT action to find entities matching your criteria.
|
|
168
|
+
For counting or computing statistics, use run_aggregation instead.
|
|
169
|
+
"""
|
|
170
|
+
query = cast(SelectQuery, cast(Query, ctx.deps.state.query).model_copy(update={"limit": limit}))
|
|
171
|
+
|
|
172
|
+
search_response, run_id, query_id = await execute_search_with_persistence(query, db.session, ctx.deps.state.run_id)
|
|
173
|
+
|
|
174
|
+
ctx.deps.state.run_id = run_id
|
|
175
|
+
ctx.deps.state.query_id = query_id
|
|
176
|
+
ctx.deps.state.results_count = len(search_response.results)
|
|
177
|
+
|
|
178
|
+
# Convert SearchResults to AggregationResults for consistent rendering
|
|
179
|
+
aggregation_results = [
|
|
180
|
+
AggregationResult(
|
|
181
|
+
group_values={
|
|
182
|
+
"entity_id": result.entity_id,
|
|
183
|
+
"title": result.entity_title,
|
|
184
|
+
"entity_type": result.entity_type.value,
|
|
185
|
+
},
|
|
186
|
+
aggregations={"score": result.score},
|
|
187
|
+
)
|
|
188
|
+
for result in search_response.results
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
# For now use the default table visualization for search results
|
|
192
|
+
aggregation_response = AggregationResponse(
|
|
193
|
+
results=aggregation_results,
|
|
194
|
+
total_groups=len(aggregation_results),
|
|
195
|
+
metadata=search_response.metadata,
|
|
196
|
+
visualization_type=VisualizationType(type="table"),
|
|
176
197
|
)
|
|
177
198
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
199
|
+
logger.debug(
|
|
200
|
+
"Search completed",
|
|
201
|
+
total_count=ctx.deps.state.results_count,
|
|
202
|
+
query_id=str(query_id),
|
|
203
|
+
)
|
|
182
204
|
|
|
183
|
-
|
|
205
|
+
return aggregation_response
|
|
184
206
|
|
|
185
|
-
if not ctx.deps.state.run_id:
|
|
186
|
-
agent_run = AgentRunTable(agent_type="search")
|
|
187
207
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
208
|
+
@search_toolset.tool
|
|
209
|
+
@require_action(ActionType.COUNT, ActionType.AGGREGATE)
|
|
210
|
+
async def run_aggregation(
|
|
211
|
+
ctx: RunContext[StateDeps[SearchState]],
|
|
212
|
+
visualization_type: VisualizationType,
|
|
213
|
+
) -> AggregationResponse:
|
|
214
|
+
"""Execute an aggregation to compute counts or statistics over entities.
|
|
191
215
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
216
|
+
Use this tool for COUNT or AGGREGATE actions after setting up:
|
|
217
|
+
- Grouping fields with set_grouping or set_temporal_grouping
|
|
218
|
+
- Aggregation functions with set_aggregations (for AGGREGATE action)
|
|
219
|
+
"""
|
|
220
|
+
query = cast(CountQuery | AggregateQuery, ctx.deps.state.query)
|
|
195
221
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
search_query = SearchQueryTable.from_state(
|
|
202
|
-
state=query_state,
|
|
203
|
-
run_id=ctx.deps.state.run_id,
|
|
204
|
-
query_number=query_number,
|
|
222
|
+
logger.debug(
|
|
223
|
+
"Executing aggregation",
|
|
224
|
+
search_entity_type=query.entity_type.value,
|
|
225
|
+
has_filters=query.filters is not None,
|
|
226
|
+
action=query.action,
|
|
205
227
|
)
|
|
206
|
-
db.session.add(search_query)
|
|
207
|
-
db.session.commit()
|
|
208
|
-
db.session.expire_all()
|
|
209
|
-
|
|
210
|
-
query_id_existed = ctx.deps.state.query_id is not None
|
|
211
|
-
ctx.deps.state.query_id = search_query.query_id
|
|
212
|
-
logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number)
|
|
213
|
-
changes.append(JSONPatchOp.upsert(path="/query_id", value=str(ctx.deps.state.query_id), existed=query_id_existed))
|
|
214
228
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
total_results=len(search_response.results),
|
|
229
|
+
aggregation_response, run_id, query_id = await execute_aggregation_with_persistence(
|
|
230
|
+
query, db.session, ctx.deps.state.run_id
|
|
218
231
|
)
|
|
219
232
|
|
|
220
|
-
|
|
221
|
-
|
|
233
|
+
ctx.deps.state.run_id = run_id
|
|
234
|
+
ctx.deps.state.query_id = query_id
|
|
235
|
+
ctx.deps.state.results_count = aggregation_response.total_groups
|
|
222
236
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
)
|
|
231
|
-
changes.append(
|
|
232
|
-
JSONPatchOp.upsert(
|
|
233
|
-
path="/results_data", value=ctx.deps.state.results_data.model_dump(), existed=results_data_existed
|
|
234
|
-
)
|
|
237
|
+
aggregation_response.visualization_type = visualization_type
|
|
238
|
+
|
|
239
|
+
logger.debug(
|
|
240
|
+
"Aggregation completed",
|
|
241
|
+
total_groups=aggregation_response.total_groups,
|
|
242
|
+
visualization_type=visualization_type.type,
|
|
243
|
+
query_id=str(query_id),
|
|
235
244
|
)
|
|
236
245
|
|
|
237
|
-
return
|
|
246
|
+
return aggregation_response
|
|
238
247
|
|
|
239
248
|
|
|
240
249
|
@search_toolset.tool
|
|
@@ -248,10 +257,11 @@ async def discover_filter_paths(
|
|
|
248
257
|
Returns a dictionary where each key is a field_name from the input list and
|
|
249
258
|
the value is its discovery result.
|
|
250
259
|
"""
|
|
251
|
-
if not entity_type and ctx.deps.state.parameters:
|
|
252
|
-
entity_type = EntityType(ctx.deps.state.parameters.get("entity_type"))
|
|
253
260
|
if not entity_type:
|
|
254
|
-
|
|
261
|
+
if ctx.deps.state.query:
|
|
262
|
+
entity_type = ctx.deps.state.query.entity_type
|
|
263
|
+
else:
|
|
264
|
+
raise ModelRetry("Entity type not specified and no query in state. Call start_new_search first.")
|
|
255
265
|
|
|
256
266
|
all_results = {}
|
|
257
267
|
for field_name in field_names:
|
|
@@ -332,17 +342,21 @@ async def fetch_entity_details(
|
|
|
332
342
|
JSON string containing detailed entity information.
|
|
333
343
|
|
|
334
344
|
Raises:
|
|
335
|
-
|
|
345
|
+
ModelRetry: If no search has been executed.
|
|
336
346
|
"""
|
|
337
|
-
if
|
|
338
|
-
raise
|
|
347
|
+
if ctx.deps.state.query_id is None:
|
|
348
|
+
raise ModelRetry("No query_id found. Run a search first.")
|
|
339
349
|
|
|
340
|
-
|
|
341
|
-
|
|
350
|
+
# Load the saved query and re-execute it to get entity IDs
|
|
351
|
+
query_state = QueryState.load_from_id(ctx.deps.state.query_id, SelectQuery)
|
|
352
|
+
query = query_state.query.model_copy(update={"limit": limit})
|
|
353
|
+
search_response = await engine.execute_search(query, db.session)
|
|
354
|
+
entity_ids = [r.entity_id for r in search_response.results]
|
|
342
355
|
|
|
343
|
-
|
|
356
|
+
if not entity_ids:
|
|
357
|
+
return json.dumps({"message": "No entities found in search results."})
|
|
344
358
|
|
|
345
|
-
|
|
359
|
+
entity_type = query.entity_type
|
|
346
360
|
|
|
347
361
|
logger.debug(
|
|
348
362
|
"Fetching detailed entity data",
|
|
@@ -356,23 +370,16 @@ async def fetch_entity_details(
|
|
|
356
370
|
|
|
357
371
|
|
|
358
372
|
@search_toolset.tool
|
|
373
|
+
@require_action(ActionType.SELECT)
|
|
359
374
|
async def prepare_export(
|
|
360
375
|
ctx: RunContext[StateDeps[SearchState]],
|
|
361
|
-
) ->
|
|
362
|
-
"""Prepares export URL using the last executed search query.
|
|
363
|
-
if not ctx.deps.state.query_id or not ctx.deps.state.run_id:
|
|
364
|
-
raise ValueError("No search has been executed yet. Run a search first before exporting.")
|
|
376
|
+
) -> ExportData:
|
|
377
|
+
"""Prepares export URL using the last executed search query.
|
|
365
378
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
action = ctx.deps.state.parameters.get("action", ActionType.SELECT)
|
|
371
|
-
if action != ActionType.SELECT:
|
|
372
|
-
raise ValueError(
|
|
373
|
-
f"Export is only available for SELECT actions. Current action is '{action}'. "
|
|
374
|
-
"Please run a SELECT search first."
|
|
375
|
-
)
|
|
379
|
+
Returns export data which is displayed directly in the UI.
|
|
380
|
+
"""
|
|
381
|
+
if not ctx.deps.state.query_id or not ctx.deps.state.run_id:
|
|
382
|
+
raise ModelRetry("No search has been executed yet. Run a search first before exporting.")
|
|
376
383
|
|
|
377
384
|
logger.debug(
|
|
378
385
|
"Prepared query for export",
|
|
@@ -381,16 +388,88 @@ async def prepare_export(
|
|
|
381
388
|
|
|
382
389
|
download_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}/export"
|
|
383
390
|
|
|
384
|
-
|
|
391
|
+
export_data = ExportData(
|
|
385
392
|
query_id=str(ctx.deps.state.query_id),
|
|
386
393
|
download_url=download_url,
|
|
387
394
|
message="Export ready for download.",
|
|
388
395
|
)
|
|
389
396
|
|
|
390
|
-
logger.debug("Export
|
|
397
|
+
logger.debug("Export prepared", query_id=export_data.query_id)
|
|
398
|
+
|
|
399
|
+
return export_data
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
@search_toolset.tool(retries=2)
|
|
403
|
+
@require_action(ActionType.COUNT, ActionType.AGGREGATE)
|
|
404
|
+
async def set_grouping(
|
|
405
|
+
ctx: RunContext[StateDeps[SearchState]],
|
|
406
|
+
group_by_paths: list[str],
|
|
407
|
+
) -> StateSnapshotEvent:
|
|
408
|
+
"""Set which field paths to group results by for aggregation.
|
|
409
|
+
|
|
410
|
+
Only used with COUNT or AGGREGATE actions. Paths must exist in the schema; use discover_filter_paths to verify.
|
|
411
|
+
"""
|
|
412
|
+
for path in group_by_paths:
|
|
413
|
+
field_type = validate_filter_path(path)
|
|
414
|
+
if field_type is None:
|
|
415
|
+
raise ModelRetry(
|
|
416
|
+
f"Path '{path}' not found in database schema. "
|
|
417
|
+
f"Use discover_filter_paths(['{path.split('.')[-1]}']) to find valid paths."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"group_by": group_by_paths})
|
|
421
|
+
|
|
422
|
+
return StateSnapshotEvent(
|
|
423
|
+
type=EventType.STATE_SNAPSHOT,
|
|
424
|
+
snapshot=ctx.deps.state.model_dump(),
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@search_toolset.tool(retries=2)
|
|
429
|
+
@require_action(ActionType.AGGREGATE)
|
|
430
|
+
async def set_aggregations(
|
|
431
|
+
ctx: RunContext[StateDeps[SearchState]],
|
|
432
|
+
aggregations: list[Aggregation],
|
|
433
|
+
) -> StateSnapshotEvent:
|
|
434
|
+
"""Define what aggregations to compute over the matching records.
|
|
435
|
+
|
|
436
|
+
Only used with AGGREGATE action. See Aggregation model (CountAggregation, FieldAggregation) for structure and field requirements.
|
|
437
|
+
"""
|
|
438
|
+
# Validate field paths for FieldAggregations
|
|
439
|
+
try:
|
|
440
|
+
for agg in aggregations:
|
|
441
|
+
if isinstance(agg, FieldAggregation):
|
|
442
|
+
validate_aggregation_field(agg.type, agg.field)
|
|
443
|
+
except ValueError as e:
|
|
444
|
+
raise ModelRetry(f"{str(e)} Use discover_filter_paths to find valid paths.")
|
|
445
|
+
|
|
446
|
+
ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"aggregations": aggregations})
|
|
447
|
+
|
|
448
|
+
return StateSnapshotEvent(
|
|
449
|
+
type=EventType.STATE_SNAPSHOT,
|
|
450
|
+
snapshot=ctx.deps.state.model_dump(),
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@search_toolset.tool(retries=2)
|
|
455
|
+
@require_action(ActionType.COUNT, ActionType.AGGREGATE)
|
|
456
|
+
async def set_temporal_grouping(
|
|
457
|
+
ctx: RunContext[StateDeps[SearchState]],
|
|
458
|
+
temporal_groups: list[TemporalGrouping],
|
|
459
|
+
) -> StateSnapshotEvent:
|
|
460
|
+
"""Set temporal grouping to group datetime fields by time periods.
|
|
461
|
+
|
|
462
|
+
Only used with COUNT or AGGREGATE actions. See TemporalGrouping model for structure, periods, and examples.
|
|
463
|
+
"""
|
|
464
|
+
# Validate that fields exist and are datetime types
|
|
465
|
+
try:
|
|
466
|
+
for tg in temporal_groups:
|
|
467
|
+
validate_temporal_grouping_field(tg.field)
|
|
468
|
+
except ValueError as e:
|
|
469
|
+
raise ModelRetry(f"{str(e)} Use discover_filter_paths to find datetime fields.")
|
|
470
|
+
|
|
471
|
+
ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"temporal_group_by": temporal_groups})
|
|
391
472
|
|
|
392
|
-
# Should use StateDelta here? Use snapshot to workaround state persistence issue
|
|
393
|
-
# TODO: Fix root cause; state is empty on frontend when it should have data from run_search
|
|
394
473
|
return StateSnapshotEvent(
|
|
395
474
|
type=EventType.STATE_SNAPSHOT,
|
|
396
475
|
snapshot=ctx.deps.state.model_dump(),
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from functools import wraps
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
17
|
+
|
|
18
|
+
import structlog
|
|
19
|
+
from pydantic_ai import RunContext
|
|
20
|
+
from pydantic_ai.ag_ui import StateDeps
|
|
21
|
+
from pydantic_ai.exceptions import ModelRetry
|
|
22
|
+
|
|
23
|
+
from orchestrator.search.core.types import ActionType
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from orchestrator.search.agent.state import SearchState
|
|
27
|
+
|
|
28
|
+
logger = structlog.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def require_action(*allowed_actions: ActionType) -> Callable:
|
|
32
|
+
"""Validate that the current search action is one of the allowed types.
|
|
33
|
+
|
|
34
|
+
This decorator is in preparation for a future finite state machine implementation
|
|
35
|
+
where we explicitly define which actions are valid in which states.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
@search_toolset.tool
|
|
39
|
+
@require_action(ActionType.SELECT)
|
|
40
|
+
async def run_search(ctx, ...):
|
|
41
|
+
# Only callable when action is SELECT
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
allowed_actions: One or more ActionType values that are valid for this tool.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Decorated function that validates action before execution.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ModelRetry: If current action is not in allowed_actions.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def decorator(func: Callable) -> Callable:
|
|
55
|
+
@wraps(func)
|
|
56
|
+
async def wrapper(ctx: "RunContext[StateDeps[SearchState]]", *args: Any, **kwargs: Any) -> Any:
|
|
57
|
+
if ctx.deps.state.action is None or ctx.deps.state.query is None:
|
|
58
|
+
logger.warning(f"Action validation failed for {func.__name__}: action or query is None")
|
|
59
|
+
raise ModelRetry("Search action and query are not initialized. Call start_new_search first.")
|
|
60
|
+
|
|
61
|
+
current_action = ctx.deps.state.action
|
|
62
|
+
|
|
63
|
+
if current_action not in allowed_actions:
|
|
64
|
+
allowed_names = ", ".join(a.value for a in allowed_actions)
|
|
65
|
+
logger.warning(
|
|
66
|
+
"Invalid action for tool",
|
|
67
|
+
tool=func.__name__,
|
|
68
|
+
allowed_actions=allowed_names,
|
|
69
|
+
current_action=current_action.value,
|
|
70
|
+
)
|
|
71
|
+
raise ModelRetry(
|
|
72
|
+
f"{func.__name__} is only available for {allowed_names} action(s). "
|
|
73
|
+
f"Current action is '{current_action.value}'."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return await func(ctx, *args, **kwargs)
|
|
77
|
+
|
|
78
|
+
return wrapper
|
|
79
|
+
|
|
80
|
+
return decorator
|
|
@@ -10,3 +10,23 @@
|
|
|
10
10
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from .base import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
AggregationType,
|
|
17
|
+
BaseAggregation,
|
|
18
|
+
CountAggregation,
|
|
19
|
+
FieldAggregation,
|
|
20
|
+
TemporalGrouping,
|
|
21
|
+
TemporalPeriod,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"Aggregation",
|
|
26
|
+
"AggregationType",
|
|
27
|
+
"BaseAggregation",
|
|
28
|
+
"CountAggregation",
|
|
29
|
+
"FieldAggregation",
|
|
30
|
+
"TemporalGrouping",
|
|
31
|
+
"TemporalPeriod",
|
|
32
|
+
]
|