misata 0.2.0b0__tar.gz → 0.3.1b0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {misata-0.2.0b0 → misata-0.3.1b0}/PKG-INFO +1 -1
  2. misata-0.3.1b0/misata/__init__.py +134 -0
  3. misata-0.3.1b0/misata/cache.py +258 -0
  4. misata-0.3.1b0/misata/constraints.py +307 -0
  5. misata-0.3.1b0/misata/context.py +259 -0
  6. misata-0.3.1b0/misata/exceptions.py +277 -0
  7. misata-0.3.1b0/misata/generators/__init__.py +29 -0
  8. misata-0.3.1b0/misata/generators/base.py +586 -0
  9. misata-0.3.1b0/misata/profiles.py +332 -0
  10. {misata-0.2.0b0 → misata-0.3.1b0}/misata/simulator.py +133 -12
  11. {misata-0.2.0b0 → misata-0.3.1b0}/misata/smart_values.py +171 -2
  12. misata-0.3.1b0/misata/streaming.py +228 -0
  13. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/PKG-INFO +1 -1
  14. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/SOURCES.txt +9 -1
  15. {misata-0.2.0b0 → misata-0.3.1b0}/pyproject.toml +1 -1
  16. misata-0.2.0b0/misata/__init__.py +0 -59
  17. {misata-0.2.0b0 → misata-0.3.1b0}/LICENSE +0 -0
  18. {misata-0.2.0b0 → misata-0.3.1b0}/README.md +0 -0
  19. {misata-0.2.0b0 → misata-0.3.1b0}/misata/api.py +0 -0
  20. {misata-0.2.0b0 → misata-0.3.1b0}/misata/audit.py +0 -0
  21. {misata-0.2.0b0 → misata-0.3.1b0}/misata/benchmark.py +0 -0
  22. {misata-0.2.0b0 → misata-0.3.1b0}/misata/cli.py +0 -0
  23. {misata-0.2.0b0 → misata-0.3.1b0}/misata/codegen.py +0 -0
  24. {misata-0.2.0b0 → misata-0.3.1b0}/misata/curve_fitting.py +0 -0
  25. {misata-0.2.0b0 → misata-0.3.1b0}/misata/customization.py +0 -0
  26. {misata-0.2.0b0 → misata-0.3.1b0}/misata/feedback.py +0 -0
  27. {misata-0.2.0b0 → misata-0.3.1b0}/misata/formulas.py +0 -0
  28. /misata-0.2.0b0/misata/generators.py → /misata-0.3.1b0/misata/generators_legacy.py +0 -0
  29. {misata-0.2.0b0 → misata-0.3.1b0}/misata/hybrid.py +0 -0
  30. {misata-0.2.0b0 → misata-0.3.1b0}/misata/llm_parser.py +0 -0
  31. {misata-0.2.0b0 → misata-0.3.1b0}/misata/noise.py +0 -0
  32. {misata-0.2.0b0 → misata-0.3.1b0}/misata/quality.py +0 -0
  33. {misata-0.2.0b0 → misata-0.3.1b0}/misata/schema.py +0 -0
  34. {misata-0.2.0b0 → misata-0.3.1b0}/misata/semantic.py +0 -0
  35. {misata-0.2.0b0 → misata-0.3.1b0}/misata/story_parser.py +0 -0
  36. {misata-0.2.0b0 → misata-0.3.1b0}/misata/templates/__init__.py +0 -0
  37. {misata-0.2.0b0 → misata-0.3.1b0}/misata/templates/library.py +0 -0
  38. {misata-0.2.0b0 → misata-0.3.1b0}/misata/validation.py +0 -0
  39. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/dependency_links.txt +0 -0
  40. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/entry_points.txt +0 -0
  41. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/requires.txt +0 -0
  42. {misata-0.2.0b0 → misata-0.3.1b0}/misata.egg-info/top_level.txt +0 -0
  43. {misata-0.2.0b0 → misata-0.3.1b0}/setup.cfg +0 -0
  44. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_api.py +0 -0
  45. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_cli.py +0 -0
  46. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_constraints.py +0 -0
  47. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_curve_fitting.py +0 -0
  48. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_enterprise.py +0 -0
  49. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_formulas.py +0 -0
  50. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_integrity.py +0 -0
  51. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_llm_parser.py +0 -0
  52. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_schema.py +0 -0
  53. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_security.py +0 -0
  54. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_semantic.py +0 -0
  55. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_simulator.py +0 -0
  56. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_templates.py +0 -0
  57. {misata-0.2.0b0 → misata-0.3.1b0}/tests/test_validation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: misata
