databao-context-engine 0.1.2__py3-none-any.whl → 0.1.4.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. databao_context_engine/__init__.py +18 -6
  2. databao_context_engine/build_sources/__init__.py +4 -0
  3. databao_context_engine/build_sources/{internal/build_runner.py → build_runner.py} +27 -23
  4. databao_context_engine/build_sources/build_service.py +53 -0
  5. databao_context_engine/build_sources/build_wiring.py +84 -0
  6. databao_context_engine/build_sources/export_results.py +41 -0
  7. databao_context_engine/build_sources/{internal/plugin_execution.py → plugin_execution.py} +3 -7
  8. databao_context_engine/cli/add_datasource_config.py +41 -15
  9. databao_context_engine/cli/commands.py +12 -43
  10. databao_context_engine/cli/info.py +3 -2
  11. databao_context_engine/databao_context_engine.py +137 -0
  12. databao_context_engine/databao_context_project_manager.py +96 -6
  13. databao_context_engine/datasources/add_config.py +34 -0
  14. databao_context_engine/{datasource_config → datasources}/check_config.py +18 -7
  15. databao_context_engine/datasources/datasource_context.py +93 -0
  16. databao_context_engine/{project → datasources}/datasource_discovery.py +17 -16
  17. databao_context_engine/{project → datasources}/types.py +64 -15
  18. databao_context_engine/init_project.py +25 -3
  19. databao_context_engine/introspection/property_extract.py +67 -53
  20. databao_context_engine/llm/errors.py +2 -8
  21. databao_context_engine/llm/install.py +13 -20
  22. databao_context_engine/llm/service.py +1 -3
  23. databao_context_engine/mcp/mcp_runner.py +4 -2
  24. databao_context_engine/mcp/mcp_server.py +10 -10
  25. databao_context_engine/plugin_loader.py +111 -0
  26. databao_context_engine/pluginlib/build_plugin.py +25 -9
  27. databao_context_engine/pluginlib/config.py +16 -2
  28. databao_context_engine/plugins/base_db_plugin.py +5 -2
  29. databao_context_engine/plugins/databases/athena_introspector.py +85 -22
  30. databao_context_engine/plugins/databases/base_introspector.py +5 -3
  31. databao_context_engine/plugins/databases/clickhouse_introspector.py +22 -11
  32. databao_context_engine/plugins/databases/duckdb_introspector.py +3 -5
  33. databao_context_engine/plugins/databases/introspection_model_builder.py +1 -1
  34. databao_context_engine/plugins/databases/introspection_scope.py +11 -9
  35. databao_context_engine/plugins/databases/introspection_scope_matcher.py +2 -5
  36. databao_context_engine/plugins/databases/mssql_introspector.py +26 -17
  37. databao_context_engine/plugins/databases/mysql_introspector.py +23 -12
  38. databao_context_engine/plugins/databases/postgresql_introspector.py +2 -2
  39. databao_context_engine/plugins/databases/snowflake_introspector.py +43 -10
  40. databao_context_engine/plugins/duckdb_tools.py +18 -0
  41. databao_context_engine/plugins/plugin_loader.py +43 -42
  42. databao_context_engine/plugins/resources/parquet_introspector.py +7 -19
  43. databao_context_engine/project/info.py +34 -2
  44. databao_context_engine/project/init_project.py +16 -7
  45. databao_context_engine/project/layout.py +3 -3
  46. databao_context_engine/retrieve_embeddings/__init__.py +3 -0
  47. databao_context_engine/retrieve_embeddings/{internal/export_results.py → export_results.py} +2 -2
  48. databao_context_engine/retrieve_embeddings/{internal/retrieve_runner.py → retrieve_runner.py} +5 -9
  49. databao_context_engine/retrieve_embeddings/{internal/retrieve_service.py → retrieve_service.py} +3 -17
  50. databao_context_engine/retrieve_embeddings/retrieve_wiring.py +49 -0
  51. databao_context_engine/{serialisation → serialization}/yaml.py +1 -1
  52. databao_context_engine/services/chunk_embedding_service.py +23 -11
  53. databao_context_engine/services/factories.py +1 -46
  54. databao_context_engine/services/persistence_service.py +11 -11
  55. databao_context_engine/storage/connection.py +11 -7
  56. databao_context_engine/storage/exceptions/exceptions.py +2 -2
  57. databao_context_engine/storage/migrate.py +2 -4
  58. databao_context_engine/storage/migrations/V01__init.sql +6 -31
  59. databao_context_engine/storage/models.py +2 -23
  60. databao_context_engine/storage/repositories/chunk_repository.py +16 -12
  61. databao_context_engine/storage/repositories/factories.py +1 -12
  62. databao_context_engine/storage/repositories/vector_search_repository.py +8 -13
  63. databao_context_engine/system/properties.py +4 -2
  64. databao_context_engine-0.1.4.dev1.dist-info/METADATA +75 -0
  65. databao_context_engine-0.1.4.dev1.dist-info/RECORD +125 -0
  66. {databao_context_engine-0.1.2.dist-info → databao_context_engine-0.1.4.dev1.dist-info}/WHEEL +1 -1
  67. databao_context_engine/build_sources/internal/build_service.py +0 -77
  68. databao_context_engine/build_sources/internal/build_wiring.py +0 -52
  69. databao_context_engine/build_sources/internal/export_results.py +0 -43
  70. databao_context_engine/build_sources/public/api.py +0 -4
  71. databao_context_engine/databao_engine.py +0 -85
  72. databao_context_engine/datasource_config/__init__.py +0 -0
  73. databao_context_engine/datasource_config/add_config.py +0 -50
  74. databao_context_engine/datasource_config/datasource_context.py +0 -60
  75. databao_context_engine/mcp/all_results_tool.py +0 -5
  76. databao_context_engine/mcp/retrieve_tool.py +0 -22
  77. databao_context_engine/project/runs.py +0 -39
  78. databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
  79. databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +0 -29
  80. databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
  81. databao_context_engine/retrieve_embeddings/public/api.py +0 -3
  82. databao_context_engine/serialisation/__init__.py +0 -0
  83. databao_context_engine/services/run_name_policy.py +0 -8
  84. databao_context_engine/storage/repositories/datasource_run_repository.py +0 -136
  85. databao_context_engine/storage/repositories/run_repository.py +0 -157
  86. databao_context_engine-0.1.2.dist-info/METADATA +0 -187
  87. databao_context_engine-0.1.2.dist-info/RECORD +0 -135
  88. /databao_context_engine/{build_sources/internal → datasources}/__init__.py +0 -0
  89. /databao_context_engine/{build_sources/public → serialization}/__init__.py +0 -0
  90. {databao_context_engine-0.1.2.dist-info → databao_context_engine-0.1.4.dev1.dist-info}/entry_points.txt +0 -0
