mycelium-ai 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. mycelium/__init__.py +0 -0
  2. mycelium/api/__init__.py +0 -0
  3. mycelium/api/app.py +1147 -0
  4. mycelium/api/client_app.py +170 -0
  5. mycelium/api/generated_sources/__init__.py +0 -0
  6. mycelium/api/generated_sources/server_schemas/__init__.py +97 -0
  7. mycelium/api/generated_sources/server_schemas/api/__init__.py +5 -0
  8. mycelium/api/generated_sources/server_schemas/api/default_api.py +2473 -0
  9. mycelium/api/generated_sources/server_schemas/api_client.py +766 -0
  10. mycelium/api/generated_sources/server_schemas/api_response.py +25 -0
  11. mycelium/api/generated_sources/server_schemas/configuration.py +434 -0
  12. mycelium/api/generated_sources/server_schemas/exceptions.py +166 -0
  13. mycelium/api/generated_sources/server_schemas/models/__init__.py +41 -0
  14. mycelium/api/generated_sources/server_schemas/models/api_section.py +71 -0
  15. mycelium/api/generated_sources/server_schemas/models/chroma_section.py +69 -0
  16. mycelium/api/generated_sources/server_schemas/models/clap_section.py +75 -0
  17. mycelium/api/generated_sources/server_schemas/models/compute_on_server200_response.py +79 -0
  18. mycelium/api/generated_sources/server_schemas/models/compute_on_server_request.py +67 -0
  19. mycelium/api/generated_sources/server_schemas/models/compute_text_search_request.py +69 -0
  20. mycelium/api/generated_sources/server_schemas/models/config_request.py +81 -0
  21. mycelium/api/generated_sources/server_schemas/models/config_response.py +107 -0
  22. mycelium/api/generated_sources/server_schemas/models/create_playlist_request.py +71 -0
  23. mycelium/api/generated_sources/server_schemas/models/get_similar_by_track200_response.py +143 -0
  24. mycelium/api/generated_sources/server_schemas/models/library_stats_response.py +77 -0
  25. mycelium/api/generated_sources/server_schemas/models/logging_section.py +67 -0
  26. mycelium/api/generated_sources/server_schemas/models/media_server_section.py +67 -0
  27. mycelium/api/generated_sources/server_schemas/models/playlist_response.py +73 -0
  28. mycelium/api/generated_sources/server_schemas/models/plex_section.py +71 -0
  29. mycelium/api/generated_sources/server_schemas/models/processing_response.py +90 -0
  30. mycelium/api/generated_sources/server_schemas/models/save_config_response.py +73 -0
  31. mycelium/api/generated_sources/server_schemas/models/scan_library_response.py +75 -0
  32. mycelium/api/generated_sources/server_schemas/models/search_result_response.py +75 -0
  33. mycelium/api/generated_sources/server_schemas/models/server_section.py +67 -0
  34. mycelium/api/generated_sources/server_schemas/models/stop_processing_response.py +71 -0
  35. mycelium/api/generated_sources/server_schemas/models/task_status_response.py +87 -0
  36. mycelium/api/generated_sources/server_schemas/models/track_database_stats.py +75 -0
  37. mycelium/api/generated_sources/server_schemas/models/track_response.py +77 -0
  38. mycelium/api/generated_sources/server_schemas/models/tracks_list_response.py +81 -0
  39. mycelium/api/generated_sources/server_schemas/rest.py +329 -0
  40. mycelium/api/generated_sources/server_schemas/test/__init__.py +0 -0
  41. mycelium/api/generated_sources/server_schemas/test/test_api_section.py +57 -0
  42. mycelium/api/generated_sources/server_schemas/test/test_chroma_section.py +55 -0
  43. mycelium/api/generated_sources/server_schemas/test/test_clap_section.py +60 -0
  44. mycelium/api/generated_sources/server_schemas/test/test_compute_on_server200_response.py +52 -0
  45. mycelium/api/generated_sources/server_schemas/test/test_compute_on_server_request.py +53 -0
  46. mycelium/api/generated_sources/server_schemas/test/test_compute_text_search_request.py +54 -0
  47. mycelium/api/generated_sources/server_schemas/test/test_config_request.py +66 -0
  48. mycelium/api/generated_sources/server_schemas/test/test_config_response.py +97 -0
  49. mycelium/api/generated_sources/server_schemas/test/test_create_playlist_request.py +60 -0
  50. mycelium/api/generated_sources/server_schemas/test/test_default_api.py +150 -0
  51. mycelium/api/generated_sources/server_schemas/test/test_get_similar_by_track200_response.py +61 -0
  52. mycelium/api/generated_sources/server_schemas/test/test_library_stats_response.py +63 -0
  53. mycelium/api/generated_sources/server_schemas/test/test_logging_section.py +53 -0
  54. mycelium/api/generated_sources/server_schemas/test/test_media_server_section.py +53 -0
  55. mycelium/api/generated_sources/server_schemas/test/test_playlist_response.py +58 -0
  56. mycelium/api/generated_sources/server_schemas/test/test_plex_section.py +56 -0
  57. mycelium/api/generated_sources/server_schemas/test/test_processing_response.py +61 -0
  58. mycelium/api/generated_sources/server_schemas/test/test_save_config_response.py +58 -0
  59. mycelium/api/generated_sources/server_schemas/test/test_scan_library_response.py +61 -0
  60. mycelium/api/generated_sources/server_schemas/test/test_search_result_response.py +69 -0
  61. mycelium/api/generated_sources/server_schemas/test/test_server_section.py +53 -0
  62. mycelium/api/generated_sources/server_schemas/test/test_stop_processing_response.py +55 -0
  63. mycelium/api/generated_sources/server_schemas/test/test_task_status_response.py +71 -0
  64. mycelium/api/generated_sources/server_schemas/test/test_track_database_stats.py +60 -0
  65. mycelium/api/generated_sources/server_schemas/test/test_track_response.py +63 -0
  66. mycelium/api/generated_sources/server_schemas/test/test_tracks_list_response.py +75 -0
  67. mycelium/api/generated_sources/worker_schemas/__init__.py +61 -0
  68. mycelium/api/generated_sources/worker_schemas/api/__init__.py +5 -0
  69. mycelium/api/generated_sources/worker_schemas/api/default_api.py +318 -0
  70. mycelium/api/generated_sources/worker_schemas/api_client.py +766 -0
  71. mycelium/api/generated_sources/worker_schemas/api_response.py +25 -0
  72. mycelium/api/generated_sources/worker_schemas/configuration.py +434 -0
  73. mycelium/api/generated_sources/worker_schemas/exceptions.py +166 -0
  74. mycelium/api/generated_sources/worker_schemas/models/__init__.py +23 -0
  75. mycelium/api/generated_sources/worker_schemas/models/save_config_response.py +73 -0
  76. mycelium/api/generated_sources/worker_schemas/models/worker_clap_section.py +75 -0
  77. mycelium/api/generated_sources/worker_schemas/models/worker_client_api_section.py +69 -0
  78. mycelium/api/generated_sources/worker_schemas/models/worker_client_section.py +79 -0
  79. mycelium/api/generated_sources/worker_schemas/models/worker_config_request.py +73 -0
  80. mycelium/api/generated_sources/worker_schemas/models/worker_config_response.py +89 -0
  81. mycelium/api/generated_sources/worker_schemas/models/worker_logging_section.py +67 -0
  82. mycelium/api/generated_sources/worker_schemas/rest.py +329 -0
  83. mycelium/api/generated_sources/worker_schemas/test/__init__.py +0 -0
  84. mycelium/api/generated_sources/worker_schemas/test/test_default_api.py +45 -0
  85. mycelium/api/generated_sources/worker_schemas/test/test_save_config_response.py +58 -0
  86. mycelium/api/generated_sources/worker_schemas/test/test_worker_clap_section.py +60 -0
  87. mycelium/api/generated_sources/worker_schemas/test/test_worker_client_api_section.py +55 -0
  88. mycelium/api/generated_sources/worker_schemas/test/test_worker_client_section.py +65 -0
  89. mycelium/api/generated_sources/worker_schemas/test/test_worker_config_request.py +59 -0
  90. mycelium/api/generated_sources/worker_schemas/test/test_worker_config_response.py +89 -0
  91. mycelium/api/generated_sources/worker_schemas/test/test_worker_logging_section.py +53 -0
  92. mycelium/api/worker_models.py +99 -0
  93. mycelium/application/__init__.py +11 -0
  94. mycelium/application/job_queue.py +323 -0
  95. mycelium/application/library_management_use_cases.py +292 -0
  96. mycelium/application/search_use_cases.py +96 -0
  97. mycelium/application/services.py +340 -0
  98. mycelium/client.py +554 -0
  99. mycelium/client_config.py +251 -0
  100. mycelium/client_frontend_dist/404.html +1 -0
  101. mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_buildManifest.js +1 -0
  102. mycelium/client_frontend_dist/_next/static/a4iyRdfsvkjdyMAK9cE9Y/_ssgManifest.js +1 -0
  103. mycelium/client_frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
  104. mycelium/client_frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
  105. mycelium/client_frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
  106. mycelium/client_frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
  107. mycelium/client_frontend_dist/_next/static/chunks/app/page-cc6bad295789134e.js +1 -0
  108. mycelium/client_frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
  109. mycelium/client_frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
  110. mycelium/client_frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
  111. mycelium/client_frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
  112. mycelium/client_frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
  113. mycelium/client_frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
  114. mycelium/client_frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
  115. mycelium/client_frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
  116. mycelium/client_frontend_dist/favicon.ico +0 -0
  117. mycelium/client_frontend_dist/file.svg +1 -0
  118. mycelium/client_frontend_dist/globe.svg +1 -0
  119. mycelium/client_frontend_dist/index.html +1 -0
  120. mycelium/client_frontend_dist/index.txt +20 -0
  121. mycelium/client_frontend_dist/next.svg +1 -0
  122. mycelium/client_frontend_dist/vercel.svg +1 -0
  123. mycelium/client_frontend_dist/window.svg +1 -0
  124. mycelium/config.py +346 -0
  125. mycelium/domain/__init__.py +13 -0
  126. mycelium/domain/models.py +71 -0
  127. mycelium/domain/repositories.py +98 -0
  128. mycelium/domain/worker.py +77 -0
  129. mycelium/frontend_dist/404.html +1 -0
  130. mycelium/frontend_dist/_next/static/chunks/4bd1b696-cf72ae8a39fa05aa.js +1 -0
  131. mycelium/frontend_dist/_next/static/chunks/964-830f77d7ce1c2463.js +1 -0
  132. mycelium/frontend_dist/_next/static/chunks/app/_not-found/page-d25eede5a9099bd3.js +1 -0
  133. mycelium/frontend_dist/_next/static/chunks/app/layout-9b3d32f96dfe13b6.js +1 -0
  134. mycelium/frontend_dist/_next/static/chunks/app/page-a761463485e0540b.js +1 -0
  135. mycelium/frontend_dist/_next/static/chunks/framework-7c95b8e5103c9e90.js +1 -0
  136. mycelium/frontend_dist/_next/static/chunks/main-6b37be50736577a2.js +1 -0
  137. mycelium/frontend_dist/_next/static/chunks/main-app-4153d115599d3126.js +1 -0
  138. mycelium/frontend_dist/_next/static/chunks/pages/_app-0a0020ddd67f79cf.js +1 -0
  139. mycelium/frontend_dist/_next/static/chunks/pages/_error-03529f2c21436739.js +1 -0
  140. mycelium/frontend_dist/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
  141. mycelium/frontend_dist/_next/static/chunks/webpack-c81e624915b2ea70.js +1 -0
  142. mycelium/frontend_dist/_next/static/css/1eb7f0e2c78e0734.css +1 -0
  143. mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_buildManifest.js +1 -0
  144. mycelium/frontend_dist/_next/static/glVJ0yJSL0zWN7anTTG3_/_ssgManifest.js +1 -0
  145. mycelium/frontend_dist/favicon.ico +0 -0
  146. mycelium/frontend_dist/file.svg +1 -0
  147. mycelium/frontend_dist/globe.svg +1 -0
  148. mycelium/frontend_dist/index.html +10 -0
  149. mycelium/frontend_dist/index.txt +20 -0
  150. mycelium/frontend_dist/next.svg +1 -0
  151. mycelium/frontend_dist/vercel.svg +1 -0
  152. mycelium/frontend_dist/window.svg +1 -0
  153. mycelium/infrastructure/__init__.py +17 -0
  154. mycelium/infrastructure/chroma_adapter.py +232 -0
  155. mycelium/infrastructure/clap_adapter.py +280 -0
  156. mycelium/infrastructure/plex_adapter.py +145 -0
  157. mycelium/infrastructure/track_database.py +467 -0
  158. mycelium/main.py +183 -0
  159. mycelium_ai-0.5.0.dist-info/METADATA +312 -0
  160. mycelium_ai-0.5.0.dist-info/RECORD +164 -0
  161. mycelium_ai-0.5.0.dist-info/WHEEL +5 -0
  162. mycelium_ai-0.5.0.dist-info/entry_points.txt +2 -0
  163. mycelium_ai-0.5.0.dist-info/licenses/LICENSE +21 -0
  164. mycelium_ai-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
