misata 0.3.0b0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. misata/__init__.py +1 -1
  2. misata/agents/__init__.py +23 -0
  3. misata/agents/pipeline.py +286 -0
  4. misata/causal/__init__.py +5 -0
  5. misata/causal/graph.py +109 -0
  6. misata/causal/solver.py +115 -0
  7. misata/cli.py +31 -0
  8. misata/generators/__init__.py +19 -0
  9. misata/generators/copula.py +198 -0
  10. misata/llm_parser.py +180 -137
  11. misata/quality.py +78 -33
  12. misata/reference_data.py +221 -0
  13. misata/research/__init__.py +3 -0
  14. misata/research/agent.py +70 -0
  15. misata/schema.py +25 -0
  16. misata/simulator.py +264 -12
  17. misata/smart_values.py +144 -6
  18. misata/studio/__init__.py +55 -0
  19. misata/studio/app.py +49 -0
  20. misata/studio/components/inspector.py +81 -0
  21. misata/studio/components/sidebar.py +35 -0
  22. misata/studio/constraint_generator.py +781 -0
  23. misata/studio/inference.py +319 -0
  24. misata/studio/outcome_curve.py +284 -0
  25. misata/studio/state/store.py +55 -0
  26. misata/studio/tabs/configure.py +50 -0
  27. misata/studio/tabs/generate.py +117 -0
  28. misata/studio/tabs/outcome_curve.py +149 -0
  29. misata/studio/tabs/schema_designer.py +217 -0
  30. misata/studio/utils/styles.py +143 -0
  31. misata/studio_constraints/__init__.py +29 -0
  32. misata/studio_constraints/z3_solver.py +259 -0
  33. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/METADATA +13 -2
  34. misata-0.5.0.dist-info/RECORD +61 -0
  35. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/WHEEL +1 -1
  36. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/entry_points.txt +1 -0
  37. misata-0.3.0b0.dist-info/RECORD +0 -37
  38. /misata/{generators.py → generators_legacy.py} +0 -0
  39. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/licenses/LICENSE +0 -0
  40. {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
1
+ """
2
+ Schema Inference Module - Reverse-engineer schemas from sample data.
3
+
4
+ This module analyzes uploaded CSV/JSON data and infers:
5
+ - Column types (int, float, categorical, date, text, email, uuid, etc.)
6
+ - Distribution parameters (min, max, mean, std, choices, etc.)
7
+ - Correlations between columns
8
+ """
9
+
10
+ import re
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+ from datetime import datetime
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from misata.schema import Column, SchemaConfig, Table
18
+
19
+
20
+ # ============ Type Detection Patterns ============
21
+
22
+ EMAIL_PATTERN = re.compile(r'^[\w\.-]+@[\w\.-]+\.\w+$')
23
+ UUID_PATTERN = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.I)
24
+ PHONE_PATTERN = re.compile(r'^[\d\s\-\+\(\)]{7,20}$')
25
+ URL_PATTERN = re.compile(r'^https?://')
26
+
27
+
28
+ def detect_column_type(series: pd.Series) -> Tuple[str, Dict[str, Any]]:
29
+ """Detect the type and distribution parameters of a column.
30
+
31
+ Args:
32
+ series: Pandas Series to analyze
33
+
34
+ Returns:
35
+ Tuple of (type_name, distribution_params)
36
+ """
37
+ # Drop nulls for analysis
38
+ clean = series.dropna()
39
+ if len(clean) == 0:
40
+ return "text", {"text_type": "sentence"}
41
+
42
+ # Check for boolean
43
+ unique_vals = set(clean.unique())
44
+ if unique_vals <= {True, False, 0, 1, "true", "false", "True", "False", "yes", "no", "Yes", "No"}:
45
+ # Calculate probability of True
46
+ bool_vals = clean.map(lambda x: str(x).lower() in ('true', '1', 'yes'))
47
+ prob = bool_vals.mean()
48
+ return "boolean", {"probability": round(prob, 2)}
49
+
50
+ # Check for UUID
51
+ if clean.dtype == object:
52
+ sample = str(clean.iloc[0])
53
+ if UUID_PATTERN.match(sample):
54
+ return "text", {"text_type": "uuid"}
55
+
56
+ # Check for email
57
+ if EMAIL_PATTERN.match(sample):
58
+ return "text", {"text_type": "email"}
59
+
60
+ # Check for URL
61
+ if URL_PATTERN.match(sample):
62
+ return "text", {"text_type": "url"}
63
+
64
+ # Check for phone
65
+ if PHONE_PATTERN.match(sample):
66
+ return "text", {"text_type": "phone"}
67
+
68
+ # Check for date
69
+ if clean.dtype == 'datetime64[ns]' or pd.api.types.is_datetime64_any_dtype(clean):
70
+ return "date", {
71
+ "start": str(clean.min().date()),
72
+ "end": str(clean.max().date())
73
+ }
74
+
75
+ # Try parsing as date
76
+ if clean.dtype == object:
77
+ try:
78
+ parsed = pd.to_datetime(clean, errors='coerce')
79
+ if parsed.notna().mean() > 0.9: # 90%+ parse as dates
80
+ return "date", {
81
+ "start": str(parsed.min().date()),
82
+ "end": str(parsed.max().date())
83
+ }
84
+ except:
85
+ pass
86
+
87
+ # Check for categorical (limited unique values)
88
+ n_unique = clean.nunique()
89
+ if n_unique <= min(20, len(clean) * 0.2): # <=20 or <=20% unique
90
+ value_counts = clean.value_counts(normalize=True)
91
+ choices = value_counts.index.tolist()
92
+ probabilities = [round(p, 3) for p in value_counts.values.tolist()]
93
+ return "categorical", {
94
+ "choices": choices,
95
+ "probabilities": probabilities
96
+ }
97
+
98
+ # Check for numeric
99
+ if pd.api.types.is_integer_dtype(clean):
100
+ return "int", {
101
+ "min": int(clean.min()),
102
+ "max": int(clean.max()),
103
+ "distribution": "uniform"
104
+ }
105
+
106
+ if pd.api.types.is_float_dtype(clean):
107
+ # Check if it looks like currency (2 decimal places)
108
+ decimals = clean.apply(lambda x: len(str(x).split('.')[-1]) if '.' in str(x) else 0)
109
+ if decimals.mode().iloc[0] == 2:
110
+ return "float", {
111
+ "min": round(float(clean.min()), 2),
112
+ "max": round(float(clean.max()), 2),
113
+ "distribution": "lognormal",
114
+ "decimals": 2
115
+ }
116
+ return "float", {
117
+ "min": float(clean.min()),
118
+ "max": float(clean.max()),
119
+ "distribution": "normal",
120
+ "mean": float(clean.mean()),
121
+ "std": float(clean.std())
122
+ }
123
+
124
+ # Try converting to numeric
125
+ try:
126
+ numeric = pd.to_numeric(clean, errors='coerce')
127
+ if numeric.notna().mean() > 0.9: # 90%+ are numeric
128
+ if numeric.apply(float.is_integer).all():
129
+ return "int", {
130
+ "min": int(numeric.min()),
131
+ "max": int(numeric.max())
132
+ }
133
+ return "float", {
134
+ "min": float(numeric.min()),
135
+ "max": float(numeric.max()),
136
+ "mean": float(numeric.mean()),
137
+ "std": float(numeric.std())
138
+ }
139
+ except:
140
+ pass
141
+
142
+ # Default to text
143
+ # Try to detect text type from column name
144
+ col_name = series.name.lower() if series.name else ""
145
+
146
+ if "name" in col_name:
147
+ return "text", {"text_type": "name"}
148
+ elif "email" in col_name:
149
+ return "text", {"text_type": "email"}
150
+ elif "address" in col_name:
151
+ return "text", {"text_type": "address"}
152
+ elif "company" in col_name or "org" in col_name:
153
+ return "text", {"text_type": "company"}
154
+ elif "phone" in col_name:
155
+ return "text", {"text_type": "phone"}
156
+ elif "url" in col_name or "website" in col_name:
157
+ return "text", {"text_type": "url"}
158
+
159
+ return "text", {"text_type": "sentence"}
160
+
161
+
162
+ def fit_distribution(series: pd.Series) -> Dict[str, Any]:
163
+ """Fit a statistical distribution to numeric data.
164
+
165
+ Args:
166
+ series: Numeric pandas Series
167
+
168
+ Returns:
169
+ Distribution parameters including type and fitted params
170
+ """
171
+ clean = pd.to_numeric(series.dropna(), errors='coerce').dropna()
172
+ if len(clean) < 5:
173
+ return {"distribution": "uniform", "min": 0, "max": 100}
174
+
175
+ mean = float(clean.mean())
176
+ std = float(clean.std())
177
+ min_val = float(clean.min())
178
+ max_val = float(clean.max())
179
+ skew = float(clean.skew())
180
+
181
+ # Determine best distribution based on characteristics
182
+ if abs(skew) < 0.5:
183
+ # Roughly symmetric → Normal
184
+ return {
185
+ "distribution": "normal",
186
+ "mean": mean,
187
+ "std": std,
188
+ "min": min_val,
189
+ "max": max_val
190
+ }
191
+ elif skew > 1.0 and min_val >= 0:
192
+ # Right-skewed, positive → Lognormal
193
+ return {
194
+ "distribution": "lognormal",
195
+ "mean": np.log(mean) if mean > 0 else 0,
196
+ "sigma": std / mean if mean > 0 else 1,
197
+ "min": min_val,
198
+ "max": max_val
199
+ }
200
+ else:
201
+ # Use empirical (histogram-based)
202
+ hist, bins = np.histogram(clean, bins=20, density=True)
203
+ control_points = []
204
+ for i in range(len(hist)):
205
+ x = (bins[i] + bins[i+1]) / 2
206
+ y = float(hist[i])
207
+ control_points.append({"x": x, "y": y})
208
+
209
+ return {
210
+ "distribution": "custom",
211
+ "control_points": control_points,
212
+ "min": min_val,
213
+ "max": max_val
214
+ }
215
+
216
+
217
+ def infer_schema(
218
+ data: pd.DataFrame,
219
+ table_name: str = "data",
220
+ row_count: Optional[int] = None
221
+ ) -> SchemaConfig:
222
+ """Infer a complete schema from sample data.
223
+
224
+ Args:
225
+ data: Sample DataFrame to analyze
226
+ table_name: Name for the inferred table
227
+ row_count: Target row count (default: 100x input)
228
+
229
+ Returns:
230
+ SchemaConfig ready for generation
231
+ """
232
+ if row_count is None:
233
+ row_count = max(len(data) * 100, 1000)
234
+
235
+ columns = []
236
+ for col_name in data.columns:
237
+ col_type, params = detect_column_type(data[col_name])
238
+
239
+ # Check for unique constraint
240
+ is_unique = data[col_name].nunique() == len(data)
241
+
242
+ column = Column(
243
+ name=str(col_name),
244
+ table_name=table_name,
245
+ type=col_type,
246
+ distribution_params=params,
247
+ nullable=data[col_name].isna().any(),
248
+ unique=is_unique
249
+ )
250
+ columns.append(column)
251
+
252
+ return SchemaConfig(
253
+ name=f"Inferred: {table_name}",
254
+ tables=[Table(
255
+ name=table_name,
256
+ row_count=row_count,
257
+ columns=[c.name for c in columns]
258
+ )],
259
+ columns={table_name: columns},
260
+ relationships=[]
261
+ )
262
+
263
+
264
+ def detect_correlations(data: pd.DataFrame) -> List[Dict[str, Any]]:
265
+ """Detect correlations between numeric columns.
266
+
267
+ Args:
268
+ data: DataFrame to analyze
269
+
270
+ Returns:
271
+ List of correlation dicts with column pairs and strength
272
+ """
273
+ numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
274
+ if len(numeric_cols) < 2:
275
+ return []
276
+
277
+ correlations = []
278
+ corr_matrix = data[numeric_cols].corr()
279
+
280
+ for i, col1 in enumerate(numeric_cols):
281
+ for col2 in numeric_cols[i+1:]:
282
+ corr = corr_matrix.loc[col1, col2]
283
+ if abs(corr) > 0.5: # Only report strong correlations
284
+ correlations.append({
285
+ "column1": col1,
286
+ "column2": col2,
287
+ "correlation": round(corr, 3),
288
+ "strength": "strong" if abs(corr) > 0.7 else "moderate"
289
+ })
290
+
291
+ return correlations
292
+
293
+
294
+ def schema_to_dict(schema: SchemaConfig) -> Dict[str, Any]:
295
+ """Convert schema to a JSON-serializable dict for the UI."""
296
+ return {
297
+ "name": schema.name,
298
+ "tables": [
299
+ {
300
+ "name": t.name,
301
+ "row_count": t.row_count,
302
+ "columns": t.columns
303
+ }
304
+ for t in schema.tables
305
+ ],
306
+ "columns": {
307
+ table_name: [
308
+ {
309
+ "name": c.name,
310
+ "type": c.type,
311
+ "params": c.distribution_params,
312
+ "nullable": c.nullable,
313
+ "unique": c.unique
314
+ }
315
+ for c in cols
316
+ ]
317
+ for table_name, cols in schema.columns.items()
318
+ }
319
+ }
@@ -0,0 +1,284 @@
1
+ """
2
+ Outcome Curve Designer - Reverse Time-Series Generation
3
+
4
+ The killer feature: Users draw the aggregated outcome they want,
5
+ and Misata generates individual transactions that produce that exact curve.
6
+
7
+ Example:
8
+ User draws: Revenue = [$100K, $150K, $200K, $180K, ...] over 12 months
9
+ Misata generates: 50,000 individual orders with dates/amounts
10
+ When aggregated: SUM(amount) GROUP BY month = exactly the drawn curve
11
+
12
+ Algorithm:
13
+ 1. Parse curve control points into time buckets
14
+ 2. For each bucket, calculate target aggregate
15
+ 3. Distribute transactions across bucket:
16
+ - Determine transaction count (based on avg ticket or specified)
17
+ - Generate individual amounts that sum to target
18
+ 4. Add variance/noise for realism
19
+ """
20
+
21
+ from dataclasses import dataclass
22
+ from datetime import datetime, timedelta
23
+ from typing import Any, Dict, List, Optional, Tuple
24
+ import numpy as np
25
+ import pandas as pd
26
+
27
+
28
+ @dataclass
29
+ class CurvePoint:
30
+ """A single point on the outcome curve."""
31
+ timestamp: datetime
32
+ value: float
33
+
34
+
35
+ @dataclass
36
+ class OutcomeCurve:
37
+ """Represents the target outcome curve drawn by user."""
38
+ metric_name: str # e.g., "revenue", "signups", "orders"
39
+ time_unit: str # "day", "week", "month"
40
+ points: List[CurvePoint]
41
+
42
+ # Optional constraints
43
+ avg_transaction_value: Optional[float] = None # For revenue curves
44
+ min_transactions_per_period: int = 10
45
+ max_transactions_per_period: int = 10000
46
+
47
+
48
+ def interpolate_curve(curve: OutcomeCurve, num_buckets: int) -> List[float]:
49
+ """Interpolate curve to get values for each time bucket."""
50
+ if len(curve.points) < 2:
51
+ return [curve.points[0].value] * num_buckets
52
+
53
+ # Extract x (time indices) and y (values)
54
+ x = np.array([i for i in range(len(curve.points))])
55
+ y = np.array([p.value for p in curve.points])
56
+
57
+ # Interpolate to num_buckets
58
+ x_new = np.linspace(0, len(curve.points) - 1, num_buckets)
59
+
60
+ from scipy.interpolate import interp1d
61
+ f = interp1d(x, y, kind='cubic', fill_value='extrapolate')
62
+ return list(np.maximum(f(x_new), 0)) # Ensure non-negative
63
+
64
+
65
+ def generate_transactions_for_bucket(
66
+ target_value: float,
67
+ bucket_start: datetime,
68
+ bucket_end: datetime,
69
+ avg_transaction: Optional[float] = None,
70
+ min_transactions: int = 10,
71
+ max_transactions: int = 1000,
72
+ rng: Optional[np.random.Generator] = None
73
+ ) -> pd.DataFrame:
74
+ """Generate individual transactions that sum to target_value for a time bucket.
75
+
76
+ Returns DataFrame with columns: [timestamp, amount]
77
+ """
78
+ if rng is None:
79
+ rng = np.random.default_rng()
80
+
81
+ if target_value <= 0:
82
+ return pd.DataFrame(columns=['timestamp', 'amount'])
83
+
84
+ # Determine number of transactions
85
+ if avg_transaction:
86
+ n_transactions = int(target_value / avg_transaction)
87
+ n_transactions = max(min_transactions, min(n_transactions, max_transactions))
88
+ else:
89
+ # Estimate based on target value
90
+ n_transactions = max(min_transactions, min(int(target_value / 50), max_transactions))
91
+
92
+ # Generate amounts that sum to target using Dirichlet distribution
93
+ # This ensures realistic variation while hitting exact target
94
+ proportions = rng.dirichlet(np.ones(n_transactions) * 2) # alpha=2 for moderate variance
95
+ amounts = proportions * target_value
96
+
97
+ # Add some variance to make it more realistic
98
+ # Small noise that doesn't change the sum significantly
99
+ noise = rng.normal(0, abs(target_value) * 0.001, n_transactions)
100
+ amounts = amounts + noise
101
+
102
+ # Adjust to hit exact target (compensate for noise)
103
+ amounts = amounts * (target_value / amounts.sum())
104
+
105
+ # Ensure all positive
106
+ amounts = np.maximum(amounts, 0.01)
107
+ amounts = amounts * (target_value / amounts.sum()) # Re-normalize
108
+
109
+ # Generate timestamps uniformly distributed within bucket
110
+ bucket_duration = (bucket_end - bucket_start).total_seconds()
111
+ random_seconds = rng.uniform(0, bucket_duration, n_transactions)
112
+ timestamps = [bucket_start + timedelta(seconds=s) for s in random_seconds]
113
+
114
+ # Sort by timestamp
115
+ df = pd.DataFrame({
116
+ 'timestamp': timestamps,
117
+ 'amount': amounts.round(2)
118
+ }).sort_values('timestamp').reset_index(drop=True)
119
+
120
+ return df
121
+
122
+
123
+ def generate_from_outcome_curve(
124
+ curve: OutcomeCurve,
125
+ start_date: Optional[datetime] = None,
126
+ seed: int = 42
127
+ ) -> pd.DataFrame:
128
+ """Generate a full transaction dataset from an outcome curve.
129
+
130
+ Args:
131
+ curve: The target outcome curve
132
+ start_date: Start date (defaults to today minus curve duration)
133
+ seed: Random seed for reproducibility
134
+
135
+ Returns:
136
+ DataFrame with columns: [id, timestamp, amount] where
137
+ SUM(amount) GROUP BY period = the drawn curve
138
+ """
139
+ rng = np.random.default_rng(seed)
140
+
141
+ n_periods = len(curve.points)
142
+
143
+ # Determine bucket duration
144
+ if curve.time_unit == "day":
145
+ bucket_delta = timedelta(days=1)
146
+ elif curve.time_unit == "week":
147
+ bucket_delta = timedelta(weeks=1)
148
+ elif curve.time_unit == "month":
149
+ bucket_delta = timedelta(days=30) # Approximate
150
+ else:
151
+ bucket_delta = timedelta(days=1)
152
+
153
+ # Set start date
154
+ if start_date is None:
155
+ start_date = datetime.now() - (bucket_delta * n_periods)
156
+
157
+ all_transactions = []
158
+
159
+ for i, point in enumerate(curve.points):
160
+ bucket_start = start_date + (bucket_delta * i)
161
+ bucket_end = bucket_start + bucket_delta
162
+
163
+ transactions = generate_transactions_for_bucket(
164
+ target_value=point.value,
165
+ bucket_start=bucket_start,
166
+ bucket_end=bucket_end,
167
+ avg_transaction=curve.avg_transaction_value,
168
+ min_transactions=curve.min_transactions_per_period,
169
+ max_transactions=curve.max_transactions_per_period,
170
+ rng=rng
171
+ )
172
+
173
+ all_transactions.append(transactions)
174
+
175
+ # Combine all transactions
176
+ df = pd.concat(all_transactions, ignore_index=True)
177
+ df.insert(0, 'id', range(1, len(df) + 1))
178
+
179
+ return df
180
+
181
+
182
+ def verify_curve_match(
183
+ transactions: pd.DataFrame,
184
+ curve: OutcomeCurve,
185
+ start_date: datetime
186
+ ) -> Dict[str, Any]:
187
+ """Verify that generated transactions aggregate to match the target curve.
188
+
189
+ Returns:
190
+ Dict with 'match_score', 'expected', 'actual', 'error_pct'
191
+ """
192
+ n_periods = len(curve.points)
193
+ expected = np.array([p.value for p in curve.points])
194
+
195
+ # Determine bucket duration
196
+ if curve.time_unit == "day":
197
+ bucket_delta = timedelta(days=1)
198
+ elif curve.time_unit == "week":
199
+ bucket_delta = timedelta(weeks=1)
200
+ else: # month
201
+ bucket_delta = timedelta(days=30)
202
+
203
+ # Assign each transaction to a bucket index based on time offset from start
204
+ def get_bucket_index(ts):
205
+ offset = (ts - start_date).total_seconds()
206
+ bucket_seconds = bucket_delta.total_seconds()
207
+ return min(int(offset / bucket_seconds), n_periods - 1)
208
+
209
+ transactions = transactions.copy()
210
+ transactions['bucket_idx'] = transactions['timestamp'].apply(get_bucket_index)
211
+
212
+ # Aggregate by bucket index
213
+ actual_by_bucket = transactions.groupby('bucket_idx')['amount'].sum()
214
+
215
+ # Build actual array matching expected length
216
+ actual = np.zeros(n_periods)
217
+ for idx, val in actual_by_bucket.items():
218
+ if 0 <= idx < n_periods:
219
+ actual[idx] = val
220
+
221
+ # Calculate match score
222
+ error_pct = np.abs(actual - expected) / np.maximum(expected, 1) * 100
223
+ avg_error = error_pct.mean()
224
+ match_score = max(0, 100 - avg_error)
225
+
226
+ return {
227
+ 'match_score': round(match_score, 2),
228
+ 'expected': expected.tolist(),
229
+ 'actual': actual.tolist(),
230
+ 'error_pct': error_pct.tolist(),
231
+ 'avg_error_pct': round(avg_error, 2)
232
+ }
233
+
234
+
235
+ # ============ Preset Curve Shapes ============
236
+
237
+ def get_curve_presets() -> Dict[str, List[float]]:
238
+ """Get preset curve shapes for common business patterns."""
239
+ return {
240
+ "Linear Growth": [100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320],
241
+ "Exponential Growth": [100, 115, 132, 152, 175, 201, 231, 266, 306, 352, 405, 466],
242
+ "Hockey Stick": [100, 102, 105, 108, 112, 118, 140, 180, 250, 350, 500, 700],
243
+ "Seasonal (Retail)": [100, 80, 70, 90, 100, 120, 110, 100, 130, 160, 200, 300],
244
+ "SaaS Growth": [10, 18, 30, 50, 80, 120, 170, 230, 300, 380, 470, 570],
245
+ "Churn Decline": [1000, 920, 850, 790, 740, 700, 665, 635, 610, 590, 575, 560],
246
+ "V-shaped Recovery": [100, 80, 60, 50, 45, 50, 65, 85, 110, 140, 170, 200],
247
+ "Plateau": [100, 150, 200, 240, 270, 290, 300, 305, 308, 310, 311, 312],
248
+ }
249
+
250
+
251
+ def create_curve_from_preset(
252
+ preset_name: str,
253
+ metric_name: str = "revenue",
254
+ time_unit: str = "month",
255
+ start_date: datetime = None,
256
+ scale: float = 1000 # Multiply preset values by this
257
+ ) -> OutcomeCurve:
258
+ """Create an OutcomeCurve from a preset shape."""
259
+ presets = get_curve_presets()
260
+ values = presets.get(preset_name, presets["Linear Growth"])
261
+
262
+ if start_date is None:
263
+ start_date = datetime.now() - timedelta(days=30 * len(values))
264
+
265
+ if time_unit == "day":
266
+ delta = timedelta(days=1)
267
+ elif time_unit == "week":
268
+ delta = timedelta(weeks=1)
269
+ else:
270
+ delta = timedelta(days=30)
271
+
272
+ points = [
273
+ CurvePoint(
274
+ timestamp=start_date + delta * i,
275
+ value=v * scale
276
+ )
277
+ for i, v in enumerate(values)
278
+ ]
279
+
280
+ return OutcomeCurve(
281
+ metric_name=metric_name,
282
+ time_unit=time_unit,
283
+ points=points
284
+ )
@@ -0,0 +1,55 @@
1
+ import streamlit as st
2
+ from datetime import datetime
3
+
4
+ class StudioStore:
5
+ """Centralized state management for Misata Studio."""
6
+
7
+ @staticmethod
8
+ def init():
9
+ """Initialize all session state variables with smart defaults."""
10
+ defaults = {
11
+ # Navigation
12
+ "active_tab": "Schema",
13
+ "sidebar_expanded": True,
14
+
15
+ # Data & Schema
16
+ "schema_config": None,
17
+ "schema_source": "Template", # "Template" or "AI"
18
+ "warehouse_schema": {
19
+ "type": "service_company",
20
+ "customer_count": 500,
21
+ "project_count": 2000
22
+ },
23
+
24
+ # Constraint Configuration
25
+ "selected_constraint": None, # e.g. "invoices.amount"
26
+ "warehouse_curve": [100000] * 12, # Default annual curve
27
+ "start_date_input": datetime.now().date(),
28
+
29
+ # Generation Config
30
+ "warehouse_config": {
31
+ "avg_transaction": 50.0,
32
+ "seed": 42,
33
+ "tier_distribution": [0.5, 0.3, 0.2]
34
+ },
35
+
36
+ # Results
37
+ "generated_warehouse": None,
38
+ "warehouse_generated": False
39
+ }
40
+
41
+ for key, default_val in defaults.items():
42
+ if key not in st.session_state:
43
+ st.session_state[key] = default_val
44
+
45
+ @staticmethod
46
+ def get(key, default=None):
47
+ return st.session_state.get(key, default)
48
+
49
+ @staticmethod
50
+ def set(key, value):
51
+ st.session_state[key] = value
52
+
53
+ @property
54
+ def schema(self):
55
+ return st.session_state.get("schema_config")