OpenOrchestrator 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
1
  """This module handles the connection to the database in OpenOrchestrator."""
2
2
 
3
3
  from datetime import datetime
4
- from typing import Callable, TypeVar, ParamSpec
4
+ from typing import TypeVar, ParamSpec
5
+ from uuid import UUID
5
6
 
6
- from croniter import croniter
7
+ from croniter import croniter # type: ignore
7
8
  from sqlalchemy import Engine, create_engine, select, insert, desc
8
9
  from sqlalchemy import exc as alc_exc
9
10
  from sqlalchemy import func as alc_func
@@ -15,12 +16,13 @@ from OpenOrchestrator.database.logs import Log, LogLevel
15
16
  from OpenOrchestrator.database.constants import Constant, Credential
16
17
  from OpenOrchestrator.database.triggers import Trigger, SingleTrigger, ScheduledTrigger, QueueTrigger, TriggerStatus
17
18
  from OpenOrchestrator.database.queues import QueueElement, QueueStatus
19
+ from OpenOrchestrator.database.truncated_string import truncate_message
18
20
 
19
21
  # Type hint helpers for decorators
20
22
  T = TypeVar("T")
21
23
  P = ParamSpec("P")
22
24
 
23
- _connection_engine: Engine = None
25
+ _connection_engine: Engine | None = None
24
26
 
25
27
 
26
28
  def connect(conn_string: str) -> bool:
@@ -48,19 +50,25 @@ def connect(conn_string: str) -> bool:
48
50
  def disconnect() -> None:
49
51
  """Disconnect from the database."""
50
52
  global _connection_engine # pylint: disable=global-statement
51
- _connection_engine.dispose()
53
+ if _connection_engine:
54
+ _connection_engine.dispose()
52
55
  _connection_engine = None
53
56
 
54
57
 
55
- def catch_db_error(func: Callable[P, T]) -> Callable[P, T]:
56
- """A decorator that catches errors in SQL calls."""
57
- def inner(*args, **kwargs) -> T:
58
- if _connection_engine is None:
59
- raise RuntimeError("Not connected to Database")
58
+ def _get_session() -> Session:
59
+ """Check if theres a database connection and return a
60
+ session to it.
60
61
 
61
- return func(*args, **kwargs)
62
+ Raises:
63
+ RuntimeError: If there's no connected database.
62
64
 
63
- return inner
65
+ Returns:
66
+ A database session.
67
+ """
68
+ if not _connection_engine:
69
+ raise RuntimeError("Not connected to database.")
70
+
71
+ return Session(_connection_engine)
64
72
 
65
73
 
66
74
  def get_conn_string() -> str:
@@ -69,25 +77,24 @@ def get_conn_string() -> str:
69
77
  Returns:
70
78
  str: The connection string if any.
71
79
  """
72
- try:
73
- return str(_connection_engine.url)
74
- except AttributeError:
75
- pass
80
+ if not _connection_engine:
81
+ raise RuntimeError("Not connected to database.")
76
82
 
77
- return None
83
+ return str(_connection_engine.url)
78
84
 
79
85
 
80
- @catch_db_error
81
86
  def initialize_database() -> None:
82
87
  """Initializes the database with all the needed tables."""
88
+ if not _connection_engine:
89
+ raise RuntimeError("Not connected to database.")
90
+
83
91
  logs.create_tables(_connection_engine)
84
92
  triggers.create_tables(_connection_engine)
85
93
  constants.create_tables(_connection_engine)
86
94
  queues.create_tables(_connection_engine)
87
95
 
88
96
 
89
- @catch_db_error
90
- def get_trigger(trigger_id: str) -> Trigger:
97
+ def get_trigger(trigger_id: UUID | str) -> Trigger:
91
98
  """Get the trigger with the given id.
92
99
 
93
100
  Args:
@@ -95,24 +102,34 @@ def get_trigger(trigger_id: str) -> Trigger:
95
102
 
96
103
  Returns:
97
104
  Trigger: The trigger with the given id.
