misata 0.3.0b0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. misata/__init__.py +1 -1
  2. misata/agents/__init__.py +23 -0
  3. misata/agents/pipeline.py +286 -0
  4. misata/causal/__init__.py +5 -0
  5. misata/causal/graph.py +109 -0
  6. misata/causal/solver.py +115 -0
  7. misata/cli.py +31 -0
  8. misata/generators/__init__.py +19 -0
  9. misata/generators/copula.py +198 -0
  10. misata/llm_parser.py +180 -137
  11. misata/quality.py +78 -33
  12. misata/reference_data.py +221 -0
  13. misata/research/__init__.py +3 -0
  14. misata/research/agent.py +70 -0
  15. misata/schema.py +25 -0
  16. misata/simulator.py +264 -12
  17. misata/smart_values.py +144 -6
  18. misata/studio/__init__.py +55 -0
  19. misata/studio/app.py +49 -0
  20. misata/studio/components/inspector.py +81 -0
  21. misata/studio/components/sidebar.py +35 -0
  22. misata/studio/constraint_generator.py +781 -0
  23. misata/studio/inference.py +319 -0
  24. misata/studio/outcome_curve.py +284 -0
  25. misata/studio/state/store.py +55 -0
  26. misata/studio/tabs/configure.py +50 -0
  27. misata/studio/tabs/generate.py +117 -0
  28. misata/studio/tabs/outcome_curve.py +149 -0
  29. misata/studio/tabs/schema_designer.py +217 -0
  30. misata/studio/utils/styles.py +143 -0
  31. misata/studio_constraints/__init__.py +29 -0
  32. misata/studio_constraints/z3_solver.py +259 -0
  33. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/METADATA +13 -2
  34. misata-0.5.0.dist-info/RECORD +61 -0
  35. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/WHEEL +1 -1
  36. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/entry_points.txt +1 -0
  37. misata-0.3.0b0.dist-info/RECORD +0 -37
  38. /misata/{generators.py → generators_legacy.py} +0 -0
  39. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/licenses/LICENSE +0 -0
  40. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,781 @@
