datajunction-query 0.0.1a58__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
djqs/config.py CHANGED
@@ -3,94 +3,197 @@ Configuration for the query service
3
3
  """
4
4
 
5
5
  import json
6
+ import logging
7
+ import os
6
8
  from datetime import timedelta
7
- from typing import Optional
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional
8
11
 
9
12
  import toml
10
13
  import yaml
11
14
  from cachelib.base import BaseCache
12
15
  from cachelib.file import FileSystemCache
13
- from pydantic import BaseSettings
14
- from sqlmodel import Session, delete, select
15
16
 
16
- from djqs.exceptions import DJException
17
- from djqs.models.catalog import Catalog, CatalogEngines
18
- from djqs.models.engine import Engine
17
+ from djqs.exceptions import DJUnknownCatalog, DJUnknownEngine
19
18
 
19
+ _logger = logging.getLogger(__name__)
20
20
 
21
- class Settings(BaseSettings): # pylint: disable=too-few-public-methods
21
+
22
+ class EngineType(Enum):
22
23
  """
23
- Configuration for the query service
24
+ Supported engine types
24
25
  """
25
26
 
26
- name: str = "DJQS"
27
- description: str = "A DataJunction Query Service"
28
- url: str = "http://localhost:8001/"
27
+ DUCKDB = "duckdb"
28
+ SQLALCHEMY = "sqlalchemy"
29
+ SNOWFLAKE = "snowflake"
30
+ TRINO = "trino"
29
31
 
30
- # SQLAlchemy URI for the metadata database.
31
- index: str = "sqlite:///djqs.db?check_same_thread=False"
32
32
 
33
- # The default engine to use for reflection
34
- default_reflection_engine: str = "default"
33
+ class EngineInfo: # pylint: disable=too-few-public-methods
34
+ """
35
+ Information about a query engine
36
+ """
35
37
 
36
- # The default engine version to use for reflection
37
- default_reflection_engine_version: str = ""
38
+ def __init__( # pylint: disable=too-many-arguments
39
+ self,
40
+ name: str,
41
+ version: str,
42
+ type: str, # pylint: disable=redefined-builtin
43
+ uri: str,
44
+ extra_params: Optional[Dict[str, str]] = None,
45
+ ):
46
+ self.name = name
47
+ self.version = str(version)
48
+ self.type = EngineType(type)
49
+ self.uri = uri
50
+ self.extra_params = extra_params or {}
51
+
52
+
53
+ class CatalogInfo: # pylint: disable=too-few-public-methods
54
+ """
55
+ Information about a catalog
56
+ """
38
57
 
39
- # Where to store the results from queries.
40
- results_backend: BaseCache = FileSystemCache("/tmp/djqs", default_timeout=0)
58
+ def __init__(self, name: str, engines: List[str]):
59
+ self.name = name
60
+ self.engines = engines
41
61
 
42
- paginating_timeout: timedelta = timedelta(minutes=5)
43
62
 
44
- # How long to wait when pinging databases to find out the fastest online database.
45
- do_ping_timeout: timedelta = timedelta(seconds=5)
63
+ class Settings: # pylint: disable=too-many-instance-attributes
64
+ """
65
+ Configuration for the query service
66
+ """
46
67
 
47
- # Configuration file for catalogs and engines
48
- configuration_file: Optional[str] = None
68
+ def __init__( # pylint: disable=too-many-arguments,too-many-locals,dangerous-default-value
69
+ self,
70
+ name: Optional[str] = "DJQS",
71
+ description: Optional[str] = "A DataJunction Query Service",
72
+ url: Optional[str] = "http://localhost:8001/",
73
+ index: Optional[str] = "postgresql://dj:dj@postgres_metadata:5432/djqs",
74
+ default_catalog: Optional[str] = "",
75
+ default_engine: Optional[str] = "",
76
+ default_engine_version: Optional[str] = "",
77
+ results_backend: Optional[BaseCache] = None,
78
+ results_backend_path: Optional[str] = "/tmp/djqs",
79
+ results_backend_timeout: Optional[str] = "0",
80
+ paginating_timeout_minutes: Optional[str] = "5",
81
+ do_ping_timeout_seconds: Optional[str] = "5",
82
+ configuration_file: Optional[str] = None,
83
+ engines: Optional[List[EngineInfo]] = None,
84
+ catalogs: Optional[List[CatalogInfo]] = None,
85
+ ):
86
+ self.name: str = os.getenv("NAME", name or "")
87
+ self.description: str = os.getenv("DESCRIPTION", description or "")
88
+ self.url: str = os.getenv("URL", url or "")
89
+
90
+ # SQLAlchemy URI for the metadata database.
91
+ self.index: str = os.getenv("INDEX", index or "")
92
+
93
+ # The default catalog to use if not specified in query payload
94
+ self.default_catalog: str = os.getenv("DEFAULT_CATALOG", default_catalog or "")
95
+
96
+ # The default engine to use if not specified in query payload
97
+ self.default_engine: str = os.getenv("DEFAULT_ENGINE", default_engine or "")
98
+
99
+ # The default engine version to use if not specified in query payload
100
+ self.default_engine_version: str = os.getenv(
101
+ "DEFAULT_ENGINE_VERSION",
102
+ default_engine_version or "",
103
+ )
49
104
 
50
- # Enable setting catalog and engine config via REST API calls
51
- enable_dynamic_config: bool = True
105
+ # Where to store the results from queries.
106
+ self.results_backend: BaseCache = results_backend or FileSystemCache(
107
+ os.getenv("RESULTS_BACKEND_PATH", results_backend_path or ""),
108
+ default_timeout=int(
109
+ os.getenv("RESULTS_BACKEND_TIMEOUT", results_backend_timeout or "0"),
110
+ ),
111
+ )
52
112
 
113
+ self.paginating_timeout: timedelta = timedelta(
114
+ minutes=int(
115
+ os.getenv(
116
+ "PAGINATING_TIMEOUT_MINUTES",
117
+ paginating_timeout_minutes or "5",
118
+ ),
119
+ ),
120
+ )
53
121
 
54
- def load_djqs_config(settings: Settings, session: Session) -> None: # pragma: no cover
55
- """
56
- Load the DJQS config file into the server metadata database
57
- """
58
- config_file = settings.configuration_file if settings.configuration_file else None
59
- if not config_file:
60
- return
61
-
62
- session.exec(delete(Catalog))
63
- session.exec(delete(Engine))
64
- session.exec(delete(CatalogEngines))
65
- session.commit()
66
-
67
- with open(config_file, mode="r", encoding="utf-8") as filestream:
68
-
69
- def unknown_filetype():
70
- raise DJException(message=f"Unknown config file type: {config_file}")
71
-
72
- data = (
73
- yaml.safe_load(filestream)
74
- if any([config_file.endswith(".yml"), config_file.endswith(".yaml")])
75
- else toml.load(filestream)
76
- if config_file.endswith(".toml")
77
- else json.load(filestream)
78
- if config_file.endswith(".json")
79
- else unknown_filetype()
122
+ # How long to wait when pinging databases to find out the fastest online database.
123
+ self.do_ping_timeout: timedelta = timedelta(
124
+ seconds=int(
125
+ os.getenv("DO_PING_TIMEOUT_SECONDS", do_ping_timeout_seconds or "5"),
126
+ ),
80
127
  )
81
128
 
82
- for engine in data["engines"]:
83
- session.add(Engine.parse_obj(engine))
84
- session.commit()
129
+ # Configuration file for catalogs and engines
130
+ self.configuration_file: Optional[str] = (
131
+ os.getenv("CONFIGURATION_FILE") or configuration_file
132
+ )
85
133
 
86
- for catalog in data["catalogs"]:
87
- attached_engines = []
88
- catalog_engines = catalog.pop("engines")
89
- for name in catalog_engines:
90
- attached_engines.append(
91
- session.exec(select(Engine).where(Engine.name == name)).one(),
134
+ self.engines: List[EngineInfo] = engines or []
135
+ self.catalogs: List[CatalogInfo] = catalogs or []
136
+
137
+ self._load_configuration()
138
+
139
+ def _load_configuration(self):
140
+ config_file = self.configuration_file
141
+
142
+ if config_file:
143
+ if config_file.endswith(".yaml") or config_file.endswith(".yml"):
144
+ with open(config_file, "r", encoding="utf-8") as file:
145
+ config = yaml.safe_load(file)
146
+ elif config_file.endswith(".toml"):
147
+ with open(config_file, "r", encoding="utf-8") as file:
148
+ config = toml.load(file)
149
+ elif config_file.endswith(".json"):
150
+ with open(config_file, "r", encoding="utf-8") as file:
151
+ config = json.load(file)
152
+ else:
153
+ raise ValueError(
154
+ f"Unsupported configuration file format: {config_file}",
155
+ )
156
+
157
+ self.engines = [
158
+ EngineInfo(**engine) for engine in config.get("engines", [])
159
+ ]
160
+ self.catalogs = [
161
+ CatalogInfo(**catalog) for catalog in config.get("catalogs", [])
162
+ ]
163
+ else:
164
+ _logger.warning("No settings configuration file has been set")
165
+
166
+ def find_engine(
167
+ self,
168
+ engine_name: str,
169
+ engine_version: str,
170
+ ) -> EngineInfo:
171
+ """
172
+ Find an engine defined in the server configuration
173
+ """
174
+ found_engine = None
175
+ for engine in self.engines:
176
+ if engine.name == engine_name and engine.version == engine_version:
177
+ found_engine = engine
178
+ if not found_engine:
179
+ raise DJUnknownEngine(
180
+ (
181
+ f"Configuration error, cannot find engine {engine_name} "
182
+ f"with version {engine_version}"
183
+ ),
184
+ )
185
+ return found_engine
186
+
187
+ def find_catalog(self, catalog_name: str) -> CatalogInfo:
188
+ """
189
+ Find a catalog defined in the server configuration
190
+ """
191
+ found_catalog = None
192
+ for catalog in self.catalogs:
193
+ if catalog.name == catalog_name:
194
+ found_catalog = catalog
195
+ if not found_catalog:
196
+ raise DJUnknownCatalog(
197
+ f"Configuration error, cannot find catalog {catalog_name}",
92
198
  )
93
- catalog_entry = Catalog.parse_obj(catalog)
94
- catalog_entry.engines = attached_engines
95
- session.add(catalog_entry)
96
- session.commit()
199
+ return found_catalog
djqs/constants.py CHANGED
@@ -15,3 +15,6 @@ DEFAULT_DIMENSION_COLUMN = "id"
15
15
  # used by the SQLAlchemy client
16
16
  QUERY_EXECUTE_TIMEOUT = timedelta(seconds=60)
17
17
  GET_COLUMNS_TIMEOUT = timedelta(seconds=60)
18
+
19
+ # Request header configuration params
20
+ SQLALCHEMY_URI = "SQLALCHEMY_URI"
djqs/db/postgres.py ADDED
@@ -0,0 +1,139 @@
1
+ """
2
+ Dependency for getting the postgres pool and running backend DB queries
3
+ """
4
+
5
+ # pylint: disable=too-many-arguments
6
+ from datetime import datetime
7
+ from typing import List
8
+ from uuid import UUID
9
+
10
+ from fastapi import Request
11
+ from psycopg import sql
12
+ from psycopg_pool import AsyncConnectionPool
13
+
14
+ from djqs.exceptions import DJDatabaseError
15
+
16
+
17
+ async def get_postgres_pool(request: Request) -> AsyncConnectionPool:
18
+ """
19
+ Get the postgres pool from the app instance
20
+ """
21
+ app = request.app
22
+ return app.state.pool
23
+
24
+
25
+ class DBQuery:
26
+ """
27
+ Metadata DB queries using the psycopg composition utility
28
+ """
29
+
30
+ def __init__(self):
31
+ self._reset()
32
+
33
+ def _reset(self):
34
+ self.selects: List = []
35
+ self.inserts: List = []
36
+
37
+ def get_query(self, query_id: UUID):
38
+ """
39
+ Get metadata about a query
40
+ """
41
+ self.selects.append(
42
+ sql.SQL(
43
+ """
44
+ SELECT id, catalog_name, engine_name, engine_version, submitted_query,
45
+ async_, executed_query, scheduled, started, finished, state, progress
46
+ FROM query
47
+ WHERE id = {query_id}
48
+ """,
49
+ ).format(query_id=sql.Literal(query_id)),
50
+ )
51
+ return self
52
+
53
+ def save_query(
54
+ self,
55
+ query_id: UUID,
56
+ catalog_name: str = "",
57
+ engine_name: str = "",
58
+ engine_version: str = "",
59
+ submitted_query: str = "",
60
+ async_: bool = False,
61
+ state: str = "",
62
+ progress: float = 0.0,
63
+ executed_query: str = None,
64
+ scheduled: datetime = None,
65
+ started: datetime = None,
66
+ finished: datetime = None,
67
+ ):
68
+ """
69
+ Save metadata about a query
70
+ """
71
+ self.inserts.append(
72
+ sql.SQL(
73
+ """
74
+ INSERT INTO query (id, catalog_name, engine_name, engine_version,
75
+ submitted_query, async_, executed_query, scheduled,
76
+ started, finished, state, progress)
77
+ VALUES ({query_id}, {catalog_name}, {engine_name}, {engine_version},
78
+ {submitted_query}, {async_}, {executed_query}, {scheduled},
79
+ {started}, {finished}, {state}, {progress})
80
+ ON CONFLICT (id) DO UPDATE SET
81
+ catalog_name = EXCLUDED.catalog_name,
82
+ engine_name = EXCLUDED.engine_name,
83
+ engine_version = EXCLUDED.engine_version,
84
+ submitted_query = EXCLUDED.submitted_query,
85
+ async_ = EXCLUDED.async_,
86
+ executed_query = EXCLUDED.executed_query,
87
+ scheduled = EXCLUDED.scheduled,
88
+ started = EXCLUDED.started,
89
+ finished = EXCLUDED.finished,
90
+ state = EXCLUDED.state,
91
+ progress = EXCLUDED.progress
92
+ RETURNING *
93
+ """,
94
+ ).format(
95
+ query_id=sql.Literal(query_id),
96
+ catalog_name=sql.Literal(catalog_name),
97
+ engine_name=sql.Literal(engine_name),
98
+ engine_version=sql.Literal(engine_version),
99
+ submitted_query=sql.Literal(submitted_query),
100
+ async_=sql.Literal(async_),
101
+ executed_query=sql.Literal(executed_query),
102
+ scheduled=sql.Literal(scheduled),
103
+ started=sql.Literal(started),
104
+ finished=sql.Literal(finished),
105
+ state=sql.Literal(state),
106
+ progress=sql.Literal(progress),
107
+ ),
108
+ )
109
+ return self
110
+
111
+ async def execute(self, conn):
112
+ """
113
+ Submit all statements to the backend DB, multiple statements are submitted together
114
+ """
115
+ if not self.selects and not self.inserts: # pragma: no cover
116
+ return
117
+
118
+ async with conn.cursor() as cur:
119
+ results = []
120
+ if len(self.inserts) > 1: # pragma: no cover
121
+ async with conn.transaction():
122
+ for statement in self.inserts:
123
+ await cur.execute(statement)
124
+ results.append(await cur.fetchall())
125
+
126
+ if len(self.inserts) == 1:
127
+ await cur.execute(self.inserts[0])
128
+ if cur.rowcount == 0: # pragma: no cover
129
+ raise DJDatabaseError(
130
+ "Insert statement resulted in no records being inserted",
131
+ )
132
+ results.append((await cur.fetchone()))
133
+ if self.selects:
134
+ for statement in self.selects:
135
+ await cur.execute(statement)
136
+ results.append(await cur.fetchall())
137
+ await conn.commit()
138
+ self._reset()
139
+ return results
djqs/engine.py CHANGED
@@ -2,28 +2,31 @@
2
2
  Query related functions.
3
3
  """
