datajunction-query 0.0.1a65__py3-none-any.whl → 0.0.1a67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datajunction-query might be problematic. Click here for more details.

djqs/config.py CHANGED
@@ -3,94 +3,197 @@ Configuration for the query service
3
3
  """
4
4
 
5
5
  import json
6
+ import logging
7
+ import os
6
8
  from datetime import timedelta
7
- from typing import Optional
9
+ from enum import Enum
10
+ from typing import Dict, List, Optional
8
11
 
9
12
  import toml
10
13
  import yaml
11
14
  from cachelib.base import BaseCache
12
15
  from cachelib.file import FileSystemCache
13
- from pydantic import BaseSettings
14
- from sqlmodel import Session, delete, select
15
16
 
16
- from djqs.exceptions import DJException
17
- from djqs.models.catalog import Catalog, CatalogEngines
18
- from djqs.models.engine import Engine
17
+ from djqs.exceptions import DJUnknownCatalog, DJUnknownEngine
19
18
 
19
+ _logger = logging.getLogger(__name__)
20
20
 
21
- class Settings(BaseSettings): # pylint: disable=too-few-public-methods
21
+
22
+ class EngineType(Enum):
22
23
  """
23
- Configuration for the query service
24
+ Supported engine types
24
25
  """
25
26
 
26
- name: str = "DJQS"
27
- description: str = "A DataJunction Query Service"
28
- url: str = "http://localhost:8001/"
27
+ DUCKDB = "duckdb"
28
+ SQLALCHEMY = "sqlalchemy"
29
+ SNOWFLAKE = "snowflake"
30
+ TRINO = "trino"
29
31
 
30
- # SQLAlchemy URI for the metadata database.
31
- index: str = "sqlite:///djqs.db?check_same_thread=False"
32
32
 
33
- # The default engine to use for reflection
34
- default_reflection_engine: str = "default"
33
+ class EngineInfo: # pylint: disable=too-few-public-methods
34
+ """
35
+ Information about a query engine
36
+ """
35
37
 
36
- # The default engine version to use for reflection
37
- default_reflection_engine_version: str = ""
38
+ def __init__( # pylint: disable=too-many-arguments
39
+ self,
40
+ name: str,
41
+ version: str,
42
+ type: str, # pylint: disable=redefined-builtin
43
+ uri: str,
44
+ extra_params: Optional[Dict[str, str]] = None,
45
+ ):
46
+ self.name = name
47
+ self.version = str(version)
48
+ self.type = EngineType(type)
49
+ self.uri = uri
50
+ self.extra_params = extra_params or {}
51
+
52
+
53
+ class CatalogInfo: # pylint: disable=too-few-public-methods
54
+ """
55
+ Information about a catalog
56
+ """
38
57
 
39
- # Where to store the results from queries.
40
- results_backend: BaseCache = FileSystemCache("/tmp/djqs", default_timeout=0)
58
+ def __init__(self, name: str, engines: List[str]):
59
+ self.name = name
60
+ self.engines = engines
41
61
 
42
- paginating_timeout: timedelta = timedelta(minutes=5)
43
62
 
44
- # How long to wait when pinging databases to find out the fastest online database.
45
- do_ping_timeout: timedelta = timedelta(seconds=5)
63
+ class Settings: # pylint: disable=too-many-instance-attributes
64
+ """
65
+ Configuration for the query service
66
+ """
46
67
 
47
- # Configuration file for catalogs and engines
48
- configuration_file: Optional[str] = None
68
+ def __init__( # pylint: disable=too-many-arguments,too-many-locals,dangerous-default-value
69
+ self,
70
+ name: Optional[str] = "DJQS",
71
+ description: Optional[str] = "A DataJunction Query Service",
72
+ url: Optional[str] = "http://localhost:8001/",
73
+ index: Optional[str] = "postgresql://dj:dj@postgres_metadata:5432/dj",
74
+ default_catalog: Optional[str] = "",
75
+ default_engine: Optional[str] = "",
76
+ default_engine_version: Optional[str] = "",
77
+ results_backend: Optional[BaseCache] = None,
78
+ results_backend_path: Optional[str] = "/tmp/djqs",
79
+ results_backend_timeout: Optional[str] = "0",
80
+ paginating_timeout_minutes: Optional[str] = "5",
81
+ do_ping_timeout_seconds: Optional[str] = "5",
82
+ configuration_file: Optional[str] = None,
83
+ engines: Optional[List[EngineInfo]] = None,
84
+ catalogs: Optional[List[CatalogInfo]] = None,
85
+ ):
86
+ self.name: str = os.getenv("NAME", name or "")
87
+ self.description: str = os.getenv("DESCRIPTION", description or "")
88
+ self.url: str = os.getenv("URL", url or "")
89
+
90
+ # SQLAlchemy URI for the metadata database.
91
+ self.index: str = os.getenv("INDEX", index or "")
92
+
93
+ # The default catalog to use if not specified in query payload
94
+ self.default_catalog: str = os.getenv("DEFAULT_CATALOG", default_catalog or "")
95
+
96
+ # The default engine to use if not specified in query payload
97
+ self.default_engine: str = os.getenv("DEFAULT_ENGINE", default_engine or "")
98
+
99
+ # The default engine version to use if not specified in query payload
100
+ self.default_engine_version: str = os.getenv(
101
+ "DEFAULT_ENGINE_VERSION",
102
+ default_engine_version or "",
103
+ )
49
104
 
50
- # Enable setting catalog and engine config via REST API calls
51
- enable_dynamic_config: bool = True
105
+ # Where to store the results from queries.
106
+ self.results_backend: BaseCache = results_backend or FileSystemCache(
107
+ os.getenv("RESULTS_BACKEND_PATH", results_backend_path or ""),
108
+ default_timeout=int(
109
+ os.getenv("RESULTS_BACKEND_TIMEOUT", results_backend_timeout or "0"),
110
+ ),
111
+ )
52
112
 
113
+ self.paginating_timeout: timedelta = timedelta(
114
+ minutes=int(
115
+ os.getenv(
116
+ "PAGINATING_TIMEOUT_MINUTES",
117
+ paginating_timeout_minutes or "5",
118
+ ),
119
+ ),
120
+ )
53
121
 
54
- def load_djqs_config(settings: Settings, session: Session) -> None: # pragma: no cover
55
- """
56
- Load the DJQS config file into the server metadata database
57
- """
58
- config_file = settings.configuration_file if settings.configuration_file else None
59
- if not config_file:
60
- return
61
-
62
- session.exec(delete(Catalog))
63
- session.exec(delete(Engine))
64
- session.exec(delete(CatalogEngines))
65
- session.commit()
66
-
67
- with open(config_file, mode="r", encoding="utf-8") as filestream:
68
-
69
- def unknown_filetype():
70
- raise DJException(message=f"Unknown config file type: {config_file}")
71
-
72
- data = (
73
- yaml.safe_load(filestream)
74
- if any([config_file.endswith(".yml"), config_file.endswith(".yaml")])
75
- else toml.load(filestream)
76
- if config_file.endswith(".toml")
77
- else json.load(filestream)
78
- if config_file.endswith(".json")
79
- else unknown_filetype()
122
+ # How long to wait when pinging databases to find out the fastest online database.
123
+ self.do_ping_timeout: timedelta = timedelta(
124
+ seconds=int(
125
+ os.getenv("DO_PING_TIMEOUT_SECONDS", do_ping_timeout_seconds or "5"),
126
+ ),
80
127
  )
81
128
 
82
- for engine in data["engines"]:
83
- session.add(Engine.parse_obj(engine))
84
- session.commit()
129
+ # Configuration file for catalogs and engines
130
+ self.configuration_file: Optional[str] = (
131
+ os.getenv("CONFIGURATION_FILE") or configuration_file
132
+ )
85
133
 
86
- for catalog in data["catalogs"]:
87
- attached_engines = []
88
- catalog_engines = catalog.pop("engines")
89
- for name in catalog_engines:
90
- attached_engines.append(
91
- session.exec(select(Engine).where(Engine.name == name)).one(),
134
+ self.engines: List[EngineInfo] = engines or []
135
+ self.catalogs: List[CatalogInfo] = catalogs or []
136
+
137
+ self._load_configuration()
138
+
139
+ def _load_configuration(self):
140
+ config_file = self.configuration_file
141
+
142
+ if config_file:
143
+ if config_file.endswith(".yaml") or config_file.endswith(".yml"):
144
+ with open(config_file, "r", encoding="utf-8") as file:
145
+ config = yaml.safe_load(file)
146
+ elif config_file.endswith(".toml"):
147
+ with open(config_file, "r", encoding="utf-8") as file:
148
+ config = toml.load(file)
149
+ elif config_file.endswith(".json"):
150
+ with open(config_file, "r", encoding="utf-8") as file:
151
+ config = json.load(file)
152
+ else:
153
+ raise ValueError(
154
+ f"Unsupported configuration file format: {config_file}",
155
+ )
156
+
157
+ self.engines = [
158
+ EngineInfo(**engine) for engine in config.get("engines", [])
159
+ ]
160
+ self.catalogs = [
161
+ CatalogInfo(**catalog) for catalog in config.get("catalogs", [])
162
+ ]
163
+ else:
164
+ _logger.warning("No settings configuration file has been set")
165
+
166
+ def find_engine(
167
+ self,
168
+ engine_name: str,
169
+ engine_version: str,
170
+ ) -> EngineInfo:
171
+ """
172
+ Find an engine defined in the server configuration
173
+ """
174
+ found_engine = None
175
+ for engine in self.engines:
176
+ if engine.name == engine_name and engine.version == engine_version:
177
+ found_engine = engine
178
+ if not found_engine:
179
+ raise DJUnknownEngine(
180
+ (
181
+ f"Configuration error, cannot find engine {engine_name} "
182
+ f"with version {engine_version}"
183
+ ),
184
+ )
185
+ return found_engine
186
+
187
+ def find_catalog(self, catalog_name: str) -> CatalogInfo:
188
+ """
189
+ Find a catalog defined in the server configuration
190
+ """
191
+ found_catalog = None
192
+ for catalog in self.catalogs:
193
+ if catalog.name == catalog_name:
194
+ found_catalog = catalog
195
+ if not found_catalog:
196
+ raise DJUnknownCatalog(
197
+ f"Configuration error, cannot find catalog {catalog_name}",
92
198
  )
93
- catalog_entry = Catalog.parse_obj(catalog)
94
- catalog_entry.engines = attached_engines
95
- session.add(catalog_entry)
96
- session.commit()
199
+ return found_catalog
djqs/db/postgres.py ADDED
@@ -0,0 +1,138 @@
1
+ """
2
+ Dependency for getting the postgres pool and running backend DB queries
3
+ """
4
+ # pylint: disable=too-many-arguments
5
+ from datetime import datetime
6
+ from typing import List
7
+ from uuid import UUID
8
+
9
+ from fastapi import Request
10
+ from psycopg import sql
11
+ from psycopg_pool import AsyncConnectionPool
12
+
13
+ from djqs.exceptions import DJDatabaseError
14
+
15
+
16
+ async def get_postgres_pool(request: Request) -> AsyncConnectionPool:
17
+ """
18
+ Get the postgres pool from the app instance
19
+ """
20
+ app = request.app
21
+ return app.state.pool
22
+
23
+
24
+ class DBQuery:
25
+ """
26
+ Metadata DB queries using the psycopg composition utility
27
+ """
28
+
29
+ def __init__(self):
30
+ self._reset()
31
+
32
+ def _reset(self):
33
+ self.selects: List = []
34
+ self.inserts: List = []
35
+
36
+ def get_query(self, query_id: UUID):
37
+ """
38
+ Get metadata about a query
39
+ """
40
+ self.selects.append(
41
+ sql.SQL(
42
+ """
43
+ SELECT id, catalog_name, engine_name, engine_version, submitted_query,
44
+ async_, executed_query, scheduled, started, finished, state, progress
45
+ FROM query
46
+ WHERE id = {query_id}
47
+ """,
48
+ ).format(query_id=sql.Literal(query_id)),
49
+ )
50
+ return self
51
+
52
+ def save_query(
53
+ self,
54
+ query_id: UUID,
55
+ catalog_name: str = "",
56
+ engine_name: str = "",
57
+ engine_version: str = "",
58
+ submitted_query: str = "",
59
+ async_: bool = False,
60
+ state: str = "",
61
+ progress: float = 0.0,
62
+ executed_query: str = None,
63
+ scheduled: datetime = None,
64
+ started: datetime = None,
65
+ finished: datetime = None,
66
+ ):
67
+ """
68
+ Save metadata about a query
69
+ """
70
+ self.inserts.append(
71
+ sql.SQL(
72
+ """
73
+ INSERT INTO query (id, catalog_name, engine_name, engine_version,
74
+ submitted_query, async_, executed_query, scheduled,
75
+ started, finished, state, progress)
76
+ VALUES ({query_id}, {catalog_name}, {engine_name}, {engine_version},
77
+ {submitted_query}, {async_}, {executed_query}, {scheduled},
78
+ {started}, {finished}, {state}, {progress})
79
+ ON CONFLICT (id) DO UPDATE SET
80
+ catalog_name = EXCLUDED.catalog_name,
81
+ engine_name = EXCLUDED.engine_name,
82
+ engine_version = EXCLUDED.engine_version,
83
+ submitted_query = EXCLUDED.submitted_query,
84
+ async_ = EXCLUDED.async_,
85
+ executed_query = EXCLUDED.executed_query,
86
+ scheduled = EXCLUDED.scheduled,
87
+ started = EXCLUDED.started,
88
+ finished = EXCLUDED.finished,
89
+ state = EXCLUDED.state,
90
+ progress = EXCLUDED.progress
91
+ RETURNING *
92
+ """,
93
+ ).format(
94
+ query_id=sql.Literal(query_id),
95
+ catalog_name=sql.Literal(catalog_name),
96
+ engine_name=sql.Literal(engine_name),
97
+ engine_version=sql.Literal(engine_version),
98
+ submitted_query=sql.Literal(submitted_query),
99
+ async_=sql.Literal(async_),
100
+ executed_query=sql.Literal(executed_query),
101
+ scheduled=sql.Literal(scheduled),
102
+ started=sql.Literal(started),
103
+ finished=sql.Literal(finished),
104
+ state=sql.Literal(state),
105
+ progress=sql.Literal(progress),
106
+ ),
107
+ )
108
+ return self
109
+
110
+ async def execute(self, conn):
111
+ """
112
+ Submit all statements to the backend DB, multiple statements are submitted together
113
+ """
114
+ if not self.selects and not self.inserts: # pragma: no cover
115
+ return
116
+
117
+ async with conn.cursor() as cur:
118
+ results = []
119
+ if len(self.inserts) > 1: # pragma: no cover
120
+ async with conn.transaction():
121
+ for statement in self.inserts:
122
+ await cur.execute(statement)
123
+ results.append(await cur.fetchall())
124
+
125
+ if len(self.inserts) == 1:
126
+ await cur.execute(self.inserts[0])
127
+ if cur.rowcount == 0: # pragma: no cover
128
+ raise DJDatabaseError(
129
+ "Insert statement resulted in no records being inserted",
130
+ )
131
+ results.append((await cur.fetchone()))
132
+ if self.selects:
133
+ for statement in self.selects:
134
+ await cur.execute(statement)
135
+ results.append(await cur.fetchall())
136
+ await conn.commit()
137
+ self._reset()
138
+ return results
djqs/engine.py CHANGED
@@ -1,29 +1,31 @@
1
1
  """
