sentencepiece 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5fea5ff13a05fa657877f461944ccb7f5f4b8b6568d25c7581f795ad19514d9a
4
- data.tar.gz: ecfb8603781a0b9129ea68d94272739c558d07a1aaff8f35b5b76c1bcbb4ee87
3
+ metadata.gz: 52d4bb0f1a0b4a68c9db252911fe0661b7d8042d2d976a17506e0eaf0005d48f
4
+ data.tar.gz: 793b8b3e47cb6a9c1ab6b05b4cbef264b16eef7b632699983de9f8c40056d9ac
5
5
  SHA512:
6
- metadata.gz: fe928e1d7810a77f34ccce58b4feea2754ba533fbe1a3e4a0f8ccfbbfb0bdb5b0366550ce5844e6df50ad68a261fe8e9a6c047dde415683bb072d91c39f91868
7
- data.tar.gz: 5247986788a8a7c6678b65d23d235647518e25f32e1d888246255d556b20e82668331de0b7a8d71998be06163b1cded08be51af7a2b23719499f2aac79e6cd72
6
+ metadata.gz: 9702468ce33efdf7ae2ac3735cec983ae58a50677357a2d6cbff88089e73bf40b81ef9872b9860474c948fed3f5920a093252e90fbdea2d5e188c41057c7923e
7
+ data.tar.gz: ba262d347b32364255ecbbce8306b2d126296b46e6aac7828903c0dc9d6485702ef153ed7aea83a79eb69b5bfc655482787f6bb1c3ad086215f9106658a3a3e8
data/CHANGELOG.md CHANGED
@@ -1,5 +1,10 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.0.2] - 2023-03-26
4
+
5
+ - Add SentencePieceTrainer class.
6
+ - Add some encoding and decoding methods to SentencePieceProcessor.
7
+
3
8
  ## [0.0.1] - 2023-03-21
4
9
 
5
10
  - Initial release
data/README.md CHANGED
@@ -1,11 +1,13 @@
1
1
  # sentencepiece.rb
2
2
 
