kodit 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (34) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +6 -1
  3. kodit/application/factories/code_indexing_factory.py +14 -12
  4. kodit/application/factories/reporting_factory.py +10 -5
  5. kodit/application/services/auto_indexing_service.py +28 -32
  6. kodit/application/services/code_indexing_application_service.py +43 -26
  7. kodit/application/services/indexing_worker_service.py +10 -12
  8. kodit/application/services/reporting.py +72 -54
  9. kodit/cli.py +68 -78
  10. kodit/config.py +2 -2
  11. kodit/domain/entities.py +99 -1
  12. kodit/domain/protocols.py +28 -3
  13. kodit/domain/services/index_service.py +11 -9
  14. kodit/domain/services/task_status_query_service.py +19 -0
  15. kodit/domain/value_objects.py +26 -29
  16. kodit/infrastructure/api/v1/dependencies.py +19 -4
  17. kodit/infrastructure/api/v1/routers/indexes.py +45 -0
  18. kodit/infrastructure/api/v1/schemas/task_status.py +39 -0
  19. kodit/infrastructure/cloning/git/working_copy.py +9 -2
  20. kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
  21. kodit/infrastructure/mappers/task_status_mapper.py +85 -0
  22. kodit/infrastructure/reporting/db_progress.py +23 -0
  23. kodit/infrastructure/reporting/log_progress.py +5 -33
  24. kodit/infrastructure/reporting/tdqm_progress.py +10 -45
  25. kodit/infrastructure/sqlalchemy/entities.py +61 -0
  26. kodit/infrastructure/sqlalchemy/task_status_repository.py +79 -0
  27. kodit/mcp.py +6 -2
  28. kodit/migrations/env.py +0 -1
  29. kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
  30. {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/METADATA +1 -1
  31. {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/RECORD +34 -28
  32. {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/WHEEL +0 -0
  33. {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/entry_points.txt +0 -0
  34. {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.2'
32
- __version_tuple__ = version_tuple = (0, 4, 2)
31
+ __version__ = version = '0.4.3'
32
+ __version_tuple__ = version_tuple = (0, 4, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
kodit/app.py CHANGED
@@ -19,6 +19,9 @@ from kodit.infrastructure.api.v1.routers import (
19
19
  search_router,
20
20
  )
21
21
  from kodit.infrastructure.api.v1.schemas.context import AppLifespanState
22
+ from kodit.infrastructure.sqlalchemy.task_status_repository import (
23
+ create_task_status_repository,
24
+ )
22
25
  from kodit.mcp import mcp
23
26
  from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
24
27
 
@@ -35,7 +38,9 @@ async def app_lifespan(_: FastAPI) -> AsyncIterator[AppLifespanState]:
35
38
  # App context has already been configured by the CLI.
36
39
  app_context = AppContext()
37
40
  db = await app_context.get_db()
38
- operation = create_server_operation()
41
+ operation = create_server_operation(
42
+ create_task_status_repository(db.session_factory)
43
+ )
39
44
 
40
45
  # Start the queue worker service
41
46
  _indexing_worker_service = IndexingWorkerService(
@@ -51,22 +51,26 @@ from kodit.infrastructure.sqlalchemy.entities import EmbeddingType
51
51
  from kodit.infrastructure.sqlalchemy.index_repository import (
52
52
  create_index_repository,
53
53
  )
54
+ from kodit.infrastructure.sqlalchemy.task_status_repository import (
55
+ create_task_status_repository,
56
+ )
54
57
 
55
58
 
56
59
  def create_code_indexing_application_service(
57
60
  app_context: AppContext,
58
- session: AsyncSession,
59
61
  session_factory: Callable[[], AsyncSession],
60
62
  operation: ProgressTracker,
61
63
  ) -> CodeIndexingApplicationService:
62
64
  """Create a unified code indexing application service with all dependencies."""
63
65
  # Create domain services
64
- bm25_service = BM25DomainService(bm25_repository_factory(app_context, session))
66
+ bm25_service = BM25DomainService(
67
+ bm25_repository_factory(app_context, session_factory())
68
+ )
65
69
  code_search_service = embedding_domain_service_factory(
66
- "code", app_context, session, session_factory
70
+ "code", app_context, session_factory(), session_factory
67
71
  )
68
72
  text_search_service = embedding_domain_service_factory(
69
- "text", app_context, session, session_factory
73
+ "text", app_context, session_factory(), session_factory
70
74
  )
71
75
  enrichment_service = enrichment_domain_service_factory(app_context)
72
76
  index_repository = create_index_repository(session_factory=session_factory)
@@ -95,20 +99,17 @@ def create_code_indexing_application_service(
95
99
  code_search_service=code_search_service,
96
100
  text_search_service=text_search_service,
97
101
  enrichment_service=enrichment_service,
98
- session=session,
99
102
  operation=operation,
100
103
  )
101
104
 
102
105
 
103
106
  def create_cli_code_indexing_application_service(
104
107
  app_context: AppContext,
105
- session: AsyncSession,
106
108
  session_factory: Callable[[], AsyncSession],
107
109
  ) -> CodeIndexingApplicationService:
108
110
  """Create a CLI code indexing application service."""
109
111
  return create_code_indexing_application_service(
110
112
  app_context,
111
- session,
112
113
  session_factory,
113
114
  create_cli_operation(),
114
115
  )
@@ -116,23 +117,25 @@ def create_cli_code_indexing_application_service(
116
117
 
117
118
  def create_server_code_indexing_application_service(
118
119
  app_context: AppContext,
119
- session: AsyncSession,
120
120
  session_factory: Callable[[], AsyncSession],
121
121
  ) -> CodeIndexingApplicationService:
122
122
  """Create a server code indexing application service."""
123
123
  return create_code_indexing_application_service(
124
- app_context, session, session_factory, create_server_operation()
124
+ app_context,
125
+ session_factory,
126
+ create_server_operation(create_task_status_repository(session_factory)),
125
127
  )
126
128
 
127
129
 
128
130
  def create_fast_test_code_indexing_application_service(
129
131
  app_context: AppContext,
130
- session: AsyncSession,
131
132
  session_factory: Callable[[], AsyncSession],
132
133
  ) -> CodeIndexingApplicationService:
133
134
  """Create a fast test code indexing application service."""
134
135
  # Create domain services
135
- bm25_service = BM25DomainService(bm25_repository_factory(app_context, session))
136
+ bm25_service = BM25DomainService(
137
+ bm25_repository_factory(app_context, session_factory())
138
+ )
136
139
  embedding_repository = create_embedding_repository(session_factory=session_factory)
137
140
  operation = create_noop_operation()
138
141
 
@@ -188,6 +191,5 @@ def create_fast_test_code_indexing_application_service(
188
191
  code_search_service=code_search_service,
189
192
  text_search_service=text_search_service,
190
193
  enrichment_service=enrichment_service,
191
- session=session,
192
194
  operation=operation,
193
195
  )
@@ -1,27 +1,32 @@
1
1
  """Reporting factory."""
2
2
 
3
- from kodit.application.services.reporting import OperationType, ProgressTracker
3
+ from kodit.application.services.reporting import ProgressTracker, TaskOperation
4
4
  from kodit.config import ReportingConfig
5
+ from kodit.domain.protocols import TaskStatusRepository
6
+ from kodit.infrastructure.reporting.db_progress import DBProgressReportingModule
5
7
  from kodit.infrastructure.reporting.log_progress import LoggingReportingModule
6
8
  from kodit.infrastructure.reporting.tdqm_progress import TQDMReportingModule
7
9
 
8
10
 
9
11
  def create_noop_operation() -> ProgressTracker:
10
12
  """Create a noop reporter."""
11
- return ProgressTracker(OperationType.ROOT.value)
13
+ return ProgressTracker.create(TaskOperation.ROOT)
12
14
 
13
15
 
14
16
  def create_cli_operation(config: ReportingConfig | None = None) -> ProgressTracker:
15
17
  """Create a CLI reporter."""
16
18
  shared_config = config or ReportingConfig()
17
- s = ProgressTracker(OperationType.ROOT.value)
19
+ s = ProgressTracker.create(TaskOperation.ROOT)
18
20
  s.subscribe(TQDMReportingModule(shared_config))
19
21
  return s
20
22
 
21
23
 
22
- def create_server_operation(config: ReportingConfig | None = None) -> ProgressTracker:
24
+ def create_server_operation(
25
+ task_status_repository: TaskStatusRepository, config: ReportingConfig | None = None
26
+ ) -> ProgressTracker:
23
27
  """Create a server reporter."""
24
28
  shared_config = config or ReportingConfig()
25
- s = ProgressTracker(OperationType.ROOT.value)
29
+ s = ProgressTracker.create(TaskOperation.ROOT)
26
30
  s.subscribe(LoggingReportingModule(shared_config))
31
+ s.subscribe(DBProgressReportingModule(task_status_repository, shared_config))
27
32
  return s
@@ -62,38 +62,34 @@ class AutoIndexingService:
62
62
  ) -> None:
63
63
  """Index all configured sources in the background."""
64
64
  operation = operation or create_noop_operation()
65
- async with self.session_factory() as session:
66
- queue_service = QueueService(session_factory=self.session_factory)
67
- service = create_code_indexing_application_service(
68
- app_context=self.app_context,
69
- session=session,
70
- session_factory=self.session_factory,
71
- operation=operation,
72
- )
73
-
74
- for source in sources:
75
- try:
76
- # Only auto-index a source if it is new
77
- if await service.does_index_exist(source):
78
- self.log.info("Index already exists, skipping", source=source)
79
- continue
80
-
81
- self.log.info("Adding auto-indexing task to queue", source=source)
82
-
83
- # Create index
84
- index = await service.create_index_from_uri(source)
85
-
86
- await queue_service.enqueue_task(
87
- Task.create_index_update_task(
88
- index.id, QueuePriority.BACKGROUND
89
- )
90
- )
91
-
92
- except Exception as exc:
93
- self.log.exception(
94
- "Failed to auto-index source", source=source, error=str(exc)
95
- )
96
- # Continue with other sources even if one fails
65
+ queue_service = QueueService(session_factory=self.session_factory)
66
+ service = create_code_indexing_application_service(
67
+ app_context=self.app_context,
68
+ session_factory=self.session_factory,
69
+ operation=operation,
70
+ )
71
+
72
+ for source in sources:
73
+ try:
74
+ # Only auto-index a source if it is new
75
+ if await service.does_index_exist(source):
76
+ self.log.info("Index already exists, skipping", source=source)
77
+ continue
78
+
79
+ self.log.info("Adding auto-indexing task to queue", source=source)
80
+
81
+ # Create index
82
+ index = await service.create_index_from_uri(source)
83
+
84
+ await queue_service.enqueue_task(
85
+ Task.create_index_update_task(index.id, QueuePriority.BACKGROUND)
86
+ )
87
+
88
+ except Exception as exc:
89
+ self.log.exception(
90
+ "Failed to auto-index source", source=source, error=str(exc)
91
+ )
92
+ # Continue with other sources even if one fails
97
93
 
98
94
  async def stop(self) -> None:
99
95
  """Stop background indexing."""
@@ -4,11 +4,10 @@ from dataclasses import replace
4
4
  from datetime import UTC, datetime
5
5
 
6
6
  import structlog
7
- from sqlalchemy.ext.asyncio import AsyncSession
8
7
 
9
8
  from kodit.application.services.reporting import (
10
- OperationType,
11
9
  ProgressTracker,
10
+ TaskOperation,
12
11
  )
13
12
  from kodit.domain.entities import Index, Snippet
14
13
  from kodit.domain.protocols import IndexRepository
@@ -26,6 +25,7 @@ from kodit.domain.value_objects import (
26
25
  SearchRequest,
27
26
  SearchResult,
28
27
  SnippetSearchFilters,
28
+ TrackableType,
29
29
  )
30
30
  from kodit.log import log_event
31
31
 
@@ -42,7 +42,6 @@ class CodeIndexingApplicationService:
42
42
  code_search_service: EmbeddingDomainService,
43
43
  text_search_service: EmbeddingDomainService,
44
44
  enrichment_service: EnrichmentDomainService,
45
- session: AsyncSession,
46
45
  operation: ProgressTracker,
47
46
  ) -> None:
48
47
  """Initialize the code indexing application service."""
@@ -53,7 +52,6 @@ class CodeIndexingApplicationService:
53
52
  self.code_search_service = code_search_service
54
53
  self.text_search_service = text_search_service
55
54
  self.enrichment_service = enrichment_service
56
- self.session = session
57
55
  self.operation = operation
58
56
  self.log = structlog.get_logger(__name__)
59
57
 
@@ -67,7 +65,7 @@ class CodeIndexingApplicationService:
67
65
  async def create_index_from_uri(self, uri: str) -> Index:
68
66
  """Create a new index for a source."""
69
67
  log_event("kodit.index.create")
70
- with self.operation.create_child(OperationType.CREATE_INDEX.value) as operation:
68
+ async with self.operation.create_child(TaskOperation.CREATE_INDEX) as operation:
71
69
  # Check if index already exists
72
70
  sanitized_uri, _ = self.index_domain_service.sanitize_uri(uri)
73
71
  self.log.info("Creating index from URI", uri=str(sanitized_uri))
@@ -86,14 +84,16 @@ class CodeIndexingApplicationService:
86
84
 
87
85
  # Create new index
88
86
  self.log.info("Creating index", uri=str(sanitized_uri))
89
- index = await self.index_repository.create(sanitized_uri, working_copy)
90
- await self.session.commit()
91
- return index
87
+ return await self.index_repository.create(sanitized_uri, working_copy)
92
88
 
93
89
  async def run_index(self, index: Index) -> None:
94
90
  """Run the complete indexing process for a specific index."""
95
91
  # Create a new operation
96
- with self.operation.create_child(OperationType.RUN_INDEX.value) as operation:
92
+ async with self.operation.create_child(
93
+ TaskOperation.RUN_INDEX,
94
+ trackable_type=TrackableType.INDEX,
95
+ trackable_id=index.id,
96
+ ) as operation:
97
97
  # TODO(philwinder): Move this into a reporter # noqa: TD003, FIX002
98
98
  log_event("kodit.index.run")
99
99
 
@@ -102,7 +102,9 @@ class CodeIndexingApplicationService:
102
102
  raise ValueError(msg)
103
103
 
104
104
  # Refresh working copy
105
- with operation.create_child("Refresh working copy") as step:
105
+ async with operation.create_child(
106
+ TaskOperation.REFRESH_WORKING_COPY
107
+ ) as step:
106
108
  index.source.working_copy = (
107
109
  await self.index_domain_service.refresh_working_copy(
108
110
  index.source.working_copy, step
@@ -110,11 +112,13 @@ class CodeIndexingApplicationService:
110
112
  )
111
113
  if len(index.source.working_copy.changed_files()) == 0:
112
114
  self.log.info("No new changes to index", index_id=index.id)
113
- step.skip("No new changes to index")
115
+ await step.skip("No new changes to index")
114
116
  return
115
117
 
116
118
  # Delete the old snippets from the files that have changed
117
- with operation.create_child("Delete old snippets") as step:
119
+ async with operation.create_child(
120
+ TaskOperation.DELETE_OLD_SNIPPETS
121
+ ) as step:
118
122
  await self.index_repository.delete_snippets_by_file_ids(
119
123
  [
120
124
  file.id
@@ -124,7 +128,7 @@ class CodeIndexingApplicationService:
124
128
  )
125
129
 
126
130
  # Extract and create snippets (domain service handles progress)
127
- with operation.create_child("Extract snippets") as step:
131
+ async with operation.create_child(TaskOperation.EXTRACT_SNIPPETS) as step:
128
132
  index = await self.index_domain_service.extract_snippets_from_index(
129
133
  index=index, step=step
130
134
  )
@@ -140,20 +144,22 @@ class CodeIndexingApplicationService:
140
144
  self.log.info(
141
145
  "No snippets to index after extraction", index_id=index.id
142
146
  )
143
- step.skip("No snippets to index after extraction")
147
+ await step.skip("No snippets to index after extraction")
144
148
  return
145
149
 
146
150
  # Create BM25 index
147
151
  self.log.info("Creating keyword index")
148
- with operation.create_child("Create BM25 index") as step:
152
+ async with operation.create_child(TaskOperation.CREATE_BM25_INDEX) as step:
149
153
  await self._create_bm25_index(index.snippets)
150
154
 
151
155
  # Create code embeddings
152
- with operation.create_child("Create code embeddings") as step:
156
+ async with operation.create_child(
157
+ TaskOperation.CREATE_CODE_EMBEDDINGS
158
+ ) as step:
153
159
  await self._create_code_embeddings(index.snippets, step)
154
160
 
155
161
  # Enrich snippets
156
- with operation.create_child("Enrich snippets") as step:
162
+ async with operation.create_child(TaskOperation.ENRICH_SNIPPETS) as step:
157
163
  enriched_snippets = (
158
164
  await self.index_domain_service.enrich_snippets_in_index(
159
165
  snippets=index.snippets,
@@ -164,15 +170,21 @@ class CodeIndexingApplicationService:
164
170
  await self.index_repository.update_snippets(index.id, enriched_snippets)
165
171
 
166
172
  # Create text embeddings (on enriched content)
167
- with operation.create_child("Create text embeddings") as step:
173
+ async with operation.create_child(
174
+ TaskOperation.CREATE_TEXT_EMBEDDINGS
175
+ ) as step:
168
176
  await self._create_text_embeddings(enriched_snippets, step)
169
177
 
170
178
  # Update index timestamp
171
- with operation.create_child("Update index timestamp") as step:
179
+ async with operation.create_child(
180
+ TaskOperation.UPDATE_INDEX_TIMESTAMP
181
+ ) as step:
172
182
  await self.index_repository.update_index_timestamp(index.id)
173
183
 
174
184
  # After indexing, clear the file processing statuses
175
- with operation.create_child("Clear file processing statuses") as step:
185
+ async with operation.create_child(
186
+ TaskOperation.CLEAR_FILE_PROCESSING_STATUSES
187
+ ) as step:
176
188
  index.source.working_copy.clear_file_processing_statuses()
177
189
  await self.index_repository.update(index)
178
190
 
@@ -340,7 +352,7 @@ class CodeIndexingApplicationService:
340
352
  async def _create_code_embeddings(
341
353
  self, snippets: list[Snippet], reporting_step: ProgressTracker
342
354
  ) -> None:
343
- reporting_step.set_total(len(snippets))
355
+ await reporting_step.set_total(len(snippets))
344
356
  processed = 0
345
357
  async for result in self.code_search_service.index_documents(
346
358
  IndexRequest(
@@ -352,7 +364,9 @@ class CodeIndexingApplicationService:
352
364
  )
353
365
  ):
354
366
  processed += len(result)
355
- reporting_step.set_current(processed)
367
+ await reporting_step.set_current(
368
+ processed, f"Creating code embeddings for {processed} snippets"
369
+ )
356
370
 
357
371
  async def _create_text_embeddings(
358
372
  self, snippets: list[Snippet], reporting_step: ProgressTracker
@@ -372,16 +386,20 @@ class CodeIndexingApplicationService:
372
386
  continue
373
387
 
374
388
  if not documents_with_summaries:
375
- reporting_step.skip("No snippets with summaries to create text embeddings")
389
+ await reporting_step.skip(
390
+ "No snippets with summaries to create text embeddings"
391
+ )
376
392
  return
377
393
 
378
- reporting_step.set_total(len(documents_with_summaries))
394
+ await reporting_step.set_total(len(documents_with_summaries))
379
395
  processed = 0
380
396
  async for result in self.text_search_service.index_documents(
381
397
  IndexRequest(documents=documents_with_summaries)
382
398
  ):
383
399
  processed += len(result)
384
- reporting_step.set_current(processed)
400
+ await reporting_step.set_current(
401
+ processed, f"Creating text embeddings for {processed} snippets"
402
+ )
385
403
 
386
404
  async def delete_index(self, index: Index) -> None:
387
405
  """Delete an index."""
@@ -390,4 +408,3 @@ class CodeIndexingApplicationService:
390
408
 
391
409
  # Delete index from the database
392
410
  await self.index_repository.delete(index)
393
- await self.session.commit()
@@ -136,17 +136,15 @@ class IndexingWorkerService:
136
136
  # Create a fresh database connection for this thread's event loop
137
137
  db = await self.app_context.new_db(run_migrations=True)
138
138
  try:
139
- async with db.session_factory() as session:
140
- service = create_code_indexing_application_service(
141
- app_context=self.app_context,
142
- session=session,
143
- session_factory=self.session_factory,
144
- operation=operation,
145
- )
146
- index = await service.index_repository.get(index_id)
147
- if not index:
148
- raise ValueError(f"Index not found: {index_id}")
149
-
150
- await service.run_index(index)
139
+ service = create_code_indexing_application_service(
140
+ app_context=self.app_context,
141
+ session_factory=self.session_factory,
142
+ operation=operation,
143
+ )
144
+ index = await service.index_repository.get(index_id)
145
+ if not index:
146
+ raise ValueError(f"Index not found: {index_id}")
147
+
148
+ await service.run_index(index)
151
149
  finally:
152
150
  await db.close()
@@ -1,86 +1,104 @@
1
1
  """Reporting."""
2
2
 
3
- from enum import StrEnum
4
- from types import TracebackType
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
5
5
  from typing import TYPE_CHECKING
6
6
 
7
7
  import structlog
8
8
 
9
- from kodit.domain.value_objects import Progress, ReportingState
9
+ from kodit.domain.entities import TaskStatus
10
+ from kodit.domain.value_objects import TaskOperation, TrackableType
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from kodit.domain.protocols import ReportingModule
13
14
 
14
15
 
15
- class OperationType(StrEnum):
16
- """Operation type."""
17
-
18
- ROOT = "kodit.root"
19
- CREATE_INDEX = "kodit.index.create"
20
- RUN_INDEX = "kodit.index.run"
16
+ class ProgressTracker:
17
+ """Progress tracker.
21
18
 
19
+ Provides a reactive wrapper around TaskStatus domain entities that automatically
20
+ propagates state changes to the database and reporting modules. This pattern was
21
+ chosen over a traditional service-repository approach because:
22
+ - State changes must trigger immediate side effects (database writes, notifications)
23
+ - Multiple consumers need real-time updates without polling
24
+ - The wrapper pattern allows transparent interception of all state mutations
22
25
 
23
- class ProgressTracker:
24
- """Progress tracker."""
26
+ The tracker monitors all modifications to the underlying TaskStatus and ensures
27
+ consistency across all downstream systems.
28
+ """
25
29
 
26
- def __init__(self, name: str, parent: "ProgressTracker | None" = None) -> None:
30
+ def __init__(
31
+ self,
32
+ task_status: TaskStatus,
33
+ ) -> None:
27
34
  """Initialize the progress tracker."""
28
- self._parent: ProgressTracker | None = parent
29
- self._children: list[ProgressTracker] = []
35
+ self.task_status = task_status
30
36
  self._log = structlog.get_logger(__name__)
31
37
  self._subscribers: list[ReportingModule] = []
32
- self._snapshot: Progress = Progress(name=name, state=ReportingState.IN_PROGRESS)
33
-
34
- def __enter__(self) -> "ProgressTracker":
35
- """Enter the operation."""
36
- self._notify_subscribers()
37
- return self
38
38
 
39
- def __exit__(
40
- self,
41
- exc_type: type[BaseException] | None,
42
- exc_value: BaseException | None,
43
- traceback: TracebackType | None,
44
- ) -> None:
45
- """Exit the operation."""
46
- if exc_value:
47
- self._snapshot = self._snapshot.with_error(exc_value)
48
- self._snapshot = self._snapshot.with_state(
49
- ReportingState.FAILED, str(exc_value)
39
+ @staticmethod
40
+ def create(
41
+ operation: TaskOperation,
42
+ parent: "TaskStatus | None" = None,
43
+ trackable_type: TrackableType | None = None,
44
+ trackable_id: int | None = None,
45
+ ) -> "ProgressTracker":
46
+ """Create a progress tracker."""
47
+ return ProgressTracker(
48
+ TaskStatus.create(
49
+ operation=operation,
50
+ trackable_type=trackable_type,
51
+ trackable_id=trackable_id,
52
+ parent=parent,
50
53
  )
54
+ )
51
55
 
52
- if self._snapshot.state == ReportingState.IN_PROGRESS:
53
- self._snapshot = self._snapshot.with_progress(100)
54
- self._snapshot = self._snapshot.with_state(ReportingState.COMPLETED)
55
- self._notify_subscribers()
56
-
57
- def create_child(self, name: str) -> "ProgressTracker":
56
+ @asynccontextmanager
57
+ async def create_child(
58
+ self,
59
+ operation: TaskOperation,
60
+ trackable_type: TrackableType | None = None,
61
+ trackable_id: int | None = None,
62
+ ) -> AsyncGenerator["ProgressTracker", None]:
58
63
  """Create a child step."""
59
- s = ProgressTracker(name, self)
60
- self._children.append(s)
61
- for subscriber in self._subscribers:
62
- s.subscribe(subscriber)
63
- return s
64
-
65
- def skip(self, reason: str | None = None) -> None:
64
+ c = ProgressTracker.create(
65
+ operation=operation,
66
+ parent=self.task_status,
67
+ trackable_type=trackable_type or self.task_status.trackable_type,
68
+ trackable_id=trackable_id or self.task_status.trackable_id,
69
+ )
70
+ try:
71
+ for subscriber in self._subscribers:
72
+ c.subscribe(subscriber)
73
+
74
+ await c.notify_subscribers()
75
+ yield c
76
+ except Exception as e: # noqa: BLE001
77
+ c.task_status.fail(str(e))
78
+ finally:
79
+ c.task_status.complete()
80
+ await c.notify_subscribers()
81
+
82
+ async def skip(self, reason: str) -> None:
66
83
  """Skip the step."""
67
- self._snapshot = self._snapshot.with_state(ReportingState.SKIPPED, reason or "")
84
+ self.task_status.skip(reason)
85
+ await self.notify_subscribers()
68
86
 
69
87
  def subscribe(self, subscriber: "ReportingModule") -> None:
70
88
  """Subscribe to the step."""
71
89
  self._subscribers.append(subscriber)
72
90
 
73
- def set_total(self, total: int) -> None:
91
+ async def set_total(self, total: int) -> None:
74
92
  """Set the total for the step."""
75
- self._snapshot = self._snapshot.with_total(total)
76
- self._notify_subscribers()
93
+ self.task_status.set_total(total)
94
+ await self.notify_subscribers()
77
95
 
78
- def set_current(self, current: int) -> None:
96
+ async def set_current(self, current: int, message: str | None = None) -> None:
79
97
  """Progress the step."""
80
- self._snapshot = self._snapshot.with_progress(current)
81
- self._notify_subscribers()
98
+ self.task_status.set_current(current, message)
99
+ await self.notify_subscribers()
82
100
 
83
- def _notify_subscribers(self) -> None:
84
- """Notify the subscribers."""
101
+ async def notify_subscribers(self) -> None:
102
+ """Notify the subscribers only if progress has changed."""
85
103
  for subscriber in self._subscribers:
86
- subscriber.on_change(self._snapshot)
104
+ await subscriber.on_change(self.task_status)