mycelium-ai 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mycelium/__init__.py +0 -0
- mycelium/api/__init__.py +0 -0
- mycelium/api/app.py +1147 -0
- mycelium/api/client_app.py +170 -0
- mycelium/api/generated_sources/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/__init__.py +97 -0
- mycelium/api/generated_sources/server_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/server_schemas/api/default_api.py +2473 -0
- mycelium/api/generated_sources/server_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/server_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/server_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/server_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/server_schemas/models/__init__.py +41 -0
- mycelium/api/generated_sources/server_schemas/models/api_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/chroma_section.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/clap_section.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server200_response.py +79 -0
- mycelium/api/generated_sources/server_schemas/models/compute_on_server_request.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/compute_text_search_request.py +69 -0
- mycelium/api/generated_sources/server_schemas/models/config_request.py +81 -0
- mycelium/api/generated_sources/server_schemas/models/config_response.py +107 -0
- mycelium/api/generated_sources/server_schemas/models/create_playlist_request.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/get_similar_by_track200_response.py +143 -0
- mycelium/api/generated_sources/server_schemas/models/library_stats_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/logging_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/media_server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/playlist_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/plex_section.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/processing_response.py +90 -0
- mycelium/api/generated_sources/server_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/server_schemas/models/scan_library_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/search_result_response.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/server_section.py +67 -0
- mycelium/api/generated_sources/server_schemas/models/stop_processing_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/models/task_status_response.py +87 -0
- mycelium/api/generated_sources/server_schemas/models/track_database_stats.py +75 -0
- mycelium/api/generated_sources/server_schemas/models/track_response.py +77 -0
- mycelium/api/generated_sources/server_schemas/models/tracks_list_response.py +81 -0
- mycelium/api/generated_sources/server_schemas/rest.py +329 -0
- mycelium/api/generated_sources/server_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/server_schemas/test/test_api_section.py +57 -0
- mycelium/api/generated_sources/server_schemas/test/test_chroma_section.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_clap_section.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server200_response.py +52 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_on_server_request.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_compute_text_search_request.py +54 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_request.py +66 -0
- mycelium/api/generated_sources/server_schemas/test/test_config_response.py +97 -0
- mycelium/api/generated_sources/server_schemas/test/test_create_playlist_request.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_default_api.py +150 -0
- mycelium/api/generated_sources/server_schemas/test/test_get_similar_by_track200_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_library_stats_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_logging_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_media_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_playlist_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_plex_section.py +56 -0
- mycelium/api/generated_sources/server_schemas/test/test_processing_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/server_schemas/test/test_scan_library_response.py +61 -0
- mycelium/api/generated_sources/server_schemas/test/test_search_result_response.py +69 -0
- mycelium/api/generated_sources/server_schemas/test/test_server_section.py +53 -0
- mycelium/api/generated_sources/server_schemas/test/test_stop_processing_response.py +55 -0
- mycelium/api/generated_sources/server_schemas/test/test_task_status_response.py +71 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_database_stats.py +60 -0
- mycelium/api/generated_sources/server_schemas/test/test_track_response.py +63 -0
- mycelium/api/generated_sources/server_schemas/test/test_tracks_list_response.py +75 -0
- mycelium/api/generated_sources/worker_schemas/__init__.py +61 -0
- mycelium/api/generated_sources/worker_schemas/api/__init__.py +5 -0
- mycelium/api/generated_sources/worker_schemas/api/default_api.py +318 -0
- mycelium/api/generated_sources/worker_schemas/api_client.py +766 -0
- mycelium/api/generated_sources/worker_schemas/api_response.py +25 -0
- mycelium/api/generated_sources/worker_schemas/configuration.py +434 -0
- mycelium/api/generated_sources/worker_schemas/exceptions.py +166 -0
- mycelium/api/generated_sources/worker_schemas/models/__init__.py +23 -0
- mycelium/api/generated_sources/worker_schemas/models/save_config_response.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_clap_section.py +75 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_api_section.py +69 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_client_section.py +79 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_request.py +73 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/models/worker_logging_section.py +67 -0
- mycelium/api/generated_sources/worker_schemas/rest.py +329 -0
- mycelium/api/generated_sources/worker_schemas/test/__init__.py +0 -0
- mycelium/api/generated_sources/worker_schemas/test/test_default_api.py +45 -0
- mycelium/api/generated_sources/worker_schemas/test/test_save_config_response.py +58 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_clap_section.py +60 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_api_section.py +55 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_client_section.py +65 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_request.py +59 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_config_response.py +89 -0
- mycelium/api/generated_sources/worker_schemas/test/test_worker_logging_section.py +53 -0
- mycelium/api/worker_models.py +99 -0
- mycelium/application/__init__.py +11 -0
- mycelium/application/job_queue.py +323 -0
- mycelium/application/library_management_use_cases.py +292 -0
- mycelium/application/search_use_cases.py +96 -0
- mycelium/application/services.py +340 -0
- mycelium/client.py +554 -0
- mycelium/client_config.py +251 -0
- mycelium/client_frontend_dist/404.html +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_buildManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_ssgManifest.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/app/page-cc6bad295789134e.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/client_frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/client_frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/client_frontend_dist/favicon.ico +0 -0
- mycelium/client_frontend_dist/file.svg +1 -0
- mycelium/client_frontend_dist/globe.svg +1 -0
- mycelium/client_frontend_dist/index.html +1 -0
- mycelium/client_frontend_dist/index.txt +20 -0
- mycelium/client_frontend_dist/next.svg +1 -0
- mycelium/client_frontend_dist/vercel.svg +1 -0
- mycelium/client_frontend_dist/window.svg +1 -0
- mycelium/config.py +346 -0
- mycelium/domain/__init__.py +13 -0
- mycelium/domain/models.py +71 -0
- mycelium/domain/repositories.py +98 -0
- mycelium/domain/worker.py +77 -0
- mycelium/frontend_dist/404.html +1 -0
- mycelium/frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/app/page-a761463485e0540b.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- mycelium/frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
- mycelium/frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_buildManifest.js +1 -0
- mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_ssgManifest.js +1 -0
- mycelium/frontend_dist/favicon.ico +0 -0
- mycelium/frontend_dist/file.svg +1 -0
- mycelium/frontend_dist/globe.svg +1 -0
- mycelium/frontend_dist/index.html +10 -0
- mycelium/frontend_dist/index.txt +20 -0
- mycelium/frontend_dist/next.svg +1 -0
- mycelium/frontend_dist/vercel.svg +1 -0
- mycelium/frontend_dist/window.svg +1 -0
- mycelium/infrastructure/__init__.py +17 -0
- mycelium/infrastructure/chroma_adapter.py +232 -0
- mycelium/infrastructure/clap_adapter.py +280 -0
- mycelium/infrastructure/plex_adapter.py +145 -0
- mycelium/infrastructure/track_database.py +467 -0
- mycelium/main.py +183 -0
- mycelium_ai-0.5.0.dist-info/METADATA +312 -0
- mycelium_ai-0.5.0.dist-info/RECORD +164 -0
- mycelium_ai-0.5.0.dist-info/WHEEL +5 -0
- mycelium_ai-0.5.0.dist-info/entry_points.txt +2 -0
- mycelium_ai-0.5.0.dist-info/licenses/LICENSE +21 -0
- mycelium_ai-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
|
|
1
|
+
"""ChromaDB integration for storing and searching embeddings."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Optional
|
7
|
+
|
8
|
+
import chromadb
|
9
|
+
from tqdm import tqdm
|
10
|
+
|
11
|
+
from ..domain.models import Track, TrackEmbedding, SearchResult, MediaServerType
|
12
|
+
from ..domain.repositories import EmbeddingRepository
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class ChromaEmbeddingRepository(EmbeddingRepository):
|
18
|
+
"""Implementation of EmbeddingRepository using ChromaDB with model-specific collections."""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
db_path: str,
|
23
|
+
media_server_type: MediaServerType,
|
24
|
+
collection_name: str = "my_music_library",
|
25
|
+
model_id: str = "laion/larger_clap_music_and_speech",
|
26
|
+
batch_size: int = 1000,
|
27
|
+
):
|
28
|
+
self.db_path = db_path
|
29
|
+
self.base_collection_name = collection_name
|
30
|
+
self.model_id = model_id
|
31
|
+
self.batch_size = batch_size
|
32
|
+
self.media_server_type = media_server_type
|
33
|
+
|
34
|
+
# Initialize ChromaDB client
|
35
|
+
try:
|
36
|
+
Path(db_path).mkdir(parents=True, exist_ok=True)
|
37
|
+
except Exception:
|
38
|
+
logger.error(f"Failed to create database directory at {db_path}. Please check permissions.")
|
39
|
+
self.client = chromadb.PersistentClient(path=db_path)
|
40
|
+
|
41
|
+
# Create model-specific collection name
|
42
|
+
self.collection_name = self._get_collection_name_for_model(model_id)
|
43
|
+
|
44
|
+
# Specify 'cosine' distance metric for normalized embeddings
|
45
|
+
self.collection = self.client.get_or_create_collection(
|
46
|
+
name=self.collection_name,
|
47
|
+
metadata={"hnsw:space": "cosine", "model_id": model_id}
|
48
|
+
)
|
49
|
+
|
50
|
+
logger.info(
|
51
|
+
f"Collection '{self.collection_name}' ready for model '{model_id}'. Current elements: {self.collection.count()}")
|
52
|
+
|
53
|
+
def _get_collection_name_for_model(self, model_id: str) -> str:
|
54
|
+
"""Generate a safe collection name for the given model ID."""
|
55
|
+
# Make model ID safe for collection name (alphanumeric and underscores only)
|
56
|
+
safe_model_id = re.sub(r'\W', '_', model_id.replace('/', '_'))
|
57
|
+
return f"{self.base_collection_name}_{safe_model_id}"
|
58
|
+
|
59
|
+
def save_embeddings(self, embeddings: List[TrackEmbedding]) -> None:
|
60
|
+
"""Save track embeddings to ChromaDB."""
|
61
|
+
if not embeddings:
|
62
|
+
return
|
63
|
+
|
64
|
+
# Prepare data for batch insertion
|
65
|
+
ids = []
|
66
|
+
embedding_vectors = []
|
67
|
+
metadatas = []
|
68
|
+
|
69
|
+
for track_embedding in embeddings:
|
70
|
+
track = track_embedding.track
|
71
|
+
ids.append(track.unique_id)
|
72
|
+
embedding_vectors.append(track_embedding.embedding)
|
73
|
+
metadatas.append({
|
74
|
+
"filepath": str(track.filepath),
|
75
|
+
"artist": track.artist,
|
76
|
+
"album": track.album,
|
77
|
+
"title": track.title,
|
78
|
+
"media_server_type": track.media_server_type.value,
|
79
|
+
"media_server_rating_key": track.media_server_rating_key,
|
80
|
+
"model_id": self.model_id
|
81
|
+
})
|
82
|
+
|
83
|
+
# Insert in batches for maximum efficiency
|
84
|
+
for i in tqdm(range(0, len(ids), self.batch_size), desc="Indexing in ChromaDB"):
|
85
|
+
end_idx = min(i + self.batch_size, len(ids))
|
86
|
+
id_batch = ids[i:end_idx]
|
87
|
+
embedding_batch = embedding_vectors[i:end_idx]
|
88
|
+
metadata_batch = metadatas[i:end_idx]
|
89
|
+
|
90
|
+
self.collection.add(
|
91
|
+
ids=id_batch,
|
92
|
+
embeddings=embedding_batch,
|
93
|
+
metadatas=metadata_batch
|
94
|
+
)
|
95
|
+
|
96
|
+
logger.info("Indexing completed!")
|
97
|
+
logger.info(f"Total elements in collection '{self.collection_name}': {self.collection.count()}")
|
98
|
+
|
99
|
+
def search_by_embedding(self, embedding: List[float], n_results: int = 10) -> List[SearchResult]:
|
100
|
+
"""Search for similar tracks by embedding."""
|
101
|
+
results = self.collection.query(
|
102
|
+
query_embeddings=[embedding],
|
103
|
+
n_results=n_results
|
104
|
+
)
|
105
|
+
|
106
|
+
return self._parse_search_results(results.copy())
|
107
|
+
|
108
|
+
def get_embedding_count(self) -> int:
|
109
|
+
"""Get the total number of embeddings stored."""
|
110
|
+
return self.collection.count()
|
111
|
+
|
112
|
+
@staticmethod
|
113
|
+
def _parse_search_results(results: dict) -> List[SearchResult]:
|
114
|
+
"""Parse ChromaDB results into SearchResult objects."""
|
115
|
+
search_results = []
|
116
|
+
|
117
|
+
if not results['ids'] or not results['ids'][0]:
|
118
|
+
return search_results
|
119
|
+
|
120
|
+
for i in range(len(results['ids'][0])):
|
121
|
+
metadata = results['metadatas'][0][i]
|
122
|
+
distance = results['distances'][0][i]
|
123
|
+
unique_id = results['ids'][0][i]
|
124
|
+
|
125
|
+
# Parse unique_id to get media server info
|
126
|
+
media_server_type_str, media_server_rating_key = unique_id.split(':', 1)
|
127
|
+
|
128
|
+
from ..domain.models import MediaServerType
|
129
|
+
try:
|
130
|
+
media_server_type = MediaServerType(media_server_type_str)
|
131
|
+
except ValueError:
|
132
|
+
media_server_type = MediaServerType.PLEX # Default fallback
|
133
|
+
|
134
|
+
track = Track(
|
135
|
+
artist=metadata['artist'],
|
136
|
+
album=metadata['album'],
|
137
|
+
title=metadata['title'],
|
138
|
+
filepath=Path(metadata['filepath']),
|
139
|
+
media_server_rating_key=media_server_rating_key,
|
140
|
+
media_server_type=media_server_type
|
141
|
+
)
|
142
|
+
|
143
|
+
# Convert distance to similarity score (1 - distance for cosine)
|
144
|
+
similarity_score = 1.0 - distance
|
145
|
+
|
146
|
+
search_results.append(SearchResult(
|
147
|
+
track=track,
|
148
|
+
similarity_score=similarity_score,
|
149
|
+
distance=distance
|
150
|
+
))
|
151
|
+
|
152
|
+
return search_results
|
153
|
+
|
154
|
+
def has_embedding(self, track_id: str) -> bool:
|
155
|
+
"""Check if an embedding exists for a track."""
|
156
|
+
track_id = Track(media_server_type=self.media_server_type,
|
157
|
+
media_server_rating_key=track_id).unique_id
|
158
|
+
logger.debug(
|
159
|
+
f"Checking embedding for track {track_id}: collection_name={self.collection_name}, model_id={self.model_id}"
|
160
|
+
)
|
161
|
+
try:
|
162
|
+
result = self.collection.get(ids=[track_id])
|
163
|
+
exists = len(result['ids']) > 0
|
164
|
+
logger.debug(f"Checking embedding for track {track_id}: exists={exists}")
|
165
|
+
return exists
|
166
|
+
except Exception as e:
|
167
|
+
logger.error(f"Error checking embedding for track {track_id}: {e}")
|
168
|
+
return False
|
169
|
+
|
170
|
+
def save_embedding(self, track_embedding: TrackEmbedding) -> None:
|
171
|
+
"""Save a single track embedding to ChromaDB."""
|
172
|
+
track = track_embedding.track
|
173
|
+
track_id = track.unique_id
|
174
|
+
|
175
|
+
logger.info(f"Saving embedding to ChromaDB for track {track_id}: {track.artist} - {track.title}")
|
176
|
+
logger.info(f"Collection count before save: {self.collection.count()}")
|
177
|
+
|
178
|
+
# Check if embedding already exists, if so, update it
|
179
|
+
existing = self.collection.get(ids=[track_id])
|
180
|
+
if existing['ids']:
|
181
|
+
logger.info(f"Updating existing embedding for track {track_id}")
|
182
|
+
self.collection.update(
|
183
|
+
ids=[track_id],
|
184
|
+
embeddings=[track_embedding.embedding],
|
185
|
+
metadatas=[{
|
186
|
+
"filepath": str(track.filepath),
|
187
|
+
"artist": track.artist,
|
188
|
+
"album": track.album,
|
189
|
+
"title": track.title,
|
190
|
+
"media_server_type": track.media_server_type.value,
|
191
|
+
"media_server_rating_key": track.media_server_rating_key,
|
192
|
+
"model_id": self.model_id
|
193
|
+
}]
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
logger.info(f"Adding new embedding for track {track_id}")
|
197
|
+
self.collection.add(
|
198
|
+
ids=[track_id],
|
199
|
+
embeddings=[track_embedding.embedding],
|
200
|
+
metadatas=[{
|
201
|
+
"filepath": str(track.filepath),
|
202
|
+
"artist": track.artist,
|
203
|
+
"album": track.album,
|
204
|
+
"title": track.title,
|
205
|
+
"media_server_type": track.media_server_type.value,
|
206
|
+
"media_server_rating_key": track.media_server_rating_key,
|
207
|
+
"model_id": self.model_id
|
208
|
+
}]
|
209
|
+
)
|
210
|
+
|
211
|
+
logger.info(f"Collection count after save: {self.collection.count()}")
|
212
|
+
logger.info(f"Successfully saved embedding to ChromaDB for track {track_id}")
|
213
|
+
|
214
|
+
def get_embedding_by_track_id(self, track_id: str) -> Optional[List[float]]:
|
215
|
+
"""Get embedding for a specific track."""
|
216
|
+
track_id = Track(media_server_type=self.media_server_type, media_server_rating_key=track_id).unique_id
|
217
|
+
try:
|
218
|
+
result = self.collection.get(
|
219
|
+
ids=[track_id],
|
220
|
+
include=['embeddings']
|
221
|
+
)
|
222
|
+
if result['embeddings'] is not None and len(result['embeddings']) > 0:
|
223
|
+
embedding = result['embeddings'][0]
|
224
|
+
logger.debug(
|
225
|
+
f"Retrieved embedding for track {track_id}, size: {len(embedding) if embedding is not None else 0}")
|
226
|
+
return embedding
|
227
|
+
else:
|
228
|
+
logger.debug(f"No embedding found in ChromaDB for track {track_id}")
|
229
|
+
return None
|
230
|
+
except Exception as e:
|
231
|
+
logger.error(f"Error retrieving embedding for track {track_id}: {e}")
|
232
|
+
return None
|
@@ -0,0 +1,280 @@
|
|
1
|
+
"""CLAP model integration for generating embeddings."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import random
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Optional
|
7
|
+
|
8
|
+
import librosa
|
9
|
+
import torch
|
10
|
+
from transformers import ClapModel, ClapProcessor
|
11
|
+
|
12
|
+
from ..domain.repositories import EmbeddingGenerator
|
13
|
+
|
14
|
+
|
15
|
+
class CLAPEmbeddingGenerator(EmbeddingGenerator):
|
16
|
+
""" Implementation of EmbeddingGenerator using LAION's CLAP model. """
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
model_id: str = "laion/larger_clap_music_and_speech",
|
21
|
+
target_sr: int = 48000,
|
22
|
+
chunk_duration_s: int = 10,
|
23
|
+
num_chunks: int = 3,
|
24
|
+
max_load_duration_s: Optional[int] = 120
|
25
|
+
):
|
26
|
+
self.model_id = model_id
|
27
|
+
self.target_sr = target_sr
|
28
|
+
self.chunk_duration_s = chunk_duration_s
|
29
|
+
self.num_chunks = num_chunks
|
30
|
+
self.max_load_duration_s = max_load_duration_s
|
31
|
+
self.logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
self.device = self.get_best_device()
|
34
|
+
self.logger.info(f"Selected device: {self.device}")
|
35
|
+
|
36
|
+
## Lazy loading. Model is not loaded on instantiation.
|
37
|
+
self.model: Optional[ClapModel] = None
|
38
|
+
self.processor: Optional[ClapProcessor] = None
|
39
|
+
|
40
|
+
self.use_half = self.can_use_half_precision()
|
41
|
+
if self.use_half:
|
42
|
+
self.logger.info("Half precision (FP16) is supported and will be used.")
|
43
|
+
else:
|
44
|
+
self.logger.info("Half precision not supported, using full precision (FP32).")
|
45
|
+
|
46
|
+
def _load_model_if_needed(self):
|
47
|
+
"""Loads the model and processor on the first call that needs them."""
|
48
|
+
if self.model is None or self.processor is None:
|
49
|
+
self.logger.info(f"Loading model '{self.model_id}' to device '{self.device}'...")
|
50
|
+
|
51
|
+
self.model = ClapModel.from_pretrained(self.model_id).to(self.device)
|
52
|
+
self.processor = ClapProcessor.from_pretrained(self.model_id)
|
53
|
+
|
54
|
+
if self.use_half and self.device == "cuda":
|
55
|
+
self.logger.info("Applying half precision (FP16) to model for CUDA device.")
|
56
|
+
self.model.half()
|
57
|
+
elif self.use_half and self.device == "mps":
|
58
|
+
self.logger.warning(
|
59
|
+
"Half precision is supported but disabled on MPS device to prevent potential crashes. Using FP32.")
|
60
|
+
|
61
|
+
self.model.eval()
|
62
|
+
self.logger.info("Model loaded successfully.")
|
63
|
+
try:
|
64
|
+
self.logger.info(f"Model dtype after load: {next(self.model.parameters()).dtype}")
|
65
|
+
except StopIteration:
|
66
|
+
self.logger.debug("Could not determine model dtype (no parameters found).")
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def get_best_device() -> str:
|
70
|
+
if torch.cuda.is_available():
|
71
|
+
return "cuda"
|
72
|
+
if torch.backends.mps.is_available():
|
73
|
+
return "mps"
|
74
|
+
return "cpu"
|
75
|
+
|
76
|
+
def can_use_half_precision(self) -> bool:
|
77
|
+
"""Checks once if the device supports half precision."""
|
78
|
+
if self.device == "cuda":
|
79
|
+
# Most modern CUDA devices support FP16.
|
80
|
+
return True
|
81
|
+
if self.device == "mps":
|
82
|
+
# Check for potential runtime errors on some MPS devices.
|
83
|
+
try:
|
84
|
+
torch.tensor([1.0], dtype=torch.half).to(self.device)
|
85
|
+
return True
|
86
|
+
except RuntimeError:
|
87
|
+
self.logger.warning("MPS device does not support half precision, falling back to FP32.")
|
88
|
+
return False
|
89
|
+
return False
|
90
|
+
|
91
|
+
def _get_processor(self) -> ClapProcessor:
|
92
|
+
"""Return a ready-to-use processor with a non-optional type."""
|
93
|
+
self._load_model_if_needed()
|
94
|
+
assert self.processor is not None
|
95
|
+
return self.processor
|
96
|
+
|
97
|
+
def _get_model(self) -> ClapModel:
|
98
|
+
"""Return a ready-to-use model with a non-optional type."""
|
99
|
+
self._load_model_if_needed()
|
100
|
+
assert self.model is not None
|
101
|
+
return self.model
|
102
|
+
|
103
|
+
def _prepare_inputs(self, inputs: dict) -> dict:
|
104
|
+
"""Move inputs to the correct device and cast floating tensors to the model's dtype.
|
105
|
+
|
106
|
+
This prevents dtype mismatches when the model runs in half precision on CUDA.
|
107
|
+
"""
|
108
|
+
model = self._get_model()
|
109
|
+
# Determine model parameter dtype (e.g., torch.float32 or torch.float16)
|
110
|
+
model_dtype = next(model.parameters()).dtype
|
111
|
+
prepared = {}
|
112
|
+
for k, v in inputs.items():
|
113
|
+
if isinstance(v, torch.Tensor):
|
114
|
+
if v.is_floating_point():
|
115
|
+
prepared[k] = v.to(device=self.device, dtype=model_dtype)
|
116
|
+
else:
|
117
|
+
prepared[k] = v.to(device=self.device)
|
118
|
+
else:
|
119
|
+
prepared[k] = v
|
120
|
+
return prepared
|
121
|
+
|
122
|
+
def generate_embedding(self, filepath: Path) -> Optional[List[float]]:
|
123
|
+
"""Generate embedding for a single audio file by delegating to batch method."""
|
124
|
+
results = self.generate_embedding_batch([filepath])
|
125
|
+
return results[0] if results else None
|
126
|
+
|
127
|
+
def generate_embedding_batch(self, filepaths: List[Path]) -> List[Optional[List[float]]]:
|
128
|
+
"""Generate embeddings for multiple audio files in a single GPU batch"""
|
129
|
+
if not filepaths:
|
130
|
+
return []
|
131
|
+
|
132
|
+
try:
|
133
|
+
processor = self._get_processor()
|
134
|
+
model = self._get_model()
|
135
|
+
|
136
|
+
all_chunks = []
|
137
|
+
file_chunk_counts = []
|
138
|
+
|
139
|
+
chunk_size_samples = self.chunk_duration_s * self.target_sr
|
140
|
+
|
141
|
+
# Load and prepare all audio files
|
142
|
+
for filepath in filepaths:
|
143
|
+
try:
|
144
|
+
waveform, _ = librosa.load(
|
145
|
+
str(filepath),
|
146
|
+
sr=self.target_sr,
|
147
|
+
mono=True,
|
148
|
+
duration=self.max_load_duration_s
|
149
|
+
)
|
150
|
+
|
151
|
+
total_samples = len(waveform)
|
152
|
+
chunks = []
|
153
|
+
|
154
|
+
# Calculate how many full, non-overlapping chunks can fit.
|
155
|
+
num_possible_bins = total_samples // chunk_size_samples
|
156
|
+
|
157
|
+
if num_possible_bins == 0:
|
158
|
+
self.logger.warning(
|
159
|
+
f"File {filepath} is too short ({total_samples / self.target_sr:.1f}s) "
|
160
|
+
f"for even one chunk of {self.chunk_duration_s:.1f}s.")
|
161
|
+
file_chunk_counts.append(0)
|
162
|
+
continue
|
163
|
+
|
164
|
+
# Determine which bin indices to sample from.
|
165
|
+
if num_possible_bins < self.num_chunks:
|
166
|
+
self.logger.warning(
|
167
|
+
f"File {filepath} only has space for {num_possible_bins} non-overlapping chunks, "
|
168
|
+
f"less than the requested {self.num_chunks}. Using all available chunks."
|
169
|
+
)
|
170
|
+
chosen_bin_indices = range(num_possible_bins)
|
171
|
+
else:
|
172
|
+
possible_bin_indices = range(num_possible_bins)
|
173
|
+
chosen_bin_indices = random.sample(possible_bin_indices, k=self.num_chunks)
|
174
|
+
|
175
|
+
# Create chunks based on the chosen indices.
|
176
|
+
for bin_index in chosen_bin_indices:
|
177
|
+
start_idx = bin_index * chunk_size_samples
|
178
|
+
end_idx = start_idx + chunk_size_samples
|
179
|
+
chunk = waveform[start_idx:end_idx]
|
180
|
+
chunks.append(chunk)
|
181
|
+
|
182
|
+
if not chunks:
|
183
|
+
self.logger.warning(f"No valid chunks generated for {filepath}.")
|
184
|
+
file_chunk_counts.append(0)
|
185
|
+
continue
|
186
|
+
|
187
|
+
all_chunks.extend(chunks)
|
188
|
+
file_chunk_counts.append(len(chunks))
|
189
|
+
|
190
|
+
except Exception as e:
|
191
|
+
self.logger.error(f"Error loading audio file {filepath}: {e}")
|
192
|
+
file_chunk_counts.append(0)
|
193
|
+
|
194
|
+
if not all_chunks:
|
195
|
+
return [None] * len(filepaths)
|
196
|
+
|
197
|
+
# Process all chunks in a single batch
|
198
|
+
inputs = processor(
|
199
|
+
audios=all_chunks,
|
200
|
+
sampling_rate=self.target_sr,
|
201
|
+
return_tensors="pt",
|
202
|
+
padding=True
|
203
|
+
)
|
204
|
+
|
205
|
+
inputs = self._prepare_inputs(inputs)
|
206
|
+
|
207
|
+
with torch.no_grad():
|
208
|
+
audio_features = model.get_audio_features(**inputs)
|
209
|
+
|
210
|
+
# Split results back to individual files and compute mean embeddings
|
211
|
+
results = []
|
212
|
+
chunk_idx = 0
|
213
|
+
|
214
|
+
for chunk_count in file_chunk_counts:
|
215
|
+
if chunk_count == 0:
|
216
|
+
results.append(None)
|
217
|
+
else:
|
218
|
+
file_features = audio_features[chunk_idx:chunk_idx + chunk_count]
|
219
|
+
mean_embedding = torch.mean(file_features, dim=0)
|
220
|
+
normalized_embedding = torch.nn.functional.normalize(mean_embedding, p=2, dim=0)
|
221
|
+
results.append(normalized_embedding.cpu().numpy().tolist())
|
222
|
+
chunk_idx += chunk_count
|
223
|
+
|
224
|
+
self.logger.info(
|
225
|
+
f"Successfully processed batch of {len(filepaths)} audio files ({len(all_chunks)} total chunks)")
|
226
|
+
return results
|
227
|
+
|
228
|
+
except Exception as e:
|
229
|
+
self.logger.error(f"Error in batch audio embedding generation: {e}", exc_info=True)
|
230
|
+
return [None] * len(filepaths)
|
231
|
+
|
232
|
+
def generate_text_embedding(self, text: str) -> Optional[List[float]]:
|
233
|
+
"""Generate embedding for a single text query by delegating to batch method."""
|
234
|
+
results = self.generate_text_embedding_batch([text])
|
235
|
+
return results[0] if results else None
|
236
|
+
|
237
|
+
def generate_text_embedding_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
|
238
|
+
"""Generate embeddings for multiple text queries in a single GPU batch for better utilization."""
|
239
|
+
if not texts:
|
240
|
+
return []
|
241
|
+
|
242
|
+
try:
|
243
|
+
processor = self._get_processor()
|
244
|
+
model = self._get_model()
|
245
|
+
|
246
|
+
inputs = processor(
|
247
|
+
text=texts,
|
248
|
+
return_tensors="pt",
|
249
|
+
padding=True
|
250
|
+
)
|
251
|
+
|
252
|
+
inputs = self._prepare_inputs(inputs)
|
253
|
+
|
254
|
+
with torch.no_grad():
|
255
|
+
text_features = model.get_text_features(**inputs)
|
256
|
+
text_embeddings = torch.nn.functional.normalize(text_features, p=2, dim=-1)
|
257
|
+
|
258
|
+
# Convert to list of lists
|
259
|
+
results = text_embeddings.cpu().numpy().tolist()
|
260
|
+
self.logger.info(f"Successfully processed batch of {len(texts)} text queries")
|
261
|
+
return results
|
262
|
+
|
263
|
+
except Exception as e:
|
264
|
+
self.logger.error(f"Error in batch text embedding generation: {e}", exc_info=True)
|
265
|
+
return [None] * len(texts)
|
266
|
+
|
267
|
+
def unload_model(self) -> None:
|
268
|
+
"""Unload model to free GPU memory."""
|
269
|
+
if self.model is not None:
|
270
|
+
del self.model
|
271
|
+
del self.processor
|
272
|
+
self.model = None
|
273
|
+
self.processor = None
|
274
|
+
|
275
|
+
if self.device == "cuda":
|
276
|
+
torch.cuda.empty_cache()
|
277
|
+
elif self.device == "mps":
|
278
|
+
torch.mps.empty_cache()
|
279
|
+
|
280
|
+
self.logger.info("Model unloaded")
|
@@ -0,0 +1,145 @@
|
|
1
|
+
"""Plex integration for accessing music library."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List, Optional
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
from plexapi.audio import Artist
|
9
|
+
from plexapi.server import PlexServer
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from ..domain.models import Track, Playlist, MediaServerType
|
13
|
+
from ..domain.repositories import MediaServerRepository
|
14
|
+
|
15
|
+
|
16
|
+
class PlexMusicRepository(MediaServerRepository):
|
17
|
+
"""Implementation of MediaServerRepository for accessing Plex music library."""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
plex_url: str = None,
|
22
|
+
plex_token: str = None,
|
23
|
+
music_library_name: str = "Music"
|
24
|
+
):
|
25
|
+
self.plex_url = plex_url
|
26
|
+
self.plex_token = plex_token
|
27
|
+
self.music_library_name = music_library_name
|
28
|
+
self.logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
def get_all_tracks(self) -> List[Track]:
|
31
|
+
"""Get all tracks from the Plex music library."""
|
32
|
+
try:
|
33
|
+
plex = PlexServer(self.plex_url, self.plex_token, timeout=3600)
|
34
|
+
music_lib = plex.library.section(self.music_library_name)
|
35
|
+
self.logger.info(f"Connected to Plex. Scanning library '{self.music_library_name}'...")
|
36
|
+
except Exception as e:
|
37
|
+
raise ConnectionError(f"Error connecting to Plex server: {e}")
|
38
|
+
|
39
|
+
all_tracks = []
|
40
|
+
|
41
|
+
# Hierarchical iteration for better robustness and memory efficiency
|
42
|
+
artists = music_lib.all(libtype='artist')
|
43
|
+
artists: List[Artist]
|
44
|
+
for artist in tqdm(artists, desc="Processing Artists"):
|
45
|
+
try:
|
46
|
+
for album in artist.albums():
|
47
|
+
for track in album.tracks():
|
48
|
+
for part in track.iterParts():
|
49
|
+
filepath = Path(part.file)
|
50
|
+
if filepath.exists():
|
51
|
+
track_obj = Track(
|
52
|
+
artist=artist.title,
|
53
|
+
album=album.title,
|
54
|
+
title=track.title,
|
55
|
+
filepath=filepath,
|
56
|
+
media_server_rating_key=str(track.ratingKey),
|
57
|
+
media_server_type=MediaServerType.PLEX
|
58
|
+
)
|
59
|
+
all_tracks.append(track_obj)
|
60
|
+
else:
|
61
|
+
self.logger.warning(f"File not found, skipping: {filepath}")
|
62
|
+
except Exception as e:
|
63
|
+
self.logger.error(f"Error processing artist {artist.title}: {e}. Continuing...", exc_info=True)
|
64
|
+
|
65
|
+
return all_tracks
|
66
|
+
|
67
|
+
def get_track_by_id(self, track_id: str) -> Optional[Track]:
|
68
|
+
"""Get a specific track by Plex rating key."""
|
69
|
+
try:
|
70
|
+
plex = PlexServer(self.plex_url, self.plex_token)
|
71
|
+
track = plex.fetchItem(int(track_id))
|
72
|
+
|
73
|
+
# Get the first available part of the track
|
74
|
+
for part in track.iterParts():
|
75
|
+
filepath = Path(part.file)
|
76
|
+
if filepath.exists():
|
77
|
+
return Track(
|
78
|
+
artist=track.grandparentTitle or "Unknown Artist",
|
79
|
+
album=track.parentTitle or "Unknown Album",
|
80
|
+
title=track.title,
|
81
|
+
filepath=filepath,
|
82
|
+
media_server_rating_key=str(track.ratingKey),
|
83
|
+
media_server_type=MediaServerType.PLEX
|
84
|
+
)
|
85
|
+
|
86
|
+
return None
|
87
|
+
|
88
|
+
except Exception as e:
|
89
|
+
self.logger.error(f"Error getting track {track_id}: {e}", exc_info=True)
|
90
|
+
return None
|
91
|
+
|
92
|
+
def create_playlist(self, playlist: Playlist, batch_size: int = 100) -> Playlist:
|
93
|
+
"""Create a playlist on the Plex server using batch processing for large playlists.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
playlist: The playlist to create
|
97
|
+
batch_size: Number of tracks to add per batch (default: 100)
|
98
|
+
"""
|
99
|
+
try:
|
100
|
+
plex = PlexServer(self.plex_url, self.plex_token)
|
101
|
+
|
102
|
+
# Get Plex track objects for all tracks in the playlist
|
103
|
+
plex_tracks = []
|
104
|
+
for track in playlist.tracks:
|
105
|
+
try:
|
106
|
+
plex_track = plex.fetchItem(int(track.media_server_rating_key))
|
107
|
+
plex_tracks.append(plex_track)
|
108
|
+
except Exception as e:
|
109
|
+
self.logger.warning(f"Could not fetch track {track.media_server_rating_key}: {e}")
|
110
|
+
continue
|
111
|
+
|
112
|
+
if not plex_tracks:
|
113
|
+
raise ValueError("No valid tracks found for playlist creation")
|
114
|
+
|
115
|
+
total_tracks = len(plex_tracks)
|
116
|
+
self.logger.info(f"Creating playlist '{playlist.name}' with {total_tracks} tracks")
|
117
|
+
|
118
|
+
# Create the playlist with the first batch
|
119
|
+
first_batch = plex_tracks[:batch_size]
|
120
|
+
created_playlist = plex.createPlaylist(title=playlist.name, items=first_batch)
|
121
|
+
self.logger.info(f"Created playlist '{playlist.name}' with initial batch of {len(first_batch)} tracks")
|
122
|
+
|
123
|
+
# Add remaining tracks in batches
|
124
|
+
remaining_tracks = plex_tracks[batch_size:]
|
125
|
+
if remaining_tracks:
|
126
|
+
self.logger.info(f"Adding {len(remaining_tracks)} remaining tracks in batches of {batch_size}")
|
127
|
+
|
128
|
+
for i in range(0, len(remaining_tracks), batch_size):
|
129
|
+
batch = remaining_tracks[i:i + batch_size]
|
130
|
+
created_playlist.addItems(batch)
|
131
|
+
self.logger.debug(f"Added batch {i//batch_size + 1}: {len(batch)} tracks")
|
132
|
+
|
133
|
+
self.logger.info(f"Successfully completed playlist creation with all {total_tracks} tracks")
|
134
|
+
|
135
|
+
# Return the playlist with server ID and creation time
|
136
|
+
return Playlist(
|
137
|
+
name=playlist.name,
|
138
|
+
tracks=playlist.tracks,
|
139
|
+
created_at=datetime.now(),
|
140
|
+
server_id=str(created_playlist.ratingKey)
|
141
|
+
)
|
142
|
+
|
143
|
+
except Exception as e:
|
144
|
+
self.logger.error(f"Error creating playlist '{playlist.name}': {e}", exc_info=True)
|
145
|
+
raise
|