wapiti 0.0.5 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. checksums.yaml +7 -0
  2. data/.simplecov +3 -0
  3. data/Gemfile +25 -2
  4. data/HISTORY.md +5 -1
  5. data/LICENSE +14 -13
  6. data/README.md +9 -16
  7. data/Rakefile +38 -8
  8. data/ext/wapiti/bcd.c +126 -124
  9. data/ext/wapiti/decoder.c +203 -124
  10. data/ext/wapiti/decoder.h +6 -4
  11. data/ext/wapiti/extconf.rb +2 -2
  12. data/ext/wapiti/gradient.c +491 -320
  13. data/ext/wapiti/gradient.h +52 -34
  14. data/ext/wapiti/lbfgs.c +74 -33
  15. data/ext/wapiti/model.c +47 -37
  16. data/ext/wapiti/model.h +22 -20
  17. data/ext/wapiti/native.c +850 -839
  18. data/ext/wapiti/native.h +1 -1
  19. data/ext/wapiti/options.c +52 -20
  20. data/ext/wapiti/options.h +37 -30
  21. data/ext/wapiti/pattern.c +35 -33
  22. data/ext/wapiti/pattern.h +12 -11
  23. data/ext/wapiti/progress.c +14 -13
  24. data/ext/wapiti/progress.h +3 -2
  25. data/ext/wapiti/quark.c +14 -16
  26. data/ext/wapiti/quark.h +6 -5
  27. data/ext/wapiti/reader.c +83 -69
  28. data/ext/wapiti/reader.h +11 -9
  29. data/ext/wapiti/rprop.c +84 -43
  30. data/ext/wapiti/sequence.h +18 -16
  31. data/ext/wapiti/sgdl1.c +45 -43
  32. data/ext/wapiti/thread.c +19 -17
  33. data/ext/wapiti/thread.h +5 -4
  34. data/ext/wapiti/tools.c +7 -7
  35. data/ext/wapiti/tools.h +3 -4
  36. data/ext/wapiti/trainers.h +1 -1
  37. data/ext/wapiti/vmath.c +40 -38
  38. data/ext/wapiti/vmath.h +12 -11
  39. data/ext/wapiti/wapiti.c +159 -37
  40. data/ext/wapiti/wapiti.h +18 -4
  41. data/lib/wapiti.rb +15 -15
  42. data/lib/wapiti/errors.rb +15 -15
  43. data/lib/wapiti/model.rb +92 -84
  44. data/lib/wapiti/options.rb +123 -124
  45. data/lib/wapiti/utility.rb +14 -14
  46. data/lib/wapiti/version.rb +2 -2
  47. data/spec/spec_helper.rb +29 -9
  48. data/spec/wapiti/model_spec.rb +230 -194
  49. data/spec/wapiti/native_spec.rb +7 -8
  50. data/spec/wapiti/options_spec.rb +184 -174
  51. data/wapiti.gemspec +22 -8
  52. metadata +38 -42
  53. data/.gitignore +0 -5
@@ -1,7 +1,7 @@
1
1
  /*
2
2
  * Wapiti - A linear-chain CRF tool
3
3
  *
4
- * Copyright (c) 2009-2011 CNRS
4
+ * Copyright (c) 2009-2013 CNRS
5
5
  * All rights reserved.
6
6
  *
7
7
  * Redistribution and use in source and binary forms, with or without
@@ -24,19 +24,22 @@
24
24
  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25
25
  * POSSIBILITY OF SUCH DAMAGE.
26
26
  */
27
+ #include <ctype.h>
28
+ #include <inttypes.h>
27
29
  #include <stdbool.h>
28
30
  #include <stddef.h>
31
+ #include <stdint.h>
29
32
  #include <stdlib.h>
30
33
  #include <stdio.h>
31
34
  #include <string.h>
32
35
 
33
36
  #include "decoder.h"
34
- #include "model.h"
35
37
  #include "options.h"
36
38
  #include "progress.h"
37
39
  #include "quark.h"
38
40
  #include "reader.h"
39
41
  #include "sequence.h"
42
+ #include "model.h"
40
43
  #include "tools.h"
41
44
  #include "trainers.h"
42
45
  #include "wapiti.h"
