rails-nl2sql 0.1.4 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c8acace727f6a1308018979554b87e003ca3ce427600a2c1822c4c3bcdbbeaeb
4
- data.tar.gz: 33cfaabfe6e5f069d78b274481a3c7a27d20a5bd634d37160f56fbc1383277a0
3
+ metadata.gz: 4c85af69e2b9d5c60665676f6b1aa853220bb44e98484bd0968b9e659528fa31
4
+ data.tar.gz: 8423aaec96729c329dad2fe601ba11ef84545054f84ab1f3e0fc32ac181a7a8f
5
5
  SHA512:
6
- metadata.gz: d69f8f85025572b8aed9d0c991fc25108964a1a7dbc27c7e6f858f1d0c7a5c9775264e68b896a80dcb430f2df51447fe7cb1ad4aa4e46cb20fc912934e8f2512
7
- data.tar.gz: e4f40efeec6ac18ce06d4d2875bec8aef20f1dae7201962b3050f0b29dcc0e736f45c5c583752b7eec55e1f39056688f86ece31ae94790f88a8a6e9dedab4459
6
+ metadata.gz: 91739a9e2c4456f0232da6f614a5f9b7647b77267bc8d9eb49b3bb36a7d64542cca9ebf5e02897424bd8dc68b20a71a2d336ffd7eededd3b4b05cc2879d25fd4
7
+ data.tar.gz: 4e4540b9939bbb22bb093cf15482215016817b5cee95dbefaf6713cce048e5464136c62aeef0116659f378f020443e409b92d927bd4755a22c2c3829d6c9d022
@@ -3,23 +3,178 @@ require "openai"
3
3
  module Rails
4
4
  module Nl2sql
5
5
  class QueryGenerator
6
- def initialize(api_key, model = "text-davinci-003")
7
- @client = OpenAI::Client.new(api_key: api_key)
6
+ def initialize(api_key, model = "gpt-3.5-turbo-instruct")
7
+ @client = OpenAI::Client.new(access_token: api_key)
8
8
  @model = model
9
9
  end
10
10
 
11
- def generate_query(prompt, schema)
12
- full_prompt = "Given the following schema:\n\n#{schema}\n\nGenerate a SQL query for the following request:\n\n#{prompt}"
11
+ def generate_query(prompt, schema, db_server = "PostgreSQL", tables = nil)
12
+ retrieved_context = build_context(schema, tables)
13
+
14
+ system_prompt = build_system_prompt(db_server, retrieved_context)
15
+ user_prompt = build_user_prompt(prompt)
16
+
17
+ full_prompt = "#{system_prompt}\n\n#{user_prompt}"
13
18
 
14
19
  response = @client.completions(
15
20
  parameters: {
16
21
  model: @model,
17
22
  prompt: full_prompt,
18
- max_tokens: 150
23
+ max_tokens: 300,
24
+ temperature: 0.1
19
25
  }
20
26
  )
21
27
 
