kodit 0.1.11__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (92) hide show
  1. {kodit-0.1.11 → kodit-0.1.13}/PKG-INFO +3 -1
  2. {kodit-0.1.11 → kodit-0.1.13}/docs/_index.md +64 -0
  3. {kodit-0.1.11 → kodit-0.1.13}/pyproject.toml +2 -0
  4. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/_version.py +2 -2
  5. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/bm25/bm25.py +1 -1
  6. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/cli.py +22 -59
  7. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/config.py +43 -3
  8. kodit-0.1.13/src/kodit/embedding/embedding.py +203 -0
  9. kodit-0.1.11/src/kodit/indexing/models.py → kodit-0.1.13/src/kodit/indexing/indexing_models.py +2 -2
  10. kodit-0.1.11/src/kodit/indexing/repository.py → kodit-0.1.13/src/kodit/indexing/indexing_repository.py +5 -5
  11. kodit-0.1.11/src/kodit/indexing/service.py → kodit-0.1.13/src/kodit/indexing/indexing_service.py +17 -12
  12. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/log.py +1 -0
  13. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/mcp.py +27 -34
  14. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/env.py +3 -3
  15. kodit-0.1.13/src/kodit/search/__init__.py +1 -0
  16. kodit-0.1.11/src/kodit/retreival/repository.py → kodit-0.1.13/src/kodit/search/search_repository.py +59 -112
  17. kodit-0.1.11/src/kodit/retreival/service.py → kodit-0.1.13/src/kodit/search/search_service.py +40 -17
  18. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/snippets.py +3 -1
  19. kodit-0.1.11/src/kodit/sources/repository.py → kodit-0.1.13/src/kodit/source/source_repository.py +2 -7
  20. kodit-0.1.11/src/kodit/sources/service.py → kodit-0.1.13/src/kodit/source/source_service.py +2 -2
  21. kodit-0.1.13/tests/kodit/cli_test.py +57 -0
  22. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/e2e.py +4 -4
  23. kodit-0.1.13/tests/kodit/embedding/embedding_test.py +13 -0
  24. kodit-0.1.11/tests/kodit/indexing/test_service.py → kodit-0.1.13/tests/kodit/indexing/indexing_service_test.py +7 -7
  25. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/mcp_test.py +2 -2
  26. kodit-0.1.13/tests/kodit/search/__init__.py +1 -0
  27. kodit-0.1.13/tests/kodit/search/search_repository_test.py +124 -0
  28. kodit-0.1.11/tests/kodit/retreival/test_service.py → kodit-0.1.13/tests/kodit/search/search_service_test.py +52 -44
  29. kodit-0.1.11/tests/kodit/sources/test_service.py → kodit-0.1.13/tests/kodit/source/source_service_test.py +2 -2
  30. {kodit-0.1.11 → kodit-0.1.13}/tests/performance/similarity.py +5 -5
  31. {kodit-0.1.11 → kodit-0.1.13}/tests/smoke.sh +1 -1
  32. {kodit-0.1.11 → kodit-0.1.13}/uv.lock +95 -0
  33. kodit-0.1.11/src/kodit/embedding/embedding.py +0 -52
  34. kodit-0.1.11/src/kodit/retreival/__init__.py +0 -1
  35. kodit-0.1.11/tests/kodit/cli_test.py +0 -71
  36. kodit-0.1.11/tests/kodit/embedding/embedding_test.py +0 -9
  37. kodit-0.1.11/tests/kodit/retreival/__init__.py +0 -1
  38. kodit-0.1.11/tests/kodit/retreival/repository_test.py +0 -57
  39. {kodit-0.1.11 → kodit-0.1.13}/.cursor/rules/kodit.mdc +0 -0
  40. {kodit-0.1.11 → kodit-0.1.13}/.github/CODE_OF_CONDUCT.md +0 -0
  41. {kodit-0.1.11 → kodit-0.1.13}/.github/CONTRIBUTING.md +0 -0
  42. {kodit-0.1.11 → kodit-0.1.13}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  43. {kodit-0.1.11 → kodit-0.1.13}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  44. {kodit-0.1.11 → kodit-0.1.13}/.github/PULL_REQUEST_TEMPLATE.md +0 -0
  45. {kodit-0.1.11 → kodit-0.1.13}/.github/workflows/docker.yaml +0 -0
  46. {kodit-0.1.11 → kodit-0.1.13}/.github/workflows/docs.yaml +0 -0
  47. {kodit-0.1.11 → kodit-0.1.13}/.github/workflows/pypi-test.yaml +0 -0
  48. {kodit-0.1.11 → kodit-0.1.13}/.github/workflows/pypi.yaml +0 -0
  49. {kodit-0.1.11 → kodit-0.1.13}/.github/workflows/test.yaml +0 -0
  50. {kodit-0.1.11 → kodit-0.1.13}/.gitignore +0 -0
  51. {kodit-0.1.11 → kodit-0.1.13}/.python-version +0 -0
  52. {kodit-0.1.11 → kodit-0.1.13}/.vscode/launch.json +0 -0
  53. {kodit-0.1.11 → kodit-0.1.13}/.vscode/settings.json +0 -0
  54. {kodit-0.1.11 → kodit-0.1.13}/Dockerfile +0 -0
  55. {kodit-0.1.11 → kodit-0.1.13}/LICENSE +0 -0
  56. {kodit-0.1.11 → kodit-0.1.13}/README.md +0 -0
  57. {kodit-0.1.11 → kodit-0.1.13}/alembic.ini +0 -0
  58. {kodit-0.1.11 → kodit-0.1.13}/docs/developer/index.md +0 -0
  59. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/.gitignore +0 -0
  60. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/__init__.py +0 -0
  61. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/app.py +0 -0
  62. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/bm25/__init__.py +0 -0
  63. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/database.py +0 -0
  64. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/embedding/__init__.py +0 -0
  65. /kodit-0.1.11/src/kodit/embedding/models.py → /kodit-0.1.13/src/kodit/embedding/embedding_models.py +0 -0
  66. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/indexing/__init__.py +0 -0
  67. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/middleware.py +0 -0
  68. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/README +0 -0
  69. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/__init__.py +0 -0
  70. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/script.py.mako +0 -0
  71. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/versions/7c3bbc2ab32b_add_embeddings_table.py +0 -0
  72. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/versions/85155663351e_initial.py +0 -0
  73. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/migrations/versions/__init__.py +0 -0
  74. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/__init__.py +0 -0
  75. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/languages/__init__.py +0 -0
  76. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/languages/csharp.scm +0 -0
  77. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/languages/python.scm +0 -0
  78. {kodit-0.1.11 → kodit-0.1.13}/src/kodit/snippets/method_snippets.py +0 -0
  79. {kodit-0.1.11/src/kodit/sources → kodit-0.1.13/src/kodit/source}/__init__.py +0 -0
  80. /kodit-0.1.11/src/kodit/sources/models.py → /kodit-0.1.13/src/kodit/source/source_models.py +0 -0
  81. {kodit-0.1.11 → kodit-0.1.13}/tests/__init__.py +0 -0
  82. {kodit-0.1.11 → kodit-0.1.13}/tests/conftest.py +0 -0
  83. {kodit-0.1.11 → kodit-0.1.13}/tests/experiments/embedding.py +0 -0
  84. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/__init__.py +0 -0
  85. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/embedding/__init__.py +0 -0
  86. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/indexing/__init__.py +0 -0
  87. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/snippets/__init__.py +0 -0
  88. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/snippets/csharp.cs +0 -0
  89. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/snippets/detect_language_test.py +0 -0
  90. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/snippets/method_extraction_test.py +0 -0
  91. {kodit-0.1.11 → kodit-0.1.13}/tests/kodit/snippets/python.py +0 -0
  92. {kodit-0.1.11/tests/kodit/sources → kodit-0.1.13/tests/kodit/source}/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -32,6 +32,7 @@ Requires-Dist: gitpython>=3.1.44
