sql-redis 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_redis/__init__.py +5 -0
- sql_redis/analyzer.py +133 -0
- sql_redis/executor.py +83 -0
- sql_redis/parser.py +440 -0
- sql_redis/query_builder.py +270 -0
- sql_redis/schema.py +142 -0
- sql_redis/translator.py +324 -0
- sql_redis-0.1.0.dist-info/METADATA +211 -0
- sql_redis-0.1.0.dist-info/RECORD +10 -0
- sql_redis-0.1.0.dist-info/WHEEL +4 -0
sql_redis/__init__.py
ADDED
sql_redis/analyzer.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""SQL analyzer component - resolves field types from schema."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from sql_redis.parser import AggregationSpec, ComputedField, Condition, ParsedQuery
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class VectorSearchAnalysis:
|
|
12
|
+
"""Analyzed vector search details."""
|
|
13
|
+
|
|
14
|
+
field: str
|
|
15
|
+
k: int
|
|
16
|
+
alias: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AnalyzedQuery:
|
|
21
|
+
"""Result of analyzing a parsed SQL query with schema context."""
|
|
22
|
+
|
|
23
|
+
parsed: ParsedQuery = field(default_factory=ParsedQuery)
|
|
24
|
+
field_types: dict[str, str] = field(default_factory=dict)
|
|
25
|
+
aggregations: list[AggregationSpec] = field(default_factory=list)
|
|
26
|
+
computed_fields: list[ComputedField] = field(default_factory=list)
|
|
27
|
+
groupby_fields: list[str] = field(default_factory=list)
|
|
28
|
+
is_global_aggregation: bool = False
|
|
29
|
+
vector_search: VectorSearchAnalysis | None = None
|
|
30
|
+
has_prefilter: bool = False
|
|
31
|
+
|
|
32
|
+
def get_field_type(self, field_name: str) -> str | None:
|
|
33
|
+
"""Get the type of a field."""
|
|
34
|
+
return self.field_types.get(field_name)
|
|
35
|
+
|
|
36
|
+
def get_conditions_by_type(self, field_type: str) -> list[Condition]:
|
|
37
|
+
"""Get conditions for fields of a specific type."""
|
|
38
|
+
return [
|
|
39
|
+
c
|
|
40
|
+
for c in self.parsed.conditions
|
|
41
|
+
if self.field_types.get(c.field) == field_type
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Analyzer:
|
|
46
|
+
"""Analyzes parsed SQL queries with schema context."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, schemas: dict[str, dict[str, str]]):
|
|
49
|
+
"""Initialize analyzer with schema registry data.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
schemas: Dict mapping index names to field->type dicts.
|
|
53
|
+
"""
|
|
54
|
+
self._schemas = schemas
|
|
55
|
+
|
|
56
|
+
def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
|
|
57
|
+
"""Analyze a parsed query, resolving field types.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
parsed: The parsed SQL query.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
An AnalyzedQuery with field types resolved.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If the index or a field is unknown.
|
|
67
|
+
"""
|
|
68
|
+
# Validate index exists
|
|
69
|
+
if parsed.index not in self._schemas:
|
|
70
|
+
raise ValueError(f"Unknown index: {parsed.index}")
|
|
71
|
+
|
|
72
|
+
schema = self._schemas[parsed.index]
|
|
73
|
+
result = AnalyzedQuery(parsed=parsed)
|
|
74
|
+
|
|
75
|
+
# Collect all fields referenced in the query
|
|
76
|
+
referenced_fields: set[str] = set()
|
|
77
|
+
|
|
78
|
+
# Fields from SELECT
|
|
79
|
+
for field_name in parsed.fields:
|
|
80
|
+
if field_name != "*":
|
|
81
|
+
referenced_fields.add(field_name)
|
|
82
|
+
|
|
83
|
+
# Fields from conditions
|
|
84
|
+
for condition in parsed.conditions:
|
|
85
|
+
referenced_fields.add(condition.field)
|
|
86
|
+
|
|
87
|
+
# Fields from aggregations
|
|
88
|
+
for agg in parsed.aggregations:
|
|
89
|
+
if agg.field:
|
|
90
|
+
referenced_fields.add(agg.field)
|
|
91
|
+
|
|
92
|
+
# Fields from computed fields (extract field references from expressions)
|
|
93
|
+
for computed in parsed.computed_fields:
|
|
94
|
+
# Simple extraction - look for field names in the expression
|
|
95
|
+
for field_name in schema.keys():
|
|
96
|
+
if field_name in computed.expression:
|
|
97
|
+
referenced_fields.add(field_name)
|
|
98
|
+
|
|
99
|
+
# Fields from vector search
|
|
100
|
+
if parsed.vector_search:
|
|
101
|
+
referenced_fields.add(parsed.vector_search.field)
|
|
102
|
+
|
|
103
|
+
# Fields from GROUP BY
|
|
104
|
+
for field_name in parsed.groupby_fields:
|
|
105
|
+
referenced_fields.add(field_name)
|
|
106
|
+
|
|
107
|
+
# Resolve field types
|
|
108
|
+
for field_name in referenced_fields:
|
|
109
|
+
if field_name not in schema:
|
|
110
|
+
raise ValueError(f"Unknown field: {field_name}")
|
|
111
|
+
result.field_types[field_name] = schema[field_name]
|
|
112
|
+
|
|
113
|
+
# Copy aggregations and computed fields
|
|
114
|
+
result.aggregations = parsed.aggregations
|
|
115
|
+
result.computed_fields = parsed.computed_fields
|
|
116
|
+
result.groupby_fields = parsed.groupby_fields
|
|
117
|
+
|
|
118
|
+
# Determine if this is a global aggregation
|
|
119
|
+
result.is_global_aggregation = (
|
|
120
|
+
len(parsed.aggregations) > 0 and len(parsed.groupby_fields) == 0
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Analyze vector search
|
|
124
|
+
if parsed.vector_search:
|
|
125
|
+
result.vector_search = VectorSearchAnalysis(
|
|
126
|
+
field=parsed.vector_search.field,
|
|
127
|
+
k=parsed.limit or parsed.vector_search.k or 10,
|
|
128
|
+
alias=parsed.vector_search.alias,
|
|
129
|
+
)
|
|
130
|
+
# Has prefilter if there are conditions
|
|
131
|
+
result.has_prefilter = len(parsed.conditions) > 0
|
|
132
|
+
|
|
133
|
+
return result
|
sql_redis/executor.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""SQL Executor - executes translated queries against Redis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import redis
|
|
8
|
+
|
|
9
|
+
from sql_redis.schema import SchemaRegistry
|
|
10
|
+
from sql_redis.translator import Translator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class QueryResult:
|
|
15
|
+
"""Result of executing a SQL query."""
|
|
16
|
+
|
|
17
|
+
rows: list[dict]
|
|
18
|
+
count: int
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Executor:
|
|
22
|
+
"""Executes SQL queries against Redis."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry):
|
|
25
|
+
"""Initialize executor with Redis client and schema registry."""
|
|
26
|
+
self._client = client
|
|
27
|
+
self._schema_registry = schema_registry
|
|
28
|
+
self._translator = Translator(schema_registry)
|
|
29
|
+
|
|
30
|
+
def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
|
|
31
|
+
"""Execute a SQL query and return results."""
|
|
32
|
+
params = params or {}
|
|
33
|
+
|
|
34
|
+
# Substitute non-bytes params in SQL
|
|
35
|
+
for key, value in params.items():
|
|
36
|
+
placeholder = f":{key}"
|
|
37
|
+
if isinstance(value, (int, float)):
|
|
38
|
+
sql = sql.replace(placeholder, str(value))
|
|
39
|
+
elif isinstance(value, str):
|
|
40
|
+
sql = sql.replace(placeholder, f"'{value}'")
|
|
41
|
+
# bytes (vectors) are handled via Redis PARAMS
|
|
42
|
+
|
|
43
|
+
# Translate SQL to Redis command
|
|
44
|
+
translated = self._translator.translate(sql)
|
|
45
|
+
|
|
46
|
+
# Build command list and substitute vector params
|
|
47
|
+
# Use list[str | bytes] to allow bytes for vector params
|
|
48
|
+
cmd: list[str | bytes] = list(translated.to_command_list())
|
|
49
|
+
|
|
50
|
+
# Find any bytes params (vectors) to substitute
|
|
51
|
+
vector_param: bytes | None = None
|
|
52
|
+
for value in params.values():
|
|
53
|
+
if isinstance(value, bytes):
|
|
54
|
+
vector_param = value
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
# Replace $vector placeholder with actual bytes
|
|
58
|
+
if vector_param:
|
|
59
|
+
for i, arg in enumerate(cmd):
|
|
60
|
+
if arg == "$vector":
|
|
61
|
+
cmd[i] = vector_param
|
|
62
|
+
|
|
63
|
+
# Execute command
|
|
64
|
+
raw_result = self._client.execute_command(*cmd)
|
|
65
|
+
|
|
66
|
+
# Parse result based on command type
|
|
67
|
+
count = raw_result[0] if raw_result else 0
|
|
68
|
+
rows = []
|
|
69
|
+
|
|
70
|
+
if translated.command == "FT.SEARCH":
|
|
71
|
+
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
|
|
72
|
+
# Skip document keys (odd indices), take field lists (even indices after count)
|
|
73
|
+
for i in range(2, len(raw_result), 2):
|
|
74
|
+
row_data = raw_result[i]
|
|
75
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
76
|
+
rows.append(row)
|
|
77
|
+
else:
|
|
78
|
+
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
|
|
79
|
+
for row_data in raw_result[1:]:
|
|
80
|
+
row = dict(zip(row_data[::2], row_data[1::2]))
|
|
81
|
+
rows.append(row)
|
|
82
|
+
|
|
83
|
+
return QueryResult(rows=rows, count=count)
|
sql_redis/parser.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
"""SQL parser component using sqlglot."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
import sqlglot
|
|
9
|
+
from sqlglot import exp
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class AggregationSpec:
|
|
14
|
+
"""Specification for an aggregation function."""
|
|
15
|
+
|
|
16
|
+
function: str
|
|
17
|
+
field: str | None = None
|
|
18
|
+
alias: str | None = None
|
|
19
|
+
extra_args: list[str] = dataclasses.field(
|
|
20
|
+
default_factory=list
|
|
21
|
+
) # For reducers like QUANTILE
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ComputedField:
|
|
26
|
+
"""Specification for a computed/APPLY field."""
|
|
27
|
+
|
|
28
|
+
expression: str
|
|
29
|
+
alias: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class VectorSearchSpec:
|
|
34
|
+
"""Specification for vector search."""
|
|
35
|
+
|
|
36
|
+
field: str
|
|
37
|
+
alias: str
|
|
38
|
+
k: int | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class Condition:
|
|
43
|
+
"""A WHERE condition."""
|
|
44
|
+
|
|
45
|
+
field: str
|
|
46
|
+
operator: str
|
|
47
|
+
value: object
|
|
48
|
+
negated: bool = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class ParsedQuery:
|
|
53
|
+
"""Result of parsing a SQL query."""
|
|
54
|
+
|
|
55
|
+
index: str = ""
|
|
56
|
+
fields: list[str] = dataclasses.field(default_factory=list)
|
|
57
|
+
conditions: list[Condition] = dataclasses.field(default_factory=list)
|
|
58
|
+
boolean_operator: str = "AND"
|
|
59
|
+
aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list)
|
|
60
|
+
computed_fields: list[ComputedField] = dataclasses.field(default_factory=list)
|
|
61
|
+
vector_search: VectorSearchSpec | None = None
|
|
62
|
+
groupby_fields: list[str] = dataclasses.field(default_factory=list)
|
|
63
|
+
orderby_fields: list[tuple[str, str]] = dataclasses.field(
|
|
64
|
+
default_factory=list
|
|
65
|
+
) # (field, ASC|DESC)
|
|
66
|
+
limit: int | None = None
|
|
67
|
+
offset: int | None = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SQLParser:
|
|
71
|
+
"""Parses SQL into a ParsedQuery structure."""
|
|
72
|
+
|
|
73
|
+
def parse(self, sql: str) -> ParsedQuery:
|
|
74
|
+
"""Parse a SQL statement into a ParsedQuery.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
sql: The SQL statement to parse.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A ParsedQuery containing the extracted components.
|
|
81
|
+
"""
|
|
82
|
+
ast = sqlglot.parse_one(sql)
|
|
83
|
+
result = ParsedQuery()
|
|
84
|
+
|
|
85
|
+
# Extract FROM clause (index name)
|
|
86
|
+
from_clause = ast.find(exp.From)
|
|
87
|
+
if from_clause:
|
|
88
|
+
table = from_clause.find(exp.Table)
|
|
89
|
+
if table:
|
|
90
|
+
result.index = table.name
|
|
91
|
+
|
|
92
|
+
# Extract SELECT fields and aggregations
|
|
93
|
+
select = ast.find(exp.Select)
|
|
94
|
+
if select:
|
|
95
|
+
for expression in select.expressions:
|
|
96
|
+
self._process_select_expression(expression, result)
|
|
97
|
+
|
|
98
|
+
# Extract WHERE clause conditions
|
|
99
|
+
where = ast.find(exp.Where)
|
|
100
|
+
if where:
|
|
101
|
+
self._process_where_clause(where.this, result)
|
|
102
|
+
|
|
103
|
+
# Extract GROUP BY clause
|
|
104
|
+
group = ast.find(exp.Group)
|
|
105
|
+
if group:
|
|
106
|
+
for expr in group.expressions:
|
|
107
|
+
if isinstance(expr, exp.Column):
|
|
108
|
+
result.groupby_fields.append(expr.name)
|
|
109
|
+
|
|
110
|
+
# Extract ORDER BY clause
|
|
111
|
+
order = ast.find(exp.Order)
|
|
112
|
+
if order:
|
|
113
|
+
for ordered in order.expressions:
|
|
114
|
+
col = ordered.this
|
|
115
|
+
if isinstance(col, exp.Column):
|
|
116
|
+
direction = "DESC" if ordered.args.get("desc") else "ASC"
|
|
117
|
+
result.orderby_fields.append((col.name, direction))
|
|
118
|
+
elif isinstance(col, (exp.CosineDistance, exp.Distance)):
|
|
119
|
+
# ORDER BY vector distance - handled by KNN, don't add to orderby
|
|
120
|
+
# The vector_search should already be set from SELECT clause
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
# Extract LIMIT clause
|
|
124
|
+
limit = ast.find(exp.Limit)
|
|
125
|
+
if limit:
|
|
126
|
+
limit_expr = limit.args.get("expression") or limit.this
|
|
127
|
+
if isinstance(limit_expr, exp.Literal):
|
|
128
|
+
result.limit = int(limit_expr.this)
|
|
129
|
+
|
|
130
|
+
# Extract OFFSET clause
|
|
131
|
+
offset = ast.find(exp.Offset)
|
|
132
|
+
if offset:
|
|
133
|
+
offset_expr = offset.args.get("expression") or offset.this
|
|
134
|
+
if isinstance(offset_expr, exp.Literal):
|
|
135
|
+
result.offset = int(offset_expr.this)
|
|
136
|
+
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
def _process_select_expression(self, expression, result: ParsedQuery) -> None:
|
|
140
|
+
"""Process a single SELECT expression."""
|
|
141
|
+
# Handle aliased expressions (e.g., COUNT(*) AS count)
|
|
142
|
+
if isinstance(expression, exp.Alias):
|
|
143
|
+
alias = expression.alias
|
|
144
|
+
inner = expression.this
|
|
145
|
+
self._process_select_expression_inner(inner, result, alias)
|
|
146
|
+
else:
|
|
147
|
+
self._process_select_expression_inner(expression, result, None)
|
|
148
|
+
|
|
149
|
+
def _process_select_expression_inner(
|
|
150
|
+
self, expression, result: ParsedQuery, alias: str | None
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Process the inner part of a SELECT expression."""
|
|
153
|
+
if isinstance(expression, exp.Column):
|
|
154
|
+
result.fields.append(expression.name)
|
|
155
|
+
elif isinstance(expression, exp.Star):
|
|
156
|
+
result.fields.append("*")
|
|
157
|
+
elif isinstance(
|
|
158
|
+
expression,
|
|
159
|
+
(
|
|
160
|
+
exp.Count,
|
|
161
|
+
exp.Sum,
|
|
162
|
+
exp.Avg,
|
|
163
|
+
exp.Min,
|
|
164
|
+
exp.Max,
|
|
165
|
+
exp.Stddev,
|
|
166
|
+
exp.Variance,
|
|
167
|
+
exp.FirstValue,
|
|
168
|
+
exp.ArrayAgg,
|
|
169
|
+
),
|
|
170
|
+
):
|
|
171
|
+
# Aggregation function
|
|
172
|
+
# Map sqlglot function names to Redis reducer names
|
|
173
|
+
func_name = expression.key.upper()
|
|
174
|
+
redis_func_map = {
|
|
175
|
+
"FIRSTVALUE": "FIRST_VALUE",
|
|
176
|
+
"ARRAYAGG": "TOLIST",
|
|
177
|
+
}
|
|
178
|
+
func_name = redis_func_map.get(func_name, func_name)
|
|
179
|
+
field_name = None
|
|
180
|
+
# Get the field being aggregated (if any)
|
|
181
|
+
if expression.this:
|
|
182
|
+
if isinstance(expression.this, exp.Column):
|
|
183
|
+
field_name = expression.this.name
|
|
184
|
+
elif isinstance(expression.this, exp.Star):
|
|
185
|
+
field_name = None # COUNT(*)
|
|
186
|
+
result.aggregations.append(
|
|
187
|
+
AggregationSpec(function=func_name, field=field_name, alias=alias)
|
|
188
|
+
)
|
|
189
|
+
elif isinstance(expression, exp.Paren):
|
|
190
|
+
# Parenthesized expression - computed field
|
|
191
|
+
inner_expr = expression.this.sql()
|
|
192
|
+
# Use alias if provided, otherwise generate one from expression
|
|
193
|
+
field_alias = alias if alias else inner_expr
|
|
194
|
+
result.computed_fields.append(
|
|
195
|
+
ComputedField(expression=inner_expr, alias=field_alias)
|
|
196
|
+
)
|
|
197
|
+
elif isinstance(expression, (exp.Mul, exp.Div, exp.Add, exp.Sub)):
|
|
198
|
+
# Arithmetic expression without parentheses - computed field
|
|
199
|
+
expr_str = expression.sql()
|
|
200
|
+
# Use alias if provided, otherwise generate one from expression
|
|
201
|
+
field_alias = alias if alias else expr_str
|
|
202
|
+
result.computed_fields.append(
|
|
203
|
+
ComputedField(expression=expr_str, alias=field_alias)
|
|
204
|
+
)
|
|
205
|
+
elif isinstance(expression, (exp.Distance, exp.CosineDistance)):
|
|
206
|
+
# Vector distance functions:
|
|
207
|
+
# - Distance: L2/Euclidean distance
|
|
208
|
+
# - CosineDistance: cosine_distance() function
|
|
209
|
+
self._process_vector_distance(expression, result, alias)
|
|
210
|
+
elif isinstance(expression, exp.Quantile):
|
|
211
|
+
# QUANTILE(field, quantile_value) -> REDUCE QUANTILE 2 @field quantile_value
|
|
212
|
+
field_name = None
|
|
213
|
+
if expression.this and isinstance(expression.this, exp.Column):
|
|
214
|
+
field_name = expression.this.name
|
|
215
|
+
quantile_value = None
|
|
216
|
+
if expression.args.get("quantile"):
|
|
217
|
+
quantile_value = str(expression.args["quantile"].this)
|
|
218
|
+
extra_args = [quantile_value] if quantile_value else []
|
|
219
|
+
result.aggregations.append(
|
|
220
|
+
AggregationSpec(
|
|
221
|
+
function="QUANTILE",
|
|
222
|
+
field=field_name,
|
|
223
|
+
alias=alias,
|
|
224
|
+
extra_args=extra_args,
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
elif isinstance(expression, exp.Anonymous):
|
|
228
|
+
# Custom function call (e.g., vector_distance) - check before exp.Func
|
|
229
|
+
# since Anonymous is a subclass of Func
|
|
230
|
+
func_name = expression.name.lower()
|
|
231
|
+
# Redis-specific reducer functions that sqlglot doesn't recognize
|
|
232
|
+
redis_reducers = {
|
|
233
|
+
"count_distinct",
|
|
234
|
+
"count_distinctish",
|
|
235
|
+
"quantile",
|
|
236
|
+
"random_sample",
|
|
237
|
+
}
|
|
238
|
+
if func_name == "vector_distance":
|
|
239
|
+
# Extract the vector field name from first argument
|
|
240
|
+
if expression.expressions:
|
|
241
|
+
first_arg = expression.expressions[0]
|
|
242
|
+
if isinstance(first_arg, exp.Column):
|
|
243
|
+
field_name = first_arg.name
|
|
244
|
+
result.vector_search = VectorSearchSpec(
|
|
245
|
+
field=field_name,
|
|
246
|
+
alias=alias or func_name,
|
|
247
|
+
)
|
|
248
|
+
elif func_name in redis_reducers:
|
|
249
|
+
# Redis-specific reducer functions
|
|
250
|
+
field_name = None
|
|
251
|
+
reducer_extra_args: list[str] = []
|
|
252
|
+
if expression.expressions:
|
|
253
|
+
first_arg = expression.expressions[0]
|
|
254
|
+
if isinstance(first_arg, exp.Column):
|
|
255
|
+
field_name = first_arg.name
|
|
256
|
+
# Extract additional arguments (e.g., quantile value for QUANTILE)
|
|
257
|
+
for arg in expression.expressions[1:]:
|
|
258
|
+
if isinstance(arg, exp.Literal):
|
|
259
|
+
reducer_extra_args.append(str(arg.this))
|
|
260
|
+
result.aggregations.append(
|
|
261
|
+
AggregationSpec(
|
|
262
|
+
function=func_name.upper(),
|
|
263
|
+
field=field_name,
|
|
264
|
+
alias=alias,
|
|
265
|
+
extra_args=reducer_extra_args,
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
# Other custom functions - treat as computed field
|
|
270
|
+
expr_str = expression.sql()
|
|
271
|
+
field_alias = alias if alias else expr_str
|
|
272
|
+
result.computed_fields.append(
|
|
273
|
+
ComputedField(expression=expr_str, alias=field_alias)
|
|
274
|
+
)
|
|
275
|
+
elif isinstance(expression, exp.Func):
|
|
276
|
+
# Built-in function call (e.g., UPPER, LOWER, etc.) - treat as computed field
|
|
277
|
+
expr_str = expression.sql()
|
|
278
|
+
field_alias = alias if alias else expr_str
|
|
279
|
+
result.computed_fields.append(
|
|
280
|
+
ComputedField(expression=expr_str, alias=field_alias)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _process_vector_distance(
|
|
284
|
+
self, expression, result: ParsedQuery, alias: str | None
|
|
285
|
+
) -> None:
|
|
286
|
+
"""Process a vector distance expression (cosine_distance, etc.)."""
|
|
287
|
+
field_name = None
|
|
288
|
+
|
|
289
|
+
# Extract field from the expression
|
|
290
|
+
# Both Distance and CosineDistance have 'this' as the first argument
|
|
291
|
+
if expression.this and isinstance(expression.this, exp.Column):
|
|
292
|
+
field_name = expression.this.name
|
|
293
|
+
|
|
294
|
+
if field_name:
|
|
295
|
+
result.vector_search = VectorSearchSpec(
|
|
296
|
+
field=field_name,
|
|
297
|
+
alias=alias or "vector_distance",
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def _process_where_clause(
|
|
301
|
+
self, expression, result: ParsedQuery, negated: bool = False
|
|
302
|
+
) -> None:
|
|
303
|
+
"""Process WHERE clause expression recursively."""
|
|
304
|
+
if isinstance(expression, exp.EQ):
|
|
305
|
+
self._add_condition(expression, "=", result, negated)
|
|
306
|
+
elif isinstance(expression, exp.GT):
|
|
307
|
+
self._add_condition(expression, ">", result, negated)
|
|
308
|
+
elif isinstance(expression, exp.GTE):
|
|
309
|
+
self._add_condition(expression, ">=", result, negated)
|
|
310
|
+
elif isinstance(expression, exp.LT):
|
|
311
|
+
self._add_condition(expression, "<", result, negated)
|
|
312
|
+
elif isinstance(expression, exp.LTE):
|
|
313
|
+
self._add_condition(expression, "<=", result, negated)
|
|
314
|
+
elif isinstance(expression, exp.NEQ):
|
|
315
|
+
self._add_condition(expression, "!=", result, negated)
|
|
316
|
+
elif isinstance(expression, exp.Between):
|
|
317
|
+
self._add_between_condition(expression, result, negated)
|
|
318
|
+
elif isinstance(expression, exp.In):
|
|
319
|
+
self._add_in_condition(expression, result, negated)
|
|
320
|
+
elif isinstance(expression, exp.And):
|
|
321
|
+
result.boolean_operator = "AND"
|
|
322
|
+
self._process_where_clause(expression.this, result, negated)
|
|
323
|
+
self._process_where_clause(expression.expression, result, negated)
|
|
324
|
+
elif isinstance(expression, exp.Or):
|
|
325
|
+
result.boolean_operator = "OR"
|
|
326
|
+
self._process_where_clause(expression.this, result, negated)
|
|
327
|
+
self._process_where_clause(expression.expression, result, negated)
|
|
328
|
+
elif isinstance(expression, exp.Not):
|
|
329
|
+
self._process_where_clause(expression.this, result, negated=True)
|
|
330
|
+
elif isinstance(expression, exp.Anonymous):
|
|
331
|
+
# Custom function like MATCH(field, value)
|
|
332
|
+
self._add_function_condition(expression, result, negated)
|
|
333
|
+
|
|
334
|
+
def _add_condition(
|
|
335
|
+
self, expression, operator: str, result: ParsedQuery, negated: bool
|
|
336
|
+
) -> None:
|
|
337
|
+
"""Add a condition from a comparison expression."""
|
|
338
|
+
field_name = None
|
|
339
|
+
value = None
|
|
340
|
+
|
|
341
|
+
# Get field name from left side
|
|
342
|
+
if isinstance(expression.this, exp.Column):
|
|
343
|
+
field_name = expression.this.name
|
|
344
|
+
elif isinstance(expression.this, exp.Anonymous):
|
|
345
|
+
# Function call like DISTANCE(location, POINT(...))
|
|
346
|
+
# Extract field from first argument
|
|
347
|
+
func_name = expression.this.name.upper()
|
|
348
|
+
if expression.this.expressions:
|
|
349
|
+
first_arg = expression.this.expressions[0]
|
|
350
|
+
if isinstance(first_arg, exp.Column):
|
|
351
|
+
field_name = first_arg.name
|
|
352
|
+
# Use function name as operator prefix
|
|
353
|
+
operator = f"{func_name}_{operator}"
|
|
354
|
+
|
|
355
|
+
# Get value from right side
|
|
356
|
+
if isinstance(expression.expression, exp.Literal):
|
|
357
|
+
value = expression.expression.this
|
|
358
|
+
# Convert numeric strings to numbers
|
|
359
|
+
if expression.expression.is_number:
|
|
360
|
+
value = int(value) if "." not in str(value) else float(value)
|
|
361
|
+
|
|
362
|
+
if field_name is not None:
|
|
363
|
+
result.conditions.append(
|
|
364
|
+
Condition(
|
|
365
|
+
field=field_name, operator=operator, value=value, negated=negated
|
|
366
|
+
)
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def _add_between_condition(
|
|
370
|
+
self, expression, result: ParsedQuery, negated: bool
|
|
371
|
+
) -> None:
|
|
372
|
+
"""Add a BETWEEN condition."""
|
|
373
|
+
field_name = None
|
|
374
|
+
if isinstance(expression.this, exp.Column):
|
|
375
|
+
field_name = expression.this.name
|
|
376
|
+
|
|
377
|
+
low = expression.args.get("low")
|
|
378
|
+
high = expression.args.get("high")
|
|
379
|
+
|
|
380
|
+
low_val = self._extract_literal_value(low)
|
|
381
|
+
high_val = self._extract_literal_value(high)
|
|
382
|
+
|
|
383
|
+
if field_name is not None:
|
|
384
|
+
result.conditions.append(
|
|
385
|
+
Condition(
|
|
386
|
+
field=field_name,
|
|
387
|
+
operator="BETWEEN",
|
|
388
|
+
value=(low_val, high_val),
|
|
389
|
+
negated=negated,
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def _add_in_condition(self, expression, result: ParsedQuery, negated: bool) -> None:
|
|
394
|
+
"""Add an IN condition."""
|
|
395
|
+
field_name = None
|
|
396
|
+
if isinstance(expression.this, exp.Column):
|
|
397
|
+
field_name = expression.this.name
|
|
398
|
+
|
|
399
|
+
values = [self._extract_literal_value(e) for e in expression.expressions]
|
|
400
|
+
|
|
401
|
+
if field_name is not None:
|
|
402
|
+
result.conditions.append(
|
|
403
|
+
Condition(
|
|
404
|
+
field=field_name, operator="IN", value=values, negated=negated
|
|
405
|
+
)
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
def _add_function_condition(
|
|
409
|
+
self, expression, result: ParsedQuery, negated: bool
|
|
410
|
+
) -> None:
|
|
411
|
+
"""Add a condition from a function call like fulltext(field, value)."""
|
|
412
|
+
func_name = expression.name.upper()
|
|
413
|
+
if func_name == "FULLTEXT" and len(expression.expressions) >= 2:
|
|
414
|
+
first_arg = expression.expressions[0]
|
|
415
|
+
second_arg = expression.expressions[1]
|
|
416
|
+
|
|
417
|
+
field_name = None
|
|
418
|
+
if isinstance(first_arg, exp.Column):
|
|
419
|
+
field_name = first_arg.name
|
|
420
|
+
|
|
421
|
+
value = self._extract_literal_value(second_arg)
|
|
422
|
+
|
|
423
|
+
if field_name is not None:
|
|
424
|
+
result.conditions.append(
|
|
425
|
+
Condition(
|
|
426
|
+
field=field_name,
|
|
427
|
+
operator="FULLTEXT",
|
|
428
|
+
value=value,
|
|
429
|
+
negated=negated,
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def _extract_literal_value(self, expression):
|
|
434
|
+
"""Extract a Python value from a sqlglot Literal."""
|
|
435
|
+
if isinstance(expression, exp.Literal):
|
|
436
|
+
value = expression.this
|
|
437
|
+
if expression.is_number:
|
|
438
|
+
return int(value) if "." not in str(value) else float(value)
|
|
439
|
+
return value
|
|
440
|
+
return None
|