scikit-base 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.7.2
3
+ Version: 0.7.4
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.7.2 is now available. Check out our
117
+ :rocket: Version 0.7.4 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
@@ -1,16 +1,16 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=sQyqEwB4g0eTD28WSUR-p2oMyfVkfk4lapeLppVk7Ks,345
2
+ skbase/__init__.py,sha256=abk0HlHOHt1z9B3iOIIUbVD4EdH8nJ3T-1HR7UuI9u0,345
3
3
  skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=J-6Wwor9yn02zg4vl2WAZfE-kwdhIggfV1xYSpMwP-0,53249
6
+ skbase/base/_base.py,sha256=1MJgavydCw-4TNqA4Na_7LMVoh4w4D5q81l15SbKJUM,53490
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
10
10
  skbase/base/_pretty_printing/_object_html_repr.py,sha256=0DHcM3AHIRkV1fCRi-G7lzDmiSTR2-MjU40iXUuV2AM,11538
11
11
  skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
12
12
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
13
- skbase/lookup/_lookup.py,sha256=YYsqz71VyPCAbqTiAm-WAuHyLUfO5pOgNlaKMdFGCOU,39824
13
+ skbase/lookup/_lookup.py,sha256=7L1JIMCzpMdSF5ZqHNDeIaHu4QRwXoLJ4DgM1Z_uFts,39864
14
14
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
15
15
  skbase/lookup/tests/test_lookup.py,sha256=_VDReGKnJF52UtFbvg_D2vlAkVvREypwM-9jR7DPAXQ,38218
16
16
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
@@ -19,7 +19,7 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
19
19
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
20
20
  skbase/testing/utils/inspect.py,sha256=XcPdm1-J3YXCTxsrqeJlStPvbC0vH1cgaApN5lzRI2c,741
21
21
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
22
- skbase/tests/conftest.py,sha256=WB8aoyiOtrQealUDCBwU35oq-KiSuz_CF235xdTgoSY,9106
22
+ skbase/tests/conftest.py,sha256=F-D3fqengjnaVSk2L4mYh8Wg_o0kS7L3wmGi2vU1B94,9272
23
23
  skbase/tests/test_base.py,sha256=-kyVDOQRdXYsBmSTqNjZ06mjnt_OWoY2i2i71qx3TF8,50648
24
24
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
25
25
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
@@ -34,14 +34,14 @@ skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
34
34
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
35
35
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
36
36
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
37
- skbase/utils/deep_equals/_deep_equals.py,sha256=KUr1Qat7kL1CuW78aSALPFinUiRQZUUSoBxdRbZMY4E,17318
38
- skbase/utils/dependencies/__init__.py,sha256=89TNnES--f1PeoPm-_h6a2mCtoGXt6mAd-n89FdusMM,352
39
- skbase/utils/dependencies/_dependencies.py,sha256=zmlWZ10HtiHE2PK2T7AzsFNKIA95O8hmIySlz2t4Mrs,10359
37
+ skbase/utils/deep_equals/_deep_equals.py,sha256=-blJhvTGdk4WjiSjBo8t954LysODZwPZoPHk2SBPzCQ,17615
38
+ skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
39
+ skbase/utils/dependencies/_dependencies.py,sha256=L3_ghGBHzaHX964b0bCw7H_Q5X4ILZ5LsQYCEAmZq5U,14501
40
40
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
41
41
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
42
42
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
43
43
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
44
- skbase/utils/tests/test_deep_equals.py,sha256=kRCQ87Tb0lIjP-3kf29wND2J4pCZpBVKC8iaKd8k7YY,2635
44
+ skbase/utils/tests/test_deep_equals.py,sha256=ZKrnCR4Ph14FgBhlIoxxpn8Pki7TGKbYYtymoJz0Fqk,2786
45
45
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
46
46
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
47
47
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
@@ -52,9 +52,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
52
52
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
53
53
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
54
54
  skbase/validate/tests/test_type_validations.py,sha256=G-qwFjXk-8WvXoeOvo2omfFKKjbpWhP-sPf6hsw8q30,14131
