spforge 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spforge/__init__.py CHANGED
@@ -2,6 +2,7 @@ from .autopipeline import AutoPipeline as AutoPipeline
2
2
  from .data_structures import ColumnNames as ColumnNames, GameColumnNames as GameColumnNames
3
3
  from .features_generator_pipeline import FeatureGeneratorPipeline as FeatureGeneratorPipeline
4
4
  from .hyperparameter_tuning import (
5
+ EstimatorHyperparameterTuner as EstimatorHyperparameterTuner,
5
6
  OptunaResult as OptunaResult,
6
7
  ParamSpec as ParamSpec,
7
8
  RatingHyperparameterTuner as RatingHyperparameterTuner,
@@ -1,9 +1,15 @@
1
1
  from spforge.hyperparameter_tuning._default_search_spaces import (
2
+ get_default_estimator_search_space,
3
+ get_default_lgbm_search_space,
4
+ get_default_negative_binomial_search_space,
5
+ get_default_normal_distribution_search_space,
2
6
  get_default_player_rating_search_space,
3
7
  get_default_search_space,
8
+ get_default_student_t_search_space,
4
9
  get_default_team_rating_search_space,
5
10
  )
6
11
  from spforge.hyperparameter_tuning._tuner import (
12
+ EstimatorHyperparameterTuner,
7
13
  OptunaResult,
8
14
  ParamSpec,
9
15
  RatingHyperparameterTuner,
@@ -11,9 +17,15 @@ from spforge.hyperparameter_tuning._tuner import (
11
17
 
12
18
  __all__ = [
13
19
  "RatingHyperparameterTuner",
20
+ "EstimatorHyperparameterTuner",
14
21
  "ParamSpec",
15
22
  "OptunaResult",
23
+ "get_default_estimator_search_space",
24
+ "get_default_lgbm_search_space",
25
+ "get_default_negative_binomial_search_space",
26
+ "get_default_normal_distribution_search_space",
16
27
  "get_default_player_rating_search_space",
17
28
  "get_default_team_rating_search_space",
29
+ "get_default_student_t_search_space",
18
30
  "get_default_search_space",
19
31
  ]
@@ -1,5 +1,126 @@
1
1
  from spforge.hyperparameter_tuning._tuner import ParamSpec
2
2
  from spforge.ratings import PlayerRatingGenerator, TeamRatingGenerator
3
+ from spforge.distributions import (
4
+ NegativeBinomialEstimator,
5
+ NormalDistributionPredictor,
6
+ StudentTDistributionEstimator,
7
+ )
8
+
9
+
10
+ def _is_lightgbm_estimator(obj: object) -> bool:
11
+ mod = (getattr(type(obj), "__module__", "") or "").lower()
12
+ name = type(obj).__name__
13
+ if "lightgbm" in mod:
14
+ return True
15
+ return bool(name.startswith("LGBM"))
16
+
17
+
18
+ def get_default_lgbm_search_space() -> dict[str, ParamSpec]:
19
+ return {
20
+ "n_estimators": ParamSpec(
21
+ param_type="int",
22
+ low=50,
23
+ high=800,
24
+ log=True,
25
+ ),
26
+ "num_leaves": ParamSpec(
27
+ param_type="int",
28
+ low=16,
29
+ high=256,
30
+ log=True,
31
+ ),
32
+ "max_depth": ParamSpec(
33
+ param_type="int",
34
+ low=3,
35
+ high=12,
36
+ ),
37
+ "min_child_samples": ParamSpec(
38
+ param_type="int",
39
+ low=10,
40
+ high=200,
41
+ log=True,
42
+ ),
43
+ "subsample": ParamSpec(
44
+ param_type="float",
45
+ low=0.6,
46
+ high=1.0,
47
+ ),
48
+ "subsample_freq": ParamSpec(
49
+ param_type="int",
50
+ low=1,
51
+ high=7,
52
+ ),
53
+ "reg_alpha": ParamSpec(
54
+ param_type="float",
55
+ low=1e-8,
56
+ high=10.0,
57
+ log=True,
58
+ ),
59
+ "reg_lambda": ParamSpec(
60
+ param_type="float",
61
+ low=1e-8,
62
+ high=10.0,
63
+ log=True,
64
+ ),
65
+ }
66
+
67
+
68
+ def get_default_negative_binomial_search_space() -> dict[str, ParamSpec]:
69
+ return {
70
+ "predicted_r_weight": ParamSpec(
71
+ param_type="float",
72
+ low=0.0,
73
+ high=1.0,
74
+ ),
75
+ "r_rolling_mean_window": ParamSpec(
76
+ param_type="int",
77
+ low=10,
78
+ high=120,
79
+ ),
80
+ "predicted_r_iterations": ParamSpec(
81
+ param_type="int",
82
+ low=2,
83
+ high=12,
84
+ ),
85
+ }
86
+
87
+
88
+ def get_default_normal_distribution_search_space() -> dict[str, ParamSpec]:
89
+ return {
90
+ "sigma": ParamSpec(
91
+ param_type="float",
92
+ low=0.5,
93
+ high=30.0,
94
+ log=True,
95
+ ),
96
+ }
97
+
98
+
99
+ def get_default_student_t_search_space() -> dict[str, ParamSpec]:
100
+ return {
101
+ "df": ParamSpec(
102
+ param_type="float",
103
+ low=3.0,
104
+ high=30.0,
105
+ log=True,
106
+ ),
107
+ "min_sigma": ParamSpec(
108
+ param_type="float",
109
+ low=0.5,
110
+ high=10.0,
111
+ log=True,
112
+ ),
113
+ "sigma_bins": ParamSpec(
114
+ param_type="int",
115
+ low=4,
116
+ high=12,
117
+ ),
118
+ "min_bin_rows": ParamSpec(
119
+ param_type="int",
120
+ low=10,
121
+ high=100,
122
+ ),
123
+ }
3
124
 
4
125
 
5
126
  def get_default_player_rating_search_space() -> dict[str, ParamSpec]:
@@ -120,3 +241,15 @@ def get_default_search_space(
120
241
  f"Unsupported rating generator type: {type(rating_generator)}. "
121
242
  "Expected PlayerRatingGenerator or TeamRatingGenerator."
122
243
  )
244
+
245
+
246
+ def get_default_estimator_search_space(estimator: object) -> dict[str, ParamSpec]:
247
+ if _is_lightgbm_estimator(estimator):
248
+ return get_default_lgbm_search_space()
249
+ if isinstance(estimator, NegativeBinomialEstimator):
250
+ return get_default_negative_binomial_search_space()
251
+ if isinstance(estimator, NormalDistributionPredictor):
252
+ return get_default_normal_distribution_search_space()
253
+ if isinstance(estimator, StudentTDistributionEstimator):
254
+ return get_default_student_t_search_space()
255
+ return {}
@@ -45,6 +45,8 @@ class ParamSpec:
45
45
  elif self.param_type == "int":
46
46
  if self.low is None or self.high is None:
47
47
  raise ValueError(f"int parameter '{name}' requires low and high bounds")
48
+ if self.step is None:
49
+ return trial.suggest_int(name, int(self.low), int(self.high))
48
50
  return trial.suggest_int(name, int(self.low), int(self.high), step=self.step)
49
51
  elif self.param_type == "categorical":
50
52
  if self.choices is None:
@@ -272,3 +274,193 @@ class RatingHyperparameterTuner:
272
274
  raise ValueError("Scorer returned invalid values in dict")
273
275
  return float(np.mean(values))
274
276
  return float(score)
277
+
278
+
279
+ def _is_estimator(obj: object) -> bool:
280
+ return hasattr(obj, "get_params") and hasattr(obj, "set_params")
281
+
282
+
283
+ def _get_leaf_estimator_paths(estimator: Any) -> dict[str, Any]:
284
+ if not _is_estimator(estimator):
285
+ raise ValueError("estimator must implement get_params and set_params")
286
+
287
+ params = estimator.get_params(deep=True)
288
+ estimator_keys = [k for k, v in params.items() if _is_estimator(v)]
289
+
290
+ if not estimator_keys:
291
+ return {"": estimator}
292
+
293
+ leaves: list[str] = []
294
+ for key in estimator_keys:
295
+ if not any(other != key and other.startswith(f"{key}__") for other in estimator_keys):
296
+ leaves.append(key)
297
+
298
+ return {key: params[key] for key in sorted(leaves)}
299
+
300
+
301
+ def _build_search_space_for_targets(
302
+ targets: dict[str, dict[str, ParamSpec]],
303
+ ) -> dict[str, ParamSpec]:
304
+ search_space: dict[str, ParamSpec] = {}
305
+ for path, params in targets.items():
306
+ for param_name, param_spec in params.items():
307
+ full_name = f"{path}__{param_name}" if path else param_name
308
+ if full_name in search_space:
309
+ raise ValueError(f"Duplicate parameter name detected: {full_name}")
310
+ search_space[full_name] = param_spec
311
+ return search_space
312
+
313
+
314
+ def _enqueue_predicted_r_weight_zero(study: optuna.Study, search_space: dict[str, ParamSpec]):
315
+ zero_params: dict[str, float] = {}
316
+ for name, spec in search_space.items():
317
+ if not name.endswith("predicted_r_weight"):
318
+ continue
319
+ if spec.param_type not in {"float", "int"}:
320
+ continue
321
+ if spec.low is None or spec.high is None:
322
+ continue
323
+ if spec.low <= 0 <= spec.high:
324
+ zero_params[name] = 0.0
325
+
326
+ if zero_params:
327
+ study.enqueue_trial(zero_params)
328
+
329
+
330
+ class EstimatorHyperparameterTuner:
331
+ """
332
+ Hyperparameter tuner for sklearn-compatible estimators.
333
+
334
+ Supports nested estimators and can target deepest leaf estimators.
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ estimator: Any,
340
+ cross_validator: MatchKFoldCrossValidator,
341
+ scorer: BaseScorer,
342
+ direction: Literal["minimize", "maximize"],
343
+ param_search_space: dict[str, ParamSpec] | None = None,
344
+ param_targets: dict[str, dict[str, ParamSpec]] | None = None,
345
+ n_trials: int = 50,
346
+ n_jobs: int = 1,
347
+ storage: str | None = None,
348
+ study_name: str | None = None,
349
+ timeout: float | None = None,
350
+ show_progress_bar: bool = True,
351
+ sampler: optuna.samplers.BaseSampler | None = None,
352
+ pruner: optuna.pruners.BasePruner | None = None,
353
+ ):
354
+ self.estimator = estimator
355
+ self.cross_validator = cross_validator
356
+ self.scorer = scorer
357
+ self.direction = direction
358
+ self.param_search_space = param_search_space
359
+ self.param_targets = param_targets
360
+ self.n_trials = n_trials
361
+ self.n_jobs = n_jobs
362
+ self.storage = storage
363
+ self.study_name = study_name
364
+ self.timeout = timeout
365
+ self.show_progress_bar = show_progress_bar
366
+ self.sampler = sampler
367
+ self.pruner = pruner
368
+
369
+ if direction not in ["minimize", "maximize"]:
370
+ raise ValueError(f"direction must be 'minimize' or 'maximize', got: {direction}")
371
+
372
+ if storage is not None and study_name is None:
373
+ raise ValueError("study_name is required when using storage")
374
+
375
+ if param_search_space is not None and param_targets is not None:
376
+ raise ValueError("param_search_space and param_targets cannot both be provided")
377
+
378
+ def optimize(self, df: IntoFrameT) -> OptunaResult:
379
+ from spforge.hyperparameter_tuning._default_search_spaces import (
380
+ get_default_estimator_search_space,
381
+ )
382
+
383
+ leaf_estimators = _get_leaf_estimator_paths(self.estimator)
384
+ default_targets = {
385
+ path: get_default_estimator_search_space(est)
386
+ for path, est in leaf_estimators.items()
387
+ }
388
+ default_targets = {path: space for path, space in default_targets.items() if space}
389
+
390
+ if self.param_targets is not None:
391
+ unknown = set(self.param_targets) - set(leaf_estimators)
392
+ if unknown:
393
+ raise ValueError(f"param_targets contains unknown estimator paths: {unknown}")
394
+ targets = self.param_targets
395
+ elif self.param_search_space is not None:
396
+ targets = {path: self.param_search_space for path in leaf_estimators}
397
+ elif default_targets:
398
+ targets = default_targets
399
+ else:
400
+ raise ValueError(
401
+ "param_search_space is required when no default search space is available"
402
+ )
403
+
404
+ search_space = _build_search_space_for_targets(targets)
405
+ if not search_space:
406
+ raise ValueError("Resolved search space is empty")
407
+
408
+ study = optuna.create_study(
409
+ direction=self.direction,
410
+ sampler=self.sampler,
411
+ pruner=self.pruner,
412
+ storage=self.storage,
413
+ study_name=self.study_name,
414
+ load_if_exists=True if self.storage else False,
415
+ )
416
+
417
+ _enqueue_predicted_r_weight_zero(study, search_space)
418
+
419
+ study.optimize(
420
+ lambda trial: self._objective(trial, df, search_space),
421
+ n_trials=self.n_trials,
422
+ n_jobs=self.n_jobs,
423
+ timeout=self.timeout,
424
+ show_progress_bar=self.show_progress_bar,
425
+ )
426
+
427
+ return OptunaResult(
428
+ best_params=study.best_params,
429
+ best_value=study.best_value,
430
+ best_trial=study.best_trial,
431
+ study=study,
432
+ )
433
+
434
+ def _objective(
435
+ self, trial: optuna.Trial, df: IntoFrameT, search_space: dict[str, ParamSpec]
436
+ ) -> float:
437
+ try:
438
+ trial_params = self._suggest_params(trial, search_space)
439
+
440
+ copied_estimator = copy.deepcopy(self.estimator)
441
+ copied_estimator.set_params(**trial_params)
442
+
443
+ cv = copy.deepcopy(self.cross_validator)
444
+ cv.estimator = copied_estimator
445
+
446
+ validation_df = cv.generate_validation_df(df)
447
+ score = self.scorer.score(validation_df)
448
+ score_value = RatingHyperparameterTuner._aggregate_score(score)
449
+
450
+ if math.isnan(score_value) or math.isinf(score_value):
451
+ logger.warning(f"Trial {trial.number} returned invalid score: {score_value}")
452
+ return float("inf") if self.direction == "minimize" else float("-inf")
453
+
454
+ return score_value
455
+
456
+ except Exception as e:
457
+ logger.warning(f"Trial {trial.number} failed with error: {e}")
458
+ return float("inf") if self.direction == "minimize" else float("-inf")
459
+
460
+ def _suggest_params(
461
+ self, trial: optuna.Trial, search_space: dict[str, ParamSpec]
462
+ ) -> dict[str, Any]:
463
+ params: dict[str, Any] = {}
464
+ for param_name, param_spec in search_space.items():
465
+ params[param_name] = param_spec.suggest(trial, param_name)
466
+ return params
spforge/scorer/_score.py CHANGED
@@ -1391,4 +1391,6 @@ class ThresholdEventScorer(BaseScorer):
1391
1391
  df, self.outcome_column, labels, self.naive_granularity
1392
1392
  )
1393
1393
  naive_score = self._score_with_probabilities(df, naive_list)
1394
+ if isinstance(score, dict) and isinstance(naive_score, dict):
1395
+ return {k: naive_score[k] - score[k] for k in score.keys()}
1394
1396
  return float(naive_score - score)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: spforge
3
- Version: 0.8.3
3
+ Version: 0.8.5
4
4
  Summary: A flexible framework for generating features, ratings, and building machine learning or other models for training and inference on sports data.
5
5
  Author-email: Mathias Holmstrøm <mathiasholmstom@gmail.com>
6
6
  License: See LICENSE file
@@ -17,7 +17,7 @@ Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: numpy>=1.23.4
19
19
  Requires-Dist: optuna>=3.4.0
20
- Requires-Dist: pandas>=2.0.0
20
+ Requires-Dist: pandas<3.0.0,>=2.0.0
21
21
  Requires-Dist: pendulum>=1.0.0
22
22
  Requires-Dist: scikit-learn>=1.4.0
23
23
  Requires-Dist: lightgbm>=4.0.0
@@ -13,7 +13,7 @@ examples/nba/predictor_transformers_example.py,sha256=mPXRVPx4J5VZtxYH89k7pwh7_E
13
13
  examples/nba/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  examples/nba/data/game_player_subsample.parquet,sha256=ODJxHC-mUYbJ7r-ScUFtPU7hrFuxLUbbDSobmpCkw0w,279161
15
15
  examples/nba/data/utils.py,sha256=41hxLQ1d6ZgBEcHa5MI0-fG5KbsRi07cclMPQZM95ek,509
16
- spforge/__init__.py,sha256=5d9zzBxaaXj2JeBNwfUwuV7Ll5FERHyXONsFiuKhHSQ,402
16
+ spforge/__init__.py,sha256=8vZhy7XUpzqWkVKpXqwqOLDkQlNytRhyf4qjwObfXgU,468
17
17
  spforge/autopipeline.py,sha256=ZUwv6Q6O8cD0u5TiSqG6lhW0j16RlSb160AzuOeL2R8,23186
18
18
  spforge/base_feature_generator.py,sha256=RbD00N6oLCQQcEb_VF5wbwZztl-X8k9B0Wlaj9Os1iU,668
19
19
  spforge/data_structures.py,sha256=k82v5r79vl0_FAVvsxVF9Nbzb5FoHqVrlHZlEXGc5gQ,7298
@@ -43,9 +43,9 @@ spforge/feature_generator/_rolling_mean_binary.py,sha256=lmODy-o9Dd9pb8IlA7g4UyA
43
43
  spforge/feature_generator/_rolling_mean_days.py,sha256=EZQmFmYVQB-JjZV5k8bOWnaTxNpPDCZAjdfdhiiG4r4,8415
44
44
  spforge/feature_generator/_rolling_window.py,sha256=HT8LezsRIPNAlMEoP9oTPW2bKFu55ZSRnQZGST7fncw,8836
45
45
  spforge/feature_generator/_utils.py,sha256=KDn33ia1OYJTK8THFpvc_uRiH_Bl3fImGqqbfzs0YA4,9654
46
- spforge/hyperparameter_tuning/__init__.py,sha256=pp7aWzydObRawFLcGiaUrUduEQIjln2uif9nKCTk6l4,509
47
- spforge/hyperparameter_tuning/_default_search_spaces.py,sha256=19sHW8zlyG88xZdyqSrp9gFI5oLb-f6THlbhYAtTfmY,3534
48
- spforge/hyperparameter_tuning/_tuner.py,sha256=S70IEmHxl36LaUPl_wc_2mo46qUuH8t0eH0aXuCuGfA,9586
46
+ spforge/hyperparameter_tuning/__init__.py,sha256=N2sKG4SvG41hlsFT2kx_DQYMmXsQr-8031Tu_rxlxyY,1015
47
+ spforge/hyperparameter_tuning/_default_search_spaces.py,sha256=entdE7gtj8JM5C47-lLd93CoEsXjw8YfcWeWS8d0AZk,6882
48
+ spforge/hyperparameter_tuning/_tuner.py,sha256=uovhGqhe8-fdhi79aErUmE2h5NCycFQEIRv5WCjpC7E,16732
49
49
  spforge/performance_transformers/__init__.py,sha256=U6d7_kltbUMLYCGBk4QAFVPJTxXD3etD9qUftV-O3q4,422
50
50
  spforge/performance_transformers/_performance_manager.py,sha256=KwAga6dGhNkXi-MDW6LPjwk6VZwCcjo5L--jnk9aio8,9706
51
51
  spforge/performance_transformers/_performances_transformers.py,sha256=0lxuWjAfWBRXRgQsNJHjw3P-nlTtHBu4_bOVdoy7hq4,15536
@@ -61,7 +61,7 @@ spforge/ratings/team_performance_predictor.py,sha256=ThQOmYQUqKBB46ONYHOMM2arXFH
61
61
  spforge/ratings/team_start_rating_generator.py,sha256=ZJe84sTvE4Yep3d4wKJMMJn2Q4PhcCwkO7Wyd5nsYUA,5110
62
62
  spforge/ratings/utils.py,sha256=qms5J5SD-FyXDR2G8giDMbu_AoLgI135pjW4nghxROg,3940
63
63
  spforge/scorer/__init__.py,sha256=wj8PCvYIl6742Xwmt86c3oy6iqE8Ss-OpwHud6kd9IY,256
64
- spforge/scorer/_score.py,sha256=f_0SiBYdlxbjuK6frnCf8fUJ7Tbi7XL1Rx1_1khHfNg,56042
64
+ spforge/scorer/_score.py,sha256=TR0T9nJj0aeVgGfOE0fZmXlO66CELulYwxhi7ZAxhvY,56184
65
65
  spforge/transformers/__init__.py,sha256=IPCsMcsgBqG52d0ttATLCY4HvFCQZddExlLt74U-zuI,390
66
66
  spforge/transformers/_base.py,sha256=-smr_McQF9bYxM5-Agx6h7Xv_fhZzPfpAdQV-qK18bs,1134
67
67
  spforge/transformers/_net_over_predicted.py,sha256=5dC8pvA1DNO0yXPSgJSMGU8zAHi-maUELm7FqFQVo-U,2321
@@ -70,12 +70,13 @@ spforge/transformers/_other_transformer.py,sha256=xLfaFIhkFsigAoitB4x3F8An2j9ymd
70
70
  spforge/transformers/_predictor.py,sha256=2sE6gfVrilXzPVcBurSrtqHw33v2ljygQcEYXt9LhZc,3119
71
71
  spforge/transformers/_simple_transformer.py,sha256=zGUFNQYMeoDSa2CoQejQNiNmKCBN5amWTvyOchiUHj0,5660
72
72
  spforge/transformers/_team_ratio_predictor.py,sha256=g8_bR53Yyv0iNCtol1O9bgJSeZcIco_AfbQuUxQJkeY,6884
73
- spforge-0.8.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
73
+ spforge-0.8.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
74
  tests/test_autopipeline.py,sha256=WXHeqBdjQD6xaXVkzvS8ocz0WVP9R7lN0PiHJ2iD8nA,16911
75
75
  tests/test_autopipeline_context.py,sha256=IuRUY4IA6uMObvbl2pXSaXO2_tl3qX6wEbTZY0dkTMI,1240
76
76
  tests/test_feature_generator_pipeline.py,sha256=CAgBknWqawqYi5_hxcPmpxrLVa5elMHVv1VrSVRKXEA,17705
77
77
  tests/cross_validator/test_cross_validator.py,sha256=itCGhNY8-NbDbKbhxHW20wiLuRst7-Rixpmi3FSKQtA,17474
78
78
  tests/distributions/test_distribution.py,sha256=aU8hfCgliM80TES4WGjs9KFXpV8XghBGF7Hu9sqEVSE,10982
79
+ tests/end_to_end/test_estimator_hyperparameter_tuning.py,sha256=fZCJ9rrED2vT68B9ovmVA1cIG2pHRTjy9xzZLxxpEBo,2513
79
80
  tests/end_to_end/test_lol_player_kills.py,sha256=RJSYUbPrZ-RzSxGggj03yN0JKYeTB1JghVGYFMYia3Y,11891
80
81
  tests/end_to_end/test_nba_player_points.py,sha256=kyzjo7QIcvpteps29Wix6IS_eJG9d1gHLeWtIHpkWMs,9066
81
82
  tests/end_to_end/test_nba_player_ratings_hyperparameter_tuning.py,sha256=eOsTSVWv16bc0l_nCxH4x8jF-gsmn4Ttfv92mHqSXzc,6303
@@ -87,13 +88,14 @@ tests/feature_generator/test_rolling_against_opponent.py,sha256=20kH1INrWy6DV7AS
87
88
  tests/feature_generator/test_rolling_mean_binary.py,sha256=KuIavJ37Pt8icAb50B23lxdWEPVSHQ7NZHisD1BDpmU,16216
88
89
  tests/feature_generator/test_rolling_mean_days.py,sha256=EyOvdJDnmgPfe13uQBOkwo7fAteBQx-tnyuGM4ng2T8,18884
89
90
  tests/feature_generator/test_rolling_window.py,sha256=YBJo36OK3ILYeXrH06ylXqviUcCaGYaVQaK5RJzwM7Y,23239
91
+ tests/hyperparameter_tuning/test_estimator_tuner.py,sha256=iewME41d6LR2aQ0OtohGFtN_ocJUwTeqvs6L0QDmfG4,4413
90
92
  tests/hyperparameter_tuning/test_rating_tuner.py,sha256=PyCFP3KPc4Iy9E_X9stCVxra14uMgC1tuRwuQ30rO_o,13195
91
93
  tests/performance_transformers/test_performance_manager.py,sha256=bfC5GiBuzHw-mLmKeEzBUUPuKm0ayax2bsF1j88W8L0,10120
92
94
  tests/performance_transformers/test_performances_transformers.py,sha256=A-tGiCx7kXrj1cVj03Bc7prOeZ1_Ryz8YFx9uj3eK6w,11064
93
95
  tests/ratings/test_player_rating_generator.py,sha256=3mjqlX159QqOlBoY3r_TFkvLwpE4zlLE0fiqpbfk3ps,58547
94
96
  tests/ratings/test_ratings_property.py,sha256=ckyfGILXa4tfQvsgyXEzBDNr2DUmHwFRV13N60w66iE,6561
95
97
  tests/ratings/test_team_rating_generator.py,sha256=cDnf1zHiYC7pkgydE3MYr8wSTJIq-bPfSqhIRI_4Tic,95357
96
- tests/scorer/test_score.py,sha256=whsHBI0VGes_RGZXlcSRQz5h2aMtTDMzSJGyMeFm-H8,67864
98
+ tests/scorer/test_score.py,sha256=KTrGJypQEpU8tmgJ6LU8wK1SRC3PLUXFzZIyiA-UY7U,71749
97
99
  tests/scorer/test_score_aggregation_granularity.py,sha256=h-hyFOLzwp-92hYVU7CwvlRJ8jhB4DzXCtqgI-zcoqM,13677
98
100
  tests/transformers/test_estimator_transformer_context.py,sha256=5GOHbuWCWBMFwwOTJOuD4oNDsv-qDR0OxNZYGGuMdag,1819
99
101
  tests/transformers/test_net_over_predicted.py,sha256=vh7O1iRRPf4vcW9aLhOMAOyatfM5ZnLsQBKNAYsR3SU,3363
@@ -101,7 +103,7 @@ tests/transformers/test_other_transformer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
101
103
  tests/transformers/test_predictor_transformer.py,sha256=N1aBYLjN3ldpYZLwjih_gTFYSMitrZu-PNK78W6RHaQ,6877
102
104
  tests/transformers/test_simple_transformer.py,sha256=wWR0qjLb_uS4HXrJgGdiqugOY1X7kwd1_OPS02IT2b8,4676
103
105
  tests/transformers/test_team_ratio_predictor.py,sha256=fOUP_JvNJi-3kom3ZOs1EdG0I6Z8hpLpYKNHu1eWtOw,8562
104
- spforge-0.8.3.dist-info/METADATA,sha256=koQFZ1LxNPJVtmYcOLm1EZVRPUx-VyWETLA27kTGt2o,20219
105
- spforge-0.8.3.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
106
- spforge-0.8.3.dist-info/top_level.txt,sha256=6UW2M5a7WKOeaAi900qQmRKNj5-HZzE8-eUD9Y9LTq0,23
107
- spforge-0.8.3.dist-info/RECORD,,
106
+ spforge-0.8.5.dist-info/METADATA,sha256=bqArRdOKZYvSc47sa9cJsOhsDxh0q4T6GoF_xIBkjpA,20226
107
+ spforge-0.8.5.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
108
+ spforge-0.8.5.dist-info/top_level.txt,sha256=6UW2M5a7WKOeaAi900qQmRKNj5-HZzE8-eUD9Y9LTq0,23
109
+ spforge-0.8.5.dist-info/RECORD,,
@@ -0,0 +1,85 @@
1
+ import polars as pl
2
+ from sklearn.linear_model import LogisticRegression
3
+ from sklearn.metrics import mean_absolute_error
4
+
5
+ from examples import get_sub_sample_nba_data
6
+ from spforge import AutoPipeline, ColumnNames, EstimatorHyperparameterTuner, ParamSpec
7
+ from spforge.cross_validator import MatchKFoldCrossValidator
8
+ from spforge.scorer import SklearnScorer
9
+
10
+
11
+ def test_nba_estimator_hyperparameter_tuning__workflow_completes():
12
+ df = get_sub_sample_nba_data(as_polars=True, as_pandas=False)
13
+ column_names = ColumnNames(
14
+ team_id="team_id",
15
+ match_id="game_id",
16
+ start_date="start_date",
17
+ player_id="player_id",
18
+ participation_weight="minutes_ratio",
19
+ )
20
+
21
+ df = df.sort(
22
+ [
23
+ column_names.start_date,
24
+ column_names.match_id,
25
+ column_names.team_id,
26
+ column_names.player_id,
27
+ ]
28
+ )
29
+
30
+ df = df.with_columns(
31
+ [
32
+ (pl.col("minutes") / pl.col("minutes").sum().over("game_id")).alias(
33
+ "minutes_ratio"
34
+ ),
35
+ (pl.col("points") > pl.lit(10)).cast(pl.Int64).alias("points_over_10"),
36
+ ]
37
+ )
38
+
39
+ estimator = AutoPipeline(
40
+ estimator=LogisticRegression(max_iter=200),
41
+ estimator_features=["minutes", "minutes_ratio"],
42
+ )
43
+
44
+ cv = MatchKFoldCrossValidator(
45
+ match_id_column_name=column_names.match_id,
46
+ date_column_name=column_names.start_date,
47
+ target_column="points_over_10",
48
+ estimator=estimator,
49
+ prediction_column_name="points_pred",
50
+ n_splits=2,
51
+ features=estimator.required_features,
52
+ )
53
+
54
+ scorer = SklearnScorer(
55
+ scorer_function=mean_absolute_error,
56
+ pred_column="points_pred",
57
+ target="points_over_10",
58
+ validation_column="is_validation",
59
+ )
60
+
61
+ tuner = EstimatorHyperparameterTuner(
62
+ estimator=estimator,
63
+ cross_validator=cv,
64
+ scorer=scorer,
65
+ direction="minimize",
66
+ param_search_space={
67
+ "C": ParamSpec(
68
+ param_type="float",
69
+ low=0.1,
70
+ high=2.0,
71
+ log=True,
72
+ ),
73
+ },
74
+ n_trials=3,
75
+ show_progress_bar=False,
76
+ )
77
+
78
+ result = tuner.optimize(df)
79
+
80
+ assert result.best_params is not None
81
+ assert isinstance(result.best_params, dict)
82
+ assert "estimator__C" in result.best_params
83
+ assert isinstance(result.best_value, float)
84
+ assert result.best_trial is not None
85
+ assert result.study is not None
@@ -0,0 +1,167 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pytest
4
+ from sklearn.base import BaseEstimator
5
+ from sklearn.linear_model import LogisticRegression
6
+
7
+ from spforge import EstimatorHyperparameterTuner, ParamSpec
8
+ from spforge.cross_validator import MatchKFoldCrossValidator
9
+ from spforge.estimator import SkLearnEnhancerEstimator
10
+ from spforge.scorer import MeanBiasScorer
11
+
12
+
13
+ class FakeLGBMClassifier(BaseEstimator):
14
+ __module__ = "lightgbm.sklearn"
15
+
16
+ def __init__(
17
+ self,
18
+ n_estimators: int = 100,
19
+ num_leaves: int = 31,
20
+ max_depth: int = 5,
21
+ min_child_samples: int = 20,
22
+ subsample: float = 1.0,
23
+ subsample_freq: int = 1,
24
+ reg_alpha: float = 0.0,
25
+ reg_lambda: float = 0.0,
26
+ ):
27
+ self.n_estimators = n_estimators
28
+ self.num_leaves = num_leaves
29
+ self.max_depth = max_depth
30
+ self.min_child_samples = min_child_samples
31
+ self.subsample = subsample
32
+ self.subsample_freq = subsample_freq
33
+ self.reg_alpha = reg_alpha
34
+ self.reg_lambda = reg_lambda
35
+
36
+ def fit(self, X, y):
37
+ self.classes_ = np.unique(y)
38
+ return self
39
+
40
+ def predict_proba(self, X):
41
+ n = len(X)
42
+ if len(self.classes_) < 2:
43
+ return np.ones((n, 1))
44
+ return np.tile([0.4, 0.6], (n, 1))
45
+
46
+ def predict(self, X):
47
+ n = len(X)
48
+ if len(self.classes_) == 1:
49
+ return np.full(n, self.classes_[0])
50
+ proba = self.predict_proba(X)
51
+ idx = np.argmax(proba, axis=1)
52
+ return np.array(self.classes_)[idx]
53
+
54
+
55
+ @pytest.fixture
56
+ def sample_df():
57
+ dates = pd.date_range("2024-01-01", periods=12, freq="D")
58
+ rows = []
59
+ for i, date in enumerate(dates):
60
+ rows.append(
61
+ {
62
+ "mid": f"M{i // 2}",
63
+ "date": date,
64
+ "x1": float(i),
65
+ "y": 1 if i % 2 == 0 else 0,
66
+ }
67
+ )
68
+ return pd.DataFrame(rows)
69
+
70
+
71
+ @pytest.fixture
72
+ def scorer():
73
+ return MeanBiasScorer(
74
+ pred_column="y_pred",
75
+ target="y",
76
+ validation_column="is_validation",
77
+ )
78
+
79
+
80
+ def test_estimator_tuner_requires_search_space(sample_df, scorer):
81
+ estimator = LogisticRegression()
82
+
83
+ cv = MatchKFoldCrossValidator(
84
+ match_id_column_name="mid",
85
+ date_column_name="date",
86
+ target_column="y",
87
+ estimator=estimator,
88
+ prediction_column_name="y_pred",
89
+ n_splits=2,
90
+ features=["x1"],
91
+ )
92
+
93
+ tuner = EstimatorHyperparameterTuner(
94
+ estimator=estimator,
95
+ cross_validator=cv,
96
+ scorer=scorer,
97
+ direction="minimize",
98
+ n_trials=2,
99
+ show_progress_bar=False,
100
+ )
101
+
102
+ with pytest.raises(ValueError, match="param_search_space is required"):
103
+ tuner.optimize(sample_df)
104
+
105
+
106
+ def test_estimator_tuner_custom_search_space(sample_df, scorer):
107
+ estimator = SkLearnEnhancerEstimator(estimator=LogisticRegression())
108
+
109
+ cv = MatchKFoldCrossValidator(
110
+ match_id_column_name="mid",
111
+ date_column_name="date",
112
+ target_column="y",
113
+ estimator=estimator,
114
+ prediction_column_name="y_pred",
115
+ n_splits=2,
116
+ features=["x1"],
117
+ )
118
+
119
+ tuner = EstimatorHyperparameterTuner(
120
+ estimator=estimator,
121
+ cross_validator=cv,
122
+ scorer=scorer,
123
+ direction="minimize",
124
+ param_search_space={
125
+ "C": ParamSpec(
126
+ param_type="float",
127
+ low=0.1,
128
+ high=2.0,
129
+ log=True,
130
+ )
131
+ },
132
+ n_trials=2,
133
+ show_progress_bar=False,
134
+ )
135
+
136
+ result = tuner.optimize(sample_df)
137
+
138
+ assert "estimator__C" in result.best_params
139
+ assert isinstance(result.best_value, float)
140
+
141
+
142
+ def test_estimator_tuner_lgbm_defaults(sample_df, scorer):
143
+ estimator = FakeLGBMClassifier()
144
+
145
+ cv = MatchKFoldCrossValidator(
146
+ match_id_column_name="mid",
147
+ date_column_name="date",
148
+ target_column="y",
149
+ estimator=estimator,
150
+ prediction_column_name="y_pred",
151
+ n_splits=2,
152
+ features=["x1"],
153
+ )
154
+
155
+ tuner = EstimatorHyperparameterTuner(
156
+ estimator=estimator,
157
+ cross_validator=cv,
158
+ scorer=scorer,
159
+ direction="minimize",
160
+ n_trials=2,
161
+ show_progress_bar=False,
162
+ )
163
+
164
+ result = tuner.optimize(sample_df)
165
+
166
+ assert "n_estimators" in result.best_params
167
+ assert isinstance(result.best_value, float)
@@ -1892,6 +1892,129 @@ def test_pwmse__accepts_ndarray_predictions(df_type):
1892
1892
  assert score >= 0
1893
1893
 
1894
1894
 
1895
+ # ============================================================================
1896
+ # ThresholdEventScorer with granularity and compare_to_naive Tests
1897
+ # ============================================================================
1898
+
1899
+
1900
+ @pytest.mark.parametrize("df_type", [pl.DataFrame, pd.DataFrame])
1901
+ def test_threshold_event_scorer__granularity_with_compare_to_naive(df_type):
1902
+ """ThresholdEventScorer fails when combining compare_to_naive with granularity.
1903
+
1904
+ Bug: When granularity is set, binary_scorer.score() returns a dict, but
1905
+ the naive comparison tries to do dict - dict which fails with:
1906
+ 'unsupported operand type(s) for -: 'dict' and 'dict''
1907
+ """
1908
+ df = create_dataframe(
1909
+ df_type,
1910
+ {
1911
+ "qtr": [1, 1, 1, 2, 2, 2],
1912
+ "dist": [
1913
+ [0.1, 0.2, 0.3, 0.4],
1914
+ [0.2, 0.3, 0.3, 0.2],
1915
+ [0.3, 0.4, 0.2, 0.1],
1916
+ [0.4, 0.3, 0.2, 0.1],
1917
+ [0.1, 0.1, 0.4, 0.4],
1918
+ [0.2, 0.2, 0.3, 0.3],
1919
+ ],
1920
+ "ydstogo": [2.0, 3.0, 1.0, 2.0, 1.0, 3.0],
1921
+ "rush_yards": [3, 2, 0, 1, 2, 4],
1922
+ },
1923
+ )
1924
+
1925
+ scorer = ThresholdEventScorer(
1926
+ dist_column="dist",
1927
+ threshold_column="ydstogo",
1928
+ outcome_column="rush_yards",
1929
+ labels=[0, 1, 2, 3],
1930
+ compare_to_naive=True,
1931
+ granularity=["qtr"],
1932
+ )
1933
+
1934
+ result = scorer.score(df)
1935
+
1936
+ assert isinstance(result, dict)
1937
+ assert len(result) == 2
1938
+ assert (1,) in result
1939
+ assert (2,) in result
1940
+ assert all(isinstance(v, float) for v in result.values())
1941
+
1942
+
1943
+ @pytest.mark.parametrize("df_type", [pl.DataFrame, pd.DataFrame])
1944
+ def test_threshold_event_scorer__granularity_with_compare_to_naive_and_naive_granularity(df_type):
1945
+ """ThresholdEventScorer with both granularity and naive_granularity."""
1946
+ df = create_dataframe(
1947
+ df_type,
1948
+ {
1949
+ "qtr": [1, 1, 1, 2, 2, 2],
1950
+ "team": ["A", "A", "B", "A", "B", "B"],
1951
+ "dist": [
1952
+ [0.1, 0.2, 0.3, 0.4],
1953
+ [0.2, 0.3, 0.3, 0.2],
1954
+ [0.3, 0.4, 0.2, 0.1],
1955
+ [0.4, 0.3, 0.2, 0.1],
1956
+ [0.1, 0.1, 0.4, 0.4],
1957
+ [0.2, 0.2, 0.3, 0.3],
1958
+ ],
1959
+ "ydstogo": [2.0, 3.0, 1.0, 2.0, 1.0, 3.0],
1960
+ "rush_yards": [3, 2, 0, 1, 2, 4],
1961
+ },
1962
+ )
1963
+
1964
+ scorer = ThresholdEventScorer(
1965
+ dist_column="dist",
1966
+ threshold_column="ydstogo",
1967
+ outcome_column="rush_yards",
1968
+ labels=[0, 1, 2, 3],
1969
+ compare_to_naive=True,
1970
+ naive_granularity=["team"],
1971
+ granularity=["qtr"],
1972
+ )
1973
+
1974
+ result = scorer.score(df)
1975
+
1976
+ assert isinstance(result, dict)
1977
+ assert len(result) == 2
1978
+ assert (1,) in result
1979
+ assert (2,) in result
1980
+ assert all(isinstance(v, float) for v in result.values())
1981
+
1982
+
1983
+ @pytest.mark.parametrize("df_type", [pl.DataFrame, pd.DataFrame])
1984
+ def test_threshold_event_scorer__multi_column_granularity_with_compare_to_naive(df_type):
1985
+ """ThresholdEventScorer with multi-column granularity and compare_to_naive."""
1986
+ df = create_dataframe(
1987
+ df_type,
1988
+ {
1989
+ "qtr": [1, 1, 2, 2],
1990
+ "half": [1, 1, 2, 2],
1991
+ "dist": [
1992
+ [0.1, 0.2, 0.3, 0.4],
1993
+ [0.2, 0.3, 0.3, 0.2],
1994
+ [0.4, 0.3, 0.2, 0.1],
1995
+ [0.1, 0.1, 0.4, 0.4],
1996
+ ],
1997
+ "ydstogo": [2.0, 3.0, 2.0, 1.0],
1998
+ "rush_yards": [3, 2, 1, 2],
1999
+ },
2000
+ )
2001
+
2002
+ scorer = ThresholdEventScorer(
2003
+ dist_column="dist",
2004
+ threshold_column="ydstogo",
2005
+ outcome_column="rush_yards",
2006
+ labels=[0, 1, 2, 3],
2007
+ compare_to_naive=True,
2008
+ granularity=["qtr", "half"],
2009
+ )
2010
+
2011
+ result = scorer.score(df)
2012
+
2013
+ assert isinstance(result, dict)
2014
+ assert len(result) == 2
2015
+ assert all(isinstance(v, float) for v in result.values())
2016
+
2017
+
1895
2018
  @pytest.mark.parametrize("df_type", [pl.DataFrame, pd.DataFrame])
1896
2019
  def test_all_scorers_handle_all_nan_targets(df_type):
1897
2020
  """All scorers handle case where all targets are NaN"""