skweights 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skweights-0.1.0/PKG-INFO +17 -0
- skweights-0.1.0/README.md +0 -0
- skweights-0.1.0/pyproject.toml +29 -0
- skweights-0.1.0/setup.cfg +4 -0
- skweights-0.1.0/skweights/__init__.py +4 -0
- skweights-0.1.0/skweights/weighter.py +55 -0
- skweights-0.1.0/skweights/wrapper.py +117 -0
- skweights-0.1.0/skweights.egg-info/PKG-INFO +17 -0
- skweights-0.1.0/skweights.egg-info/SOURCES.txt +12 -0
- skweights-0.1.0/skweights.egg-info/dependency_links.txt +1 -0
- skweights-0.1.0/skweights.egg-info/requires.txt +3 -0
- skweights-0.1.0/skweights.egg-info/top_level.txt +1 -0
- skweights-0.1.0/tests/test_weighter.py +94 -0
- skweights-0.1.0/tests/test_wrapper.py +74 -0
skweights-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: skweights
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Scikit-learn compatible meta-estimators for heuristic business rules and feature weighting.
|
|
5
|
+
Author-email: Aron Kipkurui <aronidengeno@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/wizard-hash2/skweights
|
|
7
|
+
Project-URL: Bug Tracker, https://github.com/wizard-hash2/skweights/issues
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: scikit-learn>=1.0.0
|
|
16
|
+
Requires-Dist: pandas>=1.0.0
|
|
17
|
+
Requires-Dist: numpy>=1.20.0
|
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "skweights" # the name
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Aron Kipkurui", email="aronidengeno@gmail.com" },
|
|
10
|
+
]
|
|
11
|
+
description = "Scikit-learn compatible meta-estimators for heuristic business rules and feature weighting."
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.8"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"License :: OSI Approved :: MIT License",
|
|
17
|
+
"Operating System :: OS Independent",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"scikit-learn>=1.0.0",
|
|
23
|
+
"pandas>=1.0.0",
|
|
24
|
+
"numpy>=1.20.0"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.urls]
|
|
28
|
+
"Homepage" = "https://github.com/wizard-hash2/skweights"
|
|
29
|
+
"Bug Tracker" = "https://github.com/wizard-hash2/skweights/issues"
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
4
|
+
from sklearn.utils.validation import check_is_fitted
|
|
5
|
+
|
|
6
|
+
class FeatureWeighter(BaseEstimator, TransformerMixin):
|
|
7
|
+
"""
|
|
8
|
+
A transformer that applies a priori scalar weights to specific features.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
weights : dict, default=None
|
|
13
|
+
A dictionary mapping column names (for DataFrames) or
|
|
14
|
+
column indices (for NumPy arrays) to their scalar weights.
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, weights=None):
|
|
17
|
+
self.weights = weights
|
|
18
|
+
|
|
19
|
+
def fit(self, X, y=None):
|
|
20
|
+
"""
|
|
21
|
+
Stateless fit method. Validates initialization and returns self.
|
|
22
|
+
"""
|
|
23
|
+
# _is_fitted is a scikit-learn convention to prove the model
|
|
24
|
+
# has passed through the fit step of a pipeline.
|
|
25
|
+
self._is_fitted_ = True
|
|
26
|
+
return self
|
|
27
|
+
|
|
28
|
+
def transform(self, X):
|
|
29
|
+
"""
|
|
30
|
+
Applies the scalar weights to the defined features.
|
|
31
|
+
"""
|
|
32
|
+
# 1. State Validation
|
|
33
|
+
check_is_fitted(self, '_is_fitted_')
|
|
34
|
+
|
|
35
|
+
# 2. Bypass Logic
|
|
36
|
+
if self.weights is None or not self.weights:
|
|
37
|
+
return X
|
|
38
|
+
|
|
39
|
+
# 3. Immutability Principle
|
|
40
|
+
X_transformed = X.copy()
|
|
41
|
+
|
|
42
|
+
# 4. Type Routing & Matrix Operations
|
|
43
|
+
if isinstance(X_transformed, pd.DataFrame):
|
|
44
|
+
for col, weight in self.weights.items():
|
|
45
|
+
if col in X_transformed.columns:
|
|
46
|
+
X_transformed[col] = X_transformed[col] * weight
|
|
47
|
+
|
|
48
|
+
elif isinstance(X_transformed, np.ndarray):
|
|
49
|
+
for col_idx, weight in self.weights.items():
|
|
50
|
+
if isinstance(col_idx, int) and 0 <= col_idx < X_transformed.shape[1]:
|
|
51
|
+
X_transformed[:, col_idx] = X_transformed[:, col_idx] * weight
|
|
52
|
+
else:
|
|
53
|
+
raise TypeError("Input must be a Pandas DataFrame or a NumPy array.")
|
|
54
|
+
|
|
55
|
+
return X_transformed
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
|
|
5
|
+
from sklearn.utils.validation import check_is_fitted
|
|
6
|
+
|
|
7
|
+
# Safe mapping to avoid using eval() in production
|
|
8
|
+
OPERATOR_MAP = {
|
|
9
|
+
'==': operator.eq,
|
|
10
|
+
'!=': operator.ne,
|
|
11
|
+
'>': operator.gt,
|
|
12
|
+
'<': operator.lt,
|
|
13
|
+
'>=': operator.ge,
|
|
14
|
+
'<=': operator.le
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
class RuleConstraintWrapper(BaseEstimator, MetaEstimatorMixin):
|
|
18
|
+
"""
|
|
19
|
+
A meta-estimator that evaluates deterministic business rules before
|
|
20
|
+
delegating predictions to an underlying supervised machine learning model.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
estimator : estimator object
|
|
25
|
+
The base scikit-learn estimator (e.g., LogisticRegression).
|
|
26
|
+
rules : list of dict
|
|
27
|
+
A cascade of rules. Format:
|
|
28
|
+
[{'column': 'age', 'operator': '<', 'value': 18, 'outcome': 0}]
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, estimator, rules=None):
|
|
31
|
+
self.estimator = estimator
|
|
32
|
+
self.rules = rules if rules is not None else []
|
|
33
|
+
|
|
34
|
+
def _apply_rules(self, X):
|
|
35
|
+
"""
|
|
36
|
+
Internal method to evaluate the rule cascade using vectorized masking.
|
|
37
|
+
Returns a boolean mask of handled rows and their hardcoded predictions.
|
|
38
|
+
"""
|
|
39
|
+
n_samples = X.shape[0]
|
|
40
|
+
handled_mask = np.zeros(n_samples, dtype=bool)
|
|
41
|
+
rule_predictions = np.empty(n_samples, dtype=object)
|
|
42
|
+
|
|
43
|
+
if not self.rules:
|
|
44
|
+
return handled_mask, rule_predictions
|
|
45
|
+
|
|
46
|
+
for rule in self.rules:
|
|
47
|
+
col = rule['column']
|
|
48
|
+
op_func = OPERATOR_MAP[rule['operator']]
|
|
49
|
+
val = rule['value']
|
|
50
|
+
outcome = rule['outcome']
|
|
51
|
+
|
|
52
|
+
# Extract the column data safely
|
|
53
|
+
if isinstance(X, pd.DataFrame):
|
|
54
|
+
if col not in X.columns:
|
|
55
|
+
continue
|
|
56
|
+
col_data = X[col].values
|
|
57
|
+
else:
|
|
58
|
+
# Assume col is an integer index if X is a NumPy array
|
|
59
|
+
if not isinstance(col, int) or col >= X.shape[1]:
|
|
60
|
+
continue
|
|
61
|
+
col_data = X[:, col]
|
|
62
|
+
|
|
63
|
+
# Vectorized condition check (only on rows not yet handled)
|
|
64
|
+
condition_mask = op_func(col_data, val)
|
|
65
|
+
|
|
66
|
+
# Find rows that meet the condition AND haven't been handled by previous rules
|
|
67
|
+
active_mask = condition_mask & ~handled_mask
|
|
68
|
+
|
|
69
|
+
# Apply the outcome and update the handled mask
|
|
70
|
+
rule_predictions[active_mask] = outcome
|
|
71
|
+
handled_mask |= active_mask
|
|
72
|
+
|
|
73
|
+
return handled_mask, rule_predictions
|
|
74
|
+
|
|
75
|
+
def fit(self, X, y):
|
|
76
|
+
"""
|
|
77
|
+
Filters out data that triggers deterministic rules, then fits the
|
|
78
|
+
underlying estimator only on the remaining valid data.
|
|
79
|
+
"""
|
|
80
|
+
self.estimator_ = clone(self.estimator)
|
|
81
|
+
|
|
82
|
+
handled_mask, _ = self._apply_rules(X)
|
|
83
|
+
|
|
84
|
+
# We only train the model on data that passes the rules
|
|
85
|
+
X_passed = X[~handled_mask]
|
|
86
|
+
y_passed = np.array(y)[~handled_mask]
|
|
87
|
+
|
|
88
|
+
if len(X_passed) == 0:
|
|
89
|
+
raise ValueError("All training samples were filtered out by the business rules.")
|
|
90
|
+
|
|
91
|
+
self.estimator_.fit(X_passed, y_passed)
|
|
92
|
+
self._is_fitted_ = True
|
|
93
|
+
|
|
94
|
+
return self
|
|
95
|
+
|
|
96
|
+
def predict(self, X):
|
|
97
|
+
"""
|
|
98
|
+
Predicts outcomes by first applying rules, then delegating the remainder.
|
|
99
|
+
"""
|
|
100
|
+
check_is_fitted(self, '_is_fitted_')
|
|
101
|
+
|
|
102
|
+
n_samples = X.shape[0]
|
|
103
|
+
final_predictions = np.empty(n_samples, dtype=object)
|
|
104
|
+
|
|
105
|
+
# 1. Evaluate rules (The Gatekeeper)
|
|
106
|
+
handled_mask, rule_preds = self._apply_rules(X)
|
|
107
|
+
final_predictions[handled_mask] = rule_preds[handled_mask]
|
|
108
|
+
|
|
109
|
+
# 2. Delegate remainder to the underlying model
|
|
110
|
+
unhandled_mask = ~handled_mask
|
|
111
|
+
if np.any(unhandled_mask):
|
|
112
|
+
# Pass only the unhandled rows to the trained estimator
|
|
113
|
+
X_unhandled = X[unhandled_mask] if isinstance(X, pd.DataFrame) else X[unhandled_mask, :]
|
|
114
|
+
model_preds = self.estimator_.predict(X_unhandled)
|
|
115
|
+
final_predictions[unhandled_mask] = model_preds
|
|
116
|
+
|
|
117
|
+
return final_predictions
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: skweights
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Scikit-learn compatible meta-estimators for heuristic business rules and feature weighting.
|
|
5
|
+
Author-email: Aron Kipkurui <aronidengeno@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/wizard-hash2/skweights
|
|
7
|
+
Project-URL: Bug Tracker, https://github.com/wizard-hash2/skweights/issues
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Requires-Python: >=3.8
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: scikit-learn>=1.0.0
|
|
16
|
+
Requires-Dist: pandas>=1.0.0
|
|
17
|
+
Requires-Dist: numpy>=1.20.0
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
skweights/__init__.py
|
|
4
|
+
skweights/weighter.py
|
|
5
|
+
skweights/wrapper.py
|
|
6
|
+
skweights.egg-info/PKG-INFO
|
|
7
|
+
skweights.egg-info/SOURCES.txt
|
|
8
|
+
skweights.egg-info/dependency_links.txt
|
|
9
|
+
skweights.egg-info/requires.txt
|
|
10
|
+
skweights.egg-info/top_level.txt
|
|
11
|
+
tests/test_weighter.py
|
|
12
|
+
tests/test_wrapper.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
skweights
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from sklearn.pipeline import Pipeline
|
|
5
|
+
from sklearn.linear_model import LogisticRegression
|
|
6
|
+
from skweights.weighter import FeatureWeighter
|
|
7
|
+
|
|
8
|
+
# ---------------------------------------------------------
|
|
9
|
+
# Test 1: Mathematical Accuracy & Pandas Integration
|
|
10
|
+
# ---------------------------------------------------------
|
|
11
|
+
def test_pandas_dataframe_weighting():
|
|
12
|
+
# Setup original data
|
|
13
|
+
df = pd.DataFrame({
|
|
14
|
+
'years_experience': [1.0, 2.0, 3.0],
|
|
15
|
+
'age': [25.0, 30.0, 35.0],
|
|
16
|
+
'github_commits': [100.0, 200.0, 300.0]
|
|
17
|
+
})
|
|
18
|
+
|
|
19
|
+
# We want to double experience and halve commits. Age should be untouched.
|
|
20
|
+
weights = {'years_experience': 2.0, 'github_commits': 0.5}
|
|
21
|
+
weighter = FeatureWeighter(weights=weights)
|
|
22
|
+
|
|
23
|
+
# Execute
|
|
24
|
+
df_transformed = weighter.fit_transform(df)
|
|
25
|
+
|
|
26
|
+
# Assertions
|
|
27
|
+
assert list(df_transformed['years_experience']) == [2.0, 4.0, 6.0], "Failed to multiply correctly."
|
|
28
|
+
assert list(df_transformed['github_commits']) == [50.0, 100.0, 150.0], "Failed to multiply fractional weight."
|
|
29
|
+
assert list(df_transformed['age']) == [25.0, 30.0, 35.0], "Untouched column was altered."
|
|
30
|
+
|
|
31
|
+
# ---------------------------------------------------------
|
|
32
|
+
# Test 2: NumPy Array Resilience
|
|
33
|
+
# ---------------------------------------------------------
|
|
34
|
+
def test_numpy_array_weighting():
|
|
35
|
+
# 3 rows, 3 columns
|
|
36
|
+
X = np.array([
|
|
37
|
+
[1.0, 25.0, 100.0],
|
|
38
|
+
[2.0, 30.0, 200.0],
|
|
39
|
+
[3.0, 35.0, 300.0]
|
|
40
|
+
])
|
|
41
|
+
|
|
42
|
+
# Weight index 0 (col 1) by 2.0, index 2 (col 3) by 0.5
|
|
43
|
+
weights = {0: 2.0, 2: 0.5}
|
|
44
|
+
weighter = FeatureWeighter(weights=weights)
|
|
45
|
+
|
|
46
|
+
X_transformed = weighter.fit_transform(X)
|
|
47
|
+
|
|
48
|
+
# Assertions using np.testing for safe float comparisons
|
|
49
|
+
np.testing.assert_array_equal(X_transformed[:, 0], np.array([2.0, 4.0, 6.0]))
|
|
50
|
+
np.testing.assert_array_equal(X_transformed[:, 2], np.array([50.0, 100.0, 150.0]))
|
|
51
|
+
np.testing.assert_array_equal(X_transformed[:, 1], np.array([25.0, 30.0, 35.0]))
|
|
52
|
+
|
|
53
|
+
# ---------------------------------------------------------
|
|
54
|
+
# Test 3: The Immutability Principle
|
|
55
|
+
# ---------------------------------------------------------
|
|
56
|
+
def test_original_data_not_mutated():
|
|
57
|
+
df = pd.DataFrame({'feature_a': [10.0, 20.0]})
|
|
58
|
+
|
|
59
|
+
weighter = FeatureWeighter(weights={'feature_a': 5.0})
|
|
60
|
+
_ = weighter.fit_transform(df)
|
|
61
|
+
|
|
62
|
+
# The original dataframe should still have the original values
|
|
63
|
+
assert list(df['feature_a']) == [10.0, 20.0], "Original DataFrame was mutated in place!"
|
|
64
|
+
|
|
65
|
+
# ---------------------------------------------------------
|
|
66
|
+
# Test 4: Pipeline Passthrough (None or Empty Weights)
|
|
67
|
+
# ---------------------------------------------------------
|
|
68
|
+
def test_empty_weights_passthrough():
|
|
69
|
+
df = pd.DataFrame({'feature_a': [1.0, 2.0]})
|
|
70
|
+
|
|
71
|
+
# If a developer initiates it without weights, it should just return the data untouched
|
|
72
|
+
weighter = FeatureWeighter(weights=None)
|
|
73
|
+
df_transformed = weighter.fit_transform(df)
|
|
74
|
+
|
|
75
|
+
assert list(df_transformed['feature_a']) == [1.0, 2.0]
|
|
76
|
+
|
|
77
|
+
# ---------------------------------------------------------
|
|
78
|
+
# Test 5: End-to-End Pipeline Integration
|
|
79
|
+
# ---------------------------------------------------------
|
|
80
|
+
def test_pipeline_integration_runs():
|
|
81
|
+
X = pd.DataFrame({'feature_a': [1.0, 2.0, 3.0, 4.0], 'feature_b': [4.0, 3.0, 2.0, 1.0]})
|
|
82
|
+
y = np.array([0, 0, 1, 1])
|
|
83
|
+
|
|
84
|
+
# Build a standard Scikit-learn Pipeline
|
|
85
|
+
pipe = Pipeline([
|
|
86
|
+
('weighter', FeatureWeighter(weights={'feature_a': 10.0})),
|
|
87
|
+
('classifier', LogisticRegression())
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
# If fit and predict execute without shape or type errors, integration works
|
|
91
|
+
pipe.fit(X, y)
|
|
92
|
+
predictions = pipe.predict(X)
|
|
93
|
+
|
|
94
|
+
assert len(predictions) == 4
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from sklearn.linear_model import LogisticRegression
|
|
5
|
+
from skweights.wrapper import RuleConstraintWrapper
|
|
6
|
+
|
|
7
|
+
# ---------------------------------------------------------
|
|
8
|
+
# Test 1: The Gatekeeper Intercept (Hard Constraint)
|
|
9
|
+
# ---------------------------------------------------------
|
|
10
|
+
def test_rule_cascade_blocking():
|
|
11
|
+
# 1. Setup the data
|
|
12
|
+
# Row 0 triggers the rule (laptop=0). Rows 1, 2, and 3 pass.
|
|
13
|
+
df = pd.DataFrame({
|
|
14
|
+
'laptop_status': [0, 1, 1, 1],
|
|
15
|
+
'experience': [5.0, 2.0, 10.0, 1.0]
|
|
16
|
+
})
|
|
17
|
+
|
|
18
|
+
# We need multiple classes (0 and 1) in the passing rows
|
|
19
|
+
# so LogisticRegression doesn't crash during fit().
|
|
20
|
+
y = np.array([1, 1, 0, 0])
|
|
21
|
+
|
|
22
|
+
# 2. Define the business logic
|
|
23
|
+
rules = [{'column': 'laptop_status', 'operator': '==', 'value': 0, 'outcome': 0}]
|
|
24
|
+
|
|
25
|
+
# 3. Initialize and fit the wrapper
|
|
26
|
+
wrapper = RuleConstraintWrapper(estimator=LogisticRegression(), rules=rules)
|
|
27
|
+
wrapper.fit(df, y)
|
|
28
|
+
|
|
29
|
+
# 4. Predict
|
|
30
|
+
predictions = wrapper.predict(df)
|
|
31
|
+
|
|
32
|
+
# 5. Assertions
|
|
33
|
+
# The first row MUST be 0 because of the rule, completely ignoring
|
|
34
|
+
# the fact that in the original `y` array it was labeled as 1.
|
|
35
|
+
assert predictions[0] == 0
|
|
36
|
+
# The output shape must perfectly match the input shape
|
|
37
|
+
assert len(predictions) == 4
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------
|
|
40
|
+
# Test 2: NumPy Array Fallback
|
|
41
|
+
# ---------------------------------------------------------
|
|
42
|
+
def test_numpy_rule_evaluation():
|
|
43
|
+
# Matrix where column 0 is the constraint feature
|
|
44
|
+
X = np.array([
|
|
45
|
+
[0.0, 5.0], # Blocked
|
|
46
|
+
[1.0, 2.0], # Passed
|
|
47
|
+
[1.0, 10.0] # Passed
|
|
48
|
+
])
|
|
49
|
+
y = np.array([1, 0, 1])
|
|
50
|
+
|
|
51
|
+
# Rule checks index 0 instead of a column string
|
|
52
|
+
rules = [{'column': 0, 'operator': '==', 'value': 0.0, 'outcome': 0}]
|
|
53
|
+
|
|
54
|
+
wrapper = RuleConstraintWrapper(estimator=LogisticRegression(), rules=rules)
|
|
55
|
+
wrapper.fit(X, y)
|
|
56
|
+
|
|
57
|
+
predictions = wrapper.predict(X)
|
|
58
|
+
|
|
59
|
+
assert predictions[0] == 0
|
|
60
|
+
assert len(predictions) == 3
|
|
61
|
+
|
|
62
|
+
# ---------------------------------------------------------
|
|
63
|
+
# Test 3: Empty Rules (Standard Estimator Behavior)
|
|
64
|
+
# ---------------------------------------------------------
|
|
65
|
+
def test_empty_rules_pass_through():
|
|
66
|
+
X = pd.DataFrame({'feature': [1, 2, 3, 4]})
|
|
67
|
+
y = np.array([0, 0, 1, 1])
|
|
68
|
+
|
|
69
|
+
# Passing no rules should just make it act like a normal LogisticRegression
|
|
70
|
+
wrapper = RuleConstraintWrapper(estimator=LogisticRegression(), rules=[])
|
|
71
|
+
wrapper.fit(X, y)
|
|
72
|
+
|
|
73
|
+
predictions = wrapper.predict(X)
|
|
74
|
+
assert len(predictions) == 4
|