prestogres 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -170,10 +170,10 @@ class StatementClient(object):
170
170
  "User-Agent": "presto-python/%s" % VERSION
171
171
  }
172
172
 
173
- def __init__(self, http_client, session, query):
173
+ def __init__(self, http_client, query, **options):
174
174
  self.http_client = http_client
175
- self.session = session
176
175
  self.query = query
176
+ self.options = options
177
177
 
178
178
  self.closed = False
179
179
  self.exception = None
@@ -183,14 +183,14 @@ class StatementClient(object):
183
183
  def _post_query_request(self):
184
184
  headers = StatementClient.HEADERS.copy()
185
185
 
186
- if self.session.user is not None:
187
- headers[PrestoHeaders.PRESTO_USER] = self.session.user
188
- if self.session.source is not None:
189
- headers[PrestoHeaders.PRESTO_SOURCE] = self.session.source
190
- if self.session.catalog is not None:
191
- headers[PrestoHeaders.PRESTO_CATALOG] = self.session.catalog
192
- if self.session.schema is not None:
193
- headers[PrestoHeaders.PRESTO_SCHEMA] = self.session.schema
186
+ if self.options.get("user") is not None:
187
+ headers[PrestoHeaders.PRESTO_USER] = self.options["user"]
188
+ if self.options.get("source") is not None:
189
+ headers[PrestoHeaders.PRESTO_SOURCE] = self.options["source"]
190
+ if self.options.get("catalog") is not None:
191
+ headers[PrestoHeaders.PRESTO_CATALOG] = self.options["catalog"]
192
+ if self.options.get("schema") is not None:
193
+ headers[PrestoHeaders.PRESTO_SCHEMA] = self.options["schema"]
194
194
 
195
195
  self.http_client.request("POST", "/v1/statement", self.query, headers)
196
196
  response = self.http_client.getresponse()
@@ -265,9 +265,9 @@ class StatementClient(object):
265
265
 
266
266
  class Query(object):
267
267
  @classmethod
268
- def start(cls, session, query):
269
- http_client = httplib.HTTPConnection(session.server)
270
- return Query(StatementClient(http_client, session, query))
268
+ def start(cls, query, **options):
269
+ http_client = httplib.HTTPConnection(host=options["server"], timeout=options.get("http_timeout", 300))
270
+ return Query(StatementClient(http_client, query, **options))
271
271
 
272
272
  def __init__(self, client):
273
273
  self.client = client
@@ -327,13 +327,13 @@ class Query(object):
327
327
 
328
328
  class Client(object):
329
329
  def __init__(self, **options):
330
- self.session = ClientSession(**options)
330
+ self.options = options
331
331
 
332
332
  def query(self, query):
333
- return Query.start(self.session, query)
333
+ return Query.start(query, **self.options)
334
334
 
335
335
  def run(self, query):
336
- q = Query.start(self.session, query)
336
+ q = Query.start(query, **self.options)
337
337
  try:
338
338
  columns = q.columns()
339
339
  if columns is None:
data/pgsql/prestogres.py CHANGED
@@ -1,156 +1,290 @@
1
- import presto_client
2
1
  import plpy
2
+ import presto_client
3
3
  from collections import namedtuple
4
4
  import time
5
5
 
