speedy-utils 1.0.14__py3-none-any.whl → 1.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +10 -15
- llm_utils/chat_format/display.py +1 -1
- llm_utils/chat_format/transform.py +1 -2
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/alm.py +426 -14
- llm_utils/lm/chat_html.py +246 -0
- llm_utils/lm/lm.py +386 -78
- llm_utils/lm/lm_json.py +68 -0
- llm_utils/lm/utils.py +1 -1
- llm_utils/scripts/README.md +48 -0
- llm_utils/scripts/vllm_load_balancer.py +0 -1
- speedy_utils/__init__.py +96 -5
- speedy_utils/common/function_decorator.py +1 -4
- speedy_utils/common/logger.py +1 -1
- speedy_utils/common/notebook_utils.py +63 -0
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +7 -7
- speedy_utils/common/utils_misc.py +1 -2
- speedy_utils/common/utils_print.py +2 -65
- speedy_utils/multi_worker/process.py +9 -4
- speedy_utils/scripts/mpython.py +4 -4
- speedy_utils/scripts/openapi_client_codegen.py +1 -5
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/METADATA +1 -1
- speedy_utils-1.0.16.dist-info/RECORD +37 -0
- speedy_utils-1.0.14.dist-info/RECORD +0 -33
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.14.dist-info → speedy_utils-1.0.16.dist-info}/entry_points.txt +0 -0
llm_utils/lm/lm.py
CHANGED
|
@@ -4,25 +4,24 @@ import base64
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
from
|
|
7
|
+
from abc import ABC
|
|
8
|
+
from functools import lru_cache
|
|
8
9
|
from typing import (
|
|
9
10
|
Any,
|
|
10
11
|
Dict,
|
|
11
12
|
List,
|
|
13
|
+
Literal,
|
|
12
14
|
Optional,
|
|
13
15
|
Sequence,
|
|
14
16
|
Type,
|
|
15
17
|
TypeVar,
|
|
16
18
|
Union,
|
|
17
|
-
overload,
|
|
18
19
|
cast,
|
|
20
|
+
overload,
|
|
19
21
|
)
|
|
20
22
|
|
|
21
|
-
from httpx import URL
|
|
22
|
-
from huggingface_hub import repo_info
|
|
23
23
|
from loguru import logger
|
|
24
|
-
from
|
|
25
|
-
from openai import OpenAI, AuthenticationError, RateLimitError
|
|
24
|
+
from openai import AuthenticationError, OpenAI, RateLimitError
|
|
26
25
|
from openai.pagination import SyncPage
|
|
27
26
|
from openai.types.chat import (
|
|
28
27
|
ChatCompletionAssistantMessageParam,
|
|
@@ -31,10 +30,11 @@ from openai.types.chat import (
|
|
|
31
30
|
ChatCompletionToolMessageParam,
|
|
32
31
|
ChatCompletionUserMessageParam,
|
|
33
32
|
)
|
|
34
|
-
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
35
33
|
from openai.types.model import Model
|
|
36
34
|
from pydantic import BaseModel
|
|
37
|
-
|
|
35
|
+
|
|
36
|
+
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
37
|
+
from speedy_utils.common.utils_io import jdumps
|
|
38
38
|
|
|
39
39
|
# --------------------------------------------------------------------------- #
|
|
40
40
|
# type helpers
|
|
@@ -68,6 +68,17 @@ def _yellow(text: str) -> str:
|
|
|
68
68
|
return f"\x1b[33m{text}\x1b[0m"
|
|
69
69
|
|
|
70
70
|
|
|
71
|
+
# from functools import lru_cache
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# @lru_cache(maxsize=10)
|
|
75
|
+
# def get_tok(tokenizer_name):
|
|
76
|
+
# from transformers import AutoTokenizer
|
|
77
|
+
|
|
78
|
+
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
|
79
|
+
# return tokenizer
|
|
80
|
+
|
|
81
|
+
|
|
71
82
|
class LM:
|
|
72
83
|
"""
|
|
73
84
|
Unified language-model wrapper.
|
|
@@ -159,7 +170,7 @@ class LM:
|
|
|
159
170
|
available_models = self.list_models(port=port)
|
|
160
171
|
if available_models:
|
|
161
172
|
self.model = available_models[0]
|
|
162
|
-
logger.
|
|
173
|
+
logger.debug(f"Auto-selected model: {self.model}")
|
|
163
174
|
else:
|
|
164
175
|
raise ValueError("No models available to select from.")
|
|
165
176
|
else:
|
|
@@ -258,9 +269,9 @@ class LM:
|
|
|
258
269
|
if parsed:
|
|
259
270
|
# print(_green('<Parsed Structure>'))
|
|
260
271
|
if hasattr(parsed, "model_dump"):
|
|
261
|
-
print(
|
|
272
|
+
print(jdumps(parsed.model_dump(), indent=2))
|
|
262
273
|
else:
|
|
263
|
-
print(
|
|
274
|
+
print(jdumps(parsed, indent=2))
|
|
264
275
|
# print(_green('</Parsed Structure>'))
|
|
265
276
|
print()
|
|
266
277
|
|
|
@@ -468,91 +479,388 @@ class LM:
|
|
|
468
479
|
return None
|
|
469
480
|
|
|
470
481
|
@staticmethod
|
|
471
|
-
def list_models(
|
|
472
|
-
""
|
|
473
|
-
|
|
474
|
-
"""
|
|
482
|
+
def list_models(
|
|
483
|
+
port=None, host="localhost", base_url: Optional[str] = None
|
|
484
|
+
) -> List[str]:
|
|
485
|
+
"""List available models from OpenAI-compatible API server."""
|
|
475
486
|
try:
|
|
476
|
-
client: OpenAI =
|
|
477
|
-
|
|
478
|
-
|
|
487
|
+
client: OpenAI = OpenAI(
|
|
488
|
+
api_key=os.getenv("OPENAI_API_KEY", "abc"),
|
|
489
|
+
base_url=f"http://{host}:{port}/v1" if port else base_url or None,
|
|
490
|
+
)
|
|
479
491
|
models: SyncPage[Model] = client.models.list()
|
|
480
492
|
return [model.id for model in models.data]
|
|
481
493
|
except Exception as exc:
|
|
482
|
-
|
|
483
|
-
|
|
494
|
+
endpoint = f"http://{host}:{port}/v1" if port else base_url
|
|
495
|
+
error_msg = str(exc)
|
|
484
496
|
|
|
497
|
+
if "404" in error_msg or "Not Found" in error_msg:
|
|
498
|
+
raise ValueError(
|
|
499
|
+
f"No OpenAI-compatible API found at {endpoint}. "
|
|
500
|
+
f"The endpoint appears to be running a different service "
|
|
501
|
+
f"(possibly Jupyter Server). Please check the port number."
|
|
502
|
+
) from exc
|
|
503
|
+
elif "Connection" in error_msg:
|
|
504
|
+
raise ValueError(
|
|
505
|
+
f"Cannot connect to {endpoint}. "
|
|
506
|
+
f"Please verify the service is running and accessible."
|
|
507
|
+
) from exc
|
|
508
|
+
else:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"Failed to list models from {endpoint}: {error_msg}"
|
|
511
|
+
) from exc
|
|
485
512
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
513
|
+
def parse(
|
|
514
|
+
self,
|
|
515
|
+
response_model: Type[BaseModel],
|
|
516
|
+
instruction: Optional[str] = None,
|
|
517
|
+
prompt: Optional[str] = None,
|
|
518
|
+
messages: Optional[RawMsgs] = None,
|
|
519
|
+
think: Literal[True, False, None] = None,
|
|
520
|
+
add_json_schema_to_instruction: bool = False,
|
|
521
|
+
temperature: Optional[float] = None,
|
|
522
|
+
max_tokens: Optional[int] = None,
|
|
523
|
+
return_openai_response: bool = False,
|
|
524
|
+
cache: Optional[bool] = True,
|
|
525
|
+
**kwargs,
|
|
526
|
+
):
|
|
527
|
+
if messages is None:
|
|
528
|
+
assert instruction is not None, "Instruction must be provided."
|
|
529
|
+
assert prompt is not None, "Prompt must be provided."
|
|
530
|
+
messages = [
|
|
531
|
+
{
|
|
532
|
+
"role": "system",
|
|
533
|
+
"content": instruction,
|
|
534
|
+
},
|
|
535
|
+
{
|
|
536
|
+
"role": "user",
|
|
537
|
+
"content": prompt,
|
|
538
|
+
},
|
|
539
|
+
] # type: ignore
|
|
540
|
+
|
|
541
|
+
post_fix = ""
|
|
542
|
+
json_schema = response_model.model_json_schema()
|
|
543
|
+
if add_json_schema_to_instruction and response_model:
|
|
544
|
+
_schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
545
|
+
post_fix += _schema
|
|
546
|
+
|
|
547
|
+
if think:
|
|
548
|
+
post_fix += "\n\n/think"
|
|
549
|
+
elif not think:
|
|
550
|
+
post_fix += "\n\n/no_think"
|
|
551
|
+
|
|
552
|
+
assert isinstance(messages, list), "Messages must be a list."
|
|
553
|
+
assert len(messages) > 0, "Messages cannot be empty."
|
|
554
|
+
assert (
|
|
555
|
+
messages[0]["role"] == "system"
|
|
556
|
+
), "First message must be a system message with instruction."
|
|
557
|
+
messages[0]["content"] += post_fix # type: ignore
|
|
558
|
+
|
|
559
|
+
model_kwargs = {}
|
|
560
|
+
if temperature is not None:
|
|
561
|
+
model_kwargs["temperature"] = temperature
|
|
562
|
+
if max_tokens is not None:
|
|
563
|
+
model_kwargs["max_tokens"] = max_tokens
|
|
564
|
+
model_kwargs.update(kwargs)
|
|
493
565
|
|
|
566
|
+
use_cache = self.do_cache if cache is None else cache
|
|
567
|
+
cache_key = None
|
|
568
|
+
if use_cache:
|
|
569
|
+
cache_data = {
|
|
570
|
+
"messages": messages,
|
|
571
|
+
"model_kwargs": model_kwargs,
|
|
572
|
+
"guided_json": json_schema,
|
|
573
|
+
"response_format": response_model.__name__,
|
|
574
|
+
}
|
|
575
|
+
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
576
|
+
cached_response = self._load_cache(cache_key)
|
|
577
|
+
self.last_log = [prompt, messages, cached_response]
|
|
578
|
+
if cached_response is not None:
|
|
579
|
+
if return_openai_response:
|
|
580
|
+
return cached_response
|
|
581
|
+
return self._parse_complete_output(cached_response, response_model)
|
|
582
|
+
|
|
583
|
+
completion = self.client.chat.completions.create(
|
|
584
|
+
model=self.model, # type: ignore
|
|
585
|
+
messages=messages, # type: ignore
|
|
586
|
+
extra_body={"guided_json": json_schema},
|
|
587
|
+
**model_kwargs,
|
|
588
|
+
)
|
|
494
589
|
|
|
495
|
-
|
|
496
|
-
|
|
590
|
+
if cache_key:
|
|
591
|
+
self._dump_cache(cache_key, completion)
|
|
592
|
+
|
|
593
|
+
self.last_log = [prompt, messages, completion]
|
|
594
|
+
if return_openai_response:
|
|
595
|
+
return completion
|
|
596
|
+
return self._parse_complete_output(completion, response_model)
|
|
497
597
|
|
|
498
|
-
def
|
|
598
|
+
def _parse_complete_output(
|
|
599
|
+
self, completion: Any, response_model: Type[BaseModel]
|
|
600
|
+
) -> BaseModel:
|
|
601
|
+
"""Parse completion output to response model."""
|
|
602
|
+
if hasattr(completion, "model_dump"):
|
|
603
|
+
completion = completion.model_dump()
|
|
604
|
+
|
|
605
|
+
if "choices" not in completion or not completion["choices"]:
|
|
606
|
+
raise ValueError("No choices in OpenAI response")
|
|
607
|
+
|
|
608
|
+
content = completion["choices"][0]["message"]["content"]
|
|
609
|
+
if not content:
|
|
610
|
+
raise ValueError("Empty content in response")
|
|
611
|
+
|
|
612
|
+
try:
|
|
613
|
+
data = json.loads(content)
|
|
614
|
+
return response_model.model_validate(data)
|
|
615
|
+
except Exception as exc:
|
|
616
|
+
raise ValueError(
|
|
617
|
+
f"Failed to parse response as {response_model.__name__}: {content}"
|
|
618
|
+
) from exc
|
|
619
|
+
|
|
620
|
+
def inspect_word_probs(
|
|
621
|
+
self,
|
|
622
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
623
|
+
tokenizer: Optional[Any] = None,
|
|
624
|
+
do_print=True,
|
|
625
|
+
add_think: bool = True,
|
|
626
|
+
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
499
627
|
"""
|
|
500
|
-
|
|
628
|
+
Inspect word probabilities in a language model response.
|
|
501
629
|
|
|
502
630
|
Args:
|
|
503
|
-
|
|
631
|
+
tokenizer: Tokenizer instance to encode words.
|
|
632
|
+
messages: List of messages to analyze.
|
|
504
633
|
|
|
505
634
|
Returns:
|
|
506
|
-
A
|
|
635
|
+
A tuple containing:
|
|
636
|
+
- List of word probabilities with their log probabilities.
|
|
637
|
+
- Token log probability dictionaries.
|
|
638
|
+
- Rendered string with colored word probabilities.
|
|
507
639
|
"""
|
|
508
|
-
|
|
509
|
-
|
|
640
|
+
if messages is None:
|
|
641
|
+
messages = self.last_messages(add_think=add_think)
|
|
642
|
+
if messages is None:
|
|
643
|
+
raise ValueError("No messages provided and no last messages available.")
|
|
644
|
+
|
|
645
|
+
if tokenizer is None:
|
|
646
|
+
tokenizer = get_tokenizer(self.model)
|
|
647
|
+
|
|
648
|
+
ret = inspect_word_probs(self, tokenizer, messages)
|
|
649
|
+
if do_print:
|
|
650
|
+
print(ret[-1])
|
|
651
|
+
return ret
|
|
652
|
+
|
|
653
|
+
def last_messages(self, add_think: bool = True) -> Optional[List[Dict[str, str]]]:
|
|
654
|
+
last_conv = self.last_log
|
|
655
|
+
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
656
|
+
last_msg = last_conv[2]
|
|
657
|
+
if not isinstance(last_msg, dict):
|
|
658
|
+
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
659
|
+
msg = last_conv[2]
|
|
660
|
+
# Ensure msg is a dict
|
|
661
|
+
if hasattr(msg, "model_dump"):
|
|
662
|
+
msg = msg.model_dump()
|
|
663
|
+
message = msg["choices"][0]["message"]
|
|
664
|
+
reasoning = message.get("reasoning_content")
|
|
665
|
+
answer = message.get("content")
|
|
666
|
+
if reasoning and add_think:
|
|
667
|
+
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
668
|
+
else:
|
|
669
|
+
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
670
|
+
assistant = {"role": "assistant", "content": final_answer}
|
|
671
|
+
messages = messages + [assistant] # type: ignore
|
|
672
|
+
return messages if messages else None
|
|
510
673
|
|
|
511
|
-
return regex
|
|
512
674
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
prompt: Optional[str] = None,
|
|
517
|
-
messages: Optional[RawMsgs] = None,
|
|
518
|
-
**kwargs,
|
|
519
|
-
):
|
|
675
|
+
@lru_cache(maxsize=10)
|
|
676
|
+
def get_tokenizer(model_name: str) -> Any:
|
|
677
|
+
from transformers import AutoTokenizer # type: ignore
|
|
520
678
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
re.DOTALL,
|
|
679
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
680
|
+
return tokenizer
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def inspect_word_probs(lm, tokenizer, messages):
|
|
684
|
+
|
|
685
|
+
import numpy as np
|
|
686
|
+
|
|
687
|
+
def compute_word_log_probs(
|
|
688
|
+
tokenizer: Any,
|
|
689
|
+
lm_client: Any,
|
|
690
|
+
) -> tuple[List[Dict[str, Any]], Any]:
|
|
691
|
+
# Build a prompt that preserves literal newlines
|
|
692
|
+
prompt = tokenizer.apply_chat_template(
|
|
693
|
+
messages,
|
|
694
|
+
tokenize=False, # Don't tokenize yet, we need raw text
|
|
695
|
+
add_generation_prompt=False, # No generation prompt needed
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Request token logprobs
|
|
699
|
+
response = lm_client.client.completions.create(
|
|
700
|
+
model=lm_client.model, # type: ignore
|
|
701
|
+
prompt=prompt,
|
|
702
|
+
max_tokens=1,
|
|
703
|
+
logprobs=1,
|
|
704
|
+
extra_body={"prompt_logprobs": 0},
|
|
548
705
|
)
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
706
|
+
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
707
|
+
|
|
708
|
+
# Override first token to known start marker
|
|
709
|
+
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
710
|
+
token_logprob_dicts[0] = {
|
|
711
|
+
str(start_id): {
|
|
712
|
+
"logprob": -1,
|
|
713
|
+
"rank": 1,
|
|
714
|
+
"decoded_token": "<|im_start|>",
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
# Flatten tokens
|
|
719
|
+
tokens: List[Dict[str, Any]] = [
|
|
720
|
+
{"id": int(tid), **tdata}
|
|
721
|
+
for td in token_logprob_dicts
|
|
722
|
+
for tid, tdata in td.items()
|
|
723
|
+
]
|
|
724
|
+
|
|
725
|
+
# Validate tokenization
|
|
726
|
+
tokenized = tokenizer.tokenize(prompt)
|
|
727
|
+
if len(tokenized) != len(tokens):
|
|
728
|
+
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
729
|
+
for idx, tok in enumerate(tokens):
|
|
730
|
+
if tokenized[idx] != tok["decoded_token"]:
|
|
731
|
+
raise AssertionError(
|
|
732
|
+
f"Token mismatch at {idx}: "
|
|
733
|
+
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
# Split on newline sentinel
|
|
737
|
+
split_prompt = prompt.replace("\n", " <NL> ")
|
|
738
|
+
words = split_prompt.split()
|
|
739
|
+
|
|
740
|
+
word_log_probs: List[Dict[str, Any]] = []
|
|
741
|
+
token_idx = 0
|
|
742
|
+
|
|
743
|
+
for word in words:
|
|
744
|
+
# Map sentinel back to actual newline for encoding
|
|
745
|
+
target = "\n" if word == "<NL>" else word
|
|
746
|
+
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
747
|
+
count = len(sub_ids)
|
|
748
|
+
if count == 0:
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
subs = tokens[token_idx : token_idx + count]
|
|
752
|
+
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
753
|
+
prob = float(np.exp(avg_logprob))
|
|
754
|
+
word_log_probs.append({"word": target, "probability": prob})
|
|
755
|
+
token_idx += count
|
|
555
756
|
|
|
556
|
-
|
|
557
|
-
return pydantic_object
|
|
757
|
+
return word_log_probs, token_logprob_dicts # type: ignore
|
|
558
758
|
|
|
759
|
+
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
760
|
+
"""
|
|
761
|
+
Return an ANSI-colored string for word probabilities (red → green).
|
|
762
|
+
"""
|
|
763
|
+
if not word_log_probs:
|
|
764
|
+
return ""
|
|
765
|
+
|
|
766
|
+
probs = [entry["probability"] for entry in word_log_probs]
|
|
767
|
+
min_p, max_p = min(probs), max(probs)
|
|
768
|
+
parts: List[str] = []
|
|
769
|
+
|
|
770
|
+
for entry in word_log_probs:
|
|
771
|
+
word = entry["word"]
|
|
772
|
+
# Preserve actual line breaks
|
|
773
|
+
if word == "\n":
|
|
774
|
+
parts.append("\n")
|
|
775
|
+
continue
|
|
776
|
+
|
|
777
|
+
p = entry["probability"]
|
|
778
|
+
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
779
|
+
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
780
|
+
g = int(255 * norm) # green component (high when prob is high)
|
|
781
|
+
b = 0 # no blue for red-green gradient
|
|
782
|
+
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
783
|
+
parts.append(colored + " ")
|
|
784
|
+
|
|
785
|
+
return "".join(parts).rstrip()
|
|
786
|
+
|
|
787
|
+
word_probs, token_logprob_dicts = compute_word_log_probs(tokenizer, lm)
|
|
788
|
+
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
class LLMTask(ABC):
|
|
792
|
+
"""
|
|
793
|
+
Callable wrapper around an LM endpoint.
|
|
794
|
+
|
|
795
|
+
Sub-classes must set:
|
|
796
|
+
• lm – the language-model instance
|
|
797
|
+
• InputModel – a Pydantic input class
|
|
798
|
+
• OutputModel – a Pydantic output class
|
|
799
|
+
|
|
800
|
+
Optional flags:
|
|
801
|
+
• temperature – float (default 0.6)
|
|
802
|
+
• think – bool (if the backend supports “chain-of-thought”)
|
|
803
|
+
• add_json_schema – bool (include schema in the instruction)
|
|
804
|
+
|
|
805
|
+
The **docstring** of each sub-class is sent as the LM instruction.
|
|
806
|
+
Example
|
|
807
|
+
```python
|
|
808
|
+
class DemoTask(LLMTask):
|
|
809
|
+
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
810
|
+
|
|
811
|
+
lm = LM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
812
|
+
|
|
813
|
+
class InputModel(BaseModel):
|
|
814
|
+
text_to_translate:str
|
|
815
|
+
|
|
816
|
+
class OutputModel(BaseModel):
|
|
817
|
+
translation:str
|
|
818
|
+
glossary_use:str
|
|
819
|
+
|
|
820
|
+
temperature = 0.6
|
|
821
|
+
think=False
|
|
822
|
+
|
|
823
|
+
demo_task = DemoTask()
|
|
824
|
+
demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
825
|
+
```
|
|
826
|
+
"""
|
|
827
|
+
|
|
828
|
+
lm: "LM"
|
|
829
|
+
InputModel: Type[BaseModel]
|
|
830
|
+
OutputModel: Type[BaseModel]
|
|
831
|
+
|
|
832
|
+
temperature: float = 0.6
|
|
833
|
+
think: bool = False
|
|
834
|
+
add_json_schema: bool = False
|
|
835
|
+
|
|
836
|
+
def __call__(self, data: BaseModel | dict) -> BaseModel:
|
|
837
|
+
if (
|
|
838
|
+
not hasattr(self, "InputModel")
|
|
839
|
+
or not hasattr(self, "OutputModel")
|
|
840
|
+
or not hasattr(self, "lm")
|
|
841
|
+
):
|
|
842
|
+
raise NotImplementedError(
|
|
843
|
+
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
item = data if isinstance(data, BaseModel) else self.InputModel(**data)
|
|
847
|
+
|
|
848
|
+
return self.lm.parse(
|
|
849
|
+
prompt=item.model_dump_json(),
|
|
850
|
+
instruction=self.__doc__ or "",
|
|
851
|
+
response_model=self.OutputModel,
|
|
852
|
+
temperature=self.temperature,
|
|
853
|
+
think=self.think,
|
|
854
|
+
add_json_schema_to_instruction=self.add_json_schema,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
def generate_training_data(
|
|
858
|
+
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
859
|
+
):
|
|
860
|
+
"Return share gpt like format"
|
|
861
|
+
system_prompt = self.__doc__ or ""
|
|
862
|
+
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
863
|
+
assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
|
|
864
|
+
return get_conversation_one_turn(
|
|
865
|
+
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
866
|
+
)
|
llm_utils/lm/lm_json.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from llm_utils.lm.lm import LM, RawMsgs
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LMJson(LM):
|
|
8
|
+
"Regex-based reasoning wrapper for LM."
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model: str | None = None,
|
|
13
|
+
*,
|
|
14
|
+
temperature: float = 0.0,
|
|
15
|
+
max_tokens: int = 2_000,
|
|
16
|
+
host: str = "localhost",
|
|
17
|
+
port: Optional[int | str] = None,
|
|
18
|
+
base_url: Optional[str] = None,
|
|
19
|
+
api_key: Optional[str] = None,
|
|
20
|
+
cache: bool = True,
|
|
21
|
+
**openai_kwargs: Any,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Initialize the LMJson instance.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model (str | None): The model name to use.
|
|
28
|
+
temperature (float): Sampling temperature.
|
|
29
|
+
max_tokens (int): Maximum number of tokens to generate.
|
|
30
|
+
host (str): Host for the API.
|
|
31
|
+
port (int | str, optional): Port for the API.
|
|
32
|
+
base_url (str, optional): Base URL for the API.
|
|
33
|
+
api_key (str, optional): API key for authentication.
|
|
34
|
+
cache (bool): Whether to cache responses.
|
|
35
|
+
**openai_kwargs: Additional OpenAI parameters.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(
|
|
38
|
+
model=model,
|
|
39
|
+
temperature=temperature,
|
|
40
|
+
max_tokens=max_tokens,
|
|
41
|
+
host=host,
|
|
42
|
+
port=port,
|
|
43
|
+
base_url=base_url,
|
|
44
|
+
api_key=api_key,
|
|
45
|
+
cache=cache,
|
|
46
|
+
**openai_kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def __call__(
|
|
50
|
+
self,
|
|
51
|
+
prompt: Optional[str] = None,
|
|
52
|
+
messages: Optional[RawMsgs] = None,
|
|
53
|
+
cache: Optional[bool] = None,
|
|
54
|
+
max_tokens: Optional[int] = None,
|
|
55
|
+
return_openai_response: bool = False,
|
|
56
|
+
**kwargs: Any,
|
|
57
|
+
):
|
|
58
|
+
|
|
59
|
+
output = super().__call__(
|
|
60
|
+
prompt=prompt,
|
|
61
|
+
messages=messages,
|
|
62
|
+
response_format=str,
|
|
63
|
+
cache=cache,
|
|
64
|
+
max_tokens=max_tokens,
|
|
65
|
+
return_openai_response=return_openai_response,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
68
|
+
return output
|
llm_utils/lm/utils.py
CHANGED
|
@@ -82,7 +82,7 @@ def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
|
|
|
82
82
|
try:
|
|
83
83
|
return func(self, *args, **kwargs)
|
|
84
84
|
except exceptions as e:
|
|
85
|
-
import litellm
|
|
85
|
+
import litellm # type: ignore
|
|
86
86
|
|
|
87
87
|
if isinstance(
|
|
88
88
|
e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# VLLM Server Examples
|
|
2
|
+
|
|
3
|
+
This directory contains scripts for working with VLLM servers.
|
|
4
|
+
|
|
5
|
+
## Files
|
|
6
|
+
|
|
7
|
+
- `serve_script.sh` - Script to start the VLLM server
|
|
8
|
+
- `example_vllm_client.py` - Beautiful example client for interacting with VLLM
|
|
9
|
+
- `requirements_example.txt` - Python dependencies for the example
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
### 1. Start the VLLM Server
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
bash serve_script.sh
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
### 2. Install Dependencies
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install -r requirements_example.txt
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
### 3. Run Examples
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
python example_vllm_client.py
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Features
|
|
32
|
+
|
|
33
|
+
The example client demonstrates:
|
|
34
|
+
|
|
35
|
+
- ✅ Basic text generation
|
|
36
|
+
- ✅ Batch processing
|
|
37
|
+
- ✅ Creative writing with high temperature
|
|
38
|
+
- ✅ Code generation with low temperature
|
|
39
|
+
- ✅ Proper error handling
|
|
40
|
+
- ✅ Health checks
|
|
41
|
+
- ✅ Beautiful logging with loguru
|
|
42
|
+
- ✅ Type safety with Pydantic models
|
|
43
|
+
- ✅ Async/await patterns
|
|
44
|
+
|
|
45
|
+
## Configuration
|
|
46
|
+
|
|
47
|
+
The client connects to `http://localhost:8140` by default.
|
|
48
|
+
Modify the `VLLMClient` initialization to use different servers.
|