3
- Version: 0.2.0b0
3
+ Version: 0.3.1b0
4
4
  Summary: AI-Powered Synthetic Data Engine - Generate realistic multi-table datasets from natural language
5
5
  Author-email: Muhammed Rasin <rasinbinabdulla@gmail.com>
6
6
  License: MIT
@@ -0,0 +1,134 @@
1
+ """
2
+ Misata - AI-Powered Synthetic Data Engine
3
+
4
+ Generate realistic multi-table datasets from natural language descriptions.
5
+ Supports OpenAI, Groq, Gemini, and Ollama for intelligent schema generation.
6
+
7
+ Usage:
8
+ from misata import DataSimulator, SchemaConfig
9
+
10
+ # Or use the CLI:
11
+ # misata generate --story "A SaaS with 50k users..."
12
+
13
+ # Or use pre-built templates:
14
+ from misata.templates.library import load_template
15
+ config = load_template("ecommerce")
16
+ """
17
+
18
+ __version__ = "0.3.1b0"
19
+ __author__ = "Muhammed Rasin"
20
+
21
+ from misata.schema import (
22
+ Column,
23
+ Constraint,
24
+ Relationship,
25
+ ScenarioEvent,
26
+ SchemaConfig,
27
+ Table,
28
+ )
29
+ from misata.simulator import DataSimulator
30
+ from misata.generators import TextGenerator
31
+ from misata.generators.base import (
32
+ BaseGenerator,
33
+ IntegerGenerator,
34
+ FloatGenerator,
35
+ BooleanGenerator,
36
+ CategoricalGenerator,
37
+ DateGenerator,
38
+ ForeignKeyGenerator,
39
+ GeneratorFactory,
40
+ )
41
+ from misata.constraints import (
42
+ BaseConstraint,
43
+ SumConstraint,
44
+ RangeConstraint,
45
+ UniqueConstraint,
46
+ NotNullConstraint,
47
+ RatioConstraint,
48
+ ConstraintEngine,
49
+ )
50
+ from misata.context import GenerationContext
51
+ from misata.exceptions import (
52
+ MisataError,
53
+ SchemaValidationError,
54
+ ColumnGenerationError,
55
+ LLMError,
56
+ ConfigurationError,
57
+ ExportError,
58
+ )
59
+ from misata.smart_values import SmartValueGenerator
60
+ from misata.noise import NoiseInjector, add_noise
61
+ from misata.customization import Customizer, ColumnOverride
62
+ from misata.quality import DataQualityChecker, check_quality
63
+ from misata.templates.library import load_template, list_templates
64
+ from misata.profiles import (
65
+ DistributionProfile,
66
+ get_profile,
67
+ list_profiles,
68
+ generate_with_profile,
69
+ )
70
+ from misata.generators.base import (
71
+ ConditionalCategoricalGenerator,
72
+ CONDITIONAL_LOOKUPS,
73
+ create_conditional_generator,
74
+ )
75
+
76
+ __all__ = [
77
+ # Core
78
+ "Column",
79
+ "Constraint",
80
+ "Relationship",
81
+ "ScenarioEvent",
82
+ "SchemaConfig",
83
+ "Table",
84
+ "DataSimulator",
85
+ # Generators
86
+ "TextGenerator",
87
+ "BaseGenerator",
88
+ "IntegerGenerator",
89
+ "FloatGenerator",
90
+ "BooleanGenerator",
91
+ "CategoricalGenerator",
92
+ "DateGenerator",
93
+ "ForeignKeyGenerator",
94
+ "GeneratorFactory",
95
+ "ConditionalCategoricalGenerator",
96
+ "CONDITIONAL_LOOKUPS",
97
+ "create_conditional_generator",
98
+ # Constraints
99
+ "BaseConstraint",
100
+ "SumConstraint",
101
+ "RangeConstraint",
102
+ "UniqueConstraint",
103
+ "NotNullConstraint",
104
+ "RatioConstraint",
105
+ "ConstraintEngine",
106
+ # Context
107
+ "GenerationContext",
108
+ # Exceptions
109
+ "MisataError",
110
+ "SchemaValidationError",
111
+ "ColumnGenerationError",
112
+ "LLMError",
113
+ "ConfigurationError",
114
+ "ExportError",
115
+ # Smart Values
116
+ "SmartValueGenerator",
117
+ # Distribution Profiles
118
+ "DistributionProfile",
119
+ "get_profile",
120
+ "list_profiles",
121
+ "generate_with_profile",
122
+ # ML-ready features
123
+ "NoiseInjector",
124
+ "add_noise",
125
+ "Customizer",
126
+ "ColumnOverride",
127
+ # Quality
128
+ "DataQualityChecker",
129
+ "check_quality",
130
+ # Templates
131
+ "load_template",
132
+ "list_templates",
133
+ ]
134
+
@@ -0,0 +1,258 @@
1
+ """
2
+ Caching utilities for Misata.
3
+
4
+ Provides LLM response caching using diskcache for performance optimization.
5
+ """
6
+
7
+ import hashlib
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional, Union
12
+
13
+ try:
14
+ import diskcache
15
+ HAS_DISKCACHE = True
16
+ except ImportError:
17
+ HAS_DISKCACHE = False
18
+
19
+
20
+ class LLMCache:
21
+ """Cache for LLM responses to avoid repeated API calls.
22
+
23
+ Uses diskcache for persistent storage with automatic expiration.
24
+ Falls back to in-memory dict if diskcache is not installed.
25
+
26
+ Example:
27
+ cache = LLMCache()
28
+
29
+ # Check cache first
30
+ key = cache.make_key("groq", "llama-3.3", prompt)
31
+ cached = cache.get(key)
32
+ if cached:
33
+ return cached
34
+
35
+ # Make LLM call
36
+ response = llm.generate(prompt)
37
+
38
+ # Cache the result
39
+ cache.set(key, response)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ cache_dir: Optional[str] = None,
45
+ max_size_mb: int = 100,
46
+ expire_days: int = 7,
47
+ ):
48
+ """Initialize the cache.
49
+
50
+ Args:
51
+ cache_dir: Directory for cache storage (default: ~/.misata/cache)
52
+ max_size_mb: Maximum cache size in MB
53
+ expire_days: Days before cache entries expire
54
+ """
55
+ self.expire_seconds = expire_days * 24 * 60 * 60
56
+
57
+ if cache_dir is None:
58
+ cache_dir = os.path.expanduser("~/.misata/cache/llm")
59
+
60
+ self.cache_dir = Path(cache_dir)
61
+
62
+ if HAS_DISKCACHE:
63
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
64
+ self._cache = diskcache.Cache(
65
+ str(self.cache_dir),
66
+ size_limit=max_size_mb * 1024 * 1024,
67
+ )
68
+ else:
69
+ # Fallback to in-memory cache
70
+ self._cache: Dict[str, Any] = {}
71
+ self._memory_mode = True
72
+
73
+ def make_key(
74
+ self,
75
+ provider: str,
76
+ model: str,
77
+ prompt: str,
78
+ temperature: float = 0.0,
79
+ **kwargs: Any
80
+ ) -> str:
81
+ """Create a cache key from request parameters.
82
+
83
+ Args:
84
+ provider: LLM provider (groq, openai, etc.)
85
+ model: Model name
86
+ prompt: The prompt text
87
+ temperature: Temperature setting
88
+ **kwargs: Additional parameters to include in key
89
+
90
+ Returns:
91
+ Hash-based cache key
92
+ """
93
+ key_data = {
94
+ "provider": provider,
95
+ "model": model,
96
+ "prompt": prompt,
97
+ "temperature": temperature,
98
+ **kwargs
99
+ }
100
+ key_str = json.dumps(key_data, sort_keys=True)
101
+ return hashlib.sha256(key_str.encode()).hexdigest()[:32]
102
+
103
+ def get(self, key: str) -> Optional[Any]:
104
+ """Get a cached value.
105
+
106
+ Args:
107
+ key: Cache key
108
+
109
+ Returns:
110
+ Cached value or None if not found/expired
111
+ """
112
+ if HAS_DISKCACHE:
113
+ return self._cache.get(key)
114
+ else:
115
+ return self._cache.get(key)
116
+
117
+ def set(self, key: str, value: Any) -> None:
118
+ """Store a value in cache.
119
+
120
+ Args:
121
+ key: Cache key
122
+ value: Value to cache (must be JSON-serializable for persistence)
123
+ """
124
+ if HAS_DISKCACHE:
125
+ self._cache.set(key, value, expire=self.expire_seconds)
126
+ else:
127
+ self._cache[key] = value
128
+
129
+ def delete(self, key: str) -> bool:
130
+ """Delete a cached value.
131
+
132
+ Args:
133
+ key: Cache key
134
+
135
+ Returns:
136
+ True if deleted, False if not found
137
+ """
138
+ if HAS_DISKCACHE:
139
+ return self._cache.delete(key)
140
+ else:
141
+ if key in self._cache:
142
+ del self._cache[key]
143
+ return True
144
+ return False
145
+
146
+ def clear(self) -> None:
147
+ """Clear all cached values."""
148
+ if HAS_DISKCACHE:
149
+ self._cache.clear()
150
+ else:
151
+ self._cache.clear()
152
+
153
+ def stats(self) -> Dict[str, Any]:
154
+ """Get cache statistics."""
155
+ if HAS_DISKCACHE:
156
+ return {
157
+ "type": "diskcache",
158
+ "directory": str(self.cache_dir),
159
+ "size_bytes": self._cache.volume(),
160
+ "count": len(self._cache),
161
+ }
162
+ else:
163
+ return {
164
+ "type": "memory",
165
+ "count": len(self._cache),
166
+ }
167
+
168
+ def __contains__(self, key: str) -> bool:
169
+ return self.get(key) is not None
170
+
171
+ def close(self) -> None:
172
+ """Close the cache (required for diskcache)."""
173
+ if HAS_DISKCACHE:
174
+ self._cache.close()
175
+
176
+
177
+ class SmartValueCache:
178
+ """Cache for SmartValueGenerator pools.
179
+
180
+ Caches generated value pools by domain/context to avoid
181
+ repeated LLM calls for the same domain.
182
+ """
183
+
184
+ def __init__(self, cache: Optional[LLMCache] = None):
185
+ self._cache = cache or LLMCache(
186
+ cache_dir=os.path.expanduser("~/.misata/cache/smart_values")
187
+ )
188
+
189
+ def get_pool(
190
+ self,
191
+ domain: str,
192
+ context: Optional[str] = None,
193
+ provider: str = "groq"
194
+ ) -> Optional[list]:
195
+ """Get cached value pool for a domain.
196
+
197
+ Args:
198
+ domain: Domain type (disease, prescription, etc.)
199
+ context: Additional context
200
+ provider: LLM provider used
201
+
202
+ Returns:
203
+ List of values or None if not cached
204
+ """
205
+ key = self._make_pool_key(domain, context, provider)
206
+ return self._cache.get(key)
207
+
208
+ def set_pool(
209
+ self,
210
+ domain: str,
211
+ values: list,
212
+ context: Optional[str] = None,
213
+ provider: str = "groq"
214
+ ) -> None:
215
+ """Cache a value pool.
216
+
217
+ Args:
218
+ domain: Domain type
219
+ values: List of generated values
220
+ context: Additional context
221
+ provider: LLM provider used
222
+ """
223
+ key = self._make_pool_key(domain, context, provider)
224
+ self._cache.set(key, values)
225
+
226
+ def _make_pool_key(
227
+ self,
228
+ domain: str,
229
+ context: Optional[str],
230
+ provider: str
231
+ ) -> str:
232
+ key_data = f"{provider}:{domain}:{context or ''}"
233
+ return hashlib.sha256(key_data.encode()).hexdigest()[:24]
234
+
235
+ def clear(self) -> None:
236
+ """Clear all cached pools."""
237
+ self._cache.clear()
238
+
239
+
240
+ # Global cache instances
241
+ _llm_cache: Optional[LLMCache] = None
242
+ _smart_value_cache: Optional[SmartValueCache] = None
243
+
244
+
245
+ def get_llm_cache() -> LLMCache:
246
+ """Get the global LLM cache instance."""
247
+ global _llm_cache
248
+ if _llm_cache is None:
249
+ _llm_cache = LLMCache()
250
+ return _llm_cache
251
+
252
+
253
+ def get_smart_value_cache() -> SmartValueCache:
254
+ """Get the global smart value cache instance."""
255
+ global _smart_value_cache
256
+ if _smart_value_cache is None:
257
+ _smart_value_cache = SmartValueCache()
258
+ return _smart_value_cache
@@ -0,0 +1,307 @@
1
+ """
2
+ Constraint handling for Misata data generation.
3
+
4
+ Provides constraint classes for applying business rules to generated data.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from misata.exceptions import ConstraintError
14
+
15
+
16
+ class BaseConstraint(ABC):
17
+ """Abstract base class for all constraints."""
18
+
19
+ @abstractmethod
20
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
21
+ """Apply the constraint to a DataFrame.
22
+
23
+ Args:
24
+ df: DataFrame to constrain
25
+
26
+ Returns:
27
+ Constrained DataFrame
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ def validate(self, df: pd.DataFrame) -> bool:
33
+ """Check if the constraint is satisfied.
34
+
35
+ Args:
36
+ df: DataFrame to check
37
+
38
+ Returns:
39
+ True if constraint is satisfied
40
+ """
41
+ pass
42
+
43
+
44
+ class SumConstraint(BaseConstraint):
45
+ """Ensures sum of a column (optionally grouped) doesn't exceed a value."""
46
+
47
+ def __init__(
48
+ self,
49
+ column: str,
50
+ max_sum: float,
51
+ group_by: Optional[List[str]] = None,
52
+ action: str = "cap"
53
+ ):
54
+ self.column = column
55
+ self.max_sum = max_sum
56
+ self.group_by = group_by or []
57
+ self.action = action # 'cap', 'redistribute', 'drop'
58
+
59
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
60
+ if self.column not in df.columns:
61
+ return df
62
+
63
+ df = df.copy()
64
+
65
+ if not self.group_by:
66
+ # Global sum constraint
67
+ current_sum = df[self.column].sum()
68
+ if current_sum > self.max_sum:
69
+ if self.action == "cap":
70
+ scale = self.max_sum / current_sum
71
+ df[self.column] = df[self.column] * scale
72
+ elif self.action == "drop":
73
+ # Keep first N rows that fit
74
+ cumsum = df[self.column].cumsum()
75
+ df = df[cumsum <= self.max_sum]
76
+ else:
77
+ # Grouped sum constraint
78
+ def cap_group(group):
79
+ current_sum = group[self.column].sum()
80
+ if current_sum > self.max_sum:
81
+ if self.action == "cap":
82
+ scale = self.max_sum / current_sum
83
+ group = group.copy()
84
+ group[self.column] = group[self.column] * scale
85
+ return group
86
+
87
+ df = df.groupby(self.group_by, group_keys=False).apply(cap_group)
88
+
89
+ return df
90
+
91
+ def validate(self, df: pd.DataFrame) -> bool:
92
+ if self.column not in df.columns:
93
+ return True
94
+
95
+ if not self.group_by:
96
+ return df[self.column].sum() <= self.max_sum
97
+
98
+ group_sums = df.groupby(self.group_by)[self.column].sum()
99
+ return (group_sums <= self.max_sum).all()
100
+
101
+
102
+ class RangeConstraint(BaseConstraint):
103
+ """Ensures values in a column stay within a range."""
104
+
105
+ def __init__(self, column: str, min_val: Optional[float] = None, max_val: Optional[float] = None):
106
+ self.column = column
107
+ self.min_val = min_val
108
+ self.max_val = max_val
109
+
110
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
111
+ if self.column not in df.columns:
112
+ return df
113
+
114
+ df = df.copy()
115
+
116
+ if self.min_val is not None:
117
+ df[self.column] = df[self.column].clip(lower=self.min_val)
118
+
119
+ if self.max_val is not None:
120
+ df[self.column] = df[self.column].clip(upper=self.max_val)
121
+
122
+ return df
123
+
124
+ def validate(self, df: pd.DataFrame) -> bool:
125
+ if self.column not in df.columns:
126
+ return True
127
+
128
+ values = df[self.column]
129
+
130
+ if self.min_val is not None and (values < self.min_val).any():
131
+ return False
132
+
133
+ if self.max_val is not None and (values > self.max_val).any():
134
+ return False
135
+
136
+ return True
137
+
138
+
139
+ class UniqueConstraint(BaseConstraint):
140
+ """Ensures values in column(s) are unique."""
141
+
142
+ def __init__(self, columns: Union[str, List[str]]):
143
+ self.columns = [columns] if isinstance(columns, str) else columns
144
+
145
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
146
+ # Drop duplicates
147
+ return df.drop_duplicates(subset=self.columns)
148
+
149
+ def validate(self, df: pd.DataFrame) -> bool:
150
+ return not df.duplicated(subset=self.columns).any()
151
+
152
+
153
+ class NotNullConstraint(BaseConstraint):
154
+ """Ensures a column has no null values."""
155
+
156
+ def __init__(self, column: str, fill_value: Any = None):
157
+ self.column = column
158
+ self.fill_value = fill_value
159
+
160
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
161
+ if self.column not in df.columns:
162
+ return df
163
+
164
+ df = df.copy()
165
+
166
+ if self.fill_value is not None:
167
+ df[self.column] = df[self.column].fillna(self.fill_value)
168
+ else:
169
+ df = df.dropna(subset=[self.column])
170
+
171
+ return df
172
+
173
+ def validate(self, df: pd.DataFrame) -> bool:
174
+ if self.column not in df.columns:
175
+ return True
176
+ return not df[self.column].isnull().any()
177
+
178
+
179
+ class RatioConstraint(BaseConstraint):
180
+ """Ensures ratio between categories matches target distribution."""
181
+
182
+ def __init__(self, column: str, target_ratios: Dict[Any, float]):
183
+ self.column = column
184
+ self.target_ratios = target_ratios
185
+ # Normalize ratios
186
+ total = sum(target_ratios.values())
187
+ self.target_ratios = {k: v / total for k, v in target_ratios.items()}
188
+
189
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
190
+ if self.column not in df.columns:
191
+ return df
192
+
193
+ df = df.copy()
194
+ n = len(df)
195
+
196
+ # Calculate target counts
197
+ target_counts = {k: int(n * v) for k, v in self.target_ratios.items()}
198
+
199
+ # Randomly assign categories
200
+ categories = []
201
+ for cat, count in target_counts.items():
202
+ categories.extend([cat] * count)
203
+
204
+ # Fill remaining
205
+ remaining = n - len(categories)
206
+ if remaining > 0:
207
+ most_common = max(self.target_ratios, key=self.target_ratios.get)
208
+ categories.extend([most_common] * remaining)
209
+
210
+ np.random.shuffle(categories)
211
+ df[self.column] = categories[:n]
212
+
213
+ return df
214
+
215
+ def validate(self, df: pd.DataFrame) -> bool:
216
+ if self.column not in df.columns:
217
+ return True
218
+
219
+ actual = df[self.column].value_counts(normalize=True)
220
+
221
+ for cat, target in self.target_ratios.items():
222
+ actual_ratio = actual.get(cat, 0)
223
+ if abs(actual_ratio - target) > 0.05: # 5% tolerance
224
+ return False
225
+
226
+ return True
227
+
228
+
229
+ class TemporalConstraint(BaseConstraint):
230
+ """Ensures temporal ordering between columns."""
231
+
232
+ def __init__(self, before_column: str, after_column: str, min_gap_days: int = 0):
233
+ self.before_column = before_column
234
+ self.after_column = after_column
235
+ self.min_gap_days = min_gap_days
236
+
237
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
238
+ if self.before_column not in df.columns or self.after_column not in df.columns:
239
+ return df
240
+
241
+ df = df.copy()
242
+
243
+ before = pd.to_datetime(df[self.before_column])
244
+ after = pd.to_datetime(df[self.after_column])
245
+
246
+ # Fix violations
247
+ mask = after < before
248
+ if mask.any():
249
+ # Swap dates where violated
250
+ df.loc[mask, self.after_column], df.loc[mask, self.before_column] = \
251
+ df.loc[mask, self.before_column], df.loc[mask, self.after_column]
252
+
253
+ # Apply minimum gap
254
+ if self.min_gap_days > 0:
255
+ before = pd.to_datetime(df[self.before_column])
256
+ after = pd.to_datetime(df[self.after_column])
257
+ gap = (after - before).dt.days
258
+
259
+ mask = gap < self.min_gap_days
260
+ if mask.any():
261
+ df.loc[mask, self.after_column] = (
262
+ before[mask] + pd.Timedelta(days=self.min_gap_days)
263
+ ).dt.strftime('%Y-%m-%d')
264
+
265
+ return df
266
+
267
+ def validate(self, df: pd.DataFrame) -> bool:
268
+ if self.before_column not in df.columns or self.after_column not in df.columns:
269
+ return True
270
+
271
+ before = pd.to_datetime(df[self.before_column])
272
+ after = pd.to_datetime(df[self.after_column])
273
+
274
+ gap = (after - before).dt.days
275
+ return (gap >= self.min_gap_days).all()
276
+
277
+
278
+ class ConstraintEngine:
279
+ """Engine for applying multiple constraints to a DataFrame."""
280
+
281
+ def __init__(self, constraints: Optional[List[BaseConstraint]] = None):
282
+ self.constraints = constraints or []
283
+
284
+ def add(self, constraint: BaseConstraint) -> "ConstraintEngine":
285
+ """Add a constraint."""
286
+ self.constraints.append(constraint)
287
+ return self
288
+
289
+ def apply_all(self, df: pd.DataFrame) -> pd.DataFrame:
290
+ """Apply all constraints in order."""
291
+ for constraint in self.constraints:
292
+ try:
293
+ df = constraint.apply(df)
294
+ except Exception as e:
295
+ raise ConstraintError(
296
+ f"Failed to apply constraint: {e}",
297
+ constraint_type=type(constraint).__name__
298
+ )
299
+ return df
300
+
301
+ def validate_all(self, df: pd.DataFrame) -> Dict[str, bool]:
302
+ """Check all constraints and return results."""
303
+ results = {}
304
+ for constraint in self.constraints:
305
+ name = type(constraint).__name__
306
+ results[name] = constraint.validate(df)
307
+ return results