orchestrator-core 4.6.2__py3-none-any.whl → 4.6.3rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +1 -1
- orchestrator/api/api_v1/endpoints/search.py +44 -34
- orchestrator/{search/retrieval/utils.py → cli/search/display.py} +4 -29
- orchestrator/cli/search/search_explore.py +22 -24
- orchestrator/cli/search/speedtest.py +11 -9
- orchestrator/db/models.py +6 -6
- orchestrator/log_config.py +2 -0
- orchestrator/schemas/search.py +1 -1
- orchestrator/schemas/search_requests.py +59 -0
- orchestrator/search/agent/handlers.py +129 -0
- orchestrator/search/agent/prompts.py +54 -33
- orchestrator/search/agent/state.py +9 -24
- orchestrator/search/agent/tools.py +223 -144
- orchestrator/search/agent/validation.py +80 -0
- orchestrator/search/{schemas → aggregations}/__init__.py +20 -0
- orchestrator/search/aggregations/base.py +201 -0
- orchestrator/search/core/types.py +3 -2
- orchestrator/search/filters/__init__.py +4 -0
- orchestrator/search/filters/definitions.py +22 -1
- orchestrator/search/filters/numeric_filter.py +3 -3
- orchestrator/search/llm_migration.py +2 -1
- orchestrator/search/query/__init__.py +90 -0
- orchestrator/search/query/builder.py +285 -0
- orchestrator/search/query/engine.py +162 -0
- orchestrator/search/{retrieval → query}/exceptions.py +38 -7
- orchestrator/search/query/mixins.py +95 -0
- orchestrator/search/query/queries.py +129 -0
- orchestrator/search/query/results.py +252 -0
- orchestrator/search/{retrieval/query_state.py → query/state.py} +31 -11
- orchestrator/search/{retrieval → query}/validation.py +58 -1
- orchestrator/search/retrieval/__init__.py +0 -5
- orchestrator/search/retrieval/pagination.py +7 -8
- orchestrator/search/retrieval/retrievers/base.py +9 -9
- {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/METADATA +6 -6
- {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/RECORD +38 -32
- orchestrator/search/retrieval/builder.py +0 -127
- orchestrator/search/retrieval/engine.py +0 -197
- orchestrator/search/schemas/parameters.py +0 -133
- orchestrator/search/schemas/results.py +0 -80
- /orchestrator/search/{export.py → query/export.py} +0 -0
- {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.6.2.dist-info → orchestrator_core-4.6.3rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from abc import abstractmethod
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Annotated, Any, Literal, TypeAlias
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
19
|
+
from sqlalchemy import Integer, cast, func
|
|
20
|
+
from sqlalchemy.sql.elements import ColumnElement, Label
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AggregationType(str, Enum):
|
|
24
|
+
"""Types of aggregations that can be computed."""
|
|
25
|
+
|
|
26
|
+
COUNT = "count"
|
|
27
|
+
SUM = "sum"
|
|
28
|
+
AVG = "avg"
|
|
29
|
+
MIN = "min"
|
|
30
|
+
MAX = "max"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TemporalPeriod(str, Enum):
|
|
34
|
+
"""Time periods for temporal grouping."""
|
|
35
|
+
|
|
36
|
+
YEAR = "year"
|
|
37
|
+
QUARTER = "quarter"
|
|
38
|
+
MONTH = "month"
|
|
39
|
+
WEEK = "week"
|
|
40
|
+
DAY = "day"
|
|
41
|
+
HOUR = "hour"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TemporalGrouping(BaseModel):
|
|
45
|
+
"""Defines temporal grouping for date/time fields.
|
|
46
|
+
|
|
47
|
+
Used to group query results by time periods (e.g., group subscriptions by start_date per month).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
field: str = Field(description="The datetime field path to group by temporally.")
|
|
51
|
+
period: TemporalPeriod = Field(description="The time period to group by (year, quarter, month, week, day, hour).")
|
|
52
|
+
|
|
53
|
+
model_config = ConfigDict(
|
|
54
|
+
extra="forbid",
|
|
55
|
+
json_schema_extra={
|
|
56
|
+
"examples": [
|
|
57
|
+
{"field": "subscription.start_date", "period": "month"},
|
|
58
|
+
{"field": "subscription.end_date", "period": "year"},
|
|
59
|
+
{"field": "process.created_at", "period": "day"},
|
|
60
|
+
]
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def get_pivot_fields(self) -> list[str]:
|
|
65
|
+
"""Return fields that need to be pivoted for this temporal grouping."""
|
|
66
|
+
return [self.field]
|
|
67
|
+
|
|
68
|
+
def to_expression(self, pivot_cte_columns: Any) -> tuple[Label, Any, str]:
|
|
69
|
+
"""Build SQLAlchemy expression for temporal grouping.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
pivot_cte_columns: The columns object from the pivot CTE
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
tuple: (select_column, group_by_column, column_name)
|
|
76
|
+
- select_column: Labeled column for SELECT
|
|
77
|
+
- group_by_column: Column expression for GROUP BY
|
|
78
|
+
- column_name: The label/name of the column in results
|
|
79
|
+
"""
|
|
80
|
+
from sqlalchemy import TIMESTAMP, cast, func
|
|
81
|
+
|
|
82
|
+
field_alias = BaseAggregation.field_to_alias(self.field)
|
|
83
|
+
col = getattr(pivot_cte_columns, field_alias)
|
|
84
|
+
truncated_col = func.date_trunc(self.period.value, cast(col, TIMESTAMP(timezone=True)))
|
|
85
|
+
|
|
86
|
+
# Column name without prefix
|
|
87
|
+
col_name = f"{field_alias}_{self.period.value}"
|
|
88
|
+
select_col = truncated_col.label(col_name)
|
|
89
|
+
return select_col, truncated_col, col_name
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class BaseAggregation(BaseModel):
|
|
93
|
+
"""Base class for all aggregation types."""
|
|
94
|
+
|
|
95
|
+
type: AggregationType = Field(description="The type of aggregation to perform.")
|
|
96
|
+
alias: str = Field(description="The name for this aggregation in the results.")
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def create(cls, data: dict) -> "Aggregation":
|
|
100
|
+
"""Create the correct aggregation instance based on type field.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
data: Dictionary with aggregation data including 'type' discriminator
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Validated aggregation instance (CountAggregation or FieldAggregation)
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
ValidationError: If data is invalid or type is unknown
|
|
110
|
+
"""
|
|
111
|
+
from pydantic import TypeAdapter
|
|
112
|
+
|
|
113
|
+
adapter: TypeAdapter = TypeAdapter(Aggregation)
|
|
114
|
+
return adapter.validate_python(data)
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def field_to_alias(field_path: str) -> str:
|
|
118
|
+
"""Convert field path to SQL column alias.
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
'subscription.name' -> 'subscription_name'
|
|
122
|
+
'product.serial-number' -> 'product_serial_number'
|
|
123
|
+
"""
|
|
124
|
+
return field_path.replace(".", "_").replace("-", "_")
|
|
125
|
+
|
|
126
|
+
def get_pivot_fields(self) -> list[str]:
|
|
127
|
+
"""Return fields that need to be pivoted for this aggregation."""
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def to_expression(self, *args: Any, **kwargs: Any) -> Label:
|
|
132
|
+
"""Build SQLAlchemy expression for this aggregation.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Label: A labeled SQLAlchemy expression
|
|
136
|
+
"""
|
|
137
|
+
raise NotImplementedError
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class CountAggregation(BaseAggregation):
|
|
141
|
+
"""Count aggregation - counts number of entities."""
|
|
142
|
+
|
|
143
|
+
type: Literal[AggregationType.COUNT]
|
|
144
|
+
|
|
145
|
+
def to_expression(self, entity_id_column: ColumnElement) -> Label:
|
|
146
|
+
"""Build SQLAlchemy expression for count aggregation.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
entity_id_column: The entity_id column from the pivot CTE
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Label: A labeled SQLAlchemy expression
|
|
153
|
+
"""
|
|
154
|
+
return func.count(entity_id_column).label(self.alias)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class FieldAggregation(BaseAggregation):
|
|
158
|
+
"""Field-based aggregation (sum, avg, min, max)."""
|
|
159
|
+
|
|
160
|
+
type: Literal[AggregationType.SUM, AggregationType.AVG, AggregationType.MIN, AggregationType.MAX]
|
|
161
|
+
field: str = Field(description="The field path to aggregate on.")
|
|
162
|
+
|
|
163
|
+
def get_pivot_fields(self) -> list[str]:
|
|
164
|
+
"""Return fields that need to be pivoted for this aggregation."""
|
|
165
|
+
return [self.field]
|
|
166
|
+
|
|
167
|
+
def to_expression(self, pivot_cte_columns: Any) -> Label:
|
|
168
|
+
"""Build SQLAlchemy expression for field-based aggregation.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
pivot_cte_columns: The columns object from the pivot CTE
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Label: A labeled SQLAlchemy expression
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: If the field is not found in the pivot CTE
|
|
178
|
+
"""
|
|
179
|
+
field_alias = self.field_to_alias(self.field)
|
|
180
|
+
|
|
181
|
+
if not hasattr(pivot_cte_columns, field_alias):
|
|
182
|
+
raise ValueError(f"Field '{self.field}' (alias: '{field_alias}') not found in pivot CTE columns")
|
|
183
|
+
|
|
184
|
+
col = getattr(pivot_cte_columns, field_alias)
|
|
185
|
+
|
|
186
|
+
numeric_col = cast(col, Integer)
|
|
187
|
+
|
|
188
|
+
match self.type:
|
|
189
|
+
case AggregationType.SUM:
|
|
190
|
+
return func.sum(numeric_col).label(self.alias)
|
|
191
|
+
case AggregationType.AVG:
|
|
192
|
+
return func.avg(numeric_col).label(self.alias)
|
|
193
|
+
case AggregationType.MIN:
|
|
194
|
+
return func.min(numeric_col).label(self.alias)
|
|
195
|
+
case AggregationType.MAX:
|
|
196
|
+
return func.max(numeric_col).label(self.alias)
|
|
197
|
+
case _:
|
|
198
|
+
raise ValueError(f"Unsupported aggregation type: {self.type}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
Aggregation: TypeAlias = Annotated[CountAggregation | FieldAggregation, Field(discriminator="type")]
|
|
@@ -105,7 +105,8 @@ class ActionType(str, Enum):
|
|
|
105
105
|
"""Defines the explicit, safe actions the agent can request."""
|
|
106
106
|
|
|
107
107
|
SELECT = "select" # Retrieve a list of matching records.
|
|
108
|
-
|
|
108
|
+
COUNT = "count" # Count matching records, optionally grouped.
|
|
109
|
+
AGGREGATE = "aggregate" # Compute aggregations (sum, avg, etc.) over matching records.
|
|
109
110
|
|
|
110
111
|
|
|
111
112
|
class UIType(str, Enum):
|
|
@@ -261,7 +262,7 @@ class FieldType(str, Enum):
|
|
|
261
262
|
|
|
262
263
|
def is_embeddable(self, value: str | None) -> bool:
|
|
263
264
|
"""Check if a field should be embedded."""
|
|
264
|
-
if value is None:
|
|
265
|
+
if value is None or value == "":
|
|
265
266
|
return False
|
|
266
267
|
|
|
267
268
|
# If inference suggests it's not actually a string, don't embed it
|
|
@@ -19,6 +19,7 @@ from .base import (
|
|
|
19
19
|
StringFilter,
|
|
20
20
|
)
|
|
21
21
|
from .date_filters import DateFilter, DateRangeFilter, DateValueFilter
|
|
22
|
+
from .definitions import TypeDefinition, ValueSchema
|
|
22
23
|
from .ltree_filters import LtreeFilter
|
|
23
24
|
from .numeric_filter import NumericFilter, NumericRangeFilter, NumericValueFilter
|
|
24
25
|
|
|
@@ -37,4 +38,7 @@ __all__ = [
|
|
|
37
38
|
"DateFilter",
|
|
38
39
|
"LtreeFilter",
|
|
39
40
|
"NumericFilter",
|
|
41
|
+
# Schema types
|
|
42
|
+
"TypeDefinition",
|
|
43
|
+
"ValueSchema",
|
|
40
44
|
]
|
|
@@ -11,8 +11,29 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
|
|
14
|
+
from typing import Literal
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, ConfigDict
|
|
17
|
+
|
|
14
18
|
from orchestrator.search.core.types import FieldType, FilterOp, UIType
|
|
15
|
-
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ValueSchema(BaseModel):
|
|
22
|
+
"""Schema describing the expected value type for a filter operator."""
|
|
23
|
+
|
|
24
|
+
kind: UIType | Literal["none", "object"] = UIType.STRING
|
|
25
|
+
fields: dict[str, "ValueSchema"] | None = None
|
|
26
|
+
|
|
27
|
+
model_config = ConfigDict(extra="forbid")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TypeDefinition(BaseModel):
|
|
31
|
+
"""Definition of available operators and their value schemas for a field type."""
|
|
32
|
+
|
|
33
|
+
operators: list[FilterOp]
|
|
34
|
+
value_schema: dict[FilterOp, ValueSchema]
|
|
35
|
+
|
|
36
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
16
37
|
|
|
17
38
|
|
|
18
39
|
def operators_for(ft: FieldType) -> list[FilterOp]:
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
from typing import Annotated, Any, Literal
|
|
15
15
|
|
|
16
16
|
from pydantic import BaseModel, Field, model_validator
|
|
17
|
-
from sqlalchemy import
|
|
17
|
+
from sqlalchemy import BIGINT, DOUBLE_PRECISION, and_
|
|
18
18
|
from sqlalchemy import cast as sa_cast
|
|
19
19
|
from sqlalchemy.sql.elements import ColumnElement
|
|
20
20
|
from typing_extensions import Self
|
|
@@ -40,7 +40,7 @@ class NumericValueFilter(BaseModel):
|
|
|
40
40
|
value: int | float
|
|
41
41
|
|
|
42
42
|
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
43
|
-
cast_type =
|
|
43
|
+
cast_type = BIGINT if isinstance(self.value, int) else DOUBLE_PRECISION
|
|
44
44
|
numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
|
|
45
45
|
match self.op:
|
|
46
46
|
|
|
@@ -65,7 +65,7 @@ class NumericRangeFilter(BaseModel):
|
|
|
65
65
|
value: NumericRange
|
|
66
66
|
|
|
67
67
|
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
68
|
-
cast_type =
|
|
68
|
+
cast_type = BIGINT if isinstance(self.value.start, int) else DOUBLE_PRECISION
|
|
69
69
|
numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
|
|
70
70
|
return and_(numeric_column >= self.value.start, numeric_column <= self.value.end)
|
|
71
71
|
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
|
|
14
14
|
"""Simple search migration function that runs when SEARCH_ENABLED = True."""
|
|
15
15
|
|
|
16
|
+
import pydantic_ai
|
|
16
17
|
from sqlalchemy import text
|
|
17
18
|
from sqlalchemy.engine import Connection
|
|
18
19
|
from structlog import get_logger
|
|
@@ -28,7 +29,7 @@ TARGET_DIM = 1536
|
|
|
28
29
|
|
|
29
30
|
def run_migration(connection: Connection) -> None:
|
|
30
31
|
"""Run LLM migration with ON CONFLICT DO NOTHING pattern."""
|
|
31
|
-
logger.info("Running LLM migration")
|
|
32
|
+
logger.info("Running LLM migration", pydantic_ai_version=pydantic_ai.__version__)
|
|
32
33
|
|
|
33
34
|
try:
|
|
34
35
|
# Test to see if the extenstion exists and then skip the migration; Needed for certain situations where db user
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
"""Query building and execution module."""
|
|
15
|
+
|
|
16
|
+
from orchestrator.search.aggregations import TemporalGrouping
|
|
17
|
+
|
|
18
|
+
from . import engine
|
|
19
|
+
from .builder import (
|
|
20
|
+
ComponentInfo,
|
|
21
|
+
LeafInfo,
|
|
22
|
+
build_aggregation_query,
|
|
23
|
+
build_candidate_query,
|
|
24
|
+
build_paths_query,
|
|
25
|
+
process_path_rows,
|
|
26
|
+
)
|
|
27
|
+
from .exceptions import (
|
|
28
|
+
EmptyFilterPathError,
|
|
29
|
+
IncompatibleAggregationTypeError,
|
|
30
|
+
IncompatibleFilterTypeError,
|
|
31
|
+
IncompatibleTemporalGroupingTypeError,
|
|
32
|
+
InvalidEntityPrefixError,
|
|
33
|
+
InvalidLtreePatternError,
|
|
34
|
+
PathNotFoundError,
|
|
35
|
+
QueryValidationError,
|
|
36
|
+
)
|
|
37
|
+
from .queries import AggregateQuery, CountQuery, ExportQuery, Query, SelectQuery
|
|
38
|
+
from .results import (
|
|
39
|
+
AggregationResponse,
|
|
40
|
+
AggregationResult,
|
|
41
|
+
MatchingField,
|
|
42
|
+
SearchResponse,
|
|
43
|
+
SearchResult,
|
|
44
|
+
VisualizationType,
|
|
45
|
+
format_aggregation_response,
|
|
46
|
+
format_search_response,
|
|
47
|
+
generate_highlight_indices,
|
|
48
|
+
)
|
|
49
|
+
from .state import QueryState
|
|
50
|
+
|
|
51
|
+
__all__ = [
|
|
52
|
+
# Builder functions
|
|
53
|
+
"build_aggregation_query",
|
|
54
|
+
"build_candidate_query",
|
|
55
|
+
"build_paths_query",
|
|
56
|
+
"process_path_rows",
|
|
57
|
+
# Builder metadata
|
|
58
|
+
"ComponentInfo",
|
|
59
|
+
"LeafInfo",
|
|
60
|
+
# Engine
|
|
61
|
+
"engine",
|
|
62
|
+
# Exceptions
|
|
63
|
+
"EmptyFilterPathError",
|
|
64
|
+
"IncompatibleAggregationTypeError",
|
|
65
|
+
"IncompatibleFilterTypeError",
|
|
66
|
+
"IncompatibleTemporalGroupingTypeError",
|
|
67
|
+
"InvalidEntityPrefixError",
|
|
68
|
+
"InvalidLtreePatternError",
|
|
69
|
+
"PathNotFoundError",
|
|
70
|
+
"QueryValidationError",
|
|
71
|
+
# Query models
|
|
72
|
+
"AggregateQuery",
|
|
73
|
+
"CountQuery",
|
|
74
|
+
"ExportQuery",
|
|
75
|
+
"Query",
|
|
76
|
+
"SelectQuery",
|
|
77
|
+
"TemporalGrouping",
|
|
78
|
+
# Results
|
|
79
|
+
"AggregationResponse",
|
|
80
|
+
"AggregationResult",
|
|
81
|
+
"MatchingField",
|
|
82
|
+
"SearchResponse",
|
|
83
|
+
"SearchResult",
|
|
84
|
+
"VisualizationType",
|
|
85
|
+
"format_aggregation_response",
|
|
86
|
+
"format_search_response",
|
|
87
|
+
"generate_highlight_indices",
|
|
88
|
+
# State
|
|
89
|
+
"QueryState",
|
|
90
|
+
]
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
from typing import Any, Sequence
|
|
16
|
+
|
|
17
|
+
from pydantic import BaseModel, ConfigDict
|
|
18
|
+
from sqlalchemy import Select, String, case, cast, func, select
|
|
19
|
+
from sqlalchemy.engine import Row
|
|
20
|
+
from sqlalchemy.sql.elements import Label
|
|
21
|
+
from sqlalchemy.sql.selectable import CTE
|
|
22
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
23
|
+
|
|
24
|
+
from orchestrator.db.models import AiSearchIndex
|
|
25
|
+
from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation
|
|
26
|
+
from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
|
|
27
|
+
from orchestrator.search.filters import LtreeFilter
|
|
28
|
+
from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LeafInfo(BaseModel):
|
|
32
|
+
"""Information about a leaf (terminal field) in the entity schema."""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
ui_types: list[UIType]
|
|
36
|
+
paths: list[str]
|
|
37
|
+
|
|
38
|
+
model_config = ConfigDict(
|
|
39
|
+
extra="forbid",
|
|
40
|
+
use_enum_values=True,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ComponentInfo(BaseModel):
|
|
45
|
+
"""Information about a component (nested object) in the entity schema."""
|
|
46
|
+
|
|
47
|
+
name: str
|
|
48
|
+
ui_types: list[UIType]
|
|
49
|
+
|
|
50
|
+
model_config = ConfigDict(
|
|
51
|
+
extra="forbid",
|
|
52
|
+
use_enum_values=True,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def create_path_autocomplete_lquery(prefix: str) -> str:
|
|
57
|
+
"""Create the lquery pattern for a multi-level path autocomplete search."""
|
|
58
|
+
return f"{prefix}*.*"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_candidate_query(query: Query) -> Select:
|
|
62
|
+
"""Build the base query for retrieving candidate entities.
|
|
63
|
+
|
|
64
|
+
Constructs a `SELECT` statement that retrieves distinct `entity_id` values
|
|
65
|
+
from the index table for the given entity type, applying any structured
|
|
66
|
+
filters from the provided query plan.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
query: Any query type (SelectQuery, CountQuery, AggregateQuery) containing entity type and optional filters.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Select: The SQLAlchemy `Select` object representing the query.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
stmt = (
|
|
76
|
+
select(AiSearchIndex.entity_id, AiSearchIndex.entity_title)
|
|
77
|
+
.where(AiSearchIndex.entity_type == query.entity_type.value)
|
|
78
|
+
.distinct()
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if query.filters is not None:
|
|
82
|
+
entity_id_col = AiSearchIndex.entity_id
|
|
83
|
+
stmt = stmt.where(
|
|
84
|
+
query.filters.to_expression(
|
|
85
|
+
entity_id_col,
|
|
86
|
+
entity_type_value=query.entity_type.value,
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return stmt
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def build_paths_query(entity_type: EntityType, prefix: str | None = None, q: str | None = None) -> Select:
|
|
94
|
+
"""Build the query for retrieving paths and their value types for leaves/components processing."""
|
|
95
|
+
stmt = select(AiSearchIndex.path, AiSearchIndex.value_type).where(AiSearchIndex.entity_type == entity_type.value)
|
|
96
|
+
|
|
97
|
+
if prefix:
|
|
98
|
+
lquery_pattern = create_path_autocomplete_lquery(prefix)
|
|
99
|
+
ltree_filter = LtreeFilter(op=FilterOp.MATCHES_LQUERY, value=lquery_pattern)
|
|
100
|
+
stmt = stmt.where(ltree_filter.to_expression(AiSearchIndex.path, path=""))
|
|
101
|
+
|
|
102
|
+
stmt = stmt.group_by(AiSearchIndex.path, AiSearchIndex.value_type)
|
|
103
|
+
|
|
104
|
+
if q:
|
|
105
|
+
score = func.similarity(cast(AiSearchIndex.path, String), q)
|
|
106
|
+
stmt = stmt.order_by(score.desc(), AiSearchIndex.path)
|
|
107
|
+
else:
|
|
108
|
+
stmt = stmt.order_by(AiSearchIndex.path)
|
|
109
|
+
|
|
110
|
+
return stmt
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def process_path_rows(rows: Sequence[Row]) -> tuple[list[LeafInfo], list[ComponentInfo]]:
|
|
114
|
+
"""Process query results to extract leaves and components information.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
rows : Sequence[Row]
|
|
119
|
+
Database rows containing path and value_type information
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
-------
|
|
123
|
+
tuple[list[LeafInfo], list[ComponentInfo]]
|
|
124
|
+
Processed leaves and components
|
|
125
|
+
"""
|
|
126
|
+
leaves_dict: dict[str, set[UIType]] = defaultdict(set)
|
|
127
|
+
leaves_paths_dict: dict[str, set[str]] = defaultdict(set)
|
|
128
|
+
components_set: set[str] = set()
|
|
129
|
+
|
|
130
|
+
for row in rows:
|
|
131
|
+
path, value_type = row
|
|
132
|
+
|
|
133
|
+
path_str = str(path)
|
|
134
|
+
path_segments = path_str.split(".")
|
|
135
|
+
|
|
136
|
+
# Remove numeric segments
|
|
137
|
+
clean_segments = [seg for seg in path_segments if not seg.isdigit()]
|
|
138
|
+
|
|
139
|
+
if clean_segments:
|
|
140
|
+
# Last segment is a leaf
|
|
141
|
+
leaf_name = clean_segments[-1]
|
|
142
|
+
ui_type = UIType.from_field_type(FieldType(value_type))
|
|
143
|
+
leaves_dict[leaf_name].add(ui_type)
|
|
144
|
+
leaves_paths_dict[leaf_name].add(path_str)
|
|
145
|
+
|
|
146
|
+
# All segments except the first/last are components
|
|
147
|
+
for component in clean_segments[1:-1]:
|
|
148
|
+
components_set.add(component)
|
|
149
|
+
|
|
150
|
+
leaves = [
|
|
151
|
+
LeafInfo(name=leaf, ui_types=list(types), paths=sorted(leaves_paths_dict[leaf]))
|
|
152
|
+
for leaf, types in leaves_dict.items()
|
|
153
|
+
]
|
|
154
|
+
components = [ComponentInfo(name=component, ui_types=[UIType.COMPONENT]) for component in sorted(components_set)]
|
|
155
|
+
|
|
156
|
+
return leaves, components
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE:
|
|
160
|
+
"""Build CTE that pivots EAV rows into columns using CASE WHEN."""
|
|
161
|
+
from orchestrator.search.aggregations import BaseAggregation
|
|
162
|
+
|
|
163
|
+
pivot_columns = [AiSearchIndex.entity_id.label("entity_id")]
|
|
164
|
+
|
|
165
|
+
for field_path in pivot_fields:
|
|
166
|
+
pivot_columns.append(
|
|
167
|
+
func.max(case((AiSearchIndex.path == Ltree(field_path), AiSearchIndex.value), else_=None)).label(
|
|
168
|
+
BaseAggregation.field_to_alias(field_path)
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return (
|
|
173
|
+
select(*pivot_columns)
|
|
174
|
+
.where(
|
|
175
|
+
AiSearchIndex.entity_id.in_(select(base_query.c.entity_id)),
|
|
176
|
+
AiSearchIndex.path.in_([Ltree(p) for p in pivot_fields]),
|
|
177
|
+
)
|
|
178
|
+
.group_by(AiSearchIndex.entity_id)
|
|
179
|
+
.cte("pivoted_entities")
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _build_grouping_columns(
|
|
184
|
+
query: CountQuery | AggregateQuery, pivot_cte: CTE
|
|
185
|
+
) -> tuple[list[Any], list[Any], list[str]]:
|
|
186
|
+
"""Build GROUP BY columns and their SELECT columns.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
query: CountQuery or AggregateQuery with group_by and temporal_group_by fields
|
|
190
|
+
pivot_cte: The pivoted CTE containing entity fields as columns
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
tuple: (select_columns, group_by_columns, group_column_names)
|
|
194
|
+
- select_columns: List of labeled columns for SELECT
|
|
195
|
+
- group_by_columns: List of columns for GROUP BY clause
|
|
196
|
+
- group_column_names: List of column names (labels) that are grouping columns
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
select_columns = []
|
|
200
|
+
group_by_columns = []
|
|
201
|
+
group_column_names = []
|
|
202
|
+
|
|
203
|
+
if query.group_by:
|
|
204
|
+
for group_field in query.group_by:
|
|
205
|
+
field_alias = BaseAggregation.field_to_alias(group_field)
|
|
206
|
+
col = getattr(pivot_cte.c, field_alias)
|
|
207
|
+
select_columns.append(col.label(field_alias))
|
|
208
|
+
group_by_columns.append(col)
|
|
209
|
+
group_column_names.append(field_alias)
|
|
210
|
+
|
|
211
|
+
if query.temporal_group_by:
|
|
212
|
+
for temp_group in query.temporal_group_by:
|
|
213
|
+
select_col, group_col, col_name = temp_group.to_expression(pivot_cte.c)
|
|
214
|
+
select_columns.append(select_col)
|
|
215
|
+
group_by_columns.append(group_col)
|
|
216
|
+
group_column_names.append(col_name)
|
|
217
|
+
|
|
218
|
+
return select_columns, group_by_columns, group_column_names
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CTE) -> list[Label]:
|
|
222
|
+
"""Build aggregation columns (COUNT, SUM, AVG, MIN, MAX).
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
query: CountQuery or AggregateQuery
|
|
226
|
+
pivot_cte: The pivoted CTE containing entity fields as columns
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
List of labeled aggregation expressions
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
if isinstance(query, AggregateQuery):
|
|
233
|
+
# AGGREGATE query with custom aggregations
|
|
234
|
+
agg_columns = []
|
|
235
|
+
for agg in query.aggregations:
|
|
236
|
+
if isinstance(agg, CountAggregation):
|
|
237
|
+
agg_columns.append(agg.to_expression(pivot_cte.c.entity_id))
|
|
238
|
+
else:
|
|
239
|
+
agg_columns.append(agg.to_expression(pivot_cte.c))
|
|
240
|
+
return agg_columns
|
|
241
|
+
|
|
242
|
+
# CountQuery without aggregations
|
|
243
|
+
count_agg = CountAggregation(type=AggregationType.COUNT, alias="count")
|
|
244
|
+
return [count_agg.to_expression(pivot_cte.c.entity_id)]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def build_simple_count_query(base_query: Select) -> Select:
|
|
248
|
+
"""Build a simple count query without grouping.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
base_query: Base candidate query with filters applied
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Select statement that counts distinct entity IDs
|
|
255
|
+
"""
|
|
256
|
+
return select(func.count(func.distinct(base_query.c.entity_id)).label("total_count")).select_from(
|
|
257
|
+
base_query.subquery()
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Select) -> tuple[Select, list[str]]:
|
|
262
|
+
"""Build aggregation query with GROUP BY and aggregation functions.
|
|
263
|
+
|
|
264
|
+
Handles EAV storage by pivoting rows to columns, then applying SQL aggregations.
|
|
265
|
+
This function only handles grouped aggregations. Simple counts are handled directly
|
|
266
|
+
in the engine.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
query: CountQuery or AggregateQuery with group_by and optional aggregations
|
|
270
|
+
base_query: Base candidate query with filters applied
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
tuple: (query_stmt, group_column_names)
|
|
274
|
+
- query_stmt: SQLAlchemy Select statement for grouped aggregation
|
|
275
|
+
- group_column_names: List of column names that are grouping columns
|
|
276
|
+
"""
|
|
277
|
+
pivot_cte = _build_pivot_cte(base_query, query.get_pivot_fields())
|
|
278
|
+
select_cols, group_cols, group_col_names = _build_grouping_columns(query, pivot_cte)
|
|
279
|
+
agg_cols = _build_aggregation_columns(query, pivot_cte)
|
|
280
|
+
|
|
281
|
+
stmt = select(*(select_cols + agg_cols)).select_from(pivot_cte)
|
|
282
|
+
if group_cols:
|
|
283
|
+
stmt = stmt.group_by(*group_cols)
|
|
284
|
+
|
|
285
|
+
return stmt, group_col_names
|