python-datastore-sqlalchemy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2322 @@
1
+ # Copyright (c) 2025 hychang <hychang.1997.tw@gmail.com>
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+ import base64
20
+ import collections
21
+ import logging
22
+ import os
23
+ import re
24
+ from datetime import datetime
25
+ from typing import Any, Dict, List, Optional, Tuple
26
+
27
+ import pandas as pd
28
+ import requests
29
+ from google.auth.transport.requests import AuthorizedSession
30
+ from google.cloud import datastore
31
+ from google.cloud.datastore.helpers import GeoPoint
32
+ from google.oauth2 import service_account
33
+ from requests import Response
34
+ from sqlalchemy import types
35
+ from sqlglot import exp, parse_one, tokenize, tokens
36
+ from sqlglot.tokens import TokenType
37
+
38
+ from . import _types
39
+
40
+ logger = logging.getLogger("sqlalchemy.dialects.datastore_dbapi")
41
+
42
+ apilevel = "2.0"
43
+ threadsafety = 2
44
+ paramstyle = "named"
45
+
46
+
47
+ # Required exceptions
48
+ class Warning(Exception):
49
+ """Exception raised for important DB-API warnings."""
50
+
51
+
52
+ class Error(Exception):
53
+ """Exception representing all non-warning DB-API errors."""
54
+
55
+
56
+ class InterfaceError(Error):
57
+ """DB-API error related to the database interface."""
58
+
59
+
60
+ class DatabaseError(Error):
61
+ """DB-API error related to the database."""
62
+
63
+
64
+ class DataError(DatabaseError):
65
+ """DB-API error due to problems with the processed data."""
66
+
67
+
68
+ class OperationalError(DatabaseError):
69
+ """DB-API error related to the database operation."""
70
+
71
+
72
+ class IntegrityError(DatabaseError):
73
+ """DB-API error when integrity of the database is affected."""
74
+
75
+
76
+ class InternalError(DatabaseError):
77
+ """DB-API error when the database encounters an internal error."""
78
+
79
+
80
+ class ProgrammingError(DatabaseError):
81
+ """DB-API exception raised for programming errors."""
82
+
83
+
84
+ Column = collections.namedtuple(
85
+ "Column",
86
+ [
87
+ "name",
88
+ "type_code",
89
+ "display_size",
90
+ "internal_size",
91
+ "precision",
92
+ "scale",
93
+ "null_ok",
94
+ ],
95
+ )
96
+
97
+ type_map = {
98
+ str: types.String,
99
+ int: types.NUMERIC,
100
+ float: types.FLOAT,
101
+ bool: types.BOOLEAN,
102
+ bytes: types.BINARY,
103
+ datetime: types.DATETIME,
104
+ datastore.Key: types.JSON,
105
+ GeoPoint: types.JSON,
106
+ list: types.JSON,
107
+ dict: types.JSON,
108
+ None.__class__: types.String,
109
+ }
110
+
111
+
112
+ class Cursor:
113
+ def __init__(self, connection):
114
+ self.connection = connection
115
+ self._datastore_client = connection._client
116
+ self.rowcount = -1
117
+ self.arraysize = None
118
+ self._query_data = None
119
+ self._query_rows = None
120
+ self._closed = False
121
+ self.description = None
122
+ self.lastrowid = None
123
+ self.warnings: list[str] = []
124
+
125
+ def execute(self, statements, parameters=None):
126
+ """Execute a Datastore operation."""
127
+ if self._closed:
128
+ raise Error("Cursor is closed.")
129
+
130
+ # Check for DML statements
131
+ upper_statement = statements.upper().strip()
132
+ if upper_statement.startswith("INSERT"):
133
+ self._execute_insert(statements, parameters)
134
+ return
135
+ if upper_statement.startswith("UPDATE"):
136
+ self._execute_update(statements, parameters)
137
+ return
138
+ if upper_statement.startswith("DELETE"):
139
+ self._execute_delete(statements, parameters)
140
+ return
141
+
142
+ tokens = tokenize(statements)
143
+ if self._is_derived_query(tokens):
144
+ self.execute_orm(statements, parameters, tokens)
145
+ else:
146
+ self.gql_query(statements, parameters)
147
+
148
+ def _execute_insert(self, statement: str, parameters=None):
149
+ """Execute an INSERT statement using Datastore client."""
150
+ if parameters is None:
151
+ parameters = {}
152
+
153
+ logging.debug(f"Executing INSERT: {statement} with parameters: {parameters}")
154
+
155
+ try:
156
+ # Parse INSERT statement using sqlglot
157
+ parsed = parse_one(statement)
158
+ if not isinstance(parsed, exp.Insert):
159
+ raise ProgrammingError(f"Expected INSERT statement, got: {type(parsed)}")
160
+
161
+ # Get table/kind name
162
+ # For INSERT, parsed.this is a Schema containing the table and columns
163
+ schema_expr = parsed.this
164
+ if isinstance(schema_expr, exp.Schema):
165
+ # Schema has 'this' which is the table
166
+ table_expr = schema_expr.this
167
+ if isinstance(table_expr, exp.Table):
168
+ kind = table_expr.name
169
+ else:
170
+ kind = str(table_expr)
171
+ elif isinstance(schema_expr, exp.Table):
172
+ kind = schema_expr.name
173
+ else:
174
+ raise ProgrammingError("Could not determine table name from INSERT")
175
+
176
+ # Get column names from Schema's expressions
177
+ columns = []
178
+ if isinstance(schema_expr, exp.Schema) and schema_expr.expressions:
179
+ for col in schema_expr.expressions:
180
+ if hasattr(col, "name"):
181
+ columns.append(col.name)
182
+ else:
183
+ columns.append(str(col))
184
+
185
+ # Get values
186
+ values_list = []
187
+ values_expr = parsed.args.get("expression")
188
+ if values_expr and hasattr(values_expr, "expressions"):
189
+ for tuple_expr in values_expr.expressions:
190
+ if hasattr(tuple_expr, "expressions"):
191
+ row_values = []
192
+ for val in tuple_expr.expressions:
193
+ row_values.append(self._parse_insert_value(val, parameters))
194
+ values_list.append(row_values)
195
+ elif values_expr:
196
+ # Single row VALUES clause
197
+ row_values = []
198
+ if hasattr(values_expr, "expressions"):
199
+ for val in values_expr.expressions:
200
+ row_values.append(self._parse_insert_value(val, parameters))
201
+ values_list.append(row_values)
202
+
203
+ # Create entities and insert them
204
+ entities_created = 0
205
+ for row_values in values_list:
206
+ # Create entity key (auto-generated)
207
+ key = self._datastore_client.key(kind)
208
+ entity = datastore.Entity(key=key)
209
+
210
+ # Set entity properties
211
+ for i, col in enumerate(columns):
212
+ if i < len(row_values):
213
+ entity[col] = row_values[i]
214
+
215
+ # Put entity to datastore
216
+ self._datastore_client.put(entity)
217
+ entities_created += 1
218
+ # Save the last inserted entity's key ID for lastrowid
219
+ if entity.key.id is not None:
220
+ self.lastrowid = entity.key.id
221
+ elif entity.key.name is not None:
222
+ # For named keys, use a hash of the name as a numeric ID
223
+ self.lastrowid = hash(entity.key.name) & 0x7FFFFFFFFFFFFFFF
224
+
225
+ self.rowcount = entities_created
226
+ self._query_rows = iter([])
227
+ self.description = None
228
+
229
+ except Exception as e:
230
+ logging.error(f"INSERT failed: {e}")
231
+ raise ProgrammingError(f"INSERT failed: {e}")
232
+
233
+ def _execute_update(self, statement: str, parameters=None):
234
+ """Execute an UPDATE statement using Datastore client."""
235
+ if parameters is None:
236
+ parameters = {}
237
+
238
+ logging.debug(f"Executing UPDATE: {statement} with parameters: {parameters}")
239
+
240
+ try:
241
+ parsed = parse_one(statement)
242
+ if not isinstance(parsed, exp.Update):
243
+ raise ProgrammingError(f"Expected UPDATE statement, got: {type(parsed)}")
244
+
245
+ # Get table/kind name
246
+ table_expr = parsed.this
247
+ if isinstance(table_expr, exp.Table):
248
+ kind = table_expr.name
249
+ else:
250
+ raise ProgrammingError("Could not determine table name from UPDATE")
251
+
252
+ # Get the WHERE clause to find the entity key
253
+ where = parsed.args.get("where")
254
+ if not where:
255
+ raise ProgrammingError("UPDATE without WHERE clause is not supported")
256
+
257
+ # Extract the key ID from WHERE clause (e.g., WHERE id = :id_1)
258
+ entity_key_id = self._extract_key_id_from_where(where, parameters)
259
+ if entity_key_id is None:
260
+ raise ProgrammingError("Could not extract entity key from WHERE clause")
261
+
262
+ # Get the entity
263
+ key = self._datastore_client.key(kind, entity_key_id)
264
+ entity = self._datastore_client.get(key)
265
+ if entity is None:
266
+ self.rowcount = 0
267
+ self._query_rows = iter([])
268
+ self.description = None
269
+ return
270
+
271
+ # Apply the SET values
272
+ for set_expr in parsed.args.get("expressions", []):
273
+ if isinstance(set_expr, exp.EQ):
274
+ col_name = set_expr.left.name if hasattr(set_expr.left, "name") else str(set_expr.left)
275
+ value = self._parse_update_value(set_expr.right, parameters)
276
+ entity[col_name] = value
277
+
278
+ # Save the entity
279
+ self._datastore_client.put(entity)
280
+ self.rowcount = 1
281
+ self._query_rows = iter([])
282
+ self.description = None
283
+
284
+ except Exception as e:
285
+ logging.error(f"UPDATE failed: {e}")
286
+ raise ProgrammingError(f"UPDATE failed: {e}") from e
287
+
288
+ def _execute_delete(self, statement: str, parameters=None):
289
+ """Execute a DELETE statement using Datastore client."""
290
+ if parameters is None:
291
+ parameters = {}
292
+
293
+ logging.debug(f"Executing DELETE: {statement} with parameters: {parameters}")
294
+
295
+ try:
296
+ parsed = parse_one(statement)
297
+ if not isinstance(parsed, exp.Delete):
298
+ raise ProgrammingError(f"Expected DELETE statement, got: {type(parsed)}")
299
+
300
+ # Get table/kind name
301
+ table_expr = parsed.this
302
+ if isinstance(table_expr, exp.Table):
303
+ kind = table_expr.name
304
+ else:
305
+ raise ProgrammingError("Could not determine table name from DELETE")
306
+
307
+ # Get the WHERE clause to find the entity key
308
+ where = parsed.args.get("where")
309
+ if not where:
310
+ raise ProgrammingError("DELETE without WHERE clause is not supported")
311
+
312
+ # Extract the key ID from WHERE clause
313
+ entity_key_id = self._extract_key_id_from_where(where, parameters)
314
+ if entity_key_id is None:
315
+ raise ProgrammingError("Could not extract entity key from WHERE clause")
316
+
317
+ # Delete the entity
318
+ key = self._datastore_client.key(kind, entity_key_id)
319
+ self._datastore_client.delete(key)
320
+ self.rowcount = 1
321
+ self._query_rows = iter([])
322
+ self.description = None
323
+
324
+ except Exception as e:
325
+ logging.error(f"DELETE failed: {e}")
326
+ raise ProgrammingError(f"DELETE failed: {e}") from e
327
+
328
+ def _extract_key_id_from_where(self, where_expr, parameters: dict) -> Optional[int]:
329
+ """Extract entity key ID from WHERE clause."""
330
+ # Handle WHERE id = :param or WHERE id = value
331
+ if isinstance(where_expr, exp.Where):
332
+ where_expr = where_expr.this
333
+
334
+ if isinstance(where_expr, exp.EQ):
335
+ left = where_expr.left
336
+ right = where_expr.right
337
+
338
+ # Check if left side is 'id'
339
+ col_name = left.name if hasattr(left, "name") else str(left)
340
+ if col_name.lower() == "id":
341
+ return self._parse_key_value(right, parameters)
342
+
343
+ return None
344
+
345
+ def _parse_key_value(self, val_expr, parameters: dict) -> Optional[int]:
346
+ """Parse a value expression to get key ID."""
347
+ if isinstance(val_expr, exp.Literal):
348
+ if val_expr.is_number:
349
+ return int(val_expr.this)
350
+ elif isinstance(val_expr, exp.Placeholder):
351
+ param_name = val_expr.name or val_expr.this
352
+ if param_name in parameters:
353
+ return int(parameters[param_name])
354
+ if param_name.startswith(":"):
355
+ param_name = param_name[1:]
356
+ if param_name in parameters:
357
+ return int(parameters[param_name])
358
+ elif isinstance(val_expr, exp.Parameter):
359
+ param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
360
+ if param_name in parameters:
361
+ return int(parameters[param_name])
362
+ return None
363
+
364
+ def _parse_update_value(self, val_expr, parameters: dict) -> Any:
365
+ """Parse a value expression from UPDATE SET clause."""
366
+ if isinstance(val_expr, exp.Literal):
367
+ if val_expr.is_string:
368
+ return val_expr.this
369
+ elif val_expr.is_number:
370
+ text = val_expr.this
371
+ if "." in text:
372
+ return float(text)
373
+ return int(text)
374
+ return val_expr.this
375
+ elif isinstance(val_expr, exp.Null):
376
+ return None
377
+ elif isinstance(val_expr, exp.Boolean):
378
+ return val_expr.this
379
+ elif isinstance(val_expr, exp.Placeholder):
380
+ param_name = val_expr.name or val_expr.this
381
+ if param_name in parameters:
382
+ return parameters[param_name]
383
+ if param_name.startswith(":"):
384
+ param_name = param_name[1:]
385
+ if param_name in parameters:
386
+ return parameters[param_name]
387
+ return None
388
+ elif isinstance(val_expr, exp.Parameter):
389
+ param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
390
+ if param_name in parameters:
391
+ return parameters[param_name]
392
+ return None
393
+ else:
394
+ return str(val_expr.this) if hasattr(val_expr, "this") else str(val_expr)
395
+
396
+ def _parse_insert_value(self, val_expr, parameters: dict) -> Any:
397
+ """Parse a value expression from INSERT statement."""
398
+ if isinstance(val_expr, exp.Literal):
399
+ if val_expr.is_string:
400
+ return val_expr.this
401
+ elif val_expr.is_number:
402
+ text = val_expr.this
403
+ if "." in text:
404
+ return float(text)
405
+ return int(text)
406
+ return val_expr.this
407
+ elif isinstance(val_expr, exp.Null):
408
+ return None
409
+ elif isinstance(val_expr, exp.Boolean):
410
+ return val_expr.this
411
+ elif isinstance(val_expr, exp.Placeholder):
412
+ # Named parameter like :name
413
+ param_name = val_expr.name or val_expr.this
414
+ if param_name and param_name in parameters:
415
+ return parameters[param_name]
416
+ # Handle :name format
417
+ if param_name and param_name.startswith(":"):
418
+ param_name = param_name[1:]
419
+ if param_name in parameters:
420
+ return parameters[param_name]
421
+ return None
422
+ elif isinstance(val_expr, exp.Parameter):
423
+ # Named parameter
424
+ param_name = val_expr.this.this if hasattr(val_expr.this, "this") else str(val_expr.this)
425
+ if param_name in parameters:
426
+ return parameters[param_name]
427
+ return None
428
+ else:
429
+ # Try to get the string representation
430
+ return str(val_expr.this) if hasattr(val_expr, "this") else str(val_expr)
431
+
432
+ def _is_derived_query(self, tokens: List[tokens.Token]) -> bool:
433
+ """
434
+ Checks if the SQL statement contains a derived table (subquery in FROM).
435
+ This is a more reliable way to distinguish complex ORM queries from simple GQL.
436
+ """
437
+ select_seen = 0
438
+ for token in tokens:
439
+ if token.token_type == TokenType.SELECT:
440
+ select_seen += 1
441
+ if select_seen >= 2:
442
+ return True
443
+ return False
444
+
445
+ def _is_aggregation_query(self, statement: str) -> bool:
446
+ """Check if the statement contains aggregation functions."""
447
+ upper = statement.upper()
448
+ # Check for AGGREGATE ... OVER syntax
449
+ if upper.strip().startswith("AGGREGATE"):
450
+ return True
451
+ # Check for aggregation functions in SELECT
452
+ agg_patterns = [
453
+ r"\bCOUNT\s*\(",
454
+ r"\bCOUNT_UP_TO\s*\(",
455
+ r"\bSUM\s*\(",
456
+ r"\bAVG\s*\(",
457
+ ]
458
+ for pattern in agg_patterns:
459
+ if re.search(pattern, upper):
460
+ return True
461
+ return False
462
+
463
+ def _parse_aggregation_query(self, statement: str) -> Dict[str, Any]:
464
+ """
465
+ Parse aggregation query and return components.
466
+ Returns dict with:
467
+ - 'agg_functions': list of (func_name, column, alias)
468
+ - 'base_query': the underlying SELECT query
469
+ - 'is_aggregate_over': whether it's AGGREGATE...OVER syntax
470
+ """
471
+ upper = statement.upper().strip()
472
+ result: Dict[str, Any] = {
473
+ "agg_functions": [],
474
+ "base_query": None,
475
+ "is_aggregate_over": False,
476
+ }
477
+
478
+ # Handle AGGREGATE ... OVER (SELECT ...) syntax
479
+ if upper.startswith("AGGREGATE"):
480
+ result["is_aggregate_over"] = True
481
+ # Extract the inner SELECT query
482
+ over_match = re.search(
483
+ r"OVER\s*\(\s*(SELECT\s+.+)\s*\)\s*$",
484
+ statement,
485
+ re.IGNORECASE | re.DOTALL,
486
+ )
487
+ if over_match:
488
+ result["base_query"] = over_match.group(1).strip()
489
+ else:
490
+ # Fallback - extract everything after OVER
491
+ over_idx = upper.find("OVER")
492
+ if over_idx > 0:
493
+ # Extract content inside parentheses
494
+ remaining = statement[over_idx + 4 :].strip()
495
+ if remaining.startswith("("):
496
+ paren_depth = 0
497
+ for i, c in enumerate(remaining):
498
+ if c == "(":
499
+ paren_depth += 1
500
+ elif c == ")":
501
+ paren_depth -= 1
502
+ if paren_depth == 0:
503
+ result["base_query"] = remaining[1:i].strip()
504
+ break
505
+
506
+ # Parse aggregation functions before OVER
507
+ agg_part = statement[: upper.find("OVER")].strip()
508
+ if agg_part.upper().startswith("AGGREGATE"):
509
+ agg_part = agg_part[9:].strip() # Remove "AGGREGATE"
510
+ result["agg_functions"] = self._extract_agg_functions(agg_part)
511
+ else:
512
+ # Handle SELECT COUNT(*), SUM(col), etc.
513
+ result["is_aggregate_over"] = False
514
+ # Parse the SELECT clause to extract aggregation functions
515
+ select_match = re.match(
516
+ r"SELECT\s+(.+?)\s+FROM\s+(.+)$", statement, re.IGNORECASE | re.DOTALL
517
+ )
518
+ if select_match:
519
+ select_clause = select_match.group(1)
520
+ from_clause = select_match.group(2)
521
+ result["agg_functions"] = self._extract_agg_functions(select_clause)
522
+ # Build base query to get all data
523
+ result["base_query"] = f"SELECT * FROM {from_clause}"
524
+ else:
525
+ # Handle SELECT without FROM (e.g., SELECT COUNT(*))
526
+ select_match = re.match(
527
+ r"SELECT\s+(.+)$", statement, re.IGNORECASE | re.DOTALL
528
+ )
529
+ if select_match:
530
+ select_clause = select_match.group(1)
531
+ result["agg_functions"] = self._extract_agg_functions(select_clause)
532
+ result["base_query"] = None # No base query for kindless
533
+
534
+ return result
535
+
536
+ def _extract_agg_functions(self, clause: str) -> List[Tuple[str, str, str]]:
537
+ """Extract aggregation functions from a clause."""
538
+ functions: List[Tuple[str, str, str]] = []
539
+ # Pattern to match aggregation functions with optional alias
540
+ patterns = [
541
+ (
542
+ r"COUNT_UP_TO\s*\(\s*(\d+)\s*\)(?:\s+AS\s+(\w+))?",
543
+ "COUNT_UP_TO",
544
+ ),
545
+ (r"COUNT\s*\(\s*\*\s*\)(?:\s+AS\s+(\w+))?", "COUNT"),
546
+ (r"SUM\s*\(\s*(\w+)\s*\)(?:\s+AS\s+(\w+))?", "SUM"),
547
+ (r"AVG\s*\(\s*(\w+)\s*\)(?:\s+AS\s+(\w+))?", "AVG"),
548
+ ]
549
+
550
+ for pattern, func_name in patterns:
551
+ for match in re.finditer(pattern, clause, re.IGNORECASE):
552
+ if func_name == "COUNT":
553
+ col = "*"
554
+ alias = match.group(1) if match.group(1) else func_name
555
+ elif func_name == "COUNT_UP_TO":
556
+ col = match.group(1) # The limit number
557
+ alias = match.group(2) if match.group(2) else func_name
558
+ else:
559
+ col = match.group(1)
560
+ alias = match.group(2) if match.group(2) else func_name
561
+ functions.append((func_name, col, alias))
562
+
563
+ return functions
564
+
565
+ def _compute_aggregations(
566
+ self,
567
+ rows: List[Tuple],
568
+ fields: Dict[str, Any],
569
+ agg_functions: List[Tuple[str, str, str]],
570
+ ) -> Tuple[List[Tuple], Dict[str, Any]]:
571
+ """Compute aggregations on the data."""
572
+ result_values: List[Any] = []
573
+ result_fields: Dict[str, Any] = {}
574
+
575
+ # Get column name to index mapping
576
+ field_names = list(fields.keys())
577
+
578
+ for func_name, col, alias in agg_functions:
579
+ if func_name == "COUNT":
580
+ value = len(rows)
581
+ elif func_name == "COUNT_UP_TO":
582
+ limit = int(col)
583
+ value = min(len(rows), limit)
584
+ elif func_name in ("SUM", "AVG"):
585
+ # Find the column index
586
+ if col in field_names:
587
+ col_idx = field_names.index(col)
588
+ values = [row[col_idx] for row in rows if row[col_idx] is not None]
589
+ numeric_values = [v for v in values if isinstance(v, (int, float))]
590
+ if func_name == "SUM":
591
+ value = sum(numeric_values) if numeric_values else 0
592
+ else: # AVG
593
+ value = (
594
+ sum(numeric_values) / len(numeric_values)
595
+ if numeric_values
596
+ else 0
597
+ )
598
+ else:
599
+ value = 0
600
+ else:
601
+ value = None
602
+
603
+ result_values.append(value)
604
+ result_fields[alias] = (alias, None, None, None, None, None, None)
605
+
606
+ return [tuple(result_values)], result_fields
607
+
608
+ def _execute_gql_request(self, gql_statement: str) -> Response:
609
+ """Execute a GQL query and return the response."""
610
+ body = {
611
+ "gqlQuery": {
612
+ "queryString": gql_statement,
613
+ "allowLiterals": True,
614
+ }
615
+ }
616
+
617
+ project_id = self._datastore_client.project
618
+ if os.getenv("DATASTORE_EMULATOR_HOST") is None:
619
+ credentials = getattr(
620
+ self._datastore_client, "scoped_credentials", None
621
+ )
622
+ if credentials is None and self._datastore_client.credentials_info:
623
+ credentials = service_account.Credentials.from_service_account_info(
624
+ self._datastore_client.credentials_info,
625
+ scopes=["https://www.googleapis.com/auth/datastore"],
626
+ )
627
+ if credentials is None:
628
+ raise ProgrammingError(
629
+ "No credentials available for Datastore query. "
630
+ "Provide credentials_info, credentials_path, or "
631
+ "configure Application Default Credentials."
632
+ )
633
+ authed_session = AuthorizedSession(credentials)
634
+ url = f"https://datastore.googleapis.com/v1/projects/{project_id}:runQuery"
635
+ return authed_session.post(url, json=body)
636
+ else:
637
+ host = os.environ["DATASTORE_EMULATOR_HOST"]
638
+ url = f"http://{host}/v1/projects/{project_id}:runQuery"
639
+ return requests.post(url, json=body)
640
+
641
+ def _needs_client_side_filter(self, statement: str) -> bool:
642
+ """Check if the query needs client-side filtering due to unsupported ops.
643
+
644
+ Note: This should be called on the CONVERTED GQL statement (after
645
+ _convert_sql_to_gql), since that method handles reversing sqlglot
646
+ transformations like <> -> != and NOT col IN -> col NOT IN.
647
+ GQL natively supports: =, <, >, <=, >=, !=, IN, NOT IN, CONTAINS.
648
+ """
649
+ upper = statement.upper()
650
+ unsupported_patterns = [
651
+ r"\bOR\b", # OR conditions need client-side evaluation
652
+ r"\bBLOB\s*\(", # BLOB literal (escaping issues)
653
+ ]
654
+ for pattern in unsupported_patterns:
655
+ if re.search(pattern, upper):
656
+ return True
657
+ return False
658
+
659
+ def _extract_base_query_for_filter(self, statement: str) -> str:
660
+ """Extract base query without WHERE clause for client-side filtering."""
661
+ # Remove WHERE clause to get all data
662
+ upper = statement.upper()
663
+ where_idx = upper.find(" WHERE ")
664
+ if where_idx > 0:
665
+ # Find the end of WHERE (before ORDER BY, LIMIT, OFFSET)
666
+ end_patterns = [" ORDER BY ", " LIMIT ", " OFFSET "]
667
+ end_idx = len(statement)
668
+ for pattern in end_patterns:
669
+ idx = upper.find(pattern, where_idx)
670
+ if idx > 0 and idx < end_idx:
671
+ end_idx = idx
672
+ # Remove WHERE clause
673
+ base = statement[:where_idx] + statement[end_idx:]
674
+ return base.strip()
675
+ return statement
676
+
677
+ def _is_missing_index_error(self, response: Response) -> bool:
678
+ """Check if the GQL response indicates a missing composite index."""
679
+ if response.status_code not in (400, 409):
680
+ return False
681
+ try:
682
+ body = response.json()
683
+ error = body.get("error", {})
684
+ message = error.get("message", "").lower()
685
+ status = error.get("status", "")
686
+ return (
687
+ "no matching index found" in message
688
+ or status == "FAILED_PRECONDITION"
689
+ )
690
+ except Exception:
691
+ return "no matching index" in response.text.lower()
692
+
693
+ def _extract_table_only_query(self, gql_statement: str) -> str:
694
+ """Extract just 'SELECT * FROM <table>' from a GQL statement."""
695
+ table_match = re.search(
696
+ r"\bFROM\s+(\w+)", gql_statement, flags=re.IGNORECASE
697
+ )
698
+ if table_match:
699
+ return f"SELECT * FROM {table_match.group(1)}"
700
+ raise ProgrammingError(
701
+ f"Could not extract table name from query: {gql_statement}"
702
+ )
703
+
704
+ def _parse_order_by_clause(
705
+ self, gql_statement: str
706
+ ) -> List[Tuple[str, bool]]:
707
+ """Parse ORDER BY clause. Returns list of (column, ascending) tuples."""
708
+ upper = gql_statement.upper()
709
+ order_idx = upper.find(" ORDER BY ")
710
+ if order_idx < 0:
711
+ return []
712
+ # Find end of ORDER BY (before LIMIT, OFFSET)
713
+ end_idx = len(gql_statement)
714
+ for pattern in [" LIMIT ", " OFFSET "]:
715
+ idx = upper.find(pattern, order_idx + 10)
716
+ if 0 < idx < end_idx:
717
+ end_idx = idx
718
+ order_clause = gql_statement[order_idx + 10 : end_idx].strip()
719
+ if not order_clause:
720
+ return []
721
+ result: List[Tuple[str, bool]] = []
722
+ for part in order_clause.split(","):
723
+ part = part.strip()
724
+ parts = part.split()
725
+ if not parts:
726
+ continue
727
+ col_name = parts[0]
728
+ ascending = not (len(parts) > 1 and parts[1].upper() == "DESC")
729
+ result.append((col_name, ascending))
730
+ return result
731
+
732
+ def _parse_limit_offset_clause(
733
+ self, gql_statement: str
734
+ ) -> Tuple[Optional[int], int]:
735
+ """Parse LIMIT and OFFSET from statement. Returns (limit, offset)."""
736
+ limit = None
737
+ offset = 0
738
+ limit_match = re.search(
739
+ r"\bLIMIT\s+(\d+)", gql_statement, flags=re.IGNORECASE
740
+ )
741
+ if limit_match:
742
+ limit = int(limit_match.group(1))
743
+ offset_match = re.search(
744
+ r"\bOFFSET\s+(\d+)", gql_statement, flags=re.IGNORECASE
745
+ )
746
+ if offset_match:
747
+ offset = int(offset_match.group(1))
748
+ return limit, offset
749
+
750
+ def _apply_client_side_order_by(
751
+ self,
752
+ rows: List[Tuple],
753
+ fields: Dict[str, Any],
754
+ order_keys: List[Tuple[str, bool]],
755
+ ) -> List[Tuple]:
756
+ """Sort rows on the client side based on ORDER BY specification."""
757
+ if not order_keys or not rows:
758
+ return rows
759
+ from functools import cmp_to_key
760
+
761
+ field_names = list(fields.keys())
762
+
763
+ def compare_rows(row_a: Tuple, row_b: Tuple) -> int:
764
+ for col_name, ascending in order_keys:
765
+ if col_name in field_names:
766
+ idx = field_names.index(col_name)
767
+ val_a = row_a[idx] if idx < len(row_a) else None
768
+ val_b = row_b[idx] if idx < len(row_b) else None
769
+ else:
770
+ val_a = None
771
+ val_b = None
772
+ if val_a is None and val_b is None:
773
+ continue
774
+ if val_a is None:
775
+ return 1 # None sorts last
776
+ if val_b is None:
777
+ return -1
778
+ try:
779
+ if val_a < val_b:
780
+ cmp_result = -1
781
+ elif val_a > val_b:
782
+ cmp_result = 1
783
+ else:
784
+ continue
785
+ except TypeError:
786
+ continue
787
+ if not ascending:
788
+ cmp_result = -cmp_result
789
+ return cmp_result
790
+ return 0
791
+
792
+ return sorted(rows, key=cmp_to_key(compare_rows))
793
+
794
+ def _execute_fallback_query(
795
+ self, original_statement: str, gql_statement: str
796
+ ):
797
+ """Execute a fallback query when the original fails due to missing index.
798
+
799
+ Fetches all data from the table and applies WHERE, ORDER BY,
800
+ LIMIT/OFFSET on the client side.
801
+ """
802
+ warning_msg = (
803
+ "Missing index: the query requires an index that does not exist "
804
+ "in Datastore. Falling back to fetching ALL entities and "
805
+ "processing client-side (SELECT * mode). This may significantly "
806
+ "increase query and egress costs. Consider adding the required "
807
+ "composite index to avoid this."
808
+ )
809
+ logging.warning("%s Original GQL: %s", warning_msg, gql_statement)
810
+ self.warnings.append(warning_msg)
811
+
812
+ # Build simple query to fetch all data from the table
813
+ fallback_query = self._extract_table_only_query(gql_statement)
814
+ response = self._execute_gql_request(fallback_query)
815
+ if response.status_code != 200:
816
+ raise OperationalError(
817
+ f"Fallback query failed: {fallback_query} "
818
+ f"(original: {gql_statement})"
819
+ )
820
+
821
+ data = response.json()
822
+ entity_results = data.get("batch", {}).get("entityResults", [])
823
+
824
+ # Initialize cursor state for empty result
825
+ self._query_data = iter([])
826
+ self._query_rows = iter([])
827
+ self.rowcount = 0
828
+ self.description = [(None, None, None, None, None, None, None)]
829
+ self._last_executed = original_statement
830
+ self._parameters = {}
831
+
832
+ if not entity_results:
833
+ return
834
+
835
+ self._closed = False
836
+
837
+ # Parse entities with all columns (needed for filtering/sorting)
838
+ rows, fields = ParseEntity.parse(entity_results, None)
839
+
840
+ # Apply WHERE filter using the original statement to preserve
841
+ # binary data in BLOB literals (whitespace normalization in
842
+ # _convert_sql_to_gql would corrupt them).
843
+ rows = self._apply_client_side_filter(rows, fields, original_statement)
844
+
845
+ # Apply ORDER BY
846
+ order_keys = self._parse_order_by_clause(gql_statement)
847
+ if order_keys:
848
+ rows = self._apply_client_side_order_by(rows, fields, order_keys)
849
+
850
+ # Apply LIMIT/OFFSET
851
+ limit, offset = self._parse_limit_offset_clause(gql_statement)
852
+ if offset > 0:
853
+ rows = rows[offset:]
854
+ if limit is not None:
855
+ rows = rows[:limit]
856
+
857
+ # Project to requested columns if the original query specified them
858
+ selected_columns = self._parse_select_columns(original_statement)
859
+ if selected_columns is not None:
860
+ field_names = list(fields.keys())
861
+ projected_rows: List[Tuple] = []
862
+ projected_fields: Dict[str, Any] = {}
863
+
864
+ for col in selected_columns:
865
+ col_lower = col.lower()
866
+ if col_lower in ("__key__", "key") and "key" in fields:
867
+ projected_fields["key"] = fields["key"]
868
+ elif col in fields:
869
+ projected_fields[col] = fields[col]
870
+
871
+ for row in rows:
872
+ new_row: List[Any] = []
873
+ for col in selected_columns:
874
+ col_lower = col.lower()
875
+ lookup = "key" if col_lower in ("__key__", "key") else col
876
+ if lookup in field_names:
877
+ idx = field_names.index(lookup)
878
+ new_row.append(row[idx] if idx < len(row) else None)
879
+ else:
880
+ new_row.append(None)
881
+ projected_rows.append(tuple(new_row))
882
+
883
+ rows = projected_rows
884
+ fields = projected_fields
885
+
886
+ fields_list = list(fields.values())
887
+ self._query_data = iter(rows)
888
+ self._query_rows = iter(rows)
889
+ self.rowcount = len(rows)
890
+ self.description = fields_list if fields_list else None
891
+
892
+ def _apply_client_side_filter(
893
+ self, rows: List[Tuple], fields: Dict[str, Any], statement: str
894
+ ) -> List[Tuple]:
895
+ """Apply client-side filtering for unsupported WHERE conditions."""
896
+ # Parse WHERE clause and apply filters
897
+ upper = statement.upper()
898
+ where_idx = upper.find(" WHERE ")
899
+ if where_idx < 0:
900
+ return rows
901
+
902
+ # Find end of WHERE clause
903
+ end_patterns = [" ORDER BY ", " LIMIT ", " OFFSET "]
904
+ end_idx = len(statement)
905
+ for pattern in end_patterns:
906
+ idx = upper.find(pattern, where_idx)
907
+ if idx > 0 and idx < end_idx:
908
+ end_idx = idx
909
+
910
+ where_clause = statement[where_idx + 7 : end_idx].strip()
911
+ field_names = list(fields.keys())
912
+
913
+ # Apply filter
914
+ filtered_rows = []
915
+ for row in rows:
916
+ if self._evaluate_where(row, field_names, where_clause):
917
+ filtered_rows.append(row)
918
+ return filtered_rows
919
+
920
+ def _evaluate_where(
921
+ self, row: Tuple, field_names: List[str], where_clause: str
922
+ ) -> bool:
923
+ """Evaluate WHERE clause against a row. Returns True if row matches."""
924
+ # Build a context dict from the row
925
+ context = {}
926
+ for i, name in enumerate(field_names):
927
+ if i < len(row):
928
+ context[name] = row[i]
929
+
930
+ # Parse and evaluate the WHERE clause
931
+ # This is a simplified evaluator for common patterns
932
+ try:
933
+ return self._eval_condition(context, where_clause)
934
+ except Exception:
935
+ # If evaluation fails, include the row (fail open)
936
+ return True
937
+
938
+ def _eval_condition(self, context: Dict[str, Any], condition: str) -> bool:
939
+ """Evaluate a single condition or compound condition."""
940
+ condition = condition.strip()
941
+
942
+ # Handle parentheses
943
+ if condition.startswith("(") and condition.endswith(")"):
944
+ # Find matching paren
945
+ depth = 0
946
+ for i, c in enumerate(condition):
947
+ if c == "(":
948
+ depth += 1
949
+ elif c == ")":
950
+ depth -= 1
951
+ if depth == 0:
952
+ if i == len(condition) - 1:
953
+ return self._eval_condition(context, condition[1:-1])
954
+ break
955
+
956
+ # Handle OR (lower precedence)
957
+ or_match = re.search(r"\bOR\b", condition, re.IGNORECASE)
958
+ if or_match:
959
+ # Split on OR, but respect parentheses
960
+ parts = self._split_on_operator(condition, "OR")
961
+ if len(parts) > 1:
962
+ return any(self._eval_condition(context, p) for p in parts)
963
+
964
+ # Handle AND (higher precedence)
965
+ and_match = re.search(r"\bAND\b", condition, re.IGNORECASE)
966
+ if and_match:
967
+ parts = self._split_on_operator(condition, "AND")
968
+ if len(parts) > 1:
969
+ return all(self._eval_condition(context, p) for p in parts)
970
+
971
+ # Handle simple comparisons
972
+ return self._eval_simple_condition(context, condition)
973
+
974
+ def _split_on_operator(self, condition: str, operator: str) -> List[str]:
975
+ """Split condition on operator while respecting parentheses."""
976
+ parts: List[str] = []
977
+ current = ""
978
+ depth = 0
979
+ i = 0
980
+ pattern = re.compile(rf"\b{operator}\b", re.IGNORECASE)
981
+
982
+ while i < len(condition):
983
+ if condition[i] == "(":
984
+ depth += 1
985
+ current += condition[i]
986
+ elif condition[i] == ")":
987
+ depth -= 1
988
+ current += condition[i]
989
+ elif depth == 0:
990
+ match = pattern.match(condition[i:])
991
+ if match:
992
+ parts.append(current.strip())
993
+ current = ""
994
+ i += len(match.group()) - 1
995
+ else:
996
+ current += condition[i]
997
+ else:
998
+ current += condition[i]
999
+ i += 1
1000
+
1001
+ if current.strip():
1002
+ parts.append(current.strip())
1003
+ return parts
1004
+
1005
+ def _eval_simple_condition(self, context: Dict[str, Any], condition: str) -> bool:
1006
+ """Evaluate a simple comparison condition."""
1007
+ condition = condition.strip()
1008
+
1009
+ # Handle __key__ = KEY(kind, value) comparison
1010
+ # Entity key is stored as "key" in context (from ParseEntity)
1011
+ key_eq_match = re.match(
1012
+ r"__key__\s*=\s*KEY\s*\(\s*\w+\s*,\s*(?:'([^']*)'|(\d+))\s*\)",
1013
+ condition,
1014
+ re.IGNORECASE,
1015
+ )
1016
+ if key_eq_match:
1017
+ key_name = key_eq_match.group(1)
1018
+ key_id = key_eq_match.group(2)
1019
+ field_val = context.get("key") or context.get("__key__")
1020
+ if isinstance(field_val, list) and len(field_val) > 0:
1021
+ last_path = field_val[-1]
1022
+ if isinstance(last_path, dict):
1023
+ if key_name is not None:
1024
+ return last_path.get("name") == key_name
1025
+ if key_id is not None:
1026
+ return str(last_path.get("id")) == key_id
1027
+ return False
1028
+
1029
+ # Handle BLOB equality (before generic handlers, since BLOB literal
1030
+ # would confuse the generic _parse_literal path)
1031
+ blob_eq_match = re.match(
1032
+ r"(\w+)\s*=\s*BLOB\s*\('(.*?)'\)",
1033
+ condition,
1034
+ re.IGNORECASE | re.DOTALL,
1035
+ )
1036
+ if blob_eq_match:
1037
+ field = blob_eq_match.group(1)
1038
+ blob_str = blob_eq_match.group(2)
1039
+ try:
1040
+ blob_bytes = blob_str.encode("latin-1")
1041
+ except (UnicodeEncodeError, UnicodeDecodeError):
1042
+ blob_bytes = blob_str.encode("utf-8")
1043
+ field_val = context.get(field)
1044
+ if isinstance(field_val, bytes):
1045
+ return field_val == blob_bytes
1046
+ return False
1047
+
1048
+ # Handle BLOB inequality
1049
+ blob_neq_match = re.match(
1050
+ r"(\w+)\s*!=\s*BLOB\s*\('(.*?)'\)",
1051
+ condition,
1052
+ re.IGNORECASE | re.DOTALL,
1053
+ )
1054
+ if blob_neq_match:
1055
+ field = blob_neq_match.group(1)
1056
+ blob_str = blob_neq_match.group(2)
1057
+ try:
1058
+ blob_bytes = blob_str.encode("latin-1")
1059
+ except (UnicodeEncodeError, UnicodeDecodeError):
1060
+ blob_bytes = blob_str.encode("utf-8")
1061
+ field_val = context.get(field)
1062
+ if isinstance(field_val, bytes):
1063
+ return field_val != blob_bytes
1064
+ return True
1065
+
1066
+ # Handle NOT IN / NOT IN ARRAY
1067
+ not_in_match = re.match(
1068
+ r"(\w+)\s+NOT\s+IN\s+(?:ARRAY\s*)?\(([^)]+)\)",
1069
+ condition, re.IGNORECASE,
1070
+ )
1071
+ if not_in_match:
1072
+ field = not_in_match.group(1)
1073
+ values_str = not_in_match.group(2)
1074
+ values = self._parse_value_list(values_str)
1075
+ field_val = context.get(field)
1076
+ return field_val not in values
1077
+
1078
+ # Handle IN / IN ARRAY
1079
+ in_match = re.match(
1080
+ r"(\w+)\s+IN\s+(?:ARRAY\s*)?\(([^)]+)\)",
1081
+ condition, re.IGNORECASE,
1082
+ )
1083
+ if in_match:
1084
+ field = in_match.group(1)
1085
+ values_str = in_match.group(2)
1086
+ values = self._parse_value_list(values_str)
1087
+ field_val = context.get(field)
1088
+ return field_val in values
1089
+
1090
+ # Handle != and <>
1091
+ neq_match = re.match(r"(\w+)\s*(?:!=|<>)\s*(.+)", condition, re.IGNORECASE)
1092
+ if neq_match:
1093
+ field = neq_match.group(1)
1094
+ value = self._parse_literal(neq_match.group(2).strip())
1095
+ field_val = context.get(field)
1096
+ return field_val != value
1097
+
1098
+ # Handle >=
1099
+ gte_match = re.match(r"(\w+)\s*>=\s*(.+)", condition)
1100
+ if gte_match:
1101
+ field = gte_match.group(1)
1102
+ value = self._parse_literal(gte_match.group(2).strip())
1103
+ field_val = context.get(field)
1104
+ if field_val is not None and value is not None:
1105
+ try:
1106
+ return field_val >= value
1107
+ except TypeError:
1108
+ return False
1109
+ return False
1110
+
1111
+ # Handle <=
1112
+ lte_match = re.match(r"(\w+)\s*<=\s*(.+)", condition)
1113
+ if lte_match:
1114
+ field = lte_match.group(1)
1115
+ value = self._parse_literal(lte_match.group(2).strip())
1116
+ field_val = context.get(field)
1117
+ if field_val is not None and value is not None:
1118
+ try:
1119
+ return field_val <= value
1120
+ except TypeError:
1121
+ return False
1122
+ return False
1123
+
1124
+ # Handle >
1125
+ gt_match = re.match(r"(\w+)\s*>\s*(.+)", condition)
1126
+ if gt_match:
1127
+ field = gt_match.group(1)
1128
+ value = self._parse_literal(gt_match.group(2).strip())
1129
+ field_val = context.get(field)
1130
+ if field_val is not None and value is not None:
1131
+ try:
1132
+ return field_val > value
1133
+ except TypeError:
1134
+ return False
1135
+ return False
1136
+
1137
+ # Handle <
1138
+ lt_match = re.match(r"(\w+)\s*<\s*(.+)", condition)
1139
+ if lt_match:
1140
+ field = lt_match.group(1)
1141
+ value = self._parse_literal(lt_match.group(2).strip())
1142
+ field_val = context.get(field)
1143
+ if field_val is not None and value is not None:
1144
+ try:
1145
+ return field_val < value
1146
+ except TypeError:
1147
+ return False
1148
+ return False
1149
+
1150
+ # Handle =
1151
+ eq_match = re.match(r"(\w+)\s*=\s*(.+)", condition)
1152
+ if eq_match:
1153
+ field = eq_match.group(1)
1154
+ value = self._parse_literal(eq_match.group(2).strip())
1155
+ field_val = context.get(field)
1156
+ return field_val == value
1157
+
1158
+ # Default: include row
1159
+ return True
1160
+
1161
+ def _parse_value_list(self, values_str: str) -> List[Any]:
1162
+ """Parse a comma-separated list of values."""
1163
+ values: List[Any] = []
1164
+ for v in values_str.split(","):
1165
+ values.append(self._parse_literal(v.strip()))
1166
+ return values
1167
+
1168
+ def _parse_literal(self, literal: str) -> Any:
1169
+ """Parse a literal value from string."""
1170
+ literal = literal.strip()
1171
+ # DATETIME literal: DATETIME('2023-01-01T00:00:00Z')
1172
+ datetime_match = re.match(
1173
+ r"DATETIME\s*\(\s*'([^']*)'\s*\)", literal, re.IGNORECASE
1174
+ )
1175
+ if datetime_match:
1176
+ timestamp_str = datetime_match.group(1)
1177
+ if timestamp_str.endswith("Z"):
1178
+ timestamp_str = timestamp_str.replace("Z", "+00:00")
1179
+ # Normalize fractional seconds to 6 digits for Python 3.10
1180
+ # compatibility (fromisoformat only handles 0, 3, or 6 digits).
1181
+ frac_match = re.match(
1182
+ r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})\.(\d+)(.*)",
1183
+ timestamp_str,
1184
+ )
1185
+ if frac_match:
1186
+ frac = frac_match.group(2)[:6].ljust(6, "0")
1187
+ timestamp_str = (
1188
+ frac_match.group(1) + "." + frac + frac_match.group(3)
1189
+ )
1190
+ return datetime.fromisoformat(timestamp_str)
1191
+ # String literal
1192
+ if (literal.startswith("'") and literal.endswith("'")) or (
1193
+ literal.startswith('"') and literal.endswith('"')
1194
+ ):
1195
+ return literal[1:-1]
1196
+ # Boolean
1197
+ if literal.upper() == "TRUE":
1198
+ return True
1199
+ if literal.upper() == "FALSE":
1200
+ return False
1201
+ # NULL
1202
+ if literal.upper() == "NULL":
1203
+ return None
1204
+ # Number
1205
+ try:
1206
+ if "." in literal:
1207
+ return float(literal)
1208
+ return int(literal)
1209
+ except ValueError:
1210
+ return literal
1211
+
1212
+ def _is_orm_id_query(self, statement: str) -> bool:
1213
+ """Check if this is an ORM-style query with table.id in WHERE clause."""
1214
+ upper = statement.upper()
1215
+ # Check for patterns like "table.id = :param" in WHERE clause
1216
+ return (
1217
+ "SELECT" in upper
1218
+ and ".ID" in upper
1219
+ and "WHERE" in upper
1220
+ and (":PK_" in upper or ":ID_" in upper or ".ID =" in upper)
1221
+ )
1222
+
1223
+ def _execute_orm_id_query(self, statement: str, parameters: dict):
1224
+ """Execute an ORM-style query by ID using direct key lookup."""
1225
+ try:
1226
+ parsed = parse_one(statement)
1227
+ if not isinstance(parsed, exp.Select):
1228
+ raise ProgrammingError("Expected SELECT statement")
1229
+
1230
+ # Get table name
1231
+ from_arg = parsed.args.get("from") or parsed.args.get("from_")
1232
+ if not from_arg:
1233
+ raise ProgrammingError("Could not find FROM clause")
1234
+ table_name = from_arg.this.name if hasattr(from_arg.this, "name") else str(from_arg.this)
1235
+
1236
+ # Extract column aliases from SELECT clause FIRST (before querying)
1237
+ # This ensures we have description even when no entity is found
1238
+ column_info = []
1239
+ for expr in parsed.expressions:
1240
+ if isinstance(expr, exp.Alias):
1241
+ alias = expr.alias
1242
+ if isinstance(expr.this, exp.Column):
1243
+ col_name = expr.this.name
1244
+ else:
1245
+ col_name = str(expr.this)
1246
+ column_info.append((col_name, alias))
1247
+ elif isinstance(expr, exp.Column):
1248
+ col_name = expr.name
1249
+ column_info.append((col_name, col_name))
1250
+ elif isinstance(expr, exp.Star):
1251
+ # SELECT * - we'll handle this after fetching entity
1252
+ column_info = None
1253
+ break
1254
+
1255
+ # Build description from column info (for non-SELECT * cases)
1256
+ if column_info is not None:
1257
+ field_names = [alias for _, alias in column_info]
1258
+ self.description = [
1259
+ (name, None, None, None, None, None, None)
1260
+ for name in field_names
1261
+ ]
1262
+
1263
+ # Extract ID from WHERE clause
1264
+ where = parsed.args.get("where")
1265
+ if not where:
1266
+ raise ProgrammingError("Expected WHERE clause")
1267
+
1268
+ entity_key_id = self._extract_key_id_from_where(where, parameters)
1269
+ if entity_key_id is None:
1270
+ raise ProgrammingError("Could not extract key ID from WHERE")
1271
+
1272
+ # Fetch entity by key
1273
+ key = self._datastore_client.key(table_name, entity_key_id)
1274
+ entity = self._datastore_client.get(key)
1275
+
1276
+ if entity is None:
1277
+ # No entity found - description is already set above
1278
+ self._query_rows = iter([])
1279
+ self.rowcount = 0
1280
+ # For SELECT *, set empty description since we don't know the schema
1281
+ if column_info is None:
1282
+ self.description = []
1283
+ return
1284
+
1285
+ # Build result row
1286
+ if column_info is None:
1287
+ # SELECT * case
1288
+ row_values = [entity.key.id] # Add id first
1289
+ field_names = ["id"]
1290
+ for prop_name in sorted(entity.keys()):
1291
+ row_values.append(entity[prop_name])
1292
+ field_names.append(prop_name)
1293
+ # Build description for SELECT *
1294
+ self.description = [
1295
+ (name, None, None, None, None, None, None)
1296
+ for name in field_names
1297
+ ]
1298
+ else:
1299
+ row_values = []
1300
+ for col_name, alias in column_info:
1301
+ if col_name.lower() == "id":
1302
+ row_values.append(entity.key.id)
1303
+ else:
1304
+ row_values.append(entity.get(col_name))
1305
+
1306
+ self._query_rows = iter([tuple(row_values)])
1307
+ self.rowcount = 1
1308
+
1309
+ except Exception as e:
1310
+ logging.error(f"ORM ID query failed: {e}")
1311
+ raise ProgrammingError(f"ORM ID query failed: {e}") from e
1312
+
1313
+ def _substitute_parameters(self, statement: str, parameters: dict) -> str:
1314
+ """Substitute named parameters in SQL statement with their values."""
1315
+ result = statement
1316
+ for param_name, value in parameters.items():
1317
+ # Build the placeholder pattern (e.g., :param_name)
1318
+ placeholder = f":{param_name}"
1319
+
1320
+ # Format the value appropriately for GQL
1321
+ if value is None:
1322
+ formatted_value = "NULL"
1323
+ elif isinstance(value, str):
1324
+ # Escape single quotes in strings
1325
+ escaped = value.replace("'", "''")
1326
+ formatted_value = f"'{escaped}'"
1327
+ elif isinstance(value, bool):
1328
+ formatted_value = "true" if value else "false"
1329
+ elif isinstance(value, (int, float)):
1330
+ formatted_value = str(value)
1331
+ elif isinstance(value, datetime):
1332
+ # Format as ISO string for GQL
1333
+ formatted_value = f"DATETIME('{value.isoformat()}')"
1334
+ else:
1335
+ # Default to string representation
1336
+ formatted_value = f"'{str(value)}'"
1337
+
1338
+ result = result.replace(placeholder, formatted_value)
1339
+
1340
+ return result
1341
+
1342
+ def gql_query(self, statement, parameters=None, **kwargs):
1343
+ """Execute a GQL query with support for aggregations."""
1344
+
1345
+ # Check for ORM-style queries with table.id in WHERE clause
1346
+ if parameters and self._is_orm_id_query(statement):
1347
+ self._execute_orm_id_query(statement, parameters)
1348
+ return
1349
+
1350
+ # Substitute parameters if provided
1351
+ if parameters:
1352
+ statement = self._substitute_parameters(statement, parameters)
1353
+
1354
+ # Convert SQL to GQL-compatible format
1355
+ gql_statement = self._convert_sql_to_gql(statement)
1356
+ logging.debug(f"Converted GQL statement: {gql_statement}")
1357
+
1358
+ # Check if this is an aggregation query
1359
+ if self._is_aggregation_query(statement):
1360
+ self._execute_aggregation_query(statement, parameters)
1361
+ return
1362
+
1363
+ # Check if we need client-side filtering (check converted GQL)
1364
+ needs_filter = self._needs_client_side_filter(gql_statement)
1365
+ if needs_filter:
1366
+ # Get base query without unsupported WHERE conditions
1367
+ base_query = self._extract_base_query_for_filter(gql_statement)
1368
+ gql_statement = self._convert_sql_to_gql(base_query)
1369
+
1370
+ # Execute GQL query
1371
+ response = self._execute_gql_request(gql_statement)
1372
+
1373
+ if response.status_code == 200:
1374
+ data = response.json()
1375
+ logging.debug(data)
1376
+ else:
1377
+ # Fall back to client-side processing for any GQL failure.
1378
+ # The emulator may return 400 (INVALID_ARGUMENT for !=, NOT IN,
1379
+ # multi-value IN ARRAY), 409/400 (missing composite index), or
1380
+ # 500 (server error). In all cases we fetch all data from the
1381
+ # table and apply WHERE, ORDER BY, LIMIT/OFFSET client-side.
1382
+ # If even the simple fallback query fails, it raises an error.
1383
+ warning_msg = (
1384
+ "GQL query failed. Falling back to fetching ALL entities "
1385
+ "and processing client-side (SELECT * mode). This may "
1386
+ "significantly increase query and egress costs. Consider "
1387
+ "adding the required index for this query."
1388
+ )
1389
+ logging.warning(
1390
+ "%s (status %d)", warning_msg, response.status_code
1391
+ )
1392
+ self.warnings.append(warning_msg)
1393
+ self._execute_fallback_query(statement, statement)
1394
+ return
1395
+
1396
+ self._query_data = iter([])
1397
+ self._query_rows = iter([])
1398
+ self.rowcount = 0
1399
+ self.description = [(None, None, None, None, None, None, None)]
1400
+ self._last_executed = statement
1401
+ self._parameters = parameters or {}
1402
+
1403
+ data = data.get("batch", {}).get("entityResults", [])
1404
+ if len(data) == 0:
1405
+ return # Everything is already set for an empty result
1406
+
1407
+ # Determine if this statement is expected to return rows (e.g., SELECT)
1408
+ # You'll need a way to figure this out based on 'statement' or a flag passed to your custom execute method.
1409
+ # Example (simplified check, you might need a more robust parsing or flag):
1410
+ is_select_statement = statement.upper().strip().startswith("SELECT")
1411
+
1412
+ if is_select_statement:
1413
+ self._closed = False # For SELECT, cursor should remain open to fetch rows
1414
+
1415
+ # Parse the SELECT statement to get column list
1416
+ selected_columns = self._parse_select_columns(statement)
1417
+
1418
+ rows, fields = ParseEntity.parse(data, selected_columns)
1419
+
1420
+ # Apply client-side filtering if needed.
1421
+ # Use the original statement (not the converted GQL) to preserve
1422
+ # binary data inside BLOB literals that whitespace normalization
1423
+ # in _convert_sql_to_gql would corrupt.
1424
+ if needs_filter:
1425
+ rows = self._apply_client_side_filter(rows, fields, statement)
1426
+
1427
+ fields = list(fields.values())
1428
+ self._query_data = iter(rows)
1429
+ self._query_rows = iter(rows)
1430
+ self.rowcount = len(rows)
1431
+ self.description = fields if len(fields) > 0 else None
1432
+ else:
1433
+ # For INSERT/UPDATE/DELETE, the operation is complete, no rows to yield
1434
+ # For INSERT/UPDATE/DELETE, the operation is complete, set rowcount if possible
1435
+ affected_count = len(data) if isinstance(data, list) else 0
1436
+ self.rowcount = affected_count
1437
+ self._closed = True
1438
+
1439
+ def _execute_aggregation_query(self, statement: str, parameters=None):
1440
+ """Execute an aggregation query with client-side aggregation."""
1441
+ parsed = self._parse_aggregation_query(statement)
1442
+ agg_functions = parsed["agg_functions"]
1443
+ base_query = parsed["base_query"]
1444
+
1445
+ # If there's no base query and no functions, return empty
1446
+ if not agg_functions:
1447
+ self._query_rows = iter([])
1448
+ self.rowcount = 0
1449
+ self.description = []
1450
+ return
1451
+
1452
+ # If there's no base query (e.g., SELECT COUNT(*) without FROM)
1453
+ # Return a count of 0 or handle specially
1454
+ if base_query is None:
1455
+ # For kindless COUNT(*), we return 0 since we can't query all kinds
1456
+ result_values: List[Any] = []
1457
+ result_fields: Dict[str, Any] = {}
1458
+ for func_name, col, alias in agg_functions:
1459
+ if func_name == "COUNT":
1460
+ result_values.append(0)
1461
+ elif func_name == "COUNT_UP_TO":
1462
+ result_values.append(0)
1463
+ else:
1464
+ result_values.append(0)
1465
+ result_fields[alias] = (alias, None, None, None, None, None, None)
1466
+
1467
+ self._query_rows = iter([tuple(result_values)])
1468
+ self.rowcount = 1
1469
+ self.description = list(result_fields.values())
1470
+ return
1471
+
1472
+ # Convert to GQL first, then check for client-side filtering
1473
+ base_gql = self._convert_sql_to_gql(base_query)
1474
+ original_base_gql = base_gql # Save for potential fallback
1475
+ needs_filter = self._needs_client_side_filter(base_gql)
1476
+ if needs_filter:
1477
+ filter_query = self._extract_base_query_for_filter(base_gql)
1478
+ base_gql = self._convert_sql_to_gql(filter_query)
1479
+
1480
+ response = self._execute_gql_request(base_gql)
1481
+
1482
+ if response.status_code != 200:
1483
+ warning_msg = (
1484
+ "Aggregation base query failed. Falling back to fetching "
1485
+ "ALL entities and aggregating client-side (SELECT * mode). "
1486
+ "This may significantly increase query and egress costs. "
1487
+ "Consider adding the required index for the columns used "
1488
+ "in this aggregation."
1489
+ )
1490
+ logging.warning(
1491
+ "%s (status %d)", warning_msg, response.status_code
1492
+ )
1493
+ self.warnings.append(warning_msg)
1494
+ fallback_query = self._extract_table_only_query(
1495
+ original_base_gql
1496
+ )
1497
+ response = self._execute_gql_request(fallback_query)
1498
+ if response.status_code != 200:
1499
+ raise OperationalError(
1500
+ f"Aggregation fallback query failed: "
1501
+ f"{fallback_query} (original: {statement})"
1502
+ )
1503
+ fb_data = response.json()
1504
+ fb_results = fb_data.get("batch", {}).get(
1505
+ "entityResults", []
1506
+ )
1507
+ if not fb_results:
1508
+ result_values: List[Any] = []
1509
+ result_fields: Dict[str, Any] = {}
1510
+ for _fn, _col, alias in agg_functions:
1511
+ result_values.append(0)
1512
+ result_fields[alias] = (
1513
+ alias, None, None, None, None, None, None
1514
+ )
1515
+ self._query_rows = iter([tuple(result_values)])
1516
+ self.rowcount = 1
1517
+ self.description = list(result_fields.values())
1518
+ return
1519
+ rows, fields = ParseEntity.parse(fb_results, None)
1520
+ rows = self._apply_client_side_filter(
1521
+ rows, fields, base_query
1522
+ )
1523
+ agg_rows, agg_fields = self._compute_aggregations(
1524
+ rows, fields, agg_functions
1525
+ )
1526
+ self._query_rows = iter(agg_rows)
1527
+ self.rowcount = len(agg_rows)
1528
+ self.description = list(agg_fields.values())
1529
+ return
1530
+
1531
+ data = response.json()
1532
+ entity_results = data.get("batch", {}).get("entityResults", [])
1533
+
1534
+ if len(entity_results) == 0:
1535
+ # No data - return aggregations with 0 values
1536
+ result_values = []
1537
+ result_fields: Dict[str, Any] = {}
1538
+ for func_name, _col, alias in agg_functions:
1539
+ if func_name == "COUNT":
1540
+ result_values.append(0)
1541
+ elif func_name == "COUNT_UP_TO":
1542
+ result_values.append(0)
1543
+ elif func_name in ("SUM", "AVG"):
1544
+ result_values.append(0)
1545
+ else:
1546
+ result_values.append(None)
1547
+ result_fields[alias] = (alias, None, None, None, None, None, None)
1548
+
1549
+ self._query_rows = iter([tuple(result_values)])
1550
+ self.rowcount = 1
1551
+ self.description = list(result_fields.values())
1552
+ return
1553
+
1554
+ # Parse the entity results
1555
+ rows, fields = ParseEntity.parse(entity_results, None)
1556
+
1557
+ # Apply client-side filtering if needed
1558
+ if needs_filter:
1559
+ rows = self._apply_client_side_filter(rows, fields, base_query)
1560
+
1561
+ # Compute aggregations
1562
+ agg_rows, agg_fields = self._compute_aggregations(rows, fields, agg_functions)
1563
+
1564
+ self._query_rows = iter(agg_rows)
1565
+ self.rowcount = len(agg_rows)
1566
+ self.description = list(agg_fields.values())
1567
+
1568
+ def execute_orm(
1569
+ self, statement: str, parameters=None, tokens: List[tokens.Token] = []
1570
+ ):
1571
+ if parameters is None:
1572
+ parameters = {}
1573
+
1574
+ logging.debug(
1575
+ f"[DataStore DBAPI] Executing ORM query: {statement} with parameters: {parameters}"
1576
+ )
1577
+
1578
+ statement = statement.replace("`", "'")
1579
+ parsed = parse_one(statement)
1580
+ # Note: sqlglot uses "from_" as the key, not "from"
1581
+ from_arg = parsed.args.get("from") or parsed.args.get("from_")
1582
+ if not isinstance(parsed, exp.Select) or not from_arg:
1583
+ raise ProgrammingError("Unsupported ORM query structure.")
1584
+
1585
+ from_clause = from_arg.this
1586
+ if not isinstance(from_clause, exp.Subquery):
1587
+ raise ProgrammingError("Expected a subquery in the FROM clause.")
1588
+
1589
+ subquery_sql = from_clause.this.sql()
1590
+
1591
+ # 1. Query the subquery table
1592
+ self.gql_query(subquery_sql)
1593
+ subquery_results = self.fetchall()
1594
+ subquery_description = self.description
1595
+
1596
+ # 2. Turn to pandas dataframe
1597
+ if not subquery_description:
1598
+ df = pd.DataFrame(subquery_results)
1599
+ else:
1600
+ column_names = [col[0] for col in subquery_description]
1601
+ df = pd.DataFrame(subquery_results, columns=column_names)
1602
+
1603
+ # Add computed columns from SELECT expressions before grouping or ordering
1604
+ for p in parsed.expressions:
1605
+ if isinstance(p, exp.Alias) and not p.find(exp.AggFunc):
1606
+ # This is a simplified expression evaluator for computed columns.
1607
+ # It converts "col" to col and leaves other things as is.
1608
+ expr_str = re.sub(r'"(\w+)"', r"\1", p.this.sql())
1609
+ try:
1610
+ # Use assign to add new columns based on expressions
1611
+ df = df.assign(**{p.alias: df.eval(expr_str, engine="python")})
1612
+ except Exception as e:
1613
+ logging.warning(f"Could not evaluate expression '{expr_str}': {e}")
1614
+
1615
+ # 3. Apply outer query logic (aggregations and GROUP BY)
1616
+ has_agg = any(
1617
+ isinstance(p, exp.Alias) and p.find(exp.AggFunc)
1618
+ for p in parsed.expressions
1619
+ )
1620
+
1621
+ if parsed.args.get("group"):
1622
+ group_by_cols = []
1623
+ for e in parsed.args.get("group").expressions:
1624
+ col_name = e.name if hasattr(e, "name") else ""
1625
+ if col_name and col_name in df.columns:
1626
+ group_by_cols.append(col_name)
1627
+ else:
1628
+ # Function expression (e.g. DATETIME_TRUNC) — find
1629
+ # the matching alias in the SELECT clause.
1630
+ expr_sql = e.sql()
1631
+ matched = False
1632
+ for p in parsed.expressions:
1633
+ if isinstance(p, exp.Alias) and p.this.sql() == expr_sql:
1634
+ group_by_cols.append(p.alias)
1635
+ matched = True
1636
+ break
1637
+ if not matched:
1638
+ group_by_cols.append(col_name)
1639
+
1640
+ # Convert unhashable types (lists, dicts) to hashable types for groupby.
1641
+ # Datastore keys are stored as lists of dicts, GeoPoints as dicts.
1642
+ converted_cols = {}
1643
+ for col in group_by_cols:
1644
+ if col in df.columns:
1645
+ sample = df[col].dropna().head(1)
1646
+ if len(sample) > 0 and isinstance(sample.iloc[0], list):
1647
+ converted_cols[col] = df[col].apply(
1648
+ lambda x: tuple(
1649
+ tuple(sorted(d.items()))
1650
+ if isinstance(d, dict)
1651
+ else d
1652
+ for d in x
1653
+ )
1654
+ if isinstance(x, list)
1655
+ else x
1656
+ )
1657
+ df[col] = converted_cols[col]
1658
+ elif len(sample) > 0 and isinstance(sample.iloc[0], dict):
1659
+ converted_cols[col] = df[col].apply(
1660
+ lambda x: tuple(sorted(x.items()))
1661
+ if isinstance(x, dict)
1662
+ else x
1663
+ )
1664
+ df[col] = converted_cols[col]
1665
+
1666
+ col_renames = {}
1667
+ for p in parsed.expressions:
1668
+ if isinstance(p, exp.Alias) and p.find(exp.AggFunc):
1669
+ agg_func = p.this
1670
+ agg_func_name = agg_func.key.lower()
1671
+ # Map SQL aggregate names to pandas equivalents
1672
+ sql_to_pandas_agg = {"avg": "mean"}
1673
+ agg_func_name = sql_to_pandas_agg.get(
1674
+ agg_func_name, agg_func_name
1675
+ )
1676
+ if agg_func.expressions:
1677
+ original_col_name = agg_func.expressions[0].name
1678
+ elif isinstance(agg_func.this, exp.Distinct):
1679
+ # COUNT(DISTINCT col) - column is inside Distinct
1680
+ original_col_name = (
1681
+ agg_func.this.expressions[0].name
1682
+ )
1683
+ # Use pandas nunique for COUNT(DISTINCT)
1684
+ agg_func_name = "nunique"
1685
+ elif isinstance(agg_func.this, exp.Star):
1686
+ # COUNT(*) - use first group_by column for counting
1687
+ original_col_name = group_by_cols[0]
1688
+ elif agg_func.this is not None and hasattr(
1689
+ agg_func.this, "name"
1690
+ ):
1691
+ original_col_name = agg_func.this.name
1692
+ else:
1693
+ # Fallback for unknown structures
1694
+ original_col_name = group_by_cols[0]
1695
+ desired_sql_alias = p.alias_or_name
1696
+ col_renames = {"temp_agg": desired_sql_alias}
1697
+ df = (
1698
+ df.groupby(group_by_cols)
1699
+ .agg(temp_agg=(original_col_name, agg_func_name))
1700
+ .reset_index()
1701
+ .rename(columns=col_renames)
1702
+ )
1703
+
1704
+ elif has_agg:
1705
+ # Aggregation without GROUP BY (e.g., SELECT COUNT(*) FROM table)
1706
+ result_data: Dict[str, Any] = {}
1707
+ for p in parsed.expressions:
1708
+ if not isinstance(p, exp.Alias) or not p.find(exp.AggFunc):
1709
+ continue
1710
+ agg_func = p.this
1711
+ agg_func_name = agg_func.key.lower()
1712
+ alias = p.alias_or_name
1713
+
1714
+ if agg_func_name == "count":
1715
+ if isinstance(agg_func.this, exp.Star):
1716
+ result_data[alias] = len(df)
1717
+ elif isinstance(agg_func.this, exp.Distinct):
1718
+ col_name = agg_func.this.expressions[0].name
1719
+ result_data[alias] = df[col_name].nunique()
1720
+ elif agg_func.expressions:
1721
+ col_name = agg_func.expressions[0].name
1722
+ result_data[alias] = df[col_name].count()
1723
+ else:
1724
+ result_data[alias] = len(df)
1725
+ elif agg_func_name == "sum":
1726
+ col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
1727
+ result_data[alias] = df[col_name].sum()
1728
+ elif agg_func_name == "avg":
1729
+ col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
1730
+ result_data[alias] = df[col_name].mean()
1731
+ elif agg_func_name == "min":
1732
+ col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
1733
+ result_data[alias] = df[col_name].min()
1734
+ elif agg_func_name == "max":
1735
+ col_name = agg_func.this.name if agg_func.this else agg_func.expressions[0].name
1736
+ result_data[alias] = df[col_name].max()
1737
+ else:
1738
+ result_data[alias] = None
1739
+
1740
+ df = pd.DataFrame([result_data])
1741
+
1742
+ if parsed.args.get("order"):
1743
+ order_by_cols = [e.this.name for e in parsed.args["order"].expressions]
1744
+ ascending = [
1745
+ not e.args.get("desc", False) for e in parsed.args["order"].expressions
1746
+ ]
1747
+ # Convert uncomparable types (dicts, lists) to strings for sorting.
1748
+ # Datastore keys are lists of dicts and GeoPoints are dicts, which
1749
+ # cannot be compared with < in Python 3.
1750
+ for col in order_by_cols:
1751
+ if col in df.columns:
1752
+ sample = df[col].dropna().head(1)
1753
+ if len(sample) > 0 and isinstance(
1754
+ sample.iloc[0], (dict, list)
1755
+ ):
1756
+ df[col] = df[col].apply(
1757
+ lambda x: str(x) if isinstance(x, (dict, list)) else x
1758
+ )
1759
+ df = df.sort_values(by=order_by_cols, ascending=ascending)
1760
+
1761
+ if parsed.args.get("limit"):
1762
+ limit = int(parsed.args["limit"].expression.sql())
1763
+ df = df.head(limit)
1764
+
1765
+ # Final column selection
1766
+ if not any(isinstance(p, exp.Star) for p in parsed.expressions):
1767
+ final_columns = [p.alias_or_name for p in parsed.expressions]
1768
+ # Ensure all selected columns exist in the DataFrame before selecting
1769
+ df = df[[col for col in final_columns if col in df.columns]]
1770
+
1771
+ # Finalize results
1772
+ rows = [tuple(x) for x in df.to_numpy()]
1773
+ schema = self._create_schema_from_df(df)
1774
+ self.rowcount = len(rows) if rows else 0
1775
+ self._set_description(schema)
1776
+ self._query_rows = iter(rows)
1777
+
1778
+ def _create_schema_from_df(self, df: pd.DataFrame) -> tuple:
1779
+ """Create schema from a pandas DataFrame."""
1780
+ schema = []
1781
+ for col_name, dtype in df.dtypes.items():
1782
+ if pd.api.types.is_string_dtype(dtype):
1783
+ sa_type = types.String
1784
+ elif pd.api.types.is_integer_dtype(dtype):
1785
+ sa_type = types.Integer
1786
+ elif pd.api.types.is_float_dtype(dtype):
1787
+ sa_type = types.Float
1788
+ elif pd.api.types.is_bool_dtype(dtype):
1789
+ sa_type = types.Boolean
1790
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
1791
+ sa_type = types.DateTime
1792
+ else:
1793
+ sa_type = types.String # Fallback
1794
+
1795
+ schema.append(
1796
+ Column(
1797
+ name=col_name,
1798
+ type_code=sa_type(),
1799
+ display_size=None,
1800
+ internal_size=None,
1801
+ precision=None,
1802
+ scale=None,
1803
+ null_ok=True,
1804
+ )
1805
+ )
1806
+ return tuple(schema)
1807
+
1808
+ def _set_description(self, schema: tuple = ()):
1809
+ """Set the cursor description based on the schema."""
1810
+ self.description = schema
1811
+
1812
+ def fetchall(self):
1813
+ if self._closed:
1814
+ raise Error("Cursor is closed.")
1815
+ return list(self._query_rows)
1816
+
1817
+ def fetchmany(self, size=None):
1818
+ if self._closed:
1819
+ raise Error("Cursor is closed.")
1820
+ if size is None:
1821
+ size = self.arraysize or 1
1822
+ results = []
1823
+ for _ in range(size):
1824
+ try:
1825
+ results.append(next(self._query_rows))
1826
+ except StopIteration:
1827
+ break
1828
+ return results
1829
+
1830
+ def fetchone(self):
1831
+ if self._closed:
1832
+ raise Error("Cursor is closed.")
1833
+ try:
1834
+ return next(self._query_rows)
1835
+ except StopIteration:
1836
+ return None
1837
+
1838
+ def _parse_select_columns(self, statement: str) -> Optional[List[str]]:
1839
+ """
1840
+ Parse SELECT statement to extract column names.
1841
+ Returns None for SELECT * (all columns)
1842
+ """
1843
+ try:
1844
+ # Use sqlglot to parse the statement
1845
+ parsed = parse_one(statement)
1846
+ if not isinstance(parsed, exp.Select):
1847
+ return None
1848
+
1849
+ columns = []
1850
+ for expr in parsed.expressions:
1851
+ if isinstance(expr, exp.Star):
1852
+ # SELECT * - return None to indicate all columns
1853
+ return None
1854
+ elif isinstance(expr, exp.Column):
1855
+ # Direct column reference
1856
+ col_name = expr.name
1857
+ # Map 'id' to '__key__' since Datastore uses keys, not id properties
1858
+ if col_name.lower() == "id":
1859
+ col_name = "__key__"
1860
+ columns.append(col_name)
1861
+ elif isinstance(expr, exp.Alias):
1862
+ # Column with alias
1863
+ if isinstance(expr.this, exp.Column):
1864
+ col_name = expr.this.name
1865
+ columns.append(col_name)
1866
+ else:
1867
+ # For complex expressions, use the alias
1868
+ columns.append(expr.alias)
1869
+ else:
1870
+ # For other expressions, try to get the name or use the string representation
1871
+ col_name = expr.alias_or_name
1872
+ if col_name:
1873
+ columns.append(col_name)
1874
+
1875
+ return columns if columns else None
1876
+ except Exception:
1877
+ # If parsing fails, return None to get all columns
1878
+ return None
1879
+
1880
+ def _convert_sql_to_gql(self, statement: str) -> str:
1881
+ """
1882
+ Convert SQL statements to GQL-compatible format.
1883
+
1884
+ GQL (Google Query Language) is similar to SQL but has its own syntax.
1885
+ This method reverses transformations applied by Superset's sqlglot
1886
+ processing (BigQuery dialect) and makes other adjustments for GQL
1887
+ compatibility.
1888
+ """
1889
+ # AGGREGATE queries are valid GQL - pass through directly
1890
+ if statement.strip().upper().startswith("AGGREGATE"):
1891
+ return statement
1892
+
1893
+ # Normalize whitespace: sqlglot pretty-prints with newlines which
1894
+ # breaks position-based string operations (find, regex).
1895
+ statement = re.sub(r"\s+", " ", statement).strip()
1896
+
1897
+ # === Reverse sqlglot / BigQuery dialect transformations ===
1898
+
1899
+ # 1. Convert <> back to != (sqlglot BigQuery dialect converts != to <>)
1900
+ # GQL uses != for not-equals comparisons.
1901
+ statement = re.sub(r"<>", "!=", statement)
1902
+
1903
+ # 2. Fix NOT ... IN -> ... NOT IN
1904
+ # sqlglot converts "col NOT IN (...)" to "NOT col IN (...)"
1905
+ # GQL expects "col NOT IN (...)"
1906
+ statement = re.sub(
1907
+ r"\bNOT\s+(\w+)\s+IN\s*\(",
1908
+ r"\1 NOT IN (",
1909
+ statement,
1910
+ flags=re.IGNORECASE,
1911
+ )
1912
+
1913
+ # 3. Strip ROW_NUMBER() OVER (...) added by sqlglot for DISTINCT ON
1914
+ # BigQuery dialect converts "SELECT DISTINCT ON (col) * FROM t"
1915
+ # to "SELECT *, ROW_NUMBER() OVER (PARTITION BY col ...) AS _row_... FROM t"
1916
+ # We strip the ROW_NUMBER expression and any trailing WHERE _row_... = 1
1917
+ statement = re.sub(
1918
+ r",\s*ROW_NUMBER\s*\(\s*\)\s*OVER\s*\([^)]*\)\s*(?:AS\s+\w+)?",
1919
+ "",
1920
+ statement,
1921
+ flags=re.IGNORECASE,
1922
+ )
1923
+ # Also remove the WHERE _row_number = 1 subquery wrapper if present
1924
+ statement = re.sub(
1925
+ r"\bWHERE\s+_row_\w+\s*=\s*1\b",
1926
+ "",
1927
+ statement,
1928
+ flags=re.IGNORECASE,
1929
+ )
1930
+
1931
+ # 4. Fix IN clause syntax for GQL
1932
+ # a) Convert square bracket arrays: IN ['val'] -> IN ARRAY('val')
1933
+ # b) Convert parenthesized lists: IN ('val1', 'val2') -> IN ARRAY('val1', 'val2')
1934
+ # GQL requires the ARRAY keyword: "name IN ARRAY('val1', 'val2')"
1935
+ # NOT IN also needs: "name NOT IN ARRAY('val1', 'val2')"
1936
+ statement = re.sub(
1937
+ r"\bIN\s*\[([^\]]*)\]",
1938
+ r"IN ARRAY(\1)",
1939
+ statement,
1940
+ flags=re.IGNORECASE,
1941
+ )
1942
+ # Convert IN (...) to IN ARRAY(...) but don't double-convert IN ARRAY(...)
1943
+ statement = re.sub(
1944
+ r"\bIN\s*\((?![\s]*SELECT\b)",
1945
+ "IN ARRAY(",
1946
+ statement,
1947
+ flags=re.IGNORECASE,
1948
+ )
1949
+ # Fix double ARRAY: if original was already ARRAY, we'd get IN ARRAY(ARRAY(...)
1950
+ statement = re.sub(
1951
+ r"\bARRAY\s*\(\s*ARRAY\s*\(",
1952
+ "ARRAY(",
1953
+ statement,
1954
+ flags=re.IGNORECASE,
1955
+ )
1956
+
1957
+ # 5. Fix WHERE NULL (from sqlglot optimizing "col = NULL" to "NULL")
1958
+ # sqlglot treats "col = NULL" as always-false and collapses to NULL.
1959
+ # We can't recover the original column, but if the WHERE clause is
1960
+ # just "WHERE NULL", remove it since it would return no results.
1961
+ statement = re.sub(
1962
+ r"\bWHERE\s+NULL\b",
1963
+ "",
1964
+ statement,
1965
+ flags=re.IGNORECASE,
1966
+ )
1967
+
1968
+ # === GQL-specific transformations ===
1969
+
1970
+ # Handle LIMIT FIRST(offset, count) syntax
1971
+ # Convert to LIMIT <count> OFFSET <offset>
1972
+ first_match = re.search(
1973
+ r"LIMIT\s+FIRST\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)",
1974
+ statement,
1975
+ flags=re.IGNORECASE,
1976
+ )
1977
+ if first_match:
1978
+ offset = first_match.group(1)
1979
+ count = first_match.group(2)
1980
+ statement = re.sub(
1981
+ r"LIMIT\s+FIRST\s*\(\s*\d+\s*,\s*\d+\s*\)",
1982
+ f"LIMIT {count} OFFSET {offset}",
1983
+ statement,
1984
+ flags=re.IGNORECASE,
1985
+ )
1986
+
1987
+ # Extract table name from FROM clause for KEY() conversion
1988
+ table_match = re.search(
1989
+ r"\bFROM\s+(\w+)", statement, flags=re.IGNORECASE
1990
+ )
1991
+ table_name = table_match.group(1) if table_match else None
1992
+
1993
+ # Remove DISTINCT ON (...) syntax - not supported by GQL.
1994
+ # GQL supports DISTINCT but not DISTINCT ON.
1995
+ statement = re.sub(
1996
+ r"\bDISTINCT\s+ON\s*\([^)]*\)\s*",
1997
+ "",
1998
+ statement,
1999
+ flags=re.IGNORECASE,
2000
+ )
2001
+
2002
+ # Convert table.id in SELECT clause to __key__
2003
+ if table_name:
2004
+ statement = re.sub(
2005
+ rf"\b{table_name}\.id\b",
2006
+ "__key__",
2007
+ statement,
2008
+ flags=re.IGNORECASE,
2009
+ )
2010
+
2011
+ # Handle bare 'id' references for GQL compatibility
2012
+ upper_stmt = statement.upper()
2013
+ from_pos = upper_stmt.find(" FROM ")
2014
+ if from_pos > 0:
2015
+ select_clause = statement[:from_pos]
2016
+ from_and_rest = statement[from_pos:]
2017
+
2018
+ # Parse SELECT columns and remove id/__key__ from projection
2019
+ select_match = re.match(
2020
+ r"(SELECT\s+(?:DISTINCT\s+(?:ON\s*\([^)]*\)\s*)?)?)(.*)",
2021
+ select_clause,
2022
+ flags=re.IGNORECASE,
2023
+ )
2024
+ if select_match:
2025
+ prefix = select_match.group(1)
2026
+ cols_str = select_match.group(2)
2027
+ cols = [c.strip() for c in cols_str.split(",")]
2028
+ non_key_cols = [
2029
+ c
2030
+ for c in cols
2031
+ if not re.match(
2032
+ r"^(id|__key__)$", c.strip(), flags=re.IGNORECASE
2033
+ )
2034
+ ]
2035
+
2036
+ if not non_key_cols:
2037
+ select_clause = prefix + "__key__"
2038
+ elif len(non_key_cols) < len(cols):
2039
+ select_clause = prefix + ", ".join(non_key_cols)
2040
+
2041
+ # Convert 'id' to '__key__' in WHERE/ORDER BY/etc.
2042
+ from_and_rest = re.sub(
2043
+ r"\bid\b", "__key__", from_and_rest, flags=re.IGNORECASE
2044
+ )
2045
+
2046
+ statement = select_clause + from_and_rest
2047
+ else:
2048
+ statement = re.sub(
2049
+ r"\bid\b", "__key__", statement, flags=re.IGNORECASE
2050
+ )
2051
+
2052
+ # Datastore restriction: projection queries with WHERE clauses require
2053
+ # composite indexes. Convert to SELECT * to avoid this requirement and
2054
+ # let ParseEntity handle column filtering from the full entity response.
2055
+ upper_check = statement.upper()
2056
+ from_check_pos = upper_check.find(" FROM ")
2057
+ where_check_pos = upper_check.find(" WHERE ")
2058
+ if from_check_pos > 0 and where_check_pos > from_check_pos:
2059
+ select_cols_str = re.sub(
2060
+ r"^SELECT\s+", "", statement[:from_check_pos], flags=re.IGNORECASE
2061
+ ).strip()
2062
+ if (
2063
+ select_cols_str != "*"
2064
+ and select_cols_str.upper() != "__KEY__"
2065
+ and not select_cols_str.upper().startswith("DISTINCT")
2066
+ ):
2067
+ statement = "SELECT * " + statement[from_check_pos + 1:]
2068
+
2069
+ # Handle id = <number> in WHERE clauses -> KEY() syntax
2070
+ if table_name:
2071
+ id_where_match = re.search(
2072
+ r"\bWHERE\b.*\b(?:id|__key__)\s*=\s*(\d+)",
2073
+ statement,
2074
+ flags=re.IGNORECASE,
2075
+ )
2076
+ if id_where_match:
2077
+ id_value = id_where_match.group(1)
2078
+ statement = re.sub(
2079
+ r"\b(?:id|__key__)\s*=\s*\d+",
2080
+ f"__key__ = KEY({table_name}, {id_value})",
2081
+ statement,
2082
+ flags=re.IGNORECASE,
2083
+ )
2084
+
2085
+ # Remove column aliases (AS alias_name) - GQL doesn't support them
2086
+ # But preserve AS inside AGGREGATE ... AS ... OVER syntax
2087
+ statement = re.sub(
2088
+ r"\bAS\s+\w+", "", statement, flags=re.IGNORECASE
2089
+ )
2090
+
2091
+ # Remove table prefix from column names (table.column -> column)
2092
+ if table_name:
2093
+ statement = re.sub(
2094
+ rf"\b{table_name}\.(?!__)", "", statement, flags=re.IGNORECASE
2095
+ )
2096
+
2097
+ # Clean up extra spaces and artifacts
2098
+ statement = re.sub(r"\s+", " ", statement).strip()
2099
+ statement = re.sub(r",\s*,", ",", statement)
2100
+ statement = re.sub(r"\s*,\s*\bFROM\b", " FROM", statement)
2101
+
2102
+ return statement
2103
+
2104
+ def close(self):
2105
+ self._closed = True
2106
+ self.connection = None
2107
+ logging.debug("Cursor is closed.")
2108
+
2109
+
2110
+ class Connection:
2111
+ def __init__(self, client=None):
2112
+ self._client = client
2113
+ self._transaction = None
2114
+
2115
+ def cursor(self):
2116
+ return Cursor(self)
2117
+
2118
+ def begin(self):
2119
+ logging.debug("datastore connection transaction begin")
2120
+
2121
+ def commit(self):
2122
+ logging.debug("datastore connection commit")
2123
+
2124
+ def rollback(self):
2125
+ logging.debug("datastore connection rollback")
2126
+
2127
+ def close(self):
2128
+ logging.debug("Closing connection")
2129
+
2130
+
2131
+ def connect(client=None):
2132
+ return Connection(client)
2133
+
2134
+
2135
+ class ParseEntity:
2136
+ @classmethod
2137
+ def parse(cls, data: dict, selected_columns: Optional[List[str]] = None):
2138
+ """
2139
+ Parse the datastore entity
2140
+
2141
+ dict is a json base entity
2142
+ selected_columns: List of column names to include in results. If None, include all.
2143
+ """
2144
+ all_property_names_set = set()
2145
+ for entity_data in data:
2146
+ properties = entity_data.get("entity", {}).get("properties", {})
2147
+ all_property_names_set.update(properties.keys())
2148
+
2149
+ # Determine which columns to include
2150
+ if selected_columns is None:
2151
+ # Include all properties if no specific selection
2152
+ sorted_property_names = sorted(list(all_property_names_set))
2153
+ include_key = True
2154
+ else:
2155
+ # Only include selected columns
2156
+ sorted_property_names = []
2157
+ include_key = False
2158
+ for col in selected_columns:
2159
+ if col.lower() == "__key__" or col.lower() == "key":
2160
+ include_key = True
2161
+ elif col in all_property_names_set:
2162
+ sorted_property_names.append(col)
2163
+
2164
+ final_fields: dict = {}
2165
+ final_rows: List[Tuple] = []
2166
+
2167
+ # Add key field if requested
2168
+ if include_key:
2169
+ final_fields["key"] = ("key", None, None, None, None, None, None)
2170
+
2171
+ # Add selected fields in the order they appear in selected_columns if provided
2172
+ if selected_columns:
2173
+ # Keep the order from selected_columns
2174
+ for prop_name in selected_columns:
2175
+ if (
2176
+ prop_name.lower() != "__key__"
2177
+ and prop_name.lower() != "key"
2178
+ and prop_name in all_property_names_set
2179
+ ):
2180
+ final_fields[prop_name] = (
2181
+ prop_name,
2182
+ None,
2183
+ None,
2184
+ None,
2185
+ None,
2186
+ None,
2187
+ None,
2188
+ )
2189
+ else:
2190
+ # Add all fields sorted by name
2191
+ for prop_name in sorted_property_names:
2192
+ final_fields[prop_name] = (
2193
+ prop_name,
2194
+ None,
2195
+ None,
2196
+ None,
2197
+ None,
2198
+ None,
2199
+ None,
2200
+ )
2201
+
2202
+ # Append the properties
2203
+ for entity_data in data:
2204
+ row_values: List[Any] = []
2205
+ properties = entity_data.get("entity", {}).get("properties", {})
2206
+ key = entity_data.get("entity", {}).get("key", {})
2207
+
2208
+ # Add key value if requested
2209
+ if include_key:
2210
+ row_values.append(key.get("path", []))
2211
+
2212
+ # Append selected properties in the correct order
2213
+ if selected_columns:
2214
+ for prop_name in selected_columns:
2215
+ if prop_name.lower() == "__key__" or prop_name.lower() == "key":
2216
+ continue # already added above
2217
+ if prop_name in all_property_names_set:
2218
+ prop_v = properties.get(prop_name)
2219
+ if prop_v is not None:
2220
+ prop_value, prop_type = ParseEntity.parse_properties(
2221
+ prop_name, prop_v
2222
+ )
2223
+ row_values.append(prop_value)
2224
+ current_field_info = final_fields[prop_name]
2225
+ if (
2226
+ current_field_info[1] is None
2227
+ or current_field_info[1] == "UNKNOWN"
2228
+ ):
2229
+ final_fields[prop_name] = (
2230
+ prop_name,
2231
+ prop_type,
2232
+ current_field_info[2],
2233
+ current_field_info[3],
2234
+ current_field_info[4],
2235
+ current_field_info[5],
2236
+ current_field_info[6],
2237
+ )
2238
+ else:
2239
+ row_values.append(None)
2240
+ else:
2241
+ # Append all properties in sorted order
2242
+ for prop_name in sorted_property_names:
2243
+ prop_v = properties.get(prop_name)
2244
+ if prop_v is not None:
2245
+ prop_value, prop_type = ParseEntity.parse_properties(
2246
+ prop_name, prop_v
2247
+ )
2248
+ row_values.append(prop_value)
2249
+ current_field_info = final_fields[prop_name]
2250
+ if (
2251
+ current_field_info[1] is None
2252
+ or current_field_info[1] == "UNKNOWN"
2253
+ ):
2254
+ final_fields[prop_name] = (
2255
+ prop_name,
2256
+ prop_type,
2257
+ current_field_info[2],
2258
+ current_field_info[3],
2259
+ current_field_info[4],
2260
+ current_field_info[5],
2261
+ current_field_info[6],
2262
+ )
2263
+ else:
2264
+ row_values.append(None)
2265
+
2266
+ final_rows.append(tuple(row_values))
2267
+
2268
+ return final_rows, final_fields
2269
+
2270
+ @classmethod
2271
+ def parse_properties(cls, prop_k: str, prop_v: dict):
2272
+ value_type = next(iter(prop_v), None)
2273
+ prop_type = None
2274
+ prop_value: Any = None
2275
+
2276
+ if value_type == "nullValue" or "nullValue" in prop_v:
2277
+ prop_value = None
2278
+ prop_type = _types.NULL_TYPE
2279
+ elif value_type == "booleanValue" or "booleanValue" in prop_v:
2280
+ prop_value = bool(prop_v["booleanValue"])
2281
+ prop_type = _types.BOOL
2282
+ elif value_type == "integerValue" or "integerValue" in prop_v:
2283
+ prop_value = int(prop_v["integerValue"])
2284
+ prop_type = _types.INTEGER
2285
+ elif value_type == "doubleValue" or "doubleValue" in prop_v:
2286
+ prop_value = float(prop_v["doubleValue"])
2287
+ prop_type = _types.FLOAT64
2288
+ elif value_type == "stringValue" or "stringValue" in prop_v:
2289
+ prop_value = prop_v["stringValue"]
2290
+ prop_type = _types.STRING
2291
+ elif value_type == "timestampValue" or "timestampValue" in prop_v:
2292
+ timestamp_str = prop_v["timestampValue"]
2293
+ if timestamp_str.endswith("Z"):
2294
+ # Handle ISO 8601 with Z suffix (UTC)
2295
+ prop_value = datetime.fromisoformat(
2296
+ timestamp_str.replace("Z", "+00:00")
2297
+ )
2298
+ else:
2299
+ prop_value = datetime.fromisoformat(timestamp_str)
2300
+ prop_type = _types.TIMESTAMP
2301
+ elif value_type == "blobValue" or "blobValue" in prop_v:
2302
+ prop_value = base64.b64decode(prop_v.get("blobValue", b""))
2303
+ prop_type = _types.BYTES
2304
+ elif value_type == "geoPointValue" or "geoPointValue" in prop_v:
2305
+ prop_value = prop_v["geoPointValue"]
2306
+ prop_type = _types.GEOPOINT
2307
+ elif value_type == "keyValue" or "keyValue" in prop_v:
2308
+ prop_value = prop_v["keyValue"]["path"]
2309
+ prop_type = _types.KEY_TYPE
2310
+ elif value_type == "arrayValue" or "arrayValue" in prop_v:
2311
+ prop_value = []
2312
+ for entity in prop_v["arrayValue"].get("values", []):
2313
+ e_v, _ = ParseEntity.parse_properties(prop_k, entity)
2314
+ prop_value.append(e_v)
2315
+ prop_type = _types.ARRAY
2316
+ elif value_type == "dictValue" or "dictValue" in prop_v:
2317
+ prop_value = prop_v["dictValue"]
2318
+ prop_type = _types.STRUCT_FIELD_TYPES
2319
+ elif value_type == "entityValue" or "entityValue" in prop_v:
2320
+ prop_value = prop_v["entityValue"].get("properties") or {}
2321
+ prop_type = _types.STRUCT_FIELD_TYPES
2322
+ return prop_value, prop_type