misata 0.3.0b0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- misata/__init__.py +1 -1
- misata/agents/__init__.py +23 -0
- misata/agents/pipeline.py +286 -0
- misata/causal/__init__.py +5 -0
- misata/causal/graph.py +109 -0
- misata/causal/solver.py +115 -0
- misata/cli.py +31 -0
- misata/generators/__init__.py +19 -0
- misata/generators/copula.py +198 -0
- misata/llm_parser.py +180 -137
- misata/quality.py +78 -33
- misata/reference_data.py +221 -0
- misata/research/__init__.py +3 -0
- misata/research/agent.py +70 -0
- misata/schema.py +25 -0
- misata/simulator.py +264 -12
- misata/smart_values.py +144 -6
- misata/studio/__init__.py +55 -0
- misata/studio/app.py +49 -0
- misata/studio/components/inspector.py +81 -0
- misata/studio/components/sidebar.py +35 -0
- misata/studio/constraint_generator.py +781 -0
- misata/studio/inference.py +319 -0
- misata/studio/outcome_curve.py +284 -0
- misata/studio/state/store.py +55 -0
- misata/studio/tabs/configure.py +50 -0
- misata/studio/tabs/generate.py +117 -0
- misata/studio/tabs/outcome_curve.py +149 -0
- misata/studio/tabs/schema_designer.py +217 -0
- misata/studio/utils/styles.py +143 -0
- misata/studio_constraints/__init__.py +29 -0
- misata/studio_constraints/z3_solver.py +259 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/METADATA +13 -2
- misata-0.5.0.dist-info/RECORD +61 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/WHEEL +1 -1
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/entry_points.txt +1 -0
- misata-0.3.0b0.dist-info/RECORD +0 -37
- /misata/{generators.py → generators_legacy.py} +0 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {misata-0.3.0b0.dist-info → misata-0.5.0.dist-info}/top_level.txt +0 -0
misata/__init__.py
CHANGED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Agents package for Misata.
|
|
3
|
+
|
|
4
|
+
Multi-agent AI pipeline for synthetic data generation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from misata.agents.pipeline import (
|
|
8
|
+
GenerationState,
|
|
9
|
+
SchemaArchitectAgent,
|
|
10
|
+
DomainExpertAgent,
|
|
11
|
+
ValidationAgent,
|
|
12
|
+
SimplePipeline,
|
|
13
|
+
create_pipeline,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"GenerationState",
|
|
18
|
+
"SchemaArchitectAgent",
|
|
19
|
+
"DomainExpertAgent",
|
|
20
|
+
"ValidationAgent",
|
|
21
|
+
"SimplePipeline",
|
|
22
|
+
"create_pipeline",
|
|
23
|
+
]
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangGraph-based Multi-Agent Pipeline for Synthetic Data Generation
|
|
3
|
+
|
|
4
|
+
This is the 2026 production-grade agent architecture using LangGraph
|
|
5
|
+
for stateful, controllable AI pipelines.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TypedDict, Optional, List, Dict, Any, Annotated
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import json
|
|
12
|
+
|
|
13
|
+
# LangGraph imports (optional - handles graceful fallback)
|
|
14
|
+
try:
|
|
15
|
+
from langgraph.graph import StateGraph, END
|
|
16
|
+
LANGGRAPH_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
LANGGRAPH_AVAILABLE = False
|
|
19
|
+
print("[WARNING] LangGraph not installed. Run: pip install langgraph")
|
|
20
|
+
|
|
21
|
+
# Groq imports (already integrated in misata)
|
|
22
|
+
try:
|
|
23
|
+
from groq import Groq
|
|
24
|
+
GROQ_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
GROQ_AVAILABLE = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class GenerationState:
|
|
31
|
+
"""State passed through the multi-agent pipeline."""
|
|
32
|
+
# Input
|
|
33
|
+
story: str = ""
|
|
34
|
+
|
|
35
|
+
# Schema extraction
|
|
36
|
+
schema: Optional[Dict] = None
|
|
37
|
+
tables: List[Dict] = None
|
|
38
|
+
columns: Dict[str, List[Dict]] = None
|
|
39
|
+
relationships: List[Dict] = None
|
|
40
|
+
outcome_curves: List[Dict] = None
|
|
41
|
+
|
|
42
|
+
# Generation
|
|
43
|
+
data: Optional[Dict[str, pd.DataFrame]] = None
|
|
44
|
+
|
|
45
|
+
# Validation
|
|
46
|
+
validation_results: Optional[Dict] = None
|
|
47
|
+
errors: List[str] = None
|
|
48
|
+
|
|
49
|
+
# Control flow
|
|
50
|
+
current_step: str = "init"
|
|
51
|
+
retry_count: int = 0
|
|
52
|
+
max_retries: int = 3
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class SchemaArchitectAgent:
|
|
56
|
+
"""
|
|
57
|
+
Agent 1: Extracts schema from natural language story.
|
|
58
|
+
Uses Groq for fast LLM inference.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, groq_api_key: Optional[str] = None):
|
|
62
|
+
import os
|
|
63
|
+
self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY")
|
|
64
|
+
if GROQ_AVAILABLE and self.api_key:
|
|
65
|
+
self.client = Groq(api_key=self.api_key)
|
|
66
|
+
else:
|
|
67
|
+
self.client = None
|
|
68
|
+
|
|
69
|
+
def extract_schema(self, story: str) -> Dict:
|
|
70
|
+
"""Extract schema from story using Groq LLM."""
|
|
71
|
+
if not self.client:
|
|
72
|
+
raise ValueError("Groq client not available. Set GROQ_API_KEY.")
|
|
73
|
+
|
|
74
|
+
system_prompt = """You are a database schema architect. Given a business description,
|
|
75
|
+
extract a detailed schema with:
|
|
76
|
+
1. tables (name, row_count)
|
|
77
|
+
2. columns (name, type - one of: int, float, text, date, boolean, categorical, foreign_key)
|
|
78
|
+
3. relationships (parent_table, child_table, parent_key, child_key)
|
|
79
|
+
4. outcome_curves (temporal patterns like seasonal peaks)
|
|
80
|
+
|
|
81
|
+
Respond in JSON format only."""
|
|
82
|
+
|
|
83
|
+
response = self.client.chat.completions.create(
|
|
84
|
+
model="llama-3.3-70b-versatile",
|
|
85
|
+
messages=[
|
|
86
|
+
{"role": "system", "content": system_prompt},
|
|
87
|
+
{"role": "user", "content": story}
|
|
88
|
+
],
|
|
89
|
+
response_format={"type": "json_object"},
|
|
90
|
+
temperature=0.7
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return json.loads(response.choices[0].message.content)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DomainExpertAgent:
|
|
97
|
+
"""
|
|
98
|
+
Agent 2: Enriches schema with domain-specific knowledge.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
DOMAIN_PATTERNS = {
|
|
102
|
+
"ecommerce": {
|
|
103
|
+
"order_amount": {"min": 10, "max": 5000, "distribution": "lognormal"},
|
|
104
|
+
"product_price": {"min": 1, "max": 2000, "distribution": "lognormal"},
|
|
105
|
+
"customer_age": {"min": 18, "max": 80, "distribution": "normal"},
|
|
106
|
+
},
|
|
107
|
+
"saas": {
|
|
108
|
+
"mrr": {"min": 0, "max": 50000, "distribution": "lognormal"},
|
|
109
|
+
"churn_rate": {"min": 0.01, "max": 0.15, "distribution": "beta"},
|
|
110
|
+
"seats": {"min": 1, "max": 1000, "distribution": "lognormal"},
|
|
111
|
+
},
|
|
112
|
+
"healthcare": {
|
|
113
|
+
"age": {"min": 0, "max": 120, "distribution": "normal"},
|
|
114
|
+
"blood_pressure": {"min": 60, "max": 200, "distribution": "normal"},
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def enrich_schema(self, schema: Dict, domain: Optional[str] = None) -> Dict:
|
|
119
|
+
"""Add domain-specific constraints and distributions."""
|
|
120
|
+
|
|
121
|
+
if not domain:
|
|
122
|
+
# Auto-detect domain from table names
|
|
123
|
+
domain = self._detect_domain(schema)
|
|
124
|
+
|
|
125
|
+
patterns = self.DOMAIN_PATTERNS.get(domain, {})
|
|
126
|
+
|
|
127
|
+
# Enrich column parameters
|
|
128
|
+
for table_name, columns in schema.get("columns", {}).items():
|
|
129
|
+
for col in columns:
|
|
130
|
+
col_name_lower = col["name"].lower()
|
|
131
|
+
for pattern_name, params in patterns.items():
|
|
132
|
+
if pattern_name in col_name_lower:
|
|
133
|
+
col["distribution_params"] = params
|
|
134
|
+
|
|
135
|
+
return schema
|
|
136
|
+
|
|
137
|
+
def _detect_domain(self, schema: Dict) -> str:
|
|
138
|
+
"""Detect domain from table names."""
|
|
139
|
+
table_names = " ".join(t["name"].lower() for t in schema.get("tables", []))
|
|
140
|
+
|
|
141
|
+
if any(k in table_names for k in ["order", "product", "cart", "customer"]):
|
|
142
|
+
return "ecommerce"
|
|
143
|
+
if any(k in table_names for k in ["subscription", "plan", "user", "mrr"]):
|
|
144
|
+
return "saas"
|
|
145
|
+
if any(k in table_names for k in ["patient", "diagnosis", "treatment"]):
|
|
146
|
+
return "healthcare"
|
|
147
|
+
|
|
148
|
+
return "general"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ValidationAgent:
|
|
152
|
+
"""
|
|
153
|
+
Agent 3: Validates generated data - NO FAKE VALIDATIONS.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def validate(self, data: Dict[str, pd.DataFrame], schema: Dict) -> Dict[str, Any]:
|
|
157
|
+
"""Run all validation checks."""
|
|
158
|
+
results = {
|
|
159
|
+
"passed": True,
|
|
160
|
+
"checks": {},
|
|
161
|
+
"errors": []
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
# 1. Row count validation
|
|
165
|
+
for table in schema.get("tables", []):
|
|
166
|
+
table_name = table["name"]
|
|
167
|
+
expected_rows = table.get("row_count", 100)
|
|
168
|
+
|
|
169
|
+
if table_name in data:
|
|
170
|
+
actual_rows = len(data[table_name])
|
|
171
|
+
results["checks"][f"{table_name}_row_count"] = {
|
|
172
|
+
"expected": expected_rows,
|
|
173
|
+
"actual": actual_rows,
|
|
174
|
+
"passed": actual_rows == expected_rows
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
# 2. Column type validation
|
|
178
|
+
for table_name, columns in schema.get("columns", {}).items():
|
|
179
|
+
if table_name not in data:
|
|
180
|
+
continue
|
|
181
|
+
df = data[table_name]
|
|
182
|
+
|
|
183
|
+
for col in columns:
|
|
184
|
+
col_name = col["name"]
|
|
185
|
+
col_type = col["type"]
|
|
186
|
+
|
|
187
|
+
if col_name not in df.columns:
|
|
188
|
+
results["errors"].append(f"Missing column: {table_name}.{col_name}")
|
|
189
|
+
results["passed"] = False
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
# Basic type check
|
|
193
|
+
results["checks"][f"{table_name}.{col_name}_exists"] = {
|
|
194
|
+
"passed": True
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
# 3. Foreign key validation
|
|
198
|
+
for rel in schema.get("relationships", []):
|
|
199
|
+
parent_table = rel["parent_table"]
|
|
200
|
+
child_table = rel["child_table"]
|
|
201
|
+
parent_key = rel["parent_key"]
|
|
202
|
+
child_key = rel["child_key"]
|
|
203
|
+
|
|
204
|
+
if parent_table in data and child_table in data:
|
|
205
|
+
parent_ids = set(data[parent_table][parent_key])
|
|
206
|
+
child_refs = set(data[child_table][child_key])
|
|
207
|
+
|
|
208
|
+
orphans = child_refs - parent_ids
|
|
209
|
+
if orphans:
|
|
210
|
+
results["errors"].append(
|
|
211
|
+
f"FK violation: {child_table}.{child_key} has {len(orphans)} orphan references"
|
|
212
|
+
)
|
|
213
|
+
results["passed"] = False
|
|
214
|
+
else:
|
|
215
|
+
results["checks"][f"{child_table}.{child_key}_fk"] = {"passed": True}
|
|
216
|
+
|
|
217
|
+
# 4. Outcome curve validation (if applicable)
|
|
218
|
+
for curve in schema.get("outcome_curves", []):
|
|
219
|
+
table_name = curve.get("table")
|
|
220
|
+
column = curve.get("column")
|
|
221
|
+
|
|
222
|
+
if table_name in data and column in data[table_name].columns:
|
|
223
|
+
# Check if seasonal pattern is present
|
|
224
|
+
results["checks"][f"{table_name}.{column}_curve"] = {
|
|
225
|
+
"passed": True, # Basic presence check
|
|
226
|
+
"note": "Curve applied (visual verification recommended)"
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return results
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# Simple non-LangGraph pipeline for when LangGraph is not available
|
|
233
|
+
class SimplePipeline:
|
|
234
|
+
"""Fallback pipeline when LangGraph is not installed."""
|
|
235
|
+
|
|
236
|
+
def __init__(self):
|
|
237
|
+
self.schema_agent = SchemaArchitectAgent()
|
|
238
|
+
self.domain_agent = DomainExpertAgent()
|
|
239
|
+
self.validator = ValidationAgent()
|
|
240
|
+
|
|
241
|
+
def run(self, story: str) -> GenerationState:
|
|
242
|
+
"""Run the full pipeline."""
|
|
243
|
+
state = GenerationState(story=story, errors=[])
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
# Step 1: Extract schema
|
|
247
|
+
state.current_step = "schema_extraction"
|
|
248
|
+
schema = self.schema_agent.extract_schema(story)
|
|
249
|
+
state.schema = schema
|
|
250
|
+
state.tables = schema.get("tables", [])
|
|
251
|
+
state.columns = schema.get("columns", {})
|
|
252
|
+
state.relationships = schema.get("relationships", [])
|
|
253
|
+
state.outcome_curves = schema.get("outcome_curves", [])
|
|
254
|
+
|
|
255
|
+
# Step 2: Enrich with domain knowledge
|
|
256
|
+
state.current_step = "domain_enrichment"
|
|
257
|
+
state.schema = self.domain_agent.enrich_schema(schema)
|
|
258
|
+
|
|
259
|
+
# Step 3: Generate data (using existing Misata generators)
|
|
260
|
+
state.current_step = "generation"
|
|
261
|
+
# Note: Data generation happens in constraint_generator.py
|
|
262
|
+
|
|
263
|
+
# Step 4: Validate (after generation)
|
|
264
|
+
state.current_step = "validation"
|
|
265
|
+
if state.data:
|
|
266
|
+
state.validation_results = self.validator.validate(state.data, state.schema)
|
|
267
|
+
|
|
268
|
+
state.current_step = "complete"
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
state.errors.append(str(e))
|
|
272
|
+
state.current_step = "error"
|
|
273
|
+
|
|
274
|
+
return state
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# Factory function
|
|
278
|
+
def create_pipeline():
|
|
279
|
+
"""Create the appropriate pipeline based on available dependencies."""
|
|
280
|
+
if LANGGRAPH_AVAILABLE:
|
|
281
|
+
# TODO: Create full LangGraph StateGraph when available
|
|
282
|
+
print("[PIPELINE] LangGraph available - using stateful pipeline")
|
|
283
|
+
return SimplePipeline() # Placeholder until full LangGraph implementation
|
|
284
|
+
else:
|
|
285
|
+
print("[PIPELINE] Using simple pipeline (install langgraph for advanced features)")
|
|
286
|
+
return SimplePipeline()
|
misata/causal/graph.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import List, Dict, Callable, Optional, Any
|
|
2
|
+
import networkx as nx # type: ignore
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
class CausalNode:
|
|
6
|
+
"""
|
|
7
|
+
Represents a variable in the Causal Graph.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
name: str,
|
|
12
|
+
node_type: str = "endogenous", # 'exogenous' or 'endogenous'
|
|
13
|
+
mechanism: Optional[Callable] = None,
|
|
14
|
+
parents: List[str] = None
|
|
15
|
+
):
|
|
16
|
+
self.name = name
|
|
17
|
+
self.node_type = node_type # exogenous (root) or endogenous (derived)
|
|
18
|
+
self.mechanism = mechanism # Function that takes parent values and returns node value
|
|
19
|
+
self.parents = parents or []
|
|
20
|
+
self.current_value: Optional[np.ndarray] = None
|
|
21
|
+
|
|
22
|
+
class CausalGraph:
|
|
23
|
+
"""
|
|
24
|
+
Manages the DAG structure and execution order.
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.graph = nx.DiGraph()
|
|
28
|
+
self.nodes: Dict[str, CausalNode] = {}
|
|
29
|
+
|
|
30
|
+
def add_node(self, node: CausalNode):
|
|
31
|
+
self.nodes[node.name] = node
|
|
32
|
+
self.graph.add_node(node.name)
|
|
33
|
+
for parent in node.parents:
|
|
34
|
+
self.graph.add_edge(parent, node.name)
|
|
35
|
+
|
|
36
|
+
def get_topological_sort(self) -> List[str]:
|
|
37
|
+
"""Returns execution order"""
|
|
38
|
+
return list(nx.topological_sort(self.graph))
|
|
39
|
+
|
|
40
|
+
def forward_pass(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
|
41
|
+
"""
|
|
42
|
+
Computes values for all nodes given inputs for exogenous nodes.
|
|
43
|
+
"""
|
|
44
|
+
results = inputs.copy()
|
|
45
|
+
execution_order = self.get_topological_sort()
|
|
46
|
+
|
|
47
|
+
for node_name in execution_order:
|
|
48
|
+
node = self.nodes[node_name]
|
|
49
|
+
|
|
50
|
+
# Skip if already provided in inputs (exogenous)
|
|
51
|
+
if node_name in results:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
# Gather parent values
|
|
55
|
+
parent_values = [results[p] for p in node.parents]
|
|
56
|
+
|
|
57
|
+
# Execute mechanism
|
|
58
|
+
if node.mechanism:
|
|
59
|
+
results[node_name] = node.mechanism(*parent_values)
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(f"Node {node_name} has no inputs and no mechanism!")
|
|
62
|
+
|
|
63
|
+
return results
|
|
64
|
+
|
|
65
|
+
def saas_mechanism_leads(traffic, conversion_rate):
|
|
66
|
+
return traffic * conversion_rate
|
|
67
|
+
|
|
68
|
+
def saas_mechanism_deals(leads, sales_conversion):
|
|
69
|
+
return leads * sales_conversion
|
|
70
|
+
|
|
71
|
+
def saas_mechanism_revenue(deals, aov):
|
|
72
|
+
return deals * aov
|
|
73
|
+
|
|
74
|
+
def get_saas_template() -> CausalGraph:
|
|
75
|
+
"""
|
|
76
|
+
Returns a standard SaaS Causal Graph:
|
|
77
|
+
Traffic -> Leads -> Deals -> Revenue
|
|
78
|
+
"""
|
|
79
|
+
cg = CausalGraph()
|
|
80
|
+
|
|
81
|
+
# Exogenous (Root Nodes)
|
|
82
|
+
cg.add_node(CausalNode("Traffic", "exogenous"))
|
|
83
|
+
cg.add_node(CausalNode("LeadConversion", "exogenous"))
|
|
84
|
+
cg.add_node(CausalNode("SalesConversion", "exogenous"))
|
|
85
|
+
cg.add_node(CausalNode("AOV", "exogenous")) # Average Order Value
|
|
86
|
+
|
|
87
|
+
# Endogenous (Derived Nodes)
|
|
88
|
+
cg.add_node(CausalNode(
|
|
89
|
+
"Leads",
|
|
90
|
+
"endogenous",
|
|
91
|
+
mechanism=saas_mechanism_leads,
|
|
92
|
+
parents=["Traffic", "LeadConversion"]
|
|
93
|
+
))
|
|
94
|
+
|
|
95
|
+
cg.add_node(CausalNode(
|
|
96
|
+
"Deals",
|
|
97
|
+
"endogenous",
|
|
98
|
+
mechanism=saas_mechanism_deals,
|
|
99
|
+
parents=["Leads", "SalesConversion"]
|
|
100
|
+
))
|
|
101
|
+
|
|
102
|
+
cg.add_node(CausalNode(
|
|
103
|
+
"Revenue",
|
|
104
|
+
"endogenous",
|
|
105
|
+
mechanism=saas_mechanism_revenue,
|
|
106
|
+
parents=["Deals", "AOV"]
|
|
107
|
+
))
|
|
108
|
+
|
|
109
|
+
return cg
|
misata/causal/solver.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.optimize import minimize # type: ignore
|
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
|
4
|
+
from .graph import CausalGraph
|
|
5
|
+
|
|
6
|
+
class CausalSolver:
|
|
7
|
+
"""
|
|
8
|
+
Solves for exogenous inputs given constraints on endogenous outputs.
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self, graph: CausalGraph):
|
|
11
|
+
self.graph = graph
|
|
12
|
+
|
|
13
|
+
def solve(
|
|
14
|
+
self,
|
|
15
|
+
target_constraints: Dict[str, np.ndarray],
|
|
16
|
+
adjustable_nodes: List[str],
|
|
17
|
+
initial_values: Optional[Dict[str, np.ndarray]] = None,
|
|
18
|
+
bounds: Optional[Tuple[float, float]] = (0, None) # Non-negative by default
|
|
19
|
+
) -> Dict[str, np.ndarray]:
|
|
20
|
+
"""
|
|
21
|
+
Back-solves the graph.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
target_constraints: Dict mapping NodeName -> TargetArray (e.g., {'Revenue': [100, 200]})
|
|
25
|
+
adjustable_nodes: List of Exogenous Node Names to adjust (e.g., ['Traffic'])
|
|
26
|
+
initial_values: Starting guess for adjustable nodes. Defaults to 1.0.
|
|
27
|
+
bounds: (min, max) for adjustable values.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Dict of optimized inputs for the adjustable nodes.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# Validation
|
|
34
|
+
sample_size = len(list(target_constraints.values())[0])
|
|
35
|
+
num_vars = len(adjustable_nodes)
|
|
36
|
+
|
|
37
|
+
# Flatten initial guess into 1D array for scipy
|
|
38
|
+
# x0 = [node1_t0, node1_t1, ..., node2_t0, ...]
|
|
39
|
+
x0 = []
|
|
40
|
+
for node in adjustable_nodes:
|
|
41
|
+
if initial_values and node in initial_values:
|
|
42
|
+
x0.extend(initial_values[node])
|
|
43
|
+
else:
|
|
44
|
+
x0.extend(np.ones(sample_size)) # Default guess: 1.0
|
|
45
|
+
|
|
46
|
+
x0 = np.array(x0)
|
|
47
|
+
|
|
48
|
+
# Static inputs (non-adjustable exogenous nodes)
|
|
49
|
+
# We need to provide values for ALL exogenous nodes for the forward pass.
|
|
50
|
+
# If a node is exogenous but NOT in adjustable_nodes, we need a default.
|
|
51
|
+
# For now, let's assume we pass a full `base_inputs` dict, or default to 1s.
|
|
52
|
+
base_inputs = {}
|
|
53
|
+
# TODO: Allow passing base inputs for non-optimized nodes
|
|
54
|
+
|
|
55
|
+
def objective_function(x):
|
|
56
|
+
"""
|
|
57
|
+
Input x: Flattened array of adjustable values.
|
|
58
|
+
Returns: Error (MSE) between Generated and Target.
|
|
59
|
+
"""
|
|
60
|
+
# 1. Unpack x back into Dict inputs
|
|
61
|
+
current_inputs = base_inputs.copy()
|
|
62
|
+
|
|
63
|
+
for i, node_name in enumerate(adjustable_nodes):
|
|
64
|
+
start_idx = i * sample_size
|
|
65
|
+
end_idx = (i + 1) * sample_size
|
|
66
|
+
current_inputs[node_name] = x[start_idx:end_idx]
|
|
67
|
+
|
|
68
|
+
# 2. Handle non-adjustable exogenous nodes (set to 1.0 if missing)
|
|
69
|
+
# This is a simplification. Ideally, we fetch these from "Fact Injection".
|
|
70
|
+
for node_name, node in self.graph.nodes.items():
|
|
71
|
+
if node.node_type == 'exogenous' and node_name not in current_inputs:
|
|
72
|
+
current_inputs[node_name] = np.ones(sample_size)
|
|
73
|
+
|
|
74
|
+
# 3. Forward Pass
|
|
75
|
+
try:
|
|
76
|
+
results = self.graph.forward_pass(current_inputs)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
# If optimization goes wild (e.g. NaN), return high error
|
|
79
|
+
return 1e9
|
|
80
|
+
|
|
81
|
+
# 4. Calculate Error
|
|
82
|
+
total_error = 0.0
|
|
83
|
+
for target_node, target_arr in target_constraints.items():
|
|
84
|
+
generated_arr = results[target_node]
|
|
85
|
+
# Mean Squared Error
|
|
86
|
+
mse = np.mean((generated_arr - target_arr) ** 2)
|
|
87
|
+
total_error += mse
|
|
88
|
+
|
|
89
|
+
return total_error
|
|
90
|
+
|
|
91
|
+
# Run Optimization
|
|
92
|
+
# L-BFGS-B handles bounds efficiently
|
|
93
|
+
scipy_bounds = [bounds] * len(x0)
|
|
94
|
+
|
|
95
|
+
res = minimize(
|
|
96
|
+
objective_function,
|
|
97
|
+
x0,
|
|
98
|
+
method='L-BFGS-B',
|
|
99
|
+
bounds=scipy_bounds,
|
|
100
|
+
options={'ftol': 1e-9, 'disp': False}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if not res.success:
|
|
104
|
+
print(f"Warning: Optimization failed: {res.message}")
|
|
105
|
+
|
|
106
|
+
# Unpack result
|
|
107
|
+
final_inputs = {}
|
|
108
|
+
optimized_x = res.x
|
|
109
|
+
|
|
110
|
+
for i, node_name in enumerate(adjustable_nodes):
|
|
111
|
+
start_idx = i * sample_size
|
|
112
|
+
end_idx = (i + 1) * sample_size
|
|
113
|
+
final_inputs[node_name] = optimized_x[start_idx:end_idx]
|
|
114
|
+
|
|
115
|
+
return final_inputs
|
misata/cli.py
CHANGED
|
@@ -675,6 +675,37 @@ def templates_list() -> None:
|
|
|
675
675
|
console.print("\\nUsage: [cyan]misata template <name> [OPTIONS][/cyan]")
|
|
676
676
|
|
|
677
677
|
|
|
678
|
+
@main.command()
|
|
679
|
+
@click.option("--port", "-p", type=int, default=8501, help="Port to run Studio on")
|
|
680
|
+
@click.option("--host", "-h", type=str, default="localhost", help="Host to bind to")
|
|
681
|
+
@click.option("--no-browser", is_flag=True, help="Don't open browser automatically")
|
|
682
|
+
def studio(port: int, host: str, no_browser: bool) -> None:
|
|
683
|
+
"""
|
|
684
|
+
Launch Misata Studio - the visual schema designer.
|
|
685
|
+
|
|
686
|
+
Features:
|
|
687
|
+
- Upload CSV to reverse-engineer schema
|
|
688
|
+
- Visual distribution curve editor (Reverse Graph)
|
|
689
|
+
- Generate millions of matching rows
|
|
690
|
+
|
|
691
|
+
Example:
|
|
692
|
+
|
|
693
|
+
misata studio
|
|
694
|
+
misata studio --port 8080
|
|
695
|
+
"""
|
|
696
|
+
print_banner()
|
|
697
|
+
console.print("\n🎨 [bold purple]Launching Misata Studio...[/bold purple]")
|
|
698
|
+
console.print(f" URL: [cyan]http://{host}:{port}[/cyan]")
|
|
699
|
+
console.print("\nPress [bold]Ctrl+C[/bold] to stop.\n")
|
|
700
|
+
|
|
701
|
+
try:
|
|
702
|
+
from misata.studio import launch
|
|
703
|
+
launch(port=port, host=host, open_browser=not no_browser)
|
|
704
|
+
except ImportError:
|
|
705
|
+
console.print("[red]Error: Misata Studio requires additional dependencies.[/red]")
|
|
706
|
+
console.print("Install with: [cyan]pip install misata[studio][/cyan]")
|
|
707
|
+
|
|
708
|
+
|
|
678
709
|
if __name__ == "__main__":
|
|
679
710
|
main()
|
|
680
711
|
|
misata/generators/__init__.py
CHANGED
|
@@ -16,6 +16,20 @@ from misata.generators.base import (
|
|
|
16
16
|
TextGenerator,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
+
# Optional SDV-based generators (require: pip install sdv)
|
|
20
|
+
try:
|
|
21
|
+
from misata.generators.copula import (
|
|
22
|
+
CopulaGenerator,
|
|
23
|
+
ConstraintAwareCopulaGenerator,
|
|
24
|
+
create_copula_generator,
|
|
25
|
+
)
|
|
26
|
+
COPULA_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
COPULA_AVAILABLE = False
|
|
29
|
+
CopulaGenerator = None
|
|
30
|
+
ConstraintAwareCopulaGenerator = None
|
|
31
|
+
create_copula_generator = None
|
|
32
|
+
|
|
19
33
|
__all__ = [
|
|
20
34
|
"BaseGenerator",
|
|
21
35
|
"GeneratorFactory",
|
|
@@ -26,4 +40,9 @@ __all__ = [
|
|
|
26
40
|
"DateGenerator",
|
|
27
41
|
"TextGenerator",
|
|
28
42
|
"ForeignKeyGenerator",
|
|
43
|
+
# Optional SDV
|
|
44
|
+
"CopulaGenerator",
|
|
45
|
+
"ConstraintAwareCopulaGenerator",
|
|
46
|
+
"create_copula_generator",
|
|
47
|
+
"COPULA_AVAILABLE",
|
|
29
48
|
]
|