canapy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canapy/__init__.py +15 -0
- canapy/annotator/__init__.py +63 -0
- canapy/annotator/base.py +243 -0
- canapy/annotator/commons/__init__.py +3 -0
- canapy/annotator/commons/compat.py +28 -0
- canapy/annotator/commons/esn.py +102 -0
- canapy/annotator/commons/mfccs.py +144 -0
- canapy/annotator/commons/postprocess.py +177 -0
- canapy/annotator/ensemble.py +359 -0
- canapy/annotator/nsynannotator.py +268 -0
- canapy/annotator/synannotator.py +242 -0
- canapy/corpus.py +558 -0
- canapy/correction/__init__.py +4 -0
- canapy/correction/base.py +106 -0
- canapy/formats/__init__.py +6 -0
- canapy/formats/marron1csv.py +193 -0
- canapy/log.py +21 -0
- canapy/metrics/__init__.py +4 -0
- canapy/metrics/base.py +115 -0
- canapy/metrics/utils.py +71 -0
- canapy/optimization.py +906 -0
- canapy/plots/__init__.py +4 -0
- canapy/plots/base.py +272 -0
- canapy/timings/__init__.py +4 -0
- canapy/timings/base.py +78 -0
- canapy/transforms/__init__.py +5 -0
- canapy/transforms/base.py +96 -0
- canapy/transforms/commons/__init__.py +3 -0
- canapy/transforms/commons/annots.py +99 -0
- canapy/transforms/commons/audio.py +225 -0
- canapy/transforms/commons/training.py +193 -0
- canapy/transforms/nsynesn.py +193 -0
- canapy/transforms/synesn.py +15 -0
- canapy/utils/__init__.py +11 -0
- canapy/utils/arrays.py +70 -0
- canapy/utils/exceptions.py +13 -0
- canapy/utils/tempstorage.py +20 -0
- canapy-0.1.0.dist-info/METADATA +213 -0
- canapy-0.1.0.dist-info/RECORD +80 -0
- canapy-0.1.0.dist-info/WHEEL +4 -0
- canapy-0.1.0.dist-info/entry_points.txt +4 -0
- canapy-0.1.0.dist-info/licenses/LICENSE +29 -0
- config/__init__.py +27 -0
- config/config.py +151 -0
- config/default/default.config.toml +62 -0
- config/presets/bengalese_finch.toml +64 -0
- config/presets/canary.toml +62 -0
- config/presets/infant_marmoset.toml +64 -0
- config/presets/mouse_binary_classification.toml +64 -0
- config/presets/zebra_finch.toml +64 -0
- config/store/default.config.toml +62 -0
- config/store/default.config.yml +404 -0
- config/template/default.config.toml +61 -0
- dashboard/__init__.py +6 -0
- dashboard/__main__.py +299 -0
- dashboard/app.py +105 -0
- dashboard/controler/__init__.py +4 -0
- dashboard/controler/base.py +1153 -0
- dashboard/controler/corpusutils.py +30 -0
- dashboard/controler/segments.py +89 -0
- dashboard/view/__init__.py +3 -0
- dashboard/view/annotate/__init__py +0 -0
- dashboard/view/annotate/annotate_dash.py +472 -0
- dashboard/view/eval/__init__.py +3 -0
- dashboard/view/eval/classmerge.py +499 -0
- dashboard/view/eval/eval_dash.py +239 -0
- dashboard/view/eval/samplecorrection.py +377 -0
- dashboard/view/export/__init__.py +3 -0
- dashboard/view/export/export_dash.py +333 -0
- dashboard/view/helpers.py +600 -0
- dashboard/view/home/__init__.py +3 -0
- dashboard/view/home/home_dash.py +172 -0
- dashboard/view/loaddata/__init__.py +3 -0
- dashboard/view/loaddata/load_data_dash.py +617 -0
- dashboard/view/preprocess/__init__.py +3 -0
- dashboard/view/preprocess/preprocess_dash.py +1601 -0
- dashboard/view/settings/__init__.py +3 -0
- dashboard/view/settings/settings_dash.py +1077 -0
- dashboard/view/train/__init__.py +3 -0
- dashboard/view/train/train_dash.py +449 -0
canapy/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Author: Axel Arnaud
|
|
2
|
+
# Licence: BSD-3-Clause
|
|
3
|
+
# Copyright: Axel Arnaud
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import crowsetta
|
|
7
|
+
|
|
8
|
+
from . import log
|
|
9
|
+
from .corpus import Corpus
|
|
10
|
+
from config import Config
|
|
11
|
+
from .formats.marron1csv import Marron1CSV
|
|
12
|
+
|
|
13
|
+
crowsetta.register_format(Marron1CSV)
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Author: Axel Arnaud
|
|
2
|
+
# Licence: BSD-3-Clause
|
|
3
|
+
# Copyright: Axel Arnaud
|
|
4
|
+
"""
|
|
5
|
+
This module provides annotator classes and a registry for audio classification in Canapy.
|
|
6
|
+
|
|
7
|
+
Classes
|
|
8
|
+
-------
|
|
9
|
+
Annotator
|
|
10
|
+
Base class for annotators.
|
|
11
|
+
SynAnnotator
|
|
12
|
+
An annotator that uses a Syntaxic approach with an Echo State Network (ESN).
|
|
13
|
+
NSynAnnotator
|
|
14
|
+
An annotator that uses a non-Syntaxic approach with an Echo State Network (ESN).
|
|
15
|
+
Ensemble
|
|
16
|
+
An ensemble annotator that combines the predictions of other annotators.
|
|
17
|
+
|
|
18
|
+
Functions
|
|
19
|
+
---------
|
|
20
|
+
get_annotator
|
|
21
|
+
Retrieves an annotator object by name from the registry.
|
|
22
|
+
get_annotator_names
|
|
23
|
+
Retrieves a list of available annotator names.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
import logging
|
|
27
|
+
|
|
28
|
+
from .base import Annotator
|
|
29
|
+
from .synannotator import SynAnnotator
|
|
30
|
+
from .nsynannotator import NSynAnnotator
|
|
31
|
+
from .ensemble import Ensemble
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger("canapy")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _Registry:
|
|
38
|
+
def __init__(self):
|
|
39
|
+
self._registry = {
|
|
40
|
+
"syn-esn": SynAnnotator,
|
|
41
|
+
"nsyn-esn": NSynAnnotator,
|
|
42
|
+
"ensemble": Ensemble,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
def __getitem__(self, item):
|
|
46
|
+
return self._registry[item]
|
|
47
|
+
|
|
48
|
+
def register_annotator(self, name, cls):
|
|
49
|
+
if name in self._registry:
|
|
50
|
+
logger.warning(f"'{name}' is already registered. Skipping.")
|
|
51
|
+
return
|
|
52
|
+
self._registry[name] = cls
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
registry = _Registry()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_annotator(name):
|
|
59
|
+
return registry[name]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_annotator_names():
|
|
63
|
+
return sorted(registry._registry.keys())
|
canapy/annotator/base.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# Authors: Nathan Trouvain at 29/06/2023 <nathan.trouvain<at>inria.fr>
|
|
2
|
+
# Vincent Gardies at 10/07/2023 <vincent.gardies<at>inria.fr>
|
|
3
|
+
# Licence: BSD-3-Clause
|
|
4
|
+
# Copyright: Nathan Trouvain
|
|
5
|
+
|
|
6
|
+
import abc
|
|
7
|
+
import pickle
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from config import default_config
|
|
12
|
+
from .commons.compat import _Compat, _CompatModelUnpickler
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("canapy")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Annotator(abc.ABC):
|
|
18
|
+
"""
|
|
19
|
+
Base abstract class for annotators
|
|
20
|
+
|
|
21
|
+
The basic usage of an annotator involves three steps:
|
|
22
|
+
1. Load your annotator (using the 'from_disk' method or by creating a new one).
|
|
23
|
+
|
|
24
|
+
2. Train your annotator (using the 'fit' method).
|
|
25
|
+
|
|
26
|
+
3. Make predictions with the annotator (using the 'predict' method).
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_trained: bool = False
|
|
31
|
+
_vocab: list = list()
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_disk(
|
|
35
|
+
cls,
|
|
36
|
+
path,
|
|
37
|
+
config=default_config,
|
|
38
|
+
): # spec_directory=None):
|
|
39
|
+
"""
|
|
40
|
+
Load an annotator object from disk.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
path : str
|
|
45
|
+
Path to the file containing the annotator object.
|
|
46
|
+
config : Config, default=default_config
|
|
47
|
+
The configuration object for the annotator.
|
|
48
|
+
spec_directory : str, optional
|
|
49
|
+
The directory containing the spectrogram files.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
Annotator
|
|
54
|
+
The loaded annotator object.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
NotImplementedError
|
|
59
|
+
If the loaded annotator object is of an unsupported type.
|
|
60
|
+
|
|
61
|
+
Note
|
|
62
|
+
----
|
|
63
|
+
Annotators that have been created and saved in the disk with the former version of canapy are also loadable.
|
|
64
|
+
Make sure to give the configuration of the model if you had made some changes, and a spec_directory.
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> from canapy.annotator.base import Annotator
|
|
69
|
+
|
|
70
|
+
>>> my_annotator = Annotator.from_disk("/path/to/annotator")
|
|
71
|
+
>>> # The annotator saved in the 'saved_annotator' file is now loaded into 'my_annotator'
|
|
72
|
+
|
|
73
|
+
>>> from config import default_config
|
|
74
|
+
>>> my_old_config = default_config # Get the config that the model was saved with
|
|
75
|
+
>>> my_annotator_old = Annotator.from_disk("/path/to/old/model",
|
|
76
|
+
>>> config=my_old_config # If no changes were made to the config, default_config will work
|
|
77
|
+
>>> spec_directory="/path/to/spec"
|
|
78
|
+
>>> )
|
|
79
|
+
>>> # The annotator saved in the 'syn' file is loaded into my_annotator_old
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
from . import get_annotator
|
|
84
|
+
|
|
85
|
+
with Path(path).open("rb") as file:
|
|
86
|
+
try:
|
|
87
|
+
loaded_annot = pickle.load(file)
|
|
88
|
+
except ModuleNotFoundError:
|
|
89
|
+
loaded_annot = _CompatModelUnpickler(file).load()
|
|
90
|
+
|
|
91
|
+
# This case concerns Annotators made with this version of Canapy
|
|
92
|
+
if isinstance(loaded_annot, Annotator):
|
|
93
|
+
if hasattr(loaded_annot, 'rpy_model') or (hasattr(loaded_annot, '_vocab') and len(loaded_annot._vocab) > 0):
|
|
94
|
+
loaded_annot._trained = True
|
|
95
|
+
# if spec_directory is not None:
|
|
96
|
+
# loaded_annot.spec_directory = spec_directory
|
|
97
|
+
return loaded_annot
|
|
98
|
+
|
|
99
|
+
# This case concerns Annotators made with a previous version of Canapy
|
|
100
|
+
elif isinstance(loaded_annot, _Compat):
|
|
101
|
+
# The name of the file is used to determine the type of annotator.
|
|
102
|
+
# If you need to change the model's name, make sure it ends with its type (syn / nsyn).
|
|
103
|
+
annotator_type = get_annotator(
|
|
104
|
+
"nsyn-esn" if len(path) > 3 and path[-4] == "n" else "syn-esn"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
new_annotator = annotator_type(config) # spec_directory)
|
|
108
|
+
new_annotator._trained = True
|
|
109
|
+
new_annotator.rpy_model = loaded_annot.esn
|
|
110
|
+
new_annotator._vocab = loaded_annot.vocab
|
|
111
|
+
return new_annotator
|
|
112
|
+
else:
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
def to_disk(self, path):
|
|
116
|
+
"""
|
|
117
|
+
Save the annotator object to disk.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
path : str or Path
|
|
122
|
+
Path to the file where the annotator object will be saved.
|
|
123
|
+
|
|
124
|
+
Example
|
|
125
|
+
-------
|
|
126
|
+
>>> from canapy.annotator.synannotator import SynAnnotator
|
|
127
|
+
>>> from config import default_config
|
|
128
|
+
>>> my_annotator = SynAnnotator(default_config, "/home/vincent/Documents/data_canary/spec")
|
|
129
|
+
>>> my_annotator.to_disk("/home/vincent/Documents/data_canary/annotators/my_annotator")
|
|
130
|
+
>>> # The annotator is now saved in the disk
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
with Path(path).open("wb+") as file:
|
|
134
|
+
pickle.dump(self, file)
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def trained(self):
|
|
138
|
+
"""
|
|
139
|
+
Property indicating if the annotator is trained.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
bool
|
|
144
|
+
True if the annotator is trained, False otherwise.
|
|
145
|
+
|
|
146
|
+
Example
|
|
147
|
+
-------
|
|
148
|
+
>>> from canapy.annotator.ensemble import Ensemble
|
|
149
|
+
>>> from config import default_config
|
|
150
|
+
>>> my_ensemble_annotator = Ensemble (default_config, None)
|
|
151
|
+
>>> # For example, we are using an Ensemble annotator
|
|
152
|
+
>>> print(f"My annotator is trained : {my_ensemble_annotator.trained()}")
|
|
153
|
+
My annotator is trained : False
|
|
154
|
+
>>> from canapy.corpus import Corpus
|
|
155
|
+
>>> corpus = Corpus.from_directory(audio_directory="/path/to/audio", annots_directory="/path/to/annotation")
|
|
156
|
+
>>> my_ensemble_annotator.fit(corpus)
|
|
157
|
+
>>> # We create a corpus from some files and then train the annotator on it
|
|
158
|
+
>>> print(f"My annotator is trained : {my_ensemble_annotator.trained()}")
|
|
159
|
+
My annotator is trained : True
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
return self._trained
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def vocab(self):
|
|
166
|
+
"""
|
|
167
|
+
Property containing the vocabulary of the annotator.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
list
|
|
172
|
+
The vocabulary of the annotator.
|
|
173
|
+
|
|
174
|
+
Note
|
|
175
|
+
----
|
|
176
|
+
The vocabulary of the annotator is determined during the 'fit' method.
|
|
177
|
+
Trying to access the vocabulary of an annotator before training it is useless
|
|
178
|
+
"""
|
|
179
|
+
return self._vocab
|
|
180
|
+
|
|
181
|
+
def fit(self, corpus):
|
|
182
|
+
"""
|
|
183
|
+
Fit the annotator to the given corpus.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
corpus : Corpus
|
|
188
|
+
The corpus object used for training the annotator.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
Annotator
|
|
193
|
+
The trained annotator itself.
|
|
194
|
+
|
|
195
|
+
Note
|
|
196
|
+
----
|
|
197
|
+
Even though this function returns the trained annotator, the annotator itself is trained.
|
|
198
|
+
|
|
199
|
+
Example
|
|
200
|
+
-------
|
|
201
|
+
>>> from canapy.annotator.nsynannotator import NSynAnnotator
|
|
202
|
+
>>> from config import default_config
|
|
203
|
+
>>> my_annotator = SynAnnotator(default_config, "/path/to/spec")
|
|
204
|
+
>>> # A not-syntaxic annotator is used in this example
|
|
205
|
+
>>> from canapy.corpus import Corpus
|
|
206
|
+
>>> corpus = Corpus.from_directory(audio_directory="/path/to/audio", annots_directory="/path/to/annotation")
|
|
207
|
+
>>> my_annotator_trained = my_annotator.fit(corpus)
|
|
208
|
+
>>> # The annotator is now trained with the given corpus
|
|
209
|
+
>>> my_annotator_trained == my_annotator
|
|
210
|
+
True
|
|
211
|
+
|
|
212
|
+
"""
|
|
213
|
+
raise NotImplementedError
|
|
214
|
+
|
|
215
|
+
def predict(self, corpus, return_raw=False, redo_transforms=False):
|
|
216
|
+
"""
|
|
217
|
+
Predict annotations for the given corpus.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
corpus : Corpus
|
|
222
|
+
The corpus object for which to predict annotations.
|
|
223
|
+
return_raw : bool, optional
|
|
224
|
+
If True, raw annotations are added into the 'data_resources'
|
|
225
|
+
Raw outputs are necessary to train an 'EnsemleAnnotator' on a corpus
|
|
226
|
+
redo_transforms : bool, optional
|
|
227
|
+
If True, redo the transformations on the corpus before predicting.
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
Corpus
|
|
232
|
+
The corpus object with predicted annotations.
|
|
233
|
+
|
|
234
|
+
Note
|
|
235
|
+
----
|
|
236
|
+
Annotators need to be trained before being able to make predictions
|
|
237
|
+
|
|
238
|
+
Examples
|
|
239
|
+
--------
|
|
240
|
+
Find examples of use in annotators subclasses
|
|
241
|
+
|
|
242
|
+
"""
|
|
243
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Author: Axel Arnaud
|
|
2
|
+
# Licence: BSD-3-Clause
|
|
3
|
+
# Copyright: Axel Arnaud
|
|
4
|
+
import pickle
|
|
5
|
+
from types import SimpleNamespace
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class _CompatModelUnpickler(pickle._Unpickler):
|
|
9
|
+
def __init__(self, fp):
|
|
10
|
+
super().__init__(fp)
|
|
11
|
+
self._magic_classes = {}
|
|
12
|
+
|
|
13
|
+
def find_class(self, module, name):
|
|
14
|
+
if "canapy" in module.split(".") and module.split(".")[1] in [
|
|
15
|
+
"dataset",
|
|
16
|
+
"processor",
|
|
17
|
+
"sequence",
|
|
18
|
+
"model",
|
|
19
|
+
"config",
|
|
20
|
+
]:
|
|
21
|
+
return _Compat
|
|
22
|
+
else:
|
|
23
|
+
return super().find_class(module, name)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _Compat(SimpleNamespace):
|
|
27
|
+
def __setitem__(self, key, value):
|
|
28
|
+
pass
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Author: Nathan Trouvain at 04/07/2023 <nathan.trouvain<at>inria.fr>
|
|
2
|
+
# Licence: BSD-3-Clause
|
|
3
|
+
# Copyright: Nathan Trouvain
|
|
4
|
+
import numpy as np
|
|
5
|
+
import reservoirpy as rpy
|
|
6
|
+
from reservoirpy.nodes import Reservoir, Ridge
|
|
7
|
+
from reservoirpy import ESN
|
|
8
|
+
from reservoirpy.mat_gen import fast_spectral_initialization
|
|
9
|
+
from canapy.utils.exceptions import NotTrainedError
|
|
10
|
+
from .mfccs import load_mfccs_for_annotation
|
|
11
|
+
|
|
12
|
+
def maximum_a_posteriori(logits, classes=None):
|
|
13
|
+
logits = np.atleast_2d(logits)
|
|
14
|
+
predictions = np.argmax(logits, axis=1)
|
|
15
|
+
if classes is not None:
|
|
16
|
+
predictions = np.take(classes, predictions)
|
|
17
|
+
return predictions
|
|
18
|
+
|
|
19
|
+
def init_esn_model(model_config, input_dim, audio_features, seed=None, workers=None, dtype=np.float64, **overrides):
|
|
20
|
+
rpy.set_seed(seed)
|
|
21
|
+
|
|
22
|
+
def get_p(key, default_attr, default_val):
|
|
23
|
+
if key in overrides:
|
|
24
|
+
return overrides[key]
|
|
25
|
+
return getattr(model_config, default_attr, default_val)
|
|
26
|
+
|
|
27
|
+
sr = get_p("sr", "sr", 0.4)
|
|
28
|
+
leak = get_p("leak", "leak", 0.1)
|
|
29
|
+
iss_val = get_p("iss", "iss", 0.0005)
|
|
30
|
+
ridge = get_p("ridge", "ridge", 1e-8)
|
|
31
|
+
isd_val = get_p("isd", "isd", 0.02)
|
|
32
|
+
isd2_val = get_p("isd2", "isd2", 0.002)
|
|
33
|
+
n_units = get_p("units", "units", 1000)
|
|
34
|
+
|
|
35
|
+
# Construction du scaling block par block
|
|
36
|
+
scalings = []
|
|
37
|
+
if "mfcc" in audio_features:
|
|
38
|
+
scalings.append(np.ones((input_dim,)) * iss_val)
|
|
39
|
+
if "delta" in audio_features:
|
|
40
|
+
scalings.append(np.ones((input_dim,)) * isd_val)
|
|
41
|
+
if "delta2" in audio_features:
|
|
42
|
+
scalings.append(np.ones((input_dim,)) * isd2_val)
|
|
43
|
+
|
|
44
|
+
if not scalings:
|
|
45
|
+
input_scaling = iss_val
|
|
46
|
+
else:
|
|
47
|
+
input_scaling = np.concatenate(scalings, axis=0)
|
|
48
|
+
|
|
49
|
+
reservoir = Reservoir(
|
|
50
|
+
n_units,
|
|
51
|
+
sr=sr,
|
|
52
|
+
lr=leak,
|
|
53
|
+
input_scaling=input_scaling,
|
|
54
|
+
bias=iss_val,
|
|
55
|
+
W=fast_spectral_initialization,
|
|
56
|
+
dtype=dtype,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
readout = Ridge(ridge=ridge)
|
|
60
|
+
n_workers = workers if workers is not None else getattr(model_config, "workers", -1)
|
|
61
|
+
|
|
62
|
+
return ESN(
|
|
63
|
+
reservoir=reservoir,
|
|
64
|
+
readout=readout,
|
|
65
|
+
workers=n_workers,
|
|
66
|
+
backend=getattr(model_config, "backend", "multiprocessing")
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def fit_esn_seq_by_seq(model, X_seqs, Y_seqs):
|
|
70
|
+
"""Memory-efficient ESN fit: interleave reservoir.run and readout.worker
|
|
71
|
+
one sequence at a time so peak RAM = one state matrix instead of all of them.
|
|
72
|
+
Mathematically equivalent to model.fit(X_seqs, Y_seqs).
|
|
73
|
+
"""
|
|
74
|
+
model.fit([np.asarray(X_seqs[0])], [np.asarray(Y_seqs[0])])
|
|
75
|
+
reservoir = model.reservoir
|
|
76
|
+
readout = model.readout
|
|
77
|
+
|
|
78
|
+
def _gen():
|
|
79
|
+
for x_seq, y_seq in zip(X_seqs, Y_seqs):
|
|
80
|
+
reservoir.reset()
|
|
81
|
+
states = np.asarray(reservoir.run(x_seq))
|
|
82
|
+
yield readout.worker(states, y_seq)
|
|
83
|
+
del states
|
|
84
|
+
|
|
85
|
+
readout.master(_gen())
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def predict_with_esn(annotator, corpus, return_raw=False, redo_transforms=False):
|
|
89
|
+
if not hasattr(annotator, 'rpy_model'):
|
|
90
|
+
raise NotTrainedError("Annotator does not contain a trained rpy_model.")
|
|
91
|
+
|
|
92
|
+
corpus = annotator.transforms(
|
|
93
|
+
corpus, purpose="annotation", output_directory=corpus.spec_directory, redo=redo_transforms,
|
|
94
|
+
)
|
|
95
|
+
notated_paths, mfccs = load_mfccs_for_annotation(corpus)
|
|
96
|
+
raw_preds = annotator.rpy_model.run(mfccs)
|
|
97
|
+
|
|
98
|
+
if isinstance(raw_preds, np.ndarray) and raw_preds.ndim < 3:
|
|
99
|
+
raw_preds = [raw_preds]
|
|
100
|
+
|
|
101
|
+
cls_preds = [maximum_a_posteriori(y, classes=annotator.vocab) for y in raw_preds]
|
|
102
|
+
return notated_paths, cls_preds, raw_preds if return_raw else None
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Author: Nathan Trouvain at 05/07/2023 <nathan.trouvain<at>inria.fr>
|
|
2
|
+
# Licence: BSD-3-Clause
|
|
3
|
+
# Copyright: Nathan Trouvain
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from ...timings import seconds_to_frames, seconds_to_audio
|
|
10
|
+
from ...utils.exceptions import NotTrainableError, MissingData
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("canapy")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_mfccs_and_repeat_labels(corpus, purpose="training"):
|
|
16
|
+
if purpose == "training":
|
|
17
|
+
split = "train"
|
|
18
|
+
|
|
19
|
+
if len(corpus.dataset.query(split)) == 0:
|
|
20
|
+
raise NotTrainableError(
|
|
21
|
+
"Training data was not provided, or corpus was "
|
|
22
|
+
"not properly divided between train and test data."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
elif purpose == "eval":
|
|
26
|
+
split = "not train"
|
|
27
|
+
|
|
28
|
+
if len(corpus.dataset.query(split)) == 0:
|
|
29
|
+
raise NotTrainableError(
|
|
30
|
+
"Test data was not provided, or corpus was "
|
|
31
|
+
"not properly divided between train and test data."
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError("'purpose' should be either 'training' or 'eval'.")
|
|
35
|
+
|
|
36
|
+
# load data
|
|
37
|
+
df = corpus.dataset.query(split).copy()
|
|
38
|
+
|
|
39
|
+
if "syn_mfcc" not in corpus.data_resources:
|
|
40
|
+
raise MissingData(
|
|
41
|
+
"'syn_mfcc' were never computed or can't be found in Corpus. "
|
|
42
|
+
"Maybe provide and audio/spec directory to the Corpus ?"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
mfcc_paths = corpus.data_resources["syn_mfcc"]
|
|
46
|
+
|
|
47
|
+
df["seqid"] = df["sequence"].astype(str) + df["annotation"].astype(str)
|
|
48
|
+
|
|
49
|
+
sampling_rate = corpus.config.transforms.audio.sampling_rate
|
|
50
|
+
hop_length = seconds_to_audio(
|
|
51
|
+
corpus.config.transforms.audio.hop_length, sampling_rate
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
df["onset_spec"] = seconds_to_frames(df["onset_s"], hop_length, sampling_rate)
|
|
55
|
+
df["offset_spec"] = seconds_to_frames(df["offset_s"], hop_length, sampling_rate)
|
|
56
|
+
|
|
57
|
+
n_classes = len(corpus.dataset["label"].unique())
|
|
58
|
+
|
|
59
|
+
mfccs = []
|
|
60
|
+
labels = []
|
|
61
|
+
sequences = []
|
|
62
|
+
annotations = []
|
|
63
|
+
for seqid in df["seqid"].unique():
|
|
64
|
+
seq_annots = df.query("seqid == @seqid")
|
|
65
|
+
|
|
66
|
+
notated_audio = seq_annots["notated_path"].unique()[0]
|
|
67
|
+
notated_spec = mfcc_paths.query("notated_path == @notated_audio")[
|
|
68
|
+
"feature_path"
|
|
69
|
+
].unique()[0]
|
|
70
|
+
|
|
71
|
+
seq_end = seq_annots["offset_spec"].iloc[-1]
|
|
72
|
+
mfcc = np.load(notated_spec)
|
|
73
|
+
|
|
74
|
+
# MFCC may be stored as archive arrays for convenience
|
|
75
|
+
# (see transforms/commons/audio.py)
|
|
76
|
+
if hasattr(mfcc, "keys") and "feature" in mfcc:
|
|
77
|
+
mfcc = mfcc["feature"].squeeze()
|
|
78
|
+
else:
|
|
79
|
+
raise KeyError("No key named 'feature' in mfcc archive file.")
|
|
80
|
+
|
|
81
|
+
if seq_end > mfcc.shape[1]:
|
|
82
|
+
logger.warning(
|
|
83
|
+
f"Found inconsistent sequence length: "
|
|
84
|
+
f"audio {notated_audio} was converted to "
|
|
85
|
+
f"{mfcc.shape[1]} timesteps but last annotation is at "
|
|
86
|
+
f"timestep {seq_end}. Annotation will be trimmed."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
seq_end = min(seq_end, mfcc.shape[1])
|
|
90
|
+
|
|
91
|
+
mfcc = mfcc[:, :seq_end]
|
|
92
|
+
|
|
93
|
+
# repeat labels along time axis
|
|
94
|
+
repeated_labels = np.zeros((seq_end, n_classes))
|
|
95
|
+
for row in seq_annots.itertuples():
|
|
96
|
+
onset = row.onset_spec
|
|
97
|
+
offset = min(row.offset_spec, seq_end)
|
|
98
|
+
label = row.encoded_label
|
|
99
|
+
|
|
100
|
+
repeated_labels[onset:offset] = label
|
|
101
|
+
|
|
102
|
+
mfccs.append(mfcc.T)
|
|
103
|
+
labels.append(repeated_labels)
|
|
104
|
+
sequences.append(seq_annots["sequence"].unique()[0])
|
|
105
|
+
annotations.append(seq_annots["annotation"].unique()[0])
|
|
106
|
+
|
|
107
|
+
return annotations, sequences, mfccs, labels
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def load_mfccs_for_annotation(corpus):
|
|
111
|
+
if "syn_mfcc" not in corpus.data_resources:
|
|
112
|
+
raise MissingData(
|
|
113
|
+
"'syn_mfcc' were never computed or can't be found in Corpus. "
|
|
114
|
+
"Maybe provide and audio/spec directory to the Corpus ?"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
selected_paths = corpus.dataset.notated_path.unique().tolist()
|
|
118
|
+
mfcc_paths = corpus.data_resources["syn_mfcc"]
|
|
119
|
+
mfccs = []
|
|
120
|
+
notated_paths = []
|
|
121
|
+
for row in mfcc_paths.itertuples():
|
|
122
|
+
spec_path = row.feature_path
|
|
123
|
+
|
|
124
|
+
if pd.isna(row.notated_path):
|
|
125
|
+
notated_path = spec_path
|
|
126
|
+
else:
|
|
127
|
+
notated_path = row.notated_path
|
|
128
|
+
|
|
129
|
+
# We might not want to load all data, but only the subset
|
|
130
|
+
# actually present in the corpus
|
|
131
|
+
if len(selected_paths) > 0 and notated_path not in selected_paths:
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
mfcc = np.load(spec_path)
|
|
135
|
+
|
|
136
|
+
# MFCC may be stored as structured arrays for convenience
|
|
137
|
+
# (see transforms/commons/audio.py)
|
|
138
|
+
if hasattr(mfcc, "keys") and "feature" in mfcc:
|
|
139
|
+
mfcc = mfcc["feature"].squeeze()
|
|
140
|
+
|
|
141
|
+
mfccs.append(mfcc.T)
|
|
142
|
+
notated_paths.append(notated_path)
|
|
143
|
+
|
|
144
|
+
return notated_paths, mfccs
|