sentencepiece 0.0.1 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5fea5ff13a05fa657877f461944ccb7f5f4b8b6568d25c7581f795ad19514d9a
4
- data.tar.gz: ecfb8603781a0b9129ea68d94272739c558d07a1aaff8f35b5b76c1bcbb4ee87
3
+ metadata.gz: 63aaf69a72781f0087072db8197b9ae3a126b63e632b52bded5b483ff7b2c6b4
4
+ data.tar.gz: 5ffbd48ed587ef57fd3b427615cc21f4d79804ea90598f9db2178ceff702ca57
5
5
  SHA512:
6
- metadata.gz: fe928e1d7810a77f34ccce58b4feea2754ba533fbe1a3e4a0f8ccfbbfb0bdb5b0366550ce5844e6df50ad68a261fe8e9a6c047dde415683bb072d91c39f91868
7
- data.tar.gz: 5247986788a8a7c6678b65d23d235647518e25f32e1d888246255d556b20e82668331de0b7a8d71998be06163b1cded08be51af7a2b23719499f2aac79e6cd72
6
+ metadata.gz: 61b91c838bbdb2f7c47166036b890b545060361e91a66e6cf416206f53d84595cc74052a38df083ad507a6f219d0b2b8ba3667095d2d3514b5ba0c3660967ee5
7
+ data.tar.gz: 467c2f252e23fbfe684486d42af9fe0eedd058dd99361fc5652bb83412ae801571e5e25879b2f1bb6734c997b9c04f4287378c595afa3698b06885631c8fdcae
data/CHANGELOG.md CHANGED
@@ -1,5 +1,15 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.1.0] - 2023-03-26
4
+
5
+ - Add API documentation.
6
+ - Add type signatures.
7
+
8
+ ## [0.0.2] - 2023-03-26
9
+
10
+ - Add SentencePieceTrainer class.
11
+ - Add some encoding and decoding methods to SentencePieceProcessor.
12
+
3
13
  ## [0.0.1] - 2023-03-21
4
14
 
5
15
  - Initial release
data/README.md CHANGED
@@ -1,12 +1,13 @@
1
1
  # sentencepiece.rb
2
2
 
