edge-gemma-speak 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edge_gemma_speak/__init__.py +26 -0
- edge_gemma_speak/cli.py +305 -0
- edge_gemma_speak/voice_assistant.py +661 -0
- edge_gemma_speak-0.1.0.dist-info/METADATA +376 -0
- edge_gemma_speak-0.1.0.dist-info/RECORD +9 -0
- edge_gemma_speak-0.1.0.dist-info/WHEEL +5 -0
- edge_gemma_speak-0.1.0.dist-info/entry_points.txt +2 -0
- edge_gemma_speak-0.1.0.dist-info/licenses/LICENSE +21 -0
- edge_gemma_speak-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Edge Gemma Speak - Edge-based voice assistant using Gemma LLM with STT and TTS capabilities
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .voice_assistant import (
|
6
|
+
VoiceAssistant,
|
7
|
+
STTModule,
|
8
|
+
LLMModule,
|
9
|
+
TTSModule,
|
10
|
+
AudioConfig,
|
11
|
+
ModelConfig,
|
12
|
+
main
|
13
|
+
)
|
14
|
+
|
15
|
+
__version__ = "0.1.0"
|
16
|
+
__author__ = "MimicLab, Sogang University"
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
"VoiceAssistant",
|
20
|
+
"STTModule",
|
21
|
+
"LLMModule",
|
22
|
+
"TTSModule",
|
23
|
+
"AudioConfig",
|
24
|
+
"ModelConfig",
|
25
|
+
"main"
|
26
|
+
]
|
edge_gemma_speak/cli.py
ADDED
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Command-line interface for edge_gemma_speak
|
3
|
+
"""
|
4
|
+
|
5
|
+
import argparse
|
6
|
+
import sys
|
7
|
+
from pathlib import Path
|
8
|
+
from .voice_assistant import main as voice_assistant_main, ModelConfig, AudioConfig
|
9
|
+
|
10
|
+
|
11
|
+
def download_model():
|
12
|
+
"""Download the default Gemma model"""
|
13
|
+
import subprocess
|
14
|
+
import sys
|
15
|
+
import os
|
16
|
+
|
17
|
+
model_dir = Path.home() / ".edge_gemma_speak" / "models"
|
18
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
19
|
+
|
20
|
+
model_filename = "gemma-3-12b-it-Q4_K_M.gguf"
|
21
|
+
model_path = model_dir / model_filename
|
22
|
+
|
23
|
+
if model_path.exists():
|
24
|
+
print(f"✓ Model already exists at {model_path}")
|
25
|
+
return
|
26
|
+
|
27
|
+
model_url = "https://huggingface.co/tgisaturday/Docsray/resolve/main/gemma-3-12b-it-GGUF/gemma-3-12b-it-Q4_K_M.gguf"
|
28
|
+
|
29
|
+
print(f"Downloading Gemma model to {model_path}")
|
30
|
+
print("This may take a while depending on your internet connection...")
|
31
|
+
print()
|
32
|
+
|
33
|
+
# Try wget first, then curl
|
34
|
+
try:
|
35
|
+
# Check if wget is available
|
36
|
+
result = subprocess.run(["which", "wget"], capture_output=True, text=True)
|
37
|
+
if result.returncode == 0:
|
38
|
+
# Use wget
|
39
|
+
cmd = ["wget", "-c", model_url, "-O", str(model_path)]
|
40
|
+
print(f"Using wget: {' '.join(cmd)}")
|
41
|
+
subprocess.run(cmd, check=True)
|
42
|
+
else:
|
43
|
+
# Check if curl is available
|
44
|
+
result = subprocess.run(["which", "curl"], capture_output=True, text=True)
|
45
|
+
if result.returncode == 0:
|
46
|
+
# Use curl
|
47
|
+
cmd = ["curl", "-L", "-C", "-", model_url, "-o", str(model_path)]
|
48
|
+
print(f"Using curl: {' '.join(cmd)}")
|
49
|
+
subprocess.run(cmd, check=True)
|
50
|
+
else:
|
51
|
+
# Fallback to Python urllib
|
52
|
+
print("Neither wget nor curl found. Using Python to download...")
|
53
|
+
import urllib.request
|
54
|
+
from tqdm import tqdm
|
55
|
+
|
56
|
+
def download_with_progress(url, path):
|
57
|
+
with urllib.request.urlopen(url) as response:
|
58
|
+
total_size = int(response.headers.get('Content-Length', 0))
|
59
|
+
|
60
|
+
with open(path, 'wb') as f:
|
61
|
+
with tqdm(total=total_size, unit='iB', unit_scale=True) as pbar:
|
62
|
+
while True:
|
63
|
+
chunk = response.read(8192)
|
64
|
+
if not chunk:
|
65
|
+
break
|
66
|
+
f.write(chunk)
|
67
|
+
pbar.update(len(chunk))
|
68
|
+
|
69
|
+
download_with_progress(model_url, model_path)
|
70
|
+
|
71
|
+
print(f"\n✓ Model downloaded successfully to {model_path}")
|
72
|
+
|
73
|
+
except subprocess.CalledProcessError as e:
|
74
|
+
print(f"\n✗ Download failed with error: {e}")
|
75
|
+
if model_path.exists():
|
76
|
+
os.remove(model_path)
|
77
|
+
sys.exit(1)
|
78
|
+
except Exception as e:
|
79
|
+
print(f"\n✗ Download failed with error: {e}")
|
80
|
+
if model_path.exists():
|
81
|
+
os.remove(model_path)
|
82
|
+
sys.exit(1)
|
83
|
+
|
84
|
+
|
85
|
+
def main():
|
86
|
+
parser = argparse.ArgumentParser(description="Edge Gemma Speak - Voice Assistant")
|
87
|
+
parser.add_argument("--model", type=str, help="Path to GGUF model file")
|
88
|
+
parser.add_argument("--stt-model", type=str, default="base",
|
89
|
+
choices=["tiny", "base", "small", "medium", "large"],
|
90
|
+
help="Whisper model size for STT")
|
91
|
+
parser.add_argument("--device", type=str, default=None,
|
92
|
+
choices=["cpu", "cuda", "mps", "auto"],
|
93
|
+
help="Device to use for inference (default: auto-detect)")
|
94
|
+
parser.add_argument("--voice", type=str, default="multilingual",
|
95
|
+
help="TTS voice: use preset (male/female/multilingual) or any Edge-TTS voice name")
|
96
|
+
parser.add_argument("--download-model", action="store_true",
|
97
|
+
help="Download the default Gemma model")
|
98
|
+
parser.add_argument("--list-voices", action="store_true",
|
99
|
+
help="List all available Korean TTS voices")
|
100
|
+
|
101
|
+
# STT 파라미터
|
102
|
+
parser.add_argument("--stt-language", type=str, default="ko",
|
103
|
+
help="STT language (default: ko)")
|
104
|
+
parser.add_argument("--stt-beam-size", type=int, default=5,
|
105
|
+
help="STT beam size for decoding (default: 5)")
|
106
|
+
parser.add_argument("--stt-temperature", type=float, default=0.0,
|
107
|
+
help="STT temperature for sampling (default: 0.0)")
|
108
|
+
parser.add_argument("--stt-vad-threshold", type=float, default=0.5,
|
109
|
+
help="STT VAD threshold (default: 0.5)")
|
110
|
+
|
111
|
+
# LLM 파라미터
|
112
|
+
parser.add_argument("--llm-max-tokens", type=int, default=512,
|
113
|
+
help="Maximum tokens for LLM response (default: 512)")
|
114
|
+
parser.add_argument("--llm-temperature", type=float, default=0.7,
|
115
|
+
help="LLM temperature for sampling (default: 0.7)")
|
116
|
+
parser.add_argument("--llm-top-p", type=float, default=0.95,
|
117
|
+
help="LLM top-p for nucleus sampling (default: 0.95)")
|
118
|
+
parser.add_argument("--llm-context-size", type=int, default=4096,
|
119
|
+
help="LLM context window size (default: 4096)")
|
120
|
+
|
121
|
+
args = parser.parse_args()
|
122
|
+
|
123
|
+
if args.download_model:
|
124
|
+
download_model()
|
125
|
+
sys.exit(0)
|
126
|
+
|
127
|
+
if args.list_voices:
|
128
|
+
import subprocess
|
129
|
+
import json
|
130
|
+
|
131
|
+
print("\nFetching available Edge-TTS voices...")
|
132
|
+
print("=" * 70)
|
133
|
+
|
134
|
+
try:
|
135
|
+
# Run edge-tts --list-voices command
|
136
|
+
result = subprocess.run(
|
137
|
+
["edge-tts", "--list-voices"],
|
138
|
+
capture_output=True,
|
139
|
+
text=True,
|
140
|
+
check=True
|
141
|
+
)
|
142
|
+
|
143
|
+
# Parse the table output
|
144
|
+
voices = []
|
145
|
+
lines = result.stdout.strip().split('\n')
|
146
|
+
|
147
|
+
# Skip header lines (first two lines are header and separator)
|
148
|
+
if len(lines) > 2:
|
149
|
+
for line in lines[2:]:
|
150
|
+
# Parse table columns
|
151
|
+
parts = line.split()
|
152
|
+
if len(parts) >= 2:
|
153
|
+
name = parts[0]
|
154
|
+
gender = parts[1]
|
155
|
+
|
156
|
+
# Extract locale from name (format: lang-COUNTRY-NameNeural)
|
157
|
+
locale_parts = name.split('-')
|
158
|
+
if len(locale_parts) >= 2:
|
159
|
+
locale = f"{locale_parts[0]}-{locale_parts[1]}"
|
160
|
+
else:
|
161
|
+
locale = "Unknown"
|
162
|
+
|
163
|
+
voices.append({
|
164
|
+
'name': name,
|
165
|
+
'gender': gender,
|
166
|
+
'locale': locale
|
167
|
+
})
|
168
|
+
|
169
|
+
# Group by language
|
170
|
+
voices_by_lang = {}
|
171
|
+
for voice in voices:
|
172
|
+
if 'locale' in voice:
|
173
|
+
lang = voice['locale'].split('-')[0]
|
174
|
+
if lang not in voices_by_lang:
|
175
|
+
voices_by_lang[lang] = []
|
176
|
+
voices_by_lang[lang].append(voice)
|
177
|
+
|
178
|
+
# Display voices grouped by language
|
179
|
+
print("\nAvailable Edge-TTS voices by language:")
|
180
|
+
print("=" * 70)
|
181
|
+
|
182
|
+
# Show Korean voices first
|
183
|
+
if 'ko' in voices_by_lang:
|
184
|
+
print("\n[Korean Voices]")
|
185
|
+
for voice in voices_by_lang['ko']:
|
186
|
+
gender = voice.get('gender', 'Unknown')
|
187
|
+
print(f" {voice['name']:<35} ({gender}, {voice.get('locale', 'Unknown')})")
|
188
|
+
|
189
|
+
# Show other popular languages
|
190
|
+
for lang in ['en', 'ja', 'zh', 'es', 'fr', 'de']:
|
191
|
+
if lang in voices_by_lang:
|
192
|
+
lang_name = {
|
193
|
+
'en': 'English', 'ja': 'Japanese', 'zh': 'Chinese',
|
194
|
+
'es': 'Spanish', 'fr': 'French', 'de': 'German'
|
195
|
+
}.get(lang, lang.upper())
|
196
|
+
print(f"\n[{lang_name} Voices]")
|
197
|
+
for voice in voices_by_lang[lang][:5]: # Show first 5 voices
|
198
|
+
gender = voice.get('gender', 'Unknown')
|
199
|
+
print(f" {voice['name']:<35} ({gender}, {voice.get('locale', 'Unknown')})")
|
200
|
+
if len(voices_by_lang[lang]) > 5:
|
201
|
+
print(f" ... and {len(voices_by_lang[lang]) - 5} more")
|
202
|
+
|
203
|
+
# Show total count
|
204
|
+
print(f"\n\nTotal voices available: {len(voices)}")
|
205
|
+
print("\nUsage examples:")
|
206
|
+
print(" edge-gemma-speak --voice ko-KR-InJoonNeural")
|
207
|
+
print(" edge-gemma-speak --voice en-US-JennyNeural")
|
208
|
+
print(" edge-gemma-speak --voice ja-JP-NanamiNeural")
|
209
|
+
|
210
|
+
print("\nQuick presets:")
|
211
|
+
print(" edge-gemma-speak --voice male (Korean male)")
|
212
|
+
print(" edge-gemma-speak --voice female (Korean female)")
|
213
|
+
|
214
|
+
except subprocess.CalledProcessError:
|
215
|
+
print("Error: Could not fetch Edge-TTS voices.")
|
216
|
+
print("Make sure edge-tts is installed: pip install edge-tts")
|
217
|
+
except Exception as e:
|
218
|
+
print(f"Error: {e}")
|
219
|
+
|
220
|
+
print("=" * 70)
|
221
|
+
sys.exit(0)
|
222
|
+
|
223
|
+
# Create configurations
|
224
|
+
|
225
|
+
# Map voice choice to actual voice name
|
226
|
+
voice_map = {
|
227
|
+
"male": "ko-KR-InJoonNeural",
|
228
|
+
"female": "ko-KR-SunHiNeural",
|
229
|
+
"multilingual": "ko-KR-HyunsuMultilingualNeural"
|
230
|
+
}
|
231
|
+
|
232
|
+
# Use preset or direct voice name
|
233
|
+
tts_voice = voice_map.get(args.voice, args.voice)
|
234
|
+
|
235
|
+
# Extract language from TTS voice if STT language not explicitly set
|
236
|
+
if args.stt_language == "ko" and tts_voice not in voice_map.values():
|
237
|
+
# Extract language code from voice name (e.g., "en-US-JennyNeural" -> "en")
|
238
|
+
voice_parts = tts_voice.split('-')
|
239
|
+
if len(voice_parts) >= 2:
|
240
|
+
voice_lang = voice_parts[0]
|
241
|
+
# Only override if it's a known language code
|
242
|
+
if voice_lang in ['en', 'ja', 'zh', 'es', 'fr', 'de', 'ko']:
|
243
|
+
stt_language = voice_lang
|
244
|
+
print(f"Notice: STT language automatically set to '{stt_language}' to match TTS voice '{tts_voice}'")
|
245
|
+
else:
|
246
|
+
stt_language = args.stt_language
|
247
|
+
else:
|
248
|
+
stt_language = args.stt_language
|
249
|
+
else:
|
250
|
+
stt_language = args.stt_language
|
251
|
+
|
252
|
+
# Warn if languages don't match
|
253
|
+
if tts_voice not in voice_map:
|
254
|
+
voice_lang = tts_voice.split('-')[0] if '-' in tts_voice else None
|
255
|
+
if voice_lang and voice_lang != stt_language:
|
256
|
+
print(f"Warning: TTS voice language '{voice_lang}' doesn't match STT language '{stt_language}'")
|
257
|
+
print(f" Consider using --stt-language {voice_lang} for better recognition")
|
258
|
+
|
259
|
+
# Set device (auto-detection will happen in ModelConfig.__post_init__)
|
260
|
+
device = args.device if args.device else "auto"
|
261
|
+
|
262
|
+
model_config = ModelConfig(
|
263
|
+
stt_model=args.stt_model,
|
264
|
+
llm_model=args.model,
|
265
|
+
device=device,
|
266
|
+
# STT parameters
|
267
|
+
stt_language=stt_language,
|
268
|
+
stt_beam_size=args.stt_beam_size,
|
269
|
+
stt_temperature=args.stt_temperature,
|
270
|
+
stt_vad_threshold=args.stt_vad_threshold,
|
271
|
+
# TTS parameters
|
272
|
+
tts_voice=tts_voice,
|
273
|
+
# LLM parameters
|
274
|
+
llm_max_tokens=args.llm_max_tokens,
|
275
|
+
llm_temperature=args.llm_temperature,
|
276
|
+
llm_top_p=args.llm_top_p,
|
277
|
+
llm_context_size=args.llm_context_size
|
278
|
+
)
|
279
|
+
|
280
|
+
audio_config = AudioConfig()
|
281
|
+
|
282
|
+
# Run the voice assistant
|
283
|
+
try:
|
284
|
+
# Import and create voice assistant with configurations
|
285
|
+
from .voice_assistant import VoiceAssistant
|
286
|
+
|
287
|
+
assistant = VoiceAssistant(model_config, audio_config)
|
288
|
+
|
289
|
+
# Run conversation loop
|
290
|
+
assistant.run_conversation_loop()
|
291
|
+
except FileNotFoundError as e:
|
292
|
+
print(f"Error: {e}")
|
293
|
+
print("\nTo download the model, run:")
|
294
|
+
print(" edge-gemma-speak --download-model")
|
295
|
+
sys.exit(1)
|
296
|
+
except KeyboardInterrupt:
|
297
|
+
print("\n\nProgram interrupted by user.")
|
298
|
+
sys.exit(0)
|
299
|
+
except Exception as e:
|
300
|
+
print(f"Error: {e}")
|
301
|
+
sys.exit(1)
|
302
|
+
|
303
|
+
|
304
|
+
if __name__ == "__main__":
|
305
|
+
main()
|