misata 0.3.0b0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- misata/__init__.py +1 -1
- misata/agents/__init__.py +23 -0
- misata/agents/pipeline.py +286 -0
- misata/causal/__init__.py +5 -0
- misata/causal/graph.py +109 -0
- misata/causal/solver.py +115 -0
- misata/cli.py +31 -0
- misata/generators/__init__.py +19 -0
- misata/generators/copula.py +198 -0
- misata/llm_parser.py +180 -137
- misata/quality.py +78 -33
- misata/reference_data.py +221 -0
- misata/research/__init__.py +3 -0
- misata/research/agent.py +70 -0
- misata/schema.py +25 -0
- misata/simulator.py +264 -12
- misata/smart_values.py +144 -6
- misata/studio/__init__.py +55 -0
- misata/studio/app.py +49 -0
- misata/studio/components/inspector.py +81 -0
- misata/studio/components/sidebar.py +35 -0
- misata/studio/constraint_generator.py +781 -0
- misata/studio/inference.py +319 -0
- misata/studio/outcome_curve.py +284 -0
- misata/studio/state/store.py +55 -0
- misata/studio/tabs/configure.py +50 -0
- misata/studio/tabs/generate.py +117 -0
- misata/studio/tabs/outcome_curve.py +149 -0
- misata/studio/tabs/schema_designer.py +217 -0
- misata/studio/utils/styles.py +143 -0
- misata/studio_constraints/__init__.py +29 -0
- misata/studio_constraints/z3_solver.py +259 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/METADATA +13 -2
- misata-0.5.0.dist-info/RECORD +61 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/WHEEL +1 -1
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/entry_points.txt +1 -0
- misata-0.3.0b0.dist-info/RECORD +0 -37
- /misata/{generators.py → generators_legacy.py} +0 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema Inference Module - Reverse-engineer schemas from sample data.
|
|
3
|
+
|
|
4
|
+
This module analyzes uploaded CSV/JSON data and infers:
|
|
5
|
+
- Column types (int, float, categorical, date, text, email, uuid, etc.)
|
|
6
|
+
- Distribution parameters (min, max, mean, std, choices, etc.)
|
|
7
|
+
- Correlations between columns
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import re
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from misata.schema import Column, SchemaConfig, Table
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ============ Type Detection Patterns ============
|
|
21
|
+
|
|
22
|
+
EMAIL_PATTERN = re.compile(r'^[\w\.-]+@[\w\.-]+\.\w+$')
|
|
23
|
+
UUID_PATTERN = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.I)
|
|
24
|
+
PHONE_PATTERN = re.compile(r'^[\d\s\-\+\(\)]{7,20}$')
|
|
25
|
+
URL_PATTERN = re.compile(r'^https?://')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def detect_column_type(series: pd.Series) -> Tuple[str, Dict[str, Any]]:
|
|
29
|
+
"""Detect the type and distribution parameters of a column.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
series: Pandas Series to analyze
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Tuple of (type_name, distribution_params)
|
|
36
|
+
"""
|
|
37
|
+
# Drop nulls for analysis
|
|
38
|
+
clean = series.dropna()
|
|
39
|
+
if len(clean) == 0:
|
|
40
|
+
return "text", {"text_type": "sentence"}
|
|
41
|
+
|
|
42
|
+
# Check for boolean
|
|
43
|
+
unique_vals = set(clean.unique())
|
|
44
|
+
if unique_vals <= {True, False, 0, 1, "true", "false", "True", "False", "yes", "no", "Yes", "No"}:
|
|
45
|
+
# Calculate probability of True
|
|
46
|
+
bool_vals = clean.map(lambda x: str(x).lower() in ('true', '1', 'yes'))
|
|
47
|
+
prob = bool_vals.mean()
|
|
48
|
+
return "boolean", {"probability": round(prob, 2)}
|
|
49
|
+
|
|
50
|
+
# Check for UUID
|
|
51
|
+
if clean.dtype == object:
|
|
52
|
+
sample = str(clean.iloc[0])
|
|
53
|
+
if UUID_PATTERN.match(sample):
|
|
54
|
+
return "text", {"text_type": "uuid"}
|
|
55
|
+
|
|
56
|
+
# Check for email
|
|
57
|
+
if EMAIL_PATTERN.match(sample):
|
|
58
|
+
return "text", {"text_type": "email"}
|
|
59
|
+
|
|
60
|
+
# Check for URL
|
|
61
|
+
if URL_PATTERN.match(sample):
|
|
62
|
+
return "text", {"text_type": "url"}
|
|
63
|
+
|
|
64
|
+
# Check for phone
|
|
65
|
+
if PHONE_PATTERN.match(sample):
|
|
66
|
+
return "text", {"text_type": "phone"}
|
|
67
|
+
|
|
68
|
+
# Check for date
|
|
69
|
+
if clean.dtype == 'datetime64[ns]' or pd.api.types.is_datetime64_any_dtype(clean):
|
|
70
|
+
return "date", {
|
|
71
|
+
"start": str(clean.min().date()),
|
|
72
|
+
"end": str(clean.max().date())
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Try parsing as date
|
|
76
|
+
if clean.dtype == object:
|
|
77
|
+
try:
|
|
78
|
+
parsed = pd.to_datetime(clean, errors='coerce')
|
|
79
|
+
if parsed.notna().mean() > 0.9: # 90%+ parse as dates
|
|
80
|
+
return "date", {
|
|
81
|
+
"start": str(parsed.min().date()),
|
|
82
|
+
"end": str(parsed.max().date())
|
|
83
|
+
}
|
|
84
|
+
except:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
# Check for categorical (limited unique values)
|
|
88
|
+
n_unique = clean.nunique()
|
|
89
|
+
if n_unique <= min(20, len(clean) * 0.2): # <=20 or <=20% unique
|
|
90
|
+
value_counts = clean.value_counts(normalize=True)
|
|
91
|
+
choices = value_counts.index.tolist()
|
|
92
|
+
probabilities = [round(p, 3) for p in value_counts.values.tolist()]
|
|
93
|
+
return "categorical", {
|
|
94
|
+
"choices": choices,
|
|
95
|
+
"probabilities": probabilities
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Check for numeric
|
|
99
|
+
if pd.api.types.is_integer_dtype(clean):
|
|
100
|
+
return "int", {
|
|
101
|
+
"min": int(clean.min()),
|
|
102
|
+
"max": int(clean.max()),
|
|
103
|
+
"distribution": "uniform"
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if pd.api.types.is_float_dtype(clean):
|
|
107
|
+
# Check if it looks like currency (2 decimal places)
|
|
108
|
+
decimals = clean.apply(lambda x: len(str(x).split('.')[-1]) if '.' in str(x) else 0)
|
|
109
|
+
if decimals.mode().iloc[0] == 2:
|
|
110
|
+
return "float", {
|
|
111
|
+
"min": round(float(clean.min()), 2),
|
|
112
|
+
"max": round(float(clean.max()), 2),
|
|
113
|
+
"distribution": "lognormal",
|
|
114
|
+
"decimals": 2
|
|
115
|
+
}
|
|
116
|
+
return "float", {
|
|
117
|
+
"min": float(clean.min()),
|
|
118
|
+
"max": float(clean.max()),
|
|
119
|
+
"distribution": "normal",
|
|
120
|
+
"mean": float(clean.mean()),
|
|
121
|
+
"std": float(clean.std())
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Try converting to numeric
|
|
125
|
+
try:
|
|
126
|
+
numeric = pd.to_numeric(clean, errors='coerce')
|
|
127
|
+
if numeric.notna().mean() > 0.9: # 90%+ are numeric
|
|
128
|
+
if numeric.apply(float.is_integer).all():
|
|
129
|
+
return "int", {
|
|
130
|
+
"min": int(numeric.min()),
|
|
131
|
+
"max": int(numeric.max())
|
|
132
|
+
}
|
|
133
|
+
return "float", {
|
|
134
|
+
"min": float(numeric.min()),
|
|
135
|
+
"max": float(numeric.max()),
|
|
136
|
+
"mean": float(numeric.mean()),
|
|
137
|
+
"std": float(numeric.std())
|
|
138
|
+
}
|
|
139
|
+
except:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
# Default to text
|
|
143
|
+
# Try to detect text type from column name
|
|
144
|
+
col_name = series.name.lower() if series.name else ""
|
|
145
|
+
|
|
146
|
+
if "name" in col_name:
|
|
147
|
+
return "text", {"text_type": "name"}
|
|
148
|
+
elif "email" in col_name:
|
|
149
|
+
return "text", {"text_type": "email"}
|
|
150
|
+
elif "address" in col_name:
|
|
151
|
+
return "text", {"text_type": "address"}
|
|
152
|
+
elif "company" in col_name or "org" in col_name:
|
|
153
|
+
return "text", {"text_type": "company"}
|
|
154
|
+
elif "phone" in col_name:
|
|
155
|
+
return "text", {"text_type": "phone"}
|
|
156
|
+
elif "url" in col_name or "website" in col_name:
|
|
157
|
+
return "text", {"text_type": "url"}
|
|
158
|
+
|
|
159
|
+
return "text", {"text_type": "sentence"}
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def fit_distribution(series: pd.Series) -> Dict[str, Any]:
|
|
163
|
+
"""Fit a statistical distribution to numeric data.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
series: Numeric pandas Series
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Distribution parameters including type and fitted params
|
|
170
|
+
"""
|
|
171
|
+
clean = pd.to_numeric(series.dropna(), errors='coerce').dropna()
|
|
172
|
+
if len(clean) < 5:
|
|
173
|
+
return {"distribution": "uniform", "min": 0, "max": 100}
|
|
174
|
+
|
|
175
|
+
mean = float(clean.mean())
|
|
176
|
+
std = float(clean.std())
|
|
177
|
+
min_val = float(clean.min())
|
|
178
|
+
max_val = float(clean.max())
|
|
179
|
+
skew = float(clean.skew())
|
|
180
|
+
|
|
181
|
+
# Determine best distribution based on characteristics
|
|
182
|
+
if abs(skew) < 0.5:
|
|
183
|
+
# Roughly symmetric → Normal
|
|
184
|
+
return {
|
|
185
|
+
"distribution": "normal",
|
|
186
|
+
"mean": mean,
|
|
187
|
+
"std": std,
|
|
188
|
+
"min": min_val,
|
|
189
|
+
"max": max_val
|
|
190
|
+
}
|
|
191
|
+
elif skew > 1.0 and min_val >= 0:
|
|
192
|
+
# Right-skewed, positive → Lognormal
|
|
193
|
+
return {
|
|
194
|
+
"distribution": "lognormal",
|
|
195
|
+
"mean": np.log(mean) if mean > 0 else 0,
|
|
196
|
+
"sigma": std / mean if mean > 0 else 1,
|
|
197
|
+
"min": min_val,
|
|
198
|
+
"max": max_val
|
|
199
|
+
}
|
|
200
|
+
else:
|
|
201
|
+
# Use empirical (histogram-based)
|
|
202
|
+
hist, bins = np.histogram(clean, bins=20, density=True)
|
|
203
|
+
control_points = []
|
|
204
|
+
for i in range(len(hist)):
|
|
205
|
+
x = (bins[i] + bins[i+1]) / 2
|
|
206
|
+
y = float(hist[i])
|
|
207
|
+
control_points.append({"x": x, "y": y})
|
|
208
|
+
|
|
209
|
+
return {
|
|
210
|
+
"distribution": "custom",
|
|
211
|
+
"control_points": control_points,
|
|
212
|
+
"min": min_val,
|
|
213
|
+
"max": max_val
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def infer_schema(
|
|
218
|
+
data: pd.DataFrame,
|
|
219
|
+
table_name: str = "data",
|
|
220
|
+
row_count: Optional[int] = None
|
|
221
|
+
) -> SchemaConfig:
|
|
222
|
+
"""Infer a complete schema from sample data.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
data: Sample DataFrame to analyze
|
|
226
|
+
table_name: Name for the inferred table
|
|
227
|
+
row_count: Target row count (default: 100x input)
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
SchemaConfig ready for generation
|
|
231
|
+
"""
|
|
232
|
+
if row_count is None:
|
|
233
|
+
row_count = max(len(data) * 100, 1000)
|
|
234
|
+
|
|
235
|
+
columns = []
|
|
236
|
+
for col_name in data.columns:
|
|
237
|
+
col_type, params = detect_column_type(data[col_name])
|
|
238
|
+
|
|
239
|
+
# Check for unique constraint
|
|
240
|
+
is_unique = data[col_name].nunique() == len(data)
|
|
241
|
+
|
|
242
|
+
column = Column(
|
|
243
|
+
name=str(col_name),
|
|
244
|
+
table_name=table_name,
|
|
245
|
+
type=col_type,
|
|
246
|
+
distribution_params=params,
|
|
247
|
+
nullable=data[col_name].isna().any(),
|
|
248
|
+
unique=is_unique
|
|
249
|
+
)
|
|
250
|
+
columns.append(column)
|
|
251
|
+
|
|
252
|
+
return SchemaConfig(
|
|
253
|
+
name=f"Inferred: {table_name}",
|
|
254
|
+
tables=[Table(
|
|
255
|
+
name=table_name,
|
|
256
|
+
row_count=row_count,
|
|
257
|
+
columns=[c.name for c in columns]
|
|
258
|
+
)],
|
|
259
|
+
columns={table_name: columns},
|
|
260
|
+
relationships=[]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def detect_correlations(data: pd.DataFrame) -> List[Dict[str, Any]]:
|
|
265
|
+
"""Detect correlations between numeric columns.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
data: DataFrame to analyze
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
List of correlation dicts with column pairs and strength
|
|
272
|
+
"""
|
|
273
|
+
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
|
274
|
+
if len(numeric_cols) < 2:
|
|
275
|
+
return []
|
|
276
|
+
|
|
277
|
+
correlations = []
|
|
278
|
+
corr_matrix = data[numeric_cols].corr()
|
|
279
|
+
|
|
280
|
+
for i, col1 in enumerate(numeric_cols):
|
|
281
|
+
for col2 in numeric_cols[i+1:]:
|
|
282
|
+
corr = corr_matrix.loc[col1, col2]
|
|
283
|
+
if abs(corr) > 0.5: # Only report strong correlations
|
|
284
|
+
correlations.append({
|
|
285
|
+
"column1": col1,
|
|
286
|
+
"column2": col2,
|
|
287
|
+
"correlation": round(corr, 3),
|
|
288
|
+
"strength": "strong" if abs(corr) > 0.7 else "moderate"
|
|
289
|
+
})
|
|
290
|
+
|
|
291
|
+
return correlations
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def schema_to_dict(schema: SchemaConfig) -> Dict[str, Any]:
|
|
295
|
+
"""Convert schema to a JSON-serializable dict for the UI."""
|
|
296
|
+
return {
|
|
297
|
+
"name": schema.name,
|
|
298
|
+
"tables": [
|
|
299
|
+
{
|
|
300
|
+
"name": t.name,
|
|
301
|
+
"row_count": t.row_count,
|
|
302
|
+
"columns": t.columns
|
|
303
|
+
}
|
|
304
|
+
for t in schema.tables
|
|
305
|
+
],
|
|
306
|
+
"columns": {
|
|
307
|
+
table_name: [
|
|
308
|
+
{
|
|
309
|
+
"name": c.name,
|
|
310
|
+
"type": c.type,
|
|
311
|
+
"params": c.distribution_params,
|
|
312
|
+
"nullable": c.nullable,
|
|
313
|
+
"unique": c.unique
|
|
314
|
+
}
|
|
315
|
+
for c in cols
|
|
316
|
+
]
|
|
317
|
+
for table_name, cols in schema.columns.items()
|
|
318
|
+
}
|
|
319
|
+
}
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Outcome Curve Designer - Reverse Time-Series Generation
|
|
3
|
+
|
|
4
|
+
The killer feature: Users draw the aggregated outcome they want,
|
|
5
|
+
and Misata generates individual transactions that produce that exact curve.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
User draws: Revenue = [$100K, $150K, $200K, $180K, ...] over 12 months
|
|
9
|
+
Misata generates: 50,000 individual orders with dates/amounts
|
|
10
|
+
When aggregated: SUM(amount) GROUP BY month = exactly the drawn curve
|
|
11
|
+
|
|
12
|
+
Algorithm:
|
|
13
|
+
1. Parse curve control points into time buckets
|
|
14
|
+
2. For each bucket, calculate target aggregate
|
|
15
|
+
3. Distribute transactions across bucket:
|
|
16
|
+
- Determine transaction count (based on avg ticket or specified)
|
|
17
|
+
- Generate individual amounts that sum to target
|
|
18
|
+
4. Add variance/noise for realism
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from datetime import datetime, timedelta
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CurvePoint:
|
|
30
|
+
"""A single point on the outcome curve."""
|
|
31
|
+
timestamp: datetime
|
|
32
|
+
value: float
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class OutcomeCurve:
|
|
37
|
+
"""Represents the target outcome curve drawn by user."""
|
|
38
|
+
metric_name: str # e.g., "revenue", "signups", "orders"
|
|
39
|
+
time_unit: str # "day", "week", "month"
|
|
40
|
+
points: List[CurvePoint]
|
|
41
|
+
|
|
42
|
+
# Optional constraints
|
|
43
|
+
avg_transaction_value: Optional[float] = None # For revenue curves
|
|
44
|
+
min_transactions_per_period: int = 10
|
|
45
|
+
max_transactions_per_period: int = 10000
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def interpolate_curve(curve: OutcomeCurve, num_buckets: int) -> List[float]:
|
|
49
|
+
"""Interpolate curve to get values for each time bucket."""
|
|
50
|
+
if len(curve.points) < 2:
|
|
51
|
+
return [curve.points[0].value] * num_buckets
|
|
52
|
+
|
|
53
|
+
# Extract x (time indices) and y (values)
|
|
54
|
+
x = np.array([i for i in range(len(curve.points))])
|
|
55
|
+
y = np.array([p.value for p in curve.points])
|
|
56
|
+
|
|
57
|
+
# Interpolate to num_buckets
|
|
58
|
+
x_new = np.linspace(0, len(curve.points) - 1, num_buckets)
|
|
59
|
+
|
|
60
|
+
from scipy.interpolate import interp1d
|
|
61
|
+
f = interp1d(x, y, kind='cubic', fill_value='extrapolate')
|
|
62
|
+
return list(np.maximum(f(x_new), 0)) # Ensure non-negative
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def generate_transactions_for_bucket(
|
|
66
|
+
target_value: float,
|
|
67
|
+
bucket_start: datetime,
|
|
68
|
+
bucket_end: datetime,
|
|
69
|
+
avg_transaction: Optional[float] = None,
|
|
70
|
+
min_transactions: int = 10,
|
|
71
|
+
max_transactions: int = 1000,
|
|
72
|
+
rng: Optional[np.random.Generator] = None
|
|
73
|
+
) -> pd.DataFrame:
|
|
74
|
+
"""Generate individual transactions that sum to target_value for a time bucket.
|
|
75
|
+
|
|
76
|
+
Returns DataFrame with columns: [timestamp, amount]
|
|
77
|
+
"""
|
|
78
|
+
if rng is None:
|
|
79
|
+
rng = np.random.default_rng()
|
|
80
|
+
|
|
81
|
+
if target_value <= 0:
|
|
82
|
+
return pd.DataFrame(columns=['timestamp', 'amount'])
|
|
83
|
+
|
|
84
|
+
# Determine number of transactions
|
|
85
|
+
if avg_transaction:
|
|
86
|
+
n_transactions = int(target_value / avg_transaction)
|
|
87
|
+
n_transactions = max(min_transactions, min(n_transactions, max_transactions))
|
|
88
|
+
else:
|
|
89
|
+
# Estimate based on target value
|
|
90
|
+
n_transactions = max(min_transactions, min(int(target_value / 50), max_transactions))
|
|
91
|
+
|
|
92
|
+
# Generate amounts that sum to target using Dirichlet distribution
|
|
93
|
+
# This ensures realistic variation while hitting exact target
|
|
94
|
+
proportions = rng.dirichlet(np.ones(n_transactions) * 2) # alpha=2 for moderate variance
|
|
95
|
+
amounts = proportions * target_value
|
|
96
|
+
|
|
97
|
+
# Add some variance to make it more realistic
|
|
98
|
+
# Small noise that doesn't change the sum significantly
|
|
99
|
+
noise = rng.normal(0, abs(target_value) * 0.001, n_transactions)
|
|
100
|
+
amounts = amounts + noise
|
|
101
|
+
|
|
102
|
+
# Adjust to hit exact target (compensate for noise)
|
|
103
|
+
amounts = amounts * (target_value / amounts.sum())
|
|
104
|
+
|
|
105
|
+
# Ensure all positive
|
|
106
|
+
amounts = np.maximum(amounts, 0.01)
|
|
107
|
+
amounts = amounts * (target_value / amounts.sum()) # Re-normalize
|
|
108
|
+
|
|
109
|
+
# Generate timestamps uniformly distributed within bucket
|
|
110
|
+
bucket_duration = (bucket_end - bucket_start).total_seconds()
|
|
111
|
+
random_seconds = rng.uniform(0, bucket_duration, n_transactions)
|
|
112
|
+
timestamps = [bucket_start + timedelta(seconds=s) for s in random_seconds]
|
|
113
|
+
|
|
114
|
+
# Sort by timestamp
|
|
115
|
+
df = pd.DataFrame({
|
|
116
|
+
'timestamp': timestamps,
|
|
117
|
+
'amount': amounts.round(2)
|
|
118
|
+
}).sort_values('timestamp').reset_index(drop=True)
|
|
119
|
+
|
|
120
|
+
return df
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def generate_from_outcome_curve(
|
|
124
|
+
curve: OutcomeCurve,
|
|
125
|
+
start_date: Optional[datetime] = None,
|
|
126
|
+
seed: int = 42
|
|
127
|
+
) -> pd.DataFrame:
|
|
128
|
+
"""Generate a full transaction dataset from an outcome curve.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
curve: The target outcome curve
|
|
132
|
+
start_date: Start date (defaults to today minus curve duration)
|
|
133
|
+
seed: Random seed for reproducibility
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
DataFrame with columns: [id, timestamp, amount] where
|
|
137
|
+
SUM(amount) GROUP BY period = the drawn curve
|
|
138
|
+
"""
|
|
139
|
+
rng = np.random.default_rng(seed)
|
|
140
|
+
|
|
141
|
+
n_periods = len(curve.points)
|
|
142
|
+
|
|
143
|
+
# Determine bucket duration
|
|
144
|
+
if curve.time_unit == "day":
|
|
145
|
+
bucket_delta = timedelta(days=1)
|
|
146
|
+
elif curve.time_unit == "week":
|
|
147
|
+
bucket_delta = timedelta(weeks=1)
|
|
148
|
+
elif curve.time_unit == "month":
|
|
149
|
+
bucket_delta = timedelta(days=30) # Approximate
|
|
150
|
+
else:
|
|
151
|
+
bucket_delta = timedelta(days=1)
|
|
152
|
+
|
|
153
|
+
# Set start date
|
|
154
|
+
if start_date is None:
|
|
155
|
+
start_date = datetime.now() - (bucket_delta * n_periods)
|
|
156
|
+
|
|
157
|
+
all_transactions = []
|
|
158
|
+
|
|
159
|
+
for i, point in enumerate(curve.points):
|
|
160
|
+
bucket_start = start_date + (bucket_delta * i)
|
|
161
|
+
bucket_end = bucket_start + bucket_delta
|
|
162
|
+
|
|
163
|
+
transactions = generate_transactions_for_bucket(
|
|
164
|
+
target_value=point.value,
|
|
165
|
+
bucket_start=bucket_start,
|
|
166
|
+
bucket_end=bucket_end,
|
|
167
|
+
avg_transaction=curve.avg_transaction_value,
|
|
168
|
+
min_transactions=curve.min_transactions_per_period,
|
|
169
|
+
max_transactions=curve.max_transactions_per_period,
|
|
170
|
+
rng=rng
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
all_transactions.append(transactions)
|
|
174
|
+
|
|
175
|
+
# Combine all transactions
|
|
176
|
+
df = pd.concat(all_transactions, ignore_index=True)
|
|
177
|
+
df.insert(0, 'id', range(1, len(df) + 1))
|
|
178
|
+
|
|
179
|
+
return df
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def verify_curve_match(
|
|
183
|
+
transactions: pd.DataFrame,
|
|
184
|
+
curve: OutcomeCurve,
|
|
185
|
+
start_date: datetime
|
|
186
|
+
) -> Dict[str, Any]:
|
|
187
|
+
"""Verify that generated transactions aggregate to match the target curve.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Dict with 'match_score', 'expected', 'actual', 'error_pct'
|
|
191
|
+
"""
|
|
192
|
+
n_periods = len(curve.points)
|
|
193
|
+
expected = np.array([p.value for p in curve.points])
|
|
194
|
+
|
|
195
|
+
# Determine bucket duration
|
|
196
|
+
if curve.time_unit == "day":
|
|
197
|
+
bucket_delta = timedelta(days=1)
|
|
198
|
+
elif curve.time_unit == "week":
|
|
199
|
+
bucket_delta = timedelta(weeks=1)
|
|
200
|
+
else: # month
|
|
201
|
+
bucket_delta = timedelta(days=30)
|
|
202
|
+
|
|
203
|
+
# Assign each transaction to a bucket index based on time offset from start
|
|
204
|
+
def get_bucket_index(ts):
|
|
205
|
+
offset = (ts - start_date).total_seconds()
|
|
206
|
+
bucket_seconds = bucket_delta.total_seconds()
|
|
207
|
+
return min(int(offset / bucket_seconds), n_periods - 1)
|
|
208
|
+
|
|
209
|
+
transactions = transactions.copy()
|
|
210
|
+
transactions['bucket_idx'] = transactions['timestamp'].apply(get_bucket_index)
|
|
211
|
+
|
|
212
|
+
# Aggregate by bucket index
|
|
213
|
+
actual_by_bucket = transactions.groupby('bucket_idx')['amount'].sum()
|
|
214
|
+
|
|
215
|
+
# Build actual array matching expected length
|
|
216
|
+
actual = np.zeros(n_periods)
|
|
217
|
+
for idx, val in actual_by_bucket.items():
|
|
218
|
+
if 0 <= idx < n_periods:
|
|
219
|
+
actual[idx] = val
|
|
220
|
+
|
|
221
|
+
# Calculate match score
|
|
222
|
+
error_pct = np.abs(actual - expected) / np.maximum(expected, 1) * 100
|
|
223
|
+
avg_error = error_pct.mean()
|
|
224
|
+
match_score = max(0, 100 - avg_error)
|
|
225
|
+
|
|
226
|
+
return {
|
|
227
|
+
'match_score': round(match_score, 2),
|
|
228
|
+
'expected': expected.tolist(),
|
|
229
|
+
'actual': actual.tolist(),
|
|
230
|
+
'error_pct': error_pct.tolist(),
|
|
231
|
+
'avg_error_pct': round(avg_error, 2)
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# ============ Preset Curve Shapes ============
|
|
236
|
+
|
|
237
|
+
def get_curve_presets() -> Dict[str, List[float]]:
|
|
238
|
+
"""Get preset curve shapes for common business patterns."""
|
|
239
|
+
return {
|
|
240
|
+
"Linear Growth": [100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300, 320],
|
|
241
|
+
"Exponential Growth": [100, 115, 132, 152, 175, 201, 231, 266, 306, 352, 405, 466],
|
|
242
|
+
"Hockey Stick": [100, 102, 105, 108, 112, 118, 140, 180, 250, 350, 500, 700],
|
|
243
|
+
"Seasonal (Retail)": [100, 80, 70, 90, 100, 120, 110, 100, 130, 160, 200, 300],
|
|
244
|
+
"SaaS Growth": [10, 18, 30, 50, 80, 120, 170, 230, 300, 380, 470, 570],
|
|
245
|
+
"Churn Decline": [1000, 920, 850, 790, 740, 700, 665, 635, 610, 590, 575, 560],
|
|
246
|
+
"V-shaped Recovery": [100, 80, 60, 50, 45, 50, 65, 85, 110, 140, 170, 200],
|
|
247
|
+
"Plateau": [100, 150, 200, 240, 270, 290, 300, 305, 308, 310, 311, 312],
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def create_curve_from_preset(
|
|
252
|
+
preset_name: str,
|
|
253
|
+
metric_name: str = "revenue",
|
|
254
|
+
time_unit: str = "month",
|
|
255
|
+
start_date: datetime = None,
|
|
256
|
+
scale: float = 1000 # Multiply preset values by this
|
|
257
|
+
) -> OutcomeCurve:
|
|
258
|
+
"""Create an OutcomeCurve from a preset shape."""
|
|
259
|
+
presets = get_curve_presets()
|
|
260
|
+
values = presets.get(preset_name, presets["Linear Growth"])
|
|
261
|
+
|
|
262
|
+
if start_date is None:
|
|
263
|
+
start_date = datetime.now() - timedelta(days=30 * len(values))
|
|
264
|
+
|
|
265
|
+
if time_unit == "day":
|
|
266
|
+
delta = timedelta(days=1)
|
|
267
|
+
elif time_unit == "week":
|
|
268
|
+
delta = timedelta(weeks=1)
|
|
269
|
+
else:
|
|
270
|
+
delta = timedelta(days=30)
|
|
271
|
+
|
|
272
|
+
points = [
|
|
273
|
+
CurvePoint(
|
|
274
|
+
timestamp=start_date + delta * i,
|
|
275
|
+
value=v * scale
|
|
276
|
+
)
|
|
277
|
+
for i, v in enumerate(values)
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
return OutcomeCurve(
|
|
281
|
+
metric_name=metric_name,
|
|
282
|
+
time_unit=time_unit,
|
|
283
|
+
points=points
|
|
284
|
+
)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import streamlit as st
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
|
|
4
|
+
class StudioStore:
|
|
5
|
+
"""Centralized state management for Misata Studio."""
|
|
6
|
+
|
|
7
|
+
@staticmethod
|
|
8
|
+
def init():
|
|
9
|
+
"""Initialize all session state variables with smart defaults."""
|
|
10
|
+
defaults = {
|
|
11
|
+
# Navigation
|
|
12
|
+
"active_tab": "Schema",
|
|
13
|
+
"sidebar_expanded": True,
|
|
14
|
+
|
|
15
|
+
# Data & Schema
|
|
16
|
+
"schema_config": None,
|
|
17
|
+
"schema_source": "Template", # "Template" or "AI"
|
|
18
|
+
"warehouse_schema": {
|
|
19
|
+
"type": "service_company",
|
|
20
|
+
"customer_count": 500,
|
|
21
|
+
"project_count": 2000
|
|
22
|
+
},
|
|
23
|
+
|
|
24
|
+
# Constraint Configuration
|
|
25
|
+
"selected_constraint": None, # e.g. "invoices.amount"
|
|
26
|
+
"warehouse_curve": [100000] * 12, # Default annual curve
|
|
27
|
+
"start_date_input": datetime.now().date(),
|
|
28
|
+
|
|
29
|
+
# Generation Config
|
|
30
|
+
"warehouse_config": {
|
|
31
|
+
"avg_transaction": 50.0,
|
|
32
|
+
"seed": 42,
|
|
33
|
+
"tier_distribution": [0.5, 0.3, 0.2]
|
|
34
|
+
},
|
|
35
|
+
|
|
36
|
+
# Results
|
|
37
|
+
"generated_warehouse": None,
|
|
38
|
+
"warehouse_generated": False
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
for key, default_val in defaults.items():
|
|
42
|
+
if key not in st.session_state:
|
|
43
|
+
st.session_state[key] = default_val
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def get(key, default=None):
|
|
47
|
+
return st.session_state.get(key, default)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def set(key, value):
|
|
51
|
+
st.session_state[key] = value
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def schema(self):
|
|
55
|
+
return st.session_state.get("schema_config")
|