misata 0.1.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
misata/hybrid.py ADDED
@@ -0,0 +1,398 @@
1
+ """
2
+ Hybrid Learning Module for Misata.
3
+
4
+ This module provides:
5
+ - Learn distributions from sample real data
6
+ - Combine LLM schema generation with statistical learning
7
+ - Detect and replicate correlation patterns
8
+
9
+ This addresses the critic's concern: "If user HAS data, learn from it"
10
+ """
11
+
12
+ import json
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from scipy import stats
19
+
20
+
21
+ @dataclass
22
+ class LearnedDistribution:
23
+ """A distribution learned from real data."""
24
+ column_name: str
25
+ dtype: str
26
+ distribution_type: str
27
+ parameters: Dict[str, Any]
28
+ sample_stats: Dict[str, float]
29
+
30
+ def to_schema_params(self) -> Dict[str, Any]:
31
+ """Convert to Misata schema distribution_params."""
32
+ return {
33
+ "distribution": self.distribution_type,
34
+ **self.parameters
35
+ }
36
+
37
+
38
+ @dataclass
39
+ class LearnedCorrelation:
40
+ """A correlation learned between columns."""
41
+ column1: str
42
+ column2: str
43
+ correlation: float
44
+ relationship_type: str # linear, monotonic, categorical
45
+ strength: str # weak, moderate, strong
46
+
47
+
48
+ class DistributionLearner:
49
+ """
50
+ Learn distributions from sample data.
51
+
52
+ Analyzes real data to:
53
+ 1. Detect best-fit distributions
54
+ 2. Extract parameters
55
+ 3. Identify correlations
56
+ """
57
+
58
+ def __init__(self):
59
+ self.distributions: Dict[str, LearnedDistribution] = {}
60
+ self.correlations: List[LearnedCorrelation] = []
61
+
62
+ def fit(self, df: pd.DataFrame, table_name: str = "data") -> Dict[str, Any]:
63
+ """
64
+ Learn from a DataFrame.
65
+
66
+ Args:
67
+ df: Sample data to learn from
68
+ table_name: Name for the learned table
69
+
70
+ Returns:
71
+ Schema configuration matching the learned patterns
72
+ """
73
+ columns = []
74
+
75
+ for col_name in df.columns:
76
+ col_data = df[col_name]
77
+ learned = self._learn_column(col_name, col_data)
78
+
79
+ if learned:
80
+ self.distributions[f"{table_name}.{col_name}"] = learned
81
+ columns.append({
82
+ "name": col_name,
83
+ "type": learned.dtype,
84
+ "distribution_params": learned.to_schema_params()
85
+ })
86
+
87
+ # Learn correlations
88
+ self._learn_correlations(df, table_name)
89
+
90
+ return {
91
+ "tables": [{"name": table_name, "row_count": len(df)}],
92
+ "columns": {table_name: columns},
93
+ "relationships": [],
94
+ "events": []
95
+ }
96
+
97
+ def _learn_column(self, name: str, data: pd.Series) -> Optional[LearnedDistribution]:
98
+ """Learn distribution for a single column."""
99
+ # Skip if mostly null
100
+ if data.isna().mean() > 0.5:
101
+ return None
102
+
103
+ data = data.dropna()
104
+
105
+ if len(data) < 10:
106
+ return None
107
+
108
+ # Detect dtype and appropriate distribution
109
+ if pd.api.types.is_numeric_dtype(data):
110
+ return self._learn_numeric(name, data)
111
+ elif pd.api.types.is_datetime64_any_dtype(data):
112
+ return self._learn_datetime(name, data)
113
+ else:
114
+ return self._learn_categorical(name, data)
115
+
116
+ def _learn_numeric(self, name: str, data: pd.Series) -> LearnedDistribution:
117
+ """Learn distribution for numeric column."""
118
+ values = data.values.astype(float)
119
+
120
+ # Calculate basic stats
121
+ mean = float(np.mean(values))
122
+ std = float(np.std(values))
123
+ min_val = float(np.min(values))
124
+ max_val = float(np.max(values))
125
+ skewness = float(stats.skew(values))
126
+
127
+ # Test for different distributions
128
+ distributions_to_test = [
129
+ ('normal', lambda: stats.kstest(values, 'norm', args=(mean, std))),
130
+ ('uniform', lambda: stats.kstest(values, 'uniform', args=(min_val, max_val - min_val))),
131
+ ('exponential', lambda: stats.kstest(values[values > 0], 'expon') if (values > 0).all() else (1.0, 0.0)),
132
+ ]
133
+
134
+ best_dist = 'normal'
135
+ best_p = 0
136
+
137
+ for dist_name, test_func in distributions_to_test:
138
+ try:
139
+ stat, p = test_func()
140
+ if p > best_p:
141
+ best_p = p
142
+ best_dist = dist_name
143
+ except Exception:
144
+ continue
145
+
146
+ # Build parameters based on best distribution
147
+ if best_dist == 'normal':
148
+ params = {"mean": round(mean, 2), "std": round(std, 2)}
149
+ elif best_dist == 'uniform':
150
+ params = {"min": round(min_val, 2), "max": round(max_val, 2)}
151
+ elif best_dist == 'exponential':
152
+ scale = float(np.mean(values))
153
+ params = {"scale": round(scale, 2)}
154
+ else:
155
+ params = {"mean": round(mean, 2), "std": round(std, 2)}
156
+
157
+ # Determine if int or float
158
+ is_int = np.allclose(values, np.round(values))
159
+ dtype = "int" if is_int else "float"
160
+
161
+ return LearnedDistribution(
162
+ column_name=name,
163
+ dtype=dtype,
164
+ distribution_type=best_dist,
165
+ parameters=params,
166
+ sample_stats={
167
+ "mean": mean,
168
+ "std": std,
169
+ "min": min_val,
170
+ "max": max_val,
171
+ "skewness": skewness
172
+ }
173
+ )
174
+
175
+ def _learn_datetime(self, name: str, data: pd.Series) -> LearnedDistribution:
176
+ """Learn distribution for datetime column."""
177
+ min_date = data.min()
178
+ max_date = data.max()
179
+
180
+ return LearnedDistribution(
181
+ column_name=name,
182
+ dtype="date",
183
+ distribution_type="uniform",
184
+ parameters={
185
+ "start": str(min_date.date()) if hasattr(min_date, 'date') else str(min_date),
186
+ "end": str(max_date.date()) if hasattr(max_date, 'date') else str(max_date)
187
+ },
188
+ sample_stats={
189
+ "count": len(data),
190
+ "range_days": (max_date - min_date).days if hasattr(max_date - min_date, 'days') else 0
191
+ }
192
+ )
193
+
194
+ def _learn_categorical(self, name: str, data: pd.Series) -> LearnedDistribution:
195
+ """Learn distribution for categorical column."""
196
+ value_counts = data.value_counts(normalize=True)
197
+
198
+ # If too many unique values, might be text
199
+ if len(value_counts) > 50:
200
+ # Detect if it's an email, name, etc.
201
+ sample = str(data.iloc[0]).lower()
202
+ if '@' in sample:
203
+ return LearnedDistribution(
204
+ column_name=name,
205
+ dtype="text",
206
+ distribution_type="pattern",
207
+ parameters={"text_type": "email"},
208
+ sample_stats={"unique_count": len(value_counts)}
209
+ )
210
+ else:
211
+ return LearnedDistribution(
212
+ column_name=name,
213
+ dtype="text",
214
+ distribution_type="pattern",
215
+ parameters={"text_type": "word"},
216
+ sample_stats={"unique_count": len(value_counts)}
217
+ )
218
+
219
+ choices = list(value_counts.index[:20]) # Top 20
220
+ probs = list(value_counts.values[:20])
221
+
222
+ # Normalize probabilities
223
+ total = sum(probs)
224
+ probs = [p / total for p in probs]
225
+
226
+ return LearnedDistribution(
227
+ column_name=name,
228
+ dtype="categorical",
229
+ distribution_type="categorical",
230
+ parameters={
231
+ "choices": choices,
232
+ "probabilities": [round(p, 3) for p in probs]
233
+ },
234
+ sample_stats={
235
+ "unique_count": len(value_counts),
236
+ "top_value": choices[0] if choices else None,
237
+ "entropy": float(stats.entropy(probs))
238
+ }
239
+ )
240
+
241
+ def _learn_correlations(self, df: pd.DataFrame, table_name: str):
242
+ """Learn correlations between numeric columns."""
243
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
244
+
245
+ if len(numeric_cols) < 2:
246
+ return
247
+
248
+ for i, col1 in enumerate(numeric_cols):
249
+ for col2 in numeric_cols[i+1:]:
250
+ try:
251
+ corr, p_value = stats.pearsonr(
252
+ df[col1].dropna(),
253
+ df[col2].dropna()
254
+ )
255
+
256
+ if abs(corr) > 0.3: # Non-trivial correlation
257
+ strength = (
258
+ "strong" if abs(corr) > 0.7 else
259
+ "moderate" if abs(corr) > 0.5 else
260
+ "weak"
261
+ )
262
+
263
+ self.correlations.append(LearnedCorrelation(
264
+ column1=col1,
265
+ column2=col2,
266
+ correlation=round(corr, 3),
267
+ relationship_type="linear",
268
+ strength=strength
269
+ ))
270
+ except Exception:
271
+ continue
272
+
273
+ def get_correlation_report(self) -> str:
274
+ """Get human-readable correlation report."""
275
+ if not self.correlations:
276
+ return "No significant correlations detected."
277
+
278
+ lines = ["Detected Correlations:", "-" * 40]
279
+
280
+ for corr in sorted(self.correlations, key=lambda x: abs(x.correlation), reverse=True):
281
+ direction = "positive" if corr.correlation > 0 else "negative"
282
+ lines.append(
283
+ f" {corr.column1} ↔ {corr.column2}: "
284
+ f"{corr.correlation:+.3f} ({corr.strength} {direction})"
285
+ )
286
+
287
+ return "\n".join(lines)
288
+
289
+
290
+ class HybridSchemaGenerator:
291
+ """
292
+ Combines LLM generation with statistical learning.
293
+
294
+ If user provides sample data:
295
+ 1. Learn distributions from sample
296
+ 2. Use LLM for schema structure
297
+ 3. Override LLM params with learned params
298
+ """
299
+
300
+ def __init__(self):
301
+ self.learner = DistributionLearner()
302
+ self.learned_schema: Optional[Dict] = None
303
+
304
+ def learn_from_sample(self, sample_data: Dict[str, pd.DataFrame]):
305
+ """
306
+ Learn from sample data.
307
+
308
+ Args:
309
+ sample_data: Dict of table_name -> DataFrame
310
+ """
311
+ combined_schema = {
312
+ "tables": [],
313
+ "columns": {},
314
+ "relationships": [],
315
+ "events": []
316
+ }
317
+
318
+ for table_name, df in sample_data.items():
319
+ learned = self.learner.fit(df, table_name)
320
+ combined_schema["tables"].extend(learned["tables"])
321
+ combined_schema["columns"].update(learned["columns"])
322
+
323
+ self.learned_schema = combined_schema
324
+
325
+ def enhance_llm_schema(self, llm_schema: Dict[str, Any]) -> Dict[str, Any]:
326
+ """
327
+ Enhance LLM-generated schema with learned patterns.
328
+
329
+ Args:
330
+ llm_schema: Schema from LLM
331
+
332
+ Returns:
333
+ Enhanced schema with learned distributions
334
+ """
335
+ if not self.learned_schema:
336
+ return llm_schema
337
+
338
+ enhanced = json.loads(json.dumps(llm_schema)) # Deep copy
339
+
340
+ # Override columns with learned distributions
341
+ for table_name, learned_cols in self.learned_schema["columns"].items():
342
+ if table_name in enhanced.get("columns", {}):
343
+ for learned_col in learned_cols:
344
+ col_name = learned_col["name"]
345
+
346
+ # Find matching column in LLM schema
347
+ for i, llm_col in enumerate(enhanced["columns"][table_name]):
348
+ if llm_col["name"] == col_name:
349
+ # Keep LLM structure, use learned params
350
+ enhanced["columns"][table_name][i]["distribution_params"] = \
351
+ learned_col["distribution_params"]
352
+ break
353
+
354
+ return enhanced
355
+
356
+ def generate_schema_from_csv(self, csv_path: str) -> Dict[str, Any]:
357
+ """
358
+ Generate schema from a CSV file.
359
+
360
+ Args:
361
+ csv_path: Path to CSV file
362
+
363
+ Returns:
364
+ Complete schema configuration
365
+ """
366
+ df = pd.read_csv(csv_path)
367
+ table_name = csv_path.split("/")[-1].replace(".csv", "")
368
+
369
+ return self.learner.fit(df, table_name)
370
+
371
+
372
+ # Convenience function for CLI
373
+ def learn_from_csv(csv_path: str) -> str:
374
+ """Learn and return schema from CSV file."""
375
+ generator = HybridSchemaGenerator()
376
+ schema = generator.generate_schema_from_csv(csv_path)
377
+
378
+ report = [
379
+ "=" * 50,
380
+ "MISATA HYBRID LEARNING REPORT",
381
+ "=" * 50,
382
+ f"Source: {csv_path}",
383
+ f"Tables: {len(schema['tables'])}",
384
+ "",
385
+ "Learned Columns:"
386
+ ]
387
+
388
+ for table, cols in schema["columns"].items():
389
+ report.append(f"\n{table}:")
390
+ for col in cols:
391
+ params = col["distribution_params"]
392
+ report.append(f" - {col['name']}: {col['type']} ({params.get('distribution', 'n/a')})")
393
+
394
+ report.append("")
395
+ report.append(generator.learner.get_correlation_report())
396
+ report.append("=" * 50)
397
+
398
+ return "\n".join(report)