sentencepiece 0.0.1 → 0.0.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5fea5ff13a05fa657877f461944ccb7f5f4b8b6568d25c7581f795ad19514d9a
4
- data.tar.gz: ecfb8603781a0b9129ea68d94272739c558d07a1aaff8f35b5b76c1bcbb4ee87
3
+ metadata.gz: 52d4bb0f1a0b4a68c9db252911fe0661b7d8042d2d976a17506e0eaf0005d48f
4
+ data.tar.gz: 793b8b3e47cb6a9c1ab6b05b4cbef264b16eef7b632699983de9f8c40056d9ac
5
5
  SHA512:
6
- metadata.gz: fe928e1d7810a77f34ccce58b4feea2754ba533fbe1a3e4a0f8ccfbbfb0bdb5b0366550ce5844e6df50ad68a261fe8e9a6c047dde415683bb072d91c39f91868
7
- data.tar.gz: 5247986788a8a7c6678b65d23d235647518e25f32e1d888246255d556b20e82668331de0b7a8d71998be06163b1cded08be51af7a2b23719499f2aac79e6cd72
6
+ metadata.gz: 9702468ce33efdf7ae2ac3735cec983ae58a50677357a2d6cbff88089e73bf40b81ef9872b9860474c948fed3f5920a093252e90fbdea2d5e188c41057c7923e
7
+ data.tar.gz: ba262d347b32364255ecbbce8306b2d126296b46e6aac7828903c0dc9d6485702ef153ed7aea83a79eb69b5bfc655482787f6bb1c3ad086215f9106658a3a3e8
data/CHANGELOG.md CHANGED
@@ -1,5 +1,10 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.0.2] - 2023-03-26
4
+
5
+ - Add SentencePieceTrainer class.
6
+ - Add some encoding and decoding methods to SentencePieceProcessor.
7
+
3
8
  ## [0.0.1] - 2023-03-21
4
9
 
5
10
  - Initial release
data/README.md CHANGED
@@ -1,11 +1,13 @@
1
1
  # sentencepiece.rb
2
2
 
