sonusai 0.15.8__py3-none-any.whl → 0.15.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,41 @@
1
+ import pyaudio
2
+
3
+
4
+ def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str = None) -> int:
5
+ info = p.get_host_api_info_by_index(0)
6
+ device_count = info.get('deviceCount')
7
+ for i in range(0, device_count):
8
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
9
+ if name is None:
10
+ device_name = None
11
+ else:
12
+ device_name = device_info.get('name')
13
+ if name == device_name and device_info.get('maxInputChannels') > 0:
14
+ return i
15
+
16
+ raise ValueError(f'Could not find {name}')
17
+
18
+
19
+ def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
20
+ names = []
21
+ info = p.get_host_api_info_by_index(0)
22
+ device_count = info.get('deviceCount')
23
+ for i in range(0, device_count):
24
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
25
+ device_name = device_info.get('name')
26
+ if device_info.get('maxInputChannels') > 0:
27
+ names.append(device_name)
28
+
29
+ return names
30
+
31
+
32
+ def get_default_input_device(p: pyaudio.PyAudio) -> str:
33
+ info = p.get_host_api_info_by_index(0)
34
+ device_count = info.get('deviceCount')
35
+ for i in range(0, device_count):
36
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
37
+ device_name = device_info.get('name')
38
+ if device_info.get('maxInputChannels') > 0:
39
+ return device_name
40
+
41
+ raise ValueError('No input audio devices found')
@@ -13,13 +13,12 @@ def calculate_input_shape(feature: str,
13
13
  """
14
14
  from pyaaware import FeatureGenerator
15
15
 
16
- # num_classes is irrelevant, set to 2
17
- fg = FeatureGenerator(feature_mode=feature, num_classes=2)
16
+ fg = FeatureGenerator(feature_mode=feature)
18
17
 
19
18
  if flatten:
20
- in_shape = [fg.stride * fg.num_bands]
19
+ in_shape = [fg.stride * fg.feature_parameters]
21
20
  else:
22
- in_shape = [fg.stride, fg.num_bands]
21
+ in_shape = [fg.stride, fg.feature_parameters]
23
22
 
24
23
  if timesteps > 0:
25
24
  in_shape.insert(0, timesteps)
@@ -0,0 +1,5 @@
1
+ def create_timestamp() -> str:
2
+ """Create a timestamp."""
3
+ from datetime import datetime
4
+
5
+ return datetime.now().strftime('%Y%m%d-%H%M%S')
sonusai/utils/reshape.py CHANGED
@@ -17,14 +17,14 @@ def reshape_inputs(feature: Feature,
17
17
  timesteps: int = 0,
18
18
  flatten: bool = False,
19
19
  add1ch: bool = False) -> tuple[Feature, Optional[Truth]]:
20
- """Check SonusAI feature and truth data and reshape feature of size [frames, strides, bands] into
20
+ """Check SonusAI feature and truth data and reshape feature of size [frames, strides, feature_parameters] into
21
21
  one of several options:
22
22
 
23
23
  If timesteps > 0: (i.e., for recurrent NNs):
24
- no-flatten, no-channel: [sequences, timesteps, strides, bands] (4-dim)
25
- flatten, no-channel: [sequences, timesteps, strides*bands] (3-dim)
26
- no-flatten, add-1channel: [sequences, timesteps, strides, bands, 1] (5-dim)
27
- flatten, add-1channel: [sequences, timesteps, strides*bands, 1] (4-dim)
24
+ no-flatten, no-channel: [sequences, timesteps, strides, feature_parameters] (4-dim)
25
+ flatten, no-channel: [sequences, timesteps, strides*feature_parameters] (3-dim)
26
+ no-flatten, add-1channel: [sequences, timesteps, strides, feature_parameters, 1] (5-dim)
27
+ flatten, add-1channel: [sequences, timesteps, strides*feature_parameters, 1] (4-dim)
28
28
 
29
29
  If batch_size is None, then do not reshape; just calculate new input shape and return.
30
30
 
@@ -40,7 +40,7 @@ def reshape_inputs(feature: Feature,
40
40
  """
41
41
  from sonusai import SonusAIError
42
42
 
43
- frames, strides, bands = feature.shape
43
+ frames, strides, feature_parameters = feature.shape
44
44
  if truth is not None:
45
45
  truth_frames, num_classes = truth.shape
46
46
  # Double-check correctness of inputs
@@ -50,7 +50,7 @@ def reshape_inputs(feature: Feature,
50
50
  num_classes = None
51
51
 
52
52
  if flatten:
53
- feature = np.reshape(feature, (frames, strides * bands))
53
+ feature = np.reshape(feature, (frames, strides * feature_parameters))
54
54
 
55
55
  # Reshape for Keras/TF recurrent models that require timesteps/sequence length dimension
56
56
  if timesteps > 0:
@@ -73,14 +73,14 @@ def reshape_inputs(feature: Feature,
73
73
 
74
74
  # Reshape
75
75
  if feature.ndim == 2: # flattened input
76
- # was [frames, bands*timesteps]
77
- feature = np.reshape(feature, (sequences, timesteps, strides * bands))
76
+ # was [frames, feature_parameters*timesteps]
77
+ feature = np.reshape(feature, (sequences, timesteps, strides * feature_parameters))
78
78
  if truth is not None:
79
79
  # was [frames, num_classes]
80
80
  truth = np.reshape(truth, (sequences, timesteps, num_classes))
81
81
  elif feature.ndim == 3: # un-flattened input
82
- # was [frames, bands, timesteps]
83
- feature = np.reshape(feature, (sequences, timesteps, strides, bands))
82
+ # was [frames, feature_parameters, timesteps]
83
+ feature = np.reshape(feature, (sequences, timesteps, strides, feature_parameters))
84
84
  if truth is not None:
85
85
  # was [frames, num_classes]
86
86
  truth = np.reshape(truth, (sequences, timesteps, num_classes))
sonusai/utils/wave.py CHANGED
@@ -5,15 +5,22 @@ from sonusai.mixture.datatypes import AudioT
5
5
  def write_wav(name: str, audio: AudioT, sample_rate: int = SAMPLE_RATE) -> None:
6
6
  """ Write a simple, uncompressed WAV file.
7
7
 
8
- To write multiple channels, use a 2D array of shape [samples, channels].
8
+ To write multiple channels, use a 2D array of shape [channels, samples].
9
9
  The bits per sample and PCM/float are determined by the data type.
10
10
 
11
11
  """
12
- import numpy as np
13
12
  import torch
14
13
  import torchaudio
15
14
 
16
- if audio.ndim == 1:
17
- audio = np.reshape(audio, (1, audio.shape[0]))
15
+ data = torch.tensor(audio)
18
16
 
19
- torchaudio.save(name, torch.tensor(audio), sample_rate)
17
+ if data.dim() == 1:
18
+ data = torch.reshape(data, (1, data.shape[0]))
19
+ if data.dim() != 2:
20
+ raise ValueError(f'audio must be a 1D or 2D array')
21
+
22
+ # Assuming data has more samples than channels, check if array needs to be transposed
23
+ if data.shape[1] < data.shape[0]:
24
+ data = torch.transpose(data, 0, 1)
25
+
26
+ torchaudio.save(uri=name, src=data, sample_rate=sample_rate)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sonusai
3
- Version: 0.15.8
3
+ Version: 0.15.9
4
4
  Summary: Framework for building deep neural network models for sound, speech, and voice AI
5
5
  Home-page: https://aaware.com
6
6
  License: GPL-3.0-only
@@ -16,28 +16,35 @@ Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
17
  Requires-Dist: PyYAML (>=6.0.1,<7.0.0)
18
18
  Requires-Dist: aixplain (>=0.2.6,<0.3.0)
19
+ Requires-Dist: bitarray (>=2.9.2,<3.0.0)
19
20
  Requires-Dist: ctranslate2 (==4.1.0)
20
21
  Requires-Dist: dataclasses-json (>=0.6.1,<0.7.0)
21
22
  Requires-Dist: deepgram-sdk (>=3.0.0,<4.0.0)
22
23
  Requires-Dist: docopt (>=0.6.2,<0.7.0)
24
+ Requires-Dist: einops (>=0.7.0,<0.8.0)
23
25
  Requires-Dist: faster-whisper (>=1.0.1,<2.0.0)
26
+ Requires-Dist: geomloss (>=0.2.6,<0.3.0)
24
27
  Requires-Dist: h5py (>=3.11.0,<4.0.0)
28
+ Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
25
29
  Requires-Dist: jiwer (>=3.0.3,<4.0.0)
26
30
  Requires-Dist: keras (>=3.1.1,<4.0.0)
27
31
  Requires-Dist: keras-tuner (>=1.4.7,<2.0.0)
28
32
  Requires-Dist: librosa (>=0.10.1,<0.11.0)
29
33
  Requires-Dist: lightning (>=2.2,<2.3)
30
34
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
35
+ Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
31
36
  Requires-Dist: onnx (>=1.14.1,<2.0.0)
32
37
  Requires-Dist: onnxruntime (>=1.16.1,<2.0.0)
33
38
  Requires-Dist: paho-mqtt (>=2.0.0,<3.0.0)
34
39
  Requires-Dist: pandas (>=2.1.1,<3.0.0)
35
40
  Requires-Dist: pesq (>=0.0.4,<0.0.5)
36
41
  Requires-Dist: pyaaware (>=1.5.3,<2.0.0)
42
+ Requires-Dist: pyaudio (>=0.2.14,<0.3.0)
37
43
  Requires-Dist: pydub (>=0.25.1,<0.26.0)
38
44
  Requires-Dist: pystoi (>=0.4.0,<0.5.0)
39
45
  Requires-Dist: python-magic (>=0.4.27,<0.5.0)
40
46
  Requires-Dist: requests (>=2.31.0,<3.0.0)
47
+ Requires-Dist: sacrebleu (>=2.4.2,<3.0.0)
41
48
  Requires-Dist: samplerate (>=0.2.1,<0.3.0)
42
49
  Requires-Dist: soundfile (>=0.12.1,<0.13.0)
43
50
  Requires-Dist: sox (>=1.4.1,<2.0.0)
@@ -1,27 +1,27 @@
1
1
  sonusai/__init__.py,sha256=KmIJ9wni9d9v5pyu0pUxbacZIHGkAywB9CJwl7JME28,1526
2
2
  sonusai/aawscd_probwrite.py,sha256=GukR5owp_0A3DrqSl9fHWULYgclNft4D5OkHIwfxxkc,3698
3
- sonusai/calc_metric_spenh.py,sha256=Cf02uYB6fzDX5anQuBTR_aD7mYt1WgxkEIzZA08uFvs,60825
3
+ sonusai/audiofe.py,sha256=XE_cgOhhTryjPUePxW_8NY1TwrnRZ6BHCsH-gp8PmYw,11471
4
+ sonusai/calc_metric_spenh.py,sha256=D8iQVSIhFhrsUwKuIP-S38NBnyfAOZlsOIIgOZwGOOI,60852
4
5
  sonusai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
6
  sonusai/data/genmixdb.yml,sha256=-XSs_hUR6wHJVoTPmSewzXL7u61X-xmHY46lNPatxSE,1025
6
7
  sonusai/data/speech_ma01_01.wav,sha256=PK0vMKg-NR6rPE3KouxHGF6PKXnJCr7AwjMqfu98LUA,76644
7
8
  sonusai/data/whitenoise.wav,sha256=I2umov0m34y56F9IsIBi1XtE76ZeZaSKDf70cJRe3pI,1920044
8
9
  sonusai/data_generator/__init__.py,sha256=ouCpY5EDV35fKFeKGQfIcU8uE-c3QcuNerTxUA1X5L8,232
9
- sonusai/data_generator/dataset_from_mixdb.py,sha256=Fe_7xxCYURvbZjuIZrvzozdkrx7ZO7gxGg8c5ob2dys,5478
10
- sonusai/data_generator/keras_from_mixdb.py,sha256=V5CUsGz-akIYdgQy9ABxwNKMYKv01klA4xtMDveF6uI,6167
10
+ sonusai/data_generator/dataset_from_mixdb.py,sha256=D14L8BL7a0WgkF8a8eogQ9Hk9ow4_RK3QBGsZ-HDAog,5493
11
+ sonusai/data_generator/keras_from_mixdb.py,sha256=14r89aX6Dr9ZKsmMRC7HDXbJrPrCZC1liwwLmZUKj0w,6182
11
12
  sonusai/data_generator/torch_from_mixdb.py,sha256=lvEe9DDu_rIaoyhv9PW4UAnAWp5N74L8kRfxUhsh7oo,4279
12
13
  sonusai/doc/__init__.py,sha256=rP5Hgn0Iys_xkuv4caxngdqehuU4zLZsiKuv8Nde67M,19
13
14
  sonusai/doc/doc.py,sha256=3z210v6ZckuOlsGZ3ySQBdlCNmBp2M1ahqhqG_eUN58,22664
14
15
  sonusai/doc.py,sha256=l8CaFgLI8mqx4tn0aXfxKqa2dy9GgC0zjYxZAkpmi1E,878
15
- sonusai/evaluate.py,sha256=1BX9dCXMHg6LefbvkevxYpjM3JR-i0jR3Eob2oNh_hQ,9983
16
- sonusai/genft.py,sha256=CeQN8Sxb_NKeXWJxN9HtzUu687eXl97nHBxzzCzQdLg,5557
16
+ sonusai/genft.py,sha256=6hOds8d-pYRawesLYh7XLrQh4PweWUj8U5Cbzx45bNQ,5572
17
17
  sonusai/genmix.py,sha256=0AiUfF7n0CGOp5v-woNfeP3-QaVQUb0CJZc0oXkvqpk,7016
18
- sonusai/genmixdb.py,sha256=onWkaCPzmUMvDtiGf3A7UdGkOA8xe9zuJTsuLfMdb_s,19597
18
+ sonusai/genmixdb.py,sha256=rAxCKNPkOXaAugEfp9pTcHCQONapdTnxMlBsIPIoizE,19639
19
19
  sonusai/gentcst.py,sha256=8jeXirgJe0OGgknC8A-rIudjHeH8UTYPpuh71Ha-I3w,20165
20
20
  sonusai/keras_onnx.py,sha256=WHcPIcff4VPdiXqGX-TU-_x-UuKUD3nNpQtCX-2NEVQ,2658
21
- sonusai/keras_predict.py,sha256=ffrtI12T8_Rm26KsBee0dfEcnEuhOxvtV80YmbJR0m0,9085
21
+ sonusai/keras_predict.py,sha256=_83EtPtnfrqwUzC2H2tk4LI90RiQdyEEBxFGTgFPl3M,9090
22
22
  sonusai/keras_train.py,sha256=8_M5vY-CkonPzbOtOF3Vk-wox-42o8fkaOKLjk7Oc2k,13226
23
- sonusai/lsdb.py,sha256=qTEHQ5X9Ruc0ph4OUhYKFZ6xGC3gyhcNq7t25oyLhc8,5929
24
- sonusai/main.py,sha256=_s3WCFP_zQ7pD0OgdSVRPQRE38V5Qp6bG56JX-Y79ME,3120
23
+ sonusai/lsdb.py,sha256=TTMQ-0H8fFzUSczt6yjy-9xUjZSdIGQzTVH5Xr6XPSA,5941
24
+ sonusai/main.py,sha256=KjN0dCI6rWare4wo_ACzTlURW7pvTw03n51pH7EyLAU,3108
25
25
  sonusai/metrics/__init__.py,sha256=56itZW3S1I7ZYvbxPmFIVPAh1AIJZdljByz1uCrHqFE,635
26
26
  sonusai/metrics/calc_class_weights.py,sha256=dyY7daEIf5Ms5tfTf6wF0fkx_GnMADHOZR_rtsfGoVM,3933
27
27
  sonusai/metrics/calc_optimal_thresholds.py,sha256=9fRfwl-aKAbzHJyqGHv4o8BpZXG9HHB7zUJObHXfYM4,3522
@@ -35,24 +35,24 @@ sonusai/metrics/class_summary.py,sha256=4Mb25nuk6eqotnQSFMuOQL3zofGcpNXDfDlPa513
35
35
  sonusai/metrics/confusion_matrix_summary.py,sha256=3qg6TMKjJeHtNjj2YnNjPFSlMrQXt0Zcu1dLkGB_aPU,4001
36
36
  sonusai/metrics/one_hot.py,sha256=QSeH_GdqBpOAKLrNnQ8gjcPC-vSdUqC0yPEQueTA6VI,13548
37
37
  sonusai/metrics/snr_summary.py,sha256=P4U5_Xr7v9F8kF-rZBnpsVNt3p42rIVS6zmch8yfVfg,5575
38
- sonusai/mixture/__init__.py,sha256=WzPHGSWz6v64HCSFTGboG5o-xBy_0II4i4tkf1UL1Vw,5251
39
- sonusai/mixture/audio.py,sha256=6W2ihjJGBy7Xggx1imF7bzkymDuipPOgC63j2J7Wf-E,3456
38
+ sonusai/mixture/__init__.py,sha256=fCVSlizYxUUQQD9nSZ8bEbfc_TB2yiOC14HPOB4KFz4,5287
39
+ sonusai/mixture/audio.py,sha256=S-ZROf5rVvwv1TCEuwJHz1FfX4oVubb4QhbybUMMqtM,2150
40
40
  sonusai/mixture/augmentation.py,sha256=Blb90tdTwBOj5w9tRcYyS5H67YJuFiXsGqwZWd7ON4g,10468
41
41
  sonusai/mixture/class_count.py,sha256=_wFnVl2yEOnbor7pLg7cYOUeX6nioov-03Cv3SEbh2k,996
42
42
  sonusai/mixture/config.py,sha256=CXIkVRJmaW2QW_sGl0aIqPf7I_TesyGhUYzxouw5UX4,22266
43
43
  sonusai/mixture/constants.py,sha256=xjCskcQi6khqYZDf7j6z1OkeN1C6wE06kBBapcJiNI4,1428
44
- sonusai/mixture/datatypes.py,sha256=xN-GdPCEHGE2Ak_TdFbjuSyMs4x7TLRp59trbMTiYLg,8164
44
+ sonusai/mixture/datatypes.py,sha256=zaxfOHw8ddt-i8JPYOPnlqWz_EHBEDoO4q2VAqJViHM,8173
45
45
  sonusai/mixture/eq_rule_is_valid.py,sha256=MpQwRA5M76wSiQWEI1lW2cLFdPaMttBLcQp3tWD8efM,1243
46
- sonusai/mixture/feature.py,sha256=io6OiJAJ3GYvPChiUmPQuP3h0OB2onjYF8o9-AWkmqM,1996
46
+ sonusai/mixture/feature.py,sha256=Rwuf82IoXzhHPGbKYVGcatImF_ssBf_FfvbqghVPXtg,4116
47
47
  sonusai/mixture/generation.py,sha256=miUrc3QOSUNIG6mDkiMCZ6M2ulivUZxlYUAJUOVomWc,39039
48
- sonusai/mixture/helpers.py,sha256=XqpcB15MezEMVJwf3jxzATDJSpj_27b8Cru1TDIFD7w,21326
49
- sonusai/mixture/log_duration_and_sizes.py,sha256=r-wVjrLW1XBciOL4pkZSYMR7ZNADbojE95TPSQkp3kc,1329
48
+ sonusai/mixture/helpers.py,sha256=GSGSD2KnvOeEIB6IwNTxyaQNjghTSBMB729kUEd_RiM,22403
49
+ sonusai/mixture/log_duration_and_sizes.py,sha256=baTUpqyM15wA125jo9E3posmVJUe3WlpksyO6v9Jul0,1347
50
50
  sonusai/mixture/mapped_snr_f.py,sha256=mlbYM1t14OXe_Zg4CjpWTuA_Zun4W0O3bSUXeodRBQs,1845
51
- sonusai/mixture/mixdb.py,sha256=FQ5hirb2zR8Aj1UNtz89qJQ8wlE0ELC80IxQDmyhsKk,45188
51
+ sonusai/mixture/mixdb.py,sha256=9Pe0mEG8pnEf9NZynTIldc05GfdOrgmcVoIt63RG5DA,45279
52
52
  sonusai/mixture/soundfile_audio.py,sha256=Ow_IWIMz4pMsLxMP_JsQ8AuHLCWlYQinLa58CFW97f8,2804
53
53
  sonusai/mixture/sox_audio.py,sha256=HT3kYA9TP5QPCuoOJdUMnGVN-qY6q96DGL8zxuog76o,12277
54
54
  sonusai/mixture/sox_augmentation.py,sha256=F9tBdNvX2guCn7gRppAFrxRnBtjw9q6qAq2_v_A4hh0,4490
55
- sonusai/mixture/spectral_mask.py,sha256=LKFnrqZryPHT6FBNiT7yFxOeXc6-AUg6X54N26d8ctg,2107
55
+ sonusai/mixture/spectral_mask.py,sha256=8AkCwhy-PSdP1Uri9miKZP-bXFYnFcH_c9xZCGrHavU,2071
56
56
  sonusai/mixture/target_class_balancing.py,sha256=NTNiKZH0_PWLooeow0l41CjJKK8ZTMVbUqz9ZkaNtWk,4900
57
57
  sonusai/mixture/targets.py,sha256=wyy5vhLhuN-hqBMBGoziVvEJg3FKFvJFgmEE7_LaV2M,7908
58
58
  sonusai/mixture/tokenized_shell_vars.py,sha256=gCxw8SQUcal6mqWKF7hOBTgSQmbJUk1nT0Gn3H8GA0U,4705
@@ -61,24 +61,24 @@ sonusai/mixture/torchaudio_augmentation.py,sha256=1vEDHI0caL1vrgoY2lAWe4CiHE2jKR
61
61
  sonusai/mixture/truth.py,sha256=Y41pZ52Xkols9LUler0NlgnilUOscBIucmw4GcxXNzU,1612
62
62
  sonusai/mixture/truth_functions/__init__.py,sha256=82lKYHhLy8KW3gHngrocoqwupGVLVsWdIXdYs3vhjOc,359
63
63
  sonusai/mixture/truth_functions/crm.py,sha256=_Vy8UMrOUQXsrM3nutvUMWCpvI8GePr01QFlyqLFd4k,2626
64
- sonusai/mixture/truth_functions/data.py,sha256=NJNZz5fB3jnntUDlnsKJVQIeuHNUvD4x5iNaQVQlo3Y,2857
64
+ sonusai/mixture/truth_functions/data.py,sha256=okFJeOf43NxfdLqWFCBA2pOGqujRlNDYdAcwwR_m8z8,2875
65
65
  sonusai/mixture/truth_functions/energy.py,sha256=ydMtMLjMloG76DB30ZHQ5tkBVh4dkMJ82XEhKBokmIk,4281
66
66
  sonusai/mixture/truth_functions/file.py,sha256=jOJuC_3y9BH6GGOp9eKcbVrHLVRzUA80BJq59LhcBUM,1539
67
67
  sonusai/mixture/truth_functions/phoneme.py,sha256=stYdlPuNytQK_LLT61OJLfYSqKd-sDjQZdtJKGzt5wA,479
68
68
  sonusai/mixture/truth_functions/sed.py,sha256=8cHjEFjZaH_0hIOHhPmj4AJz2GpEADM6Ys2x4NoiWSY,2469
69
- sonusai/mixture/truth_functions/target.py,sha256=3rPXYwU4SBiPP3uIDpOL-B2Xw1Zh3JboD_MYNEyUpuk,5746
69
+ sonusai/mixture/truth_functions/target.py,sha256=KAsjugDRooOA5BRcHVAbZRgV7l8S5CFg7CZ0XtKZaQ0,5764
70
70
  sonusai/mkmanifest.py,sha256=dIPVFKKhnhHdq63OGr6p__pK7fyx3OdKVtbmGUJxsR8,7078
71
71
  sonusai/mkwav.py,sha256=LZNyhq4gJEs_NtGvRsYHA2qfgkkODpt6HoH1b-Tjjuw,5266
72
- sonusai/onnx_predict.py,sha256=RhQbbNG3w6rCXuSFUWCaQmUH5JzSP2hmu6TG5_81IVA,9055
73
- sonusai/plot.py,sha256=GPrbk7SxwojtjR9KLE9jQ6ywYWXca296W0iaRPdnBHI,16982
74
- sonusai/post_spenh_targetf.py,sha256=Y9J5JbFnCFKCednNiYvVXaC5Z3lF5KxifFJ2RCy4jmg,4975
72
+ sonusai/onnx_predict.py,sha256=Bz_pR28oAZBarNajlKwyzBxmW7ktum77SmxDN2onKPM,9060
73
+ sonusai/plot.py,sha256=u-PvF8guNcm0b-GN99xfEkrcAAtidAEY3RLDzNvcyYk,17014
74
+ sonusai/post_spenh_targetf.py,sha256=NIMhDXeDuUqeWukNaAUMvDw9JpEVCauwjrL2F4M9nrI,4927
75
75
  sonusai/queries/__init__.py,sha256=oKY5JeqZ4Cz7DwCwPc1_ydB8bUs6KaMcWFp_w02TjOs,255
76
76
  sonusai/queries/queries.py,sha256=FNMUKnoY_Ya9S5sNhsB8ppwy0B7V55ilbbjhQRv_UN8,7552
77
77
  sonusai/torchl_onnx.py,sha256=5JYow3XpBaUdtuyAW0mOZyCKL_4FrHvEekYBRdDT6KA,8967
78
- sonusai/torchl_predict.py,sha256=-wlUdRGPjOvGjCQZY277D8tGeZp1KyDU7TpYH7ovC0c,26657
78
+ sonusai/torchl_predict.py,sha256=P1ySDH_ITOPefZ2xZqyxyIrsNDqblKTBLZqFApgo5EU,26238
79
79
  sonusai/torchl_train.py,sha256=NPCRB0gwTvabivmOz78gjUreDeO1z16PYuw7L1-pIRQ,9680
80
80
  sonusai/tplot.py,sha256=yFyyyg9ymp2Eh-64Muu0EFFEY61MoJSV0a_fy9OWaCk,14485
81
- sonusai/utils/__init__.py,sha256=ynmdoPJVdbk5eYq9TA9t4QpGEkdiGATDx1pFu43u7YQ,2180
81
+ sonusai/utils/__init__.py,sha256=tVSmxinSo0Enexpol6wCzz6tU7WrueC-YslFgQr-o7M,2382
82
82
  sonusai/utils/asl_p56.py,sha256=GCKlz-NLInQ0z41XBi0mOvGdSfRZf3WI53necVNDo80,3837
83
83
  sonusai/utils/asr.py,sha256=QN1wdO9-EqD72-ixr4lnzsPfT8i0syhTGj1evKNJWe4,2021
84
84
  sonusai/utils/asr_functions/__init__.py,sha256=4boXXOXlQHTt8K2DWOwFXSlc8D2NLFd8QTc68yL2ejU,214
@@ -93,9 +93,11 @@ sonusai/utils/asr_manifest_functions/__init__.py,sha256=Lz12aCGvfngZkLoUxHSqFjHc
93
93
  sonusai/utils/asr_manifest_functions/data.py,sha256=mJsaHccBReguOJu9qsshRhL-3GbeyqM0-PXMseFnZbE,151
94
94
  sonusai/utils/asr_manifest_functions/librispeech.py,sha256=HIaytcYmjRUkuR6fCQlv3Jh3IDWSox_A6WFcFFAHN9M,1635
95
95
  sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py,sha256=-69lM0dz18KbU5_-dmSeqDoNNwgJj4UlxgGkNBEi3wM,2169
96
+ sonusai/utils/audio_devices.py,sha256=LgaXTln1oRArBzaet3rZiIO2plgtaThuGBc3sJ_sLlo,1414
96
97
  sonusai/utils/braced_glob.py,sha256=h4hab7YDbM4CjLg9iSzyHZrkd22IPUOY5zZqHdifkh8,1510
97
- sonusai/utils/calculate_input_shape.py,sha256=fyf8Ggxn9xljJ87BNwIzOh2KD4pKUGFZ2RNzE3195NI,1023
98
+ sonusai/utils/calculate_input_shape.py,sha256=63ILxibYKuTQozY83QN8Y2OOhBEbW_1X47Q0askcHDM,984
98
99
  sonusai/utils/convert_string_to_number.py,sha256=i17yIxurp8Iz6NPE-imTRlARrXWqadwm8qbOTuzHZvE,236
100
+ sonusai/utils/create_timestamp.py,sha256=TxoQXWZ3SFdBEHLOv-ujeIsTEJuiFnKOGRy-FQq45YU,148
99
101
  sonusai/utils/create_ts_name.py,sha256=8RLKmgXwuGcbDMGgtTuc0MvGFfA7IOVqfjkE2T18GOo,405
100
102
  sonusai/utils/dataclass_from_dict.py,sha256=vAGnuMjhy0W9bxZ5usrH7mbQsFog3n0__IC4xyJyVUc,390
101
103
  sonusai/utils/db.py,sha256=lI77MJJLs4CTYxhjFUvBom2Kk2imAP34okOeO4irbDc,371
@@ -114,15 +116,15 @@ sonusai/utils/print_mixture_details.py,sha256=BzYM4-wHHNa6zxPzBMUJxwKt0gKHmvbwdd
114
116
  sonusai/utils/ranges.py,sha256=NPBZOVzMb95GTOIxltVO-wSzgcXqZ14wbdV46JDLKrw,1222
115
117
  sonusai/utils/read_mixture_data.py,sha256=Sb30RgSpw6DnH_iD81O7G_KOsdfjQWWLk3euEkxfMa8,453
116
118
  sonusai/utils/read_predict_data.py,sha256=5rR_ijrrcS2cKO1Sea2M2QEicokTtW5XtAo6jT5YSX8,1064
117
- sonusai/utils/reshape.py,sha256=sz-V4Za3DIvqFztTh0yYhQRcofnjc9XBW0CBhxg18lo,5854
119
+ sonusai/utils/reshape.py,sha256=E8Eu6grynaeWwVO6peIR0BF22SrVaJSa1Rkl109lq6Y,5997
118
120
  sonusai/utils/seconds_to_hms.py,sha256=oxLuZhTJJr9swj-fOSOrZJ5vBNM7_BrOMQhX1pYpiv0,260
119
121
  sonusai/utils/stacked_complex.py,sha256=feLhz3GC1ILxBGMHOj3sJK--sidsXKbfwkalwAVwizc,2950
120
122
  sonusai/utils/stratified_shuffle_split.py,sha256=rJNXvBp-GxoKzH3OpL7k0ANSu5xMP2zJ7K1fm_33UzE,7022
121
123
  sonusai/utils/trim_docstring.py,sha256=dSrtiRsEN4wkkvKBp6WDr13RUypfqZzgH_jOBLs1ouY,881
122
- sonusai/utils/wave.py,sha256=TKE-CNPGFXNXUW626CBPzCTNgWJut8I0ZEUsgG9q4Po,586
124
+ sonusai/utils/wave.py,sha256=O4ZXkZ6wjrKGa99wBCdFd8G6bp91MXXDnmGihpaEMh0,856
123
125
  sonusai/utils/yes_or_no.py,sha256=eMLXBVH0cEahiXY4W2KNORmwNQ-ba10eRtldh0y4NYg,263
124
126
  sonusai/vars.py,sha256=m2AefF0m5bXWGXpJj8Pi42zWL2ydeEj7bkak3GrtMyM,940
125
- sonusai-0.15.8.dist-info/METADATA,sha256=3eCpCJmXOfr7GV3a7HDWo0iilEVHB5ANdQqS59O0Yi0,2920
126
- sonusai-0.15.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
127
- sonusai-0.15.8.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
128
- sonusai-0.15.8.dist-info/RECORD,,
127
+ sonusai-0.15.9.dist-info/METADATA,sha256=DudNQlTEQpWpzqyzyowz_V-J9epd7mrKgAYM6rFxaPo,3209
128
+ sonusai-0.15.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
129
+ sonusai-0.15.9.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
130
+ sonusai-0.15.9.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.8.1
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
sonusai/evaluate.py DELETED
@@ -1,245 +0,0 @@
1
- """sonusai evaluate
2
-
3
- usage: evaluate [-hv] [-i MIXID] (-f FEATURE) (-p PREDICT) [-t PTHR] LOC
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -p PREDICT, --predict PREDICT A directory containing prediction data.
10
- -t PTHR, --thr PTHR Optional prediction decision threshold(s). [default: 0].
11
-
12
- Evaluate calculates performance metrics of neural-network models from model prediction data and genft data.
13
-
14
- Inputs:
15
- LOC A SonusAI mixture database directory.
16
- MIXID A glob of mixture ID(s) to generate.
17
- PREDICT A directory containing SonusAI predict HDF5 files. Contains:
18
- dataset: predict (either [frames, num_classes] or [frames, timesteps, num_classes])
19
- PTHR Scalar or array of thresholds. Default 0 will select values:
20
- argmax() if mixdb indicates single-label mode (truth_mutex = true)
21
- 0.5 if mixdb indicates multi-label mode (truth_mutex = false)
22
- If PTHR = -1, optimal thresholds are calculated using precision_recall_curve() which
23
- optimizes F1 score.
24
- """
25
- import numpy as np
26
-
27
- from sonusai import logger
28
- from sonusai.mixture import Feature
29
- from sonusai.mixture import MixtureDatabase
30
- from sonusai.mixture import Predict
31
- from sonusai.mixture import Segsnr
32
- from sonusai.mixture import Truth
33
-
34
-
35
- def evaluate(mixdb: MixtureDatabase,
36
- truth: Truth,
37
- predict: Predict = None,
38
- segsnr: Segsnr = None,
39
- output_dir: str = None,
40
- predict_thr: float | np.ndarray = 0,
41
- feature: Feature = None,
42
- verbose: bool = False) -> None:
43
- from os.path import join
44
-
45
- from sonusai import initial_log_messages
46
- from sonusai import update_console_handler
47
- from sonusai.metrics import calc_optimal_thresholds
48
- from sonusai.metrics import class_summary
49
- from sonusai.metrics import snr_summary
50
- from sonusai.mixture import SAMPLE_RATE
51
- from sonusai.queries import get_mixids_from_snr
52
- from sonusai.utils import get_num_classes_from_predict
53
- from sonusai.utils import human_readable_size
54
- from sonusai.utils import reshape_outputs
55
- from sonusai.utils import seconds_to_hms
56
-
57
- update_console_handler(verbose)
58
- initial_log_messages('evaluate')
59
-
60
- if truth.shape[-1] != predict.shape[-1]:
61
- logger.exception(f'Number of classes in truth and predict are not equal. Exiting.')
62
- raise SystemExit(1)
63
-
64
- # truth, predict can be either [frames, num_classes] or [frames, timesteps, num_classes]
65
- # in binary case dim may not exist, detect this and set num_classes == 1
66
- timesteps = -1
67
- predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
68
- num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
69
-
70
- fdiff = truth.shape[0] - predict.shape[0]
71
- if fdiff > 0:
72
- # truth = truth[0:-fdiff,:]
73
- predict = np.concatenate((predict, np.zeros((fdiff, num_classes), dtype=np.float32)))
74
- logger.info(f'Truth has more feature-frames than predict, padding predict with zeros to match.')
75
-
76
- if fdiff < 0:
77
- predict = predict[0:fdiff, :]
78
- logger.info(f'Predict has more feature-frames than truth, trimming predict to match.')
79
-
80
- # Check segsnr, input is always in transform frames
81
- compute_segsnr = False
82
- if len(segsnr) > 0:
83
- segsnr_feature_frames = segsnr.shape[0] / (mixdb.feature_step_samples / mixdb.ft_config.R)
84
- if segsnr_feature_frames == truth.shape[0]:
85
- compute_segsnr = True
86
- else:
87
- logger.warning('segsnr length does not match truth, ignoring.')
88
-
89
- # Check predict_thr array or scalar and return final scalar predict_thr value
90
- if not mixdb.truth_mutex:
91
- if num_classes > 1:
92
- if not isinstance(predict_thr, np.ndarray):
93
- if predict_thr == 0:
94
- # multi-label predict_thr scalar 0 force to 0.5 default
95
- predict_thr = np.atleast_1d(0.5)
96
- else:
97
- predict_thr = np.atleast_1d(predict_thr)
98
- else:
99
- if predict_thr.ndim == 1:
100
- if predict_thr[0] == 0:
101
- # multi-label predict_thr array scalar 0 force to 0.5 default
102
- predict_thr = np.atleast_1d(0.5)
103
- else:
104
- # multi-label predict_thr array set to scalar = array[0]
105
- predict_thr = predict_thr[0]
106
- else:
107
- # single-label mode, force argmax mode
108
- predict_thr = np.atleast_1d(0)
109
-
110
- if predict_thr == -1:
111
- thrpr, thrroc, _, _ = calc_optimal_thresholds(truth, predict, timesteps)
112
- predict_thr = np.atleast_1d(thrpr)
113
- predict_thr = np.maximum(predict_thr, 0.001) # enforce lower limit
114
- predict_thr = np.minimum(predict_thr, 0.999) # enforce upper limit
115
- predict_thr = predict_thr.round(2)
116
-
117
- # Summarize the mixture data
118
- num_mixtures = mixdb.num_mixtures
119
- total_samples = sum([mixture.samples for mixture in mixdb.mixtures])
120
- duration = total_samples / SAMPLE_RATE
121
-
122
- logger.info('')
123
- logger.info(f'Mixtures: {num_mixtures}')
124
- logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
125
- logger.info(f'truth: {human_readable_size(truth.nbytes, 1)}')
126
- logger.info(f'predict: {human_readable_size(predict.nbytes, 1)}')
127
- if compute_segsnr:
128
- logger.info(f'segsnr: {human_readable_size(segsnr.nbytes, 1)}')
129
- if feature:
130
- logger.info(f'feature: {human_readable_size(feature.nbytes, 1)}')
131
-
132
- logger.info(f'Classes: {num_classes}')
133
- if mixdb.truth_mutex:
134
- logger.info(f'Mode: Single-label / truth_mutex / softmax')
135
- else:
136
- logger.info(f'Mode: Multi-label / Binary')
137
-
138
- mxid_snro = get_mixids_from_snr(mixdb=mixdb)
139
- snrlist = list(mxid_snro.keys())
140
- snrlist.sort(reverse=True)
141
- logger.info(f'Ordered SNRs: {snrlist}\n')
142
- predict_thr_info = predict_thr.transpose() if isinstance(predict_thr, np.ndarray) else predict_thr
143
- logger.info(f'Prediction Threshold(s): {predict_thr_info}\n')
144
-
145
- # Top-level report over all mixtures
146
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
147
- mixid=':',
148
- truth_f=truth,
149
- predict=predict,
150
- segsnr=segsnr if compute_segsnr else None,
151
- predict_thr=predict_thr)
152
-
153
- if num_classes > 1:
154
- logger.info(f'Metrics micro-avg per SNR over all {num_mixtures} mixtures:')
155
- else:
156
- logger.info(f'Metrics per SNR over all {num_mixtures} mixtures:')
157
- logger.info(microdf.round(3).to_string())
158
- logger.info('')
159
- if output_dir:
160
- microdf.round(3).to_csv(join(output_dir, 'snr.csv'))
161
-
162
- if mixdb.truth_mutex:
163
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
164
- mixid=':',
165
- truth_f=truth[:, 0:-1],
166
- predict=predict[:, 0:-1],
167
- segsnr=segsnr if compute_segsnr else None,
168
- predict_thr=predict_thr)
169
-
170
- logger.info(f'Metrics micro-avg without "Other" class per SNR over all {num_mixtures} mixtures:')
171
- logger.info(microdf.round(3).to_string())
172
- logger.info('')
173
- if output_dir:
174
- microdf.round(3).to_csv(join(output_dir, 'snrwo.csv'))
175
-
176
- for snri in snrlist:
177
- mxids = mxid_snro[snri]
178
- classdf = class_summary(mixdb, mxids, truth, predict, predict_thr)
179
- logger.info(f'Metrics per class for SNR {snri} over {len(mxids)} mixtures:')
180
- logger.info(classdf.round(3).to_string())
181
- logger.info('')
182
- if output_dir:
183
- classdf.round(3).to_csv(join(output_dir, f'class_snr{snri}.csv'))
184
-
185
-
186
- def main() -> None:
187
- from datetime import datetime
188
- from os import mkdir
189
- from os.path import join
190
-
191
- import h5py
192
- from docopt import docopt
193
-
194
- import sonusai
195
- from sonusai import SonusAIError
196
- from sonusai import create_file_handler
197
- from sonusai.utils import read_predict_data
198
- from sonusai.utils import trim_docstring
199
-
200
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
201
-
202
- verbose = args['--verbose']
203
- feature_name = args['--feature']
204
- predict_name = args['--predict']
205
- predict_threshold = np.array(float(args['--thr']), dtype=np.float32)
206
- location = args['LOC']
207
-
208
- mixdb = MixtureDatabase(location)
209
-
210
- # create output directory
211
- output_dir = f'evaluate-{datetime.now():%Y%m%d}'
212
- try:
213
- mkdir(output_dir)
214
- except OSError as _:
215
- output_dir = f'evaluate-{datetime.now():%Y%m%d-%H%M%S}'
216
- try:
217
- mkdir(output_dir)
218
- except OSError as error:
219
- raise SonusAIError(f'Could not create directory, {output_dir}: {error}')
220
-
221
- create_file_handler(join(output_dir, 'evaluate.log'))
222
-
223
- with h5py.File(feature_name, 'r') as f:
224
- truth_f = np.array(f['truth_f'])
225
- segsnr = np.array(f['segsnr'])
226
-
227
- predict = read_predict_data(predict_name)
228
-
229
- evaluate(mixdb=mixdb,
230
- truth=truth_f,
231
- segsnr=segsnr,
232
- output_dir=output_dir,
233
- predict=predict,
234
- predict_thr=predict_threshold,
235
- verbose=verbose)
236
-
237
- logger.info(f'Wrote results to {output_dir}')
238
-
239
-
240
- if __name__ == '__main__':
241
- try:
242
- main()
243
- except KeyboardInterrupt:
244
- logger.info('Canceled due to keyboard interrupt')
245
- raise SystemExit(0)