kodit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (52) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +9 -2
  3. kodit/application/factories/code_indexing_factory.py +62 -13
  4. kodit/application/factories/reporting_factory.py +32 -0
  5. kodit/application/services/auto_indexing_service.py +41 -33
  6. kodit/application/services/code_indexing_application_service.py +137 -138
  7. kodit/application/services/indexing_worker_service.py +26 -30
  8. kodit/application/services/queue_service.py +12 -14
  9. kodit/application/services/reporting.py +104 -0
  10. kodit/application/services/sync_scheduler.py +21 -20
  11. kodit/cli.py +71 -85
  12. kodit/config.py +26 -3
  13. kodit/database.py +2 -1
  14. kodit/domain/entities.py +99 -1
  15. kodit/domain/protocols.py +34 -1
  16. kodit/domain/services/bm25_service.py +1 -6
  17. kodit/domain/services/index_service.py +23 -57
  18. kodit/domain/services/task_status_query_service.py +19 -0
  19. kodit/domain/value_objects.py +53 -8
  20. kodit/infrastructure/api/v1/dependencies.py +40 -12
  21. kodit/infrastructure/api/v1/routers/indexes.py +45 -0
  22. kodit/infrastructure/api/v1/schemas/task_status.py +39 -0
  23. kodit/infrastructure/cloning/git/working_copy.py +43 -7
  24. kodit/infrastructure/embedding/embedding_factory.py +8 -3
  25. kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +48 -55
  26. kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
  27. kodit/infrastructure/git/git_utils.py +3 -2
  28. kodit/infrastructure/mappers/index_mapper.py +1 -0
  29. kodit/infrastructure/mappers/task_status_mapper.py +85 -0
  30. kodit/infrastructure/reporting/__init__.py +1 -0
  31. kodit/infrastructure/reporting/db_progress.py +23 -0
  32. kodit/infrastructure/reporting/log_progress.py +37 -0
  33. kodit/infrastructure/reporting/tdqm_progress.py +38 -0
  34. kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
  35. kodit/infrastructure/sqlalchemy/entities.py +89 -2
  36. kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
  37. kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
  38. kodit/infrastructure/sqlalchemy/task_status_repository.py +79 -0
  39. kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
  40. kodit/mcp.py +15 -3
  41. kodit/migrations/env.py +0 -1
  42. kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
  43. {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/METADATA +1 -1
  44. {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/RECORD +47 -40
  45. kodit/domain/interfaces.py +0 -27
  46. kodit/infrastructure/ui/__init__.py +0 -1
  47. kodit/infrastructure/ui/progress.py +0 -170
  48. kodit/infrastructure/ui/spinner.py +0 -74
  49. kodit/reporting.py +0 -78
  50. {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/WHEEL +0 -0
  51. {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/entry_points.txt +0 -0
  52. {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -7,16 +7,15 @@ from typing import Any
7
7
  import httpx
8
8
  import litellm
9
9
  import structlog
10
+ import tiktoken
10
11
  from litellm import aembedding
11
12
 
12
13
  from kodit.config import Endpoint
13
14
  from kodit.domain.services.embedding_service import EmbeddingProvider
14
15
  from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
15
-
16
- # Constants
17
- MAX_TOKENS = 8192 # Conservative token limit for the embedding model
18
- BATCH_SIZE = 10 # Maximum number of items per API call
19
- DEFAULT_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent requests
16
+ from kodit.infrastructure.embedding.embedding_providers.batching import (
17
+ split_sub_batches,
18
+ )
20
19
 
21
20
 
22
21
  class LiteLLMEmbeddingProvider(EmbeddingProvider):
@@ -32,46 +31,36 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
32
31
  endpoint: The endpoint configuration containing all settings.
33
32
 
34
33
  """
35
- self.model_name = endpoint.model or "text-embedding-3-small"
36
- self.api_key = endpoint.api_key
37
- self.base_url = endpoint.base_url
38
- self.socket_path = endpoint.socket_path
39
- self.num_parallel_tasks = (
40
- endpoint.num_parallel_tasks or DEFAULT_NUM_PARALLEL_TASKS
41
- )
42
- self.timeout = endpoint.timeout or 30.0
43
- self.extra_params = endpoint.extra_params or {}
34
+ self.endpoint = endpoint
44
35
  self.log = structlog.get_logger(__name__)
36
+ self._encoding: tiktoken.Encoding | None = None
45
37
 
46
38
  # Configure LiteLLM with custom HTTPX client for Unix socket support if needed
47
39
  self._setup_litellm_client()
48
40
 
49
41
  def _setup_litellm_client(self) -> None:
50
42
  """Set up LiteLLM with custom HTTPX client for Unix socket support."""
51
- if self.socket_path:
43
+ if self.endpoint.socket_path:
52
44
  # Create HTTPX client with Unix socket transport
53
- transport = httpx.AsyncHTTPTransport(uds=self.socket_path)
45
+ transport = httpx.AsyncHTTPTransport(uds=self.endpoint.socket_path)
54
46
  unix_client = httpx.AsyncClient(
55
47
  transport=transport,
56
48
  base_url="http://localhost", # Base URL for Unix socket
57
- timeout=self.timeout,
49
+ timeout=self.endpoint.timeout,
58
50
  )
59
51
  # Set as LiteLLM's async client session
60
52
  litellm.aclient_session = unix_client
61
53
 
62
54
  def _split_sub_batches(
63
- self, data: list[EmbeddingRequest]
55
+ self, encoding: tiktoken.Encoding, data: list[EmbeddingRequest]
64
56
  ) -> list[list[EmbeddingRequest]]:
65
- """Split data into manageable batches.
66
-
67
- For LiteLLM, we use a simpler batching approach since token counting
68
- varies by provider. We use a conservative batch size approach.
69
- """
70
- batches = []
71
- for i in range(0, len(data), BATCH_SIZE):
72
- batch = data[i : i + BATCH_SIZE]
73
- batches.append(batch)
74
- return batches
57
+ """Proxy to the shared batching utility (kept for backward-compat)."""
58
+ return split_sub_batches(
59
+ encoding,
60
+ data,
61
+ max_tokens=self.endpoint.max_tokens,
62
+ batch_size=self.endpoint.num_parallel_tasks,
63
+ )
75
64
 
76
65
  async def _call_embeddings_api(self, texts: list[str]) -> Any:
77
66
  """Call the embeddings API using LiteLLM.
@@ -84,21 +73,21 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
84
73
 
85
74
  """
86
75
  kwargs = {
87
- "model": self.model_name,
76
+ "model": self.endpoint.model,
88
77
  "input": texts,
89
- "timeout": self.timeout,
78
+ "timeout": self.endpoint.timeout,
90
79
  }
91
80
 
92
81
  # Add API key if provided
93
- if self.api_key:
94
- kwargs["api_key"] = self.api_key
82
+ if self.endpoint.api_key:
83
+ kwargs["api_key"] = self.endpoint.api_key
95
84
 
96
85
  # Add base_url if provided
97
- if self.base_url:
98
- kwargs["api_base"] = self.base_url
86
+ if self.endpoint.base_url:
87
+ kwargs["api_base"] = self.endpoint.base_url
99
88
 
100
89
  # Add extra parameters
101
- kwargs.update(self.extra_params)
90
+ kwargs.update(self.endpoint.extra_params or {})
102
91
 
103
92
  try:
104
93
  # Use litellm's async embedding function
@@ -108,7 +97,7 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
108
97
  )
109
98
  except Exception as e:
110
99
  self.log.exception(
111
- "LiteLLM embedding API error", error=str(e), model=self.model_name
100
+ "LiteLLM embedding API error", error=str(e), model=self.endpoint.model
112
101
  )
113
102
  raise
114
103
 
@@ -121,32 +110,28 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
121
110
  return
122
111
 
123
112
  # Split into batches
124
- batched_data = self._split_sub_batches(data)
113
+ encoding = self._get_encoding()
114
+ batched_data = self._split_sub_batches(encoding, data)
125
115
 
126
116
  # Process batches concurrently with semaphore
127
- sem = asyncio.Semaphore(self.num_parallel_tasks)
117
+ sem = asyncio.Semaphore(self.endpoint.num_parallel_tasks or 10)
128
118
 
129
119
  async def _process_batch(
130
120
  batch: list[EmbeddingRequest],
131
121
  ) -> list[EmbeddingResponse]:
132
122
  async with sem:
133
- try:
134
- response = await self._call_embeddings_api(
135
- [item.text for item in batch]
123
+ response = await self._call_embeddings_api(
124
+ [item.text for item in batch]
125
+ )
126
+ embeddings_data = response.get("data", [])
127
+
128
+ return [
129
+ EmbeddingResponse(
130
+ snippet_id=item.snippet_id,
131
+ embedding=emb_data.get("embedding", []),
136
132
  )
137
- embeddings_data = response.get("data", [])
138
-
139
- return [
140
- EmbeddingResponse(
141
- snippet_id=item.snippet_id,
142
- embedding=emb_data.get("embedding", []),
143
- )
144
- for item, emb_data in zip(batch, embeddings_data, strict=True)
145
- ]
146
- except Exception as e:
147
- self.log.exception("Error embedding batch", error=str(e))
148
- # Return no embeddings for this batch if there was an error
149
- return []
133
+ for item, emb_data in zip(batch, embeddings_data, strict=True)
134
+ ]
150
135
 
151
136
  tasks = [_process_batch(batch) for batch in batched_data]
152
137
  for task in asyncio.as_completed(tasks):
@@ -155,9 +140,17 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
155
140
  async def close(self) -> None:
156
141
  """Close the provider and cleanup HTTPX client if using Unix sockets."""
157
142
  if (
158
- self.socket_path
143
+ self.endpoint.socket_path
159
144
  and hasattr(litellm, "aclient_session")
160
145
  and litellm.aclient_session
161
146
  ):
162
147
  await litellm.aclient_session.aclose()
163
148
  litellm.aclient_session = None
149
+
150
+ def _get_encoding(self) -> tiktoken.Encoding:
151
+ """Return (and cache) the tiktoken encoding for the chosen model."""
152
+ if self._encoding is None:
153
+ self._encoding = tiktoken.get_encoding(
154
+ "o200k_base"
155
+ ) # Reasonable default for most models, but might not be perfect.
156
+ return self._encoding
@@ -1,7 +1,9 @@
1
1
  """Local enrichment provider implementation."""
2
2
 
3
+ import asyncio
3
4
  import os
4
5
  from collections.abc import AsyncGenerator
6
+ from typing import Any
5
7
 
6
8
  import structlog
7
9
  import tiktoken
@@ -60,23 +62,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
60
62
  self.log.warning("No valid requests for enrichment")
61
63
  return
62
64
 
63
- from transformers.models.auto.modeling_auto import (
64
- AutoModelForCausalLM,
65
- )
66
- from transformers.models.auto.tokenization_auto import AutoTokenizer
67
-
68
- if self.tokenizer is None:
69
- self.tokenizer = AutoTokenizer.from_pretrained(
70
- self.model_name, padding_side="left"
71
- )
72
- if self.model is None:
73
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
74
- self.model = AutoModelForCausalLM.from_pretrained(
75
- self.model_name,
76
- torch_dtype="auto",
77
- trust_remote_code=True,
78
- device_map="auto",
65
+ def _init_model() -> None:
66
+ from transformers.models.auto.modeling_auto import (
67
+ AutoModelForCausalLM,
79
68
  )
69
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
70
+
71
+ if self.tokenizer is None:
72
+ self.tokenizer = AutoTokenizer.from_pretrained(
73
+ self.model_name, padding_side="left"
74
+ )
75
+ if self.model is None:
76
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
77
+ self.model = AutoModelForCausalLM.from_pretrained(
78
+ self.model_name,
79
+ torch_dtype="auto",
80
+ trust_remote_code=True,
81
+ device_map="auto",
82
+ )
83
+
84
+ await asyncio.to_thread(_init_model)
80
85
 
81
86
  # Prepare prompts
82
87
  prompts = [
@@ -96,20 +101,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
96
101
  ]
97
102
 
98
103
  for prompt in prompts:
99
- model_inputs = self.tokenizer( # type: ignore[misc]
100
- prompt["text"],
101
- return_tensors="pt",
102
- padding=True,
103
- truncation=True,
104
- ).to(self.model.device) # type: ignore[attr-defined]
105
- generated_ids = self.model.generate( # type: ignore[attr-defined]
106
- **model_inputs, max_new_tokens=self.context_window
107
- )
108
- input_ids = model_inputs["input_ids"][0]
109
- output_ids = generated_ids[0][len(input_ids) :].tolist()
110
- content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip( # type: ignore[attr-defined]
111
- "\n"
112
- )
104
+
105
+ def process_prompt(prompt: dict[str, Any]) -> str:
106
+ model_inputs = self.tokenizer( # type: ignore[misc]
107
+ prompt["text"],
108
+ return_tensors="pt",
109
+ padding=True,
110
+ truncation=True,
111
+ ).to(self.model.device) # type: ignore[attr-defined]
112
+ generated_ids = self.model.generate( # type: ignore[attr-defined]
113
+ **model_inputs, max_new_tokens=self.context_window
114
+ )
115
+ input_ids = model_inputs["input_ids"][0]
116
+ output_ids = generated_ids[0][len(input_ids) :].tolist()
117
+ return self.tokenizer.decode( # type: ignore[attr-defined]
118
+ output_ids, skip_special_tokens=True
119
+ ).strip( # type: ignore[attr-defined]
120
+ "\n"
121
+ )
122
+
123
+ content = await asyncio.to_thread(process_prompt, prompt)
113
124
  # Remove thinking tags from the response
114
125
  cleaned_content = clean_thinking_tags(content)
115
126
  yield EnrichmentResponse(
@@ -3,6 +3,7 @@
3
3
  import tempfile
4
4
 
5
5
  import git
6
+ import git.cmd
6
7
  import structlog
7
8
 
8
9
 
@@ -19,10 +20,10 @@ def is_valid_clone_target(target: str) -> bool:
19
20
  """
20
21
  with tempfile.TemporaryDirectory() as temp_dir:
21
22
  try:
22
- git.Repo.clone_from(target, temp_dir)
23
+ git.cmd.Git(temp_dir).ls_remote(target)
23
24
  except git.GitCommandError as e:
24
25
  structlog.get_logger(__name__).warning(
25
- "Failed to clone git repository",
26
+ "Failed to list git repository",
26
27
  target=target,
27
28
  error=e,
28
29
  )
@@ -15,6 +15,7 @@ from kodit.domain.value_objects import (
15
15
  from kodit.infrastructure.sqlalchemy import entities as db_entities
16
16
 
17
17
 
18
+ # TODO(Phil): Make this a pure mapper without any DB access # noqa: TD003, FIX002
18
19
  class IndexMapper:
19
20
  """Mapper for converting between domain Index aggregate and database entities."""
20
21
 
@@ -0,0 +1,85 @@
1
+ """Task status mapper."""
2
+
3
+ from kodit.domain import entities as domain_entities
4
+ from kodit.domain.value_objects import ReportingState, TaskOperation, TrackableType
5
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
6
+
7
+
8
+ class TaskStatusMapper:
9
+ """Mapper for converting between domain TaskStatus and database entities."""
10
+
11
+ @staticmethod
12
+ def from_domain_task_status(
13
+ task_status: domain_entities.TaskStatus,
14
+ ) -> db_entities.TaskStatus:
15
+ """Convert domain TaskStatus to database TaskStatus."""
16
+ return db_entities.TaskStatus(
17
+ id=task_status.id,
18
+ operation=task_status.operation,
19
+ created_at=task_status.created_at,
20
+ updated_at=task_status.updated_at,
21
+ trackable_id=task_status.trackable_id,
22
+ trackable_type=(
23
+ task_status.trackable_type.value if task_status.trackable_type else None
24
+ ),
25
+ parent=task_status.parent.id if task_status.parent else None,
26
+ state=(
27
+ task_status.state.value
28
+ if isinstance(task_status.state, ReportingState)
29
+ else task_status.state
30
+ ),
31
+ error=task_status.error,
32
+ total=task_status.total,
33
+ current=task_status.current,
34
+ message=task_status.message,
35
+ )
36
+
37
+ @staticmethod
38
+ def to_domain_task_status(
39
+ db_status: db_entities.TaskStatus,
40
+ ) -> domain_entities.TaskStatus:
41
+ """Convert database TaskStatus to domain TaskStatus."""
42
+ return domain_entities.TaskStatus(
43
+ id=db_status.id,
44
+ operation=TaskOperation(db_status.operation),
45
+ state=ReportingState(db_status.state),
46
+ created_at=db_status.created_at,
47
+ updated_at=db_status.updated_at,
48
+ trackable_id=db_status.trackable_id,
49
+ trackable_type=(
50
+ TrackableType(db_status.trackable_type)
51
+ if db_status.trackable_type
52
+ else None
53
+ ),
54
+ parent=None, # Parent relationships need to be reconstructed separately
55
+ error=db_status.error if db_status.error else None,
56
+ total=db_status.total,
57
+ current=db_status.current,
58
+ message=db_status.message,
59
+ )
60
+
61
+ @staticmethod
62
+ def to_domain_task_status_with_hierarchy(
63
+ db_statuses: list[db_entities.TaskStatus],
64
+ ) -> list[domain_entities.TaskStatus]:
65
+ """Convert database TaskStatus list to domain with parent-child hierarchy.
66
+
67
+ This method performs a two-pass conversion:
68
+ 1. First pass: Convert all DB entities to domain entities
69
+ 2. Second pass: Reconstruct parent-child relationships using ID mapping
70
+ """
71
+ # First pass: Convert all database entities to domain entities
72
+ domain_statuses = [
73
+ TaskStatusMapper.to_domain_task_status(db_status)
74
+ for db_status in db_statuses
75
+ ]
76
+
77
+ # Create ID-to-entity mapping for efficient parent lookup
78
+ id_to_entity = {status.id: status for status in domain_statuses}
79
+
80
+ # Second pass: Reconstruct parent-child relationships
81
+ for db_status, domain_status in zip(db_statuses, domain_statuses, strict=True):
82
+ if db_status.parent and db_status.parent in id_to_entity:
83
+ domain_status.parent = id_to_entity[db_status.parent]
84
+
85
+ return domain_statuses
@@ -0,0 +1 @@
1
+ """Reporting infrastructure."""
@@ -0,0 +1,23 @@
1
+ """Log progress using structlog."""
2
+
3
+ import structlog
4
+
5
+ from kodit.config import ReportingConfig
6
+ from kodit.domain.entities import TaskStatus
7
+ from kodit.domain.protocols import ReportingModule, TaskStatusRepository
8
+
9
+
10
+ class DBProgressReportingModule(ReportingModule):
11
+ """Database progress reporting module."""
12
+
13
+ def __init__(
14
+ self, task_status_repository: TaskStatusRepository, config: ReportingConfig
15
+ ) -> None:
16
+ """Initialize the database progress reporting module."""
17
+ self.task_status_repository = task_status_repository
18
+ self.config = config
19
+ self._log = structlog.get_logger(__name__)
20
+
21
+ async def on_change(self, progress: TaskStatus) -> None:
22
+ """On step changed - update task status in database."""
23
+ await self.task_status_repository.save(progress)
@@ -0,0 +1,37 @@
1
+ """Log progress using structlog."""
2
+
3
+ from datetime import UTC, datetime
4
+
5
+ import structlog
6
+
7
+ from kodit.config import ReportingConfig
8
+ from kodit.domain.entities import TaskStatus
9
+ from kodit.domain.protocols import ReportingModule
10
+ from kodit.domain.value_objects import ReportingState
11
+
12
+
13
+ class LoggingReportingModule(ReportingModule):
14
+ """Logging reporting module."""
15
+
16
+ def __init__(self, config: ReportingConfig) -> None:
17
+ """Initialize the logging reporting module."""
18
+ self.config = config
19
+ self._log = structlog.get_logger(__name__)
20
+ self._last_log_time: datetime = datetime.now(UTC)
21
+
22
+ async def on_change(self, progress: TaskStatus) -> None:
23
+ """On step changed."""
24
+ current_time = datetime.now(UTC)
25
+ time_since_last_log = current_time - self._last_log_time
26
+ step = progress
27
+
28
+ if (
29
+ step.state != ReportingState.IN_PROGRESS
30
+ or time_since_last_log >= self.config.log_time_interval
31
+ ):
32
+ self._log.info(
33
+ step.operation,
34
+ state=step.state,
35
+ completion_percent=step.completion_percent,
36
+ )
37
+ self._last_log_time = current_time
@@ -0,0 +1,38 @@
1
+ """TQDM progress."""
2
+
3
+ from tqdm import tqdm
4
+
5
+ from kodit.config import ReportingConfig
6
+ from kodit.domain.entities import TaskStatus
7
+ from kodit.domain.protocols import ReportingModule
8
+ from kodit.domain.value_objects import ReportingState
9
+
10
+
11
+ class TQDMReportingModule(ReportingModule):
12
+ """TQDM reporting module."""
13
+
14
+ def __init__(self, config: ReportingConfig) -> None:
15
+ """Initialize the TQDM reporting module."""
16
+ self.config = config
17
+ self.pbar = tqdm()
18
+
19
+ async def on_change(self, progress: TaskStatus) -> None:
20
+ """On step changed."""
21
+ step = progress
22
+ if step.state == ReportingState.COMPLETED:
23
+ self.pbar.close()
24
+ return
25
+
26
+ self.pbar.set_description(step.operation)
27
+ self.pbar.refresh()
28
+ # Update description if message is provided
29
+ if step.error:
30
+ # Fix the event message to a specific size so it's not jumping around
31
+ # If it's too small, add spaces
32
+ # If it's too large, truncate
33
+ if len(step.error) < 30:
34
+ self.pbar.set_description(step.error + " " * (30 - len(step.error)))
35
+ else:
36
+ self.pbar.set_description(step.error[-30:])
37
+ else:
38
+ self.pbar.set_description(step.operation)
@@ -1,85 +1,64 @@
1
1
  """SQLAlchemy implementation of embedding repository."""
2
2
 
3
+ from collections.abc import Callable
4
+
3
5
  import numpy as np
4
6
  from sqlalchemy import select
5
7
  from sqlalchemy.ext.asyncio import AsyncSession
6
8
 
7
9
  from kodit.infrastructure.sqlalchemy.entities import Embedding, EmbeddingType
10
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
8
11
 
9
12
 
10
- class SqlAlchemyEmbeddingRepository:
11
- """SQLAlchemy implementation of embedding repository."""
12
-
13
- def __init__(self, session: AsyncSession) -> None:
14
- """Initialize the SQLAlchemy embedding repository.
15
-
16
- Args:
17
- session: The SQLAlchemy async session to use for database operations
18
-
19
- """
20
- self.session = session
13
+ def create_embedding_repository(
14
+ session_factory: Callable[[], AsyncSession],
15
+ ) -> "SqlAlchemyEmbeddingRepository":
16
+ """Create an embedding repository."""
17
+ uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
18
+ return SqlAlchemyEmbeddingRepository(uow)
21
19
 
22
- async def create_embedding(self, embedding: Embedding) -> Embedding:
23
- """Create a new embedding record in the database.
24
20
 
25
- Args:
26
- embedding: The Embedding instance to create
21
+ class SqlAlchemyEmbeddingRepository:
22
+ """SQLAlchemy implementation of embedding repository."""
27
23
 
28
- Returns:
29
- The created Embedding instance
24
+ def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
25
+ """Initialize the SQLAlchemy embedding repository."""
26
+ self.uow = uow
30
27
 
31
- """
32
- self.session.add(embedding)
33
- return embedding
28
+ async def create_embedding(self, embedding: Embedding) -> None:
29
+ """Create a new embedding record in the database."""
30
+ async with self.uow:
31
+ self.uow.session.add(embedding)
34
32
 
35
33
  async def get_embedding_by_snippet_id_and_type(
36
34
  self, snippet_id: int, embedding_type: EmbeddingType
37
35
  ) -> Embedding | None:
38
- """Get an embedding by its snippet ID and type.
39
-
40
- Args:
41
- snippet_id: The ID of the snippet to get the embedding for
42
- embedding_type: The type of embedding to get
43
-
44
- Returns:
45
- The Embedding instance if found, None otherwise
46
-
47
- """
48
- query = select(Embedding).where(
49
- Embedding.snippet_id == snippet_id,
50
- Embedding.type == embedding_type,
51
- )
52
- result = await self.session.execute(query)
53
- return result.scalar_one_or_none()
36
+ """Get an embedding by its snippet ID and type."""
37
+ async with self.uow:
38
+ query = select(Embedding).where(
39
+ Embedding.snippet_id == snippet_id,
40
+ Embedding.type == embedding_type,
41
+ )
42
+ result = await self.uow.session.execute(query)
43
+ return result.scalar_one_or_none()
54
44
 
55
45
  async def list_embeddings_by_type(
56
46
  self, embedding_type: EmbeddingType
57
47
  ) -> list[Embedding]:
58
- """List all embeddings of a given type.
59
-
60
- Args:
61
- embedding_type: The type of embeddings to list
62
-
63
- Returns:
64
- A list of Embedding instances
65
-
66
- """
67
- query = select(Embedding).where(Embedding.type == embedding_type)
68
- result = await self.session.execute(query)
69
- return list(result.scalars())
48
+ """List all embeddings of a given type."""
49
+ async with self.uow:
50
+ query = select(Embedding).where(Embedding.type == embedding_type)
51
+ result = await self.uow.session.execute(query)
52
+ return list(result.scalars())
70
53
 
71
54
  async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
72
- """Delete all embeddings for a snippet.
73
-
74
- Args:
75
- snippet_id: The ID of the snippet to delete embeddings for
76
-
77
- """
78
- query = select(Embedding).where(Embedding.snippet_id == snippet_id)
79
- result = await self.session.execute(query)
80
- embeddings = result.scalars().all()
81
- for embedding in embeddings:
82
- await self.session.delete(embedding)
55
+ """Delete all embeddings for a snippet."""
56
+ async with self.uow:
57
+ query = select(Embedding).where(Embedding.snippet_id == snippet_id)
58
+ result = await self.uow.session.execute(query)
59
+ embeddings = result.scalars().all()
60
+ for embedding in embeddings:
61
+ await self.uow.session.delete(embedding)
83
62
 
84
63
  async def list_semantic_results(
85
64
  self,
@@ -130,17 +109,17 @@ class SqlAlchemyEmbeddingRepository:
130
109
  List of (snippet_id, embedding) tuples
131
110
 
132
111
  """
133
- # Only select the fields we need and use a more efficient query
134
- query = select(Embedding.snippet_id, Embedding.embedding).where(
135
- Embedding.type == embedding_type
136
- )
112
+ async with self.uow:
113
+ query = select(Embedding.snippet_id, Embedding.embedding).where(
114
+ Embedding.type == embedding_type
115
+ )
137
116
 
138
- # Add snippet_ids filter if provided
139
- if snippet_ids is not None:
140
- query = query.where(Embedding.snippet_id.in_(snippet_ids))
117
+ # Add snippet_ids filter if provided
118
+ if snippet_ids is not None:
119
+ query = query.where(Embedding.snippet_id.in_(snippet_ids))
141
120
 
142
- rows = await self.session.execute(query)
143
- return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
121
+ rows = await self.uow.session.execute(query)
122
+ return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
144
123
 
145
124
  def _prepare_vectors(
146
125
  self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]