dbworkload 0.6.5__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,7 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: dbworkload
3
- Version: 0.6.5
3
+ Version: 0.7.0
4
4
  Summary: Workload framework
5
- Home-page: https://dbworkload.github.io/dbworkload/
6
5
  License: GPLv3+
7
6
  Author: Fabio Ghirardello
8
7
  Requires-Python: >=3.8,<4.0
@@ -45,6 +44,7 @@ Requires-Dist: pyyaml
45
44
  Requires-Dist: sqlparse
46
45
  Requires-Dist: tabulate
47
46
  Requires-Dist: typer[all]
47
+ Project-URL: Homepage, https://dbworkload.github.io/dbworkload/
48
48
  Project-URL: Repository, https://github.com/dbworkload/dbworkload
49
49
  Description-Content-Type: text/markdown
50
50
 
@@ -1,5 +1,5 @@
1
1
  import logging
2
-
2
+ import time
3
3
  from importlib import metadata
4
4
 
5
5
  try:
@@ -10,10 +10,15 @@ except:
10
10
  del metadata # optional, avoids polluting the results of dir(__package__)
11
11
 
12
12
  logger = logging.getLogger("dbworkload")
13
- # logger.setLevel(logging.INFO)
13
+
14
14
  sh = logging.StreamHandler()
15
15
  formatter = logging.Formatter(
16
- "%(asctime)s [%(levelname)s] (%(processName)s %(threadName)s) %(module)s:%(lineno)d: %(message)s"
16
+ "%(asctime)s [%(levelname)s] (%(processName)s %(threadName)s) %(module)s:%(lineno)d: %(message)s",
17
17
  )
18
+
19
+ # set the formatter to use UTC and show microseconds
20
+ formatter.converter = time.gmtime
21
+ formatter.default_msec_format = "%s.%06d"
22
+
18
23
  sh.setFormatter(formatter)
19
24
  logger.addHandler(sh)
@@ -3,7 +3,7 @@
3
3
  from .. import __version__
4
4
  import typer
5
5
 
6
- EPILOG = "GitHub: <https://github.com/fabiog1901/dbworkload>"
6
+ EPILOG = "Docs: <https://dbworkload.github.io/dbworkload/>"
7
7
 
8
8
 
9
9
  class ConnInfo:
@@ -17,7 +17,7 @@ import platform
17
17
  import sys
18
18
  import typer
19
19
  import yaml
20
-
20
+ import pandas as pd
21
21
 
22
22
  logger = logging.getLogger("dbworkload")
23
23
 
@@ -97,6 +97,12 @@ def run(
97
97
  help="Duration in seconds. Defaults to <ad infinitum>.",
98
98
  show_default=False,
99
99
  ),
100
+ max_rate: int = typer.Option(
101
+ None,
102
+ "--max-rate",
103
+ show_default=False,
104
+ help="Set the max-rate to have dbworkload manage concurrency. Defaults to None.",
105
+ ),
100
106
  conn_duration: int = typer.Option(
101
107
  None,
102
108
  "-k",
@@ -134,6 +140,11 @@ def run(
134
140
  show_default=False,
135
141
  help="Save stats to CSV files.",
136
142
  ),
143
+ schedule: str = typer.Option(
144
+ None,
145
+ "--schedule",
146
+ help="schedule JSON string or filepath to the schedule file.",
147
+ ),
137
148
  log_level: LogLevel = Param.LogLevel,
138
149
  ):
139
150
  logger.setLevel(log_level.upper())
