scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,412 +1,412 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- # Many elements of this code were developed in scikit-learn. These elements
4
- # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
- # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
- """Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
7
- import inspect
8
- import pprint
9
- from collections import OrderedDict
10
-
11
- from skbase.base import BaseObject
12
-
13
- # from skbase.config import get_config
14
- from skbase.utils._check import _is_scalar_nan
15
-
16
-
17
- class KeyValTuple(tuple):
18
- """Dummy class for correctly rendering key-value tuples from dicts."""
19
-
20
- def __repr__(self):
21
- """Represent as string."""
22
- # needed for _dispatch[tuple.__repr__] not to be overridden
23
- return super().__repr__()
24
-
25
-
26
- class KeyValTupleParam(KeyValTuple):
27
- """Dummy class for correctly rendering key-value tuples from parameters."""
28
-
29
- pass
30
-
31
-
32
- def _changed_params(base_object):
33
- """Return dict (param_name: value) of parameters with non-default values."""
34
- params = base_object.get_params(deep=False)
35
- init_func = getattr(
36
- base_object.__init__, "deprecated_original", base_object.__init__
37
- )
38
- init_params = inspect.signature(init_func).parameters
39
- init_params = {name: param.default for name, param in init_params.items()}
40
-
41
- def has_changed(k, v):
42
- if k not in init_params: # happens if k is part of a **kwargs
43
- return True
44
- if init_params[k] == inspect._empty: # k has no default value
45
- return True
46
- # try to avoid calling repr on nested BaseObjects
47
- if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
48
- return True
49
- # Use repr as a last resort. It may be expensive.
50
- if repr(v) != repr(init_params[k]) and not (
51
- _is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
52
- ):
53
- return True
54
- return False
55
-
56
- return {k: v for k, v in params.items() if has_changed(k, v)}
57
-
58
-
59
- class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
60
- """Pretty Printer class for BaseObjects.
61
-
62
- This extends the pprint.PrettyPrinter class similar to scikit-learn's
63
- implementation, so that:
64
-
65
- - BaseObjects are printed with their parameters, e.g.
66
- BaseObject(param1=value1, ...) which is not supported by default.
67
- - the 'compact' parameter of PrettyPrinter is ignored for dicts, which
68
- may lead to very long representations that we want to avoid.
69
-
70
- Quick overview of pprint.PrettyPrinter (see also
71
- https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
72
-
73
- - the entry point is the _format() method which calls format() (overridden
74
- here)
75
- - format() directly calls _safe_repr() for a first try at rendering the
76
- object
77
- - _safe_repr formats the whole object recursively, only calling itself,
78
- not caring about line length or anything
79
- - back to _format(), if the output string is too long, _format() then calls
80
- the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
81
- the type of the object. This where the line length and the compact
82
- parameters are taken into account.
83
- - those _pprint_TYPE() methods will internally use the format() method for
84
- rendering the nested objects of an object (e.g. the elements of a list)
85
-
86
- In the end, everything has to be implemented twice: in _safe_repr and in
87
- the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
88
- straightforward to extend (especially when we want a compact output), so
89
- the code is a bit convoluted.
90
-
91
- This class overrides:
92
- - format() to support the changed_only parameter
93
- - _safe_repr to support printing of BaseObjects that fit on a single line
94
- - _format_dict_items so that dict are correctly 'compacted'
95
- - _format_items so that ellipsis is used on long lists and tuples
96
-
97
- When BaseObjects cannot be printed on a single line, the builtin _format()
98
- will call _pprint_object() because it was registered to do so (see
99
- _dispatch[BaseObject.__repr__] = _pprint_object).
100
-
101
- both _format_dict_items() and _pprint_Object() use the
102
- _format_params_or_dict_items() method that will format parameters and
103
- key-value pairs respecting the compact parameter. This method needs another
104
- subroutine _pprint_key_val_tuple() used when a parameter or a key-value
105
- pair is too long to fit on a single line. This subroutine is called in
106
- _format() and is registered as well in the _dispatch dict (just like
107
- _pprint_object). We had to create the two classes KeyValTuple and
108
- KeyValTupleParam for this.
109
- """
110
-
111
- def __init__(
112
- self,
113
- indent=1,
114
- width=80,
115
- depth=None,
116
- stream=None,
117
- *,
118
- compact=False,
119
- indent_at_name=True,
120
- n_max_elements_to_show=None,
121
- changed_only=True,
122
- ):
123
- super().__init__(indent, width, depth, stream, compact=compact)
124
- self._indent_at_name = indent_at_name
125
- if self._indent_at_name:
126
- self._indent_per_level = 1 # ignore indent param
127
- self.changed_only = changed_only
128
- # Max number of elements in a list, dict, tuple until we start using
129
- # ellipsis. This also affects the number of arguments of a BaseObject
130
- # (they are treated as dicts)
131
- self.n_max_elements_to_show = n_max_elements_to_show
132
-
133
- def format(self, obj, context, maxlevels, level): # noqa
134
- return _safe_repr(
135
- obj, context, maxlevels, level, changed_only=self.changed_only
136
- )
137
-
138
- def _pprint_object(self, obj, stream, indent, allowance, context, level):
139
- stream.write(obj.__class__.__name__ + "(")
140
- if self._indent_at_name:
141
- indent += len(obj.__class__.__name__)
142
-
143
- if self.changed_only:
144
- params = _changed_params(obj)
145
- else:
146
- params = obj.get_params(deep=False)
147
-
148
- params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
149
-
150
- self._format_params(
151
- params.items(), stream, indent, allowance + 1, context, level
152
- )
153
- stream.write(")")
154
-
155
- def _format_dict_items(self, items, stream, indent, allowance, context, level):
156
- return self._format_params_or_dict_items(
157
- items, stream, indent, allowance, context, level, is_dict=True
158
- )
159
-
160
- def _format_params(self, items, stream, indent, allowance, context, level):
161
- return self._format_params_or_dict_items(
162
- items, stream, indent, allowance, context, level, is_dict=False
163
- )
164
-
165
- def _format_params_or_dict_items(
166
- self, obj, stream, indent, allowance, context, level, is_dict
167
- ):
168
- """Format dict items or parameters respecting the compact=True parameter.
169
-
170
- For some reason, the builtin rendering of dict items doesn't
171
- respect compact=True and will use one line per key-value if all cannot
172
- fit in a single line.
173
- Dict items will be rendered as <'key': value> while params will be
174
- rendered as <key=value>. The implementation is mostly copy/pasting from
175
- the builtin _format_items().
176
- This also adds ellipsis if the number of items is greater than
177
- self.n_max_elements_to_show.
178
- """
179
- write = stream.write
180
- indent += self._indent_per_level
181
- delimnl = ",\n" + " " * indent
182
- delim = ""
183
- width = max_width = self._width - indent + 1
184
- it = iter(obj)
185
- try:
186
- next_ent = next(it)
187
- except StopIteration:
188
- return
189
- last = False
190
- n_items = 0
191
- while not last:
192
- if n_items == self.n_max_elements_to_show:
193
- write(", ...")
194
- break
195
- n_items += 1
196
- ent = next_ent
197
- try:
198
- next_ent = next(it)
199
- except StopIteration:
200
- last = True
201
- max_width -= allowance
202
- width -= allowance
203
- if self._compact:
204
- k, v = ent
205
- krepr = self._repr(k, context, level)
206
- vrepr = self._repr(v, context, level)
207
- if not is_dict:
208
- krepr = krepr.strip("'")
209
- middle = ": " if is_dict else "="
210
- rep = krepr + middle + vrepr
211
- w = len(rep) + 2
212
- if width < w:
213
- width = max_width
214
- if delim:
215
- delim = delimnl
216
- if width >= w:
217
- width -= w
218
- write(delim)
219
- delim = ", "
220
- write(rep)
221
- continue
222
- write(delim)
223
- delim = delimnl
224
- class_ = KeyValTuple if is_dict else KeyValTupleParam
225
- self._format(
226
- class_(ent), stream, indent, allowance if last else 1, context, level
227
- )
228
-
229
- def _format_items(self, items, stream, indent, allowance, context, level):
230
- """Format the items of an iterable (list, tuple...).
231
-
232
- Same as the built-in _format_items, with support for ellipsis if the
233
- number of elements is greater than self.n_max_elements_to_show.
234
- """
235
- write = stream.write
236
- indent += self._indent_per_level
237
- if self._indent_per_level > 1:
238
- write((self._indent_per_level - 1) * " ")
239
- delimnl = ",\n" + " " * indent
240
- delim = ""
241
- width = max_width = self._width - indent + 1
242
- it = iter(items)
243
- try:
244
- next_ent = next(it)
245
- except StopIteration:
246
- return
247
- last = False
248
- n_items = 0
249
- while not last:
250
- if n_items == self.n_max_elements_to_show:
251
- write(", ...")
252
- break
253
- n_items += 1
254
- ent = next_ent
255
- try:
256
- next_ent = next(it)
257
- except StopIteration:
258
- last = True
259
- max_width -= allowance
260
- width -= allowance
261
- if self._compact:
262
- rep = self._repr(ent, context, level)
263
- w = len(rep) + 2
264
- if width < w:
265
- width = max_width
266
- if delim:
267
- delim = delimnl
268
- if width >= w:
269
- width -= w
270
- write(delim)
271
- delim = ", "
272
- write(rep)
273
- continue
274
- write(delim)
275
- delim = delimnl
276
- self._format(ent, stream, indent, allowance if last else 1, context, level)
277
-
278
- def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
279
- """Pretty printing for key-value tuples from dict or parameters."""
280
- k, v = obj
281
- rep = self._repr(k, context, level)
282
- if isinstance(obj, KeyValTupleParam):
283
- rep = rep.strip("'")
284
- middle = "="
285
- else:
286
- middle = ": "
287
- stream.write(rep)
288
- stream.write(middle)
289
- self._format(
290
- v, stream, indent + len(rep) + len(middle), allowance, context, level
291
- )
292
-
293
- # Follow what scikit-learn did here and copy _dispatch to prevent instances
294
- # of the builtin PrettyPrinter class to call methods of
295
- # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
296
- # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
297
- _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
298
- _dispatch[BaseObject.__repr__] = _pprint_object
299
- _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
300
-
301
-
302
- def _safe_repr(obj, context, maxlevels, level, changed_only=False):
303
- """Safe string representation logic.
304
-
305
- Same as the builtin _safe_repr, with added support for BaseObjects.
306
- """
307
- typ = type(obj)
308
-
309
- if typ in pprint._builtin_scalars:
310
- return repr(obj), True, False
311
-
312
- r = getattr(typ, "__repr__", None)
313
- if issubclass(typ, dict) and r is dict.__repr__:
314
- if not obj:
315
- return "{}", True, False
316
- objid = id(obj)
317
- if maxlevels and level >= maxlevels:
318
- return "{...}", False, objid in context
319
- if objid in context:
320
- return pprint._recursion(obj), False, True
321
- context[objid] = 1
322
- readable = True
323
- recursive = False
324
- components = []
325
- append = components.append
326
- level += 1
327
- saferepr = _safe_repr
328
- items = sorted(obj.items(), key=pprint._safe_tuple)
329
- for k, v in items:
330
- krepr, kreadable, krecur = saferepr(
331
- k, context, maxlevels, level, changed_only=changed_only
332
- )
333
- vrepr, vreadable, vrecur = saferepr(
334
- v, context, maxlevels, level, changed_only=changed_only
335
- )
336
- append("%s: %s" % (krepr, vrepr))
337
- readable = readable and kreadable and vreadable
338
- if krecur or vrecur:
339
- recursive = True
340
- del context[objid]
341
- return "{%s}" % ", ".join(components), readable, recursive
342
-
343
- if (issubclass(typ, list) and r is list.__repr__) or (
344
- issubclass(typ, tuple) and r is tuple.__repr__
345
- ):
346
- if issubclass(typ, list):
347
- if not obj:
348
- return "[]", True, False
349
- format_ = "[%s]"
350
- elif len(obj) == 1:
351
- format_ = "(%s,)"
352
- else:
353
- if not obj:
354
- return "()", True, False
355
- format_ = "(%s)"
356
- objid = id(obj)
357
- if maxlevels and level >= maxlevels:
358
- return format_ % "...", False, objid in context
359
- if objid in context:
360
- return pprint._recursion(obj), False, True
361
- context[objid] = 1
362
- readable = True
363
- recursive = False
364
- components = []
365
- append = components.append
366
- level += 1
367
- for o in obj:
368
- orepr, oreadable, orecur = _safe_repr(
369
- o, context, maxlevels, level, changed_only=changed_only
370
- )
371
- append(orepr)
372
- if not oreadable:
373
- readable = False
374
- if orecur:
375
- recursive = True
376
- del context[objid]
377
- return format_ % ", ".join(components), readable, recursive
378
-
379
- if issubclass(typ, BaseObject):
380
- objid = id(obj)
381
- if maxlevels and level >= maxlevels:
382
- return "{...}", False, objid in context
383
- if objid in context:
384
- return pprint._recursion(obj), False, True
385
- context[objid] = 1
386
- readable = True
387
- recursive = False
388
- if changed_only:
389
- params = _changed_params(obj)
390
- else:
391
- params = obj.get_params(deep=False)
392
- components = []
393
- append = components.append
394
- level += 1
395
- saferepr = _safe_repr
396
- items = sorted(params.items(), key=pprint._safe_tuple)
397
- for k, v in items:
398
- krepr, kreadable, krecur = saferepr(
399
- k, context, maxlevels, level, changed_only=changed_only
400
- )
401
- vrepr, vreadable, vrecur = saferepr(
402
- v, context, maxlevels, level, changed_only=changed_only
403
- )
404
- append("%s=%s" % (krepr.strip("'"), vrepr))
405
- readable = readable and kreadable and vreadable
406
- if krecur or vrecur:
407
- recursive = True
408
- del context[objid]
409
- return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
410
-
411
- rep = repr(obj)
412
- return rep, (rep and not rep.startswith("<")), False
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Many elements of this code were developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
7
+ import inspect
8
+ import pprint
9
+ from collections import OrderedDict
10
+
11
+ from skbase.base import BaseObject
12
+
13
+ # from skbase.config import get_config
14
+ from skbase.utils._check import _is_scalar_nan
15
+
16
+
17
+ class KeyValTuple(tuple):
18
+ """Dummy class for correctly rendering key-value tuples from dicts."""
19
+
20
+ def __repr__(self):
21
+ """Represent as string."""
22
+ # needed for _dispatch[tuple.__repr__] not to be overridden
23
+ return super().__repr__()
24
+
25
+
26
+ class KeyValTupleParam(KeyValTuple):
27
+ """Dummy class for correctly rendering key-value tuples from parameters."""
28
+
29
+ pass
30
+
31
+
32
+ def _changed_params(base_object):
33
+ """Return dict (param_name: value) of parameters with non-default values."""
34
+ params = base_object.get_params(deep=False)
35
+ init_func = getattr(
36
+ base_object.__init__, "deprecated_original", base_object.__init__
37
+ )
38
+ init_params = inspect.signature(init_func).parameters
39
+ init_params = {name: param.default for name, param in init_params.items()}
40
+
41
+ def has_changed(k, v):
42
+ if k not in init_params: # happens if k is part of a **kwargs
43
+ return True
44
+ if init_params[k] == inspect._empty: # k has no default value
45
+ return True
46
+ # try to avoid calling repr on nested BaseObjects
47
+ if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
48
+ return True
49
+ # Use repr as a last resort. It may be expensive.
50
+ if repr(v) != repr(init_params[k]) and not (
51
+ _is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
52
+ ):
53
+ return True
54
+ return False
55
+
56
+ return {k: v for k, v in params.items() if has_changed(k, v)}
57
+
58
+
59
+ class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
60
+ """Pretty Printer class for BaseObjects.
61
+
62
+ This extends the pprint.PrettyPrinter class similar to scikit-learn's
63
+ implementation, so that:
64
+
65
+ - BaseObjects are printed with their parameters, e.g.
66
+ BaseObject(param1=value1, ...) which is not supported by default.
67
+ - the 'compact' parameter of PrettyPrinter is ignored for dicts, which
68
+ may lead to very long representations that we want to avoid.
69
+
70
+ Quick overview of pprint.PrettyPrinter (see also
71
+ https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
72
+
73
+ - the entry point is the _format() method which calls format() (overridden
74
+ here)
75
+ - format() directly calls _safe_repr() for a first try at rendering the
76
+ object
77
+ - _safe_repr formats the whole object recursively, only calling itself,
78
+ not caring about line length or anything
79
+ - back to _format(), if the output string is too long, _format() then calls
80
+ the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
81
+ the type of the object. This where the line length and the compact
82
+ parameters are taken into account.
83
+ - those _pprint_TYPE() methods will internally use the format() method for
84
+ rendering the nested objects of an object (e.g. the elements of a list)
85
+
86
+ In the end, everything has to be implemented twice: in _safe_repr and in
87
+ the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
88
+ straightforward to extend (especially when we want a compact output), so
89
+ the code is a bit convoluted.
90
+
91
+ This class overrides:
92
+ - format() to support the changed_only parameter
93
+ - _safe_repr to support printing of BaseObjects that fit on a single line
94
+ - _format_dict_items so that dict are correctly 'compacted'
95
+ - _format_items so that ellipsis is used on long lists and tuples
96
+
97
+ When BaseObjects cannot be printed on a single line, the builtin _format()
98
+ will call _pprint_object() because it was registered to do so (see
99
+ _dispatch[BaseObject.__repr__] = _pprint_object).
100
+
101
+ both _format_dict_items() and _pprint_Object() use the
102
+ _format_params_or_dict_items() method that will format parameters and
103
+ key-value pairs respecting the compact parameter. This method needs another
104
+ subroutine _pprint_key_val_tuple() used when a parameter or a key-value
105
+ pair is too long to fit on a single line. This subroutine is called in
106
+ _format() and is registered as well in the _dispatch dict (just like
107
+ _pprint_object). We had to create the two classes KeyValTuple and
108
+ KeyValTupleParam for this.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ indent=1,
114
+ width=80,
115
+ depth=None,
116
+ stream=None,
117
+ *,
118
+ compact=False,
119
+ indent_at_name=True,
120
+ n_max_elements_to_show=None,
121
+ changed_only=True,
122
+ ):
123
+ super().__init__(indent, width, depth, stream, compact=compact)
124
+ self._indent_at_name = indent_at_name
125
+ if self._indent_at_name:
126
+ self._indent_per_level = 1 # ignore indent param
127
+ self.changed_only = changed_only
128
+ # Max number of elements in a list, dict, tuple until we start using
129
+ # ellipsis. This also affects the number of arguments of a BaseObject
130
+ # (they are treated as dicts)
131
+ self.n_max_elements_to_show = n_max_elements_to_show
132
+
133
+ def format(self, obj, context, maxlevels, level): # noqa
134
+ return _safe_repr(
135
+ obj, context, maxlevels, level, changed_only=self.changed_only
136
+ )
137
+
138
+ def _pprint_object(self, obj, stream, indent, allowance, context, level):
139
+ stream.write(obj.__class__.__name__ + "(")
140
+ if self._indent_at_name:
141
+ indent += len(obj.__class__.__name__)
142
+
143
+ if self.changed_only:
144
+ params = _changed_params(obj)
145
+ else:
146
+ params = obj.get_params(deep=False)
147
+
148
+ params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
149
+
150
+ self._format_params(
151
+ params.items(), stream, indent, allowance + 1, context, level
152
+ )
153
+ stream.write(")")
154
+
155
+ def _format_dict_items(self, items, stream, indent, allowance, context, level):
156
+ return self._format_params_or_dict_items(
157
+ items, stream, indent, allowance, context, level, is_dict=True
158
+ )
159
+
160
+ def _format_params(self, items, stream, indent, allowance, context, level):
161
+ return self._format_params_or_dict_items(
162
+ items, stream, indent, allowance, context, level, is_dict=False
163
+ )
164
+
165
+ def _format_params_or_dict_items(
166
+ self, obj, stream, indent, allowance, context, level, is_dict
167
+ ):
168
+ """Format dict items or parameters respecting the compact=True parameter.
169
+
170
+ For some reason, the builtin rendering of dict items doesn't
171
+ respect compact=True and will use one line per key-value if all cannot
172
+ fit in a single line.
173
+ Dict items will be rendered as <'key': value> while params will be
174
+ rendered as <key=value>. The implementation is mostly copy/pasting from
175
+ the builtin _format_items().
176
+ This also adds ellipsis if the number of items is greater than
177
+ self.n_max_elements_to_show.
178
+ """
179
+ write = stream.write
180
+ indent += self._indent_per_level
181
+ delimnl = ",\n" + " " * indent
182
+ delim = ""
183
+ width = max_width = self._width - indent + 1
184
+ it = iter(obj)
185
+ try:
186
+ next_ent = next(it)
187
+ except StopIteration:
188
+ return
189
+ last = False
190
+ n_items = 0
191
+ while not last:
192
+ if n_items == self.n_max_elements_to_show:
193
+ write(", ...")
194
+ break
195
+ n_items += 1
196
+ ent = next_ent
197
+ try:
198
+ next_ent = next(it)
199
+ except StopIteration:
200
+ last = True
201
+ max_width -= allowance
202
+ width -= allowance
203
+ if self._compact:
204
+ k, v = ent
205
+ krepr = self._repr(k, context, level)
206
+ vrepr = self._repr(v, context, level)
207
+ if not is_dict:
208
+ krepr = krepr.strip("'")
209
+ middle = ": " if is_dict else "="
210
+ rep = krepr + middle + vrepr
211
+ w = len(rep) + 2
212
+ if width < w:
213
+ width = max_width
214
+ if delim:
215
+ delim = delimnl
216
+ if width >= w:
217
+ width -= w
218
+ write(delim)
219
+ delim = ", "
220
+ write(rep)
221
+ continue
222
+ write(delim)
223
+ delim = delimnl
224
+ class_ = KeyValTuple if is_dict else KeyValTupleParam
225
+ self._format(
226
+ class_(ent), stream, indent, allowance if last else 1, context, level
227
+ )
228
+
229
+ def _format_items(self, items, stream, indent, allowance, context, level):
230
+ """Format the items of an iterable (list, tuple...).
231
+
232
+ Same as the built-in _format_items, with support for ellipsis if the
233
+ number of elements is greater than self.n_max_elements_to_show.
234
+ """
235
+ write = stream.write
236
+ indent += self._indent_per_level
237
+ if self._indent_per_level > 1:
238
+ write((self._indent_per_level - 1) * " ")
239
+ delimnl = ",\n" + " " * indent
240
+ delim = ""
241
+ width = max_width = self._width - indent + 1
242
+ it = iter(items)
243
+ try:
244
+ next_ent = next(it)
245
+ except StopIteration:
246
+ return
247
+ last = False
248
+ n_items = 0
249
+ while not last:
250
+ if n_items == self.n_max_elements_to_show:
251
+ write(", ...")
252
+ break
253
+ n_items += 1
254
+ ent = next_ent
255
+ try:
256
+ next_ent = next(it)
257
+ except StopIteration:
258
+ last = True
259
+ max_width -= allowance
260
+ width -= allowance
261
+ if self._compact:
262
+ rep = self._repr(ent, context, level)
263
+ w = len(rep) + 2
264
+ if width < w:
265
+ width = max_width
266
+ if delim:
267
+ delim = delimnl
268
+ if width >= w:
269
+ width -= w
270
+ write(delim)
271
+ delim = ", "
272
+ write(rep)
273
+ continue
274
+ write(delim)
275
+ delim = delimnl
276
+ self._format(ent, stream, indent, allowance if last else 1, context, level)
277
+
278
+ def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
279
+ """Pretty printing for key-value tuples from dict or parameters."""
280
+ k, v = obj
281
+ rep = self._repr(k, context, level)
282
+ if isinstance(obj, KeyValTupleParam):
283
+ rep = rep.strip("'")
284
+ middle = "="
285
+ else:
286
+ middle = ": "
287
+ stream.write(rep)
288
+ stream.write(middle)
289
+ self._format(
290
+ v, stream, indent + len(rep) + len(middle), allowance, context, level
291
+ )
292
+
293
+ # Follow what scikit-learn did here and copy _dispatch to prevent instances
294
+ # of the builtin PrettyPrinter class to call methods of
295
+ # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
296
+ # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
297
+ _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
298
+ _dispatch[BaseObject.__repr__] = _pprint_object
299
+ _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
300
+
301
+
302
+ def _safe_repr(obj, context, maxlevels, level, changed_only=False):
303
+ """Safe string representation logic.
304
+
305
+ Same as the builtin _safe_repr, with added support for BaseObjects.
306
+ """
307
+ typ = type(obj)
308
+
309
+ if typ in pprint._builtin_scalars:
310
+ return repr(obj), True, False
311
+
312
+ r = getattr(typ, "__repr__", None)
313
+ if issubclass(typ, dict) and r is dict.__repr__:
314
+ if not obj:
315
+ return "{}", True, False
316
+ objid = id(obj)
317
+ if maxlevels and level >= maxlevels:
318
+ return "{...}", False, objid in context
319
+ if objid in context:
320
+ return pprint._recursion(obj), False, True
321
+ context[objid] = 1
322
+ readable = True
323
+ recursive = False
324
+ components = []
325
+ append = components.append
326
+ level += 1
327
+ saferepr = _safe_repr
328
+ items = sorted(obj.items(), key=pprint._safe_tuple)
329
+ for k, v in items:
330
+ krepr, kreadable, krecur = saferepr(
331
+ k, context, maxlevels, level, changed_only=changed_only
332
+ )
333
+ vrepr, vreadable, vrecur = saferepr(
334
+ v, context, maxlevels, level, changed_only=changed_only
335
+ )
336
+ append("%s: %s" % (krepr, vrepr))
337
+ readable = readable and kreadable and vreadable
338
+ if krecur or vrecur:
339
+ recursive = True
340
+ del context[objid]
341
+ return "{%s}" % ", ".join(components), readable, recursive
342
+
343
+ if (issubclass(typ, list) and r is list.__repr__) or (
344
+ issubclass(typ, tuple) and r is tuple.__repr__
345
+ ):
346
+ if issubclass(typ, list):
347
+ if not obj:
348
+ return "[]", True, False
349
+ format_ = "[%s]"
350
+ elif len(obj) == 1:
351
+ format_ = "(%s,)"
352
+ else:
353
+ if not obj:
354
+ return "()", True, False
355
+ format_ = "(%s)"
356
+ objid = id(obj)
357
+ if maxlevels and level >= maxlevels:
358
+ return format_ % "...", False, objid in context
359
+ if objid in context:
360
+ return pprint._recursion(obj), False, True
361
+ context[objid] = 1
362
+ readable = True
363
+ recursive = False
364
+ components = []
365
+ append = components.append
366
+ level += 1
367
+ for o in obj:
368
+ orepr, oreadable, orecur = _safe_repr(
369
+ o, context, maxlevels, level, changed_only=changed_only
370
+ )
371
+ append(orepr)
372
+ if not oreadable:
373
+ readable = False
374
+ if orecur:
375
+ recursive = True
376
+ del context[objid]
377
+ return format_ % ", ".join(components), readable, recursive
378
+
379
+ if issubclass(typ, BaseObject):
380
+ objid = id(obj)
381
+ if maxlevels and level >= maxlevels:
382
+ return "{...}", False, objid in context
383
+ if objid in context:
384
+ return pprint._recursion(obj), False, True
385
+ context[objid] = 1
386
+ readable = True
387
+ recursive = False
388
+ if changed_only:
389
+ params = _changed_params(obj)
390
+ else:
391
+ params = obj.get_params(deep=False)
392
+ components = []
393
+ append = components.append
394
+ level += 1
395
+ saferepr = _safe_repr
396
+ items = sorted(params.items(), key=pprint._safe_tuple)
397
+ for k, v in items:
398
+ krepr, kreadable, krecur = saferepr(
399
+ k, context, maxlevels, level, changed_only=changed_only
400
+ )
401
+ vrepr, vreadable, vrecur = saferepr(
402
+ v, context, maxlevels, level, changed_only=changed_only
403
+ )
404
+ append("%s=%s" % (krepr.strip("'"), vrepr))
405
+ readable = readable and kreadable and vreadable
406
+ if krecur or vrecur:
407
+ recursive = True
408
+ del context[objid]
409
+ return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
410
+
411
+ rep = repr(obj)
412
+ return rep, (rep and not rep.startswith("<")), False