orchestrator-core 4.4.2__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +17 -2
- orchestrator/agentic_app.py +103 -0
- orchestrator/api/api_v1/api.py +14 -2
- orchestrator/api/api_v1/endpoints/search.py +296 -0
- orchestrator/app.py +32 -0
- orchestrator/cli/main.py +22 -1
- orchestrator/cli/search/__init__.py +32 -0
- orchestrator/cli/search/index_llm.py +73 -0
- orchestrator/cli/search/resize_embedding.py +135 -0
- orchestrator/cli/search/search_explore.py +208 -0
- orchestrator/cli/search/speedtest.py +151 -0
- orchestrator/db/models.py +37 -1
- orchestrator/devtools/populator.py +16 -0
- orchestrator/domain/base.py +2 -7
- orchestrator/domain/lifecycle.py +24 -7
- orchestrator/llm_settings.py +57 -0
- orchestrator/log_config.py +1 -0
- orchestrator/migrations/helpers.py +7 -1
- orchestrator/schemas/search.py +130 -0
- orchestrator/schemas/workflow.py +1 -0
- orchestrator/search/__init__.py +12 -0
- orchestrator/search/agent/__init__.py +21 -0
- orchestrator/search/agent/agent.py +62 -0
- orchestrator/search/agent/prompts.py +100 -0
- orchestrator/search/agent/state.py +21 -0
- orchestrator/search/agent/tools.py +258 -0
- orchestrator/search/core/__init__.py +12 -0
- orchestrator/search/core/embedding.py +73 -0
- orchestrator/search/core/exceptions.py +36 -0
- orchestrator/search/core/types.py +296 -0
- orchestrator/search/core/validators.py +40 -0
- orchestrator/search/docs/index.md +37 -0
- orchestrator/search/docs/running_local_text_embedding_inference.md +46 -0
- orchestrator/search/filters/__init__.py +40 -0
- orchestrator/search/filters/base.py +295 -0
- orchestrator/search/filters/date_filters.py +88 -0
- orchestrator/search/filters/definitions.py +107 -0
- orchestrator/search/filters/ltree_filters.py +56 -0
- orchestrator/search/filters/numeric_filter.py +73 -0
- orchestrator/search/indexing/__init__.py +16 -0
- orchestrator/search/indexing/indexer.py +334 -0
- orchestrator/search/indexing/registry.py +101 -0
- orchestrator/search/indexing/tasks.py +69 -0
- orchestrator/search/indexing/traverse.py +334 -0
- orchestrator/search/llm_migration.py +108 -0
- orchestrator/search/retrieval/__init__.py +16 -0
- orchestrator/search/retrieval/builder.py +123 -0
- orchestrator/search/retrieval/engine.py +154 -0
- orchestrator/search/retrieval/exceptions.py +90 -0
- orchestrator/search/retrieval/pagination.py +96 -0
- orchestrator/search/retrieval/retrievers/__init__.py +26 -0
- orchestrator/search/retrieval/retrievers/base.py +123 -0
- orchestrator/search/retrieval/retrievers/fuzzy.py +94 -0
- orchestrator/search/retrieval/retrievers/hybrid.py +277 -0
- orchestrator/search/retrieval/retrievers/semantic.py +94 -0
- orchestrator/search/retrieval/retrievers/structured.py +39 -0
- orchestrator/search/retrieval/utils.py +120 -0
- orchestrator/search/retrieval/validation.py +152 -0
- orchestrator/search/schemas/__init__.py +12 -0
- orchestrator/search/schemas/parameters.py +129 -0
- orchestrator/search/schemas/results.py +77 -0
- orchestrator/services/processes.py +1 -1
- orchestrator/services/settings_env_variables.py +2 -2
- orchestrator/settings.py +8 -1
- orchestrator/utils/state.py +6 -1
- orchestrator/workflows/steps.py +15 -1
- orchestrator/workflows/tasks/validate_products.py +1 -1
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/METADATA +15 -8
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/RECORD +71 -21
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from collections.abc import Sequence
|
|
15
|
+
|
|
16
|
+
import structlog
|
|
17
|
+
from sqlalchemy.engine.row import RowMapping
|
|
18
|
+
from sqlalchemy.orm import Session
|
|
19
|
+
|
|
20
|
+
from orchestrator.search.core.types import FilterOp, SearchMetadata
|
|
21
|
+
from orchestrator.search.filters import FilterTree, LtreeFilter
|
|
22
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
23
|
+
from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult
|
|
24
|
+
|
|
25
|
+
from .builder import build_candidate_query
|
|
26
|
+
from .pagination import PaginationParams
|
|
27
|
+
from .retrievers import Retriever
|
|
28
|
+
from .utils import generate_highlight_indices
|
|
29
|
+
|
|
30
|
+
logger = structlog.get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _format_response(
|
|
34
|
+
db_rows: Sequence[RowMapping], search_params: BaseSearchParameters, metadata: SearchMetadata
|
|
35
|
+
) -> SearchResponse:
|
|
36
|
+
"""Format database query results into a `SearchResponse`.
|
|
37
|
+
|
|
38
|
+
Converts raw SQLAlchemy `RowMapping` objects into `SearchResult` instances,
|
|
39
|
+
including highlight metadata if present in the database results.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
db_rows (Sequence[RowMapping]): The rows returned from the executed SQLAlchemy query.
|
|
43
|
+
search_params (BaseSearchParameters): The search parameters, including query text and filters.
|
|
44
|
+
metadata (SearchMetadata): Metadata about the search execution.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
SearchResponse: A list of `SearchResult` objects containing entity IDs, scores,
|
|
48
|
+
and optional highlight information.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
if not db_rows:
|
|
52
|
+
return SearchResponse(results=[], metadata=metadata)
|
|
53
|
+
|
|
54
|
+
user_query = search_params.query
|
|
55
|
+
|
|
56
|
+
results = []
|
|
57
|
+
for row in db_rows:
|
|
58
|
+
matching_field = None
|
|
59
|
+
|
|
60
|
+
if (
|
|
61
|
+
user_query
|
|
62
|
+
and (text := row.get(Retriever.HIGHLIGHT_TEXT_LABEL))
|
|
63
|
+
and (path := row.get(Retriever.HIGHLIGHT_PATH_LABEL))
|
|
64
|
+
):
|
|
65
|
+
if not isinstance(text, str):
|
|
66
|
+
text = str(text)
|
|
67
|
+
if not isinstance(path, str):
|
|
68
|
+
path = str(path)
|
|
69
|
+
|
|
70
|
+
highlight_indices = generate_highlight_indices(text, user_query) or None
|
|
71
|
+
matching_field = MatchingField(text=text, path=path, highlight_indices=highlight_indices)
|
|
72
|
+
|
|
73
|
+
elif not user_query and search_params.filters and metadata.search_type == "structured":
|
|
74
|
+
# Structured search (filter-only)
|
|
75
|
+
matching_field = _extract_matching_field_from_filters(search_params.filters)
|
|
76
|
+
|
|
77
|
+
results.append(
|
|
78
|
+
SearchResult(
|
|
79
|
+
entity_id=str(row.entity_id),
|
|
80
|
+
score=row.score,
|
|
81
|
+
perfect_match=row.get("perfect_match", 0),
|
|
82
|
+
matching_field=matching_field,
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
return SearchResponse(results=results, metadata=metadata)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _extract_matching_field_from_filters(filters: FilterTree) -> MatchingField | None:
|
|
89
|
+
"""Extract the first path filter to use as matching field for structured searches."""
|
|
90
|
+
leaves = filters.get_all_leaves()
|
|
91
|
+
if len(leaves) != 1:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
pf = leaves[0]
|
|
95
|
+
|
|
96
|
+
if isinstance(pf.condition, LtreeFilter):
|
|
97
|
+
op = pf.condition.op
|
|
98
|
+
# Prefer the original component/pattern (validator may set path="*" and move the value)
|
|
99
|
+
display = str(getattr(pf.condition, "value", "") or pf.path)
|
|
100
|
+
|
|
101
|
+
# There can be no match for abscence.
|
|
102
|
+
if op == FilterOp.NOT_HAS_COMPONENT:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return MatchingField(text=display, path=display, highlight_indices=[(0, len(display))])
|
|
106
|
+
|
|
107
|
+
# Everything thats not Ltree
|
|
108
|
+
val = getattr(pf.condition, "value", "")
|
|
109
|
+
text = "" if val is None else str(val)
|
|
110
|
+
return MatchingField(text=text, path=pf.path, highlight_indices=[(0, len(text))])
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
async def execute_search(
|
|
114
|
+
search_params: BaseSearchParameters,
|
|
115
|
+
db_session: Session,
|
|
116
|
+
pagination_params: PaginationParams | None = None,
|
|
117
|
+
) -> SearchResponse:
|
|
118
|
+
"""Execute a hybrid search and return ranked results.
|
|
119
|
+
|
|
120
|
+
Builds a candidate entity query based on the given search parameters,
|
|
121
|
+
applies the appropriate ranking strategy, and executes the final ranked
|
|
122
|
+
query to retrieve results.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
search_params (BaseSearchParameters): The search parameters specifying vector, fuzzy, or filter criteria.
|
|
126
|
+
db_session (Session): The active SQLAlchemy session for executing the query.
|
|
127
|
+
pagination_params (PaginationParams): Parameters controlling pagination of the search results.
|
|
128
|
+
limit (int, optional): The maximum number of search results to return, by default 5.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
SearchResponse: A list of `SearchResult` objects containing entity IDs, scores,
|
|
132
|
+
and optional highlight metadata.
|
|
133
|
+
|
|
134
|
+
Notes:
|
|
135
|
+
If no vector query, filters, or fuzzy term are provided, a warning is logged
|
|
136
|
+
and an empty result set is returned.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
if not search_params.vector_query and not search_params.filters and not search_params.fuzzy_term:
|
|
140
|
+
logger.warning("No search criteria provided (vector_query, fuzzy_term, or filters).")
|
|
141
|
+
return SearchResponse(results=[], metadata=SearchMetadata.empty())
|
|
142
|
+
|
|
143
|
+
candidate_query = build_candidate_query(search_params)
|
|
144
|
+
|
|
145
|
+
pagination_params = pagination_params or PaginationParams()
|
|
146
|
+
retriever = await Retriever.from_params(search_params, pagination_params)
|
|
147
|
+
logger.debug("Using retriever", retriever_type=retriever.__class__.__name__)
|
|
148
|
+
|
|
149
|
+
final_stmt = retriever.apply(candidate_query)
|
|
150
|
+
final_stmt = final_stmt.limit(search_params.limit)
|
|
151
|
+
logger.debug(final_stmt)
|
|
152
|
+
result = db_session.execute(final_stmt).mappings().all()
|
|
153
|
+
|
|
154
|
+
return _format_response(result, search_params, retriever.metadata)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from orchestrator.search.core.types import FilterOp
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FilterValidationError(Exception):
|
|
18
|
+
"""Base exception for filter validation errors."""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InvalidLtreePatternError(FilterValidationError):
|
|
24
|
+
"""Raised when an ltree pattern has invalid ltree query syntax."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, pattern: str) -> None:
|
|
27
|
+
message = f"Ltree pattern '{pattern}' has invalid syntax. Use valid PostgreSQL ltree lquery syntax."
|
|
28
|
+
super().__init__(message)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EmptyFilterPathError(FilterValidationError):
|
|
32
|
+
"""Raised when a filter path is empty or contains only whitespace."""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
message = (
|
|
36
|
+
"Filter path cannot be empty. Provide a valid path like 'subscription.product.name' or 'workflow.name'."
|
|
37
|
+
)
|
|
38
|
+
super().__init__(message)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class PathNotFoundError(FilterValidationError):
|
|
42
|
+
"""Raised when a filter path doesn't exist in the database schema.
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
Using a non-existent filter path:
|
|
46
|
+
|
|
47
|
+
>>> print(PathNotFoundError('subscription.nonexistent.field'))
|
|
48
|
+
Path 'subscription.nonexistent.field' does not exist in the database.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, path: str) -> None:
|
|
52
|
+
message = f"Path '{path}' does not exist in the database."
|
|
53
|
+
super().__init__(message)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class IncompatibleFilterTypeError(FilterValidationError):
|
|
57
|
+
"""Raised when a filter operator is incompatible with the field's data type.
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
Using a numeric comparison operator on a string field:
|
|
61
|
+
|
|
62
|
+
>>> print(IncompatibleFilterTypeError(
|
|
63
|
+
... operator='gt',
|
|
64
|
+
... field_type='string',
|
|
65
|
+
... path='subscription.customer_name',
|
|
66
|
+
... expected_operators=[FilterOp.EQ, FilterOp.NEQ, FilterOp.LIKE],
|
|
67
|
+
... ))
|
|
68
|
+
Operator 'gt' is not compatible with field type 'string' for path 'subscription.customer_name'. Valid operators for 'string': [eq, neq, like]
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, operator: str, field_type: str, path: str, expected_operators: list[FilterOp]) -> None:
|
|
72
|
+
valid_ops_str = ", ".join([op.value for op in expected_operators])
|
|
73
|
+
message = f"Operator '{operator}' is not compatible with field type '{field_type}' for path '{path}'. Valid operators for '{field_type}': [{valid_ops_str}]"
|
|
74
|
+
|
|
75
|
+
super().__init__(message)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class InvalidEntityPrefixError(FilterValidationError):
|
|
79
|
+
"""Raised when a filter path doesn't have the correct entity type prefix.
|
|
80
|
+
|
|
81
|
+
Examples:
|
|
82
|
+
Using wrong entity prefix in filter path:
|
|
83
|
+
|
|
84
|
+
>>> print(InvalidEntityPrefixError('workflow.name', 'subscription.', 'SUBSCRIPTION'))
|
|
85
|
+
Filter path 'workflow.name' must start with 'subscription.' for SUBSCRIPTION searches, or use '*' for wildcard paths.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, path: str, expected_prefix: str, entity_type: str) -> None:
|
|
89
|
+
message = f"Filter path '{path}' must start with '{expected_prefix}' for {entity_type} searches, or use '*' for wildcard paths."
|
|
90
|
+
super().__init__(message)
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
import array
|
|
15
|
+
import base64
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from orchestrator.search.core.exceptions import InvalidCursorError
|
|
21
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
22
|
+
from orchestrator.search.schemas.results import SearchResult
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PaginationParams:
|
|
27
|
+
"""Parameters for pagination in search queries."""
|
|
28
|
+
|
|
29
|
+
page_after_score: float | None = None
|
|
30
|
+
page_after_id: str | None = None
|
|
31
|
+
q_vec_override: list[float] | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def floats_to_b64(v: list[float]) -> str:
|
|
35
|
+
a = array.array("f", v)
|
|
36
|
+
return base64.urlsafe_b64encode(a.tobytes()).decode("ascii")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def b64_to_floats(s: str) -> list[float]:
|
|
40
|
+
raw = base64.urlsafe_b64decode(s.encode("ascii"))
|
|
41
|
+
a = array.array("f")
|
|
42
|
+
a.frombytes(raw)
|
|
43
|
+
return list(a)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PageCursor(BaseModel):
|
|
47
|
+
score: float
|
|
48
|
+
id: str
|
|
49
|
+
q_vec_b64: str
|
|
50
|
+
|
|
51
|
+
def encode(self) -> str:
|
|
52
|
+
"""Encode the cursor data into a URL-safe Base64 string."""
|
|
53
|
+
json_str = self.model_dump_json()
|
|
54
|
+
return base64.urlsafe_b64encode(json_str.encode("utf-8")).decode("utf-8")
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def decode(cls, cursor: str) -> "PageCursor":
|
|
58
|
+
"""Decode a Base64 string back into a PageCursor instance."""
|
|
59
|
+
try:
|
|
60
|
+
decoded_str = base64.urlsafe_b64decode(cursor).decode("utf-8")
|
|
61
|
+
return cls.model_validate_json(decoded_str)
|
|
62
|
+
except Exception as e:
|
|
63
|
+
raise InvalidCursorError("Invalid pagination cursor") from e
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def process_pagination_cursor(cursor: str | None, search_params: BaseSearchParameters) -> PaginationParams:
|
|
67
|
+
"""Process pagination cursor and return pagination parameters."""
|
|
68
|
+
if cursor:
|
|
69
|
+
c = PageCursor.decode(cursor)
|
|
70
|
+
return PaginationParams(
|
|
71
|
+
page_after_score=c.score,
|
|
72
|
+
page_after_id=c.id,
|
|
73
|
+
q_vec_override=b64_to_floats(c.q_vec_b64),
|
|
74
|
+
)
|
|
75
|
+
if search_params.vector_query:
|
|
76
|
+
from orchestrator.search.core.embedding import QueryEmbedder
|
|
77
|
+
|
|
78
|
+
q_vec_override = await QueryEmbedder.generate_for_text_async(search_params.vector_query)
|
|
79
|
+
return PaginationParams(q_vec_override=q_vec_override)
|
|
80
|
+
return PaginationParams()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def create_next_page_cursor(
|
|
84
|
+
search_results: list[SearchResult], pagination_params: PaginationParams, limit: int
|
|
85
|
+
) -> str | None:
|
|
86
|
+
"""Create next page cursor if there are more results."""
|
|
87
|
+
has_next_page = len(search_results) == limit and limit > 0
|
|
88
|
+
if has_next_page:
|
|
89
|
+
last_item = search_results[-1]
|
|
90
|
+
cursor_data = PageCursor(
|
|
91
|
+
score=float(last_item.score),
|
|
92
|
+
id=last_item.entity_id,
|
|
93
|
+
q_vec_b64=floats_to_b64(pagination_params.q_vec_override or []),
|
|
94
|
+
)
|
|
95
|
+
return cursor_data.encode()
|
|
96
|
+
return None
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from .base import Retriever
|
|
15
|
+
from .fuzzy import FuzzyRetriever
|
|
16
|
+
from .hybrid import RrfHybridRetriever
|
|
17
|
+
from .semantic import SemanticRetriever
|
|
18
|
+
from .structured import StructuredRetriever
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"Retriever",
|
|
22
|
+
"FuzzyRetriever",
|
|
23
|
+
"RrfHybridRetriever",
|
|
24
|
+
"SemanticRetriever",
|
|
25
|
+
"StructuredRetriever",
|
|
26
|
+
]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from decimal import Decimal
|
|
16
|
+
|
|
17
|
+
import structlog
|
|
18
|
+
from sqlalchemy import BindParameter, Numeric, Select, literal
|
|
19
|
+
|
|
20
|
+
from orchestrator.search.core.types import FieldType, SearchMetadata
|
|
21
|
+
from orchestrator.search.schemas.parameters import BaseSearchParameters
|
|
22
|
+
|
|
23
|
+
from ..pagination import PaginationParams
|
|
24
|
+
|
|
25
|
+
logger = structlog.get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Retriever(ABC):
|
|
29
|
+
"""Abstract base class for applying a ranking strategy to a search query."""
|
|
30
|
+
|
|
31
|
+
SCORE_PRECISION = 12
|
|
32
|
+
SCORE_NUMERIC_TYPE = Numeric(38, 12)
|
|
33
|
+
HIGHLIGHT_TEXT_LABEL = "highlight_text"
|
|
34
|
+
HIGHLIGHT_PATH_LABEL = "highlight_path"
|
|
35
|
+
SCORE_LABEL = "score"
|
|
36
|
+
SEARCHABLE_FIELD_TYPES = [
|
|
37
|
+
FieldType.STRING.value,
|
|
38
|
+
FieldType.UUID.value,
|
|
39
|
+
FieldType.BLOCK.value,
|
|
40
|
+
FieldType.RESOURCE_TYPE.value,
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
async def from_params(
|
|
45
|
+
cls,
|
|
46
|
+
params: BaseSearchParameters,
|
|
47
|
+
pagination_params: PaginationParams,
|
|
48
|
+
) -> "Retriever":
|
|
49
|
+
"""Create the appropriate retriever instance from search parameters.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
params (BaseSearchParameters): Search parameters including vector queries, fuzzy terms, and filters.
|
|
53
|
+
pagination_params (PaginationParams): Pagination parameters for cursor-based paging.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Retriever: A concrete retriever instance (semantic, fuzzy, hybrid, or structured).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
from .fuzzy import FuzzyRetriever
|
|
60
|
+
from .hybrid import RrfHybridRetriever
|
|
61
|
+
from .semantic import SemanticRetriever
|
|
62
|
+
from .structured import StructuredRetriever
|
|
63
|
+
|
|
64
|
+
fuzzy_term = params.fuzzy_term
|
|
65
|
+
q_vec = await cls._get_query_vector(params.vector_query, pagination_params.q_vec_override)
|
|
66
|
+
|
|
67
|
+
# If semantic search was attempted but failed, fall back to fuzzy with the full query
|
|
68
|
+
fallback_fuzzy_term = fuzzy_term
|
|
69
|
+
if q_vec is None and params.vector_query is not None and params.query is not None:
|
|
70
|
+
fallback_fuzzy_term = params.query
|
|
71
|
+
|
|
72
|
+
if q_vec is not None and fallback_fuzzy_term is not None:
|
|
73
|
+
return RrfHybridRetriever(q_vec, fallback_fuzzy_term, pagination_params)
|
|
74
|
+
if q_vec is not None:
|
|
75
|
+
return SemanticRetriever(q_vec, pagination_params)
|
|
76
|
+
if fallback_fuzzy_term is not None:
|
|
77
|
+
return FuzzyRetriever(fallback_fuzzy_term, pagination_params)
|
|
78
|
+
|
|
79
|
+
return StructuredRetriever(pagination_params)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
async def _get_query_vector(
|
|
83
|
+
cls, vector_query: str | None, q_vec_override: list[float] | None
|
|
84
|
+
) -> list[float] | None:
|
|
85
|
+
"""Get query vector either from override or by generating from text."""
|
|
86
|
+
if q_vec_override:
|
|
87
|
+
return q_vec_override
|
|
88
|
+
|
|
89
|
+
if not vector_query:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
from orchestrator.search.core.embedding import QueryEmbedder
|
|
93
|
+
|
|
94
|
+
q_vec = await QueryEmbedder.generate_for_text_async(vector_query)
|
|
95
|
+
if not q_vec:
|
|
96
|
+
logger.warning("Embedding generation failed; using non-semantic retriever")
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
return q_vec
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
103
|
+
"""Apply the ranking logic to the given candidate query.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
candidate_query (Select): A SQLAlchemy `Select` statement returning candidate entity IDs.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Select: A new `Select` statement with ranking expressions applied.
|
|
110
|
+
"""
|
|
111
|
+
...
|
|
112
|
+
|
|
113
|
+
def _quantize_score_for_pagination(self, score_value: float) -> BindParameter[Decimal]:
|
|
114
|
+
"""Convert score value to properly quantized Decimal parameter for pagination."""
|
|
115
|
+
quantizer = Decimal(1).scaleb(-self.SCORE_PRECISION)
|
|
116
|
+
pas_dec = Decimal(str(score_value)).quantize(quantizer)
|
|
117
|
+
return literal(pas_dec, type_=self.SCORE_NUMERIC_TYPE)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def metadata(self) -> SearchMetadata:
|
|
122
|
+
"""Return metadata describing this search strategy."""
|
|
123
|
+
...
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from sqlalchemy import Select, and_, cast, func, literal, or_, select
|
|
15
|
+
from sqlalchemy.sql.expression import ColumnElement
|
|
16
|
+
|
|
17
|
+
from orchestrator.db.models import AiSearchIndex
|
|
18
|
+
from orchestrator.search.core.types import SearchMetadata
|
|
19
|
+
|
|
20
|
+
from ..pagination import PaginationParams
|
|
21
|
+
from .base import Retriever
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FuzzyRetriever(Retriever):
|
|
25
|
+
"""Ranks results based on the max of fuzzy text similarity scores."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, fuzzy_term: str, pagination_params: PaginationParams) -> None:
|
|
28
|
+
self.fuzzy_term = fuzzy_term
|
|
29
|
+
self.page_after_score = pagination_params.page_after_score
|
|
30
|
+
self.page_after_id = pagination_params.page_after_id
|
|
31
|
+
|
|
32
|
+
def apply(self, candidate_query: Select) -> Select:
|
|
33
|
+
cand = candidate_query.subquery()
|
|
34
|
+
|
|
35
|
+
similarity_expr = func.word_similarity(self.fuzzy_term, AiSearchIndex.value)
|
|
36
|
+
|
|
37
|
+
raw_max = func.max(similarity_expr).over(partition_by=AiSearchIndex.entity_id)
|
|
38
|
+
score = cast(
|
|
39
|
+
func.round(cast(raw_max, self.SCORE_NUMERIC_TYPE), self.SCORE_PRECISION), self.SCORE_NUMERIC_TYPE
|
|
40
|
+
).label(self.SCORE_LABEL)
|
|
41
|
+
|
|
42
|
+
combined_query = (
|
|
43
|
+
select(
|
|
44
|
+
AiSearchIndex.entity_id,
|
|
45
|
+
score,
|
|
46
|
+
func.first_value(AiSearchIndex.value)
|
|
47
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()])
|
|
48
|
+
.label(self.HIGHLIGHT_TEXT_LABEL),
|
|
49
|
+
func.first_value(AiSearchIndex.path)
|
|
50
|
+
.over(partition_by=AiSearchIndex.entity_id, order_by=[similarity_expr.desc(), AiSearchIndex.path.asc()])
|
|
51
|
+
.label(self.HIGHLIGHT_PATH_LABEL),
|
|
52
|
+
)
|
|
53
|
+
.select_from(AiSearchIndex)
|
|
54
|
+
.join(cand, cand.c.entity_id == AiSearchIndex.entity_id)
|
|
55
|
+
.where(
|
|
56
|
+
and_(
|
|
57
|
+
AiSearchIndex.value_type.in_(self.SEARCHABLE_FIELD_TYPES),
|
|
58
|
+
literal(self.fuzzy_term).op("<%")(AiSearchIndex.value),
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
.distinct(AiSearchIndex.entity_id)
|
|
62
|
+
)
|
|
63
|
+
final_query = combined_query.subquery("ranked_fuzzy")
|
|
64
|
+
|
|
65
|
+
stmt = select(
|
|
66
|
+
final_query.c.entity_id,
|
|
67
|
+
final_query.c.score,
|
|
68
|
+
final_query.c.highlight_text,
|
|
69
|
+
final_query.c.highlight_path,
|
|
70
|
+
).select_from(final_query)
|
|
71
|
+
|
|
72
|
+
stmt = self._apply_score_pagination(stmt, final_query.c.score, final_query.c.entity_id)
|
|
73
|
+
|
|
74
|
+
return stmt.order_by(final_query.c.score.desc().nulls_last(), final_query.c.entity_id.asc())
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def metadata(self) -> SearchMetadata:
|
|
78
|
+
return SearchMetadata.fuzzy()
|
|
79
|
+
|
|
80
|
+
def _apply_score_pagination(
|
|
81
|
+
self, stmt: Select, score_column: ColumnElement, entity_id_column: ColumnElement
|
|
82
|
+
) -> Select:
|
|
83
|
+
"""Apply standard score + entity_id pagination."""
|
|
84
|
+
if self.page_after_score is not None and self.page_after_id is not None:
|
|
85
|
+
stmt = stmt.where(
|
|
86
|
+
or_(
|
|
87
|
+
score_column < self.page_after_score,
|
|
88
|
+
and_(
|
|
89
|
+
score_column == self.page_after_score,
|
|
90
|
+
entity_id_column > self.page_after_id,
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
return stmt
|