@@ -1,22 +1,57 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Mapping
3
+ from typing import Annotated, Any, Dict, List
4
4
 
5
5
  import snowflake.connector
6
- from pydantic import Field
6
+ from pydantic import BaseModel, Field
7
7
  from snowflake.connector import DictCursor
8
8
 
9
+ from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
9
10
  from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
10
11
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
11
12
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
12
13
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
13
14
 
14
15
 
16
+ class SnowflakePasswordAuth(BaseModel):
17
+ password: Annotated[str, ConfigPropertyAnnotation(secret=True)]
18
+
19
+
20
+ class SnowflakeKeyPairAuth(BaseModel):
21
+ private_key_file: str | None = None
22
+ private_key_file_pwd: str | None = None
23
+ private_key: Annotated[str, ConfigPropertyAnnotation(secret=True)]
24
+
25
+
26
+ class SnowflakeSSOAuth(BaseModel):
27
+ authenticator: str = Field(description='e.g. "externalbrowser"')
28
+
29
+
30
+ class SnowflakeConnectionProperties(BaseModel):
31
+ account: Annotated[str, ConfigPropertyAnnotation(required=True)]
32
+ warehouse: str | None = None
33
+ database: str | None = None
34
+ user: str | None = None
35
+ role: str | None = None
36
+ auth: SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth
37
+ additional_properties: dict[str, Any] = {}
38
+
39
+ def to_snowflake_kwargs(self) -> dict[str, Any]:
40
+ kwargs = self.model_dump(
41
+ exclude={
42
+ "additional_properties": True,
43
+ },
44
+ exclude_none=True,
45
+ )
46
+ auth_fields = kwargs.pop("auth", {})
47
+ kwargs.update(auth_fields)
48
+ kwargs.update(self.additional_properties)
49
+ return kwargs
50
+
51
+
15
52
  class SnowflakeConfigFile(BaseDatabaseConfigFile):
