react-native-executorch 0.7.0 → 0.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/common/rnexecutorch/TokenizerModule.cpp +3 -2
- package/common/rnexecutorch/TokenizerModule.h +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +8 -4
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -1
- package/package.json +4 -3
- package/src/modules/computer_vision/TextToImageModule.ts +9 -4
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_model.h +84 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/bpe_tokenizer_base.h +6 -87
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/hf_tokenizer.h +28 -176
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/map_utils.h +174 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/model.h +151 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/normalizer.h +55 -1
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/padding.h +112 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/post_processor.h +101 -42
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/pre_tokenizer.h +25 -9
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/token_decoder.h +33 -6
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/tokenizer.h +2 -2
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/truncation.h +92 -0
- package/third-party/include/executorch/extension/llm/tokenizers/include/pytorch/tokenizers/wordpiece_model.h +74 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +0 -253
- package/common/rnexecutorch/tests/README.md +0 -73
- package/common/rnexecutorch/tests/integration/BaseModelTest.cpp +0 -207
- package/common/rnexecutorch/tests/integration/BaseModelTests.h +0 -120
- package/common/rnexecutorch/tests/integration/ClassificationTest.cpp +0 -117
- package/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +0 -122
- package/common/rnexecutorch/tests/integration/ImageSegmentationTest.cpp +0 -152
- package/common/rnexecutorch/tests/integration/LLMTest.cpp +0 -155
- package/common/rnexecutorch/tests/integration/OCRTest.cpp +0 -128
- package/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +0 -135
- package/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +0 -97
- package/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +0 -112
- package/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +0 -164
- package/common/rnexecutorch/tests/integration/TextToImageTest.cpp +0 -149
- package/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +0 -98
- package/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +0 -238
- package/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +0 -99
- package/common/rnexecutorch/tests/integration/assets/test_audio_float.raw +0 -0
- package/common/rnexecutorch/tests/integration/assets/we_are_software_mansion.jpg +0 -0
- package/common/rnexecutorch/tests/integration/libs/libfbjni.so +0 -0
- package/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +0 -45
- package/common/rnexecutorch/tests/integration/utils/TestUtils.h +0 -36
- package/common/rnexecutorch/tests/run_tests.sh +0 -333
- package/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +0 -32
- package/common/rnexecutorch/tests/unit/LogTest.cpp +0 -529
- package/common/rnexecutorch/tests/unit/NumericalTest.cpp +0 -107
|
@@ -13,8 +13,7 @@ using namespace executorch::extension::constants;
|
|
|
13
13
|
|
|
14
14
|
TokenizerModule::TokenizerModule(
|
|
15
15
|
std::string source, std::shared_ptr<react::CallInvoker> callInvoker)
|
|
16
|
-
: tokenizer(std::make_unique<tokenizers::HFTokenizer>())
|
|
17
|
-
memorySizeLowerBound(std::filesystem::file_size(source)) {
|
|
16
|
+
: tokenizer(std::make_unique<tokenizers::HFTokenizer>()) {
|
|
18
17
|
|
|
19
18
|
auto status = tokenizer->load(source);
|
|
20
19
|
|
|
@@ -22,6 +21,8 @@ TokenizerModule::TokenizerModule(
|
|
|
22
21
|
throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError,
|
|
23
22
|
"Unexpected issue occured while loading tokenizer");
|
|
24
23
|
};
|
|
24
|
+
std::filesystem::path modelPath{source};
|
|
25
|
+
memorySizeLowerBound = std::filesystem::file_size(modelPath);
|
|
25
26
|
}
|
|
26
27
|
|
|
27
28
|
void TokenizerModule::ensureTokenizerLoaded(
|
|
@@ -26,7 +26,7 @@ public:
|
|
|
26
26
|
private:
|
|
27
27
|
void ensureTokenizerLoaded(const std::string &methodName) const;
|
|
28
28
|
std::unique_ptr<tokenizers::HFTokenizer> tokenizer;
|
|
29
|
-
|
|
29
|
+
std::size_t memorySizeLowerBound{0};
|
|
30
30
|
};
|
|
31
31
|
|
|
32
32
|
REGISTER_CONSTRUCTOR(TokenizerModule, std::string,
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import { ResourceFetcher } from '../../utils/ResourceFetcher';
|
|
4
4
|
import { BaseModule } from '../BaseModule';
|
|
5
|
-
import { Buffer } from 'buffer';
|
|
6
5
|
import { PNG } from 'pngjs/browser';
|
|
7
6
|
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
|
|
8
7
|
import { RnExecutorchError } from '../../errors/errorUtils';
|
|
@@ -65,12 +64,17 @@ export class TextToImageModule extends BaseModule {
|
|
|
65
64
|
width: imageSize,
|
|
66
65
|
height: imageSize
|
|
67
66
|
});
|
|
68
|
-
png.data =
|
|
67
|
+
png.data = outputArray;
|
|
69
68
|
const pngBuffer = PNG.sync.write(png, {
|
|
70
69
|
colorType: 6
|
|
71
70
|
});
|
|
72
|
-
const
|
|
73
|
-
|
|
71
|
+
const pngArray = new Uint8Array(pngBuffer);
|
|
72
|
+
let binary = '';
|
|
73
|
+
const chunkSize = 8192;
|
|
74
|
+
for (let i = 0; i < pngArray.length; i += chunkSize) {
|
|
75
|
+
binary += String.fromCharCode(...pngArray.subarray(i, i + chunkSize));
|
|
76
|
+
}
|
|
77
|
+
return btoa(binary);
|
|
74
78
|
}
|
|
75
79
|
|
|
76
80
|
/**
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"names":["ResourceFetcher","BaseModule","
|
|
1
|
+
{"version":3,"names":["ResourceFetcher","BaseModule","PNG","RnExecutorchErrorCode","RnExecutorchError","TextToImageModule","constructor","inferenceCallback","stepIdx","load","model","onDownloadProgressCallback","results","fetch","tokenizerSource","schedulerSource","encoderSource","unetSource","decoderSource","DownloadInterrupted","tokenizerPath","schedulerPath","encoderPath","unetPath","decoderPath","response","schedulerConfig","json","nativeModule","global","loadTextToImage","beta_start","beta_end","num_train_timesteps","steps_offset","forward","input","imageSize","numSteps","seed","output","generate","outputArray","Uint8Array","length","png","width","height","data","pngBuffer","sync","write","colorType","pngArray","binary","chunkSize","i","String","fromCharCode","subarray","btoa","interrupt"],"sourceRoot":"../../../../src","sources":["modules/computer_vision/TextToImageModule.ts"],"mappings":";;AAAA,SAASA,eAAe,QAAQ,6BAA6B;AAE7D,SAASC,UAAU,QAAQ,eAAe;AAE1C,SAASC,GAAG,QAAQ,eAAe;AACnC,SAASC,qBAAqB,QAAQ,yBAAyB;AAC/D,SAASC,iBAAiB,QAAQ,yBAAyB;;AAE3D;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,iBAAiB,SAASJ,UAAU,CAAC;EAGhD;AACF;AACA;AACA;AACA;EACEK,WAAWA,CAACC,iBAA6C,EAAE;IACzD,KAAK,CAAC,CAAC;IACP,IAAI,CAACA,iBAAiB,GAAIC,OAAe,IAAK;MAC5CD,iBAAiB,GAAGC,OAAO,CAAC;IAC9B,CAAC;EACH;;EAEA;AACF;AACA;AACA;AACA;AACA;EACE,MAAMC,IAAIA,CACRC,KAMC,EACDC,0BAAsD,GAAGA,CAAA,KAAM,CAAC,CAAC,EAClD;IACf,MAAMC,OAAO,GAAG,MAAMZ,eAAe,CAACa,KAAK,CACzCF,0BAA0B,EAC1BD,KAAK,CAACI,eAAe,EACrBJ,KAAK,CAACK,eAAe,EACrBL,KAAK,CAACM,aAAa,EACnBN,KAAK,CAACO,UAAU,EAChBP,KAAK,CAACQ,aACR,CAAC;IACD,IAAI,CAACN,OAAO,EAAE;MACZ,MAAM,IAAIR,iBAAiB,CACzBD,qBAAqB,CAACgB,mBAAmB,EACzC,2GACF,CAAC;IACH;IACA,MAAM,CAACC,aAAa,EAAEC,aAAa,EAAEC,WAAW,EAAEC,QAAQ,EAAEC,WAAW,CAAC,GACtEZ,OAAO;IAET,IACE,CAACQ,aAAa,IACd,CAACC,aAAa,IACd,CAACC,WAAW,IACZ,CAACC,QAAQ,IACT,CAACC,WAAW,EACZ;MACA,MAAM,IAAIpB,iBAAiB,CACzBD,qBAAqB,CAACgB,mBAAmB,EACzC,2GACF,CAAC;IACH;IAEA,MAAMM,QAAQ,GAAG,MAAMZ,KAAK,CAAC,SAAS,GAAGQ,aAAa,CAAC;IACvD,MAAMK,eAAe,GAAG,MAAMD,QAAQ,CAACE,IAAI,CAAC,CAAC;IAE7C,IAAI,CAACC,YAAY,GAAGC,MAAM,CAACC,eAAe,CACxCV,aAAa,EACbE,WAAW,EACXC,QAAQ,EACRC,WAAW,EACXE,eAAe,CAACK,UAAU,EAC1BL,eAAe,CAACM,QAAQ,EACxBN,eAAe,CAACO,mBAAmB,EACnCP,eAAe,CAACQ,YAClB,CAAC;EACH;;EAEA;AACF;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;EACE,MAAMC,OAAOA,CACXC,KAAa,EACbC,SAAiB,GAAG,GAAG,EACvBC,QAAgB,GAAG,CAAC,EACpBC,IAAa,EACI;IACjB,MAAMC,MAAM,GAAG,MAAM,IAAI,CAACZ,YAAY,CAACa,QAAQ,CAC7CL,KAAK,EACLC,SAAS,EACTC,QAAQ,EACRC,IAAI,GAAGA,IAAI,GAAG,CAAC,CAAC,EAChB,IAAI,CAAChC,iBACP,CAAC;IACD,MAAMmC,WAAW,GAAG,IAAIC,UAAU,CAACH,MAAM,CAAC;IAC1C,IAAI,CAACE,WAAW,CAACE,MAAM,EAAE;MACvB,OAAO,EAAE;IACX;IACA,MAAMC,GAAG,GAAG,IAAI3C,GAAG,CAAC;MAAE4C,KAAK,EAAET,SAAS;MAAEU,MAAM,EAAEV;IAAU,CAAC,CAAC;IAC5DQ,GAAG,CAACG,IAAI,GAAGN,WAAgC;IAC3C,MAAMO,SAAS,GAAG/C,GAAG,CAACgD,IAAI,CAACC,KAAK,CAACN,GAAG,EAAE;MAAEO,SAAS,EAAE;IAAE,CAAC,CAAC;IACvD,MAAMC,QAAQ,GAAG,IAAIV,UAAU,CAACM,SAAuC,CAAC;IACxE,IAAIK,MAAM,GAAG,EAAE;IACf,MAAMC,SAAS,GAAG,IAAI;IACtB,KAAK,IAAIC,CAAC,GAAG,CAAC,EAAEA,CAAC,GAAGH,QAAQ,CAACT,MAAM,EAAEY,CAAC,IAAID,SAAS,EAAE;MACnDD,MAAM,IAAIG,MAAM,CAACC,YAAY,CAAC,GAAGL,QAAQ,CAACM,QAAQ,CAACH,CAAC,EAAEA,CAAC,GAAGD,SAAS,CAAC,CAAC;IACvE;IACA,OAAOK,IAAI,CAACN,MAAM,CAAC;EACrB;;EAEA;AACF;AACA;EACSO,SAASA,CAAA,EAAS;IACvB,IAAI,CAACjC,YAAY,CAACiC,SAAS,CAAC,CAAC;EAC/B;AACF","ignoreList":[]}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"TextToImageModule.d.ts","sourceRoot":"","sources":["../../../../src/modules/computer_vision/TextToImageModule.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,cAAc,EAAE,MAAM,oBAAoB,CAAC;AACpD,OAAO,EAAE,UAAU,EAAE,MAAM,eAAe,CAAC;AAM3C;;;;GAIG;AACH,qBAAa,iBAAkB,SAAQ,UAAU;IAC/C,OAAO,CAAC,iBAAiB,CAA4B;IAErD;;;;OAIG;gBACS,iBAAiB,CAAC,EAAE,CAAC,OAAO,EAAE,MAAM,KAAK,IAAI;IAOzD;;;;;OAKG;IACG,IAAI,CACR,KAAK,EAAE;QACL,eAAe,EAAE,cAAc,CAAC;QAChC,eAAe,EAAE,cAAc,CAAC;QAChC,aAAa,EAAE,cAAc,CAAC;QAC9B,UAAU,EAAE,cAAc,CAAC;QAC3B,aAAa,EAAE,cAAc,CAAC;KAC/B,EACD,0BAA0B,GAAE,CAAC,QAAQ,EAAE,MAAM,KAAK,IAAe,GAChE,OAAO,CAAC,IAAI,CAAC;IA8ChB;;;;;;;;;OASG;IACG,OAAO,CACX,KAAK,EAAE,MAAM,EACb,SAAS,GAAE,MAAY,EACvB,QAAQ,GAAE,MAAU,EACpB,IAAI,CAAC,EAAE,MAAM,GACZ,OAAO,CAAC,MAAM,CAAC;
|
|
1
|
+
{"version":3,"file":"TextToImageModule.d.ts","sourceRoot":"","sources":["../../../../src/modules/computer_vision/TextToImageModule.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,cAAc,EAAE,MAAM,oBAAoB,CAAC;AACpD,OAAO,EAAE,UAAU,EAAE,MAAM,eAAe,CAAC;AAM3C;;;;GAIG;AACH,qBAAa,iBAAkB,SAAQ,UAAU;IAC/C,OAAO,CAAC,iBAAiB,CAA4B;IAErD;;;;OAIG;gBACS,iBAAiB,CAAC,EAAE,CAAC,OAAO,EAAE,MAAM,KAAK,IAAI;IAOzD;;;;;OAKG;IACG,IAAI,CACR,KAAK,EAAE;QACL,eAAe,EAAE,cAAc,CAAC;QAChC,eAAe,EAAE,cAAc,CAAC;QAChC,aAAa,EAAE,cAAc,CAAC;QAC9B,UAAU,EAAE,cAAc,CAAC;QAC3B,aAAa,EAAE,cAAc,CAAC;KAC/B,EACD,0BAA0B,GAAE,CAAC,QAAQ,EAAE,MAAM,KAAK,IAAe,GAChE,OAAO,CAAC,IAAI,CAAC;IA8ChB;;;;;;;;;OASG;IACG,OAAO,CACX,KAAK,EAAE,MAAM,EACb,SAAS,GAAE,MAAY,EACvB,QAAQ,GAAE,MAAU,EACpB,IAAI,CAAC,EAAE,MAAM,GACZ,OAAO,CAAC,MAAM,CAAC;IAwBlB;;OAEG;IACI,SAAS,IAAI,IAAI;CAGzB"}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "react-native-executorch",
|
|
3
|
-
"version": "0.7.
|
|
3
|
+
"version": "0.7.2",
|
|
4
4
|
"description": "An easy way to run AI models in React Native with ExecuTorch",
|
|
5
5
|
"source": "./src/index.ts",
|
|
6
6
|
"main": "./lib/module/index.js",
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
"ios",
|
|
15
15
|
"cpp",
|
|
16
16
|
"common",
|
|
17
|
+
"!common/rnexecutorch/tests",
|
|
17
18
|
"*.podspec",
|
|
18
19
|
"third-party/include",
|
|
19
20
|
"third-party",
|
|
@@ -66,8 +67,8 @@
|
|
|
66
67
|
},
|
|
67
68
|
"peerDependencies": {
|
|
68
69
|
"expo": ">=54.0.0",
|
|
69
|
-
"expo-asset": "
|
|
70
|
-
"expo-file-system": "
|
|
70
|
+
"expo-asset": ">=12.0.0",
|
|
71
|
+
"expo-file-system": ">=19.0.0",
|
|
71
72
|
"react": "*",
|
|
72
73
|
"react-native": "*"
|
|
73
74
|
},
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { ResourceFetcher } from '../../utils/ResourceFetcher';
|
|
2
2
|
import { ResourceSource } from '../../types/common';
|
|
3
3
|
import { BaseModule } from '../BaseModule';
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
import { PNG } from 'pngjs/browser';
|
|
6
6
|
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
|
|
7
7
|
import { RnExecutorchError } from '../../errors/errorUtils';
|
|
@@ -115,10 +115,15 @@ export class TextToImageModule extends BaseModule {
|
|
|
115
115
|
return '';
|
|
116
116
|
}
|
|
117
117
|
const png = new PNG({ width: imageSize, height: imageSize });
|
|
118
|
-
png.data = Buffer
|
|
118
|
+
png.data = outputArray as unknown as Buffer;
|
|
119
119
|
const pngBuffer = PNG.sync.write(png, { colorType: 6 });
|
|
120
|
-
const
|
|
121
|
-
|
|
120
|
+
const pngArray = new Uint8Array(pngBuffer as unknown as ArrayBufferLike);
|
|
121
|
+
let binary = '';
|
|
122
|
+
const chunkSize = 8192;
|
|
123
|
+
for (let i = 0; i < pngArray.length; i += chunkSize) {
|
|
124
|
+
binary += String.fromCharCode(...pngArray.subarray(i, i + chunkSize));
|
|
125
|
+
}
|
|
126
|
+
return btoa(binary);
|
|
122
127
|
}
|
|
123
128
|
|
|
124
129
|
/**
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
// @lint-ignore-every LICENSELINT
|
|
9
|
+
|
|
10
|
+
#pragma once
|
|
11
|
+
|
|
12
|
+
#include <functional>
|
|
13
|
+
#include <memory>
|
|
14
|
+
#include <optional>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <vector>
|
|
17
|
+
|
|
18
|
+
#include <pytorch/tokenizers/map_utils.h>
|
|
19
|
+
#include <pytorch/tokenizers/model.h>
|
|
20
|
+
#include <pytorch/tokenizers/regex.h>
|
|
21
|
+
#include <pytorch/tokenizers/result.h>
|
|
22
|
+
#include <pytorch/tokenizers/string_integer_map.h>
|
|
23
|
+
|
|
24
|
+
namespace tokenizers {
|
|
25
|
+
|
|
26
|
+
class BPEModel : public Model {
|
|
27
|
+
public:
|
|
28
|
+
explicit BPEModel(detail::TokenMap token_map,
|
|
29
|
+
detail::TokenMap special_token_map,
|
|
30
|
+
std::optional<detail::TokenMap> merge_ranks,
|
|
31
|
+
std::unique_ptr<IRegex> special_token_regex,
|
|
32
|
+
bool byte_fallback, std::optional<uint64_t> unk_token_id,
|
|
33
|
+
std::optional<uint64_t> bos_token_id,
|
|
34
|
+
std::optional<uint64_t> eos_token_id);
|
|
35
|
+
|
|
36
|
+
~BPEModel() override = default;
|
|
37
|
+
|
|
38
|
+
Result<std::vector<uint64_t>>
|
|
39
|
+
tokenize(const std::string &piece) const override;
|
|
40
|
+
|
|
41
|
+
Result<std::string> id_to_piece(uint64_t token) const override;
|
|
42
|
+
Result<uint64_t> piece_to_id(const std::string &token) const override;
|
|
43
|
+
|
|
44
|
+
int32_t vocab_size() const override { return vocab_size_; }
|
|
45
|
+
|
|
46
|
+
bool is_special_token(uint64_t token) const override;
|
|
47
|
+
|
|
48
|
+
bool is_loaded() const override { return initialized_; }
|
|
49
|
+
|
|
50
|
+
std::pair<std::optional<std::string>, std::string>
|
|
51
|
+
split_with_allowed_special_token(const std::string &input,
|
|
52
|
+
size_t offset) const override;
|
|
53
|
+
|
|
54
|
+
uint64_t bos_token_id() const override { return bos_token_id_.value_or(0); }
|
|
55
|
+
|
|
56
|
+
uint64_t eos_token_id() const override { return eos_token_id_.value_or(0); }
|
|
57
|
+
|
|
58
|
+
private:
|
|
59
|
+
Result<std::pair<std::vector<uint64_t>, uint64_t>>
|
|
60
|
+
encode_with_special_token(const std::string &text) const;
|
|
61
|
+
|
|
62
|
+
Result<std::vector<uint64_t>>
|
|
63
|
+
byte_pair_encode(const std::string &piece) const;
|
|
64
|
+
|
|
65
|
+
std::vector<uint64_t>
|
|
66
|
+
byte_pair_merge(const std::string &piece, const detail::TokenMap &ranks,
|
|
67
|
+
std::function<uint64_t(uint64_t, uint64_t)> func) const;
|
|
68
|
+
|
|
69
|
+
// Real state
|
|
70
|
+
detail::TokenMap token_map_;
|
|
71
|
+
detail::TokenMap special_token_map_;
|
|
72
|
+
std::optional<detail::TokenMap> merge_ranks_;
|
|
73
|
+
std::unique_ptr<IRegex> special_token_regex_;
|
|
74
|
+
|
|
75
|
+
bool byte_fallback_ = false;
|
|
76
|
+
std::optional<uint64_t> unk_token_id_;
|
|
77
|
+
std::optional<uint64_t> bos_token_id_;
|
|
78
|
+
std::optional<uint64_t> eos_token_id_;
|
|
79
|
+
|
|
80
|
+
bool initialized_ = false;
|
|
81
|
+
int32_t vocab_size_ = 0;
|
|
82
|
+
};
|
|
83
|
+
|
|
84
|
+
} // namespace tokenizers
|
|
@@ -19,99 +19,18 @@
|
|
|
19
19
|
#include <vector>
|
|
20
20
|
|
|
21
21
|
// Local
|
|
22
|
-
#include
|
|
23
|
-
#include
|
|
24
|
-
#include
|
|
25
|
-
#include
|
|
26
|
-
#include
|
|
22
|
+
#include <pytorch/tokenizers/error.h>
|
|
23
|
+
#include <pytorch/tokenizers/map_utils.h>
|
|
24
|
+
#include <pytorch/tokenizers/regex.h>
|
|
25
|
+
#include <pytorch/tokenizers/result.h>
|
|
26
|
+
#include <pytorch/tokenizers/string_integer_map.h>
|
|
27
|
+
#include <pytorch/tokenizers/tokenizer.h>
|
|
27
28
|
|
|
28
29
|
#include "re2/re2.h"
|
|
29
30
|
|
|
30
31
|
namespace tokenizers {
|
|
31
32
|
namespace detail {
|
|
32
33
|
|
|
33
|
-
using TokenMap = StringIntegerMap<>;
|
|
34
|
-
|
|
35
|
-
template <typename TToken, typename TRank>
|
|
36
|
-
static Result<TokenMap>
|
|
37
|
-
build_token_map(std::vector<std::pair<TToken, TRank>> container) {
|
|
38
|
-
static_assert(std::is_same_v<TToken, std::string> ||
|
|
39
|
-
std::is_same_v<TToken, std::string_view>,
|
|
40
|
-
"TToken must be std::string or std::string_view");
|
|
41
|
-
static_assert(std::is_integral_v<TRank> && std::is_unsigned_v<TRank>,
|
|
42
|
-
"TRank must be an unsigned integer");
|
|
43
|
-
|
|
44
|
-
std::sort(container.begin(), container.end(),
|
|
45
|
-
[](const auto &a, const auto &b) { return a.first < b.first; });
|
|
46
|
-
|
|
47
|
-
auto duplicate_begin = std::unique(
|
|
48
|
-
container.begin(), container.end(),
|
|
49
|
-
[](const auto &a, const auto &b) { return a.first == b.first; });
|
|
50
|
-
|
|
51
|
-
TK_CHECK_OR_RETURN_ERROR(
|
|
52
|
-
duplicate_begin == container.end(), ParseFailure,
|
|
53
|
-
"duplicate token: %s rank: %llu", duplicate_begin->first.c_str(),
|
|
54
|
-
static_cast<unsigned long long>(duplicate_begin->second));
|
|
55
|
-
|
|
56
|
-
std::sort(container.begin(), container.end(),
|
|
57
|
-
[](const auto &a, const auto &b) { return a.second < b.second; });
|
|
58
|
-
|
|
59
|
-
duplicate_begin = std::unique(
|
|
60
|
-
container.begin(), container.end(),
|
|
61
|
-
[](const auto &a, const auto &b) { return a.second == b.second; });
|
|
62
|
-
|
|
63
|
-
TK_CHECK_OR_RETURN_ERROR(
|
|
64
|
-
duplicate_begin == container.end(), ParseFailure,
|
|
65
|
-
"duplicate rank: %llu"
|
|
66
|
-
" token: %s",
|
|
67
|
-
static_cast<unsigned long long>(duplicate_begin->second),
|
|
68
|
-
duplicate_begin->first.c_str());
|
|
69
|
-
|
|
70
|
-
return TokenMap(container);
|
|
71
|
-
};
|
|
72
|
-
|
|
73
|
-
template <typename TContainer, typename TTokenAccessor, typename TRankAccessor>
|
|
74
|
-
static Result<TokenMap> build_token_map(const TContainer &container,
|
|
75
|
-
TTokenAccessor token_accessor,
|
|
76
|
-
TRankAccessor rank_accessor) {
|
|
77
|
-
using TokenType = std::invoke_result_t<TTokenAccessor, const TContainer &>;
|
|
78
|
-
using RankType = std::invoke_result_t<TRankAccessor, const TContainer &>;
|
|
79
|
-
|
|
80
|
-
static_assert(std::is_same_v<TokenType, std::string> ||
|
|
81
|
-
std::is_same_v<TokenType, std::string_view>,
|
|
82
|
-
"TokenType must be std::string or std::string_view");
|
|
83
|
-
static_assert(std::is_integral_v<RankType> && std::is_unsigned_v<RankType>,
|
|
84
|
-
"RankType must be an unsigned integer");
|
|
85
|
-
|
|
86
|
-
std::vector<std::pair<TokenType, RankType>> pairs;
|
|
87
|
-
pairs.reserve(container.size());
|
|
88
|
-
for (const auto &value : container) {
|
|
89
|
-
pairs.emplace_back(token_accessor(value), rank_accessor(value));
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
return build_token_map(std::move(pairs));
|
|
93
|
-
}
|
|
94
|
-
|
|
95
|
-
inline Result<std::unique_ptr<IRegex>>
|
|
96
|
-
build_special_token_regex(const TokenMap &special_token_map) {
|
|
97
|
-
std::string special_pattern;
|
|
98
|
-
const std::size_t count = special_token_map.size();
|
|
99
|
-
|
|
100
|
-
for (std::size_t i = 0; i < count; ++i) {
|
|
101
|
-
const auto &[token, _] = special_token_map.getElement(i);
|
|
102
|
-
if (!special_pattern.empty()) {
|
|
103
|
-
special_pattern += "|";
|
|
104
|
-
}
|
|
105
|
-
special_pattern += re2::RE2::QuoteMeta(std::string(token));
|
|
106
|
-
}
|
|
107
|
-
|
|
108
|
-
if (special_pattern.empty()) {
|
|
109
|
-
return static_cast<std::unique_ptr<IRegex>>(nullptr);
|
|
110
|
-
}
|
|
111
|
-
// Wrap pattern in parentheses for proper grouping
|
|
112
|
-
return create_regex("(" + special_pattern + ")");
|
|
113
|
-
}
|
|
114
|
-
|
|
115
34
|
class BPETokenizerBase : public Tokenizer {
|
|
116
35
|
public:
|
|
117
36
|
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
|
|
@@ -13,155 +13,26 @@
|
|
|
13
13
|
#pragma once
|
|
14
14
|
|
|
15
15
|
// Standard
|
|
16
|
+
#include <memory>
|
|
16
17
|
#include <string>
|
|
18
|
+
#include <vector>
|
|
17
19
|
|
|
18
20
|
// Local
|
|
19
|
-
#include "bpe_tokenizer_base.h"
|
|
20
|
-
#include "error.h"
|
|
21
|
-
#include "normalizer.h"
|
|
22
|
-
#include "post_processor.h"
|
|
23
|
-
#include "pre_tokenizer.h"
|
|
24
|
-
#include "result.h"
|
|
25
|
-
#include "token_decoder.h"
|
|
26
21
|
#include <nlohmann/json.hpp>
|
|
22
|
+
#include <pytorch/tokenizers/error.h>
|
|
23
|
+
#include <pytorch/tokenizers/model.h>
|
|
24
|
+
#include <pytorch/tokenizers/normalizer.h>
|
|
25
|
+
#include <pytorch/tokenizers/padding.h>
|
|
26
|
+
#include <pytorch/tokenizers/post_processor.h>
|
|
27
|
+
#include <pytorch/tokenizers/pre_tokenizer.h>
|
|
28
|
+
#include <pytorch/tokenizers/result.h>
|
|
29
|
+
#include <pytorch/tokenizers/token_decoder.h>
|
|
30
|
+
#include <pytorch/tokenizers/tokenizer.h>
|
|
31
|
+
#include <pytorch/tokenizers/truncation.h>
|
|
27
32
|
|
|
28
33
|
namespace tokenizers {
|
|
29
|
-
namespace detail {
|
|
30
34
|
|
|
31
|
-
|
|
32
|
-
struct PairHash {
|
|
33
|
-
std::size_t operator()(const std::pair<uint64_t, uint64_t> &p) const {
|
|
34
|
-
return std::hash<uint64_t>{}(p.first) ^
|
|
35
|
-
(std::hash<uint64_t>{}(p.second) << 1);
|
|
36
|
-
}
|
|
37
|
-
};
|
|
38
|
-
|
|
39
|
-
// Type alias for BPE merge map: (token_id_1, token_id_2) -> (rank,
|
|
40
|
-
// merged_token_id)
|
|
41
|
-
using MergeMap = std::unordered_map<std::pair<uint64_t, uint64_t>,
|
|
42
|
-
std::pair<uint64_t, uint64_t>, PairHash>;
|
|
43
|
-
|
|
44
|
-
// Utility function to build merge ranks map from merge rules
|
|
45
|
-
template <typename TMergeMap>
|
|
46
|
-
inline Result<TokenMap> build_merge_ranks_map(const TMergeMap &merge_map,
|
|
47
|
-
const TokenMap &token_map) {
|
|
48
|
-
// Static assertions to verify TMergeMap has the expected key and value types
|
|
49
|
-
using KeyType = typename TMergeMap::key_type;
|
|
50
|
-
using ValueType = typename TMergeMap::mapped_type;
|
|
51
|
-
|
|
52
|
-
static_assert(std::is_same_v<KeyType, std::pair<uint64_t, uint64_t>>,
|
|
53
|
-
"TMergeMap key type must be std::pair<uint64_t, uint64_t>");
|
|
54
|
-
|
|
55
|
-
static_assert(std::is_same_v<ValueType, std::pair<uint64_t, uint64_t>>,
|
|
56
|
-
"TMergeMap value type must be std::pair<uint64_t, uint64_t>");
|
|
57
|
-
|
|
58
|
-
// Use a map to handle duplicates - keep the lowest rank (highest priority)
|
|
59
|
-
std::unordered_map<std::string, uint64_t> unique_merge_ranks;
|
|
60
|
-
|
|
61
|
-
for (const auto &[pair, rank_and_id] : merge_map) {
|
|
62
|
-
uint64_t first_id = pair.first;
|
|
63
|
-
uint64_t second_id = pair.second;
|
|
64
|
-
uint64_t rank = rank_and_id.first;
|
|
65
|
-
|
|
66
|
-
// Get the token strings for the pair
|
|
67
|
-
auto first_token = token_map.tryGetString(first_id);
|
|
68
|
-
auto second_token = token_map.tryGetString(second_id);
|
|
69
|
-
|
|
70
|
-
if (first_token && second_token) {
|
|
71
|
-
std::string merged_token =
|
|
72
|
-
std::string(*first_token) + std::string(*second_token);
|
|
73
|
-
|
|
74
|
-
// Keep the entry with the lowest rank (highest priority in BPE)
|
|
75
|
-
auto it = unique_merge_ranks.find(merged_token);
|
|
76
|
-
if (it == unique_merge_ranks.end() || rank < it->second) {
|
|
77
|
-
unique_merge_ranks[merged_token] = rank;
|
|
78
|
-
}
|
|
79
|
-
}
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
// Convert to vector for buildTokenMap
|
|
83
|
-
std::vector<std::pair<std::string, uint64_t>> merge_rank_pairs;
|
|
84
|
-
merge_rank_pairs.reserve(unique_merge_ranks.size());
|
|
85
|
-
|
|
86
|
-
for (const auto &[token, rank] : unique_merge_ranks) {
|
|
87
|
-
merge_rank_pairs.emplace_back(token, rank);
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
return build_token_map(std::move(merge_rank_pairs));
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
} // namespace detail
|
|
94
|
-
|
|
95
|
-
// Simple Word structure to mimic Rust's Word behavior
|
|
96
|
-
struct HFWord {
|
|
97
|
-
std::vector<uint64_t> tokens;
|
|
98
|
-
std::vector<size_t> byte_lengths;
|
|
99
|
-
|
|
100
|
-
void add(uint64_t token_id, size_t byte_len) {
|
|
101
|
-
tokens.push_back(token_id);
|
|
102
|
-
byte_lengths.push_back(byte_len);
|
|
103
|
-
}
|
|
104
|
-
|
|
105
|
-
size_t size() const { return tokens.size(); }
|
|
106
|
-
|
|
107
|
-
// Apply all possible merges using the merge ranks
|
|
108
|
-
void merge_all(const detail::TokenMap &merge_ranks,
|
|
109
|
-
const detail::TokenMap &token_map) {
|
|
110
|
-
while (tokens.size() > 1) {
|
|
111
|
-
std::optional<std::pair<size_t, uint32_t>> best_merge;
|
|
112
|
-
|
|
113
|
-
// Find the best merge (lowest rank) among adjacent token pairs
|
|
114
|
-
for (size_t i = 0; i < tokens.size() - 1; ++i) {
|
|
115
|
-
// Create the merged token string to look up its rank
|
|
116
|
-
auto first_token = token_map.tryGetString(tokens[i]);
|
|
117
|
-
auto second_token = token_map.tryGetString(tokens[i + 1]);
|
|
118
|
-
|
|
119
|
-
if (first_token && second_token) {
|
|
120
|
-
std::string merged_token =
|
|
121
|
-
std::string(*first_token) + std::string(*second_token);
|
|
122
|
-
auto rank = merge_ranks.tryGetInteger(merged_token);
|
|
123
|
-
|
|
124
|
-
if (rank && (!best_merge || *rank < best_merge->second)) {
|
|
125
|
-
best_merge = std::make_pair(i, static_cast<uint32_t>(*rank));
|
|
126
|
-
}
|
|
127
|
-
}
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
if (!best_merge) {
|
|
131
|
-
break; // No more merges possible
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
// Apply the best merge
|
|
135
|
-
size_t merge_idx = best_merge->first;
|
|
136
|
-
|
|
137
|
-
// Get the merged token ID
|
|
138
|
-
auto first_token = token_map.tryGetString(tokens[merge_idx]);
|
|
139
|
-
auto second_token = token_map.tryGetString(tokens[merge_idx + 1]);
|
|
140
|
-
|
|
141
|
-
if (first_token && second_token) {
|
|
142
|
-
std::string merged_token =
|
|
143
|
-
std::string(*first_token) + std::string(*second_token);
|
|
144
|
-
auto merged_id = token_map.tryGetInteger(merged_token);
|
|
145
|
-
|
|
146
|
-
if (merged_id) {
|
|
147
|
-
// Replace the two tokens with the merged token
|
|
148
|
-
tokens[merge_idx] = *merged_id;
|
|
149
|
-
byte_lengths[merge_idx] += byte_lengths[merge_idx + 1];
|
|
150
|
-
|
|
151
|
-
// Remove the second token
|
|
152
|
-
tokens.erase(tokens.begin() + merge_idx + 1);
|
|
153
|
-
byte_lengths.erase(byte_lengths.begin() + merge_idx + 1);
|
|
154
|
-
} else {
|
|
155
|
-
break; // Merged token not found in vocabulary
|
|
156
|
-
}
|
|
157
|
-
} else {
|
|
158
|
-
break; // Original tokens not found in vocabulary
|
|
159
|
-
}
|
|
160
|
-
}
|
|
161
|
-
}
|
|
162
|
-
};
|
|
163
|
-
|
|
164
|
-
class HFTokenizer : public detail::BPETokenizerBase {
|
|
35
|
+
class HFTokenizer : public Tokenizer {
|
|
165
36
|
public:
|
|
166
37
|
/*-- Public Interface --*/
|
|
167
38
|
|
|
@@ -179,53 +50,34 @@ public:
|
|
|
179
50
|
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos = 0,
|
|
180
51
|
int8_t eos = 0) const override;
|
|
181
52
|
|
|
182
|
-
|
|
53
|
+
Result<std::string> id_to_piece(uint64_t token) const override;
|
|
54
|
+
Result<uint64_t> piece_to_id(const std::string &text) const override;
|
|
55
|
+
|
|
56
|
+
Result<std::string> decode(uint64_t prev_token, uint64_t token,
|
|
57
|
+
bool skip_special_tokens = false) const override;
|
|
183
58
|
|
|
184
59
|
Result<std::string> decode(const std::vector<uint64_t> &tokens,
|
|
185
|
-
bool skip_special_tokens =
|
|
60
|
+
bool skip_special_tokens = false) const;
|
|
186
61
|
|
|
187
62
|
private:
|
|
188
|
-
Error _encode(const std::string &input, std::vector<uint64_t> &ret,
|
|
189
|
-
uint64_t &last_piece_token_len) const override;
|
|
190
|
-
|
|
191
|
-
void _decode(const std::string &input, std::string &ret) const override;
|
|
192
|
-
|
|
193
|
-
std::vector<std::string>
|
|
194
|
-
_decode(const std::vector<std::string> &pieces) const;
|
|
195
|
-
|
|
196
|
-
Result<std::vector<uint64_t>>
|
|
197
|
-
byte_pair_encode_(const std::string &piece,
|
|
198
|
-
const detail::TokenMap &encoder) const override;
|
|
199
|
-
|
|
200
|
-
// Override the virtual _byte_pair_merge method to use explicit merges
|
|
201
|
-
// specified in tokenizer.json. Different from Tiktoken (another user of
|
|
202
|
-
// BPETokenizerBase, but doesn't use explicit merge rules).
|
|
203
|
-
std::vector<uint64_t> _byte_pair_merge(
|
|
204
|
-
const std::string &piece, const detail::TokenMap &ranks,
|
|
205
|
-
std::function<uint64_t(uint64_t, uint64_t)> func) const override;
|
|
206
|
-
|
|
207
|
-
Error parse_special_tokens(const nlohmann::json &parsed_json);
|
|
208
|
-
Error parse_tokens(const nlohmann::json &parsed_json);
|
|
209
63
|
Error setup_normalizer(const nlohmann::json &parsed_json);
|
|
210
64
|
Error setup_pretokenizer(const nlohmann::json &parsed_json);
|
|
211
65
|
Error setup_postprocessor(const nlohmann::json &parsed_json);
|
|
212
66
|
Error setup_decoder(const nlohmann::json &parsed_json);
|
|
213
|
-
Error
|
|
214
|
-
Error
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
67
|
+
Error setup_truncation(const nlohmann::json &parsed_json);
|
|
68
|
+
Error setup_padding(const nlohmann::json &parsed_json);
|
|
69
|
+
Error setup_model(const nlohmann::json &parsed_json,
|
|
70
|
+
const std::string &model_config_path,
|
|
71
|
+
const std::string &special_tokens_map_path);
|
|
218
72
|
|
|
219
73
|
Normalizer::Ptr _normalizer;
|
|
220
74
|
PreTokenizer::Ptr _pretokenizer;
|
|
221
75
|
PostProcessor::Ptr _postprocessor;
|
|
222
76
|
TokenDecoder::Ptr _decoder;
|
|
77
|
+
Truncation::Ptr _truncation;
|
|
78
|
+
Padding::Ptr _padding;
|
|
223
79
|
|
|
224
|
-
|
|
225
|
-
std::optional<detail::TokenMap>
|
|
226
|
-
merge_ranks_; // Pre-computed merge ranks for BPE
|
|
227
|
-
bool byte_fallback_ = false;
|
|
228
|
-
bool unk_token_is_configured_ = false;
|
|
80
|
+
Model::Ptr _model;
|
|
229
81
|
};
|
|
230
82
|
|
|
231
|
-
} // namespace tokenizers
|
|
83
|
+
} // namespace tokenizers
|