nkululeko 0.95.9__py3-none-any.whl → 0.96.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.95.9"
1
+ VERSION="0.96.0"
2
2
  SAMPLING_RATE = 16000
@@ -0,0 +1,105 @@
1
+ import os
2
+
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+ import transformers
6
+ import torch
7
+ from transformers import BertTokenizer, BertModel
8
+
9
+ from nkululeko.feat_extract.featureset import Featureset
10
+ import nkululeko.glob_conf as glob_conf
11
+
12
+
13
+ class Bert(Featureset):
14
+ """Class to extract bert embeddings"""
15
+
16
+ def __init__(self, name, data_df, feat_type):
17
+ """Constructor.
18
+
19
+ If_train is needed to distinguish from test/dev sets,
20
+ because they use the codebook from the training
21
+ """
22
+ super().__init__(name, data_df, feat_type)
23
+ cuda = "cuda" if torch.cuda.is_available() else "cpu"
24
+ self.device = self.util.config_val("MODEL", "device", cuda)
25
+ self.model_initialized = False
26
+ if feat_type == "bert":
27
+ self.feat_type = "bert-base-uncased"
28
+ else:
29
+ self.feat_type = feat_type
30
+
31
+ def init_model(self):
32
+ # load model
33
+ self.util.debug(f"loading {self.feat_type} model...")
34
+ model_path = self.util.config_val(
35
+ "FEATS", "bert.model", f"google-bert/{self.feat_type}"
36
+ )
37
+ config = transformers.AutoConfig.from_pretrained(model_path)
38
+ layer_num = config.num_hidden_layers
39
+ hidden_layer = int(self.util.config_val("FEATS", "bert.layer", "0"))
40
+ config.num_hidden_layers = layer_num - hidden_layer
41
+ self.util.debug(f"using hidden layer #{config.num_hidden_layers}")
42
+
43
+ self.tokenizer = BertTokenizer.from_pretrained(model_path)
44
+ self.model = BertModel.from_pretrained(model_path, config=config).to(
45
+ self.device
46
+ )
47
+ print(f"initialized {self.feat_type} model on {self.device}")
48
+ self.model.eval()
49
+ self.model_initialized = True
50
+
51
+ def extract(self):
52
+ """Extract the features or load them from disk if present."""
53
+ store = self.util.get_path("store")
54
+ storage = os.path.join(store, f"{self.name}.pkl")
55
+ extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
56
+ no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
57
+ if extract or no_reuse or not os.path.isfile(storage):
58
+ if not self.model_initialized:
59
+ self.init_model()
60
+ self.util.debug(
61
+ f"extracting {self.feat_type} embeddings, this might take a while..."
62
+ )
63
+ emb_series = pd.Series(index=self.data_df.index, dtype=object)
64
+ for idx, row in tqdm(self.data_df.iterrows(), total=len(self.data_df)):
65
+ file = idx[0]
66
+ text = row['text']
67
+ emb = self.get_embeddings(text, file)
68
+ emb_series[idx] = emb
69
+ # print(f"emb_series shape: {emb_series.shape}")
70
+ self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
71
+ # print(f"df shape: {self.df.shape}")
72
+ self.df.to_pickle(storage)
73
+ try:
74
+ glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
75
+ except KeyError:
76
+ pass
77
+ else:
78
+ self.util.debug(f"reusing extracted {self.feat_type} embeddings")
79
+ self.df = pd.read_pickle(storage)
80
+ if self.df.isnull().values.any():
81
+ self.util.error(
82
+ f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
83
+ )
84
+
85
+ def get_embeddings(self, text, file):
86
+ r"""Extract embeddings from raw audio signal."""
87
+ try:
88
+ with torch.no_grad():
89
+ inputs = self.tokenizer(text, return_tensors="pt")
90
+ outputs = self.model(**inputs)
91
+ # mean pooling
92
+ y = torch.mean(outputs[0], dim=1)
93
+ y = y.ravel()
94
+ except RuntimeError as re:
95
+ print(str(re))
96
+ self.util.error(f"couldn't extract file: {file}")
97
+ y = None
98
+ if y is None:
99
+ return None
100
+ return y.detach().cpu().numpy()
101
+
102
+ def extract_sample(self, text):
103
+ self.init_model()
104
+ feats = self.get_embeddings(text, "no file")
105
+ return feats
@@ -80,7 +80,7 @@ class FeatureExtractor:
80
80
  return MLD_set
81
81
 
82
82
  elif feats_type.startswith(
83
- ("wav2vec2", "hubert", "wavlm", "spkrec", "whisper", "ast", "emotion2vec")
83
+ ("bert", "wav2vec2", "hubert", "wavlm", "spkrec", "whisper", "ast", "emotion2vec")
84
84
  ):
85
85
  return self._get_feat_extractor_by_prefix(feats_type)
86
86
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nkululeko
3
- Version: 0.95.9
3
+ Version: 0.96.0
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -4,7 +4,7 @@ nkululeko/aug_train.py,sha256=wpiHCJ7zsW38kumg3ypwXZe2HQrhUblAnv7P2QeJnAc,3525
4
4
  nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
