ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/.dylibs/libctranslate2.4.7.0.dylib +0 -0
- ctranslate2/__init__.py +66 -0
- ctranslate2/_ext.cpython-314-darwin.so +0 -0
- ctranslate2/converters/__init__.py +8 -0
- ctranslate2/converters/converter.py +109 -0
- ctranslate2/converters/eole_ct2.py +353 -0
- ctranslate2/converters/fairseq.py +347 -0
- ctranslate2/converters/marian.py +315 -0
- ctranslate2/converters/openai_gpt2.py +95 -0
- ctranslate2/converters/opennmt_py.py +361 -0
- ctranslate2/converters/opennmt_tf.py +455 -0
- ctranslate2/converters/opus_mt.py +44 -0
- ctranslate2/converters/transformers.py +3721 -0
- ctranslate2/converters/utils.py +127 -0
- ctranslate2/extensions.py +589 -0
- ctranslate2/logging.py +45 -0
- ctranslate2/models/__init__.py +18 -0
- ctranslate2/specs/__init__.py +18 -0
- ctranslate2/specs/attention_spec.py +98 -0
- ctranslate2/specs/common_spec.py +66 -0
- ctranslate2/specs/model_spec.py +767 -0
- ctranslate2/specs/transformer_spec.py +797 -0
- ctranslate2/specs/wav2vec2_spec.py +72 -0
- ctranslate2/specs/wav2vec2bert_spec.py +97 -0
- ctranslate2/specs/whisper_spec.py +77 -0
- ctranslate2/version.py +3 -0
- ctranslate2-4.7.0.dist-info/METADATA +180 -0
- ctranslate2-4.7.0.dist-info/RECORD +31 -0
- ctranslate2-4.7.0.dist-info/WHEEL +6 -0
- ctranslate2-4.7.0.dist-info/entry_points.txt +8 -0
- ctranslate2-4.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections
|
|
3
|
+
import itertools
|
|
4
|
+
import queue
|
|
5
|
+
import threading
|
|
6
|
+
|
|
7
|
+
from typing import AsyncIterable, Callable, Iterable, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
from ctranslate2._ext import (
|
|
10
|
+
GenerationResult,
|
|
11
|
+
GenerationStepResult,
|
|
12
|
+
Generator,
|
|
13
|
+
ScoringResult,
|
|
14
|
+
TranslationResult,
|
|
15
|
+
Translator,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def register_extensions():
|
|
20
|
+
"""Registers additional attributes to compiled modules."""
|
|
21
|
+
setattr(Translator, "translate_iterable", translator_translate_iterable)
|
|
22
|
+
setattr(Translator, "score_iterable", translator_score_iterable)
|
|
23
|
+
setattr(Translator, "generate_tokens", translator_generate_tokens)
|
|
24
|
+
setattr(Generator, "generate_iterable", generator_generate_iterable)
|
|
25
|
+
setattr(Generator, "score_iterable", generator_score_iterable)
|
|
26
|
+
setattr(Generator, "generate_tokens", generator_generate_tokens)
|
|
27
|
+
setattr(Generator, "async_generate_tokens", generator_async_generate_tokens)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def translator_translate_iterable(
|
|
31
|
+
translator: Translator,
|
|
32
|
+
source: Iterable[List[str]],
|
|
33
|
+
target_prefix: Optional[Iterable[List[str]]] = None,
|
|
34
|
+
max_batch_size: int = 32,
|
|
35
|
+
batch_type: str = "examples",
|
|
36
|
+
**kwargs,
|
|
37
|
+
) -> Iterable[TranslationResult]:
|
|
38
|
+
"""Translates an iterable of tokenized examples.
|
|
39
|
+
|
|
40
|
+
This method is built on top of :meth:`ctranslate2.Translator.translate_batch`
|
|
41
|
+
to efficiently translate an arbitrarily large stream of data. It enables the
|
|
42
|
+
following optimizations:
|
|
43
|
+
|
|
44
|
+
* stream processing (the iterable is not fully materialized in memory)
|
|
45
|
+
* parallel translations (if the translator has multiple workers)
|
|
46
|
+
* asynchronous batch prefetching
|
|
47
|
+
* local sorting by length
|
|
48
|
+
|
|
49
|
+
Arguments:
|
|
50
|
+
source: An iterable of tokenized source examples.
|
|
51
|
+
target_prefix: An optional iterable of tokenized target prefixes.
|
|
52
|
+
max_batch_size: The maximum batch size.
|
|
53
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
54
|
+
**kwargs: Any translation options accepted by
|
|
55
|
+
:meth:`ctranslate2.Translator.translate_batch`.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A generator iterator over :class:`ctranslate2.TranslationResult` instances.
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
This method can be used to efficiently translate text files:
|
|
62
|
+
|
|
63
|
+
.. code-block:: python
|
|
64
|
+
|
|
65
|
+
# Replace by your own tokenization and detokenization functions.
|
|
66
|
+
tokenize_fn = lambda line: line.strip().split()
|
|
67
|
+
detokenize_fn = lambda tokens: " ".join(tokens)
|
|
68
|
+
|
|
69
|
+
with open("input.txt") as input_file:
|
|
70
|
+
source = map(tokenize_fn, input_file)
|
|
71
|
+
results = translator.translate_iterable(source, max_batch_size=64)
|
|
72
|
+
|
|
73
|
+
for result in results:
|
|
74
|
+
tokens = result.hypotheses[0]
|
|
75
|
+
target = detokenize_fn(tokens)
|
|
76
|
+
print(target)
|
|
77
|
+
"""
|
|
78
|
+
iterables = [source]
|
|
79
|
+
if target_prefix is not None:
|
|
80
|
+
iterables.append(target_prefix)
|
|
81
|
+
|
|
82
|
+
yield from _process_iterable(
|
|
83
|
+
translator.translate_batch,
|
|
84
|
+
iterables,
|
|
85
|
+
max_batch_size,
|
|
86
|
+
batch_type,
|
|
87
|
+
**kwargs,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def translator_score_iterable(
|
|
92
|
+
translator: Translator,
|
|
93
|
+
source: Iterable[List[str]],
|
|
94
|
+
target: Iterable[List[str]],
|
|
95
|
+
max_batch_size: int = 64,
|
|
96
|
+
batch_type: str = "examples",
|
|
97
|
+
**kwargs,
|
|
98
|
+
) -> Iterable[ScoringResult]:
|
|
99
|
+
"""Scores an iterable of tokenized examples.
|
|
100
|
+
|
|
101
|
+
This method is built on top of :meth:`ctranslate2.Translator.score_batch`
|
|
102
|
+
to efficiently score an arbitrarily large stream of data. It enables the
|
|
103
|
+
following optimizations:
|
|
104
|
+
|
|
105
|
+
* stream processing (the iterable is not fully materialized in memory)
|
|
106
|
+
* parallel scoring (if the translator has multiple workers)
|
|
107
|
+
* asynchronous batch prefetching
|
|
108
|
+
* local sorting by length
|
|
109
|
+
|
|
110
|
+
Arguments:
|
|
111
|
+
source: An iterable of tokenized source examples.
|
|
112
|
+
target: An iterable of tokenized target examples.
|
|
113
|
+
max_batch_size: The maximum batch size.
|
|
114
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
115
|
+
**kwargs: Any scoring options accepted by
|
|
116
|
+
:meth:`ctranslate2.Translator.score_batch`.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A generator iterator over :class:`ctranslate2.ScoringResult` instances.
|
|
120
|
+
"""
|
|
121
|
+
yield from _process_iterable(
|
|
122
|
+
translator.score_batch,
|
|
123
|
+
[source, target],
|
|
124
|
+
max_batch_size,
|
|
125
|
+
batch_type,
|
|
126
|
+
**kwargs,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def generator_generate_iterable(
|
|
131
|
+
generator: Generator,
|
|
132
|
+
start_tokens: Iterable[List[str]],
|
|
133
|
+
max_batch_size: int = 32,
|
|
134
|
+
batch_type: str = "examples",
|
|
135
|
+
**kwargs,
|
|
136
|
+
) -> Iterable[GenerationResult]:
|
|
137
|
+
"""Generates from an iterable of tokenized prompts.
|
|
138
|
+
|
|
139
|
+
This method is built on top of :meth:`ctranslate2.Generator.generate_batch`
|
|
140
|
+
to efficiently run generation on an arbitrarily large stream of data. It enables
|
|
141
|
+
the following optimizations:
|
|
142
|
+
|
|
143
|
+
* stream processing (the iterable is not fully materialized in memory)
|
|
144
|
+
* parallel generations (if the generator has multiple workers)
|
|
145
|
+
* asynchronous batch prefetching
|
|
146
|
+
* local sorting by length
|
|
147
|
+
|
|
148
|
+
Arguments:
|
|
149
|
+
start_tokens: An iterable of tokenized prompts.
|
|
150
|
+
max_batch_size: The maximum batch size.
|
|
151
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
152
|
+
**kwargs: Any generation options accepted by
|
|
153
|
+
:meth:`ctranslate2.Generator.generate_batch`.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
A generator iterator over :class:`ctranslate2.GenerationResult` instances.
|
|
157
|
+
"""
|
|
158
|
+
yield from _process_iterable(
|
|
159
|
+
generator.generate_batch,
|
|
160
|
+
[start_tokens],
|
|
161
|
+
max_batch_size,
|
|
162
|
+
batch_type,
|
|
163
|
+
**kwargs,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def generator_score_iterable(
|
|
168
|
+
generator: Generator,
|
|
169
|
+
tokens: Iterable[List[str]],
|
|
170
|
+
max_batch_size: int = 64,
|
|
171
|
+
batch_type: str = "examples",
|
|
172
|
+
**kwargs,
|
|
173
|
+
) -> Iterable[ScoringResult]:
|
|
174
|
+
"""Scores an iterable of tokenized examples.
|
|
175
|
+
|
|
176
|
+
This method is built on top of :meth:`ctranslate2.Generator.score_batch`
|
|
177
|
+
to efficiently score an arbitrarily large stream of data. It enables
|
|
178
|
+
the following optimizations:
|
|
179
|
+
|
|
180
|
+
* stream processing (the iterable is not fully materialized in memory)
|
|
181
|
+
* parallel scoring (if the generator has multiple workers)
|
|
182
|
+
* asynchronous batch prefetching
|
|
183
|
+
* local sorting by length
|
|
184
|
+
|
|
185
|
+
Arguments:
|
|
186
|
+
tokens: An iterable of tokenized examples.
|
|
187
|
+
max_batch_size: The maximum batch size.
|
|
188
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
189
|
+
**kwargs: Any score options accepted by
|
|
190
|
+
:meth:`ctranslate2.Generator.score_batch`.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A generator iterator over :class:`ctranslate2.ScoringResult` instances.
|
|
194
|
+
"""
|
|
195
|
+
yield from _process_iterable(
|
|
196
|
+
generator.score_batch,
|
|
197
|
+
[tokens],
|
|
198
|
+
max_batch_size,
|
|
199
|
+
batch_type,
|
|
200
|
+
**kwargs,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def translator_generate_tokens(
|
|
205
|
+
translator: Translator,
|
|
206
|
+
source: List[str],
|
|
207
|
+
target_prefix: Optional[List[str]] = None,
|
|
208
|
+
*,
|
|
209
|
+
max_decoding_length: int = 256,
|
|
210
|
+
min_decoding_length: int = 1,
|
|
211
|
+
sampling_topk: int = 1,
|
|
212
|
+
sampling_topp: float = 1,
|
|
213
|
+
sampling_temperature: float = 1,
|
|
214
|
+
return_log_prob: bool = False,
|
|
215
|
+
repetition_penalty: float = 1,
|
|
216
|
+
no_repeat_ngram_size: int = 0,
|
|
217
|
+
disable_unk: bool = False,
|
|
218
|
+
suppress_sequences: Optional[List[List[str]]] = None,
|
|
219
|
+
end_token: Optional[Union[str, List[str], List[int]]] = None,
|
|
220
|
+
max_input_length: int = 1024,
|
|
221
|
+
use_vmap: bool = False,
|
|
222
|
+
) -> Iterable[GenerationStepResult]:
|
|
223
|
+
"""Yields tokens as they are generated by the model.
|
|
224
|
+
|
|
225
|
+
Arguments:
|
|
226
|
+
source: Source tokens.
|
|
227
|
+
target_prefix: Optional target prefix tokens.
|
|
228
|
+
max_decoding_length: Maximum prediction length.
|
|
229
|
+
min_decoding_length: Minimum prediction length.
|
|
230
|
+
sampling_topk: Randomly sample predictions from the top K candidates.
|
|
231
|
+
sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value.
|
|
232
|
+
sampling_temperature: Sampling temperature to generate more random samples.
|
|
233
|
+
return_log_prob: Include the token log probability in the result.
|
|
234
|
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
|
235
|
+
(set > 1 to penalize).
|
|
236
|
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size
|
|
237
|
+
(set 0 to disable).
|
|
238
|
+
disable_unk: Disable the generation of the unknown token.
|
|
239
|
+
suppress_sequences: Disable the generation of some sequences of tokens.
|
|
240
|
+
end_token: Stop the decoding on one of these tokens (defaults to the model EOS token).
|
|
241
|
+
max_input_length: Truncate inputs after this many tokens (set 0 to disable).
|
|
242
|
+
use_vmap: Use the vocabulary mapping file saved in this model
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
A generator iterator over :class:`ctranslate2.GenerationStepResult` instances.
|
|
246
|
+
|
|
247
|
+
Note:
|
|
248
|
+
This generation method is not compatible with beam search which requires a complete decoding.
|
|
249
|
+
"""
|
|
250
|
+
yield from _generate_tokens(
|
|
251
|
+
translator.translate_batch,
|
|
252
|
+
[source],
|
|
253
|
+
[target_prefix] if target_prefix is not None else None,
|
|
254
|
+
repetition_penalty=repetition_penalty,
|
|
255
|
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
256
|
+
disable_unk=disable_unk,
|
|
257
|
+
suppress_sequences=suppress_sequences,
|
|
258
|
+
end_token=end_token,
|
|
259
|
+
max_decoding_length=max_decoding_length,
|
|
260
|
+
min_decoding_length=min_decoding_length,
|
|
261
|
+
sampling_topk=sampling_topk,
|
|
262
|
+
sampling_topp=sampling_topp,
|
|
263
|
+
sampling_temperature=sampling_temperature,
|
|
264
|
+
return_scores=return_log_prob,
|
|
265
|
+
max_input_length=max_input_length,
|
|
266
|
+
use_vmap=use_vmap,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def generator_generate_tokens(
|
|
271
|
+
generator: Generator,
|
|
272
|
+
prompt: Union[List[str], List[List[str]]],
|
|
273
|
+
max_batch_size: int = 0,
|
|
274
|
+
batch_type: str = "examples",
|
|
275
|
+
*,
|
|
276
|
+
max_length: int = 512,
|
|
277
|
+
min_length: int = 0,
|
|
278
|
+
sampling_topk: int = 1,
|
|
279
|
+
sampling_topp: float = 1,
|
|
280
|
+
sampling_temperature: float = 1,
|
|
281
|
+
return_log_prob: bool = False,
|
|
282
|
+
repetition_penalty: float = 1,
|
|
283
|
+
no_repeat_ngram_size: int = 0,
|
|
284
|
+
disable_unk: bool = False,
|
|
285
|
+
suppress_sequences: Optional[List[List[str]]] = None,
|
|
286
|
+
end_token: Optional[Union[str, List[str], List[int]]] = None,
|
|
287
|
+
static_prompt: Optional[List[str]] = None,
|
|
288
|
+
cache_static_prompt: bool = True,
|
|
289
|
+
callback: Callable[[GenerationStepResult], bool] = None,
|
|
290
|
+
) -> Iterable[GenerationStepResult]:
|
|
291
|
+
"""Yields tokens as they are generated by the model.
|
|
292
|
+
|
|
293
|
+
Arguments:
|
|
294
|
+
prompt: Batch of start tokens. If the decoder starts from a
|
|
295
|
+
special start token like <s>, this token should be added to this input.
|
|
296
|
+
max_batch_size: The maximum batch size.
|
|
297
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
298
|
+
max_length: Maximum generation length.
|
|
299
|
+
min_length: Minimum generation length.
|
|
300
|
+
sampling_topk: Randomly sample predictions from the top K candidates.
|
|
301
|
+
sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value.
|
|
302
|
+
sampling_temperature: Sampling temperature to generate more random samples.
|
|
303
|
+
return_log_prob: Include the token log probability in the result.
|
|
304
|
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
|
305
|
+
(set > 1 to penalize).
|
|
306
|
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size
|
|
307
|
+
(set 0 to disable).
|
|
308
|
+
disable_unk: Disable the generation of the unknown token.
|
|
309
|
+
suppress_sequences: Disable the generation of some sequences of tokens.
|
|
310
|
+
end_token: Stop the decoding on one these tokens (defaults to the model EOS token).
|
|
311
|
+
static_prompt: If the model expects a static prompt (a.k.a. system prompt)
|
|
312
|
+
it can be set here to simplify the inputs and optionally cache the model
|
|
313
|
+
state for this prompt to accelerate future generations.
|
|
314
|
+
cache_static_prompt: Cache the model state after the static prompt and
|
|
315
|
+
reuse it for future generations using the same static prompt.
|
|
316
|
+
callback: Optional function that is called for each generated token when
|
|
317
|
+
obj:`beam_size` is 1. If the callback function returns ``True``, the
|
|
318
|
+
decoding will stop for this batch index.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
A generator iterator over :class:`ctranslate2.GenerationStepResult` instances.
|
|
322
|
+
|
|
323
|
+
Note:
|
|
324
|
+
This generation method is not compatible with beam search which requires a complete decoding.
|
|
325
|
+
"""
|
|
326
|
+
if len(prompt) > 0 and isinstance(prompt[0], str):
|
|
327
|
+
prompt = [prompt]
|
|
328
|
+
yield from _generate_tokens(
|
|
329
|
+
generator.generate_batch,
|
|
330
|
+
prompt,
|
|
331
|
+
max_batch_size=max_batch_size,
|
|
332
|
+
batch_type=batch_type,
|
|
333
|
+
repetition_penalty=repetition_penalty,
|
|
334
|
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
335
|
+
disable_unk=disable_unk,
|
|
336
|
+
suppress_sequences=suppress_sequences,
|
|
337
|
+
end_token=end_token,
|
|
338
|
+
max_length=max_length,
|
|
339
|
+
min_length=min_length,
|
|
340
|
+
sampling_topk=sampling_topk,
|
|
341
|
+
sampling_topp=sampling_topp,
|
|
342
|
+
sampling_temperature=sampling_temperature,
|
|
343
|
+
return_scores=return_log_prob,
|
|
344
|
+
static_prompt=static_prompt,
|
|
345
|
+
cache_static_prompt=cache_static_prompt,
|
|
346
|
+
include_prompt_in_result=False,
|
|
347
|
+
callback=callback,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def generator_async_generate_tokens(
|
|
352
|
+
generator: Generator,
|
|
353
|
+
prompt: Union[List[str], List[List[str]]],
|
|
354
|
+
max_batch_size: int = 0,
|
|
355
|
+
batch_type: str = "examples",
|
|
356
|
+
*,
|
|
357
|
+
max_length: int = 512,
|
|
358
|
+
min_length: int = 0,
|
|
359
|
+
sampling_topk: int = 1,
|
|
360
|
+
sampling_topp: float = 1,
|
|
361
|
+
sampling_temperature: float = 1,
|
|
362
|
+
return_log_prob: bool = False,
|
|
363
|
+
repetition_penalty: float = 1,
|
|
364
|
+
no_repeat_ngram_size: int = 0,
|
|
365
|
+
disable_unk: bool = False,
|
|
366
|
+
suppress_sequences: Optional[List[List[str]]] = None,
|
|
367
|
+
end_token: Optional[Union[str, List[str], List[int]]] = None,
|
|
368
|
+
static_prompt: Optional[List[str]] = None,
|
|
369
|
+
cache_static_prompt: bool = True,
|
|
370
|
+
callback: Callable[[GenerationStepResult], bool] = None,
|
|
371
|
+
) -> AsyncIterable[GenerationStepResult]:
|
|
372
|
+
"""Yields tokens asynchronously as they are generated by the model.
|
|
373
|
+
|
|
374
|
+
Arguments:
|
|
375
|
+
prompt: Batch of start tokens. If the decoder starts from a
|
|
376
|
+
special start token like <s>, this token should be added to this input.
|
|
377
|
+
max_batch_size: The maximum batch size.
|
|
378
|
+
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
|
|
379
|
+
max_length: Maximum generation length.
|
|
380
|
+
min_length: Minimum generation length.
|
|
381
|
+
sampling_topk: Randomly sample predictions from the top K candidates.
|
|
382
|
+
sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value.
|
|
383
|
+
sampling_temperature: Sampling temperature to generate more random samples.
|
|
384
|
+
return_log_prob: Include the token log probability in the result.
|
|
385
|
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
|
386
|
+
(set > 1 to penalize).
|
|
387
|
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size
|
|
388
|
+
(set 0 to disable).
|
|
389
|
+
disable_unk: Disable the generation of the unknown token.
|
|
390
|
+
suppress_sequences: Disable the generation of some sequences of tokens.
|
|
391
|
+
end_token: Stop the decoding on one of these tokens (defaults to the model EOS token).
|
|
392
|
+
static_prompt: If the model expects a static prompt (a.k.a. system prompt)
|
|
393
|
+
it can be set here to simplify the inputs and optionally cache the model
|
|
394
|
+
state for this prompt to accelerate future generations.
|
|
395
|
+
cache_static_prompt: Cache the model state after the static prompt and
|
|
396
|
+
reuse it for future generations using the same static prompt.
|
|
397
|
+
callback: Optional function that is called for each generated token when
|
|
398
|
+
obj:`beam_size` is 1. If the callback function returns ``True``, the
|
|
399
|
+
decoding will stop for this batch index.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
An async generator iterator over :class:`ctranslate2.GenerationStepResult` instances.
|
|
403
|
+
|
|
404
|
+
Note:
|
|
405
|
+
This generation method is not compatible with beam search which requires a complete decoding.
|
|
406
|
+
"""
|
|
407
|
+
if len(prompt) > 0 and isinstance(prompt[0], str):
|
|
408
|
+
prompt = [prompt]
|
|
409
|
+
async for step_result in AsyncGenerator(
|
|
410
|
+
generator.generate_batch,
|
|
411
|
+
prompt,
|
|
412
|
+
max_batch_size=max_batch_size,
|
|
413
|
+
batch_type=batch_type,
|
|
414
|
+
repetition_penalty=repetition_penalty,
|
|
415
|
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
416
|
+
disable_unk=disable_unk,
|
|
417
|
+
suppress_sequences=suppress_sequences,
|
|
418
|
+
end_token=end_token,
|
|
419
|
+
max_length=max_length,
|
|
420
|
+
min_length=min_length,
|
|
421
|
+
sampling_topk=sampling_topk,
|
|
422
|
+
sampling_topp=sampling_topp,
|
|
423
|
+
sampling_temperature=sampling_temperature,
|
|
424
|
+
return_scores=return_log_prob,
|
|
425
|
+
static_prompt=static_prompt,
|
|
426
|
+
cache_static_prompt=cache_static_prompt,
|
|
427
|
+
include_prompt_in_result=False,
|
|
428
|
+
callback=callback,
|
|
429
|
+
):
|
|
430
|
+
yield step_result
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class AsyncGenerator:
|
|
434
|
+
def __init__(self, process_func, *args, **kwargs):
|
|
435
|
+
self.queue = asyncio.Queue()
|
|
436
|
+
self.shutdown_event = threading.Event()
|
|
437
|
+
self.iterator_task = None
|
|
438
|
+
self.process_func = process_func
|
|
439
|
+
self.args = args
|
|
440
|
+
self.kwargs = kwargs
|
|
441
|
+
|
|
442
|
+
async def producer(self):
|
|
443
|
+
# Data generation logic here
|
|
444
|
+
for step_result in _generate_tokens(
|
|
445
|
+
self.process_func, *self.args, **self.kwargs
|
|
446
|
+
):
|
|
447
|
+
await self.queue.put(step_result)
|
|
448
|
+
await asyncio.sleep(0.0001)
|
|
449
|
+
# asyc sleep otherwise this doesn't yield any result
|
|
450
|
+
if self.shutdown_event.is_set():
|
|
451
|
+
break
|
|
452
|
+
await self.queue.put(None)
|
|
453
|
+
|
|
454
|
+
def __aiter__(self):
|
|
455
|
+
self.iterator_task = asyncio.create_task(self.producer())
|
|
456
|
+
return self
|
|
457
|
+
|
|
458
|
+
async def __anext__(self):
|
|
459
|
+
if self.shutdown_event.is_set():
|
|
460
|
+
raise StopAsyncIteration
|
|
461
|
+
|
|
462
|
+
try:
|
|
463
|
+
item = await self.queue.get()
|
|
464
|
+
if item is None:
|
|
465
|
+
self.shutdown_event.set()
|
|
466
|
+
raise StopAsyncIteration
|
|
467
|
+
return item
|
|
468
|
+
except asyncio.CancelledError:
|
|
469
|
+
self.shutdown_event.set()
|
|
470
|
+
raise StopAsyncIteration
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _generate_tokens(process_func, *args, **kwargs):
|
|
474
|
+
step_results = queue.Queue()
|
|
475
|
+
generator_closed = threading.Event()
|
|
476
|
+
|
|
477
|
+
user_callback = kwargs.get("callback", None)
|
|
478
|
+
if user_callback is None:
|
|
479
|
+
user_callback = lambda step_result: False
|
|
480
|
+
|
|
481
|
+
def _callback(step_result):
|
|
482
|
+
user_callback_result = user_callback(step_result)
|
|
483
|
+
step_results.put(step_result)
|
|
484
|
+
|
|
485
|
+
return generator_closed.is_set() or user_callback_result
|
|
486
|
+
|
|
487
|
+
kwargs.update(
|
|
488
|
+
{
|
|
489
|
+
"asynchronous": True,
|
|
490
|
+
"beam_size": 1,
|
|
491
|
+
"callback": _callback,
|
|
492
|
+
}
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
async_results = process_func(*args, **kwargs)
|
|
496
|
+
|
|
497
|
+
def _catch_exception():
|
|
498
|
+
try:
|
|
499
|
+
for result in async_results:
|
|
500
|
+
result.result()
|
|
501
|
+
except Exception as e:
|
|
502
|
+
step_results.put(e)
|
|
503
|
+
step_results.put(None)
|
|
504
|
+
|
|
505
|
+
thread = threading.Thread(target=_catch_exception, daemon=True)
|
|
506
|
+
thread.start()
|
|
507
|
+
|
|
508
|
+
while True:
|
|
509
|
+
step_result = step_results.get()
|
|
510
|
+
|
|
511
|
+
if step_result is None:
|
|
512
|
+
break
|
|
513
|
+
|
|
514
|
+
if isinstance(step_result, Exception):
|
|
515
|
+
raise step_result
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
yield step_result
|
|
519
|
+
except GeneratorExit:
|
|
520
|
+
generator_closed.set()
|
|
521
|
+
break
|
|
522
|
+
|
|
523
|
+
# Wait for the job to terminate before exiting.
|
|
524
|
+
thread.join()
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _process_iterable(process_func, iterables, max_batch_size, batch_type, **kwargs):
|
|
528
|
+
if max_batch_size < 1:
|
|
529
|
+
raise ValueError("max_batch_size must be >= 1")
|
|
530
|
+
|
|
531
|
+
if len(iterables) == 1:
|
|
532
|
+
iterable = iterables[0]
|
|
533
|
+
else:
|
|
534
|
+
iterable = itertools.zip_longest(*iterables)
|
|
535
|
+
|
|
536
|
+
kwargs.update(
|
|
537
|
+
{
|
|
538
|
+
"max_batch_size": max_batch_size,
|
|
539
|
+
"batch_type": batch_type,
|
|
540
|
+
"asynchronous": True,
|
|
541
|
+
}
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
read_batch_size = max_batch_size * 16 if max_batch_size > 1 else max_batch_size
|
|
545
|
+
queue = collections.deque()
|
|
546
|
+
|
|
547
|
+
for streams in _batch_iterator(iterable, read_batch_size, batch_type):
|
|
548
|
+
queue.extend(process_func(*streams, **kwargs))
|
|
549
|
+
|
|
550
|
+
while queue and queue[0].done():
|
|
551
|
+
yield queue.popleft().result()
|
|
552
|
+
|
|
553
|
+
while queue:
|
|
554
|
+
yield queue.popleft().result()
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def _batch_iterator(iterable, batch_size, batch_type):
|
|
558
|
+
streams = None
|
|
559
|
+
max_length = 0
|
|
560
|
+
|
|
561
|
+
for example in iterable:
|
|
562
|
+
if not isinstance(example, tuple):
|
|
563
|
+
example = (example,)
|
|
564
|
+
|
|
565
|
+
if batch_type == "examples":
|
|
566
|
+
if streams and len(streams[0]) == batch_size:
|
|
567
|
+
yield streams
|
|
568
|
+
streams = None
|
|
569
|
+
|
|
570
|
+
elif batch_type == "tokens":
|
|
571
|
+
max_length = max(max_length, len(example[0]))
|
|
572
|
+
|
|
573
|
+
if streams and (len(streams[0]) + 1) * max_length > batch_size:
|
|
574
|
+
yield streams
|
|
575
|
+
streams = None
|
|
576
|
+
max_length = len(example[0])
|
|
577
|
+
|
|
578
|
+
else:
|
|
579
|
+
raise ValueError("Invalid batch type %s" % batch_type)
|
|
580
|
+
|
|
581
|
+
if streams is None:
|
|
582
|
+
streams = tuple([] for _ in example)
|
|
583
|
+
for batch, element in zip(streams, example):
|
|
584
|
+
if element is None and len(streams) > 1:
|
|
585
|
+
raise ValueError("Input iterables do not have the same length")
|
|
586
|
+
batch.append(element)
|
|
587
|
+
|
|
588
|
+
if streams is not None:
|
|
589
|
+
yield streams
|
ctranslate2/logging.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from ctranslate2 import _ext
|
|
4
|
+
|
|
5
|
+
_PYTHON_TO_CT2_LEVEL = {
|
|
6
|
+
logging.CRITICAL: _ext.LogLevel.Critical,
|
|
7
|
+
logging.ERROR: _ext.LogLevel.Error,
|
|
8
|
+
logging.WARNING: _ext.LogLevel.Warning,
|
|
9
|
+
logging.INFO: _ext.LogLevel.Info,
|
|
10
|
+
logging.DEBUG: _ext.LogLevel.Debug,
|
|
11
|
+
logging.NOTSET: _ext.LogLevel.Trace,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
_CT2_TO_PYTHON_LEVEL = {v: k for k, v in _PYTHON_TO_CT2_LEVEL.items()}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def set_log_level(level: int):
|
|
18
|
+
"""Sets the CTranslate2 logging level from a Python logging level.
|
|
19
|
+
|
|
20
|
+
Arguments:
|
|
21
|
+
level: A Python logging level.
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
|
|
25
|
+
>>> import logging
|
|
26
|
+
>>> ctranslate2.set_log_level(logging.INFO)
|
|
27
|
+
|
|
28
|
+
Note:
|
|
29
|
+
The argument is a Python logging level for convenience, but this function
|
|
30
|
+
controls the C++ logs of the library.
|
|
31
|
+
"""
|
|
32
|
+
ct2_level = _PYTHON_TO_CT2_LEVEL.get(level)
|
|
33
|
+
if ct2_level is None:
|
|
34
|
+
raise ValueError("Level %d is not a valid logging level" % level)
|
|
35
|
+
_ext.set_log_level(ct2_level)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_log_level() -> int:
|
|
39
|
+
"""Returns the current logging level.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A Python logging level.
|
|
43
|
+
"""
|
|
44
|
+
ct2_level = _ext.get_log_level()
|
|
45
|
+
return _CT2_TO_PYTHON_LEVEL[ct2_level]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""A collection of models which don't fit in the generic classes :class:`ctranslate2.Translator`
|
|
2
|
+
and :class:`ctranslate2.Generator`.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from ctranslate2._ext import (
|
|
7
|
+
Wav2Vec2,
|
|
8
|
+
Wav2Vec2Bert,
|
|
9
|
+
Whisper,
|
|
10
|
+
WhisperGenerationResult,
|
|
11
|
+
WhisperGenerationResultAsync,
|
|
12
|
+
)
|
|
13
|
+
except ImportError as e:
|
|
14
|
+
# Allow using the Python package without the compiled extension.
|
|
15
|
+
if "No module named" in str(e):
|
|
16
|
+
pass
|
|
17
|
+
else:
|
|
18
|
+
raise
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from ctranslate2.specs.attention_spec import RotaryScalingType
|
|
2
|
+
from ctranslate2.specs.common_spec import Activation, EmbeddingsMerge
|
|
3
|
+
from ctranslate2.specs.model_spec import (
|
|
4
|
+
LanguageModelSpec,
|
|
5
|
+
LayerSpec,
|
|
6
|
+
ModelSpec,
|
|
7
|
+
SequenceToSequenceModelSpec,
|
|
8
|
+
)
|
|
9
|
+
from ctranslate2.specs.transformer_spec import (
|
|
10
|
+
TransformerDecoderModelSpec,
|
|
11
|
+
TransformerDecoderSpec,
|
|
12
|
+
TransformerEncoderModelSpec,
|
|
13
|
+
TransformerEncoderSpec,
|
|
14
|
+
TransformerSpec,
|
|
15
|
+
)
|
|
16
|
+
from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec
|
|
17
|
+
from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec
|
|
18
|
+
from ctranslate2.specs.whisper_spec import WhisperSpec
|