sentencepiece 0.0.1 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +3 -2
- data/ext/sentencepiece/sentencepiece.cpp +1 -0
- data/ext/sentencepiece/sentencepiece.hpp +422 -7
- data/lib/sentencepiece/version.rb +3 -1
- data/sig/sentencepiece.rbs +60 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 63aaf69a72781f0087072db8197b9ae3a126b63e632b52bded5b483ff7b2c6b4
|
4
|
+
data.tar.gz: 5ffbd48ed587ef57fd3b427615cc21f4d79804ea90598f9db2178ceff702ca57
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 61b91c838bbdb2f7c47166036b890b545060361e91a66e6cf416206f53d84595cc74052a38df083ad507a6f219d0b2b8ba3667095d2d3514b5ba0c3660967ee5
|
7
|
+
data.tar.gz: 467c2f252e23fbfe684486d42af9fe0eedd058dd99361fc5652bb83412ae801571e5e25879b2f1bb6734c997b9c04f4287378c595afa3698b06885631c8fdcae
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,15 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.1.0] - 2023-03-26
|
4
|
+
|
5
|
+
- Add API documentation.
|
6
|
+
- Add type signatures.
|
7
|
+
|
8
|
+
## [0.0.2] - 2023-03-26
|
9
|
+
|
10
|
+
- Add SentencePieceTrainer class.
|
11
|
+
- Add some encoding and decoding methods to SentencePieceProcessor.
|
12
|
+
|
3
13
|
## [0.0.1] - 2023-03-21
|
4
14
|
|
5
15
|
- Initial release
|
data/README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
# sentencepiece.rb
|
2
2
|
|
3
|
+
[![Build Status](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml/badge.svg)](https://github.com/yoshoku/sentencepiece.rb/actions/workflows/main.yml)
|
4
|
+
[![Gem Version](https://badge.fury.io/rb/sentencepiece.svg)](https://badge.fury.io/rb/sentencepiece)
|
3
5
|
[![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/yoshoku/sentencepiece.rb/blob/main/LICENSE.txt)
|
6
|
+
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/sentencepiece.rb/doc/)
|
4
7
|
|
5
8
|
sentencepiece.rb provides Ruby bindings for the [SentencePiece](https://github.com/google/sentencepiece),
|
6
9
|
an unsupervised text tokenizer and detokenizer for neural network-based text generation.
|
7
10
|
|
8
|
-
It is still under development and may undergo many changes in the future.
|
9
|
-
|
10
11
|
## Installation
|
11
12
|
|
12
13
|
Install SentencePiece using your OS package manager;
|
@@ -22,4 +22,5 @@ extern "C" void Init_sentencepiece(void) {
|
|
22
22
|
rb_mSentencePiece = rb_define_module("SentencePiece");
|
23
23
|
rb_eSentencePieceError = rb_define_class_under(rb_mSentencePiece, "Error", rb_eRuntimeError);
|
24
24
|
RbSentencePieceProcessor::define_class(rb_mSentencePiece);
|
25
|
+
RbSentencePieceTrainer::define_class(rb_mSentencePiece);
|
25
26
|
}
|
@@ -27,6 +27,7 @@
|
|
27
27
|
VALUE rb_mSentencePiece;
|
28
28
|
VALUE rb_eSentencePieceError;
|
29
29
|
VALUE rb_cSentencePieceProcessor;
|
30
|
+
VALUE rb_cSentencePieceTrainer;
|
30
31
|
|
31
32
|
class RbSentencePieceProcessor {
|
32
33
|
public:
|
@@ -57,10 +58,27 @@ public:
|
|
57
58
|
rb_define_method(rb_cSentencePieceProcessor, "initialize", RUBY_METHOD_FUNC(_sentencepiece_processor_init), -1);
|
58
59
|
rb_define_method(rb_cSentencePieceProcessor, "load", RUBY_METHOD_FUNC(_sentencepiece_processor_load), 1);
|
59
60
|
rb_define_method(rb_cSentencePieceProcessor, "encode", RUBY_METHOD_FUNC(_sentencepiece_processor_encode), -1);
|
61
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_ids), 1);
|
62
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_pieces), 1);
|
63
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_pieces), -1);
|
64
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_ids), -1);
|
65
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_pieces), -1);
|
66
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_ids), -1);
|
60
67
|
rb_define_method(rb_cSentencePieceProcessor, "decode", RUBY_METHOD_FUNC(_sentencepiece_processor_decode), -1);
|
68
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces), 1);
|
69
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids), 1);
|
70
|
+
rb_define_method(rb_cSentencePieceProcessor, "encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_encode_as_serialized_proto), 1);
|
71
|
+
rb_define_method(rb_cSentencePieceProcessor, "sample_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_sample_encode_as_serialized_proto), -1);
|
72
|
+
rb_define_method(rb_cSentencePieceProcessor, "nbest_encode_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_nbest_encode_as_serialized_proto), -1);
|
73
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_pieces_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_pieces_as_serialized_proto), 1);
|
74
|
+
rb_define_method(rb_cSentencePieceProcessor, "decode_ids_as_serialized_proto", RUBY_METHOD_FUNC(_sentencepiece_processor_decode_ids_as_serialized_proto), 1);
|
61
75
|
rb_define_method(rb_cSentencePieceProcessor, "piece_size", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_size), 0);
|
62
76
|
rb_define_method(rb_cSentencePieceProcessor, "piece_to_id", RUBY_METHOD_FUNC(_sentencepiece_processor_piece_to_id), 1);
|
63
77
|
rb_define_method(rb_cSentencePieceProcessor, "id_to_piece", RUBY_METHOD_FUNC(_sentencepiece_processor_id_to_piece), 1);
|
78
|
+
rb_define_method(rb_cSentencePieceProcessor, "unk_id", RUBY_METHOD_FUNC(_sentencepiece_processor_unk_id), 0);
|
79
|
+
rb_define_method(rb_cSentencePieceProcessor, "bos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_bos_id), 0);
|
80
|
+
rb_define_method(rb_cSentencePieceProcessor, "eos_id", RUBY_METHOD_FUNC(_sentencepiece_processor_eos_id), 0);
|
81
|
+
rb_define_method(rb_cSentencePieceProcessor, "pad_id", RUBY_METHOD_FUNC(_sentencepiece_processor_pad_id), 0);
|
64
82
|
return rb_cSentencePieceProcessor;
|
65
83
|
};
|
66
84
|
|
@@ -193,6 +211,182 @@ private:
|
|
193
211
|
return output;
|
194
212
|
};
|
195
213
|
|
214
|
+
static VALUE _sentencepiece_processor_encode_as_ids(VALUE self, VALUE text) {
|
215
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
216
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
217
|
+
return Qnil;
|
218
|
+
}
|
219
|
+
|
220
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
221
|
+
const std::vector<int> ids = ptr->EncodeAsIds(StringValueCStr(text));
|
222
|
+
VALUE output = rb_ary_new();
|
223
|
+
for (const int idx : ids) {
|
224
|
+
rb_ary_push(output, INT2NUM(idx));
|
225
|
+
}
|
226
|
+
|
227
|
+
RB_GC_GUARD(text);
|
228
|
+
return output;
|
229
|
+
};
|
230
|
+
|
231
|
+
static VALUE _sentencepiece_processor_encode_as_pieces(VALUE self, VALUE text) {
|
232
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
233
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
234
|
+
return Qnil;
|
235
|
+
}
|
236
|
+
|
237
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
238
|
+
const std::vector<std::string> pieces = ptr->EncodeAsPieces(StringValueCStr(text));
|
239
|
+
VALUE output = rb_ary_new();
|
240
|
+
for (const std::string& token : pieces) {
|
241
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
242
|
+
}
|
243
|
+
|
244
|
+
RB_GC_GUARD(text);
|
245
|
+
return output;
|
246
|
+
};
|
247
|
+
|
248
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
249
|
+
VALUE kw_args = Qnil;
|
250
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
251
|
+
VALUE kw_values[1] = { Qundef };
|
252
|
+
|
253
|
+
VALUE text = Qnil;
|
254
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
255
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
256
|
+
|
257
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
258
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
259
|
+
return Qnil;
|
260
|
+
}
|
261
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
262
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
263
|
+
return Qnil;
|
264
|
+
}
|
265
|
+
|
266
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
267
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
268
|
+
const std::vector<std::vector<std::string>> pieces = ptr->NBestEncodeAsPieces(StringValueCStr(text), nbest_size);
|
269
|
+
|
270
|
+
VALUE output = rb_ary_new();
|
271
|
+
for (const std::vector<std::string>& tokens : pieces) {
|
272
|
+
VALUE sub_output = rb_ary_new();
|
273
|
+
for (const std::string& token : tokens) {
|
274
|
+
rb_ary_push(sub_output, rb_utf8_str_new_cstr(token.c_str()));
|
275
|
+
}
|
276
|
+
rb_ary_push(output, sub_output);
|
277
|
+
}
|
278
|
+
|
279
|
+
RB_GC_GUARD(text);
|
280
|
+
return output;
|
281
|
+
};
|
282
|
+
|
283
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
284
|
+
VALUE kw_args = Qnil;
|
285
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
286
|
+
VALUE kw_values[1] = { Qundef };
|
287
|
+
|
288
|
+
VALUE text = Qnil;
|
289
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
290
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
291
|
+
|
292
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
293
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
294
|
+
return Qnil;
|
295
|
+
}
|
296
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
297
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
298
|
+
return Qnil;
|
299
|
+
}
|
300
|
+
|
301
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
302
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
303
|
+
const std::vector<std::vector<int>> id_set = ptr->NBestEncodeAsIds(StringValueCStr(text), nbest_size);
|
304
|
+
|
305
|
+
VALUE output = rb_ary_new();
|
306
|
+
for (const std::vector<int> ids : id_set) {
|
307
|
+
VALUE sub_output = rb_ary_new();
|
308
|
+
for (const int idx : ids) {
|
309
|
+
rb_ary_push(sub_output, INT2NUM(idx));
|
310
|
+
}
|
311
|
+
rb_ary_push(output, sub_output);
|
312
|
+
}
|
313
|
+
|
314
|
+
RB_GC_GUARD(text);
|
315
|
+
return output;
|
316
|
+
};
|
317
|
+
|
318
|
+
static VALUE _sentencepiece_processor_sample_encode_as_pieces(int argc, VALUE* argv, VALUE self) {
|
319
|
+
VALUE kw_args = Qnil;
|
320
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
321
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
322
|
+
|
323
|
+
VALUE text = Qnil;
|
324
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
325
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
326
|
+
|
327
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
328
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
329
|
+
return Qnil;
|
330
|
+
}
|
331
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
332
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
333
|
+
return Qnil;
|
334
|
+
}
|
335
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
336
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
337
|
+
return Qnil;
|
338
|
+
}
|
339
|
+
|
340
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
341
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
342
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
343
|
+
const std::vector<std::string> pieces = ptr->SampleEncodeAsPieces(StringValueCStr(text), nbest_size, alpha);
|
344
|
+
|
345
|
+
VALUE output = rb_ary_new();
|
346
|
+
for (const std::string& token : pieces) {
|
347
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(token.c_str()));
|
348
|
+
}
|
349
|
+
|
350
|
+
RB_GC_GUARD(text);
|
351
|
+
return output;
|
352
|
+
};
|
353
|
+
|
354
|
+
static VALUE _sentencepiece_processor_sample_encode_as_ids(int argc, VALUE* argv, VALUE self) {
|
355
|
+
VALUE kw_args = Qnil;
|
356
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
357
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
358
|
+
|
359
|
+
VALUE text = Qnil;
|
360
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
361
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
362
|
+
|
363
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
364
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
365
|
+
return Qnil;
|
366
|
+
}
|
367
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
368
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
369
|
+
return Qnil;
|
370
|
+
}
|
371
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
372
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
373
|
+
return Qnil;
|
374
|
+
}
|
375
|
+
|
376
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
377
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
378
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
379
|
+
const std::vector<int> ids = ptr->SampleEncodeAsIds(StringValueCStr(text), nbest_size, alpha);
|
380
|
+
|
381
|
+
VALUE output = rb_ary_new();
|
382
|
+
for (const int idx : ids) {
|
383
|
+
rb_ary_push(output, INT2NUM(idx));
|
384
|
+
}
|
385
|
+
|
386
|
+
RB_GC_GUARD(text);
|
387
|
+
return output;
|
388
|
+
};
|
389
|
+
|
196
390
|
static VALUE _sentencepiece_processor_decode(int argc, VALUE* argv, VALUE self) {
|
197
391
|
VALUE kw_args = Qnil;
|
198
392
|
ID kw_table[1] = { rb_intern("out_type") };
|
@@ -204,7 +398,7 @@ private:
|
|
204
398
|
VALUE out_type = kw_values[0] != Qundef ? kw_values[0] : rb_str_new_cstr("int");
|
205
399
|
|
206
400
|
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
207
|
-
rb_raise(rb_eArgError, "expected
|
401
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
208
402
|
return Qnil;
|
209
403
|
}
|
210
404
|
if (strcmp(StringValueCStr(out_type), "str") != 0 && strcmp(StringValueCStr(out_type), "int") != 0) {
|
@@ -232,7 +426,7 @@ private:
|
|
232
426
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
233
427
|
return Qfalse;
|
234
428
|
}
|
235
|
-
output =
|
429
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
236
430
|
} else {
|
237
431
|
output = rb_ary_new();
|
238
432
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -249,7 +443,7 @@ private:
|
|
249
443
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
250
444
|
return Qfalse;
|
251
445
|
}
|
252
|
-
rb_ary_push(output,
|
446
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
253
447
|
}
|
254
448
|
}
|
255
449
|
} else {
|
@@ -265,7 +459,7 @@ private:
|
|
265
459
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
266
460
|
return Qfalse;
|
267
461
|
}
|
268
|
-
output =
|
462
|
+
output = rb_utf8_str_new_cstr(text.c_str());
|
269
463
|
} else {
|
270
464
|
output = rb_ary_new();
|
271
465
|
for (size_t i = 0; i < n_pieces; i++) {
|
@@ -282,7 +476,7 @@ private:
|
|
282
476
|
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
283
477
|
return Qfalse;
|
284
478
|
}
|
285
|
-
rb_ary_push(output,
|
479
|
+
rb_ary_push(output, rb_utf8_str_new_cstr(text.c_str()));
|
286
480
|
}
|
287
481
|
}
|
288
482
|
}
|
@@ -290,6 +484,159 @@ private:
|
|
290
484
|
return output;
|
291
485
|
};
|
292
486
|
|
487
|
+
static VALUE _sentencepiece_processor_decode_pieces(VALUE self, VALUE pieces) {
|
488
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
489
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
490
|
+
return Qnil;
|
491
|
+
}
|
492
|
+
|
493
|
+
std::vector<std::string> pcs;
|
494
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
495
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
496
|
+
VALUE et = rb_ary_entry(pieces, i);
|
497
|
+
pcs.push_back(StringValueCStr(et));
|
498
|
+
}
|
499
|
+
|
500
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
501
|
+
const std::string text = ptr->DecodePieces(pcs);
|
502
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
503
|
+
|
504
|
+
return output;
|
505
|
+
};
|
506
|
+
|
507
|
+
static VALUE _sentencepiece_processor_decode_ids(VALUE self, VALUE ids) {
|
508
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
509
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
510
|
+
return Qnil;
|
511
|
+
}
|
512
|
+
|
513
|
+
std::vector<int> pcs;
|
514
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
515
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
516
|
+
VALUE et = rb_ary_entry(ids, i);
|
517
|
+
pcs.push_back(NUM2INT(et));
|
518
|
+
}
|
519
|
+
|
520
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
521
|
+
const std::string text = ptr->DecodeIds(pcs);
|
522
|
+
VALUE output = rb_utf8_str_new_cstr(text.c_str());
|
523
|
+
|
524
|
+
return output;
|
525
|
+
};
|
526
|
+
|
527
|
+
static VALUE _sentencepiece_processor_encode_as_serialized_proto(VALUE self, VALUE text) {
|
528
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
529
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
530
|
+
return Qnil;
|
531
|
+
}
|
532
|
+
|
533
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
534
|
+
const sentencepiece::util::bytes serialized = ptr->EncodeAsSerializedProto(StringValueCStr(text));
|
535
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
536
|
+
|
537
|
+
RB_GC_GUARD(text);
|
538
|
+
return output;
|
539
|
+
};
|
540
|
+
|
541
|
+
static VALUE _sentencepiece_processor_sample_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
542
|
+
VALUE kw_args = Qnil;
|
543
|
+
ID kw_table[2] = { rb_intern("nbest_size"), rb_intern("alpha") };
|
544
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
545
|
+
|
546
|
+
VALUE text = Qnil;
|
547
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
548
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
549
|
+
|
550
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
551
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
552
|
+
return Qnil;
|
553
|
+
}
|
554
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
555
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
556
|
+
return Qnil;
|
557
|
+
}
|
558
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1]) && !RB_FLOAT_TYPE_P(kw_values[1])) {
|
559
|
+
rb_raise(rb_eArgError, "expected alpha to be a Float");
|
560
|
+
return Qnil;
|
561
|
+
}
|
562
|
+
|
563
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
564
|
+
const float alpha = NUM2DBL(kw_values[1]);
|
565
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
566
|
+
const sentencepiece::util::bytes serialized = ptr->SampleEncodeAsSerializedProto(StringValueCStr(text), nbest_size, alpha);
|
567
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
568
|
+
|
569
|
+
RB_GC_GUARD(text);
|
570
|
+
return output;
|
571
|
+
};
|
572
|
+
|
573
|
+
static VALUE _sentencepiece_processor_nbest_encode_as_serialized_proto(int argc, VALUE* argv, VALUE self) {
|
574
|
+
VALUE kw_args = Qnil;
|
575
|
+
ID kw_table[1] = { rb_intern("nbest_size") };
|
576
|
+
VALUE kw_values[1] = { Qundef };
|
577
|
+
|
578
|
+
VALUE text = Qnil;
|
579
|
+
rb_scan_args(argc, argv, "1:", &text, &kw_args);
|
580
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
581
|
+
|
582
|
+
if (!RB_TYPE_P(text, T_STRING)) {
|
583
|
+
rb_raise(rb_eArgError, "expected text to be a String");
|
584
|
+
return Qnil;
|
585
|
+
}
|
586
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
587
|
+
rb_raise(rb_eArgError, "expected nbest_size to be an Integer");
|
588
|
+
return Qnil;
|
589
|
+
}
|
590
|
+
|
591
|
+
const int nbest_size = NUM2INT(kw_values[0]);
|
592
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
593
|
+
const sentencepiece::util::bytes serialized = ptr->NBestEncodeAsSerializedProto(StringValueCStr(text), nbest_size);
|
594
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
595
|
+
|
596
|
+
RB_GC_GUARD(text);
|
597
|
+
return output;
|
598
|
+
};
|
599
|
+
|
600
|
+
static VALUE _sentencepiece_processor_decode_pieces_as_serialized_proto(VALUE self, VALUE pieces) {
|
601
|
+
if (!RB_TYPE_P(pieces, T_ARRAY)) {
|
602
|
+
rb_raise(rb_eArgError, "expected pieces to be an Array");
|
603
|
+
return Qnil;
|
604
|
+
}
|
605
|
+
|
606
|
+
std::vector<std::string> pcs;
|
607
|
+
const size_t n_pieces = RARRAY_LEN(pieces);
|
608
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
609
|
+
VALUE et = rb_ary_entry(pieces, i);
|
610
|
+
pcs.push_back(StringValueCStr(et));
|
611
|
+
}
|
612
|
+
|
613
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
614
|
+
const sentencepiece::util::bytes serialized = ptr->DecodePiecesAsSerializedProto(pcs);
|
615
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
616
|
+
|
617
|
+
return output;
|
618
|
+
};
|
619
|
+
|
620
|
+
static VALUE _sentencepiece_processor_decode_ids_as_serialized_proto(VALUE self, VALUE ids) {
|
621
|
+
if (!RB_TYPE_P(ids, T_ARRAY)) {
|
622
|
+
rb_raise(rb_eArgError, "expected ids to be an Array");
|
623
|
+
return Qnil;
|
624
|
+
}
|
625
|
+
|
626
|
+
std::vector<int> pcs;
|
627
|
+
const size_t n_pieces = RARRAY_LEN(ids);
|
628
|
+
for (size_t i = 0; i < n_pieces; i++) {
|
629
|
+
VALUE et = rb_ary_entry(ids, i);
|
630
|
+
pcs.push_back(NUM2INT(et));
|
631
|
+
}
|
632
|
+
|
633
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
634
|
+
const sentencepiece::util::bytes serialized = ptr->DecodeIdsAsSerializedProto(pcs);
|
635
|
+
VALUE output = rb_str_new_cstr(serialized.c_str());
|
636
|
+
|
637
|
+
return output;
|
638
|
+
};
|
639
|
+
|
293
640
|
static VALUE _sentencepiece_processor_piece_size(VALUE self) {
|
294
641
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
295
642
|
return INT2NUM(ptr->GetPieceSize());
|
@@ -325,19 +672,49 @@ private:
|
|
325
672
|
VALUE output = Qnil;
|
326
673
|
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
327
674
|
if (RB_INTEGER_TYPE_P(ids)) {
|
328
|
-
const
|
675
|
+
const int idx = NUM2INT(ids);
|
676
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
677
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
678
|
+
return Qnil;
|
679
|
+
}
|
680
|
+
const std::string piece = ptr->IdToPiece(idx);
|
329
681
|
output = rb_utf8_str_new_cstr(piece.c_str());
|
330
682
|
} else {
|
331
683
|
const size_t n_ids = RARRAY_LEN(ids);
|
332
684
|
output = rb_ary_new();
|
333
685
|
for (size_t i = 0; i < n_ids; i++) {
|
334
686
|
VALUE et = rb_ary_entry(ids, i);
|
335
|
-
const
|
687
|
+
const int idx = NUM2INT(et);
|
688
|
+
if (idx < 0 || idx >= ptr->GetPieceSize()) {
|
689
|
+
rb_raise(rb_eIndexError, "piece id %i is out of range", idx);
|
690
|
+
return Qnil;
|
691
|
+
}
|
692
|
+
const std::string piece = ptr->IdToPiece(idx);
|
336
693
|
rb_ary_push(output, rb_utf8_str_new_cstr(piece.c_str()));
|
337
694
|
}
|
338
695
|
}
|
339
696
|
return output;
|
340
697
|
};
|
698
|
+
|
699
|
+
static VALUE _sentencepiece_processor_unk_id(VALUE self) {
|
700
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
701
|
+
return INT2NUM(ptr->unk_id());
|
702
|
+
};
|
703
|
+
|
704
|
+
static VALUE _sentencepiece_processor_bos_id(VALUE self) {
|
705
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
706
|
+
return INT2NUM(ptr->bos_id());
|
707
|
+
};
|
708
|
+
|
709
|
+
static VALUE _sentencepiece_processor_eos_id(VALUE self) {
|
710
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
711
|
+
return INT2NUM(ptr->eos_id());
|
712
|
+
};
|
713
|
+
|
714
|
+
static VALUE _sentencepiece_processor_pad_id(VALUE self) {
|
715
|
+
sentencepiece::SentencePieceProcessor* ptr = get_sentencepiece_processor(self);
|
716
|
+
return INT2NUM(ptr->pad_id());
|
717
|
+
};
|
341
718
|
};
|
342
719
|
|
343
720
|
const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
@@ -350,4 +727,42 @@ const rb_data_type_t RbSentencePieceProcessor::sentencepiece_processor_type = {
|
|
350
727
|
RUBY_TYPED_FREE_IMMEDIATELY
|
351
728
|
};
|
352
729
|
|
730
|
+
namespace sentencepiece {
|
731
|
+
class SentencePieceTrainerWrapper {
|
732
|
+
public:
|
733
|
+
SentencePieceTrainerWrapper(){};
|
734
|
+
~SentencePieceTrainerWrapper(){};
|
735
|
+
|
736
|
+
util::Status Train(absl::string_view args) {
|
737
|
+
return SentencePieceTrainer::Train(args);
|
738
|
+
};
|
739
|
+
};
|
740
|
+
} // namespace sentencepiece
|
741
|
+
|
742
|
+
class RbSentencePieceTrainer {
|
743
|
+
public:
|
744
|
+
static VALUE define_class(VALUE outer) {
|
745
|
+
rb_cSentencePieceTrainer = rb_define_class_under(outer, "SentencePieceTrainer", rb_cObject);
|
746
|
+
rb_define_singleton_method(rb_cSentencePieceTrainer, "train", RUBY_METHOD_FUNC(_sentencepiece_trainer_train), 1);
|
747
|
+
return rb_cSentencePieceTrainer;
|
748
|
+
};
|
749
|
+
|
750
|
+
private:
|
751
|
+
static VALUE _sentencepiece_trainer_train(VALUE self, VALUE args) {
|
752
|
+
if (!RB_TYPE_P(args, T_STRING)) {
|
753
|
+
rb_raise(rb_eArgError, "expected args to be a String");
|
754
|
+
return Qnil;
|
755
|
+
}
|
756
|
+
|
757
|
+
const sentencepiece::util::Status status = sentencepiece::SentencePieceTrainer::Train(StringValueCStr(args));
|
758
|
+
if (!status.ok()) {
|
759
|
+
rb_raise(rb_eSentencePieceError, "%s", status.message());
|
760
|
+
return Qnil;
|
761
|
+
}
|
762
|
+
|
763
|
+
RB_GC_GUARD(args);
|
764
|
+
return Qnil;
|
765
|
+
};
|
766
|
+
};
|
767
|
+
|
353
768
|
#endif /* SENTENCEPIECE_HPP */
|
data/sig/sentencepiece.rbs
CHANGED
@@ -1,4 +1,63 @@
|
|
1
1
|
module SentencePiece
|
2
2
|
VERSION: String
|
3
|
-
|
3
|
+
|
4
|
+
class SentencePieceTrainer
|
5
|
+
def self.train: (String args) -> void
|
6
|
+
end
|
7
|
+
|
8
|
+
class SentencePieceProcessor
|
9
|
+
def initialize: (?model_file: String model_file) -> void
|
10
|
+
|
11
|
+
def load: (String model_file) -> void
|
12
|
+
|
13
|
+
def encode: (String text, ?out_type: String out_type) -> (Array[Integer] | Array[String])
|
14
|
+
| (Array[String] text, ?out_type: String out_type) -> (Array[Array[Integer]] | Array[Array[String]])
|
15
|
+
|
16
|
+
def encode_as_ids: (String text) -> Array[Integer]
|
17
|
+
|
18
|
+
def encode_as_pieces: (String text) -> Array[String]
|
19
|
+
|
20
|
+
def encode_as_serialized_proto: (String text) -> String
|
21
|
+
|
22
|
+
def nbest_encode_as_ids: (String text, nbest_size: Integer nbest_size) -> Array[Array[Integer]]
|
23
|
+
|
24
|
+
def nbest_encode_as_pieces: (String text, nbest_size: Integer nbest_size) -> Array[Array[String]]
|
25
|
+
|
26
|
+
def nbest_encode_as_serialized_proto: (String text, nbest_size: Integer nbest_size) -> String
|
27
|
+
|
28
|
+
def sample_encode_as_ids: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> Array[Integer]
|
29
|
+
|
30
|
+
def sample_encode_as_pieces: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> Array[String]
|
31
|
+
|
32
|
+
def sample_encode_as_serialized_proto: (String text, nbest_size: Integer nbest_size, alpha: Float alpha) -> String
|
33
|
+
|
34
|
+
def decode: (Array[Integer], ?out_type: String out_type) -> String
|
35
|
+
| (Array[Array[Integer]], ?out_type: String out_type) -> Array[String]
|
36
|
+
| (Array[String], ?out_type: String out_type) -> String
|
37
|
+
| (Array[Array[String]], ?out_type: String out_type) -> Array[String]
|
38
|
+
|
39
|
+
def decode_ids: (Array[Integer]) -> String
|
40
|
+
|
41
|
+
def decode_ids_as_serialized_proto: (Array[Integer] ids) -> String
|
42
|
+
|
43
|
+
def decode_pieces: (Array[String]) -> String
|
44
|
+
|
45
|
+
def decode_pieces_as_serialized_proto: (Array[String] pieces) -> String
|
46
|
+
|
47
|
+
def id_to_piece: (Integer id) -> String
|
48
|
+
| (Array[Integer] ids) -> Array[String]
|
49
|
+
|
50
|
+
def piece_to_id: (String piece) -> Integer
|
51
|
+
| (Array[String] pieces) -> Array[Integer]
|
52
|
+
|
53
|
+
def piece_size: () -> Integer
|
54
|
+
|
55
|
+
def bos_id: () -> Integer
|
56
|
+
|
57
|
+
def eos_id: () -> Integer
|
58
|
+
|
59
|
+
def pad_id: () -> Integer
|
60
|
+
|
61
|
+
def unk_id: () -> Integer
|
62
|
+
end
|
4
63
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: sentencepiece
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0
|
4
|
+
version: 0.1.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-03-
|
11
|
+
date: 2023-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: |
|
14
14
|
sentencepiece.rb provides Ruby bindings for the SentencePiece,
|
@@ -36,6 +36,7 @@ metadata:
|
|
36
36
|
homepage_uri: https://github.com/yoshoku/sentencepiece.rb
|
37
37
|
source_code_uri: https://github.com/yoshoku/sentencepiece.rb
|
38
38
|
changelog_uri: https://github.com/yoshoku/sentencepiece.rb/blob/main/CHANGELOG.md
|
39
|
+
documentation_uri: https://yoshoku.github.io/sentencepiece.rb/doc/
|
39
40
|
rubygems_mfa_required: 'true'
|
40
41
|
post_install_message:
|
41
42
|
rdoc_options: []
|