1
+ """
2
+ Constraint-Driven Data Warehouse Generator
3
+
4
+ This module generates complete data warehouses where:
5
+ 1. Dimension tables are generated with specified distributions
6
+ 2. Fact tables are generated to satisfy outcome constraints (e.g., monthly revenue targets)
7
+ 3. All foreign key relationships are maintained
8
+
9
+ Example:
10
+ schema = {
11
+ "customers": {...}, # dimension
12
+ "products": {...}, # dimension
13
+ "orders": {...} # fact table with revenue constraint
14
+ }
15
+
16
+ outcome = OutcomeCurve(revenue = [100K, 150K, 200K, ...])
17
+
18
+ result = generate_constrained_warehouse(schema, outcome)
19
+ # result["orders"] will sum to exactly the outcome curve when grouped by month
20
+ """
21
+
22
+ from dataclasses import dataclass, field
23
+ from datetime import datetime, timedelta
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ from misata.schema import SchemaConfig
29
+ from misata.reference_data import detect_domain, get_reference_data
30
+ from misata.studio.outcome_curve import (
31
+ OutcomeCurve, CurvePoint,
32
+ generate_transactions_for_bucket
33
+ )
34
+ from misata.causal.graph import get_saas_template, CausalGraph
35
+ from misata.causal.solver import CausalSolver
36
+
37
+ # Import ConstraintEngine for 100% business rule compliance (MisataStudio)
38
+ try:
39
+ from misata.studio_constraints.z3_solver import (
40
+ ConstraintEngine,
41
+ add_common_business_rules,
42
+ create_constraint_engine
43
+ )
44
+ CONSTRAINT_ENGINE_AVAILABLE = True
45
+ except ImportError:
46
+ CONSTRAINT_ENGINE_AVAILABLE = False
47
+ ConstraintEngine = None
48
+ # Silent - this is optional for studio features
49
+
50
+
51
+
52
+ @dataclass
53
+ class ColumnSpec:
54
+ """Specification for a column's distribution."""
55
+ name: str
56
+ type: str # "int", "float", "categorical", "text", "date", "boolean", "foreign_key"
57
+
58
+ # For numeric
59
+ min_val: Optional[float] = None
60
+ max_val: Optional[float] = None
61
+ distribution: str = "uniform" # "uniform", "normal", "lognormal"
62
+ mean: Optional[float] = None
63
+ std: Optional[float] = None
64
+
65
+ # For categorical
66
+ choices: Optional[List[str]] = None
67
+ probabilities: Optional[List[float]] = None
68
+
69
+ # For text
70
+ text_type: str = "name" # "name", "email", "company", "uuid", etc.
71
+
72
+ # For foreign_key
73
+ references: Optional[str] = None # "customers.id"
74
+
75
+ # For date
76
+ date_start: Optional[datetime] = None
77
+ date_end: Optional[datetime] = None
78
+
79
+
80
+ @dataclass
81
+ class TableSpec:
82
+ """Specification for a table."""
83
+ name: str
84
+ row_count: int
85
+ columns: List[ColumnSpec]
86
+ is_fact: bool = False # True for fact tables that get outcome constraints
87
+
88
+ # For fact tables
89
+ date_column: Optional[str] = None # Which column holds the date
90
+ amount_column: Optional[str] = None # Which column holds the value to constrain
91
+
92
+ # For reference tables - LLM-generated actual data
93
+ inline_data: Optional[List[Dict[str, Any]]] = None
94
+
95
+
96
+ @dataclass
97
+ class OutcomeConstraint:
98
+ """An outcome constraint for a metric."""
99
+ metric_name: str # e.g., "revenue", "orders", "signups"
100
+ fact_table: str # Which table to constrain
101
+ value_column: str # Which column to sum (e.g., "amount")
102
+ date_column: str # Which column has the date
103
+ outcome_curve: OutcomeCurve # The target values
104
+
105
+
106
+ @dataclass
107
+ class WarehouseSpec:
108
+ """Complete specification for a data warehouse."""
109
+ tables: List[TableSpec]
110
+ constraints: List[OutcomeConstraint] = field(default_factory=list)
111
+
112
+
113
+ class ConstrainedWarehouseGenerator:
114
+ """Generates complete data warehouses with outcome constraints."""
115
+
116
+ def __init__(self, spec: WarehouseSpec, seed: int = 42):
117
+ self.spec = spec
118
+ self.rng = np.random.default_rng(seed)
119
+ self.generated_tables: Dict[str, pd.DataFrame] = {}
120
+
121
+ # Auto-detect domain from table names
122
+ table_names = [t.name for t in spec.tables]
123
+ self.domain = detect_domain(table_names)
124
+
125
+ # Causal Engine Integration
126
+ self.causal_graph: Optional[CausalGraph] = None
127
+ self.causal_mapping: Dict[str, str] = {} # GraphNode -> TableName
128
+ self._detect_and_setup_causal()
129
+
130
+ def _detect_and_setup_causal(self):
131
+ """Attempts to map the current schema to a Causal Template (SaaS)."""
132
+ # Try SaaS Template
133
+ # Node Map: Traffic, Leads, Deals, Revenue
134
+ mapping = {}
135
+ tables = {t.name.lower(): t.name for t in self.spec.tables}
136
+
137
+ # Heuristic mapping
138
+ # 1. Traffic
139
+ for kw in ['traffic', 'visits', 'sessions']:
140
+ matching = [real for low, real in tables.items() if kw in low]
141
+ if matching:
142
+ mapping["Traffic"] = matching[0]
143
+ break
144
+
145
+ # 2. Leads
146
+ for kw in ['leads', 'signups', 'users', 'registrations']:
147
+ matching = [real for low, real in tables.items() if kw in low]
148
+ if matching:
149
+ mapping["Leads"] = matching[0]
150
+ break
151
+
152
+ # 3. Deals/Revenue (Often same table, e.g. Invoices or Subscriptions)
153
+ for kw in ['deals', 'orders', 'subscriptions', 'invoices', 'sales']:
154
+ matching = [real for low, real in tables.items() if kw in low]
155
+ if matching:
156
+ mapping["Deals"] = matching[0]
157
+ mapping["Revenue"] = matching[0] # Revenue is usually a column in Deals
158
+ break
159
+
160
+ # If we have at least Traffic and Revenue/Deals, we can use the graph
161
+ if "Traffic" in mapping and "Revenue" in mapping:
162
+ self.causal_graph = get_saas_template()
163
+ self.causal_mapping = mapping
164
+ print(f"✅ Causal Engine Activated: Mapped to SaaS Template ({mapping})")
165
+
166
+
167
+ def generate_all(self) -> Dict[str, pd.DataFrame]:
168
+ """Generate all tables in the warehouse."""
169
+
170
+ # 1. Try Causal Generation first if activated
171
+ if self.causal_graph and self.spec.constraints:
172
+ try:
173
+ self._generate_causal_flow()
174
+ # If successful, we still need to generate dimensions and fill remaining columns
175
+ # The causal flow populates 'self.generated_tables' with the core rows.
176
+ except Exception as e:
177
+ print(f"⚠️ Causal Generation Failed: {e}. Falling back to stochastic.")
178
+
179
+ # 2. Separate dimension tables from fact tables
180
+ dimension_tables = [t for t in self.spec.tables if not t.is_fact]
181
+ fact_tables = [t for t in self.spec.tables if t.is_fact]
182
+
183
+ # 3. Generate dimension tables (if not already generated by Causal Engine)
184
+ for table in dimension_tables:
185
+ if table.name in self.generated_tables:
186
+ continue
187
+ self.generated_tables[table.name] = self._generate_dimension_table(table)
188
+
189
+ # 4. Generate fact tables with constraints (if not already generated)
190
+ for table in fact_tables:
191
+ if table.name in self.generated_tables:
192
+ continue
193
+ # Find constraint for this table
194
+ constraint = next(
195
+ (c for c in self.spec.constraints if c.fact_table == table.name),
196
+ None
197
+ )
198
+ self.generated_tables[table.name] = self._generate_fact_table(table, constraint)
199
+
200
+ # 5. ENFORCE BUSINESS RULES (100% Compliance) - Key Differentiator!
201
+ if CONSTRAINT_ENGINE_AVAILABLE:
202
+ self._enforce_business_rules()
203
+
204
+ return self.generated_tables
205
+
206
+ def _enforce_business_rules(self):
207
+ """
208
+ Post-generation step to ensure 100% business rule compliance.
209
+ This is what differentiates us from Gretel (probabilistic only).
210
+ """
211
+ if not CONSTRAINT_ENGINE_AVAILABLE:
212
+ return
213
+
214
+ engine = create_constraint_engine()
215
+
216
+ # Build schema dict for rule detection
217
+ schema_dict = {
218
+ "tables": [{"name": t.name} for t in self.spec.tables],
219
+ "columns": {
220
+ t.name: [{"name": c.name, "type": c.type, "distribution_params": {
221
+ "min": c.min_val,
222
+ "max": c.max_val
223
+ }} for c in t.columns]
224
+ for t in self.spec.tables
225
+ }
226
+ }
227
+
228
+ # Auto-add common business rules
229
+ add_common_business_rules(engine, schema_dict)
230
+
231
+ # Apply constraints to each table
232
+ total_violations = 0
233
+ for table_name, df in self.generated_tables.items():
234
+ if len(df) == 0:
235
+ continue
236
+
237
+ # Validate
238
+ result = engine.validate(df)
239
+
240
+ if not result["is_100_percent_compliant"]:
241
+ violations = len(df) - result["valid_rows"]
242
+ total_violations += violations
243
+ print(f"[CONSTRAINT] {table_name}: {violations} violations found, fixing...")
244
+
245
+ # Fix violations by filtering
246
+ self.generated_tables[table_name] = engine.filter_valid(df)
247
+
248
+ if total_violations > 0:
249
+ print(f"[CONSTRAINT] Total violations fixed: {total_violations} rows removed")
250
+ else:
251
+ print(f"[CONSTRAINT] ✅ All tables 100% compliant!")
252
+
253
+ def _generate_causal_flow(self):
254
+ """
255
+ Executes the Causal Solver to generate consistent data across tables.
256
+ """
257
+ if not self.causal_graph or not self.spec.constraints:
258
+ return
259
+
260
+ # 1. Extract Constraints
261
+ # We need to map 'Revenue' (Graph Node) to the specific Constraint Value
262
+ target_constraints = {}
263
+
264
+ for constraint in self.spec.constraints:
265
+ # Check if this constraint applies to a mapped node
266
+ # Mapped: Revenue -> Table 'invoices'
267
+ # Constraint: Table 'invoices', Metric 'revenue'
268
+
269
+ # Find which node maps to this table
270
+ mapped_node = None
271
+ for node, table_name in self.causal_mapping.items():
272
+ if table_name == constraint.fact_table:
273
+ # Check if metric matches (Revenue vs Deals)
274
+ if node == "Revenue" and constraint.metric_name.lower() in ['revenue', 'amount', 'sales']:
275
+ mapped_node = "Revenue"
276
+ elif node == "Deals" and constraint.metric_name.lower() in ['count', 'orders', 'deals', 'volume']:
277
+ mapped_node = "Deals"
278
+
279
+ if mapped_node:
280
+ # Extract simple curve array (assuming monthly for now or resampling)
281
+ # For simplicity in this v1, using the raw points values
282
+ values = [p.value for p in constraint.outcome_curve.points]
283
+ target_constraints[mapped_node] = np.array(values)
284
+
285
+ if not target_constraints:
286
+ return # No relevant constraints for the graph
287
+
288
+ # 2. Solve
289
+ solver = CausalSolver(self.causal_graph)
290
+ # solve for Traffic if Revenue is constrained
291
+ adjustable = ["Traffic"]
292
+
293
+ solved_inputs = solver.solve(target_constraints, adjustable_nodes=adjustable)
294
+
295
+ # 3. Forward Pass to get all node values
296
+ # Add defaults for conversion rates (exogenous)
297
+ # TODO: Get these from "Fact Injection" in Phase 7
298
+ full_inputs = solved_inputs.copy()
299
+ sample_size = len(list(target_constraints.values())[0])
300
+ full_inputs["LeadConversion"] = np.ones(sample_size) * 0.05 # 5% conversion
301
+ full_inputs["SalesConversion"] = np.ones(sample_size) * 0.20 # 20% conversion
302
+ full_inputs["AOV"] = np.ones(sample_size) * 100.0 # $100 AOV
303
+
304
+ # If Revenue was constrained, AOV might need to shift if we want to hit it exactly with integer deals?
305
+ # For now, CausalSolver solved for Traffic assuming these defaults (if we set them in solver).
306
+ # Actually in solver.py we defaulted to 1.0.
307
+ # WE NEED TO MATCH DEFAULTS.
308
+ # Let's fix solver usage later to be robust. For now, assuming defaults.
309
+
310
+ # Re-run forward pass with our specific defaults
311
+ final_nodes = self.causal_graph.forward_pass(full_inputs) # dict of node -> array
312
+
313
+ # 4. Generate Tables from Node Arrays
314
+ # Traffic Node -> Traffic Table
315
+ if "Traffic" in self.causal_mapping:
316
+ t_name = self.causal_mapping["Traffic"]
317
+ self.generated_tables[t_name] = self._generate_table_from_curve(t_name, final_nodes["Traffic"])
318
+
319
+ # Leads Node -> Leads Table
320
+ if "Leads" in self.causal_mapping:
321
+ t_name = self.causal_mapping["Leads"]
322
+ if t_name not in self.generated_tables: # Avoid overwrite if same table
323
+ self.generated_tables[t_name] = self._generate_table_from_curve(t_name, final_nodes["Leads"])
324
+
325
+ # Deals/Revenue Node -> Fact Table
326
+ if "Deals" in self.causal_mapping:
327
+ t_name = self.causal_mapping["Deals"]
328
+ # If table is already generated (Traffic=Leads=Deals in one table?), handle merge.
329
+ # Assuming distinct tables or overwriting for now.
330
+ self.generated_tables[t_name] = self._generate_table_from_curve(t_name, final_nodes["Deals"], revenue_array=final_nodes.get("Revenue"))
331
+
332
+ def _generate_table_from_curve(self, table_name: str, count_array: np.ndarray, revenue_array: Optional[np.ndarray] = None) -> pd.DataFrame:
333
+ """Generates a table where row counts per bucket match the count_array."""
334
+ table_spec = next(t for t in self.spec.tables if t.name == table_name)
335
+
336
+ # Assuming monthly buckets for now (from Constraints)
337
+ # We need start date.
338
+ start_date = datetime.now() - timedelta(days=365) # Default
339
+
340
+ all_rows = []
341
+ for i, count in enumerate(count_array):
342
+ bucket_start = start_date + timedelta(days=30*i)
343
+ num_rows = int(max(0, count)) # Ensure integer and non-negative
344
+
345
+ # Generate Basic Columns
346
+ data = {}
347
+ for col in table_spec.columns:
348
+ data[col.name] = self._generate_column(col, num_rows)
349
+
350
+ df = pd.DataFrame(data)
351
+
352
+ # Override Date
353
+ if table_spec.date_column:
354
+ # Spread randomly within month
355
+ offsets = self.rng.integers(0, 30, num_rows)
356
+ dates = [bucket_start + timedelta(days=int(o)) for o in offsets]
357
+ df[table_spec.date_column] = dates
358
+
359
+ # Override Revenue if provided
360
+ if revenue_array is not None and table_spec.amount_column and num_rows > 0:
361
+ # Distribute total revenue among rows
362
+ total_rev = revenue_array[i]
363
+ # Simple average for now: total / count
364
+ avg_rev = total_rev / num_rows
365
+ # Add some noise
366
+ revs = self.rng.normal(avg_rev, avg_rev * 0.1, num_rows)
367
+ # Correct sum
368
+ revs = revs * (total_rev / revs.sum())
369
+ df[table_spec.amount_column] = revs
370
+
371
+ all_rows.append(df)
372
+
373
+ final_df = pd.concat(all_rows, ignore_index=True) if all_rows else pd.DataFrame()
374
+ # Add ID
375
+ if 'id' not in final_df.columns and len(final_df) > 0:
376
+ final_df.insert(0, 'id', range(1, len(final_df) + 1))
377
+
378
+ return final_df
379
+
380
+
381
+ def _generate_dimension_table(self, table: TableSpec) -> pd.DataFrame:
382
+ """Generate a dimension table with specified distributions."""
383
+
384
+ # Priority 1: Use inline_data if provided (from LLM)
385
+ if table.inline_data:
386
+ df = pd.DataFrame(table.inline_data)
387
+ self.generated_tables[table.name] = df
388
+ return df
389
+
390
+ # Priority 2: Use domain-aware reference library
391
+ library_data = get_reference_data(self.domain, table.name)
392
+ if library_data:
393
+ df = pd.DataFrame(library_data)
394
+ self.generated_tables[table.name] = df
395
+ return df
396
+
397
+ # Priority 3: Generate from column specs with smart awareness
398
+ data = {}
399
+
400
+ for col in table.columns:
401
+ data[col.name] = self._generate_column(col, table.row_count)
402
+
403
+ df = pd.DataFrame(data)
404
+
405
+ # Add ID if not present
406
+ if 'id' not in df.columns:
407
+ df.insert(0, 'id', range(1, len(df) + 1))
408
+
409
+ return df
410
+
411
+ def _generate_fact_table(
412
+ self,
413
+ table: TableSpec,
414
+ constraint: Optional[OutcomeConstraint]
415
+ ) -> pd.DataFrame:
416
+ """Generate a fact table, optionally with outcome constraints."""
417
+
418
+ if constraint is None:
419
+ # No constraint - just generate normally
420
+ return self._generate_dimension_table(table)
421
+
422
+ # With constraint: use the outcome curve to generate rows
423
+ curve = constraint.outcome_curve
424
+ n_periods = len(curve.points)
425
+
426
+ # Determine bucket duration
427
+ if curve.time_unit == "day":
428
+ bucket_delta = timedelta(days=1)
429
+ elif curve.time_unit == "week":
430
+ bucket_delta = timedelta(weeks=1)
431
+ else:
432
+ bucket_delta = timedelta(days=30)
433
+
434
+ all_rows = []
435
+
436
+ for i, point in enumerate(curve.points):
437
+ bucket_start = point.timestamp
438
+ bucket_end = bucket_start + bucket_delta
439
+
440
+ # Generate transactions for this bucket
441
+ bucket_df = generate_transactions_for_bucket(
442
+ target_value=point.value,
443
+ bucket_start=bucket_start,
444
+ bucket_end=bucket_end,
445
+ avg_transaction=50.0, # Could be configurable
446
+ rng=self.rng
447
+ )
448
+
449
+ # Add other columns
450
+ for col in table.columns:
451
+ if col.name == constraint.date_column:
452
+ # Date column already generated as 'timestamp'
453
+ bucket_df[col.name] = bucket_df['timestamp']
454
+ elif col.name == constraint.value_column:
455
+ # Amount column already generated as 'amount'
456
+ bucket_df[col.name] = bucket_df['amount']
457
+ elif col.type == "foreign_key" and col.references:
458
+ # Link to dimension table
459
+ ref_table, ref_col = col.references.split('.')
460
+ if ref_table in self.generated_tables:
461
+ fk_values = self.generated_tables[ref_table][ref_col].values
462
+ bucket_df[col.name] = self.rng.choice(fk_values, size=len(bucket_df))
463
+ else:
464
+ # Generate other columns
465
+ bucket_df[col.name] = self._generate_column(col, len(bucket_df))
466
+
467
+ # Clean up temp columns
468
+ if 'timestamp' in bucket_df.columns and 'timestamp' != constraint.date_column:
469
+ bucket_df = bucket_df.drop('timestamp', axis=1)
470
+ if 'amount' in bucket_df.columns and 'amount' != constraint.value_column:
471
+ bucket_df = bucket_df.drop('amount', axis=1)
472
+
473
+ all_rows.append(bucket_df)
474
+
475
+ # Combine all periods
476
+ df = pd.concat(all_rows, ignore_index=True)
477
+
478
+ # Add ID if not present
479
+ if 'id' not in df.columns:
480
+ df.insert(0, 'id', range(1, len(df) + 1))
481
+
482
+ return df
483
+
484
+ def _generate_column(self, col: ColumnSpec, size: int) -> np.ndarray:
485
+ """Generate values for a single column."""
486
+ col_name_lower = col.name.lower()
487
+
488
+ if col.type == "int":
489
+ if col.distribution == "normal":
490
+ mean = col.mean or (col.min_val + col.max_val) / 2
491
+ std = col.std or (col.max_val - col.min_val) / 6
492
+ values = self.rng.normal(mean, std, size)
493
+ values = np.clip(values, col.min_val, col.max_val)
494
+ return values.astype(int)
495
+ else: # uniform
496
+ return self.rng.integers(col.min_val or 0, col.max_val or 100, size)
497
+
498
+ elif col.type == "float":
499
+ # Smart price detection
500
+ if any(kw in col_name_lower for kw in ['price', 'cost', 'amount', 'fee']):
501
+ # Generate realistic price tiers
502
+ price_tiers = [0.0, 9.99, 14.99, 19.99, 29.99, 49.99, 99.99, 199.99]
503
+ return self.rng.choice(price_tiers, size=size)
504
+ elif col.distribution == "normal":
505
+ mean = col.mean or ((col.min_val or 0) + (col.max_val or 100)) / 2
506
+ std = col.std or ((col.max_val or 100) - (col.min_val or 0)) / 6
507
+ values = self.rng.normal(mean, std, size)
508
+ return np.clip(values, col.min_val or 0, col.max_val or 100)
509
+ elif col.distribution == "lognormal":
510
+ values = self.rng.lognormal(col.mean or 0, col.std or 1, size)
511
+ if col.max_val:
512
+ values = np.clip(values, col.min_val or 0, col.max_val)
513
+ return values
514
+ else: # uniform
515
+ return self.rng.uniform(col.min_val or 0, col.max_val or 100, size)
516
+
517
+ elif col.type == "categorical":
518
+ choices = col.choices or ["A", "B", "C"]
519
+ probs = col.probabilities
520
+ if probs:
521
+ probs = np.array(probs) / sum(probs) # Normalize
522
+ return self.rng.choice(choices, size=size, p=probs)
523
+
524
+ elif col.type == "boolean":
525
+ prob = col.probabilities[0] if col.probabilities else 0.5
526
+ return self.rng.random(size) < prob
527
+
528
+ elif col.type == "date":
529
+ start = col.date_start or datetime.now() - timedelta(days=365)
530
+ end = col.date_end or datetime.now()
531
+ start_ts = start.timestamp()
532
+ end_ts = end.timestamp()
533
+ random_ts = self.rng.uniform(start_ts, end_ts, size)
534
+ return pd.to_datetime(random_ts, unit='s')
535
+
536
+ elif col.type == "text":
537
+ # Smart text generation based on column name
538
+ if "category" in col_name_lower or "type" in col_name_lower:
539
+ categories = ["Electronics", "Clothing", "Home & Garden", "Sports", "Books", "Toys", "Health", "Automotive"]
540
+ return self.rng.choice(categories, size=size)
541
+ elif "feature" in col_name_lower or "description" in col_name_lower:
542
+ features = ["Premium Support", "Advanced Analytics", "Custom Reports", "API Access", "Priority Queue", "Unlimited Storage", "24/7 Support"]
543
+ return self.rng.choice(features, size=size)
544
+ elif "status" in col_name_lower:
545
+ statuses = ["active", "pending", "completed", "cancelled", "on_hold"]
546
+ return self.rng.choice(statuses, size=size)
547
+ elif "plan" in col_name_lower or "tier" in col_name_lower:
548
+ plans = ["Free", "Basic", "Pro", "Premium", "Enterprise"]
549
+ return self.rng.choice(plans, size=size)
550
+ elif col.text_type == "email":
551
+ return [f"user{i}@example.com" for i in self.rng.integers(1000, 9999, size)]
552
+ elif col.text_type == "uuid":
553
+ import uuid
554
+ return [str(uuid.uuid4()) for _ in range(size)]
555
+ elif col.text_type == "company":
556
+ companies = ["Acme Inc", "TechCorp", "GlobalSoft", "DataDrive", "CloudBase", "ByteForge", "NexGen Systems"]
557
+ return self.rng.choice(companies, size=size)
558
+ else: # name
559
+ first = ["John", "Jane", "Bob", "Alice", "Charlie", "Diana", "Eve", "Frank", "Grace", "Henry"]
560
+ last = ["Smith", "Jones", "Brown", "Wilson", "Taylor", "Davis", "Clark", "Moore", "Anderson"]
561
+ return [f"{self.rng.choice(first)} {self.rng.choice(last)}" for _ in range(size)]
562
+
563
+ elif col.type == "foreign_key" and col.references:
564
+ ref_table, ref_col = col.references.split('.')
565
+ if ref_table in self.generated_tables:
566
+ fk_values = self.generated_tables[ref_table][ref_col].values
567
+ return self.rng.choice(fk_values, size=size)
568
+ return self.rng.integers(1, 100, size)
569
+
570
+ else:
571
+ return self.rng.integers(1, 100, size)
572
+
573
+
574
+ def generate_constrained_warehouse(
575
+ spec: WarehouseSpec,
576
+ seed: int = 42
577
+ ) -> Dict[str, pd.DataFrame]:
578
+ """
579
+ Generate a complete data warehouse with outcome constraints.
580
+
581
+ Args:
582
+ spec: Complete warehouse specification
583
+ seed: Random seed for reproducibility
584
+
585
+ Returns:
586
+ Dict mapping table names to DataFrames
587
+ """
588
+ generator = ConstrainedWarehouseGenerator(spec, seed)
589
+ return generator.generate_all()
590
+
591
+
592
+ # ============ Quick Builder Functions ============
593
+
594
+ def create_service_company_schema(
595
+ customer_count: int = 500,
596
+ project_count: int = 2000,
597
+ revenue_curve: Optional[OutcomeCurve] = None
598
+ ) -> WarehouseSpec:
599
+ """Create a typical service company data warehouse schema."""
600
+
601
+ customers = TableSpec(
602
+ name="customers",
603
+ row_count=customer_count,
604
+ columns=[
605
+ ColumnSpec(name="id", type="int", min_val=1, max_val=customer_count),
606
+ ColumnSpec(name="name", type="text", text_type="name"),
607
+ ColumnSpec(name="email", type="text", text_type="email"),
608
+ ColumnSpec(name="tier", type="categorical",
609
+ choices=["Basic", "Pro", "Enterprise"],
610
+ probabilities=[0.5, 0.3, 0.2]),
611
+ ColumnSpec(name="created_at", type="date"),
612
+ ]
613
+ )
614
+
615
+ projects = TableSpec(
616
+ name="projects",
617
+ row_count=project_count,
618
+ columns=[
619
+ ColumnSpec(name="id", type="int", min_val=1, max_val=project_count),
620
+ ColumnSpec(name="customer_id", type="foreign_key", references="customers.id"),
621
+ ColumnSpec(name="name", type="text", text_type="company"),
622
+ ColumnSpec(name="status", type="categorical",
623
+ choices=["Active", "Completed", "On Hold"],
624
+ probabilities=[0.6, 0.3, 0.1]),
625
+ ]
626
+ )
627
+
628
+ invoices = TableSpec(
629
+ name="invoices",
630
+ row_count=10000, # Will be determined by constraint
631
+ columns=[
632
+ ColumnSpec(name="id", type="int"),
633
+ ColumnSpec(name="project_id", type="foreign_key", references="projects.id"),
634
+ ColumnSpec(name="invoice_date", type="date"),
635
+ ColumnSpec(name="amount", type="float", min_val=100, max_val=10000),
636
+ ColumnSpec(name="status", type="categorical",
637
+ choices=["Paid", "Pending", "Overdue"],
638
+ probabilities=[0.7, 0.2, 0.1]),
639
+ ],
640
+ is_fact=True,
641
+ date_column="invoice_date",
642
+ amount_column="amount"
643
+ )
644
+
645
+ constraints = []
646
+ if revenue_curve:
647
+ constraints.append(OutcomeConstraint(
648
+ metric_name="revenue",
649
+ fact_table="invoices",
650
+ value_column="amount",
651
+ date_column="invoice_date",
652
+ outcome_curve=revenue_curve
653
+ ))
654
+
655
+ return WarehouseSpec(
656
+ tables=[customers, projects, invoices],
657
+ constraints=constraints
658
+ )
659
+
660
+
661
+ def convert_schema_config_to_spec(
662
+ config: SchemaConfig,
663
+ revenue_curve: Optional[OutcomeCurve] = None,
664
+ fact_table_override: Optional[str] = None,
665
+ date_col_override: Optional[str] = None,
666
+ amount_col_override: Optional[str] = None
667
+ ) -> WarehouseSpec:
668
+ """Convert a generic SchemaConfig (e.g. from LLM) to a WarehouseSpec.
669
+
670
+ If fact_table_override is provided, use that instead of heuristic detection.
671
+ """
672
+
673
+ table_specs = []
674
+ fact_table_name = None
675
+ date_col_name = None
676
+ amount_col_name = None
677
+
678
+ # Use overrides if provided (from LLM curve specification)
679
+ if fact_table_override:
680
+ fact_table_name = fact_table_override
681
+ date_col_name = date_col_override
682
+ amount_col_name = amount_col_override
683
+ print(f"[SPEC DEBUG] Using LLM overrides: fact={fact_table_name}, date={date_col_name}, amount={amount_col_name}")
684
+ else:
685
+ # 1. Identify Fact Table Heuristically (largest table with date + amount)
686
+ # or just pick the one with most rows that isn't reference
687
+ candidate_tables = []
688
+
689
+ for table_name, columns in config.columns.items():
690
+ # Find table definition
691
+ table_def = next((t for t in config.tables if t.name == table_name), None)
692
+ if not table_def or table_def.is_reference:
693
+ continue
694
+
695
+ has_date = any(c.type in ['date', 'datetime'] for c in columns)
696
+ # Broader amount detection - any float/int column that's not id/count
697
+ amount_cols = [c.name for c in columns if c.type in ['float', 'int'] and c.name not in ['id', 'count'] and not c.name.endswith('_id')]
698
+
699
+ if has_date and amount_cols:
700
+ candidate_tables.append({
701
+ "name": table_name,
702
+ "rows": table_def.row_count,
703
+ "date_col": next(c.name for c in columns if c.type in ['date', 'datetime']),
704
+ "amount_col": amount_cols[0] # Pick first numeric column
705
+ })
706
+
707
+ # Sort candidates by row count (fact tables are usually largest)
708
+ if candidate_tables:
709
+ candidate_tables.sort(key=lambda x: x["rows"], reverse=True)
710
+ best_candidate = candidate_tables[0]
711
+ fact_table_name = best_candidate["name"]
712
+ date_col_name = best_candidate["date_col"]
713
+ amount_col_name = best_candidate["amount_col"]
714
+ print(f"[SPEC DEBUG] Heuristic detected: fact={fact_table_name}, date={date_col_name}, amount={amount_col_name}")
715
+
716
+ # 2. Convert Tables
717
+ for table in config.tables:
718
+ # Skip reference tables if they are just data (generator handles them differently?
719
+ # No, generator needs specs for everything that isn't purely inline constraint)
720
+ # Actually ConstrainedWarehouseGenerator needs spec for everything to generate it.
721
+ # But if it has inline_data, we might treat it differently?
722
+ # For now, let's assume we map everything, but reference tables rely on their inline data in the original schema
723
+ # The generator 'generate_constrained_warehouse' builds from scratch.
724
+ # IF IT IS A REFERENCE TABLE, WE SHOULD RESPECT INLINE DATA
725
+ # But TableSpec doesn't strictly support inline data in this version yet.
726
+ # We will map it as best we can.
727
+
728
+ cols = config.columns.get(table.name, [])
729
+ col_specs = []
730
+
731
+ for c in cols:
732
+ # Map params
733
+ params = c.distribution_params or {}
734
+
735
+ spec = ColumnSpec(
736
+ name=c.name,
737
+ type=c.type if c.type != 'foreign_key' else 'foreign_key',
738
+ min_val=params.get('min'),
739
+ max_val=params.get('max'),
740
+ mean=params.get('mean'),
741
+ std=params.get('std'),
742
+ choices=params.get('choices'),
743
+ probabilities=params.get('probabilities'),
744
+ text_type=params.get('text_type', 'word'),
745
+ references=None # We need to resolve this if it's FK
746
+ )
747
+
748
+ # Resolve FK references logic roughly
749
+ if c.type == 'foreign_key':
750
+ # Try to find relationship
751
+ # This is hard without explicit rels in SchemaConfig sometimes, but SchemaConfig HAS relationships!
752
+ rel = next((r for r in config.relationships if r.child_table == table.name and r.child_key == c.name), None)
753
+ if rel:
754
+ spec.references = f"{rel.parent_table}.{rel.parent_key}"
755
+
756
+ col_specs.append(spec)
757
+
758
+ is_fact = (table.name == fact_table_name)
759
+
760
+ table_specs.append(TableSpec(
761
+ name=table.name,
762
+ row_count=table.row_count,
763
+ columns=col_specs,
764
+ is_fact=is_fact,
765
+ date_column=date_col_name if is_fact else None,
766
+ amount_column=amount_col_name if is_fact else None,
767
+ inline_data=table.inline_data if table.is_reference and table.inline_data else None
768
+ ))
769
+
770
+ # 3. Create Constraints
771
+ constraints = []
772
+ if revenue_curve and fact_table_name:
773
+ constraints.append(OutcomeConstraint(
774
+ metric_name="revenue",
775
+ fact_table=fact_table_name,
776
+ value_column=amount_col_name,
777
+ date_column=date_col_name,
778
+ outcome_curve=revenue_curve
779
+ ))
780
+
781
+ return WarehouseSpec(tables=table_specs, constraints=constraints)