nkululeko 0.94.3__py3-none-any.whl → 0.95.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/glob_conf.py CHANGED
@@ -1,5 +1,14 @@
1
1
  # glob_conf.py
2
2
 
3
+ # Initialize global variables
4
+ config = None
5
+ label_encoder = None
6
+ util = None
7
+ module = None
8
+ report = None
9
+ labels = None
10
+ target = None
11
+
3
12
 
4
13
  def init_config(config_obj):
5
14
  global config
nkululeko/modelrunner.py CHANGED
@@ -279,9 +279,18 @@ class Modelrunner:
279
279
  self.util.debug(
280
280
  f"balanced with: {balancing}, new size: {X_res.shape[0]} (was {orig_size})"
281
281
  )
282
- le = glob_conf.label_encoder
283
- res = y_res.value_counts()
284
- resd = {}
285
- for i, e in enumerate(le.inverse_transform(res.index.values)):
286
- resd[e] = res.values[i]
287
- self.util.debug(f"{resd})")
282
+ # Check if label encoder is available before using it
283
+ if (
284
+ hasattr(glob_conf, "label_encoder")
285
+ and glob_conf.label_encoder is not None
286
+ ):
287
+ le = glob_conf.label_encoder
288
+ res = y_res.value_counts()
289
+ resd = {}
290
+ for i, e in enumerate(le.inverse_transform(res.index.values)):
291
+ resd[e] = res.values[i]
292
+ self.util.debug(f"class distribution after balancing: {resd}")
293
+ else:
294
+ self.util.debug(
295
+ "Label encoder not available, skipping class distribution report"
296
+ )
nkululeko/models/model.py CHANGED
@@ -3,15 +3,11 @@ import ast
3
3
  import pickle
4
4
  import random
5
5
 
6
- from joblib import parallel_backend
7
6
  import numpy as np
8
7
  import pandas as pd
9
- from sklearn.model_selection import GridSearchCV
10
- from sklearn.model_selection import LeaveOneGroupOut
11
- from sklearn.model_selection import StratifiedKFold
12
8
  import sklearn.utils
13
-
14
- import audeer
9
+ from joblib import parallel_backend
10
+ from sklearn.model_selection import GridSearchCV, LeaveOneGroupOut, StratifiedKFold
15
11
 
16
12
  import nkululeko.glob_conf as glob_conf
17
13
  from nkululeko.reporting.reporter import Reporter
@@ -305,15 +301,8 @@ class Model:
305
301
  def get_type(self):
306
302
  return "generic"
307
303
 
308
- def predict_sample(self, features: np.ndarray) -> dict | float:
309
- """Predict a single sample using the trained model.
310
-
311
- Args:
312
- features (np.ndarray): The feature vector of the sample to predict.
313
-
314
- Returns:
315
- dict: A dictionary containing the predicted class probabilities or value.
316
- """
304
+ def predict_sample(self, features):
305
+ """Predict one sample"""
317
306
  prediction = {}
318
307
  if self.util.exp_is_classification():
319
308
  # get the class probabilities
@@ -347,30 +336,3 @@ class Model:
347
336
  self.set_id(run, epoch)
348
337
  with open(path, "rb") as handle:
349
338
  self.clf = pickle.load(handle)
350
-
351
- # next function exports the model to onnx
352
- def export_onnx(self, onnx_path, input_shape=None):
353
- """Export the trained sklearn model to ONNX format.
354
-
355
- Args:
356
- onnx_path (str): Path to save the ONNX model.
357
- input_shape (tuple, optional): Shape of the input features. If None, inferred from feats_train.
358
- """
359
- import skl2onnx
360
- from skl2onnx import convert_sklearn
361
- from skl2onnx.common.data_types import FloatTensorType
362
-
363
- if not hasattr(self, "clf"):
364
- self.util.error("No trained model found to export.")
365
- return
366
-
367
- if input_shape is None:
368
- n_features = self.feats_train.shape[1]
369
- initial_type = [("input", FloatTensorType([None, n_features]))]
370
- else:
371
- initial_type = [("input", FloatTensorType(input_shape))]
372
-
373
- onnx_model = convert_sklearn(self.clf, initial_types=initial_type)
374
- with open(audeer.path(onnx_path), "wb") as f:
375
- f.write(onnx_model.SerializeToString())
376
- self.util.debug(f"Model exported to ONNX at {onnx_path}")