4
4
 
5
+ import json
5
6
  import logging
6
7
  import os
7
- from datetime import datetime, timezone
8
- from typing import List, Tuple
8
+ from dataclasses import asdict
9
+ from datetime import date, datetime, timezone
10
+ from typing import Dict, List, Optional, Tuple
9
11
 
10
12
  import duckdb
11
13
  import snowflake.connector
12
- import sqlparse
14
+ from psycopg_pool import AsyncConnectionPool
13
15
  from sqlalchemy import create_engine, text
14
- from sqlmodel import Session, select
15
16
 
16
- from djqs.config import Settings
17
- from djqs.models.engine import Engine, EngineType
17
+ from djqs.config import EngineType, Settings
18
+ from djqs.constants import SQLALCHEMY_URI
19
+ from djqs.db.postgres import DBQuery
20
+ from djqs.exceptions import DJDatabaseError
18
21
  from djqs.models.query import (
19
22
  ColumnMetadata,
20
23
  Query,
21
24
  QueryResults,
22
25
  QueryState,
23
- Results,
24
26
  StatementResults,
25
27
  )
26
28
  from djqs.typing import ColumnType, Description, SQLADialect, Stream, TypeEnum
29
+ from djqs.utils import get_settings
27
30
 
28
31
  _logger = logging.getLogger(__name__)