16
- type: str = Field(default="databases/snowflake")
17
- connection: dict[str, Any] = Field(
18
- description="Connection parameters for Snowflake. It can contain any of the keys supported by the Snowflake connection library"
19
- )
53
+ type: str = Field(default="snowflake")
54
+ connection: SnowflakeConnectionProperties
20
55
 
21
56
 
22
57
  class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
@@ -28,11 +63,9 @@ class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
28
63
 
29
64
  def _connect(self, file_config: SnowflakeConfigFile):
30
65
  connection = file_config.connection
31
- if not isinstance(connection, Mapping):
32
- raise ValueError("Invalid YAML config: 'connection' must be a mapping of connection parameters")
33
66
  snowflake.connector.paramstyle = "qmark"
34
67
  return snowflake.connector.connect(
35
- **connection,
68
+ **connection.to_snowflake_kwargs(),
36
69
  )
37
70
 
38
71
  def _connect_to_catalog(self, file_config: SnowflakeConfigFile, catalog: str):
@@ -41,7 +74,7 @@ class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
41
74
  return snowflake.connector.connect(**cfg)
42
75
 
43
76
  def _get_catalogs(self, connection, file_config: SnowflakeConfigFile) -> list[str]:
44
- database = file_config.connection.get("database")
77
+ database = file_config.connection.database
45
78
  if database:
46
79
  return [database]
47
80
 
@@ -0,0 +1,18 @@
1
+ from typing import Any
2
+
3
+ from databao_context_engine.pluginlib.config import DuckDBSecret
4
+
5
+
6
+ def generate_create_secret_sql(secret_name, duckdb_secret: DuckDBSecret) -> str:
7
+ parameters = [("type", duckdb_secret.type)] + list(duckdb_secret.properties.items())
8
+ return f"""CREATE SECRET {secret_name} (
9
+ {", ".join([f"{k} '{v}'" for (k, v) in parameters])}
10
+ );
11
+ """
12
+
13
+
14
+ def fetchall_dicts(cur, sql: str, params=None) -> list[dict[str, Any]]:
15
+ cur.execute(sql, params or [])
16
+ columns = [desc[0].lower() for desc in cur.description] if cur.description else []
17
+ rows = cur.fetchall()
18
+ return [dict(zip(columns, row)) for row in rows]
@@ -14,29 +14,11 @@ class DuplicatePluginTypeError(RuntimeError):
14
14
  """Raised when two plugins register the same <main>/<sub> plugin key."""
15
15
 
16
16
 
17
- PluginList = dict[DatasourceType, BuildPlugin]
18
-
19
-
20
- def get_all_available_plugin_types(exclude_file_plugins: bool = False) -> set[DatasourceType]:
21
- return set(load_plugins(exclude_file_plugins=exclude_file_plugins).keys())
22
-
23
-
24
- def get_plugin_for_type(datasource_type: DatasourceType) -> BuildPlugin:
25
- all_plugins = load_plugins()
26
-
27
- if datasource_type not in all_plugins:
28
- raise ValueError(f"No plugin found for type '{datasource_type.full_type}'")
29
-
30
- return load_plugins()[datasource_type]
31
-
32
-
33
- def load_plugins(exclude_file_plugins: bool = False) -> PluginList:
34
- """
35
- Loads both builtin and external plugins and merges them into one list
36
- """
17
+ def load_plugins(exclude_file_plugins: bool = False) -> dict[DatasourceType, BuildPlugin]:
18
+ """Load both builtin and external plugins and merges them into one list."""
37
19
  builtin_plugins = _load_builtin_plugins(exclude_file_plugins)
38
20
  external_plugins = _load_external_plugins(exclude_file_plugins)
39
- plugins = merge_plugins(builtin_plugins, external_plugins)
21
+ plugins = _merge_plugins(builtin_plugins, external_plugins)
40
22
 
41
23
  return plugins
42
24
 
@@ -61,16 +43,9 @@ def _load_builtin_file_plugins() -> list[BuildFilePlugin]:
61
43
 
62
44
 
63
45
  def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]:
64
- """
65
- Statically register built-in plugins
66
- """
67
- from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
68
- from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
46
+ """Statically register built-in plugins."""
69
47
  from databao_context_engine.plugins.duckdb_db_plugin import DuckDbPlugin