3
+ [![Build Status](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
4
+ [![Gem Version](https://badge.fury.io/rb/sentencepiece.svg)](https://badge.fury.io/rb/sentencepiece)
3
5
  [![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
4
6
 
5
7
  sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
6
8
  an unsupervised text tokenizer and detokenizer for neural network-based text generation.
7
9
 
8
- It is still under development and may undergo many changes in the future.
10
+ It is still **under development** and may undergo many changes in the future.
9
11
 
10
12
  ## Installation
11
13
 
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
22
22
  rb_mSentencePiece = rb_define_module("SentencePiece");
23
23
  rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
24
24
  RbSentencePieceProcessor::define_class(rb_mSentencePiece);
25
+ RbSentencePieceTrainer::define_class(rb_mSentencePiece);
25
26
  }
@@ -27,6 +27,7 @@
27
27
  VALUE rb_mSentencePiece;
28
28
  VALUE rb_eSentencePieceError;
29
29
  VALUE rb_cSentencePieceProcessor;
30
+ VALUE rb_cSentencePieceTrainer;
30
31
 
31
32
  class RbSentencePieceProcessor {
32
33
  public:
@@ -57,10 +58,27 @@ public:
57
58
  rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
58
59
  rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
59
60
  rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
61
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
62
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
63
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
64
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
65
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
66
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
60
67
  rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
68
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
69
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
70
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
71
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
72
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
73
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
74
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
61
75
  rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
62
76
  rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
63
77
  rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
78
+ rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
79
+ rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
80
+ rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
81
+ rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
64
82
  return rb_cSentencePieceProcessor;
65
83
  };
66
84
 
@@ -193,6 +211,182 @@ private:
193
211
  return output;
194
212
  };
195
213
 
214
+ static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
215
+ if (!RB_TYPE_P(text, T_STRING)) {
216
+ rb_raise(rb_eArgError, "expected text to be a String");
217
+ return Qnil;
218
+ }
219
+
220
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
221
+ const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
222
+ VALUE output = rb_ary_new();
223
+ for (const int idx : ids) {
224
+ rb_ary_push(output, INT2NUM(idx));
225
+ }
226
+
227
+ RB_GC_GUARD(text);
228
+ return output;
229
+ };
230
+
231
+ static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
232
+ if (!RB_TYPE_P(text, T_STRING)) {
233
+ rb_raise(rb_eArgError, "expected text to be a String");
234
+ return Qnil;
235
+ }
236
+
237
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
238
+ const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
239
+ VALUE output = rb_ary_new();
240
+ for (const std::string& token : pieces) {
241
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
242
+ }
243
+
244
+ RB_GC_GUARD(text);
245
+ return output;
246
+ };
247
+
248
+ static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
249
+ VALUE kw_args = Qnil;
250
+ ID kw_table[1] = { rb_intern("nbest_size") };
251
+ VALUE kw_values[1] = { Qundef };
252
+
253
+ VALUE text = Qnil;
254
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
255
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
256
+
257
+ if (!RB_TYPE_P(text, T_STRING)) {
258
+ rb_raise(rb_eArgError, "expected text to be a String");
259
+ return Qnil;
260
+ }
261
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
262
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
263
+ return Qnil;
264
+ }
265
+
266
+ const int nbest_size = NUM2INT(kw_values[0]);
267
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
268
+ const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
269
+
270
+ VALUE output = rb_ary_new();
271
+ for (const std::vector<std::string>& tokens : pieces) {
272
+ VALUE sub_output = rb_ary_new();
273
+ for (const std::string& token : tokens) {
274
+ rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
275
+ }
276
+ rb_ary_push(output, sub_output);
277
+ }
278
+
279
+ RB_GC_GUARD(text);
280
+ return output;
281
+ };
282
+
283
+ static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
284
+ VALUE kw_args = Qnil;
285
+ ID kw_table[1] = { rb_intern("nbest_size") };
286
+ VALUE kw_values[1] = { Qundef };
287
+
288
+ VALUE text = Qnil;
289
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
291
+
292
+ if (!RB_TYPE_P(text, T_STRING)) {
293
+ rb_raise(rb_eArgError, "expected text to be a String");
294
+ return Qnil;
295
+ }
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
297
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
298
+ return Qnil;
299
+ }
300
+
301
+ const int nbest_size = NUM2INT(kw_values[0]);
302
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
303
+ const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
304
+
305
+ VALUE output = rb_ary_new();
306
+ for (const std::vector<int> ids : id_set) {
307
+ VALUE sub_output = rb_ary_new();
308
+ for (const int idx : ids) {
309
+ rb_ary_push(sub_output, INT2NUM(idx));
310
+ }
311
+ rb_ary_push(output, sub_output);
312
+ }
313
+
314
+ RB_GC_GUARD(text);
315
+ return output;
316
+ };
317
+
318
+ static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
319
+ VALUE kw_args = Qnil;
320
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
321
+ VALUE kw_values[2] = { Qundef, Qundef };
322
+
323
+ VALUE text = Qnil;
324
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
325
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
326
+
327
+ if (!RB_TYPE_P(text, T_STRING)) {
328
+ rb_raise(rb_eArgError, "expected text to be a String");
329
+ return Qnil;
330
+ }
331
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
332
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
333
+ return Qnil;
334
+ }
335
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
336
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
337
+ return Qnil;
338
+ }
339
+
340
+ const int nbest_size = NUM2INT(kw_values[0]);
341
+ const float alpha = NUM2DBL(kw_values[1]);
342
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
343
+ const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
344
+
345
+ VALUE output = rb_ary_new();
346
+ for (const std::string& token : pieces) {
347
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
348
+ }
349
+
350
+ RB_GC_GUARD(text);
351
+ return output;
352
+ };
353
+
354
+ static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
355
+ VALUE kw_args = Qnil;
356
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
357
+ VALUE kw_values[2] = { Qundef, Qundef };
358
+
359
+ VALUE text = Qnil;
360
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
361
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
362
+
363
+ if (!RB_TYPE_P(text, T_STRING)) {
364
+ rb_raise(rb_eArgError, "expected text to be a String");
365
+ return Qnil;
366
+ }
367
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
368
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
369
+ return Qnil;
370
+ }
371
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
372
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
373
+ return Qnil;
374
+ }
375
+
376
+ const int nbest_size = NUM2INT(kw_values[0]);
377
+ const float alpha = NUM2DBL(kw_values[1]);
378
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
379
+ const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
380
+
381
+ VALUE output = rb_ary_new();
382
+ for (const int idx : ids) {
383
+ rb_ary_push(output, INT2NUM(idx));
384
+ }
385
+
386
+ RB_GC_GUARD(text);
387
+ return output;
388
+ };
389
+
196
390
  static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
197
391
  VALUE kw_args = Qnil;
198
392
  ID kw_table[1] = { rb_intern("out_type") };
@@ -204,7 +398,7 @@ private:
204
398
  VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
205
399
 
206
400
  if (!RB_TYPE_P(pieces, T_ARRAY)) {
207
- rb_raise(rb_eArgError, "expected out_type to be an Array");
401
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
208
402
  return Qnil;
209
403
  }
210
404
  if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
@@ -232,7 +426,7 @@ private:
232
426
  rb_raise(rb_eSentencePieceError, "%s", status.message());
233
427
  return Qfalse;
234
428
  }
