pycsp3-scheduling 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,495 @@
1
+ """
2
+ Interval expression functions for scheduling models.
3
+
4
+ These functions return expression objects that can be used in constraints
5
+ and objectives. They extract properties from interval variables.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum, auto
12
+ from typing import TYPE_CHECKING, Any, Union
13
+
14
+ if TYPE_CHECKING:
15
+ from pycsp3_scheduling.variables.interval import IntervalVar
16
+
17
+
18
+ class ExprType(Enum):
19
+ """Types of interval expressions."""
20
+
21
+ START_OF = auto()
22
+ END_OF = auto()
23
+ SIZE_OF = auto()
24
+ LENGTH_OF = auto()
25
+ PRESENCE_OF = auto()
26
+ OVERLAP_LENGTH = auto()
27
+ # Arithmetic combinations
28
+ ADD = auto()
29
+ SUB = auto()
30
+ MUL = auto()
31
+ DIV = auto()
32
+ NEG = auto()
33
+ ABS = auto()
34
+ MIN = auto()
35
+ MAX = auto()
36
+ # Comparison (for constraints)
37
+ EQ = auto()
38
+ NE = auto()
39
+ LT = auto()
40
+ LE = auto()
41
+ GT = auto()
42
+ GE = auto()
43
+
44
+
45
+ @dataclass
46
+ class IntervalExpr:
47
+ """
48
+ Base class for interval-related expressions.
49
+
50
+ These expressions represent values derived from interval variables
51
+ that can be used in constraints and objectives.
52
+
53
+ Attributes:
54
+ expr_type: The type of expression.
55
+ interval: The interval variable (if applicable).
56
+ absent_value: Value to use when interval is absent.
57
+ operands: Child expressions for compound expressions.
58
+ value: Constant value (for literals).
59
+ """
60
+
61
+ expr_type: ExprType
62
+ interval: IntervalVar | None = None
63
+ absent_value: int = 0
64
+ operands: list[IntervalExpr] = field(default_factory=list)
65
+ value: int | None = None
66
+ _id: int = field(default=-1, repr=False)
67
+
68
+ def __post_init__(self) -> None:
69
+ """Assign unique ID."""
70
+ if self._id == -1:
71
+ self._id = IntervalExpr._get_next_id()
72
+
73
+ @staticmethod
74
+ def _get_next_id() -> int:
75
+ """Get next unique ID."""
76
+ current = getattr(IntervalExpr, "_id_counter", 0)
77
+ IntervalExpr._id_counter = current + 1
78
+ return current
79
+
80
+ # Arithmetic operators
81
+ def __add__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
82
+ """Add two expressions or expression and constant."""
83
+ other_expr = _to_expr(other)
84
+ return IntervalExpr(
85
+ expr_type=ExprType.ADD,
86
+ operands=[self, other_expr],
87
+ )
88
+
89
+ def __radd__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
90
+ """Right addition."""
91
+ return self.__add__(other)
92
+
93
+ def __sub__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
94
+ """Subtract two expressions or expression and constant."""
95
+ other_expr = _to_expr(other)
96
+ return IntervalExpr(
97
+ expr_type=ExprType.SUB,
98
+ operands=[self, other_expr],
99
+ )
100
+
101
+ def __rsub__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
102
+ """Right subtraction."""
103
+ other_expr = _to_expr(other)
104
+ return IntervalExpr(
105
+ expr_type=ExprType.SUB,
106
+ operands=[other_expr, self],
107
+ )
108
+
109
+ def __mul__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
110
+ """Multiply two expressions or expression and constant."""
111
+ other_expr = _to_expr(other)
112
+ return IntervalExpr(
113
+ expr_type=ExprType.MUL,
114
+ operands=[self, other_expr],
115
+ )
116
+
117
+ def __rmul__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
118
+ """Right multiplication."""
119
+ return self.__mul__(other)
120
+
121
+ def __truediv__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
122
+ """Divide two expressions or expression and constant."""
123
+ other_expr = _to_expr(other)
124
+ return IntervalExpr(
125
+ expr_type=ExprType.DIV,
126
+ operands=[self, other_expr],
127
+ )
128
+
129
+ def __neg__(self) -> IntervalExpr:
130
+ """Negate expression."""
131
+ return IntervalExpr(
132
+ expr_type=ExprType.NEG,
133
+ operands=[self],
134
+ )
135
+
136
+ def __abs__(self) -> IntervalExpr:
137
+ """Absolute value of expression."""
138
+ return IntervalExpr(
139
+ expr_type=ExprType.ABS,
140
+ operands=[self],
141
+ )
142
+
143
+ # Comparison operators (return constraint expressions)
144
+ def __eq__(self, other: object) -> IntervalExpr: # type: ignore[override]
145
+ """Equality comparison."""
146
+ if isinstance(other, (IntervalExpr, int)):
147
+ other_expr = _to_expr(other)
148
+ return IntervalExpr(
149
+ expr_type=ExprType.EQ,
150
+ operands=[self, other_expr],
151
+ )
152
+ return NotImplemented
153
+
154
+ def __ne__(self, other: object) -> IntervalExpr: # type: ignore[override]
155
+ """Inequality comparison."""
156
+ if isinstance(other, (IntervalExpr, int)):
157
+ other_expr = _to_expr(other)
158
+ return IntervalExpr(
159
+ expr_type=ExprType.NE,
160
+ operands=[self, other_expr],
161
+ )
162
+ return NotImplemented
163
+
164
+ def __lt__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
165
+ """Less than comparison."""
166
+ other_expr = _to_expr(other)
167
+ return IntervalExpr(
168
+ expr_type=ExprType.LT,
169
+ operands=[self, other_expr],
170
+ )
171
+
172
+ def __le__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
173
+ """Less than or equal comparison."""
174
+ other_expr = _to_expr(other)
175
+ return IntervalExpr(
176
+ expr_type=ExprType.LE,
177
+ operands=[self, other_expr],
178
+ )
179
+
180
+ def __gt__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
181
+ """Greater than comparison."""
182
+ other_expr = _to_expr(other)
183
+ return IntervalExpr(
184
+ expr_type=ExprType.GT,
185
+ operands=[self, other_expr],
186
+ )
187
+
188
+ def __ge__(self, other: Union[IntervalExpr, int]) -> IntervalExpr:
189
+ """Greater than or equal comparison."""
190
+ other_expr = _to_expr(other)
191
+ return IntervalExpr(
192
+ expr_type=ExprType.GE,
193
+ operands=[self, other_expr],
194
+ )
195
+
196
+ def __hash__(self) -> int:
197
+ """Hash based on unique ID."""
198
+ return hash(self._id)
199
+
200
+ def __repr__(self) -> str:
201
+ """String representation."""
202
+ # Check for constant value first
203
+ if self.value is not None:
204
+ return str(self.value)
205
+ if self.expr_type == ExprType.START_OF:
206
+ return f"start_of({self.interval.name if self.interval else '?'})"
207
+ elif self.expr_type == ExprType.END_OF:
208
+ return f"end_of({self.interval.name if self.interval else '?'})"
209
+ elif self.expr_type == ExprType.SIZE_OF:
210
+ return f"size_of({self.interval.name if self.interval else '?'})"
211
+ elif self.expr_type == ExprType.LENGTH_OF:
212
+ return f"length_of({self.interval.name if self.interval else '?'})"
213
+ elif self.expr_type == ExprType.PRESENCE_OF:
214
+ return f"presence_of({self.interval.name if self.interval else '?'})"
215
+ elif self.expr_type == ExprType.OVERLAP_LENGTH:
216
+ names = [op.interval.name if op.interval else '?' for op in self.operands]
217
+ return f"overlap_length({names[0]}, {names[1]})"
218
+ elif self.expr_type == ExprType.ADD:
219
+ return f"({self.operands[0]} + {self.operands[1]})"
220
+ elif self.expr_type == ExprType.SUB:
221
+ return f"({self.operands[0]} - {self.operands[1]})"
222
+ elif self.expr_type == ExprType.MUL:
223
+ return f"({self.operands[0]} * {self.operands[1]})"
224
+ elif self.expr_type == ExprType.DIV:
225
+ return f"({self.operands[0]} / {self.operands[1]})"
226
+ elif self.expr_type == ExprType.NEG:
227
+ return f"(-{self.operands[0]})"
228
+ elif self.expr_type == ExprType.MIN:
229
+ return f"min({', '.join(str(op) for op in self.operands)})"
230
+ elif self.expr_type == ExprType.MAX:
231
+ return f"max({', '.join(str(op) for op in self.operands)})"
232
+ elif self.expr_type == ExprType.ABS:
233
+ return f"abs({self.operands[0]})"
234
+ elif self.expr_type == ExprType.EQ:
235
+ return f"({self.operands[0]} == {self.operands[1]})"
236
+ elif self.expr_type == ExprType.NE:
237
+ return f"({self.operands[0]} != {self.operands[1]})"
238
+ elif self.expr_type == ExprType.LT:
239
+ return f"({self.operands[0]} < {self.operands[1]})"
240
+ elif self.expr_type == ExprType.LE:
241
+ return f"({self.operands[0]} <= {self.operands[1]})"
242
+ elif self.expr_type == ExprType.GT:
243
+ return f"({self.operands[0]} > {self.operands[1]})"
244
+ elif self.expr_type == ExprType.GE:
245
+ return f"({self.operands[0]} >= {self.operands[1]})"
246
+ return f"IntervalExpr({self.expr_type})"
247
+
248
+ def get_intervals(self) -> list[IntervalVar]:
249
+ """Get all interval variables referenced by this expression."""
250
+ intervals = []
251
+ if self.interval is not None:
252
+ intervals.append(self.interval)
253
+ for operand in self.operands:
254
+ intervals.extend(operand.get_intervals())
255
+ return intervals
256
+
257
+ def is_comparison(self) -> bool:
258
+ """Check if this is a comparison expression (constraint)."""
259
+ return self.expr_type in (
260
+ ExprType.EQ,
261
+ ExprType.NE,
262
+ ExprType.LT,
263
+ ExprType.LE,
264
+ ExprType.GT,
265
+ ExprType.GE,
266
+ )
267
+
268
+
269
+ def _to_expr(value: Union[IntervalExpr, int]) -> IntervalExpr:
270
+ """Convert value to IntervalExpr."""
271
+ if isinstance(value, IntervalExpr):
272
+ return value
273
+ if isinstance(value, int):
274
+ # Create a constant expression
275
+ return IntervalExpr(
276
+ expr_type=ExprType.ADD, # Dummy type for constants
277
+ value=value,
278
+ )
279
+ raise TypeError(f"Cannot convert {type(value)} to IntervalExpr")
280
+
281
+
282
+ # ============================================================================
283
+ # Public API Functions
284
+ # ============================================================================
285
+
286
+
287
+ def start_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr:
288
+ """
289
+ Return an expression representing the start time of an interval.
290
+
291
+ If the interval is absent (optional and not selected), returns absent_value.
292
+
293
+ Args:
294
+ interval: The interval variable.
295
+ absent_value: Value to return if interval is absent (default: 0).
296
+
297
+ Returns:
298
+ An expression representing the start time.
299
+
300
+ Example:
301
+ >>> task = IntervalVar(size=10, name="task")
302
+ >>> expr = start_of(task)
303
+ >>> # Can be used in constraints: start_of(task) >= 5
304
+ """
305
+ return IntervalExpr(
306
+ expr_type=ExprType.START_OF,
307
+ interval=interval,
308
+ absent_value=absent_value,
309
+ )
310
+
311
+
312
+ def end_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr:
313
+ """
314
+ Return an expression representing the end time of an interval.
315
+
316
+ If the interval is absent (optional and not selected), returns absent_value.
317
+
318
+ FIXME: end_of() still returns an internal IntervalExpr; for pycsp3 objectives use end_time() for now.
319
+
320
+ Args:
321
+ interval: The interval variable.
322
+ absent_value: Value to return if interval is absent (default: 0).
323
+
324
+ Returns:
325
+ An expression representing the end time.
326
+
327
+ Example:
328
+ >>> task = IntervalVar(size=10, name="task")
329
+ >>> expr = end_of(task)
330
+ >>> # Can be used in constraints: end_of(task) <= 100
331
+ """
332
+ return IntervalExpr(
333
+ expr_type=ExprType.END_OF,
334
+ interval=interval,
335
+ absent_value=absent_value,
336
+ )
337
+
338
+
339
+ def size_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr:
340
+ """
341
+ Return an expression representing the size (duration) of an interval.
342
+
343
+ If the interval is absent (optional and not selected), returns absent_value.
344
+
345
+ Args:
346
+ interval: The interval variable.
347
+ absent_value: Value to return if interval is absent (default: 0).
348
+
349
+ Returns:
350
+ An expression representing the size.
351
+
352
+ Example:
353
+ >>> task = IntervalVar(size=(5, 20), name="task")
354
+ >>> expr = size_of(task)
355
+ >>> # Can be used in constraints: size_of(task) >= 10
356
+ """
357
+ return IntervalExpr(
358
+ expr_type=ExprType.SIZE_OF,
359
+ interval=interval,
360
+ absent_value=absent_value,
361
+ )
362
+
363
+
364
+ def length_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr:
365
+ """
366
+ Return an expression representing the length of an interval.
367
+
368
+ Length can differ from size when intensity functions are used.
369
+ If the interval is absent, returns absent_value.
370
+
371
+ Args:
372
+ interval: The interval variable.
373
+ absent_value: Value to return if interval is absent (default: 0).
374
+
375
+ Returns:
376
+ An expression representing the length.
377
+
378
+ Example:
379
+ >>> task = IntervalVar(size=10, length=(8, 12), name="task")
380
+ >>> expr = length_of(task)
381
+ """
382
+ return IntervalExpr(
383
+ expr_type=ExprType.LENGTH_OF,
384
+ interval=interval,
385
+ absent_value=absent_value,
386
+ )
387
+
388
+
389
+ def presence_of(interval: IntervalVar) -> IntervalExpr:
390
+ """
391
+ Return a boolean expression representing whether the interval is present.
392
+
393
+ For mandatory intervals, this is always true.
394
+ For optional intervals, this is a decision variable.
395
+
396
+ Args:
397
+ interval: The interval variable.
398
+
399
+ Returns:
400
+ A boolean expression (0 or 1) for presence.
401
+
402
+ Example:
403
+ >>> task = IntervalVar(size=10, optional=True, name="task")
404
+ >>> expr = presence_of(task)
405
+ >>> # Can be used: presence_of(task) == 1 means task is selected
406
+ """
407
+ return IntervalExpr(
408
+ expr_type=ExprType.PRESENCE_OF,
409
+ interval=interval,
410
+ absent_value=0, # Not applicable for presence
411
+ )
412
+
413
+
414
+ def overlap_length(
415
+ interval1: IntervalVar,
416
+ interval2: IntervalVar,
417
+ absent_value: int = 0,
418
+ ) -> IntervalExpr:
419
+ """
420
+ Return an expression for the overlap length between two intervals.
421
+
422
+ The overlap is max(0, min(end1, end2) - max(start1, start2)).
423
+ If either interval is absent, returns absent_value.
424
+
425
+ Args:
426
+ interval1: First interval variable.
427
+ interval2: Second interval variable.
428
+ absent_value: Value to return if either interval is absent.
429
+
430
+ Returns:
431
+ An expression representing the overlap length.
432
+
433
+ Example:
434
+ >>> task1 = IntervalVar(size=10, name="task1")
435
+ >>> task2 = IntervalVar(size=15, name="task2")
436
+ >>> expr = overlap_length(task1, task2)
437
+ >>> # expr == 0 means no overlap
438
+ """
439
+ # Create placeholder expressions for the two intervals
440
+ expr1 = IntervalExpr(
441
+ expr_type=ExprType.START_OF,
442
+ interval=interval1,
443
+ )
444
+ expr2 = IntervalExpr(
445
+ expr_type=ExprType.START_OF,
446
+ interval=interval2,
447
+ )
448
+ return IntervalExpr(
449
+ expr_type=ExprType.OVERLAP_LENGTH,
450
+ operands=[expr1, expr2],
451
+ absent_value=absent_value,
452
+ )
453
+
454
+
455
+ # ============================================================================
456
+ # Utility Functions
457
+ # ============================================================================
458
+
459
+
460
+ def expr_min(*args: Union[IntervalExpr, int]) -> IntervalExpr:
461
+ """
462
+ Return the minimum of multiple expressions.
463
+
464
+ Args:
465
+ *args: Expressions or integers to take minimum of.
466
+
467
+ Returns:
468
+ An expression representing the minimum.
469
+ """
470
+ if len(args) < 2:
471
+ raise ValueError("expr_min requires at least 2 arguments")
472
+ exprs = [_to_expr(a) for a in args]
473
+ return IntervalExpr(
474
+ expr_type=ExprType.MIN,
475
+ operands=exprs,
476
+ )
477
+
478
+
479
+ def expr_max(*args: Union[IntervalExpr, int]) -> IntervalExpr:
480
+ """
481
+ Return the maximum of multiple expressions.
482
+
483
+ Args:
484
+ *args: Expressions or integers to take maximum of.
485
+
486
+ Returns:
487
+ An expression representing the maximum.
488
+ """
489
+ if len(args) < 2:
490
+ raise ValueError("expr_max requires at least 2 arguments")
491
+ exprs = [_to_expr(a) for a in args]
492
+ return IntervalExpr(
493
+ expr_type=ExprType.MAX,
494
+ operands=exprs,
495
+ )