britekit 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of britekit might be problematic. Click here for more details.

Files changed (45) hide show
  1. britekit/__about__.py +1 -1
  2. britekit/__init__.py +6 -2
  3. britekit/cli.py +6 -1
  4. britekit/commands/__init__.py +2 -1
  5. britekit/commands/_analyze.py +12 -10
  6. britekit/commands/_audioset.py +8 -8
  7. britekit/commands/_calibrate.py +8 -8
  8. britekit/commands/_ckpt_ops.py +6 -6
  9. britekit/commands/_db_add.py +12 -12
  10. britekit/commands/_db_delete.py +15 -15
  11. britekit/commands/_embed.py +4 -4
  12. britekit/commands/_ensemble.py +7 -7
  13. britekit/commands/_extract.py +158 -19
  14. britekit/commands/_find_dup.py +5 -5
  15. britekit/commands/_inat.py +4 -4
  16. britekit/commands/_init.py +1 -1
  17. britekit/commands/_pickle.py +7 -7
  18. britekit/commands/_plot.py +26 -26
  19. britekit/commands/_reextract.py +6 -6
  20. britekit/commands/_reports.py +41 -27
  21. britekit/commands/_search.py +12 -12
  22. britekit/commands/_train.py +6 -6
  23. britekit/commands/_tune.py +12 -12
  24. britekit/commands/_wav2mp3.py +2 -2
  25. britekit/commands/_xeno.py +7 -7
  26. britekit/commands/_youtube.py +3 -3
  27. britekit/core/analyzer.py +8 -8
  28. britekit/core/audio.py +14 -14
  29. britekit/core/data_module.py +2 -2
  30. britekit/core/plot.py +8 -8
  31. britekit/core/predictor.py +21 -21
  32. britekit/core/reextractor.py +6 -6
  33. britekit/core/util.py +8 -8
  34. britekit/models/base_model.py +1 -0
  35. britekit/occurrence_db/occurrence_data_provider.py +13 -13
  36. britekit/testing/{per_minute_tester.py → per_block_tester.py} +39 -36
  37. britekit/training_db/extractor.py +65 -30
  38. britekit/training_db/training_data_provider.py +1 -1
  39. britekit/training_db/training_db.py +97 -100
  40. britekit-0.1.4.dist-info/METADATA +299 -0
  41. {britekit-0.1.2.dist-info → britekit-0.1.4.dist-info}/RECORD +44 -44
  42. britekit-0.1.2.dist-info/METADATA +0 -290
  43. {britekit-0.1.2.dist-info → britekit-0.1.4.dist-info}/WHEEL +0 -0
  44. {britekit-0.1.2.dist-info → britekit-0.1.4.dist-info}/entry_points.txt +0 -0
  45. {britekit-0.1.2.dist-info → britekit-0.1.4.dist-info}/licenses/LICENSE.txt +0 -0
@@ -22,12 +22,12 @@ class Reextractor:
22
22
  updating the database.
23
23
 
24
24
  Args:
25
- cfg_path (str, optional): Path to YAML file defining configuration overrides.
26
- db_path (str, optional): Path to the training database. Defaults to cfg.train.training_db.
27
- class_name (str, optional): Name of a specific class to reextract. If omitted, processes all classes.
28
- classes_path (str, optional): Path to CSV file listing classes to reextract. Alternative to class_name.
29
- check (bool): If True, only check that all recording paths are accessible without updating database.
30
- spec_group (str): Spectrogram group name for storing the extracted spectrograms. Defaults to 'default'.
25
+ - cfg_path (str, optional): Path to YAML file defining configuration overrides.
26
+ - db_path (str, optional): Path to the training database. Defaults to cfg.train.training_db.
27
+ - class_name (str, optional): Name of a specific class to reextract. If omitted, processes all classes.
28
+ - classes_path (str, optional): Path to CSV file listing classes to reextract. Alternative to class_name.
29
+ - check (bool): If True, only check that all recording paths are accessible without updating database.
30
+ - spec_group (str): Spectrogram group name for storing the extracted spectrograms. Defaults to 'default'.
31
31
  """
32
32
 
33
33
  def __init__(
britekit/core/util.py CHANGED
@@ -166,7 +166,7 @@ def cfg_to_pure(obj: Any) -> JSONValue:
166
166
  str, int, float, bool) that can be safely serialized.
167
167
 
168
168
  Args:
169
- obj: Any object to convert to JSON-serializable format
169
+ - obj: Any object to convert to JSON-serializable format
170
170
 
171
171
  Returns:
172
172
  JSON-serializable representation of the input object
@@ -284,8 +284,8 @@ def get_audio_files(path: str, short_names: bool = False) -> List[str]:
284
284
  Return list of audio files in the given directory.
285
285
 
286
286
  Args:
287
- path (str): Directory path
288
- short_names (bool): If true, return file names, else return full paths
287
+ - path (str): Directory path
288
+ - short_names (bool): If true, return file names, else return full paths
289
289
 
290
290
  Returns:
291
291
  List of audio files in the given directory
@@ -325,8 +325,8 @@ def get_file_lines(path: str, encoding: str = "utf-8") -> List[str]:
325
325
  and lines that start with #.
326
326
 
327
327
  Args:
328
- path: Path to text file
329
- encoding: File encoding (default: utf-8)
328
+ - path: Path to text file
329
+ - encoding: File encoding (default: utf-8)
330
330
 
331
331
  Returns:
332
332
  List of lines
@@ -354,7 +354,7 @@ def get_source_name(filename: str) -> str:
354
354
  Return a source name given a recording file name.
355
355
 
356
356
  Args:
357
- filename: Recording file name
357
+ - filename: Recording file name
358
358
 
359
359
  Returns:
360
360
  Source name
@@ -390,7 +390,7 @@ def compress_spectrogram(spec) -> bytes:
390
390
  Compress a spectrogram in preparation for inserting into database.
391
391
 
392
392
  Args:
393
- spec: Uncompressed spectrogram
393
+ - spec: Uncompressed spectrogram
394
394
 
395
395
  Returns:
396
396
  Compressed spectrogram
@@ -421,7 +421,7 @@ def expand_spectrogram(spec: bytes):
421
421
  Decompress a spectrogram, then convert from bytes to floats and reshape it.
422
422
 
423
423
  Args:
424
- spec: Compressed spectrogram
424
+ - spec: Compressed spectrogram
425
425
 
426
426
  Returns:
427
427
  Uncompressed spectrogram
@@ -252,6 +252,7 @@ class BaseModel(pl.LightningModule):
252
252
  }
253
253
 
254
254
  def on_save_checkpoint(self, checkpoint):
255
+ print("on_save_checkpoint")
255
256
  """Save model metadata to checkpoint."""
256
257
  if not hasattr(self, "identifier"):
257
258
  self.identifier = str(uuid.uuid4()).upper()
@@ -10,7 +10,7 @@ class OccurrenceDataProvider:
10
10
  you must call the refresh method.
11
11
 
12
12
  Args:
13
- db (OccurrenceDatabase): The database object.
13
+ - db (OccurrenceDatabase): The database object.
14
14
  """
15
15
 
16
16
  def __init__(self, db: OccurrenceDatabase):
@@ -31,8 +31,8 @@ class OccurrenceDataProvider:
31
31
  Return county info for a given latitude/longitude, or None if not found.
32
32
 
33
33
  Args:
34
- latitude (float): Latitude.
35
- longitude (float): Longitude.
34
+ - latitude (float): Latitude.
35
+ - longitude (float): Longitude.
36
36
 
37
37
  Returns:
38
38
  County object, or None if not found.
@@ -54,8 +54,8 @@ class OccurrenceDataProvider:
54
54
  For each week, return the maximum of it and the adjacent weeks.
55
55
 
56
56
  Args:
57
- county_code (str): County code
58
- class_name (str): Class name
57
+ - county_code (str): County code
58
+ - class_name (str): Class name
59
59
 
60
60
  Returns:
61
61
  List of smoothed occurrence values.
@@ -75,8 +75,8 @@ class OccurrenceDataProvider:
75
75
  Return list of occurrence values for given county code and class name.