70
- from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
71
48
  from databao_context_engine.plugins.parquet_plugin import ParquetPlugin
72
- from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
73
- from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
74
49
 
75
50
  # optional plugins are added to the python environment via extras
76
51
  optional_plugins: list[BuildDatasourcePlugin] = []
@@ -81,31 +56,57 @@ def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]:
81
56
  except ImportError:
82
57
  pass
83
58
 
59
+ try:
60
+ from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
61
+
62
+ optional_plugins.append(ClickhouseDbPlugin())
63
+ except ImportError:
64
+ pass
65
+
66
+ try:
67
+ from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
68
+
69
+ optional_plugins.append(AthenaDbPlugin())
70
+ except ImportError:
71
+ pass
72
+
73
+ try:
74
+ from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
75
+
76
+ optional_plugins.append(SnowflakeDbPlugin())
77
+ except ImportError:
78
+ pass
79
+
80
+ try:
81
+ from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
82
+
83
+ optional_plugins.append(MySQLDbPlugin())
84
+ except ImportError:
85
+ pass
86
+
87
+ try:
88
+ from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
89
+
90
+ optional_plugins.append(PostgresqlDbPlugin())
91
+ except ImportError:
92
+ pass
93
+
84
94
  required_plugins: list[BuildDatasourcePlugin] = [
85
- AthenaDbPlugin(),
86
- ClickhouseDbPlugin(),
87
95
  DuckDbPlugin(),
88
- MySQLDbPlugin(),
89
- PostgresqlDbPlugin(),
90
- SnowflakeDbPlugin(),
91
96
  ParquetPlugin(),
92
97
  ]
93
98
  return required_plugins + optional_plugins
94
99
 
95
100
 
96
101
  def _load_external_plugins(exclude_file_plugins: bool = False) -> list[BuildPlugin]:
97
- """
98
- Discover external plugins via entry points
99
- """
102
+ """Discover external plugins via entry points."""
100
103
  # TODO: implement external plugin loading
101
104
  return []
102
105
 
103
106
 
104
- def merge_plugins(*plugin_lists: list[BuildPlugin]) -> PluginList:
105
- """
106
- Merge multiple plugin maps
107
- """
108
- registry: PluginList = {}
107
+ def _merge_plugins(*plugin_lists: list[BuildPlugin]) -> dict[DatasourceType, BuildPlugin]:
108
+ """Merge multiple plugin maps."""
109
+ registry: dict[DatasourceType, BuildPlugin] = {}
109
110
  for plugins in plugin_lists:
110
111
  for plugin in plugins:
111
112
  for full_type in plugin.supported_types():
@@ -6,10 +6,11 @@ from dataclasses import dataclass, replace
6
6
  from urllib.parse import urlparse
7
7
 
8
8
  import duckdb
9
- from _duckdb import DuckDBPyConnection
9
+ from duckdb import DuckDBPyConnection
10
10
  from pydantic import BaseModel, Field
11
11
 
12
12
  from databao_context_engine.pluginlib.config import DuckDBSecret
13
+ from databao_context_engine.plugins.duckdb_tools import fetchall_dicts, generate_create_secret_sql
13
14
 
14
15
  parquet_type = "resources/parquet"
15
16
 
@@ -18,9 +19,8 @@ logger = logging.getLogger(__name__)
18
19
 
19
20
  class ParquetConfigFile(BaseModel):
20
21
  name: str | None = Field(default=None)
21
- type: str = Field(default=parquet_type)
22
+ type: str = Field(default="parquet")
22
23
  url: str = Field(
23
- default=type,
24
24
  description="Parquet resource location. Should be a valid URL or a path to a local file. "
25
25
  "Examples: s3://your_bucket/file.parquet, s3://your-bucket/*.parquet, https://some.url/some_file.parquet, ~/path_to/file.parquet",
26
26
  )
@@ -50,14 +50,6 @@ class ParquetIntrospectionResult:
50
50
  files: list[ParquetFile]
51
51
 
52
52
 
53
- def generate_create_secret_sql(secret_name, duckdb_secret: DuckDBSecret) -> str:
54
- parameters = [("type", duckdb_secret.type)] + list(duckdb_secret.properties.items())
55
- return f"""CREATE SECRET {secret_name} (
56
- {", ".join([f"{k} {v}" for (k, v) in parameters])}
57
- );
58
- """
59
-
60
-
61
53
  @contextlib.contextmanager
