orchestrator-core 4.4.1__py3-none-any.whl → 4.5.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +26 -2
- orchestrator/agentic_app.py +84 -0
- orchestrator/api/api_v1/api.py +10 -0
- orchestrator/api/api_v1/endpoints/search.py +277 -0
- orchestrator/app.py +32 -0
- orchestrator/cli/index_llm.py +73 -0
- orchestrator/cli/main.py +22 -1
- orchestrator/cli/resize_embedding.py +135 -0
- orchestrator/cli/search_explore.py +208 -0
- orchestrator/cli/speedtest.py +151 -0
- orchestrator/db/models.py +37 -1
- orchestrator/llm_settings.py +51 -0
- orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
- orchestrator/schemas/search.py +117 -0
- orchestrator/search/__init__.py +12 -0
- orchestrator/search/agent/__init__.py +8 -0
- orchestrator/search/agent/agent.py +47 -0
- orchestrator/search/agent/prompts.py +87 -0
- orchestrator/search/agent/state.py +8 -0
- orchestrator/search/agent/tools.py +236 -0
- orchestrator/search/core/__init__.py +0 -0
- orchestrator/search/core/embedding.py +64 -0
- orchestrator/search/core/exceptions.py +22 -0
- orchestrator/search/core/types.py +281 -0
- orchestrator/search/core/validators.py +27 -0
- orchestrator/search/docs/index.md +37 -0
- orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
- orchestrator/search/filters/__init__.py +27 -0
- orchestrator/search/filters/base.py +275 -0
- orchestrator/search/filters/date_filters.py +75 -0
- orchestrator/search/filters/definitions.py +93 -0
- orchestrator/search/filters/ltree_filters.py +43 -0
- orchestrator/search/filters/numeric_filter.py +60 -0
- orchestrator/search/indexing/__init__.py +3 -0
- orchestrator/search/indexing/indexer.py +323 -0
- orchestrator/search/indexing/registry.py +88 -0
- orchestrator/search/indexing/tasks.py +53 -0
- orchestrator/search/indexing/traverse.py +322 -0
- orchestrator/search/retrieval/__init__.py +3 -0
- orchestrator/search/retrieval/builder.py +113 -0
- orchestrator/search/retrieval/engine.py +152 -0
- orchestrator/search/retrieval/pagination.py +83 -0
- orchestrator/search/retrieval/retriever.py +447 -0
- orchestrator/search/retrieval/utils.py +106 -0
- orchestrator/search/retrieval/validation.py +174 -0
- orchestrator/search/schemas/__init__.py +0 -0
- orchestrator/search/schemas/parameters.py +116 -0
- orchestrator/search/schemas/results.py +64 -0
- orchestrator/services/settings_env_variables.py +2 -2
- orchestrator/settings.py +1 -1
- {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.0a2.dist-info}/METADATA +8 -3
- {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.0a2.dist-info}/RECORD +54 -11
- {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.0a2.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.0a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from collections.abc import Generator, Iterable, Iterator
|
|
3
|
+
from contextlib import contextmanager, nullcontext
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from litellm.utils import encode, get_max_tokens
|
|
9
|
+
from sqlalchemy import delete, tuple_
|
|
10
|
+
from sqlalchemy.dialects.postgresql import insert
|
|
11
|
+
from sqlalchemy.dialects.postgresql.dml import Insert
|
|
12
|
+
from sqlalchemy.orm import Session
|
|
13
|
+
from sqlalchemy_utils.types.ltree import Ltree
|
|
14
|
+
|
|
15
|
+
from orchestrator.db import db
|
|
16
|
+
from orchestrator.db.models import AiSearchIndex
|
|
17
|
+
from orchestrator.llm_settings import llm_settings
|
|
18
|
+
from orchestrator.search.core.embedding import EmbeddingIndexer
|
|
19
|
+
from orchestrator.search.core.types import ExtractedField, IndexableRecord
|
|
20
|
+
from orchestrator.search.indexing.registry import EntityConfig
|
|
21
|
+
from orchestrator.search.indexing.traverse import DatabaseEntity
|
|
22
|
+
|
|
23
|
+
logger = structlog.get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@contextmanager
|
|
27
|
+
def _maybe_begin(session: Session | None) -> Iterator[None]:
|
|
28
|
+
if session is None:
|
|
29
|
+
yield
|
|
30
|
+
else:
|
|
31
|
+
with session.begin():
|
|
32
|
+
yield
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Indexer:
|
|
36
|
+
"""Index entities into `AiSearchIndex` using streaming reads and batched writes.
|
|
37
|
+
|
|
38
|
+
Entities are read from a streaming iterator and accumulated into chunks of
|
|
39
|
+
size `chunk_size`. For each chunk, the indexer extracts fields, diffs via
|
|
40
|
+
content hashes, deletes stale paths, and prepares upserts using a two-list
|
|
41
|
+
buffer:
|
|
42
|
+
- Embeddable list (STRING fields): maintains a running token count against a
|
|
43
|
+
token budget (model context window minus a safety margin) and flushes when
|
|
44
|
+
adding the next item would exceed the budget.
|
|
45
|
+
- Non-embeddable list: accumulated in parallel and does not contribute to the
|
|
46
|
+
flush condition.
|
|
47
|
+
Each flush (or end-of-chunk) emits a single combined UPSERT batch from both
|
|
48
|
+
lists (wrapped in a per-chunk transaction in non-dry-runs).
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config (EntityConfig): Registry config describing the entity kind,
|
|
52
|
+
ORM table, and traverser.
|
|
53
|
+
dry_run (bool): If True, skip DELETE/UPSERT statements and external
|
|
54
|
+
embedding calls.
|
|
55
|
+
force_index (bool): If True, ignore existing hashes and reindex all
|
|
56
|
+
fields for each entity.
|
|
57
|
+
chunk_size (int): Number of entities to process per batch. Defaults to 1000.
|
|
58
|
+
|
|
59
|
+
Notes:
|
|
60
|
+
- Non-dry-run runs open a write session and wrap each processed chunk in
|
|
61
|
+
a transaction (`Session.begin()`).
|
|
62
|
+
- Read queries use the passed session when available, otherwise the
|
|
63
|
+
generic `db.session`.
|
|
64
|
+
|
|
65
|
+
Workflow:
|
|
66
|
+
1) Stream entities (yield_per=chunk_size) and accumulate into a chunk.
|
|
67
|
+
2) Begin transaction for the chunk.
|
|
68
|
+
3) determine_changes() → fields_to_upsert, paths_to_delete.
|
|
69
|
+
4) Delete stale paths.
|
|
70
|
+
5) Build UPSERT batches with a two-list buffer:
|
|
71
|
+
- Embeddable list (STRING): track running token count; flush when next item
|
|
72
|
+
would exceed the token budget (model max context - safety margin).
|
|
73
|
+
- Non-embeddable list: accumulate in parallel; does not affect flushing.
|
|
74
|
+
6) Execute UPSERT for each batch (skip in dry_run).
|
|
75
|
+
7) Commit transaction (auto on context exit).
|
|
76
|
+
8) Repeat until the stream is exhausted.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool, chunk_size: int = 1000) -> None:
|
|
80
|
+
self.config = config
|
|
81
|
+
self.dry_run = dry_run
|
|
82
|
+
self.force_index = force_index
|
|
83
|
+
self.chunk_size = chunk_size
|
|
84
|
+
self.embedding_model = llm_settings.EMBEDDING_MODEL
|
|
85
|
+
self.logger = logger.bind(entity_kind=config.entity_kind.value)
|
|
86
|
+
|
|
87
|
+
def run(self, entities: Iterable[DatabaseEntity]) -> int:
|
|
88
|
+
"""Orchestrates the entire indexing process."""
|
|
89
|
+
chunk: list[DatabaseEntity] = []
|
|
90
|
+
total_records_processed = 0
|
|
91
|
+
total_identical_records = 0
|
|
92
|
+
|
|
93
|
+
write_scope = db.database_scope() if not self.dry_run else nullcontext()
|
|
94
|
+
|
|
95
|
+
def flush() -> None:
|
|
96
|
+
nonlocal total_records_processed, total_identical_records
|
|
97
|
+
with _maybe_begin(session):
|
|
98
|
+
processed_in_chunk, identical_in_chunk = self._process_chunk(chunk, session)
|
|
99
|
+
total_records_processed += processed_in_chunk
|
|
100
|
+
total_identical_records += identical_in_chunk
|
|
101
|
+
chunk.clear()
|
|
102
|
+
|
|
103
|
+
with write_scope as database:
|
|
104
|
+
session: Session | None = getattr(database, "session", None)
|
|
105
|
+
for entity in entities:
|
|
106
|
+
chunk.append(entity)
|
|
107
|
+
if len(chunk) >= self.chunk_size:
|
|
108
|
+
flush()
|
|
109
|
+
|
|
110
|
+
if chunk:
|
|
111
|
+
flush()
|
|
112
|
+
|
|
113
|
+
final_log_message = (
|
|
114
|
+
f"processed {total_records_processed} records and skipped {total_identical_records} identical records."
|
|
115
|
+
)
|
|
116
|
+
self.logger.info(
|
|
117
|
+
f"Dry run, would have indexed {final_log_message}"
|
|
118
|
+
if self.dry_run
|
|
119
|
+
else f"Indexing done, {final_log_message}"
|
|
120
|
+
)
|
|
121
|
+
return total_records_processed
|
|
122
|
+
|
|
123
|
+
def _process_chunk(self, entity_chunk: list[DatabaseEntity], session: Session | None = None) -> tuple[int, int]:
|
|
124
|
+
"""Process a chunk of entities."""
|
|
125
|
+
if not entity_chunk:
|
|
126
|
+
return 0, 0
|
|
127
|
+
|
|
128
|
+
fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session)
|
|
129
|
+
|
|
130
|
+
if paths_to_delete and session is not None:
|
|
131
|
+
self.logger.debug(f"Deleting {len(paths_to_delete)} stale records in chunk.")
|
|
132
|
+
self._execute_batched_deletes(paths_to_delete, session)
|
|
133
|
+
|
|
134
|
+
if fields_to_upsert:
|
|
135
|
+
upsert_stmt = self._get_upsert_statement()
|
|
136
|
+
batch_generator = self._generate_upsert_batches(fields_to_upsert)
|
|
137
|
+
|
|
138
|
+
for batch in batch_generator:
|
|
139
|
+
if self.dry_run:
|
|
140
|
+
self.logger.debug(f"Dry Run: Would upsert {len(batch)} records.")
|
|
141
|
+
elif batch and session:
|
|
142
|
+
session.execute(upsert_stmt, batch)
|
|
143
|
+
|
|
144
|
+
return len(fields_to_upsert), identical_count
|
|
145
|
+
|
|
146
|
+
def _determine_changes(
|
|
147
|
+
self, entities: list[DatabaseEntity], session: Session | None = None
|
|
148
|
+
) -> tuple[list[tuple[str, ExtractedField]], list[tuple[str, Ltree]], int]:
|
|
149
|
+
"""Identifies all changes across all entities using pre-fetched data."""
|
|
150
|
+
entity_ids = [str(getattr(e, self.config.pk_name)) for e in entities]
|
|
151
|
+
read_session = session or db.session
|
|
152
|
+
existing_hashes = {} if self.force_index else self._get_all_existing_hashes(entity_ids, read_session)
|
|
153
|
+
|
|
154
|
+
fields_to_upsert: list[tuple[str, ExtractedField]] = []
|
|
155
|
+
paths_to_delete: list[tuple[str, Ltree]] = []
|
|
156
|
+
identical_records_count = 0
|
|
157
|
+
|
|
158
|
+
for entity in entities:
|
|
159
|
+
entity_id = str(getattr(entity, self.config.pk_name))
|
|
160
|
+
current_fields = self.config.traverser.get_fields(
|
|
161
|
+
entity, pk_name=self.config.pk_name, root_name=self.config.root_name
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
entity_hashes = existing_hashes.get(entity_id, {})
|
|
165
|
+
current_paths = set()
|
|
166
|
+
|
|
167
|
+
for field in current_fields:
|
|
168
|
+
current_paths.add(field.path)
|
|
169
|
+
current_hash = self._compute_content_hash(field.path, field.value, field.value_type)
|
|
170
|
+
if field.path not in entity_hashes or entity_hashes[field.path] != current_hash:
|
|
171
|
+
fields_to_upsert.append((entity_id, field))
|
|
172
|
+
else:
|
|
173
|
+
identical_records_count += 1
|
|
174
|
+
|
|
175
|
+
stale_paths = set(entity_hashes.keys()) - current_paths
|
|
176
|
+
paths_to_delete.extend([(entity_id, Ltree(p)) for p in stale_paths])
|
|
177
|
+
|
|
178
|
+
return fields_to_upsert, paths_to_delete, identical_records_count
|
|
179
|
+
|
|
180
|
+
def _execute_batched_deletes(self, paths_to_delete: list[tuple[str, Ltree]], session: Session) -> None:
|
|
181
|
+
"""Execute delete operations in batches to avoid PostgreSQL stack depth limits."""
|
|
182
|
+
for i in range(0, len(paths_to_delete), self.chunk_size):
|
|
183
|
+
batch = paths_to_delete[i : i + self.chunk_size]
|
|
184
|
+
delete_stmt = delete(AiSearchIndex).where(tuple_(AiSearchIndex.entity_id, AiSearchIndex.path).in_(batch))
|
|
185
|
+
session.execute(delete_stmt)
|
|
186
|
+
self.logger.debug(f"Deleted batch of {len(batch)} records.")
|
|
187
|
+
|
|
188
|
+
def _get_all_existing_hashes(self, entity_ids: list[str], session: Session) -> dict[str, dict[str, str]]:
|
|
189
|
+
"""Fetches all existing hashes for a list of entity IDs in a single query."""
|
|
190
|
+
if not entity_ids:
|
|
191
|
+
return {}
|
|
192
|
+
|
|
193
|
+
results = (
|
|
194
|
+
session.query(AiSearchIndex.entity_id, AiSearchIndex.path, AiSearchIndex.content_hash)
|
|
195
|
+
.filter(AiSearchIndex.entity_id.in_(entity_ids))
|
|
196
|
+
.all()
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
hashes_by_entity: dict[str, dict[str, str]] = {eid: {} for eid in entity_ids}
|
|
200
|
+
for entity_id, path, content_hash in results:
|
|
201
|
+
hashes_by_entity[str(entity_id)][str(path)] = content_hash
|
|
202
|
+
return hashes_by_entity
|
|
203
|
+
|
|
204
|
+
def _generate_upsert_batches(
|
|
205
|
+
self, fields_to_upsert: Iterable[tuple[str, ExtractedField]]
|
|
206
|
+
) -> Generator[list[IndexableRecord], None, None]:
|
|
207
|
+
"""Streams through fields, buffers them by token count, and yields batches."""
|
|
208
|
+
embeddable_buffer: list[tuple[str, ExtractedField]] = []
|
|
209
|
+
non_embeddable_records: list[IndexableRecord] = []
|
|
210
|
+
current_tokens = 0
|
|
211
|
+
|
|
212
|
+
max_ctx = self._get_max_tokens()
|
|
213
|
+
safe_margin = int(max_ctx * llm_settings.EMBEDDING_SAFE_MARGIN_PERCENT)
|
|
214
|
+
token_budget = max(1, max_ctx - safe_margin)
|
|
215
|
+
|
|
216
|
+
max_batch_size = None
|
|
217
|
+
if llm_settings.OPENAI_BASE_URL: # We are using a local model
|
|
218
|
+
max_batch_size = llm_settings.EMBEDDING_MAX_BATCH_SIZE
|
|
219
|
+
|
|
220
|
+
for entity_id, field in fields_to_upsert:
|
|
221
|
+
if field.value_type.is_embeddable(field.value):
|
|
222
|
+
text = self._prepare_text_for_embedding(field)
|
|
223
|
+
try:
|
|
224
|
+
item_tokens = len(encode(model=self.embedding_model, text=text))
|
|
225
|
+
except Exception as e:
|
|
226
|
+
self.logger.warning("Tokenization failed; skipping.", path=field.path, err=str(e))
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
if item_tokens > max_ctx:
|
|
230
|
+
self.logger.warning(
|
|
231
|
+
"Field exceeds context; skipping.", path=field.path, tokens=item_tokens, max_ctx=max_ctx
|
|
232
|
+
)
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
should_flush = embeddable_buffer and (
|
|
236
|
+
current_tokens + item_tokens > token_budget
|
|
237
|
+
or (max_batch_size and len(embeddable_buffer) >= max_batch_size)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if should_flush:
|
|
241
|
+
yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
|
|
242
|
+
embeddable_buffer.clear()
|
|
243
|
+
non_embeddable_records.clear()
|
|
244
|
+
current_tokens = 0
|
|
245
|
+
|
|
246
|
+
embeddable_buffer.append((entity_id, field))
|
|
247
|
+
current_tokens += item_tokens
|
|
248
|
+
else:
|
|
249
|
+
record = self._make_indexable_record(field, entity_id, embedding=None)
|
|
250
|
+
non_embeddable_records.append(record)
|
|
251
|
+
|
|
252
|
+
if embeddable_buffer or non_embeddable_records:
|
|
253
|
+
yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
|
|
254
|
+
|
|
255
|
+
def _flush_buffer(self, embeddable_buffer: list, non_embeddable_records: list) -> list[IndexableRecord]:
|
|
256
|
+
"""Processes and combines buffers into a single batch."""
|
|
257
|
+
if not embeddable_buffer:
|
|
258
|
+
return non_embeddable_records
|
|
259
|
+
|
|
260
|
+
texts_to_embed = [self._prepare_text_for_embedding(f) for _, f in embeddable_buffer]
|
|
261
|
+
embeddings = EmbeddingIndexer.get_embeddings_from_api_batch(texts_to_embed, self.dry_run)
|
|
262
|
+
|
|
263
|
+
if len(embeddable_buffer) != len(embeddings):
|
|
264
|
+
raise ValueError(f"Embedding mismatch: sent {len(embeddable_buffer)}, received {len(embeddings)}")
|
|
265
|
+
|
|
266
|
+
with_embeddings = [
|
|
267
|
+
self._make_indexable_record(field, entity_id, embedding)
|
|
268
|
+
for (entity_id, field), embedding in zip(embeddable_buffer, embeddings)
|
|
269
|
+
]
|
|
270
|
+
return non_embeddable_records + with_embeddings
|
|
271
|
+
|
|
272
|
+
def _get_max_tokens(self) -> int:
|
|
273
|
+
"""Gets max tokens, using a fallback from settings if necessary."""
|
|
274
|
+
try:
|
|
275
|
+
max_ctx = get_max_tokens(self.embedding_model)
|
|
276
|
+
if isinstance(max_ctx, int):
|
|
277
|
+
return max_ctx
|
|
278
|
+
except Exception:
|
|
279
|
+
# Allow local(unknown) models to fall back.
|
|
280
|
+
self.logger.warning("Could not auto-detect max tokens.", model=self.embedding_model)
|
|
281
|
+
|
|
282
|
+
max_ctx = llm_settings.EMBEDDING_FALLBACK_MAX_TOKENS
|
|
283
|
+
if not isinstance(max_ctx, int):
|
|
284
|
+
raise RuntimeError("Model not recognized and EMBEDDING_FALLBACK_MAX_TOKENS not set.")
|
|
285
|
+
self.logger.warning("Using configured fallback token limit.", fallback=max_ctx)
|
|
286
|
+
return max_ctx
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def _prepare_text_for_embedding(field: ExtractedField) -> str:
|
|
290
|
+
return f"{field.path}: {str(field.value)}"
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _compute_content_hash(path: str, value: Any, value_type: Any) -> str:
|
|
294
|
+
v = "" if value is None else str(value)
|
|
295
|
+
content = f"{path}:{v}:{value_type}"
|
|
296
|
+
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
297
|
+
|
|
298
|
+
def _make_indexable_record(
|
|
299
|
+
self, field: ExtractedField, entity_id: str, embedding: list[float] | None
|
|
300
|
+
) -> IndexableRecord:
|
|
301
|
+
return IndexableRecord(
|
|
302
|
+
entity_id=entity_id,
|
|
303
|
+
entity_type=self.config.entity_kind.value,
|
|
304
|
+
path=Ltree(field.path),
|
|
305
|
+
value=field.value,
|
|
306
|
+
value_type=field.value_type,
|
|
307
|
+
content_hash=self._compute_content_hash(field.path, field.value, field.value_type),
|
|
308
|
+
embedding=embedding if embedding else None,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
@lru_cache(maxsize=1)
|
|
313
|
+
def _get_upsert_statement() -> Insert:
|
|
314
|
+
stmt = insert(AiSearchIndex)
|
|
315
|
+
return stmt.on_conflict_do_update(
|
|
316
|
+
index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path],
|
|
317
|
+
set_={
|
|
318
|
+
AiSearchIndex.value: stmt.excluded.value,
|
|
319
|
+
AiSearchIndex.value_type: stmt.excluded.value_type,
|
|
320
|
+
AiSearchIndex.content_hash: stmt.excluded.content_hash,
|
|
321
|
+
AiSearchIndex.embedding: stmt.excluded.embedding,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.orm import Query
|
|
6
|
+
from sqlalchemy.sql import Select
|
|
7
|
+
|
|
8
|
+
from orchestrator.db import (
|
|
9
|
+
ProcessTable,
|
|
10
|
+
ProductTable,
|
|
11
|
+
SubscriptionTable,
|
|
12
|
+
WorkflowTable,
|
|
13
|
+
)
|
|
14
|
+
from orchestrator.db.database import BaseModel
|
|
15
|
+
from orchestrator.search.core.types import EntityType
|
|
16
|
+
|
|
17
|
+
from .traverse import (
|
|
18
|
+
BaseTraverser,
|
|
19
|
+
ProcessTraverser,
|
|
20
|
+
ProductTraverser,
|
|
21
|
+
SubscriptionTraverser,
|
|
22
|
+
WorkflowTraverser,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
ModelT = TypeVar("ModelT", bound=BaseModel)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class EntityConfig(Generic[ModelT]):
|
|
30
|
+
"""A container for all configuration related to a specific entity type."""
|
|
31
|
+
|
|
32
|
+
entity_kind: EntityType
|
|
33
|
+
table: type[ModelT]
|
|
34
|
+
|
|
35
|
+
traverser: "type[BaseTraverser]"
|
|
36
|
+
pk_name: str
|
|
37
|
+
root_name: str
|
|
38
|
+
|
|
39
|
+
def get_all_query(self, entity_id: str | None = None) -> Query | Select:
|
|
40
|
+
query = self.table.query
|
|
41
|
+
if entity_id:
|
|
42
|
+
pk_column = getattr(self.table, self.pk_name)
|
|
43
|
+
query = query.filter(pk_column == UUID(entity_id))
|
|
44
|
+
return query
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class WorkflowConfig(EntityConfig[WorkflowTable]):
|
|
49
|
+
"""Workflows have a custom select() function that filters out deleted workflows."""
|
|
50
|
+
|
|
51
|
+
def get_all_query(self, entity_id: str | None = None) -> Select:
|
|
52
|
+
query = self.table.select()
|
|
53
|
+
if entity_id:
|
|
54
|
+
pk_column = getattr(self.table, self.pk_name)
|
|
55
|
+
query = query.where(pk_column == UUID(entity_id))
|
|
56
|
+
return query
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
ENTITY_CONFIG_REGISTRY: dict[EntityType, EntityConfig] = {
|
|
60
|
+
EntityType.SUBSCRIPTION: EntityConfig(
|
|
61
|
+
entity_kind=EntityType.SUBSCRIPTION,
|
|
62
|
+
table=SubscriptionTable,
|
|
63
|
+
traverser=SubscriptionTraverser,
|
|
64
|
+
pk_name="subscription_id",
|
|
65
|
+
root_name="subscription",
|
|
66
|
+
),
|
|
67
|
+
EntityType.PRODUCT: EntityConfig(
|
|
68
|
+
entity_kind=EntityType.PRODUCT,
|
|
69
|
+
table=ProductTable,
|
|
70
|
+
traverser=ProductTraverser,
|
|
71
|
+
pk_name="product_id",
|
|
72
|
+
root_name="product",
|
|
73
|
+
),
|
|
74
|
+
EntityType.PROCESS: EntityConfig(
|
|
75
|
+
entity_kind=EntityType.PROCESS,
|
|
76
|
+
table=ProcessTable,
|
|
77
|
+
traverser=ProcessTraverser,
|
|
78
|
+
pk_name="process_id",
|
|
79
|
+
root_name="process",
|
|
80
|
+
),
|
|
81
|
+
EntityType.WORKFLOW: WorkflowConfig(
|
|
82
|
+
entity_kind=EntityType.WORKFLOW,
|
|
83
|
+
table=WorkflowTable,
|
|
84
|
+
traverser=WorkflowTraverser,
|
|
85
|
+
pk_name="workflow_id",
|
|
86
|
+
root_name="workflow",
|
|
87
|
+
),
|
|
88
|
+
}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
from sqlalchemy.orm import Query
|
|
3
|
+
|
|
4
|
+
from orchestrator.db import db
|
|
5
|
+
from orchestrator.search.core.types import EntityType
|
|
6
|
+
from orchestrator.search.indexing.indexer import Indexer
|
|
7
|
+
from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
|
|
8
|
+
|
|
9
|
+
logger = structlog.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def run_indexing_for_entity(
|
|
13
|
+
entity_kind: EntityType,
|
|
14
|
+
entity_id: str | None = None,
|
|
15
|
+
dry_run: bool = False,
|
|
16
|
+
force_index: bool = False,
|
|
17
|
+
chunk_size: int = 1000,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Stream and index entities for the given kind.
|
|
20
|
+
|
|
21
|
+
Builds a streaming query via the entity's registry config, disables ORM eager
|
|
22
|
+
loads when applicable and delegates processing to `Indexer`.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
entity_kind (EntityType): The entity type to index (must exist in
|
|
26
|
+
`ENTITY_CONFIG_REGISTRY`).
|
|
27
|
+
entity_id (Optional[str]): If provided, restricts indexing to a single
|
|
28
|
+
entity (UUID string).
|
|
29
|
+
dry_run (bool): When True, runs the full pipeline without performing
|
|
30
|
+
writes or external embedding calls.
|
|
31
|
+
force_index (bool): When True, re-indexes all fields regardless of
|
|
32
|
+
existing hashes.
|
|
33
|
+
chunk_size (int): Number of rows fetched per round-trip and passed to
|
|
34
|
+
the indexer per batch.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
None
|
|
38
|
+
"""
|
|
39
|
+
config = ENTITY_CONFIG_REGISTRY[entity_kind]
|
|
40
|
+
|
|
41
|
+
q = config.get_all_query(entity_id)
|
|
42
|
+
|
|
43
|
+
if isinstance(q, Query):
|
|
44
|
+
q = q.enable_eagerloads(False)
|
|
45
|
+
stmt = q.statement
|
|
46
|
+
else:
|
|
47
|
+
stmt = q
|
|
48
|
+
|
|
49
|
+
stmt = stmt.execution_options(stream_results=True, yield_per=chunk_size)
|
|
50
|
+
entities = db.session.execute(stmt).scalars()
|
|
51
|
+
|
|
52
|
+
indexer = Indexer(config=config, dry_run=dry_run, force_index=force_index, chunk_size=chunk_size)
|
|
53
|
+
indexer.run(entities)
|