sonusai 0.11.3__tar.gz → 0.11.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. {sonusai-0.11.3 → sonusai-0.11.4}/PKG-INFO +1 -2
  2. {sonusai-0.11.3 → sonusai-0.11.4}/pyproject.toml +1 -2
  3. {sonusai-0.11.3 → sonusai-0.11.4}/setup.py +1 -2
  4. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/__init__.py +2 -3
  5. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/data_generator/dataset_from_mixdb.py +104 -0
  6. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/data_generator/keras_from_mixdb.py +26 -2
  7. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/keras_train.py +38 -4
  8. {sonusai-0.11.3 → sonusai-0.11.4}/README.rst +0 -0
  9. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/aawscd_probwrite.py +0 -0
  10. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/calc_metric_spenh_targetf.py +0 -0
  11. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/data/genmixdb.yml +0 -0
  12. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/data/whitenoise.wav +0 -0
  13. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/data_generator/__init__.py +0 -0
  14. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/evaluate.py +0 -0
  15. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/genft.py +0 -0
  16. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/genmix.py +0 -0
  17. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/genmixdb.py +0 -0
  18. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/gentcst.py +0 -0
  19. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/keras_onnx.py +0 -0
  20. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/keras_predict.py +0 -0
  21. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/lsdb.py +0 -0
  22. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/main.py +0 -0
  23. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/__init__.py +0 -0
  24. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_class_weights.py +0 -0
  25. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_optimal_thresholds.py +0 -0
  26. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_pcm.py +0 -0
  27. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_pesq.py +0 -0
  28. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_sa_sdr.py +0 -0
  29. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_sample_weights.py +0 -0
  30. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/calc_wer.py +0 -0
  31. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/class_summary.py +0 -0
  32. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/confusion_matrix_summary.py +0 -0
  33. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/one_hot.py +0 -0
  34. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/metrics/snr_summary.py +0 -0
  35. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/__init__.py +0 -0
  36. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/active_truth_class_balancing.py +0 -0
  37. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/audio.py +0 -0
  38. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/augmentation.py +0 -0
  39. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/balance.py +0 -0
  40. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/class_count.py +0 -0
  41. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/config.py +0 -0
  42. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/constants.py +0 -0
  43. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/feature.py +0 -0
  44. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/initialize.py +0 -0
  45. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/log_duration_and_sizes.py +0 -0
  46. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/mapped_snr_f.py +0 -0
  47. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/mixdb.py +0 -0
  48. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/spectral_mask.py +0 -0
  49. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/target_class_balancing.py +0 -0
  50. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/targets.py +0 -0
  51. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth.py +0 -0
  52. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/__init__.py +0 -0
  53. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/crm.py +0 -0
  54. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/data.py +0 -0
  55. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/energy.py +0 -0
  56. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/file.py +0 -0
  57. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/phoneme.py +0 -0
  58. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/sed.py +0 -0
  59. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/truth_functions/target.py +0 -0
  60. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mixture/types.py +0 -0
  61. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/mkwav.py +0 -0
  62. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/onnx_predict.py +0 -0
  63. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/plot.py +0 -0
  64. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/post_spenh_targetf.py +0 -0
  65. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/queries/__init__.py +0 -0
  66. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/queries/queries.py +0 -0
  67. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/tplot.py +0 -0
  68. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/__init__.py +0 -0
  69. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asl_p56.py +0 -0
  70. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr.py +0 -0
  71. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr_functions/__init__.py +0 -0
  72. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr_functions/data.py +0 -0
  73. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr_functions/deepgram.py +0 -0
  74. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr_functions/google.py +0 -0
  75. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/asr_functions/whisper.py +0 -0
  76. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/braced_glob.py +0 -0
  77. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/calculate_input_shape.py +0 -0
  78. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/create_ts_name.py +0 -0
  79. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/dataclass_from_dict.py +0 -0
  80. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/db.py +0 -0
  81. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/engineering_number.py +0 -0
  82. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/get_frames_per_batch.py +0 -0
  83. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/get_label_names.py +0 -0
  84. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/grouper.py +0 -0
  85. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/human_readable_size.py +0 -0
  86. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/keras_utils.py +0 -0
  87. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/numeric_conversion.py +0 -0
  88. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/onnx_utils.py +0 -0
  89. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/parallel.py +0 -0
  90. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/parallel_tqdm.py +0 -0
  91. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/print_mixture_details.py +0 -0
  92. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/ranges.py +0 -0
  93. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/read_mixture_data.py +0 -0
  94. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/read_predict_data.py +0 -0
  95. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/reshape.py +0 -0
  96. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/seconds_to_hms.py +0 -0
  97. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/stacked_complex.py +0 -0
  98. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/stratified_shuffle_split.py +0 -0
  99. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/trim_docstring.py +0 -0
  100. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/wave.py +0 -0
  101. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/utils/yes_or_no.py +0 -0
  102. {sonusai-0.11.3 → sonusai-0.11.4}/sonusai/vars.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sonusai
