misata 0.1.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- misata/__init__.py +48 -0
- misata/api.py +460 -0
- misata/audit.py +415 -0
- misata/benchmark.py +376 -0
- misata/cli.py +680 -0
- misata/codegen.py +153 -0
- misata/curve_fitting.py +106 -0
- misata/customization.py +256 -0
- misata/feedback.py +433 -0
- misata/formulas.py +362 -0
- misata/generators.py +247 -0
- misata/hybrid.py +398 -0
- misata/llm_parser.py +493 -0
- misata/noise.py +346 -0
- misata/schema.py +252 -0
- misata/semantic.py +185 -0
- misata/simulator.py +742 -0
- misata/story_parser.py +425 -0
- misata/templates/__init__.py +444 -0
- misata/validation.py +313 -0
- misata-0.1.0b0.dist-info/METADATA +291 -0
- misata-0.1.0b0.dist-info/RECORD +25 -0
- misata-0.1.0b0.dist-info/WHEEL +5 -0
- misata-0.1.0b0.dist-info/entry_points.txt +2 -0
- misata-0.1.0b0.dist-info/top_level.txt +1 -0
misata/noise.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Noise injection module for realistic ML training data.
|
|
3
|
+
|
|
4
|
+
Adds real-world imperfections to synthetic data:
|
|
5
|
+
- Missing values (nulls/NaN)
|
|
6
|
+
- Outliers
|
|
7
|
+
- Typos and data entry errors
|
|
8
|
+
- Duplicates and near-duplicates
|
|
9
|
+
- Distribution drift over time
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import random
|
|
13
|
+
import string
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NoiseInjector:
|
|
21
|
+
"""
|
|
22
|
+
Inject realistic noise and imperfections into synthetic data.
|
|
23
|
+
|
|
24
|
+
Makes data suitable for ML training by adding real-world issues:
|
|
25
|
+
- Missing values at configurable rates
|
|
26
|
+
- Statistical outliers
|
|
27
|
+
- Typos in text fields
|
|
28
|
+
- Duplicate rows
|
|
29
|
+
- Temporal distribution shifts
|
|
30
|
+
|
|
31
|
+
Usage:
|
|
32
|
+
injector = NoiseInjector(seed=42)
|
|
33
|
+
noisy_df = injector.apply(df, config={
|
|
34
|
+
"null_rate": 0.05,
|
|
35
|
+
"outlier_rate": 0.02,
|
|
36
|
+
"typo_rate": 0.01,
|
|
37
|
+
"duplicate_rate": 0.03,
|
|
38
|
+
})
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, seed: Optional[int] = None):
|
|
42
|
+
"""Initialize with optional random seed for reproducibility."""
|
|
43
|
+
self.rng = np.random.default_rng(seed)
|
|
44
|
+
self.py_rng = random.Random(seed)
|
|
45
|
+
|
|
46
|
+
def apply(
|
|
47
|
+
self,
|
|
48
|
+
df: pd.DataFrame,
|
|
49
|
+
config: Optional[Dict[str, Any]] = None,
|
|
50
|
+
) -> pd.DataFrame:
|
|
51
|
+
"""
|
|
52
|
+
Apply all configured noise types to a DataFrame.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
df: Input DataFrame
|
|
56
|
+
config: Noise configuration dict with rates for each type
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
DataFrame with noise applied
|
|
60
|
+
"""
|
|
61
|
+
if config is None:
|
|
62
|
+
config = {}
|
|
63
|
+
|
|
64
|
+
result = df.copy()
|
|
65
|
+
|
|
66
|
+
# Apply each noise type
|
|
67
|
+
if config.get("null_rate", 0) > 0:
|
|
68
|
+
result = self.inject_nulls(result, rate=config["null_rate"],
|
|
69
|
+
columns=config.get("null_columns"))
|
|
70
|
+
|
|
71
|
+
if config.get("outlier_rate", 0) > 0:
|
|
72
|
+
result = self.inject_outliers(result, rate=config["outlier_rate"],
|
|
73
|
+
columns=config.get("outlier_columns"))
|
|
74
|
+
|
|
75
|
+
if config.get("typo_rate", 0) > 0:
|
|
76
|
+
result = self.inject_typos(result, rate=config["typo_rate"],
|
|
77
|
+
columns=config.get("typo_columns"))
|
|
78
|
+
|
|
79
|
+
if config.get("duplicate_rate", 0) > 0:
|
|
80
|
+
result = self.inject_duplicates(result, rate=config["duplicate_rate"])
|
|
81
|
+
|
|
82
|
+
return result
|
|
83
|
+
|
|
84
|
+
def inject_nulls(
|
|
85
|
+
self,
|
|
86
|
+
df: pd.DataFrame,
|
|
87
|
+
rate: float = 0.05,
|
|
88
|
+
columns: Optional[List[str]] = None,
|
|
89
|
+
) -> pd.DataFrame:
|
|
90
|
+
"""
|
|
91
|
+
Inject null/missing values at a specified rate.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
df: Input DataFrame
|
|
95
|
+
rate: Probability of any cell becoming null (0.0-1.0)
|
|
96
|
+
columns: Specific columns to apply to (default: all except ID columns)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
DataFrame with nulls injected
|
|
100
|
+
"""
|
|
101
|
+
result = df.copy()
|
|
102
|
+
|
|
103
|
+
# Default: skip ID columns
|
|
104
|
+
if columns is None:
|
|
105
|
+
columns = [c for c in df.columns if not c.endswith('_id') and c != 'id']
|
|
106
|
+
|
|
107
|
+
for col in columns:
|
|
108
|
+
if col not in result.columns:
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
mask = self.rng.random(len(result)) < rate
|
|
112
|
+
result.loc[mask, col] = np.nan
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
def inject_outliers(
|
|
117
|
+
self,
|
|
118
|
+
df: pd.DataFrame,
|
|
119
|
+
rate: float = 0.02,
|
|
120
|
+
columns: Optional[List[str]] = None,
|
|
121
|
+
multiplier: float = 5.0,
|
|
122
|
+
) -> pd.DataFrame:
|
|
123
|
+
"""
|
|
124
|
+
Inject statistical outliers into numeric columns.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
df: Input DataFrame
|
|
128
|
+
rate: Probability of any numeric cell becoming an outlier
|
|
129
|
+
columns: Specific columns (default: all numeric)
|
|
130
|
+
multiplier: How extreme the outliers should be (times std dev)
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
DataFrame with outliers injected
|
|
134
|
+
"""
|
|
135
|
+
result = df.copy()
|
|
136
|
+
|
|
137
|
+
# Default: all numeric columns
|
|
138
|
+
if columns is None:
|
|
139
|
+
columns = result.select_dtypes(include=[np.number]).columns.tolist()
|
|
140
|
+
columns = [c for c in columns if not c.endswith('_id') and c != 'id']
|
|
141
|
+
|
|
142
|
+
for col in columns:
|
|
143
|
+
if col not in result.columns:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
series = result[col]
|
|
147
|
+
if not np.issubdtype(series.dtype, np.number):
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
mean = series.mean()
|
|
151
|
+
std = series.std()
|
|
152
|
+
|
|
153
|
+
if std == 0 or np.isnan(std):
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
mask = self.rng.random(len(result)) < rate
|
|
157
|
+
n_outliers = mask.sum()
|
|
158
|
+
|
|
159
|
+
if n_outliers > 0:
|
|
160
|
+
# Generate outliers above or below mean
|
|
161
|
+
direction = self.rng.choice([-1, 1], size=n_outliers)
|
|
162
|
+
outlier_values = mean + direction * multiplier * std * (1 + self.rng.random(n_outliers))
|
|
163
|
+
result.loc[mask, col] = outlier_values
|
|
164
|
+
|
|
165
|
+
return result
|
|
166
|
+
|
|
167
|
+
def inject_typos(
|
|
168
|
+
self,
|
|
169
|
+
df: pd.DataFrame,
|
|
170
|
+
rate: float = 0.01,
|
|
171
|
+
columns: Optional[List[str]] = None,
|
|
172
|
+
) -> pd.DataFrame:
|
|
173
|
+
"""
|
|
174
|
+
Inject typos into text columns.
|
|
175
|
+
|
|
176
|
+
Typo types:
|
|
177
|
+
- Character swap
|
|
178
|
+
- Character deletion
|
|
179
|
+
- Character insertion
|
|
180
|
+
- Case change
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
df: Input DataFrame
|
|
184
|
+
rate: Probability of any text cell getting a typo
|
|
185
|
+
columns: Specific columns (default: all object/string)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
DataFrame with typos injected
|
|
189
|
+
"""
|
|
190
|
+
result = df.copy()
|
|
191
|
+
|
|
192
|
+
# Default: all text columns
|
|
193
|
+
if columns is None:
|
|
194
|
+
columns = result.select_dtypes(include=['object', 'string']).columns.tolist()
|
|
195
|
+
# Skip columns that look like IDs or structured data
|
|
196
|
+
columns = [c for c in columns if 'id' not in c.lower() and 'email' not in c.lower()]
|
|
197
|
+
|
|
198
|
+
for col in columns:
|
|
199
|
+
if col not in result.columns:
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
mask = self.rng.random(len(result)) < rate
|
|
203
|
+
|
|
204
|
+
for idx in result.index[mask]:
|
|
205
|
+
value = result.at[idx, col]
|
|
206
|
+
if pd.isna(value) or not isinstance(value, str) or len(value) < 2:
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
result.at[idx, col] = self._add_typo(value)
|
|
210
|
+
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
def _add_typo(self, text: str) -> str:
|
|
214
|
+
"""Add a single typo to a text string."""
|
|
215
|
+
if len(text) < 2:
|
|
216
|
+
return text
|
|
217
|
+
|
|
218
|
+
typo_type = self.py_rng.choice(['swap', 'delete', 'insert', 'case'])
|
|
219
|
+
chars = list(text)
|
|
220
|
+
pos = self.py_rng.randint(0, len(chars) - 1)
|
|
221
|
+
|
|
222
|
+
if typo_type == 'swap' and pos < len(chars) - 1:
|
|
223
|
+
chars[pos], chars[pos + 1] = chars[pos + 1], chars[pos]
|
|
224
|
+
elif typo_type == 'delete':
|
|
225
|
+
chars.pop(pos)
|
|
226
|
+
elif typo_type == 'insert':
|
|
227
|
+
chars.insert(pos, self.py_rng.choice(string.ascii_lowercase))
|
|
228
|
+
elif typo_type == 'case':
|
|
229
|
+
chars[pos] = chars[pos].swapcase()
|
|
230
|
+
|
|
231
|
+
return ''.join(chars)
|
|
232
|
+
|
|
233
|
+
def inject_duplicates(
|
|
234
|
+
self,
|
|
235
|
+
df: pd.DataFrame,
|
|
236
|
+
rate: float = 0.03,
|
|
237
|
+
exact: bool = True,
|
|
238
|
+
) -> pd.DataFrame:
|
|
239
|
+
"""
|
|
240
|
+
Inject duplicate rows.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
df: Input DataFrame
|
|
244
|
+
rate: Rate of rows to duplicate
|
|
245
|
+
exact: If True, exact duplicates. If False, near-duplicates with slight variations.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
DataFrame with duplicates added
|
|
249
|
+
"""
|
|
250
|
+
n_duplicates = int(len(df) * rate)
|
|
251
|
+
|
|
252
|
+
if n_duplicates == 0:
|
|
253
|
+
return df
|
|
254
|
+
|
|
255
|
+
# Select random rows to duplicate
|
|
256
|
+
dup_indices = self.rng.choice(df.index, size=n_duplicates, replace=True)
|
|
257
|
+
duplicates = df.loc[dup_indices].copy()
|
|
258
|
+
|
|
259
|
+
if not exact:
|
|
260
|
+
# Add slight variations to numeric columns
|
|
261
|
+
for col in duplicates.select_dtypes(include=[np.number]).columns:
|
|
262
|
+
if col.endswith('_id') or col == 'id':
|
|
263
|
+
continue
|
|
264
|
+
noise = self.rng.normal(0, 0.01, len(duplicates))
|
|
265
|
+
duplicates[col] = duplicates[col] * (1 + noise)
|
|
266
|
+
|
|
267
|
+
return pd.concat([df, duplicates], ignore_index=True)
|
|
268
|
+
|
|
269
|
+
def apply_temporal_drift(
|
|
270
|
+
self,
|
|
271
|
+
df: pd.DataFrame,
|
|
272
|
+
date_column: str,
|
|
273
|
+
value_column: str,
|
|
274
|
+
drift_rate: float = 0.1,
|
|
275
|
+
drift_direction: str = "up",
|
|
276
|
+
) -> pd.DataFrame:
|
|
277
|
+
"""
|
|
278
|
+
Apply temporal distribution drift to simulate changing trends.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
df: Input DataFrame
|
|
282
|
+
date_column: Column containing dates
|
|
283
|
+
value_column: Numeric column to apply drift to
|
|
284
|
+
drift_rate: Rate of drift (0.1 = 10% change over time range)
|
|
285
|
+
drift_direction: "up" for increasing, "down" for decreasing
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
DataFrame with temporal drift applied
|
|
289
|
+
"""
|
|
290
|
+
result = df.copy()
|
|
291
|
+
|
|
292
|
+
if date_column not in result.columns or value_column not in result.columns:
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
dates = pd.to_datetime(result[date_column])
|
|
296
|
+
min_date = dates.min()
|
|
297
|
+
max_date = dates.max()
|
|
298
|
+
|
|
299
|
+
if min_date == max_date:
|
|
300
|
+
return result
|
|
301
|
+
|
|
302
|
+
# Normalize dates to 0-1 range
|
|
303
|
+
time_fraction = (dates - min_date) / (max_date - min_date)
|
|
304
|
+
|
|
305
|
+
# Calculate drift multiplier
|
|
306
|
+
multiplier = 1 + (drift_rate * time_fraction if drift_direction == "up"
|
|
307
|
+
else -drift_rate * time_fraction)
|
|
308
|
+
|
|
309
|
+
result[value_column] = result[value_column] * multiplier
|
|
310
|
+
|
|
311
|
+
return result
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# Convenience function
|
|
315
|
+
def add_noise(
|
|
316
|
+
df: pd.DataFrame,
|
|
317
|
+
null_rate: float = 0.0,
|
|
318
|
+
outlier_rate: float = 0.0,
|
|
319
|
+
typo_rate: float = 0.0,
|
|
320
|
+
duplicate_rate: float = 0.0,
|
|
321
|
+
seed: Optional[int] = None,
|
|
322
|
+
) -> pd.DataFrame:
|
|
323
|
+
"""
|
|
324
|
+
Convenience function to add noise to a DataFrame.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
df: Input DataFrame
|
|
328
|
+
null_rate: Rate of null value injection (0.0-1.0)
|
|
329
|
+
outlier_rate: Rate of outlier injection
|
|
330
|
+
typo_rate: Rate of typo injection in text
|
|
331
|
+
duplicate_rate: Rate of duplicate rows
|
|
332
|
+
seed: Random seed for reproducibility
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
DataFrame with noise applied
|
|
336
|
+
|
|
337
|
+
Example:
|
|
338
|
+
noisy_df = add_noise(df, null_rate=0.05, outlier_rate=0.02)
|
|
339
|
+
"""
|
|
340
|
+
injector = NoiseInjector(seed=seed)
|
|
341
|
+
return injector.apply(df, config={
|
|
342
|
+
"null_rate": null_rate,
|
|
343
|
+
"outlier_rate": outlier_rate,
|
|
344
|
+
"typo_rate": typo_rate,
|
|
345
|
+
"duplicate_rate": duplicate_rate,
|
|
346
|
+
})
|
misata/schema.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic models for Misata configuration.
|
|
3
|
+
|
|
4
|
+
These models define the blueprint for synthetic data generation,
|
|
5
|
+
including tables, columns, relationships, and scenario events.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, field_validator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Column(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Defines a single column in a table.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
name: Column name
|
|
19
|
+
type: Data type (int, float, date, categorical, foreign_key, text)
|
|
20
|
+
distribution_params: Parameters for data generation (mean, std, choices, etc.)
|
|
21
|
+
nullable: Whether the column can contain NULL values
|
|
22
|
+
unique: Whether values must be unique
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
name: str
|
|
26
|
+
type: Literal["int", "float", "date", "categorical", "foreign_key", "text", "boolean"]
|
|
27
|
+
distribution_params: Dict[str, Any] = Field(default_factory=dict)
|
|
28
|
+
nullable: bool = False
|
|
29
|
+
unique: bool = False
|
|
30
|
+
|
|
31
|
+
@field_validator("distribution_params")
|
|
32
|
+
@classmethod
|
|
33
|
+
def validate_params(cls, v: Dict[str, Any], info: Any) -> Dict[str, Any]:
|
|
34
|
+
"""Validate distribution parameters based on column type."""
|
|
35
|
+
col_type = info.data.get("type")
|
|
36
|
+
|
|
37
|
+
if col_type == "categorical" and "choices" not in v:
|
|
38
|
+
raise ValueError("Categorical columns must have 'choices' in distribution_params")
|
|
39
|
+
|
|
40
|
+
if col_type == "date":
|
|
41
|
+
if "relative_to" not in v:
|
|
42
|
+
if "start" not in v or "end" not in v:
|
|
43
|
+
raise ValueError("Date columns must have 'start' and 'end' OR 'relative_to' in distribution_params")
|
|
44
|
+
|
|
45
|
+
if col_type in ["int", "float"]:
|
|
46
|
+
if "distribution" not in v:
|
|
47
|
+
v["distribution"] = "normal" # Default to normal distribution
|
|
48
|
+
|
|
49
|
+
return v
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Table(BaseModel):
|
|
53
|
+
"""
|
|
54
|
+
Defines a table to be generated.
|
|
55
|
+
|
|
56
|
+
Tables can be either:
|
|
57
|
+
- Reference tables: Small lookup tables with LLM-generated actual data (exercises, plans)
|
|
58
|
+
- Transactional tables: Mass-generated tables using foreign keys to reference tables
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
name: Table name
|
|
62
|
+
row_count: Number of rows to generate (ignored if inline_data is provided)
|
|
63
|
+
description: Optional description of the table's purpose
|
|
64
|
+
is_reference: If True, this is a lookup/reference table
|
|
65
|
+
inline_data: Actual data rows for reference tables (list of dicts)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
name: str
|
|
69
|
+
row_count: int = Field(default=100, gt=0)
|
|
70
|
+
description: Optional[str] = None
|
|
71
|
+
is_reference: bool = False
|
|
72
|
+
inline_data: Optional[List[Dict[str, Any]]] = None
|
|
73
|
+
constraints: List["Constraint"] = Field(default_factory=list)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Relationship(BaseModel):
|
|
78
|
+
"""
|
|
79
|
+
Defines a parent-child relationship between tables.
|
|
80
|
+
|
|
81
|
+
Ensures referential integrity by constraining child foreign keys
|
|
82
|
+
to existing parent primary keys.
|
|
83
|
+
|
|
84
|
+
Attributes:
|
|
85
|
+
parent_table: Name of the parent table
|
|
86
|
+
child_table: Name of the child table
|
|
87
|
+
parent_key: Column name in parent table (usually primary key)
|
|
88
|
+
child_key: Column name in child table (foreign key)
|
|
89
|
+
temporal_constraint: If True, child events must occur after parent events
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
parent_table: str
|
|
93
|
+
child_table: str
|
|
94
|
+
parent_key: str
|
|
95
|
+
child_key: str
|
|
96
|
+
temporal_constraint: bool = False
|
|
97
|
+
filters: Optional[Dict[str, Any]] = None # e.g., {"status": "active"}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Constraint(BaseModel):
|
|
101
|
+
"""
|
|
102
|
+
Defines a business rule constraint to enforce during generation.
|
|
103
|
+
|
|
104
|
+
Constraints are applied after generating a batch to ensure data
|
|
105
|
+
adheres to real-world business rules.
|
|
106
|
+
|
|
107
|
+
Attributes:
|
|
108
|
+
name: Descriptive name of the constraint
|
|
109
|
+
type: Type of constraint (max_per_group, min_per_group, unique_combination, sum_limit)
|
|
110
|
+
group_by: List of columns to group by (e.g., ["employee_id", "date"])
|
|
111
|
+
column: The column to constrain
|
|
112
|
+
value: The constraint value (e.g., 8 for max 8 hours)
|
|
113
|
+
action: What to do when constraint is violated (cap, redistribute, error)
|
|
114
|
+
|
|
115
|
+
Examples:
|
|
116
|
+
# Max 8 hours per employee per day
|
|
117
|
+
Constraint(
|
|
118
|
+
name="max_daily_hours",
|
|
119
|
+
type="max_per_group",
|
|
120
|
+
group_by=["employee_id", "date"],
|
|
121
|
+
column="hours",
|
|
122
|
+
value=8,
|
|
123
|
+
action="cap"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Each employee-project-date combination must be unique
|
|
127
|
+
Constraint(
|
|
128
|
+
name="unique_timesheet_entry",
|
|
129
|
+
type="unique_combination",
|
|
130
|
+
group_by=["employee_id", "project_id", "date"],
|
|
131
|
+
action="drop"
|
|
132
|
+
)
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
name: str
|
|
136
|
+
type: Literal["max_per_group", "min_per_group", "sum_limit", "unique_combination"]
|
|
137
|
+
group_by: List[str] = Field(default_factory=list)
|
|
138
|
+
column: Optional[str] = None # Not needed for unique_combination
|
|
139
|
+
value: Optional[float] = None # The limit value
|
|
140
|
+
action: Literal["cap", "redistribute", "drop", "error"] = "cap"
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ScenarioEvent(BaseModel):
|
|
144
|
+
"""
|
|
145
|
+
Defines a time-based or conditional modifier to apply to data.
|
|
146
|
+
|
|
147
|
+
This is the "story" layer - events that force data to follow
|
|
148
|
+
specific patterns (growth, crashes, seasonality, etc.).
|
|
149
|
+
|
|
150
|
+
Attributes:
|
|
151
|
+
name: Descriptive name of the event
|
|
152
|
+
table: Table to apply the event to
|
|
153
|
+
column: Column to modify
|
|
154
|
+
condition: Python expression evaluated on the DataFrame (e.g., "date > '2023-11-01'")
|
|
155
|
+
modifier_type: Type of modification (multiply, add, set, function)
|
|
156
|
+
modifier_value: Value or function to apply
|
|
157
|
+
description: Optional description of what this event represents
|
|
158
|
+
|
|
159
|
+
Examples:
|
|
160
|
+
# Revenue crash
|
|
161
|
+
ScenarioEvent(
|
|
162
|
+
name="Q3_Revenue_Crash",
|
|
163
|
+
table="sales",
|
|
164
|
+
column="revenue",
|
|
165
|
+
condition="date >= '2023-07-01' and date < '2023-10-01'",
|
|
166
|
+
modifier_type="multiply",
|
|
167
|
+
modifier_value=0.5
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Set all churned users
|
|
171
|
+
ScenarioEvent(
|
|
172
|
+
name="Churn_Flag",
|
|
173
|
+
table="users",
|
|
174
|
+
column="churned",
|
|
175
|
+
condition="signup_date < '2023-06-01'",
|
|
176
|
+
modifier_type="set",
|
|
177
|
+
modifier_value=True
|
|
178
|
+
)
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
name: str
|
|
182
|
+
table: str
|
|
183
|
+
column: str
|
|
184
|
+
condition: str
|
|
185
|
+
modifier_type: Literal["multiply", "add", "set", "function"]
|
|
186
|
+
modifier_value: Union[int, float, str, bool]
|
|
187
|
+
description: Optional[str] = None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class SchemaConfig(BaseModel):
|
|
191
|
+
"""
|
|
192
|
+
Complete configuration for synthetic data generation.
|
|
193
|
+
|
|
194
|
+
This is the root configuration object that defines all tables,
|
|
195
|
+
columns, relationships, and scenario events.
|
|
196
|
+
|
|
197
|
+
Attributes:
|
|
198
|
+
name: Name of the dataset/scenario
|
|
199
|
+
description: Description of what this data represents
|
|
200
|
+
tables: List of tables to generate
|
|
201
|
+
columns: Mapping of table names to their column definitions
|
|
202
|
+
relationships: List of inter-table relationships
|
|
203
|
+
events: List of scenario events to apply
|
|
204
|
+
seed: Random seed for reproducibility
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
name: str
|
|
208
|
+
description: Optional[str] = None
|
|
209
|
+
tables: List[Table]
|
|
210
|
+
columns: Dict[str, List[Column]]
|
|
211
|
+
relationships: List[Relationship] = Field(default_factory=list)
|
|
212
|
+
events: List[ScenarioEvent] = Field(default_factory=list)
|
|
213
|
+
seed: Optional[int] = None
|
|
214
|
+
|
|
215
|
+
@field_validator("columns")
|
|
216
|
+
@classmethod
|
|
217
|
+
def validate_columns(cls, v: Dict[str, List[Column]], info: Any) -> Dict[str, List[Column]]:
|
|
218
|
+
"""Ensure all tables have column definitions."""
|
|
219
|
+
tables = info.data.get("tables", [])
|
|
220
|
+
table_names = {t.name for t in tables}
|
|
221
|
+
|
|
222
|
+
for table_name in table_names:
|
|
223
|
+
if table_name not in v:
|
|
224
|
+
raise ValueError(f"Table '{table_name}' has no column definitions")
|
|
225
|
+
|
|
226
|
+
return v
|
|
227
|
+
|
|
228
|
+
@field_validator("relationships")
|
|
229
|
+
@classmethod
|
|
230
|
+
def validate_relationships(cls, v: List[Relationship], info: Any) -> List[Relationship]:
|
|
231
|
+
"""Ensure relationship references exist."""
|
|
232
|
+
tables = info.data.get("tables", [])
|
|
233
|
+
table_names = {t.name for t in tables}
|
|
234
|
+
|
|
235
|
+
for rel in v:
|
|
236
|
+
if rel.parent_table not in table_names:
|
|
237
|
+
raise ValueError(f"Parent table '{rel.parent_table}' not found in schema")
|
|
238
|
+
if rel.child_table not in table_names:
|
|
239
|
+
raise ValueError(f"Child table '{rel.child_table}' not found in schema")
|
|
240
|
+
|
|
241
|
+
return v
|
|
242
|
+
|
|
243
|
+
def get_table(self, name: str) -> Optional[Table]:
|
|
244
|
+
"""Get a table by name."""
|
|
245
|
+
for table in self.tables:
|
|
246
|
+
if table.name == name:
|
|
247
|
+
return table
|
|
248
|
+
return None
|
|
249
|
+
|
|
250
|
+
def get_columns(self, table_name: str) -> List[Column]:
|
|
251
|
+
"""Get columns for a specific table."""
|
|
252
|
+
return self.columns.get(table_name, [])
|