sdg-core-lib 0.1.8.dev1__tar.gz → 0.1.8.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/PKG-INFO +1 -1
  2. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/pyproject.toml +1 -1
  3. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/CTGANComponents.py +33 -24
  4. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/implementation/CTGAN.py +51 -15
  5. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/steps.py +14 -23
  6. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/README.md +0 -0
  7. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/__init__.py +0 -0
  8. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/browser.py +0 -0
  9. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/commons.py +0 -0
  10. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/__init__.py +0 -0
  11. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/__init__.py +0 -0
  12. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/implementation/__init__.py +0 -0
  13. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/ModelInfo.py +0 -0
  14. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/TrainingInfo.py +0 -0
  15. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/UnspecializedModel.py +0 -0
  16. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/KerasBaseVAE.py +0 -0
  17. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/VAE.py +0 -0
  18. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/__init__.py +0 -0
  19. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/AutoTabularVAE.py +0 -0
  20. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/TabularVAE.py +0 -0
  21. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/TimeSeriesVAE.py +0 -0
  22. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/__init__.py +0 -0
  23. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/__init__.py +0 -0
  24. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/__init__.py +0 -0
  25. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/columns.py +0 -0
  26. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/datasets.py +0 -0
  27. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/__init__.py +0 -0
  28. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/base_evaluator.py +0 -0
  29. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/metrics.py +0 -0
  30. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/tables.py +0 -0
  31. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/time_series.py +0 -0
  32. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/job.py +0 -0
  33. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/mappings.py +0 -0
  34. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/FunctionApplier.py +0 -0
  35. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/__init__.py +0 -0
  36. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/function_factory.py +0 -0
  37. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/function_utils.py +0 -0
  38. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/Parameter.py +0 -0
  39. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/UnspecializedFunction.py +0 -0
  40. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/__init__.py +0 -0
  41. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/__init__.py +0 -0
  42. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/implementation/NormalTester.py +0 -0
  43. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/implementation/__init__.py +0 -0
  44. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/IntervalThreshold.py +0 -0
  45. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/MonoThreshold.py +0 -0
  46. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/__init__.py +0 -0
  47. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/InnerThreshold.py +0 -0
  48. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/LowerThreshold.py +0 -0
  49. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/OuterThreshold.py +0 -0
  50. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/UpperThreshold.py +0 -0
  51. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/__init__.py +0 -0
  52. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/__init__.py +0 -0
  53. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/LinearFunction.py +0 -0
  54. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/NormalDistributionSample.py +0 -0
  55. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/QuadraticFunction.py +0 -0
  56. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/SinusoidalFunction.py +0 -0
  57. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/__init__.py +0 -0
  58. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/__init__.py +0 -0
  59. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/BurstNoiseAdder.py +0 -0
  60. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/WhiteNoiseAdder.py +0 -0
  61. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/__init__.py +0 -0
  62. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/__init__.py +0 -0
  63. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/base_processor.py +0 -0
  64. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/__init__.py +0 -0
  65. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/base_strategy.py +0 -0
  66. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/ctgan_strategy.py +0 -0
  67. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/vae_strategy.py +0 -0
  68. {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/table_processor.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdg-core-lib
3
- Version: 0.1.8.dev1
3
+ Version: 0.1.8.dev2
4
4
  Summary: Add your description here
5
5
  Author: emiliocimino
6
6
  Author-email: emiliocimino <emilio.cimino@outlook.it>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sdg-core-lib"
3
- version = "0.1.8.dev1"
3
+ version = "0.1.8.dev2"
4
4
  description = "Add your description here"
5
5
  license = "AGPL-3.0"
6
6
  readme = "README.md"
@@ -185,8 +185,8 @@ class CTGANModel(keras.Model):
185
185
  def metrics(self):
186
186
  return [self.gen_loss_tracker, self.critic_loss_tracker]
187
187
 
188
- @tf.function
189
188
  def generate_batch_cond(self, batch_size):
189
+ batch_size = int(batch_size) # Convert symbolic tensor to int
190
190
  num_cats = len(self.generator.cats_disc)
191
191
  total_cond_dim = sum(self.generator.cats_disc)
192
192
  cats_disc = tf.convert_to_tensor(self.generator.cats_disc, dtype=tf.int32)
@@ -195,25 +195,22 @@ class CTGANModel(keras.Model):
195
195
  shape=[batch_size], minval=0, maxval=num_cats, dtype=tf.int32
196
196
  )
197
197
 
