databricks-sql-connector 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import unicode_literals
3
+
4
+ # Make all DB-API interface objects visible in this module.
5
+ from databricks.sql.dbapi import *
6
+ # Make all exceptions visible in this module per DB-API
7
+ from databricks.sql.exc import *
8
+
9
+ __version__ = "0.9.1"
10
+ USER_AGENT_NAME = "PyDatabricksSqlConnector"
11
+
12
+ def connect(
13
+ server_hostname,
14
+ http_path,
15
+ access_token,
16
+ **kwargs
17
+ ):
18
+ """Connect to a Databricks SQL endpoint or a Databricks cluster.
19
+
20
+ :param server_hostname: Databricks instance host name.
21
+ :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
22
+ or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
23
+ :param access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
24
+
25
+ :returns: a :py:class:`Connection` object.
26
+ """
27
+ from databricks.sql.client import Connection
28
+ return Connection(server_hostname, http_path, access_token, **kwargs)
@@ -0,0 +1,454 @@
1
+ """DB-API implementation backed by Databricks.
2
+
3
+ See http://www.python.org/dev/peps/pep-0249/
4
+
5
+ Many docstrings in this file are based on the PEP, which is in the public domain.
6
+ """
7
+
8
+ import base64
9
+ import datetime
10
+ import re
11
+ from decimal import Decimal
12
+ from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
13
+
14
+
15
+ from databricks.sql import common, USER_AGENT_NAME, __version__
16
+ from databricks.sql.exc import *
17
+ from databricks.sql.TCLIService import TCLIService, ttypes
18
+ from builtins import range
19
+ from future.utils import iteritems
20
+ import logging
21
+ import sys
22
+ import thrift.transport.THttpClient
23
+ import thrift.protocol.TBinaryProtocol
24
+ import thrift.transport.TSocket
25
+ import thrift.transport.TTransport
26
+
27
+ _logger = logging.getLogger(__name__)
28
+
29
+ _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
30
+
31
+ ssl_cert_parameter_map = {
32
+ "none": CERT_NONE,
33
+ "optional": CERT_OPTIONAL,
34
+ "required": CERT_REQUIRED,
35
+ }
36
+
37
+ def _parse_timestamp(value):
38
+ if value:
39
+ match = _TIMESTAMP_PATTERN.match(value)
40
+ if match:
41
+ if match.group(2):
42
+ format = '%Y-%m-%d %H:%M:%S.%f'
43
+ # use the pattern to truncate the value
44
+ value = match.group()
45
+ else:
46
+ format = '%Y-%m-%d %H:%M:%S'
47
+ value = datetime.datetime.strptime(value, format)
48
+ else:
49
+ raise Exception(
50
+ 'Cannot convert "{}" into a datetime'.format(value))
51
+ else:
52
+ value = None
53
+ return value
54
+
55
+
56
+ def _parse_date(value):
57
+ if value:
58
+ format = '%Y-%m-%d'
59
+ value = datetime.datetime.strptime(value, format).date()
60
+ else:
61
+ value = None
62
+ return value
63
+
64
+
65
+ TYPES_CONVERTER = {"decimal": Decimal,
66
+ "timestamp": _parse_timestamp,
67
+ "date": _parse_date}
68
+
69
+
70
+ class _HiveParamEscaper(common.ParamEscaper):
71
+ def escape_string(self, item):
72
+ # backslashes and single quotes need to be escaped
73
+ # TODO verify against parser
74
+ # Need to decode UTF-8 because of old sqlalchemy.
75
+ # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
76
+ # as byte strings. The old version always encodes Unicode as byte strings, which breaks
77
+ # string formatting here.
78
+ if isinstance(item, bytes):
79
+ item = item.decode('utf-8')
80
+ return "'{}'".format(
81
+ item
82
+ .replace('\\', '\\\\')
83
+ .replace("'", "\\'")
84
+ .replace('\r', '\\r')
85
+ .replace('\n', '\\n')
86
+ .replace('\t', '\\t')
87
+ )
88
+
89
+
90
+ _escaper = _HiveParamEscaper()
91
+
92
+
93
+ class Connection(object):
94
+ """Wraps a Thrift session"""
95
+
96
+ def __init__(
97
+ self,
98
+ server_hostname,
99
+ http_path,
100
+ access_token,
101
+ **kwargs
102
+ ):
103
+ """Connect to a Databricks SQL endpoint or a Databricks cluster.
104
+
105
+ :param server_hostname: Databricks instance host name.
106
+ :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
107
+ or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
108
+ :param access_token: Http Bearer access token, e.g. Databricks Personal Access Token.
109
+ """
110
+
111
+ # Internal arguments in **kwargs:
112
+ # _user_agent_entry
113
+ # Tag to add to User-Agent header. For use by partners.
114
+ # _username, _password
115
+ # Username and password Basic authentication (no official support)
116
+ # _tls_no_verify
117
+ # Set to True (Boolean) to completely disable SSL verification.
118
+ # _tls_verify_hostname
119
+ # Set to False (Boolean) to disable SSL hostname verification, but check certificate.
120
+ # _tls_trusted_ca_file
121
+ # Set to the path of the file containing trusted CA certificates for server certificate
122
+ # verification. If not provide, uses system truststore.
123
+ # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
124
+ # Set client SSL certificate.
125
+ # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
126
+ # _connection_uri
127
+ # Overrides server_hostname and http_path.
128
+
129
+ port = 443
130
+ if kwargs.get("_connection_uri"):
131
+ uri = kwargs.get("_connection_uri")
132
+ elif server_hostname and http_path:
133
+ uri = "https://{host}:{port}/{path}".format(
134
+ host=server_hostname, port=port, path=http_path.lstrip("/"))
135
+ else:
136
+ raise ValueError("No valid connection settings.")
137
+
138
+ # Configure tls context
139
+ ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
140
+ if kwargs.get("_tls_no_verify") is True:
141
+ ssl_context.check_hostname = False
142
+ ssl_context.verify_mode = CERT_NONE
143
+ elif kwargs.get("_tls_verify_hostname") is False:
144
+ ssl_context.check_hostname = False
145
+ ssl_context.verify_mode = CERT_REQUIRED
146
+ else:
147
+ ssl_context.check_hostname = True
148
+ ssl_context.verify_mode = CERT_REQUIRED
149
+
150
+ tls_client_cert_file = kwargs.get("_tls_client_cert_file")
151
+ tls_client_cert_key_file = kwargs.get("__tls_client_cert_key_file")
152
+ tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
153
+ if tls_client_cert_file:
154
+ ssl_context.load_cert_chain(
155
+ certfile=tls_client_cert_file,
156
+ keyfile=tls_client_cert_key_file,
157
+ password=tls_client_cert_key_password
158
+ )
159
+
160
+ self._transport = thrift.transport.THttpClient.THttpClient(
161
+ uri_or_host=uri,
162
+ ssl_context=ssl_context,
163
+ )
164
+
165
+ if kwargs.get("_username") and kwargs.get("_password"):
166
+ auth_credentials = "{username}:{password}".format(
167
+ username=kwargs.get("_username"), password=kwargs.get("_password")
168
+ ).encode("UTF-8")
169
+ auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
170
+ "UTF-8"
171
+ )
172
+ authorization_header = "Basic {}".format(auth_credentials_base64)
173
+ elif access_token:
174
+ authorization_header = "Bearer {}".format(access_token)
175
+ else:
176
+ raise ValueError("No valid authentication settings.")
177
+
178
+ if not kwargs.get("_user_agent_entry"):
179
+ useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
180
+ else:
181
+ useragent_header = "{}/{} ({})".format(
182
+ USER_AGENT_NAME, __version__, kwargs.get("_user_agent_entry"))
183
+
184
+ self._transport.setCustomHeaders({
185
+ "Authorization" : authorization_header,
186
+ "User-Agent" : useragent_header
187
+ })
188
+ protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
189
+ self._client = TCLIService.Client(protocol)
190
+ # oldest version that still contains features we care about
191
+ # "V6 uses binary type for binary payload (was string) and uses columnar result set"
192
+ protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
193
+
194
+ try:
195
+ self._transport.open()
196
+ open_session_req = ttypes.TOpenSessionReq(
197
+ client_protocol=protocol_version
198
+ )
199
+ response = self._client.OpenSession(open_session_req)
200
+ _check_status(response)
201
+ assert response.sessionHandle is not None, "Expected a session from OpenSession"
202
+ self._sessionHandle = response.sessionHandle
203
+ assert response.serverProtocolVersion == protocol_version, \
204
+ "Unable to handle protocol version {}".format(response.serverProtocolVersion)
205
+ except:
206
+ self._transport.close()
207
+ raise
208
+
209
+ def __enter__(self):
210
+ """Transport should already be opened by __init__"""
211
+ return self
212
+
213
+ def __exit__(self, exc_type, exc_val, exc_tb):
214
+ """Call close"""
215
+ self.close()
216
+
217
+ def close(self):
218
+ """Close the underlying session and Thrift transport"""
219
+ req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
220
+ response = self._client.CloseSession(req)
221
+ self._transport.close()
222
+ _check_status(response)
223
+
224
+ def commit(self):
225
+ """No-op because Databricks does not support transactions"""
226
+ pass
227
+
228
+ def cursor(self, *args, **kwargs):
229
+ """Return a new :py:class:`Cursor` object using the connection."""
230
+ return Cursor(self, *args, **kwargs)
231
+
232
+ def rollback(self):
233
+ raise NotSupportedError("Databricks does not have transactions")
234
+
235
+
236
+ class Cursor(common.DBAPICursor):
237
+ """These objects represent a database cursor, which is used to manage the context of a fetch
238
+ operation.
239
+
240
+ Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
241
+ visible by other cursors or connections.
242
+ """
243
+
244
+ def __init__(self, connection, arraysize=10000):
245
+ self._operationHandle = None
246
+ super(Cursor, self).__init__()
247
+ self._arraysize = arraysize
248
+ self._connection = connection
249
+
250
+ def _reset_state(self):
251
+ """Reset state about the previous query in preparation for running another query"""
252
+ super(Cursor, self)._reset_state()
253
+ self._description = None
254
+ if self._operationHandle is not None:
255
+ request = ttypes.TCloseOperationReq(self._operationHandle)
256
+ try:
257
+ response = self._connection._client.CloseOperation(request)
258
+ _check_status(response)
259
+ finally:
260
+ self._operationHandle = None
261
+
262
+ @property
263
+ def arraysize(self):
264
+ return self._arraysize
265
+
266
+ @arraysize.setter
267
+ def arraysize(self, value):
268
+ """Array size cannot be None, and should be an integer"""
269
+ default_arraysize = 10000
270
+ try:
271
+ self._arraysize = int(value) or default_arraysize
272
+ except TypeError:
273
+ self._arraysize = default_arraysize
274
+
275
+ @property
276
+ def description(self):
277
+ """This read-only attribute is a sequence of 7-item sequences.
278
+
279
+ Each of these sequences contains information describing one result column:
280
+
281
+ - name
282
+ - type_code
283
+ - display_size (None in current implementation)
284
+ - internal_size (None in current implementation)
285
+ - precision (None in current implementation)
286
+ - scale (None in current implementation)
287
+ - null_ok (always True in current implementation)
288
+
289
+ This attribute will be ``None`` for operations that do not return rows or if the cursor has
290
+ not had an operation invoked via the :py:meth:`execute` method yet.
291
+
292
+ The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
293
+ section below.
294
+ """
295
+ if self._operationHandle is None or not self._operationHandle.hasResultSet:
296
+ return None
297
+ if self._description is None:
298
+ req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
299
+ response = self._connection._client.GetResultSetMetadata(req)
300
+ _check_status(response)
301
+ columns = response.schema.columns
302
+ self._description = []
303
+ for col in columns:
304
+ primary_type_entry = col.typeDesc.types[0]
305
+ if primary_type_entry.primitiveEntry is None:
306
+ # All fancy stuff maps to string
307
+ type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
308
+ else:
309
+ type_id = primary_type_entry.primitiveEntry.type
310
+ type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
311
+ # Stripping "_TYPE" and converting to lowercase makes TTypeIds consistent with
312
+ # Spark types.
313
+ if type_code.endswith("_TYPE"):
314
+ type_code = type_code[:-5]
315
+ type_code = type_code.lower()
316
+
317
+ self._description.append((
318
+ col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName,
319
+ type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code,
320
+ None, None, None, None, True
321
+ ))
322
+ return self._description
323
+
324
+ def __enter__(self):
325
+ return self
326
+
327
+ def __exit__(self, exc_type, exc_val, exc_tb):
328
+ self.close()
329
+
330
+ def close(self):
331
+ """Close the operation handle"""
332
+ self._reset_state()
333
+
334
+ def execute(self, operation, parameters=None):
335
+ """Prepare and execute a database operation (query or command).
336
+
337
+ Return values are not defined.
338
+ """
339
+ # Prepare statement
340
+ if parameters is None:
341
+ sql = operation
342
+ else:
343
+ sql = operation % _escaper.escape_args(parameters)
344
+
345
+ self._reset_state()
346
+
347
+ self._state = self._STATE_RUNNING
348
+ _logger.info('%s', sql)
349
+
350
+ req = ttypes.TExecuteStatementReq(self._connection._sessionHandle,
351
+ sql, runAsync=True)
352
+ _logger.debug(req)
353
+ response = self._connection._client.ExecuteStatement(req)
354
+ _check_status(response)
355
+ self._operationHandle = response.operationHandle
356
+
357
+ self.__wait_for_query_completion()
358
+
359
+ def cancel(self):
360
+ req = ttypes.TCancelOperationReq(
361
+ operationHandle=self._operationHandle,
362
+ )
363
+ response = self._connection._client.CancelOperation(req)
364
+ _check_status(response)
365
+
366
+ def _fetch_more(self):
367
+ """Send another TFetchResultsReq and update state"""
368
+ assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
369
+ assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
370
+ if not self._operationHandle.hasResultSet:
371
+ raise ProgrammingError("No result set")
372
+ req = ttypes.TFetchResultsReq(
373
+ operationHandle=self._operationHandle,
374
+ orientation=ttypes.TFetchOrientation.FETCH_NEXT,
375
+ maxRows=self.arraysize,
376
+ )
377
+ response = self._connection._client.FetchResults(req)
378
+ _check_status(response)
379
+ schema = self.description
380
+ assert not response.results.rows, 'expected data in columnar format'
381
+ columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
382
+ zip(response.results.columns, schema)]
383
+ new_data = list(zip(*columns))
384
+ self._data += new_data
385
+ # response.hasMoreRows seems to always be False, so we instead check the number of rows
386
+ # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
387
+ # if not response.hasMoreRows:
388
+ if not new_data:
389
+ self._state = self._STATE_FINISHED
390
+
391
+ def __wait_for_query_completion(self):
392
+ state = ttypes.TOperationState.INITIALIZED_STATE
393
+ while state in [ttypes.TOperationState.INITIALIZED_STATE, ttypes.TOperationState.PENDING_STATE, ttypes.TOperationState.RUNNING_STATE]:
394
+ resp = self.__poll()
395
+ state = resp.operationState
396
+
397
+ if state in [ttypes.TOperationState.ERROR_STATE,
398
+ ttypes.TOperationState.UKNOWN_STATE,
399
+ ttypes.TOperationState.CANCELED_STATE,
400
+ ttypes.TOperationState.TIMEDOUT_STATE]:
401
+ raise OperationalError("Query execution failed.\nState: {}; Error code: {}; SQLSTATE: {}\nError message: {}\n"
402
+ .format(ttypes.TOperationState._VALUES_TO_NAMES[state],
403
+ resp.errorCode,
404
+ resp.errorMessage,
405
+ resp.sqlState))
406
+
407
+ def __poll(self, get_progress_update=True):
408
+ """Poll for and return the raw status data provided by the Hive Thrift REST API.
409
+ :returns: ``ttypes.TGetOperationStatusResp``
410
+ :raises: ``ProgrammingError`` when no query has been started
411
+ .. note::
412
+ This is not a part of DB-API.
413
+ """
414
+ if self._state == self._STATE_NONE:
415
+ raise ProgrammingError("No query yet")
416
+
417
+ req = ttypes.TGetOperationStatusReq(
418
+ operationHandle=self._operationHandle,
419
+ getProgressUpdate=get_progress_update,
420
+ )
421
+ response = self._connection._client.GetOperationStatus(req)
422
+ _check_status(response)
423
+
424
+ return response
425
+
426
+ #
427
+ # Private utilities
428
+ #
429
+
430
+
431
+ def _unwrap_column(col, type_=None):
432
+ """Return a list of raw values from a TColumn instance."""
433
+ for attr, wrapper in iteritems(col.__dict__):
434
+ if wrapper is not None:
435
+ result = wrapper.values
436
+ nulls = wrapper.nulls # bit set describing what's null
437
+ assert isinstance(nulls, bytes)
438
+ for i, char in enumerate(nulls):
439
+ byte = ord(char) if sys.version_info[0] == 2 else char
440
+ for b in range(8):
441
+ if byte & (1 << b):
442
+ result[i * 8 + b] = None
443
+ converter = TYPES_CONVERTER.get(type_, None)
444
+ if converter and type_:
445
+ result = [converter(row) if row else row for row in result]
446
+ return result
447
+ raise DataError("Got empty column value {}".format(col)) # pragma: no cover
448
+
449
+
450
+ def _check_status(response):
451
+ """Raise an OperationalError if the status is not success"""
452
+ _logger.debug(response)
453
+ if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
454
+ raise OperationalError(response)