islpy 2025.1.4__cp310-cp310-macosx_11_0_arm64.whl → 2025.2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
islpy/__init__.py CHANGED
@@ -20,969 +20,135 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
20
  THE SOFTWARE.
21
21
  """
22
22
 
23
- from typing import Any, Callable, Optional, TypeVar, cast
23
+ from collections.abc import Collection, Sequence
24
+ from typing import Literal, TypeAlias, TypeVar
24
25
 
25
- import islpy._isl as _isl
26
- from islpy.version import VERSION, VERSION_TEXT # noqa
26
+ from islpy.version import VERSION, VERSION_TEXT
27
27
 
28
28
 
29
29
  __version__ = VERSION_TEXT
30
30
 
31
- # {{{ copied verbatim from pytools to avoid numpy/pytools dependency
32
-
33
- F = TypeVar("F", bound=Callable[..., Any])
34
-
35
-
36
- class _HasKwargs:
37
- pass
38
-
39
-
40
- def _memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) -> F:
41
- """Like :func:`memoize_method`, but for functions that take the object
42
- in which do memoization information is stored as first argument.
43
-
44
- Supports cache deletion via ``function_name.clear_cache(self)``.
45
- """
46
- from sys import intern
47
-
48
- if cache_dict_name is None:
49
- cache_dict_name = intern(
50
- f"_memoize_dic_{function.__module__}{function.__name__}"
51
- )
52
-
53
- def wrapper(obj, *args, **kwargs):
54
- if kwargs:
55
- key = (_HasKwargs, frozenset(kwargs.items()), *args)
56
- else:
57
- key = args
58
-
59
- try:
60
- return getattr(obj, cache_dict_name)[key]
61
- except AttributeError:
62
- attribute_error = True
63
- except KeyError:
64
- attribute_error = False
65
-
66
- result = function(obj, *args, **kwargs)
67
- if attribute_error:
68
- object.__setattr__(obj, cache_dict_name, {key: result})
69
- return result
70
- else:
71
- getattr(obj, cache_dict_name)[key] = result
72
- return result
73
-
74
- def clear_cache(obj):
75
- object.__delattr__(obj, cache_dict_name)
76
-
77
- from functools import update_wrapper
78
- new_wrapper = update_wrapper(wrapper, function)
31
+ # {{{ name imports
79
32
 
80
- # type-ignore because mypy has a point here, stuffing random attributes
81
- # into the function's dict is moderately sketchy.
82
- new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
33
+ from islpy._isl import (
34
+ AccessInfo,
35
+ Aff,
36
+ AffList,
37
+ AstBuild,
38
+ AstExpr,
39
+ AstExprList,
40
+ AstNode,
41
+ AstNodeList,
42
+ AstPrintOptions,
43
+ BasicMap,
44
+ BasicMapList,
45
+ BasicSet,
46
+ BasicSetList,
47
+ Cell,
48
+ Constraint,
49
+ ConstraintList,
50
+ Context,
51
+ Error,
52
+ FixedBox,
53
+ Flow,
54
+ Id,
55
+ IdList,
56
+ IdToAstExpr,
57
+ LocalSpace,
58
+ Map,
59
+ MapList,
60
+ Mat,
61
+ MultiAff,
62
+ MultiId,
63
+ MultiPwAff,
64
+ MultiUnionPwAff,
65
+ MultiVal,
66
+ Point,
67
+ Printer,
68
+ PwAff,
69
+ PwAffList,
70
+ PwMultiAff,
71
+ PwMultiAffList,
72
+ PwQPolynomial,
73
+ PwQPolynomialFold,
74
+ PwQPolynomialFoldList,
75
+ PwQPolynomialList,
76
+ QPolynomial,
77
+ QPolynomialFold,
78
+ QPolynomialList,
79
+ Restriction,
80
+ Schedule,
81
+ ScheduleConstraints,
82
+ ScheduleNode,
83
+ Set,
84
+ SetList,
85
+ Space,
86
+ StrideInfo,
87
+ Term,
88
+ UnionAccessInfo,
89
+ UnionFlow,
90
+ UnionMap,
91
+ UnionMapList,
92
+ UnionPwAff,
93
+ UnionPwAffList,
94
+ UnionPwMultiAff,
95
+ UnionPwMultiAffList,
96
+ UnionPwQPolynomial,
97
+ UnionPwQPolynomialFold,
98
+ UnionSet,
99
+ UnionSetList,
100
+ Val,
101
+ ValList,
102
+ Vec,
103
+ Vertex,
104
+ Vertices,
105
+ ast_expr_op_type,
106
+ ast_expr_type,
107
+ ast_loop_type,
108
+ ast_node_type,
109
+ bound,
110
+ dim_type,
111
+ error,
112
+ fold,
113
+ format,
114
+ isl_version,
115
+ on_error,
116
+ schedule_algorithm,
117
+ schedule_node_type,
118
+ stat,
119
+ yaml_style,
120
+ )
121
+
122
+ # importing _monkeypatch has the side effect of actually monkeypatching
123
+ from islpy._monkeypatch import _CHECK_DIM_TYPES, EXPR_CLASSES
83
124
 
84
- return cast(F, new_wrapper)
85
125
 
86
126
  # }}}
87
127
 
88
128
 
89
- Error = _isl.Error
90
-
91
- # {{{ name imports
129
+ # {{{ typing helpers
92
130
 
93
- isl_version = _isl.isl_version
94
-
95
- Context = _isl.Context
96
- IdList = _isl.IdList
97
- ValList = _isl.ValList
98
- BasicSetList = _isl.BasicSetList
99
- BasicMapList = _isl.BasicMapList
100
- SetList = _isl.SetList
101
- MapList = _isl.MapList
102
- UnionSetList = _isl.UnionSetList
103
- ConstraintList = _isl.ConstraintList
104
- AffList = _isl.AffList
105
- PwAffList = _isl.PwAffList
106
- PwMultiAffList = _isl.PwMultiAffList
107
- AstExprList = _isl.AstExprList
108
- AstNodeList = _isl.AstNodeList
109
-
110
- QPolynomialList = _isl.QPolynomialList
111
- PwQPolynomialList = _isl.PwQPolynomialList
112
- PwQPolynomialFoldList = _isl.PwQPolynomialFoldList
113
-
114
- UnionPwAffList = _isl.UnionPwAffList
115
- UnionPwMultiAffList = _isl.UnionPwMultiAffList
116
- UnionMapList = _isl.UnionMapList
117
- UnionSetList = _isl.UnionSetList
118
-
119
- IdToAstExpr = _isl.IdToAstExpr
120
- Printer = _isl.Printer
121
- Val = _isl.Val
122
- MultiVal = _isl.MultiVal
123
- Vec = _isl.Vec
124
- Mat = _isl.Mat
125
- FixedBox = _isl.FixedBox
126
- Aff = _isl.Aff
127
- PwAff = _isl.PwAff
128
- UnionPwAff = _isl.UnionPwAff
129
- MultiAff = _isl.MultiAff
130
- MultiPwAff = _isl.MultiPwAff
131
- PwMultiAff = _isl.PwMultiAff
132
- UnionPwMultiAff = _isl.UnionPwMultiAff
133
- UnionPwAffList = _isl.UnionPwAffList
134
- MultiUnionPwAff = _isl.MultiUnionPwAff
135
- Id = _isl.Id
136
- MultiId = _isl.MultiId
137
- Constraint = _isl.Constraint
138
- Space = _isl.Space
139
- LocalSpace = _isl.LocalSpace
140
- BasicSet = _isl.BasicSet
141
- BasicMap = _isl.BasicMap
142
- Set = _isl.Set
143
- Map = _isl.Map
144
- UnionMap = _isl.UnionMap
145
- UnionSet = _isl.UnionSet
146
- Point = _isl.Point
147
- Vertex = _isl.Vertex
148
- Cell = _isl.Cell
149
- Vertices = _isl.Vertices
150
- StrideInfo = _isl.StrideInfo
151
- QPolynomialFold = _isl.QPolynomialFold
152
- PwQPolynomialFold = _isl.PwQPolynomialFold
153
- UnionPwQPolynomialFold = _isl.UnionPwQPolynomialFold
154
- UnionPwQPolynomial = _isl.UnionPwQPolynomial
155
- QPolynomial = _isl.QPolynomial
156
- PwQPolynomial = _isl.PwQPolynomial
157
- Term = _isl.Term
158
- ScheduleConstraints = _isl.ScheduleConstraints
159
- ScheduleNode = _isl.ScheduleNode
160
- Schedule = _isl.Schedule
161
- AccessInfo = _isl.AccessInfo
162
- Flow = _isl.Flow
163
- Restriction = _isl.Restriction
164
- UnionAccessInfo = _isl.UnionAccessInfo
165
- UnionFlow = _isl.UnionFlow
166
- AstExpr = _isl.AstExpr
167
- AstNode = _isl.AstNode
168
- AstPrintOptions = _isl.AstPrintOptions
169
- AstBuild = _isl.AstBuild
170
-
171
- error = _isl.error
172
- stat = _isl.stat
173
- dim_type = _isl.dim_type
174
- schedule_node_type = _isl.schedule_node_type
175
- ast_expr_op_type = _isl.ast_expr_op_type
176
- ast_expr_type = _isl.ast_expr_type
177
- ast_node_type = _isl.ast_node_type
178
- ast_loop_type = _isl.ast_loop_type
179
- fold = _isl.fold
180
- format = _isl.format
181
- yaml_style = _isl.yaml_style
182
- bound = _isl.bound
183
- on_error = _isl.on_error
184
- schedule_algorithm = _isl.schedule_algorithm
185
-
186
- # backward compatibility
187
- ast_op_type = _isl.ast_expr_op_type
131
+ Alignable: TypeAlias = (
132
+ Space
133
+ | Set | Map
134
+ | BasicSet | BasicMap
135
+ | Aff | PwAff
136
+ )
137
+ AlignableT = TypeVar("AlignableT", bound=Alignable)
188
138
 
189
139
  # }}}
190
140
 
191
141
 
192
- _CHECK_DIM_TYPES = [
193
- dim_type.in_, dim_type.param, dim_type.set]
194
-
195
- ALL_CLASSES = tuple(getattr(_isl, cls) for cls in dir(_isl) if cls[0].isupper())
196
- EXPR_CLASSES = tuple(cls for cls in ALL_CLASSES
197
- if "Aff" in cls.__name__ or "Polynomial" in cls.__name__)
198
-
199
142
  DEFAULT_CONTEXT = Context()
200
143
 
201
144
 
202
- def _get_default_context():
145
+ def _get_default_context() -> Context:
203
146
  """A callable to get the default context for the benefit of Python's
