databao-context-engine 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. databao_context_engine/__init__.py +35 -0
  2. databao_context_engine/build_sources/__init__.py +0 -0
  3. databao_context_engine/build_sources/internal/__init__.py +0 -0
  4. databao_context_engine/build_sources/internal/build_runner.py +111 -0
  5. databao_context_engine/build_sources/internal/build_service.py +77 -0
  6. databao_context_engine/build_sources/internal/build_wiring.py +52 -0
  7. databao_context_engine/build_sources/internal/export_results.py +43 -0
  8. databao_context_engine/build_sources/internal/plugin_execution.py +74 -0
  9. databao_context_engine/build_sources/public/__init__.py +0 -0
  10. databao_context_engine/build_sources/public/api.py +4 -0
  11. databao_context_engine/cli/__init__.py +0 -0
  12. databao_context_engine/cli/add_datasource_config.py +130 -0
  13. databao_context_engine/cli/commands.py +256 -0
  14. databao_context_engine/cli/datasources.py +64 -0
  15. databao_context_engine/cli/info.py +32 -0
  16. databao_context_engine/config/__init__.py +0 -0
  17. databao_context_engine/config/log_config.yaml +16 -0
  18. databao_context_engine/config/logging.py +43 -0
  19. databao_context_engine/databao_context_project_manager.py +92 -0
  20. databao_context_engine/databao_engine.py +85 -0
  21. databao_context_engine/datasource_config/__init__.py +0 -0
  22. databao_context_engine/datasource_config/add_config.py +50 -0
  23. databao_context_engine/datasource_config/check_config.py +131 -0
  24. databao_context_engine/datasource_config/datasource_context.py +60 -0
  25. databao_context_engine/event_journal/__init__.py +0 -0
  26. databao_context_engine/event_journal/writer.py +29 -0
  27. databao_context_engine/generate_configs_schemas.py +92 -0
  28. databao_context_engine/init_project.py +18 -0
  29. databao_context_engine/introspection/__init__.py +0 -0
  30. databao_context_engine/introspection/property_extract.py +202 -0
  31. databao_context_engine/llm/__init__.py +0 -0
  32. databao_context_engine/llm/config.py +20 -0
  33. databao_context_engine/llm/descriptions/__init__.py +0 -0
  34. databao_context_engine/llm/descriptions/ollama.py +21 -0
  35. databao_context_engine/llm/descriptions/provider.py +10 -0
  36. databao_context_engine/llm/embeddings/__init__.py +0 -0
  37. databao_context_engine/llm/embeddings/ollama.py +37 -0
  38. databao_context_engine/llm/embeddings/provider.py +13 -0
  39. databao_context_engine/llm/errors.py +16 -0
  40. databao_context_engine/llm/factory.py +61 -0
  41. databao_context_engine/llm/install.py +227 -0
  42. databao_context_engine/llm/runtime.py +73 -0
  43. databao_context_engine/llm/service.py +159 -0
  44. databao_context_engine/main.py +19 -0
  45. databao_context_engine/mcp/__init__.py +0 -0
  46. databao_context_engine/mcp/all_results_tool.py +5 -0
  47. databao_context_engine/mcp/mcp_runner.py +16 -0
  48. databao_context_engine/mcp/mcp_server.py +63 -0
  49. databao_context_engine/mcp/retrieve_tool.py +22 -0
  50. databao_context_engine/pluginlib/__init__.py +0 -0
  51. databao_context_engine/pluginlib/build_plugin.py +107 -0
  52. databao_context_engine/pluginlib/config.py +37 -0
  53. databao_context_engine/pluginlib/plugin_utils.py +68 -0
  54. databao_context_engine/plugins/__init__.py +0 -0
  55. databao_context_engine/plugins/athena_db_plugin.py +12 -0
  56. databao_context_engine/plugins/base_db_plugin.py +45 -0
  57. databao_context_engine/plugins/clickhouse_db_plugin.py +15 -0
  58. databao_context_engine/plugins/databases/__init__.py +0 -0
  59. databao_context_engine/plugins/databases/athena_introspector.py +101 -0
  60. databao_context_engine/plugins/databases/base_introspector.py +144 -0
  61. databao_context_engine/plugins/databases/clickhouse_introspector.py +162 -0
  62. databao_context_engine/plugins/databases/database_chunker.py +69 -0
  63. databao_context_engine/plugins/databases/databases_types.py +114 -0
  64. databao_context_engine/plugins/databases/duckdb_introspector.py +325 -0
  65. databao_context_engine/plugins/databases/introspection_model_builder.py +270 -0
  66. databao_context_engine/plugins/databases/introspection_scope.py +74 -0
  67. databao_context_engine/plugins/databases/introspection_scope_matcher.py +103 -0
  68. databao_context_engine/plugins/databases/mssql_introspector.py +433 -0
  69. databao_context_engine/plugins/databases/mysql_introspector.py +338 -0
  70. databao_context_engine/plugins/databases/postgresql_introspector.py +428 -0
  71. databao_context_engine/plugins/databases/snowflake_introspector.py +287 -0
  72. databao_context_engine/plugins/duckdb_db_plugin.py +12 -0
  73. databao_context_engine/plugins/mssql_db_plugin.py +12 -0
  74. databao_context_engine/plugins/mysql_db_plugin.py +12 -0
  75. databao_context_engine/plugins/parquet_plugin.py +32 -0
  76. databao_context_engine/plugins/plugin_loader.py +110 -0
  77. databao_context_engine/plugins/postgresql_db_plugin.py +12 -0
  78. databao_context_engine/plugins/resources/__init__.py +0 -0
  79. databao_context_engine/plugins/resources/parquet_chunker.py +23 -0
  80. databao_context_engine/plugins/resources/parquet_introspector.py +154 -0
  81. databao_context_engine/plugins/snowflake_db_plugin.py +12 -0
  82. databao_context_engine/plugins/unstructured_files_plugin.py +68 -0
  83. databao_context_engine/project/__init__.py +0 -0
  84. databao_context_engine/project/datasource_discovery.py +141 -0
  85. databao_context_engine/project/info.py +44 -0
  86. databao_context_engine/project/init_project.py +102 -0
  87. databao_context_engine/project/layout.py +127 -0
  88. databao_context_engine/project/project_config.py +32 -0
  89. databao_context_engine/project/resources/examples/src/databases/example_postgres.yaml +7 -0
  90. databao_context_engine/project/resources/examples/src/files/documentation.md +30 -0
  91. databao_context_engine/project/resources/examples/src/files/notes.txt +20 -0
  92. databao_context_engine/project/runs.py +39 -0
  93. databao_context_engine/project/types.py +134 -0
  94. databao_context_engine/retrieve_embeddings/__init__.py +0 -0
  95. databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
  96. databao_context_engine/retrieve_embeddings/internal/export_results.py +12 -0
  97. databao_context_engine/retrieve_embeddings/internal/retrieve_runner.py +34 -0
  98. databao_context_engine/retrieve_embeddings/internal/retrieve_service.py +68 -0
  99. databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +29 -0
  100. databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
  101. databao_context_engine/retrieve_embeddings/public/api.py +3 -0
  102. databao_context_engine/serialisation/__init__.py +0 -0
  103. databao_context_engine/serialisation/yaml.py +35 -0
  104. databao_context_engine/services/__init__.py +0 -0
  105. databao_context_engine/services/chunk_embedding_service.py +104 -0
  106. databao_context_engine/services/embedding_shard_resolver.py +64 -0
  107. databao_context_engine/services/factories.py +88 -0
  108. databao_context_engine/services/models.py +12 -0
  109. databao_context_engine/services/persistence_service.py +61 -0
  110. databao_context_engine/services/run_name_policy.py +8 -0
  111. databao_context_engine/services/table_name_policy.py +15 -0
  112. databao_context_engine/storage/__init__.py +0 -0
  113. databao_context_engine/storage/connection.py +32 -0
  114. databao_context_engine/storage/exceptions/__init__.py +0 -0
  115. databao_context_engine/storage/exceptions/exceptions.py +6 -0
  116. databao_context_engine/storage/migrate.py +127 -0
  117. databao_context_engine/storage/migrations/V01__init.sql +63 -0
  118. databao_context_engine/storage/models.py +51 -0
  119. databao_context_engine/storage/repositories/__init__.py +0 -0
  120. databao_context_engine/storage/repositories/chunk_repository.py +130 -0
  121. databao_context_engine/storage/repositories/datasource_run_repository.py +136 -0
  122. databao_context_engine/storage/repositories/embedding_model_registry_repository.py +87 -0
  123. databao_context_engine/storage/repositories/embedding_repository.py +113 -0
  124. databao_context_engine/storage/repositories/factories.py +35 -0
  125. databao_context_engine/storage/repositories/run_repository.py +157 -0
  126. databao_context_engine/storage/repositories/vector_search_repository.py +63 -0
  127. databao_context_engine/storage/transaction.py +14 -0
  128. databao_context_engine/system/__init__.py +0 -0
  129. databao_context_engine/system/properties.py +13 -0
  130. databao_context_engine/templating/__init__.py +0 -0
  131. databao_context_engine/templating/renderer.py +29 -0
  132. databao_context_engine-0.1.1.dist-info/METADATA +186 -0
  133. databao_context_engine-0.1.1.dist-info/RECORD +135 -0
  134. databao_context_engine-0.1.1.dist-info/WHEEL +4 -0
  135. databao_context_engine-0.1.1.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,134 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from enum import StrEnum
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from databao_context_engine.pluginlib.build_plugin import DatasourceType
8
+
9
+
10
+ class DatasourceKind(StrEnum):
11
+ CONFIG = "config"
12
+ FILE = "file"
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class DatasourceDescriptor:
17
+ path: Path
18
+ kind: DatasourceKind
19
+ main_type: str
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class PreparedConfig:
24
+ datasource_type: DatasourceType
25
+ path: Path
26
+ config: dict[Any, Any]
27
+ datasource_name: str
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class PreparedFile:
32
+ datasource_type: DatasourceType
33
+ path: Path
34
+
35
+
36
+ PreparedDatasource = PreparedConfig | PreparedFile
37
+
38
+
39
+ @dataclass(kw_only=True, frozen=True, eq=True)
40
+ class DatasourceId:
41
+ """
42
+ The ID of a datasource. The ID is the path to the datasource's config file relative to the src folder in the project.
43
+
44
+ e.g: "databases/my_postgres_datasource.yaml"
45
+
46
+ Use the provided factory methods `from_string_repr` and `from_datasource_config_file_path` to create a DatasourceId, rather than its constructor.
47
+ """
48
+
49
+ datasource_config_folder: str
50
+ datasource_name: str
51
+ config_file_suffix: str
52
+
53
+ def __post_init__(self):
54
+ if not self.datasource_config_folder.strip():
55
+ raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_config_folder must not be empty")
56
+ if not self.datasource_name.strip():
57
+ raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not be empty")
58
+ if not self.config_file_suffix.strip():
59
+ raise ValueError(f"Invalid DatasourceId ({str(self)}): config_file_suffix must not be empty")
60
+
61
+ if os.sep in self.datasource_config_folder:
62
+ raise ValueError(
63
+ f"Invalid DatasourceId ({str(self)}): datasource_config_folder must not contain a path separator"
64
+ )
65
+
66
+ if os.sep in self.datasource_name:
67
+ raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not contain a path separator")
68
+
69
+ if not self.config_file_suffix.startswith("."):
70
+ raise ValueError(
71
+ f'Invalid DatasourceId ({str(self)}): config_file_suffix must start with a dot "." (e.g.: .yaml)'
72
+ )
73
+
74
+ if self.datasource_name.endswith(self.config_file_suffix):
75
+ raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not contain the file suffix")
76
+
77
+ def __str__(self):
78
+ return str(self.relative_path_to_config_file())
79
+
80
+ def relative_path_to_config_file(self) -> Path:
81
+ """
82
+ Returns a path to the config file for this datasource.
83
+
84
+ The returned path is relative to the src folder in the project.
85
+ """
86
+ return Path(self.datasource_config_folder).joinpath(self.datasource_name + self.config_file_suffix)
87
+
88
+ def relative_path_to_context_file(self) -> Path:
89
+ """
90
+ Returns a path to the config file for this datasource.
91
+
92
+ The returned path is relative to an output run folder in the project.
93
+ """
94
+ # Keep the suffix in the filename if this datasource is a raw file, to handle multiple files with the same name and different extensions
95
+ suffix = ".yaml" if self.config_file_suffix == ".yaml" else (self.config_file_suffix + ".yaml")
96
+
97
+ return Path(self.datasource_config_folder).joinpath(self.datasource_name + suffix)
98
+
99
+ @classmethod
100
+ def from_string_repr(cls, datasource_id_as_string: str):
101
+ """
102
+ Creates a DatasourceId from a string representation.
103
+
104
+ The string representation of a DatasourceId is the path to the datasource's config file relative to the src folder in the project.
105
+
106
+ e.g: "databases/my_postgres_datasource.yaml"
107
+ """
108
+ config_file_path = Path(datasource_id_as_string)
109
+
110
+ if len(config_file_path.parents) > 2:
111
+ raise ValueError(
112
+ f"Invalid string representation of a DatasourceId: too many parent folders defined in {datasource_id_as_string}"
113
+ )
114
+
115
+ return DatasourceId.from_datasource_config_file_path(config_file_path)
116
+
117
+ @classmethod
118
+ def from_datasource_config_file_path(cls, datasource_config_file: Path):
119
+ """
120
+ Creates a DatasourceId from a config file path.
121
+
122
+ The `datasource_config_file` path provided can either be the config file path relative to the src folder or the full path to the config file.
123
+ """
124
+ return DatasourceId(
125
+ datasource_config_folder=datasource_config_file.parent.name,
126
+ datasource_name=datasource_config_file.stem,
127
+ config_file_suffix=datasource_config_file.suffix,
128
+ )
129
+
130
+
131
+ @dataclass
132
+ class Datasource:
133
+ id: DatasourceId
134
+ type: DatasourceType
File without changes
@@ -0,0 +1,12 @@
1
+ from pathlib import Path
2
+
3
+
4
+ def export_retrieve_results(run_dir: Path, retrieve_results: list[str]) -> Path:
5
+ path = run_dir.joinpath("context_duckdb.yaml")
6
+
7
+ with path.open("w") as export_file:
8
+ for result in retrieve_results:
9
+ export_file.write(result)
10
+ export_file.write("\n")
11
+
12
+ return path
@@ -0,0 +1,34 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from databao_context_engine.project.runs import get_run_dir
5
+ from databao_context_engine.retrieve_embeddings.internal.export_results import export_retrieve_results
6
+ from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
7
+ from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def retrieve(
13
+ project_dir: Path,
14
+ *,
15
+ retrieve_service: RetrieveService,
16
+ project_id: str,
17
+ text: str,
18
+ run_name: str | None,
19
+ limit: int | None,
20
+ export_to_file: bool,
21
+ ) -> list[VectorSearchResult]:
22
+ resolved_run_name = retrieve_service.resolve_run_name(project_id=project_id, run_name=run_name)
23
+ retrieve_results = retrieve_service.retrieve(
24
+ project_id=project_id, text=text, run_name=resolved_run_name, limit=limit
25
+ )
26
+
27
+ if export_to_file:
28
+ export_directory = get_run_dir(project_dir=project_dir, run_name=resolved_run_name)
29
+
30
+ display_texts = [result.display_text for result in retrieve_results]
31
+ export_file = export_retrieve_results(export_directory, display_texts)
32
+ logger.info(f"Exported results to {export_file}")
33
+
34
+ return retrieve_results
@@ -0,0 +1,68 @@
1
+ import logging
2
+ from collections.abc import Sequence
3
+
4
+ from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
5
+ from databao_context_engine.project.runs import resolve_run_name_from_repo
6
+ from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
7
+ from databao_context_engine.storage.repositories.run_repository import RunRepository
8
+ from databao_context_engine.storage.repositories.vector_search_repository import (
9
+ VectorSearchRepository,
10
+ VectorSearchResult,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RetrieveService:
17
+ def __init__(
18
+ self,
19
+ *,
20
+ run_repo: RunRepository,
21
+ vector_search_repo: VectorSearchRepository,
22
+ shard_resolver: EmbeddingShardResolver,
23
+ provider: EmbeddingProvider,
24
+ ):
25
+ self._run_repo = run_repo
26
+ self._shard_resolver = shard_resolver
27
+ self._provider = provider
28
+ self._vector_search_repo = vector_search_repo
29
+
30
+ def retrieve(
31
+ self, *, project_id: str, text: str, run_name: str, limit: int | None = None
32
+ ) -> list[VectorSearchResult]:
33
+ if limit is None:
34
+ limit = 10
35
+
36
+ run = self._run_repo.get_by_run_name(project_id=project_id, run_name=run_name)
37
+ if run is None:
38
+ raise LookupError(f"Run '{run_name}' not found for project '{project_id}'.")
39
+
40
+ table_name, dimension = self._shard_resolver.resolve(
41
+ embedder=self._provider.embedder, model_id=self._provider.model_id
42
+ )
43
+
44
+ retrieve_vec: Sequence[float] = self._provider.embed(text)
45
+
46
+ logger.debug(f"Retrieving display texts for run {run.run_id} in table {table_name}")
47
+
48
+ search_results = self._vector_search_repo.get_display_texts_by_similarity(
49
+ table_name=table_name,
50
+ run_id=run.run_id,
51
+ retrieve_vec=retrieve_vec,
52
+ dimension=dimension,
53
+ limit=limit,
54
+ )
55
+
56
+ logger.debug(f"Retrieved {len(search_results)} display texts for run {run.run_id} in table {table_name}")
57
+
58
+ if logger.isEnabledFor(logging.DEBUG):
59
+ closest_result = min(search_results, key=lambda result: result.cosine_distance)
60
+ logger.debug(f"Best result: ({closest_result.cosine_distance}, {closest_result.embeddable_text})")
61
+
62
+ farthest_result = max(search_results, key=lambda result: result.cosine_distance)
63
+ logger.debug(f"Worst result: ({farthest_result.cosine_distance}, {farthest_result.embeddable_text})")
64
+
65
+ return search_results
66
+
67
+ def resolve_run_name(self, *, project_id: str, run_name: str | None) -> str:
68
+ return resolve_run_name_from_repo(run_repository=self._run_repo, project_id=project_id, run_name=run_name)
@@ -0,0 +1,29 @@
1
+ from databao_context_engine.llm.factory import create_ollama_embedding_provider, create_ollama_service
2
+ from databao_context_engine.project.layout import ProjectLayout
3
+ from databao_context_engine.retrieve_embeddings.internal.retrieve_runner import retrieve
4
+ from databao_context_engine.services.factories import create_retrieve_service
5
+ from databao_context_engine.storage.connection import open_duckdb_connection
6
+ from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
7
+ from databao_context_engine.system.properties import get_db_path
8
+
9
+
10
+ def retrieve_embeddings(
11
+ project_layout: ProjectLayout,
12
+ retrieve_text: str,
13
+ run_name: str | None,
14
+ limit: int | None,
15
+ export_to_file: bool,
16
+ ) -> list[VectorSearchResult]:
17
+ with open_duckdb_connection(get_db_path()) as conn:
18
+ ollama_service = create_ollama_service()
19
+ embedding_provider = create_ollama_embedding_provider(ollama_service)
20
+ retrieve_service = create_retrieve_service(conn, embedding_provider=embedding_provider)
21
+ return retrieve(
22
+ project_dir=project_layout.project_dir,
23
+ retrieve_service=retrieve_service,
24
+ project_id=str(project_layout.read_config_file().project_id),
25
+ text=retrieve_text,
26
+ run_name=run_name,
27
+ limit=limit,
28
+ export_to_file=export_to_file,
29
+ )
@@ -0,0 +1,3 @@
1
+ from databao_context_engine.retrieve_embeddings.internal.retrieve_wiring import retrieve_embeddings
2
+
3
+ __all__ = ["retrieve_embeddings"]
File without changes
@@ -0,0 +1,35 @@
1
+ from typing import Any, Mapping, TextIO, cast
2
+
3
+ import yaml
4
+ from yaml import Node, SafeDumper
5
+
6
+
7
+ def default_representer(dumper: SafeDumper, data: object) -> Node:
8
+ if isinstance(data, Mapping):
9
+ return dumper.represent_dict(data)
10
+ elif hasattr(data, "__dict__"):
11
+ # Doesn't serialise "private" attributes (that starts with an _)
12
+ data_public_attributes = {key: value for key, value in data.__dict__.items() if not key.startswith("_")}
13
+ if data_public_attributes:
14
+ return dumper.represent_dict(data_public_attributes)
15
+ else:
16
+ # If there is no public attributes, we default to the string representation
17
+ return dumper.represent_str(str(data))
18
+ else:
19
+ return dumper.represent_str(str(data))
20
+
21
+
22
+ # Registers our default representer only once, when that file is imported
23
+ yaml.add_multi_representer(object, default_representer, Dumper=SafeDumper)
24
+
25
+
26
+ def write_yaml_to_stream(*, data: Any, file_stream: TextIO) -> None:
27
+ _to_yaml(data, file_stream)
28
+
29
+
30
+ def to_yaml_string(data: Any) -> str:
31
+ return cast(str, _to_yaml(data, None))
32
+
33
+
34
+ def _to_yaml(data: Any, stream: TextIO | None) -> str | None:
35
+ return yaml.safe_dump(data, stream, sort_keys=False, default_flow_style=False)
File without changes
@@ -0,0 +1,104 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import cast
4
+
5
+ from databao_context_engine.llm.descriptions.provider import DescriptionProvider
6
+ from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
7
+ from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk
8
+ from databao_context_engine.serialisation.yaml import to_yaml_string
9
+ from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
10
+ from databao_context_engine.services.models import ChunkEmbedding
11
+ from databao_context_engine.services.persistence_service import PersistenceService
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ChunkEmbeddingMode(Enum):
17
+ EMBEDDABLE_TEXT_ONLY = "EMBEDDABLE_TEXT_ONLY"
18
+ GENERATED_DESCRIPTION_ONLY = "GENERATED_DESCRIPTION_ONLY"
19
+ EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION = "EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION"
20
+
21
+ def should_generate_description(self) -> bool:
22
+ return self in (
23
+ ChunkEmbeddingMode.GENERATED_DESCRIPTION_ONLY,
24
+ ChunkEmbeddingMode.EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION,
25
+ )
26
+
27
+
28
+ class ChunkEmbeddingService:
29
+ def __init__(
30
+ self,
31
+ *,
32
+ persistence_service: PersistenceService,
33
+ embedding_provider: EmbeddingProvider,
34
+ description_provider: DescriptionProvider | None,
35
+ shard_resolver: EmbeddingShardResolver,
36
+ chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
37
+ ):
38
+ self._persistence_service = persistence_service
39
+ self._embedding_provider = embedding_provider
40
+ self._description_provider = description_provider
41
+ self._shard_resolver = shard_resolver
42
+ self._chunk_embedding_mode = chunk_embedding_mode
43
+
44
+ if self._chunk_embedding_mode.should_generate_description() and description_provider is None:
45
+ raise ValueError("A DescriptionProvider must be provided when generating descriptions")
46
+
47
+ def embed_chunks(self, *, datasource_run_id: int, chunks: list[EmbeddableChunk], result: str) -> None:
48
+ """
49
+ Turn plugin chunks into persisted chunks and embeddings
50
+
51
+ Flow:
52
+ 1) Embed each chunk into an embedded vector
53
+ 2) Get or create embedding table for the appropriate model and embedding dimensions
54
+ 3) Persist chunks and embeddings vectors in a single transaction
55
+ """
56
+
57
+ if not chunks:
58
+ return
59
+
60
+ logger.debug(
61
+ f"Embedding {len(chunks)} chunks for datasource run {datasource_run_id}, with chunk_embedding_mode={self._chunk_embedding_mode}"
62
+ )
63
+
64
+ enriched_embeddings: list[ChunkEmbedding] = []
65
+ for chunk in chunks:
66
+ chunk_display_text = to_yaml_string(chunk.content)
67
+
68
+ generated_description = ""
69
+ match self._chunk_embedding_mode:
70
+ case ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY:
71
+ embedding_text = chunk.embeddable_text
72
+ case ChunkEmbeddingMode.GENERATED_DESCRIPTION_ONLY:
73
+ generated_description = cast(DescriptionProvider, self._description_provider).describe(
74
+ text=chunk_display_text, context=result
75
+ )
76
+ embedding_text = generated_description
77
+ case ChunkEmbeddingMode.EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION:
78
+ generated_description = cast(DescriptionProvider, self._description_provider).describe(
79
+ text=chunk_display_text, context=result
80
+ )
81
+ embedding_text = generated_description + "\n" + chunk.embeddable_text
82
+
83
+ vec = self._embedding_provider.embed(embedding_text)
84
+
85
+ enriched_embeddings.append(
86
+ ChunkEmbedding(
87
+ chunk=chunk,
88
+ vec=vec,
89
+ display_text=chunk_display_text,
90
+ generated_description=generated_description,
91
+ )
92
+ )
93
+
94
+ table_name = self._shard_resolver.resolve_or_create(
95
+ embedder=self._embedding_provider.embedder,
96
+ model_id=self._embedding_provider.model_id,
97
+ dim=self._embedding_provider.dim,
98
+ )
99
+
100
+ self._persistence_service.write_chunks_and_embeddings(
101
+ datasource_run_id=datasource_run_id,
102
+ chunk_embeddings=enriched_embeddings,
103
+ table_name=table_name,
104
+ )
@@ -0,0 +1,64 @@
1
+ import duckdb
2
+
3
+ from databao_context_engine.services.table_name_policy import TableNamePolicy
4
+ from databao_context_engine.storage.repositories.embedding_model_registry_repository import (
5
+ EmbeddingModelRegistryRepository,
6
+ )
7
+
8
+
9
+ class EmbeddingShardResolver:
10
+ def __init__(
11
+ self,
12
+ *,
13
+ conn: duckdb.DuckDBPyConnection,
14
+ registry_repo: EmbeddingModelRegistryRepository,
15
+ table_name_policy: TableNamePolicy | None = None,
16
+ ):
17
+ self._conn = conn
18
+ self._registry = registry_repo
19
+ self._policy = table_name_policy or TableNamePolicy()
20
+
21
+ def resolve(self, *, embedder: str, model_id: str) -> tuple[str, int]:
22
+ row = self._registry.get(embedder=embedder, model_id=model_id)
23
+ if not row:
24
+ raise ValueError(f"Model not registered: {embedder}:{model_id}")
25
+ return row.table_name, row.dim
26
+
27
+ def resolve_or_create(self, *, embedder: str, model_id: str, dim: int) -> str:
28
+ row = self._registry.get(embedder=embedder, model_id=model_id)
29
+ if row:
30
+ if row.dim != dim:
31
+ raise ValueError(f"Model already registered with dim={row.dim}, requested dim={dim}")
32
+ return row.table_name
33
+
34
+ table_name = self._policy.build(embedder=embedder, model_id=model_id, dim=dim)
35
+ self._create_table_and_index(table_name, dim)
36
+
37
+ self._registry.create(
38
+ embedder=embedder,
39
+ model_id=model_id,
40
+ dim=dim,
41
+ table_name=table_name,
42
+ )
43
+
44
+ return table_name
45
+
46
+ def _create_table_and_index(self, table_name: str, dim: int) -> None:
47
+ self._conn.execute("LOAD vss;")
48
+ self._conn.execute("SET hnsw_enable_experimental_persistence = true;")
49
+
50
+ self._conn.execute(
51
+ f"""
52
+ CREATE TABLE IF NOT EXISTS {table_name} (
53
+ chunk_id BIGINT NOT NULL REFERENCES chunk(chunk_id),
54
+ vec FLOAT[{dim}] NOT NULL,
55
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
56
+ PRIMARY KEY (chunk_id)
57
+ )
58
+ """
59
+ )
60
+ self._conn.execute(
61
+ f"""
62
+ CREATE INDEX IF NOT EXISTS emb_hnsw_{table_name} ON {table_name} USING HNSW (vec) WITH (metric='cosine');
63
+ """
64
+ )
@@ -0,0 +1,88 @@
1
+ from _duckdb import DuckDBPyConnection
2
+
3
+ from databao_context_engine.build_sources.internal.build_service import BuildService
4
+ from databao_context_engine.llm.descriptions.provider import DescriptionProvider
5
+ from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
6
+ from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
7
+ from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode, ChunkEmbeddingService
8
+ from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
9
+ from databao_context_engine.services.persistence_service import PersistenceService
10
+ from databao_context_engine.services.table_name_policy import TableNamePolicy
11
+ from databao_context_engine.storage.repositories.factories import (
12
+ create_chunk_repository,
13
+ create_datasource_run_repository,
14
+ create_embedding_repository,
15
+ create_registry_repository,
16
+ create_run_repository,
17
+ create_vector_search_repository,
18
+ )
19
+
20
+
21
+ def create_shard_resolver(conn: DuckDBPyConnection, policy: TableNamePolicy | None = None) -> EmbeddingShardResolver:
22
+ return EmbeddingShardResolver(
23
+ conn=conn, registry_repo=create_registry_repository(conn), table_name_policy=policy or TableNamePolicy()
24
+ )
25
+
26
+
27
+ def create_persistence_service(conn: DuckDBPyConnection) -> PersistenceService:
28
+ return PersistenceService(
29
+ conn=conn, chunk_repo=create_chunk_repository(conn), embedding_repo=create_embedding_repository(conn)
30
+ )
31
+
32
+
33
+ def create_chunk_embedding_service(
34
+ conn: DuckDBPyConnection,
35
+ *,
36
+ embedding_provider: EmbeddingProvider,
37
+ description_provider: DescriptionProvider | None,
38
+ chunk_embedding_mode: ChunkEmbeddingMode,
39
+ ) -> ChunkEmbeddingService:
40
+ resolver = create_shard_resolver(conn)
41
+ persistence = create_persistence_service(conn)
42
+ return ChunkEmbeddingService(
43
+ persistence_service=persistence,
44
+ embedding_provider=embedding_provider,
45
+ shard_resolver=resolver,
46
+ description_provider=description_provider,
47
+ chunk_embedding_mode=chunk_embedding_mode,
48
+ )
49
+
50
+
51
+ def create_build_service(
52
+ conn: DuckDBPyConnection,
53
+ *,
54
+ embedding_provider: EmbeddingProvider,
55
+ description_provider: DescriptionProvider | None,
56
+ chunk_embedding_mode: ChunkEmbeddingMode,
57
+ ) -> BuildService:
58
+ run_repo = create_run_repository(conn)
59
+ datasource_run_repo = create_datasource_run_repository(conn)
60
+ chunk_embedding_service = create_chunk_embedding_service(
61
+ conn,
62
+ embedding_provider=embedding_provider,
63
+ description_provider=description_provider,
64
+ chunk_embedding_mode=chunk_embedding_mode,
65
+ )
66
+
67
+ return BuildService(
68
+ run_repo=run_repo,
69
+ datasource_run_repo=datasource_run_repo,
70
+ chunk_embedding_service=chunk_embedding_service,
71
+ )
72
+
73
+
74
+ def create_retrieve_service(
75
+ conn: DuckDBPyConnection,
76
+ *,
77
+ embedding_provider: EmbeddingProvider,
78
+ ) -> RetrieveService:
79
+ run_repo = create_run_repository(conn)
80
+ vector_search_repo = create_vector_search_repository(conn)
81
+ shard_resolver = create_shard_resolver(conn)
82
+
83
+ return RetrieveService(
84
+ run_repo=run_repo,
85
+ vector_search_repo=vector_search_repo,
86
+ shard_resolver=shard_resolver,
87
+ provider=embedding_provider,
88
+ )
@@ -0,0 +1,12 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass
3
+
4
+ from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class ChunkEmbedding:
9
+ chunk: EmbeddableChunk
10
+ vec: Sequence[float]
11
+ display_text: str
12
+ generated_description: str
@@ -0,0 +1,61 @@
1
+ from collections.abc import Sequence
2
+
3
+ import duckdb
4
+
5
+ from databao_context_engine.services.models import ChunkEmbedding
6
+ from databao_context_engine.storage.models import ChunkDTO
7
+ from databao_context_engine.storage.repositories.chunk_repository import ChunkRepository
8
+ from databao_context_engine.storage.repositories.embedding_repository import EmbeddingRepository
9
+ from databao_context_engine.storage.transaction import transaction
10
+
11
+
12
+ class PersistenceService:
13
+ def __init__(
14
+ self,
15
+ conn: duckdb.DuckDBPyConnection,
16
+ chunk_repo: ChunkRepository,
17
+ embedding_repo: EmbeddingRepository,
18
+ *,
19
+ dim: int = 768,
20
+ ):
21
+ self._conn = conn
22
+ self._chunk_repo = chunk_repo
23
+ self._embedding_repo = embedding_repo
24
+ self._dim = dim
25
+
26
+ def write_chunks_and_embeddings(
27
+ self, *, datasource_run_id: int, chunk_embeddings: list[ChunkEmbedding], table_name: str
28
+ ):
29
+ """
30
+ Atomically persist chunks and their vectors.
31
+ Returns the number of embeddings written.
32
+ """
33
+ if not chunk_embeddings:
34
+ raise ValueError("chunk_embeddings must be a non-empty list")
35
+
36
+ with transaction(self._conn):
37
+ for chunk_embedding in chunk_embeddings:
38
+ chunk_dto = self.create_chunk(
39
+ datasource_run_id=datasource_run_id,
40
+ embeddable_text=chunk_embedding.chunk.embeddable_text,
41
+ display_text=chunk_embedding.display_text,
42
+ generated_description=chunk_embedding.generated_description,
43
+ )
44
+ self.create_embedding(table_name=table_name, chunk_id=chunk_dto.chunk_id, vec=chunk_embedding.vec)
45
+
46
+ def create_chunk(
47
+ self, *, datasource_run_id: int, embeddable_text: str, display_text: str, generated_description: str
48
+ ) -> ChunkDTO:
49
+ return self._chunk_repo.create(
50
+ datasource_run_id=datasource_run_id,
51
+ embeddable_text=embeddable_text,
52
+ display_text=display_text,
53
+ generated_description=generated_description,
54
+ )
55
+
56
+ def create_embedding(self, *, table_name: str, chunk_id: int, vec: Sequence[float]):
57
+ self._embedding_repo.create(
58
+ table_name=table_name,
59
+ chunk_id=chunk_id,
60
+ vec=vec,
61
+ )
@@ -0,0 +1,8 @@
1
+ from datetime import datetime
2
+
3
+
4
+ class RunNamePolicy:
5
+ _RUN_DIR_PREFIX = "run-"
6
+
7
+ def build(self, *, run_started_at: datetime):
8
+ return f"{RunNamePolicy._RUN_DIR_PREFIX}{run_started_at.isoformat(timespec='seconds')}"