76
76
 
77
77
  Args:
78
- county_code (str): County code
79
- class_name (str): Class name
78
+ - county_code (str): County code
79
+ - class_name (str): Class name
80
80
 
81
81
  Returns:
82
82
  List of occurrence values.
@@ -97,9 +97,9 @@ class OccurrenceDataProvider:
97
97
  If area_weight = True, weight each county by its area.
98
98
 
99
99
  Args:
100
- county_prefix (str): County code prefix
101
- class_name (str): Class name
102
- area_weight (bool, Optional): If true, weight by county area (default = False)
100
+ - county_prefix (str): County code prefix
101
+ - class_name (str): Class name
102
+ - area_weight (bool, Optional): If true, weight by county area (default = False)
103
103
 
104
104
  Returns:
105
105
  Numpy array of 48 average occurrence values (one per week, using 4-week months).
@@ -139,9 +139,9 @@ class OccurrenceDataProvider:
139
139
  county don't occur in the same week.
140
140
 
141
141
  Args:
142
- county_prefix (str): County code prefix
143
- class_name (str): Class name
144
- area_weight (bool, Optional): If true, weight by county area (default = False)
142
+ - county_prefix (str): County code prefix
143
+ - class_name (str): Class name
144
+ - area_weight (bool, Optional): If true, weight by county area (default = False)
145
145
 
146
146
  Returns:
147
147
  Numpy average maximum occurrence value.
@@ -17,16 +17,16 @@ class Annotation:
17
17
  return f"{self.class_code}: {self.start_time}-{self.end_time}"
18
18
 
19
19
 
20
- class PerMinuteTester(BaseTester):
20
+ class PerBlockTester(BaseTester):
21
21
  """
22
- Calculate test metrics when annotations are specified per minute. That is, for selected minutes of
23
- each recording, a list of classes known to be present is given, and we are to calculate metrics for
24
- those minutes only.
22
+ Calculate test metrics when annotations are specified per block, where a block is a fixed length, such
23
+ as a minute. That is, for selected blocks of each recording, a list of classes known to be present is given,
24
+ and we are to calculate metrics for those blocks only.
25
25
 
26
- Annotations are read as a CSV with three columns: "recording", "minute", and "classes".
26
+ Annotations are read as a CSV with three columns: "recording", "block", and "classes".
27
27
  The recording column is the file name without the path or type suffix, e.g. "recording1".
28
- The minute column contains 1 for the first minute, 2 for the second minute etc. and may
29
- exclude some minutes. The classes column contains a comma-separated list of codes for the classes found in the corresponding minute.
28
+ The block column contains 1 for the first block, 2 for the second block etc. and may exclude some blocks.
29
+ The classes column contains a comma-separated list of codes for the classes found in the corresponding block.
30
30
  If your annotations are in a different format, simply convert to this format to use this script.
31
31
 
32
32
  Classifiers should be run with a threshold of 0, and with label merging disabled so segment-specific scores are retained.
@@ -37,6 +37,7 @@ class PerMinuteTester(BaseTester):
37
37
  label_dir (str): Directory containing Audacity labels.
38
38
  output_dir (str): Output directory, where reports will be written.
39
39
  threshold (float): Score threshold for precision/recall reporting.
40
+ block_size (int, optional): block_size in seconds (default=60).
40
41
  gen_pr_table (bool, optional): If true, generate a PR table, which may be slow (default = False).
41
42
  """
42
43
 
@@ -47,10 +48,11 @@ class PerMinuteTester(BaseTester):
47
48
  label_dir: str,
48
49
  output_dir: str,
49
50
  threshold: float,
51
+ block_size: int = 60,
50
52
  gen_pr_table: bool = False,
51
53
  ):
52
54
  """
53
- Initialize the PerMinuteTester.
55
+ Initialize the PerBlockTester.
54
56
 
55
57
  See class docstring for detailed parameter descriptions and usage information.
56
58
  """
@@ -60,6 +62,7 @@ class PerMinuteTester(BaseTester):
60
62
  self.label_dir = label_dir