22
- response.choices.first.text.strip
28
+ generated_query = response.dig("choices", 0, "text")&.strip
29
+
30
+ # Clean up the response to remove markdown formatting
31
+ generated_query = clean_sql_response(generated_query)
32
+
33
+ # Safety check
34
+ validate_query_safety(generated_query)
35
+
36
+ generated_query
37
+ end
38
+
39
+ private
40
+
41
+ def build_context(schema, tables)
42
+ if tables&.any?
43
+ # Filter schema to only include requested tables
44
+ filtered_schema = filter_schema_by_tables(schema, tables)
45
+ filtered_schema
46
+ else
47
+ schema
48
+ end
49
+ end
50
+
51
+ def filter_schema_by_tables(schema, tables)
52
+ # Simple filtering - in a real implementation, this would be more sophisticated
53
+ lines = schema.split("\n")
54
+ filtered_lines = []
55
+ current_table = nil
56
+ include_current = false
57
+
58
+ lines.each do |line|
59
+ if line.match(/CREATE TABLE (\w+)/)
60
+ current_table = $1
61
+ include_current = tables.include?(current_table)
62
+ end
63
+
64
+ if include_current || line.strip.empty?
65
+ filtered_lines << line
66
+ end
67
+ end
68
+
69
+ filtered_lines.join("\n")
70
+ end
71
+
72
+ def build_system_prompt(db_server, retrieved_context)
73
+ <<~PROMPT
74
+ You are an expert SQL assistant specializing in generating dynamic queries based on natural language.
75
+ Your primary goal is to generate **correct, safe, and executable #{db_server} SQL queries** based on user questions.
76
+
77
+ ---
78
+ **DATABASE CONTEXT (SCHEMA):**
79
+ You are provided with relevant schema details from the database, retrieved to help you.
80
+ **STRICTLY adhere to this provided schema context.** Do not use any tables or columns not explicitly listed here.
81
+ #{retrieved_context}
82
+
83
+ ---
84
+ **SQL GENERATION RULES:**
85
+ 1. **SQL Dialect:** All generated SQL must be valid **#{db_server} syntax**.
86
+ * For limiting results, use `LIMIT` (e.g., `LIMIT 10`) instead of `TOP`.
87
+ * Be mindful of #{db_server}'s specific function names (e.g., `COUNT(*)`, `MAX()`) and behaviors.
88
+ * For subqueries that return a single value to be used in a `WHERE` clause, ensure they are correctly formatted for #{db_server}.
89
+ 2. **Schema Adherence:** Only use table names and column names that are explicitly present in the provided context. Do not invent names.
90
+ 3. **Valid JOIN Paths:** All `JOIN` operations must be based on valid foreign key relationships. The provided schema context explicitly details many of these.
91
+ 4. **Safety First:** Absolutely **DO NOT** generate any DDL (CREATE, ALTER, DROP) or DML (INSERT, UPDATE, DELETE) statements. Only `SELECT` queries are permitted.
92
+ 5. **CRITICAL: Handling Missing/Empty Text Data:**
93
+ * When a user asks about "missing," "no," "empty," or "null" values for a TEXT column (like 'email', 'phone', 'address', 'company', 'fax'), generate a `WHERE` clause that explicitly checks for **both `IS NULL` and `= ''` (an empty string)**.
94
+ * **Example:** To find agents with no email, the query should be `SELECT first_name, last_name FROM agents WHERE email IS NULL OR email = '';`
95
+ * This is essential
96
+ 6. **Ambiguity:** If a user question is ambiguous or requires more information to form a precise SQL query, clearly state that you need clarification and ask for more details. Do not guess.
97
+
98
+ **RESPOND WITH ONLY THE SQL QUERY - NO EXPLANATIONS, NO MARKDOWN FORMATTING, NO CODE BLOCKS, NO ADDITIONAL TEXT.**
99
+ PROMPT
100
+ end
101
+
102
+ def build_user_prompt(input)
103
+ <<~PROMPT
104
+ Here is the **USER QUESTION:** Respond with a thoughtful process that leads to a SQL query, using the tools as necessary.
105
+ "#{input}"
106
+ PROMPT
107
+ end
108
+
109
+ def clean_sql_response(query)
110
+ return query unless query
111
+
112
+ # Remove markdown code blocks
113
+ query = query.gsub(/```sql\n?/, '')
114
+ query = query.gsub(/```\n?/, '')
115
+
116
+ # Remove any leading/trailing whitespace
117
+ query = query.strip
118
+
119
+ # Remove any explanatory text before or after the query
120
+ # Look for common patterns like "Here's the SQL query:" or "The query is:"
121
+ query = query.gsub(/^.*?(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER)/i, '\1')
122
+
123
+ # Remove any trailing explanatory text after the query
124
+ # Split by newlines and take only the SQL part
125
+ lines = query.split("\n")
126
+ sql_lines = []
127
+
128
+ lines.each do |line|
129
+ line = line.strip
130
+ # Skip empty lines or lines that look like explanations
131
+ next if line.empty?
132
+ next if line.match(/^(here|this|the query|explanation|note)/i)
133
+
134
+ sql_lines << line
135
+ end
136
+
137
+ # Rejoin the SQL lines
138
+ cleaned_query = sql_lines.join("\n").strip
139
+
140
+ # Ensure it ends with a semicolon if it's a complete query
141
+ if cleaned_query.match(/^(SELECT|WITH)/i) && !cleaned_query.end_with?(';')
142
+ cleaned_query += ';'
143
+ end
144
+
145
+ cleaned_query
146
+ end
147
+
148
+ def validate_query_safety(query)
149
+ return unless query
150
+
151
+ banned_keywords = [
152
+ "delete", "drop", "truncate", "update", "insert", "alter",
153
+ "exec", "execute", "create", "merge", "replace", "into"
154
+ ]
155
+
156
+ banned_phrases = [
157
+ "ignore previous instructions", "pretend you are", "i am the admin",
158
+ "you are no longer bound", "bypass the rules", "run this instead",
159
+ "for testing, run", "no safety constraints", "show me a dangerous query",
160
+ "this is a dev environment", "drop all data", "delete all users", "wipe the database"
161
+ ]
162
+
163
+ query_lower = query.downcase
164
+
165
+ # Check for banned keywords
166
+ banned_keywords.each do |keyword|
167
+ if query_lower.include?(keyword)
168
+ raise Rails::Nl2sql::Error, "Query contains banned keyword: #{keyword}"
169
+ end
170
+ end
171
+
172
+ # Check for banned phrases
173
+ banned_phrases.each do |phrase|
174
+ if query_lower.include?(phrase)
175
+ raise Rails::Nl2sql::Error, "Query contains banned phrase: #{phrase}"
176
+ end
177
+ end
23
178
  end
