scylla-cqlsh 6.0.29__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- copyutil.cp310-win_amd64.pyd +0 -0
- cqlsh/__init__.py +1 -0
- cqlsh/__main__.py +11 -0
- cqlsh/cqlsh.py +2736 -0
- cqlshlib/__init__.py +90 -0
- cqlshlib/_version.py +34 -0
- cqlshlib/authproviderhandling.py +176 -0
- cqlshlib/copyutil.py +2762 -0
- cqlshlib/cql3handling.py +1670 -0
- cqlshlib/cqlhandling.py +333 -0
- cqlshlib/cqlshhandling.py +314 -0
- cqlshlib/displaying.py +128 -0
- cqlshlib/formatting.py +601 -0
- cqlshlib/helptopics.py +190 -0
- cqlshlib/pylexotron.py +562 -0
- cqlshlib/saferscanner.py +91 -0
- cqlshlib/sslhandling.py +109 -0
- cqlshlib/tracing.py +90 -0
- cqlshlib/util.py +183 -0
- cqlshlib/wcwidth.py +379 -0
- scylla_cqlsh-6.0.29.dist-info/METADATA +108 -0
- scylla_cqlsh-6.0.29.dist-info/RECORD +26 -0
- scylla_cqlsh-6.0.29.dist-info/WHEEL +5 -0
- scylla_cqlsh-6.0.29.dist-info/entry_points.txt +2 -0
- scylla_cqlsh-6.0.29.dist-info/licenses/LICENSE.txt +204 -0
- scylla_cqlsh-6.0.29.dist-info/top_level.txt +3 -0
cqlshlib/copyutil.py
ADDED
|
@@ -0,0 +1,2762 @@
|
|
|
1
|
+
# cython: profile=True
|
|
2
|
+
|
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
|
5
|
+
# distributed with this work for additional information
|
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
|
8
|
+
# "License"); you may not use this file except in compliance
|
|
9
|
+
# with the License. You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
|
|
19
|
+
import csv
|
|
20
|
+
import datetime
|
|
21
|
+
import json
|
|
22
|
+
import glob
|
|
23
|
+
import multiprocessing as mp
|
|
24
|
+
import os
|
|
25
|
+
import platform
|
|
26
|
+
import random
|
|
27
|
+
import re
|
|
28
|
+
import signal
|
|
29
|
+
import struct
|
|
30
|
+
import sys
|
|
31
|
+
import threading
|
|
32
|
+
import time
|
|
33
|
+
import traceback
|
|
34
|
+
|
|
35
|
+
# Python 3.14 changed the default to 'forkserver', which is not compatible
|
|
36
|
+
# with our relocatable python. It execs our Python binary, but without our
|
|
37
|
+
# ld.so. Change it back to 'fork' to avoid issues.
|
|
38
|
+
mp.set_start_method('fork')
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
from bisect import bisect_right
|
|
42
|
+
from calendar import timegm
|
|
43
|
+
from collections import defaultdict, namedtuple
|
|
44
|
+
from decimal import Decimal
|
|
45
|
+
from random import randint
|
|
46
|
+
from io import StringIO
|
|
47
|
+
from select import select
|
|
48
|
+
from uuid import UUID
|
|
49
|
+
|
|
50
|
+
import configparser
|
|
51
|
+
from queue import Queue
|
|
52
|
+
|
|
53
|
+
from cassandra import OperationTimedOut
|
|
54
|
+
from cassandra.cluster import Cluster, DefaultConnection
|
|
55
|
+
from cassandra.cqltypes import ReversedType, UserType, VarcharType
|
|
56
|
+
from cassandra.metadata import protect_name, protect_names, protect_value
|
|
57
|
+
from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, FallthroughRetryPolicy
|
|
58
|
+
from cassandra.query import BatchStatement, BatchType, SimpleStatement, tuple_factory
|
|
59
|
+
from cassandra.util import Date, Time
|
|
60
|
+
from cqlshlib.util import profile_on, profile_off
|
|
61
|
+
|
|
62
|
+
from cqlshlib.cql3handling import CqlRuleSet
|
|
63
|
+
from cqlshlib.displaying import NO_COLOR_MAP
|
|
64
|
+
from cqlshlib.formatting import format_value_default, CqlType, DateTimeFormat, EMPTY, get_formatter, BlobType
|
|
65
|
+
from cqlshlib.sslhandling import ssl_settings
|
|
66
|
+
|
|
67
|
+
PROFILE_ON = False
|
|
68
|
+
STRACE_ON = False
|
|
69
|
+
DEBUG = False # This may be set to True when initializing the task
|
|
70
|
+
# TODO: review this for MacOS, maybe use in ('Linux', 'Darwin')
|
|
71
|
+
IS_LINUX = platform.system() == 'Linux'
|
|
72
|
+
|
|
73
|
+
CopyOptions = namedtuple('CopyOptions', 'copy dialect unrecognized')
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def safe_normpath(fname):
|
|
77
|
+
"""
|
|
78
|
+
:return the normalized path but only if there is a filename, we don't want to convert
|
|
79
|
+
an empty string (which means no file name) to a dot. Also expand any user variables such as ~ to the full path
|
|
80
|
+
"""
|
|
81
|
+
return os.path.normpath(os.path.expanduser(fname)) if fname else fname
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def printdebugmsg(msg):
|
|
85
|
+
if DEBUG:
|
|
86
|
+
printmsg(msg)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def printmsg(msg, eol='\n'):
|
|
90
|
+
sys.stdout.write(msg)
|
|
91
|
+
sys.stdout.write(eol)
|
|
92
|
+
sys.stdout.flush()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def noop(*arg, **kwargs):
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class OneWayPipe(object):
|
|
100
|
+
"""
|
|
101
|
+
A one way pipe protected by two process level locks, one for reading and one for writing.
|
|
102
|
+
"""
|
|
103
|
+
def __init__(self):
|
|
104
|
+
self.reader, self.writer = mp.Pipe(duplex=False)
|
|
105
|
+
self.rlock = mp.Lock()
|
|
106
|
+
self.wlock = mp.Lock()
|
|
107
|
+
|
|
108
|
+
def send(self, obj):
|
|
109
|
+
with self.wlock:
|
|
110
|
+
self.writer.send(obj)
|
|
111
|
+
|
|
112
|
+
def recv(self):
|
|
113
|
+
with self.rlock:
|
|
114
|
+
return self.reader.recv()
|
|
115
|
+
|
|
116
|
+
def close(self):
|
|
117
|
+
self.reader.close()
|
|
118
|
+
self.writer.close()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ReceivingChannel(object):
|
|
122
|
+
"""
|
|
123
|
+
A one way channel that wraps a pipe to receive messages.
|
|
124
|
+
"""
|
|
125
|
+
def __init__(self, pipe):
|
|
126
|
+
self.pipe = pipe
|
|
127
|
+
|
|
128
|
+
def recv(self):
|
|
129
|
+
return self.pipe.recv()
|
|
130
|
+
|
|
131
|
+
def close(self):
|
|
132
|
+
self.pipe.close()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class SendingChannel(object):
|
|
136
|
+
"""
|
|
137
|
+
A one way channel that wraps a pipe and provides a feeding thread to send messages asynchronously.
|
|
138
|
+
"""
|
|
139
|
+
def __init__(self, pipe):
|
|
140
|
+
self.pipe = pipe
|
|
141
|
+
self.pending_messages = Queue()
|
|
142
|
+
|
|
143
|
+
def feed():
|
|
144
|
+
while True:
|
|
145
|
+
try:
|
|
146
|
+
msg = self.pending_messages.get()
|
|
147
|
+
self.pipe.send(msg)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
printmsg('%s: %s' % (e.__class__.__name__, e.message if hasattr(e, 'message') else str(e)))
|
|
150
|
+
|
|
151
|
+
feeding_thread = threading.Thread(target=feed)
|
|
152
|
+
feeding_thread.setDaemon(True)
|
|
153
|
+
feeding_thread.start()
|
|
154
|
+
|
|
155
|
+
def send(self, obj):
|
|
156
|
+
self.pending_messages.put(obj)
|
|
157
|
+
|
|
158
|
+
def num_pending(self):
|
|
159
|
+
return self.pending_messages.qsize() if self.pending_messages else 0
|
|
160
|
+
|
|
161
|
+
def close(self):
|
|
162
|
+
self.pipe.close()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class SendingChannels(object):
|
|
166
|
+
"""
|
|
167
|
+
A group of one way channels for sending messages.
|
|
168
|
+
"""
|
|
169
|
+
def __init__(self, num_channels):
|
|
170
|
+
self.pipes = [OneWayPipe() for _ in range(num_channels)]
|
|
171
|
+
self.channels = [SendingChannel(p) for p in self.pipes]
|
|
172
|
+
self.num_channels = num_channels
|
|
173
|
+
self._readers = [p.reader for p in self.pipes]
|
|
174
|
+
|
|
175
|
+
def release_readers(self):
|
|
176
|
+
for reader in self._readers:
|
|
177
|
+
reader.close()
|
|
178
|
+
|
|
179
|
+
def close(self):
|
|
180
|
+
for ch in self.channels:
|
|
181
|
+
try:
|
|
182
|
+
ch.close()
|
|
183
|
+
except ValueError:
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class ReceivingChannels(object):
|
|
188
|
+
"""
|
|
189
|
+
A group of one way channels for receiving messages.
|
|
190
|
+
"""
|
|
191
|
+
def __init__(self, num_channels):
|
|
192
|
+
self.pipes = [OneWayPipe() for _ in range(num_channels)]
|
|
193
|
+
self.channels = [ReceivingChannel(p) for p in self.pipes]
|
|
194
|
+
self._readers = [p.reader for p in self.pipes]
|
|
195
|
+
self._writers = [p.writer for p in self.pipes]
|
|
196
|
+
self._rlocks = [p.rlock for p in self.pipes]
|
|
197
|
+
self._rlocks_by_readers = dict([(p.reader, p.rlock) for p in self.pipes])
|
|
198
|
+
self.num_channels = num_channels
|
|
199
|
+
|
|
200
|
+
self.recv = self.recv_select if IS_LINUX else self.recv_polling
|
|
201
|
+
|
|
202
|
+
def release_writers(self):
|
|
203
|
+
for writer in self._writers:
|
|
204
|
+
writer.close()
|
|
205
|
+
|
|
206
|
+
def recv_select(self, timeout):
|
|
207
|
+
"""
|
|
208
|
+
Implementation of the recv method for Linux, where select is available. Receive an object from
|
|
209
|
+
all pipes that are ready for reading without blocking.
|
|
210
|
+
"""
|
|
211
|
+
while True:
|
|
212
|
+
try:
|
|
213
|
+
readable, _, _ = select(self._readers, [], [], timeout)
|
|
214
|
+
except OSError:
|
|
215
|
+
raise
|
|
216
|
+
else:
|
|
217
|
+
break
|
|
218
|
+
for r in readable:
|
|
219
|
+
with self._rlocks_by_readers[r]:
|
|
220
|
+
try:
|
|
221
|
+
yield r.recv()
|
|
222
|
+
except EOFError:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
def recv_polling(self, timeout):
|
|
226
|
+
"""
|
|
227
|
+
Implementation of the recv method for platforms where select() is not available for pipes.
|
|
228
|
+
We poll on all of the readers with a very small timeout. We stop when the timeout specified
|
|
229
|
+
has been received but we may exceed it since we check all processes during each sweep.
|
|
230
|
+
"""
|
|
231
|
+
start = time.time()
|
|
232
|
+
while True:
|
|
233
|
+
for i, r in enumerate(self._readers):
|
|
234
|
+
with self._rlocks[i]:
|
|
235
|
+
if r.poll(0.000000001):
|
|
236
|
+
try:
|
|
237
|
+
yield r.recv()
|
|
238
|
+
except EOFError:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
if time.time() - start > timeout:
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
def close(self):
|
|
245
|
+
for ch in self.channels:
|
|
246
|
+
try:
|
|
247
|
+
ch.close()
|
|
248
|
+
except ValueError:
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class CopyTask(object):
|
|
253
|
+
"""
|
|
254
|
+
A base class for ImportTask and ExportTask
|
|
255
|
+
"""
|
|
256
|
+
def __init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file, direction):
|
|
257
|
+
self.shell = shell
|
|
258
|
+
self.ks = ks
|
|
259
|
+
self.table = table
|
|
260
|
+
self.table_meta = self.shell.get_table_meta(self.ks, self.table)
|
|
261
|
+
self.host = shell.conn.get_control_connection_host()
|
|
262
|
+
self.fname = safe_normpath(fname)
|
|
263
|
+
self.protocol_version = protocol_version
|
|
264
|
+
self.config_file = config_file
|
|
265
|
+
|
|
266
|
+
# if cqlsh is invoked with --debug then set the global debug flag to True
|
|
267
|
+
if shell.debug:
|
|
268
|
+
global DEBUG
|
|
269
|
+
DEBUG = True
|
|
270
|
+
|
|
271
|
+
# do not display messages when exporting to STDOUT unless --debug is set
|
|
272
|
+
self.printmsg = printmsg if self.fname is not None or direction == 'from' or DEBUG else noop
|
|
273
|
+
self.options = self.parse_options(opts, direction)
|
|
274
|
+
|
|
275
|
+
self.num_processes = self.options.copy['numprocesses']
|
|
276
|
+
self.encoding = self.options.copy['encoding']
|
|
277
|
+
self.printmsg('Using %d child processes' % (self.num_processes,))
|
|
278
|
+
|
|
279
|
+
if direction == 'from':
|
|
280
|
+
self.num_processes += 1 # add the feeder process
|
|
281
|
+
|
|
282
|
+
self.processes = []
|
|
283
|
+
self.inmsg = ReceivingChannels(self.num_processes)
|
|
284
|
+
self.outmsg = SendingChannels(self.num_processes)
|
|
285
|
+
|
|
286
|
+
self.columns = CopyTask.get_columns(shell, ks, table, columns)
|
|
287
|
+
self.time_start = time.time()
|
|
288
|
+
|
|
289
|
+
def maybe_read_config_file(self, opts, direction):
|
|
290
|
+
"""
|
|
291
|
+
Read optional sections from a configuration file that was specified in the command options or from the default
|
|
292
|
+
cqlshrc configuration file if none was specified.
|
|
293
|
+
"""
|
|
294
|
+
config_file = opts.pop('configfile', '')
|
|
295
|
+
if not config_file:
|
|
296
|
+
config_file = self.config_file
|
|
297
|
+
|
|
298
|
+
if not os.path.isfile(config_file):
|
|
299
|
+
return opts
|
|
300
|
+
|
|
301
|
+
configs = configparser.RawConfigParser()
|
|
302
|
+
configs.read_file(open(config_file))
|
|
303
|
+
|
|
304
|
+
ret = dict()
|
|
305
|
+
config_sections = list(['copy', 'copy-%s' % (direction,),
|
|
306
|
+
'copy:%s.%s' % (self.ks, self.table),
|
|
307
|
+
'copy-%s:%s.%s' % (direction, self.ks, self.table)])
|
|
308
|
+
|
|
309
|
+
for section in config_sections:
|
|
310
|
+
if configs.has_section(section):
|
|
311
|
+
options = dict(configs.items(section))
|
|
312
|
+
self.printmsg("Reading options from %s:[%s]: %s" % (config_file, section, options))
|
|
313
|
+
ret.update(options)
|
|
314
|
+
|
|
315
|
+
# Update this last so the command line options take precedence over the configuration file options
|
|
316
|
+
if opts:
|
|
317
|
+
self.printmsg("Reading options from the command line: %s" % (opts,))
|
|
318
|
+
ret.update(opts)
|
|
319
|
+
|
|
320
|
+
if self.shell.debug: # this is important for testing, do not remove
|
|
321
|
+
self.printmsg("Using options: '%s'" % (ret,))
|
|
322
|
+
|
|
323
|
+
return ret
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
def clean_options(opts):
|
|
327
|
+
"""
|
|
328
|
+
Convert all option values to valid string literals unless they are path names
|
|
329
|
+
"""
|
|
330
|
+
return dict([(k, v if k not in ['errfile', 'ratefile'] else v)
|
|
331
|
+
for k, v, in opts.items()])
|
|
332
|
+
|
|
333
|
+
def parse_options(self, opts, direction):
|
|
334
|
+
"""
|
|
335
|
+
Parse options for import (COPY FROM) and export (COPY TO) operations.
|
|
336
|
+
Extract from opts csv and dialect options.
|
|
337
|
+
|
|
338
|
+
:return: 3 dictionaries: the csv options, the dialect options, any unrecognized options.
|
|
339
|
+
"""
|
|
340
|
+
shell = self.shell
|
|
341
|
+
opts = self.clean_options(self.maybe_read_config_file(opts, direction))
|
|
342
|
+
|
|
343
|
+
dialect_options = dict()
|
|
344
|
+
dialect_options['quotechar'] = opts.pop('quote', '"')
|
|
345
|
+
dialect_options['escapechar'] = opts.pop('escape', '\\')
|
|
346
|
+
dialect_options['delimiter'] = opts.pop('delimiter', ',')
|
|
347
|
+
if dialect_options['quotechar'] == dialect_options['escapechar']:
|
|
348
|
+
dialect_options['doublequote'] = True
|
|
349
|
+
del dialect_options['escapechar']
|
|
350
|
+
else:
|
|
351
|
+
dialect_options['doublequote'] = False
|
|
352
|
+
|
|
353
|
+
copy_options = dict()
|
|
354
|
+
copy_options['nullval'] = opts.pop('null', '')
|
|
355
|
+
copy_options['header'] = bool(opts.pop('header', '').lower() == 'true')
|
|
356
|
+
copy_options['encoding'] = opts.pop('encoding', 'utf8')
|
|
357
|
+
copy_options['maxrequests'] = int(opts.pop('maxrequests', 6))
|
|
358
|
+
copy_options['pagesize'] = int(opts.pop('pagesize', 1000))
|
|
359
|
+
# by default the page timeout is 10 seconds per 1000 entries
|
|
360
|
+
# in the page size or 10 seconds if pagesize is smaller
|
|
361
|
+
copy_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (copy_options['pagesize'] / 1000))))
|
|
362
|
+
copy_options['maxattempts'] = int(opts.pop('maxattempts', 5))
|
|
363
|
+
copy_options['dtformats'] = DateTimeFormat(opts.pop('datetimeformat', shell.display_timestamp_format),
|
|
364
|
+
shell.display_date_format, shell.display_nanotime_format,
|
|
365
|
+
milliseconds_only=True)
|
|
366
|
+
copy_options['floatprecision'] = int(opts.pop('floatprecision', '5'))
|
|
367
|
+
copy_options['doubleprecision'] = int(opts.pop('doubleprecision', '12'))
|
|
368
|
+
copy_options['chunksize'] = int(opts.pop('chunksize', 5000))
|
|
369
|
+
copy_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
|
|
370
|
+
copy_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
|
|
371
|
+
copy_options['minbatchsize'] = int(opts.pop('minbatchsize', 10))
|
|
372
|
+
copy_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
|
|
373
|
+
copy_options['consistencylevel'] = shell.consistency_level
|
|
374
|
+
copy_options['decimalsep'] = opts.pop('decimalsep', '.')
|
|
375
|
+
copy_options['thousandssep'] = opts.pop('thousandssep', '')
|
|
376
|
+
copy_options['boolstyle'] = [s.strip() for s in opts.pop('boolstyle', 'True, False').split(',')]
|
|
377
|
+
copy_options['numprocesses'] = int(opts.pop('numprocesses', self.get_num_processes(16)))
|
|
378
|
+
copy_options['begintoken'] = opts.pop('begintoken', '')
|
|
379
|
+
copy_options['endtoken'] = opts.pop('endtoken', '')
|
|
380
|
+
copy_options['maxrows'] = int(opts.pop('maxrows', '-1'))
|
|
381
|
+
copy_options['skiprows'] = int(opts.pop('skiprows', '0'))
|
|
382
|
+
copy_options['skipcols'] = opts.pop('skipcols', '')
|
|
383
|
+
copy_options['maxparseerrors'] = int(opts.pop('maxparseerrors', '-1'))
|
|
384
|
+
copy_options['maxinserterrors'] = int(opts.pop('maxinserterrors', '1000'))
|
|
385
|
+
copy_options['errfile'] = safe_normpath(opts.pop('errfile', 'import_%s_%s.err' % (self.ks, self.table,)))
|
|
386
|
+
copy_options['ratefile'] = safe_normpath(opts.pop('ratefile', ''))
|
|
387
|
+
copy_options['maxoutputsize'] = int(opts.pop('maxoutputsize', '-1'))
|
|
388
|
+
copy_options['preparedstatements'] = bool(opts.pop('preparedstatements', 'true').lower() == 'true')
|
|
389
|
+
copy_options['ttl'] = int(opts.pop('ttl', -1))
|
|
390
|
+
|
|
391
|
+
# Hidden properties, they do not appear in the documentation but can be set in config files
|
|
392
|
+
# or on the cmd line but w/o completion
|
|
393
|
+
copy_options['maxinflightmessages'] = int(opts.pop('maxinflightmessages', '512'))
|
|
394
|
+
copy_options['maxbackoffattempts'] = int(opts.pop('maxbackoffattempts', '12'))
|
|
395
|
+
copy_options['maxpendingchunks'] = int(opts.pop('maxpendingchunks', '24'))
|
|
396
|
+
# set requesttimeout to a value high enough so that maxbatchsize rows will never timeout if the server
|
|
397
|
+
# responds: here we set it to 1 sec per 10 rows but no less than 60 seconds
|
|
398
|
+
copy_options['requesttimeout'] = int(opts.pop('requesttimeout', max(60, 1 * copy_options['maxbatchsize'] / 10)))
|
|
399
|
+
# set childtimeout higher than requesttimeout so that child processes have a chance to report request timeouts
|
|
400
|
+
copy_options['childtimeout'] = int(opts.pop('childtimeout', copy_options['requesttimeout'] + 30))
|
|
401
|
+
|
|
402
|
+
self.check_options(copy_options)
|
|
403
|
+
return CopyOptions(copy=copy_options, dialect=dialect_options, unrecognized=opts)
|
|
404
|
+
|
|
405
|
+
@staticmethod
|
|
406
|
+
def check_options(copy_options):
|
|
407
|
+
"""
|
|
408
|
+
Check any options that require a sanity check beyond a simple type conversion and if required
|
|
409
|
+
raise a value error:
|
|
410
|
+
|
|
411
|
+
- boolean styles must be exactly 2, they must be different and they cannot be empty
|
|
412
|
+
"""
|
|
413
|
+
bool_styles = copy_options['boolstyle']
|
|
414
|
+
if len(bool_styles) != 2 or bool_styles[0] == bool_styles[1] or not bool_styles[0] or not bool_styles[1]:
|
|
415
|
+
raise ValueError("Invalid boolean styles %s" % copy_options['boolstyle'])
|
|
416
|
+
|
|
417
|
+
@staticmethod
|
|
418
|
+
def get_num_processes(cap):
|
|
419
|
+
"""
|
|
420
|
+
Pick a reasonable number of child processes. We need to leave at
|
|
421
|
+
least one core for the parent or feeder process.
|
|
422
|
+
"""
|
|
423
|
+
return max(1, min(cap, CopyTask.get_num_cores() - 1))
|
|
424
|
+
|
|
425
|
+
@staticmethod
|
|
426
|
+
def get_num_cores():
|
|
427
|
+
"""
|
|
428
|
+
Return the number of cores if available. If the test environment variable
|
|
429
|
+
is set, then return the number carried by this variable. This is to test single-core
|
|
430
|
+
machine more easily.
|
|
431
|
+
"""
|
|
432
|
+
try:
|
|
433
|
+
num_cores_for_testing = os.environ.get('CQLSH_COPY_TEST_NUM_CORES', '')
|
|
434
|
+
ret = int(num_cores_for_testing) if num_cores_for_testing else mp.cpu_count()
|
|
435
|
+
printdebugmsg("Detected %d core(s)" % (ret,))
|
|
436
|
+
return ret
|
|
437
|
+
except NotImplementedError:
|
|
438
|
+
printdebugmsg("Failed to detect number of cores, returning 1")
|
|
439
|
+
return 1
|
|
440
|
+
|
|
441
|
+
@staticmethod
|
|
442
|
+
def describe_interval(seconds):
|
|
443
|
+
desc = []
|
|
444
|
+
for length, unit in ((86400, 'day'), (3600, 'hour'), (60, 'minute')):
|
|
445
|
+
num = int(seconds) / length
|
|
446
|
+
if num > 0:
|
|
447
|
+
desc.append('%d %s' % (num, unit))
|
|
448
|
+
if num > 1:
|
|
449
|
+
desc[-1] += 's'
|
|
450
|
+
seconds %= length
|
|
451
|
+
words = '%.03f seconds' % seconds
|
|
452
|
+
if len(desc) > 1:
|
|
453
|
+
words = ', '.join(desc) + ', and ' + words
|
|
454
|
+
elif len(desc) == 1:
|
|
455
|
+
words = desc[0] + ' and ' + words
|
|
456
|
+
return words
|
|
457
|
+
|
|
458
|
+
@staticmethod
|
|
459
|
+
def get_columns(shell, ks, table, columns):
|
|
460
|
+
"""
|
|
461
|
+
Return all columns if none were specified or only the columns specified.
|
|
462
|
+
Possible enhancement: introduce a regex like syntax (^) to allow users
|
|
463
|
+
to specify all columns except a few.
|
|
464
|
+
"""
|
|
465
|
+
return shell.get_column_names(ks, table) if not columns else columns
|
|
466
|
+
|
|
467
|
+
def close(self):
|
|
468
|
+
self.stop_processes()
|
|
469
|
+
self.inmsg.close()
|
|
470
|
+
self.outmsg.close()
|
|
471
|
+
|
|
472
|
+
def num_live_processes(self):
|
|
473
|
+
return sum(1 for p in self.processes if p.is_alive())
|
|
474
|
+
|
|
475
|
+
@staticmethod
|
|
476
|
+
def get_pid():
|
|
477
|
+
return os.getpid() if hasattr(os, 'getpid') else None
|
|
478
|
+
|
|
479
|
+
@staticmethod
|
|
480
|
+
def trace_process(pid):
|
|
481
|
+
if pid and STRACE_ON:
|
|
482
|
+
os.system("strace -vvvv -c -o strace.{pid}.out -e trace=all -p {pid}&".format(pid=pid))
|
|
483
|
+
|
|
484
|
+
def start_processes(self):
|
|
485
|
+
for i, process in enumerate(self.processes):
|
|
486
|
+
process.start()
|
|
487
|
+
self.trace_process(process.pid)
|
|
488
|
+
self.inmsg.release_writers()
|
|
489
|
+
self.outmsg.release_readers()
|
|
490
|
+
self.trace_process(self.get_pid())
|
|
491
|
+
|
|
492
|
+
def stop_processes(self):
|
|
493
|
+
for process in self.processes:
|
|
494
|
+
process.terminate()
|
|
495
|
+
|
|
496
|
+
def make_params(self):
|
|
497
|
+
"""
|
|
498
|
+
Return a dictionary of parameters to be used by the worker processes.
|
|
499
|
+
On platforms using 'spawn' as the default multiprocessing start method,
|
|
500
|
+
this dictionary must be picklable.
|
|
501
|
+
"""
|
|
502
|
+
shell = self.shell
|
|
503
|
+
|
|
504
|
+
return dict(ks=self.ks,
|
|
505
|
+
table=self.table,
|
|
506
|
+
local_dc=self.host.datacenter,
|
|
507
|
+
columns=self.columns,
|
|
508
|
+
options=self.options,
|
|
509
|
+
connect_timeout=shell.conn.connect_timeout,
|
|
510
|
+
hostname=self.host.address,
|
|
511
|
+
port=shell.port,
|
|
512
|
+
ssl=shell.ssl,
|
|
513
|
+
auth_provider=shell.auth_provider,
|
|
514
|
+
cql_version=shell.conn.cql_version,
|
|
515
|
+
config_file=self.config_file,
|
|
516
|
+
protocol_version=self.protocol_version,
|
|
517
|
+
debug=shell.debug,
|
|
518
|
+
coverage=shell.coverage,
|
|
519
|
+
coveragerc_path=shell.coveragerc_path
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def validate_columns(self):
|
|
523
|
+
shell = self.shell
|
|
524
|
+
|
|
525
|
+
if not self.columns:
|
|
526
|
+
shell.printerr("No column specified")
|
|
527
|
+
return False
|
|
528
|
+
|
|
529
|
+
for c in self.columns:
|
|
530
|
+
if c not in self.table_meta.columns:
|
|
531
|
+
shell.printerr('Invalid column name %s' % (c,))
|
|
532
|
+
return False
|
|
533
|
+
|
|
534
|
+
return True
|
|
535
|
+
|
|
536
|
+
def update_params(self, params, i):
|
|
537
|
+
"""
|
|
538
|
+
Add the communication pipes to the parameters to be passed to the worker process:
|
|
539
|
+
inpipe is the message pipe flowing from parent to child process, so outpipe from the parent point
|
|
540
|
+
of view and, vice-versa, outpipe is the message pipe flowing from child to parent, so inpipe
|
|
541
|
+
from the parent point of view, hence the two are swapped below.
|
|
542
|
+
"""
|
|
543
|
+
params['inpipe'] = self.outmsg.pipes[i]
|
|
544
|
+
params['outpipe'] = self.inmsg.pipes[i]
|
|
545
|
+
return params
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class ExportWriter(object):
|
|
549
|
+
"""
|
|
550
|
+
A class that writes to one or more csv files, or STDOUT
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
def __init__(self, fname, shell, columns, options):
|
|
554
|
+
self.fname = fname
|
|
555
|
+
self.shell = shell
|
|
556
|
+
self.columns = columns
|
|
557
|
+
self.options = options
|
|
558
|
+
self.header = options.copy['header']
|
|
559
|
+
self.max_output_size = int(options.copy['maxoutputsize'])
|
|
560
|
+
self.current_dest = None
|
|
561
|
+
self.num_files = 0
|
|
562
|
+
|
|
563
|
+
if self.max_output_size > 0:
|
|
564
|
+
if fname is not None:
|
|
565
|
+
self.write = self._write_with_split
|
|
566
|
+
self.num_written = 0
|
|
567
|
+
else:
|
|
568
|
+
shell.printerr("WARNING: maxoutputsize {} ignored when writing to STDOUT".format(self.max_output_size))
|
|
569
|
+
self.write = self._write_without_split
|
|
570
|
+
else:
|
|
571
|
+
self.write = self._write_without_split
|
|
572
|
+
|
|
573
|
+
def open(self):
|
|
574
|
+
self.current_dest = self._get_dest(self.fname)
|
|
575
|
+
if self.current_dest is None:
|
|
576
|
+
return False
|
|
577
|
+
|
|
578
|
+
if self.header:
|
|
579
|
+
writer = csv.writer(self.current_dest.output, **self.options.dialect)
|
|
580
|
+
writer.writerow([str(c) for c in self.columns])
|
|
581
|
+
|
|
582
|
+
return True
|
|
583
|
+
|
|
584
|
+
def close(self):
|
|
585
|
+
self._close_current_dest()
|
|
586
|
+
|
|
587
|
+
def _next_dest(self):
|
|
588
|
+
self._close_current_dest()
|
|
589
|
+
self.current_dest = self._get_dest(self.fname + '.%d' % (self.num_files,))
|
|
590
|
+
|
|
591
|
+
def _get_dest(self, source_name):
|
|
592
|
+
"""
|
|
593
|
+
Open the output file if any or else use stdout. Return a namedtuple
|
|
594
|
+
containing the out and a boolean indicating if the output should be closed.
|
|
595
|
+
"""
|
|
596
|
+
CsvDest = namedtuple('CsvDest', 'output close')
|
|
597
|
+
|
|
598
|
+
if self.fname is None:
|
|
599
|
+
return CsvDest(output=sys.stdout, close=False)
|
|
600
|
+
else:
|
|
601
|
+
try:
|
|
602
|
+
ret = CsvDest(output=open(source_name, 'w'), close=True)
|
|
603
|
+
self.num_files += 1
|
|
604
|
+
return ret
|
|
605
|
+
except IOError as e:
|
|
606
|
+
self.shell.printerr("Can't open %r for writing: %s" % (source_name, e))
|
|
607
|
+
return None
|
|
608
|
+
|
|
609
|
+
def _close_current_dest(self):
|
|
610
|
+
if self.current_dest and self.current_dest.close:
|
|
611
|
+
self.current_dest.output.close()
|
|
612
|
+
self.current_dest = None
|
|
613
|
+
|
|
614
|
+
def _write_without_split(self, data, _):
|
|
615
|
+
"""
|
|
616
|
+
Write the data to the current destination output.
|
|
617
|
+
"""
|
|
618
|
+
self.current_dest.output.write(data)
|
|
619
|
+
|
|
620
|
+
def _write_with_split(self, data, num):
|
|
621
|
+
"""
|
|
622
|
+
Write the data to the current destination output if we still
|
|
623
|
+
haven't reached the maximum number of rows. Otherwise split
|
|
624
|
+
the rows between the current destination and the next.
|
|
625
|
+
"""
|
|
626
|
+
if (self.num_written + num) > self.max_output_size:
|
|
627
|
+
num_remaining = self.max_output_size - self.num_written
|
|
628
|
+
last_switch = 0
|
|
629
|
+
for i, row in enumerate([_f for _f in data.split(os.linesep) if _f]):
|
|
630
|
+
if i == num_remaining:
|
|
631
|
+
self._next_dest()
|
|
632
|
+
last_switch = i
|
|
633
|
+
num_remaining += self.max_output_size
|
|
634
|
+
self.current_dest.output.write(row + '\n')
|
|
635
|
+
|
|
636
|
+
self.num_written = num - last_switch
|
|
637
|
+
else:
|
|
638
|
+
self.num_written += num
|
|
639
|
+
self.current_dest.output.write(data)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
class ExportTask(CopyTask):
|
|
643
|
+
"""
|
|
644
|
+
A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
|
|
645
|
+
"""
|
|
646
|
+
def __init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file):
|
|
647
|
+
CopyTask.__init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file, 'to')
|
|
648
|
+
|
|
649
|
+
options = self.options
|
|
650
|
+
self.begin_token = int(options.copy['begintoken']) if options.copy['begintoken'] else None
|
|
651
|
+
self.end_token = int(options.copy['endtoken']) if options.copy['endtoken'] else None
|
|
652
|
+
self.writer = ExportWriter(fname, shell, columns, options)
|
|
653
|
+
|
|
654
|
+
def run(self):
|
|
655
|
+
"""
|
|
656
|
+
Initiates the export by starting the worker processes.
|
|
657
|
+
Then hand over control to export_records.
|
|
658
|
+
"""
|
|
659
|
+
shell = self.shell
|
|
660
|
+
|
|
661
|
+
if self.options.unrecognized:
|
|
662
|
+
shell.printerr('Unrecognized COPY TO options: %s' % ', '.join(list(self.options.unrecognized.keys())))
|
|
663
|
+
return
|
|
664
|
+
|
|
665
|
+
if not self.validate_columns():
|
|
666
|
+
return 0
|
|
667
|
+
|
|
668
|
+
ranges = self.get_ranges()
|
|
669
|
+
if not ranges:
|
|
670
|
+
return 0
|
|
671
|
+
|
|
672
|
+
if not self.writer.open():
|
|
673
|
+
return 0
|
|
674
|
+
|
|
675
|
+
columns = "[" + ", ".join(self.columns) + "]"
|
|
676
|
+
self.printmsg("\nStarting copy of %s.%s with columns %s." % (self.ks, self.table, columns))
|
|
677
|
+
|
|
678
|
+
params = self.make_params()
|
|
679
|
+
for i in range(self.num_processes):
|
|
680
|
+
self.processes.append(ExportProcess(self.update_params(params, i)))
|
|
681
|
+
|
|
682
|
+
self.start_processes()
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
self.export_records(ranges)
|
|
686
|
+
finally:
|
|
687
|
+
self.close()
|
|
688
|
+
|
|
689
|
+
def close(self):
|
|
690
|
+
CopyTask.close(self)
|
|
691
|
+
self.writer.close()
|
|
692
|
+
|
|
693
|
+
def get_ranges(self):
|
|
694
|
+
"""
|
|
695
|
+
return a queue of tuples, where the first tuple entry is a token range (from, to]
|
|
696
|
+
and the second entry is a list of hosts that own that range. Each host is responsible
|
|
697
|
+
for all the tokens in the range (from, to].
|
|
698
|
+
|
|
699
|
+
The ring information comes from the driver metadata token map, which is built by
|
|
700
|
+
querying System.PEERS.
|
|
701
|
+
|
|
702
|
+
We only consider replicas that are in the local datacenter. If there are no local replicas
|
|
703
|
+
we use the cqlsh session host.
|
|
704
|
+
"""
|
|
705
|
+
shell = self.shell
|
|
706
|
+
hostname = self.host.address
|
|
707
|
+
local_dc = self.host.datacenter
|
|
708
|
+
ranges = dict()
|
|
709
|
+
min_token = self.get_min_token()
|
|
710
|
+
begin_token = self.begin_token
|
|
711
|
+
end_token = self.end_token
|
|
712
|
+
|
|
713
|
+
def make_range(prev, curr):
|
|
714
|
+
"""
|
|
715
|
+
Return the intersection of (prev, curr) and (begin_token, end_token),
|
|
716
|
+
return None if the intersection is empty
|
|
717
|
+
"""
|
|
718
|
+
ret = (prev, curr)
|
|
719
|
+
if begin_token:
|
|
720
|
+
if curr < begin_token:
|
|
721
|
+
return None
|
|
722
|
+
elif (prev is None) or (prev < begin_token):
|
|
723
|
+
ret = (begin_token, curr)
|
|
724
|
+
|
|
725
|
+
if end_token:
|
|
726
|
+
if (ret[0] is not None) and (ret[0] > end_token):
|
|
727
|
+
return None
|
|
728
|
+
elif (curr is not None) and (curr > end_token):
|
|
729
|
+
ret = (ret[0], end_token)
|
|
730
|
+
|
|
731
|
+
return ret
|
|
732
|
+
|
|
733
|
+
def make_range_data(replicas=None):
|
|
734
|
+
hosts = []
|
|
735
|
+
if replicas:
|
|
736
|
+
for r in replicas:
|
|
737
|
+
if r.is_up is not False and r.datacenter == local_dc:
|
|
738
|
+
hosts.append(r.address)
|
|
739
|
+
if not hosts:
|
|
740
|
+
hosts.append(hostname) # fallback to default host if no replicas in current dc
|
|
741
|
+
return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0, 'workerno': -1}
|
|
742
|
+
|
|
743
|
+
if begin_token and begin_token < min_token:
|
|
744
|
+
shell.printerr('Begin token %d must be bigger or equal to min token %d' % (begin_token, min_token))
|
|
745
|
+
return ranges
|
|
746
|
+
|
|
747
|
+
if begin_token and end_token and begin_token > end_token:
|
|
748
|
+
shell.printerr('Begin token %d must be smaller than end token %d' % (begin_token, end_token))
|
|
749
|
+
return ranges
|
|
750
|
+
|
|
751
|
+
if shell.conn.metadata.token_map is None or min_token is None:
|
|
752
|
+
ranges[(begin_token, end_token)] = make_range_data()
|
|
753
|
+
return ranges
|
|
754
|
+
|
|
755
|
+
ring = list(shell.get_ring(self.ks).items())
|
|
756
|
+
ring.sort()
|
|
757
|
+
|
|
758
|
+
if not ring:
|
|
759
|
+
# If the ring is empty we get the entire ring from the host we are currently connected to
|
|
760
|
+
ranges[(begin_token, end_token)] = make_range_data()
|
|
761
|
+
elif len(ring) == 1:
|
|
762
|
+
# If there is only one token we get the entire ring from the replicas for that token
|
|
763
|
+
ranges[(begin_token, end_token)] = make_range_data(ring[0][1])
|
|
764
|
+
else:
|
|
765
|
+
# else we loop on the ring
|
|
766
|
+
first_range_data = None
|
|
767
|
+
previous = None
|
|
768
|
+
for token, replicas in ring:
|
|
769
|
+
if not first_range_data:
|
|
770
|
+
first_range_data = make_range_data(replicas) # we use it at the end when wrapping around
|
|
771
|
+
|
|
772
|
+
if token.value == min_token:
|
|
773
|
+
continue # avoids looping entire ring
|
|
774
|
+
|
|
775
|
+
current_range = make_range(previous, token.value)
|
|
776
|
+
if not current_range:
|
|
777
|
+
continue
|
|
778
|
+
|
|
779
|
+
ranges[current_range] = make_range_data(replicas)
|
|
780
|
+
previous = token.value
|
|
781
|
+
|
|
782
|
+
# For the last ring interval we query the same replicas that hold the first token in the ring
|
|
783
|
+
if previous is not None and (not end_token or previous < end_token):
|
|
784
|
+
ranges[(previous, end_token)] = first_range_data
|
|
785
|
+
# TODO: fix this logic added in 4.0: if previous is None, then it can't be compared with less than
|
|
786
|
+
elif previous is None and (not end_token or previous < end_token):
|
|
787
|
+
previous = begin_token if begin_token else min_token
|
|
788
|
+
ranges[(previous, end_token)] = first_range_data
|
|
789
|
+
|
|
790
|
+
if not ranges:
|
|
791
|
+
shell.printerr('Found no ranges to query, check begin and end tokens: %s - %s' % (begin_token, end_token))
|
|
792
|
+
|
|
793
|
+
return ranges
|
|
794
|
+
|
|
795
|
+
def get_min_token(self):
|
|
796
|
+
"""
|
|
797
|
+
:return the minimum token, which depends on the partitioner.
|
|
798
|
+
For partitioners that do not support tokens we return None, in
|
|
799
|
+
this cases we will not work in parallel, we'll just send all requests
|
|
800
|
+
to the cqlsh session host.
|
|
801
|
+
"""
|
|
802
|
+
partitioner = self.shell.conn.metadata.partitioner
|
|
803
|
+
|
|
804
|
+
if partitioner.endswith('RandomPartitioner'):
|
|
805
|
+
return -1
|
|
806
|
+
elif partitioner.endswith('Murmur3Partitioner'):
|
|
807
|
+
return -(2 ** 63) # Long.MIN_VALUE in Java
|
|
808
|
+
else:
|
|
809
|
+
return None
|
|
810
|
+
|
|
811
|
+
def send_work(self, ranges, tokens_to_send):
|
|
812
|
+
prev_worker_no = ranges[tokens_to_send[0]]['workerno']
|
|
813
|
+
i = prev_worker_no + 1 if -1 <= prev_worker_no < (self.num_processes - 1) else 0
|
|
814
|
+
|
|
815
|
+
for token_range in tokens_to_send:
|
|
816
|
+
ranges[token_range]['workerno'] = i
|
|
817
|
+
self.outmsg.channels[i].send((token_range, ranges[token_range]))
|
|
818
|
+
ranges[token_range]['attempts'] += 1
|
|
819
|
+
|
|
820
|
+
i = i + 1 if i < self.num_processes - 1 else 0
|
|
821
|
+
|
|
822
|
+
def export_records(self, ranges):
|
|
823
|
+
"""
|
|
824
|
+
Send records to child processes and monitor them by collecting their results
|
|
825
|
+
or any errors. We terminate when we have processed all the ranges or when one child
|
|
826
|
+
process has died (since in this case we will never get any ACK for the ranges
|
|
827
|
+
processed by it and at the moment we don't keep track of which ranges a
|
|
828
|
+
process is handling).
|
|
829
|
+
"""
|
|
830
|
+
shell = self.shell
|
|
831
|
+
processes = self.processes
|
|
832
|
+
meter = RateMeter(log_fcn=self.printmsg,
|
|
833
|
+
update_interval=self.options.copy['reportfrequency'],
|
|
834
|
+
log_file=self.options.copy['ratefile'])
|
|
835
|
+
total_requests = len(ranges)
|
|
836
|
+
max_attempts = self.options.copy['maxattempts']
|
|
837
|
+
|
|
838
|
+
self.send_work(ranges, list(ranges.keys()))
|
|
839
|
+
|
|
840
|
+
num_processes = len(processes)
|
|
841
|
+
succeeded = 0
|
|
842
|
+
failed = 0
|
|
843
|
+
while (failed + succeeded) < total_requests and self.num_live_processes() == num_processes:
|
|
844
|
+
for token_range, result in self.inmsg.recv(timeout=0.1):
|
|
845
|
+
if token_range is None and result is None: # a request has finished
|
|
846
|
+
succeeded += 1
|
|
847
|
+
elif isinstance(result, Exception): # an error occurred
|
|
848
|
+
# This token_range failed, retry up to max_attempts if no rows received yet,
|
|
849
|
+
# If rows were already received we'd risk duplicating data.
|
|
850
|
+
# Note that there is still a slight risk of duplicating data, even if we have
|
|
851
|
+
# an error with no rows received yet, it's just less likely. To avoid retrying on
|
|
852
|
+
# all timeouts would however mean we could risk not exporting some rows.
|
|
853
|
+
if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0:
|
|
854
|
+
shell.printerr('Error for %s: %s (will try again later attempt %d of %d)'
|
|
855
|
+
% (token_range, result, ranges[token_range]['attempts'], max_attempts))
|
|
856
|
+
self.send_work(ranges, [token_range])
|
|
857
|
+
else:
|
|
858
|
+
shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)'
|
|
859
|
+
% (token_range, result, ranges[token_range]['rows'],
|
|
860
|
+
ranges[token_range]['attempts']))
|
|
861
|
+
failed += 1
|
|
862
|
+
else: # partial result received
|
|
863
|
+
data, num = result
|
|
864
|
+
self.writer.write(data, num)
|
|
865
|
+
meter.increment(n=num)
|
|
866
|
+
ranges[token_range]['rows'] += num
|
|
867
|
+
|
|
868
|
+
if self.num_live_processes() < len(processes):
|
|
869
|
+
for process in processes:
|
|
870
|
+
if not process.is_alive():
|
|
871
|
+
shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode))
|
|
872
|
+
|
|
873
|
+
if succeeded < total_requests:
|
|
874
|
+
shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing'
|
|
875
|
+
% (succeeded, total_requests))
|
|
876
|
+
|
|
877
|
+
self.printmsg("\n%d rows exported to %d files in %s." %
|
|
878
|
+
(meter.get_total_records(),
|
|
879
|
+
self.writer.num_files,
|
|
880
|
+
self.describe_interval(time.time() - self.time_start)))
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class FilesReader(object):
|
|
884
|
+
"""
|
|
885
|
+
A wrapper around a csv reader to keep track of when we have
|
|
886
|
+
exhausted reading input files. We are passed a comma separated
|
|
887
|
+
list of paths, where each path is a valid glob expression.
|
|
888
|
+
We generate a source generator and we read each source one
|
|
889
|
+
by one.
|
|
890
|
+
"""
|
|
891
|
+
def __init__(self, fname, options):
|
|
892
|
+
self.chunk_size = options.copy['chunksize']
|
|
893
|
+
self.header = options.copy['header']
|
|
894
|
+
self.max_rows = options.copy['maxrows']
|
|
895
|
+
self.skip_rows = options.copy['skiprows']
|
|
896
|
+
self.fname = fname
|
|
897
|
+
self.sources = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
898
|
+
self.num_sources = 0
|
|
899
|
+
self.current_source = None
|
|
900
|
+
self.num_read = 0
|
|
901
|
+
|
|
902
|
+
@staticmethod
|
|
903
|
+
def get_source(paths):
|
|
904
|
+
"""
|
|
905
|
+
Return a source generator. Each source is a named tuple
|
|
906
|
+
wrapping the source input, file name and a boolean indicating
|
|
907
|
+
if it requires closing.
|
|
908
|
+
"""
|
|
909
|
+
def make_source(fname):
|
|
910
|
+
try:
|
|
911
|
+
return open(fname, 'r')
|
|
912
|
+
except IOError as e:
|
|
913
|
+
raise IOError("Can't open %r for reading: %s" % (fname, e))
|
|
914
|
+
|
|
915
|
+
for path in paths.split(','):
|
|
916
|
+
path = path.strip()
|
|
917
|
+
if os.path.isfile(path):
|
|
918
|
+
yield make_source(path)
|
|
919
|
+
else:
|
|
920
|
+
result = glob.glob(path)
|
|
921
|
+
if len(result) == 0:
|
|
922
|
+
raise IOError("Can't open %r for reading: no matching file found" % (path,))
|
|
923
|
+
|
|
924
|
+
for f in result:
|
|
925
|
+
yield make_source(f)
|
|
926
|
+
|
|
927
|
+
def start(self):
|
|
928
|
+
self.sources = self.get_source(self.fname)
|
|
929
|
+
self.next_source()
|
|
930
|
+
|
|
931
|
+
@property
|
|
932
|
+
def exhausted(self):
|
|
933
|
+
return not self.current_source
|
|
934
|
+
|
|
935
|
+
def next_source(self):
|
|
936
|
+
"""
|
|
937
|
+
Close the current source, if any, and open the next one. Return true
|
|
938
|
+
if there is another source, false otherwise.
|
|
939
|
+
"""
|
|
940
|
+
self.close_current_source()
|
|
941
|
+
while self.current_source is None:
|
|
942
|
+
try:
|
|
943
|
+
self.current_source = next(self.sources)
|
|
944
|
+
if self.current_source:
|
|
945
|
+
self.num_sources += 1
|
|
946
|
+
except StopIteration:
|
|
947
|
+
return False
|
|
948
|
+
|
|
949
|
+
if self.header:
|
|
950
|
+
next(self.current_source)
|
|
951
|
+
|
|
952
|
+
return True
|
|
953
|
+
|
|
954
|
+
def close_current_source(self):
|
|
955
|
+
if not self.current_source:
|
|
956
|
+
return
|
|
957
|
+
|
|
958
|
+
self.current_source.close()
|
|
959
|
+
self.current_source = None
|
|
960
|
+
|
|
961
|
+
def close(self):
|
|
962
|
+
self.close_current_source()
|
|
963
|
+
|
|
964
|
+
def read_rows(self, max_rows):
|
|
965
|
+
if not self.current_source:
|
|
966
|
+
return []
|
|
967
|
+
|
|
968
|
+
rows = []
|
|
969
|
+
for i in range(min(max_rows, self.chunk_size)):
|
|
970
|
+
try:
|
|
971
|
+
row = next(self.current_source)
|
|
972
|
+
self.num_read += 1
|
|
973
|
+
|
|
974
|
+
if 0 <= self.max_rows < self.num_read:
|
|
975
|
+
self.next_source()
|
|
976
|
+
break
|
|
977
|
+
|
|
978
|
+
if self.num_read > self.skip_rows:
|
|
979
|
+
rows.append(row)
|
|
980
|
+
|
|
981
|
+
except StopIteration:
|
|
982
|
+
self.next_source()
|
|
983
|
+
break
|
|
984
|
+
|
|
985
|
+
return [_f for _f in rows if _f]
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
class PipeReader(object):
|
|
989
|
+
"""
|
|
990
|
+
A class for reading rows received on a pipe, this is used for reading input from STDIN
|
|
991
|
+
"""
|
|
992
|
+
def __init__(self, inpipe, options):
|
|
993
|
+
self.inpipe = inpipe
|
|
994
|
+
self.chunk_size = options.copy['chunksize']
|
|
995
|
+
self.header = options.copy['header']
|
|
996
|
+
self.max_rows = options.copy['maxrows']
|
|
997
|
+
self.skip_rows = options.copy['skiprows']
|
|
998
|
+
self.num_read = 0
|
|
999
|
+
self.exhausted = False
|
|
1000
|
+
self.num_sources = 1
|
|
1001
|
+
|
|
1002
|
+
def start(self):
|
|
1003
|
+
pass
|
|
1004
|
+
|
|
1005
|
+
def read_rows(self, max_rows):
|
|
1006
|
+
rows = []
|
|
1007
|
+
for i in range(min(max_rows, self.chunk_size)):
|
|
1008
|
+
row = self.inpipe.recv()
|
|
1009
|
+
if row is None:
|
|
1010
|
+
self.exhausted = True
|
|
1011
|
+
break
|
|
1012
|
+
|
|
1013
|
+
self.num_read += 1
|
|
1014
|
+
if 0 <= self.max_rows < self.num_read:
|
|
1015
|
+
self.exhausted = True
|
|
1016
|
+
break # max rows exceeded
|
|
1017
|
+
|
|
1018
|
+
if self.header or self.num_read < self.skip_rows:
|
|
1019
|
+
self.header = False # skip header or initial skip_rows rows
|
|
1020
|
+
continue
|
|
1021
|
+
|
|
1022
|
+
rows.append(row)
|
|
1023
|
+
|
|
1024
|
+
return rows
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
class ImportProcessResult(object):
|
|
1028
|
+
"""
|
|
1029
|
+
An object sent from ImportProcess instances to the parent import task in order to indicate progress.
|
|
1030
|
+
"""
|
|
1031
|
+
def __init__(self, imported=0):
|
|
1032
|
+
self.imported = imported
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
class FeedingProcessResult(object):
|
|
1036
|
+
"""
|
|
1037
|
+
An object sent from FeedingProcess instances to the parent import task in order to indicate progress.
|
|
1038
|
+
"""
|
|
1039
|
+
def __init__(self, sent, reader):
|
|
1040
|
+
self.sent = sent
|
|
1041
|
+
self.num_sources = reader.num_sources
|
|
1042
|
+
self.skip_rows = reader.skip_rows
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
class ImportTaskError(object):
|
|
1046
|
+
"""
|
|
1047
|
+
An object sent from child processes (feeder or workers) to the parent import task to indicate an error.
|
|
1048
|
+
"""
|
|
1049
|
+
def __init__(self, name, msg, rows=None, attempts=1, final=True):
|
|
1050
|
+
self.name = name
|
|
1051
|
+
self.msg = msg
|
|
1052
|
+
self.rows = rows if rows else []
|
|
1053
|
+
self.attempts = attempts
|
|
1054
|
+
self.final = final
|
|
1055
|
+
|
|
1056
|
+
def is_parse_error(self):
|
|
1057
|
+
"""
|
|
1058
|
+
We treat read and parse errors as unrecoverable and we have different global counters for giving up when
|
|
1059
|
+
a maximum has been reached. We consider value and type errors as parse errors as well since they
|
|
1060
|
+
are typically non recoverable.
|
|
1061
|
+
"""
|
|
1062
|
+
name = self.name
|
|
1063
|
+
return name.startswith('ValueError') or name.startswith('TypeError') or \
|
|
1064
|
+
name.startswith('ParseError') or name.startswith('IndexError') or name.startswith('ReadError')
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
class ImportErrorHandler(object):
|
|
1068
|
+
"""
|
|
1069
|
+
A class for managing import errors
|
|
1070
|
+
"""
|
|
1071
|
+
def __init__(self, task):
|
|
1072
|
+
self.shell = task.shell
|
|
1073
|
+
self.options = task.options
|
|
1074
|
+
self.max_attempts = self.options.copy['maxattempts']
|
|
1075
|
+
self.max_parse_errors = self.options.copy['maxparseerrors']
|
|
1076
|
+
self.max_insert_errors = self.options.copy['maxinserterrors']
|
|
1077
|
+
self.err_file = self.options.copy['errfile']
|
|
1078
|
+
self.parse_errors = 0
|
|
1079
|
+
self.insert_errors = 0
|
|
1080
|
+
self.num_rows_failed = 0
|
|
1081
|
+
|
|
1082
|
+
if os.path.isfile(self.err_file):
|
|
1083
|
+
now = datetime.datetime.now()
|
|
1084
|
+
old_err_file = self.err_file + now.strftime('.%Y%m%d_%H%M%S')
|
|
1085
|
+
printdebugmsg("Renaming existing %s to %s\n" % (self.err_file, old_err_file))
|
|
1086
|
+
os.rename(self.err_file, old_err_file)
|
|
1087
|
+
|
|
1088
|
+
def max_exceeded(self):
|
|
1089
|
+
if self.insert_errors > self.max_insert_errors >= 0:
|
|
1090
|
+
self.shell.printerr("Exceeded maximum number of insert errors %d" % self.max_insert_errors)
|
|
1091
|
+
return True
|
|
1092
|
+
|
|
1093
|
+
if self.parse_errors > self.max_parse_errors >= 0:
|
|
1094
|
+
self.shell.printerr("Exceeded maximum number of parse errors %d" % self.max_parse_errors)
|
|
1095
|
+
return True
|
|
1096
|
+
|
|
1097
|
+
return False
|
|
1098
|
+
|
|
1099
|
+
def add_failed_rows(self, rows):
|
|
1100
|
+
self.num_rows_failed += len(rows)
|
|
1101
|
+
|
|
1102
|
+
with open(self.err_file, "a") as f:
|
|
1103
|
+
writer = csv.writer(f, **self.options.dialect)
|
|
1104
|
+
for row in rows:
|
|
1105
|
+
writer.writerow(row)
|
|
1106
|
+
|
|
1107
|
+
def handle_error(self, err):
|
|
1108
|
+
"""
|
|
1109
|
+
Handle an error by printing the appropriate error message and incrementing the correct counter.
|
|
1110
|
+
"""
|
|
1111
|
+
shell = self.shell
|
|
1112
|
+
|
|
1113
|
+
if err.is_parse_error():
|
|
1114
|
+
self.parse_errors += len(err.rows)
|
|
1115
|
+
self.add_failed_rows(err.rows)
|
|
1116
|
+
shell.printerr("Failed to import %d rows: %s - %s, given up without retries"
|
|
1117
|
+
% (len(err.rows), err.name, err.msg))
|
|
1118
|
+
else:
|
|
1119
|
+
if not err.final:
|
|
1120
|
+
shell.printerr("Failed to import %d rows: %s - %s, will retry later, attempt %d of %d"
|
|
1121
|
+
% (len(err.rows), err.name, err.msg, err.attempts, self.max_attempts))
|
|
1122
|
+
else:
|
|
1123
|
+
self.insert_errors += len(err.rows)
|
|
1124
|
+
self.add_failed_rows(err.rows)
|
|
1125
|
+
shell.printerr("Failed to import %d rows: %s - %s, given up after %d attempts"
|
|
1126
|
+
% (len(err.rows), err.name, err.msg, err.attempts))
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
class ImportTask(CopyTask):
|
|
1130
|
+
"""
|
|
1131
|
+
A class to import data from .csv by instantiating one or more processes
|
|
1132
|
+
that work in parallel (ImportProcess).
|
|
1133
|
+
"""
|
|
1134
|
+
def __init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file):
|
|
1135
|
+
CopyTask.__init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file, 'from')
|
|
1136
|
+
|
|
1137
|
+
options = self.options
|
|
1138
|
+
self.skip_columns = [c.strip() for c in self.options.copy['skipcols'].split(',')]
|
|
1139
|
+
self.valid_columns = [c for c in self.columns if c not in self.skip_columns]
|
|
1140
|
+
self.receive_meter = RateMeter(log_fcn=self.printmsg,
|
|
1141
|
+
update_interval=options.copy['reportfrequency'],
|
|
1142
|
+
log_file=options.copy['ratefile'])
|
|
1143
|
+
self.error_handler = ImportErrorHandler(self)
|
|
1144
|
+
self.feeding_result = None
|
|
1145
|
+
self.sent = 0
|
|
1146
|
+
|
|
1147
|
+
def make_params(self):
|
|
1148
|
+
ret = CopyTask.make_params(self)
|
|
1149
|
+
ret['skip_columns'] = self.skip_columns
|
|
1150
|
+
ret['valid_columns'] = self.valid_columns
|
|
1151
|
+
return ret
|
|
1152
|
+
|
|
1153
|
+
def validate_columns(self):
|
|
1154
|
+
if not CopyTask.validate_columns(self):
|
|
1155
|
+
return False
|
|
1156
|
+
|
|
1157
|
+
shell = self.shell
|
|
1158
|
+
if not self.valid_columns:
|
|
1159
|
+
shell.printerr("No valid column specified")
|
|
1160
|
+
return False
|
|
1161
|
+
|
|
1162
|
+
for c in self.table_meta.primary_key:
|
|
1163
|
+
if c.name not in self.valid_columns:
|
|
1164
|
+
shell.printerr("Primary key column '%s' missing or skipped" % (c.name,))
|
|
1165
|
+
return False
|
|
1166
|
+
|
|
1167
|
+
return True
|
|
1168
|
+
|
|
1169
|
+
def run(self):
|
|
1170
|
+
shell = self.shell
|
|
1171
|
+
|
|
1172
|
+
if self.options.unrecognized:
|
|
1173
|
+
shell.printerr('Unrecognized COPY FROM options: %s' % ', '.join(list(self.options.unrecognized.keys())))
|
|
1174
|
+
return
|
|
1175
|
+
|
|
1176
|
+
if not self.validate_columns():
|
|
1177
|
+
return 0
|
|
1178
|
+
|
|
1179
|
+
columns = "[" + ", ".join(self.valid_columns) + "]"
|
|
1180
|
+
self.printmsg("\nStarting copy of %s.%s with columns %s." % (self.ks, self.table, columns))
|
|
1181
|
+
|
|
1182
|
+
try:
|
|
1183
|
+
params = self.make_params()
|
|
1184
|
+
|
|
1185
|
+
for i in range(self.num_processes - 1):
|
|
1186
|
+
self.processes.append(ImportProcess(self.update_params(params, i)))
|
|
1187
|
+
|
|
1188
|
+
feeder = FeedingProcess(self.outmsg.pipes[-1], self.inmsg.pipes[-1],
|
|
1189
|
+
self.outmsg.pipes[:-1], self.fname, self.options)
|
|
1190
|
+
self.processes.append(feeder)
|
|
1191
|
+
|
|
1192
|
+
self.start_processes()
|
|
1193
|
+
|
|
1194
|
+
pr = profile_on() if PROFILE_ON else None
|
|
1195
|
+
|
|
1196
|
+
self.import_records()
|
|
1197
|
+
|
|
1198
|
+
if pr:
|
|
1199
|
+
profile_off(pr, file_name='parent_profile_%d.txt' % (os.getpid(),))
|
|
1200
|
+
|
|
1201
|
+
except Exception as exc:
|
|
1202
|
+
shell.printerr(str(exc))
|
|
1203
|
+
if shell.debug:
|
|
1204
|
+
traceback.print_exc()
|
|
1205
|
+
return 0
|
|
1206
|
+
finally:
|
|
1207
|
+
self.close()
|
|
1208
|
+
|
|
1209
|
+
def send_stdin_rows(self):
|
|
1210
|
+
"""
|
|
1211
|
+
We need to pass stdin rows to the feeder process as it is not safe to pickle or share stdin
|
|
1212
|
+
directly (in case of file the child process would close it). This is a very primitive support
|
|
1213
|
+
for STDIN import in that we we won't start reporting progress until STDIN is fully consumed. I
|
|
1214
|
+
think this is reasonable.
|
|
1215
|
+
"""
|
|
1216
|
+
shell = self.shell
|
|
1217
|
+
|
|
1218
|
+
self.printmsg("[Use . on a line by itself to end input]")
|
|
1219
|
+
for row in shell.use_stdin_reader(prompt='[copy] ', until=r'.'):
|
|
1220
|
+
self.outmsg.channels[-1].send(row)
|
|
1221
|
+
|
|
1222
|
+
self.outmsg.channels[-1].send(None)
|
|
1223
|
+
if shell.tty:
|
|
1224
|
+
print()
|
|
1225
|
+
|
|
1226
|
+
def import_records(self):
|
|
1227
|
+
"""
|
|
1228
|
+
Keep on running until we have stuff to receive or send and until all processes are running.
|
|
1229
|
+
Send data (batches or retries) up to the max ingest rate. If we are waiting for stuff to
|
|
1230
|
+
receive check the incoming queue.
|
|
1231
|
+
"""
|
|
1232
|
+
if not self.fname:
|
|
1233
|
+
self.send_stdin_rows()
|
|
1234
|
+
|
|
1235
|
+
child_timeout = self.options.copy['childtimeout']
|
|
1236
|
+
last_recv_num_records = 0
|
|
1237
|
+
last_recv_time = time.time()
|
|
1238
|
+
|
|
1239
|
+
while self.feeding_result is None or self.receive_meter.total_records < self.feeding_result.sent:
|
|
1240
|
+
self.receive_results()
|
|
1241
|
+
|
|
1242
|
+
if self.feeding_result is not None:
|
|
1243
|
+
if self.receive_meter.total_records != last_recv_num_records:
|
|
1244
|
+
last_recv_num_records = self.receive_meter.total_records
|
|
1245
|
+
last_recv_time = time.time()
|
|
1246
|
+
elif (time.time() - last_recv_time) > child_timeout:
|
|
1247
|
+
self.shell.printerr("No records inserted in {} seconds, aborting".format(child_timeout))
|
|
1248
|
+
break
|
|
1249
|
+
|
|
1250
|
+
if self.error_handler.max_exceeded() or not self.all_processes_running():
|
|
1251
|
+
break
|
|
1252
|
+
|
|
1253
|
+
if self.error_handler.num_rows_failed:
|
|
1254
|
+
self.shell.printerr("Failed to process %d rows; failed rows written to %s" %
|
|
1255
|
+
(self.error_handler.num_rows_failed,
|
|
1256
|
+
self.error_handler.err_file))
|
|
1257
|
+
|
|
1258
|
+
if not self.all_processes_running():
|
|
1259
|
+
self.shell.printerr("{} child process(es) died unexpectedly, aborting"
|
|
1260
|
+
.format(self.num_processes - self.num_live_processes()))
|
|
1261
|
+
else:
|
|
1262
|
+
if self.error_handler.max_exceeded():
|
|
1263
|
+
self.processes[-1].terminate() # kill the feeder
|
|
1264
|
+
|
|
1265
|
+
for i, _ in enumerate(self.processes):
|
|
1266
|
+
if self.processes[i].is_alive():
|
|
1267
|
+
self.outmsg.channels[i].send(None)
|
|
1268
|
+
|
|
1269
|
+
# allow time for worker processes to exit cleanly
|
|
1270
|
+
attempts = 50 # 100 milliseconds per attempt, so 5 seconds total
|
|
1271
|
+
while attempts > 0 and self.num_live_processes() > 0:
|
|
1272
|
+
time.sleep(0.1)
|
|
1273
|
+
attempts -= 1
|
|
1274
|
+
|
|
1275
|
+
self.printmsg("\n%d rows imported from %d files in %s (%d skipped)." %
|
|
1276
|
+
(self.receive_meter.get_total_records() - self.error_handler.num_rows_failed,
|
|
1277
|
+
self.feeding_result.num_sources if self.feeding_result else 0,
|
|
1278
|
+
self.describe_interval(time.time() - self.time_start),
|
|
1279
|
+
self.feeding_result.skip_rows if self.feeding_result else 0))
|
|
1280
|
+
|
|
1281
|
+
def all_processes_running(self):
|
|
1282
|
+
return self.num_live_processes() == len(self.processes)
|
|
1283
|
+
|
|
1284
|
+
def receive_results(self):
|
|
1285
|
+
"""
|
|
1286
|
+
Receive results from the worker processes, which will send the number of rows imported
|
|
1287
|
+
or from the feeder process, which will send the number of rows sent when it has finished sending rows.
|
|
1288
|
+
"""
|
|
1289
|
+
aggregate_result = ImportProcessResult()
|
|
1290
|
+
try:
|
|
1291
|
+
for result in self.inmsg.recv(timeout=0.1):
|
|
1292
|
+
if isinstance(result, ImportProcessResult):
|
|
1293
|
+
aggregate_result.imported += result.imported
|
|
1294
|
+
elif isinstance(result, ImportTaskError):
|
|
1295
|
+
self.error_handler.handle_error(result)
|
|
1296
|
+
elif isinstance(result, FeedingProcessResult):
|
|
1297
|
+
self.feeding_result = result
|
|
1298
|
+
else:
|
|
1299
|
+
raise ValueError("Unexpected result: %s" % (result,))
|
|
1300
|
+
finally:
|
|
1301
|
+
self.receive_meter.increment(aggregate_result.imported)
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
class FeedingProcess(mp.Process):
|
|
1305
|
+
"""
|
|
1306
|
+
A process that reads from import sources and sends chunks to worker processes.
|
|
1307
|
+
"""
|
|
1308
|
+
def __init__(self, inpipe, outpipe, worker_pipes, fname, options):
|
|
1309
|
+
super(FeedingProcess, self).__init__(target=self.run)
|
|
1310
|
+
self.inpipe = inpipe
|
|
1311
|
+
self.outpipe = outpipe
|
|
1312
|
+
self.worker_pipes = worker_pipes
|
|
1313
|
+
self.inmsg = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
1314
|
+
self.outmsg = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
1315
|
+
self.worker_channels = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
1316
|
+
self.reader = FilesReader(fname, options) if fname else PipeReader(inpipe, options)
|
|
1317
|
+
self.send_meter = RateMeter(log_fcn=None, update_interval=1)
|
|
1318
|
+
self.ingest_rate = options.copy['ingestrate']
|
|
1319
|
+
self.num_worker_processes = options.copy['numprocesses']
|
|
1320
|
+
self.max_pending_chunks = options.copy['maxpendingchunks']
|
|
1321
|
+
self.chunk_id = 0
|
|
1322
|
+
|
|
1323
|
+
def on_fork(self):
|
|
1324
|
+
"""
|
|
1325
|
+
Create the channels and release any parent connections after forking,
|
|
1326
|
+
see CASSANDRA-11749 for details.
|
|
1327
|
+
"""
|
|
1328
|
+
self.inmsg = ReceivingChannel(self.inpipe)
|
|
1329
|
+
self.outmsg = SendingChannel(self.outpipe)
|
|
1330
|
+
self.worker_channels = [SendingChannel(p) for p in self.worker_pipes]
|
|
1331
|
+
|
|
1332
|
+
def run(self):
|
|
1333
|
+
pr = profile_on() if PROFILE_ON else None
|
|
1334
|
+
|
|
1335
|
+
self.inner_run()
|
|
1336
|
+
|
|
1337
|
+
if pr:
|
|
1338
|
+
profile_off(pr, file_name='feeder_profile_%d.txt' % (os.getpid(),))
|
|
1339
|
+
|
|
1340
|
+
def inner_run(self):
|
|
1341
|
+
"""
|
|
1342
|
+
Send one batch per worker process to the queue unless we have exceeded the ingest rate.
|
|
1343
|
+
In the export case we queue everything and let the worker processes throttle using max_requests,
|
|
1344
|
+
here we throttle using the ingest rate in the feeding process because of memory usage concerns.
|
|
1345
|
+
When finished we send back to the parent process the total number of rows sent.
|
|
1346
|
+
"""
|
|
1347
|
+
|
|
1348
|
+
self.on_fork()
|
|
1349
|
+
|
|
1350
|
+
reader = self.reader
|
|
1351
|
+
try:
|
|
1352
|
+
reader.start()
|
|
1353
|
+
except IOError as exc:
|
|
1354
|
+
self.outmsg.send(
|
|
1355
|
+
ImportTaskError(exc.__class__.__name__, exc.message if hasattr(exc, 'message') else str(exc)))
|
|
1356
|
+
|
|
1357
|
+
channels = self.worker_channels
|
|
1358
|
+
max_pending_chunks = self.max_pending_chunks
|
|
1359
|
+
sent = 0
|
|
1360
|
+
failed_attempts = 0
|
|
1361
|
+
|
|
1362
|
+
while not reader.exhausted:
|
|
1363
|
+
channels_eligible = [c for c in channels if c.num_pending() < max_pending_chunks]
|
|
1364
|
+
if not channels_eligible:
|
|
1365
|
+
failed_attempts += 1
|
|
1366
|
+
delay = randint(1, pow(2, failed_attempts))
|
|
1367
|
+
printdebugmsg("All workers busy, sleeping for %d second(s)" % (delay,))
|
|
1368
|
+
time.sleep(delay)
|
|
1369
|
+
continue
|
|
1370
|
+
elif failed_attempts > 0:
|
|
1371
|
+
failed_attempts = 0
|
|
1372
|
+
|
|
1373
|
+
for ch in channels_eligible:
|
|
1374
|
+
try:
|
|
1375
|
+
max_rows = self.ingest_rate - self.send_meter.current_record
|
|
1376
|
+
if max_rows <= 0:
|
|
1377
|
+
self.send_meter.maybe_update(sleep=False)
|
|
1378
|
+
continue
|
|
1379
|
+
|
|
1380
|
+
rows = reader.read_rows(max_rows)
|
|
1381
|
+
if rows:
|
|
1382
|
+
sent += self.send_chunk(ch, rows)
|
|
1383
|
+
except Exception as exc:
|
|
1384
|
+
self.outmsg.send(
|
|
1385
|
+
ImportTaskError(exc.__class__.__name__, exc.message if hasattr(exc, 'message') else str(exc)))
|
|
1386
|
+
|
|
1387
|
+
if reader.exhausted:
|
|
1388
|
+
break
|
|
1389
|
+
|
|
1390
|
+
# send back to the parent process the number of rows sent to the worker processes
|
|
1391
|
+
self.outmsg.send(FeedingProcessResult(sent, reader))
|
|
1392
|
+
|
|
1393
|
+
# wait for poison pill (None)
|
|
1394
|
+
self.inmsg.recv()
|
|
1395
|
+
|
|
1396
|
+
def send_chunk(self, ch, rows):
|
|
1397
|
+
self.chunk_id += 1
|
|
1398
|
+
num_rows = len(rows)
|
|
1399
|
+
self.send_meter.increment(num_rows)
|
|
1400
|
+
ch.send({'id': self.chunk_id, 'rows': rows, 'imported': 0, 'num_rows_sent': num_rows})
|
|
1401
|
+
return num_rows
|
|
1402
|
+
|
|
1403
|
+
def close(self):
|
|
1404
|
+
self.reader.close()
|
|
1405
|
+
self.inmsg.close()
|
|
1406
|
+
self.outmsg.close()
|
|
1407
|
+
|
|
1408
|
+
for ch in self.worker_channels:
|
|
1409
|
+
ch.close()
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
class ChildProcess(mp.Process):
|
|
1413
|
+
"""
|
|
1414
|
+
An child worker process, this is for common functionality between ImportProcess and ExportProcess.
|
|
1415
|
+
"""
|
|
1416
|
+
|
|
1417
|
+
def __init__(self, params, target):
|
|
1418
|
+
super(ChildProcess, self).__init__(target=target)
|
|
1419
|
+
self.inpipe = params['inpipe']
|
|
1420
|
+
self.outpipe = params['outpipe']
|
|
1421
|
+
self.inmsg = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
1422
|
+
self.outmsg = None # might be initialised directly here? (see CASSANDRA-17350)
|
|
1423
|
+
self.ks = params['ks']
|
|
1424
|
+
self.table = params['table']
|
|
1425
|
+
self.local_dc = params['local_dc']
|
|
1426
|
+
self.columns = params['columns']
|
|
1427
|
+
self.debug = params['debug']
|
|
1428
|
+
self.port = params['port']
|
|
1429
|
+
self.hostname = params['hostname']
|
|
1430
|
+
self.connect_timeout = params['connect_timeout']
|
|
1431
|
+
self.cql_version = params['cql_version']
|
|
1432
|
+
self.auth_provider = params['auth_provider']
|
|
1433
|
+
self.ssl = params['ssl']
|
|
1434
|
+
self.protocol_version = params['protocol_version']
|
|
1435
|
+
self.config_file = params['config_file']
|
|
1436
|
+
|
|
1437
|
+
options = params['options']
|
|
1438
|
+
self.date_time_format = options.copy['dtformats']
|
|
1439
|
+
self.consistency_level = options.copy['consistencylevel']
|
|
1440
|
+
self.decimal_sep = options.copy['decimalsep']
|
|
1441
|
+
self.thousands_sep = options.copy['thousandssep']
|
|
1442
|
+
self.boolean_styles = options.copy['boolstyle']
|
|
1443
|
+
self.max_attempts = options.copy['maxattempts']
|
|
1444
|
+
self.encoding = options.copy['encoding']
|
|
1445
|
+
# Here we inject some failures for testing purposes, only if this environment variable is set
|
|
1446
|
+
if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
|
|
1447
|
+
self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
|
|
1448
|
+
else:
|
|
1449
|
+
self.test_failures = None
|
|
1450
|
+
# attributes for coverage
|
|
1451
|
+
self.coverage = params['coverage']
|
|
1452
|
+
self.coveragerc_path = params['coveragerc_path']
|
|
1453
|
+
self.coverage_collection = None
|
|
1454
|
+
self.sigterm_handler = None
|
|
1455
|
+
self.sighup_handler = None
|
|
1456
|
+
|
|
1457
|
+
def on_fork(self):
|
|
1458
|
+
"""
|
|
1459
|
+
Create the channels and release any parent connections after forking, see CASSANDRA-11749 for details.
|
|
1460
|
+
"""
|
|
1461
|
+
self.inmsg = ReceivingChannel(self.inpipe)
|
|
1462
|
+
self.outmsg = SendingChannel(self.outpipe)
|
|
1463
|
+
|
|
1464
|
+
def close(self):
|
|
1465
|
+
printdebugmsg("Closing queues...")
|
|
1466
|
+
self.inmsg.close()
|
|
1467
|
+
self.outmsg.close()
|
|
1468
|
+
|
|
1469
|
+
def start_coverage(self):
|
|
1470
|
+
import coverage
|
|
1471
|
+
self.coverage_collection = coverage.Coverage(config_file=self.coveragerc_path)
|
|
1472
|
+
self.coverage_collection.start()
|
|
1473
|
+
|
|
1474
|
+
# save current handlers for SIGTERM and SIGHUP
|
|
1475
|
+
self.sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|
1476
|
+
self.sighup_handler = signal.getsignal(signal.SIGTERM)
|
|
1477
|
+
|
|
1478
|
+
def handle_sigterm():
|
|
1479
|
+
self.stop_coverage()
|
|
1480
|
+
self.close()
|
|
1481
|
+
self.terminate()
|
|
1482
|
+
|
|
1483
|
+
# set custom handler for SIGHUP and SIGTERM
|
|
1484
|
+
# needed to make sure coverage data is saved
|
|
1485
|
+
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
1486
|
+
signal.signal(signal.SIGHUP, handle_sigterm)
|
|
1487
|
+
|
|
1488
|
+
def stop_coverage(self):
|
|
1489
|
+
self.coverage_collection.stop()
|
|
1490
|
+
self.coverage_collection.save()
|
|
1491
|
+
signal.signal(signal.SIGTERM, self.sigterm_handler)
|
|
1492
|
+
signal.signal(signal.SIGHUP, self.sighup_handler)
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
class ExpBackoffRetryPolicy(RetryPolicy):
|
|
1496
|
+
"""
|
|
1497
|
+
A retry policy with exponential back-off for read timeouts and write timeouts
|
|
1498
|
+
"""
|
|
1499
|
+
def __init__(self, parent_process):
|
|
1500
|
+
RetryPolicy.__init__(self)
|
|
1501
|
+
self.max_attempts = parent_process.max_attempts
|
|
1502
|
+
|
|
1503
|
+
def on_read_timeout(self, query, consistency, required_responses,
|
|
1504
|
+
received_responses, data_retrieved, retry_num):
|
|
1505
|
+
return self._handle_timeout(consistency, retry_num)
|
|
1506
|
+
|
|
1507
|
+
def on_write_timeout(self, query, consistency, write_type,
|
|
1508
|
+
required_responses, received_responses, retry_num):
|
|
1509
|
+
return self._handle_timeout(consistency, retry_num)
|
|
1510
|
+
|
|
1511
|
+
def _handle_timeout(self, consistency, retry_num):
|
|
1512
|
+
delay = self.backoff(retry_num)
|
|
1513
|
+
if delay > 0:
|
|
1514
|
+
printdebugmsg("Timeout received, retrying after %d seconds" % (delay,))
|
|
1515
|
+
time.sleep(delay)
|
|
1516
|
+
return self.RETRY, consistency
|
|
1517
|
+
elif delay == 0:
|
|
1518
|
+
printdebugmsg("Timeout received, retrying immediately")
|
|
1519
|
+
return self.RETRY, consistency
|
|
1520
|
+
else:
|
|
1521
|
+
printdebugmsg("Timeout received, giving up after %d attempts" % (retry_num + 1))
|
|
1522
|
+
return self.RETHROW, None
|
|
1523
|
+
|
|
1524
|
+
def backoff(self, retry_num):
|
|
1525
|
+
"""
|
|
1526
|
+
Perform exponential back-off up to a maximum number of times, where
|
|
1527
|
+
this maximum is per query.
|
|
1528
|
+
To back-off we should wait a random number of seconds
|
|
1529
|
+
between 0 and 2^c - 1, where c is the number of total failures.
|
|
1530
|
+
|
|
1531
|
+
:return : the number of seconds to wait for, -1 if we should not retry
|
|
1532
|
+
"""
|
|
1533
|
+
if retry_num >= self.max_attempts:
|
|
1534
|
+
return -1
|
|
1535
|
+
|
|
1536
|
+
delay = randint(0, pow(2, retry_num + 1) - 1)
|
|
1537
|
+
return delay
|
|
1538
|
+
|
|
1539
|
+
|
|
1540
|
+
class ExportSession(object):
|
|
1541
|
+
"""
|
|
1542
|
+
A class for connecting to a cluster and storing the number
|
|
1543
|
+
of requests that this connection is processing. It wraps the methods
|
|
1544
|
+
for executing a query asynchronously and for shutting down the
|
|
1545
|
+
connection to the cluster.
|
|
1546
|
+
"""
|
|
1547
|
+
def __init__(self, cluster, export_process):
|
|
1548
|
+
session = cluster.connect(export_process.ks)
|
|
1549
|
+
session.row_factory = tuple_factory
|
|
1550
|
+
session.default_fetch_size = export_process.options.copy['pagesize']
|
|
1551
|
+
session.default_timeout = export_process.options.copy['pagetimeout']
|
|
1552
|
+
|
|
1553
|
+
printdebugmsg("Created connection to %s with page size %d and timeout %d seconds per page"
|
|
1554
|
+
% (cluster.contact_points, session.default_fetch_size, session.default_timeout))
|
|
1555
|
+
|
|
1556
|
+
self.cluster = cluster
|
|
1557
|
+
self.session = session
|
|
1558
|
+
self.requests = 1
|
|
1559
|
+
self.lock = threading.Lock()
|
|
1560
|
+
self.consistency_level = export_process.consistency_level
|
|
1561
|
+
|
|
1562
|
+
def add_request(self):
|
|
1563
|
+
with self.lock:
|
|
1564
|
+
self.requests += 1
|
|
1565
|
+
|
|
1566
|
+
def complete_request(self):
|
|
1567
|
+
with self.lock:
|
|
1568
|
+
self.requests -= 1
|
|
1569
|
+
|
|
1570
|
+
def num_requests(self):
|
|
1571
|
+
with self.lock:
|
|
1572
|
+
return self.requests
|
|
1573
|
+
|
|
1574
|
+
def execute_async(self, query):
|
|
1575
|
+
return self.session.execute_async(SimpleStatement(query, consistency_level=self.consistency_level))
|
|
1576
|
+
|
|
1577
|
+
def shutdown(self):
|
|
1578
|
+
self.cluster.shutdown()
|
|
1579
|
+
|
|
1580
|
+
|
|
1581
|
+
class ExportProcess(ChildProcess):
|
|
1582
|
+
"""
|
|
1583
|
+
An child worker process for the export task, ExportTask.
|
|
1584
|
+
"""
|
|
1585
|
+
|
|
1586
|
+
def __init__(self, params):
|
|
1587
|
+
ChildProcess.__init__(self, params=params, target=self.run)
|
|
1588
|
+
options = params['options']
|
|
1589
|
+
self.float_precision = options.copy['floatprecision']
|
|
1590
|
+
self.double_precision = options.copy['doubleprecision']
|
|
1591
|
+
self.nullval = options.copy['nullval']
|
|
1592
|
+
self.max_requests = options.copy['maxrequests']
|
|
1593
|
+
|
|
1594
|
+
self.hosts_to_sessions = dict()
|
|
1595
|
+
self.formatters = dict()
|
|
1596
|
+
self.options = options
|
|
1597
|
+
|
|
1598
|
+
def run(self):
|
|
1599
|
+
if self.coverage:
|
|
1600
|
+
self.start_coverage()
|
|
1601
|
+
try:
|
|
1602
|
+
self.inner_run()
|
|
1603
|
+
finally:
|
|
1604
|
+
if self.coverage:
|
|
1605
|
+
self.stop_coverage()
|
|
1606
|
+
self.close()
|
|
1607
|
+
|
|
1608
|
+
def inner_run(self):
|
|
1609
|
+
"""
|
|
1610
|
+
The parent sends us (range, info) on the inbound queue (inmsg)
|
|
1611
|
+
in order to request us to process a range, for which we can
|
|
1612
|
+
select any of the hosts in info, which also contains other information for this
|
|
1613
|
+
range such as the number of attempts already performed. We can signal errors
|
|
1614
|
+
on the outbound queue (outmsg) by sending (range, error) or
|
|
1615
|
+
we can signal a global error by sending (None, error).
|
|
1616
|
+
We terminate when the inbound queue is closed.
|
|
1617
|
+
"""
|
|
1618
|
+
|
|
1619
|
+
self.on_fork()
|
|
1620
|
+
|
|
1621
|
+
while True:
|
|
1622
|
+
if self.num_requests() > self.max_requests:
|
|
1623
|
+
time.sleep(0.001) # 1 millisecond
|
|
1624
|
+
continue
|
|
1625
|
+
|
|
1626
|
+
token_range, info = self.inmsg.recv()
|
|
1627
|
+
self.start_request(token_range, info)
|
|
1628
|
+
|
|
1629
|
+
@staticmethod
|
|
1630
|
+
def get_error_message(err, print_traceback=False):
|
|
1631
|
+
if isinstance(err, str):
|
|
1632
|
+
msg = err
|
|
1633
|
+
elif isinstance(err, BaseException):
|
|
1634
|
+
msg = "%s - %s" % (err.__class__.__name__, err)
|
|
1635
|
+
if print_traceback and sys.exc_info()[1] == err:
|
|
1636
|
+
traceback.print_exc()
|
|
1637
|
+
else:
|
|
1638
|
+
msg = str(err)
|
|
1639
|
+
return msg
|
|
1640
|
+
|
|
1641
|
+
def report_error(self, err, token_range):
|
|
1642
|
+
msg = self.get_error_message(err, print_traceback=self.debug)
|
|
1643
|
+
printdebugmsg(msg)
|
|
1644
|
+
self.send((token_range, Exception(msg)))
|
|
1645
|
+
|
|
1646
|
+
def send(self, response):
|
|
1647
|
+
self.outmsg.send(response)
|
|
1648
|
+
|
|
1649
|
+
def start_request(self, token_range, info):
|
|
1650
|
+
"""
|
|
1651
|
+
Begin querying a range by executing an async query that
|
|
1652
|
+
will later on invoke the callbacks attached in attach_callbacks.
|
|
1653
|
+
"""
|
|
1654
|
+
session = self.get_session(info['hosts'], token_range)
|
|
1655
|
+
if session:
|
|
1656
|
+
metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.table]
|
|
1657
|
+
query = self.prepare_query(metadata.partition_key, token_range, info['attempts'])
|
|
1658
|
+
future = session.execute_async(query)
|
|
1659
|
+
self.attach_callbacks(token_range, future, session)
|
|
1660
|
+
|
|
1661
|
+
def num_requests(self):
|
|
1662
|
+
return sum(session.num_requests() for session in list(self.hosts_to_sessions.values()))
|
|
1663
|
+
|
|
1664
|
+
def get_session(self, hosts, token_range):
|
|
1665
|
+
"""
|
|
1666
|
+
We return a session connected to one of the hosts passed in, which are valid replicas for
|
|
1667
|
+
the token range. We sort replicas by favouring those without any active requests yet or with the
|
|
1668
|
+
smallest number of requests. If we fail to connect we report an error so that the token will
|
|
1669
|
+
be retried again later.
|
|
1670
|
+
|
|
1671
|
+
:return: An ExportSession connected to the chosen host.
|
|
1672
|
+
"""
|
|
1673
|
+
# sorted replicas favouring those with no connections yet
|
|
1674
|
+
hosts = sorted(hosts,
|
|
1675
|
+
key=lambda hh: 0 if hh not in self.hosts_to_sessions else self.hosts_to_sessions[hh].requests)
|
|
1676
|
+
|
|
1677
|
+
errors = []
|
|
1678
|
+
ret = None
|
|
1679
|
+
for host in hosts:
|
|
1680
|
+
try:
|
|
1681
|
+
ret = self.connect(host)
|
|
1682
|
+
except Exception as e:
|
|
1683
|
+
errors.append(self.get_error_message(e))
|
|
1684
|
+
|
|
1685
|
+
if ret:
|
|
1686
|
+
if errors:
|
|
1687
|
+
printdebugmsg("Warning: failed to connect to some replicas: %s" % (errors,))
|
|
1688
|
+
return ret
|
|
1689
|
+
|
|
1690
|
+
self.report_error("Failed to connect to all replicas %s for %s, errors: %s" % (hosts, token_range, errors),
|
|
1691
|
+
token_range)
|
|
1692
|
+
return None
|
|
1693
|
+
|
|
1694
|
+
def connect(self, host):
|
|
1695
|
+
if host in list(self.hosts_to_sessions.keys()):
|
|
1696
|
+
session = self.hosts_to_sessions[host]
|
|
1697
|
+
session.add_request()
|
|
1698
|
+
return session
|
|
1699
|
+
|
|
1700
|
+
new_cluster = Cluster(
|
|
1701
|
+
contact_points=(host,),
|
|
1702
|
+
port=self.port,
|
|
1703
|
+
cql_version=self.cql_version,
|
|
1704
|
+
protocol_version=self.protocol_version,
|
|
1705
|
+
auth_provider=self.auth_provider,
|
|
1706
|
+
ssl_context=ssl_settings(host, self.config_file) if self.ssl else None,
|
|
1707
|
+
load_balancing_policy=WhiteListRoundRobinPolicy([host]),
|
|
1708
|
+
default_retry_policy=ExpBackoffRetryPolicy(self),
|
|
1709
|
+
compression=None,
|
|
1710
|
+
control_connection_timeout=self.connect_timeout,
|
|
1711
|
+
connect_timeout=self.connect_timeout,
|
|
1712
|
+
idle_heartbeat_interval=0)
|
|
1713
|
+
session = ExportSession(new_cluster, self)
|
|
1714
|
+
self.hosts_to_sessions[host] = session
|
|
1715
|
+
return session
|
|
1716
|
+
|
|
1717
|
+
def attach_callbacks(self, token_range, future, session):
|
|
1718
|
+
metadata = session.cluster.metadata
|
|
1719
|
+
ks_meta = metadata.keyspaces[self.ks]
|
|
1720
|
+
table_meta = ks_meta.tables[self.table]
|
|
1721
|
+
cql_types = [CqlType(table_meta.columns[c].cql_type, ks_meta) for c in self.columns]
|
|
1722
|
+
|
|
1723
|
+
def result_callback(rows):
|
|
1724
|
+
if future.has_more_pages:
|
|
1725
|
+
future.start_fetching_next_page()
|
|
1726
|
+
self.write_rows_to_csv(token_range, rows, cql_types)
|
|
1727
|
+
else:
|
|
1728
|
+
self.write_rows_to_csv(token_range, rows, cql_types)
|
|
1729
|
+
self.send((None, None))
|
|
1730
|
+
session.complete_request()
|
|
1731
|
+
|
|
1732
|
+
def err_callback(err):
|
|
1733
|
+
self.report_error(err, token_range)
|
|
1734
|
+
session.complete_request()
|
|
1735
|
+
|
|
1736
|
+
future.add_callbacks(callback=result_callback, errback=err_callback)
|
|
1737
|
+
|
|
1738
|
+
def write_rows_to_csv(self, token_range, rows, cql_types):
|
|
1739
|
+
if not rows:
|
|
1740
|
+
return # no rows in this range
|
|
1741
|
+
|
|
1742
|
+
try:
|
|
1743
|
+
output = StringIO()
|
|
1744
|
+
writer = csv.writer(output, **self.options.dialect)
|
|
1745
|
+
|
|
1746
|
+
for row in rows:
|
|
1747
|
+
writer.writerow(list(map(self.format_value, row, cql_types)))
|
|
1748
|
+
|
|
1749
|
+
data = (output.getvalue(), len(rows))
|
|
1750
|
+
self.send((token_range, data))
|
|
1751
|
+
output.close()
|
|
1752
|
+
|
|
1753
|
+
except Exception as e:
|
|
1754
|
+
self.report_error(e, token_range)
|
|
1755
|
+
|
|
1756
|
+
def format_value(self, val, cqltype):
|
|
1757
|
+
if val is None or val == EMPTY:
|
|
1758
|
+
return format_value_default(self.nullval, colormap=NO_COLOR_MAP)
|
|
1759
|
+
|
|
1760
|
+
formatter = self.formatters.get(cqltype, None)
|
|
1761
|
+
if not formatter:
|
|
1762
|
+
formatter = get_formatter(val, cqltype)
|
|
1763
|
+
self.formatters[cqltype] = formatter
|
|
1764
|
+
|
|
1765
|
+
if not hasattr(cqltype, 'precision'):
|
|
1766
|
+
cqltype.precision = self.double_precision if cqltype.type_name == 'double' else self.float_precision
|
|
1767
|
+
|
|
1768
|
+
formatted = formatter(val, cqltype=cqltype,
|
|
1769
|
+
encoding=self.encoding, colormap=NO_COLOR_MAP, date_time_format=self.date_time_format,
|
|
1770
|
+
float_precision=cqltype.precision, nullval=self.nullval, quote=False,
|
|
1771
|
+
decimal_sep=self.decimal_sep, thousands_sep=self.thousands_sep,
|
|
1772
|
+
boolean_styles=self.boolean_styles)
|
|
1773
|
+
return formatted
|
|
1774
|
+
|
|
1775
|
+
def close(self):
|
|
1776
|
+
ChildProcess.close(self)
|
|
1777
|
+
for session in list(self.hosts_to_sessions.values()):
|
|
1778
|
+
session.shutdown()
|
|
1779
|
+
|
|
1780
|
+
def prepare_query(self, partition_key, token_range, attempts):
|
|
1781
|
+
"""
|
|
1782
|
+
Return the export query or a fake query with some failure injected.
|
|
1783
|
+
"""
|
|
1784
|
+
if self.test_failures:
|
|
1785
|
+
return self.maybe_inject_failures(partition_key, token_range, attempts)
|
|
1786
|
+
else:
|
|
1787
|
+
return self.prepare_export_query(partition_key, token_range)
|
|
1788
|
+
|
|
1789
|
+
def maybe_inject_failures(self, partition_key, token_range, attempts):
|
|
1790
|
+
"""
|
|
1791
|
+
Examine self.test_failures and see if token_range is either a token range
|
|
1792
|
+
supposed to cause a failure (failing_range) or to terminate the worker process
|
|
1793
|
+
(exit_range). If not then call prepare_export_query(), which implements the
|
|
1794
|
+
normal behavior.
|
|
1795
|
+
"""
|
|
1796
|
+
start_token, end_token = token_range
|
|
1797
|
+
|
|
1798
|
+
if not start_token or not end_token:
|
|
1799
|
+
# exclude first and last ranges to make things simpler
|
|
1800
|
+
return self.prepare_export_query(partition_key, token_range)
|
|
1801
|
+
|
|
1802
|
+
if 'failing_range' in self.test_failures:
|
|
1803
|
+
failing_range = self.test_failures['failing_range']
|
|
1804
|
+
if start_token >= failing_range['start'] and end_token <= failing_range['end']:
|
|
1805
|
+
if attempts < failing_range['num_failures']:
|
|
1806
|
+
return 'SELECT * from bad_table'
|
|
1807
|
+
|
|
1808
|
+
if 'exit_range' in self.test_failures:
|
|
1809
|
+
exit_range = self.test_failures['exit_range']
|
|
1810
|
+
if start_token >= exit_range['start'] and end_token <= exit_range['end']:
|
|
1811
|
+
sys.exit(1)
|
|
1812
|
+
|
|
1813
|
+
return self.prepare_export_query(partition_key, token_range)
|
|
1814
|
+
|
|
1815
|
+
def prepare_export_query(self, partition_key, token_range):
|
|
1816
|
+
"""
|
|
1817
|
+
Return a query where we select all the data for this token range
|
|
1818
|
+
"""
|
|
1819
|
+
pk_cols = ", ".join(protect_names(col.name for col in partition_key))
|
|
1820
|
+
columnlist = ', '.join(protect_names(self.columns))
|
|
1821
|
+
start_token, end_token = token_range
|
|
1822
|
+
query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks), protect_name(self.table))
|
|
1823
|
+
if start_token is not None or end_token is not None:
|
|
1824
|
+
query += ' WHERE'
|
|
1825
|
+
if start_token is not None:
|
|
1826
|
+
query += ' token(%s) > %s' % (pk_cols, start_token)
|
|
1827
|
+
if start_token is not None and end_token is not None:
|
|
1828
|
+
query += ' AND'
|
|
1829
|
+
if end_token is not None:
|
|
1830
|
+
query += ' token(%s) <= %s' % (pk_cols, end_token)
|
|
1831
|
+
return query
|
|
1832
|
+
|
|
1833
|
+
|
|
1834
|
+
class ParseError(Exception):
|
|
1835
|
+
""" We failed to parse an import record """
|
|
1836
|
+
pass
|
|
1837
|
+
|
|
1838
|
+
|
|
1839
|
+
class ImmutableDict(frozenset):
|
|
1840
|
+
"""
|
|
1841
|
+
Immutable dictionary implementation to represent map types.
|
|
1842
|
+
We need to pass BoundStatement.bind() a dict() because it calls iteritems(),
|
|
1843
|
+
except we can't create a dict with another dict as the key, hence we use a class
|
|
1844
|
+
that adds iteritems to a frozen set of tuples (which is how dict are normally made
|
|
1845
|
+
immutable in python).
|
|
1846
|
+
Must be declared in the top level of the module to be available for pickling.
|
|
1847
|
+
"""
|
|
1848
|
+
iteritems = frozenset.__iter__
|
|
1849
|
+
|
|
1850
|
+
def items(self):
|
|
1851
|
+
for k, v in self.iteritems():
|
|
1852
|
+
yield k, v
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
class ImportConversion(object):
|
|
1856
|
+
"""
|
|
1857
|
+
A class for converting strings to values when importing from csv, used by ImportProcess,
|
|
1858
|
+
the parent.
|
|
1859
|
+
"""
|
|
1860
|
+
def __init__(self, parent, table_meta, statement=None):
|
|
1861
|
+
self.ks = parent.ks
|
|
1862
|
+
self.table = parent.table
|
|
1863
|
+
self.columns = parent.valid_columns
|
|
1864
|
+
self.nullval = parent.nullval
|
|
1865
|
+
self.decimal_sep = parent.decimal_sep
|
|
1866
|
+
self.thousands_sep = parent.thousands_sep
|
|
1867
|
+
self.boolean_styles = parent.boolean_styles
|
|
1868
|
+
self.date_time_format = parent.date_time_format.timestamp_format
|
|
1869
|
+
self.debug = parent.debug
|
|
1870
|
+
self.encoding = parent.encoding
|
|
1871
|
+
|
|
1872
|
+
self.table_meta = table_meta
|
|
1873
|
+
self.primary_key_indexes = [self.columns.index(col.name) for col in self.table_meta.primary_key]
|
|
1874
|
+
self.partition_key_indexes = [self.columns.index(col.name) for col in self.table_meta.partition_key]
|
|
1875
|
+
|
|
1876
|
+
if statement is None:
|
|
1877
|
+
self.use_prepared_statements = False
|
|
1878
|
+
statement = self._get_primary_key_statement(parent, table_meta)
|
|
1879
|
+
else:
|
|
1880
|
+
self.use_prepared_statements = True
|
|
1881
|
+
|
|
1882
|
+
self.is_counter = parent.is_counter(table_meta)
|
|
1883
|
+
self.proto_version = statement.protocol_version
|
|
1884
|
+
|
|
1885
|
+
# the cql types and converters for the prepared statement, either the full statement or only the primary keys
|
|
1886
|
+
self.cqltypes = [c.type for c in statement.column_metadata]
|
|
1887
|
+
self.converters = [self._get_converter(c.type) for c in statement.column_metadata]
|
|
1888
|
+
|
|
1889
|
+
# the cql types for the entire statement, these are the same as the types above but
|
|
1890
|
+
# only when using prepared statements
|
|
1891
|
+
self.coltypes = [table_meta.columns[name].cql_type for name in parent.valid_columns]
|
|
1892
|
+
# these functions are used for non-prepared statements to protect values with quotes if required
|
|
1893
|
+
self.protectors = [self._get_protector(t) for t in self.coltypes]
|
|
1894
|
+
|
|
1895
|
+
@staticmethod
|
|
1896
|
+
def _get_protector(t):
|
|
1897
|
+
if t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet'):
|
|
1898
|
+
return lambda v: protect_value(v)
|
|
1899
|
+
else:
|
|
1900
|
+
return lambda v: v
|
|
1901
|
+
|
|
1902
|
+
@staticmethod
|
|
1903
|
+
def _get_primary_key_statement(parent, table_meta):
|
|
1904
|
+
"""
|
|
1905
|
+
We prepare a query statement to find out the types of the partition key columns so we can
|
|
1906
|
+
route the update query to the correct replicas. As far as I understood this is the easiest
|
|
1907
|
+
way to find out the types of the partition columns, we will never use this prepared statement
|
|
1908
|
+
"""
|
|
1909
|
+
where_clause = ' AND '.join(['%s = ?' % (protect_name(c.name)) for c in table_meta.partition_key])
|
|
1910
|
+
select_query = 'SELECT * FROM %s.%s WHERE %s' % (protect_name(parent.ks),
|
|
1911
|
+
protect_name(parent.table),
|
|
1912
|
+
where_clause)
|
|
1913
|
+
return parent.session.prepare(select_query)
|
|
1914
|
+
|
|
1915
|
+
@staticmethod
|
|
1916
|
+
def unprotect(v):
|
|
1917
|
+
if v is not None:
|
|
1918
|
+
return CqlRuleSet.dequote_value(v)
|
|
1919
|
+
|
|
1920
|
+
def _get_converter(self, cql_type):
|
|
1921
|
+
"""
|
|
1922
|
+
Return a function that converts a string into a value the can be passed
|
|
1923
|
+
into BoundStatement.bind() for the given cql type. See cassandra.cqltypes
|
|
1924
|
+
for more details.
|
|
1925
|
+
"""
|
|
1926
|
+
unprotect = self.unprotect
|
|
1927
|
+
|
|
1928
|
+
def convert(t, v):
|
|
1929
|
+
v = unprotect(v)
|
|
1930
|
+
if v == self.nullval:
|
|
1931
|
+
return self.get_null_val()
|
|
1932
|
+
return converters.get(t.typename, convert_unknown)(v, ct=t)
|
|
1933
|
+
|
|
1934
|
+
def convert_mandatory(t, v):
|
|
1935
|
+
v = unprotect(v)
|
|
1936
|
+
# we can't distinguish between empty strings and null values in csv. Null values are not supported in
|
|
1937
|
+
# collections, so it must be an empty string.
|
|
1938
|
+
if v == self.nullval and not issubclass(t, VarcharType):
|
|
1939
|
+
raise ParseError('Empty values are not allowed')
|
|
1940
|
+
return converters.get(t.typename, convert_unknown)(v, ct=t)
|
|
1941
|
+
|
|
1942
|
+
def convert_blob(v, **_):
|
|
1943
|
+
if sys.version_info.major >= 3:
|
|
1944
|
+
return bytes.fromhex(v[2:])
|
|
1945
|
+
else:
|
|
1946
|
+
return BlobType(v[2:].decode("hex"))
|
|
1947
|
+
|
|
1948
|
+
def convert_text(v, **_):
|
|
1949
|
+
return str(v)
|
|
1950
|
+
|
|
1951
|
+
def convert_uuid(v, **_):
|
|
1952
|
+
return UUID(v)
|
|
1953
|
+
|
|
1954
|
+
def convert_bool(v, **_):
|
|
1955
|
+
return True if v.lower() == self.boolean_styles[0].lower() else False
|
|
1956
|
+
|
|
1957
|
+
def get_convert_integer_fcn(adapter=int):
|
|
1958
|
+
"""
|
|
1959
|
+
Return a slow and a fast integer conversion function depending on self.thousands_sep
|
|
1960
|
+
"""
|
|
1961
|
+
if self.thousands_sep:
|
|
1962
|
+
return lambda v, ct=cql_type: adapter(v.replace(self.thousands_sep, ''))
|
|
1963
|
+
else:
|
|
1964
|
+
return lambda v, ct=cql_type: adapter(v)
|
|
1965
|
+
|
|
1966
|
+
def get_convert_decimal_fcn(adapter=float):
|
|
1967
|
+
"""
|
|
1968
|
+
Return a slow and a fast decimal conversion function depending on self.thousands_sep and self.decimal_sep
|
|
1969
|
+
"""
|
|
1970
|
+
empty_str = ''
|
|
1971
|
+
dot_str = '.'
|
|
1972
|
+
if self.thousands_sep and self.decimal_sep:
|
|
1973
|
+
return lambda v, ct=cql_type: \
|
|
1974
|
+
adapter(v.replace(self.thousands_sep, empty_str).replace(self.decimal_sep, dot_str))
|
|
1975
|
+
elif self.thousands_sep:
|
|
1976
|
+
return lambda v, ct=cql_type: adapter(v.replace(self.thousands_sep, empty_str))
|
|
1977
|
+
elif self.decimal_sep:
|
|
1978
|
+
return lambda v, ct=cql_type: adapter(v.replace(self.decimal_sep, dot_str))
|
|
1979
|
+
else:
|
|
1980
|
+
return lambda v, ct=cql_type: adapter(v)
|
|
1981
|
+
|
|
1982
|
+
def split(val, sep=','):
|
|
1983
|
+
"""
|
|
1984
|
+
Split "val" into a list of values whenever the separator "sep" is found, but
|
|
1985
|
+
ignore separators inside parentheses or single quotes, except for the two
|
|
1986
|
+
outermost parentheses, which will be ignored. This method is called when parsing composite
|
|
1987
|
+
types, "val" should be at least 2 characters long, the first char should be an
|
|
1988
|
+
open parenthesis and the last char should be a matching closing parenthesis. We could also
|
|
1989
|
+
check exactly which parenthesis type depending on the caller, but I don't want to enforce
|
|
1990
|
+
too many checks that don't necessarily provide any additional benefits, and risk breaking
|
|
1991
|
+
data that could previously be imported, even if strictly speaking it is incorrect CQL.
|
|
1992
|
+
For example, right now we accept sets that start with '[' and ']', I don't want to break this
|
|
1993
|
+
by enforcing '{' and '}' in a minor release.
|
|
1994
|
+
"""
|
|
1995
|
+
def is_open_paren(cc):
|
|
1996
|
+
return cc == '{' or cc == '[' or cc == '('
|
|
1997
|
+
|
|
1998
|
+
def is_close_paren(cc):
|
|
1999
|
+
return cc == '}' or cc == ']' or cc == ')'
|
|
2000
|
+
|
|
2001
|
+
def paren_match(c1, c2):
|
|
2002
|
+
return (c1 == '{' and c2 == '}') or (c1 == '[' and c2 == ']') or (c1 == '(' and c2 == ')')
|
|
2003
|
+
|
|
2004
|
+
if len(val) < 2 or not paren_match(val[0], val[-1]):
|
|
2005
|
+
raise ParseError('Invalid composite string, it should start and end with matching parentheses: {}'
|
|
2006
|
+
.format(val))
|
|
2007
|
+
|
|
2008
|
+
ret = []
|
|
2009
|
+
last = 1
|
|
2010
|
+
level = 0
|
|
2011
|
+
quote = False
|
|
2012
|
+
for i, c in enumerate(val):
|
|
2013
|
+
if c == '\'':
|
|
2014
|
+
quote = not quote
|
|
2015
|
+
elif not quote:
|
|
2016
|
+
if is_open_paren(c):
|
|
2017
|
+
level += 1
|
|
2018
|
+
elif is_close_paren(c):
|
|
2019
|
+
level -= 1
|
|
2020
|
+
elif c == sep and level == 1:
|
|
2021
|
+
ret.append(val[last:i])
|
|
2022
|
+
last = i + 1
|
|
2023
|
+
else:
|
|
2024
|
+
if last < len(val) - 1:
|
|
2025
|
+
ret.append(val[last:-1])
|
|
2026
|
+
|
|
2027
|
+
return ret
|
|
2028
|
+
|
|
2029
|
+
# this should match all possible CQL and CQLSH datetime formats
|
|
2030
|
+
p = re.compile(r"(\d{4})-(\d{2})-(\d{2})\s?(?:'T')?" # YYYY-MM-DD[( |'T')]
|
|
2031
|
+
+ r"(?:(\d{2}):(\d{2})(?::(\d{2})(?:\.(\d{1,6}))?))?" # [HH:MM[:SS[.NNNNNN]]]
|
|
2032
|
+
+ r"(?:([+\-])(\d{2}):?(\d{2}))?") # [(+|-)HH[:]MM]]
|
|
2033
|
+
|
|
2034
|
+
def convert_datetime(val, **_):
|
|
2035
|
+
try:
|
|
2036
|
+
dtval = datetime.datetime.strptime(val, self.date_time_format)
|
|
2037
|
+
return dtval.timestamp() * 1000
|
|
2038
|
+
except ValueError:
|
|
2039
|
+
pass # if it's not in the default format we try CQL formats
|
|
2040
|
+
|
|
2041
|
+
m = p.match(val)
|
|
2042
|
+
if not m:
|
|
2043
|
+
try:
|
|
2044
|
+
# in case of overflow COPY TO prints dates as milliseconds from the epoch, see
|
|
2045
|
+
# deserialize_date_fallback_int in cqlsh.py
|
|
2046
|
+
return int(val)
|
|
2047
|
+
except ValueError:
|
|
2048
|
+
raise ValueError("can't interpret %r as a date with format %s or as int" % (val,
|
|
2049
|
+
self.date_time_format))
|
|
2050
|
+
|
|
2051
|
+
# https://docs.python.org/3/library/time.html#time.struct_time
|
|
2052
|
+
tval = time.struct_time((int(m.group(1)), int(m.group(2)), int(m.group(3)), # year, month, day
|
|
2053
|
+
int(m.group(4)) if m.group(4) else 0, # hour
|
|
2054
|
+
int(m.group(5)) if m.group(5) else 0, # minute
|
|
2055
|
+
int(m.group(6)) if m.group(6) else 0, # second
|
|
2056
|
+
0, 1, -1)) # day of week, day of year, dst-flag
|
|
2057
|
+
|
|
2058
|
+
# convert sub-seconds (a number between 1 and 6 digits) to milliseconds
|
|
2059
|
+
milliseconds = 0 if not m.group(7) else int(m.group(7)) * pow(10, 3 - len(m.group(7)))
|
|
2060
|
+
|
|
2061
|
+
if m.group(8):
|
|
2062
|
+
offset = (int(m.group(9)) * 3600 + int(m.group(10)) * 60) * int(m.group(8) + '1')
|
|
2063
|
+
else:
|
|
2064
|
+
offset = -time.timezone
|
|
2065
|
+
|
|
2066
|
+
# scale seconds to millis for the raw value
|
|
2067
|
+
return ((timegm(tval) + offset) * 1000) + milliseconds
|
|
2068
|
+
|
|
2069
|
+
def convert_date(v, **_):
|
|
2070
|
+
return Date(v)
|
|
2071
|
+
|
|
2072
|
+
def convert_time(v, **_):
|
|
2073
|
+
return Time(v)
|
|
2074
|
+
|
|
2075
|
+
def convert_tuple(val, ct=cql_type):
|
|
2076
|
+
return tuple(convert_mandatory(t, v) for t, v in zip(ct.subtypes, split(val)))
|
|
2077
|
+
|
|
2078
|
+
def convert_list(val, ct=cql_type):
|
|
2079
|
+
return tuple(convert_mandatory(ct.subtypes[0], v) for v in split(val))
|
|
2080
|
+
|
|
2081
|
+
def convert_set(val, ct=cql_type):
|
|
2082
|
+
return frozenset(convert_mandatory(ct.subtypes[0], v) for v in split(val))
|
|
2083
|
+
|
|
2084
|
+
def convert_map(val, ct=cql_type):
|
|
2085
|
+
"""
|
|
2086
|
+
See ImmutableDict above for a discussion of why a special object is needed here.
|
|
2087
|
+
"""
|
|
2088
|
+
split_format_str = '{%s}'
|
|
2089
|
+
sep = ':'
|
|
2090
|
+
return ImmutableDict(frozenset((convert_mandatory(ct.subtypes[0], v[0]), convert(ct.subtypes[1], v[1]))
|
|
2091
|
+
for v in [split(split_format_str % vv, sep=sep) for vv in split(val)]))
|
|
2092
|
+
|
|
2093
|
+
def convert_user_type(val, ct=cql_type):
|
|
2094
|
+
"""
|
|
2095
|
+
A user type is a dictionary except that we must convert each key into
|
|
2096
|
+
an attribute, so we are using named tuples. It must also be hashable,
|
|
2097
|
+
so we cannot use dictionaries. Maybe there is a way to instantiate ct
|
|
2098
|
+
directly but I could not work it out.
|
|
2099
|
+
Also note that it is possible that the subfield names in the csv are in the
|
|
2100
|
+
wrong order, so we must sort them according to ct.fieldnames, see CASSANDRA-12959.
|
|
2101
|
+
"""
|
|
2102
|
+
split_format_str = '{%s}'
|
|
2103
|
+
sep = ':'
|
|
2104
|
+
vals = [v for v in [split(split_format_str % vv, sep=sep) for vv in split(val)]]
|
|
2105
|
+
dict_vals = dict((unprotect(v[0]), v[1]) for v in vals)
|
|
2106
|
+
sorted_converted_vals = [(n, convert(t, dict_vals[n]) if n in dict_vals else self.get_null_val())
|
|
2107
|
+
for n, t in zip(ct.fieldnames, ct.subtypes)]
|
|
2108
|
+
ret_type = namedtuple(ct.typename, [v[0] for v in sorted_converted_vals])
|
|
2109
|
+
return ret_type(*tuple(v[1] for v in sorted_converted_vals))
|
|
2110
|
+
|
|
2111
|
+
def convert_single_subtype(val, ct=cql_type):
|
|
2112
|
+
return converters.get(ct.subtypes[0].typename, convert_unknown)(val, ct=ct.subtypes[0])
|
|
2113
|
+
|
|
2114
|
+
def convert_unknown(val, ct=cql_type):
|
|
2115
|
+
if issubclass(ct, UserType):
|
|
2116
|
+
return convert_user_type(val, ct=ct)
|
|
2117
|
+
elif issubclass(ct, ReversedType):
|
|
2118
|
+
return convert_single_subtype(val, ct=ct)
|
|
2119
|
+
|
|
2120
|
+
printdebugmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val))
|
|
2121
|
+
return val
|
|
2122
|
+
|
|
2123
|
+
converters = {
|
|
2124
|
+
'blob': convert_blob,
|
|
2125
|
+
'decimal': get_convert_decimal_fcn(adapter=Decimal),
|
|
2126
|
+
'uuid': convert_uuid,
|
|
2127
|
+
'boolean': convert_bool,
|
|
2128
|
+
'tinyint': get_convert_integer_fcn(),
|
|
2129
|
+
'ascii': convert_text,
|
|
2130
|
+
'float': get_convert_decimal_fcn(),
|
|
2131
|
+
'double': get_convert_decimal_fcn(),
|
|
2132
|
+
'bigint': get_convert_integer_fcn(adapter=int),
|
|
2133
|
+
'int': get_convert_integer_fcn(),
|
|
2134
|
+
'varint': get_convert_integer_fcn(),
|
|
2135
|
+
'inet': convert_text,
|
|
2136
|
+
'counter': get_convert_integer_fcn(adapter=int),
|
|
2137
|
+
'timestamp': convert_datetime,
|
|
2138
|
+
'timeuuid': convert_uuid,
|
|
2139
|
+
'date': convert_date,
|
|
2140
|
+
'smallint': get_convert_integer_fcn(),
|
|
2141
|
+
'time': convert_time,
|
|
2142
|
+
'text': convert_text,
|
|
2143
|
+
'varchar': convert_text,
|
|
2144
|
+
'list': convert_list,
|
|
2145
|
+
'set': convert_set,
|
|
2146
|
+
'map': convert_map,
|
|
2147
|
+
'tuple': convert_tuple,
|
|
2148
|
+
'frozen': convert_single_subtype,
|
|
2149
|
+
}
|
|
2150
|
+
|
|
2151
|
+
return converters.get(cql_type.typename, convert_unknown)
|
|
2152
|
+
|
|
2153
|
+
def get_null_val(self):
|
|
2154
|
+
"""
|
|
2155
|
+
Return the null value that is inserted for fields that are missing from csv files.
|
|
2156
|
+
For counters we should return zero so that the counter value won't be incremented.
|
|
2157
|
+
For everything else we return nulls, this means None if we use prepared statements
|
|
2158
|
+
or "NULL" otherwise. Note that for counters we never use prepared statements, so we
|
|
2159
|
+
only check is_counter when use_prepared_statements is false.
|
|
2160
|
+
"""
|
|
2161
|
+
return None if self.use_prepared_statements else ("0" if self.is_counter else "NULL")
|
|
2162
|
+
|
|
2163
|
+
def convert_row(self, row):
|
|
2164
|
+
"""
|
|
2165
|
+
Convert the row into a list of parsed values if using prepared statements, else simply apply the
|
|
2166
|
+
protection functions to escape values with quotes when required. Also check on the row length and
|
|
2167
|
+
make sure primary partition key values aren't missing.
|
|
2168
|
+
"""
|
|
2169
|
+
converters = self.converters if self.use_prepared_statements else self.protectors
|
|
2170
|
+
|
|
2171
|
+
if len(row) != len(converters):
|
|
2172
|
+
raise ParseError('Invalid row length %d should be %d' % (len(row), len(converters)))
|
|
2173
|
+
|
|
2174
|
+
for i in self.primary_key_indexes:
|
|
2175
|
+
if row[i] == self.nullval:
|
|
2176
|
+
raise ParseError(self.get_null_primary_key_message(i))
|
|
2177
|
+
|
|
2178
|
+
def convert(c, v):
|
|
2179
|
+
try:
|
|
2180
|
+
return c(v) if v != self.nullval else self.get_null_val()
|
|
2181
|
+
except Exception as e:
|
|
2182
|
+
# if we could not convert an empty string, then self.nullval has been set to a marker
|
|
2183
|
+
# because the user needs to import empty strings, except that the converters for some types
|
|
2184
|
+
# will fail to convert an empty string, in this case the null value should be inserted
|
|
2185
|
+
# see CASSANDRA-12794
|
|
2186
|
+
if v == '':
|
|
2187
|
+
return self.get_null_val()
|
|
2188
|
+
|
|
2189
|
+
if self.debug:
|
|
2190
|
+
traceback.print_exc()
|
|
2191
|
+
raise ParseError("Failed to parse %s : %s" % (v, e.message if hasattr(e, 'message') else str(e)))
|
|
2192
|
+
|
|
2193
|
+
return [convert(conv, val) for conv, val in zip(converters, row)]
|
|
2194
|
+
|
|
2195
|
+
def get_null_primary_key_message(self, idx):
|
|
2196
|
+
message = "Cannot insert null value for primary key column '%s'." % (self.columns[idx],)
|
|
2197
|
+
if self.nullval == '':
|
|
2198
|
+
message += " If you want to insert empty strings, consider using" \
|
|
2199
|
+
" the WITH NULL=<marker> option for COPY."
|
|
2200
|
+
return message
|
|
2201
|
+
|
|
2202
|
+
def get_row_partition_key_values_fcn(self):
|
|
2203
|
+
"""
|
|
2204
|
+
Return a function to convert a row into a string composed of the partition key values serialized
|
|
2205
|
+
and binary packed (the tokens on the ring). Depending on whether we are using prepared statements, we
|
|
2206
|
+
may have to convert the primary key values first, so we have two different serialize_value implementations.
|
|
2207
|
+
We also return different functions depending on how many partition key indexes we have (single or multiple).
|
|
2208
|
+
See also BoundStatement.routing_key.
|
|
2209
|
+
"""
|
|
2210
|
+
def serialize_value_prepared(n, v):
|
|
2211
|
+
return self.cqltypes[n].serialize(v, self.proto_version)
|
|
2212
|
+
|
|
2213
|
+
def serialize_value_not_prepared(n, v):
|
|
2214
|
+
return self.cqltypes[n].serialize(self.converters[n](self.unprotect(v)), self.proto_version)
|
|
2215
|
+
|
|
2216
|
+
partition_key_indexes = self.partition_key_indexes
|
|
2217
|
+
serialize = serialize_value_prepared if self.use_prepared_statements else serialize_value_not_prepared
|
|
2218
|
+
|
|
2219
|
+
def serialize_row_single(row):
|
|
2220
|
+
return serialize(partition_key_indexes[0], row[partition_key_indexes[0]])
|
|
2221
|
+
|
|
2222
|
+
def serialize_row_multiple(row):
|
|
2223
|
+
pk_values = []
|
|
2224
|
+
for i in partition_key_indexes:
|
|
2225
|
+
val = serialize(i, row[i])
|
|
2226
|
+
length = len(val)
|
|
2227
|
+
pk_values.append(struct.pack(">H%dsB" % length, length, val, 0))
|
|
2228
|
+
|
|
2229
|
+
return b"".join(pk_values)
|
|
2230
|
+
|
|
2231
|
+
if len(partition_key_indexes) == 1:
|
|
2232
|
+
return serialize_row_single
|
|
2233
|
+
return serialize_row_multiple
|
|
2234
|
+
|
|
2235
|
+
|
|
2236
|
+
class TokenMap(object):
|
|
2237
|
+
"""
|
|
2238
|
+
A wrapper around the metadata token map to speed things up by caching ring token *values* and
|
|
2239
|
+
replicas. It is very important that we use the token values, which are primitive types, rather
|
|
2240
|
+
than the tokens classes when calling bisect_right() in split_batches(). If we use primitive values,
|
|
2241
|
+
the bisect is done in compiled code whilst with token classes each comparison requires a call
|
|
2242
|
+
into the interpreter to perform the cmp operation defined in Python. A simple test with 1 million bisect
|
|
2243
|
+
operations on an array of 2048 tokens was done in 0.37 seconds with primitives and 2.25 seconds with
|
|
2244
|
+
token classes. This is significant for large datasets because we need to do a bisect for each single row,
|
|
2245
|
+
and if VNODES are used, the size of the token map can get quite large too.
|
|
2246
|
+
"""
|
|
2247
|
+
def __init__(self, ks, hostname, local_dc, session):
|
|
2248
|
+
|
|
2249
|
+
self.ks = ks
|
|
2250
|
+
self.hostname = hostname
|
|
2251
|
+
self.local_dc = local_dc
|
|
2252
|
+
self.metadata = session.cluster.metadata
|
|
2253
|
+
|
|
2254
|
+
self._initialize_ring()
|
|
2255
|
+
|
|
2256
|
+
# Note that refresh metadata is disabled by default and we currently do not intercept it
|
|
2257
|
+
# If hosts are added, removed or moved during a COPY operation our token map is no longer optimal
|
|
2258
|
+
# However we can cope with hosts going down and up since we filter for replicas that are up when
|
|
2259
|
+
# making each batch
|
|
2260
|
+
|
|
2261
|
+
def _initialize_ring(self):
|
|
2262
|
+
token_map = self.metadata.token_map
|
|
2263
|
+
if token_map is None:
|
|
2264
|
+
self.ring = [0]
|
|
2265
|
+
self.replicas = [(self.metadata.get_host(self.hostname),)]
|
|
2266
|
+
self.pk_to_token_value = lambda pk: 0
|
|
2267
|
+
return
|
|
2268
|
+
|
|
2269
|
+
token_map.rebuild_keyspace(self.ks, build_if_absent=True)
|
|
2270
|
+
tokens_to_hosts = token_map.tokens_to_hosts_by_ks.get(self.ks, None)
|
|
2271
|
+
from_key = token_map.token_class.from_key
|
|
2272
|
+
|
|
2273
|
+
self.ring = [token.value for token in token_map.ring]
|
|
2274
|
+
self.replicas = [tuple(tokens_to_hosts[token]) for token in token_map.ring]
|
|
2275
|
+
self.pk_to_token_value = lambda pk: from_key(pk).value
|
|
2276
|
+
|
|
2277
|
+
@staticmethod
|
|
2278
|
+
def get_ring_pos(ring, val):
|
|
2279
|
+
idx = bisect_right(ring, val)
|
|
2280
|
+
return idx if idx < len(ring) else 0
|
|
2281
|
+
|
|
2282
|
+
def filter_replicas(self, hosts):
|
|
2283
|
+
shuffled = tuple(sorted(hosts, key=lambda k: random.random()))
|
|
2284
|
+
return [r for r in shuffled if r.is_up is not False and r.datacenter == self.local_dc] if hosts else ()
|
|
2285
|
+
|
|
2286
|
+
|
|
2287
|
+
class FastTokenAwarePolicy(DCAwareRoundRobinPolicy):
|
|
2288
|
+
"""
|
|
2289
|
+
Send to any replicas attached to the query, or else fall back to DCAwareRoundRobinPolicy. Perform
|
|
2290
|
+
exponential back-off if too many in flight requests to all replicas are already in progress.
|
|
2291
|
+
"""
|
|
2292
|
+
|
|
2293
|
+
def __init__(self, parent):
|
|
2294
|
+
DCAwareRoundRobinPolicy.__init__(self, parent.local_dc, 0)
|
|
2295
|
+
self.max_backoff_attempts = parent.max_backoff_attempts
|
|
2296
|
+
self.max_inflight_messages = parent.max_inflight_messages
|
|
2297
|
+
|
|
2298
|
+
def make_query_plan(self, working_keyspace=None, query=None):
|
|
2299
|
+
"""
|
|
2300
|
+
Extend TokenAwarePolicy.make_query_plan() so that we choose the same replicas in preference
|
|
2301
|
+
and most importantly we avoid repeating the (slow) bisect. We also implement a backoff policy
|
|
2302
|
+
by sleeping an exponentially larger delay in case all connections to eligible replicas have
|
|
2303
|
+
too many in flight requests.
|
|
2304
|
+
"""
|
|
2305
|
+
connections = ConnectionWrapper.connections
|
|
2306
|
+
replicas = list(query.replicas) if hasattr(query, 'replicas') else []
|
|
2307
|
+
replicas.extend([r for r in DCAwareRoundRobinPolicy.make_query_plan(self, working_keyspace, query)
|
|
2308
|
+
if r not in replicas])
|
|
2309
|
+
|
|
2310
|
+
if replicas:
|
|
2311
|
+
def replica_is_not_overloaded(r):
|
|
2312
|
+
if r.address in connections:
|
|
2313
|
+
conn = connections[r.address]
|
|
2314
|
+
return conn.in_flight < min(conn.max_request_id, self.max_inflight_messages)
|
|
2315
|
+
return True
|
|
2316
|
+
|
|
2317
|
+
for i in range(self.max_backoff_attempts):
|
|
2318
|
+
for r in filter(replica_is_not_overloaded, replicas):
|
|
2319
|
+
yield r
|
|
2320
|
+
|
|
2321
|
+
# the back-off starts at 10 ms (0.01) and it can go up to to 2^max_backoff_attempts,
|
|
2322
|
+
# which is currently 12, so 2^12 = 4096 = ~40 seconds when dividing by 0.01
|
|
2323
|
+
delay = randint(1, pow(2, i + 1)) * 0.01
|
|
2324
|
+
printdebugmsg("All replicas busy, sleeping for %d second(s)..." % (delay,))
|
|
2325
|
+
time.sleep(delay)
|
|
2326
|
+
|
|
2327
|
+
printdebugmsg("Replicas too busy, given up")
|
|
2328
|
+
|
|
2329
|
+
|
|
2330
|
+
class ConnectionWrapper(DefaultConnection):
|
|
2331
|
+
"""
|
|
2332
|
+
A wrapper to the driver default connection that helps in keeping track of messages in flight.
|
|
2333
|
+
The newly created connection is registered into a global dictionary so that FastTokenAwarePolicy
|
|
2334
|
+
is able to determine if a connection has too many in flight requests.
|
|
2335
|
+
"""
|
|
2336
|
+
connections = {}
|
|
2337
|
+
|
|
2338
|
+
def __init__(self, *args, **kwargs):
|
|
2339
|
+
DefaultConnection.__init__(self, *args, **kwargs)
|
|
2340
|
+
self.connections[self.host] = self
|
|
2341
|
+
|
|
2342
|
+
|
|
2343
|
+
class ImportProcess(ChildProcess):
|
|
2344
|
+
|
|
2345
|
+
def __init__(self, params):
|
|
2346
|
+
ChildProcess.__init__(self, params=params, target=self.run)
|
|
2347
|
+
|
|
2348
|
+
self.skip_columns = params['skip_columns']
|
|
2349
|
+
self.valid_columns = [c for c in params['valid_columns']]
|
|
2350
|
+
self.skip_column_indexes = [i for i, c in enumerate(self.columns) if c in self.skip_columns]
|
|
2351
|
+
|
|
2352
|
+
options = params['options']
|
|
2353
|
+
self.nullval = options.copy['nullval']
|
|
2354
|
+
self.max_attempts = options.copy['maxattempts']
|
|
2355
|
+
self.min_batch_size = options.copy['minbatchsize']
|
|
2356
|
+
self.max_batch_size = options.copy['maxbatchsize']
|
|
2357
|
+
self.use_prepared_statements = options.copy['preparedstatements']
|
|
2358
|
+
self.ttl = options.copy['ttl']
|
|
2359
|
+
self.max_inflight_messages = options.copy['maxinflightmessages']
|
|
2360
|
+
self.max_backoff_attempts = options.copy['maxbackoffattempts']
|
|
2361
|
+
self.request_timeout = options.copy['requesttimeout']
|
|
2362
|
+
|
|
2363
|
+
self.dialect_options = options.dialect
|
|
2364
|
+
self._session = None
|
|
2365
|
+
self.query = None
|
|
2366
|
+
self.conv = None
|
|
2367
|
+
self.make_statement = None
|
|
2368
|
+
|
|
2369
|
+
@property
|
|
2370
|
+
def session(self):
|
|
2371
|
+
if not self._session:
|
|
2372
|
+
cluster = Cluster(
|
|
2373
|
+
contact_points=(self.hostname,),
|
|
2374
|
+
port=self.port,
|
|
2375
|
+
cql_version=self.cql_version,
|
|
2376
|
+
protocol_version=self.protocol_version,
|
|
2377
|
+
auth_provider=self.auth_provider,
|
|
2378
|
+
load_balancing_policy=FastTokenAwarePolicy(self),
|
|
2379
|
+
ssl_context=ssl_settings(self.hostname, self.config_file) if self.ssl else None,
|
|
2380
|
+
default_retry_policy=FallthroughRetryPolicy(), # we throw on timeouts and retry in the error callback
|
|
2381
|
+
compression=None,
|
|
2382
|
+
control_connection_timeout=self.connect_timeout,
|
|
2383
|
+
connect_timeout=self.connect_timeout,
|
|
2384
|
+
idle_heartbeat_interval=0,
|
|
2385
|
+
connection_class=ConnectionWrapper)
|
|
2386
|
+
|
|
2387
|
+
self._session = cluster.connect(self.ks)
|
|
2388
|
+
self._session.default_timeout = self.request_timeout
|
|
2389
|
+
return self._session
|
|
2390
|
+
|
|
2391
|
+
def run(self):
|
|
2392
|
+
if self.coverage:
|
|
2393
|
+
self.start_coverage()
|
|
2394
|
+
|
|
2395
|
+
try:
|
|
2396
|
+
pr = profile_on() if PROFILE_ON else None
|
|
2397
|
+
|
|
2398
|
+
self.on_fork()
|
|
2399
|
+
self.inner_run(*self.make_params())
|
|
2400
|
+
|
|
2401
|
+
if pr:
|
|
2402
|
+
profile_off(pr, file_name='worker_profile_%d.txt' % (os.getpid(),))
|
|
2403
|
+
|
|
2404
|
+
except Exception as exc:
|
|
2405
|
+
self.report_error(exc)
|
|
2406
|
+
|
|
2407
|
+
finally:
|
|
2408
|
+
if self.coverage:
|
|
2409
|
+
self.stop_coverage()
|
|
2410
|
+
self.close()
|
|
2411
|
+
|
|
2412
|
+
def close(self):
|
|
2413
|
+
if self._session:
|
|
2414
|
+
self._session.cluster.shutdown()
|
|
2415
|
+
ChildProcess.close(self)
|
|
2416
|
+
|
|
2417
|
+
def is_counter(self, table_meta):
|
|
2418
|
+
return "counter" in [table_meta.columns[name].cql_type for name in self.valid_columns]
|
|
2419
|
+
|
|
2420
|
+
def make_params(self):
|
|
2421
|
+
metadata = self.session.cluster.metadata
|
|
2422
|
+
table_meta = metadata.keyspaces[self.ks].tables[self.table]
|
|
2423
|
+
|
|
2424
|
+
prepared_statement = None
|
|
2425
|
+
if self.is_counter(table_meta):
|
|
2426
|
+
query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), protect_name(self.table))
|
|
2427
|
+
make_statement = self.wrap_make_statement(self.make_counter_batch_statement)
|
|
2428
|
+
elif self.use_prepared_statements:
|
|
2429
|
+
query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
|
|
2430
|
+
protect_name(self.table),
|
|
2431
|
+
', '.join(protect_names(self.valid_columns),),
|
|
2432
|
+
', '.join(['?' for _ in self.valid_columns]))
|
|
2433
|
+
if self.ttl >= 0:
|
|
2434
|
+
query += 'USING TTL %s' % (self.ttl,)
|
|
2435
|
+
query = self.session.prepare(query)
|
|
2436
|
+
query.consistency_level = self.consistency_level
|
|
2437
|
+
prepared_statement = query
|
|
2438
|
+
make_statement = self.wrap_make_statement(self.make_prepared_batch_statement)
|
|
2439
|
+
else:
|
|
2440
|
+
query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (protect_name(self.ks),
|
|
2441
|
+
protect_name(self.table),
|
|
2442
|
+
', '.join(protect_names(self.valid_columns),))
|
|
2443
|
+
if self.ttl >= 0:
|
|
2444
|
+
query += 'USING TTL %s' % (self.ttl,)
|
|
2445
|
+
make_statement = self.wrap_make_statement(self.make_non_prepared_batch_statement)
|
|
2446
|
+
|
|
2447
|
+
conv = ImportConversion(self, table_meta, prepared_statement)
|
|
2448
|
+
tm = TokenMap(self.ks, self.hostname, self.local_dc, self.session)
|
|
2449
|
+
return query, conv, tm, make_statement
|
|
2450
|
+
|
|
2451
|
+
def inner_run(self, query, conv, tm, make_statement):
|
|
2452
|
+
"""
|
|
2453
|
+
Main run method. Note that we bind self methods that are called inside loops
|
|
2454
|
+
for performance reasons.
|
|
2455
|
+
"""
|
|
2456
|
+
self.query = query
|
|
2457
|
+
self.conv = conv
|
|
2458
|
+
self.make_statement = make_statement
|
|
2459
|
+
|
|
2460
|
+
convert_rows = self.convert_rows
|
|
2461
|
+
split_into_batches = self.split_into_batches
|
|
2462
|
+
result_callback = self.result_callback
|
|
2463
|
+
err_callback = self.err_callback
|
|
2464
|
+
session = self.session
|
|
2465
|
+
|
|
2466
|
+
while True:
|
|
2467
|
+
chunk = self.inmsg.recv()
|
|
2468
|
+
if chunk is None:
|
|
2469
|
+
break
|
|
2470
|
+
|
|
2471
|
+
try:
|
|
2472
|
+
chunk['rows'] = convert_rows(conv, chunk)
|
|
2473
|
+
for replicas, batch in split_into_batches(chunk, conv, tm):
|
|
2474
|
+
statement = make_statement(query, conv, chunk, batch, replicas)
|
|
2475
|
+
if statement:
|
|
2476
|
+
future = session.execute_async(statement)
|
|
2477
|
+
future.add_callbacks(callback=result_callback, callback_args=(batch, chunk),
|
|
2478
|
+
errback=err_callback, errback_args=(batch, chunk, replicas))
|
|
2479
|
+
# do not handle else case, if a statement could not be created, the exception is handled
|
|
2480
|
+
# in self.wrap_make_statement and the error is reported, if a failure is injected that
|
|
2481
|
+
# causes the statement to be None, then we should not report the error so that we can test
|
|
2482
|
+
# the parent process handling missing batches from child processes
|
|
2483
|
+
|
|
2484
|
+
except Exception as exc:
|
|
2485
|
+
self.report_error(exc, chunk, chunk['rows'])
|
|
2486
|
+
|
|
2487
|
+
def wrap_make_statement(self, inner_make_statement):
|
|
2488
|
+
def make_statement(query, conv, chunk, batch, replicas):
|
|
2489
|
+
try:
|
|
2490
|
+
return inner_make_statement(query, conv, batch, replicas)
|
|
2491
|
+
except Exception as exc:
|
|
2492
|
+
print("Failed to make batch statement: {}".format(exc))
|
|
2493
|
+
self.report_error(exc, chunk, batch['rows'])
|
|
2494
|
+
return None
|
|
2495
|
+
|
|
2496
|
+
def make_statement_with_failures(query, conv, chunk, batch, replicas):
|
|
2497
|
+
failed_batch, apply_failure = self.maybe_inject_failures(batch)
|
|
2498
|
+
if apply_failure:
|
|
2499
|
+
return failed_batch
|
|
2500
|
+
return make_statement(query, conv, chunk, batch, replicas)
|
|
2501
|
+
|
|
2502
|
+
return make_statement_with_failures if self.test_failures else make_statement
|
|
2503
|
+
|
|
2504
|
+
def make_counter_batch_statement(self, query, conv, batch, replicas):
|
|
2505
|
+
statement = BatchStatement(batch_type=BatchType.COUNTER, consistency_level=self.consistency_level)
|
|
2506
|
+
statement.replicas = replicas
|
|
2507
|
+
statement.keyspace = self.ks
|
|
2508
|
+
for row in batch['rows']:
|
|
2509
|
+
where_clause = []
|
|
2510
|
+
set_clause = []
|
|
2511
|
+
for i, value in enumerate(row):
|
|
2512
|
+
if i in conv.primary_key_indexes:
|
|
2513
|
+
where_clause.append("{}={}".format(self.valid_columns[i], str(value)))
|
|
2514
|
+
else:
|
|
2515
|
+
set_clause.append("{}={}+{}".format(self.valid_columns[i], self.valid_columns[i], str(value)))
|
|
2516
|
+
|
|
2517
|
+
full_query_text = query % (','.join(set_clause), ' AND '.join(where_clause))
|
|
2518
|
+
statement.add(full_query_text)
|
|
2519
|
+
return statement
|
|
2520
|
+
|
|
2521
|
+
def make_prepared_batch_statement(self, query, _, batch, replicas):
|
|
2522
|
+
"""
|
|
2523
|
+
Return a batch statement. This is an optimized version of:
|
|
2524
|
+
|
|
2525
|
+
statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
|
|
2526
|
+
for row in batch['rows']:
|
|
2527
|
+
statement.add(query, row)
|
|
2528
|
+
|
|
2529
|
+
We could optimize further by removing bound_statements altogether but we'd have to duplicate much
|
|
2530
|
+
more driver's code (BoundStatement.bind()).
|
|
2531
|
+
"""
|
|
2532
|
+
statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
|
|
2533
|
+
statement.replicas = replicas
|
|
2534
|
+
statement.keyspace = self.ks
|
|
2535
|
+
statement._statements_and_parameters = [(True, query.query_id, query.bind(r).values) for r in batch['rows']]
|
|
2536
|
+
return statement
|
|
2537
|
+
|
|
2538
|
+
def make_non_prepared_batch_statement(self, query, _, batch, replicas):
|
|
2539
|
+
statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
|
|
2540
|
+
statement.replicas = replicas
|
|
2541
|
+
statement.keyspace = self.ks
|
|
2542
|
+
field_sep = ','
|
|
2543
|
+
statement._statements_and_parameters = [(False, query % (field_sep.join(r),), ()) for r in batch['rows']]
|
|
2544
|
+
return statement
|
|
2545
|
+
|
|
2546
|
+
def convert_rows(self, conv, chunk):
|
|
2547
|
+
"""
|
|
2548
|
+
Return converted rows and report any errors during conversion.
|
|
2549
|
+
"""
|
|
2550
|
+
def filter_row_values(row):
|
|
2551
|
+
return [v for i, v in enumerate(row) if i not in self.skip_column_indexes]
|
|
2552
|
+
|
|
2553
|
+
if self.skip_column_indexes:
|
|
2554
|
+
rows = [filter_row_values(r) for r in list(csv.reader(chunk['rows'], **self.dialect_options))]
|
|
2555
|
+
else:
|
|
2556
|
+
rows = list(csv.reader(chunk['rows'], **self.dialect_options))
|
|
2557
|
+
|
|
2558
|
+
errors = defaultdict(list)
|
|
2559
|
+
|
|
2560
|
+
def convert_row(r):
|
|
2561
|
+
try:
|
|
2562
|
+
return conv.convert_row(r)
|
|
2563
|
+
except Exception as err:
|
|
2564
|
+
errors[err.message if hasattr(err, 'message') else str(err)].append(r)
|
|
2565
|
+
return None
|
|
2566
|
+
|
|
2567
|
+
converted_rows = [_f for _f in [convert_row(r) for r in rows] if _f]
|
|
2568
|
+
|
|
2569
|
+
if errors:
|
|
2570
|
+
for msg, rows in errors.items():
|
|
2571
|
+
self.report_error(ParseError(msg), chunk, rows)
|
|
2572
|
+
return converted_rows
|
|
2573
|
+
|
|
2574
|
+
def maybe_inject_failures(self, batch):
|
|
2575
|
+
"""
|
|
2576
|
+
Examine self.test_failures and see if the batch is a batch
|
|
2577
|
+
supposed to cause a failure (failing_batch), or to terminate the worker process
|
|
2578
|
+
(exit_batch), or not to be sent (unsent_batch).
|
|
2579
|
+
|
|
2580
|
+
@return any statement that will cause a failure or None if the statement should not be sent
|
|
2581
|
+
plus a boolean indicating if a failure should be applied at all
|
|
2582
|
+
"""
|
|
2583
|
+
if 'failing_batch' in self.test_failures:
|
|
2584
|
+
failing_batch = self.test_failures['failing_batch']
|
|
2585
|
+
if failing_batch['id'] == batch['id']:
|
|
2586
|
+
if batch['attempts'] < failing_batch['failures']:
|
|
2587
|
+
statement = SimpleStatement("INSERT INTO badtable (a, b) VALUES (1, 2)",
|
|
2588
|
+
consistency_level=self.consistency_level)
|
|
2589
|
+
return statement, True # use this statement, which will cause an error
|
|
2590
|
+
|
|
2591
|
+
if 'exit_batch' in self.test_failures:
|
|
2592
|
+
exit_batch = self.test_failures['exit_batch']
|
|
2593
|
+
if exit_batch['id'] == batch['id']:
|
|
2594
|
+
sys.exit(1)
|
|
2595
|
+
|
|
2596
|
+
if 'unsent_batch' in self.test_failures:
|
|
2597
|
+
unsent_batch = self.test_failures['unsent_batch']
|
|
2598
|
+
if unsent_batch['id'] == batch['id']:
|
|
2599
|
+
return None, True # do not send this batch, which will cause missing acks in the parent process
|
|
2600
|
+
|
|
2601
|
+
return None, False # carry on as normal, do not apply any failures
|
|
2602
|
+
|
|
2603
|
+
@staticmethod
|
|
2604
|
+
def make_batch(batch_id, rows, attempts=1):
|
|
2605
|
+
return {'id': batch_id, 'rows': rows, 'attempts': attempts}
|
|
2606
|
+
|
|
2607
|
+
def split_into_batches(self, chunk, conv, tm):
|
|
2608
|
+
"""
|
|
2609
|
+
Batch rows by ring position or replica.
|
|
2610
|
+
If there are at least min_batch_size rows for a ring position then split these rows into
|
|
2611
|
+
groups of max_batch_size and send a batch for each group, using all replicas for this ring position.
|
|
2612
|
+
Otherwise, we are forced to batch by replica, and here unfortunately we can only choose one replica to
|
|
2613
|
+
guarantee common replicas across partition keys. We are typically able
|
|
2614
|
+
to batch by ring position for small clusters or when VNODES are not used. For large clusters with VNODES
|
|
2615
|
+
it may not be possible, in this case it helps to increase the CHUNK SIZE but up to a limit, otherwise
|
|
2616
|
+
we may choke the cluster.
|
|
2617
|
+
"""
|
|
2618
|
+
|
|
2619
|
+
rows_by_ring_pos = defaultdict(list)
|
|
2620
|
+
errors = defaultdict(list)
|
|
2621
|
+
|
|
2622
|
+
min_batch_size = self.min_batch_size
|
|
2623
|
+
max_batch_size = self.max_batch_size
|
|
2624
|
+
ring = tm.ring
|
|
2625
|
+
|
|
2626
|
+
get_row_partition_key_values = conv.get_row_partition_key_values_fcn()
|
|
2627
|
+
pk_to_token_value = tm.pk_to_token_value
|
|
2628
|
+
get_ring_pos = tm.get_ring_pos
|
|
2629
|
+
make_batch = self.make_batch
|
|
2630
|
+
|
|
2631
|
+
for row in chunk['rows']:
|
|
2632
|
+
try:
|
|
2633
|
+
pk = get_row_partition_key_values(row)
|
|
2634
|
+
rows_by_ring_pos[get_ring_pos(ring, pk_to_token_value(pk))].append(row)
|
|
2635
|
+
except Exception as e:
|
|
2636
|
+
errors[e.message if hasattr(e, 'message') else str(e)].append(row)
|
|
2637
|
+
|
|
2638
|
+
if errors:
|
|
2639
|
+
for msg, rows in errors.items():
|
|
2640
|
+
self.report_error(ParseError(msg), chunk, rows)
|
|
2641
|
+
|
|
2642
|
+
replicas = tm.replicas
|
|
2643
|
+
filter_replicas = tm.filter_replicas
|
|
2644
|
+
rows_by_replica = defaultdict(list)
|
|
2645
|
+
for ring_pos, rows in rows_by_ring_pos.items():
|
|
2646
|
+
if len(rows) > min_batch_size:
|
|
2647
|
+
for i in range(0, len(rows), max_batch_size):
|
|
2648
|
+
yield filter_replicas(replicas[ring_pos]), make_batch(chunk['id'], rows[i:i + max_batch_size])
|
|
2649
|
+
else:
|
|
2650
|
+
# select only the first valid replica to guarantee more overlap or none at all
|
|
2651
|
+
# TODO: revisit tuple wrapper
|
|
2652
|
+
rows_by_replica[tuple(filter_replicas(replicas[ring_pos])[:1])].extend(rows)
|
|
2653
|
+
|
|
2654
|
+
# Now send the batches by replica
|
|
2655
|
+
for replicas, rows in rows_by_replica.items():
|
|
2656
|
+
for i in range(0, len(rows), max_batch_size):
|
|
2657
|
+
yield replicas, make_batch(chunk['id'], rows[i:i + max_batch_size])
|
|
2658
|
+
|
|
2659
|
+
def result_callback(self, _, batch, chunk):
|
|
2660
|
+
self.update_chunk(batch['rows'], chunk)
|
|
2661
|
+
|
|
2662
|
+
def err_callback(self, response, batch, chunk, replicas):
|
|
2663
|
+
if isinstance(response, OperationTimedOut) and chunk['imported'] == chunk['num_rows_sent']:
|
|
2664
|
+
return # occasionally the driver sends false timeouts for rows already processed (PYTHON-652)
|
|
2665
|
+
err_is_final = batch['attempts'] >= self.max_attempts
|
|
2666
|
+
self.report_error(response, chunk, batch['rows'], batch['attempts'], err_is_final)
|
|
2667
|
+
if not err_is_final:
|
|
2668
|
+
batch['attempts'] += 1
|
|
2669
|
+
statement = self.make_statement(self.query, self.conv, chunk, batch, replicas)
|
|
2670
|
+
future = self.session.execute_async(statement)
|
|
2671
|
+
future.add_callbacks(callback=self.result_callback, callback_args=(batch, chunk),
|
|
2672
|
+
errback=self.err_callback, errback_args=(batch, chunk, replicas))
|
|
2673
|
+
|
|
2674
|
+
# TODO: review why this is defined twice
|
|
2675
|
+
def report_error(self, err, chunk=None, rows=None, attempts=1, final=True):
|
|
2676
|
+
if self.debug and sys.exc_info()[1] == err:
|
|
2677
|
+
traceback.print_exc()
|
|
2678
|
+
err_msg = err.message if hasattr(err, 'message') else str(err)
|
|
2679
|
+
self.outmsg.send(ImportTaskError(err.__class__.__name__, err_msg, rows, attempts, final))
|
|
2680
|
+
if final and chunk is not None:
|
|
2681
|
+
self.update_chunk(rows, chunk)
|
|
2682
|
+
|
|
2683
|
+
def update_chunk(self, rows, chunk):
|
|
2684
|
+
chunk['imported'] += len(rows)
|
|
2685
|
+
if chunk['imported'] == chunk['num_rows_sent']:
|
|
2686
|
+
self.outmsg.send(ImportProcessResult(chunk['num_rows_sent']))
|
|
2687
|
+
|
|
2688
|
+
|
|
2689
|
+
class RateMeter(object):
|
|
2690
|
+
|
|
2691
|
+
def __init__(self, log_fcn, update_interval=0.25, log_file=''):
|
|
2692
|
+
self.log_fcn = log_fcn # the function for logging, may be None to disable logging
|
|
2693
|
+
self.update_interval = update_interval # how often we update in seconds
|
|
2694
|
+
self.log_file = log_file # an optional file where to log statistics in addition to stdout
|
|
2695
|
+
self.start_time = time.time() # the start time
|
|
2696
|
+
self.last_checkpoint_time = self.start_time # last time we logged
|
|
2697
|
+
self.current_rate = 0.0 # rows per second
|
|
2698
|
+
self.current_record = 0 # number of records since we last updated
|
|
2699
|
+
self.total_records = 0 # total number of records
|
|
2700
|
+
|
|
2701
|
+
if os.path.isfile(self.log_file):
|
|
2702
|
+
os.unlink(self.log_file)
|
|
2703
|
+
|
|
2704
|
+
def increment(self, n=1):
|
|
2705
|
+
self.current_record += n
|
|
2706
|
+
self.maybe_update()
|
|
2707
|
+
|
|
2708
|
+
def maybe_update(self, sleep=False):
|
|
2709
|
+
if self.current_record == 0:
|
|
2710
|
+
return
|
|
2711
|
+
|
|
2712
|
+
new_checkpoint_time = time.time()
|
|
2713
|
+
time_difference = new_checkpoint_time - self.last_checkpoint_time
|
|
2714
|
+
if time_difference >= self.update_interval:
|
|
2715
|
+
self.update(new_checkpoint_time)
|
|
2716
|
+
self.log_message()
|
|
2717
|
+
elif sleep:
|
|
2718
|
+
remaining_time = time_difference - self.update_interval
|
|
2719
|
+
if remaining_time > 0.000001:
|
|
2720
|
+
time.sleep(remaining_time)
|
|
2721
|
+
|
|
2722
|
+
def update(self, new_checkpoint_time):
|
|
2723
|
+
time_difference = new_checkpoint_time - self.last_checkpoint_time
|
|
2724
|
+
if time_difference >= 1e-09:
|
|
2725
|
+
self.current_rate = self.get_new_rate(self.current_record / time_difference)
|
|
2726
|
+
|
|
2727
|
+
self.last_checkpoint_time = new_checkpoint_time
|
|
2728
|
+
self.total_records += self.current_record
|
|
2729
|
+
self.current_record = 0
|
|
2730
|
+
|
|
2731
|
+
def get_new_rate(self, new_rate):
|
|
2732
|
+
"""
|
|
2733
|
+
return the rate of the last period: this is the new rate but
|
|
2734
|
+
averaged with the last rate to smooth a bit
|
|
2735
|
+
"""
|
|
2736
|
+
if self.current_rate == 0.0:
|
|
2737
|
+
return new_rate
|
|
2738
|
+
else:
|
|
2739
|
+
return (self.current_rate + new_rate) / 2.0
|
|
2740
|
+
|
|
2741
|
+
def get_avg_rate(self):
|
|
2742
|
+
"""
|
|
2743
|
+
return the average rate since we started measuring
|
|
2744
|
+
"""
|
|
2745
|
+
time_difference = time.time() - self.start_time
|
|
2746
|
+
return self.total_records / time_difference if time_difference >= 1e-09 else 0
|
|
2747
|
+
|
|
2748
|
+
def log_message(self):
|
|
2749
|
+
if not self.log_fcn:
|
|
2750
|
+
return
|
|
2751
|
+
|
|
2752
|
+
output = 'Processed: %d rows; Rate: %7.0f rows/s; Avg. rate: %7.0f rows/s\r' % \
|
|
2753
|
+
(self.total_records, self.current_rate, self.get_avg_rate())
|
|
2754
|
+
self.log_fcn(output, eol='\r')
|
|
2755
|
+
if self.log_file:
|
|
2756
|
+
with open(self.log_file, "a") as f:
|
|
2757
|
+
f.write(output + '\n')
|
|
2758
|
+
|
|
2759
|
+
def get_total_records(self):
|
|
2760
|
+
self.update(time.time())
|
|
2761
|
+
self.log_message()
|
|
2762
|
+
return self.total_records
|