atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
  3. audio_engine/asr/__init__.py +45 -0
  4. audio_engine/asr/base.py +89 -0
  5. audio_engine/asr/cartesia.py +356 -0
  6. audio_engine/asr/deepgram.py +196 -0
  7. audio_engine/core/__init__.py +13 -0
  8. audio_engine/core/config.py +162 -0
  9. audio_engine/core/pipeline.py +282 -0
  10. audio_engine/core/types.py +87 -0
  11. audio_engine/examples/__init__.py +1 -0
  12. audio_engine/examples/basic_stt_llm_tts.py +200 -0
  13. audio_engine/examples/geneface_animation.py +99 -0
  14. audio_engine/examples/personaplex_pipeline.py +116 -0
  15. audio_engine/examples/websocket_server.py +86 -0
  16. audio_engine/integrations/__init__.py +5 -0
  17. audio_engine/integrations/geneface.py +297 -0
  18. audio_engine/llm/__init__.py +38 -0
  19. audio_engine/llm/base.py +108 -0
  20. audio_engine/llm/groq.py +210 -0
  21. audio_engine/pipelines/__init__.py +1 -0
  22. audio_engine/pipelines/personaplex/__init__.py +41 -0
  23. audio_engine/pipelines/personaplex/client.py +259 -0
  24. audio_engine/pipelines/personaplex/config.py +69 -0
  25. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  26. audio_engine/pipelines/personaplex/types.py +173 -0
  27. audio_engine/pipelines/personaplex/utils.py +192 -0
  28. audio_engine/scripts/debug_pipeline.py +79 -0
  29. audio_engine/scripts/debug_tts.py +162 -0
  30. audio_engine/scripts/test_cartesia_connect.py +57 -0
  31. audio_engine/streaming/__init__.py +5 -0
  32. audio_engine/streaming/websocket_server.py +341 -0
  33. audio_engine/tests/__init__.py +1 -0
  34. audio_engine/tests/test_personaplex/__init__.py +1 -0
  35. audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
  36. audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
  37. audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
  38. audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
  39. audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
  40. audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
  41. audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
  42. audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
  43. audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
  44. audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
  45. audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
  46. audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
  47. audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
  48. audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
  49. audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
  50. audio_engine/tts/__init__.py +37 -0
  51. audio_engine/tts/base.py +155 -0
  52. audio_engine/tts/cartesia.py +392 -0
  53. audio_engine/utils/__init__.py +15 -0
  54. audio_engine/utils/audio.py +220 -0
  55. atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
  56. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
  57. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ """Tests for PersonaPlexMessage (Step 2)."""