61
63
  self.output_dir = output_dir
62
64
  self.threshold = threshold
65
+ self.block_size = block_size
63
66
  self.gen_pr_table = gen_pr_table
64
67
 
65
68
  self.cfg = get_config()
@@ -119,8 +122,8 @@ class PerMinuteTester(BaseTester):
119
122
  Load annotation data from CSV file and process into internal format.
120
123
 
121
124
  This method reads a CSV file containing ground truth annotations where each row
122
- represents a recording, minute, and its associated classes. The CSV should have columns:
123
- "recording" (filename without path/extension), "minute" (minute number starting from 1),
125
+ represents a recording, block, and its associated classes. The CSV should have columns:
126
+ "recording" (filename without path/extension), "block" (block number starting from 1),
124
127
  and "classes" (comma-separated class codes).
125
128
 
126
129
  The method processes the annotations, handles class code mapping, filters out
@@ -151,10 +154,10 @@ class PerMinuteTester(BaseTester):
151
154
  self.annotations[recording] = {}
152
155
  self.segments_per_recording[recording] = []
153
156
 
154
- minute = row["minute"]
155
- if minute not in self.annotations[recording]:
156
- self.annotations[recording][minute] = []
157
- self.segments_per_recording[recording].append(minute - 1)
157
+ block = row["block"]
158
+ if block not in self.annotations[recording]:
159
+ self.annotations[recording][block] = []
160
+ self.segments_per_recording[recording].append(block - 1)
158
161
 
159
162
  input_class_list = []
160
163
  for code in row["classes"].split(","):
@@ -176,7 +179,7 @@ class PerMinuteTester(BaseTester):
176
179
  continue # exclude from saved annotations
177
180
 
178
181
  if class_code:
179
- self.annotations[recording][minute].append(class_code)
182
+ self.annotations[recording][block].append(class_code)
180
183
  self.annotated_class_set.add(class_code)
181
184
 
182
185
  self.annotated_classes = sorted(list(self.annotated_class_set))
@@ -192,12 +195,12 @@ class PerMinuteTester(BaseTester):
192
195
 
193
196
  This method evaluates precision and recall metrics at different threshold values
194
197
  (0.01 to 1.00 in 0.01 increments) to create comprehensive precision-recall curves.
195
- It calculates both per-minute granularity metrics and per-second granularity metrics.
198
+ It calculates both per-block granularity metrics and per-second granularity metrics.
196
199
 
197
200
  Returns:
198
201
  dict: Dictionary containing precision-recall data with keys:
199
202
  - annotated_thresholds: List of threshold values for annotated classes
200
- - annotated_precisions_minutes: List of precision values (minutes) for annotated classes
203
+ - annotated_precisions_blocks: List of precision values (blocks) for annotated classes
201
204
  - annotated_precisions_seconds: List of precision values (seconds) for annotated classes
202
205
  - annotated_recalls: List of recall values for annotated classes
203
206
  - trained_thresholds: List of threshold values for trained classes
@@ -219,20 +222,20 @@ class PerMinuteTester(BaseTester):
219
222
 
220
223
  # use the looping method so we get per_second precision
221
224
  thresholds = []
