videopython 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- videopython/__init__.py +3 -0
- videopython/ai/__init__.py +39 -0
- videopython/ai/backends.py +77 -0
- videopython/ai/config.py +125 -0
- videopython/ai/generation/__init__.py +11 -0
- videopython/ai/generation/audio.py +285 -0
- videopython/ai/generation/image.py +117 -0
- videopython/ai/generation/video.py +366 -0
- videopython/ai/understanding/__init__.py +30 -0
- videopython/ai/understanding/audio.py +436 -0
- videopython/ai/understanding/color.py +135 -0
- videopython/ai/understanding/detection.py +914 -0
- videopython/ai/understanding/image.py +274 -0
- videopython/ai/understanding/motion.py +323 -0
- videopython/ai/understanding/text.py +208 -0
- videopython/ai/understanding/video.py +570 -0
- videopython/base/__init__.py +82 -0
- videopython/base/audio/__init__.py +12 -0
- videopython/base/audio/analysis.py +92 -0
- videopython/base/audio/audio.py +992 -0
- videopython/base/combine.py +60 -0
- videopython/base/description.py +345 -0
- videopython/base/effects.py +231 -0
- videopython/base/exceptions.py +2 -0
- videopython/base/scene.py +457 -0
- videopython/base/text/__init__.py +12 -0
- videopython/base/text/overlay.py +1108 -0
- videopython/base/text/transcription.py +289 -0
- videopython/base/transforms.py +269 -0
- videopython/base/transitions.py +117 -0
- videopython/base/utils.py +6 -0
- videopython/base/video.py +833 -0
- videopython/py.typed +0 -0
- videopython-0.9.1.dist-info/METADATA +145 -0
- videopython-0.9.1.dist-info/RECORD +37 -0
- videopython-0.9.1.dist-info/WHEEL +4 -0
- videopython-0.9.1.dist-info/licenses/LICENSE +192 -0
videopython/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from .generation import ImageToVideo, TextToImage, TextToMusic, TextToSpeech, TextToVideo
|
|
2
|
+
from .understanding import (
|
|
3
|
+
AudioClassifier,
|
|
4
|
+
AudioToText,
|
|
5
|
+
CameraMotionDetector,
|
|
6
|
+
CombinedFrameAnalyzer,
|
|
7
|
+
FaceDetector,
|
|
8
|
+
ImageToText,
|
|
9
|
+
LLMSummarizer,
|
|
10
|
+
MotionAnalyzer,
|
|
11
|
+
ObjectDetector,
|
|
12
|
+
ShotTypeClassifier,
|
|
13
|
+
TextDetector,
|
|
14
|
+
VideoAnalyzer,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
# Generation
|
|
19
|
+
"TextToVideo",
|
|
20
|
+
"ImageToVideo",
|
|
21
|
+
"TextToImage",
|
|
22
|
+
"TextToSpeech",
|
|
23
|
+
"TextToMusic",
|
|
24
|
+
# Understanding
|
|
25
|
+
"AudioToText",
|
|
26
|
+
"AudioClassifier",
|
|
27
|
+
"ImageToText",
|
|
28
|
+
"LLMSummarizer",
|
|
29
|
+
"VideoAnalyzer",
|
|
30
|
+
# Detection
|
|
31
|
+
"ObjectDetector",
|
|
32
|
+
"FaceDetector",
|
|
33
|
+
"TextDetector",
|
|
34
|
+
"ShotTypeClassifier",
|
|
35
|
+
"CameraMotionDetector",
|
|
36
|
+
"CombinedFrameAnalyzer",
|
|
37
|
+
# Motion
|
|
38
|
+
"MotionAnalyzer",
|
|
39
|
+
]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Backend utilities for videopython.ai module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
# Backend type definitions per task
|
|
9
|
+
TextToVideoBackend = Literal["local", "luma"]
|
|
10
|
+
ImageToVideoBackend = Literal["local", "luma", "runway"]
|
|
11
|
+
TextToSpeechBackend = Literal["local", "openai", "elevenlabs"]
|
|
12
|
+
TextToMusicBackend = Literal["local"]
|
|
13
|
+
TextToImageBackend = Literal["local", "openai"]
|
|
14
|
+
ImageToTextBackend = Literal["local", "openai", "gemini"]
|
|
15
|
+
AudioToTextBackend = Literal["local", "openai", "gemini"]
|
|
16
|
+
AudioClassifierBackend = Literal["local"]
|
|
17
|
+
LLMBackend = Literal["local", "openai", "gemini"]
|
|
18
|
+
|
|
19
|
+
# Environment variable names per provider
|
|
20
|
+
API_KEY_ENV_VARS: dict[str, str] = {
|
|
21
|
+
"openai": "OPENAI_API_KEY",
|
|
22
|
+
"gemini": "GOOGLE_API_KEY",
|
|
23
|
+
"elevenlabs": "ELEVENLABS_API_KEY",
|
|
24
|
+
"runway": "RUNWAYML_API_KEY",
|
|
25
|
+
"luma": "LUMAAI_API_KEY",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BackendError(Exception):
|
|
30
|
+
"""Base exception for backend-related errors."""
|
|
31
|
+
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MissingAPIKeyError(BackendError):
|
|
36
|
+
"""Raised when a required API key is not found."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, provider: str):
|
|
39
|
+
env_var = API_KEY_ENV_VARS.get(provider, f"{provider.upper()}_API_KEY")
|
|
40
|
+
super().__init__(
|
|
41
|
+
f"API key for '{provider}' not found. Set the {env_var} environment variable or pass api_key parameter."
|
|
42
|
+
)
|
|
43
|
+
self.provider = provider
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class UnsupportedBackendError(BackendError):
|
|
47
|
+
"""Raised when an unsupported backend is requested."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, backend: str, supported: list[str]):
|
|
50
|
+
super().__init__(f"Backend '{backend}' is not supported. Supported backends: {', '.join(supported)}")
|
|
51
|
+
self.backend = backend
|
|
52
|
+
self.supported = supported
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_api_key(provider: str, api_key: str | None = None) -> str:
|
|
56
|
+
"""Get API key for a provider.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
provider: Provider name (e.g., 'openai', 'runway', 'luma')
|
|
60
|
+
api_key: Optional explicit API key. If provided, returns this directly.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The API key string.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
MissingAPIKeyError: If no API key is found.
|
|
67
|
+
"""
|
|
68
|
+
if api_key:
|
|
69
|
+
return api_key
|
|
70
|
+
|
|
71
|
+
env_var = API_KEY_ENV_VARS.get(provider)
|
|
72
|
+
if env_var:
|
|
73
|
+
key = os.environ.get(env_var)
|
|
74
|
+
if key:
|
|
75
|
+
return key
|
|
76
|
+
|
|
77
|
+
raise MissingAPIKeyError(provider)
|
videopython/ai/config.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Configuration loader for videopython.ai module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import tomllib
|
|
10
|
+
|
|
11
|
+
# Default backends per task (used when no config file is found)
|
|
12
|
+
DEFAULT_BACKENDS: dict[str, str] = {
|
|
13
|
+
"text_to_video": "local",
|
|
14
|
+
"image_to_video": "local",
|
|
15
|
+
"video_upscaler": "local",
|
|
16
|
+
"text_to_speech": "local",
|
|
17
|
+
"text_to_music": "local",
|
|
18
|
+
"text_to_image": "local",
|
|
19
|
+
"image_to_text": "local",
|
|
20
|
+
"audio_to_text": "local",
|
|
21
|
+
"audio_classifier": "local",
|
|
22
|
+
"llm_summarizer": "local",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _find_config_file() -> Path | None:
|
|
27
|
+
"""Find the configuration file in current directory or parents.
|
|
28
|
+
|
|
29
|
+
Looks for:
|
|
30
|
+
1. videopython.toml in current directory
|
|
31
|
+
2. pyproject.toml in current directory
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Path to config file if found, None otherwise.
|
|
35
|
+
"""
|
|
36
|
+
cwd = Path.cwd()
|
|
37
|
+
|
|
38
|
+
# Check for videopython.toml
|
|
39
|
+
videopython_toml = cwd / "videopython.toml"
|
|
40
|
+
if videopython_toml.exists():
|
|
41
|
+
return videopython_toml
|
|
42
|
+
|
|
43
|
+
# Check for pyproject.toml
|
|
44
|
+
pyproject_toml = cwd / "pyproject.toml"
|
|
45
|
+
if pyproject_toml.exists():
|
|
46
|
+
return pyproject_toml
|
|
47
|
+
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _load_toml(path: Path) -> dict[str, Any]:
|
|
52
|
+
"""Load and parse a TOML file."""
|
|
53
|
+
with open(path, "rb") as f:
|
|
54
|
+
return tomllib.load(f)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _extract_config(data: dict[str, Any], filename: str) -> dict[str, Any]:
|
|
58
|
+
"""Extract videopython config from parsed TOML data.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
data: Parsed TOML data
|
|
62
|
+
filename: Name of the file (to determine extraction method)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The videopython configuration section, or empty dict if not found.
|
|
66
|
+
"""
|
|
67
|
+
if filename == "videopython.toml":
|
|
68
|
+
return data
|
|
69
|
+
elif filename == "pyproject.toml":
|
|
70
|
+
return data.get("tool", {}).get("videopython", {})
|
|
71
|
+
return {}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@lru_cache(maxsize=1)
|
|
75
|
+
def _get_cached_config() -> dict[str, Any]:
|
|
76
|
+
"""Load and cache the configuration.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The loaded configuration, or empty dict if no config file found.
|
|
80
|
+
"""
|
|
81
|
+
config_path = _find_config_file()
|
|
82
|
+
if config_path is None:
|
|
83
|
+
return {}
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
data = _load_toml(config_path)
|
|
87
|
+
return _extract_config(data, config_path.name)
|
|
88
|
+
except Exception:
|
|
89
|
+
return {}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_config() -> dict[str, Any]:
|
|
93
|
+
"""Get the current configuration.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The configuration dictionary.
|
|
97
|
+
"""
|
|
98
|
+
return _get_cached_config()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_default_backend(task: str) -> str:
|
|
102
|
+
"""Get the default backend for a task.
|
|
103
|
+
|
|
104
|
+
Priority:
|
|
105
|
+
1. Config file setting
|
|
106
|
+
2. Hardcoded default ("local")
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
task: Task name (e.g., 'text_to_video', 'text_to_speech')
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The backend name to use.
|
|
113
|
+
"""
|
|
114
|
+
config = get_config()
|
|
115
|
+
ai_defaults = config.get("ai", {}).get("defaults", {})
|
|
116
|
+
|
|
117
|
+
if task in ai_defaults:
|
|
118
|
+
return ai_defaults[task]
|
|
119
|
+
|
|
120
|
+
return DEFAULT_BACKENDS.get(task, "local")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def clear_config_cache() -> None:
|
|
124
|
+
"""Clear the configuration cache. Useful for testing."""
|
|
125
|
+
_get_cached_config.cache_clear()
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Audio generation with multi-backend support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from videopython.ai.backends import (
|
|
8
|
+
TextToMusicBackend,
|
|
9
|
+
TextToSpeechBackend,
|
|
10
|
+
UnsupportedBackendError,
|
|
11
|
+
get_api_key,
|
|
12
|
+
)
|
|
13
|
+
from videopython.ai.config import get_default_backend
|
|
14
|
+
from videopython.base.audio import Audio, AudioMetadata
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TextToSpeech:
|
|
18
|
+
"""Generates speech audio from text."""
|
|
19
|
+
|
|
20
|
+
SUPPORTED_BACKENDS: list[str] = ["local", "openai", "elevenlabs"]
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
backend: TextToSpeechBackend | None = None,
|
|
25
|
+
model_size: str = "base",
|
|
26
|
+
voice: str | None = None,
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
device: str | None = None,
|
|
29
|
+
):
|
|
30
|
+
"""Initialize text-to-speech generator.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
backend: Backend to use. If None, uses config default or 'local'.
|
|
34
|
+
model_size: Model size for local backend ('base' or 'small').
|
|
35
|
+
voice: Voice to use (backend-specific).
|
|
36
|
+
api_key: API key for cloud backends. If None, reads from environment.
|
|
37
|
+
device: Device for local backend ('cuda' or 'cpu').
|
|
38
|
+
"""
|
|
39
|
+
resolved_backend: str = backend if backend is not None else get_default_backend("text_to_speech")
|
|
40
|
+
if resolved_backend not in self.SUPPORTED_BACKENDS:
|
|
41
|
+
raise UnsupportedBackendError(resolved_backend, self.SUPPORTED_BACKENDS)
|
|
42
|
+
|
|
43
|
+
self.backend: TextToSpeechBackend = resolved_backend # type: ignore[assignment]
|
|
44
|
+
self.model_size = model_size
|
|
45
|
+
self.voice = voice
|
|
46
|
+
self.api_key = api_key
|
|
47
|
+
self.device = device
|
|
48
|
+
self._model: Any = None
|
|
49
|
+
self._processor: Any = None
|
|
50
|
+
|
|
51
|
+
def _init_local(self) -> None:
|
|
52
|
+
"""Initialize local Bark model."""
|
|
53
|
+
import torch
|
|
54
|
+
from transformers import AutoModel, AutoProcessor
|
|
55
|
+
|
|
56
|
+
if self.model_size not in ["base", "small"]:
|
|
57
|
+
raise ValueError(f"model_size must be 'base' or 'small', got '{self.model_size}'")
|
|
58
|
+
|
|
59
|
+
device = self.device
|
|
60
|
+
if device is None:
|
|
61
|
+
if torch.cuda.is_available():
|
|
62
|
+
device = "cuda"
|
|
63
|
+
elif torch.backends.mps.is_available():
|
|
64
|
+
device = "mps"
|
|
65
|
+
else:
|
|
66
|
+
device = "cpu"
|
|
67
|
+
|
|
68
|
+
model_name = "suno/bark" if self.model_size == "base" else "suno/bark-small"
|
|
69
|
+
self._processor = AutoProcessor.from_pretrained(model_name)
|
|
70
|
+
self._model = AutoModel.from_pretrained(model_name).to(device)
|
|
71
|
+
self.device = device
|
|
72
|
+
|
|
73
|
+
def _generate_local(
|
|
74
|
+
self,
|
|
75
|
+
text: str,
|
|
76
|
+
voice_preset: str | None,
|
|
77
|
+
) -> Audio:
|
|
78
|
+
"""Generate speech using local Bark model."""
|
|
79
|
+
import torch
|
|
80
|
+
|
|
81
|
+
if self._model is None:
|
|
82
|
+
self._init_local()
|
|
83
|
+
|
|
84
|
+
inputs = self._processor(text=[text], return_tensors="pt", voice_preset=voice_preset)
|
|
85
|
+
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
86
|
+
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
speech_values = self._model.generate(**inputs, do_sample=True)
|
|
89
|
+
|
|
90
|
+
audio_data = speech_values.cpu().float().numpy().squeeze()
|
|
91
|
+
sample_rate = self._model.generation_config.sample_rate
|
|
92
|
+
|
|
93
|
+
metadata = AudioMetadata(
|
|
94
|
+
sample_rate=sample_rate,
|
|
95
|
+
channels=1,
|
|
96
|
+
sample_width=2,
|
|
97
|
+
duration_seconds=len(audio_data) / sample_rate,
|
|
98
|
+
frame_count=len(audio_data),
|
|
99
|
+
)
|
|
100
|
+
return Audio(audio_data, metadata)
|
|
101
|
+
|
|
102
|
+
def _generate_openai(self, text: str) -> Audio:
|
|
103
|
+
"""Generate speech using OpenAI TTS."""
|
|
104
|
+
import numpy as np
|
|
105
|
+
from openai import OpenAI
|
|
106
|
+
|
|
107
|
+
api_key = get_api_key("openai", self.api_key)
|
|
108
|
+
client = OpenAI(api_key=api_key)
|
|
109
|
+
|
|
110
|
+
voice = self.voice or "alloy"
|
|
111
|
+
response = client.audio.speech.create(
|
|
112
|
+
model="tts-1-hd",
|
|
113
|
+
voice=voice, # type: ignore
|
|
114
|
+
input=text,
|
|
115
|
+
response_format="pcm",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# OpenAI returns raw PCM at 24kHz, 16-bit, mono
|
|
119
|
+
audio_bytes = response.read()
|
|
120
|
+
audio_data = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
|
121
|
+
sample_rate = 24000
|
|
122
|
+
|
|
123
|
+
metadata = AudioMetadata(
|
|
124
|
+
sample_rate=sample_rate,
|
|
125
|
+
channels=1,
|
|
126
|
+
sample_width=2,
|
|
127
|
+
duration_seconds=len(audio_data) / sample_rate,
|
|
128
|
+
frame_count=len(audio_data),
|
|
129
|
+
)
|
|
130
|
+
return Audio(audio_data, metadata)
|
|
131
|
+
|
|
132
|
+
def _generate_elevenlabs(self, text: str) -> Audio:
|
|
133
|
+
"""Generate speech using ElevenLabs."""
|
|
134
|
+
import numpy as np
|
|
135
|
+
from elevenlabs import ElevenLabs
|
|
136
|
+
|
|
137
|
+
api_key = get_api_key("elevenlabs", self.api_key)
|
|
138
|
+
client = ElevenLabs(api_key=api_key)
|
|
139
|
+
|
|
140
|
+
voice = self.voice or "Sarah"
|
|
141
|
+
|
|
142
|
+
# Resolve voice name to ID if needed (voice IDs are 20+ chars)
|
|
143
|
+
if len(voice) < 20:
|
|
144
|
+
voices = client.voices.get_all()
|
|
145
|
+
voice_id = None
|
|
146
|
+
for v in voices.voices:
|
|
147
|
+
if v.name and voice.lower() in v.name.lower():
|
|
148
|
+
voice_id = v.voice_id
|
|
149
|
+
break
|
|
150
|
+
if voice_id is None:
|
|
151
|
+
raise ValueError(f"Voice '{voice}' not found. Use a voice ID or valid name.")
|
|
152
|
+
voice = voice_id
|
|
153
|
+
|
|
154
|
+
# Generate audio - returns generator
|
|
155
|
+
audio_chunks = []
|
|
156
|
+
for chunk in client.text_to_speech.convert(
|
|
157
|
+
voice_id=voice,
|
|
158
|
+
text=text,
|
|
159
|
+
model_id="eleven_multilingual_v2",
|
|
160
|
+
output_format="pcm_24000",
|
|
161
|
+
):
|
|
162
|
+
audio_chunks.append(chunk)
|
|
163
|
+
|
|
164
|
+
audio_bytes = b"".join(audio_chunks)
|
|
165
|
+
audio_data = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
|
166
|
+
sample_rate = 24000
|
|
167
|
+
|
|
168
|
+
metadata = AudioMetadata(
|
|
169
|
+
sample_rate=sample_rate,
|
|
170
|
+
channels=1,
|
|
171
|
+
sample_width=2,
|
|
172
|
+
duration_seconds=len(audio_data) / sample_rate,
|
|
173
|
+
frame_count=len(audio_data),
|
|
174
|
+
)
|
|
175
|
+
return Audio(audio_data, metadata)
|
|
176
|
+
|
|
177
|
+
def generate_audio(
|
|
178
|
+
self,
|
|
179
|
+
text: str,
|
|
180
|
+
voice_preset: str | None = None,
|
|
181
|
+
) -> Audio:
|
|
182
|
+
"""Generate speech audio from text.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
text: Text to synthesize. For local backend, can include emotion markers
|
|
186
|
+
like [laughs], [sighs].
|
|
187
|
+
voice_preset: Voice preset (backend-specific). For local backend, use
|
|
188
|
+
IDs like "v2/en_speaker_0".
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Generated speech audio.
|
|
192
|
+
"""
|
|
193
|
+
effective_voice = voice_preset or self.voice
|
|
194
|
+
|
|
195
|
+
if self.backend == "local":
|
|
196
|
+
return self._generate_local(text, effective_voice)
|
|
197
|
+
elif self.backend == "openai":
|
|
198
|
+
return self._generate_openai(text)
|
|
199
|
+
elif self.backend == "elevenlabs":
|
|
200
|
+
return self._generate_elevenlabs(text)
|
|
201
|
+
else:
|
|
202
|
+
raise UnsupportedBackendError(self.backend, self.SUPPORTED_BACKENDS)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class TextToMusic:
|
|
206
|
+
"""Generates music from text descriptions."""
|
|
207
|
+
|
|
208
|
+
SUPPORTED_BACKENDS: list[str] = ["local"]
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
backend: TextToMusicBackend | None = None,
|
|
213
|
+
):
|
|
214
|
+
"""Initialize text-to-music generator.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
backend: Backend to use. If None, uses config default or 'local'.
|
|
218
|
+
"""
|
|
219
|
+
resolved_backend: str = backend if backend is not None else get_default_backend("text_to_music")
|
|
220
|
+
if resolved_backend not in self.SUPPORTED_BACKENDS:
|
|
221
|
+
raise UnsupportedBackendError(resolved_backend, self.SUPPORTED_BACKENDS)
|
|
222
|
+
|
|
223
|
+
self.backend: TextToMusicBackend = resolved_backend # type: ignore[assignment]
|
|
224
|
+
self._processor: Any = None
|
|
225
|
+
self._model: Any = None
|
|
226
|
+
self._device: str = "cpu"
|
|
227
|
+
|
|
228
|
+
def _init_local(self) -> None:
|
|
229
|
+
"""Initialize local MusicGen model."""
|
|
230
|
+
import os
|
|
231
|
+
|
|
232
|
+
import torch
|
|
233
|
+
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
|
234
|
+
|
|
235
|
+
# Enable MPS fallback for unsupported operations
|
|
236
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
237
|
+
|
|
238
|
+
if torch.cuda.is_available():
|
|
239
|
+
self._device = "cuda"
|
|
240
|
+
elif torch.backends.mps.is_available():
|
|
241
|
+
self._device = "mps"
|
|
242
|
+
else:
|
|
243
|
+
self._device = "cpu"
|
|
244
|
+
|
|
245
|
+
model_name = "facebook/musicgen-small"
|
|
246
|
+
self._processor = AutoProcessor.from_pretrained(model_name)
|
|
247
|
+
self._model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
|
248
|
+
self._model.to(self._device)
|
|
249
|
+
|
|
250
|
+
def _generate_local(self, text: str, max_new_tokens: int) -> Audio:
|
|
251
|
+
"""Generate music using local MusicGen model."""
|
|
252
|
+
if self._model is None:
|
|
253
|
+
self._init_local()
|
|
254
|
+
|
|
255
|
+
inputs = self._processor(text=[text], padding=True, return_tensors="pt")
|
|
256
|
+
# Move inputs to the same device as the model
|
|
257
|
+
inputs = {k: v.to(self._device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
|
258
|
+
audio_values = self._model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
259
|
+
sampling_rate = self._model.config.audio_encoder.sampling_rate
|
|
260
|
+
|
|
261
|
+
audio_data = audio_values[0, 0].cpu().float().numpy()
|
|
262
|
+
|
|
263
|
+
metadata = AudioMetadata(
|
|
264
|
+
sample_rate=sampling_rate,
|
|
265
|
+
channels=1,
|
|
266
|
+
sample_width=2,
|
|
267
|
+
duration_seconds=len(audio_data) / sampling_rate,
|
|
268
|
+
frame_count=len(audio_data),
|
|
269
|
+
)
|
|
270
|
+
return Audio(audio_data, metadata)
|
|
271
|
+
|
|
272
|
+
def generate_audio(self, text: str, max_new_tokens: int = 256) -> Audio:
|
|
273
|
+
"""Generate music audio from text description.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
text: Text description of desired music.
|
|
277
|
+
max_new_tokens: Maximum length of generated audio in tokens.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Generated music audio.
|
|
281
|
+
"""
|
|
282
|
+
if self.backend == "local":
|
|
283
|
+
return self._generate_local(text, max_new_tokens)
|
|
284
|
+
else:
|
|
285
|
+
raise UnsupportedBackendError(self.backend, self.SUPPORTED_BACKENDS)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Image generation with multi-backend support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from videopython.ai.backends import TextToImageBackend, UnsupportedBackendError, get_api_key
|
|
11
|
+
from videopython.ai.config import get_default_backend
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextToImage:
|
|
15
|
+
"""Generates images from text descriptions."""
|
|
16
|
+
|
|
17
|
+
SUPPORTED_BACKENDS: list[str] = ["local", "openai"]
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
backend: TextToImageBackend | None = None,
|
|
22
|
+
api_key: str | None = None,
|
|
23
|
+
):
|
|
24
|
+
"""Initialize text-to-image generator.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
backend: Backend to use. If None, uses config default or 'local'.
|
|
28
|
+
api_key: API key for cloud backends. If None, reads from environment.
|
|
29
|
+
"""
|
|
30
|
+
resolved_backend: str = backend if backend is not None else get_default_backend("text_to_image")
|
|
31
|
+
if resolved_backend not in self.SUPPORTED_BACKENDS:
|
|
32
|
+
raise UnsupportedBackendError(resolved_backend, self.SUPPORTED_BACKENDS)
|
|
33
|
+
|
|
34
|
+
self.backend: TextToImageBackend = resolved_backend # type: ignore[assignment]
|
|
35
|
+
self.api_key = api_key
|
|
36
|
+
self._pipeline: Any = None
|
|
37
|
+
|
|
38
|
+
def _init_local(self) -> None:
|
|
39
|
+
"""Initialize local diffusion pipeline."""
|
|
40
|
+
import torch
|
|
41
|
+
from diffusers import DiffusionPipeline
|
|
42
|
+
|
|
43
|
+
if torch.cuda.is_available():
|
|
44
|
+
device = "cuda"
|
|
45
|
+
dtype = torch.float16
|
|
46
|
+
variant = "fp16"
|
|
47
|
+
elif torch.backends.mps.is_available():
|
|
48
|
+
device = "mps"
|
|
49
|
+
dtype = torch.float32 # MPS works better with float32
|
|
50
|
+
variant = None # No fp16 variant for MPS
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("No GPU available. Local TextToImage requires CUDA or MPS (Apple Silicon).")
|
|
53
|
+
|
|
54
|
+
model_name = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
55
|
+
self._pipeline = DiffusionPipeline.from_pretrained(
|
|
56
|
+
model_name, torch_dtype=dtype, variant=variant, use_safetensors=True
|
|
57
|
+
)
|
|
58
|
+
self._pipeline.to(device)
|
|
59
|
+
|
|
60
|
+
# Enable attention slicing for MPS memory efficiency
|
|
61
|
+
if device == "mps":
|
|
62
|
+
self._pipeline.enable_attention_slicing()
|
|
63
|
+
|
|
64
|
+
def _generate_local(self, prompt: str) -> Image.Image:
|
|
65
|
+
"""Generate image using local diffusion model."""
|
|
66
|
+
if self._pipeline is None:
|
|
67
|
+
self._init_local()
|
|
68
|
+
|
|
69
|
+
return self._pipeline(prompt=prompt).images[0]
|
|
70
|
+
|
|
71
|
+
def _generate_openai(self, prompt: str, size: str) -> Image.Image:
|
|
72
|
+
"""Generate image using OpenAI DALL-E."""
|
|
73
|
+
import httpx
|
|
74
|
+
from openai import OpenAI
|
|
75
|
+
|
|
76
|
+
api_key = get_api_key("openai", self.api_key)
|
|
77
|
+
client = OpenAI(api_key=api_key)
|
|
78
|
+
|
|
79
|
+
response = client.images.generate(
|
|
80
|
+
model="dall-e-3",
|
|
81
|
+
prompt=prompt,
|
|
82
|
+
size=size, # type: ignore
|
|
83
|
+
quality="hd",
|
|
84
|
+
n=1,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
image_url = response.data[0].url
|
|
88
|
+
if image_url is None:
|
|
89
|
+
raise RuntimeError("OpenAI returned no image URL")
|
|
90
|
+
|
|
91
|
+
# Download the image
|
|
92
|
+
with httpx.Client() as http_client:
|
|
93
|
+
image_response = http_client.get(image_url, timeout=60.0)
|
|
94
|
+
image_response.raise_for_status()
|
|
95
|
+
|
|
96
|
+
return Image.open(io.BytesIO(image_response.content))
|
|
97
|
+
|
|
98
|
+
def generate_image(
|
|
99
|
+
self,
|
|
100
|
+
prompt: str,
|
|
101
|
+
size: str = "1024x1024",
|
|
102
|
+
) -> Image.Image:
|
|
103
|
+
"""Generate image from text prompt.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
prompt: Text description of desired image.
|
|
107
|
+
size: Image size (OpenAI backend only). Options: "1024x1024", "1792x1024", "1024x1792".
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Generated PIL Image.
|
|
111
|
+
"""
|
|
112
|
+
if self.backend == "local":
|
|
113
|
+
return self._generate_local(prompt)
|
|
114
|
+
elif self.backend == "openai":
|
|
115
|
+
return self._generate_openai(prompt, size)
|
|
116
|
+
else:
|
|
117
|
+
raise UnsupportedBackendError(self.backend, self.SUPPORTED_BACKENDS)
|