chaine 3.13.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chaine might be problematic. Click here for more details.

Files changed (68) hide show
  1. chaine/__init__.py +2 -0
  2. chaine/_core/crf.cp313-win_amd64.pyd +0 -0
  3. chaine/_core/crf.cpp +19854 -0
  4. chaine/_core/crf.pyx +271 -0
  5. chaine/_core/crfsuite/COPYING +27 -0
  6. chaine/_core/crfsuite/README +183 -0
  7. chaine/_core/crfsuite/include/crfsuite.h +1077 -0
  8. chaine/_core/crfsuite/include/crfsuite.hpp +649 -0
  9. chaine/_core/crfsuite/include/crfsuite_api.hpp +406 -0
  10. chaine/_core/crfsuite/include/os.h +65 -0
  11. chaine/_core/crfsuite/lib/cqdb/COPYING +28 -0
  12. chaine/_core/crfsuite/lib/cqdb/include/cqdb.h +518 -0
  13. chaine/_core/crfsuite/lib/cqdb/src/cqdb.c +639 -0
  14. chaine/_core/crfsuite/lib/cqdb/src/lookup3.c +1271 -0
  15. chaine/_core/crfsuite/lib/cqdb/src/main.c +184 -0
  16. chaine/_core/crfsuite/lib/crf/src/crf1d.h +354 -0
  17. chaine/_core/crfsuite/lib/crf/src/crf1d_context.c +788 -0
  18. chaine/_core/crfsuite/lib/crf/src/crf1d_encode.c +1020 -0
  19. chaine/_core/crfsuite/lib/crf/src/crf1d_feature.c +382 -0
  20. chaine/_core/crfsuite/lib/crf/src/crf1d_model.c +1085 -0
  21. chaine/_core/crfsuite/lib/crf/src/crf1d_tag.c +582 -0
  22. chaine/_core/crfsuite/lib/crf/src/crfsuite.c +500 -0
  23. chaine/_core/crfsuite/lib/crf/src/crfsuite_internal.h +233 -0
  24. chaine/_core/crfsuite/lib/crf/src/crfsuite_train.c +302 -0
  25. chaine/_core/crfsuite/lib/crf/src/dataset.c +115 -0
  26. chaine/_core/crfsuite/lib/crf/src/dictionary.c +127 -0
  27. chaine/_core/crfsuite/lib/crf/src/holdout.c +83 -0
  28. chaine/_core/crfsuite/lib/crf/src/json.c +1497 -0
  29. chaine/_core/crfsuite/lib/crf/src/json.h +120 -0
  30. chaine/_core/crfsuite/lib/crf/src/logging.c +85 -0
  31. chaine/_core/crfsuite/lib/crf/src/logging.h +49 -0
  32. chaine/_core/crfsuite/lib/crf/src/params.c +370 -0
  33. chaine/_core/crfsuite/lib/crf/src/params.h +84 -0
  34. chaine/_core/crfsuite/lib/crf/src/quark.c +180 -0
  35. chaine/_core/crfsuite/lib/crf/src/quark.h +46 -0
  36. chaine/_core/crfsuite/lib/crf/src/rumavl.c +1178 -0
  37. chaine/_core/crfsuite/lib/crf/src/rumavl.h +144 -0
  38. chaine/_core/crfsuite/lib/crf/src/train_arow.c +409 -0
  39. chaine/_core/crfsuite/lib/crf/src/train_averaged_perceptron.c +237 -0
  40. chaine/_core/crfsuite/lib/crf/src/train_l2sgd.c +491 -0
  41. chaine/_core/crfsuite/lib/crf/src/train_lbfgs.c +323 -0
  42. chaine/_core/crfsuite/lib/crf/src/train_passive_aggressive.c +442 -0
  43. chaine/_core/crfsuite/lib/crf/src/vecmath.h +360 -0
  44. chaine/_core/crfsuite/swig/crfsuite.cpp +1 -0
  45. chaine/_core/crfsuite_api.pxd +67 -0
  46. chaine/_core/liblbfgs/COPYING +22 -0
  47. chaine/_core/liblbfgs/README +71 -0
  48. chaine/_core/liblbfgs/include/lbfgs.h +745 -0
  49. chaine/_core/liblbfgs/lib/arithmetic_ansi.h +142 -0
  50. chaine/_core/liblbfgs/lib/arithmetic_sse_double.h +303 -0
  51. chaine/_core/liblbfgs/lib/arithmetic_sse_float.h +312 -0
  52. chaine/_core/liblbfgs/lib/lbfgs.c +1531 -0
  53. chaine/_core/tagger_wrapper.hpp +58 -0
  54. chaine/_core/trainer_wrapper.cpp +32 -0
  55. chaine/_core/trainer_wrapper.hpp +26 -0
  56. chaine/crf.py +505 -0
  57. chaine/logging.py +214 -0
  58. chaine/optimization/__init__.py +10 -0
  59. chaine/optimization/metrics.py +129 -0
  60. chaine/optimization/spaces.py +394 -0
  61. chaine/optimization/trial.py +103 -0
  62. chaine/optimization/utils.py +119 -0
  63. chaine/training.py +184 -0
  64. chaine/typing.py +18 -0
  65. chaine/validation.py +43 -0
  66. chaine-3.13.1.dist-info/METADATA +348 -0
  67. chaine-3.13.1.dist-info/RECORD +68 -0
  68. chaine-3.13.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,58 @@
