islpy 2023.2.5__cp312-cp312-musllinux_1_1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
islpy/__init__.py ADDED
@@ -0,0 +1,1278 @@
1
+ __copyright__ = "Copyright (C) 2011-20 Andreas Kloeckner"
2
+
3
+ __license__ = """
4
+ Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ of this software and associated documentation files (the "Software"), to deal
6
+ in the Software without restriction, including without limitation the rights
7
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
+ copies of the Software, and to permit persons to whom the Software is
9
+ furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in
12
+ all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
+ THE SOFTWARE.
21
+ """
22
+
23
+ from typing import Optional, TypeVar, cast, Callable, Any
24
+
25
+ import islpy._isl as _isl
26
+ from islpy.version import VERSION, VERSION_TEXT # noqa
27
+
28
+ # {{{ copied verbatim from pytools to avoid numpy/pytools dependency
29
+
30
+ F = TypeVar("F", bound=Callable[..., Any])
31
+
32
+
33
+ class _HasKwargs:
34
+ pass
35
+
36
+
37
+ def _memoize_on_first_arg(function: F, cache_dict_name: Optional[str] = None) -> F:
38
+ """Like :func:`memoize_method`, but for functions that take the object
39
+ in which do memoization information is stored as first argument.
40
+
41
+ Supports cache deletion via ``function_name.clear_cache(self)``.
42
+ """
43
+ from sys import intern
44
+
45
+ if cache_dict_name is None:
46
+ cache_dict_name = intern(
47
+ f"_memoize_dic_{function.__module__}{function.__name__}"
48
+ )
49
+
50
+ def wrapper(obj, *args, **kwargs):
51
+ if kwargs:
52
+ key = (_HasKwargs, frozenset(kwargs.items())) + args
53
+ else:
54
+ key = args
55
+
56
+ try:
57
+ return getattr(obj, cache_dict_name)[key]
58
+ except AttributeError:
59
+ attribute_error = True
60
+ except KeyError:
61
+ attribute_error = False
62
+
63
+ result = function(obj, *args, **kwargs)
64
+ if attribute_error:
65
+ object.__setattr__(obj, cache_dict_name, {key: result})
66
+ return result
67
+ else:
68
+ getattr(obj, cache_dict_name)[key] = result
69
+ return result
70
+
71
+ def clear_cache(obj):
72
+ object.__delattr__(obj, cache_dict_name)
73
+
74
+ from functools import update_wrapper
75
+ new_wrapper = update_wrapper(wrapper, function)
76
+
77
+ # type-ignore because mypy has a point here, stuffing random attributes
78
+ # into the function's dict is moderately sketchy.
79
+ new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
80
+
81
+ return cast(F, new_wrapper)
82
+
83
+ # }}}
84
+
85
+
86
+ Error = _isl.Error
87
+
88
+ # {{{ name imports
89
+
90
+ Context = _isl.Context
91
+ IdList = _isl.IdList
92
+ ValList = _isl.ValList
93
+ BasicSetList = _isl.BasicSetList
94
+ BasicMapList = _isl.BasicMapList
95
+ SetList = _isl.SetList
96
+ MapList = _isl.MapList
97
+ UnionSetList = _isl.UnionSetList
98
+ ConstraintList = _isl.ConstraintList
99
+ AffList = _isl.AffList
100
+ PwAffList = _isl.PwAffList
101
+ PwMultiAffList = _isl.PwMultiAffList
102
+ AstExprList = _isl.AstExprList
103
+ AstNodeList = _isl.AstNodeList
104
+
105
+ QPolynomialList = _isl.QPolynomialList
106
+ PwQPolynomialList = _isl.PwQPolynomialList
107
+ PwQPolynomialFoldList = _isl.PwQPolynomialFoldList
108
+
109
+ UnionPwAffList = _isl.UnionPwAffList
110
+ UnionPwMultiAffList = _isl.UnionPwMultiAffList
111
+ UnionMapList = _isl.UnionMapList
112
+ UnionSetList = _isl.UnionSetList
113
+
114
+ IdToAstExpr = _isl.IdToAstExpr
115
+ Printer = _isl.Printer
116
+ Val = _isl.Val
117
+ MultiVal = _isl.MultiVal
118
+ Vec = _isl.Vec
119
+ Mat = _isl.Mat
120
+ FixedBox = _isl.FixedBox
121
+ Aff = _isl.Aff
122
+ PwAff = _isl.PwAff
123
+ UnionPwAff = _isl.UnionPwAff
124
+ MultiAff = _isl.MultiAff
125
+ MultiPwAff = _isl.MultiPwAff
126
+ PwMultiAff = _isl.PwMultiAff
127
+ UnionPwMultiAff = _isl.UnionPwMultiAff
128
+ UnionPwAffList = _isl.UnionPwAffList
129
+ MultiUnionPwAff = _isl.MultiUnionPwAff
130
+ Id = _isl.Id
131
+ MultiId = _isl.MultiId
132
+ Constraint = _isl.Constraint
133
+ Space = _isl.Space
134
+ LocalSpace = _isl.LocalSpace
135
+ BasicSet = _isl.BasicSet
136
+ BasicMap = _isl.BasicMap
137
+ Set = _isl.Set
138
+ Map = _isl.Map
139
+ UnionMap = _isl.UnionMap
140
+ UnionSet = _isl.UnionSet
141
+ Point = _isl.Point
142
+ Vertex = _isl.Vertex
143
+ Cell = _isl.Cell
144
+ Vertices = _isl.Vertices
145
+ StrideInfo = _isl.StrideInfo
146
+ QPolynomialFold = _isl.QPolynomialFold
147
+ PwQPolynomialFold = _isl.PwQPolynomialFold
148
+ UnionPwQPolynomialFold = _isl.UnionPwQPolynomialFold
149
+ UnionPwQPolynomial = _isl.UnionPwQPolynomial
150
+ QPolynomial = _isl.QPolynomial
151
+ PwQPolynomial = _isl.PwQPolynomial
152
+ Term = _isl.Term
153
+ ScheduleConstraints = _isl.ScheduleConstraints
154
+ ScheduleNode = _isl.ScheduleNode
155
+ Schedule = _isl.Schedule
156
+ AccessInfo = _isl.AccessInfo
157
+ Flow = _isl.Flow
158
+ Restriction = _isl.Restriction
159
+ UnionAccessInfo = _isl.UnionAccessInfo
160
+ UnionFlow = _isl.UnionFlow
161
+ AstExpr = _isl.AstExpr
162
+ AstNode = _isl.AstNode
163
+ AstPrintOptions = _isl.AstPrintOptions
164
+ AstBuild = _isl.AstBuild
165
+
166
+ error = _isl.error
167
+ stat = _isl.stat
168
+ dim_type = _isl.dim_type
169
+ schedule_node_type = _isl.schedule_node_type
170
+ ast_expr_op_type = _isl.ast_expr_op_type
171
+ ast_expr_type = _isl.ast_expr_type
172
+ ast_node_type = _isl.ast_node_type
173
+ ast_loop_type = _isl.ast_loop_type
174
+ fold = _isl.fold
175
+ format = _isl.format
176
+ yaml_style = _isl.yaml_style
177
+ bound = _isl.bound
178
+ on_error = _isl.on_error
179
+ schedule_algorithm = _isl.schedule_algorithm
180
+
181
+ # backward compatibility
182
+ ast_op_type = _isl.ast_expr_op_type
183
+
184
+ ALL_CLASSES = [Context, IdList, ValList, BasicSetList, BasicMapList, SetList,
185
+ MapList, UnionSetList, ConstraintList, AffList, PwAffList, AstExprList,
186
+ AstNodeList, IdToAstExpr, Printer, Val, MultiVal, Vec, Mat, Aff, PwAff,
187
+ UnionPwAff, MultiAff, MultiPwAff, PwMultiAff, UnionPwMultiAff,
188
+ UnionPwAffList, MultiUnionPwAff, Id, Constraint, Space, LocalSpace,
189
+ BasicSet, BasicMap, Set, Map, UnionMap, UnionSet, Point, Vertex, Cell,
190
+ Vertices, StrideInfo, QPolynomialFold, PwQPolynomialFold,
191
+ UnionPwQPolynomialFold, UnionPwQPolynomial, QPolynomial, PwQPolynomial,
192
+ Term, ScheduleConstraints, ScheduleNode, Schedule, AccessInfo,
193
+ Flow, Restriction, UnionAccessInfo, UnionFlow, AstExpr, AstNode,
194
+ AstPrintOptions, AstBuild]
195
+
196
+ # }}}
197
+
198
+
199
+ _CHECK_DIM_TYPES = [
200
+ dim_type.in_, dim_type.param, dim_type.set]
201
+
202
+ ALL_CLASSES = tuple(getattr(_isl, cls) for cls in dir(_isl) if cls[0].isupper())
203
+ EXPR_CLASSES = tuple(cls for cls in ALL_CLASSES
204
+ if "Aff" in cls.__name__ or "Polynomial" in cls.__name__)
205
+
206
+ DEFAULT_CONTEXT = Context()
207
+
208
+
209
+ def _get_default_context():
210
+ """A callable to get the default context for the benefit of Python's
211
+ ``__reduce__`` protocol.
212
+ """
213
+ return DEFAULT_CONTEXT
214
+
215
+
216
+ def _read_from_str_wrapper(cls, context, s, dims_with_apostrophes):
217
+ """A callable to reconstitute instances from strings for the benefit
218
+ of Python's ``__reduce__`` protocol.
219
+ """
220
+ cls_from_str = cls.read_from_str(context, s)
221
+
222
+ # Apostrophes in dim names have been lost, put them back
223
+ for dim_name, (dim_type, dim_idx) in dims_with_apostrophes.items():
224
+ cls_from_str = cls_from_str.set_dim_name(dim_type, dim_idx, dim_name)
225
+
226
+ return cls_from_str
227
+
228
+
229
+ def _add_functionality():
230
+ import islpy._isl as _isl # noqa
231
+
232
+ # {{{ dim_type
233
+
234
+ def dim_type_reduce(v):
235
+ return (dim_type, (int(v),))
236
+
237
+ dim_type.__reduce__ = dim_type_reduce
238
+
239
+ # }}}
240
+
241
+ # {{{ Context
242
+
243
+ def context_reduce(self):
244
+ if self._wraps_same_instance_as(DEFAULT_CONTEXT):
245
+ return (_get_default_context, ())
246
+ else:
247
+ return (Context, ())
248
+
249
+ def context_eq(self, other):
250
+ return isinstance(other, Context) and self._wraps_same_instance_as(other)
251
+
252
+ def context_ne(self, other):
253
+ return not self.__eq__(other)
254
+
255
+ Context.__reduce__ = context_reduce
256
+ Context.__eq__ = context_eq
257
+ Context.__ne__ = context_ne
258
+
259
+ # }}}
260
+
261
+ # {{{ generic initialization, pickling
262
+
263
+ def generic_reduce(self):
264
+ ctx = self.get_ctx()
265
+ prn = Printer.to_str(ctx)
266
+ prn = getattr(prn, f"print_{self._base_name}")(self)
267
+
268
+ # Reconstructing from string will remove apostrophes in dim names,
269
+ # so keep track of dim names with apostrophes
270
+ dims_with_apostrophes = {
271
+ dname: pos for dname, pos in self.get_var_dict().items()
272
+ if "'" in dname}
273
+
274
+ return (
275
+ _read_from_str_wrapper,
276
+ (type(self), ctx, prn.get_str(), dims_with_apostrophes))
277
+
278
+ for cls in ALL_CLASSES:
279
+ if hasattr(cls, "read_from_str"):
280
+ cls.__reduce__ = generic_reduce
281
+
282
+ # }}}
283
+
284
+ # {{{ printing
285
+
286
+ def generic_str(self):
287
+ prn = Printer.to_str(self.get_ctx())
288
+ getattr(prn, f"print_{self._base_name}")(self)
289
+ return prn.get_str()
290
+
291
+ def generic_repr(self):
292
+ prn = Printer.to_str(self.get_ctx())
293
+ getattr(prn, f"print_{self._base_name}")(self)
294
+ return f'{type(self).__name__}("{prn.get_str()}")'
295
+
296
+ for cls in ALL_CLASSES:
297
+ if (hasattr(cls, "_base_name")
298
+ and hasattr(Printer, f"print_{cls._base_name}")):
299
+ cls.__str__ = generic_str
300
+ cls.__repr__ = generic_repr
301
+
302
+ if not hasattr(cls, "__hash__"):
303
+ raise AssertionError(f"not hashable: {cls}")
304
+
305
+ # }}}
306
+
307
+ # {{{ Python set-like behavior
308
+
309
+ def obj_or(self, other):
310
+ try:
311
+ return self.union(other)
312
+ except TypeError:
313
+ return NotImplemented
314
+
315
+ def obj_and(self, other):
316
+ try:
317
+ return self.intersect(other)
318
+ except TypeError:
319
+ return NotImplemented
320
+
321
+ def obj_sub(self, other):
322
+ try:
323
+ return self.subtract(other)
324
+ except TypeError:
325
+ return NotImplemented
326
+
327
+ for cls in [BasicSet, BasicMap, Set, Map]:
328
+ cls.__or__ = obj_or
329
+ cls.__ror__ = obj_or
330
+ cls.__and__ = obj_and
331
+ cls.__rand__ = obj_and
332
+ cls.__sub__ = obj_sub
333
+
334
+ # }}}
335
+
336
+ # {{{ Space
337
+
338
+ def space_get_id_dict(self, dimtype=None):
339
+ """Return a dictionary mapping variable :class:`Id` instances to tuples
340
+ of (:class:`dim_type`, index).
341
+
342
+ :param dimtype: None to get all variables, otherwise
343
+ one of :class:`dim_type`.
344
+ """
345
+ result = {}
346
+
347
+ def set_dim_id(name, tp, idx):
348
+ if name in result:
349
+ raise RuntimeError(f"non-unique var id '{name}' encountered")
350
+ result[name] = tp, idx
351
+
352
+ if dimtype is None:
353
+ types = _CHECK_DIM_TYPES
354
+ else:
355
+ types = [dimtype]
356
+
357
+ for tp in types:
358
+ for i in range(self.dim(tp)):
359
+ name = self.get_dim_id(tp, i)
360
+ if name is not None:
361
+ set_dim_id(name, tp, i)
362
+
363
+ return result
364
+
365
+ def space_get_var_dict(self, dimtype=None, ignore_out=False):
366
+ """Return a dictionary mapping variable names to tuples of
367
+ (:class:`dim_type`, index).
368
+
369
+ :param dimtype: None to get all variables, otherwise
370
+ one of :class:`dim_type`.
371
+ """
372
+ result = {}
373
+
374
+ def set_dim_name(name, tp, idx):
375
+ if name in result:
376
+ raise RuntimeError(f"non-unique var name '{name}' encountered")
377
+ result[name] = tp, idx
378
+
379
+ if dimtype is None:
380
+ types = _CHECK_DIM_TYPES
381
+ if ignore_out:
382
+ types = types[:]
383
+ types.remove(dim_type.out)
384
+ else:
385
+ types = [dimtype]
386
+
387
+ for tp in types:
388
+ for i in range(self.dim(tp)):
389
+ name = self.get_dim_name(tp, i)
390
+ if name is not None:
391
+ set_dim_name(name, tp, i)
392
+
393
+ return result
394
+
395
+ def space_create_from_names(ctx, set=None, in_=None, out=None, params=()):
396
+ """Create a :class:`Space` from lists of variable names.
397
+
398
+ :param set_: names of `set`-type variables.
399
+ :param in_: names of `in`-type variables.
400
+ :param out: names of `out`-type variables.
401
+ :param params: names of parameter-type variables.
402
+ """
403
+ dt = dim_type
404
+
405
+ if set is not None:
406
+ if in_ is not None or out is not None:
407
+ raise RuntimeError("must pass only one of set / (in_,out)")
408
+
409
+ result = Space.set_alloc(ctx, nparam=len(params),
410
+ dim=len(set))
411
+
412
+ for i, name in enumerate(set):
413
+ result = result.set_dim_name(dt.set, i, name)
414
+
415
+ elif in_ is not None and out is not None:
416
+ if set is not None:
417
+ raise RuntimeError("must pass only one of set / (in_,out)")
418
+
419
+ result = Space.alloc(ctx, nparam=len(params),
420
+ n_in=len(in_), n_out=len(out))
421
+
422
+ for i, name in enumerate(in_):
423
+ result = result.set_dim_name(dt.in_, i, name)
424
+
425
+ for i, name in enumerate(out):
426
+ result = result.set_dim_name(dt.out, i, name)
427
+ else:
428
+ raise RuntimeError("invalid parameter combination")
429
+
430
+ for i, name in enumerate(params):
431
+ result = result.set_dim_name(dt.param, i, name)
432
+
433
+ return result
434
+
435
+ Space.create_from_names = staticmethod(space_create_from_names)
436
+ Space.get_var_dict = space_get_var_dict
437
+ Space.get_id_dict = space_get_id_dict
438
+
439
+ # }}}
440
+
441
+ # {{{ coefficient wrangling
442
+
443
+ def obj_set_coefficients(self, dim_tp, args):
444
+ """
445
+ :param dim_tp: :class:`dim_type`
446
+ :param args: :class:`list` of coefficients, for indices `0..len(args)-1`.
447
+
448
+ .. versionchanged:: 2011.3
449
+ New for :class:`Aff`
450
+ """
451
+ for i, coeff in enumerate(args):
452
+ self = self.set_coefficient_val(dim_tp, i, coeff)
453
+
454
+ return self
455
+
456
+ def obj_set_coefficients_by_name(self, iterable, name_to_dim=None):
457
+ """Set the coefficients and the constant.
458
+
459
+ :param iterable: a :class:`dict` or iterable of :class:`tuple`
460
+ instances mapping variable names to their coefficients.
461
+ The constant is set to the value of the key '1'.
462
+
463
+ .. versionchanged:: 2011.3
464
+ New for :class:`Aff`
465
+ """
466
+ try:
467
+ iterable = list(iterable.items())
468
+ except AttributeError:
469
+ pass
470
+
471
+ if name_to_dim is None:
472
+ name_to_dim = self.get_space().get_var_dict()
473
+
474
+ for name, coeff in iterable:
475
+ if name == 1:
476
+ self = self.set_constant_val(coeff)
477
+ else:
478
+ tp, idx = name_to_dim[name]
479
+ self = self.set_coefficient_val(tp, idx, coeff)
480
+
481
+ return self
482
+
483
+ def obj_get_coefficients_by_name(self, dimtype=None, dim_to_name=None):
484
+ """Return a dictionary mapping variable names to coefficients.
485
+
486
+ :param dimtype: None to get all variables, otherwise
487
+ one of :class:`dim_type`.
488
+
489
+ .. versionchanged:: 2011.3
490
+ New for :class:`Aff`
491
+ """
492
+ if dimtype is None:
493
+ types = _CHECK_DIM_TYPES
494
+ else:
495
+ types = [dimtype]
496
+
497
+ result = {}
498
+ for tp in types:
499
+ for i in range(self.get_space().dim(tp)):
500
+ coeff = self.get_coefficient_val(tp, i)
501
+ if coeff:
502
+ if dim_to_name is None:
503
+ name = self.get_dim_name(tp, i)
504
+ else:
505
+ name = dim_to_name[(tp, i)]
506
+
507
+ result[name] = coeff
508
+
509
+ const = self.get_constant_val()
510
+ if const:
511
+ result[1] = const
512
+
513
+ return result
514
+
515
+ for coeff_class in [Constraint, Aff]:
516
+ coeff_class.set_coefficients = obj_set_coefficients
517
+ coeff_class.set_coefficients_by_name = obj_set_coefficients_by_name
518
+ coeff_class.get_coefficients_by_name = obj_get_coefficients_by_name
519
+
520
+ # }}}
521
+
522
+ # {{{ Id
523
+
524
+ Id.user = property(Id.get_user)
525
+ Id.name = property(Id.get_name)
526
+
527
+ # }}}
528
+
529
+ # {{{ Constraint
530
+
531
+ def eq_from_names(space, coefficients=None):
532
+ """Create a constraint `const + coeff_1*var_1 +... == 0`.
533
+
534
+ :param space: :class:`Space`
535
+ :param coefficients: a :class:`dict` or iterable of :class:`tuple`
536
+ instances mapping variable names to their coefficients
537
+ The constant is set to the value of the key '1'.
538
+
539
+ .. versionchanged:: 2011.3
540
+ Eliminated the separate *const* parameter.
541
+ """
542
+ if coefficients is None:
543
+ coefficients = {}
544
+ c = Constraint.equality_alloc(space)
545
+ return c.set_coefficients_by_name(coefficients)
546
+
547
+ def ineq_from_names(space, coefficients=None):
548
+ """Create a constraint `const + coeff_1*var_1 +... >= 0`.
549
+
550
+ :param space: :class:`Space`
551
+ :param coefficients: a :class:`dict` or iterable of :class:`tuple`
552
+ instances mapping variable names to their coefficients
553
+ The constant is set to the value of the key '1'.
554
+
555
+ .. versionchanged:: 2011.3
556
+ Eliminated the separate *const* parameter.
557
+ """
558
+ if coefficients is None:
559
+ coefficients = {}
560
+ c = Constraint.inequality_alloc(space)
561
+ return c.set_coefficients_by_name(coefficients)
562
+
563
+ Constraint.eq_from_names = staticmethod(eq_from_names)
564
+ Constraint.ineq_from_names = staticmethod(ineq_from_names)
565
+
566
+ # }}}
567
+
568
+ def basic_obj_get_constraints(self):
569
+ """Get a list of constraints."""
570
+ result = []
571
+ self.foreach_constraint(result.append)
572
+ return result
573
+
574
+ # {{{ BasicSet
575
+
576
+ BasicSet.get_constraints = basic_obj_get_constraints
577
+
578
+ # }}}
579
+
580
+ # {{{ BasicMap
581
+
582
+ BasicMap.get_constraints = basic_obj_get_constraints
583
+
584
+ # }}}
585
+
586
+ # {{{ Set
587
+
588
+ def set_get_basic_sets(self):
589
+ """Get the list of :class:`BasicSet` instances in this :class:`Set`."""
590
+ result = []
591
+ self.foreach_basic_set(result.append)
592
+ return result
593
+
594
+ Set.get_basic_sets = set_get_basic_sets
595
+ BasicSet.get_basic_sets = set_get_basic_sets
596
+
597
+ # }}}
598
+
599
+ # {{{ Map
600
+
601
+ def map_get_basic_maps(self):
602
+ """Get the list of :class:`BasicMap` instances in this :class:`Map`."""
603
+ result = []
604
+ self.foreach_basic_map(result.append)
605
+ return result
606
+
607
+ Map.get_basic_maps = map_get_basic_maps
608
+
609
+ # }}}
610
+
611
+ # {{{ common functionality
612
+
613
+ def obj_get_id_dict(self, dimtype=None):
614
+ """Return a dictionary mapping :class:`Id` instances to tuples of
615
+ (:class:`dim_type`, index).
616
+
617
+ :param dimtype: None to get all variables, otherwise
618
+ one of :class:`dim_type`.
619
+ """
620
+ return self.get_space().get_id_dict(dimtype)
621
+
622
+ @_memoize_on_first_arg
623
+ def obj_get_var_dict(self, dimtype=None):
624
+ """Return a dictionary mapping variable names to tuples of
625
+ (:class:`dim_type`, index).
626
+
627
+ :param dimtype: None to get all variables, otherwise
628
+ one of :class:`dim_type`.
629
+ """
630
+ return self.get_space().get_var_dict(
631
+ dimtype, ignore_out=isinstance(self, EXPR_CLASSES))
632
+
633
+ def obj_get_var_ids(self, dimtype):
634
+ """Return a list of :class:`Id` instances for :class:`dim_type` *dimtype*."""
635
+ return [self.get_dim_name(dimtype, i) for i in range(self.dim(dimtype))]
636
+
637
+ @_memoize_on_first_arg
638
+ def obj_get_var_names(self, dimtype):
639
+ """Return a list of dim names (in order) for :class:`dim_type` *dimtype*."""
640
+ return [self.get_dim_name(dimtype, i) for i in range(self.dim(dimtype))]
641
+
642
+ for cls in ALL_CLASSES:
643
+ if hasattr(cls, "get_space") and cls is not Space:
644
+ cls.get_id_dict = obj_get_id_dict
645
+ cls.get_var_dict = obj_get_var_dict
646
+ cls.get_var_ids = obj_get_var_ids
647
+ cls.get_var_names = obj_get_var_names
648
+ cls.space = property(cls.get_space)
649
+
650
+ # }}}
651
+
652
+ # {{{ piecewise
653
+
654
+ def pwaff_get_pieces(self):
655
+ """
656
+ :return: list of (:class:`Set`, :class:`Aff`)
657
+ """
658
+
659
+ result = []
660
+
661
+ def append_tuple(*args):
662
+ result.append(args)
663
+
664
+ self.foreach_piece(append_tuple)
665
+ return result
666
+
667
+ def pwqpolynomial_get_pieces(self):
668
+ """
669
+ :return: list of (:class:`Set`, :class:`QPolynomial`)
670
+ """
671
+
672
+ result = []
673
+
674
+ def append_tuple(*args):
675
+ result.append(args)
676
+
677
+ self.foreach_piece(append_tuple)
678
+ return result
679
+
680
+ def pw_get_aggregate_domain(self):
681
+ """
682
+ :return: a :class:`Set` that is the union of the domains of all pieces
683
+ """
684
+
685
+ result = Set.empty(self.get_domain_space())
686
+ for dom, _ in self.get_pieces():
687
+ result = result.union(dom)
688
+
689
+ return result
690
+
691
+ PwAff.get_pieces = pwaff_get_pieces
692
+ Aff.get_pieces = pwaff_get_pieces
693
+ PwAff.get_aggregate_domain = pw_get_aggregate_domain
694
+
695
+ PwQPolynomial.get_pieces = pwqpolynomial_get_pieces
696
+ PwQPolynomial.get_aggregate_domain = pw_get_aggregate_domain
697
+
698
+ # }}}
699
+
700
+ # {{{ QPolynomial
701
+
702
+ def qpolynomial_get_terms(self):
703
+ """Get the list of :class:`Term` instances in this :class:`QPolynomial`."""
704
+ result = []
705
+ self.foreach_term(result.append)
706
+ return result
707
+
708
+ QPolynomial.get_terms = qpolynomial_get_terms
709
+
710
+ # }}}
711
+
712
+ # {{{ PwQPolynomial
713
+
714
+ def pwqpolynomial_eval_with_dict(self, value_dict):
715
+ """Evaluates *self* for the parameters specified by
716
+ *value_dict*, which maps parameter names to their values.
717
+ """
718
+
719
+ pt = Point.zero(self.space.params())
720
+
721
+ for i in range(self.space.dim(dim_type.param)):
722
+ par_name = self.space.get_dim_name(dim_type.param, i)
723
+ pt = pt.set_coordinate_val(
724
+ dim_type.param, i, value_dict[par_name])
725
+
726
+ return self.eval(pt).to_python()
727
+
728
+ PwQPolynomial.eval_with_dict = pwqpolynomial_eval_with_dict
729
+
730
+ # }}}
731
+
732
+ # {{{ arithmetic
733
+
734
+ def _number_to_expr_like(template, num):
735
+ number_aff = Aff.zero_on_domain(template.get_domain_space())
736
+ number_aff = number_aff.set_constant_val(num)
737
+
738
+ if isinstance(template, Aff):
739
+ return number_aff
740
+ if isinstance(template, QPolynomial):
741
+ return QPolynomial.from_aff(number_aff)
742
+
743
+ # everything else is piecewise
744
+
745
+ if template.get_pieces():
746
+ number_pw_aff = PwAff.empty(template.get_space())
747
+ for set, _ in template.get_pieces():
748
+ number_pw_aff = set.indicator_function().cond(
749
+ number_aff, number_pw_aff)
750
+ else:
751
+ number_pw_aff = PwAff.alloc(
752
+ Set.universe(template.domain().space),
753
+ number_aff)
754
+
755
+ if isinstance(template, PwAff):
756
+ return number_pw_aff
757
+
758
+ elif isinstance(template, PwQPolynomial):
759
+ return PwQPolynomial.from_pw_aff(number_pw_aff)
760
+
761
+ else:
762
+ raise TypeError("unexpected template type")
763
+
764
+ ARITH_CLASSES = (Aff, PwAff, QPolynomial, PwQPolynomial) # noqa
765
+
766
+ def expr_like_add(self, other):
767
+ if not isinstance(other, ARITH_CLASSES):
768
+ other = _number_to_expr_like(self, other)
769
+
770
+ try:
771
+ return self.add(other)
772
+ except TypeError:
773
+ return NotImplemented
774
+
775
+ def expr_like_sub(self, other):
776
+ if not isinstance(other, ARITH_CLASSES):
777
+ other = _number_to_expr_like(self, other)
778
+
779
+ try:
780
+ return self.sub(other)
781
+ except TypeError:
782
+ return NotImplemented
783
+
784
+ def expr_like_rsub(self, other):
785
+ if not isinstance(other, ARITH_CLASSES):
786
+ other = _number_to_expr_like(self, other)
787
+
788
+ return -self + other
789
+
790
+ def expr_like_mul(self, other):
791
+ if not isinstance(other, ARITH_CLASSES):
792
+ other = _number_to_expr_like(self, other)
793
+
794
+ try:
795
+ return self.mul(other)
796
+ except TypeError:
797
+ return NotImplemented
798
+
799
+ def expr_like_floordiv(self, other):
800
+ return self.scale_down_val(other).floor()
801
+
802
+ for expr_like_class in ARITH_CLASSES:
803
+ expr_like_class.__add__ = expr_like_add
804
+ expr_like_class.__radd__ = expr_like_add
805
+ expr_like_class.__sub__ = expr_like_sub
806
+ expr_like_class.__rsub__ = expr_like_rsub
807
+ expr_like_class.__mul__ = expr_like_mul
808
+ expr_like_class.__rmul__ = expr_like_mul
809
+ expr_like_class.__neg__ = expr_like_class.neg
810
+
811
+ for qpoly_class in [QPolynomial, PwQPolynomial]:
812
+ qpoly_class.__pow__ = qpoly_class.pow
813
+
814
+ for aff_class in [Aff, PwAff]:
815
+ aff_class.__mod__ = aff_class.mod_val
816
+ aff_class.__floordiv__ = expr_like_floordiv
817
+
818
+ # }}}
819
+
820
+ # {{{ Val
821
+
822
+ def val_rsub(self, other):
823
+ return -self + other
824
+
825
+ def val_bool(self):
826
+ return not self.is_zero()
827
+
828
+ def val_repr(self):
829
+ return f'{type(self).__name__}("{self.to_str()}")'
830
+
831
+ def val_to_python(self):
832
+ if not self.is_int():
833
+ raise ValueError("can only convert integer Val to python")
834
+
835
+ return int(self.to_str())
836
+
837
+ Val.__add__ = Val.add
838
+ Val.__radd__ = Val.add
839
+ Val.__sub__ = Val.sub
840
+ Val.__rsub__ = val_rsub
841
+ Val.__mul__ = Val.mul
842
+ Val.__rmul__ = Val.mul
843
+ Val.__neg__ = Val.neg
844
+ Val.__mod__ = Val.mod
845
+ Val.__bool__ = Val.__nonzero__ = val_bool
846
+
847
+ Val.__lt__ = Val.lt
848
+ Val.__gt__ = Val.gt
849
+ Val.__le__ = Val.le
850
+ Val.__ge__ = Val.ge
851
+ Val.__eq__ = Val.eq
852
+ Val.__ne__ = Val.ne
853
+
854
+ Val.__repr__ = val_repr
855
+ Val.__str__ = Val.to_str
856
+ Val.to_python = val_to_python
857
+
858
+ # }}}
859
+
860
+ # {{{ rich comparisons
861
+
862
+ def obj_eq(self, other):
863
+ assert self.get_ctx() == other.get_ctx(), (
864
+ "Equality-comparing two objects from different ISL Contexts "
865
+ "will likely lead to entertaining (but never useful) results. "
866
+ "In particular, Spaces with matching names will no longer be "
867
+ "equal.")
868
+
869
+ return self.is_equal(other)
870
+
871
+ def obj_ne(self, other):
872
+ return not self.__eq__(other)
873
+
874
+ for cls in ALL_CLASSES:
875
+ if hasattr(cls, "is_equal"):
876
+ cls.__eq__ = obj_eq
877
+ cls.__ne__ = obj_ne
878
+
879
+ def obj_lt(self, other):
880
+ return self.is_strict_subset(other)
881
+
882
+ def obj_le(self, other):
883
+ return self.is_subset(other)
884
+
885
+ def obj_gt(self, other):
886
+ return other.is_strict_subset(self)
887
+
888
+ def obj_ge(self, other):
889
+ return other.is_subset(self)
890
+
891
+ for cls in [BasicSet, BasicMap, Set, Map]:
892
+ cls.__lt__ = obj_lt
893
+ cls.__le__ = obj_le
894
+ cls.__gt__ = obj_gt
895
+ cls.__ge__ = obj_ge
896
+
897
+ # }}}
898
+
899
+ # {{{ project_out_except
900
+
901
+ def obj_project_out_except(obj, names, types):
902
+ """
903
+ :param types: list of :class:`dim_type` determining
904
+ the types of axes to project out
905
+ :param names: names of axes matching the above which
906
+ should be left alone by the projection
907
+
908
+ .. versionadded:: 2011.3
909
+ """
910
+
911
+ for tp in types:
912
+ while True:
913
+ space = obj.get_space()
914
+ var_dict = space.get_var_dict(tp)
915
+
916
+ all_indices = set(range(space.dim(tp)))
917
+ leftover_indices = {var_dict[name][1] for name in names
918
+ if name in var_dict}
919
+ project_indices = all_indices-leftover_indices
920
+ if not project_indices:
921
+ break
922
+
923
+ min_index = min(project_indices)
924
+ count = 1
925
+ while min_index+count in project_indices:
926
+ count += 1
927
+
928
+ obj = obj.project_out(tp, min_index, count)
929
+
930
+ return obj
931
+
932
+ # }}}
933
+
934
+ # {{{ eliminate_except
935
+
936
+ def obj_eliminate_except(obj, names, types):
937
+ """
938
+ :param types: list of :class:`dim_type` determining
939
+ the types of axes to eliminate
940
+ :param names: names of axes matching the above which
941
+ should be left alone by the eliminate
942
+
943
+ .. versionadded:: 2011.3
944
+ """
945
+
946
+ for tp in types:
947
+ space = obj.get_space()
948
+ var_dict = space.get_var_dict(tp)
949
+ to_eliminate = (
950
+ set(range(space.dim(tp)))
951
+ - {var_dict[name][1] for name in names
952
+ if name in var_dict})
953
+
954
+ while to_eliminate:
955
+ min_index = min(to_eliminate)
956
+ count = 1
957
+ while min_index+count in to_eliminate:
958
+ count += 1
959
+
960
+ obj = obj.eliminate(tp, min_index, count)
961
+
962
+ to_eliminate -= set(range(min_index, min_index+count))
963
+
964
+ return obj
965
+
966
+ # }}}
967
+
968
+ # {{{ add_constraints
969
+
970
+ def obj_add_constraints(obj, constraints):
971
+ """
972
+ .. versionadded:: 2011.3
973
+ """
974
+
975
+ for cns in constraints:
976
+ obj = obj.add_constraint(cns)
977
+
978
+ return obj
979
+
980
+ # }}}
981
+
982
+ for c in [BasicSet, BasicMap, Set, Map]:
983
+ c.project_out_except = obj_project_out_except
984
+ c.add_constraints = obj_add_constraints
985
+
986
+ for c in [BasicSet, Set]:
987
+ c.eliminate_except = obj_eliminate_except
988
+
989
+
990
+ _add_functionality()
991
+
992
+
993
+ def _back_to_basic(new_obj, old_obj):
994
+ # Work around set_dim_id not being available for Basic{Set,Map}
995
+ if isinstance(old_obj, BasicSet) and isinstance(new_obj, Set):
996
+ bsets = new_obj.get_basic_sets()
997
+
998
+ if len(bsets) == 0:
999
+ bset = BasicSet.universe(new_obj.space).complement()
1000
+ else:
1001
+ bset, = bsets
1002
+
1003
+ return bset
1004
+
1005
+ if isinstance(old_obj, BasicMap) and isinstance(new_obj, Map):
1006
+ bmaps = new_obj.get_basic_maps()
1007
+
1008
+ if len(bmaps) == 0:
1009
+ bmap = BasicMap.universe(new_obj.space).complement()
1010
+ else:
1011
+ bmap, = bmaps
1012
+
1013
+ return bmap
1014
+
1015
+ return new_obj
1016
+
1017
+
1018
+ def _set_dim_id(obj, dt, idx, id):
1019
+ return _back_to_basic(obj.set_dim_id(dt, idx, id), obj)
1020
+
1021
+
1022
+ def _align_dim_type(template_dt, obj, template, obj_bigger_ok, obj_names,
1023
+ template_names):
1024
+
1025
+ # {{{ deal with Aff, PwAff
1026
+
1027
+ # The technique below will not work for PwAff et al, because there is *only*
1028
+ # the 'param' dim_type, and we are not allowed to move dims around in there.
1029
+ # We'll make isl do the work, using align_params.
1030
+
1031
+ if template_dt == dim_type.param and isinstance(obj, (Aff, PwAff)):
1032
+ if not isinstance(template, Space):
1033
+ template_space = template.space
1034
+ else:
1035
+ template_space = template
1036
+
1037
+ if not obj_bigger_ok:
1038
+ if (obj.dim(template_dt) > template.dim(template_dt)
1039
+ or not set(obj.get_var_dict()) <= set(template.get_var_dict())):
1040
+ raise Error("obj has leftover dimensions after alignment")
1041
+ return obj.align_params(template_space)
1042
+
1043
+ # }}}
1044
+
1045
+ if None in template_names:
1046
+ all_nones = [None] * len(template_names)
1047
+ if template_names == all_nones and obj_names == all_nones:
1048
+ # that's ok
1049
+ return obj
1050
+
1051
+ raise Error("template may not contain any unnamed dimensions")
1052
+
1053
+ obj_names = set(obj_names) - {None}
1054
+ template_names = set(template_names) - {None}
1055
+
1056
+ names_in_both = obj_names & template_names
1057
+
1058
+ tgt_idx = 0
1059
+ while tgt_idx < template.dim(template_dt):
1060
+ tgt_id = template.get_dim_id(template_dt, tgt_idx)
1061
+ tgt_name = tgt_id.name
1062
+
1063
+ if tgt_name in names_in_both:
1064
+ if (obj.dim(template_dt) > tgt_idx
1065
+ and tgt_name == obj.get_dim_name(template_dt, tgt_idx)):
1066
+ pass
1067
+
1068
+ else:
1069
+ src_dt, src_idx = obj.get_var_dict()[tgt_name]
1070
+
1071
+ if src_dt == template_dt:
1072
+ assert src_idx > tgt_idx
1073
+
1074
+ # isl requires move_dims to be between different types.
1075
+ # Not sure why. Let's make it happy.
1076
+ other_dt = dim_type.param
1077
+ if src_dt == other_dt:
1078
+ other_dt = dim_type.out
1079
+
1080
+ other_dt_dim = obj.dim(other_dt)
1081
+ obj = obj.move_dims(other_dt, other_dt_dim, src_dt, src_idx, 1)
1082
+ obj = obj.move_dims(
1083
+ template_dt, tgt_idx, other_dt, other_dt_dim, 1)
1084
+ else:
1085
+ obj = obj.move_dims(template_dt, tgt_idx, src_dt, src_idx, 1)
1086
+
1087
+ # names are same, make Ids the same, too
1088
+ obj = _set_dim_id(obj, template_dt, tgt_idx, tgt_id)
1089
+
1090
+ tgt_idx += 1
1091
+ else:
1092
+ obj = obj.insert_dims(template_dt, tgt_idx, 1)
1093
+ obj = _set_dim_id(obj, template_dt, tgt_idx, tgt_id)
1094
+
1095
+ tgt_idx += 1
1096
+
1097
+ if tgt_idx < obj.dim(template_dt) and not obj_bigger_ok:
1098
+ raise Error("obj has leftover dimensions after alignment")
1099
+
1100
+ return obj
1101
+
1102
+
1103
+ def align_spaces(obj, template, obj_bigger_ok=False, across_dim_types=None):
1104
+ """
1105
+ Try to make the space in which *obj* lives the same as that of *template* by
1106
+ adding/matching named dimensions.
1107
+
1108
+ :param obj_bigger_ok: If *True*, no error is raised if the resulting *obj*
1109
+ has more dimensions than *template*.
1110
+ """
1111
+
1112
+ if across_dim_types is not None:
1113
+ from warnings import warn
1114
+ warn("across_dim_types is deprecated and should no longer be used. "
1115
+ "It never had any effect anyway.",
1116
+ DeprecationWarning, stacklevel=2)
1117
+
1118
+ have_any_param_domains = (
1119
+ isinstance(obj, (Set, BasicSet))
1120
+ and isinstance(template, (Set, BasicSet))
1121
+ and (obj.is_params() or template.is_params()))
1122
+ if have_any_param_domains:
1123
+ if obj.is_params():
1124
+ obj = type(obj).from_params(obj)
1125
+ if template.is_params():
1126
+ template = type(template).from_params(template)
1127
+
1128
+ if isinstance(template, EXPR_CLASSES):
1129
+ dim_types = _CHECK_DIM_TYPES[:]
1130
+ dim_types.remove(dim_type.out)
1131
+ else:
1132
+ dim_types = _CHECK_DIM_TYPES
1133
+
1134
+ obj_names = [
1135
+ obj.get_dim_name(dt, i)
1136
+ for dt in dim_types
1137
+ for i in range(obj.dim(dt))
1138
+ ]
1139
+ template_names = [
1140
+ template.get_dim_name(dt, i)
1141
+ for dt in dim_types
1142
+ for i in range(template.dim(dt))
1143
+ ]
1144
+
1145
+ for dt in dim_types:
1146
+ obj = _align_dim_type(
1147
+ dt, obj, template, obj_bigger_ok, obj_names, template_names)
1148
+
1149
+ return obj
1150
+
1151
+
1152
+ def align_two(obj1, obj2, across_dim_types=None):
1153
+ """Align the spaces of two objects, potentially modifying both of them.
1154
+
1155
+ See also :func:`align_spaces`.
1156
+ """
1157
+
1158
+ if across_dim_types is not None:
1159
+ from warnings import warn
1160
+ warn("across_dim_types is deprecated and should no longer be used. "
1161
+ "It never had any effect anyway.",
1162
+ DeprecationWarning, stacklevel=2)
1163
+
1164
+ obj1 = align_spaces(obj1, obj2, obj_bigger_ok=True)
1165
+ obj2 = align_spaces(obj2, obj1, obj_bigger_ok=True)
1166
+ return (obj1, obj2)
1167
+
1168
+
1169
+ def make_zero_and_vars(set_vars, params=(), ctx=None):
1170
+ """
1171
+ :arg set_vars: an iterable of variable names, or a comma-separated string
1172
+ :arg params: an iterable of variable names, or a comma-separated string
1173
+
1174
+ :return: a dictionary from variable names (in *set_vars* and *params*)
1175
+ to :class:`PwAff` instances that represent each of the
1176
+ variables. They key '0' is also include and represents
1177
+ a :class:`PwAff` zero constant.
1178
+
1179
+ .. versionadded:: 2016.1.1
1180
+
1181
+ This function is intended to make it relatively easy to construct sets
1182
+ programmatically without resorting to string manipulation.
1183
+
1184
+ Usage example::
1185
+
1186
+ v = isl.make_zero_and_vars("i,j,k", "n")
1187
+
1188
+ myset = (
1189
+ v[0].le_set(v["i"] + v["j"])
1190
+ &
1191
+ (v["i"] + v["j"]).lt_set(v["n"])
1192
+ &
1193
+ (v[0].le_set(v["i"]))
1194
+ &
1195
+ (v["i"].le_set(13 + v["n"]))
1196
+ )
1197
+ """
1198
+ if ctx is None:
1199
+ ctx = DEFAULT_CONTEXT
1200
+
1201
+ if isinstance(set_vars, str):
1202
+ set_vars = [s.strip() for s in set_vars.split(",")]
1203
+ if isinstance(params, str):
1204
+ params = [s.strip() for s in params.split(",")]
1205
+
1206
+ space = Space.create_from_names(ctx, set=set_vars, params=params)
1207
+ return affs_from_space(space)
1208
+
1209
+
1210
+ def affs_from_space(space):
1211
+ """
1212
+ :return: a dictionary from variable names (in *set_vars* and *params*)
1213
+ to :class:`PwAff` instances that represent each of the
1214
+ variables *in*space*. They key '0' is also include and represents
1215
+ a :class:`PwAff` zero constant.
1216
+
1217
+ .. versionadded:: 2016.2
1218
+
1219
+ This function is intended to make it relatively easy to construct sets
1220
+ programmatically without resorting to string manipulation.
1221
+
1222
+ Usage example::
1223
+
1224
+ s = isl.Set("[n] -> {[i,j,k]: 0<=i,j,k<n}")
1225
+ v = isl.affs_from_space(s.space)
1226
+
1227
+ myset = (
1228
+ v[0].le_set(v["i"] + v["j"])
1229
+ &
1230
+ (v["i"] + v["j"]).lt_set(v["n"])
1231
+ &
1232
+ (v[0].le_set(v["i"]))
1233
+ &
1234
+ (v["i"].le_set(13 + v["n"]))
1235
+ )
1236
+ """
1237
+
1238
+ result = {}
1239
+
1240
+ zero = Aff.zero_on_domain(LocalSpace.from_space(space))
1241
+ result[0] = PwAff.from_aff(zero)
1242
+
1243
+ var_dict = zero.get_var_dict()
1244
+ for name, (dt, idx) in var_dict.items():
1245
+ result[name] = PwAff.from_aff(zero.set_coefficient_val(dt, idx, 1))
1246
+
1247
+ return result
1248
+
1249
+
1250
+ class SuppressedWarnings:
1251
+ def __init__(self, ctx):
1252
+ from warnings import warn
1253
+ warn("islpy.SuppressedWarnings is a deprecated no-op and will be removed "
1254
+ "in 2023. Simply remove the use of it to avoid this warning.",
1255
+ DeprecationWarning, stacklevel=1)
1256
+
1257
+ def __enter__(self):
1258
+ pass
1259
+
1260
+ def __exit__(self, type, value, traceback):
1261
+ pass
1262
+
1263
+
1264
+ # {{{ give sphinx something to import so we can produce docs
1265
+
1266
+ def _define_doc_link_names():
1267
+ class Div:
1268
+ pass
1269
+
1270
+ _isl.Div = Div
1271
+
1272
+
1273
+ _define_doc_link_names()
1274
+
1275
+ # }}}
1276
+
1277
+
1278
+ # vim: foldmethod=marker