sqlshell 0.1.9__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlshell might be problematic. Click here for more details.

@@ -0,0 +1,765 @@
1
+ """
2
+ Context-aware SQL suggestions for SQLShell.
3
+
4
+ This module provides advanced context-based suggestion capabilities
5
+ for the SQL editor, providing intelligent and relevant completions
6
+ based on the current query context, schema information, and query history.
7
+ """
8
+
9
+ import re
10
+ from collections import defaultdict, Counter
11
+ from typing import Dict, List, Set, Tuple, Optional, Any
12
+
13
+
14
+ class ContextSuggester:
15
+ """
16
+ Provides context-aware SQL suggestions based on schema information,
17
+ query context analysis, and usage patterns.
18
+ """
19
+
20
+ def __init__(self):
21
+ # Schema information
22
+ self.tables = set() # Set of table names
23
+ self.table_columns = defaultdict(list) # {table_name: [column_names]}
24
+ self.column_types = {} # {table.column: data_type} or {column: data_type}
25
+
26
+ # Detected relationships between tables for JOIN suggestions
27
+ self.relationships = [] # [(table1, column1, table2, column2)]
28
+
29
+ # Usage statistics for prioritizing suggestions
30
+ self.usage_counts = Counter() # {completion_term: count}
31
+
32
+ # Query pattern detection
33
+ self.common_patterns = []
34
+
35
+ # Context cache to avoid recomputing
36
+ self._context_cache = {}
37
+ self._last_analyzed_text = ""
38
+
39
+ # Query history
40
+ self.query_history = [] # List of recent queries for pattern detection
41
+
42
+ # Initialize with common SQL elements
43
+ self._initialize_sql_keywords()
44
+
45
+ def _initialize_sql_keywords(self) -> None:
46
+ """Initialize common SQL keywords by category"""
47
+ self.sql_keywords = {
48
+ 'basic': [
49
+ 'SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING',
50
+ 'LIMIT', 'OFFSET', 'INSERT INTO', 'VALUES', 'UPDATE', 'SET',
51
+ 'DELETE FROM', 'CREATE TABLE', 'DROP TABLE', 'ALTER TABLE',
52
+ 'ADD COLUMN', 'DROP COLUMN', 'RENAME TO', 'UNION', 'UNION ALL',
53
+ 'INTERSECT', 'EXCEPT', 'AS', 'WITH', 'DISTINCT', 'CASE', 'WHEN',
54
+ 'THEN', 'ELSE', 'END', 'AND', 'OR', 'NOT', 'LIKE', 'IN', 'BETWEEN',
55
+ 'IS NULL', 'IS NOT NULL', 'ALL', 'ANY', 'EXISTS'
56
+ ],
57
+ 'aggregation': [
58
+ 'AVG(', 'COUNT(', 'COUNT(*)', 'COUNT(DISTINCT ', 'SUM(', 'MIN(', 'MAX(',
59
+ 'MEDIAN(', 'PERCENTILE_CONT(', 'PERCENTILE_DISC(', 'VARIANCE(', 'STDDEV(',
60
+ 'FIRST(', 'LAST(', 'ARRAY_AGG(', 'STRING_AGG(', 'GROUP_CONCAT('
61
+ ],
62
+ 'functions': [
63
+ # String functions
64
+ 'LOWER(', 'UPPER(', 'INITCAP(', 'TRIM(', 'LTRIM(', 'RTRIM(', 'SUBSTRING(',
65
+ 'SUBSTR(', 'REPLACE(', 'POSITION(', 'CONCAT(', 'LENGTH(', 'CHAR_LENGTH(',
66
+ 'LEFT(', 'RIGHT(', 'REGEXP_REPLACE(', 'REGEXP_EXTRACT(', 'REGEXP_MATCH(',
67
+ 'FORMAT(', 'LPAD(', 'RPAD(', 'REVERSE(', 'SPLIT_PART(',
68
+
69
+ # Numeric functions
70
+ 'ABS(', 'SIGN(', 'ROUND(', 'CEIL(', 'FLOOR(', 'TRUNC(', 'MOD(',
71
+ 'POWER(', 'SQRT(', 'CBRT(', 'LOG(', 'LOG10(', 'EXP(', 'RANDOM(',
72
+
73
+ # Date/time functions
74
+ 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'NOW()',
75
+ 'DATE(', 'TIME(', 'DATETIME(', 'EXTRACT(', 'DATE_TRUNC(', 'DATE_PART(',
76
+ 'DATEADD(', 'DATEDIFF(', 'DATE_FORMAT(', 'STRFTIME(', 'MAKEDATE(',
77
+ 'YEAR(', 'QUARTER(', 'MONTH(', 'WEEK(', 'DAY(', 'HOUR(', 'MINUTE(', 'SECOND(',
78
+
79
+ # Conditional functions
80
+ 'CASE', 'COALESCE(', 'NULLIF(', 'GREATEST(', 'LEAST(', 'IFF(', 'IFNULL(',
81
+ 'DECODE(', 'NVL(', 'NVL2(',
82
+
83
+ # Type conversion
84
+ 'CAST(', 'CONVERT(', 'TRY_CAST(', 'TO_CHAR(', 'TO_DATE(', 'TO_NUMBER(',
85
+ 'TO_TIMESTAMP(', 'PARSE_JSON(',
86
+
87
+ # Window functions
88
+ 'ROW_NUMBER() OVER (', 'RANK() OVER (', 'DENSE_RANK() OVER (',
89
+ 'LEAD(', 'LAG(', 'FIRST_VALUE(', 'LAST_VALUE(', 'NTH_VALUE('
90
+ ],
91
+ 'table_ops': [
92
+ 'CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'INSERT', 'UPDATE', 'DELETE',
93
+ 'MERGE', 'COPY', 'GRANT', 'REVOKE', 'INDEX', 'PRIMARY KEY', 'FOREIGN KEY',
94
+ 'REFERENCES', 'UNIQUE', 'NOT NULL', 'CHECK', 'DEFAULT'
95
+ ],
96
+ 'join': [
97
+ 'INNER JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'FULL JOIN', 'CROSS JOIN',
98
+ 'NATURAL JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN'
99
+ ]
100
+ }
101
+
102
+ # Create a flattened list of all keywords for easy lookup
103
+ self.all_keywords = []
104
+ for category, keywords in self.sql_keywords.items():
105
+ self.all_keywords.extend(keywords)
106
+
107
+ def update_schema(self, tables: Set[str], table_columns: Dict[str, List[str]],
108
+ column_types: Dict[str, str] = None) -> None:
109
+ """
110
+ Update schema information for suggestions.
111
+
112
+ Args:
113
+ tables: Set of table names
114
+ table_columns: Dictionary mapping table names to column lists
115
+ column_types: Optional dictionary of column data types
116
+ """
117
+ self.tables = tables
118
+ self.table_columns = table_columns
119
+
120
+ if column_types:
121
+ self.column_types = column_types
122
+
123
+ # Clear context cache since schema has changed
124
+ self._context_cache = {}
125
+
126
+ # Detect relationships between tables
127
+ self._detect_relationships()
128
+
129
+ def _detect_relationships(self) -> None:
130
+ """Detect potential relationships between tables based on column naming patterns"""
131
+ self.relationships = []
132
+
133
+ # For each table and its columns
134
+ for table, columns in self.table_columns.items():
135
+ for col in columns:
136
+ # Check for foreign key naming pattern (table_id, tableId)
137
+ if col.lower().endswith('_id') or (col.lower().endswith('id') and len(col) > 2):
138
+ # Extract potential referenced table name
139
+ if col.lower().endswith('_id'):
140
+ ref_table = col[:-3] # Remove '_id'
141
+ else:
142
+ # Extract from camelCase/PascalCase
143
+ ref_table = col[:-2] # Remove 'Id'
144
+
145
+ # Normalize and check if this table exists
146
+ ref_table_lower = ref_table.lower()
147
+ for other_table in self.tables:
148
+ other_lower = other_table.lower()
149
+ if other_lower == ref_table_lower or other_lower.endswith(f"_{ref_table_lower}"):
150
+ # Found a potential relationship - check for id column
151
+ if 'id' in self.table_columns[other_table]:
152
+ self.relationships.append((table, col, other_table, 'id'))
153
+ else:
154
+ # Look for any primary key column
155
+ for other_col in self.table_columns[other_table]:
156
+ if other_col.lower() == 'id' or other_col.lower().endswith('_id'):
157
+ self.relationships.append((table, col, other_table, other_col))
158
+ break
159
+
160
+ # Also check for columns with the same name across tables
161
+ for other_table, other_columns in self.table_columns.items():
162
+ if other_table != table and col in other_columns:
163
+ self.relationships.append((table, col, other_table, col))
164
+
165
+ def record_query(self, query: str) -> None:
166
+ """
167
+ Record a query to improve suggestion relevance.
168
+
169
+ Args:
170
+ query: SQL query to record
171
+ """
172
+ if not query.strip():
173
+ return
174
+
175
+ # Add to query history (limit size)
176
+ self.query_history.append(query)
177
+ if len(self.query_history) > 100:
178
+ self.query_history.pop(0)
179
+
180
+ # Update usage statistics
181
+ self._update_usage_stats(query)
182
+
183
+ # Extract common patterns
184
+ self._extract_patterns(query)
185
+
186
+ def _update_usage_stats(self, query: str) -> None:
187
+ """Update usage statistics by analyzing the query"""
188
+ # Extract tables
189
+ tables = self._extract_tables_from_query(query)
190
+ for table in tables:
191
+ if table in self.tables:
192
+ self.usage_counts[table] += 1
193
+
194
+ # Extract columns (with and without table prefix)
195
+ columns = self._extract_columns_from_query(query)
196
+ for col in columns:
197
+ self.usage_counts[col] += 1
198
+
199
+ # Extract SQL keywords
200
+ keywords = re.findall(r'\b([A-Z_]{2,})\b', query.upper())
201
+ for kw in keywords:
202
+ if kw in self.all_keywords:
203
+ self.usage_counts[kw] += 1
204
+
205
+ # Extract common patterns (like "GROUP BY")
206
+ patterns = [
207
+ r'(SELECT\s+.*?\s+FROM)',
208
+ r'(GROUP\s+BY\s+.*?(?:HAVING|ORDER|LIMIT|$))',
209
+ r'(ORDER\s+BY\s+.*?(?:LIMIT|$))',
210
+ r'(INNER\s+JOIN|LEFT\s+JOIN|RIGHT\s+JOIN|FULL\s+JOIN).*?ON\s+.*?=\s+.*?(?:WHERE|JOIN|GROUP|ORDER|LIMIT|$)',
211
+ r'(INSERT\s+INTO\s+.*?\s+VALUES)',
212
+ r'(UPDATE\s+.*?\s+SET\s+.*?\s+WHERE)',
213
+ r'(DELETE\s+FROM\s+.*?\s+WHERE)'
214
+ ]
215
+
216
+ for pattern in patterns:
217
+ matches = re.findall(pattern, query, re.IGNORECASE | re.DOTALL)
218
+ for match in matches:
219
+ # Normalize pattern by removing extra whitespace and converting to uppercase
220
+ normalized = re.sub(r'\s+', ' ', match).strip().upper()
221
+ if len(normalized) < 50: # Only track reasonably sized patterns
222
+ self.usage_counts[normalized] += 1
223
+
224
+ def _extract_tables_from_query(self, query: str) -> List[str]:
225
+ """Extract table names from a SQL query"""
226
+ tables = []
227
+
228
+ # Look for tables after FROM and JOIN
229
+ from_matches = re.findall(r'FROM\s+([a-zA-Z0-9_]+)', query, re.IGNORECASE)
230
+ join_matches = re.findall(r'JOIN\s+([a-zA-Z0-9_]+)', query, re.IGNORECASE)
231
+
232
+ tables.extend(from_matches)
233
+ tables.extend(join_matches)
234
+
235
+ # Look for tables in UPDATE and INSERT statements
236
+ update_matches = re.findall(r'UPDATE\s+([a-zA-Z0-9_]+)', query, re.IGNORECASE)
237
+ insert_matches = re.findall(r'INSERT\s+INTO\s+([a-zA-Z0-9_]+)', query, re.IGNORECASE)
238
+
239
+ tables.extend(update_matches)
240
+ tables.extend(insert_matches)
241
+
242
+ return tables
243
+
244
+ def _extract_columns_from_query(self, query: str) -> List[str]:
245
+ """Extract column names from a SQL query"""
246
+ columns = []
247
+
248
+ # Extract qualified column names (table.column)
249
+ qual_columns = re.findall(r'([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)', query)
250
+ for table, column in qual_columns:
251
+ columns.append(f"{table}.{column}")
252
+ columns.append(column)
253
+
254
+ # Other patterns would need more complex parsing which is beyond the scope
255
+ return columns
256
+
257
+ def _extract_patterns(self, query: str) -> None:
258
+ """Extract common query patterns for future suggestions"""
259
+ # This would require a more sophisticated SQL parser to be accurate
260
+ # Placeholder for future pattern extraction logic
261
+ pass
262
+
263
+ def analyze_context(self, text_before_cursor: str, current_word: str) -> Dict[str, Any]:
264
+ """
265
+ Analyze the SQL context at the current cursor position.
266
+
267
+ Args:
268
+ text_before_cursor: Text from the start of the document to the cursor
269
+ current_word: The current word being typed
270
+
271
+ Returns:
272
+ Dictionary with context information
273
+ """
274
+ # Use cached context if analyzing the same text
275
+ cache_key = f"{text_before_cursor}:{current_word}"
276
+ if cache_key in self._context_cache:
277
+ return self._context_cache[cache_key]
278
+
279
+ # Convert to uppercase for easier keyword matching
280
+ text_upper = text_before_cursor.upper()
281
+
282
+ # Initialize context dictionary
283
+ context = {
284
+ 'type': 'unknown',
285
+ 'table_prefix': None,
286
+ 'after_from': False,
287
+ 'after_join': False,
288
+ 'after_select': False,
289
+ 'after_where': False,
290
+ 'after_group_by': False,
291
+ 'after_order_by': False,
292
+ 'after_having': False,
293
+ 'in_function_args': False,
294
+ 'columns_already_selected': [],
295
+ 'tables_in_from': [],
296
+ 'last_token': '',
297
+ 'current_word': current_word,
298
+ 'current_function': None,
299
+ }
300
+
301
+ # Extract tables from the query for context-aware suggestions
302
+ # Look for tables after FROM and JOIN
303
+ from_matches = re.findall(r'FROM\s+([a-zA-Z0-9_]+)', text_upper)
304
+ join_matches = re.findall(r'JOIN\s+([a-zA-Z0-9_]+)', text_upper)
305
+
306
+ # Add all found tables to context
307
+ if from_matches or join_matches:
308
+ tables = []
309
+ tables.extend(from_matches)
310
+ tables.extend(join_matches)
311
+ context['tables_in_from'] = tables
312
+
313
+ # Check for table.column context
314
+ if '.' in current_word:
315
+ parts = current_word.split('.')
316
+ if len(parts) == 2:
317
+ context['type'] = 'column'
318
+ context['table_prefix'] = parts[0]
319
+
320
+ # Extract the last few keywords to determine context
321
+ keywords = re.findall(r'\b([A-Z_]+)\b', text_upper)
322
+ last_keywords = keywords[-5:] if keywords else []
323
+ last_keyword = last_keywords[-1] if last_keywords else ""
324
+ context['last_token'] = last_keyword
325
+
326
+ # Check for function context - match the last opening parenthesis
327
+ if '(' in text_before_cursor:
328
+ # Count parentheses to check if we're inside function arguments
329
+ open_parens = text_before_cursor.count('(')
330
+ close_parens = text_before_cursor.count(')')
331
+
332
+ if open_parens > close_parens:
333
+ context['type'] = 'function_arg'
334
+ context['in_function_args'] = True
335
+
336
+ # Find the last open parenthesis position
337
+ last_open_paren_pos = text_before_cursor.rindex('(')
338
+
339
+ # Extract text before the parenthesis to identify the function
340
+ func_text = text_before_cursor[:last_open_paren_pos].strip()
341
+ # Get the last word which should be the function name
342
+ func_words = re.findall(r'\b([A-Za-z0-9_]+)\b', func_text)
343
+ if func_words:
344
+ context['current_function'] = func_words[-1].upper()
345
+ context['last_token'] = context['current_function']
346
+
347
+ # Extract the last line or statement
348
+ last_line = text_before_cursor.split('\n')[-1].strip().upper()
349
+
350
+ # Check for specific contexts
351
+
352
+ # FROM/JOIN context - likely to be followed by table names
353
+ if 'FROM' in last_keywords and not any(k in last_keywords[last_keywords.index('FROM'):] for k in ['WHERE', 'GROUP', 'HAVING', 'ORDER']):
354
+ context['type'] = 'table'
355
+ context['after_from'] = True
356
+
357
+ elif any(k.endswith('JOIN') for k in last_keywords):
358
+ context['type'] = 'table'
359
+ context['after_join'] = True
360
+
361
+ # WHERE/AND/OR context - likely to be followed by columns or expressions
362
+ elif any(kw in last_keywords for kw in ['WHERE', 'AND', 'OR']):
363
+ context['type'] = 'column_or_expression'
364
+ context['after_where'] = True
365
+
366
+ # SELECT context - likely to be followed by columns
367
+ elif 'SELECT' in last_keywords and not any(k in last_keywords[last_keywords.index('SELECT'):] for k in ['FROM', 'WHERE']):
368
+ context['type'] = 'column'
369
+ context['after_select'] = True
370
+ # Try to extract columns already in SELECT clause
371
+ select_text = text_before_cursor[text_before_cursor.upper().find('SELECT'):]
372
+ if 'FROM' in select_text.upper():
373
+ select_text = select_text[:select_text.upper().find('FROM')]
374
+ context['columns_already_selected'] = [c.strip() for c in select_text.split(',')[1:]]
375
+
376
+ # GROUP BY context
377
+ elif 'GROUP' in last_keywords or ('BY' in last_keywords and len(last_keywords) >= 2 and last_keywords[-2:] == ['GROUP', 'BY']):
378
+ context['type'] = 'column'
379
+ context['after_group_by'] = True
380
+
381
+ # ORDER BY context
382
+ elif 'ORDER' in last_keywords or ('BY' in last_keywords and len(last_keywords) >= 2 and last_keywords[-2:] == ['ORDER', 'BY']):
383
+ context['type'] = 'column'
384
+ context['after_order_by'] = True
385
+
386
+ # HAVING context
387
+ elif 'HAVING' in last_keywords:
388
+ context['type'] = 'aggregation'
389
+ context['after_having'] = True
390
+
391
+ # Cache the context
392
+ self._context_cache[cache_key] = context
393
+ return context
394
+
395
+ def get_suggestions(self, text_before_cursor: str, current_word: str) -> List[str]:
396
+ """
397
+ Get context-aware SQL suggestions.
398
+
399
+ Args:
400
+ text_before_cursor: Text from start of document to cursor position
401
+ current_word: The current word being typed (possibly empty)
402
+
403
+ Returns:
404
+ List of suggestion strings relevant to the current context
405
+ """
406
+ # Get detailed context
407
+ context = self.analyze_context(text_before_cursor, current_word)
408
+
409
+ # Start with an empty suggestion list
410
+ suggestions = []
411
+
412
+ # Different suggestion strategies based on context type
413
+ if context['type'] == 'table':
414
+ suggestions = self._get_table_suggestions(context)
415
+ elif context['type'] == 'column' and context['table_prefix']:
416
+ suggestions = self._get_column_suggestions_for_table(context['table_prefix'])
417
+ elif context['type'] == 'column' or context['type'] == 'column_or_expression':
418
+ suggestions = self._get_column_suggestions(context)
419
+ elif context['type'] == 'function_arg':
420
+ suggestions = self._get_function_arg_suggestions(context)
421
+ elif context['type'] == 'aggregation':
422
+ suggestions = self._get_aggregation_suggestions(context)
423
+ else:
424
+ # Default case - general SQL keywords
425
+ suggestions = self._get_default_suggestions()
426
+
427
+ # Filter by current word if needed
428
+ if current_word:
429
+ suggestions = [s for s in suggestions if s.lower().startswith(current_word.lower())]
430
+
431
+ # Prioritize by usage frequency
432
+ return self._prioritize_suggestions(suggestions, context)
433
+
434
+ def _get_table_suggestions(self, context: Dict[str, Any]) -> List[str]:
435
+ """Get table name suggestions"""
436
+ suggestions = list(self.tables)
437
+
438
+ # Add table aliases if relevant
439
+ aliases = [f"{t} AS {t[0]}" for t in self.tables]
440
+ suggestions.extend(aliases)
441
+
442
+ # Add keywords that might follow FROM/JOIN
443
+ if context['after_from'] or context['after_join']:
444
+ suggestions.extend(self.sql_keywords['join'])
445
+
446
+ # If we have previous tables, suggest relationships
447
+ if context['tables_in_from'] and self.relationships:
448
+ prev_tables = context['tables_in_from']
449
+ for prev_table in prev_tables:
450
+ for t1, c1, t2, c2 in self.relationships:
451
+ if t1 == prev_table:
452
+ join_suggestion = f"{t2} ON {t2}.{c2} = {t1}.{c1}"
453
+ suggestions.append(join_suggestion)
454
+ elif t2 == prev_table:
455
+ join_suggestion = f"{t1} ON {t1}.{c1} = {t2}.{c2}"
456
+ suggestions.append(join_suggestion)
457
+
458
+ return suggestions
459
+
460
+ def _get_column_suggestions_for_table(self, table_prefix: str) -> List[str]:
461
+ """Get column suggestions for a specific table"""
462
+ if table_prefix in self.table_columns:
463
+ return self.table_columns[table_prefix]
464
+ return []
465
+
466
+ def _get_column_suggestions(self, context: Dict[str, Any]) -> List[str]:
467
+ """Get column name suggestions"""
468
+ suggestions = []
469
+
470
+ # Add SQL functions and keywords for columns
471
+ suggestions.extend(self.sql_keywords['aggregation'])
472
+ suggestions.extend(self.sql_keywords['functions'])
473
+
474
+ # Identify active tables in the current query
475
+ active_tables = set()
476
+ # First check tables extracted from FROM/JOIN clauses
477
+ if 'tables_in_from' in context and context['tables_in_from']:
478
+ active_tables.update(context['tables_in_from'])
479
+
480
+ # Define column lists by priority
481
+ active_table_columns = []
482
+ other_columns = []
483
+
484
+ # Get columns from active tables first
485
+ for table in active_tables:
486
+ if table in self.table_columns:
487
+ columns = self.table_columns[table]
488
+ # Add both plain column names and qualified ones
489
+ active_table_columns.extend(columns)
490
+ active_table_columns.extend([f"{table}.{col}" for col in columns])
491
+
492
+ # Then get all other columns as fallback
493
+ for table, columns in self.table_columns.items():
494
+ if table not in active_tables:
495
+ other_columns.extend(columns)
496
+ # Only add qualified names if we have multiple tables to avoid confusion
497
+ if len(self.table_columns) > 1:
498
+ other_columns.extend([f"{table}.{col}" for col in columns])
499
+
500
+ # Add * and table.* suggestions
501
+ suggestions.append("*")
502
+ for table in self.tables:
503
+ suggestions.append(f"{table}.*")
504
+
505
+ # Context-specific additions
506
+ if context['after_select']:
507
+ # Add common SELECT patterns
508
+ suggestions.append("DISTINCT ")
509
+ # Avoid suggesting columns already in the select list
510
+ already_selected = [col.split(' ')[0].split('.')[0] for col in context['columns_already_selected']]
511
+ for col in already_selected:
512
+ if col in suggestions:
513
+ suggestions.remove(col)
514
+
515
+ elif context['after_where']:
516
+ # Add comparison operators for WHERE clause
517
+ operators = ["=", ">", "<", ">=", "<=", "<>", "!=", "LIKE", "IN", "BETWEEN", "IS NULL", "IS NOT NULL"]
518
+ suggestions.extend(operators)
519
+
520
+ # Add columns with priority ordering
521
+ suggestions.extend(active_table_columns)
522
+ suggestions.extend(other_columns)
523
+
524
+ # Remove duplicates while preserving order
525
+ seen = set()
526
+ filtered_suggestions = []
527
+ for item in suggestions:
528
+ if item not in seen:
529
+ seen.add(item)
530
+ filtered_suggestions.append(item)
531
+
532
+ return filtered_suggestions
533
+
534
+ def _get_function_arg_suggestions(self, context: Dict[str, Any]) -> List[str]:
535
+ """Get suggestions for function arguments"""
536
+ suggestions = []
537
+
538
+ # Identify active tables in the current query
539
+ active_tables = set()
540
+ # First check tables extracted from FROM/JOIN clauses
541
+ if 'tables_in_from' in context and context['tables_in_from']:
542
+ active_tables.update(context['tables_in_from'])
543
+
544
+ # Add column names as function arguments, prioritizing columns from active tables
545
+ active_table_columns = []
546
+ other_columns = []
547
+
548
+ # First get columns from active tables
549
+ for table in active_tables:
550
+ if table in self.table_columns:
551
+ columns = self.table_columns[table]
552
+ # Add both plain column names and qualified ones
553
+ active_table_columns.extend(columns)
554
+ active_table_columns.extend([f"{table}.{col}" for col in columns])
555
+
556
+ # Then get all other columns as fallback
557
+ for table, columns in self.table_columns.items():
558
+ if table not in active_tables:
559
+ other_columns.extend(columns)
560
+ # Only add qualified names if we have multiple tables to avoid confusion
561
+ if len(self.table_columns) > 1:
562
+ other_columns.extend([f"{table}.{col}" for col in columns])
563
+
564
+ # Add context-specific suggestions based on the last token
565
+ last_token = context['last_token']
566
+
567
+ if last_token in ['AVG', 'SUM', 'MIN', 'MAX', 'COUNT']:
568
+ # For aggregate functions, prioritize numeric columns
569
+ numeric_columns = []
570
+
571
+ # First check active tables for numeric columns
572
+ for table in active_tables:
573
+ if table in self.table_columns:
574
+ for col in self.table_columns[table]:
575
+ qualified_name = f"{table}.{col}"
576
+ # Check if column type info is available
577
+ if qualified_name in self.column_types:
578
+ data_type = self.column_types[qualified_name].upper()
579
+ if any(t in data_type for t in ['INT', 'NUM', 'FLOAT', 'DOUBLE', 'DECIMAL']):
580
+ numeric_columns.append(qualified_name)
581
+ numeric_columns.append(col)
582
+
583
+ # If no numeric columns found in active tables, check all columns
584
+ if not numeric_columns:
585
+ for col_name, data_type in self.column_types.items():
586
+ if data_type and any(t in data_type.upper() for t in ['INT', 'NUM', 'FLOAT', 'DOUBLE', 'DECIMAL']):
587
+ numeric_columns.append(col_name)
588
+
589
+ # Build final suggestion list with priority order:
590
+ # 1. Numeric columns from active tables
591
+ # 2. All columns from active tables
592
+ # 3. Numeric columns from other tables
593
+ # 4. All other columns
594
+ suggestions = numeric_columns + active_table_columns + other_columns
595
+
596
+ elif last_token in ['SUBSTRING', 'LOWER', 'UPPER', 'TRIM', 'REPLACE', 'CONCAT']:
597
+ # For string functions, prioritize text columns
598
+ text_columns = []
599
+
600
+ # First check active tables for text columns
601
+ for table in active_tables:
602
+ if table in self.table_columns:
603
+ for col in self.table_columns[table]:
604
+ qualified_name = f"{table}.{col}"
605
+ # Check if column type info is available
606
+ if qualified_name in self.column_types:
607
+ data_type = self.column_types[qualified_name].upper()
608
+ if any(t in data_type for t in ['CHAR', 'VARCHAR', 'TEXT', 'STRING']):
609
+ text_columns.append(qualified_name)
610
+ text_columns.append(col)
611
+
612
+ # If no text columns found in active tables, check all columns
613
+ if not text_columns:
614
+ for col_name, data_type in self.column_types.items():
615
+ if data_type and any(t in data_type.upper() for t in ['CHAR', 'VARCHAR', 'TEXT', 'STRING']):
616
+ text_columns.append(col_name)
617
+
618
+ suggestions = text_columns + active_table_columns + other_columns
619
+
620
+ elif last_token in ['DATE', 'DATETIME', 'EXTRACT', 'DATEADD', 'DATEDIFF']:
621
+ # For date functions, prioritize date columns
622
+ date_columns = []
623
+
624
+ # First check active tables for date columns
625
+ for table in active_tables:
626
+ if table in self.table_columns:
627
+ for col in self.table_columns[table]:
628
+ qualified_name = f"{table}.{col}"
629
+ # Check if column type info is available
630
+ if qualified_name in self.column_types:
631
+ data_type = self.column_types[qualified_name].upper()
632
+ if any(t in data_type for t in ['DATE', 'TIME', 'TIMESTAMP']):
633
+ date_columns.append(qualified_name)
634
+ date_columns.append(col)
635
+
636
+ # If no date columns found in active tables, check all columns
637
+ if not date_columns:
638
+ for col_name, data_type in self.column_types.items():
639
+ if data_type and any(t in data_type.upper() for t in ['DATE', 'TIME', 'TIMESTAMP']):
640
+ date_columns.append(col_name)
641
+
642
+ suggestions = date_columns + active_table_columns + other_columns
643
+
644
+ else:
645
+ # For other functions or generic cases, prioritize active table columns
646
+ suggestions = active_table_columns + other_columns
647
+
648
+ # Remove duplicates while preserving order
649
+ seen = set()
650
+ filtered_suggestions = []
651
+ for item in suggestions:
652
+ if item not in seen:
653
+ seen.add(item)
654
+ filtered_suggestions.append(item)
655
+
656
+ return filtered_suggestions
657
+
658
+ def _get_aggregation_suggestions(self, context: Dict[str, Any]) -> List[str]:
659
+ """Get suggestions for aggregation functions (HAVING clause)"""
660
+ suggestions = []
661
+
662
+ # Aggregation functions
663
+ suggestions.extend(self.sql_keywords['aggregation'])
664
+
665
+ # Common HAVING patterns
666
+ having_patterns = [
667
+ "COUNT(*) > ",
668
+ "COUNT(*) < ",
669
+ "COUNT(DISTINCT ",
670
+ "SUM(",
671
+ "AVG(",
672
+ "MIN(",
673
+ "MAX("
674
+ ]
675
+ suggestions.extend(having_patterns)
676
+
677
+ return suggestions
678
+
679
+ def _get_default_suggestions(self) -> List[str]:
680
+ """Get default suggestions when no specific context is detected"""
681
+ suggestions = []
682
+
683
+ # Basic SQL keywords
684
+ suggestions.extend(self.sql_keywords['basic'])
685
+
686
+ # Common query starters
687
+ query_starters = [
688
+ "SELECT * FROM ",
689
+ "SELECT COUNT(*) FROM ",
690
+ "SELECT DISTINCT ",
691
+ "INSERT INTO ",
692
+ "UPDATE ",
693
+ "DELETE FROM ",
694
+ "CREATE TABLE ",
695
+ "DROP TABLE ",
696
+ "ALTER TABLE "
697
+ ]
698
+ suggestions.extend(query_starters)
699
+
700
+ # Add most-used tables and columns
701
+ top_used = [item for item, _ in self.usage_counts.most_common(10)]
702
+ suggestions.extend(top_used)
703
+
704
+ return suggestions
705
+
706
+ def _prioritize_suggestions(self, suggestions: List[str], context: Dict[str, Any]) -> List[str]:
707
+ """
708
+ Prioritize suggestions based on relevance and usage statistics.
709
+
710
+ Args:
711
+ suggestions: List of initial suggestions
712
+ context: Current SQL context
713
+
714
+ Returns:
715
+ Prioritized list of suggestions
716
+ """
717
+ # If there are no suggestions, return empty list
718
+ if not suggestions:
719
+ return []
720
+
721
+ # Create a set for O(1) lookups and to remove duplicates
722
+ suggestion_set = set(suggestions)
723
+
724
+ # Start with a list of (suggestion, score) tuples
725
+ scored_suggestions = []
726
+
727
+ for suggestion in suggestion_set:
728
+ # Base score from usage count (normalize to 0-10 range)
729
+ count = self.usage_counts.get(suggestion, 0)
730
+ max_count = max(self.usage_counts.values()) if self.usage_counts else 1
731
+ usage_score = (count / max_count) * 10 if max_count > 0 else 0
732
+
733
+ # Start with usage score
734
+ score = usage_score
735
+
736
+ # Boost for SQL keywords
737
+ if suggestion.upper() in self.all_keywords:
738
+ score += 5
739
+
740
+ # Context-specific boosting
741
+ if context['type'] == 'table' and suggestion in self.tables:
742
+ score += 10
743
+ elif context['type'] == 'column' and context['table_prefix']:
744
+ if suggestion in self.table_columns.get(context['table_prefix'], []):
745
+ score += 15
746
+ elif context['type'] == 'column' and any(suggestion in cols for cols in self.table_columns.values()):
747
+ score += 8
748
+ elif context['type'] == 'aggregation' and suggestion in self.sql_keywords['aggregation']:
749
+ score += 12
750
+
751
+ # Exact prefix match gives a big boost
752
+ current_word = context['current_word']
753
+ if current_word and suggestion.lower().startswith(current_word.lower()):
754
+ # More boost for exact case match
755
+ if suggestion.startswith(current_word):
756
+ score += 20
757
+ else:
758
+ score += 15
759
+
760
+ # Add to scored list
761
+ scored_suggestions.append((suggestion, score))
762
+
763
+ # Sort by score (descending) and return just the suggestions
764
+ scored_suggestions.sort(key=lambda x: x[1], reverse=True)
765
+ return [suggestion for suggestion, _ in scored_suggestions]