32
32
  Requires-Dist: hf-xet>=1.1.2
33
33
  Requires-Dist: httpx-retries>=0.3.2
34
34
  Requires-Dist: httpx>=0.28.1
35
+ Requires-Dist: openai>=1.82.0
35
36
  Requires-Dist: posthog>=4.0.1
36
37
  Requires-Dist: pydantic-settings>=2.9.1
37
38
  Requires-Dist: pytable-formatter>=0.1.1
@@ -39,6 +40,7 @@ Requires-Dist: sentence-transformers>=4.1.0
39
40
  Requires-Dist: sqlalchemy[asyncio]>=2.0.40
40
41
  Requires-Dist: structlog>=25.3.0
41
42
  Requires-Dist: tdqm>=0.0.1
43
+ Requires-Dist: tiktoken>=0.9.0
42
44
  Requires-Dist: tree-sitter-language-pack>=0.7.3
43
45
  Requires-Dist: tree-sitter>=0.24.0
44
46
  Requires-Dist: uritools>=5.0.0
@@ -94,3 +94,67 @@ You MUST use the code-search MCP tool and always include any file context the us
94
94
  ```
95
95
 
96
96
  Alternatively, you can browse to the cursor settings and set this prompt globally.
97
+
98
+ ### Integration with Cline
99
+
100
+ 1. Click on the Cline icon in the menu
101
+ 2. Click the `MCP Servers` button at the top right of the Cline window (looks like a
102
+ server)
103
+ 3. Click the `Remote Servers` tab.
104
+ 4. Click `Edit Configuration`
105
+ 5. Add the following configuration:
106
+
107
+ ```json
108
+ {
109
+ "mcpServers": {
110
+ "kodit": {
111
+ "autoApprove": [],
112
+ "disabled": true,
113
+ "timeout": 60,
114
+ "url": "http://localhost:8080/sse",
115
+ "transportType": "sse"
116
+ }
117
+ }
118
+ }
119
+ ```
120
+
121
+ 6. Save the configuration and browse to the `Installed` tab.
122
+
123
+ Kodit should be listed and responding. Now code on!
124
+
125
+ ## Configuring Kodit
126
+
127
+ Configuration of Kodit is performed by setting environmental variables or adding
128
+ variables to a .env file.
129
+
130
+ {{< warn >}}
131
+ Note that updating a setting does not automatically update the data that uses that
132
+ setting. For example, if you change a provider, you will need to delete and
133
+ recreate all indexes.
134
+ {{< /warn >}}
135
+
136
+ ### Indexing
137
+
138
+ #### Default Provider
139
+
140
+ By default, Kodit will use small local models for semantic search and enrichment. If you
141
+ are using Kodit in a professional capacity, it is likely that the local model latency is
142
+ too high to provide a good developer experience.
143
+
144
+ Instead, you should use an external provider. The settings provided here will cause all
145
+ embedding and enrichments request to be sent to this provider by default. You can
146
+ override the provider used for each task if you wish. (Coming soon!)
147
+
148
+ ##### OpenAI
149
+
150
+ Add the following settings to your .env file, or export them as environmental variables:
151
+
152
+ ```bash
153
+ DEFAULT_ENDPOINT_BASE_URL=https://api.openai.com/v1
154
+ DEFAULT_ENDPOINT_API_KEY=sk-xxxxxx
155
+ ```
156
+
157
+ ## Managing Kodit
158
+
159
+ There is limited management functionality at this time. To delete indexes you must
160
+ delete the database and/or tables.
@@ -46,6 +46,8 @@ dependencies = [
46
46
  "gitpython>=3.1.44",
47
47
  "sentence-transformers>=4.1.0",
48
48
  "hf-xet>=1.1.2",
49
+ "openai>=1.82.0",
50
+ "tiktoken>=0.9.0",
49
51
  ]
50
52
 
51
53
  [dependency-groups]
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.11'
21
- __version_tuple__ = version_tuple = (0, 1, 11)
20
+ __version__ = version = '0.1.13'
21
+ __version_tuple__ = version_tuple = (0, 1, 13)
@@ -52,7 +52,7 @@ class BM25Service:
52
52
  self.log.warning("No documents to retrieve from, returning empty list")
53
53
  return []
54
54
 
55
- top_k = min(top_k, len(doc_ids))
55
+ top_k = min(top_k, len(self.retriever.scores))
56
56
  self.log.debug(
57
57
  "Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
58
58
  )
@@ -1,6 +1,5 @@
1
1
  """Command line interface for kodit."""
2
2
 
3
- import os
4
3
  import signal
5
4
  from pathlib import Path
6
5
  from typing import Any
@@ -12,35 +11,21 @@ from pytable_formatter import Cell, Table
12
11
  from sqlalchemy.ext.asyncio import AsyncSession
13
12
 
14
13
  from kodit.config import (
15
- DEFAULT_BASE_DIR,
16
- DEFAULT_DB_URL,
17
- DEFAULT_DISABLE_TELEMETRY,
18
- DEFAULT_EMBEDDING_MODEL_NAME,
19
- DEFAULT_LOG_FORMAT,
20
- DEFAULT_LOG_LEVEL,
21
14
  AppContext,
22
15
  with_app_context,
23
16
  with_session,
24
17
  )
25
- from kodit.indexing.repository import IndexRepository
26
- from kodit.indexing.service import IndexService
18
+ from kodit.embedding.embedding import embedding_factory
19
+ from kodit.indexing.indexing_repository import IndexRepository
20
+ from kodit.indexing.indexing_service import IndexService
27
21
  from kodit.log import configure_logging, configure_telemetry, log_event
28
- from kodit.retreival.repository import RetrievalRepository
29
- from kodit.retreival.service import RetrievalRequest, RetrievalService
30
- from kodit.sources.repository import SourceRepository
31
- from kodit.sources.service import SourceService
22
+ from kodit.search.search_repository import SearchRepository
23
+ from kodit.search.search_service import SearchRequest, SearchService
24
+ from kodit.source.source_repository import SourceRepository
25
+ from kodit.source.source_service import SourceService
32
26
 
33
27
 
34
28
  @click.group(context_settings={"max_content_width": 100})
35
- @click.option("--log-level", help=f"Log level [default: {DEFAULT_LOG_LEVEL}]")
36
- @click.option("--log-format", help=f"Log format [default: {DEFAULT_LOG_FORMAT}]")
37
- @click.option(
38
- "--disable-telemetry",
39
- is_flag=True,
40
- help=f"Disable telemetry [default: {DEFAULT_DISABLE_TELEMETRY}]",
41
- )
42
- @click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
43
- @click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
44
29
  @click.option(
45
30
  "--env-file",
46
31
  help="Path to a .env file [default: .env]",
@@ -52,13 +37,8 @@ from kodit.sources.service import SourceService
52
37
  ),
53
38
  )
54
39
  @click.pass_context
55
- def cli( # noqa: PLR0913
40
+ def cli(
56
41
  ctx: click.Context,
57
- log_level: str | None,
58
- log_format: str | None,
59
- disable_telemetry: bool | None,
60
- db_url: str | None,
61
- data_dir: str | None,
62
42
  env_file: Path | None,
63
43
  ) -> None:
64
44
  """kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
@@ -67,17 +47,6 @@ def cli( # noqa: PLR0913
67
47
  if env_file:
68
48
  config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
69
49
 
70
- # Now override with CLI arguments, if set
71
- if data_dir:
72
- config.data_dir = Path(data_dir)
73
- if db_url:
74
- config.db_url = db_url
75
- if log_level:
76
- config.log_level = log_level
77
- if log_format:
78
- config.log_format = log_format
79
- if disable_telemetry:
80
- config.disable_telemetry = disable_telemetry
81
50
  configure_logging(config)
82
51
  configure_telemetry(config)
83
52
 
@@ -102,7 +71,7 @@ async def index(
102
71
  repository,
103
72
  source_service,
104
73
  app_context.get_data_dir(),
105
- embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
74
+ embedding_service=embedding_factory(app_context.get_default_openai_client()),
106
75
  )
107
76
 
108
77
  if not sources:
@@ -159,14 +128,14 @@ async def code(
159
128
 
160
129
  This works best if your query is code.
161
130
  """
162
- repository = RetrievalRepository(session)
163
- service = RetrievalService(
131
+ repository = SearchRepository(session)
132
+ service = SearchService(
164
133
  repository,
165
134
  app_context.get_data_dir(),
166
- embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
135
+ embedding_service=embedding_factory(app_context.get_default_openai_client()),
167
136
  )
168
137
 
169
- snippets = await service.retrieve(RetrievalRequest(code_query=query, top_k=top_k))
138
+ snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
170
139
 
171
140
  if len(snippets) == 0:
172
141
  click.echo("No snippets found")
@@ -192,14 +161,14 @@ async def keyword(
192
161
  top_k: int,
193
162
  ) -> None:
194
163
  """Search for snippets using keyword search."""
195
- repository = RetrievalRepository(session)
196
- service = RetrievalService(
164
+ repository = SearchRepository(session)
165
+ service = SearchService(
197
166
  repository,
198
167
  app_context.get_data_dir(),
199
- embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
168
+ embedding_service=embedding_factory(app_context.get_default_openai_client()),
200
169
  )
201
170
 
202
- snippets = await service.retrieve(RetrievalRequest(keywords=keywords, top_k=top_k))
171
+ snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
203
172
 
204
173
  if len(snippets) == 0:
205
174
  click.echo("No snippets found")
@@ -227,18 +196,18 @@ async def hybrid(
227
196
  code: str,
228
197
  ) -> None:
229
198
  """Search for snippets using hybrid search."""
230
- repository = RetrievalRepository(session)
231
- service = RetrievalService(
199
+ repository = SearchRepository(session)
200
+ service = SearchService(
232
201
  repository,
233
202
  app_context.get_data_dir(),
234
- embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
203
+ embedding_service=embedding_factory(app_context.get_default_openai_client()),
235
204
  )
236
205
 
237
206
  # Parse keywords into a list of strings
238
207
  keywords_list = [k.strip().lower() for k in keywords.split(",")]
239
208
 
240
- snippets = await service.retrieve(
241
- RetrievalRequest(keywords=keywords_list, code_query=code, top_k=top_k)
209
+ snippets = await service.search(
210
+ SearchRequest(keywords=keywords_list, code_query=code, top_k=top_k)
242
211
  )
243
212
 
244
213
  if len(snippets) == 0:
@@ -256,9 +225,7 @@ async def hybrid(
256
225
  @cli.command()
257
226
  @click.option("--host", default="127.0.0.1", help="Host to bind the server to")
258
227
  @click.option("--port", default=8080, help="Port to bind the server to")
259
- @with_app_context
260
228
  def serve(
261
- app_context: AppContext,
262
229
  host: str,
263
230
  port: int,
264
231
  ) -> None:
@@ -267,10 +234,6 @@ def serve(
267
234
  log.info("Starting kodit server", host=host, port=port)
268
235
  log_event("kodit_server_started")
269
236
 
270
- # Dump AppContext to a dictionary of strings, and set the env vars
271
- app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
272
- os.environ.update(app_context_dict)
273
-
274
237
  # Configure uvicorn with graceful shutdown
275
238
  config = uvicorn.Config(
276
239
  "kodit.app:app",
@@ -4,10 +4,11 @@ import asyncio
4
4
  from collections.abc import Callable, Coroutine
5
5
  from functools import wraps
6
6
  from pathlib import Path
7
- from typing import Any, TypeVar
7
+ from typing import Any, Literal, TypeVar
8
8
 
9
9
  import click
10
- from pydantic import Field
10
+ from openai import AsyncOpenAI
11
+ from pydantic import BaseModel, Field
11
12
  from pydantic_settings import BaseSettings, SettingsConfigDict
12
13
 
13
14
  from kodit.database import Database
@@ -22,16 +23,40 @@ DEFAULT_EMBEDDING_MODEL_NAME = TINY
22
23
  T = TypeVar("T")
23
24
 
24
25
 
26
+ class Endpoint(BaseModel):
27
+ """Endpoint provides configuration for an AI service."""
28
+
29
+ type: Literal["openai"] = Field(default="openai")
30
+ api_key: str | None = None
31
+ base_url: str | None = None
32
+
33
+
25
34
  class AppContext(BaseSettings):
26
35
  """Global context for the kodit project. Provides a shared state for the app."""
27
36
 
28
- model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
37
+ model_config = SettingsConfigDict(
38
+ env_file=".env",
39
+ env_file_encoding="utf-8",
40
+ env_nested_delimiter="_",
41
+ nested_model_default_partial_update=True,
42
+ env_nested_max_split=1,
43
+ )
29
44
 
30
45
  data_dir: Path = Field(default=DEFAULT_BASE_DIR)
31
46
  db_url: str = Field(default=DEFAULT_DB_URL)
32
47
  log_level: str = Field(default=DEFAULT_LOG_LEVEL)
33
48
  log_format: str = Field(default=DEFAULT_LOG_FORMAT)
34
49
  disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
50
+ default_endpoint: Endpoint | None = Field(
51
+ default=Endpoint(
52
+ type="openai",
53
+ base_url="https://api.openai.com/v1",
54
+ ),
55
+ description=(
56
+ "Default endpoint to use for all AI interactions "
57
+ "(can be overridden by task-specific configuration)."
58
+ ),
59
+ )
35
60
  _db: Database | None = None
36
61
 
37
62
  def model_post_init(self, _: Any) -> None:
@@ -58,6 +83,21 @@ class AppContext(BaseSettings):
58
83
  await self._db.run_migrations(self.db_url)
59
84
  return self._db
60
85
 
86
+ def get_default_openai_client(self) -> AsyncOpenAI | None:
87
+ """Get the default OpenAI client, if it is configured."""
88
+ endpoint = self.default_endpoint
89
+ if not (
90
+ endpoint
91
+ and endpoint.type == "openai"
92
+ and endpoint.api_key
93
+ and endpoint.base_url
94
+ ):
95
+ return None
96
+ return AsyncOpenAI(
97
+ api_key=endpoint.api_key,
98
+ base_url=endpoint.base_url,
99
+ )
100
+
61
101
 
62
102
  with_app_context = click.make_pass_decorator(AppContext)
63
103
 
@@ -0,0 +1,203 @@
1
+ """Embedding service."""
2
+
3
+ import asyncio
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import AsyncGenerator
7
+ from typing import NamedTuple
8
+
9
+ import structlog
10
+ import tiktoken
11
+ from openai import AsyncOpenAI
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ TINY = "tiny"
15
+ CODE = "code"
16
+ TEST = "test"
17
+
18
+ COMMON_EMBEDDING_MODELS = {
19
+ TINY: "ibm-granite/granite-embedding-30m-english",
20
+ CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
21
+ TEST: "minishlab/potion-base-4M",
22
+ }
23
+
24
+
25
+ class EmbeddingInput(NamedTuple):
26
+ """Input for embedding."""
27
+
28
+ id: int
29
+ text: str
30
+
31
+
32
+ class EmbeddingOutput(NamedTuple):
33
+ """Output for embedding."""
34
+
35
+ id: int
36
+ embedding: list[float]
37
+
38
+
39
+ class Embedder(ABC):
40
+ """Embedder interface."""
41
+
42
+ @abstractmethod
43
+ def embed(
44
+ self, data: list[EmbeddingInput]
45
+ ) -> AsyncGenerator[EmbeddingOutput, None]:
46
+ """Embed a list of documents.
47
+
48
+ The embedding service accepts a massive list of id,strings to embed. Behind the
49
+ scenes it batches up requests and parallelizes them for performance according to
50
+ the specifics of the embedding service.
51
+
52
+ The id reference is required because the parallelization may return results out
53
+ of order.
54
+ """
55
+
56
+ @abstractmethod
57
+ def query(self, data: list[str]) -> AsyncGenerator[list[float], None]:
58
+ """Query the embedding model."""
59
+
60
+
61
+ def embedding_factory(openai_client: AsyncOpenAI | None = None) -> Embedder:
62
+ """Create an embedding service."""
63
+ if openai_client is not None:
64
+ return OpenAIEmbedder(openai_client)
65
+ return LocalEmbedder(model_name=TINY)
66
+
67
+
68
+ class LocalEmbedder(Embedder):
69
+ """Local embedder."""
70
+
71
+ def __init__(self, model_name: str) -> None:
72
+ """Initialize the local embedder."""
73
+ self.log = structlog.get_logger(__name__)
74
+ self.log.info("Creating local embedder", model_name=model_name)
75
+ self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
76
+ self.embedding_model = None
77
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
78
+
79
+ def _model(self) -> SentenceTransformer:
80
+ """Get the embedding model."""
81
+ if self.embedding_model is None:
82
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
83
+ self.embedding_model = SentenceTransformer(
84
+ self.model_name,
85
+ trust_remote_code=True,
86
+ device="cpu", # Force CPU so we don't have to install accelerate, etc.
87
+ )
88
+ return self.embedding_model
89
+
90
+ async def embed(
91
+ self, data: list[EmbeddingInput]
92
+ ) -> AsyncGenerator[EmbeddingOutput, None]:
93
+ """Embed a list of documents."""
94
+ model = self._model()
95
+
96
+ batched_data = _split_sub_batches(self.encoding, data)
97
+
98
+ for batch in batched_data:
99
+ embeddings = model.encode(
100
+ [i.text for i in batch], show_progress_bar=False, batch_size=4
101
+ )
102
+ for i, x in zip(batch, embeddings, strict=False):
103
+ yield EmbeddingOutput(i.id, [float(y) for y in x])
104
+
105
+ async def query(self, data: list[str]) -> AsyncGenerator[list[float], None]:
106
+ """Query the embedding model."""
107
+ model = self._model()
108
+ embeddings = model.encode(data, show_progress_bar=False, batch_size=4)
109
+ for embedding in embeddings:
110
+ yield [float(x) for x in embedding]
111
+
112
+
113
+ OPENAI_MAX_EMBEDDING_SIZE = 8192
114
+ OPENAI_NUM_PARALLEL_TASKS = 10
115
+
116
+
117
+ def _split_sub_batches(
118
+ encoding: tiktoken.Encoding, data: list[EmbeddingInput]
119
+ ) -> list[list[EmbeddingInput]]:
120
+ """Split a list of strings into smaller sub-batches."""
121
+ log = structlog.get_logger(__name__)
122
+ result = []
123
+ data_to_process = [s for s in data if s.text.strip()] # Filter out empty strings
124
+
125
+ while data_to_process:
126
+ next_batch = []
127
+ current_tokens = 0
128
+
129
+ while data_to_process:
130
+ next_item = data_to_process[0]
131
+ item_tokens = len(encoding.encode(next_item.text))
132
+
133
+ if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
134
+ log.warning("Skipping too long snippet", snippet=data_to_process.pop(0))
135
+ continue
136
+
137
+ if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
138
+ break
139
+
140
+ next_batch.append(data_to_process.pop(0))
141
+ current_tokens += item_tokens
142
+
143
+ if next_batch:
144
+ result.append(next_batch)
145
+
146
+ return result
147
+
148
+
149
+ class OpenAIEmbedder(Embedder):
150
+ """OpenAI embedder."""
151
+
152
+ def __init__(
153
+ self, openai_client: AsyncOpenAI, model_name: str = "text-embedding-3-small"
154
+ ) -> None:
155
+ """Initialize the OpenAI embedder."""
156
+ self.log = structlog.get_logger(__name__)
157
+ self.log.info("Creating OpenAI embedder", model_name=model_name)
158
+ self.openai_client = openai_client
159
+ self.encoding = tiktoken.encoding_for_model(model_name)
160
+ self.log = structlog.get_logger(__name__)
161
+
162
+ async def embed(
163
+ self,
164
+ data: list[EmbeddingInput],
165
+ ) -> AsyncGenerator[EmbeddingOutput, None]:
166
+ """Embed a list of documents."""
167
+ # First split the list into a list of list where each sublist has fewer than
168
+ # max tokens.
169
+ batched_data = _split_sub_batches(self.encoding, data)
170
+
171
+ # Process batches in parallel with a semaphore to limit concurrent requests
172
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
173
+
174
+ async def process_batch(batch: list[EmbeddingInput]) -> list[EmbeddingOutput]:
175
+ async with sem:
176
+ try:
177
+ response = await self.openai_client.embeddings.create(
178
+ model="text-embedding-3-small",
179
+ input=[i.text for i in batch],
180
+ )
181
+ return [
182
+ EmbeddingOutput(i.id, x.embedding)
183
+ for i, x in zip(batch, response.data, strict=False)
184
+ ]
185
+ except Exception as e:
186
+ self.log.exception("Error embedding batch", error=str(e))
187
+ return []
188
+
189
+ # Create tasks for all batches
190
+ tasks = [process_batch(batch) for batch in batched_data]
191
+
192
+ # Process all batches and yield results as they complete
193
+ for task in asyncio.as_completed(tasks):
194
+ embeddings = await task
195
+ for e in embeddings:
196
+ yield e
197
+
198
+ async def query(self, data: list[str]) -> AsyncGenerator[list[float], None]:
199
+ """Query the embedding model."""
200
+ async for e in self.embed(
201
+ [EmbeddingInput(i, text) for i, text in enumerate(data)]
202
+ ):
203
+ yield e.embedding
@@ -31,8 +31,8 @@ class Snippet(Base, CommonMixin):
31
31
 
32
32
  __tablename__ = "snippets"
33
33
 
34
- file_id: Mapped[int] = mapped_column(ForeignKey("files.id"))
35
- index_id: Mapped[int] = mapped_column(ForeignKey("indexes.id"))
34
+ file_id: Mapped[int] = mapped_column(ForeignKey("files.id"), index=True)
35
+ index_id: Mapped[int] = mapped_column(ForeignKey("indexes.id"), index=True)
36
36
  content: Mapped[str] = mapped_column(UnicodeText, default="")
37
37
 
38
38
  def __init__(self, file_id: int, index_id: int, content: str) -> None:
@@ -11,9 +11,9 @@ from typing import TypeVar
11
11
  from sqlalchemy import delete, func, select
12
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
13
 
14
- from kodit.embedding.models import Embedding
15
- from kodit.indexing.models import Index, Snippet
16
- from kodit.sources.models import File, Source
14
+ from kodit.embedding.embedding_models import Embedding
15
+ from kodit.indexing.indexing_models import Index, Snippet
16
+ from kodit.source.source_models import File, Source
17
17
 
18
18
  T = TypeVar("T")
19
19
 
@@ -156,14 +156,14 @@ class IndexRepository:
156
156
  result = await self.session.execute(query)
157
157
  return list(result.scalars())
158
158
 
159
- async def get_all_snippets(self) -> list[Snippet]:
159
+ async def get_all_snippets(self, index_id: int) -> list[Snippet]:
160
160
  """Get all snippets.
161
161
 
162
162
  Returns:
163
163
  A list of all snippets.
164
164
 
165
165
  """
166
- query = select(Snippet).order_by(Snippet.id)
166
+ query = select(Snippet).where(Snippet.index_id == index_id).order_by(Snippet.id)
167
167
  result = await self.session.execute(query)
168
168
  return list(result.scalars())
169
169
 
@@ -14,12 +14,12 @@ import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
16
  from kodit.bm25.bm25 import BM25Service
17
- from kodit.embedding.embedding import EmbeddingService
18
- from kodit.embedding.models import Embedding, EmbeddingType
19
- from kodit.indexing.models import Snippet
20
- from kodit.indexing.repository import IndexRepository
17
+ from kodit.embedding.embedding import Embedder, EmbeddingInput
18
+ from kodit.embedding.embedding_models import Embedding, EmbeddingType
19
+ from kodit.indexing.indexing_models import Snippet
20
+ from kodit.indexing.indexing_repository import IndexRepository
21
21
  from kodit.snippets.snippets import SnippetService
22
- from kodit.sources.service import SourceService
22
+ from kodit.source.source_service import SourceService
23
23
 
24
24
  # List of MIME types that are blacklisted from being indexed
25
25
  MIME_BLACKLIST = ["unknown/unknown"]
@@ -52,7 +52,7 @@ class IndexService:
52
52
  repository: IndexRepository,
53
53
  source_service: SourceService,
54
54
  data_dir: Path,
55
- embedding_model_name: str,
55
+ embedding_service: Embedder,
56
56
  ) -> None:
57
57
  """Initialize the index service.
58
58
 
@@ -66,7 +66,7 @@ class IndexService:
66
66
  self.snippet_service = SnippetService()
67
67
  self.log = structlog.get_logger(__name__)
68
68
  self.bm25 = BM25Service(data_dir)
69
- self.code_embedding_service = EmbeddingService(model_name=embedding_model_name)
69
+ self.code_embedding_service = embedding_service
70
70
 
71
71
  async def create(self, source_id: int) -> IndexView:
72
72
  """Create a new index for a source.
@@ -132,7 +132,7 @@ class IndexService:
132
132
  # Create snippets for supported file types
133
133
  await self._create_snippets(index_id)
134
134
 
135
- snippets = await self.repository.get_all_snippets()
135
+ snippets = await self.repository.get_all_snippets(index_id)
136
136
 
137
137
  self.log.info("Creating keyword index")
138
138
  self.bm25.index(
@@ -143,12 +143,17 @@ class IndexService:
143
143
  )
144
144
 
145
145
  self.log.info("Creating semantic code index")
146
- for snippet in tqdm(snippets, total=len(snippets), leave=False):
147
- embedding = next(self.code_embedding_service.embed([snippet.content]))
146
+ async for e in tqdm(
147
+ self.code_embedding_service.embed(
148
+ [EmbeddingInput(snippet.id, snippet.content) for snippet in snippets]
149
+ ),
150
+ total=len(snippets),
151
+ leave=False,
152
+ ):
148
153
  await self.repository.add_embedding(
149
154
  Embedding(
150
- snippet_id=snippet.id,
151
- embedding=embedding,
155
+ snippet_id=e.id,
156
+ embedding=e.embedding,
152
157
  type=EmbeddingType.CODE,
153
158
  )
154
159
  )