chaine 4.0.0b2__cp314-cp314-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. chaine/__init__.py +2 -0
  2. chaine/_core/crf.cpp +19496 -0
  3. chaine/_core/crf.cpython-314-x86_64-linux-musl.so +0 -0
  4. chaine/_core/crfsuite/include/crfsuite.h +1077 -0
  5. chaine/_core/crfsuite/include/crfsuite_api.hpp +406 -0
  6. chaine/_core/crfsuite/lib/cqdb/src/cqdb.c +639 -0
  7. chaine/_core/crfsuite/lib/cqdb/src/lookup3.c +1271 -0
  8. chaine/_core/crfsuite/lib/crf/src/crf1d_context.c +788 -0
  9. chaine/_core/crfsuite/lib/crf/src/crf1d_encode.c +1020 -0
  10. chaine/_core/crfsuite/lib/crf/src/crf1d_feature.c +382 -0
  11. chaine/_core/crfsuite/lib/crf/src/crf1d_model.c +1085 -0
  12. chaine/_core/crfsuite/lib/crf/src/crf1d_tag.c +582 -0
  13. chaine/_core/crfsuite/lib/crf/src/crfsuite.c +500 -0
  14. chaine/_core/crfsuite/lib/crf/src/crfsuite_train.c +302 -0
  15. chaine/_core/crfsuite/lib/crf/src/dataset.c +115 -0
  16. chaine/_core/crfsuite/lib/crf/src/dictionary.c +127 -0
  17. chaine/_core/crfsuite/lib/crf/src/holdout.c +83 -0
  18. chaine/_core/crfsuite/lib/crf/src/json.c +1497 -0
  19. chaine/_core/crfsuite/lib/crf/src/logging.c +85 -0
  20. chaine/_core/crfsuite/lib/crf/src/params.c +370 -0
  21. chaine/_core/crfsuite/lib/crf/src/quark.c +180 -0
  22. chaine/_core/crfsuite/lib/crf/src/rumavl.c +1178 -0
  23. chaine/_core/crfsuite/lib/crf/src/train_arow.c +409 -0
  24. chaine/_core/crfsuite/lib/crf/src/train_averaged_perceptron.c +237 -0
  25. chaine/_core/crfsuite/lib/crf/src/train_l2sgd.c +491 -0
  26. chaine/_core/crfsuite/lib/crf/src/train_lbfgs.c +323 -0
  27. chaine/_core/crfsuite/lib/crf/src/train_passive_aggressive.c +442 -0
  28. chaine/_core/crfsuite/swig/crfsuite.cpp +1 -0
  29. chaine/_core/liblbfgs/lib/lbfgs.c +1531 -0
  30. chaine/_core/tagger_wrapper.hpp +58 -0
  31. chaine/_core/trainer_wrapper.cpp +32 -0
  32. chaine/_core/trainer_wrapper.hpp +26 -0
  33. chaine/crf.py +505 -0
  34. chaine/logging.py +214 -0
  35. chaine/optimization/__init__.py +10 -0
  36. chaine/optimization/metrics.py +129 -0
  37. chaine/optimization/spaces.py +394 -0
  38. chaine/optimization/trial.py +103 -0
  39. chaine/optimization/utils.py +119 -0
  40. chaine/training.py +184 -0
  41. chaine/typing.py +18 -0
  42. chaine/validation.py +43 -0
  43. chaine-4.0.0b2.dist-info/METADATA +343 -0
  44. chaine-4.0.0b2.dist-info/RECORD +50 -0
  45. chaine-4.0.0b2.dist-info/WHEEL +5 -0
  46. chaine-4.0.0b2.dist-info/licenses/LICENSE +22 -0
  47. chaine-4.0.0b2.dist-info/sboms/auditwheel.cdx.json +1 -0
  48. chaine-4.0.0b2.dist-info/top_level.txt +1 -0
  49. chaine.libs/libgcc_s-0cd532bd.so.1 +0 -0
  50. chaine.libs/libstdc++-5d72f927.so.6.0.33 +0 -0