3
+ [![Build Status](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
4
+ [![Gem Version](https://badge.fury.io/rb/sentencepiece.svg)](https://badge.fury.io/rb/sentencepiece)
3
5
  [![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
6
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/sentencepiece.rb/doc/)
4
7
 
5
8
  sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
6
9
  an unsupervised text tokenizer and detokenizer for neural network-based text generation.
7
10
 
8
- It is still under development and may undergo many changes in the future.
9
-
10
11
  ## Installation
11
12
 
12
13
  Install SentencePiece using your OS package manager;
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
22
22
  rb_mSentencePiece = rb_define_module("SentencePiece");
23
23
  rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
24
24
  RbSentencePieceProcessor::define_class(rb_mSentencePiece);
25
+ RbSentencePieceTrainer::define_class(rb_mSentencePiece);
25
26
  }
@@ -27,6 +27,7 @@
27
27
  VALUE rb_mSentencePiece;
28
28
  VALUE rb_eSentencePieceError;
29
29
  VALUE rb_cSentencePieceProcessor;
30
+ VALUE rb_cSentencePieceTrainer;
30
31
 
31
32
  class RbSentencePieceProcessor {
32
33
  public:
@@ -57,10 +58,27 @@ public:
57
58
  rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
58
59
  rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
59
60
  rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
61
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
62
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
63
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
64
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
65
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
66
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
60
67
  rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
68
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
69
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
70
+ rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
71
+ rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
72
+ rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
73
+ rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
74
+ rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
61
75
  rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
62
76
  rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
63
77
  rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
78
+ rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
79
+ rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
80
+ rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
81
+ rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
64
82
  return rb_cSentencePieceProcessor;
65
83
  };
66
84
 
@@ -193,6 +211,182 @@ private:
193
211
  return output;
194
212
  };
195
213
 
214
+ static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
215
+ if (!RB_TYPE_P(text, T_STRING)) {
216
+ rb_raise(rb_eArgError, "expected text to be a String");
217
+ return Qnil;
218
+ }
219
+
220
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
221
+ const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
222
+ VALUE output = rb_ary_new();
223
+ for (const int idx : ids) {
224
+ rb_ary_push(output, INT2NUM(idx));
225
+ }
226
+
227
+ RB_GC_GUARD(text);
228
+ return output;
229
+ };
230
+
231
+ static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
232
+ if (!RB_TYPE_P(text, T_STRING)) {
233
+ rb_raise(rb_eArgError, "expected text to be a String");
234
+ return Qnil;
235
+ }
236
+
237
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
238
+ const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
239
+ VALUE output = rb_ary_new();
240
+ for (const std::string& token : pieces) {
241
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
242
+ }
243
+
244
+ RB_GC_GUARD(text);
245
+ return output;
246
+ };
247
+
248
+ static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
249
+ VALUE kw_args = Qnil;
250
+ ID kw_table[1] = { rb_intern("nbest_size") };
251
+ VALUE kw_values[1] = { Qundef };
252
+
253
+ VALUE text = Qnil;
254
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
255
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
256
+
257
+ if (!RB_TYPE_P(text, T_STRING)) {
258
+ rb_raise(rb_eArgError, "expected text to be a String");
259
+ return Qnil;
260
+ }
261
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
262
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
263
+ return Qnil;
264
+ }
265
+
266
+ const int nbest_size = NUM2INT(kw_values[0]);
267
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
268
+ const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
269
+
270
+ VALUE output = rb_ary_new();
271
+ for (const std::vector<std::string>& tokens : pieces) {
272
+ VALUE sub_output = rb_ary_new();
273
+ for (const std::string& token : tokens) {
274
+ rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
275
+ }
276
+ rb_ary_push(output, sub_output);
277
+ }
278
+
279
+ RB_GC_GUARD(text);
280
+ return output;
281
+ };
282
+
283
+ static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
284
+ VALUE kw_args = Qnil;
285
+ ID kw_table[1] = { rb_intern("nbest_size") };
286
+ VALUE kw_values[1] = { Qundef };
287
+
288
+ VALUE text = Qnil;
289
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
291
+
292
+ if (!RB_TYPE_P(text, T_STRING)) {
293
+ rb_raise(rb_eArgError, "expected text to be a String");
294
+ return Qnil;
295
+ }
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
297
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
298
+ return Qnil;
299
+ }
300
+
301
+ const int nbest_size = NUM2INT(kw_values[0]);
302
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
303
+ const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
304
+
305
+ VALUE output = rb_ary_new();
306
+ for (const std::vector<int> ids : id_set) {
307
+ VALUE sub_output = rb_ary_new();
308
+ for (const int idx : ids) {
309
+ rb_ary_push(sub_output, INT2NUM(idx));
310
+ }
311
+ rb_ary_push(output, sub_output);
312
+ }
313
+
314
+ RB_GC_GUARD(text);
315
+ return output;
316
+ };
317
+
318
+ static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
319
+ VALUE kw_args = Qnil;
320
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
321
+ VALUE kw_values[2] = { Qundef, Qundef };
322
+
323
+ VALUE text = Qnil;
324
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
325
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
326
+
327
+ if (!RB_TYPE_P(text, T_STRING)) {
328
+ rb_raise(rb_eArgError, "expected text to be a String");
329
+ return Qnil;
330
+ }
331
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
332
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
333
+ return Qnil;
334
+ }
335
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
336
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
337
+ return Qnil;
338
+ }
339
+
340
+ const int nbest_size = NUM2INT(kw_values[0]);
341
+ const float alpha = NUM2DBL(kw_values[1]);
342
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
343
+ const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
344
+
345
+ VALUE output = rb_ary_new();
346
+ for (const std::string& token : pieces) {
347
+ rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
348
+ }
349
+
350
+ RB_GC_GUARD(text);
351
+ return output;
352
+ };
353
+
354
+ static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
355
+ VALUE kw_args = Qnil;
356
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
357
+ VALUE kw_values[2] = { Qundef, Qundef };
358
+
359
+ VALUE text = Qnil;
360
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
361
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
362
+
363
+ if (!RB_TYPE_P(text, T_STRING)) {
364
+ rb_raise(rb_eArgError, "expected text to be a String");
365
+ return Qnil;
366
+ }
367
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
368
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
369
+ return Qnil;
370
+ }
371
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
372
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
373
+ return Qnil;
374
+ }
375
+
376
+ const int nbest_size = NUM2INT(kw_values[0]);
377
+ const float alpha = NUM2DBL(kw_values[1]);
378
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
379
+ const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
380
+
381
+ VALUE output = rb_ary_new();
382
+ for (const int idx : ids) {
383
+ rb_ary_push(output, INT2NUM(idx));
384
+ }
385
+
386
+ RB_GC_GUARD(text);
387
+ return output;
388
+ };
389
+
196
390
  static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
