llama-cpp-python-win 0.3.16__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. bin/convert_hf_to_gguf.py +8751 -0
  2. bin/ggml-base.dll +0 -0
  3. bin/ggml-cpu.dll +0 -0
  4. bin/ggml.dll +0 -0
  5. bin/llama-mtmd-cli.exe +0 -0
  6. bin/llama.dll +0 -0
  7. bin/mtmd.dll +0 -0
  8. include/ggml-alloc.h +76 -0
  9. include/ggml-backend.h +354 -0
  10. include/ggml-blas.h +25 -0
  11. include/ggml-cann.h +123 -0
  12. include/ggml-cpp.h +39 -0
  13. include/ggml-cpu.h +145 -0
  14. include/ggml-cuda.h +47 -0
  15. include/ggml-metal.h +66 -0
  16. include/ggml-opt.h +256 -0
  17. include/ggml-rpc.h +33 -0
  18. include/ggml-sycl.h +49 -0
  19. include/ggml-vulkan.h +29 -0
  20. include/ggml-webgpu.h +19 -0
  21. include/ggml.h +2467 -0
  22. include/gguf.h +202 -0
  23. include/llama-cpp.h +30 -0
  24. include/llama.h +1482 -0
  25. include/mtmd-helper.h +91 -0
  26. include/mtmd.h +298 -0
  27. lib/cmake/ggml/ggml-config.cmake +328 -0
  28. lib/cmake/ggml/ggml-version.cmake +65 -0
  29. lib/cmake/llama/llama-config.cmake +54 -0
  30. lib/cmake/llama/llama-version.cmake +65 -0
  31. lib/ggml-base.lib +0 -0
  32. lib/ggml-cpu.lib +0 -0
  33. lib/ggml.lib +0 -0
  34. lib/llama.lib +0 -0
  35. lib/mtmd.lib +0 -0
  36. lib/pkgconfig/llama.pc +10 -0
  37. llama_cpp/__init__.py +4 -0
  38. llama_cpp/_ctypes_extensions.py +131 -0
  39. llama_cpp/_ggml.py +12 -0
  40. llama_cpp/_internals.py +856 -0
  41. llama_cpp/_logger.py +47 -0
  42. llama_cpp/_utils.py +78 -0
  43. llama_cpp/lib/ggml-base.dll +0 -0
  44. llama_cpp/lib/ggml-base.lib +0 -0
  45. llama_cpp/lib/ggml-cpu.dll +0 -0
  46. llama_cpp/lib/ggml-cpu.lib +0 -0
  47. llama_cpp/lib/ggml.dll +0 -0
  48. llama_cpp/lib/ggml.lib +0 -0
  49. llama_cpp/lib/llama.dll +0 -0
  50. llama_cpp/lib/llama.lib +0 -0
  51. llama_cpp/lib/mtmd.dll +0 -0
  52. llama_cpp/lib/mtmd.lib +0 -0
  53. llama_cpp/llama.py +2422 -0
  54. llama_cpp/llama_cache.py +155 -0
  55. llama_cpp/llama_chat_format.py +3962 -0
  56. llama_cpp/llama_cpp.py +4374 -0
  57. llama_cpp/llama_grammar.py +953 -0
  58. llama_cpp/llama_speculative.py +64 -0
  59. llama_cpp/llama_tokenizer.py +120 -0
  60. llama_cpp/llama_types.py +316 -0
  61. llama_cpp/llava_cpp.py +158 -0
  62. llama_cpp/mtmd_cpp.py +280 -0
  63. llama_cpp/py.typed +0 -0
  64. llama_cpp/server/__init__.py +0 -0
  65. llama_cpp/server/__main__.py +100 -0
  66. llama_cpp/server/app.py +597 -0
  67. llama_cpp/server/cli.py +97 -0
  68. llama_cpp/server/errors.py +212 -0
  69. llama_cpp/server/model.py +312 -0
  70. llama_cpp/server/settings.py +240 -0
  71. llama_cpp/server/types.py +316 -0
  72. llama_cpp_python_win-0.3.16.dist-info/METADATA +856 -0
  73. llama_cpp_python_win-0.3.16.dist-info/RECORD +75 -0
  74. llama_cpp_python_win-0.3.16.dist-info/WHEEL +5 -0
  75. llama_cpp_python_win-0.3.16.dist-info/licenses/LICENSE.md +9 -0
