additory 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- additory/__init__.py +15 -0
- additory/analysis/__init__.py +48 -0
- additory/analysis/cardinality.py +126 -0
- additory/analysis/correlations.py +124 -0
- additory/analysis/distributions.py +376 -0
- additory/analysis/quality.py +158 -0
- additory/analysis/scan.py +400 -0
- additory/augment/__init__.py +24 -0
- additory/augment/augmentor.py +653 -0
- additory/augment/builtin_lists.py +430 -0
- additory/augment/distributions.py +22 -0
- additory/augment/forecast.py +1132 -0
- additory/augment/list_registry.py +177 -0
- additory/augment/smote.py +320 -0
- additory/augment/strategies.py +883 -0
- additory/common/__init__.py +157 -0
- additory/common/backend.py +355 -0
- additory/common/column_utils.py +191 -0
- additory/common/distributions.py +737 -0
- additory/common/exceptions.py +62 -0
- additory/common/lists.py +229 -0
- additory/common/patterns.py +240 -0
- additory/common/resolver.py +567 -0
- additory/common/sample_data.py +182 -0
- additory/common/validation.py +197 -0
- additory/core/__init__.py +27 -0
- additory/core/ast_builder.py +165 -0
- additory/core/backends/__init__.py +23 -0
- additory/core/backends/arrow_bridge.py +476 -0
- additory/core/backends/cudf_bridge.py +355 -0
- additory/core/column_positioning.py +358 -0
- additory/core/compiler_polars.py +166 -0
- additory/core/config.py +342 -0
- additory/core/enhanced_cache_manager.py +1119 -0
- additory/core/enhanced_matchers.py +473 -0
- additory/core/enhanced_version_manager.py +325 -0
- additory/core/executor.py +59 -0
- additory/core/integrity_manager.py +477 -0
- additory/core/loader.py +190 -0
- additory/core/logging.py +24 -0
- additory/core/memory_manager.py +547 -0
- additory/core/namespace_manager.py +657 -0
- additory/core/parser.py +176 -0
- additory/core/polars_expression_engine.py +551 -0
- additory/core/registry.py +176 -0
- additory/core/sample_data_manager.py +492 -0
- additory/core/user_namespace.py +751 -0
- additory/core/validator.py +27 -0
- additory/dynamic_api.py +308 -0
- additory/expressions/__init__.py +26 -0
- additory/expressions/engine.py +551 -0
- additory/expressions/parser.py +176 -0
- additory/expressions/proxy.py +546 -0
- additory/expressions/registry.py +313 -0
- additory/expressions/samples.py +492 -0
- additory/synthetic/__init__.py +101 -0
- additory/synthetic/api.py +220 -0
- additory/synthetic/common_integration.py +314 -0
- additory/synthetic/config.py +262 -0
- additory/synthetic/engines.py +529 -0
- additory/synthetic/exceptions.py +180 -0
- additory/synthetic/file_managers.py +518 -0
- additory/synthetic/generator.py +702 -0
- additory/synthetic/generator_parser.py +68 -0
- additory/synthetic/integration.py +319 -0
- additory/synthetic/models.py +241 -0
- additory/synthetic/pattern_resolver.py +573 -0
- additory/synthetic/performance.py +469 -0
- additory/synthetic/polars_integration.py +464 -0
- additory/synthetic/proxy.py +60 -0
- additory/synthetic/schema_parser.py +685 -0
- additory/synthetic/validator.py +553 -0
- additory/utilities/__init__.py +53 -0
- additory/utilities/encoding.py +600 -0
- additory/utilities/games.py +300 -0
- additory/utilities/keys.py +8 -0
- additory/utilities/lookup.py +103 -0
- additory/utilities/matchers.py +216 -0
- additory/utilities/resolvers.py +286 -0
- additory/utilities/settings.py +167 -0
- additory/utilities/units.py +746 -0
- additory/utilities/validators.py +153 -0
- additory-0.1.0a1.dist-info/METADATA +293 -0
- additory-0.1.0a1.dist-info/RECORD +87 -0
- additory-0.1.0a1.dist-info/WHEEL +5 -0
- additory-0.1.0a1.dist-info/licenses/LICENSE +21 -0
- additory-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
List Registry Manager for Data Augmentation
|
|
3
|
+
|
|
4
|
+
Manages user-registered lists and provides access to built-in lists.
|
|
5
|
+
|
|
6
|
+
List Resolution Order:
|
|
7
|
+
1. User-registered lists (highest priority)
|
|
8
|
+
2. Built-in lists
|
|
9
|
+
3. Error if not found
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import List, Optional, Dict, Any
|
|
13
|
+
|
|
14
|
+
from additory.common.exceptions import ValidationError
|
|
15
|
+
from additory.augment.builtin_lists import BUILTIN_LISTS, list_builtin_names
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Global registry for user-registered lists
|
|
19
|
+
_USER_LISTS: Dict[str, List[Any]] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def register_list(name: str, values: List[Any]) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Register a custom list for use in augmentation strategies.
|
|
25
|
+
|
|
26
|
+
User-registered lists take priority over built-in lists with the same name.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: List name (e.g., "my_custom_list")
|
|
30
|
+
values: List of values
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValidationError: If parameters are invalid
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
>>> add.register_list("custom_statuses", ["New", "Processing", "Done"])
|
|
37
|
+
>>> add.register_list("banks", ["My Bank", "Your Bank"]) # Overrides built-in
|
|
38
|
+
"""
|
|
39
|
+
if not isinstance(name, str) or not name.strip():
|
|
40
|
+
raise ValidationError("List name must be a non-empty string")
|
|
41
|
+
|
|
42
|
+
if not isinstance(values, list):
|
|
43
|
+
raise ValidationError("Values must be a list")
|
|
44
|
+
|
|
45
|
+
if len(values) == 0:
|
|
46
|
+
raise ValidationError("List must contain at least one value")
|
|
47
|
+
|
|
48
|
+
# Store in user registry
|
|
49
|
+
_USER_LISTS[name] = values
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_list(name: str) -> List[Any]:
|
|
53
|
+
"""
|
|
54
|
+
Get a list by name (user-registered or built-in).
|
|
55
|
+
|
|
56
|
+
Resolution order:
|
|
57
|
+
1. User-registered lists
|
|
58
|
+
2. Built-in lists
|
|
59
|
+
3. Raise error if not found
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name: List name
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of values
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValidationError: If list not found
|
|
69
|
+
"""
|
|
70
|
+
# Check user-registered lists first
|
|
71
|
+
if name in _USER_LISTS:
|
|
72
|
+
return _USER_LISTS[name]
|
|
73
|
+
|
|
74
|
+
# Check built-in lists
|
|
75
|
+
if name in BUILTIN_LISTS:
|
|
76
|
+
return BUILTIN_LISTS[name]
|
|
77
|
+
|
|
78
|
+
# Not found
|
|
79
|
+
raise ValidationError(
|
|
80
|
+
f"List '{name}' not found. "
|
|
81
|
+
f"Use add.list_available() to see available lists or "
|
|
82
|
+
f"add.register_list('{name}', [...]) to create it."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def list_exists(name: str) -> bool:
|
|
87
|
+
"""
|
|
88
|
+
Check if a list exists (user-registered or built-in).
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
name: List name
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
True if list exists, False otherwise
|
|
95
|
+
"""
|
|
96
|
+
return name in _USER_LISTS or name in BUILTIN_LISTS
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def list_available() -> Dict[str, int]:
|
|
100
|
+
"""
|
|
101
|
+
Get all available lists with their sizes.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Dictionary mapping list names to their sizes
|
|
105
|
+
Format: {"list_name": count, ...}
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
>>> lists = add.list_available()
|
|
109
|
+
>>> print(lists)
|
|
110
|
+
{'first_names': 200, 'banks': 120, 'my_custom_list': 5, ...}
|
|
111
|
+
"""
|
|
112
|
+
result = {}
|
|
113
|
+
|
|
114
|
+
# Add built-in lists
|
|
115
|
+
for name, values in BUILTIN_LISTS.items():
|
|
116
|
+
result[name] = len(values)
|
|
117
|
+
|
|
118
|
+
# Add user-registered lists (may override built-in counts)
|
|
119
|
+
for name, values in _USER_LISTS.items():
|
|
120
|
+
result[name] = len(values)
|
|
121
|
+
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def list_show(name: str) -> List[Any]:
|
|
126
|
+
"""
|
|
127
|
+
Show the contents of a list.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
name: List name
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of values
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValidationError: If list not found
|
|
137
|
+
|
|
138
|
+
Examples:
|
|
139
|
+
>>> add.list_show("statuses")
|
|
140
|
+
['Active', 'Inactive', 'Pending', 'Completed', ...]
|
|
141
|
+
"""
|
|
142
|
+
return get_list(name)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def list_remove(name: str) -> bool:
|
|
146
|
+
"""
|
|
147
|
+
Remove a user-registered list.
|
|
148
|
+
|
|
149
|
+
Note: Cannot remove built-in lists. If a user-registered list
|
|
150
|
+
overrides a built-in list, removing it will restore the built-in.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
name: List name
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if list was removed, False if not found
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
>>> add.register_list("temp_list", ["a", "b"])
|
|
160
|
+
>>> add.list_remove("temp_list")
|
|
161
|
+
True
|
|
162
|
+
>>> add.list_remove("temp_list")
|
|
163
|
+
False
|
|
164
|
+
"""
|
|
165
|
+
if name in _USER_LISTS:
|
|
166
|
+
del _USER_LISTS[name]
|
|
167
|
+
return True
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def clear_user_lists() -> None:
|
|
172
|
+
"""
|
|
173
|
+
Clear all user-registered lists.
|
|
174
|
+
|
|
175
|
+
Built-in lists are not affected.
|
|
176
|
+
"""
|
|
177
|
+
_USER_LISTS.clear()
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SMOTE (Synthetic Minority Over-sampling Technique) for Data Augmentation
|
|
3
|
+
|
|
4
|
+
Provides imbalanced data handling strategies:
|
|
5
|
+
- SMOTE: Generate synthetic samples for minority class
|
|
6
|
+
- Balance: Balance class distribution
|
|
7
|
+
- Oversample: Simple oversampling with variation
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import List, Optional, Dict, Any, Tuple
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from additory.common.exceptions import ValidationError, AugmentError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def calculate_distances(point: np.ndarray, data: np.ndarray) -> np.ndarray:
|
|
19
|
+
"""
|
|
20
|
+
Calculate Euclidean distances from point to all points in data.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
point: Single data point (1D array)
|
|
24
|
+
data: Array of data points (2D array)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Array of distances
|
|
28
|
+
"""
|
|
29
|
+
return np.sqrt(np.sum((data - point) ** 2, axis=1))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def find_k_nearest_neighbors(
|
|
33
|
+
point_idx: int,
|
|
34
|
+
data: np.ndarray,
|
|
35
|
+
k: int = 5
|
|
36
|
+
) -> np.ndarray:
|
|
37
|
+
"""
|
|
38
|
+
Find k nearest neighbors of a point.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
point_idx: Index of the point
|
|
42
|
+
data: Array of all data points
|
|
43
|
+
k: Number of neighbors to find
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Array of indices of k nearest neighbors
|
|
47
|
+
"""
|
|
48
|
+
point = data[point_idx]
|
|
49
|
+
distances = calculate_distances(point, data)
|
|
50
|
+
|
|
51
|
+
# Exclude the point itself
|
|
52
|
+
distances[point_idx] = np.inf
|
|
53
|
+
|
|
54
|
+
# Get k nearest
|
|
55
|
+
nearest_indices = np.argsort(distances)[:k]
|
|
56
|
+
|
|
57
|
+
return nearest_indices
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def generate_synthetic_sample(
|
|
61
|
+
point: np.ndarray,
|
|
62
|
+
neighbor: np.ndarray,
|
|
63
|
+
seed: Optional[int] = None
|
|
64
|
+
) -> np.ndarray:
|
|
65
|
+
"""
|
|
66
|
+
Generate synthetic sample between point and neighbor.
|
|
67
|
+
|
|
68
|
+
Uses linear interpolation with random weight.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
point: Original data point
|
|
72
|
+
neighbor: Neighbor data point
|
|
73
|
+
seed: Random seed
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Synthetic sample
|
|
77
|
+
"""
|
|
78
|
+
if seed is not None:
|
|
79
|
+
np.random.seed(seed)
|
|
80
|
+
|
|
81
|
+
# Random weight between 0 and 1
|
|
82
|
+
weight = np.random.random()
|
|
83
|
+
|
|
84
|
+
# Linear interpolation
|
|
85
|
+
synthetic = point + weight * (neighbor - point)
|
|
86
|
+
|
|
87
|
+
return synthetic
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def smote_generate(
|
|
91
|
+
data: np.ndarray,
|
|
92
|
+
n_samples: int,
|
|
93
|
+
k_neighbors: int = 5,
|
|
94
|
+
seed: Optional[int] = None
|
|
95
|
+
) -> np.ndarray:
|
|
96
|
+
"""
|
|
97
|
+
Generate synthetic samples using SMOTE algorithm.
|
|
98
|
+
|
|
99
|
+
SMOTE creates synthetic samples by:
|
|
100
|
+
1. For each sample, find k nearest neighbors
|
|
101
|
+
2. Randomly select one neighbor
|
|
102
|
+
3. Create synthetic sample along line between sample and neighbor
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
data: Original data (2D array: samples x features)
|
|
106
|
+
n_samples: Number of synthetic samples to generate
|
|
107
|
+
k_neighbors: Number of nearest neighbors to consider
|
|
108
|
+
seed: Random seed for reproducibility
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Array of synthetic samples
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ValidationError: If parameters invalid
|
|
115
|
+
"""
|
|
116
|
+
n_original, n_features = data.shape
|
|
117
|
+
|
|
118
|
+
# Validate parameters
|
|
119
|
+
if n_samples <= 0:
|
|
120
|
+
raise ValidationError(f"n_samples must be positive, got {n_samples}")
|
|
121
|
+
|
|
122
|
+
if k_neighbors <= 0:
|
|
123
|
+
raise ValidationError(f"k_neighbors must be positive, got {k_neighbors}")
|
|
124
|
+
|
|
125
|
+
if k_neighbors >= n_original:
|
|
126
|
+
warnings.warn(
|
|
127
|
+
f"k_neighbors ({k_neighbors}) >= number of samples ({n_original}). "
|
|
128
|
+
f"Using k_neighbors={n_original - 1}"
|
|
129
|
+
)
|
|
130
|
+
k_neighbors = n_original - 1
|
|
131
|
+
|
|
132
|
+
if n_original < 2:
|
|
133
|
+
raise ValidationError(
|
|
134
|
+
f"Need at least 2 samples for SMOTE, got {n_original}"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Set seed for reproducibility
|
|
138
|
+
if seed is not None:
|
|
139
|
+
np.random.seed(seed)
|
|
140
|
+
|
|
141
|
+
# Generate synthetic samples
|
|
142
|
+
synthetic_samples = []
|
|
143
|
+
|
|
144
|
+
for i in range(n_samples):
|
|
145
|
+
# Randomly select a sample
|
|
146
|
+
sample_idx = np.random.randint(0, n_original)
|
|
147
|
+
sample = data[sample_idx]
|
|
148
|
+
|
|
149
|
+
# Find k nearest neighbors
|
|
150
|
+
neighbor_indices = find_k_nearest_neighbors(sample_idx, data, k_neighbors)
|
|
151
|
+
|
|
152
|
+
# Randomly select one neighbor
|
|
153
|
+
neighbor_idx = np.random.choice(neighbor_indices)
|
|
154
|
+
neighbor = data[neighbor_idx]
|
|
155
|
+
|
|
156
|
+
# Generate synthetic sample
|
|
157
|
+
synthetic = generate_synthetic_sample(sample, neighbor, seed=None)
|
|
158
|
+
synthetic_samples.append(synthetic)
|
|
159
|
+
|
|
160
|
+
return np.array(synthetic_samples)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def apply_smote_strategy(
|
|
164
|
+
df_polars,
|
|
165
|
+
columns: List[str],
|
|
166
|
+
n_rows: int,
|
|
167
|
+
k_neighbors: int = 5,
|
|
168
|
+
seed: Optional[int] = None
|
|
169
|
+
) -> Dict[str, List[float]]:
|
|
170
|
+
"""
|
|
171
|
+
Apply SMOTE to generate synthetic rows for specified columns.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
df_polars: Input Polars DataFrame
|
|
175
|
+
columns: List of column names to use for SMOTE
|
|
176
|
+
n_rows: Number of synthetic rows to generate
|
|
177
|
+
k_neighbors: Number of nearest neighbors
|
|
178
|
+
seed: Random seed for reproducibility
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dictionary mapping column names to generated values
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValidationError: If columns invalid or insufficient data
|
|
185
|
+
"""
|
|
186
|
+
# Validate columns exist
|
|
187
|
+
for col in columns:
|
|
188
|
+
if col not in df_polars.columns:
|
|
189
|
+
raise ValidationError(f"Column '{col}' not found in DataFrame")
|
|
190
|
+
|
|
191
|
+
# Extract data for specified columns
|
|
192
|
+
data_list = []
|
|
193
|
+
for col in columns:
|
|
194
|
+
col_data = df_polars[col].to_numpy()
|
|
195
|
+
|
|
196
|
+
# Check if numeric
|
|
197
|
+
if not np.issubdtype(col_data.dtype, np.number):
|
|
198
|
+
raise ValidationError(
|
|
199
|
+
f"SMOTE requires numeric columns. Column '{col}' is not numeric."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Check for nulls
|
|
203
|
+
if np.any(np.isnan(col_data)):
|
|
204
|
+
raise ValidationError(
|
|
205
|
+
f"SMOTE requires non-null values. Column '{col}' contains nulls."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
data_list.append(col_data)
|
|
209
|
+
|
|
210
|
+
# Stack into 2D array (samples x features)
|
|
211
|
+
data = np.column_stack(data_list)
|
|
212
|
+
|
|
213
|
+
# Generate synthetic samples
|
|
214
|
+
synthetic_data = smote_generate(data, n_rows, k_neighbors, seed)
|
|
215
|
+
|
|
216
|
+
# Split back into columns
|
|
217
|
+
result = {}
|
|
218
|
+
for i, col in enumerate(columns):
|
|
219
|
+
result[col] = synthetic_data[:, i].tolist()
|
|
220
|
+
|
|
221
|
+
return result
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def balance_classes(
|
|
225
|
+
df_polars,
|
|
226
|
+
class_column: str,
|
|
227
|
+
target_ratio: float = 1.0,
|
|
228
|
+
method: str = "smote",
|
|
229
|
+
k_neighbors: int = 5,
|
|
230
|
+
seed: Optional[int] = None
|
|
231
|
+
) -> Tuple[int, str]:
|
|
232
|
+
"""
|
|
233
|
+
Calculate how many samples needed to balance classes.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
df_polars: Input Polars DataFrame
|
|
237
|
+
class_column: Column containing class labels
|
|
238
|
+
target_ratio: Target ratio of minority to majority class (default: 1.0 for perfect balance)
|
|
239
|
+
method: Balancing method ('smote' or 'oversample')
|
|
240
|
+
k_neighbors: Number of neighbors for SMOTE
|
|
241
|
+
seed: Random seed
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Tuple of (n_samples_needed, minority_class)
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValidationError: If class column invalid
|
|
248
|
+
"""
|
|
249
|
+
# Validate class column
|
|
250
|
+
if class_column not in df_polars.columns:
|
|
251
|
+
raise ValidationError(f"Class column '{class_column}' not found in DataFrame")
|
|
252
|
+
|
|
253
|
+
# Get class counts
|
|
254
|
+
class_counts = df_polars[class_column].value_counts()
|
|
255
|
+
|
|
256
|
+
if len(class_counts) < 2:
|
|
257
|
+
raise ValidationError(
|
|
258
|
+
f"Need at least 2 classes for balancing, found {len(class_counts)}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Find minority and majority classes
|
|
262
|
+
class_counts_dict = dict(zip(
|
|
263
|
+
class_counts[class_column].to_list(),
|
|
264
|
+
class_counts['counts'].to_list()
|
|
265
|
+
))
|
|
266
|
+
|
|
267
|
+
minority_class = min(class_counts_dict, key=class_counts_dict.get)
|
|
268
|
+
majority_class = max(class_counts_dict, key=class_counts_dict.get)
|
|
269
|
+
|
|
270
|
+
minority_count = class_counts_dict[minority_class]
|
|
271
|
+
majority_count = class_counts_dict[majority_class]
|
|
272
|
+
|
|
273
|
+
# Calculate target count for minority class
|
|
274
|
+
target_count = int(majority_count * target_ratio)
|
|
275
|
+
|
|
276
|
+
# Calculate how many samples needed
|
|
277
|
+
n_samples_needed = max(0, target_count - minority_count)
|
|
278
|
+
|
|
279
|
+
return n_samples_needed, minority_class
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def generate_smote_values(
|
|
283
|
+
df_polars,
|
|
284
|
+
columns: List[str],
|
|
285
|
+
n_rows: int,
|
|
286
|
+
k_neighbors: int = 5,
|
|
287
|
+
seed: Optional[int] = None,
|
|
288
|
+
**params
|
|
289
|
+
) -> Dict[str, List[Any]]:
|
|
290
|
+
"""
|
|
291
|
+
Main SMOTE generation function.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
df_polars: Input Polars DataFrame
|
|
295
|
+
columns: Columns to use for SMOTE (numeric only)
|
|
296
|
+
n_rows: Number of synthetic rows to generate
|
|
297
|
+
k_neighbors: Number of nearest neighbors (default: 5)
|
|
298
|
+
seed: Random seed for reproducibility
|
|
299
|
+
**params: Additional parameters (reserved for future use)
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Dictionary mapping column names to generated values
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
ValidationError: If parameters invalid
|
|
306
|
+
AugmentError: If generation fails
|
|
307
|
+
"""
|
|
308
|
+
try:
|
|
309
|
+
return apply_smote_strategy(
|
|
310
|
+
df_polars,
|
|
311
|
+
columns,
|
|
312
|
+
n_rows,
|
|
313
|
+
k_neighbors,
|
|
314
|
+
seed
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
except Exception as e:
|
|
318
|
+
if isinstance(e, (ValidationError, AugmentError)):
|
|
319
|
+
raise
|
|
320
|
+
raise AugmentError(f"SMOTE generation failed: {e}")
|