197
391
  VALUE kw_args = Qnil;
198
392
  ID kw_table[1] = { rb_intern("out_type") };
@@ -204,7 +398,7 @@ private:
204
398
  VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
205
399
 
206
400
  if (!RB_TYPE_P(pieces, T_ARRAY)) {
207
- rb_raise(rb_eArgError, "expected out_type to be an Array");
401
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
208
402
  return Qnil;
209
403
  }
210
404
  if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
@@ -232,7 +426,7 @@ private:
232
426
  rb_raise(rb_eSentencePieceError, "%s", status.message());
233
427
  return Qfalse;
234
428
  }
235
- output = rb_str_new_cstr(text.c_str());
429
+ output = rb_utf8_str_new_cstr(text.c_str());
236
430
  } else {
237
431
  output = rb_ary_new();
238
432
  for (size_t i = 0; i < n_pieces; i++) {
@@ -249,7 +443,7 @@ private:
249
443
  rb_raise(rb_eSentencePieceError, "%s", status.message());
250
444
  return Qfalse;
251
445
  }
252
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
446
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
253
447
  }
254
448
  }
255
449
  } else {
@@ -265,7 +459,7 @@ private:
265
459
  rb_raise(rb_eSentencePieceError, "%s", status.message());
266
460
  return Qfalse;
267
461
  }
268
- output = rb_str_new_cstr(text.c_str());
462
+ output = rb_utf8_str_new_cstr(text.c_str());
269
463
  } else {
270
464
  output = rb_ary_new();
271
465
  for (size_t i = 0; i < n_pieces; i++) {
@@ -282,7 +476,7 @@ private:
282
476
  rb_raise(rb_eSentencePieceError, "%s", status.message());
283
477
  return Qfalse;
284
478
  }
285
- rb_ary_push(output, rb_str_new_cstr(text.c_str()));
479
+ rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
286
480
  }
287
481
  }
288
482
  }
@@ -290,6 +484,159 @@ private:
290
484
  return output;
291
485
  };
292
486
 