3
+ [![Build Status](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
4
+ [![Gem Version](https://badge.fury.io/rb/sentencepiece.svg)](https://badge.fury.io/rb/sentencepiece)
3
5
  [![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
4
6
 
5
7
  sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
6
8
  an unsupervised text tokenizer and detokenizer for neural network-based text generation.
7
9
 
8
- It is still under development and may undergo many changes in the future.
10
+ It is still **under development** and may undergo many changes in the future.
9
11
 
10
12
  ## Installation
11
13
 
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
22
22
  rb_mSentencePiece = rb_define_module("SentencePiece");
23
23
  rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
24
24
  RbSentencePieceProcessor::define_class(rb_mSentencePiece);
25
+ RbSentencePieceTrainer::define_class(rb_mSentencePiece);
25
26
  }
@@ -27,6 +27,7 @@
27
27
  VALUE rb_mSentencePiece;
28
28
  VALUE rb_eSentencePieceError;
29
29
  VALUE rb_cSentencePieceProcessor;
30
+ VALUE rb_cSentencePieceTrainer;
30
31
 
31
32
  class RbSentencePieceProcessor {
32
33
  public:
@@ -57,10 +58,27 @@ public:
57
58
  rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
58
59
  rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
59
60
  rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
61
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
62
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
63
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
64
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
65
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
66
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
60
67
  rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
68
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
69
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
70
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
71
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
72
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
73
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
74
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
61
75
  rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
62
76
  rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
63
77
  rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
78
+ rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
79
+ rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
80
+ rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
81
+ rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
64
82
  return rb_cSentencePieceProcessor;
65
83
  };
66
84
 
@@ -193,6 +211,182 @@ private:
193
211
  return output;
194
212
  };
195
213
 
214
+ static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
215
+ if (!RB_TYPE_P(text, T_STRING)) {
216
+ rb_raise(rb_eArgError, "expected text to be a String");
217
+ return Qnil;
218
+ }
219
+
220
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
221
+ const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
222
+ VALUE output = rb_ary_new();
223
+ for (const int idx : ids) {
224
+ rb_ary_push(output, INT2NUM(idx));
225
+ }
226
+
227
+ RB_GC_GUARD(text);
228
+ return output;
229
+ };
230
+
231
+ static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
232
+ if (!RB_TYPE_P(text, T_STRING)) {
233
+ rb_raise(rb_eArgError, "expected text to be a String");
234
+ return Qnil;
235
+ }
236
+
237
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
238
+ const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
239
+ VALUE output = rb_ary_new();
240
+ for (const std::string& token : pieces) {
241
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
242
+ }
243
+
244
+ RB_GC_GUARD(text);
245
+ return output;
246
+ };
247
+
248
+ static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
249
+ VALUE kw_args = Qnil;
250
+ ID kw_table[1] = { rb_intern("nbest_size") };
251
+ VALUE kw_values[1] = { Qundef };
252
+
253
+ VALUE text = Qnil;
254
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
255
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
256
+
257
+ if (!RB_TYPE_P(text, T_STRING)) {
258
+ rb_raise(rb_eArgError, "expected text to be a String");
259
+ return Qnil;
260
+ }
261
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
262
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
263
+ return Qnil;
264
+ }
265
+
266
+ const int nbest_size = NUM2INT(kw_values[0]);
267
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
268
+ const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
269
+
270
+ VALUE output = rb_ary_new();
271
+ for (const std::vector<std::string>& tokens : pieces) {
272
+ VALUE sub_output = rb_ary_new();
273
+ for (const std::string& token : tokens) {
274
+ rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
275
+ }
276
+ rb_ary_push(output, sub_output);
277
+ }
278
+
279
+ RB_GC_GUARD(text);
280
+ return output;
281
+ };
282
+
283
+ static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
284
+ VALUE kw_args = Qnil;
285
+ ID kw_table[1] = { rb_intern("nbest_size") };
286
+ VALUE kw_values[1] = { Qundef };
287
+
288
+ VALUE text = Qnil;
289
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
291
+
292
+ if (!RB_TYPE_P(text, T_STRING)) {
293
+ rb_raise(rb_eArgError, "expected text to be a String");
294
+ return Qnil;
295
+ }
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
297
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
298
+ return Qnil;
299
+ }
300
+
301
+ const int nbest_size = NUM2INT(kw_values[0]);
302
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
303
+ const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
304
+
305
+ VALUE output = rb_ary_new();
306
+ for (const std::vector<int> ids : id_set) {
307
+ VALUE sub_output = rb_ary_new();
308
+ for (const int idx : ids) {
309
+ rb_ary_push(sub_output, INT2NUM(idx));
310
+ }
311
+ rb_ary_push(output, sub_output);
312
+ }
313
+
314
+ RB_GC_GUARD(text);
315
+ return output;
316
+ };
317
+
318
+ static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
319
+ VALUE kw_args = Qnil;
320
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
321
+ VALUE kw_values[2] = { Qundef, Qundef };
322
+
323
+ VALUE text = Qnil;
324
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
325
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
326
+
327
+ if (!RB_TYPE_P(text, T_STRING)) {
328
+ rb_raise(rb_eArgError, "expected text to be a String");
329
+ return Qnil;
330
+ }
331
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
332
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
333
+ return Qnil;
334
+ }
335
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
336
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
337
+ return Qnil;
338
+ }
339
+
340
+ const int nbest_size = NUM2INT(kw_values[0]);
341
+ const float alpha = NUM2DBL(kw_values[1]);
342
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
343
+ const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
344
+
345
+ VALUE output = rb_ary_new();
346
+ for (const std::string& token : pieces) {
347
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
348
+ }
349
+
350
+ RB_GC_GUARD(text);
351
+ return output;
352
+ };
353
+
354
+ static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
355
+ VALUE kw_args = Qnil;
356
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
357
+ VALUE kw_values[2] = { Qundef, Qundef };
358
+
359
+ VALUE text = Qnil;
360
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
361
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
362
+
363
+ if (!RB_TYPE_P(text, T_STRING)) {
364
+ rb_raise(rb_eArgError, "expected text to be a String");
365
+ return Qnil;
366
+ }
367
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
368
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
369
+ return Qnil;
370
+ }
371
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
372
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
373
+ return Qnil;
374
+ }
375
+
376
+ const int nbest_size = NUM2INT(kw_values[0]);
377
+ const float alpha = NUM2DBL(kw_values[1]);
378
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
379
+ const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
380
+
381
+ VALUE output = rb_ary_new();
382
+ for (const int idx : ids) {
383
+ rb_ary_push(output, INT2NUM(idx));
384
+ }
385
+
386
+ RB_GC_GUARD(text);
387
+ return output;
388
+ };
389
+
196
390
  static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
197
391
  VALUE kw_args = Qnil;
198
392
  ID kw_table[1] = { rb_intern("out_type") };
@@ -204,7 +398,7 @@ private:
204
398
  VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
205
399
 
206
400
  if (!RB_TYPE_P(pieces, T_ARRAY)) {
207
- rb_raise(rb_eArgError, "expected out_type to be an Array");
401
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
208
402
  return Qnil;
209
403
  }
210
404
  if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
@@ -232,7 +426,7 @@ private:
232
426
  rb_raise(rb_eSentencePieceError, "%s", status.message());
233
427
  return Qfalse;
234
428
  }