198
- relevant_pmfs = tf.gather(self.probability_mass_function_list, col_indices)
199
-
200
- cat_indices = tf.random.categorical(tf.math.log(relevant_pmfs), num_samples=1)
201
- cat_indices = tf.cast(tf.squeeze(cat_indices, axis=1), tf.int32)
202
-
203
- offsets_table = tf.concat([[0], tf.cumsum(cats_disc)[:-1]], axis=0)
204
- batch_offsets = tf.gather(offsets_table, col_indices)
205
-
206
- global_hot_indices = batch_offsets + cat_indices
207
- row_indices = tf.range(batch_size)
208
- scatter_indices = tf.stack([row_indices, global_hot_indices], axis=1)
209
-
210
- cond_batch = tf.scatter_nd(
211
- indices=scatter_indices,
212
- updates=tf.ones([batch_size], dtype=tf.float32),
213
- shape=[batch_size, total_cond_dim],
214
- )
215
-
216
- return cond_batch
198
+ # Create condition vector by sampling from PMFs and creating one-hot encoding
199
+ condition_list = []
200
+ for i in range(batch_size):
201
+ idx = col_indices[i].numpy()
202
+ pmf = self.probability_mass_function_list[idx].numpy().flatten()
203
+ # Sample from PMF
204
+ cat_idx = np.random.choice(len(pmf), p=pmf)
205
+ # Create one-hot vector for total_cond_dim
206
+ one_hot = np.zeros(total_cond_dim)
207
+ # Calculate offset for this categorical variable
208
+ offset = sum(self.generator.cats_disc[:idx])
209
+ one_hot[offset + cat_idx] = 1.0
210
+ condition_list.append(one_hot)
211
+
212
+ cond = tf.convert_to_tensor(condition_list, dtype=tf.float32)
213
+ return cond
217
214
 
218
215
  @staticmethod
219
216
  @tf.function
@@ -291,9 +288,12 @@ class CTGANModel(keras.Model):
291
288
  pmfs = []
292
289
  curr = 0
293
290
  for sz in self.generator.cats_disc:
294
- chunk = onehot_all[:, curr : curr + sz]
295
- log_freqs = tf.math.log(tf.reduce_sum(chunk, axis=0) + 1.0)
296
- pmfs.append(log_freqs / tf.reduce_sum(log_freqs))
291
+ chunk = onehot_all[:, curr : curr + sz] # (N_row, cats)
292
+ chunk_np = chunk.numpy()
293
+ log_freqs = np.log(np.sum(chunk_np, axis=0) + 1.0).reshape(1, -1)
294
+ pmfs.append(
295
+ tf.convert_to_tensor(log_freqs / np.sum(log_freqs), dtype=tf.float32)
296
+ )
297
297
  curr += sz
298
298
  return pmfs
299
299
 
@@ -301,7 +301,16 @@ class CTGANModel(keras.Model):
301
301
  batch = ops.shape(data)[0]
302
302
  self.row_dim = ops.shape(data)[1]
303
303
  z = tf.random.normal([batch, self.row_dim - sum(self.generator.cats_disc)])
304
- cond = self.generate_batch_cond(batch)
304
+
305
+ # Use tf.py_function to call generate_batch_cond in eager mode
306
+ def generate_cond_eager(batch_size):
307
+ return self.generate_batch_cond(batch_size)
308
+
309
+ cond = tf.py_function(generate_cond_eager, [batch], tf.float32)
310
+ # Set the shape explicitly - it should be [batch_size, total_cond_dim]
311
+ total_cond_dim = sum(self.generator.cats_disc)
312
+ cond.set_shape([None, total_cond_dim])
313
+
305
314
  real_batch = CTGANModel.sample_real_data(
306
315
  self._train_data, cond, self.onehot_discrete_indexes
307
316
  )
@@ -11,6 +11,7 @@ from sdg_core_lib.data_generator.models.GANs.CTGANComponents import (
11
11
  CTGANModel,
12
12
  )
13
13
  import keras
14
+ import tensorflow as tf
14
15
  from sdg_core_lib.data_generator.models.TrainingInfo import TrainingInfo
15
16
  import numpy as np
16
17
 
@@ -18,7 +19,7 @@ import numpy as np
18
19
  class CTGAN(UnspecializedModel):
19
20
  def __init__(
20
21
  self,
21
- metadata: dict,
22
+ metadata: list[dict],
22
23
  model_name: str,
23
24
  input_shape: str = None,
24
25
  load_path: str = None,
@@ -31,7 +32,9 @@ class CTGAN(UnspecializedModel):
31
32
  gen_steps=4,
32
33
  critic_dropout=0.2,
33
34
  ):