222
- recall_annotated, precision_annotated_minutes, precision_annotated_seconds = (
225
+ recall_annotated, precision_annotated_blocks, precision_annotated_seconds = (
223
226
  [],
224
227
  [],
225
228
  [],
226
229
  )
227
- recall_trained, precision_trained_minutes = [], []
230
+ recall_trained, precision_trained_blocks = [], []
228
231
  for threshold in np.arange(0.01, 1.01, 0.01):
229
232
  info = self.get_precision_recall(threshold)
230
233
  thresholds.append(threshold)
231
234
  recall_annotated.append(info["recall_annotated"])
232
- precision_annotated_minutes.append(info["precision_annotated"])
235
+ precision_annotated_blocks.append(info["precision_annotated"])
233
236
  precision_annotated_seconds.append(info["precision_secs"])
234
237
  recall_trained.append(info["recall_trained"])
235
- precision_trained_minutes.append(info["precision_trained"])
238
+ precision_trained_blocks.append(info["precision_trained"])
236
239
  logging.info(
237
240
  f"\rPercent complete: {int(threshold * 100)}%", end="", flush=True
238
241
  )
@@ -240,12 +243,12 @@ class PerMinuteTester(BaseTester):
240
243
  logging.info("")
241
244
  pr_table_dict = {}
242
245
  pr_table_dict["annotated_thresholds"] = thresholds
243
- pr_table_dict["annotated_precisions_minutes"] = precision_annotated_minutes
246
+ pr_table_dict["annotated_precisions_blocks"] = precision_annotated_blocks
244
247
  pr_table_dict["annotated_precisions_seconds"] = precision_annotated_seconds
245
248
  pr_table_dict["annotated_recalls"] = recall_annotated
246
249
 
247
250
  pr_table_dict["trained_thresholds"] = thresholds
248
- pr_table_dict["trained_precisions"] = precision_trained_minutes
251
+ pr_table_dict["trained_precisions"] = precision_trained_blocks
249
252
  pr_table_dict["trained_recalls"] = recall_trained
250
253
 
251
254
  # use this method for more granular results without per_second precision
@@ -303,7 +306,7 @@ class PerMinuteTester(BaseTester):
303
306
  if self.gen_pr_table:
304
307
  # calculate and output precision/recall per threshold
305
308
  threshold_annotated = self.pr_table_dict["annotated_thresholds"]
306
- precision_annotated = self.pr_table_dict["annotated_precisions_minutes"]
309
+ precision_annotated = self.pr_table_dict["annotated_precisions_blocks"]
307
310
  precision_annotated_secs = self.pr_table_dict[
308
311
  "annotated_precisions_seconds"
309
312
  ]
@@ -401,13 +404,13 @@ class PerMinuteTester(BaseTester):
401
404
  )
402
405
  rpt.append(f" For threshold = {self.threshold}:\n")
403
406
  rpt.append(
404
- f" Precision (minutes) = {100 * self.details_dict['precision_annotated']:.2f}%\n"
407
+ f" Precision (blocks) = {100 * self.details_dict['precision_annotated']:.2f}%\n"
405
408
  )
406
409
  rpt.append(
407
410
  f" Precision (seconds) = {100 * self.details_dict['precision_secs']:.2f}%\n"
408
411
  )
409
412
  rpt.append(
410
- f" Recall (minutes) = {100 * self.details_dict['recall_annotated']:.2f}%\n"
413
+ f" Recall (blocks) = {100 * self.details_dict['recall_annotated']:.2f}%\n"
411
414
  )
412
415
 
413
416
  rpt.append("\n")
@@ -420,10 +423,10 @@ class PerMinuteTester(BaseTester):
420
423
  )
421
424
  rpt.append(f" For threshold = {self.threshold}:\n")
422
425
  rpt.append(
423
- f" Precision (minutes) = {100 * self.details_dict['precision_trained']:.2f}%\n"
426
+ f" Precision (blocks) = {100 * self.details_dict['precision_trained']:.2f}%\n"
424
427
  )
425
428
  rpt.append(
426
- f" Recall (minutes) = {100 * self.details_dict['recall_trained']:.2f}%\n"
429
+ f" Recall (blocks) = {100 * self.details_dict['recall_trained']:.2f}%\n"
427
430
  )
428
431
  logging.info("")
429
432
  with open(os.path.join(self.output_dir, "summary_report.txt"), "w") as summary:
@@ -551,7 +554,7 @@ class PerMinuteTester(BaseTester):
551
554
 
552
555
  # initialize y_true and y_pred and save them as CSV files
553
556
  logging.info("Initializing")
554
- self.get_labels(self.label_dir, segment_len=60, overlap=0)
557
+ self.get_labels(self.label_dir, segment_len=self.block_size, overlap=0)
555
558
  self.get_annotations()
556
559
  self._init_y_true()
557
560
  self.init_y_pred(segments_per_recording=self.segments_per_recording)
@@ -573,7 +576,7 @@ class PerMinuteTester(BaseTester):
573
576
 
