sql-redis 0.1.2__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sql_redis-0.1.2 → sql_redis-0.2.0}/PKG-INFO +1 -1
- {sql_redis-0.1.2 → sql_redis-0.2.0}/pyproject.toml +1 -1
- sql_redis-0.2.0/sql_redis/__init__.py +17 -0
- sql_redis-0.2.0/sql_redis/executor.py +228 -0
- sql_redis-0.2.0/sql_redis/schema.py +215 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/translator.py +5 -3
- {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/version.py +1 -1
- sql_redis-0.2.0/tests/test_async_executor.py +344 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_schema_registry.py +5 -9
- sql_redis-0.1.2/sql_redis/__init__.py +0 -6
- sql_redis-0.1.2/sql_redis/executor.py +0 -155
- sql_redis-0.1.2/sql_redis/schema.py +0 -142
- {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/lint.yml +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/release.yml +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/test.yml +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/.gitignore +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/.pre-commit-config.yaml +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/Makefile +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/README.md +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/analyzer.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/parser.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/query_builder.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/__init__.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/conftest.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_analyzer.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_executor.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_parameter_substitution.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_query_builder.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_redis_queries.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_sql_parser.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_sql_queries.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_translator.py +0 -0
- {sql_redis-0.1.2 → sql_redis-0.2.0}/uv.lock +0 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""SQL to Redis command translation utility."""
|
|
2
|
+
|
|
3
|
+
from sql_redis.executor import AsyncExecutor, Executor, QueryResult
|
|
4
|
+
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
|
|
5
|
+
from sql_redis.translator import TranslatedQuery, Translator
|
|
6
|
+
from sql_redis.version import __version__
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Translator",
|
|
10
|
+
"TranslatedQuery",
|
|
11
|
+
"SchemaRegistry",
|
|
12
|
+
"AsyncSchemaRegistry",
|
|
13
|
+
"Executor",
|
|
14
|
+
"AsyncExecutor",
|
|
15
|
+
"QueryResult",
|
|
16
|
+
"__version__",
|
|
17
|
+
]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""SQL Executor - executes translated queries against Redis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
import redis
|
|
10
|
+
|
|
11
|
+
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
|
|
12
|
+
from sql_redis.translator import Translator
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import redis.asyncio as async_redis
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _substitute_params(sql: str, params: dict[str, Any]) -> str:
|
|
19
|
+
"""Substitute parameter placeholders in SQL with actual values.
|
|
20
|
+
|
|
21
|
+
This is a pure function with no I/O operations, shared by both
|
|
22
|
+
sync and async executors.
|
|
23
|
+
|
|
24
|
+
Uses token-based approach: splits SQL on :param patterns, then rebuilds
|
|
25
|
+
with substituted values. This approach solves two critical bugs:
|
|
26
|
+
|
|
27
|
+
1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
|
|
28
|
+
by treating each :identifier as a complete token
|
|
29
|
+
|
|
30
|
+
2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
|
|
31
|
+
using SQL standard (single quote -> double single quote)
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
sql: The SQL string with :param placeholders.
|
|
35
|
+
params: Dictionary mapping parameter names to values.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
SQL string with parameters substituted.
|
|
39
|
+
|
|
40
|
+
Implementation Details:
|
|
41
|
+
- Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
|
|
42
|
+
- Keeps delimiters (the :param tokens) in the split result
|
|
43
|
+
- Iterates through tokens, substituting matched parameters
|
|
44
|
+
- String values are wrapped in single quotes with proper escaping
|
|
45
|
+
- Numeric values are converted to strings
|
|
46
|
+
- Bytes values (e.g., vectors) are NOT substituted here
|
|
47
|
+
|
|
48
|
+
Known Limitations:
|
|
49
|
+
- Colons in string literals: SQL like "WHERE x = 'test:value'" would
|
|
50
|
+
theoretically match :value as a parameter. However, this is not a
|
|
51
|
+
practical issue because:
|
|
52
|
+
1. Users pass values via parameters, not hardcoded in SQL
|
|
53
|
+
2. The translator has its own handling of string literals
|
|
54
|
+
3. No real-world use cases have been identified
|
|
55
|
+
- Parameter names are case-sensitive (:id != :ID)
|
|
56
|
+
- Only handles int, float, str types; other types keep placeholder
|
|
57
|
+
"""
|
|
58
|
+
if not params:
|
|
59
|
+
return sql
|
|
60
|
+
|
|
61
|
+
# Split SQL on :param patterns, keeping the delimiters
|
|
62
|
+
# Pattern matches : followed by valid identifier:
|
|
63
|
+
# [a-zA-Z_] - First char must be letter or underscore
|
|
64
|
+
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
|
|
65
|
+
# This prevents partial matching: :id and :product_id are separate tokens
|
|
66
|
+
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
|
|
67
|
+
|
|
68
|
+
result = []
|
|
69
|
+
for token in tokens:
|
|
70
|
+
if token.startswith(":"):
|
|
71
|
+
# This is a parameter placeholder
|
|
72
|
+
key = token[1:] # Remove leading :
|
|
73
|
+
if key in params:
|
|
74
|
+
value = params[key]
|
|
75
|
+
if isinstance(value, (int, float)):
|
|
76
|
+
# Numeric values: convert to string
|
|
77
|
+
result.append(str(value))
|
|
78
|
+
elif isinstance(value, str):
|
|
79
|
+
# String values: wrap in quotes and escape single quotes
|
|
80
|
+
# SQL standard: ' -> '' (double single quote)
|
|
81
|
+
# This fixes the quote escaping bug
|
|
82
|
+
escaped = value.replace("'", "''")
|
|
83
|
+
result.append(f"'{escaped}'")
|
|
84
|
+
else:
|
|
85
|
+
# Other types (bytes, None, bool, list, etc.):
|
|
86
|
+
# Keep placeholder as-is (handled elsewhere or unsupported)
|
|
87
|
+
result.append(token)
|
|
88
|
+
else:
|
|
89
|
+
# Parameter not provided: keep placeholder as-is
|
|
90
|
+
result.append(token)
|
|
91
|
+
else:
|
|
92
|
+
# Not a parameter: keep as-is
|
|
93
|
+
result.append(token)
|
|
94
|
+
|
|
95
|
+
return "".join(result)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class QueryResult:
|
|
100
|
+
"""Result of executing a SQL query."""
|
|
101
|
+
|
|
102
|
+
rows: list[dict]
|
|
103
|
+
count: int
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Executor:
|
|
107
|
+
"""Executes SQL queries against Redis."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry) -> None:
|
|
110
|
+
"""Initialize executor with Redis client and schema registry."""
|
|
111
|
+
self._client = client
|
|
112
|
+
self._schema_registry = schema_registry
|
|
113
|
+
self._translator = Translator(schema_registry)
|
|
114
|
+
|
|
115
|
+
def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
|
|
116
|
+
"""Execute a SQL query and return results."""
|
|
117
|
+
params = params or {}
|
|
118
|
+
|
|
119
|
+
# Substitute non-bytes params in SQL using token-based approach
|
|
120
|
+
sql = _substitute_params(sql, params)
|
|
121
|
+
|
|
122
|
+
# Translate SQL to Redis command
|
|
123
|
+
translated = self._translator.translate(sql)
|
|
124
|
+
|
|
125
|
+
# Build command list and substitute vector params
|
|
126
|
+
# Use list[str | bytes] to allow bytes for vector params
|
|
127
|
+
cmd: list[str | bytes] = list(translated.to_command_list())
|
|
128
|
+
|
|
129
|
+
# Find any bytes params (vectors) to substitute
|
|
130
|
+
vector_param: bytes | None = None
|
|
131
|
+
for value in params.values():
|
|
132
|
+
if isinstance(value, bytes):
|
|
133
|
+
vector_param = value
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
# Replace $vector placeholder with actual bytes
|
|
137
|
+
if vector_param:
|
|
138
|
+
for i, arg in enumerate(cmd):
|
|
139
|
+
if arg == "$vector":
|
|
140
|
+
cmd[i] = vector_param
|
|
141
|
+
|
|
142
|
+
# Execute command
|
|
143
|
+
raw_result = self._client.execute_command(*cmd)
|
|
144
|
+
|
|
145
|
+
# Parse result based on command type
|
|
146
|
+
count = raw_result[0] if raw_result else 0
|
|
147
|
+
rows = []
|
|
148
|
+
|
|
149
|
+
if translated.command == "FT.SEARCH":
|
|
150
|
+
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
|
|
151
|
+
# Skip document keys (odd indices), take field lists (even indices after count)
|
|
152
|
+
for i in range(2, len(raw_result), 2):
|
|
153
|
+
row_data = raw_result[i]
|
|
154
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
155
|
+
rows.append(row)
|
|
156
|
+
else:
|
|
157
|
+
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
|
|
158
|
+
for row_data in raw_result[1:]:
|
|
159
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
160
|
+
rows.append(row)
|
|
161
|
+
|
|
162
|
+
return QueryResult(rows=rows, count=count)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class AsyncExecutor:
|
|
166
|
+
"""Async version of Executor for use with redis.asyncio clients."""
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
client: "async_redis.Redis",
|
|
171
|
+
schema_registry: AsyncSchemaRegistry,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Initialize async executor with Redis client and schema registry.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
client: An async Redis client (redis.asyncio.Redis).
|
|
177
|
+
schema_registry: An AsyncSchemaRegistry instance.
|
|
178
|
+
"""
|
|
179
|
+
self._client = client
|
|
180
|
+
self._schema_registry = schema_registry
|
|
181
|
+
self._translator = Translator(schema_registry)
|
|
182
|
+
|
|
183
|
+
async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
|
|
184
|
+
"""Execute a SQL query asynchronously and return results."""
|
|
185
|
+
params = params or {}
|
|
186
|
+
|
|
187
|
+
# Substitute non-bytes params in SQL
|
|
188
|
+
sql = _substitute_params(sql, params)
|
|
189
|
+
|
|
190
|
+
# Translate SQL to Redis command (sync - no Redis calls)
|
|
191
|
+
translated = self._translator.translate(sql)
|
|
192
|
+
|
|
193
|
+
# Build command list and substitute vector params
|
|
194
|
+
cmd: list[str | bytes] = list(translated.to_command_list())
|
|
195
|
+
|
|
196
|
+
# Find any bytes params (vectors) to substitute
|
|
197
|
+
vector_param: bytes | None = None
|
|
198
|
+
for value in params.values():
|
|
199
|
+
if isinstance(value, bytes):
|
|
200
|
+
vector_param = value
|
|
201
|
+
break
|
|
202
|
+
|
|
203
|
+
# Replace $vector placeholder with actual bytes
|
|
204
|
+
if vector_param:
|
|
205
|
+
for i, arg in enumerate(cmd):
|
|
206
|
+
if arg == "$vector":
|
|
207
|
+
cmd[i] = vector_param
|
|
208
|
+
|
|
209
|
+
# Execute command asynchronously
|
|
210
|
+
raw_result = await self._client.execute_command(*cmd)
|
|
211
|
+
|
|
212
|
+
# Parse result based on command type
|
|
213
|
+
count = raw_result[0] if raw_result else 0
|
|
214
|
+
rows = []
|
|
215
|
+
|
|
216
|
+
if translated.command == "FT.SEARCH":
|
|
217
|
+
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
|
|
218
|
+
for i in range(2, len(raw_result), 2):
|
|
219
|
+
row_data = raw_result[i]
|
|
220
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
221
|
+
rows.append(row)
|
|
222
|
+
else:
|
|
223
|
+
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
|
|
224
|
+
for row_data in raw_result[1:]:
|
|
225
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
226
|
+
rows.append(row)
|
|
227
|
+
|
|
228
|
+
return QueryResult(rows=rows, count=count)
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Schema registry for Redis search indexes."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Callable
|
|
6
|
+
|
|
7
|
+
import redis
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import redis.asyncio as async_redis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _parse_schema_from_info(info: list) -> dict[str, str]:
|
|
14
|
+
"""Parse field types from FT.INFO response.
|
|
15
|
+
|
|
16
|
+
This is a pure function with no I/O operations, shared by both
|
|
17
|
+
sync and async schema registries.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
info: The raw response from FT.INFO command.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dictionary mapping field names to their types (e.g., {"title": "TEXT"}).
|
|
24
|
+
"""
|
|
25
|
+
schema = {}
|
|
26
|
+
# Find the 'attributes' section in the info response
|
|
27
|
+
for i, item in enumerate(info):
|
|
28
|
+
# Handle bytes or string comparison
|
|
29
|
+
item_str = item.decode("utf-8") if isinstance(item, bytes) else item
|
|
30
|
+
if item_str == "attributes":
|
|
31
|
+
attributes = info[i + 1]
|
|
32
|
+
for attr in attributes:
|
|
33
|
+
field_name = None
|
|
34
|
+
field_type = None
|
|
35
|
+
# Each attribute is a list like:
|
|
36
|
+
# [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...]
|
|
37
|
+
for j, val in enumerate(attr):
|
|
38
|
+
val_str = val.decode("utf-8") if isinstance(val, bytes) else val
|
|
39
|
+
if val_str == "attribute" and j + 1 < len(attr):
|
|
40
|
+
fn = attr[j + 1]
|
|
41
|
+
field_name = fn.decode("utf-8") if isinstance(fn, bytes) else fn
|
|
42
|
+
if val_str == "type" and j + 1 < len(attr):
|
|
43
|
+
ft = attr[j + 1]
|
|
44
|
+
field_type = ft.decode("utf-8") if isinstance(ft, bytes) else ft
|
|
45
|
+
if field_name and field_type:
|
|
46
|
+
schema[field_name] = field_type
|
|
47
|
+
break
|
|
48
|
+
return schema
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SchemaRegistry:
|
|
52
|
+
"""Loads and caches index schemas from Redis.
|
|
53
|
+
|
|
54
|
+
Supports automatic schema refresh via Redis keyspace notifications.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, redis_client: redis.Redis):
|
|
58
|
+
self._client = redis_client
|
|
59
|
+
self._schemas: dict[str, dict[str, str]] = {}
|
|
60
|
+
self._on_change: Callable[[str, str], None] | None = None
|
|
61
|
+
self._watching = False
|
|
62
|
+
|
|
63
|
+
def load_all(self) -> None:
|
|
64
|
+
"""Load schemas for all indexes on the server."""
|
|
65
|
+
self._schemas.clear()
|
|
66
|
+
indexes = self._client.execute_command("FT._LIST")
|
|
67
|
+
for index_name in indexes:
|
|
68
|
+
# Decode bytes to string if needed
|
|
69
|
+
if isinstance(index_name, bytes):
|
|
70
|
+
index_name = index_name.decode("utf-8")
|
|
71
|
+
self._load_index_schema(index_name)
|
|
72
|
+
|
|
73
|
+
def _load_index_schema(self, index_name: str) -> None:
|
|
74
|
+
"""Load schema for a single index."""
|
|
75
|
+
try:
|
|
76
|
+
info = self._client.execute_command("FT.INFO", index_name)
|
|
77
|
+
schema = _parse_schema_from_info(info)
|
|
78
|
+
self._schemas[index_name] = schema
|
|
79
|
+
except redis.ResponseError:
|
|
80
|
+
# Index doesn't exist or was deleted
|
|
81
|
+
self._schemas.pop(index_name, None)
|
|
82
|
+
|
|
83
|
+
def get_field_type(self, index: str, field: str) -> str | None:
|
|
84
|
+
"""Get field type for a given index and field.
|
|
85
|
+
|
|
86
|
+
Returns None if index or field is unknown.
|
|
87
|
+
"""
|
|
88
|
+
schema = self._schemas.get(index, {})
|
|
89
|
+
return schema.get(field)
|
|
90
|
+
|
|
91
|
+
def get_schema(self, index: str) -> dict[str, str]:
|
|
92
|
+
"""Get full schema for an index.
|
|
93
|
+
|
|
94
|
+
Returns empty dict if index is unknown.
|
|
95
|
+
"""
|
|
96
|
+
return self._schemas.get(index, {})
|
|
97
|
+
|
|
98
|
+
def refresh(self, index_name: str) -> None:
|
|
99
|
+
"""Refresh schema for a single index.
|
|
100
|
+
|
|
101
|
+
If the index no longer exists, removes it from the registry.
|
|
102
|
+
If the index is new, adds it to the registry.
|
|
103
|
+
"""
|
|
104
|
+
self._load_index_schema(index_name)
|
|
105
|
+
|
|
106
|
+
def start_watching(
|
|
107
|
+
self, on_change: Callable[[str, str], None] | None = None
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Start watching for index changes.
|
|
110
|
+
|
|
111
|
+
Since RediSearch doesn't emit keyspace notifications for FT commands,
|
|
112
|
+
this uses polling via FT._LIST to detect changes.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
on_change: Optional callback invoked with (event_type, index_name)
|
|
116
|
+
when an index is created, dropped, or altered.
|
|
117
|
+
"""
|
|
118
|
+
self._on_change = on_change
|
|
119
|
+
self._watching = True
|
|
120
|
+
|
|
121
|
+
def stop_watching(self) -> None:
|
|
122
|
+
"""Stop watching for index changes."""
|
|
123
|
+
self._watching = False
|
|
124
|
+
self._on_change = None
|
|
125
|
+
|
|
126
|
+
def process_pending_events(self) -> None:
|
|
127
|
+
"""Process any pending index change events.
|
|
128
|
+
|
|
129
|
+
Since RediSearch doesn't emit keyspace notifications, this polls
|
|
130
|
+
FT._LIST to detect new and deleted indexes. Call this periodically.
|
|
131
|
+
"""
|
|
132
|
+
if not self._watching:
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# Get current indexes from Redis
|
|
136
|
+
current_indexes = set(self._client.execute_command("FT._LIST"))
|
|
137
|
+
|
|
138
|
+
cached_indexes = set(self._schemas.keys())
|
|
139
|
+
|
|
140
|
+
# Detect new indexes
|
|
141
|
+
new_indexes = current_indexes - cached_indexes
|
|
142
|
+
for idx in new_indexes:
|
|
143
|
+
self._load_index_schema(idx)
|
|
144
|
+
if self._on_change:
|
|
145
|
+
self._on_change("created", idx)
|
|
146
|
+
|
|
147
|
+
# Detect deleted indexes
|
|
148
|
+
deleted_indexes = cached_indexes - current_indexes
|
|
149
|
+
for idx in deleted_indexes:
|
|
150
|
+
self._schemas.pop(idx, None)
|
|
151
|
+
if self._on_change:
|
|
152
|
+
self._on_change("dropped", idx)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class AsyncSchemaRegistry:
|
|
156
|
+
"""Async version of SchemaRegistry for use with redis.asyncio clients.
|
|
157
|
+
|
|
158
|
+
Loads and caches index schemas from Redis asynchronously.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(self, redis_client: "async_redis.Redis") -> None:
|
|
162
|
+
"""Initialize with an async Redis client.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
redis_client: An async Redis client (redis.asyncio.Redis).
|
|
166
|
+
"""
|
|
167
|
+
self._client = redis_client
|
|
168
|
+
self._schemas: dict[str, dict[str, str]] = {}
|
|
169
|
+
|
|
170
|
+
async def load_all(self) -> None:
|
|
171
|
+
"""Load schemas for all indexes on the server.
|
|
172
|
+
|
|
173
|
+
Uses asyncio.gather() to load all index schemas concurrently.
|
|
174
|
+
"""
|
|
175
|
+
import asyncio
|
|
176
|
+
|
|
177
|
+
self._schemas.clear()
|
|
178
|
+
indexes = await self._client.execute_command("FT._LIST")
|
|
179
|
+
# Decode bytes to strings
|
|
180
|
+
decoded_indexes = [
|
|
181
|
+
idx.decode("utf-8") if isinstance(idx, bytes) else idx for idx in indexes
|
|
182
|
+
]
|
|
183
|
+
# Load all schemas concurrently
|
|
184
|
+
await asyncio.gather(
|
|
185
|
+
*[self._load_index_schema(name) for name in decoded_indexes]
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
async def _load_index_schema(self, index_name: str) -> None:
|
|
189
|
+
"""Load schema for a single index."""
|
|
190
|
+
try:
|
|
191
|
+
info = await self._client.execute_command("FT.INFO", index_name)
|
|
192
|
+
schema = _parse_schema_from_info(info)
|
|
193
|
+
self._schemas[index_name] = schema
|
|
194
|
+
except redis.ResponseError:
|
|
195
|
+
# Index doesn't exist or was deleted
|
|
196
|
+
self._schemas.pop(index_name, None)
|
|
197
|
+
|
|
198
|
+
def get_field_type(self, index: str, field: str) -> str | None:
|
|
199
|
+
"""Get field type for a given index and field.
|
|
200
|
+
|
|
201
|
+
Returns None if index or field is unknown.
|
|
202
|
+
"""
|
|
203
|
+
schema = self._schemas.get(index, {})
|
|
204
|
+
return schema.get(field)
|
|
205
|
+
|
|
206
|
+
def get_schema(self, index: str) -> dict[str, str]:
|
|
207
|
+
"""Get full schema for an index.
|
|
208
|
+
|
|
209
|
+
Returns empty dict if index is unknown.
|
|
210
|
+
"""
|
|
211
|
+
return self._schemas.get(index, {})
|
|
212
|
+
|
|
213
|
+
async def refresh(self, index_name: str) -> None:
|
|
214
|
+
"""Refresh schema for a single index."""
|
|
215
|
+
await self._load_index_schema(index_name)
|
|
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from sql_redis.analyzer import AnalyzedQuery, Analyzer
|
|
8
8
|
from sql_redis.parser import Condition, ParsedQuery, SQLParser
|
|
9
9
|
from sql_redis.query_builder import QueryBuilder
|
|
10
|
-
from sql_redis.schema import SchemaRegistry
|
|
10
|
+
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@dataclass
|
|
@@ -34,11 +34,13 @@ class TranslatedQuery:
|
|
|
34
34
|
class Translator:
|
|
35
35
|
"""Translates SQL queries to Redis FT.SEARCH/FT.AGGREGATE commands."""
|
|
36
36
|
|
|
37
|
-
def __init__(self, schema_registry: SchemaRegistry):
|
|
37
|
+
def __init__(self, schema_registry: SchemaRegistry | AsyncSchemaRegistry) -> None:
|
|
38
38
|
"""Initialize translator with schema registry.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
schema_registry: Registry containing index schemas.
|
|
41
|
+
schema_registry: Registry containing index schemas. Can be either
|
|
42
|
+
sync (SchemaRegistry) or async (AsyncSchemaRegistry) - only
|
|
43
|
+
the sync get_schema() method is used.
|
|
42
44
|
"""
|
|
43
45
|
self._schema_registry = schema_registry
|
|
44
46
|
self._parser = SQLParser()
|
|
@@ -2,7 +2,7 @@ try:
|
|
|
2
2
|
from importlib.metadata import PackageNotFoundError, version
|
|
3
3
|
except ImportError:
|
|
4
4
|
# Python < 3.8 fallback
|
|
5
|
-
from importlib_metadata import PackageNotFoundError, version # type: ignore
|
|
5
|
+
from importlib_metadata import PackageNotFoundError, version # type: ignore # isort: skip
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
__version__ = version("sql-redis")
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""Integration tests for async SQL executor.
|
|
2
|
+
|
|
3
|
+
TDD: These tests define the expected behavior for AsyncSchemaRegistry and AsyncExecutor.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import struct
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import redis.asyncio as async_redis
|
|
10
|
+
from testcontainers.redis import RedisContainer
|
|
11
|
+
|
|
12
|
+
from sql_redis.executor import AsyncExecutor, QueryResult
|
|
13
|
+
from sql_redis.schema import AsyncSchemaRegistry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture(scope="module")
|
|
17
|
+
def redis_container():
|
|
18
|
+
"""Start a Redis container for testing."""
|
|
19
|
+
with RedisContainer("redis:8.0.2") as container:
|
|
20
|
+
yield container
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
async def async_client(redis_container) -> async_redis.Redis:
|
|
25
|
+
"""Create an async Redis client connected to the test container."""
|
|
26
|
+
client = async_redis.Redis(
|
|
27
|
+
host=redis_container.get_container_host_ip(),
|
|
28
|
+
port=int(redis_container.get_exposed_port(6379)),
|
|
29
|
+
decode_responses=True,
|
|
30
|
+
)
|
|
31
|
+
yield client
|
|
32
|
+
await client.aclose()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
async def products_index(async_client: async_redis.Redis) -> str:
|
|
37
|
+
"""Create a products index with test data."""
|
|
38
|
+
index_name = "async_products"
|
|
39
|
+
try:
|
|
40
|
+
await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
|
|
41
|
+
except Exception:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
await async_client.execute_command(
|
|
45
|
+
"FT.CREATE",
|
|
46
|
+
index_name,
|
|
47
|
+
"ON",
|
|
48
|
+
"HASH",
|
|
49
|
+
"PREFIX",
|
|
50
|
+
"1",
|
|
51
|
+
"async_product:",
|
|
52
|
+
"SCHEMA",
|
|
53
|
+
"title",
|
|
54
|
+
"TEXT",
|
|
55
|
+
"category",
|
|
56
|
+
"TAG",
|
|
57
|
+
"price",
|
|
58
|
+
"NUMERIC",
|
|
59
|
+
"stock",
|
|
60
|
+
"NUMERIC",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Add test data
|
|
64
|
+
await async_client.hset(
|
|
65
|
+
"async_product:1",
|
|
66
|
+
mapping={
|
|
67
|
+
"title": "Laptop Pro",
|
|
68
|
+
"category": "electronics",
|
|
69
|
+
"price": "999.99",
|
|
70
|
+
"stock": "10",
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
await async_client.hset(
|
|
74
|
+
"async_product:2",
|
|
75
|
+
mapping={
|
|
76
|
+
"title": "Wireless Mouse",
|
|
77
|
+
"category": "electronics",
|
|
78
|
+
"price": "29.99",
|
|
79
|
+
"stock": "50",
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
await async_client.hset(
|
|
83
|
+
"async_product:3",
|
|
84
|
+
mapping={
|
|
85
|
+
"title": "Python Book",
|
|
86
|
+
"category": "books",
|
|
87
|
+
"price": "49.99",
|
|
88
|
+
"stock": "25",
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
await async_client.hset(
|
|
92
|
+
"async_product:4",
|
|
93
|
+
mapping={
|
|
94
|
+
"title": "Redis Guide",
|
|
95
|
+
"category": "books",
|
|
96
|
+
"price": "39.99",
|
|
97
|
+
"stock": "15",
|
|
98
|
+
},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
yield index_name
|
|
102
|
+
|
|
103
|
+
# Cleanup
|
|
104
|
+
try:
|
|
105
|
+
await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.fixture
|
|
111
|
+
async def async_executor(
|
|
112
|
+
async_client: async_redis.Redis, products_index: str
|
|
113
|
+
) -> AsyncExecutor:
|
|
114
|
+
"""Create an async executor with the products index loaded."""
|
|
115
|
+
registry = AsyncSchemaRegistry(async_client)
|
|
116
|
+
await registry.load_all()
|
|
117
|
+
return AsyncExecutor(async_client, registry)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TestAsyncSchemaRegistry:
|
|
121
|
+
"""Tests for AsyncSchemaRegistry."""
|
|
122
|
+
|
|
123
|
+
async def test_load_all_loads_indexes(
|
|
124
|
+
self, async_client: async_redis.Redis, products_index: str
|
|
125
|
+
):
|
|
126
|
+
"""load_all() should load index schemas from Redis."""
|
|
127
|
+
registry = AsyncSchemaRegistry(async_client)
|
|
128
|
+
await registry.load_all()
|
|
129
|
+
|
|
130
|
+
schema = registry.get_schema(products_index)
|
|
131
|
+
assert schema is not None
|
|
132
|
+
assert "title" in schema
|
|
133
|
+
assert schema["title"] == "TEXT"
|
|
134
|
+
assert "category" in schema
|
|
135
|
+
assert schema["category"] == "TAG"
|
|
136
|
+
assert "price" in schema
|
|
137
|
+
assert schema["price"] == "NUMERIC"
|
|
138
|
+
|
|
139
|
+
async def test_get_schema_returns_empty_for_unknown(
|
|
140
|
+
self, async_client: async_redis.Redis
|
|
141
|
+
):
|
|
142
|
+
"""get_schema() returns empty dict for unknown index."""
|
|
143
|
+
registry = AsyncSchemaRegistry(async_client)
|
|
144
|
+
await registry.load_all()
|
|
145
|
+
|
|
146
|
+
schema = registry.get_schema("nonexistent_index")
|
|
147
|
+
assert schema == {}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class TestAsyncExecutorBasic:
|
|
151
|
+
"""Tests for basic async query execution."""
|
|
152
|
+
|
|
153
|
+
async def test_select_all(self, async_executor: AsyncExecutor, products_index: str):
|
|
154
|
+
"""SELECT * returns all documents."""
|
|
155
|
+
result = await async_executor.execute(f"SELECT * FROM {products_index}")
|
|
156
|
+
assert result.count == 4
|
|
157
|
+
assert len(result.rows) == 4
|
|
158
|
+
|
|
159
|
+
async def test_result_is_query_result(
|
|
160
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
161
|
+
):
|
|
162
|
+
"""Result should be a QueryResult instance."""
|
|
163
|
+
result = await async_executor.execute(f"SELECT * FROM {products_index}")
|
|
164
|
+
assert isinstance(result, QueryResult)
|
|
165
|
+
assert hasattr(result, "rows")
|
|
166
|
+
assert hasattr(result, "count")
|
|
167
|
+
|
|
168
|
+
async def test_select_with_tag_filter(
|
|
169
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
170
|
+
):
|
|
171
|
+
"""SELECT with tag filter."""
|
|
172
|
+
result = await async_executor.execute(
|
|
173
|
+
f"SELECT * FROM {products_index} WHERE category = 'books'"
|
|
174
|
+
)
|
|
175
|
+
assert result.count == 2
|
|
176
|
+
for row in result.rows:
|
|
177
|
+
assert row["category"] == "books"
|
|
178
|
+
|
|
179
|
+
async def test_select_with_numeric_filter(
|
|
180
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
181
|
+
):
|
|
182
|
+
"""SELECT with numeric comparison."""
|
|
183
|
+
result = await async_executor.execute(
|
|
184
|
+
f"SELECT * FROM {products_index} WHERE price < 50"
|
|
185
|
+
)
|
|
186
|
+
assert result.count >= 2
|
|
187
|
+
for row in result.rows:
|
|
188
|
+
assert float(row["price"]) < 50
|
|
189
|
+
|
|
190
|
+
async def test_select_with_limit(
|
|
191
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
192
|
+
):
|
|
193
|
+
"""SELECT with LIMIT."""
|
|
194
|
+
result = await async_executor.execute(f"SELECT * FROM {products_index} LIMIT 2")
|
|
195
|
+
assert len(result.rows) == 2
|
|
196
|
+
|
|
197
|
+
async def test_select_with_order_by(
|
|
198
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
199
|
+
):
|
|
200
|
+
"""SELECT with ORDER BY."""
|
|
201
|
+
result = await async_executor.execute(
|
|
202
|
+
f"SELECT * FROM {products_index} ORDER BY price DESC"
|
|
203
|
+
)
|
|
204
|
+
prices = [float(row["price"]) for row in result.rows]
|
|
205
|
+
assert prices == sorted(prices, reverse=True)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class TestAsyncExecutorAggregation:
|
|
209
|
+
"""Tests for async aggregate query execution."""
|
|
210
|
+
|
|
211
|
+
async def test_count_all(self, async_executor: AsyncExecutor, products_index: str):
|
|
212
|
+
"""SELECT COUNT(*) returns count."""
|
|
213
|
+
result = await async_executor.execute(f"SELECT COUNT(*) FROM {products_index}")
|
|
214
|
+
assert len(result.rows) == 1
|
|
215
|
+
row = result.rows[0]
|
|
216
|
+
count_value = row.get("COUNT(*)", row.get("count", None))
|
|
217
|
+
assert count_value is not None
|
|
218
|
+
|
|
219
|
+
async def test_group_by_with_count(
|
|
220
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
221
|
+
):
|
|
222
|
+
"""SELECT with GROUP BY and COUNT."""
|
|
223
|
+
result = await async_executor.execute(
|
|
224
|
+
f"SELECT category, COUNT(*) as cnt FROM {products_index} GROUP BY category"
|
|
225
|
+
)
|
|
226
|
+
assert len(result.rows) == 2 # electronics and books
|
|
227
|
+
categories = {row["category"] for row in result.rows}
|
|
228
|
+
assert categories == {"electronics", "books"}
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class TestAsyncExecutorParams:
|
|
232
|
+
"""Tests for parameterized async execution."""
|
|
233
|
+
|
|
234
|
+
async def test_numeric_param(
|
|
235
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
236
|
+
):
|
|
237
|
+
"""Execute with numeric parameter."""
|
|
238
|
+
result = await async_executor.execute(
|
|
239
|
+
f"SELECT * FROM {products_index} WHERE price > :min_price",
|
|
240
|
+
params={"min_price": 40},
|
|
241
|
+
)
|
|
242
|
+
for row in result.rows:
|
|
243
|
+
assert float(row["price"]) > 40
|
|
244
|
+
|
|
245
|
+
async def test_string_param(
|
|
246
|
+
self, async_executor: AsyncExecutor, products_index: str
|
|
247
|
+
):
|
|
248
|
+
"""Execute with string parameter."""
|
|
249
|
+
result = await async_executor.execute(
|
|
250
|
+
f"SELECT * FROM {products_index} WHERE category = :cat",
|
|
251
|
+
params={"cat": "books"},
|
|
252
|
+
)
|
|
253
|
+
assert len(result.rows) == 2
|
|
254
|
+
for row in result.rows:
|
|
255
|
+
assert row["category"] == "books"
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class TestAsyncVectorSearch:
|
|
259
|
+
"""Tests for async vector search execution."""
|
|
260
|
+
|
|
261
|
+
@pytest.fixture
|
|
262
|
+
async def vector_index(self, async_client: async_redis.Redis) -> str:
|
|
263
|
+
"""Create a vector index with test data."""
|
|
264
|
+
index_name = "async_vectors"
|
|
265
|
+
try:
|
|
266
|
+
await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
|
|
267
|
+
except Exception:
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
await async_client.execute_command(
|
|
271
|
+
"FT.CREATE",
|
|
272
|
+
index_name,
|
|
273
|
+
"ON",
|
|
274
|
+
"HASH",
|
|
275
|
+
"PREFIX",
|
|
276
|
+
"1",
|
|
277
|
+
"async_vec:",
|
|
278
|
+
"SCHEMA",
|
|
279
|
+
"title",
|
|
280
|
+
"TEXT",
|
|
281
|
+
"embedding",
|
|
282
|
+
"VECTOR",
|
|
283
|
+
"HNSW",
|
|
284
|
+
"6",
|
|
285
|
+
"TYPE",
|
|
286
|
+
"FLOAT32",
|
|
287
|
+
"DIM",
|
|
288
|
+
"4",
|
|
289
|
+
"DISTANCE_METRIC",
|
|
290
|
+
"COSINE",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def to_bytes(v):
|
|
294
|
+
return struct.pack(f"{len(v)}f", *v)
|
|
295
|
+
|
|
296
|
+
# Use a separate non-decode client for binary data
|
|
297
|
+
raw_client = async_redis.Redis(
|
|
298
|
+
host=async_client.connection_pool.connection_kwargs["host"],
|
|
299
|
+
port=async_client.connection_pool.connection_kwargs["port"],
|
|
300
|
+
decode_responses=False,
|
|
301
|
+
)
|
|
302
|
+
await raw_client.hset(
|
|
303
|
+
"async_vec:1",
|
|
304
|
+
mapping={"title": "First", "embedding": to_bytes([0.1, 0.2, 0.3, 0.4])},
|
|
305
|
+
)
|
|
306
|
+
await raw_client.hset(
|
|
307
|
+
"async_vec:2",
|
|
308
|
+
mapping={"title": "Second", "embedding": to_bytes([0.5, 0.6, 0.7, 0.8])},
|
|
309
|
+
)
|
|
310
|
+
await raw_client.hset(
|
|
311
|
+
"async_vec:3",
|
|
312
|
+
mapping={"title": "Third", "embedding": to_bytes([0.9, 0.8, 0.7, 0.6])},
|
|
313
|
+
)
|
|
314
|
+
await raw_client.aclose()
|
|
315
|
+
|
|
316
|
+
yield index_name
|
|
317
|
+
|
|
318
|
+
# Cleanup
|
|
319
|
+
try:
|
|
320
|
+
await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
|
|
321
|
+
except Exception:
|
|
322
|
+
pass
|
|
323
|
+
|
|
324
|
+
async def test_vector_search_with_param(
|
|
325
|
+
self, async_client: async_redis.Redis, vector_index: str
|
|
326
|
+
):
|
|
327
|
+
"""Vector search with vector parameter."""
|
|
328
|
+
registry = AsyncSchemaRegistry(async_client)
|
|
329
|
+
await registry.load_all()
|
|
330
|
+
executor = AsyncExecutor(async_client, registry)
|
|
331
|
+
|
|
332
|
+
query_vector = struct.pack("4f", 0.1, 0.2, 0.3, 0.4)
|
|
333
|
+
result = await executor.execute(
|
|
334
|
+
f"SELECT title, vector_distance(embedding, :vec) AS score "
|
|
335
|
+
f"FROM {vector_index} LIMIT 3",
|
|
336
|
+
params={"vec": query_vector},
|
|
337
|
+
)
|
|
338
|
+
assert len(result.rows) <= 3
|
|
339
|
+
# First result should be closest to query vector
|
|
340
|
+
assert result.rows[0]["title"] == "First"
|
|
341
|
+
# Verify vector distance score is returned
|
|
342
|
+
assert "score" in result.rows[0]
|
|
343
|
+
score = float(result.rows[0]["score"])
|
|
344
|
+
assert score >= 0 # Distance should be non-negative
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import pytest
|
|
4
4
|
import redis
|
|
5
5
|
|
|
6
|
-
from sql_redis.schema import SchemaRegistry
|
|
6
|
+
from sql_redis.schema import SchemaRegistry, _parse_schema_from_info
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def _create_test_indexes(redis_client: redis.Redis) -> list[str]:
|
|
@@ -222,20 +222,16 @@ class TestSchemaRegistryEmptyServer:
|
|
|
222
222
|
class TestSchemaRegistryParsing:
|
|
223
223
|
"""Tests for schema parsing edge cases."""
|
|
224
224
|
|
|
225
|
-
def test_parse_schema_no_attributes_section(self
|
|
225
|
+
def test_parse_schema_no_attributes_section(self):
|
|
226
226
|
"""_parse_schema_from_info handles response without attributes."""
|
|
227
|
-
registry = SchemaRegistry(redis_client)
|
|
228
|
-
|
|
229
227
|
# FT.INFO response without 'attributes' key
|
|
230
228
|
fake_info = ["index_name", "test", "other_key", "value"]
|
|
231
|
-
schema =
|
|
229
|
+
schema = _parse_schema_from_info(fake_info)
|
|
232
230
|
|
|
233
231
|
assert schema == {}
|
|
234
232
|
|
|
235
|
-
def test_parse_schema_incomplete_attribute(self
|
|
233
|
+
def test_parse_schema_incomplete_attribute(self):
|
|
236
234
|
"""_parse_schema_from_info handles attribute without type."""
|
|
237
|
-
registry = SchemaRegistry(redis_client)
|
|
238
|
-
|
|
239
235
|
# FT.INFO response with attribute but missing type
|
|
240
236
|
fake_info = [
|
|
241
237
|
"attributes",
|
|
@@ -244,7 +240,7 @@ class TestSchemaRegistryParsing:
|
|
|
244
240
|
["identifier", "field2", "attribute", "field2", "type", "TEXT"],
|
|
245
241
|
],
|
|
246
242
|
]
|
|
247
|
-
schema =
|
|
243
|
+
schema = _parse_schema_from_info(fake_info)
|
|
248
244
|
|
|
249
245
|
# Only field2 should be captured (field1 has no type)
|
|
250
246
|
assert schema == {"field2": "TEXT"}
|
|
@@ -1,155 +0,0 @@
|
|
|
1
|
-
"""SQL Executor - executes translated queries against Redis."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import re
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
|
-
import redis
|
|
10
|
-
|
|
11
|
-
from sql_redis.schema import SchemaRegistry
|
|
12
|
-
from sql_redis.translator import Translator
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@dataclass
|
|
16
|
-
class QueryResult:
|
|
17
|
-
"""Result of executing a SQL query."""
|
|
18
|
-
|
|
19
|
-
rows: list[dict]
|
|
20
|
-
count: int
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Executor:
|
|
24
|
-
"""Executes SQL queries against Redis."""
|
|
25
|
-
|
|
26
|
-
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry):
|
|
27
|
-
"""Initialize executor with Redis client and schema registry."""
|
|
28
|
-
self._client = client
|
|
29
|
-
self._schema_registry = schema_registry
|
|
30
|
-
self._translator = Translator(schema_registry)
|
|
31
|
-
|
|
32
|
-
def _substitute_params(self, sql: str, params: dict[str, Any]) -> str:
|
|
33
|
-
"""Substitute parameter placeholders in SQL with actual values.
|
|
34
|
-
|
|
35
|
-
Uses token-based approach: splits SQL on :param patterns, then rebuilds
|
|
36
|
-
with substituted values. This approach solves two critical bugs:
|
|
37
|
-
|
|
38
|
-
1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
|
|
39
|
-
by treating each :identifier as a complete token
|
|
40
|
-
|
|
41
|
-
2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
|
|
42
|
-
using SQL standard (single quote -> double single quote)
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
sql: The SQL string with :param placeholders.
|
|
46
|
-
params: Dictionary mapping parameter names to values.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
SQL string with parameters substituted.
|
|
50
|
-
|
|
51
|
-
Implementation Details:
|
|
52
|
-
- Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
|
|
53
|
-
- Keeps delimiters (the :param tokens) in the split result
|
|
54
|
-
- Iterates through tokens, substituting matched parameters
|
|
55
|
-
- String values are wrapped in single quotes with proper escaping
|
|
56
|
-
- Numeric values are converted to strings
|
|
57
|
-
- Bytes values (e.g., vectors) are NOT substituted here
|
|
58
|
-
|
|
59
|
-
Known Limitations:
|
|
60
|
-
- Colons in string literals: SQL like "WHERE x = 'test:value'" would
|
|
61
|
-
theoretically match :value as a parameter. However, this is not a
|
|
62
|
-
practical issue because:
|
|
63
|
-
1. Users pass values via parameters, not hardcoded in SQL
|
|
64
|
-
2. The translator has its own handling of string literals
|
|
65
|
-
3. No real-world use cases have been identified
|
|
66
|
-
- Parameter names are case-sensitive (:id != :ID)
|
|
67
|
-
- Only handles int, float, str types; other types keep placeholder
|
|
68
|
-
"""
|
|
69
|
-
if not params:
|
|
70
|
-
return sql
|
|
71
|
-
|
|
72
|
-
# Split SQL on :param patterns, keeping the delimiters
|
|
73
|
-
# Pattern matches : followed by valid identifier:
|
|
74
|
-
# [a-zA-Z_] - First char must be letter or underscore
|
|
75
|
-
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
|
|
76
|
-
# This prevents partial matching: :id and :product_id are separate tokens
|
|
77
|
-
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
|
|
78
|
-
|
|
79
|
-
result = []
|
|
80
|
-
for token in tokens:
|
|
81
|
-
if token.startswith(":"):
|
|
82
|
-
# This is a parameter placeholder
|
|
83
|
-
key = token[1:] # Remove leading :
|
|
84
|
-
if key in params:
|
|
85
|
-
value = params[key]
|
|
86
|
-
if isinstance(value, (int, float)):
|
|
87
|
-
# Numeric values: convert to string
|
|
88
|
-
result.append(str(value))
|
|
89
|
-
elif isinstance(value, str):
|
|
90
|
-
# String values: wrap in quotes and escape single quotes
|
|
91
|
-
# SQL standard: ' -> '' (double single quote)
|
|
92
|
-
# This fixes the quote escaping bug
|
|
93
|
-
escaped = value.replace("'", "''")
|
|
94
|
-
result.append(f"'{escaped}'")
|
|
95
|
-
else:
|
|
96
|
-
# Other types (bytes, None, bool, list, etc.):
|
|
97
|
-
# Keep placeholder as-is (handled elsewhere or unsupported)
|
|
98
|
-
result.append(token)
|
|
99
|
-
else:
|
|
100
|
-
# Parameter not provided: keep placeholder as-is
|
|
101
|
-
result.append(token)
|
|
102
|
-
else:
|
|
103
|
-
# Not a parameter: keep as-is
|
|
104
|
-
result.append(token)
|
|
105
|
-
|
|
106
|
-
return "".join(result)
|
|
107
|
-
|
|
108
|
-
def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
|
|
109
|
-
"""Execute a SQL query and return results."""
|
|
110
|
-
params = params or {}
|
|
111
|
-
|
|
112
|
-
# Substitute non-bytes params in SQL using token-based approach
|
|
113
|
-
sql = self._substitute_params(sql, params)
|
|
114
|
-
|
|
115
|
-
# Translate SQL to Redis command
|
|
116
|
-
translated = self._translator.translate(sql)
|
|
117
|
-
|
|
118
|
-
# Build command list and substitute vector params
|
|
119
|
-
# Use list[str | bytes] to allow bytes for vector params
|
|
120
|
-
cmd: list[str | bytes] = list(translated.to_command_list())
|
|
121
|
-
|
|
122
|
-
# Find any bytes params (vectors) to substitute
|
|
123
|
-
vector_param: bytes | None = None
|
|
124
|
-
for value in params.values():
|
|
125
|
-
if isinstance(value, bytes):
|
|
126
|
-
vector_param = value
|
|
127
|
-
break
|
|
128
|
-
|
|
129
|
-
# Replace $vector placeholder with actual bytes
|
|
130
|
-
if vector_param:
|
|
131
|
-
for i, arg in enumerate(cmd):
|
|
132
|
-
if arg == "$vector":
|
|
133
|
-
cmd[i] = vector_param
|
|
134
|
-
|
|
135
|
-
# Execute command
|
|
136
|
-
raw_result = self._client.execute_command(*cmd)
|
|
137
|
-
|
|
138
|
-
# Parse result based on command type
|
|
139
|
-
count = raw_result[0] if raw_result else 0
|
|
140
|
-
rows = []
|
|
141
|
-
|
|
142
|
-
if translated.command == "FT.SEARCH":
|
|
143
|
-
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
|
|
144
|
-
# Skip document keys (odd indices), take field lists (even indices after count)
|
|
145
|
-
for i in range(2, len(raw_result), 2):
|
|
146
|
-
row_data = raw_result[i]
|
|
147
|
-
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
148
|
-
rows.append(row)
|
|
149
|
-
else:
|
|
150
|
-
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
|
|
151
|
-
for row_data in raw_result[1:]:
|
|
152
|
-
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
153
|
-
rows.append(row)
|
|
154
|
-
|
|
155
|
-
return QueryResult(rows=rows, count=count)
|
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
"""Schema registry for Redis search indexes."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Callable
|
|
6
|
-
|
|
7
|
-
import redis
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class SchemaRegistry:
|
|
11
|
-
"""Loads and caches index schemas from Redis.
|
|
12
|
-
|
|
13
|
-
Supports automatic schema refresh via Redis keyspace notifications.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
def __init__(self, redis_client: redis.Redis):
|
|
17
|
-
self._client = redis_client
|
|
18
|
-
self._schemas: dict[str, dict[str, str]] = {}
|
|
19
|
-
self._on_change: Callable[[str, str], None] | None = None
|
|
20
|
-
self._watching = False
|
|
21
|
-
|
|
22
|
-
def load_all(self) -> None:
|
|
23
|
-
"""Load schemas for all indexes on the server."""
|
|
24
|
-
self._schemas.clear()
|
|
25
|
-
indexes = self._client.execute_command("FT._LIST")
|
|
26
|
-
for index_name in indexes:
|
|
27
|
-
# Decode bytes to string if needed
|
|
28
|
-
if isinstance(index_name, bytes):
|
|
29
|
-
index_name = index_name.decode("utf-8")
|
|
30
|
-
self._load_index_schema(index_name)
|
|
31
|
-
|
|
32
|
-
def _load_index_schema(self, index_name: str) -> None:
|
|
33
|
-
"""Load schema for a single index."""
|
|
34
|
-
try:
|
|
35
|
-
info = self._client.execute_command("FT.INFO", index_name)
|
|
36
|
-
schema = self._parse_schema_from_info(info)
|
|
37
|
-
self._schemas[index_name] = schema
|
|
38
|
-
except redis.ResponseError:
|
|
39
|
-
# Index doesn't exist or was deleted
|
|
40
|
-
self._schemas.pop(index_name, None)
|
|
41
|
-
|
|
42
|
-
def _parse_schema_from_info(self, info: list) -> dict[str, str]:
|
|
43
|
-
"""Parse field types from FT.INFO response."""
|
|
44
|
-
schema = {}
|
|
45
|
-
# Find the 'attributes' section in the info response
|
|
46
|
-
for i, item in enumerate(info):
|
|
47
|
-
# Handle bytes or string comparison
|
|
48
|
-
item_str = item.decode("utf-8") if isinstance(item, bytes) else item
|
|
49
|
-
if item_str == "attributes":
|
|
50
|
-
attributes = info[i + 1]
|
|
51
|
-
for attr in attributes:
|
|
52
|
-
field_name = None
|
|
53
|
-
field_type = None
|
|
54
|
-
# Each attribute is a list like:
|
|
55
|
-
# [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...]
|
|
56
|
-
for j, val in enumerate(attr):
|
|
57
|
-
val_str = val.decode("utf-8") if isinstance(val, bytes) else val
|
|
58
|
-
if val_str == "attribute" and j + 1 < len(attr):
|
|
59
|
-
fn = attr[j + 1]
|
|
60
|
-
field_name = (
|
|
61
|
-
fn.decode("utf-8") if isinstance(fn, bytes) else fn
|
|
62
|
-
)
|
|
63
|
-
if val_str == "type" and j + 1 < len(attr):
|
|
64
|
-
ft = attr[j + 1]
|
|
65
|
-
field_type = (
|
|
66
|
-
ft.decode("utf-8") if isinstance(ft, bytes) else ft
|
|
67
|
-
)
|
|
68
|
-
if field_name and field_type:
|
|
69
|
-
schema[field_name] = field_type
|
|
70
|
-
break
|
|
71
|
-
return schema
|
|
72
|
-
|
|
73
|
-
def get_field_type(self, index: str, field: str) -> str | None:
|
|
74
|
-
"""Get field type for a given index and field.
|
|
75
|
-
|
|
76
|
-
Returns None if index or field is unknown.
|
|
77
|
-
"""
|
|
78
|
-
schema = self._schemas.get(index, {})
|
|
79
|
-
return schema.get(field)
|
|
80
|
-
|
|
81
|
-
def get_schema(self, index: str) -> dict[str, str]:
|
|
82
|
-
"""Get full schema for an index.
|
|
83
|
-
|
|
84
|
-
Returns empty dict if index is unknown.
|
|
85
|
-
"""
|
|
86
|
-
return self._schemas.get(index, {})
|
|
87
|
-
|
|
88
|
-
def refresh(self, index_name: str) -> None:
|
|
89
|
-
"""Refresh schema for a single index.
|
|
90
|
-
|
|
91
|
-
If the index no longer exists, removes it from the registry.
|
|
92
|
-
If the index is new, adds it to the registry.
|
|
93
|
-
"""
|
|
94
|
-
self._load_index_schema(index_name)
|
|
95
|
-
|
|
96
|
-
def start_watching(
|
|
97
|
-
self, on_change: Callable[[str, str], None] | None = None
|
|
98
|
-
) -> None:
|
|
99
|
-
"""Start watching for index changes.
|
|
100
|
-
|
|
101
|
-
Since RediSearch doesn't emit keyspace notifications for FT commands,
|
|
102
|
-
this uses polling via FT._LIST to detect changes.
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
on_change: Optional callback invoked with (event_type, index_name)
|
|
106
|
-
when an index is created, dropped, or altered.
|
|
107
|
-
"""
|
|
108
|
-
self._on_change = on_change
|
|
109
|
-
self._watching = True
|
|
110
|
-
|
|
111
|
-
def stop_watching(self) -> None:
|
|
112
|
-
"""Stop watching for index changes."""
|
|
113
|
-
self._watching = False
|
|
114
|
-
self._on_change = None
|
|
115
|
-
|
|
116
|
-
def process_pending_events(self) -> None:
|
|
117
|
-
"""Process any pending index change events.
|
|
118
|
-
|
|
119
|
-
Since RediSearch doesn't emit keyspace notifications, this polls
|
|
120
|
-
FT._LIST to detect new and deleted indexes. Call this periodically.
|
|
121
|
-
"""
|
|
122
|
-
if not self._watching:
|
|
123
|
-
return
|
|
124
|
-
|
|
125
|
-
# Get current indexes from Redis
|
|
126
|
-
current_indexes = set(self._client.execute_command("FT._LIST"))
|
|
127
|
-
|
|
128
|
-
cached_indexes = set(self._schemas.keys())
|
|
129
|
-
|
|
130
|
-
# Detect new indexes
|
|
131
|
-
new_indexes = current_indexes - cached_indexes
|
|
132
|
-
for idx in new_indexes:
|
|
133
|
-
self._load_index_schema(idx)
|
|
134
|
-
if self._on_change:
|
|
135
|
-
self._on_change("created", idx)
|
|
136
|
-
|
|
137
|
-
# Detect deleted indexes
|
|
138
|
-
deleted_indexes = cached_indexes - current_indexes
|
|
139
|
-
for idx in deleted_indexes:
|
|
140
|
-
self._schemas.pop(idx, None)
|
|
141
|
-
if self._on_change:
|
|
142
|
-
self._on_change("dropped", idx)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|