additory 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- additory/__init__.py +4 -0
- additory/common/__init__.py +2 -2
- additory/common/backend.py +20 -4
- additory/common/distributions.py +1 -1
- additory/common/sample_data.py +19 -19
- additory/core/backends/arrow_bridge.py +7 -0
- additory/core/config.py +3 -3
- additory/core/polars_expression_engine.py +66 -16
- additory/core/registry.py +4 -3
- additory/dynamic_api.py +95 -51
- additory/expressions/proxy.py +4 -1
- additory/expressions/registry.py +3 -3
- additory/synthetic/__init__.py +7 -95
- additory/synthetic/column_name_resolver.py +149 -0
- additory/synthetic/deduce.py +259 -0
- additory/{augment → synthetic}/distributions.py +2 -2
- additory/{augment → synthetic}/forecast.py +1 -1
- additory/synthetic/linked_list_parser.py +415 -0
- additory/synthetic/namespace_lookup.py +129 -0
- additory/{augment → synthetic}/smote.py +1 -1
- additory/{augment → synthetic}/strategies.py +87 -44
- additory/{augment/augmentor.py → synthetic/synthesizer.py} +75 -15
- additory/utilities/units.py +4 -1
- {additory-0.1.0a2.dist-info → additory-0.1.0a4.dist-info}/METADATA +44 -28
- {additory-0.1.0a2.dist-info → additory-0.1.0a4.dist-info}/RECORD +28 -43
- {additory-0.1.0a2.dist-info → additory-0.1.0a4.dist-info}/WHEEL +1 -1
- additory/augment/__init__.py +0 -24
- additory/augment/builtin_lists.py +0 -430
- additory/augment/list_registry.py +0 -177
- additory/synthetic/api.py +0 -220
- additory/synthetic/common_integration.py +0 -314
- additory/synthetic/config.py +0 -262
- additory/synthetic/engines.py +0 -529
- additory/synthetic/exceptions.py +0 -180
- additory/synthetic/file_managers.py +0 -518
- additory/synthetic/generator.py +0 -702
- additory/synthetic/generator_parser.py +0 -68
- additory/synthetic/integration.py +0 -319
- additory/synthetic/models.py +0 -241
- additory/synthetic/pattern_resolver.py +0 -573
- additory/synthetic/performance.py +0 -469
- additory/synthetic/polars_integration.py +0 -464
- additory/synthetic/proxy.py +0 -60
- additory/synthetic/schema_parser.py +0 -685
- additory/synthetic/validator.py +0 -553
- {additory-0.1.0a2.dist-info → additory-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {additory-0.1.0a2.dist-info → additory-0.1.0a4.dist-info}/top_level.txt +0 -0
additory/synthetic/engines.py
DELETED
|
@@ -1,529 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Distribution strategy engines for synthetic data generation.
|
|
3
|
-
|
|
4
|
-
Implements the six essential distribution strategies: equal, custom, categorical,
|
|
5
|
-
high_cardinality, numeric_range, and skewed. All engines use polars for
|
|
6
|
-
high-performance data generation.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from abc import ABC, abstractmethod
|
|
10
|
-
from typing import Dict, Any, List, Optional, Union
|
|
11
|
-
import polars as pl
|
|
12
|
-
import re
|
|
13
|
-
from dataclasses import dataclass
|
|
14
|
-
|
|
15
|
-
from .models import DistributionType, DistributionStrategy, ValidationResult
|
|
16
|
-
from .exceptions import SyntheticDataError, ValidationError
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass
|
|
20
|
-
class DistributionConfig:
|
|
21
|
-
"""Configuration for distribution strategy application."""
|
|
22
|
-
strategy: DistributionStrategy
|
|
23
|
-
base_values: List[str]
|
|
24
|
-
target_rows: int
|
|
25
|
-
seed: Optional[int] = None
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class DistributionEngine(ABC):
|
|
29
|
-
"""Abstract base class for distribution strategy engines."""
|
|
30
|
-
|
|
31
|
-
def __init__(self, seed: Optional[int] = None):
|
|
32
|
-
"""Initialize the distribution engine with optional seed."""
|
|
33
|
-
self.seed = seed
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
37
|
-
"""Check if this engine supports the given strategy type."""
|
|
38
|
-
pass
|
|
39
|
-
|
|
40
|
-
@abstractmethod
|
|
41
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
42
|
-
"""Validate the distribution configuration."""
|
|
43
|
-
pass
|
|
44
|
-
|
|
45
|
-
@abstractmethod
|
|
46
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
47
|
-
"""Apply the distribution strategy to generate data."""
|
|
48
|
-
pass
|
|
49
|
-
|
|
50
|
-
def _validate_base_requirements(self, config: DistributionConfig) -> ValidationResult:
|
|
51
|
-
"""Validate basic requirements common to all distributions."""
|
|
52
|
-
result = ValidationResult(is_valid=True)
|
|
53
|
-
|
|
54
|
-
if config.target_rows <= 0:
|
|
55
|
-
result.add_error("Target rows must be positive", "Set target_rows to a positive integer")
|
|
56
|
-
|
|
57
|
-
if not config.base_values:
|
|
58
|
-
result.add_error("Base values cannot be empty", "Provide at least one base value")
|
|
59
|
-
|
|
60
|
-
return result
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class EqualDistributionEngine(DistributionEngine):
|
|
64
|
-
"""Engine for equal (uniform) distribution strategy."""
|
|
65
|
-
|
|
66
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
67
|
-
"""Check if this engine supports equal distribution."""
|
|
68
|
-
return strategy_type == DistributionType.EQUAL
|
|
69
|
-
|
|
70
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
71
|
-
"""Validate equal distribution configuration."""
|
|
72
|
-
result = self._validate_base_requirements(config)
|
|
73
|
-
|
|
74
|
-
# Equal distribution has no additional requirements
|
|
75
|
-
return result
|
|
76
|
-
|
|
77
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
78
|
-
"""Apply equal distribution - uniform selection from base values."""
|
|
79
|
-
if not self.validate_config(config).is_valid:
|
|
80
|
-
raise ValidationError("Invalid configuration for equal distribution")
|
|
81
|
-
|
|
82
|
-
# Create uniform distribution using polars
|
|
83
|
-
return (
|
|
84
|
-
pl.Series("values", config.base_values)
|
|
85
|
-
.sample(n=config.target_rows, with_replacement=True, seed=self.seed)
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
class CustomDistributionEngine(DistributionEngine):
|
|
90
|
-
"""Engine for custom weighted distribution strategy."""
|
|
91
|
-
|
|
92
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
93
|
-
"""Check if this engine supports custom distribution."""
|
|
94
|
-
return strategy_type == DistributionType.CUSTOM
|
|
95
|
-
|
|
96
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
97
|
-
"""Validate custom distribution configuration."""
|
|
98
|
-
result = self._validate_base_requirements(config)
|
|
99
|
-
|
|
100
|
-
weights = config.strategy.parameters.get('weights', {})
|
|
101
|
-
if not weights:
|
|
102
|
-
result.add_error("Custom distribution requires weights", "Add 'weights' parameter with value:percentage pairs")
|
|
103
|
-
return result
|
|
104
|
-
|
|
105
|
-
# Validate weights format and sum
|
|
106
|
-
total_weight = 0
|
|
107
|
-
for value, weight in weights.items():
|
|
108
|
-
if not isinstance(weight, (int, float)):
|
|
109
|
-
result.add_error(f"Weight for '{value}' must be numeric", f"Change weight to a number")
|
|
110
|
-
else:
|
|
111
|
-
total_weight += weight
|
|
112
|
-
|
|
113
|
-
if abs(total_weight - 100) > 0.01: # Allow small floating point errors
|
|
114
|
-
result.add_error(f"Weights must sum to 100%, got {total_weight}%", "Adjust weights to sum to 100%")
|
|
115
|
-
|
|
116
|
-
return result
|
|
117
|
-
|
|
118
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
119
|
-
"""Apply custom weighted distribution."""
|
|
120
|
-
validation = self.validate_config(config)
|
|
121
|
-
if not validation.is_valid:
|
|
122
|
-
raise ValidationError(f"Invalid configuration for custom distribution: {validation.errors}")
|
|
123
|
-
|
|
124
|
-
weights = config.strategy.parameters['weights']
|
|
125
|
-
|
|
126
|
-
# Convert percentages to probabilities and create weighted samples
|
|
127
|
-
values = []
|
|
128
|
-
probabilities = []
|
|
129
|
-
|
|
130
|
-
for value, weight in weights.items():
|
|
131
|
-
values.append(value)
|
|
132
|
-
probabilities.append(weight / 100.0)
|
|
133
|
-
|
|
134
|
-
# Create weighted distribution manually since polars sample doesn't support weights
|
|
135
|
-
import random
|
|
136
|
-
if self.seed is not None:
|
|
137
|
-
random.seed(self.seed)
|
|
138
|
-
|
|
139
|
-
# Generate weighted samples
|
|
140
|
-
result_values = []
|
|
141
|
-
for _ in range(config.target_rows):
|
|
142
|
-
rand_val = random.random()
|
|
143
|
-
cumulative = 0.0
|
|
144
|
-
for i, prob in enumerate(probabilities):
|
|
145
|
-
cumulative += prob
|
|
146
|
-
if rand_val <= cumulative:
|
|
147
|
-
result_values.append(values[i])
|
|
148
|
-
break
|
|
149
|
-
|
|
150
|
-
return pl.Series("values", result_values)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class CategoricalDistributionEngine(DistributionEngine):
|
|
154
|
-
"""Engine for categorical distribution strategy."""
|
|
155
|
-
|
|
156
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
157
|
-
"""Check if this engine supports categorical distribution."""
|
|
158
|
-
return strategy_type == DistributionType.CATEGORICAL
|
|
159
|
-
|
|
160
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
161
|
-
"""Validate categorical distribution configuration."""
|
|
162
|
-
result = self._validate_base_requirements(config)
|
|
163
|
-
|
|
164
|
-
categories = config.strategy.parameters.get('categories', [])
|
|
165
|
-
if not categories:
|
|
166
|
-
result.add_error("Categorical distribution requires categories", "Add 'categories' parameter with list of allowed values")
|
|
167
|
-
return result
|
|
168
|
-
|
|
169
|
-
if not isinstance(categories, list):
|
|
170
|
-
result.add_error("Categories must be a list", "Change categories to a list format")
|
|
171
|
-
|
|
172
|
-
# Check that all categories exist in base values
|
|
173
|
-
base_set = set(config.base_values)
|
|
174
|
-
for category in categories:
|
|
175
|
-
if category not in base_set:
|
|
176
|
-
result.add_warning(f"Category '{category}' not found in base values")
|
|
177
|
-
|
|
178
|
-
return result
|
|
179
|
-
|
|
180
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
181
|
-
"""Apply categorical distribution - select only from specified categories."""
|
|
182
|
-
validation = self.validate_config(config)
|
|
183
|
-
if not validation.is_valid:
|
|
184
|
-
raise ValidationError(f"Invalid configuration for categorical distribution: {validation.errors}")
|
|
185
|
-
|
|
186
|
-
categories = config.strategy.parameters['categories']
|
|
187
|
-
|
|
188
|
-
# Filter base values to only include specified categories
|
|
189
|
-
available_categories = [val for val in config.base_values if val in categories]
|
|
190
|
-
|
|
191
|
-
if not available_categories:
|
|
192
|
-
raise SyntheticDataError("No valid categories found in base values")
|
|
193
|
-
|
|
194
|
-
return (
|
|
195
|
-
pl.Series("values", available_categories)
|
|
196
|
-
.sample(n=config.target_rows, with_replacement=True, seed=self.seed)
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
class HighCardinalityDistributionEngine(DistributionEngine):
|
|
201
|
-
"""Engine for high cardinality distribution strategy."""
|
|
202
|
-
|
|
203
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
204
|
-
"""Check if this engine supports high cardinality distribution."""
|
|
205
|
-
return strategy_type == DistributionType.HIGH_CARDINALITY
|
|
206
|
-
|
|
207
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
208
|
-
"""Validate high cardinality distribution configuration."""
|
|
209
|
-
result = self._validate_base_requirements(config)
|
|
210
|
-
|
|
211
|
-
# Check if we have enough base values for high cardinality
|
|
212
|
-
if len(config.base_values) < config.target_rows * 0.8:
|
|
213
|
-
result.add_warning(
|
|
214
|
-
f"Base values ({len(config.base_values)}) may be insufficient for high cardinality with {config.target_rows} rows"
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
return result
|
|
218
|
-
|
|
219
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
220
|
-
"""Apply high cardinality distribution - mostly unique values with controlled duplicates."""
|
|
221
|
-
validation = self.validate_config(config)
|
|
222
|
-
if not validation.is_valid:
|
|
223
|
-
raise ValidationError(f"Invalid configuration for high cardinality distribution: {validation.errors}")
|
|
224
|
-
|
|
225
|
-
# Strategy: Use most values once, then fill remainder with random selection
|
|
226
|
-
unique_count = min(len(config.base_values), int(config.target_rows * 0.9))
|
|
227
|
-
duplicate_count = config.target_rows - unique_count
|
|
228
|
-
|
|
229
|
-
# Create base series for sampling
|
|
230
|
-
base_series = pl.Series("values", config.base_values)
|
|
231
|
-
|
|
232
|
-
# Get unique values first
|
|
233
|
-
unique_values = base_series.head(unique_count)
|
|
234
|
-
|
|
235
|
-
# Fill remainder with random selection if needed
|
|
236
|
-
if duplicate_count > 0:
|
|
237
|
-
duplicate_values = base_series.sample(n=duplicate_count, with_replacement=True, seed=self.seed)
|
|
238
|
-
# Both series now have the same dtype since they come from the same base series
|
|
239
|
-
result = pl.concat([unique_values, duplicate_values])
|
|
240
|
-
else:
|
|
241
|
-
result = unique_values
|
|
242
|
-
|
|
243
|
-
# Shuffle the final result - handle case where we have fewer values than requested
|
|
244
|
-
if result.len() < config.target_rows:
|
|
245
|
-
# Need to sample with replacement to reach target
|
|
246
|
-
return result.sample(n=config.target_rows, with_replacement=True, seed=self.seed)
|
|
247
|
-
else:
|
|
248
|
-
return result.sample(n=config.target_rows, with_replacement=False, seed=self.seed)
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
class NumericRangeDistributionEngine(DistributionEngine):
|
|
252
|
-
"""Engine for numeric range distribution strategy."""
|
|
253
|
-
|
|
254
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
255
|
-
"""Check if this engine supports numeric range distribution."""
|
|
256
|
-
return strategy_type == DistributionType.NUMERIC_RANGE
|
|
257
|
-
|
|
258
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
259
|
-
"""Validate numeric range distribution configuration."""
|
|
260
|
-
result = self._validate_base_requirements(config)
|
|
261
|
-
|
|
262
|
-
min_val = config.strategy.parameters.get('min')
|
|
263
|
-
max_val = config.strategy.parameters.get('max')
|
|
264
|
-
|
|
265
|
-
if min_val is None:
|
|
266
|
-
result.add_error("Numeric range requires 'min' parameter", "Add 'min' parameter with minimum value")
|
|
267
|
-
|
|
268
|
-
if max_val is None:
|
|
269
|
-
result.add_error("Numeric range requires 'max' parameter", "Add 'max' parameter with maximum value")
|
|
270
|
-
|
|
271
|
-
if min_val is not None and max_val is not None:
|
|
272
|
-
if not isinstance(min_val, (int, float)):
|
|
273
|
-
result.add_error("Min value must be numeric", "Change 'min' to a number")
|
|
274
|
-
|
|
275
|
-
if not isinstance(max_val, (int, float)):
|
|
276
|
-
result.add_error("Max value must be numeric", "Change 'max' to a number")
|
|
277
|
-
|
|
278
|
-
if isinstance(min_val, (int, float)) and isinstance(max_val, (int, float)) and min_val >= max_val:
|
|
279
|
-
result.add_error("Min value must be less than max value", "Ensure min < max")
|
|
280
|
-
|
|
281
|
-
return result
|
|
282
|
-
|
|
283
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
284
|
-
"""Apply numeric range distribution - filter values within specified range."""
|
|
285
|
-
validation = self.validate_config(config)
|
|
286
|
-
if not validation.is_valid:
|
|
287
|
-
raise ValidationError(f"Invalid configuration for numeric range distribution: {validation.errors}")
|
|
288
|
-
|
|
289
|
-
min_val = config.strategy.parameters['min']
|
|
290
|
-
max_val = config.strategy.parameters['max']
|
|
291
|
-
|
|
292
|
-
# Filter base values to numeric values within range
|
|
293
|
-
numeric_values = []
|
|
294
|
-
for val in config.base_values:
|
|
295
|
-
try:
|
|
296
|
-
num_val = float(val)
|
|
297
|
-
if min_val <= num_val <= max_val:
|
|
298
|
-
numeric_values.append(val)
|
|
299
|
-
except (ValueError, TypeError):
|
|
300
|
-
continue
|
|
301
|
-
|
|
302
|
-
if not numeric_values:
|
|
303
|
-
raise SyntheticDataError(f"No numeric values found in range [{min_val}, {max_val}]")
|
|
304
|
-
|
|
305
|
-
return (
|
|
306
|
-
pl.Series("values", numeric_values)
|
|
307
|
-
.sample(n=config.target_rows, with_replacement=True, seed=self.seed)
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
class SkewedDistributionEngine(DistributionEngine):
|
|
312
|
-
"""Engine for skewed distribution strategy (80/20 rule)."""
|
|
313
|
-
|
|
314
|
-
def supports_strategy(self, strategy_type: DistributionType) -> bool:
|
|
315
|
-
"""Check if this engine supports skewed distribution."""
|
|
316
|
-
return strategy_type == DistributionType.SKEWED
|
|
317
|
-
|
|
318
|
-
def validate_config(self, config: DistributionConfig) -> ValidationResult:
|
|
319
|
-
"""Validate skewed distribution configuration."""
|
|
320
|
-
result = self._validate_base_requirements(config)
|
|
321
|
-
|
|
322
|
-
# Need at least 2 values for skewed distribution
|
|
323
|
-
if len(config.base_values) < 2:
|
|
324
|
-
result.add_error("Skewed distribution requires at least 2 base values", "Provide more base values")
|
|
325
|
-
|
|
326
|
-
return result
|
|
327
|
-
|
|
328
|
-
def apply_distribution(self, config: DistributionConfig) -> pl.Series:
|
|
329
|
-
"""Apply skewed distribution - 80% from 20% of values."""
|
|
330
|
-
validation = self.validate_config(config)
|
|
331
|
-
if not validation.is_valid:
|
|
332
|
-
raise ValidationError(f"Invalid configuration for skewed distribution: {validation.errors}")
|
|
333
|
-
|
|
334
|
-
# Calculate 20% of values (at least 1)
|
|
335
|
-
popular_count = max(1, len(config.base_values) // 5)
|
|
336
|
-
popular_values = config.base_values[:popular_count]
|
|
337
|
-
remaining_values = config.base_values[popular_count:]
|
|
338
|
-
|
|
339
|
-
# 80% of rows from popular values, 20% from remaining
|
|
340
|
-
popular_rows = int(config.target_rows * 0.8)
|
|
341
|
-
remaining_rows = config.target_rows - popular_rows
|
|
342
|
-
|
|
343
|
-
# Generate popular values
|
|
344
|
-
popular_series = (
|
|
345
|
-
pl.Series("values", popular_values)
|
|
346
|
-
.sample(n=popular_rows, with_replacement=True, seed=self.seed)
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
# Generate remaining values
|
|
350
|
-
if remaining_values and remaining_rows > 0:
|
|
351
|
-
remaining_series = (
|
|
352
|
-
pl.Series("values", remaining_values)
|
|
353
|
-
.sample(n=remaining_rows, with_replacement=True, seed=self.seed)
|
|
354
|
-
)
|
|
355
|
-
result = pl.concat([popular_series, remaining_series])
|
|
356
|
-
else:
|
|
357
|
-
# If no remaining values, use all from popular
|
|
358
|
-
result = popular_series
|
|
359
|
-
|
|
360
|
-
# Shuffle the final result
|
|
361
|
-
return result.sample(n=config.target_rows, with_replacement=False, seed=self.seed)
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
class DistributionEngineFactory:
|
|
365
|
-
"""Factory for creating distribution engines with plugin support."""
|
|
366
|
-
|
|
367
|
-
# Class-level registry for custom engines (shared across all instances)
|
|
368
|
-
_custom_engines: List[type] = []
|
|
369
|
-
|
|
370
|
-
def __init__(self):
|
|
371
|
-
"""Initialize the factory with all available engines."""
|
|
372
|
-
self._engines = [
|
|
373
|
-
EqualDistributionEngine,
|
|
374
|
-
CustomDistributionEngine,
|
|
375
|
-
CategoricalDistributionEngine,
|
|
376
|
-
HighCardinalityDistributionEngine,
|
|
377
|
-
NumericRangeDistributionEngine,
|
|
378
|
-
SkewedDistributionEngine,
|
|
379
|
-
]
|
|
380
|
-
|
|
381
|
-
@classmethod
|
|
382
|
-
def register_custom_engine(cls, engine_class: type) -> None:
|
|
383
|
-
"""
|
|
384
|
-
Register a custom distribution engine.
|
|
385
|
-
|
|
386
|
-
Args:
|
|
387
|
-
engine_class: A class that inherits from DistributionEngine
|
|
388
|
-
|
|
389
|
-
Raises:
|
|
390
|
-
ValidationError: If the engine class is invalid
|
|
391
|
-
|
|
392
|
-
Example:
|
|
393
|
-
>>> class MyCustomEngine(DistributionEngine):
|
|
394
|
-
... def supports_strategy(self, strategy_type):
|
|
395
|
-
... return strategy_type == DistributionType.CUSTOM
|
|
396
|
-
... # ... implement other methods
|
|
397
|
-
>>> DistributionEngineFactory.register_custom_engine(MyCustomEngine)
|
|
398
|
-
"""
|
|
399
|
-
# Validate that the class inherits from DistributionEngine
|
|
400
|
-
if not issubclass(engine_class, DistributionEngine):
|
|
401
|
-
raise ValidationError(
|
|
402
|
-
f"Custom engine must inherit from DistributionEngine, got {engine_class.__name__}"
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
# Validate that required methods are implemented
|
|
406
|
-
required_methods = ['supports_strategy', 'validate_config', 'apply_distribution']
|
|
407
|
-
for method_name in required_methods:
|
|
408
|
-
if not hasattr(engine_class, method_name):
|
|
409
|
-
raise ValidationError(
|
|
410
|
-
f"Custom engine {engine_class.__name__} must implement {method_name} method"
|
|
411
|
-
)
|
|
412
|
-
|
|
413
|
-
# Check if engine is already registered
|
|
414
|
-
if engine_class in cls._custom_engines:
|
|
415
|
-
raise ValidationError(
|
|
416
|
-
f"Engine {engine_class.__name__} is already registered"
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
# Add to registry
|
|
420
|
-
cls._custom_engines.append(engine_class)
|
|
421
|
-
|
|
422
|
-
@classmethod
|
|
423
|
-
def unregister_custom_engine(cls, engine_class: type) -> None:
|
|
424
|
-
"""
|
|
425
|
-
Unregister a custom distribution engine.
|
|
426
|
-
|
|
427
|
-
Args:
|
|
428
|
-
engine_class: The engine class to unregister
|
|
429
|
-
|
|
430
|
-
Raises:
|
|
431
|
-
ValidationError: If the engine is not registered
|
|
432
|
-
"""
|
|
433
|
-
if engine_class not in cls._custom_engines:
|
|
434
|
-
raise ValidationError(
|
|
435
|
-
f"Engine {engine_class.__name__} is not registered"
|
|
436
|
-
)
|
|
437
|
-
cls._custom_engines.remove(engine_class)
|
|
438
|
-
|
|
439
|
-
@classmethod
|
|
440
|
-
def list_custom_engines(cls) -> List[type]:
|
|
441
|
-
"""Get list of all registered custom engines."""
|
|
442
|
-
return cls._custom_engines.copy()
|
|
443
|
-
|
|
444
|
-
@classmethod
|
|
445
|
-
def clear_custom_engines(cls) -> None:
|
|
446
|
-
"""Clear all registered custom engines (useful for testing)."""
|
|
447
|
-
cls._custom_engines.clear()
|
|
448
|
-
|
|
449
|
-
def create_engine(self, strategy_type: DistributionType, seed: Optional[int] = None) -> DistributionEngine:
|
|
450
|
-
"""
|
|
451
|
-
Create an appropriate engine for the given strategy type.
|
|
452
|
-
|
|
453
|
-
Custom engines are checked first, then built-in engines.
|
|
454
|
-
"""
|
|
455
|
-
# Check custom engines first (allows overriding built-in engines)
|
|
456
|
-
for engine_class in self._custom_engines:
|
|
457
|
-
engine = engine_class(seed=seed)
|
|
458
|
-
if engine.supports_strategy(strategy_type):
|
|
459
|
-
return engine
|
|
460
|
-
|
|
461
|
-
# Check built-in engines
|
|
462
|
-
for engine_class in self._engines:
|
|
463
|
-
engine = engine_class(seed=seed)
|
|
464
|
-
if engine.supports_strategy(strategy_type):
|
|
465
|
-
return engine
|
|
466
|
-
|
|
467
|
-
raise SyntheticDataError(f"No engine available for distribution strategy: {strategy_type}")
|
|
468
|
-
|
|
469
|
-
def get_supported_strategies(self) -> List[DistributionType]:
|
|
470
|
-
"""Get list of all supported distribution strategies (built-in and custom)."""
|
|
471
|
-
strategies = []
|
|
472
|
-
|
|
473
|
-
# Check custom engines
|
|
474
|
-
for engine_class in self._custom_engines:
|
|
475
|
-
engine = engine_class()
|
|
476
|
-
for strategy_type in DistributionType:
|
|
477
|
-
if engine.supports_strategy(strategy_type):
|
|
478
|
-
strategies.append(strategy_type)
|
|
479
|
-
|
|
480
|
-
# Check built-in engines
|
|
481
|
-
for engine_class in self._engines:
|
|
482
|
-
engine = engine_class()
|
|
483
|
-
for strategy_type in DistributionType:
|
|
484
|
-
if engine.supports_strategy(strategy_type):
|
|
485
|
-
strategies.append(strategy_type)
|
|
486
|
-
|
|
487
|
-
return list(set(strategies))
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
class DistributionManager:
|
|
491
|
-
"""High-level manager for distribution strategy operations."""
|
|
492
|
-
|
|
493
|
-
def __init__(self, seed: Optional[int] = None):
|
|
494
|
-
"""Initialize the distribution manager."""
|
|
495
|
-
self.factory = DistributionEngineFactory()
|
|
496
|
-
self.seed = seed
|
|
497
|
-
|
|
498
|
-
def apply_distribution(self, strategy: DistributionStrategy, base_values: List[str], target_rows: int) -> pl.Series:
|
|
499
|
-
"""Apply a distribution strategy to base values."""
|
|
500
|
-
config = DistributionConfig(
|
|
501
|
-
strategy=strategy,
|
|
502
|
-
base_values=base_values,
|
|
503
|
-
target_rows=target_rows,
|
|
504
|
-
seed=self.seed
|
|
505
|
-
)
|
|
506
|
-
|
|
507
|
-
engine = self.factory.create_engine(strategy.strategy_type, self.seed)
|
|
508
|
-
return engine.apply_distribution(config)
|
|
509
|
-
|
|
510
|
-
def validate_strategy(self, strategy: DistributionStrategy, base_values: List[str], target_rows: int) -> ValidationResult:
|
|
511
|
-
"""Validate a distribution strategy configuration."""
|
|
512
|
-
config = DistributionConfig(
|
|
513
|
-
strategy=strategy,
|
|
514
|
-
base_values=base_values,
|
|
515
|
-
target_rows=target_rows,
|
|
516
|
-
seed=self.seed
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
try:
|
|
520
|
-
engine = self.factory.create_engine(strategy.strategy_type, self.seed)
|
|
521
|
-
return engine.validate_config(config)
|
|
522
|
-
except SyntheticDataError as e:
|
|
523
|
-
result = ValidationResult(is_valid=False)
|
|
524
|
-
result.add_error(str(e))
|
|
525
|
-
return result
|
|
526
|
-
|
|
527
|
-
def get_supported_strategies(self) -> List[DistributionType]:
|
|
528
|
-
"""Get list of supported distribution strategies."""
|
|
529
|
-
return self.factory.get_supported_strategies()
|