python-datastore-sqlalchemy 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_datastore_sqlalchemy-0.0.1.dist-info/METADATA +65 -0
- python_datastore_sqlalchemy-0.0.1.dist-info/RECORD +12 -0
- python_datastore_sqlalchemy-0.0.1.dist-info/WHEEL +5 -0
- python_datastore_sqlalchemy-0.0.1.dist-info/entry_points.txt +2 -0
- python_datastore_sqlalchemy-0.0.1.dist-info/licenses/LICENSE +9 -0
- python_datastore_sqlalchemy-0.0.1.dist-info/top_level.txt +1 -0
- sqlalchemy_datastore/__init__.py +27 -0
- sqlalchemy_datastore/_helpers.py +135 -0
- sqlalchemy_datastore/_types.py +147 -0
- sqlalchemy_datastore/base.py +291 -0
- sqlalchemy_datastore/datastore_dbapi.py +2322 -0
- sqlalchemy_datastore/parse_url.py +287 -0
|
@@ -0,0 +1,2322 @@
|
|
|
1
|
+
# Copyright (c) 2025 hychang <hychang.1997.tw@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
# this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
# the Software without restriction, including without limitation the rights to
|
|
6
|
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
7
|
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
8
|
+
# subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
15
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
16
|
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
17
|
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
18
|
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
19
|
+
import base64
|
|
20
|
+
import collections
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import re
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
26
|
+
|
|
27
|
+
import pandas as pd
|
|
28
|
+
import requests
|
|
29
|
+
from google.auth.transport.requests import AuthorizedSession
|
|
30
|
+
from google.cloud import datastore
|
|
31
|
+
from google.cloud.datastore.helpers import GeoPoint
|
|
32
|
+
from google.oauth2 import service_account
|
|
33
|
+
from requests import Response
|
|
34
|
+
from sqlalchemy import types
|
|
35
|
+
from sqlglot import exp, parse_one, tokenize, tokens
|
|
36
|
+
from sqlglot.tokens import TokenType
|
|
37
|
+
|
|
38
|
+
from . import _types
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger("sqlalchemy.dialects.datastore_dbapi")
|
|
41
|
+
|
|
42
|
+
apilevel = "2.0"
|
|
43
|
+
threadsafety = 2
|
|
44
|
+
paramstyle = "named"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Required exceptions
|
|
48
|
+
class Warning(Exception):
|
|
49
|
+
"""Exception raised for important DB-API warnings."""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Error(Exception):
|
|
53
|
+
"""Exception representing all non-warning DB-API errors."""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class InterfaceError(Error):
|
|
57
|
+
"""DB-API error related to the database interface."""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DatabaseError(Error):
|
|
61
|
+
"""DB-API error related to the database."""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DataError(DatabaseError):
|
|
65
|
+
"""DB-API error due to problems with the processed data."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class OperationalError(DatabaseError):
|
|
69
|
+
"""DB-API error related to the database operation."""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class IntegrityError(DatabaseError):
|
|
73
|
+
"""DB-API error when integrity of the database is affected."""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class InternalError(DatabaseError):
|
|
77
|
+
"""DB-API error when the database encounters an internal error."""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ProgrammingError(DatabaseError):
|
|
81
|
+
"""DB-API exception raised for programming errors."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
Column = collections.namedtuple(
|
|
85
|
+
"Column",
|
|
86
|
+
[
|
|
87
|
+
"name",
|
|
88
|
+
"type_code",
|
|
89
|
+
"display_size",
|
|
90
|
+
"internal_size",
|
|
91
|
+
"precision",
|
|
92
|
+
"scale",
|
|
93
|
+
"null_ok",
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
type_map = {
|
|
98
|
+
str: types.String,
|
|
99
|
+
int: types.NUMERIC,
|
|
100
|
+
float: types.FLOAT,
|
|
101
|
+
bool: types.BOOLEAN,
|
|
102
|
+
bytes: types.BINARY,
|
|
103
|
+
datetime: types.DATETIME,
|
|
104
|
+
datastore.Key: types.JSON,
|
|
105
|
+
GeoPoint: types.JSON,
|
|
106
|
+
list: types.JSON,
|
|
107
|
+
dict: types.JSON,
|
|
108
|
+
None.__class__: types.String,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Cursor:
|
|
113
|
+
def __init__(self, connection):
|
|
114
|
+
self.connection = connection
|
|
115
|
+
self._datastore_client = connection._client
|
|
116
|
+
self.rowcount = -1
|
|
117
|
+
self.arraysize = None
|
|
118
|
+
self._query_data = None
|
|
119
|
+
self._query_rows = None
|
|
120
|
+
self._closed = False
|
|
121
|
+
self.description = None
|
|
122
|
+
self.lastrowid = None
|
|
123
|
+
self.warnings: list[str] = []
|
|
124
|
+
|
|
125
|
+
def execute(self, statements, parameters=None):
|
|
126
|
+
"""Execute a Datastore operation."""
|
|
127
|
+
if self._closed:
|
|
128
|
+
raise Error("Cursor is closed.")
|
|
129
|
+
|
|
130
|
+
# Check for DML statements
|
|
131
|
+
upper_statement = statements.upper().strip()
|
|
132
|
+
if upper_statement.startswith("INSERT"):
|
|
133
|
+
self._execute_insert(statements, parameters)
|
|
134
|
+
return
|
|
135
|
+
if upper_statement.startswith("UPDATE"):
|
|
136
|
+
self._execute_update(statements, parameters)
|
|
137
|
+
return
|
|
138
|
+
if upper_statement.startswith("DELETE"):
|
|
139
|
+
self._execute_delete(statements, parameters)
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
tokens = tokenize(statements)
|
|
143
|
+
if self._is_derived_query(tokens):
|
|
144
|
+
self.execute_orm(statements, parameters, tokens)
|
|
145
|
+
else:
|
|
146
|
+
self.gql_query(statements, parameters)
|
|
147
|
+
|
|
148
|
+
def _execute_insert(self, statement: str, parameters=None):
|
|
149
|
+
"""Execute an INSERT statement using Datastore client."""
|
|
150
|
+
if parameters is None:
|
|
151
|
+
parameters = {}
|
|
152
|
+
|
|
153
|
+
logging.debug(f"Executing INSERT: {statement} with parameters: {parameters}")
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
# Parse INSERT statement using sqlglot
|
|
157
|
+
parsed = parse_one(statement)
|
|
158
|
+
if not isinstance(parsed, exp.Insert):
|
|
159
|
+
raise ProgrammingError(f"Expected INSERT statement, got: {type(parsed)}")
|
|
160
|
+
|
|
161
|
+
# Get table/kind name
|
|
162
|
+
# For INSERT, parsed.this is a Schema containing the table and columns
|
|
163
|
+
schema_expr = parsed.this
|
|
164
|
+
if isinstance(schema_expr, exp.Schema):
|
|
165
|
+
# Schema has 'this' which is the table
|
|
166
|
+
table_expr = schema_expr.this
|
|
167
|
+
if isinstance(table_expr, exp.Table):
|
|
168
|
+
kind = table_expr.name
|
|
169
|
+
else:
|
|
170
|
+
kind = str(table_expr)
|
|
171
|
+
elif isinstance(schema_expr, exp.Table):
|
|
172
|
+
kind = schema_expr.name
|
|
173
|
+
else:
|
|
174
|
+
raise ProgrammingError("Could not determine table name from INSERT")
|
|
175
|
+
|
|
176
|
+
# Get column names from Schema's expressions
|
|
177
|
+
columns = []
|
|
178
|
+
if isinstance(schema_expr, exp.Schema) and schema_expr.expressions:
|
|
179
|
+
for col in schema_expr.expressions:
|
|
180
|
+
if hasattr(col, "name"):
|
|
181
|
+
columns.append(col.name)
|
|
182
|
+
else:
|
|
183
|
+
columns.append(str(col))
|
|
184
|
+
|
|
185
|
+
# Get values
|
|
186
|
+
values_list = []
|
|
187
|
+
values_expr = parsed.args.get("expression")
|
|
188
|
+
if values_expr and hasattr(values_expr, "expressions"):
|
|
189
|
+
for tuple_expr in values_expr.expressions:
|
|
190
|
+
if hasattr(tuple_expr, "expressions"):
|
|
191
|
+
row_values = []
|
|
192
|
+
for val in tuple_expr.expressions:
|
|
193
|
+
row_values.append(self._parse_insert_value(val, parameters))
|
|
194
|
+
values_list.append(row_values)
|
|
195
|
+
elif values_expr:
|
|
196
|
+
# Single row VALUES clause
|
|
197
|
+
row_values = []
|
|
198
|
+
if hasattr(values_expr, "expressions"):
|
|
199
|
+
for val in values_expr.expressions:
|
|
200
|
+
row_values.append(self._parse_insert_value(val, parameters))
|
|
201
|
+
values_list.append(row_values)
|
|
202
|
+
|
|
203
|
+
# Create entities and insert them
|
|
204
|
+
entities_created = 0
|
|
205
|
+
for row_values in values_list:
|
|
206
|
+
# Create entity key (auto-generated)
|
|
207
|
+
key = self._datastore_client.key(kind)
|
|
208
|
+
entity = datastore.Entity(key=key)
|
|
209
|
+
|
|
210
|
+
# Set entity properties
|
|
211
|
+
for i, col in enumerate(columns):
|
|
212
|
+
if i < len(row_values):
|
|
213
|
+
entity[col] = row_values[i]
|
|
214
|
+
|
|
215
|
+
# Put entity to datastore
|
|
216
|
+
self._datastore_client.put(entity)
|
|
217
|
+
entities_created += 1
|
|
218
|
+
# Save the last inserted entity's key ID for lastrowid
|
|
219
|
+
if entity.key.id is not None:
|
|
220
|
+
self.lastrowid = entity.key.id
|
|
221
|
+
elif entity.key.name is not None:
|
|
222
|
+
# For named keys, use a hash of the name as a numeric ID
|
|
223
|
+
self.lastrowid = hash(entity.key.name) & 0x7FFFFFFFFFFFFFFF
|
|
224
|
+
|
|
225
|
+
self.rowcount = entities_created
|
|
226
|
+
self._query_rows = iter([])
|
|
227
|
+
self.description = None
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logging.error(f"INSERT failed: {e}")
|
|
231
|
+
raise ProgrammingError(f"INSERT failed: {e}")
|
|
232
|
+
|
|
233
|
+
def _execute_update(self, statement: str, parameters=None):
|
|
234
|
+
"""Execute an UPDATE statement using Datastore client."""
|
|
235
|
+
if parameters is None:
|
|
236
|
+
parameters = {}
|
|
237
|
+
|
|
238
|
+
logging.debug(f"Executing UPDATE: {statement} with parameters: {parameters}")
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
parsed = parse_one(statement)
|
|
242
|
+
if not isinstance(parsed, exp.Update):
|
|
243
|
+
raise ProgrammingError(f"Expected UPDATE statement, got: {type(parsed)}")
|
|
244
|
+
|
|
245
|
+
# Get table/kind name
|
|
246
|
+
table_expr = parsed.this
|
|
247
|
+
if isinstance(table_expr, exp.Table):
|
|
248
|
+
kind = table_expr.name
|
|
249
|
+
else:
|
|
250
|
+
raise ProgrammingError("Could not determine table name from UPDATE")
|
|
251
|
+
|
|
252
|
+
# Get the WHERE clause to find the entity key
|
|
253
|
+
where = parsed.args.get("where")
|
|
254
|
+
if not where:
|
|
255
|
+
raise ProgrammingError("UPDATE without WHERE clause is not supported")
|
|
256
|
+
|
|
257
|
+
# Extract the key ID from WHERE clause (e.g., WHERE id = :id_1)
|
|
258
|
+
entity_key_id = self._extract_key_id_from_where(where, parameters)
|
|
259
|
+
if entity_key_id is None:
|
|
260
|
+
raise ProgrammingError("Could not extract entity key from WHERE clause")
|
|
261
|
+
|
|
262
|
+
# Get the entity
|
|
263
|
+
key = self._datastore_client.key(kind, entity_key_id)
|
|
264
|
+
entity = self._datastore_client.get(key)
|
|
265
|
+
if entity is None:
|
|
266
|
+
self.rowcount = 0
|
|
267
|
+
self._query_rows = iter([])
|
|
268
|
+
self.description = None
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
# Apply the SET values
|
|
272
|
+
for set_expr in parsed.args.get("expressions", []):
|
|
273
|
+
if isinstance(set_expr, exp.EQ):
|
|
274
|
+
col_name = set_expr.left.name if hasattr(set_expr.left, "name") else str(set_expr.left)
|
|
275
|
+
value = self._parse_update_value(set_expr.right, parameters)
|
|
276
|
+
entity[col_name] = value
|
|
277
|
+
|
|
278
|
+
# Save the entity
|
|
279
|
+
self._datastore_client.put(entity)
|
|
280
|
+
self.rowcount = 1
|
|
281
|
+
self._query_rows = iter([])
|
|
282
|
+
self.description = None
|
|
283
|
+
|
|
284
|
+
except Exception as e:
|
|
285
|
+
logging.error(f"UPDATE failed: {e}")
|
|
286
|
+
raise ProgrammingError(f"UPDATE failed: {e}") from e
|
|
287
|
+
|
|
288
|
+
def _execute_delete(self, statement: str, parameters=None):
|
|
289
|
+
"""Execute a DELETE statement using Datastore client."""
|
|
290
|
+
if parameters is None:
|
|
291
|
+
parameters = {}
|
|
292
|
+
|
|
293
|
+
logging.debug(f"Executing DELETE: {statement} with parameters: {parameters}")
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
parsed = parse_one(statement)
|
|
297
|
+
if not isinstance(parsed, exp.Delete):
|
|
298
|
+
raise ProgrammingError(f"Expected DELETE statement, got: {type(parsed)}")
|
|
299
|
+
|
|
300
|
+
# Get table/kind name
|
|
301
|
+
table_expr = parsed.this
|
|
302
|
+
if isinstance(table_expr, exp.Table):
|
|
303
|
+
kind = table_expr.name
|
|
304
|
+
else:
|
|
305
|
+
raise ProgrammingError("Could not determine table name from DELETE")
|
|
306
|
+
|
|
307
|
+
# Get the WHERE clause to find the entity key
|
|
308
|
+
where = parsed.args.get("where")
|
|
309
|
+
if not where:
|
|
310
|
+
raise ProgrammingError("DELETE without WHERE clause is not supported")
|
|
311
|
+
|
|
312
|
+
# Extract the key ID from WHERE clause
|
|
313
|
+
entity_key_id = self._extract_key_id_from_where(where, parameters)
|
|
314
|
+
if entity_key_id is None:
|
|
315
|
+
raise ProgrammingError("Could not extract entity key from WHERE clause")
|
|
316
|
+
|
|
317
|
+
# Delete the entity
|
|
318
|
+
key = self._datastore_client.key(kind, entity_key_id)
|
|
319
|
+
self._datastore_client.delete(key)
|
|
320
|
+
self.rowcount = 1
|
|
321
|
+
self._query_rows = iter([])
|
|
322
|
+
self.description = None
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logging.error(f"DELETE failed: {e}")
|
|
326
|
+
raise ProgrammingError(f"DELETE failed: {e}") from e
|
|
327
|
+
|
|
328
|
+
def _extract_key_id_from_where(self, where_expr, parameters: dict) -> Optional[int]:
|
|
329
|
+
"""Extract entity key ID from WHERE clause."""
|
|
330
|
+
# Handle WHERE id = :param or WHERE id = value
|
|
331
|
+
if isinstance(where_expr, exp.Where):
|
|
332
|
+
where_expr = where_expr.this
|
|
333
|
+
|
|
334
|
+
if isinstance(where_expr, exp.EQ):
|
|
335
|
+
left = where_expr.left
|
|
336
|
+
right = where_expr.right
|
|
337
|
+
|
|
338
|
+
# Check if left side is 'id'
|
|
339
|
+
col_name = left.name if hasattr(left, "name") else str(left)
|
|
340
|
+
if col_name.lower() == "id":
|
|
341
|
+
return self._parse_key_value(right, parameters)
|
|
342
|
+
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
def _parse_key_value(self, val_expr, parameters: dict) -> Optional[int]:
|
|
346
|
+
"""Parse a value expression to get key ID."""
|
|
347
|
+
if isinstance(val_expr, exp.Literal):
|
|
348
|
+
if val_expr.is_number:
|
|
349
|
+
return int(val_expr.this)
|
|
350
|
+
elif isinstance(val_expr, exp.Placeholder):
|
|
351
|
+
param_name = val_expr.name or val_expr.this
|
|
352
|
+
if param_name in parameters:
|
|
353
|
+
return int(parameters[param_name])
|
|
354
|
+
if param_name.startswith(":"):
|
|
355
|
+
param_name = param_name[1:]
|
|
356
|
+
if param_name in parameters:
|
|
357
|
+
return int(parameters[param_name])
|
|
358
|
+
elif isinstance(val_expr, exp.Parameter):
|
|
359
|
+
param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
|
|
360
|
+
if param_name in parameters:
|
|
361
|
+
return int(parameters[param_name])
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
def _parse_update_value(self, val_expr, parameters: dict) -> Any:
|
|
365
|
+
"""Parse a value expression from UPDATE SET clause."""
|
|
366
|
+
if isinstance(val_expr, exp.Literal):
|
|
367
|
+
if val_expr.is_string:
|
|
368
|
+
return val_expr.this
|
|
369
|
+
elif val_expr.is_number:
|
|
370
|
+
text = val_expr.this
|
|
371
|
+
if "." in text:
|
|
372
|
+
return float(text)
|
|
373
|
+
return int(text)
|
|
374
|
+
return val_expr.this
|
|
375
|
+
elif isinstance(val_expr, exp.Null):
|
|
376
|
+
return None
|
|
377
|
+
elif isinstance(val_expr, exp.Boolean):
|
|
378
|
+
return val_expr.this
|
|
379
|
+
elif isinstance(val_expr, exp.Placeholder):
|
|
380
|
+
param_name = val_expr.name or val_expr.this
|
|
381
|
+
if param_name in parameters:
|
|
382
|
+
return parameters[param_name]
|
|
383
|
+
if param_name.startswith(":"):
|
|
384
|
+
param_name = param_name[1:]
|
|
385
|
+
if param_name in parameters:
|
|
386
|
+
return parameters[param_name]
|
|
387
|
+
return None
|
|
388
|
+
elif isinstance(val_expr, exp.Parameter):
|
|
389
|
+
param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
|
|
390
|
+
if param_name in parameters:
|
|
391
|
+
return parameters[param_name]
|
|
392
|
+
return None
|
|
393
|
+
else:
|
|
394
|
+
return str(val_expr.this) if hasattr(val_expr, "this") else str(val_expr)
|
|
395
|
+
|
|
396
|
+
def _parse_insert_value(self, val_expr, parameters: dict) -> Any:
|
|
397
|
+
"""Parse a value expression from INSERT statement."""
|
|
398
|
+
if isinstance(val_expr, exp.Literal):
|
|
399
|
+
if val_expr.is_string:
|
|
400
|
+
return val_expr.this
|
|
401
|
+
elif val_expr.is_number:
|
|
402
|
+
text = val_expr.this
|
|
403
|
+
if "." in text:
|
|
404
|
+
return float(text)
|
|
405
|
+
return int(text)
|
|
406
|
+
return val_expr.this
|
|
407
|
+
elif isinstance(val_expr, exp.Null):
|
|
408
|
+
return None
|
|
409
|
+
elif isinstance(val_expr, exp.Boolean):
|
|
410
|
+
return val_expr.this
|
|
411
|
+
elif isinstance(val_expr, exp.Placeholder):
|
|
412
|
+
# Named parameter like :name
|
|
413
|
+
param_name = val_expr.name or val_expr.this
|
|
414
|
+
if param_name and param_name in parameters:
|
|
415
|
+
return parameters[param_name]
|
|
416
|
+
# Handle :name format
|
|
417
|
+
if param_name and param_name.startswith(":"):
|
|
418
|
+
param_name = param_name[1:]
|
|
419
|
+
if param_name in parameters:
|
|
420
|
+
return parameters[param_name]
|
|
421
|
+
return None
|
|
422
|
+
elif isinstance(val_expr, exp.Parameter):
|
|
423
|
+
# Named parameter
|
|
424
|
+
param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
|
|
425
|
+
if param_name in parameters:
|
|
426
|
+
return parameters[param_name]
|
|
427
|
+
return None
|
|
428
|
+
else:
|
|
429
|
+
# Try to get the string representation
|
|
430
|
+
return str(val_expr.this) if hasattr(val_expr, "this") else str(val_expr)
|
|
431
|
+
|
|
432
|
+
def _is_derived_query(self, tokens: List[tokens.Token]) -> bool:
|
|
433
|
+
"""
|
|
434
|
+
Checks if the SQL statement contains a derived table (subquery in FROM).
|
|
435
|
+
This is a more reliable way to distinguish complex ORM queries from simple GQL.
|
|
436
|
+
"""
|
|
437
|
+
select_seen = 0
|
|
438
|
+
for token in tokens:
|
|
439
|
+
if token.token_type == TokenType.SELECT:
|
|
440
|
+
select_seen += 1
|
|
441
|
+
if select_seen >= 2:
|
|
442
|
+
return True
|
|
443
|
+
return False
|
|
444
|
+
|
|
445
|
+
def _is_aggregation_query(self, statement: str) -> bool:
|
|
446
|
+
"""Check if the statement contains aggregation functions."""
|
|
447
|
+
upper = statement.upper()
|
|
448
|
+
# Check for AGGREGATE ... OVER syntax
|
|
449
|
+
if upper.strip().startswith("AGGREGATE"):
|
|
450
|
+
return True
|
|
451
|
+
# Check for aggregation functions in SELECT
|
|
452
|
+
agg_patterns = [
|
|
453
|
+
r"\bCOUNT\s*\(",
|
|
454
|
+
r"\bCOUNT_UP_TO\s*\(",
|
|
455
|
+
r"\bSUM\s*\(",
|
|
456
|
+
r"\bAVG\s*\(",
|
|
457
|
+
]
|
|
458
|
+
for pattern in agg_patterns:
|
|
459
|
+
if re.search(pattern, upper):
|
|
460
|
+
return True
|
|
461
|
+
return False
|
|
462
|
+
|
|
463
|
+
def _parse_aggregation_query(self, statement: str) -> Dict[str, Any]:
|
|
464
|
+
"""
|
|
465
|
+
Parse aggregation query and return components.
|
|
466
|
+
Returns dict with:
|
|
467
|
+
- 'agg_functions': list of (func_name, column, alias)
|
|
468
|
+
- 'base_query': the underlying SELECT query
|
|
469
|
+
- 'is_aggregate_over': whether it's AGGREGATE...OVER syntax
|
|
470
|
+
"""
|
|
471
|
+
upper = statement.upper().strip()
|
|
472
|
+
result: Dict[str, Any] = {
|
|
473
|
+
"agg_functions": [],
|
|
474
|
+
"base_query": None,
|
|
475
|
+
"is_aggregate_over": False,
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
# Handle AGGREGATE ... OVER (SELECT ...) syntax
|
|
479
|
+
if upper.startswith("AGGREGATE"):
|
|
480
|
+
result["is_aggregate_over"] = True
|
|
481
|
+
# Extract the inner SELECT query
|
|
482
|
+
over_match = re.search(
|
|
483
|
+
r"OVER\s*\(\s*(SELECT\s+.+)\s*\)\s*$",
|
|
484
|
+
statement,
|
|
485
|
+
re.IGNORECASE | re.DOTALL,
|
|
486
|
+
)
|
|
487
|
+
if over_match:
|
|
488
|
+
result["base_query"] = over_match.group(1).strip()
|
|
489
|
+
else:
|
|
490
|
+
# Fallback - extract everything after OVER
|
|
491
|
+
over_idx = upper.find("OVER")
|
|
492
|
+
if over_idx > 0:
|
|
493
|
+
# Extract content inside parentheses
|
|
494
|
+
remaining = statement[over_idx + 4 :].strip()
|
|
495
|
+
if remaining.startswith("("):
|
|
496
|
+
paren_depth = 0
|
|
497
|
+
for i, c in enumerate(remaining):
|
|
498
|
+
if c == "(":
|
|
499
|
+
paren_depth += 1
|
|
500
|
+
elif c == ")":
|
|
501
|
+
paren_depth -= 1
|
|
502
|
+
if paren_depth == 0:
|
|
503
|
+
result["base_query"] = remaining[1:i].strip()
|
|
504
|
+
break
|
|
505
|
+
|
|
506
|
+
# Parse aggregation functions before OVER
|
|
507
|
+
agg_part = statement[: upper.find("OVER")].strip()
|
|
508
|
+
if agg_part.upper().startswith("AGGREGATE"):
|
|
509
|
+
agg_part = agg_part[9:].strip() # Remove "AGGREGATE"
|
|
510
|
+
result["agg_functions"] = self._extract_agg_functions(agg_part)
|
|
511
|
+
else:
|
|
512
|
+
# Handle SELECT COUNT(*), SUM(col), etc.
|
|
513
|
+
result["is_aggregate_over"] = False
|
|
514
|
+
# Parse the SELECT clause to extract aggregation functions
|
|
515
|
+
select_match = re.match(
|
|
516
|
+
r"SELECT\s+(.+?)\s+FROM\s+(.+)$", statement, re.IGNORECASE | re.DOTALL
|
|
517
|
+
)
|
|
518
|
+
if select_match:
|
|
519
|
+
select_clause = select_match.group(1)
|
|
520
|
+
from_clause = select_match.group(2)
|
|
521
|
+
result["agg_functions"] = self._extract_agg_functions(select_clause)
|
|
522
|
+
# Build base query to get all data
|
|
523
|
+
result["base_query"] = f"SELECT * FROM {from_clause}"
|
|
524
|
+
else:
|
|
525
|
+
# Handle SELECT without FROM (e.g., SELECT COUNT(*))
|
|
526
|
+
select_match = re.match(
|
|
527
|
+
r"SELECT\s+(.+)$", statement, re.IGNORECASE | re.DOTALL
|
|
528
|
+
)
|
|
529
|
+
if select_match:
|
|
530
|
+
select_clause = select_match.group(1)
|
|
531
|
+
result["agg_functions"] = self._extract_agg_functions(select_clause)
|
|
532
|
+
result["base_query"] = None # No base query for kindless
|
|
533
|
+
|
|
534
|
+
return result
|
|
535
|
+
|
|
536
|
+
def _extract_agg_functions(self, clause: str) -> List[Tuple[str, str, str]]:
|
|
537
|
+
"""Extract aggregation functions from a clause."""
|
|
538
|
+
functions: List[Tuple[str, str, str]] = []
|
|
539
|
+
# Pattern to match aggregation functions with optional alias
|
|
540
|
+
patterns = [
|
|
541
|
+
(
|
|
542
|
+
r"COUNT_UP_TO\s*\(\s*(\d+)\s*\)(?:\s+AS\s+(\w+))?",
|
|
543
|
+
"COUNT_UP_TO",
|
|
544
|
+
),
|
|
545
|
+
(r"COUNT\s*\(\s*\*\s*\)(?:\s+AS\s+(\w+))?", "COUNT"),
|
|
546
|
+
(r"SUM\s*\(\s*(\w+)\s*\)(?:\s+AS\s+(\w+))?", "SUM"),
|
|
547
|
+
(r"AVG\s*\(\s*(\w+)\s*\)(?:\s+AS\s+(\w+))?", "AVG"),
|
|
548
|
+
]
|
|
549
|
+
|
|
550
|
+
for pattern, func_name in patterns:
|
|
551
|
+
for match in re.finditer(pattern, clause, re.IGNORECASE):
|
|
552
|
+
if func_name == "COUNT":
|
|
553
|
+
col = "*"
|
|
554
|
+
alias = match.group(1) if match.group(1) else func_name
|
|
555
|
+
elif func_name == "COUNT_UP_TO":
|
|
556
|
+
col = match.group(1) # The limit number
|
|
557
|
+
alias = match.group(2) if match.group(2) else func_name
|
|
558
|
+
else:
|
|
559
|
+
col = match.group(1)
|
|
560
|
+
alias = match.group(2) if match.group(2) else func_name
|
|
561
|
+
functions.append((func_name, col, alias))
|
|
562
|
+
|
|
563
|
+
return functions
|
|
564
|
+
|
|
565
|
+
def _compute_aggregations(
|
|
566
|
+
self,
|
|
567
|
+
rows: List[Tuple],
|
|
568
|
+
fields: Dict[str, Any],
|
|
569
|
+
agg_functions: List[Tuple[str, str, str]],
|
|
570
|
+
) -> Tuple[List[Tuple], Dict[str, Any]]:
|
|
571
|
+
"""Compute aggregations on the data."""
|
|
572
|
+
result_values: List[Any] = []
|
|
573
|
+
result_fields: Dict[str, Any] = {}
|
|
574
|
+
|
|
575
|
+
# Get column name to index mapping
|
|
576
|
+
field_names = list(fields.keys())
|
|
577
|
+
|
|
578
|
+
for func_name, col, alias in agg_functions:
|
|
579
|
+
if func_name == "COUNT":
|
|
580
|
+
value = len(rows)
|
|
581
|
+
elif func_name == "COUNT_UP_TO":
|
|
582
|
+
limit = int(col)
|
|
583
|
+
value = min(len(rows), limit)
|
|
584
|
+
elif func_name in ("SUM", "AVG"):
|
|
585
|
+
# Find the column index
|
|
586
|
+
if col in field_names:
|
|
587
|
+
col_idx = field_names.index(col)
|
|
588
|
+
values = [row[col_idx] for row in rows if row[col_idx] is not None]
|
|
589
|
+
numeric_values = [v for v in values if isinstance(v, (int, float))]
|
|
590
|
+
if func_name == "SUM":
|
|
591
|
+
value = sum(numeric_values) if numeric_values else 0
|
|
592
|
+
else: # AVG
|
|
593
|
+
value = (
|
|
594
|
+
sum(numeric_values) / len(numeric_values)
|
|
595
|
+
if numeric_values
|
|
596
|
+
else 0
|
|
597
|
+
)
|
|
598
|
+
else:
|
|
599
|
+
value = 0
|
|
600
|
+
else:
|
|
601
|
+
value = None
|
|
602
|
+
|
|
603
|
+
result_values.append(value)
|
|
604
|
+
result_fields[alias] = (alias, None, None, None, None, None, None)
|
|
605
|
+
|
|
606
|
+
return [tuple(result_values)], result_fields
|
|
607
|
+
|
|
608
|
+
def _execute_gql_request(self, gql_statement: str) -> Response:
|
|
609
|
+
"""Execute a GQL query and return the response."""
|
|
610
|
+
body = {
|
|
611
|
+
"gqlQuery": {
|
|
612
|
+
"queryString": gql_statement,
|
|
613
|
+
"allowLiterals": True,
|
|
614
|
+
}
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
project_id = self._datastore_client.project
|
|
618
|
+
if os.getenv("DATASTORE_EMULATOR_HOST") is None:
|
|
619
|
+
credentials = getattr(
|
|
620
|
+
self._datastore_client, "scoped_credentials", None
|
|
621
|
+
)
|
|
622
|
+
if credentials is None and self._datastore_client.credentials_info:
|
|
623
|
+
credentials = service_account.Credentials.from_service_account_info(
|
|
624
|
+
self._datastore_client.credentials_info,
|
|
625
|
+
scopes=["https://www.googleapis.com/auth/datastore"],
|
|
626
|
+
)
|
|
627
|
+
if credentials is None:
|
|
628
|
+
raise ProgrammingError(
|
|
629
|
+
"No credentials available for Datastore query. "
|
|
630
|
+
"Provide credentials_info, credentials_path, or "
|
|
631
|
+
"configure Application Default Credentials."
|
|
632
|
+
)
|
|
633
|
+
authed_session = AuthorizedSession(credentials)
|
|
634
|
+
url = f"https://datastore.googleapis.com/v1/projects/{project_id}:runQuery"
|
|
635
|
+
return authed_session.post(url, json=body)
|
|
636
|
+
else:
|
|
637
|
+
host = os.environ["DATASTORE_EMULATOR_HOST"]
|
|
638
|
+
url = f"http://{host}/v1/projects/{project_id}:runQuery"
|
|
639
|
+
return requests.post(url, json=body)
|
|
640
|
+
|
|
641
|
+
def _needs_client_side_filter(self, statement: str) -> bool:
|
|
642
|
+
"""Check if the query needs client-side filtering due to unsupported ops.
|
|
643
|
+
|
|
644
|
+
Note: This should be called on the CONVERTED GQL statement (after
|
|
645
|
+
_convert_sql_to_gql), since that method handles reversing sqlglot
|
|
646
|
+
transformations like <> -> != and NOT col IN -> col NOT IN.
|
|
647
|
+
GQL natively supports: =, <, >, <=, >=, !=, IN, NOT IN, CONTAINS.
|
|
648
|
+
"""
|
|
649
|
+
upper = statement.upper()
|
|
650
|
+
unsupported_patterns = [
|
|
651
|
+
r"\bOR\b", # OR conditions need client-side evaluation
|
|
652
|
+
r"\bBLOB\s*\(", # BLOB literal (escaping issues)
|
|
653
|
+
]
|
|
654
|
+
for pattern in unsupported_patterns:
|
|
655
|
+
if re.search(pattern, upper):
|
|
656
|
+
return True
|
|
657
|
+
return False
|
|
658
|
+
|
|
659
|
+
def _extract_base_query_for_filter(self, statement: str) -> str:
|
|
660
|
+
"""Extract base query without WHERE clause for client-side filtering."""
|
|
661
|
+
# Remove WHERE clause to get all data
|
|
662
|
+
upper = statement.upper()
|
|
663
|
+
where_idx = upper.find(" WHERE ")
|
|
664
|
+
if where_idx > 0:
|
|
665
|
+
# Find the end of WHERE (before ORDER BY, LIMIT, OFFSET)
|
|
666
|
+
end_patterns = [" ORDER BY ", " LIMIT ", " OFFSET "]
|
|
667
|
+
end_idx = len(statement)
|
|
668
|
+
for pattern in end_patterns:
|
|
669
|
+
idx = upper.find(pattern, where_idx)
|
|
670
|
+
if idx > 0 and idx < end_idx:
|
|
671
|
+
end_idx = idx
|
|
672
|
+
# Remove WHERE clause
|
|
673
|
+
base = statement[:where_idx] + statement[end_idx:]
|
|
674
|
+
return base.strip()
|
|
675
|
+
return statement
|
|
676
|
+
|
|
677
|
+
def _is_missing_index_error(self, response: Response) -> bool:
|
|
678
|
+
"""Check if the GQL response indicates a missing composite index."""
|
|
679
|
+
if response.status_code not in (400, 409):
|
|
680
|
+
return False
|
|
681
|
+
try:
|
|
682
|
+
body = response.json()
|
|
683
|
+
error = body.get("error", {})
|
|
684
|
+
message = error.get("message", "").lower()
|
|
685
|
+
status = error.get("status", "")
|
|
686
|
+
return (
|
|
687
|
+
"no matching index found" in message
|
|
688
|
+
or status == "FAILED_PRECONDITION"
|
|
689
|
+
)
|
|
690
|
+
except Exception:
|
|
691
|
+
return "no matching index" in response.text.lower()
|
|
692
|
+
|
|
693
|
+
def _extract_table_only_query(self, gql_statement: str) -> str:
|
|
694
|
+
"""Extract just 'SELECT * FROM <table>' from a GQL statement."""
|
|
695
|
+
table_match = re.search(
|
|
696
|
+
r"\bFROM\s+(\w+)", gql_statement, flags=re.IGNORECASE
|
|
697
|
+
)
|
|
698
|
+
if table_match:
|
|
699
|
+
return f"SELECT * FROM {table_match.group(1)}"
|
|
700
|
+
raise ProgrammingError(
|
|
701
|
+
f"Could not extract table name from query: {gql_statement}"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
def _parse_order_by_clause(
|
|
705
|
+
self, gql_statement: str
|
|
706
|
+
) -> List[Tuple[str, bool]]:
|
|
707
|
+
"""Parse ORDER BY clause. Returns list of (column, ascending) tuples."""
|
|
708
|
+
upper = gql_statement.upper()
|
|
709
|
+
order_idx = upper.find(" ORDER BY ")
|
|
710
|
+
if order_idx < 0:
|
|
711
|
+
return []
|
|
712
|
+
# Find end of ORDER BY (before LIMIT, OFFSET)
|
|
713
|
+
end_idx = len(gql_statement)
|
|
714
|
+
for pattern in [" LIMIT ", " OFFSET "]:
|
|
715
|
+
idx = upper.find(pattern, order_idx + 10)
|
|
716
|
+
if 0 < idx < end_idx:
|
|
717
|
+
end_idx = idx
|
|
718
|
+
order_clause = gql_statement[order_idx + 10 : end_idx].strip()
|
|
719
|
+
if not order_clause:
|
|
720
|
+
return []
|
|
721
|
+
result: List[Tuple[str, bool]] = []
|
|
722
|
+
for part in order_clause.split(","):
|
|
723
|
+
part = part.strip()
|
|
724
|
+
parts = part.split()
|
|
725
|
+
if not parts:
|
|
726
|
+
continue
|
|
727
|
+
col_name = parts[0]
|
|
728
|
+
ascending = not (len(parts) > 1 and parts[1].upper() == "DESC")
|
|
729
|
+
result.append((col_name, ascending))
|
|
730
|
+
return result
|
|
731
|
+
|
|
732
|
+
def _parse_limit_offset_clause(
|
|
733
|
+
self, gql_statement: str
|
|
734
|
+
) -> Tuple[Optional[int], int]:
|
|
735
|
+
"""Parse LIMIT and OFFSET from statement. Returns (limit, offset)."""
|
|
736
|
+
limit = None
|
|
737
|
+
offset = 0
|
|
738
|
+
limit_match = re.search(
|
|
739
|
+
r"\bLIMIT\s+(\d+)", gql_statement, flags=re.IGNORECASE
|
|
740
|
+
)
|
|
741
|
+
if limit_match:
|
|
742
|
+
limit = int(limit_match.group(1))
|
|
743
|
+
offset_match = re.search(
|
|
744
|
+
r"\bOFFSET\s+(\d+)", gql_statement, flags=re.IGNORECASE
|
|
745
|
+
)
|
|
746
|
+
if offset_match:
|
|
747
|
+
offset = int(offset_match.group(1))
|
|
748
|
+
return limit, offset
|
|
749
|
+
|
|
750
|
+
def _apply_client_side_order_by(
|
|
751
|
+
self,
|
|
752
|
+
rows: List[Tuple],
|
|
753
|
+
fields: Dict[str, Any],
|
|
754
|
+
order_keys: List[Tuple[str, bool]],
|
|
755
|
+
) -> List[Tuple]:
|
|
756
|
+
"""Sort rows on the client side based on ORDER BY specification."""
|
|
757
|
+
if not order_keys or not rows:
|
|
758
|
+
return rows
|
|
759
|
+
from functools import cmp_to_key
|
|
760
|
+
|
|
761
|
+
field_names = list(fields.keys())
|
|
762
|
+
|
|
763
|
+
def compare_rows(row_a: Tuple, row_b: Tuple) -> int:
|
|
764
|
+
for col_name, ascending in order_keys:
|
|
765
|
+
if col_name in field_names:
|
|
766
|
+
idx = field_names.index(col_name)
|
|
767
|
+
val_a = row_a[idx] if idx < len(row_a) else None
|
|
768
|
+
val_b = row_b[idx] if idx < len(row_b) else None
|
|
769
|
+
else:
|
|
770
|
+
val_a = None
|
|
771
|
+
val_b = None
|
|
772
|
+
if val_a is None and val_b is None:
|
|
773
|
+
continue
|
|
774
|
+
if val_a is None:
|
|
775
|
+
return 1 # None sorts last
|
|
776
|
+
if val_b is None:
|
|
777
|
+
return -1
|
|
778
|
+
try:
|
|
779
|
+
if val_a < val_b:
|
|
780
|
+
cmp_result = -1
|
|
781
|
+
elif val_a > val_b:
|
|
782
|
+
cmp_result = 1
|
|
783
|
+
else:
|
|
784
|
+
continue
|
|
785
|
+
except TypeError:
|
|
786
|
+
continue
|
|
787
|
+
if not ascending:
|
|
788
|
+
cmp_result = -cmp_result
|
|
789
|
+
return cmp_result
|
|
790
|
+
return 0
|
|
791
|
+
|
|
792
|
+
return sorted(rows, key=cmp_to_key(compare_rows))
|
|
793
|
+
|
|
794
|
+
def _execute_fallback_query(
|
|
795
|
+
self, original_statement: str, gql_statement: str
|
|
796
|
+
):
|
|
797
|
+
"""Execute a fallback query when the original fails due to missing index.
|
|
798
|
+
|
|
799
|
+
Fetches all data from the table and applies WHERE, ORDER BY,
|
|
800
|
+
LIMIT/OFFSET on the client side.
|
|
801
|
+
"""
|
|
802
|
+
warning_msg = (
|
|
803
|
+
"Missing index: the query requires an index that does not exist "
|
|
804
|
+
"in Datastore. Falling back to fetching ALL entities and "
|
|
805
|
+
"processing client-side (SELECT * mode). This may significantly "
|
|
806
|
+
"increase query and egress costs. Consider adding the required "
|
|
807
|
+
"composite index to avoid this."
|
|
808
|
+
)
|
|
809
|
+
logging.warning("%s Original GQL: %s", warning_msg, gql_statement)
|
|
810
|
+
self.warnings.append(warning_msg)
|
|
811
|
+
|
|
812
|
+
# Build simple query to fetch all data from the table
|
|
813
|
+
fallback_query = self._extract_table_only_query(gql_statement)
|
|
814
|
+
response = self._execute_gql_request(fallback_query)
|
|
815
|
+
if response.status_code != 200:
|
|
816
|
+
raise OperationalError(
|
|
817
|
+
f"Fallback query failed: {fallback_query} "
|
|
818
|
+
f"(original: {gql_statement})"
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
data = response.json()
|
|
822
|
+
entity_results = data.get("batch", {}).get("entityResults", [])
|
|
823
|
+
|
|
824
|
+
# Initialize cursor state for empty result
|
|
825
|
+
self._query_data = iter([])
|
|
826
|
+
self._query_rows = iter([])
|
|
827
|
+
self.rowcount = 0
|
|
828
|
+
self.description = [(None, None, None, None, None, None, None)]
|
|
829
|
+
self._last_executed = original_statement
|
|
830
|
+
self._parameters = {}
|
|
831
|
+
|
|
832
|
+
if not entity_results:
|
|
833
|
+
return
|
|
834
|
+
|
|
835
|
+
self._closed = False
|
|
836
|
+
|
|
837
|
+
# Parse entities with all columns (needed for filtering/sorting)
|
|
838
|
+
rows, fields = ParseEntity.parse(entity_results, None)
|
|
839
|
+
|
|
840
|
+
# Apply WHERE filter using the original statement to preserve
|
|
841
|
+
# binary data in BLOB literals (whitespace normalization in
|
|
842
|
+
# _convert_sql_to_gql would corrupt them).
|
|
843
|
+
rows = self._apply_client_side_filter(rows, fields, original_statement)
|
|
844
|
+
|
|
845
|
+
# Apply ORDER BY
|
|
846
|
+
order_keys = self._parse_order_by_clause(gql_statement)
|
|
847
|
+
if order_keys:
|
|
848
|
+
rows = self._apply_client_side_order_by(rows, fields, order_keys)
|
|
849
|
+
|
|
850
|
+
# Apply LIMIT/OFFSET
|
|
851
|
+
limit, offset = self._parse_limit_offset_clause(gql_statement)
|
|
852
|
+
if offset > 0:
|
|
853
|
+
rows = rows[offset:]
|
|
854
|
+
if limit is not None:
|
|
855
|
+
rows = rows[:limit]
|
|
856
|
+
|
|
857
|
+
# Project to requested columns if the original query specified them
|
|
858
|
+
selected_columns = self._parse_select_columns(original_statement)
|
|
859
|
+
if selected_columns is not None:
|
|
860
|
+
field_names = list(fields.keys())
|
|
861
|
+
projected_rows: List[Tuple] = []
|
|
862
|
+
projected_fields: Dict[str, Any] = {}
|
|
863
|
+
|
|
864
|
+
for col in selected_columns:
|
|
865
|
+
col_lower = col.lower()
|
|
866
|
+
if col_lower in ("__key__", "key") and "key" in fields:
|
|
867
|
+
projected_fields["key"] = fields["key"]
|
|
868
|
+
elif col in fields:
|
|
869
|
+
projected_fields[col] = fields[col]
|
|
870
|
+
|
|
871
|
+
for row in rows:
|
|
872
|
+
new_row: List[Any] = []
|
|
873
|
+
for col in selected_columns:
|
|
874
|
+
col_lower = col.lower()
|
|
875
|
+
lookup = "key" if col_lower in ("__key__", "key") else col
|
|
876
|
+
if lookup in field_names:
|
|
877
|
+
idx = field_names.index(lookup)
|
|
878
|
+
new_row.append(row[idx] if idx < len(row) else None)
|
|
879
|
+
else:
|
|
880
|
+
new_row.append(None)
|
|
881
|
+
projected_rows.append(tuple(new_row))
|
|
882
|
+
|
|
883
|
+
rows = projected_rows
|
|
884
|
+
fields = projected_fields
|
|
885
|
+
|
|
886
|
+
fields_list = list(fields.values())
|
|
887
|
+
self._query_data = iter(rows)
|
|
888
|
+
self._query_rows = iter(rows)
|
|
889
|
+
self.rowcount = len(rows)
|
|
890
|
+
self.description = fields_list if fields_list else None
|
|
891
|
+
|
|
892
|
+
def _apply_client_side_filter(
|
|
893
|
+
self, rows: List[Tuple], fields: Dict[str, Any], statement: str
|
|
894
|
+
) -> List[Tuple]:
|
|
895
|
+
"""Apply client-side filtering for unsupported WHERE conditions."""
|
|
896
|
+
# Parse WHERE clause and apply filters
|
|
897
|
+
upper = statement.upper()
|
|
898
|
+
where_idx = upper.find(" WHERE ")
|
|
899
|
+
if where_idx < 0:
|
|
900
|
+
return rows
|
|
901
|
+
|
|
902
|
+
# Find end of WHERE clause
|
|
903
|
+
end_patterns = [" ORDER BY ", " LIMIT ", " OFFSET "]
|
|
904
|
+
end_idx = len(statement)
|
|
905
|
+
for pattern in end_patterns:
|
|
906
|
+
idx = upper.find(pattern, where_idx)
|
|
907
|
+
if idx > 0 and idx < end_idx:
|
|
908
|
+
end_idx = idx
|
|
909
|
+
|
|
910
|
+
where_clause = statement[where_idx + 7 : end_idx].strip()
|
|
911
|
+
field_names = list(fields.keys())
|
|
912
|
+
|
|
913
|
+
# Apply filter
|
|
914
|
+
filtered_rows = []
|
|
915
|
+
for row in rows:
|
|
916
|
+
if self._evaluate_where(row, field_names, where_clause):
|
|
917
|
+
filtered_rows.append(row)
|
|
918
|
+
return filtered_rows
|
|
919
|
+
|
|
920
|
+
def _evaluate_where(
|
|
921
|
+
self, row: Tuple, field_names: List[str], where_clause: str
|
|
922
|
+
) -> bool:
|
|
923
|
+
"""Evaluate WHERE clause against a row. Returns True if row matches."""
|
|
924
|
+
# Build a context dict from the row
|
|
925
|
+
context = {}
|
|
926
|
+
for i, name in enumerate(field_names):
|
|
927
|
+
if i < len(row):
|
|
928
|
+
context[name] = row[i]
|
|
929
|
+
|
|
930
|
+
# Parse and evaluate the WHERE clause
|
|
931
|
+
# This is a simplified evaluator for common patterns
|
|
932
|
+
try:
|
|
933
|
+
return self._eval_condition(context, where_clause)
|
|
934
|
+
except Exception:
|
|
935
|
+
# If evaluation fails, include the row (fail open)
|
|
936
|
+
return True
|
|
937
|
+
|
|
938
|
+
def _eval_condition(self, context: Dict[str, Any], condition: str) -> bool:
|
|
939
|
+
"""Evaluate a single condition or compound condition."""
|
|
940
|
+
condition = condition.strip()
|
|
941
|
+
|
|
942
|
+
# Handle parentheses
|
|
943
|
+
if condition.startswith("(") and condition.endswith(")"):
|
|
944
|
+
# Find matching paren
|
|
945
|
+
depth = 0
|
|
946
|
+
for i, c in enumerate(condition):
|
|
947
|
+
if c == "(":
|
|
948
|
+
depth += 1
|
|
949
|
+
elif c == ")":
|
|
950
|
+
depth -= 1
|
|
951
|
+
if depth == 0:
|
|
952
|
+
if i == len(condition) - 1:
|
|
953
|
+
return self._eval_condition(context, condition[1:-1])
|
|
954
|
+
break
|
|
955
|
+
|
|
956
|
+
# Handle OR (lower precedence)
|
|
957
|
+
or_match = re.search(r"\bOR\b", condition, re.IGNORECASE)
|
|
958
|
+
if or_match:
|
|
959
|
+
# Split on OR, but respect parentheses
|
|
960
|
+
parts = self._split_on_operator(condition, "OR")
|
|
961
|
+
if len(parts) > 1:
|
|
962
|
+
return any(self._eval_condition(context, p) for p in parts)
|
|
963
|
+
|
|
964
|
+
# Handle AND (higher precedence)
|
|
965
|
+
and_match = re.search(r"\bAND\b", condition, re.IGNORECASE)
|
|
966
|
+
if and_match:
|
|
967
|
+
parts = self._split_on_operator(condition, "AND")
|
|
968
|
+
if len(parts) > 1:
|
|
969
|
+
return all(self._eval_condition(context, p) for p in parts)
|
|
970
|
+
|
|
971
|
+
# Handle simple comparisons
|
|
972
|
+
return self._eval_simple_condition(context, condition)
|
|
973
|
+
|
|
974
|
+
def _split_on_operator(self, condition: str, operator: str) -> List[str]:
|
|
975
|
+
"""Split condition on operator while respecting parentheses."""
|
|
976
|
+
parts: List[str] = []
|
|
977
|
+
current = ""
|
|
978
|
+
depth = 0
|
|
979
|
+
i = 0
|
|
980
|
+
pattern = re.compile(rf"\b{operator}\b", re.IGNORECASE)
|
|
981
|
+
|
|
982
|
+
while i < len(condition):
|
|
983
|
+
if condition[i] == "(":
|
|
984
|
+
depth += 1
|
|
985
|
+
current += condition[i]
|
|
986
|
+
elif condition[i] == ")":
|
|
987
|
+
depth -= 1
|
|
988
|
+
current += condition[i]
|
|
989
|
+
elif depth == 0:
|
|
990
|
+
match = pattern.match(condition[i:])
|
|
991
|
+
if match:
|
|
992
|
+
parts.append(current.strip())
|
|
993
|
+
current = ""
|
|
994
|
+
i += len(match.group()) - 1
|
|
995
|
+
else:
|
|
996
|
+
current += condition[i]
|
|
997
|
+
else:
|
|
998
|
+
current += condition[i]
|
|
999
|
+
i += 1
|
|
1000
|
+
|
|
1001
|
+
if current.strip():
|
|
1002
|
+
parts.append(current.strip())
|
|
1003
|
+
return parts
|
|
1004
|
+
|
|
1005
|
+
def _eval_simple_condition(self, context: Dict[str, Any], condition: str) -> bool:
|
|
1006
|
+
"""Evaluate a simple comparison condition."""
|
|
1007
|
+
condition = condition.strip()
|
|
1008
|
+
|
|
1009
|
+
# Handle __key__ = KEY(kind, value) comparison
|
|
1010
|
+
# Entity key is stored as "key" in context (from ParseEntity)
|
|
1011
|
+
key_eq_match = re.match(
|
|
1012
|
+
r"__key__\s*=\s*KEY\s*\(\s*\w+\s*,\s*(?:'([^']*)'|(\d+))\s*\)",
|
|
1013
|
+
condition,
|
|
1014
|
+
re.IGNORECASE,
|
|
1015
|
+
)
|
|
1016
|
+
if key_eq_match:
|
|
1017
|
+
key_name = key_eq_match.group(1)
|
|
1018
|
+
key_id = key_eq_match.group(2)
|
|
1019
|
+
field_val = context.get("key") or context.get("__key__")
|
|
1020
|
+
if isinstance(field_val, list) and len(field_val) > 0:
|
|
1021
|
+
last_path = field_val[-1]
|
|
1022
|
+
if isinstance(last_path, dict):
|
|
1023
|
+
if key_name is not None:
|
|
1024
|
+
return last_path.get("name") == key_name
|
|
1025
|
+
if key_id is not None:
|
|
1026
|
+
return str(last_path.get("id")) == key_id
|
|
1027
|
+
return False
|
|
1028
|
+
|
|
1029
|
+
# Handle BLOB equality (before generic handlers, since BLOB literal
|
|
1030
|
+
# would confuse the generic _parse_literal path)
|
|
1031
|
+
blob_eq_match = re.match(
|
|
1032
|
+
r"(\w+)\s*=\s*BLOB\s*\('(.*?)'\)",
|
|
1033
|
+
condition,
|
|
1034
|
+
re.IGNORECASE | re.DOTALL,
|
|
1035
|
+
)
|
|
1036
|
+
if blob_eq_match:
|
|
1037
|
+
field = blob_eq_match.group(1)
|
|
1038
|
+
blob_str = blob_eq_match.group(2)
|
|
1039
|
+
try:
|
|
1040
|
+
blob_bytes = blob_str.encode("latin-1")
|
|
1041
|
+
except (UnicodeEncodeError, UnicodeDecodeError):
|
|
1042
|
+
blob_bytes = blob_str.encode("utf-8")
|
|
1043
|
+
field_val = context.get(field)
|
|
1044
|
+
if isinstance(field_val, bytes):
|
|
1045
|
+
return field_val == blob_bytes
|
|
1046
|
+
return False
|
|
1047
|
+
|
|
1048
|
+
# Handle BLOB inequality
|
|
1049
|
+
blob_neq_match = re.match(
|
|
1050
|
+
r"(\w+)\s*!=\s*BLOB\s*\('(.*?)'\)",
|
|
1051
|
+
condition,
|
|
1052
|
+
re.IGNORECASE | re.DOTALL,
|
|
1053
|
+
)
|
|
1054
|
+
if blob_neq_match:
|
|
1055
|
+
field = blob_neq_match.group(1)
|
|
1056
|
+
blob_str = blob_neq_match.group(2)
|
|
1057
|
+
try:
|
|
1058
|
+
blob_bytes = blob_str.encode("latin-1")
|
|
1059
|
+
except (UnicodeEncodeError, UnicodeDecodeError):
|
|
1060
|
+
blob_bytes = blob_str.encode("utf-8")
|
|
1061
|
+
field_val = context.get(field)
|
|
1062
|
+
if isinstance(field_val, bytes):
|
|
1063
|
+
return field_val != blob_bytes
|
|
1064
|
+
return True
|
|
1065
|
+
|
|
1066
|
+
# Handle NOT IN / NOT IN ARRAY
|
|
1067
|
+
not_in_match = re.match(
|
|
1068
|
+
r"(\w+)\s+NOT\s+IN\s+(?:ARRAY\s*)?\(([^)]+)\)",
|
|
1069
|
+
condition, re.IGNORECASE,
|
|
1070
|
+
)
|
|
1071
|
+
if not_in_match:
|
|
1072
|
+
field = not_in_match.group(1)
|
|
1073
|
+
values_str = not_in_match.group(2)
|
|
1074
|
+
values = self._parse_value_list(values_str)
|
|
1075
|
+
field_val = context.get(field)
|
|
1076
|
+
return field_val not in values
|
|
1077
|
+
|
|
1078
|
+
# Handle IN / IN ARRAY
|
|
1079
|
+
in_match = re.match(
|
|
1080
|
+
r"(\w+)\s+IN\s+(?:ARRAY\s*)?\(([^)]+)\)",
|
|
1081
|
+
condition, re.IGNORECASE,
|
|
1082
|
+
)
|
|
1083
|
+
if in_match:
|
|
1084
|
+
field = in_match.group(1)
|
|
1085
|
+
values_str = in_match.group(2)
|
|
1086
|
+
values = self._parse_value_list(values_str)
|
|
1087
|
+
field_val = context.get(field)
|
|
1088
|
+
return field_val in values
|
|
1089
|
+
|
|
1090
|
+
# Handle != and <>
|
|
1091
|
+
neq_match = re.match(r"(\w+)\s*(?:!=|<>)\s*(.+)", condition, re.IGNORECASE)
|
|
1092
|
+
if neq_match:
|
|
1093
|
+
field = neq_match.group(1)
|
|
1094
|
+
value = self._parse_literal(neq_match.group(2).strip())
|
|
1095
|
+
field_val = context.get(field)
|
|
1096
|
+
return field_val != value
|
|
1097
|
+
|
|
1098
|
+
# Handle >=
|
|
1099
|
+
gte_match = re.match(r"(\w+)\s*>=\s*(.+)", condition)
|
|
1100
|
+
if gte_match:
|
|
1101
|
+
field = gte_match.group(1)
|
|
1102
|
+
value = self._parse_literal(gte_match.group(2).strip())
|
|
1103
|
+
field_val = context.get(field)
|
|
1104
|
+
if field_val is not None and value is not None:
|
|
1105
|
+
try:
|
|
1106
|
+
return field_val >= value
|
|
1107
|
+
except TypeError:
|
|
1108
|
+
return False
|
|
1109
|
+
return False
|
|
1110
|
+
|
|
1111
|
+
# Handle <=
|
|
1112
|
+
lte_match = re.match(r"(\w+)\s*<=\s*(.+)", condition)
|
|
1113
|
+
if lte_match:
|
|
1114
|
+
field = lte_match.group(1)
|
|
1115
|
+
value = self._parse_literal(lte_match.group(2).strip())
|
|
1116
|
+
field_val = context.get(field)
|
|
1117
|
+
if field_val is not None and value is not None:
|
|
1118
|
+
try:
|
|
1119
|
+
return field_val <= value
|
|
1120
|
+
except TypeError:
|
|
1121
|
+
return False
|
|
1122
|
+
return False
|
|
1123
|
+
|
|
1124
|
+
# Handle >
|
|
1125
|
+
gt_match = re.match(r"(\w+)\s*>\s*(.+)", condition)
|
|
1126
|
+
if gt_match:
|
|
1127
|
+
field = gt_match.group(1)
|
|
1128
|
+
value = self._parse_literal(gt_match.group(2).strip())
|
|
1129
|
+
field_val = context.get(field)
|
|
1130
|
+
if field_val is not None and value is not None:
|
|
1131
|
+
try:
|
|
1132
|
+
return field_val > value
|
|
1133
|
+
except TypeError:
|
|
1134
|
+
return False
|
|
1135
|
+
return False
|
|
1136
|
+
|
|
1137
|
+
# Handle <
|
|
1138
|
+
lt_match = re.match(r"(\w+)\s*<\s*(.+)", condition)
|
|
1139
|
+
if lt_match:
|
|
1140
|
+
field = lt_match.group(1)
|
|
1141
|
+
value = self._parse_literal(lt_match.group(2).strip())
|
|
1142
|
+
field_val = context.get(field)
|
|
1143
|
+
if field_val is not None and value is not None:
|
|
1144
|
+
try:
|
|
1145
|
+
return field_val < value
|
|
1146
|
+
except TypeError:
|
|
1147
|
+
return False
|
|
1148
|
+
return False
|
|
1149
|
+
|
|
1150
|
+
# Handle =
|
|
1151
|
+
eq_match = re.match(r"(\w+)\s*=\s*(.+)", condition)
|
|
1152
|
+
if eq_match:
|
|
1153
|
+
field = eq_match.group(1)
|
|
1154
|
+
value = self._parse_literal(eq_match.group(2).strip())
|
|
1155
|
+
field_val = context.get(field)
|
|
1156
|
+
return field_val == value
|
|
1157
|
+
|
|
1158
|
+
# Default: include row
|
|
1159
|
+
return True
|
|
1160
|
+
|
|
1161
|
+
def _parse_value_list(self, values_str: str) -> List[Any]:
|
|
1162
|
+
"""Parse a comma-separated list of values."""
|
|
1163
|
+
values: List[Any] = []
|
|
1164
|
+
for v in values_str.split(","):
|
|
1165
|
+
values.append(self._parse_literal(v.strip()))
|
|
1166
|
+
return values
|
|
1167
|
+
|
|
1168
|
+
def _parse_literal(self, literal: str) -> Any:
|
|
1169
|
+
"""Parse a literal value from string."""
|
|
1170
|
+
literal = literal.strip()
|
|
1171
|
+
# DATETIME literal: DATETIME('2023-01-01T00:00:00Z')
|
|
1172
|
+
datetime_match = re.match(
|
|
1173
|
+
r"DATETIME\s*\(\s*'([^']*)'\s*\)", literal, re.IGNORECASE
|
|
1174
|
+
)
|
|
1175
|
+
if datetime_match:
|
|
1176
|
+
timestamp_str = datetime_match.group(1)
|
|
1177
|
+
if timestamp_str.endswith("Z"):
|
|
1178
|
+
timestamp_str = timestamp_str.replace("Z", "+00:00")
|
|
1179
|
+
# Normalize fractional seconds to 6 digits for Python 3.10
|
|
1180
|
+
# compatibility (fromisoformat only handles 0, 3, or 6 digits).
|
|
1181
|
+
frac_match = re.match(
|
|
1182
|
+
r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})\.(\d+)(.*)",
|
|
1183
|
+
timestamp_str,
|
|
1184
|
+
)
|
|
1185
|
+
if frac_match:
|
|
1186
|
+
frac = frac_match.group(2)[:6].ljust(6, "0")
|
|
1187
|
+
timestamp_str = (
|
|
1188
|
+
frac_match.group(1) + "." + frac + frac_match.group(3)
|
|
1189
|
+
)
|
|
1190
|
+
return datetime.fromisoformat(timestamp_str)
|
|
1191
|
+
# String literal
|
|
1192
|
+
if (literal.startswith("'") and literal.endswith("'")) or (
|
|
1193
|
+
literal.startswith('"') and literal.endswith('"')
|
|
1194
|
+
):
|
|
1195
|
+
return literal[1:-1]
|
|
1196
|
+
# Boolean
|
|
1197
|
+
if literal.upper() == "TRUE":
|
|
1198
|
+
return True
|
|
1199
|
+
if literal.upper() == "FALSE":
|
|
1200
|
+
return False
|
|
1201
|
+
# NULL
|
|
1202
|
+
if literal.upper() == "NULL":
|
|
1203
|
+
return None
|
|
1204
|
+
# Number
|
|
1205
|
+
try:
|
|
1206
|
+
if "." in literal:
|
|
1207
|
+
return float(literal)
|
|
1208
|
+
return int(literal)
|
|
1209
|
+
except ValueError:
|
|
1210
|
+
return literal
|
|
1211
|
+
|
|
1212
|
+
def _is_orm_id_query(self, statement: str) -> bool:
|
|
1213
|
+
"""Check if this is an ORM-style query with table.id in WHERE clause."""
|
|
1214
|
+
upper = statement.upper()
|
|
1215
|
+
# Check for patterns like "table.id = :param" in WHERE clause
|
|
1216
|
+
return (
|
|
1217
|
+
"SELECT" in upper
|
|
1218
|
+
and ".ID" in upper
|
|
1219
|
+
and "WHERE" in upper
|
|
1220
|
+
and (":PK_" in upper or ":ID_" in upper or ".ID =" in upper)
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
def _execute_orm_id_query(self, statement: str, parameters: dict):
|
|
1224
|
+
"""Execute an ORM-style query by ID using direct key lookup."""
|
|
1225
|
+
try:
|
|
1226
|
+
parsed = parse_one(statement)
|
|
1227
|
+
if not isinstance(parsed, exp.Select):
|
|
1228
|
+
raise ProgrammingError("Expected SELECT statement")
|
|
1229
|
+
|
|
1230
|
+
# Get table name
|
|
1231
|
+
from_arg = parsed.args.get("from") or parsed.args.get("from_")
|
|
1232
|
+
if not from_arg:
|
|
1233
|
+
raise ProgrammingError("Could not find FROM clause")
|
|
1234
|
+
table_name = from_arg.this.name if hasattr(from_arg.this, "name") else str(from_arg.this)
|
|
1235
|
+
|
|
1236
|
+
# Extract column aliases from SELECT clause FIRST (before querying)
|
|
1237
|
+
# This ensures we have description even when no entity is found
|
|
1238
|
+
column_info = []
|
|
1239
|
+
for expr in parsed.expressions:
|
|
1240
|
+
if isinstance(expr, exp.Alias):
|
|
1241
|
+
alias = expr.alias
|
|
1242
|
+
if isinstance(expr.this, exp.Column):
|
|
1243
|
+
col_name = expr.this.name
|
|
1244
|
+
else:
|
|
1245
|
+
col_name = str(expr.this)
|
|
1246
|
+
column_info.append((col_name, alias))
|
|
1247
|
+
elif isinstance(expr, exp.Column):
|
|
1248
|
+
col_name = expr.name
|
|
1249
|
+
column_info.append((col_name, col_name))
|
|
1250
|
+
elif isinstance(expr, exp.Star):
|
|
1251
|
+
# SELECT * - we'll handle this after fetching entity
|
|
1252
|
+
column_info = None
|
|
1253
|
+
break
|
|
1254
|
+
|
|
1255
|
+
# Build description from column info (for non-SELECT * cases)
|
|
1256
|
+
if column_info is not None:
|
|
1257
|
+
field_names = [alias for _, alias in column_info]
|
|
1258
|
+
self.description = [
|
|
1259
|
+
(name, None, None, None, None, None, None)
|
|
1260
|
+
for name in field_names
|
|
1261
|
+
]
|
|
1262
|
+
|
|
1263
|
+
# Extract ID from WHERE clause
|
|
1264
|
+
where = parsed.args.get("where")
|
|
1265
|
+
if not where:
|
|
1266
|
+
raise ProgrammingError("Expected WHERE clause")
|
|
1267
|
+
|
|
1268
|
+
entity_key_id = self._extract_key_id_from_where(where, parameters)
|
|
1269
|
+
if entity_key_id is None:
|
|
1270
|
+
raise ProgrammingError("Could not extract key ID from WHERE")
|
|
1271
|
+
|
|
1272
|
+
# Fetch entity by key
|
|
1273
|
+
key = self._datastore_client.key(table_name, entity_key_id)
|
|
1274
|
+
entity = self._datastore_client.get(key)
|
|
1275
|
+
|
|
1276
|
+
if entity is None:
|
|
1277
|
+
# No entity found - description is already set above
|
|
1278
|
+
self._query_rows = iter([])
|
|
1279
|
+
self.rowcount = 0
|
|
1280
|
+
# For SELECT *, set empty description since we don't know the schema
|
|
1281
|
+
if column_info is None:
|
|
1282
|
+
self.description = []
|
|
1283
|
+
return
|
|
1284
|
+
|
|
1285
|
+
# Build result row
|
|
1286
|
+
if column_info is None:
|
|
1287
|
+
# SELECT * case
|
|
1288
|
+
row_values = [entity.key.id] # Add id first
|
|
1289
|
+
field_names = ["id"]
|
|
1290
|
+
for prop_name in sorted(entity.keys()):
|
|
1291
|
+
row_values.append(entity[prop_name])
|
|
1292
|
+
field_names.append(prop_name)
|
|
1293
|
+
# Build description for SELECT *
|
|
1294
|
+
self.description = [
|
|
1295
|
+
(name, None, None, None, None, None, None)
|
|
1296
|
+
for name in field_names
|
|
1297
|
+
]
|
|
1298
|
+
else:
|
|
1299
|
+
row_values = []
|
|
1300
|
+
for col_name, alias in column_info:
|
|
1301
|
+
if col_name.lower() == "id":
|
|
1302
|
+
row_values.append(entity.key.id)
|
|
1303
|
+
else:
|
|
1304
|
+
row_values.append(entity.get(col_name))
|
|
1305
|
+
|
|
1306
|
+
self._query_rows = iter([tuple(row_values)])
|
|
1307
|
+
self.rowcount = 1
|
|
1308
|
+
|
|
1309
|
+
except Exception as e:
|
|
1310
|
+
logging.error(f"ORM ID query failed: {e}")
|
|
1311
|
+
raise ProgrammingError(f"ORM ID query failed: {e}") from e
|
|
1312
|
+
|
|
1313
|
+
def _substitute_parameters(self, statement: str, parameters: dict) -> str:
|
|
1314
|
+
"""Substitute named parameters in SQL statement with their values."""
|
|
1315
|
+
result = statement
|
|
1316
|
+
for param_name, value in parameters.items():
|
|
1317
|
+
# Build the placeholder pattern (e.g., :param_name)
|
|
1318
|
+
placeholder = f":{param_name}"
|
|
1319
|
+
|
|
1320
|
+
# Format the value appropriately for GQL
|
|
1321
|
+
if value is None:
|
|
1322
|
+
formatted_value = "NULL"
|
|
1323
|
+
elif isinstance(value, str):
|
|
1324
|
+
# Escape single quotes in strings
|
|
1325
|
+
escaped = value.replace("'", "''")
|
|
1326
|
+
formatted_value = f"'{escaped}'"
|
|
1327
|
+
elif isinstance(value, bool):
|
|
1328
|
+
formatted_value = "true" if value else "false"
|
|
1329
|
+
elif isinstance(value, (int, float)):
|
|
1330
|
+
formatted_value = str(value)
|
|
1331
|
+
elif isinstance(value, datetime):
|
|
1332
|
+
# Format as ISO string for GQL
|
|
1333
|
+
formatted_value = f"DATETIME('{value.isoformat()}')"
|
|
1334
|
+
else:
|
|
1335
|
+
# Default to string representation
|
|
1336
|
+
formatted_value = f"'{str(value)}'"
|
|
1337
|
+
|
|
1338
|
+
result = result.replace(placeholder, formatted_value)
|
|
1339
|
+
|
|
1340
|
+
return result
|
|
1341
|
+
|
|
1342
|
+
def gql_query(self, statement, parameters=None, **kwargs):
|
|
1343
|
+
"""Execute a GQL query with support for aggregations."""
|
|
1344
|
+
|
|
1345
|
+
# Check for ORM-style queries with table.id in WHERE clause
|
|
1346
|
+
if parameters and self._is_orm_id_query(statement):
|
|
1347
|
+
self._execute_orm_id_query(statement, parameters)
|
|
1348
|
+
return
|
|
1349
|
+
|
|
1350
|
+
# Substitute parameters if provided
|
|
1351
|
+
if parameters:
|
|
1352
|
+
statement = self._substitute_parameters(statement, parameters)
|
|
1353
|
+
|
|
1354
|
+
# Convert SQL to GQL-compatible format
|
|
1355
|
+
gql_statement = self._convert_sql_to_gql(statement)
|
|
1356
|
+
logging.debug(f"Converted GQL statement: {gql_statement}")
|
|
1357
|
+
|
|
1358
|
+
# Check if this is an aggregation query
|
|
1359
|
+
if self._is_aggregation_query(statement):
|
|
1360
|
+
self._execute_aggregation_query(statement, parameters)
|
|
1361
|
+
return
|
|
1362
|
+
|
|
1363
|
+
# Check if we need client-side filtering (check converted GQL)
|
|
1364
|
+
needs_filter = self._needs_client_side_filter(gql_statement)
|
|
1365
|
+
if needs_filter:
|
|
1366
|
+
# Get base query without unsupported WHERE conditions
|
|
1367
|
+
base_query = self._extract_base_query_for_filter(gql_statement)
|
|
1368
|
+
gql_statement = self._convert_sql_to_gql(base_query)
|
|
1369
|
+
|
|
1370
|
+
# Execute GQL query
|
|
1371
|
+
response = self._execute_gql_request(gql_statement)
|
|
1372
|
+
|
|
1373
|
+
if response.status_code == 200:
|
|
1374
|
+
data = response.json()
|
|
1375
|
+
logging.debug(data)
|
|
1376
|
+
else:
|
|
1377
|
+
# Fall back to client-side processing for any GQL failure.
|
|
1378
|
+
# The emulator may return 400 (INVALID_ARGUMENT for !=, NOT IN,
|
|
1379
|
+
# multi-value IN ARRAY), 409/400 (missing composite index), or
|
|
1380
|
+
# 500 (server error). In all cases we fetch all data from the
|
|
1381
|
+
# table and apply WHERE, ORDER BY, LIMIT/OFFSET client-side.
|
|
1382
|
+
# If even the simple fallback query fails, it raises an error.
|
|
1383
|
+
warning_msg = (
|
|
1384
|
+
"GQL query failed. Falling back to fetching ALL entities "
|
|
1385
|
+
"and processing client-side (SELECT * mode). This may "
|
|
1386
|
+
"significantly increase query and egress costs. Consider "
|
|
1387
|
+
"adding the required index for this query."
|
|
1388
|
+
)
|
|
1389
|
+
logging.warning(
|
|
1390
|
+
"%s (status %d)", warning_msg, response.status_code
|
|
1391
|
+
)
|
|
1392
|
+
self.warnings.append(warning_msg)
|
|
1393
|
+
self._execute_fallback_query(statement, statement)
|
|
1394
|
+
return
|
|
1395
|
+
|
|
1396
|
+
self._query_data = iter([])
|
|
1397
|
+
self._query_rows = iter([])
|
|
1398
|
+
self.rowcount = 0
|
|
1399
|
+
self.description = [(None, None, None, None, None, None, None)]
|
|
1400
|
+
self._last_executed = statement
|
|
1401
|
+
self._parameters = parameters or {}
|
|
1402
|
+
|
|
1403
|
+
data = data.get("batch", {}).get("entityResults", [])
|
|
1404
|
+
if len(data) == 0:
|
|
1405
|
+
return # Everything is already set for an empty result
|
|
1406
|
+
|
|
1407
|
+
# Determine if this statement is expected to return rows (e.g., SELECT)
|
|
1408
|
+
# You'll need a way to figure this out based on 'statement' or a flag passed to your custom execute method.
|
|
1409
|
+
# Example (simplified check, you might need a more robust parsing or flag):
|
|
1410
|
+
is_select_statement = statement.upper().strip().startswith("SELECT")
|
|
1411
|
+
|
|
1412
|
+
if is_select_statement:
|
|
1413
|
+
self._closed = False # For SELECT, cursor should remain open to fetch rows
|
|
1414
|
+
|
|
1415
|
+
# Parse the SELECT statement to get column list
|
|
1416
|
+
selected_columns = self._parse_select_columns(statement)
|
|
1417
|
+
|
|
1418
|
+
rows, fields = ParseEntity.parse(data, selected_columns)
|
|
1419
|
+
|
|
1420
|
+
# Apply client-side filtering if needed.
|
|
1421
|
+
# Use the original statement (not the converted GQL) to preserve
|
|
1422
|
+
# binary data inside BLOB literals that whitespace normalization
|
|
1423
|
+
# in _convert_sql_to_gql would corrupt.
|
|
1424
|
+
if needs_filter:
|
|
1425
|
+
rows = self._apply_client_side_filter(rows, fields, statement)
|
|
1426
|
+
|
|
1427
|
+
fields = list(fields.values())
|
|
1428
|
+
self._query_data = iter(rows)
|
|
1429
|
+
self._query_rows = iter(rows)
|
|
1430
|
+
self.rowcount = len(rows)
|
|
1431
|
+
self.description = fields if len(fields) > 0 else None
|
|
1432
|
+
else:
|
|
1433
|
+
# For INSERT/UPDATE/DELETE, the operation is complete, no rows to yield
|
|
1434
|
+
# For INSERT/UPDATE/DELETE, the operation is complete, set rowcount if possible
|
|
1435
|
+
affected_count = len(data) if isinstance(data, list) else 0
|
|
1436
|
+
self.rowcount = affected_count
|
|
1437
|
+
self._closed = True
|
|
1438
|
+
|
|
1439
|
+
def _execute_aggregation_query(self, statement: str, parameters=None):
|
|
1440
|
+
"""Execute an aggregation query with client-side aggregation."""
|
|
1441
|
+
parsed = self._parse_aggregation_query(statement)
|
|
1442
|
+
agg_functions = parsed["agg_functions"]
|
|
1443
|
+
base_query = parsed["base_query"]
|
|
1444
|
+
|
|
1445
|
+
# If there's no base query and no functions, return empty
|
|
1446
|
+
if not agg_functions:
|
|
1447
|
+
self._query_rows = iter([])
|
|
1448
|
+
self.rowcount = 0
|
|
1449
|
+
self.description = []
|
|
1450
|
+
return
|
|
1451
|
+
|
|
1452
|
+
# If there's no base query (e.g., SELECT COUNT(*) without FROM)
|
|
1453
|
+
# Return a count of 0 or handle specially
|
|
1454
|
+
if base_query is None:
|
|
1455
|
+
# For kindless COUNT(*), we return 0 since we can't query all kinds
|
|
1456
|
+
result_values: List[Any] = []
|
|
1457
|
+
result_fields: Dict[str, Any] = {}
|
|
1458
|
+
for func_name, col, alias in agg_functions:
|
|
1459
|
+
if func_name == "COUNT":
|
|
1460
|
+
result_values.append(0)
|
|
1461
|
+
elif func_name == "COUNT_UP_TO":
|
|
1462
|
+
result_values.append(0)
|
|
1463
|
+
else:
|
|
1464
|
+
result_values.append(0)
|
|
1465
|
+
result_fields[alias] = (alias, None, None, None, None, None, None)
|
|
1466
|
+
|
|
1467
|
+
self._query_rows = iter([tuple(result_values)])
|
|
1468
|
+
self.rowcount = 1
|
|
1469
|
+
self.description = list(result_fields.values())
|
|
1470
|
+
return
|
|
1471
|
+
|
|
1472
|
+
# Convert to GQL first, then check for client-side filtering
|
|
1473
|
+
base_gql = self._convert_sql_to_gql(base_query)
|
|
1474
|
+
original_base_gql = base_gql # Save for potential fallback
|
|
1475
|
+
needs_filter = self._needs_client_side_filter(base_gql)
|
|
1476
|
+
if needs_filter:
|
|
1477
|
+
filter_query = self._extract_base_query_for_filter(base_gql)
|
|
1478
|
+
base_gql = self._convert_sql_to_gql(filter_query)
|
|
1479
|
+
|
|
1480
|
+
response = self._execute_gql_request(base_gql)
|
|
1481
|
+
|
|
1482
|
+
if response.status_code != 200:
|
|
1483
|
+
warning_msg = (
|
|
1484
|
+
"Aggregation base query failed. Falling back to fetching "
|
|
1485
|
+
"ALL entities and aggregating client-side (SELECT * mode). "
|
|
1486
|
+
"This may significantly increase query and egress costs. "
|
|
1487
|
+
"Consider adding the required index for the columns used "
|
|
1488
|
+
"in this aggregation."
|
|
1489
|
+
)
|
|
1490
|
+
logging.warning(
|
|
1491
|
+
"%s (status %d)", warning_msg, response.status_code
|
|
1492
|
+
)
|
|
1493
|
+
self.warnings.append(warning_msg)
|
|
1494
|
+
fallback_query = self._extract_table_only_query(
|
|
1495
|
+
original_base_gql
|
|
1496
|
+
)
|
|
1497
|
+
response = self._execute_gql_request(fallback_query)
|
|
1498
|
+
if response.status_code != 200:
|
|
1499
|
+
raise OperationalError(
|
|
1500
|
+
f"Aggregation fallback query failed: "
|
|
1501
|
+
f"{fallback_query} (original: {statement})"
|
|
1502
|
+
)
|
|
1503
|
+
fb_data = response.json()
|
|
1504
|
+
fb_results = fb_data.get("batch", {}).get(
|
|
1505
|
+
"entityResults", []
|
|
1506
|
+
)
|
|
1507
|
+
if not fb_results:
|
|
1508
|
+
result_values: List[Any] = []
|
|
1509
|
+
result_fields: Dict[str, Any] = {}
|
|
1510
|
+
for _fn, _col, alias in agg_functions:
|
|
1511
|
+
result_values.append(0)
|
|
1512
|
+
result_fields[alias] = (
|
|
1513
|
+
alias, None, None, None, None, None, None
|
|
1514
|
+
)
|
|
1515
|
+
self._query_rows = iter([tuple(result_values)])
|
|
1516
|
+
self.rowcount = 1
|
|
1517
|
+
self.description = list(result_fields.values())
|
|
1518
|
+
return
|
|
1519
|
+
rows, fields = ParseEntity.parse(fb_results, None)
|
|
1520
|
+
rows = self._apply_client_side_filter(
|
|
1521
|
+
rows, fields, base_query
|
|
1522
|
+
)
|
|
1523
|
+
agg_rows, agg_fields = self._compute_aggregations(
|
|
1524
|
+
rows, fields, agg_functions
|
|
1525
|
+
)
|
|
1526
|
+
self._query_rows = iter(agg_rows)
|
|
1527
|
+
self.rowcount = len(agg_rows)
|
|
1528
|
+
self.description = list(agg_fields.values())
|
|
1529
|
+
return
|
|
1530
|
+
|
|
1531
|
+
data = response.json()
|
|
1532
|
+
entity_results = data.get("batch", {}).get("entityResults", [])
|
|
1533
|
+
|
|
1534
|
+
if len(entity_results) == 0:
|
|
1535
|
+
# No data - return aggregations with 0 values
|
|
1536
|
+
result_values = []
|
|
1537
|
+
result_fields: Dict[str, Any] = {}
|
|
1538
|
+
for func_name, _col, alias in agg_functions:
|
|
1539
|
+
if func_name == "COUNT":
|
|
1540
|
+
result_values.append(0)
|
|
1541
|
+
elif func_name == "COUNT_UP_TO":
|
|
1542
|
+
result_values.append(0)
|
|
1543
|
+
elif func_name in ("SUM", "AVG"):
|
|
1544
|
+
result_values.append(0)
|
|
1545
|
+
else:
|
|
1546
|
+
result_values.append(None)
|
|
1547
|
+
result_fields[alias] = (alias, None, None, None, None, None, None)
|
|
1548
|
+
|
|
1549
|
+
self._query_rows = iter([tuple(result_values)])
|
|
1550
|
+
self.rowcount = 1
|
|
1551
|
+
self.description = list(result_fields.values())
|
|
1552
|
+
return
|
|
1553
|
+
|
|
1554
|
+
# Parse the entity results
|
|
1555
|
+
rows, fields = ParseEntity.parse(entity_results, None)
|
|
1556
|
+
|
|
1557
|
+
# Apply client-side filtering if needed
|
|
1558
|
+
if needs_filter:
|
|
1559
|
+
rows = self._apply_client_side_filter(rows, fields, base_query)
|
|
1560
|
+
|
|
1561
|
+
# Compute aggregations
|
|
1562
|
+
agg_rows, agg_fields = self._compute_aggregations(rows, fields, agg_functions)
|
|
1563
|
+
|
|
1564
|
+
self._query_rows = iter(agg_rows)
|
|
1565
|
+
self.rowcount = len(agg_rows)
|
|
1566
|
+
self.description = list(agg_fields.values())
|
|
1567
|
+
|
|
1568
|
+
def execute_orm(
|
|
1569
|
+
self, statement: str, parameters=None, tokens: List[tokens.Token] = []
|
|
1570
|
+
):
|
|
1571
|
+
if parameters is None:
|
|
1572
|
+
parameters = {}
|
|
1573
|
+
|
|
1574
|
+
logging.debug(
|
|
1575
|
+
f"[DataStore DBAPI] Executing ORM query: {statement} with parameters: {parameters}"
|
|
1576
|
+
)
|
|
1577
|
+
|
|
1578
|
+
statement = statement.replace("`", "'")
|
|
1579
|
+
parsed = parse_one(statement)
|
|
1580
|
+
# Note: sqlglot uses "from_" as the key, not "from"
|
|
1581
|
+
from_arg = parsed.args.get("from") or parsed.args.get("from_")
|
|
1582
|
+
if not isinstance(parsed, exp.Select) or not from_arg:
|
|
1583
|
+
raise ProgrammingError("Unsupported ORM query structure.")
|
|
1584
|
+
|
|
1585
|
+
from_clause = from_arg.this
|
|
1586
|
+
if not isinstance(from_clause, exp.Subquery):
|
|
1587
|
+
raise ProgrammingError("Expected a subquery in the FROM clause.")
|
|
1588
|
+
|
|
1589
|
+
subquery_sql = from_clause.this.sql()
|
|
1590
|
+
|
|
1591
|
+
# 1. Query the subquery table
|
|
1592
|
+
self.gql_query(subquery_sql)
|
|
1593
|
+
subquery_results = self.fetchall()
|
|
1594
|
+
subquery_description = self.description
|
|
1595
|
+
|
|
1596
|
+
# 2. Turn to pandas dataframe
|
|
1597
|
+
if not subquery_description:
|
|
1598
|
+
df = pd.DataFrame(subquery_results)
|
|
1599
|
+
else:
|
|
1600
|
+
column_names = [col[0] for col in subquery_description]
|
|
1601
|
+
df = pd.DataFrame(subquery_results, columns=column_names)
|
|
1602
|
+
|
|
1603
|
+
# Add computed columns from SELECT expressions before grouping or ordering
|
|
1604
|
+
for p in parsed.expressions:
|
|
1605
|
+
if isinstance(p, exp.Alias) and not p.find(exp.AggFunc):
|
|
1606
|
+
# This is a simplified expression evaluator for computed columns.
|
|
1607
|
+
# It converts "col" to col and leaves other things as is.
|
|
1608
|
+
expr_str = re.sub(r'"(\w+)"', r"\1", p.this.sql())
|
|
1609
|
+
try:
|
|
1610
|
+
# Use assign to add new columns based on expressions
|
|
1611
|
+
df = df.assign(**{p.alias: df.eval(expr_str, engine="python")})
|
|
1612
|
+
except Exception as e:
|
|
1613
|
+
logging.warning(f"Could not evaluate expression '{expr_str}': {e}")
|
|
1614
|
+
|
|
1615
|
+
# 3. Apply outer query logic (aggregations and GROUP BY)
|
|
1616
|
+
has_agg = any(
|
|
1617
|
+
isinstance(p, exp.Alias) and p.find(exp.AggFunc)
|
|
1618
|
+
for p in parsed.expressions
|
|
1619
|
+
)
|
|
1620
|
+
|
|
1621
|
+
if parsed.args.get("group"):
|
|
1622
|
+
group_by_cols = []
|
|
1623
|
+
for e in parsed.args.get("group").expressions:
|
|
1624
|
+
col_name = e.name if hasattr(e, "name") else ""
|
|
1625
|
+
if col_name and col_name in df.columns:
|
|
1626
|
+
group_by_cols.append(col_name)
|
|
1627
|
+
else:
|
|
1628
|
+
# Function expression (e.g. DATETIME_TRUNC) — find
|
|
1629
|
+
# the matching alias in the SELECT clause.
|
|
1630
|
+
expr_sql = e.sql()
|
|
1631
|
+
matched = False
|
|
1632
|
+
for p in parsed.expressions:
|
|
1633
|
+
if isinstance(p, exp.Alias) and p.this.sql() == expr_sql:
|
|
1634
|
+
group_by_cols.append(p.alias)
|
|
1635
|
+
matched = True
|
|
1636
|
+
break
|
|
1637
|
+
if not matched:
|
|
1638
|
+
group_by_cols.append(col_name)
|
|
1639
|
+
|
|
1640
|
+
# Convert unhashable types (lists, dicts) to hashable types for groupby.
|
|
1641
|
+
# Datastore keys are stored as lists of dicts, GeoPoints as dicts.
|
|
1642
|
+
converted_cols = {}
|
|
1643
|
+
for col in group_by_cols:
|
|
1644
|
+
if col in df.columns:
|
|
1645
|
+
sample = df[col].dropna().head(1)
|
|
1646
|
+
if len(sample) > 0 and isinstance(sample.iloc[0], list):
|
|
1647
|
+
converted_cols[col] = df[col].apply(
|
|
1648
|
+
lambda x: tuple(
|
|
1649
|
+
tuple(sorted(d.items()))
|
|
1650
|
+
if isinstance(d, dict)
|
|
1651
|
+
else d
|
|
1652
|
+
for d in x
|
|
1653
|
+
)
|
|
1654
|
+
if isinstance(x, list)
|
|
1655
|
+
else x
|
|
1656
|
+
)
|
|
1657
|
+
df[col] = converted_cols[col]
|
|
1658
|
+
elif len(sample) > 0 and isinstance(sample.iloc[0], dict):
|
|
1659
|
+
converted_cols[col] = df[col].apply(
|
|
1660
|
+
lambda x: tuple(sorted(x.items()))
|
|
1661
|
+
if isinstance(x, dict)
|
|
1662
|
+
else x
|
|
1663
|
+
)
|
|
1664
|
+
df[col] = converted_cols[col]
|
|
1665
|
+
|
|
1666
|
+
col_renames = {}
|
|
1667
|
+
for p in parsed.expressions:
|
|
1668
|
+
if isinstance(p, exp.Alias) and p.find(exp.AggFunc):
|
|
1669
|
+
agg_func = p.this
|
|
1670
|
+
agg_func_name = agg_func.key.lower()
|
|
1671
|
+
# Map SQL aggregate names to pandas equivalents
|
|
1672
|
+
sql_to_pandas_agg = {"avg": "mean"}
|
|
1673
|
+
agg_func_name = sql_to_pandas_agg.get(
|
|
1674
|
+
agg_func_name, agg_func_name
|
|
1675
|
+
)
|
|
1676
|
+
if agg_func.expressions:
|
|
1677
|
+
original_col_name = agg_func.expressions[0].name
|
|
1678
|
+
elif isinstance(agg_func.this, exp.Distinct):
|
|
1679
|
+
# COUNT(DISTINCT col) - column is inside Distinct
|
|
1680
|
+
original_col_name = (
|
|
1681
|
+
agg_func.this.expressions[0].name
|
|
1682
|
+
)
|
|
1683
|
+
# Use pandas nunique for COUNT(DISTINCT)
|
|
1684
|
+
agg_func_name = "nunique"
|
|
1685
|
+
elif isinstance(agg_func.this, exp.Star):
|
|
1686
|
+
# COUNT(*) - use first group_by column for counting
|
|
1687
|
+
original_col_name = group_by_cols[0]
|
|
1688
|
+
elif agg_func.this is not None and hasattr(
|
|
1689
|
+
agg_func.this, "name"
|
|
1690
|
+
):
|
|
1691
|
+
original_col_name = agg_func.this.name
|
|
1692
|
+
else:
|
|
1693
|
+
# Fallback for unknown structures
|
|
1694
|
+
original_col_name = group_by_cols[0]
|
|
1695
|
+
desired_sql_alias = p.alias_or_name
|
|
1696
|
+
col_renames = {"temp_agg": desired_sql_alias}
|
|
1697
|
+
df = (
|
|
1698
|
+
df.groupby(group_by_cols)
|
|
1699
|
+
.agg(temp_agg=(original_col_name, agg_func_name))
|
|
1700
|
+
.reset_index()
|
|
1701
|
+
.rename(columns=col_renames)
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
elif has_agg:
|
|
1705
|
+
# Aggregation without GROUP BY (e.g., SELECT COUNT(*) FROM table)
|
|
1706
|
+
result_data: Dict[str, Any] = {}
|
|
1707
|
+
for p in parsed.expressions:
|
|
1708
|
+
if not isinstance(p, exp.Alias) or not p.find(exp.AggFunc):
|
|
1709
|
+
continue
|
|
1710
|
+
agg_func = p.this
|
|
1711
|
+
agg_func_name = agg_func.key.lower()
|
|
1712
|
+
alias = p.alias_or_name
|
|
1713
|
+
|
|
1714
|
+
if agg_func_name == "count":
|
|
1715
|
+
if isinstance(agg_func.this, exp.Star):
|
|
1716
|
+
result_data[alias] = len(df)
|
|
1717
|
+
elif isinstance(agg_func.this, exp.Distinct):
|
|
1718
|
+
col_name = agg_func.this.expressions[0].name
|
|
1719
|
+
result_data[alias] = df[col_name].nunique()
|
|
1720
|
+
elif agg_func.expressions:
|
|
1721
|
+
col_name = agg_func.expressions[0].name
|
|
1722
|
+
result_data[alias] = df[col_name].count()
|
|
1723
|
+
else:
|
|
1724
|
+
result_data[alias] = len(df)
|
|
1725
|
+
elif agg_func_name == "sum":
|
|
1726
|
+
col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
|
|
1727
|
+
result_data[alias] = df[col_name].sum()
|
|
1728
|
+
elif agg_func_name == "avg":
|
|
1729
|
+
col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
|
|
1730
|
+
result_data[alias] = df[col_name].mean()
|
|
1731
|
+
elif agg_func_name == "min":
|
|
1732
|
+
col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
|
|
1733
|
+
result_data[alias] = df[col_name].min()
|
|
1734
|
+
elif agg_func_name == "max":
|
|
1735
|
+
col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
|
|
1736
|
+
result_data[alias] = df[col_name].max()
|
|
1737
|
+
else:
|
|
1738
|
+
result_data[alias] = None
|
|
1739
|
+
|
|
1740
|
+
df = pd.DataFrame([result_data])
|
|
1741
|
+
|
|
1742
|
+
if parsed.args.get("order"):
|
|
1743
|
+
order_by_cols = [e.this.name for e in parsed.args["order"].expressions]
|
|
1744
|
+
ascending = [
|
|
1745
|
+
not e.args.get("desc", False) for e in parsed.args["order"].expressions
|
|
1746
|
+
]
|
|
1747
|
+
# Convert uncomparable types (dicts, lists) to strings for sorting.
|
|
1748
|
+
# Datastore keys are lists of dicts and GeoPoints are dicts, which
|
|
1749
|
+
# cannot be compared with < in Python 3.
|
|
1750
|
+
for col in order_by_cols:
|
|
1751
|
+
if col in df.columns:
|
|
1752
|
+
sample = df[col].dropna().head(1)
|
|
1753
|
+
if len(sample) > 0 and isinstance(
|
|
1754
|
+
sample.iloc[0], (dict, list)
|
|
1755
|
+
):
|
|
1756
|
+
df[col] = df[col].apply(
|
|
1757
|
+
lambda x: str(x) if isinstance(x, (dict, list)) else x
|
|
1758
|
+
)
|
|
1759
|
+
df = df.sort_values(by=order_by_cols, ascending=ascending)
|
|
1760
|
+
|
|
1761
|
+
if parsed.args.get("limit"):
|
|
1762
|
+
limit = int(parsed.args["limit"].expression.sql())
|
|
1763
|
+
df = df.head(limit)
|
|
1764
|
+
|
|
1765
|
+
# Final column selection
|
|
1766
|
+
if not any(isinstance(p, exp.Star) for p in parsed.expressions):
|
|
1767
|
+
final_columns = [p.alias_or_name for p in parsed.expressions]
|
|
1768
|
+
# Ensure all selected columns exist in the DataFrame before selecting
|
|
1769
|
+
df = df[[col for col in final_columns if col in df.columns]]
|
|
1770
|
+
|
|
1771
|
+
# Finalize results
|
|
1772
|
+
rows = [tuple(x) for x in df.to_numpy()]
|
|
1773
|
+
schema = self._create_schema_from_df(df)
|
|
1774
|
+
self.rowcount = len(rows) if rows else 0
|
|
1775
|
+
self._set_description(schema)
|
|
1776
|
+
self._query_rows = iter(rows)
|
|
1777
|
+
|
|
1778
|
+
def _create_schema_from_df(self, df: pd.DataFrame) -> tuple:
|
|
1779
|
+
"""Create schema from a pandas DataFrame."""
|
|
1780
|
+
schema = []
|
|
1781
|
+
for col_name, dtype in df.dtypes.items():
|
|
1782
|
+
if pd.api.types.is_string_dtype(dtype):
|
|
1783
|
+
sa_type = types.String
|
|
1784
|
+
elif pd.api.types.is_integer_dtype(dtype):
|
|
1785
|
+
sa_type = types.Integer
|
|
1786
|
+
elif pd.api.types.is_float_dtype(dtype):
|
|
1787
|
+
sa_type = types.Float
|
|
1788
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
|
1789
|
+
sa_type = types.Boolean
|
|
1790
|
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
|
1791
|
+
sa_type = types.DateTime
|
|
1792
|
+
else:
|
|
1793
|
+
sa_type = types.String # Fallback
|
|
1794
|
+
|
|
1795
|
+
schema.append(
|
|
1796
|
+
Column(
|
|
1797
|
+
name=col_name,
|
|
1798
|
+
type_code=sa_type(),
|
|
1799
|
+
display_size=None,
|
|
1800
|
+
internal_size=None,
|
|
1801
|
+
precision=None,
|
|
1802
|
+
scale=None,
|
|
1803
|
+
null_ok=True,
|
|
1804
|
+
)
|
|
1805
|
+
)
|
|
1806
|
+
return tuple(schema)
|
|
1807
|
+
|
|
1808
|
+
def _set_description(self, schema: tuple = ()):
|
|
1809
|
+
"""Set the cursor description based on the schema."""
|
|
1810
|
+
self.description = schema
|
|
1811
|
+
|
|
1812
|
+
def fetchall(self):
|
|
1813
|
+
if self._closed:
|
|
1814
|
+
raise Error("Cursor is closed.")
|
|
1815
|
+
return list(self._query_rows)
|
|
1816
|
+
|
|
1817
|
+
def fetchmany(self, size=None):
|
|
1818
|
+
if self._closed:
|
|
1819
|
+
raise Error("Cursor is closed.")
|
|
1820
|
+
if size is None:
|
|
1821
|
+
size = self.arraysize or 1
|
|
1822
|
+
results = []
|
|
1823
|
+
for _ in range(size):
|
|
1824
|
+
try:
|
|
1825
|
+
results.append(next(self._query_rows))
|
|
1826
|
+
except StopIteration:
|
|
1827
|
+
break
|
|
1828
|
+
return results
|
|
1829
|
+
|
|
1830
|
+
def fetchone(self):
|
|
1831
|
+
if self._closed:
|
|
1832
|
+
raise Error("Cursor is closed.")
|
|
1833
|
+
try:
|
|
1834
|
+
return next(self._query_rows)
|
|
1835
|
+
except StopIteration:
|
|
1836
|
+
return None
|
|
1837
|
+
|
|
1838
|
+
def _parse_select_columns(self, statement: str) -> Optional[List[str]]:
|
|
1839
|
+
"""
|
|
1840
|
+
Parse SELECT statement to extract column names.
|
|
1841
|
+
Returns None for SELECT * (all columns)
|
|
1842
|
+
"""
|
|
1843
|
+
try:
|
|
1844
|
+
# Use sqlglot to parse the statement
|
|
1845
|
+
parsed = parse_one(statement)
|
|
1846
|
+
if not isinstance(parsed, exp.Select):
|
|
1847
|
+
return None
|
|
1848
|
+
|
|
1849
|
+
columns = []
|
|
1850
|
+
for expr in parsed.expressions:
|
|
1851
|
+
if isinstance(expr, exp.Star):
|
|
1852
|
+
# SELECT * - return None to indicate all columns
|
|
1853
|
+
return None
|
|
1854
|
+
elif isinstance(expr, exp.Column):
|
|
1855
|
+
# Direct column reference
|
|
1856
|
+
col_name = expr.name
|
|
1857
|
+
# Map 'id' to '__key__' since Datastore uses keys, not id properties
|
|
1858
|
+
if col_name.lower() == "id":
|
|
1859
|
+
col_name = "__key__"
|
|
1860
|
+
columns.append(col_name)
|
|
1861
|
+
elif isinstance(expr, exp.Alias):
|
|
1862
|
+
# Column with alias
|
|
1863
|
+
if isinstance(expr.this, exp.Column):
|
|
1864
|
+
col_name = expr.this.name
|
|
1865
|
+
columns.append(col_name)
|
|
1866
|
+
else:
|
|
1867
|
+
# For complex expressions, use the alias
|
|
1868
|
+
columns.append(expr.alias)
|
|
1869
|
+
else:
|
|
1870
|
+
# For other expressions, try to get the name or use the string representation
|
|
1871
|
+
col_name = expr.alias_or_name
|
|
1872
|
+
if col_name:
|
|
1873
|
+
columns.append(col_name)
|
|
1874
|
+
|
|
1875
|
+
return columns if columns else None
|
|
1876
|
+
except Exception:
|
|
1877
|
+
# If parsing fails, return None to get all columns
|
|
1878
|
+
return None
|
|
1879
|
+
|
|
1880
|
+
def _convert_sql_to_gql(self, statement: str) -> str:
|
|
1881
|
+
"""
|
|
1882
|
+
Convert SQL statements to GQL-compatible format.
|
|
1883
|
+
|
|
1884
|
+
GQL (Google Query Language) is similar to SQL but has its own syntax.
|
|
1885
|
+
This method reverses transformations applied by Superset's sqlglot
|
|
1886
|
+
processing (BigQuery dialect) and makes other adjustments for GQL
|
|
1887
|
+
compatibility.
|
|
1888
|
+
"""
|
|
1889
|
+
# AGGREGATE queries are valid GQL - pass through directly
|
|
1890
|
+
if statement.strip().upper().startswith("AGGREGATE"):
|
|
1891
|
+
return statement
|
|
1892
|
+
|
|
1893
|
+
# Normalize whitespace: sqlglot pretty-prints with newlines which
|
|
1894
|
+
# breaks position-based string operations (find, regex).
|
|
1895
|
+
statement = re.sub(r"\s+", " ", statement).strip()
|
|
1896
|
+
|
|
1897
|
+
# === Reverse sqlglot / BigQuery dialect transformations ===
|
|
1898
|
+
|
|
1899
|
+
# 1. Convert <> back to != (sqlglot BigQuery dialect converts != to <>)
|
|
1900
|
+
# GQL uses != for not-equals comparisons.
|
|
1901
|
+
statement = re.sub(r"<>", "!=", statement)
|
|
1902
|
+
|
|
1903
|
+
# 2. Fix NOT ... IN -> ... NOT IN
|
|
1904
|
+
# sqlglot converts "col NOT IN (...)" to "NOT col IN (...)"
|
|
1905
|
+
# GQL expects "col NOT IN (...)"
|
|
1906
|
+
statement = re.sub(
|
|
1907
|
+
r"\bNOT\s+(\w+)\s+IN\s*\(",
|
|
1908
|
+
r"\1 NOT IN (",
|
|
1909
|
+
statement,
|
|
1910
|
+
flags=re.IGNORECASE,
|
|
1911
|
+
)
|
|
1912
|
+
|
|
1913
|
+
# 3. Strip ROW_NUMBER() OVER (...) added by sqlglot for DISTINCT ON
|
|
1914
|
+
# BigQuery dialect converts "SELECT DISTINCT ON (col) * FROM t"
|
|
1915
|
+
# to "SELECT *, ROW_NUMBER() OVER (PARTITION BY col ...) AS _row_... FROM t"
|
|
1916
|
+
# We strip the ROW_NUMBER expression and any trailing WHERE _row_... = 1
|
|
1917
|
+
statement = re.sub(
|
|
1918
|
+
r",\s*ROW_NUMBER\s*\(\s*\)\s*OVER\s*\([^)]*\)\s*(?:AS\s+\w+)?",
|
|
1919
|
+
"",
|
|
1920
|
+
statement,
|
|
1921
|
+
flags=re.IGNORECASE,
|
|
1922
|
+
)
|
|
1923
|
+
# Also remove the WHERE _row_number = 1 subquery wrapper if present
|
|
1924
|
+
statement = re.sub(
|
|
1925
|
+
r"\bWHERE\s+_row_\w+\s*=\s*1\b",
|
|
1926
|
+
"",
|
|
1927
|
+
statement,
|
|
1928
|
+
flags=re.IGNORECASE,
|
|
1929
|
+
)
|
|
1930
|
+
|
|
1931
|
+
# 4. Fix IN clause syntax for GQL
|
|
1932
|
+
# a) Convert square bracket arrays: IN ['val'] -> IN ARRAY('val')
|
|
1933
|
+
# b) Convert parenthesized lists: IN ('val1', 'val2') -> IN ARRAY('val1', 'val2')
|
|
1934
|
+
# GQL requires the ARRAY keyword: "name IN ARRAY('val1', 'val2')"
|
|
1935
|
+
# NOT IN also needs: "name NOT IN ARRAY('val1', 'val2')"
|
|
1936
|
+
statement = re.sub(
|
|
1937
|
+
r"\bIN\s*\[([^\]]*)\]",
|
|
1938
|
+
r"IN ARRAY(\1)",
|
|
1939
|
+
statement,
|
|
1940
|
+
flags=re.IGNORECASE,
|
|
1941
|
+
)
|
|
1942
|
+
# Convert IN (...) to IN ARRAY(...) but don't double-convert IN ARRAY(...)
|
|
1943
|
+
statement = re.sub(
|
|
1944
|
+
r"\bIN\s*\((?![\s]*SELECT\b)",
|
|
1945
|
+
"IN ARRAY(",
|
|
1946
|
+
statement,
|
|
1947
|
+
flags=re.IGNORECASE,
|
|
1948
|
+
)
|
|
1949
|
+
# Fix double ARRAY: if original was already ARRAY, we'd get IN ARRAY(ARRAY(...)
|
|
1950
|
+
statement = re.sub(
|
|
1951
|
+
r"\bARRAY\s*\(\s*ARRAY\s*\(",
|
|
1952
|
+
"ARRAY(",
|
|
1953
|
+
statement,
|
|
1954
|
+
flags=re.IGNORECASE,
|
|
1955
|
+
)
|
|
1956
|
+
|
|
1957
|
+
# 5. Fix WHERE NULL (from sqlglot optimizing "col = NULL" to "NULL")
|
|
1958
|
+
# sqlglot treats "col = NULL" as always-false and collapses to NULL.
|
|
1959
|
+
# We can't recover the original column, but if the WHERE clause is
|
|
1960
|
+
# just "WHERE NULL", remove it since it would return no results.
|
|
1961
|
+
statement = re.sub(
|
|
1962
|
+
r"\bWHERE\s+NULL\b",
|
|
1963
|
+
"",
|
|
1964
|
+
statement,
|
|
1965
|
+
flags=re.IGNORECASE,
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
# === GQL-specific transformations ===
|
|
1969
|
+
|
|
1970
|
+
# Handle LIMIT FIRST(offset, count) syntax
|
|
1971
|
+
# Convert to LIMIT <count> OFFSET <offset>
|
|
1972
|
+
first_match = re.search(
|
|
1973
|
+
r"LIMIT\s+FIRST\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)",
|
|
1974
|
+
statement,
|
|
1975
|
+
flags=re.IGNORECASE,
|
|
1976
|
+
)
|
|
1977
|
+
if first_match:
|
|
1978
|
+
offset = first_match.group(1)
|
|
1979
|
+
count = first_match.group(2)
|
|
1980
|
+
statement = re.sub(
|
|
1981
|
+
r"LIMIT\s+FIRST\s*\(\s*\d+\s*,\s*\d+\s*\)",
|
|
1982
|
+
f"LIMIT {count} OFFSET {offset}",
|
|
1983
|
+
statement,
|
|
1984
|
+
flags=re.IGNORECASE,
|
|
1985
|
+
)
|
|
1986
|
+
|
|
1987
|
+
# Extract table name from FROM clause for KEY() conversion
|
|
1988
|
+
table_match = re.search(
|
|
1989
|
+
r"\bFROM\s+(\w+)", statement, flags=re.IGNORECASE
|
|
1990
|
+
)
|
|
1991
|
+
table_name = table_match.group(1) if table_match else None
|
|
1992
|
+
|
|
1993
|
+
# Remove DISTINCT ON (...) syntax - not supported by GQL.
|
|
1994
|
+
# GQL supports DISTINCT but not DISTINCT ON.
|
|
1995
|
+
statement = re.sub(
|
|
1996
|
+
r"\bDISTINCT\s+ON\s*\([^)]*\)\s*",
|
|
1997
|
+
"",
|
|
1998
|
+
statement,
|
|
1999
|
+
flags=re.IGNORECASE,
|
|
2000
|
+
)
|
|
2001
|
+
|
|
2002
|
+
# Convert table.id in SELECT clause to __key__
|
|
2003
|
+
if table_name:
|
|
2004
|
+
statement = re.sub(
|
|
2005
|
+
rf"\b{table_name}\.id\b",
|
|
2006
|
+
"__key__",
|
|
2007
|
+
statement,
|
|
2008
|
+
flags=re.IGNORECASE,
|
|
2009
|
+
)
|
|
2010
|
+
|
|
2011
|
+
# Handle bare 'id' references for GQL compatibility
|
|
2012
|
+
upper_stmt = statement.upper()
|
|
2013
|
+
from_pos = upper_stmt.find(" FROM ")
|
|
2014
|
+
if from_pos > 0:
|
|
2015
|
+
select_clause = statement[:from_pos]
|
|
2016
|
+
from_and_rest = statement[from_pos:]
|
|
2017
|
+
|
|
2018
|
+
# Parse SELECT columns and remove id/__key__ from projection
|
|
2019
|
+
select_match = re.match(
|
|
2020
|
+
r"(SELECT\s+(?:DISTINCT\s+(?:ON\s*\([^)]*\)\s*)?)?)(.*)",
|
|
2021
|
+
select_clause,
|
|
2022
|
+
flags=re.IGNORECASE,
|
|
2023
|
+
)
|
|
2024
|
+
if select_match:
|
|
2025
|
+
prefix = select_match.group(1)
|
|
2026
|
+
cols_str = select_match.group(2)
|
|
2027
|
+
cols = [c.strip() for c in cols_str.split(",")]
|
|
2028
|
+
non_key_cols = [
|
|
2029
|
+
c
|
|
2030
|
+
for c in cols
|
|
2031
|
+
if not re.match(
|
|
2032
|
+
r"^(id|__key__)$", c.strip(), flags=re.IGNORECASE
|
|
2033
|
+
)
|
|
2034
|
+
]
|
|
2035
|
+
|
|
2036
|
+
if not non_key_cols:
|
|
2037
|
+
select_clause = prefix + "__key__"
|
|
2038
|
+
elif len(non_key_cols) < len(cols):
|
|
2039
|
+
select_clause = prefix + ", ".join(non_key_cols)
|
|
2040
|
+
|
|
2041
|
+
# Convert 'id' to '__key__' in WHERE/ORDER BY/etc.
|
|
2042
|
+
from_and_rest = re.sub(
|
|
2043
|
+
r"\bid\b", "__key__", from_and_rest, flags=re.IGNORECASE
|
|
2044
|
+
)
|
|
2045
|
+
|
|
2046
|
+
statement = select_clause + from_and_rest
|
|
2047
|
+
else:
|
|
2048
|
+
statement = re.sub(
|
|
2049
|
+
r"\bid\b", "__key__", statement, flags=re.IGNORECASE
|
|
2050
|
+
)
|
|
2051
|
+
|
|
2052
|
+
# Datastore restriction: projection queries with WHERE clauses require
|
|
2053
|
+
# composite indexes. Convert to SELECT * to avoid this requirement and
|
|
2054
|
+
# let ParseEntity handle column filtering from the full entity response.
|
|
2055
|
+
upper_check = statement.upper()
|
|
2056
|
+
from_check_pos = upper_check.find(" FROM ")
|
|
2057
|
+
where_check_pos = upper_check.find(" WHERE ")
|
|
2058
|
+
if from_check_pos > 0 and where_check_pos > from_check_pos:
|
|
2059
|
+
select_cols_str = re.sub(
|
|
2060
|
+
r"^SELECT\s+", "", statement[:from_check_pos], flags=re.IGNORECASE
|
|
2061
|
+
).strip()
|
|
2062
|
+
if (
|
|
2063
|
+
select_cols_str != "*"
|
|
2064
|
+
and select_cols_str.upper() != "__KEY__"
|
|
2065
|
+
and not select_cols_str.upper().startswith("DISTINCT")
|
|
2066
|
+
):
|
|
2067
|
+
statement = "SELECT * " + statement[from_check_pos + 1:]
|
|
2068
|
+
|
|
2069
|
+
# Handle id = <number> in WHERE clauses -> KEY() syntax
|
|
2070
|
+
if table_name:
|
|
2071
|
+
id_where_match = re.search(
|
|
2072
|
+
r"\bWHERE\b.*\b(?:id|__key__)\s*=\s*(\d+)",
|
|
2073
|
+
statement,
|
|
2074
|
+
flags=re.IGNORECASE,
|
|
2075
|
+
)
|
|
2076
|
+
if id_where_match:
|
|
2077
|
+
id_value = id_where_match.group(1)
|
|
2078
|
+
statement = re.sub(
|
|
2079
|
+
r"\b(?:id|__key__)\s*=\s*\d+",
|
|
2080
|
+
f"__key__ = KEY({table_name}, {id_value})",
|
|
2081
|
+
statement,
|
|
2082
|
+
flags=re.IGNORECASE,
|
|
2083
|
+
)
|
|
2084
|
+
|
|
2085
|
+
# Remove column aliases (AS alias_name) - GQL doesn't support them
|
|
2086
|
+
# But preserve AS inside AGGREGATE ... AS ... OVER syntax
|
|
2087
|
+
statement = re.sub(
|
|
2088
|
+
r"\bAS\s+\w+", "", statement, flags=re.IGNORECASE
|
|
2089
|
+
)
|
|
2090
|
+
|
|
2091
|
+
# Remove table prefix from column names (table.column -> column)
|
|
2092
|
+
if table_name:
|
|
2093
|
+
statement = re.sub(
|
|
2094
|
+
rf"\b{table_name}\.(?!__)", "", statement, flags=re.IGNORECASE
|
|
2095
|
+
)
|
|
2096
|
+
|
|
2097
|
+
# Clean up extra spaces and artifacts
|
|
2098
|
+
statement = re.sub(r"\s+", " ", statement).strip()
|
|
2099
|
+
statement = re.sub(r",\s*,", ",", statement)
|
|
2100
|
+
statement = re.sub(r"\s*,\s*\bFROM\b", " FROM", statement)
|
|
2101
|
+
|
|
2102
|
+
return statement
|
|
2103
|
+
|
|
2104
|
+
def close(self):
|
|
2105
|
+
self._closed = True
|
|
2106
|
+
self.connection = None
|
|
2107
|
+
logging.debug("Cursor is closed.")
|
|
2108
|
+
|
|
2109
|
+
|
|
2110
|
+
class Connection:
|
|
2111
|
+
def __init__(self, client=None):
|
|
2112
|
+
self._client = client
|
|
2113
|
+
self._transaction = None
|
|
2114
|
+
|
|
2115
|
+
def cursor(self):
|
|
2116
|
+
return Cursor(self)
|
|
2117
|
+
|
|
2118
|
+
def begin(self):
|
|
2119
|
+
logging.debug("datastore connection transaction begin")
|
|
2120
|
+
|
|
2121
|
+
def commit(self):
|
|
2122
|
+
logging.debug("datastore connection commit")
|
|
2123
|
+
|
|
2124
|
+
def rollback(self):
|
|
2125
|
+
logging.debug("datastore connection rollback")
|
|
2126
|
+
|
|
2127
|
+
def close(self):
|
|
2128
|
+
logging.debug("Closing connection")
|
|
2129
|
+
|
|
2130
|
+
|
|
2131
|
+
def connect(client=None):
|
|
2132
|
+
return Connection(client)
|
|
2133
|
+
|
|
2134
|
+
|
|
2135
|
+
class ParseEntity:
|
|
2136
|
+
@classmethod
|
|
2137
|
+
def parse(cls, data: dict, selected_columns: Optional[List[str]] = None):
|
|
2138
|
+
"""
|
|
2139
|
+
Parse the datastore entity
|
|
2140
|
+
|
|
2141
|
+
dict is a json base entity
|
|
2142
|
+
selected_columns: List of column names to include in results. If None, include all.
|
|
2143
|
+
"""
|
|
2144
|
+
all_property_names_set = set()
|
|
2145
|
+
for entity_data in data:
|
|
2146
|
+
properties = entity_data.get("entity", {}).get("properties", {})
|
|
2147
|
+
all_property_names_set.update(properties.keys())
|
|
2148
|
+
|
|
2149
|
+
# Determine which columns to include
|
|
2150
|
+
if selected_columns is None:
|
|
2151
|
+
# Include all properties if no specific selection
|
|
2152
|
+
sorted_property_names = sorted(list(all_property_names_set))
|
|
2153
|
+
include_key = True
|
|
2154
|
+
else:
|
|
2155
|
+
# Only include selected columns
|
|
2156
|
+
sorted_property_names = []
|
|
2157
|
+
include_key = False
|
|
2158
|
+
for col in selected_columns:
|
|
2159
|
+
if col.lower() == "__key__" or col.lower() == "key":
|
|
2160
|
+
include_key = True
|
|
2161
|
+
elif col in all_property_names_set:
|
|
2162
|
+
sorted_property_names.append(col)
|
|
2163
|
+
|
|
2164
|
+
final_fields: dict = {}
|
|
2165
|
+
final_rows: List[Tuple] = []
|
|
2166
|
+
|
|
2167
|
+
# Add key field if requested
|
|
2168
|
+
if include_key:
|
|
2169
|
+
final_fields["key"] = ("key", None, None, None, None, None, None)
|
|
2170
|
+
|
|
2171
|
+
# Add selected fields in the order they appear in selected_columns if provided
|
|
2172
|
+
if selected_columns:
|
|
2173
|
+
# Keep the order from selected_columns
|
|
2174
|
+
for prop_name in selected_columns:
|
|
2175
|
+
if (
|
|
2176
|
+
prop_name.lower() != "__key__"
|
|
2177
|
+
and prop_name.lower() != "key"
|
|
2178
|
+
and prop_name in all_property_names_set
|
|
2179
|
+
):
|
|
2180
|
+
final_fields[prop_name] = (
|
|
2181
|
+
prop_name,
|
|
2182
|
+
None,
|
|
2183
|
+
None,
|
|
2184
|
+
None,
|
|
2185
|
+
None,
|
|
2186
|
+
None,
|
|
2187
|
+
None,
|
|
2188
|
+
)
|
|
2189
|
+
else:
|
|
2190
|
+
# Add all fields sorted by name
|
|
2191
|
+
for prop_name in sorted_property_names:
|
|
2192
|
+
final_fields[prop_name] = (
|
|
2193
|
+
prop_name,
|
|
2194
|
+
None,
|
|
2195
|
+
None,
|
|
2196
|
+
None,
|
|
2197
|
+
None,
|
|
2198
|
+
None,
|
|
2199
|
+
None,
|
|
2200
|
+
)
|
|
2201
|
+
|
|
2202
|
+
# Append the properties
|
|
2203
|
+
for entity_data in data:
|
|
2204
|
+
row_values: List[Any] = []
|
|
2205
|
+
properties = entity_data.get("entity", {}).get("properties", {})
|
|
2206
|
+
key = entity_data.get("entity", {}).get("key", {})
|
|
2207
|
+
|
|
2208
|
+
# Add key value if requested
|
|
2209
|
+
if include_key:
|
|
2210
|
+
row_values.append(key.get("path", []))
|
|
2211
|
+
|
|
2212
|
+
# Append selected properties in the correct order
|
|
2213
|
+
if selected_columns:
|
|
2214
|
+
for prop_name in selected_columns:
|
|
2215
|
+
if prop_name.lower() == "__key__" or prop_name.lower() == "key":
|
|
2216
|
+
continue # already added above
|
|
2217
|
+
if prop_name in all_property_names_set:
|
|
2218
|
+
prop_v = properties.get(prop_name)
|
|
2219
|
+
if prop_v is not None:
|
|
2220
|
+
prop_value, prop_type = ParseEntity.parse_properties(
|
|
2221
|
+
prop_name, prop_v
|
|
2222
|
+
)
|
|
2223
|
+
row_values.append(prop_value)
|
|
2224
|
+
current_field_info = final_fields[prop_name]
|
|
2225
|
+
if (
|
|
2226
|
+
current_field_info[1] is None
|
|
2227
|
+
or current_field_info[1] == "UNKNOWN"
|
|
2228
|
+
):
|
|
2229
|
+
final_fields[prop_name] = (
|
|
2230
|
+
prop_name,
|
|
2231
|
+
prop_type,
|
|
2232
|
+
current_field_info[2],
|
|
2233
|
+
current_field_info[3],
|
|
2234
|
+
current_field_info[4],
|
|
2235
|
+
current_field_info[5],
|
|
2236
|
+
current_field_info[6],
|
|
2237
|
+
)
|
|
2238
|
+
else:
|
|
2239
|
+
row_values.append(None)
|
|
2240
|
+
else:
|
|
2241
|
+
# Append all properties in sorted order
|
|
2242
|
+
for prop_name in sorted_property_names:
|
|
2243
|
+
prop_v = properties.get(prop_name)
|
|
2244
|
+
if prop_v is not None:
|
|
2245
|
+
prop_value, prop_type = ParseEntity.parse_properties(
|
|
2246
|
+
prop_name, prop_v
|
|
2247
|
+
)
|
|
2248
|
+
row_values.append(prop_value)
|
|
2249
|
+
current_field_info = final_fields[prop_name]
|
|
2250
|
+
if (
|
|
2251
|
+
current_field_info[1] is None
|
|
2252
|
+
or current_field_info[1] == "UNKNOWN"
|
|
2253
|
+
):
|
|
2254
|
+
final_fields[prop_name] = (
|
|
2255
|
+
prop_name,
|
|
2256
|
+
prop_type,
|
|
2257
|
+
current_field_info[2],
|
|
2258
|
+
current_field_info[3],
|
|
2259
|
+
current_field_info[4],
|
|
2260
|
+
current_field_info[5],
|
|
2261
|
+
current_field_info[6],
|
|
2262
|
+
)
|
|
2263
|
+
else:
|
|
2264
|
+
row_values.append(None)
|
|
2265
|
+
|
|
2266
|
+
final_rows.append(tuple(row_values))
|
|
2267
|
+
|
|
2268
|
+
return final_rows, final_fields
|
|
2269
|
+
|
|
2270
|
+
@classmethod
|
|
2271
|
+
def parse_properties(cls, prop_k: str, prop_v: dict):
|
|
2272
|
+
value_type = next(iter(prop_v), None)
|
|
2273
|
+
prop_type = None
|
|
2274
|
+
prop_value: Any = None
|
|
2275
|
+
|
|
2276
|
+
if value_type == "nullValue" or "nullValue" in prop_v:
|
|
2277
|
+
prop_value = None
|
|
2278
|
+
prop_type = _types.NULL_TYPE
|
|
2279
|
+
elif value_type == "booleanValue" or "booleanValue" in prop_v:
|
|
2280
|
+
prop_value = bool(prop_v["booleanValue"])
|
|
2281
|
+
prop_type = _types.BOOL
|
|
2282
|
+
elif value_type == "integerValue" or "integerValue" in prop_v:
|
|
2283
|
+
prop_value = int(prop_v["integerValue"])
|
|
2284
|
+
prop_type = _types.INTEGER
|
|
2285
|
+
elif value_type == "doubleValue" or "doubleValue" in prop_v:
|
|
2286
|
+
prop_value = float(prop_v["doubleValue"])
|
|
2287
|
+
prop_type = _types.FLOAT64
|
|
2288
|
+
elif value_type == "stringValue" or "stringValue" in prop_v:
|
|
2289
|
+
prop_value = prop_v["stringValue"]
|
|
2290
|
+
prop_type = _types.STRING
|
|
2291
|
+
elif value_type == "timestampValue" or "timestampValue" in prop_v:
|
|
2292
|
+
timestamp_str = prop_v["timestampValue"]
|
|
2293
|
+
if timestamp_str.endswith("Z"):
|
|
2294
|
+
# Handle ISO 8601 with Z suffix (UTC)
|
|
2295
|
+
prop_value = datetime.fromisoformat(
|
|
2296
|
+
timestamp_str.replace("Z", "+00:00")
|
|
2297
|
+
)
|
|
2298
|
+
else:
|
|
2299
|
+
prop_value = datetime.fromisoformat(timestamp_str)
|
|
2300
|
+
prop_type = _types.TIMESTAMP
|
|
2301
|
+
elif value_type == "blobValue" or "blobValue" in prop_v:
|
|
2302
|
+
prop_value = base64.b64decode(prop_v.get("blobValue", b""))
|
|
2303
|
+
prop_type = _types.BYTES
|
|
2304
|
+
elif value_type == "geoPointValue" or "geoPointValue" in prop_v:
|
|
2305
|
+
prop_value = prop_v["geoPointValue"]
|
|
2306
|
+
prop_type = _types.GEOPOINT
|
|
2307
|
+
elif value_type == "keyValue" or "keyValue" in prop_v:
|
|
2308
|
+
prop_value = prop_v["keyValue"]["path"]
|
|
2309
|
+
prop_type = _types.KEY_TYPE
|
|
2310
|
+
elif value_type == "arrayValue" or "arrayValue" in prop_v:
|
|
2311
|
+
prop_value = []
|
|
2312
|
+
for entity in prop_v["arrayValue"].get("values", []):
|
|
2313
|
+
e_v, _ = ParseEntity.parse_properties(prop_k, entity)
|
|
2314
|
+
prop_value.append(e_v)
|
|
2315
|
+
prop_type = _types.ARRAY
|
|
2316
|
+
elif value_type == "dictValue" or "dictValue" in prop_v:
|
|
2317
|
+
prop_value = prop_v["dictValue"]
|
|
2318
|
+
prop_type = _types.STRUCT_FIELD_TYPES
|
|
2319
|
+
elif value_type == "entityValue" or "entityValue" in prop_v:
|
|
2320
|
+
prop_value = prop_v["entityValue"].get("properties") or {}
|
|
2321
|
+
prop_type = _types.STRUCT_FIELD_TYPES
|
|
2322
|
+
return prop_value, prop_type
|