29
32
 
@@ -66,9 +69,9 @@ def get_columns_from_description(
66
69
  return columns
67
70
 
68
71
 
69
- def run_query(
70
- session: Session,
72
+ def run_query( # pylint: disable=R0914
71
73
  query: Query,
74
+ headers: Optional[Dict[str, str]] = None,
72
75
  ) -> List[Tuple[str, List[ColumnMetadata], Stream]]:
73
76
  """
74
77
  Run a query and return its results.
@@ -76,13 +79,26 @@ def run_query(
76
79
  For each statement we return a tuple with the statement SQL, a description of the
77
80
  columns (name and type) and a stream of rows (tuples).
78
81
  """
82
+
79
83
  _logger.info("Running query on catalog %s", query.catalog_name)
80
- engine = session.exec(
81
- select(Engine)
82
- .where(Engine.name == query.engine_name)
83
- .where(Engine.version == query.engine_version),
84
- ).one()
85
- if engine.type == EngineType.DUCKDB:
84
+
85
+ settings = get_settings()
86
+ engine_name = query.engine_name or settings.default_engine
87
+ engine_version = query.engine_version
88
+ engine = settings.find_engine(
89
+ engine_name=engine_name,
90
+ engine_version=engine_version,
91
+ )
92
+ query_server = headers.get(SQLALCHEMY_URI) if headers else None
93
+
94
+ if query_server:
95
+ _logger.info(
96
+ "Creating sqlalchemy engine using request header param %s",
97
+ SQLALCHEMY_URI,
98
+ )
99
+ sqla_engine = create_engine(query_server)
100
+ elif engine.type == EngineType.DUCKDB:
101
+ _logger.info("Creating duckdb connection")
86
102
  conn = (
87
103
  duckdb.connect()
88
104
  if engine.uri == "duckdb:///:memory:"
@@ -92,7 +108,8 @@ def run_query(
92
108
  )
93
109
  )
94
110
  return run_duckdb_query(query, conn)
95
- if engine.type == EngineType.SNOWFLAKE:
111
+ elif engine.type == EngineType.SNOWFLAKE:
112
+ _logger.info("Creating snowflake connection")
96
113
  conn = snowflake.connector.connect(
97
114
  **engine.extra_params,
98
115
  password=os.getenv("SNOWSQL_PWD"),
@@ -100,22 +117,21 @@ def run_query(
100
117
  cur = conn.cursor()
101
118
 
102
119
  return run_snowflake_query(query, cur)
120
+
121
+ _logger.info(
122
+ "Creating sqlalchemy engine using engine name and version defined on query",
123
+ )
103
124
  sqla_engine = create_engine(engine.uri, connect_args=engine.extra_params)
104
125
  connection = sqla_engine.connect()
105
126
 
106
127
  output: List[Tuple[str, List[ColumnMetadata], Stream]] = []
107
- statements = sqlparse.parse(query.executed_query)
108
- for statement in statements:
109
- # Druid doesn't like statements that end in a semicolon...
110
- sql = str(statement).strip().rstrip(";")
111
-
112
- results = connection.execute(text(sql))
113
- stream = (tuple(row) for row in results)
114
- columns = get_columns_from_description(
115
- results.cursor.description,
116
- sqla_engine.dialect,
117
- )
118
- output.append((sql, columns, stream))
128
+ results = connection.execute(text(query.executed_query))
129
+ stream = (tuple(row) for row in results)
130
+ columns = get_columns_from_description(
131
+ results.cursor.description,
132
+ sqla_engine.dialect,
133
+ )
134
+ output.append((query.executed_query, columns, stream)) # type: ignore
119
135
 
120
136
  return output
121
137
 
@@ -148,10 +164,24 @@ def run_snowflake_query(
148
164
  return output
149
165
 
150
166
 
151
- def process_query(
152
- session: Session,
167
+ def serialize_for_json(obj):
168
+ """
169
+ Handle serialization of date/datetimes for JSON output.
170
+ """
171
+ if isinstance(obj, list):
172
+ return [serialize_for_json(x) for x in obj]
173
+ if isinstance(obj, dict):
174
+ return {k: serialize_for_json(v) for k, v in obj.items()}
175
+ if isinstance(obj, (date, datetime)):
176
+ return obj.isoformat()
177
+ return obj
178
+
179
+
180
+ async def process_query(
153
181
  settings: Settings,
182
+ postgres_pool: AsyncConnectionPool,
154
183
  query: Query,
184
+ headers: Optional[Dict[str, str]] = None,
155
185
  ) -> QueryResults:
156
186
  """
157
187
  Process a query.
@@ -163,10 +193,13 @@ def process_query(
163
193
  errors = []
164
194
  query.started = datetime.now(timezone.utc)
165
195
  try:
166
- root = []
167
- for sql, columns, stream in run_query(session=session, query=query):
196
+ results = []
197
+ for sql, columns, stream in run_query(
198
+ query=query,
199
+ headers=headers,
200
+ ):
168
201
  rows = list(stream)
169
- root.append(
202
+ results.append(
170
203
  StatementResults(
171
204
  sql=sql,
172
205
  columns=columns,
@@ -174,21 +207,52 @@ def process_query(
174
207
  row_count=len(rows),
175
208
  ),
176
209
  )
177
- results = Results(__root__=root)
178
210
 
179
211
  query.state = QueryState.FINISHED
180
212
  query.progress = 1.0
181
213
  except Exception as ex: # pylint: disable=broad-except
182
- results = Results(__root__=[])
214
+ results = []
183
215
  query.state = QueryState.FAILED
184
216
  errors = [str(ex)]
185
217
 
186
218
  query.finished = datetime.now(timezone.utc)
187
219
 
188
- session.add(query)
189
- session.commit()
190
- session.refresh(query)
220
+ async with postgres_pool.connection() as conn:
221
+ dbquery_results = (
222
+ await DBQuery()
223
+ .save_query(
224
+ query_id=query.id,
225
+ submitted_query=query.submitted_query,
226
+ state=QueryState.FINISHED.value,
227
+ async_=query.async_,
228
+ )
229
+ .execute(conn=conn)
230
+ )
231
+ query_save_result = dbquery_results[0]
232
+ if not query_save_result: # pragma: no cover
233
+ raise DJDatabaseError("Query failed to save")
191
234
 
192
- settings.results_backend.add(str(query.id), results.json())
235
+ settings.results_backend.add(
236
+ str(query.id),
237
+ json.dumps(
238
+ serialize_for_json(
239
+ [asdict(statement_result) for statement_result in results],
240
+ ),
241
+ ),
242
+ )
193
243
 
194
- return QueryResults(results=results, errors=errors, **query.dict())
244
+ return QueryResults(
245
+ id=query.id,
246
+ catalog_name=query.catalog_name,
247
+ engine_name=query.engine_name,
248
+ engine_version=query.engine_version,
249
+ submitted_query=query.submitted_query,
250
+ executed_query=query.executed_query,
251
+ scheduled=query.scheduled,
252
+ started=query.started,
253
+ finished=query.finished,
254
+ state=query.state,
255
+ progress=query.progress,
256
+ results=results,
257
+ errors=errors,
258
+ )
djqs/enum.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Backwards-compatible StrEnum for both Python >= and < 3.11
3
3
  """
4
+
4
5
  import enum
5
6
  import sys
6
7