orchestrator-core 4.6.4__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/api.py +4 -0
  3. orchestrator/api/api_v1/endpoints/processes.py +25 -9
  4. orchestrator/api/api_v1/endpoints/schedules.py +44 -0
  5. orchestrator/app.py +34 -1
  6. orchestrator/cli/scheduler.py +126 -11
  7. orchestrator/cli/search/resize_embedding.py +3 -0
  8. orchestrator/db/models.py +26 -0
  9. orchestrator/graphql/schemas/process.py +2 -2
  10. orchestrator/graphql/schemas/workflow.py +1 -1
  11. orchestrator/llm_settings.py +0 -1
  12. orchestrator/migrations/versions/schema/2020-10-19_a76b9185b334_add_generic_workflows_to_core.py +1 -0
  13. orchestrator/migrations/versions/schema/2021-04-06_3c8b9185c221_add_validate_products_task.py +1 -0
  14. orchestrator/migrations/versions/schema/2025-11-18_961eddbd4c13_create_linker_table_workflow_apscheduler.py +106 -0
  15. orchestrator/migrations/versions/schema/2025-12-10_9736496e3eba_set_is_task_true_on_certain_tasks.py +40 -0
  16. orchestrator/schedules/__init__.py +8 -7
  17. orchestrator/schedules/scheduler.py +27 -1
  18. orchestrator/schedules/scheduling.py +5 -1
  19. orchestrator/schedules/service.py +253 -0
  20. orchestrator/schemas/schedules.py +71 -0
  21. orchestrator/search/agent/prompts.py +10 -6
  22. orchestrator/search/agent/tools.py +55 -15
  23. orchestrator/search/aggregations/base.py +6 -2
  24. orchestrator/search/query/builder.py +76 -1
  25. orchestrator/search/query/mixins.py +57 -2
  26. orchestrator/search/query/queries.py +15 -1
  27. orchestrator/search/query/validation.py +43 -0
  28. orchestrator/services/processes.py +0 -7
  29. orchestrator/services/workflows.py +4 -0
  30. orchestrator/settings.py +48 -0
  31. orchestrator/utils/auth.py +2 -2
  32. orchestrator/websocket/__init__.py +14 -0
  33. orchestrator/workflow.py +1 -1
  34. orchestrator/workflows/__init__.py +1 -0
  35. orchestrator/workflows/modify_note.py +10 -1
  36. orchestrator/workflows/removed_workflow.py +8 -1
  37. orchestrator/workflows/tasks/cleanup_tasks_log.py +9 -2
  38. orchestrator/workflows/tasks/resume_workflows.py +4 -0
  39. orchestrator/workflows/tasks/validate_product_type.py +7 -1
  40. orchestrator/workflows/tasks/validate_products.py +9 -1
  41. orchestrator/{schedules → workflows/tasks}/validate_subscriptions.py +16 -3
  42. orchestrator/workflows/translations/en-GB.json +2 -1
  43. {orchestrator_core-4.6.4.dist-info → orchestrator_core-4.7.0.dist-info}/METADATA +13 -13
  44. {orchestrator_core-4.6.4.dist-info → orchestrator_core-4.7.0.dist-info}/RECORD +46 -43
  45. orchestrator/schedules/resume_workflows.py +0 -21
  46. orchestrator/schedules/task_vacuum.py +0 -21
  47. {orchestrator_core-4.6.4.dist-info → orchestrator_core-4.7.0.dist-info}/WHEEL +0 -0
  48. {orchestrator_core-4.6.4.dist-info → orchestrator_core-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,6 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
 
14
-
15
14
  from contextlib import contextmanager
16
15
  from datetime import datetime
17
16
  from typing import Any, Generator
@@ -27,6 +26,7 @@ from orchestrator.db.filters import Filter
27
26
  from orchestrator.db.filters.filters import CallableErrorHandler
28
27
  from orchestrator.db.sorting import Sort
29
28
  from orchestrator.db.sorting.sorting import SortOrder
29
+ from orchestrator.schedules.service import get_linker_entries_by_schedule_ids
30
30
  from orchestrator.utils.helpers import camel_to_snake, to_camel
31
31
 
32
32
  executors = {
@@ -75,6 +75,7 @@ def get_scheduler(paused: bool = False) -> Generator[BackgroundScheduler, Any, N
75
75
 
76
76
  class ScheduledTask(BaseModel):
77
77
  id: str
78
+ workflow_id: str | None = None
78
79
  name: str | None = None
79
80
  next_run_time: datetime | None = None
80
81
  trigger: str
@@ -161,6 +162,29 @@ def default_error_handler(message: str, **context) -> None: # type: ignore
161
162
  raise ValueError(f"{message} {_format_context(context)}")
162
163
 
163
164
 
165
+ def enrich_with_workflow_id(scheduled_tasks: list[ScheduledTask]) -> list[ScheduledTask]:
166
+ """Does a get call to the linker table to get the workflow_id for each scheduled task.
167
+
168
+ Returns all the scheduled tasks with the workflow_id added.
169
+ """
170
+ schedule_ids = [task.id for task in scheduled_tasks]
171
+
172
+ entries = {
173
+ str(entry.schedule_id): str(entry.workflow_id) for entry in get_linker_entries_by_schedule_ids(schedule_ids)
174
+ }
175
+
176
+ return [
177
+ ScheduledTask(
178
+ id=task.id,
179
+ workflow_id=entries.get(task.id, None),
180
+ name=task.name,
181
+ next_run_time=task.next_run_time,
182
+ trigger=str(task.trigger),
183
+ )
184
+ for task in scheduled_tasks
185
+ ]
186
+
187
+
164
188
  def get_scheduler_tasks(
165
189
  first: int = 10,
166
190
  after: int = 0,
@@ -171,6 +195,7 @@ def get_scheduler_tasks(
171
195
  scheduled_tasks = get_all_scheduler_tasks()
172
196
  scheduled_tasks = filter_scheduled_tasks(scheduled_tasks, error_handler, filter_by)
173
197
  scheduled_tasks = sort_scheduled_tasks(scheduled_tasks, error_handler, sort_by)
198
+ scheduled_tasks = enrich_with_workflow_id(scheduled_tasks)
174
199
 
175
200
  total = len(scheduled_tasks)
176
201
  paginated_tasks = scheduled_tasks[after : after + first + 1]
@@ -178,6 +203,7 @@ def get_scheduler_tasks(
178
203
  return [
179
204
  ScheduledTask(
180
205
  id=task.id,
206
+ workflow_id=task.workflow_id,
181
207
  name=task.name,
182
208
  next_run_time=task.next_run_time,
183
209
  trigger=str(task.trigger),
@@ -23,7 +23,11 @@ F = TypeVar("F", bound=Callable[..., object])
23
23
 
24
24
 
25
25
  @deprecated(
26
- reason="We changed from scheduler to apscheduler which has its own decoractor, use `@scheduler.scheduled_job()` from `from orchestrator.scheduling.scheduler import scheduler`"
26
+ reason=(
27
+ "Scheduling tasks with a decorator is deprecated in favor of using the API. "
28
+ "This decorator will be removed in 5.0.0. "
29
+ "For more details, please consult https://workfloworchestrator.org/orchestrator-core/guides/upgrading/4.7/"
30
+ )
27
31
  )
28
32
  def scheduler(
29
33
  name: str,
@@ -0,0 +1,253 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ import json
14
+ import logging
15
+ from uuid import UUID, uuid4
16
+
17
+ from apscheduler.schedulers.base import BaseScheduler
18
+ from apscheduler.triggers.cron import CronTrigger
19
+ from apscheduler.triggers.date import DateTrigger
20
+ from apscheduler.triggers.interval import IntervalTrigger
21
+ from sqlalchemy import delete
22
+
23
+ from orchestrator import app_settings
24
+ from orchestrator.db import db
25
+ from orchestrator.db.models import WorkflowApschedulerJob
26
+ from orchestrator.schemas.schedules import (
27
+ APSchedulerJobCreate,
28
+ APSchedulerJobDelete,
29
+ APSchedulerJobs,
30
+ APSchedulerJobUpdate,
31
+ APSJobAdapter,
32
+ )
33
+ from orchestrator.services.processes import start_process
34
+ from orchestrator.services.workflows import get_workflow_by_workflow_id
35
+ from orchestrator.utils.redis_client import create_redis_client
36
+
37
+ redis_connection = create_redis_client(app_settings.CACHE_URI)
38
+
39
+ SCHEDULER_QUEUE = "scheduler:queue:"
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ def serialize_payload(payload: APSchedulerJobs) -> bytes:
46
+ """Serialize the payload to bytes for Redis storage.
47
+
48
+ Args:
49
+ payload: APSchedulerJobs The scheduled task payload.
50
+ """
51
+ data = json.loads(payload.model_dump_json())
52
+ data["scheduled_type"] = payload.scheduled_type
53
+ return json.dumps(data).encode()
54
+
55
+
56
+ def deserialize_payload(bytes_dump: bytes) -> APSchedulerJobs:
57
+ """Deserialize the payload from bytes for Redis retrieval.
58
+
59
+ Args:
60
+ bytes_dump: bytes The serialized payload.
61
+ """
62
+ json_dump = bytes_dump.decode()
63
+ return APSJobAdapter.validate_json(json_dump)
64
+
65
+
66
+ def add_scheduled_task_to_queue(payload: APSchedulerJobs) -> None:
67
+ """Create a scheduled task service function.
68
+
69
+ We need to create a apscheduler job, and put the workflow and schedule_id in
70
+ the linker table workflows_apscheduler_jobs.
71
+
72
+ Args:
73
+ payload: APSchedulerJobCreate The scheduled task to create.
74
+ """
75
+ bytes_dump = serialize_payload(payload)
76
+ redis_connection.lpush(SCHEDULER_QUEUE, bytes_dump)
77
+ logger.info("Added scheduled task to queue.")
78
+
79
+
80
+ def get_linker_entries_by_schedule_ids(schedule_ids: list[str]) -> list[WorkflowApschedulerJob]:
81
+ """Get linker table entries for multiple schedule IDs in a single query.
82
+
83
+ Args:
84
+ schedule_ids: list[str] — One or many schedule IDs.
85
+
86
+ Returns:
87
+ list[WorkflowApschedulerJob]: All linker table rows matching those IDs.
88
+ """
89
+ if not schedule_ids:
90
+ return []
91
+
92
+ return db.session.query(WorkflowApschedulerJob).filter(WorkflowApschedulerJob.schedule_id.in_(schedule_ids)).all()
93
+
94
+
95
+ def _add_linker_entry(workflow_id: UUID, schedule_id: str) -> None:
96
+ """Add an entry to the linker table workflows_apscheduler_jobs.
97
+
98
+ Args:
99
+ workflow_id: UUID The workflow ID.
100
+ schedule_id: str The schedule ID.
101
+ """
102
+ workflows_apscheduler_job = WorkflowApschedulerJob(workflow_id=workflow_id, schedule_id=schedule_id)
103
+ db.session.add(workflows_apscheduler_job)
104
+ db.session.commit()
105
+
106
+
107
+ def _delete_linker_entry(workflow_id: UUID, schedule_id: str) -> None:
108
+ """Delete an entry from the linker table workflows_apscheduler_jobs.
109
+
110
+ Args:
111
+ workflow_id: UUID The workflow ID.
112
+ schedule_id: str The schedule ID.
113
+ """
114
+ db.session.execute(
115
+ delete(WorkflowApschedulerJob).where(
116
+ WorkflowApschedulerJob.workflow_id == workflow_id, WorkflowApschedulerJob.schedule_id == schedule_id
117
+ )
118
+ )
119
+ db.session.commit()
120
+
121
+
122
+ def run_start_workflow_scheduler_task(workflow_name: str) -> None:
123
+ """Function to start a workflow from the scheduler.
124
+
125
+ Args:
126
+ workflow_name: str The name of the workflow to start.
127
+ """
128
+ logger.info(f"Starting workflow: {workflow_name}")
129
+ start_process(workflow_name)
130
+
131
+
132
+ def _add_scheduled_task(payload: APSchedulerJobCreate, scheduler_connection: BaseScheduler) -> None:
133
+ """Create a new scheduled task in the scheduler and also in the linker table.
134
+
135
+ Args:
136
+ payload: APSchedulerJobCreate The scheduled task to create.
137
+ scheduler_connection: BaseScheduler The scheduler connection.
138
+ """
139
+ logger.info(f"Adding scheduled task: {payload}")
140
+
141
+ workflow_description = None
142
+ # Check if a workflow exists - we cannot schedule a non-existing workflow
143
+ workflow = get_workflow_by_workflow_id(str(payload.workflow_id))
144
+ if not workflow:
145
+ raise ValueError(f"Workflow with id {payload.workflow_id} does not exist.")
146
+ workflow_description = workflow.description
147
+
148
+ # This function is always the same for scheduled tasks, it will run the workflow
149
+ func = run_start_workflow_scheduler_task
150
+
151
+ # Ensure payload has required data
152
+ if not payload.trigger or not payload.workflow_name or not payload.trigger_kwargs or not payload.workflow_id:
153
+ raise ValueError("Trigger must be specified for scheduled tasks.")
154
+
155
+ schedule_id = str(uuid4())
156
+ scheduler_connection.add_job(
157
+ func=func,
158
+ trigger=payload.trigger,
159
+ id=schedule_id,
160
+ name=payload.name or workflow_description,
161
+ kwargs={"workflow_name": payload.workflow_name},
162
+ **(payload.trigger_kwargs or {}),
163
+ )
164
+
165
+ _add_linker_entry(workflow_id=payload.workflow_id, schedule_id=schedule_id)
166
+
167
+
168
+ def _build_trigger_on_update(
169
+ trigger_name: str | None, trigger_kwargs: dict
170
+ ) -> IntervalTrigger | CronTrigger | DateTrigger | None:
171
+ if not trigger_name or not trigger_kwargs:
172
+ logger.info("Skipping building trigger as no trigger information is provided.")
173
+ return None
174
+
175
+ match trigger_name:
176
+ case "interval":
177
+ return IntervalTrigger(**trigger_kwargs)
178
+ case "cron":
179
+ return CronTrigger(**trigger_kwargs)
180
+ case "date":
181
+ return DateTrigger(**trigger_kwargs)
182
+ case _:
183
+ raise ValueError(f"Invalid trigger type: {trigger_name}")
184
+
185
+
186
+ def _update_scheduled_task(payload: APSchedulerJobUpdate, scheduler_connection: BaseScheduler) -> None:
187
+ """Update an existing scheduled task in the scheduler.
188
+
189
+ Only allow update of name and trigger
190
+ Job id must be that of an existing job
191
+ Do not insert in linker table - it should already exist.
192
+
193
+ Args:
194
+ payload: APSchedulerJobUpdate The scheduled task to update.
195
+ scheduler_connection: BaseScheduler The scheduler connection.
196
+ """
197
+ logger.info(f"Updating scheduled task: {payload}")
198
+
199
+ schedule_id = str(payload.schedule_id)
200
+ job = scheduler_connection.get_job(job_id=schedule_id)
201
+ if not job:
202
+ raise ValueError(f"Schedule Job with id {schedule_id} does not exist.")
203
+
204
+ trigger = _build_trigger_on_update(payload.trigger, payload.trigger_kwargs or {})
205
+ modify_kwargs = {}
206
+
207
+ if trigger:
208
+ job = job.reschedule(trigger=trigger)
209
+
210
+ if payload.name:
211
+ modify_kwargs["name"] = payload.name
212
+
213
+ job.modify(**modify_kwargs)
214
+
215
+
216
+ def _delete_scheduled_task(payload: APSchedulerJobDelete, scheduler_connection: BaseScheduler) -> None:
217
+ """Delete an existing scheduled task in the scheduler and also in the linker table.
218
+
219
+ Args:
220
+ payload: APSchedulerJobDelete The scheduled task to delete.
221
+ scheduler_connection: BaseScheduler The scheduler connection.
222
+ """
223
+ logger.info(f"Deleting scheduled task: {payload}")
224
+
225
+ schedule_id = str(payload.schedule_id)
226
+ scheduler_connection.remove_job(job_id=schedule_id)
227
+ _delete_linker_entry(workflow_id=payload.workflow_id, schedule_id=schedule_id)
228
+
229
+
230
+ def workflow_scheduler_queue(queue_item: tuple[str, bytes], scheduler_connection: BaseScheduler) -> None:
231
+ """Process an item from the scheduler queue.
232
+
233
+ Args:
234
+ queue_item: tuple[str, bytes] The item from the scheduler queue.
235
+ scheduler_connection: BaseScheduler The scheduler connection.
236
+ """
237
+ try:
238
+ _, bytes_dump = queue_item
239
+ payload = deserialize_payload(bytes_dump)
240
+ match payload:
241
+ case APSchedulerJobCreate():
242
+ _add_scheduled_task(payload, scheduler_connection)
243
+
244
+ case APSchedulerJobUpdate():
245
+ _update_scheduled_task(payload, scheduler_connection)
246
+
247
+ case APSchedulerJobDelete():
248
+ _delete_scheduled_task(payload, scheduler_connection)
249
+
250
+ case _:
251
+ logger.warning(f"Unexpected schedule type: {payload}") # type: ignore
252
+ except Exception:
253
+ logger.exception("Error processing scheduler queue item")
@@ -0,0 +1,71 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ from typing import Annotated, Any, Literal, Union
14
+ from uuid import UUID
15
+
16
+ from pydantic import BaseModel, Field, TypeAdapter
17
+
18
+ SCHEDULER_Q_CREATE = "create"
19
+ SCHEDULER_Q_UPDATE = "update"
20
+ SCHEDULER_Q_DELETE = "delete"
21
+
22
+
23
+ class APSchedulerJob(BaseModel):
24
+ scheduled_type: Literal["create", "update", "delete"] = Field(..., description="Discriminator for job type")
25
+
26
+
27
+ class APSchedulerJobCreate(APSchedulerJob):
28
+ name: str | None = Field(None, description="Human readable name e.g. 'My Process'")
29
+ workflow_name: str = Field(..., description="Name of the workflow to run e.g. 'my_workflow_name'")
30
+ workflow_id: UUID = Field(..., description="UUID of the workflow associated with this scheduled task")
31
+
32
+ trigger: Literal["interval", "cron", "date"] = Field(..., description="APScheduler trigger type")
33
+ trigger_kwargs: dict[str, Any] = Field(
34
+ default_factory=lambda: {},
35
+ description="Arguments passed to the trigger on job creation",
36
+ examples=[{"hours": 12}, {"minutes": 30}, {"days": 1, "hours": 2}],
37
+ )
38
+
39
+ scheduled_type: Literal["create"] = Field("create", frozen=True)
40
+
41
+
42
+ class APSchedulerJobUpdate(APSchedulerJob):
43
+ name: str | None = Field(None, description="Human readable name e.g. 'My Process'")
44
+ schedule_id: UUID = Field(..., description="UUID of the scheduled task")
45
+
46
+ trigger: Literal["interval", "cron", "date"] | None = Field(None, description="APScheduler trigger type")
47
+ trigger_kwargs: dict[str, Any] | None = Field(
48
+ default=None,
49
+ description="Arguments passed to the job function",
50
+ examples=[{"hours": 12}, {"minutes": 30}, {"days": 1, "hours": 2}],
51
+ )
52
+
53
+ scheduled_type: Literal["update"] = Field("update", frozen=True)
54
+
55
+
56
+ class APSchedulerJobDelete(APSchedulerJob):
57
+ workflow_id: UUID = Field(..., description="UUID of the workflow associated with this scheduled task")
58
+ schedule_id: UUID | None = Field(None, description="UUID of the scheduled task")
59
+
60
+ scheduled_type: Literal["delete"] = Field("delete", frozen=True)
61
+
62
+
63
+ APSchedulerJobs = Annotated[
64
+ Union[
65
+ APSchedulerJobCreate,
66
+ APSchedulerJobUpdate,
67
+ APSchedulerJobDelete,
68
+ ],
69
+ Field(discriminator="scheduled_type"),
70
+ ]
71
+ APSJobAdapter = TypeAdapter(APSchedulerJobs) # type: ignore
@@ -26,7 +26,7 @@ logger = structlog.get_logger(__name__)
26
26
 
27
27
  async def get_base_instructions() -> str:
28
28
  return dedent(
29
- """
29
+ f"""
30
30
  You are an expert assistant designed to find relevant information by building and running database queries.
31
31
 
32
32
  ---
@@ -50,17 +50,21 @@ async def get_base_instructions() -> str:
50
50
 
51
51
  Follow these steps:
52
52
 
53
- 1. **Set Context**: Call `start_new_search` with appropriate entity_type and action
53
+ 1. **Set Context**: Call `start_new_search` with appropriate entity_type and action:
54
+ - `action={ActionType.SELECT.value}` for finding/searching entities
55
+ - `action={ActionType.COUNT.value}` for counting (e.g., "how many", "count by status", "monthly growth")
56
+ - `action={ActionType.AGGREGATE.value}` for numeric operations (SUM, AVG, MIN, MAX of specific fields)
54
57
  2. **Set Filters** (if needed): Discover paths, build FilterTree, call `set_filter_tree`
55
58
  - IMPORTANT: Temporal constraints like "in 2025", "in January", "between X and Y" require filters on datetime fields
56
59
  - Filters restrict WHICH records to include; grouping controls HOW to aggregate them
57
- 3. **Set Grouping/Aggregations** (for COUNT/AGGREGATE):
60
+ 3. **Set Grouping/Aggregations** (for {ActionType.COUNT.value}/{ActionType.AGGREGATE.value}):
58
61
  - For temporal grouping (per month, per year, per day, etc.): Use `set_temporal_grouping`
59
62
  - For regular grouping (by status, by name, etc.): Use `set_grouping`
60
- - For aggregations: Use `set_aggregations`
63
+ - For {ActionType.AGGREGATE.value} action ONLY: Use `set_aggregations` to specify what to compute (SUM, AVG, etc.)
64
+ - For {ActionType.COUNT.value} action: Do NOT call `set_aggregations` (counting is automatic)
61
65
  4. **Execute**:
62
- - For SELECT action: Call `run_search()`
63
- - For COUNT/AGGREGATE actions: Call `run_aggregation()`
66
+ - For {ActionType.SELECT.value} action: Call `run_search()`
67
+ - For {ActionType.COUNT.value}/{ActionType.AGGREGATE.value} actions: Call `run_aggregation()`
64
68
 
65
69
  After search execution, follow the dynamic instructions based on the current state.
66
70
 
@@ -16,6 +16,7 @@ from typing import Any, cast
16
16
 
17
17
  import structlog
18
18
  from ag_ui.core import EventType, StateSnapshotEvent
19
+ from pydantic import ValidationError
19
20
  from pydantic_ai import RunContext
20
21
  from pydantic_ai.ag_ui import StateDeps
21
22
  from pydantic_ai.exceptions import ModelRetry
@@ -39,13 +40,15 @@ from orchestrator.search.filters import FilterTree
39
40
  from orchestrator.search.query import engine
40
41
  from orchestrator.search.query.exceptions import PathNotFoundError, QueryValidationError
41
42
  from orchestrator.search.query.export import fetch_export_data
43
+ from orchestrator.search.query.mixins import OrderBy
42
44
  from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query, SelectQuery
43
45
  from orchestrator.search.query.results import AggregationResponse, AggregationResult, ExportData, VisualizationType
44
46
  from orchestrator.search.query.state import QueryState
45
47
  from orchestrator.search.query.validation import (
46
48
  validate_aggregation_field,
47
- validate_filter_path,
48
49
  validate_filter_tree,
50
+ validate_grouping_fields,
51
+ validate_order_by_fields,
49
52
  validate_temporal_grouping_field,
50
53
  )
51
54
  from orchestrator.settings import app_settings
@@ -404,20 +407,30 @@ async def prepare_export(
404
407
  async def set_grouping(
405
408
  ctx: RunContext[StateDeps[SearchState]],
406
409
  group_by_paths: list[str],
410
+ order_by: list[OrderBy] | None = None,
407
411
  ) -> StateSnapshotEvent:
408
412
  """Set which field paths to group results by for aggregation.
409
413
 
410
414
  Only used with COUNT or AGGREGATE actions. Paths must exist in the schema; use discover_filter_paths to verify.
415
+ Optionally specify ordering for the grouped results.
416
+
417
+ For order_by: You can order by grouping field paths OR aggregation aliases (e.g., 'count').
418
+ Grouping field paths will be validated; aggregation aliases cannot be validated until execution.
411
419
  """
412
- for path in group_by_paths:
413
- field_type = validate_filter_path(path)
414
- if field_type is None:
415
- raise ModelRetry(
416
- f"Path '{path}' not found in database schema. "
417
- f"Use discover_filter_paths(['{path.split('.')[-1]}']) to find valid paths."
418
- )
420
+ try:
421
+ validate_grouping_fields(group_by_paths)
422
+ validate_order_by_fields(order_by)
423
+ except PathNotFoundError as e:
424
+ raise ModelRetry(f"{str(e)} Use discover_filter_paths to find valid paths.")
419
425
 
420
- ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"group_by": group_by_paths})
426
+ update_dict: dict[str, Any] = {"group_by": group_by_paths}
427
+ if order_by is not None:
428
+ update_dict["order_by"] = order_by
429
+
430
+ try:
431
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update=update_dict)
432
+ except ValidationError as e:
433
+ raise ModelRetry(str(e))
421
434
 
422
435
  return StateSnapshotEvent(
423
436
  type=EventType.STATE_SNAPSHOT,
@@ -434,16 +447,26 @@ async def set_aggregations(
434
447
  """Define what aggregations to compute over the matching records.
435
448
 
436
449
  Only used with AGGREGATE action. See Aggregation model (CountAggregation, FieldAggregation) for structure and field requirements.
450
+
437
451
  """
438
452
  # Validate field paths for FieldAggregations
439
453
  try:
440
454
  for agg in aggregations:
441
455
  if isinstance(agg, FieldAggregation):
442
456
  validate_aggregation_field(agg.type, agg.field)
443
- except ValueError as e:
444
- raise ModelRetry(f"{str(e)} Use discover_filter_paths to find valid paths.")
457
+ except PathNotFoundError as e:
458
+ raise ModelRetry(
459
+ f"{str(e)} "
460
+ f"You MUST call discover_filter_paths first to find valid fields. "
461
+ f"If the field truly doesn't exist, inform the user that this data is not available."
462
+ )
463
+ except QueryValidationError as e:
464
+ raise ModelRetry(f"{str(e)}")
445
465
 
446
- ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"aggregations": aggregations})
466
+ try:
467
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"aggregations": aggregations})
468
+ except ValidationError as e:
469
+ raise ModelRetry(str(e))
447
470
 
448
471
  return StateSnapshotEvent(
449
472
  type=EventType.STATE_SNAPSHOT,
@@ -456,19 +479,36 @@ async def set_aggregations(
456
479
  async def set_temporal_grouping(
457
480
  ctx: RunContext[StateDeps[SearchState]],
458
481
  temporal_groups: list[TemporalGrouping],
482
+ cumulative: bool = False,
483
+ order_by: list[OrderBy] | None = None,
459
484
  ) -> StateSnapshotEvent:
460
485
  """Set temporal grouping to group datetime fields by time periods.
461
486
 
462
487
  Only used with COUNT or AGGREGATE actions. See TemporalGrouping model for structure, periods, and examples.
488
+ Optionally enable cumulative aggregations (running totals) and specify ordering.
489
+
490
+ For order_by: You can order by temporal field paths OR aggregation aliases (e.g., 'count').
491
+ Temporal field paths will be validated; aggregation aliases cannot be validated until execution.
463
492
  """
464
- # Validate that fields exist and are datetime types
465
493
  try:
466
494
  for tg in temporal_groups:
467
495
  validate_temporal_grouping_field(tg.field)
468
- except ValueError as e:
496
+ validate_order_by_fields(order_by)
497
+ except PathNotFoundError as e:
498
+ raise ModelRetry(f"{str(e)} Use discover_filter_paths to find valid paths.")
499
+ except QueryValidationError as e:
469
500
  raise ModelRetry(f"{str(e)} Use discover_filter_paths to find datetime fields.")
470
501
 
471
- ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update={"temporal_group_by": temporal_groups})
502
+ update_dict: dict[str, Any] = {"temporal_group_by": temporal_groups}
503
+ if cumulative:
504
+ update_dict["cumulative"] = cumulative
505
+ if order_by is not None:
506
+ update_dict["order_by"] = order_by
507
+
508
+ try:
509
+ ctx.deps.state.query = cast(Query, ctx.deps.state.query).model_copy(update=update_dict)
510
+ except ValidationError as e:
511
+ raise ModelRetry(str(e))
472
512
 
473
513
  return StateSnapshotEvent(
474
514
  type=EventType.STATE_SNAPSHOT,
@@ -61,6 +61,11 @@ class TemporalGrouping(BaseModel):
61
61
  },
62
62
  )
63
63
 
64
+ @property
65
+ def alias(self) -> str:
66
+ """Return the SQL-friendly alias for this temporal grouping."""
67
+ return f"{BaseAggregation.field_to_alias(self.field)}_{self.period.value}"
68
+
64
69
  def get_pivot_fields(self) -> list[str]:
65
70
  """Return fields that need to be pivoted for this temporal grouping."""
66
71
  return [self.field]
@@ -83,8 +88,7 @@ class TemporalGrouping(BaseModel):
83
88
  col = getattr(pivot_cte_columns, field_alias)
84
89
  truncated_col = func.date_trunc(self.period.value, cast(col, TIMESTAMP(timezone=True)))
85
90
 
86
- # Column name without prefix
87
- col_name = f"{field_alias}_{self.period.value}"
91
+ col_name = self.alias
88
92
  select_col = truncated_col.label(col_name)
89
93
  return select_col, truncated_col, col_name
90
94