@@ -44,16 +47,15 @@
44
47
  /*******************************************************************************
45
48
  * Training
46
49
  ******************************************************************************/
47
- static void trn_auto(mdl_t *mdl) {
48
- const int maxiter = mdl->opt->maxiter;
49
- mdl->opt->maxiter = 3;
50
- trn_sgdl1(mdl);
51
- mdl->opt->maxiter = maxiter;
52
- trn_lbfgs(mdl);
53
- }
50
+ static const char *typ_lst[] = {
51
+ "maxent",
52
+ "memm",
53
+ "crf"
54
+ };
55
+ static const uint32_t typ_cnt = sizeof(typ_lst) / sizeof(typ_lst[0]);
54
56
 
55
57
  static const struct {
56
- char *name;
58
+ const char *name;
57
59
  void (* train)(mdl_t *mdl);
58
60
  } trn_lst[] = {
59
61
  {"l-bfgs", trn_lbfgs},
@@ -62,20 +64,31 @@ static const struct {
62
64
  {"rprop", trn_rprop},
63
65
  {"rprop+", trn_rprop},
64
66
  {"rprop-", trn_rprop},
65
- {"auto", trn_auto }
66
67
  };
67
- static const int trn_cnt = sizeof(trn_lst) / sizeof(trn_lst[0]);
68
+ static const uint32_t trn_cnt = sizeof(trn_lst) / sizeof(trn_lst[0]);
68
69
 
69
70
  void dotrain(mdl_t *mdl) {
70
- // Check if the user requested the trainer list. If this is not the
71
- // case, search the trainer.
71
+ // Check if the user requested the type or trainer list. If this is not
72
+ // the case, search them in the lists.
73
+ if (!strcmp(mdl->opt->type, "list")) {
74
+ info("Available types of models:\n");
75
+ for (uint32_t i = 0; i < typ_cnt; i++)
76
+ info("\t%s\n", typ_lst[i]);
77
+ exit(EXIT_SUCCESS);
78
+ }
72
79
  if (!strcmp(mdl->opt->algo, "list")) {
73
80
  info("Available training algorithms:\n");
74
- for (int i = 0; i < trn_cnt; i++)
81
+ for (uint32_t i = 0; i < trn_cnt; i++)
75
82
  info("\t%s\n", trn_lst[i].name);
76
83
  exit(EXIT_SUCCESS);
77
84
  }
78
- int trn;
85
+ uint32_t typ, trn;
86
+ for (typ = 0; typ < typ_cnt; typ++)
87
+ if (!strcmp(mdl->opt->type, typ_lst[typ]))
88
+ break;
89
+ if (typ == typ_cnt)
90
+ fatal("unknown model type '%s'", mdl->opt->type);
91
+ mdl->type = typ;
79
92
  for (trn = 0; trn < trn_cnt; trn++)
80
93
  if (!strcmp(mdl->opt->algo, trn_lst[trn].name))
81
94
  break;
@@ -136,12 +149,12 @@ void dotrain(mdl_t *mdl) {
136
149
  mdl_sync(mdl);
137
150
  // Display some statistics as we all love this.
138
151
  info("* Summary\n");
139
- info(" nb train: %d\n", mdl->train->nseq);
152
+ info(" nb train: %"PRIu32"\n", mdl->train->nseq);
140
153
  if (mdl->devel != NULL)
141
- info(" nb devel: %d\n", mdl->devel->nseq);
142
- info(" nb labels: %zu\n", mdl->nlbl);
143
- info(" nb blocks: %zu\n", mdl->nobs);
144
- info(" nb features: %zu\n", mdl->nftr);
154
+ info(" nb devel: %"PRIu32"\n", mdl->devel->nseq);
155
+ info(" nb labels: %"PRIu32"\n", mdl->nlbl);
156
+ info(" nb blocks: %"PRIu64"\n", mdl->nobs);
157
+ info(" nb features: %"PRIu64"\n", mdl->nftr);
145
158
  // And train the model...
146
159
  info("* Train the model with %s\n", mdl->opt->algo);
147
160
  uit_setup(mdl);
@@ -149,12 +162,12 @@ void dotrain(mdl_t *mdl) {
149
162
  uit_cleanup(mdl);
150
163
  // If requested compact the model.
151
164
  if (mdl->opt->compact) {
152
- const size_t O = mdl->nobs;
153
- const size_t F = mdl->nftr;
165
+ const uint64_t O = mdl->nobs;
166
+ const uint64_t F = mdl->nftr;
154
167
  info("* Compacting the model\n");
155
168
  mdl_compact(mdl);
156
- info(" %8zu observations removed\n", O - mdl->nobs);
157
- info(" %8zu features removed\n", F - mdl->nftr);
169
+ info(" %8"PRIu64" observations removed\n", O - mdl->nobs);
170
+ info(" %8"PRIu64" features removed\n", F - mdl->nftr);
158
171
  }
159
172
  // And save the trained model
160
173
  info("* Save the model\n");
@@ -209,7 +222,7 @@ void dolabel(mdl_t *mdl) {
209
222
  /*******************************************************************************
210
223
  * Dumping
211
224
  ******************************************************************************/
212
- void dodump(mdl_t *mdl) {
225
+ static void dodump(mdl_t *mdl) {
213
226
  // Load input model file
214
227
  info("* Load model\n");
215
228
  FILE *fin = stdin;
@@ -230,32 +243,35 @@ void dodump(mdl_t *mdl) {
230
243
  }
231
244
  // Dump model
232
245
  info("* Dump model\n");
233
- const size_t Y = mdl->nlbl;
234
- const size_t O = mdl->nobs;
246
+ const uint32_t Y = mdl->nlbl;
247
+ const uint64_t O = mdl->nobs;
235
248
  const qrk_t *Qlbl = mdl->reader->lbl;
236
249
  const qrk_t *Qobs = mdl->reader->obs;
237
- for (size_t o = 0; o < O; o++) {
250
+ char fmt[16];
251
+ sprintf(fmt, "%%.%df\n", mdl->opt->prec);
252
+ for (uint64_t o = 0; o < O; o++) {
238
253
  const char *obs = qrk_id2str(Qobs, o);
239
254
  bool empty = true;
240
255
  if (mdl->kind[o] & 1) {
241
256
  const double *w = mdl->theta + mdl->uoff[o];
242
- for (size_t y = 0; y < Y; y++) {
243
- if (w[y] == 0.0)
257
+ for (uint32_t y = 0; y < Y; y++) {
258
+ if (!mdl->opt->all && w[y] == 0.0)
244
259
  continue;
245
260
  const char *ly = qrk_id2str(Qlbl, y);
246
- fprintf(fout, "%s\t#\t%s\t%f\n", obs, ly, w[y]);
261
+ fprintf(fout, "%s\t#\t%s\t", obs, ly);
262
+ fprintf(fout, fmt, w[y]);
247
263
  empty = false;
248
264
  }
249
265
  }
250
266
  if (mdl->kind[o] & 2) {
251
267
  const double *w = mdl->theta + mdl->boff[o];
252
- for (size_t d = 0; d < Y * Y; d++) {
253
- if (w[d] == 0.0)
268
+ for (uint32_t d = 0; d < Y * Y; d++) {
269
+ if (!mdl->opt->all && w[d] == 0.0)
254
270
  continue;
255
271
  const char *ly = qrk_id2str(Qlbl, d % Y);
256
272
  const char *lyp = qrk_id2str(Qlbl, d / Y);
257
- fprintf(fout, "%s\t%s\t%s\t%f\n", obs, lyp, ly,
258
- w[d]);
273
+ fprintf(fout, "%s\t%s\t%s\t", obs, lyp, ly);
274
+ fprintf(fout, fmt, w[d]);
259
275
  empty = false;
260
276
  }
261
277
  }
@@ -266,6 +282,110 @@ void dodump(mdl_t *mdl) {
266
282
  fclose(fout);
267
283
  }
268
284
 
285
+
286
+ /*******************************************************************************
287
+ * Updating
288
+ ******************************************************************************/
289
+ void doupdt(mdl_t *mdl) {
290
+ // Load input model file
291
+ info("* Load model\n");
292
+ if (mdl->opt->model == NULL)
293
+ fatal("no model file provided");
294
+ FILE *Min = fopen(mdl->opt->model, "r");
295
+ if (Min == NULL)
296
+ pfatal("cannot open model file %s", mdl->opt->model);
297
+ mdl_load(mdl, Min);
298
+ fclose(Min);
299
+ // Open patch file
300
+ info("* Update model\n");
301
+ FILE *fin = stdin;
302
+ if (mdl->opt->input != NULL) {
303
+ fin = fopen(mdl->opt->input, "r");
304
+ if (fin == NULL)
305
+ pfatal("cannot open update file");
306
+ }
307
+ int nline = 0;
308
+ while (!feof(fin)) {
309
+ char *raw = rdr_readline(fin);
310
+ if (raw == NULL)
311
+ break;
312
+ char *line = raw;
313
+ nline++;
314
+ // First we split the line in space separated tokens. We expect
315
+ // four of them and skip empty lines.
316
+ char *toks[4];
317
+ int ntoks = 0;
318
+ while (ntoks < 4) {
319
+ while (isspace(*line))
320
+ line++;
321
+ if (*line == '\0')
322
+ break;
323
+ toks[ntoks++] = line;
324
+ while (*line != '\0' && !isspace(*line))
325
+ line++;
326
+ if (*line == '\0')
327
+ break;
328
+ *line++ = '\0';
329
+ }
330
+ if (ntoks == 0) {
331
+ free(raw);
332
+ continue;
333
+ } else if (ntoks != 4) {
334
+ fatal("invalid line at %d", nline);
335
+ }
336
+ // Parse the tokens, the first three should be string maping to
337
+ // observations and labels and the last should be the weight.
338
+ uint64_t obs = none, yp = none, y = none;
339
+ obs = qrk_str2id(mdl->reader->obs, toks[0]);
340
+ if (obs == none)
341
+ fatal("bad on observation on line %d", nline);
342
+ if (strcmp(toks[1], "#")) {
343
+ yp = qrk_str2id(mdl->reader->lbl, toks[1]);
344
+ if (yp == none)
345
+ fatal("bad label <%s> line %d", toks[1], nline);
346
+ }
347
+ y = qrk_str2id(mdl->reader->lbl, toks[2]);
348
+ if (y == none)
349
+ fatal("bad label <%s> line %d", toks[2], nline);
350
+ double wgh = 0.0;
351
+ if (sscanf(toks[3], "%lf", &wgh) != 1)
352
+ fatal("bad weight on line %d", nline);
353
+
354
+ const uint32_t Y = mdl->nlbl;
355
+ if (yp == none) {
356
+ double *w = mdl->theta + mdl->uoff[obs];
357
+ w[y] = wgh;
358
+ } else {
359
+ double *w = mdl->theta + mdl->boff[obs];
360
+ w[yp * Y + y] = wgh;
361
+ }
362
+ free(raw);
363
+ }
364
+ if (mdl->opt->input != NULL)
365
+ fclose(fin);
366
+ // If requested compact the model.
367
+ if (mdl->opt->compact) {
368
+ const uint64_t O = mdl->nobs;
369
+ const uint64_t F = mdl->nftr;
370
+ info("* Compacting the model\n");
371
+ mdl_compact(mdl);
372
+ info(" %8"PRIu64" observations removed\n", O - mdl->nobs);
373
+ info(" %8"PRIu64" features removed\n", F - mdl->nftr);
374
+ }
375
+ // And save the updated model
376
+ info("* Save the model\n");
377
+ FILE *file = stdout;
378
+ if (mdl->opt->output != NULL) {
379
+ file = fopen(mdl->opt->output, "w");
380
+ if (file == NULL)
381
+ pfatal("cannot open output model");
382
+ }
383
+ mdl_save(mdl, file);
384
+ if (mdl->opt->output != NULL)
385
+ fclose(file);
386
+ info("* Done\n");
387
+ }
388
+
269
389
  /*******************************************************************************
270
390
  * Entry point
271
391
  ******************************************************************************/
@@ -280,9 +400,11 @@ int wapiti_main(int argc, char *argv[argc]) {
280
400
  switch (opt.mode) {
281
401
  case 0: dotrain(mdl); break;
282
402
  case 1: dolabel(mdl); break;
283
- case 2: dodump(mdl); break;
403
+ case 2: dodump(mdl); break;
404
+ case 3: doupdt(mdl); break;
284
405
  }
285
406
  // And cleanup
286
407
  mdl_free(mdl);
287
408
  return EXIT_SUCCESS;
288
409
  }
410
+
@@ -1,7 +1,7 @@
1
1
  /*
2
2
  * Wapiti - A linear-chain CRF tool
3
3
  *
4
- * Copyright (c) 2009-2011 CNRS
4
+ * Copyright (c) 2009-2013 CNRS
5
5
  * All rights reserved.
6
6
  *
7
7
  * Redistribution and use in source and binary forms, with or without
@@ -27,7 +27,7 @@
27
27
  #ifndef wapiti_h
28
28
  #define wapiti_h
29
29
 
30
- #define VERSION "1.2.0"
30
+ #define VERSION "1.5.0"
31
31
 
32
32
  /* XVM_ANSI:
33
33
  * By uncomenting the following define, you can force wapiti to not use SSE2
@@ -36,10 +36,24 @@
36
36
  //#define XVM_ANSI
37
37
 
38
38
  /* MTH_ANSI:
39
- * By uncomenting the following define, you can disable the use of POSIX
40
- * threads in the multi-threading part of Wapiti, for non-POSIX systems.
39
+ * By uncomenting the following define, you can disable the use of POSIX
40
+ * threads in the multi-threading part of Wapiti, for non-POSIX systems.
41
41
  */
42
42
  //#define MTH_ANSI
43
43
 
44
+ /* ATM_ANSI:
45
+ * By uncomenting the following define, you can disable the use of atomic
46
+ * operation to update the gradient. This imply that multi-threaded gradient
47
+ * computation will require more memory but is more portable.
48
+ */
49
+ //#define ATM_ANSI
50
+
51
+ /* Without multi-threading we disable atomic updates as they are not needed and
52
+ * can only decrease performances in this case.
53
+ */
54
+ #ifdef MTH_ANSI
55
+ #define ATM_ANSI
56
+ #endif
57
+
44
58
  #endif
45
59
 
@@ -5,20 +5,20 @@ require 'tempfile'
5
5
  require 'wapiti/version'
6
6
 
7
7
  module Wapiti
8
-
9
- Logger = ::Logger.new(STDOUT)
10
- Logger.level = ::Logger::WARN
11
-
12
- class << self
13
- def log
14
- Logger
15
- end
16
-
17
- def debug!
18
- log.level == ::Logger::DEBUG
19
- end
20
- end
21
-
8
+
9
+ Logger = ::Logger.new(STDOUT)
10
+ Logger.level = ::Logger::WARN
11
+
12
+ class << self
13
+ def log
14
+ Logger
15
+ end
16
+
17
+ def debug!
18
+ log.level == ::Logger::DEBUG
19
+ end
20
+ end
21
+
22
22
  end
23
23
 
24
24
  require 'wapiti/errors'
@@ -27,4 +27,4 @@ require 'wapiti/native'
27
27
  require 'wapiti/options'
28
28
  require 'wapiti/model'
29
29
 
30
- require 'wapiti/utility'
30
+ require 'wapiti/utility'
@@ -1,17 +1,17 @@
1
1
  module Wapiti
2
-
3
- class Error < StandardError
4
-
5
- attr_accessor :original
6
-
7
- def initialize(message = '', original = $!)
8
- super(message)
9
- @original = original
10
- end
11
-
12
- end
13
2
 
14
- class NativeError < Error; end
15
- class ConfigurationError < Error; end
16
-
17
- end
3
+ class Error < StandardError
4
+
5
+ attr_accessor :original
6
+
7
+ def initialize(message = '', original = $!)
8
+ super(message)
9
+ @original = original
10
+ end
11
+
12
+ end
13
+
14
+ class NativeError < Error; end
15
+ class ConfigurationError < Error; end
16
+
17
+ end
@@ -1,85 +1,93 @@
1
1
  module Wapiti
2
-
3
- class Model
4
-
5
- class << self
6
-
7
- def train(data, options, &block)
8
- config = Options.new(options, &block)
9
-
10
- # check configuration
11
- # if config.pattern.empty?
12
- # raise ConfigurationError, 'invalid options: no pattern specified'
13
- # end
14
-
15
- unless config.valid?
16
- raise ConfigurationError, "invalid options: #{ config.validate.join('; ') }"
17
- end
18
-
19
- new(config).train(data)
20
- end
21
-
22
- def load(filename)
23
- m = new
24
- m.path = filename
25
- m.load
26
- m
27
- end
28
-
29
- end
30
-
31
- attr_accessor :path
32
-
33
- attr_reader :token_count, :token_errors, :sequence_count, :sequence_errors
34
-
35
- def pattern
36
- options.pattern
37
- end
38
-
39
- def pattern=(filename)
40
- options.pattern = filename
41
- end
42
-
43
- alias native_label label
44
-
45
- def label(input, opts = nil)
46
- options.update(opts) unless opts.nil?
47
- block_given? ? native_label(input, &Proc.new) : native_label(input)
48
- end
49
-
50
- alias native_train train
51
-
52
- def train(input, opts = nil)
53
- options.update(opts) unless opts.nil?
54
- block_given? ? native_train(input, &Proc.new) : native_train(input)
55
- end
56
-
57
-
58
- def statistics
59
- s = {}
60
- s[:tokens] = {
61
- :total => token_count, :errors => @token_errors,
62
- :rate => token_errors.to_f / token_count.to_f * 100.0
63
- }
64
- s[:sequences] = {
65
- :total => sequence_count, :errors => sequence_errors,
66
- :rate => sequence_errors.to_f / sequence_count.to_f * 100.0
67
- }
68
- s
69
- end
70
-
71
- alias stats statistics
72
-
73
- def clear_counters
74
- @token_count = @token_errors = @sequence_count = @sequence_errors = 0
75
- end
76
-
77
- alias clear clear_counters
78
-
79
- # alias native_save save
80
-
81
- private :native_label, :native_train
82
-
83
- end
84
-
85
- end
2
+
3
+ class Model
4
+
5
+ class << self
6
+
7
+ def train(data, options, &block)
8
+ config = Options.new(options, &block)
9
+
10
+ # check configuration
11
+ # if config.pattern.empty?
12
+ # raise ConfigurationError, 'invalid options: no pattern specified'
13
+ # end
14
+
15
+ unless config.valid?
16
+ raise ConfigurationError, "invalid options: #{ config.validate.join('; ') }"
17
+ end
18
+
19
+ new(config).train(data)
20
+ end
21
+
22
+ def load(filename)
23
+ m = new
24
+ m.path = filename
25
+ m.load
26
+ m
27
+ end
28
+
29
+ end
30
+
31
+ attr_accessor :path
32
+
33
+ attr_reader :token_count, :token_errors, :sequence_count, :sequence_errors
34
+
35
+ def pattern
36
+ options.pattern
37
+ end
38
+
39
+ def pattern=(filename)
40
+ options.pattern = filename
41
+ end
42
+
43
+ alias native_label label
44
+
45
+ def label(input, opts = nil)
46
+ options.update(opts) unless opts.nil?
47
+ block_given? ? native_label(input, &Proc.new) : native_label(input)
48
+ end
49
+
50
+ alias native_train train
51
+
52
+ def train(input, opts = nil)
53
+ options.update(opts) unless opts.nil?
54
+ block_given? ? native_train(input, &Proc.new) : native_train(input)
55
+ end
56
+
57
+
58
+ def statistics
59
+ s = {}
60
+ s[:tokens] = {
61
+ :total => token_count, :errors => token_errors, :rate => token_error_rate
62
+ }
63
+ s[:sequences] = {
64
+ :total => sequence_count, :errors => sequence_errors, :rate => sequence_error_rate
65
+ }
66
+ s
67
+ end
68
+
69
+ alias stats statistics
70
+
71
+ def clear_counters
72
+ @token_count = @token_errors = @sequence_count = @sequence_errors = 0
73
+ end
74
+
75
+ alias clear clear_counters
76
+
77
+ def token_error_rate
78
+ return 0 if token_errors.zero?
79
+ token_errors / token_count.to_f * 100.0
80
+ end
81
+
82
+ def sequence_error_rate
83
+ return 0 if sequence_errors.zero?
84
+ sequence_errors / sequence_count.to_f * 100.0
85
+ end
86
+
87
+ # alias native_save save
88
+
89
+ private :native_label, :native_train
90
+
91
+ end
92
+
93
+ end