2
+
3
+ import pytest
4
+ from pipelines.personaplex import PersonaPlexMessage, MessageType
5
+
6
+
7
+ class TestPersonaPlexMessage:
8
+ """Test WebSocket message encoding and decoding."""
9
+
10
+ def test_encode_audio(self):
11
+ """Can we encode an audio message?"""
12
+ audio_data = b"fake_opus_audio"
13
+ msg = PersonaPlexMessage(type=MessageType.AUDIO, data=audio_data)
14
+ encoded = msg.encode()
15
+
16
+ assert encoded[0] == 0x01
17
+ assert encoded[1:] == audio_data
18
+
19
+ def test_encode_text(self):
20
+ """Can we encode a text message?"""
21
+ text = "Hello, world!"
22
+ msg = PersonaPlexMessage(type=MessageType.TEXT, data=text)
23
+ encoded = msg.encode()
24
+
25
+ assert encoded[0] == 0x02
26
+ assert encoded[1:] == text.encode("utf-8")
27
+
28
+ def test_decode_audio(self):
29
+ """Can we decode an audio message?"""
30
+ audio_data = b"fake_opus_audio"
31
+ raw = bytes([0x01]) + audio_data
32
+
33
+ msg = PersonaPlexMessage.decode(raw)
34
+ assert msg.type == MessageType.AUDIO
35
+ assert msg.data == audio_data
36
+
37
+ def test_decode_text(self):
38
+ """Can we decode a text message?"""
39
+ text = "Hello, assistant!"
40
+ raw = bytes([0x02]) + text.encode("utf-8")
41
+
42
+ msg = PersonaPlexMessage.decode(raw)
43
+ assert msg.type == MessageType.TEXT
44
+ assert msg.data == text
45
+
46
+ def test_decode_error(self):
47
+ """Can we decode an error message?"""
48
+ error_text = b"Connection timeout"
49
+ raw = bytes([0x05]) + error_text
50
+
51
+ msg = PersonaPlexMessage.decode(raw)
52
+ assert msg.type == MessageType.ERROR
53
+ assert msg.data == error_text
54
+
55
+ def test_roundtrip_audio(self):
56
+ """Audio message roundtrip: encode then decode?"""
57
+ original_data = b"\x00\x01\x02\x03\x04\x05"
58
+ msg1 = PersonaPlexMessage(type=MessageType.AUDIO, data=original_data)
59
+
60
+ encoded = msg1.encode()
61
+ msg2 = PersonaPlexMessage.decode(encoded)
62
+
63
+ assert msg2.type == msg1.type
64
+ assert msg2.data == original_data
65
+
66
+ def test_roundtrip_text(self):
67
+ """Text message roundtrip: encode then decode?"""
68
+ original_text = "This is a test message with émojis 🎉"
69
+ msg1 = PersonaPlexMessage(type=MessageType.TEXT, data=original_text)
70
+
71
+ encoded = msg1.encode()
72
+ msg2 = PersonaPlexMessage.decode(encoded)
73
+
74
+ assert msg2.type == msg1.type
75
+ assert msg2.data == original_text
76
+
77
+ def test_decode_too_short(self):
78
+ """Does decode reject empty message?"""
79
+ with pytest.raises(ValueError):
80
+ PersonaPlexMessage.decode(b"")
@@ -0,0 +1,226 @@
1
+ """Tests for PersonaPlexPipeline lifecycle and orchestration (Step 6-7)."""
2
+
3
+ import pytest
4
+ from unittest.mock import AsyncMock
5
+
6
+ from pipelines.personaplex import (
7
+ PersonaPlexPipeline,
8
+ PersonaPlexConfig,
9
+ SessionData,
10
+ AudioChunk,
11
+ TextChunk,
12
+ )
13
+
14
+
15
+ class TestPersonaPlexPipelineInit:
16
+ """Test pipeline initialization."""
17
+
18
+ def test_pipeline_init_with_defaults(self):
19
+ """Can we create a pipeline with defaults?"""
20
+ pipeline = PersonaPlexPipeline()
21
+
22
+ assert pipeline.session_id is not None
23
+ assert pipeline.session_data is not None
24
+ assert pipeline.system_prompt is not None
25
+
26
+ def test_pipeline_init_with_custom_config(self):
27
+ """Can we create pipeline with custom config?"""
28
+ config = PersonaPlexConfig(voice_prompt="NATM0.pt")
29
+ pipeline = PersonaPlexPipeline(config=config)
30
+
31
+ assert pipeline.config.voice_prompt == "NATM0.pt"
32
+
33
+ def test_pipeline_init_creates_session_data(self):
34
+ """Does init create a SessionData object?"""
35
+ pipeline = PersonaPlexPipeline(system_prompt="Custom prompt")
36
+
37
+ assert isinstance(pipeline.session_data, SessionData)
38
+ assert pipeline.session_data.system_prompt == "Custom prompt"
39
+
40
+ def test_pipeline_is_not_running_initially(self):
41
+ """Is pipeline not running after init?"""
42
+ pipeline = PersonaPlexPipeline()
43
+
44
+ assert not pipeline._is_running
45
+
46
+
47
+ @pytest.mark.asyncio
48
+ class TestPersonaPlexPipelineLifecycle:
49
+ """Test pipeline start/stop lifecycle."""
50
+
51
+ async def test_pipeline_start_connects_client(self):
52
+ """Does start() connect the client?"""
53
+ pipeline = PersonaPlexPipeline()
54
+ pipeline.client = AsyncMock()
55
+
56
+ await pipeline.start()
57
+
58
+ assert pipeline._is_running
59
+ pipeline.client.connect.assert_called_once()
60
+
61
+ async def test_pipeline_start_creates_receive_task(self):
62
+ """Does start() create a receive task?"""
63
+ pipeline = PersonaPlexPipeline()
64
+ pipeline.client = AsyncMock()
65
+ pipeline.client.stream_messages = AsyncMock(
66
+ return_value=AsyncMock(
67
+ __aiter__=AsyncMock(
68
+ return_value=AsyncMock(
69
+ __anext__=AsyncMock(side_effect=StopAsyncIteration)
70
+ )
71
+ )
72
+ )
73
+ )
74
+
75
+ await pipeline.start()
76
+
77
+ assert pipeline._receive_task is not None
78
+
79
+ async def test_pipeline_stop_disconnects_client(self):
80
+ """Does stop() disconnect the client?"""
81
+ pipeline = PersonaPlexPipeline()
82
+ pipeline.client = AsyncMock()
83
+ pipeline._is_running = True
84
+
85
+ result = await pipeline.stop()
86
+
87
+ pipeline.client.disconnect.assert_called_once()
88
+ assert not pipeline._is_running
89
+
90
+ async def test_pipeline_context_manager(self):
91
+ """Does context manager handle lifecycle?"""
92
+ pipeline = PersonaPlexPipeline()
93
+ pipeline.client = AsyncMock()
94
+ pipeline.client.stream_messages = AsyncMock(
95
+ return_value=AsyncMock(
96
+ __aiter__=AsyncMock(
97
+ return_value=AsyncMock(
98
+ __anext__=AsyncMock(side_effect=StopAsyncIteration)
99
+ )
100
+ )
101
+ )
102
+ )
103
+
104
+ async with pipeline:
105
+ assert pipeline._is_running
106
+
107
+ pipeline.client.disconnect.assert_called_once()
108
+
109
+
110
+ @pytest.mark.asyncio
111
+ class TestPersonaPlexPipelineSendAudio:
112
+ """Test sending audio."""
113
+
114
+ async def test_pipeline_send_audio_when_running(self):
115
+ """Can we send audio when pipeline is running?"""
116
+ pipeline = PersonaPlexPipeline()
117
+ pipeline.client = AsyncMock()
118
+ pipeline._is_running = True
119
+
120
+ audio = b"opus_audio_data"
121
+ await pipeline.send_audio(audio)
122
+
123
+ pipeline.client.send_audio.assert_called_once_with(audio)
124
+
125
+ async def test_pipeline_send_audio_fails_when_stopped(self):
126
+ """Does send_audio fail when not running?"""
127
+ pipeline = PersonaPlexPipeline()
128
+ pipeline._is_running = False
129
+
130
+ with pytest.raises(RuntimeError):
131
+ await pipeline.send_audio(b"audio")
132
+
133
+
134
+ @pytest.mark.asyncio
135
+ class TestPersonaPlexPipelineStreaming:
136
+ """Test streaming interface."""
137
+
138
+ async def test_pipeline_stream_requires_running(self):
139
+ """Does stream() require pipeline to be running?"""
140
+ pipeline = PersonaPlexPipeline()
141
+ pipeline._is_running = False
142
+
143
+ with pytest.raises(RuntimeError):
144
+ async for _ in pipeline.stream():
145
+ pass
146
+
147
+ async def test_pipeline_stream_yields_tuples(self):
148
+ """Does stream() yield (audio, text) tuples?"""
149
+ import asyncio
150
+
151
+ pipeline = PersonaPlexPipeline()
152
+ pipeline._is_running = True
153
+
154
+ # Pre-populate queues
155
+ audio_chunk = AudioChunk(data=b"audio", sample_rate=48000)
156
+ text_chunk = TextChunk(text="Hello")
157
+
158
+ await pipeline._audio_queue.put(audio_chunk)
159
+ await pipeline._text_queue.put(text_chunk)
160
+
161
+ # Stop after one iteration
162
+ async def stop_soon():
163
+ await asyncio.sleep(0.01)
164
+ pipeline._is_running = False
165
+
166
+ asyncio.create_task(stop_soon())
167
+
168
+ results = []
169
+ async for audio, text in pipeline.stream():
170
+ results.append((audio, text))
171
+ if len(results) >= 1:
172
+ break
173
+
174
+ # Should get at least one result
175
+ assert len(results) >= 1
176
+
177
+
178
+ @pytest.mark.asyncio
179
+ class TestPersonaPlexPipelineTranscript:
180
+ """Test transcript management."""
181
+
182
+ def test_pipeline_session_data_stores_messages(self):
183
+ """Does session_data store conversation messages?"""
184
+ pipeline = PersonaPlexPipeline()
185
+
186
+ pipeline.session_data.add_message("user", "Hello")
187
+ pipeline.session_data.add_message("assistant", "Hi there!")
188
+
189
+ assert len(pipeline.session_data.messages) == 2
190
+ assert pipeline.session_data.messages[0].role == "user"
191
+ assert pipeline.session_data.messages[1].role == "assistant"
192
+
193
+ async def test_pipeline_stop_can_save_transcript(self):
194
+ """Can pipeline save transcript when configured?"""
195
+ import tempfile
196
+ from pathlib import Path
197
+
198
+ with tempfile.TemporaryDirectory() as tmpdir:
199
+ config = PersonaPlexConfig(transcript_path=tmpdir, save_transcripts=True)
200
+ pipeline = PersonaPlexPipeline(config=config)
201
+ pipeline.client = AsyncMock()
202
+ pipeline._is_running = True # Pretend running so stop works
203
+
204
+ # Add a message to transcript
205
+ pipeline.session_data.add_message("user", "Test message")
206
+
207
+ # Stop (should save because save_transcripts=True)
208
+ result = await pipeline.stop()
209
+
210
+ # Check transcript was saved if result is not None
211
+ if result is not None:
212
+ transcript_files = list(Path(tmpdir).glob("*.json"))
213
+ assert len(transcript_files) > 0
214
+
215
+ def test_pipeline_session_data_to_dict(self):
216
+ """Can we convert session data to dict?"""
217
+ pipeline = PersonaPlexPipeline(system_prompt="Test")
218
+ pipeline.session_data.add_message("user", "Hello")
219
+
220
+ data = pipeline.session_data.to_dict()
221
+
222
+ assert "session_id" in data
223
+ assert "timestamp" in data
224
+ assert "system_prompt" in data
225
+ assert "messages" in data
226
+ assert len(data["messages"]) == 1
@@ -0,0 +1,184 @@
1
+ """Tests for session management utilities (Step 4)."""
2
+
3
+ from datetime import datetime, UTC
4
+ import uuid
5
+
6
+ from pipelines.personaplex import (
7
+ generate_session_id,
8
+ get_timestamp_iso,
9
+ list_transcripts,
10
+ format_transcript_for_display,
11
+ )
12
+
13
+
14
+ class TestSessionUtilities:
15
+ """Test session ID and timestamp generation."""
16
+
17
+ def test_generate_session_id_returns_string(self):
18
+ """Does generate_session_id return a string?"""
19
+ session_id = generate_session_id()
20
+
21
+ assert isinstance(session_id, str)
22
+ assert len(session_id) > 0
23
+
24
+ def test_generate_session_id_is_uuid4(self):
25
+ """Does it generate valid UUID4 format?"""
26
+ session_id = generate_session_id()
27
+
28
+ # Should be valid UUID
29
+ uuid.UUID(session_id)
30
+
31
+ def test_generate_session_id_unique(self):
32
+ """Do consecutive calls generate different IDs?"""
33
+ id1 = generate_session_id()
34
+ id2 = generate_session_id()
35
+
36
+ assert id1 != id2
37
+
38
+ def test_get_timestamp_iso_returns_string(self):
39
+ """Does get_timestamp_iso return a string?"""
40
+ timestamp = get_timestamp_iso()
41
+
42
+ assert isinstance(timestamp, str)
43
+ assert len(timestamp) > 0
44
+
45
+ def test_get_timestamp_iso_ends_with_z(self):
46
+ """Does timestamp end with Z suffix?"""
47
+ timestamp = get_timestamp_iso()
48
+
49
+ assert timestamp.endswith("Z")
50
+
51
+ def test_get_timestamp_iso_is_valid_iso8601(self):
52
+ """Is the timestamp in ISO 8601 format?"""
53
+ timestamp = get_timestamp_iso()
54
+
55
+ # Remove Z suffix and parse
56
+ iso_part = timestamp[:-1]
57
+ parsed = datetime.fromisoformat(iso_part)
58
+
59
+ assert isinstance(parsed, datetime)
60
+
61
+ def test_get_timestamp_iso_close_to_now(self):
62
+ """Is generated timestamp close to current time?"""
63
+ before = datetime.now(UTC)
64
+ timestamp = get_timestamp_iso()
65
+ after = datetime.now(UTC)
66
+
67
+ # Parse timestamp (remove Z and convert to UTC-aware)
68
+ iso_part = timestamp[:-1]
69
+ parsed = datetime.fromisoformat(iso_part).replace(tzinfo=UTC)
70
+
71
+ # Should be within 1 second of now
72
+ assert before <= parsed <= after
73
+
74
+
75
+ class TestTranscriptUtilities:
76
+ """Test transcript management utilities."""
77
+
78
+ def test_format_transcript_for_display_empty(self):
79
+ """Can we format an empty transcript?"""
80
+ from pipelines.personaplex import SessionData
81
+
82
+ session = SessionData(
83
+ session_id="test-format",
84
+ timestamp="2025-01-01T00:00:00Z",
85
+ system_prompt="You are helpful",
86
+ voice_prompt="NATF0.pt",
87
+ messages=[],
88
+ )
89
+
90
+ formatted = format_transcript_for_display(session)
91
+
92
+ assert isinstance(formatted, str)
93
+
94
+ def test_format_transcript_for_display_with_messages(self):
95
+ """Does formatting show user/assistant messages?"""
96
+ from pipelines.personaplex import SessionData, TranscriptMessage
97
+
98
+ msg1 = TranscriptMessage(
99
+ role="user",
100
+ text="Hello",
101
+ timestamp="2025-01-01T00:00:01Z",
102
+ )
103
+ msg2 = TranscriptMessage(
104
+ role="assistant",
105
+ text="Hi there!",
106
+ timestamp="2025-01-01T00:00:02Z",
107
+ )
108
+ session = SessionData(
109
+ session_id="test-format-msg",
110
+ timestamp="2025-01-01T00:00:00Z",
111
+ system_prompt="You are helpful",
112
+ voice_prompt="NATF0.pt",
113
+ messages=[msg1, msg2],
114
+ )
115
+
116
+ formatted = format_transcript_for_display(session)
117
+
118
+ # Should contain message text
119
+ assert "Hello" in formatted
120
+ assert "Hi there!" in formatted
121
+
122
+ def test_format_transcript_for_display_labels_speakers(self):
123
+ """Does formatting include speaker labels?"""
124
+ from pipelines.personaplex import SessionData, TranscriptMessage
125
+
126
+ msg = TranscriptMessage(
127
+ role="user",
128
+ text="Test message",
129
+ timestamp="2025-01-01T00:00:01Z",
130
+ )
131
+ session = SessionData(
132
+ session_id="test-labels",
133
+ timestamp="2025-01-01T00:00:00Z",
134
+ system_prompt="You are helpful",
135
+ voice_prompt="NATF0.pt",
136
+ messages=[msg],
137
+ )
138
+
139
+ formatted = format_transcript_for_display(session)
140
+
141
+ # Should show role/speaker
142
+ assert "user" in formatted.lower() or "User" in formatted
143
+
144
+ def test_list_transcripts_returns_list(self):
145
+ """Does list_transcripts return a list?"""
146
+ import tempfile
147
+
148
+ with tempfile.TemporaryDirectory() as tmpdir:
149
+ transcripts = list_transcripts(tmpdir)
150
+
151
+ assert isinstance(transcripts, list)
152
+
153
+ def test_list_transcripts_empty_dir(self):
154
+ """Does it return empty list for empty directory?"""
155
+ import tempfile
156
+
157
+ with tempfile.TemporaryDirectory() as tmpdir:
158
+ transcripts = list_transcripts(tmpdir)
159
+
160
+ assert len(transcripts) == 0
161
+
162
+ def test_list_transcripts_finds_json_files(self):
163
+ """Does it find transcript JSON files?"""
164
+ import tempfile
165
+ from pathlib import Path
166
+
167
+ with tempfile.TemporaryDirectory() as tmpdir:
168
+ from pipelines.personaplex import save_transcript, SessionData
169
+
170
+ # Create a transcript
171
+ session = SessionData(
172
+ session_id="list-test",
173
+ timestamp="2025-01-01T00:00:00Z",
174
+ system_prompt="You are helpful",
175
+ voice_prompt="NATF0.pt",
176
+ messages=[],
177
+ )
178
+ save_transcript(session, tmpdir)
179
+
180
+ # List should find it
181
+ transcripts = list_transcripts(tmpdir)
182
+
183
+ assert len(transcripts) >= 1
184
+ assert any("list-test" in str(t) for t in transcripts)
@@ -0,0 +1,184 @@
1
+ """Tests for transcript persistence (Step 3)."""
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ from pipelines.personaplex import (
8
+ SessionData,
9
+ TranscriptMessage,
10
+ save_transcript,
11
+ load_transcript,
12
+ )
13
+
14
+
15
+ class TestTranscriptPersistence:
16
+ """Test transcript saving and loading."""
17
+
18
+ def test_save_transcript_creates_file(self):
19
+ """Does save_transcript create a JSON file?"""
20
+ with tempfile.TemporaryDirectory() as tmpdir:
21
+ session = SessionData(
22
+ session_id="test-123",
23
+ timestamp="2025-01-01T00:00:00Z",
24
+ system_prompt="You are helpful",
25
+ voice_prompt="NATF0.pt",
26
+ messages=[],
27
+ )
28
+
29
+ saved_path = save_transcript(session, Path(tmpdir))
30
+
31
+ assert saved_path.exists()
32
+ assert saved_path.suffix == ".json"
33
+
34
+ def test_save_transcript_contains_correct_structure(self):
35
+ """Does saved JSON have session_id, timestamp, and messages?"""
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ msg = TranscriptMessage(
38
+ role="user",
39
+ text="Hello",
40
+ timestamp="2025-01-01T00:00:01Z",
41
+ )
42
+ session = SessionData(
43
+ session_id="test-456",
44
+ timestamp="2025-01-01T00:00:00Z",
45
+ system_prompt="You are helpful",
46
+ voice_prompt="NATF0.pt",
47
+ messages=[msg],
48
+ )
49
+
50
+ saved_path = save_transcript(session, Path(tmpdir))
51
+
52
+ with open(saved_path, "r") as f:
53
+ data = json.load(f)
54
+
55
+ assert data["session_id"] == "test-456"
56
+ assert data["timestamp"] == "2025-01-01T00:00:00Z"
57
+ assert len(data["messages"]) == 1
58
+ assert data["messages"][0]["role"] == "user"
59
+ assert data["messages"][0]["text"] == "Hello"
60
+
61
+ def test_load_transcript_reads_file(self):
62
+ """Can we load a transcript from JSON?"""
63
+ with tempfile.TemporaryDirectory() as tmpdir:
64
+ # Create and save a session
65
+ msg = TranscriptMessage(
66
+ role="assistant",
67
+ text="Hi there!",
68
+ timestamp="2025-01-01T00:00:02Z",
69
+ )
70
+ original = SessionData(
71
+ session_id="test-789",
72
+ timestamp="2025-01-01T00:00:00Z",
73
+ system_prompt="You are helpful",
74
+ voice_prompt="NATF0.pt",
75
+ messages=[msg],
76
+ )
77
+
78
+ saved_path = save_transcript(original, Path(tmpdir))
79
+
80
+ # Load it back
81
+ loaded = load_transcript(saved_path)
82
+
83
+ assert loaded.session_id == "test-789"
84
+ assert loaded.timestamp == "2025-01-01T00:00:00Z"
85
+ assert len(loaded.messages) == 1
86
+ assert loaded.messages[0].role == "assistant"
87
+ assert loaded.messages[0].text == "Hi there!"
88
+
89
+ def test_transcript_roundtrip(self):
90
+ """Session → save → load → Session equals original?"""
91
+ with tempfile.TemporaryDirectory() as tmpdir:
92
+ msg1 = TranscriptMessage(
93
+ role="user",
94
+ text="First message",
95
+ timestamp="2025-01-01T00:00:01Z",
96
+ )
97
+ msg2 = TranscriptMessage(
98
+ role="assistant",
99
+ text="Response",
100
+ timestamp="2025-01-01T00:00:02Z",
101
+ )
102
+ original = SessionData(
103
+ session_id="roundtrip-123",
104
+ timestamp="2025-01-01T00:00:00Z",
105
+ system_prompt="You are helpful",
106
+ voice_prompt="NATF0.pt",
107
+ messages=[msg1, msg2],
108
+ )
109
+
110
+ saved_path = save_transcript(original, Path(tmpdir))
111
+ loaded = load_transcript(saved_path)
112
+
113
+ # Compare all fields
114
+ assert loaded.session_id == original.session_id
115
+ assert loaded.timestamp == original.timestamp
116
+ assert loaded.system_prompt == original.system_prompt
117
+ assert loaded.voice_prompt == original.voice_prompt
118
+ assert len(loaded.messages) == len(original.messages)
119
+ for i, msg in enumerate(original.messages):
120
+ assert loaded.messages[i].role == msg.role
121
+ assert loaded.messages[i].text == msg.text
122
+ assert loaded.messages[i].timestamp == msg.timestamp
123
+
124
+ def test_transcript_add_message(self):
125
+ """Can we add messages to a SessionData?"""
126
+ session = SessionData(
127
+ session_id="add-test",
128
+ timestamp="2025-01-01T00:00:00Z",
129
+ system_prompt="You are helpful",
130
+ voice_prompt="NATF0.pt",
131
+ messages=[],
132
+ )
133
+
134
+ assert len(session.messages) == 0
135
+
136
+ session.add_message("user", "Test")
137
+
138
+ assert len(session.messages) == 1
139
+ assert session.messages[0].text == "Test"
140
+
141
+ def test_transcript_to_dict(self):
142
+ """Can we convert SessionData to dict?"""
143
+ msg = TranscriptMessage(
144
+ role="user",
145
+ text="Convert me",
146
+ timestamp="2025-01-01T00:00:01Z",
147
+ )
148
+ session = SessionData(
149
+ session_id="dict-test",
150
+ timestamp="2025-01-01T00:00:00Z",
151
+ system_prompt="You are helpful",
152
+ voice_prompt="NATF0.pt",
153
+ messages=[msg],
154
+ )
155
+
156
+ data = session.to_dict()
157
+
158
+ assert isinstance(data, dict)
159
+ assert data["session_id"] == "dict-test"
160
+ assert data["timestamp"] == "2025-01-01T00:00:00Z"
161
+ assert len(data["messages"]) == 1
162
+
163
+ def test_transcript_from_dict(self):
164
+ """Can we create SessionData from dict?"""
165
+ data = {
166
+ "session_id": "from-dict-test",
167
+ "timestamp": "2025-01-01T00:00:00Z",
168
+ "system_prompt": "You are helpful",
169
+ "voice_prompt": "NATF0.pt",
170
+ "messages": [
171
+ {
172
+ "role": "user",
173
+ "text": "Recreate me",
174
+ "timestamp": "2025-01-01T00:00:01Z",
175
+ }
176
+ ],
177
+ }
178
+
179
+ session = SessionData.from_dict(data)
180
+
181
+ assert session.session_id == "from-dict-test"
182
+ assert session.system_prompt == "You are helpful"
183
+ assert len(session.messages) == 1
184
+ assert session.messages[0].text == "Recreate me"
@@ -0,0 +1 @@
1
+ """Tests for the audio engine."""