speedy-utils 1.0.14__py3-none-any.whl → 1.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +10 -15
- llm_utils/chat_format/display.py +1 -1
- llm_utils/chat_format/transform.py +1 -2
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/alm.py +426 -14
- llm_utils/lm/chat_html.py +246 -0
- llm_utils/lm/lm.py +386 -78
- llm_utils/lm/lm_json.py +68 -0
- llm_utils/lm/utils.py +1 -1
- llm_utils/scripts/README.md +48 -0
- llm_utils/scripts/vllm_load_balancer.py +0 -1
- speedy_utils/__init__.py +96 -5
- speedy_utils/common/function_decorator.py +1 -4
- speedy_utils/common/logger.py +1 -1
- speedy_utils/common/notebook_utils.py +63 -0
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +7 -7
- speedy_utils/common/utils_misc.py +1 -2
- speedy_utils/common/utils_print.py +2 -65
- speedy_utils/multi_worker/process.py +9 -4
- speedy_utils/scripts/mpython.py +4 -4
- speedy_utils/scripts/openapi_client_codegen.py +1 -5
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/METADATA +1 -1
- speedy_utils-1.0.16.dist-info/RECORD +37 -0
- speedy_utils-1.0.14.dist-info/RECORD +0 -33
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -1,20 +1,16 @@
|
|
|
1
1
|
from .chat_format import (
|
|
2
|
-
transform_messages,
|
|
3
|
-
transform_messages_to_chatml,
|
|
4
|
-
show_chat,
|
|
5
|
-
get_conversation_one_turn,
|
|
6
|
-
show_string_diff,
|
|
7
|
-
display_conversations,
|
|
8
2
|
build_chatml_input,
|
|
9
|
-
format_msgs,
|
|
10
3
|
display_chat_messages_as_html,
|
|
4
|
+
display_conversations,
|
|
5
|
+
format_msgs,
|
|
6
|
+
get_conversation_one_turn,
|
|
7
|
+
show_chat,
|
|
8
|
+
show_string_diff,
|
|
9
|
+
transform_messages,
|
|
10
|
+
transform_messages_to_chatml,
|
|
11
11
|
)
|
|
12
|
-
from .lm.lm import LM,
|
|
12
|
+
from .lm.lm import LM, LLMTask
|
|
13
13
|
from .lm.alm import AsyncLM
|
|
14
|
-
from .group_messages import (
|
|
15
|
-
split_indices_by_length,
|
|
16
|
-
group_messages_by_len,
|
|
17
|
-
)
|
|
18
14
|
|
|
19
15
|
__all__ = [
|
|
20
16
|
"transform_messages",
|
|
@@ -25,10 +21,9 @@ __all__ = [
|
|
|
25
21
|
"display_conversations",
|
|
26
22
|
"build_chatml_input",
|
|
27
23
|
"format_msgs",
|
|
28
|
-
"
|
|
29
|
-
"group_messages_by_len",
|
|
24
|
+
# "group_messages_by_len",
|
|
30
25
|
"LM",
|
|
31
|
-
"LMReasoner",
|
|
32
26
|
"AsyncLM",
|
|
33
27
|
"display_chat_messages_as_html",
|
|
28
|
+
"LLMTask",
|
|
34
29
|
]
|
llm_utils/chat_format/display.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Callable, Dict, List, Sequence
|
|
4
3
|
|
|
5
4
|
|
|
6
5
|
def identify_format(item):
|
|
@@ -114,7 +113,7 @@ def transform_messages(
|
|
|
114
113
|
|
|
115
114
|
def transform_messages_to_chatml(input_data, input_format="auto"):
|
|
116
115
|
if input_format == "auto":
|
|
117
|
-
input_data =
|
|
116
|
+
input_data = deepcopy(input_data)
|
|
118
117
|
if isinstance(input_data, list):
|
|
119
118
|
input_format = "chatlm"
|
|
120
119
|
assert (
|
llm_utils/group_messages.py
CHANGED
|
@@ -76,7 +76,7 @@ def group_messages_by_len(
|
|
|
76
76
|
"""
|
|
77
77
|
if messages is None:
|
|
78
78
|
raise ValueError("messages parameter cannot be None")
|
|
79
|
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
79
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
|
|
80
80
|
|
|
81
81
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
82
82
|
|
llm_utils/lm/alm.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
"""An **asynchronous** drop‑in replacement for the original `LM` class.
|
|
4
2
|
|
|
5
3
|
Usage example (Python ≥3.8):
|
|
@@ -15,26 +13,30 @@ Usage example (Python ≥3.8):
|
|
|
15
13
|
asyncio.run(main())
|
|
16
14
|
"""
|
|
17
15
|
|
|
18
|
-
import asyncio
|
|
19
16
|
import base64
|
|
20
17
|
import hashlib
|
|
21
18
|
import json
|
|
22
19
|
import os
|
|
20
|
+
from abc import ABC
|
|
21
|
+
from functools import lru_cache
|
|
23
22
|
from typing import (
|
|
24
23
|
Any,
|
|
25
24
|
Dict,
|
|
26
25
|
List,
|
|
26
|
+
Literal,
|
|
27
27
|
Optional,
|
|
28
28
|
Sequence,
|
|
29
29
|
Type,
|
|
30
30
|
TypeVar,
|
|
31
31
|
Union,
|
|
32
|
-
overload,
|
|
33
32
|
cast,
|
|
33
|
+
overload,
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
from httpx import URL
|
|
37
|
-
from
|
|
37
|
+
from loguru import logger
|
|
38
|
+
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
39
|
+
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
38
40
|
|
|
39
41
|
# from openai.pagination import AsyncSyncPage
|
|
40
42
|
from openai.types.chat import (
|
|
@@ -44,11 +46,10 @@ from openai.types.chat import (
|
|
|
44
46
|
ChatCompletionToolMessageParam,
|
|
45
47
|
ChatCompletionUserMessageParam,
|
|
46
48
|
)
|
|
47
|
-
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
48
49
|
from openai.types.model import Model
|
|
49
50
|
from pydantic import BaseModel
|
|
50
|
-
|
|
51
|
-
from
|
|
51
|
+
|
|
52
|
+
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
52
53
|
|
|
53
54
|
# --------------------------------------------------------------------------- #
|
|
54
55
|
# type helpers
|
|
@@ -67,10 +68,20 @@ def _color(code: int, text: str) -> str:
|
|
|
67
68
|
return f"\x1b[{code}m{text}\x1b[0m"
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
_red
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
def _red(t):
|
|
72
|
+
return _color(31, t)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _green(t):
|
|
76
|
+
return _color(32, t)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _blue(t):
|
|
80
|
+
return _color(34, t)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _yellow(t):
|
|
84
|
+
return _color(33, t)
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
class AsyncLM:
|
|
@@ -100,6 +111,7 @@ class AsyncLM:
|
|
|
100
111
|
self.openai_kwargs = openai_kwargs
|
|
101
112
|
self.do_cache = cache
|
|
102
113
|
self.ports = ports
|
|
114
|
+
self._init_port = port # <-- store the port provided at init
|
|
103
115
|
|
|
104
116
|
# Async client
|
|
105
117
|
|
|
@@ -108,6 +120,7 @@ class AsyncLM:
|
|
|
108
120
|
# if have multiple ports
|
|
109
121
|
if self.ports:
|
|
110
122
|
import random
|
|
123
|
+
|
|
111
124
|
port = random.choice(self.ports)
|
|
112
125
|
api_base = f"http://{self.host}:{port}/v1"
|
|
113
126
|
logger.debug(f"Using port: {port}")
|
|
@@ -213,6 +226,13 @@ class AsyncLM:
|
|
|
213
226
|
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
214
227
|
)
|
|
215
228
|
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
229
|
+
# Check if cached value is an error
|
|
230
|
+
if isinstance(hit, dict) and hit.get("error"):
|
|
231
|
+
error_type = hit.get("error_type", "Unknown")
|
|
232
|
+
error_msg = hit.get("error_message", "Cached error")
|
|
233
|
+
logger.warning(f"Found cached error ({error_type}): {error_msg}")
|
|
234
|
+
# Re-raise as a ValueError with meaningful message
|
|
235
|
+
raise ValueError(f"Cached {error_type}: {error_msg}")
|
|
216
236
|
return hit
|
|
217
237
|
|
|
218
238
|
try:
|
|
@@ -230,8 +250,21 @@ class AsyncLM:
|
|
|
230
250
|
**kw,
|
|
231
251
|
)
|
|
232
252
|
|
|
233
|
-
except (AuthenticationError, RateLimitError) as exc:
|
|
234
|
-
|
|
253
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
254
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
255
|
+
logger.error(error_msg)
|
|
256
|
+
|
|
257
|
+
# Cache the error if it's a BadRequestError to avoid repeated calls
|
|
258
|
+
if isinstance(exc, BadRequestError) and cache_key:
|
|
259
|
+
error_response = {
|
|
260
|
+
"error": True,
|
|
261
|
+
"error_type": "BadRequestError",
|
|
262
|
+
"error_message": str(exc),
|
|
263
|
+
"choices": [],
|
|
264
|
+
}
|
|
265
|
+
self._dump_cache(cache_key, error_response)
|
|
266
|
+
logger.debug(f"Cached BadRequestError for key: {cache_key}")
|
|
267
|
+
|
|
235
268
|
raise
|
|
236
269
|
|
|
237
270
|
if cache_key:
|
|
@@ -354,10 +387,182 @@ class AsyncLM:
|
|
|
354
387
|
except Exception:
|
|
355
388
|
return None
|
|
356
389
|
|
|
390
|
+
# ------------------------------------------------------------------ #
|
|
391
|
+
# Missing methods from LM class
|
|
392
|
+
# ------------------------------------------------------------------ #
|
|
393
|
+
async def parse(
|
|
394
|
+
self,
|
|
395
|
+
response_model: Type[BaseModel],
|
|
396
|
+
instruction: Optional[str] = None,
|
|
397
|
+
prompt: Optional[str] = None,
|
|
398
|
+
messages: Optional[RawMsgs] = None,
|
|
399
|
+
think: Literal[True, False, None] = None,
|
|
400
|
+
add_json_schema_to_instruction: bool = False,
|
|
401
|
+
temperature: Optional[float] = None,
|
|
402
|
+
max_tokens: Optional[int] = None,
|
|
403
|
+
return_openai_response: bool = False,
|
|
404
|
+
cache: Optional[bool] = True,
|
|
405
|
+
**kwargs,
|
|
406
|
+
):
|
|
407
|
+
"""Parse response using guided JSON generation."""
|
|
408
|
+
if messages is None:
|
|
409
|
+
assert instruction is not None, "Instruction must be provided."
|
|
410
|
+
assert prompt is not None, "Prompt must be provided."
|
|
411
|
+
messages = [
|
|
412
|
+
{
|
|
413
|
+
"role": "system",
|
|
414
|
+
"content": instruction,
|
|
415
|
+
},
|
|
416
|
+
{
|
|
417
|
+
"role": "user",
|
|
418
|
+
"content": prompt,
|
|
419
|
+
},
|
|
420
|
+
] # type: ignore
|
|
421
|
+
|
|
422
|
+
post_fix = ""
|
|
423
|
+
json_schema = response_model.model_json_schema()
|
|
424
|
+
if add_json_schema_to_instruction and response_model:
|
|
425
|
+
_schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
426
|
+
post_fix += _schema
|
|
427
|
+
|
|
428
|
+
if think:
|
|
429
|
+
post_fix += "\n\n/think"
|
|
430
|
+
elif not think:
|
|
431
|
+
post_fix += "\n\n/no_think"
|
|
432
|
+
|
|
433
|
+
assert isinstance(messages, list), "Messages must be a list."
|
|
434
|
+
assert len(messages) > 0, "Messages cannot be empty."
|
|
435
|
+
assert messages[0]["role"] == "system", (
|
|
436
|
+
"First message must be a system message with instruction."
|
|
437
|
+
)
|
|
438
|
+
messages[0]["content"] += post_fix # type: ignore
|
|
439
|
+
|
|
440
|
+
model_kwargs = {}
|
|
441
|
+
if temperature is not None:
|
|
442
|
+
model_kwargs["temperature"] = temperature
|
|
443
|
+
if max_tokens is not None:
|
|
444
|
+
model_kwargs["max_tokens"] = max_tokens
|
|
445
|
+
model_kwargs.update(kwargs)
|
|
446
|
+
|
|
447
|
+
use_cache = self.do_cache if cache is None else cache
|
|
448
|
+
cache_key = None
|
|
449
|
+
if use_cache:
|
|
450
|
+
cache_data = {
|
|
451
|
+
"messages": messages,
|
|
452
|
+
"model_kwargs": model_kwargs,
|
|
453
|
+
"guided_json": json_schema,
|
|
454
|
+
"response_format": response_model.__name__,
|
|
455
|
+
}
|
|
456
|
+
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
457
|
+
cached_response = self._load_cache(cache_key)
|
|
458
|
+
self.last_log = [prompt, messages, cached_response]
|
|
459
|
+
if cached_response is not None:
|
|
460
|
+
if return_openai_response:
|
|
461
|
+
return cached_response
|
|
462
|
+
return self._parse_complete_output(cached_response, response_model)
|
|
463
|
+
|
|
464
|
+
completion = await self.client.chat.completions.create(
|
|
465
|
+
model=self.model, # type: ignore
|
|
466
|
+
messages=messages, # type: ignore
|
|
467
|
+
extra_body={"guided_json": json_schema},
|
|
468
|
+
**model_kwargs,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if cache_key:
|
|
472
|
+
self._dump_cache(cache_key, completion)
|
|
473
|
+
|
|
474
|
+
self.last_log = [prompt, messages, completion]
|
|
475
|
+
if return_openai_response:
|
|
476
|
+
return completion
|
|
477
|
+
return self._parse_complete_output(completion, response_model)
|
|
478
|
+
|
|
479
|
+
def _parse_complete_output(
|
|
480
|
+
self, completion: Any, response_model: Type[BaseModel]
|
|
481
|
+
) -> BaseModel:
|
|
482
|
+
"""Parse completion output to response model."""
|
|
483
|
+
if hasattr(completion, "model_dump"):
|
|
484
|
+
completion = completion.model_dump()
|
|
485
|
+
|
|
486
|
+
if "choices" not in completion or not completion["choices"]:
|
|
487
|
+
raise ValueError("No choices in OpenAI response")
|
|
488
|
+
|
|
489
|
+
content = completion["choices"][0]["message"]["content"]
|
|
490
|
+
if not content:
|
|
491
|
+
raise ValueError("Empty content in response")
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
data = json.loads(content)
|
|
495
|
+
return response_model.model_validate(data)
|
|
496
|
+
except Exception as exc:
|
|
497
|
+
raise ValueError(
|
|
498
|
+
f"Failed to parse response as {response_model.__name__}: {content}"
|
|
499
|
+
) from exc
|
|
500
|
+
|
|
501
|
+
async def inspect_word_probs(
|
|
502
|
+
self,
|
|
503
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
504
|
+
tokenizer: Optional[Any] = None,
|
|
505
|
+
do_print=True,
|
|
506
|
+
add_think: bool = True,
|
|
507
|
+
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
508
|
+
"""
|
|
509
|
+
Inspect word probabilities in a language model response.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
tokenizer: Tokenizer instance to encode words.
|
|
513
|
+
messages: List of messages to analyze.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
A tuple containing:
|
|
517
|
+
- List of word probabilities with their log probabilities.
|
|
518
|
+
- Token log probability dictionaries.
|
|
519
|
+
- Rendered string with colored word probabilities.
|
|
520
|
+
"""
|
|
521
|
+
if messages is None:
|
|
522
|
+
messages = await self.last_messages(add_think=add_think)
|
|
523
|
+
if messages is None:
|
|
524
|
+
raise ValueError("No messages provided and no last messages available.")
|
|
525
|
+
|
|
526
|
+
if tokenizer is None:
|
|
527
|
+
tokenizer = get_tokenizer(self.model)
|
|
528
|
+
|
|
529
|
+
ret = await inspect_word_probs_async(self, tokenizer, messages)
|
|
530
|
+
if do_print:
|
|
531
|
+
print(ret[-1])
|
|
532
|
+
return ret
|
|
533
|
+
|
|
534
|
+
async def last_messages(
|
|
535
|
+
self, add_think: bool = True
|
|
536
|
+
) -> Optional[List[Dict[str, str]]]:
|
|
537
|
+
"""Get the last conversation messages including assistant response."""
|
|
538
|
+
if not hasattr(self, "last_log"):
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
last_conv = self.last_log
|
|
542
|
+
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
543
|
+
last_msg = last_conv[2]
|
|
544
|
+
if not isinstance(last_msg, dict):
|
|
545
|
+
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
546
|
+
msg = last_conv[2]
|
|
547
|
+
# Ensure msg is a dict
|
|
548
|
+
if hasattr(msg, "model_dump"):
|
|
549
|
+
msg = msg.model_dump()
|
|
550
|
+
message = msg["choices"][0]["message"]
|
|
551
|
+
reasoning = message.get("reasoning_content")
|
|
552
|
+
answer = message.get("content")
|
|
553
|
+
if reasoning and add_think:
|
|
554
|
+
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
555
|
+
else:
|
|
556
|
+
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
557
|
+
assistant = {"role": "assistant", "content": final_answer}
|
|
558
|
+
messages = messages + [assistant] # type: ignore
|
|
559
|
+
return messages if messages else None
|
|
560
|
+
|
|
357
561
|
# ------------------------------------------------------------------ #
|
|
358
562
|
# Utility helpers
|
|
359
563
|
# ------------------------------------------------------------------ #
|
|
360
564
|
async def inspect_history(self) -> None:
|
|
565
|
+
"""Inspect the conversation history with proper formatting."""
|
|
361
566
|
if not hasattr(self, "last_log"):
|
|
362
567
|
raise ValueError("No history available. Please call the model first.")
|
|
363
568
|
|
|
@@ -445,3 +650,210 @@ class AsyncLM:
|
|
|
445
650
|
except Exception as exc:
|
|
446
651
|
logger.error(f"Failed to list models: {exc}")
|
|
447
652
|
return []
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
# --------------------------------------------------------------------------- #
|
|
656
|
+
# Module-level utility functions (async versions)
|
|
657
|
+
# --------------------------------------------------------------------------- #
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
@lru_cache(maxsize=10)
|
|
661
|
+
def get_tokenizer(model_name: str) -> Any:
|
|
662
|
+
"""Get tokenizer for the given model."""
|
|
663
|
+
from transformers import AutoTokenizer # type: ignore
|
|
664
|
+
|
|
665
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
666
|
+
return tokenizer
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
670
|
+
"""Async version of inspect_word_probs."""
|
|
671
|
+
|
|
672
|
+
import numpy as np
|
|
673
|
+
|
|
674
|
+
async def compute_word_log_probs(
|
|
675
|
+
tokenizer: Any,
|
|
676
|
+
lm_client: Any,
|
|
677
|
+
) -> tuple[List[Dict[str, Any]], Any]:
|
|
678
|
+
# Build a prompt that preserves literal newlines
|
|
679
|
+
prompt = tokenizer.apply_chat_template(
|
|
680
|
+
messages,
|
|
681
|
+
tokenize=False, # Don't tokenize yet, we need raw text
|
|
682
|
+
add_generation_prompt=False, # No generation prompt needed
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# Request token logprobs
|
|
686
|
+
response = await lm_client.client.completions.create(
|
|
687
|
+
model=lm_client.model, # type: ignore
|
|
688
|
+
prompt=prompt,
|
|
689
|
+
max_tokens=1,
|
|
690
|
+
logprobs=1,
|
|
691
|
+
extra_body={"prompt_logprobs": 0},
|
|
692
|
+
)
|
|
693
|
+
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
694
|
+
|
|
695
|
+
# Override first token to known start marker
|
|
696
|
+
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
697
|
+
token_logprob_dicts[0] = {
|
|
698
|
+
str(start_id): {
|
|
699
|
+
"logprob": -1,
|
|
700
|
+
"rank": 1,
|
|
701
|
+
"decoded_token": "<|im_start|>",
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
# Flatten tokens
|
|
706
|
+
tokens: List[Dict[str, Any]] = [
|
|
707
|
+
{"id": int(tid), **tdata}
|
|
708
|
+
for td in token_logprob_dicts
|
|
709
|
+
for tid, tdata in td.items()
|
|
710
|
+
]
|
|
711
|
+
|
|
712
|
+
# Validate tokenization
|
|
713
|
+
tokenized = tokenizer.tokenize(prompt)
|
|
714
|
+
if len(tokenized) != len(tokens):
|
|
715
|
+
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
716
|
+
for idx, tok in enumerate(tokens):
|
|
717
|
+
if tokenized[idx] != tok["decoded_token"]:
|
|
718
|
+
raise AssertionError(
|
|
719
|
+
f"Token mismatch at {idx}: "
|
|
720
|
+
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Split on newline sentinel
|
|
724
|
+
split_prompt = prompt.replace("\n", " <NL> ")
|
|
725
|
+
words = split_prompt.split()
|
|
726
|
+
|
|
727
|
+
word_log_probs: List[Dict[str, Any]] = []
|
|
728
|
+
token_idx = 0
|
|
729
|
+
|
|
730
|
+
for word in words:
|
|
731
|
+
# Map sentinel back to actual newline for encoding
|
|
732
|
+
target = "\n" if word == "<NL>" else word
|
|
733
|
+
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
734
|
+
count = len(sub_ids)
|
|
735
|
+
if count == 0:
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
subs = tokens[token_idx : token_idx + count]
|
|
739
|
+
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
740
|
+
prob = float(np.exp(avg_logprob))
|
|
741
|
+
word_log_probs.append({"word": target, "probability": prob})
|
|
742
|
+
token_idx += count
|
|
743
|
+
|
|
744
|
+
return word_log_probs, token_logprob_dicts # type: ignore
|
|
745
|
+
|
|
746
|
+
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
747
|
+
"""
|
|
748
|
+
Return an ANSI-colored string for word probabilities (red → green).
|
|
749
|
+
"""
|
|
750
|
+
if not word_log_probs:
|
|
751
|
+
return ""
|
|
752
|
+
|
|
753
|
+
probs = [entry["probability"] for entry in word_log_probs]
|
|
754
|
+
min_p, max_p = min(probs), max(probs)
|
|
755
|
+
parts: List[str] = []
|
|
756
|
+
|
|
757
|
+
for entry in word_log_probs:
|
|
758
|
+
word = entry["word"]
|
|
759
|
+
# Preserve actual line breaks
|
|
760
|
+
if word == "\n":
|
|
761
|
+
parts.append("\n")
|
|
762
|
+
continue
|
|
763
|
+
|
|
764
|
+
p = entry["probability"]
|
|
765
|
+
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
766
|
+
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
767
|
+
g = int(255 * norm) # green component (high when prob is high)
|
|
768
|
+
b = 0 # no blue for red-green gradient
|
|
769
|
+
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
770
|
+
parts.append(colored + " ")
|
|
771
|
+
|
|
772
|
+
return "".join(parts).rstrip()
|
|
773
|
+
|
|
774
|
+
word_probs, token_logprob_dicts = await compute_word_log_probs(tokenizer, lm)
|
|
775
|
+
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
# --------------------------------------------------------------------------- #
|
|
779
|
+
# Async LLMTask class
|
|
780
|
+
# --------------------------------------------------------------------------- #
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
class AsyncLLMTask(ABC):
|
|
784
|
+
"""
|
|
785
|
+
Async callable wrapper around an AsyncLM endpoint.
|
|
786
|
+
|
|
787
|
+
Sub-classes must set:
|
|
788
|
+
• lm – the async language-model instance
|
|
789
|
+
• InputModel – a Pydantic input class
|
|
790
|
+
• OutputModel – a Pydantic output class
|
|
791
|
+
|
|
792
|
+
Optional flags:
|
|
793
|
+
• temperature – float (default 0.6)
|
|
794
|
+
• think – bool (if the backend supports "chain-of-thought")
|
|
795
|
+
• add_json_schema – bool (include schema in the instruction)
|
|
796
|
+
|
|
797
|
+
The **docstring** of each sub-class is sent as the LM instruction.
|
|
798
|
+
Example
|
|
799
|
+
```python
|
|
800
|
+
class DemoTask(AsyncLLMTask):
|
|
801
|
+
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
802
|
+
|
|
803
|
+
lm = AsyncLM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
804
|
+
|
|
805
|
+
class InputModel(BaseModel):
|
|
806
|
+
text_to_translate:str
|
|
807
|
+
|
|
808
|
+
class OutputModel(BaseModel):
|
|
809
|
+
translation:str
|
|
810
|
+
glossary_use:str
|
|
811
|
+
|
|
812
|
+
temperature = 0.6
|
|
813
|
+
think=False
|
|
814
|
+
|
|
815
|
+
demo_task = DemoTask()
|
|
816
|
+
result = await demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
817
|
+
```
|
|
818
|
+
"""
|
|
819
|
+
|
|
820
|
+
lm: "AsyncLM"
|
|
821
|
+
InputModel: Type[BaseModel]
|
|
822
|
+
OutputModel: Type[BaseModel]
|
|
823
|
+
|
|
824
|
+
temperature: float = 0.6
|
|
825
|
+
think: bool = False
|
|
826
|
+
add_json_schema: bool = False
|
|
827
|
+
|
|
828
|
+
async def __call__(self, data: BaseModel | dict) -> BaseModel:
|
|
829
|
+
if (
|
|
830
|
+
not hasattr(self, "InputModel")
|
|
831
|
+
or not hasattr(self, "OutputModel")
|
|
832
|
+
or not hasattr(self, "lm")
|
|
833
|
+
):
|
|
834
|
+
raise NotImplementedError(
|
|
835
|
+
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
item = data if isinstance(data, BaseModel) else self.InputModel(**data)
|
|
839
|
+
|
|
840
|
+
return await self.lm.parse(
|
|
841
|
+
prompt=item.model_dump_json(),
|
|
842
|
+
instruction=self.__doc__ or "",
|
|
843
|
+
response_model=self.OutputModel,
|
|
844
|
+
temperature=self.temperature,
|
|
845
|
+
think=self.think,
|
|
846
|
+
add_json_schema_to_instruction=self.add_json_schema,
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
def generate_training_data(
|
|
850
|
+
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
851
|
+
):
|
|
852
|
+
"""Return share gpt like format"""
|
|
853
|
+
system_prompt = self.__doc__ or ""
|
|
854
|
+
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
855
|
+
assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
|
|
856
|
+
messages = get_conversation_one_turn(
|
|
857
|
+
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
858
|
+
)
|
|
859
|
+
return {"messages": messages}
|