1
+ #ifndef TAGGER_WRAPPER_H
2
+ #define TAGGER_WRAPPER_H 1
3
+
4
+ #include <stdio.h>
5
+ #include <errno.h>
6
+ #include <stdexcept>
7
+ #include "crfsuite_api.hpp"
8
+
9
+ namespace CRFSuiteWrapper
10
+ {
11
+ class Tagger : public CRFSuite::Tagger
12
+ {
13
+ public:
14
+ void dump_states(int fileno)
15
+ {
16
+ if (model == NULL)
17
+ {
18
+ throw std::runtime_error("Tagger is closed");
19
+ }
20
+
21
+ FILE *file = fdopen(fileno, "w");
22
+ if (!file)
23
+ {
24
+ throw std::runtime_error("Cannot open file");
25
+ }
26
+
27
+ model->dump_states(model, file);
28
+
29
+ if (fclose(file))
30
+ {
31
+ throw std::runtime_error("Cannot close file");
32
+ };
33
+ }
34
+
35
+ public:
36
+ void dump_transitions(int fileno)
37
+ {
38
+ if (model == NULL)
39
+ {
40
+ throw std::runtime_error("Tagger is closed");
41
+ }
42
+
43
+ FILE *file = fdopen(fileno, "w");
44
+ if (!file)
45
+ {
46
+ throw std::runtime_error("Cannot open file");
47
+ }
48
+
49
+ model->dump_transitions(model, file);
50
+
51
+ if (fclose(file))
52
+ {
53
+ throw std::runtime_error("Cannot close file");
54
+ };
55
+ }
56
+ };
57
+ } // namespace CRFSuiteWrapper
58
+ #endif
@@ -0,0 +1,32 @@
1
+ #include <iostream>
2
+ #include "Python.h"
3
+ #include "trainer_wrapper.hpp"
4
+ #include <stdexcept>
5
+
6
+ namespace CRFSuiteWrapper
7
+ {
8
+ void Trainer::set_handler(PyObject *obj, messagefunc handler)
9
+ {
10
+ this->m_obj = obj;
11
+ this->handler = handler;
12
+ }
13
+
14
+ void Trainer::message(const std::string &msg)
15
+ {
16
+ if (this->m_obj == NULL)
17
+ {
18
+ std::cerr << "** Trainer invalid state: obj [" << this->m_obj << "]\n";
19
+ return;
20
+ }
21
+ PyObject *result = handler(this->m_obj, msg);
22
+ if (result == NULL)
23
+ {
24
+ throw std::runtime_error("AAAaaahhhhHHhh!!!!!");
25
+ }
26
+ }
27
+
28
+ void Trainer::_init_trainer()
29
+ {
30
+ Trainer::init();
31
+ }
32
+ } // namespace CRFSuiteWrapper
@@ -0,0 +1,26 @@
1
+ #ifndef TRAINER_WRAPPER_H
2
+ #define TRAINER_WRAPPER_H 1
3
+
4
+ #include <string>
5
+ #include "crfsuite_api.hpp"
6
+
7
+ struct _object;
8
+ typedef _object PyObject;
9
+
10
+ namespace CRFSuiteWrapper
11
+ {
12
+ typedef PyObject *(*messagefunc)(PyObject *self, std::string message);
13
+
14
+ class Trainer : public CRFSuite::Trainer
15
+ {
16
+ protected:
17
+ PyObject *m_obj;
18
+ messagefunc handler;
19
+
20
+ public:
21
+ void set_handler(PyObject *obj, messagefunc handler);
22
+ virtual void message(const std::string &msg);
23
+ void _init_trainer();
24
+ };
25
+ } // namespace CRFSuiteWrapper
26
+ #endif
chaine/crf.py ADDED
@@ -0,0 +1,505 @@
1
+ """
2
+ chaine.crf
3
+ ~~~~~~~~~~
4
+
5
+ This module implements the trainer, optimizer and model.
6
+ """
7
+
8
+ import json
9
+ import random
10
+ import tempfile
11
+ import uuid
12
+ from functools import cached_property
13
+ from operator import itemgetter
14
+ from pathlib import Path
15
+
16
+ from chaine._core.crf import Model as _Model
17
+ from chaine._core.crf import Trainer as _Trainer
18
+ from chaine.logging import Logger, set_verbosity
19
+ from chaine.optimization.spaces import (
20
+ APSearchSpace,
21
+ AROWSearchSpace,
22
+ L2SGDSearchSpace,
23
+ LBFGSSearchSpace,
24
+ PASearchSpace,
25
+ SearchSpace,
26
+ )
27
+ from chaine.optimization.trial import OptimizationTrial
28
+ from chaine.optimization.utils import cross_validation, downsample
29
+ from chaine.typing import Filepath, Iterable, Labels, Sequence
30
+ from chaine.validation import is_valid_sequence
31
+
32
+ LOGGER = Logger(__name__)
33
+
34
+
35
+ class Trainer:
36
+ """Trainer for conditional random fields.
37
+
38
+ Parameters
39
+ ----------
40
+ algorithm : str
41
+ The following optimization algorithms are available:
42
+ * lbfgs: Limited-Memory BFGS with L1/L2 regularization
43
+ * l2sgd: Stochastic Gradient Descent with L2 regularization
44
+ * ap: Averaged Perceptron
45
+ * pa: Passive Aggressive
46
+ * arow: Adaptive Regularization of Weights
47
+
48
+ Limited-Memory BFGS Parameters (lbfgs)
49
+ --------------------------------------
50
+ min_freq : float, optional (default=0)
51
+ Threshold value for minimum frequency of a feature occurring in training data.
52
+ all_possible_states : bool, optional (default=False)
53
+ Generate state features that do not even occur in the training data.
54
+ all_possible_transitions : bool, optional (default=False)
55
+ Generate transition features that do not even occur in the training data.
56
+ max_iterations : int, optional (default=None)
57
+ Maximum number of iterations (unlimited by default).
58
+ num_memories : int, optional (default=6)
59
+ Number of limited memories for approximating the inverse hessian matrix.
60
+ c1 : float, optional (default=0)
61
+ Coefficient for L1 regularization.
62
+ c2 : float, optional (default=1.0)
63
+ Coefficient for L2 regularization.
64
+ epsilon : float, optional (default=1e-5)
65
+ Parameter that determines the condition of convergence.
66
+ period : int, optional (default=10)
67
+ Threshold value for iterations to test the stopping criterion.
68
+ delta : float, optional (default=1e-5)
69
+ Top iteration when log likelihood is not greater than this.
70
+ linesearch : str, optional (default="MoreThuente")
71
+ Line search algorithm used in updates:
72
+ * MoreThuente: More and Thuente's method
73
+ * Backtracking: Backtracking method with regular Wolfe condition
74
+ * StrongBacktracking: Backtracking method with strong Wolfe condition
75
+ max_linesearch : int, optional (default=20)
76
+ Maximum number of trials for the line search algorithm.
77
+
78
+ SGD with L2 Parameters (l2sgd)
79
+ ------------------------------
80
+ min_freq : float, optional (default=0)
81
+ Threshold value for minimum frequency of a feature occurring in training data.
82
+ all_possible_states : bool, optional (default=False)
83
+ Generate state features that do not even occur in the training data.
84
+ all_possible_transitions : bool, optional (default=False)
85
+ Generate transition features that do not even occur in the training data.
86
+ max_iterations : int, optional (default=None)
87
+ Maximum number of iterations (1000 by default).
88
+ c2 : float, optional (default=1.0)
89
+ Coefficient for L2 regularization.
90
+ period : int, optional (default=10)
91
+ Threshold value for iterations to test the stopping criterion.
92
+ delta : float, optional (default=1e-5)
93
+ Top iteration when log likelihood is not greater than this.
94
+ calibration_eta : float, optional (default=0.1)
95
+ Initial value of learning rate (eta) used for calibration.
96
+ calibration_rate : float, optional (default=2.0)
97
+ Rate of increase/decrease of learning rate for calibration.
98
+ calibration_samples : int, optional (default=1000)
99
+ Number of instances used for calibration.
100
+ calibration_candidates : int, optional (default=10)
101
+ Number of candidates of learning rate.
102
+ calibration_max_trials : int, optional (default=20)
103
+ Maximum number of trials of learning rates for calibration.
104
+
105
+ Averaged Perceptron Parameters (ap)
106
+ -----------------------------------
107
+ min_freq : float, optional (default=0)
108
+ Threshold value for minimum frequency of a feature occurring in training data.
109
+ all_possible_states : bool, optional (default=False)
110
+ Generate state features that do not even occur in the training data.
111
+ all_possible_transitions : bool, optional (default=False)
112
+ Generate transition features that do not even occur in the training data.
113
+ max_iterations : int, optional (default=None)
114
+ Maximum number of iterations (100 by default).
115
+ epsilon : float, optional (default=1e-5)
116
+ Parameter that determines the condition of convergence.
117
+
118
+ Passive Aggressive Parameters (pa)
119
+ ----------------------------------
120
+ min_freq : float, optional (default=0)
121
+ Threshold value for minimum frequency of a feature occurring in training data.
122
+ all_possible_states : bool, optional (default=False)
123
+ Generate state features that do not even occur in the training data.
124
+ all_possible_transitions : bool, optional (default=False)
125
+ Generate transition features that do not even occur in the training data.
126
+ max_iterations : int, optional (default=None)
127
+ Maximum number of iterations (100 by default).
128
+ epsilon : float, optional (default=1e-5)
129
+ Parameter that determines the condition of convergence.
130
+ pa_type : int, optional (default=1)
131
+ Strategy for updating feature weights:
132
+ * 0: PA without slack variables
133
+ * 1: PA type I
134
+ * 2: PA type II
135
+ c : float, optional (default=1)
136
+ Aggressiveness parameter (used only for PA-I and PA-II).
137
+ error_sensitive : bool, optional (default=True)
138
+ Include square root of predicted incorrect labels into optimization routine.
139
+ averaging : bool, optional (default=True)
140
+ Compute average of feature weights at all updates.
141
+
142
+ Adaptive Regularization of Weights Parameters (arow)
143
+ ----------------------------------------------------
144
+ min_freq : float, optional (default=0)
145
+ Threshold value for minimum frequency of a feature occurring in training data.
146
+ all_possible_states : bool, optional (default=False)
147
+ Generate state features that do not even occur in the training data.
148
+ all_possible_transitions : bool, optional (default=False)
149
+ Generate transition features that do not even occur in the training data.
150
+ max_iterations : int, optional (default=None)
151
+ Maximum number of iterations (100 by default).
152
+ epsilon : float, optional (default=1e-5)
153
+ Parameter that determines the condition of convergence.
154
+ variance : float, optional (default=1)
155
+ Initial variance of every feature weight.
156
+ gamma : float, optional (default=1)
157
+ Trade-off between loss function and changes of feature weights.
158
+ """
159
+
160
+ def __init__(self, algorithm: str = "l2sgd", **kwargs):
161
+ self.algorithm = algorithm
162
+ self._trainer = _Trainer(algorithm, **kwargs)
163
+
164
+ def __repr__(self):
165
+ return f"<Trainer ({self.algorithm}): {self.params}>"
166
+
167
+ def train(
168
+ self,
169
+ dataset: Iterable[Sequence],
170
+ labels: Iterable[Labels],
171
+ *,
172
+ model_filepath: Filepath,
173
+ ):
174
+ """Start training on the given data set.
175
+
176
+ Parameters
177
+ ----------
178
+ dataset : Iterable[Sequence]
179
+ Data set consisting of sequences of feature sets.
180
+ labels : Iterable[Labels]
181
+ Labels corresponding to each instance in the data set.
182
+ model_filepath : Filepath, optional (default=model.chaine)
183
+ Path to model location.
184
+ """
185
+ LOGGER.info("Loading data set")
186
+ for i, (sequence, labels_) in enumerate(zip(dataset, labels)):
187
+ if not is_valid_sequence(sequence):
188
+ raise ValueError(f"Invalid format: {sequence}")
189
+
190
+ # log progress every 100 data points
191
+ if i > 0 and i % 100 == 0:
192
+ LOGGER.debug(f"{i} processed data points")
193
+
194
+ try:
195
+ self._trainer.append(sequence, labels_)
196
+ except Exception as message:
197
+ LOGGER.error(message)
198
+ LOGGER.debug(f"Sequence: {json.dumps(sequence)}")
199
+ LOGGER.debug(f"Labels: {json.dumps(labels_)}")
200
+
201
+ # fire!
202
+ LOGGER.info("Start training")
203
+ self._trainer.train(model_filepath)
204
+
205
+ @cached_property
206
+ def params(self) -> dict[str, str | int | float | bool]:
207
+ """Set parameters of the trainer.
208
+
209
+ Returns
210
+ -------
211
+ dict[str, str | int | float | bool]
212
+ Parameters of the trainer.
213
+ """
214
+ return {
215
+ self._trainer.param2kwarg.get(name, name): self._trainer.get_param(name)
216
+ for name in self._trainer.params
217
+ }
218
+
219
+
220
+ class HyperparameterOptimizer:
221
+ def __init__(
222
+ self,
223
+ trials: int = 10,
224
+ seed: int | None = None,
225
+ metric: str = "f1",
226
+ folds: int = 5,
227
+ spaces: list[SearchSpace] = [
228
+ AROWSearchSpace(),
229
+ APSearchSpace(),
230
+ LBFGSSearchSpace(),
231
+ L2SGDSearchSpace(),
232
+ PASearchSpace(),
233
+ ],
234
+ ):
235
+ """Optimize hyperparameters in a randomized manner.
236
+
237
+ Parameters
238
+ ----------
239
+ trials : int, optional
240
+ Number of trials for an algorithm, by default 10.
241
+ seed : int | None, optional
242
+ Random seed, by default None.
243
+ metric : str, optional
244
+ Metric to sort the results by, by default "f1"..
245
+ folds : int, optional
246
+ Number of folds to split the data set into, by default 5.
247
+ spaces : list[SearchSpace], optional
248
+ Search spaces to select hyperparameters from, by default [AROWSearchSpace(),
249
+ APSearchSpace(), LBFGSSearchSpace(), L2SGDSearchSpace(), PASearchSpace()].
250
+ """
251
+ self.trials = trials
252
+ self.seed = seed
253
+ self.metric = metric
254
+ self.folds = folds
255
+ self.spaces = spaces
256
+ self.results = []
257
+ self.baselines = []
258
+ self.logger = Logger("hyperparameter-optimization")
259
+
260
+ def optimize_hyperparameters(
261
+ self,
262
+ dataset: Iterable[Sequence],
263
+ labels: Iterable[Labels],
264
+ sample_size: int | None = None,
265
+ ) -> list[dict[str, dict]]:
266
+ """Optimize hyperparameters on the given data set.
267
+
268
+ Parameters
269
+ ----------
270
+ dataset : Iterable[Sequence]
271
+ Data set to train models on.
272
+ labels : Iterable[Labels]
273
+ Labels to train models on.
274
+ sample_size : int | None
275
+ Number of instances to sample from the data set.
276
+
277
+ Returns
278
+ -------
279
+ list[dict[str, dict]]
280
+ Sorted list of hyperparameters and evaluation scores.
281
+ """
282
+ # disable logging
283
+ set_verbosity(0)
284
+
285
+ # set random seed
286
+ random.seed(self.seed)
287
+
288
+ # optional downsampling
289
+ if sample_size:
290
+ dataset, labels = downsample(dataset, labels, sample_size, self.seed)
291
+
292
+ # split data set for cross validation
293
+ splits = list(cross_validation(dataset, labels, k=self.folds))
294
+
295
+ for i, space in enumerate(self.spaces):
296
+ self.logger.info(f"Starting with {space.algorithm} ({i + 1}/{len(self.spaces)})")
297
+ self.logger.info(f"Baseline for {space.algorithm}")
298
+
299
+ with OptimizationTrial(splits, space, is_baseline=True) as trial:
300
+ self.results.append(trial)
301
+ self.baselines.append(trial["stats"])
302
+
303
+ for j in range(self.trials):
304
+ self.logger.info(f"Trial {j + 1}/{self.trials} for {space.algorithm}")
305
+
306
+ with OptimizationTrial(splits, space, is_baseline=False) as trial:
307
+ self.results.append(trial)
308
+
309
+ self.logger.info(f"Best baseline model: {self._best_baseline_score}")
310
+ self.logger.info(f"Best optimized model: {self._best_optimized_score}")
311
+
312
+ self.logger.info("Finished hyperparameter optimization")
313
+ self.logger.info(f"Trained {len(self.results)} models with different hyperparamters")
314
+
315
+ # make more verbose again
316
+ set_verbosity(1)
317
+
318
+ # return sorted results
319
+ return sorted(self.results, key=self._metric, reverse=True)
320
+
321
+ @property
322
+ def _best_baseline_score(self) -> str | float:
323
+ """Best evaluation score with default hyperparameters.
324
+
325
+ Returns
326
+ -------
327
+ str | float
328
+ Score (or 'n/a' of no results available).
329
+ """
330
+ if self.baselines:
331
+ best = sorted(self.baselines, key=itemgetter(f"mean_{self.metric}"), reverse=True)[0]
332
+ return best[f"mean_{self.metric}"]
333
+
334
+ return "n/a"
335
+
336
+ @property
337
+ def _best_optimized_score(self) -> str | float:
338
+ """Best evaluation score with optimized hyperparameters.
339
+
340
+ Returns
341
+ -------
342
+ str | float
343
+ Score (or 'n/a' of no results available).
344
+ """
345
+ if self.results:
346
+ best = sorted(self.results, key=self._metric, reverse=True)[0]
347
+ return best["stats"][f"mean_{self.metric}"]
348
+
349
+ return "n/a"
350
+
351
+ def _metric(self, trial: dict[str, dict]) -> float:
352
+ """Metric so select for sorting.
353
+
354
+ Parameters
355
+ ----------
356
+ trial : dict[str, dict]
357
+ Optimization trial result.
358
+
359
+ Returns
360
+ -------
361
+ float
362
+ Metric score.
363
+ """
364
+ return trial["stats"][f"mean_{self.metric}"]
365
+
366
+
367
+ class Model:
368
+ """Linear-chain conditional random field.
369
+
370
+ Parameters
371
+ ----------
372
+ model_filepath : Filepath
373
+ Path to the trained model.
374
+ """
375
+
376
+ def __init__(self, filepath: Filepath):
377
+ self._model = _Model(filepath)
378
+
379
+ def __repr__(self):
380
+ return f"<Model: {self.labels}>"
381
+
382
+ @cached_property
383
+ def labels(self) -> set[str]:
384
+ """Labels the model is trained on."""
385
+ return set(self._model.labels)
386
+
387
+ @cached_property
388
+ def transitions(self) -> dict[str, float]:
389
+ """Learned transition weights."""
390
+ # get temporary file to dump the transitions
391
+ filepath = Path(tempfile.gettempdir(), str(uuid.uuid4()))
392
+
393
+ # write model to disk
394
+ self.dump_transitions(filepath)
395
+
396
+ # return the components
397
+ transitions = json.loads(filepath.read_text())
398
+
399
+ # cleanup
400
+ filepath.unlink()
401
+
402
+ return transitions
403
+
404
+ @cached_property
405
+ def states(self) -> dict[str, float]:
406
+ """Learned state feature weights."""
407
+ # get temporary file to dump the states
408
+ filepath = Path(tempfile.gettempdir(), str(uuid.uuid4()))
409
+
410
+ # write model to disk
411
+ self.dump_states(filepath)
412
+
413
+ # return the components
414
+ states = json.loads(filepath.read_text())
415
+
416
+ # cleanup
417
+ filepath.unlink()
418
+
419
+ return states
420
+
421
+ def predict_single(self, sequence: Sequence) -> list[str]:
422
+ """Predict most likely labels for a given sequence of tokens.
423
+
424
+ Parameters
425
+ ----------
426
+ sequence : Sequence
427
+ Sequence of tokens represented as feature dictionaries.
428
+
429
+ Returns
430
+ -------
431
+ list[str]
432
+ Most likely label sequence.
433
+ """
434
+ if not is_valid_sequence(sequence):
435
+ raise ValueError(f"Invalid format: {sequence}")
436
+
437
+ return self._model.predict_single(sequence)
438
+
439
+ def predict(self, sequences: Iterable[Sequence]) -> list[list[str]]:
440
+ """Predict most likely labels for a batch of tokens
441
+
442
+ Parameters
443
+ ----------
444
+ sequences : Iterable[Sequence]
445
+ Batch of sequences of tokens represented as feature dictionaries.
446
+
447
+ Returns
448
+ -------
449
+ list[list[str]]
450
+ Most likely label sequences.
451
+ """
452
+ return [self.predict_single(sequence) for sequence in sequences]
453
+
454
+ def predict_proba_single(self, sequence: Sequence) -> list[dict[str, float]]:
455
+ """Predict probabilities over all labels for each token in a sequence.
456
+
457
+ Parameters
458
+ ----------
459
+ sequence : Sequence
460
+ Sequence of tokens represented as feature dictionaries.
461
+
462
+ Returns
463
+ -------
464
+ list[dict[str, float]]
465
+ Probability distributions over all labels for each token.
466
+ """
467
+ if not is_valid_sequence(sequence):
468
+ raise ValueError(f"Invalid format: {sequence}")
469
+
470
+ return self._model.predict_proba_single(sequence)
471
+
472
+ def predict_proba(self, sequences: Iterable[Sequence]) -> list[list[dict[str, float]]]:
473
+ """Predict probabilities over all labels for each token in a batch of sequences.
474
+
475
+ Parameters
476
+ ----------
477
+ sequences : Sequence
478
+ Batch of sequences of tokens represented as feature dictionaries.
479
+
480
+ Returns
481
+ -------
482
+ list[dict[str, float]]
483
+ Probability distributions over all labels for each token in the sequences.
484
+ """
485
+ return [self.predict_proba_single(sequence) for sequence in sequences]
486
+
487
+ def dump_transitions(self, filepath: Filepath):
488
+ """Dump learned transitions with weights as JSON.
489
+
490
+ Parameters
491
+ ----------
492
+ filepath : Filepath
493
+ File to dump transitions to.
494
+ """
495
+ self._model.dump_transitions(filepath)
496
+
497
+ def dump_states(self, filepath: Filepath):
498
+ """Dump learned states with weights as JSON.
499
+
500
+ Parameters
501
+ ----------
502
+ filepath : Filepath
503
+ File to dump states to.
504
+ """
505
+ self._model.dump_states(filepath)