spforge 0.8.4__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/lol/pipeline_transformer_example.py +69 -86
- examples/nba/cross_validation_example.py +4 -11
- examples/nba/feature_engineering_example.py +33 -15
- examples/nba/game_winner_example.py +24 -14
- examples/nba/predictor_transformers_example.py +29 -16
- spforge/__init__.py +1 -0
- spforge/hyperparameter_tuning/__init__.py +12 -0
- spforge/hyperparameter_tuning/_default_search_spaces.py +159 -1
- spforge/hyperparameter_tuning/_tuner.py +192 -0
- spforge/ratings/__init__.py +4 -0
- spforge/ratings/_player_rating.py +11 -0
- spforge/ratings/league_start_rating_optimizer.py +201 -0
- {spforge-0.8.4.dist-info → spforge-0.8.7.dist-info}/METADATA +12 -19
- {spforge-0.8.4.dist-info → spforge-0.8.7.dist-info}/RECORD +23 -19
- tests/end_to_end/test_estimator_hyperparameter_tuning.py +85 -0
- tests/end_to_end/test_league_start_rating_optimizer.py +117 -0
- tests/end_to_end/test_nba_player_ratings_hyperparameter_tuning.py +5 -0
- tests/hyperparameter_tuning/test_estimator_tuner.py +167 -0
- tests/ratings/test_player_rating_generator.py +27 -0
- tests/scorer/test_score.py +90 -0
- {spforge-0.8.4.dist-info → spforge-0.8.7.dist-info}/WHEEL +0 -0
- {spforge-0.8.4.dist-info → spforge-0.8.7.dist-info}/licenses/LICENSE +0 -0
- {spforge-0.8.4.dist-info → spforge-0.8.7.dist-info}/top_level.txt +0 -0
tests/scorer/test_score.py
CHANGED
|
@@ -2048,3 +2048,93 @@ def test_all_scorers_handle_all_nan_targets(df_type):
|
|
|
2048
2048
|
assert np.isnan(score) or score == 0.0
|
|
2049
2049
|
except (ValueError, IndexError):
|
|
2050
2050
|
pass
|
|
2051
|
+
SCORER_VALIDATION_CASES = [
|
|
2052
|
+
pytest.param(
|
|
2053
|
+
lambda: MeanBiasScorer(pred_column="pred", target="target", validation_column="is_validation"),
|
|
2054
|
+
lambda: pd.DataFrame(
|
|
2055
|
+
{
|
|
2056
|
+
"pred": [2.0, 0.0],
|
|
2057
|
+
"target": [1.0, 2.0],
|
|
2058
|
+
"is_validation": [1, 0],
|
|
2059
|
+
}
|
|
2060
|
+
),
|
|
2061
|
+
id="mean_bias",
|
|
2062
|
+
),
|
|
2063
|
+
pytest.param(
|
|
2064
|
+
lambda: PWMSE(pred_column="pred", target="target", labels=[0, 1], validation_column="is_validation"),
|
|
2065
|
+
lambda: pd.DataFrame(
|
|
2066
|
+
{
|
|
2067
|
+
"pred": [[0.7, 0.3], [0.4, 0.6]],
|
|
2068
|
+
"target": [0, 1],
|
|
2069
|
+
"is_validation": [1, 0],
|
|
2070
|
+
}
|
|
2071
|
+
),
|
|
2072
|
+
id="pwmse",
|
|
2073
|
+
),
|
|
2074
|
+
pytest.param(
|
|
2075
|
+
lambda: SklearnScorer(
|
|
2076
|
+
scorer_function=mean_absolute_error, pred_column="pred", target="target", validation_column="is_validation"
|
|
2077
|
+
),
|
|
2078
|
+
lambda: pd.DataFrame(
|
|
2079
|
+
{
|
|
2080
|
+
"pred": [1.0, 0.0],
|
|
2081
|
+
"target": [1.0, 0.0],
|
|
2082
|
+
"is_validation": [1, 0],
|
|
2083
|
+
}
|
|
2084
|
+
),
|
|
2085
|
+
id="sklearn",
|
|
2086
|
+
),
|
|
2087
|
+
pytest.param(
|
|
2088
|
+
lambda: ProbabilisticMeanBias(
|
|
2089
|
+
pred_column="pred", target="target", class_column_name="classes", validation_column="is_validation"
|
|
2090
|
+
),
|
|
2091
|
+
lambda: pd.DataFrame(
|
|
2092
|
+
{
|
|
2093
|
+
"pred": [[0.2, 0.8], [0.6, 0.4]],
|
|
2094
|
+
"target": [1, 0],
|
|
2095
|
+
"classes": [[0, 1], [0, 1]],
|
|
2096
|
+
"is_validation": [1, 0],
|
|
2097
|
+
}
|
|
2098
|
+
),
|
|
2099
|
+
id="probabilistic_mean_bias",
|
|
2100
|
+
),
|
|
2101
|
+
pytest.param(
|
|
2102
|
+
lambda: OrdinalLossScorer(pred_column="pred", target="target", classes=[0, 1], validation_column="is_validation"),
|
|
2103
|
+
lambda: pd.DataFrame(
|
|
2104
|
+
{
|
|
2105
|
+
"pred": [[0.2, 0.8], [0.6, 0.4]],
|
|
2106
|
+
"target": [1, 0],
|
|
2107
|
+
"is_validation": [1, 0],
|
|
2108
|
+
}
|
|
2109
|
+
),
|
|
2110
|
+
id="ordinal_loss",
|
|
2111
|
+
),
|
|
2112
|
+
pytest.param(
|
|
2113
|
+
lambda: ThresholdEventScorer(
|
|
2114
|
+
dist_column="dist",
|
|
2115
|
+
threshold_column="threshold",
|
|
2116
|
+
outcome_column="outcome",
|
|
2117
|
+
comparator=Operator.GREATER_THAN_OR_EQUALS,
|
|
2118
|
+
validation_column="is_validation",
|
|
2119
|
+
),
|
|
2120
|
+
lambda: pd.DataFrame(
|
|
2121
|
+
{
|
|
2122
|
+
"dist": [[0.2, 0.8], [0.6, 0.4], [0.3, 0.7]],
|
|
2123
|
+
"threshold": [0.5, 0.2, 0.3],
|
|
2124
|
+
"outcome": [1, 0, 1],
|
|
2125
|
+
"is_validation": [1, 1, 0],
|
|
2126
|
+
}
|
|
2127
|
+
),
|
|
2128
|
+
id="threshold_event",
|
|
2129
|
+
),
|
|
2130
|
+
]
|
|
2131
|
+
|
|
2132
|
+
|
|
2133
|
+
@pytest.mark.parametrize("scorer_factory, df_factory", SCORER_VALIDATION_CASES)
|
|
2134
|
+
def test_scorers_respect_validation_column(scorer_factory, df_factory):
|
|
2135
|
+
"""Scorers should filter on validation_column when specified."""
|
|
2136
|
+
df = df_factory()
|
|
2137
|
+
df_valid = df[df["is_validation"] == 1]
|
|
2138
|
+
score_all = scorer_factory().score(df)
|
|
2139
|
+
score_valid = scorer_factory().score(df_valid)
|
|
2140
|
+
assert score_all == score_valid
|
|
File without changes
|
|
File without changes
|
|
File without changes
|