55
- scikit_base-0.7.2.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
56
- scikit_base-0.7.2.dist-info/METADATA,sha256=PVrJAa0Bch0rbC4ks09SBgYepftCp_QZUttyN0eGBsI,8704
57
- scikit_base-0.7.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
58
- scikit_base-0.7.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
59
- scikit_base-0.7.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
60
- scikit_base-0.7.2.dist-info/RECORD,,
55
+ scikit_base-0.7.4.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
56
+ scikit_base-0.7.4.dist-info/METADATA,sha256=yktJpyUY8DuNNcflKdRmVroKTBQ1pbb-1tZldt3vGsk,8704
57
+ scikit_base-0.7.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
58
+ scikit_base-0.7.4.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
59
+ scikit_base-0.7.4.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
60
+ scikit_base-0.7.4.dist-info/RECORD,,
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.7.2"
9
+ __version__: str = "0.7.4"
skbase/base/_base.py CHANGED
@@ -63,7 +63,7 @@ from skbase._exceptions import NotFittedError
63
63
  from skbase.base._pretty_printing._object_html_repr import _object_html_repr
64
64
  from skbase.base._tagmanager import _FlagManager
65
65
 
66
- __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
66
+ __author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
67
67
  __all__: List[str] = ["BaseEstimator", "BaseObject"]
68
68
 
69
69
 
@@ -157,113 +157,11 @@ class BaseObject(_FlagManager):
157
157
  -----
158
158
  If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
159
159
  """
160
- self_params = self.get_params(deep=False)
161
- self_clone = self._clone(self)
162
-
163
- # if checking the clone is turned off, return now
164
- if not self.get_config()["check_clone"]:
165
- return self_clone
166
-
167
- from skbase.utils.deep_equals import deep_equals
168
-
169
- # check that all attributes are written to the clone
170
- for attrname in self_params.keys():
171
- if not hasattr(self_clone, attrname):
172
- raise RuntimeError(
173
- f"error in {self}.clone, __init__ must write all arguments "
174
- f"to self and not mutate them, but {attrname} was not found. "
175
- f"Please check __init__ of {self}."
176
- )
177
-
178
- clone_attrs = {attr: getattr(self_clone, attr) for attr in self_params.keys()}
179
-
180
- # check equality of parameters post-clone and pre-clone
181
- clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
182
- if not clone_attrs_valid:
183
- raise RuntimeError(
184
- f"error in {self}.clone, __init__ must write all arguments "
185
- f"to self and not mutate them, but this is not the case. "
186
- f"Error on equality check of arguments (x) vs parameters (y): {msg}"
187
- )
188
-
160
+ self_clone = _clone(self)
161
+ if self.get_config()["check_clone"]:
162
+ _check_clone(original=self, clone=self_clone)
189
163
  return self_clone
190
164
 
191
- # copied from sklearn
192
- def _clone(self, estimator, *, safe=True):
193
- """Construct a new unfitted estimator with the same parameters.
194
-
195
- Clone does a deep copy of the model in an estimator
196
- without actually copying attached data. It returns a new estimator
197
- with the same parameters that has not been fitted on any data.
198
-
199
- Parameters
200
- ----------
201
- estimator : {list, tuple, set} of estimator instance or a single \
202
- estimator instance
203
- The estimator or group of estimators to be cloned.
204
- safe : bool, default=True
205
- If safe is False, clone will fall back to a deep copy on objects
206
- that are not estimators.
207
-
208
- Returns
209
- -------
210
- estimator : object
211
- The deep copy of the input, an estimator if input is an estimator.
212
-
213
- Notes
214
- -----
215
- If the estimator's `random_state` parameter is an integer (or if the
216
- estimator doesn't have a `random_state` parameter), an *exact clone* is
217
- returned: the clone and the original estimator will give the exact same
218
- results. Otherwise, *statistical clone* is returned: the clone might
219
- return different results from the original estimator. More details can be
220
- found in :ref:`randomness`.
221
- """
222
- estimator_type = type(estimator)
223
- # XXX: not handling dictionaries
224
- if estimator_type in (list, tuple, set, frozenset):
225
- return estimator_type([self._clone(e, safe=safe) for e in estimator])
226
- elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
227
- if not safe:
228
- return deepcopy(estimator)
229
- else:
230
- if isinstance(estimator, type):
231
- raise TypeError(
232
- "Cannot clone object. "
233
- + "You should provide an instance of "
234
- + "scikit-learn estimator instead of a class."
235
- )
236
- else:
237
- raise TypeError(
238
- "Cannot clone object '%s' (type %s): "
239
- "it does not seem to be a scikit-learn "
240
- "estimator as it does not implement a "
241
- "'get_params' method." % (repr(estimator), type(estimator))
242
- )
243
-
244
- klass = estimator.__class__
245
- new_object_params = estimator.get_params(deep=False)
246
- for name, param in new_object_params.items():
247
- new_object_params[name] = self._clone(param, safe=False)
248
- new_object = klass(**new_object_params)
249
- params_set = new_object.get_params(deep=False)
250
-
251
- # quick sanity check of the parameters of the clone
252
- for name in new_object_params:
253
- param1 = new_object_params[name]
254
- param2 = params_set[name]
255
- if param1 is not param2:
256
- raise RuntimeError(
257
- "Cannot clone object %s, as the constructor "
258
- "either does not set or modifies parameter %s" % (estimator, name)
259
- )
260
-
261
- # This is an extension to the original sklearn implementation
262
- if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
263
- new_object.set_config(**estimator.get_config())
264
-
265
- return new_object
266
-
267
165
  @classmethod
268
166
  def _get_init_signature(cls):
269
167
  """Get class init signature.