204
147
  ``__reduce__`` protocol.
205
148
  """
206
149
  return DEFAULT_CONTEXT
207
150
 
208
151
 
209
- def _read_from_str_wrapper(cls, context, s, dims_with_apostrophes):
210
- """A callable to reconstitute instances from strings for the benefit
211
- of Python's ``__reduce__`` protocol.
212
- """
213
- cls_from_str = cls.read_from_str(context, s)
214
-
215
- # Apostrophes in dim names have been lost, put them back
216
- for dim_name, (dim_type, dim_idx) in dims_with_apostrophes.items():
217
- cls_from_str = cls_from_str.set_dim_name(dim_type, dim_idx, dim_name)
218
-
219
- return cls_from_str
220
-
221
-
222
- def _add_functionality():
223
- import islpy._isl as _isl # noqa
224
-
225
- # {{{ dim_type
226
-
227
- def dim_type_reduce(v):
228
- return (dim_type, (int(v),))
229
-
230
- dim_type.__reduce__ = dim_type_reduce
231
-
232
- # }}}
233
-
234
- # {{{ Context
235
-
236
- def context_reduce(self):
237
- if self._wraps_same_instance_as(DEFAULT_CONTEXT):
238
- return (_get_default_context, ())
239
- else:
240
- return (Context, ())
241
-
242
- def context_eq(self, other):
243
- return isinstance(other, Context) and self._wraps_same_instance_as(other)
244
-
245
- def context_ne(self, other):
246
- return not self.__eq__(other)
247
-
248
- Context.__reduce__ = context_reduce
249
- Context.__eq__ = context_eq
250
- Context.__ne__ = context_ne
251
-
252
- # }}}
253
-
254
- # {{{ generic initialization, pickling
255
-
256
- def generic_reduce(self):
257
- ctx = self.get_ctx()
258
- prn = Printer.to_str(ctx)
259
- prn = getattr(prn, f"print_{self._base_name}")(self)
260
-
261
- # Reconstructing from string will remove apostrophes in dim names,
262
- # so keep track of dim names with apostrophes
263
- dims_with_apostrophes = {
264
- dname: pos for dname, pos in self.get_var_dict().items()
265
- if "'" in dname}
266
-
267
- return (
268
- _read_from_str_wrapper,
269
- (type(self), ctx, prn.get_str(), dims_with_apostrophes))
270
-
271
- for cls in ALL_CLASSES:
272
- if hasattr(cls, "read_from_str"):
273
- cls.__reduce__ = generic_reduce
274
-
275
- # }}}
276
-
277
- # {{{ printing
278
-
279
- def generic_str(self):
280
- prn = Printer.to_str(self.get_ctx())
281
- getattr(prn, f"print_{self._base_name}")(self)
282
- return prn.get_str()
283
-
284
- def generic_repr(self):
285
- prn = Printer.to_str(self.get_ctx())
286
- getattr(prn, f"print_{self._base_name}")(self)
287
- return f'{type(self).__name__}("{prn.get_str()}")'
288
-
289
- for cls in ALL_CLASSES:
290
- if (hasattr(cls, "_base_name")
291
- and hasattr(Printer, f"print_{cls._base_name}")):
292
- cls.__str__ = generic_str
293
- cls.__repr__ = generic_repr
294
-
295
- if not hasattr(cls, "__hash__"):
296
- raise AssertionError(f"not hashable: {cls}")
297
-
298
- # }}}
299
-
300
- # {{{ Python set-like behavior
301
-
302
- def obj_or(self, other):
303
- try:
304
- return self.union(other)
305
- except TypeError:
306
- return NotImplemented
307
-
308
- def obj_and(self, other):
309
- try:
310
- return self.intersect(other)
311
- except TypeError:
312
- return NotImplemented
313
-
314
- def obj_sub(self, other):
315
- try:
316
- return self.subtract(other)
317
- except TypeError:
318
- return NotImplemented
319
-
320
- for cls in [BasicSet, BasicMap, Set, Map]:
321
- cls.__or__ = obj_or
322
- cls.__ror__ = obj_or
323
- cls.__and__ = obj_and
324
- cls.__rand__ = obj_and
325
- cls.__sub__ = obj_sub
326
-
327
- # }}}
328
-
329
- # {{{ Space
330
-
331
- def space_get_id_dict(self, dimtype=None):
332
- """Return a dictionary mapping variable :class:`Id` instances to tuples
333
- of (:class:`dim_type`, index).
334
-
335
- :param dimtype: None to get all variables, otherwise
336
- one of :class:`dim_type`.
337
- """
338
- result = {}
339
-
340
- def set_dim_id(name, tp, idx):
341
- if name in result:
342
- raise RuntimeError(f"non-unique var id '{name}' encountered")
343
- result[name] = tp, idx
344
-
345
- if dimtype is None:
346
- types = _CHECK_DIM_TYPES
347
- else:
348
- types = [dimtype]
349
-
350
- for tp in types:
351
- for i in range(self.dim(tp)):
352
- name = self.get_dim_id(tp, i)
353
- if name is not None:
354
- set_dim_id(name, tp, i)
355
-
356
- return result
357
-
358
- def space_get_var_dict(self, dimtype=None, ignore_out=False):
359
- """Return a dictionary mapping variable names to tuples of
360
- (:class:`dim_type`, index).
361
-
362
- :param dimtype: None to get all variables, otherwise
363
- one of :class:`dim_type`.
364
- """
365
- result = {}
366
-
367
- def set_dim_name(name, tp, idx):
368
- if name in result:
369
- raise RuntimeError(f"non-unique var name '{name}' encountered")
370
- result[name] = tp, idx
371
-
372
- if dimtype is None:
373
- types = _CHECK_DIM_TYPES
374
- if ignore_out:
375
- types = types[:]
376
- types.remove(dim_type.out)
377
- else:
378
- types = [dimtype]
379
-
380
- for tp in types:
381
- for i in range(self.dim(tp)):
382
- name = self.get_dim_name(tp, i)
383
- if name is not None:
384
- set_dim_name(name, tp, i)
385
-
386
- return result
387
-
388
- def space_create_from_names(ctx, set=None, in_=None, out=None, params=()):
389
- """Create a :class:`Space` from lists of variable names.
390
-
391
- :param set_: names of `set`-type variables.
392
- :param in_: names of `in`-type variables.
393
- :param out: names of `out`-type variables.
394
- :param params: names of parameter-type variables.
395
- """
396
- dt = dim_type
397
-
398
- if set is not None:
399
- if in_ is not None or out is not None:
400
- raise RuntimeError("must pass only one of set / (in_,out)")
401
-
402
- result = Space.set_alloc(ctx, nparam=len(params),
403
- dim=len(set))
404
-
405
- for i, name in enumerate(set):
406
- result = result.set_dim_name(dt.set, i, name)
407
-
408
- elif in_ is not None and out is not None:
409
- if set is not None:
410
- raise RuntimeError("must pass only one of set / (in_,out)")
411
-
412
- result = Space.alloc(ctx, nparam=len(params),
413
- n_in=len(in_), n_out=len(out))
414
-
415
- for i, name in enumerate(in_):
416
- result = result.set_dim_name(dt.in_, i, name)
417
-
418
- for i, name in enumerate(out):
419
- result = result.set_dim_name(dt.out, i, name)
420
- else:
421
- raise RuntimeError("invalid parameter combination")
422
-
423
- for i, name in enumerate(params):
424
- result = result.set_dim_name(dt.param, i, name)
425
-
426
- return result
427
-
428
- Space.create_from_names = staticmethod(space_create_from_names)
429
- Space.get_var_dict = space_get_var_dict
430
- Space.get_id_dict = space_get_id_dict
431
-
432
- # }}}
433
-
434
- # {{{ coefficient wrangling
435
-
436
- def obj_set_coefficients(self, dim_tp, args):
437
- """
438
- :param dim_tp: :class:`dim_type`
439
- :param args: :class:`list` of coefficients, for indices `0..len(args)-1`.
440
-
441
- .. versionchanged:: 2011.3
442
- New for :class:`Aff`
443
- """
444
- for i, coeff in enumerate(args):
445
- self = self.set_coefficient_val(dim_tp, i, coeff)
446
-
447
- return self
448
-
449
- def obj_set_coefficients_by_name(self, iterable, name_to_dim=None):
450
- """Set the coefficients and the constant.
451
-
452
- :param iterable: a :class:`dict` or iterable of :class:`tuple`
453
- instances mapping variable names to their coefficients.
454
- The constant is set to the value of the key '1'.
455
-
456
- .. versionchanged:: 2011.3
457
- New for :class:`Aff`
458
- """
459
- try:
460
- iterable = list(iterable.items())
461
- except AttributeError:
462
- pass
463
-
464
- if name_to_dim is None:
465
- name_to_dim = self.get_space().get_var_dict()
466
-
467
- for name, coeff in iterable:
468
- if name == 1:
469
- self = self.set_constant_val(coeff)
470
- else:
471
- tp, idx = name_to_dim[name]
472
- self = self.set_coefficient_val(tp, idx, coeff)
473
-
474
- return self
475
-
476
- def obj_get_coefficients_by_name(self, dimtype=None, dim_to_name=None):
477
- """Return a dictionary mapping variable names to coefficients.
478
-
479
- :param dimtype: None to get all variables, otherwise
480
- one of :class:`dim_type`.
481
-
482
- .. versionchanged:: 2011.3
483
- New for :class:`Aff`
484
- """
485
- if dimtype is None:
486
- types = _CHECK_DIM_TYPES
487
- else:
488
- types = [dimtype]
489
-
490
- result = {}
491
- for tp in types:
492
- for i in range(self.get_space().dim(tp)):
493
- coeff = self.get_coefficient_val(tp, i)
494
- if coeff:
495
- if dim_to_name is None:
496
- name = self.get_dim_name(tp, i)
497
- else:
498
- name = dim_to_name[tp, i]
499
-
500
- result[name] = coeff
501
-
502
- const = self.get_constant_val()
503
- if const:
504
- result[1] = const
505
-
506
- return result
507
-
508
- for coeff_class in [Constraint, Aff]:
509
- coeff_class.set_coefficients = obj_set_coefficients
510
- coeff_class.set_coefficients_by_name = obj_set_coefficients_by_name
511
- coeff_class.get_coefficients_by_name = obj_get_coefficients_by_name
512
-
513
- # }}}
514
-
515
- # {{{ Id
516
-
517
- Id.user = property(Id.get_user)
518
- Id.name = property(Id.get_name)
519
-
520
- # }}}
521
-
522
- # {{{ Constraint
523
-
524
- def eq_from_names(space, coefficients=None):
525
- """Create a constraint `const + coeff_1*var_1 +... == 0`.
526
-
527
- :param space: :class:`Space`
528
- :param coefficients: a :class:`dict` or iterable of :class:`tuple`
529
- instances mapping variable names to their coefficients
530
- The constant is set to the value of the key '1'.
531
-
532
- .. versionchanged:: 2011.3
533
- Eliminated the separate *const* parameter.
534
- """
535
- if coefficients is None:
536
- coefficients = {}
537
- c = Constraint.equality_alloc(space)
538
- return c.set_coefficients_by_name(coefficients)
539
-
540
- def ineq_from_names(space, coefficients=None):
541
- """Create a constraint `const + coeff_1*var_1 +... >= 0`.
542
-
543
- :param space: :class:`Space`
544
- :param coefficients: a :class:`dict` or iterable of :class:`tuple`
545
- instances mapping variable names to their coefficients
546
- The constant is set to the value of the key '1'.
547
-
548
- .. versionchanged:: 2011.3
549
- Eliminated the separate *const* parameter.
550
- """
551
- if coefficients is None:
552
- coefficients = {}
553
- c = Constraint.inequality_alloc(space)
554
- return c.set_coefficients_by_name(coefficients)
555
-
556
- Constraint.eq_from_names = staticmethod(eq_from_names)
557
- Constraint.ineq_from_names = staticmethod(ineq_from_names)
558
-
559
- # }}}
560
-
561
- def basic_obj_get_constraints(self):
562
- """Get a list of constraints."""
563
- result = []
564
- self.foreach_constraint(result.append)
565
- return result
566
-
567
- # {{{ BasicSet
568
-
569
- BasicSet.get_constraints = basic_obj_get_constraints
570
-
571
- # }}}
572
-
573
- # {{{ BasicMap
574
-
575
- BasicMap.get_constraints = basic_obj_get_constraints
576
-
577
- # }}}
578
-
579
- # {{{ Set
580
-
581
- def set_get_basic_sets(self):
582
- """Get the list of :class:`BasicSet` instances in this :class:`Set`."""
583
- result = []
584
- self.foreach_basic_set(result.append)
585
- return result
586
-
587
- Set.get_basic_sets = set_get_basic_sets
588
- BasicSet.get_basic_sets = set_get_basic_sets
589
-
590
- # }}}
591
-
592
- # {{{ Map
593
-
594
- def map_get_basic_maps(self):
595
- """Get the list of :class:`BasicMap` instances in this :class:`Map`."""
596
- result = []
597
- self.foreach_basic_map(result.append)
598
- return result
599
-
600
- Map.get_basic_maps = map_get_basic_maps
601
-
602
- # }}}
603
-
604
- # {{{ common functionality
605
-
606
- def obj_get_id_dict(self, dimtype=None):
607
- """Return a dictionary mapping :class:`Id` instances to tuples of
608
- (:class:`dim_type`, index).
609
-
610
- :param dimtype: None to get all variables, otherwise
611
- one of :class:`dim_type`.
612
- """
613
- return self.get_space().get_id_dict(dimtype)
614
-
615
- @_memoize_on_first_arg
616
- def obj_get_var_dict(self, dimtype=None):
617
- """Return a dictionary mapping variable names to tuples of
618
- (:class:`dim_type`, index).
619
-
620
- :param dimtype: None to get all variables, otherwise
621
- one of :class:`dim_type`.
622
- """
623
- return self.get_space().get_var_dict(
624
- dimtype, ignore_out=isinstance(self, EXPR_CLASSES))
625
-
626
- def obj_get_var_ids(self, dimtype):
627
- """Return a list of :class:`Id` instances for :class:`dim_type` *dimtype*."""
628
- return [self.get_dim_name(dimtype, i) for i in range(self.dim(dimtype))]
629
-
630
- @_memoize_on_first_arg
631
- def obj_get_var_names(self, dimtype):
632
- """Return a list of dim names (in order) for :class:`dim_type` *dimtype*."""
633
- return [self.get_dim_name(dimtype, i) for i in range(self.dim(dimtype))]
634
-
635
- for cls in ALL_CLASSES:
636
- if hasattr(cls, "get_space") and cls is not Space:
637
- cls.get_id_dict = obj_get_id_dict
638
- cls.get_var_dict = obj_get_var_dict
639
- cls.get_var_ids = obj_get_var_ids
640
- cls.get_var_names = obj_get_var_names
641
- cls.space = property(cls.get_space)
642
-
643
- # }}}
644
-
645
- # {{{ piecewise
646
-
647
- def pwaff_get_pieces(self):
648
- """
649
- :return: list of (:class:`Set`, :class:`Aff`)
650
- """
651
-
652
- result = []
653
-
654
- def append_tuple(*args):
655
- result.append(args)
656
-
657
- self.foreach_piece(append_tuple)
658
- return result
659
-
660
- def pwqpolynomial_get_pieces(self):
661
- """
662
- :return: list of (:class:`Set`, :class:`QPolynomial`)
663
- """
664
-
665
- result = []
666
-
667
- def append_tuple(*args):
668
- result.append(args)
669
-
670
- self.foreach_piece(append_tuple)
671
- return result
672
-
673
- def pw_get_aggregate_domain(self):
674
- """
675
- :return: a :class:`Set` that is the union of the domains of all pieces
676
- """
677
-
678
- result = Set.empty(self.get_domain_space())
679
- for dom, _ in self.get_pieces():
680
- result = result.union(dom)
681
-
682
- return result
683
-
684
- PwAff.get_pieces = pwaff_get_pieces
685
- Aff.get_pieces = pwaff_get_pieces
686
- PwAff.get_aggregate_domain = pw_get_aggregate_domain
687
-
688
- PwQPolynomial.get_pieces = pwqpolynomial_get_pieces
689
- PwQPolynomial.get_aggregate_domain = pw_get_aggregate_domain
690
-
691
- # }}}
692
-
693
- # {{{ QPolynomial
694
-
695
- def qpolynomial_get_terms(self):
696
- """Get the list of :class:`Term` instances in this :class:`QPolynomial`."""
697
- result = []
698
- self.foreach_term(result.append)
699
- return result
700
-
701
- QPolynomial.get_terms = qpolynomial_get_terms
702
-
703
- # }}}
704
-
705
- # {{{ PwQPolynomial
706
-
707
- def pwqpolynomial_eval_with_dict(self, value_dict):
708
- """Evaluates *self* for the parameters specified by
709
- *value_dict*, which maps parameter names to their values.
710
- """
711
-
712
- pt = Point.zero(self.space.params())
713
-
714
- for i in range(self.space.dim(dim_type.param)):
715
- par_name = self.space.get_dim_name(dim_type.param, i)
716
- pt = pt.set_coordinate_val(
717
- dim_type.param, i, value_dict[par_name])
718
-
719
- return self.eval(pt).to_python()
720
-
721
- PwQPolynomial.eval_with_dict = pwqpolynomial_eval_with_dict
722
-
723
- # }}}
724
-
725
- # {{{ arithmetic
726
-
727
- def _number_to_expr_like(template, num):
728
- number_aff = Aff.zero_on_domain(template.get_domain_space())
729
- number_aff = number_aff.set_constant_val(num)
730
-
731
- if isinstance(template, Aff):
732
- return number_aff
733
- if isinstance(template, QPolynomial):
734
- return QPolynomial.from_aff(number_aff)
735
-
736
- # everything else is piecewise
737
-
738
- if template.get_pieces():
739
- number_pw_aff = PwAff.empty(template.get_space())
740
- for set, _ in template.get_pieces():
741
- number_pw_aff = set.indicator_function().cond(
742
- number_aff, number_pw_aff)
743
- else:
744
- number_pw_aff = PwAff.alloc(
745
- Set.universe(template.domain().space),
746
- number_aff)
747
-
748
- if isinstance(template, PwAff):
749
- return number_pw_aff
750
-
751
- elif isinstance(template, PwQPolynomial):
752
- return PwQPolynomial.from_pw_aff(number_pw_aff)
753
-
754
- else:
755
- raise TypeError("unexpected template type")
756
-
757
- ARITH_CLASSES = (Aff, PwAff, QPolynomial, PwQPolynomial) # noqa
758
-
759
- def expr_like_add(self, other):
760
- if not isinstance(other, ARITH_CLASSES):
761
- other = _number_to_expr_like(self, other)
762
-
763
- try:
764
- return self.add(other)
765
- except TypeError:
766
- return NotImplemented
767
-
768
- def expr_like_sub(self, other):
769
- if not isinstance(other, ARITH_CLASSES):
770
- other = _number_to_expr_like(self, other)
771
-
772
- try:
773
- return self.sub(other)
774
- except TypeError:
775
- return NotImplemented
776
-
777
- def expr_like_rsub(self, other):
778
- if not isinstance(other, ARITH_CLASSES):
779
- other = _number_to_expr_like(self, other)
780
-
781
- return -self + other
782
-
783
- def expr_like_mul(self, other):
784
- if not isinstance(other, ARITH_CLASSES):
785
- other = _number_to_expr_like(self, other)
786
-
787
- try:
788
- return self.mul(other)
789
- except TypeError:
790
- return NotImplemented
791
-
792
- def expr_like_floordiv(self, other):
793
- return self.scale_down_val(other).floor()
794
-
795
- for expr_like_class in ARITH_CLASSES:
796
- expr_like_class.__add__ = expr_like_add
797
- expr_like_class.__radd__ = expr_like_add
798
- expr_like_class.__sub__ = expr_like_sub
799
- expr_like_class.__rsub__ = expr_like_rsub
800
- expr_like_class.__mul__ = expr_like_mul
801
- expr_like_class.__rmul__ = expr_like_mul
802
- expr_like_class.__neg__ = expr_like_class.neg
803
-
804
- for qpoly_class in [QPolynomial, PwQPolynomial]:
805
- qpoly_class.__pow__ = qpoly_class.pow
806
-
807
- for aff_class in [Aff, PwAff]:
808
- aff_class.__mod__ = aff_class.mod_val
809
- aff_class.__floordiv__ = expr_like_floordiv
810
-
811
- # }}}
812
-
813
- # {{{ Val
814
-
815
- def val_rsub(self, other):
816
- return -self + other
817
-
818
- def val_bool(self):
819
- return not self.is_zero()
820
-
821
- def val_repr(self):
822
- return f'{type(self).__name__}("{self.to_str()}")'
823
-
824
- def val_to_python(self):
825
- if not self.is_int():
826
- raise ValueError("can only convert integer Val to python")
827
-
828
- return int(self.to_str())
829
-
830
- Val.__add__ = Val.add
831
- Val.__radd__ = Val.add
832
- Val.__sub__ = Val.sub
833
- Val.__rsub__ = val_rsub
834
- Val.__mul__ = Val.mul
835
- Val.__rmul__ = Val.mul
836
- Val.__neg__ = Val.neg
837
- Val.__mod__ = Val.mod
838
- Val.__bool__ = Val.__nonzero__ = val_bool
839
-
840
- Val.__lt__ = Val.lt
841
- Val.__gt__ = Val.gt
842
- Val.__le__ = Val.le
843
- Val.__ge__ = Val.ge
844
- Val.__eq__ = Val.eq
845
- Val.__ne__ = Val.ne
846
-
847
- Val.__repr__ = val_repr
848
- Val.__str__ = Val.to_str
849
- Val.to_python = val_to_python
850
-
851
- # }}}
852
-
853
- # {{{ rich comparisons
854
-
855
- def obj_eq(self, other):
856
- assert self.get_ctx() == other.get_ctx(), (
857
- "Equality-comparing two objects from different ISL Contexts "
858
- "will likely lead to entertaining (but never useful) results. "
859
- "In particular, Spaces with matching names will no longer be "
860
- "equal.")
861
-
862
- return self.is_equal(other)
863
-
864
- def obj_ne(self, other):
865
- return not self.__eq__(other)
866
-
867
- for cls in ALL_CLASSES:
868
- if hasattr(cls, "is_equal"):
869
- cls.__eq__ = obj_eq
870
- cls.__ne__ = obj_ne
871
-
872
- def obj_lt(self, other):
873
- return self.is_strict_subset(other)
874
-
875
- def obj_le(self, other):
876
- return self.is_subset(other)
877
-
878
- def obj_gt(self, other):
879
- return other.is_strict_subset(self)
880
-
881
- def obj_ge(self, other):
882
- return other.is_subset(self)
883
-
884
- for cls in [BasicSet, BasicMap, Set, Map]:
885
- cls.__lt__ = obj_lt
886
- cls.__le__ = obj_le
887
- cls.__gt__ = obj_gt
888
- cls.__ge__ = obj_ge
889
-
890
- # }}}
891
-
892
- # {{{ project_out_except
893
-
894
- def obj_project_out_except(obj, names, types):
895
- """
896
- :param types: list of :class:`dim_type` determining
897
- the types of axes to project out
898
- :param names: names of axes matching the above which
899
- should be left alone by the projection
900
-
901
- .. versionadded:: 2011.3
902
- """
903
-
904
- for tp in types:
905
- while True:
906
- space = obj.get_space()
907
- var_dict = space.get_var_dict(tp)
908
-
909
- all_indices = set(range(space.dim(tp)))
910
- leftover_indices = {var_dict[name][1] for name in names
911
- if name in var_dict}
912
- project_indices = all_indices-leftover_indices
913
- if not project_indices:
914
- break
915
-
916
- min_index = min(project_indices)
917
- count = 1
918
- while min_index+count in project_indices:
919
- count += 1
920
-
921
- obj = obj.project_out(tp, min_index, count)
922
-
923
- return obj
924
-
925
- # }}}
926
-
927
- # {{{ eliminate_except
928
-
929
- def obj_eliminate_except(obj, names, types):
930
- """
931
- :param types: list of :class:`dim_type` determining
932
- the types of axes to eliminate
933
- :param names: names of axes matching the above which
934
- should be left alone by the eliminate
935
-
936
- .. versionadded:: 2011.3
937
- """
938
-
939
- for tp in types:
940
- space = obj.get_space()
941
- var_dict = space.get_var_dict(tp)
942
- to_eliminate = (
943
- set(range(space.dim(tp)))
944
- - {var_dict[name][1] for name in names
945
- if name in var_dict})
946
-
947
- while to_eliminate:
948
- min_index = min(to_eliminate)
949
- count = 1
950
- while min_index+count in to_eliminate:
951
- count += 1
952
-
953
- obj = obj.eliminate(tp, min_index, count)
954
-
955
- to_eliminate -= set(range(min_index, min_index+count))
956
-
957
- return obj
958
-
959
- # }}}
960
-
961
- # {{{ add_constraints
962
-
963
- def obj_add_constraints(obj, constraints):
964
- """
965
- .. versionadded:: 2011.3
966
- """
967
-
968
- for cns in constraints:
969
- obj = obj.add_constraint(cns)
970
-
971
- return obj
972
-
973
- # }}}
974
-
975
- for c in [BasicSet, BasicMap, Set, Map]:
976
- c.project_out_except = obj_project_out_except
977
- c.add_constraints = obj_add_constraints
978
-
979
- for c in [BasicSet, Set]:
980
- c.eliminate_except = obj_eliminate_except
981
-
982
-
983
- _add_functionality()
984
-
985
-
986
152
  def _back_to_basic(new_obj, old_obj):
987
153
  # Work around set_dim_id not being available for Basic{Set,Map}
988
154
  if isinstance(old_obj, BasicSet) and isinstance(new_obj, Set):
@@ -1012,8 +178,14 @@ def _set_dim_id(obj, dt, idx, id):
1012
178
  return _back_to_basic(obj.set_dim_id(dt, idx, id), obj)
1013
179
 
1014
180
 
1015
- def _align_dim_type(template_dt, obj, template, obj_bigger_ok, obj_names,
1016
- template_names):
181
+ def _align_dim_type(
182
+ template_dt: dim_type,
183
+ obj: AlignableT,
184
+ template: AlignableT,
185
+ obj_bigger_ok: bool,
186
+ obj_names: Collection[str],
187
+ template_names: Collection[str],
188
+ ) -> AlignableT:
1017
189
 
1018
190
  # {{{ deal with Aff, PwAff
1019
191
 
@@ -1093,7 +265,11 @@ def _align_dim_type(template_dt, obj, template, obj_bigger_ok, obj_names,
1093
265
  return obj
1094
266
 
1095
267
 
1096
- def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None):
268
+ def align_spaces(
269
+ obj: AlignableT,
270
+ template: Alignable,
271
+ obj_bigger_ok: bool = False,
272
+ ) -> AlignableT:
1097
273
  """
