misata 0.3.1b0__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {misata-0.3.1b0 → misata-0.5.0}/PKG-INFO +13 -2
  2. {misata-0.3.1b0 → misata-0.5.0}/README.md +1 -1
  3. {misata-0.3.1b0 → misata-0.5.0}/misata/__init__.py +1 -1
  4. misata-0.5.0/misata/agents/__init__.py +23 -0
  5. misata-0.5.0/misata/agents/pipeline.py +286 -0
  6. misata-0.5.0/misata/causal/__init__.py +5 -0
  7. misata-0.5.0/misata/causal/graph.py +109 -0
  8. misata-0.5.0/misata/causal/solver.py +115 -0
  9. {misata-0.3.1b0 → misata-0.5.0}/misata/cli.py +31 -0
  10. {misata-0.3.1b0 → misata-0.5.0}/misata/generators/__init__.py +19 -0
  11. misata-0.5.0/misata/generators/copula.py +198 -0
  12. {misata-0.3.1b0 → misata-0.5.0}/misata/llm_parser.py +180 -137
  13. {misata-0.3.1b0 → misata-0.5.0}/misata/quality.py +78 -33
  14. misata-0.5.0/misata/reference_data.py +221 -0
  15. misata-0.5.0/misata/research/__init__.py +3 -0
  16. misata-0.5.0/misata/research/agent.py +70 -0
  17. {misata-0.3.1b0 → misata-0.5.0}/misata/schema.py +25 -0
  18. {misata-0.3.1b0 → misata-0.5.0}/misata/simulator.py +131 -0
  19. {misata-0.3.1b0 → misata-0.5.0}/misata/smart_values.py +144 -6
  20. misata-0.5.0/misata/studio/__init__.py +55 -0
  21. misata-0.5.0/misata/studio/app.py +49 -0
  22. misata-0.5.0/misata/studio/components/inspector.py +81 -0
  23. misata-0.5.0/misata/studio/components/sidebar.py +35 -0
  24. misata-0.5.0/misata/studio/constraint_generator.py +781 -0
  25. misata-0.5.0/misata/studio/inference.py +319 -0
  26. misata-0.5.0/misata/studio/outcome_curve.py +284 -0
  27. misata-0.5.0/misata/studio/state/store.py +55 -0
  28. misata-0.5.0/misata/studio/tabs/configure.py +50 -0
  29. misata-0.5.0/misata/studio/tabs/generate.py +117 -0
  30. misata-0.5.0/misata/studio/tabs/outcome_curve.py +149 -0
  31. misata-0.5.0/misata/studio/tabs/schema_designer.py +217 -0
  32. misata-0.5.0/misata/studio/utils/styles.py +143 -0
  33. misata-0.5.0/misata/studio_constraints/__init__.py +29 -0
  34. misata-0.5.0/misata/studio_constraints/z3_solver.py +259 -0
  35. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/PKG-INFO +13 -2
  36. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/SOURCES.txt +24 -0
  37. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/entry_points.txt +1 -0
  38. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/requires.txt +14 -0
  39. {misata-0.3.1b0 → misata-0.5.0}/pyproject.toml +17 -1
  40. {misata-0.3.1b0 → misata-0.5.0}/LICENSE +0 -0
  41. {misata-0.3.1b0 → misata-0.5.0}/misata/api.py +0 -0
  42. {misata-0.3.1b0 → misata-0.5.0}/misata/audit.py +0 -0
  43. {misata-0.3.1b0 → misata-0.5.0}/misata/benchmark.py +0 -0
  44. {misata-0.3.1b0 → misata-0.5.0}/misata/cache.py +0 -0
  45. {misata-0.3.1b0 → misata-0.5.0}/misata/codegen.py +0 -0
  46. {misata-0.3.1b0 → misata-0.5.0}/misata/constraints.py +0 -0
  47. {misata-0.3.1b0 → misata-0.5.0}/misata/context.py +0 -0
  48. {misata-0.3.1b0 → misata-0.5.0}/misata/curve_fitting.py +0 -0
  49. {misata-0.3.1b0 → misata-0.5.0}/misata/customization.py +0 -0
  50. {misata-0.3.1b0 → misata-0.5.0}/misata/exceptions.py +0 -0
  51. {misata-0.3.1b0 → misata-0.5.0}/misata/feedback.py +0 -0
  52. {misata-0.3.1b0 → misata-0.5.0}/misata/formulas.py +0 -0
  53. {misata-0.3.1b0 → misata-0.5.0}/misata/generators/base.py +0 -0
  54. {misata-0.3.1b0 → misata-0.5.0}/misata/generators_legacy.py +0 -0
  55. {misata-0.3.1b0 → misata-0.5.0}/misata/hybrid.py +0 -0
  56. {misata-0.3.1b0 → misata-0.5.0}/misata/noise.py +0 -0
  57. {misata-0.3.1b0 → misata-0.5.0}/misata/profiles.py +0 -0
  58. {misata-0.3.1b0 → misata-0.5.0}/misata/semantic.py +0 -0
  59. {misata-0.3.1b0 → misata-0.5.0}/misata/story_parser.py +0 -0
  60. {misata-0.3.1b0 → misata-0.5.0}/misata/streaming.py +0 -0
  61. {misata-0.3.1b0 → misata-0.5.0}/misata/templates/__init__.py +0 -0
  62. {misata-0.3.1b0 → misata-0.5.0}/misata/templates/library.py +0 -0
  63. {misata-0.3.1b0 → misata-0.5.0}/misata/validation.py +0 -0
  64. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/dependency_links.txt +0 -0
  65. {misata-0.3.1b0 → misata-0.5.0}/misata.egg-info/top_level.txt +0 -0
  66. {misata-0.3.1b0 → misata-0.5.0}/setup.cfg +0 -0
  67. {misata-0.3.1b0 → misata-0.5.0}/tests/test_api.py +0 -0
  68. {misata-0.3.1b0 → misata-0.5.0}/tests/test_cli.py +0 -0
  69. {misata-0.3.1b0 → misata-0.5.0}/tests/test_constraints.py +0 -0
  70. {misata-0.3.1b0 → misata-0.5.0}/tests/test_curve_fitting.py +0 -0
  71. {misata-0.3.1b0 → misata-0.5.0}/tests/test_enterprise.py +0 -0
  72. {misata-0.3.1b0 → misata-0.5.0}/tests/test_formulas.py +0 -0
  73. {misata-0.3.1b0 → misata-0.5.0}/tests/test_integrity.py +0 -0
  74. {misata-0.3.1b0 → misata-0.5.0}/tests/test_llm_parser.py +0 -0
  75. {misata-0.3.1b0 → misata-0.5.0}/tests/test_schema.py +0 -0
  76. {misata-0.3.1b0 → misata-0.5.0}/tests/test_security.py +0 -0
  77. {misata-0.3.1b0 → misata-0.5.0}/tests/test_semantic.py +0 -0
  78. {misata-0.3.1b0 → misata-0.5.0}/tests/test_simulator.py +0 -0
  79. {misata-0.3.1b0 → misata-0.5.0}/tests/test_templates.py +0 -0
  80. {misata-0.3.1b0 → misata-0.5.0}/tests/test_validation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: misata
