orchestrator-core 4.4.0rc3__py3-none-any.whl → 4.5.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. orchestrator/__init__.py +26 -2
  2. orchestrator/agentic_app.py +84 -0
  3. orchestrator/api/api_v1/api.py +10 -0
  4. orchestrator/api/api_v1/endpoints/search.py +277 -0
  5. orchestrator/app.py +32 -0
  6. orchestrator/cli/index_llm.py +73 -0
  7. orchestrator/cli/main.py +22 -1
  8. orchestrator/cli/resize_embedding.py +135 -0
  9. orchestrator/cli/search_explore.py +208 -0
  10. orchestrator/cli/speedtest.py +151 -0
  11. orchestrator/db/models.py +37 -1
  12. orchestrator/graphql/schemas/process.py +2 -2
  13. orchestrator/graphql/schemas/workflow.py +2 -2
  14. orchestrator/llm_settings.py +51 -0
  15. orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
  16. orchestrator/schedules/scheduler.py +6 -7
  17. orchestrator/schemas/search.py +117 -0
  18. orchestrator/search/__init__.py +12 -0
  19. orchestrator/search/agent/__init__.py +8 -0
  20. orchestrator/search/agent/agent.py +47 -0
  21. orchestrator/search/agent/prompts.py +62 -0
  22. orchestrator/search/agent/state.py +8 -0
  23. orchestrator/search/agent/tools.py +121 -0
  24. orchestrator/search/core/__init__.py +0 -0
  25. orchestrator/search/core/embedding.py +64 -0
  26. orchestrator/search/core/exceptions.py +22 -0
  27. orchestrator/search/core/types.py +281 -0
  28. orchestrator/search/core/validators.py +27 -0
  29. orchestrator/search/docs/index.md +37 -0
  30. orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
  31. orchestrator/search/filters/__init__.py +27 -0
  32. orchestrator/search/filters/base.py +272 -0
  33. orchestrator/search/filters/date_filters.py +75 -0
  34. orchestrator/search/filters/definitions.py +93 -0
  35. orchestrator/search/filters/ltree_filters.py +43 -0
  36. orchestrator/search/filters/numeric_filter.py +60 -0
  37. orchestrator/search/indexing/__init__.py +3 -0
  38. orchestrator/search/indexing/indexer.py +323 -0
  39. orchestrator/search/indexing/registry.py +88 -0
  40. orchestrator/search/indexing/tasks.py +53 -0
  41. orchestrator/search/indexing/traverse.py +322 -0
  42. orchestrator/search/retrieval/__init__.py +3 -0
  43. orchestrator/search/retrieval/builder.py +108 -0
  44. orchestrator/search/retrieval/engine.py +152 -0
  45. orchestrator/search/retrieval/pagination.py +83 -0
  46. orchestrator/search/retrieval/retriever.py +447 -0
  47. orchestrator/search/retrieval/utils.py +106 -0
  48. orchestrator/search/retrieval/validation.py +174 -0
  49. orchestrator/search/schemas/__init__.py +0 -0
  50. orchestrator/search/schemas/parameters.py +116 -0
  51. orchestrator/search/schemas/results.py +63 -0
  52. orchestrator/services/settings_env_variables.py +2 -2
  53. orchestrator/settings.py +1 -1
  54. {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/METADATA +8 -3
  55. {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/RECORD +57 -14
  56. {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/WHEEL +0 -0
  57. {orchestrator_core-4.4.0rc3.dist-info → orchestrator_core-4.5.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,75 @@
1
+ from datetime import date, datetime
2
+ from typing import Annotated, Any, Literal
3
+
4
+ from dateutil.parser import parse as dt_parse
5
+ from pydantic import BaseModel, BeforeValidator, Field, model_validator
6
+ from sqlalchemy import TIMESTAMP, and_
7
+ from sqlalchemy import cast as sa_cast
8
+ from sqlalchemy.sql.elements import ColumnElement
9
+
10
+ from orchestrator.search.core.types import FilterOp, SQLAColumn
11
+
12
+
13
+ def _validate_date_string(v: Any) -> Any:
14
+ if not isinstance(v, str):
15
+ return v
16
+ try:
17
+ dt_parse(v)
18
+ return v
19
+ except Exception as exc:
20
+ raise ValueError("is not a valid date or datetime string") from exc
21
+
22
+
23
+ DateValue = datetime | date | str
24
+ ValidatedDateValue = Annotated[DateValue, BeforeValidator(_validate_date_string)]
25
+
26
+
27
+ class DateRange(BaseModel):
28
+
29
+ start: ValidatedDateValue
30
+ end: ValidatedDateValue
31
+
32
+ @model_validator(mode="after")
33
+ def _order(self) -> "DateRange":
34
+ to_datetime = dt_parse(str(self.end))
35
+ from_datetime = dt_parse(str(self.start))
36
+ if to_datetime <= from_datetime:
37
+ raise ValueError("'to' must be after 'from'")
38
+ return self
39
+
40
+
41
+ class DateValueFilter(BaseModel):
42
+ """A filter that operates on a single date value."""
43
+
44
+ op: Literal[FilterOp.EQ, FilterOp.NEQ, FilterOp.LT, FilterOp.LTE, FilterOp.GT, FilterOp.GTE]
45
+ value: ValidatedDateValue
46
+
47
+ def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
48
+ date_column = sa_cast(column, TIMESTAMP(timezone=True))
49
+ match self.op:
50
+ case FilterOp.EQ:
51
+ return date_column == self.value
52
+ case FilterOp.NEQ:
53
+ return date_column != self.value
54
+ case FilterOp.LT:
55
+ return date_column < self.value
56
+ case FilterOp.LTE:
57
+ return date_column <= self.value
58
+ case FilterOp.GT:
59
+ return date_column > self.value
60
+ case FilterOp.GTE:
61
+ return date_column >= self.value
62
+
63
+
64
+ class DateRangeFilter(BaseModel):
65
+ """A filter that operates on a range of dates."""
66
+
67
+ op: Literal[FilterOp.BETWEEN]
68
+ value: DateRange
69
+
70
+ def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
71
+ date_column = sa_cast(column, TIMESTAMP(timezone=True))
72
+ return and_(date_column >= self.value.start, date_column < self.value.end)
73
+
74
+
75
+ DateFilter = Annotated[DateValueFilter | DateRangeFilter, Field(discriminator="op")]
@@ -0,0 +1,93 @@
1
+ from orchestrator.search.core.types import FieldType, FilterOp, UIType
2
+ from orchestrator.search.schemas.results import TypeDefinition, ValueSchema
3
+
4
+
5
+ def operators_for(ft: FieldType) -> list[FilterOp]:
6
+ """Return the list of valid operators for a given FieldType."""
7
+ return list(value_schema_for(ft).keys())
8
+
9
+
10
+ def component_operators() -> dict[FilterOp, ValueSchema]:
11
+ """Return operators available for path components."""
12
+ return {
13
+ FilterOp.HAS_COMPONENT: ValueSchema(kind=UIType.COMPONENT),
14
+ FilterOp.NOT_HAS_COMPONENT: ValueSchema(kind=UIType.COMPONENT),
15
+ }
16
+
17
+
18
+ def value_schema_for(ft: FieldType) -> dict[FilterOp, ValueSchema]:
19
+ """Return the value schema map for a given FieldType."""
20
+ if ft in (FieldType.INTEGER, FieldType.FLOAT):
21
+ return {
22
+ FilterOp.EQ: ValueSchema(kind=UIType.NUMBER),
23
+ FilterOp.NEQ: ValueSchema(kind=UIType.NUMBER),
24
+ FilterOp.LT: ValueSchema(kind=UIType.NUMBER),
25
+ FilterOp.LTE: ValueSchema(kind=UIType.NUMBER),
26
+ FilterOp.GT: ValueSchema(kind=UIType.NUMBER),
27
+ FilterOp.GTE: ValueSchema(kind=UIType.NUMBER),
28
+ FilterOp.BETWEEN: ValueSchema(
29
+ kind="object",
30
+ fields={
31
+ "start": ValueSchema(kind=UIType.NUMBER),
32
+ "end": ValueSchema(kind=UIType.NUMBER),
33
+ },
34
+ ),
35
+ }
36
+
37
+ if ft == FieldType.BOOLEAN:
38
+ return {
39
+ FilterOp.EQ: ValueSchema(kind=UIType.BOOLEAN),
40
+ FilterOp.NEQ: ValueSchema(kind=UIType.BOOLEAN),
41
+ }
42
+
43
+ if ft == FieldType.DATETIME:
44
+ return {
45
+ FilterOp.EQ: ValueSchema(kind=UIType.DATETIME),
46
+ FilterOp.NEQ: ValueSchema(kind=UIType.DATETIME),
47
+ FilterOp.LT: ValueSchema(kind=UIType.DATETIME),
48
+ FilterOp.LTE: ValueSchema(kind=UIType.DATETIME),
49
+ FilterOp.GT: ValueSchema(kind=UIType.DATETIME),
50
+ FilterOp.GTE: ValueSchema(kind=UIType.DATETIME),
51
+ FilterOp.BETWEEN: ValueSchema(
52
+ kind="object",
53
+ fields={
54
+ "start": ValueSchema(kind=UIType.DATETIME),
55
+ "end": ValueSchema(kind=UIType.DATETIME),
56
+ },
57
+ ),
58
+ }
59
+
60
+ return {
61
+ FilterOp.EQ: ValueSchema(kind=UIType.STRING),
62
+ FilterOp.NEQ: ValueSchema(kind=UIType.STRING),
63
+ }
64
+
65
+
66
+ def generate_definitions() -> dict[UIType, TypeDefinition]:
67
+ """Generate the full definitions dictionary for all UI types."""
68
+ definitions: dict[UIType, TypeDefinition] = {}
69
+
70
+ for ui_type in UIType:
71
+ if ui_type == UIType.COMPONENT:
72
+ # Special case for component filtering
73
+ comp_ops = component_operators()
74
+ definitions[ui_type] = TypeDefinition(
75
+ operators=list(comp_ops.keys()),
76
+ valueSchema=comp_ops,
77
+ )
78
+ else:
79
+ # Regular field types
80
+ if ui_type == UIType.NUMBER:
81
+ rep_ft = FieldType.INTEGER
82
+ elif ui_type == UIType.DATETIME:
83
+ rep_ft = FieldType.DATETIME
84
+ elif ui_type == UIType.BOOLEAN:
85
+ rep_ft = FieldType.BOOLEAN
86
+ else:
87
+ rep_ft = FieldType.STRING
88
+
89
+ definitions[ui_type] = TypeDefinition(
90
+ operators=operators_for(rep_ft),
91
+ valueSchema=value_schema_for(rep_ft),
92
+ )
93
+ return definitions
@@ -0,0 +1,43 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+ from sqlalchemy import TEXT, bindparam
5
+ from sqlalchemy.sql.elements import ColumnElement
6
+ from sqlalchemy_utils.types.ltree import Ltree
7
+
8
+ from orchestrator.search.core.types import FilterOp, SQLAColumn
9
+
10
+
11
+ class LtreeFilter(BaseModel):
12
+ """Filter for ltree path operations."""
13
+
14
+ op: Literal[
15
+ FilterOp.MATCHES_LQUERY,
16
+ FilterOp.IS_ANCESTOR,
17
+ FilterOp.IS_DESCENDANT,
18
+ FilterOp.PATH_MATCH,
19
+ FilterOp.HAS_COMPONENT,
20
+ FilterOp.NOT_HAS_COMPONENT,
21
+ FilterOp.ENDS_WITH,
22
+ ]
23
+ value: str = Field(description="The ltree path or lquery pattern to compare against.")
24
+
25
+ def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
26
+ """Converts the filter condition into a SQLAlchemy expression."""
27
+ match self.op:
28
+ case FilterOp.IS_DESCENDANT:
29
+ ltree_value = Ltree(self.value)
30
+ return column.op("<@")(ltree_value)
31
+ case FilterOp.IS_ANCESTOR:
32
+ ltree_value = Ltree(self.value)
33
+ return column.op("@>")(ltree_value)
34
+ case FilterOp.MATCHES_LQUERY:
35
+ param = bindparam(None, self.value, type_=TEXT)
36
+ return column.op("~")(param)
37
+ case FilterOp.PATH_MATCH:
38
+ ltree_value = Ltree(path)
39
+ return column == ltree_value
40
+ case FilterOp.HAS_COMPONENT | FilterOp.NOT_HAS_COMPONENT:
41
+ return column.op("~")(bindparam(None, f"*.{self.value}.*", type_=TEXT))
42
+ case FilterOp.ENDS_WITH:
43
+ return column.op("~")(bindparam(None, f"*.{self.value}", type_=TEXT))
@@ -0,0 +1,60 @@
1
+ from typing import Annotated, Any, Literal
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+ from sqlalchemy import DOUBLE_PRECISION, INTEGER, and_
5
+ from sqlalchemy import cast as sa_cast
6
+ from sqlalchemy.sql.elements import ColumnElement
7
+ from typing_extensions import Self
8
+
9
+ from orchestrator.search.core.types import FilterOp, SQLAColumn
10
+
11
+
12
+ class NumericRange(BaseModel):
13
+ start: int | float
14
+ end: int | float
15
+
16
+ @model_validator(mode="after")
17
+ def validate_order(self) -> Self:
18
+ if self.end <= self.start:
19
+ raise ValueError("'end' must be greater than 'start'")
20
+ return self
21
+
22
+
23
+ class NumericValueFilter(BaseModel):
24
+ """A filter for single numeric value comparisons (int or float)."""
25
+
26
+ op: Literal[FilterOp.EQ, FilterOp.NEQ, FilterOp.LT, FilterOp.LTE, FilterOp.GT, FilterOp.GTE]
27
+ value: int | float
28
+
29
+ def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
30
+ cast_type = INTEGER if isinstance(self.value, int) else DOUBLE_PRECISION
31
+ numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
32
+ match self.op:
33
+
34
+ case FilterOp.EQ:
35
+ return numeric_column == self.value
36
+ case FilterOp.NEQ:
37
+ return numeric_column != self.value
38
+ case FilterOp.LT:
39
+ return numeric_column < self.value
40
+ case FilterOp.LTE:
41
+ return numeric_column <= self.value
42
+ case FilterOp.GT:
43
+ return numeric_column > self.value
44
+ case FilterOp.GTE:
45
+ return numeric_column >= self.value
46
+
47
+
48
+ class NumericRangeFilter(BaseModel):
49
+ """A filter for a range of numeric values (int or float)."""
50
+
51
+ op: Literal[FilterOp.BETWEEN]
52
+ value: NumericRange
53
+
54
+ def to_expression(self, column: SQLAColumn, path: str) -> ColumnElement[bool]:
55
+ cast_type = INTEGER if isinstance(self.value.start, int) else DOUBLE_PRECISION
56
+ numeric_column: ColumnElement[Any] = sa_cast(column, cast_type)
57
+ return and_(numeric_column >= self.value.start, numeric_column <= self.value.end)
58
+
59
+
60
+ NumericFilter = Annotated[NumericValueFilter | NumericRangeFilter, Field(discriminator="op")]
@@ -0,0 +1,3 @@
1
+ from .tasks import run_indexing_for_entity
2
+
3
+ __all__ = ["run_indexing_for_entity"]
@@ -0,0 +1,323 @@
1
+ import hashlib
2
+ from collections.abc import Generator, Iterable, Iterator
3
+ from contextlib import contextmanager, nullcontext
4
+ from functools import lru_cache
5
+ from typing import Any
6
+
7
+ import structlog
8
+ from litellm.utils import encode, get_max_tokens
9
+ from sqlalchemy import delete, tuple_
10
+ from sqlalchemy.dialects.postgresql import insert
11
+ from sqlalchemy.dialects.postgresql.dml import Insert
12
+ from sqlalchemy.orm import Session
13
+ from sqlalchemy_utils.types.ltree import Ltree
14
+
15
+ from orchestrator.db import db
16
+ from orchestrator.db.models import AiSearchIndex
17
+ from orchestrator.llm_settings import llm_settings
18
+ from orchestrator.search.core.embedding import EmbeddingIndexer
19
+ from orchestrator.search.core.types import ExtractedField, IndexableRecord
20
+ from orchestrator.search.indexing.registry import EntityConfig
21
+ from orchestrator.search.indexing.traverse import DatabaseEntity
22
+
23
+ logger = structlog.get_logger(__name__)
24
+
25
+
26
+ @contextmanager
27
+ def _maybe_begin(session: Session | None) -> Iterator[None]:
28
+ if session is None:
29
+ yield
30
+ else:
31
+ with session.begin():
32
+ yield
33
+
34
+
35
+ class Indexer:
36
+ """Index entities into `AiSearchIndex` using streaming reads and batched writes.
37
+
38
+ Entities are read from a streaming iterator and accumulated into chunks of
39
+ size `chunk_size`. For each chunk, the indexer extracts fields, diffs via
40
+ content hashes, deletes stale paths, and prepares upserts using a two-list
41
+ buffer:
42
+ - Embeddable list (STRING fields): maintains a running token count against a
43
+ token budget (model context window minus a safety margin) and flushes when
44
+ adding the next item would exceed the budget.
45
+ - Non-embeddable list: accumulated in parallel and does not contribute to the
46
+ flush condition.
47
+ Each flush (or end-of-chunk) emits a single combined UPSERT batch from both
48
+ lists (wrapped in a per-chunk transaction in non-dry-runs).
49
+
50
+ Args:
51
+ config (EntityConfig): Registry config describing the entity kind,
52
+ ORM table, and traverser.
53
+ dry_run (bool): If True, skip DELETE/UPSERT statements and external
54
+ embedding calls.
55
+ force_index (bool): If True, ignore existing hashes and reindex all
56
+ fields for each entity.
57
+ chunk_size (int): Number of entities to process per batch. Defaults to 1000.
58
+
59
+ Notes:
60
+ - Non-dry-run runs open a write session and wrap each processed chunk in
61
+ a transaction (`Session.begin()`).
62
+ - Read queries use the passed session when available, otherwise the
63
+ generic `db.session`.
64
+
65
+ Workflow:
66
+ 1) Stream entities (yield_per=chunk_size) and accumulate into a chunk.
67
+ 2) Begin transaction for the chunk.
68
+ 3) determine_changes() → fields_to_upsert, paths_to_delete.
69
+ 4) Delete stale paths.
70
+ 5) Build UPSERT batches with a two-list buffer:
71
+ - Embeddable list (STRING): track running token count; flush when next item
72
+ would exceed the token budget (model max context - safety margin).
73
+ - Non-embeddable list: accumulate in parallel; does not affect flushing.
74
+ 6) Execute UPSERT for each batch (skip in dry_run).
75
+ 7) Commit transaction (auto on context exit).
76
+ 8) Repeat until the stream is exhausted.
77
+ """
78
+
79
+ def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk_size: int = 1000) -> None:
80
+ self.config = config
81
+ self.dry_run = dry_run
82
+ self.force_index = force_index
83
+ self.chunk_size = chunk_size
84
+ self.embedding_model = llm_settings.EMBEDDING_MODEL
85
+ self.logger = logger.bind(entity_kind=config.entity_kind.value)
86
+
87
+ def run(self, entities: Iterable[DatabaseEntity]) -> int:
88
+ """Orchestrates the entire indexing process."""
89
+ chunk: list[DatabaseEntity] = []
90
+ total_records_processed = 0
91
+ total_identical_records = 0
92
+
93
+ write_scope = db.database_scope() if not self.dry_run else nullcontext()
94
+
95
+ def flush() -> None:
96
+ nonlocal total_records_processed, total_identical_records
97
+ with _maybe_begin(session):
98
+ processed_in_chunk, identical_in_chunk = self._process_chunk(chunk, session)
99
+ total_records_processed += processed_in_chunk
100
+ total_identical_records += identical_in_chunk
101
+ chunk.clear()
102
+
103
+ with write_scope as database:
104
+ session: Session | None = getattr(database, "session", None)
105
+ for entity in entities:
106
+ chunk.append(entity)
107
+ if len(chunk) >= self.chunk_size:
108
+ flush()
109
+
110
+ if chunk:
111
+ flush()
112
+
113
+ final_log_message = (
114
+ f"processed {total_records_processed} records and skipped {total_identical_records} identical records."
115
+ )
116
+ self.logger.info(
117
+ f"Dry run, would have indexed {final_log_message}"
118
+ if self.dry_run
119
+ else f"Indexing done, {final_log_message}"
120
+ )
121
+ return total_records_processed
122
+
123
+ def _process_chunk(self, entity_chunk: list[DatabaseEntity], session: Session | None = None) -> tuple[int, int]:
124
+ """Process a chunk of entities."""
125
+ if not entity_chunk:
126
+ return 0, 0
127
+
128
+ fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session)
129
+
130
+ if paths_to_delete and session is not None:
131
+ self.logger.debug(f"Deleting {len(paths_to_delete)} stale records in chunk.")
132
+ self._execute_batched_deletes(paths_to_delete, session)
133
+
134
+ if fields_to_upsert:
135
+ upsert_stmt = self._get_upsert_statement()
136
+ batch_generator = self._generate_upsert_batches(fields_to_upsert)
137
+
138
+ for batch in batch_generator:
139
+ if self.dry_run:
140
+ self.logger.debug(f"Dry Run: Would upsert {len(batch)} records.")
141
+ elif batch and session:
142
+ session.execute(upsert_stmt, batch)
143
+
144
+ return len(fields_to_upsert), identical_count
145
+
146
+ def _determine_changes(
147
+ self, entities: list[DatabaseEntity], session: Session | None = None
148
+ ) -> tuple[list[tuple[str, ExtractedField]], list[tuple[str, Ltree]], int]:
149
+ """Identifies all changes across all entities using pre-fetched data."""
150
+ entity_ids = [str(getattr(e, self.config.pk_name)) for e in entities]
151
+ read_session = session or db.session
152
+ existing_hashes = {} if self.force_index else self._get_all_existing_hashes(entity_ids, read_session)
153
+
154
+ fields_to_upsert: list[tuple[str, ExtractedField]] = []
155
+ paths_to_delete: list[tuple[str, Ltree]] = []
156
+ identical_records_count = 0
157
+
158
+ for entity in entities:
159
+ entity_id = str(getattr(entity, self.config.pk_name))
160
+ current_fields = self.config.traverser.get_fields(
161
+ entity, pk_name=self.config.pk_name, root_name=self.config.root_name
162
+ )
163
+
164
+ entity_hashes = existing_hashes.get(entity_id, {})
165
+ current_paths = set()
166
+
167
+ for field in current_fields:
168
+ current_paths.add(field.path)
169
+ current_hash = self._compute_content_hash(field.path, field.value, field.value_type)
170
+ if field.path not in entity_hashes or entity_hashes[field.path] != current_hash:
171
+ fields_to_upsert.append((entity_id, field))
172
+ else:
173
+ identical_records_count += 1
174
+
175
+ stale_paths = set(entity_hashes.keys()) - current_paths
176
+ paths_to_delete.extend([(entity_id, Ltree(p)) for p in stale_paths])
177
+
178
+ return fields_to_upsert, paths_to_delete, identical_records_count
179
+
180
+ def _execute_batched_deletes(self, paths_to_delete: list[tuple[str, Ltree]], session: Session) -> None:
181
+ """Execute delete operations in batches to avoid PostgreSQL stack depth limits."""
182
+ for i in range(0, len(paths_to_delete), self.chunk_size):
183
+ batch = paths_to_delete[i : i + self.chunk_size]
184
+ delete_stmt = delete(AiSearchIndex).where(tuple_(AiSearchIndex.entity_id, AiSearchIndex.path).in_(batch))
185
+ session.execute(delete_stmt)
186
+ self.logger.debug(f"Deleted batch of {len(batch)} records.")
187
+
188
+ def _get_all_existing_hashes(self, entity_ids: list[str], session: Session) -> dict[str, dict[str, str]]:
189
+ """Fetches all existing hashes for a list of entity IDs in a single query."""
190
+ if not entity_ids:
191
+ return {}
192
+
193
+ results = (
194
+ session.query(AiSearchIndex.entity_id, AiSearchIndex.path, AiSearchIndex.content_hash)
195
+ .filter(AiSearchIndex.entity_id.in_(entity_ids))
196
+ .all()
197
+ )
198
+
199
+ hashes_by_entity: dict[str, dict[str, str]] = {eid: {} for eid in entity_ids}
200
+ for entity_id, path, content_hash in results:
201
+ hashes_by_entity[str(entity_id)][str(path)] = content_hash
202
+ return hashes_by_entity
203
+
204
+ def _generate_upsert_batches(
205
+ self, fields_to_upsert: Iterable[tuple[str, ExtractedField]]
206
+ ) -> Generator[list[IndexableRecord], None, None]:
207
+ """Streams through fields, buffers them by token count, and yields batches."""
208
+ embeddable_buffer: list[tuple[str, ExtractedField]] = []
209
+ non_embeddable_records: list[IndexableRecord] = []
210
+ current_tokens = 0
211
+
212
+ max_ctx = self._get_max_tokens()
213
+ safe_margin = int(max_ctx * llm_settings.EMBEDDING_SAFE_MARGIN_PERCENT)
214
+ token_budget = max(1, max_ctx - safe_margin)
215
+
216
+ max_batch_size = None
217
+ if llm_settings.OPENAI_BASE_URL: # We are using a local model
218
+ max_batch_size = llm_settings.EMBEDDING_MAX_BATCH_SIZE
219
+
220
+ for entity_id, field in fields_to_upsert:
221
+ if field.value_type.is_embeddable(field.value):
222
+ text = self._prepare_text_for_embedding(field)
223
+ try:
224
+ item_tokens = len(encode(model=self.embedding_model, text=text))
225
+ except Exception as e:
226
+ self.logger.warning("Tokenization failed; skipping.", path=field.path, err=str(e))
227
+ continue
228
+
229
+ if item_tokens > max_ctx:
230
+ self.logger.warning(
231
+ "Field exceeds context; skipping.", path=field.path, tokens=item_tokens, max_ctx=max_ctx
232
+ )
233
+ continue
234
+
235
+ should_flush = embeddable_buffer and (
236
+ current_tokens + item_tokens > token_budget
237
+ or (max_batch_size and len(embeddable_buffer) >= max_batch_size)
238
+ )
239
+
240
+ if should_flush:
241
+ yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
242
+ embeddable_buffer.clear()
243
+ non_embeddable_records.clear()
244
+ current_tokens = 0
245
+
246
+ embeddable_buffer.append((entity_id, field))
247
+ current_tokens += item_tokens
248
+ else:
249
+ record = self._make_indexable_record(field, entity_id, embedding=None)
250
+ non_embeddable_records.append(record)
251
+
252
+ if embeddable_buffer or non_embeddable_records:
253
+ yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
254
+
255
+ def _flush_buffer(self, embeddable_buffer: list, non_embeddable_records: list) -> list[IndexableRecord]:
256
+ """Processes and combines buffers into a single batch."""
257
+ if not embeddable_buffer:
258
+ return non_embeddable_records
259
+
260
+ texts_to_embed = [self._prepare_text_for_embedding(f) for _, f in embeddable_buffer]
261
+ embeddings = EmbeddingIndexer.get_embeddings_from_api_batch(texts_to_embed, self.dry_run)
262
+
263
+ if len(embeddable_buffer) != len(embeddings):
264
+ raise ValueError(f"Embedding mismatch: sent {len(embeddable_buffer)}, received {len(embeddings)}")
265
+
266
+ with_embeddings = [
267
+ self._make_indexable_record(field, entity_id, embedding)
268
+ for (entity_id, field), embedding in zip(embeddable_buffer, embeddings)
269
+ ]
270
+ return non_embeddable_records + with_embeddings
271
+
272
+ def _get_max_tokens(self) -> int:
273
+ """Gets max tokens, using a fallback from settings if necessary."""
274
+ try:
275
+ max_ctx = get_max_tokens(self.embedding_model)
276
+ if isinstance(max_ctx, int):
277
+ return max_ctx
278
+ except Exception:
279
+ # Allow local(unknown) models to fall back.
280
+ self.logger.warning("Could not auto-detect max tokens.", model=self.embedding_model)
281
+
282
+ max_ctx = llm_settings.EMBEDDING_FALLBACK_MAX_TOKENS
283
+ if not isinstance(max_ctx, int):
284
+ raise RuntimeError("Model not recognized and EMBEDDING_FALLBACK_MAX_TOKENS not set.")
285
+ self.logger.warning("Using configured fallback token limit.", fallback=max_ctx)
286
+ return max_ctx
287
+
288
+ @staticmethod
289
+ def _prepare_text_for_embedding(field: ExtractedField) -> str:
290
+ return f"{field.path}: {str(field.value)}"
291
+
292
+ @staticmethod
293
+ def _compute_content_hash(path: str, value: Any, value_type: Any) -> str:
294
+ v = "" if value is None else str(value)
295
+ content = f"{path}:{v}:{value_type}"
296
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
297
+
298
+ def _make_indexable_record(
299
+ self, field: ExtractedField, entity_id: str, embedding: list[float] | None
300
+ ) -> IndexableRecord:
301
+ return IndexableRecord(
302
+ entity_id=entity_id,
303
+ entity_type=self.config.entity_kind.value,
304
+ path=Ltree(field.path),
305
+ value=field.value,
306
+ value_type=field.value_type,
307
+ content_hash=self._compute_content_hash(field.path, field.value, field.value_type),
308
+ embedding=embedding if embedding else None,
309
+ )
310
+
311
+ @staticmethod
312
+ @lru_cache(maxsize=1)
313
+ def _get_upsert_statement() -> Insert:
314
+ stmt = insert(AiSearchIndex)
315
+ return stmt.on_conflict_do_update(
316
+ index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path],
317
+ set_={
318
+ AiSearchIndex.value: stmt.excluded.value,
319
+ AiSearchIndex.value_type: stmt.excluded.value_type,
320
+ AiSearchIndex.content_hash: stmt.excluded.content_hash,
321
+ AiSearchIndex.embedding: stmt.excluded.embedding,
322
+ },
323
+ )
@@ -0,0 +1,88 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar
3
+ from uuid import UUID
4
+
5
+ from sqlalchemy.orm import Query
6
+ from sqlalchemy.sql import Select
7
+
8
+ from orchestrator.db import (
9
+ ProcessTable,
10
+ ProductTable,
11
+ SubscriptionTable,
12
+ WorkflowTable,
13
+ )
14
+ from orchestrator.db.database import BaseModel
15
+ from orchestrator.search.core.types import EntityType
16
+
17
+ from .traverse import (
18
+ BaseTraverser,
19
+ ProcessTraverser,
20
+ ProductTraverser,
21
+ SubscriptionTraverser,
22
+ WorkflowTraverser,
23
+ )
24
+
25
+ ModelT = TypeVar("ModelT", bound=BaseModel)
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class EntityConfig(Generic[ModelT]):
30
+ """A container for all configuration related to a specific entity type."""
31
+
32
+ entity_kind: EntityType
33
+ table: type[ModelT]
34
+
35
+ traverser: "type[BaseTraverser]"
36
+ pk_name: str
37
+ root_name: str
38
+
39
+ def get_all_query(self, entity_id: str | None = None) -> Query | Select:
40
+ query = self.table.query
41
+ if entity_id:
42
+ pk_column = getattr(self.table, self.pk_name)
43
+ query = query.filter(pk_column == UUID(entity_id))
44
+ return query
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class WorkflowConfig(EntityConfig[WorkflowTable]):
49
+ """Workflows have a custom select() function that filters out deleted workflows."""
50
+
51
+ def get_all_query(self, entity_id: str | None = None) -> Select:
52
+ query = self.table.select()
53
+ if entity_id:
54
+ pk_column = getattr(self.table, self.pk_name)
55
+ query = query.where(pk_column == UUID(entity_id))
56
+ return query
57
+
58
+
59
+ ENTITY_CONFIG_REGISTRY: dict[EntityType, EntityConfig] = {
60
+ EntityType.SUBSCRIPTION: EntityConfig(
61
+ entity_kind=EntityType.SUBSCRIPTION,
62
+ table=SubscriptionTable,
63
+ traverser=SubscriptionTraverser,
64
+ pk_name="subscription_id",
65
+ root_name="subscription",
66
+ ),
67
+ EntityType.PRODUCT: EntityConfig(
68
+ entity_kind=EntityType.PRODUCT,
69
+ table=ProductTable,
70
+ traverser=ProductTraverser,
71
+ pk_name="product_id",
72
+ root_name="product",
73
+ ),
74
+ EntityType.PROCESS: EntityConfig(
75
+ entity_kind=EntityType.PROCESS,
76
+ table=ProcessTable,
77
+ traverser=ProcessTraverser,
78
+ pk_name="process_id",
79
+ root_name="process",
80
+ ),
81
+ EntityType.WORKFLOW: WorkflowConfig(
82
+ entity_kind=EntityType.WORKFLOW,
83
+ table=WorkflowTable,
84
+ traverser=WorkflowTraverser,
85
+ pk_name="workflow_id",
86
+ root_name="workflow",
87
+ ),
88
+ }