5
5
  nkululeko/balance.py,sha256=r7opXbrqAipm2euPPaOmLlA5J10p2bHQgO5kWk2x9ro,8702
6
6
  nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
7
- nkululeko/constants.py,sha256=t_C_hQqVC1idXJB6HHr1m7ZtCYC5JVvqhYrVLRhzwIw,39
7
+ nkululeko/constants.py,sha256=dv32YQyIGpLW6TD10egbPsxz2wdIYaw29xUA2K28xt4,39
8
8
  nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
9
9
  nkululeko/demo.py,sha256=tu7Al2l5MCLVegkDC-NE2wcuc_YE7NRbgOlPW3yhGEs,4940
10
10
  nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
@@ -13,7 +13,7 @@ nkululeko/ensemble.py,sha256=71V-rre61H3J4sh7lu-OTo4I2_g7mm_rQxwW1ARDHgY,12782
13
13
  nkululeko/experiment.py,sha256=TG9G9kSETT_R8d92aRKMMsb0HRGyM_GBFHBsU9A6ppw,38633
14
14
  nkululeko/explore.py,sha256=PjNcLuPdvWqCqYXUvGhd0hBijIhzdyi3ED1RF6o5Gjk,4212
15
15
  nkululeko/export.py,sha256=U-V4acxtuL6qKt6oAsVcM5TTeWogYUJ3GU-lA6rq6d4,4336
16
- nkululeko/feature_extractor.py,sha256=CsKmBoxwNClRGu20ox_eCxMG4u_1OH8Y83FYw7GfUwA,4230
16
+ nkululeko/feature_extractor.py,sha256=d3G42OOh315Aho-yLaFT739U0UI8otiB1I4ZksK8kfg,4238
17
17
  nkululeko/file_checker.py,sha256=xJY0Q6w47pnmgJVK5rcAKPYBrCpV7eBT4_3YBzTx-H8,3454
18
18
  nkululeko/filter_data.py,sha256=4sGrKvMZ_hLnJPrHm_CqjDPKIRV8REWoT7nfSYGXbwo,7305
19
19
  nkululeko/fixedsegment.py,sha256=Tb92QiuiyMsOO3WRWwuGjZGibS8hbHHCrcWAXGk7g04,2868
@@ -69,6 +69,7 @@ nkululeko/feat_extract/feats_analyser.py,sha256=lodim7qQ8M7c3iMeJ5bHQ-nCy9Cehx1X
69
69
  nkululeko/feat_extract/feats_ast.py,sha256=w62xEoLiFtU-rj6SXkqXAktmoFaXcAcAWpUyEjp8JWo,4652
70
70
  nkululeko/feat_extract/feats_auddim.py,sha256=CGLp_aYhudfwoU5522vjrvjPxfZcyw593A8xLjYefV8,3134
71
71
  nkululeko/feat_extract/feats_audmodel.py,sha256=OsZyB1rdcG0Fai2gAwBlbuubmWor1_-P4IDkZLqgPKE,3161
72
+ nkululeko/feat_extract/feats_bert.py,sha256=KgWLYLA11e86ubY1KtUk74QGOrZaiUongn-2LKWyf1M,4114
72
73
  nkululeko/feat_extract/feats_clap.py,sha256=1tttpfm2SJmQgYm2u8eUVpDiDOpWdKqFChpY3ZZokNs,3395
73
74
  nkululeko/feat_extract/feats_emotion2vec.py,sha256=LnV8xEg7L7HIDqz0ulqUNoaAHBU0d5gyQPb2_32T_18,9694
74
75
  nkululeko/feat_extract/feats_hubert.py,sha256=F3vrPCkx8EimJjFWYCZ7Yg9uo1G3NjYt4UKrGIUev8k,5172
@@ -136,9 +137,9 @@ nkululeko/utils/files.py,sha256=SrrYaU7AB80MZHiV1jcB0h_zigvYLYgSVNTXV4ao38g,4593
136
137
  nkululeko/utils/stats.py,sha256=3Fyx8q8BSKYmiufT6OkRug9RATWmGrr9BaX_y8jziWo,3074
137
138
  nkululeko/utils/unzip.py,sha256=G68f5120TjwACZC3bQcneMniddnwubPbBdMc2L5KBOo,1206
138
139
  nkululeko/utils/util.py,sha256=s7Hd7Ju1r3_WCw8gLD9YK4O6k3S_WhFcN2-XZBSctSM,18705
139
- nkululeko-0.95.9.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
140
- nkululeko-0.95.9.dist-info/METADATA,sha256=WhITXnJHYD5GhyATjEb7kJhmMecWRu-BeMBw7pSWNdc,21998
141
- nkululeko-0.95.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
142
- nkululeko-0.95.9.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
143
- nkululeko-0.95.9.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
144
- nkululeko-0.95.9.dist-info/RECORD,,
140
+ nkululeko-0.96.0.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
141
+ nkululeko-0.96.0.dist-info/METADATA,sha256=uwhWqKjvfyvV6UHK4v_2H7WusIw6Bei_7i03RWVHWHE,21998
142
+ nkululeko-0.96.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
143
+ nkululeko-0.96.0.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
144
+ nkululeko-0.96.0.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
145
+ nkululeko-0.96.0.dist-info/RECORD,,