574
577
  def _init_y_true(self):
575
578
  """
576
- Create a dataframe representing the ground truth data, with recordings segmented into 1-minute segments
579
+ Create a dataframe representing the ground truth data, with recordings segmented into 1-block segments
577
580
  """
578
581
  import pandas as pd
579
582
 
@@ -582,11 +585,11 @@ class PerMinuteTester(BaseTester):
582
585
  self.recordings = [] # base class needs array with recording per row
583
586
  rows = []
584
587
  for recording in sorted(self.annotations.keys()):
585
- for minute in sorted(self.annotations[recording].keys()):
588
+ for block in sorted(self.annotations[recording].keys()):
586
589
  self.recordings.append(recording)
587
- row = [f"{recording}-{minute - 1}"]
590
+ row = [f"{recording}-{block - 1}"]
588
591
  row.extend([0 for class_code in self.trained_classes])
589
- for class_code in self.annotations[recording][minute]:
592
+ for class_code in self.annotations[recording][block]:
590
593
  if class_code in self.trained_class_indexes:
591
594
  row[self.trained_class_indexes[class_code] + 1] = 1
592
595
 
@@ -618,8 +621,8 @@ class PerMinuteTester(BaseTester):
618
621
 
619
622
  df = pd.DataFrame()
620
623
  df["threshold"] = pd.Series(threshold)
621
- df["recall (minutes)"] = pd.Series(recall)
622
- df["precision (minutes)"] = pd.Series(precision)
624
+ df["recall (blocks)"] = pd.Series(recall)
625
+ df["precision (blocks)"] = pd.Series(precision)
623
626
  if precision_secs is not None:
624
627
  df["precision (seconds)"] = pd.Series(precision_secs)
625
628
 
@@ -631,7 +634,7 @@ class PerMinuteTester(BaseTester):
631
634
 
632
635
  plt.clf()
633
636
  plt.plot(recall, label="Recall")
634
- plt.plot(precision, label="Precision (Minutes)")
637
+ plt.plot(precision, label="Precision (blocks)")
635
638
  if precision_secs is not None:
636
639
  plt.plot(precision_secs, label="Precision (Seconds)")
637
640
 
@@ -109,13 +109,45 @@ class Extractor:
109
109
 
110
110
  return offsets_per_file
111
111
 
112
+ def _insert_by_dict(self, recording_dir, destination_dir, offsets_per_file):
113
+ """
114
+ Given a recording directory and a dict from recording stems to offsets,
115
+ insert the corresponding spectrograms.
116
+ """
117
+ num_inserted = 0
118
+ recording_paths = util.get_audio_files(recording_dir)
119
+ for recording_dir in recording_paths:
120
+ filename = Path(recording_dir).stem
121
+ if filename not in offsets_per_file:
122
+ continue
123
+
124
+ if destination_dir is not None:
125
+ dest_path = os.path.join(destination_dir, Path(recording_dir).name)
126
+ if not os.path.exists(dest_path):
127
+ shutil.copy(recording_dir, dest_path)
128
+
129
+ recording_dir = dest_path
130
+
131
+ logging.info(f"Processing {recording_dir}")
132
+ try:
133
+ self.audio.load(recording_dir)
134
+ except Exception as e:
135
+ logging.error(f"Caught exception: {e}")
136
+ continue
137
+
138
+ num_inserted += self.insert_spectrograms(
139
+ recording_dir, offsets_per_file[filename]
140
+ )
141
+
142
+ return num_inserted
143
+
112
144
  def insert_spectrograms(self, recording_path, offsets):