@@ -687,16 +585,18 @@ class BaseObject(_FlagManager):
687
585
  `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
688
586
  `create_test_instance` uses the first (or only) dictionary in `params`
689
587
  """
588
+ params_with_defaults = set(cls.get_param_defaults().keys())
589
+ all_params = set(cls.get_param_names())
590
+ params_without_defaults = all_params - params_with_defaults
591
+
690
592
  # if non-default parameters are required, but none have been found, raise error
691
- if hasattr(cls, "_required_parameters"):
692
- required_parameters = getattr(cls, "_required_parameters", [])
693
- if len(required_parameters) > 0:
694
- raise ValueError(
695
- f"Estimator: {cls} requires "
696
- f"non-default parameters for construction, "
697
- f"but none were given. Please set them "
698
- f"as given in the extension template"
699
- )
593
+ if len(params_without_defaults) > 0:
594
+ raise ValueError(
595
+ f"Estimator: {cls} has parameters without default values, "
596
+ f"but these are not set in get_test_params. "
597
+ f"Please set them in get_test_params, or provide default values. "
598
+ f"Also see the respective extension template, if applicable."
599
+ )
700
600
 
701
601
  # construct with parameter configuration for testing, otherwise construct with
702
602
  # default parameters (empty dict)
@@ -737,7 +637,7 @@ class BaseObject(_FlagManager):
737
637
  "get_test_params should either return a dict or list of dict."
738
638
  )
739
639
 
740
- return cls(**params)
640
+ return cls._safe_init_test_params(params)
741
641
 
742
642
  @classmethod
743
643
  def create_test_instances_and_names(cls, parameter_set="default"):
@@ -757,9 +657,6 @@ class BaseObject(_FlagManager):
757
657
  i-th element is name of i-th instance of obj in tests
758
658
  convention is {cls.__name__}-{i} if more than one instance
759
659
  otherwise {cls.__name__}
760
- parameter_set : str, default="default"
761
- Name of the set of test parameters to return, for use in tests. If no
762
- special parameters are defined for a value, will return `"default"` set.
763
660
  """
764
661
  if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
765
662
  param_list = cls.get_test_params(parameter_set=parameter_set)
@@ -780,7 +677,7 @@ class BaseObject(_FlagManager):
780
677
  f"Error in {cls.__name__}.get_test_params, "
781
678
  "return must be param dict for class, or list thereof"
782
679
  )
783
- objs += [cls(**params)]
680
+ objs += [cls._safe_init_test_params(params)]
784
681
 
785
682
  num_instances = len(param_list)
786
683
  if num_instances > 1:
@@ -790,6 +687,22 @@ class BaseObject(_FlagManager):
790
687
 
791
688
  return objs, names
792
689
 
690
+ @classmethod
691
+ def _safe_init_test_params(cls, params):
692
+ """Safe init of cls with params for testing.
693
+
694
+ Will raise informative error message if params are not valid.
695
+ """
696
+ try:
697
+ return cls(**params)
698
+ except Exception as e:
699
+ raise type(e)(
700
+ f"Error in {cls.__name__}.get_test_params, "
701
+ "return must be valid param dict for class, or list thereof, "
702
+ "but attempted construction raised a exception. "
703
+ f"Problematic parameter set: {params}. Exception raised: {e}"
704
+ ) from e
705
+
793
706
  @classmethod
794
707
  def _has_implementation_of(cls, method):
795
708
  """Check if method has a concrete implementation in this class.
@@ -1387,3 +1300,106 @@ class BaseEstimator(BaseObject):
1387
1300
  fitted parameters, keyed by names of fitted parameter
1388
1301
  """
1389
1302
  return self._get_fitted_params_default()
1303
+
1304
+
1305
+ # Adapted from sklearn's `_clone_parametrized()`
1306
+ def _clone(estimator, *, safe=True):
1307
+ """Construct a new unfitted estimator with the same parameters.
1308
+
1309
+ Clone does a deep copy of the model in an estimator
1310
+ without actually copying attached data. It returns a new estimator
1311
+ with the same parameters that has not been fitted on any data.
1312
+
1313
+ Parameters
1314
+ ----------
1315
+ estimator : {list, tuple, set} of estimator instance or a single \
1316
+ estimator instance
1317
+ The estimator or group of estimators to be cloned.
1318
+ safe : bool, default=True
1319
+ If safe is False, clone will fall back to a deep copy on objects
1320
+ that are not estimators.
1321
+
1322
+ Returns
1323
+ -------
1324
+ estimator : object
1325
+ The deep copy of the input, an estimator if input is an estimator.
1326
+
1327
+ Notes
1328
+ -----
1329
+ If the estimator's `random_state` parameter is an integer (or if the
1330
+ estimator doesn't have a `random_state` parameter), an *exact clone* is
1331
+ returned: the clone and the original estimator will give the exact same
1332
+ results. Otherwise, *statistical clone* is returned: the clone might
1333
+ return different results from the original estimator. More details can be
1334
+ found in :ref:`randomness`.
1335
+ """
1336
+ estimator_type = type(estimator)
1337
+ # XXX: not handling dictionaries
1338
+ if estimator_type in (list, tuple, set, frozenset):
1339
+ return estimator_type([_clone(e, safe=safe) for e in estimator])
1340
+ elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
1341
+ if not safe:
1342
+ return deepcopy(estimator)
1343
+ else:
1344
+ if isinstance(estimator, type):
1345
+ raise TypeError(
1346
+ "Cannot clone object. "
1347
+ + "You should provide an instance of "
1348
+ + "scikit-learn estimator instead of a class."
1349
+ )
1350
+ else:
1351
+ raise TypeError(
1352
+ "Cannot clone object '%s' (type %s): "
1353
+ "it does not seem to be a scikit-learn "
1354
+ "estimator as it does not implement a "
1355
+ "'get_params' method." % (repr(estimator), type(estimator))
1356
+ )
1357
+
1358
+ klass = estimator.__class__
1359
+ new_object_params = estimator.get_params(deep=False)
1360
+ for name, param in new_object_params.items():
1361
+ new_object_params[name] = _clone(param, safe=False)
1362
+ new_object = klass(**new_object_params)
1363
+ params_set = new_object.get_params(deep=False)
1364
+
1365
+ # quick sanity check of the parameters of the clone
1366
+ for name in new_object_params:
1367
+ param1 = new_object_params[name]
1368
+ param2 = params_set[name]
1369
+ if param1 is not param2:
1370
+ raise RuntimeError(
1371
+ "Cannot clone object %s, as the constructor "
1372
+ "either does not set or modifies parameter %s" % (estimator, name)
1373
+ )
1374
+
1375
+ # This is an extension to the original sklearn implementation
1376
+ if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
1377
+ new_object.set_config(**estimator.get_config())
1378
+
1379
+ return new_object
1380
+
1381
+
1382
+ def _check_clone(original, clone):
1383
+ from skbase.utils.deep_equals import deep_equals
1384
+
1385
+ self_params = original.get_params(deep=False)
1386
+
1387
+ # check that all attributes are written to the clone
1388
+ for attrname in self_params.keys():
1389
+ if not hasattr(clone, attrname):
1390
+ raise RuntimeError(
1391
+ f"error in {original}.clone, __init__ must write all arguments "
1392
+ f"to self and not mutate them, but {attrname} was not found. "
1393
+ f"Please check __init__ of {original}."
1394
+ )
1395
+
1396
+ clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
1397
+
1398
+ # check equality of parameters post-clone and pre-clone
1399
+ clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
1400
+ if not clone_attrs_valid:
1401
+ raise RuntimeError(
1402
+ f"error in {original}.clone, __init__ must write all arguments "
1403
+ f"to self and not mutate them, but this is not the case. "
1404
+ f"Error on equality check of arguments (x) vs parameters (y): {msg}"
1405
+ )
skbase/lookup/_lookup.py CHANGED
@@ -693,7 +693,6 @@ def all_objects(
693
693
  object_types=None,
694
694
  filter_tags=None,
695
695
  exclude_objects=None,
696
- exclude_estimators=None,
697
696
  return_names=True,
698
697
  as_dataframe=False,
699
698
  return_tags=None,
@@ -701,7 +700,6 @@ def all_objects(
701
700
  package_name="skbase",
702
701
  path: Optional[str] = None,
703
702
  modules_to_ignore=None,
704
- ignore_modules=None,
705
703
  class_lookup=None,
706
704
  ):
707
705
  """Get a list of all objects in a package with name `package_name`.
@@ -825,9 +823,12 @@ def all_objects(
825
823
  return name.startswith("_") or name.startswith("Base")
826
824
 
827
825
  def _is_estimator(name, klass):
828
- # Check if klass is subclass of base estimators, not an base class itself and
826
+ # Check if klass is subclass of base estimators, not a base class itself and
829
827
  # not an abstract class
830
- return issubclass(klass, BaseObject) and not _is_base_class(name)
828
+ if object_types is None:
829
+ return issubclass(klass, BaseObject) and not _is_base_class(name)
830
+ else:
831
+ return not _is_base_class(name)
831
832
 
832
833
  # Ignore deprecation warnings triggered at import time and from walking packages
833
834
  with warnings.catch_warnings():
skbase/tests/conftest.py CHANGED
@@ -178,6 +178,10 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
178
178
  SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
179
179
  SKBASE_FUNCTIONS_BY_MODULE.update(
180
180
  {
181
+ "skbase.base._base": (
182
+ "_clone",
183
+ "_check_clone",
184
+ ),
181
185
  "skbase.base._pretty_printing._object_html_repr": (
182
186
  "_get_visual_block",
183
187
  "_object_html_repr",
@@ -205,6 +209,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
205
209
  "skbase.utils.dependencies": (
206
210
  "_check_soft_dependencies",
207
211
  "_check_python_version",
212
+ "_check_estimator_deps",
208
213
  ),
209
214
  "skbase.utils._iter": (
210
215
  "_format_seq_to_str",
@@ -240,6 +245,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
240
245
  "skbase.utils.dependencies._dependencies": (
241
246
  "_check_soft_dependencies",
242
247
  "_check_python_version",
248
+ "_check_estimator_deps",
243
249
  ),
244
250
  "skbase.utils.random_state": (
245
251
  "check_random_state",
@@ -480,6 +480,14 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
480
480
  if res is not None:
481
481
  return res
482
482
 
483
+ # if the object x and y have a len() then compare of x and y lengths else continue
484
+ if hasattr(x, "__len__") and hasattr(y, "__len__"):
485
+ if len(x) != len(y):
486
+ return ret(
487
+ False,
488
+ f".len, x.len = {len(x)} != y.len = {len(y)}",
489
+ )
490
+
483
491
  # this if covers case where != is boolean
484
492
  # some types return a vector upon !=, this is covered in the next elif
485
493
  if isinstance(x == y, bool):
@@ -4,8 +4,13 @@
4
4
  """Utility functionality used through `skbase`."""
5
5
 
6
6
  from skbase.utils.dependencies._dependencies import (
7
+ _check_estimator_deps,
7
8
  _check_python_version,
8
9
  _check_soft_dependencies,
9
10
  )
10
11
 
11
- __all__ = ["_check_python_version", "_check_soft_dependencies"]
12
+ __all__ = [
13
+ "_check_python_version",
14
+ "_check_soft_dependencies",
15
+ "_check_estimator_deps",
16
+ ]
@@ -18,6 +18,7 @@ def _check_soft_dependencies(
18
18
  package_import_alias=None,
19
19
  severity="error",
20
20
  obj=None,
21
+ msg=None,
21
22
  suppress_import_stdout=False,
22
23
  ):
23
24
  """Check if required soft dependencies are installed and raise error or warning.
@@ -40,7 +41,7 @@ def _check_soft_dependencies(
40
41
  should be provided if import name differs from package name
41
42
  severity : str, "error" (default), "warning", "none"
42
43
  behaviour for raising errors or warnings
43
- "error" - raises a `ModuleNotFoundException` if one of packages is not installed
44
+ "error" - raises a `ModuleNotFoundError` if one of packages is not installed
44
45
  "warning" - raises a warning if one of packages is not installed
45
46
  function returns False if one of packages is not installed, otherwise True
46
47
  "none" - does not raise exception or warning
@@ -50,6 +51,8 @@ def _check_soft_dependencies(
50
51
  or a class is passed when it is called at the start of a single-class module,
51
52
  the error message is more informative and will refer to the class/object;
52
53
  if str is passed, will be used as name of the class/object or module
54
+ msg : str, or None, default=None
55
+ if str, will override the error message or warning shown with msg
53
56
  suppress_import_stdout : bool, optional. Default=False
54
57
  whether to suppress stdout printout upon import.
55
58
 
@@ -65,17 +68,24 @@ def _check_soft_dependencies(
65
68
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
66
69
  packages = packages[0]
67
70
  if not all(isinstance(x, str) for x in packages):
68
- raise TypeError("packages must be str or tuple of str")
71
+ raise TypeError(
72
+ "packages argument of _check_soft_dependencies must be str or tuple of "
73
+ f"str, but found packages argument of type {type(packages)}"
74
+ )
69
75
 
70
76
  if package_import_alias is None:
71
77
  package_import_alias = {}
72
- msg = "package_import_alias must be a dict with str keys and values"
78
+ msg_pkg_import_alias = (
79
+ "package_import_alias argument of _check_soft_dependencies must "
80
+ "be a dict with str keys and values, but found "
81
+ f"package_import_alias of type {type(package_import_alias)}"
82
+ )
73
83
  if not isinstance(package_import_alias, dict):
74
- raise TypeError(msg)
84
+ raise TypeError(msg_pkg_import_alias)
75
85
  if not all(isinstance(x, str) for x in package_import_alias.keys()):
76
- raise TypeError(msg)
86
+ raise TypeError(msg_pkg_import_alias)
77
87
  if not all(isinstance(x, str) for x in package_import_alias.values()):
78
- raise TypeError(msg)
88
+ raise TypeError(msg_pkg_import_alias)
79
89
 
80
90
  if obj is None:
81
91
  class_name = "This functionality"
@@ -86,7 +96,17 @@ def _check_soft_dependencies(
86
96
  elif isinstance(obj, str):
87
97
  class_name = obj
88
98
  else:
89
- raise TypeError("obj must be a class, an object, a str, or None")
99
+ raise TypeError(
100
+ "obj argument of _check_soft_dependencies must be a class, an object,"
101
+ " a str, or None, but found obj of type"
102
+ f" {type(obj)}"
103
+ )
104
+
105
+ if msg is not None and not isinstance(msg, str):
106
+ raise TypeError(
107
+ "msg argument of _check_soft_dependencies must be a str, "
108
+ f"or None, but found msg of type {type(msg)}"
109
+ )
90
110
 
91
111
  for package in packages:
92
112
  try:
@@ -94,6 +114,7 @@ def _check_soft_dependencies(
94
114
  except InvalidRequirement:
95
115
  msg_version = (
96
116
  f"wrong format for package requirement string, "
117
+ f"passed via packages argument of _check_soft_dependencies, "
97
118
  f'must be PEP 440 compatible requirement string, e.g., "pandas"'
98
119
  f' or "pandas>1.1", but found {package!r}'
99
120
  )
@@ -118,20 +139,23 @@ def _check_soft_dependencies(
118
139
  pkg_ref = import_module(package_import_name)
119
140
  # if package cannot be imported, make the user aware of installation requirement
120
141
  except ModuleNotFoundError as e:
121
- msg = (
122
- f"{e}. "
123
- f"{class_name} requires package {package!r} to be present "
124
- f"in the python environment, but {package!r} was not found. "
125
- )
126
- if obj is not None:
142
+ if msg is None:
143
+ msg = (
144
+ f"{e}. "
145
+ f"{class_name} requires package {package!r} to be present "
146
+ f"in the python environment, but {package!r} was not found. "
147
+ )
148
+ if obj is not None:
149
+ msg = msg + (
150
+ f"{package!r} is a dependency of {class_name} and required "
151
+ f"to construct it. "
152
+ )
127
153
  msg = msg + (
128
- f"{package!r} is a dependency of {class_name} and required "
129
- f"to construct it. "
154
+ f"Please run: `pip install {package}` to "
155
+ f"install the {package} package. "
130
156
  )
131
- msg = msg + (
132
- f"Please run: `pip install {package}` to "
133
- f"install the {package} package. "
134
- )
157
+ # if msg is not None, none of the above is executed,
158
+ # so if msg is passed it overrides the default messages
135
159
 
136
160
  if severity == "error":
137
161
  raise ModuleNotFoundError(msg) from e
@@ -227,10 +251,14 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
227
251
  if sys_version in est_specifier:
228
252
  return True
229
253
  # now we know that est_version is not compatible with sys_version
254
+ if isclass(obj):
255
+ class_name = obj.__name__
256
+ else:
257
+ class_name = type(obj).__name__
230
258
 
231
259
  if not isinstance(msg, str):
232
260
  msg = (
233
- f"{type(obj).__name__} requires python version to be {est_specifier},"
261
+ f"{class_name} requires python version to be {est_specifier},"
234
262
  f" but system python version is {sys.version}."
235
263
  )
236
264
 
@@ -251,3 +279,67 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
251
279
  f'argument must be "error", "warning", or "none", found {severity!r}.'
252
280
  )
253
281
  return True
282
+
283
+
284
+ def _check_estimator_deps(obj, msg=None, severity="error"):
285
+ """Check if object/estimator's package & python requirements are met by python env.
286
+
287
+ Convenience wrapper around `_check_python_version` and `_check_soft_dependencies`,
288
+ checking against estimator tags `"python_version"`, `"python_dependencies"`.
289
+
290
+ Checks whether dependency requirements of `BaseObject`-s in `obj`
291
+ are satisfied by the current python environment.
292
+
293
+ Parameters
294
+ ----------
295
+ obj : `BaseObject` descendant, instance or class, or list/tuple thereof
296
+ object(s) that this function checks compatibility of, with the python env
297
+ msg : str, optional, default = default message (msg below)
298
+ error message to be returned in the `ModuleNotFoundError`, overrides default
299
+ severity : str, "error" (default), "warning", or "none"
300
+ behaviour for raising errors or warnings
301
+ "error" - raises a `ModuleNotFoundError` if environment is incompatible
302
+ "warning" - raises a warning if environment is incompatible
303
+ function returns False if environment is incompatible, otherwise True
304
+ "none" - does not raise exception or warning
305
+ function returns False if environment is incompatible, otherwise True
306
+
307
+ Returns
308
+ -------
309
+ compatible : bool, whether `obj` is compatible with python environment
310
+ False is returned only if no exception is raised by the function
311
+ checks for python version using the python_version tag of obj
312
+ checks for soft dependencies present using the python_dependencies tag of obj
313
+ if `obj` contains multiple `BaseObject`-s, checks whether all are compatible
314
+
315
+ Raises
316
+ ------
317
+ ModuleNotFoundError
318
+ User friendly error if obj has python_version tag that is
319
+ incompatible with the system python version.
320
+ Compatible python versions are determined by the "python_version" tag of obj.
321
+ User friendly error if obj has package dependencies that are not satisfied.
322
+ Packages are determined based on the "python_dependencies" tag of obj.
323
+ """
324
+ compatible = True
325
+
326
+ # if list or tuple, recurse & iterate over element, and return conjunction
327
+ if isinstance(obj, (list, tuple)):
328
+ for x in obj:
329
+ x_chk = _check_estimator_deps(x, msg=msg, severity=severity)
330
+ compatible = compatible and x_chk
331
+ return compatible
332
+
333
+ compatible = compatible and _check_python_version(obj, severity=severity)
334
+
335
+ pkg_deps = obj.get_class_tag("python_dependencies", None)
336
+ pck_alias = obj.get_class_tag("python_dependencies_alias", None)
337
+ if pkg_deps is not None and not isinstance(pkg_deps, list):
338
+ pkg_deps = [pkg_deps]
339
+ if pkg_deps is not None:
340
+ pkg_deps_ok = _check_soft_dependencies(
341
+ *pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
342
+ )
343
+ compatible = compatible and pkg_deps_ok
344
+
345
+ return compatible
@@ -23,6 +23,7 @@ if _check_soft_dependencies("numpy", severity="none"):
23
23
  EXAMPLES += [
24
24
  np.array([2, 3, 4]),
25
25
  np.array([2, 4, 5]),
26
+ np.array([2, 4, 5, 4]),
26
27
  np.nan,
27
28
  # these cases test that plugins are passed to recursions
28
29
  # in this case, the numpy equality plugin
@@ -31,6 +32,7 @@ if _check_soft_dependencies("numpy", severity="none"):
31
32
  # test case to cover branch re dtype and equal_nan
32
33
  np.array([0.1, 1], dtype="object"),
33
34
  np.array([0.2, 1], dtype="object"),
35
+ np.array([0.2, 1, 4], dtype="object"),
34
36
  ]
35
37
 
36
38
  if _check_soft_dependencies("pandas", severity="none"):
@@ -39,12 +41,14 @@ if _check_soft_dependencies("pandas", severity="none"):
39
41
  EXAMPLES += [
40
42
  pd.DataFrame({"a": [4, 2]}),
41
43
  pd.DataFrame({"a": [4, 3]}),
44
+ pd.DataFrame({"a": [4, 3, 5]}),
42
45
  pd.DataFrame({"a": ["4", "3"]}),
43
46
  (np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]),
44
47
  {"foo": [42], "bar": pd.Series([1, 2])},
45
48
  {"bar": [42], "foo": pd.Series([1, 2])},
46
49
  pd.Index([1, 2, 3]),
47
50
  pd.Index([2, 3, 4]),
51
+ pd.Index([2, 3, 4, 6]),
48
52
  ]
49
53
 
50
54
  # nested DataFrame example