pydocket 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/execution.py CHANGED
@@ -1,10 +1,16 @@
1
+ import abc
2
+ import enum
1
3
  import inspect
4
+ import logging
2
5
  from datetime import datetime
3
- from typing import Any, Awaitable, Callable, Self
6
+ from typing import Any, Awaitable, Callable, Hashable, Literal, Self, cast
4
7
 
5
- import cloudpickle
8
+ import cloudpickle # type: ignore[import]
6
9
 
7
- from docket.annotations import Logged
10
+
11
+ from .annotations import Logged
12
+
13
+ logger: logging.Logger = logging.getLogger(__name__)
8
14
 
9
15
  Message = dict[bytes, bytes]
10
16
 
@@ -31,8 +37,8 @@ class Execution:
31
37
  b"key": self.key.encode(),
32
38
  b"when": self.when.isoformat().encode(),
33
39
  b"function": self.function.__name__.encode(),
34
- b"args": cloudpickle.dumps(self.args),
35
- b"kwargs": cloudpickle.dumps(self.kwargs),
40
+ b"args": cloudpickle.dumps(self.args), # type: ignore[arg-type]
41
+ b"kwargs": cloudpickle.dumps(self.kwargs), # type: ignore[arg-type]
36
42
  b"attempt": str(self.attempt).encode(),
37
43
  }
38
44
 
@@ -72,3 +78,266 @@ class Execution:
72
78
  arguments.append(f"{parameter_name}=...")
73
79
 
74
80
  return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