34
- super().__init__(metadata, model_name, input_shape, load_path)
35
+ super().__init__(
36
+ self._clean_skeleton(metadata), model_name, input_shape, load_path
37
+ )
35
38
  self._batch_size = batch_size
36
39
  self._epochs = epochs
37
40
  self._gen_steps = gen_steps
@@ -43,10 +46,19 @@ class CTGAN(UnspecializedModel):
43
46
  self._instantiate()
44
47
 
45
48
  @staticmethod
46
- def infer_data_structure(skeleton):
49
+ def _clean_skeleton(skeleton):
50
+ if skeleton != [{}]:
51
+ return [
52
+ item
53
+ for item in skeleton
54
+ if item["feature_type"] in ["continuous", "categorical"]
55
+ ]
56
+ return skeleton
57
+
58
+ def infer_data_structure(self):
47
59
  cats, modes, idxs = [], [], []
48
60
  true_index = 0
49
- for col in skeleton:
61
+ for col in self._metadata:
50
62
  try:
51
63
  f_size = int(col["feature_size"])
52
64
  if col["feature_type"] == "categorical":
@@ -90,7 +102,7 @@ class CTGAN(UnspecializedModel):
90
102
  categories_per_discrete_column,
91
103
  modes_per_continuous_column,
92
104
  onehot_discrete_indexes,
