prestogres 0.1.0 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -170,10 +170,10 @@ class StatementClient(object):
170
170
  "User-Agent": "presto-python/%s" % VERSION
171
171
  }
172
172
 
173
- def __init__(self, http_client, session, query):
173
+ def __init__(self, http_client, query, **options):
174
174
  self.http_client = http_client
175
- self.session = session
176
175
  self.query = query
176
+ self.options = options
177
177
 
178
178
  self.closed = False
179
179
  self.exception = None
@@ -183,14 +183,14 @@ class StatementClient(object):
183
183
  def _post_query_request(self):
184
184
  headers = StatementClient.HEADERS.copy()
185
185
 
186
- if self.session.user is not None:
187
- headers[PrestoHeaders.PRESTO_USER] = self.session.user
188
- if self.session.source is not None:
189
- headers[PrestoHeaders.PRESTO_SOURCE] = self.session.source
190
- if self.session.catalog is not None:
191
- headers[PrestoHeaders.PRESTO_CATALOG] = self.session.catalog
192
- if self.session.schema is not None:
193
- headers[PrestoHeaders.PRESTO_SCHEMA] = self.session.schema
186
+ if self.options.get("user") is not None:
187
+ headers[PrestoHeaders.PRESTO_USER] = self.options["user"]
188
+ if self.options.get("source") is not None:
189
+ headers[PrestoHeaders.PRESTO_SOURCE] = self.options["source"]
190
+ if self.options.get("catalog") is not None:
191
+ headers[PrestoHeaders.PRESTO_CATALOG] = self.options["catalog"]
192
+ if self.options.get("schema") is not None:
193
+ headers[PrestoHeaders.PRESTO_SCHEMA] = self.options["schema"]
194
194
 
195
195
  self.http_client.request("POST", "/v1/statement", self.query, headers)
196
196
  response = self.http_client.getresponse()
@@ -265,9 +265,9 @@ class StatementClient(object):
265
265
 
266
266
  class Query(object):
267
267
  @classmethod
268
- def start(cls, session, query):
269
- http_client = httplib.HTTPConnection(session.server)
270
- return Query(StatementClient(http_client, session, query))
268
+ def start(cls, query, **options):
269
+ http_client = httplib.HTTPConnection(host=options["server"], timeout=options.get("http_timeout", 300))
270
+ return Query(StatementClient(http_client, query, **options))
271
271
 
272
272
  def __init__(self, client):
273
273
  self.client = client
@@ -327,13 +327,13 @@ class Query(object):
327
327
 
328
328
  class Client(object):
329
329
  def __init__(self, **options):
330
- self.session = ClientSession(**options)
330
+ self.options = options
331
331
 
332
332
  def query(self, query):
333
- return Query.start(self.session, query)
333
+ return Query.start(query, **self.options)
334
334
 
335
335
  def run(self, query):
336
- q = Query.start(self.session, query)
336
+ q = Query.start(query, **self.options)
337
337
  try:
338
338
  columns = q.columns()
339
339
  if columns is None:
data/pgsql/prestogres.py CHANGED
@@ -1,156 +1,290 @@
1
- import presto_client
2
1
  import plpy
2
+ import presto_client
3
3
  from collections import namedtuple
4
4
  import time
5
5
 