81
+
82
+
83
+ class Operator(enum.StrEnum):
84
+ EQUAL = "=="
85
+ NOT_EQUAL = "!="
86
+ GREATER_THAN = ">"
87
+ GREATER_THAN_OR_EQUAL = ">="
88
+ LESS_THAN = "<"
89
+ LESS_THAN_OR_EQUAL = "<="
90
+ BETWEEN = "between"
91
+
92
+
93
+ LiteralOperator = Literal["==", "!=", ">", ">=", "<", "<=", "between"]
94
+
95
+
96
+ class StrikeInstruction(abc.ABC):
97
+ direction: Literal["strike", "restore"]
98
+ operator: Operator
99
+
100
+ def __init__(
101
+ self,
102
+ function: str | None,
103
+ parameter: str | None,
104
+ operator: Operator,
105
+ value: Hashable,
106
+ ) -> None:
107
+ self.function = function
108
+ self.parameter = parameter
109
+ self.operator = operator
110
+ self.value = value
111
+
112
+ def as_message(self) -> Message:
113
+ message: dict[bytes, bytes] = {b"direction": self.direction.encode()}
114
+ if self.function:
115
+ message[b"function"] = self.function.encode()
116
+ if self.parameter:
117
+ message[b"parameter"] = self.parameter.encode()
118
+ message[b"operator"] = self.operator.encode()
119
+ message[b"value"] = cloudpickle.dumps(self.value) # type: ignore[arg-type]
120
+ return message
121
+
122
+ @classmethod
123
+ def from_message(cls, message: Message) -> "StrikeInstruction":
124
+ direction = cast(Literal["strike", "restore"], message[b"direction"].decode())
125
+ function = message[b"function"].decode() if b"function" in message else None
126
+ parameter = message[b"parameter"].decode() if b"parameter" in message else None
127
+ operator = cast(Operator, message[b"operator"].decode())
128
+ value = cloudpickle.loads(message[b"value"])
129
+ if direction == "strike":
130
+ return Strike(function, parameter, operator, value)
131
+ else:
132
+ return Restore(function, parameter, operator, value)
133
+
134
+ def as_span_attributes(self) -> dict[str, str]:
135
+ span_attributes: dict[str, str] = {}
136
+ if self.function:
137
+ span_attributes["docket.function"] = self.function
138
+
139
+ if self.parameter:
140
+ span_attributes["docket.parameter"] = self.parameter
141
+ span_attributes["docket.operator"] = self.operator
142
+ span_attributes["docket.value"] = repr(self.value)
143
+
144
+ return span_attributes
145
+
146
+ def call_repr(self) -> str:
147
+ return (
148
+ f"{self.function or '*'}"
149
+ "("
150
+ f"{self.parameter or '*'}"
151
+ " "
152
+ f"{self.operator}"
153
+ " "
154
+ f"{repr(self.value) if self.parameter else '*'}"
155
+ ")"
156
+ )
157
+
158
+
159
+ class Strike(StrikeInstruction):
160
+ direction: Literal["strike", "restore"] = "strike"
161
+
162
+
163
+ class Restore(StrikeInstruction):
164
+ direction: Literal["strike", "restore"] = "restore"
165
+
166
+
167
+ MinimalStrike = tuple[Operator, Hashable]
168
+ ParameterStrikes = dict[str, set[MinimalStrike]]
169
+ TaskStrikes = dict[str, ParameterStrikes]
170
+
171
+
172
+ class StrikeList:
173
+ task_strikes: TaskStrikes
174
+ parameter_strikes: ParameterStrikes
175
+
176
+ def __init__(self) -> None:
177
+ self.task_strikes = {}
178
+ self.parameter_strikes = {}
179
+
180
+ def is_stricken(self, execution: Execution) -> bool:
181
+ """
182
+ Checks if an execution is stricken based on task name or parameter values.
183
+
184
+ Returns:
185
+ bool: True if the execution is stricken, False otherwise.
186
+ """
187
+ function_name = execution.function.__name__
188
+
189
+ # Check if the entire task is stricken (without parameter conditions)
190
+ task_strikes = self.task_strikes.get(function_name, {})
191
+ if function_name in self.task_strikes and not task_strikes:
192
+ return True
193
+
194
+ sig = inspect.signature(execution.function)
195
+
196
+ try:
197
+ bound_args = sig.bind(*execution.args, **execution.kwargs)
198
+ bound_args.apply_defaults()
199
+ except TypeError:
200
+ # If we can't make sense of the arguments, just assume the task is fine
201
+ return False
202
+
203
+ all_arguments = {
204
+ **bound_args.arguments,
205
+ **{
206
+ k: v
207
+ for k, v in execution.kwargs.items()
208
+ if k not in bound_args.arguments
209
+ },
210
+ }
211
+
212
+ for parameter, argument in all_arguments.items():
213
+ for strike_source in [task_strikes, self.parameter_strikes]:
214
+ if parameter not in strike_source:
215
+ continue
216
+
217
+ for operator, strike_value in strike_source[parameter]:
218
+ if self._is_match(argument, operator, strike_value):
219
+ return True
220
+
221
+ return False
222
+
223
+ def _is_match(self, value: Any, operator: Operator, strike_value: Any) -> bool:
224
+ """Determines if a value matches a strike condition."""
225
+ try:
226
+ match operator:
227
+ case "==":
228
+ return value == strike_value
229
+ case "!=":
230
+ return value != strike_value
231
+ case ">":
232
+ return value > strike_value
233
+ case ">=":
234
+ return value >= strike_value
235
+ case "<":
236
+ return value < strike_value
237
+ case "<=":
238
+ return value <= strike_value
239
+ case "between": # pragma: no branch
240
+ lower, upper = strike_value
241
+ return lower <= value <= upper
242
+ except (ValueError, TypeError):
243
+ # If we can't make the comparison due to incompatible types, just log the
244
+ # error and assume the task is not stricken
245
+ logger.warning(
246
+ "Incompatible type for strike condition: %r %s %r",
247
+ strike_value,
248
+ operator,
249
+ value,
250
+ exc_info=True,
251
+ )
252
+ return False
253
+
254
+ def update(self, instruction: StrikeInstruction) -> None:
255
+ try:
256
+ hash(instruction.value)
257
+ except TypeError:
258
+ logger.warning(
259
+ "Incompatible type for strike condition: %s %r",
260
+ instruction.operator,
261
+ instruction.value,
262
+ )
263
+ return
264
+
265
+ if isinstance(instruction, Strike):
266
+ self._strike(instruction)
267
+ elif isinstance(instruction, Restore): # pragma: no branch
268
+ self._restore(instruction)
269
+
270
+ def _strike(self, strike: Strike) -> None:
271
+ if strike.function and strike.parameter:
272
+ try:
273
+ task_strikes = self.task_strikes[strike.function]
274
+ except KeyError:
275
+ task_strikes = self.task_strikes[strike.function] = {}
276
+
277
+ try:
278
+ parameter_strikes = task_strikes[strike.parameter]
279
+ except KeyError:
280
+ parameter_strikes = task_strikes[strike.parameter] = set()
281
+
282
+ parameter_strikes.add((strike.operator, strike.value))
283
+
284
+ elif strike.function:
285
+ try:
286
+ task_strikes = self.task_strikes[strike.function]
287
+ except KeyError:
288
+ task_strikes = self.task_strikes[strike.function] = {}
289
+
290
+ elif strike.parameter: # pragma: no branch
291
+ try:
292
+ parameter_strikes = self.parameter_strikes[strike.parameter]
293
+ except KeyError:
294
+ parameter_strikes = self.parameter_strikes[strike.parameter] = set()
295
+
296
+ parameter_strikes.add((strike.operator, strike.value))
297
+
298
+ def _restore(self, restore: Restore) -> None:
299
+ if restore.function and restore.parameter:
300
+ try:
301
+ task_strikes = self.task_strikes[restore.function]
302
+ except KeyError:
303
+ return
304
+
305
+ try:
306
+ parameter_strikes = task_strikes[restore.parameter]
307
+ except KeyError:
308
+ task_strikes.pop(restore.parameter, None)
309
+ return
310
+
311
+ try:
312
+ parameter_strikes.remove((restore.operator, restore.value))
313
+ except KeyError:
314
+ pass
315
+
316
+ if not parameter_strikes:
317
+ task_strikes.pop(restore.parameter, None)
318
+ if not task_strikes:
319
+ self.task_strikes.pop(restore.function, None)
320
+
321
+ elif restore.function:
322
+ try:
323
+ task_strikes = self.task_strikes[restore.function]
324
+ except KeyError:
325
+ return
326
+
327
+ # If there are no parameter strikes, this was a full task strike
328
+ if not task_strikes:
329
+ self.task_strikes.pop(restore.function, None)
330
+
331
+ elif restore.parameter: # pragma: no branch
332
+ try:
333
+ parameter_strikes = self.parameter_strikes[restore.parameter]
334
+ except KeyError:
335
+ return
336
+
337
+ try:
338
+ parameter_strikes.remove((restore.operator, restore.value))
339
+ except KeyError:
340
+ pass
341
+
342
+ if not parameter_strikes:
343
+ self.parameter_strikes.pop(restore.parameter, None)
docket/instrumentation.py CHANGED
@@ -33,6 +33,12 @@ TASKS_STARTED = meter.create_counter(
33
33
  unit="1",
34
34
  )