1098
274
  Try to make the space in which *obj* lives the same as that of *template* by
1099
275
  adding/matching named dimensions.
@@ -1102,12 +278,6 @@ def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None):
1102
278
  has more dimensions than *template*.
1103
279
  """
1104
280
 
1105
- if across_dim_types is not None:
1106
- from warnings import warn
1107
- warn("across_dim_types is deprecated and should no longer be used. "
1108
- "It never had any effect anyway.",
1109
- DeprecationWarning, stacklevel=2)
1110
-
1111
281
  have_any_param_domains = (
1112
282
  isinstance(obj, (Set, BasicSet))
1113
283
  and isinstance(template, (Set, BasicSet))
@@ -1119,7 +289,7 @@ def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None):
1119
289
  template = type(template).from_params(template)
1120
290
 
1121
291
  if isinstance(template, EXPR_CLASSES):
1122
- dim_types = _CHECK_DIM_TYPES[:]
292
+ dim_types = list(_CHECK_DIM_TYPES)
1123
293
  dim_types.remove(dim_type.out)
1124
294
  else:
1125
295
  dim_types = _CHECK_DIM_TYPES
@@ -1142,24 +312,25 @@ def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None):
1142
312
  return obj
1143
313
 
1144
314
 
1145
- def align_two(obj1, obj2, across_dim_types=None):
315
+ def align_two(
316
+ obj1: AlignableT,
317
+ obj2: AlignableT,
318
+ ) -> tuple[AlignableT, AlignableT]:
1146
319
  """Align the spaces of two objects, potentially modifying both of them.