235
- output = rb_str_new_cstr(text.c_str());
429
+ output = rb_utf8_str_new_cstr(text.c_str());
236
430
  } else {
237
431
  output = rb_ary_new();
238
432
  for (size_t i = 0; i < n_pieces; i++) {
@@ -249,7 +443,7 @@ private:
249
443
  rb_raise(rb_eSentencePieceError, "%s", status.message());
250
444
  return Qfalse;
251
445
  }
252
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
446
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
253
447
  }
254
448
  }
255
449
  } else {
@@ -265,7 +459,7 @@ private:
265
459
  rb_raise(rb_eSentencePieceError, "%s", status.message());
266
460
  return Qfalse;
267
461
  }
268
- output = rb_str_new_cstr(text.c_str());
462
+ output = rb_utf8_str_new_cstr(text.c_str());
269
463
  } else {
270
464
  output = rb_ary_new();
271
465
  for (size_t i = 0; i < n_pieces; i++) {
@@ -282,7 +476,7 @@ private:
282
476
  rb_raise(rb_eSentencePieceError, "%s", status.message());
283
477
  return Qfalse;
284
478
  }
285
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
479
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
286
480
  }
287
481
  }
288
482
  }
@@ -290,6 +484,159 @@ private:
290
484
  return output;
291
485
  };
292
486
 
487
+ static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
488
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
489
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
490
+ return Qnil;
491
+ }
492
+
493
+ std::vector<std::string> pcs;
494
+ const size_t n_pieces = RARRAY_LEN(pieces);
495
+ for (size_t i = 0; i < n_pieces; i++) {
496
+ VALUE et = rb_ary_entry(pieces, i);
497
+ pcs.push_back(StringValueCStr(et));
498
+ }
499
+
500
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
501
+ const std::string text = ptr->DecodePieces(pcs);
502
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
503
+
504
+ return output;
505
+ };
506
+
507
+ static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
508
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
509
+ rb_raise(rb_eArgError, "expected ids to be an Array");
510
+ return Qnil;
511
+ }
512
+
513
+ std::vector<int> pcs;
514
+ const size_t n_pieces = RARRAY_LEN(ids);
515
+ for (size_t i = 0; i < n_pieces; i++) {
516
+ VALUE et = rb_ary_entry(ids, i);
517
+ pcs.push_back(NUM2INT(et));
518
+ }
519
+
520
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
521
+ const std::string text = ptr->DecodeIds(pcs);
522
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
523
+
524
+ return output;
525
+ };
526
+
527
+ static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
528
+ if (!RB_TYPE_P(text, T_STRING)) {
529
+ rb_raise(rb_eArgError, "expected text to be a String");
530
+ return Qnil;
531
+ }
532
+
533
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
534
+ const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
535
+ VALUE output = rb_str_new_cstr(serialized.c_str());
536
+
537
+ RB_GC_GUARD(text);
538
+ return output;
539
+ };
540
+
541
+ static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
542
+ VALUE kw_args = Qnil;
543
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
544
+ VALUE kw_values[2] = { Qundef, Qundef };
545
+
546
+ VALUE text = Qnil;
547
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
548
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
549
+
550
+ if (!RB_TYPE_P(text, T_STRING)) {
551
+ rb_raise(rb_eArgError, "expected text to be a String");
552
+ return Qnil;
553
+ }
554
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
555
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
556
+ return Qnil;
557
+ }
558
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
559
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
560
+ return Qnil;
561
+ }
562
+
563
+ const int nbest_size = NUM2INT(kw_values[0]);
564
+ const float alpha = NUM2DBL(kw_values[1]);
565
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
566
+ const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
567
+ VALUE output = rb_str_new_cstr(serialized.c_str());
568
+
569
+ RB_GC_GUARD(text);
570
+ return output;
571
+ };
572
+
573
+ static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
574
+ VALUE kw_args = Qnil;
575
+ ID kw_table[1] = { rb_intern("nbest_size") };
576
+ VALUE kw_values[1] = { Qundef };
577
+
578
+ VALUE text = Qnil;
579
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
580
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
581
+
582
+ if (!RB_TYPE_P(text, T_STRING)) {
583
+ rb_raise(rb_eArgError, "expected text to be a String");
584
+ return Qnil;
585
+ }
586
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
587
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
588
+ return Qnil;
589
+ }
590
+
591
+ const int nbest_size = NUM2INT(kw_values[0]);
592
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
593
+ const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
594
+ VALUE output = rb_str_new_cstr(serialized.c_str());
595
+
596
+ RB_GC_GUARD(text);
597
+ return output;
598
+ };
599
+
600
+ static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
601
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
602
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
603
+ return Qnil;
604
+ }
605
+
606
+ std::vector<std::string> pcs;
607
+ const size_t n_pieces = RARRAY_LEN(pieces);
608
+ for (size_t i = 0; i < n_pieces; i++) {
609
+ VALUE et = rb_ary_entry(pieces, i);
610
+ pcs.push_back(StringValueCStr(et));
611
+ }
612
+
613
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
614
+ const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
615
+ VALUE output = rb_str_new_cstr(serialized.c_str());
616
+
617
+ return output;
618
+ };
619
+
620
+ static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
621
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
622
+ rb_raise(rb_eArgError, "expected ids to be an Array");
623
+ return Qnil;
624
+ }
625
+
626
+ std::vector<int> pcs;
627
+ const size_t n_pieces = RARRAY_LEN(ids);
628
+ for (size_t i = 0; i < n_pieces; i++) {
629
+ VALUE et = rb_ary_entry(ids, i);
630
+ pcs.push_back(NUM2INT(et));
631
+ }
632
+
633
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
634
+ const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
635
+ VALUE output = rb_str_new_cstr(serialized.c_str());
636
+
637
+ return output;
638
+ };
639
+
293
640
  static VALUE _sentencepiece_processor_piece_size(VALUE self) {
294
641
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
295
642
  return INT2NUM(ptr->GetPieceSize());
@@ -325,19 +672,49 @@ private:
325
672
  VALUE output = Qnil;
326
673
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
327
674
  if (RB_INTEGER_TYPE_P(ids)) {
328
- const std::string piece = ptr->IdToPiece(NUM2INT(ids));
675
+ const int idx = NUM2INT(ids);
676
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
677
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
678
+ return Qnil;
679
+ }
680
+ const std::string piece = ptr->IdToPiece(idx);
329
681
  output = rb_utf8_str_new_cstr(piece.c_str());
330
682
  } else {
331
683
  const size_t n_ids = RARRAY_LEN(ids);
332
684
  output = rb_ary_new();
333
685
  for (size_t i = 0; i < n_ids; i++) {
334
686
  VALUE et = rb_ary_entry(ids, i);
335
- const std::string piece = ptr->IdToPiece(NUM2INT(et));
687
+ const int idx = NUM2INT(et);
688
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
689
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
690
+ return Qnil;
691
+ }
692
+ const std::string piece = ptr->IdToPiece(idx);
336
693
  rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
337
694
  }
