bowdet 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bowdet/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ """bowdet - Bow change detection for bowed string instruments"""
2
+
3
+ from .detect import detect
4
+
5
+ __version__ = "0.1.0"
6
+ __all__ = ["detect"]
bowdet/detect.py ADDED
@@ -0,0 +1,197 @@
1
+ """
2
+ bowdet - Bow change detection for bowed string instruments
3
+ Main detection function
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ import librosa
12
+ import soundfile as sf
13
+ from scipy.signal import find_peaks
14
+ from huggingface_hub import hf_hub_download
15
+
16
+
17
+ # ========== Weight Download ==========
18
+
19
+ HF_REPO_ID = "Haotian-Yuan/bowdet"
20
+
21
+
22
+ def _get_weights_dir():
23
+ weights_dir = Path.home() / ".bowdet" / "weights"
24
+ weights_dir.mkdir(parents=True, exist_ok=True)
25
+ return weights_dir
26
+
27
+
28
+ def _download_weight(filename):
29
+ weights_dir = _get_weights_dir()
30
+ weight_path = weights_dir / filename
31
+ if not weight_path.exists():
32
+ print(f"Downloading {filename} from Hugging Face Hub...")
33
+ hf_hub_download(
34
+ repo_id=HF_REPO_ID,
35
+ filename=filename,
36
+ local_dir=str(weights_dir),
37
+ )
38
+ print(f"Weights saved to {weight_path}")
39
+ return weight_path
40
+
41
+
42
+ # ========== CNN Inference (parameters match training exactly) ==========
43
+
44
+ CNN_SR = 22050
45
+ CNN_N_MELS = 80
46
+ CNN_HOP = 220
47
+ CNN_WIN = 512
48
+ CNN_N_FFT = 512
49
+ CNN_WIN_SEC = 1.0
50
+ CNN_STRIDE_SEC = 0.1
51
+
52
+
53
+ def _infer_cnn(audio, model, threshold, min_dist):
54
+ win_samples = int(CNN_WIN_SEC * CNN_SR)
55
+ stride_samples = int(CNN_STRIDE_SEC * CNN_SR)
56
+ times, probs = [], []
57
+
58
+ model.eval()
59
+ with torch.no_grad():
60
+ start = 0
61
+ while start + win_samples <= len(audio):
62
+ chunk = audio[start: start + win_samples].astype(np.float32)
63
+
64
+ # Mel spectrogram — identical to training
65
+ mel = librosa.feature.melspectrogram(
66
+ y=chunk, sr=CNN_SR,
67
+ n_mels=CNN_N_MELS, n_fft=CNN_N_FFT,
68
+ hop_length=CNN_HOP, win_length=CNN_WIN
69
+ )
70
+ mel = librosa.power_to_db(mel, ref=np.max).astype(np.float32)
71
+ mel = (mel - mel.mean()) / (mel.std() + 1e-9)
72
+
73
+ x = torch.tensor(mel).unsqueeze(0).unsqueeze(0) # [1, 1, 80, T]
74
+ prob = torch.sigmoid(model(x)).item()
75
+ times.append((start + win_samples / 2) / CNN_SR)
76
+ probs.append(prob)
77
+ start += stride_samples
78
+
79
+ times = np.array(times)
80
+ probs = np.array(probs)
81
+ min_dist_frames = max(1, int(min_dist / CNN_STRIDE_SEC))
82
+ peaks, _ = find_peaks(probs, height=threshold, distance=min_dist_frames)
83
+ return times[peaks].tolist()
84
+
85
+
86
+ # ========== MERT Inference (parameters match training exactly) ==========
87
+
88
+ MERT_SR = 24000
89
+ MERT_WIN_SEC = 1.0
90
+ MERT_STRIDE_SEC = 0.1
91
+
92
+
93
+ def _infer_mert(audio, model, processor, threshold, min_dist):
94
+ win_samples = int(MERT_WIN_SEC * MERT_SR)
95
+ stride_samples = int(MERT_STRIDE_SEC * MERT_SR)
96
+ times, probs = [], []
97
+
98
+ model.eval()
99
+ with torch.no_grad():
100
+ start = 0
101
+ while start + win_samples <= len(audio):
102
+ chunk = audio[start: start + win_samples].astype(np.float32)
103
+ inputs = processor(chunk, sampling_rate=MERT_SR, return_tensors="pt")
104
+ logit = model(inputs["input_values"])
105
+ probs.append(torch.sigmoid(logit).item())
106
+ times.append((start + win_samples / 2) / MERT_SR)
107
+ start += stride_samples
108
+
109
+ times = np.array(times)
110
+ probs = np.array(probs)
111
+ min_dist_frames = max(1, int(min_dist / MERT_STRIDE_SEC))
112
+ peaks, _ = find_peaks(probs, height=threshold, distance=min_dist_frames)
113
+ return times[peaks].tolist()
114
+
115
+
116
+ # ========== Public API ==========
117
+
118
+ def detect(audio_path, model="mert", threshold=0.5, min_dist=0.35):
119
+ """
120
+ Detect bow changes in a bowed string instrument recording.
121
+
122
+ Parameters
123
+ ----------
124
+ audio_path : str
125
+ Path to audio file (wav recommended).
126
+ model : str, default="mert"
127
+ Model to use: "mert" (more accurate, IoU@0.1 F1=0.616) or
128
+ "cnn" (faster, IoU@0.1 F1=0.554).
129
+ threshold : float, default=0.5
130
+ Peak detection threshold (0-1).
131
+ Lower values detect more bow changes but may increase false positives.
132
+ min_dist : float, default=0.35
133
+ Minimum time between bow changes in seconds.
134
+ Decrease for fast passages (e.g. spiccato): min_dist=0.2
135
+ Increase for slow passages (e.g. long bows): min_dist=0.5
136
+
137
+ Returns
138
+ -------
139
+ list of float
140
+ Timestamps of detected bow changes in seconds.
141
+
142
+ Examples
143
+ --------
144
+ >>> from bowdet import detect
145
+ >>> bow_changes = detect("recording.wav")
146
+ >>> print(bow_changes)
147
+ [1.23, 2.45, 3.67, ...]
148
+
149
+ >>> # Use CNN model (faster, less accurate)
150
+ >>> bow_changes = detect("recording.wav", model="cnn")
151
+
152
+ >>> # Custom parameters for fast spiccato passages
153
+ >>> bow_changes = detect("recording.wav", threshold=0.4, min_dist=0.2)
154
+ """
155
+ if not os.path.exists(audio_path):
156
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
157
+
158
+ if model == "cnn":
159
+ from .model_cnn import BowCNN
160
+
161
+ weight_path = _download_weight("BowDET-C.pth")
162
+ cnn = BowCNN()
163
+ ckpt = torch.load(str(weight_path), map_location="cpu", weights_only=False)
164
+ cnn.load_state_dict(ckpt["model_state_dict"])
165
+ cnn.eval()
166
+
167
+ audio, sr = sf.read(audio_path)
168
+ if audio.ndim > 1:
169
+ audio = audio.mean(axis=1)
170
+ if sr != CNN_SR:
171
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=CNN_SR)
172
+
173
+ return _infer_cnn(audio, cnn, threshold, min_dist)
174
+
175
+ elif model == "mert":
176
+ from .model_mert import MERTClassifier
177
+ from transformers import AutoProcessor
178
+
179
+ weight_path = _download_weight("BowDET-M.pth")
180
+ processor = AutoProcessor.from_pretrained(
181
+ "m-a-p/MERT-v1-95M", trust_remote_code=True
182
+ )
183
+ mert = MERTClassifier()
184
+ ckpt = torch.load(str(weight_path), map_location="cpu", weights_only=False)
185
+ mert.load_state_dict(ckpt["model_state_dict"])
186
+ mert.eval()
187
+
188
+ audio, sr = sf.read(audio_path)
189
+ if audio.ndim > 1:
190
+ audio = audio.mean(axis=1)
191
+ if sr != MERT_SR:
192
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=MERT_SR)
193
+
194
+ return _infer_mert(audio, mert, processor, threshold, min_dist)
195
+
196
+ else:
197
+ raise ValueError(f"Unknown model '{model}'. Choose 'mert' or 'cnn'.")
bowdet/model_cnn.py ADDED
@@ -0,0 +1,36 @@
1
+ """CNN model for bow change detection"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BowCNN(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.features = nn.Sequential(
11
+ nn.Conv2d(1, 16, 3, padding=1),
12
+ nn.BatchNorm2d(16),
13
+ nn.ReLU(),
14
+ nn.Conv2d(16, 32, 3, padding=1),
15
+ nn.BatchNorm2d(32),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2, 2),
18
+ nn.Conv2d(32, 64, 3, padding=1),
19
+ nn.BatchNorm2d(64),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(2, 2),
22
+ nn.Conv2d(64, 64, 3, padding=1),
23
+ nn.BatchNorm2d(64),
24
+ nn.ReLU(),
25
+ nn.AdaptiveAvgPool2d(1),
26
+ )
27
+ self.classifier = nn.Sequential(
28
+ nn.Flatten(),
29
+ nn.Linear(64, 64),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(64, 1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.classifier(self.features(x)).squeeze(1)
bowdet/model_mert.py ADDED
@@ -0,0 +1,27 @@
1
+ """MERT-based model for bow change detection"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModel
6
+
7
+
8
+ class MERTClassifier(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.mert = AutoModel.from_pretrained(
12
+ "m-a-p/MERT-v1-95M", trust_remote_code=True
13
+ )
14
+ self.classifier = nn.Sequential(
15
+ nn.Linear(768, 256),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.3),
18
+ nn.Linear(256, 64),
19
+ nn.ReLU(),
20
+ nn.Dropout(0.3),
21
+ nn.Linear(64, 1),
22
+ )
23
+
24
+ def forward(self, input_values):
25
+ outputs = self.mert(input_values, output_hidden_states=False)
26
+ hidden = outputs.last_hidden_state.mean(dim=1)
27
+ return self.classifier(hidden).squeeze(1)
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: bowdet
3
+ Version: 0.1.0
4
+ Summary: Bow change detection for bowed string instruments
5
+ Home-page: https://github.com/haotian-yuan/bowdet
6
+ Author: Haotian Yuan
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Requires-Python: >=3.8
17
+ Description-Content-Type: text/markdown
18
+ Requires-Dist: torch>=1.13.0
19
+ Requires-Dist: torchaudio>=0.13.0
20
+ Requires-Dist: transformers>=4.30.0
21
+ Requires-Dist: huggingface-hub>=0.16.0
22
+ Requires-Dist: numpy>=1.21.0
23
+ Requires-Dist: scipy>=1.7.0
24
+ Dynamic: author
25
+ Dynamic: classifier
26
+ Dynamic: description
27
+ Dynamic: description-content-type
28
+ Dynamic: home-page
29
+ Dynamic: requires-dist
30
+ Dynamic: requires-python
31
+ Dynamic: summary
32
+
33
+ # bowdet
34
+
35
+ **Bow change detection for bowed string instruments**
36
+
37
+ bowdet detects bow changes in audio recordings of bowed string instruments (viola, violin, cello) using deep learning.
38
+
39
+ ## Installation
40
+
41
+ ```bash
42
+ pip install bowdet
43
+ ```
44
+
45
+ ## Quick Start
46
+
47
+ ```python
48
+ from bowdet import detect
49
+
50
+ # Detect bow changes (returns list of timestamps in seconds)
51
+ bow_changes = detect("recording.wav")
52
+ print(bow_changes) # [1.23, 2.45, 3.67, ...]
53
+ ```
54
+
55
+ ## Models
56
+
57
+ | Model | IoU@0.1 F1 | Speed | Size |
58
+ |-------|-----------|-------|------|
59
+ | MERT (default) | 0.616 | ~2 min/min audio | 378 MB |
60
+ | CNN | 0.554 | ~30 sec/min audio | 5 MB |
61
+
62
+ ## Parameters
63
+
64
+ ```python
65
+ detect(
66
+ audio_path, # path to wav file
67
+ model="mert", # "mert" or "cnn"
68
+ threshold=0.5, # peak detection threshold (0-1)
69
+ min_dist=0.35, # minimum distance between bow changes (seconds)
70
+ # decrease for fast passages (e.g. spiccato): min_dist=0.2
71
+ # increase for slow passages (e.g. long bows): min_dist=0.5
72
+ )
73
+ ```
74
+
75
+ ## Weights
76
+
77
+ Weights are downloaded automatically on first use (~380 MB for MERT).
78
+ Cached at `~/.bowdet/weights/`
79
+
80
+ ## Limitations
81
+
82
+ - Trained on 9 performers across diverse repertoire (Bach, Biber, Penderecki, Hindemith, etc.)
83
+ - Performance may degrade on playing styles significantly different from training data
84
+ - Evaluated on viola recordings; expected to generalize to violin and cello
85
+ - Designed for solo string instrument recordings; accompaniment or ensemble recordings may reduce accuracy
86
+
87
+ ## Citation
88
+
89
+ [paper pending]
90
+
91
+ ## License
92
+
93
+ MIT
@@ -0,0 +1,8 @@
1
+ bowdet/__init__.py,sha256=A8XnAkWdRshEdu20sXpAPgiTAin17kdgFsxoEZewusY,137
2
+ bowdet/detect.py,sha256=tZV-_pxcRO9ryUjbfjvYhmaL5FeF-0Qw-4QmBEqZ4TE,6264
3
+ bowdet/model_cnn.py,sha256=lE8rHxoO8vGIOhfge-1s95rXo6tS1HCNKpIi1Gw3JOg,981
4
+ bowdet/model_mert.py,sha256=vtgvKEuElR2Xca_kD76BXABYbb-r12xfpy5BC26HLu0,781
5
+ bowdet-0.1.0.dist-info/METADATA,sha256=gGQIl-sgDaIln5GdwO2j8FC9rWshTXYh3t1qXRts6qA,2759
6
+ bowdet-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
7
+ bowdet-0.1.0.dist-info/top_level.txt,sha256=mADqg52I_9JN6B8dkhXwrxA0oswmGeuYBVfsNjxQgLE,7
8
+ bowdet-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ bowdet