113
145
  """
114
146
  Insert a spectrogram at each of the given offsets of the specified file.
115
147
 
116
148
  Args:
117
- recording_path (str): Path to audio recording.
118
- offsets (list[float]): List of offsets, where each represents number of seconds to start of spectrogram.
149
+ - recording_path (str): Path to audio recording.
150
+ - offsets (list[float]): List of offsets, where each represents number of seconds to start of spectrogram.
119
151
 
120
152
  Returns:
121
153
  Number of spectrograms inserted.
@@ -156,7 +188,7 @@ class Extractor:
156
188
  Extract spectrograms for all recordings in the given directory.
157
189
 
158
190
  Args:
159
- dir_path (str): Directory containing recordings.
191
+ - dir_path (str): Directory containing recordings.
160
192
 
161
193
  Returns:
162
194
  Number of spectrograms inserted.
@@ -187,45 +219,48 @@ class Extractor:
187
219
 
188
220
  return num_inserted
189
221
 
190
- def extract_by_image(
191
- self, rec_dir: str, spec_dir: str, dest_dir: Optional[str] = None
222
+ def extract_by_csv(
223
+ self, rec_dir: str, csv_path: str, dest_dir: Optional[str] = None
192
224
  ):
193
225
  """
194
226
  Extract spectrograms that match names of spectrogram images in a given directory.
195
227
  Typically the spectrograms were generated using the 'search' or 'plot-db' commands.
196
228
 
197
229
  Args:
198
- rec_dir (str): Directory containing recordings.
199
- spec_dir (str): Directory containing spectrogram images.
200
- dest_dir (str, optional): Optionally copy used recordings to this directory.
230
+ - rec_dir (str): Directory containing recordings.
231
+ - csv_path (str): Path to CSV file containing two columns (recording and offset) to identify segments to extract.
232
+ - dest_dir (str, optional): Optionally copy used recordings to this directory.
201
233
 
202
234
  Returns:
203
235
  Number of spectrograms inserted.
204
236
  """
205
- offsets_per_file = self._process_image_dir(spec_dir)
206
- num_inserted = 0
207
- recording_paths = util.get_audio_files(rec_dir)
208
- for recording_path in recording_paths:
209
- filename = Path(recording_path).stem
210
- if filename not in offsets_per_file:
211
- continue
237
+ import pandas as pd
212
238
 
213
- if dest_dir is not None:
214
- dest_path = os.path.join(dest_dir, Path(recording_path).name)
215
- if not os.path.exists(dest_path):
216
- shutil.copy(recording_path, dest_path)
239
+ df = pd.read_csv(csv_path)
240
+ offsets_per_file: dict[str, list] = {}
241
+ for i, row in df.iterrows():
242
+ recording = row["recording"]
243
+ if recording not in offsets_per_file:
244
+ offsets_per_file[recording] = []
217
245
 
218
- recording_path = dest_path
246
+ offsets_per_file[recording].append(row["offset"])
219
247
 
220
- logging.info(f"Processing {recording_path}")
221
- try:
222
- self.audio.load(recording_path)
223
- except Exception as e:
224
- logging.error(f"Caught exception: {e}")
225
- continue
248
+ return self._insert_by_dict(rec_dir, dest_dir, offsets_per_file)
226
249
 
227
- num_inserted += self.insert_spectrograms(
228
- recording_path, offsets_per_file[filename]
229
- )
250
+ def extract_by_image(
251
+ self, rec_dir: str, spec_dir: str, dest_dir: Optional[str] = None
252
+ ):
253
+ """
254
+ Extract spectrograms that match names of spectrogram images in a given directory.
255
+ Typically the spectrograms were generated using the 'search' or 'plot-db' commands.
230
256
 
231
- return num_inserted
257
+ Args:
258
+ - rec_dir (str): Directory containing recordings.
259
+ - spec_dir (str): Directory containing spectrogram images.
260
+ - dest_dir (str, optional): Optionally copy used recordings to this directory.
261
+
262
+ Returns:
263
+ Number of spectrograms inserted.
264
+ """
265
+ offsets_per_file = self._process_image_dir(spec_dir)
266
+ return self._insert_by_dict(rec_dir, dest_dir, offsets_per_file)
@@ -8,7 +8,7 @@ class TrainingDataProvider:
8
8
  Data access layer on top of TrainingDatabase.
9
9
 
10
10
  Args:
11
- db (TrainingDatabase): The database object.
11
+ - db (TrainingDatabase): The database object.
12
12
  """
13
13
 
14
14
  def __init__(self, db: TrainingDatabase):