62
54
  def _create_secret(conn: DuckDBPyConnection, duckdb_secret: DuckDBSecret):
63
55
  secret_name = duckdb_secret.name or "gen_secret_" + str(uuid.uuid4()).replace("-", "_")
@@ -98,10 +90,9 @@ class ParquetIntrospector:
98
90
  with self._connect(file_config) as conn:
99
91
  with conn.cursor() as cur:
100
92
  resolved_url = _resolve_url(file_config)
101
- cur.execute(f"SELECT * FROM parquet_file_metadata('{resolved_url}') LIMIT 1")
102
- columns = [desc[0].lower() for desc in cur.description] if cur.description else []
103
- rows = cur.fetchall()
104
- parquet_file_metadata = [dict(zip(columns, row)) for row in rows]
93
+ parquet_file_metadata = fetchall_dicts(
94
+ cur, f"SELECT * FROM parquet_file_metadata('{resolved_url}') LIMIT 1"
95
+ )
105
96
  if not parquet_file_metadata:
106
97
  raise ValueError(f"No parquet files found by url {resolved_url}")
107
98
  if not parquet_file_metadata or not parquet_file_metadata[0]["file_name"]:
@@ -111,10 +102,7 @@ class ParquetIntrospector:
111
102
  with self._connect(file_config) as conn:
112
103
  with conn.cursor() as cur:
113
104
  resolved_url = _resolve_url(file_config)
114
- cur.execute(f"SELECT * from parquet_metadata('{resolved_url}')")
115
- cols = [desc[0].lower() for desc in cur.description] if cur.description else []
116
- rows = cur.fetchall()
117
- file_metas = [dict(zip(cols, row)) for row in rows]
105
+ file_metas = fetchall_dicts(cur, f"SELECT * from parquet_metadata('{resolved_url}')")
118
106
 
119
107
  columns_per_file: dict[str, dict[int, ParquetColumn]] = defaultdict(defaultdict)
120
108
  for file_meta in file_metas:
@@ -3,29 +3,56 @@ from importlib.metadata import version
3
3
  from pathlib import Path
4
4
  from uuid import UUID
5
5
 
6
+ from databao_context_engine.plugins.plugin_loader import load_plugins
6
7
  from databao_context_engine.project.layout import validate_project_dir
7
8
  from databao_context_engine.system.properties import get_dce_path
8
9
 
9
10
 
10
11
  @dataclass(kw_only=True, frozen=True)
11
12
  class DceProjectInfo:
13
+ """Information about a Databao Context Engine project.
14
+
15
+ Attributes:
16
+ project_path: The root directory of the Databao Context Engine project.
17
+ is_initialized: Whether the project has been initialized.
18
+ project_id: The UUID of the project, or None if the project has not been initialized.
19
+ """
20
+
12
21
  project_path: Path
13
- is_initialised: bool
22
+ is_initialized: bool
14
23
  project_id: UUID | None
15
24
 
16
25
 
17
26
  @dataclass(kw_only=True, frozen=True)
18
27
  class DceInfo:
28
+ """Information about the current Databao Context Engine installation and project.
29
+
30
+ Attributes:
31
+ version: The version of the databao_context_engine package installed on the system.
32
+ dce_path: The path where databao_context_engine stores its global data.
33
+ project_info: Information about the Databao Context Engine project.
34
+ """
35
+
19
36
  version: str
20
37
  dce_path: Path
38
+ plugin_ids: list[str]
21
39
 
22
40
  project_info: DceProjectInfo
23
41
 
24
42
 
25
43
  def get_databao_context_engine_info(project_dir: Path) -> DceInfo:
44
+ """Return information about the current Databao Context Engine installation and project.
45
+
46
+ Args:
47
+ project_dir: The root directory of the Databao Context Project.
48
+
49
+ Returns:
50
+ A DceInfo instance containing information about the Databao Context Engine installation and project.
51
+ """
26
52
  return DceInfo(
27
53
  version=get_dce_version(),
28
54
  dce_path=get_dce_path(),
55
+ plugin_ids=[plugin.id for plugin in load_plugins().values()],
29
56
  project_info=_get_project_info(project_dir),
30
57
  )
