orchestrator-core 4.4.0rc3__py3-none-any.whl → 4.5.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +26 -2
- orchestrator/agentic_app.py +84 -0
- orchestrator/api/api_v1/api.py +10 -0
- orchestrator/api/api_v1/endpoints/search.py +277 -0
- orchestrator/app.py +32 -0
- orchestrator/cli/index_llm.py +73 -0
- orchestrator/cli/main.py +22 -1
- orchestrator/cli/resize_embedding.py +135 -0
- orchestrator/cli/search_explore.py +208 -0
- orchestrator/cli/speedtest.py +151 -0
- orchestrator/db/models.py +37 -1
- orchestrator/graphql/schemas/process.py +2 -2
- orchestrator/graphql/schemas/workflow.py +2 -2
- orchestrator/llm_settings.py +51 -0
- orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
- orchestrator/schedules/scheduler.py +6 -7
- orchestrator/schemas/search.py +117 -0
- orchestrator/search/__init__.py +12 -0
- orchestrator/search/agent/__init__.py +8 -0
- orchestrator/search/agent/agent.py +47 -0
- orchestrator/search/agent/prompts.py +62 -0
- orchestrator/search/agent/state.py +8 -0
- orchestrator/search/agent/tools.py +121 -0
- orchestrator/search/core/__init__.py +0 -0
- orchestrator/search/core/embedding.py +64 -0
- orchestrator/search/core/exceptions.py +22 -0
- orchestrator/search/core/types.py +281 -0
- orchestrator/search/core/validators.py +27 -0
- orchestrator/search/docs/index.md +37 -0
- orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
- orchestrator/search/filters/__init__.py +27 -0
- orchestrator/search/filters/base.py +272 -0
- orchestrator/search/filters/date_filters.py +75 -0
- orchestrator/search/filters/definitions.py +93 -0
- orchestrator/search/filters/ltree_filters.py +43 -0
- orchestrator/search/filters/numeric_filter.py +60 -0
- orchestrator/search/indexing/__init__.py +3 -0
- orchestrator/search/indexing/indexer.py +323 -0
- orchestrator/search/indexing/registry.py +88 -0
- orchestrator/search/indexing/tasks.py +53 -0
- orchestrator/search/indexing/traverse.py +322 -0
- orchestrator/search/retrieval/__init__.py +3 -0
- orchestrator/search/retrieval/builder.py +108 -0
- orchestrator/search/retrieval/engine.py +152 -0
- orchestrator/search/retrieval/pagination.py +83 -0
- orchestrator/search/retrieval/retriever.py +447 -0
- orchestrator/search/retrieval/utils.py +106 -0
- orchestrator/search/retrieval/validation.py +174 -0
- orchestrator/search/schemas/__init__.py +0 -0
- orchestrator/search/schemas/parameters.py +116 -0
- orchestrator/search/schemas/results.py +63 -0
- orchestrator/services/settings_env_variables.py +2 -2
- orchestrator/settings.py +1 -1
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/METADATA +8 -3
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/RECORD +57 -14
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import array
|
|
2
|
+
import base64
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from orchestrator.search.core.exceptions import InvalidCursorError
|
|
8
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
9
|
+
from orchestrator.search.schemas.results import SearchResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PaginationParams:
|
|
14
|
+
"""Parameters for pagination in search queries."""
|
|
15
|
+
|
|
16
|
+
page_after_score: float | None = None
|
|
17
|
+
page_after_id: str | None = None
|
|
18
|
+
q_vec_override: list[float] | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def floats_to_b64(v: list[float]) -> str:
|
|
22
|
+
a = array.array("f", v)
|
|
23
|
+
return base64.urlsafe_b64encode(a.tobytes()).decode("ascii")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def b64_to_floats(s: str) -> list[float]:
|
|
27
|
+
raw = base64.urlsafe_b64decode(s.encode("ascii"))
|
|
28
|
+
a = array.array("f")
|
|
29
|
+
a.frombytes(raw)
|
|
30
|
+
return list(a)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PageCursor(BaseModel):
|
|
34
|
+
score: float
|
|
35
|
+
id: str
|
|
36
|
+
q_vec_b64: str
|
|
37
|
+
|
|
38
|
+
def encode(self) -> str:
|
|
39
|
+
"""Encode the cursor data into a URL-safe Base64 string."""
|
|
40
|
+
json_str = self.model_dump_json()
|
|
41
|
+
return base64.urlsafe_b64encode(json_str.encode("utf-8")).decode("utf-8")
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def decode(cls, cursor: str) -> "PageCursor":
|
|
45
|
+
"""Decode a Base64 string back into a PageCursor instance."""
|
|
46
|
+
try:
|
|
47
|
+
decoded_str = base64.urlsafe_b64decode(cursor).decode("utf-8")
|
|
48
|
+
return cls.model_validate_json(decoded_str)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise InvalidCursorError("Invalid pagination cursor") from e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def process_pagination_cursor(cursor: str | None, search_params: BaseSearchParameters) -> PaginationParams:
|
|
54
|
+
"""Process pagination cursor and return pagination parameters."""
|
|
55
|
+
if cursor:
|
|
56
|
+
c = PageCursor.decode(cursor)
|
|
57
|
+
return PaginationParams(
|
|
58
|
+
page_after_score=c.score,
|
|
59
|
+
page_after_id=c.id,
|
|
60
|
+
q_vec_override=b64_to_floats(c.q_vec_b64),
|
|
61
|
+
)
|
|
62
|
+
if search_params.vector_query:
|
|
63
|
+
from orchestrator.search.core.embedding import QueryEmbedder
|
|
64
|
+
|
|
65
|
+
q_vec_override = await QueryEmbedder.generate_for_text_async(search_params.vector_query)
|
|
66
|
+
return PaginationParams(q_vec_override=q_vec_override)
|
|
67
|
+
return PaginationParams()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def create_next_page_cursor(
|
|
71
|
+
search_results: list[SearchResult], pagination_params: PaginationParams, limit: int
|
|
72
|
+
) -> str | None:
|
|
73
|
+
"""Create next page cursor if there are more results."""
|
|
74
|
+
has_next_page = len(search_results) == limit and limit > 0
|
|
75
|
+
if has_next_page:
|
|
76
|
+
last_item = search_results[-1]
|
|
77
|
+
cursor_data = PageCursor(
|
|
78
|
+
score=float(last_item.score),
|
|
79
|
+
id=last_item.entity_id,
|
|
80
|
+
q_vec_b64=floats_to_b64(pagination_params.q_vec_override or []),
|
|
81
|
+
)
|
|
82
|
+
return cursor_data.encode()
|
|
83
|
+
return None
|
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from decimal import Decimal
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from sqlalchemy import BindParameter, Numeric, Select, and_, bindparam, case, cast, func, literal, or_, select
|
|
6
|
+
from sqlalchemy.sql.expression import ColumnElement
|
|
7
|
+
|
|
8
|
+
from orchestrator.db.models import AiSearchIndex
|
|
9
|
+
from orchestrator.search.core.types import FieldType, SearchMetadata
|
|
10
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
11
|
+
|
|
12
|
+
from .pagination import PaginationParams
|
|
13
|
+
|
|
14
|
+
logger = structlog.get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Retriever(ABC):
|
|
18
|
+
"""Abstract base class for applying a ranking strategy to a search query."""
|
|
19
|
+
|
|
20
|
+
SCORE_PRECISION = 12
|
|
21
|
+
SCORE_NUMERIC_TYPE = Numeric(38, 12)
|
|
22
|
+
HIGHLIGHT_TEXT_LABEL = "highlight_text"
|
|
23
|
+
HIGHLIGHT_PATH_LABEL = "highlight_path"
|
|
24
|
+
SCORE_LABEL = "score"
|
|
25
|
+
SEARCHABLE_FIELD_TYPES = [
|
|
26
|
+
FieldType.STRING.value,
|
|
27
|
+
FieldType.UUID.value,
|
|
28
|
+
FieldType.BLOCK.value,
|
|
29
|
+
FieldType.RESOURCE_TYPE.value,
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
async def from_params(
|
|
34
|
+
cls,
|
|
35
|
+
params: BaseSearchParameters,
|
|
36
|
+
pagination_params: PaginationParams,
|
|
37
|
+
) -> "Retriever":
|
|
38
|
+
"""Create the appropriate retriever instance from search parameters.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
params : BaseSearchParameters
|
|
43
|
+
Search parameters including vector queries, fuzzy terms, and filters.
|
|
44
|
+
pagination_params : PaginationParams
|
|
45
|
+
Pagination parameters for cursor-based paging.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
-------
|
|
49
|
+
Retriever
|
|
50
|
+
A concrete retriever instance (semantic, fuzzy, hybrid, or structured).
|
|
51
|
+
"""
|
|
52
|
+
fuzzy_term = params.fuzzy_term
|
|
53
|
+
q_vec = await cls._get_query_vector(params.vector_query, pagination_params.q_vec_override)
|
|
54
|
+
|
|
55
|
+
# If semantic search was attempted but failed, fall back to fuzzy with the full query
|
|
56
|
+
fallback_fuzzy_term = fuzzy_term
|
|
57
|
+
if q_vec is None and params.vector_query is not None and params.query is not None:
|
|
58
|
+
fallback_fuzzy_term = params.query
|
|
59
|
+
|
|
60
|
+
if q_vec is not None and fallback_fuzzy_term is not None:
|
|
61
|
+
return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params)
|
|
62
|
+
if q_vec is not None:
|
|
63
|
+
return SemanticRetriever(q_vec, pagination_params)
|
|
64
|
+
if fallback_fuzzy_term is not None:
|
|
65
|
+
return FuzzyRetriever(fallback_fuzzy_term, pagination_params)
|
|
66
|
+
|
|
67
|
+
return StructuredRetriever(pagination_params)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
async def _get_query_vector(
|
|
71
|
+
cls, vector_query: str | None, q_vec_override: list[float] | None
|
|
72
|
+
) -> list[float] | None:
|
|
73
|
+
"""Get query vector either from override or by generating from text."""
|
|
74
|
+
if q_vec_override:
|
|
75
|
+
return q_vec_override
|
|
76
|
+
|
|
77
|
+
if not vector_query:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
from orchestrator.search.core.embedding import QueryEmbedder
|
|
81
|
+
|
|
82
|
+
q_vec = await QueryEmbedder.generate_for_text_async(vector_query)
|
|
83
|
+
if not q_vec:
|
|
84
|
+
logger.warning("Embedding generation failed; using non-semantic retriever")
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
return q_vec
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
91
|
+
"""Apply the ranking logic to the given candidate query.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
candidate_query : Select
|
|
96
|
+
A SQLAlchemy `Select` statement returning candidate entity IDs.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
-------
|
|
100
|
+
Select
|
|
101
|
+
A new `Select` statement with ranking expressions applied.
|
|
102
|
+
"""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def _quantize_score_for_pagination(self, score_value: float) -> BindParameter[Decimal]:
|
|
106
|
+
"""Convert score value to properly quantized Decimal parameter for pagination."""
|
|
107
|
+
pas_dec = Decimal(str(score_value)).quantize(Decimal("0.000000000001"))
|
|
108
|
+
return literal(pas_dec, type_=self.SCORE_NUMERIC_TYPE)
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def metadata(self) -> SearchMetadata:
|
|
113
|
+
"""Return metadata describing this search strategy."""
|
|
114
|
+
...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class StructuredRetriever(Retriever):
|
|
118
|
+
"""Applies a dummy score for purely structured searches with no text query."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, pagination_params: PaginationParams) -> None:
|
|
121
|
+
self.page_after_id = pagination_params.page_after_id
|
|
122
|
+
|
|
123
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
124
|
+
cand = candidate_query.subquery()
|
|
125
|
+
stmt = select(cand.c.entity_id, literal(1.0).label("score")).select_from(cand)
|
|
126
|
+
|
|
127
|
+
if self.page_after_id:
|
|
128
|
+
stmt = stmt.where(cand.c.entity_id > self.page_after_id)
|
|
129
|
+
|
|
130
|
+
return stmt.order_by(cand.c.entity_id.asc())
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def metadata(self) -> SearchMetadata:
|
|
134
|
+
return SearchMetadata.structured()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class FuzzyRetriever(Retriever):
|
|
138
|
+
"""Ranks results based on the max of fuzzy text similarity scores."""
|
|
139
|
+
|
|
140
|
+
def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None:
|
|
141
|
+
self.fuzzy_term = fuzzy_term
|
|
142
|
+
self.page_after_score = pagination_params.page_after_score
|
|
143
|
+
self.page_after_id = pagination_params.page_after_id
|
|
144
|
+
|
|
145
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
146
|
+
cand = candidate_query.subquery()
|
|
147
|
+
|
|
148
|
+
similarity_expr = func.word_similarity(self.fuzzy_term, AiSearchIndex.value)
|
|
149
|
+
|
|
150
|
+
raw_max = func.max(similarity_expr).over(partition_by=AiSearchIndex.entity_id)
|
|
151
|
+
score = cast(
|
|
152
|
+
func.round(cast(raw_max, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION), self.SCORE_NUMERIC_TYPE
|
|
153
|
+
).label(self.SCORE_LABEL)
|
|
154
|
+
|
|
155
|
+
combined_query = (
|
|
156
|
+
select(
|
|
157
|
+
AiSearchIndex.entity_id,
|
|
158
|
+
score,
|
|
159
|
+
func.first_value(AiSearchIndex.value)
|
|
160
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()])
|
|
161
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
162
|
+
func.first_value(AiSearchIndex.path)
|
|
163
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()])
|
|
164
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
165
|
+
)
|
|
166
|
+
.select_from(AiSearchIndex)
|
|
167
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
168
|
+
.where(
|
|
169
|
+
and_(
|
|
170
|
+
AiSearchIndex.value_type.in_(self.SEARCHABLE_FIELD_TYPES),
|
|
171
|
+
literal(self.fuzzy_term).op("<%")(AiSearchIndex.value),
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
.distinct(AiSearchIndex.entity_id)
|
|
175
|
+
)
|
|
176
|
+
final_query = combined_query.subquery("ranked_fuzzy")
|
|
177
|
+
|
|
178
|
+
stmt = select(
|
|
179
|
+
final_query.c.entity_id,
|
|
180
|
+
final_query.c.score,
|
|
181
|
+
final_query.c.highlight_text,
|
|
182
|
+
final_query.c.highlight_path,
|
|
183
|
+
).select_from(final_query)
|
|
184
|
+
|
|
185
|
+
stmt = self._apply_score_pagination(stmt, final_query.c.score, final_query.c.entity_id)
|
|
186
|
+
|
|
187
|
+
return stmt.order_by(final_query.c.score.desc().nulls_last(), final_query.c.entity_id.asc())
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def metadata(self) -> SearchMetadata:
|
|
191
|
+
return SearchMetadata.fuzzy()
|
|
192
|
+
|
|
193
|
+
def _apply_score_pagination(
|
|
194
|
+
self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement
|
|
195
|
+
) -> Select:
|
|
196
|
+
"""Apply standard score + entity_id pagination."""
|
|
197
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
198
|
+
stmt = stmt.where(
|
|
199
|
+
or_(
|
|
200
|
+
score_column < self.page_after_score,
|
|
201
|
+
and_(
|
|
202
|
+
score_column == self.page_after_score,
|
|
203
|
+
entity_id_column > self.page_after_id,
|
|
204
|
+
),
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
return stmt
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class SemanticRetriever(Retriever):
|
|
211
|
+
"""Ranks results based on the minimum semantic vector distance."""
|
|
212
|
+
|
|
213
|
+
def __init__(self, vector_query: list[float], pagination_params: PaginationParams) -> None:
|
|
214
|
+
self.vector_query = vector_query
|
|
215
|
+
self.page_after_score = pagination_params.page_after_score
|
|
216
|
+
self.page_after_id = pagination_params.page_after_id
|
|
217
|
+
|
|
218
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
219
|
+
cand = candidate_query.subquery()
|
|
220
|
+
|
|
221
|
+
dist = AiSearchIndex.embedding.l2_distance(self.vector_query)
|
|
222
|
+
|
|
223
|
+
raw_min = func.min(dist).over(partition_by=AiSearchIndex.entity_id)
|
|
224
|
+
|
|
225
|
+
# Normalize score to preserve ordering in accordance with other retrievers:
|
|
226
|
+
# smaller distance = higher score
|
|
227
|
+
similarity = literal(1.0, type_=self.SCORE_NUMERIC_TYPE) / (
|
|
228
|
+
literal(1.0, type_=self.SCORE_NUMERIC_TYPE) + cast(raw_min, self.SCORE_NUMERIC_TYPE)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
score = cast(
|
|
232
|
+
func.round(cast(similarity, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION), self.SCORE_NUMERIC_TYPE
|
|
233
|
+
).label(self.SCORE_LABEL)
|
|
234
|
+
|
|
235
|
+
combined_query = (
|
|
236
|
+
select(
|
|
237
|
+
AiSearchIndex.entity_id,
|
|
238
|
+
score,
|
|
239
|
+
func.first_value(AiSearchIndex.value)
|
|
240
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()])
|
|
241
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
242
|
+
func.first_value(AiSearchIndex.path)
|
|
243
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[dist.asc(), AiSearchIndex.path.asc()])
|
|
244
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
245
|
+
)
|
|
246
|
+
.select_from(AiSearchIndex)
|
|
247
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
248
|
+
.where(AiSearchIndex.embedding.isnot(None))
|
|
249
|
+
.distinct(AiSearchIndex.entity_id)
|
|
250
|
+
)
|
|
251
|
+
final_query = combined_query.subquery("ranked_semantic")
|
|
252
|
+
|
|
253
|
+
stmt = select(
|
|
254
|
+
final_query.c.entity_id,
|
|
255
|
+
final_query.c.score,
|
|
256
|
+
final_query.c.highlight_text,
|
|
257
|
+
final_query.c.highlight_path,
|
|
258
|
+
).select_from(final_query)
|
|
259
|
+
|
|
260
|
+
stmt = self._apply_semantic_pagination(stmt, final_query.c.score, final_query.c.entity_id)
|
|
261
|
+
|
|
262
|
+
return stmt.order_by(final_query.c.score.desc().nulls_last(), final_query.c.entity_id.asc())
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def metadata(self) -> SearchMetadata:
|
|
266
|
+
return SearchMetadata.semantic()
|
|
267
|
+
|
|
268
|
+
def _apply_semantic_pagination(
|
|
269
|
+
self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement
|
|
270
|
+
) -> Select:
|
|
271
|
+
"""Apply semantic score pagination with precise Decimal handling."""
|
|
272
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
273
|
+
score_param = self._quantize_score_for_pagination(self.page_after_score)
|
|
274
|
+
stmt = stmt.where(
|
|
275
|
+
or_(
|
|
276
|
+
score_column < score_param,
|
|
277
|
+
and_(score_column == score_param, entity_id_column > self.page_after_id),
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
return stmt
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class RrfHybridRetriever(Retriever):
|
|
284
|
+
"""Reciprocal Rank Fusion of semantic and fuzzy ranking with parent-child retrieval."""
|
|
285
|
+
|
|
286
|
+
def __init__(
|
|
287
|
+
self,
|
|
288
|
+
q_vec: list[float],
|
|
289
|
+
fuzzy_term: str,
|
|
290
|
+
pagination_params: PaginationParams,
|
|
291
|
+
k: int = 60,
|
|
292
|
+
field_candidates_limit: int = 100,
|
|
293
|
+
) -> None:
|
|
294
|
+
self.q_vec = q_vec
|
|
295
|
+
self.fuzzy_term = fuzzy_term
|
|
296
|
+
self.page_after_score = pagination_params.page_after_score
|
|
297
|
+
self.page_after_id = pagination_params.page_after_id
|
|
298
|
+
self.k = k
|
|
299
|
+
self.field_candidates_limit = field_candidates_limit
|
|
300
|
+
|
|
301
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
302
|
+
cand = candidate_query.subquery()
|
|
303
|
+
q_param: BindParameter[list[float]] = bindparam("q_vec", self.q_vec, type_=AiSearchIndex.embedding.type)
|
|
304
|
+
|
|
305
|
+
best_similarity = func.word_similarity(self.fuzzy_term, AiSearchIndex.value)
|
|
306
|
+
sem_expr = case(
|
|
307
|
+
(AiSearchIndex.embedding.is_(None), None),
|
|
308
|
+
else_=AiSearchIndex.embedding.op("<->")(q_param),
|
|
309
|
+
)
|
|
310
|
+
sem_val = func.coalesce(sem_expr, literal(1.0)).label("semantic_distance")
|
|
311
|
+
|
|
312
|
+
filter_condition = literal(self.fuzzy_term).op("<%")(AiSearchIndex.value)
|
|
313
|
+
|
|
314
|
+
field_candidates = (
|
|
315
|
+
select(
|
|
316
|
+
AiSearchIndex.entity_id,
|
|
317
|
+
AiSearchIndex.path,
|
|
318
|
+
AiSearchIndex.value,
|
|
319
|
+
sem_val,
|
|
320
|
+
best_similarity.label("fuzzy_score"),
|
|
321
|
+
)
|
|
322
|
+
.select_from(AiSearchIndex)
|
|
323
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
324
|
+
.where(
|
|
325
|
+
and_(
|
|
326
|
+
AiSearchIndex.value_type.in_(self.SEARCHABLE_FIELD_TYPES),
|
|
327
|
+
filter_condition,
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
.order_by(
|
|
331
|
+
best_similarity.desc().nulls_last(),
|
|
332
|
+
sem_expr.asc().nulls_last(),
|
|
333
|
+
AiSearchIndex.entity_id.asc(),
|
|
334
|
+
)
|
|
335
|
+
.limit(self.field_candidates_limit)
|
|
336
|
+
).cte("field_candidates")
|
|
337
|
+
|
|
338
|
+
entity_scores = (
|
|
339
|
+
select(
|
|
340
|
+
field_candidates.c.entity_id,
|
|
341
|
+
func.avg(field_candidates.c.semantic_distance).label("avg_semantic_distance"),
|
|
342
|
+
func.avg(field_candidates.c.fuzzy_score).label("avg_fuzzy_score"),
|
|
343
|
+
).group_by(field_candidates.c.entity_id)
|
|
344
|
+
).cte("entity_scores")
|
|
345
|
+
|
|
346
|
+
entity_highlights = (
|
|
347
|
+
select(
|
|
348
|
+
field_candidates.c.entity_id,
|
|
349
|
+
func.first_value(field_candidates.c.value)
|
|
350
|
+
.over(
|
|
351
|
+
partition_by=field_candidates.c.entity_id,
|
|
352
|
+
order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()],
|
|
353
|
+
)
|
|
354
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
355
|
+
func.first_value(field_candidates.c.path)
|
|
356
|
+
.over(
|
|
357
|
+
partition_by=field_candidates.c.entity_id,
|
|
358
|
+
order_by=[field_candidates.c.fuzzy_score.desc(), field_candidates.c.path.asc()],
|
|
359
|
+
)
|
|
360
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
361
|
+
).distinct(field_candidates.c.entity_id)
|
|
362
|
+
).cte("entity_highlights")
|
|
363
|
+
|
|
364
|
+
ranked = (
|
|
365
|
+
select(
|
|
366
|
+
entity_scores.c.entity_id,
|
|
367
|
+
entity_scores.c.avg_semantic_distance,
|
|
368
|
+
entity_scores.c.avg_fuzzy_score,
|
|
369
|
+
entity_highlights.c.highlight_text,
|
|
370
|
+
entity_highlights.c.highlight_path,
|
|
371
|
+
func.dense_rank()
|
|
372
|
+
.over(
|
|
373
|
+
order_by=[entity_scores.c.avg_semantic_distance.asc().nulls_last(), entity_scores.c.entity_id.asc()]
|
|
374
|
+
)
|
|
375
|
+
.label("sem_rank"),
|
|
376
|
+
func.dense_rank()
|
|
377
|
+
.over(order_by=[entity_scores.c.avg_fuzzy_score.desc().nulls_last(), entity_scores.c.entity_id.asc()])
|
|
378
|
+
.label("fuzzy_rank"),
|
|
379
|
+
).select_from(
|
|
380
|
+
entity_scores.join(entity_highlights, entity_scores.c.entity_id == entity_highlights.c.entity_id)
|
|
381
|
+
)
|
|
382
|
+
).cte("ranked_results")
|
|
383
|
+
|
|
384
|
+
# RRF (rank-based)
|
|
385
|
+
rrf_raw = (1.0 / (self.k + ranked.c.sem_rank)) + (1.0 / (self.k + ranked.c.fuzzy_rank))
|
|
386
|
+
rrf_num = cast(rrf_raw, self.SCORE_NUMERIC_TYPE)
|
|
387
|
+
|
|
388
|
+
# Perfect flag to boost near perfect fuzzy matches as this most likely indicates the desired record.
|
|
389
|
+
perfect = case((ranked.c.avg_fuzzy_score >= 0.9, 1), else_=0).label("perfect_match")
|
|
390
|
+
|
|
391
|
+
# Dynamic beta based on k (and number of sources)
|
|
392
|
+
# rrf_max = n_sources / (k + 1)
|
|
393
|
+
k_num = literal(float(self.k), type_=self.SCORE_NUMERIC_TYPE)
|
|
394
|
+
n_sources = literal(2.0, type_=self.SCORE_NUMERIC_TYPE) # semantic + fuzzy
|
|
395
|
+
rrf_max = n_sources / (k_num + literal(1.0, type_=self.SCORE_NUMERIC_TYPE))
|
|
396
|
+
|
|
397
|
+
# Choose a small positive margin above rrf_max to ensure strict separation
|
|
398
|
+
# Keep it small to avoid compressing perfects near 1 after normalization
|
|
399
|
+
margin = rrf_max * literal(0.05, type_=self.SCORE_NUMERIC_TYPE) # 5% above bound
|
|
400
|
+
beta = rrf_max + margin
|
|
401
|
+
|
|
402
|
+
fused_num = rrf_num + beta * cast(perfect, self.SCORE_NUMERIC_TYPE)
|
|
403
|
+
|
|
404
|
+
# Normalize to [0,1] via the theoretical max (beta + rrf_max)
|
|
405
|
+
norm_den = beta + rrf_max
|
|
406
|
+
normalized_score = fused_num / norm_den
|
|
407
|
+
|
|
408
|
+
score = cast(
|
|
409
|
+
func.round(cast(normalized_score, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION),
|
|
410
|
+
self.SCORE_NUMERIC_TYPE,
|
|
411
|
+
).label(self.SCORE_LABEL)
|
|
412
|
+
|
|
413
|
+
stmt = select(
|
|
414
|
+
ranked.c.entity_id,
|
|
415
|
+
score,
|
|
416
|
+
ranked.c.highlight_text,
|
|
417
|
+
ranked.c.highlight_path,
|
|
418
|
+
perfect.label("perfect_match"),
|
|
419
|
+
).select_from(ranked)
|
|
420
|
+
|
|
421
|
+
stmt = self._apply_fused_pagination(stmt, score, ranked.c.entity_id)
|
|
422
|
+
|
|
423
|
+
return stmt.order_by(
|
|
424
|
+
score.desc().nulls_last(),
|
|
425
|
+
ranked.c.entity_id.asc(),
|
|
426
|
+
).params(q_vec=self.q_vec)
|
|
427
|
+
|
|
428
|
+
def _apply_fused_pagination(
|
|
429
|
+
self,
|
|
430
|
+
stmt: Select,
|
|
431
|
+
score_column: ColumnElement,
|
|
432
|
+
entity_id_column: ColumnElement,
|
|
433
|
+
) -> Select:
|
|
434
|
+
"""Keyset paginate by fused score + id."""
|
|
435
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
436
|
+
score_param = self._quantize_score_for_pagination(self.page_after_score)
|
|
437
|
+
stmt = stmt.where(
|
|
438
|
+
or_(
|
|
439
|
+
score_column < score_param,
|
|
440
|
+
and_(score_column == score_param, entity_id_column > self.page_after_id),
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
return stmt
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def metadata(self) -> SearchMetadata:
|
|
447
|
+
return SearchMetadata.hybrid()
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from sqlalchemy import and_
|
|
6
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
7
|
+
|
|
8
|
+
from orchestrator.db.database import WrappedSession
|
|
9
|
+
from orchestrator.db.models import AiSearchIndex
|
|
10
|
+
from orchestrator.search.core.types import EntityType
|
|
11
|
+
from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
|
|
12
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
13
|
+
from orchestrator.search.schemas.results import SearchResult
|
|
14
|
+
|
|
15
|
+
logger = structlog.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def generate_highlight_indices(text: str, term: str) -> list[tuple[int, int]]:
|
|
19
|
+
"""Finds all occurrences of individual words from the term with word boundary matching case-insensitively."""
|
|
20
|
+
if not text or not term:
|
|
21
|
+
return []
|
|
22
|
+
|
|
23
|
+
all_matches = []
|
|
24
|
+
words = [w.strip() for w in term.split() if w.strip()]
|
|
25
|
+
|
|
26
|
+
for word in words:
|
|
27
|
+
word_boundary_pattern = rf"\b{re.escape(word)}\b"
|
|
28
|
+
matches = list(re.finditer(word_boundary_pattern, text, re.IGNORECASE))
|
|
29
|
+
|
|
30
|
+
if not matches:
|
|
31
|
+
substring_pattern = re.escape(word)
|
|
32
|
+
matches = list(re.finditer(substring_pattern, text, re.IGNORECASE))
|
|
33
|
+
|
|
34
|
+
all_matches.extend([(m.start(), m.end()) for m in matches])
|
|
35
|
+
|
|
36
|
+
return sorted(set(all_matches))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def display_filtered_paths_only(
|
|
40
|
+
results: list[SearchResult], search_params: BaseSearchParameters, db_session: WrappedSession
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Display only the paths that were searched for in the results."""
|
|
43
|
+
if not results:
|
|
44
|
+
logger.info("No results found.")
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
logger.info("--- Search Results ---")
|
|
48
|
+
|
|
49
|
+
searched_paths = search_params.filters.get_all_paths() if search_params.filters else []
|
|
50
|
+
if not searched_paths:
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
for result in results:
|
|
54
|
+
for path in searched_paths:
|
|
55
|
+
record: AiSearchIndex | None = (
|
|
56
|
+
db_session.query(AiSearchIndex)
|
|
57
|
+
.filter(and_(AiSearchIndex.entity_id == result.entity_id, AiSearchIndex.path == Ltree(path)))
|
|
58
|
+
.first()
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if record:
|
|
62
|
+
logger.info(f" {record.path}: {record.value}")
|
|
63
|
+
|
|
64
|
+
logger.info("-" * 40)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def display_results(
|
|
68
|
+
results: list[SearchResult],
|
|
69
|
+
db_session: WrappedSession,
|
|
70
|
+
score_label: str = "Score",
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Display search results, showing matched field when available or uuid+name for vector search."""
|
|
73
|
+
if not results:
|
|
74
|
+
logger.info("No results found.")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
logger.info("--- Search Results ---")
|
|
78
|
+
for result in results:
|
|
79
|
+
entity_id = result.entity_id
|
|
80
|
+
score = result.score
|
|
81
|
+
|
|
82
|
+
# If we have a matching field from fuzzy search, display only that
|
|
83
|
+
if result.matching_field:
|
|
84
|
+
logger.info(f"Entity ID: {entity_id}")
|
|
85
|
+
logger.info(f"Matched field ({result.matching_field.path}): {result.matching_field.text}")
|
|
86
|
+
logger.info(f"{score_label}: {score:.4f}\n" + "-" * 20)
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
index_records = db_session.query(AiSearchIndex).filter(AiSearchIndex.entity_id == entity_id).all()
|
|
90
|
+
if not index_records:
|
|
91
|
+
logger.warning(f"Could not find indexed records for entity_id={entity_id}")
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
first_record = index_records[0]
|
|
95
|
+
kind = EntityType(first_record.entity_type)
|
|
96
|
+
config = ENTITY_CONFIG_REGISTRY[kind]
|
|
97
|
+
|
|
98
|
+
db_entity = db_session.get(config.table, entity_id) if config.table else None
|
|
99
|
+
|
|
100
|
+
if db_entity and config.traverser:
|
|
101
|
+
fields = config.traverser.get_fields(db_entity, config.pk_name, config.root_name)
|
|
102
|
+
result_obj = {p: v for p, v, _ in fields}
|
|
103
|
+
logger.info(json.dumps(result_obj, indent=2, default=str))
|
|
104
|
+
logger.info(f"{score_label}: {score:.4f}\n" + "-" * 20)
|
|
105
|
+
else:
|
|
106
|
+
logger.warning(f"Could not display entity {kind.value} with id={entity_id}")
|