105
+
106
+ Raises:
107
+ ValueError: If the trigger doesn't exist.
98
108
  """
99
- with Session(_connection_engine) as session:
109
+ if isinstance(trigger_id, str):
110
+ trigger_id = UUID(trigger_id)
111
+
112
+ with _get_session() as session:
100
113
  query = (
101
114
  select(Trigger)
102
115
  .where(Trigger.id == trigger_id)
103
116
  .options(selectin_polymorphic(Trigger, (ScheduledTrigger, QueueTrigger, SingleTrigger)))
104
117
  )
105
- return session.scalar(query)
118
+ trigger = session.scalar(query)
119
+
120
+ if not trigger:
121
+ raise ValueError(f"No trigger with the given id: {trigger_id}")
122
+
123
+ return trigger
106
124
 
107
125
 
108
- @catch_db_error
109
- def get_all_triggers() -> tuple[Trigger]:
126
+ def get_all_triggers() -> tuple[Trigger, ...]:
110
127
  """Get all triggers in the database.
111
128
 
112
129
  Returns:
113
130
  A tuple of Trigger objects.
114
131
  """
115
- with Session(_connection_engine) as session:
132
+ with _get_session() as session:
116
133
  query = (
117
134
  select(Trigger)
118
135
  .options(selectin_polymorphic(Trigger, (ScheduledTrigger, QueueTrigger, SingleTrigger)))
@@ -120,75 +137,72 @@ def get_all_triggers() -> tuple[Trigger]:
120
137
  return tuple(session.scalars(query))
121
138
 
122
139
 
123
- @catch_db_error
124
140
  def update_trigger(trigger: Trigger):
125
141
  """Updates an existing trigger in the database.
126
142
 
127
143
  Args:
128
144
  trigger: The trigger object with updated values.
129
145
  """
130
- with Session(_connection_engine) as session:
146
+ with _get_session() as session:
131
147
  session.add(trigger)
132
148
  session.commit()
133
149
  session.refresh(trigger)
134
150
 
135
151
 
136
- @catch_db_error
137
- def get_scheduled_triggers() -> tuple[ScheduledTrigger]:
152
+ def get_scheduled_triggers() -> tuple[ScheduledTrigger, ...]:
138
153
  """Get all scheduled triggers from the database.
139
154
 
140
155
  Returns:
141
- tuple[ScheduledTrigger]: A list of all scheduled triggers in the database.
156
+ A list of all scheduled triggers in the database.
142
157
  """
143
- with Session(_connection_engine) as session:
158
+ with _get_session() as session:
144
159
  query = select(ScheduledTrigger)
145
160
  result = session.scalars(query).all()
146
161
  return tuple(result)
147
162
 
148
163
 
149
- @catch_db_error
150
- def get_single_triggers() -> tuple[SingleTrigger]:
164
+ def get_single_triggers() -> tuple[SingleTrigger, ...]:
151
165
  """Get all single triggers from the database.
152
166
 
153
167
  Returns:
154
- tuple[SingleTrigger]: A list of all single triggers in the database.
168
+ A list of all single triggers in the database.
155
169
  """
156
- with Session(_connection_engine) as session:
170
+ with _get_session() as session:
157
171
  query = select(SingleTrigger)
158
172
  result = session.scalars(query).all()
159
173
  return tuple(result)
160
174
 
161
175
 
162
- @catch_db_error
163
- def get_queue_triggers() -> tuple[QueueTrigger]:
176
+ def get_queue_triggers() -> tuple[QueueTrigger, ...]:
164
177
  """Get all queue triggers from the database.
165
178
 
166
179
  Returns:
167
- tuple[QueueTrigger]: A list of all queue triggers in the database.
180
+ A list of all queue triggers in the database.
168
181
  """
169
- with Session(_connection_engine) as session:
182
+ with _get_session() as session:
170
183
  query = select(QueueTrigger)
171
184
  result = session.scalars(query).all()
172
185
  return tuple(result)
173
186
 
174
187
 
175
- @catch_db_error
176
- def delete_trigger(trigger_id: str) -> None:
188
+ def delete_trigger(trigger_id: UUID | str) -> None:
177
189
  """Delete the given trigger from the database.
178
190
 
179
191
  Args:
180
192
  trigger_id: The id of the trigger to delete.
181
193
  """
182
- with Session(_connection_engine) as session:
183
- trigger = session.get(Trigger, trigger_id)
194
+ if isinstance(trigger_id, str):
195
+ trigger_id = UUID(trigger_id)
196
+
197
+ with _get_session() as session:
198
+ trigger = get_trigger(trigger_id)
184
199
  session.delete(trigger)
185
200
  session.commit()
186
201
 
187
202
 
188
- @catch_db_error
189
203
  def get_logs(offset: int, limit: int,
190
- from_date: datetime = None, to_date: datetime = None,
191
- process_name: str = None, log_level: LogLevel = None) -> tuple[Log]:
204
+ from_date: datetime | None = None, to_date: datetime | None = None,
205
+ process_name: str | None = None, log_level: LogLevel | None = None) -> tuple[Log, ...]:
192
206
  """Get the logs from the database using filters and pagination.