1147
320
 
1148
321
  See also :func:`align_spaces`.
1149
322
  """
1150
323
 
1151
- if across_dim_types is not None:
1152
- from warnings import warn
1153
- warn("across_dim_types is deprecated and should no longer be used. "
1154
- "It never had any effect anyway.",
1155
- DeprecationWarning, stacklevel=2)
1156
-
1157
324
  obj1 = align_spaces(obj1, obj2, obj_bigger_ok=True)
1158
325
  obj2 = align_spaces(obj2, obj1, obj_bigger_ok=True)
1159
326
  return (obj1, obj2)
1160
327
 
1161
328
 
1162
- def make_zero_and_vars(set_vars, params=(), ctx=None):
329
+ def make_zero_and_vars(
330
+ set_vars: Sequence[str],
331
+ params: Sequence[str] = (),
332
+ ctx: Context | None = None
333
+ ) -> dict[str | Literal[0], PwAff]:
1163
334
  """
1164
335
  :arg set_vars: an iterable of variable names, or a comma-separated string
1165
336
  :arg params: an iterable of variable names, or a comma-separated string
@@ -1200,7 +371,7 @@ def make_zero_and_vars(set_vars, params=(), ctx=None):
1200
371
  return affs_from_space(space)
1201
372
 
1202
373
 