2
2
  Query related functions.
3
3
  """
4
+ import json
4
5
  import logging
5
6
  import os
7
+ from dataclasses import asdict
6
8
  from datetime import datetime, timezone
7
9
  from typing import Dict, List, Optional, Tuple
8
10
 
9
11
  import duckdb
10
12
  import snowflake.connector
11
- import sqlparse
13
+ from psycopg_pool import AsyncConnectionPool
12
14
  from sqlalchemy import create_engine, text
13
- from sqlmodel import Session, select
14
15
 
15
- from djqs.config import Settings
16
+ from djqs.config import EngineType, Settings
16
17
  from djqs.constants import SQLALCHEMY_URI
17
- from djqs.models.engine import Engine, EngineType
18
+ from djqs.db.postgres import DBQuery
19
+ from djqs.exceptions import DJDatabaseError
18
20
  from djqs.models.query import (
19
21
  ColumnMetadata,
20
22
  Query,
21
23
  QueryResults,
22
24
  QueryState,
23
- Results,
24
25
  StatementResults,
25
26
  )
26
27
  from djqs.typing import ColumnType, Description, SQLADialect, Stream, TypeEnum
28
+ from djqs.utils import get_settings
27
29
 
28
30
  _logger = logging.getLogger(__name__)
29
31
 
@@ -67,7 +69,6 @@ def get_columns_from_description(
67
69
 
68
70
 
69
71
  def run_query( # pylint: disable=R0914
70
- session: Session,
71
72
  query: Query,
72
73
  headers: Optional[Dict[str, str]] = None,
73
74
  ) -> List[Tuple[str, List[ColumnMetadata], Stream]]:
@@ -80,13 +81,14 @@ def run_query( # pylint: disable=R0914
80
81
 
81
82
  _logger.info("Running query on catalog %s", query.catalog_name)
82
83
 
83
- engine = session.exec(
84
- select(Engine)
85
- .where(Engine.name == query.engine_name)
86
- .where(Engine.version == query.engine_version),
87
- ).one()
88
-
89
- query_server = headers.get("SQLALCHEMY_URI") if headers else None
84
+ settings = get_settings()
85
+ engine_name = query.engine_name or settings.default_engine
86
+ engine_version = query.engine_version or settings.default_engine_version
87
+ engine = settings.find_engine(
88
+ engine_name=engine_name,
89
+ engine_version=engine_version,
90
+ )
91
+ query_server = headers.get(SQLALCHEMY_URI) if headers else None
90
92
 
91
93
  if query_server:
92
94
  _logger.info(
@@ -122,18 +124,13 @@ def run_query( # pylint: disable=R0914
122
124
  connection = sqla_engine.connect()
123
125
 
124
126
  output: List[Tuple[str, List[ColumnMetadata], Stream]] = []
125
- statements = sqlparse.parse(query.executed_query)
126
- for statement in statements:
127
- # Druid doesn't like statements that end in a semicolon...
128
- sql = str(statement).strip().rstrip(";")
129
-
130
- results = connection.execute(text(sql))
131
- stream = (tuple(row) for row in results)
132
- columns = get_columns_from_description(
133
- results.cursor.description,
134
- sqla_engine.dialect,
135
- )
136
- output.append((sql, columns, stream))
127
+ results = connection.execute(text(query.executed_query))
128
+ stream = (tuple(row) for row in results)
129
+ columns = get_columns_from_description(
130
+ results.cursor.description,
131
+ sqla_engine.dialect,
132
+ )
133
+ output.append((query.executed_query, columns, stream)) # type: ignore
137
134
 
138
135
  return output
139
136
 
@@ -166,9 +163,9 @@ def run_snowflake_query(
166
163
  return output
167
164
 
168
165
 
169
- def process_query(
170
- session: Session,
166
+ async def process_query(
171
167
  settings: Settings,
168
+ postgres_pool: AsyncConnectionPool,
172
169
  query: Query,
173
170
  headers: Optional[Dict[str, str]] = None,
174
171
  ) -> QueryResults:
@@ -182,14 +179,13 @@ def process_query(
182
179
  errors = []
183
180
  query.started = datetime.now(timezone.utc)
184
181
  try:
185
- root = []
182
+ results = []
186
183
  for sql, columns, stream in run_query(
187
- session=session,
188
184
  query=query,
189
185
  headers=headers,
190
186
  ):
191
187
  rows = list(stream)
192
- root.append(
188
+ results.append(
193
189
  StatementResults(
194
190
  sql=sql,
195
191
  columns=columns,
@@ -197,21 +193,48 @@ def process_query(
197
193
  row_count=len(rows),
198
194
  ),
199
195
  )
200
- results = Results(__root__=root)
201
196
 
202
197
  query.state = QueryState.FINISHED
203
198
  query.progress = 1.0
204
199
  except Exception as ex: # pylint: disable=broad-except
205
- results = Results(__root__=[])
200
+ results = []
206
201
  query.state = QueryState.FAILED
207
202
  errors = [str(ex)]
208
203
 
209
204
  query.finished = datetime.now(timezone.utc)
210
205
 
211
- session.add(query)
212
- session.commit()
213
- session.refresh(query)
206
+ async with postgres_pool.connection() as conn:
207
+ dbquery_results = (
208
+ await DBQuery()
209
+ .save_query(
210
+ query_id=query.id,
211
+ submitted_query=query.submitted_query,
212
+ state=QueryState.FINISHED.value,
213
+ async_=query.async_,
214
+ )
215
+ .execute(conn=conn)
216
+ )
217
+ query_save_result = dbquery_results[0]
218
+ if not query_save_result: # pragma: no cover
219
+ raise DJDatabaseError("Query failed to save")
214
220
 
215
- settings.results_backend.add(str(query.id), results.json())
221
+ settings.results_backend.add(
222
+ str(query.id),
223
+ json.dumps([asdict(statement_result) for statement_result in results]),
224
+ )
216
225
 
217
- return QueryResults(results=results, errors=errors, **query.dict())
226
+ return QueryResults(
227
+ id=query.id,
228
+ catalog_name=query.catalog_name,
229
+ engine_name=query.engine_name,
230
+ engine_version=query.engine_version,
231
+ submitted_query=query.submitted_query,
232
+ executed_query=query.executed_query,
233
+ scheduled=query.scheduled,
234
+ started=query.started,
235
+ finished=query.finished,
236
+ state=query.state,
237
+ progress=query.progress,
238
+ results=results,
239
+ errors=errors,
240
+ )