@@ -0,0 +1,3962 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import ctypes
7
+ import dataclasses
8
+ import random
9
+ import string
10
+
11
+ from datetime import datetime
12
+ from contextlib import ExitStack
13
+ from typing import (
14
+ Any,
15
+ Dict,
16
+ Iterator,
17
+ List,
18
+ Literal,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ Protocol,
23
+ cast,
24
+ )
25
+
26
+ import jinja2
27
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
28
+
29
+ import numpy as np
30
+ import numpy.typing as npt
31
+
32
+ import llama_cpp.llama_cpp as llama_cpp
33
+ import llama_cpp.llama as llama
34
+ import llama_cpp.llama_types as llama_types
35
+ import llama_cpp.llama_grammar as llama_grammar
36
+
37
+ from ._logger import logger
38
+ from ._utils import suppress_stdout_stderr, Singleton
39
+
40
+ ### Common Chat Templates and Special Tokens ###
41
+
42
+ # Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
43
+ CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
44
+ CHATML_BOS_TOKEN = "<s>"
45
+ CHATML_EOS_TOKEN = "<|im_end|>"
46
+
47
+ # Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
48
+ MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
49
+ MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
50
+ MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
51
+
52
+ # Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
53
+ MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
54
+
55
+ # Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
56
+ LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
57
+
58
+ ### Chat Completion Handler ###
59
+
60
+
61
+ class LlamaChatCompletionHandler(Protocol):
62
+ """Base Protocol for a llama chat completion handler.
63
+
64
+ Very generic protocol that can be used to implement any chat format.
65
+ The only hard requirement is that it must return a ChatCompletion when
66
+ stream=False and an iterator of ChatCompletionChunks when stream=True."""
67
+
68
+ def __call__(
69
+ self,
70
+ *,
71
+ # llama.cpp instance
72
+ llama: llama.Llama,
73
+ # openai api parameters
74
+ messages: List[llama_types.ChatCompletionRequestMessage],
75
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
76
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
77
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
78
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
79
+ temperature: float = 0.2,
80
+ top_p: float = 0.95,
81
+ top_k: int = 40,
82
+ stream: bool = False,
83
+ stop: Optional[Union[str, List[str]]] = [],
84
+ seed: Optional[int] = None,
85
+ response_format: Optional[
86
+ llama_types.ChatCompletionRequestResponseFormat
87
+ ] = None,
88
+ max_tokens: Optional[int] = None,
89
+ presence_penalty: float = 0.0,
90
+ frequency_penalty: float = 0.0,
91
+ repeat_penalty: float = 1.1,
92
+ model: Optional[str] = None,
93
+ logit_bias: Optional[Dict[str, float]] = None,
94
+ # llama.cpp parameters
95
+ min_p: float = 0.05,
96
+ typical_p: float = 1.0,
97
+ tfs_z: float = 1.0,
98
+ mirostat_mode: int = 0,
99
+ mirostat_tau: float = 5.0,
100
+ mirostat_eta: float = 0.1,
101
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
102
+ grammar: Optional[llama.LlamaGrammar] = None,
103
+ logprobs: Optional[bool] = None,
104
+ top_logprobs: Optional[int] = None,
105
+ **kwargs, # type: ignore
106
+ ) -> Union[
107
+ llama_types.CreateChatCompletionResponse,
108
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
109
+ ]: ...
110
+
111
+
112
+ class LlamaChatCompletionHandlerNotFoundException(Exception):
113
+ pass
114
+
115
+
116
+ class LlamaChatCompletionHandlerRegistry(Singleton):
117
+ _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
118
+
119
+ def register_chat_completion_handler(
120
+ self,
121
+ name: str,
122
+ chat_handler: LlamaChatCompletionHandler,
123
+ overwrite: bool = False,
124
+ ):
125
+ if not overwrite and name in self._chat_handlers:
126
+ raise ValueError(
127
+ f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
128
+ )
129
+ self._chat_handlers[name] = chat_handler
130
+
131
+ def unregister_chat_handler(self, name: str):
132
+ if name in self._chat_handlers:
133
+ del self._chat_handlers[name]
134
+ else:
135
+ raise ValueError(f"No formatter registered under the name '{name}'.")
136
+
137
+ def get_chat_completion_handler_by_name(
138
+ self, name: str
139
+ ) -> LlamaChatCompletionHandler:
140
+ try:
141
+ chat_handler = self._chat_handlers[name]
142
+ return chat_handler
143
+ except KeyError:
144
+ raise LlamaChatCompletionHandlerNotFoundException(
145
+ f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
146
+ )
147
+
148
+
149
+ def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
150
+ return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
151
+ name
152
+ )
153
+
154
+
155
+ def register_chat_completion_handler(name: str):
156
+ def decorator(f: LlamaChatCompletionHandler):
157
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
158
+ return f
159
+
160
+ return decorator
161
+
162
+
163
+ ### Chat Formatter ###
164
+
165
+
166
+ @dataclasses.dataclass
167
+ class ChatFormatterResponse:
168
+ """Dataclass that stores completion parameters for a given chat format and
169
+ create_chat_completion request.
170
+
171
+ prompt contains the formatted prompt generated from the chat format and messages.
172
+ stop contains the stop token or list of stop tokens to use for the chat format."""
173
+
174
+ prompt: str
175
+ stop: Optional[Union[str, List[str]]] = None
176
+ stopping_criteria: Optional[llama.StoppingCriteriaList] = None
177
+ added_special: bool = False
178
+
179
+
180
+ class ChatFormatter(Protocol):
181
+ """Base Protocol for a chat formatter. A chat formatter is a function that
182
+ takes a list of messages and returns a chat format response which can be used
183
+ to generate a completion. The response can also include a stop token or list
184
+ of stop tokens to use for the completion."""
185
+
186
+ def __call__(
187
+ self,
188
+ *,
189
+ messages: List[llama_types.ChatCompletionRequestMessage],
190
+ **kwargs: Any,
191
+ ) -> ChatFormatterResponse: ...
192
+
193
+
194
+ class Jinja2ChatFormatter(ChatFormatter):
195
+ def __init__(
196
+ self,
197
+ template: str,
198
+ eos_token: str,
199
+ bos_token: str,
200
+ add_generation_prompt: bool = True,
201
+ stop_token_ids: Optional[List[int]] = None,
202
+ ):
203
+ """A chat formatter that uses jinja2 templates to format the prompt."""
204
+ self.template = template
205
+ self.eos_token = eos_token
206
+ self.bos_token = bos_token
207
+ self.add_generation_prompt = add_generation_prompt
208
+ self.stop_token_ids = (
209
+ set(stop_token_ids) if stop_token_ids is not None else None
210
+ )
211
+
212
+ self._environment = ImmutableSandboxedEnvironment(
213
+ loader=jinja2.BaseLoader(),
214
+ trim_blocks=True,
215
+ lstrip_blocks=True,
216
+ ).from_string(self.template)
217
+
218
+ @staticmethod
219
+ def strftime_now(f: str) -> str:
220
+ return datetime.now().strftime(f)
221
+
222
+ def __call__(
223
+ self,
224
+ *,
225
+ messages: List[llama_types.ChatCompletionRequestMessage],
226
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
227
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
228
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
229
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
230
+ **kwargs: Any,
231
+ ) -> ChatFormatterResponse:
232
+ def raise_exception(message: str):
233
+ raise ValueError(message)
234
+
235
+ prompt = self._environment.render(
236
+ messages=messages,
237
+ eos_token=self.eos_token,
238
+ bos_token=self.bos_token,
239
+ raise_exception=raise_exception,
240
+ add_generation_prompt=self.add_generation_prompt,
241
+ functions=functions,
242
+ function_call=function_call,
243
+ tools=tools,
244
+ tool_choice=tool_choice,
245
+ strftime_now=self.strftime_now,
246
+ )
247
+
248
+ stopping_criteria = None
249
+ if self.stop_token_ids is not None:
250
+
251
+ def stop_on_last_token(
252
+ tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
253
+ ) -> bool:
254
+ return tokens[-1] in self.stop_token_ids
255
+
256
+ stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
257
+
258
+ return ChatFormatterResponse(
259
+ prompt=prompt,
260
+ stop=[self.eos_token],
261
+ stopping_criteria=stopping_criteria,
262
+ added_special=True,
263
+ )
264
+
265
+ def to_chat_handler(self) -> LlamaChatCompletionHandler:
266
+ return chat_formatter_to_chat_completion_handler(self)
267
+
268
+
269
+ def _convert_text_completion_logprobs_to_chat(
270
+ logprobs: Optional[llama_types.CompletionLogprobs],
271
+ ) -> llama_types.ChatCompletionLogprobs:
272
+ if logprobs is None:
273
+ return None
274
+
275
+ return {
276
+ "content": [
277
+ {
278
+ "token": token,
279
+ "bytes": None,
280
+ "logprob": logprob,
281
+ "top_logprobs": [
282
+ {
283
+ "token": top_token,
284
+ "logprob": top_logprob,
285
+ "bytes": None,
286
+ }
287
+ for top_token, top_logprob in top_logprobs.items()
288
+ ],
289
+ } for (token, logprob, top_logprobs) in zip(logprobs["tokens"], logprobs["token_logprobs"], logprobs["top_logprobs"])
290
+ ],
291
+ "refusal": None,
292
+ }
293
+
294
+ def _convert_text_completion_to_chat(
295
+ completion: llama_types.Completion,
296
+ ) -> llama_types.ChatCompletion:
297
+ assert "usage" in completion
298
+ return {
299
+ "id": "chat" + completion["id"],
300
+ "object": "chat.completion",
301
+ "created": completion["created"],
302
+ "model": completion["model"],
303
+ "choices": [
304
+ {
305
+ "index": 0,
306
+ "message": {
307
+ "role": "assistant",
308
+ "content": completion["choices"][0]["text"],
309
+ },
310
+ "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
311
+ "finish_reason": completion["choices"][0]["finish_reason"],
312
+ }
313
+ ],
314
+ "usage": completion["usage"],
315
+ }
316
+
317
+
318
+ def _convert_text_completion_chunks_to_chat(
319
+ chunks: Iterator[llama_types.CreateCompletionStreamResponse],
320
+ ) -> Iterator[llama_types.ChatCompletionChunk]:
321
+ for i, chunk in enumerate(chunks):
322
+ if i == 0:
323
+ yield {
324
+ "id": "chat" + chunk["id"],
325
+ "model": chunk["model"],
326
+ "created": chunk["created"],
327
+ "object": "chat.completion.chunk",
328
+ "choices": [
329
+ {
330
+ "index": 0,
331
+ "delta": {
332
+ "role": "assistant",
333
+ },
334
+ "logprobs": None,
335
+ "finish_reason": None,
336
+ }
337
+ ],
338
+ }
339
+ yield {
340
+ "id": "chat" + chunk["id"],
341
+ "model": chunk["model"],
342
+ "created": chunk["created"],
343
+ "object": "chat.completion.chunk",
344
+ "choices": [
345
+ {
346
+ "index": 0,
347
+ "delta": (
348
+ {
349
+ "content": chunk["choices"][0]["text"],
350
+ }
351
+ if chunk["choices"][0]["finish_reason"] is None
352
+ else {}
353
+ ),
354
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
355
+ "finish_reason": chunk["choices"][0]["finish_reason"],
356
+ }
357
+ ],
358
+ }
359
+
360
+
361
+ def _convert_completion_to_chat(
362
+ completion_or_chunks: Union[
363
+ llama_types.CreateCompletionResponse,
364
+ Iterator[llama_types.CreateCompletionStreamResponse],
365
+ ],
366
+ stream: bool = False,
367
+ ) -> Union[
368
+ llama_types.CreateChatCompletionResponse, Iterator[llama_types.ChatCompletionChunk]
369
+ ]:
370
+ if stream:
371
+ chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore
372
+ return _convert_text_completion_chunks_to_chat(chunks)
373
+ else:
374
+ completion: llama_types.Completion = completion_or_chunks # type: ignore
375
+ return _convert_text_completion_to_chat(completion)
376
+
377
+
378
+ def _convert_completion_to_chat_function(
379
+ tool_name: str,
380
+ completion_or_chunks: Union[
381
+ llama_types.CreateCompletionResponse,
382
+ Iterator[llama_types.CreateCompletionStreamResponse],
383
+ ],
384
+ stream: bool,
385
+ ):
386
+ if not stream:
387
+ completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
388
+ assert "usage" in completion
389
+ tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"]
390
+ # TODO: Fix for legacy function calls
391
+ chat_completion: llama_types.CreateChatCompletionResponse = {
392
+ "id": "chat" + completion["id"],
393
+ "object": "chat.completion",
394
+ "created": completion["created"],
395
+ "model": completion["model"],
396
+ "choices": [
397
+ {
398
+ "index": 0,
399
+ "message": {
400
+ "role": "assistant",
401
+ "content": None,
402
+ "function_call": {
403
+ "name": tool_name,
404
+ "arguments": completion["choices"][0]["text"],
405
+ },
406
+ "tool_calls": [
407
+ {
408
+ "id": tool_id,
409
+ "type": "function",
410
+ "function": {
411
+ "name": tool_name,
412
+ "arguments": completion["choices"][0]["text"],
413
+ },
414
+ }
415
+ ],
416
+ },
417
+ "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
418
+ "finish_reason": "tool_calls",
419
+ }
420
+ ],
421
+ "usage": completion["usage"],
422
+ }
423
+ return chat_completion
424
+ else:
425
+ chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore
426
+
427
+ def _stream_response_to_function_stream(
428
+ chunks: Iterator[llama_types.CreateCompletionStreamResponse],
429
+ ) -> Iterator[llama_types.CreateChatCompletionStreamResponse]:
430
+ # blank first message
431
+ first = True
432
+ id_ = None
433
+ created = None
434
+ model = None
435
+ tool_id = None
436
+ for chunk in chunks:
437
+ if first:
438
+ id_ = "chat" + chunk["id"]
439
+ created = chunk["created"]
440
+ model = chunk["model"]
441
+ tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
442
+ yield {
443
+ "id": id_,
444
+ "object": "chat.completion.chunk",
445
+ "created": created,
446
+ "model": model,
447
+ "choices": [
448
+ {
449
+ "index": 0,
450
+ "finish_reason": None,
451
+ "logprobs": None,
452
+ "delta": {
453
+ "role": "assistant",
454
+ "content": None,
455
+ "function_call": None,
456
+ "tool_calls": None,
457
+ },
458
+ }
459
+ ],
460
+ }
461
+ yield {
462
+ "id": "chat" + chunk["id"],
463
+ "object": "chat.completion.chunk",
464
+ "created": chunk["created"],
465
+ "model": chunk["model"],
466
+ "choices": [
467
+ {
468
+ "index": 0,
469
+ "finish_reason": None,
470
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
471
+ "delta": {
472
+ "role": None,
473
+ "content": None,
474
+ "function_call": {
475
+ "name": tool_name,
476
+ "arguments": chunk["choices"][0]["text"],
477
+ },
478
+ "tool_calls": [
479
+ {
480
+ "index": 0,
481
+ "id": tool_id,
482
+ "type": "function",
483
+ "function": {
484
+ "name": tool_name,
485
+ "arguments": chunk["choices"][0][
486
+ "text"
487
+ ],
488
+ },
489
+ }
490
+ ],
491
+ },
492
+ }
493
+ ],
494
+ }
495
+ first = False
496
+ continue
497
+ assert tool_id is not None
498
+ yield {
499
+ "id": "chat" + chunk["id"],
500
+ "object": "chat.completion.chunk",
501
+ "created": chunk["created"],
502
+ "model": chunk["model"],
503
+ "choices": [
504
+ {
505
+ "index": 0,
506
+ "finish_reason": None,
507
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
508
+ "delta": {
509
+ "role": None,
510
+ "content": None,
511
+ "function_call": {
512
+ "name": tool_name,
513
+ "arguments": chunk["choices"][0]["text"],
514
+ },
515
+ "tool_calls": [
516
+ {
517
+ "index": 0,
518
+ "id": tool_id,
519
+ "type": "function",
520
+ "function": {
521
+ "name": tool_name,
522
+ "arguments": chunk["choices"][0]["text"],
523
+ },
524
+ }
525
+ ],
526
+ },
527
+ }
528
+ ],
529
+ }
530
+
531
+ if id_ is not None and created is not None and model is not None:
532
+ yield {
533
+ "id": id_,
534
+ "object": "chat.completion.chunk",
535
+ "created": created,
536
+ "model": model,
537
+ "choices": [
538
+ {
539
+ "index": 0,
540
+ "finish_reason": "tool_calls",
541
+ "logprobs": None,
542
+ "delta": {
543
+ "role": None,
544
+ "content": None,
545
+ "function_call": None,
546
+ "tool_calls": None,
547
+ },
548
+ }
549
+ ],
550
+ }
551
+
552
+ return _stream_response_to_function_stream(chunks)
553
+
554
+
555
+ def chat_formatter_to_chat_completion_handler(
556
+ chat_formatter: ChatFormatter,
557
+ ) -> LlamaChatCompletionHandler:
558
+ def chat_completion_handler(
559
+ *,
560
+ llama: llama.Llama,
561
+ messages: List[llama_types.ChatCompletionRequestMessage],
562
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
563
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
564
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
565
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
566
+ temperature: float = 0.2,
567
+ top_p: float = 0.95,
568
+ top_k: int = 40,
569
+ min_p: float = 0.05,
570
+ typical_p: float = 1.0,
571
+ stream: bool = False,
572
+ stop: Optional[Union[str, List[str]]] = [],
573
+ seed: Optional[int] = None,
574
+ response_format: Optional[
575
+ llama_types.ChatCompletionRequestResponseFormat
576
+ ] = None,
577
+ max_tokens: Optional[int] = None,
578
+ presence_penalty: float = 0.0,
579
+ frequency_penalty: float = 0.0,
580
+ repeat_penalty: float = 1.1,
581
+ tfs_z: float = 1.0,
582
+ mirostat_mode: int = 0,
583
+ mirostat_tau: float = 5.0,
584
+ mirostat_eta: float = 0.1,
585
+ model: Optional[str] = None,
586
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
587
+ grammar: Optional[llama.LlamaGrammar] = None,
588
+ logit_bias: Optional[Dict[str, float]] = None,
589
+ logprobs: Optional[bool] = None,
590
+ top_logprobs: Optional[int] = None,
591
+ **kwargs, # type: ignore
592
+ ) -> Union[
593
+ llama_types.CreateChatCompletionResponse,
594
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
595
+ ]:
596
+ result = chat_formatter(
597
+ messages=messages,
598
+ functions=functions,
599
+ function_call=function_call,
600
+ tools=tools,
601
+ tool_choice=tool_choice,
602
+ )
603
+ prompt = llama.tokenize(
604
+ result.prompt.encode("utf-8"),
605
+ add_bos=not result.added_special,
606
+ special=True,
607
+ )
608
+ if result.stop is not None:
609
+ stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
610
+ rstop = result.stop if isinstance(result.stop, list) else [result.stop]
611
+ stop = stop + rstop
612
+
613
+ stopping_criteria = None
614
+ if result.stopping_criteria is not None:
615
+ stopping_criteria = result.stopping_criteria
616
+
617
+ if response_format is not None and response_format["type"] == "json_object":
618
+ grammar = _grammar_for_response_format(
619
+ response_format, verbose=llama.verbose
620
+ )
621
+
622
+ # Convert legacy functions to tools
623
+ if functions is not None:
624
+ tools = [
625
+ {
626
+ "type": "function",
627
+ "function": function,
628
+ }
629
+ for function in functions
630
+ ]
631
+
632
+ # Convert legacy function_call to tool_choice
633
+ if function_call is not None:
634
+ if isinstance(function_call, str) and (
635
+ function_call == "none" or function_call == "auto"
636
+ ):
637
+ tool_choice = function_call
638
+ if isinstance(function_call, dict) and "name" in function_call:
639
+ tool_choice = {
640
+ "type": "function",
641
+ "function": {
642
+ "name": function_call["name"],
643
+ },
644
+ }
645
+
646
+ tool = None
647
+ if (
648
+ tool_choice is not None
649
+ and isinstance(tool_choice, dict)
650
+ and tools is not None
651
+ ):
652
+ name = tool_choice["function"]["name"]
653
+ tool = next((t for t in tools if t["function"]["name"] == name), None)
654
+ if tool is None:
655
+ raise ValueError(f"Tool choice '{name}' not found in tools.")
656
+ schema = tool["function"]["parameters"]
657
+ try:
658
+ # create grammar from json schema
659
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
660
+ json.dumps(schema), verbose=llama.verbose
661
+ )
662
+ except Exception as e:
663
+ if llama.verbose:
664
+ print(str(e), file=sys.stderr)
665
+ grammar = llama_grammar.LlamaGrammar.from_string(
666
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
667
+ )
668
+
669
+ completion_or_chunks = llama.create_completion(
670
+ prompt=prompt,
671
+ temperature=temperature,
672
+ top_p=top_p,
673
+ top_k=top_k,
674
+ min_p=min_p,
675
+ typical_p=typical_p,
676
+ logprobs=top_logprobs if logprobs else None,
677
+ stream=stream,
678
+ stop=stop,
679
+ seed=seed,
680
+ max_tokens=max_tokens,
681
+ presence_penalty=presence_penalty,
682
+ frequency_penalty=frequency_penalty,
683
+ repeat_penalty=repeat_penalty,
684
+ tfs_z=tfs_z,
685
+ mirostat_mode=mirostat_mode,
686
+ mirostat_tau=mirostat_tau,
687
+ mirostat_eta=mirostat_eta,
688
+ model=model,
689
+ logits_processor=logits_processor,
690
+ stopping_criteria=stopping_criteria,
691
+ grammar=grammar,
692
+ logit_bias=logit_bias,
693
+ )
694
+ if tool is not None:
695
+ tool_name = tool["function"]["name"]
696
+ return _convert_completion_to_chat_function(
697
+ tool_name, completion_or_chunks, stream
698
+ )
699
+ return _convert_completion_to_chat(completion_or_chunks, stream=stream)
700
+
701
+ return chat_completion_handler
702
+
703
+
704
+ def hf_autotokenizer_to_chat_formatter(
705
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]]
706
+ ) -> ChatFormatter:
707
+ # https://huggingface.co/docs/transformers/main/chat_templating
708
+ # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
709
+ # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
710
+ from transformers import AutoTokenizer # type: ignore
711
+
712
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
713
+
714
+ def format_autotokenizer(
715
+ messages: List[llama_types.ChatCompletionRequestMessage],
716
+ **kwargs: Any,
717
+ ) -> ChatFormatterResponse:
718
+ tokenizer.use_default_system_prompt = False # type: ignore
719
+ prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
720
+ assert isinstance(prompt, str)
721
+ # Return formatted prompt and eos token by default
722
+ return ChatFormatterResponse(
723
+ prompt=prompt, stop=tokenizer.eos_token, added_special=True
724
+ )
725
+
726
+ return format_autotokenizer
727
+
728
+
729
+ def hf_autotokenizer_to_chat_completion_handler(
730
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]]
731
+ ) -> LlamaChatCompletionHandler:
732
+ chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
733
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
734
+
735
+
736
+ def hf_tokenizer_config_to_chat_formatter(
737
+ tokenizer_config: Dict[str, Any],
738
+ add_generation_prompt: bool = True,
739
+ ) -> ChatFormatter:
740
+ assert isinstance(tokenizer_config, dict)
741
+
742
+ assert "chat_template" in tokenizer_config
743
+ assert isinstance(tokenizer_config["chat_template"], str)
744
+ chat_template = tokenizer_config["chat_template"]
745
+
746
+ assert "bos_token" in tokenizer_config
747
+ assert isinstance(tokenizer_config["bos_token"], str)
748
+ bos_token = tokenizer_config["bos_token"]
749
+
750
+ assert "eos_token" in tokenizer_config
751
+ assert isinstance(tokenizer_config["eos_token"], str)
752
+ eos_token = tokenizer_config["eos_token"]
753
+
754
+ env = ImmutableSandboxedEnvironment(
755
+ trim_blocks=True,
756
+ lstrip_blocks=True,
757
+ ).from_string(chat_template)
758
+
759
+ def format_tokenizer_config(
760
+ messages: List[llama_types.ChatCompletionRequestMessage],
761
+ **kwargs: Any,
762
+ ) -> ChatFormatterResponse:
763
+ # TODO: veryify this is correct
764
+ # Add a blank assistant message to the end of the messages to prompt the model to generate a response
765
+ if add_generation_prompt:
766
+ messages = [
767
+ *messages,
768
+ llama_types.ChatCompletionRequestAssistantMessage(
769
+ role="assistant", content=""
770
+ ),
771
+ ]
772
+ prompt = env.render(
773
+ messages=messages,
774
+ bos_token=bos_token,
775
+ eos_token=eos_token,
776
+ )
777
+ return ChatFormatterResponse(
778
+ prompt=prompt, stop=[eos_token, bos_token], added_special=True
779
+ )
780
+
781
+ return format_tokenizer_config
782
+
783
+
784
+ def hf_tokenizer_config_to_chat_completion_handler(
785
+ tokenizer_config: Dict[str, Any],
786
+ add_generation_prompt: bool = True,
787
+ ) -> LlamaChatCompletionHandler:
788
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(
789
+ tokenizer_config, add_generation_prompt=add_generation_prompt
790
+ )
791
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
792
+
793
+
794
+ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
795
+ if "tokenizer.chat_template" not in metadata:
796
+ return None
797
+
798
+ if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
799
+ return "chatml"
800
+
801
+ if (
802
+ metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE
803
+ or metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE
804
+ ):
805
+ return "mistral-instruct"
806
+
807
+ if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE:
808
+ return "llama-3"
809
+
810
+ return None
811
+
812
+
813
+ ### Utility functions for formatting chat prompts ###
814
+ # TODO: Replace these with jinja2 templates
815
+
816
+
817
+ def _get_system_message(
818
+ messages: List[llama_types.ChatCompletionRequestMessage],
819
+ ) -> str:
820
+ """Get the first system message."""
821
+ for message in messages:
822
+ if message["role"] == "system":
823
+ return message["content"] or ""
824
+ return ""
825
+
826
+
827
+ def _map_roles(
828
+ messages: List[llama_types.ChatCompletionRequestMessage],
829
+ role_map: Dict[str, str],
830
+ ) -> List[Tuple[str, Optional[str]]]:
831
+ """Map the message roles."""
832
+ output: List[Tuple[str, Optional[str]]] = []
833
+ for message in messages:
834
+ role = message["role"]
835
+ if role in role_map:
836
+ content: str | None = (
837
+ message["content"] if isinstance(message["content"], str) else None
838
+ )
839
+ output.append((role_map[role], content))
840
+ return output
841
+
842
+
843
+ def _format_llama2(
844
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
845
+ ) -> str:
846
+ """Format the prompt with the llama2 style."""
847
+ seps = [sep, sep2]
848
+ ret = system_message + sep
849
+ for i, (role, message) in enumerate(messages):
850
+ if system_message and i == 0:
851
+ m = message or ""
852
+ ret += m + seps[i % 2]
853
+ elif message:
854
+ ret += role + message + " " + seps[i % 2]
855
+ else:
856
+ ret += role + " "
857
+ return ret
858
+
859
+
860
+ def _format_add_colon_single(
861
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
862
+ ) -> str:
863
+ """Format the prompt with the add-colon-single style."""
864
+ ret = system_message + sep
865
+ for role, message in messages:
866
+ if message:
867
+ ret += role + ": " + message + sep
868
+ else:
869
+ ret += role + ":"
870
+ return ret
871
+
872
+
873
+ def _format_add_colon_two(
874
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
875
+ ) -> str:
876
+ """Format the prompt with the add-colon-two style."""
877
+ seps = [sep, sep2]
878
+ ret = system_message + seps[0]
879
+ for i, (role, message) in enumerate(messages):
880
+ if message:
881
+ ret += role + ": " + message + seps[i % 2]
882
+ else:
883
+ ret += role + ":"
884
+ return ret
885
+
886
+
887
+ def _format_no_colon_single(
888
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
889
+ ) -> str:
890
+ """Format the prompt with the no-colon-single style."""
891
+ ret = system_message
892
+ for role, message in messages:
893
+ if message:
894
+ ret += role + message + sep
895
+ else:
896
+ ret += role
897
+ return ret
898
+
899
+
900
+ def _format_add_colon_space_single(
901
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
902
+ ) -> str:
903
+ """Format the prompt with the add-colon-space-single style."""
904
+ ret = system_message + sep
905
+ for role, message in messages:
906
+ if message:
907
+ ret += role + ": " + message + sep
908
+ else:
909
+ ret += role + ": " # must be end with a space
910
+ return ret
911
+
912
+
913
+ def _format_chatml(
914
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
915
+ ) -> str:
916
+ """Format the prompt with the chatml style."""
917
+ ret = "" if system_message == "" else system_message + sep + "\n"
918
+ for role, message in messages:
919
+ if message:
920
+ ret += role + "\n" + message + sep + "\n"
921
+ else:
922
+ ret += role + "\n"
923
+ return ret
924
+
925
+
926
+ def _format_chatglm3(
927
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
928
+ ) -> str:
929
+ """Format the prompt with the chatglm3 style."""
930
+ ret = ""
931
+ if system_message:
932
+ ret += system_message
933
+ for role, message in messages:
934
+ if message:
935
+ ret += role + "\n" + " " + message
936
+ else:
937
+ ret += role
938
+ return ret
939
+
940
+
941
+ def _grammar_for_json(verbose: bool = False):
942
+ return llama_grammar.LlamaGrammar.from_string(
943
+ llama_grammar.JSON_GBNF, verbose=verbose
944
+ )
945
+
946
+
947
+ def _grammar_for_json_schema(
948
+ schema: str, verbose: bool = False, fallback_to_json: bool = True
949
+ ):
950
+ try:
951
+ return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose)
952
+ except Exception as e:
953
+ if fallback_to_json:
954
+ return _grammar_for_json(verbose=verbose)
955
+ else:
956
+ raise e
957
+
958
+
959
+ def _grammar_for_response_format(
960
+ response_format: llama_types.ChatCompletionRequestResponseFormat,
961
+ verbose: bool = False,
962
+ ):
963
+ if response_format["type"] != "json_object":
964
+ return None
965
+
966
+ if "schema" in response_format:
967
+ return _grammar_for_json_schema(
968
+ json.dumps(response_format["schema"]), verbose=verbose
969
+ )
970
+ else:
971
+ return _grammar_for_json(verbose=verbose)
972
+
973
+
974
+ ### Chat Formats ###
975
+
976
+
977
+ def register_chat_format(name: str):
978
+ def decorator(f: ChatFormatter):
979
+ chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
980
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
981
+ name, chat_completion_handler
982
+ )
983
+ return f
984
+
985
+ return decorator
986
+
987
+
988
+ # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
989
+ # system prompt is "embedded" in the first message
990
+ @register_chat_format("llama-2")
991
+ def format_llama2(
992
+ messages: List[llama_types.ChatCompletionRequestMessage],
993
+ **kwargs: Any,
994
+ ) -> ChatFormatterResponse:
995
+ _system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
996
+ _roles = dict(user="<s>[INST]", assistant="[/INST]")
997
+ _messages = _map_roles(messages, _roles)
998
+ system_message = _get_system_message(messages)
999
+ if system_message:
1000
+ system_message = _system_template.format(system_message=system_message)
1001
+ _prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
1002
+ return ChatFormatterResponse(prompt=_prompt)
1003
+
1004
+
1005
+ # Chat format for Llama-3 models, see more details at:
1006
+ # https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
1007
+ @register_chat_format("llama-3")
1008
+ def format_llama3(
1009
+ messages: List[llama_types.ChatCompletionRequestMessage],
1010
+ **kwargs: Any,
1011
+ ) -> ChatFormatterResponse:
1012
+ _roles = dict(
1013
+ system="<|start_header_id|>system<|end_header_id|>\n\n",
1014
+ user="<|start_header_id|>user<|end_header_id|>\n\n",
1015
+ assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
1016
+ )
1017
+ _sep = "<|eot_id|>"
1018
+ _messages = _map_roles(messages, _roles)
1019
+ _messages.append((_roles["assistant"], None))
1020
+ _prompt = _format_no_colon_single("", _messages, _sep)
1021
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1022
+
1023
+
1024
+ @register_chat_format("alpaca")
1025
+ def format_alpaca(
1026
+ messages: List[llama_types.ChatCompletionRequestMessage],
1027
+ **kwargs: Any,
1028
+ ) -> ChatFormatterResponse:
1029
+ _roles = dict(user="### Instruction", assistant="### Response")
1030
+ _sep = "\n\n"
1031
+ _sep2 = "</s>"
1032
+ system_message = _get_system_message(messages)
1033
+ _messages = _map_roles(messages, _roles)
1034
+ _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
1035
+ return ChatFormatterResponse(prompt=_prompt)
1036
+
1037
+
1038
+ @register_chat_format("qwen")
1039
+ def format_qwen(
1040
+ messages: List[llama_types.ChatCompletionRequestMessage],
1041
+ **kwargs: Any,
1042
+ ) -> ChatFormatterResponse:
1043
+ _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
1044
+ system_message = _get_system_message(messages) or "You are a helpful assistant."
1045
+ system_template = "<|im_start|>system\n{system_message}"
1046
+ system_message = system_template.format(system_message=system_message)
1047
+ _messages = _map_roles(messages, _roles)
1048
+ _messages.append((_roles["assistant"], None))
1049
+ _sep = "<|im_end|>"
1050
+ _prompt = _format_chatml(system_message, _messages, _sep)
1051
+ _sep2 = "<|endoftext|>"
1052
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
1053
+
1054
+
1055
+ @register_chat_format("vicuna")
1056
+ def format(
1057
+ messages: List[llama_types.ChatCompletionRequestMessage],
1058
+ **kwargs: Any,
1059
+ ) -> ChatFormatterResponse:
1060
+ _system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
1061
+ _roles = dict(user="USER", assistant="ASSISTANT")
1062
+ _sep = " "
1063
+ _sep2 = "</s>"
1064
+ system_message = _system_message
1065
+ _messages = _map_roles(messages, _roles)
1066
+ _messages.append((_roles["assistant"], None))
1067
+ _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
1068
+ return ChatFormatterResponse(prompt=_prompt)
1069
+
1070
+
1071
+ @register_chat_format("oasst_llama")
1072
+ def format_oasst_llama(
1073
+ messages: List[llama_types.ChatCompletionRequestMessage],
1074
+ **kwargs: Any,
1075
+ ) -> ChatFormatterResponse:
1076
+ _system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
1077
+ _roles = dict(user="<|prompter|>", assistant="<|assistant|>")
1078
+ _sep = "</s>"
1079
+ system_message = _get_system_message(messages)
1080
+ system_message = _system_template.format(system_message=system_message)
1081
+ _messages = _map_roles(messages, _roles)
1082
+ _messages.append((_roles["assistant"], None))
1083
+ _prompt = _format_no_colon_single(system_message, _messages, _sep)
1084
+ return ChatFormatterResponse(prompt=_prompt)
1085
+
1086
+
1087
+ @register_chat_format("baichuan-2")
1088
+ def format_baichuan2(
1089
+ messages: List[llama_types.ChatCompletionRequestMessage],
1090
+ **kwargs: Any,
1091
+ ) -> ChatFormatterResponse:
1092
+ _system_template = "{system_message}"
1093
+ _roles = dict(user="<reserved_106>", assistant="<reserved_107>")
1094
+ _sep = ""
1095
+ system_message = _get_system_message(messages)
1096
+ system_message = _system_template.format(system_message=system_message)
1097
+ _messages = _map_roles(messages, _roles)
1098
+ _messages.append((_roles["assistant"], None))
1099
+ _prompt = _format_no_colon_single(system_message, _messages, _sep)
1100
+ return ChatFormatterResponse(prompt=_prompt)
1101
+
1102
+
1103
+ @register_chat_format("baichuan")
1104
+ def format_baichuan(
1105
+ messages: List[llama_types.ChatCompletionRequestMessage],
1106
+ **kwargs: Any,
1107
+ ) -> ChatFormatterResponse:
1108
+ _system_template = "{system_message}"
1109
+ _roles = dict(user="<reserved_102>", assistant="<reserved_103>")
1110
+ _sep = ""
1111
+ system_message = _get_system_message(messages)
1112
+ system_message = _system_template.format(system_message=system_message)
1113
+ _messages = _map_roles(messages, _roles)
1114
+ _messages.append((_roles["assistant"], None))
1115
+ _prompt = _format_no_colon_single(system_message, _messages, _sep)
1116
+ return ChatFormatterResponse(prompt=_prompt)
1117
+
1118
+
1119
+ @register_chat_format("openbuddy")
1120
+ def format_openbuddy(
1121
+ messages: List[llama_types.ChatCompletionRequestMessage],
1122
+ **kwargs: Any,
1123
+ ) -> ChatFormatterResponse:
1124
+ _system_message = """You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.
1125
+ Always answer as helpfully and logically as possible, while being safe. Your answers should not include any harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
1126
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
1127
+ You can speak fluently in many languages, for example: English, Chinese.
1128
+ You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.
1129
+ You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.
1130
+
1131
+ """
1132
+ _roles = dict(user="User", assistant="Assistant")
1133
+ _sep = "\n"
1134
+ system_message = _system_message
1135
+ _messages = _map_roles(messages, _roles)
1136
+ _messages.append((_roles["assistant"], None))
1137
+ _prompt = _format_add_colon_single(system_message, _messages, _sep)
1138
+ return ChatFormatterResponse(prompt=_prompt)
1139
+
1140
+
1141
+ @register_chat_format("redpajama-incite")
1142
+ def format_redpajama_incite(
1143
+ messages: List[llama_types.ChatCompletionRequestMessage],
1144
+ **kwargs: Any,
1145
+ ) -> ChatFormatterResponse:
1146
+ _system_message = _get_system_message(messages)
1147
+ _roles = dict(user="<human>", assistant="<bot>")
1148
+ _sep = "\n"
1149
+ _stop = "<human>"
1150
+ system_message = _system_message
1151
+ _messages = _map_roles(messages, _roles)
1152
+ _messages.append((_roles["assistant"], None))
1153
+ _prompt = _format_add_colon_single(system_message, _messages, _sep)
1154
+ return ChatFormatterResponse(prompt=_prompt, stop=_stop)
1155
+
1156
+
1157
+ @register_chat_format("snoozy")
1158
+ def format_snoozy(
1159
+ messages: List[llama_types.ChatCompletionRequestMessage],
1160
+ **kwargs: Any,
1161
+ ) -> ChatFormatterResponse:
1162
+ system_template = "### Instruction:\n{system_message}"
1163
+ default_system_message = "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response."
1164
+ _system_message = _get_system_message(messages)
1165
+ _system_message = (
1166
+ _system_message if _system_message != "" else default_system_message
1167
+ )
1168
+ system_message = system_template.format(system_message=_system_message)
1169
+ _roles = dict(user="### Prompt", assistant="### Response")
1170
+ _sep = "\n"
1171
+ _stop = "###"
1172
+ system_message = _system_message
1173
+ _messages = _map_roles(messages, _roles)
1174
+ _messages.append((_roles["assistant"], None))
1175
+ _prompt = _format_add_colon_single(system_message, _messages, _sep)
1176
+ return ChatFormatterResponse(prompt=_prompt, stop=_stop)
1177
+
1178
+
1179
+ @register_chat_format("phind")
1180
+ def format_phind(
1181
+ messages: List[llama_types.ChatCompletionRequestMessage],
1182
+ **kwargs: Any,
1183
+ ) -> ChatFormatterResponse:
1184
+ _roles = dict(user="### User Message", assistant="### Assistant")
1185
+ _sep = "\n\n"
1186
+ _system_message = "### System Prompt\nYou are an intelligent programming assistant."
1187
+ _messages = _map_roles(messages, _roles)
1188
+ _messages.append((_roles["assistant"], None))
1189
+ _prompt = _format_add_colon_single(_system_message, _messages, _sep)
1190
+ return ChatFormatterResponse(prompt=_prompt)
1191
+
1192
+
1193
+ @register_chat_format("intel")
1194
+ def format_intel(
1195
+ messages: List[llama_types.ChatCompletionRequestMessage],
1196
+ **kwargs: Any,
1197
+ ) -> ChatFormatterResponse:
1198
+ _roles = dict(user="### User:", assistant="### Assistant:")
1199
+ _sep = "\n"
1200
+ _system_message = "### System:\n{system_message}"
1201
+ _messages = _map_roles(messages, _roles)
1202
+ _messages.append((_roles["assistant"], None))
1203
+ _prompt = _format_add_colon_single(_system_message, _messages, _sep)
1204
+ return ChatFormatterResponse(prompt=_prompt)
1205
+
1206
+
1207
+ @register_chat_format("open-orca")
1208
+ def format_open_orca(
1209
+ messages: List[llama_types.ChatCompletionRequestMessage],
1210
+ **kwargs: Any,
1211
+ ) -> ChatFormatterResponse:
1212
+ system_template = "{system_message}"
1213
+ system_message = (
1214
+ "You are a helpful assistant. Please answer truthfully and write out your "
1215
+ "thinking step by step to be sure you get the right answer. If you make a mistake or encounter "
1216
+ "an error in your thinking, say so out loud and attempt to correct it. If you don't know or "
1217
+ "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, "
1218
+ "and physicist. You will also act as the most appropriate type of expert to answer any particular "
1219
+ "question or solve the relevant problem; state which expert type your are, if so. Also think of "
1220
+ "any particular named expert that would be ideal to answer the relevant question or solve the "
1221
+ "relevant problem; name and act as them, if appropriate."
1222
+ )
1223
+ roles = ("User", "Assistant")
1224
+ sep = "<|end_of_turn|>\n"
1225
+ # stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
1226
+ stop_str = "User"
1227
+ system_message = system_template.format(system_message=system_message)
1228
+ _messages = _map_roles(messages, dict(zip(roles, roles)))
1229
+ _messages.append((roles[1], None))
1230
+ _prompt = _format_add_colon_space_single(system_message, _messages, sep)
1231
+ return ChatFormatterResponse(prompt=_prompt, stop=stop_str)
1232
+
1233
+
1234
+ @register_chat_format("mistrallite")
1235
+ def format_mistrallite(
1236
+ messages: List[llama_types.ChatCompletionRequestMessage],
1237
+ **kwargs: Any,
1238
+ ) -> ChatFormatterResponse:
1239
+ _roles = dict(user="<|prompter|>", assistant="</s>\n<|assistant|>")
1240
+ _sep = " "
1241
+ system_template = """<|system|>{system_message}</s>"""
1242
+ system_message = _get_system_message(messages)
1243
+ system_message = system_template.format(system_message=system_message)
1244
+ _messages = _map_roles(messages, _roles)
1245
+ _messages.append((_roles["assistant"], None))
1246
+ _prompt = _format_no_colon_single(system_message, _messages, _sep)
1247
+ return ChatFormatterResponse(prompt=_prompt)
1248
+
1249
+
1250
+ @register_chat_format("zephyr")
1251
+ def format_zephyr(
1252
+ messages: List[llama_types.ChatCompletionRequestMessage],
1253
+ **kwargs: Any,
1254
+ ) -> ChatFormatterResponse:
1255
+ system_template = """<|system|>
1256
+ {system_message}"""
1257
+ system_message = _get_system_message(messages)
1258
+ system_message = system_template.format(system_message=system_message)
1259
+ _roles = dict(user="<|user|>\n", assistant="<|assistant|>\n")
1260
+ _sep = "</s>"
1261
+ _messages = _map_roles(messages, _roles)
1262
+ _messages.append((_roles["assistant"], None))
1263
+ _prompt = _format_chatml(system_message, _messages, _sep)
1264
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1265
+
1266
+
1267
+ @register_chat_format("pygmalion")
1268
+ def format_pygmalion(
1269
+ messages: List[llama_types.ChatCompletionRequestMessage],
1270
+ **kwargs: Any,
1271
+ ) -> ChatFormatterResponse:
1272
+ system_template = """<|system|>{system_message}"""
1273
+ system_message = _get_system_message(messages)
1274
+ system_message = system_template.format(system_message=system_message)
1275
+ _roles = dict(user="<|user|>", assistant="<|model|>")
1276
+ _sep = "\n"
1277
+ _messages = _map_roles(messages, _roles)
1278
+ _messages.append((_roles["assistant"], None))
1279
+ _prompt = _format_chatml(system_message, _messages, _sep)
1280
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1281
+
1282
+
1283
+ @register_chat_format("chatml")
1284
+ def format_chatml(
1285
+ messages: List[llama_types.ChatCompletionRequestMessage],
1286
+ **kwargs: Any,
1287
+ ) -> ChatFormatterResponse:
1288
+ system_template = """<|im_start|>system
1289
+ {system_message}"""
1290
+ system_message = _get_system_message(messages)
1291
+ system_message = system_template.format(system_message=system_message)
1292
+ _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
1293
+ _sep = "<|im_end|>"
1294
+ _messages = _map_roles(messages, _roles)
1295
+ _messages.append((_roles["assistant"], None))
1296
+ _prompt = _format_chatml(system_message, _messages, _sep)
1297
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1298
+
1299
+
1300
+ @register_chat_format("mistral-instruct")
1301
+ def format_mistral_instruct(
1302
+ messages: List[llama_types.ChatCompletionRequestMessage],
1303
+ **kwargs: Any,
1304
+ ) -> ChatFormatterResponse:
1305
+ eos = "</s>"
1306
+ stop = eos
1307
+ prompt = ""
1308
+ for message in messages:
1309
+ if (
1310
+ message["role"] == "user"
1311
+ and message["content"] is not None
1312
+ and isinstance(message["content"], str)
1313
+ ):
1314
+ prompt += "[INST] " + message["content"]
1315
+ elif message["role"] == "assistant" and message["content"] is not None:
1316
+ prompt += " [/INST]" + message["content"] + eos
1317
+ prompt += " [/INST]"
1318
+ return ChatFormatterResponse(prompt=prompt, stop=stop)
1319
+
1320
+
1321
+ @register_chat_format("chatglm3")
1322
+ def format_chatglm3(
1323
+ messages: List[llama_types.ChatCompletionRequestMessage],
1324
+ **kwargs: Any,
1325
+ ) -> ChatFormatterResponse:
1326
+ system_template = """<|system|>
1327
+ {system_message}"""
1328
+ system_message = _get_system_message(messages)
1329
+ system_message = system_template.format(system_message=system_message)
1330
+ _roles = dict(user="<|user|>", assistant="<|assistant|>")
1331
+ _sep = "</s>"
1332
+ _messages = _map_roles(messages, _roles)
1333
+ _messages.append((_roles["assistant"], None))
1334
+ _prompt = _format_chatglm3(system_message, _messages, _sep)
1335
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1336
+
1337
+
1338
+ @register_chat_format("openchat")
1339
+ def format_openchat(
1340
+ messages: List[llama_types.ChatCompletionRequestMessage],
1341
+ **kwargs: Any,
1342
+ ) -> ChatFormatterResponse:
1343
+ system_template = "{system_message}<|end_of_turn|>"
1344
+ system_message = _get_system_message(messages)
1345
+ system_message = system_template.format(system_message=system_message)
1346
+ _roles = dict(
1347
+ user="GPT4 Correct User: ", assistant="<|end_of_turn|>GPT4 Correct Assistant: "
1348
+ )
1349
+ _sep = "<|end_of_turn|>"
1350
+ _messages = _map_roles(messages, _roles)
1351
+ _messages.append((_roles["assistant"], None))
1352
+ _prompt = _format_chatml(system_message, _messages, _sep)
1353
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1354
+
1355
+
1356
+ # Chat format for Saiga models, see more details and available models:
1357
+ # https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
1358
+ @register_chat_format("saiga")
1359
+ def format_saiga(
1360
+ messages: list[llama_types.ChatCompletionRequestMessage],
1361
+ **kwargs: Any,
1362
+ ) -> ChatFormatterResponse:
1363
+ _message_template = "<s>{role}\n{content}</s>"
1364
+ _roles = dict(user="user", bot="bot", system="system")
1365
+ _messages = _map_roles(messages, _roles)
1366
+
1367
+ _prompt = ""
1368
+ for role, content in _messages:
1369
+ if content:
1370
+ _prompt += _message_template.format(role=role, content=content)
1371
+ else:
1372
+ _prompt += f"<s>{role}\n"
1373
+ # Response template
1374
+ _prompt += "<s>bot"
1375
+ return ChatFormatterResponse(prompt=_prompt.strip())
1376
+
1377
+
1378
+ # Chat format for Google's Gemma models, see more details and available models:
1379
+ # https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
1380
+ @register_chat_format("gemma")
1381
+ def format_gemma(
1382
+ messages: List[llama_types.ChatCompletionRequestMessage],
1383
+ **kwargs: Any,
1384
+ ) -> ChatFormatterResponse:
1385
+ system_message = _get_system_message(messages)
1386
+ if system_message != "":
1387
+ logger.debug(
1388
+ "`role='system'` messages are not allowed on Google's Gemma models."
1389
+ )
1390
+ _roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n")
1391
+ _sep = "<end_of_turn>\n"
1392
+ _messages = _map_roles(messages, _roles)
1393
+ _messages.append((_roles["assistant"], None))
1394
+ _prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
1395
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep)
1396
+
1397
+
1398
+ # Tricky chat formats that require custom chat handlers
1399
+
1400
+
1401
+ @register_chat_completion_handler("functionary")
1402
+ def functionary_chat_handler(
1403
+ llama: llama.Llama,
1404
+ messages: List[llama_types.ChatCompletionRequestMessage],
1405
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
1406
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
1407
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
1408
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
1409
+ temperature: float = 0.2,
1410
+ top_p: float = 0.95,
1411
+ top_k: int = 40,
1412
+ min_p: float = 0.05,
1413
+ typical_p: float = 1.0,
1414
+ stream: bool = False,
1415
+ stop: Optional[Union[str, List[str]]] = [],
1416
+ response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
1417
+ max_tokens: Optional[int] = None,
1418
+ presence_penalty: float = 0.0,
1419
+ frequency_penalty: float = 0.0,
1420
+ repeat_penalty: float = 1.1,
1421
+ tfs_z: float = 1.0,
1422
+ mirostat_mode: int = 0,
1423
+ mirostat_tau: float = 5.0,
1424
+ mirostat_eta: float = 0.1,
1425
+ model: Optional[str] = None,
1426
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
1427
+ grammar: Optional[llama.LlamaGrammar] = None,
1428
+ **kwargs, # type: ignore
1429
+ ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
1430
+ SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
1431
+
1432
+ def generate_type_definition(
1433
+ param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
1434
+ ) -> str:
1435
+ indent = " " * indent_level
1436
+ if "$ref" in param:
1437
+ # Reference to a shared definition
1438
+ ref_name = param["$ref"].split("/")[
1439
+ -1
1440
+ ] # Extract the type name from the reference
1441
+ return ref_name
1442
+ elif param.get("type") == "array":
1443
+ items = param.get("items", {})
1444
+ item_type = generate_type_definition(items, indent_level + 1, shared_defs)
1445
+ return f"Array<{item_type}>"
1446
+ elif param.get("type") == "object":
1447
+ properties = param.get("properties", {})
1448
+ nested_schema = "{\n"
1449
+ for nested_param_name, nested_param in properties.items():
1450
+ nested_param_type = generate_type_definition(
1451
+ nested_param, indent_level + 1, shared_defs
1452
+ )
1453
+ nested_schema += (
1454
+ f"{indent} {nested_param_name}: {nested_param_type},\n"
1455
+ )
1456
+ nested_schema += indent + "}"
1457
+ return nested_schema
1458
+ elif "enum" in param:
1459
+ # Enum type
1460
+ return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
1461
+ else:
1462
+ # Simple type
1463
+ return param.get("type", "any")
1464
+
1465
+ def generate_shared_definitions(shared_defs, indent_level: int) -> str:
1466
+ indent = " " * indent_level
1467
+ shared_definitions = ""
1468
+ for def_name, def_properties in shared_defs.items():
1469
+ shared_definitions += f"{indent}type {def_name} = "
1470
+ if def_properties.get("type") == "object":
1471
+ shared_definitions += generate_type_definition(
1472
+ def_properties, indent_level, shared_defs
1473
+ )
1474
+ elif "enum" in def_properties:
1475
+ # Enum type
1476
+ shared_definitions += " | ".join(
1477
+ [f'"{enum_value}"' for enum_value in def_properties["enum"]]
1478
+ )
1479
+ shared_definitions += ";\n"
1480
+ return shared_definitions
1481
+
1482
+ def generate_schema_from_functions(functions, namespace="functions") -> str:
1483
+ schema = (
1484
+ "// Supported function definitions that should be called when necessary.\n"
1485
+ )
1486
+ schema += f"namespace {namespace} {{\n\n"
1487
+
1488
+ # Generate shared definitions
1489
+ shared_definitions = {}
1490
+ for function in functions:
1491
+ parameters = function.get("parameters", {})
1492
+ shared_definitions.update(parameters.get("$defs", {}))
1493
+
1494
+ schema += generate_shared_definitions(shared_definitions, 1)
1495
+
1496
+ for function in functions:
1497
+ function_name = function["name"]
1498
+ description = function.get("description", "")
1499
+ parameters = function.get("parameters", {})
1500
+ required_params = parameters.get("required", [])
1501
+
1502
+ schema += f" // {description}\n"
1503
+ schema += f" type {function_name} = (_: {{\n"
1504
+
1505
+ for param_name, param in parameters.get("properties", {}).items():
1506
+ param_description = param.get("description", "")
1507
+ param_type = generate_type_definition(param, 2, shared_definitions)
1508
+ optional_indicator = "" if param_name in required_params else "?"
1509
+ schema += f" // {param_description}\n"
1510
+ schema += f" {param_name}{optional_indicator}: {param_type},\n"
1511
+ schema += " }) => any;\n\n"
1512
+
1513
+ schema += "}} // namespace {}\n".format(namespace)
1514
+ return schema
1515
+
1516
+ def prepare_messages_for_inference(
1517
+ messages: List[llama_types.ChatCompletionRequestMessage],
1518
+ functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
1519
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
1520
+ ):
1521
+ all_messages: List[llama_types.ChatCompletionRequestMessage] = []
1522
+ if functions is not None:
1523
+ all_messages.append(
1524
+ llama_types.ChatCompletionRequestSystemMessage(
1525
+ role="system", content=generate_schema_from_functions(functions)
1526
+ )
1527
+ )
1528
+
1529
+ if tools is not None:
1530
+ all_messages.append(
1531
+ llama_types.ChatCompletionRequestSystemMessage(
1532
+ role="system",
1533
+ content=generate_schema_from_functions(
1534
+ [
1535
+ tool["function"]
1536
+ for tool in tools
1537
+ if tool["type"] == "function"
1538
+ ]
1539
+ ),
1540
+ )
1541
+ )
1542
+
1543
+ all_messages.append(
1544
+ llama_types.ChatCompletionRequestSystemMessage(
1545
+ role="system", content=SYSTEM_MESSAGE
1546
+ )
1547
+ )
1548
+
1549
+ for message in messages:
1550
+ # Function call responses
1551
+ if message["role"] == "function" and "name" in message:
1552
+ message["name"] = f"functions.{message['name']}"
1553
+ # Function call requests by assistant
1554
+ if "function_call" in message:
1555
+ message["function_call"][
1556
+ "name"
1557
+ ] = f"functions.{message['function_call']['name']}"
1558
+ all_messages.append(message)
1559
+
1560
+ all_messages.append(
1561
+ llama_types.ChatCompletionRequestAssistantMessage(
1562
+ role="assistant", content=None
1563
+ )
1564
+ )
1565
+
1566
+ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
1567
+ if msg["role"] == "system":
1568
+ return f"system:\n{msg['content']}\n"
1569
+
1570
+ elif msg["role"] == "function" and "name" in msg:
1571
+ return f"function name={msg['name']}:\n{msg['content']}\n"
1572
+ elif msg["role"] == "function" and "function_call" in msg:
1573
+ return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
1574
+ elif msg["role"] == "tool":
1575
+ if msg["content"] is not None:
1576
+ return f"function name={msg['tool_call_id']}:\n{msg['content']}\n"
1577
+ else:
1578
+ return f"function name={msg['tool_call_id']}\n"
1579
+ elif msg["role"] == "user":
1580
+ if msg["content"] is None:
1581
+ return "user:\n</s></s>\n"
1582
+ else:
1583
+ return f"user:\n</s>{msg['content']}</s>\n"
1584
+ elif msg["role"] == "assistant":
1585
+ if msg["content"] is not None and "function_call" in msg:
1586
+ return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
1587
+ elif "function_call" in msg:
1588
+ return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
1589
+ elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
1590
+ for tool_call in msg[
1591
+ "tool_calls"
1592
+ ]: # NOTE: probably doesn't work with the functionary model
1593
+ return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
1594
+ elif msg["content"] is None:
1595
+ return "assistant"
1596
+ else:
1597
+ return f"assistant:\n{msg['content']}\n"
1598
+ else:
1599
+ raise ValueError(f"Unsupported role: {msg['role']}")
1600
+
1601
+ return "".join([message_to_str(msg) for msg in all_messages])
1602
+
1603
+ if tools is not None:
1604
+ functions = [tool["function"] for tool in tools if tool["type"] == "function"]
1605
+
1606
+ if tool_choice is not None:
1607
+ function_call = (
1608
+ tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
1609
+ )
1610
+
1611
+ prompt = prepare_messages_for_inference(messages, functions, tools)
1612
+
1613
+ if function_call is None and (functions is None or len(functions) == 0):
1614
+ completion_or_completion_chunks = llama.create_completion(
1615
+ prompt=prompt + ":\n",
1616
+ temperature=temperature,
1617
+ top_p=top_p,
1618
+ top_k=top_k,
1619
+ min_p=min_p,
1620
+ typical_p=typical_p,
1621
+ stream=stream,
1622
+ stop=["user:", "</s>"],
1623
+ max_tokens=max_tokens,
1624
+ presence_penalty=presence_penalty,
1625
+ frequency_penalty=frequency_penalty,
1626
+ repeat_penalty=repeat_penalty,
1627
+ tfs_z=tfs_z,
1628
+ mirostat_mode=mirostat_mode,
1629
+ mirostat_tau=mirostat_tau,
1630
+ mirostat_eta=mirostat_eta,
1631
+ model=model,
1632
+ logits_processor=logits_processor,
1633
+ grammar=grammar,
1634
+ )
1635
+ return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
1636
+
1637
+ if function_call is None or (
1638
+ isinstance(function_call, str) and function_call == "auto"
1639
+ ):
1640
+ stop = "\n"
1641
+ completion: llama_types.Completion = llama.create_completion(
1642
+ prompt=prompt, stop=stop, stream=False
1643
+ ) # type: ignore
1644
+ completion_text = completion["choices"][0]["text"]
1645
+ # strip " to=functions." and ending ":"
1646
+ function_call = completion_text.split(".")[-1][:-1]
1647
+ new_prompt = prompt + completion_text + stop
1648
+ elif isinstance(function_call, str) and function_call != "none":
1649
+ new_prompt = prompt + ":\n"
1650
+ elif isinstance(function_call, dict):
1651
+ new_prompt = prompt + f" to=functions.{function_call['name']}:\n"
1652
+ function_call = function_call["name"]
1653
+ else:
1654
+ new_prompt = prompt + ":\n"
1655
+
1656
+ function_body = None
1657
+ for function in functions or []:
1658
+ if function["name"] == function_call:
1659
+ function_body = function["parameters"]
1660
+ break
1661
+ for tool in tools or []:
1662
+ if tool["type"] == "function" and tool["function"]["name"] == function_call:
1663
+ function_body = tool["function"]["parameters"]
1664
+ break
1665
+
1666
+ if function_body is not None:
1667
+ try:
1668
+ with suppress_stdout_stderr(disable=llama.verbose):
1669
+ grammar_text = llama_grammar.json_schema_to_gbnf(
1670
+ json.dumps(function_body)
1671
+ )
1672
+ grammar = llama_grammar.LlamaGrammar.from_string(
1673
+ llama_grammar.json_schema_to_gbnf(json.dumps(function_body)),
1674
+ verbose=llama.verbose,
1675
+ )
1676
+ print(grammar_text)
1677
+ except Exception as e:
1678
+ if llama.verbose:
1679
+ print(
1680
+ "Failed to parse function body as JSON schema, falling back to default grammar"
1681
+ )
1682
+ print(e)
1683
+ with suppress_stdout_stderr(disable=llama.verbose):
1684
+ grammar = llama_grammar.LlamaGrammar.from_string(
1685
+ llama_grammar.JSON_GBNF,
1686
+ verbose=llama.verbose,
1687
+ )
1688
+ else:
1689
+ with suppress_stdout_stderr(disable=llama.verbose):
1690
+ grammar = llama_grammar.LlamaGrammar.from_string(
1691
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
1692
+ )
1693
+
1694
+ completion: llama_types.Completion = llama.create_completion(
1695
+ prompt=new_prompt,
1696
+ stop=["user:", "</s>"],
1697
+ stream=False,
1698
+ grammar=grammar,
1699
+ max_tokens=max_tokens,
1700
+ temperature=temperature,
1701
+ top_p=top_p,
1702
+ top_k=top_k,
1703
+ min_p=min_p,
1704
+ typical_p=typical_p,
1705
+ presence_penalty=presence_penalty,
1706
+ frequency_penalty=frequency_penalty,
1707
+ repeat_penalty=repeat_penalty,
1708
+ tfs_z=tfs_z,
1709
+ mirostat_mode=mirostat_mode,
1710
+ mirostat_tau=mirostat_tau,
1711
+ mirostat_eta=mirostat_eta,
1712
+ model=model,
1713
+ logits_processor=logits_processor,
1714
+ ) # type: ignore
1715
+
1716
+ assert "usage" in completion
1717
+ assert isinstance(function_call, str)
1718
+ assert stream is False # TODO: support stream mode
1719
+
1720
+ if llama.verbose:
1721
+ print(new_prompt)
1722
+ print(completion["choices"][0]["text"])
1723
+
1724
+ # TODO: support stream mode
1725
+ return llama_types.CreateChatCompletionResponse(
1726
+ id="chat" + completion["id"],
1727
+ object="chat.completion",
1728
+ created=completion["created"],
1729
+ model=completion["model"],
1730
+ choices=[
1731
+ {
1732
+ "index": 0,
1733
+ "message": {
1734
+ "role": "assistant",
1735
+ "content": None,
1736
+ "function_call": {
1737
+ "name": function_call,
1738
+ "arguments": completion["choices"][0]["text"],
1739
+ },
1740
+ "tool_calls": [
1741
+ {
1742
+ "id": function_call,
1743
+ "type": "function",
1744
+ "function": {
1745
+ "name": function_call,
1746
+ "arguments": completion["choices"][0]["text"],
1747
+ },
1748
+ }
1749
+ ],
1750
+ },
1751
+ "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
1752
+ "finish_reason": "tool_calls",
1753
+ }
1754
+ ],
1755
+ usage=completion["usage"],
1756
+ )
1757
+
1758
+
1759
+ @register_chat_completion_handler("functionary-v1")
1760
+ @register_chat_completion_handler("functionary-v2")
1761
+ def functionary_v1_v2_chat_handler(
1762
+ llama: llama.Llama,
1763
+ messages: List[llama_types.ChatCompletionRequestMessage],
1764
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
1765
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
1766
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
1767
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
1768
+ temperature: float = 0.2,
1769
+ top_p: float = 0.95,
1770
+ top_k: int = 40,
1771
+ min_p: float = 0.05,
1772
+ typical_p: float = 1.0,
1773
+ stream: bool = False,
1774
+ stop: Optional[Union[str, List[str]]] = [],
1775
+ response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
1776
+ max_tokens: Optional[int] = None,
1777
+ presence_penalty: float = 0.0,
1778
+ frequency_penalty: float = 0.0,
1779
+ repeat_penalty: float = 1.1,
1780
+ tfs_z: float = 1.0,
1781
+ mirostat_mode: int = 0,
1782
+ mirostat_tau: float = 5.0,
1783
+ mirostat_eta: float = 0.1,
1784
+ model: Optional[str] = None,
1785
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
1786
+ grammar: Optional[llama.LlamaGrammar] = None,
1787
+ **kwargs, # type: ignore
1788
+ ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
1789
+ SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
1790
+
1791
+ tokenizer = llama.tokenizer_
1792
+ assert hasattr(
1793
+ tokenizer, "hf_tokenizer"
1794
+ ), "Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class"
1795
+ from transformers import AutoTokenizer
1796
+
1797
+ if "<|START_OF_FUNCTION_CALL|>" in tokenizer.hf_tokenizer.additional_special_tokens:
1798
+ version = "v1"
1799
+ END_SYSTEM_TOKEN = "<|END_OF_SYSTEM|>"
1800
+ END_USER_TOKEN = "<|END_OF_USER|>"
1801
+ END_ASSISTANT_TOKEN = "<|END_OF_ASSISTANT|>"
1802
+ END_FUNCTION_RESULT_TOKEN = "<|END_OF_FUNCTION_RESULT|>"
1803
+ START_FUNCTION_CALL_TOKEN = "<|START_OF_FUNCTION_CALL|>"
1804
+ END_FUNCTION_CALL_TOKEN = "<|END_OF_FUNCTION_CALL|>"
1805
+ else:
1806
+ version = "v2"
1807
+ RECIPIENT_TOKEN = "<|recipient|>"
1808
+ FROM_TOKEN = "<|from|>"
1809
+ STOP_TOKEN = "<|stop|>"
1810
+ CONTENT_TOKEN = "<|content|>"
1811
+
1812
+ def generate_type_definition(
1813
+ param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
1814
+ ) -> str:
1815
+ indent = " " * indent_level
1816
+ if "$ref" in param:
1817
+ # Reference to a shared definition
1818
+ ref_name = param["$ref"].split("/")[
1819
+ -1
1820
+ ] # Extract the type name from the reference
1821
+ return ref_name
1822
+ elif param.get("type") == "array":
1823
+ items = param.get("items", {})
1824
+ item_type = generate_type_definition(items, indent_level + 1, shared_defs)
1825
+ return f"Array<{item_type}>"
1826
+ elif param.get("type") == "object":
1827
+ properties = param.get("properties", {})
1828
+ nested_schema = "{\n"
1829
+ for nested_param_name, nested_param in properties.items():
1830
+ nested_param_type = generate_type_definition(
1831
+ nested_param, indent_level + 1, shared_defs
1832
+ )
1833
+ nested_schema += (
1834
+ f"{indent} {nested_param_name}: {nested_param_type},\n"
1835
+ )
1836
+ nested_schema += indent + "}"
1837
+ return nested_schema
1838
+ elif "enum" in param:
1839
+ # Enum type
1840
+ return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
1841
+ else:
1842
+ # Simple type
1843
+ return param.get("type", "any")
1844
+
1845
+ def generate_shared_definitions(shared_defs, indent_level: int) -> str:
1846
+ indent = " " * indent_level
1847
+ shared_definitions = ""
1848
+ for def_name, def_properties in shared_defs.items():
1849
+ shared_definitions += f"{indent}type {def_name} = "
1850
+ if def_properties.get("type") == "object":
1851
+ shared_definitions += generate_type_definition(
1852
+ def_properties, indent_level, shared_defs
1853
+ )
1854
+ elif "enum" in def_properties:
1855
+ # Enum type
1856
+ shared_definitions += " | ".join(
1857
+ [f'"{enum_value}"' for enum_value in def_properties["enum"]]
1858
+ )
1859
+ shared_definitions += ";\n"
1860
+ return shared_definitions
1861
+
1862
+ def generate_schema_from_functions(functions, namespace="functions") -> str:
1863
+ schema = (
1864
+ "// Supported function definitions that should be called when necessary.\n"
1865
+ )
1866
+ schema += f"namespace {namespace} {{\n\n"
1867
+
1868
+ # Generate shared definitions
1869
+ shared_definitions = {}
1870
+ for function in functions:
1871
+ parameters = function.get("parameters", {})
1872
+ shared_definitions.update(parameters.get("$defs", {}))
1873
+
1874
+ schema += generate_shared_definitions(shared_definitions, 1)
1875
+
1876
+ for function in functions:
1877
+ function_name = function["name"]
1878
+ description = function.get("description", "")
1879
+ parameters = function.get("parameters", {})
1880
+ required_params = parameters.get("required", [])
1881
+
1882
+ schema += f"// {description}\n"
1883
+ schema += f"type {function_name} = (_: {{\n"
1884
+
1885
+ for param_name, param in parameters.get("properties", {}).items():
1886
+ param_description = param.get("description", "")
1887
+ param_type = generate_type_definition(param, 2, shared_definitions)
1888
+ optional_indicator = "" if param_name in required_params else "?"
1889
+ schema += f"// {param_description}\n"
1890
+ schema += f"{param_name}{optional_indicator}: {param_type},\n"
1891
+ schema += "}) => any;\n\n"
1892
+
1893
+ schema += "}} // namespace {}".format(namespace)
1894
+ return schema
1895
+
1896
+ def prepare_messages_for_inference(
1897
+ messages: List[llama_types.ChatCompletionRequestMessage],
1898
+ tokenizer: AutoTokenizer,
1899
+ version: Literal["v1", "v2"],
1900
+ functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
1901
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
1902
+ tool_choice: Union[Dict, str] = "auto",
1903
+ ):
1904
+ all_messages: List[llama_types.ChatCompletionRequestMessage] = []
1905
+ if tool_choice == "none":
1906
+ all_messages.append(
1907
+ llama_types.ChatCompletionRequestSystemMessage(
1908
+ role="system", content=generate_schema_from_functions([])
1909
+ )
1910
+ )
1911
+ else:
1912
+ if functions is not None:
1913
+ all_messages.append(
1914
+ llama_types.ChatCompletionRequestSystemMessage(
1915
+ role="system", content=generate_schema_from_functions(functions)
1916
+ )
1917
+ )
1918
+ elif tools is not None and tool_choice != "none":
1919
+ all_messages.append(
1920
+ llama_types.ChatCompletionRequestSystemMessage(
1921
+ role="system",
1922
+ content=generate_schema_from_functions(
1923
+ [
1924
+ tool["function"]
1925
+ for tool in tools
1926
+ if tool["type"] == "function"
1927
+ ]
1928
+ ),
1929
+ )
1930
+ )
1931
+
1932
+ all_messages.append(
1933
+ llama_types.ChatCompletionRequestSystemMessage(
1934
+ role="system", content=SYSTEM_MESSAGE
1935
+ )
1936
+ )
1937
+
1938
+ for message in messages:
1939
+ # Function call responses
1940
+ if message["role"] == "function" and "name" in message:
1941
+ message["name"] = f"functions.{message['name']}"
1942
+ # Function call requests by assistant
1943
+ if "function_call" in message:
1944
+ message["function_call"][
1945
+ "name"
1946
+ ] = f"functions.{message['function_call']['name']}"
1947
+ all_messages.append(message)
1948
+
1949
+ if version == "v1":
1950
+ suffix = "assistant:\n"
1951
+ else:
1952
+ suffix = "<|from|>assistant\n<|recipient|>"
1953
+
1954
+ return (
1955
+ tokenizer.hf_tokenizer.apply_chat_template(all_messages, tokenize=False)
1956
+ + suffix
1957
+ )
1958
+
1959
+ if tools is not None:
1960
+ functions = [tool["function"] for tool in tools if tool["type"] == "function"]
1961
+
1962
+ if tool_choice is not None:
1963
+ function_call = (
1964
+ tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
1965
+ )
1966
+ elif function_call is not None:
1967
+ pass
1968
+ else:
1969
+ function_call = "auto"
1970
+
1971
+ prompt = prepare_messages_for_inference(
1972
+ messages, tokenizer, version, functions, tools, function_call
1973
+ )
1974
+
1975
+ # If no tools/functions are provided
1976
+ if function_call == "none" or functions is None or len(functions) == 0:
1977
+ if version == "v1":
1978
+ stop = END_ASSISTANT_TOKEN
1979
+ else:
1980
+ stop = STOP_TOKEN
1981
+ prompt += "all\n<|content|>"
1982
+
1983
+ completion_or_completion_chunks = llama.create_completion(
1984
+ prompt=prompt,
1985
+ temperature=temperature,
1986
+ top_p=top_p,
1987
+ top_k=top_k,
1988
+ min_p=min_p,
1989
+ typical_p=typical_p,
1990
+ stream=stream,
1991
+ stop=stop,
1992
+ max_tokens=max_tokens,
1993
+ presence_penalty=presence_penalty,
1994
+ frequency_penalty=frequency_penalty,
1995
+ repeat_penalty=repeat_penalty,
1996
+ tfs_z=tfs_z,
1997
+ mirostat_mode=mirostat_mode,
1998
+ mirostat_tau=mirostat_tau,
1999
+ mirostat_eta=mirostat_eta,
2000
+ model=model,
2001
+ logits_processor=logits_processor,
2002
+ grammar=grammar,
2003
+ )
2004
+ if stream is False:
2005
+ completion_or_completion_chunks["choices"][0]["text"] = (
2006
+ completion_or_completion_chunks["choices"][0]["text"].lstrip()
2007
+ )
2008
+ return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
2009
+
2010
+ def get_grammar(function_call):
2011
+ function_body = None
2012
+ for function in functions or []:
2013
+ if function["name"] == function_call:
2014
+ function_body = function["parameters"]
2015
+ break
2016
+ for tool in tools or []:
2017
+ if tool["type"] == "function" and tool["function"]["name"] == function_call:
2018
+ function_body = tool["function"]["parameters"]
2019
+ break
2020
+
2021
+ try:
2022
+ with suppress_stdout_stderr(disable=llama.verbose):
2023
+ grammar_text = llama_grammar.json_schema_to_gbnf(
2024
+ json.dumps(function_body)
2025
+ )
2026
+ grammar = llama_grammar.LlamaGrammar.from_string(
2027
+ llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
2028
+ )
2029
+ print(grammar_text)
2030
+ except Exception as e:
2031
+ if llama.verbose:
2032
+ print(
2033
+ "Failed to parse function body as JSON schema, falling back to default grammar"
2034
+ )
2035
+ print(e)
2036
+ with suppress_stdout_stderr(disable=llama.verbose):
2037
+ grammar = llama_grammar.LlamaGrammar.from_string(
2038
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
2039
+ )
2040
+
2041
+ return grammar
2042
+
2043
+ def create_completion(prompt, stop, grammar):
2044
+ completion = cast(
2045
+ llama_types.Completion,
2046
+ llama.create_completion(
2047
+ prompt=prompt,
2048
+ temperature=temperature,
2049
+ top_p=top_p,
2050
+ top_k=top_k,
2051
+ min_p=min_p,
2052
+ typical_p=typical_p,
2053
+ stream=stream,
2054
+ stop=stop,
2055
+ max_tokens=max_tokens,
2056
+ presence_penalty=presence_penalty,
2057
+ frequency_penalty=frequency_penalty,
2058
+ repeat_penalty=repeat_penalty,
2059
+ tfs_z=tfs_z,
2060
+ mirostat_mode=mirostat_mode,
2061
+ mirostat_tau=mirostat_tau,
2062
+ mirostat_eta=mirostat_eta,
2063
+ model=model,
2064
+ logits_processor=logits_processor,
2065
+ grammar=grammar,
2066
+ ),
2067
+ )
2068
+
2069
+ return completion
2070
+
2071
+ content = ""
2072
+ function_calls, function_bodies = [], []
2073
+ completion_tokens = 0
2074
+
2075
+ def generate_streaming(tools, functions, function_call, prompt):
2076
+ assert version == "v2", "Streaming for v1 is not supported"
2077
+
2078
+ chunk_id, chunk_created = None, None
2079
+
2080
+ # If tool_choice/function_call is provided
2081
+ if isinstance(function_call, dict):
2082
+ prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
2083
+ grammar = get_grammar(function_call["name"])
2084
+ stops = [STOP_TOKEN, FROM_TOKEN]
2085
+ tool_id = "".join(
2086
+ [random.choice(string.ascii_letters + string.digits) for _ in range(24)]
2087
+ )
2088
+ completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
2089
+ completion_text = ""
2090
+ first = True
2091
+ for chunk in completion:
2092
+ # Yield the tool/function name first
2093
+ if first:
2094
+ if tools is not None:
2095
+ func_call_dict = {
2096
+ "tool_calls": [
2097
+ {
2098
+ "index": 0,
2099
+ "id": "call_" + tool_id,
2100
+ "type": "function",
2101
+ "function": {
2102
+ "name": function_call["name"],
2103
+ "arguments": "",
2104
+ },
2105
+ }
2106
+ ]
2107
+ }
2108
+ else:
2109
+ func_call_dict = {
2110
+ "function_call": {
2111
+ "name": function_call["name"],
2112
+ "arguments": "",
2113
+ }
2114
+ }
2115
+ yield llama_types.CreateChatCompletionStreamResponse(
2116
+ id="chat" + chunk["id"],
2117
+ object="chat.completion.chunk",
2118
+ created=chunk["created"],
2119
+ model=chunk["model"],
2120
+ choices=[
2121
+ {
2122
+ "index": 0,
2123
+ "logprobs": None,
2124
+ "delta": {
2125
+ "role": None,
2126
+ "content": None,
2127
+ **func_call_dict,
2128
+ },
2129
+ }
2130
+ ],
2131
+ )
2132
+ first = False
2133
+ if tools is not None:
2134
+ func_call_dict = {
2135
+ "tool_calls": [
2136
+ {
2137
+ "index": 0,
2138
+ "id": "call_" + tool_id,
2139
+ "type": "function",
2140
+ "function": {
2141
+ "name": None,
2142
+ "arguments": chunk["choices"][0]["text"].rstrip(),
2143
+ },
2144
+ }
2145
+ ]
2146
+ }
2147
+ else:
2148
+ func_call_dict = {
2149
+ "function_call": {
2150
+ "name": None,
2151
+ "arguments": chunk["choices"][0]["text"].rstrip(),
2152
+ }
2153
+ }
2154
+ if len(chunk["choices"][0]["text"].rstrip()) > 0:
2155
+ yield llama_types.CreateChatCompletionStreamResponse(
2156
+ id="chat" + chunk["id"],
2157
+ object="chat.completion.chunk",
2158
+ created=chunk["created"],
2159
+ model=chunk["model"],
2160
+ choices=[
2161
+ {
2162
+ "index": 0,
2163
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2164
+ "delta": {
2165
+ "role": None,
2166
+ "content": None,
2167
+ **func_call_dict,
2168
+ },
2169
+ }
2170
+ ],
2171
+ )
2172
+ # Yield tool_call/function_call stop message
2173
+ yield llama_types.CreateChatCompletionStreamResponse(
2174
+ id="chat" + chunk["id"],
2175
+ object="chat.completion.chunk",
2176
+ created=chunk["created"],
2177
+ model=chunk["model"],
2178
+ choices=[
2179
+ {
2180
+ "index": 0,
2181
+ "finish_reason": (
2182
+ "tool_calls" if tools is not None else "function_call"
2183
+ ),
2184
+ "logprobs": None,
2185
+ "delta": {
2186
+ "role": None,
2187
+ "content": None,
2188
+ "function_call": None,
2189
+ "tool_calls": None,
2190
+ },
2191
+ }
2192
+ ],
2193
+ )
2194
+ # If "auto" or no tool_choice/function_call
2195
+ elif isinstance(function_call, str) and function_call == "auto":
2196
+ tool_index = 0
2197
+ while True:
2198
+ # Generate function name first
2199
+ grammar = None
2200
+ stops = CONTENT_TOKEN
2201
+ completion = create_completion(
2202
+ prompt=prompt, stop=stops, grammar=grammar
2203
+ )
2204
+ completion_text = ""
2205
+ for chunk in completion:
2206
+ completion_text += chunk["choices"][0]["text"]
2207
+ if chunk_id is None:
2208
+ chunk_id = chunk["id"]
2209
+ if chunk_created is None:
2210
+ chunk_created = chunk["created"]
2211
+ function_name = completion_text.strip()
2212
+ if function_name == "all":
2213
+ prompt += "all\n<|content|>"
2214
+ # Yield the first empty message for content
2215
+ yield llama_types.CreateChatCompletionStreamResponse(
2216
+ id="chat" + chunk_id,
2217
+ model=chunk["model"],
2218
+ created=chunk_created,
2219
+ object="chat.completion.chunk",
2220
+ choices=[
2221
+ {
2222
+ "index": 0,
2223
+ "delta": {"role": "assistant", "content": ""},
2224
+ "logprobs": None,
2225
+ "finish_reason": None,
2226
+ }
2227
+ ],
2228
+ )
2229
+ else:
2230
+ prompt += f"{function_name}\n<|content|>"
2231
+ grammar = get_grammar(function_name)
2232
+ tool_id = "".join(
2233
+ [
2234
+ random.choice(string.ascii_letters + string.digits)
2235
+ for _ in range(24)
2236
+ ]
2237
+ )
2238
+ if tools is not None:
2239
+ func_call_dict = {
2240
+ "tool_calls": [
2241
+ {
2242
+ "index": tool_index,
2243
+ "id": "call_" + tool_id,
2244
+ "type": "function",
2245
+ "function": {
2246
+ "name": function_name,
2247
+ "arguments": "",
2248
+ },
2249
+ }
2250
+ ]
2251
+ }
2252
+ else:
2253
+ func_call_dict = {
2254
+ "function_call": {"name": function_name, "arguments": ""}
2255
+ }
2256
+ # Stream function name
2257
+ yield llama_types.CreateChatCompletionStreamResponse(
2258
+ id="chat" + chunk_id,
2259
+ object="chat.completion.chunk",
2260
+ created=chunk_created,
2261
+ model=chunk["model"],
2262
+ choices=[
2263
+ {
2264
+ "index": 0,
2265
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2266
+ "delta": {
2267
+ "role": "assistant",
2268
+ "content": None,
2269
+ **func_call_dict,
2270
+ },
2271
+ }
2272
+ ],
2273
+ )
2274
+ # Generate content
2275
+ stops = [RECIPIENT_TOKEN, STOP_TOKEN]
2276
+ completion = create_completion(
2277
+ prompt=prompt, stop=stops, grammar=grammar
2278
+ )
2279
+ if function_name == "all":
2280
+ completion_text = ""
2281
+ stop_sequence, buffer, is_end = (
2282
+ "\n<|from|>assistant\n<|recipient|>",
2283
+ [],
2284
+ False,
2285
+ )
2286
+ for i, chunk in enumerate(completion):
2287
+ completion_text += chunk["choices"][0]["text"]
2288
+ if is_end:
2289
+ buffer.append(chunk["choices"][0]["text"].strip(" "))
2290
+ if stop_sequence.startswith("".join(buffer)):
2291
+ continue
2292
+ else:
2293
+ buffer.pop()
2294
+ while len(buffer) > 0:
2295
+ yield llama_types.CreateChatCompletionStreamResponse(
2296
+ id="chat" + chunk_id,
2297
+ object="chat.completion.chunk",
2298
+ created=chunk_created,
2299
+ model=chunk["model"],
2300
+ choices=[
2301
+ {
2302
+ "index": 0,
2303
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2304
+ "delta": {
2305
+ "role": "assistant",
2306
+ "content": buffer.pop(0),
2307
+ },
2308
+ }
2309
+ ],
2310
+ )
2311
+ is_end = False
2312
+ elif chunk["choices"][0]["text"] == "\n":
2313
+ is_end = True
2314
+ buffer.append(chunk["choices"][0]["text"].strip(" "))
2315
+ continue
2316
+
2317
+ if len(buffer) == 0 and len(chunk["choices"][0]["text"]) > 0:
2318
+ yield llama_types.CreateChatCompletionStreamResponse(
2319
+ id="chat" + chunk_id,
2320
+ object="chat.completion.chunk",
2321
+ created=chunk_created,
2322
+ model=chunk["model"],
2323
+ choices=[
2324
+ {
2325
+ "index": 0,
2326
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2327
+ "delta": {
2328
+ "role": "assistant",
2329
+ "content": (
2330
+ chunk["choices"][0]["text"]
2331
+ if i > 0
2332
+ else chunk["choices"][0][
2333
+ "text"
2334
+ ].lstrip()
2335
+ ),
2336
+ },
2337
+ }
2338
+ ],
2339
+ )
2340
+ # Check whether the model wants to generate another turn
2341
+ if (
2342
+ "<|from|> assistant" in completion_text
2343
+ or "<|from|>assistant" in completion_text
2344
+ ):
2345
+ if completion_text.endswith("\n<|from|>assistant\n"):
2346
+ cleaned_completion_text = completion_text[
2347
+ : -len("\n<|from|>assistant\n")
2348
+ ].strip()
2349
+ elif completion_text.endswith("\n<|from|> assistant\n"):
2350
+ cleaned_completion_text = completion_text[
2351
+ : -len("\n<|from|> assistant\n")
2352
+ ].strip()
2353
+ else:
2354
+ cleaned_completion_text = completion_text.strip()
2355
+ prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
2356
+ else:
2357
+ # Yield stop message
2358
+ yield llama_types.CreateChatCompletionStreamResponse(
2359
+ id="chat" + chunk_id,
2360
+ model=chunk["model"],
2361
+ created=chunk_created,
2362
+ object="chat.completion.chunk",
2363
+ choices=[
2364
+ {
2365
+ "index": 0,
2366
+ "delta": {},
2367
+ "logprobs": None,
2368
+ "finish_reason": "stop",
2369
+ }
2370
+ ],
2371
+ )
2372
+ break
2373
+ else:
2374
+ # Check whether the model wants to generate another turn
2375
+ completion_text = ""
2376
+ for chunk in completion:
2377
+ completion_text += chunk["choices"][0]["text"]
2378
+ if len(chunk["choices"][0]["text"].rstrip()) > 0:
2379
+ if tools is not None:
2380
+ func_call_dict = {
2381
+ "tool_calls": [
2382
+ {
2383
+ "index": tool_index,
2384
+ "id": "call_" + tool_id,
2385
+ "type": "function",
2386
+ "function": {
2387
+ "name": None,
2388
+ "arguments": chunk["choices"][0][
2389
+ "text"
2390
+ ].rstrip(),
2391
+ },
2392
+ }
2393
+ ]
2394
+ }
2395
+ else:
2396
+ func_call_dict = {
2397
+ "function_call": {
2398
+ "name": None,
2399
+ "arguments": chunk["choices"][0][
2400
+ "text"
2401
+ ].rstrip(),
2402
+ }
2403
+ }
2404
+ yield llama_types.CreateChatCompletionStreamResponse(
2405
+ id="chat" + chunk_id,
2406
+ object="chat.completion.chunk",
2407
+ created=chunk_created,
2408
+ model=chunk["model"],
2409
+ choices=[
2410
+ {
2411
+ "index": 0,
2412
+ "logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2413
+ "delta": {
2414
+ "role": None,
2415
+ "content": None,
2416
+ **func_call_dict,
2417
+ },
2418
+ }
2419
+ ],
2420
+ )
2421
+ prompt += completion_text.strip()
2422
+ grammar = None
2423
+ completion = create_completion(
2424
+ prompt=prompt, stop=stops, grammar=grammar
2425
+ )
2426
+ completion_text += "".join(
2427
+ [chunk["choices"][0]["text"] for chunk in completion]
2428
+ )
2429
+ if (
2430
+ "<|from|> assistant" in completion_text
2431
+ or "<|from|>assistant" in completion_text
2432
+ ) and tools is not None:
2433
+ prompt += "\n<|from|>assistant\n<|recipient|>"
2434
+ tool_index += 1
2435
+ else:
2436
+ # Yield tool_call/function_call stop message
2437
+ yield llama_types.CreateChatCompletionStreamResponse(
2438
+ id="chat" + chunk_id,
2439
+ object="chat.completion.chunk",
2440
+ created=chunk_created,
2441
+ model=chunk["model"],
2442
+ choices=[
2443
+ {
2444
+ "index": 0,
2445
+ "finish_reason": (
2446
+ "tool_calls"
2447
+ if tools is not None
2448
+ else "function_call"
2449
+ ),
2450
+ "logprobs": None,
2451
+ "delta": {
2452
+ "role": None,
2453
+ "content": None,
2454
+ "function_call": None,
2455
+ "tool_calls": None,
2456
+ },
2457
+ }
2458
+ ],
2459
+ )
2460
+ break
2461
+
2462
+ if stream is not False:
2463
+ return generate_streaming(
2464
+ tools=tools, functions=functions, function_call=function_call, prompt=prompt
2465
+ )
2466
+ else:
2467
+ if version == "v1":
2468
+ # If no or "auto" tool_choice/function_call
2469
+ if isinstance(function_call, str) and function_call == "auto":
2470
+ stops = ["\n", END_ASSISTANT_TOKEN]
2471
+ # If tool_choice/function_call is provided
2472
+ elif isinstance(function_call, dict):
2473
+ prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
2474
+ stops = END_FUNCTION_CALL_TOKEN
2475
+ function_call = function_call["name"]
2476
+ function_calls.append(function_call)
2477
+ grammar = get_grammar(function_call)
2478
+ else:
2479
+ prompt = prompt
2480
+ stops = ["\n", END_ASSISTANT_TOKEN]
2481
+
2482
+ completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
2483
+ completion_text = completion["choices"][0]["text"]
2484
+ completion_tokens += completion["usage"]["completion_tokens"]
2485
+
2486
+ # If the generation does not involve a function call
2487
+ if (
2488
+ START_FUNCTION_CALL_TOKEN not in prompt
2489
+ and START_FUNCTION_CALL_TOKEN not in completion_text
2490
+ ):
2491
+ completion["usage"]["completion_tokens"] = completion_tokens
2492
+ return _convert_completion_to_chat(completion, stream=stream) # type: ignore
2493
+ # If the generation involves a function call in completion, generate the parameters
2494
+ elif (
2495
+ START_FUNCTION_CALL_TOKEN not in prompt
2496
+ and START_FUNCTION_CALL_TOKEN in completion_text
2497
+ ):
2498
+ prompt += (
2499
+ completion_text.replace(
2500
+ f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN
2501
+ )
2502
+ + "\n"
2503
+ )
2504
+ function_calls.append(
2505
+ completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
2506
+ )
2507
+ grammar = get_grammar(function_calls[-1])
2508
+ completion = create_completion(
2509
+ prompt=prompt, stop=END_FUNCTION_CALL_TOKEN, grammar=grammar
2510
+ )
2511
+ completion_tokens += completion["usage"]["completion_tokens"]
2512
+ function_bodies.append(completion["choices"][0]["text"].strip())
2513
+ # If the prompt involves a function call, just append generated parameters to function_bodies
2514
+ else:
2515
+ function_bodies.append(completion_text.strip())
2516
+ else:
2517
+ # If tool_choice/function_call is provided
2518
+ if isinstance(function_call, dict):
2519
+ prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
2520
+ function_call = function_call["name"]
2521
+ function_calls.append(function_call)
2522
+ grammar = get_grammar(function_call)
2523
+ stops = [STOP_TOKEN, FROM_TOKEN]
2524
+ completion = create_completion(
2525
+ prompt=prompt, stop=stops, grammar=grammar
2526
+ )
2527
+ completion_text = completion["choices"][0]["text"]
2528
+ completion_tokens += completion["usage"]["completion_tokens"]
2529
+ function_bodies.append(completion_text.strip())
2530
+ # If "auto" or no tool_choice/function_call
2531
+ elif isinstance(function_call, str) and function_call == "auto":
2532
+ while True:
2533
+ # Generate function name first
2534
+ grammar = None
2535
+ stops = CONTENT_TOKEN
2536
+ completion = create_completion(
2537
+ prompt=prompt, stop=stops, grammar=grammar
2538
+ )
2539
+ completion_text = completion["choices"][0]["text"]
2540
+ completion_tokens += completion["usage"]["completion_tokens"]
2541
+ function_name = completion_text.strip()
2542
+ if function_name == "all":
2543
+ prompt += "all\n<|content|>"
2544
+ else:
2545
+ function_call = completion_text.strip()
2546
+ prompt += f"{function_call}\n<|content|>"
2547
+ function_calls.append(function_call)
2548
+ grammar = get_grammar(function_call)
2549
+ # Generate content
2550
+ stops = [RECIPIENT_TOKEN, STOP_TOKEN]
2551
+ completion = create_completion(
2552
+ prompt=prompt, stop=stops, grammar=grammar
2553
+ )
2554
+ completion_text = completion["choices"][0]["text"]
2555
+ completion_tokens += completion["usage"]["completion_tokens"]
2556
+ if function_name == "all":
2557
+ if completion_text.endswith("\n<|from|>assistant\n"):
2558
+ content += completion_text[: -len("\n<|from|>assistant\n")]
2559
+ if completion_text.endswith("\n<|from|> assistant\n"):
2560
+ content += completion_text[-len("\n<|from|> assistant\n")]
2561
+ else:
2562
+ content += completion_text
2563
+ content = content.lstrip()
2564
+ # Check whether the model wants to generate another turn
2565
+ if (
2566
+ "<|from|> assistant" in completion_text
2567
+ or "<|from|>assistant" in completion_text
2568
+ ):
2569
+ if completion_text.endswith("\n<|from|>assistant\n"):
2570
+ cleaned_completion_text = completion_text[
2571
+ : -len("\n<|from|>assistant\n")
2572
+ ].strip()
2573
+ elif completion_text.endswith("\n<|from|> assistant\n"):
2574
+ cleaned_completion_text = completion_text[
2575
+ -len("\n<|from|> assistant\n")
2576
+ ].strip()
2577
+ else:
2578
+ cleaned_completion_text = completion_text.strip()
2579
+ prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
2580
+ else:
2581
+ break
2582
+ else:
2583
+ function_bodies.append(completion_text.strip())
2584
+ # Check whether the model wants to generate another turn
2585
+ prompt += completion_text.strip()
2586
+ grammar = None
2587
+ completion = create_completion(
2588
+ prompt=prompt, stop=stops, grammar=grammar
2589
+ )
2590
+ completion_tokens += completion["usage"]["completion_tokens"]
2591
+ if (
2592
+ "<|from|> assistant" in completion["choices"][0]["text"]
2593
+ or "<|from|>assistant" in completion["choices"][0]["text"]
2594
+ ):
2595
+ prompt += "\n<|from|>assistant\n<|recipient|>"
2596
+ else:
2597
+ break
2598
+
2599
+ assert "usage" in completion
2600
+ assert len(function_calls) == len(function_bodies)
2601
+
2602
+ tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
2603
+ for function_call, function_body in zip(function_calls, function_bodies):
2604
+ tool_calls.append(
2605
+ {
2606
+ "id": "call_"
2607
+ + "".join(
2608
+ [
2609
+ random.choice(string.ascii_letters + string.digits)
2610
+ for _ in range(24)
2611
+ ]
2612
+ ),
2613
+ "type": "function",
2614
+ "function": {
2615
+ "name": function_call,
2616
+ "arguments": function_body,
2617
+ },
2618
+ }
2619
+ )
2620
+
2621
+ # TODO: support stream mode
2622
+ function_call_dict: Union[
2623
+ Dict[str, str],
2624
+ Dict[
2625
+ Literal["function_call"],
2626
+ llama_types.ChatCompletionRequestAssistantMessageFunctionCall,
2627
+ ],
2628
+ ] = {}
2629
+ if len(tool_calls) > 0:
2630
+ if tools is not None:
2631
+ function_call_dict["tool_calls"] = tool_calls
2632
+ else:
2633
+ function_call_dict["function_call"] = {
2634
+ "name": tool_calls[0]["function"]["name"],
2635
+ "arguments": tool_calls[0]["function"]["arguments"],
2636
+ }
2637
+ completion["usage"]["completion_tokens"] = completion_tokens
2638
+ return llama_types.CreateChatCompletionResponse(
2639
+ id="chat" + completion["id"],
2640
+ object="chat.completion",
2641
+ created=completion["created"],
2642
+ model=completion["model"],
2643
+ choices=[
2644
+ {
2645
+ "index": 0,
2646
+ "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
2647
+ "message": {
2648
+ "role": "assistant",
2649
+ "content": None if content == "" else content,
2650
+ **function_call_dict,
2651
+ },
2652
+ "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
2653
+ }
2654
+ ],
2655
+ usage=completion["usage"],
2656
+ )
2657
+
2658
+
2659
+ class Llava15ChatHandler:
2660
+ DEFAULT_SYSTEM_MESSAGE: Optional[str] = (
2661
+ "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
2662
+ )
2663
+
2664
+ CHAT_FORMAT = (
2665
+ "{% for message in messages %}"
2666
+ "{% if message.role == 'system' %}"
2667
+ "{{ message.content }}"
2668
+ "{% endif %}"
2669
+ "{% if message.role == 'user' %}"
2670
+ "{% if message.content is string %}"
2671
+ "\nUSER: {{ message.content }}"
2672
+ "{% endif %}"
2673
+ "{% if message.content is iterable %}"
2674
+ "\nUSER: "
2675
+ "{% for content in message.content %}"
2676
+ "{% if content.type == 'image_url' and content.image_url is string %}"
2677
+ "{{ content.image_url }}"
2678
+ "{% endif %}"
2679
+ "{% if content.type == 'image_url' and content.image_url is mapping %}"
2680
+ "{{ content.image_url.url }}"
2681
+ "{% endif %}"
2682
+ "{% endfor %}"
2683
+ "{% for content in message.content %}"
2684
+ "{% if content.type == 'text' %}"
2685
+ "{{ content.text }}"
2686
+ "{% endif %}"
2687
+ "{% endfor %}"
2688
+ "{% endif %}"
2689
+ "{% endif %}"
2690
+ "{% if message.role == 'assistant' and message.content is not none %}"
2691
+ "\nASSISTANT: {{ message.content }}"
2692
+ "{% endif %}"
2693
+ "{% endfor %}"
2694
+ "{% if add_generation_prompt %}"
2695
+ "\nASSISTANT: "
2696
+ "{% endif %}"
2697
+ )
2698
+
2699
+ def __init__(self, clip_model_path: str, verbose: bool = True):
2700
+ import llama_cpp.mtmd_cpp as mtmd_cpp
2701
+
2702
+ self.clip_model_path = clip_model_path
2703
+ self.verbose = verbose
2704
+ self._mtmd_cpp = mtmd_cpp
2705
+ self._exit_stack = ExitStack()
2706
+ self.mtmd_ctx: Optional[mtmd_cpp.mtmd_context_p] = None
2707
+
2708
+ if not os.path.exists(clip_model_path):
2709
+ raise ValueError(f"Clip model path does not exist: {clip_model_path}")
2710
+
2711
+ def _init_mtmd_context(self, llama_model: llama.Llama):
2712
+ """Initialize mtmd context with the llama model."""
2713
+ if self.mtmd_ctx is not None:
2714
+ return # Already initialized
2715
+
2716
+ with suppress_stdout_stderr(disable=self.verbose):
2717
+ # Get default parameters
2718
+ ctx_params = self._mtmd_cpp.mtmd_context_params_default()
2719
+ ctx_params.use_gpu = True # TODO: Make this configurable
2720
+ ctx_params.print_timings = self.verbose
2721
+ ctx_params.n_threads = llama_model.n_threads
2722
+ ctx_params.verbosity = 2 if self.verbose else 0 # GGML_LOG_LEVEL_INFO = 2
2723
+
2724
+ # Initialize mtmd context
2725
+ self.mtmd_ctx = self._mtmd_cpp.mtmd_init_from_file(
2726
+ self.clip_model_path.encode(),
2727
+ llama_model.model,
2728
+ ctx_params
2729
+ )
2730
+
2731
+ if self.mtmd_ctx is None:
2732
+ raise ValueError(f"Failed to load mtmd context from: {self.clip_model_path}")
2733
+
2734
+ # Check if vision is supported
2735
+ if not self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx):
2736
+ raise ValueError("Vision is not supported by this model")
2737
+
2738
+ def mtmd_free():
2739
+ with suppress_stdout_stderr(disable=self.verbose):
2740
+ if self.mtmd_ctx is not None:
2741
+ self._mtmd_cpp.mtmd_free(self.mtmd_ctx)
2742
+ self.mtmd_ctx = None
2743
+
2744
+ self._exit_stack.callback(mtmd_free)
2745
+
2746
+ def load_image(self, image_url: str) -> bytes:
2747
+ return self._load_image(image_url)
2748
+
2749
+ def _create_bitmap_from_bytes(self, image_bytes: bytes):
2750
+ """Create mtmd_bitmap from image bytes."""
2751
+ if self.mtmd_ctx is None:
2752
+ raise ValueError("mtmd context not initialized")
2753
+
2754
+ with suppress_stdout_stderr(disable=self.verbose):
2755
+ # Create bitmap from buffer using helper function
2756
+ bitmap = self._mtmd_cpp.mtmd_helper_bitmap_init_from_buf(
2757
+ self.mtmd_ctx,
2758
+ (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
2759
+ len(image_bytes)
2760
+ )
2761
+
2762
+ if bitmap is None:
2763
+ raise ValueError("Failed to create bitmap from image bytes")
2764
+
2765
+ return bitmap
2766
+
2767
+ def __call__(
2768
+ self,
2769
+ *,
2770
+ llama: llama.Llama,
2771
+ messages: List[llama_types.ChatCompletionRequestMessage],
2772
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
2773
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
2774
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
2775
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
2776
+ temperature: float = 0.2,
2777
+ top_p: float = 0.95,
2778
+ top_k: int = 40,
2779
+ min_p: float = 0.05,
2780
+ typical_p: float = 1.0,
2781
+ stream: bool = False,
2782
+ stop: Optional[Union[str, List[str]]] = [],
2783
+ seed: Optional[int] = None,
2784
+ response_format: Optional[
2785
+ llama_types.ChatCompletionRequestResponseFormat
2786
+ ] = None,
2787
+ max_tokens: Optional[int] = None,
2788
+ presence_penalty: float = 0.0,
2789
+ frequency_penalty: float = 0.0,
2790
+ repeat_penalty: float = 1.1,
2791
+ tfs_z: float = 1.0,
2792
+ mirostat_mode: int = 0,
2793
+ mirostat_tau: float = 5.0,
2794
+ mirostat_eta: float = 0.1,
2795
+ model: Optional[str] = None,
2796
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
2797
+ grammar: Optional[llama.LlamaGrammar] = None,
2798
+ logit_bias: Optional[Dict[str, float]] = None,
2799
+ logprobs: Optional[bool] = None,
2800
+ top_logprobs: Optional[int] = None,
2801
+ **kwargs, # type: ignore
2802
+ ) -> Union[
2803
+ llama_types.CreateChatCompletionResponse,
2804
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
2805
+ ]:
2806
+ # Initialize mtmd context
2807
+ self._init_mtmd_context(llama)
2808
+ assert self.mtmd_ctx is not None
2809
+
2810
+ system_prompt = _get_system_message(messages)
2811
+ if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None:
2812
+ messages = [
2813
+ llama_types.ChatCompletionRequestSystemMessage(
2814
+ role="system", content=self.DEFAULT_SYSTEM_MESSAGE
2815
+ )
2816
+ ] + messages
2817
+
2818
+ image_urls = self.get_image_urls(messages)
2819
+ template = ImmutableSandboxedEnvironment(
2820
+ trim_blocks=True,
2821
+ lstrip_blocks=True,
2822
+ ).from_string(self.CHAT_FORMAT)
2823
+
2824
+ # Get the default media marker
2825
+ media_marker = self._mtmd_cpp.mtmd_default_marker().decode('utf-8')
2826
+
2827
+ # Replace image URLs with media markers in the template
2828
+ text = template.render(
2829
+ messages=messages,
2830
+ add_generation_prompt=True,
2831
+ eos_token=llama.detokenize([llama.token_eos()]),
2832
+ bos_token=llama.detokenize([llama.token_bos()]),
2833
+ )
2834
+
2835
+ # Replace image URLs in text with media markers
2836
+ for image_url in image_urls:
2837
+ text = text.replace(image_url, media_marker)
2838
+
2839
+ if self.verbose:
2840
+ print(text, file=sys.stderr)
2841
+
2842
+ # Create bitmaps from images
2843
+ bitmaps = []
2844
+ bitmap_cleanup = []
2845
+ try:
2846
+ for image_url in image_urls:
2847
+ image_bytes = self.load_image(image_url)
2848
+ bitmap = self._create_bitmap_from_bytes(image_bytes)
2849
+ bitmaps.append(bitmap)
2850
+ bitmap_cleanup.append(bitmap)
2851
+
2852
+ # Create input text structure
2853
+ input_text = self._mtmd_cpp.mtmd_input_text()
2854
+ input_text.text = text.encode('utf-8')
2855
+ input_text.add_special = True
2856
+ input_text.parse_special = True
2857
+
2858
+ # Create input chunks
2859
+ chunks = self._mtmd_cpp.mtmd_input_chunks_init()
2860
+ if chunks is None:
2861
+ raise ValueError("Failed to create input chunks")
2862
+
2863
+ try:
2864
+ # Tokenize text and images together
2865
+ bitmap_array = (self._mtmd_cpp.mtmd_bitmap_p_ctypes * len(bitmaps))(*bitmaps)
2866
+ result = self._mtmd_cpp.mtmd_tokenize(
2867
+ self.mtmd_ctx,
2868
+ chunks,
2869
+ ctypes.byref(input_text),
2870
+ bitmap_array,
2871
+ len(bitmaps)
2872
+ )
2873
+
2874
+ if result != 0:
2875
+ raise ValueError(f"Failed to tokenize input: error code {result}")
2876
+
2877
+ # Reset llama context
2878
+ llama.reset()
2879
+ llama._ctx.kv_cache_clear()
2880
+
2881
+ # Process each chunk
2882
+ n_past = llama_cpp.llama_pos(0)
2883
+ n_chunks = self._mtmd_cpp.mtmd_input_chunks_size(chunks)
2884
+
2885
+ for i in range(n_chunks):
2886
+ chunk = self._mtmd_cpp.mtmd_input_chunks_get(chunks, i)
2887
+ if chunk is None:
2888
+ continue
2889
+
2890
+ chunk_type = self._mtmd_cpp.mtmd_input_chunk_get_type(chunk)
2891
+
2892
+ if chunk_type == self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_TEXT:
2893
+ # Handle text chunk
2894
+ n_tokens_out = ctypes.c_size_t()
2895
+ tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text(
2896
+ chunk, ctypes.byref(n_tokens_out)
2897
+ )
2898
+
2899
+ if tokens_ptr and n_tokens_out.value > 0:
2900
+ # Convert ctypes array to Python list
2901
+ tokens = [tokens_ptr[j] for j in range(n_tokens_out.value)]
2902
+
2903
+ if llama.n_tokens + len(tokens) > llama.n_ctx():
2904
+ raise ValueError(
2905
+ f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}"
2906
+ )
2907
+ llama.eval(tokens)
2908
+
2909
+ elif chunk_type in [self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_AUDIO]:
2910
+ # Handle image/audio chunk using helper
2911
+ chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk)
2912
+
2913
+ if llama.n_tokens + chunk_n_tokens > llama.n_ctx():
2914
+ raise ValueError(
2915
+ f"Prompt exceeds n_ctx: {llama.n_tokens + chunk_n_tokens} > {llama.n_ctx()}"
2916
+ )
2917
+
2918
+ new_n_past = llama_cpp.llama_pos(0)
2919
+ result = self._mtmd_cpp.mtmd_helper_eval_chunk_single(
2920
+ self.mtmd_ctx,
2921
+ llama._ctx.ctx,
2922
+ chunk,
2923
+ llama_cpp.llama_pos(llama.n_tokens),
2924
+ llama_cpp.llama_seq_id(0),
2925
+ llama.n_batch,
2926
+ False, # logits_last
2927
+ ctypes.byref(new_n_past)
2928
+ )
2929
+
2930
+ if result != 0:
2931
+ raise ValueError(f"Failed to evaluate chunk: error code {result}")
2932
+
2933
+ # Update llama's token count
2934
+ llama.n_tokens = new_n_past.value
2935
+
2936
+ # Get prompt tokens to avoid a cache miss
2937
+ prompt = llama.input_ids[: llama.n_tokens].tolist()
2938
+
2939
+ finally:
2940
+ self._mtmd_cpp.mtmd_input_chunks_free(chunks)
2941
+
2942
+ finally:
2943
+ # Cleanup bitmaps
2944
+ for bitmap in bitmap_cleanup:
2945
+ self._mtmd_cpp.mtmd_bitmap_free(bitmap)
2946
+
2947
+ # Handle response format and tools (same as before)
2948
+ if response_format is not None and response_format["type"] == "json_object":
2949
+ grammar = _grammar_for_response_format(response_format)
2950
+
2951
+ # Convert legacy functions to tools
2952
+ if functions is not None:
2953
+ tools = [
2954
+ {
2955
+ "type": "function",
2956
+ "function": function,
2957
+ }
2958
+ for function in functions
2959
+ ]
2960
+
2961
+ # Convert legacy function_call to tool_choice
2962
+ if function_call is not None:
2963
+ if isinstance(function_call, str) and (
2964
+ function_call == "none" or function_call == "auto"
2965
+ ):
2966
+ tool_choice = function_call
2967
+ if isinstance(function_call, dict) and "name" in function_call:
2968
+ tool_choice = {
2969
+ "type": "function",
2970
+ "function": {
2971
+ "name": function_call["name"],
2972
+ },
2973
+ }
2974
+
2975
+ tool = None
2976
+ if (
2977
+ tool_choice is not None
2978
+ and isinstance(tool_choice, dict)
2979
+ and tools is not None
2980
+ ):
2981
+ name = tool_choice["function"]["name"]
2982
+ tool = next((t for t in tools if t["function"]["name"] == name), None)
2983
+ if tool is None:
2984
+ raise ValueError(f"Tool choice '{name}' not found in tools.")
2985
+ schema = tool["function"]["parameters"]
2986
+ try:
2987
+ # create grammar from json schema
2988
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
2989
+ json.dumps(schema), verbose=llama.verbose
2990
+ )
2991
+ except Exception as e:
2992
+ if llama.verbose:
2993
+ print(str(e), file=sys.stderr)
2994
+ grammar = llama_grammar.LlamaGrammar.from_string(
2995
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
2996
+ )
2997
+
2998
+ completion_or_chunks = llama.create_completion(
2999
+ prompt=prompt,
3000
+ temperature=temperature,
3001
+ top_p=top_p,
3002
+ top_k=top_k,
3003
+ min_p=min_p,
3004
+ typical_p=typical_p,
3005
+ logprobs=top_logprobs if logprobs else None,
3006
+ stream=stream,
3007
+ stop=stop,
3008
+ seed=seed,
3009
+ max_tokens=max_tokens,
3010
+ presence_penalty=presence_penalty,
3011
+ frequency_penalty=frequency_penalty,
3012
+ repeat_penalty=repeat_penalty,
3013
+ tfs_z=tfs_z,
3014
+ mirostat_mode=mirostat_mode,
3015
+ mirostat_tau=mirostat_tau,
3016
+ mirostat_eta=mirostat_eta,
3017
+ model=model,
3018
+ logits_processor=logits_processor,
3019
+ grammar=grammar,
3020
+ logit_bias=logit_bias,
3021
+ )
3022
+
3023
+ if tool is not None:
3024
+ tool_name = tool["function"]["name"]
3025
+ return _convert_completion_to_chat_function(
3026
+ tool_name, completion_or_chunks, stream
3027
+ )
3028
+ return _convert_completion_to_chat(completion_or_chunks, stream=stream)
3029
+
3030
+ @staticmethod
3031
+ def _load_image(image_url: str) -> bytes:
3032
+ # TODO: Add Pillow support for other image formats beyond (jpg, png)
3033
+ if image_url.startswith("data:"):
3034
+ import base64
3035
+ image_bytes = base64.b64decode(image_url.split(",")[1])
3036
+ return image_bytes
3037
+ else:
3038
+ import urllib.request
3039
+ with urllib.request.urlopen(image_url) as f:
3040
+ image_bytes = f.read()
3041
+ return image_bytes
3042
+
3043
+ @staticmethod
3044
+ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
3045
+ image_urls: List[str] = []
3046
+ for message in messages:
3047
+ if message["role"] == "user":
3048
+ if message["content"] is None:
3049
+ continue
3050
+ for content in message["content"]:
3051
+ if isinstance(content, dict) and "type" in content:
3052
+ if content["type"] == "image_url":
3053
+ if (
3054
+ isinstance(content["image_url"], dict)
3055
+ and "url" in content["image_url"]
3056
+ ):
3057
+ image_urls.append(content["image_url"]["url"])
3058
+ else:
3059
+ image_urls.append(content["image_url"])
3060
+ return image_urls
3061
+
3062
+ @staticmethod
3063
+ def split_text_on_image_urls(text: str, image_urls: List[str]):
3064
+ """This method is no longer used in the new implementation."""
3065
+ def find_first(s: str, substrs: List[str]):
3066
+ for i, substr in enumerate(substrs):
3067
+ pos = s.find(substr)
3068
+ if pos != -1:
3069
+ return pos, i
3070
+ return None, None
3071
+
3072
+ split_text: List[Tuple[Literal["text", "image_url"], str]] = []
3073
+ remaining = text
3074
+ while remaining:
3075
+ # Find first image_url
3076
+ pos, i = find_first(remaining, image_urls)
3077
+ if pos is not None and i is not None:
3078
+ if pos > 0:
3079
+ split_text.append(("text", remaining[:pos]))
3080
+ split_text.append(("image_url", image_urls[i]))
3081
+ remaining = remaining[pos + len(image_urls[i]) :]
3082
+ else:
3083
+ split_text.append(("text", remaining))
3084
+ remaining = ""
3085
+ return split_text
3086
+
3087
+ @classmethod
3088
+ def from_pretrained(
3089
+ cls,
3090
+ repo_id: str,
3091
+ filename: Optional[str],
3092
+ local_dir: Optional[Union[str, os.PathLike[str]]] = None,
3093
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
3094
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
3095
+ **kwargs: Any,
3096
+ ) -> "Llava15ChatHandler":
3097
+ import fnmatch
3098
+ from pathlib import Path
3099
+
3100
+ try:
3101
+ from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore
3102
+ from huggingface_hub.utils import validate_repo_id # type: ignore
3103
+ except ImportError:
3104
+ raise ImportError(
3105
+ "Llama.from_pretrained requires the huggingface-hub package. "
3106
+ "You can install it with `pip install huggingface-hub`."
3107
+ )
3108
+
3109
+ validate_repo_id(repo_id)
3110
+
3111
+ hffs = HfFileSystem()
3112
+
3113
+ files = [
3114
+ file["name"] if isinstance(file, dict) else file
3115
+ for file in hffs.ls(repo_id) # type: ignore
3116
+ ]
3117
+
3118
+ # split each file into repo_id, subfolder, filename
3119
+ file_list: List[str] = []
3120
+ for file in files:
3121
+ rel_path = Path(file).relative_to(repo_id)
3122
+ file_list.append(str(rel_path))
3123
+
3124
+ matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
3125
+
3126
+ if len(matching_files) == 0:
3127
+ raise ValueError(
3128
+ f"No file found in {repo_id} that match {filename}\n\n"
3129
+ f"Available Files:\n{json.dumps(file_list)}"
3130
+ )
3131
+
3132
+ if len(matching_files) > 1:
3133
+ raise ValueError(
3134
+ f"Multiple files found in {repo_id} matching {filename}\n\n"
3135
+ f"Available Files:\n{json.dumps(files)}"
3136
+ )
3137
+
3138
+ (matching_file,) = matching_files
3139
+
3140
+ subfolder = str(Path(matching_file).parent)
3141
+ filename = Path(matching_file).name
3142
+
3143
+ # download the file
3144
+ hf_hub_download(
3145
+ repo_id=repo_id,
3146
+ filename=filename,
3147
+ subfolder=subfolder,
3148
+ local_dir=cast(Union[str, Path, None], local_dir),
3149
+ local_dir_use_symlinks=local_dir_use_symlinks,
3150
+ cache_dir=cast(Union[str, Path, None], cache_dir),
3151
+ )
3152
+
3153
+ if local_dir is None:
3154
+ model_path = hf_hub_download(
3155
+ repo_id=repo_id,
3156
+ filename=filename,
3157
+ subfolder=subfolder,
3158
+ local_dir=local_dir,
3159
+ local_dir_use_symlinks=local_dir_use_symlinks,
3160
+ cache_dir=cast(Union[str, Path, None], cache_dir),
3161
+ local_files_only=True,
3162
+ )
3163
+ else:
3164
+ model_path = os.path.join(local_dir, filename)
3165
+
3166
+ return cls(
3167
+ clip_model_path=model_path,
3168
+ **kwargs,
3169
+ )
3170
+
3171
+
3172
+ class ObsidianChatHandler(Llava15ChatHandler):
3173
+ # Prompt Format
3174
+ # The model followed ChatML format. However, with ### as the seperator
3175
+
3176
+ # <|im_start|>user
3177
+ # What is this sign about?\n<image>
3178
+ # ###
3179
+ # <|im_start|>assistant
3180
+ # The sign is about bullying, and it is placed on a black background with a red background.
3181
+ # ###
3182
+
3183
+ CHAT_FORMAT = (
3184
+ "{% for message in messages %}"
3185
+ # System message
3186
+ "{% if message.role == 'system' %}"
3187
+ "<|im_start|>system\n"
3188
+ "{{ message.content }}\n"
3189
+ "###\n"
3190
+ "{% endif %}"
3191
+ # User message
3192
+ "{% if message.role == 'user' %}"
3193
+ "<|im_start|>user\n"
3194
+ "{% if message.content is string %}"
3195
+ "{{ message.content }}"
3196
+ "{% endif %}"
3197
+ "{% if message.content is iterable %}"
3198
+ "{% for content in message.content %}"
3199
+ "{% if content.type == 'image_url' and content.image_url is string %}"
3200
+ "{{ content.image_url }}"
3201
+ "{% endif %}"
3202
+ "{% if content.type == 'image_url' and content.image_url is mapping %}"
3203
+ "{{ content.image_url.url }}"
3204
+ "{% endif %}"
3205
+ "{% endfor %}"
3206
+ "{% for content in message.content %}"
3207
+ "{% if content.type == 'text' %}"
3208
+ "{{ content.text }}"
3209
+ "{% endif %}"
3210
+ "{% endfor %}"
3211
+ "{% endif %}"
3212
+ "###\n"
3213
+ "{% endif %}"
3214
+ # Assistant message
3215
+ "{% if message.role == 'assistant' %}"
3216
+ "<|im_start|>assistant\n"
3217
+ "{{ message.content }}"
3218
+ "###\n"
3219
+ "{% endif %}"
3220
+ "{% endfor %}"
3221
+ # Generation prompt
3222
+ "{% if add_generation_prompt %}"
3223
+ "<|im_start|>assistant\n"
3224
+ "{% endif %}"
3225
+ )
3226
+
3227
+
3228
+ class MoondreamChatHandler(Llava15ChatHandler):
3229
+ # Chat Format:
3230
+ # f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
3231
+ CHAT_FORMAT = (
3232
+ "{% for message in messages %}"
3233
+ "{% if message.role == 'user' %}"
3234
+ "{% if message.content is iterable %}"
3235
+ # <image>
3236
+ "{% for content in message.content %}"
3237
+ "{% if content.type == 'image_url' %}"
3238
+ "{% if content.image_url is string %}"
3239
+ "{{ content.image_url }}\n\n"
3240
+ "{% endif %}"
3241
+ "{% if content.image_url is mapping %}"
3242
+ "{{ content.image_url.url }}\n\n"
3243
+ "{% endif %}"
3244
+ "{% endif %}"
3245
+ "{% endfor %}"
3246
+ # Question:
3247
+ "{% for content in message.content %}"
3248
+ "{% if content.type == 'text' %}"
3249
+ "Question: {{ content.text }}\n\n"
3250
+ "{% endif %}"
3251
+ "{% endfor %}"
3252
+ "{% endif %}"
3253
+ # Question:
3254
+ "{% if message.content is string %}"
3255
+ "Question: {{ message.content }}\n\n"
3256
+ "{% endif %}"
3257
+ "{% endif %}"
3258
+ # Answer:
3259
+ "{% if message.role == 'assistant' %}"
3260
+ "Answer:{{ message.content }}\n\n"
3261
+ "{% endif %}"
3262
+ "{% endfor %}"
3263
+ # Generation prompt
3264
+ "{% if add_generation_prompt %}"
3265
+ "Answer:"
3266
+ "{% endif %}"
3267
+ )
3268
+
3269
+
3270
+ class Llava16ChatHandler(Llava15ChatHandler):
3271
+ DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
3272
+
3273
+ # Example prompt
3274
+ # "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
3275
+
3276
+ CHAT_FORMAT = (
3277
+ "{% for message in messages %}"
3278
+ "{% if message.role == 'system' %}"
3279
+ "{{ message.content }}"
3280
+ "{% endif %}"
3281
+ "{% if message.role == 'user' %}"
3282
+ "{% if message.content is iterable %}"
3283
+ # <image>
3284
+ "{% for content in message.content %}"
3285
+ "{% if content.type == 'image_url' %}"
3286
+ "{% if content.image_url is string %}"
3287
+ "{{ content.image_url }}\n"
3288
+ "{% endif %}"
3289
+ "{% if content.image_url is mapping %}"
3290
+ "{{ content.image_url.url }}\n"
3291
+ "{% endif %}"
3292
+ "{% endif %}"
3293
+ "{% endfor %}"
3294
+ # Question:
3295
+ "{% for content in message.content %}"
3296
+ "{% if content.type == 'text' %}"
3297
+ "{{ content.text }}"
3298
+ "{% endif %}"
3299
+ "{% endfor %}"
3300
+ "{% endif %}"
3301
+ # Question:
3302
+ "{% if message.content is string %}"
3303
+ "{{ message.content }}"
3304
+ "{% endif %}"
3305
+ "{% endif %}"
3306
+ # Answer:
3307
+ "{% if message.role == 'assistant' %}"
3308
+ "{{ message.content }}"
3309
+ "{% endif %}"
3310
+ "{% endfor %}"
3311
+ # Generation prompt
3312
+ "{% if add_generation_prompt %}"
3313
+ "Answer:"
3314
+ "{% endif %}"
3315
+ )
3316
+
3317
+
3318
+ class NanoLlavaChatHandler(Llava15ChatHandler):
3319
+ # Prompt Format
3320
+ # The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
3321
+
3322
+ # <|im_start|>system
3323
+ # Answer the question<|im_end|><|im_start|>user
3324
+ # <image>
3325
+ # What is the picture about?<|im_end|><|im_start|>assistant
3326
+ DEFAULT_SYSTEM_MESSAGE = "Answer the question"
3327
+
3328
+ CHAT_FORMAT = (
3329
+ "{% for message in messages %}"
3330
+ # System message
3331
+ "{% if message.role == 'system' %}"
3332
+ "<|im_start|>system\n"
3333
+ "{{ message.content }}"
3334
+ "<|im_end|>"
3335
+ "{% endif %}"
3336
+ # User message
3337
+ "{% if message.role == 'user' %}"
3338
+ "<|im_start|>user\n"
3339
+ "{% if message.content is string %}"
3340
+ "{{ message.content }}"
3341
+ "{% endif %}"
3342
+ "{% if message.content is iterable %}"
3343
+ "{% for content in message.content %}"
3344
+ "{% if content.type == 'image_url' and content.image_url is string %}"
3345
+ "{{ content.image_url }}"
3346
+ "{% endif %}"
3347
+ "{% if content.type == 'image_url' and content.image_url is mapping %}"
3348
+ "{{ content.image_url.url }}"
3349
+ "{% endif %}"
3350
+ "{% endfor %}"
3351
+ "{% for content in message.content %}"
3352
+ "{% if content.type == 'text' %}"
3353
+ "{{ content.text }}"
3354
+ "{% endif %}"
3355
+ "{% endfor %}"
3356
+ "{% endif %}"
3357
+ "<|im_end|>"
3358
+ "{% endif %}"
3359
+ # Assistant message
3360
+ "{% if message.role == 'assistant' %}"
3361
+ "<|im_start|>assistant\n"
3362
+ "{{ message.content }}"
3363
+ "<|im_end|>"
3364
+ "{% endif %}"
3365
+ "{% endfor %}"
3366
+ # Generation prompt
3367
+ "{% if add_generation_prompt %}"
3368
+ "<|im_start|>assistant\n"
3369
+ "{% endif %}"
3370
+ )
3371
+
3372
+
3373
+ class Llama3VisionAlphaChatHandler(Llava15ChatHandler):
3374
+ # question = "<image>" + q
3375
+
3376
+ # prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
3377
+ DEFAULT_SYSTEM_MESSAGE = None
3378
+
3379
+ CHAT_FORMAT = (
3380
+ "{% for message in messages %}"
3381
+ "<|start_header_id|>"
3382
+ "{% if message.role == 'user' %}"
3383
+ "user<|end_header_id|>\n\n"
3384
+ "{% if message.content is iterable %}"
3385
+ # <image>
3386
+ "{% for content in message.content %}"
3387
+ "{% if content.type == 'image_url' %}"
3388
+ "{% if content.image_url is string %}"
3389
+ "{{ content.image_url }}"
3390
+ "{% endif %}"
3391
+ "{% if content.image_url is mapping %}"
3392
+ "{{ content.image_url.url }}"
3393
+ "{% endif %}"
3394
+ "{% endif %}"
3395
+ "{% endfor %}"
3396
+ # Question:
3397
+ "{% for content in message.content %}"
3398
+ "{% if content.type == 'text' %}"
3399
+ "{{ content.text }}"
3400
+ "{% endif %}"
3401
+ "{% endfor %}"
3402
+ "{% endif %}"
3403
+ # Question:
3404
+ "{% if message.content is string %}"
3405
+ "{{ message.content }}"
3406
+ "{% endif %}"
3407
+ "{% endif %}"
3408
+ # Answer:
3409
+ "{% if message.role == 'assistant' %}"
3410
+ "assistant<|end_header_id|>\n\n"
3411
+ "{{ message.content }}"
3412
+ "{% endif %}"
3413
+ "<|eot_id|>"
3414
+ "{% endfor %}"
3415
+ # Generation prompt
3416
+ "{% if add_generation_prompt %}"
3417
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
3418
+ "{% endif %}"
3419
+ )
3420
+
3421
+
3422
+ # alias
3423
+ Llama3VisionAlpha = Llama3VisionAlphaChatHandler
3424
+
3425
+
3426
+ class MiniCPMv26ChatHandler(Llava15ChatHandler):
3427
+ DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
3428
+
3429
+ CHAT_FORMAT = (
3430
+ "{% for message in messages %}"
3431
+ "{% if loop.first and messages[0]['role'] != 'system' %}"
3432
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
3433
+ "{% endif %}"
3434
+ "<|im_start|>{{ message['role'] }}\n"
3435
+ "{% if message['content'] is iterable %}"
3436
+ "{% for content in message['content'] %}"
3437
+ "{% if content.type == 'image_url' %}"
3438
+ "{% if content.image_url is string %}"
3439
+ "{{ content.image_url }}"
3440
+ "{% endif %}"
3441
+ "{% if content.image_url is mapping %}"
3442
+ "{{ content.image_url.url }}"
3443
+ "{% endif %}"
3444
+ "{% endif %}"
3445
+ "{% endfor %}"
3446
+
3447
+ "{% for content in message['content'] %}"
3448
+ "{% if content.type == 'text' %}"
3449
+ "{{ content.text }}"
3450
+ "{% endif %}"
3451
+ "{% endfor %}"
3452
+ "{% endif %}"
3453
+ "{% if message['content'] is string %}"
3454
+ "{{ message['content'] }}"
3455
+ "{% endif %}"
3456
+ "<|im_end|>\n"
3457
+ "{% endfor %}"
3458
+ "{% if add_generation_prompt %}"
3459
+ "<|im_start|>assistant\n"
3460
+ "{% endif %}"
3461
+ )
3462
+
3463
+
3464
+ class Qwen25VLChatHandler(Llava15ChatHandler):
3465
+ DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
3466
+
3467
+ CHAT_FORMAT = (
3468
+ #"{% set image_count = namespace(value=0) %}"
3469
+ #"{% set video_count = namespace(value=0) %}"
3470
+ "{% for message in messages %}"
3471
+ "{% if loop.first and message['role'] != 'system' %}"
3472
+ "<|im_start|>system\n"
3473
+ "{{ self.DEFAULT_SYSTEM_MESSAGE }}<|im_end|>\n"
3474
+ "{% endif %}"
3475
+ "<|im_start|>{{ message['role'] }}\n"
3476
+ "{% if message['content'] is string %}"
3477
+ "{{ message['content'] }}<|im_end|>\n"
3478
+ "{% else %}"
3479
+ "{% for content in message['content'] %}"
3480
+ "{% if content['type'] == 'image_url' %}"
3481
+ "{% if content.image_url is string %}"
3482
+ "{{ content.image_url }}"
3483
+ "{% else %}"
3484
+ "{{ content.image_url.url }}"
3485
+ "{% endif %}"
3486
+ #"{% set image_count.value = image_count.value + 1 %}"
3487
+ "{% elif content['type'] == 'text' %}"
3488
+ "{{ content['text'] }}"
3489
+ "{% endif %}"
3490
+ "{% endfor %}"
3491
+ "<|im_end|>\n"
3492
+ "{% endif %}"
3493
+ "{% endfor %}"
3494
+ "<|im_start|>assistant\n"
3495
+ )
3496
+
3497
+ def __call__(self, **kwargs):
3498
+ llama = kwargs['llama']
3499
+
3500
+ # Clear state for multiple runs
3501
+ llama.reset()
3502
+ llama._ctx.kv_cache_clear()
3503
+ llama.n_tokens = 0
3504
+
3505
+ if hasattr(llama, 'input_ids'):
3506
+ llama.input_ids.fill(0)
3507
+
3508
+ # Clear any handler state
3509
+ if hasattr(self, '_last_image_embed'):
3510
+ self._last_image_embed = None
3511
+ self._last_image_hash = None
3512
+
3513
+ if self.verbose:
3514
+ messages = kwargs.get('messages', [])
3515
+ image_count = len(self.get_image_urls(messages))
3516
+ print(f"Minimal - Cleared state, processing {image_count} images", file=sys.stderr)
3517
+
3518
+ # Use parent implementation
3519
+ return super().__call__(**kwargs)
3520
+
3521
+
3522
+ @register_chat_completion_handler("chatml-function-calling")
3523
+ def chatml_function_calling(
3524
+ llama: llama.Llama,
3525
+ messages: List[llama_types.ChatCompletionRequestMessage],
3526
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
3527
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
3528
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
3529
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
3530
+ temperature: float = 0.2,
3531
+ top_p: float = 0.95,
3532
+ top_k: int = 40,
3533
+ min_p: float = 0.05,
3534
+ typical_p: float = 1.0,
3535
+ stream: bool = False,
3536
+ stop: Optional[Union[str, List[str]]] = [],
3537
+ response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
3538
+ max_tokens: Optional[int] = None,
3539
+ presence_penalty: float = 0.0,
3540
+ frequency_penalty: float = 0.0,
3541
+ repeat_penalty: float = 1.1,
3542
+ tfs_z: float = 1.0,
3543
+ mirostat_mode: int = 0,
3544
+ mirostat_tau: float = 5.0,
3545
+ mirostat_eta: float = 0.1,
3546
+ model: Optional[str] = None,
3547
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
3548
+ grammar: Optional[llama.LlamaGrammar] = None,
3549
+ logprobs: Optional[bool] = None,
3550
+ top_logprobs: Optional[int] = None,
3551
+ **kwargs, # type: ignore
3552
+ ) -> Union[
3553
+ llama_types.CreateChatCompletionResponse,
3554
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
3555
+ ]:
3556
+ function_calling_template = (
3557
+ "{% for message in messages %}"
3558
+ "<|im_start|>{{ message.role }}\n"
3559
+ # System message
3560
+ "{% if message.role == 'system' %}"
3561
+ "{{ message.content }}"
3562
+ "{% if tool_calls %}"
3563
+ "\n\nYou have access to the following functions:\n"
3564
+ "{% for tool in tools %}"
3565
+ "\nfunctions.{{ tool.function.name }}:\n"
3566
+ "{{ tool.function.parameters | tojson }}"
3567
+ "\n{% endfor %}"
3568
+ "\n\nYou can respond to users messages with either a single message or one or more function calls."
3569
+ "\n\nTo respond with a message begin the message with 'message:', use the following format:"
3570
+ "\n\nmessage:"
3571
+ "\n<message>"
3572
+ "\n\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
3573
+ "\n\nfunctions.<function_name>:"
3574
+ '\n{ "arg1": "value1", "arg2": "value2" }'
3575
+ "\nfunctions.<function_name>:"
3576
+ '\n{ "arg1": "value1", "arg2": "value2" }'
3577
+ "{% endif %}"
3578
+ "<|im_end|>\n"
3579
+ "{% endif %}"
3580
+ # User message
3581
+ "{% if message.role == 'user' %}"
3582
+ "{{ message.content }}"
3583
+ "<|im_end|>\n"
3584
+ "{% endif %}"
3585
+ # Assistant message
3586
+ "{% if message.role == 'assistant' %}"
3587
+ ## Reglar message
3588
+ "{% if message.content and message.content | length > 0 %}"
3589
+ "{% if tool_calls %}"
3590
+ "message:\n"
3591
+ "{% endif %}"
3592
+ "{{ message.content }}"
3593
+ "<|im_end|>\n"
3594
+ "{% endif %}"
3595
+ ## Function calls
3596
+ "{% if 'tool_calls' in message %}"
3597
+ "{% for tool_call in message.tool_calls %}"
3598
+ "functions.{{ tool_call.function.name }}:\n"
3599
+ "{{ tool_call.function.arguments }}"
3600
+ "{% endfor %}"
3601
+ "<|im_end|>\n"
3602
+ "{% endif %}"
3603
+ "{% endif %}"
3604
+ "{% endfor %}"
3605
+ "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3606
+ )
3607
+ template_renderer = ImmutableSandboxedEnvironment(
3608
+ autoescape=jinja2.select_autoescape(["html", "xml"]),
3609
+ undefined=jinja2.StrictUndefined,
3610
+ ).from_string(function_calling_template)
3611
+
3612
+ # Convert legacy functions to tools
3613
+ if functions is not None:
3614
+ tools = [
3615
+ {
3616
+ "type": "function",
3617
+ "function": function,
3618
+ }
3619
+ for function in functions
3620
+ ]
3621
+
3622
+ # Convert legacy function_call to tool_choice
3623
+ if function_call is not None:
3624
+ if isinstance(function_call, str) and (
3625
+ function_call == "none" or function_call == "auto"
3626
+ ):
3627
+ tool_choice = function_call
3628
+ if isinstance(function_call, dict) and "name" in function_call:
3629
+ tool_choice = {
3630
+ "type": "function",
3631
+ "function": {
3632
+ "name": function_call["name"],
3633
+ },
3634
+ }
3635
+
3636
+ stop = (
3637
+ [stop, "<|im_end|>"]
3638
+ if isinstance(stop, str)
3639
+ else stop + ["<|im_end|>"] if stop else ["<|im_end|>"]
3640
+ )
3641
+
3642
+ # Case 1: No tool choice by user
3643
+ if (
3644
+ tool_choice is None
3645
+ or (isinstance(tool_choice, str) and tool_choice == "none")
3646
+ or tools is None
3647
+ or len(tools) == 0
3648
+ ):
3649
+ prompt = template_renderer.render(
3650
+ messages=messages,
3651
+ tools=[],
3652
+ tool_calls=None,
3653
+ add_generation_prompt=True,
3654
+ )
3655
+
3656
+ if response_format is not None and response_format["type"] == "json_object":
3657
+ grammar = _grammar_for_response_format(response_format)
3658
+
3659
+ return _convert_completion_to_chat(
3660
+ llama.create_completion(
3661
+ prompt=prompt,
3662
+ temperature=temperature,
3663
+ top_p=top_p,
3664
+ top_k=top_k,
3665
+ min_p=min_p,
3666
+ typical_p=typical_p,
3667
+ stream=stream,
3668
+ stop=stop,
3669
+ max_tokens=max_tokens,
3670
+ presence_penalty=presence_penalty,
3671
+ frequency_penalty=frequency_penalty,
3672
+ repeat_penalty=repeat_penalty,
3673
+ tfs_z=tfs_z,
3674
+ mirostat_mode=mirostat_mode,
3675
+ mirostat_tau=mirostat_tau,
3676
+ mirostat_eta=mirostat_eta,
3677
+ model=model,
3678
+ logits_processor=logits_processor,
3679
+ grammar=grammar,
3680
+ logprobs=top_logprobs if logprobs else None,
3681
+ ),
3682
+ stream=stream,
3683
+ )
3684
+
3685
+ # Case 2: Tool choice by user
3686
+ if isinstance(tool_choice, dict):
3687
+ tool_name = tool_choice["function"]["name"]
3688
+ tool = next(
3689
+ (tool for tool in tools if tool["function"]["name"] == tool_name), None
3690
+ )
3691
+ if tool is None:
3692
+ raise ValueError(f"Tool with name '{tool_name}' not found in tools")
3693
+ prompt = template_renderer.render(
3694
+ messages=messages,
3695
+ tools=tools,
3696
+ tool_calls=True,
3697
+ add_generation_prompt=True,
3698
+ )
3699
+ prompt += f"functions.{tool_name}:\n"
3700
+ try:
3701
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
3702
+ json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
3703
+ )
3704
+ except Exception as e:
3705
+ grammar = llama_grammar.LlamaGrammar.from_string(
3706
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
3707
+ )
3708
+ if llama.verbose:
3709
+ print(
3710
+ "Failed to parse function body as JSON schema, falling back to default grammar"
3711
+ )
3712
+ print(e)
3713
+ completion_or_chunks = llama.create_completion(
3714
+ prompt=prompt,
3715
+ temperature=temperature,
3716
+ top_p=top_p,
3717
+ top_k=top_k,
3718
+ min_p=min_p,
3719
+ typical_p=typical_p,
3720
+ stream=stream,
3721
+ stop=stop,
3722
+ max_tokens=max_tokens,
3723
+ presence_penalty=presence_penalty,
3724
+ frequency_penalty=frequency_penalty,
3725
+ repeat_penalty=repeat_penalty,
3726
+ tfs_z=tfs_z,
3727
+ mirostat_mode=mirostat_mode,
3728
+ mirostat_tau=mirostat_tau,
3729
+ mirostat_eta=mirostat_eta,
3730
+ model=model,
3731
+ logits_processor=logits_processor,
3732
+ grammar=grammar,
3733
+ )
3734
+ return _convert_completion_to_chat_function(
3735
+ tool_name, completion_or_chunks, stream
3736
+ )
3737
+
3738
+ # Case 3: Automatic tool choice
3739
+ assert isinstance(tool_choice, str) and tool_choice == "auto"
3740
+ function_names = " | ".join(
3741
+ [f'''"functions.{tool['function']['name']}:"''' for tool in tools]
3742
+ )
3743
+ initial_gbnf_tool_grammar = (
3744
+ """root ::= functions | "message:"\n"""
3745
+ f"""functions ::= {function_names}\n"""
3746
+ )
3747
+ follow_up_gbnf_tool_grammar = (
3748
+ """root ::= functions | "<|im_end|>"\n"""
3749
+ f"""functions ::= {function_names}\n"""
3750
+ )
3751
+ prompt = template_renderer.render(
3752
+ messages=messages,
3753
+ tools=tools,
3754
+ tool_calls=True,
3755
+ add_generation_prompt=True,
3756
+ )
3757
+ completion_or_chunks = llama.create_completion(
3758
+ prompt=prompt,
3759
+ temperature=0,
3760
+ top_p=top_p,
3761
+ top_k=top_k,
3762
+ min_p=min_p,
3763
+ typical_p=typical_p,
3764
+ stream=False,
3765
+ stop=[":"],
3766
+ max_tokens=None,
3767
+ presence_penalty=presence_penalty,
3768
+ frequency_penalty=frequency_penalty,
3769
+ repeat_penalty=repeat_penalty,
3770
+ tfs_z=tfs_z,
3771
+ mirostat_mode=mirostat_mode,
3772
+ mirostat_tau=mirostat_tau,
3773
+ mirostat_eta=mirostat_eta,
3774
+ model=model,
3775
+ logits_processor=logits_processor,
3776
+ grammar=llama_grammar.LlamaGrammar.from_string(
3777
+ initial_gbnf_tool_grammar, verbose=llama.verbose
3778
+ ),
3779
+ )
3780
+ completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
3781
+ text = completion["choices"][0]["text"]
3782
+ if "message" in text:
3783
+ return _convert_completion_to_chat(
3784
+ llama.create_completion(
3785
+ prompt=prompt + "message:\n",
3786
+ temperature=temperature,
3787
+ top_p=top_p,
3788
+ top_k=top_k,
3789
+ min_p=min_p,
3790
+ typical_p=typical_p,
3791
+ stream=stream,
3792
+ stop=["<|im_end|>"],
3793
+ logprobs=top_logprobs if logprobs else None,
3794
+ max_tokens=None,
3795
+ presence_penalty=presence_penalty,
3796
+ frequency_penalty=frequency_penalty,
3797
+ repeat_penalty=repeat_penalty,
3798
+ tfs_z=tfs_z,
3799
+ mirostat_mode=mirostat_mode,
3800
+ mirostat_tau=mirostat_tau,
3801
+ mirostat_eta=mirostat_eta,
3802
+ model=model,
3803
+ logits_processor=logits_processor,
3804
+ grammar=llama_grammar.LlamaGrammar.from_string(
3805
+ follow_up_gbnf_tool_grammar, verbose=llama.verbose
3806
+ ),
3807
+ ),
3808
+ stream=stream,
3809
+ )
3810
+
3811
+ # One or more function calls
3812
+ tool_name = text[len("functions.") :]
3813
+ tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
3814
+ if not stream:
3815
+ completions: List[llama_types.CreateCompletionResponse] = []
3816
+ completions_tool_name: List[str] = []
3817
+ while tool is not None:
3818
+ prompt += f"functions.{tool_name}:\n"
3819
+ try:
3820
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
3821
+ json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
3822
+ )
3823
+ except Exception as e:
3824
+ grammar = llama_grammar.LlamaGrammar.from_string(
3825
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
3826
+ )
3827
+ if llama.verbose:
3828
+ print(
3829
+ "Failed to parse function body as JSON schema, falling back to default grammar"
3830
+ )
3831
+ print(e)
3832
+ completion_or_chunks = llama.create_completion(
3833
+ prompt=prompt,
3834
+ temperature=temperature,
3835
+ top_p=top_p,
3836
+ top_k=top_k,
3837
+ min_p=min_p,
3838
+ typical_p=typical_p,
3839
+ stream=False,
3840
+ stop=stop,
3841
+ max_tokens=None,
3842
+ presence_penalty=presence_penalty,
3843
+ frequency_penalty=frequency_penalty,
3844
+ repeat_penalty=repeat_penalty,
3845
+ tfs_z=tfs_z,
3846
+ mirostat_mode=mirostat_mode,
3847
+ mirostat_tau=mirostat_tau,
3848
+ mirostat_eta=mirostat_eta,
3849
+ model=model,
3850
+ logits_processor=logits_processor,
3851
+ grammar=grammar,
3852
+ )
3853
+ completion_or_chunks = cast(
3854
+ llama_types.CreateCompletionResponse, completion_or_chunks
3855
+ )
3856
+ completions.append(completion_or_chunks)
3857
+ completions_tool_name.append(tool_name)
3858
+ prompt += completion_or_chunks["choices"][0]["text"]
3859
+ prompt += "\n"
3860
+
3861
+ response = llama.create_completion(
3862
+ prompt=prompt,
3863
+ temperature=temperature,
3864
+ top_p=top_p,
3865
+ top_k=top_k,
3866
+ min_p=min_p,
3867
+ typical_p=typical_p,
3868
+ stream=False,
3869
+ stop=stop,
3870
+ max_tokens=None,
3871
+ presence_penalty=presence_penalty,
3872
+ frequency_penalty=frequency_penalty,
3873
+ repeat_penalty=repeat_penalty,
3874
+ tfs_z=tfs_z,
3875
+ mirostat_mode=mirostat_mode,
3876
+ mirostat_tau=mirostat_tau,
3877
+ mirostat_eta=mirostat_eta,
3878
+ model=model,
3879
+ logits_processor=logits_processor,
3880
+ grammar=llama_grammar.LlamaGrammar.from_string(
3881
+ follow_up_gbnf_tool_grammar, verbose=llama.verbose
3882
+ ),
3883
+ )
3884
+ response = cast(llama_types.CreateCompletionResponse, response)
3885
+
3886
+ tool_name = response["choices"][0]["text"][len("functions.") :]
3887
+ tool = next(
3888
+ (tool for tool in tools if tool["function"]["name"] == tool_name), None
3889
+ )
3890
+
3891
+ # Merge completions
3892
+ function_call_dict: Union[
3893
+ Dict[str, str],
3894
+ Dict[
3895
+ Literal["function_call"],
3896
+ llama_types.ChatCompletionRequestAssistantMessageFunctionCall,
3897
+ ],
3898
+ ] = (
3899
+ {
3900
+ "function_call": {
3901
+ "name": tool_name,
3902
+ "arguments": completions[0]["choices"][0]["text"],
3903
+ }
3904
+ }
3905
+ if len(completions) == 1
3906
+ else {}
3907
+ )
3908
+ return {
3909
+ "id": "chat" + completion["id"],
3910
+ "object": "chat.completion",
3911
+ "created": completion["created"],
3912
+ "model": completion["model"],
3913
+ "choices": [
3914
+ {
3915
+ "finish_reason": "tool_calls",
3916
+ "index": 0,
3917
+ "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
3918
+ "message": {
3919
+ "role": "assistant",
3920
+ "content": None,
3921
+ "tool_calls": [
3922
+ {
3923
+ "id": "call_"
3924
+ + f"_{i}_"
3925
+ + tool_name
3926
+ + "_"
3927
+ + completion["id"],
3928
+ "type": "function",
3929
+ "function": {
3930
+ "name": tool_name,
3931
+ "arguments": completion["choices"][0]["text"],
3932
+ },
3933
+ }
3934
+ for i, (tool_name, completion) in enumerate(
3935
+ zip(completions_tool_name, completions)
3936
+ )
3937
+ ],
3938
+ **function_call_dict,
3939
+ },
3940
+ }
3941
+ ],
3942
+ "usage": {
3943
+ "completion_tokens": sum(
3944
+ (
3945
+ completion["usage"]["completion_tokens"]
3946
+ if "usage" in completion
3947
+ else 0
3948
+ )
3949
+ for completion in completions
3950
+ ),
3951
+ "prompt_tokens": sum(
3952
+ completion["usage"]["prompt_tokens"] if "usage" in completion else 0
3953
+ for completion in completions
3954
+ ),
3955
+ "total_tokens": sum(
3956
+ completion["usage"]["total_tokens"] if "usage" in completion else 0
3957
+ for completion in completions
3958
+ ),
3959
+ },
3960
+ }
3961
+
3962
+ raise ValueError("Automatic streaming tool choice is not supported")