nedo-vision-worker 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nedo_vision_worker/__init__.py +1 -1
- nedo_vision_worker/cli.py +197 -168
- nedo_vision_worker/database/DatabaseManager.py +3 -3
- nedo_vision_worker/doctor.py +1066 -386
- nedo_vision_worker/models/ai_model.py +35 -2
- nedo_vision_worker/protos/AIModelService_pb2.py +12 -10
- nedo_vision_worker/protos/AIModelService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/DatasetSourceService_pb2.py +2 -2
- nedo_vision_worker/protos/DatasetSourceService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/HumanDetectionService_pb2.py +2 -2
- nedo_vision_worker/protos/HumanDetectionService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/PPEDetectionService_pb2.py +2 -2
- nedo_vision_worker/protos/PPEDetectionService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/VisionWorkerService_pb2.py +2 -2
- nedo_vision_worker/protos/VisionWorkerService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/WorkerSourcePipelineService_pb2.py +2 -2
- nedo_vision_worker/protos/WorkerSourcePipelineService_pb2_grpc.py +1 -1
- nedo_vision_worker/protos/WorkerSourceService_pb2.py +2 -2
- nedo_vision_worker/protos/WorkerSourceService_pb2_grpc.py +1 -1
- nedo_vision_worker/services/AIModelClient.py +184 -160
- nedo_vision_worker/services/DirectDeviceToRTMPStreamer.py +534 -0
- nedo_vision_worker/services/GrpcClientBase.py +142 -108
- nedo_vision_worker/services/PPEDetectionClient.py +0 -7
- nedo_vision_worker/services/RestrictedAreaClient.py +0 -5
- nedo_vision_worker/services/SharedDirectDeviceClient.py +278 -0
- nedo_vision_worker/services/SharedVideoStreamServer.py +315 -0
- nedo_vision_worker/services/SystemWideDeviceCoordinator.py +236 -0
- nedo_vision_worker/services/VideoSharingDaemon.py +832 -0
- nedo_vision_worker/services/VideoStreamClient.py +30 -13
- nedo_vision_worker/services/WorkerSourceClient.py +1 -1
- nedo_vision_worker/services/WorkerSourcePipelineClient.py +28 -6
- nedo_vision_worker/services/WorkerSourceUpdater.py +30 -3
- nedo_vision_worker/util/VideoProbeUtil.py +222 -15
- nedo_vision_worker/worker/DataSyncWorker.py +1 -0
- nedo_vision_worker/worker/PipelineImageWorker.py +1 -1
- nedo_vision_worker/worker/VideoStreamWorker.py +27 -3
- nedo_vision_worker/worker/WorkerManager.py +2 -29
- nedo_vision_worker/worker_service.py +24 -11
- {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/METADATA +1 -3
- {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/RECORD +43 -38
- {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/WHEEL +0 -0
- {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/entry_points.txt +0 -0
- {nedo_vision_worker-1.1.2.dist-info → nedo_vision_worker-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -4,21 +4,19 @@ import threading
|
|
|
4
4
|
import time
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import Dict, Optional
|
|
7
|
+
from typing import Dict, Optional, Set, List
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from dataclasses import dataclass
|
|
8
10
|
|
|
9
11
|
from ..models.ai_model import AIModelEntity
|
|
10
12
|
from ..repositories.AIModelRepository import AIModelRepository
|
|
11
13
|
from .GrpcClientBase import GrpcClientBase
|
|
12
14
|
from ..protos.AIModelService_pb2_grpc import AIModelGRPCServiceStub
|
|
13
|
-
from ..protos.AIModelService_pb2 import
|
|
14
|
-
GetAIModelListRequest,
|
|
15
|
-
DownloadAIModelRequest
|
|
16
|
-
)
|
|
15
|
+
from ..protos.AIModelService_pb2 import GetAIModelListRequest, DownloadAIModelRequest
|
|
17
16
|
from ..database.DatabaseManager import _get_storage_paths
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
class DownloadState(Enum):
|
|
21
|
-
"""Enum for tracking download states."""
|
|
22
20
|
PENDING = "pending"
|
|
23
21
|
DOWNLOADING = "downloading"
|
|
24
22
|
COMPLETED = "completed"
|
|
@@ -26,24 +24,25 @@ class DownloadState(Enum):
|
|
|
26
24
|
CANCELLED = "cancelled"
|
|
27
25
|
|
|
28
26
|
|
|
27
|
+
@dataclass
|
|
29
28
|
class DownloadInfo:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
29
|
+
model_id: str
|
|
30
|
+
model_name: str
|
|
31
|
+
version: str
|
|
32
|
+
state: DownloadState = DownloadState.PENDING
|
|
33
|
+
start_time: Optional[float] = None
|
|
34
|
+
end_time: Optional[float] = None
|
|
35
|
+
error_message: Optional[str] = None
|
|
36
|
+
thread: Optional[threading.Thread] = None
|
|
37
|
+
stop_event: threading.Event = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self):
|
|
40
|
+
if self.stop_event is None:
|
|
41
|
+
self.stop_event = threading.Event()
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
class AIModelClient(GrpcClientBase):
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def __init__(self, token, server_host: str, server_port: int = 50051):
|
|
45
|
+
def __init__(self, token: str, server_host: str, server_port: int = 50051):
|
|
47
46
|
super().__init__(server_host, server_port)
|
|
48
47
|
storage_paths = _get_storage_paths()
|
|
49
48
|
self.models_path = storage_paths["models"]
|
|
@@ -51,62 +50,59 @@ class AIModelClient(GrpcClientBase):
|
|
|
51
50
|
self.repository = AIModelRepository()
|
|
52
51
|
self.token = token
|
|
53
52
|
|
|
54
|
-
# Download tracking
|
|
55
53
|
self.download_tracker: Dict[str, DownloadInfo] = {}
|
|
56
|
-
self.download_lock = threading.
|
|
54
|
+
self.download_lock = threading.RLock()
|
|
57
55
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
except Exception as e:
|
|
61
|
-
logging.error(f"Failed to connect to gRPC server: {e}")
|
|
56
|
+
if not self.connect(AIModelGRPCServiceStub):
|
|
57
|
+
logging.error("Failed to connect to gRPC server")
|
|
62
58
|
self.stub = None
|
|
63
59
|
|
|
64
|
-
def _get_model_path(self,
|
|
65
|
-
|
|
66
|
-
return self.models_path / os.path.basename(file)
|
|
60
|
+
def _get_model_path(self, filename: str) -> Path:
|
|
61
|
+
return self.models_path / os.path.basename(filename)
|
|
67
62
|
|
|
68
|
-
def
|
|
69
|
-
"""Check if the model file actually exists on disk."""
|
|
63
|
+
def _model_file_exists(self, file_path: str) -> bool:
|
|
70
64
|
if not file_path:
|
|
71
65
|
return False
|
|
72
66
|
model_path = self._get_model_path(file_path)
|
|
73
67
|
return model_path.exists() and model_path.stat().st_size > 0
|
|
74
68
|
|
|
75
|
-
|
|
76
|
-
|
|
69
|
+
@contextmanager
|
|
70
|
+
def _download_lock_context(self):
|
|
77
71
|
with self.download_lock:
|
|
72
|
+
yield
|
|
73
|
+
|
|
74
|
+
def _get_download_info(self, model_id: str) -> Optional[DownloadInfo]:
|
|
75
|
+
with self._download_lock_context():
|
|
78
76
|
return self.download_tracker.get(model_id)
|
|
79
77
|
|
|
80
78
|
def _set_download_info(self, model_id: str, download_info: DownloadInfo):
|
|
81
|
-
|
|
82
|
-
with self.download_lock:
|
|
79
|
+
with self._download_lock_context():
|
|
83
80
|
self.download_tracker[model_id] = download_info
|
|
84
81
|
|
|
85
82
|
def _remove_download_info(self, model_id: str):
|
|
86
|
-
|
|
87
|
-
with self.download_lock:
|
|
83
|
+
with self._download_lock_context():
|
|
88
84
|
self.download_tracker.pop(model_id, None)
|
|
89
85
|
|
|
90
86
|
def _is_downloading(self, model_id: str) -> bool:
|
|
91
|
-
"""Check if a model is currently being downloaded."""
|
|
92
87
|
download_info = self._get_download_info(model_id)
|
|
93
|
-
|
|
94
|
-
return False
|
|
95
|
-
return download_info.state in [DownloadState.PENDING, DownloadState.DOWNLOADING]
|
|
88
|
+
return download_info and download_info.state in {DownloadState.PENDING, DownloadState.DOWNLOADING}
|
|
96
89
|
|
|
97
|
-
def _cancel_download(self, model_id: str):
|
|
98
|
-
"""Cancel an ongoing download."""
|
|
90
|
+
def _cancel_download(self, model_id: str) -> bool:
|
|
99
91
|
download_info = self._get_download_info(model_id)
|
|
100
|
-
if download_info
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
"""
|
|
92
|
+
if not download_info or download_info.state not in {DownloadState.PENDING, DownloadState.DOWNLOADING}:
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
download_info.state = DownloadState.CANCELLED
|
|
96
|
+
download_info.stop_event.set()
|
|
97
|
+
|
|
98
|
+
if download_info.thread and download_info.thread.is_alive():
|
|
99
|
+
download_info.thread.join(timeout=5)
|
|
100
|
+
|
|
101
|
+
self._update_model_status(model_id, "cancelled", "Download cancelled")
|
|
102
|
+
logging.info(f"🛑 Cancelled download for {download_info.model_name}")
|
|
103
|
+
return True
|
|
104
|
+
|
|
105
|
+
def _update_model_status(self, model_id: str, status: str, error_message: str = None):
|
|
110
106
|
try:
|
|
111
107
|
from datetime import datetime
|
|
112
108
|
model = self.repository.get_model_by_id(model_id)
|
|
@@ -117,53 +113,44 @@ class AIModelClient(GrpcClientBase):
|
|
|
117
113
|
model.download_error = error_message
|
|
118
114
|
self.repository.session.commit()
|
|
119
115
|
except Exception as e:
|
|
120
|
-
logging.error(f"❌ Error updating model
|
|
121
|
-
self.repository
|
|
116
|
+
logging.error(f"❌ Error updating model status: {e}")
|
|
117
|
+
if hasattr(self.repository, 'session'):
|
|
118
|
+
self.repository.session.rollback()
|
|
122
119
|
|
|
123
120
|
def sync_ai_models(self, worker_id: str) -> dict:
|
|
124
|
-
"""Fetch and sync AI model list from gRPC service using token authentication."""
|
|
125
121
|
if not self.stub:
|
|
126
|
-
return {"success": False, "message": "gRPC connection
|
|
122
|
+
return {"success": False, "message": "gRPC connection not established"}
|
|
127
123
|
|
|
128
124
|
try:
|
|
129
|
-
# Get model list from server
|
|
130
125
|
response = self._fetch_model_list(worker_id)
|
|
131
126
|
if not response or not response.success:
|
|
132
|
-
return {"success": False, "message": response
|
|
127
|
+
return {"success": False, "message": getattr(response, 'message', 'Unknown error')}
|
|
133
128
|
|
|
134
|
-
# Process models
|
|
135
129
|
self._process_server_models(response.data)
|
|
136
|
-
|
|
137
130
|
return {"success": True, "message": response.message, "data": response.data}
|
|
138
131
|
|
|
139
132
|
except Exception as e:
|
|
140
|
-
logging.error(f"Error
|
|
141
|
-
return {"success": False, "message": f"Error
|
|
133
|
+
logging.error(f"Error syncing models: {e}")
|
|
134
|
+
return {"success": False, "message": f"Error: {e}"}
|
|
142
135
|
|
|
143
136
|
def _fetch_model_list(self, worker_id: str):
|
|
144
|
-
"""Fetch model list from server using token authentication."""
|
|
145
137
|
request = GetAIModelListRequest(worker_id=worker_id, token=self.token)
|
|
146
138
|
return self.handle_rpc(self.stub.GetAIModelList, request)
|
|
147
139
|
|
|
148
140
|
def _process_server_models(self, server_models):
|
|
149
|
-
"""Process server models, handling additions, updates, and deletions."""
|
|
150
141
|
local_models = {model.id: model for model in self.repository.get_models()}
|
|
151
|
-
server_model_ids =
|
|
142
|
+
server_model_ids = {model.id for model in server_models}
|
|
152
143
|
|
|
153
144
|
new_models = []
|
|
154
145
|
updated_models = []
|
|
155
146
|
|
|
156
|
-
# Process each model from the server
|
|
157
147
|
for model in server_models:
|
|
158
|
-
server_model_ids.add(model.id)
|
|
159
148
|
existing_model = local_models.get(model.id)
|
|
160
|
-
|
|
161
149
|
if existing_model:
|
|
162
150
|
self._handle_existing_model(model, existing_model, updated_models)
|
|
163
151
|
else:
|
|
164
152
|
self._handle_new_model(model, new_models)
|
|
165
153
|
|
|
166
|
-
# Handle models that no longer exist on the server
|
|
167
154
|
models_to_delete = [
|
|
168
155
|
model for model_id, model in local_models.items()
|
|
169
156
|
if model_id not in server_model_ids
|
|
@@ -171,123 +158,151 @@ class AIModelClient(GrpcClientBase):
|
|
|
171
158
|
|
|
172
159
|
self._save_changes(new_models, updated_models, models_to_delete)
|
|
173
160
|
|
|
174
|
-
def _handle_existing_model(self, server_model, local_model, updated_models):
|
|
175
|
-
|
|
176
|
-
# Check if model file actually exists
|
|
177
|
-
if not self._is_model_file_exists(local_model.file):
|
|
161
|
+
def _handle_existing_model(self, server_model, local_model, updated_models: List):
|
|
162
|
+
if not self._model_file_exists(local_model.file):
|
|
178
163
|
logging.warning(f"⚠️ Model file missing for {local_model.name}. Re-downloading...")
|
|
179
|
-
self.
|
|
164
|
+
self._schedule_download(server_model)
|
|
180
165
|
return
|
|
181
166
|
|
|
182
|
-
|
|
183
|
-
|
|
167
|
+
needs_update, changes = self._check_model_changes(server_model, local_model)
|
|
168
|
+
|
|
169
|
+
version_changed = server_model.version != local_model.version
|
|
170
|
+
if not needs_update:
|
|
184
171
|
return
|
|
185
172
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
# Cancel any ongoing download for this model
|
|
190
|
-
self._cancel_download(server_model.id)
|
|
191
|
-
|
|
192
|
-
# Delete old model file
|
|
193
|
-
self.delete_local_model(local_model.file)
|
|
194
|
-
|
|
195
|
-
# Schedule new download
|
|
196
|
-
self._schedule_model_download(server_model)
|
|
173
|
+
change_desc = ", ".join(changes)
|
|
174
|
+
logging.info(f"🔄 Model update: {server_model.name} ({change_desc})")
|
|
197
175
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
176
|
+
if version_changed:
|
|
177
|
+
self._cancel_download(server_model.id)
|
|
178
|
+
self.delete_local_model(local_model.file)
|
|
179
|
+
self._schedule_download(server_model)
|
|
180
|
+
|
|
181
|
+
self._update_model_properties(local_model, server_model)
|
|
202
182
|
updated_models.append(local_model)
|
|
203
183
|
|
|
204
|
-
def _handle_new_model(self, server_model, new_models):
|
|
205
|
-
"""Handle model that doesn't exist locally."""
|
|
206
|
-
# Check if already downloading this model
|
|
184
|
+
def _handle_new_model(self, server_model, new_models: List):
|
|
207
185
|
if self._is_downloading(server_model.id):
|
|
208
|
-
logging.info(f"⏳ Model {server_model.name}
|
|
186
|
+
logging.info(f"⏳ Model {server_model.name} already downloading")
|
|
209
187
|
return
|
|
210
188
|
|
|
211
|
-
new_model =
|
|
212
|
-
|
|
213
|
-
|
|
189
|
+
new_model = self._create_model_entity(server_model)
|
|
190
|
+
new_models.append(new_model)
|
|
191
|
+
|
|
192
|
+
logging.info(f"⬇️ New model: {server_model.name}")
|
|
193
|
+
self._schedule_download(server_model)
|
|
194
|
+
|
|
195
|
+
def _check_model_changes(self, server_model, local_model) -> tuple[bool, List[str]]:
|
|
196
|
+
"""Check if model needs update and return list of changes"""
|
|
197
|
+
changes = []
|
|
198
|
+
|
|
199
|
+
if server_model.version != local_model.version:
|
|
200
|
+
changes.append(f"version: {local_model.version} -> {server_model.version}")
|
|
201
|
+
|
|
202
|
+
if server_model.ai_model_type_code != local_model.type:
|
|
203
|
+
changes.append(f"type: {local_model.type} -> {server_model.ai_model_type_code}")
|
|
204
|
+
|
|
205
|
+
server_classes = set(server_model.classes)
|
|
206
|
+
local_classes = set(local_model.get_classes() or [])
|
|
207
|
+
if server_classes != local_classes:
|
|
208
|
+
changes.append("classes updated")
|
|
209
|
+
|
|
210
|
+
server_ppe_groups = {
|
|
211
|
+
group.name: {"compliance": group.compliance, "violation": group.violation}
|
|
212
|
+
for group in (server_model.ppe_class_groups or [])
|
|
213
|
+
}
|
|
214
|
+
local_ppe_groups = local_model.get_ppe_class_groups() or {}
|
|
215
|
+
if server_ppe_groups != local_ppe_groups:
|
|
216
|
+
changes.append("PPE class groups updated")
|
|
217
|
+
|
|
218
|
+
if server_model.main_class != local_model.get_main_class():
|
|
219
|
+
changes.append(f"main class: {local_model.get_main_class()} -> {server_model.main_class}")
|
|
220
|
+
|
|
221
|
+
return bool(changes), changes
|
|
222
|
+
|
|
223
|
+
def _create_model_entity(self, server_model) -> AIModelEntity:
|
|
224
|
+
model = AIModelEntity(
|
|
225
|
+
id=server_model.id,
|
|
226
|
+
name=server_model.name,
|
|
214
227
|
type=server_model.ai_model_type_code,
|
|
215
|
-
file=os.path.basename(server_model.file_path),
|
|
228
|
+
file=os.path.basename(server_model.file_path),
|
|
216
229
|
version=server_model.version
|
|
217
230
|
)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
logging.info(f"⬇️ New model detected: {server_model.name}. Scheduling download...")
|
|
221
|
-
self._schedule_model_download(server_model)
|
|
231
|
+
self._update_model_properties(model, server_model)
|
|
232
|
+
return model
|
|
222
233
|
|
|
223
|
-
def
|
|
224
|
-
|
|
225
|
-
|
|
234
|
+
def _update_model_properties(self, local_model: AIModelEntity, server_model):
|
|
235
|
+
local_model.name = server_model.name
|
|
236
|
+
local_model.type = server_model.ai_model_type_code
|
|
237
|
+
local_model.version = server_model.version
|
|
238
|
+
local_model.set_classes(list(server_model.classes))
|
|
239
|
+
local_model.set_ppe_class_groups({
|
|
240
|
+
group.name: {
|
|
241
|
+
"compliance": group.compliance,
|
|
242
|
+
"violation": group.violation
|
|
243
|
+
}
|
|
244
|
+
for group in (server_model.ppe_class_groups or [])
|
|
245
|
+
})
|
|
246
|
+
local_model.set_main_class(server_model.main_class)
|
|
247
|
+
|
|
248
|
+
def _schedule_download(self, model):
|
|
226
249
|
self._cancel_download(model.id)
|
|
227
250
|
|
|
228
|
-
# Create new download info
|
|
229
251
|
download_info = DownloadInfo(
|
|
230
252
|
model_id=model.id,
|
|
231
253
|
model_name=model.name,
|
|
232
|
-
version=model.version
|
|
254
|
+
version=model.version,
|
|
255
|
+
start_time=time.time()
|
|
233
256
|
)
|
|
234
|
-
download_info.state = DownloadState.PENDING
|
|
235
|
-
download_info.start_time = time.time()
|
|
236
257
|
|
|
237
|
-
|
|
238
|
-
self._update_model_download_status(model.id, "pending", None)
|
|
258
|
+
self._update_model_status(model.id, "pending")
|
|
239
259
|
|
|
240
|
-
# Start download in background thread
|
|
241
260
|
download_info.thread = threading.Thread(
|
|
242
|
-
target=self.
|
|
261
|
+
target=self._download_worker,
|
|
243
262
|
args=(model, download_info),
|
|
244
263
|
daemon=True,
|
|
245
|
-
name=f"ModelDownload-{model.id}"
|
|
264
|
+
name=f"ModelDownload-{model.id[:8]}"
|
|
246
265
|
)
|
|
247
266
|
download_info.thread.start()
|
|
248
|
-
|
|
249
267
|
self._set_download_info(model.id, download_info)
|
|
250
268
|
|
|
251
|
-
def
|
|
252
|
-
"""Background worker for downloading a model."""
|
|
269
|
+
def _download_worker(self, model, download_info: DownloadInfo):
|
|
253
270
|
try:
|
|
254
271
|
download_info.state = DownloadState.DOWNLOADING
|
|
255
|
-
self.
|
|
256
|
-
logging.info(f"📥
|
|
272
|
+
self._update_model_status(model.id, "downloading")
|
|
273
|
+
logging.info(f"📥 Downloading {model.name}...")
|
|
274
|
+
|
|
275
|
+
success = self._download_model_file(model, download_info)
|
|
257
276
|
|
|
258
|
-
if
|
|
277
|
+
if success:
|
|
259
278
|
download_info.state = DownloadState.COMPLETED
|
|
260
279
|
download_info.end_time = time.time()
|
|
261
280
|
duration = download_info.end_time - download_info.start_time
|
|
262
|
-
self.
|
|
263
|
-
logging.info(f"✅
|
|
281
|
+
self._update_model_status(model.id, "completed")
|
|
282
|
+
logging.info(f"✅ Downloaded {model.name} in {duration:.1f}s")
|
|
264
283
|
else:
|
|
265
|
-
download_info
|
|
266
|
-
download_info.error_message = "Download failed"
|
|
267
|
-
self._update_model_download_status(model.id, "failed", "Download failed")
|
|
268
|
-
logging.error(f"❌ Failed to download AI Model '{model.name}'")
|
|
284
|
+
self._handle_download_failure(download_info, "Download failed")
|
|
269
285
|
|
|
270
286
|
except Exception as e:
|
|
271
|
-
download_info
|
|
272
|
-
download_info.error_message = str(e)
|
|
273
|
-
self._update_model_download_status(model.id, "failed", str(e))
|
|
274
|
-
logging.error(f"❌ Error downloading AI Model '{model.name}': {e}")
|
|
287
|
+
self._handle_download_failure(download_info, str(e))
|
|
275
288
|
finally:
|
|
276
|
-
# Clean up download info after a delay to allow status checking
|
|
277
289
|
threading.Timer(300, lambda: self._remove_download_info(model.id)).start()
|
|
278
290
|
|
|
279
|
-
def
|
|
280
|
-
|
|
291
|
+
def _handle_download_failure(self, download_info: DownloadInfo, error: str):
|
|
292
|
+
download_info.state = DownloadState.FAILED
|
|
293
|
+
download_info.error_message = error
|
|
294
|
+
self._update_model_status(download_info.model_id, "failed", error)
|
|
295
|
+
logging.error(f"❌ Failed to download {download_info.model_name}: {error}")
|
|
296
|
+
|
|
297
|
+
def _save_changes(self, new_models: List, updated_models: List, models_to_delete: List):
|
|
281
298
|
try:
|
|
282
299
|
if new_models:
|
|
283
300
|
self.repository.session.bulk_save_objects(new_models)
|
|
284
|
-
|
|
285
301
|
if updated_models:
|
|
286
302
|
self.repository.session.bulk_save_objects(updated_models)
|
|
287
303
|
|
|
288
304
|
for model in models_to_delete:
|
|
289
|
-
logging.info(f"🗑️
|
|
290
|
-
# Cancel any ongoing download
|
|
305
|
+
logging.info(f"🗑️ Removing {model.name}")
|
|
291
306
|
self._cancel_download(model.id)
|
|
292
307
|
self.repository.session.delete(model)
|
|
293
308
|
self.delete_local_model(model.file)
|
|
@@ -295,50 +310,43 @@ class AIModelClient(GrpcClientBase):
|
|
|
295
310
|
self.repository.session.commit()
|
|
296
311
|
except Exception as e:
|
|
297
312
|
self.repository.session.rollback()
|
|
298
|
-
logging.error(f"Error saving
|
|
313
|
+
logging.error(f"Error saving changes: {e}")
|
|
299
314
|
raise
|
|
300
315
|
|
|
301
|
-
def
|
|
302
|
-
"""Download the AI model and save it to the models directory."""
|
|
316
|
+
def _download_model_file(self, model, download_info: DownloadInfo) -> bool:
|
|
303
317
|
if not self.stub:
|
|
304
|
-
logging.error("gRPC connection is not established.")
|
|
305
318
|
return False
|
|
306
319
|
|
|
307
320
|
try:
|
|
308
321
|
request = DownloadAIModelRequest(ai_model_id=model.id, token=self.token)
|
|
309
322
|
file_path = self._get_model_path(model.file_path)
|
|
310
323
|
|
|
311
|
-
# Check if download was cancelled
|
|
312
|
-
if download_info and download_info.stop_event.is_set():
|
|
313
|
-
logging.info(f"🛑 Download cancelled for model '{model.name}'")
|
|
314
|
-
return False
|
|
315
|
-
|
|
316
324
|
with open(file_path, "wb") as f:
|
|
317
325
|
for chunk in self.stub.DownloadAIModel(request):
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
logging.info(f"🛑 Download cancelled during streaming for model '{model.name}'")
|
|
326
|
+
if download_info.stop_event.is_set():
|
|
327
|
+
logging.info(f"🛑 Download cancelled: {model.name}")
|
|
321
328
|
return False
|
|
322
329
|
f.write(chunk.file_chunk)
|
|
323
330
|
|
|
324
331
|
return True
|
|
325
332
|
|
|
326
333
|
except Exception as e:
|
|
327
|
-
logging.error(f"❌
|
|
334
|
+
logging.error(f"❌ Download error for {model.name}: {e}")
|
|
328
335
|
return False
|
|
329
336
|
|
|
330
|
-
def delete_local_model(self,
|
|
331
|
-
"""Delete a local AI model file."""
|
|
332
|
-
file_path = self._get_model_path(file)
|
|
337
|
+
def delete_local_model(self, filename: str) -> bool:
|
|
333
338
|
try:
|
|
339
|
+
file_path = self._get_model_path(filename)
|
|
334
340
|
if file_path.exists():
|
|
335
341
|
file_path.unlink()
|
|
336
|
-
logging.info(f"🗑️
|
|
342
|
+
logging.info(f"🗑️ Deleted {filename}")
|
|
343
|
+
return True
|
|
344
|
+
return False
|
|
337
345
|
except Exception as e:
|
|
338
|
-
logging.error(f"❌ Error deleting
|
|
346
|
+
logging.error(f"❌ Error deleting {filename}: {e}")
|
|
347
|
+
return False
|
|
339
348
|
|
|
340
349
|
def get_download_status(self, model_id: str) -> Optional[Dict]:
|
|
341
|
-
"""Get the download status for a specific model."""
|
|
342
350
|
download_info = self._get_download_info(model_id)
|
|
343
351
|
if not download_info:
|
|
344
352
|
return None
|
|
@@ -354,9 +362,25 @@ class AIModelClient(GrpcClientBase):
|
|
|
354
362
|
}
|
|
355
363
|
|
|
356
364
|
def get_all_download_status(self) -> Dict[str, Dict]:
|
|
357
|
-
|
|
358
|
-
with self.download_lock:
|
|
365
|
+
with self._download_lock_context():
|
|
359
366
|
return {
|
|
360
367
|
model_id: self.get_download_status(model_id)
|
|
361
368
|
for model_id in self.download_tracker.keys()
|
|
362
|
-
}
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
def cancel_all_downloads(self) -> int:
|
|
372
|
+
cancelled_count = 0
|
|
373
|
+
with self._download_lock_context():
|
|
374
|
+
for model_id in list(self.download_tracker.keys()):
|
|
375
|
+
if self._cancel_download(model_id):
|
|
376
|
+
cancelled_count += 1
|
|
377
|
+
return cancelled_count
|
|
378
|
+
|
|
379
|
+
def cleanup_downloads(self):
|
|
380
|
+
with self._download_lock_context():
|
|
381
|
+
completed_ids = [
|
|
382
|
+
model_id for model_id, info in self.download_tracker.items()
|
|
383
|
+
if info.state in {DownloadState.COMPLETED, DownloadState.FAILED, DownloadState.CANCELLED}
|
|
384
|
+
]
|
|
385
|
+
for model_id in completed_ids:
|
|
386
|
+
self._remove_download_info(model_id)
|