3
- Version: 0.11.3
3
+ Version: 0.11.4
4
4
  Summary: Framework for building deep neural network models for sound, speech, and voice AI
5
5
  Home-page: https://aaware.com
6
6
  License: GPL-3.0-only
@@ -29,7 +29,6 @@ Requires-Dist: pesq (>=0.0.4,<0.0.5)
29
29
  Requires-Dist: pyaaware (>=1.4.10,<2.0.0)
30
30
  Requires-Dist: python-magic (>=0.4.27,<0.5.0)
31
31
  Requires-Dist: scikit-learn (>=1.2.0,<2.0.0)
32
- Requires-Dist: setuptools (>=67.0.0,<68.0.0)
33
32
  Requires-Dist: sh (>=1.14.3,<2.0.0)
34
33
  Requires-Dist: sox (>=1.4.1,<2.0.0)
35
34
  Requires-Dist: speechrecognition (>=3.9.0,<4.0.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sonusai"
3
- version = "0.11.3"
3
+ version = "0.11.4"
4
4
  description = "Framework for building deep neural network models for sound, speech, and voice AI"
5
5
  authors = ["Chris Eddington <chris@aaware.com>", "Jason Calderwood <jason@aaware.com>"]
6
6
  maintainers = ["Chris Eddington <chris@aaware.com>", "Jason Calderwood <jason@aaware.com>"]
@@ -30,7 +30,6 @@ pyaaware = "^1.4.10"
30
30
  python = ">=3.8,<3.11"
31
31
  python-magic = "^0.4.27"
32
32
  scikit-learn = "^1.2.0"
33
- setuptools = "^67.0.0"
34
33
  sh = "^1.14.3"
35
34
  sox = "^1.4.1"
36
35
  speechrecognition = "^3.9.0"
@@ -30,7 +30,6 @@ install_requires = \
30
30
  'pyaaware>=1.4.10,<2.0.0',
31
31
  'python-magic>=0.4.27,<0.5.0',
32
32
  'scikit-learn>=1.2.0,<2.0.0',
33
- 'setuptools>=67.0.0,<68.0.0',
34
33
  'sh>=1.14.3,<2.0.0',
35
34
  'sox>=1.4.1,<2.0.0',
36
35
  'speechrecognition>=3.9.0,<4.0.0',
@@ -45,7 +44,7 @@ entry_points = \
45
44
 
46
45
  setup_kwargs = {
47
46
  'name': 'sonusai',
48
- 'version': '0.11.3',
47
+ 'version': '0.11.4',
49
48
  'description': 'Framework for building deep neural network models for sound, speech, and voice AI',
50
49
  'long_description': "Sonus AI: Framework for simplified creation of deep NN models for sound, speech, and voice AI\n\nSonus AI includes functions for pre-processing training and validation data and\ncreating performance metrics reports for key types of Keras models:\n- recurrent, convolutional, or a combination (i.e. RCNNs)\n- binary, multiclass single-label, multiclass multi-label, and regresssion\n- training with data augmentations: noise mixing, pitch and time stretch, etc.\n\nSonus AI python functions are used by:\n - Aaware Inc. sonusai executable: Easily create train/validation data, run prediction, evaluate model performance\n - Keras model scripts: User python scripts for keras model creation, training, and prediction. These can use sonusai-specific data but also some general useful utilities for trainining rnn-based models like CRNN's, DSCRNN's, etc. in Keras\n",
51
50
  'author': 'Chris Eddington',
@@ -1,9 +1,8 @@
1
1
  import logging
2
+ from importlib import metadata
2
3
  from os.path import dirname
3
4
 
4
- from pkg_resources import get_distribution
5
-
6
- __version__ = get_distribution('sonusai').version
5
+ __version__ = metadata.version('sonusai')
7
6
  BASEDIR = dirname(__file__)
8
7
 
9
8
  # create logger
@@ -1,10 +1,114 @@
1
1
  import warnings
2
+ from dataclasses import dataclass
2
3
  from typing import List
3
4
 
4
5
  import numpy as np
6
+ import tensorflow as tf
5
7
 
6
8
  from sonusai.mixture import GeneralizedIDs
7
9
  from sonusai.mixture import MixtureDatabase
10
+ from sonusai.utils import get_frames_per_batch
11
+
12
+
13
+ def get_dataset_from_mixdb(mixdb: MixtureDatabase,
14
+ mixids: GeneralizedIDs,
15
+ batch_size: int,
16
+ timesteps: int,
17
+ flatten: bool,
18
+ add1ch: bool,
19
+ shuffle: bool = False) -> tf.data.Dataset:
20
+ @dataclass(frozen=True)
21
+ class BatchParams:
22
+ mixids: List[int]
23
+ offset: int
24
+ extra: int
25
+ padding: int
26
+
27
+ def _getitem(batch_index) -> (np.ndarray, np.ndarray):
28
+ """Get one batch of data
29
+ """
30
+ from sonusai.utils import reshape_inputs
31
+
32
+ batch_params = self.batch_params[batch_index]
33
+
34
+ result = [self.mixdb.mixture_ft(mixid) for mixid in batch_params.mixids]
35
+ feature = np.vstack([result[i][0] for i in range(len(result))])
36
+ truth = np.vstack([result[i][1] for i in range(len(result))])
37
+
38
+ pad_shape = list(feature.shape)
39
+ pad_shape[0] = batch_params.padding
40
+ feature = np.vstack([feature, np.zeros(pad_shape)])
41
+
42
+ pad_shape = list(truth.shape)
43
+ pad_shape[0] = batch_params.padding
44
+ truth = np.vstack([truth, np.zeros(pad_shape)])
45
+
46
+ if batch_params.extra > 0:
47
+ feature = feature[batch_params.offset:-batch_params.extra]
48
+ truth = truth[batch_params.offset:-batch_params.extra]
49
+ else:
50
+ feature = feature[batch_params.offset:]
51
+ truth = truth[batch_params.offset:]
52
+
53
+ feature, truth = reshape_inputs(feature=feature,
54
+ truth=truth,
55
+ batch_size=self.batch_size,
56
+ timesteps=self.timesteps,
57
+ flatten=self.flatten,
58
+ add1ch=self.add1ch)
59
+
60
+ return feature, truth
61
+
62
+ mixids = mixdb.mixids_to_list(mixids)
63
+ stride = mixdb.fg.stride
64
+ num_bands = mixdb.fg.num_bands
65
+ num_classes = mixdb.num_classes
66
+ mixture_frame_segments = None
67
+ batch_frame_segments = None
68
+
69
+ frames_per_batch = get_frames_per_batch(batch_size, timesteps)
70
+ # Always extend the number of batches to use all available data
71
+ # The last batch may need padding
72
+ total_batches = int(np.ceil(mixdb.total_feature_frames(mixids) / frames_per_batch))
73
+
74
+ # Compute mixid, offset, and extra for dataset
75
+ # offsets and extras are needed because mixtures are not guaranteed to fall on batch boundaries.
76
+ # When fetching a new index that starts in the middle of a sequence of mixtures, the
77
+ # previous feature frame offset must be maintained in order to preserve the correct
78
+ # data sequence. And the extra must be maintained in order to preserve the correct data length.
79
+ cumulative_frames = 0
80
+ start_mixture_index = 0
81
+ offset = 0
82
+ batch_params = []
83
+ file_indices = []
84
+ total_frames = 0
85
+ for idx, mixid in enumerate(mixids):
86
+ current_frames = mixdb.mixture_samples(mixid) // mixdb.feature_step_samples
87
+ file_indices.append(slice(total_frames, total_frames + current_frames))
88
+ total_frames += current_frames
89
+ cumulative_frames += current_frames
90
+ while cumulative_frames >= frames_per_batch:
91
+ extra = cumulative_frames - frames_per_batch
92
+ mixids = mixids[start_mixture_index:idx + 1]
93
+ batch_params.append(BatchParams(mixids=mixids, offset=offset, extra=extra, padding=0))
94
+ if extra == 0:
95
+ start_mixture_index = idx + 1
96
+ offset = 0
97
+ else:
98
+ start_mixture_index = idx
99
+ offset = current_frames - extra
100
+ cumulative_frames = extra
101
+
102
+ # If needed, add final batch with padding
103
+ needed_frames = total_batches * frames_per_batch
104
+ padding = needed_frames - total_frames
105
+ if padding != 0:
106
+ mixids = mixids[start_mixture_index:]
107
+ batch_params.append(BatchParams(mixids=mixids, offset=offset, extra=0, padding=padding))
108
+
109
+ dataset = tf.data.Dataset.from_generator()
110
+ return dataset
111
+
8
112
 
9
113
  with warnings.catch_warnings():
10
114
  warnings.simplefilter('ignore')
@@ -1,20 +1,40 @@
1
+ import multiprocessing as mp
1
2
  import warnings
3
+ from dataclasses import dataclass
2
4
  from typing import List
3
5
 
4
6
  import numpy as np
5
7
 
8
+ from sonusai.mixture import Feature
6
9
  from sonusai.mixture import GeneralizedIDs
7
10
  from sonusai.mixture import MixtureDatabase
11
+ from sonusai.mixture import Truth
8
12
 
9
13
  with warnings.catch_warnings():
10
14
  warnings.simplefilter('ignore')
11
15
  from keras.utils import Sequence
12
16
 
13
17
 
18
+ @dataclass
19
+ class MPGlobal:
20
+ mixdb: MixtureDatabase = None
21
+
22
+
23
+ MP_GLOBAL = MPGlobal()
24
+
25
+
26
+ def _pool_init(mixdb: MixtureDatabase) -> None:
27
+ MP_GLOBAL.mixdb = mixdb
28
+
29
+
30
+ def _pool_func(mixid: int) -> (Feature, Truth):
31
+ mixdb = MP_GLOBAL.mixdb
32
+ return mixdb.mixture_ft(mixid)
33
+
34
+
14
35
  class KerasFromMixtureDatabase(Sequence):
15
36
  """Generates data for Keras from a SonusAI mixture database
16
37
  """
17
- from dataclasses import dataclass
18
38
 
19
39
  @dataclass(frozen=True)
20
40
  class BatchParams:
@@ -49,6 +69,10 @@ class KerasFromMixtureDatabase(Sequence):
49
69
 
50
70
  self._initialize_mixtures()
51
71
 
72
+ self.pool = mp.Pool(processes=mp.cpu_count(),
73
+ initializer=_pool_init,
74
+ initargs=[mixdb])
75
+
52
76
  def __len__(self) -> int:
53
77
  """Denotes the number of batches per epoch
54
78
  """
@@ -61,7 +85,7 @@ class KerasFromMixtureDatabase(Sequence):
61
85
 
62
86
  batch_params = self.batch_params[batch_index]
63
87
 
64
- result = [self.mixdb.mixture_ft(mixid) for mixid in batch_params.mixids]
88
+ result = self.pool.map(_pool_func, batch_params.mixids)
65
89
  feature = np.vstack([result[i][0] for i in range(len(result))])
66
90
  truth = np.vstack([result[i][1] for i in range(len(result))])
67
91
 
@@ -1,6 +1,6 @@
1
1
  """sonusai keras_train
2
2
 
3
- usage: keras_train [-hv] (-m MODEL) (-l VLOC) [-w KMODEL] [-e EPOCHS] [-b BATCH] [-t TSTEPS] [-p ESP] TLOC
3
+ usage: keras_train [-hgv] (-m MODEL) (-l VLOC) [-w KMODEL] [-e EPOCHS] [-b BATCH] [-t TSTEPS] [-p ESP] TLOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -12,6 +12,7 @@ options:
12
12
  -b BATCH, --batch BATCH Batch size.
13
13
  -t TSTEPS, --tsteps TSTEPS Timesteps.
14
14
  -p ESP, --patience ESP Early stopping patience.
15
+ -g, --loss-batch-log Enable per-batch loss log. [default: False]
15
16
 
16
17
  Use Keras to train a model defined by a Python definition file and SonusAI genft data.
17
18
 
@@ -20,6 +21,7 @@ Inputs:
20
21
  VLOC A SonusAI mixture database directory to use for validation data.
21
22
 
22
23
  Results are written into subdirectory <MODEL>-<TIMESTAMP>.
24
+ Per-batch loss history, if enabled, is written to <basename>-history-lossb.npy
23
25
 
24
26
  """
25
27
  import tensorflow as tf
@@ -27,6 +29,20 @@ import tensorflow as tf
27
29
  from sonusai import logger
28
30
 
29
31
 
32
+ class LossBatchHistory(tf.keras.callbacks.Callback):
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.history = None
36
+
37
+ def on_train_begin(self, logs=None):
38
+ self.history = {'loss': []}
39
+
40
+ def on_batch_end(self, batch, logs=None):
41
+ if logs is None:
42
+ logs = {}
43
+ self.history['loss'].append(logs.get('loss'))
44
+
45
+
30
46
  class SonusAIModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
31
47
  def __init__(self,
32
48
  filepath,
@@ -80,6 +96,7 @@ def main():
80
96
  batch_size = args['--batch']
81
97
  timesteps = args['--tsteps']
82
98
  esp = args['--patience']
99
+ loss_batch_log = args['--loss-batch-log']
83
100
  t_name = args['TLOC']
84
101
 
85
102
  import warnings
@@ -108,6 +125,7 @@ def main():
108
125
  from sonusai.utils import import_keras_model
109
126
  from sonusai.utils import stratified_shuffle_split_mixid
110
127
  from sonusai.utils import reshape_outputs
128
+ from sonusai.utils import get_frames_per_batch
111
129
 
112
130
  model_base = basename(model_name)
113
131
  model_root = splitext(model_base)[0]
@@ -143,15 +161,18 @@ def main():
143
161
 
144
162
  # Check overrides
145
163
  timesteps = check_keras_overrides(model, t_mixdb.feature, t_mixdb.num_classes, timesteps, batch_size)
164
+ # Calculate batches per epoch, use ceiling as last batch is zero extended
165
+ frames_per_batch = get_frames_per_batch(batch_size, timesteps)
166
+ batches_per_epoch = int(np.ceil(t_mixdb.total_feature_frames('*') / frames_per_batch))
146
167
 
147
- logger.info('Building model')
168
+ logger.info('Building and compiling model')
148
169
  try:
149
170
  hypermodel = model.MyHyperModel(feature=t_mixdb.feature,
150
171
  num_classes=t_mixdb.num_classes,
151
172
  timesteps=timesteps,
152
173
  batch_size=batch_size)
153
174
  built_model = hypermodel.build_model(kt.HyperParameters())
154
- built_model = hypermodel.compile_default(built_model)
175
+ built_model = hypermodel.compile_default(built_model, batches_per_epoch)
155
176
  except Exception as e:
156
177
  logger.exception(f'Error: build_model() in {model_base} failed: {e}')
157
178
  raise SystemExit(1)
@@ -225,6 +246,15 @@ def main():
225
246
  feature=hypermodel.feature,
226
247
  num_classes=hypermodel.num_classes)
227
248
 
249
+ csv_logger = tf.keras.callbacks.CSVLogger(base_name + '-history.csv')
250
+ callbacks = [es, ckpt_callback, csv_logger]
251
+ # loss_batch_log = True
252
+ loss_batchlogger = None
253
+ if loss_batch_log is True:
254
+ loss_batchlogger = LossBatchHistory()
255
+ callbacks.append(loss_batchlogger)
256
+ logger.info(f'Adding per batch loss logging to training')
257
+
228
258
  if weights_name is not None:
229
259
  logger.info(f'Loading weights from {weights_name}')
230
260
  built_model.load_weights(weights_name)
@@ -240,13 +270,17 @@ def main():
240
270
  epochs=epochs,
241
271
  validation_data=v_datagen,
242
272
  shuffle=False,
243
- callbacks=[es, ckpt_callback])
273
+ callbacks=callbacks)
244
274
 
245
275
  # Save history into numpy file
246
276
  history_name = base_name + '-history'
247
277
  np.save(history_name, history.history)
248
278
  # Note: Reload with history=np.load(history_name, allow_pickle='TRUE').item()
249
279
  logger.info(f'Saved training history to numpy file {history_name}.npy')
280
+ if loss_batch_log is True:
281
+ his_batch_loss_name = base_name + '-history-lossb.npy'
282
+ np.save(his_batch_loss_name, loss_batchlogger.history)
283
+ logger.info(f'Saved per-batch loss history to numpy file {his_batch_loss_name}')
250
284
 
251
285
  # Find checkpoint file and load weights for prediction and model save
252
286
  checkpoint_name = None
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes