orchestrator-core 4.4.0rc3__py3-none-any.whl → 4.5.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +26 -2
- orchestrator/agentic_app.py +84 -0
- orchestrator/api/api_v1/api.py +10 -0
- orchestrator/api/api_v1/endpoints/search.py +277 -0
- orchestrator/app.py +32 -0
- orchestrator/cli/index_llm.py +73 -0
- orchestrator/cli/main.py +22 -1
- orchestrator/cli/resize_embedding.py +135 -0
- orchestrator/cli/search_explore.py +208 -0
- orchestrator/cli/speedtest.py +151 -0
- orchestrator/db/models.py +37 -1
- orchestrator/graphql/schemas/process.py +2 -2
- orchestrator/graphql/schemas/workflow.py +2 -2
- orchestrator/llm_settings.py +51 -0
- orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
- orchestrator/schedules/scheduler.py +6 -7
- orchestrator/schemas/search.py +117 -0
- orchestrator/search/__init__.py +12 -0
- orchestrator/search/agent/__init__.py +8 -0
- orchestrator/search/agent/agent.py +47 -0
- orchestrator/search/agent/prompts.py +62 -0
- orchestrator/search/agent/state.py +8 -0
- orchestrator/search/agent/tools.py +121 -0
- orchestrator/search/core/__init__.py +0 -0
- orchestrator/search/core/embedding.py +64 -0
- orchestrator/search/core/exceptions.py +22 -0
- orchestrator/search/core/types.py +281 -0
- orchestrator/search/core/validators.py +27 -0
- orchestrator/search/docs/index.md +37 -0
- orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
- orchestrator/search/filters/__init__.py +27 -0
- orchestrator/search/filters/base.py +272 -0
- orchestrator/search/filters/date_filters.py +75 -0
- orchestrator/search/filters/definitions.py +93 -0
- orchestrator/search/filters/ltree_filters.py +43 -0
- orchestrator/search/filters/numeric_filter.py +60 -0
- orchestrator/search/indexing/__init__.py +3 -0
- orchestrator/search/indexing/indexer.py +323 -0
- orchestrator/search/indexing/registry.py +88 -0
- orchestrator/search/indexing/tasks.py +53 -0
- orchestrator/search/indexing/traverse.py +322 -0
- orchestrator/search/retrieval/__init__.py +3 -0
- orchestrator/search/retrieval/builder.py +108 -0
- orchestrator/search/retrieval/engine.py +152 -0
- orchestrator/search/retrieval/pagination.py +83 -0
- orchestrator/search/retrieval/retriever.py +447 -0
- orchestrator/search/retrieval/utils.py +106 -0
- orchestrator/search/retrieval/validation.py +174 -0
- orchestrator/search/schemas/__init__.py +0 -0
- orchestrator/search/schemas/parameters.py +116 -0
- orchestrator/search/schemas/results.py +63 -0
- orchestrator/services/settings_env_variables.py +2 -2
- orchestrator/settings.py +1 -1
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/METADATA +8 -3
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/RECORD +57 -14
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from datetime import date, datetime
|
|
2
|
+
from typing import Annotated, Any, Literal
|
|
3
|
+
|
|
4
|
+
from dateutil.parser import parse as dt_parse
|
|
5
|
+
from pydantic import BaseModel, BeforeValidator, Field, model_validator
|
|
6
|
+
from sqlalchemy import TIMESTAMP, and_
|
|
7
|
+
from sqlalchemy import cast as sa_cast
|
|
8
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
9
|
+
|
|
10
|
+
from orchestrator.search.core.types import FilterOp, SQLAColumn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _validate_date_string(v: Any) -> Any:
|
|
14
|
+
if not isinstance(v, str):
|
|
15
|
+
return v
|
|
16
|
+
try:
|
|
17
|
+
dt_parse(v)
|
|
18
|
+
return v
|
|
19
|
+
except Exception as exc:
|
|
20
|
+
raise ValueError("is not a valid date or datetime string") from exc
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
DateValue = datetime | date | str
|
|
24
|
+
ValidatedDateValue = Annotated[DateValue, BeforeValidator(_validate_date_string)]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DateRange(BaseModel):
|
|
28
|
+
|
|
29
|
+
start: ValidatedDateValue
|
|
30
|
+
end: ValidatedDateValue
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="after")
|
|
33
|
+
def _order(self) -> "DateRange":
|
|
34
|
+
to_datetime = dt_parse(str(self.end))
|
|
35
|
+
from_datetime = dt_parse(str(self.start))
|
|
36
|
+
if to_datetime <= from_datetime:
|
|
37
|
+
raise ValueError("'to' must be after 'from'")
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DateValueFilter(BaseModel):
|
|
42
|
+
"""A filter that operates on a single date value."""
|
|
43
|
+
|
|
44
|
+
op: Literal[FilterOp.EQ, FilterOp.NEQ, FilterOp.LT, FilterOp.LTE, FilterOp.GT, FilterOp.GTE]
|
|
45
|
+
value: ValidatedDateValue
|
|
46
|
+
|
|
47
|
+
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
48
|
+
date_column = sa_cast(column, TIMESTAMP(timezone=True))
|
|
49
|
+
match self.op:
|
|
50
|
+
case FilterOp.EQ:
|
|
51
|
+
return date_column == self.value
|
|
52
|
+
case FilterOp.NEQ:
|
|
53
|
+
return date_column != self.value
|
|
54
|
+
case FilterOp.LT:
|
|
55
|
+
return date_column < self.value
|
|
56
|
+
case FilterOp.LTE:
|
|
57
|
+
return date_column <= self.value
|
|
58
|
+
case FilterOp.GT:
|
|
59
|
+
return date_column > self.value
|
|
60
|
+
case FilterOp.GTE:
|
|
61
|
+
return date_column >= self.value
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DateRangeFilter(BaseModel):
|
|
65
|
+
"""A filter that operates on a range of dates."""
|
|
66
|
+
|
|
67
|
+
op: Literal[FilterOp.BETWEEN]
|
|
68
|
+
value: DateRange
|
|
69
|
+
|
|
70
|
+
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
71
|
+
date_column = sa_cast(column, TIMESTAMP(timezone=True))
|
|
72
|
+
return and_(date_column >= self.value.start, date_column < self.value.end)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
DateFilter = Annotated[DateValueFilter | DateRangeFilter, Field(discriminator="op")]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from orchestrator.search.core.types import FieldType, FilterOp, UIType
|
|
2
|
+
from orchestrator.search.schemas.results import TypeDefinition, ValueSchema
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def operators_for(ft: FieldType) -> list[FilterOp]:
|
|
6
|
+
"""Return the list of valid operators for a given FieldType."""
|
|
7
|
+
return list(value_schema_for(ft).keys())
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def component_operators() -> dict[FilterOp, ValueSchema]:
|
|
11
|
+
"""Return operators available for path components."""
|
|
12
|
+
return {
|
|
13
|
+
FilterOp.HAS_COMPONENT: ValueSchema(kind=UIType.COMPONENT),
|
|
14
|
+
FilterOp.NOT_HAS_COMPONENT: ValueSchema(kind=UIType.COMPONENT),
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def value_schema_for(ft: FieldType) -> dict[FilterOp, ValueSchema]:
|
|
19
|
+
"""Return the value schema map for a given FieldType."""
|
|
20
|
+
if ft in (FieldType.INTEGER, FieldType.FLOAT):
|
|
21
|
+
return {
|
|
22
|
+
FilterOp.EQ: ValueSchema(kind=UIType.NUMBER),
|
|
23
|
+
FilterOp.NEQ: ValueSchema(kind=UIType.NUMBER),
|
|
24
|
+
FilterOp.LT: ValueSchema(kind=UIType.NUMBER),
|
|
25
|
+
FilterOp.LTE: ValueSchema(kind=UIType.NUMBER),
|
|
26
|
+
FilterOp.GT: ValueSchema(kind=UIType.NUMBER),
|
|
27
|
+
FilterOp.GTE: ValueSchema(kind=UIType.NUMBER),
|
|
28
|
+
FilterOp.BETWEEN: ValueSchema(
|
|
29
|
+
kind="object",
|
|
30
|
+
fields={
|
|
31
|
+
"start": ValueSchema(kind=UIType.NUMBER),
|
|
32
|
+
"end": ValueSchema(kind=UIType.NUMBER),
|
|
33
|
+
},
|
|
34
|
+
),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
if ft == FieldType.BOOLEAN:
|
|
38
|
+
return {
|
|
39
|
+
FilterOp.EQ: ValueSchema(kind=UIType.BOOLEAN),
|
|
40
|
+
FilterOp.NEQ: ValueSchema(kind=UIType.BOOLEAN),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if ft == FieldType.DATETIME:
|
|
44
|
+
return {
|
|
45
|
+
FilterOp.EQ: ValueSchema(kind=UIType.DATETIME),
|
|
46
|
+
FilterOp.NEQ: ValueSchema(kind=UIType.DATETIME),
|
|
47
|
+
FilterOp.LT: ValueSchema(kind=UIType.DATETIME),
|
|
48
|
+
FilterOp.LTE: ValueSchema(kind=UIType.DATETIME),
|
|
49
|
+
FilterOp.GT: ValueSchema(kind=UIType.DATETIME),
|
|
50
|
+
FilterOp.GTE: ValueSchema(kind=UIType.DATETIME),
|
|
51
|
+
FilterOp.BETWEEN: ValueSchema(
|
|
52
|
+
kind="object",
|
|
53
|
+
fields={
|
|
54
|
+
"start": ValueSchema(kind=UIType.DATETIME),
|
|
55
|
+
"end": ValueSchema(kind=UIType.DATETIME),
|
|
56
|
+
},
|
|
57
|
+
),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
return {
|
|
61
|
+
FilterOp.EQ: ValueSchema(kind=UIType.STRING),
|
|
62
|
+
FilterOp.NEQ: ValueSchema(kind=UIType.STRING),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def generate_definitions() -> dict[UIType, TypeDefinition]:
|
|
67
|
+
"""Generate the full definitions dictionary for all UI types."""
|
|
68
|
+
definitions: dict[UIType, TypeDefinition] = {}
|
|
69
|
+
|
|
70
|
+
for ui_type in UIType:
|
|
71
|
+
if ui_type == UIType.COMPONENT:
|
|
72
|
+
# Special case for component filtering
|
|
73
|
+
comp_ops = component_operators()
|
|
74
|
+
definitions[ui_type] = TypeDefinition(
|
|
75
|
+
operators=list(comp_ops.keys()),
|
|
76
|
+
valueSchema=comp_ops,
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
# Regular field types
|
|
80
|
+
if ui_type == UIType.NUMBER:
|
|
81
|
+
rep_ft = FieldType.INTEGER
|
|
82
|
+
elif ui_type == UIType.DATETIME:
|
|
83
|
+
rep_ft = FieldType.DATETIME
|
|
84
|
+
elif ui_type == UIType.BOOLEAN:
|
|
85
|
+
rep_ft = FieldType.BOOLEAN
|
|
86
|
+
else:
|
|
87
|
+
rep_ft = FieldType.STRING
|
|
88
|
+
|
|
89
|
+
definitions[ui_type] = TypeDefinition(
|
|
90
|
+
operators=operators_for(rep_ft),
|
|
91
|
+
valueSchema=value_schema_for(rep_ft),
|
|
92
|
+
)
|
|
93
|
+
return definitions
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
from sqlalchemy import TEXT, bindparam
|
|
5
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
6
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
7
|
+
|
|
8
|
+
from orchestrator.search.core.types import FilterOp, SQLAColumn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LtreeFilter(BaseModel):
|
|
12
|
+
"""Filter for ltree path operations."""
|
|
13
|
+
|
|
14
|
+
op: Literal[
|
|
15
|
+
FilterOp.MATCHES_LQUERY,
|
|
16
|
+
FilterOp.IS_ANCESTOR,
|
|
17
|
+
FilterOp.IS_DESCENDANT,
|
|
18
|
+
FilterOp.PATH_MATCH,
|
|
19
|
+
FilterOp.HAS_COMPONENT,
|
|
20
|
+
FilterOp.NOT_HAS_COMPONENT,
|
|
21
|
+
FilterOp.ENDS_WITH,
|
|
22
|
+
]
|
|
23
|
+
value: str = Field(description="The ltree path or lquery pattern to compare against.")
|
|
24
|
+
|
|
25
|
+
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
26
|
+
"""Converts the filter condition into a SQLAlchemy expression."""
|
|
27
|
+
match self.op:
|
|
28
|
+
case FilterOp.IS_DESCENDANT:
|
|
29
|
+
ltree_value = Ltree(self.value)
|
|
30
|
+
return column.op("<@")(ltree_value)
|
|
31
|
+
case FilterOp.IS_ANCESTOR:
|
|
32
|
+
ltree_value = Ltree(self.value)
|
|
33
|
+
return column.op("@>")(ltree_value)
|
|
34
|
+
case FilterOp.MATCHES_LQUERY:
|
|
35
|
+
param = bindparam(None, self.value, type_=TEXT)
|
|
36
|
+
return column.op("~")(param)
|
|
37
|
+
case FilterOp.PATH_MATCH:
|
|
38
|
+
ltree_value = Ltree(path)
|
|
39
|
+
return column == ltree_value
|
|
40
|
+
case FilterOp.HAS_COMPONENT | FilterOp.NOT_HAS_COMPONENT:
|
|
41
|
+
return column.op("~")(bindparam(None, f"*.{self.value}.*", type_=TEXT))
|
|
42
|
+
case FilterOp.ENDS_WITH:
|
|
43
|
+
return column.op("~")(bindparam(None, f"*.{self.value}", type_=TEXT))
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Annotated, Any, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
|
+
from sqlalchemy import DOUBLE_PRECISION, INTEGER, and_
|
|
5
|
+
from sqlalchemy import cast as sa_cast
|
|
6
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from orchestrator.search.core.types import FilterOp, SQLAColumn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NumericRange(BaseModel):
|
|
13
|
+
start: int | float
|
|
14
|
+
end: int | float
|
|
15
|
+
|
|
16
|
+
@model_validator(mode="after")
|
|
17
|
+
def validate_order(self) -> Self:
|
|
18
|
+
if self.end <= self.start:
|
|
19
|
+
raise ValueError("'end' must be greater than 'start'")
|
|
20
|
+
return self
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NumericValueFilter(BaseModel):
|
|
24
|
+
"""A filter for single numeric value comparisons (int or float)."""
|
|
25
|
+
|
|
26
|
+
op: Literal[FilterOp.EQ, FilterOp.NEQ, FilterOp.LT, FilterOp.LTE, FilterOp.GT, FilterOp.GTE]
|
|
27
|
+
value: int | float
|
|
28
|
+
|
|
29
|
+
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
30
|
+
cast_type = INTEGER if isinstance(self.value, int) else DOUBLE_PRECISION
|
|
31
|
+
numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
|
|
32
|
+
match self.op:
|
|
33
|
+
|
|
34
|
+
case FilterOp.EQ:
|
|
35
|
+
return numeric_column == self.value
|
|
36
|
+
case FilterOp.NEQ:
|
|
37
|
+
return numeric_column != self.value
|
|
38
|
+
case FilterOp.LT:
|
|
39
|
+
return numeric_column < self.value
|
|
40
|
+
case FilterOp.LTE:
|
|
41
|
+
return numeric_column <= self.value
|
|
42
|
+
case FilterOp.GT:
|
|
43
|
+
return numeric_column > self.value
|
|
44
|
+
case FilterOp.GTE:
|
|
45
|
+
return numeric_column >= self.value
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class NumericRangeFilter(BaseModel):
|
|
49
|
+
"""A filter for a range of numeric values (int or float)."""
|
|
50
|
+
|
|
51
|
+
op: Literal[FilterOp.BETWEEN]
|
|
52
|
+
value: NumericRange
|
|
53
|
+
|
|
54
|
+
def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
|
|
55
|
+
cast_type = INTEGER if isinstance(self.value.start, int) else DOUBLE_PRECISION
|
|
56
|
+
numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
|
|
57
|
+
return and_(numeric_column >= self.value.start, numeric_column <= self.value.end)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
NumericFilter = Annotated[NumericValueFilter | NumericRangeFilter, Field(discriminator="op")]
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from collections.abc import Generator, Iterable, Iterator
|
|
3
|
+
from contextlib import contextmanager, nullcontext
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from litellm.utils import encode, get_max_tokens
|
|
9
|
+
from sqlalchemy import delete, tuple_
|
|
10
|
+
from sqlalchemy.dialects.postgresql import insert
|
|
11
|
+
from sqlalchemy.dialects.postgresql.dml import Insert
|
|
12
|
+
from sqlalchemy.orm import Session
|
|
13
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
14
|
+
|
|
15
|
+
from orchestrator.db import db
|
|
16
|
+
from orchestrator.db.models import AiSearchIndex
|
|
17
|
+
from orchestrator.llm_settings import llm_settings
|
|
18
|
+
from orchestrator.search.core.embedding import EmbeddingIndexer
|
|
19
|
+
from orchestrator.search.core.types import ExtractedField, IndexableRecord
|
|
20
|
+
from orchestrator.search.indexing.registry import EntityConfig
|
|
21
|
+
from orchestrator.search.indexing.traverse import DatabaseEntity
|
|
22
|
+
|
|
23
|
+
logger = structlog.get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@contextmanager
|
|
27
|
+
def _maybe_begin(session: Session | None) -> Iterator[None]:
|
|
28
|
+
if session is None:
|
|
29
|
+
yield
|
|
30
|
+
else:
|
|
31
|
+
with session.begin():
|
|
32
|
+
yield
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Indexer:
|
|
36
|
+
"""Index entities into `AiSearchIndex` using streaming reads and batched writes.
|
|
37
|
+
|
|
38
|
+
Entities are read from a streaming iterator and accumulated into chunks of
|
|
39
|
+
size `chunk_size`. For each chunk, the indexer extracts fields, diffs via
|
|
40
|
+
content hashes, deletes stale paths, and prepares upserts using a two-list
|
|
41
|
+
buffer:
|
|
42
|
+
- Embeddable list (STRING fields): maintains a running token count against a
|
|
43
|
+
token budget (model context window minus a safety margin) and flushes when
|
|
44
|
+
adding the next item would exceed the budget.
|
|
45
|
+
- Non-embeddable list: accumulated in parallel and does not contribute to the
|
|
46
|
+
flush condition.
|
|
47
|
+
Each flush (or end-of-chunk) emits a single combined UPSERT batch from both
|
|
48
|
+
lists (wrapped in a per-chunk transaction in non-dry-runs).
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config (EntityConfig): Registry config describing the entity kind,
|
|
52
|
+
ORM table, and traverser.
|
|
53
|
+
dry_run (bool): If True, skip DELETE/UPSERT statements and external
|
|
54
|
+
embedding calls.
|
|
55
|
+
force_index (bool): If True, ignore existing hashes and reindex all
|
|
56
|
+
fields for each entity.
|
|
57
|
+
chunk_size (int): Number of entities to process per batch. Defaults to 1000.
|
|
58
|
+
|
|
59
|
+
Notes:
|
|
60
|
+
- Non-dry-run runs open a write session and wrap each processed chunk in
|
|
61
|
+
a transaction (`Session.begin()`).
|
|
62
|
+
- Read queries use the passed session when available, otherwise the
|
|
63
|
+
generic `db.session`.
|
|
64
|
+
|
|
65
|
+
Workflow:
|
|
66
|
+
1) Stream entities (yield_per=chunk_size) and accumulate into a chunk.
|
|
67
|
+
2) Begin transaction for the chunk.
|
|
68
|
+
3) determine_changes() → fields_to_upsert, paths_to_delete.
|
|
69
|
+
4) Delete stale paths.
|
|
70
|
+
5) Build UPSERT batches with a two-list buffer:
|
|
71
|
+
- Embeddable list (STRING): track running token count; flush when next item
|
|
72
|
+
would exceed the token budget (model max context - safety margin).
|
|
73
|
+
- Non-embeddable list: accumulate in parallel; does not affect flushing.
|
|
74
|
+
6) Execute UPSERT for each batch (skip in dry_run).
|
|
75
|
+
7) Commit transaction (auto on context exit).
|
|
76
|
+
8) Repeat until the stream is exhausted.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk_size: int = 1000) -> None:
|
|
80
|
+
self.config = config
|
|
81
|
+
self.dry_run = dry_run
|
|
82
|
+
self.force_index = force_index
|
|
83
|
+
self.chunk_size = chunk_size
|
|
84
|
+
self.embedding_model = llm_settings.EMBEDDING_MODEL
|
|
85
|
+
self.logger = logger.bind(entity_kind=config.entity_kind.value)
|
|
86
|
+
|
|
87
|
+
def run(self, entities: Iterable[DatabaseEntity]) -> int:
|
|
88
|
+
"""Orchestrates the entire indexing process."""
|
|
89
|
+
chunk: list[DatabaseEntity] = []
|
|
90
|
+
total_records_processed = 0
|
|
91
|
+
total_identical_records = 0
|
|
92
|
+
|
|
93
|
+
write_scope = db.database_scope() if not self.dry_run else nullcontext()
|
|
94
|
+
|
|
95
|
+
def flush() -> None:
|
|
96
|
+
nonlocal total_records_processed, total_identical_records
|
|
97
|
+
with _maybe_begin(session):
|
|
98
|
+
processed_in_chunk, identical_in_chunk = self._process_chunk(chunk, session)
|
|
99
|
+
total_records_processed += processed_in_chunk
|
|
100
|
+
total_identical_records += identical_in_chunk
|
|
101
|
+
chunk.clear()
|
|
102
|
+
|
|
103
|
+
with write_scope as database:
|
|
104
|
+
session: Session | None = getattr(database, "session", None)
|
|
105
|
+
for entity in entities:
|
|
106
|
+
chunk.append(entity)
|
|
107
|
+
if len(chunk) >= self.chunk_size:
|
|
108
|
+
flush()
|
|
109
|
+
|
|
110
|
+
if chunk:
|
|
111
|
+
flush()
|
|
112
|
+
|
|
113
|
+
final_log_message = (
|
|
114
|
+
f"processed {total_records_processed} records and skipped {total_identical_records} identical records."
|
|
115
|
+
)
|
|
116
|
+
self.logger.info(
|
|
117
|
+
f"Dry run, would have indexed {final_log_message}"
|
|
118
|
+
if self.dry_run
|
|
119
|
+
else f"Indexing done, {final_log_message}"
|
|
120
|
+
)
|
|
121
|
+
return total_records_processed
|
|
122
|
+
|
|
123
|
+
def _process_chunk(self, entity_chunk: list[DatabaseEntity], session: Session | None = None) -> tuple[int, int]:
|
|
124
|
+
"""Process a chunk of entities."""
|
|
125
|
+
if not entity_chunk:
|
|
126
|
+
return 0, 0
|
|
127
|
+
|
|
128
|
+
fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session)
|
|
129
|
+
|
|
130
|
+
if paths_to_delete and session is not None:
|
|
131
|
+
self.logger.debug(f"Deleting {len(paths_to_delete)} stale records in chunk.")
|
|
132
|
+
self._execute_batched_deletes(paths_to_delete, session)
|
|
133
|
+
|
|
134
|
+
if fields_to_upsert:
|
|
135
|
+
upsert_stmt = self._get_upsert_statement()
|
|
136
|
+
batch_generator = self._generate_upsert_batches(fields_to_upsert)
|
|
137
|
+
|
|
138
|
+
for batch in batch_generator:
|
|
139
|
+
if self.dry_run:
|
|
140
|
+
self.logger.debug(f"Dry Run: Would upsert {len(batch)} records.")
|
|
141
|
+
elif batch and session:
|
|
142
|
+
session.execute(upsert_stmt, batch)
|
|
143
|
+
|
|
144
|
+
return len(fields_to_upsert), identical_count
|
|
145
|
+
|
|
146
|
+
def _determine_changes(
|
|
147
|
+
self, entities: list[DatabaseEntity], session: Session | None = None
|
|
148
|
+
) -> tuple[list[tuple[str, ExtractedField]], list[tuple[str, Ltree]], int]:
|
|
149
|
+
"""Identifies all changes across all entities using pre-fetched data."""
|
|
150
|
+
entity_ids = [str(getattr(e, self.config.pk_name)) for e in entities]
|
|
151
|
+
read_session = session or db.session
|
|
152
|
+
existing_hashes = {} if self.force_index else self._get_all_existing_hashes(entity_ids, read_session)
|
|
153
|
+
|
|
154
|
+
fields_to_upsert: list[tuple[str, ExtractedField]] = []
|
|
155
|
+
paths_to_delete: list[tuple[str, Ltree]] = []
|
|
156
|
+
identical_records_count = 0
|
|
157
|
+
|
|
158
|
+
for entity in entities:
|
|
159
|
+
entity_id = str(getattr(entity, self.config.pk_name))
|
|
160
|
+
current_fields = self.config.traverser.get_fields(
|
|
161
|
+
entity, pk_name=self.config.pk_name, root_name=self.config.root_name
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
entity_hashes = existing_hashes.get(entity_id, {})
|
|
165
|
+
current_paths = set()
|
|
166
|
+
|
|
167
|
+
for field in current_fields:
|
|
168
|
+
current_paths.add(field.path)
|
|
169
|
+
current_hash = self._compute_content_hash(field.path, field.value, field.value_type)
|
|
170
|
+
if field.path not in entity_hashes or entity_hashes[field.path] != current_hash:
|
|
171
|
+
fields_to_upsert.append((entity_id, field))
|
|
172
|
+
else:
|
|
173
|
+
identical_records_count += 1
|
|
174
|
+
|
|
175
|
+
stale_paths = set(entity_hashes.keys()) - current_paths
|
|
176
|
+
paths_to_delete.extend([(entity_id, Ltree(p)) for p in stale_paths])
|
|
177
|
+
|
|
178
|
+
return fields_to_upsert, paths_to_delete, identical_records_count
|
|
179
|
+
|
|
180
|
+
def _execute_batched_deletes(self, paths_to_delete: list[tuple[str, Ltree]], session: Session) -> None:
|
|
181
|
+
"""Execute delete operations in batches to avoid PostgreSQL stack depth limits."""
|
|
182
|
+
for i in range(0, len(paths_to_delete), self.chunk_size):
|
|
183
|
+
batch = paths_to_delete[i : i + self.chunk_size]
|
|
184
|
+
delete_stmt = delete(AiSearchIndex).where(tuple_(AiSearchIndex.entity_id, AiSearchIndex.path).in_(batch))
|
|
185
|
+
session.execute(delete_stmt)
|
|
186
|
+
self.logger.debug(f"Deleted batch of {len(batch)} records.")
|
|
187
|
+
|
|
188
|
+
def _get_all_existing_hashes(self, entity_ids: list[str], session: Session) -> dict[str, dict[str, str]]:
|
|
189
|
+
"""Fetches all existing hashes for a list of entity IDs in a single query."""
|
|
190
|
+
if not entity_ids:
|
|
191
|
+
return {}
|
|
192
|
+
|
|
193
|
+
results = (
|
|
194
|
+
session.query(AiSearchIndex.entity_id, AiSearchIndex.path, AiSearchIndex.content_hash)
|
|
195
|
+
.filter(AiSearchIndex.entity_id.in_(entity_ids))
|
|
196
|
+
.all()
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
hashes_by_entity: dict[str, dict[str, str]] = {eid: {} for eid in entity_ids}
|
|
200
|
+
for entity_id, path, content_hash in results:
|
|
201
|
+
hashes_by_entity[str(entity_id)][str(path)] = content_hash
|
|
202
|
+
return hashes_by_entity
|
|
203
|
+
|
|
204
|
+
def _generate_upsert_batches(
|
|
205
|
+
self, fields_to_upsert: Iterable[tuple[str, ExtractedField]]
|
|
206
|
+
) -> Generator[list[IndexableRecord], None, None]:
|
|
207
|
+
"""Streams through fields, buffers them by token count, and yields batches."""
|
|
208
|
+
embeddable_buffer: list[tuple[str, ExtractedField]] = []
|
|
209
|
+
non_embeddable_records: list[IndexableRecord] = []
|
|
210
|
+
current_tokens = 0
|
|
211
|
+
|
|
212
|
+
max_ctx = self._get_max_tokens()
|
|
213
|
+
safe_margin = int(max_ctx * llm_settings.EMBEDDING_SAFE_MARGIN_PERCENT)
|
|
214
|
+
token_budget = max(1, max_ctx - safe_margin)
|
|
215
|
+
|
|
216
|
+
max_batch_size = None
|
|
217
|
+
if llm_settings.OPENAI_BASE_URL: # We are using a local model
|
|
218
|
+
max_batch_size = llm_settings.EMBEDDING_MAX_BATCH_SIZE
|
|
219
|
+
|
|
220
|
+
for entity_id, field in fields_to_upsert:
|
|
221
|
+
if field.value_type.is_embeddable(field.value):
|
|
222
|
+
text = self._prepare_text_for_embedding(field)
|
|
223
|
+
try:
|
|
224
|
+
item_tokens = len(encode(model=self.embedding_model, text=text))
|
|
225
|
+
except Exception as e:
|
|
226
|
+
self.logger.warning("Tokenization failed; skipping.", path=field.path, err=str(e))
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
if item_tokens > max_ctx:
|
|
230
|
+
self.logger.warning(
|
|
231
|
+
"Field exceeds context; skipping.", path=field.path, tokens=item_tokens, max_ctx=max_ctx
|
|
232
|
+
)
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
should_flush = embeddable_buffer and (
|
|
236
|
+
current_tokens + item_tokens > token_budget
|
|
237
|
+
or (max_batch_size and len(embeddable_buffer) >= max_batch_size)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if should_flush:
|
|
241
|
+
yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
|
|
242
|
+
embeddable_buffer.clear()
|
|
243
|
+
non_embeddable_records.clear()
|
|
244
|
+
current_tokens = 0
|
|
245
|
+
|
|
246
|
+
embeddable_buffer.append((entity_id, field))
|
|
247
|
+
current_tokens += item_tokens
|
|
248
|
+
else:
|
|
249
|
+
record = self._make_indexable_record(field, entity_id, embedding=None)
|
|
250
|
+
non_embeddable_records.append(record)
|
|
251
|
+
|
|
252
|
+
if embeddable_buffer or non_embeddable_records:
|
|
253
|
+
yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
|
|
254
|
+
|
|
255
|
+
def _flush_buffer(self, embeddable_buffer: list, non_embeddable_records: list) -> list[IndexableRecord]:
|
|
256
|
+
"""Processes and combines buffers into a single batch."""
|
|
257
|
+
if not embeddable_buffer:
|
|
258
|
+
return non_embeddable_records
|
|
259
|
+
|
|
260
|
+
texts_to_embed = [self._prepare_text_for_embedding(f) for _, f in embeddable_buffer]
|
|
261
|
+
embeddings = EmbeddingIndexer.get_embeddings_from_api_batch(texts_to_embed, self.dry_run)
|
|
262
|
+
|
|
263
|
+
if len(embeddable_buffer) != len(embeddings):
|
|
264
|
+
raise ValueError(f"Embedding mismatch: sent {len(embeddable_buffer)}, received {len(embeddings)}")
|
|
265
|
+
|
|
266
|
+
with_embeddings = [
|
|
267
|
+
self._make_indexable_record(field, entity_id, embedding)
|
|
268
|
+
for (entity_id, field), embedding in zip(embeddable_buffer, embeddings)
|
|
269
|
+
]
|
|
270
|
+
return non_embeddable_records + with_embeddings
|
|
271
|
+
|
|
272
|
+
def _get_max_tokens(self) -> int:
|
|
273
|
+
"""Gets max tokens, using a fallback from settings if necessary."""
|
|
274
|
+
try:
|
|
275
|
+
max_ctx = get_max_tokens(self.embedding_model)
|
|
276
|
+
if isinstance(max_ctx, int):
|
|
277
|
+
return max_ctx
|
|
278
|
+
except Exception:
|
|
279
|
+
# Allow local(unknown) models to fall back.
|
|
280
|
+
self.logger.warning("Could not auto-detect max tokens.", model=self.embedding_model)
|
|
281
|
+
|
|
282
|
+
max_ctx = llm_settings.EMBEDDING_FALLBACK_MAX_TOKENS
|
|
283
|
+
if not isinstance(max_ctx, int):
|
|
284
|
+
raise RuntimeError("Model not recognized and EMBEDDING_FALLBACK_MAX_TOKENS not set.")
|
|
285
|
+
self.logger.warning("Using configured fallback token limit.", fallback=max_ctx)
|
|
286
|
+
return max_ctx
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def _prepare_text_for_embedding(field: ExtractedField) -> str:
|
|
290
|
+
return f"{field.path}: {str(field.value)}"
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _compute_content_hash(path: str, value: Any, value_type: Any) -> str:
|
|
294
|
+
v = "" if value is None else str(value)
|
|
295
|
+
content = f"{path}:{v}:{value_type}"
|
|
296
|
+
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
297
|
+
|
|
298
|
+
def _make_indexable_record(
|
|
299
|
+
self, field: ExtractedField, entity_id: str, embedding: list[float] | None
|
|
300
|
+
) -> IndexableRecord:
|
|
301
|
+
return IndexableRecord(
|
|
302
|
+
entity_id=entity_id,
|
|
303
|
+
entity_type=self.config.entity_kind.value,
|
|
304
|
+
path=Ltree(field.path),
|
|
305
|
+
value=field.value,
|
|
306
|
+
value_type=field.value_type,
|
|
307
|
+
content_hash=self._compute_content_hash(field.path, field.value, field.value_type),
|
|
308
|
+
embedding=embedding if embedding else None,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
@lru_cache(maxsize=1)
|
|
313
|
+
def _get_upsert_statement() -> Insert:
|
|
314
|
+
stmt = insert(AiSearchIndex)
|
|
315
|
+
return stmt.on_conflict_do_update(
|
|
316
|
+
index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path],
|
|
317
|
+
set_={
|
|
318
|
+
AiSearchIndex.value: stmt.excluded.value,
|
|
319
|
+
AiSearchIndex.value_type: stmt.excluded.value_type,
|
|
320
|
+
AiSearchIndex.content_hash: stmt.excluded.content_hash,
|
|
321
|
+
AiSearchIndex.embedding: stmt.excluded.embedding,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.orm import Query
|
|
6
|
+
from sqlalchemy.sql import Select
|
|
7
|
+
|
|
8
|
+
from orchestrator.db import (
|
|
9
|
+
ProcessTable,
|
|
10
|
+
ProductTable,
|
|
11
|
+
SubscriptionTable,
|
|
12
|
+
WorkflowTable,
|
|
13
|
+
)
|
|
14
|
+
from orchestrator.db.database import BaseModel
|
|
15
|
+
from orchestrator.search.core.types import EntityType
|
|
16
|
+
|
|
17
|
+
from .traverse import (
|
|
18
|
+
BaseTraverser,
|
|
19
|
+
ProcessTraverser,
|
|
20
|
+
ProductTraverser,
|
|
21
|
+
SubscriptionTraverser,
|
|
22
|
+
WorkflowTraverser,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
ModelT = TypeVar("ModelT", bound=BaseModel)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class EntityConfig(Generic[ModelT]):
|
|
30
|
+
"""A container for all configuration related to a specific entity type."""
|
|
31
|
+
|
|
32
|
+
entity_kind: EntityType
|
|
33
|
+
table: type[ModelT]
|
|
34
|
+
|
|
35
|
+
traverser: "type[BaseTraverser]"
|
|
36
|
+
pk_name: str
|
|
37
|
+
root_name: str
|
|
38
|
+
|
|
39
|
+
def get_all_query(self, entity_id: str | None = None) -> Query | Select:
|
|
40
|
+
query = self.table.query
|
|
41
|
+
if entity_id:
|
|
42
|
+
pk_column = getattr(self.table, self.pk_name)
|
|
43
|
+
query = query.filter(pk_column == UUID(entity_id))
|
|
44
|
+
return query
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class WorkflowConfig(EntityConfig[WorkflowTable]):
|
|
49
|
+
"""Workflows have a custom select() function that filters out deleted workflows."""
|
|
50
|
+
|
|
51
|
+
def get_all_query(self, entity_id: str | None = None) -> Select:
|
|
52
|
+
query = self.table.select()
|
|
53
|
+
if entity_id:
|
|
54
|
+
pk_column = getattr(self.table, self.pk_name)
|
|
55
|
+
query = query.where(pk_column == UUID(entity_id))
|
|
56
|
+
return query
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
ENTITY_CONFIG_REGISTRY: dict[EntityType, EntityConfig] = {
|
|
60
|
+
EntityType.SUBSCRIPTION: EntityConfig(
|
|
61
|
+
entity_kind=EntityType.SUBSCRIPTION,
|
|
62
|
+
table=SubscriptionTable,
|
|
63
|
+
traverser=SubscriptionTraverser,
|
|
64
|
+
pk_name="subscription_id",
|
|
65
|
+
root_name="subscription",
|
|
66
|
+
),
|
|
67
|
+
EntityType.PRODUCT: EntityConfig(
|
|
68
|
+
entity_kind=EntityType.PRODUCT,
|
|
69
|
+
table=ProductTable,
|
|
70
|
+
traverser=ProductTraverser,
|
|
71
|
+
pk_name="product_id",
|
|
72
|
+
root_name="product",
|
|
73
|
+
),
|
|
74
|
+
EntityType.PROCESS: EntityConfig(
|
|
75
|
+
entity_kind=EntityType.PROCESS,
|
|
76
|
+
table=ProcessTable,
|
|
77
|
+
traverser=ProcessTraverser,
|
|
78
|
+
pk_name="process_id",
|
|
79
|
+
root_name="process",
|
|
80
|
+
),
|
|
81
|
+
EntityType.WORKFLOW: WorkflowConfig(
|
|
82
|
+
entity_kind=EntityType.WORKFLOW,
|
|
83
|
+
table=WorkflowTable,
|
|
84
|
+
traverser=WorkflowTraverser,
|
|
85
|
+
pk_name="workflow_id",
|
|
86
|
+
root_name="workflow",
|
|
87
|
+
),
|
|
88
|
+
}
|