487
+ static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
488
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
489
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
490
+ return Qnil;
491
+ }
492
+
493
+ std::vector<std::string> pcs;
494
+ const size_t n_pieces = RARRAY_LEN(pieces);
495
+ for (size_t i = 0; i < n_pieces; i++) {
496
+ VALUE et = rb_ary_entry(pieces, i);
497
+ pcs.push_back(StringValueCStr(et));
498
+ }
499
+
500
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
501
+ const std::string text = ptr->DecodePieces(pcs);
502
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
503
+
504
+ return output;
505
+ };
506
+
507
+ static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
508
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
509
+ rb_raise(rb_eArgError, "expected ids to be an Array");
510
+ return Qnil;
511
+ }
512
+
513
+ std::vector<int> pcs;
514
+ const size_t n_pieces = RARRAY_LEN(ids);
515
+ for (size_t i = 0; i < n_pieces; i++) {
516
+ VALUE et = rb_ary_entry(ids, i);
517
+ pcs.push_back(NUM2INT(et));
518
+ }
519
+
520
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
521
+ const std::string text = ptr->DecodeIds(pcs);
522
+ VALUE output = rb_utf8_str_new_cstr(text.c_str());
523
+
524
+ return output;
525
+ };
526
+
527
+ static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
528
+ if (!RB_TYPE_P(text, T_STRING)) {
529
+ rb_raise(rb_eArgError, "expected text to be a String");
530
+ return Qnil;
531
+ }
532
+
533
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
534
+ const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
535
+ VALUE output = rb_str_new_cstr(serialized.c_str());
536
+
537
+ RB_GC_GUARD(text);
538
+ return output;
539
+ };
540
+
541
+ static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
542
+ VALUE kw_args = Qnil;
543
+ ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
544
+ VALUE kw_values[2] = { Qundef, Qundef };
545
+
546
+ VALUE text = Qnil;
547
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
548
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
549
+
550
+ if (!RB_TYPE_P(text, T_STRING)) {
551
+ rb_raise(rb_eArgError, "expected text to be a String");
552
+ return Qnil;
553
+ }
554
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
555
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
556
+ return Qnil;
557
+ }
558
+ if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
559
+ rb_raise(rb_eArgError, "expected alpha to be a Float");
560
+ return Qnil;
561
+ }
562
+
563
+ const int nbest_size = NUM2INT(kw_values[0]);
564
+ const float alpha = NUM2DBL(kw_values[1]);
565
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
566
+ const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
567
+ VALUE output = rb_str_new_cstr(serialized.c_str());
568
+
569
+ RB_GC_GUARD(text);
570
+ return output;
571
+ };
572
+
573
+ static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
574
+ VALUE kw_args = Qnil;
575
+ ID kw_table[1] = { rb_intern("nbest_size") };
576
+ VALUE kw_values[1] = { Qundef };
577
+
578
+ VALUE text = Qnil;
579
+ rb_scan_args(argc, argv, "1:", &text, &kw_args);
580
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
581
+
582
+ if (!RB_TYPE_P(text, T_STRING)) {
583
+ rb_raise(rb_eArgError, "expected text to be a String");
584
+ return Qnil;
585
+ }
586
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
587
+ rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
588
+ return Qnil;
589
+ }
590
+
591
+ const int nbest_size = NUM2INT(kw_values[0]);
592
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
593
+ const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
594
+ VALUE output = rb_str_new_cstr(serialized.c_str());
595
+
596
+ RB_GC_GUARD(text);
597
+ return output;
598
+ };
599
+
600
+ static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
601
+ if (!RB_TYPE_P(pieces, T_ARRAY)) {
602
+ rb_raise(rb_eArgError, "expected pieces to be an Array");
603
+ return Qnil;
604
+ }
605
+
606
+ std::vector<std::string> pcs;
607
+ const size_t n_pieces = RARRAY_LEN(pieces);
608
+ for (size_t i = 0; i < n_pieces; i++) {
609
+ VALUE et = rb_ary_entry(pieces, i);
610
+ pcs.push_back(StringValueCStr(et));
611
+ }
612
+
613
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
614
+ const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
615
+ VALUE output = rb_str_new_cstr(serialized.c_str());
616
+
617
+ return output;
618
+ };
619
+
620
+ static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
621
+ if (!RB_TYPE_P(ids, T_ARRAY)) {
622
+ rb_raise(rb_eArgError, "expected ids to be an Array");
623
+ return Qnil;
624
+ }
625
+
626
+ std::vector<int> pcs;
627
+ const size_t n_pieces = RARRAY_LEN(ids);
628
+ for (size_t i = 0; i < n_pieces; i++) {
629
+ VALUE et = rb_ary_entry(ids, i);
630
+ pcs.push_back(NUM2INT(et));
631
+ }
632
+
633
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
634
+ const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
635
+ VALUE output = rb_str_new_cstr(serialized.c_str());
636
+
637
+ return output;
638
+ };
639
+
293
640
  static VALUE _sentencepiece_processor_piece_size(VALUE self) {
294
641
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
295
642
  return INT2NUM(ptr->GetPieceSize());
@@ -325,19 +672,49 @@ private:
325
672
  VALUE output = Qnil;
326
673
  sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
327
674
  if (RB_INTEGER_TYPE_P(ids)) {
328
- const std::string piece = ptr->IdToPiece(NUM2INT(ids));
675
+ const int idx = NUM2INT(ids);
676
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
677
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
678
+ return Qnil;
679
+ }
680
+ const std::string piece = ptr->IdToPiece(idx);
329
681
  output = rb_utf8_str_new_cstr(piece.c_str());
330
682
  } else {
331
683
  const size_t n_ids = RARRAY_LEN(ids);
332
684
  output = rb_ary_new();
333
685
  for (size_t i = 0; i < n_ids; i++) {
334
686
  VALUE et = rb_ary_entry(ids, i);
335
- const std::string piece = ptr->IdToPiece(NUM2INT(et));
687
+ const int idx = NUM2INT(et);
688
+ if (idx < 0 || idx >= ptr->GetPieceSize()) {
689
+ rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
690
+ return Qnil;
691
+ }
692
+ const std::string piece = ptr->IdToPiece(idx);
336
693
  rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
337
694
  }
338
695
  }