@@ -220,6 +231,8 @@ def run(
220
231
 
221
232
  args = load_args(args)
222
233
 
234
+ schedule = load_schedule(schedule)
235
+
223
236
  dbworkload.models.run.run(
224
237
  concurrency,
225
238
  workload_path,
@@ -230,10 +243,12 @@ def run(
230
243
  conn_info,
231
244
  duration,
232
245
  conn_duration,
246
+ max_rate,
233
247
  args,
234
248
  driver,
235
249
  quiet,
236
250
  save,
251
+ schedule,
237
252
  log_level.upper(),
238
253
  )
239
254
 
@@ -278,6 +293,21 @@ def load_args(args: str):
278
293
  return {}
279
294
 
280
295
 
296
+ def load_schedule(schedule_path: str):
297
+ if schedule_path:
298
+ if os.path.exists(schedule_path):
299
+ df = pd.read_csv(schedule_path, dtype="Int64", comment="#").fillna(0)
300
+ # trasform ramp and duration columns from minutes to seconds
301
+ df[["ramp", "duration"]] = df[["ramp", "duration"]] * 60
302
+
303
+ return df.values.tolist()
304
+ else:
305
+ try:
306
+ return json.loads(schedule_path)
307
+ except:
308
+ logger.error(f"couldn't decode {schedule_path} as JSON")
309
+
310
+
281
311
  def _version_callback(value: bool) -> None:
282
312
  if value:
283
313
  typer.echo(f"dbworkload : {__version__}")
@@ -13,7 +13,7 @@ import signal
13
13
  import sys
14
14
  import sys
15
15
  import tabulate
16
- import threading
16
+ from threading import Thread
17
17
  import time
18
18
  import traceback
19
19
 
@@ -92,35 +92,72 @@ def signal_handler(sig, frame):
92
92
  """
93
93
  logger.info("KeyboardInterrupt signal detected. Stopping processes...")
94
94
 
95
- # send the poison pill to each worker.
95
+ # send the poison pill to each proc.
96
96
  # if dbworkload cannot graceful shutdown due
97
97
  # to processes being still in the init phase
98
98
  # when the pill is sent, a subsequent Ctrl+C will cause
99
99
  # the pill to overflow the kill_q
100
100
  # and raise the queue.Full exception, forcing to quit.
101
- for _ in range(concurrency):
101
+ for q in queues.values():
102
102
  try:
103
- kill_q.put(None, timeout=0.1)
103
+ q.put("proc_end", timeout=0.1)
104
104
  except queue.Full:
105
105
  logger.error("Timed out")
106
106
  sys.exit(1)
107
107
 
108
- logger.debug("Sent poison pill to all threads")
108
+ logger.debug("Sent poison pill to all procs")
109
109
 
110
110
 
111
- def ramp_up(
112
- processes: list, interval: float, threads_per_proc: list, init_sleep: int = 0
111
+ def cycle(iterable, backwards=False):
112
+ global current_proc
113
+
114
+ if not backwards:
115
+ current_proc += 1
116
+ return current_proc % iterable
117
+ else:
118
+ v = current_proc % iterable
119
+ current_proc -= 1
120
+ return v
121
+
122
+
123
+ # Launch or kill worker threads based on cc_change value.
124
+ # workers are added or removed evenly across all supervisors.
125
+ # If a ramp time is specified, threads creation or destruction
126
+ # will be paced accordingly.
127
+ def launch_or_kill_workers(
128
+ queues: list,
129
+ ramp_time: int,
130
+ cc_change: int,
131
+ proc_len: list,
132
+ iterations_per_thread,
133
+ concurrency,
113
134
  ):
114
- """Start each process in the list sequentially respecting the interval between each process"""
115
- time.sleep(init_sleep)
116
- for i, p in enumerate(processes):
117
- logger.debug("Starting a new Process...")
118
- p.start()
119
- time.sleep(interval * threads_per_proc[i])
135
+ if cc_change == 0:
136
+ return
137
+
138
+ ramp_interval = ramp_time / abs(cc_change)
139
+ global thread_id
140
+
141
+ if cc_change > 0:
142
+ for _ in range(cc_change):
143
+ queues[cycle(proc_len)].put(
144
+ (
145
+ thread_id,
146
+ iterations_per_thread,
147
+ concurrency,
148
+ )
149
+ )
150
+ thread_id += 1
151
+ time.sleep(ramp_interval)
152
+
153
+ if cc_change < 0:
154
+ for _ in range(abs(cc_change)):
155
+ queues[cycle(proc_len, backwards=True)].put("kill_one")
156
+ time.sleep(ramp_interval)
120
157
 
121
158
 
122
159
  def run(
123
- conc: int,
160
+ concurrency: int,
124
161
  workload_path: str,
125
162
  prom_port: int,
126
163
  iterations: int,
@@ -129,13 +166,15 @@ def run(
129
166
  conn_info: dict,
130
167
  duration: int,
131
168
  conn_duration: int,
169
+ max_rate: int,
132
170
  args: dict,
133
171
  driver: str,
134
172
  quiet: bool,
135
173
  save: bool,
174
+ schedule: list,
136
175
  log_level: str,
137
176
  ):
138
- def gracefully_shutdown():
177
+ def gracefully_shutdown(by_keyinterrupt: bool = False):
139
178
  """
140
179
  wait for final stat reports to come in,
141
180
  then print final stats and quit
@@ -144,10 +183,21 @@ def run(
144
183
  end_time = int(time.time())
145
184
  _s = stats_received
146
185
 
186
+ if not by_keyinterrupt:
187
+ for q in queues.values():
188
+ try:
189
+ q.put("proc_end", timeout=0.1)
190
+ except queue.Full:
191
+ logger.error("Timed out")
192
+ sys.exit(1)
193
+
194
+ for x in supervisors.values():
195
+ if x.is_alive():
196
+ x.join()
197
+
147
198
  while True:
148
199
  try:
149
- msg = q.get(block=True, timeout=2.0)
150
-
200
+ msg = to_main_q.get(block=True, timeout=2.0)
151
201
  if isinstance(msg, list):
152
202
  _s += 1
153
203
  stats.add_tds(msg)
@@ -249,16 +299,7 @@ def run(
249
299
 
250
300
  logger.setLevel(log_level)
251
301
 
252
- global concurrency
253
- concurrency = conc
254
-
255
- global kill_q
256
- global q
257
302
  start_time = int(time.time())
258
-
259
- # the offset registers at what second we want all threads
260
- # to send the stat report, so they all send it at the same time
261
- offset = start_time % FREQUENCY
262
303
  workload = dbworkload.utils.common.import_class_at_runtime(workload_path)
263
304
 
264
305
  run_name = (
@@ -269,6 +310,10 @@ def run(
269
310
 
270
311
  logger.info(f"Starting workload {run_name}")
271
312
 
313
+ # the offset registers at what second we want all threads
314
+ # to send the stat report, so they all send it at the same time
315
+ offset = start_time % FREQUENCY
316
+
272
317
  # open a new csv file and just write the header columns
273
318
  if save:
274
319
  with open(run_name + ".csv", "w") as f:
@@ -281,6 +326,51 @@ def run(
281
326
 
282
327
  prom = dbworkload.utils.common.Prom(prom_port)
283
328
 
329
+ to_main_q = mp.Queue()
330
+
331
+ global queues
332
+ global supervisors
333
+ supervisors = {}
334
+ queues = {}
335
+
336
+ # launch supervisors in a dedicated OS process
337
+ for x in range(procs):
338
+ queues[x] = mp.Queue()
339
+ supervisors[x] = mp.Process(
340
+ target=supervisor,
341
+ args=(
342
+ to_main_q,
343
+ queues[x],
344
+ log_level,
345
+ conn_info,
346
+ driver,
347
+ workload,
348
+ args,
349
+ conn_duration,
350
+ offset,
351
+ x,
352
+ ),
353
+ daemon=True,
354
+ )
355
+ supervisors[x].start()
356
+
357
+ # report time happens 2 seconds after the stats are received.
358
+ # we add this buffer to make sure we get all the stats reports
359
+ # from each thread before we aggregate and display
360
+ report_time = start_time + FREQUENCY + 2
361
+
362
+ returned_procs = 0
363
+ active_connections = 0
364
+ stats_received = 0
365
+
366
+ global current_proc
367
+ global thread_id
368
+
369
+ current_proc = -1
370
+ current_cc = 0
371
+ thread_id = 0
372
+ pause_for_ramp_time = 0
373
+
284
374
  iterations_per_thread = None
285
375
  if iterations:
286
376
  # ensure we don't create more threads than the total number of iterations requested.
@@ -293,232 +383,299 @@ def run(
293
383
  f"You have requested {iterations} iterations on {concurrency} threads. {iterations} modulo {concurrency} = {iterations%concurrency} iterations will not be executed."
294
384
  )
295
385
 
296
- duration_endtime = None
297
- if duration:
298
- duration_endtime = time.time() + duration
386
+ # if no schedule was passed, create a schedule with just 1 line
387
+ if schedule is None:
388
+ schedule = [(concurrency, max_rate, ramp, duration)]
299
389
 
300
- q = mp.Queue(maxsize=0)
301
- kill_q = mp.Queue(maxsize=concurrency)
390
+ # loop through all lines in the schedule
391
+ for i, s in enumerate(schedule):
392
+ cc, max_rate, ramp_time, dur = s
302
393
 
303
- # calculate the ramp up schedule, if any
304
- threads_per_proc = dbworkload.utils.common.get_threads_per_proc(procs, concurrency)
305
- ramp_interval = ramp / concurrency
394
+ # sanitize
395
+ if dur and ramp_time > dur:
396
+ ramp_time = dur
306
397
 
307
- # each Process must generate an ID for each of its threads,
308
- # starting from the id_base_counter and incrementing by 1.
309
- # for each Process' MainThread, the id_base_counter is also its id.
310
- id_base_counter = 0
398
+ logger.info(
399
+ f"Starting schedule {i+1}/{len(schedule)}: cc={cc}, max_rate={max_rate}, ramp={ramp_time}, dur={dur}"
400
+ )
311
401
 
312
- processes: list[mp.Process] = []
313
- for x in threads_per_proc:
314
- processes.append(
315
- mp.Process(
316
- target=worker,
402
+ # always make sure that a duration is specified, even if none was passed
403
+ # in which case it defaults to infinite
404
+ end_schedule_time = time.time() + dur if dur else float("inf")
405
+
406
+ # if max_rate was set instead of concurrency
407
+ # and current_cc = 0,
408
+ # start the workload with 1 thread so that dbworkload
409
+ # has stats to measure on for adding/removing threads
410
+ # as part of the calculations for maintaining
411
+ # the desired max_rate
412
+ if current_cc == 0 and max_rate:
413
+ Thread(
414
+ target=launch_or_kill_workers,
415
+ daemon=True,
317
416
  args=(
318
- x - 1,
319
- ramp_interval,
320
- q,
321
- kill_q,
322
- log_level,
323
- conn_info,
324
- workload,
325
- args,
417
+ queues,
418
+ ramp_time,
419
+ 1,
420
+ procs,
326
421
  iterations_per_thread,
327
- duration_endtime,
328
- conn_duration,
329
422
  concurrency,
330
- offset,
331
- id_base_counter,
332
- id_base_counter,
333
- driver,
334
423
  ),
335
- daemon=True,
336
- )
337
- )
338
- id_base_counter += x
424
+ ).start()
339
425
 
340
- # starting the actual processes is done by the ramp_up method,
341
- # executed asynchronously, in its own thread
342
- threading.Thread(
343
- target=ramp_up, daemon=True, args=(processes, ramp_interval, threads_per_proc)
344
- ).start()
426
+ current_cc = 1
345
427
 
346
- # report time happens 2 seconds after the stats are received.
347
- # we add this buffer to make sure we get all the stats reports
348
- # from each thread before we aggregate and display
349
- report_time = start_time + FREQUENCY + 2
428
+ if not max_rate:
429
+ Thread(
430
+ target=launch_or_kill_workers,
431
+ daemon=True,
432
+ args=(
433
+ queues,
434
+ ramp_time,
435
+ cc - current_cc,
436
+ procs,
437
+ iterations_per_thread,
438
+ concurrency,
439
+ ),
440
+ ).start()
350
441
 
351
- returned_threads = 0
352
- active_connections = 0
353
- stats_received = 0
442
+ current_cc = cc
354
443
 
355
- while True:
356
- try:
357
- # read from the queue for stats or completion messages
358
- msg = q.get(block=False)
359
- # a stats report is a list obj
360
- if isinstance(msg, list):
361
- stats_received += 1
362
- stats.add_tds(msg)
363
- elif msg == "init":
364
- active_connections += 1
365
- else:
366
- # the worker returned
367
- # the mmsg is either a 'task_done' or 'poison_pill',
368
- # depending on the reason why the thread returned
369
- returned_threads += 1
370
- except queue.Empty:
371
- pass
444
+ returned_threads = 0
372
445
 
373
- # once the sum of the completion messages matches
374
- # the count of threads, identify what type of
375
- # completion message it was
376
- if returned_threads > 0 and returned_threads >= active_connections:
377
- if msg == "task_done":
378
- logger.info("Requested iteration/duration limit reached")
379
- gracefully_shutdown()
380
- elif msg == "poison_pill":
381
- gracefully_shutdown()
382
- elif isinstance(msg, Exception):
383
- logger.error(f"error_type={msg.__class__.__name__}, msg={msg}")
384
- sys.exit(1)
385
- else:
386
- logger.error(f"unrecognized message: {msg}")
387
- sys.exit(1)
446
+ # loop for the entire duration of the schedule's current line
447
+ while time.time() < end_schedule_time:
448
+ try:
449
+ # read from the queue for stats or completion messages
450
+ msg = to_main_q.get(block=False)
451
+ # a stats report is a list obj
452
+ if isinstance(msg, list):
453
+ stats_received += 1
454
+ stats.add_tds(msg)
455
+ elif msg == "init":
456
+ active_connections += 1
457
+ elif msg == "got_killed":
458
+ active_connections -= 1
459
+ elif msg == "proc_returned":
460
+ returned_procs += 1
461
+ elif msg == "task_done":
462
+ returned_threads += 1
463
+ except queue.Empty:
464
+ pass
465
+
466
+ # check if all procs returned, then exit
467
+ if returned_procs >= procs or (
468
+ returned_threads > 0 and returned_threads >= active_connections
469
+ ):
470
+ if msg == "task_done":
471
+ logger.info("Requested iteration/duration limit reached")
472
+ gracefully_shutdown()
473
+ elif msg == "proc_returned":
474
+ logger.debug("All procs returned")
475
+ gracefully_shutdown(by_keyinterrupt=True)
476
+ elif isinstance(msg, Exception):
477
+ logger.error(f"error_type={msg.__class__.__name__}, msg={msg}")
478
+ sys.exit(1)
479
+ else:
480
+ logger.error(f"unrecognized message: {msg}")
481
+ sys.exit(1)
388
482
 
389
- if time.time() >= report_time:
390
- if stats_received != active_connections:
391
- logger.warning("didn't receive all stats reports yet")
483
+ if time.time() >= report_time:
484
+ # if stats_received != active_connections:
485
+ # logger.warning("didn't receive all stats reports yet")
486
+
487
+ # remove the 2 seconds added
488
+ endtime = int(time.time()) - 2
489
+
490
+ report = stats.calculate_stats(active_connections, endtime)
491
+
492
+ # if max_rate is specified, try to stick to it.
493
+ # to calculate how to get to the max rate, we need a non-empty report
494
+ if max_rate and report:
495
+ current_rate = report[0][6] # __cycle__ period_ops/s
496
+
497
+ # approximate how many threads are needed to get
498
+ # to the desired max_rate given the current QPS rate
499
+ # and current threads count
500
+ extrapolated_cc = int(max_rate / (current_rate / current_cc))
501
+
502
+ # adjust the thread count if there is a difference
503
+ # between the current thread count and the calculated
504
+ # thread count, but not if there is one such operation already
505
+ # running, that is, not if there's an operation that is slow due
506
+ # to a long ramp_time.
507
+ if (
508
+ extrapolated_cc - current_cc
509
+ and time.time() >= pause_for_ramp_time
510
+ ):
511
+ Thread(
512
+ target=launch_or_kill_workers,
513
+ daemon=True,
514
+ args=(
515
+ queues,
516
+ ramp_time,
517
+ extrapolated_cc - current_cc,
518
+ procs,
519
+ iterations_per_thread,
520
+ concurrency,
521
+ ),
522
+ ).start()
523
+
524
+ # make sure we will not add/remove threads while the newly
525
+ # created thread is still working
526
+ pause_for_ramp_time = time.time() + ramp_time + 2 * FREQUENCY
527
+
528
+ logger.warning(
529
+ f"Calculating max_rate: desired max_rate: {max_rate}, "
530
+ f"current_rate: {report[0][6]}, current_cc = {current_cc}, "
531
+ f"extrapolated_cc = {extrapolated_cc}, "
532
+ f"difference: {extrapolated_cc-current_cc}"
533
+ )
534
+ current_cc = extrapolated_cc
392
535
 
393
- # remove the 2 seconds added
394
- endtime = int(time.time()) - 2
536
+ # ramp_time is only considered for reaching the desired max_rate.
537
+ # For adjustments over time, we want the changes to happen immediately
538
+ # and not smoothed out over the initial ramp_time value
539
+ ramp_time = 0
395
540
 
396
- report = stats.calculate_stats(active_connections, endtime)
541
+ centroids = stats.get_centroids()
397
542
 
398
- centroids = stats.get_centroids()
543
+ stats.new_window(endtime)
544
+ stats_received = 0
399
545
 
400
- stats.new_window(endtime)
401
- stats_received = 0
546
+ if save:
547
+ with open(run_name + ".csv", "a") as f:
548
+ for row in report:
549
+ f.write(str(stats.endtime) + ",")
550
+ for col in row:
551
+ f.write(str(col) + ",")
552
+ np.savetxt(f, next(centroids), newline=";")
553
+ f.write("\n")
402
554
 
403
- if save:
404
- with open(run_name + ".csv", "a") as f:
405
- for row in report:
406
- f.write(str(stats.endtime) + ",")
407
- for col in row:
408
- f.write(str(col) + ",")
409
- np.savetxt(f, next(centroids), newline=";")
410
- f.write("\n")
555
+ if not quiet:
556
+ print_stats(report)
411
557
 
412
- if not quiet:
413
- print_stats(report)
558
+ prom.publish(report)
414
559
 
415
- prom.publish(report)
560
+ report_time += FREQUENCY
416
561
 
417
- report_time += FREQUENCY
562
+ # pause briefly to prevent the loop from overheating the CPU
563
+ time.sleep(0.1)
418
564
 
419
- # pause briefly to prevent the loop from overheating the CPU
420
- time.sleep(0.1)
565
+ gracefully_shutdown()
421
566
 
422
- def worker(
423
- thread_count: int,
424
- interval: int,
425
- q: mp.Queue,
426
- kill_q: mp.Queue,
567
+
568
+ # a supervisor runs in a separate process.
569
+ # The idea is to create as many supervisors as vCPUs.
570
+ # The sole role of the supervisor is to listen for instructions
571
+ # from the MainProcess.
572
+ # Instructions are:
573
+ # - Create a new worker.
574
+ # - Destroy a worker.
575
+ # - Destroy all workers and return.
576
+ def supervisor(
577
+ to_main_q: mp.Queue,
578
+ from_main_q: mp.Queue,
427
579
  log_level: str,
428
580
  conn_info: ConnInfo,
581
+ driver: str,
429
582
  workload: object,
430
583
  args: dict,
431
- iterations: int,
432
- duration_endtime: float,
433
584
  conn_duration: int,
434
- conc: int,
435
585
  offset: int,
436
- id_base_counter: int = 0,
437
- id: int = 0,
438
- driver: str = None,
586
+ id: int,
439
587
  ):
440
- """Process worker function to run the workload in a multiprocessing env
441
-
442
- Args:
443
- thread_count (int): The number of threads to create
444
- q (mp.Queue): queue to report query metrics
445
- kill_q (mp.Queue): queue to handle stopping the worker
446
- log_level (str): log level to set the logger to
447
- conn_info (ConnInfo): connection data
448
- workload (object): workload class object
449
- args (dict): args to init the workload class
450
- iterations (int): count of workload iteration before returning
451
- duration_endtime (float): timestamp at which to stop and return
452
- conn_duration (int): seconds before restarting the database connection
453
- conc: (int): the total number of threads
454
- id_base_counter (int): the base counter to generate ID for each Process
455
- id (int): the ID of the thread
456
- driver (str): the friendly driver name
457
- """
458
-
459
588
  def gracefully_return(msg):
460
- # send notification to MainThread
461
- q.put(msg)
462
- # send final stats
463
- q.put(ws.get_tdigest_ndarray(), block=False)
464
-
465
- # wait for all Processes children threads to return before
589
+ # wait for Threads to return before
466
590
  # letting the Process MainThread return
591
+ # threading.enumerate()
592
+ for x in threads:
593
+ if x.is_alive():
594
+ from_proc_q.put("poison_pill")
595
+
467
596
  for x in threads:
468
597
  if x.is_alive():
469
598
  x.join()
470
599
 
600
+ # send notification to MainThread
601
+ to_main_q.put(msg)
602
+
603
+ logger.debug(f"PROC-{id} terminated")
604
+ return
605
+
471
606
  logger.setLevel(log_level)
607
+ logger.debug(f"PROC-{id} started")
472
608
 
473
- logger.debug(f"My ID is {id}")
474
-
475
- threads: list[threading.Thread] = []
476
-
477
- # execute only if the current thread is the main thread for each process
478
- if thread_count is not None:
479
- # capture KeyboardInterrupt and do nothing
480
- signal.signal(signal.SIGINT, signal.SIG_IGN)
481
-
482
- # only the MainThread of a child Process spawns Threads
483
- for i in range(thread_count):
484
- threads.append(
485
- threading.Thread(
486
- target=worker,
487
- daemon=True,
488
- args=(
489
- None,
490
- 0,
491
- q,
492
- kill_q,
493
- log_level,
494
- conn_info,
495
- workload,
496
- args,
497
- iterations,
498
- duration_endtime,
499
- conn_duration,
500
- conc,
501
- offset,
502
- None,
503
- id_base_counter + i + 1,
504
- driver,
505
- ),
506
- )
609
+ threads: list[Thread] = []
610
+ from_proc_q = mp.Queue()
611
+
612
+ # capture KeyboardInterrupt and do nothing
613
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
614
+
615
+ while True:
616
+ msg = from_main_q.get(block=True)
617
+
618
+ if msg == "proc_end":
619
+ logger.debug(f"PROC-{id} terminating...")
620
+ gracefully_return("proc_returned")
621
+ return
622
+ elif msg == "kill_one":
623
+ from_proc_q.put("poison_pill")
624
+ elif isinstance(msg, tuple):
625
+ t = Thread(
626
+ target=worker,
627
+ daemon=True,
628
+ args=(
629
+ to_main_q,
630
+ from_proc_q,
631
+ log_level,
632
+ conn_info,
633
+ driver,
634
+ workload,
635
+ args,
636
+ conn_duration,
637
+ offset,
638
+ *msg,
639
+ ),
507
640
  )
641
+ t.start()
642
+ threads.append(t)
508
643
 
509
- # starting each Thread is done by the ramp_up in its own thread
510
- threading.Thread(
511
- target=ramp_up,
512
- daemon=True,
513
- args=(threads, interval, [1] * thread_count, interval),
514
- ).start()
644
+
645
+ def worker(
646
+ to_main_q: mp.Queue,
647
+ from_proc_q: mp.Queue,
648
+ log_level: str,
649
+ conn_info: ConnInfo,
650
+ driver: str,
651
+ workload: object,
652
+ args: dict,
653
+ conn_duration: int,
654
+ offset: int,
655
+ id: int = 0,
656
+ iterations: int = 0,
657
+ concurrency: int = 0,
658
+ ):
659
+ def gracefully_return(msg):
660
+ # send notification to MainThread
661
+ to_main_q.put(msg)
662
+ # send final stats
663
+ to_main_q.put(ws.get_tdigest_ndarray(), block=False)
664
+
665
+ logger.debug(f"Thread ID {id} terminated")
666
+
667
+ return
668
+
669
+ logger.setLevel(log_level)
670
+
671
+ logger.debug(f"Thread ID {id} started")
515
672
 
516
673
  # catch exception while instantiating the workload class
517
674
  try:
518
675
  w = workload(args)
519
676
  except Exception as e:
520
677
  stack_lines = traceback.format_exc()
521
- q.put(Exception(stack_lines))
678
+ to_main_q.put(Exception(stack_lines))
522
679
  return
523
680
 
524
681
  c = 0
@@ -530,22 +687,13 @@ def worker(
530
687
  run_init = True
531
688
 
532
689
  # send notification that a new thread has started
533
- q.put("init")
690
+ to_main_q.put("init")
534
691
 
535
692
  while True:
536
693
  if conn_duration:
537
694
  # reconnect every conn_duration +/- 20%
538
695
  conn_endtime = time.time() + int(conn_duration * random.uniform(0.8, 1.2))
539
696
 
540
- # listen for termination messages (poison pill)
541
- try:
542
- kill_q.get(block=False)
543
- logger.debug("Poison pill received")
544
- gracefully_return("poison_pill")
545
- return
546
- except queue.Empty:
547
- pass
548
-
549
697
  try:
550
698
  logger.debug(f"driver: {driver}, params: {conn_info.params}")
551
699
  # with Cluster().connect('bank') as conn:
@@ -560,7 +708,11 @@ def worker(
560
708
  logger.debug("Executing setup() function")
561
709
  run_transaction(
562
710
  conn,
563
- lambda conn: w.setup(conn, id, conc),
711
+ lambda conn: w.setup(
712
+ conn,
713
+ id,
714
+ concurrency,
715
+ ),
564
716
  driver,
565
717
  max_retries=MAX_RETRIES,
566
718
  )
@@ -572,16 +724,14 @@ def worker(
572
724
  while True:
573
725
  # listen for termination messages (poison pill)
574
726
  try:
575
- kill_q.get(block=False)
727
+ from_proc_q.get(block=False)
576
728
  logger.debug("Poison pill received")
577
- return gracefully_return("poison_pill")
729
+ return gracefully_return("got_killed")
578
730
  except queue.Empty:
579
731
  pass
580
732
 
581
- # return if the limits of either iteration count and duration have been reached
582
- if (iterations and c >= iterations) or (
583
- duration_endtime and time.time() >= duration_endtime
584
- ):
733
+ # return if the iteration count has been reached
734
+ if iterations and c >= iterations:
585
735
  logger.debug("Task completed!")
586
736
  gracefully_return("task_done")
587
737
  return
@@ -618,10 +768,10 @@ def worker(
618
768
 
619
769
  ws.add_latency_measurement("__cycle__", time.time() - cycle_start)
620
770
 
621
- if q.full():
771
+ if to_main_q.full():
622
772
  logger.error("=========== Q FULL!!!! ======================")
623
773
  if time.time() >= stat_time:
624
- q.put(ws.get_tdigest_ndarray(), block=False)
774
+ to_main_q.put(ws.get_tdigest_ndarray(), block=False)
625
775
  ws.new_window()
626
776
  stat_time += FREQUENCY
627
777
 
@@ -630,7 +780,7 @@ def worker(
630
780
  import psycopg
631
781
 
632
782
  if isinstance(e, psycopg.errors.UndefinedTable):
633
- q.put(e)
783
+ to_main_q.put(e)
634
784
  return
635
785
  log_and_sleep(e)
636
786
 
@@ -638,26 +788,26 @@ def worker(
638
788
  import mysql.connector.errorcode
639
789
 
640
790
  if e.errno == mysql.connector.errorcode.ER_NO_SUCH_TABLE:
641
- q.put(e)
791
+ to_main_q.put(e)
642
792
  return
643
793
  log_and_sleep(e)
644
794
 
645
795
  elif driver == "maria":
646
796
  if str(e).endswith(" doesn't exist"):
647
- q.put(e)
797
+ to_main_q.put(e)
648
798
  return
649
799
  log_and_sleep(e)
650
800
 
651
801
  elif driver == "oracle":
652
802
  if str(e).startswith("ORA-00942: table or view does not exist"):
653
- q.put(e)
803
+ to_main_q.put(e)
654
804
  return
655
805
  log_and_sleep(e)
656
806
 
657
807
  else:
658
808
  # for all other Exceptions, report and return
659
809
  logger.error(type(e), stack_info=True)
660
- q.put(e)
810
+ to_main_q.put(e)
661
811
  return
662
812
 
663
813
 
@@ -29,8 +29,8 @@ logger.setLevel(logging.INFO)
29
29
 
30
30
 
31
31
  def util_csv(
32
- input: str,
33
- output: str,
32
+ input: PosixPath,
33
+ output: PosixPath,
34
34
  compression: str,
35
35
  procs: int,
36
36
  csv_max_rows: int,
@@ -54,13 +54,11 @@ def util_csv(
54
54
  if os.path.isdir(output_dir):
55
55
  os.rename(
56
56
  output_dir,
57
- output_dir + "." + dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S"),
57
+ str(output_dir)
58
+ + "."
59
+ + dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d-%H%M%S"),
58
60
  )
59
61
 
60
- # if the output dir is
61
- if os.path.exists(output_dir):
62
- output_dir += "_dir"
63
-
64
62
  # create new directory
65
63
  os.mkdir(output_dir)
66
64
 
@@ -92,7 +90,7 @@ def util_csv(
92
90
  print()
93
91
 
94
92
 
95
- def util_yaml(input: str, output: str):
93
+ def util_yaml(input: PosixPath, output: PosixPath):
96
94
  """Wrapper around util function ddl_to_yaml() for
97
95
  crafting a data gen definition YAML string from
98
96
  CREATE TABLE statements.
@@ -106,7 +104,12 @@ def util_yaml(input: str, output: str):
106
104
 
107
105
  # backup the current file as to not override
108
106
  if os.path.exists(output):
109
- os.rename(output, output + "." + dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S"))
107
+ os.rename(
108
+ output,
109
+ str(output)
110
+ + "."
111
+ + dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d-%H%M%S"),
112
+ )
110
113
 
111
114
  # create new file
112
115
  with open(output, "w") as f:
@@ -153,7 +156,7 @@ def util_merge_sort(input_dir: str, output_dir: str, csv_max_rows: int, compress
153
156
  self.output_dir,
154
157
  str(self.output_dir)
155
158
  + "."
156
- + dt.datetime.utcnow().strftime("%Y%m%d-%H%M%S"),
159
+ + dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d-%H%M%S"),
157
160
  )
158
161
 
159
162
  # create new directory
@@ -654,6 +654,23 @@ def ddl_to_yaml(ddl: str):
654
654
  elif within_brackets > 0 and i == ",":
655
655
  col_def += ":"
656
656
 
657
+ # process the content within parenthesis in the
658
+ # CREATE TABLE stmt char by char to distinguish
659
+ # the comma for separating columns vs the comma
660
+ # included in single quote strings such as those in DEFAULT
661
+ # eg: mycol STRING NULL DEFAULT 'corporate, inc'
662
+ within_quote = False
663
+ col_def_str = col_def
664
+ col_def = ""
665
+ for i in col_def_str:
666
+ if i == "'":
667
+ within_quote = not within_quote
668
+ continue
669
+ if within_quote:
670
+ continue
671
+ else:
672
+ col_def += i
673
+
657
674
  col_def = [x.strip().lower() for x in col_def.split(",")]
658
675
 
659
676
  ll = []
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dbworkload"
3
- version = "0.6.5"
3
+ version = "0.7.0"
4
4
  description = "Workload framework"
5
5
  authors = ["Fabio Ghirardello"]
6
6
  license = "GPLv3+"
File without changes
File without changes