mycelium-ai 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mycelium/__init__.py +0 -0
- mycelium/api/__init__.py +0 -0
- mycelium/api/app.py +1147 -0
- mycelium/api/client_app.py +170 -0
- mycelium/api/generated_sources/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/__init__.py +97 -0
- mycelium/api/generated_sources/server_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/server_schemas/api/default_api.py +2473 -0
- mycelium/api/generated_sources/server_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/server_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/server_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/server_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/server_schemas/models/__init__.py +41 -0
- mycelium/api/generated_sources/server_schemas/models/api_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/chroma_section.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/clap_section.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server200_response.py +79 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server_request.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/compute_text_search_request.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/config_request.py +81 -0
- mycelium/api/generated_sources/server_schemas/models/config_response.py +107 -0
- mycelium/api/generated_sources/server_schemas/models/create_playlist_request.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/get_similar_by_track200_response.py +143 -0
- mycelium/api/generated_sources/server_schemas/models/library_stats_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/logging_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/media_server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/playlist_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/plex_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/processing_response.py +90 -0
- mycelium/api/generated_sources/server_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/scan_library_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/search_result_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/stop_processing_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/task_status_response.py +87 -0
- mycelium/api/generated_sources/server_schemas/models/track_database_stats.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/track_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/tracks_list_response.py +81 -0
- mycelium/api/generated_sources/server_schemas/rest.py +329 -0
- mycelium/api/generated_sources/server_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/test/test_api_section.py +57 -0
- mycelium/api/generated_sources/server_schemas/test/test_chroma_section.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_clap_section.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server200_response.py +52 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server_request.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_text_search_request.py +54 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_request.py +66 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_response.py +97 -0
- mycelium/api/generated_sources/server_schemas/test/test_create_playlist_request.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_default_api.py +150 -0
- mycelium/api/generated_sources/server_schemas/test/test_get_similar_by_track200_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_library_stats_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_logging_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_media_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_playlist_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_plex_section.py +56 -0
- mycelium/api/generated_sources/server_schemas/test/test_processing_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_scan_library_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_search_result_response.py +69 -0
- mycelium/api/generated_sources/server_schemas/test/test_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_stop_processing_response.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_task_status_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_database_stats.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_tracks_list_response.py +75 -0
- mycelium/api/generated_sources/worker_schemas/__init__.py +61 -0
- mycelium/api/generated_sources/worker_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/worker_schemas/api/default_api.py +318 -0
- mycelium/api/generated_sources/worker_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/worker_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/worker_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/worker_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/worker_schemas/models/__init__.py +23 -0
- mycelium/api/generated_sources/worker_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_clap_section.py +75 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_api_section.py +69 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_section.py +79 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_request.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_logging_section.py +67 -0
- mycelium/api/generated_sources/worker_schemas/rest.py +329 -0
- mycelium/api/generated_sources/worker_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/worker_schemas/test/test_default_api.py +45 -0
- mycelium/api/generated_sources/worker_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_clap_section.py +60 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_api_section.py +55 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_section.py +65 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_request.py +59 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_logging_section.py +53 -0
- mycelium/api/worker_models.py +99 -0
- mycelium/application/__init__.py +11 -0
- mycelium/application/job_queue.py +323 -0
- mycelium/application/library_management_use_cases.py +292 -0
- mycelium/application/search_use_cases.py +96 -0
- mycelium/application/services.py +340 -0
- mycelium/client.py +554 -0
- mycelium/client_config.py +251 -0
- mycelium/client_frontend_dist/404.html +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_buildManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_ssgManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/page-cc6bad295789134e.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/client_frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/client_frontend_dist/favicon.ico +0 -0
- mycelium/client_frontend_dist/file.svg +1 -0
- mycelium/client_frontend_dist/globe.svg +1 -0
- mycelium/client_frontend_dist/index.html +1 -0
- mycelium/client_frontend_dist/index.txt +20 -0
- mycelium/client_frontend_dist/next.svg +1 -0
- mycelium/client_frontend_dist/vercel.svg +1 -0
- mycelium/client_frontend_dist/window.svg +1 -0
- mycelium/config.py +346 -0
- mycelium/domain/__init__.py +13 -0
- mycelium/domain/models.py +71 -0
- mycelium/domain/repositories.py +98 -0
- mycelium/domain/worker.py +77 -0
- mycelium/frontend_dist/404.html +1 -0
- mycelium/frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/page-a761463485e0540b.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_buildManifest.js +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_ssgManifest.js +1 -0
- mycelium/frontend_dist/favicon.ico +0 -0
- mycelium/frontend_dist/file.svg +1 -0
- mycelium/frontend_dist/globe.svg +1 -0
- mycelium/frontend_dist/index.html +10 -0
- mycelium/frontend_dist/index.txt +20 -0
- mycelium/frontend_dist/next.svg +1 -0
- mycelium/frontend_dist/vercel.svg +1 -0
- mycelium/frontend_dist/window.svg +1 -0
- mycelium/infrastructure/__init__.py +17 -0
- mycelium/infrastructure/chroma_adapter.py +232 -0
- mycelium/infrastructure/clap_adapter.py +280 -0
- mycelium/infrastructure/plex_adapter.py +145 -0
- mycelium/infrastructure/track_database.py +467 -0
- mycelium/main.py +183 -0
- mycelium_ai-0.5.0.dist-info/METADATA +312 -0
- mycelium_ai-0.5.0.dist-info/RECORD +164 -0
- mycelium_ai-0.5.0.dist-info/WHEEL +5 -0
- mycelium_ai-0.5.0.dist-info/entry_points.txt +2 -0
- mycelium_ai-0.5.0.dist-info/licenses/LICENSE +21 -0
- mycelium_ai-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,292 @@
|
|
1
|
+
"""New use cases for separated scanning and processing workflow."""
|
2
|
+
import logging
|
3
|
+
from datetime import datetime, timezone
|
4
|
+
from typing import Optional, Dict, Any
|
5
|
+
|
6
|
+
from .job_queue import JobQueueService
|
7
|
+
from ..domain.models import TrackEmbedding
|
8
|
+
from ..domain.repositories import MediaServerRepository, EmbeddingRepository, EmbeddingGenerator
|
9
|
+
from ..domain.worker import ContextType
|
10
|
+
from ..infrastructure.track_database import TrackDatabase
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class LibraryScanUseCase:
|
16
|
+
"""Use case for scanning and storing track metadata."""
|
17
|
+
|
18
|
+
def __init__(self, media_server_repository: MediaServerRepository, track_database: TrackDatabase):
|
19
|
+
self.media_server_repository = media_server_repository
|
20
|
+
self.track_database = track_database
|
21
|
+
|
22
|
+
def execute(self, progress_callback: Optional[callable] = None) -> Dict[str, Any]:
|
23
|
+
"""Scan the media server music library and store track metadata."""
|
24
|
+
logger.info("Starting library scan...")
|
25
|
+
|
26
|
+
try:
|
27
|
+
# Get all tracks from media server
|
28
|
+
tracks = self.media_server_repository.get_all_tracks()
|
29
|
+
logger.info(f"Found {len(tracks)} tracks in library")
|
30
|
+
|
31
|
+
if progress_callback:
|
32
|
+
progress_callback({"stage": "scanning", "current": len(tracks), "total": len(tracks)})
|
33
|
+
|
34
|
+
# Save tracks to database
|
35
|
+
scan_timestamp = datetime.now(timezone.utc)
|
36
|
+
stats = self.track_database.save_tracks(tracks=tracks,
|
37
|
+
scan_timestamp=scan_timestamp)
|
38
|
+
|
39
|
+
result = {
|
40
|
+
"total_tracks": stats["total"],
|
41
|
+
"new_tracks": stats["new"],
|
42
|
+
"updated_tracks": stats["updated"],
|
43
|
+
"scan_timestamp": scan_timestamp.isoformat()
|
44
|
+
}
|
45
|
+
|
46
|
+
logger.info(f"Scan completed: {stats['total']} total, {stats['new']} new, {stats['updated']} updated")
|
47
|
+
|
48
|
+
if progress_callback:
|
49
|
+
progress_callback({"stage": "complete", "result": result})
|
50
|
+
|
51
|
+
return result
|
52
|
+
|
53
|
+
except Exception as e:
|
54
|
+
logger.error(f"Scan failed: {e}")
|
55
|
+
raise
|
56
|
+
|
57
|
+
|
58
|
+
class EmbeddingProcessingUseCase:
|
59
|
+
"""Use case for embedding processing from stored tracks."""
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
embedding_generator: EmbeddingGenerator,
|
64
|
+
embedding_repository: EmbeddingRepository,
|
65
|
+
track_database: TrackDatabase,
|
66
|
+
model_id: str,
|
67
|
+
gpu_batch_size: int = 16
|
68
|
+
):
|
69
|
+
self.embedding_generator = embedding_generator
|
70
|
+
self.embedding_repository = embedding_repository
|
71
|
+
self.track_database = track_database
|
72
|
+
self.model_id = model_id
|
73
|
+
self.gpu_batch_size = gpu_batch_size
|
74
|
+
self._should_stop = False
|
75
|
+
|
76
|
+
def process_embeddings(
|
77
|
+
self,
|
78
|
+
progress_callback: Optional[callable] = None,
|
79
|
+
max_tracks: Optional[int] = None
|
80
|
+
) -> Dict[str, Any]:
|
81
|
+
"""Process embeddings for unprocessed tracks with resumability."""
|
82
|
+
logger.info(f"Starting embedding processing with model: {self.model_id}")
|
83
|
+
|
84
|
+
# Get unprocessed tracks for this specific model
|
85
|
+
unprocessed_tracks = self.track_database.get_unprocessed_tracks(model_id=self.model_id,
|
86
|
+
limit=max_tracks)
|
87
|
+
|
88
|
+
if not unprocessed_tracks:
|
89
|
+
logger.info("No unprocessed tracks found")
|
90
|
+
return {
|
91
|
+
"processed": 0,
|
92
|
+
"failed": 0,
|
93
|
+
"total": 0,
|
94
|
+
"message": "No tracks to process"
|
95
|
+
}
|
96
|
+
|
97
|
+
logger.info(f"Found {len(unprocessed_tracks)} unprocessed tracks")
|
98
|
+
processed_count = 0
|
99
|
+
failed_count = 0
|
100
|
+
|
101
|
+
try:
|
102
|
+
for i in range(0, len(unprocessed_tracks), self.gpu_batch_size):
|
103
|
+
if self._should_stop:
|
104
|
+
logger.info("Processing stopped by user request")
|
105
|
+
break
|
106
|
+
|
107
|
+
batch = unprocessed_tracks[i:i + self.gpu_batch_size]
|
108
|
+
tracks = []
|
109
|
+
filepaths = []
|
110
|
+
valid_stored_tracks = []
|
111
|
+
|
112
|
+
# Prepare batch data
|
113
|
+
for stored_track in batch:
|
114
|
+
try:
|
115
|
+
track = stored_track.to_track()
|
116
|
+
tracks.append(track)
|
117
|
+
filepaths.append(track.filepath)
|
118
|
+
valid_stored_tracks.append(stored_track)
|
119
|
+
except Exception as e:
|
120
|
+
logger.error(f"Error converting track {stored_track.media_server_rating_key}: {e}")
|
121
|
+
failed_count += 1
|
122
|
+
|
123
|
+
if not filepaths:
|
124
|
+
continue
|
125
|
+
|
126
|
+
# Generate embeddings in batch
|
127
|
+
embeddings = self.embedding_generator.generate_embedding_batch(filepaths)
|
128
|
+
|
129
|
+
# Process results
|
130
|
+
for track, stored_track, embedding in zip(tracks, valid_stored_tracks, embeddings):
|
131
|
+
try:
|
132
|
+
if embedding:
|
133
|
+
# Create track embedding object with model info
|
134
|
+
track_embedding = TrackEmbedding(
|
135
|
+
track=track,
|
136
|
+
embedding=embedding,
|
137
|
+
model_id=self.model_id,
|
138
|
+
processed_at=datetime.now(timezone.utc)
|
139
|
+
)
|
140
|
+
|
141
|
+
# Save to vector database
|
142
|
+
self.embedding_repository.save_embeddings(embeddings=[track_embedding])
|
143
|
+
|
144
|
+
# Mark as processed in metadata database
|
145
|
+
self.track_database.mark_track_processed(
|
146
|
+
media_server_rating_key=stored_track.media_server_rating_key,
|
147
|
+
model_id=self.model_id
|
148
|
+
)
|
149
|
+
|
150
|
+
processed_count += 1
|
151
|
+
|
152
|
+
if progress_callback:
|
153
|
+
progress_callback({
|
154
|
+
"stage": "processing",
|
155
|
+
"current": processed_count + failed_count,
|
156
|
+
"total": len(unprocessed_tracks),
|
157
|
+
"processed": processed_count,
|
158
|
+
"failed": failed_count,
|
159
|
+
"current_track": track.display_name
|
160
|
+
})
|
161
|
+
else:
|
162
|
+
logger.warning(f"Failed to generate embedding for: {track.display_name}")
|
163
|
+
failed_count += 1
|
164
|
+
|
165
|
+
except Exception as e:
|
166
|
+
logger.error(f"Error processing track {stored_track.media_server_rating_key}: {e}")
|
167
|
+
failed_count += 1
|
168
|
+
|
169
|
+
result = {
|
170
|
+
"processed": processed_count,
|
171
|
+
"failed": failed_count,
|
172
|
+
"total": len(unprocessed_tracks),
|
173
|
+
"stopped": self._should_stop
|
174
|
+
}
|
175
|
+
|
176
|
+
logger.info(f"Processing completed: {processed_count} processed, {failed_count} failed")
|
177
|
+
|
178
|
+
if progress_callback:
|
179
|
+
progress_callback({"stage": "complete", "result": result})
|
180
|
+
|
181
|
+
return result
|
182
|
+
|
183
|
+
except Exception as e:
|
184
|
+
logger.info(f"Processing failed: {e}")
|
185
|
+
raise
|
186
|
+
|
187
|
+
def stop(self) -> None:
|
188
|
+
"""Request to stop processing."""
|
189
|
+
self._should_stop = True
|
190
|
+
logger.info("Stop requested - will finish current track and stop")
|
191
|
+
|
192
|
+
def reset_stop_flag(self) -> None:
|
193
|
+
"""Reset the stop flag for a new processing session."""
|
194
|
+
self._should_stop = False
|
195
|
+
|
196
|
+
|
197
|
+
class ProcessingProgressUseCase:
|
198
|
+
"""Use case for tracking processing progress."""
|
199
|
+
|
200
|
+
def __init__(self, track_database: TrackDatabase):
|
201
|
+
self.track_database = track_database
|
202
|
+
|
203
|
+
def get_current_stats(self, model_id: Optional[str] = None) -> Dict[str, Any]:
|
204
|
+
"""Get current processing statistics, optionally filtered by model."""
|
205
|
+
stats = self.track_database.get_processing_stats(model_id)
|
206
|
+
|
207
|
+
return {
|
208
|
+
"total_tracks": stats["total_tracks"],
|
209
|
+
"processed_tracks": stats["processed_tracks"],
|
210
|
+
"unprocessed_tracks": stats["unprocessed_tracks"],
|
211
|
+
"progress_percentage": (stats["processed_tracks"] / stats["total_tracks"] * 100) if stats[
|
212
|
+
"total_tracks"] > 0 else 0,
|
213
|
+
"model_id": model_id
|
214
|
+
}
|
215
|
+
|
216
|
+
|
217
|
+
class WorkerBasedProcessingUseCase:
|
218
|
+
"""Use case for processing embeddings using client workers."""
|
219
|
+
|
220
|
+
def __init__(
|
221
|
+
self,
|
222
|
+
job_queue_service: JobQueueService,
|
223
|
+
track_database: TrackDatabase,
|
224
|
+
api_host: str = "localhost",
|
225
|
+
api_port: int = 8000
|
226
|
+
):
|
227
|
+
self.job_queue = job_queue_service
|
228
|
+
self.track_database = track_database
|
229
|
+
self.api_host = api_host
|
230
|
+
self.api_port = api_port
|
231
|
+
|
232
|
+
def can_use_workers(self) -> bool:
|
233
|
+
"""Check if there are active workers available."""
|
234
|
+
active_workers = self.job_queue.get_active_workers()
|
235
|
+
return len(active_workers) > 0
|
236
|
+
|
237
|
+
def get_worker_info(self) -> Dict[str, Any]:
|
238
|
+
"""Get information about available workers."""
|
239
|
+
active_workers = self.job_queue.get_active_workers()
|
240
|
+
queue_stats = self.job_queue.get_queue_stats()
|
241
|
+
|
242
|
+
return {
|
243
|
+
"active_workers": len(active_workers),
|
244
|
+
"worker_details": [
|
245
|
+
{
|
246
|
+
"id": worker.id,
|
247
|
+
"ip_address": worker.ip_address,
|
248
|
+
"last_heartbeat": worker.last_heartbeat.isoformat()
|
249
|
+
}
|
250
|
+
for worker in active_workers
|
251
|
+
],
|
252
|
+
"queue_stats": queue_stats
|
253
|
+
}
|
254
|
+
|
255
|
+
def create_worker_tasks(self, model_id: str, max_tracks: Optional[int] = None) -> Dict[str, Any]:
|
256
|
+
"""Create tasks for all unprocessed tracks to be handled by workers."""
|
257
|
+
if not self.can_use_workers():
|
258
|
+
return {
|
259
|
+
"success": False,
|
260
|
+
"message": "No active workers available",
|
261
|
+
"tasks_created": 0
|
262
|
+
}
|
263
|
+
|
264
|
+
# Get unprocessed tracks
|
265
|
+
unprocessed_tracks = self.track_database.get_unprocessed_tracks(limit=max_tracks, model_id=model_id)
|
266
|
+
|
267
|
+
if not unprocessed_tracks:
|
268
|
+
return {
|
269
|
+
"success": True,
|
270
|
+
"message": "No tracks to process",
|
271
|
+
"tasks_created": 0
|
272
|
+
}
|
273
|
+
|
274
|
+
# Create tasks for each track
|
275
|
+
tasks_created = 0
|
276
|
+
for stored_track in unprocessed_tracks:
|
277
|
+
try:
|
278
|
+
download_url = f"/download_track/{stored_track.media_server_rating_key}"
|
279
|
+
self.job_queue.create_task(stored_track.media_server_rating_key, download_url, prioritize=False,
|
280
|
+
context_type=ContextType.AUDIO_PROCESSING)
|
281
|
+
tasks_created += 1
|
282
|
+
except Exception as e:
|
283
|
+
logger.error(f"Failed to create task for track {stored_track.media_server_rating_key}: {e}")
|
284
|
+
continue
|
285
|
+
|
286
|
+
return {
|
287
|
+
"success": True,
|
288
|
+
"message": f"Created {tasks_created} tasks for worker processing",
|
289
|
+
"tasks_created": tasks_created,
|
290
|
+
"total_unprocessed": len(unprocessed_tracks),
|
291
|
+
"worker_info": self.get_worker_info()
|
292
|
+
}
|
@@ -0,0 +1,96 @@
|
|
1
|
+
"""Use cases for the Mycelium application."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
from ..domain.models import SearchResult, MediaServerType, Track
|
8
|
+
from ..domain.repositories import EmbeddingRepository, EmbeddingGenerator
|
9
|
+
|
10
|
+
class MusicSearchUseCase:
|
11
|
+
"""Use case for searching music by similarity."""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
embedding_repository: EmbeddingRepository,
|
16
|
+
embedding_generator: EmbeddingGenerator
|
17
|
+
):
|
18
|
+
self.embedding_repository = embedding_repository
|
19
|
+
self.embedding_generator = embedding_generator
|
20
|
+
self.logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
def search_by_audio_file(
|
23
|
+
self,
|
24
|
+
filepath: Path,
|
25
|
+
n_results: int = 10,
|
26
|
+
exclude_self: bool = True
|
27
|
+
) -> List[SearchResult]:
|
28
|
+
"""Find similar songs to an audio file."""
|
29
|
+
self.logger.info(f"Searching for songs similar to: {filepath.name}")
|
30
|
+
|
31
|
+
# Generate embedding for the query audio
|
32
|
+
query_embedding = self.embedding_generator.generate_embedding(filepath)
|
33
|
+
|
34
|
+
if query_embedding is None:
|
35
|
+
self.logger.error("Could not generate embedding for the query.")
|
36
|
+
return []
|
37
|
+
|
38
|
+
# Search in the database
|
39
|
+
# Request n_results + 1 to account for potentially discarding the same song
|
40
|
+
results = self.embedding_repository.search_by_embedding(
|
41
|
+
query_embedding,
|
42
|
+
n_results=n_results + 1 if exclude_self else n_results
|
43
|
+
)
|
44
|
+
|
45
|
+
# Filter out the same file if requested
|
46
|
+
if exclude_self:
|
47
|
+
results = [
|
48
|
+
result for result in results
|
49
|
+
if result.track.filepath != filepath
|
50
|
+
][:n_results]
|
51
|
+
|
52
|
+
return results
|
53
|
+
|
54
|
+
def search_by_text(self, query_text: str, n_results: int = 10) -> List[SearchResult]:
|
55
|
+
"""Find songs that match a text description."""
|
56
|
+
self.logger.info(f"Searching for songs matching: '{query_text}'")
|
57
|
+
|
58
|
+
# Generate embedding for the text query
|
59
|
+
text_embedding = self.embedding_generator.generate_text_embedding(query_text)
|
60
|
+
|
61
|
+
if text_embedding is None:
|
62
|
+
self.logger.error("Could not generate embedding for the text query.")
|
63
|
+
return []
|
64
|
+
|
65
|
+
# Search in the database
|
66
|
+
results = self.embedding_repository.search_by_embedding(text_embedding, n_results)
|
67
|
+
|
68
|
+
return results
|
69
|
+
|
70
|
+
def search_by_track_id(self, track_id: str, n_results: int = 10) -> List[SearchResult]:
|
71
|
+
"""Find songs similar to a track identified by its ID."""
|
72
|
+
self.logger.info(f"Searching for songs similar to track ID: {track_id}")
|
73
|
+
|
74
|
+
# Get the embedding for this track
|
75
|
+
embedding = self.embedding_repository.get_embedding_by_track_id(track_id)
|
76
|
+
|
77
|
+
if embedding is None:
|
78
|
+
self.logger.error(f"No embedding found for track ID: {track_id}")
|
79
|
+
# Try to check if the embedding exists using has_embedding
|
80
|
+
has_emb = self.embedding_repository.has_embedding(track_id)
|
81
|
+
self.logger.error(f"Double-check has_embedding for track {track_id}: {has_emb}")
|
82
|
+
return []
|
83
|
+
|
84
|
+
self.logger.info(f"Found embedding for track {track_id}, size: {len(embedding)}")
|
85
|
+
|
86
|
+
# Search for similar tracks
|
87
|
+
results = self.embedding_repository.search_by_embedding(embedding, n_results + 1)
|
88
|
+
|
89
|
+
# Filter out the same track (it will be the first result with distance 0)
|
90
|
+
results = [
|
91
|
+
result for result in results
|
92
|
+
if result.track.media_server_rating_key != track_id
|
93
|
+
][:n_results]
|
94
|
+
|
95
|
+
self.logger.info(f"Found {len(results)} similar tracks for track {track_id}")
|
96
|
+
return results
|