champions 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- champions/__init__.py +24 -0
- champions/cli.py +39 -0
- champions/model/__init__.py +0 -0
- champions/model/champions.py +35 -0
- champions/model/datacard.py +48 -0
- champions/model/dataframes.py +225 -0
- champions/model/filter.py +37 -0
- champions/model/settings.py +55 -0
- champions/service/__init__.py +0 -0
- champions/service/darkwing.py +96 -0
- champions/service/eval.py +130 -0
- champions/service/train.py +213 -0
- champions-0.9.0.dist-info/METADATA +33 -0
- champions-0.9.0.dist-info/RECORD +17 -0
- champions-0.9.0.dist-info/WHEEL +4 -0
- champions-0.9.0.dist-info/entry_points.txt +2 -0
- champions-0.9.0.dist-info/licenses/LICENSE +674 -0
champions/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
import logging.config
|
2
|
+
import yaml
|
3
|
+
|
4
|
+
try:
|
5
|
+
with open("logging.yaml", "r") as f:
|
6
|
+
config = yaml.safe_load(f)
|
7
|
+
except FileNotFoundError:
|
8
|
+
config = {
|
9
|
+
"version": 1,
|
10
|
+
"disable_existing_loggers": False,
|
11
|
+
"formatters": {
|
12
|
+
"simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
|
13
|
+
},
|
14
|
+
"handlers": {
|
15
|
+
"console": {
|
16
|
+
"class": "logging.StreamHandler",
|
17
|
+
"formatter": "simple",
|
18
|
+
"level": "INFO",
|
19
|
+
}
|
20
|
+
},
|
21
|
+
"root": {"handlers": ["console"], "level": "INFO"},
|
22
|
+
}
|
23
|
+
|
24
|
+
logging.config.dictConfig(config)
|
champions/cli.py
ADDED
@@ -0,0 +1,39 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Annotated
|
3
|
+
import typer
|
4
|
+
import yaml
|
5
|
+
|
6
|
+
from champions.model.datacard import DataCard
|
7
|
+
from champions.model.settings import EvalSettings, TrainSettings
|
8
|
+
from champions.service.eval import Eval
|
9
|
+
from champions.service.train import Train
|
10
|
+
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
app = typer.Typer()
|
15
|
+
|
16
|
+
|
17
|
+
# def main(config: Annotated[typer.FileText, typer.Option()]):
|
18
|
+
@app.command()
|
19
|
+
def train(
|
20
|
+
datacard: Annotated[typer.FileText, typer.Option()],
|
21
|
+
trainsettings: Annotated[typer.FileText, typer.Option()],
|
22
|
+
):
|
23
|
+
train = Train(
|
24
|
+
dc=DataCard(**yaml.safe_load(datacard)),
|
25
|
+
settings=TrainSettings(**yaml.safe_load(trainsettings)),
|
26
|
+
)
|
27
|
+
train.run()
|
28
|
+
|
29
|
+
|
30
|
+
@app.command()
|
31
|
+
def eval(
|
32
|
+
datacard: Annotated[typer.FileText, typer.Option()],
|
33
|
+
evalsettings: Annotated[typer.FileText, typer.Option()],
|
34
|
+
):
|
35
|
+
train = Eval(
|
36
|
+
dc=DataCard(**yaml.safe_load(datacard)),
|
37
|
+
settings=EvalSettings(**yaml.safe_load(evalsettings)),
|
38
|
+
)
|
39
|
+
train.run()
|
File without changes
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import logging
|
2
|
+
from pydantic import BaseModel
|
3
|
+
|
4
|
+
logger = logging.getLogger(__name__)
|
5
|
+
|
6
|
+
|
7
|
+
class Spore(BaseModel):
|
8
|
+
cut: list[str]
|
9
|
+
score: float
|
10
|
+
depth: str
|
11
|
+
|
12
|
+
|
13
|
+
class Champion(BaseModel):
|
14
|
+
spore: list[Spore]
|
15
|
+
target: str | int
|
16
|
+
|
17
|
+
def get_sql(self, res_name: str) -> str:
|
18
|
+
sql_ref = "CASE \n"
|
19
|
+
for spore in self.spore:
|
20
|
+
sql_ref += (
|
21
|
+
f" WHEN {' AND '.join(spore.cut)} THEN CAST({spore.score} AS DOUBLE)\n"
|
22
|
+
)
|
23
|
+
sql_ref += f"END AS {res_name}"
|
24
|
+
return sql_ref
|
25
|
+
|
26
|
+
|
27
|
+
class Champions(BaseModel):
|
28
|
+
champions: dict[str, Champion]
|
29
|
+
target: str
|
30
|
+
|
31
|
+
def get_sql(self) -> dict[str, str]:
|
32
|
+
return {
|
33
|
+
f"res_{name}": champion.get_sql(f"res_{name}")
|
34
|
+
for name, champion in self.champions.items()
|
35
|
+
}
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Literal
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
class Feature(BaseModel):
|
9
|
+
name: str
|
10
|
+
statistical: Literal["category", "continus"]
|
11
|
+
type: str
|
12
|
+
|
13
|
+
|
14
|
+
class Target(BaseModel):
|
15
|
+
feature_name: str
|
16
|
+
values: list[int | str]
|
17
|
+
|
18
|
+
|
19
|
+
class DataCard(BaseModel):
|
20
|
+
features: list[Feature]
|
21
|
+
infos: dict
|
22
|
+
target: Target
|
23
|
+
test_files: list[str]
|
24
|
+
train_files: list[str]
|
25
|
+
|
26
|
+
@property
|
27
|
+
def feature_names(self) -> list[str]:
|
28
|
+
return [f'"{feat.name}"' for feat in self.features]
|
29
|
+
|
30
|
+
@property
|
31
|
+
def train_feature(self) -> list[Feature]:
|
32
|
+
return [feat for feat in self.features if feat.name != self.target.feature_name]
|
33
|
+
|
34
|
+
@property
|
35
|
+
def train_cat_feature(self) -> list[Feature]:
|
36
|
+
return [
|
37
|
+
feat
|
38
|
+
for feat in self.features
|
39
|
+
if feat.name != self.target.feature_name and feat.statistical == "category"
|
40
|
+
]
|
41
|
+
|
42
|
+
@property
|
43
|
+
def train_con_feature(self) -> list[Feature]:
|
44
|
+
return [
|
45
|
+
feat
|
46
|
+
for feat in self.features
|
47
|
+
if feat.name != self.target.feature_name and feat.statistical == "continus"
|
48
|
+
]
|
@@ -0,0 +1,225 @@
|
|
1
|
+
import itertools
|
2
|
+
import math
|
3
|
+
import time
|
4
|
+
import polars as pl
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from champions.model.datacard import Feature
|
9
|
+
from champions.model.filter import CombineFilter, Filter, SingleFilter
|
10
|
+
import sys
|
11
|
+
|
12
|
+
|
13
|
+
class ContCats(BaseModel):
|
14
|
+
cuts: list[float]
|
15
|
+
labels: list[str]
|
16
|
+
feat_name: str
|
17
|
+
|
18
|
+
def cut(self, series: pl.Series) -> pl.Series:
|
19
|
+
# return series.cut(self.cuts)
|
20
|
+
return series.cut(self.cuts, labels=self.labels)
|
21
|
+
|
22
|
+
def get_single_filter(self, i: int) -> SingleFilter:
|
23
|
+
if i == 2 and len(self.cuts) == 1:
|
24
|
+
sys.exit()
|
25
|
+
|
26
|
+
if i == 0:
|
27
|
+
return SingleFilter(
|
28
|
+
feat_name=self.feat_name,
|
29
|
+
operator="<=",
|
30
|
+
value=self.cuts[i],
|
31
|
+
)
|
32
|
+
|
33
|
+
if i == len(self.labels) - 1:
|
34
|
+
return SingleFilter(
|
35
|
+
feat_name=self.feat_name,
|
36
|
+
operator=">",
|
37
|
+
value=self.cuts[i - 1],
|
38
|
+
)
|
39
|
+
return SingleFilter(
|
40
|
+
feat_name=self.feat_name,
|
41
|
+
operator="between",
|
42
|
+
value=[self.cuts[i - 1], self.cuts[i]],
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class CategorizedFeatureMixin:
|
47
|
+
diff_df: pl.DataFrame
|
48
|
+
cut_list: list[ContCats]
|
49
|
+
feature_list: list[str]
|
50
|
+
|
51
|
+
def calc_diff(
|
52
|
+
self,
|
53
|
+
) -> float:
|
54
|
+
# self.diff_sr.entropy()
|
55
|
+
return self.diff_df["diff"].abs().sum()
|
56
|
+
|
57
|
+
def set_diff_df(
|
58
|
+
self,
|
59
|
+
df_count_target: pl.DataFrame,
|
60
|
+
df_count_non_target: pl.DataFrame,
|
61
|
+
join_on: list[str] | str,
|
62
|
+
):
|
63
|
+
self.diff_df = (
|
64
|
+
df_count_non_target.join(df_count_target, on=join_on, how="outer")
|
65
|
+
.fill_null(0)
|
66
|
+
.with_columns(
|
67
|
+
(pl.col("proportion_right") - pl.col("proportion")).alias("diff"),
|
68
|
+
pl.max_horizontal(
|
69
|
+
pl.col("proportion"), pl.col("proportion_right")
|
70
|
+
).alias("max_proportion"),
|
71
|
+
)
|
72
|
+
)
|
73
|
+
|
74
|
+
def get_left_right_filter(self) -> tuple[Filter, Filter]:
|
75
|
+
# for fea_name, cut in zip(self.feature_list, self.cut_list):
|
76
|
+
|
77
|
+
target_res_df = self.diff_df.filter(pl.col("diff") > 0)[
|
78
|
+
[f"{cut.feat_name}_right" for cut in self.cut_list]
|
79
|
+
]
|
80
|
+
target_lists = [
|
81
|
+
target_res_df[f"{cut.feat_name}_right"].to_list() for cut in self.cut_list
|
82
|
+
]
|
83
|
+
target_list_tupels = [akt_tuple for akt_tuple in zip(*target_lists)]
|
84
|
+
|
85
|
+
all_labels = [cut.labels for cut in self.cut_list]
|
86
|
+
target_filter = []
|
87
|
+
non_targer_filter = []
|
88
|
+
|
89
|
+
for possible_combs in itertools.product(*all_labels):
|
90
|
+
single_filter_list = [
|
91
|
+
cut.get_single_filter(int(index))
|
92
|
+
for index, cut in zip(possible_combs, self.cut_list)
|
93
|
+
]
|
94
|
+
comb_filter = CombineFilter(combine=single_filter_list)
|
95
|
+
if possible_combs in target_list_tupels:
|
96
|
+
target_filter.append(comb_filter)
|
97
|
+
else:
|
98
|
+
non_targer_filter.append(comb_filter)
|
99
|
+
|
100
|
+
return Filter(combine=non_targer_filter, invert=False), Filter(
|
101
|
+
combine=target_filter, invert=False
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
@dataclass
|
106
|
+
class CategorizedFeature(CategorizedFeatureMixin):
|
107
|
+
feature: Feature
|
108
|
+
cuts: ContCats
|
109
|
+
target_sr: pl.Series
|
110
|
+
non_target_sr: pl.Series
|
111
|
+
|
112
|
+
def __post_init__(self):
|
113
|
+
df_count_target = self.target_sr.value_counts(normalize=True)
|
114
|
+
df_count_non_target = self.non_target_sr.value_counts(normalize=True)
|
115
|
+
self.set_diff_df(
|
116
|
+
df_count_target=df_count_target,
|
117
|
+
df_count_non_target=df_count_non_target,
|
118
|
+
join_on=self.feature.name,
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
@property
|
123
|
+
def cut_list(self) -> list[ContCats]:
|
124
|
+
return [self.cuts]
|
125
|
+
|
126
|
+
def is_diff_to_low(self, threshold: float = 0.90) -> bool:
|
127
|
+
max_wert = self.diff_df["max_proportion"].max()
|
128
|
+
min_prop_of_max = self.diff_df.filter(pl.col("max_proportion") == max_wert)[
|
129
|
+
"proportion", "proportion_right"
|
130
|
+
].min_horizontal()[0]
|
131
|
+
return min_prop_of_max > threshold
|
132
|
+
|
133
|
+
|
134
|
+
class CombinedCategorizedFeature(CategorizedFeatureMixin):
|
135
|
+
def __init__(
|
136
|
+
self,
|
137
|
+
train_features: tuple[CategorizedFeature],
|
138
|
+
non_target_size: int,
|
139
|
+
target_size: int,
|
140
|
+
):
|
141
|
+
groub_by = [train.feature.name for train in train_features]
|
142
|
+
non_target_df = pl.DataFrame([train.non_target_sr for train in train_features])
|
143
|
+
|
144
|
+
df_count_non_target = (
|
145
|
+
non_target_df.group_by(groub_by)
|
146
|
+
.len(name="proportion")
|
147
|
+
.with_columns(pl.col("proportion") / non_target_size)
|
148
|
+
)
|
149
|
+
target_df = pl.DataFrame([train.target_sr for train in train_features])
|
150
|
+
df_count_target = (
|
151
|
+
target_df.group_by(groub_by)
|
152
|
+
.len(name="proportion")
|
153
|
+
.with_columns(pl.col("proportion") / target_size)
|
154
|
+
)
|
155
|
+
self.set_diff_df(
|
156
|
+
df_count_target=df_count_target,
|
157
|
+
df_count_non_target=df_count_non_target,
|
158
|
+
join_on=groub_by,
|
159
|
+
)
|
160
|
+
self.cut_list = [train.cuts for train in train_features]
|
161
|
+
|
162
|
+
|
163
|
+
class TrainDataframes:
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
target_df: pl.DataFrame,
|
167
|
+
non_target_df: pl.DataFrame,
|
168
|
+
frac_eval_cat: float,
|
169
|
+
min_size: int,
|
170
|
+
):
|
171
|
+
self.target_df_size = target_df.height
|
172
|
+
self.non_target_df_size = non_target_df.height
|
173
|
+
|
174
|
+
self.min_size = min_size
|
175
|
+
|
176
|
+
self.n_count_target, n_group_target = self._calc_split(target_df.height, frac_eval_cat)
|
177
|
+
self.target_df_count = target_df.head(self.n_count_target)
|
178
|
+
target_df_group = target_df.tail(-self.n_count_target)
|
179
|
+
#target_df_group['weight'] = 0.5 / len(target_df_group)
|
180
|
+
|
181
|
+
self.n_count_non_target, n_group_non_target = self._calc_split(non_target_df.height, frac_eval_cat)
|
182
|
+
self.non_target_df_count = non_target_df.head(self.n_count_non_target)
|
183
|
+
non_target_df_group = non_target_df.tail(-self.n_count_non_target)
|
184
|
+
|
185
|
+
if n_group_non_target > 0 and n_group_target > 0:
|
186
|
+
|
187
|
+
self.df_group = pl.concat([target_df_group.with_columns(pl.lit(0.5 /n_group_target).alias('weight')),
|
188
|
+
non_target_df_group.with_columns(pl.lit(0.5 /n_group_non_target).alias('weight'))
|
189
|
+
])
|
190
|
+
self.train_features: list[CategorizedFeature] = []
|
191
|
+
|
192
|
+
def create_categorized_features(
|
193
|
+
self, feat: Feature, cuts: ContCats
|
194
|
+
) -> CategorizedFeature:
|
195
|
+
target_sr = cuts.cut(series=self.target_df_count[feat.name])
|
196
|
+
non_target_sr = cuts.cut(series=self.non_target_df_count[feat.name])
|
197
|
+
return CategorizedFeature(
|
198
|
+
feature=feat, cuts=cuts, target_sr=target_sr, non_target_sr=non_target_sr
|
199
|
+
)
|
200
|
+
|
201
|
+
def _calc_split(self, n: int, frac: float):
|
202
|
+
split = round(n * frac)
|
203
|
+
if split < 1:
|
204
|
+
return 1 ,0
|
205
|
+
if split >= n:
|
206
|
+
return n - 1 , 1
|
207
|
+
return split, n- split
|
208
|
+
|
209
|
+
def score(self) -> float:
|
210
|
+
if self.non_target_df_size + self.target_df_size == 0:
|
211
|
+
return 0.0
|
212
|
+
return (self.target_df_size - self.non_target_df_size) / (
|
213
|
+
self.non_target_df_size + self.target_df_size
|
214
|
+
)
|
215
|
+
|
216
|
+
def is_final_size(self) -> bool:
|
217
|
+
return (
|
218
|
+
self.target_df_size < self.min_size
|
219
|
+
or self.non_target_df_size < self.min_size
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
class EvalDataframe:
|
224
|
+
def __init__(self, df: pl.DataFrame, target: str | int) -> None:
|
225
|
+
self.df = df
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import Literal
|
2
|
+
from pydantic import BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
class SingleFilter(BaseModel):
|
6
|
+
feat_name: str
|
7
|
+
operator: Literal["=", "<", "<=", ">", ">=", "in", "between"]
|
8
|
+
value: str | int | float | list[str] | list[int] | list[float]
|
9
|
+
|
10
|
+
def sql(self) -> str:
|
11
|
+
if self.operator == "between":
|
12
|
+
return f'( "{self.feat_name}" > {self.value[0]} AND "{self.feat_name}" <= {self.value[1]} )'
|
13
|
+
if isinstance(self.value, list):
|
14
|
+
return f"{self.feat_name} {self.operator} ( {', '.join(self.value)})"
|
15
|
+
return f'"{self.feat_name}"{self.operator}{self.value}'
|
16
|
+
|
17
|
+
|
18
|
+
class CombineFilter(BaseModel):
|
19
|
+
combine: list[SingleFilter]
|
20
|
+
invert: bool = False
|
21
|
+
|
22
|
+
def sql(self, do_invert: bool = False) -> str:
|
23
|
+
prefix = "not" if self.invert ^ do_invert else "" # python xor
|
24
|
+
if len(self.combine) == 1:
|
25
|
+
return f" {prefix} {self.combine[0].sql()} "
|
26
|
+
return f" {prefix} ( {' AND '.join([sf.sql() for sf in self.combine])} )"
|
27
|
+
|
28
|
+
|
29
|
+
class Filter(BaseModel):
|
30
|
+
combine: list[CombineFilter]
|
31
|
+
invert: bool
|
32
|
+
|
33
|
+
def sql(self, do_invert: bool = False) -> str:
|
34
|
+
prefix = "not" if self.invert ^ do_invert else "" # python xor
|
35
|
+
if len(self.combine) == 1:
|
36
|
+
return f" {prefix} {self.combine[0].sql()} "
|
37
|
+
return f" {prefix} ( {' OR '.join([sf.sql() for sf in self.combine])} )"
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import logging
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from pydantic import BaseModel, Field
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
from champions.model.champions import Champion, Champions
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class EvalSettings(BaseModel):
|
13
|
+
in_folder: Path = Field(description="Name where the champion is stored")
|
14
|
+
out_folder: Path = Field(description="Name where the evaluation is stored")
|
15
|
+
roc_points: int = Field(description="Number of points for roc curve", default=500)
|
16
|
+
|
17
|
+
|
18
|
+
class TrainSettings(BaseModel):
|
19
|
+
n: int = Field(description="Number of Champions per target value", default=1)
|
20
|
+
out_folder: Path = Field(description="Name where the champion is stored")
|
21
|
+
n_cat: int = Field(description="Number of categories in one featerue", default=3)
|
22
|
+
max_depth: int = Field(description="Max depth of the champion", default=10)
|
23
|
+
|
24
|
+
max_eval_fit: int = Field(
|
25
|
+
description="Wieviel sollen max für training gleichzeitig benutzt werden",
|
26
|
+
default=1000,
|
27
|
+
)
|
28
|
+
min_eval_fit: int = Field(
|
29
|
+
description="Wieviel müssen fürs training mindestens benutzt werden", default=10
|
30
|
+
)
|
31
|
+
n_dims: int = Field(
|
32
|
+
description="Wieviele dimensionen sollen auf einmal benutzt werden", default=3
|
33
|
+
)
|
34
|
+
calcs_per_dim: int | None = Field(
|
35
|
+
description="Wieviele berechnungen sollen pro dimension gemacht werden",
|
36
|
+
default=5000,
|
37
|
+
)
|
38
|
+
frac_eval_cat: float = Field(
|
39
|
+
description="Wie groß ist der anteil der für eval benutzt werden soll",
|
40
|
+
default=0.5,
|
41
|
+
)
|
42
|
+
|
43
|
+
def champion_exists(self, target: str | int, n: int) -> bool:
|
44
|
+
if self.out_folder.exists():
|
45
|
+
if (self.out_folder / f"{target}" / f"{n}.yaml").exists():
|
46
|
+
return True
|
47
|
+
return False
|
48
|
+
|
49
|
+
def save_champion(self, champion: Champion, n: int):
|
50
|
+
target = champion.target
|
51
|
+
out_folder = self.out_folder / f"{target}"
|
52
|
+
logger.info(f"save champion {n} for target {target} to {out_folder}")
|
53
|
+
out_folder.mkdir(parents=True, exist_ok=True)
|
54
|
+
with open(out_folder / f"{n}.yaml", "w") as f:
|
55
|
+
f.write(yaml.safe_dump(champion.model_dump()))
|
File without changes
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
from typing import Any
|
4
|
+
import duckdb
|
5
|
+
from pydantic import BaseModel
|
6
|
+
import polars as pl
|
7
|
+
|
8
|
+
from champions.model.champions import Champions
|
9
|
+
from champions.model.datacard import DataCard
|
10
|
+
from champions.model.dataframes import EvalDataframe, TrainDataframes
|
11
|
+
from champions.model.filter import Filter, SingleFilter
|
12
|
+
from champions.model.settings import TrainSettings
|
13
|
+
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class Darkwing(BaseModel):
|
19
|
+
dc: DataCard
|
20
|
+
df_train_cach: Any | None = None
|
21
|
+
|
22
|
+
def read_akt_train(
|
23
|
+
self,
|
24
|
+
targer_filter: Filter,
|
25
|
+
train_settings: TrainSettings,
|
26
|
+
akt_filters: list[Filter] = [],
|
27
|
+
) -> TrainDataframes:
|
28
|
+
"""
|
29
|
+
Reads a list of CSV files into a DuckDB relation.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
filepaths: A list of filepaths to CSV files.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
A DuckDB relation containing the data from all CSV files.
|
36
|
+
"""
|
37
|
+
|
38
|
+
full_filters_target = [targer_filter.sql()] + [f.sql() for f in akt_filters]
|
39
|
+
target_df = self._get_pl_train_df(
|
40
|
+
full_filters=full_filters_target,
|
41
|
+
max_eval_fit=train_settings.max_eval_fit,
|
42
|
+
)
|
43
|
+
full_filters_non_target = [targer_filter.sql(do_invert=True)] + [
|
44
|
+
f.sql() for f in akt_filters
|
45
|
+
]
|
46
|
+
non_target_df = self._get_pl_train_df(
|
47
|
+
full_filters=full_filters_non_target,
|
48
|
+
max_eval_fit=train_settings.max_eval_fit,
|
49
|
+
)
|
50
|
+
|
51
|
+
return TrainDataframes(
|
52
|
+
target_df=target_df,
|
53
|
+
non_target_df=non_target_df,
|
54
|
+
frac_eval_cat=train_settings.frac_eval_cat,
|
55
|
+
min_size=train_settings.min_eval_fit,
|
56
|
+
)
|
57
|
+
|
58
|
+
def _get_pl_train_df(self, full_filters: list[str], max_eval_fit) -> pl.DataFrame:
|
59
|
+
df = self.get_cached_train_df()
|
60
|
+
sql = f"""
|
61
|
+
SELECT {", ".join(self.dc.feature_names)}
|
62
|
+
FROM df
|
63
|
+
WHERE {" AND ".join(full_filters)}
|
64
|
+
LIMIT {max_eval_fit};
|
65
|
+
"""
|
66
|
+
return duckdb.sql(sql).pl().sample(fraction=1, shuffle=True)
|
67
|
+
|
68
|
+
def get_cached_train_df(self) -> pl.DataFrame:
|
69
|
+
if self.df_train_cach is None:
|
70
|
+
self.df_train_cach = pl.read_csv(",".join(self.dc.train_files))
|
71
|
+
return self.df_train_cach
|
72
|
+
|
73
|
+
def get_eval_sr(self, champions: Champions) -> EvalDataframe:
|
74
|
+
df_sum = None
|
75
|
+
norm = 0.0
|
76
|
+
|
77
|
+
for name, champion in champions.champions.items():
|
78
|
+
norm += 1.0
|
79
|
+
col_name = f"res_{name}"
|
80
|
+
case_sql = champion.get_sql(col_name)
|
81
|
+
df = self._get_pl_eval_df(col_sql=case_sql)
|
82
|
+
df_sum = df_sum.hstack(df) if df_sum is not None else df
|
83
|
+
|
84
|
+
return (df_sum.sum_horizontal() / norm).alias(champions.target)
|
85
|
+
|
86
|
+
# for feat_name, case_sql in champions.get_sql().items():
|
87
|
+
# df = self._get_pl_eval_df(case_sql=case_sql)
|
88
|
+
# logger.info(f"Evaluate {df}")
|
89
|
+
# logger.info(f"{case_sql}")
|
90
|
+
|
91
|
+
def _get_pl_eval_df(self, col_sql: str) -> pl.DataFrame:
|
92
|
+
sql = f"""
|
93
|
+
SELECT {col_sql}\n
|
94
|
+
FROM read_csv({self.dc.test_files});
|
95
|
+
"""
|
96
|
+
return duckdb.sql(sql).pl()
|
@@ -0,0 +1,130 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from typing import Optional
|
4
|
+
import duckdb
|
5
|
+
from pydantic import BaseModel
|
6
|
+
import yaml
|
7
|
+
from sklearn.metrics import roc_curve
|
8
|
+
import altair as alt
|
9
|
+
|
10
|
+
from champions.model.champions import Champion, Champions
|
11
|
+
from champions.model.datacard import DataCard
|
12
|
+
from champions.model.settings import EvalSettings
|
13
|
+
from champions.service.darkwing import Darkwing
|
14
|
+
import polars as pl
|
15
|
+
import logging
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class Eval(BaseModel):
|
21
|
+
dc: DataCard
|
22
|
+
settings: EvalSettings
|
23
|
+
darkwing: Optional[Darkwing] = None
|
24
|
+
|
25
|
+
def model_post_init(self, __context) -> None:
|
26
|
+
self.darkwing = Darkwing(dc=self.dc)
|
27
|
+
os.makedirs(self.settings.out_folder, exist_ok=True)
|
28
|
+
super().model_post_init(__context)
|
29
|
+
|
30
|
+
def run(self):
|
31
|
+
logger.info("Start Eval")
|
32
|
+
|
33
|
+
df_res = self.darkwing._get_pl_eval_df(col_sql=self.dc.target.feature_name)
|
34
|
+
# df_res = pl.read_parquet("test.parquet")
|
35
|
+
plots = None
|
36
|
+
|
37
|
+
for target in self.dc.target.values:
|
38
|
+
logger.info(f"Evaluate target {target}")
|
39
|
+
target_champions = self.load_champions(target=f"{target}")
|
40
|
+
sr = self.darkwing.get_eval_sr(champions=target_champions)
|
41
|
+
df_res = df_res.with_columns(sr)
|
42
|
+
new_plot = self.plot_roc(df=df_res, target_label=target)
|
43
|
+
new_plot.save(self.settings.out_folder / f"{target}_roc.html")
|
44
|
+
if plots is None:
|
45
|
+
plots = new_plot
|
46
|
+
else:
|
47
|
+
plots = plots + new_plot
|
48
|
+
|
49
|
+
plots.save(self.settings.out_folder / "all_roc.html")
|
50
|
+
|
51
|
+
if len(self.dc.target.values) > 1:
|
52
|
+
df_multi_res = self.add_multi_class_result(
|
53
|
+
df=df_res, target_values=self.dc.target.values
|
54
|
+
)
|
55
|
+
df_res = pl.concat(
|
56
|
+
[df_res, df_multi_res],
|
57
|
+
how="horizontal",
|
58
|
+
)
|
59
|
+
|
60
|
+
multi_res_plot = (
|
61
|
+
df_res.group_by("label")
|
62
|
+
.agg(pl.col("correct").sum() / pl.count())
|
63
|
+
.sort("label")
|
64
|
+
.plot.bar(x="label", y="correct")
|
65
|
+
)
|
66
|
+
prec = df_res["correct"].sum() / df_res["correct"].count()
|
67
|
+
all_pred_labels = (
|
68
|
+
alt.Chart(pl.DataFrame({"correct": [prec]}))
|
69
|
+
.mark_rule()
|
70
|
+
.encode(y="correct")
|
71
|
+
)
|
72
|
+
(multi_res_plot + all_pred_labels).save(
|
73
|
+
self.settings.out_folder / "multi_class_result.html"
|
74
|
+
)
|
75
|
+
logger.info(f"{df_res}")
|
76
|
+
|
77
|
+
def add_multi_class_result(self, df: pl.DataFrame, target_values: list[str | int]):
|
78
|
+
greates = ", ".join([f'"{value}"' for value in target_values])
|
79
|
+
cases = [
|
80
|
+
f'WHEN "{value}" = GREATEST({greates}) THEN {value}'
|
81
|
+
for value in target_values
|
82
|
+
]
|
83
|
+
return (
|
84
|
+
duckdb.sql(f"""
|
85
|
+
SELECT
|
86
|
+
label,
|
87
|
+
CASE
|
88
|
+
{" \n ".join(cases)}
|
89
|
+
END AS predicted_label
|
90
|
+
FROM df
|
91
|
+
""")
|
92
|
+
.pl()
|
93
|
+
.with_columns(
|
94
|
+
(pl.col("label") == pl.col("predicted_label"))
|
95
|
+
.cast(pl.Int64)
|
96
|
+
.alias("correct")
|
97
|
+
)
|
98
|
+
.select("predicted_label", "correct")
|
99
|
+
)
|
100
|
+
|
101
|
+
def load_champions(self, target: str) -> Champions:
|
102
|
+
target_folder = self.settings.in_folder / target
|
103
|
+
data = {}
|
104
|
+
for file in target_folder.glob("*.yaml"):
|
105
|
+
with open(file, "r") as f:
|
106
|
+
logger.debug(f"Load Champion from {file}")
|
107
|
+
data[file.name.strip(".yaml")] = Champion(**yaml.safe_load(f))
|
108
|
+
return Champions(champions=data, target=target)
|
109
|
+
|
110
|
+
def plot_roc(self, df: pl.DataFrame, target_label: str | int):
|
111
|
+
score_feat = f"{target_label}"
|
112
|
+
df_plot = (
|
113
|
+
duckdb.sql(f"""
|
114
|
+
SELECT fp,tp
|
115
|
+
FROM (
|
116
|
+
SELECT
|
117
|
+
ROUND(SUM(CASE WHEN label == '{target_label}' THEN 1 ELSE 0 END) OVER (ORDER BY "{score_feat}" DESC) / SUM(CASE WHEN label == '{target_label}' THEN 1 ELSE 0 END) OVER (), 3) AS tp,
|
118
|
+
ROUND(SUM(CASE WHEN label == '{target_label}' THEN 0 ELSE 1 END) OVER (ORDER BY "{score_feat}" DESC) / SUM(CASE WHEN label == '{target_label}' THEN 0 ELSE 1 END) OVER (), 3) AS fp
|
119
|
+
FROM df
|
120
|
+
)
|
121
|
+
GROUP BY fp,tp
|
122
|
+
ORDER BY fp,tp
|
123
|
+
""")
|
124
|
+
.pl(
|
125
|
+
# ).vstack(pl.DataFrame({'fp': 0.0, 'tp': 0.0})
|
126
|
+
)
|
127
|
+
.with_columns(pl.lit(score_feat).alias("target"))
|
128
|
+
)
|
129
|
+
plot = df_plot.plot.line(x="fp", y="tp", color="target").interactive()
|
130
|
+
return plot
|