orchestrator-core 4.6.1__py3-none-any.whl → 4.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/endpoints/processes.py +4 -1
  3. orchestrator/api/api_v1/endpoints/search.py +44 -34
  4. orchestrator/{search/retrieval/utils.py → cli/search/display.py} +4 -29
  5. orchestrator/cli/search/search_explore.py +22 -24
  6. orchestrator/cli/search/speedtest.py +11 -9
  7. orchestrator/db/models.py +6 -6
  8. orchestrator/graphql/resolvers/helpers.py +15 -0
  9. orchestrator/graphql/resolvers/process.py +5 -3
  10. orchestrator/graphql/resolvers/product.py +3 -2
  11. orchestrator/graphql/resolvers/product_block.py +3 -2
  12. orchestrator/graphql/resolvers/resource_type.py +3 -2
  13. orchestrator/graphql/resolvers/scheduled_tasks.py +3 -1
  14. orchestrator/graphql/resolvers/settings.py +2 -0
  15. orchestrator/graphql/resolvers/subscription.py +5 -3
  16. orchestrator/graphql/resolvers/version.py +2 -0
  17. orchestrator/graphql/resolvers/workflow.py +3 -2
  18. orchestrator/graphql/schemas/process.py +3 -3
  19. orchestrator/log_config.py +2 -0
  20. orchestrator/schemas/search.py +1 -1
  21. orchestrator/schemas/search_requests.py +59 -0
  22. orchestrator/search/agent/handlers.py +129 -0
  23. orchestrator/search/agent/prompts.py +54 -33
  24. orchestrator/search/agent/state.py +9 -24
  25. orchestrator/search/agent/tools.py +223 -144
  26. orchestrator/search/agent/validation.py +80 -0
  27. orchestrator/search/{schemas → aggregations}/__init__.py +20 -0
  28. orchestrator/search/aggregations/base.py +201 -0
  29. orchestrator/search/core/types.py +3 -2
  30. orchestrator/search/filters/__init__.py +4 -0
  31. orchestrator/search/filters/definitions.py +22 -1
  32. orchestrator/search/filters/numeric_filter.py +3 -3
  33. orchestrator/search/llm_migration.py +2 -1
  34. orchestrator/search/query/__init__.py +90 -0
  35. orchestrator/search/query/builder.py +285 -0
  36. orchestrator/search/query/engine.py +162 -0
  37. orchestrator/search/{retrieval → query}/exceptions.py +38 -7
  38. orchestrator/search/query/mixins.py +95 -0
  39. orchestrator/search/query/queries.py +129 -0
  40. orchestrator/search/query/results.py +252 -0
  41. orchestrator/search/{retrieval/query_state.py → query/state.py} +31 -11
  42. orchestrator/search/{retrieval → query}/validation.py +58 -1
  43. orchestrator/search/retrieval/__init__.py +0 -5
  44. orchestrator/search/retrieval/pagination.py +7 -8
  45. orchestrator/search/retrieval/retrievers/base.py +9 -9
  46. orchestrator/workflows/translations/en-GB.json +1 -0
  47. {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/METADATA +16 -15
  48. {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/RECORD +51 -45
  49. orchestrator/search/retrieval/builder.py +0 -127
  50. orchestrator/search/retrieval/engine.py +0 -197
  51. orchestrator/search/schemas/parameters.py +0 -133
  52. orchestrator/search/schemas/results.py +0 -80
  53. /orchestrator/search/{export.py → query/export.py} +0 -0
  54. {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/WHEEL +0 -0
  55. {orchestrator_core-4.6.1.dist-info → orchestrator_core-4.6.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,10 +12,10 @@
12
12
  # limitations under the License.
13
13
 
14
14
  import json
15
- from typing import Any
15
+ from typing import Any, cast
16
16
 
17
17
  import structlog
18
- from ag_ui.core import EventType, StateDeltaEvent, StateSnapshotEvent
18
+ from ag_ui.core import EventType, StateSnapshotEvent
19
19
  from pydantic_ai import RunContext
20
20
  from pydantic_ai.ag_ui import StateDeps
21
21
  from pydantic_ai.exceptions import ModelRetry
@@ -26,17 +26,28 @@ from orchestrator.api.api_v1.endpoints.search import (
26
26
  get_definitions,
27
27
  list_paths,
28
28
  )
29
- from orchestrator.db import AgentRunTable, SearchQueryTable, db
30
- from orchestrator.search.agent.json_patch import JSONPatchOp
31
- from orchestrator.search.agent.state import ExportData, SearchResultsData, SearchState
29
+ from orchestrator.db import db
30
+ from orchestrator.search.agent.handlers import (
31
+ execute_aggregation_with_persistence,
32
+ execute_search_with_persistence,
33
+ )
34
+ from orchestrator.search.agent.state import SearchState
35
+ from orchestrator.search.agent.validation import require_action
36
+ from orchestrator.search.aggregations import Aggregation, FieldAggregation, TemporalGrouping
32
37
  from orchestrator.search.core.types import ActionType, EntityType, FilterOp
33
- from orchestrator.search.export import fetch_export_data
34
38
  from orchestrator.search.filters import FilterTree
35
- from orchestrator.search.retrieval.engine import execute_search
36
- from orchestrator.search.retrieval.exceptions import FilterValidationError, PathNotFoundError
37
- from orchestrator.search.retrieval.query_state import SearchQueryState
38
- from orchestrator.search.retrieval.validation import validate_filter_tree
39
- from orchestrator.search.schemas.parameters import BaseSearchParameters
39
+ from orchestrator.search.query import engine
40
+ from orchestrator.search.query.exceptions import PathNotFoundError, QueryValidationError
41
+ from orchestrator.search.query.export import fetch_export_data
42
+ from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query, SelectQuery
43
+ from orchestrator.search.query.results import AggregationResponse, AggregationResult, ExportData, VisualizationType
44
+ from orchestrator.search.query.state import QueryState
45
+ from orchestrator.search.query.validation import (
46
+ validate_aggregation_field,
47
+ validate_filter_path,
48
+ validate_filter_tree,
49
+ validate_temporal_grouping_field,
50
+ )
40
51
  from orchestrator.settings import app_settings
41
52
 
42
53
  logger = structlog.get_logger(__name__)
@@ -53,27 +64,11 @@ def last_user_message(ctx: RunContext[StateDeps[SearchState]]) -> str | None:
53
64
  return None
54
65
 
55
66
 
56
- def _set_parameters(
57
- ctx: RunContext[StateDeps[SearchState]],
58
- entity_type: EntityType,
59
- action: str | ActionType,
60
- query: str,
61
- filters: Any | None,
62
- ) -> None:
63
- """Internal helper to set parameters."""
64
- ctx.deps.state.parameters = {
65
- "action": action,
66
- "entity_type": entity_type,
67
- "filters": filters,
68
- "query": query,
69
- }
70
-
71
-
72
67
  @search_toolset.tool
73
68
  async def start_new_search(
74
69
  ctx: RunContext[StateDeps[SearchState]],
75
70
  entity_type: EntityType,
76
- action: str | ActionType = ActionType.SELECT,
71
+ action: ActionType = ActionType.SELECT,
77
72
  ) -> StateSnapshotEvent:
78
73
  """Starts a completely new search, clearing all previous state.
79
74
 
@@ -90,13 +85,26 @@ async def start_new_search(
90
85
  )
91
86
 
92
87
  # Clear all state
93
- ctx.deps.state.results_data = None
94
- ctx.deps.state.export_data = None
95
-
96
- # Set fresh parameters with no filters
97
- _set_parameters(ctx, entity_type, action, final_query, None)
88
+ ctx.deps.state.results_count = None
89
+ ctx.deps.state.action = action
90
+
91
+ # Create the appropriate query object based on action
92
+ if action == ActionType.SELECT:
93
+ ctx.deps.state.query = SelectQuery(
94
+ entity_type=entity_type,
95
+ query_text=final_query,
96
+ )
97
+ elif action == ActionType.COUNT:
98
+ ctx.deps.state.query = CountQuery(
99
+ entity_type=entity_type,
100
+ )
101
+ else: # ActionType.AGGREGATE
102
+ ctx.deps.state.query = AggregateQuery(
103
+ entity_type=entity_type,
104
+ aggregations=[], # Will be set by set_aggregations tool
105
+ )
98
106
 
99
- logger.debug("New search started", parameters=ctx.deps.state.parameters)
107
+ logger.debug("New search started", action=action.value, query_type=type(ctx.deps.state.query).__name__)
100
108
 
101
109
  return StateSnapshotEvent(
102
110
  type=EventType.STATE_SNAPSHOT,
@@ -108,18 +116,15 @@ async def start_new_search(
108
116
  async def set_filter_tree(
109
117
  ctx: RunContext[StateDeps[SearchState]],
110
118
  filters: FilterTree | None,
111
- ) -> StateDeltaEvent:
119
+ ) -> StateSnapshotEvent:
112
120
  """Replace current filters atomically with a full FilterTree, or clear with None.
113
121
 
114
- Requirements:
115
- - Root/group operators must be 'AND' or 'OR' (uppercase).
116
- - Provide either PathFilters or nested groups under `children`.
117
- - See the FilterTree schema examples for the exact shape.
122
+ See FilterTree model for structure, operators, and examples.
118
123
  """
119
- if ctx.deps.state.parameters is None:
120
- raise ModelRetry("Search parameters are not initialized. Call start_new_search first.")
124
+ if ctx.deps.state.query is None:
125
+ raise ModelRetry("Search query is not initialized. Call start_new_search first.")
121
126
 
122
- entity_type = EntityType(ctx.deps.state.parameters["entity_type"])
127
+ entity_type = ctx.deps.state.query.entity_type
123
128
 
124
129
  logger.debug(
125
130
  "Setting filter tree",
@@ -133,108 +138,112 @@ async def set_filter_tree(
133
138
  except PathNotFoundError as e:
134
139
  logger.debug(f"{PathNotFoundError.__name__}: {str(e)}")
135
140
  raise ModelRetry(f"{str(e)} Use discover_filter_paths tool to find valid paths.")
136
- except FilterValidationError as e:
141
+ except QueryValidationError as e:
137
142
  # ModelRetry will trigger an agent retry, containing the specific validation error.
138
- logger.debug(f"Filter validation failed: {str(e)}")
143
+ logger.debug(f"Query validation failed: {str(e)}")
139
144
  raise ModelRetry(str(e))
140
145
  except Exception as e:
141
146
  logger.error("Unexpected Filter validation exception", error=str(e))
142
147
  raise ModelRetry(f"Filter validation failed: {str(e)}. Please check your filter structure and try again.")
143
148
 
144
- filter_data = None if filters is None else filters.model_dump(mode="json", by_alias=True)
145
- filters_existed = "filters" in ctx.deps.state.parameters
146
- ctx.deps.state.parameters["filters"] = filter_data
147
- return StateDeltaEvent(
148
- type=EventType.STATE_DELTA,
149
- delta=[
150
- JSONPatchOp.upsert(
151
- path="/parameters/filters",
152
- value=filter_data,
153
- existed=filters_existed,
154
- )
155
- ],
149
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"filters": filters})
150
+
151
+ # Use snapshot to workaround state persistence issue
152
+ # TODO: Fix root cause; state tree may be empty on frontend when parameters are being set
153
+ return StateSnapshotEvent(
154
+ type=EventType.STATE_SNAPSHOT,
155
+ snapshot=ctx.deps.state.model_dump(),
156
156
  )
157
157
 
158
158
 
159
159
  @search_toolset.tool
160
+ @require_action(ActionType.SELECT)
160
161
  async def run_search(
161
162
  ctx: RunContext[StateDeps[SearchState]],
162
163
  limit: int = 10,
163
- ) -> StateDeltaEvent:
164
- """Execute the search with the current parameters and save to database."""
165
- if not ctx.deps.state.parameters:
166
- raise ValueError("No search parameters set")
164
+ ) -> AggregationResponse:
165
+ """Execute a search to find and rank entities.
167
166
 
168
- params = BaseSearchParameters.create(**ctx.deps.state.parameters)
169
- logger.debug(
170
- "Executing database search",
171
- search_entity_type=params.entity_type.value,
172
- limit=limit,
173
- has_filters=params.filters is not None,
174
- query=params.query,
175
- action=params.action,
167
+ Use this tool for SELECT action to find entities matching your criteria.
168
+ For counting or computing statistics, use run_aggregation instead.
169
+ """
170
+ query = cast(SelectQuery, cast(Query, ctx.deps.state.query).model_copy(update={"limit": limit}))
171
+
172
+ search_response, run_id, query_id = await execute_search_with_persistence(query, db.session, ctx.deps.state.run_id)
173
+
174
+ ctx.deps.state.run_id = run_id
175
+ ctx.deps.state.query_id = query_id
176
+ ctx.deps.state.results_count = len(search_response.results)
177
+
178
+ # Convert SearchResults to AggregationResults for consistent rendering
179
+ aggregation_results = [
180
+ AggregationResult(
181
+ group_values={
182
+ "entity_id": result.entity_id,
183
+ "title": result.entity_title,
184
+ "entity_type": result.entity_type.value,
185
+ },
186
+ aggregations={"score": result.score},
187
+ )
188
+ for result in search_response.results
189
+ ]
190
+
191
+ # For now use the default table visualization for search results
192
+ aggregation_response = AggregationResponse(
193
+ results=aggregation_results,
194
+ total_groups=len(aggregation_results),
195
+ metadata=search_response.metadata,
196
+ visualization_type=VisualizationType(type="table"),
176
197
  )
177
198
 
178
- if params.filters:
179
- logger.debug("Search filters", filters=params.filters)
180
-
181
- params.limit = limit
199
+ logger.debug(
200
+ "Search completed",
201
+ total_count=ctx.deps.state.results_count,
202
+ query_id=str(query_id),
203
+ )
182
204
 
183
- changes: list[JSONPatchOp] = []
205
+ return aggregation_response
184
206
 
185
- if not ctx.deps.state.run_id:
186
- agent_run = AgentRunTable(agent_type="search")
187
207
 
188
- db.session.add(agent_run)
189
- db.session.commit()
190
- db.session.expire_all() # Release connection to prevent stacking while agent runs
208
+ @search_toolset.tool
209
+ @require_action(ActionType.COUNT, ActionType.AGGREGATE)
210
+ async def run_aggregation(
211
+ ctx: RunContext[StateDeps[SearchState]],
212
+ visualization_type: VisualizationType,
213
+ ) -> AggregationResponse:
214
+ """Execute an aggregation to compute counts or statistics over entities.
191
215
 
192
- ctx.deps.state.run_id = agent_run.run_id
193
- logger.debug("Created new agent run", run_id=str(agent_run.run_id))
194
- changes.append(JSONPatchOp(op="add", path="/run_id", value=str(ctx.deps.state.run_id)))
216
+ Use this tool for COUNT or AGGREGATE actions after setting up:
217
+ - Grouping fields with set_grouping or set_temporal_grouping
218
+ - Aggregation functions with set_aggregations (for AGGREGATE action)
219
+ """
220
+ query = cast(CountQuery | AggregateQuery, ctx.deps.state.query)
195
221
 
196
- # Get query with embedding and save to DB
197
- search_response = await execute_search(params, db.session)
198
- query_embedding = search_response.query_embedding
199
- query_state = SearchQueryState(parameters=params, query_embedding=query_embedding)
200
- query_number = db.session.query(SearchQueryTable).filter_by(run_id=ctx.deps.state.run_id).count() + 1
201
- search_query = SearchQueryTable.from_state(
202
- state=query_state,
203
- run_id=ctx.deps.state.run_id,
204
- query_number=query_number,
222
+ logger.debug(
223
+ "Executing aggregation",
224
+ search_entity_type=query.entity_type.value,
225
+ has_filters=query.filters is not None,
226
+ action=query.action,
205
227
  )
206
- db.session.add(search_query)
207
- db.session.commit()
208
- db.session.expire_all()
209
-
210
- query_id_existed = ctx.deps.state.query_id is not None
211
- ctx.deps.state.query_id = search_query.query_id
212
- logger.debug("Saved search query", query_id=str(search_query.query_id), query_number=query_number)
213
- changes.append(JSONPatchOp.upsert(path="/query_id", value=str(ctx.deps.state.query_id), existed=query_id_existed))
214
228
 
215
- logger.debug(
216
- "Search completed",
217
- total_results=len(search_response.results),
229
+ aggregation_response, run_id, query_id = await execute_aggregation_with_persistence(
230
+ query, db.session, ctx.deps.state.run_id
218
231
  )
219
232
 
220
- # Store results data for both frontend display and agent context
221
- results_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}"
233
+ ctx.deps.state.run_id = run_id
234
+ ctx.deps.state.query_id = query_id
235
+ ctx.deps.state.results_count = aggregation_response.total_groups
222
236
 
223
- results_data_existed = ctx.deps.state.results_data is not None
224
- ctx.deps.state.results_data = SearchResultsData(
225
- query_id=str(ctx.deps.state.query_id),
226
- results_url=results_url,
227
- total_count=len(search_response.results),
228
- message=f"Found {len(search_response.results)} results.",
229
- results=search_response.results, # Include actual results in state
230
- )
231
- changes.append(
232
- JSONPatchOp.upsert(
233
- path="/results_data", value=ctx.deps.state.results_data.model_dump(), existed=results_data_existed
234
- )
237
+ aggregation_response.visualization_type = visualization_type
238
+
239
+ logger.debug(
240
+ "Aggregation completed",
241
+ total_groups=aggregation_response.total_groups,
242
+ visualization_type=visualization_type.type,
243
+ query_id=str(query_id),
235
244
  )
236
245
 
237
- return StateDeltaEvent(type=EventType.STATE_DELTA, delta=changes)
246
+ return aggregation_response
238
247
 
239
248
 
240
249
  @search_toolset.tool
@@ -248,10 +257,11 @@ async def discover_filter_paths(
248
257
  Returns a dictionary where each key is a field_name from the input list and
249
258
  the value is its discovery result.
250
259
  """
251
- if not entity_type and ctx.deps.state.parameters:
252
- entity_type = EntityType(ctx.deps.state.parameters.get("entity_type"))
253
260
  if not entity_type:
254
- entity_type = EntityType.SUBSCRIPTION
261
+ if ctx.deps.state.query:
262
+ entity_type = ctx.deps.state.query.entity_type
263
+ else:
264
+ raise ModelRetry("Entity type not specified and no query in state. Call start_new_search first.")
255
265
 
256
266
  all_results = {}
257
267
  for field_name in field_names:
@@ -332,17 +342,21 @@ async def fetch_entity_details(
332
342
  JSON string containing detailed entity information.
333
343
 
334
344
  Raises:
335
- ValueError: If no search results are available.
345
+ ModelRetry: If no search has been executed.
336
346
  """
337
- if not ctx.deps.state.results_data or not ctx.deps.state.results_data.results:
338
- raise ValueError("No search results available. Run a search first before fetching entity details.")
347
+ if ctx.deps.state.query_id is None:
348
+ raise ModelRetry("No query_id found. Run a search first.")
339
349
 
340
- if not ctx.deps.state.parameters:
341
- raise ValueError("No search parameters found.")
350
+ # Load the saved query and re-execute it to get entity IDs
351
+ query_state = QueryState.load_from_id(ctx.deps.state.query_id, SelectQuery)
352
+ query = query_state.query.model_copy(update={"limit": limit})
353
+ search_response = await engine.execute_search(query, db.session)
354
+ entity_ids = [r.entity_id for r in search_response.results]
342
355
 
343
- entity_type = EntityType(ctx.deps.state.parameters["entity_type"])
356
+ if not entity_ids:
357
+ return json.dumps({"message": "No entities found in search results."})
344
358
 
345
- entity_ids = [r.entity_id for r in ctx.deps.state.results_data.results[:limit]]
359
+ entity_type = query.entity_type
346
360
 
347
361
  logger.debug(
348
362
  "Fetching detailed entity data",
@@ -356,23 +370,16 @@ async def fetch_entity_details(
356
370
 
357
371
 
358
372
  @search_toolset.tool
373
+ @require_action(ActionType.SELECT)
359
374
  async def prepare_export(
360
375
  ctx: RunContext[StateDeps[SearchState]],
361
- ) -> StateSnapshotEvent:
362
- """Prepares export URL using the last executed search query."""
363
- if not ctx.deps.state.query_id or not ctx.deps.state.run_id:
364
- raise ValueError("No search has been executed yet. Run a search first before exporting.")
376
+ ) -> ExportData:
377
+ """Prepares export URL using the last executed search query.
365
378
 
366
- if not ctx.deps.state.parameters:
367
- raise ValueError("No search parameters found. Run a search first before exporting.")
368
-
369
- # Validate that export is only available for SELECT actions
370
- action = ctx.deps.state.parameters.get("action", ActionType.SELECT)
371
- if action != ActionType.SELECT:
372
- raise ValueError(
373
- f"Export is only available for SELECT actions. Current action is '{action}'. "
374
- "Please run a SELECT search first."
375
- )
379
+ Returns export data which is displayed directly in the UI.
380
+ """
381
+ if not ctx.deps.state.query_id or not ctx.deps.state.run_id:
382
+ raise ModelRetry("No search has been executed yet. Run a search first before exporting.")
376
383
 
377
384
  logger.debug(
378
385
  "Prepared query for export",
@@ -381,16 +388,88 @@ async def prepare_export(
381
388
 
382
389
  download_url = f"{app_settings.BASE_URL}/api/search/queries/{ctx.deps.state.query_id}/export"
383
390
 
384
- ctx.deps.state.export_data = ExportData(
391
+ export_data = ExportData(
385
392
  query_id=str(ctx.deps.state.query_id),
386
393
  download_url=download_url,
387
394
  message="Export ready for download.",
388
395
  )
389
396
 
390
- logger.debug("Export data set in state", export_data=ctx.deps.state.export_data.model_dump())
397
+ logger.debug("Export prepared", query_id=export_data.query_id)
398
+
399
+ return export_data
400
+
401
+
402
+ @search_toolset.tool(retries=2)
403
+ @require_action(ActionType.COUNT, ActionType.AGGREGATE)
404
+ async def set_grouping(
405
+ ctx: RunContext[StateDeps[SearchState]],
406
+ group_by_paths: list[str],
407
+ ) -> StateSnapshotEvent:
408
+ """Set which field paths to group results by for aggregation.
409
+
410
+ Only used with COUNT or AGGREGATE actions. Paths must exist in the schema; use discover_filter_paths to verify.
411
+ """
412
+ for path in group_by_paths:
413
+ field_type = validate_filter_path(path)
414
+ if field_type is None:
415
+ raise ModelRetry(
416
+ f"Path '{path}' not found in database schema. "
417
+ f"Use discover_filter_paths(['{path.split('.')[-1]}']) to find valid paths."
418
+ )
419
+
420
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"group_by": group_by_paths})
421
+
422
+ return StateSnapshotEvent(
423
+ type=EventType.STATE_SNAPSHOT,
424
+ snapshot=ctx.deps.state.model_dump(),
425
+ )
426
+
427
+
428
+ @search_toolset.tool(retries=2)
429
+ @require_action(ActionType.AGGREGATE)
430
+ async def set_aggregations(
431
+ ctx: RunContext[StateDeps[SearchState]],
432
+ aggregations: list[Aggregation],
433
+ ) -> StateSnapshotEvent:
434
+ """Define what aggregations to compute over the matching records.
435
+
436
+ Only used with AGGREGATE action. See Aggregation model (CountAggregation, FieldAggregation) for structure and field requirements.
437
+ """
438
+ # Validate field paths for FieldAggregations
439
+ try:
440
+ for agg in aggregations:
441
+ if isinstance(agg, FieldAggregation):
442
+ validate_aggregation_field(agg.type, agg.field)
443
+ except ValueError as e:
444
+ raise ModelRetry(f"{str(e)} Use discover_filter_paths to find valid paths.")
445
+
446
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"aggregations": aggregations})
447
+
448
+ return StateSnapshotEvent(
449
+ type=EventType.STATE_SNAPSHOT,
450
+ snapshot=ctx.deps.state.model_dump(),
451
+ )
452
+
453
+
454
+ @search_toolset.tool(retries=2)
455
+ @require_action(ActionType.COUNT, ActionType.AGGREGATE)
456
+ async def set_temporal_grouping(
457
+ ctx: RunContext[StateDeps[SearchState]],
458
+ temporal_groups: list[TemporalGrouping],
459
+ ) -> StateSnapshotEvent:
460
+ """Set temporal grouping to group datetime fields by time periods.
461
+
462
+ Only used with COUNT or AGGREGATE actions. See TemporalGrouping model for structure, periods, and examples.
463
+ """
464
+ # Validate that fields exist and are datetime types
465
+ try:
466
+ for tg in temporal_groups:
467
+ validate_temporal_grouping_field(tg.field)
468
+ except ValueError as e:
469
+ raise ModelRetry(f"{str(e)} Use discover_filter_paths to find datetime fields.")
470
+
471
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"temporal_group_by": temporal_groups})
391
472
 
392
- # Should use StateDelta here? Use snapshot to workaround state persistence issue
393
- # TODO: Fix root cause; state is empty on frontend when it should have data from run_search
394
473
  return StateSnapshotEvent(
395
474
  type=EventType.STATE_SNAPSHOT,
396
475
  snapshot=ctx.deps.state.model_dump(),
@@ -0,0 +1,80 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+
15
+ from functools import wraps
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ import structlog
19
+ from pydantic_ai import RunContext
20
+ from pydantic_ai.ag_ui import StateDeps
21
+ from pydantic_ai.exceptions import ModelRetry
22
+
23
+ from orchestrator.search.core.types import ActionType
24
+
25
+ if TYPE_CHECKING:
26
+ from orchestrator.search.agent.state import SearchState
27
+
28
+ logger = structlog.get_logger(__name__)
29
+
30
+
31
+ def require_action(*allowed_actions: ActionType) -> Callable:
32
+ """Validate that the current search action is one of the allowed types.
33
+
34
+ This decorator is in preparation for a future finite state machine implementation
35
+ where we explicitly define which actions are valid in which states.
36
+
37
+ Example:
38
+ @search_toolset.tool
39
+ @require_action(ActionType.SELECT)
40
+ async def run_search(ctx, ...):
41
+ # Only callable when action is SELECT
42
+ ...
43
+
44
+ Args:
45
+ allowed_actions: One or more ActionType values that are valid for this tool.
46
+
47
+ Returns:
48
+ Decorated function that validates action before execution.
49
+
50
+ Raises:
51
+ ModelRetry: If current action is not in allowed_actions.
52
+ """
53
+
54
+ def decorator(func: Callable) -> Callable:
55
+ @wraps(func)
56
+ async def wrapper(ctx: "RunContext[StateDeps[SearchState]]", *args: Any, **kwargs: Any) -> Any:
57
+ if ctx.deps.state.action is None or ctx.deps.state.query is None:
58
+ logger.warning(f"Action validation failed for {func.__name__}: action or query is None")
59
+ raise ModelRetry("Search action and query are not initialized. Call start_new_search first.")
60
+
61
+ current_action = ctx.deps.state.action
62
+
63
+ if current_action not in allowed_actions:
64
+ allowed_names = ", ".join(a.value for a in allowed_actions)
65
+ logger.warning(
66
+ "Invalid action for tool",
67
+ tool=func.__name__,
68
+ allowed_actions=allowed_names,
69
+ current_action=current_action.value,
70
+ )
71
+ raise ModelRetry(
72
+ f"{func.__name__} is only available for {allowed_names} action(s). "
73
+ f"Current action is '{current_action.value}'."
74
+ )
75
+
76
+ return await func(ctx, *args, **kwargs)
77
+
78
+ return wrapper
79
+
80
+ return decorator
@@ -10,3 +10,23 @@
10
10
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
+
14
+ from .base import (
15
+ Aggregation,
16
+ AggregationType,
17
+ BaseAggregation,
18
+ CountAggregation,
19
+ FieldAggregation,
20
+ TemporalGrouping,
21
+ TemporalPeriod,
22
+ )
23
+
24
+ __all__ = [
25
+ "Aggregation",
26
+ "AggregationType",
27
+ "BaseAggregation",
28
+ "CountAggregation",
29
+ "FieldAggregation",
30
+ "TemporalGrouping",
31
+ "TemporalPeriod",
32
+ ]