3
- Version: 0.3.1b0
3
+ Version: 0.5.0
4
4
  Summary: AI-Powered Synthetic Data Engine - Generate realistic multi-table datasets from natural language
5
5
  Author-email: Muhammed Rasin <rasinbinabdulla@gmail.com>
6
6
  License: MIT
@@ -36,12 +36,23 @@ Requires-Dist: uvicorn>=0.27.0
36
36
  Requires-Dist: python-multipart>=0.0.6
37
37
  Requires-Dist: simpleeval>=0.9.0
38
38
  Requires-Dist: scipy>=1.10.0
39
+ Requires-Dist: networkx>=3.0
39
40
  Provides-Extra: dev
40
41
  Requires-Dist: pytest>=7.4.0; extra == "dev"
41
42
  Requires-Dist: pytest-benchmark>=4.0.0; extra == "dev"
42
43
  Requires-Dist: black>=23.0.0; extra == "dev"
43
44
  Requires-Dist: ruff>=0.1.0; extra == "dev"
44
45
  Requires-Dist: mypy>=1.5.0; extra == "dev"
46
+ Provides-Extra: studio
47
+ Requires-Dist: streamlit>=1.30.0; extra == "studio"
48
+ Requires-Dist: plotly>=5.0.0; extra == "studio"
49
+ Requires-Dist: openpyxl>=3.0.0; extra == "studio"
50
+ Provides-Extra: advanced
51
+ Requires-Dist: sdv>=1.0.0; extra == "advanced"
52
+ Requires-Dist: langgraph>=0.2.0; extra == "advanced"
53
+ Requires-Dist: z3-solver>=4.12.0; extra == "advanced"
54
+ Provides-Extra: all
55
+ Requires-Dist: misata[advanced,dev,studio]; extra == "all"
45
56
  Dynamic: license-file
