misata 0.1.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
misata/feedback.py ADDED
@@ -0,0 +1,433 @@
1
+ """
2
+ Human-in-the-Loop Feedback System for Misata.
3
+
4
+ This module provides:
5
+ - Schema correction collection and storage
6
+ - Learning from user feedback to improve future generations
7
+ - Persistent feedback database (SQLite)
8
+ - Feedback-aware prompt enhancement
9
+
10
+ This addresses the critic's concern: "No learning/feedback loop"
11
+ """
12
+
13
+ import json
14
+ import sqlite3
15
+ from dataclasses import dataclass
16
+ from datetime import datetime
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+ from pathlib import Path
19
+
20
+
21
+ @dataclass
22
+ class SchemaCorrection:
23
+ """A single schema correction from user feedback."""
24
+ original_column: Dict[str, Any]
25
+ corrected_column: Dict[str, Any]
26
+ table_name: str
27
+ reason: str
28
+ timestamp: str
29
+ story_context: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class FeedbackStats:
34
+ """Statistics about collected feedback."""
35
+ total_corrections: int
36
+ unique_patterns: int
37
+ most_common_fixes: List[Tuple[str, int]]
38
+ columns_corrected: int
39
+ tables_affected: int
40
+
41
+
42
+ class FeedbackDatabase:
43
+ """
44
+ Persistent storage for user feedback using SQLite.
45
+
46
+ Stores schema corrections that can be used to:
47
+ 1. Improve prompts over time
48
+ 2. Auto-fix common mistakes
49
+ 3. Learn industry-specific patterns
50
+ """
51
+
52
+ def __init__(self, db_path: Optional[str] = None):
53
+ """
54
+ Initialize feedback database.
55
+
56
+ Args:
57
+ db_path: Path to SQLite database. Defaults to ~/.misata/feedback.db
58
+ """
59
+ if db_path is None:
60
+ home = Path.home()
61
+ misata_dir = home / ".misata"
62
+ misata_dir.mkdir(exist_ok=True)
63
+ db_path = str(misata_dir / "feedback.db")
64
+
65
+ self.db_path = db_path
66
+ self._init_db()
67
+
68
+ def _init_db(self):
69
+ """Initialize database schema."""
70
+ conn = sqlite3.connect(self.db_path)
71
+ cursor = conn.cursor()
72
+
73
+ # Corrections table
74
+ cursor.execute("""
75
+ CREATE TABLE IF NOT EXISTS corrections (
76
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
77
+ timestamp TEXT NOT NULL,
78
+ table_name TEXT NOT NULL,
79
+ column_name TEXT NOT NULL,
80
+ original_type TEXT,
81
+ corrected_type TEXT,
82
+ original_params TEXT,
83
+ corrected_params TEXT,
84
+ reason TEXT,
85
+ story_context TEXT,
86
+ industry TEXT
87
+ )
88
+ """)
89
+
90
+ # Patterns table for learned rules
91
+ cursor.execute("""
92
+ CREATE TABLE IF NOT EXISTS patterns (
93
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
94
+ pattern_type TEXT NOT NULL,
95
+ pattern_key TEXT NOT NULL,
96
+ pattern_value TEXT NOT NULL,
97
+ confidence REAL,
98
+ occurrence_count INTEGER DEFAULT 1,
99
+ last_updated TEXT
100
+ )
101
+ """)
102
+
103
+ # Sessions table for audit logging
104
+ cursor.execute("""
105
+ CREATE TABLE IF NOT EXISTS sessions (
106
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
107
+ session_id TEXT UNIQUE NOT NULL,
108
+ start_time TEXT NOT NULL,
109
+ end_time TEXT,
110
+ story TEXT,
111
+ schema_generated TEXT,
112
+ tables_count INTEGER,
113
+ rows_generated INTEGER,
114
+ corrections_made INTEGER DEFAULT 0
115
+ )
116
+ """)
117
+
118
+ conn.commit()
119
+ conn.close()
120
+
121
+ def add_correction(
122
+ self,
123
+ table_name: str,
124
+ column_name: str,
125
+ original: Dict[str, Any],
126
+ corrected: Dict[str, Any],
127
+ reason: str = "",
128
+ story_context: str = "",
129
+ industry: str = ""
130
+ ) -> int:
131
+ """
132
+ Store a schema correction.
133
+
134
+ Args:
135
+ table_name: Name of the table
136
+ column_name: Name of the corrected column
137
+ original: Original column definition from LLM
138
+ corrected: User's corrected definition
139
+ reason: Why the correction was made
140
+ story_context: Original story that generated this
141
+ industry: Industry context (saas, healthcare, etc.)
142
+
143
+ Returns:
144
+ ID of the inserted correction
145
+ """
146
+ conn = sqlite3.connect(self.db_path)
147
+ cursor = conn.cursor()
148
+
149
+ cursor.execute("""
150
+ INSERT INTO corrections (
151
+ timestamp, table_name, column_name,
152
+ original_type, corrected_type,
153
+ original_params, corrected_params,
154
+ reason, story_context, industry
155
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
156
+ """, (
157
+ datetime.now().isoformat(),
158
+ table_name,
159
+ column_name,
160
+ original.get("type"),
161
+ corrected.get("type"),
162
+ json.dumps(original.get("distribution_params", {})),
163
+ json.dumps(corrected.get("distribution_params", {})),
164
+ reason,
165
+ story_context,
166
+ industry
167
+ ))
168
+
169
+ correction_id = cursor.lastrowid
170
+
171
+ # Update learned patterns
172
+ self._update_patterns(cursor, column_name, original, corrected)
173
+
174
+ conn.commit()
175
+ conn.close()
176
+
177
+ return correction_id
178
+
179
+ def _update_patterns(
180
+ self,
181
+ cursor: sqlite3.Cursor,
182
+ column_name: str,
183
+ original: Dict,
184
+ corrected: Dict
185
+ ):
186
+ """Learn patterns from corrections."""
187
+ # Pattern: column name -> correct type
188
+ pattern_key = column_name.lower()
189
+ pattern_value = json.dumps({
190
+ "type": corrected.get("type"),
191
+ "params": corrected.get("distribution_params", {})
192
+ })
193
+
194
+ # Check if pattern exists
195
+ cursor.execute("""
196
+ SELECT id, occurrence_count FROM patterns
197
+ WHERE pattern_type = 'column_name' AND pattern_key = ?
198
+ """, (pattern_key,))
199
+
200
+ existing = cursor.fetchone()
201
+
202
+ if existing:
203
+ # Update occurrence count
204
+ cursor.execute("""
205
+ UPDATE patterns
206
+ SET occurrence_count = occurrence_count + 1,
207
+ pattern_value = ?,
208
+ last_updated = ?
209
+ WHERE id = ?
210
+ """, (pattern_value, datetime.now().isoformat(), existing[0]))
211
+ else:
212
+ # Insert new pattern
213
+ cursor.execute("""
214
+ INSERT INTO patterns (
215
+ pattern_type, pattern_key, pattern_value,
216
+ confidence, last_updated
217
+ ) VALUES (?, ?, ?, ?, ?)
218
+ """, ('column_name', pattern_key, pattern_value, 0.5, datetime.now().isoformat()))
219
+
220
+ def get_learned_patterns(self, min_occurrences: int = 2) -> Dict[str, Dict]:
221
+ """
222
+ Get patterns learned from corrections.
223
+
224
+ Args:
225
+ min_occurrences: Minimum times a pattern was seen
226
+
227
+ Returns:
228
+ Dict mapping column names to suggested configurations
229
+ """
230
+ conn = sqlite3.connect(self.db_path)
231
+ cursor = conn.cursor()
232
+
233
+ cursor.execute("""
234
+ SELECT pattern_key, pattern_value, occurrence_count
235
+ FROM patterns
236
+ WHERE pattern_type = 'column_name' AND occurrence_count >= ?
237
+ ORDER BY occurrence_count DESC
238
+ """, (min_occurrences,))
239
+
240
+ patterns = {}
241
+ for key, value, count in cursor.fetchall():
242
+ patterns[key] = {
243
+ "suggestion": json.loads(value),
244
+ "confidence": min(0.9, 0.5 + count * 0.1),
245
+ "occurrences": count
246
+ }
247
+
248
+ conn.close()
249
+ return patterns
250
+
251
+ def get_stats(self) -> FeedbackStats:
252
+ """Get statistics about collected feedback."""
253
+ conn = sqlite3.connect(self.db_path)
254
+ cursor = conn.cursor()
255
+
256
+ # Total corrections
257
+ cursor.execute("SELECT COUNT(*) FROM corrections")
258
+ total = cursor.fetchone()[0]
259
+
260
+ # Unique patterns
261
+ cursor.execute("SELECT COUNT(DISTINCT pattern_key) FROM patterns")
262
+ patterns = cursor.fetchone()[0]
263
+
264
+ # Most common column fixes
265
+ cursor.execute("""
266
+ SELECT column_name, COUNT(*) as cnt
267
+ FROM corrections
268
+ GROUP BY column_name
269
+ ORDER BY cnt DESC
270
+ LIMIT 5
271
+ """)
272
+ common_fixes = cursor.fetchall()
273
+
274
+ # Unique columns
275
+ cursor.execute("SELECT COUNT(DISTINCT column_name) FROM corrections")
276
+ unique_cols = cursor.fetchone()[0]
277
+
278
+ # Unique tables
279
+ cursor.execute("SELECT COUNT(DISTINCT table_name) FROM corrections")
280
+ unique_tables = cursor.fetchone()[0]
281
+
282
+ conn.close()
283
+
284
+ return FeedbackStats(
285
+ total_corrections=total,
286
+ unique_patterns=patterns,
287
+ most_common_fixes=common_fixes,
288
+ columns_corrected=unique_cols,
289
+ tables_affected=unique_tables
290
+ )
291
+
292
+ def generate_prompt_enhancement(self) -> str:
293
+ """
294
+ Generate prompt enhancement based on learned corrections.
295
+
296
+ This is injected into the LLM prompt to improve future generations.
297
+ """
298
+ patterns = self.get_learned_patterns(min_occurrences=1)
299
+
300
+ if not patterns:
301
+ return ""
302
+
303
+ lines = [
304
+ "Based on previous user corrections, apply these rules:",
305
+ ""
306
+ ]
307
+
308
+ for col_name, data in list(patterns.items())[:10]:
309
+ suggestion = data["suggestion"]
310
+ data["confidence"]
311
+
312
+ lines.append(f"- Column '{col_name}': use type '{suggestion.get('type')}' with params {suggestion.get('params')}")
313
+
314
+ return "\n".join(lines)
315
+
316
+
317
+ class HumanFeedbackLoop:
318
+ """
319
+ Main interface for human-in-the-loop feedback.
320
+
321
+ Provides methods to:
322
+ 1. Collect corrections from users
323
+ 2. Apply learned patterns to new schemas
324
+ 3. Generate enhanced prompts
325
+ """
326
+
327
+ def __init__(self, db_path: Optional[str] = None):
328
+ self.db = FeedbackDatabase(db_path)
329
+
330
+ def submit_correction(
331
+ self,
332
+ table_name: str,
333
+ column_name: str,
334
+ original: Dict[str, Any],
335
+ corrected: Dict[str, Any],
336
+ reason: str = "",
337
+ context: str = ""
338
+ ) -> Dict[str, Any]:
339
+ """
340
+ Submit a schema correction.
341
+
342
+ Returns confirmation with learned pattern info.
343
+ """
344
+ correction_id = self.db.add_correction(
345
+ table_name=table_name,
346
+ column_name=column_name,
347
+ original=original,
348
+ corrected=corrected,
349
+ reason=reason,
350
+ story_context=context
351
+ )
352
+
353
+ return {
354
+ "id": correction_id,
355
+ "message": "Correction recorded. Misata will learn from this.",
356
+ "pattern_learned": column_name.lower()
357
+ }
358
+
359
+ def apply_learned_patterns(
360
+ self,
361
+ schema: Dict[str, Any]
362
+ ) -> Tuple[Dict[str, Any], List[str]]:
363
+ """
364
+ Apply learned patterns to improve a schema.
365
+
366
+ Args:
367
+ schema: Schema to enhance
368
+
369
+ Returns:
370
+ (enhanced_schema, list of changes made)
371
+ """
372
+ patterns = self.db.get_learned_patterns()
373
+ changes = []
374
+
375
+ columns = schema.get("columns", {})
376
+
377
+ for table_name, cols in columns.items():
378
+ for i, col in enumerate(cols):
379
+ col_name = col.get("name", "").lower()
380
+
381
+ if col_name in patterns:
382
+ pattern = patterns[col_name]
383
+ if pattern["confidence"] > 0.6:
384
+ suggestion = pattern["suggestion"]
385
+
386
+ # Apply correction
387
+ old_type = col.get("type")
388
+ new_type = suggestion.get("type")
389
+
390
+ if old_type != new_type:
391
+ columns[table_name][i]["type"] = new_type
392
+ columns[table_name][i]["distribution_params"] = suggestion.get("params", {})
393
+ changes.append(
394
+ f"Applied learned pattern to {table_name}.{col['name']}: "
395
+ f"{old_type} -> {new_type} (confidence: {pattern['confidence']:.0%})"
396
+ )
397
+
398
+ schema["columns"] = columns
399
+ return schema, changes
400
+
401
+ def get_enhanced_prompt(self) -> str:
402
+ """Get prompt enhancement from learned patterns."""
403
+ return self.db.generate_prompt_enhancement()
404
+
405
+ def get_feedback_report(self) -> str:
406
+ """Get a summary of feedback collected."""
407
+ stats = self.db.get_stats()
408
+
409
+ lines = [
410
+ "=" * 50,
411
+ "MISATA FEEDBACK LEARNING REPORT",
412
+ "=" * 50,
413
+ f"Total Corrections Collected: {stats.total_corrections}",
414
+ f"Patterns Learned: {stats.unique_patterns}",
415
+ f"Columns Improved: {stats.columns_corrected}",
416
+ f"Tables Affected: {stats.tables_affected}",
417
+ "",
418
+ "Most Common Corrections:"
419
+ ]
420
+
421
+ for col, count in stats.most_common_fixes:
422
+ lines.append(f" - {col}: {count} corrections")
423
+
424
+ lines.append("=" * 50)
425
+
426
+ return "\n".join(lines)
427
+
428
+
429
+ # Convenience function for CLI
430
+ def collect_feedback_interactive():
431
+ """Interactive feedback collection (for CLI use)."""
432
+ loop = HumanFeedbackLoop()
433
+ print(loop.get_feedback_report())