scalable-pypeline 1.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,803 @@
1
+ """ Utilities for running and managing tasks inside pipelines.
2
+ """
3
+ import os
4
+ import logging
5
+ import uuid
6
+ from typing import List, Any, Union
7
+ from networkx.classes.digraph import DiGraph
8
+ from celery import chain, signature, chord
9
+
10
+ from pypeline.constants import DEFAULT_TASK_TTL, \
11
+ DEFAULT_REGULATOR_TASK, CHAIN_SUCCESS_MSG, CHAIN_FAILURE_MSG, \
12
+ PIPELINE_RUN_WRAPPER_CACHE_KEY, DEFAULT_SUCCESS_TASK, DEFAULT_RESULT_TTL, \
13
+ PIPELINE_RESULT_CACHE_KEY, DEFAULT_RETRY_TASK, DEFAULT_MAX_RETRY, \
14
+ DEFAULT_RETRY_TASK_MAX_TTL
15
+ from pypeline.utils.graph_utils import get_execution_graph, get_chainable_tasks
16
+ from pypeline.utils.config_utils import retrieve_latest_pipeline_config, \
17
+ load_json_config_from_redis, set_json_config_to_redis
18
+ from pypeline.pipeline_config_schema import PipelineConfigValidator
19
+
20
+ logger = logging.getLogger(__name__)
21
+ WORKER_NAME = os.environ.get('WORKER_NAME', None)
22
+
23
+
24
+ def get_service_config_for_worker(sermos_config: dict,
25
+ worker_name: str = None
26
+ ) -> Union[dict, None]:
27
+ """ For the current WORKER_NAME (which must be present in the environment
28
+ of this worker instance for a valid deployment), return the worker's
29
+ serviceConfig object.
30
+ """
31
+ if sermos_config is None:
32
+ raise ValueError('Sermos config was not provided')
33
+ if worker_name is None:
34
+ worker_name = WORKER_NAME
35
+ if worker_name is None:
36
+ return None
37
+
38
+ service_config = sermos_config.get('serviceConfig', [])
39
+ for service in service_config:
40
+ if service['serviceType'] == 'celery-worker' and service[
41
+ 'name'] == worker_name:
42
+ return service
43
+
44
+ raise ValueError('Could not find a service config for worker '
45
+ f'`{worker_name}`. Make sure you have added the service in'
46
+ f' your sermos.yaml with `name: {worker_name}` and '
47
+ '`type: celery-worker`.')
48
+
49
+
50
+ def get_task_signature(task_path: str,
51
+ queue: str,
52
+ access_key: str = None,
53
+ pipeline_id: str = None,
54
+ execution_id: str = None,
55
+ max_ttl: int = None,
56
+ immutable: bool = True,
57
+ task_config: dict = None,
58
+ custom_event_data: dict = None) -> signature:
59
+ """ Generate a task signature with enforced event keyword
60
+ """
61
+ if task_config is None:
62
+ task_config = dict()
63
+ if custom_event_data is None:
64
+ custom_event_data = dict()
65
+
66
+ if queue is None:
67
+ # Look for a pipeline task configuration, if one was provided then we
68
+ # use queue specified on that task if it's specified.
69
+ queue = task_config.get('queue', None)
70
+
71
+ # If we still have None or 'default' (for backwards compability), raise
72
+ # because we're not requiring that a queue is specified.
73
+ if queue in (None, 'default'):
74
+ raise ValueError('Must set queue for a worker or registeredTask.')
75
+
76
+ if max_ttl is None:
77
+ # First look on the pipeline configuration, if a max_ttl is specified,
78
+ # then we're using that regardless.
79
+ max_ttl = task_config.get('maxTtl', None)
80
+
81
+ # If we still have None or 'default', set the default queue!
82
+ if max_ttl in (None, 'default'):
83
+ max_ttl = DEFAULT_TASK_TTL
84
+
85
+ kwargs = {
86
+ 'event': {
87
+ 'access_key': access_key,
88
+ 'pipeline_id': pipeline_id,
89
+ 'execution_id': execution_id
90
+ }
91
+ }
92
+ if custom_event_data is not None:
93
+ kwargs['event'] = {**kwargs['event'], **custom_event_data}
94
+
95
+ # TODO where do we inject the 'event' data from sermos yaml schema?
96
+
97
+ sig = signature(
98
+ task_path,
99
+ args=(),
100
+ kwargs=kwargs,
101
+ immutable=immutable,
102
+ task_id=str(uuid.uuid4()),
103
+ options={
104
+ 'queue': queue,
105
+ 'expires': 86400, # Expire after 1 day. TODO make tunable.
106
+ 'soft_time_limit': max_ttl,
107
+ 'time_limit': max_ttl + 10, # Add 10s buffer for cleanup
108
+ })
109
+ return sig
110
+
111
+
112
+ class PipelineRunWrapper:
113
+ """ A wrapper for a single "run" of a Pipeline.
114
+
115
+ A 'run' is defined as a single execution of a pipeline, a pipeline
116
+ consisting of one or more steps in a chain.
117
+
118
+ When a pipeline's run is first executed, the execution id is generated
119
+ as a uuid. Subsequent retries of this 'run' will be able to look up
120
+ using that execution id.
121
+
122
+ The primary purpose for the PipelineRunWrapper is to provide a cached
123
+ representation of the full 'run' including retry count and any payload
124
+ that should be accessible to any step in the chain. Remember, a pipeline
125
+ is running asynchronously and, as such, each node in the graph operates
126
+ independent the others, this allows for consistent coordination.
127
+ """
128
+ pipeline_id: str = None
129
+ pipeline_config: dict = None # Pipeline configuration in dictionary format
130
+ dag_config: dict = None
131
+ execution_id: str = None
132
+ current_event: dict = None # For single task when from_event(). NOT cached.
133
+ cache_key: str = None # Set on init
134
+ max_ttl: int = 60 # Overloaded when pipeline_config provided and it's set
135
+ max_retry: int = 3 # Overloaded when pipeline_config provided and it's set
136
+ retry_count: int = 0
137
+ chain_payload: dict = None # Optional data to pass to each step in chain
138
+ execution_graph: DiGraph = None
139
+ good_to_go = False
140
+ loading_message = None
141
+
142
+ def __init__(self,
143
+ pipeline_id: str,
144
+ pipeline_config: dict = None,
145
+ execution_id: str = None,
146
+ max_ttl: int = 60,
147
+ max_retry: int = 3,
148
+ chain_payload: dict = None,
149
+ current_event: dict = None):
150
+ super().__init__()
151
+ self.pipeline_id = pipeline_id
152
+ self.pipeline_config = pipeline_config
153
+
154
+ self.max_ttl = max_ttl
155
+ self.max_retry = max_retry
156
+
157
+ # Execution IDs uniquely identify a single run of a given pipeline.
158
+ # If None is provided, a random id is generated, which will be cached
159
+ # and used downstream in the event of a retry. Initial invocations
160
+ # should generally not set this value manually.
161
+ self.execution_id = execution_id
162
+ if self.execution_id is None:
163
+ self.execution_id = str(uuid.uuid4())
164
+
165
+ self.chain_payload = chain_payload\
166
+ if chain_payload is not None else {}
167
+
168
+ self.current_event = current_event\
169
+ if current_event is not None else {}
170
+
171
+ self.cache_key = PIPELINE_RUN_WRAPPER_CACHE_KEY.format(
172
+ self.pipeline_id, self.execution_id)
173
+
174
+ self.good_to_go = True
175
+
176
+ @property
177
+ def _cachable_keys(self):
178
+ """ For caching purposes, only store json serializable values that are
179
+ required for caching / loading from cache.
180
+
181
+ Note: Several keys are pulled from the pipeline_config where they are
182
+ camelCase and set on this as snake_case. This is done for convenience
183
+ in the wrapper. Style convention switching is to keep with the naming
184
+ convention of all yaml files following camelCase to conform with k8s
185
+ and all local python variables being snake_case. This extraction of
186
+ the yaml file variables to place onto the wrapper object is done
187
+ during the .load() stage.
188
+ """
189
+ return ('pipeline_config', 'max_ttl', 'max_retry', 'retry_count',
190
+ 'chain_payload', 'pipeline_id')
191
+
192
+ def _load_from_cache(self, is_retry=False):
193
+ """ Attempt to load this PipelineRunWrapper from cache.
194
+ """
195
+ logger.debug(f"Attempting to load {self.cache_key} from cache")
196
+ try:
197
+ cached_wrapper = load_json_config_from_redis(self.cache_key)
198
+ if cached_wrapper is not None:
199
+ for key in self._cachable_keys:
200
+ setattr(self, key, cached_wrapper[key])
201
+
202
+ msg = f"{self.cache_key} found in cache ..."
203
+ self.loading_message = msg
204
+ logger.debug(msg)
205
+ else:
206
+ raise ValueError(f"Unable to find {self.cache_key} ...")
207
+ except Exception as e:
208
+ if not is_retry:
209
+ self.good_to_go = False
210
+ self.loading_message = e
211
+ logger.exception(e)
212
+
213
+ if self.pipeline_config is None:
214
+ raise ValueError("pipeline_config not set, invalid ...")
215
+
216
+ return
217
+
218
+ def save_to_cache(self):
219
+ """ Save current state of PipelineRunWrapper to cache, json serialized.
220
+ Re-set the key's TTL
221
+
222
+ TODO: Lock this so no race condition on concurrent steps.
223
+ """
224
+ logger.debug(f"Saving {self.cache_key} to cache")
225
+ cached_json = {}
226
+ for key in self._cachable_keys:
227
+ cached_json[key] = getattr(self, key)
228
+ ttl = (self.max_ttl *
229
+ len(self.pipeline_config['taskDefinitions'])) + 10
230
+ set_json_config_to_redis(self.cache_key, cached_json, ttl)
231
+
232
+ @classmethod
233
+ def from_event(cls, event):
234
+ """ Create instance of PipelineRunWrapper from pipeline event.
235
+
236
+ Loads the cached PipelineRunWrapper instance, which is assumed to exist
237
+ when loading from an event (which should only occur inside a pipeline
238
+ node, which means the pipeline has been invoked/generated previously).
239
+
240
+ Usage::
241
+
242
+ pipeline_wrapper = PipelineRunWrapper.from_event(event)
243
+ # pipeline_wrapper.load() # TODO deprecate
244
+ """
245
+ wrapper = cls(pipeline_id=event.get('pipeline_id', None),
246
+ execution_id=event.get('execution_id', None),
247
+ current_event=event)
248
+ wrapper.load()
249
+ return wrapper
250
+
251
+ def load(self,
252
+ verify_retry_count: bool = True,
253
+ allow_deadletter: bool = True,
254
+ is_retry: bool = False):
255
+ """ Loads PipelineRunWrapper from cache
256
+
257
+ If verify_retry_count is True, this will deadletter the task wrapper
258
+ immediately (if deadletter=True) if retry count is exceeded.
259
+ """
260
+ try:
261
+ # Pipeline config is expected to be provided when first initializing
262
+ # a pipeline run wrapper. On subsequent runs or when loading from
263
+ # an event, the run wrapper can be loaded using only the pipeline
264
+ # id and execution id, the pipeline config is then initialized from
265
+ # the wrapper
266
+ if self.pipeline_config is None or is_retry:
267
+ self._load_from_cache(is_retry=is_retry)
268
+ else:
269
+ # If the pipeline_config is set before .load(), that means
270
+ # this invocation is coming from an initial load, not cache.
271
+ # We don't want to re-set pipeline_config and the retry_count
272
+ # and chain_payload are not going to exist, as they are an
273
+ # artifact of the caching process. We also explicitly skip
274
+ # pipeline_id, max_retry, and max_ttl keys because those are
275
+ # metadata keys in the pipeline_config and are camel case
276
+ # (pipelineId/maxRetry/maxTtl), we set them on this wrapper
277
+ # object purely for convenience and to provide logical defaults.
278
+ for key in self._cachable_keys:
279
+ if key in ('pipeline_config', 'pipeline_id', 'max_retry',
280
+ 'max_ttl', 'retry_count', 'chain_payload'):
281
+ continue
282
+ setattr(self, key, self.pipeline_config[key])
283
+
284
+ # Validate pipeline config
285
+ PipelineConfigValidator(config_dict=self.pipeline_config)
286
+
287
+ # Initialize the actual pipeline configuration and execution graph
288
+ self.dag_config = self.pipeline_config['dagAdjacency']
289
+ self.execution_graph = get_execution_graph(self.pipeline_config)
290
+
291
+ # Overload defaults if explicitly provided
292
+ self.max_ttl = self.pipeline_config['metadata'].get(
293
+ 'maxTtl', self.max_ttl)
294
+ self.max_retry = self.pipeline_config['metadata'].get(
295
+ 'maxRetry', self.max_retry)
296
+
297
+ if is_retry:
298
+ self.increment_retry()
299
+
300
+ if verify_retry_count and self.retry_exceeded:
301
+ msg = "Attempted to retry {}_{}; exceeded retry count."\
302
+ .format(self.pipeline_id, self.execution_id)
303
+ logger.warning(msg)
304
+ self.loading_message = msg
305
+ if allow_deadletter:
306
+ self.deadletter()
307
+ return
308
+
309
+ self.save_to_cache() # Always save back to cache
310
+ except Exception as e:
311
+ logger.exception(e)
312
+ self.loading_message = e
313
+ if allow_deadletter:
314
+ self.deadletter()
315
+ return
316
+
317
+ self.loading_message = "Loaded Successfully."
318
+ return
319
+
320
+ def increment_retry(self, exceed_max: bool = False):
321
+ """ Increment retry_count by 1
322
+
323
+ `cache` determines whether this will re-cache object after increment
324
+ `exceed_max` allows an instant kickout of this to deadletter.
325
+ """
326
+ if exceed_max:
327
+ new_count = self.max_retry + 1
328
+ else:
329
+ new_count = self.retry_count + 1
330
+
331
+ logger.debug(f"Incrementing Retry to {new_count}")
332
+ self.retry_count = new_count
333
+ self.save_to_cache()
334
+
335
+ @property
336
+ def retry_exceeded(self):
337
+ """ Determine if retry_count has been exceeded.
338
+ """
339
+ logger.debug(f"Checking retry count: {self.retry_count} / "
340
+ f"{self.max_retry} / {self.retry_count > self.max_retry}")
341
+ if self.retry_count >= self.max_retry:
342
+ return True
343
+ return False
344
+
345
+ def deadletter(self):
346
+ """ Add details of this PipelineTask to a deadletter queue.
347
+
348
+ TODO:
349
+ - add to a system for tracking failed pipeline runs
350
+ - delete task wrapper and all tasks from cache
351
+ """
352
+ self.good_to_go = False
353
+ pr = PipelineResult(
354
+ self.execution_id,
355
+ status='failed',
356
+ result='Pipeline retried and failed {} times.'.format(
357
+ self.retry_count))
358
+ pr.save()
359
+ self.increment_retry(
360
+ exceed_max=True) # Ensure this won't be retried...
361
+ return
362
+
363
+
364
+ class PipelineResult:
365
+ """ Standard store for pipeline results.
366
+
367
+ Helps keep standard way to store/retrieve results + status messages
368
+ for pipelines.
369
+
370
+ Can get fancier in the future by tracking retry count, pipeline
371
+ execution time, etc.
372
+ """
373
+ def __init__(self,
374
+ execution_id: str,
375
+ status: str = None,
376
+ result: Any = None,
377
+ result_ttl: int = DEFAULT_RESULT_TTL):
378
+ super().__init__()
379
+ self.execution_id = execution_id
380
+ if self.execution_id is None:
381
+ raise ValueError("Must provide an execution_id!")
382
+ self.status = status
383
+ self.result = result
384
+ self.results = result # TODO Deprecate in future release, keep singular
385
+ self.result_ttl = result_ttl
386
+ self.cache_key =\
387
+ PIPELINE_RESULT_CACHE_KEY.format(self.execution_id)
388
+
389
+ self.valid_status_types = ('pending', 'success', 'failed',
390
+ 'unavailable')
391
+
392
+ # Always validate status
393
+ self._validate_status()
394
+
395
+ def _validate_status(self):
396
+ if self.status and self.status not in self.valid_status_types:
397
+ raise ValueError("{} is not a valid status type ({})".format(
398
+ self.status, self.valid_status_types))
399
+
400
+ def save(self, status: str = None, result: Any = None):
401
+ """ Save the result's current state.
402
+
403
+ If status and/or result are not provided, then the existing instance
404
+ state is used. You can override either by passing to this fn.
405
+
406
+ Typical use case would be to initialize the PipelineResult with only
407
+ the execution ID, then 'save_result()' and pass status/result.
408
+ """
409
+ if status is not None:
410
+ self.status = status
411
+ if result is not None:
412
+ self.result = result
413
+ self.results = result # TODO Deprecate in future release
414
+ set_json_config_to_redis(self.cache_key, self.to_dict(),
415
+ self.result_ttl)
416
+
417
+ def load(self):
418
+ """ Load a pipeline result from cache.
419
+ """
420
+ results = load_json_config_from_redis(self.cache_key)
421
+ if results is not None:
422
+ for k in results:
423
+ setattr(self, k, results[k])
424
+ else:
425
+ self.status = 'unavailable'
426
+ self.result = None
427
+ self.results = None # TODO Deprecate in future release
428
+
429
+ @classmethod
430
+ def from_event(cls, event):
431
+ """ Create initialized instance of PipelineResult from a pipeline event.
432
+
433
+ Usage::
434
+
435
+ pipeline_result = PipelineResult.from_event(event)
436
+ pipeline_result.save(
437
+ result='my result value'
438
+ )
439
+ """
440
+ pr = cls(execution_id=event.get('execution_id', None))
441
+ pr.load()
442
+ return pr
443
+
444
+ def to_dict(self):
445
+ """ Return serializable version of result for storage/retrieval.
446
+ """
447
+ return {
448
+ 'execution_id': self.execution_id,
449
+ 'status': self.status,
450
+ 'result': self.result,
451
+ 'results': self.result, # TODO Deprecate in future release
452
+ 'result_ttl': self.result_ttl
453
+ }
454
+
455
+
456
+ class PipelineGenerator(object):
457
+ """ Allows an API endpoint to generate a functional pipeline based on the
458
+ requested pipeline id. Allows API to then issue the tasks asynchronously
459
+ to initiate the pipeline. Thereafter, celery will monitor status and
460
+ handle success/failure modes so the API web worker can return
461
+ immediately.
462
+
463
+ The primary purpose is to unpack the pipeline config, create the
464
+ requisite cached entities to track pipeline progress, and apply the
465
+ chained pipeline tasks asynchronously so Celery can take over.
466
+
467
+ Usage:
468
+ gen = PipelineGenerator(pipeline_id)
469
+ chain = gen.generate_chain()
470
+ chain.on_error(custom_error_task.s()) # Optional add error handling
471
+ chain.delay()
472
+ """
473
+ def __init__(self,
474
+ pipeline_id: str,
475
+ access_key: str = None,
476
+ execution_id: str = None,
477
+ queue: str = None,
478
+ default_task_ttl: int = None,
479
+ regulator_queue: str = None,
480
+ regulator_task: str = None,
481
+ success_queue: str = None,
482
+ success_task: str = None,
483
+ retry_task: str = None,
484
+ add_retry: bool = True,
485
+ default_max_retry: int = None,
486
+ chain_payload: dict = None):
487
+ super().__init__()
488
+ self.pipeline_id = pipeline_id
489
+ self.access_key = access_key
490
+
491
+ pipeline_config_api_resp = retrieve_latest_pipeline_config(
492
+ pipeline_id=self.pipeline_id, access_key=self.access_key)
493
+
494
+ if pipeline_config_api_resp is None:
495
+ raise ValueError("Unable to load Pipeline Configuration for "
496
+ f"pipeline id: {self.pipeline_id} ...")
497
+
498
+ # The only part of the API response used for any 'pipeline config'
499
+ # is the `config` key. The API nests it under `config` to preserve
500
+ # ability to add additional detail at a later date.
501
+ self.pipeline_config = pipeline_config_api_resp.get('config', {})
502
+ schema_version = pipeline_config_api_resp.get('schemaVersion')
503
+ PipelineConfigValidator(config_dict=self.pipeline_config,
504
+ schema_version=schema_version)
505
+
506
+ self.execution_id = execution_id # UUID string
507
+ self.good_to_go = False # Indicates initialization/loading success
508
+ self.loading_message = None # Allows access to success/error messages
509
+ self.is_retry = False if self.execution_id is None else True
510
+ self.add_retry = add_retry
511
+ self.retry_task = retry_task\
512
+ if retry_task is not None else DEFAULT_RETRY_TASK
513
+
514
+ self.default_max_retry = default_max_retry \
515
+ if default_max_retry is not None else \
516
+ self.pipeline_config['metadata'].get('maxRetry', DEFAULT_MAX_RETRY)
517
+
518
+ # Queue on which to place tasks by default and default TTL per task
519
+ # These can be overridden in PipelineConfig.config['taskDefinitions']
520
+ self.queue = queue \
521
+ if queue is not None \
522
+ else self.pipeline_config['metadata']['queue']
523
+ self.default_task_ttl = default_task_ttl \
524
+ if default_task_ttl is not None else \
525
+ self.pipeline_config['metadata'].get('maxTtl', DEFAULT_TASK_TTL)
526
+
527
+ # See docstring in self._get_regulator()
528
+ self.regulator_queue = regulator_queue \
529
+ if regulator_queue is not None \
530
+ else self.pipeline_config['metadata']['queue']
531
+ self.regulator_task = regulator_task\
532
+ if regulator_task is not None else DEFAULT_REGULATOR_TASK
533
+
534
+ # See docstring in self._get_success_task()
535
+ self.success_queue = success_queue \
536
+ if success_queue is not None \
537
+ else self.pipeline_config['metadata']['queue']
538
+ self.success_task = success_task\
539
+ if success_task is not None else DEFAULT_SUCCESS_TASK
540
+
541
+ # Optional data to pass to each step in chain
542
+ self.chain_payload = chain_payload\
543
+ if chain_payload is not None else {}
544
+
545
+ self.pipeline_wrapper = None # Allows access to the PipelineRunWrapper
546
+ self.chain = None # Must be intentionally built with generate_chain()
547
+
548
+ try:
549
+ # Generate our wrapper for this pipeline_id / execution_id
550
+ self.pipeline_wrapper = PipelineRunWrapper(
551
+ pipeline_id=self.pipeline_id,
552
+ pipeline_config=self.pipeline_config,
553
+ execution_id=self.execution_id,
554
+ max_ttl=self.default_task_ttl,
555
+ max_retry=self.default_max_retry,
556
+ chain_payload=self.chain_payload)
557
+
558
+ # Loads pipeline config from remote or cache if it's already there
559
+ # `is_retry` will be True for any PipelineGenerator instantiated
560
+ # with an execution_id. This flag helps the wrapper increment the
561
+ # retry count and determine if this should be deadlettered.
562
+ # This step also saves the valid/initialized run wrapper to cache.
563
+ self.pipeline_wrapper.load(is_retry=self.is_retry)
564
+
565
+ # Set all variables that were established from the run wrapper
566
+ # initialization. Notably, default_task_ttl can be overloaded
567
+ # if the pipeline config has an explicit maxTtl set in metadata.
568
+ self.good_to_go = self.pipeline_wrapper.good_to_go
569
+ self.loading_message = self.pipeline_wrapper.loading_message
570
+ self.execution_id = self.pipeline_wrapper.execution_id
571
+
572
+ except Exception as e:
573
+ fail_msg = "Failed to load Pipeline for id {} ... {}".format(
574
+ self.pipeline_id, e)
575
+ self.loading_message = fail_msg
576
+ logger.error(fail_msg)
577
+ raise e
578
+
579
+ def _get_regulator(self):
580
+ """ Create a chain regulator celery task signature.
581
+
582
+ For a chain(), if each element is a group() then celery does not
583
+ properly adhere to the chain elements occurring sequentially. If you
584
+ insert a task that is not a group() in between, though, then the
585
+ chain operates as expected.
586
+ """
587
+ return signature(self.regulator_task,
588
+ queue=self.regulator_queue,
589
+ immutable=True)
590
+
591
+ def _get_success_task(self):
592
+ """ A final 'success' task that's added to the end of every pipeline.
593
+
594
+ This stores the 'success' state in the cached result. Users can
595
+ set other values by using TaskRunner().save_result()
596
+ """
597
+ return get_task_signature(task_path=self.success_task,
598
+ queue=self.success_queue,
599
+ pipeline_id=self.pipeline_id,
600
+ execution_id=self.execution_id)
601
+
602
+ def _get_retry_task(self):
603
+ """ The retry task will re-invoke a chain.
604
+ """
605
+ return get_task_signature(task_path=self.retry_task,
606
+ queue=self.queue,
607
+ access_key=self.access_key,
608
+ pipeline_id=self.pipeline_id,
609
+ execution_id=self.execution_id,
610
+ max_ttl=DEFAULT_RETRY_TASK_MAX_TTL,
611
+ custom_event_data={
612
+ 'queue': self.queue,
613
+ 'default_task_ttl':
614
+ self.default_task_ttl,
615
+ 'add_retry': self.add_retry,
616
+ 'chain_payload': self.chain_payload
617
+ })
618
+
619
+ def _get_signature(self, node):
620
+ """ Create a celery task signature based on a graph node.
621
+ """
622
+ metadata = self.pipeline_config['metadata']
623
+ node_config = self.pipeline_config['taskDefinitions'][node]
624
+
625
+ # Node config takes precedence, pipeline metadata as default
626
+ queue = node_config.get('queue', metadata['queue'])
627
+ max_ttl = node_config.get('maxTtl', metadata.get('maxTtl', None))
628
+
629
+ # Ensures task signatures include requisite information to retrieve
630
+ # PipelineRunWrapper from cache using the pipeline id, and execution id.
631
+ # We set immutable=True to ensure each client task can be defined
632
+ # with this specific signature (event)
633
+ # http://docs.celeryproject.org/en/master/userguide/canvas.html#immutability
634
+ return get_task_signature(task_path=node_config.get('handler'),
635
+ queue=queue,
636
+ access_key=self.access_key,
637
+ pipeline_id=self.pipeline_id,
638
+ execution_id=self.execution_id,
639
+ max_ttl=max_ttl,
640
+ immutable=True,
641
+ task_config=node_config)
642
+
643
+ def generate_chain(self):
644
+ """ Generate the full pipeline chain.
645
+ """
646
+ logger.debug(f'Starting Pipeline {self.pipeline_id}')
647
+
648
+ if not self.good_to_go:
649
+ logger.info("Chain deemed to be not good to go.")
650
+ if self.loading_message is None:
651
+ self.loading_message = CHAIN_FAILURE_MSG
652
+ return None
653
+
654
+ try:
655
+ # Create the task chain such that all concurrent tasks are grouped
656
+ # and all high level node groups are run serially
657
+ G = self.pipeline_wrapper.execution_graph
658
+
659
+ total_tasks = 0
660
+ pipeline_chain = []
661
+ chainable_tasks = get_chainable_tasks(G, None, [])
662
+
663
+ # Current chord+chain solution based on
664
+ # https://stackoverflow.com/questions/15123772/celery-chaining-groups-and-subtasks-out-of-order-execution
665
+ # Look also at last comment from Nov 7, 2017 here
666
+ # https://github.com/celery/celery/issues/3597
667
+ # Big outstanding bug in Celery related to failures in chords that
668
+ # results in really nasty log output. See
669
+ # https://github.com/celery/celery/issues/4834
670
+ for i, node_group in enumerate(chainable_tasks):
671
+ total_tasks += len(node_group)
672
+ this_group = []
673
+ for node in node_group:
674
+ node_signature = self._get_signature(node)
675
+ this_group.append(node_signature)
676
+
677
+ if len(this_group) <= 1:
678
+ this_group.append(self._get_regulator())
679
+
680
+ the_chord = chord(header=this_group,
681
+ body=self._get_regulator())
682
+
683
+ pipeline_chain.append(the_chord)
684
+
685
+ # Add a 'finished/success' task to the end of all pipelines
686
+ pipeline_chain.append(
687
+ chord(header=self._get_success_task(),
688
+ body=self._get_regulator()))
689
+
690
+ the_chain = chain(*pipeline_chain)
691
+
692
+ # Add retry
693
+ if self.add_retry:
694
+ the_chain.link_error(self._get_retry_task())
695
+
696
+ self.loading_message = CHAIN_SUCCESS_MSG
697
+
698
+ self.chain = the_chain
699
+ except Exception as e:
700
+ self.loading_message = CHAIN_FAILURE_MSG + " {}".format(e)
701
+ logger.exception(e)
702
+ the_chain = None
703
+
704
+ self.chain = the_chain
705
+
706
+ return the_chain
707
+
708
+
709
+ class TaskRunner:
710
+ """ Run tasks in Sermos
711
+ """
712
+ @classmethod
713
+ def save_result(cls):
714
+ """ Save a task result
715
+ """
716
+ # TODO Implement
717
+
718
+ @classmethod
719
+ def publish_work(cls,
720
+ task_path: str,
721
+ task_payload: dict,
722
+ queue: str = None,
723
+ max_ttl: int = None):
724
+ """ Uniform way to issue a task to another celery worker.
725
+
726
+ Args:
727
+ task_path (str): Full path to task intended to run. e.g.
728
+ sermos_company_client.workers.my_work.task_name
729
+ task_payload (dict): A dictionary containing whatever payload
730
+ the receiving task expects. This is merged into the `event`
731
+ argument for the receiving task such that any top level
732
+ keys in your `task_payload` are found at event['the_key']
733
+ queue (str): The queue on which to place this task.
734
+ Ensure there are workers available to accept work on
735
+ that queue.
736
+ max_ttl (int): Optional. Max time to live for the issued task.
737
+ If not specified, system default is used.
738
+ """
739
+ try:
740
+ # TODO consider whether to add access key/deployment id here
741
+ worker = get_task_signature(task_path=task_path,
742
+ queue=queue,
743
+ max_ttl=max_ttl,
744
+ custom_event_data=task_payload)
745
+ worker.delay()
746
+ except Exception as e:
747
+ logger.error(f"Failed to publish work ... {e}")
748
+ return False
749
+
750
+ return True
751
+
752
+ @classmethod
753
+ def publish_work_in_batches(cls,
754
+ task_path: str,
755
+ task_payload_list: List[dict],
756
+ queue: str,
757
+ grouping_key: str = 'tasks',
758
+ max_per_task: int = 5,
759
+ max_ttl: int = None):
760
+ """ Uniform way to issue tasks to celery in 'batches'.
761
+
762
+ This allows work to be spread over multiple workers, each worker is
763
+ able to consume one or more messages in a single task.
764
+
765
+ Args:
766
+ task_path (str): Full path to task intended to run. e.g.
767
+ sermos_company_client.workers.my_work.task_name
768
+ task_payload_list (list): A list of dictionaries containing
769
+ whatever payload the receiving task expects. This is broken
770
+ into batches according to `max_per_task` and nested under
771
+ the `grouping_key` in the `event` argument for the receiving
772
+ task such that payload dicts are found at event['grouping_key']
773
+ queue (str): The queue on which to place this task.
774
+ Ensure there are workers available to accept work on
775
+ that queue.
776
+ grouping_key (str): Default: tasks. Sets the key name under the
777
+ receiving task's `event` where the payload items are found.
778
+ max_per_task (int): Default: 5. Maximum number of tasks from the
779
+ `task_payload_list` that will be bundled under the `grouping_key`
780
+ and issued as a single task to the receiving worker.
781
+ max_ttl (int): Optional. Max time to live for the issued task.
782
+ If not specified, system default is used.
783
+ """
784
+ try:
785
+ if len(task_payload_list) > 0:
786
+ for idx in range(len(task_payload_list)):
787
+ if idx % max_per_task == 0:
788
+ custom_event_data = {
789
+ grouping_key:
790
+ task_payload_list[idx:idx + max_per_task]
791
+ }
792
+ # TODO consider whether to add access key/deployment id here
793
+ worker = get_task_signature(
794
+ task_path=task_path,
795
+ queue=queue,
796
+ max_ttl=max_ttl,
797
+ custom_event_data=custom_event_data)
798
+ worker.delay()
799
+ except Exception as e:
800
+ logger.error(f"Failed to publish work in batches ... {e}")
801
+ return False
802
+
803
+ return True