chaine 3.13.1__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chaine might be problematic. Click here for more details.
- chaine/__init__.py +2 -0
- chaine/_core/crf.cpp +19854 -0
- chaine/_core/crf.cpython-313-x86_64-linux-gnu.so +0 -0
- chaine/_core/crf.pyx +271 -0
- chaine/_core/crfsuite/COPYING +27 -0
- chaine/_core/crfsuite/README +183 -0
- chaine/_core/crfsuite/include/crfsuite.h +1077 -0
- chaine/_core/crfsuite/include/crfsuite.hpp +649 -0
- chaine/_core/crfsuite/include/crfsuite_api.hpp +406 -0
- chaine/_core/crfsuite/include/os.h +65 -0
- chaine/_core/crfsuite/lib/cqdb/COPYING +28 -0
- chaine/_core/crfsuite/lib/cqdb/include/cqdb.h +518 -0
- chaine/_core/crfsuite/lib/cqdb/src/cqdb.c +639 -0
- chaine/_core/crfsuite/lib/cqdb/src/lookup3.c +1271 -0
- chaine/_core/crfsuite/lib/cqdb/src/main.c +184 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d.h +354 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d_context.c +788 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d_encode.c +1020 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d_feature.c +382 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d_model.c +1085 -0
- chaine/_core/crfsuite/lib/crf/src/crf1d_tag.c +582 -0
- chaine/_core/crfsuite/lib/crf/src/crfsuite.c +500 -0
- chaine/_core/crfsuite/lib/crf/src/crfsuite_internal.h +233 -0
- chaine/_core/crfsuite/lib/crf/src/crfsuite_train.c +302 -0
- chaine/_core/crfsuite/lib/crf/src/dataset.c +115 -0
- chaine/_core/crfsuite/lib/crf/src/dictionary.c +127 -0
- chaine/_core/crfsuite/lib/crf/src/holdout.c +83 -0
- chaine/_core/crfsuite/lib/crf/src/json.c +1497 -0
- chaine/_core/crfsuite/lib/crf/src/json.h +120 -0
- chaine/_core/crfsuite/lib/crf/src/logging.c +85 -0
- chaine/_core/crfsuite/lib/crf/src/logging.h +49 -0
- chaine/_core/crfsuite/lib/crf/src/params.c +370 -0
- chaine/_core/crfsuite/lib/crf/src/params.h +84 -0
- chaine/_core/crfsuite/lib/crf/src/quark.c +180 -0
- chaine/_core/crfsuite/lib/crf/src/quark.h +46 -0
- chaine/_core/crfsuite/lib/crf/src/rumavl.c +1178 -0
- chaine/_core/crfsuite/lib/crf/src/rumavl.h +144 -0
- chaine/_core/crfsuite/lib/crf/src/train_arow.c +409 -0
- chaine/_core/crfsuite/lib/crf/src/train_averaged_perceptron.c +237 -0
- chaine/_core/crfsuite/lib/crf/src/train_l2sgd.c +491 -0
- chaine/_core/crfsuite/lib/crf/src/train_lbfgs.c +323 -0
- chaine/_core/crfsuite/lib/crf/src/train_passive_aggressive.c +442 -0
- chaine/_core/crfsuite/lib/crf/src/vecmath.h +360 -0
- chaine/_core/crfsuite/swig/crfsuite.cpp +1 -0
- chaine/_core/crfsuite_api.pxd +67 -0
- chaine/_core/liblbfgs/COPYING +22 -0
- chaine/_core/liblbfgs/README +71 -0
- chaine/_core/liblbfgs/include/lbfgs.h +745 -0
- chaine/_core/liblbfgs/lib/arithmetic_ansi.h +142 -0
- chaine/_core/liblbfgs/lib/arithmetic_sse_double.h +303 -0
- chaine/_core/liblbfgs/lib/arithmetic_sse_float.h +312 -0
- chaine/_core/liblbfgs/lib/lbfgs.c +1531 -0
- chaine/_core/tagger_wrapper.hpp +58 -0
- chaine/_core/trainer_wrapper.cpp +32 -0
- chaine/_core/trainer_wrapper.hpp +26 -0
- chaine/crf.py +505 -0
- chaine/logging.py +214 -0
- chaine/optimization/__init__.py +10 -0
- chaine/optimization/metrics.py +129 -0
- chaine/optimization/spaces.py +394 -0
- chaine/optimization/trial.py +103 -0
- chaine/optimization/utils.py +119 -0
- chaine/training.py +184 -0
- chaine/typing.py +18 -0
- chaine/validation.py +43 -0
- chaine-3.13.1.dist-info/METADATA +348 -0
- chaine-3.13.1.dist-info/RECORD +68 -0
- chaine-3.13.1.dist-info/WHEEL +6 -0
chaine/logging.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
chaine.logging
|
|
3
|
+
~~~~~~~~~~~~~~
|
|
4
|
+
|
|
5
|
+
This module implements a basic logger.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import sys
|
|
10
|
+
from logging import Formatter, StreamHandler
|
|
11
|
+
|
|
12
|
+
DEBUG = 10
|
|
13
|
+
INFO = 20
|
|
14
|
+
WARNING = 30
|
|
15
|
+
ERROR = 40
|
|
16
|
+
LEVELS = {"DEBUG": DEBUG, "INFO": INFO, "WARNING": WARNING, "ERROR": ERROR}
|
|
17
|
+
|
|
18
|
+
DEFAULT_FORMAT = Formatter("[%(asctime)s] [%(levelname)s] %(message)s")
|
|
19
|
+
DEBUG_FORMAT = Formatter("[%(asctime)s] %(name)s:%(lineno)d [%(levelname)s] %(message)s")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Logger:
|
|
23
|
+
def __init__(self, name: str):
|
|
24
|
+
"""Basic logger
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
name : str
|
|
29
|
+
Name of the logger
|
|
30
|
+
"""
|
|
31
|
+
self.name = name
|
|
32
|
+
|
|
33
|
+
# return a logger with the specified name, creating it if necessary
|
|
34
|
+
self._logger = logging.getLogger(name)
|
|
35
|
+
|
|
36
|
+
# stream handler to stdout
|
|
37
|
+
self._stream_handler = StreamHandler(sys.stdout)
|
|
38
|
+
self._logger.addHandler(self._stream_handler)
|
|
39
|
+
|
|
40
|
+
# set level of both the logger and the handler to INFO by default
|
|
41
|
+
self.set_level("INFO")
|
|
42
|
+
|
|
43
|
+
def set_level(self, level: str | int):
|
|
44
|
+
# translate string to integer
|
|
45
|
+
if isinstance(level, str):
|
|
46
|
+
level = LEVELS[level.upper()]
|
|
47
|
+
|
|
48
|
+
# set the logger's level
|
|
49
|
+
self._logger.setLevel(level)
|
|
50
|
+
|
|
51
|
+
# and all handlers
|
|
52
|
+
for handler in self._logger.handlers:
|
|
53
|
+
handler.setLevel(level)
|
|
54
|
+
|
|
55
|
+
# optionally change the formatter (log more when in debug mode)
|
|
56
|
+
if level < INFO:
|
|
57
|
+
handler.setFormatter(DEBUG_FORMAT)
|
|
58
|
+
else:
|
|
59
|
+
handler.setFormatter(DEFAULT_FORMAT)
|
|
60
|
+
|
|
61
|
+
def debug(self, message: str):
|
|
62
|
+
"""Debug log message
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
message : str
|
|
67
|
+
Message to log
|
|
68
|
+
"""
|
|
69
|
+
if self._logger.isEnabledFor(DEBUG):
|
|
70
|
+
self._logger._log(DEBUG, message, ())
|
|
71
|
+
|
|
72
|
+
def info(self, message: str):
|
|
73
|
+
"""Info log message
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
message : str
|
|
78
|
+
Message to log
|
|
79
|
+
"""
|
|
80
|
+
if self._logger.isEnabledFor(INFO):
|
|
81
|
+
self._logger._log(INFO, message, ())
|
|
82
|
+
|
|
83
|
+
def warning(self, message: str):
|
|
84
|
+
"""Warning log message
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
message : str
|
|
89
|
+
Message to log
|
|
90
|
+
"""
|
|
91
|
+
if self._logger.isEnabledFor(WARNING):
|
|
92
|
+
self._logger._log(WARNING, message, ())
|
|
93
|
+
|
|
94
|
+
def error(self, message: str | Exception):
|
|
95
|
+
"""Error log message
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
message : str
|
|
100
|
+
Message to log
|
|
101
|
+
"""
|
|
102
|
+
if self._logger.isEnabledFor(ERROR):
|
|
103
|
+
if isinstance(message, Exception):
|
|
104
|
+
# log stacktrace if message is an exception
|
|
105
|
+
self._logger._log(ERROR, message, (), exc_info=True)
|
|
106
|
+
else:
|
|
107
|
+
self._logger._log(ERROR, message, ())
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def in_debug_mode(self) -> bool:
|
|
111
|
+
"""Checks if the logger's level is DEBUG
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
bool
|
|
116
|
+
True, if logger is in DEBUG mode, False otherwise
|
|
117
|
+
"""
|
|
118
|
+
return self._logger.level == DEBUG
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def level(self) -> int:
|
|
122
|
+
"""Returns the current log level
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
int
|
|
127
|
+
Log level.
|
|
128
|
+
"""
|
|
129
|
+
return self._logger.level
|
|
130
|
+
|
|
131
|
+
def __repr__(self):
|
|
132
|
+
return f"<Logger: {self.name} ({self.level})>"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_logger(name: str) -> logging.Logger:
|
|
136
|
+
"""Gets the specified logger object
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
name : str
|
|
141
|
+
Name of the module to get the logger for
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
logging.Logger
|
|
146
|
+
Logger of the specified module
|
|
147
|
+
"""
|
|
148
|
+
return logging.getLogger(name)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def logger_exists(name: str) -> bool:
|
|
152
|
+
"""Checks if a logger exists for the specified module
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
name : str
|
|
157
|
+
Name of the module to check the logger for
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
bool
|
|
162
|
+
True if logger exists, False otherwise
|
|
163
|
+
"""
|
|
164
|
+
return logging.getLogger(name).hasHa
|
|
165
|
+
ndlers()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def set_level(name: str, level: int | str):
|
|
169
|
+
"""Sets log level for the specified logger
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
name : str
|
|
174
|
+
Name of the module
|
|
175
|
+
level : int | str
|
|
176
|
+
Level to set
|
|
177
|
+
"""
|
|
178
|
+
logger = logging.getLogger(name)
|
|
179
|
+
|
|
180
|
+
# translate string to integer
|
|
181
|
+
if isinstance(level, str):
|
|
182
|
+
level = LEVELS[level.upper()]
|
|
183
|
+
|
|
184
|
+
# set the logger's level
|
|
185
|
+
logger.setLevel(level)
|
|
186
|
+
|
|
187
|
+
# and all handlers
|
|
188
|
+
for handler in logger.handlers:
|
|
189
|
+
handler.setLevel(level)
|
|
190
|
+
|
|
191
|
+
# optionally change the formatter (log more when in debug mode)
|
|
192
|
+
if level < INFO:
|
|
193
|
+
handler.setFormatter(DEBUG_FORMAT)
|
|
194
|
+
else:
|
|
195
|
+
handler.setFormatter(DEFAULT_FORMAT)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def set_verbosity(level: int):
|
|
199
|
+
"""Sets verbosity to the given level
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
level : int
|
|
204
|
+
Logg only errors (0), info (1) or even debug messages (2)
|
|
205
|
+
"""
|
|
206
|
+
if level == 0:
|
|
207
|
+
set_level("chaine._core.crf", "ERROR")
|
|
208
|
+
set_level("chaine.crf", "ERROR")
|
|
209
|
+
elif level == 1:
|
|
210
|
+
set_level("chaine._core.crf", "INFO")
|
|
211
|
+
set_level("chaine.crf", "INFO")
|
|
212
|
+
elif level == 2:
|
|
213
|
+
set_level("chaine._core.crf", "DEBUG")
|
|
214
|
+
set_level("chaine.crf", "DEBUG")
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""
|
|
2
|
+
chaine.optimization.metrics
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
|
+
|
|
5
|
+
This module implements metrics to evaluate the performance of a trained model.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections import Counter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def calculate_precision(true_positives: int, false_positives: int) -> float:
|
|
12
|
+
"""Calculate precision score.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
true_positives : int
|
|
17
|
+
Number of true positives.
|
|
18
|
+
false_positives : int
|
|
19
|
+
Number of false positives.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
float
|
|
24
|
+
Precision score.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
return true_positives / (true_positives + false_positives)
|
|
28
|
+
except ZeroDivisionError:
|
|
29
|
+
# only false negatives is perfect precision
|
|
30
|
+
return 1.0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def calculate_recall(true_positives: int, false_negatives: int) -> float:
|
|
34
|
+
"""Calculate recall score.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
true_positives : int
|
|
39
|
+
Number of true positives.
|
|
40
|
+
false_negatives : int
|
|
41
|
+
Number of false negatives.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
float
|
|
46
|
+
Recall score.
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
return true_positives / (true_positives + false_negatives)
|
|
50
|
+
except ZeroDivisionError:
|
|
51
|
+
# only false positives is imperfect recall
|
|
52
|
+
return 0.0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def calculate_f1(true_positives: int, false_positives: int, false_negatives: int) -> float:
|
|
56
|
+
"""Calculate F1 score.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
true_positives : int
|
|
61
|
+
Number of true positives.
|
|
62
|
+
false_negatives : int
|
|
63
|
+
Number of false negatives.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
float
|
|
68
|
+
Precision score
|
|
69
|
+
"""
|
|
70
|
+
precision = calculate_precision(true_positives, false_positives)
|
|
71
|
+
recall = calculate_recall(true_positives, false_negatives)
|
|
72
|
+
try:
|
|
73
|
+
return (2 * precision * recall) / (precision + recall)
|
|
74
|
+
except ZeroDivisionError:
|
|
75
|
+
# zero precision and zero recall is zero f1
|
|
76
|
+
return 0.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def evaluate_predictions(true: list[list[str]], pred: list[list[str]]) -> dict[str, float]:
|
|
80
|
+
"""Evaluate the given predictions with the true labels.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
true : list[list[str]]
|
|
85
|
+
True labels.
|
|
86
|
+
pred : list[list[str]]
|
|
87
|
+
Predicted labels.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
dict[str, float]
|
|
92
|
+
Precision, recall and F1 scores.
|
|
93
|
+
"""
|
|
94
|
+
# validate input
|
|
95
|
+
if (
|
|
96
|
+
not isinstance(true, list)
|
|
97
|
+
or not isinstance(pred, list)
|
|
98
|
+
or not isinstance(true[0], list)
|
|
99
|
+
or not isinstance(pred[0], list)
|
|
100
|
+
):
|
|
101
|
+
raise ValueError("Input lists are invalid")
|
|
102
|
+
|
|
103
|
+
counts = Counter()
|
|
104
|
+
|
|
105
|
+
# get true positives, true negatives, false positives, false negatives
|
|
106
|
+
for true_labels, predicted_labels in zip(true, pred):
|
|
107
|
+
# ignore prefixes
|
|
108
|
+
true_labels = [l.removeprefix("B-").removeprefix("I-") for l in true_labels]
|
|
109
|
+
predicted_labels = [l.removeprefix("B-").removeprefix("I-") for l in predicted_labels]
|
|
110
|
+
|
|
111
|
+
if len(true_labels) != len(predicted_labels):
|
|
112
|
+
raise ValueError(f"Different lengths: '{true_labels}' vs. '{predicted_labels}'")
|
|
113
|
+
|
|
114
|
+
for true_label, predicted_label in zip(true_labels, predicted_labels):
|
|
115
|
+
if true_label != "O" and predicted_label == true_label:
|
|
116
|
+
counts["tp"] += 1
|
|
117
|
+
if predicted_label != "O" and predicted_label != true_label:
|
|
118
|
+
counts["fp"] += 1
|
|
119
|
+
if true_label == "O" and predicted_label == "O":
|
|
120
|
+
counts["tn"] += 1
|
|
121
|
+
if true_label != "O" and predicted_label == "O":
|
|
122
|
+
counts["fn"] += 1
|
|
123
|
+
|
|
124
|
+
# calculate precision, recall and f1 score
|
|
125
|
+
return {
|
|
126
|
+
"precision": calculate_precision(counts["tp"], counts["fp"]),
|
|
127
|
+
"recall": calculate_recall(counts["tp"], counts["fn"]),
|
|
128
|
+
"f1": calculate_f1(counts["tp"], counts["fp"], counts["fn"]),
|
|
129
|
+
}
|