databricks-sql-connector 3.2.0__tar.gz → 3.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/CHANGELOG.md +14 -0
  2. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/PKG-INFO +3 -3
  3. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/pyproject.toml +4 -4
  4. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/__init__.py +1 -1
  5. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/auth.py +15 -18
  6. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/authenticators.py +0 -15
  7. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/retry.py +16 -8
  8. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/client.py +14 -12
  9. databricks_sql_connector-3.3.0/src/databricks/sql/cloudfetch/download_manager.py +107 -0
  10. databricks_sql_connector-3.3.0/src/databricks/sql/cloudfetch/downloader.py +177 -0
  11. {databricks_sql_connector-3.2.0/src/databricks → databricks_sql_connector-3.3.0/src/databricks/sql/parameters}/py.typed +0 -0
  12. {databricks_sql_connector-3.2.0/src/databricks/sql/parameters → databricks_sql_connector-3.3.0/src/databricks/sql}/py.typed +0 -0
  13. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_backend.py +7 -3
  14. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/utils.py +15 -8
  15. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/README.sqlalchemy.md +1 -2
  16. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/README.tests.md +3 -3
  17. databricks_sql_connector-3.3.0/src/databricks/sqlalchemy/py.typed +0 -0
  18. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/requirements.py +1 -1
  19. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/overrides/_componentreflectiontest.py +1 -1
  20. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/__init__.py +1 -1
  21. databricks_sql_connector-3.2.0/src/databricks/sql/cloudfetch/download_manager.py +0 -215
  22. databricks_sql_connector-3.2.0/src/databricks/sql/cloudfetch/downloader.py +0 -173
  23. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/LICENSE +0 -0
  24. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/README.md +0 -0
  25. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/__init__.py +0 -0
  26. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/__init__.py +0 -0
  27. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/endpoint.py +0 -0
  28. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/oauth.py +0 -0
  29. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
  30. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/auth/thrift_http_client.py +0 -0
  31. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/exc.py +0 -0
  32. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/experimental/__init__.py +0 -0
  33. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
  34. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/parameters/__init__.py +0 -0
  35. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/parameters/native.py +0 -0
  36. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
  37. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
  38. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
  39. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
  40. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
  41. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/thrift_api/__init__.py +0 -0
  42. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sql/types.py +0 -0
  43. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/__init__.py +0 -0
  44. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/_ddl.py +0 -0
  45. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/_parse.py +0 -0
  46. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/_types.py +0 -0
  47. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/base.py +0 -0
  48. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/setup.cfg +0 -0
  49. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/_extra.py +0 -0
  50. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/_future.py +0 -0
  51. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/_regression.py +0 -0
  52. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/_unsupported.py +0 -0
  53. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/conftest.py +0 -0
  54. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/overrides/_ctetest.py +0 -0
  55. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test/test_suite.py +0 -0
  56. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/conftest.py +0 -0
  57. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx +0 -0
  58. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +0 -0
  59. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/test_ddl.py +0 -0
  60. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/test_parsing.py +0 -0
  61. {databricks_sql_connector-3.2.0 → databricks_sql_connector-3.3.0}/src/databricks/sqlalchemy/test_local/test_types.py +0 -0
@@ -1,5 +1,19 @@
1
1
  # Release History
2
2
 