24
179
  end
25
180
  end
@@ -2,17 +2,36 @@ module Rails
2
2
  module Nl2sql
3
3
  class QueryValidator
4
4
  def self.validate(query)
5
+ return false unless query && !query.strip.empty?
6
+
7
+ # Clean the query first
8
+ query = query.strip
9
+
10
+ # Check if query is malformed (contains markdown or other formatting)
11
+ if query.include?('```') || query.include?('```sql')
12
+ raise Rails::Nl2sql::Error, "Query contains markdown formatting and could not be cleaned properly"
13
+ end
14
+
5
15
  # Basic validation: prevent destructive commands
6
- disallowed_keywords = %w(DROP DELETE UPDATE INSERT TRUNCATE ALTER CREATE)
7
- if disallowed_keywords.any? { |keyword| query.upcase.include?(keyword) }
8
- raise "Query contains disallowed keywords."
16
+ disallowed_keywords = %w(DROP DELETE UPDATE INSERT TRUNCATE ALTER CREATE EXEC EXECUTE MERGE REPLACE)
17
+ query_upper = query.upcase
18
+
19
+ if disallowed_keywords.any? { |keyword| query_upper.include?(keyword) }
20
+ raise Rails::Nl2sql::Error, "Query contains disallowed keywords."
21
+ end
22
+
23
+ # Ensure it's a SELECT query
24
+ unless query_upper.strip.start_with?('SELECT', 'WITH')
25
+ raise Rails::Nl2sql::Error, "Only SELECT queries are allowed."
9
26
  end
10
27
 
11
- # Use Rails' built-in sanitization to be safe
28
+ # Use Rails' built-in validation with EXPLAIN
12
29
  begin
13
- ActiveRecord::Base.connection.execute("EXPLAIN #{query}")
30
+ # Remove trailing semicolon for EXPLAIN
31
+ explain_query = query.gsub(/;\s*$/, '')
32
+ ActiveRecord::Base.connection.execute("EXPLAIN #{explain_query}")
14
33
  rescue ActiveRecord::StatementInvalid => e
15
- raise "Invalid SQL query: #{e.message}"
34
+ raise Rails::Nl2sql::Error, "Invalid SQL query: #{e.message}"
16
35
  end
17
36
 
18
37
  true
@@ -2,10 +2,157 @@ module Rails
2
2
  module Nl2sql
3
3
  class SchemaBuilder