31
58
 
@@ -35,10 +62,15 @@ def _get_project_info(project_dir: Path) -> DceProjectInfo:
35
62
 
36
63
  return DceProjectInfo(
37
64
  project_path=project_dir,
38
- is_initialised=project_layout is not None,
65
+ is_initialized=project_layout is not None,
39
66
  project_id=project_layout.read_config_file().project_id if project_layout is not None else None,
40
67
  )
41
68
 
42
69
 
43
70
  def get_dce_version() -> str:
71
+ """Return the installed version of the databao_context_engine package.
72
+
73
+ Returns:
74
+ The installed version of the databao_context_engine package.
75
+ """
44
76
  return version("databao_context_engine")
@@ -13,12 +13,21 @@ from databao_context_engine.project.project_config import ProjectConfig
13
13
 
14
14
 
15
15
  class InitErrorReason(Enum):
16
+ """Reasons for which project initialization can fail."""
17
+
16
18
  PROJECT_DIR_DOESNT_EXIST = "PROJECT_DIR_DOESNT_EXIST"
17
19
  PROJECT_DIR_NOT_DIRECTORY = "PROJECT_DIR_NOT_DIRECTORY"
18
- PROJECT_DIR_ALREADY_INITIALISED = "PROJECT_DIR_ALREADY_INITIALISED"
20
+ PROJECT_DIR_ALREADY_INITIALIZED = "PROJECT_DIR_ALREADY_INITIALIZED"
19
21
 
20
22
 
21
23
  class InitProjectError(Exception):
24
+ """Raised when a project can't be initialized.
25
+
26
+ Attributes:
27
+ message: The error message.
28
+ reason: The reason for the initialization failure.
29
+ """
30
+
22
31
  reason: InitErrorReason
23
32
 
24
33
  def __init__(self, reason: InitErrorReason, message: str | None):
@@ -65,20 +74,20 @@ class _ProjectCreator:
65
74
 
66
75
  if self.config_file.is_file() or self.deprecated_config_file.is_file():
67
76
  raise InitProjectError(
68
- message=f"Can't initialise a Databao Context Engine project in a folder that already contains a config file. [project_dir: {self.project_dir.resolve()}]",
69
- reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALISED,
77
+ message=f"Can't initialize a Databao Context Engine project in a folder that already contains a config file. [project_dir: {self.project_dir.resolve()}]",
78
+ reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
70
79
  )
71
80
 
72
81
  if self.src_dir.is_dir():
73
82
  raise InitProjectError(
74
- message=f"Can't initialise a Databao Context Engine project in a folder that already contains a src directory. [project_dir: {self.project_dir.resolve()}]",
75
- reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALISED,
83
+ message=f"Can't initialize a Databao Context Engine project in a folder that already contains a src directory. [project_dir: {self.project_dir.resolve()}]",
84
+ reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
76
85
  )
77
86
 
78
87
  if self.examples_dir.is_file():
79
88
  raise InitProjectError(
80
- message=f"Can't initialise a Databao Context Engine project in a folder that already contains an examples dir. [project_dir: {self.project_dir.resolve()}]",
81
- reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALISED,
89
+ message=f"Can't initialize a Databao Context Engine project in a folder that already contains an examples dir. [project_dir: {self.project_dir.resolve()}]",
90
+ reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
82
91
  )
83
92
 
84
93
  return True
@@ -2,8 +2,8 @@ import logging
2
2
  from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
 
5
+ from databao_context_engine.datasources.types import DatasourceId
5
6
  from databao_context_engine.project.project_config import ProjectConfig
6
- from databao_context_engine.project.types import DatasourceId
7
7
 
8
8
  SOURCE_FOLDER_NAME = "src"
9
9
  OUTPUT_FOLDER_NAME = "output"
@@ -96,12 +96,12 @@ class _ProjectValidator:
96
96
 
97
97
  if self.config_file is None:
98
98
  raise ValueError(
99
- f"The current project directory has not been initialised. It should contain a config file. [project_dir: {self.project_dir.resolve()}]"
99
+ f"The current project directory has not been initialized. It should contain a config file. [project_dir: {self.project_dir.resolve()}]"
100
100
  )
101
101
 
102
102
  if not self.is_src_valid():
