kodit 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (95) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +53 -23
  3. kodit/application/factories/reporting_factory.py +6 -2
  4. kodit/application/factories/server_factory.py +311 -0
  5. kodit/application/services/code_search_application_service.py +144 -0
  6. kodit/application/services/commit_indexing_application_service.py +543 -0
  7. kodit/application/services/indexing_worker_service.py +13 -44
  8. kodit/application/services/queue_service.py +24 -3
  9. kodit/application/services/reporting.py +0 -2
  10. kodit/application/services/sync_scheduler.py +15 -31
  11. kodit/cli.py +2 -753
  12. kodit/cli_utils.py +2 -9
  13. kodit/config.py +1 -94
  14. kodit/database.py +38 -1
  15. kodit/domain/{entities.py → entities/__init__.py} +50 -195
  16. kodit/domain/entities/git.py +190 -0
  17. kodit/domain/factories/__init__.py +1 -0
  18. kodit/domain/factories/git_repo_factory.py +76 -0
  19. kodit/domain/protocols.py +263 -64
  20. kodit/domain/services/bm25_service.py +5 -1
  21. kodit/domain/services/embedding_service.py +3 -0
  22. kodit/domain/services/git_repository_service.py +429 -0
  23. kodit/domain/services/git_service.py +300 -0
  24. kodit/domain/services/task_status_query_service.py +2 -2
  25. kodit/domain/value_objects.py +83 -114
  26. kodit/infrastructure/api/client/__init__.py +0 -2
  27. kodit/infrastructure/api/v1/__init__.py +0 -4
  28. kodit/infrastructure/api/v1/dependencies.py +92 -46
  29. kodit/infrastructure/api/v1/routers/__init__.py +0 -6
  30. kodit/infrastructure/api/v1/routers/commits.py +271 -0
  31. kodit/infrastructure/api/v1/routers/queue.py +2 -2
  32. kodit/infrastructure/api/v1/routers/repositories.py +282 -0
  33. kodit/infrastructure/api/v1/routers/search.py +31 -14
  34. kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
  35. kodit/infrastructure/api/v1/schemas/commit.py +96 -0
  36. kodit/infrastructure/api/v1/schemas/context.py +2 -0
  37. kodit/infrastructure/api/v1/schemas/repository.py +128 -0
  38. kodit/infrastructure/api/v1/schemas/search.py +12 -9
  39. kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
  40. kodit/infrastructure/api/v1/schemas/tag.py +31 -0
  41. kodit/infrastructure/api/v1/schemas/task_status.py +2 -0
  42. kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
  43. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
  44. kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
  45. kodit/infrastructure/cloning/git/working_copy.py +1 -1
  46. kodit/infrastructure/embedding/embedding_factory.py +3 -2
  47. kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
  48. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
  49. kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
  50. kodit/infrastructure/indexing/fusion_service.py +1 -1
  51. kodit/infrastructure/mappers/git_mapper.py +193 -0
  52. kodit/infrastructure/mappers/snippet_mapper.py +106 -0
  53. kodit/infrastructure/mappers/task_mapper.py +5 -44
  54. kodit/infrastructure/reporting/log_progress.py +8 -5
  55. kodit/infrastructure/reporting/telemetry_progress.py +21 -0
  56. kodit/infrastructure/slicing/slicer.py +32 -31
  57. kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
  58. kodit/infrastructure/sqlalchemy/entities.py +394 -158
  59. kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
  60. kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
  61. kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
  62. kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
  63. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
  64. kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
  65. kodit/infrastructure/sqlalchemy/task_status_repository.py +24 -12
  66. kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
  67. kodit/mcp.py +12 -30
  68. kodit/migrations/env.py +1 -0
  69. kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
  70. kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
  71. kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
  72. kodit/py.typed +0 -0
  73. kodit/utils/dump_openapi.py +7 -4
  74. kodit/utils/path_utils.py +29 -0
  75. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
  76. kodit-0.5.0.dist-info/RECORD +137 -0
  77. kodit/application/factories/code_indexing_factory.py +0 -195
  78. kodit/application/services/auto_indexing_service.py +0 -99
  79. kodit/application/services/code_indexing_application_service.py +0 -410
  80. kodit/domain/services/index_query_service.py +0 -70
  81. kodit/domain/services/index_service.py +0 -269
  82. kodit/infrastructure/api/client/index_client.py +0 -57
  83. kodit/infrastructure/api/v1/routers/indexes.py +0 -164
  84. kodit/infrastructure/api/v1/schemas/index.py +0 -101
  85. kodit/infrastructure/bm25/bm25_factory.py +0 -28
  86. kodit/infrastructure/cloning/__init__.py +0 -1
  87. kodit/infrastructure/cloning/metadata.py +0 -98
  88. kodit/infrastructure/mappers/index_mapper.py +0 -345
  89. kodit/infrastructure/reporting/tdqm_progress.py +0 -38
  90. kodit/infrastructure/slicing/language_detection_service.py +0 -18
  91. kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
  92. kodit-0.4.3.dist-info/RECORD +0 -125
  93. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
  94. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
  95. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,252 @@