1
+ """ChromaDB integration for storing and searching embeddings."""
2
+
3
+ import logging
4
+ import re
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ import chromadb
9
+ from tqdm import tqdm
10
+
11
+ from ..domain.models import Track, TrackEmbedding, SearchResult, MediaServerType
12
+ from ..domain.repositories import EmbeddingRepository
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ChromaEmbeddingRepository(EmbeddingRepository):
18
+ """Implementation of EmbeddingRepository using ChromaDB with model-specific collections."""
19
+
20
+ def __init__(
21
+ self,
22
+ db_path: str,
23
+ media_server_type: MediaServerType,
24
+ collection_name: str = "my_music_library",
25
+ model_id: str = "laion/larger_clap_music_and_speech",
26
+ batch_size: int = 1000,
27
+ ):
28
+ self.db_path = db_path
29
+ self.base_collection_name = collection_name
30
+ self.model_id = model_id
31
+ self.batch_size = batch_size
32
+ self.media_server_type = media_server_type
33
+
34
+ # Initialize ChromaDB client
35
+ try:
36
+ Path(db_path).mkdir(parents=True, exist_ok=True)
37
+ except Exception:
38
+ logger.error(f"Failed to create database directory at {db_path}. Please check permissions.")
39
+ self.client = chromadb.PersistentClient(path=db_path)
40
+
41
+ # Create model-specific collection name
42
+ self.collection_name = self._get_collection_name_for_model(model_id)
43
+
44
+ # Specify 'cosine' distance metric for normalized embeddings
45
+ self.collection = self.client.get_or_create_collection(
46
+ name=self.collection_name,
47
+ metadata={"hnsw:space": "cosine", "model_id": model_id}
48
+ )
49
+
50
+ logger.info(
51
+ f"Collection '{self.collection_name}' ready for model '{model_id}'. Current elements: {self.collection.count()}")
52
+
53
+ def _get_collection_name_for_model(self, model_id: str) -> str:
54
+ """Generate a safe collection name for the given model ID."""
55
+ # Make model ID safe for collection name (alphanumeric and underscores only)
56
+ safe_model_id = re.sub(r'\W', '_', model_id.replace('/', '_'))
57
+ return f"{self.base_collection_name}_{safe_model_id}"
58
+
59
+ def save_embeddings(self, embeddings: List[TrackEmbedding]) -> None:
60
+ """Save track embeddings to ChromaDB."""
61
+ if not embeddings:
62
+ return
63
+
64
+ # Prepare data for batch insertion
65
+ ids = []
66
+ embedding_vectors = []
67
+ metadatas = []
68
+
69
+ for track_embedding in embeddings:
70
+ track = track_embedding.track
71
+ ids.append(track.unique_id)
72
+ embedding_vectors.append(track_embedding.embedding)
73
+ metadatas.append({
74
+ "filepath": str(track.filepath),
75
+ "artist": track.artist,
76
+ "album": track.album,
77
+ "title": track.title,
78
+ "media_server_type": track.media_server_type.value,
79
+ "media_server_rating_key": track.media_server_rating_key,
80
+ "model_id": self.model_id
81
+ })
82
+
83
+ # Insert in batches for maximum efficiency
84
+ for i in tqdm(range(0, len(ids), self.batch_size), desc="Indexing in ChromaDB"):
85
+ end_idx = min(i + self.batch_size, len(ids))
86
+ id_batch = ids[i:end_idx]
87
+ embedding_batch = embedding_vectors[i:end_idx]
88
+ metadata_batch = metadatas[i:end_idx]
89
+
90
+ self.collection.add(
91
+ ids=id_batch,
92
+ embeddings=embedding_batch,
93
+ metadatas=metadata_batch
94
+ )
95
+
96
+ logger.info("Indexing completed!")
97
+ logger.info(f"Total elements in collection '{self.collection_name}': {self.collection.count()}")
98
+
99
+ def search_by_embedding(self, embedding: List[float], n_results: int = 10) -> List[SearchResult]:
100
+ """Search for similar tracks by embedding."""
101
+ results = self.collection.query(
102
+ query_embeddings=[embedding],
103
+ n_results=n_results
104
+ )
105
+
106
+ return self._parse_search_results(results.copy())
107
+
108
+ def get_embedding_count(self) -> int:
109
+ """Get the total number of embeddings stored."""
110
+ return self.collection.count()
111
+
112
+ @staticmethod
113
+ def _parse_search_results(results: dict) -> List[SearchResult]:
114
+ """Parse ChromaDB results into SearchResult objects."""
115
+ search_results = []
116
+
117
+ if not results['ids'] or not results['ids'][0]:
118
+ return search_results
119
+
120
+ for i in range(len(results['ids'][0])):
121
+ metadata = results['metadatas'][0][i]
122
+ distance = results['distances'][0][i]
123
+ unique_id = results['ids'][0][i]
124
+
125
+ # Parse unique_id to get media server info
126
+ media_server_type_str, media_server_rating_key = unique_id.split(':', 1)
127
+
128
+ from ..domain.models import MediaServerType
129
+ try:
130
+ media_server_type = MediaServerType(media_server_type_str)
131
+ except ValueError:
132
+ media_server_type = MediaServerType.PLEX # Default fallback
133
+
134
+ track = Track(
135
+ artist=metadata['artist'],
136
+ album=metadata['album'],
137
+ title=metadata['title'],
138
+ filepath=Path(metadata['filepath']),
139
+ media_server_rating_key=media_server_rating_key,
140
+ media_server_type=media_server_type
141
+ )
142
+
143
+ # Convert distance to similarity score (1 - distance for cosine)
144
+ similarity_score = 1.0 - distance
145
+
146
+ search_results.append(SearchResult(
147
+ track=track,
148
+ similarity_score=similarity_score,
149
+ distance=distance
150
+ ))
151
+
152
+ return search_results
153
+
154
+ def has_embedding(self, track_id: str) -> bool:
155
+ """Check if an embedding exists for a track."""
156
+ track_id = Track(media_server_type=self.media_server_type,
157
+ media_server_rating_key=track_id).unique_id
158
+ logger.debug(
159
+ f"Checking embedding for track {track_id}: collection_name={self.collection_name}, model_id={self.model_id}"
160
+ )
161
+ try:
162
+ result = self.collection.get(ids=[track_id])
163
+ exists = len(result['ids']) > 0
164
+ logger.debug(f"Checking embedding for track {track_id}: exists={exists}")
165
+ return exists
166
+ except Exception as e:
167
+ logger.error(f"Error checking embedding for track {track_id}: {e}")
168
+ return False
169
+
170
+ def save_embedding(self, track_embedding: TrackEmbedding) -> None:
171
+ """Save a single track embedding to ChromaDB."""
172
+ track = track_embedding.track
173
+ track_id = track.unique_id
174
+
175
+ logger.info(f"Saving embedding to ChromaDB for track {track_id}: {track.artist} - {track.title}")
176
+ logger.info(f"Collection count before save: {self.collection.count()}")
177
+
178
+ # Check if embedding already exists, if so, update it
179
+ existing = self.collection.get(ids=[track_id])
180
+ if existing['ids']:
181
+ logger.info(f"Updating existing embedding for track {track_id}")
182
+ self.collection.update(
183
+ ids=[track_id],
184
+ embeddings=[track_embedding.embedding],
185
+ metadatas=[{
186
+ "filepath": str(track.filepath),
187
+ "artist": track.artist,
188
+ "album": track.album,
189
+ "title": track.title,
190
+ "media_server_type": track.media_server_type.value,
191
+ "media_server_rating_key": track.media_server_rating_key,
192
+ "model_id": self.model_id
193
+ }]
194
+ )
195
+ else:
196
+ logger.info(f"Adding new embedding for track {track_id}")
197
+ self.collection.add(
198
+ ids=[track_id],
199
+ embeddings=[track_embedding.embedding],
200
+ metadatas=[{
201
+ "filepath": str(track.filepath),
202
+ "artist": track.artist,
203
+ "album": track.album,
204
+ "title": track.title,
205
+ "media_server_type": track.media_server_type.value,
206
+ "media_server_rating_key": track.media_server_rating_key,
207
+ "model_id": self.model_id
208
+ }]
209
+ )
210
+
211
+ logger.info(f"Collection count after save: {self.collection.count()}")
212
+ logger.info(f"Successfully saved embedding to ChromaDB for track {track_id}")
213
+
214
+ def get_embedding_by_track_id(self, track_id: str) -> Optional[List[float]]:
215
+ """Get embedding for a specific track."""
216
+ track_id = Track(media_server_type=self.media_server_type, media_server_rating_key=track_id).unique_id
217
+ try:
218
+ result = self.collection.get(
219
+ ids=[track_id],
220
+ include=['embeddings']
221
+ )
222
+ if result['embeddings'] is not None and len(result['embeddings']) > 0:
223
+ embedding = result['embeddings'][0]
224
+ logger.debug(
225
+ f"Retrieved embedding for track {track_id}, size: {len(embedding) if embedding is not None else 0}")
226
+ return embedding
227
+ else:
228
+ logger.debug(f"No embedding found in ChromaDB for track {track_id}")
229
+ return None
230
+ except Exception as e:
231
+ logger.error(f"Error retrieving embedding for track {track_id}: {e}")
232
+ return None
@@ -0,0 +1,280 @@
1
+ """CLAP model integration for generating embeddings."""
2
+
3
+ import logging
4
+ import random
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ import librosa
9
+ import torch
10
+ from transformers import ClapModel, ClapProcessor
11
+
12
+ from ..domain.repositories import EmbeddingGenerator
13
+
14
+
15
+ class CLAPEmbeddingGenerator(EmbeddingGenerator):
16
+ """ Implementation of EmbeddingGenerator using LAION's CLAP model. """
17
+
18
+ def __init__(
19
+ self,
20
+ model_id: str = "laion/larger_clap_music_and_speech",
21
+ target_sr: int = 48000,
22
+ chunk_duration_s: int = 10,
23
+ num_chunks: int = 3,
24
+ max_load_duration_s: Optional[int] = 120
25
+ ):
26
+ self.model_id = model_id
27
+ self.target_sr = target_sr
28
+ self.chunk_duration_s = chunk_duration_s
29
+ self.num_chunks = num_chunks
30
+ self.max_load_duration_s = max_load_duration_s
31
+ self.logger = logging.getLogger(__name__)
32
+
33
+ self.device = self.get_best_device()
34
+ self.logger.info(f"Selected device: {self.device}")
35
+
36
+ ## Lazy loading. Model is not loaded on instantiation.
37
+ self.model: Optional[ClapModel] = None
38
+ self.processor: Optional[ClapProcessor] = None
39
+
40
+ self.use_half = self.can_use_half_precision()
41
+ if self.use_half:
42
+ self.logger.info("Half precision (FP16) is supported and will be used.")
43
+ else:
44
+ self.logger.info("Half precision not supported, using full precision (FP32).")
45
+
46
+ def _load_model_if_needed(self):
47
+ """Loads the model and processor on the first call that needs them."""
48
+ if self.model is None or self.processor is None:
49
+ self.logger.info(f"Loading model '{self.model_id}' to device '{self.device}'...")
50
+
51
+ self.model = ClapModel.from_pretrained(self.model_id).to(self.device)
52
+ self.processor = ClapProcessor.from_pretrained(self.model_id)
53
+
54
+ if self.use_half and self.device == "cuda":
55
+ self.logger.info("Applying half precision (FP16) to model for CUDA device.")
56
+ self.model.half()
57
+ elif self.use_half and self.device == "mps":
58
+ self.logger.warning(
59
+ "Half precision is supported but disabled on MPS device to prevent potential crashes. Using FP32.")
60
+
61
+ self.model.eval()
62
+ self.logger.info("Model loaded successfully.")
63
+ try:
64
+ self.logger.info(f"Model dtype after load: {next(self.model.parameters()).dtype}")
65
+ except StopIteration:
66
+ self.logger.debug("Could not determine model dtype (no parameters found).")
67
+
68
+ @staticmethod
69
+ def get_best_device() -> str:
70
+ if torch.cuda.is_available():
71
+ return "cuda"
72
+ if torch.backends.mps.is_available():
73
+ return "mps"
74
+ return "cpu"
75
+
76
+ def can_use_half_precision(self) -> bool:
77
+ """Checks once if the device supports half precision."""
78
+ if self.device == "cuda":
79
+ # Most modern CUDA devices support FP16.
80
+ return True
81
+ if self.device == "mps":
82
+ # Check for potential runtime errors on some MPS devices.
83
+ try:
84
+ torch.tensor([1.0], dtype=torch.half).to(self.device)
85
+ return True
86
+ except RuntimeError:
87
+ self.logger.warning("MPS device does not support half precision, falling back to FP32.")
88
+ return False
89
+ return False
90
+
91
+ def _get_processor(self) -> ClapProcessor:
92
+ """Return a ready-to-use processor with a non-optional type."""
93
+ self._load_model_if_needed()
94
+ assert self.processor is not None
95
+ return self.processor
96
+
97
+ def _get_model(self) -> ClapModel:
98
+ """Return a ready-to-use model with a non-optional type."""
99
+ self._load_model_if_needed()
100
+ assert self.model is not None
101
+ return self.model
102
+
103
+ def _prepare_inputs(self, inputs: dict) -> dict:
104
+ """Move inputs to the correct device and cast floating tensors to the model's dtype.
105
+
106
+ This prevents dtype mismatches when the model runs in half precision on CUDA.
107
+ """
108
+ model = self._get_model()
109
+ # Determine model parameter dtype (e.g., torch.float32 or torch.float16)
110
+ model_dtype = next(model.parameters()).dtype
111
+ prepared = {}
112
+ for k, v in inputs.items():
113
+ if isinstance(v, torch.Tensor):
114
+ if v.is_floating_point():
115
+ prepared[k] = v.to(device=self.device, dtype=model_dtype)
116
+ else:
117
+ prepared[k] = v.to(device=self.device)
118
+ else:
119
+ prepared[k] = v
120
+ return prepared
121
+
122
+ def generate_embedding(self, filepath: Path) -> Optional[List[float]]:
123
+ """Generate embedding for a single audio file by delegating to batch method."""
124
+ results = self.generate_embedding_batch([filepath])
125
+ return results[0] if results else None
126
+
127
+ def generate_embedding_batch(self, filepaths: List[Path]) -> List[Optional[List[float]]]:
128
+ """Generate embeddings for multiple audio files in a single GPU batch"""
129
+ if not filepaths:
130
+ return []
131
+
132
+ try:
133
+ processor = self._get_processor()
134
+ model = self._get_model()
135
+
136
+ all_chunks = []
137
+ file_chunk_counts = []
138
+
139
+ chunk_size_samples = self.chunk_duration_s * self.target_sr
140
+
141
+ # Load and prepare all audio files
142
+ for filepath in filepaths:
143
+ try:
144
+ waveform, _ = librosa.load(
145
+ str(filepath),
146
+ sr=self.target_sr,
147
+ mono=True,
148
+ duration=self.max_load_duration_s
149
+ )
150
+
151
+ total_samples = len(waveform)
152
+ chunks = []
153
+
154
+ # Calculate how many full, non-overlapping chunks can fit.
155
+ num_possible_bins = total_samples // chunk_size_samples
156
+
157
+ if num_possible_bins == 0:
158
+ self.logger.warning(
159
+ f"File {filepath} is too short ({total_samples / self.target_sr:.1f}s) "
160
+ f"for even one chunk of {self.chunk_duration_s:.1f}s.")
161
+ file_chunk_counts.append(0)
162
+ continue
163
+
164
+ # Determine which bin indices to sample from.
165
+ if num_possible_bins < self.num_chunks:
166
+ self.logger.warning(
167
+ f"File {filepath} only has space for {num_possible_bins} non-overlapping chunks, "
168
+ f"less than the requested {self.num_chunks}. Using all available chunks."
169
+ )
170
+ chosen_bin_indices = range(num_possible_bins)
171
+ else:
172
+ possible_bin_indices = range(num_possible_bins)
173
+ chosen_bin_indices = random.sample(possible_bin_indices, k=self.num_chunks)
174
+
175
+ # Create chunks based on the chosen indices.
176
+ for bin_index in chosen_bin_indices:
177
+ start_idx = bin_index * chunk_size_samples
178
+ end_idx = start_idx + chunk_size_samples
179
+ chunk = waveform[start_idx:end_idx]
180
+ chunks.append(chunk)
181
+
182
+ if not chunks:
183
+ self.logger.warning(f"No valid chunks generated for {filepath}.")
184
+ file_chunk_counts.append(0)
185
+ continue
186
+
187
+ all_chunks.extend(chunks)
188
+ file_chunk_counts.append(len(chunks))
189
+
190
+ except Exception as e:
191
+ self.logger.error(f"Error loading audio file {filepath}: {e}")
192
+ file_chunk_counts.append(0)
193
+
194
+ if not all_chunks:
195
+ return [None] * len(filepaths)
196
+
197
+ # Process all chunks in a single batch
198
+ inputs = processor(
199
+ audios=all_chunks,
200
+ sampling_rate=self.target_sr,
201
+ return_tensors="pt",
202
+ padding=True
203
+ )
204
+
205
+ inputs = self._prepare_inputs(inputs)
206
+
207
+ with torch.no_grad():
208
+ audio_features = model.get_audio_features(**inputs)
209
+
210
+ # Split results back to individual files and compute mean embeddings
211
+ results = []
212
+ chunk_idx = 0
213
+
214
+ for chunk_count in file_chunk_counts:
215
+ if chunk_count == 0:
216
+ results.append(None)
217
+ else:
218
+ file_features = audio_features[chunk_idx:chunk_idx + chunk_count]
219
+ mean_embedding = torch.mean(file_features, dim=0)
220
+ normalized_embedding = torch.nn.functional.normalize(mean_embedding, p=2, dim=0)
221
+ results.append(normalized_embedding.cpu().numpy().tolist())
222
+ chunk_idx += chunk_count
223
+
224
+ self.logger.info(
225
+ f"Successfully processed batch of {len(filepaths)} audio files ({len(all_chunks)} total chunks)")
226
+ return results
227
+
228
+ except Exception as e:
229
+ self.logger.error(f"Error in batch audio embedding generation: {e}", exc_info=True)
230
+ return [None] * len(filepaths)
231
+
232
+ def generate_text_embedding(self, text: str) -> Optional[List[float]]:
233
+ """Generate embedding for a single text query by delegating to batch method."""
234
+ results = self.generate_text_embedding_batch([text])
235
+ return results[0] if results else None
236
+
237
+ def generate_text_embedding_batch(self, texts: List[str]) -> List[Optional[List[float]]]:
238
+ """Generate embeddings for multiple text queries in a single GPU batch for better utilization."""
239
+ if not texts:
240
+ return []
241
+
242
+ try:
243
+ processor = self._get_processor()
244
+ model = self._get_model()
245
+
246
+ inputs = processor(
247
+ text=texts,
248
+ return_tensors="pt",
249
+ padding=True
250
+ )
251
+
252
+ inputs = self._prepare_inputs(inputs)
253
+
254
+ with torch.no_grad():
255
+ text_features = model.get_text_features(**inputs)
256
+ text_embeddings = torch.nn.functional.normalize(text_features, p=2, dim=-1)
257
+
258
+ # Convert to list of lists
259
+ results = text_embeddings.cpu().numpy().tolist()
260
+ self.logger.info(f"Successfully processed batch of {len(texts)} text queries")
261
+ return results
262
+
263
+ except Exception as e:
264
+ self.logger.error(f"Error in batch text embedding generation: {e}", exc_info=True)
265
+ return [None] * len(texts)
266
+
267
+ def unload_model(self) -> None:
268
+ """Unload model to free GPU memory."""
269
+ if self.model is not None:
270
+ del self.model
271
+ del self.processor
272
+ self.model = None
273
+ self.processor = None
274
+
275
+ if self.device == "cuda":
276
+ torch.cuda.empty_cache()
277
+ elif self.device == "mps":
278
+ torch.mps.empty_cache()
279
+
280
+ self.logger.info("Model unloaded")
@@ -0,0 +1,145 @@
1
+ """Plex integration for accessing music library."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+ from datetime import datetime
7
+
8
+ from plexapi.audio import Artist
9
+ from plexapi.server import PlexServer
10
+ from tqdm import tqdm
11
+
12
+ from ..domain.models import Track, Playlist, MediaServerType
13
+ from ..domain.repositories import MediaServerRepository
14
+
15
+
16
+ class PlexMusicRepository(MediaServerRepository):
17
+ """Implementation of MediaServerRepository for accessing Plex music library."""
18
+
19
+ def __init__(
20
+ self,
21
+ plex_url: str = None,
22
+ plex_token: str = None,
23
+ music_library_name: str = "Music"
24
+ ):
25
+ self.plex_url = plex_url
26
+ self.plex_token = plex_token
27
+ self.music_library_name = music_library_name
28
+ self.logger = logging.getLogger(__name__)
29
+
30
+ def get_all_tracks(self) -> List[Track]:
31
+ """Get all tracks from the Plex music library."""
32
+ try:
33
+ plex = PlexServer(self.plex_url, self.plex_token, timeout=3600)
34
+ music_lib = plex.library.section(self.music_library_name)
35
+ self.logger.info(f"Connected to Plex. Scanning library '{self.music_library_name}'...")
36
+ except Exception as e:
37
+ raise ConnectionError(f"Error connecting to Plex server: {e}")
38
+
39
+ all_tracks = []
40
+
41
+ # Hierarchical iteration for better robustness and memory efficiency
42
+ artists = music_lib.all(libtype='artist')
43
+ artists: List[Artist]
44
+ for artist in tqdm(artists, desc="Processing Artists"):
45
+ try:
46
+ for album in artist.albums():
47
+ for track in album.tracks():
48
+ for part in track.iterParts():
49
+ filepath = Path(part.file)
50
+ if filepath.exists():
51
+ track_obj = Track(
52
+ artist=artist.title,
53
+ album=album.title,
54
+ title=track.title,
55
+ filepath=filepath,
56
+ media_server_rating_key=str(track.ratingKey),
57
+ media_server_type=MediaServerType.PLEX
58
+ )
59
+ all_tracks.append(track_obj)
60
+ else:
61
+ self.logger.warning(f"File not found, skipping: {filepath}")
62
+ except Exception as e:
63
+ self.logger.error(f"Error processing artist {artist.title}: {e}. Continuing...", exc_info=True)
64
+
65
+ return all_tracks
66
+
67
+ def get_track_by_id(self, track_id: str) -> Optional[Track]:
68
+ """Get a specific track by Plex rating key."""
69
+ try:
70
+ plex = PlexServer(self.plex_url, self.plex_token)
71
+ track = plex.fetchItem(int(track_id))
72
+
73
+ # Get the first available part of the track
74
+ for part in track.iterParts():
75
+ filepath = Path(part.file)
76
+ if filepath.exists():
77
+ return Track(
78
+ artist=track.grandparentTitle or "Unknown Artist",
79
+ album=track.parentTitle or "Unknown Album",
80
+ title=track.title,
81
+ filepath=filepath,
82
+ media_server_rating_key=str(track.ratingKey),
83
+ media_server_type=MediaServerType.PLEX
84
+ )
85
+
86
+ return None
87
+
88
+ except Exception as e:
89
+ self.logger.error(f"Error getting track {track_id}: {e}", exc_info=True)
90
+ return None
91
+
92
+ def create_playlist(self, playlist: Playlist, batch_size: int = 100) -> Playlist:
93
+ """Create a playlist on the Plex server using batch processing for large playlists.
94
+
95
+ Args:
96
+ playlist: The playlist to create
97
+ batch_size: Number of tracks to add per batch (default: 100)
98
+ """
99
+ try:
100
+ plex = PlexServer(self.plex_url, self.plex_token)
101
+
102
+ # Get Plex track objects for all tracks in the playlist
103
+ plex_tracks = []
104
+ for track in playlist.tracks:
105
+ try:
106
+ plex_track = plex.fetchItem(int(track.media_server_rating_key))
107
+ plex_tracks.append(plex_track)
108
+ except Exception as e:
109
+ self.logger.warning(f"Could not fetch track {track.media_server_rating_key}: {e}")
110
+ continue
111
+
112
+ if not plex_tracks:
113
+ raise ValueError("No valid tracks found for playlist creation")
114
+
115
+ total_tracks = len(plex_tracks)
116
+ self.logger.info(f"Creating playlist '{playlist.name}' with {total_tracks} tracks")
117
+
118
+ # Create the playlist with the first batch
119
+ first_batch = plex_tracks[:batch_size]
120
+ created_playlist = plex.createPlaylist(title=playlist.name, items=first_batch)
121
+ self.logger.info(f"Created playlist '{playlist.name}' with initial batch of {len(first_batch)} tracks")
122
+
123
+ # Add remaining tracks in batches
124
+ remaining_tracks = plex_tracks[batch_size:]
125
+ if remaining_tracks:
126
+ self.logger.info(f"Adding {len(remaining_tracks)} remaining tracks in batches of {batch_size}")
127
+
128
+ for i in range(0, len(remaining_tracks), batch_size):
129
+ batch = remaining_tracks[i:i + batch_size]
130
+ created_playlist.addItems(batch)
131
+ self.logger.debug(f"Added batch {i//batch_size + 1}: {len(batch)} tracks")
132
+
133
+ self.logger.info(f"Successfully completed playlist creation with all {total_tracks} tracks")
134
+
135
+ # Return the playlist with server ID and creation time
136
+ return Playlist(
137
+ name=playlist.name,
138
+ tracks=playlist.tracks,
139
+ created_at=datetime.now(),
140
+ server_id=str(created_playlist.ratingKey)
141
+ )
142
+
143
+ except Exception as e:
144
+ self.logger.error(f"Error creating playlist '{playlist.name}': {e}", exc_info=True)
145
+ raise