35
35
 
36
+ TASKS_STRICKEN = meter.create_counter(
37
+ "docket_tasks_stricken",
38
+ description="How many tasks have been stricken from executing",
39
+ unit="1",
40
+ )
41
+
36
42
  TASKS_COMPLETED = meter.create_counter(
37
43
  "docket_tasks_completed",
38
44
  description="How many tasks that have completed in any state",
@@ -75,6 +81,18 @@ TASKS_RUNNING = meter.create_up_down_counter(
75
81
  unit="1",
76
82
  )
77
83
 
84
+ REDIS_DISRUPTIONS = meter.create_counter(
85
+ "docket_redis_disruptions",
86
+ description="How many times the Redis connection has been disrupted",
87
+ unit="1",
88
+ )
89
+
90
+ STRIKES_IN_EFFECT = meter.create_up_down_counter(
91
+ "docket_strikes_in_effect",
92
+ description="How many strikes are currently in effect",
93
+ unit="1",
94
+ )
95
+
78
96
  Message = dict[bytes, bytes]
79
97
 
80
98
 
docket/tasks.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import logging
2
3
  from datetime import datetime, timezone
3
4
 
@@ -44,7 +45,13 @@ async def fail(
44
45
  )
45
46
 
46
47
 
48
+ async def sleep(seconds: float) -> None:
49
+ logger.info("Sleeping for %s seconds", seconds)
50
+ await asyncio.sleep(seconds)
51
+
52
+
47
53
  standard_tasks: TaskCollection = [
48
54
  trace,
49
55
  fail,
56
+ sleep,
50
57
  ]