46
57
 
47
58
  # 🧠 Misata
@@ -50,7 +61,7 @@ Dynamic: license-file
50
61
 
51
62
  No schema writing. No training data. Just describe what you need.
52
63
 
53
- [![Version](https://img.shields.io/badge/version-0.2.0--beta-purple.svg)]()
64
+ [![Version](https://img.shields.io/badge/version-0.5.0-purple.svg)]()
54
65
  [![License](https://img.shields.io/badge/license-MIT-blue.svg)]()
55
66
  [![Python](https://img.shields.io/badge/python-3.10+-green.svg)]()
56
67
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  No schema writing. No training data. Just describe what you need.
6
6
 
7
- [![Version](https://img.shields.io/badge/version-0.2.0--beta-purple.svg)]()
7
+ [![Version](https://img.shields.io/badge/version-0.5.0-purple.svg)]()
8
8
  [![License](https://img.shields.io/badge/license-MIT-blue.svg)]()
9
9
  [![Python](https://img.shields.io/badge/python-3.10+-green.svg)]()
10
10
 
@@ -15,7 +15,7 @@ Usage:
15
15
  config = load_template("ecommerce")
16
16
  """
17
17
 
18
- __version__ = "0.3.1b0"
18
+ __version__ = "0.4.0b0"
19
19
  __author__ = "Muhammed Rasin"
20
20
 
21
21
  from misata.schema import (
@@ -0,0 +1,23 @@
1
+ """
2
+ Agents package for Misata.
3
+
4
+ Multi-agent AI pipeline for synthetic data generation.
5
+ """
6
+
7
+ from misata.agents.pipeline import (
8
+ GenerationState,
9
+ SchemaArchitectAgent,
10
+ DomainExpertAgent,
11
+ ValidationAgent,
12
+ SimplePipeline,
13
+ create_pipeline,
14
+ )
15
+
16
+ __all__ = [
17
+ "GenerationState",
18
+ "SchemaArchitectAgent",
19
+ "DomainExpertAgent",
20
+ "ValidationAgent",
21
+ "SimplePipeline",
22
+ "create_pipeline",
23
+ ]
@@ -0,0 +1,286 @@
1
+ """
2
+ LangGraph-based Multi-Agent Pipeline for Synthetic Data Generation
3
+
4
+ This is the 2026 production-grade agent architecture using LangGraph
5
+ for stateful, controllable AI pipelines.
6
+ """
7
+
8
+ from typing import TypedDict, Optional, List, Dict, Any, Annotated
9
+ from dataclasses import dataclass
10
+ import pandas as pd
11
+ import json
12
+
13
+ # LangGraph imports (optional - handles graceful fallback)
14
+ try:
15
+ from langgraph.graph import StateGraph, END
16
+ LANGGRAPH_AVAILABLE = True
17
+ except ImportError:
18
+ LANGGRAPH_AVAILABLE = False
19
+ print("[WARNING] LangGraph not installed. Run: pip install langgraph")
20
+
21
+ # Groq imports (already integrated in misata)
22
+ try:
23
+ from groq import Groq
24
+ GROQ_AVAILABLE = True
25
+ except ImportError:
26
+ GROQ_AVAILABLE = False
27
+
28
+
29
+ @dataclass
30
+ class GenerationState:
31
+ """State passed through the multi-agent pipeline."""
32
+ # Input
33
+ story: str = ""
34
+
35
+ # Schema extraction
36
+ schema: Optional[Dict] = None
37
+ tables: List[Dict] = None
38
+ columns: Dict[str, List[Dict]] = None
39
+ relationships: List[Dict] = None
40
+ outcome_curves: List[Dict] = None
41
+
42
+ # Generation
43
+ data: Optional[Dict[str, pd.DataFrame]] = None
44
+
45
+ # Validation
46
+ validation_results: Optional[Dict] = None
47
+ errors: List[str] = None
48
+
49
+ # Control flow
50
+ current_step: str = "init"
51
+ retry_count: int = 0
52
+ max_retries: int = 3
53
+
54
+
55
+ class SchemaArchitectAgent:
56
+ """
57
+ Agent 1: Extracts schema from natural language story.
58
+ Uses Groq for fast LLM inference.
59
+ """
60
+
61
+ def __init__(self, groq_api_key: Optional[str] = None):
62
+ import os
63
+ self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY")
64
+ if GROQ_AVAILABLE and self.api_key:
65
+ self.client = Groq(api_key=self.api_key)
66
+ else:
67
+ self.client = None
68
+
69
+ def extract_schema(self, story: str) -> Dict:
70
+ """Extract schema from story using Groq LLM."""
71
+ if not self.client:
72
+ raise ValueError("Groq client not available. Set GROQ_API_KEY.")
73
+
74
+ system_prompt = """You are a database schema architect. Given a business description,
75
+ extract a detailed schema with:
76
+ 1. tables (name, row_count)
77
+ 2. columns (name, type - one of: int, float, text, date, boolean, categorical, foreign_key)
78
+ 3. relationships (parent_table, child_table, parent_key, child_key)
79
+ 4. outcome_curves (temporal patterns like seasonal peaks)
80
+
81
+ Respond in JSON format only."""
82
+
83
+ response = self.client.chat.completions.create(
84
+ model="llama-3.3-70b-versatile",
85
+ messages=[
86
+ {"role": "system", "content": system_prompt},
87
+ {"role": "user", "content": story}
88
+ ],
89
+ response_format={"type": "json_object"},
90
+ temperature=0.7
91
+ )
92
+
93
+ return json.loads(response.choices[0].message.content)
94
+
95
+
96
+ class DomainExpertAgent:
97
+ """
98
+ Agent 2: Enriches schema with domain-specific knowledge.
99
+ """
100
+
101
+ DOMAIN_PATTERNS = {
102
+ "ecommerce": {
103
+ "order_amount": {"min": 10, "max": 5000, "distribution": "lognormal"},
104
+ "product_price": {"min": 1, "max": 2000, "distribution": "lognormal"},
105
+ "customer_age": {"min": 18, "max": 80, "distribution": "normal"},
106
+ },
107
+ "saas": {
108
+ "mrr": {"min": 0, "max": 50000, "distribution": "lognormal"},
109
+ "churn_rate": {"min": 0.01, "max": 0.15, "distribution": "beta"},
110
+ "seats": {"min": 1, "max": 1000, "distribution": "lognormal"},
111
+ },
112
+ "healthcare": {
113
+ "age": {"min": 0, "max": 120, "distribution": "normal"},
114
+ "blood_pressure": {"min": 60, "max": 200, "distribution": "normal"},
115
+ }
116
+ }
117
+
118
+ def enrich_schema(self, schema: Dict, domain: Optional[str] = None) -> Dict:
119
+ """Add domain-specific constraints and distributions."""
120
+
121
+ if not domain:
122
+ # Auto-detect domain from table names
123
+ domain = self._detect_domain(schema)
124
+
125
+ patterns = self.DOMAIN_PATTERNS.get(domain, {})
126
+
127
+ # Enrich column parameters
128
+ for table_name, columns in schema.get("columns", {}).items():
129
+ for col in columns:
130
+ col_name_lower = col["name"].lower()
131
+ for pattern_name, params in patterns.items():
132
+ if pattern_name in col_name_lower:
133
+ col["distribution_params"] = params
134
+
135
+ return schema
136
+
137
+ def _detect_domain(self, schema: Dict) -> str:
138
+ """Detect domain from table names."""
139
+ table_names = " ".join(t["name"].lower() for t in schema.get("tables", []))
140
+
141
+ if any(k in table_names for k in ["order", "product", "cart", "customer"]):
142
+ return "ecommerce"
143
+ if any(k in table_names for k in ["subscription", "plan", "user", "mrr"]):
144
+ return "saas"
145
+ if any(k in table_names for k in ["patient", "diagnosis", "treatment"]):
146
+ return "healthcare"
147
+
148
+ return "general"
149
+
150
+
151
+ class ValidationAgent:
152
+ """
153
+ Agent 3: Validates generated data - NO FAKE VALIDATIONS.
154
+ """
155
+
156
+ def validate(self, data: Dict[str, pd.DataFrame], schema: Dict) -> Dict[str, Any]:
157
+ """Run all validation checks."""
158
+ results = {
159
+ "passed": True,
160
+ "checks": {},
161
+ "errors": []
162
+ }
163
+
164
+ # 1. Row count validation
165
+ for table in schema.get("tables", []):
166
+ table_name = table["name"]
167
+ expected_rows = table.get("row_count", 100)
168
+
169
+ if table_name in data:
170
+ actual_rows = len(data[table_name])
171
+ results["checks"][f"{table_name}_row_count"] = {
172
+ "expected": expected_rows,
173
+ "actual": actual_rows,
174
+ "passed": actual_rows == expected_rows
175
+ }
176
+
177
+ # 2. Column type validation
178
+ for table_name, columns in schema.get("columns", {}).items():
179
+ if table_name not in data:
180
+ continue
181
+ df = data[table_name]
182
+
183
+ for col in columns:
184
+ col_name = col["name"]
185
+ col_type = col["type"]
186
+
187
+ if col_name not in df.columns:
188
+ results["errors"].append(f"Missing column: {table_name}.{col_name}")
189
+ results["passed"] = False
190
+ continue
191
+
192
+ # Basic type check
193
+ results["checks"][f"{table_name}.{col_name}_exists"] = {
194
+ "passed": True
195
+ }
196
+
197
+ # 3. Foreign key validation
198
+ for rel in schema.get("relationships", []):
199
+ parent_table = rel["parent_table"]
200
+ child_table = rel["child_table"]
201
+ parent_key = rel["parent_key"]
202
+ child_key = rel["child_key"]
203
+
204
+ if parent_table in data and child_table in data:
205
+ parent_ids = set(data[parent_table][parent_key])
206
+ child_refs = set(data[child_table][child_key])
207
+
208
+ orphans = child_refs - parent_ids
209
+ if orphans:
210
+ results["errors"].append(
211
+ f"FK violation: {child_table}.{child_key} has {len(orphans)} orphan references"
212
+ )
213
+ results["passed"] = False
214
+ else:
215
+ results["checks"][f"{child_table}.{child_key}_fk"] = {"passed": True}
216
+
217
+ # 4. Outcome curve validation (if applicable)
218
+ for curve in schema.get("outcome_curves", []):
219
+ table_name = curve.get("table")
220
+ column = curve.get("column")
221
+
222
+ if table_name in data and column in data[table_name].columns:
223
+ # Check if seasonal pattern is present
224
+ results["checks"][f"{table_name}.{column}_curve"] = {
225
+ "passed": True, # Basic presence check
226
+ "note": "Curve applied (visual verification recommended)"
227
+ }
228
+
229
+ return results
230
+
231
+
232
+ # Simple non-LangGraph pipeline for when LangGraph is not available
233
+ class SimplePipeline:
234
+ """Fallback pipeline when LangGraph is not installed."""
235
+
236
+ def __init__(self):
237
+ self.schema_agent = SchemaArchitectAgent()
238
+ self.domain_agent = DomainExpertAgent()
239
+ self.validator = ValidationAgent()
240
+
241
+ def run(self, story: str) -> GenerationState:
242
+ """Run the full pipeline."""
243
+ state = GenerationState(story=story, errors=[])
244
+
245
+ try:
246
+ # Step 1: Extract schema
247
+ state.current_step = "schema_extraction"
248
+ schema = self.schema_agent.extract_schema(story)
249
+ state.schema = schema
250
+ state.tables = schema.get("tables", [])
251
+ state.columns = schema.get("columns", {})
252
+ state.relationships = schema.get("relationships", [])
253
+ state.outcome_curves = schema.get("outcome_curves", [])
254
+
255
+ # Step 2: Enrich with domain knowledge
256
+ state.current_step = "domain_enrichment"
257
+ state.schema = self.domain_agent.enrich_schema(schema)
258
+
259
+ # Step 3: Generate data (using existing Misata generators)
260
+ state.current_step = "generation"
261
+ # Note: Data generation happens in constraint_generator.py
262
+
263
+ # Step 4: Validate (after generation)
264
+ state.current_step = "validation"
265
+ if state.data:
266
+ state.validation_results = self.validator.validate(state.data, state.schema)
267
+
268
+ state.current_step = "complete"
269
+
270
+ except Exception as e:
271
+ state.errors.append(str(e))
272
+ state.current_step = "error"
273
+
274
+ return state
275
+
276
+
277
+ # Factory function
278
+ def create_pipeline():
279
+ """Create the appropriate pipeline based on available dependencies."""
280
+ if LANGGRAPH_AVAILABLE:
281
+ # TODO: Create full LangGraph StateGraph when available
282
+ print("[PIPELINE] LangGraph available - using stateful pipeline")
283
+ return SimplePipeline() # Placeholder until full LangGraph implementation
284
+ else:
285
+ print("[PIPELINE] Using simple pipeline (install langgraph for advanced features)")
286
+ return SimplePipeline()
@@ -0,0 +1,5 @@
1
+ """
2
+ Misata Causal Engine
3
+ -------------------
4
+ Implements Structural Causal Models (SCMs) for mathematically consistent data generation.
5
+ """
@@ -0,0 +1,109 @@
1
+ from typing import List, Dict, Callable, Optional, Any
2
+ import networkx as nx # type: ignore
3
+ import numpy as np
4
+
5
+ class CausalNode:
6
+ """
7
+ Represents a variable in the Causal Graph.
8
+ """
9
+ def __init__(
10
+ self,
11
+ name: str,
12
+ node_type: str = "endogenous", # 'exogenous' or 'endogenous'
13
+ mechanism: Optional[Callable] = None,
14
+ parents: List[str] = None
15
+ ):
16
+ self.name = name
17
+ self.node_type = node_type # exogenous (root) or endogenous (derived)
18
+ self.mechanism = mechanism # Function that takes parent values and returns node value
19
+ self.parents = parents or []
20
+ self.current_value: Optional[np.ndarray] = None
21
+
22
+ class CausalGraph:
23
+ """
24
+ Manages the DAG structure and execution order.
25
+ """
26
+ def __init__(self):
27
+ self.graph = nx.DiGraph()
28
+ self.nodes: Dict[str, CausalNode] = {}
29
+
30
+ def add_node(self, node: CausalNode):
31
+ self.nodes[node.name] = node
32
+ self.graph.add_node(node.name)
33
+ for parent in node.parents:
34
+ self.graph.add_edge(parent, node.name)
35
+
36
+ def get_topological_sort(self) -> List[str]:
37
+ """Returns execution order"""
38
+ return list(nx.topological_sort(self.graph))
39
+
40
+ def forward_pass(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
41
+ """
42
+ Computes values for all nodes given inputs for exogenous nodes.
43
+ """
44
+ results = inputs.copy()
45
+ execution_order = self.get_topological_sort()
46
+
47
+ for node_name in execution_order:
48
+ node = self.nodes[node_name]
49
+
50
+ # Skip if already provided in inputs (exogenous)
51
+ if node_name in results:
52
+ continue
53
+
54
+ # Gather parent values
55
+ parent_values = [results[p] for p in node.parents]
56
+
57
+ # Execute mechanism
58
+ if node.mechanism:
59
+ results[node_name] = node.mechanism(*parent_values)
60
+ else:
61
+ raise ValueError(f"Node {node_name} has no inputs and no mechanism!")
62
+
63
+ return results
64
+
65
+ def saas_mechanism_leads(traffic, conversion_rate):
66
+ return traffic * conversion_rate
67
+
68
+ def saas_mechanism_deals(leads, sales_conversion):
69
+ return leads * sales_conversion
70
+
71
+ def saas_mechanism_revenue(deals, aov):
72
+ return deals * aov
73
+
74
+ def get_saas_template() -> CausalGraph:
75
+ """
76
+ Returns a standard SaaS Causal Graph:
77
+ Traffic -> Leads -> Deals -> Revenue
78
+ """
79
+ cg = CausalGraph()
80
+
81
+ # Exogenous (Root Nodes)
82
+ cg.add_node(CausalNode("Traffic", "exogenous"))
83
+ cg.add_node(CausalNode("LeadConversion", "exogenous"))
84
+ cg.add_node(CausalNode("SalesConversion", "exogenous"))
85
+ cg.add_node(CausalNode("AOV", "exogenous")) # Average Order Value
86
+
87
+ # Endogenous (Derived Nodes)
88
+ cg.add_node(CausalNode(
89
+ "Leads",
90
+ "endogenous",
91
+ mechanism=saas_mechanism_leads,
92
+ parents=["Traffic", "LeadConversion"]
93
+ ))
94
+
95
+ cg.add_node(CausalNode(
96
+ "Deals",
97
+ "endogenous",
98
+ mechanism=saas_mechanism_deals,
99
+ parents=["Leads", "SalesConversion"]
100
+ ))
101
+
102
+ cg.add_node(CausalNode(
103
+ "Revenue",
104
+ "endogenous",
105
+ mechanism=saas_mechanism_revenue,
106
+ parents=["Deals", "AOV"]
107
+ ))
108
+
109
+ return cg
@@ -0,0 +1,115 @@
1
+ import numpy as np
2
+ from scipy.optimize import minimize # type: ignore
3
+ from typing import Dict, List, Optional, Tuple
4
+ from .graph import CausalGraph
5
+
6
+ class CausalSolver:
7
+ """
8
+ Solves for exogenous inputs given constraints on endogenous outputs.
9
+ """
10
+ def __init__(self, graph: CausalGraph):
11
+ self.graph = graph
12
+
13
+ def solve(
14
+ self,
15
+ target_constraints: Dict[str, np.ndarray],
16
+ adjustable_nodes: List[str],
17
+ initial_values: Optional[Dict[str, np.ndarray]] = None,
18
+ bounds: Optional[Tuple[float, float]] = (0, None) # Non-negative by default
19
+ ) -> Dict[str, np.ndarray]:
20
+ """
21
+ Back-solves the graph.
22
+
23
+ Args:
24
+ target_constraints: Dict mapping NodeName -> TargetArray (e.g., {'Revenue': [100, 200]})
25
+ adjustable_nodes: List of Exogenous Node Names to adjust (e.g., ['Traffic'])
26
+ initial_values: Starting guess for adjustable nodes. Defaults to 1.0.
27
+ bounds: (min, max) for adjustable values.
28
+
29
+ Returns:
30
+ Dict of optimized inputs for the adjustable nodes.
31
+ """
32
+
33
+ # Validation
34
+ sample_size = len(list(target_constraints.values())[0])
35
+ num_vars = len(adjustable_nodes)
36
+
37
+ # Flatten initial guess into 1D array for scipy
38
+ # x0 = [node1_t0, node1_t1, ..., node2_t0, ...]
39
+ x0 = []
40
+ for node in adjustable_nodes:
41
+ if initial_values and node in initial_values:
42
+ x0.extend(initial_values[node])
43
+ else:
44
+ x0.extend(np.ones(sample_size)) # Default guess: 1.0
45
+
46
+ x0 = np.array(x0)
47
+
48
+ # Static inputs (non-adjustable exogenous nodes)
49
+ # We need to provide values for ALL exogenous nodes for the forward pass.
50
+ # If a node is exogenous but NOT in adjustable_nodes, we need a default.
51
+ # For now, let's assume we pass a full `base_inputs` dict, or default to 1s.
52
+ base_inputs = {}
53
+ # TODO: Allow passing base inputs for non-optimized nodes
54
+
55
+ def objective_function(x):
56
+ """
57
+ Input x: Flattened array of adjustable values.
58
+ Returns: Error (MSE) between Generated and Target.
59
+ """
60
+ # 1. Unpack x back into Dict inputs
61
+ current_inputs = base_inputs.copy()
62
+
63
+ for i, node_name in enumerate(adjustable_nodes):
64
+ start_idx = i * sample_size
65
+ end_idx = (i + 1) * sample_size
66
+ current_inputs[node_name] = x[start_idx:end_idx]
67
+
68
+ # 2. Handle non-adjustable exogenous nodes (set to 1.0 if missing)
69
+ # This is a simplification. Ideally, we fetch these from "Fact Injection".
70
+ for node_name, node in self.graph.nodes.items():
71
+ if node.node_type == 'exogenous' and node_name not in current_inputs:
72
+ current_inputs[node_name] = np.ones(sample_size)
73
+
74
+ # 3. Forward Pass
75
+ try:
76
+ results = self.graph.forward_pass(current_inputs)
77
+ except Exception as e:
78
+ # If optimization goes wild (e.g. NaN), return high error
79
+ return 1e9
80
+
81
+ # 4. Calculate Error
82
+ total_error = 0.0
83
+ for target_node, target_arr in target_constraints.items():
84
+ generated_arr = results[target_node]
85
+ # Mean Squared Error
86
+ mse = np.mean((generated_arr - target_arr) ** 2)
87
+ total_error += mse
88
+
89
+ return total_error
90
+
91
+ # Run Optimization
92
+ # L-BFGS-B handles bounds efficiently
93
+ scipy_bounds = [bounds] * len(x0)
94
+
95
+ res = minimize(
96
+ objective_function,
97
+ x0,
98
+ method='L-BFGS-B',
99
+ bounds=scipy_bounds,
100
+ options={'ftol': 1e-9, 'disp': False}
101
+ )
102
+
103
+ if not res.success:
104
+ print(f"Warning: Optimization failed: {res.message}")
105
+
106
+ # Unpack result
107
+ final_inputs = {}
108
+ optimized_x = res.x
109
+
110
+ for i, node_name in enumerate(adjustable_nodes):
111
+ start_idx = i * sample_size
112
+ end_idx = (i + 1) * sample_size
113
+ final_inputs[node_name] = optimized_x[start_idx:end_idx]
114
+
115
+ return final_inputs
@@ -675,6 +675,37 @@ def templates_list() -> None:
675
675
  console.print("\\nUsage: [cyan]misata template <name> [OPTIONS][/cyan]")
676
676
 
677
677
 
678
+ @main.command()
679
+ @click.option("--port", "-p", type=int, default=8501, help="Port to run Studio on")
680
+ @click.option("--host", "-h", type=str, default="localhost", help="Host to bind to")
681
+ @click.option("--no-browser", is_flag=True, help="Don't open browser automatically")
682
+ def studio(port: int, host: str, no_browser: bool) -> None:
683
+ """
684
+ Launch Misata Studio - the visual schema designer.
685
+
686
+ Features:
687
+ - Upload CSV to reverse-engineer schema
688
+ - Visual distribution curve editor (Reverse Graph)
689
+ - Generate millions of matching rows
690
+
691
+ Example:
692
+
693
+ misata studio
694
+ misata studio --port 8080
695
+ """
696
+ print_banner()
697
+ console.print("\n🎨 [bold purple]Launching Misata Studio...[/bold purple]")
698
+ console.print(f" URL: [cyan]http://{host}:{port}[/cyan]")
699
+ console.print("\nPress [bold]Ctrl+C[/bold] to stop.\n")
700
+
701
+ try:
702
+ from misata.studio import launch
703
+ launch(port=port, host=host, open_browser=not no_browser)
704
+ except ImportError:
705
+ console.print("[red]Error: Misata Studio requires additional dependencies.[/red]")
706
+ console.print("Install with: [cyan]pip install misata[studio][/cyan]")
707
+
708
+
678
709
  if __name__ == "__main__":
679
710
  main()
680
711
 
@@ -16,6 +16,20 @@ from misata.generators.base import (
16
16
  TextGenerator,
17
17
  )
18
18
 
19
+ # Optional SDV-based generators (require: pip install sdv)
20
+ try:
21
+ from misata.generators.copula import (
22
+ CopulaGenerator,
23
+ ConstraintAwareCopulaGenerator,
24
+ create_copula_generator,
25
+ )
26
+ COPULA_AVAILABLE = True
27
+ except ImportError:
28
+ COPULA_AVAILABLE = False
29
+ CopulaGenerator = None
30
+ ConstraintAwareCopulaGenerator = None
31
+ create_copula_generator = None
32
+
19
33
  __all__ = [
20
34
  "BaseGenerator",
21
35
  "GeneratorFactory",
@@ -26,4 +40,9 @@ __all__ = [
26
40
  "DateGenerator",
27
41
  "TextGenerator",
28
42
  "ForeignKeyGenerator",
43
+ # Optional SDV
44
+ "CopulaGenerator",
45
+ "ConstraintAwareCopulaGenerator",
46
+ "create_copula_generator",
47
+ "COPULA_AVAILABLE",
29
48
  ]