3
+ # 3.3.0 (2024-07-18)
4
+
5
+ - Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo)
6
+ - Remove username/password (aka "basic") auth option (databricks/databricks-sql-python#409 by @jackyhu-db)
7
+ - Refactor CloudFetch handler to fix numerous issues with it (databricks/databricks-sql-python#405 by @kravets-levko)
8
+ - Add option to disable SSL verification for CloudFetch links (databricks/databricks-sql-python#414 by @kravets-levko)
9
+
10
+ Databricks-managed passwords reached end of life on July 10, 2024. Therefore, Basic auth support was removed from
11
+ the library. See https://docs.databricks.com/en/security/auth-authz/password-deprecation.html
12
+
13
+ The existing option `_tls_no_verify=True` of `sql.connect(...)` will now also disable SSL cert verification
14
+ (but not the SSL itself) for CloudFetch links. This option should be used as a workaround only, when other ways
15
+ to fix SSL certificate errors didn't work.
16
+
3
17
  # 3.2.0 (2024-06-06)
4
18
 
5
19
  - Update proxy authentication (databricks/databricks-sql-python#354 by @amir-haroun)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: databricks-sql-connector
3
- Version: 3.2.0
3
+ Version: 3.3.0
4
4
  Summary: Databricks SQL Connector for Python
5
5
  License: Apache-2.0
6
6
  Author: Databricks
@@ -17,8 +17,8 @@ Provides-Extra: alembic
17
17
  Provides-Extra: sqlalchemy
18
18
  Requires-Dist: alembic (>=1.0.11,<2.0.0) ; extra == "alembic"
19
19
  Requires-Dist: lz4 (>=4.0.2,<5.0.0)
20
- Requires-Dist: numpy (>=1.16.6) ; python_version >= "3.8" and python_version < "3.11"
21
- Requires-Dist: numpy (>=1.23.4) ; python_version >= "3.11"
20
+ Requires-Dist: numpy (>=1.16.6,<2.0.0) ; python_version >= "3.8" and python_version < "3.11"
21
+ Requires-Dist: numpy (>=1.23.4,<2.0.0) ; python_version >= "3.11"
22
22
  Requires-Dist: oauthlib (>=3.1.0,<4.0.0)
23
23
  Requires-Dist: openpyxl (>=3.0.10,<4.0.0)
24
24
  Requires-Dist: pandas (>=1.2.5,<2.2.0) ; python_version >= "3.8"
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "databricks-sql-connector"
3
- version = "3.2.0"
3
+ version = "3.3.0"
4
4
  description = "Databricks SQL Connector for Python"
5
5
  authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
6
6
  license = "Apache-2.0"
@@ -20,8 +20,8 @@ lz4 = "^4.0.2"
20
20
  requests = "^2.18.1"
21
21
  oauthlib = "^3.1.0"
22
22
  numpy = [
23
- { version = ">=1.16.6", python = ">=3.8,<3.11" },
24
- { version = ">=1.23.4", python = ">=3.11" },
23
+ { version = "^1.16.6", python = ">=3.8,<3.11" },
24
+ { version = "^1.23.4", python = ">=3.11" },
25
25
  ]
26
26
  sqlalchemy = { version = ">=2.0.21", optional = true }
27
27
  openpyxl = "^3.0.10"
@@ -34,7 +34,7 @@ alembic = ["sqlalchemy", "alembic"]
34
34
 
35
35
  [tool.poetry.dev-dependencies]
36
36
  pytest = "^7.1.2"
37
- mypy = "^0.981"
37
+ mypy = "^1.10.1"
38
38
  pylint = ">=2.12.0"
39
39
  black = "^22.3.0"
40
40
  pytest-dotenv = "^0.5.2"
@@ -68,7 +68,7 @@ DATETIME = DBAPITypeObject("timestamp")
68
68
  DATE = DBAPITypeObject("date")
69
69
  ROWID = DBAPITypeObject()
70
70
 
71
- __version__ = "3.2.0"
71
+ __version__ = "3.3.0"
72
72
  USER_AGENT_NAME = "PyDatabricksSqlConnector"
73
73
 
74
74
  # These two functions are pyhive legacy
@@ -1,10 +1,9 @@
1
1
  from enum import Enum
2
- from typing import List
2
+ from typing import Optional, List
3
3
 
4
4
  from databricks.sql.auth.authenticators import (
5
5
  AuthProvider,
6
6
  AccessTokenAuthProvider,
7
- BasicAuthProvider,
8
7
  ExternalAuthProvider,
9
8
  DatabricksOAuthProvider,
10
9
  )
@@ -13,7 +12,7 @@ from databricks.sql.auth.authenticators import (
13
12
  class AuthType(Enum):
14
13
  DATABRICKS_OAUTH = "databricks-oauth"
15
14
  AZURE_OAUTH = "azure-oauth"
16
- # other supported types (access_token, user/pass) can be inferred
15
+ # other supported types (access_token) can be inferred
17
16
  # we can add more types as needed later
18
17
 
19
18
 
@@ -21,21 +20,17 @@ class ClientContext:
21
20
  def __init__(
22
21
  self,
23
22
  hostname: str,
24
- username: str = None,
25
- password: str = None,
26
- access_token: str = None,
27
- auth_type: str = None,
28
- oauth_scopes: List[str] = None,
29
- oauth_client_id: str = None,
30
- oauth_redirect_port_range: List[int] = None,
31
- use_cert_as_auth: str = None,
32
- tls_client_cert_file: str = None,
23
+ access_token: Optional[str] = None,
24
+ auth_type: Optional[str] = None,
25
+ oauth_scopes: Optional[List[str]] = None,
26
+ oauth_client_id: Optional[str] = None,
27
+ oauth_redirect_port_range: Optional[List[int]] = None,
28
+ use_cert_as_auth: Optional[str] = None,
29
+ tls_client_cert_file: Optional[str] = None,
33
30
  oauth_persistence=None,
34
31
  credentials_provider=None,
35
32
  ):
36
33
  self.hostname = hostname
37
- self.username = username
38
- self.password = password
39
34
  self.access_token = access_token
40
35
  self.auth_type = auth_type
41
36
  self.oauth_scopes = oauth_scopes
@@ -65,8 +60,6 @@ def get_auth_provider(cfg: ClientContext):
65
60
  )
66
61
  elif cfg.access_token is not None:
67
62
  return AccessTokenAuthProvider(cfg.access_token)
68
- elif cfg.username is not None and cfg.password is not None:
69
- return BasicAuthProvider(cfg.username, cfg.password)
70
63
  elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
71
64
  # no op authenticator. authentication is performed using ssl certificate outside of headers
72
65
  return AuthProvider()
@@ -100,12 +93,16 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
100
93
  (client_id, redirect_port_range) = get_client_id_and_redirect_port(
101
94
  auth_type == AuthType.AZURE_OAUTH.value
102
95
  )
96
+ if kwargs.get("username") or kwargs.get("password"):
97
+ raise ValueError(
98
+ "Username/password authentication is no longer supported. "
99
+ "Please use OAuth or access token instead."
100
+ )
101
+
103
102
  cfg = ClientContext(
104
103
  hostname=normalize_host_name(hostname),
105
104
  auth_type=auth_type,
106
105
  access_token=kwargs.get("access_token"),
107
- username=kwargs.get("_username"),
108
- password=kwargs.get("_password"),
109
106
  use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
110
107
  tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
111
108
  oauth_scopes=PYSQL_OAUTH_SCOPES,
@@ -43,21 +43,6 @@ class AccessTokenAuthProvider(AuthProvider):
43
43
  request_headers["Authorization"] = self.__authorization_header_value
44
44
 
45
45
 
46
- # Private API: this is an evolving interface and it will change in the future.
47
- # Please must not depend on it in your applications.
48
- class BasicAuthProvider(AuthProvider):
49
- def __init__(self, username: str, password: str):
50
- auth_credentials = f"{username}:{password}".encode("UTF-8")
51
- auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
52
- "UTF-8"
53
- )
54
-
55
- self.__authorization_header_value = f"Basic {auth_credentials_base64}"
56
-
57
- def add_headers(self, request_headers: Dict[str, str]):
58
- request_headers["Authorization"] = self.__authorization_header_value
59
-
60
-
61
46
  # Private API: this is an evolving interface and it will change in the future.
62
47
  # Please must not depend on it in your applications.
63
48
  class DatabricksOAuthProvider(AuthProvider):
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
7
7
  # We only use this import for type hinting
8
8
  try:
9
9
  # If urllib3~=2.0 is installed
10
- from urllib3 import BaseHTTPResponse # type: ignore
10
+ from urllib3 import BaseHTTPResponse
11
11
  except ImportError:
12
12
  # If urllib3~=1.0 is installed
13
13
  from urllib3 import HTTPResponse as BaseHTTPResponse
@@ -129,7 +129,7 @@ class DatabricksRetryPolicy(Retry):
129
129
  urllib3_kwargs.update(**_urllib_kwargs_we_care_about)
130
130
 
131
131
  super().__init__(
132
- **urllib3_kwargs, # type: ignore
132
+ **urllib3_kwargs,
133
133
  )
134
134
 
135
135
  @classmethod
@@ -162,7 +162,9 @@ class DatabricksRetryPolicy(Retry):
162
162
  new_object.command_type = command_type
163
163
  return new_object
164
164
 
165
- def new(self, **urllib3_incremented_counters: typing.Any) -> Retry:
165
+ def new(
166
+ self, **urllib3_incremented_counters: typing.Any
167
+ ) -> "DatabricksRetryPolicy":
166
168
  """This method is responsible for passing the entire Retry state to its next iteration.
167
169
 
168
170
  urllib3 calls Retry.new() between successive requests as part of its `.increment()` method
@@ -210,7 +212,7 @@ class DatabricksRetryPolicy(Retry):
210
212
  other=self.other,
211
213
  allowed_methods=self.allowed_methods,
212
214
  status_forcelist=self.status_forcelist,
213
- backoff_factor=self.backoff_factor, # type: ignore
215
+ backoff_factor=self.backoff_factor,
214
216
  raise_on_redirect=self.raise_on_redirect,
215
217
  raise_on_status=self.raise_on_status,
216
218
  history=self.history,
@@ -222,7 +224,7 @@ class DatabricksRetryPolicy(Retry):
222
224
  urllib3_init_params.update(**urllib3_incremented_counters)
223
225
 
224
226
  # Include urllib3's current state in our __init__ params
225
- databricks_init_params["urllib3_kwargs"].update(**urllib3_init_params) # type: ignore
227
+ databricks_init_params["urllib3_kwargs"].update(**urllib3_init_params) # type: ignore[attr-defined]
226
228
 
227
229
  return type(self).__private_init__(
228
230
  retry_start_time=self._retry_start_time,
@@ -274,7 +276,7 @@ class DatabricksRetryPolicy(Retry):
274
276
  f"Retry request would exceed Retry policy max retry duration of {self.stop_after_attempts_duration} seconds"
275
277
  )
276
278
 
277
- def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: # type: ignore
279
+ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
278
280
  """Sleeps for the duration specified in the response Retry-After header, if present
279
281
 
280
282
  A MaxRetryDurationError will be raised if doing so would exceed self.max_attempts_duration
@@ -325,7 +327,8 @@ class DatabricksRetryPolicy(Retry):
325
327
  default, this means ExecuteStatement is only retried for codes 429 and 503.
326
328
  This limit prevents automatically retrying non-idempotent commands that could
327
329
  be destructive.
328
- 5. The request received a 403 response, because this can never succeed.
330
+ 5. The request received a 401 response, because this can never succeed.
331
+ 6. The request received a 403 response, because this can never succeed.
329
332
 
330
333
 
331
334
  Q: What about OSErrors and Redirects?
@@ -339,6 +342,11 @@ class DatabricksRetryPolicy(Retry):
339
342
  if status_code == 200:
340
343
  return False, "200 codes are not retried"
341
344
 
345
+ if status_code == 401:
346
+ raise NonRecoverableNetworkError(
347
+ "Received 401 - UNAUTHORIZED. Confirm your authentication credentials."
348
+ )
349
+
342
350
  if status_code == 403:
343
351
  raise NonRecoverableNetworkError(
344
352
  "Received 403 - FORBIDDEN. Confirm your authentication credentials."
@@ -349,7 +357,7 @@ class DatabricksRetryPolicy(Retry):
349
357
  raise NonRecoverableNetworkError("Received code 501 from server.")
350
358
 
351
359
  # Request failed and this method is not retryable. We only retry POST requests.
352
- if not self._is_method_retryable(method): # type: ignore
360
+ if not self._is_method_retryable(method):
353
361
  return False, "Only POST requests are retried"
354
362
 
355
363
  # Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry.
@@ -59,7 +59,7 @@ class Connection:
59
59
  http_path: str,
60
60
  access_token: Optional[str] = None,
61
61
  http_headers: Optional[List[Tuple[str, str]]] = None,
62
- session_configuration: Dict[str, Any] = None,
62
+ session_configuration: Optional[Dict[str, Any]] = None,
63
63
  catalog: Optional[str] = None,
64
64
  schema: Optional[str] = None,
65
65
  _use_arrow_native_complex_types: Optional[bool] = True,
@@ -163,16 +163,16 @@ class Connection:
163
163
  # Internal arguments in **kwargs:
164
164
  # _user_agent_entry
165
165
  # Tag to add to User-Agent header. For use by partners.
166
- # _username, _password
167
- # Username and password Basic authentication (no official support)
168
166
  # _use_cert_as_auth
169
- # Use a TLS cert instead of a token or username / password (internal use only)
167
+ # Use a TLS cert instead of a token
170
168
  # _enable_ssl
171
169
  # Connect over HTTP instead of HTTPS
172
170
  # _port
173
171
  # Which port to connect to
174
172
  # _skip_routing_headers:
175
173
  # Don't set routing headers if set to True (for use when connecting directly to server)
174
+ # _tls_no_verify
175
+ # Set to True (Boolean) to completely disable SSL verification.
176
176
  # _tls_verify_hostname
177
177
  # Set to False (Boolean) to disable SSL hostname verification, but check certificate.
178
178
  # _tls_trusted_ca_file
@@ -460,9 +460,9 @@ class Cursor:
460
460
  output: List[TDbsqlParameter] = []
461
461
  for p in params:
462
462
  if isinstance(p, DbsqlParameterBase):
463
- output.append(p) # type: ignore
463
+ output.append(p)
464
464
  else:
465
- output.append(dbsql_parameter_from_primitive(value=p)) # type: ignore
465
+ output.append(dbsql_parameter_from_primitive(value=p))
466
466
 
467
467
  return output
468
468
 
@@ -640,7 +640,7 @@ class Cursor:
640
640
  )
641
641
 
642
642
  def _handle_staging_put(
643
- self, presigned_url: str, local_file: str, headers: dict = None
643
+ self, presigned_url: str, local_file: str, headers: Optional[dict] = None
644
644
  ):
645
645
  """Make an HTTP PUT request
646
646
 
@@ -655,7 +655,7 @@ class Cursor:
655
655
 
656
656
  # fmt: off
657
657
  # Design borrowed from: https://stackoverflow.com/a/2342589/5093960
658
-
658
+
659
659
  OK = requests.codes.ok # 200
660
660
  CREATED = requests.codes.created # 201
661
661
  ACCEPTED = requests.codes.accepted # 202
@@ -675,7 +675,7 @@ class Cursor:
675
675
  )
676
676
 
677
677
  def _handle_staging_get(
678
- self, local_file: str, presigned_url: str, headers: dict = None
678
+ self, local_file: str, presigned_url: str, headers: Optional[dict] = None
679
679
  ):
680
680
  """Make an HTTP GET request, create a local file with the received data
681
681
 
@@ -697,7 +697,9 @@ class Cursor:
697
697
  with open(local_file, "wb") as fp:
698
698
  fp.write(r.content)
699
699
 
700
- def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
700
+ def _handle_staging_remove(
701
+ self, presigned_url: str, headers: Optional[dict] = None
702
+ ):
701
703
  """Make an HTTP DELETE request to the presigned_url"""
702
704
 
703
705
  r = requests.delete(url=presigned_url, headers=headers)
@@ -757,7 +759,7 @@ class Cursor:
757
759
  normalized_parameters = self._normalize_tparametercollection(parameters)
758
760
  param_structure = self._determine_parameter_structure(normalized_parameters)
759
761
  transformed_operation = transform_paramstyle(
760
- operation, normalized_parameters, param_structure # type: ignore
762
+ operation, normalized_parameters, param_structure
761
763
  )
762
764
  prepared_operation, prepared_params = self._prepare_native_parameters(
763
765
  transformed_operation, normalized_parameters, param_structure
@@ -861,7 +863,7 @@ class Cursor:
861
863
  catalog_name: Optional[str] = None,
862
864
  schema_name: Optional[str] = None,
863
865
  table_name: Optional[str] = None,
864
- table_types: List[str] = None,
866
+ table_types: Optional[List[str]] = None,
865
867
  ) -> "Cursor":
866
868
  """
867
869
  Get tables corresponding to the catalog_name, schema_name and table_name.
@@ -0,0 +1,107 @@
1
+ import logging
2
+
3
+ from ssl import SSLContext
4
+ from concurrent.futures import ThreadPoolExecutor, Future
5
+ from typing import List, Union
6
+
7
+ from databricks.sql.cloudfetch.downloader import (
8
+ ResultSetDownloadHandler,
9
+ DownloadableResultSettings,
10
+ DownloadedFile,
11
+ )
12
+ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ResultFileDownloadManager:
18
+ def __init__(
19
+ self,
20
+ links: List[TSparkArrowResultLink],
21
+ max_download_threads: int,
22
+ lz4_compressed: bool,
23
+ ssl_context: SSLContext,
24
+ ):
25
+ self._pending_links: List[TSparkArrowResultLink] = []
26
+ for link in links:
27
+ if link.rowCount <= 0:
28
+ continue
29
+ logger.debug(
30
+ "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
31
+ link.startRowOffset, link.rowCount
32
+ )
33
+ )
34
+ self._pending_links.append(link)
35
+
36
+ self._download_tasks: List[Future[DownloadedFile]] = []
37
+ self._max_download_threads: int = max_download_threads
38
+ self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
39
+
40
+ self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
41
+ self._ssl_context = ssl_context
42
+
43
+ def get_next_downloaded_file(
44
+ self, next_row_offset: int
45
+ ) -> Union[DownloadedFile, None]:
46
+ """
47
+ Get next file that starts at given offset.
48
+
49
+ This function gets the next downloaded file in which its rows start at the specified next_row_offset
50
+ in relation to the full result. File downloads are scheduled if not already, and once the correct
51
+ download handler is located, the function waits for the download status and returns the resulting file.
52
+ If there are no more downloads, a download was not successful, or the correct file could not be located,
53
+ this function shuts down the thread pool and returns None.
54
+
55
+ Args:
56
+ next_row_offset (int): The offset of the starting row of the next file we want data from.
57
+ """
58
+
59
+ # Make sure the download queue is always full
60
+ self._schedule_downloads()
61
+
62
+ # No more files to download from this batch of links
63
+ if len(self._download_tasks) == 0:
64
+ self._shutdown_manager()
65
+ return None
66
+
67
+ task = self._download_tasks.pop(0)
68
+ # Future's `result()` method will wait for the call to complete, and return
69
+ # the value returned by the call. If the call throws an exception - `result()`
70
+ # will throw the same exception
71
+ file = task.result()
72
+ if (next_row_offset < file.start_row_offset) or (
73
+ next_row_offset > file.start_row_offset + file.row_count
74
+ ):
75
+ logger.debug(
76
+ "ResultFileDownloadManager: file does not contain row {}, start {}, row count {}".format(
77
+ next_row_offset, file.start_row_offset, file.row_count
78
+ )
79
+ )
80
+
81
+ return file
82
+
83
+ def _schedule_downloads(self):
84
+ """
85
+ While download queue has a capacity, peek pending links and submit them to thread pool.
86
+ """
87
+ logger.debug("ResultFileDownloadManager: schedule downloads")
88
+ while (len(self._download_tasks) < self._max_download_threads) and (
89
+ len(self._pending_links) > 0
90
+ ):
91
+ link = self._pending_links.pop(0)
92
+ logger.debug(
93
+ "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
94
+ )
95
+ handler = ResultSetDownloadHandler(
96
+ settings=self._downloadable_result_settings,
97
+ link=link,
98
+ ssl_context=self._ssl_context,
99
+ )
100
+ task = self._thread_pool.submit(handler.run)
101
+ self._download_tasks.append(task)
102
+
103
+ def _shutdown_manager(self):
104
+ # Clear download handlers and shutdown the thread pool
105
+ self._pending_links = []
106
+ self._download_tasks = []
107
+ self._thread_pool.shutdown(wait=False)
@@ -0,0 +1,177 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+
4
+ import requests
5
+ from requests.adapters import HTTPAdapter, Retry
6
+ from ssl import SSLContext, CERT_NONE
7
+ import lz4.frame
8
+ import time
9
+
10
+ from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
11
+
12
+ from databricks.sql.exc import Error
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library.
17
+ # But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests
18
+ retryPolicy = Retry(
19
+ total=5, # max retry attempts
20
+ backoff_factor=1, # min delay, 1 second
21
+ # TODO: `backoff_max` is supported since `urllib3` v2.0.0, but we allow >= 1.26.
22
+ # The default value (120 seconds) used since v1.26 looks reasonable enough
23
+ # backoff_max=60, # max delay, 60 seconds
24
+ # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500,
25
+ # excluding 501 Not implemented
26
+ status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)],
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class DownloadedFile:
32
+ """
33
+ Class for the result file and metadata.
34
+
35
+ Attributes:
36
+ file_bytes (bytes): Downloaded file in bytes.
37
+ start_row_offset (int): The offset of the starting row in relation to the full result.
38
+ row_count (int): Number of rows the file represents in the result.
39
+ """
40
+
41
+ file_bytes: bytes
42
+ start_row_offset: int
43
+ row_count: int
44
+
45
+
46
+ @dataclass
47
+ class DownloadableResultSettings:
48
+ """
49
+ Class for settings common to each download handler.
50
+
51
+ Attributes:
52
+ is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
53
+ link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
54
+ download_timeout (int): Timeout for download requests. Default 60 secs.
55
+ max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
56
+ """
57
+
58
+ is_lz4_compressed: bool
59
+ link_expiry_buffer_secs: int = 0
60
+ download_timeout: int = 60
61
+ max_consecutive_file_download_retries: int = 0
62
+
63
+
64
+ class ResultSetDownloadHandler:
65
+ def __init__(
66
+ self,
67
+ settings: DownloadableResultSettings,
68
+ link: TSparkArrowResultLink,
69
+ ssl_context: SSLContext,
70
+ ):
71
+ self.settings = settings
72
+ self.link = link
73
+ self._ssl_context = ssl_context
74
+
75
+ def run(self) -> DownloadedFile:
76
+ """
77
+ Download the file described in the cloud fetch link.
78
+
79
+ This function checks if the link has or is expiring, gets the file via a requests session, decompresses the
80
+ file, and signals to waiting threads that the download is finished and whether it was successful.
81
+ """
82
+
83
+ logger.debug(
84
+ "ResultSetDownloadHandler: starting file download, offset {}, row count {}".format(
85
+ self.link.startRowOffset, self.link.rowCount
86
+ )
87
+ )
88
+
89
+ # Check if link is already expired or is expiring
90
+ ResultSetDownloadHandler._validate_link(
91
+ self.link, self.settings.link_expiry_buffer_secs
92
+ )
93
+
94
+ session = requests.Session()
95
+ session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
96
+ session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
97
+
98
+ ssl_verify = self._ssl_context.verify_mode != CERT_NONE
99
+
100
+ try:
101
+ # Get the file via HTTP request
102
+ response = session.get(
103
+ self.link.fileLink,
104
+ timeout=self.settings.download_timeout,
105
+ verify=ssl_verify,
106
+ )
107
+ response.raise_for_status()
108
+
109
+ # Save (and decompress if needed) the downloaded file
110
+ compressed_data = response.content
111
+ decompressed_data = (
112
+ ResultSetDownloadHandler._decompress_data(compressed_data)
113
+ if self.settings.is_lz4_compressed
114
+ else compressed_data
115
+ )
116
+
117
+ # The size of the downloaded file should match the size specified from TSparkArrowResultLink
118
+ if len(decompressed_data) != self.link.bytesNum:
119
+ logger.debug(
120
+ "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format(
121
+ len(decompressed_data), self.link.bytesNum
122
+ )
123
+ )
124
+
125
+ logger.debug(
126
+ "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format(
127
+ self.link.startRowOffset, self.link.rowCount
128
+ )
129
+ )
130
+
131
+ return DownloadedFile(
132
+ decompressed_data,
133
+ self.link.startRowOffset,
134
+ self.link.rowCount,
135
+ )
136
+ finally:
137
+ if session:
138
+ session.close()
139
+
140
+ @staticmethod
141
+ def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
142
+ """
143
+ Check if a link has expired or will expire.
144
+
145
+ Expiry buffer can be set to avoid downloading files that has not expired yet when the function is called,
146
+ but may expire before the file has fully downloaded.
147
+ """
148
+ current_time = int(time.time())
149
+ if (
150
+ link.expiryTime <= current_time
151
+ or link.expiryTime - current_time <= expiry_buffer_secs
152
+ ):
153
+ raise Error("CloudFetch link has expired")
154
+
155
+ @staticmethod
156
+ def _decompress_data(compressed_data: bytes) -> bytes:
157
+ """
158
+ Decompress lz4 frame compressed data.
159
+
160
+ Decompresses data that has been lz4 compressed, either via the whole frame or by series of chunks.
161
+ """
162
+ uncompressed_data, bytes_read = lz4.frame.decompress(
163
+ compressed_data, return_bytes_read=True
164
+ )
165
+ # The last cloud fetch file of the entire result is commonly punctuated by frequent end-of-frame markers.
166
+ # Full frame decompression above will short-circuit, so chunking is necessary
167
+ if bytes_read < len(compressed_data):
168
+ d_context = lz4.frame.create_decompression_context()
169
+ start = 0
170
+ uncompressed_data = bytearray()
171
+ while start < len(compressed_data):
172
+ data, num_bytes, is_end = lz4.frame.decompress_chunk(
173
+ d_context, compressed_data[start:]
174
+ )
175
+ uncompressed_data += data
176
+ start += num_bytes
177
+ return uncompressed_data