6
- def run_presto_as_temp_table(server, user, catalog, schema, table_name, query):
7
- client = presto_client.Client(server=server, user=user, catalog=catalog, schema=schema)
8
-
9
- create_sql = 'create temp table ' + plpy.quote_ident(table_name) + ' (\n '
10
- insert_sql = 'insert into ' + plpy.quote_ident(table_name) + ' (\n '
11
- values_types = []
12
-
13
- q = client.query(query)
6
+ # convert Presto query result type to PostgreSQL type
7
+ def _pg_result_type(presto_type):
8
+ if presto_type == "varchar":
9
+ return "text"
10
+ elif presto_type == "bigint":
11
+ return "bigint"
12
+ elif presto_type == "boolean":
13
+ return "boolean"
14
+ elif presto_type == "double":
15
+ return "double precision"
16
+ else:
17
+ raise Exception, "unknown result column type: " + plpy.quote_ident(presto_type)
18
+
19
+ # convert Presto type to PostgreSQL type
20
+ def _pg_table_type(presto_type):
21
+ if presto_type == "varchar":
22
+ return "varchar(100)"
23
+ elif presto_type == "bigint":
24
+ return "bigint"
25
+ elif presto_type == "boolean":
26
+ return "boolean"
27
+ elif presto_type == "double":
28
+ return "double precision"
29
+ else:
30
+ raise Exception("unknown table column type: " + plpy.quote_ident(presto_type))
31
+
32
+ # build CREATE TEMPORARY TABLE statement
33
+ def _build_create_temp_table_sql(table_name, column_names, column_types, not_nulls=None):
34
+ create_sql = "create temporary table %s (\n " % plpy.quote_ident(table_name)
35
+
36
+ first = True
37
+ for column_name, column_type in zip(column_names, column_types):
38
+ if first:
39
+ first = False
40
+ else:
41
+ create_sql += ",\n "
42
+
43
+ create_sql += plpy.quote_ident(column_name)
44
+ create_sql += " "
45
+ create_sql += column_type
46
+
47
+ # TODO not null
48
+ #if not column.nullable:
49
+ # create_sql += " not null"
50
+
51
+ create_sql += "\n)"
52
+ return create_sql
53
+
54
+ # build INSERT INTO statement and string format to build VALUES (..), ...
55
+ def _build_insert_into_sql(table_name, column_names):
56
+ insert_sql = "insert into %s (\n " % plpy.quote_ident(table_name)
57
+
58
+ first = True
59
+ for column_name in column_names:
60
+ if first:
61
+ first = False
62
+ else:
63
+ insert_sql += ",\n "
64
+
65
+ insert_sql += plpy.quote_ident(column_name)
66
+
67
+ insert_sql += "\n) values\n"
68
+
69
+ values_sql_format = "(%s)" % (", ".join(["${}"] * len(column_names)))
70
+
71
+ return (insert_sql, values_sql_format)
72
+
73
+ # create a prepared statement for batch INSERT
74
+ def _plan_batch(insert_sql, values_sql_format, column_types, batch_size):
75
+ # format string 'values ($1, $2), ($3, $4) ...'
76
+ values_sql = (", ".join([values_sql_format] * batch_size)).format(*range(1, batch_size * len(column_types) + 1))
77
+ batch_insert_sql = insert_sql + values_sql
78
+ return plpy.prepare(batch_insert_sql, column_types * batch_size)
79
+
80
+ # run batch INSERT
81
+ def _batch_insert(insert_sql, values_sql_format, batch_size, column_types, rows):
82
+ full_batch_plan = None
83
+
84
+ batch = []
85
+ for row in rows:
86
+ batch.append(row)
87
+ batch_len = len(batch)
88
+ if batch_len >= batch_size:
89
+ if full_batch_plan is None:
90
+ full_batch_plan = _plan_batch(insert_sql, values_sql_format, column_types, batch_len)
91
+ plpy.execute(full_batch_plan, [item for sublist in batch for item in sublist])
92
+ del batch[:]
93
+
94
+ if batch:
95
+ plan = _plan_batch(insert_sql, values_sql_format, column_types, len(batch))
96
+ plpy.execute(plan, [item for sublist in batch for item in sublist])
97
+
98
+ class SchemaCache(object):
99
+ def __init__(self):
100
+ self.server = None
101
+ self.user = None
102
+ self.catalog = None
103
+ self.schema_names = None
104
+ self.statements = None
105
+ self.expire_time = None
106
+
107
+ def is_cached(self, server, user, catalog, current_time):
108
+ return self.server == server and self.user == user and self.catalog == catalog \
109
+ and self.statements is not None and current_time < self.expire_time
110
+
111
+ def set_cache(self, server, user, catalog, schema_names, statements, expire_time):
112
+ self.server = server
113
+ self.user = user
114
+ self.catalog = catalog
115
+ self.schema_names = schema_names
116
+ self.statements = statements
117
+ self.expire_time = expire_time
118
+
119
+ OidToTypeNameMapping = {}
120
+
121
+ def _load_oid_to_type_name_mapping(oids):
122
+ oids = filter(lambda oid: oid not in OidToTypeNameMapping, oids)
123
+ if oids:
124
+ sql = ("select oid, typname" \
125
+ " from pg_catalog.pg_type" \
126
+ " where oid in (%s)") % (", ".join(map(str, oids)))
127
+ for row in plpy.execute(sql):
128
+ OidToTypeNameMapping[int(row["oid"])] = row["typname"]
129
+
130
+ return OidToTypeNameMapping
131
+
132
+ Column = namedtuple("Column", ("name", "type", "nullable"))
133
+
134
+ SchemaCacheEntry = SchemaCache()
135
+
136
+ def run_presto_as_temp_table(server, user, catalog, schema, result_table, query):
14
137
  try:
15
- columns = q.columns()
16
- column_num = len(columns)
17
-
18
- first = True
19
- for column in columns:
20
- if column.type == "varchar":
21
- pg_column_type = "text"
22
- elif column.type == "bigint":
23
- pg_column_type = "bigint"
24
- elif column.type == "boolean":
25
- pg_column_type = "boolean"
26
- elif column.type == "double":
27
- pg_column_type = "double precision"
28
- else:
29
- raise Exception, "unknown column type: " + plpy.quote_ident(column.type)
30
-
31
- if first:
32
- first = False
33
- else:
34
- create_sql += ",\n "
35
- insert_sql += ",\n "
36
-
37
- create_sql += plpy.quote_ident(column.name) + ' ' + pg_column_type
38
- insert_sql += plpy.quote_ident(column.name)
39
- values_types.append(pg_column_type)
40
-
41
- create_sql += '\n)'
42
- #if trait:
43
- # create_sql += ' '
44
- # create_sql += trait
45
- create_sql += ';'
46
-
47
- insert_sql += '\n) values\n'
48
- values_sql_format = '(' + ', '.join(['${}'] * column_num) + ')'
49
-
50
- #plpy.execute("drop table if exists "+plpy.quote_ident(table_name))
138
+ client = presto_client.Client(server=server, user=user, catalog=catalog, schema=schema)
139
+
140
+ create_sql = "create temporary table %s (\n " % plpy.quote_ident(result_table)
141
+ insert_sql = "insert into %s (\n " % plpy.quote_ident(result_table)
142
+ values_types = []
143
+
144
+ q = client.query(query)
145
+ try:
146
+ # result schema
147
+ column_names = []
148
+ column_types = []
149
+ for column in q.columns():
150
+ column_names.append(column.name)
151
+ column_types.append(_pg_result_type(column.type))
152
+
153
+ # build SQL
154
+ create_sql = _build_create_temp_table_sql(result_table, column_names, column_types)
155
+ insert_sql, values_sql_format = _build_insert_into_sql(result_table, column_names)
156
+
157
+ # run CREATE TABLE
158
+ plpy.execute("drop table if exists " + plpy.quote_ident(result_table))
159
+ plpy.execute(create_sql)
160
+
161
+ # run INSERT
162
+ _batch_insert(insert_sql, values_sql_format, 10, column_types, q.results())
163
+ finally:
164
+ q.close()
165
+
166
+ except (plpy.SPIError, presto_client.PrestoException) as e:
167
+ # PL/Python converts an exception object in Python to an error message in PostgreSQL
168
+ # using exception class name if exc.__module__ is either of "builtins", "exceptions",
169
+ # or "__main__". Otherwise using "module.name" format. Set __module__ = "__module__"
170
+ # to generate pretty messages.
171
+ e.__class__.__module__ = "__main__"
172
+ raise
173
+
174
+ def run_system_catalog_as_temp_table(server, user, catalog, result_table, query):
175
+ try:
176
+ client = presto_client.Client(server=server, user=user, catalog=catalog, schema="default")
177
+
178
+ # create SQL statements which put data to system catalogs
179
+ if SchemaCacheEntry.is_cached(server, user, catalog, time.time()):
180
+ schema_names = SchemaCacheEntry.schema_names
181
+ statements = SchemaCacheEntry.statements
182
+
183
+ else:
184
+ # get table list
185
+ sql = "select table_schema, table_name, column_name, is_nullable, data_type" \
186
+ " from information_schema.columns"
187
+ columns, rows = client.run(sql)
188
+
189
+ schemas = {}
190
+
191
+ if rows is None:
192
+ rows = []
193
+
194
+ for row in rows:
195
+ schema_name = row[0]
196
+ table_name = row[1]
197
+ column_name = row[2]
198
+ is_nullable = row[3]
199
+ column_type = row[4]
200
+
201
+ tables = schemas.setdefault(schema_name, {})
202
+ columns = tables.setdefault(table_name, [])
203
+ columns.append(Column(column_name, column_type, is_nullable))
204
+
205
+ # generate SQL statements
206
+ statements = []
207
+ schema_names = []
208
+
209
+ for schema_name, tables in schemas.items():
210
+ if schema_name == "sys" or schema_name == "information_schema":
211
+ # skip system schemas
212
+ continue
213
+
214
+ schema_names.append(schema_name)
215
+
216
+ for table_name, columns in tables.items():
217
+ # table schema
218
+ column_names = []
219
+ column_types = []
220
+ not_nulls = []
221
+ for column in columns:
222
+ column_names.append(column.name)
223
+ column_types.append(_pg_table_type(column.type))
224
+ not_nulls.append(not column.nullable)
225
+
226
+ create_sql = _build_create_temp_table_sql(table_name, column_names, column_types, not_nulls)
227
+ statements.append(create_sql)
228
+
229
+ # cache expires after 10 seconds
230
+ SchemaCacheEntry.set_cache(server, user, catalog, schema_names, statements, time.time() + 10)
231
+
232
+ # enter subtransaction to rollback tables right after running the query
233
+ subxact = plpy.subtransaction()
234
+ subxact.enter()
235
+ try:
236
+ # delete all schemas excepting prestogres_catalog
237
+ sql = "select n.nspname as schema_name from pg_catalog.pg_namespace n" \
238
+ " where n.nspname not in ('prestogres_catalog', 'pg_catalog', 'information_schema', 'public')" \
239
+ " and n.nspname !~ '^pg_toast'"
240
+ for row in plpy.cursor(sql):
241
+ plpy.execute("drop schema %s cascade" % plpy.quote_ident(row["schema_name"]))
242
+
243
+ # delete all tables in prestogres_catalog
244
+ # relkind = 'r' takes only tables and skip views, indexes, etc.
245
+ sql = "select n.nspname as schema_name, c.relname as table_name from pg_catalog.pg_class c" \
246
+ " left join pg_catalog.pg_namespace n on n.oid = c.relnamespace" \
247
+ " where c.relkind in ('r')" \
248
+ " and n.nspname in ('prestogres_catalog')" \
249
+ " and n.nspname !~ '^pg_toast'"
250
+ for row in plpy.cursor(sql):
251
+ plpy.execute("drop table %s.%s" % (plpy.quote_ident(row["schema_name"]), plpy.quote_ident(row["table_name"])))
252
+
253
+ # create schemas
254
+ for schema_name in schema_names:
255
+ try:
256
+ plpy.execute("create schema %s" % plpy.quote_ident(schema_name))
257
+ except:
258
+ # ignore error
259
+ pass
260
+
261
+ # create tables
262
+ for statement in statements:
263
+ plpy.execute(statement)
264
+
265
+ # run the actual query
266
+ metadata = plpy.execute(query)
267
+ result = map(lambda row: row.values(), metadata)
268
+
269
+ finally:
270
+ # rollback subtransaction
271
+ subxact.exit("rollback subtransaction", None, None)
272
+
273
+ # table schema
274
+ oid_to_type_name = _load_oid_to_type_name_mapping(metadata.coltypes())
275
+ column_names = metadata.colnames()
276
+ column_types = map(oid_to_type_name.get, metadata.coltypes())
277
+
278
+ create_sql = _build_create_temp_table_sql(result_table, column_names, column_types)
279
+ insert_sql, values_sql_format = _build_insert_into_sql(result_table, column_names)
280
+
281
+ # run CREATE TABLE and INSERT
282
+ plpy.execute("drop table if exists " + plpy.quote_ident(result_table))
51
283
  plpy.execute(create_sql)
