nkululeko 0.86.2__py3-none-any.whl → 0.86.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.86.2"
1
+ VERSION="0.86.3"
2
2
  SAMPLING_RATE = 16000
@@ -35,6 +35,7 @@ class ImportSet(Featureset):
35
35
  if not os.path.isfile(feat_import_file):
36
36
  self.util.error(f"no import file: {feat_import_file}")
37
37
  df = audformat.utils.read_csv(feat_import_file)
38
+ df = self.util.make_segmented_index(df)
38
39
  df = df[df.index.isin(self.data_df.index)]
39
40
  feat_df = pd.concat([feat_df, df])
40
41
  if feat_df.shape[0] == 0:
@@ -62,11 +62,13 @@ class TunedModel(BaseModel):
62
62
  if drop:
63
63
  self.drop = float(drop)
64
64
  self.util.debug(f"init: training with dropout: {self.drop}")
65
+ self.push = eval(self.util.config_val("MODEL", "push_to_hub", "False"))
65
66
  self._init_model()
66
67
 
67
68
  def _init_model(self):
68
69
  model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
69
- pretrained_model = self.util.config_val("MODEL", "pretrained_model", model_path)
70
+ pretrained_model = self.util.config_val(
71
+ "MODEL", "pretrained_model", model_path)
70
72
  self.num_layers = None
71
73
  self.sampling_rate = 16000
72
74
  self.max_duration_sec = self.max_duration
@@ -131,6 +133,11 @@ class TunedModel(BaseModel):
131
133
  tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
132
134
  tokenizer.save_pretrained(".")
133
135
 
136
+ # uoload tokenizer to hub if true
137
+ if self.push:
138
+ tokenizer.push_to_hub(self.util.get_name())
139
+
140
+
134
141
  feature_extractor = transformers.Wav2Vec2FeatureExtractor(
135
142
  feature_size=1,
136
143
  sampling_rate=16000,
@@ -273,7 +280,7 @@ class TunedModel(BaseModel):
273
280
  self.util.error(f"criterion {criterion} not supported for regressor")
274
281
 
275
282
  # set push_to_hub value, default false
276
- push = eval(self.util.config_val("MODEL", "push_to_hub", "False"))
283
+ # push = eval(self.util.config_val("MODEL", "push_to_hub", "False"))
277
284
 
278
285
  class Trainer(transformers.Trainer):
279
286
  def compute_loss(
@@ -312,7 +319,7 @@ class TunedModel(BaseModel):
312
319
  self.util.error(f"unknown metric/measure: {metrics_for_best_model}")
313
320
 
314
321
  training_args = transformers.TrainingArguments(
315
- output_dir=model_root,
322
+ output_dir=self.torch_root,
316
323
  logging_dir=self.log_root,
317
324
  per_device_train_batch_size=self.batch_size,
318
325
  per_device_eval_batch_size=self.batch_size,
@@ -331,7 +338,7 @@ class TunedModel(BaseModel):
331
338
  load_best_model_at_end=True,
332
339
  remove_unused_columns=False,
333
340
  report_to="none",
334
- push_to_hub=push,
341
+ push_to_hub=self.push,
335
342
  hub_model_id=f"{self.util.get_name()}",
336
343
  )
337
344
 
@@ -347,7 +354,7 @@ class TunedModel(BaseModel):
347
354
  )
348
355
 
349
356
  trainer.train()
350
- trainer.save_model(self.torch_root)
357
+ # trainer.save_model(self.torch_root)
351
358
  log_file = os.path.join(
352
359
  self.log_root,
353
360
  "log.txt",
@@ -517,7 +524,10 @@ class Model(Wav2Vec2PreTrainedModel):
517
524
  )
518
525
  outputs = torch.sum(hidden_states, dim=1)
519
526
  attention_sum = torch.sum(attention_mask, dim=1)
520
- outputs = outputs / torch.reshape(attention_sum, (-1, 1))
527
+
528
+ epsilon = 1e-6 # to avoid division by zero and numerical instability
529
+ outputs = outputs / (torch.reshape(attention_sum, (-1, 1)) +
530
+ epsilon)
521
531
 
522
532
  return outputs
523
533
 
nkululeko/utils/util.py CHANGED
@@ -35,9 +35,9 @@ class Util:
35
35
  if has_config:
36
36
  try:
37
37
  import nkululeko.glob_conf as glob_conf
38
+
38
39
  self.config = glob_conf.config
39
- self.got_data_roots = self.config_val(
40
- "DATA", "root_folders", False)
40
+ self.got_data_roots = self.config_val("DATA", "root_folders", False)
41
41
  if self.got_data_roots:
42
42
  # if there is a global data rootfolder file, read from there
43
43
  if not os.path.isfile(self.got_data_roots):
@@ -116,8 +116,7 @@ class Util:
116
116
  )
117
117
  return default
118
118
  if not default in self.stopvals:
119
- self.debug(
120
- f"value for {key} not found, using default: {default}")
119
+ self.debug(f"value for {key} not found, using default: {default}")
121
120
  return default
122
121
 
123
122
  def set_config(self, config):
@@ -160,8 +159,8 @@ class Util:
160
159
  if len(df) == 0:
161
160
  return df