235
- output = rb_str_new_cstr(text.c_str());
429
+ output = rb_utf8_str_new_cstr(text.c_str());
236
430
  } else {
237
431
  output = rb_ary_new();
238
432
  for (size_t i = 0; i < n_pieces; i++) {
@@ -249,7 +443,7 @@ private:
249
443
  rb_raise(rb_eSentencePieceError, "%s", status.message());
250
444
  return Qfalse;
251
445
  }
252
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
446
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
253
447
  }
254
448
  }
255
449
  } else {
@@ -265,7 +459,7 @@ private:
265
459
  rb_raise(rb_eSentencePieceError, "%s", status.message());
266
460
  return Qfalse;
267
461
  }
268
- output = rb_str_new_cstr(text.c_str());
462
+ output = rb_utf8_str_new_cstr(text.c_str());
269
463
  } else {
270
464
  output = rb_ary_new();
271
465
  for (size_t i = 0; i < n_pieces; i++) {
@@ -282,7 +476,7 @@ private:
282
476
  rb_raise(rb_eSentencePieceError, "%s", status.message());
283
477
  return Qfalse;
284
478
  }
285
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
479
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
286
480
  }
287
481
  }
288
482
  }
@@ -290,6 +484,159 @@ private:
290
484
  return output;
291
485
  };
292
486
 
487
+ static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
488
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
489
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
490
+ return Qnil;
491
+ }
492
+
493
+ std::vector<std::string> pcs;
494
+ const size_t n_pieces = RARRAY_LEN(pieces);
495
+ for (size_t i = 0; i < n_pieces; i++) {
496
+ VALUE et = rb_ary_entry(pieces, i);
497
+ pcs.push_back(StringValueCStr(et));
498
+ }
499
+
500
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
501
+ const std::string text = ptr->DecodePieces(pcs);
502
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
503
+
504
+ return output;
505
+ };
506
+
507
+ static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
508
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
509
+ rb_raise(rb_eArgError, "expected ids to be an Array");
510
+ return Qnil;
511
+ }
512
+
513
+ std::vector<int> pcs;
514
+ const size_t n_pieces = RARRAY_LEN(ids);
515
+ for (size_t i = 0; i < n_pieces; i++) {
516
+ VALUE et = rb_ary_entry(ids, i);
517
+ pcs.push_back(NUM2INT(et));
518
+ }
519
+
520
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
521
+ const std::string text = ptr->DecodeIds(pcs);
522
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
523
+
524
+ return output;
525
+ };
526
+
527
+ static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
528
+ if (!RB_TYPE_P(text, T_STRING)) {
529
+ rb_raise(rb_eArgError, "expected text to be a String");
530
+ return Qnil;
531
+ }
532
+
533
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
534
+ const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
535
+ VALUE output = rb_str_new_cstr(serialized.c_str());
536
+
537
+ RB_GC_GUARD(text);
538
+ return output;
539
+ };
540
+
541
+ static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
542
+ VALUE kw_args = Qnil;
543
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
544
+ VALUE kw_values[2] = { Qundef, Qundef };
545
+
546
+ VALUE text = Qnil;
547
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
548
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
549
+
550
+ if (!RB_TYPE_P(text, T_STRING)) {
551
+ rb_raise(rb_eArgError, "expected text to be a String");
552
+ return Qnil;
553
+ }
554
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
555
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
556
+ return Qnil;
557
+ }
558
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
559
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
560
+ return Qnil;
561
+ }
562
+
563
+ const int nbest_size = NUM2INT(kw_values[0]);
564
+ const float alpha = NUM2DBL(kw_values[1]);
565
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
566
+ const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
567
+ VALUE output = rb_str_new_cstr(serialized.c_str());
568
+
569
+ RB_GC_GUARD(text);
570
+ return output;
571
+ };
572
+
573
+ static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
574
+ VALUE kw_args = Qnil;
575
+ ID kw_table[1] = { rb_intern("nbest_size") };
576
+ VALUE kw_values[1] = { Qundef };
577
+
578
+ VALUE text = Qnil;
579
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
580
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
581
+
582
+ if (!RB_TYPE_P(text, T_STRING)) {
583
+ rb_raise(rb_eArgError, "expected text to be a String");
584
+ return Qnil;
585
+ }
586
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
587
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
588
+ return Qnil;
589
+ }
590
+
591
+ const int nbest_size = NUM2INT(kw_values[0]);
592
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
593
+ const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
594
+ VALUE output = rb_str_new_cstr(serialized.c_str());
595
+
596
+ RB_GC_GUARD(text);
597
+ return output;
598
+ };
599
+
600
+ static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
601
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
602
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
603
+ return Qnil;
604
+ }
605
+
606
+ std::vector<std::string> pcs;
607
+ const size_t n_pieces = RARRAY_LEN(pieces);
608
+ for (size_t i = 0; i < n_pieces; i++) {
609
+ VALUE et = rb_ary_entry(pieces, i);
610
+ pcs.push_back(StringValueCStr(et));
611
+ }
612
+
613
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
614
+ const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
615
+ VALUE output = rb_str_new_cstr(serialized.c_str());
616
+
617
+ return output;
618
+ };
619
+
620
+ static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
621
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
622
+ rb_raise(rb_eArgError, "expected ids to be an Array");
623
+ return Qnil;
624
+ }
625
+
626
+ std::vector<int> pcs;
627
+ const size_t n_pieces = RARRAY_LEN(ids);
628
+ for (size_t i = 0; i < n_pieces; i++) {
629
+ VALUE et = rb_ary_entry(ids, i);
630
+ pcs.push_back(NUM2INT(et));
631
+ }
632
+
633
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
634
+ const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
635
+ VALUE output = rb_str_new_cstr(serialized.c_str());
636
+
637
+ return output;
638
+ };
639
+
293
640
  static VALUE _sentencepiece_processor_piece_size(VALUE self) {
294
641
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
295
642
  return INT2NUM(ptr->GetPieceSize());
@@ -325,19 +672,49 @@ private:
325
672
  VALUE output = Qnil;
326
673
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
327
674
  if (RB_INTEGER_TYPE_P(ids)) {
328
- const std::string piece = ptr->IdToPiece(NUM2INT(ids));
675
+ const int idx = NUM2INT(ids);
676
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
677
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
678
+ return Qnil;
679
+ }
680
+ const std::string piece = ptr->IdToPiece(idx);
329
681
  output = rb_utf8_str_new_cstr(piece.c_str());
330
682
  } else {
331
683
  const size_t n_ids = RARRAY_LEN(ids);
332
684
  output = rb_ary_new();
333
685
  for (size_t i = 0; i < n_ids; i++) {
334
686
  VALUE et = rb_ary_entry(ids, i);
335
- const std::string piece = ptr->IdToPiece(NUM2INT(et));
687
+ const int idx = NUM2INT(et);
688
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
689
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
690
+ return Qnil;
691
+ }
692
+ const std::string piece = ptr->IdToPiece(idx);
336
693
  rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
337
694
  }