284
+ _batch_insert(insert_sql, values_sql_format, 10, column_types, result)
52
285
 
53
- batch = []
54
- for row in q.results():
55
- batch.append(row)
56
- if len(batch) > 10:
57
- batch_len = len(batch)
58
- # format string 'values ($1, $2), ($3, $4) ...'
59
- values_sql = (', '.join([values_sql_format] * batch_len)).format(*range(1, batch_len * column_num + 1))
60
- batch_insert_sql = insert_sql + values_sql
61
- # flatten rows into an array
62
- params = [item for sublist in batch for item in sublist]
63
- plpy.execute(plpy.prepare(batch_insert_sql, values_types * batch_len), params)
64
- del batch[:]
65
-
66
- if batch:
67
- batch_len = len(batch)
68
- # format string 'values ($1, $2), ($3, $4) ...'
69
- values_sql = (', '.join([values_sql_format] * batch_len)).format(*range(1, batch_len * column_num + 1))
70
- batch_insert_sql = insert_sql + values_sql
71
- # flatten rows into an array
72
- params = [item for sublist in batch for item in sublist]
73
- plpy.execute(plpy.prepare(batch_insert_sql, values_types * batch_len), params)
74
-
75
- finally:
76
- q.close()
77
-
78
- Column = namedtuple('Column', ('name', 'type', 'nullable'))
79
-
80
- cache_expire_times = {}
81
-
82
- def presto_create_tables(server, user, catalog):
83
- client = presto_client.Client(server=server, user=user, catalog=catalog, schema="default")
84
-
85
- cache_key = "%s:%s.%s" % (server, user, catalog)
86
- expire_time = cache_expire_times.get(cache_key)
87
- if expire_time is not None and time.time() - expire_time < 10:
88
- # TODO scan cache_expire_times and remove expired cache entries if it is large
89
- return
90
-
91
- try:
92
- schemas = {}
93
-
94
- columns, rows = client.run("select table_schema, table_name, column_name, is_nullable, data_type from information_schema.columns")
95
-
96
- if rows is None:
97
- return
98
-
99
- for row in rows:
100
- schema_name = row[0]
101
- table_name = row[1]
102
- column_name = row[2]
103
- is_nullable = row[3]
104
- column_type = row[4]
105
-
106
- tables = schemas.setdefault(schema_name, {})
107
- columns = tables.setdefault(table_name, [])
108
- columns.append(Column(column_name, column_type, is_nullable))
109
-
110
- for schema_name, tables in schemas.items():
111
- if schema_name == "sys" or schema_name == "information_schema":
112
- # skip system schemas
113
- continue
114
-
115
- # create schema
116
- try:
117
- plpy.execute("create schema %s" % plpy.quote_ident(schema_name))
118
- except:
119
- # ignore error
120
- pass
121
-
122
- for table_name, columns in tables.items():
123
- create_sql = "create table %s.%s (\n " % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name))
124
-
125
- first = True
126
- for column in columns:
127
- if column.type == "varchar":
128
- pg_column_type = "varchar(100)"
129
- elif column.type == "bigint":
130
- pg_column_type = "bigint"
131
- elif column.type == "boolean":
132
- pg_column_type = "boolean"
133
- elif column.type == "double":
134
- pg_column_type = "double precision"
135
- else:
136
- raise Exception("unknown column type: " + plpy.quote_ident(column.type))
137
-
138
- if first:
139
- first = False
140
- else:
141
- create_sql += ",\n "
142
-
143
- create_sql += plpy.quote_ident(column.name) + " " + pg_column_type
144
- if not column.nullable:
145
- create_sql += " not null"
146
-
147
- create_sql += "\n)"
148
-
149
- plpy.execute("drop table if exists %s.%s" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))
150
- plpy.execute(create_sql)
151
-
152
- cache_expire_times[cache_key] = time.time()
153
-
154
- except Exception as e:
155
- plpy.error(str(e))
286
+ except (plpy.SPIError, presto_client.PrestoException) as e:
287
+ # Set __module__ = "__module__" to generate pretty messages.
288
+ e.__class__.__module__ = "__main__"
289
+ raise
156
290