orchestrator-core 4.6.2__py3-none-any.whl → 4.6.3rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/endpoints/search.py +44 -34
  3. orchestrator/{search/retrieval/utils.py → cli/search/display.py} +4 -29
  4. orchestrator/cli/search/search_explore.py +22 -24
  5. orchestrator/cli/search/speedtest.py +11 -9
  6. orchestrator/db/models.py +6 -6
  7. orchestrator/log_config.py +2 -0
  8. orchestrator/schemas/search.py +1 -1
  9. orchestrator/schemas/search_requests.py +59 -0
  10. orchestrator/search/agent/handlers.py +129 -0
  11. orchestrator/search/agent/prompts.py +54 -33
  12. orchestrator/search/agent/state.py +9 -24
  13. orchestrator/search/agent/tools.py +223 -144
  14. orchestrator/search/agent/validation.py +80 -0
  15. orchestrator/search/{schemas → aggregations}/__init__.py +20 -0
  16. orchestrator/search/aggregations/base.py +201 -0
  17. orchestrator/search/core/types.py +3 -2
  18. orchestrator/search/filters/__init__.py +4 -0
  19. orchestrator/search/filters/definitions.py +22 -1
  20. orchestrator/search/filters/numeric_filter.py +3 -3
  21. orchestrator/search/llm_migration.py +2 -1
  22. orchestrator/search/query/__init__.py +90 -0
  23. orchestrator/search/query/builder.py +285 -0
  24. orchestrator/search/query/engine.py +162 -0
  25. orchestrator/search/{retrieval → query}/exceptions.py +38 -7
  26. orchestrator/search/query/mixins.py +95 -0
  27. orchestrator/search/query/queries.py +129 -0
  28. orchestrator/search/query/results.py +252 -0
  29. orchestrator/search/{retrieval/query_state.py → query/state.py} +31 -11
  30. orchestrator/search/{retrieval → query}/validation.py +58 -1
  31. orchestrator/search/retrieval/__init__.py +0 -5
  32. orchestrator/search/retrieval/pagination.py +7 -8
  33. orchestrator/search/retrieval/retrievers/base.py +9 -9
  34. {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/METADATA +6 -6
  35. {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/RECORD +38 -32
  36. orchestrator/search/retrieval/builder.py +0 -127
  37. orchestrator/search/retrieval/engine.py +0 -197
  38. orchestrator/search/schemas/parameters.py +0 -133
  39. orchestrator/search/schemas/results.py +0 -80
  40. /orchestrator/search/{export.py → query/export.py} +0 -0
  41. {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/WHEEL +0 -0
  42. {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,201 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from abc import abstractmethod
15
+ from enum import Enum
16
+ from typing import Annotated, Any, Literal, TypeAlias
17
+
18
+ from pydantic import BaseModel, ConfigDict, Field
19
+ from sqlalchemy import Integer, cast, func
20
+ from sqlalchemy.sql.elements import ColumnElement, Label
21
+
22
+
23
+ class AggregationType(str, Enum):
24
+ """Types of aggregations that can be computed."""
25
+
26
+ COUNT = "count"
27
+ SUM = "sum"
28
+ AVG = "avg"
29
+ MIN = "min"
30
+ MAX = "max"
31
+
32
+
33
+ class TemporalPeriod(str, Enum):
34
+ """Time periods for temporal grouping."""
35
+
36
+ YEAR = "year"
37
+ QUARTER = "quarter"
38
+ MONTH = "month"
39
+ WEEK = "week"
40
+ DAY = "day"
41
+ HOUR = "hour"
42
+
43
+
44
+ class TemporalGrouping(BaseModel):
45
+ """Defines temporal grouping for date/time fields.
46
+
47
+ Used to group query results by time periods (e.g., group subscriptions by start_date per month).
48
+ """
49
+
50
+ field: str = Field(description="The datetime field path to group by temporally.")
51
+ period: TemporalPeriod = Field(description="The time period to group by (year, quarter, month, week, day, hour).")
52
+
53
+ model_config = ConfigDict(
54
+ extra="forbid",
55
+ json_schema_extra={
56
+ "examples": [
57
+ {"field": "subscription.start_date", "period": "month"},
58
+ {"field": "subscription.end_date", "period": "year"},
59
+ {"field": "process.created_at", "period": "day"},
60
+ ]
61
+ },
62
+ )
63
+
64
+ def get_pivot_fields(self) -> list[str]:
65
+ """Return fields that need to be pivoted for this temporal grouping."""
66
+ return [self.field]
67
+
68
+ def to_expression(self, pivot_cte_columns: Any) -> tuple[Label, Any, str]:
69
+ """Build SQLAlchemy expression for temporal grouping.
70
+
71
+ Args:
72
+ pivot_cte_columns: The columns object from the pivot CTE
73
+
74
+ Returns:
75
+ tuple: (select_column, group_by_column, column_name)
76
+ - select_column: Labeled column for SELECT
77
+ - group_by_column: Column expression for GROUP BY
78
+ - column_name: The label/name of the column in results
79
+ """
80
+ from sqlalchemy import TIMESTAMP, cast, func
81
+
82
+ field_alias = BaseAggregation.field_to_alias(self.field)
83
+ col = getattr(pivot_cte_columns, field_alias)
84
+ truncated_col = func.date_trunc(self.period.value, cast(col, TIMESTAMP(timezone=True)))
85
+
86
+ # Column name without prefix
87
+ col_name = f"{field_alias}_{self.period.value}"
88
+ select_col = truncated_col.label(col_name)
89
+ return select_col, truncated_col, col_name
90
+
91
+
92
+ class BaseAggregation(BaseModel):
93
+ """Base class for all aggregation types."""
94
+
95
+ type: AggregationType = Field(description="The type of aggregation to perform.")
96
+ alias: str = Field(description="The name for this aggregation in the results.")
97
+
98
+ @classmethod
99
+ def create(cls, data: dict) -> "Aggregation":
100
+ """Create the correct aggregation instance based on type field.
101
+
102
+ Args:
103
+ data: Dictionary with aggregation data including 'type' discriminator
104
+
105
+ Returns:
106
+ Validated aggregation instance (CountAggregation or FieldAggregation)
107
+
108
+ Raises:
109
+ ValidationError: If data is invalid or type is unknown
110
+ """
111
+ from pydantic import TypeAdapter
112
+
113
+ adapter: TypeAdapter = TypeAdapter(Aggregation)
114
+ return adapter.validate_python(data)
115
+
116
+ @staticmethod
117
+ def field_to_alias(field_path: str) -> str:
118
+ """Convert field path to SQL column alias.
119
+
120
+ Examples:
121
+ 'subscription.name' -> 'subscription_name'
122
+ 'product.serial-number' -> 'product_serial_number'
123
+ """
124
+ return field_path.replace(".", "_").replace("-", "_")
125
+
126
+ def get_pivot_fields(self) -> list[str]:
127
+ """Return fields that need to be pivoted for this aggregation."""
128
+ return []
129
+
130
+ @abstractmethod
131
+ def to_expression(self, *args: Any, **kwargs: Any) -> Label:
132
+ """Build SQLAlchemy expression for this aggregation.
133
+
134
+ Returns:
135
+ Label: A labeled SQLAlchemy expression
136
+ """
137
+ raise NotImplementedError
138
+
139
+
140
+ class CountAggregation(BaseAggregation):
141
+ """Count aggregation - counts number of entities."""
142
+
143
+ type: Literal[AggregationType.COUNT]
144
+
145
+ def to_expression(self, entity_id_column: ColumnElement) -> Label:
146
+ """Build SQLAlchemy expression for count aggregation.
147
+
148
+ Args:
149
+ entity_id_column: The entity_id column from the pivot CTE
150
+
151
+ Returns:
152
+ Label: A labeled SQLAlchemy expression
153
+ """
154
+ return func.count(entity_id_column).label(self.alias)
155
+
156
+
157
+ class FieldAggregation(BaseAggregation):
158
+ """Field-based aggregation (sum, avg, min, max)."""
159
+
160
+ type: Literal[AggregationType.SUM, AggregationType.AVG, AggregationType.MIN, AggregationType.MAX]
161
+ field: str = Field(description="The field path to aggregate on.")
162
+
163
+ def get_pivot_fields(self) -> list[str]:
164
+ """Return fields that need to be pivoted for this aggregation."""
165
+ return [self.field]
166
+
167
+ def to_expression(self, pivot_cte_columns: Any) -> Label:
168
+ """Build SQLAlchemy expression for field-based aggregation.
169
+
170
+ Args:
171
+ pivot_cte_columns: The columns object from the pivot CTE
172
+
173
+ Returns:
174
+ Label: A labeled SQLAlchemy expression
175
+
176
+ Raises:
177
+ ValueError: If the field is not found in the pivot CTE
178
+ """
179
+ field_alias = self.field_to_alias(self.field)
180
+
181
+ if not hasattr(pivot_cte_columns, field_alias):
182
+ raise ValueError(f"Field '{self.field}' (alias: '{field_alias}') not found in pivot CTE columns")
183
+
184
+ col = getattr(pivot_cte_columns, field_alias)
185
+
186
+ numeric_col = cast(col, Integer)
187
+
188
+ match self.type:
189
+ case AggregationType.SUM:
190
+ return func.sum(numeric_col).label(self.alias)
191
+ case AggregationType.AVG:
192
+ return func.avg(numeric_col).label(self.alias)
193
+ case AggregationType.MIN:
194
+ return func.min(numeric_col).label(self.alias)
195
+ case AggregationType.MAX:
196
+ return func.max(numeric_col).label(self.alias)
197
+ case _:
198
+ raise ValueError(f"Unsupported aggregation type: {self.type}")
199
+
200
+
201
+ Aggregation: TypeAlias = Annotated[CountAggregation | FieldAggregation, Field(discriminator="type")]
@@ -105,7 +105,8 @@ class ActionType(str, Enum):
105
105
  """Defines the explicit, safe actions the agent can request."""
106
106
 
107
107
  SELECT = "select" # Retrieve a list of matching records.
108
- # COUNT = "count" # For phase1; the agent will not support this yet.
108
+ COUNT = "count" # Count matching records, optionally grouped.
109
+ AGGREGATE = "aggregate" # Compute aggregations (sum, avg, etc.) over matching records.
109
110
 
110
111
 
111
112
  class UIType(str, Enum):
@@ -261,7 +262,7 @@ class FieldType(str, Enum):
261
262
 
262
263
  def is_embeddable(self, value: str | None) -> bool:
263
264
  """Check if a field should be embedded."""
264
- if value is None:
265
+ if value is None or value == "":
265
266
  return False
266
267
 
267
268
  # If inference suggests it's not actually a string, don't embed it
@@ -19,6 +19,7 @@ from .base import (
19
19
  StringFilter,
20
20
  )
21
21
  from .date_filters import DateFilter, DateRangeFilter, DateValueFilter
22
+ from .definitions import TypeDefinition, ValueSchema
22
23
  from .ltree_filters import LtreeFilter
23
24
  from .numeric_filter import NumericFilter, NumericRangeFilter, NumericValueFilter
24
25
 
@@ -37,4 +38,7 @@ __all__ = [
37
38
  "DateFilter",
38
39
  "LtreeFilter",
39
40
  "NumericFilter",
41
+ # Schema types
42
+ "TypeDefinition",
43
+ "ValueSchema",
40
44
  ]
@@ -11,8 +11,29 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
 
14
+ from typing import Literal
15
+
16
+ from pydantic import BaseModel, ConfigDict
17
+
14
18
  from orchestrator.search.core.types import FieldType, FilterOp, UIType
15
- from orchestrator.search.schemas.results import TypeDefinition, ValueSchema
19
+
20
+
21
+ class ValueSchema(BaseModel):
22
+ """Schema describing the expected value type for a filter operator."""
23
+
24
+ kind: UIType | Literal["none", "object"] = UIType.STRING
25
+ fields: dict[str, "ValueSchema"] | None = None
26
+
27
+ model_config = ConfigDict(extra="forbid")
28
+
29
+
30
+ class TypeDefinition(BaseModel):
31
+ """Definition of available operators and their value schemas for a field type."""
32
+
33
+ operators: list[FilterOp]
34
+ value_schema: dict[FilterOp, ValueSchema]
35
+
36
+ model_config = ConfigDict(use_enum_values=True)
16
37
 
17
38
 
18
39
  def operators_for(ft: FieldType) -> list[FilterOp]:
@@ -14,7 +14,7 @@
14
14
  from typing import Annotated, Any, Literal
15
15
 
16
16
  from pydantic import BaseModel, Field, model_validator
17
- from sqlalchemy import DOUBLE_PRECISION, INTEGER, and_
17
+ from sqlalchemy import BIGINT, DOUBLE_PRECISION, and_
18
18
  from sqlalchemy import cast as sa_cast
19
19
  from sqlalchemy.sql.elements import ColumnElement
20
20
  from typing_extensions import Self
@@ -40,7 +40,7 @@ class NumericValueFilter(BaseModel):
40
40
  value: int | float
41
41
 
42
42
  def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
43
- cast_type = INTEGER if isinstance(self.value, int) else DOUBLE_PRECISION
43
+ cast_type = BIGINT if isinstance(self.value, int) else DOUBLE_PRECISION
44
44
  numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
45
45
  match self.op:
46
46
 
@@ -65,7 +65,7 @@ class NumericRangeFilter(BaseModel):
65
65
  value: NumericRange
66
66
 
67
67
  def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
68
- cast_type = INTEGER if isinstance(self.value.start, int) else DOUBLE_PRECISION
68
+ cast_type = BIGINT if isinstance(self.value.start, int) else DOUBLE_PRECISION
69
69
  numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
70
70
  return and_(numeric_column >= self.value.start, numeric_column <= self.value.end)
71
71
 
@@ -13,6 +13,7 @@
13
13
 
14
14
  """Simple search migration function that runs when SEARCH_ENABLED = True."""
15
15
 
16
+ import pydantic_ai
16
17
  from sqlalchemy import text
17
18
  from sqlalchemy.engine import Connection
18
19
  from structlog import get_logger
@@ -28,7 +29,7 @@ TARGET_DIM = 1536
28
29
 
29
30
  def run_migration(connection: Connection) -> None:
30
31
  """Run LLM migration with ON CONFLICT DO NOTHING pattern."""
31
- logger.info("Running LLM migration")
32
+ logger.info("Running LLM migration", pydantic_ai_version=pydantic_ai.__version__)
32
33
 
33
34
  try:
34
35
  # Test to see if the extenstion exists and then skip the migration; Needed for certain situations where db user
@@ -0,0 +1,90 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ """Query building and execution module."""
15
+
16
+ from orchestrator.search.aggregations import TemporalGrouping
17
+
18
+ from . import engine
19
+ from .builder import (
20
+ ComponentInfo,
21
+ LeafInfo,
22
+ build_aggregation_query,
23
+ build_candidate_query,
24
+ build_paths_query,
25
+ process_path_rows,
26
+ )
27
+ from .exceptions import (
28
+ EmptyFilterPathError,
29
+ IncompatibleAggregationTypeError,
30
+ IncompatibleFilterTypeError,
31
+ IncompatibleTemporalGroupingTypeError,
32
+ InvalidEntityPrefixError,
33
+ InvalidLtreePatternError,
34
+ PathNotFoundError,
35
+ QueryValidationError,
36
+ )
37
+ from .queries import AggregateQuery, CountQuery, ExportQuery, Query, SelectQuery
38
+ from .results import (
39
+ AggregationResponse,
40
+ AggregationResult,
41
+ MatchingField,
42
+ SearchResponse,
43
+ SearchResult,
44
+ VisualizationType,
45
+ format_aggregation_response,
46
+ format_search_response,
47
+ generate_highlight_indices,
48
+ )
49
+ from .state import QueryState
50
+
51
+ __all__ = [
52
+ # Builder functions
53
+ "build_aggregation_query",
54
+ "build_candidate_query",
55
+ "build_paths_query",
56
+ "process_path_rows",
57
+ # Builder metadata
58
+ "ComponentInfo",
59
+ "LeafInfo",
60
+ # Engine
61
+ "engine",
62
+ # Exceptions
63
+ "EmptyFilterPathError",
64
+ "IncompatibleAggregationTypeError",
65
+ "IncompatibleFilterTypeError",
66
+ "IncompatibleTemporalGroupingTypeError",
67
+ "InvalidEntityPrefixError",
68
+ "InvalidLtreePatternError",
69
+ "PathNotFoundError",
70
+ "QueryValidationError",
71
+ # Query models
72
+ "AggregateQuery",
73
+ "CountQuery",
74
+ "ExportQuery",
75
+ "Query",
76
+ "SelectQuery",
77
+ "TemporalGrouping",
78
+ # Results
79
+ "AggregationResponse",
80
+ "AggregationResult",
81
+ "MatchingField",
82
+ "SearchResponse",
83
+ "SearchResult",
84
+ "VisualizationType",
85
+ "format_aggregation_response",
86
+ "format_search_response",
87
+ "generate_highlight_indices",
88
+ # State
89
+ "QueryState",
90
+ ]
@@ -0,0 +1,285 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from collections import defaultdict
15
+ from typing import Any, Sequence
16
+
17
+ from pydantic import BaseModel, ConfigDict
18
+ from sqlalchemy import Select, String, case, cast, func, select
19
+ from sqlalchemy.engine import Row
20
+ from sqlalchemy.sql.elements import Label
21
+ from sqlalchemy.sql.selectable import CTE
22
+ from sqlalchemy_utils.types.ltree import Ltree
23
+
24
+ from orchestrator.db.models import AiSearchIndex
25
+ from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation
26
+ from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
27
+ from orchestrator.search.filters import LtreeFilter
28
+ from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query
29
+
30
+
31
+ class LeafInfo(BaseModel):
32
+ """Information about a leaf (terminal field) in the entity schema."""
33
+
34
+ name: str
35
+ ui_types: list[UIType]
36
+ paths: list[str]
37
+
38
+ model_config = ConfigDict(
39
+ extra="forbid",
40
+ use_enum_values=True,
41
+ )
42
+
43
+
44
+ class ComponentInfo(BaseModel):
45
+ """Information about a component (nested object) in the entity schema."""
46
+
47
+ name: str
48
+ ui_types: list[UIType]
49
+
50
+ model_config = ConfigDict(
51
+ extra="forbid",
52
+ use_enum_values=True,
53
+ )
54
+
55
+
56
+ def create_path_autocomplete_lquery(prefix: str) -> str:
57
+ """Create the lquery pattern for a multi-level path autocomplete search."""
58
+ return f"{prefix}*.*"
59
+
60
+
61
+ def build_candidate_query(query: Query) -> Select:
62
+ """Build the base query for retrieving candidate entities.
63
+
64
+ Constructs a `SELECT` statement that retrieves distinct `entity_id` values
65
+ from the index table for the given entity type, applying any structured
66
+ filters from the provided query plan.
67
+
68
+ Args:
69
+ query: Any query type (SelectQuery, CountQuery, AggregateQuery) containing entity type and optional filters.
70
+
71
+ Returns:
72
+ Select: The SQLAlchemy `Select` object representing the query.
73
+ """
74
+
75
+ stmt = (
76
+ select(AiSearchIndex.entity_id, AiSearchIndex.entity_title)
77
+ .where(AiSearchIndex.entity_type == query.entity_type.value)
78
+ .distinct()
79
+ )
80
+
81
+ if query.filters is not None:
82
+ entity_id_col = AiSearchIndex.entity_id
83
+ stmt = stmt.where(
84
+ query.filters.to_expression(
85
+ entity_id_col,
86
+ entity_type_value=query.entity_type.value,
87
+ )
88
+ )
89
+
90
+ return stmt
91
+
92
+
93
+ def build_paths_query(entity_type: EntityType, prefix: str | None = None, q: str | None = None) -> Select:
94
+ """Build the query for retrieving paths and their value types for leaves/components processing."""
95
+ stmt = select(AiSearchIndex.path, AiSearchIndex.value_type).where(AiSearchIndex.entity_type == entity_type.value)
96
+
97
+ if prefix:
98
+ lquery_pattern = create_path_autocomplete_lquery(prefix)
99
+ ltree_filter = LtreeFilter(op=FilterOp.MATCHES_LQUERY, value=lquery_pattern)
100
+ stmt = stmt.where(ltree_filter.to_expression(AiSearchIndex.path, path=""))
101
+
102
+ stmt = stmt.group_by(AiSearchIndex.path, AiSearchIndex.value_type)
103
+
104
+ if q:
105
+ score = func.similarity(cast(AiSearchIndex.path, String), q)
106
+ stmt = stmt.order_by(score.desc(), AiSearchIndex.path)
107
+ else:
108
+ stmt = stmt.order_by(AiSearchIndex.path)
109
+
110
+ return stmt
111
+
112
+
113
+ def process_path_rows(rows: Sequence[Row]) -> tuple[list[LeafInfo], list[ComponentInfo]]:
114
+ """Process query results to extract leaves and components information.
115
+
116
+ Parameters
117
+ ----------
118
+ rows : Sequence[Row]
119
+ Database rows containing path and value_type information
120
+
121
+ Returns:
122
+ -------
123
+ tuple[list[LeafInfo], list[ComponentInfo]]
124
+ Processed leaves and components
125
+ """
126
+ leaves_dict: dict[str, set[UIType]] = defaultdict(set)
127
+ leaves_paths_dict: dict[str, set[str]] = defaultdict(set)
128
+ components_set: set[str] = set()
129
+
130
+ for row in rows:
131
+ path, value_type = row
132
+
133
+ path_str = str(path)
134
+ path_segments = path_str.split(".")
135
+
136
+ # Remove numeric segments
137
+ clean_segments = [seg for seg in path_segments if not seg.isdigit()]
138
+
139
+ if clean_segments:
140
+ # Last segment is a leaf
141
+ leaf_name = clean_segments[-1]
142
+ ui_type = UIType.from_field_type(FieldType(value_type))
143
+ leaves_dict[leaf_name].add(ui_type)
144
+ leaves_paths_dict[leaf_name].add(path_str)
145
+
146
+ # All segments except the first/last are components
147
+ for component in clean_segments[1:-1]:
148
+ components_set.add(component)
149
+
150
+ leaves = [
151
+ LeafInfo(name=leaf, ui_types=list(types), paths=sorted(leaves_paths_dict[leaf]))
152
+ for leaf, types in leaves_dict.items()
153
+ ]
154
+ components = [ComponentInfo(name=component, ui_types=[UIType.COMPONENT]) for component in sorted(components_set)]
155
+
156
+ return leaves, components
157
+
158
+
159
+ def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE:
160
+ """Build CTE that pivots EAV rows into columns using CASE WHEN."""
161
+ from orchestrator.search.aggregations import BaseAggregation
162
+
163
+ pivot_columns = [AiSearchIndex.entity_id.label("entity_id")]
164
+
165
+ for field_path in pivot_fields:
166
+ pivot_columns.append(
167
+ func.max(case((AiSearchIndex.path == Ltree(field_path), AiSearchIndex.value), else_=None)).label(
168
+ BaseAggregation.field_to_alias(field_path)
169
+ )
170
+ )
171
+
172
+ return (
173
+ select(*pivot_columns)
174
+ .where(
175
+ AiSearchIndex.entity_id.in_(select(base_query.c.entity_id)),
176
+ AiSearchIndex.path.in_([Ltree(p) for p in pivot_fields]),
177
+ )
178
+ .group_by(AiSearchIndex.entity_id)
179
+ .cte("pivoted_entities")
180
+ )
181
+
182
+
183
+ def _build_grouping_columns(
184
+ query: CountQuery | AggregateQuery, pivot_cte: CTE
185
+ ) -> tuple[list[Any], list[Any], list[str]]:
186
+ """Build GROUP BY columns and their SELECT columns.
187
+
188
+ Args:
189
+ query: CountQuery or AggregateQuery with group_by and temporal_group_by fields
190
+ pivot_cte: The pivoted CTE containing entity fields as columns
191
+
192
+ Returns:
193
+ tuple: (select_columns, group_by_columns, group_column_names)
194
+ - select_columns: List of labeled columns for SELECT
195
+ - group_by_columns: List of columns for GROUP BY clause
196
+ - group_column_names: List of column names (labels) that are grouping columns
197
+ """
198
+
199
+ select_columns = []
200
+ group_by_columns = []
201
+ group_column_names = []
202
+
203
+ if query.group_by:
204
+ for group_field in query.group_by:
205
+ field_alias = BaseAggregation.field_to_alias(group_field)
206
+ col = getattr(pivot_cte.c, field_alias)
207
+ select_columns.append(col.label(field_alias))
208
+ group_by_columns.append(col)
209
+ group_column_names.append(field_alias)
210
+
211
+ if query.temporal_group_by:
212
+ for temp_group in query.temporal_group_by:
213
+ select_col, group_col, col_name = temp_group.to_expression(pivot_cte.c)
214
+ select_columns.append(select_col)
215
+ group_by_columns.append(group_col)
216
+ group_column_names.append(col_name)
217
+
218
+ return select_columns, group_by_columns, group_column_names
219
+
220
+
221
+ def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CTE) -> list[Label]:
222
+ """Build aggregation columns (COUNT, SUM, AVG, MIN, MAX).
223
+
224
+ Args:
225
+ query: CountQuery or AggregateQuery
226
+ pivot_cte: The pivoted CTE containing entity fields as columns
227
+
228
+ Returns:
229
+ List of labeled aggregation expressions
230
+ """
231
+
232
+ if isinstance(query, AggregateQuery):
233
+ # AGGREGATE query with custom aggregations
234
+ agg_columns = []
235
+ for agg in query.aggregations:
236
+ if isinstance(agg, CountAggregation):
237
+ agg_columns.append(agg.to_expression(pivot_cte.c.entity_id))
238
+ else:
239
+ agg_columns.append(agg.to_expression(pivot_cte.c))
240
+ return agg_columns
241
+
242
+ # CountQuery without aggregations
243
+ count_agg = CountAggregation(type=AggregationType.COUNT, alias="count")
244
+ return [count_agg.to_expression(pivot_cte.c.entity_id)]
245
+
246
+
247
+ def build_simple_count_query(base_query: Select) -> Select:
248
+ """Build a simple count query without grouping.
249
+
250
+ Args:
251
+ base_query: Base candidate query with filters applied
252
+
253
+ Returns:
254
+ Select statement that counts distinct entity IDs
255
+ """
256
+ return select(func.count(func.distinct(base_query.c.entity_id)).label("total_count")).select_from(
257
+ base_query.subquery()
258
+ )
259
+
260
+
261
+ def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Select) -> tuple[Select, list[str]]:
262
+ """Build aggregation query with GROUP BY and aggregation functions.
263
+
264
+ Handles EAV storage by pivoting rows to columns, then applying SQL aggregations.
265
+ This function only handles grouped aggregations. Simple counts are handled directly
266
+ in the engine.
267
+
268
+ Args:
269
+ query: CountQuery or AggregateQuery with group_by and optional aggregations
270
+ base_query: Base candidate query with filters applied
271
+
272
+ Returns:
273
+ tuple: (query_stmt, group_column_names)
274
+ - query_stmt: SQLAlchemy Select statement for grouped aggregation
275
+ - group_column_names: List of column names that are grouping columns
276
+ """
277
+ pivot_cte = _build_pivot_cte(base_query, query.get_pivot_fields())
278
+ select_cols, group_cols, group_col_names = _build_grouping_columns(query, pivot_cte)
279
+ agg_cols = _build_aggregation_columns(query, pivot_cte)
280
+
281
+ stmt = select(*(select_cols + agg_cols)).select_from(pivot_cte)
282
+ if group_cols:
283
+ stmt = stmt.group_by(*group_cols)
284
+
285
+ return stmt, group_col_names