mycelium-ai 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mycelium/__init__.py +0 -0
- mycelium/api/__init__.py +0 -0
- mycelium/api/app.py +1147 -0
- mycelium/api/client_app.py +170 -0
- mycelium/api/generated_sources/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/__init__.py +97 -0
- mycelium/api/generated_sources/server_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/server_schemas/api/default_api.py +2473 -0
- mycelium/api/generated_sources/server_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/server_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/server_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/server_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/server_schemas/models/__init__.py +41 -0
- mycelium/api/generated_sources/server_schemas/models/api_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/chroma_section.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/clap_section.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server200_response.py +79 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server_request.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/compute_text_search_request.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/config_request.py +81 -0
- mycelium/api/generated_sources/server_schemas/models/config_response.py +107 -0
- mycelium/api/generated_sources/server_schemas/models/create_playlist_request.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/get_similar_by_track200_response.py +143 -0
- mycelium/api/generated_sources/server_schemas/models/library_stats_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/logging_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/media_server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/playlist_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/plex_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/processing_response.py +90 -0
- mycelium/api/generated_sources/server_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/scan_library_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/search_result_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/stop_processing_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/task_status_response.py +87 -0
- mycelium/api/generated_sources/server_schemas/models/track_database_stats.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/track_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/tracks_list_response.py +81 -0
- mycelium/api/generated_sources/server_schemas/rest.py +329 -0
- mycelium/api/generated_sources/server_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/test/test_api_section.py +57 -0
- mycelium/api/generated_sources/server_schemas/test/test_chroma_section.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_clap_section.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server200_response.py +52 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server_request.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_text_search_request.py +54 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_request.py +66 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_response.py +97 -0
- mycelium/api/generated_sources/server_schemas/test/test_create_playlist_request.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_default_api.py +150 -0
- mycelium/api/generated_sources/server_schemas/test/test_get_similar_by_track200_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_library_stats_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_logging_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_media_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_playlist_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_plex_section.py +56 -0
- mycelium/api/generated_sources/server_schemas/test/test_processing_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_scan_library_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_search_result_response.py +69 -0
- mycelium/api/generated_sources/server_schemas/test/test_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_stop_processing_response.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_task_status_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_database_stats.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_tracks_list_response.py +75 -0
- mycelium/api/generated_sources/worker_schemas/__init__.py +61 -0
- mycelium/api/generated_sources/worker_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/worker_schemas/api/default_api.py +318 -0
- mycelium/api/generated_sources/worker_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/worker_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/worker_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/worker_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/worker_schemas/models/__init__.py +23 -0
- mycelium/api/generated_sources/worker_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_clap_section.py +75 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_api_section.py +69 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_section.py +79 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_request.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_logging_section.py +67 -0
- mycelium/api/generated_sources/worker_schemas/rest.py +329 -0
- mycelium/api/generated_sources/worker_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/worker_schemas/test/test_default_api.py +45 -0
- mycelium/api/generated_sources/worker_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_clap_section.py +60 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_api_section.py +55 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_section.py +65 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_request.py +59 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_logging_section.py +53 -0
- mycelium/api/worker_models.py +99 -0
- mycelium/application/__init__.py +11 -0
- mycelium/application/job_queue.py +323 -0
- mycelium/application/library_management_use_cases.py +292 -0
- mycelium/application/search_use_cases.py +96 -0
- mycelium/application/services.py +340 -0
- mycelium/client.py +554 -0
- mycelium/client_config.py +251 -0
- mycelium/client_frontend_dist/404.html +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_buildManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_ssgManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/page-cc6bad295789134e.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/client_frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/client_frontend_dist/favicon.ico +0 -0
- mycelium/client_frontend_dist/file.svg +1 -0
- mycelium/client_frontend_dist/globe.svg +1 -0
- mycelium/client_frontend_dist/index.html +1 -0
- mycelium/client_frontend_dist/index.txt +20 -0
- mycelium/client_frontend_dist/next.svg +1 -0
- mycelium/client_frontend_dist/vercel.svg +1 -0
- mycelium/client_frontend_dist/window.svg +1 -0
- mycelium/config.py +346 -0
- mycelium/domain/__init__.py +13 -0
- mycelium/domain/models.py +71 -0
- mycelium/domain/repositories.py +98 -0
- mycelium/domain/worker.py +77 -0
- mycelium/frontend_dist/404.html +1 -0
- mycelium/frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/page-a761463485e0540b.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_buildManifest.js +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_ssgManifest.js +1 -0
- mycelium/frontend_dist/favicon.ico +0 -0
- mycelium/frontend_dist/file.svg +1 -0
- mycelium/frontend_dist/globe.svg +1 -0
- mycelium/frontend_dist/index.html +10 -0
- mycelium/frontend_dist/index.txt +20 -0
- mycelium/frontend_dist/next.svg +1 -0
- mycelium/frontend_dist/vercel.svg +1 -0
- mycelium/frontend_dist/window.svg +1 -0
- mycelium/infrastructure/__init__.py +17 -0
- mycelium/infrastructure/chroma_adapter.py +232 -0
- mycelium/infrastructure/clap_adapter.py +280 -0
- mycelium/infrastructure/plex_adapter.py +145 -0
- mycelium/infrastructure/track_database.py +467 -0
- mycelium/main.py +183 -0
- mycelium_ai-0.5.0.dist-info/METADATA +312 -0
- mycelium_ai-0.5.0.dist-info/RECORD +164 -0
- mycelium_ai-0.5.0.dist-info/WHEEL +5 -0
- mycelium_ai-0.5.0.dist-info/entry_points.txt +2 -0
- mycelium_ai-0.5.0.dist-info/licenses/LICENSE +21 -0
- mycelium_ai-0.5.0.dist-info/top_level.txt +1 -0
mycelium/client.py
ADDED
@@ -0,0 +1,554 @@
|
|
1
|
+
"""Mycelium client for processing audio embeddings on GPU workers."""
|
2
|
+
import gc
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import socket
|
6
|
+
import tempfile
|
7
|
+
import threading
|
8
|
+
import time
|
9
|
+
import uuid
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from pathlib import Path
|
12
|
+
from queue import Queue, Empty
|
13
|
+
from typing import Optional, List
|
14
|
+
|
15
|
+
import requests
|
16
|
+
|
17
|
+
from mycelium.client_config import MyceliumClientConfig
|
18
|
+
from mycelium.client_config import get_client_config_file_path
|
19
|
+
from mycelium.infrastructure.clap_adapter import CLAPEmbeddingGenerator
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class DownloadedJob:
|
26
|
+
"""Represents a job with downloaded audio file."""
|
27
|
+
task_id: str
|
28
|
+
track_id: str
|
29
|
+
original_job: dict
|
30
|
+
audio_file: Optional[Path]
|
31
|
+
|
32
|
+
|
33
|
+
class MyceliumClient:
|
34
|
+
"""Client for processing CLAP embeddings on GPU hardware."""
|
35
|
+
|
36
|
+
def __init__(self):
|
37
|
+
# Load configuration
|
38
|
+
self.config = MyceliumClientConfig.load_from_yaml()
|
39
|
+
|
40
|
+
# Use config values for all settings
|
41
|
+
self.server_host = self.config.client.server_host
|
42
|
+
self.server_port = self.config.client.server_port
|
43
|
+
self.server_url = f"http://{self.server_host}:{self.server_port}"
|
44
|
+
self.model_id = self.config.clap.model_id
|
45
|
+
self.poll_interval = self.config.client.poll_interval
|
46
|
+
self.download_queue_size = self.config.client.download_queue_size
|
47
|
+
self.download_workers = self.config.client.download_workers
|
48
|
+
|
49
|
+
self.config_file_path = get_client_config_file_path()
|
50
|
+
self.last_config_mtime = self._get_config_mtime()
|
51
|
+
|
52
|
+
self.worker_id = f"worker-{uuid.uuid4().hex[:8]}"
|
53
|
+
self.ip_address = self._get_local_ip()
|
54
|
+
|
55
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
56
|
+
|
57
|
+
self.device = CLAPEmbeddingGenerator.get_best_device()
|
58
|
+
|
59
|
+
self.job_queue: Queue[dict] = Queue(maxsize=self.config.client.job_queue_size)
|
60
|
+
self.download_queue: Queue[DownloadedJob] = Queue(maxsize=self.download_queue_size)
|
61
|
+
|
62
|
+
self.job_fetcher_thread: Optional[threading.Thread] = None
|
63
|
+
self.download_threads: List[threading.Thread] = []
|
64
|
+
self.stop_event = threading.Event()
|
65
|
+
|
66
|
+
self.clap_embedding_generator = CLAPEmbeddingGenerator(model_id=self.config.clap.model_id,
|
67
|
+
target_sr=self.config.clap.target_sr,
|
68
|
+
chunk_duration_s=self.config.clap.chunk_duration_s,
|
69
|
+
num_chunks=self.config.clap.num_chunks,
|
70
|
+
max_load_duration_s=self.config.clap.max_load_duration_s)
|
71
|
+
|
72
|
+
logging.info("Mycelium Client initialized")
|
73
|
+
logging.info(f"Worker ID: {self.worker_id}")
|
74
|
+
logging.info(f"Server: {self.server_url}")
|
75
|
+
logging.info(f"Device: {self.device}")
|
76
|
+
logging.info(f"Download queue size: {self.download_queue_size}")
|
77
|
+
logging.info(f"Job queue size: {self.config.client.job_queue_size}")
|
78
|
+
logging.info(f"Poll interval: {self.poll_interval}s")
|
79
|
+
logging.info(f"Parallel download workers: {self.download_workers}")
|
80
|
+
|
81
|
+
def _log_queue_status(self, context: str = ""):
|
82
|
+
"""Log current queue status with context."""
|
83
|
+
job_q_size = self.job_queue.qsize()
|
84
|
+
dl_q_size = self.download_queue.qsize()
|
85
|
+
dl_q_cap = self.download_queue.maxsize
|
86
|
+
dl_q_percent = (dl_q_size / dl_q_cap) * 100 if dl_q_cap > 0 else 0
|
87
|
+
|
88
|
+
status_msg = (
|
89
|
+
f"Queue status ({context}): "
|
90
|
+
f"Jobs to download: {job_q_size}, "
|
91
|
+
f"Jobs ready for GPU: {dl_q_size}/{dl_q_cap} ({dl_q_percent:.1f}%)"
|
92
|
+
)
|
93
|
+
logging.info(status_msg)
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def _get_local_ip() -> str:
|
97
|
+
"""Get the local IP address."""
|
98
|
+
try:
|
99
|
+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
100
|
+
s.connect(("8.8.8.8", 80))
|
101
|
+
return s.getsockname()[0]
|
102
|
+
except Exception:
|
103
|
+
return "127.0.0.1"
|
104
|
+
|
105
|
+
def _get_config_mtime(self) -> float:
|
106
|
+
"""Get the modification time of the config file."""
|
107
|
+
try:
|
108
|
+
if self.config_file_path.exists():
|
109
|
+
return self.config_file_path.stat().st_mtime
|
110
|
+
except Exception:
|
111
|
+
pass
|
112
|
+
return 0.0
|
113
|
+
|
114
|
+
def _check_config_reload(self) -> None:
|
115
|
+
"""Check if config file has been modified and reload if necessary."""
|
116
|
+
try:
|
117
|
+
current_mtime = self._get_config_mtime()
|
118
|
+
if current_mtime > self.last_config_mtime:
|
119
|
+
logging.info("Config file modification detected, reloading...")
|
120
|
+
self.reload_config()
|
121
|
+
self.last_config_mtime = current_mtime
|
122
|
+
except Exception as e:
|
123
|
+
logging.error(f"Error checking config reload: {e}")
|
124
|
+
|
125
|
+
def reload_config(self):
|
126
|
+
"""Reload configuration and apply changes that can be hot-reloaded."""
|
127
|
+
try:
|
128
|
+
logging.info("Reloading client configuration...")
|
129
|
+
new_config = MyceliumClientConfig.load_from_yaml()
|
130
|
+
|
131
|
+
# Check for CLAP model changes
|
132
|
+
clap_changed = (
|
133
|
+
new_config.clap.model_id != self.config.clap.model_id or
|
134
|
+
new_config.clap.target_sr != self.config.clap.target_sr or
|
135
|
+
new_config.clap.chunk_duration_s != self.config.clap.chunk_duration_s
|
136
|
+
)
|
137
|
+
|
138
|
+
# Check for client configuration changes that require worker restart
|
139
|
+
client_changed = (
|
140
|
+
new_config.client.server_host != self.config.client.server_host or
|
141
|
+
new_config.client.server_port != self.config.client.server_port or
|
142
|
+
new_config.client.download_workers != self.config.client.download_workers or
|
143
|
+
new_config.client.download_queue_size != self.config.client.download_queue_size or
|
144
|
+
new_config.client.job_queue_size != self.config.client.job_queue_size
|
145
|
+
)
|
146
|
+
|
147
|
+
if clap_changed:
|
148
|
+
logging.info("CLAP configuration changed, recreating embedding generator...")
|
149
|
+
self.clap_embedding_generator.unload_model()
|
150
|
+
self.clap_embedding_generator = CLAPEmbeddingGenerator(
|
151
|
+
model_id=new_config.clap.model_id,
|
152
|
+
target_sr=new_config.clap.target_sr,
|
153
|
+
chunk_duration_s=new_config.clap.chunk_duration_s,
|
154
|
+
num_chunks=new_config.clap.num_chunks,
|
155
|
+
max_load_duration_s=new_config.clap.max_load_duration_s
|
156
|
+
)
|
157
|
+
logging.info("CLAP embedding generator updated.")
|
158
|
+
|
159
|
+
if client_changed:
|
160
|
+
logging.warning("Client configuration changed. Some changes require restart:")
|
161
|
+
if new_config.client.server_host != self.config.client.server_host:
|
162
|
+
logging.warning(f" - Server host: {self.config.client.server_host} -> {new_config.client.server_host} (requires restart)")
|
163
|
+
if new_config.client.server_port != self.config.client.server_port:
|
164
|
+
logging.warning(f" - Server port: {self.config.client.server_port} -> {new_config.client.server_port} (requires restart)")
|
165
|
+
if new_config.client.download_workers != self.config.client.download_workers:
|
166
|
+
logging.warning(f" - Download workers: {self.config.client.download_workers} -> {new_config.client.download_workers} (requires restart)")
|
167
|
+
if new_config.client.download_queue_size != self.config.client.download_queue_size:
|
168
|
+
logging.warning(f" - Download queue size: {self.config.client.download_queue_size} -> {new_config.client.download_queue_size} (requires restart)")
|
169
|
+
if new_config.client.job_queue_size != self.config.client.job_queue_size:
|
170
|
+
logging.warning(f" - Job queue size: {self.config.client.job_queue_size} -> {new_config.client.job_queue_size} (requires restart)")
|
171
|
+
|
172
|
+
# Apply hot-reloadable changes
|
173
|
+
self.poll_interval = new_config.client.poll_interval
|
174
|
+
if new_config.client.poll_interval != self.config.client.poll_interval:
|
175
|
+
logging.info(f"Poll interval updated: {self.config.client.poll_interval}s -> {new_config.client.poll_interval}s")
|
176
|
+
|
177
|
+
# GPU batch settings can be hot-reloaded
|
178
|
+
if new_config.client.gpu_batch_size != self.config.client.gpu_batch_size:
|
179
|
+
logging.info(f"GPU batch size updated: {self.config.client.gpu_batch_size} -> {new_config.client.gpu_batch_size}")
|
180
|
+
|
181
|
+
self.config = new_config
|
182
|
+
logging.info("Client configuration reloaded successfully")
|
183
|
+
except Exception as e:
|
184
|
+
logging.error(f"Failed to reload client configuration: {e}", exc_info=True)
|
185
|
+
|
186
|
+
|
187
|
+
def register_with_server(self) -> bool:
|
188
|
+
"""Register this worker with the server, retrying on failure."""
|
189
|
+
delay_seconds = 3
|
190
|
+
attempt = 1
|
191
|
+
print("Attempting to register with server...")
|
192
|
+
while not self.stop_event.is_set():
|
193
|
+
try:
|
194
|
+
response = requests.post(
|
195
|
+
f"{self.server_url}/workers/register",
|
196
|
+
json={"worker_id": self.worker_id, "ip_address": self.ip_address},
|
197
|
+
timeout=10
|
198
|
+
)
|
199
|
+
response.raise_for_status()
|
200
|
+
print(f"Successfully registered with server (attempt {attempt})")
|
201
|
+
return True
|
202
|
+
except requests.exceptions.RequestException as e:
|
203
|
+
print(f"Error registering with server (attempt {attempt}): {e}")
|
204
|
+
|
205
|
+
time.sleep(delay_seconds)
|
206
|
+
attempt += 1
|
207
|
+
return False
|
208
|
+
|
209
|
+
def get_job(self) -> Optional[dict]:
|
210
|
+
"""Get the next job from the server."""
|
211
|
+
try:
|
212
|
+
response = requests.get(
|
213
|
+
f"{self.server_url}/workers/get_job",
|
214
|
+
params={"worker_id": self.worker_id, "ip_address": self.ip_address},
|
215
|
+
timeout=30
|
216
|
+
)
|
217
|
+
response.raise_for_status()
|
218
|
+
if response.status_code == 200 and response.text.strip():
|
219
|
+
return response.json()
|
220
|
+
return None
|
221
|
+
except requests.exceptions.RequestException as e:
|
222
|
+
logging.error(f"Error getting job from server: {e}")
|
223
|
+
return None
|
224
|
+
|
225
|
+
@staticmethod
|
226
|
+
def download_audio_file(download_url: str) -> Optional[Path]:
|
227
|
+
"""Download audio file from server."""
|
228
|
+
try:
|
229
|
+
response = requests.get(download_url, stream=True, timeout=60)
|
230
|
+
response.raise_for_status()
|
231
|
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tmp")
|
232
|
+
for chunk in response.iter_content(chunk_size=8192):
|
233
|
+
temp_file.write(chunk)
|
234
|
+
temp_file.close()
|
235
|
+
return Path(temp_file.name)
|
236
|
+
except requests.exceptions.RequestException as e:
|
237
|
+
logging.error(f"Error downloading file from {download_url}: {e}")
|
238
|
+
return None
|
239
|
+
|
240
|
+
def _job_fetcher(self):
|
241
|
+
"""
|
242
|
+
A single thread that requests jobs from the server and puts them in the job_queue.
|
243
|
+
"""
|
244
|
+
logging.info("Job fetcher thread started")
|
245
|
+
while not self.stop_event.is_set():
|
246
|
+
try:
|
247
|
+
if not self.job_queue.full():
|
248
|
+
job = self.get_job()
|
249
|
+
if job:
|
250
|
+
logging.info(f"Job fetcher: Got job {job['task_id']}, adding to download queue.")
|
251
|
+
self.job_queue.put(job)
|
252
|
+
else:
|
253
|
+
time.sleep(self.poll_interval)
|
254
|
+
else:
|
255
|
+
logging.info("Job fetcher: Job queue is full, pausing.")
|
256
|
+
time.sleep(5)
|
257
|
+
except Exception as e:
|
258
|
+
logging.error(f"Job fetcher error: {e}")
|
259
|
+
time.sleep(self.poll_interval)
|
260
|
+
logging.info("Job fetcher thread stopped")
|
261
|
+
|
262
|
+
def _download_worker(self):
|
263
|
+
"""
|
264
|
+
Takes jobs from the job_queue, downloads the audio, and puts them in the download_queue.
|
265
|
+
"""
|
266
|
+
logging.info("Download worker thread started")
|
267
|
+
while not self.stop_event.is_set():
|
268
|
+
try:
|
269
|
+
job = self.job_queue.get(timeout=1)
|
270
|
+
|
271
|
+
if self.download_queue.full():
|
272
|
+
self.job_queue.put(job)
|
273
|
+
time.sleep(5)
|
274
|
+
continue
|
275
|
+
|
276
|
+
task_id = job["task_id"]
|
277
|
+
task_type = job.get("task_type", "compute_audio_embedding")
|
278
|
+
|
279
|
+
if task_type == "compute_text_embedding":
|
280
|
+
downloaded_job = DownloadedJob(
|
281
|
+
task_id=task_id,
|
282
|
+
track_id=job["track_id"],
|
283
|
+
audio_file=None,
|
284
|
+
original_job=job
|
285
|
+
)
|
286
|
+
self.download_queue.put(downloaded_job)
|
287
|
+
logging.info(f"Queued text search job {task_id} for processing.")
|
288
|
+
self.job_queue.task_done()
|
289
|
+
continue
|
290
|
+
|
291
|
+
download_url = job.get("download_url")
|
292
|
+
if not download_url:
|
293
|
+
logging.error(f"Job {task_id} is missing download_url.")
|
294
|
+
self.job_queue.task_done()
|
295
|
+
continue
|
296
|
+
|
297
|
+
full_url = f"http://{self.server_host}:{self.server_port}{download_url}"
|
298
|
+
logging.info(f"Downloading audio for job {task_id} from {full_url}")
|
299
|
+
audio_file = self.download_audio_file(full_url)
|
300
|
+
|
301
|
+
if audio_file:
|
302
|
+
downloaded_job = DownloadedJob(
|
303
|
+
task_id=task_id,
|
304
|
+
track_id=job["track_id"],
|
305
|
+
audio_file=audio_file,
|
306
|
+
original_job=job
|
307
|
+
)
|
308
|
+
self.download_queue.put(downloaded_job)
|
309
|
+
logging.info(f"Queued audio job {task_id} for processing.")
|
310
|
+
else:
|
311
|
+
logging.error(f"Failed to download audio for job {task_id}. Job discarded.")
|
312
|
+
|
313
|
+
self.job_queue.task_done()
|
314
|
+
|
315
|
+
except Empty:
|
316
|
+
continue
|
317
|
+
except Exception as e:
|
318
|
+
logging.error(f"Download worker error: {e}")
|
319
|
+
|
320
|
+
logging.info("Download worker thread stopped")
|
321
|
+
|
322
|
+
def _start_workers(self):
|
323
|
+
"""Start job fetcher and download worker threads."""
|
324
|
+
self.stop_event.clear()
|
325
|
+
|
326
|
+
self.job_fetcher_thread = threading.Thread(target=self._job_fetcher, daemon=True)
|
327
|
+
self.job_fetcher_thread.start()
|
328
|
+
|
329
|
+
for _ in range(self.download_workers):
|
330
|
+
thread = threading.Thread(target=self._download_worker, daemon=True)
|
331
|
+
thread.start()
|
332
|
+
self.download_threads.append(thread)
|
333
|
+
logging.info(f"Started 1 job fetcher and {self.download_workers} download workers.")
|
334
|
+
|
335
|
+
def _stop_workers(self):
|
336
|
+
"""Stop all worker threads."""
|
337
|
+
logging.info("Stopping all worker threads...")
|
338
|
+
self.stop_event.set()
|
339
|
+
|
340
|
+
if self.job_fetcher_thread:
|
341
|
+
self.job_fetcher_thread.join(timeout=5)
|
342
|
+
|
343
|
+
for thread in self.download_threads:
|
344
|
+
thread.join(timeout=5)
|
345
|
+
|
346
|
+
while not self.download_queue.empty():
|
347
|
+
try:
|
348
|
+
job = self.download_queue.get_nowait()
|
349
|
+
if job.audio_file:
|
350
|
+
os.unlink(job.audio_file)
|
351
|
+
except (Empty, OSError):
|
352
|
+
break
|
353
|
+
logging.info("All worker threads stopped.")
|
354
|
+
|
355
|
+
|
356
|
+
def submit_result(self, task_id: str, track_id: str, embedding: Optional[List[float]],
|
357
|
+
error_message: Optional[str] = None) -> bool:
|
358
|
+
"""Submit task result to server."""
|
359
|
+
try:
|
360
|
+
status = "success" if (embedding is not None) else "failed"
|
361
|
+
response = requests.post(
|
362
|
+
f"{self.server_url}/workers/submit_result",
|
363
|
+
json={
|
364
|
+
"task_id": task_id,
|
365
|
+
"track_id": track_id,
|
366
|
+
"status": status,
|
367
|
+
"embedding": embedding,
|
368
|
+
"error_message": error_message
|
369
|
+
},
|
370
|
+
timeout=30
|
371
|
+
)
|
372
|
+
response.raise_for_status()
|
373
|
+
return response.json().get("success", False)
|
374
|
+
except requests.exceptions.RequestException as e:
|
375
|
+
logging.error(f"Error submitting result for task {task_id}: {e}")
|
376
|
+
return False
|
377
|
+
|
378
|
+
def _process_batch(self, batch: List[DownloadedJob]) -> None:
|
379
|
+
"""Process a batch of jobs to improve GPU utilization."""
|
380
|
+
if not batch:
|
381
|
+
return
|
382
|
+
|
383
|
+
logging.info(f"Processing batch of {len(batch)} jobs")
|
384
|
+
|
385
|
+
# Separate jobs by type for more efficient batching
|
386
|
+
audio_jobs = []
|
387
|
+
text_jobs = []
|
388
|
+
|
389
|
+
for job in batch:
|
390
|
+
task_type = job.original_job.get("task_type", "compute_audio_embedding")
|
391
|
+
if task_type == "compute_audio_embedding":
|
392
|
+
audio_jobs.append(job)
|
393
|
+
elif task_type == "compute_text_embedding":
|
394
|
+
text_jobs.append(job)
|
395
|
+
|
396
|
+
# Process audio jobs in batch
|
397
|
+
if audio_jobs:
|
398
|
+
self._process_audio_batch(audio_jobs)
|
399
|
+
|
400
|
+
# Process text jobs in batch
|
401
|
+
if text_jobs:
|
402
|
+
self._process_text_batch(text_jobs)
|
403
|
+
|
404
|
+
def _process_audio_batch(self, audio_jobs: List[DownloadedJob]) -> None:
|
405
|
+
"""Process a batch of audio embedding jobs."""
|
406
|
+
self._process_audio_batch_gpu(audio_jobs)
|
407
|
+
|
408
|
+
def _process_audio_batch_gpu(self, audio_jobs: List[DownloadedJob]) -> None:
|
409
|
+
"""Process audio jobs using GPU batch processing."""
|
410
|
+
# Prepare batch data
|
411
|
+
audio_files = []
|
412
|
+
job_metadata = []
|
413
|
+
|
414
|
+
for job in audio_jobs:
|
415
|
+
if job.audio_file and job.audio_file.exists():
|
416
|
+
audio_files.append(job.audio_file)
|
417
|
+
job_metadata.append(job)
|
418
|
+
else:
|
419
|
+
# Handle jobs with missing files individually
|
420
|
+
self.submit_result(job.task_id, job.track_id, None, "Audio file not available")
|
421
|
+
|
422
|
+
if not audio_files:
|
423
|
+
return
|
424
|
+
|
425
|
+
try:
|
426
|
+
# Generate embeddings in batch
|
427
|
+
embeddings = self.clap_embedding_generator.generate_embedding_batch(audio_files)
|
428
|
+
|
429
|
+
# Submit results
|
430
|
+
for job, embedding in zip(job_metadata, embeddings):
|
431
|
+
success = self.submit_result(job.task_id, job.track_id, embedding,
|
432
|
+
None if embedding else "Failed to compute audio embedding")
|
433
|
+
if success:
|
434
|
+
logging.debug(f"Successfully submitted batch job {job.task_id}")
|
435
|
+
else:
|
436
|
+
logging.warning(f"Failed to submit batch job {job.task_id}")
|
437
|
+
|
438
|
+
except Exception as e:
|
439
|
+
logging.error(f"Batch processing failed: {e}", exc_info=True)
|
440
|
+
finally:
|
441
|
+
# Clean up audio files
|
442
|
+
for job in audio_jobs:
|
443
|
+
if job.audio_file:
|
444
|
+
try:
|
445
|
+
os.unlink(job.audio_file)
|
446
|
+
except OSError as e:
|
447
|
+
logging.error(f"Error deleting temp file {job.audio_file}: {e}")
|
448
|
+
|
449
|
+
# Force garbage collection after batch processing
|
450
|
+
collected = gc.collect()
|
451
|
+
if collected > 0:
|
452
|
+
logging.debug(f"Post-batch cleanup: collected {collected} objects")
|
453
|
+
|
454
|
+
def _process_text_batch(self, text_jobs: List[DownloadedJob]) -> None:
|
455
|
+
"""Process a batch of text embedding jobs."""
|
456
|
+
self._process_text_batch_gpu(text_jobs)
|
457
|
+
|
458
|
+
|
459
|
+
def _process_text_batch_gpu(self, text_jobs: List[DownloadedJob]) -> None:
|
460
|
+
"""Process text jobs using GPU batch processing."""
|
461
|
+
# Prepare batch data
|
462
|
+
text_queries = []
|
463
|
+
job_metadata = []
|
464
|
+
|
465
|
+
for job in text_jobs:
|
466
|
+
text_query = job.original_job.get("text_query")
|
467
|
+
if text_query:
|
468
|
+
text_queries.append(text_query)
|
469
|
+
job_metadata.append(job)
|
470
|
+
else:
|
471
|
+
# Handle jobs with missing text individually
|
472
|
+
self.submit_result(job.task_id, job.track_id, None, "Missing text query")
|
473
|
+
|
474
|
+
if not text_queries:
|
475
|
+
return
|
476
|
+
|
477
|
+
try:
|
478
|
+
# Generate embeddings in batch
|
479
|
+
embeddings = self.clap_embedding_generator.generate_text_embedding_batch(text_queries)
|
480
|
+
|
481
|
+
# Submit results
|
482
|
+
for job, embedding in zip(job_metadata, embeddings):
|
483
|
+
success = self.submit_result(job.task_id, job.track_id, embedding,
|
484
|
+
None if embedding else "Failed to compute text embedding")
|
485
|
+
if success:
|
486
|
+
logging.debug(f"Successfully submitted batch text job {job.task_id}")
|
487
|
+
else:
|
488
|
+
logging.warning(f"Failed to submit batch text job {job.task_id}")
|
489
|
+
|
490
|
+
except Exception as e:
|
491
|
+
logging.error(f"Text batch processing failed: {e}", exc_info=True)
|
492
|
+
|
493
|
+
def run(self):
|
494
|
+
"""Main worker loop with batch processing for better GPU utilization."""
|
495
|
+
logging.info("Starting Mycelium client worker loop...")
|
496
|
+
|
497
|
+
if not self.register_with_server():
|
498
|
+
logging.error("Failed to register with server. Exiting.")
|
499
|
+
return
|
500
|
+
|
501
|
+
self._start_workers()
|
502
|
+
self._log_queue_status("worker started")
|
503
|
+
|
504
|
+
last_status_log = time.time()
|
505
|
+
status_log_interval = 30
|
506
|
+
gpu_batch_size = self.config.client.gpu_batch_size
|
507
|
+
|
508
|
+
try:
|
509
|
+
while True:
|
510
|
+
# Collect a batch of jobs for processing
|
511
|
+
batch = []
|
512
|
+
while len(batch) < gpu_batch_size:
|
513
|
+
try:
|
514
|
+
downloaded_job = self.download_queue.get(timeout=0.5)
|
515
|
+
batch.append(downloaded_job)
|
516
|
+
except Empty:
|
517
|
+
if self.stop_event.is_set():
|
518
|
+
break
|
519
|
+
continue
|
520
|
+
|
521
|
+
# Process the batch if we have any jobs
|
522
|
+
if batch:
|
523
|
+
self._process_batch(batch)
|
524
|
+
|
525
|
+
# Mark all jobs as done
|
526
|
+
for _ in batch:
|
527
|
+
self.download_queue.task_done()
|
528
|
+
|
529
|
+
if time.time() - last_status_log > status_log_interval:
|
530
|
+
self._log_queue_status("processing")
|
531
|
+
last_status_log = time.time()
|
532
|
+
logging.info(f"Processed batch of {len(batch)} jobs")
|
533
|
+
else:
|
534
|
+
# No jobs available
|
535
|
+
if self.stop_event.is_set():
|
536
|
+
break
|
537
|
+
if time.time() - last_status_log > status_log_interval:
|
538
|
+
self._log_queue_status("idle")
|
539
|
+
last_status_log = time.time()
|
540
|
+
self._check_config_reload()
|
541
|
+
|
542
|
+
except KeyboardInterrupt:
|
543
|
+
logging.info("\nShutting down worker...")
|
544
|
+
finally:
|
545
|
+
self._log_queue_status("shutdown")
|
546
|
+
self._stop_workers()
|
547
|
+
self.clap_embedding_generator.unload_model()
|
548
|
+
logging.info("Worker stopped")
|
549
|
+
|
550
|
+
|
551
|
+
def run_client():
|
552
|
+
"""Run the Mycelium client."""
|
553
|
+
client = MyceliumClient()
|
554
|
+
client.run()
|