misata 0.1.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- misata/__init__.py +48 -0
- misata/api.py +460 -0
- misata/audit.py +415 -0
- misata/benchmark.py +376 -0
- misata/cli.py +680 -0
- misata/codegen.py +153 -0
- misata/curve_fitting.py +106 -0
- misata/customization.py +256 -0
- misata/feedback.py +433 -0
- misata/formulas.py +362 -0
- misata/generators.py +247 -0
- misata/hybrid.py +398 -0
- misata/llm_parser.py +493 -0
- misata/noise.py +346 -0
- misata/schema.py +252 -0
- misata/semantic.py +185 -0
- misata/simulator.py +742 -0
- misata/story_parser.py +425 -0
- misata/templates/__init__.py +444 -0
- misata/validation.py +313 -0
- misata-0.1.0b0.dist-info/METADATA +291 -0
- misata-0.1.0b0.dist-info/RECORD +25 -0
- misata-0.1.0b0.dist-info/WHEEL +5 -0
- misata-0.1.0b0.dist-info/entry_points.txt +2 -0
- misata-0.1.0b0.dist-info/top_level.txt +1 -0
misata/feedback.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Human-in-the-Loop Feedback System for Misata.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Schema correction collection and storage
|
|
6
|
+
- Learning from user feedback to improve future generations
|
|
7
|
+
- Persistent feedback database (SQLite)
|
|
8
|
+
- Feedback-aware prompt enhancement
|
|
9
|
+
|
|
10
|
+
This addresses the critic's concern: "No learning/feedback loop"
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import sqlite3
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class SchemaCorrection:
|
|
23
|
+
"""A single schema correction from user feedback."""
|
|
24
|
+
original_column: Dict[str, Any]
|
|
25
|
+
corrected_column: Dict[str, Any]
|
|
26
|
+
table_name: str
|
|
27
|
+
reason: str
|
|
28
|
+
timestamp: str
|
|
29
|
+
story_context: Optional[str] = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class FeedbackStats:
|
|
34
|
+
"""Statistics about collected feedback."""
|
|
35
|
+
total_corrections: int
|
|
36
|
+
unique_patterns: int
|
|
37
|
+
most_common_fixes: List[Tuple[str, int]]
|
|
38
|
+
columns_corrected: int
|
|
39
|
+
tables_affected: int
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FeedbackDatabase:
|
|
43
|
+
"""
|
|
44
|
+
Persistent storage for user feedback using SQLite.
|
|
45
|
+
|
|
46
|
+
Stores schema corrections that can be used to:
|
|
47
|
+
1. Improve prompts over time
|
|
48
|
+
2. Auto-fix common mistakes
|
|
49
|
+
3. Learn industry-specific patterns
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, db_path: Optional[str] = None):
|
|
53
|
+
"""
|
|
54
|
+
Initialize feedback database.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
db_path: Path to SQLite database. Defaults to ~/.misata/feedback.db
|
|
58
|
+
"""
|
|
59
|
+
if db_path is None:
|
|
60
|
+
home = Path.home()
|
|
61
|
+
misata_dir = home / ".misata"
|
|
62
|
+
misata_dir.mkdir(exist_ok=True)
|
|
63
|
+
db_path = str(misata_dir / "feedback.db")
|
|
64
|
+
|
|
65
|
+
self.db_path = db_path
|
|
66
|
+
self._init_db()
|
|
67
|
+
|
|
68
|
+
def _init_db(self):
|
|
69
|
+
"""Initialize database schema."""
|
|
70
|
+
conn = sqlite3.connect(self.db_path)
|
|
71
|
+
cursor = conn.cursor()
|
|
72
|
+
|
|
73
|
+
# Corrections table
|
|
74
|
+
cursor.execute("""
|
|
75
|
+
CREATE TABLE IF NOT EXISTS corrections (
|
|
76
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
77
|
+
timestamp TEXT NOT NULL,
|
|
78
|
+
table_name TEXT NOT NULL,
|
|
79
|
+
column_name TEXT NOT NULL,
|
|
80
|
+
original_type TEXT,
|
|
81
|
+
corrected_type TEXT,
|
|
82
|
+
original_params TEXT,
|
|
83
|
+
corrected_params TEXT,
|
|
84
|
+
reason TEXT,
|
|
85
|
+
story_context TEXT,
|
|
86
|
+
industry TEXT
|
|
87
|
+
)
|
|
88
|
+
""")
|
|
89
|
+
|
|
90
|
+
# Patterns table for learned rules
|
|
91
|
+
cursor.execute("""
|
|
92
|
+
CREATE TABLE IF NOT EXISTS patterns (
|
|
93
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
94
|
+
pattern_type TEXT NOT NULL,
|
|
95
|
+
pattern_key TEXT NOT NULL,
|
|
96
|
+
pattern_value TEXT NOT NULL,
|
|
97
|
+
confidence REAL,
|
|
98
|
+
occurrence_count INTEGER DEFAULT 1,
|
|
99
|
+
last_updated TEXT
|
|
100
|
+
)
|
|
101
|
+
""")
|
|
102
|
+
|
|
103
|
+
# Sessions table for audit logging
|
|
104
|
+
cursor.execute("""
|
|
105
|
+
CREATE TABLE IF NOT EXISTS sessions (
|
|
106
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
107
|
+
session_id TEXT UNIQUE NOT NULL,
|
|
108
|
+
start_time TEXT NOT NULL,
|
|
109
|
+
end_time TEXT,
|
|
110
|
+
story TEXT,
|
|
111
|
+
schema_generated TEXT,
|
|
112
|
+
tables_count INTEGER,
|
|
113
|
+
rows_generated INTEGER,
|
|
114
|
+
corrections_made INTEGER DEFAULT 0
|
|
115
|
+
)
|
|
116
|
+
""")
|
|
117
|
+
|
|
118
|
+
conn.commit()
|
|
119
|
+
conn.close()
|
|
120
|
+
|
|
121
|
+
def add_correction(
|
|
122
|
+
self,
|
|
123
|
+
table_name: str,
|
|
124
|
+
column_name: str,
|
|
125
|
+
original: Dict[str, Any],
|
|
126
|
+
corrected: Dict[str, Any],
|
|
127
|
+
reason: str = "",
|
|
128
|
+
story_context: str = "",
|
|
129
|
+
industry: str = ""
|
|
130
|
+
) -> int:
|
|
131
|
+
"""
|
|
132
|
+
Store a schema correction.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
table_name: Name of the table
|
|
136
|
+
column_name: Name of the corrected column
|
|
137
|
+
original: Original column definition from LLM
|
|
138
|
+
corrected: User's corrected definition
|
|
139
|
+
reason: Why the correction was made
|
|
140
|
+
story_context: Original story that generated this
|
|
141
|
+
industry: Industry context (saas, healthcare, etc.)
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
ID of the inserted correction
|
|
145
|
+
"""
|
|
146
|
+
conn = sqlite3.connect(self.db_path)
|
|
147
|
+
cursor = conn.cursor()
|
|
148
|
+
|
|
149
|
+
cursor.execute("""
|
|
150
|
+
INSERT INTO corrections (
|
|
151
|
+
timestamp, table_name, column_name,
|
|
152
|
+
original_type, corrected_type,
|
|
153
|
+
original_params, corrected_params,
|
|
154
|
+
reason, story_context, industry
|
|
155
|
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
156
|
+
""", (
|
|
157
|
+
datetime.now().isoformat(),
|
|
158
|
+
table_name,
|
|
159
|
+
column_name,
|
|
160
|
+
original.get("type"),
|
|
161
|
+
corrected.get("type"),
|
|
162
|
+
json.dumps(original.get("distribution_params", {})),
|
|
163
|
+
json.dumps(corrected.get("distribution_params", {})),
|
|
164
|
+
reason,
|
|
165
|
+
story_context,
|
|
166
|
+
industry
|
|
167
|
+
))
|
|
168
|
+
|
|
169
|
+
correction_id = cursor.lastrowid
|
|
170
|
+
|
|
171
|
+
# Update learned patterns
|
|
172
|
+
self._update_patterns(cursor, column_name, original, corrected)
|
|
173
|
+
|
|
174
|
+
conn.commit()
|
|
175
|
+
conn.close()
|
|
176
|
+
|
|
177
|
+
return correction_id
|
|
178
|
+
|
|
179
|
+
def _update_patterns(
|
|
180
|
+
self,
|
|
181
|
+
cursor: sqlite3.Cursor,
|
|
182
|
+
column_name: str,
|
|
183
|
+
original: Dict,
|
|
184
|
+
corrected: Dict
|
|
185
|
+
):
|
|
186
|
+
"""Learn patterns from corrections."""
|
|
187
|
+
# Pattern: column name -> correct type
|
|
188
|
+
pattern_key = column_name.lower()
|
|
189
|
+
pattern_value = json.dumps({
|
|
190
|
+
"type": corrected.get("type"),
|
|
191
|
+
"params": corrected.get("distribution_params", {})
|
|
192
|
+
})
|
|
193
|
+
|
|
194
|
+
# Check if pattern exists
|
|
195
|
+
cursor.execute("""
|
|
196
|
+
SELECT id, occurrence_count FROM patterns
|
|
197
|
+
WHERE pattern_type = 'column_name' AND pattern_key = ?
|
|
198
|
+
""", (pattern_key,))
|
|
199
|
+
|
|
200
|
+
existing = cursor.fetchone()
|
|
201
|
+
|
|
202
|
+
if existing:
|
|
203
|
+
# Update occurrence count
|
|
204
|
+
cursor.execute("""
|
|
205
|
+
UPDATE patterns
|
|
206
|
+
SET occurrence_count = occurrence_count + 1,
|
|
207
|
+
pattern_value = ?,
|
|
208
|
+
last_updated = ?
|
|
209
|
+
WHERE id = ?
|
|
210
|
+
""", (pattern_value, datetime.now().isoformat(), existing[0]))
|
|
211
|
+
else:
|
|
212
|
+
# Insert new pattern
|
|
213
|
+
cursor.execute("""
|
|
214
|
+
INSERT INTO patterns (
|
|
215
|
+
pattern_type, pattern_key, pattern_value,
|
|
216
|
+
confidence, last_updated
|
|
217
|
+
) VALUES (?, ?, ?, ?, ?)
|
|
218
|
+
""", ('column_name', pattern_key, pattern_value, 0.5, datetime.now().isoformat()))
|
|
219
|
+
|
|
220
|
+
def get_learned_patterns(self, min_occurrences: int = 2) -> Dict[str, Dict]:
|
|
221
|
+
"""
|
|
222
|
+
Get patterns learned from corrections.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
min_occurrences: Minimum times a pattern was seen
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dict mapping column names to suggested configurations
|
|
229
|
+
"""
|
|
230
|
+
conn = sqlite3.connect(self.db_path)
|
|
231
|
+
cursor = conn.cursor()
|
|
232
|
+
|
|
233
|
+
cursor.execute("""
|
|
234
|
+
SELECT pattern_key, pattern_value, occurrence_count
|
|
235
|
+
FROM patterns
|
|
236
|
+
WHERE pattern_type = 'column_name' AND occurrence_count >= ?
|
|
237
|
+
ORDER BY occurrence_count DESC
|
|
238
|
+
""", (min_occurrences,))
|
|
239
|
+
|
|
240
|
+
patterns = {}
|
|
241
|
+
for key, value, count in cursor.fetchall():
|
|
242
|
+
patterns[key] = {
|
|
243
|
+
"suggestion": json.loads(value),
|
|
244
|
+
"confidence": min(0.9, 0.5 + count * 0.1),
|
|
245
|
+
"occurrences": count
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
conn.close()
|
|
249
|
+
return patterns
|
|
250
|
+
|
|
251
|
+
def get_stats(self) -> FeedbackStats:
|
|
252
|
+
"""Get statistics about collected feedback."""
|
|
253
|
+
conn = sqlite3.connect(self.db_path)
|
|
254
|
+
cursor = conn.cursor()
|
|
255
|
+
|
|
256
|
+
# Total corrections
|
|
257
|
+
cursor.execute("SELECT COUNT(*) FROM corrections")
|
|
258
|
+
total = cursor.fetchone()[0]
|
|
259
|
+
|
|
260
|
+
# Unique patterns
|
|
261
|
+
cursor.execute("SELECT COUNT(DISTINCT pattern_key) FROM patterns")
|
|
262
|
+
patterns = cursor.fetchone()[0]
|
|
263
|
+
|
|
264
|
+
# Most common column fixes
|
|
265
|
+
cursor.execute("""
|
|
266
|
+
SELECT column_name, COUNT(*) as cnt
|
|
267
|
+
FROM corrections
|
|
268
|
+
GROUP BY column_name
|
|
269
|
+
ORDER BY cnt DESC
|
|
270
|
+
LIMIT 5
|
|
271
|
+
""")
|
|
272
|
+
common_fixes = cursor.fetchall()
|
|
273
|
+
|
|
274
|
+
# Unique columns
|
|
275
|
+
cursor.execute("SELECT COUNT(DISTINCT column_name) FROM corrections")
|
|
276
|
+
unique_cols = cursor.fetchone()[0]
|
|
277
|
+
|
|
278
|
+
# Unique tables
|
|
279
|
+
cursor.execute("SELECT COUNT(DISTINCT table_name) FROM corrections")
|
|
280
|
+
unique_tables = cursor.fetchone()[0]
|
|
281
|
+
|
|
282
|
+
conn.close()
|
|
283
|
+
|
|
284
|
+
return FeedbackStats(
|
|
285
|
+
total_corrections=total,
|
|
286
|
+
unique_patterns=patterns,
|
|
287
|
+
most_common_fixes=common_fixes,
|
|
288
|
+
columns_corrected=unique_cols,
|
|
289
|
+
tables_affected=unique_tables
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def generate_prompt_enhancement(self) -> str:
|
|
293
|
+
"""
|
|
294
|
+
Generate prompt enhancement based on learned corrections.
|
|
295
|
+
|
|
296
|
+
This is injected into the LLM prompt to improve future generations.
|
|
297
|
+
"""
|
|
298
|
+
patterns = self.get_learned_patterns(min_occurrences=1)
|
|
299
|
+
|
|
300
|
+
if not patterns:
|
|
301
|
+
return ""
|
|
302
|
+
|
|
303
|
+
lines = [
|
|
304
|
+
"Based on previous user corrections, apply these rules:",
|
|
305
|
+
""
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
for col_name, data in list(patterns.items())[:10]:
|
|
309
|
+
suggestion = data["suggestion"]
|
|
310
|
+
data["confidence"]
|
|
311
|
+
|
|
312
|
+
lines.append(f"- Column '{col_name}': use type '{suggestion.get('type')}' with params {suggestion.get('params')}")
|
|
313
|
+
|
|
314
|
+
return "\n".join(lines)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class HumanFeedbackLoop:
|
|
318
|
+
"""
|
|
319
|
+
Main interface for human-in-the-loop feedback.
|
|
320
|
+
|
|
321
|
+
Provides methods to:
|
|
322
|
+
1. Collect corrections from users
|
|
323
|
+
2. Apply learned patterns to new schemas
|
|
324
|
+
3. Generate enhanced prompts
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
def __init__(self, db_path: Optional[str] = None):
|
|
328
|
+
self.db = FeedbackDatabase(db_path)
|
|
329
|
+
|
|
330
|
+
def submit_correction(
|
|
331
|
+
self,
|
|
332
|
+
table_name: str,
|
|
333
|
+
column_name: str,
|
|
334
|
+
original: Dict[str, Any],
|
|
335
|
+
corrected: Dict[str, Any],
|
|
336
|
+
reason: str = "",
|
|
337
|
+
context: str = ""
|
|
338
|
+
) -> Dict[str, Any]:
|
|
339
|
+
"""
|
|
340
|
+
Submit a schema correction.
|
|
341
|
+
|
|
342
|
+
Returns confirmation with learned pattern info.
|
|
343
|
+
"""
|
|
344
|
+
correction_id = self.db.add_correction(
|
|
345
|
+
table_name=table_name,
|
|
346
|
+
column_name=column_name,
|
|
347
|
+
original=original,
|
|
348
|
+
corrected=corrected,
|
|
349
|
+
reason=reason,
|
|
350
|
+
story_context=context
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
return {
|
|
354
|
+
"id": correction_id,
|
|
355
|
+
"message": "Correction recorded. Misata will learn from this.",
|
|
356
|
+
"pattern_learned": column_name.lower()
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
def apply_learned_patterns(
|
|
360
|
+
self,
|
|
361
|
+
schema: Dict[str, Any]
|
|
362
|
+
) -> Tuple[Dict[str, Any], List[str]]:
|
|
363
|
+
"""
|
|
364
|
+
Apply learned patterns to improve a schema.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
schema: Schema to enhance
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
(enhanced_schema, list of changes made)
|
|
371
|
+
"""
|
|
372
|
+
patterns = self.db.get_learned_patterns()
|
|
373
|
+
changes = []
|
|
374
|
+
|
|
375
|
+
columns = schema.get("columns", {})
|
|
376
|
+
|
|
377
|
+
for table_name, cols in columns.items():
|
|
378
|
+
for i, col in enumerate(cols):
|
|
379
|
+
col_name = col.get("name", "").lower()
|
|
380
|
+
|
|
381
|
+
if col_name in patterns:
|
|
382
|
+
pattern = patterns[col_name]
|
|
383
|
+
if pattern["confidence"] > 0.6:
|
|
384
|
+
suggestion = pattern["suggestion"]
|
|
385
|
+
|
|
386
|
+
# Apply correction
|
|
387
|
+
old_type = col.get("type")
|
|
388
|
+
new_type = suggestion.get("type")
|
|
389
|
+
|
|
390
|
+
if old_type != new_type:
|
|
391
|
+
columns[table_name][i]["type"] = new_type
|
|
392
|
+
columns[table_name][i]["distribution_params"] = suggestion.get("params", {})
|
|
393
|
+
changes.append(
|
|
394
|
+
f"Applied learned pattern to {table_name}.{col['name']}: "
|
|
395
|
+
f"{old_type} -> {new_type} (confidence: {pattern['confidence']:.0%})"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
schema["columns"] = columns
|
|
399
|
+
return schema, changes
|
|
400
|
+
|
|
401
|
+
def get_enhanced_prompt(self) -> str:
|
|
402
|
+
"""Get prompt enhancement from learned patterns."""
|
|
403
|
+
return self.db.generate_prompt_enhancement()
|
|
404
|
+
|
|
405
|
+
def get_feedback_report(self) -> str:
|
|
406
|
+
"""Get a summary of feedback collected."""
|
|
407
|
+
stats = self.db.get_stats()
|
|
408
|
+
|
|
409
|
+
lines = [
|
|
410
|
+
"=" * 50,
|
|
411
|
+
"MISATA FEEDBACK LEARNING REPORT",
|
|
412
|
+
"=" * 50,
|
|
413
|
+
f"Total Corrections Collected: {stats.total_corrections}",
|
|
414
|
+
f"Patterns Learned: {stats.unique_patterns}",
|
|
415
|
+
f"Columns Improved: {stats.columns_corrected}",
|
|
416
|
+
f"Tables Affected: {stats.tables_affected}",
|
|
417
|
+
"",
|
|
418
|
+
"Most Common Corrections:"
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
for col, count in stats.most_common_fixes:
|
|
422
|
+
lines.append(f" - {col}: {count} corrections")
|
|
423
|
+
|
|
424
|
+
lines.append("=" * 50)
|
|
425
|
+
|
|
426
|
+
return "\n".join(lines)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
# Convenience function for CLI
|
|
430
|
+
def collect_feedback_interactive():
|
|
431
|
+
"""Interactive feedback collection (for CLI use)."""
|
|
432
|
+
loop = HumanFeedbackLoop()
|
|
433
|
+
print(loop.get_feedback_report())
|