338
695
  }
339
696
  return output;
340
697
  };
698
+
699
+ static VALUE _sentencepiece_processor_unk_id(VALUE self) {
700
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
701
+ return INT2NUM(ptr->unk_id());
702
+ };
703
+
704
+ static VALUE _sentencepiece_processor_bos_id(VALUE self) {
705
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
706
+ return INT2NUM(ptr->bos_id());
707
+ };
708
+
709
+ static VALUE _sentencepiece_processor_eos_id(VALUE self) {
710
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
711
+ return INT2NUM(ptr->eos_id());
712
+ };
713
+
714
+ static VALUE _sentencepiece_processor_pad_id(VALUE self) {
715
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
716
+ return INT2NUM(ptr->pad_id());
717
+ };
341
718
  };
342
719
 
343
720
  const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
350
727
  RUBY_TYPED_FREE_IMMEDIATELY
351
728
  };
352
729
 
730
+ namespace sentencepiece {
731
+ class SentencePieceTrainerWrapper {
732
+ public:
733
+ SentencePieceTrainerWrapper(){};
734
+ ~SentencePieceTrainerWrapper(){};
735
+
736
+ util::Status Train(absl::string_view args) {
737
+ return SentencePieceTrainer::Train(args);
738
+ };
739
+ };
740
+ } // namespace sentencepiece
741
+
742
+ class RbSentencePieceTrainer {
743
+ public:
744
+ static VALUE define_class(VALUE outer) {
745
+ rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
746
+ rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
747
+ return rb_cSentencePieceTrainer;
748
+ };
749
+
750
+ private:
751
+ static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
752
+ if (!RB_TYPE_P(args, T_STRING)) {
753
+ rb_raise(rb_eArgError, "expected args to be a String");
754
+ return Qnil;
755
+ }
756
+
757
+ const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
758
+ if (!status.ok()) {
759
+ rb_raise(rb_eSentencePieceError, "%s", status.message());
760
+ return Qnil;
761
+ }
762
+
763
+ RB_GC_GUARD(args);
764
+ return Qnil;
765
+ };
766
+ };
767
+
353
768
  #endif /* SENTENCEPIECE_HPP */
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module SentencePiece
4
- VERSION = '0.0.1'
4
+ VERSION = '0.0.2'
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: sentencepiece
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.0.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-21 00:00:00.000000000 Z
11
+ date: 2023-03-26 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: |
14
14
  sentencepiece.rb provides Ruby bindings for the SentencePiece,