162
161
  if not isinstance(df.index, pd.MultiIndex):
163
- df.index = audformat.utils.to_segmented_index(
164
- df.index, allow_nat=False)
162
+ self.debug("converting to segmented index, this might take a while...")
163
+ df.index = audformat.utils.to_segmented_index(df.index, allow_nat=False)
165
164
  return df
166
165
 
167
166
  def _get_value_descript(self, section, name):
@@ -272,8 +271,7 @@ class Util:
272
271
  return self.config[section][key]
273
272
  except KeyError:
274
273
  if default not in self.stopvals:
275
- self.debug(
276
- f"value for {key} not found, using default: {default}")
274
+ self.debug(f"value for {key} not found, using default: {default}")
277
275
  return default
278
276
 
279
277
  def config_val_list(self, section, key, default):
@@ -281,8 +279,7 @@ class Util:
281
279
  return ast.literal_eval(self.config[section][key])
282
280
  except KeyError:
283
281
  if not default in self.stopvals:
284
- self.debug(
285
- f"value for {key} not found, using default: {default}")
282
+ self.debug(f"value for {key} not found, using default: {default}")
286
283
  return default
287
284
 
288
285
  def continuous_to_categorical(self, series):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.86.2
3
+ Version: 0.86.3
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -334,6 +334,11 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
334
334
  Changelog
335
335
  =========
336
336
 
337
+ Version 0.86.3
338
+ --------------
339
+ * bugfixed: nan in finetuned model and double saving
340
+ * import features now get multiindex automatically
341
+
337
342
  Version 0.86.2
338
343
  --------------
339
344
  * plots epoch progression for finetuned models now
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=E9mXpvAI5IDamRRXgBlBH8XGTw1xEjEBzNibjhFPEFc,39
5
+ nkululeko/constants.py,sha256=2ysebEFzu3zwO0-FXWf2pBOs8XRLPmy718GFrZ2O9pU,39
6
6
  nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
@@ -55,7 +55,7 @@ nkululeko/feat_extract/feats_auddim.py,sha256=VlzKKXTXa5kjLgQBWyEFy-daIyU1SkOwCC
55
55
  nkululeko/feat_extract/feats_audmodel.py,sha256=VjBNgAoxsHJhwr6Kwt9CxX6SaCM4RK_OV-GU2W5-bhU,3187
56
56
  nkululeko/feat_extract/feats_clap.py,sha256=nR6eEIRdsMHcfmD1bNtt5WfDvkxKjvEbukSSrXHm-HU,3489
57
57
  nkululeko/feat_extract/feats_hubert.py,sha256=cLoUzSLjSYBkQnftjacSL7ES3O7Ysh_KrPYvZtLX_TU,5196
58
- nkululeko/feat_extract/feats_import.py,sha256=rj1p8lz19tCAC8hLzzZAwZ0M6gzwH3BzfabFUgal0yw,1622
58
+ nkululeko/feat_extract/feats_import.py,sha256=WiU5lCkJsmFNTDyPV0qIh8mJssa6bpgP7AYw_ClKfWM,1674
59
59
  nkululeko/feat_extract/feats_mld.py,sha256=Vvu7GZOkn7Vda8eIOXqHjg78zegkFe3vTUaCXyVM0eA,2021
60
60
  nkululeko/feat_extract/feats_mos.py,sha256=KXNt7QYEfxkvr6UyVhig2aWQBaIvovlrR4gPuP03gmo,4174
61
61
  nkululeko/feat_extract/feats_opensmile.py,sha256=g6ZsAxjjGGvGfrr5fngWC-NJ8E7CP1kYZwrlodZJzzU,4028
@@ -88,7 +88,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
88
88
  nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
89
89
  nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
90
90
  nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
91
- nkululeko/models/model_tuned.py,sha256=xOoY5TROzzTVu3sDtlmEle3V1MAgpf8S3WxO9o4MzV4,20777
91
+ nkululeko/models/model_tuned.py,sha256=RDcvcejBQNGY_uW00r22i7EDT6oKchS5uqFFnj0Gtzg,21146
92
92
  nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
93
93
  nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
94
94
  nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -104,9 +104,9 @@ nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5Ev
104
104
  nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
105
105
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
106
106
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
107
- nkululeko/utils/util.py,sha256=mK1MgO14NinrPhavJw72eR_2WN_kBKjVKiEJnzvdO1Q,13946
108
- nkululeko-0.86.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
- nkululeko-0.86.2.dist-info/METADATA,sha256=DmmpMrftBptpWqx7h9US7_4mvMIQbZ5ugzv_4kyBjkM,37170
110
- nkululeko-0.86.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
- nkululeko-0.86.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
- nkululeko-0.86.2.dist-info/RECORD,,
107
+ nkululeko/utils/util.py,sha256=PcyAuCGgGuxjlv-e4JrVbpewiRTiAXWk47w5X0dVgx8,13930
108
+ nkululeko-0.86.3.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
+ nkululeko-0.86.3.dist-info/METADATA,sha256=Nnb3gRWEI1DSqf8KpaD8CDqdkHyiKdv-j9HpN4jjeks,37305
110
+ nkululeko-0.86.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
+ nkululeko-0.86.3.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
+ nkululeko-0.86.3.dist-info/RECORD,,