6
- def run_presto_as_temp_table(server, user, catalog, schema, table_name, query):
7
- client = presto_client.Client(server=server, user=user, catalog=catalog, schema=schema)
8
-
9
- create_sql = 'create temp table ' + plpy.quote_ident(table_name) + ' (\n '
10
- insert_sql = 'insert into ' + plpy.quote_ident(table_name) + ' (\n '
11
- values_types = []
12
-
13
- q = client.query(query)
6
+ # convert Presto query result type to PostgreSQL type
7
+ def _pg_result_type(presto_type):
8
+ if presto_type == "varchar":
9
+ return "text"
10
+ elif presto_type == "bigint":
11
+ return "bigint"
12
+ elif presto_type == "boolean":
13
+ return "boolean"
14
+ elif presto_type == "double":
15
+ return "double precision"
16
+ else:
17
+ raise Exception, "unknown result column type: " + plpy.quote_ident(presto_type)
18
+
19
+ # convert Presto type to PostgreSQL type
20
+ def _pg_table_type(presto_type):
21
+ if presto_type == "varchar":
22
+ return "varchar(100)"
23
+ elif presto_type == "bigint":
24
+ return "bigint"
25
+ elif presto_type == "boolean":
26
+ return "boolean"
27
+ elif presto_type == "double":
28
+ return "double precision"
29
+ else:
30
+ raise Exception("unknown table column type: " + plpy.quote_ident(presto_type))
31
+
32
+ # build CREATE TEMPORARY TABLE statement
33
+ def _build_create_temp_table_sql(table_name, column_names, column_types, not_nulls=None):
34
+ create_sql = "create temporary table %s (\n " % plpy.quote_ident(table_name)
35
+
36
+ first = True
37
+ for column_name, column_type in zip(column_names, column_types):
38
+ if first:
39
+ first = False
40
+ else:
41
+ create_sql += ",\n "
42
+
43
+ create_sql += plpy.quote_ident(column_name)
44
+ create_sql += " "
45
+ create_sql += column_type
46
+
47
+ # TODO not null
48
+ #if not column.nullable:
49
+ # create_sql += " not null"
50
+
51
+ create_sql += "\n)"
52
+ return create_sql
53
+
54
+ # build INSERT INTO statement and string format to build VALUES (..), ...
55
+ def _build_insert_into_sql(table_name, column_names):
56
+ insert_sql = "insert into %s (\n " % plpy.quote_ident(table_name)
57
+
58
+ first = True
59
+ for column_name in column_names:
60
+ if first:
61
+ first = False
62
+ else:
63
+ insert_sql += ",\n "
64
+
65
+ insert_sql += plpy.quote_ident(column_name)
66
+
67
+ insert_sql += "\n) values\n"
68
+
69
+ values_sql_format = "(%s)" % (", ".join(["${}"] * len(column_names)))
70
+
71
+ return (insert_sql, values_sql_format)
72
+
73
+ # create a prepared statement for batch INSERT
74
+ def _plan_batch(insert_sql, values_sql_format, column_types, batch_size):
75
+ # format string 'values ($1, $2), ($3, $4) ...'
76
+ values_sql = (", ".join([values_sql_format] * batch_size)).format(*range(1, batch_size * len(column_types) + 1))
77
+ batch_insert_sql = insert_sql + values_sql
78
+ return plpy.prepare(batch_insert_sql, column_types * batch_size)
79
+
80
+ # run batch INSERT
81
+ def _batch_insert(insert_sql, values_sql_format, batch_size, column_types, rows):
82
+ full_batch_plan = None
83
+
84
+ batch = []
85
+ for row in rows:
86
+ batch.append(row)
87
+ batch_len = len(batch)
88
+ if batch_len >= batch_size:
89
+ if full_batch_plan is None:
90
+ full_batch_plan = _plan_batch(insert_sql, values_sql_format, column_types, batch_len)
91
+ plpy.execute(full_batch_plan, [item for sublist in batch for item in sublist])
92
+ del batch[:]
93
+
94
+ if batch:
95
+ plan = _plan_batch(insert_sql, values_sql_format, column_types, len(batch))
96
+ plpy.execute(plan, [item for sublist in batch for item in sublist])
97
+
98
+ class SchemaCache(object):
99
+ def __init__(self):
100
+ self.server = None
101
+ self.user = None
102
+ self.catalog = None
103
+ self.schema_names = None
104
+ self.statements = None
105
+ self.expire_time = None
106
+
107
+ def is_cached(self, server, user, catalog, current_time):
108
+ return self.server == server and self.user == user and self.catalog == catalog \
109
+ and self.statements is not None and current_time < self.expire_time
110
+
111
+ def set_cache(self, server, user, catalog, schema_names, statements, expire_time):
112
+ self.server = server
113
+ self.user = user
114
+ self.catalog = catalog
115
+ self.schema_names = schema_names
116
+ self.statements = statements
117
+ self.expire_time = expire_time
118
+
119
+ OidToTypeNameMapping = {}
120
+
121
+ def _load_oid_to_type_name_mapping(oids):
122
+ oids = filter(lambda oid: oid not in OidToTypeNameMapping, oids)
123
+ if oids:
124
+ sql = ("select oid, typname" \
125
+ " from pg_catalog.pg_type" \
126
+ " where oid in (%s)") % (", ".join(map(str, oids)))
127
+ for row in plpy.execute(sql):
128
+ OidToTypeNameMapping[int(row["oid"])] = row["typname"]
129
+
130
+ return OidToTypeNameMapping
131
+
132
+ Column = namedtuple("Column", ("name", "type", "nullable"))
133
+
134
+ SchemaCacheEntry = SchemaCache()
135
+
136
+ def run_presto_as_temp_table(server, user, catalog, schema, result_table, query):
14
137
  try:
15
- columns = q.columns()
16
- column_num = len(columns)
17
-
18
- first = True
19
- for column in columns:
20
- if column.type == "varchar":
21
- pg_column_type = "text"
22
- elif column.type == "bigint":
23
- pg_column_type = "bigint"
24
- elif column.type == "boolean":
25
- pg_column_type = "boolean"
26
- elif column.type == "double":
27
- pg_column_type = "double precision"
28
- else:
29
- raise Exception, "unknown column type: " + plpy.quote_ident(column.type)
30
-
31
- if first:
32
- first = False
33
- else:
34
- create_sql += ",\n "
35
- insert_sql += ",\n "
36
-
37
- create_sql += plpy.quote_ident(column.name) + ' ' + pg_column_type
38
- insert_sql += plpy.quote_ident(column.name)
39
- values_types.append(pg_column_type)
40
-
41
- create_sql += '\n)'
42
- #if trait:
43
- # create_sql += ' '
44
- # create_sql += trait
45
- create_sql += ';'
46
-
47
- insert_sql += '\n) values\n'
48
- values_sql_format = '(' + ', '.join(['${}'] * column_num) + ')'
49
-
50
- #plpy.execute("drop table if exists "+plpy.quote_ident(table_name))
138
+ client = presto_client.Client(server=server, user=user, catalog=catalog, schema=schema)
139
+
140
+ create_sql = "create temporary table %s (\n " % plpy.quote_ident(result_table)
141
+ insert_sql = "insert into %s (\n " % plpy.quote_ident(result_table)
142
+ values_types = []
143
+
144
+ q = client.query(query)
145
+ try:
146
+ # result schema
147
+ column_names = []
148
+ column_types = []
149
+ for column in q.columns():
150
+ column_names.append(column.name)
151
+ column_types.append(_pg_result_type(column.type))
152
+
153
+ # build SQL
154
+ create_sql = _build_create_temp_table_sql(result_table, column_names, column_types)
155
+ insert_sql, values_sql_format = _build_insert_into_sql(result_table, column_names)
156
+
157
+ # run CREATE TABLE
158
+ plpy.execute("drop table if exists " + plpy.quote_ident(result_table))
159
+ plpy.execute(create_sql)
160
+
161
+ # run INSERT
162
+ _batch_insert(insert_sql, values_sql_format, 10, column_types, q.results())
163
+ finally:
164
+ q.close()
165
+
166
+ except (plpy.SPIError, presto_client.PrestoException) as e:
167
+ # PL/Python converts an exception object in Python to an error message in PostgreSQL
168
+ # using exception class name if exc.__module__ is either of "builtins", "exceptions",
169
+ # or "__main__". Otherwise using "module.name" format. Set __module__ = "__module__"
170
+ # to generate pretty messages.
171
+ e.__class__.__module__ = "__main__"
172
+ raise
173
+
174
+ def run_system_catalog_as_temp_table(server, user, catalog, result_table, query):
175
+ try:
176
+ client = presto_client.Client(server=server, user=user, catalog=catalog, schema="default")
177
+
178
+ # create SQL statements which put data to system catalogs
179
+ if SchemaCacheEntry.is_cached(server, user, catalog, time.time()):
180
+ schema_names = SchemaCacheEntry.schema_names
181
+ statements = SchemaCacheEntry.statements
182
+
183
+ else:
184
+ # get table list
185
+ sql = "select table_schema, table_name, column_name, is_nullable, data_type" \
186
+ " from information_schema.columns"
187
+ columns, rows = client.run(sql)
188
+
189
+ schemas = {}
190
+
191
+ if rows is None:
192
+ rows = []
193
+
194
+ for row in rows:
195
+ schema_name = row[0]
196
+ table_name = row[1]
197
+ column_name = row[2]
198
+ is_nullable = row[3]
199
+ column_type = row[4]
200
+
201
+ tables = schemas.setdefault(schema_name, {})
202
+ columns = tables.setdefault(table_name, [])
203
+ columns.append(Column(column_name, column_type, is_nullable))
204
+
205
+ # generate SQL statements
206
+ statements = []
207
+ schema_names = []
208
+
209
+ for schema_name, tables in schemas.items():
210
+ if schema_name == "sys" or schema_name == "information_schema":
211
+ # skip system schemas
212
+ continue
213
+
214
+ schema_names.append(schema_name)
215
+
216
+ for table_name, columns in tables.items():
217
+ # table schema
218
+ column_names = []
219
+ column_types = []
220
+ not_nulls = []
221
+ for column in columns:
222
+ column_names.append(column.name)
223
+ column_types.append(_pg_table_type(column.type))
224
+ not_nulls.append(not column.nullable)
225
+
226
+ create_sql = _build_create_temp_table_sql(table_name, column_names, column_types, not_nulls)
227
+ statements.append(create_sql)
228
+
229
+ # cache expires after 10 seconds
230
+ SchemaCacheEntry.set_cache(server, user, catalog, schema_names, statements, time.time() + 10)
231
+
232
+ # enter subtransaction to rollback tables right after running the query
233
+ subxact = plpy.subtransaction()
234
+ subxact.enter()
235
+ try:
236
+ # delete all schemas excepting prestogres_catalog
237
+ sql = "select n.nspname as schema_name from pg_catalog.pg_namespace n" \
238
+ " where n.nspname not in ('prestogres_catalog', 'pg_catalog', 'information_schema', 'public')" \
239
+ " and n.nspname !~ '^pg_toast'"
240
+ for row in plpy.cursor(sql):
241
+ plpy.execute("drop schema %s cascade" % plpy.quote_ident(row["schema_name"]))
242
+
243
+ # delete all tables in prestogres_catalog
244
+ # relkind = 'r' takes only tables and skip views, indexes, etc.
245
+ sql = "select n.nspname as schema_name, c.relname as table_name from pg_catalog.pg_class c" \
246
+ " left join pg_catalog.pg_namespace n on n.oid = c.relnamespace" \
247
+ " where c.relkind in ('r')" \
248
+ " and n.nspname in ('prestogres_catalog')" \
249
+ " and n.nspname !~ '^pg_toast'"
250
+ for row in plpy.cursor(sql):
251
+ plpy.execute("drop table %s.%s" % (plpy.quote_ident(row["schema_name"]), plpy.quote_ident(row["table_name"])))
252
+
253
+ # create schemas
254
+ for schema_name in schema_names:
255
+ try:
256
+ plpy.execute("create schema %s" % plpy.quote_ident(schema_name))
257
+ except:
258
+ # ignore error
259
+ pass
260
+
261
+ # create tables
262
+ for statement in statements:
263
+ plpy.execute(statement)
264
+
265
+ # run the actual query
266
+ metadata = plpy.execute(query)
267
+ result = map(lambda row: row.values(), metadata)
268
+
269
+ finally:
270
+ # rollback subtransaction
271
+ subxact.exit("rollback subtransaction", None, None)
272
+
273
+ # table schema
274
+ oid_to_type_name = _load_oid_to_type_name_mapping(metadata.coltypes())
275
+ column_names = metadata.colnames()
276
+ column_types = map(oid_to_type_name.get, metadata.coltypes())
277
+
278
+ create_sql = _build_create_temp_table_sql(result_table, column_names, column_types)
279
+ insert_sql, values_sql_format = _build_insert_into_sql(result_table, column_names)
280
+
281
+ # run CREATE TABLE and INSERT
282
+ plpy.execute("drop table if exists " + plpy.quote_ident(result_table))
51
283
  plpy.execute(create_sql)
