clickhouse-driver 0.2.10__cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. clickhouse_driver/__init__.py +9 -0
  2. clickhouse_driver/block.py +227 -0
  3. clickhouse_driver/blockstreamprofileinfo.py +22 -0
  4. clickhouse_driver/bufferedreader.cpython-310-aarch64-linux-gnu.so +0 -0
  5. clickhouse_driver/bufferedwriter.cpython-310-aarch64-linux-gnu.so +0 -0
  6. clickhouse_driver/client.py +812 -0
  7. clickhouse_driver/clientinfo.py +119 -0
  8. clickhouse_driver/columns/__init__.py +0 -0
  9. clickhouse_driver/columns/arraycolumn.py +161 -0
  10. clickhouse_driver/columns/base.py +221 -0
  11. clickhouse_driver/columns/boolcolumn.py +7 -0
  12. clickhouse_driver/columns/datecolumn.py +108 -0
  13. clickhouse_driver/columns/datetimecolumn.py +203 -0
  14. clickhouse_driver/columns/decimalcolumn.py +116 -0
  15. clickhouse_driver/columns/enumcolumn.py +129 -0
  16. clickhouse_driver/columns/exceptions.py +12 -0
  17. clickhouse_driver/columns/floatcolumn.py +34 -0
  18. clickhouse_driver/columns/intcolumn.py +157 -0
  19. clickhouse_driver/columns/intervalcolumn.py +33 -0
  20. clickhouse_driver/columns/ipcolumn.py +118 -0
  21. clickhouse_driver/columns/jsoncolumn.py +37 -0
  22. clickhouse_driver/columns/largeint.cpython-310-aarch64-linux-gnu.so +0 -0
  23. clickhouse_driver/columns/lowcardinalitycolumn.py +142 -0
  24. clickhouse_driver/columns/mapcolumn.py +73 -0
  25. clickhouse_driver/columns/nestedcolumn.py +10 -0
  26. clickhouse_driver/columns/nothingcolumn.py +13 -0
  27. clickhouse_driver/columns/nullablecolumn.py +7 -0
  28. clickhouse_driver/columns/nullcolumn.py +15 -0
  29. clickhouse_driver/columns/numpy/__init__.py +0 -0
  30. clickhouse_driver/columns/numpy/base.py +47 -0
  31. clickhouse_driver/columns/numpy/boolcolumn.py +8 -0
  32. clickhouse_driver/columns/numpy/datecolumn.py +19 -0
  33. clickhouse_driver/columns/numpy/datetimecolumn.py +146 -0
  34. clickhouse_driver/columns/numpy/floatcolumn.py +24 -0
  35. clickhouse_driver/columns/numpy/intcolumn.py +43 -0
  36. clickhouse_driver/columns/numpy/lowcardinalitycolumn.py +96 -0
  37. clickhouse_driver/columns/numpy/service.py +58 -0
  38. clickhouse_driver/columns/numpy/stringcolumn.py +78 -0
  39. clickhouse_driver/columns/numpy/tuplecolumn.py +37 -0
  40. clickhouse_driver/columns/service.py +185 -0
  41. clickhouse_driver/columns/simpleaggregatefunctioncolumn.py +7 -0
  42. clickhouse_driver/columns/stringcolumn.py +73 -0
  43. clickhouse_driver/columns/tuplecolumn.py +63 -0
  44. clickhouse_driver/columns/util.py +61 -0
  45. clickhouse_driver/columns/uuidcolumn.py +64 -0
  46. clickhouse_driver/compression/__init__.py +32 -0
  47. clickhouse_driver/compression/base.py +87 -0
  48. clickhouse_driver/compression/lz4.py +21 -0
  49. clickhouse_driver/compression/lz4hc.py +9 -0
  50. clickhouse_driver/compression/zstd.py +20 -0
  51. clickhouse_driver/connection.py +825 -0
  52. clickhouse_driver/context.py +36 -0
  53. clickhouse_driver/dbapi/__init__.py +62 -0
  54. clickhouse_driver/dbapi/connection.py +99 -0
  55. clickhouse_driver/dbapi/cursor.py +370 -0
  56. clickhouse_driver/dbapi/errors.py +40 -0
  57. clickhouse_driver/dbapi/extras.py +73 -0
  58. clickhouse_driver/defines.py +58 -0
  59. clickhouse_driver/errors.py +453 -0
  60. clickhouse_driver/log.py +48 -0
  61. clickhouse_driver/numpy/__init__.py +0 -0
  62. clickhouse_driver/numpy/block.py +8 -0
  63. clickhouse_driver/numpy/helpers.py +28 -0
  64. clickhouse_driver/numpy/result.py +123 -0
  65. clickhouse_driver/opentelemetry.py +43 -0
  66. clickhouse_driver/progress.py +44 -0
  67. clickhouse_driver/protocol.py +130 -0
  68. clickhouse_driver/queryprocessingstage.py +8 -0
  69. clickhouse_driver/reader.py +69 -0
  70. clickhouse_driver/readhelpers.py +26 -0
  71. clickhouse_driver/result.py +144 -0
  72. clickhouse_driver/settings/__init__.py +0 -0
  73. clickhouse_driver/settings/available.py +405 -0
  74. clickhouse_driver/settings/types.py +50 -0
  75. clickhouse_driver/settings/writer.py +34 -0
  76. clickhouse_driver/streams/__init__.py +0 -0
  77. clickhouse_driver/streams/compressed.py +88 -0
  78. clickhouse_driver/streams/native.py +108 -0
  79. clickhouse_driver/util/__init__.py +0 -0
  80. clickhouse_driver/util/compat.py +39 -0
  81. clickhouse_driver/util/escape.py +94 -0
  82. clickhouse_driver/util/helpers.py +173 -0
  83. clickhouse_driver/varint.cpython-310-aarch64-linux-gnu.so +0 -0
  84. clickhouse_driver/writer.py +67 -0
  85. clickhouse_driver-0.2.10.dist-info/METADATA +215 -0
  86. clickhouse_driver-0.2.10.dist-info/RECORD +89 -0
  87. clickhouse_driver-0.2.10.dist-info/WHEEL +7 -0
  88. clickhouse_driver-0.2.10.dist-info/licenses/LICENSE +21 -0
  89. clickhouse_driver-0.2.10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,825 @@