1
+ """SQLAlchemy implementation of GitRepoRepository."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ from pydantic import AnyUrl
6
+ from sqlalchemy import delete, select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain.entities.git import GitRepo
10
+ from kodit.domain.protocols import GitRepoRepository
11
+ from kodit.infrastructure.mappers.git_mapper import GitMapper
12
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
13
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
14
+
15
+
16
+ def create_git_repo_repository(
17
+ session_factory: Callable[[], AsyncSession],
18
+ ) -> GitRepoRepository:
19
+ """Create a git repository."""
20
+ return SqlAlchemyGitRepoRepository(session_factory=session_factory)
21
+
22
+
23
+ class SqlAlchemyGitRepoRepository(GitRepoRepository):
24
+ """SQLAlchemy implementation of GitRepoRepository.
25
+
26
+ This repository manages the GitRepo aggregate, including:
27
+ - GitRepo entity
28
+ - GitBranch entities
29
+ - GitTag entities
30
+
31
+ Note: Commits are now managed by the separate GitCommitRepository.
32
+ """
33
+
34
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
35
+ """Initialize the repository."""
36
+ self.session_factory = session_factory
37
+
38
+ @property
39
+ def _mapper(self) -> GitMapper:
40
+ return GitMapper()
41
+
42
+ async def save(self, repo: GitRepo) -> GitRepo:
43
+ """Save or update a repository with all its branches, commits, and tags."""
44
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
45
+ # 1. Save or update the GitRepo entity
46
+ # Check if repo exists by URI (for new repos from domain)
47
+ existing_repo_stmt = select(db_entities.GitRepo).where(
48
+ db_entities.GitRepo.sanitized_remote_uri
49
+ == str(repo.sanitized_remote_uri)
50
+ )
51
+ existing_repo = await session.scalar(existing_repo_stmt)
52
+
53
+ if existing_repo:
54
+ # Update existing repo found by URI
55
+ existing_repo.remote_uri = str(repo.remote_uri)
56
+ existing_repo.cloned_path = repo.cloned_path
57
+ existing_repo.last_scanned_at = repo.last_scanned_at
58
+ existing_repo.num_commits = repo.num_commits
59
+ existing_repo.num_branches = repo.num_branches
60
+ existing_repo.num_tags = repo.num_tags
61
+ db_repo = existing_repo
62
+ repo.id = existing_repo.id # Set the domain ID
63
+ else:
64
+ # Create new repo
65
+ db_repo = db_entities.GitRepo(
66
+ sanitized_remote_uri=str(repo.sanitized_remote_uri),
67
+ remote_uri=str(repo.remote_uri),
68
+ cloned_path=repo.cloned_path,
69
+ last_scanned_at=repo.last_scanned_at,
70
+ num_commits=repo.num_commits,
71
+ num_branches=repo.num_branches,
72
+ num_tags=repo.num_tags,
73
+ )
74
+ session.add(db_repo)
75
+ await session.flush() # Get the new ID
76
+ repo.id = db_repo.id # Set the domain ID
77
+
78
+ # 2. Save tracking branch
79
+ await self._save_tracking_branch(session, repo)
80
+
81
+ await session.flush()
82
+ return repo
83
+
84
+
85
+
86
+ async def _save_tracking_branch(self, session: AsyncSession, repo: GitRepo) -> None:
87
+ """Save tracking branch if it doesn't exist."""
88
+ if not repo.tracking_branch:
89
+ return
90
+
91
+ existing_tracking_branch = await session.get(
92
+ db_entities.GitTrackingBranch, [repo.id, repo.tracking_branch.name]
93
+ )
94
+ if not existing_tracking_branch and repo.id is not None:
95
+ db_tracking_branch = db_entities.GitTrackingBranch(
96
+ repo_id=repo.id,
97
+ name=repo.tracking_branch.name,
98
+ )
99
+ session.add(db_tracking_branch)
100
+
101
+
102
+ async def get_by_id(self, repo_id: int) -> GitRepo:
103
+ """Get repository by ID with all associated data."""
104
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
105
+ db_repo = await session.get(db_entities.GitRepo, repo_id)
106
+ if not db_repo:
107
+ raise ValueError(f"Repository with ID {repo_id} not found")
108
+
109
+ return await self._load_complete_repo(session, db_repo)
110
+
111
+ async def get_by_uri(self, sanitized_uri: AnyUrl) -> GitRepo:
112
+ """Get repository by sanitized URI with all associated data."""
113
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
114
+ stmt = select(db_entities.GitRepo).where(
115
+ db_entities.GitRepo.sanitized_remote_uri == str(sanitized_uri)
116
+ )
117
+ db_repo = await session.scalar(stmt)
118
+ if not db_repo:
119
+ raise ValueError(f"Repository with URI {sanitized_uri} not found")
120
+
121
+ return await self._load_complete_repo(session, db_repo)
122
+
123
+ async def get_by_commit(self, commit_sha: str) -> GitRepo:
124
+ """Get repository by commit SHA with all associated data."""
125
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
126
+ # Find the commit first
127
+ stmt = select(db_entities.GitCommit).where(
128
+ db_entities.GitCommit.commit_sha == commit_sha
129
+ )
130
+ db_commit = await session.scalar(stmt)
131
+ if not db_commit:
132
+ raise ValueError(f"Commit with SHA {commit_sha} not found")
133
+
134
+ # Get the repo
135
+ db_repo = await session.get(db_entities.GitRepo, db_commit.repo_id)
136
+ if not db_repo:
137
+ raise ValueError(f"Repository with commit SHA {commit_sha} not found")
138
+
139
+ return await self._load_complete_repo(session, db_repo)
140
+
141
+ async def get_all(self) -> list[GitRepo]:
142
+ """Get all repositories."""
143
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
144
+ stmt = select(db_entities.GitRepo)
145
+ db_repos = (await session.scalars(stmt)).all()
146
+
147
+ repos = []
148
+ for db_repo in db_repos:
149
+ repo = await self._load_complete_repo(session, db_repo)
150
+ repos.append(repo)
151
+
152
+ return repos
153
+
154
+ async def delete(self, sanitized_uri: AnyUrl) -> bool:
155
+ """Delete only the repository entity itself.
156
+
157
+ According to DDD principles, this repository should only delete
158
+ the GitRepo entity it directly controls. Related entities (commits,
159
+ branches, tags, snippets) should be deleted by their respective
160
+ repositories before calling this method.
161
+ """
162
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
163
+ # Find the repo
164
+ stmt = select(db_entities.GitRepo).where(
165
+ db_entities.GitRepo.sanitized_remote_uri == str(sanitized_uri)
166
+ )
167
+ db_repo = await session.scalar(stmt)
168
+ if not db_repo:
169
+ return False
170
+
171
+ # Delete tracking branches first (they reference the repo)
172
+ del_tracking_branches_stmt = delete(db_entities.GitTrackingBranch).where(
173
+ db_entities.GitTrackingBranch.repo_id == db_repo.id
174
+ )
175
+ await session.execute(del_tracking_branches_stmt)
176
+
177
+ # Delete only the repo entity itself
178
+ # Foreign key constraints will prevent deletion if related entities exist
179
+ del_stmt = delete(db_entities.GitRepo).where(
180
+ db_entities.GitRepo.id == db_repo.id
181
+ )
182
+ await session.execute(del_stmt)
183
+ return True
184
+
185
+
186
+ async def _load_complete_repo(
187
+ self, session: AsyncSession, db_repo: db_entities.GitRepo
188
+ ) -> GitRepo:
189
+ """Load a complete repo with all its associations."""
190
+ all_tags = list(
191
+ (
192
+ await session.scalars(
193
+ select(db_entities.GitTag).where(
194
+ db_entities.GitTag.repo_id == db_repo.id
195
+ )
196
+ )
197
+ ).all()
198
+ )
199
+ tracking_branch = await session.scalar(
200
+ select(db_entities.GitTrackingBranch).where(
201
+ db_entities.GitTrackingBranch.repo_id == db_repo.id
202
+ )
203
+ )
204
+
205
+ # Get tracking branch from branches table if needed
206
+ db_tracking_branch_entity = None
207
+ if tracking_branch:
208
+ db_tracking_branch_entity = await session.scalar(
209
+ select(db_entities.GitBranch).where(
210
+ db_entities.GitBranch.repo_id == db_repo.id,
211
+ db_entities.GitBranch.name == tracking_branch.name,
212
+ )
213
+ )
214
+
215
+ # Get only commits needed for tags and tracking branch
216
+ referenced_commit_shas = set()
217
+ for tag in all_tags:
218
+ referenced_commit_shas.add(tag.target_commit_sha)
219
+ if db_tracking_branch_entity:
220
+ referenced_commit_shas.add(db_tracking_branch_entity.head_commit_sha)
221
+
222
+ # Load only the referenced commits
223
+ referenced_commits = []
224
+ referenced_files = []
225
+ if referenced_commit_shas:
226
+ referenced_commits = list(
227
+ (
228
+ await session.scalars(
229
+ select(db_entities.GitCommit).where(
230
+ db_entities.GitCommit.commit_sha.in_(referenced_commit_shas)
231
+ )
232
+ )
233
+ ).all()
234
+ )
235
+ referenced_files = list(
236
+ (
237
+ await session.scalars(
238
+ select(db_entities.GitCommitFile).where(
239
+ db_entities.GitCommitFile.commit_sha.in_(referenced_commit_shas)
240
+ )
241
+ )
242
+ ).all()
243
+ )
244
+
245
+ return self._mapper.to_domain_git_repo(
246
+ db_repo=db_repo,
247
+ db_tracking_branch_entity=db_tracking_branch_entity,
248
+ db_commits=referenced_commits,
249
+ db_tags=all_tags,
250
+ db_commit_files=referenced_files,
251
+ db_tracking_branch=tracking_branch,
252
+ )
@@ -0,0 +1,257 @@
1
+ """SQLAlchemy implementation of GitTagRepository."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ from sqlalchemy import delete, func, insert, select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.domain.entities.git import GitCommit, GitFile, GitTag
9
+ from kodit.domain.protocols import GitTagRepository
10
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
11
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
12
+
13
+
14
+ def create_git_tag_repository(
15
+ session_factory: Callable[[], AsyncSession],
16
+ ) -> GitTagRepository:
17
+ """Create a git tag repository."""
18
+ return SqlAlchemyGitTagRepository(session_factory=session_factory)
19
+
20
+
21
+ class SqlAlchemyGitTagRepository(GitTagRepository):
22
+ """SQLAlchemy implementation of GitTagRepository."""
23
+
24
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
25
+ """Initialize the repository."""
26
+ self.session_factory = session_factory
27
+
28
+ async def get_by_name(self, tag_name: str, repo_id: int) -> GitTag:
29
+ """Get a tag by name and repository ID."""
30
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
31
+ # Get the tag
32
+ stmt = select(db_entities.GitTag).where(
33
+ db_entities.GitTag.name == tag_name,
34
+ db_entities.GitTag.repo_id == repo_id,
35
+ )
36
+ db_tag = await session.scalar(stmt)
37
+ if not db_tag:
38
+ raise ValueError(f"Tag {tag_name} not found in repo {repo_id}")
39
+
40
+ # Get the target commit
41
+ commit_stmt = select(db_entities.GitCommit).where(
42
+ db_entities.GitCommit.commit_sha == db_tag.target_commit_sha
43
+ )
44
+ db_commit = await session.scalar(commit_stmt)
45
+ if not db_commit:
46
+ raise ValueError(f"Target commit {db_tag.target_commit_sha} not found")
47
+
48
+ # Get files for the target commit
49
+ files_stmt = select(db_entities.GitCommitFile).where(
50
+ db_entities.GitCommitFile.commit_sha == db_tag.target_commit_sha
51
+ )
52
+ db_files = (await session.scalars(files_stmt)).all()
53
+
54
+ domain_files = []
55
+ for db_file in db_files:
56
+ domain_file = GitFile(
57
+ blob_sha=db_file.blob_sha,
58
+ path=db_file.path,
59
+ mime_type=db_file.mime_type,
60
+ size=db_file.size,
61
+ extension=db_file.extension,
62
+ created_at=db_file.created_at,
63
+ )
64
+ domain_files.append(domain_file)
65
+
66
+ target_commit = GitCommit(
67
+ commit_sha=db_commit.commit_sha,
68
+ date=db_commit.date,
69
+ message=db_commit.message,
70
+ parent_commit_sha=db_commit.parent_commit_sha,
71
+ files=domain_files,
72
+ author=db_commit.author,
73
+ created_at=db_commit.created_at,
74
+ updated_at=db_commit.updated_at,
75
+ )
76
+
77
+ return GitTag(
78
+ repo_id=db_tag.repo_id,
79
+ name=db_tag.name,
80
+ target_commit=target_commit,
81
+ created_at=db_tag.created_at,
82
+ updated_at=db_tag.updated_at,
83
+ )
84
+
85
+ async def get_by_repo_id(self, repo_id: int) -> list[GitTag]:
86
+ """Get all tags for a repository."""
87
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
88
+ # Get all tags for the repo
89
+ tags_stmt = select(db_entities.GitTag).where(
90
+ db_entities.GitTag.repo_id == repo_id
91
+ )
92
+ db_tags = (await session.scalars(tags_stmt)).all()
93
+
94
+ if not db_tags:
95
+ return []
96
+
97
+ commit_shas = [tag.target_commit_sha for tag in db_tags]
98
+
99
+ # Get all target commits for these tags
100
+ commits_stmt = select(db_entities.GitCommit).where(
101
+ db_entities.GitCommit.commit_sha.in_(commit_shas)
102
+ )
103
+ db_commits = (await session.scalars(commits_stmt)).all()
104
+
105
+ # Get all files for these commits
106
+ files_stmt = select(db_entities.GitCommitFile).where(
107
+ db_entities.GitCommitFile.commit_sha.in_(commit_shas)
108
+ )
109
+ db_files = (await session.scalars(files_stmt)).all()
110
+
111
+ # Group files by commit SHA
112
+ files_by_commit: dict[str, list[GitFile]] = {}
113
+ for db_file in db_files:
114
+ if db_file.commit_sha not in files_by_commit:
115
+ files_by_commit[db_file.commit_sha] = []
116
+
117
+ domain_file = GitFile(
118
+ blob_sha=db_file.blob_sha,
119
+ path=db_file.path,
120
+ mime_type=db_file.mime_type,
121
+ size=db_file.size,
122
+ extension=db_file.extension,
123
+ created_at=db_file.created_at,
124
+ )
125
+ files_by_commit[db_file.commit_sha].append(domain_file)
126
+
127
+ # Create commit lookup
128
+ commits_by_sha = {commit.commit_sha: commit for commit in db_commits}
129
+
130
+ # Create domain tags
131
+ domain_tags = []
132
+ for db_tag in db_tags:
133
+ db_commit = commits_by_sha.get(db_tag.target_commit_sha)
134
+ if not db_commit:
135
+ continue
136
+
137
+ commit_files = files_by_commit.get(db_tag.target_commit_sha, [])
138
+ target_commit = GitCommit(
139
+ commit_sha=db_commit.commit_sha,
140
+ date=db_commit.date,
141
+ message=db_commit.message,
142
+ parent_commit_sha=db_commit.parent_commit_sha,
143
+ files=commit_files,
144
+ author=db_commit.author,
145
+ created_at=db_commit.created_at,
146
+ updated_at=db_commit.updated_at,
147
+ )
148
+
149
+ domain_tag = GitTag(
150
+ repo_id=db_tag.repo_id,
151
+ name=db_tag.name,
152
+ target_commit=target_commit,
153
+ created_at=db_tag.created_at,
154
+ updated_at=db_tag.updated_at,
155
+ )
156
+ domain_tags.append(domain_tag)
157
+
158
+ return domain_tags
159
+
160
+ async def save(self, tag: GitTag, repo_id: int) -> GitTag:
161
+ """Save a tag to a repository."""
162
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
163
+ # Set repo_id on the tag
164
+ tag.repo_id = repo_id
165
+
166
+ # Check if tag already exists
167
+ existing_tag = await session.get(
168
+ db_entities.GitTag, (repo_id, tag.name)
169
+ )
170
+
171
+ if existing_tag:
172
+ # Update existing tag
173
+ existing_tag.target_commit_sha = tag.target_commit.commit_sha
174
+ if tag.updated_at:
175
+ existing_tag.updated_at = tag.updated_at
176
+ else:
177
+ # Create new tag
178
+ db_tag = db_entities.GitTag(
179
+ repo_id=repo_id,
180
+ name=tag.name,
181
+ target_commit_sha=tag.target_commit.commit_sha,
182
+ )
183
+ session.add(db_tag)
184
+
185
+ return tag
186
+
187
+ async def save_bulk(self, tags: list[GitTag], repo_id: int) -> None:
188
+ """Bulk save tags to a repository."""
189
+ if not tags:
190
+ return
191
+
192
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
193
+ # Get existing tags in bulk
194
+ existing_tags_stmt = select(db_entities.GitTag).where(
195
+ db_entities.GitTag.repo_id == repo_id,
196
+ db_entities.GitTag.name.in_([tag.name for tag in tags]),
197
+ )
198
+ existing_tags = (await session.scalars(existing_tags_stmt)).all()
199
+ existing_tag_names = {tag.name for tag in existing_tags}
200
+
201
+ # Update existing tags
202
+ for existing_tag in existing_tags:
203
+ for tag in tags:
204
+ if (
205
+ tag.name == existing_tag.name
206
+ and existing_tag.target_commit_sha
207
+ != tag.target_commit.commit_sha
208
+ ):
209
+ existing_tag.target_commit_sha = tag.target_commit.commit_sha
210
+ break
211
+
212
+ # Prepare new tags for bulk insert
213
+ new_tags_data = [
214
+ {
215
+ "repo_id": repo_id,
216
+ "name": tag.name,
217
+ "target_commit_sha": tag.target_commit.commit_sha,
218
+ }
219
+ for tag in tags
220
+ if tag.name not in existing_tag_names
221
+ ]
222
+
223
+ # Bulk insert new tags in chunks to avoid parameter limits
224
+ if new_tags_data:
225
+ chunk_size = 1000 # Conservative chunk size for parameter limits
226
+ for i in range(0, len(new_tags_data), chunk_size):
227
+ chunk = new_tags_data[i : i + chunk_size]
228
+ stmt = insert(db_entities.GitTag).values(chunk)
229
+ await session.execute(stmt)
230
+
231
+ async def exists(self, tag_name: str, repo_id: int) -> bool:
232
+ """Check if a tag exists."""
233
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
234
+ stmt = select(db_entities.GitTag.name).where(
235
+ db_entities.GitTag.name == tag_name,
236
+ db_entities.GitTag.repo_id == repo_id,
237
+ )
238
+ result = await session.scalar(stmt)
239
+ return result is not None
240
+
241
+ async def delete_by_repo_id(self, repo_id: int) -> None:
242
+ """Delete all tags for a repository."""
243
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
244
+ # Delete tags
245
+ del_tags_stmt = delete(db_entities.GitTag).where(
246
+ db_entities.GitTag.repo_id == repo_id
247
+ )
248
+ await session.execute(del_tags_stmt)
249
+
250
+ async def count_by_repo_id(self, repo_id: int) -> int:
251
+ """Count the number of tags for a repository."""
252
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
253
+ stmt = select(func.count()).select_from(db_entities.GitTag).where(
254
+ db_entities.GitTag.repo_id == repo_id
255
+ )
256
+ result = await session.scalar(stmt)
257
+ return result or 0