kotoshu 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.rubocop.yml +18 -0
- data/CHANGELOG.md +182 -0
- data/CLAUDE.md +172 -0
- data/CODE_OF_CONDUCT.md +132 -0
- data/LICENSE +31 -0
- data/README.adoc +955 -0
- data/Rakefile +12 -0
- data/SECURITY.md +93 -0
- data/examples/01_basic_word_checking.rb +38 -0
- data/examples/02_text_document_checking.rb +77 -0
- data/examples/03_dictionary_backends.rb +137 -0
- data/examples/04_trie_data_structure.rb +146 -0
- data/examples/05_suggestion_algorithms.rb +239 -0
- data/examples/06_configuration_advanced.rb +287 -0
- data/examples/07_multi_language_dictionaries.rb +278 -0
- data/exe/kotoshu +6 -0
- data/lib/kotoshu/algorithms/capitalization.rb +276 -0
- data/lib/kotoshu/algorithms/lookup.rb +876 -0
- data/lib/kotoshu/algorithms/ngram_suggest.rb +270 -0
- data/lib/kotoshu/algorithms/permutations.rb +283 -0
- data/lib/kotoshu/algorithms/phonet_suggest.rb +167 -0
- data/lib/kotoshu/algorithms/suggest.rb +575 -0
- data/lib/kotoshu/algorithms.rb +14 -0
- data/lib/kotoshu/analyzers/semantic_analyzer.rb +295 -0
- data/lib/kotoshu/cache/base_cache.rb +596 -0
- data/lib/kotoshu/cache/cache.rb +91 -0
- data/lib/kotoshu/cache/frequency_cache.rb +224 -0
- data/lib/kotoshu/cache/language_cache.rb +454 -0
- data/lib/kotoshu/cache/lookup_cache.rb +166 -0
- data/lib/kotoshu/cache/model_cache.rb +513 -0
- data/lib/kotoshu/cache/suggestion_cache.rb +113 -0
- data/lib/kotoshu/cache.rb +40 -0
- data/lib/kotoshu/cli/auto_setup.rb +71 -0
- data/lib/kotoshu/cli/batch_reporter.rb +315 -0
- data/lib/kotoshu/cli/cache_command.rb +356 -0
- data/lib/kotoshu/cli/display_formatter.rb +431 -0
- data/lib/kotoshu/cli/errors.rb +36 -0
- data/lib/kotoshu/cli/interactive_reviewer.rb +319 -0
- data/lib/kotoshu/cli/language_resolver.rb +91 -0
- data/lib/kotoshu/cli/navigation_manager.rb +272 -0
- data/lib/kotoshu/cli/progress_reporter.rb +114 -0
- data/lib/kotoshu/cli/status_report.rb +130 -0
- data/lib/kotoshu/cli.rb +627 -0
- data/lib/kotoshu/commands/cache_command.rb +424 -0
- data/lib/kotoshu/commands/check_command.rb +312 -0
- data/lib/kotoshu/commands/model_command.rb +295 -0
- data/lib/kotoshu/components/passthrough_spell_checker.rb +72 -0
- data/lib/kotoshu/components/pos_tagger.rb +98 -0
- data/lib/kotoshu/components/spell_checker.rb +73 -0
- data/lib/kotoshu/components/synthesizer.rb +60 -0
- data/lib/kotoshu/components/tokenizer.rb +58 -0
- data/lib/kotoshu/components/whitespace_tokenizer.rb +96 -0
- data/lib/kotoshu/configuration/builder.rb +209 -0
- data/lib/kotoshu/configuration/resolver.rb +124 -0
- data/lib/kotoshu/configuration.rb +702 -0
- data/lib/kotoshu/core/exceptions.rb +165 -0
- data/lib/kotoshu/core/indexed_dictionary.rb +291 -0
- data/lib/kotoshu/core/models/affix_rule.rb +260 -0
- data/lib/kotoshu/core/models/result/document_result.rb +263 -0
- data/lib/kotoshu/core/models/result/word_result.rb +203 -0
- data/lib/kotoshu/core/models/word.rb +142 -0
- data/lib/kotoshu/core/trie/builder.rb +119 -0
- data/lib/kotoshu/core/trie/node.rb +94 -0
- data/lib/kotoshu/core/trie/trie.rb +249 -0
- data/lib/kotoshu/core.rb +28 -0
- data/lib/kotoshu/data/common_words/de.yml +1800 -0
- data/lib/kotoshu/data/common_words/en.yml +1215 -0
- data/lib/kotoshu/data/common_words/es.yml +750 -0
- data/lib/kotoshu/data/common_words/fr.yml +1015 -0
- data/lib/kotoshu/data/common_words/pt.yml +870 -0
- data/lib/kotoshu/data/common_words/ru.yml +484 -0
- data/lib/kotoshu/data/common_words_loader.rb +152 -0
- data/lib/kotoshu/data_structures/bloom_filter.rb +176 -0
- data/lib/kotoshu/debug_logger.rb +146 -0
- data/lib/kotoshu/debug_mode.rb +134 -0
- data/lib/kotoshu/defaults.rb +86 -0
- data/lib/kotoshu/dictionaries/catalog.rb +817 -0
- data/lib/kotoshu/dictionary/base.rb +237 -0
- data/lib/kotoshu/dictionary/cspell.rb +254 -0
- data/lib/kotoshu/dictionary/custom.rb +224 -0
- data/lib/kotoshu/dictionary/hunspell.rb +526 -0
- data/lib/kotoshu/dictionary/plain_text.rb +282 -0
- data/lib/kotoshu/dictionary/repository.rb +248 -0
- data/lib/kotoshu/dictionary/unified.rb +260 -0
- data/lib/kotoshu/dictionary/unix_words.rb +218 -0
- data/lib/kotoshu/documents/asciidoc_document.rb +441 -0
- data/lib/kotoshu/documents/document.rb +229 -0
- data/lib/kotoshu/documents/location.rb +139 -0
- data/lib/kotoshu/documents/markdown_document.rb +389 -0
- data/lib/kotoshu/documents/plain_text_document.rb +147 -0
- data/lib/kotoshu/embeddings/embedding_pipeline.rb +244 -0
- data/lib/kotoshu/embeddings/lru_cache.rb +233 -0
- data/lib/kotoshu/embeddings/onnx_runtime_model.rb +388 -0
- data/lib/kotoshu/embeddings/protocol.rb +83 -0
- data/lib/kotoshu/embeddings/protocols.rb +17 -0
- data/lib/kotoshu/embeddings/registry.rb +182 -0
- data/lib/kotoshu/embeddings/search.rb +192 -0
- data/lib/kotoshu/embeddings/similarity_engine.rb +248 -0
- data/lib/kotoshu/embeddings/similarity_search.rb +331 -0
- data/lib/kotoshu/embeddings/vocabulary.rb +257 -0
- data/lib/kotoshu/embeddings.rb +97 -0
- data/lib/kotoshu/fluent_checker.rb +91 -0
- data/lib/kotoshu/grammar/pattern_matchers/base_matcher.rb +48 -0
- data/lib/kotoshu/grammar/pattern_matchers/double_negative_matcher.rb +105 -0
- data/lib/kotoshu/grammar/pattern_matchers/possessive_context_matcher.rb +77 -0
- data/lib/kotoshu/grammar/pattern_matchers/vowel_sound_matcher.rb +83 -0
- data/lib/kotoshu/grammar/rule.rb +95 -0
- data/lib/kotoshu/grammar/rule_engine.rb +111 -0
- data/lib/kotoshu/grammar/rule_loader.rb +31 -0
- data/lib/kotoshu/grammar.rb +18 -0
- data/lib/kotoshu/integrity/audit_log.rb +88 -0
- data/lib/kotoshu/integrity/manifest.rb +117 -0
- data/lib/kotoshu/integrity/net_http.rb +46 -0
- data/lib/kotoshu/integrity.rb +25 -0
- data/lib/kotoshu/keyboard/layout.rb +115 -0
- data/lib/kotoshu/keyboard/layouts/azerty.rb +57 -0
- data/lib/kotoshu/keyboard/layouts/dvorak.rb +56 -0
- data/lib/kotoshu/keyboard/layouts/jcuken.rb +59 -0
- data/lib/kotoshu/keyboard/layouts/qwerty.rb +54 -0
- data/lib/kotoshu/keyboard/layouts/qwertz.rb +57 -0
- data/lib/kotoshu/keyboard/registry.rb +146 -0
- data/lib/kotoshu/keyboard.rb +60 -0
- data/lib/kotoshu/language/detector.rb +242 -0
- data/lib/kotoshu/language/identifier.rb +378 -0
- data/lib/kotoshu/language/languages/base.rb +256 -0
- data/lib/kotoshu/language/normalizer/base.rb +137 -0
- data/lib/kotoshu/language/registry.rb +147 -0
- data/lib/kotoshu/language/resources/ar/common_words.txt +6753 -0
- data/lib/kotoshu/language/resources/ar/confusion_sets.txt +11 -0
- data/lib/kotoshu/language/resources/de/common_words.txt +10003 -0
- data/lib/kotoshu/language/resources/de/confusion_sets.txt +246 -0
- data/lib/kotoshu/language/resources/en/common_words.txt +9979 -0
- data/lib/kotoshu/language/resources/en/confusion_sets.txt +871 -0
- data/lib/kotoshu/language/resources/es/common_words.txt +9992 -0
- data/lib/kotoshu/language/resources/es/confusion_sets.txt +17 -0
- data/lib/kotoshu/language/resources/fr/common_words.txt +9993 -0
- data/lib/kotoshu/language/resources/fr/confusion_sets.txt +76 -0
- data/lib/kotoshu/language/resources/pt/common_words.txt +9977 -0
- data/lib/kotoshu/language/resources/pt/confusion_sets.txt +18 -0
- data/lib/kotoshu/language/resources/ru/common_words.txt +9951 -0
- data/lib/kotoshu/language/resources/ru/confusion_sets.txt +5 -0
- data/lib/kotoshu/language/tokenizer/base.rb +170 -0
- data/lib/kotoshu/language/tokenizer/french_tokenizer.rb +170 -0
- data/lib/kotoshu/language/tokenizer/german_tokenizer.rb +41 -0
- data/lib/kotoshu/language/tokenizer/japanese_tokenizer.rb +60 -0
- data/lib/kotoshu/language/tokenizer/latin_tokenizer.rb +141 -0
- data/lib/kotoshu/language/tokenizer/portuguese_tokenizer.rb +160 -0
- data/lib/kotoshu/language/tokenizer/russian_tokenizer.rb +95 -0
- data/lib/kotoshu/language/tokenizer/spanish_tokenizer.rb +122 -0
- data/lib/kotoshu/language.rb +99 -0
- data/lib/kotoshu/languages/de/language.rb +546 -0
- data/lib/kotoshu/languages/en/language.rb +448 -0
- data/lib/kotoshu/languages/es/language.rb +459 -0
- data/lib/kotoshu/languages/fr/language.rb +493 -0
- data/lib/kotoshu/languages/ja/language.rb +477 -0
- data/lib/kotoshu/languages/pt/language.rb +423 -0
- data/lib/kotoshu/languages/ru/language.rb +404 -0
- data/lib/kotoshu/languages.rb +43 -0
- data/lib/kotoshu/metrics_collector.rb +222 -0
- data/lib/kotoshu/metrics_module.rb +110 -0
- data/lib/kotoshu/models/context.rb +119 -0
- data/lib/kotoshu/models/embedding_model.rb +182 -0
- data/lib/kotoshu/models/fasttext_model.rb +220 -0
- data/lib/kotoshu/models/nearest_neighbor.rb +87 -0
- data/lib/kotoshu/models/onnx_model.rb +333 -0
- data/lib/kotoshu/models/semantic_error.rb +165 -0
- data/lib/kotoshu/models/suggestion.rb +106 -0
- data/lib/kotoshu/models/word_embedding.rb +107 -0
- data/lib/kotoshu/paths.rb +53 -0
- data/lib/kotoshu/personal_dictionary.rb +94 -0
- data/lib/kotoshu/plugins/plugin.rb +61 -0
- data/lib/kotoshu/plugins/registry.rb +120 -0
- data/lib/kotoshu/project_config.rb +76 -0
- data/lib/kotoshu/readers/aff_data.rb +356 -0
- data/lib/kotoshu/readers/aff_reader.rb +375 -0
- data/lib/kotoshu/readers/condition_checker.rb +142 -0
- data/lib/kotoshu/readers/dic_reader.rb +118 -0
- data/lib/kotoshu/readers/file_reader.rb +347 -0
- data/lib/kotoshu/readers/lookup_builder.rb +299 -0
- data/lib/kotoshu/readers/readers.rb +6 -0
- data/lib/kotoshu/readers.rb +9 -0
- data/lib/kotoshu/resource_bundle.rb +30 -0
- data/lib/kotoshu/resource_manager.rb +295 -0
- data/lib/kotoshu/results/result.rb +165 -0
- data/lib/kotoshu/scripts/fasttext_to_onnx.py +275 -0
- data/lib/kotoshu/source_registry.rb +74 -0
- data/lib/kotoshu/spellchecker/parallel_checker.rb +90 -0
- data/lib/kotoshu/spellchecker.rb +298 -0
- data/lib/kotoshu/string_metrics.rb +153 -0
- data/lib/kotoshu/suggestions/context.rb +55 -0
- data/lib/kotoshu/suggestions/generator.rb +175 -0
- data/lib/kotoshu/suggestions/pipeline.rb +135 -0
- data/lib/kotoshu/suggestions/strategies/base_strategy.rb +296 -0
- data/lib/kotoshu/suggestions/strategies/composite_strategy.rb +140 -0
- data/lib/kotoshu/suggestions/strategies/edit_distance_strategy.rb +671 -0
- data/lib/kotoshu/suggestions/strategies/keyboard_proximity_strategy.rb +228 -0
- data/lib/kotoshu/suggestions/strategies/ngram_strategy.rb +130 -0
- data/lib/kotoshu/suggestions/strategies/phonetic_strategy.rb +329 -0
- data/lib/kotoshu/suggestions/strategies/semantic_strategy.rb +316 -0
- data/lib/kotoshu/suggestions/strategies/symspell_strategy.rb +275 -0
- data/lib/kotoshu/suggestions/suggestion.rb +174 -0
- data/lib/kotoshu/suggestions/suggestion_set.rb +238 -0
- data/lib/kotoshu/version.rb +5 -0
- data/lib/kotoshu.rb +493 -0
- data/script/validate_all_dictionaries.rb +444 -0
- data/sig/kotoshu.rbs +4 -0
- data/test_oop.rb +79 -0
- metadata +298 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative 'protocol'
|
|
4
|
+
|
|
5
|
+
# OnnxRuntimeModel - ONNX Runtime wrapper for FastText embeddings
|
|
6
|
+
#
|
|
7
|
+
# Provides embedding inference using ONNX Runtime. Supports single lookups,
|
|
8
|
+
# batch inference, and vocabulary-aware operations.
|
|
9
|
+
#
|
|
10
|
+
# @example Single embedding lookup
|
|
11
|
+
# model = OnnxRuntimeModel.from_file('fasttext.en.onnx', language_code: 'en')
|
|
12
|
+
# model.load!
|
|
13
|
+
# embedding = model.get_embedding(1234)
|
|
14
|
+
#
|
|
15
|
+
# @example Batch lookup
|
|
16
|
+
# embeddings = model.get_embeddings([1, 2, 3, 4, 5])
|
|
17
|
+
#
|
|
18
|
+
# @example With vocabulary
|
|
19
|
+
# embedding = model.get_embedding_for_word('hello', vocabulary)
|
|
20
|
+
#
|
|
21
|
+
class OnnxRuntimeModel
|
|
22
|
+
include EmbeddingModelProtocol
|
|
23
|
+
|
|
24
|
+
# Default dimension for FastText models
|
|
25
|
+
DEFAULT_DIMENSION = 300
|
|
26
|
+
|
|
27
|
+
# Batch size for batch inference
|
|
28
|
+
BATCH_SIZE = 32
|
|
29
|
+
|
|
30
|
+
# @return [String] Language code (ISO 639-1)
|
|
31
|
+
attr_reader :language_code
|
|
32
|
+
|
|
33
|
+
# @return [Integer] Embedding dimension
|
|
34
|
+
attr_reader :dimension
|
|
35
|
+
|
|
36
|
+
# @return [String] Path to ONNX model file
|
|
37
|
+
attr_reader :onnx_path
|
|
38
|
+
|
|
39
|
+
# @return [Boolean] Whether the model is loaded
|
|
40
|
+
attr_reader :loaded
|
|
41
|
+
|
|
42
|
+
# @return [Integer] Number of inference calls
|
|
43
|
+
attr_reader :inference_count
|
|
44
|
+
|
|
45
|
+
# Create a new ONNX Runtime model
|
|
46
|
+
#
|
|
47
|
+
# @param language_code [String] ISO 639-1 language code
|
|
48
|
+
# @param onnx_path [String] Path to .onnx file
|
|
49
|
+
# @param dimension [Integer] Embedding dimension (default: 300)
|
|
50
|
+
#
|
|
51
|
+
def initialize(language_code:, onnx_path:, dimension: DEFAULT_DIMENSION)
|
|
52
|
+
@language_code = language_code
|
|
53
|
+
@onnx_path = onnx_path
|
|
54
|
+
@dimension = dimension
|
|
55
|
+
@session = nil
|
|
56
|
+
@loaded = false
|
|
57
|
+
@input_name = nil
|
|
58
|
+
@output_name = nil
|
|
59
|
+
@inference_count = 0
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Load the ONNX model into memory
|
|
63
|
+
#
|
|
64
|
+
# @return [self]
|
|
65
|
+
#
|
|
66
|
+
# @raise [Kotoshu::Models::OnnxUnavailable] if onnxruntime gem is missing
|
|
67
|
+
# @raise [ArgumentError] if model file doesn't exist
|
|
68
|
+
#
|
|
69
|
+
def load!
|
|
70
|
+
return self if @loaded
|
|
71
|
+
|
|
72
|
+
raise Kotoshu::Models::OnnxModel::OnnxUnavailable unless Kotoshu::Models::OnnxModel::ONNX_LOADED
|
|
73
|
+
|
|
74
|
+
raise ArgumentError, "ONNX file not found: #{@onnx_path}" unless File.exist?(@onnx_path)
|
|
75
|
+
|
|
76
|
+
@session = OnnxRuntime::InferenceSession.new(@onnx_path)
|
|
77
|
+
|
|
78
|
+
# Detect input/output names
|
|
79
|
+
@input_name = detect_input_name
|
|
80
|
+
@output_name = detect_output_name
|
|
81
|
+
|
|
82
|
+
@loaded = true
|
|
83
|
+
self
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# Unload the model from memory
|
|
87
|
+
#
|
|
88
|
+
# @return [self]
|
|
89
|
+
#
|
|
90
|
+
def unload!
|
|
91
|
+
@session = nil
|
|
92
|
+
@input_name = nil
|
|
93
|
+
@output_name = nil
|
|
94
|
+
@loaded = false
|
|
95
|
+
self
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Check if model is ready for inference
|
|
99
|
+
#
|
|
100
|
+
# @return [Boolean]
|
|
101
|
+
#
|
|
102
|
+
def ready?
|
|
103
|
+
@loaded && !@session.nil?
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# Get embedding for a single word index
|
|
107
|
+
#
|
|
108
|
+
# @param index [Integer] Word index in vocabulary
|
|
109
|
+
# @return [Array<Float>] Embedding vector
|
|
110
|
+
#
|
|
111
|
+
# @raise [RuntimeError] if model is not loaded
|
|
112
|
+
# @raise [ArgumentError] if index is invalid
|
|
113
|
+
#
|
|
114
|
+
def get_embedding(index)
|
|
115
|
+
ensure_loaded
|
|
116
|
+
|
|
117
|
+
raise ArgumentError, "Invalid word index: #{index}" unless valid_index?(index)
|
|
118
|
+
|
|
119
|
+
output = @session.run(
|
|
120
|
+
[@output_name],
|
|
121
|
+
{ @input_name => [index] }
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@inference_count += 1
|
|
125
|
+
|
|
126
|
+
extract_embedding(output.first)
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
# Get embeddings for multiple indices (batched)
|
|
130
|
+
#
|
|
131
|
+
# More efficient than individual calls for batch operations.
|
|
132
|
+
#
|
|
133
|
+
# @param indices [Array<Integer>] Word indices
|
|
134
|
+
# @return [Array<Array<Float>>] Array of embedding vectors
|
|
135
|
+
#
|
|
136
|
+
def get_embeddings(indices)
|
|
137
|
+
ensure_loaded
|
|
138
|
+
return [] if indices.nil? || indices.empty?
|
|
139
|
+
|
|
140
|
+
valid_indices = indices.select { |i| valid_index?(i) }
|
|
141
|
+
return [] if valid_indices.empty?
|
|
142
|
+
|
|
143
|
+
# Process in batches for memory efficiency
|
|
144
|
+
valid_indices.each_slice(BATCH_SIZE).flat_map do |batch|
|
|
145
|
+
run_batch_inference(batch)
|
|
146
|
+
end
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Preload all embeddings into memory
|
|
150
|
+
#
|
|
151
|
+
# For small vocabularies, this provides O(1) lookup after loading.
|
|
152
|
+
#
|
|
153
|
+
# @param vocabulary [Vocabulary] Vocabulary with complete word list
|
|
154
|
+
# @return [Hash<Integer, Array<Float>>] Index to embedding mapping
|
|
155
|
+
#
|
|
156
|
+
def preload_embeddings!(vocabulary)
|
|
157
|
+
ensure_loaded
|
|
158
|
+
|
|
159
|
+
all_indices = (0...vocabulary.size).to_a
|
|
160
|
+
embeddings = get_embeddings(all_indices)
|
|
161
|
+
|
|
162
|
+
# Build index mapping
|
|
163
|
+
all_indices.zip(embeddings).to_h
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
# Get embedding for a word using vocabulary
|
|
167
|
+
#
|
|
168
|
+
# @param word [String] The word to lookup
|
|
169
|
+
# @param vocabulary [Vocabulary] Vocabulary for word-to-index mapping
|
|
170
|
+
# @return [Array<Float>, nil] Embedding vector or nil if word not found
|
|
171
|
+
#
|
|
172
|
+
def get_embedding_for_word(word, vocabulary)
|
|
173
|
+
index = vocabulary.lookup(word)
|
|
174
|
+
return nil unless index
|
|
175
|
+
|
|
176
|
+
get_embedding(index)
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
# Get embeddings for multiple words using vocabulary
|
|
180
|
+
#
|
|
181
|
+
# @param words [Array<String>] Words to lookup
|
|
182
|
+
# @param vocabulary [Vocabulary] Vocabulary for word-to-index mapping
|
|
183
|
+
# @return [Hash<String, Array<Float>>] Word to embedding mapping
|
|
184
|
+
#
|
|
185
|
+
def get_embeddings_for_words(words, vocabulary)
|
|
186
|
+
result = {}
|
|
187
|
+
words.each do |word|
|
|
188
|
+
embedding = get_embedding_for_word(word, vocabulary)
|
|
189
|
+
result[word] = embedding if embedding
|
|
190
|
+
end
|
|
191
|
+
result
|
|
192
|
+
end
|
|
193
|
+
|
|
194
|
+
# Check if batching is supported
|
|
195
|
+
#
|
|
196
|
+
# @return [Boolean]
|
|
197
|
+
#
|
|
198
|
+
def supports_batching?
|
|
199
|
+
true
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# Get batch size for batch inference
|
|
203
|
+
#
|
|
204
|
+
# @return [Integer]
|
|
205
|
+
#
|
|
206
|
+
def batch_size
|
|
207
|
+
BATCH_SIZE
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
# Get model type identifier
|
|
211
|
+
#
|
|
212
|
+
# @return [String]
|
|
213
|
+
#
|
|
214
|
+
def model_type
|
|
215
|
+
'onnx'
|
|
216
|
+
end
|
|
217
|
+
|
|
218
|
+
# Get model information
|
|
219
|
+
#
|
|
220
|
+
# @return [Hash]
|
|
221
|
+
#
|
|
222
|
+
def model_info
|
|
223
|
+
{
|
|
224
|
+
type: 'onnx',
|
|
225
|
+
language: @language_code,
|
|
226
|
+
dimension: @dimension,
|
|
227
|
+
path: @onnx_path,
|
|
228
|
+
loaded: @loaded,
|
|
229
|
+
inference_count: @inference_count
|
|
230
|
+
}
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
# Create model from file
|
|
234
|
+
#
|
|
235
|
+
# @param onnx_path [String] Path to .onnx file
|
|
236
|
+
# @param language_code [String] Language code (auto-detected if nil)
|
|
237
|
+
# @param dimension [Integer] Embedding dimension
|
|
238
|
+
# @return [OnnxRuntimeModel]
|
|
239
|
+
#
|
|
240
|
+
def self.from_file(onnx_path, language_code: nil, dimension: nil)
|
|
241
|
+
raise ArgumentError, "ONNX file not found: #{onnx_path}" unless File.exist?(onnx_path)
|
|
242
|
+
|
|
243
|
+
language_code ||= detect_language_from_path(onnx_path)
|
|
244
|
+
dimension ||= DEFAULT_DIMENSION
|
|
245
|
+
|
|
246
|
+
new(
|
|
247
|
+
language_code: language_code,
|
|
248
|
+
onnx_path: onnx_path,
|
|
249
|
+
dimension: dimension
|
|
250
|
+
)
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
# Create model from cache
|
|
254
|
+
#
|
|
255
|
+
# @param language_code [String] ISO 639-1 language code
|
|
256
|
+
# @param cache [Cache::ModelCache] Cache instance
|
|
257
|
+
# @return [OnnxRuntimeModel, nil]
|
|
258
|
+
#
|
|
259
|
+
def self.from_cache(language_code, cache = nil)
|
|
260
|
+
require_relative '../cache/model_cache'
|
|
261
|
+
|
|
262
|
+
cache ||= Cache::ModelCache.new
|
|
263
|
+
|
|
264
|
+
onnx_path = cache.get_onnx_model(language_code)
|
|
265
|
+
return nil unless onnx_path
|
|
266
|
+
|
|
267
|
+
from_file(onnx_path, language_code: language_code)
|
|
268
|
+
end
|
|
269
|
+
|
|
270
|
+
# String representation
|
|
271
|
+
#
|
|
272
|
+
# @return [String]
|
|
273
|
+
#
|
|
274
|
+
def to_s
|
|
275
|
+
"OnnxRuntimeModel(language: #{@language_code}, dimension: #{@dimension}, loaded: #{@loaded})"
|
|
276
|
+
end
|
|
277
|
+
alias inspect to_s
|
|
278
|
+
|
|
279
|
+
private
|
|
280
|
+
|
|
281
|
+
# Ensure model is loaded
|
|
282
|
+
#
|
|
283
|
+
def ensure_loaded
|
|
284
|
+
load! unless @loaded
|
|
285
|
+
end
|
|
286
|
+
|
|
287
|
+
# Check if index is valid
|
|
288
|
+
#
|
|
289
|
+
def valid_index?(index)
|
|
290
|
+
index.is_a?(Integer) && index >= 0
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
# Run batch inference for a batch of indices
|
|
294
|
+
#
|
|
295
|
+
# @param indices [Array<Integer>] Word indices
|
|
296
|
+
# @return [Array<Array<Float>>] Embedding vectors
|
|
297
|
+
#
|
|
298
|
+
def run_batch_inference(indices)
|
|
299
|
+
# Create input tensor
|
|
300
|
+
input_data = indices.flatten
|
|
301
|
+
|
|
302
|
+
output = @session.run(
|
|
303
|
+
[@output_name],
|
|
304
|
+
{ @input_name => input_data }
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
@inference_count += 1
|
|
308
|
+
|
|
309
|
+
# Extract embeddings
|
|
310
|
+
result = output.first
|
|
311
|
+
if result.is_a?(Array)
|
|
312
|
+
result
|
|
313
|
+
else
|
|
314
|
+
# Handle OrtValue or other wrappers
|
|
315
|
+
indices.length.times.map { |i| extract_single_embedding(result, i) }
|
|
316
|
+
end
|
|
317
|
+
end
|
|
318
|
+
|
|
319
|
+
# Extract embedding from output
|
|
320
|
+
#
|
|
321
|
+
# @param output [Object] ONNX output
|
|
322
|
+
# @return [Array<Float>]
|
|
323
|
+
#
|
|
324
|
+
def extract_embedding(output)
|
|
325
|
+
case output
|
|
326
|
+
when Array
|
|
327
|
+
output
|
|
328
|
+
when NumpyArray, Numo::SFloat
|
|
329
|
+
output.to_a
|
|
330
|
+
when OnnxRuntime::OrtValue
|
|
331
|
+
output.to_a
|
|
332
|
+
else
|
|
333
|
+
# Try to convert to array
|
|
334
|
+
output.respond_to?(:to_a) ? output.to_a : Array(output)
|
|
335
|
+
end
|
|
336
|
+
end
|
|
337
|
+
|
|
338
|
+
# Extract single embedding from batch output
|
|
339
|
+
#
|
|
340
|
+
# @param output [Object] ONNX batch output
|
|
341
|
+
# @param index [Integer] Index in batch
|
|
342
|
+
# @return [Array<Float>]
|
|
343
|
+
#
|
|
344
|
+
def extract_single_embedding(output, index)
|
|
345
|
+
case output
|
|
346
|
+
when Array
|
|
347
|
+
output[index]
|
|
348
|
+
when NumpyArray, Numo::SFloat
|
|
349
|
+
output[index, true].to_a
|
|
350
|
+
else
|
|
351
|
+
# Default: assume array-like
|
|
352
|
+
output[index].to_a
|
|
353
|
+
end
|
|
354
|
+
end
|
|
355
|
+
|
|
356
|
+
# Detect input name from model
|
|
357
|
+
#
|
|
358
|
+
# @return [String]
|
|
359
|
+
#
|
|
360
|
+
def detect_input_name
|
|
361
|
+
inputs = @session.inputs
|
|
362
|
+
inputs&.first&.dig(:name) || 'word_index'
|
|
363
|
+
end
|
|
364
|
+
|
|
365
|
+
# Detect output name from model
|
|
366
|
+
#
|
|
367
|
+
# @return [String]
|
|
368
|
+
#
|
|
369
|
+
def detect_output_name
|
|
370
|
+
outputs = @session.outputs
|
|
371
|
+
outputs&.first&.dig(:name) || 'embedding'
|
|
372
|
+
end
|
|
373
|
+
|
|
374
|
+
# Detect language from file path
|
|
375
|
+
#
|
|
376
|
+
# @param path [String]
|
|
377
|
+
# @return [String]
|
|
378
|
+
#
|
|
379
|
+
def self.detect_language_from_path(path)
|
|
380
|
+
basename = File.basename(path)
|
|
381
|
+
|
|
382
|
+
if basename =~ /\.([a-z]{2})\./i
|
|
383
|
+
Regexp.last_match(1).downcase
|
|
384
|
+
else
|
|
385
|
+
'en'
|
|
386
|
+
end
|
|
387
|
+
end
|
|
388
|
+
end
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Protocol - Ruby interface/contract system
|
|
4
|
+
#
|
|
5
|
+
# Provides a simple way to define interfaces with required and optional methods.
|
|
6
|
+
#
|
|
7
|
+
module Protocol
|
|
8
|
+
# Store required method names
|
|
9
|
+
def required_methods
|
|
10
|
+
@required_methods ||= Set.new
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Store optional method names
|
|
14
|
+
def optional_methods
|
|
15
|
+
@optional_methods ||= Set.new
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Define required methods
|
|
19
|
+
def required_methods(*names)
|
|
20
|
+
names.each { |n| required_methods << n }
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Define optional methods
|
|
24
|
+
def optional_methods(*names)
|
|
25
|
+
names.each { |n| optional_methods << n }
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Check compliance
|
|
29
|
+
def compliance_errors(klass)
|
|
30
|
+
required_methods.select { |m| !klass.respond_to?(m) }
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Assert compliance
|
|
34
|
+
def assert_implemented_by!(klass)
|
|
35
|
+
errors = compliance_errors(klass)
|
|
36
|
+
raise "Missing methods: #{errors.join(', ')}" unless errors.empty?
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Protocol error
|
|
41
|
+
class ProtocolError < StandardError
|
|
42
|
+
attr_reader :klass, :protocol, :missing_methods
|
|
43
|
+
|
|
44
|
+
def initialize(klass, protocol, missing_methods)
|
|
45
|
+
@klass = klass
|
|
46
|
+
@protocol = protocol
|
|
47
|
+
@missing_methods = missing_methods
|
|
48
|
+
super("#{klass} missing: #{missing_methods.join(', ')}")
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# EmbeddingModel Protocol
|
|
53
|
+
module EmbeddingModelProtocol
|
|
54
|
+
extend Protocol
|
|
55
|
+
|
|
56
|
+
required_methods :dimension, :language_code, :get_embedding, :get_embeddings
|
|
57
|
+
required_methods :load!, :unload!, :loaded?, :ready?
|
|
58
|
+
|
|
59
|
+
optional_methods :get_embeddings_batch, :batch_size, :preload_embeddings!
|
|
60
|
+
optional_methods :supports_batching?, :model_type, :model_info
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# SimilarityEngine Protocol
|
|
64
|
+
module SimilarityEngineProtocol
|
|
65
|
+
extend Protocol
|
|
66
|
+
|
|
67
|
+
required_methods :cosine, :dot_product, :euclidean, :manhattan
|
|
68
|
+
required_methods :pre_normalize, :normalize_and_compute
|
|
69
|
+
|
|
70
|
+
optional_methods :cosine_batch, :compute_all_pairs
|
|
71
|
+
optional_methods :is_normalized?, :normalization_required?
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Vocabulary Protocol
|
|
75
|
+
module VocabularyProtocol
|
|
76
|
+
extend Protocol
|
|
77
|
+
|
|
78
|
+
required_methods :lookup, :get_word, :include?, :size, :words
|
|
79
|
+
required_methods :valid_index?, :common_words, :to_h
|
|
80
|
+
|
|
81
|
+
optional_methods :sample, :sub_vocabulary, :words_starting_with
|
|
82
|
+
optional_methods :save_to_file, :language_code
|
|
83
|
+
end
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Protocols index - all protocol definitions
|
|
4
|
+
#
|
|
5
|
+
# This file provides access to all protocol modules.
|
|
6
|
+
# Protocols define contracts that implementations must follow.
|
|
7
|
+
|
|
8
|
+
require_relative 'protocols/embedding_model'
|
|
9
|
+
require_relative 'protocols/similarity_engine'
|
|
10
|
+
require_relative 'protocols/vocabulary'
|
|
11
|
+
|
|
12
|
+
# Namespace for protocol definitions (for backward compatibility)
|
|
13
|
+
module EmbeddingProtocols
|
|
14
|
+
EmbeddingModel = ::EmbeddingModel
|
|
15
|
+
SimilarityEngine = ::SimilarityEngine
|
|
16
|
+
Vocabulary = ::Vocabulary
|
|
17
|
+
end
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Registry - Plugin system for embeddings components
|
|
4
|
+
#
|
|
5
|
+
# Provides a centralized registry for embedding models, similarity engines,
|
|
6
|
+
# and vocabularies. Allows registering and retrieving custom implementations.
|
|
7
|
+
#
|
|
8
|
+
# @example Registering a custom model
|
|
9
|
+
# Registry.register_model(:my_model, MyCustomModel)
|
|
10
|
+
#
|
|
11
|
+
# @example Creating from registry
|
|
12
|
+
# model = Registry.create_model(:my_model, vectors: my_vectors)
|
|
13
|
+
#
|
|
14
|
+
# @example Listing available implementations
|
|
15
|
+
# Registry.models.keys # => [:onnx, :my_model]
|
|
16
|
+
#
|
|
17
|
+
class EmbeddingRegistry
|
|
18
|
+
class << self
|
|
19
|
+
# @return [Hash{Symbol => Class}] Registered models
|
|
20
|
+
attr_reader :models
|
|
21
|
+
|
|
22
|
+
# @return [Hash{Symbol => Class}] Registered engines
|
|
23
|
+
attr_reader :engines
|
|
24
|
+
|
|
25
|
+
# @return [Hash{Symbol => Class}] Registered vocabularies
|
|
26
|
+
attr_reader :vocabularies
|
|
27
|
+
|
|
28
|
+
# Initialize registry
|
|
29
|
+
#
|
|
30
|
+
def init
|
|
31
|
+
@models ||= {}
|
|
32
|
+
@engines ||= {}
|
|
33
|
+
@vocabularies ||= {}
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Register an embedding model
|
|
37
|
+
#
|
|
38
|
+
# @param name [Symbol] Model identifier
|
|
39
|
+
# @param klass [Class] Model class (must implement EmbeddingModel protocol)
|
|
40
|
+
#
|
|
41
|
+
def register_model(name, klass)
|
|
42
|
+
init
|
|
43
|
+
@models[name] = klass
|
|
44
|
+
klass
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Register a similarity engine
|
|
48
|
+
#
|
|
49
|
+
# @param name [Symbol] Engine identifier
|
|
50
|
+
# @param klass [Class] Engine class (must implement SimilarityEngine protocol)
|
|
51
|
+
#
|
|
52
|
+
def register_engine(name, klass)
|
|
53
|
+
init
|
|
54
|
+
@engines[name] = klass
|
|
55
|
+
klass
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Register a vocabulary
|
|
59
|
+
#
|
|
60
|
+
# @param name [Symbol] Vocabulary identifier
|
|
61
|
+
# @param klass [Class] Vocabulary class (must implement Vocabulary protocol)
|
|
62
|
+
#
|
|
63
|
+
def register_vocabulary(name, klass)
|
|
64
|
+
init
|
|
65
|
+
@vocabularies[name] = klass
|
|
66
|
+
klass
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Get registered model class
|
|
70
|
+
#
|
|
71
|
+
# @param name [Symbol] Model identifier
|
|
72
|
+
# @return [Class, nil]
|
|
73
|
+
#
|
|
74
|
+
def model(name)
|
|
75
|
+
init
|
|
76
|
+
@models[name]
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
# Get registered engine class
|
|
80
|
+
#
|
|
81
|
+
# @param name [Symbol] Engine identifier
|
|
82
|
+
# @return [Class, nil]
|
|
83
|
+
#
|
|
84
|
+
def engine(name)
|
|
85
|
+
init
|
|
86
|
+
@engines[name]
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
# Get registered vocabulary class
|
|
90
|
+
#
|
|
91
|
+
# @param name [Symbol] Vocabulary identifier
|
|
92
|
+
# @return [Class, nil]
|
|
93
|
+
#
|
|
94
|
+
def vocabulary(name)
|
|
95
|
+
init
|
|
96
|
+
@vocabularies[name]
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# Create a model instance
|
|
100
|
+
#
|
|
101
|
+
# @param name [Symbol] Model identifier
|
|
102
|
+
# @param kwargs [Hash] Model constructor arguments
|
|
103
|
+
# @return [EmbeddingModel]
|
|
104
|
+
#
|
|
105
|
+
def create_model(name, **kwargs)
|
|
106
|
+
klass = model(name)
|
|
107
|
+
raise ArgumentError, "Unknown model: #{name}" unless klass
|
|
108
|
+
|
|
109
|
+
klass.new(**kwargs)
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
# Create an engine instance
|
|
113
|
+
#
|
|
114
|
+
# @param name [Symbol] Engine identifier
|
|
115
|
+
# @param kwargs [Hash] Engine constructor arguments
|
|
116
|
+
# @return [SimilarityEngine]
|
|
117
|
+
#
|
|
118
|
+
def create_engine(name, **kwargs)
|
|
119
|
+
klass = engine(name)
|
|
120
|
+
raise ArgumentError, "Unknown engine: #{name}" unless klass
|
|
121
|
+
|
|
122
|
+
klass.new(**kwargs)
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Create a vocabulary instance
|
|
126
|
+
#
|
|
127
|
+
# @param name [Symbol] Vocabulary identifier
|
|
128
|
+
# @param kwargs [Hash] Vocabulary constructor arguments
|
|
129
|
+
# @return [Vocabulary]
|
|
130
|
+
#
|
|
131
|
+
def create_vocabulary(name, **kwargs)
|
|
132
|
+
klass = vocabulary(name)
|
|
133
|
+
raise ArgumentError, "Unknown vocabulary: #{name}" unless klass
|
|
134
|
+
|
|
135
|
+
klass.new(**kwargs)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
# List all registered models
|
|
139
|
+
#
|
|
140
|
+
# @return [Array<Symbol>]
|
|
141
|
+
#
|
|
142
|
+
def model_names
|
|
143
|
+
init
|
|
144
|
+
@models.keys
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
# List all registered engines
|
|
148
|
+
#
|
|
149
|
+
# @return [Array<Symbol>]
|
|
150
|
+
#
|
|
151
|
+
def engine_names
|
|
152
|
+
init
|
|
153
|
+
@engines.keys
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
# List all registered vocabularies
|
|
157
|
+
#
|
|
158
|
+
# @return [Array<Symbol>]
|
|
159
|
+
#
|
|
160
|
+
def vocabulary_names
|
|
161
|
+
init
|
|
162
|
+
@vocabularies.keys
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
# Clear all registrations
|
|
166
|
+
#
|
|
167
|
+
def reset!
|
|
168
|
+
@models = {}
|
|
169
|
+
@engines = {}
|
|
170
|
+
@vocabularies = {}
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
# Register built-in implementations
|
|
176
|
+
require_relative 'onnx_runtime_model'
|
|
177
|
+
require_relative 'similarity_engine'
|
|
178
|
+
require_relative 'vocabulary'
|
|
179
|
+
|
|
180
|
+
EmbeddingRegistry.register_model(:onnx, OnnxRuntimeModel)
|
|
181
|
+
EmbeddingRegistry.register_engine(:cosine, SimilarityEngine)
|
|
182
|
+
EmbeddingRegistry.register_vocabulary(:json, Vocabulary)
|