339
696
  return output;
340
697
  };
698
+
699
+ static VALUE _sentencepiece_processor_unk_id(VALUE self) {
700
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
701
+ return INT2NUM(ptr->unk_id());
702
+ };
703
+
704
+ static VALUE _sentencepiece_processor_bos_id(VALUE self) {
705
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
706
+ return INT2NUM(ptr->bos_id());
707
+ };
708
+
709
+ static VALUE _sentencepiece_processor_eos_id(VALUE self) {
710
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
711
+ return INT2NUM(ptr->eos_id());
712
+ };
713
+
714
+ static VALUE _sentencepiece_processor_pad_id(VALUE self) {
715
+ sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
716
+ return INT2NUM(ptr->pad_id());
717
+ };
341
718
  };
342
719
 
343
720
  const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
350
727
  RUBY_TYPED_FREE_IMMEDIATELY
351
728
  };
352
729
 
730
+ namespace sentencepiece {
731
+ class SentencePieceTrainerWrapper {
732
+ public:
733
+ SentencePieceTrainerWrapper(){};
734
+ ~SentencePieceTrainerWrapper(){};
735
+
736
+ util::Status Train(absl::string_view args) {
737
+ return SentencePieceTrainer::Train(args);
738
+ };
739
+ };
740
+ } // namespace sentencepiece
741
+
742
+ class RbSentencePieceTrainer {
743
+ public:
744
+ static VALUE define_class(VALUE outer) {
745
+ rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
746
+ rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
747
+ return rb_cSentencePieceTrainer;
748
+ };
749
+
750
+ private:
751
+ static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
752
+ if (!RB_TYPE_P(args, T_STRING)) {
753
+ rb_raise(rb_eArgError, "expected args to be a String");
754
+ return Qnil;
755
+ }
756
+
757
+ const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
758
+ if (!status.ok()) {
759
+ rb_raise(rb_eSentencePieceError, "%s", status.message());
760
+ return Qnil;
761
+ }
762
+
763
+ RB_GC_GUARD(args);
764
+ return Qnil;
765
+ };
766
+ };
767
+
353
768
  #endif /* SENTENCEPIECE_HPP */
