sonusai 0.17.3__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,15 @@ def initialize_db(location: str, test: bool = False) -> None:
37
37
  id INTEGER PRIMARY KEY NOT NULL,
38
38
  name TEXT NOT NULL,
39
39
  samples INTEGER NOT NULL,
40
- level_type TEXT NOT NULL)
40
+ level_type TEXT NOT NULL,
41
+ speaker_id INTEGER,
42
+ FOREIGN KEY(speaker_id) REFERENCES speaker (id))
43
+ """)
44
+
45
+ con.execute("""
46
+ CREATE TABLE speaker (
47
+ id INTEGER PRIMARY KEY NOT NULL,
48
+ parent TEXT NOT NULL)
41
49
  """)
42
50
 
43
51
  con.execute("""
@@ -58,13 +66,9 @@ def initialize_db(location: str, test: bool = False) -> None:
58
66
  seed INTEGER NOT NULL,
59
67
  truth_mutex BOOLEAN NOT NULL,
60
68
  truth_reduction_function TEXT NOT NULL,
61
- mixid_width INTEGER NOT NULL)
62
- """)
63
-
64
- con.execute("""
65
- CREATE TABLE asr_manifest (
66
- id INTEGER PRIMARY KEY NOT NULL,
67
- manifest TEXT NOT NULL)
69
+ mixid_width INTEGER NOT NULL,
70
+ speaker_metadata_tiers TEXT NOT NULL,
71
+ textgrid_metadata_tiers TEXT NOT NULL)
68
72
  """)
69
73
 
70
74
  con.execute("""
@@ -155,8 +159,8 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
155
159
  con = db_connection(location=location, readonly=False, test=test)
156
160
  con.execute("""
157
161
  INSERT INTO top (version, class_balancing, feature, noise_mix_mode, num_classes,
158
- seed, truth_mutex, truth_reduction_function, mixid_width)
159
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
162
+ seed, truth_mutex, truth_reduction_function, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
163
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
160
164
  """, (
161
165
  1,
162
166
  config['class_balancing'],
@@ -166,19 +170,9 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
166
170
  config['seed'],
167
171
  truth_mutex,
168
172
  config['truth_reduction_function'],
169
- 0))
170
- con.commit()
171
- con.close()
172
-
173
-
174
- def populate_asr_manifest_table(location: str, config: dict, test: bool = False) -> None:
175
- """Populate asr_manifest table
176
- """
177
- from .mixdb import db_connection
178
-
179
- con = db_connection(location=location, readonly=False, test=test)
180
- con.executemany("INSERT INTO asr_manifest (manifest) VALUES (?)",
181
- [(item,) for item in config['asr_manifest']])
173
+ 0,
174
+ '',
175
+ ''))
182
176
  con.commit()
183
177
  con.close()
184
178
 
@@ -242,36 +236,51 @@ def populate_spectral_mask_table(location: str, config: dict, test: bool = False
242
236
  def populate_target_file_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
243
237
  """Populate target file table
244
238
  """
239
+ import json
240
+ from pathlib import Path
241
+
245
242
  from .mixdb import db_connection
246
243
 
247
- con = db_connection(location=location, readonly=False, test=test)
244
+ _populate_truth_setting_table(location, target_files, test)
245
+ _populate_speaker_table(location, target_files, test)
248
246
 
249
- # Populate truth_setting table
250
- truth_settings: list[str] = []
251
- for truth_setting in [truth_setting for target_file in target_files
252
- for truth_setting in target_file.truth_settings]:
253
- ts = truth_setting.to_json()
254
- if ts not in truth_settings:
255
- truth_settings.append(ts)
256
- con.executemany("INSERT INTO truth_setting (setting) VALUES (?)",
257
- [(item,) for item in truth_settings])
247
+ con = db_connection(location=location, readonly=False, test=test)
258
248
 
259
- # Populate target_file table
260
249
  cur = con.cursor()
250
+ textgrid_metadata_tiers: set[str] = set()
261
251
  for target_file in target_files:
252
+ # Get TextGrid tiers for target file and add to collection
253
+ tiers = _get_textgrid_tiers_from_target_file(target_file.name)
254
+ for tier in tiers:
255
+ textgrid_metadata_tiers.add(tier)
256
+
257
+ # Get truth settings for target file
262
258
  truth_setting_ids: list[int] = []
263
259
  for truth_setting in target_file.truth_settings:
264
260
  cur.execute("SELECT truth_setting.id FROM truth_setting WHERE ? = truth_setting.setting",
265
261
  (truth_setting.to_json(),))
266
262
  truth_setting_ids.append(cur.fetchone()[0])
267
263
 
268
- cur.execute("INSERT INTO target_file (name, samples, level_type) VALUES (?, ?, ?)",
269
- (target_file.name, target_file.samples, target_file.level_type))
264
+ # Get speaker_id for target file
265
+ cur.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
266
+ (Path(target_file.name).parent.as_posix(),))
267
+ result = cur.fetchone()
268
+ speaker_id = None
269
+ if result is not None:
270
+ speaker_id = result[0]
271
+
272
+ # Add entry
273
+ cur.execute("INSERT INTO target_file (name, samples, level_type, speaker_id) VALUES (?, ?, ?, ?)",
274
+ (target_file.name, target_file.samples, target_file.level_type, speaker_id))
270
275
  target_file_id = cur.lastrowid
