orchestrator-core 4.4.2__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +17 -2
- orchestrator/agentic_app.py +103 -0
- orchestrator/api/api_v1/api.py +14 -2
- orchestrator/api/api_v1/endpoints/search.py +296 -0
- orchestrator/app.py +32 -0
- orchestrator/cli/main.py +22 -1
- orchestrator/cli/search/__init__.py +32 -0
- orchestrator/cli/search/index_llm.py +73 -0
- orchestrator/cli/search/resize_embedding.py +135 -0
- orchestrator/cli/search/search_explore.py +208 -0
- orchestrator/cli/search/speedtest.py +151 -0
- orchestrator/db/models.py +37 -1
- orchestrator/devtools/populator.py +16 -0
- orchestrator/domain/base.py +2 -7
- orchestrator/domain/lifecycle.py +24 -7
- orchestrator/llm_settings.py +57 -0
- orchestrator/log_config.py +1 -0
- orchestrator/migrations/helpers.py +7 -1
- orchestrator/schemas/search.py +130 -0
- orchestrator/schemas/workflow.py +1 -0
- orchestrator/search/__init__.py +12 -0
- orchestrator/search/agent/__init__.py +21 -0
- orchestrator/search/agent/agent.py +62 -0
- orchestrator/search/agent/prompts.py +100 -0
- orchestrator/search/agent/state.py +21 -0
- orchestrator/search/agent/tools.py +258 -0
- orchestrator/search/core/__init__.py +12 -0
- orchestrator/search/core/embedding.py +73 -0
- orchestrator/search/core/exceptions.py +36 -0
- orchestrator/search/core/types.py +296 -0
- orchestrator/search/core/validators.py +40 -0
- orchestrator/search/docs/index.md +37 -0
- orchestrator/search/docs/running_local_text_embedding_inference.md +46 -0
- orchestrator/search/filters/__init__.py +40 -0
- orchestrator/search/filters/base.py +295 -0
- orchestrator/search/filters/date_filters.py +88 -0
- orchestrator/search/filters/definitions.py +107 -0
- orchestrator/search/filters/ltree_filters.py +56 -0
- orchestrator/search/filters/numeric_filter.py +73 -0
- orchestrator/search/indexing/__init__.py +16 -0
- orchestrator/search/indexing/indexer.py +334 -0
- orchestrator/search/indexing/registry.py +101 -0
- orchestrator/search/indexing/tasks.py +69 -0
- orchestrator/search/indexing/traverse.py +334 -0
- orchestrator/search/llm_migration.py +108 -0
- orchestrator/search/retrieval/__init__.py +16 -0
- orchestrator/search/retrieval/builder.py +123 -0
- orchestrator/search/retrieval/engine.py +154 -0
- orchestrator/search/retrieval/exceptions.py +90 -0
- orchestrator/search/retrieval/pagination.py +96 -0
- orchestrator/search/retrieval/retrievers/__init__.py +26 -0
- orchestrator/search/retrieval/retrievers/base.py +123 -0
- orchestrator/search/retrieval/retrievers/fuzzy.py +94 -0
- orchestrator/search/retrieval/retrievers/hybrid.py +277 -0
- orchestrator/search/retrieval/retrievers/semantic.py +94 -0
- orchestrator/search/retrieval/retrievers/structured.py +39 -0
- orchestrator/search/retrieval/utils.py +120 -0
- orchestrator/search/retrieval/validation.py +152 -0
- orchestrator/search/schemas/__init__.py +12 -0
- orchestrator/search/schemas/parameters.py +129 -0
- orchestrator/search/schemas/results.py +77 -0
- orchestrator/services/processes.py +1 -1
- orchestrator/services/settings_env_variables.py +2 -2
- orchestrator/settings.py +8 -1
- orchestrator/utils/state.py +6 -1
- orchestrator/workflows/steps.py +15 -1
- orchestrator/workflows/tasks/validate_products.py +1 -1
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/METADATA +15 -8
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/RECORD +71 -21
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from typing import TypedDict
|
|
15
|
+
|
|
16
|
+
from sqlalchemy import BindParameter, Select, and_, bindparam, case, cast, func, literal, or_, select
|
|
17
|
+
from sqlalchemy.sql.expression import ColumnElement, Label
|
|
18
|
+
from sqlalchemy.types import TypeEngine
|
|
19
|
+
|
|
20
|
+
from orchestrator.db.models import AiSearchIndex
|
|
21
|
+
from orchestrator.search.core.types import SearchMetadata
|
|
22
|
+
|
|
23
|
+
from ..pagination import PaginationParams
|
|
24
|
+
from .base import Retriever
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RrfScoreSqlComponents(TypedDict):
|
|
28
|
+
"""SQL expression components of the RRF hybrid score calculation."""
|
|
29
|
+
|
|
30
|
+
rrf_num: ColumnElement
|
|
31
|
+
perfect: Label
|
|
32
|
+
beta: ColumnElement
|
|
33
|
+
rrf_max: ColumnElement
|
|
34
|
+
fused_num: ColumnElement
|
|
35
|
+
normalized_score: ColumnElement
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def compute_rrf_hybrid_score_sql(
|
|
39
|
+
sem_rank_col: ColumnElement,
|
|
40
|
+
fuzzy_rank_col: ColumnElement,
|
|
41
|
+
avg_fuzzy_score_col: ColumnElement,
|
|
42
|
+
k: int,
|
|
43
|
+
perfect_threshold: float,
|
|
44
|
+
n_sources: int = 2,
|
|
45
|
+
margin_factor: float = 0.05,
|
|
46
|
+
score_numeric_type: TypeEngine | None = None,
|
|
47
|
+
) -> RrfScoreSqlComponents:
|
|
48
|
+
"""Compute RRF (Reciprocal Rank Fusion) hybrid score as SQL expressions for database execution.
|
|
49
|
+
|
|
50
|
+
This function implements the core scoring logic for hybrid search combining semantic
|
|
51
|
+
and fuzzy ranking. It computes:
|
|
52
|
+
1. Base RRF score from both ranks
|
|
53
|
+
2. Perfect match detection and boosting
|
|
54
|
+
3. Dynamic beta parameter based on k and n_sources
|
|
55
|
+
4. Normalized final score in [0, 1] range
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
sem_rank_col: SQLAlchemy column expression for semantic rank
|
|
59
|
+
fuzzy_rank_col: SQLAlchemy column expression for fuzzy rank
|
|
60
|
+
avg_fuzzy_score_col: SQLAlchemy column expression for average fuzzy score
|
|
61
|
+
k: RRF constant controlling rank influence (typically 60)
|
|
62
|
+
perfect_threshold: Threshold for perfect match boost (typically 0.9)
|
|
63
|
+
n_sources: Number of ranking sources being fused (default: 2 for semantic + fuzzy)
|
|
64
|
+
margin_factor: Margin above rrf_max as fraction (default: 0.05 = 5%)
|
|
65
|
+
score_numeric_type: SQLAlchemy numeric type for casting scores
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
RrfScoreSqlComponents: Dictionary of SQL expressions for score components
|
|
69
|
+
- rrf_num: Raw RRF score (cast to numeric type if provided)
|
|
70
|
+
- perfect: Perfect match flag (1 if avg_fuzzy_score >= threshold, else 0)
|
|
71
|
+
- beta: Boost amount for perfect matches
|
|
72
|
+
- rrf_max: Maximum possible RRF score
|
|
73
|
+
- fused_num: RRF + perfect boost
|
|
74
|
+
- normalized_score: Final score normalized to [0, 1]
|
|
75
|
+
|
|
76
|
+
Note:
|
|
77
|
+
- Keep margin_factor small to avoid compressing perfects near 1 after normalization.
|
|
78
|
+
|
|
79
|
+
- The `beta` boost is calculated to be greater than the maximum possible standard
|
|
80
|
+
RRF score (`rrf_max`). This guarantees that any item flagged as a "perfect" match
|
|
81
|
+
will always rank above any non-perfect match.
|
|
82
|
+
|
|
83
|
+
- This function assumes that rank columns do not
|
|
84
|
+
contain `NULL` values. A `NULL` in any rank column will result in a `NULL` final score
|
|
85
|
+
for that item.
|
|
86
|
+
"""
|
|
87
|
+
# RRF (rank-based): sum of 1/(k + rank_i) for each ranking source
|
|
88
|
+
rrf_raw = (1.0 / (k + sem_rank_col)) + (1.0 / (k + fuzzy_rank_col))
|
|
89
|
+
rrf_num = cast(rrf_raw, score_numeric_type) if score_numeric_type else rrf_raw
|
|
90
|
+
|
|
91
|
+
# Perfect flag to boost near perfect fuzzy matches
|
|
92
|
+
perfect = case((avg_fuzzy_score_col >= perfect_threshold, 1), else_=0).label("perfect_match")
|
|
93
|
+
|
|
94
|
+
# Dynamic beta based on k and number of sources
|
|
95
|
+
# rrf_max = n_sources / (k + 1)
|
|
96
|
+
k_num = literal(float(k), type_=score_numeric_type) if score_numeric_type else literal(float(k))
|
|
97
|
+
n_sources_lit = (
|
|
98
|
+
literal(float(n_sources), type_=score_numeric_type) if score_numeric_type else literal(float(n_sources))
|
|
99
|
+
)
|
|
100
|
+
rrf_max = n_sources_lit / (k_num + literal(1.0, type_=score_numeric_type if score_numeric_type else None))
|
|
101
|
+
|
|
102
|
+
margin = rrf_max * literal(margin_factor, type_=score_numeric_type if score_numeric_type else None)
|
|
103
|
+
beta = rrf_max + margin
|
|
104
|
+
|
|
105
|
+
# Fused score: RRF + perfect match boost
|
|
106
|
+
perfect_casted = cast(perfect, score_numeric_type) if score_numeric_type else perfect
|
|
107
|
+
fused_num = rrf_num + beta * perfect_casted
|
|
108
|
+
|
|
109
|
+
# Normalize to [0,1] via the theoretical max (beta + rrf_max)
|
|
110
|
+
norm_den = beta + rrf_max
|
|
111
|
+
normalized_score = fused_num / norm_den
|
|
112
|
+
|
|
113
|
+
return RrfScoreSqlComponents(
|
|
114
|
+
rrf_num=rrf_num,
|
|
115
|
+
perfect=perfect,
|
|
116
|
+
beta=beta,
|
|
117
|
+
rrf_max=rrf_max,
|
|
118
|
+
fused_num=fused_num,
|
|
119
|
+
normalized_score=normalized_score,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class RrfHybridRetriever(Retriever):
|
|
124
|
+
"""Reciprocal Rank Fusion of semantic and fuzzy ranking with parent-child retrieval."""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
q_vec: list[float],
|
|
129
|
+
fuzzy_term: str,
|
|
130
|
+
pagination_params: PaginationParams,
|
|
131
|
+
k: int = 60,
|
|
132
|
+
field_candidates_limit: int = 100,
|
|
133
|
+
) -> None:
|
|
134
|
+
self.q_vec = q_vec
|
|
135
|
+
self.fuzzy_term = fuzzy_term
|
|
136
|
+
self.page_after_score = pagination_params.page_after_score
|
|
137
|
+
self.page_after_id = pagination_params.page_after_id
|
|
138
|
+
self.k = k
|
|
139
|
+
self.field_candidates_limit = field_candidates_limit
|
|
140
|
+
|
|
141
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
142
|
+
cand = candidate_query.subquery()
|
|
143
|
+
q_param: BindParameter[list[float]] = bindparam("q_vec", self.q_vec, type_=AiSearchIndex.embedding.type)
|
|
144
|
+
|
|
145
|
+
best_similarity = func.word_similarity(self.fuzzy_term, AiSearchIndex.value)
|
|
146
|
+
sem_expr = case(
|
|
147
|
+
(AiSearchIndex.embedding.is_(None), None),
|
|
148
|
+
else_=AiSearchIndex.embedding.op("<->")(q_param),
|
|
149
|
+
)
|
|
150
|
+
sem_val = func.coalesce(sem_expr, literal(1.0)).label("semantic_distance")
|
|
151
|
+
|
|
152
|
+
filter_condition = literal(self.fuzzy_term).op("<%")(AiSearchIndex.value)
|
|
153
|
+
|
|
154
|
+
field_candidates = (
|
|
155
|
+
select(
|
|
156
|
+
AiSearchIndex.entity_id,
|
|
157
|
+
AiSearchIndex.path,
|
|
158
|
+
AiSearchIndex.value,
|
|
159
|
+
sem_val,
|
|
160
|
+
best_similarity.label("fuzzy_score"),
|
|
161
|
+
)
|
|
162
|
+
.select_from(AiSearchIndex)
|
|
163
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
164
|
+
.where(
|
|
165
|
+
and_(
|
|
166
|
+
AiSearchIndex.value_type.in_(self.SEARCHABLE_FIELD_TYPES),
|
|
167
|
+
filter_condition,
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
.order_by(
|
|
171
|
+
best_similarity.desc().nulls_last(),
|
|
172
|
+
sem_expr.asc().nulls_last(),
|
|
173
|
+
AiSearchIndex.entity_id.asc(),
|
|
174
|
+
)
|
|
175
|
+
.limit(self.field_candidates_limit)
|
|
176
|
+
).cte("field_candidates")
|
|
177
|
+
|
|
178
|
+
entity_scores = (
|
|
179
|
+
select(
|
|
180
|
+
field_candidates.c.entity_id,
|
|
181
|
+
func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"),
|
|
182
|
+
func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"),
|
|
183
|
+
).group_by(field_candidates.c.entity_id)
|
|
184
|
+
).cte("entity_scores")
|
|
185
|
+
|
|
186
|
+
entity_highlights = (
|
|
187
|
+
select(
|
|
188
|
+
field_candidates.c.entity_id,
|
|
189
|
+
func.first_value(field_candidates.c.value)
|
|
190
|
+
.over(
|
|
191
|
+
partition_by=field_candidates.c.entity_id,
|
|
192
|
+
order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()],
|
|
193
|
+
)
|
|
194
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
195
|
+
func.first_value(field_candidates.c.path)
|
|
196
|
+
.over(
|
|
197
|
+
partition_by=field_candidates.c.entity_id,
|
|
198
|
+
order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()],
|
|
199
|
+
)
|
|
200
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
201
|
+
).distinct(field_candidates.c.entity_id)
|
|
202
|
+
).cte("entity_highlights")
|
|
203
|
+
|
|
204
|
+
ranked = (
|
|
205
|
+
select(
|
|
206
|
+
entity_scores.c.entity_id,
|
|
207
|
+
entity_scores.c.avg_semantic_distance,
|
|
208
|
+
entity_scores.c.avg_fuzzy_score,
|
|
209
|
+
entity_highlights.c.highlight_text,
|
|
210
|
+
entity_highlights.c.highlight_path,
|
|
211
|
+
func.dense_rank()
|
|
212
|
+
.over(
|
|
213
|
+
order_by=[entity_scores.c.avg_semantic_distance.asc().nulls_last(), entity_scores.c.entity_id.asc()]
|
|
214
|
+
)
|
|
215
|
+
.label("sem_rank"),
|
|
216
|
+
func.dense_rank()
|
|
217
|
+
.over(order_by=[entity_scores.c.avg_fuzzy_score.desc().nulls_last(), entity_scores.c.entity_id.asc()])
|
|
218
|
+
.label("fuzzy_rank"),
|
|
219
|
+
).select_from(
|
|
220
|
+
entity_scores.join(entity_highlights, entity_scores.c.entity_id == entity_highlights.c.entity_id)
|
|
221
|
+
)
|
|
222
|
+
).cte("ranked_results")
|
|
223
|
+
|
|
224
|
+
# Compute RRF hybrid score
|
|
225
|
+
score_components = compute_rrf_hybrid_score_sql(
|
|
226
|
+
sem_rank_col=ranked.c.sem_rank,
|
|
227
|
+
fuzzy_rank_col=ranked.c.fuzzy_rank,
|
|
228
|
+
avg_fuzzy_score_col=ranked.c.avg_fuzzy_score,
|
|
229
|
+
k=self.k,
|
|
230
|
+
perfect_threshold=0.9,
|
|
231
|
+
score_numeric_type=self.SCORE_NUMERIC_TYPE,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
perfect = score_components["perfect"]
|
|
235
|
+
normalized_score = score_components["normalized_score"]
|
|
236
|
+
|
|
237
|
+
# Round to configured precision
|
|
238
|
+
score = cast(
|
|
239
|
+
func.round(cast(normalized_score, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION),
|
|
240
|
+
self.SCORE_NUMERIC_TYPE,
|
|
241
|
+
).label(self.SCORE_LABEL)
|
|
242
|
+
|
|
243
|
+
stmt = select(
|
|
244
|
+
ranked.c.entity_id,
|
|
245
|
+
score,
|
|
246
|
+
ranked.c.highlight_text,
|
|
247
|
+
ranked.c.highlight_path,
|
|
248
|
+
perfect.label("perfect_match"),
|
|
249
|
+
).select_from(ranked)
|
|
250
|
+
|
|
251
|
+
stmt = self._apply_fused_pagination(stmt, score, ranked.c.entity_id)
|
|
252
|
+
|
|
253
|
+
return stmt.order_by(
|
|
254
|
+
score.desc().nulls_last(),
|
|
255
|
+
ranked.c.entity_id.asc(),
|
|
256
|
+
).params(q_vec=self.q_vec)
|
|
257
|
+
|
|
258
|
+
def _apply_fused_pagination(
|
|
259
|
+
self,
|
|
260
|
+
stmt: Select,
|
|
261
|
+
score_column: ColumnElement,
|
|
262
|
+
entity_id_column: ColumnElement,
|
|
263
|
+
) -> Select:
|
|
264
|
+
"""Keyset paginate by fused score + id."""
|
|
265
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
266
|
+
score_param = self._quantize_score_for_pagination(self.page_after_score)
|
|
267
|
+
stmt = stmt.where(
|
|
268
|
+
or_(
|
|
269
|
+
score_column < score_param,
|
|
270
|
+
and_(score_column == score_param, entity_id_column > self.page_after_id),
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
return stmt
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def metadata(self) -> SearchMetadata:
|
|
277
|
+
return SearchMetadata.hybrid()
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from sqlalchemy import Select, and_, cast, func, literal, or_, select
|
|
15
|
+
from sqlalchemy.sql.expression import ColumnElement
|
|
16
|
+
|
|
17
|
+
from orchestrator.db.models import AiSearchIndex
|
|
18
|
+
from orchestrator.search.core.types import SearchMetadata
|
|
19
|
+
|
|
20
|
+
from ..pagination import PaginationParams
|
|
21
|
+
from .base import Retriever
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SemanticRetriever(Retriever):
|
|
25
|
+
"""Ranks results based on the minimum semantic vector distance."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None:
|
|
28
|
+
self.vector_query = vector_query
|
|
29
|
+
self.page_after_score = pagination_params.page_after_score
|
|
30
|
+
self.page_after_id = pagination_params.page_after_id
|
|
31
|
+
|
|
32
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
33
|
+
cand = candidate_query.subquery()
|
|
34
|
+
|
|
35
|
+
dist = AiSearchIndex.embedding.l2_distance(self.vector_query)
|
|
36
|
+
|
|
37
|
+
raw_min = func.min(dist).over(partition_by=AiSearchIndex.entity_id)
|
|
38
|
+
|
|
39
|
+
# Normalize score to preserve ordering in accordance with other retrievers:
|
|
40
|
+
# smaller distance = higher score
|
|
41
|
+
similarity = literal(1.0, type_=self.SCORE_NUMERIC_TYPE) / (
|
|
42
|
+
literal(1.0, type_=self.SCORE_NUMERIC_TYPE) + cast(raw_min, self.SCORE_NUMERIC_TYPE)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
score = cast(
|
|
46
|
+
func.round(cast(similarity, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION), self.SCORE_NUMERIC_TYPE
|
|
47
|
+
).label(self.SCORE_LABEL)
|
|
48
|
+
|
|
49
|
+
combined_query = (
|
|
50
|
+
select(
|
|
51
|
+
AiSearchIndex.entity_id,
|
|
52
|
+
score,
|
|
53
|
+
func.first_value(AiSearchIndex.value)
|
|
54
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()])
|
|
55
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
56
|
+
func.first_value(AiSearchIndex.path)
|
|
57
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()])
|
|
58
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
59
|
+
)
|
|
60
|
+
.select_from(AiSearchIndex)
|
|
61
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
62
|
+
.where(AiSearchIndex.embedding.isnot(None))
|
|
63
|
+
.distinct(AiSearchIndex.entity_id)
|
|
64
|
+
)
|
|
65
|
+
final_query = combined_query.subquery("ranked_semantic")
|
|
66
|
+
|
|
67
|
+
stmt = select(
|
|
68
|
+
final_query.c.entity_id,
|
|
69
|
+
final_query.c.score,
|
|
70
|
+
final_query.c.highlight_text,
|
|
71
|
+
final_query.c.highlight_path,
|
|
72
|
+
).select_from(final_query)
|
|
73
|
+
|
|
74
|
+
stmt = self._apply_semantic_pagination(stmt, final_query.c.score, final_query.c.entity_id)
|
|
75
|
+
|
|
76
|
+
return stmt.order_by(final_query.c.score.desc().nulls_last(), final_query.c.entity_id.asc())
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def metadata(self) -> SearchMetadata:
|
|
80
|
+
return SearchMetadata.semantic()
|
|
81
|
+
|
|
82
|
+
def _apply_semantic_pagination(
|
|
83
|
+
self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement
|
|
84
|
+
) -> Select:
|
|
85
|
+
"""Apply semantic score pagination with precise Decimal handling."""
|
|
86
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
87
|
+
score_param = self._quantize_score_for_pagination(self.page_after_score)
|
|
88
|
+
stmt = stmt.where(
|
|
89
|
+
or_(
|
|
90
|
+
score_column < score_param,
|
|
91
|
+
and_(score_column == score_param, entity_id_column > self.page_after_id),
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
return stmt
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from sqlalchemy import Select, literal, select
|
|
15
|
+
|
|
16
|
+
from orchestrator.search.core.types import SearchMetadata
|
|
17
|
+
|
|
18
|
+
from ..pagination import PaginationParams
|
|
19
|
+
from .base import Retriever
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class StructuredRetriever(Retriever):
|
|
23
|
+
"""Applies a dummy score for purely structured searches with no text query."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, pagination_params: PaginationParams) -> None:
|
|
26
|
+
self.page_after_id = pagination_params.page_after_id
|
|
27
|
+
|
|
28
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
29
|
+
cand = candidate_query.subquery()
|
|
30
|
+
stmt = select(cand.c.entity_id, literal(1.0).label("score")).select_from(cand)
|
|
31
|
+
|
|
32
|
+
if self.page_after_id:
|
|
33
|
+
stmt = stmt.where(cand.c.entity_id > self.page_after_id)
|
|
34
|
+
|
|
35
|
+
return stmt.order_by(cand.c.entity_id.asc())
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def metadata(self) -> SearchMetadata:
|
|
39
|
+
return SearchMetadata.structured()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
import structlog
|
|
18
|
+
from sqlalchemy import and_
|
|
19
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
20
|
+
|
|
21
|
+
from orchestrator.db.database import WrappedSession
|
|
22
|
+
from orchestrator.db.models import AiSearchIndex
|
|
23
|
+
from orchestrator.search.core.types import EntityType
|
|
24
|
+
from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
|
|
25
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
26
|
+
from orchestrator.search.schemas.results import SearchResult
|
|
27
|
+
|
|
28
|
+
logger = structlog.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def generate_highlight_indices(text: str, term: str) -> list[tuple[int, int]]:
|
|
32
|
+
"""Finds all occurrences of individual words from the term, including both word boundary and substring matches."""
|
|
33
|
+
if not text or not term:
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
all_matches = []
|
|
37
|
+
words = [w.strip() for w in term.split() if w.strip()]
|
|
38
|
+
|
|
39
|
+
for word in words:
|
|
40
|
+
# First find word boundary matches
|
|
41
|
+
word_boundary_pattern = rf"\b{re.escape(word)}\b"
|
|
42
|
+
word_matches = list(re.finditer(word_boundary_pattern, text, re.IGNORECASE))
|
|
43
|
+
all_matches.extend([(m.start(), m.end()) for m in word_matches])
|
|
44
|
+
|
|
45
|
+
# Then find all substring matches
|
|
46
|
+
substring_pattern = re.escape(word)
|
|
47
|
+
substring_matches = list(re.finditer(substring_pattern, text, re.IGNORECASE))
|
|
48
|
+
all_matches.extend([(m.start(), m.end()) for m in substring_matches])
|
|
49
|
+
|
|
50
|
+
return sorted(set(all_matches))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def display_filtered_paths_only(
|
|
54
|
+
results: list[SearchResult], search_params: BaseSearchParameters, db_session: WrappedSession
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Display only the paths that were searched for in the results."""
|
|
57
|
+
if not results:
|
|
58
|
+
logger.info("No results found.")
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
logger.info("--- Search Results ---")
|
|
62
|
+
|
|
63
|
+
searched_paths = search_params.filters.get_all_paths() if search_params.filters else []
|
|
64
|
+
if not searched_paths:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
for result in results:
|
|
68
|
+
for path in searched_paths:
|
|
69
|
+
record: AiSearchIndex | None = (
|
|
70
|
+
db_session.query(AiSearchIndex)
|
|
71
|
+
.filter(and_(AiSearchIndex.entity_id == result.entity_id, AiSearchIndex.path == Ltree(path)))
|
|
72
|
+
.first()
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if record:
|
|
76
|
+
logger.info(f" {record.path}: {record.value}")
|
|
77
|
+
|
|
78
|
+
logger.info("-" * 40)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def display_results(
|
|
82
|
+
results: list[SearchResult],
|
|
83
|
+
db_session: WrappedSession,
|
|
84
|
+
score_label: str = "Score",
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Display search results, showing matched field when available or uuid+name for vector search."""
|
|
87
|
+
if not results:
|
|
88
|
+
logger.info("No results found.")
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
logger.info("--- Search Results ---")
|
|
92
|
+
for result in results:
|
|
93
|
+
entity_id = result.entity_id
|
|
94
|
+
score = result.score
|
|
95
|
+
|
|
96
|
+
# If we have a matching field from fuzzy search, display only that
|
|
97
|
+
if result.matching_field:
|
|
98
|
+
logger.info(f"Entity ID: {entity_id}")
|
|
99
|
+
logger.info(f"Matched field ({result.matching_field.path}): {result.matching_field.text}")
|
|
100
|
+
logger.info(f"{score_label}: {score:.4f}\n" + "-" * 20)
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
index_records = db_session.query(AiSearchIndex).filter(AiSearchIndex.entity_id == entity_id).all()
|
|
104
|
+
if not index_records:
|
|
105
|
+
logger.warning(f"Could not find indexed records for entity_id={entity_id}")
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
first_record = index_records[0]
|
|
109
|
+
kind = EntityType(first_record.entity_type)
|
|
110
|
+
config = ENTITY_CONFIG_REGISTRY[kind]
|
|
111
|
+
|
|
112
|
+
db_entity = db_session.get(config.table, entity_id) if config.table else None
|
|
113
|
+
|
|
114
|
+
if db_entity and config.traverser:
|
|
115
|
+
fields = config.traverser.get_fields(db_entity, config.pk_name, config.root_name)
|
|
116
|
+
result_obj = {p: v for p, v, _ in fields}
|
|
117
|
+
logger.info(json.dumps(result_obj, indent=2, default=str))
|
|
118
|
+
logger.info(f"{score_label}: {score:.4f}\n" + "-" * 20)
|
|
119
|
+
else:
|
|
120
|
+
logger.warning(f"Could not display entity {kind.value} with id={entity_id}")
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from sqlalchemy import select, text
|
|
15
|
+
from sqlalchemy.exc import ProgrammingError
|
|
16
|
+
from sqlalchemy_utils import Ltree
|
|
17
|
+
|
|
18
|
+
from orchestrator.db import db
|
|
19
|
+
from orchestrator.db.database import WrappedSession
|
|
20
|
+
from orchestrator.db.models import AiSearchIndex
|
|
21
|
+
from orchestrator.search.core.types import EntityType, FieldType
|
|
22
|
+
from orchestrator.search.filters import FilterCondition, FilterTree, LtreeFilter, PathFilter
|
|
23
|
+
from orchestrator.search.filters.definitions import operators_for
|
|
24
|
+
from orchestrator.search.retrieval.exceptions import (
|
|
25
|
+
EmptyFilterPathError,
|
|
26
|
+
IncompatibleFilterTypeError,
|
|
27
|
+
InvalidEntityPrefixError,
|
|
28
|
+
InvalidLtreePatternError,
|
|
29
|
+
PathNotFoundError,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_filter_compatible_with_field_type(filter_condition: FilterCondition, field_type: FieldType) -> bool:
|
|
34
|
+
"""Check whether a filter condition is compatible with a given field type.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
filter_condition (FilterCondition): The filter condition instance to check.
|
|
38
|
+
field_type (FieldType): The type of field from the index schema.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
bool: True if the filter condition is valid for the given field type, False otherwise.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
# LtreeFilter is for path filtering only and is thus compatible with all field types.
|
|
45
|
+
if isinstance(filter_condition, LtreeFilter):
|
|
46
|
+
return True
|
|
47
|
+
|
|
48
|
+
# Get valid operators for this field type and check if the filter's operator is valid.
|
|
49
|
+
valid_operators = operators_for(field_type)
|
|
50
|
+
return filter_condition.op in valid_operators
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def is_lquery_syntactically_valid(pattern: str, db_session: WrappedSession) -> bool:
|
|
54
|
+
"""Validate whether a string is a syntactically correct `lquery` pattern.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
pattern (str): The LTree lquery pattern string to validate.
|
|
58
|
+
db_session (WrappedSession): The database session used to test casting.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
bool: True if the pattern is valid, False if it fails to cast in PostgreSQL.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
with db_session.begin_nested():
|
|
66
|
+
db_session.execute(text("SELECT CAST(:pattern AS lquery)"), {"pattern": pattern})
|
|
67
|
+
return True
|
|
68
|
+
except ProgrammingError:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_structured_filter_schema() -> dict[str, str]:
|
|
73
|
+
"""Retrieve all distinct filterable paths and their field types from the index.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Dict[str, str]: Mapping of path strings to their corresponding field type values.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
stmt = select(AiSearchIndex.path, AiSearchIndex.value_type).distinct().order_by(AiSearchIndex.path)
|
|
80
|
+
result = db.session.execute(stmt)
|
|
81
|
+
return {str(path): value_type.value for path, value_type in result}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def validate_filter_path(path: str) -> str | None:
|
|
85
|
+
"""Check if a given path exists in the index and return its field type.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
path (str): The fully qualified LTree path.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Optional[str]: The value type of the field if found, otherwise None.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
stmt = select(AiSearchIndex.value_type).where(AiSearchIndex.path == Ltree(path)).limit(1)
|
|
95
|
+
result = db.session.execute(stmt).scalar_one_or_none()
|
|
96
|
+
return result.value if result else None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def complete_filter_validation(filter: PathFilter, entity_type: EntityType) -> None:
|
|
100
|
+
"""Validate a PathFilter against the database schema and entity type.
|
|
101
|
+
|
|
102
|
+
Checks performed:
|
|
103
|
+
1. LTree filter syntax (for LtreeFilter only)
|
|
104
|
+
2. Non-empty path
|
|
105
|
+
3. Path exists in the database schema
|
|
106
|
+
4. Filter type matches the field's value_type
|
|
107
|
+
5. Path starts with the correct entity type prefix (unless wildcard)
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
filter (PathFilter): The filter to validate.
|
|
111
|
+
entity_type (EntityType): The entity type being searched.
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ValueError: If any of the validation checks fail.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
# Ltree is a special case
|
|
118
|
+
if isinstance(filter.condition, LtreeFilter):
|
|
119
|
+
lquery_pattern = filter.condition.value
|
|
120
|
+
if not is_lquery_syntactically_valid(lquery_pattern, db.session):
|
|
121
|
+
raise InvalidLtreePatternError(lquery_pattern)
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
if not filter.path or not filter.path.strip():
|
|
125
|
+
raise EmptyFilterPathError()
|
|
126
|
+
|
|
127
|
+
# 1. Check if path exists in database
|
|
128
|
+
db_field_type_str = validate_filter_path(filter.path)
|
|
129
|
+
if db_field_type_str is None:
|
|
130
|
+
raise PathNotFoundError(filter.path)
|
|
131
|
+
|
|
132
|
+
db_field_type = FieldType(db_field_type_str)
|
|
133
|
+
|
|
134
|
+
# 2. Check filter compatibility with field type
|
|
135
|
+
if not is_filter_compatible_with_field_type(filter.condition, db_field_type):
|
|
136
|
+
expected_operators = operators_for(db_field_type)
|
|
137
|
+
raise IncompatibleFilterTypeError(
|
|
138
|
+
filter.condition.op.value, db_field_type.value, filter.path, expected_operators
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# 3. Check entity type prefix requirements (unless it's a wildcard path)
|
|
142
|
+
expected_prefix = f"{entity_type.value.lower()}."
|
|
143
|
+
if not filter.path.startswith(expected_prefix) and not filter.path.startswith("*"):
|
|
144
|
+
raise InvalidEntityPrefixError(filter.path, expected_prefix, entity_type.value)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def validate_filter_tree(filters: FilterTree | None, entity_type: EntityType) -> None:
|
|
148
|
+
"""Validate all PathFilter leaves in a FilterTree."""
|
|
149
|
+
if filters is None:
|
|
150
|
+
return
|
|
151
|
+
for leaf in filters.get_all_leaves():
|
|
152
|
+
await complete_filter_validation(leaf, entity_type)
|