1203
- def affs_from_space(space):
374
+ def affs_from_space(space: Space) -> dict[Literal[0] | str, PwAff]:
1204
375
  """
1205
376
  :return: a dictionary from variable names (in *set_vars* and *params*)
1206
377
  to :class:`PwAff` instances that represent each of the
@@ -1240,32 +411,97 @@ def affs_from_space(space):
1240
411
  return result
1241
412
 
1242
413
 
1243
- class SuppressedWarnings:
1244
- def __init__(self, ctx):
1245
- from warnings import warn
1246
- warn("islpy.SuppressedWarnings is a deprecated no-op and will be removed "
1247
- "in 2023. Simply remove the use of it to avoid this warning.",
1248
- DeprecationWarning, stacklevel=1)
1249
-
1250
- def __enter__(self):
1251
- pass
1252
-
1253
- def __exit__(self, type, value, traceback):
1254
- pass
1255
-
1256
-
1257
- # {{{ give sphinx something to import so we can produce docs
1258
-
1259
- def _define_doc_link_names():
1260
- class Div:
1261
- pass
1262
-
1263
- _isl.Div = Div
1264
-
1265
-
1266
- _define_doc_link_names()
1267
-
1268
- # }}}
1269
-
414
+ __all__ = (
415
+ "VERSION",
416
+ "VERSION_TEXT",
417
+ "AccessInfo",
418
+ "Aff",
419
+ "AffList",
420
+ "AstBuild",
421
+ "AstExpr",
422
+ "AstExprList",
423
+ "AstNode",
424
+ "AstNodeList",
425
+ "AstPrintOptions",
426
+ "BasicMap",
427
+ "BasicMapList",
428
+ "BasicSet",
429
+ "BasicSetList",
430
+ "Cell",
431
+ "Constraint",
432
+ "ConstraintList",
433
+ "Context",
434
+ "Error",
435
+ "FixedBox",
436
+ "Flow",
437
+ "Id",
438
+ "IdList",
439
+ "IdToAstExpr",
440
+ "LocalSpace",
441
+ "Map",
442
+ "MapList",
443
+ "Mat",
444
+ "MultiAff",
445
+ "MultiId",
446
+ "MultiPwAff",
447
+ "MultiUnionPwAff",
448
+ "MultiVal",
449
+ "Point",
450
+ "Printer",
451
+ "PwAff",
452
+ "PwAffList",
453
+ "PwMultiAff",
454
+ "PwMultiAffList",
455
+ "PwQPolynomial",
456
+ "PwQPolynomialFold",
457
+ "PwQPolynomialFoldList",
458
+ "PwQPolynomialList",
459
+ "QPolynomial",
460
+ "QPolynomialFold",
461
+ "QPolynomialList",
462
+ "Restriction",
463
+ "Schedule",
464
+ "ScheduleConstraints",
465
+ "ScheduleNode",
466
+ "Set",
467
+ "SetList",
468
+ "Space",
469
+ "StrideInfo",
470
+ "Term",
471
+ "UnionAccessInfo",
472
+ "UnionFlow",
473
+ "UnionMap",
474
+ "UnionMapList",
475
+ "UnionPwAff",
476
+ "UnionPwAffList",
477
+ "UnionPwAffList",
478
+ "UnionPwMultiAff",
479
+ "UnionPwMultiAffList",
480
+ "UnionPwQPolynomial",
481
+ "UnionPwQPolynomialFold",
482
+ "UnionSet",
483
+ "UnionSetList",
484
+ "UnionSetList",
485
+ "Val",
486
+ "ValList",
487
+ "Vec",
488
+ "Vertex",
489
+ "Vertices",
490
+ "ast_expr_op_type",
491
+ "ast_expr_type",
492
+ "ast_loop_type",
493
+ "ast_node_type",
494
+ "bound",
495
+ "dim_type",
496
+ "error",
497
+ "fold",
498
+ "format",
499
+ "isl_version",
500
+ "on_error",
501
+ "schedule_algorithm",
502
+ "schedule_node_type",
503
+ "stat",
504
+ "yaml_style",
505
+ )
1270
506
 
1271
507
  # vim: foldmethod=marker