scalable-pypeline 1.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,197 @@
1
+ """ Schemas for Pipelines
2
+
3
+ TODO: Add validation that all specified nodes in DAG have corresponding
4
+ node in taskDefinitions
5
+ """
6
+ import yaml
7
+ from marshmallow import Schema, fields, EXCLUDE, validates_schema
8
+ from marshmallow.exceptions import ValidationError
9
+
10
+
11
+ class ExcludeUnknownSchema(Schema):
12
+ """ Remove unknown keys from loaded dictionary
13
+ """
14
+ class Meta:
15
+ """ Exclude unknown properties.
16
+ """
17
+ unknown = EXCLUDE
18
+
19
+
20
+ class MetadataSchema(Schema):
21
+ """ Schema for a pipeline's metadata object.
22
+ """
23
+ queue = fields.String(required=True,
24
+ description="Default queue for all pipeline tasks.",
25
+ example="default-queue-name")
26
+
27
+ maxRetry = fields.Integer(
28
+ required=False,
29
+ description="Max number of retries for a pipeline.",
30
+ default=3,
31
+ example=3)
32
+
33
+ maxTtl = fields.Integer(required=False,
34
+ description="Max TTL for a pipeline in seconds.",
35
+ default=60,
36
+ example=60)
37
+
38
+
39
+ class TaskDefinitionsSchema(ExcludeUnknownSchema):
40
+ """ Schema for a single task's configuration
41
+ """
42
+ handler = fields.String(required=True,
43
+ description="Path to the worker task definition",
44
+ example="client.workers.my_task")
45
+
46
+ maxTtl = fields.Integer(required=False,
47
+ description="Max TTL for a task in seconds.",
48
+ default=60,
49
+ example=60)
50
+
51
+ queue = fields.String(required=False,
52
+ description="Non-default queue for this task.",
53
+ example="custom-queue-name")
54
+ # payload_args = fields.List(
55
+ # fields.Dict(keys=fields.String(),
56
+ # values=fields.Nested(PayloadArgKwargSchema)))
57
+ # payload_kwargs = fields.List(
58
+ # fields.Dict(keys=fields.String(),
59
+ # values=fields.Nested(PayloadArgKwargSchema)))
60
+ # model_version = fields.String()
61
+ # arbitrary other stuff passed to task?
62
+
63
+
64
+ class PipelineConfigSchemaV1(Schema):
65
+ """ Overall pipeline configuration schema
66
+ """
67
+ metadata = fields.Nested(
68
+ MetadataSchema,
69
+ required=True,
70
+ description="Metadata and configuration information for this pipeline."
71
+ )
72
+ dagAdjacency = fields.Dict(
73
+ keys=fields.String(
74
+ required=True,
75
+ description=
76
+ "Task's node name. *MUST* match key in taskDefinitions dict.",
77
+ example="node_a"),
78
+ values=fields.List(
79
+ fields.String(
80
+ required=True,
81
+ description=
82
+ "Task's node name. *Must* match key in taskDefinitions dict.")
83
+ ),
84
+ required=True,
85
+ description="The DAG Adjacency definition.")
86
+ taskDefinitions = fields.Dict(
87
+ keys=fields.String(
88
+ required=True,
89
+ description=
90
+ "Task's node name. *Must* match related key in dagAdjacency.",
91
+ example="node_a"),
92
+ values=fields.Nested(
93
+ TaskDefinitionsSchema,
94
+ required=True,
95
+ description="Definition of each task in the pipeline.",
96
+ example={
97
+ 'handler': 'abc.task',
98
+ 'maxRetry': 1
99
+ }),
100
+ required=True,
101
+ description="Configuration for each node defined in DAG.")
102
+
103
+
104
+ class BasePipelineSchema(ExcludeUnknownSchema):
105
+ __schema_version__ = None
106
+
107
+ name = fields.String(required=True, description="Pipeline name")
108
+ description = fields.String(required=False, missing=None,
109
+ description="Description of the pipeline.",
110
+ example="A valuable pipeline.")
111
+ schemaVersion = fields.Integer(required=True)
112
+ config = fields.Dict(required=True)
113
+
114
+ @classmethod
115
+ def get_by_version(cls, version):
116
+ for subclass in cls.__subclasses__():
117
+ if subclass.__schema_version__ == version:
118
+ return subclass
119
+
120
+ return None
121
+
122
+ @classmethod
123
+ def get_latest(cls):
124
+ max_version = 0
125
+ max_class = None
126
+ for subclass in cls.__subclasses__():
127
+ if subclass.__schema_version__ > max_version:
128
+ max_version = max_version
129
+ max_class = subclass
130
+
131
+ return max_class
132
+
133
+ @validates_schema
134
+ def validate_pipeline(self, data, **kwargs):
135
+ schema_version = data['schemaVersion']
136
+ PipelineSchema = BasePipelineSchema.get_by_version(schema_version)
137
+ schema = PipelineSchema(exclude=['name', 'description'])
138
+ schema.load(data)
139
+
140
+ class PipelineSchemaV1(BasePipelineSchema):
141
+
142
+ __schema_version__ = 1
143
+ class Meta:
144
+ unknown = EXCLUDE
145
+
146
+ config = fields.Nested(
147
+ PipelineConfigSchemaV1,
148
+ required=True,
149
+ description="Metadata and configuration information for this pipeline."
150
+ )
151
+
152
+ def validate_pipeline(self, data, **kwargs):
153
+ # We need to add this function to avoid infinite recursion since
154
+ # the BasePipelineSchema class above uses the same method for
155
+ # validation
156
+ pass
157
+
158
+
159
+ class PipelineConfigValidator(object):
160
+ """ Validate a pipeline configuration.
161
+
162
+ This is stored as a string in the database under `PipelineConfig.config`
163
+ in order to keep it easy for custom features to be added over time.
164
+ This model represents the required / valid features so we can
165
+ programmatically validate when saving, updating, viewing.
166
+ """
167
+ def __init__(self, config_dict: dict = None, config_yaml: str = None,
168
+ schema_version: int = None):
169
+ super().__init__()
170
+
171
+ # We validate this as a dictionary. Turn into dictionary if provided
172
+ # as yaml.
173
+ if config_dict is not None:
174
+ self.config = config_dict
175
+ elif config_yaml is not None:
176
+ self.config = yaml.safe_load(config_yaml)
177
+
178
+ if schema_version is None:
179
+ PipelineSchema = BasePipelineSchema.get_latest()
180
+ else:
181
+ PipelineSchema = BasePipelineSchema.get_by_version(schema_version)
182
+
183
+
184
+ self.is_valid = False
185
+ self.validated_config = {}
186
+ self.validation_errors = {}
187
+ try:
188
+ # https://github.com/marshmallow-code/marshmallow/issues/377
189
+ # See issue above when migrating to marshmallow 3
190
+ pcs = PipelineSchema._declared_fields['config'].schema
191
+ self.validated_config = pcs.load(self.config)
192
+ self.is_valid = True
193
+ except ValidationError as e:
194
+ self.validation_errors = e.messages
195
+ raise e
196
+ except Exception as e:
197
+ raise e
@@ -0,0 +1,210 @@
1
+ """ Schemas for Schedule Configuration
2
+ """
3
+ import re
4
+ from celery.schedules import crontab_parser
5
+ from croniter import croniter
6
+ from marshmallow.validate import OneOf
7
+ from marshmallow.exceptions import ValidationError
8
+ from marshmallow import Schema, fields, EXCLUDE, pre_load, validates_schema
9
+
10
+
11
+ class ExcludeUnknownSchema(Schema):
12
+ """ Remove unknown keys from loaded dictionary
13
+
14
+ # TODO this seems to be just ignoring and letting through vs excluding...
15
+ """
16
+ class Meta:
17
+ unknown = EXCLUDE
18
+
19
+
20
+ class IntervalScheduleSchema(Schema):
21
+ every = fields.Integer(required=True)
22
+ period = fields.String(
23
+ required=True,
24
+ validate=OneOf(['microseconds', 'seconds', 'minutes', 'hours',
25
+ 'days']))
26
+
27
+
28
+ class CrontabScheduleSchema(Schema):
29
+ minute = fields.String(required=True)
30
+ hour = fields.String(required=True)
31
+ dayOfWeek = fields.String(required=True)
32
+ dayOfMonth = fields.String(required=True)
33
+ monthOfYear = fields.String(required=True)
34
+
35
+ @validates_schema
36
+ def validate_values(self, data, **kwargs):
37
+ if data['minute'] is None or data['hour'] is None or \
38
+ data['dayOfWeek'] is None or data['dayOfMonth'] is None or\
39
+ data['monthOfYear'] is None:
40
+ raise ValidationError("Empty crontab value")
41
+
42
+ test_cron_expression = \
43
+ f"{data['minute']} {data['hour']} {data['dayOfMonth']} " \
44
+ f"{data['monthOfYear']} {data['dayOfWeek']}"
45
+
46
+ if not croniter.is_valid(test_cron_expression):
47
+ return ValidationError("Invalid crontab value")
48
+
49
+
50
+ class Schedule(fields.Dict):
51
+ def _serialize(self, value, attr, obj, **kwargs):
52
+ return value
53
+
54
+ def _deserialize(self, value, attr, data, **kwargs):
55
+ if data['scheduleType'] == 'crontab':
56
+ schema = CrontabScheduleSchema()
57
+ else:
58
+ schema = IntervalScheduleSchema()
59
+ return schema.load(value)
60
+
61
+
62
+ class ScheduleConfigSchemaV1(ExcludeUnknownSchema):
63
+ """ Definition of a single schedule entry
64
+
65
+ TODO: Add validation based on schedule_type and the relevant optional fields
66
+ TODO: Add validation that each name is unique
67
+ """
68
+
69
+ scheduleType = fields.String(
70
+ required=True,
71
+ validate=OneOf(['interval', 'crontab']),
72
+ description="The Celery schedule type of this entry.",
73
+ example="interval",
74
+ data_key='scheduleType')
75
+
76
+ queue = fields.String(required=True,
77
+ description="Name of queue on which to place task.",
78
+ example="my-default-queue")
79
+ task = fields.String(required=True,
80
+ description="Path to task to invoke.",
81
+ example="my_app.module.method")
82
+ exchange = fields.String(
83
+ required=False,
84
+ description="Exchange for the task. Celery default "
85
+ "used if not set, which is recommended.",
86
+ example="tasks")
87
+ routing_key = fields.String(
88
+ required=False,
89
+ description="Routing key for the task. Celery "
90
+ "default used if not set, which is recommended.",
91
+ example="task.default",
92
+ data_key='routingKey')
93
+ expires = fields.Integer(
94
+ required=False,
95
+ description="Number of seconds after which task "
96
+ "expires if not executed. Default: no expiration.",
97
+ example=60)
98
+
99
+ schedule = Schedule(required=True)
100
+
101
+ @pre_load
102
+ def validate_string_fields(self, item, **kwargs):
103
+ """ Ensure string fields with no OneOf validation conform to patterns
104
+ """
105
+ if item is None:
106
+ raise ValidationError("NoneType provided, check input.")
107
+
108
+ validation_map = {
109
+ 'name': r'^[\w\d\-\_\.\s]+$',
110
+ 'queue': r'^[\w\d\-\_\.]+$',
111
+ 'task': r'^[\w\d\-\_\.]+$',
112
+ 'exchange': r'^[\w\d\-\_\.]+$',
113
+ 'routing_key': r'^[\w\d\-\_\.]+$'
114
+ }
115
+ for field in validation_map:
116
+ if item.get(field, None) is None:
117
+ continue
118
+ if not bool(re.match(validation_map[field], item[field])):
119
+ raise ValidationError(
120
+ f"Invalid {field}: `{item[field]}``. Must match pattern: "
121
+ f"{validation_map[field]}")
122
+
123
+ if 'scheduleType' not in item:
124
+ raise ValidationError('Missing required field scheduleType')
125
+
126
+ if item['scheduleType'] == 'crontab':
127
+ cron_validation_map = {
128
+ 'minute': crontab_parser(60),
129
+ 'hour': crontab_parser(24),
130
+ 'dayOfWeek': crontab_parser(7),
131
+ 'dayOfMonth': crontab_parser(31, 1),
132
+ 'monthOfYear': crontab_parser(12, 1)
133
+ }
134
+
135
+ for field in cron_validation_map:
136
+ try:
137
+ cron_validation_map[field].parse(item['schedule'][field])
138
+ except:
139
+ raise ValidationError(
140
+ f"Invalid {field}: `{item['schedule'][field]}`. Must "
141
+ "be valid crontab pattern.")
142
+
143
+ return item
144
+
145
+
146
+ class BaseScheduleSchema(ExcludeUnknownSchema):
147
+ __schema_version__ = 0
148
+
149
+ name = fields.String(required=True,
150
+ description="Name of schedule entry.",
151
+ example="My Scheduled Task")
152
+ schemaVersion = fields.Integer(required=True)
153
+ config = fields.Dict(required=True)
154
+ enabled = fields.Boolean(required=True,
155
+ description="Whether entry is enabled.",
156
+ example=True)
157
+ # TODO Figure out where that wonky timestamp format is coming from and
158
+ # update this and in celery_beat.py.
159
+ lastRunAt = fields.DateTime(allow_none=True,
160
+ missing=None,
161
+ description="Timestamp of last run time.",
162
+ example="Tue, 18 Aug 2020 01:36:06 GMT",
163
+ data_key='lastRunAt')
164
+ totalRunCount = fields.Integer(
165
+ allow_none=True,
166
+ missing=0,
167
+ description="Count of number of executions.",
168
+ example=12345,
169
+ data_key='totalRunCount')
170
+
171
+ @classmethod
172
+ def get_by_version(cls, version):
173
+ for subclass in cls.__subclasses__():
174
+ if subclass.__schema_version__ == version:
175
+ return subclass
176
+
177
+ return None
178
+
179
+ @classmethod
180
+ def get_latest(cls):
181
+ max_version = 0
182
+ max_class = None
183
+ for subclass in cls.__subclasses__():
184
+ if subclass.__schema_version__ > max_version:
185
+ max_version = max_version
186
+ max_class = subclass
187
+
188
+ return max_class
189
+
190
+ @validates_schema
191
+ def validate_scheduled_tasks(self, data, **kwargs):
192
+ schema_version = data['schemaVersion']
193
+ TaskSchema = BaseScheduleSchema.get_by_version(schema_version)
194
+ schema = TaskSchema()
195
+ schema.load(data)
196
+
197
+
198
+ class ScheduleSchemaV1(BaseScheduleSchema):
199
+ __schema_version__ = 1
200
+
201
+ config = fields.Nested(
202
+ ScheduleConfigSchemaV1,
203
+ required=True,
204
+ description="Configuration information for this schedule.")
205
+
206
+ def validate_scheduled_tasks(self, data, **kwargs):
207
+ # We need to add this function to avoid infinite recursion since
208
+ # the BaseScheduleSchema class above uses the same method for
209
+ # validation
210
+ pass