bowdet 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bowdet/__init__.py +6 -0
- bowdet/detect.py +197 -0
- bowdet/model_cnn.py +36 -0
- bowdet/model_mert.py +27 -0
- bowdet-0.1.0.dist-info/METADATA +93 -0
- bowdet-0.1.0.dist-info/RECORD +8 -0
- bowdet-0.1.0.dist-info/WHEEL +5 -0
- bowdet-0.1.0.dist-info/top_level.txt +1 -0
bowdet/__init__.py
ADDED
bowdet/detect.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
bowdet - Bow change detection for bowed string instruments
|
|
3
|
+
Main detection function
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
import librosa
|
|
12
|
+
import soundfile as sf
|
|
13
|
+
from scipy.signal import find_peaks
|
|
14
|
+
from huggingface_hub import hf_hub_download
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ========== Weight Download ==========
|
|
18
|
+
|
|
19
|
+
HF_REPO_ID = "Haotian-Yuan/bowdet"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_weights_dir():
|
|
23
|
+
weights_dir = Path.home() / ".bowdet" / "weights"
|
|
24
|
+
weights_dir.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
return weights_dir
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _download_weight(filename):
|
|
29
|
+
weights_dir = _get_weights_dir()
|
|
30
|
+
weight_path = weights_dir / filename
|
|
31
|
+
if not weight_path.exists():
|
|
32
|
+
print(f"Downloading {filename} from Hugging Face Hub...")
|
|
33
|
+
hf_hub_download(
|
|
34
|
+
repo_id=HF_REPO_ID,
|
|
35
|
+
filename=filename,
|
|
36
|
+
local_dir=str(weights_dir),
|
|
37
|
+
)
|
|
38
|
+
print(f"Weights saved to {weight_path}")
|
|
39
|
+
return weight_path
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ========== CNN Inference (parameters match training exactly) ==========
|
|
43
|
+
|
|
44
|
+
CNN_SR = 22050
|
|
45
|
+
CNN_N_MELS = 80
|
|
46
|
+
CNN_HOP = 220
|
|
47
|
+
CNN_WIN = 512
|
|
48
|
+
CNN_N_FFT = 512
|
|
49
|
+
CNN_WIN_SEC = 1.0
|
|
50
|
+
CNN_STRIDE_SEC = 0.1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _infer_cnn(audio, model, threshold, min_dist):
|
|
54
|
+
win_samples = int(CNN_WIN_SEC * CNN_SR)
|
|
55
|
+
stride_samples = int(CNN_STRIDE_SEC * CNN_SR)
|
|
56
|
+
times, probs = [], []
|
|
57
|
+
|
|
58
|
+
model.eval()
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
start = 0
|
|
61
|
+
while start + win_samples <= len(audio):
|
|
62
|
+
chunk = audio[start: start + win_samples].astype(np.float32)
|
|
63
|
+
|
|
64
|
+
# Mel spectrogram — identical to training
|
|
65
|
+
mel = librosa.feature.melspectrogram(
|
|
66
|
+
y=chunk, sr=CNN_SR,
|
|
67
|
+
n_mels=CNN_N_MELS, n_fft=CNN_N_FFT,
|
|
68
|
+
hop_length=CNN_HOP, win_length=CNN_WIN
|
|
69
|
+
)
|
|
70
|
+
mel = librosa.power_to_db(mel, ref=np.max).astype(np.float32)
|
|
71
|
+
mel = (mel - mel.mean()) / (mel.std() + 1e-9)
|
|
72
|
+
|
|
73
|
+
x = torch.tensor(mel).unsqueeze(0).unsqueeze(0) # [1, 1, 80, T]
|
|
74
|
+
prob = torch.sigmoid(model(x)).item()
|
|
75
|
+
times.append((start + win_samples / 2) / CNN_SR)
|
|
76
|
+
probs.append(prob)
|
|
77
|
+
start += stride_samples
|
|
78
|
+
|
|
79
|
+
times = np.array(times)
|
|
80
|
+
probs = np.array(probs)
|
|
81
|
+
min_dist_frames = max(1, int(min_dist / CNN_STRIDE_SEC))
|
|
82
|
+
peaks, _ = find_peaks(probs, height=threshold, distance=min_dist_frames)
|
|
83
|
+
return times[peaks].tolist()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# ========== MERT Inference (parameters match training exactly) ==========
|
|
87
|
+
|
|
88
|
+
MERT_SR = 24000
|
|
89
|
+
MERT_WIN_SEC = 1.0
|
|
90
|
+
MERT_STRIDE_SEC = 0.1
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _infer_mert(audio, model, processor, threshold, min_dist):
|
|
94
|
+
win_samples = int(MERT_WIN_SEC * MERT_SR)
|
|
95
|
+
stride_samples = int(MERT_STRIDE_SEC * MERT_SR)
|
|
96
|
+
times, probs = [], []
|
|
97
|
+
|
|
98
|
+
model.eval()
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
start = 0
|
|
101
|
+
while start + win_samples <= len(audio):
|
|
102
|
+
chunk = audio[start: start + win_samples].astype(np.float32)
|
|
103
|
+
inputs = processor(chunk, sampling_rate=MERT_SR, return_tensors="pt")
|
|
104
|
+
logit = model(inputs["input_values"])
|
|
105
|
+
probs.append(torch.sigmoid(logit).item())
|
|
106
|
+
times.append((start + win_samples / 2) / MERT_SR)
|
|
107
|
+
start += stride_samples
|
|
108
|
+
|
|
109
|
+
times = np.array(times)
|
|
110
|
+
probs = np.array(probs)
|
|
111
|
+
min_dist_frames = max(1, int(min_dist / MERT_STRIDE_SEC))
|
|
112
|
+
peaks, _ = find_peaks(probs, height=threshold, distance=min_dist_frames)
|
|
113
|
+
return times[peaks].tolist()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ========== Public API ==========
|
|
117
|
+
|
|
118
|
+
def detect(audio_path, model="mert", threshold=0.5, min_dist=0.35):
|
|
119
|
+
"""
|
|
120
|
+
Detect bow changes in a bowed string instrument recording.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
audio_path : str
|
|
125
|
+
Path to audio file (wav recommended).
|
|
126
|
+
model : str, default="mert"
|
|
127
|
+
Model to use: "mert" (more accurate, IoU@0.1 F1=0.616) or
|
|
128
|
+
"cnn" (faster, IoU@0.1 F1=0.554).
|
|
129
|
+
threshold : float, default=0.5
|
|
130
|
+
Peak detection threshold (0-1).
|
|
131
|
+
Lower values detect more bow changes but may increase false positives.
|
|
132
|
+
min_dist : float, default=0.35
|
|
133
|
+
Minimum time between bow changes in seconds.
|
|
134
|
+
Decrease for fast passages (e.g. spiccato): min_dist=0.2
|
|
135
|
+
Increase for slow passages (e.g. long bows): min_dist=0.5
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
list of float
|
|
140
|
+
Timestamps of detected bow changes in seconds.
|
|
141
|
+
|
|
142
|
+
Examples
|
|
143
|
+
--------
|
|
144
|
+
>>> from bowdet import detect
|
|
145
|
+
>>> bow_changes = detect("recording.wav")
|
|
146
|
+
>>> print(bow_changes)
|
|
147
|
+
[1.23, 2.45, 3.67, ...]
|
|
148
|
+
|
|
149
|
+
>>> # Use CNN model (faster, less accurate)
|
|
150
|
+
>>> bow_changes = detect("recording.wav", model="cnn")
|
|
151
|
+
|
|
152
|
+
>>> # Custom parameters for fast spiccato passages
|
|
153
|
+
>>> bow_changes = detect("recording.wav", threshold=0.4, min_dist=0.2)
|
|
154
|
+
"""
|
|
155
|
+
if not os.path.exists(audio_path):
|
|
156
|
+
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
157
|
+
|
|
158
|
+
if model == "cnn":
|
|
159
|
+
from .model_cnn import BowCNN
|
|
160
|
+
|
|
161
|
+
weight_path = _download_weight("BowDET-C.pth")
|
|
162
|
+
cnn = BowCNN()
|
|
163
|
+
ckpt = torch.load(str(weight_path), map_location="cpu", weights_only=False)
|
|
164
|
+
cnn.load_state_dict(ckpt["model_state_dict"])
|
|
165
|
+
cnn.eval()
|
|
166
|
+
|
|
167
|
+
audio, sr = sf.read(audio_path)
|
|
168
|
+
if audio.ndim > 1:
|
|
169
|
+
audio = audio.mean(axis=1)
|
|
170
|
+
if sr != CNN_SR:
|
|
171
|
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=CNN_SR)
|
|
172
|
+
|
|
173
|
+
return _infer_cnn(audio, cnn, threshold, min_dist)
|
|
174
|
+
|
|
175
|
+
elif model == "mert":
|
|
176
|
+
from .model_mert import MERTClassifier
|
|
177
|
+
from transformers import AutoProcessor
|
|
178
|
+
|
|
179
|
+
weight_path = _download_weight("BowDET-M.pth")
|
|
180
|
+
processor = AutoProcessor.from_pretrained(
|
|
181
|
+
"m-a-p/MERT-v1-95M", trust_remote_code=True
|
|
182
|
+
)
|
|
183
|
+
mert = MERTClassifier()
|
|
184
|
+
ckpt = torch.load(str(weight_path), map_location="cpu", weights_only=False)
|
|
185
|
+
mert.load_state_dict(ckpt["model_state_dict"])
|
|
186
|
+
mert.eval()
|
|
187
|
+
|
|
188
|
+
audio, sr = sf.read(audio_path)
|
|
189
|
+
if audio.ndim > 1:
|
|
190
|
+
audio = audio.mean(axis=1)
|
|
191
|
+
if sr != MERT_SR:
|
|
192
|
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=MERT_SR)
|
|
193
|
+
|
|
194
|
+
return _infer_mert(audio, mert, processor, threshold, min_dist)
|
|
195
|
+
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(f"Unknown model '{model}'. Choose 'mert' or 'cnn'.")
|
bowdet/model_cnn.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""CNN model for bow change detection"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BowCNN(nn.Module):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.features = nn.Sequential(
|
|
11
|
+
nn.Conv2d(1, 16, 3, padding=1),
|
|
12
|
+
nn.BatchNorm2d(16),
|
|
13
|
+
nn.ReLU(),
|
|
14
|
+
nn.Conv2d(16, 32, 3, padding=1),
|
|
15
|
+
nn.BatchNorm2d(32),
|
|
16
|
+
nn.ReLU(),
|
|
17
|
+
nn.MaxPool2d(2, 2),
|
|
18
|
+
nn.Conv2d(32, 64, 3, padding=1),
|
|
19
|
+
nn.BatchNorm2d(64),
|
|
20
|
+
nn.ReLU(),
|
|
21
|
+
nn.MaxPool2d(2, 2),
|
|
22
|
+
nn.Conv2d(64, 64, 3, padding=1),
|
|
23
|
+
nn.BatchNorm2d(64),
|
|
24
|
+
nn.ReLU(),
|
|
25
|
+
nn.AdaptiveAvgPool2d(1),
|
|
26
|
+
)
|
|
27
|
+
self.classifier = nn.Sequential(
|
|
28
|
+
nn.Flatten(),
|
|
29
|
+
nn.Linear(64, 64),
|
|
30
|
+
nn.ReLU(),
|
|
31
|
+
nn.Dropout(0.5),
|
|
32
|
+
nn.Linear(64, 1),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
return self.classifier(self.features(x)).squeeze(1)
|
bowdet/model_mert.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""MERT-based model for bow change detection"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from transformers import AutoModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MERTClassifier(nn.Module):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.mert = AutoModel.from_pretrained(
|
|
12
|
+
"m-a-p/MERT-v1-95M", trust_remote_code=True
|
|
13
|
+
)
|
|
14
|
+
self.classifier = nn.Sequential(
|
|
15
|
+
nn.Linear(768, 256),
|
|
16
|
+
nn.ReLU(),
|
|
17
|
+
nn.Dropout(0.3),
|
|
18
|
+
nn.Linear(256, 64),
|
|
19
|
+
nn.ReLU(),
|
|
20
|
+
nn.Dropout(0.3),
|
|
21
|
+
nn.Linear(64, 1),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def forward(self, input_values):
|
|
25
|
+
outputs = self.mert(input_values, output_hidden_states=False)
|
|
26
|
+
hidden = outputs.last_hidden_state.mean(dim=1)
|
|
27
|
+
return self.classifier(hidden).squeeze(1)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: bowdet
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Bow change detection for bowed string instruments
|
|
5
|
+
Home-page: https://github.com/haotian-yuan/bowdet
|
|
6
|
+
Author: Haotian Yuan
|
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: torch>=1.13.0
|
|
19
|
+
Requires-Dist: torchaudio>=0.13.0
|
|
20
|
+
Requires-Dist: transformers>=4.30.0
|
|
21
|
+
Requires-Dist: huggingface-hub>=0.16.0
|
|
22
|
+
Requires-Dist: numpy>=1.21.0
|
|
23
|
+
Requires-Dist: scipy>=1.7.0
|
|
24
|
+
Dynamic: author
|
|
25
|
+
Dynamic: classifier
|
|
26
|
+
Dynamic: description
|
|
27
|
+
Dynamic: description-content-type
|
|
28
|
+
Dynamic: home-page
|
|
29
|
+
Dynamic: requires-dist
|
|
30
|
+
Dynamic: requires-python
|
|
31
|
+
Dynamic: summary
|
|
32
|
+
|
|
33
|
+
# bowdet
|
|
34
|
+
|
|
35
|
+
**Bow change detection for bowed string instruments**
|
|
36
|
+
|
|
37
|
+
bowdet detects bow changes in audio recordings of bowed string instruments (viola, violin, cello) using deep learning.
|
|
38
|
+
|
|
39
|
+
## Installation
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install bowdet
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Quick Start
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
from bowdet import detect
|
|
49
|
+
|
|
50
|
+
# Detect bow changes (returns list of timestamps in seconds)
|
|
51
|
+
bow_changes = detect("recording.wav")
|
|
52
|
+
print(bow_changes) # [1.23, 2.45, 3.67, ...]
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Models
|
|
56
|
+
|
|
57
|
+
| Model | IoU@0.1 F1 | Speed | Size |
|
|
58
|
+
|-------|-----------|-------|------|
|
|
59
|
+
| MERT (default) | 0.616 | ~2 min/min audio | 378 MB |
|
|
60
|
+
| CNN | 0.554 | ~30 sec/min audio | 5 MB |
|
|
61
|
+
|
|
62
|
+
## Parameters
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
detect(
|
|
66
|
+
audio_path, # path to wav file
|
|
67
|
+
model="mert", # "mert" or "cnn"
|
|
68
|
+
threshold=0.5, # peak detection threshold (0-1)
|
|
69
|
+
min_dist=0.35, # minimum distance between bow changes (seconds)
|
|
70
|
+
# decrease for fast passages (e.g. spiccato): min_dist=0.2
|
|
71
|
+
# increase for slow passages (e.g. long bows): min_dist=0.5
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Weights
|
|
76
|
+
|
|
77
|
+
Weights are downloaded automatically on first use (~380 MB for MERT).
|
|
78
|
+
Cached at `~/.bowdet/weights/`
|
|
79
|
+
|
|
80
|
+
## Limitations
|
|
81
|
+
|
|
82
|
+
- Trained on 9 performers across diverse repertoire (Bach, Biber, Penderecki, Hindemith, etc.)
|
|
83
|
+
- Performance may degrade on playing styles significantly different from training data
|
|
84
|
+
- Evaluated on viola recordings; expected to generalize to violin and cello
|
|
85
|
+
- Designed for solo string instrument recordings; accompaniment or ensemble recordings may reduce accuracy
|
|
86
|
+
|
|
87
|
+
## Citation
|
|
88
|
+
|
|
89
|
+
[paper pending]
|
|
90
|
+
|
|
91
|
+
## License
|
|
92
|
+
|
|
93
|
+
MIT
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
bowdet/__init__.py,sha256=A8XnAkWdRshEdu20sXpAPgiTAin17kdgFsxoEZewusY,137
|
|
2
|
+
bowdet/detect.py,sha256=tZV-_pxcRO9ryUjbfjvYhmaL5FeF-0Qw-4QmBEqZ4TE,6264
|
|
3
|
+
bowdet/model_cnn.py,sha256=lE8rHxoO8vGIOhfge-1s95rXo6tS1HCNKpIi1Gw3JOg,981
|
|
4
|
+
bowdet/model_mert.py,sha256=vtgvKEuElR2Xca_kD76BXABYbb-r12xfpy5BC26HLu0,781
|
|
5
|
+
bowdet-0.1.0.dist-info/METADATA,sha256=gGQIl-sgDaIln5GdwO2j8FC9rWshTXYh3t1qXRts6qA,2759
|
|
6
|
+
bowdet-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
7
|
+
bowdet-0.1.0.dist-info/top_level.txt,sha256=mADqg52I_9JN6B8dkhXwrxA0oswmGeuYBVfsNjxQgLE,7
|
|
8
|
+
bowdet-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
bowdet
|