champions 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
champions/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ import logging.config
2
+ import yaml
3
+
4
+ try:
5
+ with open("logging.yaml", "r") as f:
6
+ config = yaml.safe_load(f)
7
+ except FileNotFoundError:
8
+ config = {
9
+ "version": 1,
10
+ "disable_existing_loggers": False,
11
+ "formatters": {
12
+ "simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
13
+ },
14
+ "handlers": {
15
+ "console": {
16
+ "class": "logging.StreamHandler",
17
+ "formatter": "simple",
18
+ "level": "INFO",
19
+ }
20
+ },
21
+ "root": {"handlers": ["console"], "level": "INFO"},
22
+ }
23
+
24
+ logging.config.dictConfig(config)
champions/cli.py ADDED
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Annotated
3
+ import typer
4
+ import yaml
5
+
6
+ from champions.model.datacard import DataCard
7
+ from champions.model.settings import EvalSettings, TrainSettings
8
+ from champions.service.eval import Eval
9
+ from champions.service.train import Train
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = typer.Typer()
15
+
16
+
17
+ # def main(config: Annotated[typer.FileText, typer.Option()]):
18
+ @app.command()
19
+ def train(
20
+ datacard: Annotated[typer.FileText, typer.Option()],
21
+ trainsettings: Annotated[typer.FileText, typer.Option()],
22
+ ):
23
+ train = Train(
24
+ dc=DataCard(**yaml.safe_load(datacard)),
25
+ settings=TrainSettings(**yaml.safe_load(trainsettings)),
26
+ )
27
+ train.run()
28
+
29
+
30
+ @app.command()
31
+ def eval(
32
+ datacard: Annotated[typer.FileText, typer.Option()],
33
+ evalsettings: Annotated[typer.FileText, typer.Option()],
34
+ ):
35
+ train = Eval(
36
+ dc=DataCard(**yaml.safe_load(datacard)),
37
+ settings=EvalSettings(**yaml.safe_load(evalsettings)),
38
+ )
39
+ train.run()
File without changes
@@ -0,0 +1,35 @@
1
+ import logging
2
+ from pydantic import BaseModel
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class Spore(BaseModel):
8
+ cut: list[str]
9
+ score: float
10
+ depth: str
11
+
12
+
13
+ class Champion(BaseModel):
14
+ spore: list[Spore]
15
+ target: str | int
16
+
17
+ def get_sql(self, res_name: str) -> str:
18
+ sql_ref = "CASE \n"
19
+ for spore in self.spore:
20
+ sql_ref += (
21
+ f" WHEN {' AND '.join(spore.cut)} THEN CAST({spore.score} AS DOUBLE)\n"
22
+ )
23
+ sql_ref += f"END AS {res_name}"
24
+ return sql_ref
25
+
26
+
27
+ class Champions(BaseModel):
28
+ champions: dict[str, Champion]
29
+ target: str
30
+
31
+ def get_sql(self) -> dict[str, str]:
32
+ return {
33
+ f"res_{name}": champion.get_sql(f"res_{name}")
34
+ for name, champion in self.champions.items()
35
+ }
@@ -0,0 +1,48 @@
1
+ import logging
2
+ from typing import Literal
3
+ from pydantic import BaseModel
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class Feature(BaseModel):
9
+ name: str
10
+ statistical: Literal["category", "continus"]
11
+ type: str
12
+
13
+
14
+ class Target(BaseModel):
15
+ feature_name: str
16
+ values: list[int | str]
17
+
18
+
19
+ class DataCard(BaseModel):
20
+ features: list[Feature]
21
+ infos: dict
22
+ target: Target
23
+ test_files: list[str]
24
+ train_files: list[str]
25
+
26
+ @property
27
+ def feature_names(self) -> list[str]:
28
+ return [f'"{feat.name}"' for feat in self.features]
29
+
30
+ @property
31
+ def train_feature(self) -> list[Feature]:
32
+ return [feat for feat in self.features if feat.name != self.target.feature_name]
33
+
34
+ @property
35
+ def train_cat_feature(self) -> list[Feature]:
36
+ return [
37
+ feat
38
+ for feat in self.features
39
+ if feat.name != self.target.feature_name and feat.statistical == "category"
40
+ ]
41
+
42
+ @property
43
+ def train_con_feature(self) -> list[Feature]:
44
+ return [
45
+ feat
46
+ for feat in self.features
47
+ if feat.name != self.target.feature_name and feat.statistical == "continus"
48
+ ]
@@ -0,0 +1,225 @@
1
+ import itertools
2
+ import math
3
+ import time
4
+ import polars as pl
5
+ from pydantic import BaseModel
6
+ from dataclasses import dataclass
7
+
8
+ from champions.model.datacard import Feature
9
+ from champions.model.filter import CombineFilter, Filter, SingleFilter
10
+ import sys
11
+
12
+
13
+ class ContCats(BaseModel):
14
+ cuts: list[float]
15
+ labels: list[str]
16
+ feat_name: str
17
+
18
+ def cut(self, series: pl.Series) -> pl.Series:
19
+ # return series.cut(self.cuts)
20
+ return series.cut(self.cuts, labels=self.labels)
21
+
22
+ def get_single_filter(self, i: int) -> SingleFilter:
23
+ if i == 2 and len(self.cuts) == 1:
24
+ sys.exit()
25
+
26
+ if i == 0:
27
+ return SingleFilter(
28
+ feat_name=self.feat_name,
29
+ operator="<=",
30
+ value=self.cuts[i],
31
+ )
32
+
33
+ if i == len(self.labels) - 1:
34
+ return SingleFilter(
35
+ feat_name=self.feat_name,
36
+ operator=">",
37
+ value=self.cuts[i - 1],
38
+ )
39
+ return SingleFilter(
40
+ feat_name=self.feat_name,
41
+ operator="between",
42
+ value=[self.cuts[i - 1], self.cuts[i]],
43
+ )
44
+
45
+
46
+ class CategorizedFeatureMixin:
47
+ diff_df: pl.DataFrame
48
+ cut_list: list[ContCats]
49
+ feature_list: list[str]
50
+
51
+ def calc_diff(
52
+ self,
53
+ ) -> float:
54
+ # self.diff_sr.entropy()
55
+ return self.diff_df["diff"].abs().sum()
56
+
57
+ def set_diff_df(
58
+ self,
59
+ df_count_target: pl.DataFrame,
60
+ df_count_non_target: pl.DataFrame,
61
+ join_on: list[str] | str,
62
+ ):
63
+ self.diff_df = (
64
+ df_count_non_target.join(df_count_target, on=join_on, how="outer")
65
+ .fill_null(0)
66
+ .with_columns(
67
+ (pl.col("proportion_right") - pl.col("proportion")).alias("diff"),
68
+ pl.max_horizontal(
69
+ pl.col("proportion"), pl.col("proportion_right")
70
+ ).alias("max_proportion"),
71
+ )
72
+ )
73
+
74
+ def get_left_right_filter(self) -> tuple[Filter, Filter]:
75
+ # for fea_name, cut in zip(self.feature_list, self.cut_list):
76
+
77
+ target_res_df = self.diff_df.filter(pl.col("diff") > 0)[
78
+ [f"{cut.feat_name}_right" for cut in self.cut_list]
79
+ ]
80
+ target_lists = [
81
+ target_res_df[f"{cut.feat_name}_right"].to_list() for cut in self.cut_list
82
+ ]
83
+ target_list_tupels = [akt_tuple for akt_tuple in zip(*target_lists)]
84
+
85
+ all_labels = [cut.labels for cut in self.cut_list]
86
+ target_filter = []
87
+ non_targer_filter = []
88
+
89
+ for possible_combs in itertools.product(*all_labels):
90
+ single_filter_list = [
91
+ cut.get_single_filter(int(index))
92
+ for index, cut in zip(possible_combs, self.cut_list)
93
+ ]
94
+ comb_filter = CombineFilter(combine=single_filter_list)
95
+ if possible_combs in target_list_tupels:
96
+ target_filter.append(comb_filter)
97
+ else:
98
+ non_targer_filter.append(comb_filter)
99
+
100
+ return Filter(combine=non_targer_filter, invert=False), Filter(
101
+ combine=target_filter, invert=False
102
+ )
103
+
104
+
105
+ @dataclass
106
+ class CategorizedFeature(CategorizedFeatureMixin):
107
+ feature: Feature
108
+ cuts: ContCats
109
+ target_sr: pl.Series
110
+ non_target_sr: pl.Series
111
+
112
+ def __post_init__(self):
113
+ df_count_target = self.target_sr.value_counts(normalize=True)
114
+ df_count_non_target = self.non_target_sr.value_counts(normalize=True)
115
+ self.set_diff_df(
116
+ df_count_target=df_count_target,
117
+ df_count_non_target=df_count_non_target,
118
+ join_on=self.feature.name,
119
+ )
120
+
121
+
122
+ @property
123
+ def cut_list(self) -> list[ContCats]:
124
+ return [self.cuts]
125
+
126
+ def is_diff_to_low(self, threshold: float = 0.90) -> bool:
127
+ max_wert = self.diff_df["max_proportion"].max()
128
+ min_prop_of_max = self.diff_df.filter(pl.col("max_proportion") == max_wert)[
129
+ "proportion", "proportion_right"
130
+ ].min_horizontal()[0]
131
+ return min_prop_of_max > threshold
132
+
133
+
134
+ class CombinedCategorizedFeature(CategorizedFeatureMixin):
135
+ def __init__(
136
+ self,
137
+ train_features: tuple[CategorizedFeature],
138
+ non_target_size: int,
139
+ target_size: int,
140
+ ):
141
+ groub_by = [train.feature.name for train in train_features]
142
+ non_target_df = pl.DataFrame([train.non_target_sr for train in train_features])
143
+
144
+ df_count_non_target = (
145
+ non_target_df.group_by(groub_by)
146
+ .len(name="proportion")
147
+ .with_columns(pl.col("proportion") / non_target_size)
148
+ )
149
+ target_df = pl.DataFrame([train.target_sr for train in train_features])
150
+ df_count_target = (
151
+ target_df.group_by(groub_by)
152
+ .len(name="proportion")
153
+ .with_columns(pl.col("proportion") / target_size)
154
+ )
155
+ self.set_diff_df(
156
+ df_count_target=df_count_target,
157
+ df_count_non_target=df_count_non_target,
158
+ join_on=groub_by,
159
+ )
160
+ self.cut_list = [train.cuts for train in train_features]
161
+
162
+
163
+ class TrainDataframes:
164
+ def __init__(
165
+ self,
166
+ target_df: pl.DataFrame,
167
+ non_target_df: pl.DataFrame,
168
+ frac_eval_cat: float,
169
+ min_size: int,
170
+ ):
171
+ self.target_df_size = target_df.height
172
+ self.non_target_df_size = non_target_df.height
173
+
174
+ self.min_size = min_size
175
+
176
+ self.n_count_target, n_group_target = self._calc_split(target_df.height, frac_eval_cat)
177
+ self.target_df_count = target_df.head(self.n_count_target)
178
+ target_df_group = target_df.tail(-self.n_count_target)
179
+ #target_df_group['weight'] = 0.5 / len(target_df_group)
180
+
181
+ self.n_count_non_target, n_group_non_target = self._calc_split(non_target_df.height, frac_eval_cat)
182
+ self.non_target_df_count = non_target_df.head(self.n_count_non_target)
183
+ non_target_df_group = non_target_df.tail(-self.n_count_non_target)
184
+
185
+ if n_group_non_target > 0 and n_group_target > 0:
186
+
187
+ self.df_group = pl.concat([target_df_group.with_columns(pl.lit(0.5 /n_group_target).alias('weight')),
188
+ non_target_df_group.with_columns(pl.lit(0.5 /n_group_non_target).alias('weight'))
189
+ ])
190
+ self.train_features: list[CategorizedFeature] = []
191
+
192
+ def create_categorized_features(
193
+ self, feat: Feature, cuts: ContCats
194
+ ) -> CategorizedFeature:
195
+ target_sr = cuts.cut(series=self.target_df_count[feat.name])
196
+ non_target_sr = cuts.cut(series=self.non_target_df_count[feat.name])
197
+ return CategorizedFeature(
198
+ feature=feat, cuts=cuts, target_sr=target_sr, non_target_sr=non_target_sr
199
+ )
200
+
201
+ def _calc_split(self, n: int, frac: float):
202
+ split = round(n * frac)
203
+ if split < 1:
204
+ return 1 ,0
205
+ if split >= n:
206
+ return n - 1 , 1
207
+ return split, n- split
208
+
209
+ def score(self) -> float:
210
+ if self.non_target_df_size + self.target_df_size == 0:
211
+ return 0.0
212
+ return (self.target_df_size - self.non_target_df_size) / (
213
+ self.non_target_df_size + self.target_df_size
214
+ )
215
+
216
+ def is_final_size(self) -> bool:
217
+ return (
218
+ self.target_df_size < self.min_size
219
+ or self.non_target_df_size < self.min_size
220
+ )
221
+
222
+
223
+ class EvalDataframe:
224
+ def __init__(self, df: pl.DataFrame, target: str | int) -> None:
225
+ self.df = df
@@ -0,0 +1,37 @@
1
+ from typing import Literal
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class SingleFilter(BaseModel):
6
+ feat_name: str
7
+ operator: Literal["=", "<", "<=", ">", ">=", "in", "between"]
8
+ value: str | int | float | list[str] | list[int] | list[float]
9
+
10
+ def sql(self) -> str:
11
+ if self.operator == "between":
12
+ return f'( "{self.feat_name}" > {self.value[0]} AND "{self.feat_name}" <= {self.value[1]} )'
13
+ if isinstance(self.value, list):
14
+ return f"{self.feat_name} {self.operator} ( {', '.join(self.value)})"
15
+ return f'"{self.feat_name}"{self.operator}{self.value}'
16
+
17
+
18
+ class CombineFilter(BaseModel):
19
+ combine: list[SingleFilter]
20
+ invert: bool = False
21
+
22
+ def sql(self, do_invert: bool = False) -> str:
23
+ prefix = "not" if self.invert ^ do_invert else "" # python xor
24
+ if len(self.combine) == 1:
25
+ return f" {prefix} {self.combine[0].sql()} "
26
+ return f" {prefix} ( {' AND '.join([sf.sql() for sf in self.combine])} )"
27
+
28
+
29
+ class Filter(BaseModel):
30
+ combine: list[CombineFilter]
31
+ invert: bool
32
+
33
+ def sql(self, do_invert: bool = False) -> str:
34
+ prefix = "not" if self.invert ^ do_invert else "" # python xor
35
+ if len(self.combine) == 1:
36
+ return f" {prefix} {self.combine[0].sql()} "
37
+ return f" {prefix} ( {' OR '.join([sf.sql() for sf in self.combine])} )"
@@ -0,0 +1,55 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from pydantic import BaseModel, Field
5
+ import yaml
6
+
7
+ from champions.model.champions import Champion, Champions
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EvalSettings(BaseModel):
13
+ in_folder: Path = Field(description="Name where the champion is stored")
14
+ out_folder: Path = Field(description="Name where the evaluation is stored")
15
+ roc_points: int = Field(description="Number of points for roc curve", default=500)
16
+
17
+
18
+ class TrainSettings(BaseModel):
19
+ n: int = Field(description="Number of Champions per target value", default=1)
20
+ out_folder: Path = Field(description="Name where the champion is stored")
21
+ n_cat: int = Field(description="Number of categories in one featerue", default=3)
22
+ max_depth: int = Field(description="Max depth of the champion", default=10)
23
+
24
+ max_eval_fit: int = Field(
25
+ description="Wieviel sollen max für training gleichzeitig benutzt werden",
26
+ default=1000,
27
+ )
28
+ min_eval_fit: int = Field(
29
+ description="Wieviel müssen fürs training mindestens benutzt werden", default=10
30
+ )
31
+ n_dims: int = Field(
32
+ description="Wieviele dimensionen sollen auf einmal benutzt werden", default=3
33
+ )
34
+ calcs_per_dim: int | None = Field(
35
+ description="Wieviele berechnungen sollen pro dimension gemacht werden",
36
+ default=5000,
37
+ )
38
+ frac_eval_cat: float = Field(
39
+ description="Wie groß ist der anteil der für eval benutzt werden soll",
40
+ default=0.5,
41
+ )
42
+
43
+ def champion_exists(self, target: str | int, n: int) -> bool:
44
+ if self.out_folder.exists():
45
+ if (self.out_folder / f"{target}" / f"{n}.yaml").exists():
46
+ return True
47
+ return False
48
+
49
+ def save_champion(self, champion: Champion, n: int):
50
+ target = champion.target
51
+ out_folder = self.out_folder / f"{target}"
52
+ logger.info(f"save champion {n} for target {target} to {out_folder}")
53
+ out_folder.mkdir(parents=True, exist_ok=True)
54
+ with open(out_folder / f"{n}.yaml", "w") as f:
55
+ f.write(yaml.safe_dump(champion.model_dump()))
File without changes
@@ -0,0 +1,96 @@
1
+ import logging
2
+ import sys
3
+ from typing import Any
4
+ import duckdb
5
+ from pydantic import BaseModel
6
+ import polars as pl
7
+
8
+ from champions.model.champions import Champions
9
+ from champions.model.datacard import DataCard
10
+ from champions.model.dataframes import EvalDataframe, TrainDataframes
11
+ from champions.model.filter import Filter, SingleFilter
12
+ from champions.model.settings import TrainSettings
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Darkwing(BaseModel):
19
+ dc: DataCard
20
+ df_train_cach: Any | None = None
21
+
22
+ def read_akt_train(
23
+ self,
24
+ targer_filter: Filter,
25
+ train_settings: TrainSettings,
26
+ akt_filters: list[Filter] = [],
27
+ ) -> TrainDataframes:
28
+ """
29
+ Reads a list of CSV files into a DuckDB relation.
30
+
31
+ Args:
32
+ filepaths: A list of filepaths to CSV files.
33
+
34
+ Returns:
35
+ A DuckDB relation containing the data from all CSV files.
36
+ """
37
+
38
+ full_filters_target = [targer_filter.sql()] + [f.sql() for f in akt_filters]
39
+ target_df = self._get_pl_train_df(
40
+ full_filters=full_filters_target,
41
+ max_eval_fit=train_settings.max_eval_fit,
42
+ )
43
+ full_filters_non_target = [targer_filter.sql(do_invert=True)] + [
44
+ f.sql() for f in akt_filters
45
+ ]
46
+ non_target_df = self._get_pl_train_df(
47
+ full_filters=full_filters_non_target,
48
+ max_eval_fit=train_settings.max_eval_fit,
49
+ )
50
+
51
+ return TrainDataframes(
52
+ target_df=target_df,
53
+ non_target_df=non_target_df,
54
+ frac_eval_cat=train_settings.frac_eval_cat,
55
+ min_size=train_settings.min_eval_fit,
56
+ )
57
+
58
+ def _get_pl_train_df(self, full_filters: list[str], max_eval_fit) -> pl.DataFrame:
59
+ df = self.get_cached_train_df()
60
+ sql = f"""
61
+ SELECT {", ".join(self.dc.feature_names)}
62
+ FROM df
63
+ WHERE {" AND ".join(full_filters)}
64
+ LIMIT {max_eval_fit};
65
+ """
66
+ return duckdb.sql(sql).pl().sample(fraction=1, shuffle=True)
67
+
68
+ def get_cached_train_df(self) -> pl.DataFrame:
69
+ if self.df_train_cach is None:
70
+ self.df_train_cach = pl.read_csv(",".join(self.dc.train_files))
71
+ return self.df_train_cach
72
+
73
+ def get_eval_sr(self, champions: Champions) -> EvalDataframe:
74
+ df_sum = None
75
+ norm = 0.0
76
+
77
+ for name, champion in champions.champions.items():
78
+ norm += 1.0
79
+ col_name = f"res_{name}"
80
+ case_sql = champion.get_sql(col_name)
81
+ df = self._get_pl_eval_df(col_sql=case_sql)
82
+ df_sum = df_sum.hstack(df) if df_sum is not None else df
83
+
84
+ return (df_sum.sum_horizontal() / norm).alias(champions.target)
85
+
86
+ # for feat_name, case_sql in champions.get_sql().items():
87
+ # df = self._get_pl_eval_df(case_sql=case_sql)
88
+ # logger.info(f"Evaluate {df}")
89
+ # logger.info(f"{case_sql}")
90
+
91
+ def _get_pl_eval_df(self, col_sql: str) -> pl.DataFrame:
92
+ sql = f"""
93
+ SELECT {col_sql}\n
94
+ FROM read_csv({self.dc.test_files});
95
+ """
96
+ return duckdb.sql(sql).pl()
@@ -0,0 +1,130 @@
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+ import duckdb
5
+ from pydantic import BaseModel
6
+ import yaml
7
+ from sklearn.metrics import roc_curve
8
+ import altair as alt
9
+
10
+ from champions.model.champions import Champion, Champions
11
+ from champions.model.datacard import DataCard
12
+ from champions.model.settings import EvalSettings
13
+ from champions.service.darkwing import Darkwing
14
+ import polars as pl
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class Eval(BaseModel):
21
+ dc: DataCard
22
+ settings: EvalSettings
23
+ darkwing: Optional[Darkwing] = None
24
+
25
+ def model_post_init(self, __context) -> None:
26
+ self.darkwing = Darkwing(dc=self.dc)
27
+ os.makedirs(self.settings.out_folder, exist_ok=True)
28
+ super().model_post_init(__context)
29
+
30
+ def run(self):
31
+ logger.info("Start Eval")
32
+
33
+ df_res = self.darkwing._get_pl_eval_df(col_sql=self.dc.target.feature_name)
34
+ # df_res = pl.read_parquet("test.parquet")
35
+ plots = None
36
+
37
+ for target in self.dc.target.values:
38
+ logger.info(f"Evaluate target {target}")
39
+ target_champions = self.load_champions(target=f"{target}")
40
+ sr = self.darkwing.get_eval_sr(champions=target_champions)
41
+ df_res = df_res.with_columns(sr)
42
+ new_plot = self.plot_roc(df=df_res, target_label=target)
43
+ new_plot.save(self.settings.out_folder / f"{target}_roc.html")
44
+ if plots is None:
45
+ plots = new_plot
46
+ else:
47
+ plots = plots + new_plot
48
+
49
+ plots.save(self.settings.out_folder / "all_roc.html")
50
+
51
+ if len(self.dc.target.values) > 1:
52
+ df_multi_res = self.add_multi_class_result(
53
+ df=df_res, target_values=self.dc.target.values
54
+ )
55
+ df_res = pl.concat(
56
+ [df_res, df_multi_res],
57
+ how="horizontal",
58
+ )
59
+
60
+ multi_res_plot = (
61
+ df_res.group_by("label")
62
+ .agg(pl.col("correct").sum() / pl.count())
63
+ .sort("label")
64
+ .plot.bar(x="label", y="correct")
65
+ )
66
+ prec = df_res["correct"].sum() / df_res["correct"].count()
67
+ all_pred_labels = (
68
+ alt.Chart(pl.DataFrame({"correct": [prec]}))
69
+ .mark_rule()
70
+ .encode(y="correct")
71
+ )
72
+ (multi_res_plot + all_pred_labels).save(
73
+ self.settings.out_folder / "multi_class_result.html"
74
+ )
75
+ logger.info(f"{df_res}")
76
+
77
+ def add_multi_class_result(self, df: pl.DataFrame, target_values: list[str | int]):
78
+ greates = ", ".join([f'"{value}"' for value in target_values])
79
+ cases = [
80
+ f'WHEN "{value}" = GREATEST({greates}) THEN {value}'
81
+ for value in target_values
82
+ ]
83
+ return (
84
+ duckdb.sql(f"""
85
+ SELECT
86
+ label,
87
+ CASE
88
+ {" \n ".join(cases)}
89
+ END AS predicted_label
90
+ FROM df
91
+ """)
92
+ .pl()
93
+ .with_columns(
94
+ (pl.col("label") == pl.col("predicted_label"))
95
+ .cast(pl.Int64)
96
+ .alias("correct")
97
+ )
98
+ .select("predicted_label", "correct")
99
+ )
100
+
101
+ def load_champions(self, target: str) -> Champions:
102
+ target_folder = self.settings.in_folder / target
103
+ data = {}
104
+ for file in target_folder.glob("*.yaml"):
105
+ with open(file, "r") as f:
106
+ logger.debug(f"Load Champion from {file}")
107
+ data[file.name.strip(".yaml")] = Champion(**yaml.safe_load(f))
108
+ return Champions(champions=data, target=target)
109
+
110
+ def plot_roc(self, df: pl.DataFrame, target_label: str | int):
111
+ score_feat = f"{target_label}"
112
+ df_plot = (
113
+ duckdb.sql(f"""
114
+ SELECT fp,tp
115
+ FROM (
116
+ SELECT
117
+ ROUND(SUM(CASE WHEN label == '{target_label}' THEN 1 ELSE 0 END) OVER (ORDER BY "{score_feat}" DESC) / SUM(CASE WHEN label == '{target_label}' THEN 1 ELSE 0 END) OVER (), 3) AS tp,
118
+ ROUND(SUM(CASE WHEN label == '{target_label}' THEN 0 ELSE 1 END) OVER (ORDER BY "{score_feat}" DESC) / SUM(CASE WHEN label == '{target_label}' THEN 0 ELSE 1 END) OVER (), 3) AS fp
119
+ FROM df
120
+ )
121
+ GROUP BY fp,tp
122
+ ORDER BY fp,tp
123
+ """)
124
+ .pl(
125
+ # ).vstack(pl.DataFrame({'fp': 0.0, 'tp': 0.0})
126
+ )
127
+ .with_columns(pl.lit(score_feat).alias("target"))
128
+ )
129
+ plot = df_plot.plot.line(x="fp", y="tp", color="target").interactive()
130
+ return plot