ebm4subjects 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,203 @@
1
+ import logging
2
+
3
+ import xgboost
4
+
5
+
6
+ class EbmLogger:
7
+ """
8
+ A custom logger class.
9
+
10
+ This class provides a way to log messages at different levels
11
+ (error, warning, info, debug) to a file.
12
+ It also provides a way to get the logger instance.
13
+
14
+ Attributes:
15
+ logger (logging.Logger): The logger instance.
16
+ log_path (str): The path to the log file.
17
+ level (str): The log level (default: "info").
18
+ """
19
+ def __init__(self, log_path: str, level: str = "info") -> None:
20
+ """
21
+ Initializes the logger.
22
+
23
+ Args:
24
+ log_path (str): The path to the log file.
25
+ level (str): The log level (default: "info").
26
+ """
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ # Set the log level based on the provided level
30
+ if level == "error":
31
+ self.logger.setLevel(logging.ERROR)
32
+ elif level == "warning":
33
+ self.logger.setLevel(logging.WARNING)
34
+ elif level == "info":
35
+ self.logger.setLevel(logging.INFO)
36
+ elif level == "debug":
37
+ self.logger.setLevel(logging.DEBUG)
38
+ else:
39
+ self.logger.setLevel(logging.NOTSET)
40
+
41
+ # Create a file handler to log messages to a file
42
+ log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
43
+ log_file_handler.setFormatter(
44
+ logging.Formatter(
45
+ "%(asctime)s %(levelname)s: %(message)s",
46
+ "%Y-%m-%d %H:%M:%S",
47
+ )
48
+ )
49
+
50
+ # Add the file handler to the logger
51
+ self.logger.addHandler(log_file_handler)
52
+
53
+ def get_logger(self) -> logging.Logger:
54
+ """
55
+ Returns the logger instance.
56
+
57
+ Returns:
58
+ logging.Logger: The logger instance.
59
+ """
60
+ return self.logger
61
+
62
+
63
+ class NullLogger:
64
+ """
65
+ A null logger class that does nothing.
66
+
67
+ This class is used when no logging is needed.
68
+ """
69
+ def __init__(self) -> None:
70
+ """
71
+ Initializes the null logger.
72
+ """
73
+ pass
74
+
75
+ def debug(self, *args, **kwargs):
76
+ """
77
+ Does nothing when debug message is logged.
78
+
79
+ Args:
80
+ *args: The message to log.
81
+ **kwargs: Additional keyword arguments.
82
+ """
83
+ pass
84
+
85
+ def info(self, *args, **kwargs):
86
+ """
87
+ Does nothing when info message is logged.
88
+
89
+ Args:
90
+ *args: The message to log.
91
+ **kwargs: Additional keyword arguments.
92
+ """
93
+ pass
94
+
95
+ def warn(self, *args, **kwargs):
96
+ """
97
+ Does nothing when warn message is logged.
98
+
99
+ Args:
100
+ *args: The message to log.
101
+ **kwargs: Additional keyword arguments.
102
+ """
103
+ pass
104
+
105
+ def warning(self, *args, **kwargs):
106
+ """
107
+ Does nothing when warning message is logged.
108
+
109
+ Args:
110
+ *args: The message to log.
111
+ **kwargs: Additional keyword arguments.
112
+ """
113
+ pass
114
+
115
+ def error(self, *args, **kwargs):
116
+ """
117
+ Does nothing when error message is logged.
118
+
119
+ Args:
120
+ *args: The message to log.
121
+ **kwargs: Additional keyword arguments.
122
+ """
123
+ pass
124
+
125
+ def critical(self, *args, **kwargs):
126
+ """
127
+ Does nothing when critical message is logged.
128
+
129
+ Args:
130
+ *args: The message to log.
131
+ **kwargs: Additional keyword arguments.
132
+ """
133
+ pass
134
+
135
+
136
+ class XGBLogging(xgboost.callback.TrainingCallback):
137
+ """
138
+ Custom XGBoost training callback for logging model performance during training.
139
+
140
+ Args:
141
+ logger (logging.Logger): Logger instance to use for logging.
142
+ epoch_log_interval (int, optional): Interval at which to log model performance
143
+ (default: 100).
144
+
145
+ Attributes:
146
+ logger (logging.Logger): Logger instance used for logging.
147
+ epoch_log_interval (int): Interval at which to log model performance.
148
+ """
149
+ def __init__(
150
+ self,
151
+ logger: logging.Logger,
152
+ epoch_log_interval: int = 100,
153
+ ) -> None:
154
+ """
155
+ Initializes the XGBLogger.
156
+
157
+ Args:
158
+ logger (logging.Logger): Logger instance to use for logging.
159
+ epoch_log_interval (int, optional): Interval at which to log model
160
+ performance (default: to 100).
161
+ """
162
+ # Logger instance used for logging
163
+ self.logger = logger
164
+ # Interval at which to log model performance
165
+ self.epoch_log_interval = epoch_log_interval
166
+
167
+ def after_iteration(
168
+ self,
169
+ model: xgboost.Booster,
170
+ epoch: int,
171
+ evals_log: dict,
172
+ ) -> bool:
173
+ """
174
+ Callback function called after each iteration of the XGBoost training process.
175
+
176
+ Logs model performance at the specified interval.
177
+
178
+ Args:
179
+ model (xgboost.Booster): XGBoost model instance.
180
+ epoch (int): Current epoch number.
181
+ evals_log (dict): Dictionary containing evaluation metrics.
182
+
183
+ Returns:
184
+ bool: Always returns False, as specified by the XGBoost callback API.
185
+ """
186
+ # Log model performance at the specified interval
187
+ if epoch % self.epoch_log_interval == 0:
188
+ # Iterate over each data point and its corresponding metrics
189
+ for data, metric in evals_log.items():
190
+ # Get the list of metric keys
191
+ metrics = list(metric.keys())
192
+
193
+ # Construct a string containing the metric values
194
+ metrics_str = ""
195
+ for metric_key in metrics:
196
+ # Append the metric key and its value to the string
197
+ metrics_str += f"{metric_key}: {metric[metric_key][-1]}"
198
+
199
+ # Log the model performance using the specified logger
200
+ self.logger.info(f"Epoch: {epoch}, {data}: {metrics_str}")
201
+
202
+ # Always return False, as specified by the XGBoost callback API
203
+ return False