speedy-utils 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/chat_format/display.py +17 -4
- llm_utils/lm/async_lm/__init__.py +2 -0
- llm_utils/lm/async_lm/_utils.py +198 -0
- llm_utils/lm/async_lm/async_llm_task.py +154 -0
- llm_utils/lm/{async_lm.py → async_lm/async_lm.py} +191 -354
- llm_utils/scripts/vllm_load_balancer.py +220 -135
- {speedy_utils-1.1.4.dist-info → speedy_utils-1.1.6.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.4.dist-info → speedy_utils-1.1.6.dist-info}/RECORD +10 -7
- {speedy_utils-1.1.4.dist-info → speedy_utils-1.1.6.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.4.dist-info → speedy_utils-1.1.6.dist-info}/entry_points.txt +0 -0
|
@@ -1,28 +1,23 @@
|
|
|
1
|
-
|
|
1
|
+
# from ._utils import *
|
|
2
2
|
import base64
|
|
3
3
|
import hashlib
|
|
4
4
|
import json
|
|
5
5
|
import os
|
|
6
|
-
from abc import ABC
|
|
7
|
-
from functools import cache, lru_cache
|
|
8
6
|
from typing import (
|
|
9
7
|
Any,
|
|
10
8
|
Dict,
|
|
11
|
-
Generic,
|
|
12
9
|
List,
|
|
13
10
|
Literal,
|
|
14
11
|
Optional,
|
|
15
12
|
Sequence,
|
|
16
13
|
Type,
|
|
17
|
-
TypeVar,
|
|
18
14
|
Union,
|
|
19
15
|
cast,
|
|
20
16
|
overload,
|
|
21
17
|
)
|
|
22
|
-
|
|
18
|
+
|
|
23
19
|
from httpx import URL
|
|
24
20
|
from loguru import logger
|
|
25
|
-
from numpy import isin
|
|
26
21
|
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
27
22
|
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
28
23
|
|
|
@@ -36,49 +31,23 @@ from openai.types.chat import (
|
|
|
36
31
|
)
|
|
37
32
|
from openai.types.model import Model
|
|
38
33
|
from pydantic import BaseModel
|
|
39
|
-
from pydantic import ValidationError
|
|
40
|
-
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
41
|
-
|
|
42
|
-
# --------------------------------------------------------------------------- #
|
|
43
|
-
# type helpers
|
|
44
|
-
# --------------------------------------------------------------------------- #
|
|
45
|
-
TModel = TypeVar("TModel", bound=BaseModel)
|
|
46
|
-
Messages = List[ChatCompletionMessageParam]
|
|
47
|
-
LegacyMsgs = List[Dict[str, str]]
|
|
48
|
-
RawMsgs = Union[Messages, LegacyMsgs]
|
|
49
|
-
|
|
50
|
-
# --------------------------------------------------------------------------- #
|
|
51
|
-
# color helpers (unchanged)
|
|
52
|
-
# --------------------------------------------------------------------------- #
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _color(code: int, text: str) -> str:
|
|
56
|
-
return f"\x1b[{code}m{text}\x1b[0m"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _red(t):
|
|
60
|
-
return _color(31, t)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _green(t):
|
|
64
|
-
return _color(32, t)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def _blue(t):
|
|
68
|
-
return _color(34, t)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _yellow(t):
|
|
72
|
-
return _color(33, t)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
TParsed = TypeVar("TParsed", bound=BaseModel)
|
|
76
34
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
35
|
+
from speedy_utils import jloads
|
|
36
|
+
|
|
37
|
+
from ._utils import (
|
|
38
|
+
LegacyMsgs,
|
|
39
|
+
Messages,
|
|
40
|
+
ParsedOutput,
|
|
41
|
+
RawMsgs,
|
|
42
|
+
TModel,
|
|
43
|
+
TParsed,
|
|
44
|
+
_blue,
|
|
45
|
+
_green,
|
|
46
|
+
_red,
|
|
47
|
+
_yellow,
|
|
48
|
+
get_tokenizer,
|
|
49
|
+
inspect_word_probs_async,
|
|
50
|
+
)
|
|
82
51
|
|
|
83
52
|
|
|
84
53
|
class AsyncLM:
|
|
@@ -153,6 +122,14 @@ class AsyncLM:
|
|
|
153
122
|
**kwargs: Any,
|
|
154
123
|
) -> TModel: ...
|
|
155
124
|
|
|
125
|
+
async def _set_model(self) -> None:
|
|
126
|
+
if not self.model:
|
|
127
|
+
models = await self.list_models(port=self.port, host=self.host)
|
|
128
|
+
self.model = models[0] if models else None
|
|
129
|
+
logger.info(
|
|
130
|
+
f"No model specified. Using the first available model. {self.model}"
|
|
131
|
+
)
|
|
132
|
+
|
|
156
133
|
async def __call__(
|
|
157
134
|
self,
|
|
158
135
|
prompt: Optional[str] = None,
|
|
@@ -171,12 +148,8 @@ class AsyncLM:
|
|
|
171
148
|
|
|
172
149
|
assert messages is not None
|
|
173
150
|
# assert self.model is not None, "Model must be set before calling."
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
self.model = models[0] if models else None
|
|
177
|
-
logger.info(
|
|
178
|
-
f"No model specified. Using the first available model. {self.model}"
|
|
179
|
-
)
|
|
151
|
+
await self._set_model()
|
|
152
|
+
|
|
180
153
|
openai_msgs: Messages = (
|
|
181
154
|
self._convert_messages(cast(LegacyMsgs, messages))
|
|
182
155
|
if isinstance(messages[0], dict)
|
|
@@ -203,7 +176,7 @@ class AsyncLM:
|
|
|
203
176
|
else:
|
|
204
177
|
response = self._parse_output(raw_response, response_format)
|
|
205
178
|
|
|
206
|
-
self.
|
|
179
|
+
self._last_log = [prompt, messages, raw_response]
|
|
207
180
|
return response
|
|
208
181
|
|
|
209
182
|
# ------------------------------------------------------------------ #
|
|
@@ -390,48 +363,44 @@ class AsyncLM:
|
|
|
390
363
|
async def parse(
|
|
391
364
|
self,
|
|
392
365
|
response_model: Type[TParsed],
|
|
393
|
-
instruction
|
|
394
|
-
prompt
|
|
395
|
-
messages: Optional[RawMsgs] = None,
|
|
366
|
+
instruction,
|
|
367
|
+
prompt,
|
|
396
368
|
think: Literal[True, False, None] = None,
|
|
397
369
|
add_json_schema_to_instruction: bool = False,
|
|
398
370
|
temperature: Optional[float] = None,
|
|
399
371
|
max_tokens: Optional[int] = None,
|
|
400
|
-
cache: Optional[bool] =
|
|
372
|
+
cache: Optional[bool] = None,
|
|
373
|
+
use_beta: bool = False,
|
|
401
374
|
**kwargs,
|
|
402
375
|
) -> ParsedOutput[TParsed]:
|
|
403
376
|
"""Parse response using guided JSON generation."""
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
assert
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
"content": instruction,
|
|
411
|
-
},
|
|
412
|
-
{
|
|
413
|
-
"role": "user",
|
|
414
|
-
"content": prompt,
|
|
415
|
-
},
|
|
416
|
-
] # type: ignore
|
|
417
|
-
|
|
418
|
-
post_fix = ""
|
|
377
|
+
|
|
378
|
+
if not use_beta:
|
|
379
|
+
assert add_json_schema_to_instruction, (
|
|
380
|
+
"add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
|
|
381
|
+
)
|
|
382
|
+
|
|
419
383
|
json_schema = response_model.model_json_schema()
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
384
|
+
|
|
385
|
+
# Build system message content in a single, clear block
|
|
386
|
+
assert instruction is not None, "Instruction must be provided."
|
|
387
|
+
assert prompt is not None, "Prompt must be provided."
|
|
388
|
+
system_content = instruction
|
|
389
|
+
|
|
390
|
+
# Add schema if needed
|
|
391
|
+
system_content = self._build_system_prompt(
|
|
392
|
+
response_model,
|
|
393
|
+
add_json_schema_to_instruction,
|
|
394
|
+
json_schema,
|
|
395
|
+
system_content,
|
|
396
|
+
think=think,
|
|
433
397
|
)
|
|
434
|
-
|
|
398
|
+
|
|
399
|
+
# Rebuild messages with updated system message if needed
|
|
400
|
+
messages = [
|
|
401
|
+
{"role": "system", "content": system_content},
|
|
402
|
+
{"role": "user", "content": prompt},
|
|
403
|
+
] # type: ignore
|
|
435
404
|
|
|
436
405
|
model_kwargs = {}
|
|
437
406
|
if temperature is not None:
|
|
@@ -443,38 +412,98 @@ class AsyncLM:
|
|
|
443
412
|
use_cache = self.do_cache if cache is None else cache
|
|
444
413
|
cache_key = None
|
|
445
414
|
completion = None
|
|
415
|
+
choice = None
|
|
416
|
+
parsed = None
|
|
417
|
+
|
|
446
418
|
if use_cache:
|
|
447
419
|
cache_data = {
|
|
448
420
|
"messages": messages,
|
|
449
421
|
"model_kwargs": model_kwargs,
|
|
450
422
|
"guided_json": json_schema,
|
|
451
423
|
"response_format": response_model.__name__,
|
|
424
|
+
"use_beta": use_beta,
|
|
452
425
|
}
|
|
453
426
|
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
454
427
|
completion = self._load_cache(cache_key) # dict
|
|
428
|
+
|
|
455
429
|
if not completion:
|
|
456
|
-
completion = await self.
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
430
|
+
completion, choice, parsed = await self._call_and_parse_completion(
|
|
431
|
+
messages,
|
|
432
|
+
response_model,
|
|
433
|
+
json_schema,
|
|
434
|
+
use_beta=use_beta,
|
|
435
|
+
model_kwargs=model_kwargs,
|
|
461
436
|
)
|
|
462
|
-
|
|
437
|
+
|
|
463
438
|
if cache_key:
|
|
464
439
|
self._dump_cache(cache_key, completion)
|
|
440
|
+
else:
|
|
441
|
+
# Extract choice and parsed from cached completion
|
|
442
|
+
choice = completion["choices"][0]["message"]
|
|
443
|
+
try:
|
|
444
|
+
parsed = self._parse_complete_output(completion, response_model)
|
|
445
|
+
except Exception as e:
|
|
446
|
+
raise ValueError(
|
|
447
|
+
f"Failed to parse cached completion: {e}\nRaw: {choice.get('content')}"
|
|
448
|
+
) from e
|
|
449
|
+
|
|
465
450
|
assert isinstance(completion, dict), (
|
|
466
451
|
"Completion must be a dictionary with OpenAI response format."
|
|
467
452
|
)
|
|
468
|
-
self.
|
|
453
|
+
self._last_log = [prompt, messages, completion]
|
|
454
|
+
|
|
455
|
+
reasoning_content = choice.get("reasoning_content", "").strip()
|
|
456
|
+
_content = choice.get("content", "").lstrip("\n")
|
|
457
|
+
content = f"<think>\n{reasoning_content}\n</think>\n\n{_content}"
|
|
458
|
+
|
|
459
|
+
full_messages = messages + [{"role": "assistant", "content": content}]
|
|
469
460
|
|
|
470
|
-
output = cast(TParsed, self._parse_complete_output(completion, response_model))
|
|
471
|
-
full_messages = messages + [completion]
|
|
472
461
|
return ParsedOutput(
|
|
473
462
|
messages=full_messages,
|
|
474
463
|
completion=completion,
|
|
475
|
-
parsed=
|
|
464
|
+
parsed=parsed, # type: ignore
|
|
476
465
|
)
|
|
477
466
|
|
|
467
|
+
def _build_system_prompt(
|
|
468
|
+
self,
|
|
469
|
+
response_model,
|
|
470
|
+
add_json_schema_to_instruction,
|
|
471
|
+
json_schema,
|
|
472
|
+
system_content,
|
|
473
|
+
think,
|
|
474
|
+
):
|
|
475
|
+
if add_json_schema_to_instruction and response_model:
|
|
476
|
+
schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
477
|
+
# if schema_block not in system_content:
|
|
478
|
+
if "<output_json_schema>" in system_content:
|
|
479
|
+
# remove exsting schema block
|
|
480
|
+
import re # replace
|
|
481
|
+
|
|
482
|
+
system_content = re.sub(
|
|
483
|
+
r"<output_json_schema>.*?</output_json_schema>",
|
|
484
|
+
"",
|
|
485
|
+
system_content,
|
|
486
|
+
flags=re.DOTALL,
|
|
487
|
+
)
|
|
488
|
+
system_content = system_content.strip()
|
|
489
|
+
system_content += schema_block
|
|
490
|
+
|
|
491
|
+
if think is True:
|
|
492
|
+
if "/think" in system_content:
|
|
493
|
+
pass
|
|
494
|
+
elif "/no_think" in system_content:
|
|
495
|
+
system_content = system_content.replace("/no_think", "/think")
|
|
496
|
+
else:
|
|
497
|
+
system_content += "\n\n/think"
|
|
498
|
+
elif think is False:
|
|
499
|
+
if "/no_think" in system_content:
|
|
500
|
+
pass
|
|
501
|
+
elif "/think" in system_content:
|
|
502
|
+
system_content = system_content.replace("/think", "/no_think")
|
|
503
|
+
else:
|
|
504
|
+
system_content += "\n\n/no_think"
|
|
505
|
+
return system_content
|
|
506
|
+
|
|
478
507
|
def _parse_complete_output(
|
|
479
508
|
self, completion: Any, response_model: Type[BaseModel]
|
|
480
509
|
) -> BaseModel:
|
|
@@ -492,24 +521,24 @@ class AsyncLM:
|
|
|
492
521
|
# Try to extract tokens from the completion for debugging
|
|
493
522
|
input_tokens = None
|
|
494
523
|
try:
|
|
495
|
-
input_tokens = completion.get(
|
|
524
|
+
input_tokens = completion.get("usage", {}).get("prompt_tokens")
|
|
496
525
|
except Exception:
|
|
497
526
|
input_tokens = None
|
|
498
527
|
|
|
499
528
|
# Try to get the prompt/messages for tokenization
|
|
500
529
|
prompt = None
|
|
501
530
|
try:
|
|
502
|
-
prompt = completion.get(
|
|
531
|
+
prompt = completion.get("messages") or completion.get("prompt")
|
|
503
532
|
except Exception:
|
|
504
533
|
prompt = None
|
|
505
534
|
|
|
506
|
-
tokens_preview =
|
|
535
|
+
tokens_preview = ""
|
|
507
536
|
if prompt is not None:
|
|
508
537
|
try:
|
|
509
538
|
tokenizer = get_tokenizer(self.model)
|
|
510
539
|
if isinstance(prompt, list):
|
|
511
|
-
prompt_text =
|
|
512
|
-
m.get(
|
|
540
|
+
prompt_text = "\n".join(
|
|
541
|
+
m.get("content", "") for m in prompt if isinstance(m, dict)
|
|
513
542
|
)
|
|
514
543
|
else:
|
|
515
544
|
prompt_text = str(prompt)
|
|
@@ -518,17 +547,17 @@ class AsyncLM:
|
|
|
518
547
|
first_100 = tokens[:100]
|
|
519
548
|
last_100 = tokens[-100:] if n_tokens > 100 else []
|
|
520
549
|
tokens_preview = (
|
|
521
|
-
f
|
|
522
|
-
f
|
|
523
|
-
f
|
|
550
|
+
f"\nInput tokens: {n_tokens}"
|
|
551
|
+
f"\nFirst 100 tokens: {first_100}"
|
|
552
|
+
f"\nLast 100 tokens: {last_100}"
|
|
524
553
|
)
|
|
525
554
|
except Exception as exc:
|
|
526
|
-
tokens_preview = f
|
|
555
|
+
tokens_preview = f"\n[Tokenization failed: {exc}]"
|
|
527
556
|
|
|
528
557
|
raise ValueError(
|
|
529
|
-
f
|
|
530
|
-
f
|
|
531
|
-
f
|
|
558
|
+
f"Empty content in response."
|
|
559
|
+
f"\nInput tokens (if available): {input_tokens}"
|
|
560
|
+
f"{tokens_preview}"
|
|
532
561
|
)
|
|
533
562
|
|
|
534
563
|
try:
|
|
@@ -579,7 +608,7 @@ class AsyncLM:
|
|
|
579
608
|
if not hasattr(self, "last_log"):
|
|
580
609
|
return None
|
|
581
610
|
|
|
582
|
-
last_conv = self.
|
|
611
|
+
last_conv = self._last_log
|
|
583
612
|
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
584
613
|
last_msg = last_conv[2]
|
|
585
614
|
if not isinstance(last_msg, dict):
|
|
@@ -607,7 +636,7 @@ class AsyncLM:
|
|
|
607
636
|
if not hasattr(self, "last_log"):
|
|
608
637
|
raise ValueError("No history available. Please call the model first.")
|
|
609
638
|
|
|
610
|
-
prompt, messages, response = self.
|
|
639
|
+
prompt, messages, response = self._last_log
|
|
611
640
|
if hasattr(response, "model_dump"):
|
|
612
641
|
response = response.model_dump()
|
|
613
642
|
if not messages:
|
|
@@ -692,251 +721,59 @@ class AsyncLM:
|
|
|
692
721
|
logger.error(f"Failed to list models: {exc}")
|
|
693
722
|
return []
|
|
694
723
|
|
|
695
|
-
|
|
696
|
-
# --------------------------------------------------------------------------- #
|
|
697
|
-
# Module-level utility functions (async versions)
|
|
698
|
-
# --------------------------------------------------------------------------- #
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
@lru_cache(maxsize=10)
|
|
702
|
-
def get_tokenizer(model_name: str) -> Any:
|
|
703
|
-
"""Get tokenizer for the given model."""
|
|
704
|
-
from transformers import AutoTokenizer # type: ignore
|
|
705
|
-
|
|
706
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
707
|
-
return tokenizer
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
711
|
-
"""Async version of inspect_word_probs."""
|
|
712
|
-
|
|
713
|
-
import numpy as np
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
async def compute_word_log_probs(
|
|
717
|
-
tokenizer: Any,
|
|
718
|
-
lm_client: Any,
|
|
719
|
-
) -> tuple[List[Dict[str, Any]], Any]:
|
|
720
|
-
# Build a prompt that preserves literal newlines
|
|
721
|
-
prompt = tokenizer.apply_chat_template(
|
|
722
|
-
messages,
|
|
723
|
-
tokenize=False, # Don't tokenize yet, we need raw text
|
|
724
|
-
add_generation_prompt=False, # No generation prompt needed
|
|
725
|
-
)
|
|
726
|
-
|
|
727
|
-
# Request token logprobs
|
|
728
|
-
response = await lm_client.client.completions.create(
|
|
729
|
-
model=lm_client.model, # type: ignore
|
|
730
|
-
prompt=prompt,
|
|
731
|
-
max_tokens=1,
|
|
732
|
-
logprobs=1,
|
|
733
|
-
extra_body={"prompt_logprobs": 0},
|
|
734
|
-
)
|
|
735
|
-
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
736
|
-
|
|
737
|
-
# Override first token to known start marker
|
|
738
|
-
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
739
|
-
token_logprob_dicts[0] = {
|
|
740
|
-
str(start_id): {
|
|
741
|
-
"logprob": -1,
|
|
742
|
-
"rank": 1,
|
|
743
|
-
"decoded_token": "<|im_start|>",
|
|
744
|
-
}
|
|
745
|
-
}
|
|
746
|
-
|
|
747
|
-
# Flatten tokens
|
|
748
|
-
tokens: List[Dict[str, Any]] = [
|
|
749
|
-
{"id": int(tid), **tdata}
|
|
750
|
-
for td in token_logprob_dicts
|
|
751
|
-
for tid, tdata in td.items()
|
|
752
|
-
]
|
|
753
|
-
|
|
754
|
-
# Validate tokenization
|
|
755
|
-
tokenized = tokenizer.tokenize(prompt)
|
|
756
|
-
if len(tokenized) != len(tokens):
|
|
757
|
-
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
758
|
-
for idx, tok in enumerate(tokens):
|
|
759
|
-
if tokenized[idx] != tok["decoded_token"]:
|
|
760
|
-
raise AssertionError(
|
|
761
|
-
f"Token mismatch at {idx}: "
|
|
762
|
-
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
763
|
-
)
|
|
764
|
-
|
|
765
|
-
# Split on newline sentinel
|
|
766
|
-
split_prompt = prompt.replace("\n", " <NL> ")
|
|
767
|
-
words = split_prompt.split()
|
|
768
|
-
|
|
769
|
-
word_log_probs: List[Dict[str, Any]] = []
|
|
770
|
-
token_idx = 0
|
|
771
|
-
|
|
772
|
-
for word in words:
|
|
773
|
-
# Map sentinel back to actual newline for encoding
|
|
774
|
-
target = "\n" if word == "<NL>" else word
|
|
775
|
-
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
776
|
-
count = len(sub_ids)
|
|
777
|
-
if count == 0:
|
|
778
|
-
continue
|
|
779
|
-
|
|
780
|
-
subs = tokens[token_idx : token_idx + count]
|
|
781
|
-
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
782
|
-
prob = float(np.exp(avg_logprob))
|
|
783
|
-
word_log_probs.append({"word": target, "probability": prob})
|
|
784
|
-
token_idx += count
|
|
785
|
-
|
|
786
|
-
return word_log_probs, token_logprob_dicts # type: ignore
|
|
787
|
-
|
|
788
|
-
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
789
|
-
"""
|
|
790
|
-
Return an ANSI-colored string for word probabilities (red → green).
|
|
791
|
-
"""
|
|
792
|
-
if not word_log_probs:
|
|
793
|
-
return ""
|
|
794
|
-
|
|
795
|
-
probs = [entry["probability"] for entry in word_log_probs]
|
|
796
|
-
min_p, max_p = min(probs), max(probs)
|
|
797
|
-
parts: List[str] = []
|
|
798
|
-
|
|
799
|
-
for entry in word_log_probs:
|
|
800
|
-
word = entry["word"]
|
|
801
|
-
# Preserve actual line breaks
|
|
802
|
-
if word == "\n":
|
|
803
|
-
parts.append("\n")
|
|
804
|
-
continue
|
|
805
|
-
|
|
806
|
-
p = entry["probability"]
|
|
807
|
-
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
808
|
-
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
809
|
-
g = int(255 * norm) # green component (high when prob is high)
|
|
810
|
-
b = 0 # no blue for red-green gradient
|
|
811
|
-
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
812
|
-
parts.append(colored + " ")
|
|
813
|
-
|
|
814
|
-
return "".join(parts).rstrip()
|
|
815
|
-
|
|
816
|
-
word_probs, token_logprob_dicts = await compute_word_log_probs(tokenizer, lm)
|
|
817
|
-
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
# --------------------------------------------------------------------------- #
|
|
821
|
-
# Async LLMTask class
|
|
822
|
-
# --------------------------------------------------------------------------- #
|
|
823
|
-
|
|
824
|
-
InputModelType = TypeVar("InputModelType", bound=BaseModel)
|
|
825
|
-
OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
829
|
-
"""
|
|
830
|
-
Async callable wrapper around an AsyncLM endpoint.
|
|
831
|
-
|
|
832
|
-
Sub-classes must set:
|
|
833
|
-
• lm – the async language-model instance
|
|
834
|
-
• InputModel – a Pydantic input class
|
|
835
|
-
• OutputModel – a Pydantic output class
|
|
836
|
-
|
|
837
|
-
Optional flags:
|
|
838
|
-
• temperature – float (default 0.6)
|
|
839
|
-
• think – bool (if the backend supports "chain-of-thought")
|
|
840
|
-
• add_json_schema – bool (include schema in the instruction)
|
|
841
|
-
|
|
842
|
-
The **docstring** of each sub-class is sent as the LM instruction.
|
|
843
|
-
Example
|
|
844
|
-
```python
|
|
845
|
-
class DemoTask(AsyncLLMTask):
|
|
846
|
-
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
847
|
-
|
|
848
|
-
lm = AsyncLM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
849
|
-
|
|
850
|
-
class InputModel(BaseModel):
|
|
851
|
-
text_to_translate:str
|
|
852
|
-
|
|
853
|
-
class OutputModel(BaseModel):
|
|
854
|
-
translation:str
|
|
855
|
-
glossary_use:str
|
|
856
|
-
|
|
857
|
-
temperature = 0.6
|
|
858
|
-
think=False
|
|
859
|
-
|
|
860
|
-
demo_task = DemoTask()
|
|
861
|
-
result = await demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
862
|
-
```
|
|
863
|
-
"""
|
|
864
|
-
|
|
865
|
-
lm: "AsyncLM"
|
|
866
|
-
InputModel: InputModelType
|
|
867
|
-
OutputModel: OutputModelType
|
|
868
|
-
|
|
869
|
-
temperature: float = 0.6
|
|
870
|
-
think: bool = False
|
|
871
|
-
add_json_schema: bool = False
|
|
872
|
-
cache: bool = False
|
|
873
|
-
|
|
874
|
-
async def __call__(
|
|
724
|
+
async def _call_and_parse_completion(
|
|
875
725
|
self,
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
726
|
+
messages: list[dict],
|
|
727
|
+
response_model: Type[TParsed],
|
|
728
|
+
json_schema: dict,
|
|
729
|
+
use_beta: bool,
|
|
730
|
+
model_kwargs: dict,
|
|
731
|
+
) -> tuple[dict, dict, TParsed]:
|
|
732
|
+
"""Call vLLM or OpenAI-compatible endpoint and parse JSON response consistently."""
|
|
733
|
+
await self._set_model() # Ensure model is set before making the call
|
|
734
|
+
# Convert messages to proper type
|
|
735
|
+
converted_messages = self._convert_messages(messages) # type: ignore
|
|
736
|
+
|
|
737
|
+
if use_beta:
|
|
738
|
+
# Use guided JSON for structure enforcement
|
|
739
|
+
try:
|
|
740
|
+
completion = await self.client.chat.completions.create(
|
|
741
|
+
model=str(self.model), # type: ignore
|
|
742
|
+
messages=converted_messages,
|
|
743
|
+
extra_body={"guided_json": json_schema}, # type: ignore
|
|
744
|
+
**model_kwargs,
|
|
745
|
+
) # type: ignore
|
|
746
|
+
except Exception:
|
|
747
|
+
# Fallback if extra_body is not supported
|
|
748
|
+
completion = await self.client.chat.completions.create(
|
|
749
|
+
model=str(self.model), # type: ignore
|
|
750
|
+
messages=converted_messages,
|
|
751
|
+
response_format={"type": "json_object"},
|
|
752
|
+
**model_kwargs,
|
|
899
753
|
)
|
|
900
|
-
input_model = self.InputModel
|
|
901
|
-
output_model = self.OutputModel
|
|
902
|
-
|
|
903
|
-
# Ensure input_model is a class before calling
|
|
904
|
-
if isinstance(data, BaseModel):
|
|
905
|
-
item = data
|
|
906
|
-
elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
|
|
907
|
-
item = input_model(**data)
|
|
908
754
|
else:
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
755
|
+
# Use OpenAI-style structured output
|
|
756
|
+
completion = await self.client.chat.completions.create(
|
|
757
|
+
model=str(self.model), # type: ignore
|
|
758
|
+
messages=converted_messages,
|
|
759
|
+
response_format={"type": "json_object"},
|
|
760
|
+
**model_kwargs,
|
|
761
|
+
)
|
|
914
762
|
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
instruction=self.__doc__ or "",
|
|
918
|
-
response_model=output_model,
|
|
919
|
-
temperature=temperature or self.temperature,
|
|
920
|
-
think=think if think is not None else self.think,
|
|
921
|
-
add_json_schema_to_instruction=self.add_json_schema,
|
|
922
|
-
cache=self.cache or cache,
|
|
923
|
-
)
|
|
763
|
+
if hasattr(completion, "model_dump"):
|
|
764
|
+
completion = completion.model_dump()
|
|
924
765
|
|
|
925
|
-
|
|
926
|
-
cast(OutputModelType, result["parsed"]), # type: ignore
|
|
927
|
-
cast(List[dict], result["messages"]), # type: ignore
|
|
928
|
-
)
|
|
766
|
+
choice = completion["choices"][0]["message"]
|
|
929
767
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
return {"messages": messages}
|
|
768
|
+
try:
|
|
769
|
+
parsed = (
|
|
770
|
+
self._parse_complete_output(completion, response_model)
|
|
771
|
+
if use_beta
|
|
772
|
+
else response_model.model_validate(jloads(choice.get("content")))
|
|
773
|
+
)
|
|
774
|
+
except Exception as e:
|
|
775
|
+
raise ValueError(
|
|
776
|
+
f"Failed to parse model response: {e}\nRaw: {choice.get('content')}"
|
|
777
|
+
) from e
|
|
941
778
|
|
|
942
|
-
|
|
779
|
+
return completion, choice, parsed # type: ignore
|