nedo-vision-worker 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nedo_vision_worker/__init__.py +1 -1
  2. nedo_vision_worker/cli.py +197 -168
  3. nedo_vision_worker/database/DatabaseManager.py +3 -3
  4. nedo_vision_worker/doctor.py +1066 -386
  5. nedo_vision_worker/models/ai_model.py +35 -2
  6. nedo_vision_worker/protos/AIModelService_pb2.py +12 -10
  7. nedo_vision_worker/protos/AIModelService_pb2_grpc.py +1 -1
  8. nedo_vision_worker/protos/DatasetSourceService_pb2.py +2 -2
  9. nedo_vision_worker/protos/DatasetSourceService_pb2_grpc.py +1 -1
  10. nedo_vision_worker/protos/HumanDetectionService_pb2.py +2 -2
  11. nedo_vision_worker/protos/HumanDetectionService_pb2_grpc.py +1 -1
  12. nedo_vision_worker/protos/PPEDetectionService_pb2.py +2 -2
  13. nedo_vision_worker/protos/PPEDetectionService_pb2_grpc.py +1 -1
  14. nedo_vision_worker/protos/VisionWorkerService_pb2.py +2 -2
  15. nedo_vision_worker/protos/VisionWorkerService_pb2_grpc.py +1 -1
  16. nedo_vision_worker/protos/WorkerSourcePipelineService_pb2.py +2 -2
  17. nedo_vision_worker/protos/WorkerSourcePipelineService_pb2_grpc.py +1 -1
  18. nedo_vision_worker/protos/WorkerSourceService_pb2.py +2 -2
  19. nedo_vision_worker/protos/WorkerSourceService_pb2_grpc.py +1 -1
  20. nedo_vision_worker/services/AIModelClient.py +184 -160
  21. nedo_vision_worker/services/DirectDeviceToRTMPStreamer.py +534 -0
  22. nedo_vision_worker/services/GrpcClientBase.py +142 -108
  23. nedo_vision_worker/services/PPEDetectionClient.py +0 -7
  24. nedo_vision_worker/services/RestrictedAreaClient.py +0 -5
  25. nedo_vision_worker/services/SharedDirectDeviceClient.py +278 -0
  26. nedo_vision_worker/services/SharedVideoStreamServer.py +315 -0
  27. nedo_vision_worker/services/SystemWideDeviceCoordinator.py +236 -0
  28. nedo_vision_worker/services/VideoSharingDaemon.py +832 -0
  29. nedo_vision_worker/services/VideoStreamClient.py +30 -13
  30. nedo_vision_worker/services/WorkerSourceClient.py +1 -1
  31. nedo_vision_worker/services/WorkerSourcePipelineClient.py +28 -6
  32. nedo_vision_worker/services/WorkerSourceUpdater.py +30 -3
  33. nedo_vision_worker/util/VideoProbeUtil.py +222 -15
  34. nedo_vision_worker/worker/DataSyncWorker.py +1 -0
  35. nedo_vision_worker/worker/PipelineImageWorker.py +1 -1
  36. nedo_vision_worker/worker/VideoStreamWorker.py +27 -3
  37. nedo_vision_worker/worker/WorkerManager.py +2 -29
  38. nedo_vision_worker/worker_service.py +24 -11
  39. {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/METADATA +1 -3
  40. {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/RECORD +43 -38
  41. {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/WHEEL +0 -0
  42. {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/entry_points.txt +0 -0
  43. {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/top_level.txt +0 -0
@@ -4,21 +4,19 @@ import threading
4
4
  import time
5
5
  from pathlib import Path
6
6
  from enum import Enum
7
- from typing import Dict, Optional
7
+ from typing import Dict, Optional, Set, List
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
8
10
 
9
11
  from ..models.ai_model import AIModelEntity
10
12
  from ..repositories.AIModelRepository import AIModelRepository
11
13
  from .GrpcClientBase import GrpcClientBase
12
14
  from ..protos.AIModelService_pb2_grpc import AIModelGRPCServiceStub
13
- from ..protos.AIModelService_pb2 import (
14
- GetAIModelListRequest,
15
- DownloadAIModelRequest
16
- )
15
+ from ..protos.AIModelService_pb2 import GetAIModelListRequest, DownloadAIModelRequest
17
16
  from ..database.DatabaseManager import _get_storage_paths
18
17
 
19
18
 
20
19
  class DownloadState(Enum):
21
- """Enum for tracking download states."""
22
20
  PENDING = "pending"
23
21
  DOWNLOADING = "downloading"
24
22
  COMPLETED = "completed"
@@ -26,24 +24,25 @@ class DownloadState(Enum):
26
24
  CANCELLED = "cancelled"
27
25
 
28
26
 
27
+ @dataclass
29
28
  class DownloadInfo:
30
- """Class to track download information."""
31
- def __init__(self, model_id: str, model_name: str, version: str):
32
- self.model_id = model_id
33
- self.model_name = model_name
34
- self.version = version
35
- self.state = DownloadState.PENDING
36
- self.start_time = None
37
- self.end_time = None
38
- self.error_message = None
39
- self.thread = None
40
- self.stop_event = threading.Event()
29
+ model_id: str
30
+ model_name: str
31
+ version: str
32
+ state: DownloadState = DownloadState.PENDING
33
+ start_time: Optional[float] = None
34
+ end_time: Optional[float] = None
35
+ error_message: Optional[str] = None
36
+ thread: Optional[threading.Thread] = None
37
+ stop_event: threading.Event = None
38
+
39
+ def __post_init__(self):
40
+ if self.stop_event is None:
41
+ self.stop_event = threading.Event()
41
42
 
42
43
 
43
44
  class AIModelClient(GrpcClientBase):
44
- """Client for interacting with AI models via gRPC with improved download tracking."""
45
-
46
- def __init__(self, token, server_host: str, server_port: int = 50051):
45
+ def __init__(self, token: str, server_host: str, server_port: int = 50051):
47
46
  super().__init__(server_host, server_port)
48
47
  storage_paths = _get_storage_paths()
49
48
  self.models_path = storage_paths["models"]
@@ -51,62 +50,59 @@ class AIModelClient(GrpcClientBase):
51
50
  self.repository = AIModelRepository()
52
51
  self.token = token
53
52
 
54
- # Download tracking
55
53
  self.download_tracker: Dict[str, DownloadInfo] = {}
56
- self.download_lock = threading.Lock()
54
+ self.download_lock = threading.RLock()
57
55
 
58
- try:
59
- self.connect(AIModelGRPCServiceStub)
60
- except Exception as e:
61
- logging.error(f"Failed to connect to gRPC server: {e}")
56
+ if not self.connect(AIModelGRPCServiceStub):
57
+ logging.error("Failed to connect to gRPC server")
62
58
  self.stub = None
63
59
 
64
- def _get_model_path(self, file: str) -> Path:
65
- """Get the path to a local AI model file."""
66
- return self.models_path / os.path.basename(file)
60
+ def _get_model_path(self, filename: str) -> Path:
61
+ return self.models_path / os.path.basename(filename)
67
62
 
68
- def _is_model_file_exists(self, file_path: str) -> bool:
69
- """Check if the model file actually exists on disk."""
63
+ def _model_file_exists(self, file_path: str) -> bool:
70
64
  if not file_path:
71
65
  return False
72
66
  model_path = self._get_model_path(file_path)
73
67
  return model_path.exists() and model_path.stat().st_size > 0
74
68
 
75
- def _get_download_info(self, model_id: str) -> Optional[DownloadInfo]:
76
- """Get download info for a model."""
69
+ @contextmanager
70
+ def _download_lock_context(self):
77
71
  with self.download_lock:
72
+ yield
73
+
74
+ def _get_download_info(self, model_id: str) -> Optional[DownloadInfo]:
75
+ with self._download_lock_context():
78
76
  return self.download_tracker.get(model_id)
79
77
 
80
78
  def _set_download_info(self, model_id: str, download_info: DownloadInfo):
81
- """Set download info for a model."""
82
- with self.download_lock:
79
+ with self._download_lock_context():
83
80
  self.download_tracker[model_id] = download_info
84
81
 
85
82
  def _remove_download_info(self, model_id: str):
86
- """Remove download info for a model."""
87
- with self.download_lock:
83
+ with self._download_lock_context():
88
84
  self.download_tracker.pop(model_id, None)
89
85
 
90
86
  def _is_downloading(self, model_id: str) -> bool:
91
- """Check if a model is currently being downloaded."""
92
87
  download_info = self._get_download_info(model_id)
93
- if not download_info:
94
- return False
95
- return download_info.state in [DownloadState.PENDING, DownloadState.DOWNLOADING]
88
+ return download_info and download_info.state in {DownloadState.PENDING, DownloadState.DOWNLOADING}
96
89
 
97
- def _cancel_download(self, model_id: str):
98
- """Cancel an ongoing download."""
90
+ def _cancel_download(self, model_id: str) -> bool:
99
91
  download_info = self._get_download_info(model_id)
100
- if download_info and download_info.state in [DownloadState.PENDING, DownloadState.DOWNLOADING]:
101
- download_info.state = DownloadState.CANCELLED
102
- download_info.stop_event.set()
103
- if download_info.thread and download_info.thread.is_alive():
104
- download_info.thread.join(timeout=5)
105
- self._update_model_download_status(model_id, "cancelled", "Download cancelled")
106
- logging.info(f"🛑 Cancelled download for model {download_info.model_name}")
107
-
108
- def _update_model_download_status(self, model_id: str, status: str, error_message: str = None):
109
- """Update the download status in the database."""
92
+ if not download_info or download_info.state not in {DownloadState.PENDING, DownloadState.DOWNLOADING}:
93
+ return False
94
+
95
+ download_info.state = DownloadState.CANCELLED
96
+ download_info.stop_event.set()
97
+
98
+ if download_info.thread and download_info.thread.is_alive():
99
+ download_info.thread.join(timeout=5)
100
+
101
+ self._update_model_status(model_id, "cancelled", "Download cancelled")
102
+ logging.info(f"🛑 Cancelled download for {download_info.model_name}")
103
+ return True
104
+
105
+ def _update_model_status(self, model_id: str, status: str, error_message: str = None):
110
106
  try:
111
107
  from datetime import datetime
112
108
  model = self.repository.get_model_by_id(model_id)
@@ -117,53 +113,44 @@ class AIModelClient(GrpcClientBase):
117
113
  model.download_error = error_message
118
114
  self.repository.session.commit()
119
115
  except Exception as e:
120
- logging.error(f"❌ Error updating model download status: {e}")
121
- self.repository.session.rollback()
116
+ logging.error(f"❌ Error updating model status: {e}")
117
+ if hasattr(self.repository, 'session'):
118
+ self.repository.session.rollback()
122
119
 
123
120
  def sync_ai_models(self, worker_id: str) -> dict:
124
- """Fetch and sync AI model list from gRPC service using token authentication."""
125
121
  if not self.stub:
126
- return {"success": False, "message": "gRPC connection is not established."}
122
+ return {"success": False, "message": "gRPC connection not established"}
127
123
 
128
124
  try:
129
- # Get model list from server
130
125
  response = self._fetch_model_list(worker_id)
131
126
  if not response or not response.success:
132
- return {"success": False, "message": response.message if response else "Unknown error"}
127
+ return {"success": False, "message": getattr(response, 'message', 'Unknown error')}
133
128
 
134
- # Process models
135
129
  self._process_server_models(response.data)
136
-
137
130
  return {"success": True, "message": response.message, "data": response.data}
138
131
 
139
132
  except Exception as e:
140
- logging.error(f"Error fetching AI model list: {e}")
141
- return {"success": False, "message": f"Error occurred: {e}"}
133
+ logging.error(f"Error syncing models: {e}")
134
+ return {"success": False, "message": f"Error: {e}"}
142
135
 
143
136
  def _fetch_model_list(self, worker_id: str):
144
- """Fetch model list from server using token authentication."""
145
137
  request = GetAIModelListRequest(worker_id=worker_id, token=self.token)
146
138
  return self.handle_rpc(self.stub.GetAIModelList, request)
147
139
 
148
140
  def _process_server_models(self, server_models):
149
- """Process server models, handling additions, updates, and deletions."""
150
141
  local_models = {model.id: model for model in self.repository.get_models()}
151
- server_model_ids = set()
142
+ server_model_ids = {model.id for model in server_models}
152
143
 
153
144
  new_models = []
154
145
  updated_models = []
155
146
 
156
- # Process each model from the server
157
147
  for model in server_models:
158
- server_model_ids.add(model.id)
159
148
  existing_model = local_models.get(model.id)
160
-
161
149
  if existing_model:
162
150
  self._handle_existing_model(model, existing_model, updated_models)
163
151
  else:
164
152
  self._handle_new_model(model, new_models)
165
153
 
166
- # Handle models that no longer exist on the server
167
154
  models_to_delete = [
168
155
  model for model_id, model in local_models.items()
169
156
  if model_id not in server_model_ids
@@ -171,123 +158,151 @@ class AIModelClient(GrpcClientBase):
171
158
 
172
159
  self._save_changes(new_models, updated_models, models_to_delete)
173
160
 
174
- def _handle_existing_model(self, server_model, local_model, updated_models):
175
- """Handle model that exists locally but might need updates."""
176
- # Check if model file actually exists
177
- if not self._is_model_file_exists(local_model.file):
161
+ def _handle_existing_model(self, server_model, local_model, updated_models: List):
162
+ if not self._model_file_exists(local_model.file):
178
163
  logging.warning(f"⚠️ Model file missing for {local_model.name}. Re-downloading...")
179
- self._schedule_model_download(server_model)
164
+ self._schedule_download(server_model)
180
165
  return
181
166
 
182
- # Check if version or type changed
183
- if server_model.version == local_model.version and server_model.ai_model_type_code == local_model.type:
167
+ needs_update, changes = self._check_model_changes(server_model, local_model)
168
+
169
+ version_changed = server_model.version != local_model.version
170
+ if not needs_update:
184
171
  return
185
172
 
186
- logging.info(f"🔄 Model update detected: {server_model.name} "
187
- f"(Version {local_model.version} -> {server_model.version}). Updating...")
188
-
189
- # Cancel any ongoing download for this model
190
- self._cancel_download(server_model.id)
191
-
192
- # Delete old model file
193
- self.delete_local_model(local_model.file)
194
-
195
- # Schedule new download
196
- self._schedule_model_download(server_model)
173
+ change_desc = ", ".join(changes)
174
+ logging.info(f"🔄 Model update: {server_model.name} ({change_desc})")
197
175
 
198
- # Update properties regardless
199
- local_model.name = server_model.name
200
- local_model.type = server_model.ai_model_type_code
201
- local_model.version = server_model.version
176
+ if version_changed:
177
+ self._cancel_download(server_model.id)
178
+ self.delete_local_model(local_model.file)
179
+ self._schedule_download(server_model)
180
+
181
+ self._update_model_properties(local_model, server_model)
202
182
  updated_models.append(local_model)
203
183
 
204
- def _handle_new_model(self, server_model, new_models):
205
- """Handle model that doesn't exist locally."""
206
- # Check if already downloading this model
184
+ def _handle_new_model(self, server_model, new_models: List):
207
185
  if self._is_downloading(server_model.id):
208
- logging.info(f"⏳ Model {server_model.name} is already being downloaded. Skipping...")
186
+ logging.info(f"⏳ Model {server_model.name} already downloading")
209
187
  return
210
188
 
211
- new_model = AIModelEntity(
212
- id=server_model.id,
213
- name=server_model.name,
189
+ new_model = self._create_model_entity(server_model)
190
+ new_models.append(new_model)
191
+
192
+ logging.info(f"⬇️ New model: {server_model.name}")
193
+ self._schedule_download(server_model)
194
+
195
+ def _check_model_changes(self, server_model, local_model) -> tuple[bool, List[str]]:
196
+ """Check if model needs update and return list of changes"""
197
+ changes = []
198
+
199
+ if server_model.version != local_model.version:
200
+ changes.append(f"version: {local_model.version} -> {server_model.version}")
201
+
202
+ if server_model.ai_model_type_code != local_model.type:
203
+ changes.append(f"type: {local_model.type} -> {server_model.ai_model_type_code}")
204
+
205
+ server_classes = set(server_model.classes)
206
+ local_classes = set(local_model.get_classes() or [])
207
+ if server_classes != local_classes:
208
+ changes.append("classes updated")
209
+
210
+ server_ppe_groups = {
211
+ group.name: {"compliance": group.compliance, "violation": group.violation}
212
+ for group in (server_model.ppe_class_groups or [])
213
+ }
214
+ local_ppe_groups = local_model.get_ppe_class_groups() or {}
215
+ if server_ppe_groups != local_ppe_groups:
216
+ changes.append("PPE class groups updated")
217
+
218
+ if server_model.main_class != local_model.get_main_class():
219
+ changes.append(f"main class: {local_model.get_main_class()} -> {server_model.main_class}")
220
+
221
+ return bool(changes), changes
222
+
223
+ def _create_model_entity(self, server_model) -> AIModelEntity:
224
+ model = AIModelEntity(
225
+ id=server_model.id,
226
+ name=server_model.name,
214
227
  type=server_model.ai_model_type_code,
215
- file=os.path.basename(server_model.file_path),
228
+ file=os.path.basename(server_model.file_path),
216
229
  version=server_model.version
217
230
  )
218
- new_models.append(new_model)
219
-
220
- logging.info(f"⬇️ New model detected: {server_model.name}. Scheduling download...")
221
- self._schedule_model_download(server_model)
231
+ self._update_model_properties(model, server_model)
232
+ return model
222
233
 
223
- def _schedule_model_download(self, model):
224
- """Schedule a model download in background thread."""
225
- # Cancel any existing download for this model
234
+ def _update_model_properties(self, local_model: AIModelEntity, server_model):
235
+ local_model.name = server_model.name
236
+ local_model.type = server_model.ai_model_type_code
237
+ local_model.version = server_model.version
238
+ local_model.set_classes(list(server_model.classes))
239
+ local_model.set_ppe_class_groups({
240
+ group.name: {
241
+ "compliance": group.compliance,
242
+ "violation": group.violation
243
+ }
244
+ for group in (server_model.ppe_class_groups or [])
245
+ })
246
+ local_model.set_main_class(server_model.main_class)
247
+
248
+ def _schedule_download(self, model):
226
249
  self._cancel_download(model.id)
227
250
 
228
- # Create new download info
229
251
  download_info = DownloadInfo(
230
252
  model_id=model.id,
231
253
  model_name=model.name,
232
- version=model.version
254
+ version=model.version,
255
+ start_time=time.time()
233
256
  )
234
- download_info.state = DownloadState.PENDING
235
- download_info.start_time = time.time()
236
257
 
237
- # Update database status
238
- self._update_model_download_status(model.id, "pending", None)
258
+ self._update_model_status(model.id, "pending")
239
259
 
240
- # Start download in background thread
241
260
  download_info.thread = threading.Thread(
242
- target=self._download_model_worker,
261
+ target=self._download_worker,
243
262
  args=(model, download_info),
244
263
  daemon=True,
245
- name=f"ModelDownload-{model.id}"
264
+ name=f"ModelDownload-{model.id[:8]}"
246
265
  )
247
266
  download_info.thread.start()
248
-
249
267
  self._set_download_info(model.id, download_info)
250
268
 
251
- def _download_model_worker(self, model, download_info):
252
- """Background worker for downloading a model."""
269
+ def _download_worker(self, model, download_info: DownloadInfo):
253
270
  try:
254
271
  download_info.state = DownloadState.DOWNLOADING
255
- self._update_model_download_status(model.id, "downloading", None)
256
- logging.info(f"📥 Starting download for AI model '{model.name}'...")
272
+ self._update_model_status(model.id, "downloading")
273
+ logging.info(f"📥 Downloading {model.name}...")
274
+
275
+ success = self._download_model_file(model, download_info)
257
276
 
258
- if self.download_model(model, download_info):
277
+ if success:
259
278
  download_info.state = DownloadState.COMPLETED
260
279
  download_info.end_time = time.time()
261
280
  duration = download_info.end_time - download_info.start_time
262
- self._update_model_download_status(model.id, "completed", None)
263
- logging.info(f"✅ AI Model '{model.name}' downloaded successfully in {duration:.2f}s")
281
+ self._update_model_status(model.id, "completed")
282
+ logging.info(f"✅ Downloaded {model.name} in {duration:.1f}s")
264
283
  else:
265
- download_info.state = DownloadState.FAILED
266
- download_info.error_message = "Download failed"
267
- self._update_model_download_status(model.id, "failed", "Download failed")
268
- logging.error(f"❌ Failed to download AI Model '{model.name}'")
284
+ self._handle_download_failure(download_info, "Download failed")
269
285
 
270
286
  except Exception as e:
271
- download_info.state = DownloadState.FAILED
272
- download_info.error_message = str(e)
273
- self._update_model_download_status(model.id, "failed", str(e))
274
- logging.error(f"❌ Error downloading AI Model '{model.name}': {e}")
287
+ self._handle_download_failure(download_info, str(e))
275
288
  finally:
276
- # Clean up download info after a delay to allow status checking
277
289
  threading.Timer(300, lambda: self._remove_download_info(model.id)).start()
278
290
 
279
- def _save_changes(self, new_models, updated_models, models_to_delete):
280
- """Save all changes to database in a single transaction."""
291
+ def _handle_download_failure(self, download_info: DownloadInfo, error: str):
292
+ download_info.state = DownloadState.FAILED
293
+ download_info.error_message = error
294
+ self._update_model_status(download_info.model_id, "failed", error)
295
+ logging.error(f"❌ Failed to download {download_info.model_name}: {error}")
296
+
297
+ def _save_changes(self, new_models: List, updated_models: List, models_to_delete: List):
281
298
  try:
282
299
  if new_models:
283
300
  self.repository.session.bulk_save_objects(new_models)
284
-
285
301
  if updated_models:
286
302
  self.repository.session.bulk_save_objects(updated_models)
287
303
 
288
304
  for model in models_to_delete:
289
- logging.info(f"🗑️ Model removed from server: {model.name}. Deleting local copy...")
290
- # Cancel any ongoing download
305
+ logging.info(f"🗑️ Removing {model.name}")
291
306
  self._cancel_download(model.id)
292
307
  self.repository.session.delete(model)
293
308
  self.delete_local_model(model.file)
@@ -295,50 +310,43 @@ class AIModelClient(GrpcClientBase):
295
310
  self.repository.session.commit()
296
311
  except Exception as e:
297
312
  self.repository.session.rollback()
298
- logging.error(f"Error saving model changes: {e}")
313
+ logging.error(f"Error saving changes: {e}")
299
314
  raise
300
315
 
301
- def download_model(self, model, download_info=None) -> bool:
302
- """Download the AI model and save it to the models directory."""
316
+ def _download_model_file(self, model, download_info: DownloadInfo) -> bool:
303
317
  if not self.stub:
304
- logging.error("gRPC connection is not established.")
305
318
  return False
306
319
 
307
320
  try:
308
321
  request = DownloadAIModelRequest(ai_model_id=model.id, token=self.token)
309
322
  file_path = self._get_model_path(model.file_path)
310
323
 
311
- # Check if download was cancelled
312
- if download_info and download_info.stop_event.is_set():
313
- logging.info(f"🛑 Download cancelled for model '{model.name}'")
314
- return False
315
-
316
324
  with open(file_path, "wb") as f:
317
325
  for chunk in self.stub.DownloadAIModel(request):
318
- # Check if download was cancelled during streaming
319
- if download_info and download_info.stop_event.is_set():
320
- logging.info(f"🛑 Download cancelled during streaming for model '{model.name}'")
326
+ if download_info.stop_event.is_set():
327
+ logging.info(f"🛑 Download cancelled: {model.name}")
321
328
  return False
322
329
  f.write(chunk.file_chunk)
323
330
 
324
331
  return True
325
332
 
326
333
  except Exception as e:
327
- logging.error(f"❌ Error downloading AI Model '{model.name}': {e}")
334
+ logging.error(f"❌ Download error for {model.name}: {e}")
328
335
  return False
329
336
 
330
- def delete_local_model(self, file: str) -> None:
331
- """Delete a local AI model file."""
332
- file_path = self._get_model_path(file)
337
+ def delete_local_model(self, filename: str) -> bool:
333
338
  try:
339
+ file_path = self._get_model_path(filename)
334
340
  if file_path.exists():
335
341
  file_path.unlink()
336
- logging.info(f"🗑️ Model file deleted: {file}")
342
+ logging.info(f"🗑️ Deleted {filename}")
343
+ return True
344
+ return False
337
345
  except Exception as e:
338
- logging.error(f"❌ Error deleting model file: {e}")
346
+ logging.error(f"❌ Error deleting {filename}: {e}")
347
+ return False
339
348
 
340
349
  def get_download_status(self, model_id: str) -> Optional[Dict]:
341
- """Get the download status for a specific model."""
342
350
  download_info = self._get_download_info(model_id)
343
351
  if not download_info:
344
352
  return None
@@ -354,9 +362,25 @@ class AIModelClient(GrpcClientBase):
354
362
  }
355
363
 
356
364
  def get_all_download_status(self) -> Dict[str, Dict]:
357
- """Get download status for all models."""
358
- with self.download_lock:
365
+ with self._download_lock_context():
359
366
  return {
360
367
  model_id: self.get_download_status(model_id)
361
368
  for model_id in self.download_tracker.keys()
362
- }
369
+ }
370
+
371
+ def cancel_all_downloads(self) -> int:
372
+ cancelled_count = 0
373
+ with self._download_lock_context():
374
+ for model_id in list(self.download_tracker.keys()):
375
+ if self._cancel_download(model_id):
376
+ cancelled_count += 1
377
+ return cancelled_count
378
+
379
+ def cleanup_downloads(self):
380
+ with self._download_lock_context():
381
+ completed_ids = [
382
+ model_id for model_id, info in self.download_tracker.items()
383
+ if info.state in {DownloadState.COMPLETED, DownloadState.FAILED, DownloadState.CANCELLED}
384
+ ]
385
+ for model_id in completed_ids:
386
+ self._remove_download_info(model_id)