chaine/logging.py ADDED
@@ -0,0 +1,214 @@
1
+ """
2
+ chaine.logging
3
+ ~~~~~~~~~~~~~~
4
+
5
+ This module implements a basic logger.
6
+ """
7
+
8
+ import logging
9
+ import sys
10
+ from logging import Formatter, StreamHandler
11
+
12
+ DEBUG = 10
13
+ INFO = 20
14
+ WARNING = 30
15
+ ERROR = 40
16
+ LEVELS = {"DEBUG": DEBUG, "INFO": INFO, "WARNING": WARNING, "ERROR": ERROR}
17
+
18
+ DEFAULT_FORMAT = Formatter("[%(asctime)s] [%(levelname)s] %(message)s")
19
+ DEBUG_FORMAT = Formatter("[%(asctime)s] %(name)s:%(lineno)d [%(levelname)s] %(message)s")
20
+
21
+
22
+ class Logger:
23
+ def __init__(self, name: str):
24
+ """Basic logger
25
+
26
+ Parameters
27
+ ----------
28
+ name : str
29
+ Name of the logger
30
+ """
31
+ self.name = name
32
+
33
+ # return a logger with the specified name, creating it if necessary
34
+ self._logger = logging.getLogger(name)
35
+
36
+ # stream handler to stdout
37
+ self._stream_handler = StreamHandler(sys.stdout)
38
+ self._logger.addHandler(self._stream_handler)
39
+
40
+ # set level of both the logger and the handler to INFO by default
41
+ self.set_level("INFO")
42
+
43
+ def set_level(self, level: str | int):
44
+ # translate string to integer
45
+ if isinstance(level, str):
46
+ level = LEVELS[level.upper()]
47
+
48
+ # set the logger's level
49
+ self._logger.setLevel(level)
50
+
51
+ # and all handlers
52
+ for handler in self._logger.handlers:
53
+ handler.setLevel(level)
54
+
55
+ # optionally change the formatter (log more when in debug mode)
56
+ if level < INFO:
57
+ handler.setFormatter(DEBUG_FORMAT)
58
+ else:
59
+ handler.setFormatter(DEFAULT_FORMAT)
60
+
61
+ def debug(self, message: str):
62
+ """Debug log message
63
+
64
+ Parameters
65
+ ----------
66
+ message : str
67
+ Message to log
68
+ """
69
+ if self._logger.isEnabledFor(DEBUG):
70
+ self._logger._log(DEBUG, message, ())
71
+
72
+ def info(self, message: str):
73
+ """Info log message
74
+
75
+ Parameters
76
+ ----------
77
+ message : str
78
+ Message to log
79
+ """
80
+ if self._logger.isEnabledFor(INFO):
81
+ self._logger._log(INFO, message, ())
82
+
83
+ def warning(self, message: str):
84
+ """Warning log message
85
+
86
+ Parameters
87
+ ----------
88
+ message : str
89
+ Message to log
90
+ """
91
+ if self._logger.isEnabledFor(WARNING):
92
+ self._logger._log(WARNING, message, ())
93
+
94
+ def error(self, message: str | Exception):
95
+ """Error log message
96
+
97
+ Parameters
98
+ ----------
99
+ message : str
100
+ Message to log
101
+ """
102
+ if self._logger.isEnabledFor(ERROR):
103
+ if isinstance(message, Exception):
104
+ # log stacktrace if message is an exception
105
+ self._logger._log(ERROR, message, (), exc_info=True)
106
+ else:
107
+ self._logger._log(ERROR, message, ())
108
+
109
+ @property
110
+ def in_debug_mode(self) -> bool:
111
+ """Checks if the logger's level is DEBUG
112
+
113
+ Returns
114
+ -------
115
+ bool
116
+ True, if logger is in DEBUG mode, False otherwise
117
+ """
118
+ return self._logger.level == DEBUG
119
+
120
+ @property
121
+ def level(self) -> int:
122
+ """Returns the current log level
123
+
124
+ Returns
125
+ -------
126
+ int
127
+ Log level.
128
+ """
129
+ return self._logger.level
130
+
131
+ def __repr__(self):
132
+ return f"<Logger: {self.name} ({self.level})>"
133
+
134
+
135
+ def get_logger(name: str) -> logging.Logger:
136
+ """Gets the specified logger object
137
+
138
+ Parameters
139
+ ----------
140
+ name : str
141
+ Name of the module to get the logger for
142
+
143
+ Returns
144
+ -------
145
+ logging.Logger
146
+ Logger of the specified module
147
+ """
148
+ return logging.getLogger(name)
149
+
150
+
151
+ def logger_exists(name: str) -> bool:
152
+ """Checks if a logger exists for the specified module
153
+
154
+ Parameters
155
+ ----------
156
+ name : str
157
+ Name of the module to check the logger for
158
+
159
+ Returns
160
+ -------
161
+ bool
162
+ True if logger exists, False otherwise
163
+ """
164
+ return logging.getLogger(name).hasHa
165
+ ndlers()
166
+
167
+
168
+ def set_level(name: str, level: int | str):
169
+ """Sets log level for the specified logger
170
+
171
+ Parameters
172
+ ----------
173
+ name : str
174
+ Name of the module
175
+ level : int | str
176
+ Level to set
177
+ """
178
+ logger = logging.getLogger(name)
179
+
180
+ # translate string to integer
181
+ if isinstance(level, str):
182
+ level = LEVELS[level.upper()]
183
+
184
+ # set the logger's level
185
+ logger.setLevel(level)
186
+
187
+ # and all handlers
188
+ for handler in logger.handlers:
189
+ handler.setLevel(level)
190
+
191
+ # optionally change the formatter (log more when in debug mode)
192
+ if level < INFO:
193
+ handler.setFormatter(DEBUG_FORMAT)
194
+ else:
195
+ handler.setFormatter(DEFAULT_FORMAT)
196
+
197
+
198
+ def set_verbosity(level: int):
199
+ """Sets verbosity to the given level
200
+
201
+ Parameters
202
+ ----------
203
+ level : int
204
+ Logg only errors (0), info (1) or even debug messages (2)
205
+ """
206
+ if level == 0:
207
+ set_level("chaine._core.crf", "ERROR")
208
+ set_level("chaine.crf", "ERROR")
209
+ elif level == 1:
210
+ set_level("chaine._core.crf", "INFO")
211
+ set_level("chaine.crf", "INFO")
212
+ elif level == 2:
213
+ set_level("chaine._core.crf", "DEBUG")
214
+ set_level("chaine.crf", "DEBUG")
@@ -0,0 +1,10 @@
1
+ from chaine.optimization import utils
2
+ from chaine.optimization.spaces import (
3
+ APSearchSpace,
4
+ AROWSearchSpace,
5
+ L2SGDSearchSpace,
6
+ LBFGSSearchSpace,
7
+ PASearchSpace,
8
+ SearchSpace,
9
+ )
10
+ from chaine.optimization.trial import OptimizationTrial
@@ -0,0 +1,129 @@
1
+ """
2
+ chaine.optimization.metrics
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+
5
+ This module implements metrics to evaluate the performance of a trained model.
6
+ """
7
+
8
+ from collections import Counter
9
+
10
+
11
+ def calculate_precision(true_positives: int, false_positives: int) -> float:
12
+ """Calculate precision score.
13
+
14
+ Parameters
15
+ ----------
16
+ true_positives : int
17
+ Number of true positives.
18
+ false_positives : int
19
+ Number of false positives.
20
+
21
+ Returns
22
+ -------
23
+ float
24
+ Precision score.
25
+ """
26
+ try:
27
+ return true_positives / (true_positives + false_positives)
28
+ except ZeroDivisionError:
29
+ # only false negatives is perfect precision
30
+ return 1.0
31
+
32
+
33
+ def calculate_recall(true_positives: int, false_negatives: int) -> float:
34
+ """Calculate recall score.
35
+
36
+ Parameters
37
+ ----------
38
+ true_positives : int
39
+ Number of true positives.
40
+ false_negatives : int
41
+ Number of false negatives.
42
+
43
+ Returns
44
+ -------
45
+ float
46
+ Recall score.
47
+ """
48
+ try:
49
+ return true_positives / (true_positives + false_negatives)
50
+ except ZeroDivisionError:
51
+ # only false positives is imperfect recall
52
+ return 0.0
53
+
54
+
55
+ def calculate_f1(true_positives: int, false_positives: int, false_negatives: int) -> float:
56
+ """Calculate F1 score.
57
+
58
+ Parameters
59
+ ----------
60
+ true_positives : int
61
+ Number of true positives.
62
+ false_negatives : int
63
+ Number of false negatives.
64
+
65
+ Returns
66
+ -------
67
+ float
68
+ Precision score
69
+ """
70
+ precision = calculate_precision(true_positives, false_positives)
71
+ recall = calculate_recall(true_positives, false_negatives)
72
+ try:
73
+ return (2 * precision * recall) / (precision + recall)
74
+ except ZeroDivisionError:
75
+ # zero precision and zero recall is zero f1
76
+ return 0.0
77
+
78
+
79
+ def evaluate_predictions(true: list[list[str]], pred: list[list[str]]) -> dict[str, float]:
80
+ """Evaluate the given predictions with the true labels.
81
+
82
+ Parameters
83
+ ----------
84
+ true : list[list[str]]
85
+ True labels.
86
+ pred : list[list[str]]
87
+ Predicted labels.
88
+
89
+ Returns
90
+ -------
91
+ dict[str, float]
92
+ Precision, recall and F1 scores.
93
+ """
94
+ # validate input
95
+ if (
96
+ not isinstance(true, list)
97
+ or not isinstance(pred, list)
98
+ or not isinstance(true[0], list)
99
+ or not isinstance(pred[0], list)
100
+ ):
101
+ raise ValueError("Input lists are invalid")
102
+
103
+ counts = Counter()
104
+
105
+ # get true positives, true negatives, false positives, false negatives
106
+ for true_labels, predicted_labels in zip(true, pred):
107
+ # ignore prefixes
108
+ true_labels = [l.removeprefix("B-").removeprefix("I-") for l in true_labels]
109
+ predicted_labels = [l.removeprefix("B-").removeprefix("I-") for l in predicted_labels]
110
+
111
+ if len(true_labels) != len(predicted_labels):
112
+ raise ValueError(f"Different lengths: '{true_labels}' vs. '{predicted_labels}'")
113
+
114
+ for true_label, predicted_label in zip(true_labels, predicted_labels):
115
+ if true_label != "O" and predicted_label == true_label:
116
+ counts["tp"] += 1
117
+ if predicted_label != "O" and predicted_label != true_label:
118
+ counts["fp"] += 1
119
+ if true_label == "O" and predicted_label == "O":
120
+ counts["tn"] += 1
121
+ if true_label != "O" and predicted_label == "O":
122
+ counts["fn"] += 1
123
+
124
+ # calculate precision, recall and f1 score
125
+ return {
126
+ "precision": calculate_precision(counts["tp"], counts["fp"]),
127
+ "recall": calculate_recall(counts["tp"], counts["fn"]),
128
+ "f1": calculate_f1(counts["tp"], counts["fp"], counts["fn"]),
129
+ }