scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,412 +1,412 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
# Many elements of this code were developed in scikit-learn. These elements
|
4
|
-
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
|
-
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
|
-
"""Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
|
7
|
-
import inspect
|
8
|
-
import pprint
|
9
|
-
from collections import OrderedDict
|
10
|
-
|
11
|
-
from skbase.base import BaseObject
|
12
|
-
|
13
|
-
# from skbase.config import get_config
|
14
|
-
from skbase.utils._check import _is_scalar_nan
|
15
|
-
|
16
|
-
|
17
|
-
class KeyValTuple(tuple):
|
18
|
-
"""Dummy class for correctly rendering key-value tuples from dicts."""
|
19
|
-
|
20
|
-
def __repr__(self):
|
21
|
-
"""Represent as string."""
|
22
|
-
# needed for _dispatch[tuple.__repr__] not to be overridden
|
23
|
-
return super().__repr__()
|
24
|
-
|
25
|
-
|
26
|
-
class KeyValTupleParam(KeyValTuple):
|
27
|
-
"""Dummy class for correctly rendering key-value tuples from parameters."""
|
28
|
-
|
29
|
-
pass
|
30
|
-
|
31
|
-
|
32
|
-
def _changed_params(base_object):
|
33
|
-
"""Return dict (param_name: value) of parameters with non-default values."""
|
34
|
-
params = base_object.get_params(deep=False)
|
35
|
-
init_func = getattr(
|
36
|
-
base_object.__init__, "deprecated_original", base_object.__init__
|
37
|
-
)
|
38
|
-
init_params = inspect.signature(init_func).parameters
|
39
|
-
init_params = {name: param.default for name, param in init_params.items()}
|
40
|
-
|
41
|
-
def has_changed(k, v):
|
42
|
-
if k not in init_params: # happens if k is part of a **kwargs
|
43
|
-
return True
|
44
|
-
if init_params[k] == inspect._empty: # k has no default value
|
45
|
-
return True
|
46
|
-
# try to avoid calling repr on nested BaseObjects
|
47
|
-
if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
|
48
|
-
return True
|
49
|
-
# Use repr as a last resort. It may be expensive.
|
50
|
-
if repr(v) != repr(init_params[k]) and not (
|
51
|
-
_is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
|
52
|
-
):
|
53
|
-
return True
|
54
|
-
return False
|
55
|
-
|
56
|
-
return {k: v for k, v in params.items() if has_changed(k, v)}
|
57
|
-
|
58
|
-
|
59
|
-
class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
|
60
|
-
"""Pretty Printer class for BaseObjects.
|
61
|
-
|
62
|
-
This extends the pprint.PrettyPrinter class similar to scikit-learn's
|
63
|
-
implementation, so that:
|
64
|
-
|
65
|
-
- BaseObjects are printed with their parameters, e.g.
|
66
|
-
BaseObject(param1=value1, ...) which is not supported by default.
|
67
|
-
- the 'compact' parameter of PrettyPrinter is ignored for dicts, which
|
68
|
-
may lead to very long representations that we want to avoid.
|
69
|
-
|
70
|
-
Quick overview of pprint.PrettyPrinter (see also
|
71
|
-
https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
|
72
|
-
|
73
|
-
- the entry point is the _format() method which calls format() (overridden
|
74
|
-
here)
|
75
|
-
- format() directly calls _safe_repr() for a first try at rendering the
|
76
|
-
object
|
77
|
-
- _safe_repr formats the whole object recursively, only calling itself,
|
78
|
-
not caring about line length or anything
|
79
|
-
- back to _format(), if the output string is too long, _format() then calls
|
80
|
-
the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
|
81
|
-
the type of the object. This where the line length and the compact
|
82
|
-
parameters are taken into account.
|
83
|
-
- those _pprint_TYPE() methods will internally use the format() method for
|
84
|
-
rendering the nested objects of an object (e.g. the elements of a list)
|
85
|
-
|
86
|
-
In the end, everything has to be implemented twice: in _safe_repr and in
|
87
|
-
the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
|
88
|
-
straightforward to extend (especially when we want a compact output), so
|
89
|
-
the code is a bit convoluted.
|
90
|
-
|
91
|
-
This class overrides:
|
92
|
-
- format() to support the changed_only parameter
|
93
|
-
- _safe_repr to support printing of BaseObjects that fit on a single line
|
94
|
-
- _format_dict_items so that dict are correctly 'compacted'
|
95
|
-
- _format_items so that ellipsis is used on long lists and tuples
|
96
|
-
|
97
|
-
When BaseObjects cannot be printed on a single line, the builtin _format()
|
98
|
-
will call _pprint_object() because it was registered to do so (see
|
99
|
-
_dispatch[BaseObject.__repr__] = _pprint_object).
|
100
|
-
|
101
|
-
both _format_dict_items() and _pprint_Object() use the
|
102
|
-
_format_params_or_dict_items() method that will format parameters and
|
103
|
-
key-value pairs respecting the compact parameter. This method needs another
|
104
|
-
subroutine _pprint_key_val_tuple() used when a parameter or a key-value
|
105
|
-
pair is too long to fit on a single line. This subroutine is called in
|
106
|
-
_format() and is registered as well in the _dispatch dict (just like
|
107
|
-
_pprint_object). We had to create the two classes KeyValTuple and
|
108
|
-
KeyValTupleParam for this.
|
109
|
-
"""
|
110
|
-
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
indent=1,
|
114
|
-
width=80,
|
115
|
-
depth=None,
|
116
|
-
stream=None,
|
117
|
-
*,
|
118
|
-
compact=False,
|
119
|
-
indent_at_name=True,
|
120
|
-
n_max_elements_to_show=None,
|
121
|
-
changed_only=True,
|
122
|
-
):
|
123
|
-
super().__init__(indent, width, depth, stream, compact=compact)
|
124
|
-
self._indent_at_name = indent_at_name
|
125
|
-
if self._indent_at_name:
|
126
|
-
self._indent_per_level = 1 # ignore indent param
|
127
|
-
self.changed_only = changed_only
|
128
|
-
# Max number of elements in a list, dict, tuple until we start using
|
129
|
-
# ellipsis. This also affects the number of arguments of a BaseObject
|
130
|
-
# (they are treated as dicts)
|
131
|
-
self.n_max_elements_to_show = n_max_elements_to_show
|
132
|
-
|
133
|
-
def format(self, obj, context, maxlevels, level): # noqa
|
134
|
-
return _safe_repr(
|
135
|
-
obj, context, maxlevels, level, changed_only=self.changed_only
|
136
|
-
)
|
137
|
-
|
138
|
-
def _pprint_object(self, obj, stream, indent, allowance, context, level):
|
139
|
-
stream.write(obj.__class__.__name__ + "(")
|
140
|
-
if self._indent_at_name:
|
141
|
-
indent += len(obj.__class__.__name__)
|
142
|
-
|
143
|
-
if self.changed_only:
|
144
|
-
params = _changed_params(obj)
|
145
|
-
else:
|
146
|
-
params = obj.get_params(deep=False)
|
147
|
-
|
148
|
-
params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
|
149
|
-
|
150
|
-
self._format_params(
|
151
|
-
params.items(), stream, indent, allowance + 1, context, level
|
152
|
-
)
|
153
|
-
stream.write(")")
|
154
|
-
|
155
|
-
def _format_dict_items(self, items, stream, indent, allowance, context, level):
|
156
|
-
return self._format_params_or_dict_items(
|
157
|
-
items, stream, indent, allowance, context, level, is_dict=True
|
158
|
-
)
|
159
|
-
|
160
|
-
def _format_params(self, items, stream, indent, allowance, context, level):
|
161
|
-
return self._format_params_or_dict_items(
|
162
|
-
items, stream, indent, allowance, context, level, is_dict=False
|
163
|
-
)
|
164
|
-
|
165
|
-
def _format_params_or_dict_items(
|
166
|
-
self, obj, stream, indent, allowance, context, level, is_dict
|
167
|
-
):
|
168
|
-
"""Format dict items or parameters respecting the compact=True parameter.
|
169
|
-
|
170
|
-
For some reason, the builtin rendering of dict items doesn't
|
171
|
-
respect compact=True and will use one line per key-value if all cannot
|
172
|
-
fit in a single line.
|
173
|
-
Dict items will be rendered as <'key': value> while params will be
|
174
|
-
rendered as <key=value>. The implementation is mostly copy/pasting from
|
175
|
-
the builtin _format_items().
|
176
|
-
This also adds ellipsis if the number of items is greater than
|
177
|
-
self.n_max_elements_to_show.
|
178
|
-
"""
|
179
|
-
write = stream.write
|
180
|
-
indent += self._indent_per_level
|
181
|
-
delimnl = ",\n" + " " * indent
|
182
|
-
delim = ""
|
183
|
-
width = max_width = self._width - indent + 1
|
184
|
-
it = iter(obj)
|
185
|
-
try:
|
186
|
-
next_ent = next(it)
|
187
|
-
except StopIteration:
|
188
|
-
return
|
189
|
-
last = False
|
190
|
-
n_items = 0
|
191
|
-
while not last:
|
192
|
-
if n_items == self.n_max_elements_to_show:
|
193
|
-
write(", ...")
|
194
|
-
break
|
195
|
-
n_items += 1
|
196
|
-
ent = next_ent
|
197
|
-
try:
|
198
|
-
next_ent = next(it)
|
199
|
-
except StopIteration:
|
200
|
-
last = True
|
201
|
-
max_width -= allowance
|
202
|
-
width -= allowance
|
203
|
-
if self._compact:
|
204
|
-
k, v = ent
|
205
|
-
krepr = self._repr(k, context, level)
|
206
|
-
vrepr = self._repr(v, context, level)
|
207
|
-
if not is_dict:
|
208
|
-
krepr = krepr.strip("'")
|
209
|
-
middle = ": " if is_dict else "="
|
210
|
-
rep = krepr + middle + vrepr
|
211
|
-
w = len(rep) + 2
|
212
|
-
if width < w:
|
213
|
-
width = max_width
|
214
|
-
if delim:
|
215
|
-
delim = delimnl
|
216
|
-
if width >= w:
|
217
|
-
width -= w
|
218
|
-
write(delim)
|
219
|
-
delim = ", "
|
220
|
-
write(rep)
|
221
|
-
continue
|
222
|
-
write(delim)
|
223
|
-
delim = delimnl
|
224
|
-
class_ = KeyValTuple if is_dict else KeyValTupleParam
|
225
|
-
self._format(
|
226
|
-
class_(ent), stream, indent, allowance if last else 1, context, level
|
227
|
-
)
|
228
|
-
|
229
|
-
def _format_items(self, items, stream, indent, allowance, context, level):
|
230
|
-
"""Format the items of an iterable (list, tuple...).
|
231
|
-
|
232
|
-
Same as the built-in _format_items, with support for ellipsis if the
|
233
|
-
number of elements is greater than self.n_max_elements_to_show.
|
234
|
-
"""
|
235
|
-
write = stream.write
|
236
|
-
indent += self._indent_per_level
|
237
|
-
if self._indent_per_level > 1:
|
238
|
-
write((self._indent_per_level - 1) * " ")
|
239
|
-
delimnl = ",\n" + " " * indent
|
240
|
-
delim = ""
|
241
|
-
width = max_width = self._width - indent + 1
|
242
|
-
it = iter(items)
|
243
|
-
try:
|
244
|
-
next_ent = next(it)
|
245
|
-
except StopIteration:
|
246
|
-
return
|
247
|
-
last = False
|
248
|
-
n_items = 0
|
249
|
-
while not last:
|
250
|
-
if n_items == self.n_max_elements_to_show:
|
251
|
-
write(", ...")
|
252
|
-
break
|
253
|
-
n_items += 1
|
254
|
-
ent = next_ent
|
255
|
-
try:
|
256
|
-
next_ent = next(it)
|
257
|
-
except StopIteration:
|
258
|
-
last = True
|
259
|
-
max_width -= allowance
|
260
|
-
width -= allowance
|
261
|
-
if self._compact:
|
262
|
-
rep = self._repr(ent, context, level)
|
263
|
-
w = len(rep) + 2
|
264
|
-
if width < w:
|
265
|
-
width = max_width
|
266
|
-
if delim:
|
267
|
-
delim = delimnl
|
268
|
-
if width >= w:
|
269
|
-
width -= w
|
270
|
-
write(delim)
|
271
|
-
delim = ", "
|
272
|
-
write(rep)
|
273
|
-
continue
|
274
|
-
write(delim)
|
275
|
-
delim = delimnl
|
276
|
-
self._format(ent, stream, indent, allowance if last else 1, context, level)
|
277
|
-
|
278
|
-
def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
|
279
|
-
"""Pretty printing for key-value tuples from dict or parameters."""
|
280
|
-
k, v = obj
|
281
|
-
rep = self._repr(k, context, level)
|
282
|
-
if isinstance(obj, KeyValTupleParam):
|
283
|
-
rep = rep.strip("'")
|
284
|
-
middle = "="
|
285
|
-
else:
|
286
|
-
middle = ": "
|
287
|
-
stream.write(rep)
|
288
|
-
stream.write(middle)
|
289
|
-
self._format(
|
290
|
-
v, stream, indent + len(rep) + len(middle), allowance, context, level
|
291
|
-
)
|
292
|
-
|
293
|
-
# Follow what scikit-learn did here and copy _dispatch to prevent instances
|
294
|
-
# of the builtin PrettyPrinter class to call methods of
|
295
|
-
# _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
|
296
|
-
# mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
|
297
|
-
_dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
|
298
|
-
_dispatch[BaseObject.__repr__] = _pprint_object
|
299
|
-
_dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
|
300
|
-
|
301
|
-
|
302
|
-
def _safe_repr(obj, context, maxlevels, level, changed_only=False):
|
303
|
-
"""Safe string representation logic.
|
304
|
-
|
305
|
-
Same as the builtin _safe_repr, with added support for BaseObjects.
|
306
|
-
"""
|
307
|
-
typ = type(obj)
|
308
|
-
|
309
|
-
if typ in pprint._builtin_scalars:
|
310
|
-
return repr(obj), True, False
|
311
|
-
|
312
|
-
r = getattr(typ, "__repr__", None)
|
313
|
-
if issubclass(typ, dict) and r is dict.__repr__:
|
314
|
-
if not obj:
|
315
|
-
return "{}", True, False
|
316
|
-
objid = id(obj)
|
317
|
-
if maxlevels and level >= maxlevels:
|
318
|
-
return "{...}", False, objid in context
|
319
|
-
if objid in context:
|
320
|
-
return pprint._recursion(obj), False, True
|
321
|
-
context[objid] = 1
|
322
|
-
readable = True
|
323
|
-
recursive = False
|
324
|
-
components = []
|
325
|
-
append = components.append
|
326
|
-
level += 1
|
327
|
-
saferepr = _safe_repr
|
328
|
-
items = sorted(obj.items(), key=pprint._safe_tuple)
|
329
|
-
for k, v in items:
|
330
|
-
krepr, kreadable, krecur = saferepr(
|
331
|
-
k, context, maxlevels, level, changed_only=changed_only
|
332
|
-
)
|
333
|
-
vrepr, vreadable, vrecur = saferepr(
|
334
|
-
v, context, maxlevels, level, changed_only=changed_only
|
335
|
-
)
|
336
|
-
append("%s: %s" % (krepr, vrepr))
|
337
|
-
readable = readable and kreadable and vreadable
|
338
|
-
if krecur or vrecur:
|
339
|
-
recursive = True
|
340
|
-
del context[objid]
|
341
|
-
return "{%s}" % ", ".join(components), readable, recursive
|
342
|
-
|
343
|
-
if (issubclass(typ, list) and r is list.__repr__) or (
|
344
|
-
issubclass(typ, tuple) and r is tuple.__repr__
|
345
|
-
):
|
346
|
-
if issubclass(typ, list):
|
347
|
-
if not obj:
|
348
|
-
return "[]", True, False
|
349
|
-
format_ = "[%s]"
|
350
|
-
elif len(obj) == 1:
|
351
|
-
format_ = "(%s,)"
|
352
|
-
else:
|
353
|
-
if not obj:
|
354
|
-
return "()", True, False
|
355
|
-
format_ = "(%s)"
|
356
|
-
objid = id(obj)
|
357
|
-
if maxlevels and level >= maxlevels:
|
358
|
-
return format_ % "...", False, objid in context
|
359
|
-
if objid in context:
|
360
|
-
return pprint._recursion(obj), False, True
|
361
|
-
context[objid] = 1
|
362
|
-
readable = True
|
363
|
-
recursive = False
|
364
|
-
components = []
|
365
|
-
append = components.append
|
366
|
-
level += 1
|
367
|
-
for o in obj:
|
368
|
-
orepr, oreadable, orecur = _safe_repr(
|
369
|
-
o, context, maxlevels, level, changed_only=changed_only
|
370
|
-
)
|
371
|
-
append(orepr)
|
372
|
-
if not oreadable:
|
373
|
-
readable = False
|
374
|
-
if orecur:
|
375
|
-
recursive = True
|
376
|
-
del context[objid]
|
377
|
-
return format_ % ", ".join(components), readable, recursive
|
378
|
-
|
379
|
-
if issubclass(typ, BaseObject):
|
380
|
-
objid = id(obj)
|
381
|
-
if maxlevels and level >= maxlevels:
|
382
|
-
return "{...}", False, objid in context
|
383
|
-
if objid in context:
|
384
|
-
return pprint._recursion(obj), False, True
|
385
|
-
context[objid] = 1
|
386
|
-
readable = True
|
387
|
-
recursive = False
|
388
|
-
if changed_only:
|
389
|
-
params = _changed_params(obj)
|
390
|
-
else:
|
391
|
-
params = obj.get_params(deep=False)
|
392
|
-
components = []
|
393
|
-
append = components.append
|
394
|
-
level += 1
|
395
|
-
saferepr = _safe_repr
|
396
|
-
items = sorted(params.items(), key=pprint._safe_tuple)
|
397
|
-
for k, v in items:
|
398
|
-
krepr, kreadable, krecur = saferepr(
|
399
|
-
k, context, maxlevels, level, changed_only=changed_only
|
400
|
-
)
|
401
|
-
vrepr, vreadable, vrecur = saferepr(
|
402
|
-
v, context, maxlevels, level, changed_only=changed_only
|
403
|
-
)
|
404
|
-
append("%s=%s" % (krepr.strip("'"), vrepr))
|
405
|
-
readable = readable and kreadable and vreadable
|
406
|
-
if krecur or vrecur:
|
407
|
-
recursive = True
|
408
|
-
del context[objid]
|
409
|
-
return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
|
410
|
-
|
411
|
-
rep = repr(obj)
|
412
|
-
return rep, (rep and not rep.startswith("<")), False
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
# Many elements of this code were developed in scikit-learn. These elements
|
4
|
+
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
|
+
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
|
+
"""Utility functionality for pretty-printing objects used in BaseObject.__repr__."""
|
7
|
+
import inspect
|
8
|
+
import pprint
|
9
|
+
from collections import OrderedDict
|
10
|
+
|
11
|
+
from skbase.base import BaseObject
|
12
|
+
|
13
|
+
# from skbase.config import get_config
|
14
|
+
from skbase.utils._check import _is_scalar_nan
|
15
|
+
|
16
|
+
|
17
|
+
class KeyValTuple(tuple):
|
18
|
+
"""Dummy class for correctly rendering key-value tuples from dicts."""
|
19
|
+
|
20
|
+
def __repr__(self):
|
21
|
+
"""Represent as string."""
|
22
|
+
# needed for _dispatch[tuple.__repr__] not to be overridden
|
23
|
+
return super().__repr__()
|
24
|
+
|
25
|
+
|
26
|
+
class KeyValTupleParam(KeyValTuple):
|
27
|
+
"""Dummy class for correctly rendering key-value tuples from parameters."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
def _changed_params(base_object):
|
33
|
+
"""Return dict (param_name: value) of parameters with non-default values."""
|
34
|
+
params = base_object.get_params(deep=False)
|
35
|
+
init_func = getattr(
|
36
|
+
base_object.__init__, "deprecated_original", base_object.__init__
|
37
|
+
)
|
38
|
+
init_params = inspect.signature(init_func).parameters
|
39
|
+
init_params = {name: param.default for name, param in init_params.items()}
|
40
|
+
|
41
|
+
def has_changed(k, v):
|
42
|
+
if k not in init_params: # happens if k is part of a **kwargs
|
43
|
+
return True
|
44
|
+
if init_params[k] == inspect._empty: # k has no default value
|
45
|
+
return True
|
46
|
+
# try to avoid calling repr on nested BaseObjects
|
47
|
+
if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__:
|
48
|
+
return True
|
49
|
+
# Use repr as a last resort. It may be expensive.
|
50
|
+
if repr(v) != repr(init_params[k]) and not (
|
51
|
+
_is_scalar_nan(init_params[k]) and _is_scalar_nan(v)
|
52
|
+
):
|
53
|
+
return True
|
54
|
+
return False
|
55
|
+
|
56
|
+
return {k: v for k, v in params.items() if has_changed(k, v)}
|
57
|
+
|
58
|
+
|
59
|
+
class _BaseObjectPrettyPrinter(pprint.PrettyPrinter):
|
60
|
+
"""Pretty Printer class for BaseObjects.
|
61
|
+
|
62
|
+
This extends the pprint.PrettyPrinter class similar to scikit-learn's
|
63
|
+
implementation, so that:
|
64
|
+
|
65
|
+
- BaseObjects are printed with their parameters, e.g.
|
66
|
+
BaseObject(param1=value1, ...) which is not supported by default.
|
67
|
+
- the 'compact' parameter of PrettyPrinter is ignored for dicts, which
|
68
|
+
may lead to very long representations that we want to avoid.
|
69
|
+
|
70
|
+
Quick overview of pprint.PrettyPrinter (see also
|
71
|
+
https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers):
|
72
|
+
|
73
|
+
- the entry point is the _format() method which calls format() (overridden
|
74
|
+
here)
|
75
|
+
- format() directly calls _safe_repr() for a first try at rendering the
|
76
|
+
object
|
77
|
+
- _safe_repr formats the whole object recursively, only calling itself,
|
78
|
+
not caring about line length or anything
|
79
|
+
- back to _format(), if the output string is too long, _format() then calls
|
80
|
+
the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on
|
81
|
+
the type of the object. This where the line length and the compact
|
82
|
+
parameters are taken into account.
|
83
|
+
- those _pprint_TYPE() methods will internally use the format() method for
|
84
|
+
rendering the nested objects of an object (e.g. the elements of a list)
|
85
|
+
|
86
|
+
In the end, everything has to be implemented twice: in _safe_repr and in
|
87
|
+
the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not
|
88
|
+
straightforward to extend (especially when we want a compact output), so
|
89
|
+
the code is a bit convoluted.
|
90
|
+
|
91
|
+
This class overrides:
|
92
|
+
- format() to support the changed_only parameter
|
93
|
+
- _safe_repr to support printing of BaseObjects that fit on a single line
|
94
|
+
- _format_dict_items so that dict are correctly 'compacted'
|
95
|
+
- _format_items so that ellipsis is used on long lists and tuples
|
96
|
+
|
97
|
+
When BaseObjects cannot be printed on a single line, the builtin _format()
|
98
|
+
will call _pprint_object() because it was registered to do so (see
|
99
|
+
_dispatch[BaseObject.__repr__] = _pprint_object).
|
100
|
+
|
101
|
+
both _format_dict_items() and _pprint_Object() use the
|
102
|
+
_format_params_or_dict_items() method that will format parameters and
|
103
|
+
key-value pairs respecting the compact parameter. This method needs another
|
104
|
+
subroutine _pprint_key_val_tuple() used when a parameter or a key-value
|
105
|
+
pair is too long to fit on a single line. This subroutine is called in
|
106
|
+
_format() and is registered as well in the _dispatch dict (just like
|
107
|
+
_pprint_object). We had to create the two classes KeyValTuple and
|
108
|
+
KeyValTupleParam for this.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
indent=1,
|
114
|
+
width=80,
|
115
|
+
depth=None,
|
116
|
+
stream=None,
|
117
|
+
*,
|
118
|
+
compact=False,
|
119
|
+
indent_at_name=True,
|
120
|
+
n_max_elements_to_show=None,
|
121
|
+
changed_only=True,
|
122
|
+
):
|
123
|
+
super().__init__(indent, width, depth, stream, compact=compact)
|
124
|
+
self._indent_at_name = indent_at_name
|
125
|
+
if self._indent_at_name:
|
126
|
+
self._indent_per_level = 1 # ignore indent param
|
127
|
+
self.changed_only = changed_only
|
128
|
+
# Max number of elements in a list, dict, tuple until we start using
|
129
|
+
# ellipsis. This also affects the number of arguments of a BaseObject
|
130
|
+
# (they are treated as dicts)
|
131
|
+
self.n_max_elements_to_show = n_max_elements_to_show
|
132
|
+
|
133
|
+
def format(self, obj, context, maxlevels, level): # noqa
|
134
|
+
return _safe_repr(
|
135
|
+
obj, context, maxlevels, level, changed_only=self.changed_only
|
136
|
+
)
|
137
|
+
|
138
|
+
def _pprint_object(self, obj, stream, indent, allowance, context, level):
|
139
|
+
stream.write(obj.__class__.__name__ + "(")
|
140
|
+
if self._indent_at_name:
|
141
|
+
indent += len(obj.__class__.__name__)
|
142
|
+
|
143
|
+
if self.changed_only:
|
144
|
+
params = _changed_params(obj)
|
145
|
+
else:
|
146
|
+
params = obj.get_params(deep=False)
|
147
|
+
|
148
|
+
params = OrderedDict((name, val) for (name, val) in sorted(params.items()))
|
149
|
+
|
150
|
+
self._format_params(
|
151
|
+
params.items(), stream, indent, allowance + 1, context, level
|
152
|
+
)
|
153
|
+
stream.write(")")
|
154
|
+
|
155
|
+
def _format_dict_items(self, items, stream, indent, allowance, context, level):
|
156
|
+
return self._format_params_or_dict_items(
|
157
|
+
items, stream, indent, allowance, context, level, is_dict=True
|
158
|
+
)
|
159
|
+
|
160
|
+
def _format_params(self, items, stream, indent, allowance, context, level):
|
161
|
+
return self._format_params_or_dict_items(
|
162
|
+
items, stream, indent, allowance, context, level, is_dict=False
|
163
|
+
)
|
164
|
+
|
165
|
+
def _format_params_or_dict_items(
|
166
|
+
self, obj, stream, indent, allowance, context, level, is_dict
|
167
|
+
):
|
168
|
+
"""Format dict items or parameters respecting the compact=True parameter.
|
169
|
+
|
170
|
+
For some reason, the builtin rendering of dict items doesn't
|
171
|
+
respect compact=True and will use one line per key-value if all cannot
|
172
|
+
fit in a single line.
|
173
|
+
Dict items will be rendered as <'key': value> while params will be
|
174
|
+
rendered as <key=value>. The implementation is mostly copy/pasting from
|
175
|
+
the builtin _format_items().
|
176
|
+
This also adds ellipsis if the number of items is greater than
|
177
|
+
self.n_max_elements_to_show.
|
178
|
+
"""
|
179
|
+
write = stream.write
|
180
|
+
indent += self._indent_per_level
|
181
|
+
delimnl = ",\n" + " " * indent
|
182
|
+
delim = ""
|
183
|
+
width = max_width = self._width - indent + 1
|
184
|
+
it = iter(obj)
|
185
|
+
try:
|
186
|
+
next_ent = next(it)
|
187
|
+
except StopIteration:
|
188
|
+
return
|
189
|
+
last = False
|
190
|
+
n_items = 0
|
191
|
+
while not last:
|
192
|
+
if n_items == self.n_max_elements_to_show:
|
193
|
+
write(", ...")
|
194
|
+
break
|
195
|
+
n_items += 1
|
196
|
+
ent = next_ent
|
197
|
+
try:
|
198
|
+
next_ent = next(it)
|
199
|
+
except StopIteration:
|
200
|
+
last = True
|
201
|
+
max_width -= allowance
|
202
|
+
width -= allowance
|
203
|
+
if self._compact:
|
204
|
+
k, v = ent
|
205
|
+
krepr = self._repr(k, context, level)
|
206
|
+
vrepr = self._repr(v, context, level)
|
207
|
+
if not is_dict:
|
208
|
+
krepr = krepr.strip("'")
|
209
|
+
middle = ": " if is_dict else "="
|
210
|
+
rep = krepr + middle + vrepr
|
211
|
+
w = len(rep) + 2
|
212
|
+
if width < w:
|
213
|
+
width = max_width
|
214
|
+
if delim:
|
215
|
+
delim = delimnl
|
216
|
+
if width >= w:
|
217
|
+
width -= w
|
218
|
+
write(delim)
|
219
|
+
delim = ", "
|
220
|
+
write(rep)
|
221
|
+
continue
|
222
|
+
write(delim)
|
223
|
+
delim = delimnl
|
224
|
+
class_ = KeyValTuple if is_dict else KeyValTupleParam
|
225
|
+
self._format(
|
226
|
+
class_(ent), stream, indent, allowance if last else 1, context, level
|
227
|
+
)
|
228
|
+
|
229
|
+
def _format_items(self, items, stream, indent, allowance, context, level):
|
230
|
+
"""Format the items of an iterable (list, tuple...).
|
231
|
+
|
232
|
+
Same as the built-in _format_items, with support for ellipsis if the
|
233
|
+
number of elements is greater than self.n_max_elements_to_show.
|
234
|
+
"""
|
235
|
+
write = stream.write
|
236
|
+
indent += self._indent_per_level
|
237
|
+
if self._indent_per_level > 1:
|
238
|
+
write((self._indent_per_level - 1) * " ")
|
239
|
+
delimnl = ",\n" + " " * indent
|
240
|
+
delim = ""
|
241
|
+
width = max_width = self._width - indent + 1
|
242
|
+
it = iter(items)
|
243
|
+
try:
|
244
|
+
next_ent = next(it)
|
245
|
+
except StopIteration:
|
246
|
+
return
|
247
|
+
last = False
|
248
|
+
n_items = 0
|
249
|
+
while not last:
|
250
|
+
if n_items == self.n_max_elements_to_show:
|
251
|
+
write(", ...")
|
252
|
+
break
|
253
|
+
n_items += 1
|
254
|
+
ent = next_ent
|
255
|
+
try:
|
256
|
+
next_ent = next(it)
|
257
|
+
except StopIteration:
|
258
|
+
last = True
|
259
|
+
max_width -= allowance
|
260
|
+
width -= allowance
|
261
|
+
if self._compact:
|
262
|
+
rep = self._repr(ent, context, level)
|
263
|
+
w = len(rep) + 2
|
264
|
+
if width < w:
|
265
|
+
width = max_width
|
266
|
+
if delim:
|
267
|
+
delim = delimnl
|
268
|
+
if width >= w:
|
269
|
+
width -= w
|
270
|
+
write(delim)
|
271
|
+
delim = ", "
|
272
|
+
write(rep)
|
273
|
+
continue
|
274
|
+
write(delim)
|
275
|
+
delim = delimnl
|
276
|
+
self._format(ent, stream, indent, allowance if last else 1, context, level)
|
277
|
+
|
278
|
+
def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level):
|
279
|
+
"""Pretty printing for key-value tuples from dict or parameters."""
|
280
|
+
k, v = obj
|
281
|
+
rep = self._repr(k, context, level)
|
282
|
+
if isinstance(obj, KeyValTupleParam):
|
283
|
+
rep = rep.strip("'")
|
284
|
+
middle = "="
|
285
|
+
else:
|
286
|
+
middle = ": "
|
287
|
+
stream.write(rep)
|
288
|
+
stream.write(middle)
|
289
|
+
self._format(
|
290
|
+
v, stream, indent + len(rep) + len(middle), allowance, context, level
|
291
|
+
)
|
292
|
+
|
293
|
+
# Follow what scikit-learn did here and copy _dispatch to prevent instances
|
294
|
+
# of the builtin PrettyPrinter class to call methods of
|
295
|
+
# _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906)
|
296
|
+
# mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch"
|
297
|
+
_dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore
|
298
|
+
_dispatch[BaseObject.__repr__] = _pprint_object
|
299
|
+
_dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
|
300
|
+
|
301
|
+
|
302
|
+
def _safe_repr(obj, context, maxlevels, level, changed_only=False):
|
303
|
+
"""Safe string representation logic.
|
304
|
+
|
305
|
+
Same as the builtin _safe_repr, with added support for BaseObjects.
|
306
|
+
"""
|
307
|
+
typ = type(obj)
|
308
|
+
|
309
|
+
if typ in pprint._builtin_scalars:
|
310
|
+
return repr(obj), True, False
|
311
|
+
|
312
|
+
r = getattr(typ, "__repr__", None)
|
313
|
+
if issubclass(typ, dict) and r is dict.__repr__:
|
314
|
+
if not obj:
|
315
|
+
return "{}", True, False
|
316
|
+
objid = id(obj)
|
317
|
+
if maxlevels and level >= maxlevels:
|
318
|
+
return "{...}", False, objid in context
|
319
|
+
if objid in context:
|
320
|
+
return pprint._recursion(obj), False, True
|
321
|
+
context[objid] = 1
|
322
|
+
readable = True
|
323
|
+
recursive = False
|
324
|
+
components = []
|
325
|
+
append = components.append
|
326
|
+
level += 1
|
327
|
+
saferepr = _safe_repr
|
328
|
+
items = sorted(obj.items(), key=pprint._safe_tuple)
|
329
|
+
for k, v in items:
|
330
|
+
krepr, kreadable, krecur = saferepr(
|
331
|
+
k, context, maxlevels, level, changed_only=changed_only
|
332
|
+
)
|
333
|
+
vrepr, vreadable, vrecur = saferepr(
|
334
|
+
v, context, maxlevels, level, changed_only=changed_only
|
335
|
+
)
|
336
|
+
append("%s: %s" % (krepr, vrepr))
|
337
|
+
readable = readable and kreadable and vreadable
|
338
|
+
if krecur or vrecur:
|
339
|
+
recursive = True
|
340
|
+
del context[objid]
|
341
|
+
return "{%s}" % ", ".join(components), readable, recursive
|
342
|
+
|
343
|
+
if (issubclass(typ, list) and r is list.__repr__) or (
|
344
|
+
issubclass(typ, tuple) and r is tuple.__repr__
|
345
|
+
):
|
346
|
+
if issubclass(typ, list):
|
347
|
+
if not obj:
|
348
|
+
return "[]", True, False
|
349
|
+
format_ = "[%s]"
|
350
|
+
elif len(obj) == 1:
|
351
|
+
format_ = "(%s,)"
|
352
|
+
else:
|
353
|
+
if not obj:
|
354
|
+
return "()", True, False
|
355
|
+
format_ = "(%s)"
|
356
|
+
objid = id(obj)
|
357
|
+
if maxlevels and level >= maxlevels:
|
358
|
+
return format_ % "...", False, objid in context
|
359
|
+
if objid in context:
|
360
|
+
return pprint._recursion(obj), False, True
|
361
|
+
context[objid] = 1
|
362
|
+
readable = True
|
363
|
+
recursive = False
|
364
|
+
components = []
|
365
|
+
append = components.append
|
366
|
+
level += 1
|
367
|
+
for o in obj:
|
368
|
+
orepr, oreadable, orecur = _safe_repr(
|
369
|
+
o, context, maxlevels, level, changed_only=changed_only
|
370
|
+
)
|
371
|
+
append(orepr)
|
372
|
+
if not oreadable:
|
373
|
+
readable = False
|
374
|
+
if orecur:
|
375
|
+
recursive = True
|
376
|
+
del context[objid]
|
377
|
+
return format_ % ", ".join(components), readable, recursive
|
378
|
+
|
379
|
+
if issubclass(typ, BaseObject):
|
380
|
+
objid = id(obj)
|
381
|
+
if maxlevels and level >= maxlevels:
|
382
|
+
return "{...}", False, objid in context
|
383
|
+
if objid in context:
|
384
|
+
return pprint._recursion(obj), False, True
|
385
|
+
context[objid] = 1
|
386
|
+
readable = True
|
387
|
+
recursive = False
|
388
|
+
if changed_only:
|
389
|
+
params = _changed_params(obj)
|
390
|
+
else:
|
391
|
+
params = obj.get_params(deep=False)
|
392
|
+
components = []
|
393
|
+
append = components.append
|
394
|
+
level += 1
|
395
|
+
saferepr = _safe_repr
|
396
|
+
items = sorted(params.items(), key=pprint._safe_tuple)
|
397
|
+
for k, v in items:
|
398
|
+
krepr, kreadable, krecur = saferepr(
|
399
|
+
k, context, maxlevels, level, changed_only=changed_only
|
400
|
+
)
|
401
|
+
vrepr, vreadable, vrecur = saferepr(
|
402
|
+
v, context, maxlevels, level, changed_only=changed_only
|
403
|
+
)
|
404
|
+
append("%s=%s" % (krepr.strip("'"), vrepr))
|
405
|
+
readable = readable and kreadable and vreadable
|
406
|
+
if krecur or vrecur:
|
407
|
+
recursive = True
|
408
|
+
del context[objid]
|
409
|
+
return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive)
|
410
|
+
|
411
|
+
rep = repr(obj)
|
412
|
+
return rep, (rep and not rep.startswith("<")), False
|