271
276
  for truth_setting_id in truth_setting_ids:
272
277
  cur.execute("INSERT INTO target_file_truth_setting (target_file_id, truth_setting_id) VALUES (?, ?)",
273
278
  (target_file_id, truth_setting_id))
274
279
 
280
+ # Update textgrid_metadata_tiers in the top table
281
+ con.execute("UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
282
+ (json.dumps(sorted(textgrid_metadata_tiers)), 1))
283
+
275
284
  con.commit()
276
285
  con.close()
277
286
 
@@ -304,8 +313,8 @@ def populate_impulse_response_file_table(location: str, impulse_response_files:
304
313
  def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
305
314
  """Update the mixid width
306
315
  """
307
- from sonusai.utils import max_text_width
308
316
  from .mixdb import db_connection
317
+ from sonusai.utils import max_text_width
309
318
 
310
319
  con = db_connection(location=location, readonly=False, test=test)
311
320
  con.execute("UPDATE top SET mixid_width=? WHERE top.id = ?", (max_text_width(num_mixtures), 1))
@@ -367,8 +376,8 @@ def update_mixture(mixdb: MixtureDatabase,
367
376
  """
368
377
  from .audio import get_next_noise
369
378
  from .augmentation import apply_gain
370
- from .helpers import get_target
371
379
  from .datatypes import GenMixData
380
+ from .helpers import get_target
372
381
 
373
382
  mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
374
383
 
@@ -917,3 +926,92 @@ def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
917
926
 
918
927
  return ([UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config['snrs']] +
919
928
  [UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config['random_snrs']])
929
+
930
+
931
+ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
932
+ from pathlib import Path
933
+
934
+ from praatio import textgrid
935
+
936
+ from sonusai.mixture import tokenized_expand
937
+
938
+ textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
939
+ if not textgrid_file.exists():
940
+ return []
941
+
942
+ tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
943
+
944
+ return sorted(tg.tierNames)
945
+
946
+
947
+ def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
948
+ """Populate speaker table
949
+ """
950
+ import json
951
+ from pathlib import Path
952
+
953
+ import yaml
954
+
955
+ from .mixdb import db_connection
956
+ from .tokenized_shell_vars import tokenized_expand
957
+
958
+ # Determine columns for speaker table
959
+ all_parents = set([Path(target_file.name).parent for target_file in target_files])
960
+ speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / 'speaker.yml')[0]).exists())
961
+
962
+ speakers: dict[Path, dict[str, str]] = {}
963
+ for parent in sorted(speaker_parents):
964
+ with open(tokenized_expand(parent / 'speaker.yml')[0], 'r') as f:
965
+ speakers[parent] = yaml.safe_load(f)
966
+
967
+ new_columns: list[str] = []
968
+ for keys in speakers.keys():
969
+ for column in speakers[keys].keys():
970
+ new_columns.append(column)
971
+ new_columns = sorted(set(new_columns))
972
+
973
+ con = db_connection(location=location, readonly=False, test=test)
974
+
975
+ for new_column in new_columns:
976
+ con.execute(f'ALTER TABLE speaker ADD COLUMN {new_column} TEXT')
977
+
978
+ # Populate speaker table
979
+ speaker_rows: list[tuple[str, ...]] = []
980
+ for key in speakers.keys():
981
+ entry = (speakers[key].get(column, None) for column in new_columns)
982
+ speaker_rows.append((key.as_posix(), *entry))
983
+
984
+ column_ids = ', '.join(['parent', *new_columns])
985
+ column_values = ', '.join(['?'] * (len(new_columns) + 1))
986
+ con.executemany(f'INSERT INTO speaker ({column_ids}) VALUES ({column_values})', speaker_rows)
987
+
988
+ con.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
989
+
990
+ # Update speaker_metadata_tiers in the top table
991
+ tiers = [description[0] for description in con.execute("SELECT * FROM speaker").description if
992
+ description[0] not in ('id', 'parent')]
993
+ con.execute("UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?", (json.dumps(tiers), 1))
994
+
995
+ con.commit()
996
+ con.close()
997
+
998
+
999
+ def _populate_truth_setting_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
1000
+ """Populate truth_setting table
1001
+ """
1002
+ from .mixdb import db_connection
1003
+
1004
+ con = db_connection(location=location, readonly=False, test=test)
1005
+
1006
+ # Populate truth_setting table
1007
+ truth_settings: list[str] = []
1008
+ for truth_setting in [truth_setting for target_file in target_files
1009
+ for truth_setting in target_file.truth_settings]:
1010
+ ts = truth_setting.to_json()
1011
+ if ts not in truth_settings:
1012
+ truth_settings.append(ts)
1013
+ con.executemany("INSERT INTO truth_setting (setting) VALUES (?)",
1014
+ [(item,) for item in truth_settings])
1015
+
1016
+ con.commit()
1017
+ con.close()
@@ -1,5 +1,7 @@
1
1
  from typing import Any
2
+ from typing import Optional
2
3
 
4
+ from praatio.utilities.constants import Interval
3
5
  from pyaaware import ForwardTransform
4
6
  from pyaaware import InverseTransform
5
7
 
@@ -18,6 +20,7 @@ from sonusai.mixture.datatypes import Mixture
18
20
  from sonusai.mixture.datatypes import NoiseFile
19
21
  from sonusai.mixture.datatypes import NoiseFiles
20
22
  from sonusai.mixture.datatypes import Segsnr
23
+ from sonusai.mixture.datatypes import SpeechMetadata
21
24
  from sonusai.mixture.datatypes import Target
22
25
  from sonusai.mixture.datatypes import TargetFiles
23
26
  from sonusai.mixture.datatypes import Targets
@@ -123,6 +126,35 @@ def write_mixture_data(mixdb: MixtureDatabase,
123
126
  f.create_dataset(name=item[0], data=item[1])
124
127
 
125
128
 
129
+ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
130
+ """Get a list of all speech metadata for the given mixture
131
+ """
132
+ results: list[dict[str, SpeechMetadata]] = []
133
+ for target in mixture.targets:
134
+ data: dict[str, SpeechMetadata] = {}
135
+ for tier in mixdb.speaker_metadata_tiers:
136
+ data[tier] = mixdb.speaker(mixdb.target_file(target.file_id).speaker_id, tier)
137
+
138
+ for tier in mixdb.textgrid_metadata_tiers:
139
+ item = get_textgrid_tier_from_target_file(mixdb.target_file(target.file_id).name, tier)
140
+ if isinstance(item, list):
141
+ # Check for tempo augmentation and adjust Interval start and end data as needed
142
+ entries = []
143
+ for entry in item:
144
+ if target.augmentation.tempo is not None:
145
+ entries.append(Interval(entry.start / target.augmentation.tempo,
146
+ entry.end / target.augmentation.tempo,
147
+ entry.label))
148
+ else:
149
+ entries.append(entry)
150
+ data[tier] = entries
151
+ else:
152
+ data[tier] = item
153
+ results.append(data)
154
+
155
+ return results
156
+
157
+
126
158
  def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
127
159
  """Create a string of metadata for a Mixture
128
160
 
@@ -131,6 +163,7 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
131
163
  :return: String of metadata
132
164
  """
133
165
  metadata = ''
166
+ speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
134
167
  for mi, target in enumerate(mixture.targets):
135
168
  target_file = mixdb.target_file(target.file_id)
136
169
  target_augmentation = target.augmentation
@@ -147,7 +180,8 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
147
180
  metadata += f'target {mi} truth index {tsi}: {truth_settings[tsi].index}\n'
148
181
  metadata += f'target {mi} truth function {tsi}: {truth_settings[tsi].function}\n'
149
182
  metadata += f'target {mi} truth config {tsi}: {truth_settings[tsi].config}\n'
150
- metadata += f'target {mi} asr: {mixdb.target_asr_data(target.file_id)}\n'
183
+ for key in speech_metadata[mi].keys():
184
+ metadata += f'target {mi} speech {key}: {speech_metadata[mi][key]}\n'
151
185
  noise = mixdb.noise_file(mixture.noise.file_id)
152
186
  noise_augmentation = mixture.noise.augmentation
153
187
  metadata += f'noise name: {noise.name}\n'
@@ -582,3 +616,26 @@ def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentati
582
616
 
583
617
  return estimate_augmented_length_from_length(length=noise_file.samples,
584
618
  tempo=noise_augmentation.tempo)
619
+
620
+
621
+ def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> Optional[SpeechMetadata]:
622
+ from pathlib import Path
623
+
624
+ from praatio import textgrid
625
+
626
+ from .tokenized_shell_vars import tokenized_expand
627
+
628
+ textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
629
+ if not textgrid_file.exists():
630
+ return None
631
+
632
+ tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
633
+
634
+ if tier not in tg.tierNames:
635
+ return None
636
+
637
+ entries = tg.getTier(tier).entries
638
+ if len(entries) > 1:
639
+ return list(entries)
640
+ else:
641
+ return entries[0].label
@@ -7,35 +7,82 @@ def calculate_snr_f_statistics(truth_f: np.ndarray) -> tuple[np.ndarray, np.ndar
7
7
  For now, includes mean and standard deviation of the raw values (usually energy)
8
8
  and mean and standard deviation of the dB values (10 * log10).
9
9
  """
10
- classes = truth_f.shape[1]
10
+ return (
11
+ calculate_snr_mean(truth_f),
12
+ calculate_snr_std(truth_f),
13
+ calculate_snr_db_mean(truth_f),
14
+ calculate_snr_db_std(truth_f),
15
+ )
11
16
 
12
- snr_mean = np.zeros(classes, dtype=np.float32)
13
- snr_std = np.zeros(classes, dtype=np.float32)
14
- snr_db_mean = np.zeros(classes, dtype=np.float32)
15
- snr_db_std = np.zeros(classes, dtype=np.float32)
16
17
 
17
- for c in range(classes):
18
+ def calculate_snr_mean(truth_f: np.ndarray) -> np.ndarray:
19
+ """Calculate mean of snr_f truth data."""
20
+ snr_mean = np.zeros(truth_f.shape[1], dtype=np.float32)
21
+
22
+ for c in range(truth_f.shape[1]):
18
23
  tmp_truth = truth_f[:, c]
19
24
  tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
20
25
 
21
26
  if len(tmp) == 0:
22
27
  snr_mean[c] = -np.inf
23
- snr_std[c] = -np.inf
24
28
  else:
25
29
  snr_mean[c] = np.mean(tmp)
30
+
31
+ return snr_mean
32
+
33
+
34
+ def calculate_snr_std(truth_f: np.ndarray) -> np.ndarray:
35
+ """Calculate standard deviation of snr_f truth data."""
36
+ snr_std = np.zeros(truth_f.shape[1], dtype=np.float32)
37
+
38
+ for c in range(truth_f.shape[1]):
39
+ tmp_truth = truth_f[:, c]
40
+ tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
41
+
42
+ if len(tmp) == 0:
43
+ snr_std[c] = -np.inf
44
+ else:
26
45
  snr_std[c] = np.std(tmp, ddof=1)
27
46
 
47
+ return snr_std
48
+
49
+
50
+ def calculate_snr_db_mean(truth_f: np.ndarray) -> np.ndarray:
51
+ """Calculate dB mean of snr_f truth data."""
52
+ snr_db_mean = np.zeros(truth_f.shape[1], dtype=np.float32)
53
+
54
+ for c in range(truth_f.shape[1]):
55
+ tmp_truth = truth_f[:, c]
56
+ tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
57
+
28
58
  tmp2 = 10 * np.ma.log10(tmp).filled(-np.inf)
29
59
  tmp2 = tmp2[np.isfinite(tmp2)]
30
60
 
31
61
  if len(tmp2) == 0:
32
62
  snr_db_mean[c] = -np.inf
33
- snr_db_std[c] = -np.inf
34
63
  else:
35
64
  snr_db_mean[c] = np.mean(tmp2)
65
+
66
+ return snr_db_mean
67
+
68
+
69
+ def calculate_snr_db_std(truth_f: np.ndarray) -> np.ndarray:
70
+ """Calculate dB standard deviation of snr_f truth data."""
71
+ snr_db_std = np.zeros(truth_f.shape[1], dtype=np.float32)
72
+
73
+ for c in range(truth_f.shape[1]):
74
+ tmp_truth = truth_f[:, c]
75
+ tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
76
+
77
+ tmp2 = 10 * np.ma.log10(tmp).filled(-np.inf)
78
+ tmp2 = tmp2[np.isfinite(tmp2)]
79
+
80
+ if len(tmp2) == 0:
81
+ snr_db_std[c] = -np.inf
82
+ else:
36
83
  snr_db_std[c] = np.std(tmp2, ddof=1)
37
84
 
38
- return snr_mean, snr_std, snr_db_mean, snr_db_std
85
+ return snr_db_std
39
86
 
40
87
 
41
88
  def calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray: