plexmix 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plexmix/__init__.py +20 -0
- plexmix/ai/__init__.py +51 -0
- plexmix/ai/base.py +113 -0
- plexmix/ai/claude_provider.py +50 -0
- plexmix/ai/gemini_provider.py +64 -0
- plexmix/ai/openai_provider.py +50 -0
- plexmix/ai/tag_generator.py +274 -0
- plexmix/cli/__init__.py +0 -0
- plexmix/cli/main.py +678 -0
- plexmix/config/__init__.py +0 -0
- plexmix/config/credentials.py +73 -0
- plexmix/config/settings.py +132 -0
- plexmix/database/__init__.py +0 -0
- plexmix/database/models.py +166 -0
- plexmix/database/sqlite_manager.py +433 -0
- plexmix/database/vector_index.py +169 -0
- plexmix/playlist/__init__.py +0 -0
- plexmix/playlist/generator.py +195 -0
- plexmix/plex/__init__.py +0 -0
- plexmix/plex/client.py +278 -0
- plexmix/plex/sync.py +275 -0
- plexmix/utils/__init__.py +0 -0
- plexmix/utils/embeddings.py +286 -0
- plexmix/utils/logging.py +64 -0
- plexmix-0.1.0.dist-info/METADATA +394 -0
- plexmix-0.1.0.dist-info/RECORD +28 -0
- plexmix-0.1.0.dist-info/WHEEL +4 -0
- plexmix-0.1.0.dist-info/entry_points.txt +3 -0
plexmix/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
os.environ['GRPC_VERBOSITY'] = 'NONE'
|
|
4
|
+
os.environ['GRPC_TRACE'] = ''
|
|
5
|
+
os.environ['GLOG_minloglevel'] = '2'
|
|
6
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
7
|
+
|
|
8
|
+
import sys
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
if not sys.flags.dev_mode:
|
|
12
|
+
warnings.filterwarnings('ignore', category=DeprecationWarning, module='google')
|
|
13
|
+
warnings.filterwarnings('ignore', category=FutureWarning, module='google')
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import absl.logging
|
|
17
|
+
absl.logging.set_verbosity('error')
|
|
18
|
+
absl.logging.set_stderrthreshold('error')
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
plexmix/ai/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import os
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from .base import AIProvider
|
|
6
|
+
from .gemini_provider import GeminiProvider
|
|
7
|
+
from .openai_provider import OpenAIProvider
|
|
8
|
+
from .claude_provider import ClaudeProvider
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_ai_provider(
|
|
14
|
+
provider_name: str,
|
|
15
|
+
api_key: Optional[str] = None,
|
|
16
|
+
model: Optional[str] = None,
|
|
17
|
+
temperature: float = 0.7
|
|
18
|
+
) -> AIProvider:
|
|
19
|
+
provider_name = provider_name.lower()
|
|
20
|
+
|
|
21
|
+
if api_key is None:
|
|
22
|
+
if provider_name == "gemini":
|
|
23
|
+
api_key = os.getenv("GOOGLE_API_KEY")
|
|
24
|
+
elif provider_name == "openai":
|
|
25
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
26
|
+
elif provider_name == "claude":
|
|
27
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
28
|
+
|
|
29
|
+
if not api_key:
|
|
30
|
+
raise ValueError(f"API key required for {provider_name} provider")
|
|
31
|
+
|
|
32
|
+
if provider_name == "gemini":
|
|
33
|
+
model = model or "gemini-2.5-flash"
|
|
34
|
+
return GeminiProvider(api_key, model, temperature)
|
|
35
|
+
elif provider_name == "openai":
|
|
36
|
+
model = model or "gpt-4o-mini"
|
|
37
|
+
return OpenAIProvider(api_key, model, temperature)
|
|
38
|
+
elif provider_name == "claude":
|
|
39
|
+
model = model or "claude-sonnet-4-5-20250929"
|
|
40
|
+
return ClaudeProvider(api_key, model, temperature)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Unknown provider: {provider_name}. Choose from: gemini, openai, claude")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"AIProvider",
|
|
47
|
+
"GeminiProvider",
|
|
48
|
+
"OpenAIProvider",
|
|
49
|
+
"ClaudeProvider",
|
|
50
|
+
"get_ai_provider"
|
|
51
|
+
]
|
plexmix/ai/base.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Any, Optional
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AIProvider(ABC):
|
|
10
|
+
def __init__(self, api_key: str, model: str, temperature: float = 0.7):
|
|
11
|
+
self.api_key = api_key
|
|
12
|
+
self.model = model
|
|
13
|
+
self.temperature = temperature
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def generate_playlist(
|
|
17
|
+
self,
|
|
18
|
+
mood_query: str,
|
|
19
|
+
candidate_tracks: List[Dict[str, Any]],
|
|
20
|
+
max_tracks: int = 50
|
|
21
|
+
) -> List[int]:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def get_max_candidates(self) -> int:
|
|
25
|
+
"""Return maximum candidate pool size based on model context window."""
|
|
26
|
+
context_limits = {
|
|
27
|
+
'gemini-2.5-flash': 1000,
|
|
28
|
+
'gpt-5-mini': 500,
|
|
29
|
+
'gpt-5-nano': 500,
|
|
30
|
+
'gpt-4o-mini': 200,
|
|
31
|
+
'gpt-4o': 300,
|
|
32
|
+
'claude-sonnet-4-5': 300,
|
|
33
|
+
'claude-sonnet-4-5-20250929': 300,
|
|
34
|
+
'claude-3-5-haiku-20241022': 300,
|
|
35
|
+
'claude-3-haiku': 200,
|
|
36
|
+
}
|
|
37
|
+
return context_limits.get(self.model, 200)
|
|
38
|
+
|
|
39
|
+
def _prepare_prompt(
|
|
40
|
+
self,
|
|
41
|
+
mood_query: str,
|
|
42
|
+
candidate_tracks: List[Dict[str, Any]],
|
|
43
|
+
max_tracks: int
|
|
44
|
+
) -> str:
|
|
45
|
+
max_candidates = self.get_max_candidates()
|
|
46
|
+
if len(candidate_tracks) > max_candidates:
|
|
47
|
+
logger.warning(f"Truncating {len(candidate_tracks)} candidates to {max_candidates} for model {self.model}")
|
|
48
|
+
candidate_tracks = candidate_tracks[:max_candidates]
|
|
49
|
+
system_prompt = """You are an expert music curator helping create the perfect playlist.
|
|
50
|
+
Your task is to select tracks from the provided candidate list that best match the user's mood query.
|
|
51
|
+
|
|
52
|
+
Rules:
|
|
53
|
+
1. Select exactly the requested number of tracks
|
|
54
|
+
2. Only select from the provided candidate list
|
|
55
|
+
3. Order tracks by relevance to the mood query
|
|
56
|
+
4. Consider the track's title, artist, album, genre, and year
|
|
57
|
+
5. **CRITICAL: Do NOT select the same track title + artist combination more than once (no duplicates)**
|
|
58
|
+
6. **IMPORTANT: Prioritize artist diversity - avoid selecting multiple tracks from the same artist unless necessary**
|
|
59
|
+
7. **IMPORTANT: Prioritize album diversity - avoid selecting multiple tracks from the same album unless necessary**
|
|
60
|
+
8. Return ONLY a JSON array of track IDs, nothing else
|
|
61
|
+
|
|
62
|
+
Response format: [1, 5, 12, 23, ...]"""
|
|
63
|
+
|
|
64
|
+
tracks_json = json.dumps(candidate_tracks, indent=2)
|
|
65
|
+
|
|
66
|
+
user_prompt = f"""Mood Query: "{mood_query}"
|
|
67
|
+
|
|
68
|
+
Number of tracks to select: {max_tracks}
|
|
69
|
+
|
|
70
|
+
Candidate tracks:
|
|
71
|
+
{tracks_json}
|
|
72
|
+
|
|
73
|
+
Select {max_tracks} tracks that best match the mood "{mood_query}". Return only a JSON array of track IDs."""
|
|
74
|
+
|
|
75
|
+
return system_prompt + "\n\n" + user_prompt
|
|
76
|
+
|
|
77
|
+
def _parse_response(self, response: str) -> List[int]:
|
|
78
|
+
try:
|
|
79
|
+
response = response.strip()
|
|
80
|
+
|
|
81
|
+
if response.startswith("```"):
|
|
82
|
+
lines = response.split("\n")
|
|
83
|
+
response = "\n".join([line for line in lines if not line.startswith("```")])
|
|
84
|
+
|
|
85
|
+
track_ids = json.loads(response)
|
|
86
|
+
|
|
87
|
+
if not isinstance(track_ids, list):
|
|
88
|
+
logger.error("Response is not a list")
|
|
89
|
+
return []
|
|
90
|
+
|
|
91
|
+
return [int(tid) for tid in track_ids]
|
|
92
|
+
|
|
93
|
+
except json.JSONDecodeError as e:
|
|
94
|
+
logger.error(f"Failed to parse JSON response: {e}")
|
|
95
|
+
return []
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"Failed to parse response: {e}")
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
def _validate_selections(
|
|
101
|
+
self,
|
|
102
|
+
selections: List[int],
|
|
103
|
+
candidate_tracks: List[Dict[str, Any]]
|
|
104
|
+
) -> List[int]:
|
|
105
|
+
valid_ids = {track['id'] for track in candidate_tracks}
|
|
106
|
+
validated = [tid for tid in selections if tid in valid_ids]
|
|
107
|
+
|
|
108
|
+
if len(validated) < len(selections):
|
|
109
|
+
logger.warning(
|
|
110
|
+
f"Filtered out {len(selections) - len(validated)} invalid track IDs"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return validated
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from .base import AIProvider
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ClaudeProvider(AIProvider):
|
|
10
|
+
def __init__(self, api_key: str, model: str = "claude-sonnet-4-5-20250929", temperature: float = 0.7):
|
|
11
|
+
super().__init__(api_key, model, temperature)
|
|
12
|
+
try:
|
|
13
|
+
from anthropic import Anthropic
|
|
14
|
+
self.client = Anthropic(api_key=api_key)
|
|
15
|
+
logger.info(f"Initialized Claude provider with model {model}")
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError("anthropic not installed. Run: pip install anthropic")
|
|
18
|
+
|
|
19
|
+
def generate_playlist(
|
|
20
|
+
self,
|
|
21
|
+
mood_query: str,
|
|
22
|
+
candidate_tracks: List[Dict[str, Any]],
|
|
23
|
+
max_tracks: int = 50
|
|
24
|
+
) -> List[int]:
|
|
25
|
+
try:
|
|
26
|
+
prompt = self._prepare_prompt(mood_query, candidate_tracks, max_tracks)
|
|
27
|
+
|
|
28
|
+
response = self.client.messages.create(
|
|
29
|
+
model=self.model,
|
|
30
|
+
max_tokens=4096,
|
|
31
|
+
temperature=self.temperature,
|
|
32
|
+
messages=[
|
|
33
|
+
{"role": "user", "content": prompt}
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if not response.content or not response.content[0].text:
|
|
38
|
+
logger.error("Empty response from Claude")
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
content = response.content[0].text
|
|
42
|
+
track_ids = self._parse_response(content)
|
|
43
|
+
validated_ids = self._validate_selections(track_ids, candidate_tracks)
|
|
44
|
+
|
|
45
|
+
logger.info(f"Claude selected {len(validated_ids)} tracks for mood: {mood_query}")
|
|
46
|
+
return validated_ids[:max_tracks]
|
|
47
|
+
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"Failed to generate playlist with Claude: {e}")
|
|
50
|
+
return []
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from .base import AIProvider
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GeminiProvider(AIProvider):
|
|
10
|
+
def __init__(self, api_key: str, model: str = "gemini-2.5-flash", temperature: float = 0.7):
|
|
11
|
+
super().__init__(api_key, model, temperature)
|
|
12
|
+
try:
|
|
13
|
+
import google.generativeai as genai
|
|
14
|
+
genai.configure(api_key=api_key)
|
|
15
|
+
self.genai = genai
|
|
16
|
+
logger.info(f"Initialized Gemini AI provider with model {model}")
|
|
17
|
+
except ImportError:
|
|
18
|
+
raise ImportError("google-generativeai not installed. Run: pip install google-generativeai")
|
|
19
|
+
|
|
20
|
+
def generate_playlist(
|
|
21
|
+
self,
|
|
22
|
+
mood_query: str,
|
|
23
|
+
candidate_tracks: List[Dict[str, Any]],
|
|
24
|
+
max_tracks: int = 50
|
|
25
|
+
) -> List[int]:
|
|
26
|
+
try:
|
|
27
|
+
prompt = self._prepare_prompt(mood_query, candidate_tracks, max_tracks)
|
|
28
|
+
|
|
29
|
+
model = self.genai.GenerativeModel(
|
|
30
|
+
model_name=self.model,
|
|
31
|
+
generation_config={
|
|
32
|
+
"temperature": self.temperature,
|
|
33
|
+
"max_output_tokens": 8192,
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
response = model.generate_content(prompt)
|
|
38
|
+
|
|
39
|
+
if not response:
|
|
40
|
+
logger.error("Empty response from Gemini")
|
|
41
|
+
return []
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
response_text = response.text
|
|
45
|
+
except ValueError:
|
|
46
|
+
if response.candidates and response.candidates[0].content.parts:
|
|
47
|
+
response_text = "".join(part.text for part in response.candidates[0].content.parts)
|
|
48
|
+
else:
|
|
49
|
+
logger.error("Could not extract text from Gemini response")
|
|
50
|
+
return []
|
|
51
|
+
|
|
52
|
+
if not response_text:
|
|
53
|
+
logger.error("Empty response text from Gemini")
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
track_ids = self._parse_response(response_text)
|
|
57
|
+
validated_ids = self._validate_selections(track_ids, candidate_tracks)
|
|
58
|
+
|
|
59
|
+
logger.info(f"Gemini selected {len(validated_ids)} tracks for mood: {mood_query}")
|
|
60
|
+
return validated_ids[:max_tracks]
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.error(f"Failed to generate playlist with Gemini: {e}")
|
|
64
|
+
return []
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from .base import AIProvider
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenAIProvider(AIProvider):
|
|
10
|
+
def __init__(self, api_key: str, model: str = "gpt-4o-mini", temperature: float = 0.7):
|
|
11
|
+
super().__init__(api_key, model, temperature)
|
|
12
|
+
try:
|
|
13
|
+
from openai import OpenAI
|
|
14
|
+
self.client = OpenAI(api_key=api_key)
|
|
15
|
+
logger.info(f"Initialized OpenAI provider with model {model}")
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError("openai not installed. Run: pip install openai")
|
|
18
|
+
|
|
19
|
+
def generate_playlist(
|
|
20
|
+
self,
|
|
21
|
+
mood_query: str,
|
|
22
|
+
candidate_tracks: List[Dict[str, Any]],
|
|
23
|
+
max_tracks: int = 50
|
|
24
|
+
) -> List[int]:
|
|
25
|
+
try:
|
|
26
|
+
prompt = self._prepare_prompt(mood_query, candidate_tracks, max_tracks)
|
|
27
|
+
|
|
28
|
+
response = self.client.chat.completions.create(
|
|
29
|
+
model=self.model,
|
|
30
|
+
messages=[
|
|
31
|
+
{"role": "user", "content": prompt}
|
|
32
|
+
],
|
|
33
|
+
temperature=self.temperature,
|
|
34
|
+
max_tokens=4096
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if not response.choices or not response.choices[0].message.content:
|
|
38
|
+
logger.error("Empty response from OpenAI")
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
content = response.choices[0].message.content
|
|
42
|
+
track_ids = self._parse_response(content)
|
|
43
|
+
validated_ids = self._validate_selections(track_ids, candidate_tracks)
|
|
44
|
+
|
|
45
|
+
logger.info(f"OpenAI selected {len(validated_ids)} tracks for mood: {mood_query}")
|
|
46
|
+
return validated_ids[:max_tracks]
|
|
47
|
+
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"Failed to generate playlist with OpenAI: {e}")
|
|
50
|
+
return []
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import re
|
|
6
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
7
|
+
|
|
8
|
+
from .base import AIProvider
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TagGenerator:
|
|
14
|
+
def __init__(self, ai_provider: AIProvider):
|
|
15
|
+
self.ai_provider = ai_provider
|
|
16
|
+
|
|
17
|
+
def generate_tags_batch(
|
|
18
|
+
self,
|
|
19
|
+
tracks: List[Dict[str, Any]],
|
|
20
|
+
batch_size: int = 20
|
|
21
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
22
|
+
logger.debug(f"Generating tags for {len(tracks)} tracks")
|
|
23
|
+
return self._generate_batch(tracks)
|
|
24
|
+
|
|
25
|
+
def _generate_batch(self, tracks: List[Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
|
|
26
|
+
prompt = self._prepare_tag_prompt(tracks)
|
|
27
|
+
|
|
28
|
+
max_retries = 3
|
|
29
|
+
base_delay = 1
|
|
30
|
+
|
|
31
|
+
for attempt in range(max_retries):
|
|
32
|
+
try:
|
|
33
|
+
response = self._call_ai_provider(prompt)
|
|
34
|
+
parsed_tags = self._parse_tag_response(response, tracks)
|
|
35
|
+
return parsed_tags
|
|
36
|
+
except json.JSONDecodeError as e:
|
|
37
|
+
if attempt < max_retries - 1:
|
|
38
|
+
delay = base_delay * (attempt + 1)
|
|
39
|
+
logger.warning(f"JSON parse error (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s...")
|
|
40
|
+
time.sleep(delay)
|
|
41
|
+
continue
|
|
42
|
+
else:
|
|
43
|
+
logger.error(f"Failed to parse JSON after {max_retries} attempts: {e}")
|
|
44
|
+
return {track['id']: {'tags': [], 'environments': [], 'instruments': []} for track in tracks}
|
|
45
|
+
except Exception as e:
|
|
46
|
+
error_str = str(e)
|
|
47
|
+
|
|
48
|
+
is_rate_limit = "429" in error_str or "quota" in error_str.lower() or "rate" in error_str.lower()
|
|
49
|
+
is_timeout = "504" in error_str or "timeout" in error_str.lower() or "timed out" in error_str.lower()
|
|
50
|
+
is_server_error = "500" in error_str or "502" in error_str or "503" in error_str
|
|
51
|
+
|
|
52
|
+
if is_rate_limit or is_timeout or is_server_error:
|
|
53
|
+
if attempt < max_retries - 1:
|
|
54
|
+
retry_after = self._extract_retry_delay(error_str)
|
|
55
|
+
|
|
56
|
+
if retry_after:
|
|
57
|
+
delay = retry_after * 1.5
|
|
58
|
+
logger.warning(f"API error (attempt {attempt + 1}/{max_retries}). Server suggested {retry_after}s, using {delay:.1f}s with backoff...")
|
|
59
|
+
else:
|
|
60
|
+
delay = base_delay * (2 ** attempt)
|
|
61
|
+
if is_rate_limit:
|
|
62
|
+
logger.warning(f"Rate limit hit (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s...")
|
|
63
|
+
elif is_timeout:
|
|
64
|
+
logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s...")
|
|
65
|
+
else:
|
|
66
|
+
logger.warning(f"Server error (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s...")
|
|
67
|
+
|
|
68
|
+
time.sleep(delay)
|
|
69
|
+
continue
|
|
70
|
+
else:
|
|
71
|
+
logger.error(f"Failed after {max_retries} attempts: {e}")
|
|
72
|
+
return {track['id']: {'tags': [], 'environments': [], 'instruments': []} for track in tracks}
|
|
73
|
+
else:
|
|
74
|
+
logger.error(f"Failed to generate tags for batch: {e}")
|
|
75
|
+
return {track['id']: {'tags': [], 'environments': [], 'instruments': []} for track in tracks}
|
|
76
|
+
|
|
77
|
+
return {track['id']: {'tags': [], 'environments': [], 'instruments': []} for track in tracks}
|
|
78
|
+
|
|
79
|
+
def _extract_retry_delay(self, error_message: str) -> Optional[float]:
|
|
80
|
+
retry_match = re.search(r'retry_delay\s*\{\s*seconds:\s*(\d+)', error_message)
|
|
81
|
+
if retry_match:
|
|
82
|
+
return float(retry_match.group(1))
|
|
83
|
+
|
|
84
|
+
retry_after_match = re.search(r'Retry-After:\s*(\d+)', error_message, re.IGNORECASE)
|
|
85
|
+
if retry_after_match:
|
|
86
|
+
return float(retry_after_match.group(1))
|
|
87
|
+
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def _prepare_tag_prompt(self, tracks: List[Dict[str, Any]]) -> str:
|
|
91
|
+
system_prompt = """You are a music expert helping to categorize songs with descriptive tags, environment contexts, and primary instruments.
|
|
92
|
+
|
|
93
|
+
Your task is to analyze each song and provide:
|
|
94
|
+
1. **Tags** (3-5 descriptive tags)
|
|
95
|
+
2. **Environments** (top 3 best-fit contexts from: work, study, focus, relax, party, workout, sleep, driving, social)
|
|
96
|
+
3. **Instruments** (top 3 most prominent instruments from: piano, guitar, saxophone, trumpet, drums, bass, synth, vocals, strings, orchestra, flute, violin, cello, harmonica, accordion)
|
|
97
|
+
|
|
98
|
+
Tags should describe:
|
|
99
|
+
- Mood (e.g., energetic, melancholic, upbeat, chill, intense)
|
|
100
|
+
- Energy level (e.g., high-energy, low-energy, moderate)
|
|
101
|
+
- Tempo feel (e.g., fast-paced, slow, mid-tempo)
|
|
102
|
+
- Emotional tone (e.g., happy, sad, angry, romantic, nostalgic)
|
|
103
|
+
|
|
104
|
+
Rules:
|
|
105
|
+
1. Assign 3-5 tags per song
|
|
106
|
+
2. Assign 1-3 environments per song (ordered by best fit)
|
|
107
|
+
3. Assign 1-3 instruments per song (ordered by prominence)
|
|
108
|
+
4. Use lowercase for all fields
|
|
109
|
+
5. Be consistent with naming
|
|
110
|
+
6. Return ONLY a JSON object mapping track IDs to objects with tags, environments, and instruments
|
|
111
|
+
|
|
112
|
+
Example output format:
|
|
113
|
+
{
|
|
114
|
+
"1": {
|
|
115
|
+
"tags": ["energetic", "workout", "high-energy", "upbeat"],
|
|
116
|
+
"environments": ["workout", "party", "driving"],
|
|
117
|
+
"instruments": ["guitar", "drums", "bass"]
|
|
118
|
+
},
|
|
119
|
+
"2": {
|
|
120
|
+
"tags": ["melancholic", "slow", "sad", "introspective"],
|
|
121
|
+
"environments": ["study", "focus", "relax"],
|
|
122
|
+
"instruments": ["piano", "strings"]
|
|
123
|
+
}
|
|
124
|
+
}"""
|
|
125
|
+
|
|
126
|
+
tracks_list = []
|
|
127
|
+
for track in tracks:
|
|
128
|
+
tracks_list.append({
|
|
129
|
+
'id': track['id'],
|
|
130
|
+
'title': track['title'],
|
|
131
|
+
'artist': track['artist'],
|
|
132
|
+
'genre': track.get('genre', 'unknown')
|
|
133
|
+
})
|
|
134
|
+
|
|
135
|
+
tracks_json = json.dumps(tracks_list, indent=2)
|
|
136
|
+
|
|
137
|
+
user_prompt = f"""Assign tags to the following songs:
|
|
138
|
+
|
|
139
|
+
{tracks_json}
|
|
140
|
+
|
|
141
|
+
Return a JSON object mapping each track ID to an array of 3-5 descriptive tags."""
|
|
142
|
+
|
|
143
|
+
return system_prompt + "\n\n" + user_prompt
|
|
144
|
+
|
|
145
|
+
def _call_ai_provider(self, prompt: str) -> str:
|
|
146
|
+
try:
|
|
147
|
+
import google.generativeai as genai
|
|
148
|
+
|
|
149
|
+
if hasattr(self.ai_provider, 'genai'):
|
|
150
|
+
model = self.ai_provider.genai.GenerativeModel(
|
|
151
|
+
model_name=self.ai_provider.model,
|
|
152
|
+
generation_config={
|
|
153
|
+
"temperature": 0.3,
|
|
154
|
+
"max_output_tokens": 8192,
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
response = model.generate_content(prompt)
|
|
158
|
+
|
|
159
|
+
if not response:
|
|
160
|
+
raise ValueError("Empty response from Gemini")
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
response_text = response.text
|
|
164
|
+
except ValueError:
|
|
165
|
+
if response.candidates and response.candidates[0].content.parts:
|
|
166
|
+
response_text = "".join(part.text for part in response.candidates[0].content.parts)
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError("Could not extract text from Gemini response")
|
|
169
|
+
|
|
170
|
+
return response_text
|
|
171
|
+
|
|
172
|
+
elif hasattr(self.ai_provider, 'client'):
|
|
173
|
+
if hasattr(self.ai_provider.client, 'chat'):
|
|
174
|
+
response = self.ai_provider.client.chat.completions.create(
|
|
175
|
+
model=self.ai_provider.model,
|
|
176
|
+
messages=[{"role": "user", "content": prompt}],
|
|
177
|
+
temperature=0.3,
|
|
178
|
+
max_tokens=4096
|
|
179
|
+
)
|
|
180
|
+
return response.choices[0].message.content
|
|
181
|
+
else:
|
|
182
|
+
response = self.ai_provider.client.messages.create(
|
|
183
|
+
model=self.ai_provider.model,
|
|
184
|
+
max_tokens=4096,
|
|
185
|
+
temperature=0.3,
|
|
186
|
+
messages=[{"role": "user", "content": prompt}]
|
|
187
|
+
)
|
|
188
|
+
return response.content[0].text
|
|
189
|
+
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError("Unknown AI provider type")
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
logger.error(f"AI provider call failed: {e}")
|
|
195
|
+
raise
|
|
196
|
+
|
|
197
|
+
def _parse_tag_response(
|
|
198
|
+
self,
|
|
199
|
+
response: str,
|
|
200
|
+
tracks: List[Dict[str, Any]]
|
|
201
|
+
) -> Dict[int, Dict[str, Any]]:
|
|
202
|
+
try:
|
|
203
|
+
response = response.strip()
|
|
204
|
+
|
|
205
|
+
if response.startswith("```"):
|
|
206
|
+
lines = response.split("\n")
|
|
207
|
+
response = "\n".join([line for line in lines if not line.startswith("```")])
|
|
208
|
+
|
|
209
|
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
|
210
|
+
if json_match:
|
|
211
|
+
response = json_match.group(0)
|
|
212
|
+
|
|
213
|
+
response = re.sub(r',\s*}', '}', response)
|
|
214
|
+
response = re.sub(r',\s*\]', ']', response)
|
|
215
|
+
|
|
216
|
+
tags_dict = json.loads(response)
|
|
217
|
+
|
|
218
|
+
result = {}
|
|
219
|
+
for track in tracks:
|
|
220
|
+
track_id = track['id']
|
|
221
|
+
track_id_str = str(track_id)
|
|
222
|
+
|
|
223
|
+
if track_id_str in tags_dict:
|
|
224
|
+
data = tags_dict[track_id_str]
|
|
225
|
+
|
|
226
|
+
if isinstance(data, dict):
|
|
227
|
+
tags = data.get('tags', [])
|
|
228
|
+
environments = data.get('environments', [])
|
|
229
|
+
instruments = data.get('instruments', [])
|
|
230
|
+
|
|
231
|
+
if isinstance(tags, list):
|
|
232
|
+
tags = [str(tag).lower().strip() for tag in tags[:5]]
|
|
233
|
+
else:
|
|
234
|
+
tags = []
|
|
235
|
+
|
|
236
|
+
if isinstance(environments, list):
|
|
237
|
+
environments = [str(env).lower().strip() for env in environments[:3]]
|
|
238
|
+
elif isinstance(environments, str):
|
|
239
|
+
environments = [str(environments).lower().strip()]
|
|
240
|
+
else:
|
|
241
|
+
environments = []
|
|
242
|
+
|
|
243
|
+
if isinstance(instruments, list):
|
|
244
|
+
instruments = [str(inst).lower().strip() for inst in instruments[:3]]
|
|
245
|
+
elif isinstance(instruments, str):
|
|
246
|
+
instruments = [str(instruments).lower().strip()]
|
|
247
|
+
else:
|
|
248
|
+
instruments = []
|
|
249
|
+
|
|
250
|
+
result[track_id] = {
|
|
251
|
+
'tags': tags,
|
|
252
|
+
'environments': environments,
|
|
253
|
+
'instruments': instruments
|
|
254
|
+
}
|
|
255
|
+
elif isinstance(data, list):
|
|
256
|
+
result[track_id] = {
|
|
257
|
+
'tags': [str(tag).lower().strip() for tag in data[:5]],
|
|
258
|
+
'environments': [],
|
|
259
|
+
'instruments': []
|
|
260
|
+
}
|
|
261
|
+
else:
|
|
262
|
+
result[track_id] = {'tags': [], 'environments': [], 'instruments': []}
|
|
263
|
+
else:
|
|
264
|
+
result[track_id] = {'tags': [], 'environments': [], 'instruments': []}
|
|
265
|
+
|
|
266
|
+
return result
|
|
267
|
+
|
|
268
|
+
except json.JSONDecodeError as e:
|
|
269
|
+
logger.error(f"Failed to parse JSON response: {e}")
|
|
270
|
+
logger.debug(f"Problematic response (first 500 chars): {response[:500]}")
|
|
271
|
+
raise
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Failed to parse tag response: {e}")
|
|
274
|
+
raise
|
plexmix/cli/__init__.py
ADDED
|
File without changes
|