284
+ _batch_insert(insert_sql, values_sql_format, 10, column_types, result)
52
285
 
53
- batch = []
54
- for row in q.results():
55
- batch.append(row)
56
- if len(batch) > 10:
57
- batch_len = len(batch)
58
- # format string 'values ($1, $2), ($3, $4) ...'
59
- values_sql = (', '.join([values_sql_format] * batch_len)).format(*range(1, batch_len * column_num + 1))
60
- batch_insert_sql = insert_sql + values_sql
61
- # flatten rows into an array
62
- params = [item for sublist in batch for item in sublist]
63
- plpy.execute(plpy.prepare(batch_insert_sql, values_types * batch_len), params)
64
- del batch[:]
65
-
66
- if batch:
67
- batch_len = len(batch)
68
- # format string 'values ($1, $2), ($3, $4) ...'
69
- values_sql = (', '.join([values_sql_format] * batch_len)).format(*range(1, batch_len * column_num + 1))
70
- batch_insert_sql = insert_sql + values_sql
71
- # flatten rows into an array
72
- params = [item for sublist in batch for item in sublist]
73
- plpy.execute(plpy.prepare(batch_insert_sql, values_types * batch_len), params)
74
-
75
- finally:
76
- q.close()
77
-
78
- Column = namedtuple('Column', ('name', 'type', 'nullable'))
79
-
80
- cache_expire_times = {}
81
-
82
- def presto_create_tables(server, user, catalog):
83
- client = presto_client.Client(server=server, user=user, catalog=catalog, schema="default")
84
-
85
- cache_key = "%s:%s.%s" % (server, user, catalog)
86
- expire_time = cache_expire_times.get(cache_key)
87
- if expire_time is not None and time.time() - expire_time < 10:
88
- # TODO scan cache_expire_times and remove expired cache entries if it is large
89
- return
90
-
91
- try:
92
- schemas = {}
93
-
94
- columns, rows = client.run("select table_schema, table_name, column_name, is_nullable, data_type from information_schema.columns")
95
-
96
- if rows is None:
97
- return
98
-
99
- for row in rows:
100
- schema_name = row[0]
101
- table_name = row[1]
102
- column_name = row[2]
103
- is_nullable = row[3]
104
- column_type = row[4]
105
-
106
- tables = schemas.setdefault(schema_name, {})
107
- columns = tables.setdefault(table_name, [])
108
- columns.append(Column(column_name, column_type, is_nullable))
109
-
110
- for schema_name, tables in schemas.items():
111
- if schema_name == "sys" or schema_name == "information_schema":
112
- # skip system schemas
113
- continue
114
-
115
- # create schema
116
- try:
117
- plpy.execute("create schema %s" % plpy.quote_ident(schema_name))
118
- except:
119
- # ignore error
120
- pass
121
-
122
- for table_name, columns in tables.items():
123
- create_sql = "create table %s.%s (\n " % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name))
124
-
125
- first = True
126
- for column in columns:
127
- if column.type == "varchar":
128
- pg_column_type = "varchar(100)"
129
- elif column.type == "bigint":
130
- pg_column_type = "bigint"
131
- elif column.type == "boolean":
132
- pg_column_type = "boolean"
133
- elif column.type == "double":
134
- pg_column_type = "double precision"
135
- else:
136
- raise Exception("unknown column type: " + plpy.quote_ident(column.type))
137
-
138
- if first:
139
- first = False
140
- else:
141
- create_sql += ",\n "
142
-
143
- create_sql += plpy.quote_ident(column.name) + " " + pg_column_type
144
- if not column.nullable:
145
- create_sql += " not null"
146
-
147
- create_sql += "\n)"
148
-
149
- plpy.execute("drop table if exists %s.%s" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))
150
- plpy.execute(create_sql)
151
-
152
- cache_expire_times[cache_key] = time.time()
153
-
154
- except Exception as e:
155
- plpy.error(str(e))
286
+ except (plpy.SPIError, presto_client.PrestoException) as e:
287
+ # Set __module__ = "__module__" to generate pretty messages.
288
+ e.__class__.__module__ = "__main__"
289
+ raise
156
290