sql-redis 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {sql_redis-0.1.2 → sql_redis-0.2.0}/PKG-INFO +1 -1
  2. {sql_redis-0.1.2 → sql_redis-0.2.0}/pyproject.toml +1 -1
  3. sql_redis-0.2.0/sql_redis/__init__.py +17 -0
  4. sql_redis-0.2.0/sql_redis/executor.py +228 -0
  5. sql_redis-0.2.0/sql_redis/schema.py +215 -0
  6. {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/translator.py +5 -3
  7. {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/version.py +1 -1
  8. sql_redis-0.2.0/tests/test_async_executor.py +344 -0
  9. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_schema_registry.py +5 -9
  10. sql_redis-0.1.2/sql_redis/__init__.py +0 -6
  11. sql_redis-0.1.2/sql_redis/executor.py +0 -155
  12. sql_redis-0.1.2/sql_redis/schema.py +0 -142
  13. {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/lint.yml +0 -0
  14. {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/release.yml +0 -0
  15. {sql_redis-0.1.2 → sql_redis-0.2.0}/.github/workflows/test.yml +0 -0
  16. {sql_redis-0.1.2 → sql_redis-0.2.0}/.gitignore +0 -0
  17. {sql_redis-0.1.2 → sql_redis-0.2.0}/.pre-commit-config.yaml +0 -0
  18. {sql_redis-0.1.2 → sql_redis-0.2.0}/Makefile +0 -0
  19. {sql_redis-0.1.2 → sql_redis-0.2.0}/README.md +0 -0
  20. {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/analyzer.py +0 -0
  21. {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/parser.py +0 -0
  22. {sql_redis-0.1.2 → sql_redis-0.2.0}/sql_redis/query_builder.py +0 -0
  23. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/__init__.py +0 -0
  24. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/conftest.py +0 -0
  25. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_analyzer.py +0 -0
  26. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_executor.py +0 -0
  27. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_parameter_substitution.py +0 -0
  28. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_query_builder.py +0 -0
  29. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_redis_queries.py +0 -0
  30. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_sql_parser.py +0 -0
  31. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_sql_queries.py +0 -0
  32. {sql_redis-0.1.2 → sql_redis-0.2.0}/tests/test_translator.py +0 -0
  33. {sql_redis-0.1.2 → sql_redis-0.2.0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sql-redis
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: SQL to Redis command translation utility
5
5
  Project-URL: Homepage, https://github.com/redis/sql-redis
6
6
  Project-URL: Repository, https://github.com/redis/sql-redis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sql-redis"
3
- version = "0.1.2"
3
+ version = "0.2.0"
4
4
  description = "SQL to Redis command translation utility"
5
5
  authors = [{ name = "Redis Inc.", email = "applied.ai@redis.com" }]
6
6
  requires-python = ">=3.9,<3.14"
@@ -0,0 +1,17 @@
1
+ """SQL to Redis command translation utility."""
2
+
3
+ from sql_redis.executor import AsyncExecutor, Executor, QueryResult
4
+ from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
5
+ from sql_redis.translator import TranslatedQuery, Translator
6
+ from sql_redis.version import __version__
7
+
8
+ __all__ = [
9
+ "Translator",
10
+ "TranslatedQuery",
11
+ "SchemaRegistry",
12
+ "AsyncSchemaRegistry",
13
+ "Executor",
14
+ "AsyncExecutor",
15
+ "QueryResult",
16
+ "__version__",
17
+ ]
@@ -0,0 +1,228 @@
1
+ """SQL Executor - executes translated queries against Redis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import redis
10
+
11
+ from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
12
+ from sql_redis.translator import Translator
13
+
14
+ if TYPE_CHECKING:
15
+ import redis.asyncio as async_redis
16
+
17
+
18
+ def _substitute_params(sql: str, params: dict[str, Any]) -> str:
19
+ """Substitute parameter placeholders in SQL with actual values.
20
+
21
+ This is a pure function with no I/O operations, shared by both
22
+ sync and async executors.
23
+
24
+ Uses token-based approach: splits SQL on :param patterns, then rebuilds
25
+ with substituted values. This approach solves two critical bugs:
26
+
27
+ 1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
28
+ by treating each :identifier as a complete token
29
+
30
+ 2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
31
+ using SQL standard (single quote -> double single quote)
32
+
33
+ Args:
34
+ sql: The SQL string with :param placeholders.
35
+ params: Dictionary mapping parameter names to values.
36
+
37
+ Returns:
38
+ SQL string with parameters substituted.
39
+
40
+ Implementation Details:
41
+ - Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
42
+ - Keeps delimiters (the :param tokens) in the split result
43
+ - Iterates through tokens, substituting matched parameters
44
+ - String values are wrapped in single quotes with proper escaping
45
+ - Numeric values are converted to strings
46
+ - Bytes values (e.g., vectors) are NOT substituted here
47
+
48
+ Known Limitations:
49
+ - Colons in string literals: SQL like "WHERE x = 'test:value'" would
50
+ theoretically match :value as a parameter. However, this is not a
51
+ practical issue because:
52
+ 1. Users pass values via parameters, not hardcoded in SQL
53
+ 2. The translator has its own handling of string literals
54
+ 3. No real-world use cases have been identified
55
+ - Parameter names are case-sensitive (:id != :ID)
56
+ - Only handles int, float, str types; other types keep placeholder
57
+ """
58
+ if not params:
59
+ return sql
60
+
61
+ # Split SQL on :param patterns, keeping the delimiters
62
+ # Pattern matches : followed by valid identifier:
63
+ # [a-zA-Z_] - First char must be letter or underscore
64
+ # [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
65
+ # This prevents partial matching: :id and :product_id are separate tokens
66
+ tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
67
+
68
+ result = []
69
+ for token in tokens:
70
+ if token.startswith(":"):
71
+ # This is a parameter placeholder
72
+ key = token[1:] # Remove leading :
73
+ if key in params:
74
+ value = params[key]
75
+ if isinstance(value, (int, float)):
76
+ # Numeric values: convert to string
77
+ result.append(str(value))
78
+ elif isinstance(value, str):
79
+ # String values: wrap in quotes and escape single quotes
80
+ # SQL standard: ' -> '' (double single quote)
81
+ # This fixes the quote escaping bug
82
+ escaped = value.replace("'", "''")
83
+ result.append(f"'{escaped}'")
84
+ else:
85
+ # Other types (bytes, None, bool, list, etc.):
86
+ # Keep placeholder as-is (handled elsewhere or unsupported)
87
+ result.append(token)
88
+ else:
89
+ # Parameter not provided: keep placeholder as-is
90
+ result.append(token)
91
+ else:
92
+ # Not a parameter: keep as-is
93
+ result.append(token)
94
+
95
+ return "".join(result)
96
+
97
+
98
+ @dataclass
99
+ class QueryResult:
100
+ """Result of executing a SQL query."""
101
+
102
+ rows: list[dict]
103
+ count: int
104
+
105
+
106
+ class Executor:
107
+ """Executes SQL queries against Redis."""
108
+
109
+ def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry) -> None:
110
+ """Initialize executor with Redis client and schema registry."""
111
+ self._client = client
112
+ self._schema_registry = schema_registry
113
+ self._translator = Translator(schema_registry)
114
+
115
+ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
116
+ """Execute a SQL query and return results."""
117
+ params = params or {}
118
+
119
+ # Substitute non-bytes params in SQL using token-based approach
120
+ sql = _substitute_params(sql, params)
121
+
122
+ # Translate SQL to Redis command
123
+ translated = self._translator.translate(sql)
124
+
125
+ # Build command list and substitute vector params
126
+ # Use list[str | bytes] to allow bytes for vector params
127
+ cmd: list[str | bytes] = list(translated.to_command_list())
128
+
129
+ # Find any bytes params (vectors) to substitute
130
+ vector_param: bytes | None = None
131
+ for value in params.values():
132
+ if isinstance(value, bytes):
133
+ vector_param = value
134
+ break
135
+
136
+ # Replace $vector placeholder with actual bytes
137
+ if vector_param:
138
+ for i, arg in enumerate(cmd):
139
+ if arg == "$vector":
140
+ cmd[i] = vector_param
141
+
142
+ # Execute command
143
+ raw_result = self._client.execute_command(*cmd)
144
+
145
+ # Parse result based on command type
146
+ count = raw_result[0] if raw_result else 0
147
+ rows = []
148
+
149
+ if translated.command == "FT.SEARCH":
150
+ # FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
151
+ # Skip document keys (odd indices), take field lists (even indices after count)
152
+ for i in range(2, len(raw_result), 2):
153
+ row_data = raw_result[i]
154
+ row = dict(zip(row_data[::2], row_data[1::2]))
155
+ rows.append(row)
156
+ else:
157
+ # FT.AGGREGATE format: [count, [fields1], [fields2], ...]
158
+ for row_data in raw_result[1:]:
159
+ row = dict(zip(row_data[::2], row_data[1::2]))
160
+ rows.append(row)
161
+
162
+ return QueryResult(rows=rows, count=count)
163
+
164
+
165
+ class AsyncExecutor:
166
+ """Async version of Executor for use with redis.asyncio clients."""
167
+
168
+ def __init__(
169
+ self,
170
+ client: "async_redis.Redis",
171
+ schema_registry: AsyncSchemaRegistry,
172
+ ) -> None:
173
+ """Initialize async executor with Redis client and schema registry.
174
+
175
+ Args:
176
+ client: An async Redis client (redis.asyncio.Redis).
177
+ schema_registry: An AsyncSchemaRegistry instance.
178
+ """
179
+ self._client = client
180
+ self._schema_registry = schema_registry
181
+ self._translator = Translator(schema_registry)
182
+
183
+ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
184
+ """Execute a SQL query asynchronously and return results."""
185
+ params = params or {}
186
+
187
+ # Substitute non-bytes params in SQL
188
+ sql = _substitute_params(sql, params)
189
+
190
+ # Translate SQL to Redis command (sync - no Redis calls)
191
+ translated = self._translator.translate(sql)
192
+
193
+ # Build command list and substitute vector params
194
+ cmd: list[str | bytes] = list(translated.to_command_list())
195
+
196
+ # Find any bytes params (vectors) to substitute
197
+ vector_param: bytes | None = None
198
+ for value in params.values():
199
+ if isinstance(value, bytes):
200
+ vector_param = value
201
+ break
202
+
203
+ # Replace $vector placeholder with actual bytes
204
+ if vector_param:
205
+ for i, arg in enumerate(cmd):
206
+ if arg == "$vector":
207
+ cmd[i] = vector_param
208
+
209
+ # Execute command asynchronously
210
+ raw_result = await self._client.execute_command(*cmd)
211
+
212
+ # Parse result based on command type
213
+ count = raw_result[0] if raw_result else 0
214
+ rows = []
215
+
216
+ if translated.command == "FT.SEARCH":
217
+ # FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
218
+ for i in range(2, len(raw_result), 2):
219
+ row_data = raw_result[i]
220
+ row = dict(zip(row_data[::2], row_data[1::2]))
221
+ rows.append(row)
222
+ else:
223
+ # FT.AGGREGATE format: [count, [fields1], [fields2], ...]
224
+ for row_data in raw_result[1:]:
225
+ row = dict(zip(row_data[::2], row_data[1::2]))
226
+ rows.append(row)
227
+
228
+ return QueryResult(rows=rows, count=count)
@@ -0,0 +1,215 @@
1
+ """Schema registry for Redis search indexes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Callable
6
+
7
+ import redis
8
+
9
+ if TYPE_CHECKING:
10
+ import redis.asyncio as async_redis
11
+
12
+
13
+ def _parse_schema_from_info(info: list) -> dict[str, str]:
14
+ """Parse field types from FT.INFO response.
15
+
16
+ This is a pure function with no I/O operations, shared by both
17
+ sync and async schema registries.
18
+
19
+ Args:
20
+ info: The raw response from FT.INFO command.
21
+
22
+ Returns:
23
+ Dictionary mapping field names to their types (e.g., {"title": "TEXT"}).
24
+ """
25
+ schema = {}
26
+ # Find the 'attributes' section in the info response
27
+ for i, item in enumerate(info):
28
+ # Handle bytes or string comparison
29
+ item_str = item.decode("utf-8") if isinstance(item, bytes) else item
30
+ if item_str == "attributes":
31
+ attributes = info[i + 1]
32
+ for attr in attributes:
33
+ field_name = None
34
+ field_type = None
35
+ # Each attribute is a list like:
36
+ # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...]
37
+ for j, val in enumerate(attr):
38
+ val_str = val.decode("utf-8") if isinstance(val, bytes) else val
39
+ if val_str == "attribute" and j + 1 < len(attr):
40
+ fn = attr[j + 1]
41
+ field_name = fn.decode("utf-8") if isinstance(fn, bytes) else fn
42
+ if val_str == "type" and j + 1 < len(attr):
43
+ ft = attr[j + 1]
44
+ field_type = ft.decode("utf-8") if isinstance(ft, bytes) else ft
45
+ if field_name and field_type:
46
+ schema[field_name] = field_type
47
+ break
48
+ return schema
49
+
50
+
51
+ class SchemaRegistry:
52
+ """Loads and caches index schemas from Redis.
53
+
54
+ Supports automatic schema refresh via Redis keyspace notifications.
55
+ """
56
+
57
+ def __init__(self, redis_client: redis.Redis):
58
+ self._client = redis_client
59
+ self._schemas: dict[str, dict[str, str]] = {}
60
+ self._on_change: Callable[[str, str], None] | None = None
61
+ self._watching = False
62
+
63
+ def load_all(self) -> None:
64
+ """Load schemas for all indexes on the server."""
65
+ self._schemas.clear()
66
+ indexes = self._client.execute_command("FT._LIST")
67
+ for index_name in indexes:
68
+ # Decode bytes to string if needed
69
+ if isinstance(index_name, bytes):
70
+ index_name = index_name.decode("utf-8")
71
+ self._load_index_schema(index_name)
72
+
73
+ def _load_index_schema(self, index_name: str) -> None:
74
+ """Load schema for a single index."""
75
+ try:
76
+ info = self._client.execute_command("FT.INFO", index_name)
77
+ schema = _parse_schema_from_info(info)
78
+ self._schemas[index_name] = schema
79
+ except redis.ResponseError:
80
+ # Index doesn't exist or was deleted
81
+ self._schemas.pop(index_name, None)
82
+
83
+ def get_field_type(self, index: str, field: str) -> str | None:
84
+ """Get field type for a given index and field.
85
+
86
+ Returns None if index or field is unknown.
87
+ """
88
+ schema = self._schemas.get(index, {})
89
+ return schema.get(field)
90
+
91
+ def get_schema(self, index: str) -> dict[str, str]:
92
+ """Get full schema for an index.
93
+
94
+ Returns empty dict if index is unknown.
95
+ """
96
+ return self._schemas.get(index, {})
97
+
98
+ def refresh(self, index_name: str) -> None:
99
+ """Refresh schema for a single index.
100
+
101
+ If the index no longer exists, removes it from the registry.
102
+ If the index is new, adds it to the registry.
103
+ """
104
+ self._load_index_schema(index_name)
105
+
106
+ def start_watching(
107
+ self, on_change: Callable[[str, str], None] | None = None
108
+ ) -> None:
109
+ """Start watching for index changes.
110
+
111
+ Since RediSearch doesn't emit keyspace notifications for FT commands,
112
+ this uses polling via FT._LIST to detect changes.
113
+
114
+ Args:
115
+ on_change: Optional callback invoked with (event_type, index_name)
116
+ when an index is created, dropped, or altered.
117
+ """
118
+ self._on_change = on_change
119
+ self._watching = True
120
+
121
+ def stop_watching(self) -> None:
122
+ """Stop watching for index changes."""
123
+ self._watching = False
124
+ self._on_change = None
125
+
126
+ def process_pending_events(self) -> None:
127
+ """Process any pending index change events.
128
+
129
+ Since RediSearch doesn't emit keyspace notifications, this polls
130
+ FT._LIST to detect new and deleted indexes. Call this periodically.
131
+ """
132
+ if not self._watching:
133
+ return
134
+
135
+ # Get current indexes from Redis
136
+ current_indexes = set(self._client.execute_command("FT._LIST"))
137
+
138
+ cached_indexes = set(self._schemas.keys())
139
+
140
+ # Detect new indexes
141
+ new_indexes = current_indexes - cached_indexes
142
+ for idx in new_indexes:
143
+ self._load_index_schema(idx)
144
+ if self._on_change:
145
+ self._on_change("created", idx)
146
+
147
+ # Detect deleted indexes
148
+ deleted_indexes = cached_indexes - current_indexes
149
+ for idx in deleted_indexes:
150
+ self._schemas.pop(idx, None)
151
+ if self._on_change:
152
+ self._on_change("dropped", idx)
153
+
154
+
155
+ class AsyncSchemaRegistry:
156
+ """Async version of SchemaRegistry for use with redis.asyncio clients.
157
+
158
+ Loads and caches index schemas from Redis asynchronously.
159
+ """
160
+
161
+ def __init__(self, redis_client: "async_redis.Redis") -> None:
162
+ """Initialize with an async Redis client.
163
+
164
+ Args:
165
+ redis_client: An async Redis client (redis.asyncio.Redis).
166
+ """
167
+ self._client = redis_client
168
+ self._schemas: dict[str, dict[str, str]] = {}
169
+
170
+ async def load_all(self) -> None:
171
+ """Load schemas for all indexes on the server.
172
+
173
+ Uses asyncio.gather() to load all index schemas concurrently.
174
+ """
175
+ import asyncio
176
+
177
+ self._schemas.clear()
178
+ indexes = await self._client.execute_command("FT._LIST")
179
+ # Decode bytes to strings
180
+ decoded_indexes = [
181
+ idx.decode("utf-8") if isinstance(idx, bytes) else idx for idx in indexes
182
+ ]
183
+ # Load all schemas concurrently
184
+ await asyncio.gather(
185
+ *[self._load_index_schema(name) for name in decoded_indexes]
186
+ )
187
+
188
+ async def _load_index_schema(self, index_name: str) -> None:
189
+ """Load schema for a single index."""
190
+ try:
191
+ info = await self._client.execute_command("FT.INFO", index_name)
192
+ schema = _parse_schema_from_info(info)
193
+ self._schemas[index_name] = schema
194
+ except redis.ResponseError:
195
+ # Index doesn't exist or was deleted
196
+ self._schemas.pop(index_name, None)
197
+
198
+ def get_field_type(self, index: str, field: str) -> str | None:
199
+ """Get field type for a given index and field.
200
+
201
+ Returns None if index or field is unknown.
202
+ """
203
+ schema = self._schemas.get(index, {})
204
+ return schema.get(field)
205
+
206
+ def get_schema(self, index: str) -> dict[str, str]:
207
+ """Get full schema for an index.
208
+
209
+ Returns empty dict if index is unknown.
210
+ """
211
+ return self._schemas.get(index, {})
212
+
213
+ async def refresh(self, index_name: str) -> None:
214
+ """Refresh schema for a single index."""
215
+ await self._load_index_schema(index_name)
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
7
7
  from sql_redis.analyzer import AnalyzedQuery, Analyzer
8
8
  from sql_redis.parser import Condition, ParsedQuery, SQLParser
9
9
  from sql_redis.query_builder import QueryBuilder
10
- from sql_redis.schema import SchemaRegistry
10
+ from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
11
11
 
12
12
 
13
13
  @dataclass
@@ -34,11 +34,13 @@ class TranslatedQuery:
34
34
  class Translator:
35
35
  """Translates SQL queries to Redis FT.SEARCH/FT.AGGREGATE commands."""
36
36
 
37
- def __init__(self, schema_registry: SchemaRegistry):
37
+ def __init__(self, schema_registry: SchemaRegistry | AsyncSchemaRegistry) -> None:
38
38
  """Initialize translator with schema registry.
39
39
 
40
40
  Args:
41
- schema_registry: Registry containing index schemas.
41
+ schema_registry: Registry containing index schemas. Can be either
42
+ sync (SchemaRegistry) or async (AsyncSchemaRegistry) - only
43
+ the sync get_schema() method is used.
42
44
  """
43
45
  self._schema_registry = schema_registry
44
46
  self._parser = SQLParser()
@@ -2,7 +2,7 @@ try:
2
2
  from importlib.metadata import PackageNotFoundError, version
3
3
  except ImportError:
4
4
  # Python < 3.8 fallback
5
- from importlib_metadata import PackageNotFoundError, version # type: ignore
5
+ from importlib_metadata import PackageNotFoundError, version # type: ignore # isort: skip
6
6
 
7
7
  try:
8
8
  __version__ = version("sql-redis")
@@ -0,0 +1,344 @@
1
+ """Integration tests for async SQL executor.
2
+
3
+ TDD: These tests define the expected behavior for AsyncSchemaRegistry and AsyncExecutor.
4
+ """
5
+
6
+ import struct
7
+
8
+ import pytest
9
+ import redis.asyncio as async_redis
10
+ from testcontainers.redis import RedisContainer
11
+
12
+ from sql_redis.executor import AsyncExecutor, QueryResult
13
+ from sql_redis.schema import AsyncSchemaRegistry
14
+
15
+
16
+ @pytest.fixture(scope="module")
17
+ def redis_container():
18
+ """Start a Redis container for testing."""
19
+ with RedisContainer("redis:8.0.2") as container:
20
+ yield container
21
+
22
+
23
+ @pytest.fixture
24
+ async def async_client(redis_container) -> async_redis.Redis:
25
+ """Create an async Redis client connected to the test container."""
26
+ client = async_redis.Redis(
27
+ host=redis_container.get_container_host_ip(),
28
+ port=int(redis_container.get_exposed_port(6379)),
29
+ decode_responses=True,
30
+ )
31
+ yield client
32
+ await client.aclose()
33
+
34
+
35
+ @pytest.fixture
36
+ async def products_index(async_client: async_redis.Redis) -> str:
37
+ """Create a products index with test data."""
38
+ index_name = "async_products"
39
+ try:
40
+ await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
41
+ except Exception:
42
+ pass
43
+
44
+ await async_client.execute_command(
45
+ "FT.CREATE",
46
+ index_name,
47
+ "ON",
48
+ "HASH",
49
+ "PREFIX",
50
+ "1",
51
+ "async_product:",
52
+ "SCHEMA",
53
+ "title",
54
+ "TEXT",
55
+ "category",
56
+ "TAG",
57
+ "price",
58
+ "NUMERIC",
59
+ "stock",
60
+ "NUMERIC",
61
+ )
62
+
63
+ # Add test data
64
+ await async_client.hset(
65
+ "async_product:1",
66
+ mapping={
67
+ "title": "Laptop Pro",
68
+ "category": "electronics",
69
+ "price": "999.99",
70
+ "stock": "10",
71
+ },
72
+ )
73
+ await async_client.hset(
74
+ "async_product:2",
75
+ mapping={
76
+ "title": "Wireless Mouse",
77
+ "category": "electronics",
78
+ "price": "29.99",
79
+ "stock": "50",
80
+ },
81
+ )
82
+ await async_client.hset(
83
+ "async_product:3",
84
+ mapping={
85
+ "title": "Python Book",
86
+ "category": "books",
87
+ "price": "49.99",
88
+ "stock": "25",
89
+ },
90
+ )
91
+ await async_client.hset(
92
+ "async_product:4",
93
+ mapping={
94
+ "title": "Redis Guide",
95
+ "category": "books",
96
+ "price": "39.99",
97
+ "stock": "15",
98
+ },
99
+ )
100
+
101
+ yield index_name
102
+
103
+ # Cleanup
104
+ try:
105
+ await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
106
+ except Exception:
107
+ pass
108
+
109
+
110
+ @pytest.fixture
111
+ async def async_executor(
112
+ async_client: async_redis.Redis, products_index: str
113
+ ) -> AsyncExecutor:
114
+ """Create an async executor with the products index loaded."""
115
+ registry = AsyncSchemaRegistry(async_client)
116
+ await registry.load_all()
117
+ return AsyncExecutor(async_client, registry)
118
+
119
+
120
+ class TestAsyncSchemaRegistry:
121
+ """Tests for AsyncSchemaRegistry."""
122
+
123
+ async def test_load_all_loads_indexes(
124
+ self, async_client: async_redis.Redis, products_index: str
125
+ ):
126
+ """load_all() should load index schemas from Redis."""
127
+ registry = AsyncSchemaRegistry(async_client)
128
+ await registry.load_all()
129
+
130
+ schema = registry.get_schema(products_index)
131
+ assert schema is not None
132
+ assert "title" in schema
133
+ assert schema["title"] == "TEXT"
134
+ assert "category" in schema
135
+ assert schema["category"] == "TAG"
136
+ assert "price" in schema
137
+ assert schema["price"] == "NUMERIC"
138
+
139
+ async def test_get_schema_returns_empty_for_unknown(
140
+ self, async_client: async_redis.Redis
141
+ ):
142
+ """get_schema() returns empty dict for unknown index."""
143
+ registry = AsyncSchemaRegistry(async_client)
144
+ await registry.load_all()
145
+
146
+ schema = registry.get_schema("nonexistent_index")
147
+ assert schema == {}
148
+
149
+
150
+ class TestAsyncExecutorBasic:
151
+ """Tests for basic async query execution."""
152
+
153
+ async def test_select_all(self, async_executor: AsyncExecutor, products_index: str):
154
+ """SELECT * returns all documents."""
155
+ result = await async_executor.execute(f"SELECT * FROM {products_index}")
156
+ assert result.count == 4
157
+ assert len(result.rows) == 4
158
+
159
+ async def test_result_is_query_result(
160
+ self, async_executor: AsyncExecutor, products_index: str
161
+ ):
162
+ """Result should be a QueryResult instance."""
163
+ result = await async_executor.execute(f"SELECT * FROM {products_index}")
164
+ assert isinstance(result, QueryResult)
165
+ assert hasattr(result, "rows")
166
+ assert hasattr(result, "count")
167
+
168
+ async def test_select_with_tag_filter(
169
+ self, async_executor: AsyncExecutor, products_index: str
170
+ ):
171
+ """SELECT with tag filter."""
172
+ result = await async_executor.execute(
173
+ f"SELECT * FROM {products_index} WHERE category = 'books'"
174
+ )
175
+ assert result.count == 2
176
+ for row in result.rows:
177
+ assert row["category"] == "books"
178
+
179
+ async def test_select_with_numeric_filter(
180
+ self, async_executor: AsyncExecutor, products_index: str
181
+ ):
182
+ """SELECT with numeric comparison."""
183
+ result = await async_executor.execute(
184
+ f"SELECT * FROM {products_index} WHERE price < 50"
185
+ )
186
+ assert result.count >= 2
187
+ for row in result.rows:
188
+ assert float(row["price"]) < 50
189
+
190
+ async def test_select_with_limit(
191
+ self, async_executor: AsyncExecutor, products_index: str
192
+ ):
193
+ """SELECT with LIMIT."""
194
+ result = await async_executor.execute(f"SELECT * FROM {products_index} LIMIT 2")
195
+ assert len(result.rows) == 2
196
+
197
+ async def test_select_with_order_by(
198
+ self, async_executor: AsyncExecutor, products_index: str
199
+ ):
200
+ """SELECT with ORDER BY."""
201
+ result = await async_executor.execute(
202
+ f"SELECT * FROM {products_index} ORDER BY price DESC"
203
+ )
204
+ prices = [float(row["price"]) for row in result.rows]
205
+ assert prices == sorted(prices, reverse=True)
206
+
207
+
208
+ class TestAsyncExecutorAggregation:
209
+ """Tests for async aggregate query execution."""
210
+
211
+ async def test_count_all(self, async_executor: AsyncExecutor, products_index: str):
212
+ """SELECT COUNT(*) returns count."""
213
+ result = await async_executor.execute(f"SELECT COUNT(*) FROM {products_index}")
214
+ assert len(result.rows) == 1
215
+ row = result.rows[0]
216
+ count_value = row.get("COUNT(*)", row.get("count", None))
217
+ assert count_value is not None
218
+
219
+ async def test_group_by_with_count(
220
+ self, async_executor: AsyncExecutor, products_index: str
221
+ ):
222
+ """SELECT with GROUP BY and COUNT."""
223
+ result = await async_executor.execute(
224
+ f"SELECT category, COUNT(*) as cnt FROM {products_index} GROUP BY category"
225
+ )
226
+ assert len(result.rows) == 2 # electronics and books
227
+ categories = {row["category"] for row in result.rows}
228
+ assert categories == {"electronics", "books"}
229
+
230
+
231
+ class TestAsyncExecutorParams:
232
+ """Tests for parameterized async execution."""
233
+
234
+ async def test_numeric_param(
235
+ self, async_executor: AsyncExecutor, products_index: str
236
+ ):
237
+ """Execute with numeric parameter."""
238
+ result = await async_executor.execute(
239
+ f"SELECT * FROM {products_index} WHERE price > :min_price",
240
+ params={"min_price": 40},
241
+ )
242
+ for row in result.rows:
243
+ assert float(row["price"]) > 40
244
+
245
+ async def test_string_param(
246
+ self, async_executor: AsyncExecutor, products_index: str
247
+ ):
248
+ """Execute with string parameter."""
249
+ result = await async_executor.execute(
250
+ f"SELECT * FROM {products_index} WHERE category = :cat",
251
+ params={"cat": "books"},
252
+ )
253
+ assert len(result.rows) == 2
254
+ for row in result.rows:
255
+ assert row["category"] == "books"
256
+
257
+
258
+ class TestAsyncVectorSearch:
259
+ """Tests for async vector search execution."""
260
+
261
+ @pytest.fixture
262
+ async def vector_index(self, async_client: async_redis.Redis) -> str:
263
+ """Create a vector index with test data."""
264
+ index_name = "async_vectors"
265
+ try:
266
+ await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
267
+ except Exception:
268
+ pass
269
+
270
+ await async_client.execute_command(
271
+ "FT.CREATE",
272
+ index_name,
273
+ "ON",
274
+ "HASH",
275
+ "PREFIX",
276
+ "1",
277
+ "async_vec:",
278
+ "SCHEMA",
279
+ "title",
280
+ "TEXT",
281
+ "embedding",
282
+ "VECTOR",
283
+ "HNSW",
284
+ "6",
285
+ "TYPE",
286
+ "FLOAT32",
287
+ "DIM",
288
+ "4",
289
+ "DISTANCE_METRIC",
290
+ "COSINE",
291
+ )
292
+
293
+ def to_bytes(v):
294
+ return struct.pack(f"{len(v)}f", *v)
295
+
296
+ # Use a separate non-decode client for binary data
297
+ raw_client = async_redis.Redis(
298
+ host=async_client.connection_pool.connection_kwargs["host"],
299
+ port=async_client.connection_pool.connection_kwargs["port"],
300
+ decode_responses=False,
301
+ )
302
+ await raw_client.hset(
303
+ "async_vec:1",
304
+ mapping={"title": "First", "embedding": to_bytes([0.1, 0.2, 0.3, 0.4])},
305
+ )
306
+ await raw_client.hset(
307
+ "async_vec:2",
308
+ mapping={"title": "Second", "embedding": to_bytes([0.5, 0.6, 0.7, 0.8])},
309
+ )
310
+ await raw_client.hset(
311
+ "async_vec:3",
312
+ mapping={"title": "Third", "embedding": to_bytes([0.9, 0.8, 0.7, 0.6])},
313
+ )
314
+ await raw_client.aclose()
315
+
316
+ yield index_name
317
+
318
+ # Cleanup
319
+ try:
320
+ await async_client.execute_command("FT.DROPINDEX", index_name, "DD")
321
+ except Exception:
322
+ pass
323
+
324
+ async def test_vector_search_with_param(
325
+ self, async_client: async_redis.Redis, vector_index: str
326
+ ):
327
+ """Vector search with vector parameter."""
328
+ registry = AsyncSchemaRegistry(async_client)
329
+ await registry.load_all()
330
+ executor = AsyncExecutor(async_client, registry)
331
+
332
+ query_vector = struct.pack("4f", 0.1, 0.2, 0.3, 0.4)
333
+ result = await executor.execute(
334
+ f"SELECT title, vector_distance(embedding, :vec) AS score "
335
+ f"FROM {vector_index} LIMIT 3",
336
+ params={"vec": query_vector},
337
+ )
338
+ assert len(result.rows) <= 3
339
+ # First result should be closest to query vector
340
+ assert result.rows[0]["title"] == "First"
341
+ # Verify vector distance score is returned
342
+ assert "score" in result.rows[0]
343
+ score = float(result.rows[0]["score"])
344
+ assert score >= 0 # Distance should be non-negative
@@ -3,7 +3,7 @@
3
3
  import pytest
4
4
  import redis
5
5
 
6
- from sql_redis.schema import SchemaRegistry
6
+ from sql_redis.schema import SchemaRegistry, _parse_schema_from_info
7
7
 
8
8
 
9
9
  def _create_test_indexes(redis_client: redis.Redis) -> list[str]:
@@ -222,20 +222,16 @@ class TestSchemaRegistryEmptyServer:
222
222
  class TestSchemaRegistryParsing:
223
223
  """Tests for schema parsing edge cases."""
224
224
 
225
- def test_parse_schema_no_attributes_section(self, redis_client: redis.Redis):
225
+ def test_parse_schema_no_attributes_section(self):
226
226
  """_parse_schema_from_info handles response without attributes."""
227
- registry = SchemaRegistry(redis_client)
228
-
229
227
  # FT.INFO response without 'attributes' key
230
228
  fake_info = ["index_name", "test", "other_key", "value"]
231
- schema = registry._parse_schema_from_info(fake_info)
229
+ schema = _parse_schema_from_info(fake_info)
232
230
 
233
231
  assert schema == {}
234
232
 
235
- def test_parse_schema_incomplete_attribute(self, redis_client: redis.Redis):
233
+ def test_parse_schema_incomplete_attribute(self):
236
234
  """_parse_schema_from_info handles attribute without type."""
237
- registry = SchemaRegistry(redis_client)
238
-
239
235
  # FT.INFO response with attribute but missing type
240
236
  fake_info = [
241
237
  "attributes",
@@ -244,7 +240,7 @@ class TestSchemaRegistryParsing:
244
240
  ["identifier", "field2", "attribute", "field2", "type", "TEXT"],
245
241
  ],
246
242
  ]
247
- schema = registry._parse_schema_from_info(fake_info)
243
+ schema = _parse_schema_from_info(fake_info)
248
244
 
249
245
  # Only field2 should be captured (field1 has no type)
250
246
  assert schema == {"field2": "TEXT"}
@@ -1,6 +0,0 @@
1
- """SQL to Redis command translation utility."""
2
-
3
- from sql_redis.translator import TranslatedQuery, Translator
4
- from sql_redis.version import __version__
5
-
6
- __all__ = ["Translator", "TranslatedQuery", "__version__"]
@@ -1,155 +0,0 @@
1
- """SQL Executor - executes translated queries against Redis."""
2
-
3
- from __future__ import annotations
4
-
5
- import re
6
- from dataclasses import dataclass
7
- from typing import Any
8
-
9
- import redis
10
-
11
- from sql_redis.schema import SchemaRegistry
12
- from sql_redis.translator import Translator
13
-
14
-
15
- @dataclass
16
- class QueryResult:
17
- """Result of executing a SQL query."""
18
-
19
- rows: list[dict]
20
- count: int
21
-
22
-
23
- class Executor:
24
- """Executes SQL queries against Redis."""
25
-
26
- def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry):
27
- """Initialize executor with Redis client and schema registry."""
28
- self._client = client
29
- self._schema_registry = schema_registry
30
- self._translator = Translator(schema_registry)
31
-
32
- def _substitute_params(self, sql: str, params: dict[str, Any]) -> str:
33
- """Substitute parameter placeholders in SQL with actual values.
34
-
35
- Uses token-based approach: splits SQL on :param patterns, then rebuilds
36
- with substituted values. This approach solves two critical bugs:
37
-
38
- 1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
39
- by treating each :identifier as a complete token
40
-
41
- 2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
42
- using SQL standard (single quote -> double single quote)
43
-
44
- Args:
45
- sql: The SQL string with :param placeholders.
46
- params: Dictionary mapping parameter names to values.
47
-
48
- Returns:
49
- SQL string with parameters substituted.
50
-
51
- Implementation Details:
52
- - Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
53
- - Keeps delimiters (the :param tokens) in the split result
54
- - Iterates through tokens, substituting matched parameters
55
- - String values are wrapped in single quotes with proper escaping
56
- - Numeric values are converted to strings
57
- - Bytes values (e.g., vectors) are NOT substituted here
58
-
59
- Known Limitations:
60
- - Colons in string literals: SQL like "WHERE x = 'test:value'" would
61
- theoretically match :value as a parameter. However, this is not a
62
- practical issue because:
63
- 1. Users pass values via parameters, not hardcoded in SQL
64
- 2. The translator has its own handling of string literals
65
- 3. No real-world use cases have been identified
66
- - Parameter names are case-sensitive (:id != :ID)
67
- - Only handles int, float, str types; other types keep placeholder
68
- """
69
- if not params:
70
- return sql
71
-
72
- # Split SQL on :param patterns, keeping the delimiters
73
- # Pattern matches : followed by valid identifier:
74
- # [a-zA-Z_] - First char must be letter or underscore
75
- # [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
76
- # This prevents partial matching: :id and :product_id are separate tokens
77
- tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
78
-
79
- result = []
80
- for token in tokens:
81
- if token.startswith(":"):
82
- # This is a parameter placeholder
83
- key = token[1:] # Remove leading :
84
- if key in params:
85
- value = params[key]
86
- if isinstance(value, (int, float)):
87
- # Numeric values: convert to string
88
- result.append(str(value))
89
- elif isinstance(value, str):
90
- # String values: wrap in quotes and escape single quotes
91
- # SQL standard: ' -> '' (double single quote)
92
- # This fixes the quote escaping bug
93
- escaped = value.replace("'", "''")
94
- result.append(f"'{escaped}'")
95
- else:
96
- # Other types (bytes, None, bool, list, etc.):
97
- # Keep placeholder as-is (handled elsewhere or unsupported)
98
- result.append(token)
99
- else:
100
- # Parameter not provided: keep placeholder as-is
101
- result.append(token)
102
- else:
103
- # Not a parameter: keep as-is
104
- result.append(token)
105
-
106
- return "".join(result)
107
-
108
- def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
109
- """Execute a SQL query and return results."""
110
- params = params or {}
111
-
112
- # Substitute non-bytes params in SQL using token-based approach
113
- sql = self._substitute_params(sql, params)
114
-
115
- # Translate SQL to Redis command
116
- translated = self._translator.translate(sql)
117
-
118
- # Build command list and substitute vector params
119
- # Use list[str | bytes] to allow bytes for vector params
120
- cmd: list[str | bytes] = list(translated.to_command_list())
121
-
122
- # Find any bytes params (vectors) to substitute
123
- vector_param: bytes | None = None
124
- for value in params.values():
125
- if isinstance(value, bytes):
126
- vector_param = value
127
- break
128
-
129
- # Replace $vector placeholder with actual bytes
130
- if vector_param:
131
- for i, arg in enumerate(cmd):
132
- if arg == "$vector":
133
- cmd[i] = vector_param
134
-
135
- # Execute command
136
- raw_result = self._client.execute_command(*cmd)
137
-
138
- # Parse result based on command type
139
- count = raw_result[0] if raw_result else 0
140
- rows = []
141
-
142
- if translated.command == "FT.SEARCH":
143
- # FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
144
- # Skip document keys (odd indices), take field lists (even indices after count)
145
- for i in range(2, len(raw_result), 2):
146
- row_data = raw_result[i]
147
- row = dict(zip(row_data[::2], row_data[1::2]))
148
- rows.append(row)
149
- else:
150
- # FT.AGGREGATE format: [count, [fields1], [fields2], ...]
151
- for row_data in raw_result[1:]:
152
- row = dict(zip(row_data[::2], row_data[1::2]))
153
- rows.append(row)
154
-
155
- return QueryResult(rows=rows, count=count)
@@ -1,142 +0,0 @@
1
- """Schema registry for Redis search indexes."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Callable
6
-
7
- import redis
8
-
9
-
10
- class SchemaRegistry:
11
- """Loads and caches index schemas from Redis.
12
-
13
- Supports automatic schema refresh via Redis keyspace notifications.
14
- """
15
-
16
- def __init__(self, redis_client: redis.Redis):
17
- self._client = redis_client
18
- self._schemas: dict[str, dict[str, str]] = {}
19
- self._on_change: Callable[[str, str], None] | None = None
20
- self._watching = False
21
-
22
- def load_all(self) -> None:
23
- """Load schemas for all indexes on the server."""
24
- self._schemas.clear()
25
- indexes = self._client.execute_command("FT._LIST")
26
- for index_name in indexes:
27
- # Decode bytes to string if needed
28
- if isinstance(index_name, bytes):
29
- index_name = index_name.decode("utf-8")
30
- self._load_index_schema(index_name)
31
-
32
- def _load_index_schema(self, index_name: str) -> None:
33
- """Load schema for a single index."""
34
- try:
35
- info = self._client.execute_command("FT.INFO", index_name)
36
- schema = self._parse_schema_from_info(info)
37
- self._schemas[index_name] = schema
38
- except redis.ResponseError:
39
- # Index doesn't exist or was deleted
40
- self._schemas.pop(index_name, None)
41
-
42
- def _parse_schema_from_info(self, info: list) -> dict[str, str]:
43
- """Parse field types from FT.INFO response."""
44
- schema = {}
45
- # Find the 'attributes' section in the info response
46
- for i, item in enumerate(info):
47
- # Handle bytes or string comparison
48
- item_str = item.decode("utf-8") if isinstance(item, bytes) else item
49
- if item_str == "attributes":
50
- attributes = info[i + 1]
51
- for attr in attributes:
52
- field_name = None
53
- field_type = None
54
- # Each attribute is a list like:
55
- # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...]
56
- for j, val in enumerate(attr):
57
- val_str = val.decode("utf-8") if isinstance(val, bytes) else val
58
- if val_str == "attribute" and j + 1 < len(attr):
59
- fn = attr[j + 1]
60
- field_name = (
61
- fn.decode("utf-8") if isinstance(fn, bytes) else fn
62
- )
63
- if val_str == "type" and j + 1 < len(attr):
64
- ft = attr[j + 1]
65
- field_type = (
66
- ft.decode("utf-8") if isinstance(ft, bytes) else ft
67
- )
68
- if field_name and field_type:
69
- schema[field_name] = field_type
70
- break
71
- return schema
72
-
73
- def get_field_type(self, index: str, field: str) -> str | None:
74
- """Get field type for a given index and field.
75
-
76
- Returns None if index or field is unknown.
77
- """
78
- schema = self._schemas.get(index, {})
79
- return schema.get(field)
80
-
81
- def get_schema(self, index: str) -> dict[str, str]:
82
- """Get full schema for an index.
83
-
84
- Returns empty dict if index is unknown.
85
- """
86
- return self._schemas.get(index, {})
87
-
88
- def refresh(self, index_name: str) -> None:
89
- """Refresh schema for a single index.
90
-
91
- If the index no longer exists, removes it from the registry.
92
- If the index is new, adds it to the registry.
93
- """
94
- self._load_index_schema(index_name)
95
-
96
- def start_watching(
97
- self, on_change: Callable[[str, str], None] | None = None
98
- ) -> None:
99
- """Start watching for index changes.
100
-
101
- Since RediSearch doesn't emit keyspace notifications for FT commands,
102
- this uses polling via FT._LIST to detect changes.
103
-
104
- Args:
105
- on_change: Optional callback invoked with (event_type, index_name)
106
- when an index is created, dropped, or altered.
107
- """
108
- self._on_change = on_change
109
- self._watching = True
110
-
111
- def stop_watching(self) -> None:
112
- """Stop watching for index changes."""
113
- self._watching = False
114
- self._on_change = None
115
-
116
- def process_pending_events(self) -> None:
117
- """Process any pending index change events.
118
-
119
- Since RediSearch doesn't emit keyspace notifications, this polls
120
- FT._LIST to detect new and deleted indexes. Call this periodically.
121
- """
122
- if not self._watching:
123
- return
124
-
125
- # Get current indexes from Redis
126
- current_indexes = set(self._client.execute_command("FT._LIST"))
127
-
128
- cached_indexes = set(self._schemas.keys())
129
-
130
- # Detect new indexes
131
- new_indexes = current_indexes - cached_indexes
132
- for idx in new_indexes:
133
- self._load_index_schema(idx)
134
- if self._on_change:
135
- self._on_change("created", idx)
136
-
137
- # Detect deleted indexes
138
- deleted_indexes = cached_indexes - current_indexes
139
- for idx in deleted_indexes:
140
- self._schemas.pop(idx, None)
141
- if self._on_change:
142
- self._on_change("dropped", idx)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes