ebm4subjects 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm4subjects/__init__.py +0 -0
- ebm4subjects/analyzer.py +57 -0
- ebm4subjects/chunker.py +173 -0
- ebm4subjects/duckdb_client.py +329 -0
- ebm4subjects/ebm_logging.py +203 -0
- ebm4subjects/ebm_model.py +715 -0
- ebm4subjects/embedding_generator.py +63 -0
- ebm4subjects/prepare_data.py +82 -0
- ebm4subjects-0.4.1.dist-info/METADATA +134 -0
- ebm4subjects-0.4.1.dist-info/RECORD +12 -0
- ebm4subjects-0.4.1.dist-info/WHEEL +4 -0
- ebm4subjects-0.4.1.dist-info/licenses/LICENSE +287 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import xgboost
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EbmLogger:
|
|
7
|
+
"""
|
|
8
|
+
A custom logger class.
|
|
9
|
+
|
|
10
|
+
This class provides a way to log messages at different levels
|
|
11
|
+
(error, warning, info, debug) to a file.
|
|
12
|
+
It also provides a way to get the logger instance.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
logger (logging.Logger): The logger instance.
|
|
16
|
+
log_path (str): The path to the log file.
|
|
17
|
+
level (str): The log level (default: "info").
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, log_path: str, level: str = "info") -> None:
|
|
20
|
+
"""
|
|
21
|
+
Initializes the logger.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
log_path (str): The path to the log file.
|
|
25
|
+
level (str): The log level (default: "info").
|
|
26
|
+
"""
|
|
27
|
+
self.logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Set the log level based on the provided level
|
|
30
|
+
if level == "error":
|
|
31
|
+
self.logger.setLevel(logging.ERROR)
|
|
32
|
+
elif level == "warning":
|
|
33
|
+
self.logger.setLevel(logging.WARNING)
|
|
34
|
+
elif level == "info":
|
|
35
|
+
self.logger.setLevel(logging.INFO)
|
|
36
|
+
elif level == "debug":
|
|
37
|
+
self.logger.setLevel(logging.DEBUG)
|
|
38
|
+
else:
|
|
39
|
+
self.logger.setLevel(logging.NOTSET)
|
|
40
|
+
|
|
41
|
+
# Create a file handler to log messages to a file
|
|
42
|
+
log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
|
|
43
|
+
log_file_handler.setFormatter(
|
|
44
|
+
logging.Formatter(
|
|
45
|
+
"%(asctime)s %(levelname)s: %(message)s",
|
|
46
|
+
"%Y-%m-%d %H:%M:%S",
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Add the file handler to the logger
|
|
51
|
+
self.logger.addHandler(log_file_handler)
|
|
52
|
+
|
|
53
|
+
def get_logger(self) -> logging.Logger:
|
|
54
|
+
"""
|
|
55
|
+
Returns the logger instance.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
logging.Logger: The logger instance.
|
|
59
|
+
"""
|
|
60
|
+
return self.logger
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class NullLogger:
|
|
64
|
+
"""
|
|
65
|
+
A null logger class that does nothing.
|
|
66
|
+
|
|
67
|
+
This class is used when no logging is needed.
|
|
68
|
+
"""
|
|
69
|
+
def __init__(self) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Initializes the null logger.
|
|
72
|
+
"""
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
def debug(self, *args, **kwargs):
|
|
76
|
+
"""
|
|
77
|
+
Does nothing when debug message is logged.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
*args: The message to log.
|
|
81
|
+
**kwargs: Additional keyword arguments.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
def info(self, *args, **kwargs):
|
|
86
|
+
"""
|
|
87
|
+
Does nothing when info message is logged.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
*args: The message to log.
|
|
91
|
+
**kwargs: Additional keyword arguments.
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
def warn(self, *args, **kwargs):
|
|
96
|
+
"""
|
|
97
|
+
Does nothing when warn message is logged.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
*args: The message to log.
|
|
101
|
+
**kwargs: Additional keyword arguments.
|
|
102
|
+
"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def warning(self, *args, **kwargs):
|
|
106
|
+
"""
|
|
107
|
+
Does nothing when warning message is logged.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
*args: The message to log.
|
|
111
|
+
**kwargs: Additional keyword arguments.
|
|
112
|
+
"""
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
def error(self, *args, **kwargs):
|
|
116
|
+
"""
|
|
117
|
+
Does nothing when error message is logged.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
*args: The message to log.
|
|
121
|
+
**kwargs: Additional keyword arguments.
|
|
122
|
+
"""
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
def critical(self, *args, **kwargs):
|
|
126
|
+
"""
|
|
127
|
+
Does nothing when critical message is logged.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
*args: The message to log.
|
|
131
|
+
**kwargs: Additional keyword arguments.
|
|
132
|
+
"""
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class XGBLogging(xgboost.callback.TrainingCallback):
|
|
137
|
+
"""
|
|
138
|
+
Custom XGBoost training callback for logging model performance during training.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
logger (logging.Logger): Logger instance to use for logging.
|
|
142
|
+
epoch_log_interval (int, optional): Interval at which to log model performance
|
|
143
|
+
(default: 100).
|
|
144
|
+
|
|
145
|
+
Attributes:
|
|
146
|
+
logger (logging.Logger): Logger instance used for logging.
|
|
147
|
+
epoch_log_interval (int): Interval at which to log model performance.
|
|
148
|
+
"""
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
logger: logging.Logger,
|
|
152
|
+
epoch_log_interval: int = 100,
|
|
153
|
+
) -> None:
|
|
154
|
+
"""
|
|
155
|
+
Initializes the XGBLogger.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
logger (logging.Logger): Logger instance to use for logging.
|
|
159
|
+
epoch_log_interval (int, optional): Interval at which to log model
|
|
160
|
+
performance (default: to 100).
|
|
161
|
+
"""
|
|
162
|
+
# Logger instance used for logging
|
|
163
|
+
self.logger = logger
|
|
164
|
+
# Interval at which to log model performance
|
|
165
|
+
self.epoch_log_interval = epoch_log_interval
|
|
166
|
+
|
|
167
|
+
def after_iteration(
|
|
168
|
+
self,
|
|
169
|
+
model: xgboost.Booster,
|
|
170
|
+
epoch: int,
|
|
171
|
+
evals_log: dict,
|
|
172
|
+
) -> bool:
|
|
173
|
+
"""
|
|
174
|
+
Callback function called after each iteration of the XGBoost training process.
|
|
175
|
+
|
|
176
|
+
Logs model performance at the specified interval.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
model (xgboost.Booster): XGBoost model instance.
|
|
180
|
+
epoch (int): Current epoch number.
|
|
181
|
+
evals_log (dict): Dictionary containing evaluation metrics.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
bool: Always returns False, as specified by the XGBoost callback API.
|
|
185
|
+
"""
|
|
186
|
+
# Log model performance at the specified interval
|
|
187
|
+
if epoch % self.epoch_log_interval == 0:
|
|
188
|
+
# Iterate over each data point and its corresponding metrics
|
|
189
|
+
for data, metric in evals_log.items():
|
|
190
|
+
# Get the list of metric keys
|
|
191
|
+
metrics = list(metric.keys())
|
|
192
|
+
|
|
193
|
+
# Construct a string containing the metric values
|
|
194
|
+
metrics_str = ""
|
|
195
|
+
for metric_key in metrics:
|
|
196
|
+
# Append the metric key and its value to the string
|
|
197
|
+
metrics_str += f"{metric_key}: {metric[metric_key][-1]}"
|
|
198
|
+
|
|
199
|
+
# Log the model performance using the specified logger
|
|
200
|
+
self.logger.info(f"Epoch: {epoch}, {data}: {metrics_str}")
|
|
201
|
+
|
|
202
|
+
# Always return False, as specified by the XGBoost callback API
|
|
203
|
+
return False
|