sentencepiece 0.0.1 → 0.0.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -1
- data/ext/sentencepiece/sentencepiece.cpp +1 -0
- data/ext/sentencepiece/sentencepiece.hpp +422 -7
- data/lib/sentencepiece/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 52d4bb0f1a0b4a68c9db252911fe0661b7d8042d2d976a17506e0eaf0005d48f
|
4
|
+
data.tar.gz: 793b8b3e47cb6a9c1ab6b05b4cbef264b16eef7b632699983de9f8c40056d9ac
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9702468ce33efdf7ae2ac3735cec983ae58a50677357a2d6cbff88089e73bf40b81ef9872b9860474c948fed3f5920a093252e90fbdea2d5e188c41057c7923e
|
7
|
+
data.tar.gz: ba262d347b32364255ecbbce8306b2d126296b46e6aac7828903c0dc9d6485702ef153ed7aea83a79eb69b5bfc655482787f6bb1c3ad086215f9106658a3a3e8
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
# sentencepiece.rb
|
2
2
|
|
3
|
+
[![Build Status](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
|
4
|
+
[![Gem Version](https://badge.fury.io/rb/sentencepiece.svg)](https://badge.fury.io/rb/sentencepiece)
|
3
5
|
[![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
|
4
6
|
|
5
7
|
sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
|
6
8
|
an unsupervised text tokenizer and detokenizer for neural network-based text generation.
|
7
9
|
|
8
|
-
It is still under development and may undergo many changes in the future.
|
10
|
+
It is still **under development** and may undergo many changes in the future.
|
9
11
|
|
10
12
|
## Installation
|
11
13
|
|
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
|
|
22
22
|
rb_mSentencePiece = rb_define_module("SentencePiece");
|
23
23
|
rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
|
24
24
|
RbSentencePieceProcessor::define_class(rb_mSentencePiece);
|
25
|
+
RbSentencePieceTrainer::define_class(rb_mSentencePiece);
|
25
26
|
}
|
@@ -27,6 +27,7 @@
|
|
27
27
|
VALUE rb_mSentencePiece;
|
28
28
|
VALUE rb_eSentencePieceError;
|
29
29
|
VALUE rb_cSentencePieceProcessor;
|
30
|
+
VALUE rb_cSentencePieceTrainer;
|
30
31
|
|
31
32
|
class RbSentencePieceProcessor {
|
32
33
|
public:
|
@@ -57,10 +58,27 @@ public:
|
|
57
58
|
rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
|
58
59
|
rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
|
59
60
|
rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
|
61
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
|
62
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
|
63
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
|
64
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
|
65
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
|
66
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
|
60
67
|
rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
|
68
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
|
69
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
|
70
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
|
71
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
|
72
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
|
73
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
|
74
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
|
61
75
|
rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
|
62
76
|
rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
|
63
77
|
rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
|
78
|
+
rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
|
79
|
+
rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
|
80
|
+
rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
|
81
|
+
rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
|
64
82
|
return rb_cSentencePieceProcessor;
|
65
83
|
};
|
66
84
|
|
@@ -193,6 +211,182 @@ private:
|
|
193
211
|
return output;
|
194
212
|
};
|
195
213
|
|
214
|
+
static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
|
215
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
216
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
217
|
+
return Qnil;
|
218
|
+
}
|
219
|
+
|
220
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
221
|
+
const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
|
222
|
+
VALUE output = rb_ary_new();
|
223
|
+
for (const int idx : ids) {
|
224
|
+
rb_ary_push(output, INT2NUM(idx));
|
225
|
+
}
|
226
|
+
|
227
|
+
RB_GC_GUARD(text);
|
228
|
+
return output;
|
229
|
+
};
|
230
|
+
|
231
|
+
static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
|
232
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
233
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
234
|
+
return Qnil;
|
235
|
+
}
|
236
|
+
|
237
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
238
|
+
const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
|
239
|
+
VALUE output = rb_ary_new();
|
240
|
+
for (const std::string& token : pieces) {
|
241
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
242
|
+
}
|
243
|
+
|
244
|
+
RB_GC_GUARD(text);
|
245
|
+
return output;
|
246
|
+
};
|
247
|
+
|
248
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
249
|
+
VALUE kw_args = Qnil;
|
250
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
251
|
+
VALUE kw_values[1] = { Qundef };
|
252
|
+
|
253
|
+
VALUE text = Qnil;
|
254
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
255
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
256
|
+
|
257
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
258
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
259
|
+
return Qnil;
|
260
|
+
}
|
261
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
262
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
263
|
+
return Qnil;
|
264
|
+
}
|
265
|
+
|
266
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
267
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
268
|
+
const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
|
269
|
+
|
270
|
+
VALUE output = rb_ary_new();
|
271
|
+
for (const std::vector<std::string>& tokens : pieces) {
|
272
|
+
VALUE sub_output = rb_ary_new();
|
273
|
+
for (const std::string& token : tokens) {
|
274
|
+
rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
|
275
|
+
}
|
276
|
+
rb_ary_push(output, sub_output);
|
277
|
+
}
|
278
|
+
|
279
|
+
RB_GC_GUARD(text);
|
280
|
+
return output;
|
281
|
+
};
|
282
|
+
|
283
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
284
|
+
VALUE kw_args = Qnil;
|
285
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
286
|
+
VALUE kw_values[1] = { Qundef };
|
287
|
+
|
288
|
+
VALUE text = Qnil;
|
289
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
290
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
291
|
+
|
292
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
293
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
294
|
+
return Qnil;
|
295
|
+
}
|
296
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
297
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
298
|
+
return Qnil;
|
299
|
+
}
|
300
|
+
|
301
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
302
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
303
|
+
const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
|
304
|
+
|
305
|
+
VALUE output = rb_ary_new();
|
306
|
+
for (const std::vector<int> ids : id_set) {
|
307
|
+
VALUE sub_output = rb_ary_new();
|
308
|
+
for (const int idx : ids) {
|
309
|
+
rb_ary_push(sub_output, INT2NUM(idx));
|
310
|
+
}
|
311
|
+
rb_ary_push(output, sub_output);
|
312
|
+
}
|
313
|
+
|
314
|
+
RB_GC_GUARD(text);
|
315
|
+
return output;
|
316
|
+
};
|
317
|
+
|
318
|
+
static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
319
|
+
VALUE kw_args = Qnil;
|
320
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
321
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
322
|
+
|
323
|
+
VALUE text = Qnil;
|
324
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
325
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
326
|
+
|
327
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
328
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
329
|
+
return Qnil;
|
330
|
+
}
|
331
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
332
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
333
|
+
return Qnil;
|
334
|
+
}
|
335
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
336
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
337
|
+
return Qnil;
|
338
|
+
}
|
339
|
+
|
340
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
341
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
342
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
343
|
+
const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
|
344
|
+
|
345
|
+
VALUE output = rb_ary_new();
|
346
|
+
for (const std::string& token : pieces) {
|
347
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
348
|
+
}
|
349
|
+
|
350
|
+
RB_GC_GUARD(text);
|
351
|
+
return output;
|
352
|
+
};
|
353
|
+
|
354
|
+
static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
355
|
+
VALUE kw_args = Qnil;
|
356
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
357
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
358
|
+
|
359
|
+
VALUE text = Qnil;
|
360
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
361
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
362
|
+
|
363
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
364
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
365
|
+
return Qnil;
|
366
|
+
}
|
367
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
368
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
369
|
+
return Qnil;
|
370
|
+
}
|
371
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
372
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
373
|
+
return Qnil;
|
374
|
+
}
|
375
|
+
|
376
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
377
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
378
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
379
|
+
const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
|
380
|
+
|
381
|
+
VALUE output = rb_ary_new();
|
382
|
+
for (const int idx : ids) {
|
383
|
+
rb_ary_push(output, INT2NUM(idx));
|
384
|
+
}
|
385
|
+
|
386
|
+
RB_GC_GUARD(text);
|
387
|
+
return output;
|
388
|
+
};
|
389
|
+
|
196
390
|
static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
|
197
391
|
VALUE kw_args = Qnil;
|
198
392
|
ID kw_table[1] = { rb_intern("out_type") };
|
@@ -204,7 +398,7 @@ private:
|
|
204
398
|
VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
|
205
399
|
|
206
400
|
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
207
|
-
rb_raise(rb_eArgError, "expected
|
401
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
208
402
|
return Qnil;
|
209
403
|
}
|
210
404
|
if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
|
@@ -232,7 +426,7 @@ private:
|
|
232
426
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
233
427
|
return Qfalse;
|
234
428
|
}
|
235
|
-
output =
|
429
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
236
430
|
} else {
|
237
431
|
output = rb_ary_new();
|
238
432
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -249,7 +443,7 @@ private:
|
|
249
443
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
250
444
|
return Qfalse;
|
251
445
|
}
|
252
|
-
rb_ary_push(output,
|
446
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
253
447
|
}
|
254
448
|
}
|
255
449
|
} else {
|
@@ -265,7 +459,7 @@ private:
|
|
265
459
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
266
460
|
return Qfalse;
|
267
461
|
}
|
268
|
-
output =
|
462
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
269
463
|
} else {
|
270
464
|
output = rb_ary_new();
|
271
465
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -282,7 +476,7 @@ private:
|
|
282
476
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
283
477
|
return Qfalse;
|
284
478
|
}
|
285
|
-
rb_ary_push(output,
|
479
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
286
480
|
}
|
287
481
|
}
|
288
482
|
}
|
@@ -290,6 +484,159 @@ private:
|
|
290
484
|
return output;
|
291
485
|
};
|
292
486
|
|
487
|
+
static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
|
488
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
489
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
490
|
+
return Qnil;
|
491
|
+
}
|
492
|
+
|
493
|
+
std::vector<std::string> pcs;
|
494
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
495
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
496
|
+
VALUE et = rb_ary_entry(pieces, i);
|
497
|
+
pcs.push_back(StringValueCStr(et));
|
498
|
+
}
|
499
|
+
|
500
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
501
|
+
const std::string text = ptr->DecodePieces(pcs);
|
502
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
503
|
+
|
504
|
+
return output;
|
505
|
+
};
|
506
|
+
|
507
|
+
static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
|
508
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
509
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
510
|
+
return Qnil;
|
511
|
+
}
|
512
|
+
|
513
|
+
std::vector<int> pcs;
|
514
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
515
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
516
|
+
VALUE et = rb_ary_entry(ids, i);
|
517
|
+
pcs.push_back(NUM2INT(et));
|
518
|
+
}
|
519
|
+
|
520
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
521
|
+
const std::string text = ptr->DecodeIds(pcs);
|
522
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
523
|
+
|
524
|
+
return output;
|
525
|
+
};
|
526
|
+
|
527
|
+
static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
|
528
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
529
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
530
|
+
return Qnil;
|
531
|
+
}
|
532
|
+
|
533
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
534
|
+
const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
|
535
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
536
|
+
|
537
|
+
RB_GC_GUARD(text);
|
538
|
+
return output;
|
539
|
+
};
|
540
|
+
|
541
|
+
static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
542
|
+
VALUE kw_args = Qnil;
|
543
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
544
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
545
|
+
|
546
|
+
VALUE text = Qnil;
|
547
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
548
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
549
|
+
|
550
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
551
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
552
|
+
return Qnil;
|
553
|
+
}
|
554
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
555
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
556
|
+
return Qnil;
|
557
|
+
}
|
558
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
559
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
560
|
+
return Qnil;
|
561
|
+
}
|
562
|
+
|
563
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
564
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
565
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
566
|
+
const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
|
567
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
568
|
+
|
569
|
+
RB_GC_GUARD(text);
|
570
|
+
return output;
|
571
|
+
};
|
572
|
+
|
573
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
574
|
+
VALUE kw_args = Qnil;
|
575
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
576
|
+
VALUE kw_values[1] = { Qundef };
|
577
|
+
|
578
|
+
VALUE text = Qnil;
|
579
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
580
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
581
|
+
|
582
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
583
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
584
|
+
return Qnil;
|
585
|
+
}
|
586
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
587
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
588
|
+
return Qnil;
|
589
|
+
}
|
590
|
+
|
591
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
592
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
593
|
+
const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
|
594
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
595
|
+
|
596
|
+
RB_GC_GUARD(text);
|
597
|
+
return output;
|
598
|
+
};
|
599
|
+
|
600
|
+
static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
|
601
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
602
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
603
|
+
return Qnil;
|
604
|
+
}
|
605
|
+
|
606
|
+
std::vector<std::string> pcs;
|
607
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
608
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
609
|
+
VALUE et = rb_ary_entry(pieces, i);
|
610
|
+
pcs.push_back(StringValueCStr(et));
|
611
|
+
}
|
612
|
+
|
613
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
614
|
+
const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
|
615
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
616
|
+
|
617
|
+
return output;
|
618
|
+
};
|
619
|
+
|
620
|
+
static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
|
621
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
622
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
623
|
+
return Qnil;
|
624
|
+
}
|
625
|
+
|
626
|
+
std::vector<int> pcs;
|
627
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
628
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
629
|
+
VALUE et = rb_ary_entry(ids, i);
|
630
|
+
pcs.push_back(NUM2INT(et));
|
631
|
+
}
|
632
|
+
|
633
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
634
|
+
const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
|
635
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
636
|
+
|
637
|
+
return output;
|
638
|
+
};
|
639
|
+
|
293
640
|
static VALUE _sentencepiece_processor_piece_size(VALUE self) {
|
294
641
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
295
642
|
return INT2NUM(ptr->GetPieceSize());
|
@@ -325,19 +672,49 @@ private:
|
|
325
672
|
VALUE output = Qnil;
|
326
673
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
327
674
|
if (RB_INTEGER_TYPE_P(ids)) {
|
328
|
-
const
|
675
|
+
const int idx = NUM2INT(ids);
|
676
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
677
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
678
|
+
return Qnil;
|
679
|
+
}
|
680
|
+
const std::string piece = ptr->IdToPiece(idx);
|
329
681
|
output = rb_utf8_str_new_cstr(piece.c_str());
|
330
682
|
} else {
|
331
683
|
const size_t n_ids = RARRAY_LEN(ids);
|
332
684
|
output = rb_ary_new();
|
333
685
|
for (size_t i = 0; i < n_ids; i++) {
|
334
686
|
VALUE et = rb_ary_entry(ids, i);
|
335
|
-
const
|
687
|
+
const int idx = NUM2INT(et);
|
688
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
689
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
690
|
+
return Qnil;
|
691
|
+
}
|
692
|
+
const std::string piece = ptr->IdToPiece(idx);
|
336
693
|
rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
|
337
694
|
}
|
338
695
|
}
|
339
696
|
return output;
|
340
697
|
};
|
698
|
+
|
699
|
+
static VALUE _sentencepiece_processor_unk_id(VALUE self) {
|
700
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
701
|
+
return INT2NUM(ptr->unk_id());
|
702
|
+
};
|
703
|
+
|
704
|
+
static VALUE _sentencepiece_processor_bos_id(VALUE self) {
|
705
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
706
|
+
return INT2NUM(ptr->bos_id());
|
707
|
+
};
|
708
|
+
|
709
|
+
static VALUE _sentencepiece_processor_eos_id(VALUE self) {
|
710
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
711
|
+
return INT2NUM(ptr->eos_id());
|
712
|
+
};
|
713
|
+
|
714
|
+
static VALUE _sentencepiece_processor_pad_id(VALUE self) {
|
715
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
716
|
+
return INT2NUM(ptr->pad_id());
|
717
|
+
};
|
341
718
|
};
|
342
719
|
|
343
720
|
const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
|
350
727
|
RUBY_TYPED_FREE_IMMEDIATELY
|
351
728
|
};
|
352
729
|
|
730
|
+
namespace sentencepiece {
|
731
|
+
class SentencePieceTrainerWrapper {
|
732
|
+
public:
|
733
|
+
SentencePieceTrainerWrapper(){};
|
734
|
+
~SentencePieceTrainerWrapper(){};
|
735
|
+
|
736
|
+
util::Status Train(absl::string_view args) {
|
737
|
+
return SentencePieceTrainer::Train(args);
|
738
|
+
};
|
739
|
+
};
|
740
|
+
} // namespace sentencepiece
|
741
|
+
|
742
|
+
class RbSentencePieceTrainer {
|
743
|
+
public:
|
744
|
+
static VALUE define_class(VALUE outer) {
|
745
|
+
rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
|
746
|
+
rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
|
747
|
+
return rb_cSentencePieceTrainer;
|
748
|
+
};
|
749
|
+
|
750
|
+
private:
|
751
|
+
static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
|
752
|
+
if (!RB_TYPE_P(args, T_STRING)) {
|
753
|
+
rb_raise(rb_eArgError, "expected args to be a String");
|
754
|
+
return Qnil;
|
755
|
+
}
|
756
|
+
|
757
|
+
const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
|
758
|
+
if (!status.ok()) {
|
759
|
+
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
760
|
+
return Qnil;
|
761
|
+
}
|
762
|
+
|
763
|
+
RB_GC_GUARD(args);
|
764
|
+
return Qnil;
|
765
|
+
};
|
766
|
+
};
|
767
|
+
|
353
768
|
#endif /* SENTENCEPIECE_HPP */
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: sentencepiece
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 0.0.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-03-
|
11
|
+
date: 2023-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: |
|
14
14
|
sentencepiece.rb provides Ruby bindings for the SentencePiece,
|