sentencepiece 0.0.1 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -1
- data/ext/sentencepiece/sentencepiece.cpp +1 -0
- data/ext/sentencepiece/sentencepiece.hpp +422 -7
- data/lib/sentencepiece/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 52d4bb0f1a0b4a68c9db252911fe0661b7d8042d2d976a17506e0eaf0005d48f
|
4
|
+
data.tar.gz: 793b8b3e47cb6a9c1ab6b05b4cbef264b16eef7b632699983de9f8c40056d9ac
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9702468ce33efdf7ae2ac3735cec983ae58a50677357a2d6cbff88089e73bf40b81ef9872b9860474c948fed3f5920a093252e90fbdea2d5e188c41057c7923e
|
7
|
+
data.tar.gz: ba262d347b32364255ecbbce8306b2d126296b46e6aac7828903c0dc9d6485702ef153ed7aea83a79eb69b5bfc655482787f6bb1c3ad086215f9106658a3a3e8
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
# sentencepiece.rb
|
2
2
|
|
3
|
+
[](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
|
4
|
+
[](https://badge.fury.io/rb/sentencepiece)
|
3
5
|
[](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
|
4
6
|
|
5
7
|
sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
|
6
8
|
an unsupervised text tokenizer and detokenizer for neural network-based text generation.
|
7
9
|
|
8
|
-
It is still under development and may undergo many changes in the future.
|
10
|
+
It is still **under development** and may undergo many changes in the future.
|
9
11
|
|
10
12
|
## Installation
|
11
13
|
|
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
|
|
22
22
|
rb_mSentencePiece = rb_define_module("SentencePiece");
|
23
23
|
rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
|
24
24
|
RbSentencePieceProcessor::define_class(rb_mSentencePiece);
|
25
|
+
RbSentencePieceTrainer::define_class(rb_mSentencePiece);
|
25
26
|
}
|
@@ -27,6 +27,7 @@
|
|
27
27
|
VALUE rb_mSentencePiece;
|
28
28
|
VALUE rb_eSentencePieceError;
|
29
29
|
VALUE rb_cSentencePieceProcessor;
|
30
|
+
VALUE rb_cSentencePieceTrainer;
|
30
31
|
|
31
32
|
class RbSentencePieceProcessor {
|
32
33
|
public:
|
@@ -57,10 +58,27 @@ public:
|
|
57
58
|
rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
|
58
59
|
rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
|
59
60
|
rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
|
61
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
|
62
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
|
63
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
|
64
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
|
65
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
|
66
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
|
60
67
|
rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
|
68
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
|
69
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
|
70
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
|
71
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
|
72
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
|
73
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
|
74
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
|
61
75
|
rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
|
62
76
|
rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
|
63
77
|
rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
|
78
|
+
rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
|
79
|
+
rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
|
80
|
+
rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
|
81
|
+
rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
|
64
82
|
return rb_cSentencePieceProcessor;
|
65
83
|
};
|
66
84
|
|
@@ -193,6 +211,182 @@ private:
|
|
193
211
|
return output;
|
194
212
|
};
|
195
213
|
|
214
|
+
static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
|
215
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
216
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
217
|
+
return Qnil;
|
218
|
+
}
|
219
|
+
|
220
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
221
|
+
const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
|
222
|
+
VALUE output = rb_ary_new();
|
223
|
+
for (const int idx : ids) {
|
224
|
+
rb_ary_push(output, INT2NUM(idx));
|
225
|
+
}
|
226
|
+
|
227
|
+
RB_GC_GUARD(text);
|
228
|
+
return output;
|
229
|
+
};
|
230
|
+
|
231
|
+
static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
|
232
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
233
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
234
|
+
return Qnil;
|
235
|
+
}
|
236
|
+
|
237
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
238
|
+
const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
|
239
|
+
VALUE output = rb_ary_new();
|
240
|
+
for (const std::string& token : pieces) {
|
241
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
242
|
+
}
|
243
|
+
|
244
|
+
RB_GC_GUARD(text);
|
245
|
+
return output;
|
246
|
+
};
|
247
|
+
|
248
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
249
|
+
VALUE kw_args = Qnil;
|
250
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
251
|
+
VALUE kw_values[1] = { Qundef };
|
252
|
+
|
253
|
+
VALUE text = Qnil;
|
254
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
255
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
256
|
+
|
257
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
258
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
259
|
+
return Qnil;
|
260
|
+
}
|
261
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
262
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
263
|
+
return Qnil;
|
264
|
+
}
|
265
|
+
|
266
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
267
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
268
|
+
const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
|
269
|
+
|
270
|
+
VALUE output = rb_ary_new();
|
271
|
+
for (const std::vector<std::string>& tokens : pieces) {
|
272
|
+
VALUE sub_output = rb_ary_new();
|
273
|
+
for (const std::string& token : tokens) {
|
274
|
+
rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
|
275
|
+
}
|
276
|
+
rb_ary_push(output, sub_output);
|
277
|
+
}
|
278
|
+
|
279
|
+
RB_GC_GUARD(text);
|
280
|
+
return output;
|
281
|
+
};
|
282
|
+
|
283
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
284
|
+
VALUE kw_args = Qnil;
|
285
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
286
|
+
VALUE kw_values[1] = { Qundef };
|
287
|
+
|
288
|
+
VALUE text = Qnil;
|
289
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
290
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
291
|
+
|
292
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
293
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
294
|
+
return Qnil;
|
295
|
+
}
|
296
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
297
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
298
|
+
return Qnil;
|
299
|
+
}
|
300
|
+
|
301
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
302
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
303
|
+
const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
|
304
|
+
|
305
|
+
VALUE output = rb_ary_new();
|
306
|
+
for (const std::vector<int> ids : id_set) {
|
307
|
+
VALUE sub_output = rb_ary_new();
|
308
|
+
for (const int idx : ids) {
|
309
|
+
rb_ary_push(sub_output, INT2NUM(idx));
|
310
|
+
}
|
311
|
+
rb_ary_push(output, sub_output);
|
312
|
+
}
|
313
|
+
|
314
|
+
RB_GC_GUARD(text);
|
315
|
+
return output;
|
316
|
+
};
|
317
|
+
|
318
|
+
static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
319
|
+
VALUE kw_args = Qnil;
|
320
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
321
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
322
|
+
|
323
|
+
VALUE text = Qnil;
|
324
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
325
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
326
|
+
|
327
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
328
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
329
|
+
return Qnil;
|
330
|
+
}
|
331
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
332
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
333
|
+
return Qnil;
|
334
|
+
}
|
335
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
336
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
337
|
+
return Qnil;
|
338
|
+
}
|
339
|
+
|
340
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
341
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
342
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
343
|
+
const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
|
344
|
+
|
345
|
+
VALUE output = rb_ary_new();
|
346
|
+
for (const std::string& token : pieces) {
|
347
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
348
|
+
}
|
349
|
+
|
350
|
+
RB_GC_GUARD(text);
|
351
|
+
return output;
|
352
|
+
};
|
353
|
+
|
354
|
+
static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
355
|
+
VALUE kw_args = Qnil;
|
356
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
357
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
358
|
+
|
359
|
+
VALUE text = Qnil;
|
360
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
361
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
362
|
+
|
363
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
364
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
365
|
+
return Qnil;
|
366
|
+
}
|
367
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
368
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
369
|
+
return Qnil;
|
370
|
+
}
|
371
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
372
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
373
|
+
return Qnil;
|
374
|
+
}
|
375
|
+
|
376
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
377
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
378
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
379
|
+
const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
|
380
|
+
|
381
|
+
VALUE output = rb_ary_new();
|
382
|
+
for (const int idx : ids) {
|
383
|
+
rb_ary_push(output, INT2NUM(idx));
|
384
|
+
}
|
385
|
+
|
386
|
+
RB_GC_GUARD(text);
|
387
|
+
return output;
|
388
|
+
};
|
389
|
+
|
196
390
|
static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
|
197
391
|
VALUE kw_args = Qnil;
|
198
392
|
ID kw_table[1] = { rb_intern("out_type") };
|
@@ -204,7 +398,7 @@ private:
|
|
204
398
|
VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
|
205
399
|
|
206
400
|
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
207
|
-
rb_raise(rb_eArgError, "expected
|
401
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
208
402
|
return Qnil;
|
209
403
|
}
|
210
404
|
if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
|
@@ -232,7 +426,7 @@ private:
|
|
232
426
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
233
427
|
return Qfalse;
|
234
428
|
}
|
235
|
-
output =
|
429
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
236
430
|
} else {
|
237
431
|
output = rb_ary_new();
|
238
432
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -249,7 +443,7 @@ private:
|
|
249
443
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
250
444
|
return Qfalse;
|
251
445
|
}
|
252
|
-
rb_ary_push(output,
|
446
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
253
447
|
}
|
254
448
|
}
|
255
449
|
} else {
|
@@ -265,7 +459,7 @@ private:
|
|
265
459
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
266
460
|
return Qfalse;
|
267
461
|
}
|
268
|
-
output =
|
462
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
269
463
|
} else {
|
270
464
|
output = rb_ary_new();
|
271
465
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -282,7 +476,7 @@ private:
|
|
282
476
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
283
477
|
return Qfalse;
|
284
478
|
}
|
285
|
-
rb_ary_push(output,
|
479
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
286
480
|
}
|
287
481
|
}
|
288
482
|
}
|
@@ -290,6 +484,159 @@ private:
|
|
290
484
|
return output;
|
291
485
|
};
|
292
486
|
|
487
|
+
static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
|
488
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
489
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
490
|
+
return Qnil;
|
491
|
+
}
|
492
|
+
|
493
|
+
std::vector<std::string> pcs;
|
494
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
495
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
496
|
+
VALUE et = rb_ary_entry(pieces, i);
|
497
|
+
pcs.push_back(StringValueCStr(et));
|
498
|
+
}
|
499
|
+
|
500
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
501
|
+
const std::string text = ptr->DecodePieces(pcs);
|
502
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
503
|
+
|
504
|
+
return output;
|
505
|
+
};
|
506
|
+
|
507
|
+
static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
|
508
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
509
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
510
|
+
return Qnil;
|
511
|
+
}
|
512
|
+
|
513
|
+
std::vector<int> pcs;
|
514
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
515
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
516
|
+
VALUE et = rb_ary_entry(ids, i);
|
517
|
+
pcs.push_back(NUM2INT(et));
|
518
|
+
}
|
519
|
+
|
520
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
521
|
+
const std::string text = ptr->DecodeIds(pcs);
|
522
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
523
|
+
|
524
|
+
return output;
|
525
|
+
};
|
526
|
+
|
527
|
+
static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
|
528
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
529
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
530
|
+
return Qnil;
|
531
|
+
}
|
532
|
+
|
533
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
534
|
+
const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
|
535
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
536
|
+
|
537
|
+
RB_GC_GUARD(text);
|
538
|
+
return output;
|
539
|
+
};
|
540
|
+
|
541
|
+
static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
542
|
+
VALUE kw_args = Qnil;
|
543
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
544
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
545
|
+
|
546
|
+
VALUE text = Qnil;
|
547
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
548
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
549
|
+
|
550
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
551
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
552
|
+
return Qnil;
|
553
|
+
}
|
554
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
555
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
556
|
+
return Qnil;
|
557
|
+
}
|
558
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
559
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
560
|
+
return Qnil;
|
561
|
+
}
|
562
|
+
|
563
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
564
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
565
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
566
|
+
const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
|
567
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
568
|
+
|
569
|
+
RB_GC_GUARD(text);
|
570
|
+
return output;
|
571
|
+
};
|
572
|
+
|
573
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
574
|
+
VALUE kw_args = Qnil;
|
575
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
576
|
+
VALUE kw_values[1] = { Qundef };
|
577
|
+
|
578
|
+
VALUE text = Qnil;
|
579
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
580
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
581
|
+
|
582
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
583
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
584
|
+
return Qnil;
|
585
|
+
}
|
586
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
587
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
588
|
+
return Qnil;
|
589
|
+
}
|
590
|
+
|
591
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
592
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
593
|
+
const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
|
594
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
595
|
+
|
596
|
+
RB_GC_GUARD(text);
|
597
|
+
return output;
|
598
|
+
};
|
599
|
+
|
600
|
+
static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
|
601
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
602
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
603
|
+
return Qnil;
|
604
|
+
}
|
605
|
+
|
606
|
+
std::vector<std::string> pcs;
|
607
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
608
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
609
|
+
VALUE et = rb_ary_entry(pieces, i);
|
610
|
+
pcs.push_back(StringValueCStr(et));
|
611
|
+
}
|
612
|
+
|
613
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
614
|
+
const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
|
615
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
616
|
+
|
617
|
+
return output;
|
618
|
+
};
|
619
|
+
|
620
|
+
static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
|
621
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
622
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
623
|
+
return Qnil;
|
624
|
+
}
|
625
|
+
|
626
|
+
std::vector<int> pcs;
|
627
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
628
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
629
|
+
VALUE et = rb_ary_entry(ids, i);
|
630
|
+
pcs.push_back(NUM2INT(et));
|
631
|
+
}
|
632
|
+
|
633
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
634
|
+
const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
|
635
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
636
|
+
|
637
|
+
return output;
|
638
|
+
};
|
639
|
+
|
293
640
|
static VALUE _sentencepiece_processor_piece_size(VALUE self) {
|
294
641
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
295
642
|
return INT2NUM(ptr->GetPieceSize());
|
@@ -325,19 +672,49 @@ private:
|
|
325
672
|
VALUE output = Qnil;
|
326
673
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
327
674
|
if (RB_INTEGER_TYPE_P(ids)) {
|
328
|
-
const
|
675
|
+
const int idx = NUM2INT(ids);
|
676
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
677
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
678
|
+
return Qnil;
|
679
|
+
}
|
680
|
+
const std::string piece = ptr->IdToPiece(idx);
|
329
681
|
output = rb_utf8_str_new_cstr(piece.c_str());
|
330
682
|
} else {
|
331
683
|
const size_t n_ids = RARRAY_LEN(ids);
|
332
684
|
output = rb_ary_new();
|
333
685
|
for (size_t i = 0; i < n_ids; i++) {
|
334
686
|
VALUE et = rb_ary_entry(ids, i);
|
335
|
-
const
|
687
|
+
const int idx = NUM2INT(et);
|
688
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
689
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
690
|
+
return Qnil;
|
691
|
+
}
|
692
|
+
const std::string piece = ptr->IdToPiece(idx);
|
336
693
|
rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
|
337
694
|
}
|
338
695
|
}
|
339
696
|
return output;
|
340
697
|
};
|
698
|
+
|
699
|
+
static VALUE _sentencepiece_processor_unk_id(VALUE self) {
|
700
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
701
|
+
return INT2NUM(ptr->unk_id());
|
702
|
+
};
|
703
|
+
|
704
|
+
static VALUE _sentencepiece_processor_bos_id(VALUE self) {
|
705
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
706
|
+
return INT2NUM(ptr->bos_id());
|
707
|
+
};
|
708
|
+
|
709
|
+
static VALUE _sentencepiece_processor_eos_id(VALUE self) {
|
710
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
711
|
+
return INT2NUM(ptr->eos_id());
|
712
|
+
};
|
713
|
+
|
714
|
+
static VALUE _sentencepiece_processor_pad_id(VALUE self) {
|
715
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
716
|
+
return INT2NUM(ptr->pad_id());
|
717
|
+
};
|
341
718
|
};
|
342
719
|
|
343
720
|
const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
|
350
727
|
RUBY_TYPED_FREE_IMMEDIATELY
|
351
728
|
};
|
352
729
|
|
730
|
+
namespace sentencepiece {
|
731
|
+
class SentencePieceTrainerWrapper {
|
732
|
+
public:
|
733
|
+
SentencePieceTrainerWrapper(){};
|
734
|
+
~SentencePieceTrainerWrapper(){};
|
735
|
+
|
736
|
+
util::Status Train(absl::string_view args) {
|
737
|
+
return SentencePieceTrainer::Train(args);
|
738
|
+
};
|
739
|
+
};
|
740
|
+
} // namespace sentencepiece
|
741
|
+
|
742
|
+
class RbSentencePieceTrainer {
|
743
|
+
public:
|
744
|
+
static VALUE define_class(VALUE outer) {
|
745
|
+
rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
|
746
|
+
rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
|
747
|
+
return rb_cSentencePieceTrainer;
|
748
|
+
};
|
749
|
+
|
750
|
+
private:
|
751
|
+
static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
|
752
|
+
if (!RB_TYPE_P(args, T_STRING)) {
|
753
|
+
rb_raise(rb_eArgError, "expected args to be a String");
|
754
|
+
return Qnil;
|
755
|
+
}
|
756
|
+
|
757
|
+
const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
|
758
|
+
if (!status.ok()) {
|
759
|
+
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
760
|
+
return Qnil;
|
761
|
+
}
|
762
|
+
|
763
|
+
RB_GC_GUARD(args);
|
764
|
+
return Qnil;
|
765
|
+
};
|
766
|
+
};
|
767
|
+
|
353
768
|
#endif /* SENTENCEPIECE_HPP */
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: sentencepiece
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-03-
|
11
|
+
date: 2023-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: |
|
14
14
|
sentencepiece.rb provides Ruby bindings for the SentencePiece,
|