103
103
  raise ValueError(
104
- f"The current project directory has not been initialised. It should contain a src directory. [project_dir: {self.project_dir.resolve()}]"
104
+ f"The current project directory has not been initialized. It should contain a src directory. [project_dir: {self.project_dir.resolve()}]"
105
105
  )
106
106
 
107
107
  return ProjectLayout(project_dir=self.project_dir, config_file=self.config_file)
@@ -0,0 +1,3 @@
1
+ from databao_context_engine.retrieve_embeddings.retrieve_wiring import retrieve_embeddings
2
+
3
+ __all__ = ["retrieve_embeddings"]
@@ -1,8 +1,8 @@
1
1
  from pathlib import Path
2
2
 
3
3
 
4
- def export_retrieve_results(run_dir: Path, retrieve_results: list[str]) -> Path:
5
- path = run_dir.joinpath("context_duckdb.yaml")
4
+ def export_retrieve_results(output_dir: Path, retrieve_results: list[str]) -> Path:
5
+ path = output_dir.joinpath("context_duckdb.yaml")
6
6
 
7
7
  with path.open("w") as export_file:
8
8
  for result in retrieve_results:
@@ -1,9 +1,9 @@
1
1
  import logging
2
2
  from pathlib import Path
3
3
 
4
- from databao_context_engine.project.runs import get_run_dir
5
- from databao_context_engine.retrieve_embeddings.internal.export_results import export_retrieve_results
6
- from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
4
+ from databao_context_engine.project.layout import get_output_dir
5
+ from databao_context_engine.retrieve_embeddings.export_results import export_retrieve_results
6
+ from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
7
7
  from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
8
8
 
9
9
  logger = logging.getLogger(__name__)
@@ -15,17 +15,13 @@ def retrieve(
15
15
  retrieve_service: RetrieveService,
16
16
  project_id: str,
17
17
  text: str,
18
- run_name: str | None,
19
18
  limit: int | None,
20
19
  export_to_file: bool,
21
20
  ) -> list[VectorSearchResult]:
22
- resolved_run_name = retrieve_service.resolve_run_name(project_id=project_id, run_name=run_name)
23
- retrieve_results = retrieve_service.retrieve(
24
- project_id=project_id, text=text, run_name=resolved_run_name, limit=limit
25
- )
21
+ retrieve_results = retrieve_service.retrieve(project_id=project_id, text=text, limit=limit)
26
22
 
27
23
  if export_to_file:
28
- export_directory = get_run_dir(project_dir=project_dir, run_name=resolved_run_name)
24
+ export_directory = get_output_dir(project_dir)
29
25
 
30
26
  display_texts = [result.display_text for result in retrieve_results]
31
27
  export_file = export_retrieve_results(export_directory, display_texts)
@@ -2,9 +2,7 @@ import logging
2
2
  from collections.abc import Sequence
3
3
 
4
4
  from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
5
- from databao_context_engine.project.runs import resolve_run_name_from_repo
6
5
  from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