4
4
  def self.build_schema(options = {})
5
- tables = ActiveRecord::Base.connection.tables
6
- tables -= options[:exclude] if options[:exclude]
7
- tables = options[:include] if options[:include]
5
+ tables = get_filtered_tables(options)
6
+
7
+ schema_text = build_schema_text(tables)
8
+ schema_text
9
+ end
10
+
11
+ def self.get_database_type
12
+ adapter = ActiveRecord::Base.connection.adapter_name.downcase
13
+ case adapter
14
+ when 'postgresql'
15
+ 'PostgreSQL'
16
+ when 'mysql', 'mysql2'
17
+ 'MySQL'
18
+ when 'sqlite3'
19
+ 'SQLite'
20
+ when 'oracle'
21
+ 'Oracle'
22
+ when 'sqlserver'
23
+ 'SQL Server'
24
+ else
25
+ 'PostgreSQL' # Default fallback
26
+ end
27
+ end
28
+
29
+ def self.get_filtered_tables(options = {})
30
+ all_tables = ActiveRecord::Base.connection.tables
31
+
32
+ # Remove system tables
33
+ all_tables.reject! { |table| system_table?(table) }
34
+
35
+ # Apply filtering options
36
+ if options[:exclude]
37
+ all_tables -= options[:exclude]
38
+ end
39
+
40
+ if options[:include]
41
+ all_tables = options[:include] & all_tables
42
+ end
43
+
44
+ all_tables
45
+ end
46
+
47
+ def self.build_schema_text(tables)
48
+ schema_parts = []
49
+
50
+ tables.each do |table|
51
+ schema_parts << build_table_schema(table)
52
+ end
53
+
54
+ schema_parts.join("\n\n")
55
+ end
56
+
57
+ def self.build_table_schema(table)
58
+ columns = ActiveRecord::Base.connection.columns(table)
59
+
60
+ schema = "CREATE TABLE #{table} (\n"
61
+
62
+ column_definitions = columns.map do |column|
63
+ type_info = get_column_type_info(column)
64
+ nullable = column.null ? "" : " NOT NULL"
65
+ default = column.default ? " DEFAULT #{column.default}" : ""
66
+
67
+ " #{column.name} #{type_info}#{nullable}#{default}"
68
+ end
69
+
70
+ schema += column_definitions.join(",\n")
71
+ schema += "\n);"
72
+
73
+ # Add indexes and foreign keys if available
74
+ indexes = get_table_indexes(table)
75
+ if indexes.any?
76
+ schema += "\n\n-- Indexes for #{table}:"
77
+ indexes.each do |index|
78
+ schema += "\n-- #{index[:type]}: #{index[:columns].join(', ')}"
79
+ end
80
+ end
81
+
82
+ schema
83
+ end
84
+
85
+ def self.get_column_type_info(column)
86
+ case column.type
87
+ when :string
88
+ "VARCHAR(#{column.limit || 255})"
89
+ when :text
90
+ "TEXT"
91
+ when :integer
92
+ "INTEGER"
93
+ when :bigint
94
+ "BIGINT"
95
+ when :float
96
+ "FLOAT"
97
+ when :decimal
98
+ precision = column.precision || 10
99
+ scale = column.scale || 0
100
+ "DECIMAL(#{precision},#{scale})"
101
+ when :datetime
102
+ "TIMESTAMP"
103
+ when :date
104
+ "DATE"
105
+ when :time
106
+ "TIME"
107
+ when :boolean
108
+ "BOOLEAN"
109
+ when :json
110
+ "JSON"
111
+ else
112
+ column.sql_type || "TEXT"
113
+ end
114
+ end
115
+
116
+ def self.get_table_indexes(table)
117
+ indexes = []
118
+
119
+ begin
120
+ connection = ActiveRecord::Base.connection
121
+ if connection.respond_to?(:indexes)
122
+ table_indexes = connection.indexes(table)
123
+ table_indexes.each do |index|
124
+ indexes << {
125
+ type: index.unique? ? "UNIQUE INDEX" : "INDEX",
126
+ columns: index.columns,
127
+ name: index.name
128
+ }
129
+ end
130
+ end
131
+ rescue => e
132
+ # Skip if indexes can't be retrieved
133
+ end
134
+
135
+ indexes
136
+ end
137
+
138
+ def self.system_table?(table)
139
+ system_tables = [
140
+ 'schema_migrations',
141
+ 'ar_internal_metadata',
142
+ 'sqlite_sequence',
143
+ 'information_schema',
144
+ 'performance_schema',
145
+ 'mysql',
146
+ 'sys'
147
+ ]
148
+
149
+ system_tables.any? { |sys_table| table.include?(sys_table) }
150
+ end
8
151
 