338
695
  }
339
696
  return output;
340
697
  };
698
+
699
+ static VALUE _sentencepiece_processor_unk_id(VALUE self) {
700
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
701
+ return INT2NUM(ptr->unk_id());
702
+ };
703
+
704
+ static VALUE _sentencepiece_processor_bos_id(VALUE self) {
705
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
706
+ return INT2NUM(ptr->bos_id());
707
+ };
708
+
709
+ static VALUE _sentencepiece_processor_eos_id(VALUE self) {
710
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
711
+ return INT2NUM(ptr->eos_id());
712
+ };
713
+
714
+ static VALUE _sentencepiece_processor_pad_id(VALUE self) {
715
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
716
+ return INT2NUM(ptr->pad_id());
717
+ };
341
718
  };
342
719
 
343
720
  const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
350
727
  RUBY_TYPED_FREE_IMMEDIATELY
351
728
  };
352
729
 
730
+ namespace sentencepiece {
731
+ class SentencePieceTrainerWrapper {
732
+ public:
733
+ SentencePieceTrainerWrapper(){};
734
+ ~SentencePieceTrainerWrapper(){};
735
+
736
+ util::Status Train(absl::string_view args) {
737
+ return SentencePieceTrainer::Train(args);
738
+ };
739
+ };
740
+ } // namespace sentencepiece
741
+
742
+ class RbSentencePieceTrainer {
743
+ public:
744
+ static VALUE define_class(VALUE outer) {
745
+ rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
746
+ rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
747
+ return rb_cSentencePieceTrainer;
748
+ };
749
+
750
+ private:
751
+ static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
752
+ if (!RB_TYPE_P(args, T_STRING)) {
753
+ rb_raise(rb_eArgError, "expected args to be a String");
754
+ return Qnil;
755
+ }
756
+
757
+ const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
758
+ if (!status.ok()) {
759
+ rb_raise(rb_eSentencePieceError, "%s", status.message());
760
+ return Qnil;
761
+ }
762
+
763
+ RB_GC_GUARD(args);
764
+ return Qnil;
765
+ };
766
+ };
767
+
353
768
  #endif /* SENTENCEPIECE_HPP */
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module SentencePiece
4
- VERSION = '0.0.1'
4
+ VERSION = '0.0.2'
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: sentencepiece
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-21 00:00:00.000000000 Z
11
+ date: 2023-03-26 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: |
14
14
  sentencepiece.rb provides Ruby bindings for the SentencePiece,