@@ -1,5 +1,7 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ # sentencepiece.rb provides Ruby bindings for the SentencePiece.
3
4
  module SentencePiece
4
- VERSION = '0.0.1'
5
+ # The version of sentencepiece.rb you install.
6
+ VERSION = '0.1.0'
5
7
  end
@@ -1,4 +1,63 @@
1
1
  module SentencePiece
2
2
  VERSION: String
3
- # See the writing guide of rbs: https://github.com/ruby/rbs#guides
3
+
4
+ class SentencePieceTrainer
5
+ def self.train: (String args) -> void
6
+ end
7
+
8
+ class SentencePieceProcessor
9
+ def initialize: (?model_file: String model_file) -> void
10
+
11
+ def load: (String model_file) -> void
12
+
13
+ def encode: (String text, ?out_type: String out_type) -> (Array[Integer] | Array[String])
14
+ | (Array[String] text, ?out_type: String out_type) -> (Array[Array[Integer]] | Array[Array[String]])
15
+
16
+ def encode_as_ids: (String text) -> Array[Integer]
17
+
18
+ def encode_as_pieces: (String text) -> Array[String]
19
+
20
+ def encode_as_serialized_proto: (String text) -> String
21
+
22
+ def nbest_encode_as_ids: (String text, nbest_size: Integer nbest_size) -> Array[Array[Integer]]
23
+
24
+ def nbest_encode_as_pieces: (String text, nbest_size: Integer nbest_size) -> Array[Array[String]]
25
+
26
+ def nbest_encode_as_serialized_proto: (String text, nbest_size: Integer nbest_size) -> String
27
+
28
+ def sample_encode_as_ids: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> Array[Integer]
29
+
30
+ def sample_encode_as_pieces: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> Array[String]
31
+
32
+ def sample_encode_as_serialized_proto: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> String
33
+
34
+ def decode: (Array[Integer], ?out_type: String out_type) -> String
35
+ | (Array[Array[Integer]], ?out_type: String out_type) -> Array[String]
36
+ | (Array[String], ?out_type: String out_type) -> String
37
+ | (Array[Array[String]], ?out_type: String out_type) -> Array[String]
38
+
39
+ def decode_ids: (Array[Integer]) -> String
40
+
41
+ def decode_ids_as_serialized_proto: (Array[Integer] ids) -> String
42
+
43
+ def decode_pieces: (Array[String]) -> String
44
+
45
+ def decode_pieces_as_serialized_proto: (Array[String] pieces) -> String
46
+
47
+ def id_to_piece: (Integer id) -> String
48
+ | (Array[Integer] ids) -> Array[String]
49
+
50
+ def piece_to_id: (String piece) -> Integer
51
+ | (Array[String] pieces) -> Array[Integer]
52
+
53
+ def piece_size: () -> Integer
54
+
55
+ def bos_id: () -> Integer
56
+
57
+ def eos_id: () -> Integer
58
+
59
+ def pad_id: () -> Integer
60
+
61
+ def unk_id: () -> Integer
62
+ end
4
63
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: sentencepiece
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.1
4
+ version: 0.1.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-21 00:00:00.000000000 Z
11
+ date: 2023-03-26 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: |
14
14
  sentencepiece.rb provides Ruby bindings for the SentencePiece,
@@ -36,6 +36,7 @@ metadata:
36
36
  homepage_uri: https://github.com/yoshoku/sentencepiece.rb
37
37
  source_code_uri: https://github.com/yoshoku/sentencepiece.rb
38
38
  changelog_uri: https://github.com/yoshoku/sentencepiece.rb/blob/main/CHANGELOG.md
39
+ documentation_uri: https://yoshoku.github.io/sentencepiece.rb/doc/
39
40
  rubygems_mfa_required: 'true'
40
41
  post_install_message:
41
42
  rdoc_options: []