93
- ) = CTGAN.infer_data_structure(self._metadata)
105
+ ) = self.infer_data_structure()
94
106
  self.generator = CTGANGenerator(
95
107
  self._metadata,
96
108
  modes_per_continuous_column,
@@ -106,7 +118,7 @@ class CTGAN(UnspecializedModel):
106
118
  # Should set the _model variable CTGAN Model complete with Generator and Critic
107
119
  # Does NOT return the model
108
120
  # self._metadata is available
109
- _, _, onehot_discrete_indexes = CTGAN.infer_data_structure(self._metadata)
121
+ _, _, onehot_discrete_indexes = self.infer_data_structure()
110
122
  critic = keras.saving.load_model(os.path.join(folder_path, "critic.keras"))
111
123
  generator = keras.saving.load_model(
112
124
  os.path.join(folder_path, "generator.keras")
@@ -114,11 +126,33 @@ class CTGAN(UnspecializedModel):
114
126
  self._model = CTGANModel(generator, critic, onehot_discrete_indexes)
115
127
 
116
128
  # Load probability_mass_function_list if it exists
117
- pmf_path = os.path.join(folder_path, "probability_mass_function_list.npy")
129
+ pmf_path = os.path.join(folder_path, "probability_mass_function_list.npz")
118
130
  if os.path.exists(pmf_path):
119
- self._model.probability_mass_function_list = np.load(
120
- pmf_path, allow_pickle=True
131
+ pmf_data = np.load(pmf_path)
132
+ # Convert back to list of TensorFlow tensors
133
+ pmf_list = []
134
+ for key in sorted(pmf_data.keys()):
135
+ pmf_list.append(tf.convert_to_tensor(pmf_data[key], dtype=tf.float32))
136
+ self._model.probability_mass_function_list = pmf_list
137
+ # Also check for old .npy format for backward compatibility
138
+ elif os.path.exists(
139
+ os.path.join(folder_path, "probability_mass_function_list.npy")
140
+ ):
141
+ pmf_list = np.load(
142
+ os.path.join(folder_path, "probability_mass_function_list.npy"),
143
+ allow_pickle=True,
121
144
  )
145
+ # Convert to TensorFlow tensors if needed
146
+ if (
147
+ isinstance(pmf_list, list)
148
+ and len(pmf_list) > 0
149
+ and isinstance(pmf_list[0], np.ndarray)
150
+ ):
151
+ self._model.probability_mass_function_list = [
152
+ tf.convert_to_tensor(pmf, dtype=tf.float32) for pmf in pmf_list
153
+ ]
154
+ else:
155
+ self._model.probability_mass_function_list = pmf_list
122
156
 
123
157
  def save(self, folder_path: str):
124
158
  keras.saving.save_model(
@@ -132,9 +166,13 @@ class CTGAN(UnspecializedModel):
132
166
  hasattr(self._model, "probability_mass_function_list")
133
167
  and self._model.probability_mass_function_list is not None
134
168
  ):
135
- np.save(
136
- os.path.join(folder_path, "probability_mass_function_list.npy"),
137
- self._model.probability_mass_function_list,
169
+ # Convert TensorFlow tensors to numpy arrays before saving
170
+ pmf_list = [
171
+ tensor.numpy() for tensor in self._model.probability_mass_function_list
172
+ ]
173
+ np.savez(
174
+ os.path.join(folder_path, "probability_mass_function_list.npz"),
175
+ *pmf_list,
138
176
  )
139
177
 
140
178
  def train(self, data: np.ndarray):
@@ -157,9 +195,7 @@ class CTGAN(UnspecializedModel):
157
195
  )
158
196
  self._model._train_data = data
159
197
  probability_mass_function_list = self._model.get_pmfs(data)
160
- self._model.probability_mass_function_list = keras.ops.convert_to_numpy(
161
- probability_mass_function_list
162
- )
198
+ self._model.probability_mass_function_list = probability_mass_function_list
163
199
  history = self._model.fit(
164
200
  data, batch_size=self._batch_size, epochs=self._epochs, verbose=1
165
201
  )
@@ -180,33 +180,32 @@ class PerModeNormalization(Step):
180
180
  if self.operator is None:
181
181
  raise ValueError("Operator not initialized")
182
182
  column = data.reshape(-1, 1)
183
- active_weights_indx = np.where(self.operator.weights_ > 0.01)
184
- weights = self.operator.weights_[active_weights_indx]
185
- means = self.operator.means_[active_weights_indx].flatten()
186
- stds = np.sqrt(self.operator.covariances_[active_weights_indx].flatten())
183
+ active_weights_indexes = np.where(self.operator.weights_ > 0.01)
184
+ weights = self.operator.weights_[active_weights_indexes]
185
+ means = self.operator.means_[active_weights_indexes].flatten()
186
+ stds = np.sqrt(self.operator.covariances_[active_weights_indexes].flatten())
187
187
  mixture_probability_density = []
188
- for w, m, s in zip(weights, means, stds):
188
+ for weight, mean, std in zip(weights, means, stds):
189
189
  mixture_probability_density.append(
190
- w
191
- * PerModeNormalization._gaussian_probability_density_function(
192
- column, m, s
193
- )
190
+ weight * self._gaussian_probability_density_function(column, mean, std)
194
191
  )
195
192
  marginal_mixture_probability_density = np.hstack(mixture_probability_density)
196
- responsibilities = PerModeNormalization._compute_responsibilities(
193
+ responsibilities = self._compute_responsibilities(
197
194
  marginal_mixture_probability_density
198
195
  )
199
196
  rng = np.random.default_rng(self.random_state)
200
- n, K = responsibilities.shape
197
+ n, k = responsibilities.shape
201
198
  sampled_mode = np.array(
202
- [rng.choice(K, p=responsibilities[i]) for i in range(n)]
199
+ [rng.choice(k, p=responsibilities[i]) for i in range(n)]
203
200
  )
204
- f = np.zeros((n, K), dtype=int)
205
- f[np.arange(n), sampled_mode] = 1
201
+ mode_assignment = np.zeros((n, k), dtype=int)
202
+ mode_assignment[np.arange(n), sampled_mode] = 1
206
203
  mu_sel = means[sampled_mode]
207
204
  std_sel = stds[sampled_mode]
208
205
  normalized_value = (column.reshape(-1) - mu_sel) / (4.0 * std_sel)
209
- to_return = np.concatenate([normalized_value.reshape(-1, 1), f], axis=1)
206
+ to_return = np.concatenate(
207
+ [normalized_value.reshape(-1, 1), mode_assignment], axis=1
208
+ )
210
209
  return to_return
211
210
 
212
211
  def inverse_transform(self, data: np.ndarray) -> np.ndarray:
@@ -221,18 +220,10 @@ class PerModeNormalization(Step):
221
220
  data = data.reshape(1, -1)
222
221
 
223
222
  active_modes = np.argmax(data[:, 1:], axis=1)
224
-
225
- # Get the means and stds for the active modes
226
223
  selected_mus = means[active_modes]
227
224
  selected_devs = stds[active_modes]
228
-
229
- # Get the normalized values (first column)
230
225
  normalized_values = data[:, 0]
231
-
232
- # Denormalize the values
233
226
  values = (normalized_values * 4 * selected_devs) + selected_mus
234
-
235
- # Always return 2D array with shape (n_samples, 1) for consistency
236
227
  return values.reshape(-1, 1)
237
228
 
238
229
  @staticmethod