rosabeats 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rosabeats/__init__.py +1 -1
- rosabeats/__main__.py +59 -0
- rosabeats/beatrecipe_processor.py +63 -46
- rosabeats/beatswitch.py +29 -13
- rosabeats/downbeat.py +207 -0
- rosabeats/rosabeats.py +575 -543
- rosabeats/rosabeats_shell.py +391 -284
- rosabeats/segment_song.py +100 -31
- {rosabeats-0.1.3.dist-info → rosabeats-0.2.0.dist-info}/METADATA +8 -30
- rosabeats-0.2.0.dist-info/RECORD +21 -0
- rosabeats-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/conftest.py +131 -0
- tests/test_beatrecipe_processor.py +193 -0
- tests/test_downbeat.py +149 -0
- tests/test_rosabeats.py +234 -0
- tests/test_segment_song.py +120 -0
- tests/test_shell.py +305 -0
- docs/beatrecipe_docs.txt +0 -80
- rosabeats-0.1.3.dist-info/RECORD +0 -16
- rosabeats-0.1.3.dist-info/top_level.txt +0 -3
- scripts/reverse_beats_in_bars_rosa.py +0 -48
- scripts/shuffle_bars_rosa.py +0 -35
- scripts/shuffle_beats_rosa.py +0 -29
- {rosabeats-0.1.3.dist-info → rosabeats-0.2.0.dist-info}/WHEEL +0 -0
- {rosabeats-0.1.3.dist-info → rosabeats-0.2.0.dist-info}/entry_points.txt +0 -0
- {rosabeats-0.1.3.dist-info → rosabeats-0.2.0.dist-info}/licenses/LICENSE.md +0 -0
tests/test_downbeat.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Tests for downbeat detection module."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from rosabeats.downbeat import (
|
|
7
|
+
compute_beat_features,
|
|
8
|
+
score_offset,
|
|
9
|
+
detect_downbeat,
|
|
10
|
+
detect_downbeat_dbn,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestComputeBeatFeatures:
|
|
15
|
+
"""Tests for compute_beat_features function."""
|
|
16
|
+
|
|
17
|
+
def test_empty_beat_times(self, mono_audio, sample_rate):
|
|
18
|
+
"""Should return empty dict for empty beat times."""
|
|
19
|
+
features = compute_beat_features(mono_audio, sample_rate, np.array([]))
|
|
20
|
+
assert features == {}
|
|
21
|
+
|
|
22
|
+
def test_returns_expected_features(self, synthetic_audio_with_beats):
|
|
23
|
+
"""Should return dict with expected feature keys."""
|
|
24
|
+
audio, sr = synthetic_audio_with_beats
|
|
25
|
+
beat_times = np.arange(0, 10, 0.5) # 120 BPM
|
|
26
|
+
|
|
27
|
+
features = compute_beat_features(audio, sr, beat_times)
|
|
28
|
+
|
|
29
|
+
assert 'onset_strength' in features
|
|
30
|
+
assert 'low_freq_energy' in features
|
|
31
|
+
assert 'mid_freq_energy' in features
|
|
32
|
+
assert 'spectral_flux' in features
|
|
33
|
+
assert 'low_mid_ratio' in features
|
|
34
|
+
|
|
35
|
+
def test_feature_array_lengths(self, synthetic_audio_with_beats):
|
|
36
|
+
"""Feature arrays should have same length as beat_times."""
|
|
37
|
+
audio, sr = synthetic_audio_with_beats
|
|
38
|
+
beat_times = np.arange(0, 10, 0.5)
|
|
39
|
+
n_beats = len(beat_times)
|
|
40
|
+
|
|
41
|
+
features = compute_beat_features(audio, sr, beat_times)
|
|
42
|
+
|
|
43
|
+
for key, values in features.items():
|
|
44
|
+
assert len(values) == n_beats, f"{key} has wrong length"
|
|
45
|
+
|
|
46
|
+
def test_features_are_normalized(self, synthetic_audio_with_beats):
|
|
47
|
+
"""Features should be normalized to 0-1 range."""
|
|
48
|
+
audio, sr = synthetic_audio_with_beats
|
|
49
|
+
beat_times = np.arange(0, 10, 0.5)
|
|
50
|
+
|
|
51
|
+
features = compute_beat_features(audio, sr, beat_times)
|
|
52
|
+
|
|
53
|
+
for key in ['onset_strength', 'low_freq_energy', 'mid_freq_energy', 'spectral_flux']:
|
|
54
|
+
assert features[key].min() >= 0, f"{key} min should be >= 0"
|
|
55
|
+
assert features[key].max() <= 1, f"{key} max should be <= 1"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TestScoreOffset:
|
|
59
|
+
"""Tests for score_offset function."""
|
|
60
|
+
|
|
61
|
+
def test_returns_float(self):
|
|
62
|
+
"""Should return a float score."""
|
|
63
|
+
features = {
|
|
64
|
+
'onset_strength': np.array([1.0, 0.5, 0.5, 0.5, 1.0, 0.5, 0.5, 0.5]),
|
|
65
|
+
'low_freq_energy': np.array([1.0, 0.2, 0.2, 0.2, 1.0, 0.2, 0.2, 0.2]),
|
|
66
|
+
'low_mid_ratio': np.array([1.0, 0.3, 0.3, 0.3, 1.0, 0.3, 0.3, 0.3]),
|
|
67
|
+
'spectral_flux': np.array([0.8, 0.4, 0.4, 0.4, 0.8, 0.4, 0.4, 0.4]),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
score = score_offset(features, offset=0, beats_per_bar=4)
|
|
71
|
+
assert isinstance(score, float)
|
|
72
|
+
|
|
73
|
+
def test_correct_offset_scores_higher(self):
|
|
74
|
+
"""Offset 0 should score higher when downbeats have stronger features."""
|
|
75
|
+
# Simulate features where beat 0, 4, 8 have strong low freq (downbeats)
|
|
76
|
+
features = {
|
|
77
|
+
'onset_strength': np.array([1.0, 0.3, 0.3, 0.3, 1.0, 0.3, 0.3, 0.3, 1.0, 0.3, 0.3, 0.3]),
|
|
78
|
+
'low_freq_energy': np.array([1.0, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1, 0.1]),
|
|
79
|
+
'low_mid_ratio': np.array([1.0, 0.2, 0.2, 0.2, 1.0, 0.2, 0.2, 0.2, 1.0, 0.2, 0.2, 0.2]),
|
|
80
|
+
'spectral_flux': np.array([0.8, 0.3, 0.3, 0.3, 0.8, 0.3, 0.3, 0.3, 0.8, 0.3, 0.3, 0.3]),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
score_0 = score_offset(features, offset=0, beats_per_bar=4)
|
|
84
|
+
score_1 = score_offset(features, offset=1, beats_per_bar=4)
|
|
85
|
+
score_2 = score_offset(features, offset=2, beats_per_bar=4)
|
|
86
|
+
score_3 = score_offset(features, offset=3, beats_per_bar=4)
|
|
87
|
+
|
|
88
|
+
assert score_0 > score_1, "Offset 0 should score higher than offset 1"
|
|
89
|
+
assert score_0 > score_2, "Offset 0 should score higher than offset 2"
|
|
90
|
+
assert score_0 > score_3, "Offset 0 should score higher than offset 3"
|
|
91
|
+
|
|
92
|
+
def test_too_few_beats(self):
|
|
93
|
+
"""Should return 0 if fewer beats than beats_per_bar."""
|
|
94
|
+
features = {
|
|
95
|
+
'onset_strength': np.array([1.0, 0.5]),
|
|
96
|
+
'low_freq_energy': np.array([1.0, 0.5]),
|
|
97
|
+
'low_mid_ratio': np.array([1.0, 0.5]),
|
|
98
|
+
'spectral_flux': np.array([1.0, 0.5]),
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
score = score_offset(features, offset=0, beats_per_bar=4)
|
|
102
|
+
assert score == 0.0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TestDetectDownbeat:
|
|
106
|
+
"""Tests for detect_downbeat function."""
|
|
107
|
+
|
|
108
|
+
def test_returns_valid_offset(self, synthetic_audio_with_beats):
|
|
109
|
+
"""Should return an integer offset within valid range."""
|
|
110
|
+
audio, sr = synthetic_audio_with_beats
|
|
111
|
+
beat_times = np.arange(0, 10, 0.5)
|
|
112
|
+
beats_per_bar = 4
|
|
113
|
+
|
|
114
|
+
result = detect_downbeat(audio, sr, beat_times, beats_per_bar)
|
|
115
|
+
|
|
116
|
+
assert isinstance(result, (int, np.integer))
|
|
117
|
+
assert 0 <= result < beats_per_bar
|
|
118
|
+
|
|
119
|
+
def test_few_beats_returns_zero(self, mono_audio, sample_rate):
|
|
120
|
+
"""Should return 0 when there are very few beats."""
|
|
121
|
+
beat_times = np.array([0.0, 0.5]) # Only 2 beats
|
|
122
|
+
|
|
123
|
+
result = detect_downbeat(mono_audio, sample_rate, beat_times, beats_per_bar=4)
|
|
124
|
+
|
|
125
|
+
assert result == 0
|
|
126
|
+
|
|
127
|
+
def test_detects_correct_downbeat_synthetic(self, synthetic_audio_with_beats):
|
|
128
|
+
"""Should detect offset 0 for synthetic audio with kicks on downbeats."""
|
|
129
|
+
audio, sr = synthetic_audio_with_beats
|
|
130
|
+
beat_times = np.arange(0, 10, 0.5)
|
|
131
|
+
|
|
132
|
+
result = detect_downbeat(audio, sr, beat_times, beats_per_bar=4)
|
|
133
|
+
|
|
134
|
+
# The synthetic audio has kicks on beat 0, 4, 8, etc. so offset should be 0
|
|
135
|
+
assert result == 0, f"Expected offset 0 but got {result}"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class TestDetectDownbeatDbn:
|
|
139
|
+
"""Tests for detect_downbeat_dbn (alias function)."""
|
|
140
|
+
|
|
141
|
+
def test_is_alias_for_detect_downbeat(self, synthetic_audio_with_beats):
|
|
142
|
+
"""detect_downbeat_dbn should produce same results as detect_downbeat."""
|
|
143
|
+
audio, sr = synthetic_audio_with_beats
|
|
144
|
+
beat_times = np.arange(0, 10, 0.5)
|
|
145
|
+
|
|
146
|
+
result1 = detect_downbeat(audio, sr, beat_times, beats_per_bar=4)
|
|
147
|
+
result2 = detect_downbeat_dbn(audio, sr, beat_times, beats_per_bar=4)
|
|
148
|
+
|
|
149
|
+
assert result1 == result2
|
tests/test_rosabeats.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Tests for core rosabeats module."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from rosabeats.rosabeats import rosabeats
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestRosabeatsInit:
|
|
11
|
+
"""Tests for rosabeats initialization."""
|
|
12
|
+
|
|
13
|
+
def test_init_without_file(self):
|
|
14
|
+
"""Should initialize without a file."""
|
|
15
|
+
r = rosabeats()
|
|
16
|
+
assert r.sourcefile is None
|
|
17
|
+
assert r.data is None
|
|
18
|
+
assert r.sr is None
|
|
19
|
+
|
|
20
|
+
def test_init_with_file(self, temp_audio_file):
|
|
21
|
+
"""Should initialize with a file path."""
|
|
22
|
+
r = rosabeats(temp_audio_file)
|
|
23
|
+
assert r.sourcefile is not None
|
|
24
|
+
assert temp_audio_file in r.sourcefile
|
|
25
|
+
|
|
26
|
+
def test_debug_mode(self):
|
|
27
|
+
"""Should set debug mode."""
|
|
28
|
+
r = rosabeats(debug=True)
|
|
29
|
+
assert rosabeats.debug is True
|
|
30
|
+
|
|
31
|
+
# Reset
|
|
32
|
+
rosabeats.debug = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TestRosabeatsSetfile:
|
|
36
|
+
"""Tests for setfile method."""
|
|
37
|
+
|
|
38
|
+
def test_setfile_sets_sourcefile(self, temp_audio_file):
|
|
39
|
+
"""setfile should set sourcefile to absolute path."""
|
|
40
|
+
r = rosabeats()
|
|
41
|
+
r.setfile(temp_audio_file)
|
|
42
|
+
|
|
43
|
+
assert r.sourcefile is not None
|
|
44
|
+
assert os.path.isabs(r.sourcefile)
|
|
45
|
+
|
|
46
|
+
def test_setfile_sets_saved_features_path(self, temp_audio_file):
|
|
47
|
+
"""setfile should set saved_features path."""
|
|
48
|
+
r = rosabeats()
|
|
49
|
+
r.setfile(temp_audio_file)
|
|
50
|
+
|
|
51
|
+
assert r.saved_features is not None
|
|
52
|
+
assert r.saved_features.endswith(".pkl")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TestRosabeatsLoad:
|
|
56
|
+
"""Tests for audio loading."""
|
|
57
|
+
|
|
58
|
+
def test_load_wav_file(self, temp_audio_file):
|
|
59
|
+
"""Should load a WAV file."""
|
|
60
|
+
r = rosabeats(temp_audio_file)
|
|
61
|
+
r.load()
|
|
62
|
+
|
|
63
|
+
assert r.data is not None
|
|
64
|
+
assert r.sr is not None
|
|
65
|
+
assert r.channels is not None
|
|
66
|
+
|
|
67
|
+
def test_load_creates_stereo_data(self, temp_audio_file):
|
|
68
|
+
"""Loaded data should have channel dimension."""
|
|
69
|
+
r = rosabeats(temp_audio_file)
|
|
70
|
+
r.load()
|
|
71
|
+
|
|
72
|
+
# Data should be (channels, samples)
|
|
73
|
+
assert r.data.ndim >= 1
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TestRosabeatsMixToMono:
|
|
77
|
+
"""Tests for mix_to_mono method."""
|
|
78
|
+
|
|
79
|
+
def test_mix_to_mono(self, temp_audio_file):
|
|
80
|
+
"""Should create mono mix of audio."""
|
|
81
|
+
r = rosabeats(temp_audio_file)
|
|
82
|
+
r.load()
|
|
83
|
+
r.mix_to_mono()
|
|
84
|
+
|
|
85
|
+
assert r.mono is not None
|
|
86
|
+
assert r.mono.ndim == 1
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TestRosabeatsTrackBeats:
|
|
90
|
+
"""Tests for beat tracking."""
|
|
91
|
+
|
|
92
|
+
def test_track_beats_basic(self, temp_audio_file):
|
|
93
|
+
"""Should track beats and set related attributes."""
|
|
94
|
+
r = rosabeats(temp_audio_file)
|
|
95
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
96
|
+
|
|
97
|
+
assert r.beat_timings is not None
|
|
98
|
+
assert r.beat_samples is not None
|
|
99
|
+
assert r.beat_slices is not None
|
|
100
|
+
assert r.total_beats is not None
|
|
101
|
+
assert r.total_beats > 0
|
|
102
|
+
assert r.beatsperbar == 4
|
|
103
|
+
assert r.downbeat == 0
|
|
104
|
+
|
|
105
|
+
def test_track_beats_sets_total_bars(self, temp_audio_file):
|
|
106
|
+
"""Should calculate total bars."""
|
|
107
|
+
r = rosabeats(temp_audio_file)
|
|
108
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
109
|
+
|
|
110
|
+
assert r.total_bars is not None
|
|
111
|
+
expected_bars = (r.total_beats - r.downbeat) // r.beatsperbar
|
|
112
|
+
assert r.total_bars == expected_bars
|
|
113
|
+
|
|
114
|
+
def test_track_beats_with_downbeat(self, temp_audio_file):
|
|
115
|
+
"""Should handle non-zero downbeat."""
|
|
116
|
+
r = rosabeats(temp_audio_file)
|
|
117
|
+
r.track_beats(beatsper=4, downbeat=2)
|
|
118
|
+
|
|
119
|
+
assert r.downbeat == 2
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class TestRosabeatsBeatStartsBar:
|
|
123
|
+
"""Tests for beat_starts_bar method."""
|
|
124
|
+
|
|
125
|
+
@pytest.fixture
|
|
126
|
+
def tracked_rosabeats(self, temp_audio_file):
|
|
127
|
+
"""Create rosabeats instance with beats tracked."""
|
|
128
|
+
r = rosabeats(temp_audio_file)
|
|
129
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
130
|
+
return r
|
|
131
|
+
|
|
132
|
+
def test_downbeat_starts_bar(self, tracked_rosabeats):
|
|
133
|
+
"""Beat 0 should start bar 0 when downbeat is 0."""
|
|
134
|
+
result = tracked_rosabeats.beat_starts_bar(0)
|
|
135
|
+
assert result == 0
|
|
136
|
+
|
|
137
|
+
def test_non_downbeat_returns_none(self, tracked_rosabeats):
|
|
138
|
+
"""Non-downbeat beats should return None."""
|
|
139
|
+
result = tracked_rosabeats.beat_starts_bar(1)
|
|
140
|
+
assert result is None
|
|
141
|
+
|
|
142
|
+
result = tracked_rosabeats.beat_starts_bar(2)
|
|
143
|
+
assert result is None
|
|
144
|
+
|
|
145
|
+
def test_subsequent_downbeats(self, tracked_rosabeats):
|
|
146
|
+
"""Beats at bar boundaries should return bar numbers."""
|
|
147
|
+
result = tracked_rosabeats.beat_starts_bar(4)
|
|
148
|
+
assert result == 1
|
|
149
|
+
|
|
150
|
+
result = tracked_rosabeats.beat_starts_bar(8)
|
|
151
|
+
assert result == 2
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TestRosabeatsOutputControl:
|
|
155
|
+
"""Tests for output enable/disable methods."""
|
|
156
|
+
|
|
157
|
+
def test_enable_disable_play(self):
|
|
158
|
+
"""Should enable and disable playback output."""
|
|
159
|
+
r = rosabeats()
|
|
160
|
+
|
|
161
|
+
r.enable_output_play()
|
|
162
|
+
assert r.output_play is True
|
|
163
|
+
|
|
164
|
+
r.disable_output_play()
|
|
165
|
+
assert r.output_play is False
|
|
166
|
+
|
|
167
|
+
def test_enable_disable_save(self, tmp_path):
|
|
168
|
+
"""Should enable and disable save output."""
|
|
169
|
+
r = rosabeats()
|
|
170
|
+
outfile = str(tmp_path / "out.wav")
|
|
171
|
+
|
|
172
|
+
r.enable_output_save(outfile)
|
|
173
|
+
assert r.output_save is True
|
|
174
|
+
assert r.remix_output_file == outfile
|
|
175
|
+
|
|
176
|
+
r.disable_output_save()
|
|
177
|
+
assert r.output_save is False
|
|
178
|
+
|
|
179
|
+
def test_enable_disable_beats(self, tmp_path):
|
|
180
|
+
"""Should enable and disable beats output."""
|
|
181
|
+
r = rosabeats()
|
|
182
|
+
outfile = str(tmp_path / "out.br")
|
|
183
|
+
|
|
184
|
+
r.enable_output_beats(outfile)
|
|
185
|
+
assert r.output_beats is True
|
|
186
|
+
assert r.beats_output_file == outfile
|
|
187
|
+
|
|
188
|
+
r.disable_output_beats()
|
|
189
|
+
assert r.output_beats is False
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class TestRosabeatsDetectDownbeat:
|
|
193
|
+
"""Tests for detect_downbeat method."""
|
|
194
|
+
|
|
195
|
+
def test_detect_downbeat_returns_int(self, temp_audio_file):
|
|
196
|
+
"""detect_downbeat should return an integer."""
|
|
197
|
+
r = rosabeats(temp_audio_file)
|
|
198
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
199
|
+
|
|
200
|
+
result = r.detect_downbeat(4)
|
|
201
|
+
|
|
202
|
+
assert isinstance(result, (int, np.integer))
|
|
203
|
+
assert 0 <= result < 4
|
|
204
|
+
|
|
205
|
+
def test_detect_downbeat_dbn(self, temp_audio_file):
|
|
206
|
+
"""detect_downbeat_dbn should work."""
|
|
207
|
+
r = rosabeats(temp_audio_file)
|
|
208
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
209
|
+
|
|
210
|
+
result = r.detect_downbeat_dbn(4)
|
|
211
|
+
|
|
212
|
+
assert isinstance(result, (int, np.integer))
|
|
213
|
+
assert 0 <= result < 4
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class TestRosabeatsRest:
|
|
217
|
+
"""Tests for rest method."""
|
|
218
|
+
|
|
219
|
+
@pytest.fixture
|
|
220
|
+
def output_rosabeats(self, temp_audio_file, tmp_path):
|
|
221
|
+
"""Create rosabeats with save output enabled."""
|
|
222
|
+
r = rosabeats(temp_audio_file)
|
|
223
|
+
r.track_beats(beatsper=4, downbeat=0)
|
|
224
|
+
r.enable_output_save(str(tmp_path / "out.wav"))
|
|
225
|
+
r.reset_remix()
|
|
226
|
+
return r
|
|
227
|
+
|
|
228
|
+
def test_rest_adds_silence(self, output_rosabeats):
|
|
229
|
+
"""rest should add silence to remix buffer."""
|
|
230
|
+
initial_index = output_rosabeats.remix_index
|
|
231
|
+
|
|
232
|
+
output_rosabeats.rest(1.0) # 1 beat of silence
|
|
233
|
+
|
|
234
|
+
assert output_rosabeats.remix_index > initial_index
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Tests for segment_song module."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from rosabeats.segment_song import (
|
|
6
|
+
get_segment_letter,
|
|
7
|
+
generate_segment_names,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestGetSegmentLetter:
|
|
12
|
+
"""Tests for get_segment_letter function."""
|
|
13
|
+
|
|
14
|
+
def test_single_letters(self):
|
|
15
|
+
"""Should return single letters for 0-25."""
|
|
16
|
+
assert get_segment_letter(0) == "A"
|
|
17
|
+
assert get_segment_letter(1) == "B"
|
|
18
|
+
assert get_segment_letter(25) == "Z"
|
|
19
|
+
|
|
20
|
+
def test_double_letters(self):
|
|
21
|
+
"""Should return double letters for 26+."""
|
|
22
|
+
assert get_segment_letter(26) == "AA"
|
|
23
|
+
assert get_segment_letter(27) == "AB"
|
|
24
|
+
assert get_segment_letter(51) == "AZ"
|
|
25
|
+
assert get_segment_letter(52) == "BA"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestGenerateSegmentNames:
|
|
29
|
+
"""Tests for generate_segment_names function."""
|
|
30
|
+
|
|
31
|
+
def test_unique_clusters(self):
|
|
32
|
+
"""Should assign unique letters to unique clusters."""
|
|
33
|
+
segments = [
|
|
34
|
+
{'label': 0},
|
|
35
|
+
{'label': 1},
|
|
36
|
+
{'label': 2},
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
names = generate_segment_names(segments)
|
|
40
|
+
|
|
41
|
+
assert names == ["A", "B", "C"]
|
|
42
|
+
|
|
43
|
+
def test_repeated_clusters_get_numbers(self):
|
|
44
|
+
"""Repeated clusters should get numeric suffixes."""
|
|
45
|
+
segments = [
|
|
46
|
+
{'label': 0},
|
|
47
|
+
{'label': 1},
|
|
48
|
+
{'label': 0}, # Second occurrence of cluster 0
|
|
49
|
+
{'label': 1}, # Second occurrence of cluster 1
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
names = generate_segment_names(segments)
|
|
53
|
+
|
|
54
|
+
assert names == ["A", "B", "A2", "B2"]
|
|
55
|
+
|
|
56
|
+
def test_multiple_repeats(self):
|
|
57
|
+
"""Multiple repeats should get incrementing numbers."""
|
|
58
|
+
segments = [
|
|
59
|
+
{'label': 0},
|
|
60
|
+
{'label': 0},
|
|
61
|
+
{'label': 0},
|
|
62
|
+
{'label': 0},
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
names = generate_segment_names(segments)
|
|
66
|
+
|
|
67
|
+
assert names == ["A", "A2", "A3", "A4"]
|
|
68
|
+
|
|
69
|
+
def test_complex_pattern(self):
|
|
70
|
+
"""Should handle complex patterns correctly."""
|
|
71
|
+
# Simulating A B A C A B pattern
|
|
72
|
+
segments = [
|
|
73
|
+
{'label': 0}, # A
|
|
74
|
+
{'label': 1}, # B
|
|
75
|
+
{'label': 0}, # A2
|
|
76
|
+
{'label': 2}, # C
|
|
77
|
+
{'label': 0}, # A3
|
|
78
|
+
{'label': 1}, # B2
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
names = generate_segment_names(segments)
|
|
82
|
+
|
|
83
|
+
assert names == ["A", "B", "A2", "C", "A3", "B2"]
|
|
84
|
+
|
|
85
|
+
def test_empty_segments(self):
|
|
86
|
+
"""Should return empty list for empty input."""
|
|
87
|
+
names = generate_segment_names([])
|
|
88
|
+
assert names == []
|
|
89
|
+
|
|
90
|
+
def test_single_segment(self):
|
|
91
|
+
"""Should handle single segment."""
|
|
92
|
+
segments = [{'label': 0}]
|
|
93
|
+
names = generate_segment_names(segments)
|
|
94
|
+
assert names == ["A"]
|
|
95
|
+
|
|
96
|
+
def test_many_unique_clusters(self):
|
|
97
|
+
"""Should handle many unique clusters."""
|
|
98
|
+
# 30 unique clusters
|
|
99
|
+
segments = [{'label': i} for i in range(30)]
|
|
100
|
+
|
|
101
|
+
names = generate_segment_names(segments)
|
|
102
|
+
|
|
103
|
+
assert len(names) == 30
|
|
104
|
+
assert names[0] == "A"
|
|
105
|
+
assert names[25] == "Z"
|
|
106
|
+
assert names[26] == "AA"
|
|
107
|
+
assert names[27] == "AB"
|
|
108
|
+
|
|
109
|
+
def test_no_apostrophes_in_names(self):
|
|
110
|
+
"""Names should not contain apostrophes."""
|
|
111
|
+
segments = [
|
|
112
|
+
{'label': 0},
|
|
113
|
+
{'label': 0},
|
|
114
|
+
{'label': 0},
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
names = generate_segment_names(segments)
|
|
118
|
+
|
|
119
|
+
for name in names:
|
|
120
|
+
assert "'" not in name, f"Name '{name}' contains apostrophe"
|