voltdbclient 14.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
voltdbclient.py
ADDED
|
@@ -0,0 +1,2021 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# This file is part of VoltDB.
|
|
3
|
+
# Copyright (C) 2008-2025 Volt Active Data Inc.
|
|
4
|
+
#
|
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
|
6
|
+
# it under the terms of the GNU Affero General Public License as
|
|
7
|
+
# published by the Free Software Foundation, either version 3 of the
|
|
8
|
+
# License, or (at your option) any later version.
|
|
9
|
+
#
|
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
13
|
+
# GNU Affero General Public License for more details.
|
|
14
|
+
#
|
|
15
|
+
# You should have received a copy of the GNU Affero General Public License
|
|
16
|
+
# along with VoltDB. If not, see <http://www.gnu.org/licenses/>.
|
|
17
|
+
|
|
18
|
+
import sys
|
|
19
|
+
if sys.hexversion < 0x03060000:
|
|
20
|
+
# For now, we allow a minimum of 3.6 so that the current API can
|
|
21
|
+
# also be used on older systems. New features may not all be
|
|
22
|
+
# available when running on versions of Python older than 3.9.
|
|
23
|
+
# Volt applications using this API may enforce stricter limits.
|
|
24
|
+
raise Exception("Python version 3.6 or greater is required (3.9+ is preferred).")
|
|
25
|
+
|
|
26
|
+
import array
|
|
27
|
+
import atexit
|
|
28
|
+
import socket
|
|
29
|
+
import base64, textwrap
|
|
30
|
+
import struct
|
|
31
|
+
import datetime
|
|
32
|
+
import decimal
|
|
33
|
+
import hashlib
|
|
34
|
+
import re
|
|
35
|
+
import math
|
|
36
|
+
import os
|
|
37
|
+
import stat
|
|
38
|
+
import time
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import ssl
|
|
42
|
+
ssl_available = True
|
|
43
|
+
except ImportError as e:
|
|
44
|
+
ssl_available = False
|
|
45
|
+
ssl_exception = e
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
import jks
|
|
49
|
+
pyjks_available = True
|
|
50
|
+
except ImportError as e:
|
|
51
|
+
pyjks_available = False
|
|
52
|
+
pyjks_exception = e
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from cryptography.hazmat.primitives.serialization import pkcs12, Encoding, PrivateFormat, NoEncryption
|
|
56
|
+
pkcs12_available = True
|
|
57
|
+
except ImportError as e:
|
|
58
|
+
pkcs12_available = False
|
|
59
|
+
pkcs12_exception = e
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
import gssapi
|
|
63
|
+
kerberos_available = True
|
|
64
|
+
except ImportError as e:
|
|
65
|
+
kerberos_available = False
|
|
66
|
+
kerberos_exception = e
|
|
67
|
+
|
|
68
|
+
logger = None
|
|
69
|
+
|
|
70
|
+
def use_logging():
|
|
71
|
+
import logging
|
|
72
|
+
global logger
|
|
73
|
+
logger = logging.getLogger()
|
|
74
|
+
|
|
75
|
+
def error(text):
|
|
76
|
+
if logger:
|
|
77
|
+
logger.error(text)
|
|
78
|
+
else:
|
|
79
|
+
print(text)
|
|
80
|
+
|
|
81
|
+
decimal.getcontext().prec = 38
|
|
82
|
+
|
|
83
|
+
def int16toBytes(val):
|
|
84
|
+
return [val >> 8 & 0xff,
|
|
85
|
+
val >> 0 & 0xff]
|
|
86
|
+
|
|
87
|
+
def int32toBytes(val):
|
|
88
|
+
return [val >> 24 & 0xff,
|
|
89
|
+
val >> 16 & 0xff,
|
|
90
|
+
val >> 8 & 0xff,
|
|
91
|
+
val >> 0 & 0xff]
|
|
92
|
+
|
|
93
|
+
def int64toBytes(val):
|
|
94
|
+
return [val >> 56 & 0xff,
|
|
95
|
+
val >> 48 & 0xff,
|
|
96
|
+
val >> 40 & 0xff,
|
|
97
|
+
val >> 32 & 0xff,
|
|
98
|
+
val >> 24 & 0xff,
|
|
99
|
+
val >> 16 & 0xff,
|
|
100
|
+
val >> 8 & 0xff,
|
|
101
|
+
val >> 0 & 0xff]
|
|
102
|
+
|
|
103
|
+
def isNaN(d):
|
|
104
|
+
# Per IEEE 754, 'NaN == NaN' must be false,
|
|
105
|
+
# so we cannot check for simple equality
|
|
106
|
+
if d == None:
|
|
107
|
+
return False
|
|
108
|
+
else: # routine misnamed, returns true for 'Inf' too
|
|
109
|
+
return math.isnan(d) or math.isinf(d)
|
|
110
|
+
|
|
111
|
+
class ReadBuffer(object):
|
|
112
|
+
"""
|
|
113
|
+
Read buffer management class.
|
|
114
|
+
"""
|
|
115
|
+
def __init__(self):
|
|
116
|
+
self.clear()
|
|
117
|
+
|
|
118
|
+
def clear(self):
|
|
119
|
+
self._buf = bytes()
|
|
120
|
+
self._off = 0
|
|
121
|
+
|
|
122
|
+
def buffer_length(self):
|
|
123
|
+
return len(self._buf)
|
|
124
|
+
|
|
125
|
+
def remaining(self):
|
|
126
|
+
return (len(self._buf) - self._off)
|
|
127
|
+
|
|
128
|
+
def get_buffer(self):
|
|
129
|
+
return self._buf
|
|
130
|
+
|
|
131
|
+
def append(self, content):
|
|
132
|
+
self._buf += content
|
|
133
|
+
|
|
134
|
+
def shift(self, size):
|
|
135
|
+
self._off += size
|
|
136
|
+
|
|
137
|
+
def read(self, size):
|
|
138
|
+
return self._buf[self._off:self._off+size]
|
|
139
|
+
|
|
140
|
+
def unpack(self, format, size):
|
|
141
|
+
try:
|
|
142
|
+
values = struct.unpack_from(format, self._buf, self._off)
|
|
143
|
+
except struct.error as e:
|
|
144
|
+
error('Exception unpacking %d bytes using format "%s": %s' % (size, format, str(e)))
|
|
145
|
+
raise e
|
|
146
|
+
self.shift(size)
|
|
147
|
+
return values
|
|
148
|
+
|
|
149
|
+
unique_tag = '%x' % int(time.time() * 1000000)
|
|
150
|
+
scratch_dir = None
|
|
151
|
+
temporary_files = []
|
|
152
|
+
|
|
153
|
+
def remove_temporary_files():
|
|
154
|
+
global temporary_files
|
|
155
|
+
for tf in temporary_files:
|
|
156
|
+
try:
|
|
157
|
+
os.unlink(tf)
|
|
158
|
+
except:
|
|
159
|
+
pass
|
|
160
|
+
temporary_files = []
|
|
161
|
+
|
|
162
|
+
atexit.register(remove_temporary_files)
|
|
163
|
+
|
|
164
|
+
class FastSerializer:
|
|
165
|
+
"Primitive type de/serialization in VoltDB formats"
|
|
166
|
+
|
|
167
|
+
LITTLE_ENDIAN = '<'
|
|
168
|
+
BIG_ENDIAN = '>'
|
|
169
|
+
|
|
170
|
+
ARRAY = -99
|
|
171
|
+
|
|
172
|
+
# VoltType enumerations
|
|
173
|
+
VOLTTYPE_NULL = 1
|
|
174
|
+
VOLTTYPE_TINYINT = 3 # int8
|
|
175
|
+
VOLTTYPE_SMALLINT = 4 # int16
|
|
176
|
+
VOLTTYPE_INTEGER = 5 # int32
|
|
177
|
+
VOLTTYPE_BIGINT = 6 # int64
|
|
178
|
+
VOLTTYPE_FLOAT = 8 # float64
|
|
179
|
+
VOLTTYPE_STRING = 9
|
|
180
|
+
VOLTTYPE_TIMESTAMP = 11 # 8 byte long
|
|
181
|
+
VOLTTYPE_DECIMAL = 22 # fixed precision decimal
|
|
182
|
+
VOLTTYPE_MONEY = 20 # 8 byte long
|
|
183
|
+
VOLTTYPE_VOLTTABLE = 21
|
|
184
|
+
VOLTTYPE_VARBINARY = 25
|
|
185
|
+
VOLTTYPE_GEOGRAPHY_POINT = 26
|
|
186
|
+
VOLTTYPE_GEOGRAPHY = 27
|
|
187
|
+
|
|
188
|
+
# SQL NULL indicator for object type serializations (string, decimal)
|
|
189
|
+
NULL_STRING_INDICATOR = -1
|
|
190
|
+
NULL_DECIMAL_INDICATOR = -170141183460469231731687303715884105728
|
|
191
|
+
NULL_TINYINT_INDICATOR = -128
|
|
192
|
+
NULL_SMALLINT_INDICATOR = -32768
|
|
193
|
+
NULL_INTEGER_INDICATOR = -2147483648
|
|
194
|
+
NULL_BIGINT_INDICATOR = -9223372036854775808
|
|
195
|
+
NULL_FLOAT_INDICATOR = -1.7E308
|
|
196
|
+
|
|
197
|
+
# default decimal scale
|
|
198
|
+
DEFAULT_DECIMAL_SCALE = 12
|
|
199
|
+
|
|
200
|
+
# protocol constants
|
|
201
|
+
AUTH_HANDSHAKE_VERSION = 2
|
|
202
|
+
AUTH_SERVICE_NAME = 4
|
|
203
|
+
AUTH_HANDSHAKE = 5
|
|
204
|
+
|
|
205
|
+
# procedure call result codes
|
|
206
|
+
PROC_OK = 0
|
|
207
|
+
|
|
208
|
+
# there are assumptions here about datatype sizes which are
|
|
209
|
+
# machine dependent. the program exits with an error message
|
|
210
|
+
# if these assumptions are not true. it is further assumed
|
|
211
|
+
# that host order is little endian. See isNaN().
|
|
212
|
+
|
|
213
|
+
# default ssl configuration
|
|
214
|
+
if (ssl_available):
|
|
215
|
+
DEFAULT_SSL_CONFIG = {
|
|
216
|
+
'keyfile': None,
|
|
217
|
+
'keypass': None,
|
|
218
|
+
'certfile': None,
|
|
219
|
+
'cert_reqs': ssl.CERT_NONE,
|
|
220
|
+
'ca_certs': None,
|
|
221
|
+
'do_handshake_on_connect': True
|
|
222
|
+
}
|
|
223
|
+
else:
|
|
224
|
+
DEFAULT_SSL_CONFIG = {}
|
|
225
|
+
|
|
226
|
+
def __init__(self, host = None,
|
|
227
|
+
port = 21212,
|
|
228
|
+
usessl = False,
|
|
229
|
+
username = "",
|
|
230
|
+
password = "",
|
|
231
|
+
kerberos = False,
|
|
232
|
+
dump_file_path = None,
|
|
233
|
+
connect_timeout = 8,
|
|
234
|
+
procedure_timeout = None,
|
|
235
|
+
default_timeout = None,
|
|
236
|
+
ssl_config_file = None,
|
|
237
|
+
default_cacerts = True):
|
|
238
|
+
"""
|
|
239
|
+
:param host: host string for connection or None
|
|
240
|
+
:param port: port for connection or None
|
|
241
|
+
:param usessl: switch for use ssl or not
|
|
242
|
+
:param username: authentication user name for connection or None
|
|
243
|
+
:param password: authentication password for connection or None
|
|
244
|
+
:param kerberos: use Kerberos authentication
|
|
245
|
+
:param dump_file_path: path to optional dump file or None
|
|
246
|
+
:param connect_timeout: timeout (secs) or None for authentication (default=8)
|
|
247
|
+
:param procedure_timeout: timeout (secs) or None for procedure calls (default=None)
|
|
248
|
+
:param default_timeout: default timeout (secs) or None for all other operations (default=None)
|
|
249
|
+
:param ssl_config_file: config file that defines java keystore and truststore files
|
|
250
|
+
:param default_cacerts: if true, use installation default cacerts when truststore unspecified
|
|
251
|
+
"""
|
|
252
|
+
# connect a socket to host, port and get a file object
|
|
253
|
+
self.wbuf = array.array('B')
|
|
254
|
+
self.host = host
|
|
255
|
+
self.port = port
|
|
256
|
+
self.usessl = usessl
|
|
257
|
+
if kerberos is None:
|
|
258
|
+
self.usekerberos = False
|
|
259
|
+
else:
|
|
260
|
+
self.usekerberos = kerberos
|
|
261
|
+
self.kerberosprincipal = None
|
|
262
|
+
self.ssl_config = self.DEFAULT_SSL_CONFIG
|
|
263
|
+
self.ssl_config_file = ssl_config_file
|
|
264
|
+
self.default_cacerts = default_cacerts and usessl
|
|
265
|
+
if not dump_file_path is None:
|
|
266
|
+
self.dump_file = open(dump_file_path, "wb")
|
|
267
|
+
else:
|
|
268
|
+
self.dump_file = None
|
|
269
|
+
self.default_timeout = default_timeout
|
|
270
|
+
self.procedure_timeout = procedure_timeout
|
|
271
|
+
|
|
272
|
+
self.socket = None
|
|
273
|
+
if self.host != None and self.port != None:
|
|
274
|
+
ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG)[0]
|
|
275
|
+
# ai = (family, socktype, proto, canonname, sockaddr)
|
|
276
|
+
ss = socket.socket(ai[0], ai[1], ai[2])
|
|
277
|
+
if self.usessl:
|
|
278
|
+
if ssl_available:
|
|
279
|
+
self.socket = self.__wrap_socket(ss)
|
|
280
|
+
else:
|
|
281
|
+
error("ERROR: To use SSL functionality please install the Python ssl module.")
|
|
282
|
+
raise ssl_exception
|
|
283
|
+
else:
|
|
284
|
+
self.socket = ss
|
|
285
|
+
self.socket.setblocking(1)
|
|
286
|
+
self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
|
287
|
+
try:
|
|
288
|
+
self.socket.connect(ai[4])
|
|
289
|
+
except Exception:
|
|
290
|
+
error("ERROR: Failed to connect to %s port %s" % (ai[4][0], ai[4][1]))
|
|
291
|
+
raise
|
|
292
|
+
#if self.usessl:
|
|
293
|
+
# print('Cipher suite: ' + str(self.socket.cipher()))
|
|
294
|
+
|
|
295
|
+
# input can be big or little endian
|
|
296
|
+
self.inputBOM = self.BIG_ENDIAN # byte order if input stream
|
|
297
|
+
self.localBOM = self.LITTLE_ENDIAN # byte order of host
|
|
298
|
+
|
|
299
|
+
# Type to reader/writer mappings
|
|
300
|
+
self.READER = {self.VOLTTYPE_NULL: self.readNull,
|
|
301
|
+
self.VOLTTYPE_TINYINT: self.readByte,
|
|
302
|
+
self.VOLTTYPE_SMALLINT: self.readInt16,
|
|
303
|
+
self.VOLTTYPE_INTEGER: self.readInt32,
|
|
304
|
+
self.VOLTTYPE_BIGINT: self.readInt64,
|
|
305
|
+
self.VOLTTYPE_FLOAT: self.readFloat64,
|
|
306
|
+
self.VOLTTYPE_STRING: self.readString,
|
|
307
|
+
self.VOLTTYPE_VARBINARY: self.readVarbinary,
|
|
308
|
+
self.VOLTTYPE_TIMESTAMP: self.readDate,
|
|
309
|
+
self.VOLTTYPE_DECIMAL: self.readDecimal,
|
|
310
|
+
self.VOLTTYPE_GEOGRAPHY_POINT: self.readGeographyPoint,
|
|
311
|
+
self.VOLTTYPE_GEOGRAPHY: self.readGeography}
|
|
312
|
+
self.WRITER = {self.VOLTTYPE_NULL: self.writeNull,
|
|
313
|
+
self.VOLTTYPE_TINYINT: self.writeByte,
|
|
314
|
+
self.VOLTTYPE_SMALLINT: self.writeInt16,
|
|
315
|
+
self.VOLTTYPE_INTEGER: self.writeInt32,
|
|
316
|
+
self.VOLTTYPE_BIGINT: self.writeInt64,
|
|
317
|
+
self.VOLTTYPE_FLOAT: self.writeFloat64,
|
|
318
|
+
self.VOLTTYPE_STRING: self.writeString,
|
|
319
|
+
self.VOLTTYPE_VARBINARY: self.writeVarbinary,
|
|
320
|
+
self.VOLTTYPE_TIMESTAMP: self.writeDate,
|
|
321
|
+
self.VOLTTYPE_DECIMAL: self.writeDecimal,
|
|
322
|
+
self.VOLTTYPE_GEOGRAPHY_POINT: self.writeGeographyPoint,
|
|
323
|
+
self.VOLTTYPE_GEOGRAPHY: self.writeGeography}
|
|
324
|
+
self.ARRAY_READER = {self.VOLTTYPE_TINYINT: self.readByteArray,
|
|
325
|
+
self.VOLTTYPE_SMALLINT: self.readInt16Array,
|
|
326
|
+
self.VOLTTYPE_INTEGER: self.readInt32Array,
|
|
327
|
+
self.VOLTTYPE_BIGINT: self.readInt64Array,
|
|
328
|
+
self.VOLTTYPE_FLOAT: self.readFloat64Array,
|
|
329
|
+
self.VOLTTYPE_STRING: self.readStringArray,
|
|
330
|
+
self.VOLTTYPE_TIMESTAMP: self.readDateArray,
|
|
331
|
+
self.VOLTTYPE_DECIMAL: self.readDecimalArray,
|
|
332
|
+
self.VOLTTYPE_GEOGRAPHY_POINT: self.readGeographyPointArray,
|
|
333
|
+
self.VOLTTYPE_GEOGRAPHY: self.readGeographyArray}
|
|
334
|
+
|
|
335
|
+
self.__compileStructs()
|
|
336
|
+
|
|
337
|
+
# Check if the value of a given type is NULL
|
|
338
|
+
self.NULL_DECIMAL_INDICATOR = \
|
|
339
|
+
self.__intToBytes(self.__class__.NULL_DECIMAL_INDICATOR, 0)
|
|
340
|
+
self.NullCheck = {self.VOLTTYPE_NULL:
|
|
341
|
+
lambda x: None,
|
|
342
|
+
self.VOLTTYPE_TINYINT:
|
|
343
|
+
lambda x: None if x == self.__class__.NULL_TINYINT_INDICATOR else x,
|
|
344
|
+
self.VOLTTYPE_SMALLINT:
|
|
345
|
+
lambda x: None if x == self.__class__.NULL_SMALLINT_INDICATOR else x,
|
|
346
|
+
self.VOLTTYPE_INTEGER:
|
|
347
|
+
lambda x: None if x == self.__class__.NULL_INTEGER_INDICATOR else x,
|
|
348
|
+
self.VOLTTYPE_BIGINT:
|
|
349
|
+
lambda x: None if x == self.__class__.NULL_BIGINT_INDICATOR else x,
|
|
350
|
+
self.VOLTTYPE_FLOAT:
|
|
351
|
+
lambda x: None if abs(x - self.__class__.NULL_FLOAT_INDICATOR) < 1e307 else x,
|
|
352
|
+
self.VOLTTYPE_STRING:
|
|
353
|
+
lambda x: None if x == self.__class__.NULL_STRING_INDICATOR else x,
|
|
354
|
+
self.VOLTTYPE_VARBINARY:
|
|
355
|
+
lambda x: None if x == self.__class__.NULL_STRING_INDICATOR else x,
|
|
356
|
+
self.VOLTTYPE_DECIMAL:
|
|
357
|
+
lambda x: None if x == self.NULL_DECIMAL_INDICATOR else x}
|
|
358
|
+
|
|
359
|
+
self.read_buffer = ReadBuffer()
|
|
360
|
+
|
|
361
|
+
if self.usekerberos:
|
|
362
|
+
if not kerberos_available:
|
|
363
|
+
raise RuntimeError("Requested Kerberos authentication but unable to import the GSSAPI package.")
|
|
364
|
+
if not self.has_ticket():
|
|
365
|
+
raise RuntimeError("Requested Kerberos authentication but no valid ticket found. Authenticate with Kerberos first.")
|
|
366
|
+
assert not self.socket is None
|
|
367
|
+
self.socket.settimeout(connect_timeout)
|
|
368
|
+
self.authenticate(str(self.kerberosprincipal), "")
|
|
369
|
+
elif not username is None and not password is None and not host is None:
|
|
370
|
+
assert not self.socket is None
|
|
371
|
+
self.socket.settimeout(connect_timeout)
|
|
372
|
+
self.authenticate(username, password)
|
|
373
|
+
|
|
374
|
+
if self.socket:
|
|
375
|
+
self.socket.settimeout(self.default_timeout)
|
|
376
|
+
|
|
377
|
+
# Front end to SSL socket support.
|
|
378
|
+
#
|
|
379
|
+
# The SSL config file can be one of:
|
|
380
|
+
# - a properties file (see below)
|
|
381
|
+
# - a truststore to verify server identity, in Java keystore format
|
|
382
|
+
# - a certificate chain to verify server identity, in PEM format
|
|
383
|
+
#
|
|
384
|
+
# "Java keystore" is either the traditional JKS keystore as created
|
|
385
|
+
# by the keytool program, or else a PKCS12 file, which is now the
|
|
386
|
+
# preferred output from keytool. Both are binary encodings.
|
|
387
|
+
#
|
|
388
|
+
# We initially look at the first few bytes of the file to decide
|
|
389
|
+
# what sort of file we are dealing with.
|
|
390
|
+
#
|
|
391
|
+
# A properties file contains a sequence of key=value lines which
|
|
392
|
+
# are used to provide arguments to the SSL context methods:
|
|
393
|
+
# Key Value Used in call to
|
|
394
|
+
# ------------------ ------------------------ ------------------
|
|
395
|
+
# keystore path to keystore load_cert_chain
|
|
396
|
+
# keystorepassword password for keystore --
|
|
397
|
+
# truststore path to truststore load_verify_locations
|
|
398
|
+
# truststorepassword password for truststore --
|
|
399
|
+
# cacerts path to PEM cert chain load_verify_locations
|
|
400
|
+
# ssl_version ignored --
|
|
401
|
+
#
|
|
402
|
+
# Thus keystore identifies the client (needed if mutual authentication
|
|
403
|
+
# is required by the server side), whereas truststore and cacerts
|
|
404
|
+
# identify the server. If truststore and cacerts are both specified,
|
|
405
|
+
# cacerts takes precedence.
|
|
406
|
+
#
|
|
407
|
+
# Keystore can only be set via the properties file, and it can
|
|
408
|
+
# be a Java keystore with a single private key entry, or a PEM
|
|
409
|
+
# file containing a PKCS#8 private key and a chain of X.509
|
|
410
|
+
# certificates. Truststore can be set directly as the SSL
|
|
411
|
+
# config file, or via the properties file. Either way it can
|
|
412
|
+
# be a Java keystore or a PEM file, containing a chain of
|
|
413
|
+
# X.509 certificates.
|
|
414
|
+
#
|
|
415
|
+
# If the ssl_config_file name is absent or empty, then:
|
|
416
|
+
# - if default_cacerts is true, certficate checks will use the
|
|
417
|
+
# installation default cacerts
|
|
418
|
+
# - if default_cacerts is false, no certificate checks will be
|
|
419
|
+
# done; we will blindly accept the server's cert
|
|
420
|
+
|
|
421
|
+
def __wrap_socket(self, ss):
|
|
422
|
+
parsed_config = {}
|
|
423
|
+
if self.ssl_config_file:
|
|
424
|
+
parsed_config = self.__process_ssl_config_file()
|
|
425
|
+
|
|
426
|
+
# Process keystore/truststore files; non-PEM files need conversion to PEM
|
|
427
|
+
|
|
428
|
+
keystore_type = truststore_type = None
|
|
429
|
+
if 'keystore' in parsed_config and parsed_config['keystore']:
|
|
430
|
+
keystore_type = self.__classify(parsed_config['keystore'])
|
|
431
|
+
if 'truststore' in parsed_config and parsed_config['truststore']:
|
|
432
|
+
truststore_type = self.__classify(parsed_config['truststore'])
|
|
433
|
+
|
|
434
|
+
store_type = keystore_type or truststore_type
|
|
435
|
+
if store_type:
|
|
436
|
+
if keystore_type and truststore_type and keystore_type != truststore_type:
|
|
437
|
+
raise RuntimeError("keystore file and truststore file must have same format")
|
|
438
|
+
elif store_type == 'pem':
|
|
439
|
+
self.__set_up_pem_files(parsed_config)
|
|
440
|
+
elif store_type == 'jks':
|
|
441
|
+
self.__convert_jks_files(parsed_config)
|
|
442
|
+
elif store_type == 'bin': # assumed to be pkcs12
|
|
443
|
+
self.__convert_pkcs12_files(parsed_config)
|
|
444
|
+
elif store_type == 'prop':
|
|
445
|
+
raise RuntimeError ("keystore or truststore path from property file is a property file")
|
|
446
|
+
else:
|
|
447
|
+
raise RuntimeError("internal error, unknown store type '%s'" % store_type)
|
|
448
|
+
|
|
449
|
+
# Additional cacerts (may supersede truststore)
|
|
450
|
+
|
|
451
|
+
if 'cacerts' in parsed_config and parsed_config['cacerts']:
|
|
452
|
+
cacerts_type = self.__classify(parsed_config['cacerts'])
|
|
453
|
+
if cacerts_type != 'pem':
|
|
454
|
+
raise RuntimeError("cacerts file %s is not PEM format" % parsed_config['cacerts'])
|
|
455
|
+
self.ssl_config['ca_certs'] = parsed_config['cacerts']
|
|
456
|
+
self.ssl_config['cert_reqs'] = ssl.CERT_REQUIRED
|
|
457
|
+
|
|
458
|
+
if self.default_cacerts:
|
|
459
|
+
self.ssl_config['cert_reqs'] = ssl.CERT_REQUIRED
|
|
460
|
+
|
|
461
|
+
# convert ssl_config to python SSLContext and wrap socket
|
|
462
|
+
|
|
463
|
+
context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS)
|
|
464
|
+
context.verify_mode = self.ssl_config['cert_reqs']
|
|
465
|
+
|
|
466
|
+
if self.ssl_config['certfile']:
|
|
467
|
+
context.load_cert_chain(certfile=self.ssl_config['certfile'],
|
|
468
|
+
keyfile=self.ssl_config['keyfile'],
|
|
469
|
+
password=self.ssl_config['keypass'])
|
|
470
|
+
|
|
471
|
+
if self.ssl_config['cert_reqs'] != ssl.CERT_NONE:
|
|
472
|
+
if self.ssl_config['ca_certs']:
|
|
473
|
+
context.load_verify_locations(cafile=self.ssl_config['ca_certs'])
|
|
474
|
+
else:
|
|
475
|
+
context.load_default_certs()
|
|
476
|
+
|
|
477
|
+
if sys.hexversion >= 0x03070000:
|
|
478
|
+
protocols = os.getenv('TLS_ENABLED_PROTOCOLS')
|
|
479
|
+
if protocols:
|
|
480
|
+
min, max = self.__select_protocols(protocols)
|
|
481
|
+
if min and max:
|
|
482
|
+
context.minimum_version = min
|
|
483
|
+
context.maximum_version = max
|
|
484
|
+
else:
|
|
485
|
+
print("TLS_ENABLED_PROTOCOLS ignored: no supported versions found")
|
|
486
|
+
|
|
487
|
+
ciphers = os.getenv('TLS_PREFERRED_CIPHERS')
|
|
488
|
+
if ciphers:
|
|
489
|
+
try:
|
|
490
|
+
context.set_ciphers(self.__select_ciphers(context, ciphers))
|
|
491
|
+
except ssl.SSLError as ex:
|
|
492
|
+
print("TLS_PREFERRED_CIPHERS ignored: %s" % ex)
|
|
493
|
+
|
|
494
|
+
return context.wrap_socket(ss)
|
|
495
|
+
|
|
496
|
+
def __process_ssl_config_file(self):
|
|
497
|
+
resolved_file = os.path.expandvars(os.path.expanduser(self.ssl_config_file))
|
|
498
|
+
file_type = self.__classify(resolved_file)
|
|
499
|
+
if file_type == 'prop':
|
|
500
|
+
properties = resolve_paths(read_properties_file(resolved_file, True))
|
|
501
|
+
elif file_type == 'pem':
|
|
502
|
+
properties = { 'cacerts': resolved_file }
|
|
503
|
+
elif file_type == 'jks' or file_type == 'bin':
|
|
504
|
+
properties = { 'truststore': resolved_file }
|
|
505
|
+
else:
|
|
506
|
+
raise RuntimeError('Unexpected file classification: %s' % file_type)
|
|
507
|
+
return properties
|
|
508
|
+
|
|
509
|
+
def __classify(self, file):
|
|
510
|
+
sample_size = 64
|
|
511
|
+
with open(file, mode='rb') as f:
|
|
512
|
+
buff = f.read(sample_size)
|
|
513
|
+
if len(buff) == sample_size:
|
|
514
|
+
if buff[:4] == bytes.fromhex('feedfeed'): # jks if starts with the right magic
|
|
515
|
+
return 'jks'
|
|
516
|
+
for b in buff: # binary if contains C0/C1 control except CR, LF, tab
|
|
517
|
+
if (b & 0x7f) < 0x20 and b != 0x0d and b != 0x0a and b != 0x09:
|
|
518
|
+
return 'bin'
|
|
519
|
+
if buff.decode().startswith('-----BEGIN'):
|
|
520
|
+
return 'pem'
|
|
521
|
+
return 'prop' # compatible default
|
|
522
|
+
|
|
523
|
+
def __unique_name(self, path, cksum):
|
|
524
|
+
if not scratch_dir:
|
|
525
|
+
self.__set_scratch_dir()
|
|
526
|
+
base = os.path.basename(path) or 'noname'
|
|
527
|
+
return scratch_dir + '/' + base + '-' + cksum + '-' + unique_tag
|
|
528
|
+
|
|
529
|
+
def __set_scratch_dir(self):
|
|
530
|
+
global scratch_dir
|
|
531
|
+
try:
|
|
532
|
+
dir = os.path.expanduser('~/.voltssl')
|
|
533
|
+
if dir[0] != '~': # expanded ok
|
|
534
|
+
if not os.path.exists(dir):
|
|
535
|
+
os.mkdir(dir, stat.S_IRUSR|stat.S_IWUSR|stat.S_IXUSR)
|
|
536
|
+
scratch_dir = dir
|
|
537
|
+
except: # mkdir failed
|
|
538
|
+
pass
|
|
539
|
+
if not scratch_dir:
|
|
540
|
+
scratch_dir = os.getenv('TMPDIR', '/tmp')
|
|
541
|
+
|
|
542
|
+
def __set_up_pem_files(self, pem_config):
|
|
543
|
+
use_ks = 'keystore' in pem_config and pem_config['keystore']
|
|
544
|
+
use_ts = 'truststore' in pem_config and pem_config['truststore']
|
|
545
|
+
if use_ks:
|
|
546
|
+
if self.__has_cert(pem_config['keystore']) or not use_ts:
|
|
547
|
+
self.ssl_config['keyfile'] = None
|
|
548
|
+
self.ssl_config['certfile'] = pem_config['keystore']
|
|
549
|
+
else: # key and cert in separate file
|
|
550
|
+
self.ssl_config['keyfile'] = pem_config['keystore']
|
|
551
|
+
self.ssl_config['certfile'] = pem_config['truststore']
|
|
552
|
+
self.ssl_config['keypass'] = pem_config.get('keystorepassword')
|
|
553
|
+
if use_ts:
|
|
554
|
+
self.ssl_config['ca_certs'] = pem_config['truststore']
|
|
555
|
+
self.ssl_config['cert_reqs'] = ssl.CERT_REQUIRED
|
|
556
|
+
|
|
557
|
+
def __has_cert(self, keyfile):
|
|
558
|
+
with open(keyfile, 'r') as f:
|
|
559
|
+
data = f.read()
|
|
560
|
+
return 'CERTIFICATE' in data
|
|
561
|
+
|
|
562
|
+
def __convert_jks_files(self, jks_config):
|
|
563
|
+
if not pyjks_available:
|
|
564
|
+
if os.getenv('VOLTDB_CONTAINER'):
|
|
565
|
+
print("Java KeyStore support is unavailable in this container.\n" +
|
|
566
|
+
"You can use --ssl=nocheck to skip verification of the server certificate.\n");
|
|
567
|
+
else:
|
|
568
|
+
error("To use Java KeyStore please install the 'pyjks' module.\n" +
|
|
569
|
+
"It may be more convenient to use a PEM file instead.\n");
|
|
570
|
+
raise pyjks_exception
|
|
571
|
+
|
|
572
|
+
def load_keystore(filename, password):
|
|
573
|
+
with open(filename, 'rb') as f:
|
|
574
|
+
data = f.read()
|
|
575
|
+
ks = jks.KeyStore.loads(data, password)
|
|
576
|
+
cksum = hashlib.md5(data).hexdigest()
|
|
577
|
+
return ks, cksum
|
|
578
|
+
|
|
579
|
+
def write_pem(der_bytes, type, f):
|
|
580
|
+
f.write("-----BEGIN %s-----\n" % type)
|
|
581
|
+
f.write("\r\n".join(textwrap.wrap(base64.b64encode(der_bytes).decode('ascii'), 64)))
|
|
582
|
+
f.write("\n-----END %s-----\n" % type)
|
|
583
|
+
|
|
584
|
+
# extract key and certs from jks keystore with cacheing
|
|
585
|
+
use_key_cert = False
|
|
586
|
+
if 'keystore' in jks_config and jks_config['keystore']:
|
|
587
|
+
kpass = jks_config.get('keystorepassword')
|
|
588
|
+
ks, cksum = load_keystore(jks_config['keystore'], kpass)
|
|
589
|
+
kfname = self.__unique_name(jks_config['keystore'], cksum)
|
|
590
|
+
keyfilename = kfname + '.key.pem'
|
|
591
|
+
certfilename = kfname + '.cert.pem'
|
|
592
|
+
|
|
593
|
+
if os.path.exists(keyfilename) and os.path.exists(certfilename):
|
|
594
|
+
use_key_cert = os.path.getsize(certfilename) > 0
|
|
595
|
+
else:
|
|
596
|
+
keyfile = self.__create_temp(keyfilename)
|
|
597
|
+
certfile = self.__create_temp(certfilename)
|
|
598
|
+
for alias, pk in list(ks.private_keys.items()):
|
|
599
|
+
# print("Private key: %s" % pk.alias)
|
|
600
|
+
if pk.algorithm_oid == jks.util.RSA_ENCRYPTION_OID:
|
|
601
|
+
write_pem(pk.pkey, "RSA PRIVATE KEY", keyfile)
|
|
602
|
+
else:
|
|
603
|
+
write_pem(pk.pkey_pkcs8, "PRIVATE KEY", keyfile)
|
|
604
|
+
for c in pk.cert_chain:
|
|
605
|
+
write_pem(c[1], "CERTIFICATE", certfile)
|
|
606
|
+
use_key_cert = True
|
|
607
|
+
keyfile.close()
|
|
608
|
+
certfile.close()
|
|
609
|
+
|
|
610
|
+
if use_key_cert:
|
|
611
|
+
self.ssl_config['keyfile'] = keyfilename
|
|
612
|
+
self.ssl_config['certfile'] = certfilename
|
|
613
|
+
|
|
614
|
+
# extract ca certs from jks truststore with cacheing
|
|
615
|
+
use_ca_cert = False
|
|
616
|
+
if 'truststore' in jks_config and jks_config['truststore']:
|
|
617
|
+
tpass = jks_config.get('truststorepassword')
|
|
618
|
+
ts, cksum = load_keystore(jks_config['truststore'], tpass)
|
|
619
|
+
tfname = self.__unique_name(jks_config['truststore'], cksum)
|
|
620
|
+
cafilename = tfname + '.ca.cert.pem'
|
|
621
|
+
|
|
622
|
+
if os.path.exists(cafilename):
|
|
623
|
+
use_ca_cert = os.path.getsize(cafilename) > 0
|
|
624
|
+
else:
|
|
625
|
+
cafile = self.__create_temp(cafilename)
|
|
626
|
+
for alias, c in list(ts.certs.items()):
|
|
627
|
+
# print("Certificate: %s" % c.alias)
|
|
628
|
+
write_pem(c.cert, "CERTIFICATE", cafile)
|
|
629
|
+
use_ca_cert = True
|
|
630
|
+
cafile.close()
|
|
631
|
+
|
|
632
|
+
if use_ca_cert:
|
|
633
|
+
self.ssl_config['ca_certs'] = cafilename
|
|
634
|
+
self.ssl_config['cert_reqs'] = ssl.CERT_REQUIRED
|
|
635
|
+
|
|
636
|
+
def __convert_pkcs12_files(self, p12_config):
|
|
637
|
+
if not pkcs12_available:
|
|
638
|
+
if os.getenv('VOLTDB_CONTAINER'):
|
|
639
|
+
print("PKCS12 certificate support is unavailable in this container.\n" +
|
|
640
|
+
"You can use --ssl=nocheck to skip verification of the server certificate.\n");
|
|
641
|
+
else:
|
|
642
|
+
error("To use PKCS12 certificate support please install the 'cryptography' module.\n" +
|
|
643
|
+
"It may be more convenient to use a PEM file instead.\n");
|
|
644
|
+
raise pkcs12_exception
|
|
645
|
+
|
|
646
|
+
def checksum(fname):
|
|
647
|
+
with open(fname, 'rb') as f:
|
|
648
|
+
data = f.read()
|
|
649
|
+
return hashlib.md5(data).hexdigest()
|
|
650
|
+
|
|
651
|
+
def write_private_key(fname, key):
|
|
652
|
+
with self.__create_temp(fname) as f:
|
|
653
|
+
f.write(key.private_bytes(Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()).decode('ascii'))
|
|
654
|
+
|
|
655
|
+
def write_cert_chain(fname, cert, more):
|
|
656
|
+
with self.__create_temp(fname) as f:
|
|
657
|
+
if cert:
|
|
658
|
+
f.write(cert.public_bytes(Encoding.PEM).decode('ascii'))
|
|
659
|
+
if more:
|
|
660
|
+
for cert in more:
|
|
661
|
+
f.write(cert.public_bytes(Encoding.PEM).decode('ascii'))
|
|
662
|
+
|
|
663
|
+
# pkcs12 keystore file to PEM files: private key and certificate
|
|
664
|
+
use_key_cert = False
|
|
665
|
+
if 'keystore' in p12_config and p12_config['keystore']:
|
|
666
|
+
kpass = p12_config.get('keystorepassword')
|
|
667
|
+
if kpass: kpass = kpass.encode('utf-8')
|
|
668
|
+
kfname = self.__unique_name(p12_config['keystore'], checksum(p12_config['keystore']))
|
|
669
|
+
keyfilename = kfname + '.key.pem'
|
|
670
|
+
certfilename = kfname + '.cert.pem'
|
|
671
|
+
|
|
672
|
+
if os.path.exists(keyfilename) and os.path.exists(certfilename):
|
|
673
|
+
use_key_cert = os.path.getsize(certfilename) > 0
|
|
674
|
+
else:
|
|
675
|
+
with open(p12_config['keystore'], 'rb') as f:
|
|
676
|
+
keystore = pkcs12.load_key_and_certificates(f.read(), kpass)
|
|
677
|
+
if not keystore[0]:
|
|
678
|
+
raise RuntimeError('No private key entry in keystore')
|
|
679
|
+
if not keystore[1]:
|
|
680
|
+
raise RuntimeError('No certificate entry in keystore')
|
|
681
|
+
write_private_key(keyfilename, keystore[0])
|
|
682
|
+
write_cert_chain(certfilename, keystore[1], keystore[2])
|
|
683
|
+
use_key_cert = True
|
|
684
|
+
|
|
685
|
+
if use_key_cert:
|
|
686
|
+
self.ssl_config['keyfile'] = keyfilename
|
|
687
|
+
self.ssl_config['certfile'] = certfilename
|
|
688
|
+
|
|
689
|
+
# pkcs12 truststore to single PEM file with cert chain
|
|
690
|
+
use_ca_cert = False
|
|
691
|
+
if 'truststore' in p12_config and p12_config['truststore']:
|
|
692
|
+
tpass = p12_config.get('truststorepassword')
|
|
693
|
+
if tpass: tpass = tpass.encode('utf-8')
|
|
694
|
+
tfname = self.__unique_name(p12_config['truststore'], checksum(p12_config['truststore']))
|
|
695
|
+
cafilename = tfname + '.ca.cert.pem'
|
|
696
|
+
|
|
697
|
+
if os.path.exists(cafilename):
|
|
698
|
+
use_ca_cert = os.path.getsize(cafilename) > 0
|
|
699
|
+
else:
|
|
700
|
+
with open(p12_config['truststore'], 'rb') as f:
|
|
701
|
+
truststore = pkcs12.load_key_and_certificates(f.read(), tpass)
|
|
702
|
+
if not (truststore[1] or truststore[2]):
|
|
703
|
+
raise RuntimeError('No certificates in truststore')
|
|
704
|
+
write_cert_chain(cafilename, truststore[1], truststore[2])
|
|
705
|
+
use_ca_cert = True
|
|
706
|
+
|
|
707
|
+
if use_ca_cert:
|
|
708
|
+
self.ssl_config['ca_certs'] = cafilename
|
|
709
|
+
self.ssl_config['cert_reqs'] = ssl.CERT_REQUIRED
|
|
710
|
+
|
|
711
|
+
__protodict = { }
|
|
712
|
+
if ssl_available and sys.hexversion >= 0x03070000: # for ssl.TLSVersion
|
|
713
|
+
__protodict = { 'TLSv1.2': ssl.TLSVersion.TLSv1_2,
|
|
714
|
+
'TLSv1.3': ssl.TLSVersion.TLSv1_3 }
|
|
715
|
+
|
|
716
|
+
def __select_protocols(self, protstr):
|
|
717
|
+
proto = [ self.__protodict[p] for p in protstr.split(',')
|
|
718
|
+
if p in self.__protodict ]
|
|
719
|
+
return (min(proto), max(proto)) if proto else (None, None)
|
|
720
|
+
|
|
721
|
+
def __select_ciphers(self, context, namestr):
|
|
722
|
+
supported = list(c['name'] for c in context.get_ciphers())
|
|
723
|
+
selected = []
|
|
724
|
+
for name in namestr.split(','):
|
|
725
|
+
if name in supported:
|
|
726
|
+
selected.append(name)
|
|
727
|
+
else:
|
|
728
|
+
name2 = self.__java_to_openssl_cipher(name)
|
|
729
|
+
if name2 in supported:
|
|
730
|
+
selected.append(name2)
|
|
731
|
+
return ':'.join(selected)
|
|
732
|
+
|
|
733
|
+
__tlspattern = re.compile(r'^TLS_(.*)_WITH_(.*)_(.*)$')
|
|
734
|
+
__aespattern = re.compile(r'^(AES)-(\d*)(-.*)$')
|
|
735
|
+
|
|
736
|
+
def __java_to_openssl_cipher(self, name):
|
|
737
|
+
m = self.__tlspattern.match(name)
|
|
738
|
+
if not m:
|
|
739
|
+
return name
|
|
740
|
+
handshake = m.group(1).replace('_','-')
|
|
741
|
+
cipher = m.group(2).replace('_','-')
|
|
742
|
+
hmac = m.group(3)
|
|
743
|
+
m = self.__aespattern.match(cipher)
|
|
744
|
+
if m:
|
|
745
|
+
cipher = m.group(1) + m.group(2) + m.group(3)
|
|
746
|
+
return handshake + '-' + cipher + '-' + hmac
|
|
747
|
+
|
|
748
|
+
def __create_temp(self, filename):
|
|
749
|
+
f = open(filename, 'w')
|
|
750
|
+
os.chmod(filename, stat.S_IRUSR|stat.S_IWUSR)
|
|
751
|
+
temporary_files.append(filename)
|
|
752
|
+
return f
|
|
753
|
+
|
|
754
|
+
def __compileStructs(self):
|
|
755
|
+
# Compiled structs for each type
|
|
756
|
+
self.byteType = lambda length : '%c%db' % (self.inputBOM, length)
|
|
757
|
+
self.ubyteType = lambda length : '%c%dB' % (self.inputBOM, length)
|
|
758
|
+
self.int16Type = lambda length : '%c%dh' % (self.inputBOM, length)
|
|
759
|
+
self.int32Type = lambda length : '%c%di' % (self.inputBOM, length)
|
|
760
|
+
self.int64Type = lambda length : '%c%dq' % (self.inputBOM, length)
|
|
761
|
+
self.uint64Type = lambda length : '%c%dQ' % (self.inputBOM, length)
|
|
762
|
+
self.float64Type = lambda length : '%c%dd' % (self.inputBOM, length)
|
|
763
|
+
self.stringType = lambda length : '%c%ds' % (self.inputBOM, length)
|
|
764
|
+
self.varbinaryType = lambda length : '%c%ds' % (self.inputBOM, length)
|
|
765
|
+
|
|
766
|
+
def close(self):
|
|
767
|
+
if self.dump_file != None:
|
|
768
|
+
self.dump_file.close()
|
|
769
|
+
self.socket.close()
|
|
770
|
+
|
|
771
|
+
def authenticate(self, username, password):
|
|
772
|
+
# Requires sending a length preceded username and password even if
|
|
773
|
+
# authentication is turned off.
|
|
774
|
+
|
|
775
|
+
#protocol version
|
|
776
|
+
self.writeByte(1)
|
|
777
|
+
#sha256
|
|
778
|
+
self.writeByte(1)
|
|
779
|
+
|
|
780
|
+
# service requested
|
|
781
|
+
if (self.usekerberos):
|
|
782
|
+
self.writeString("kerberos")
|
|
783
|
+
else:
|
|
784
|
+
self.writeString("database")
|
|
785
|
+
|
|
786
|
+
if username:
|
|
787
|
+
# utf8 encode supplied username or kerberos principal name
|
|
788
|
+
self.writeString(username)
|
|
789
|
+
else:
|
|
790
|
+
# no username, just output length of 0
|
|
791
|
+
self.writeString("")
|
|
792
|
+
|
|
793
|
+
# password supplied, sha-256 hash it
|
|
794
|
+
m = hashlib.sha256()
|
|
795
|
+
encoded_password = password.encode("utf-8")
|
|
796
|
+
m.update(encoded_password)
|
|
797
|
+
pwHash = bytearray(m.digest())
|
|
798
|
+
self.wbuf.extend(pwHash)
|
|
799
|
+
|
|
800
|
+
self.prependLength()
|
|
801
|
+
self.flush()
|
|
802
|
+
|
|
803
|
+
ioerror_message = "ERROR: Connection failed. Please check that the host, port, and ssl settings are correct."
|
|
804
|
+
sslerror_message = "ERROR: Connection failed due to a problem with TLS/SSL, possibly a configuration msismatch."
|
|
805
|
+
|
|
806
|
+
# A length, version number, and status code is returned
|
|
807
|
+
# If there was a problem with ssl handshaking it will probably show up
|
|
808
|
+
# on this first read
|
|
809
|
+
try:
|
|
810
|
+
self.bufferForRead()
|
|
811
|
+
except ssl.SSLError as e:
|
|
812
|
+
error(sslerror_message)
|
|
813
|
+
raise e
|
|
814
|
+
except IOError as e:
|
|
815
|
+
error(ioerror_message)
|
|
816
|
+
raise e
|
|
817
|
+
except socket.timeout:
|
|
818
|
+
raise RuntimeError("Authentication timed out after %d seconds."
|
|
819
|
+
% self.socket.gettimeout())
|
|
820
|
+
version = self.readByte()
|
|
821
|
+
status = self.readByte()
|
|
822
|
+
|
|
823
|
+
if (version == self.AUTH_HANDSHAKE_VERSION):
|
|
824
|
+
#service name supplied by VoltDB Server
|
|
825
|
+
service_string = self.readString().encode('ascii','ignore')
|
|
826
|
+
try:
|
|
827
|
+
service_name = gssapi.Name(service_string, name_type=gssapi.NameType.kerberos_principal)
|
|
828
|
+
ctx = gssapi.SecurityContext(name=service_name, mech=gssapi.MechType.kerberos)
|
|
829
|
+
in_token = None
|
|
830
|
+
out_token = ctx.step(in_token)
|
|
831
|
+
while not ctx.complete:
|
|
832
|
+
self.writeByte(self.AUTH_HANDSHAKE_VERSION)
|
|
833
|
+
self.writeByte(self.AUTH_HANDSHAKE)
|
|
834
|
+
self.wbuf.extend(out_token)
|
|
835
|
+
self.prependLength()
|
|
836
|
+
self.flush()
|
|
837
|
+
|
|
838
|
+
try:
|
|
839
|
+
self.bufferForRead()
|
|
840
|
+
except IOError as e:
|
|
841
|
+
error(ioerror_message)
|
|
842
|
+
raise e
|
|
843
|
+
except socket.timeout:
|
|
844
|
+
raise RuntimeError("Authentication timed out after %d seconds."
|
|
845
|
+
% self.socket.gettimeout())
|
|
846
|
+
version = self.readByte()
|
|
847
|
+
status = self.readByte()
|
|
848
|
+
if version != self.AUTH_HANDSHAKE_VERSION or status != self.AUTH_HANDSHAKE:
|
|
849
|
+
raise RuntimeError("Authentication failed.")
|
|
850
|
+
|
|
851
|
+
in_token = self.readVarbinaryContent(self.read_buffer.remaining()).tobytes()
|
|
852
|
+
out_token = ctx.step(in_token)
|
|
853
|
+
|
|
854
|
+
try:
|
|
855
|
+
self.bufferForRead()
|
|
856
|
+
except IOError as e:
|
|
857
|
+
error(ioerror_message)
|
|
858
|
+
raise e
|
|
859
|
+
except socket.timeout:
|
|
860
|
+
raise RuntimeError("Authentication timed out after %d seconds."
|
|
861
|
+
% self.socket.gettimeout())
|
|
862
|
+
version = self.readByte()
|
|
863
|
+
status = self.readByte()
|
|
864
|
+
|
|
865
|
+
except Exception as e:
|
|
866
|
+
raise RuntimeError("Authentication failed.")
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
if status != 0:
|
|
870
|
+
reason = "Authentication failed."
|
|
871
|
+
# Must match assignments in Constants.java
|
|
872
|
+
status_text = ("Server has too many connections.",
|
|
873
|
+
"Connection timed out during authentication.",
|
|
874
|
+
"Wire protocol format violation error.",
|
|
875
|
+
"Failed to authenticate to rejoining node.",
|
|
876
|
+
"Export not enabled for server.",
|
|
877
|
+
"Server requires use of TLS/SSL.",
|
|
878
|
+
"Client certificate required for mutual authentication.")
|
|
879
|
+
if status > 0 and status <= len(status_text):
|
|
880
|
+
reason = status_text[status-1]
|
|
881
|
+
raise RuntimeError(reason)
|
|
882
|
+
|
|
883
|
+
self.readInt32()
|
|
884
|
+
self.readInt64()
|
|
885
|
+
self.readInt64()
|
|
886
|
+
self.readInt32()
|
|
887
|
+
for x in range(self.readInt32()):
|
|
888
|
+
self.readByte()
|
|
889
|
+
|
|
890
|
+
def has_ticket(self):
|
|
891
|
+
'''
|
|
892
|
+
Checks to see if the user has a valid ticket.
|
|
893
|
+
'''
|
|
894
|
+
default_cred = None
|
|
895
|
+
retval = False
|
|
896
|
+
try:
|
|
897
|
+
default_cred = gssapi.creds.Credentials(usage='initiate')
|
|
898
|
+
if default_cred.lifetime > 0:
|
|
899
|
+
self.kerberosprincipal = str(default_cred.name)
|
|
900
|
+
retval = True
|
|
901
|
+
else:
|
|
902
|
+
error("ERROR: Kerberos principal found but login expired.")
|
|
903
|
+
except gssapi.raw.misc.GSSError as e:
|
|
904
|
+
error("ERROR: unable to find default principal from Kerberos cache.")
|
|
905
|
+
return retval
|
|
906
|
+
|
|
907
|
+
def setInputByteOrder(self, bom):
|
|
908
|
+
# assuming bom is high bit set?
|
|
909
|
+
if bom == 1:
|
|
910
|
+
self.inputBOM = self.LITTLE_ENDIAN
|
|
911
|
+
else:
|
|
912
|
+
self.inputBOM = self.BIG_ENDIAN
|
|
913
|
+
|
|
914
|
+
# recompile the structs
|
|
915
|
+
self.__compileStructs()
|
|
916
|
+
|
|
917
|
+
def prependLength(self):
|
|
918
|
+
# write 32 bit array length at offset 0, NOT including the
|
|
919
|
+
# size of this length preceding value. This value is written
|
|
920
|
+
# in the network order.
|
|
921
|
+
ttllen = self.wbuf.buffer_info()[1] * self.wbuf.itemsize
|
|
922
|
+
lenBytes = int32toBytes(ttllen)
|
|
923
|
+
#lenBytes = struct.pack(self.inputBOM + 'i', ttllen)
|
|
924
|
+
[self.wbuf.insert(0, x) for x in lenBytes[::-1]]
|
|
925
|
+
|
|
926
|
+
def size(self):
|
|
927
|
+
"""Returns the size of the write buffer.
|
|
928
|
+
"""
|
|
929
|
+
|
|
930
|
+
return (self.wbuf.buffer_info()[1] * self.wbuf.itemsize)
|
|
931
|
+
|
|
932
|
+
def flush(self):
|
|
933
|
+
if self.socket is None:
|
|
934
|
+
error("ERROR: not connected to server.")
|
|
935
|
+
raise IOError("No Connection")
|
|
936
|
+
|
|
937
|
+
if self.dump_file != None:
|
|
938
|
+
self.dump_file.write(self.wbuf)
|
|
939
|
+
self.dump_file.write(b"\n")
|
|
940
|
+
self.socket.sendall(self.wbuf.tobytes())
|
|
941
|
+
self.wbuf = array.array('B')
|
|
942
|
+
|
|
943
|
+
def bufferForRead(self):
|
|
944
|
+
if self.socket is None:
|
|
945
|
+
error("ERROR: not connected to server.")
|
|
946
|
+
raise IOError("No Connection")
|
|
947
|
+
|
|
948
|
+
# fully buffer a new length preceded message from socket
|
|
949
|
+
# read the length. the read until the buffer is completed.
|
|
950
|
+
responseprefix = bytes()
|
|
951
|
+
while (len(responseprefix) < 4):
|
|
952
|
+
responseprefix += self.socket.recv(4 - len(responseprefix))
|
|
953
|
+
if responseprefix == b'':
|
|
954
|
+
raise IOError("Connection broken")
|
|
955
|
+
if self.dump_file != None:
|
|
956
|
+
self.dump_file.write(responseprefix)
|
|
957
|
+
responseLength = struct.unpack(self.int32Type(1), responseprefix)[0]
|
|
958
|
+
self.read_buffer.clear()
|
|
959
|
+
remaining = responseLength
|
|
960
|
+
while remaining > 0:
|
|
961
|
+
message = self.socket.recv(remaining)
|
|
962
|
+
self.read_buffer.append(message)
|
|
963
|
+
remaining = responseLength - self.read_buffer.buffer_length()
|
|
964
|
+
if not self.dump_file is None:
|
|
965
|
+
self.dump_file.write(self.read_buffer.get_buffer())
|
|
966
|
+
self.dump_file.write(b"\n")
|
|
967
|
+
|
|
968
|
+
def read(self, type):
|
|
969
|
+
if type not in self.READER:
|
|
970
|
+
error("ERROR: can't read wire type(%d) yet." % (type))
|
|
971
|
+
raise IOError("ERROR: can't read wire type(%d) yet." % (type))
|
|
972
|
+
|
|
973
|
+
return self.READER[type]()
|
|
974
|
+
|
|
975
|
+
def write(self, type, value):
|
|
976
|
+
if type not in self.WRITER:
|
|
977
|
+
error("ERROR: can't write wire type(%d) yet." % (type))
|
|
978
|
+
raise IOError("ERROR: can't write wire type(%d) yet." % (type))
|
|
979
|
+
|
|
980
|
+
return self.WRITER[type](value)
|
|
981
|
+
|
|
982
|
+
def readWireType(self):
|
|
983
|
+
type = self.readByte()
|
|
984
|
+
return self.read(type)
|
|
985
|
+
|
|
986
|
+
def writeWireType(self, type, value):
|
|
987
|
+
if type not in self.WRITER:
|
|
988
|
+
error("ERROR: can't write wire type(%d) yet." % (type))
|
|
989
|
+
raise IOError("ERROR: can't write wire type(%d) yet." % (type))
|
|
990
|
+
|
|
991
|
+
self.writeByte(type)
|
|
992
|
+
return self.write(type, value)
|
|
993
|
+
|
|
994
|
+
def getRawBytes(self):
|
|
995
|
+
return self.wbuf
|
|
996
|
+
|
|
997
|
+
def writeRawBytes(self, value):
|
|
998
|
+
"""Appends the given raw bytes to the end of the write buffer.
|
|
999
|
+
"""
|
|
1000
|
+
|
|
1001
|
+
self.wbuf.extend(value)
|
|
1002
|
+
|
|
1003
|
+
def __str__(self):
|
|
1004
|
+
return repr(self.wbuf)
|
|
1005
|
+
|
|
1006
|
+
def readArray(self, type):
|
|
1007
|
+
if type not in self.ARRAY_READER:
|
|
1008
|
+
error("ERROR: can't read wire type(%d) yet." % (type))
|
|
1009
|
+
raise IOError("ERROR: can't write wire type(%d) yet." % (type))
|
|
1010
|
+
|
|
1011
|
+
return self.ARRAY_READER[type]()
|
|
1012
|
+
|
|
1013
|
+
def readNull(self):
|
|
1014
|
+
return None
|
|
1015
|
+
|
|
1016
|
+
def writeNull(self, value):
|
|
1017
|
+
return
|
|
1018
|
+
|
|
1019
|
+
def writeArray(self, type, array):
|
|
1020
|
+
if (not array) or (len(array) == 0) or (not type):
|
|
1021
|
+
return
|
|
1022
|
+
|
|
1023
|
+
if type not in self.ARRAY_READER:
|
|
1024
|
+
error("ERROR: Unsupported data type (%d)." % (type))
|
|
1025
|
+
raise IOError("ERROR: Unsupported data type (%d)." % (type))
|
|
1026
|
+
|
|
1027
|
+
# serialize arrays of bytes as larger values to support
|
|
1028
|
+
# strings and varbinary input
|
|
1029
|
+
if type != FastSerializer.VOLTTYPE_TINYINT:
|
|
1030
|
+
self.writeInt16(len(array))
|
|
1031
|
+
else:
|
|
1032
|
+
self.writeInt32(len(array))
|
|
1033
|
+
|
|
1034
|
+
for i in array:
|
|
1035
|
+
self.WRITER[type](i)
|
|
1036
|
+
|
|
1037
|
+
def writeWireTypeArray(self, type, array):
|
|
1038
|
+
if type not in self.ARRAY_READER:
|
|
1039
|
+
error("ERROR: can't write wire type(%d) yet." % (type))
|
|
1040
|
+
raise IOError("ERROR: can't write wire type(%d) yet." % (type))
|
|
1041
|
+
|
|
1042
|
+
self.writeByte(type)
|
|
1043
|
+
self.writeArray(type, array)
|
|
1044
|
+
|
|
1045
|
+
# byte
|
|
1046
|
+
def readByteArrayContent(self, cnt):
|
|
1047
|
+
offset = cnt * struct.calcsize('b')
|
|
1048
|
+
return self.read_buffer.unpack(self.byteType(cnt), offset)
|
|
1049
|
+
|
|
1050
|
+
def readByteArray(self):
|
|
1051
|
+
length = self.readInt32()
|
|
1052
|
+
val = self.readByteArrayContent(length)
|
|
1053
|
+
val = list(map(self.NullCheck[self.VOLTTYPE_TINYINT], val))
|
|
1054
|
+
return val
|
|
1055
|
+
|
|
1056
|
+
def readByte(self):
|
|
1057
|
+
val = self.readByteArrayContent(1)[0]
|
|
1058
|
+
return self.NullCheck[self.VOLTTYPE_TINYINT](val)
|
|
1059
|
+
|
|
1060
|
+
def readByteRaw(self):
|
|
1061
|
+
val = self.readByteArrayContent(1)[0]
|
|
1062
|
+
if val > 127:
|
|
1063
|
+
return val - 256
|
|
1064
|
+
else:
|
|
1065
|
+
return val
|
|
1066
|
+
|
|
1067
|
+
def writeByte(self, value):
|
|
1068
|
+
if value == None:
|
|
1069
|
+
value = self.__class__.NULL_TINYINT_INDICATOR
|
|
1070
|
+
if value < 0:
|
|
1071
|
+
value += 256
|
|
1072
|
+
self.wbuf.append(value)
|
|
1073
|
+
|
|
1074
|
+
# int16
|
|
1075
|
+
def readInt16ArrayContent(self, cnt):
|
|
1076
|
+
offset = cnt * struct.calcsize('h')
|
|
1077
|
+
return self.read_buffer.unpack(self.int16Type(cnt), offset)
|
|
1078
|
+
|
|
1079
|
+
def readInt16Array(self):
|
|
1080
|
+
length = self.readInt16()
|
|
1081
|
+
val = self.readInt16ArrayContent(length)
|
|
1082
|
+
val = list(map(self.NullCheck[self.VOLTTYPE_SMALLINT], val))
|
|
1083
|
+
return val
|
|
1084
|
+
|
|
1085
|
+
def readInt16(self):
|
|
1086
|
+
val = self.readInt16ArrayContent(1)[0]
|
|
1087
|
+
return self.NullCheck[self.VOLTTYPE_SMALLINT](val)
|
|
1088
|
+
|
|
1089
|
+
def writeInt16(self, value):
|
|
1090
|
+
if value == None:
|
|
1091
|
+
val = self.__class__.NULL_SMALLINT_INDICATOR
|
|
1092
|
+
else:
|
|
1093
|
+
val = value
|
|
1094
|
+
self.wbuf.extend(int16toBytes(val))
|
|
1095
|
+
|
|
1096
|
+
# int32
|
|
1097
|
+
def readInt32ArrayContent(self, cnt):
|
|
1098
|
+
offset = cnt * struct.calcsize('i')
|
|
1099
|
+
return self.read_buffer.unpack(self.int32Type(cnt), offset)
|
|
1100
|
+
|
|
1101
|
+
def readInt32Array(self):
|
|
1102
|
+
length = self.readInt16()
|
|
1103
|
+
val = self.readInt32ArrayContent(length)
|
|
1104
|
+
val = list(map(self.NullCheck[self.VOLTTYPE_INTEGER], val))
|
|
1105
|
+
return val
|
|
1106
|
+
|
|
1107
|
+
def readInt32(self):
|
|
1108
|
+
val = self.readInt32ArrayContent(1)[0]
|
|
1109
|
+
return self.NullCheck[self.VOLTTYPE_INTEGER](val)
|
|
1110
|
+
|
|
1111
|
+
def writeInt32(self, value):
|
|
1112
|
+
if value == None:
|
|
1113
|
+
val = self.__class__.NULL_INTEGER_INDICATOR
|
|
1114
|
+
else:
|
|
1115
|
+
val = value
|
|
1116
|
+
self.wbuf.extend(int32toBytes(val))
|
|
1117
|
+
|
|
1118
|
+
# int64
|
|
1119
|
+
def readInt64ArrayContent(self, cnt):
|
|
1120
|
+
offset = cnt * struct.calcsize('q')
|
|
1121
|
+
return self.read_buffer.unpack(self.int64Type(cnt), offset)
|
|
1122
|
+
|
|
1123
|
+
def readInt64Array(self):
|
|
1124
|
+
length = self.readInt16()
|
|
1125
|
+
val = self.readInt64ArrayContent(length)
|
|
1126
|
+
val = list(map(self.NullCheck[self.VOLTTYPE_BIGINT], val))
|
|
1127
|
+
return val
|
|
1128
|
+
|
|
1129
|
+
def readInt64(self):
|
|
1130
|
+
val = self.readInt64ArrayContent(1)[0]
|
|
1131
|
+
return self.NullCheck[self.VOLTTYPE_BIGINT](val)
|
|
1132
|
+
|
|
1133
|
+
def writeInt64(self, value):
|
|
1134
|
+
if value == None:
|
|
1135
|
+
val = self.__class__.NULL_BIGINT_INDICATOR
|
|
1136
|
+
else:
|
|
1137
|
+
val = value
|
|
1138
|
+
self.wbuf.extend(int64toBytes(val))
|
|
1139
|
+
|
|
1140
|
+
# float64
|
|
1141
|
+
def readFloat64ArrayContent(self, cnt):
|
|
1142
|
+
offset = cnt * struct.calcsize('d')
|
|
1143
|
+
return self.read_buffer.unpack(self.float64Type(cnt), offset)
|
|
1144
|
+
|
|
1145
|
+
def readFloat64Array(self):
|
|
1146
|
+
length = self.readInt16()
|
|
1147
|
+
val = self.readFloat64ArrayContent(length)
|
|
1148
|
+
val = list(map(self.NullCheck[self.VOLTTYPE_FLOAT], val))
|
|
1149
|
+
return val
|
|
1150
|
+
|
|
1151
|
+
def readFloat64(self):
|
|
1152
|
+
val = self.readFloat64ArrayContent(1)[0]
|
|
1153
|
+
return self.NullCheck[self.VOLTTYPE_FLOAT](val)
|
|
1154
|
+
|
|
1155
|
+
def writeFloat64(self, value):
|
|
1156
|
+
if value == None:
|
|
1157
|
+
val = self.__class__.NULL_FLOAT_INDICATOR
|
|
1158
|
+
else:
|
|
1159
|
+
val = float(value)
|
|
1160
|
+
ba = bytearray(struct.pack(self.float64Type(1), val))
|
|
1161
|
+
self.wbuf.extend(ba)
|
|
1162
|
+
|
|
1163
|
+
# string
|
|
1164
|
+
def readStringContent(self, cnt):
|
|
1165
|
+
if cnt == 0:
|
|
1166
|
+
return ""
|
|
1167
|
+
|
|
1168
|
+
offset = cnt * struct.calcsize('c')
|
|
1169
|
+
val = self.read_buffer.unpack(self.stringType(cnt), offset)
|
|
1170
|
+
return val[0].decode("utf-8")
|
|
1171
|
+
|
|
1172
|
+
def readString(self):
|
|
1173
|
+
# length preceeded (4 byte value) string
|
|
1174
|
+
length = self.readInt32()
|
|
1175
|
+
if self.NullCheck[self.VOLTTYPE_STRING](length) == None:
|
|
1176
|
+
return None
|
|
1177
|
+
return self.readStringContent(length)
|
|
1178
|
+
|
|
1179
|
+
def readStringArray(self):
|
|
1180
|
+
retval = []
|
|
1181
|
+
cnt = self.readInt16()
|
|
1182
|
+
|
|
1183
|
+
for i in range(cnt):
|
|
1184
|
+
retval.append(self.readString())
|
|
1185
|
+
|
|
1186
|
+
return tuple(retval)
|
|
1187
|
+
|
|
1188
|
+
def writeString(self, value):
|
|
1189
|
+
if value is None:
|
|
1190
|
+
self.writeInt32(self.NULL_STRING_INDICATOR)
|
|
1191
|
+
return
|
|
1192
|
+
|
|
1193
|
+
encoded_value = value.encode("utf-8")
|
|
1194
|
+
ba = bytearray(encoded_value)
|
|
1195
|
+
self.writeInt32(len(encoded_value))
|
|
1196
|
+
self.wbuf.extend(ba)
|
|
1197
|
+
|
|
1198
|
+
# varbinary
|
|
1199
|
+
def readVarbinaryContent(self, cnt):
|
|
1200
|
+
if cnt == 0:
|
|
1201
|
+
return array.array('B', [])
|
|
1202
|
+
|
|
1203
|
+
offset = cnt * struct.calcsize('c')
|
|
1204
|
+
val = self.read_buffer.unpack(self.varbinaryType(cnt), offset)
|
|
1205
|
+
|
|
1206
|
+
return array.array('B', val[0])
|
|
1207
|
+
|
|
1208
|
+
def readVarbinary(self):
|
|
1209
|
+
# length preceeded (4 byte value) string
|
|
1210
|
+
length = self.readInt32()
|
|
1211
|
+
if self.NullCheck[self.VOLTTYPE_VARBINARY](length) == None:
|
|
1212
|
+
return None
|
|
1213
|
+
return self.readVarbinaryContent(length)
|
|
1214
|
+
|
|
1215
|
+
def writeVarbinary(self, value):
|
|
1216
|
+
if value is None:
|
|
1217
|
+
self.writeInt32(self.NULL_STRING_INDICATOR)
|
|
1218
|
+
return
|
|
1219
|
+
|
|
1220
|
+
self.writeInt32(len(value))
|
|
1221
|
+
self.wbuf.extend(value)
|
|
1222
|
+
|
|
1223
|
+
# date
|
|
1224
|
+
# The timestamp we receive from the server is a 64-bit integer representing
|
|
1225
|
+
# microseconds since the epoch. It will be converted to a datetime object in
|
|
1226
|
+
# the local timezone.
|
|
1227
|
+
def readDate(self):
|
|
1228
|
+
raw = self.readInt64()
|
|
1229
|
+
if raw == None:
|
|
1230
|
+
return None
|
|
1231
|
+
# microseconds before or after Jan 1, 1970 UTC
|
|
1232
|
+
return datetime.datetime.fromtimestamp(raw/1000000.0)
|
|
1233
|
+
|
|
1234
|
+
def readDateArray(self):
|
|
1235
|
+
retval = []
|
|
1236
|
+
raw = self.readInt64Array()
|
|
1237
|
+
|
|
1238
|
+
for i in raw:
|
|
1239
|
+
val = None
|
|
1240
|
+
if i != None:
|
|
1241
|
+
val = datetime.datetime.fromtimestamp(i/1000000.0)
|
|
1242
|
+
retval.append(val)
|
|
1243
|
+
|
|
1244
|
+
return tuple(retval)
|
|
1245
|
+
|
|
1246
|
+
def writeDate(self, value):
|
|
1247
|
+
if value is None:
|
|
1248
|
+
val = self.__class__.NULL_BIGINT_INDICATOR
|
|
1249
|
+
else:
|
|
1250
|
+
seconds = int(value.strftime("%s"))
|
|
1251
|
+
val = seconds * 1000000 + value.microsecond
|
|
1252
|
+
self.wbuf.extend(int64toBytes(val))
|
|
1253
|
+
|
|
1254
|
+
def readDecimal(self):
|
|
1255
|
+
offset = 16 * struct.calcsize('b')
|
|
1256
|
+
if self.NullCheck[self.VOLTTYPE_DECIMAL](self.read_buffer.read(offset)) == None:
|
|
1257
|
+
self.read_buffer.shift(offset)
|
|
1258
|
+
return None
|
|
1259
|
+
val = list(self.read_buffer.unpack(self.ubyteType(16), offset))
|
|
1260
|
+
mostSignificantBit = 1 << 7
|
|
1261
|
+
isNegative = (val[0] & mostSignificantBit) != 0
|
|
1262
|
+
unscaledValue = -(val[0] & mostSignificantBit) << 120
|
|
1263
|
+
# Clear the highest bit
|
|
1264
|
+
# Unleash the powers of the butterfly
|
|
1265
|
+
val[0] &= ~mostSignificantBit
|
|
1266
|
+
# Get the 2's complement
|
|
1267
|
+
for x in range(16):
|
|
1268
|
+
unscaledValue += val[x] << ((15 - x) * 8)
|
|
1269
|
+
unscaledValue = [int(x) for x in str(abs(unscaledValue))]
|
|
1270
|
+
return decimal.Decimal((isNegative, tuple(unscaledValue),
|
|
1271
|
+
-self.__class__.DEFAULT_DECIMAL_SCALE))
|
|
1272
|
+
|
|
1273
|
+
def readDecimalArray(self):
|
|
1274
|
+
retval = []
|
|
1275
|
+
cnt = self.readInt16()
|
|
1276
|
+
for i in range(cnt):
|
|
1277
|
+
retval.append(self.readDecimal())
|
|
1278
|
+
return tuple(retval)
|
|
1279
|
+
|
|
1280
|
+
def __intToBytes(self, value, sign):
|
|
1281
|
+
value_bytes = bytes()
|
|
1282
|
+
if sign == 1:
|
|
1283
|
+
value = ~value + 1 # 2's complement
|
|
1284
|
+
# Turn into byte array
|
|
1285
|
+
while value != 0 and value != -1:
|
|
1286
|
+
byte = value & 0xff
|
|
1287
|
+
# flip the high order bits to 1 only if the number is negative and
|
|
1288
|
+
# this is the highest order byte
|
|
1289
|
+
if value >> 8 == 0 and sign == 1:
|
|
1290
|
+
mask = 1 << 7
|
|
1291
|
+
while mask > 0 and (byte & mask) == 0:
|
|
1292
|
+
byte |= mask
|
|
1293
|
+
mask >> 1
|
|
1294
|
+
value_bytes = struct.pack(self.ubyteType(1), byte) + value_bytes
|
|
1295
|
+
value = value >> 8
|
|
1296
|
+
if len(value_bytes) > 16:
|
|
1297
|
+
raise ValueError("Precision of this decimal is >38 digits");
|
|
1298
|
+
if sign == 1:
|
|
1299
|
+
ret = struct.pack(self.ubyteType(1), 0xff)
|
|
1300
|
+
else:
|
|
1301
|
+
ret = struct.pack(self.ubyteType(1), 0)
|
|
1302
|
+
# Pad it
|
|
1303
|
+
ret *= 16 - len(value_bytes)
|
|
1304
|
+
ret += value_bytes
|
|
1305
|
+
return ret
|
|
1306
|
+
|
|
1307
|
+
def writeDecimal(self, num):
|
|
1308
|
+
if num is None:
|
|
1309
|
+
self.wbuf.extend(self.NULL_DECIMAL_INDICATOR)
|
|
1310
|
+
return
|
|
1311
|
+
if not isinstance(num, decimal.Decimal):
|
|
1312
|
+
raise TypeError("num must be of the type decimal.Decimal")
|
|
1313
|
+
(sign, digits, exponent) = num.as_tuple()
|
|
1314
|
+
precision = len(digits)
|
|
1315
|
+
scale = -exponent
|
|
1316
|
+
if (scale > self.__class__.DEFAULT_DECIMAL_SCALE):
|
|
1317
|
+
raise ValueError("Scale of this decimal is %d and the max is 12"
|
|
1318
|
+
% (scale))
|
|
1319
|
+
rest = precision - scale
|
|
1320
|
+
if rest > 26:
|
|
1321
|
+
raise ValueError("Precision to the left of the decimal point is %d"
|
|
1322
|
+
" and the max is 26" % (rest))
|
|
1323
|
+
scale_factor = self.__class__.DEFAULT_DECIMAL_SCALE - scale
|
|
1324
|
+
unscaled_int = int(decimal.Decimal((0, digits, scale_factor)))
|
|
1325
|
+
data = self.__intToBytes(unscaled_int, sign)
|
|
1326
|
+
self.wbuf.extend(data)
|
|
1327
|
+
|
|
1328
|
+
def writeDecimalString(self, num):
|
|
1329
|
+
if num is None:
|
|
1330
|
+
self.writeString(None)
|
|
1331
|
+
return
|
|
1332
|
+
if not isinstance(num, decimal.Decimal):
|
|
1333
|
+
raise TypeError("num must be of type decimal.Decimal")
|
|
1334
|
+
self.writeString(num.to_eng_string())
|
|
1335
|
+
|
|
1336
|
+
# cash!
|
|
1337
|
+
def readMoney(self):
|
|
1338
|
+
# money-unit * 10,000
|
|
1339
|
+
return self.readInt64()
|
|
1340
|
+
|
|
1341
|
+
def readGeographyPoint(self):
|
|
1342
|
+
# returns a tuple of a pair of doubles representing lat,long
|
|
1343
|
+
lng = self.readFloat64()
|
|
1344
|
+
lat = self.readFloat64()
|
|
1345
|
+
if (lat == Geography.NULL_COORD) and (lon == Geography.NULL_COORD):
|
|
1346
|
+
return None
|
|
1347
|
+
return (lng, lat)
|
|
1348
|
+
|
|
1349
|
+
def readGeographyPointArray(self):
|
|
1350
|
+
retval = []
|
|
1351
|
+
cnt = self.readInt16()
|
|
1352
|
+
for i in range(cnt):
|
|
1353
|
+
retval.append(self.readGeographyPoint())
|
|
1354
|
+
return tuple(retval)
|
|
1355
|
+
|
|
1356
|
+
def writeGeographyPoint(self, point):
|
|
1357
|
+
if point is None:
|
|
1358
|
+
self.writeFloat64(Geography.NULL_COORD)
|
|
1359
|
+
self.writeFloat64(Geography.NULL_COORD)
|
|
1360
|
+
return
|
|
1361
|
+
if not isinstance(num, tuple):
|
|
1362
|
+
raise TypeError("point must be a 2-tuple of floats")
|
|
1363
|
+
if len(tuple) != 2:
|
|
1364
|
+
raise TypeError("point must be a 2-tuple of floats")
|
|
1365
|
+
self.writeFloat64(point[0])
|
|
1366
|
+
self.writeFloat64(point[1])
|
|
1367
|
+
|
|
1368
|
+
def readGeography(self):
|
|
1369
|
+
return Geography.unflatten(self)
|
|
1370
|
+
|
|
1371
|
+
def readGeographyArray(self):
|
|
1372
|
+
retval = []
|
|
1373
|
+
cnt = self.readInt16()
|
|
1374
|
+
for i in range(cnt):
|
|
1375
|
+
retval.append(Geography.unflatten(self))
|
|
1376
|
+
return tuple(retval)
|
|
1377
|
+
|
|
1378
|
+
def writeGeography(self, geo):
|
|
1379
|
+
if geo is None:
|
|
1380
|
+
writeInt32(NULL_STRING_INDICATOR)
|
|
1381
|
+
else:
|
|
1382
|
+
geo.flatten(self)
|
|
1383
|
+
|
|
1384
|
+
class XYZPoint(object):
|
|
1385
|
+
"""
|
|
1386
|
+
Google's S2 geometry library uses (x, y, z) representation of polygon vertices,
|
|
1387
|
+
But the interface we expose to users is (lat, lng). This class is the
|
|
1388
|
+
internal representation for vertices.
|
|
1389
|
+
"""
|
|
1390
|
+
def __init__(self, x, y, z):
|
|
1391
|
+
self.x = x
|
|
1392
|
+
self.y = y
|
|
1393
|
+
self.z = z
|
|
1394
|
+
|
|
1395
|
+
@staticmethod
|
|
1396
|
+
def fromGeographyPoint(p):
|
|
1397
|
+
latRadians = p[0] * (math.pi / 180) # AKA phi
|
|
1398
|
+
lngRadians = p[1] * (math.pi / 180) # AKA theta
|
|
1399
|
+
|
|
1400
|
+
cosPhi = math.cos(latRadians)
|
|
1401
|
+
x = math.cos(lngRadians) * cosPhi
|
|
1402
|
+
y = math.sin(lngRadians) * cosPhi
|
|
1403
|
+
z = math.sin(latRadians)
|
|
1404
|
+
|
|
1405
|
+
return XYZPoint(x, y, z)
|
|
1406
|
+
|
|
1407
|
+
def toGeogrpahyPoint(self):
|
|
1408
|
+
latRadians = math.atan2(self.z, math.sqrt(self.x * self.x + self.y * self.y))
|
|
1409
|
+
lngRadians = math.atan2(self.y, self.x)
|
|
1410
|
+
|
|
1411
|
+
latDegrees = latRadians * (180 / math.pi)
|
|
1412
|
+
lngDegrees = lngRadians * (180 / math.pi)
|
|
1413
|
+
return (lngDegrees, latDegrees)
|
|
1414
|
+
|
|
1415
|
+
def __eq__(self, other):
|
|
1416
|
+
"""Overrides the default implementation"""
|
|
1417
|
+
if isinstance(self, other.__class__):
|
|
1418
|
+
return self.__dict__ == other.__dict__
|
|
1419
|
+
return False
|
|
1420
|
+
|
|
1421
|
+
def __ne__(self, other):
|
|
1422
|
+
"""Overrides the default implementation (unnecessary in Python 3)"""
|
|
1423
|
+
return not self.__eq__(other)
|
|
1424
|
+
|
|
1425
|
+
def __str__(self):
|
|
1426
|
+
p = self.toGeogrpahyPoint()
|
|
1427
|
+
return "(%s,%s)" % (p[0], p[1])
|
|
1428
|
+
|
|
1429
|
+
class Geography(object):
|
|
1430
|
+
"""
|
|
1431
|
+
S2-esque geography element representing a polygon for now
|
|
1432
|
+
"""
|
|
1433
|
+
|
|
1434
|
+
EPSILON = 1.0e-12
|
|
1435
|
+
NULL_COORD = 360.0
|
|
1436
|
+
|
|
1437
|
+
def __init__(self, loops=[]):
|
|
1438
|
+
self.loops = loops
|
|
1439
|
+
|
|
1440
|
+
# Serialization format for polygons.
|
|
1441
|
+
#
|
|
1442
|
+
# This is the format used by S2 in the EE. Most of the
|
|
1443
|
+
# metadata (especially lat/lng rect bounding boxes) are
|
|
1444
|
+
# ignored here in Java.
|
|
1445
|
+
#
|
|
1446
|
+
# 1 byte encoding version
|
|
1447
|
+
# 1 byte boolean owns_loops
|
|
1448
|
+
# 1 byte boolean has_holes
|
|
1449
|
+
# 4 bytes number of loops
|
|
1450
|
+
# And then for each loop:
|
|
1451
|
+
# 1 byte encoding version
|
|
1452
|
+
# 4 bytes number of vertices
|
|
1453
|
+
# ((number of vertices) * sizeof(double) * 3) bytes vertices as XYZPoints
|
|
1454
|
+
# 1 byte boolean origin_inside
|
|
1455
|
+
# 4 bytes depth (nesting level of loop)
|
|
1456
|
+
# 33 bytes bounding box
|
|
1457
|
+
# 33 bytes bounding box
|
|
1458
|
+
#
|
|
1459
|
+
# We use S2 in the EE for all geometric computation, so polygons sent to
|
|
1460
|
+
# the EE will be missing bounding box and other info. We indicate this
|
|
1461
|
+
# by passing INCOMPLETE_ENCODING_FROM_JAVA in the version field. This
|
|
1462
|
+
# tells the EE to compute bounding boxes and other metadata before storing
|
|
1463
|
+
# the polygon to memory.
|
|
1464
|
+
|
|
1465
|
+
# for encoding byte + lat min, lat max, lng min, lng max as doubles
|
|
1466
|
+
BOUND_LENGTH_IN_BYTES = 33
|
|
1467
|
+
POLYGON_OVERHEAD_IN_BYTES = 7 + BOUND_LENGTH_IN_BYTES
|
|
1468
|
+
# 1 byte for encoding version
|
|
1469
|
+
# 4 bytes for number of vertices
|
|
1470
|
+
# number of vertices * 8 * 3 bytes for vertices as XYZPoints
|
|
1471
|
+
# 1 byte for origin_inside_
|
|
1472
|
+
# 4 bytes for depth_
|
|
1473
|
+
# length of bound
|
|
1474
|
+
LOOP_OVERHEAD_IN_BYTES = 10 + BOUND_LENGTH_IN_BYTES
|
|
1475
|
+
VERTEX_SIZE_IN_BYTES = 24
|
|
1476
|
+
|
|
1477
|
+
def serializedSize(self):
|
|
1478
|
+
length = POLYGON_OVERHEAD_IN_BYTES
|
|
1479
|
+
for loop in self.loops:
|
|
1480
|
+
length += loopSerializedSize(loop);
|
|
1481
|
+
return length
|
|
1482
|
+
|
|
1483
|
+
@staticmethod
|
|
1484
|
+
def loopSerializedSize(loop):
|
|
1485
|
+
LOOP_OVERHEAD_IN_BYTES + (len(loop) * VERTEX_SIZE_IN_BYTES)
|
|
1486
|
+
|
|
1487
|
+
@staticmethod
|
|
1488
|
+
def unflatten(fs):
|
|
1489
|
+
length = fs.readInt32() # size
|
|
1490
|
+
if (length == fs.NULL_STRING_INDICATOR):
|
|
1491
|
+
return None
|
|
1492
|
+
|
|
1493
|
+
version = fs.readByteRaw() # encoding version
|
|
1494
|
+
fs.readByteRaw() # owns loops
|
|
1495
|
+
fs.readByteRaw() # has holes
|
|
1496
|
+
numLoops = fs.readInt32()
|
|
1497
|
+
loops = []
|
|
1498
|
+
indexOfOuterRing = 0
|
|
1499
|
+
for i in range(numLoops):
|
|
1500
|
+
depth, loop = Geography.__unflattenLoop(fs)
|
|
1501
|
+
if depth == 0:
|
|
1502
|
+
indexOfOuterRing = i
|
|
1503
|
+
loops.append(loop)
|
|
1504
|
+
|
|
1505
|
+
Geography.__unflattenBound(fs)
|
|
1506
|
+
|
|
1507
|
+
return Geography(loops);
|
|
1508
|
+
|
|
1509
|
+
@staticmethod
|
|
1510
|
+
def __unflattenLoop(fs):
|
|
1511
|
+
# 1 byte for encoding version
|
|
1512
|
+
# 4 bytes for number of vertices
|
|
1513
|
+
# number of vertices * 8 * 3 bytes for vertices as XYZPoints
|
|
1514
|
+
# 1 byte for origin_inside_
|
|
1515
|
+
# 4 bytes for depth_
|
|
1516
|
+
# length of bound
|
|
1517
|
+
|
|
1518
|
+
loop = []
|
|
1519
|
+
fs.readByteRaw() # encoding version
|
|
1520
|
+
numVertices = fs.readInt32()
|
|
1521
|
+
for i in range(numVertices):
|
|
1522
|
+
x = fs.readFloat64()
|
|
1523
|
+
y = fs.readFloat64()
|
|
1524
|
+
z = fs.readFloat64()
|
|
1525
|
+
loop.append(XYZPoint(x, y, z))
|
|
1526
|
+
|
|
1527
|
+
fs.readByteRaw() # origin_inside_
|
|
1528
|
+
depth = fs.readInt32() # depth
|
|
1529
|
+
Geography.__unflattenBound(fs);
|
|
1530
|
+
return (depth, loop)
|
|
1531
|
+
|
|
1532
|
+
@staticmethod
|
|
1533
|
+
def __unflattenBound(fs):
|
|
1534
|
+
fs.readByteRaw() # for encoding version
|
|
1535
|
+
fs.readFloat64()
|
|
1536
|
+
fs.readFloat64()
|
|
1537
|
+
fs.readFloat64()
|
|
1538
|
+
fs.readFloat64()
|
|
1539
|
+
|
|
1540
|
+
def flatten(self, fs):
|
|
1541
|
+
fs.writeInt32(self.serializedSize()) # prepend length
|
|
1542
|
+
|
|
1543
|
+
fs.writeByte(0); # encoding version
|
|
1544
|
+
fs.writeByte(1); # owns_loops
|
|
1545
|
+
|
|
1546
|
+
if len(self.loops) > 1: # has_holes
|
|
1547
|
+
fs.writeByte(1)
|
|
1548
|
+
else:
|
|
1549
|
+
fs.writeByte(0)
|
|
1550
|
+
fs.writeInt32(len(self.loops))
|
|
1551
|
+
depth = 0
|
|
1552
|
+
for loop in self.loops:
|
|
1553
|
+
Geography.__flattenLoop(loop, depth, fs);
|
|
1554
|
+
depth = 1;
|
|
1555
|
+
Geography.__flattenEmptyBound(fs);
|
|
1556
|
+
|
|
1557
|
+
@staticmethod
|
|
1558
|
+
def __flattenLoop(loop, depth, fs):
|
|
1559
|
+
# 1 byte for encoding version
|
|
1560
|
+
# 4 bytes for number of vertices
|
|
1561
|
+
# number of vertices * 8 * 3 bytes for vertices as XYZPoints
|
|
1562
|
+
# 1 byte for origin_inside_
|
|
1563
|
+
# 4 bytes for depth_
|
|
1564
|
+
# length of bound
|
|
1565
|
+
fs.writeByte(0);
|
|
1566
|
+
fs.writeInt32(len(loop))
|
|
1567
|
+
for xyzp in loop:
|
|
1568
|
+
fs.writeFloat64(xyzp.x)
|
|
1569
|
+
fs.writeFloat64(xyzp.y)
|
|
1570
|
+
fs.writeFloat64(xyzp.z)
|
|
1571
|
+
|
|
1572
|
+
fs.writeByte(0); # origin_inside
|
|
1573
|
+
fs.writeInt32(depth); # depth
|
|
1574
|
+
Geography.__flattenEmptyBound(fs);
|
|
1575
|
+
|
|
1576
|
+
@staticmethod
|
|
1577
|
+
def __flattenEmptyBound(fs):
|
|
1578
|
+
fs.writeByte(0); # for encoding version
|
|
1579
|
+
fs.writeFloat64(Geography.NULL_COORD)
|
|
1580
|
+
fs.writeFloat64(Geography.NULL_COORD)
|
|
1581
|
+
fs.writeFloat64(Geography.NULL_COORD)
|
|
1582
|
+
fs.writeFloat64(Geography.NULL_COORD)
|
|
1583
|
+
|
|
1584
|
+
@staticmethod
|
|
1585
|
+
def formatPoint(point):
|
|
1586
|
+
# auto convert XYZ points
|
|
1587
|
+
if isinstance(point, XYZPoint):
|
|
1588
|
+
point = point.toGeogrpahyPoint()
|
|
1589
|
+
|
|
1590
|
+
fmt = "{}"
|
|
1591
|
+
#DecimalFormat df = new DecimalFormat("##0.0###########");
|
|
1592
|
+
|
|
1593
|
+
# Explicitly test for differences less than 1.0e-12 and
|
|
1594
|
+
# force them to be zero. Otherwise you may find a case
|
|
1595
|
+
# where two points differ in the less significant bits, but
|
|
1596
|
+
# they format as the same number.
|
|
1597
|
+
lng = point[0]
|
|
1598
|
+
if lng < Geography.EPSILON:
|
|
1599
|
+
lng = 0.0
|
|
1600
|
+
lat = point[1]
|
|
1601
|
+
if lat < Geography.EPSILON:
|
|
1602
|
+
lat = 0.0
|
|
1603
|
+
return fmt.format(lng) + " " + fmt.format(lat);
|
|
1604
|
+
|
|
1605
|
+
@staticmethod
|
|
1606
|
+
def pointToWKT(point):
|
|
1607
|
+
# auto convert XYZ points
|
|
1608
|
+
if isinstance(point, XYZPoint):
|
|
1609
|
+
point = point.toGeogrpahyPoint()
|
|
1610
|
+
|
|
1611
|
+
# This is not GEOGRAPHY_POINT. This is wkt syntax.
|
|
1612
|
+
return "POINT (" + Geography.formatGeographyPoint(point) + ")"
|
|
1613
|
+
|
|
1614
|
+
|
|
1615
|
+
wktPointMatcher = re.compile(r"^\s*point\s*\(\s*(-?\d+[\.\d*]*)\s+(-?\d+[\.\d*]*)\s*\)", flags=re.IGNORECASE)
|
|
1616
|
+
@staticmethod
|
|
1617
|
+
def pointFromWKT(wkt):
|
|
1618
|
+
if wkt is None:
|
|
1619
|
+
raise ValueError("None passed to pointFromWKT")
|
|
1620
|
+
match = re.search()
|
|
1621
|
+
lngStr = match.group(1)
|
|
1622
|
+
latStr = match.group(2)
|
|
1623
|
+
if latStr is None or lngStr is None:
|
|
1624
|
+
return None
|
|
1625
|
+
lng = float(lngStr)
|
|
1626
|
+
lat = float(latStr)
|
|
1627
|
+
return (lng, lat)
|
|
1628
|
+
|
|
1629
|
+
@staticmethod
|
|
1630
|
+
def geographyFromWKT(wkt):
|
|
1631
|
+
pass
|
|
1632
|
+
|
|
1633
|
+
def __str__(self):
|
|
1634
|
+
# return representation in Well Known Text (WKT)
|
|
1635
|
+
wkt = "POLYGON ("
|
|
1636
|
+
|
|
1637
|
+
isFirstLoop = True
|
|
1638
|
+
for loop in self.loops:
|
|
1639
|
+
if not isFirstLoop:
|
|
1640
|
+
wkt += ", "
|
|
1641
|
+
wkt += "("
|
|
1642
|
+
|
|
1643
|
+
# iterate backwards
|
|
1644
|
+
startIdx = len(loop) - 1
|
|
1645
|
+
endIdx = 0
|
|
1646
|
+
increment = -1
|
|
1647
|
+
# reverse direction for first loop
|
|
1648
|
+
if isFirstLoop:
|
|
1649
|
+
startIdx = 1
|
|
1650
|
+
endIdx = len(loop)
|
|
1651
|
+
increment = 1
|
|
1652
|
+
|
|
1653
|
+
wkt += Geography.formatPoint(loop[0]) + ", "
|
|
1654
|
+
for idx in range(startIdx, endIdx, increment):
|
|
1655
|
+
xyzp = loop[idx]
|
|
1656
|
+
wkt += Geography.formatPoint(xyzp) + ", "
|
|
1657
|
+
|
|
1658
|
+
# Repeat the start vertex to close the loop as WKT requires.
|
|
1659
|
+
wkt += Geography.formatPoint(loop[0]) + ")"
|
|
1660
|
+
isFirstLoop = False
|
|
1661
|
+
|
|
1662
|
+
wkt += ")"
|
|
1663
|
+
return wkt
|
|
1664
|
+
|
|
1665
|
+
def __repr__(self):
|
|
1666
|
+
return self.__str__()
|
|
1667
|
+
|
|
1668
|
+
class VoltColumn:
|
|
1669
|
+
"definition of one VoltDB table column"
|
|
1670
|
+
def __init__(self, fser = None, type = None, name = None):
|
|
1671
|
+
if fser != None:
|
|
1672
|
+
self.type = fser.readByte()
|
|
1673
|
+
self.name = None
|
|
1674
|
+
elif type != None and name != None:
|
|
1675
|
+
self.type = type
|
|
1676
|
+
self.name = name
|
|
1677
|
+
|
|
1678
|
+
def __str__(self):
|
|
1679
|
+
# If the name is empty, use the default "modified tuples". Has to do
|
|
1680
|
+
# this because HSQLDB doesn't return a column name if the table is
|
|
1681
|
+
# empty.
|
|
1682
|
+
return "(%s: %d)" % (self.name or "modified tuples" ,
|
|
1683
|
+
self.type)
|
|
1684
|
+
|
|
1685
|
+
def __eq__(self, other):
|
|
1686
|
+
# For now, if we've been through the query on a column with no name,
|
|
1687
|
+
# just assume that there's no way the types are matching up cleanly
|
|
1688
|
+
# and there ain't no one for to give us no pain
|
|
1689
|
+
if (not self.name or not other.name):
|
|
1690
|
+
return True
|
|
1691
|
+
return (self.type == other.type and self.name == other.name)
|
|
1692
|
+
|
|
1693
|
+
def readName(self, fser):
|
|
1694
|
+
self.name = fser.readString()
|
|
1695
|
+
|
|
1696
|
+
def writeType(self, fser):
|
|
1697
|
+
fser.writeByte(self.type)
|
|
1698
|
+
|
|
1699
|
+
def writeName(self, fser):
|
|
1700
|
+
fser.writeString(self.name)
|
|
1701
|
+
|
|
1702
|
+
class VoltTable:
|
|
1703
|
+
"definition and content of one VoltDB table"
|
|
1704
|
+
def __init__(self, fser):
|
|
1705
|
+
self.fser = fser
|
|
1706
|
+
self.columns = [] # column definitions
|
|
1707
|
+
self.tuples = []
|
|
1708
|
+
|
|
1709
|
+
def __str__(self):
|
|
1710
|
+
result = ""
|
|
1711
|
+
|
|
1712
|
+
result += "column count: %d\n" % (len(self.columns))
|
|
1713
|
+
result += "row count: %d\n" % (len(self.tuples))
|
|
1714
|
+
result += "cols: "
|
|
1715
|
+
result += ", ".join([str(x) for x in self.columns])
|
|
1716
|
+
result += "\n"
|
|
1717
|
+
result += "rows -\n"
|
|
1718
|
+
result += "\n".join([str(["NULL" if y is None else y for y in x]) for x in self.tuples])
|
|
1719
|
+
|
|
1720
|
+
return result
|
|
1721
|
+
|
|
1722
|
+
def __getstate__(self):
|
|
1723
|
+
return (self.columns, self.tuples)
|
|
1724
|
+
|
|
1725
|
+
def __setstate__(self, state):
|
|
1726
|
+
self.fser = None
|
|
1727
|
+
self.columns, self.tuples = state
|
|
1728
|
+
|
|
1729
|
+
def __eq__(self, other):
|
|
1730
|
+
if len(self.tuples) > 0:
|
|
1731
|
+
return (self.columns == other.columns) and \
|
|
1732
|
+
(self.tuples == other.tuples)
|
|
1733
|
+
return (self.tuples == other.tuples)
|
|
1734
|
+
|
|
1735
|
+
# The VoltTable is always serialized in big-endian order.
|
|
1736
|
+
#
|
|
1737
|
+
# How to read a table off the wire.
|
|
1738
|
+
# 1. Read the length of the whole table
|
|
1739
|
+
# 2. Read the columns
|
|
1740
|
+
# a. read the column header size
|
|
1741
|
+
# a. read the column count
|
|
1742
|
+
# b. read column definitions.
|
|
1743
|
+
# 3. Read the tuples count.
|
|
1744
|
+
# a. read the row count
|
|
1745
|
+
# b. read tuples recording string lengths
|
|
1746
|
+
def readFromSerializer(self):
|
|
1747
|
+
# 1.
|
|
1748
|
+
tablesize = self.fser.readInt32()
|
|
1749
|
+
limit_position = self.fser.read_buffer._off + tablesize
|
|
1750
|
+
# 2.
|
|
1751
|
+
headersize = self.fser.readInt32()
|
|
1752
|
+
statuscode = self.fser.readByte()
|
|
1753
|
+
columncount = self.fser.readInt16()
|
|
1754
|
+
for i in range(columncount):
|
|
1755
|
+
column = VoltColumn(fser = self.fser)
|
|
1756
|
+
self.columns.append(column)
|
|
1757
|
+
list([x.readName(self.fser) for x in self.columns])
|
|
1758
|
+
|
|
1759
|
+
# 3.
|
|
1760
|
+
rowcount = self.fser.readInt32()
|
|
1761
|
+
for i in range(rowcount):
|
|
1762
|
+
rowsize = self.fser.readInt32()
|
|
1763
|
+
# list comprehension: build list by calling read for each column in
|
|
1764
|
+
# row/tuple
|
|
1765
|
+
row = [self.fser.read(self.columns[j].type)
|
|
1766
|
+
for j in range(columncount)]
|
|
1767
|
+
self.tuples.append(row)
|
|
1768
|
+
|
|
1769
|
+
# advance offset to end of table-size on read_buffer
|
|
1770
|
+
if self.fser.read_buffer._off != limit_position:
|
|
1771
|
+
self.fser.read_buffer._off = limit_position
|
|
1772
|
+
|
|
1773
|
+
return self
|
|
1774
|
+
|
|
1775
|
+
def writeToSerializer(self):
|
|
1776
|
+
table_fser = FastSerializer()
|
|
1777
|
+
|
|
1778
|
+
# We have to pack the header into a buffer first so that we can
|
|
1779
|
+
# calculate the size
|
|
1780
|
+
header_fser = FastSerializer()
|
|
1781
|
+
|
|
1782
|
+
header_fser.writeByte(0)
|
|
1783
|
+
header_fser.writeInt16(len(self.columns))
|
|
1784
|
+
list([x.writeType(header_fser) for x in self.columns])
|
|
1785
|
+
list([x.writeName(header_fser) for x in self.columns])
|
|
1786
|
+
|
|
1787
|
+
table_fser.writeInt32(header_fser.size() - 4)
|
|
1788
|
+
table_fser.writeRawBytes(header_fser.getRawBytes())
|
|
1789
|
+
|
|
1790
|
+
table_fser.writeInt32(len(self.tuples))
|
|
1791
|
+
for i in self.tuples:
|
|
1792
|
+
row_fser = FastSerializer()
|
|
1793
|
+
|
|
1794
|
+
list([row_fser.write(self.columns[x].type, i[x]) for x in range(len(i))])
|
|
1795
|
+
|
|
1796
|
+
table_fser.writeInt32(row_fser.size())
|
|
1797
|
+
table_fser.writeRawBytes(row_fser.getRawBytes())
|
|
1798
|
+
|
|
1799
|
+
table_fser.prependLength()
|
|
1800
|
+
self.fser.writeRawBytes(table_fser.getRawBytes())
|
|
1801
|
+
|
|
1802
|
+
|
|
1803
|
+
class VoltException:
|
|
1804
|
+
# Volt SerializableException enumerations
|
|
1805
|
+
VOLTEXCEPTION_NONE = 0
|
|
1806
|
+
VOLTEXCEPTION_EEEXCEPTION = 1
|
|
1807
|
+
VOLTEXCEPTION_SQLEXCEPTION = 2
|
|
1808
|
+
VOLTEXCEPTION_CONSTRAINTFAILURE = 3
|
|
1809
|
+
VOLTEXCEPTION_GENERIC = 4
|
|
1810
|
+
|
|
1811
|
+
def __init__(self, fser):
|
|
1812
|
+
self.type = self.VOLTEXCEPTION_NONE
|
|
1813
|
+
self.typestr = "None"
|
|
1814
|
+
self.message = ""
|
|
1815
|
+
|
|
1816
|
+
if fser != None:
|
|
1817
|
+
self.deserialize(fser)
|
|
1818
|
+
|
|
1819
|
+
def deserialize(self, fser):
|
|
1820
|
+
self.length = fser.readInt32()
|
|
1821
|
+
if self.length == 0:
|
|
1822
|
+
self.type = self.VOLTEXCEPTION_NONE
|
|
1823
|
+
return
|
|
1824
|
+
self.type = fser.readByte()
|
|
1825
|
+
# quick and dirty exception skipping
|
|
1826
|
+
if self.type == self.VOLTEXCEPTION_NONE:
|
|
1827
|
+
return
|
|
1828
|
+
|
|
1829
|
+
self.message = []
|
|
1830
|
+
self.message_len = fser.readInt32()
|
|
1831
|
+
for i in range(0, self.message_len):
|
|
1832
|
+
self.message.append(chr(fser.readByte()))
|
|
1833
|
+
self.message = ''.join(self.message)
|
|
1834
|
+
|
|
1835
|
+
if self.type == self.VOLTEXCEPTION_GENERIC:
|
|
1836
|
+
self.typestr = "Generic"
|
|
1837
|
+
elif self.type == self.VOLTEXCEPTION_EEEXCEPTION:
|
|
1838
|
+
self.typestr = "EE Exception"
|
|
1839
|
+
# serialized size from EEException.java is 4 bytes
|
|
1840
|
+
self.error_code = fser.readInt32()
|
|
1841
|
+
elif self.type == self.VOLTEXCEPTION_SQLEXCEPTION or \
|
|
1842
|
+
self.type == self.VOLTEXCEPTION_CONSTRAINTFAILURE:
|
|
1843
|
+
self.sql_state_bytes = []
|
|
1844
|
+
for i in range(0, 5):
|
|
1845
|
+
self.sql_state_bytes.append(chr(fser.readByte()))
|
|
1846
|
+
self.sql_state_bytes = ''.join(self.sql_state_bytes)
|
|
1847
|
+
|
|
1848
|
+
if self.type == self.VOLTEXCEPTION_SQLEXCEPTION:
|
|
1849
|
+
self.typestr = "SQL Exception"
|
|
1850
|
+
else:
|
|
1851
|
+
self.typestr = "Constraint Failure"
|
|
1852
|
+
self.constraint_type = fser.readInt32()
|
|
1853
|
+
self.table_name = fser.readString()
|
|
1854
|
+
self.buffer_size = fser.readInt32()
|
|
1855
|
+
self.buffer = []
|
|
1856
|
+
for i in range(0, self.buffer_size):
|
|
1857
|
+
self.buffer.append(fser.readByte())
|
|
1858
|
+
else:
|
|
1859
|
+
for i in range(0, self.length - 3 - 2 - self.message_len):
|
|
1860
|
+
fser.readByte()
|
|
1861
|
+
error("Python client deserialized unknown VoltException.")
|
|
1862
|
+
|
|
1863
|
+
def __str__(self):
|
|
1864
|
+
msgstr = "VoltException: type: %s\n" % self.typestr
|
|
1865
|
+
if self.type == self.VOLTEXCEPTION_EEEXCEPTION:
|
|
1866
|
+
msgstr += " Error code: %d\n" % self.error_code
|
|
1867
|
+
elif self.type == self.VOLTEXCEPTION_SQLEXCEPTION:
|
|
1868
|
+
msgstr += " SQL code: "
|
|
1869
|
+
msgstr += self.sql_state_bytes
|
|
1870
|
+
elif self.type == self.VOLTEXCEPTION_SQLEXCEPTION:
|
|
1871
|
+
msgstr += " Constraint violation type: %d\n" + self.constraint_type
|
|
1872
|
+
msgstr += " on table: %s\n" + self.table_name
|
|
1873
|
+
return msgstr
|
|
1874
|
+
|
|
1875
|
+
class VoltResponse:
|
|
1876
|
+
"VoltDB called procedure response (ClientResponse.java)"
|
|
1877
|
+
def __init__(self, fser):
|
|
1878
|
+
self.fser = fser
|
|
1879
|
+
self.version = -1
|
|
1880
|
+
self.clientHandle = -1
|
|
1881
|
+
self.status = -1
|
|
1882
|
+
self.statusString = ""
|
|
1883
|
+
self.appStatus = -1
|
|
1884
|
+
self.appStatusString = ""
|
|
1885
|
+
self.roundtripTime = -1
|
|
1886
|
+
self.exception = None
|
|
1887
|
+
self.tables = None
|
|
1888
|
+
|
|
1889
|
+
if fser != None:
|
|
1890
|
+
self.deserialize(fser)
|
|
1891
|
+
|
|
1892
|
+
def deserialize(self, fser):
|
|
1893
|
+
# serialization order: response-length, status, roundtripTime, exception,
|
|
1894
|
+
# tables[], info, id.
|
|
1895
|
+
fser.bufferForRead()
|
|
1896
|
+
self.version = fser.readByte()
|
|
1897
|
+
self.clientHandle = fser.readInt64()
|
|
1898
|
+
presentFields = fser.readByteRaw();
|
|
1899
|
+
self.status = fser.readByte()
|
|
1900
|
+
if presentFields & (1 << 5) != 0:
|
|
1901
|
+
self.statusString = fser.readString()
|
|
1902
|
+
else:
|
|
1903
|
+
self.statusString = None
|
|
1904
|
+
self.appStatus = fser.readByte()
|
|
1905
|
+
if presentFields & (1 << 7) != 0:
|
|
1906
|
+
self.appStatusString = fser.readString()
|
|
1907
|
+
else:
|
|
1908
|
+
self.appStatusString = None
|
|
1909
|
+
self.roundtripTime = fser.readInt32()
|
|
1910
|
+
if presentFields & (1 << 6) != 0:
|
|
1911
|
+
self.exception = VoltException(fser)
|
|
1912
|
+
else:
|
|
1913
|
+
self.exception = None
|
|
1914
|
+
|
|
1915
|
+
# tables[]
|
|
1916
|
+
tablecount = fser.readInt16()
|
|
1917
|
+
self.tables = []
|
|
1918
|
+
for i in range(tablecount):
|
|
1919
|
+
table = VoltTable(fser)
|
|
1920
|
+
self.tables.append(table.readFromSerializer())
|
|
1921
|
+
|
|
1922
|
+
def __str__(self):
|
|
1923
|
+
tablestr=""
|
|
1924
|
+
if self.tables != None:
|
|
1925
|
+
tablestr = "\n\n".join([str(i) for i in self.tables])
|
|
1926
|
+
if self.exception is None:
|
|
1927
|
+
return "Status: %d\nInformation: %s\n%s" % (self.status,
|
|
1928
|
+
self.statusString,
|
|
1929
|
+
tablestr)
|
|
1930
|
+
else:
|
|
1931
|
+
msgstr = "Status: %d\nInformation: %s\n%s\n" % (self.status,
|
|
1932
|
+
self.statusString,
|
|
1933
|
+
tablestr)
|
|
1934
|
+
msgstr += "Exception: %s" % (self.exception)
|
|
1935
|
+
return msgstr
|
|
1936
|
+
|
|
1937
|
+
class VoltProcedure:
|
|
1938
|
+
"VoltDB called procedure interface"
|
|
1939
|
+
def __init__(self, fser, name, paramtypes = []):
|
|
1940
|
+
self.fser = fser # FastSerializer object
|
|
1941
|
+
self.name = name # procedure class name
|
|
1942
|
+
self.paramtypes = paramtypes # list of fser.WIRE_* values
|
|
1943
|
+
|
|
1944
|
+
def call(self, params = None, response = True, timeout = None):
|
|
1945
|
+
self.fser.writeByte(0) # version number
|
|
1946
|
+
self.fser.writeString(self.name)
|
|
1947
|
+
self.fser.writeInt64(1) # client handle
|
|
1948
|
+
self.fser.writeInt16(len(self.paramtypes))
|
|
1949
|
+
for i in range(len(self.paramtypes)):
|
|
1950
|
+
if self.as_array(self.paramtypes[i], params[i]):
|
|
1951
|
+
self.fser.writeByte(FastSerializer.ARRAY)
|
|
1952
|
+
self.fser.writeByte(self.paramtypes[i])
|
|
1953
|
+
self.fser.writeArray(self.paramtypes[i], params[i])
|
|
1954
|
+
else:
|
|
1955
|
+
self.fser.writeWireType(self.paramtypes[i], params[i])
|
|
1956
|
+
self.fser.prependLength() # prepend the total length of the invocation
|
|
1957
|
+
self.fser.flush()
|
|
1958
|
+
|
|
1959
|
+
# The timeout in effect for the procedure call is the timeout argument
|
|
1960
|
+
# if not None or self.procedure_timeout. Exceeding that time will raise
|
|
1961
|
+
# a timeout exception. Restores the original timeout value when done.
|
|
1962
|
+
# This default argument usage does not allow overriding with None.
|
|
1963
|
+
if timeout is None:
|
|
1964
|
+
timeout = self.fser.procedure_timeout
|
|
1965
|
+
original_timeout = self.fser.socket.gettimeout()
|
|
1966
|
+
self.fser.socket.settimeout(timeout)
|
|
1967
|
+
try:
|
|
1968
|
+
try:
|
|
1969
|
+
res = VoltResponse(self.fser)
|
|
1970
|
+
except socket.timeout:
|
|
1971
|
+
res = VoltResponse(None)
|
|
1972
|
+
res.statusString = "timeout: procedure call took longer than %d seconds" % timeout
|
|
1973
|
+
except IOError as err:
|
|
1974
|
+
res = VoltResponse(None)
|
|
1975
|
+
res.statusString = str(err)
|
|
1976
|
+
finally:
|
|
1977
|
+
self.fser.socket.settimeout(original_timeout)
|
|
1978
|
+
return response and res or None
|
|
1979
|
+
|
|
1980
|
+
def as_array(self, paramtype, param):
|
|
1981
|
+
try:
|
|
1982
|
+
iter(param) # throws TypeError if not a python array type
|
|
1983
|
+
if isinstance(param, str):
|
|
1984
|
+
return False
|
|
1985
|
+
if isinstance(param, bytes) or isinstance(param, bytearray):
|
|
1986
|
+
return paramtype != FastSerializer.VOLTTYPE_VARBINARY # as non-array if we want varbinary
|
|
1987
|
+
return True
|
|
1988
|
+
except TypeError:
|
|
1989
|
+
return False
|
|
1990
|
+
|
|
1991
|
+
# Reads a properties file that is broadly compatible
|
|
1992
|
+
# with the forms supported by Java, in particular for
|
|
1993
|
+
# allowable separators between key and value. Note,
|
|
1994
|
+
# the key can optionally be converted to lower case
|
|
1995
|
+
# for compatibility with the previous implementation.
|
|
1996
|
+
|
|
1997
|
+
def read_properties_file(filename, lowerkeys=False):
|
|
1998
|
+
separator = re.compile(r'\s*[=:]\s*|\s+')
|
|
1999
|
+
properties = {}
|
|
2000
|
+
count = 0
|
|
2001
|
+
with open(filename, mode='rt', encoding='utf-8') as f:
|
|
2002
|
+
for line in f:
|
|
2003
|
+
count = count + 1
|
|
2004
|
+
line = line.strip()
|
|
2005
|
+
if line and line[0] != '#' and line[0] != '!':
|
|
2006
|
+
m = separator.search(line)
|
|
2007
|
+
if m and m.start() > 0:
|
|
2008
|
+
key = line[:m.start()]
|
|
2009
|
+
val = line[m.end():]
|
|
2010
|
+
properties[key.lower() if lowerkeys else key] = val
|
|
2011
|
+
else:
|
|
2012
|
+
raise ValueError('Malformed property at line %s in %s: %s' % (count, filename, line))
|
|
2013
|
+
return properties
|
|
2014
|
+
|
|
2015
|
+
# Expand leading ~ in known pathnames from properties
|
|
2016
|
+
|
|
2017
|
+
def resolve_paths(properties):
|
|
2018
|
+
for key in ('keystore', 'truststore', 'cacerts'):
|
|
2019
|
+
if key in properties and properties[key].startswith('~'):
|
|
2020
|
+
properties[key] = os.path.expanduser(properties[key])
|
|
2021
|
+
return properties
|