7
- from databao_context_engine.storage.repositories.run_repository import RunRepository
8
6
  from databao_context_engine.storage.repositories.vector_search_repository import (
9
7
  VectorSearchRepository,
10
8
  VectorSearchResult,
@@ -17,43 +15,34 @@ class RetrieveService:
17
15
  def __init__(
18
16
  self,
19
17
  *,
20
- run_repo: RunRepository,
21
18
  vector_search_repo: VectorSearchRepository,
22
19
  shard_resolver: EmbeddingShardResolver,
23
20
  provider: EmbeddingProvider,
24
21
  ):
25
- self._run_repo = run_repo
26
22
  self._shard_resolver = shard_resolver
27
23
  self._provider = provider
28
24
  self._vector_search_repo = vector_search_repo
29
25
 
30
- def retrieve(
31
- self, *, project_id: str, text: str, run_name: str, limit: int | None = None
32
- ) -> list[VectorSearchResult]:
26
+ def retrieve(self, *, project_id: str, text: str, limit: int | None = None) -> list[VectorSearchResult]:
33
27
  if limit is None:
34
28
  limit = 10
35
29
 
36
- run = self._run_repo.get_by_run_name(project_id=project_id, run_name=run_name)
37
- if run is None:
38
- raise LookupError(f"Run '{run_name}' not found for project '{project_id}'.")
39
-
40
30
  table_name, dimension = self._shard_resolver.resolve(
41
31
  embedder=self._provider.embedder, model_id=self._provider.model_id
42
32
  )
43
33
 
44
34
  retrieve_vec: Sequence[float] = self._provider.embed(text)
45
35
 
46
- logger.debug(f"Retrieving display texts for run {run.run_id} in table {table_name}")
36
+ logger.debug(f"Retrieving display texts in table {table_name}")
47
37
 
48
38
  search_results = self._vector_search_repo.get_display_texts_by_similarity(
49
39
  table_name=table_name,
50
- run_id=run.run_id,
51
40
  retrieve_vec=retrieve_vec,
52
41
  dimension=dimension,
53
42
  limit=limit,
54
43
  )
55
44
 
56
- logger.debug(f"Retrieved {len(search_results)} display texts for run {run.run_id} in table {table_name}")
45
+ logger.debug(f"Retrieved {len(search_results)} display texts in table {table_name}")
57
46
 
58
47
  if logger.isEnabledFor(logging.DEBUG):
59
48
  closest_result = min(search_results, key=lambda result: result.cosine_distance)
@@ -63,6 +52,3 @@ class RetrieveService:
63
52
  logger.debug(f"Worst result: ({farthest_result.cosine_distance}, {farthest_result.embeddable_text})")
64
53
 
65
54
  return search_results
66
-
67
- def resolve_run_name(self, *, project_id: str, run_name: str | None) -> str:
68
- return resolve_run_name_from_repo(run_repository=self._run_repo, project_id=project_id, run_name=run_name)
@@ -0,0 +1,49 @@
1
+ from duckdb import DuckDBPyConnection
2
+
3
+ from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
4
+ from databao_context_engine.llm.factory import create_ollama_embedding_provider, create_ollama_service
5
+ from databao_context_engine.project.layout import ProjectLayout, ensure_project_dir
6
+ from databao_context_engine.retrieve_embeddings.retrieve_runner import retrieve
7
+ from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
8
+ from databao_context_engine.services.factories import create_shard_resolver
9
+ from databao_context_engine.storage.connection import open_duckdb_connection
10
+ from databao_context_engine.storage.repositories.factories import create_vector_search_repository
11
+ from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
12
+ from databao_context_engine.system.properties import get_db_path
13
+
14
+
15
+ def retrieve_embeddings(
16
+ project_layout: ProjectLayout,
17
+ retrieve_text: str,
18
+ limit: int | None,
19
+ export_to_file: bool,
20
+ ) -> list[VectorSearchResult]:
21
+ ensure_project_dir(project_layout.project_dir)
22
+
23
+ with open_duckdb_connection(get_db_path(project_layout.project_dir)) as conn:
24
+ ollama_service = create_ollama_service()
25
+ embedding_provider = create_ollama_embedding_provider(ollama_service)
26
+ retrieve_service = _create_retrieve_service(conn, embedding_provider=embedding_provider)
27
+ return retrieve(
28
+ project_dir=project_layout.project_dir,
29
+ retrieve_service=retrieve_service,
30
+ project_id=str(project_layout.read_config_file().project_id),
31
+ text=retrieve_text,
32
+ limit=limit,
33
+ export_to_file=export_to_file,
34
+ )
35
+
36
+
37
+ def _create_retrieve_service(
38
+ conn: DuckDBPyConnection,
39
+ *,
40
+ embedding_provider: EmbeddingProvider,
41
+ ) -> RetrieveService:
42
+ vector_search_repo = create_vector_search_repository(conn)
43
+ shard_resolver = create_shard_resolver(conn)
44
+
45
+ return RetrieveService(
46
+ vector_search_repo=vector_search_repo,
47
+ shard_resolver=shard_resolver,
48
+ provider=embedding_provider,
49
+ )
@@ -8,7 +8,7 @@ def default_representer(dumper: SafeDumper, data: object) -> Node:
8
8
  if isinstance(data, Mapping):
9
9
  return dumper.represent_dict(data)
10
10
  elif hasattr(data, "__dict__"):
11
- # Doesn't serialise "private" attributes (that starts with an _)
11
+ # Doesn't serialize "private" attributes (that starts with an _)
12
12
  data_public_attributes = {key: value for key, value in data.__dict__.items() if not key.startswith("_")}
13
13
  if data_public_attributes:
14
14
  return dumper.represent_dict(data_public_attributes)