1
+ import logging
2
+ import socket
3
+ import ssl
4
+ from collections import deque
5
+ from contextlib import contextmanager
6
+ from sys import platform
7
+ from time import time
8
+ from urllib.parse import urlparse
9
+
10
+ try:
11
+ import certifi
12
+ except ImportError:
13
+ certifi = None
14
+
15
+ from . import defines
16
+ from . import errors
17
+ from .block import RowOrientedBlock
18
+ from .blockstreamprofileinfo import BlockStreamProfileInfo
19
+ from .bufferedreader import BufferedSocketReader
20
+ from .bufferedwriter import BufferedSocketWriter
21
+ from .clientinfo import ClientInfo
22
+ from .compression import get_compressor_cls
23
+ from .context import Context
24
+ from .log import log_block
25
+ from .progress import Progress
26
+ from .protocol import Compression, ClientPacketTypes, ServerPacketTypes
27
+ from .queryprocessingstage import QueryProcessingStage
28
+ from .reader import read_binary_str, read_binary_uint64
29
+ from .readhelpers import read_exception
30
+ from .settings.writer import write_settings, SettingsFlags
31
+ from .streams.native import BlockInputStream, BlockOutputStream
32
+ from .util.compat import threading
33
+ from .util.escape import escape_params
34
+ from .varint import write_varint, read_varint
35
+ from .writer import write_binary_str
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class Packet(object):
41
+ def __init__(self):
42
+ self.type = None
43
+ self.block = None
44
+ self.exception = None
45
+ self.progress = None
46
+ self.profile_info = None
47
+ self.multistring_message = None
48
+
49
+ super(Packet, self).__init__()
50
+
51
+
52
+ class ServerInfo(object):
53
+ def __init__(self, name, version_major, version_minor, version_patch,
54
+ revision, timezone, display_name, used_revision):
55
+ self.name = name
56
+ self.version_major = version_major
57
+ self.version_minor = version_minor
58
+ self.version_patch = version_patch
59
+ self.revision = revision
60
+ self.timezone = timezone
61
+ self.session_timezone = None
62
+ self.display_name = display_name
63
+ self.used_revision = used_revision
64
+
65
+ super(ServerInfo, self).__init__()
66
+
67
+ def get_timezone(self):
68
+ return self.session_timezone or self.timezone
69
+
70
+ def version_tuple(self):
71
+ return self.version_major, self.version_minor, self.version_patch
72
+
73
+ def __repr__(self):
74
+ version = '%s.%s.%s' % (
75
+ self.version_major, self.version_minor, self.version_patch
76
+ )
77
+ items = [
78
+ ('name', self.name),
79
+ ('version', version),
80
+ ('revision', self.revision),
81
+ ('used revision', self.used_revision),
82
+ ('timezone', self.timezone),
83
+ ('display_name', self.display_name)
84
+ ]
85
+
86
+ params = ', '.join('{}={}'.format(key, value) for key, value in items)
87
+ return '<ServerInfo(%s)>' % (params)
88
+
89
+
90
+ class Connection(object):
91
+ """
92
+ Represents connection between client and ClickHouse server.
93
+
94
+ :param host: host with running ClickHouse server.
95
+ :param port: port ClickHouse server is bound to.
96
+ Defaults to ``9000`` if connection is not secured and
97
+ to ``9440`` if connection is secured.
98
+ :param database: database connect to. Defaults to ``'default'``.
99
+ :param user: database user. Defaults to ``'default'``.
100
+ :param password: user's password. Defaults to ``''`` (no password).
101
+ :param client_name: this name will appear in server logs.
102
+ Defaults to ``'python-driver'``.
103
+ :param connect_timeout: timeout for establishing connection.
104
+ Defaults to ``10`` seconds.
105
+ :param send_receive_timeout: timeout for sending and receiving data.
106
+ Defaults to ``300`` seconds.
107
+ :param sync_request_timeout: timeout for server ping.
108
+ Defaults to ``5`` seconds.
109
+ :param compress_block_size: size of compressed block to send.
110
+ Defaults to ``1048576``.
111
+ :param compression: specifies whether or not use compression.
112
+ Defaults to ``False``. Possible choices:
113
+
114
+ * ``True`` is equivalent to ``'lz4'``.
115
+ * ``'lz4'``.
116
+ * ``'lz4hc'`` high-compression variant of
117
+ ``'lz4'``.
118
+ * ``'zstd'``.
119
+
120
+ :param secure: establish secure connection. Defaults to ``False``.
121
+ :param verify: specifies whether a certificate is required and whether it
122
+ will be validated after connection.
123
+ Defaults to ``True``.
124
+ :param ssl_version: see :func:`ssl.wrap_socket` docs.
125
+ :param ca_certs: see :func:`ssl.wrap_socket` docs.
126
+ :param ciphers: see :func:`ssl.wrap_socket` docs.
127
+ :param keyfile: see :func:`ssl.wrap_socket` docs.
128
+ :param keypass: see :func:`ssl.wrap_socket` docs.
129
+ :param certfile: see :func:`ssl.wrap_socket` docs.
130
+ :param check_hostname: see :func:`ssl.wrap_socket` docs.
131
+ Defaults to ``True``.
132
+ :param server_hostname: Hostname to use in SSL Wrapper construction.
133
+ Defaults to `None` which will send the passed
134
+ host param during SSL initialization. This param
135
+ may be used when connecting over an SSH tunnel
136
+ to correctly identify the desired server via SNI.
137
+ :param alt_hosts: list of alternative hosts for connection.
138
+ Example: alt_hosts=host1:port1,host2:port2.
139
+ :param settings_is_important: ``False`` means unknown settings will be
140
+ ignored, ``True`` means that the query will
141
+ fail with UNKNOWN_SETTING error.
142
+ Defaults to ``False``.
143
+ :param tcp_keepalive: enables `TCP keepalive <https://tldp.org/HOWTO/
144
+ TCP-Keepalive-HOWTO/overview.html>`_ on established
145
+ connection. If is set to ``True``` system keepalive
146
+ settings are used. You can also specify custom
147
+ keepalive setting with tuple:
148
+ ``(idle_time_sec, interval_sec, probes)``.
149
+ Defaults to ``False``.
150
+ :param client_revision: can be used for client version downgrading.
151
+ Defaults to ``None``.
152
+ :param disable_reconnect: disable automatic reconnect in case of
153
+ failed ``ping``, helpful when every reconnect
154
+ need to be caught in calling code.
155
+ Defaults to ``False``.
156
+ """
157
+
158
+ def __init__(
159
+ self, host, port=None,
160
+ database=defines.DEFAULT_DATABASE,
161
+ user=defines.DEFAULT_USER, password=defines.DEFAULT_PASSWORD,
162
+ client_name=defines.CLIENT_NAME,
163
+ connect_timeout=defines.DBMS_DEFAULT_CONNECT_TIMEOUT_SEC,
164
+ send_receive_timeout=defines.DBMS_DEFAULT_TIMEOUT_SEC,
165
+ sync_request_timeout=defines.DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC,
166
+ compress_block_size=defines.DEFAULT_COMPRESS_BLOCK_SIZE,
167
+ compression=False,
168
+ secure=False,
169
+ # Secure socket parameters.
170
+ verify=True, ssl_version=None, ca_certs=None, ciphers=None,
171
+ keyfile=None, keypass=None, certfile=None, check_hostname=True,
172
+ server_hostname=None,
173
+ alt_hosts=None,
174
+ settings_is_important=False,
175
+ tcp_keepalive=False,
176
+ client_revision=None,
177
+ disable_reconnect=False,
178
+ ):
179
+ if secure:
180
+ default_port = defines.DEFAULT_SECURE_PORT
181
+ else:
182
+ default_port = defines.DEFAULT_PORT
183
+
184
+ self.hosts = deque([(host, port or default_port)])
185
+
186
+ if alt_hosts:
187
+ for host in alt_hosts.split(','):
188
+ url = urlparse('clickhouse://' + host)
189
+ self.hosts.append((url.hostname, url.port or default_port))
190
+
191
+ self.database = database
192
+ self.user = user
193
+ self.password = password
194
+ self.client_name = defines.DBMS_NAME + ' ' + client_name
195
+ self.connect_timeout = connect_timeout
196
+ self.send_receive_timeout = send_receive_timeout
197
+ self.sync_request_timeout = sync_request_timeout
198
+ self.settings_is_important = settings_is_important
199
+ self.tcp_keepalive = tcp_keepalive
200
+ self.client_revision = min(
201
+ client_revision or defines.CLIENT_REVISION, defines.CLIENT_REVISION
202
+ )
203
+ self.disable_reconnect = disable_reconnect
204
+
205
+ self.secure_socket = secure
206
+ self.verify_cert = verify
207
+
208
+ if certifi is not None:
209
+ ca_certs = ca_certs or certifi.where()
210
+
211
+ ssl_options = {}
212
+ if ssl_version is not None:
213
+ ssl_options['ssl_version'] = ssl_version
214
+ if ca_certs is not None:
215
+ ssl_options['ca_certs'] = ca_certs
216
+ if ciphers is not None:
217
+ ssl_options['ciphers'] = ciphers
218
+ if keyfile is not None:
219
+ ssl_options['keyfile'] = keyfile
220
+ if keypass is not None:
221
+ ssl_options['keypass'] = keypass
222
+ if certfile is not None:
223
+ ssl_options['certfile'] = certfile
224
+
225
+ self.ssl_options = ssl_options
226
+
227
+ self.check_hostname = check_hostname if self.verify_cert else False
228
+ self.server_hostname = server_hostname
229
+
230
+ # Use LZ4 compression by default.
231
+ if compression is True:
232
+ compression = 'lz4'
233
+
234
+ if compression is False:
235
+ self.compression = Compression.DISABLED
236
+ self.compressor_cls = None
237
+ self.compress_block_size = None
238
+ else:
239
+ self.compression = Compression.ENABLED
240
+ self.compressor_cls = get_compressor_cls(compression)
241
+ self.compress_block_size = compress_block_size
242
+
243
+ self.socket = None
244
+ self.fin = None
245
+ self.fout = None
246
+
247
+ self.connected = False
248
+
249
+ self.client_trace_context = None
250
+ self.server_info = None
251
+ self.context = Context()
252
+
253
+ # Block writer/reader
254
+ self.block_in = None
255
+ self.block_out = None
256
+ self.block_in_raw = None # log blocks are always not compressed
257
+
258
+ self._lock = threading.Lock()
259
+ self.is_query_executing = False
260
+
261
+ super(Connection, self).__init__()
262
+
263
+ def __repr__(self):
264
+ dsn = '%s://%s:***@%s:%s/%s' % (
265
+ 'clickhouses' if self.secure_socket else 'clickhouse',
266
+ self.user, self.host, self.port, self.database
267
+ ) if self.connected else '(not connected)'
268
+
269
+ return '<Connection(dsn=%s, compression=%s)>' % (dsn, self.compression)
270
+
271
+ def get_description(self):
272
+ return '{}:{}'.format(self.host, self.port)
273
+
274
+ def force_connect(self):
275
+ self.check_query_execution()
276
+
277
+ if not self.connected:
278
+ self.connect()
279
+
280
+ elif not self.ping():
281
+ if self.disable_reconnect:
282
+ raise errors.NetworkError(
283
+ "Connection was closed, reconnect is disabled."
284
+ )
285
+
286
+ logger.warning('Connection was closed, reconnecting.')
287
+ self.connect()
288
+
289
+ def _create_socket(self, host, port):
290
+ """
291
+ Acts like socket.create_connection, but wraps socket with SSL
292
+ if connection is secure.
293
+ """
294
+ ssl_options = {}
295
+ if self.secure_socket:
296
+ if self.verify_cert:
297
+ cert_reqs = ssl.CERT_REQUIRED
298
+ else:
299
+ cert_reqs = ssl.CERT_NONE
300
+
301
+ ssl_options = self.ssl_options.copy()
302
+ ssl_options['cert_reqs'] = cert_reqs
303
+
304
+ err = None
305
+ for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
306
+ af, socktype, proto, canonname, sa = res
307
+ sock = None
308
+ try:
309
+ sock = socket.socket(af, socktype, proto)
310
+ sock.settimeout(self.connect_timeout)
311
+
312
+ if self.secure_socket:
313
+ ssl_context = self._create_ssl_context(ssl_options)
314
+ sock = ssl_context.wrap_socket(
315
+ sock, server_hostname=self.server_hostname or host)
316
+
317
+ sock.connect(sa)
318
+ return sock
319
+
320
+ except socket.error as _:
321
+ err = _
322
+ if sock is not None:
323
+ sock.close()
324
+
325
+ if err is not None:
326
+ raise err
327
+ else:
328
+ raise socket.error("getaddrinfo returns an empty list")
329
+
330
+ def _create_ssl_context(self, ssl_options):
331
+ purpose = ssl.Purpose.SERVER_AUTH
332
+
333
+ version = ssl_options.get('ssl_version', ssl.PROTOCOL_TLS_CLIENT)
334
+ context = ssl.SSLContext(version)
335
+ context.check_hostname = self.check_hostname
336
+
337
+ if 'ca_certs' in ssl_options:
338
+ context.load_verify_locations(ssl_options['ca_certs'])
339
+ elif ssl_options.get('cert_reqs') != ssl.CERT_NONE:
340
+ context.load_default_certs(purpose)
341
+ if 'ciphers' in ssl_options:
342
+ context.set_ciphers(ssl_options['ciphers'])
343
+
344
+ if 'cert_reqs' in ssl_options:
345
+ context.verify_mode = ssl_options['cert_reqs']
346
+
347
+ if 'certfile' in ssl_options:
348
+ keyfile = ssl_options.get('keyfile')
349
+ keypass = ssl_options.get('keypass')
350
+ context.load_cert_chain(
351
+ ssl_options['certfile'],
352
+ keyfile=keyfile, password=keypass
353
+ )
354
+
355
+ return context
356
+
357
+ def _init_connection(self, host, port):
358
+ self.socket = self._create_socket(host, port)
359
+ self.connected = True
360
+ self.host, self.port = host, port
361
+ self.socket.settimeout(self.send_receive_timeout)
362
+
363
+ # performance tweak
364
+ self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
365
+ if self.tcp_keepalive:
366
+ self._set_keepalive()
367
+
368
+ self.fin = BufferedSocketReader(self.socket, defines.BUFFER_SIZE)
369
+ self.fout = BufferedSocketWriter(self.socket, defines.BUFFER_SIZE)
370
+
371
+ self.send_hello()
372
+ self.receive_hello()
373
+
374
+ revision = self.server_info.used_revision
375
+ if revision >= defines.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM:
376
+ self.send_addendum()
377
+
378
+ self.block_in = self.get_block_in_stream()
379
+ self.block_in_raw = BlockInputStream(self.fin, self.context)
380
+ self.block_out = self.get_block_out_stream()
381
+
382
+ def _set_keepalive(self):
383
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
384
+
385
+ if not isinstance(self.tcp_keepalive, tuple):
386
+ return
387
+
388
+ idle_time_sec, interval_sec, probes = self.tcp_keepalive
389
+
390
+ if platform == 'linux' or platform == 'win32':
391
+ # This should also work for Windows
392
+ # starting with Windows 10, version 1709.
393
+ self.socket.setsockopt(
394
+ socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, idle_time_sec
395
+ )
396
+ self.socket.setsockopt(
397
+ socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval_sec
398
+ )
399
+ self.socket.setsockopt(
400
+ socket.IPPROTO_TCP, socket.TCP_KEEPCNT, probes
401
+ )
402
+
403
+ elif platform == 'darwin':
404
+ TCP_KEEPALIVE = 0x10
405
+ # Only interval is available in mac os.
406
+ self.socket.setsockopt(
407
+ socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec
408
+ )
409
+
410
+ def _format_connection_error(self, e, host, port):
411
+ err = (e.strerror + ' ') if e.strerror else ''
412
+ return err + '({}:{})'.format(host, port)
413
+
414
+ def connect(self):
415
+ if self.connected:
416
+ self.disconnect()
417
+
418
+ logger.debug(
419
+ 'Connecting. Database: %s. User: %s', self.database, self.user
420
+ )
421
+
422
+ err = None
423
+ for i in range(len(self.hosts)):
424
+ host, port = self.hosts[0]
425
+ logger.debug('Connecting to %s:%s', host, port)
426
+
427
+ try:
428
+ return self._init_connection(host, port)
429
+
430
+ except socket.timeout as e:
431
+ self.disconnect()
432
+ logger.warning(
433
+ 'Failed to connect to %s:%s', host, port, exc_info=True
434
+ )
435
+ err_str = self._format_connection_error(e, host, port)
436
+ err = errors.SocketTimeoutError(err_str)
437
+
438
+ except socket.error as e:
439
+ self.disconnect()
440
+ logger.warning(
441
+ 'Failed to connect to %s:%s', host, port, exc_info=True
442
+ )
443
+ err_str = self._format_connection_error(e, host, port)
444
+ err = errors.NetworkError(err_str)
445
+
446
+ self.hosts.rotate(-1)
447
+
448
+ if err is not None:
449
+ raise err
450
+
451
+ def reset_state(self):
452
+ self.host = None
453
+ self.port = None
454
+ self.socket = None
455
+ self.fin = None
456
+ self.fout = None
457
+
458
+ self.connected = False
459
+
460
+ self.client_trace_context = None
461
+ self.server_info = None
462
+
463
+ self.block_in = None
464
+ self.block_in_raw = None
465
+ self.block_out = None
466
+
467
+ self.is_query_executing = False
468
+
469
+ def disconnect(self):
470
+ """
471
+ Closes connection between server and client.
472
+ Frees resources: e.g. closes socket.
473
+ """
474
+
475
+ if self.connected:
476
+ # There can be errors on shutdown.
477
+ # We need to close socket and reset state even if it happens.
478
+ try:
479
+ self.socket.shutdown(socket.SHUT_RDWR)
480
+
481
+ except socket.error as e:
482
+ logger.warning('Error on socket shutdown: %s', e)
483
+
484
+ self.socket.close()
485
+
486
+ # Socket can be constructed but not connected.
487
+ elif self.socket:
488
+ self.socket.close()
489
+
490
+ self.reset_state()
491
+
492
+ def send_hello(self):
493
+ write_varint(ClientPacketTypes.HELLO, self.fout)
494
+ write_binary_str(self.client_name, self.fout)
495
+ write_varint(defines.CLIENT_VERSION_MAJOR, self.fout)
496
+ write_varint(defines.CLIENT_VERSION_MINOR, self.fout)
497
+ # NOTE For backward compatibility of the protocol,
498
+ # client cannot send its version_patch.
499
+ write_varint(self.client_revision, self.fout)
500
+ write_binary_str(self.database, self.fout)
501
+ write_binary_str(self.user, self.fout)
502
+ write_binary_str(self.password, self.fout)
503
+
504
+ self.fout.flush()
505
+
506
+ def receive_hello(self):
507
+ packet_type = read_varint(self.fin)
508
+
509
+ if packet_type == ServerPacketTypes.HELLO:
510
+ server_name = read_binary_str(self.fin)
511
+ server_version_major = read_varint(self.fin)
512
+ server_version_minor = read_varint(self.fin)
513
+ server_revision = read_varint(self.fin)
514
+
515
+ used_revision = min(self.client_revision, server_revision)
516
+
517
+ server_timezone = None
518
+ if used_revision >= \
519
+ defines.DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE:
520
+ server_timezone = read_binary_str(self.fin)
521
+
522
+ server_display_name = ''
523
+ if used_revision >= \
524
+ defines.DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME:
525
+ server_display_name = read_binary_str(self.fin)
526
+
527
+ server_version_patch = server_revision
528
+ if used_revision >= \
529
+ defines.DBMS_MIN_REVISION_WITH_VERSION_PATCH:
530
+ server_version_patch = read_varint(self.fin)
531
+
532
+ if used_revision >= defines. \
533
+ DBMS_MIN_PROTOCOL_VERSION_WITH_PASSWORD_COMPLEXITY_RULES:
534
+ rules_size = read_varint(self.fin)
535
+ for _i in range(rules_size):
536
+ read_binary_str(self.fin) # original_pattern
537
+ read_binary_str(self.fin) # exception_message
538
+
539
+ if used_revision >= defines. \
540
+ DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2:
541
+ read_binary_uint64(self.fin) # read_nonce
542
+
543
+ self.server_info = ServerInfo(
544
+ server_name, server_version_major, server_version_minor,
545
+ server_version_patch, server_revision,
546
+ server_timezone, server_display_name, used_revision
547
+ )
548
+ self.context.server_info = self.server_info
549
+
550
+ logger.debug(
551
+ 'Connected to %s server version %s.%s.%s, revision: %s',
552
+ server_name, server_version_major, server_version_minor,
553
+ server_version_patch, server_revision
554
+ )
555
+
556
+ elif packet_type == ServerPacketTypes.EXCEPTION:
557
+ raise self.receive_exception()
558
+
559
+ else:
560
+ message = self.unexpected_packet_message('Hello or Exception',
561
+ packet_type)
562
+ self.disconnect()
563
+ raise errors.UnexpectedPacketFromServerError(message)
564
+
565
+ def send_addendum(self):
566
+ revision = self.server_info.used_revision
567
+
568
+ if revision >= defines.DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY:
569
+ write_binary_str(
570
+ self.context.client_settings['quota_key'], self.fout
571
+ )
572
+
573
+ def ping(self):
574
+ if not self.socket:
575
+ return None
576
+
577
+ timeout = self.sync_request_timeout
578
+
579
+ with self.timeout_setter(timeout):
580
+ try:
581
+ write_varint(ClientPacketTypes.PING, self.fout)
582
+ self.fout.flush()
583
+
584
+ packet_type = read_varint(self.fin)
585
+ while packet_type == ServerPacketTypes.PROGRESS:
586
+ self.receive_progress()
587
+ packet_type = read_varint(self.fin)
588
+
589
+ if packet_type != ServerPacketTypes.PONG:
590
+ msg = self.unexpected_packet_message('Pong', packet_type)
591
+ raise errors.UnexpectedPacketFromServerError(msg)
592
+
593
+ except errors.Error:
594
+ raise
595
+
596
+ except (socket.error, EOFError) as e:
597
+ # It's just a warning now.
598
+ # Current connection will be closed, new will be established.
599
+ logger.warning(
600
+ 'Error on %s ping: %s', self.get_description(), e
601
+ )
602
+ return False
603
+
604
+ return True
605
+
606
+ def receive_packet(self):
607
+ packet = Packet()
608
+
609
+ packet.type = packet_type = read_varint(self.fin)
610
+
611
+ if packet_type == ServerPacketTypes.DATA:
612
+ packet.block = self.receive_data(may_be_use_numpy=True)
613
+
614
+ elif packet_type == ServerPacketTypes.EXCEPTION:
615
+ packet.exception = self.receive_exception()
616
+
617
+ elif packet_type == ServerPacketTypes.PROGRESS:
618
+ packet.progress = self.receive_progress()
619
+
620
+ elif packet_type == ServerPacketTypes.PROFILE_INFO:
621
+ packet.profile_info = self.receive_profile_info()
622
+
623
+ elif packet_type == ServerPacketTypes.TOTALS:
624
+ packet.block = self.receive_data()
625
+
626
+ elif packet_type == ServerPacketTypes.EXTREMES:
627
+ packet.block = self.receive_data()
628
+
629
+ elif packet_type == ServerPacketTypes.LOG:
630
+ packet.block = self.receive_data(may_be_compressed=False)
631
+ log_block(packet.block)
632
+
633
+ elif packet_type == ServerPacketTypes.END_OF_STREAM:
634
+ self.is_query_executing = False
635
+ pass
636
+
637
+ elif packet_type == ServerPacketTypes.TABLE_COLUMNS:
638
+ packet.multistring_message = self.receive_multistring_message(
639
+ packet_type
640
+ )
641
+
642
+ elif packet_type == ServerPacketTypes.PART_UUIDS:
643
+ packet.block = self.receive_data()
644
+
645
+ elif packet_type == ServerPacketTypes.READ_TASK_REQUEST:
646
+ packet.block = self.receive_data()
647
+
648
+ elif packet_type == ServerPacketTypes.PROFILE_EVENTS:
649
+ packet.block = self.receive_data(may_be_compressed=False)
650
+
651
+ elif packet_type == ServerPacketTypes.TIMEZONE_UPDATE:
652
+ timezone = read_binary_str(self.fin)
653
+ if timezone:
654
+ logger.info('Server timezone changed to %s', timezone)
655
+ self.server_info.session_timezone = timezone
656
+
657
+ else:
658
+ message = 'Unknown packet {} from server {}'.format(
659
+ packet_type, self.get_description()
660
+ )
661
+ self.disconnect()
662
+ raise errors.UnknownPacketFromServerError(message)
663
+
664
+ return packet
665
+
666
+ def get_block_in_stream(self):
667
+ if self.compression:
668
+ from .streams.compressed import CompressedBlockInputStream
669
+
670
+ return CompressedBlockInputStream(self.fin, self.context)
671
+ else:
672
+ return BlockInputStream(self.fin, self.context)
673
+
674
+ def get_block_out_stream(self):
675
+ if self.compression:
676
+ from .streams.compressed import CompressedBlockOutputStream
677
+
678
+ return CompressedBlockOutputStream(
679
+ self.compressor_cls, self.compress_block_size,
680
+ self.fout, self.context
681
+ )
682
+ else:
683
+ return BlockOutputStream(self.fout, self.context)
684
+
685
+ def receive_data(self, may_be_compressed=True, may_be_use_numpy=False):
686
+ revision = self.server_info.used_revision
687
+
688
+ if revision >= defines.DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES:
689
+ read_binary_str(self.fin)
690
+
691
+ reader = self.block_in if may_be_compressed else self.block_in_raw
692
+ use_numpy = False if not may_be_use_numpy else None
693
+ return reader.read(use_numpy=use_numpy)
694
+
695
+ def receive_exception(self):
696
+ return read_exception(self.fin)
697
+
698
+ def receive_progress(self):
699
+ progress = Progress()
700
+ progress.read(self.server_info, self.fin)
701
+ return progress
702
+
703
+ def receive_profile_info(self):
704
+ profile_info = BlockStreamProfileInfo()
705
+ profile_info.read(self.fin)
706
+ return profile_info
707
+
708
+ def receive_multistring_message(self, packet_type):
709
+ num = ServerPacketTypes.strings_in_message(packet_type)
710
+ return [read_binary_str(self.fin) for _i in range(num)]
711
+
712
+ def send_data(self, block, table_name=''):
713
+ start = time()
714
+ write_varint(ClientPacketTypes.DATA, self.fout)
715
+
716
+ revision = self.server_info.used_revision
717
+ if revision >= defines.DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES:
718
+ write_binary_str(table_name, self.fout)
719
+
720
+ self.block_out.write(block)
721
+ logger.debug('Block "%s" send time: %f', table_name, time() - start)
722
+
723
+ def send_query(self, query, query_id=None, params=None):
724
+ if not self.connected:
725
+ self.connect()
726
+
727
+ write_varint(ClientPacketTypes.QUERY, self.fout)
728
+
729
+ write_binary_str(query_id or '', self.fout)
730
+
731
+ revision = self.server_info.used_revision
732
+ if revision >= defines.DBMS_MIN_REVISION_WITH_CLIENT_INFO:
733
+ client_info = ClientInfo(self.client_name, self.context,
734
+ client_revision=self.client_revision)
735
+ client_info.query_kind = ClientInfo.QueryKind.INITIAL_QUERY
736
+
737
+ client_info.write(revision, self.fout)
738
+
739
+ settings_as_strings = (
740
+ revision >= defines
741
+ .DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS
742
+ )
743
+ settings_flags = 0
744
+ if self.settings_is_important:
745
+ settings_flags |= SettingsFlags.IMPORTANT
746
+ write_settings(self.context.settings, self.fout, settings_as_strings,
747
+ settings_flags)
748
+
749
+ if revision >= defines.DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET:
750
+ write_binary_str('', self.fout)
751
+
752
+ write_varint(QueryProcessingStage.COMPLETE, self.fout)
753
+ write_varint(self.compression, self.fout)
754
+
755
+ write_binary_str(query, self.fout)
756
+
757
+ if revision >= defines.DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS:
758
+ if self.context.client_settings['server_side_params']:
759
+ # Always settings_as_strings = True
760
+ escaped = escape_params(
761
+ params or {}, self.context, for_server=True
762
+ )
763
+ else:
764
+ escaped = {}
765
+ write_settings(escaped, self.fout, True, SettingsFlags.CUSTOM)
766
+
767
+ logger.debug('Query: %s', query)
768
+
769
+ self.fout.flush()
770
+
771
+ def send_cancel(self):
772
+ write_varint(ClientPacketTypes.CANCEL, self.fout)
773
+
774
+ self.fout.flush()
775
+
776
+ def send_external_tables(self, tables, types_check=False):
777
+ for table in tables or []:
778
+ if not table['structure']:
779
+ raise ValueError(
780
+ 'Empty table "{}" structure'.format(table['name'])
781
+ )
782
+
783
+ data = table['data']
784
+ block_cls = RowOrientedBlock
785
+
786
+ if self.context.client_settings['use_numpy']:
787
+ from .numpy.block import NumpyColumnOrientedBlock
788
+
789
+ columns = [x[0] for x in table['structure']]
790
+ data = [data[column].values for column in columns]
791
+
792
+ block_cls = NumpyColumnOrientedBlock
793
+
794
+ block = block_cls(table['structure'], data,
795
+ types_check=types_check)
796
+ self.send_data(block, table_name=table['name'])
797
+
798
+ # Empty block, end of data transfer.
799
+ self.send_data(RowOrientedBlock())
800
+
801
+ @contextmanager
802
+ def timeout_setter(self, new_timeout):
803
+ old_timeout = self.socket.gettimeout()
804
+ self.socket.settimeout(new_timeout)
805
+
806
+ yield
807
+
808
+ self.socket.settimeout(old_timeout)
809
+
810
+ def unexpected_packet_message(self, expected, packet_type):
811
+ packet_type = ServerPacketTypes.to_str(packet_type)
812
+
813
+ return (
814
+ 'Unexpected packet from server {} (expected {}, got {})'
815
+ .format(self.get_description(), expected, packet_type)
816
+ )
817
+
818
+ def check_query_execution(self):
819
+ self._lock.acquire(blocking=False)
820
+
821
+ if self.is_query_executing:
822
+ raise errors.PartiallyConsumedQueryError()
823
+
824
+ self.is_query_executing = True
825
+ self._lock.release()