193
207
 
194
208
  Args:
@@ -200,7 +214,7 @@ def get_logs(offset: int, limit: int,
200
214
  log_level: The log level to filter on. If none the filter is disabled.
201
215
 
202
216
  Returns:
203
- tuple[Log]: A list of logs matching the given filters.
217
+ A list of logs matching the given filters.
204
218
  """
205
219
  query = (
206
220
  select(Log)
@@ -221,12 +235,11 @@ def get_logs(offset: int, limit: int,
221
235
  if log_level:
222
236
  query = query.where(Log.log_level == log_level)
223
237
 
224
- with Session(_connection_engine) as session:
238
+ with _get_session() as session:
225
239
  result = session.scalars(query).all()
226
240
  return tuple(result)
227
241
 
228
242
 
229
- @catch_db_error
230
243
  def create_log(process_name: str, level: LogLevel, message: str) -> None:
231
244
  """Create a log in the logs table in the database.
232
245
 
@@ -235,22 +248,21 @@ def create_log(process_name: str, level: LogLevel, message: str) -> None:
235
248
  level: The level of the log.
236
249
  message: The message of the log.
237
250
  """
238
- with Session(_connection_engine) as session:
251
+ with _get_session() as session:
239
252
  log = Log(
240
253
  log_level = level,
241
254
  process_name = process_name,
242
- log_message = message
255
+ log_message = truncate_message(message)
243
256
  )
244
257
  session.add(log)
245
258
  session.commit()
246
259
 
247
260
 
248
- @catch_db_error
249
- def get_unique_log_process_names() -> tuple[str]:
261
+ def get_unique_log_process_names() -> tuple[str, ...]:
250
262
  """Get a list of unique process names in the logs database.
251
263
 
252
264
  Returns:
253
- tuple[str]: A list of unique process names.
265
+ A list of unique process names.
254
266
  """
255
267
 
256
268
  query = (
@@ -259,12 +271,11 @@ def get_unique_log_process_names() -> tuple[str]:
259
271
  .order_by(Log.process_name)
260
272
  )
261
273
 
262
- with Session(_connection_engine) as session:
274
+ with _get_session() as session:
263
275
  result = session.scalars(query).all()
264
276
  return tuple(result)
265
277
 
266
278
 
267
- @catch_db_error
268
279
  def create_single_trigger(trigger_name: str, process_name: str, next_run: datetime,
269
280
  process_path: str, process_args: str, is_git_repo: bool, is_blocking: bool) -> None:
270
281
  """Create a new single trigger in the database.
@@ -278,7 +289,7 @@ def create_single_trigger(trigger_name: str, process_name: str, next_run: dateti
278
289
  is_git_repo: If the process_path points to a git repo.
279
290
  is_blocking: If the process should be blocking.
280
291
  """
281
- with Session(_connection_engine) as session:
292
+ with _get_session() as session:
282
293
  trigger = SingleTrigger(
283
294
  trigger_name= trigger_name,
284
295
  process_name = process_name,
@@ -292,7 +303,6 @@ def create_single_trigger(trigger_name: str, process_name: str, next_run: dateti
292
303
  session.commit()
293
304
 
294
305
 
295
- @catch_db_error
296
306
  def create_scheduled_trigger(trigger_name: str, process_name: str, cron_expr: str, next_run: datetime,
297
307
  process_path: str, process_args: str, is_git_repo: bool,
298
308
  is_blocking: bool) -> None:
@@ -308,7 +318,7 @@ def create_scheduled_trigger(trigger_name: str, process_name: str, cron_expr: st
308
318
  is_git_repo: If the process_path points to a git repo.
309
319
  is_blocking: If the process should be blocking.
310
320
  """
311
- with Session(_connection_engine) as session:
321
+ with _get_session() as session:
312
322
  trigger = ScheduledTrigger(
313
323
  trigger_name= trigger_name,
314
324
  process_name = process_name,
@@ -323,7 +333,6 @@ def create_scheduled_trigger(trigger_name: str, process_name: str, cron_expr: st
323
333
  session.commit()
324
334
 
325
335
 
326
- @catch_db_error
327
336
  def create_queue_trigger(trigger_name: str, process_name: str, queue_name: str, process_path: str,
328
337
  process_args: str, is_git_repo: bool, is_blocking: bool,
329
338
  min_batch_size: int) -> None:
@@ -339,7 +348,7 @@ def create_queue_trigger(trigger_name: str, process_name: str, queue_name: str,
339
348
  is_blocking: The is_blocking value of the process.
340
349
  min_batch_size: The minimum number of queue elements before triggering.
341
350
  """
342
- with Session(_connection_engine) as session:
351
+ with _get_session() as session:
343
352
  trigger = QueueTrigger(
344
353
  trigger_name= trigger_name,
345
354
  process_name = process_name,
@@ -354,7 +363,6 @@ def create_queue_trigger(trigger_name: str, process_name: str, queue_name: str,
354
363
  session.commit()
355
364
 
356
365
 
357
- @catch_db_error
358
366
  def get_constant(name: str) -> Constant:
359
367
  """Get a constant from the database.
360
368
 
@@ -367,27 +375,25 @@ def get_constant(name: str) -> Constant:
367
375
  Raises:
368
376
  ValueError: If no constant with the given name exists.
369
377
  """
370
- with Session(_connection_engine) as session:
378
+ with _get_session() as session:
371
379
  constant = session.get(Constant, name)
372
380
  if constant is None:
373
381
  raise ValueError(f"No constant with name '{name}' was found.")
374
382
  return constant
375
383
 
376
384
 
377
- @catch_db_error
378
- def get_constants() -> tuple[Constant]:
385
+ def get_constants() -> tuple[Constant, ...]:
379
386
  """Get all constants in the database.
380
387
 
381
388
  Returns:
382
389
  tuple[Constants]: A list of constants.
383
390
  """
384
- with Session(_connection_engine) as session:
391
+ with _get_session() as session:
385
392
  query = select(Constant).order_by(Constant.name)
386
393
  result = session.scalars(query).all()
387
394
  return tuple(result)
388
395
 
389
396
 
390
- @catch_db_error
391
397
  def create_constant(name: str, value: str) -> None:
392
398
  """Create a new constant in the database.
393
399
 
@@ -395,13 +401,12 @@ def create_constant(name: str, value: str) -> None:
395
401
  name: The name of the constant.
396
402
  value: The value of the constant.
397
403
  """
398
- with Session(_connection_engine) as session:
404
+ with _get_session() as session:
399
405
  constant = Constant(name = name, value = value)
400
406
  session.add(constant)
401
407
  session.commit()
402
408
 
403
409
 
404
- @catch_db_error
405
410
  def update_constant(name: str, new_value: str) -> None:
406
411
  """Updates an existing constant with a new value.
407
412
 
@@ -409,32 +414,35 @@ def update_constant(name: str, new_value: str) -> None:
409
414
  name: The name of the constant to update.
410
415
  new_value: The new value of the constant.
411
416
  """
412
- with Session(_connection_engine) as session:
417
+ with _get_session() as session:
413
418
  constant = session.get(Constant, name)
419
+
420
+ if not constant:
421
+ raise ValueError(f"No constant with name '{name}' was found.")
422
+
414
423
  constant.value = new_value
415
424
  session.commit()
416
425
 
417
426
 
418
- @catch_db_error
419
427
  def delete_constant(name: str) -> None:
420
428
  """Delete the constant with the given name from the database.
421
429
 
422
430
  Args:
423
431
  name: The name of the constant to delete.
424
432
  """
425
- with Session(_connection_engine) as session:
433
+ with _get_session() as session:
426
434
  constant = session.get(Constant, name)
427
435
  session.delete(constant)
428
436
  session.commit()
429
437
 
430
438
 
431
- @catch_db_error
432
- def get_credential(name: str) -> Credential:
439
+ def get_credential(name: str, decrypt_password: bool = True) -> Credential:
433
440
  """Get a credential from the database.
434
441
  The password of the credential is decrypted.
435
442
 
436
443
  Args:
437
444
  name: The name of the credential.
445
+ decrypt_password: Whether to decrypt the credential password or not.
438
446
 
439
447
  Returns:
440
448
  Credential: The credential with the given name.
@@ -442,31 +450,31 @@ def get_credential(name: str) -> Credential:
442
450
  Raises:
443
451
  ValueError: If no credential with the given name exists.
444
452
  """
445
- with Session(_connection_engine) as session:
453
+ with _get_session() as session:
446
454
  credential = session.get(Credential, name)
447
455
 
448
456
  if credential is None:
449
457
  raise ValueError(f"No credential with name '{name}' was found.")
450
458
 
451
- credential.password = crypto_util.decrypt_string(credential.password)
459
+ if decrypt_password:
460
+ credential.password = crypto_util.decrypt_string(credential.password)
461
+
452
462
  return credential
453
463
 
454
464
 
455
- @catch_db_error
456
- def get_credentials() -> tuple[Credential]:
465
+ def get_credentials() -> tuple[Credential, ...]:
457
466
  """Get all credentials in the database.
458
467
  The passwords of the credentials are encrypted.
459
468
 
460
469
  Returns:
461
470
  tuple[Credential]: A list of credentials.
462
471
  """
463
- with Session(_connection_engine) as session:
472
+ with _get_session() as session:
464
473
  query = select(Credential).order_by(Credential.name)
465
474
  result = session.scalars(query).all()
466
475
  return tuple(result)
467
476
 
468
477
 
469
- @catch_db_error
470
478
  def create_credential(name: str, username: str, password: str) -> None:
471
479
  """Create a new credential in the database.
472
480
  The password is encrypted before sending it to the database.
@@ -479,7 +487,7 @@ def create_credential(name: str, username: str, password: str) -> None:
479
487
 
480
488
  password = crypto_util.encrypt_string(password)
481
489
 
482
- with Session(_connection_engine) as session:
490
+ with _get_session() as session:
483
491
  credential = Credential(
484
492
  name = name,
485
493
  username= username,
@@ -489,7 +497,6 @@ def create_credential(name: str, username: str, password: str) -> None:
489
497
  session.commit()
490
498
 
491
499
 
492
- @catch_db_error
493
500
  def update_credential(name: str, new_username: str, new_password: str) -> None:
494
501
  """Updates an existing credential with a new value.
495
502
 
@@ -500,28 +507,30 @@ def update_credential(name: str, new_username: str, new_password: str) -> None:
500
507
  """
501
508
  new_password = crypto_util.encrypt_string(new_password)
502
509
 
503
- with Session(_connection_engine) as session:
510
+ with _get_session() as session:
504
511
  credential = session.get(Credential, name)
512
+
513
+ if not credential:
514
+ raise ValueError(f"No credential with name '{name}' was found.")
515
+
505
516
  credential.username = new_username
506
517
  credential.password = new_password
507
518
  session.commit()
508
519
 
509
520
 
510
- @catch_db_error
511
521
  def delete_credential(name: str) -> None:
512
522
  """Delete the credential with the given name from the database.
513
523
 
514
524
  Args:
515
525
  name: The name of the credential to delete.
516
526
  """
517
- with Session(_connection_engine) as session:
527
+ with _get_session() as session:
518
528
  constant = session.get(Credential, name)
519
529
  session.delete(constant)
520
530
  session.commit()
521
531
 
522
532
 
523
- @catch_db_error
524
- def begin_single_trigger(trigger_id: str) -> bool:
533
+ def begin_single_trigger(trigger_id: UUID | str) -> bool:
525
534
  """Set the status of a single trigger to 'running' and
526
535
  set the last run time to the current time.
527
536
 
@@ -531,9 +540,15 @@ def begin_single_trigger(trigger_id: str) -> bool:
531
540
  Returns:
532
541
  bool: True if the trigger was 'idle' and now 'running'.
533
542
  """
534
- with Session(_connection_engine) as session:
543
+ if isinstance(trigger_id, str):
544
+ trigger_id = UUID(trigger_id)
545
+
546
+ with _get_session() as session:
535
547
  trigger = session.get(SingleTrigger, trigger_id)
536
548
 
549
+ if not trigger:
550
+ raise ValueError("No trigger with the given id was found.")
551
+
537
552
  if trigger.process_status != TriggerStatus.IDLE:
538
553
  return False
539
554
 
@@ -544,14 +559,13 @@ def begin_single_trigger(trigger_id: str) -> bool:
544
559
  return True
545
560
 
546
561
 
547
- @catch_db_error
548
562
  def get_next_single_trigger() -> SingleTrigger | None:
549
563
  """Get the single trigger that should trigger next.
550
564
 
551
565
  Returns:
552
- SingleTrigger | None: The next single trigger to run if any.
566
+ The next single trigger to run if any.
553
567
  """
554
- with Session(_connection_engine) as session:
568
+ with _get_session() as session:
555
569
  query = (
556
570
  select(SingleTrigger)
557
571
  .where(SingleTrigger.process_status == TriggerStatus.IDLE)
@@ -562,14 +576,13 @@ def get_next_single_trigger() -> SingleTrigger | None:
562
576
  return session.scalar(query)
563
577
 
564
578
 
565
- @catch_db_error
566
579
  def get_next_scheduled_trigger() -> ScheduledTrigger | None:
567
580
  """Get the scheduled trigger that should trigger next.
568
581
 
569
582
  Returns:
570
- ScheduledTrigger | None: The next scheduled trigger to run if any.
583
+ The next scheduled trigger to run if any.
571
584
  """
572
- with Session(_connection_engine) as session:
585
+ with _get_session() as session:
573
586
  query = (
574
587
  select(ScheduledTrigger)
575
588
  .where(ScheduledTrigger.process_status == TriggerStatus.IDLE)
@@ -580,8 +593,7 @@ def get_next_scheduled_trigger() -> ScheduledTrigger | None:
580
593
  return session.scalar(query)
581
594
 
582
595
 
583
- @catch_db_error
584
- def begin_scheduled_trigger(trigger_id: str) -> bool:
596
+ def begin_scheduled_trigger(trigger_id: UUID | str) -> bool:
585
597
  """Set the status of a scheduled trigger to 'running',
586
598
  set the last run time to the current time,
587
599
  and set the next run time to the given datetime.
@@ -593,9 +605,15 @@ def begin_scheduled_trigger(trigger_id: str) -> bool:
593
605
  Returns:
594
606
  bool: True if the trigger was 'idle' and now 'running'.
595
607
  """
596
- with Session(_connection_engine) as session:
608
+ if isinstance(trigger_id, str):
609
+ trigger_id = UUID(trigger_id)
610
+
611
+ with _get_session() as session:
597
612
  trigger = session.get(ScheduledTrigger, trigger_id)
598
613
 
614
+ if not trigger:
615
+ raise ValueError("No trigger with the given id was found.")
616
+
599
617
  if trigger.process_status != TriggerStatus.IDLE:
600
618
  return False
601
619
 
@@ -607,7 +625,6 @@ def begin_scheduled_trigger(trigger_id: str) -> bool:
607
625
  return True
608
626
 
609
627
 
610
- @catch_db_error
611
628
  def get_next_queue_trigger() -> QueueTrigger | None:
612
629
  """Get the next queue trigger to run.
613
630
  This functions loops through the queue triggers and checks
@@ -618,7 +635,7 @@ def get_next_queue_trigger() -> QueueTrigger | None:
618
635
  QueueTrigger | None: The next queue trigger to run if any.
619
636
  """
620
637
 
621
- with Session(_connection_engine) as session:
638
+ with _get_session() as session:
622
639
 
623
640
  sub_query = (
624
641
  select(alc_func.count()) # pylint: disable=not-callable
@@ -636,8 +653,7 @@ def get_next_queue_trigger() -> QueueTrigger | None:
636
653
  return session.scalar(query)
637
654
 
638
655
 
639
- @catch_db_error
640
- def begin_queue_trigger(trigger_id: str) -> None:
656
+ def begin_queue_trigger(trigger_id: UUID | str) -> bool:
641
657
  """Set the status of a queue trigger to 'running' and
642
658
  set the last run time to the current time.
643
659
 
@@ -647,9 +663,15 @@ def begin_queue_trigger(trigger_id: str) -> None:
647
663
  Returns:
648
664
  bool: True if the trigger was 'idle' and now 'running'.
649
665
  """
650
- with Session(_connection_engine) as session:
666
+ if isinstance(trigger_id, str):
667
+ trigger_id = UUID(trigger_id)
668
+
669
+ with _get_session() as session:
651
670
  trigger = session.get(QueueTrigger, trigger_id)
652
671
 
672
+ if not trigger:
673
+ raise ValueError("No trigger with the given id was found.")
674
+
653
675
  if trigger.process_status != TriggerStatus.IDLE:
654
676
  return False
655
677
 
@@ -660,22 +682,27 @@ def begin_queue_trigger(trigger_id: str) -> None:
660
682
  return True
661
683
 
662
684
 
663
- @catch_db_error
664
- def set_trigger_status(trigger_id: str, status: TriggerStatus) -> None:
685
+ def set_trigger_status(trigger_id: UUID | str, status: TriggerStatus) -> None:
665
686
  """Set the status of a trigger.
666
687
 
667
688
  Args:
668
689
  trigger_id: The id of the trigger.
669
690
  status: The new status of the trigger.
670
691
  """
671
- with Session(_connection_engine) as session:
692
+ if isinstance(trigger_id, str):
693
+ trigger_id = UUID(trigger_id)
694
+
695
+ with _get_session() as session:
672
696
  trigger = session.get(Trigger, trigger_id)
697
+
698
+ if not trigger:
699
+ raise ValueError("No trigger with the given id was found.")
700
+
673
701
  trigger.process_status = status
674
702
  session.commit()
675
703
 
676
704
 
677
- @catch_db_error
678
- def create_queue_element(queue_name: str, reference: str = None, data: str = None, created_by: str = None) -> QueueElement:
705
+ def create_queue_element(queue_name: str, reference: str | None = None, data: str | None = None, created_by: str | None = None) -> QueueElement:
679
706
  """Adds a queue element to the given queue.
680
707
 
681
708
  Args:
@@ -687,7 +714,7 @@ def create_queue_element(queue_name: str, reference: str = None, data: str = Non
687
714
  Returns:
688
715
  QueueElement: The created queue element.
689
716
  """
690
- with Session(_connection_engine) as session:
717
+ with _get_session() as session:
691
718
  q_element = QueueElement(
692
719
  queue_name = queue_name,
693
720
  data = data,
@@ -701,8 +728,7 @@ def create_queue_element(queue_name: str, reference: str = None, data: str = Non
701
728
  return q_element
702
729
 
703
730
 
704
- @catch_db_error
705
- def bulk_create_queue_elements(queue_name: str, references: tuple[str], data: tuple[str], created_by: str = None) -> None:
731
+ def bulk_create_queue_elements(queue_name: str, references: tuple[str | None, ...], data: tuple[str | None, ...], created_by: str | None = None) -> None:
706
732
  """Insert multiple queue elements into a queue in an optimized manner.
707
733
  The lengths of both 'references' and 'data' must be equal to the number of elements to insert.
708
734
 
@@ -734,13 +760,12 @@ def bulk_create_queue_elements(queue_name: str, references: tuple[str], data: tu
734
760
  for ref, dat in zip(references, data)
735
761
  )
736
762
 
737
- with Session(_connection_engine) as session:
738
- session.execute(insert(QueueElement), q_elements)
763
+ with _get_session() as session:
764
+ session.execute(insert(QueueElement), q_elements) # type: ignore
739
765
  session.commit()
740
766
 
741
767
 
742
- @catch_db_error
743
- def get_next_queue_element(queue_name: str, reference: str = None, set_status: bool = True) -> QueueElement | None:
768
+ def get_next_queue_element(queue_name: str, reference: str | None = None, set_status: bool = True) -> QueueElement | None:
744
769
  """Gets the next queue element from the given queue that has the status 'new'.
745
770
 
746
771
  Args:
@@ -752,7 +777,7 @@ def get_next_queue_element(queue_name: str, reference: str = None, set_status: b
752
777
  QueueElement | None: The next queue element in the queue if any.
753
778
  """
754
779
 
755
- with Session(_connection_engine) as session:
780
+ with _get_session() as session:
756
781
  query = (
757
782
  select(QueueElement)
758
783
  .where(QueueElement.queue_name == queue_name)
@@ -775,9 +800,9 @@ def get_next_queue_element(queue_name: str, reference: str = None, set_status: b
775
800
  return q_element
776
801
 
777
802
 
778
- @catch_db_error
779
- def get_queue_elements(queue_name: str, reference: str = None, status: QueueStatus = None,
780
- offset: int = 0, limit: int = 100) -> tuple[QueueElement]:
803
+ def get_queue_elements(queue_name: str, reference: str | None = None, status: QueueStatus | None = None,
804
+ from_date: datetime | None = None, to_date: datetime | None = None,
805
+ offset: int = 0, limit: int = 100) -> tuple[QueueElement, ...]:
781
806
  """Get multiple queue elements from a queue. The elements are ordered by created_date.
782
807
 
783
808
  Args:
@@ -790,7 +815,7 @@ def get_queue_elements(queue_name: str, reference: str = None, status: QueueStat
790
815
  Returns:
791
816
  tuple[QueueElement]: A tuple of queue elements.
792
817
  """
793
- with Session(_connection_engine) as session:
818
+ with _get_session() as session:
794
819
  query = (
795
820
  select(QueueElement)
796
821
  .where(QueueElement.queue_name == queue_name)
@@ -798,6 +823,13 @@ def get_queue_elements(queue_name: str, reference: str = None, status: QueueStat
798
823
  .offset(offset)
799
824
  .limit(limit)
800
825
  )
826
+
827
+ if from_date:
828
+ query = query.where(QueueElement.created_date >= from_date)
829
+
830
+ if to_date:
831
+ query = query.where(QueueElement.created_date <= to_date)
832
+
801
833
  if reference is not None:
802
834
  query = query.where(QueueElement.reference == reference)
803
835
 
@@ -808,14 +840,13 @@ def get_queue_elements(queue_name: str, reference: str = None, status: QueueStat
808
840
  return tuple(result)
809
841
 
810
842
 
811
- @catch_db_error
812
843
  def get_queue_count() -> dict[str, dict[QueueStatus, int]]:
813
844
  """Count the number of queue elements of each status for every queue.
814
845
 
815
846
  Returns:
816
847
  A dict for each queue with the count for each status. E.g. result[queue_name][status] => count.
817
848
  """
818
- with Session(_connection_engine) as session:
849
+ with _get_session() as session:
819
850
  query = (
820
851
  select(QueueElement.queue_name, QueueElement.status, alc_func.count()) # pylint: disable=not-callable
821
852
  .group_by(QueueElement.queue_name)
@@ -834,19 +865,25 @@ def get_queue_count() -> dict[str, dict[QueueStatus, int]]:
834
865
  return result
835
866
 
836
867
 
837
- @catch_db_error
838
- def set_queue_element_status(element_id: str, status: QueueStatus, message: str = None) -> None:
868
+ def set_queue_element_status(element_id: UUID | str, status: QueueStatus, message: str | None = None) -> None:
839
869
  """Set the status of a queue element.
840
870
  If the new status is 'in progress' the start date is noted.
841
- If the new status is 'Done' or 'Failed' the end date is noted.
871
+ If the new status is 'Done', 'Failed' or 'Abandoned' the end date is noted.
842
872
 
843
873
  Args:
844
874
  element_id: The id of the queue element to change status on.
845
875
  status: The new status of the queue element.
846
876
  message (Optional): The message to attach to the queue element. This overrides any existing messages.
847
877
  """
848
- with Session(_connection_engine) as session:
878
+ if isinstance(element_id, str):
879
+ element_id = UUID(element_id)
880
+
881
+ with _get_session() as session:
849
882
  q_element = session.get(QueueElement, element_id)
883
+
884
+ if not q_element:
885
+ raise ValueError("No queue element with the given id was found.")
886
+
850
887
  q_element.status = status
851
888
 
852
889
  if message is not None:
@@ -855,20 +892,21 @@ def set_queue_element_status(element_id: str, status: QueueStatus, message: str
855
892
  match status:
856
893
  case QueueStatus.IN_PROGRESS:
857
894
  q_element.start_date = datetime.now()
858
- case QueueStatus.DONE | QueueStatus.FAILED:
895
+ case QueueStatus.DONE | QueueStatus.FAILED | QueueStatus.ABANDONED:
859
896
  q_element.end_date = datetime.now()
897
+ case _:
898
+ pass
860
899
 
861
900
  session.commit()
862
901
 
863
902
 
864
- @catch_db_error
865
- def delete_queue_element(element_id: str) -> None:
903
+ def delete_queue_element(element_id: UUID | str) -> None:
866
904
  """Delete a queue element from the database.
867
905
 
868
906
  Args:
869
907
  element_id: The id of the queue element.
870
908
  """
871
- with Session(_connection_engine) as session:
909
+ with _get_session() as session:
872
910
  q_element = session.get(QueueElement, element_id)
873
911
  session.delete(q_element)
874
912
  session.commit()