atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
- audio_engine/asr/__init__.py +45 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +356 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +282 -0
- audio_engine/core/types.py +87 -0
- audio_engine/examples/__init__.py +1 -0
- audio_engine/examples/basic_stt_llm_tts.py +200 -0
- audio_engine/examples/geneface_animation.py +99 -0
- audio_engine/examples/personaplex_pipeline.py +116 -0
- audio_engine/examples/websocket_server.py +86 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +38 -0
- audio_engine/llm/base.py +108 -0
- audio_engine/llm/groq.py +210 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/scripts/debug_pipeline.py +79 -0
- audio_engine/scripts/debug_tts.py +162 -0
- audio_engine/scripts/test_cartesia_connect.py +57 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +341 -0
- audio_engine/tests/__init__.py +1 -0
- audio_engine/tests/test_personaplex/__init__.py +1 -0
- audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
- audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
- audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
- audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
- audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
- audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
- audio_engine/tts/__init__.py +37 -0
- audio_engine/tts/base.py +155 -0
- audio_engine/tts/cartesia.py +392 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +220 -0
- atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Tests for PersonaPlexMessage (Step 2)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from pipelines.personaplex import PersonaPlexMessage, MessageType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TestPersonaPlexMessage:
|
|
8
|
+
"""Test WebSocket message encoding and decoding."""
|
|
9
|
+
|
|
10
|
+
def test_encode_audio(self):
|
|
11
|
+
"""Can we encode an audio message?"""
|
|
12
|
+
audio_data = b"fake_opus_audio"
|
|
13
|
+
msg = PersonaPlexMessage(type=MessageType.AUDIO, data=audio_data)
|
|
14
|
+
encoded = msg.encode()
|
|
15
|
+
|
|
16
|
+
assert encoded[0] == 0x01
|
|
17
|
+
assert encoded[1:] == audio_data
|
|
18
|
+
|
|
19
|
+
def test_encode_text(self):
|
|
20
|
+
"""Can we encode a text message?"""
|
|
21
|
+
text = "Hello, world!"
|
|
22
|
+
msg = PersonaPlexMessage(type=MessageType.TEXT, data=text)
|
|
23
|
+
encoded = msg.encode()
|
|
24
|
+
|
|
25
|
+
assert encoded[0] == 0x02
|
|
26
|
+
assert encoded[1:] == text.encode("utf-8")
|
|
27
|
+
|
|
28
|
+
def test_decode_audio(self):
|
|
29
|
+
"""Can we decode an audio message?"""
|
|
30
|
+
audio_data = b"fake_opus_audio"
|
|
31
|
+
raw = bytes([0x01]) + audio_data
|
|
32
|
+
|
|
33
|
+
msg = PersonaPlexMessage.decode(raw)
|
|
34
|
+
assert msg.type == MessageType.AUDIO
|
|
35
|
+
assert msg.data == audio_data
|
|
36
|
+
|
|
37
|
+
def test_decode_text(self):
|
|
38
|
+
"""Can we decode a text message?"""
|
|
39
|
+
text = "Hello, assistant!"
|
|
40
|
+
raw = bytes([0x02]) + text.encode("utf-8")
|
|
41
|
+
|
|
42
|
+
msg = PersonaPlexMessage.decode(raw)
|
|
43
|
+
assert msg.type == MessageType.TEXT
|
|
44
|
+
assert msg.data == text
|
|
45
|
+
|
|
46
|
+
def test_decode_error(self):
|
|
47
|
+
"""Can we decode an error message?"""
|
|
48
|
+
error_text = b"Connection timeout"
|
|
49
|
+
raw = bytes([0x05]) + error_text
|
|
50
|
+
|
|
51
|
+
msg = PersonaPlexMessage.decode(raw)
|
|
52
|
+
assert msg.type == MessageType.ERROR
|
|
53
|
+
assert msg.data == error_text
|
|
54
|
+
|
|
55
|
+
def test_roundtrip_audio(self):
|
|
56
|
+
"""Audio message roundtrip: encode then decode?"""
|
|
57
|
+
original_data = b"\x00\x01\x02\x03\x04\x05"
|
|
58
|
+
msg1 = PersonaPlexMessage(type=MessageType.AUDIO, data=original_data)
|
|
59
|
+
|
|
60
|
+
encoded = msg1.encode()
|
|
61
|
+
msg2 = PersonaPlexMessage.decode(encoded)
|
|
62
|
+
|
|
63
|
+
assert msg2.type == msg1.type
|
|
64
|
+
assert msg2.data == original_data
|
|
65
|
+
|
|
66
|
+
def test_roundtrip_text(self):
|
|
67
|
+
"""Text message roundtrip: encode then decode?"""
|
|
68
|
+
original_text = "This is a test message with émojis 🎉"
|
|
69
|
+
msg1 = PersonaPlexMessage(type=MessageType.TEXT, data=original_text)
|
|
70
|
+
|
|
71
|
+
encoded = msg1.encode()
|
|
72
|
+
msg2 = PersonaPlexMessage.decode(encoded)
|
|
73
|
+
|
|
74
|
+
assert msg2.type == msg1.type
|
|
75
|
+
assert msg2.data == original_text
|
|
76
|
+
|
|
77
|
+
def test_decode_too_short(self):
|
|
78
|
+
"""Does decode reject empty message?"""
|
|
79
|
+
with pytest.raises(ValueError):
|
|
80
|
+
PersonaPlexMessage.decode(b"")
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""Tests for PersonaPlexPipeline lifecycle and orchestration (Step 6-7)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from unittest.mock import AsyncMock
|
|
5
|
+
|
|
6
|
+
from pipelines.personaplex import (
|
|
7
|
+
PersonaPlexPipeline,
|
|
8
|
+
PersonaPlexConfig,
|
|
9
|
+
SessionData,
|
|
10
|
+
AudioChunk,
|
|
11
|
+
TextChunk,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestPersonaPlexPipelineInit:
|
|
16
|
+
"""Test pipeline initialization."""
|
|
17
|
+
|
|
18
|
+
def test_pipeline_init_with_defaults(self):
|
|
19
|
+
"""Can we create a pipeline with defaults?"""
|
|
20
|
+
pipeline = PersonaPlexPipeline()
|
|
21
|
+
|
|
22
|
+
assert pipeline.session_id is not None
|
|
23
|
+
assert pipeline.session_data is not None
|
|
24
|
+
assert pipeline.system_prompt is not None
|
|
25
|
+
|
|
26
|
+
def test_pipeline_init_with_custom_config(self):
|
|
27
|
+
"""Can we create pipeline with custom config?"""
|
|
28
|
+
config = PersonaPlexConfig(voice_prompt="NATM0.pt")
|
|
29
|
+
pipeline = PersonaPlexPipeline(config=config)
|
|
30
|
+
|
|
31
|
+
assert pipeline.config.voice_prompt == "NATM0.pt"
|
|
32
|
+
|
|
33
|
+
def test_pipeline_init_creates_session_data(self):
|
|
34
|
+
"""Does init create a SessionData object?"""
|
|
35
|
+
pipeline = PersonaPlexPipeline(system_prompt="Custom prompt")
|
|
36
|
+
|
|
37
|
+
assert isinstance(pipeline.session_data, SessionData)
|
|
38
|
+
assert pipeline.session_data.system_prompt == "Custom prompt"
|
|
39
|
+
|
|
40
|
+
def test_pipeline_is_not_running_initially(self):
|
|
41
|
+
"""Is pipeline not running after init?"""
|
|
42
|
+
pipeline = PersonaPlexPipeline()
|
|
43
|
+
|
|
44
|
+
assert not pipeline._is_running
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.mark.asyncio
|
|
48
|
+
class TestPersonaPlexPipelineLifecycle:
|
|
49
|
+
"""Test pipeline start/stop lifecycle."""
|
|
50
|
+
|
|
51
|
+
async def test_pipeline_start_connects_client(self):
|
|
52
|
+
"""Does start() connect the client?"""
|
|
53
|
+
pipeline = PersonaPlexPipeline()
|
|
54
|
+
pipeline.client = AsyncMock()
|
|
55
|
+
|
|
56
|
+
await pipeline.start()
|
|
57
|
+
|
|
58
|
+
assert pipeline._is_running
|
|
59
|
+
pipeline.client.connect.assert_called_once()
|
|
60
|
+
|
|
61
|
+
async def test_pipeline_start_creates_receive_task(self):
|
|
62
|
+
"""Does start() create a receive task?"""
|
|
63
|
+
pipeline = PersonaPlexPipeline()
|
|
64
|
+
pipeline.client = AsyncMock()
|
|
65
|
+
pipeline.client.stream_messages = AsyncMock(
|
|
66
|
+
return_value=AsyncMock(
|
|
67
|
+
__aiter__=AsyncMock(
|
|
68
|
+
return_value=AsyncMock(
|
|
69
|
+
__anext__=AsyncMock(side_effect=StopAsyncIteration)
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
await pipeline.start()
|
|
76
|
+
|
|
77
|
+
assert pipeline._receive_task is not None
|
|
78
|
+
|
|
79
|
+
async def test_pipeline_stop_disconnects_client(self):
|
|
80
|
+
"""Does stop() disconnect the client?"""
|
|
81
|
+
pipeline = PersonaPlexPipeline()
|
|
82
|
+
pipeline.client = AsyncMock()
|
|
83
|
+
pipeline._is_running = True
|
|
84
|
+
|
|
85
|
+
result = await pipeline.stop()
|
|
86
|
+
|
|
87
|
+
pipeline.client.disconnect.assert_called_once()
|
|
88
|
+
assert not pipeline._is_running
|
|
89
|
+
|
|
90
|
+
async def test_pipeline_context_manager(self):
|
|
91
|
+
"""Does context manager handle lifecycle?"""
|
|
92
|
+
pipeline = PersonaPlexPipeline()
|
|
93
|
+
pipeline.client = AsyncMock()
|
|
94
|
+
pipeline.client.stream_messages = AsyncMock(
|
|
95
|
+
return_value=AsyncMock(
|
|
96
|
+
__aiter__=AsyncMock(
|
|
97
|
+
return_value=AsyncMock(
|
|
98
|
+
__anext__=AsyncMock(side_effect=StopAsyncIteration)
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async with pipeline:
|
|
105
|
+
assert pipeline._is_running
|
|
106
|
+
|
|
107
|
+
pipeline.client.disconnect.assert_called_once()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.asyncio
|
|
111
|
+
class TestPersonaPlexPipelineSendAudio:
|
|
112
|
+
"""Test sending audio."""
|
|
113
|
+
|
|
114
|
+
async def test_pipeline_send_audio_when_running(self):
|
|
115
|
+
"""Can we send audio when pipeline is running?"""
|
|
116
|
+
pipeline = PersonaPlexPipeline()
|
|
117
|
+
pipeline.client = AsyncMock()
|
|
118
|
+
pipeline._is_running = True
|
|
119
|
+
|
|
120
|
+
audio = b"opus_audio_data"
|
|
121
|
+
await pipeline.send_audio(audio)
|
|
122
|
+
|
|
123
|
+
pipeline.client.send_audio.assert_called_once_with(audio)
|
|
124
|
+
|
|
125
|
+
async def test_pipeline_send_audio_fails_when_stopped(self):
|
|
126
|
+
"""Does send_audio fail when not running?"""
|
|
127
|
+
pipeline = PersonaPlexPipeline()
|
|
128
|
+
pipeline._is_running = False
|
|
129
|
+
|
|
130
|
+
with pytest.raises(RuntimeError):
|
|
131
|
+
await pipeline.send_audio(b"audio")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@pytest.mark.asyncio
|
|
135
|
+
class TestPersonaPlexPipelineStreaming:
|
|
136
|
+
"""Test streaming interface."""
|
|
137
|
+
|
|
138
|
+
async def test_pipeline_stream_requires_running(self):
|
|
139
|
+
"""Does stream() require pipeline to be running?"""
|
|
140
|
+
pipeline = PersonaPlexPipeline()
|
|
141
|
+
pipeline._is_running = False
|
|
142
|
+
|
|
143
|
+
with pytest.raises(RuntimeError):
|
|
144
|
+
async for _ in pipeline.stream():
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
async def test_pipeline_stream_yields_tuples(self):
|
|
148
|
+
"""Does stream() yield (audio, text) tuples?"""
|
|
149
|
+
import asyncio
|
|
150
|
+
|
|
151
|
+
pipeline = PersonaPlexPipeline()
|
|
152
|
+
pipeline._is_running = True
|
|
153
|
+
|
|
154
|
+
# Pre-populate queues
|
|
155
|
+
audio_chunk = AudioChunk(data=b"audio", sample_rate=48000)
|
|
156
|
+
text_chunk = TextChunk(text="Hello")
|
|
157
|
+
|
|
158
|
+
await pipeline._audio_queue.put(audio_chunk)
|
|
159
|
+
await pipeline._text_queue.put(text_chunk)
|
|
160
|
+
|
|
161
|
+
# Stop after one iteration
|
|
162
|
+
async def stop_soon():
|
|
163
|
+
await asyncio.sleep(0.01)
|
|
164
|
+
pipeline._is_running = False
|
|
165
|
+
|
|
166
|
+
asyncio.create_task(stop_soon())
|
|
167
|
+
|
|
168
|
+
results = []
|
|
169
|
+
async for audio, text in pipeline.stream():
|
|
170
|
+
results.append((audio, text))
|
|
171
|
+
if len(results) >= 1:
|
|
172
|
+
break
|
|
173
|
+
|
|
174
|
+
# Should get at least one result
|
|
175
|
+
assert len(results) >= 1
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@pytest.mark.asyncio
|
|
179
|
+
class TestPersonaPlexPipelineTranscript:
|
|
180
|
+
"""Test transcript management."""
|
|
181
|
+
|
|
182
|
+
def test_pipeline_session_data_stores_messages(self):
|
|
183
|
+
"""Does session_data store conversation messages?"""
|
|
184
|
+
pipeline = PersonaPlexPipeline()
|
|
185
|
+
|
|
186
|
+
pipeline.session_data.add_message("user", "Hello")
|
|
187
|
+
pipeline.session_data.add_message("assistant", "Hi there!")
|
|
188
|
+
|
|
189
|
+
assert len(pipeline.session_data.messages) == 2
|
|
190
|
+
assert pipeline.session_data.messages[0].role == "user"
|
|
191
|
+
assert pipeline.session_data.messages[1].role == "assistant"
|
|
192
|
+
|
|
193
|
+
async def test_pipeline_stop_can_save_transcript(self):
|
|
194
|
+
"""Can pipeline save transcript when configured?"""
|
|
195
|
+
import tempfile
|
|
196
|
+
from pathlib import Path
|
|
197
|
+
|
|
198
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
199
|
+
config = PersonaPlexConfig(transcript_path=tmpdir, save_transcripts=True)
|
|
200
|
+
pipeline = PersonaPlexPipeline(config=config)
|
|
201
|
+
pipeline.client = AsyncMock()
|
|
202
|
+
pipeline._is_running = True # Pretend running so stop works
|
|
203
|
+
|
|
204
|
+
# Add a message to transcript
|
|
205
|
+
pipeline.session_data.add_message("user", "Test message")
|
|
206
|
+
|
|
207
|
+
# Stop (should save because save_transcripts=True)
|
|
208
|
+
result = await pipeline.stop()
|
|
209
|
+
|
|
210
|
+
# Check transcript was saved if result is not None
|
|
211
|
+
if result is not None:
|
|
212
|
+
transcript_files = list(Path(tmpdir).glob("*.json"))
|
|
213
|
+
assert len(transcript_files) > 0
|
|
214
|
+
|
|
215
|
+
def test_pipeline_session_data_to_dict(self):
|
|
216
|
+
"""Can we convert session data to dict?"""
|
|
217
|
+
pipeline = PersonaPlexPipeline(system_prompt="Test")
|
|
218
|
+
pipeline.session_data.add_message("user", "Hello")
|
|
219
|
+
|
|
220
|
+
data = pipeline.session_data.to_dict()
|
|
221
|
+
|
|
222
|
+
assert "session_id" in data
|
|
223
|
+
assert "timestamp" in data
|
|
224
|
+
assert "system_prompt" in data
|
|
225
|
+
assert "messages" in data
|
|
226
|
+
assert len(data["messages"]) == 1
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Tests for session management utilities (Step 4)."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, UTC
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
from pipelines.personaplex import (
|
|
7
|
+
generate_session_id,
|
|
8
|
+
get_timestamp_iso,
|
|
9
|
+
list_transcripts,
|
|
10
|
+
format_transcript_for_display,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestSessionUtilities:
|
|
15
|
+
"""Test session ID and timestamp generation."""
|
|
16
|
+
|
|
17
|
+
def test_generate_session_id_returns_string(self):
|
|
18
|
+
"""Does generate_session_id return a string?"""
|
|
19
|
+
session_id = generate_session_id()
|
|
20
|
+
|
|
21
|
+
assert isinstance(session_id, str)
|
|
22
|
+
assert len(session_id) > 0
|
|
23
|
+
|
|
24
|
+
def test_generate_session_id_is_uuid4(self):
|
|
25
|
+
"""Does it generate valid UUID4 format?"""
|
|
26
|
+
session_id = generate_session_id()
|
|
27
|
+
|
|
28
|
+
# Should be valid UUID
|
|
29
|
+
uuid.UUID(session_id)
|
|
30
|
+
|
|
31
|
+
def test_generate_session_id_unique(self):
|
|
32
|
+
"""Do consecutive calls generate different IDs?"""
|
|
33
|
+
id1 = generate_session_id()
|
|
34
|
+
id2 = generate_session_id()
|
|
35
|
+
|
|
36
|
+
assert id1 != id2
|
|
37
|
+
|
|
38
|
+
def test_get_timestamp_iso_returns_string(self):
|
|
39
|
+
"""Does get_timestamp_iso return a string?"""
|
|
40
|
+
timestamp = get_timestamp_iso()
|
|
41
|
+
|
|
42
|
+
assert isinstance(timestamp, str)
|
|
43
|
+
assert len(timestamp) > 0
|
|
44
|
+
|
|
45
|
+
def test_get_timestamp_iso_ends_with_z(self):
|
|
46
|
+
"""Does timestamp end with Z suffix?"""
|
|
47
|
+
timestamp = get_timestamp_iso()
|
|
48
|
+
|
|
49
|
+
assert timestamp.endswith("Z")
|
|
50
|
+
|
|
51
|
+
def test_get_timestamp_iso_is_valid_iso8601(self):
|
|
52
|
+
"""Is the timestamp in ISO 8601 format?"""
|
|
53
|
+
timestamp = get_timestamp_iso()
|
|
54
|
+
|
|
55
|
+
# Remove Z suffix and parse
|
|
56
|
+
iso_part = timestamp[:-1]
|
|
57
|
+
parsed = datetime.fromisoformat(iso_part)
|
|
58
|
+
|
|
59
|
+
assert isinstance(parsed, datetime)
|
|
60
|
+
|
|
61
|
+
def test_get_timestamp_iso_close_to_now(self):
|
|
62
|
+
"""Is generated timestamp close to current time?"""
|
|
63
|
+
before = datetime.now(UTC)
|
|
64
|
+
timestamp = get_timestamp_iso()
|
|
65
|
+
after = datetime.now(UTC)
|
|
66
|
+
|
|
67
|
+
# Parse timestamp (remove Z and convert to UTC-aware)
|
|
68
|
+
iso_part = timestamp[:-1]
|
|
69
|
+
parsed = datetime.fromisoformat(iso_part).replace(tzinfo=UTC)
|
|
70
|
+
|
|
71
|
+
# Should be within 1 second of now
|
|
72
|
+
assert before <= parsed <= after
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TestTranscriptUtilities:
|
|
76
|
+
"""Test transcript management utilities."""
|
|
77
|
+
|
|
78
|
+
def test_format_transcript_for_display_empty(self):
|
|
79
|
+
"""Can we format an empty transcript?"""
|
|
80
|
+
from pipelines.personaplex import SessionData
|
|
81
|
+
|
|
82
|
+
session = SessionData(
|
|
83
|
+
session_id="test-format",
|
|
84
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
85
|
+
system_prompt="You are helpful",
|
|
86
|
+
voice_prompt="NATF0.pt",
|
|
87
|
+
messages=[],
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
formatted = format_transcript_for_display(session)
|
|
91
|
+
|
|
92
|
+
assert isinstance(formatted, str)
|
|
93
|
+
|
|
94
|
+
def test_format_transcript_for_display_with_messages(self):
|
|
95
|
+
"""Does formatting show user/assistant messages?"""
|
|
96
|
+
from pipelines.personaplex import SessionData, TranscriptMessage
|
|
97
|
+
|
|
98
|
+
msg1 = TranscriptMessage(
|
|
99
|
+
role="user",
|
|
100
|
+
text="Hello",
|
|
101
|
+
timestamp="2025-01-01T00:00:01Z",
|
|
102
|
+
)
|
|
103
|
+
msg2 = TranscriptMessage(
|
|
104
|
+
role="assistant",
|
|
105
|
+
text="Hi there!",
|
|
106
|
+
timestamp="2025-01-01T00:00:02Z",
|
|
107
|
+
)
|
|
108
|
+
session = SessionData(
|
|
109
|
+
session_id="test-format-msg",
|
|
110
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
111
|
+
system_prompt="You are helpful",
|
|
112
|
+
voice_prompt="NATF0.pt",
|
|
113
|
+
messages=[msg1, msg2],
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
formatted = format_transcript_for_display(session)
|
|
117
|
+
|
|
118
|
+
# Should contain message text
|
|
119
|
+
assert "Hello" in formatted
|
|
120
|
+
assert "Hi there!" in formatted
|
|
121
|
+
|
|
122
|
+
def test_format_transcript_for_display_labels_speakers(self):
|
|
123
|
+
"""Does formatting include speaker labels?"""
|
|
124
|
+
from pipelines.personaplex import SessionData, TranscriptMessage
|
|
125
|
+
|
|
126
|
+
msg = TranscriptMessage(
|
|
127
|
+
role="user",
|
|
128
|
+
text="Test message",
|
|
129
|
+
timestamp="2025-01-01T00:00:01Z",
|
|
130
|
+
)
|
|
131
|
+
session = SessionData(
|
|
132
|
+
session_id="test-labels",
|
|
133
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
134
|
+
system_prompt="You are helpful",
|
|
135
|
+
voice_prompt="NATF0.pt",
|
|
136
|
+
messages=[msg],
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
formatted = format_transcript_for_display(session)
|
|
140
|
+
|
|
141
|
+
# Should show role/speaker
|
|
142
|
+
assert "user" in formatted.lower() or "User" in formatted
|
|
143
|
+
|
|
144
|
+
def test_list_transcripts_returns_list(self):
|
|
145
|
+
"""Does list_transcripts return a list?"""
|
|
146
|
+
import tempfile
|
|
147
|
+
|
|
148
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
149
|
+
transcripts = list_transcripts(tmpdir)
|
|
150
|
+
|
|
151
|
+
assert isinstance(transcripts, list)
|
|
152
|
+
|
|
153
|
+
def test_list_transcripts_empty_dir(self):
|
|
154
|
+
"""Does it return empty list for empty directory?"""
|
|
155
|
+
import tempfile
|
|
156
|
+
|
|
157
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
158
|
+
transcripts = list_transcripts(tmpdir)
|
|
159
|
+
|
|
160
|
+
assert len(transcripts) == 0
|
|
161
|
+
|
|
162
|
+
def test_list_transcripts_finds_json_files(self):
|
|
163
|
+
"""Does it find transcript JSON files?"""
|
|
164
|
+
import tempfile
|
|
165
|
+
from pathlib import Path
|
|
166
|
+
|
|
167
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
168
|
+
from pipelines.personaplex import save_transcript, SessionData
|
|
169
|
+
|
|
170
|
+
# Create a transcript
|
|
171
|
+
session = SessionData(
|
|
172
|
+
session_id="list-test",
|
|
173
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
174
|
+
system_prompt="You are helpful",
|
|
175
|
+
voice_prompt="NATF0.pt",
|
|
176
|
+
messages=[],
|
|
177
|
+
)
|
|
178
|
+
save_transcript(session, tmpdir)
|
|
179
|
+
|
|
180
|
+
# List should find it
|
|
181
|
+
transcripts = list_transcripts(tmpdir)
|
|
182
|
+
|
|
183
|
+
assert len(transcripts) >= 1
|
|
184
|
+
assert any("list-test" in str(t) for t in transcripts)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Tests for transcript persistence (Step 3)."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import tempfile
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from pipelines.personaplex import (
|
|
8
|
+
SessionData,
|
|
9
|
+
TranscriptMessage,
|
|
10
|
+
save_transcript,
|
|
11
|
+
load_transcript,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestTranscriptPersistence:
|
|
16
|
+
"""Test transcript saving and loading."""
|
|
17
|
+
|
|
18
|
+
def test_save_transcript_creates_file(self):
|
|
19
|
+
"""Does save_transcript create a JSON file?"""
|
|
20
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
21
|
+
session = SessionData(
|
|
22
|
+
session_id="test-123",
|
|
23
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
24
|
+
system_prompt="You are helpful",
|
|
25
|
+
voice_prompt="NATF0.pt",
|
|
26
|
+
messages=[],
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
saved_path = save_transcript(session, Path(tmpdir))
|
|
30
|
+
|
|
31
|
+
assert saved_path.exists()
|
|
32
|
+
assert saved_path.suffix == ".json"
|
|
33
|
+
|
|
34
|
+
def test_save_transcript_contains_correct_structure(self):
|
|
35
|
+
"""Does saved JSON have session_id, timestamp, and messages?"""
|
|
36
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
37
|
+
msg = TranscriptMessage(
|
|
38
|
+
role="user",
|
|
39
|
+
text="Hello",
|
|
40
|
+
timestamp="2025-01-01T00:00:01Z",
|
|
41
|
+
)
|
|
42
|
+
session = SessionData(
|
|
43
|
+
session_id="test-456",
|
|
44
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
45
|
+
system_prompt="You are helpful",
|
|
46
|
+
voice_prompt="NATF0.pt",
|
|
47
|
+
messages=[msg],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
saved_path = save_transcript(session, Path(tmpdir))
|
|
51
|
+
|
|
52
|
+
with open(saved_path, "r") as f:
|
|
53
|
+
data = json.load(f)
|
|
54
|
+
|
|
55
|
+
assert data["session_id"] == "test-456"
|
|
56
|
+
assert data["timestamp"] == "2025-01-01T00:00:00Z"
|
|
57
|
+
assert len(data["messages"]) == 1
|
|
58
|
+
assert data["messages"][0]["role"] == "user"
|
|
59
|
+
assert data["messages"][0]["text"] == "Hello"
|
|
60
|
+
|
|
61
|
+
def test_load_transcript_reads_file(self):
|
|
62
|
+
"""Can we load a transcript from JSON?"""
|
|
63
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
64
|
+
# Create and save a session
|
|
65
|
+
msg = TranscriptMessage(
|
|
66
|
+
role="assistant",
|
|
67
|
+
text="Hi there!",
|
|
68
|
+
timestamp="2025-01-01T00:00:02Z",
|
|
69
|
+
)
|
|
70
|
+
original = SessionData(
|
|
71
|
+
session_id="test-789",
|
|
72
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
73
|
+
system_prompt="You are helpful",
|
|
74
|
+
voice_prompt="NATF0.pt",
|
|
75
|
+
messages=[msg],
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
saved_path = save_transcript(original, Path(tmpdir))
|
|
79
|
+
|
|
80
|
+
# Load it back
|
|
81
|
+
loaded = load_transcript(saved_path)
|
|
82
|
+
|
|
83
|
+
assert loaded.session_id == "test-789"
|
|
84
|
+
assert loaded.timestamp == "2025-01-01T00:00:00Z"
|
|
85
|
+
assert len(loaded.messages) == 1
|
|
86
|
+
assert loaded.messages[0].role == "assistant"
|
|
87
|
+
assert loaded.messages[0].text == "Hi there!"
|
|
88
|
+
|
|
89
|
+
def test_transcript_roundtrip(self):
|
|
90
|
+
"""Session → save → load → Session equals original?"""
|
|
91
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
92
|
+
msg1 = TranscriptMessage(
|
|
93
|
+
role="user",
|
|
94
|
+
text="First message",
|
|
95
|
+
timestamp="2025-01-01T00:00:01Z",
|
|
96
|
+
)
|
|
97
|
+
msg2 = TranscriptMessage(
|
|
98
|
+
role="assistant",
|
|
99
|
+
text="Response",
|
|
100
|
+
timestamp="2025-01-01T00:00:02Z",
|
|
101
|
+
)
|
|
102
|
+
original = SessionData(
|
|
103
|
+
session_id="roundtrip-123",
|
|
104
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
105
|
+
system_prompt="You are helpful",
|
|
106
|
+
voice_prompt="NATF0.pt",
|
|
107
|
+
messages=[msg1, msg2],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
saved_path = save_transcript(original, Path(tmpdir))
|
|
111
|
+
loaded = load_transcript(saved_path)
|
|
112
|
+
|
|
113
|
+
# Compare all fields
|
|
114
|
+
assert loaded.session_id == original.session_id
|
|
115
|
+
assert loaded.timestamp == original.timestamp
|
|
116
|
+
assert loaded.system_prompt == original.system_prompt
|
|
117
|
+
assert loaded.voice_prompt == original.voice_prompt
|
|
118
|
+
assert len(loaded.messages) == len(original.messages)
|
|
119
|
+
for i, msg in enumerate(original.messages):
|
|
120
|
+
assert loaded.messages[i].role == msg.role
|
|
121
|
+
assert loaded.messages[i].text == msg.text
|
|
122
|
+
assert loaded.messages[i].timestamp == msg.timestamp
|
|
123
|
+
|
|
124
|
+
def test_transcript_add_message(self):
|
|
125
|
+
"""Can we add messages to a SessionData?"""
|
|
126
|
+
session = SessionData(
|
|
127
|
+
session_id="add-test",
|
|
128
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
129
|
+
system_prompt="You are helpful",
|
|
130
|
+
voice_prompt="NATF0.pt",
|
|
131
|
+
messages=[],
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
assert len(session.messages) == 0
|
|
135
|
+
|
|
136
|
+
session.add_message("user", "Test")
|
|
137
|
+
|
|
138
|
+
assert len(session.messages) == 1
|
|
139
|
+
assert session.messages[0].text == "Test"
|
|
140
|
+
|
|
141
|
+
def test_transcript_to_dict(self):
|
|
142
|
+
"""Can we convert SessionData to dict?"""
|
|
143
|
+
msg = TranscriptMessage(
|
|
144
|
+
role="user",
|
|
145
|
+
text="Convert me",
|
|
146
|
+
timestamp="2025-01-01T00:00:01Z",
|
|
147
|
+
)
|
|
148
|
+
session = SessionData(
|
|
149
|
+
session_id="dict-test",
|
|
150
|
+
timestamp="2025-01-01T00:00:00Z",
|
|
151
|
+
system_prompt="You are helpful",
|
|
152
|
+
voice_prompt="NATF0.pt",
|
|
153
|
+
messages=[msg],
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
data = session.to_dict()
|
|
157
|
+
|
|
158
|
+
assert isinstance(data, dict)
|
|
159
|
+
assert data["session_id"] == "dict-test"
|
|
160
|
+
assert data["timestamp"] == "2025-01-01T00:00:00Z"
|
|
161
|
+
assert len(data["messages"]) == 1
|
|
162
|
+
|
|
163
|
+
def test_transcript_from_dict(self):
|
|
164
|
+
"""Can we create SessionData from dict?"""
|
|
165
|
+
data = {
|
|
166
|
+
"session_id": "from-dict-test",
|
|
167
|
+
"timestamp": "2025-01-01T00:00:00Z",
|
|
168
|
+
"system_prompt": "You are helpful",
|
|
169
|
+
"voice_prompt": "NATF0.pt",
|
|
170
|
+
"messages": [
|
|
171
|
+
{
|
|
172
|
+
"role": "user",
|
|
173
|
+
"text": "Recreate me",
|
|
174
|
+
"timestamp": "2025-01-01T00:00:01Z",
|
|
175
|
+
}
|
|
176
|
+
],
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
session = SessionData.from_dict(data)
|
|
180
|
+
|
|
181
|
+
assert session.session_id == "from-dict-test"
|
|
182
|
+
assert session.system_prompt == "You are helpful"
|
|
183
|
+
assert len(session.messages) == 1
|
|
184
|
+
assert session.messages[0].text == "Recreate me"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the audio engine."""
|