sdg-core-lib 0.1.8.dev1__tar.gz → 0.1.8.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/PKG-INFO +1 -1
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/pyproject.toml +1 -1
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/CTGANComponents.py +33 -24
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/implementation/CTGAN.py +51 -15
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/steps.py +14 -23
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/README.md +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/browser.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/commons.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/GANs/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/ModelInfo.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/TrainingInfo.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/UnspecializedModel.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/KerasBaseVAE.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/VAE.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/AutoTabularVAE.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/TabularVAE.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/TimeSeriesVAE.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/VAEs/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/models/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/columns.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/dataset/datasets.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/base_evaluator.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/metrics.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/tables.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/time_series.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/job.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/mappings.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/FunctionApplier.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/function_factory.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/function_utils.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/Parameter.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/UnspecializedFunction.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/implementation/NormalTester.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/distribution_evaluator/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/IntervalThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/MonoThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/InnerThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/LowerThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/OuterThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/UpperThreshold.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/filter/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/LinearFunction.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/NormalDistributionSample.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/QuadraticFunction.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/SinusoidalFunction.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/generation/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/BurstNoiseAdder.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/WhiteNoiseAdder.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/functions/modification/implementation/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/base_processor.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/__init__.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/base_strategy.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/ctgan_strategy.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/vae_strategy.py +0 -0
- {sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/table_processor.py +0 -0
|
@@ -185,8 +185,8 @@ class CTGANModel(keras.Model):
|
|
|
185
185
|
def metrics(self):
|
|
186
186
|
return [self.gen_loss_tracker, self.critic_loss_tracker]
|
|
187
187
|
|
|
188
|
-
@tf.function
|
|
189
188
|
def generate_batch_cond(self, batch_size):
|
|
189
|
+
batch_size = int(batch_size) # Convert symbolic tensor to int
|
|
190
190
|
num_cats = len(self.generator.cats_disc)
|
|
191
191
|
total_cond_dim = sum(self.generator.cats_disc)
|
|
192
192
|
cats_disc = tf.convert_to_tensor(self.generator.cats_disc, dtype=tf.int32)
|
|
@@ -195,25 +195,22 @@ class CTGANModel(keras.Model):
|
|
|
195
195
|
shape=[batch_size], minval=0, maxval=num_cats, dtype=tf.int32
|
|
196
196
|
)
|
|
197
197
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
return cond_batch
|
|
198
|
+
# Create condition vector by sampling from PMFs and creating one-hot encoding
|
|
199
|
+
condition_list = []
|
|
200
|
+
for i in range(batch_size):
|
|
201
|
+
idx = col_indices[i].numpy()
|
|
202
|
+
pmf = self.probability_mass_function_list[idx].numpy().flatten()
|
|
203
|
+
# Sample from PMF
|
|
204
|
+
cat_idx = np.random.choice(len(pmf), p=pmf)
|
|
205
|
+
# Create one-hot vector for total_cond_dim
|
|
206
|
+
one_hot = np.zeros(total_cond_dim)
|
|
207
|
+
# Calculate offset for this categorical variable
|
|
208
|
+
offset = sum(self.generator.cats_disc[:idx])
|
|
209
|
+
one_hot[offset + cat_idx] = 1.0
|
|
210
|
+
condition_list.append(one_hot)
|
|
211
|
+
|
|
212
|
+
cond = tf.convert_to_tensor(condition_list, dtype=tf.float32)
|
|
213
|
+
return cond
|
|
217
214
|
|
|
218
215
|
@staticmethod
|
|
219
216
|
@tf.function
|
|
@@ -291,9 +288,12 @@ class CTGANModel(keras.Model):
|
|
|
291
288
|
pmfs = []
|
|
292
289
|
curr = 0
|
|
293
290
|
for sz in self.generator.cats_disc:
|
|
294
|
-
chunk = onehot_all[:, curr : curr + sz]
|
|
295
|
-
|
|
296
|
-
|
|
291
|
+
chunk = onehot_all[:, curr : curr + sz] # (N_row, cats)
|
|
292
|
+
chunk_np = chunk.numpy()
|
|
293
|
+
log_freqs = np.log(np.sum(chunk_np, axis=0) + 1.0).reshape(1, -1)
|
|
294
|
+
pmfs.append(
|
|
295
|
+
tf.convert_to_tensor(log_freqs / np.sum(log_freqs), dtype=tf.float32)
|
|
296
|
+
)
|
|
297
297
|
curr += sz
|
|
298
298
|
return pmfs
|
|
299
299
|
|
|
@@ -301,7 +301,16 @@ class CTGANModel(keras.Model):
|
|
|
301
301
|
batch = ops.shape(data)[0]
|
|
302
302
|
self.row_dim = ops.shape(data)[1]
|
|
303
303
|
z = tf.random.normal([batch, self.row_dim - sum(self.generator.cats_disc)])
|
|
304
|
-
|
|
304
|
+
|
|
305
|
+
# Use tf.py_function to call generate_batch_cond in eager mode
|
|
306
|
+
def generate_cond_eager(batch_size):
|
|
307
|
+
return self.generate_batch_cond(batch_size)
|
|
308
|
+
|
|
309
|
+
cond = tf.py_function(generate_cond_eager, [batch], tf.float32)
|
|
310
|
+
# Set the shape explicitly - it should be [batch_size, total_cond_dim]
|
|
311
|
+
total_cond_dim = sum(self.generator.cats_disc)
|
|
312
|
+
cond.set_shape([None, total_cond_dim])
|
|
313
|
+
|
|
305
314
|
real_batch = CTGANModel.sample_real_data(
|
|
306
315
|
self._train_data, cond, self.onehot_discrete_indexes
|
|
307
316
|
)
|
|
@@ -11,6 +11,7 @@ from sdg_core_lib.data_generator.models.GANs.CTGANComponents import (
|
|
|
11
11
|
CTGANModel,
|
|
12
12
|
)
|
|
13
13
|
import keras
|
|
14
|
+
import tensorflow as tf
|
|
14
15
|
from sdg_core_lib.data_generator.models.TrainingInfo import TrainingInfo
|
|
15
16
|
import numpy as np
|
|
16
17
|
|
|
@@ -18,7 +19,7 @@ import numpy as np
|
|
|
18
19
|
class CTGAN(UnspecializedModel):
|
|
19
20
|
def __init__(
|
|
20
21
|
self,
|
|
21
|
-
metadata: dict,
|
|
22
|
+
metadata: list[dict],
|
|
22
23
|
model_name: str,
|
|
23
24
|
input_shape: str = None,
|
|
24
25
|
load_path: str = None,
|
|
@@ -31,7 +32,9 @@ class CTGAN(UnspecializedModel):
|
|
|
31
32
|
gen_steps=4,
|
|
32
33
|
critic_dropout=0.2,
|
|
33
34
|
):
|
|
34
|
-
super().__init__(
|
|
35
|
+
super().__init__(
|
|
36
|
+
self._clean_skeleton(metadata), model_name, input_shape, load_path
|
|
37
|
+
)
|
|
35
38
|
self._batch_size = batch_size
|
|
36
39
|
self._epochs = epochs
|
|
37
40
|
self._gen_steps = gen_steps
|
|
@@ -43,10 +46,19 @@ class CTGAN(UnspecializedModel):
|
|
|
43
46
|
self._instantiate()
|
|
44
47
|
|
|
45
48
|
@staticmethod
|
|
46
|
-
def
|
|
49
|
+
def _clean_skeleton(skeleton):
|
|
50
|
+
if skeleton != [{}]:
|
|
51
|
+
return [
|
|
52
|
+
item
|
|
53
|
+
for item in skeleton
|
|
54
|
+
if item["feature_type"] in ["continuous", "categorical"]
|
|
55
|
+
]
|
|
56
|
+
return skeleton
|
|
57
|
+
|
|
58
|
+
def infer_data_structure(self):
|
|
47
59
|
cats, modes, idxs = [], [], []
|
|
48
60
|
true_index = 0
|
|
49
|
-
for col in
|
|
61
|
+
for col in self._metadata:
|
|
50
62
|
try:
|
|
51
63
|
f_size = int(col["feature_size"])
|
|
52
64
|
if col["feature_type"] == "categorical":
|
|
@@ -90,7 +102,7 @@ class CTGAN(UnspecializedModel):
|
|
|
90
102
|
categories_per_discrete_column,
|
|
91
103
|
modes_per_continuous_column,
|
|
92
104
|
onehot_discrete_indexes,
|
|
93
|
-
) =
|
|
105
|
+
) = self.infer_data_structure()
|
|
94
106
|
self.generator = CTGANGenerator(
|
|
95
107
|
self._metadata,
|
|
96
108
|
modes_per_continuous_column,
|
|
@@ -106,7 +118,7 @@ class CTGAN(UnspecializedModel):
|
|
|
106
118
|
# Should set the _model variable CTGAN Model complete with Generator and Critic
|
|
107
119
|
# Does NOT return the model
|
|
108
120
|
# self._metadata is available
|
|
109
|
-
_, _, onehot_discrete_indexes =
|
|
121
|
+
_, _, onehot_discrete_indexes = self.infer_data_structure()
|
|
110
122
|
critic = keras.saving.load_model(os.path.join(folder_path, "critic.keras"))
|
|
111
123
|
generator = keras.saving.load_model(
|
|
112
124
|
os.path.join(folder_path, "generator.keras")
|
|
@@ -114,11 +126,33 @@ class CTGAN(UnspecializedModel):
|
|
|
114
126
|
self._model = CTGANModel(generator, critic, onehot_discrete_indexes)
|
|
115
127
|
|
|
116
128
|
# Load probability_mass_function_list if it exists
|
|
117
|
-
pmf_path = os.path.join(folder_path, "probability_mass_function_list.
|
|
129
|
+
pmf_path = os.path.join(folder_path, "probability_mass_function_list.npz")
|
|
118
130
|
if os.path.exists(pmf_path):
|
|
119
|
-
|
|
120
|
-
|
|
131
|
+
pmf_data = np.load(pmf_path)
|
|
132
|
+
# Convert back to list of TensorFlow tensors
|
|
133
|
+
pmf_list = []
|
|
134
|
+
for key in sorted(pmf_data.keys()):
|
|
135
|
+
pmf_list.append(tf.convert_to_tensor(pmf_data[key], dtype=tf.float32))
|
|
136
|
+
self._model.probability_mass_function_list = pmf_list
|
|
137
|
+
# Also check for old .npy format for backward compatibility
|
|
138
|
+
elif os.path.exists(
|
|
139
|
+
os.path.join(folder_path, "probability_mass_function_list.npy")
|
|
140
|
+
):
|
|
141
|
+
pmf_list = np.load(
|
|
142
|
+
os.path.join(folder_path, "probability_mass_function_list.npy"),
|
|
143
|
+
allow_pickle=True,
|
|
121
144
|
)
|
|
145
|
+
# Convert to TensorFlow tensors if needed
|
|
146
|
+
if (
|
|
147
|
+
isinstance(pmf_list, list)
|
|
148
|
+
and len(pmf_list) > 0
|
|
149
|
+
and isinstance(pmf_list[0], np.ndarray)
|
|
150
|
+
):
|
|
151
|
+
self._model.probability_mass_function_list = [
|
|
152
|
+
tf.convert_to_tensor(pmf, dtype=tf.float32) for pmf in pmf_list
|
|
153
|
+
]
|
|
154
|
+
else:
|
|
155
|
+
self._model.probability_mass_function_list = pmf_list
|
|
122
156
|
|
|
123
157
|
def save(self, folder_path: str):
|
|
124
158
|
keras.saving.save_model(
|
|
@@ -132,9 +166,13 @@ class CTGAN(UnspecializedModel):
|
|
|
132
166
|
hasattr(self._model, "probability_mass_function_list")
|
|
133
167
|
and self._model.probability_mass_function_list is not None
|
|
134
168
|
):
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
self._model.probability_mass_function_list
|
|
169
|
+
# Convert TensorFlow tensors to numpy arrays before saving
|
|
170
|
+
pmf_list = [
|
|
171
|
+
tensor.numpy() for tensor in self._model.probability_mass_function_list
|
|
172
|
+
]
|
|
173
|
+
np.savez(
|
|
174
|
+
os.path.join(folder_path, "probability_mass_function_list.npz"),
|
|
175
|
+
*pmf_list,
|
|
138
176
|
)
|
|
139
177
|
|
|
140
178
|
def train(self, data: np.ndarray):
|
|
@@ -157,9 +195,7 @@ class CTGAN(UnspecializedModel):
|
|
|
157
195
|
)
|
|
158
196
|
self._model._train_data = data
|
|
159
197
|
probability_mass_function_list = self._model.get_pmfs(data)
|
|
160
|
-
self._model.probability_mass_function_list =
|
|
161
|
-
probability_mass_function_list
|
|
162
|
-
)
|
|
198
|
+
self._model.probability_mass_function_list = probability_mass_function_list
|
|
163
199
|
history = self._model.fit(
|
|
164
200
|
data, batch_size=self._batch_size, epochs=self._epochs, verbose=1
|
|
165
201
|
)
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/strategies/steps.py
RENAMED
|
@@ -180,33 +180,32 @@ class PerModeNormalization(Step):
|
|
|
180
180
|
if self.operator is None:
|
|
181
181
|
raise ValueError("Operator not initialized")
|
|
182
182
|
column = data.reshape(-1, 1)
|
|
183
|
-
|
|
184
|
-
weights = self.operator.weights_[
|
|
185
|
-
means = self.operator.means_[
|
|
186
|
-
stds = np.sqrt(self.operator.covariances_[
|
|
183
|
+
active_weights_indexes = np.where(self.operator.weights_ > 0.01)
|
|
184
|
+
weights = self.operator.weights_[active_weights_indexes]
|
|
185
|
+
means = self.operator.means_[active_weights_indexes].flatten()
|
|
186
|
+
stds = np.sqrt(self.operator.covariances_[active_weights_indexes].flatten())
|
|
187
187
|
mixture_probability_density = []
|
|
188
|
-
for
|
|
188
|
+
for weight, mean, std in zip(weights, means, stds):
|
|
189
189
|
mixture_probability_density.append(
|
|
190
|
-
|
|
191
|
-
* PerModeNormalization._gaussian_probability_density_function(
|
|
192
|
-
column, m, s
|
|
193
|
-
)
|
|
190
|
+
weight * self._gaussian_probability_density_function(column, mean, std)
|
|
194
191
|
)
|
|
195
192
|
marginal_mixture_probability_density = np.hstack(mixture_probability_density)
|
|
196
|
-
responsibilities =
|
|
193
|
+
responsibilities = self._compute_responsibilities(
|
|
197
194
|
marginal_mixture_probability_density
|
|
198
195
|
)
|
|
199
196
|
rng = np.random.default_rng(self.random_state)
|
|
200
|
-
n,
|
|
197
|
+
n, k = responsibilities.shape
|
|
201
198
|
sampled_mode = np.array(
|
|
202
|
-
[rng.choice(
|
|
199
|
+
[rng.choice(k, p=responsibilities[i]) for i in range(n)]
|
|
203
200
|
)
|
|
204
|
-
|
|
205
|
-
|
|
201
|
+
mode_assignment = np.zeros((n, k), dtype=int)
|
|
202
|
+
mode_assignment[np.arange(n), sampled_mode] = 1
|
|
206
203
|
mu_sel = means[sampled_mode]
|
|
207
204
|
std_sel = stds[sampled_mode]
|
|
208
205
|
normalized_value = (column.reshape(-1) - mu_sel) / (4.0 * std_sel)
|
|
209
|
-
to_return = np.concatenate(
|
|
206
|
+
to_return = np.concatenate(
|
|
207
|
+
[normalized_value.reshape(-1, 1), mode_assignment], axis=1
|
|
208
|
+
)
|
|
210
209
|
return to_return
|
|
211
210
|
|
|
212
211
|
def inverse_transform(self, data: np.ndarray) -> np.ndarray:
|
|
@@ -221,18 +220,10 @@ class PerModeNormalization(Step):
|
|
|
221
220
|
data = data.reshape(1, -1)
|
|
222
221
|
|
|
223
222
|
active_modes = np.argmax(data[:, 1:], axis=1)
|
|
224
|
-
|
|
225
|
-
# Get the means and stds for the active modes
|
|
226
223
|
selected_mus = means[active_modes]
|
|
227
224
|
selected_devs = stds[active_modes]
|
|
228
|
-
|
|
229
|
-
# Get the normalized values (first column)
|
|
230
225
|
normalized_values = data[:, 0]
|
|
231
|
-
|
|
232
|
-
# Denormalize the values
|
|
233
226
|
values = (normalized_values * 4 * selected_devs) + selected_mus
|
|
234
|
-
|
|
235
|
-
# Always return 2D array with shape (n_samples, 1) for consistency
|
|
236
227
|
return values.reshape(-1, 1)
|
|
237
228
|
|
|
238
229
|
@staticmethod
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/data_generator/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/base_evaluator.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/evaluate/time_series.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/FunctionApplier.py
RENAMED
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/post_process/function_utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/base_processor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sdg_core_lib-0.1.8.dev1 → sdg_core_lib-0.1.8.dev2}/src/sdg_core_lib/preprocess/table_processor.py
RENAMED
|
File without changes
|