152
+ # Legacy method for backward compatibility
153
+ def self.build_hash_schema(options = {})
154
+ tables = get_filtered_tables(options)
155
+
9
156
  schema = {}
10
157
  tables.each do |table|
11
158
  schema[table] = ActiveRecord::Base.connection.columns(table).map(&:name)
@@ -1,5 +1,5 @@
1
1
  module Rails
2
2
  module Nl2sql
3
- VERSION = "0.1.4"
3
+ VERSION = "0.1.8"
4
4
  end
5
5
  end
data/lib/rails/nl2sql.rb CHANGED
@@ -12,7 +12,7 @@ module Rails
12
12
  attr_accessor :api_key
13
13
  attr_accessor :model
14
14
  end
15
- @@model = "text-davinci-003"
15
+ @@model = "gpt-3.5-turbo-instruct"
16
16
 
17
17
  def self.configure
18
18
  yield self
@@ -20,22 +20,67 @@ module Rails
20
20
 
21
21
  class Processor
22
22
  def self.execute(natural_language_query, options = {})
23
+ # Get database type
24
+ db_server = SchemaBuilder.get_database_type
25
+
26
+ # Build schema with optional table filtering
23
27
  schema = SchemaBuilder.build_schema(options)
28
+
29
+ # Extract tables for filtering if specified
30
+ tables = options[:tables]
31
+
32
+ # Generate query with enhanced prompt
24
33
  query_generator = QueryGenerator.new(Rails::Nl2sql.api_key, Rails::Nl2sql.model)
25
- generated_query = query_generator.generate_query(natural_language_query, schema)
34
+ generated_query = query_generator.generate_query(
35
+ natural_language_query,
36
+ schema,
37
+ db_server,
38
+ tables
39
+ )
26
40
 
41
+ # Validate the generated query
27
42
  QueryValidator.validate(generated_query)
28
43
 
44
+ # Execute the query
29
45
  ActiveRecord::Base.connection.execute(generated_query)
30
46
  end
31
47
 
48
+ def self.generate_query_only(natural_language_query, options = {})
49
+ # Get database type
50
+ db_server = SchemaBuilder.get_database_type
51
+
52
+ # Build schema with optional table filtering
53
+ schema = SchemaBuilder.build_schema(options)
54
+
55
+ # Extract tables for filtering if specified
56
+ tables = options[:tables]
57
+
58
+ # Generate query with enhanced prompt
59
+ query_generator = QueryGenerator.new(Rails::Nl2sql.api_key, Rails::Nl2sql.model)
60
+ generated_query = query_generator.generate_query(
61
+ natural_language_query,
62
+ schema,
63
+ db_server,
64
+ tables
65
+ )
66
+
67
+ # Validate the generated query
68
+ QueryValidator.validate(generated_query)
69
+
70
+ generated_query
71
+ end
72
+
32
73
  def self.get_tables(options = {})
33
- ActiveRecord::Base.connection.tables
74
+ SchemaBuilder.get_filtered_tables(options)
34
75
  end
35
76
 
36
77
  def self.get_schema(options = {})
37
78
  SchemaBuilder.build_schema(options)
38
79
  end
80
+
81
+ def self.get_database_type
82
+ SchemaBuilder.get_database_type
83
+ end
39
84
  end
40
85
  end
41
86
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rails-nl2sql
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.4
4
+ version: 0.1.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - Russell Van Curen