speedy-utils 1.0.13__py3-none-any.whl → 1.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +10 -15
- llm_utils/lm/alm.py +24 -3
- llm_utils/lm/chat_html.py +244 -0
- llm_utils/lm/lm.py +390 -74
- llm_utils/lm/lm_json.py +72 -0
- llm_utils/scripts/README.md +48 -0
- llm_utils/scripts/example_vllm_client.py +269 -0
- llm_utils/scripts/requirements_example.txt +3 -0
- llm_utils/scripts/serve_script.sh +2 -0
- speedy_utils/__init__.py +96 -5
- speedy_utils/common/notebook_utils.py +63 -0
- speedy_utils/common/utils_cache.py +3 -3
- speedy_utils/common/utils_print.py +2 -65
- speedy_utils/multi_worker/process.py +7 -0
- speedy_utils/scripts/__init__.py +0 -0
- speedy_utils/scripts/mpython.py +3 -2
- speedy_utils/scripts/openapi_client_codegen.py +258 -0
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/RECORD +21 -12
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/entry_points.txt +1 -0
- {speedy_utils-1.0.13.dist-info → speedy_utils-1.0.15.dist-info}/WHEEL +0 -0
llm_utils/lm/lm.py
CHANGED
|
@@ -4,25 +4,27 @@ import base64
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
|
|
7
|
+
import warnings
|
|
8
|
+
from abc import ABC
|
|
8
9
|
from typing import (
|
|
9
10
|
Any,
|
|
10
11
|
Dict,
|
|
11
12
|
List,
|
|
13
|
+
Literal,
|
|
12
14
|
Optional,
|
|
13
15
|
Sequence,
|
|
14
16
|
Type,
|
|
15
17
|
TypeVar,
|
|
16
18
|
Union,
|
|
17
|
-
overload,
|
|
18
19
|
cast,
|
|
20
|
+
overload,
|
|
19
21
|
)
|
|
20
22
|
|
|
21
23
|
from httpx import URL
|
|
22
24
|
from huggingface_hub import repo_info
|
|
23
25
|
from loguru import logger
|
|
24
26
|
from numpy import isin
|
|
25
|
-
from openai import
|
|
27
|
+
from openai import AuthenticationError, OpenAI, RateLimitError
|
|
26
28
|
from openai.pagination import SyncPage
|
|
27
29
|
from openai.types.chat import (
|
|
28
30
|
ChatCompletionAssistantMessageParam,
|
|
@@ -34,7 +36,8 @@ from openai.types.chat import (
|
|
|
34
36
|
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
35
37
|
from openai.types.model import Model
|
|
36
38
|
from pydantic import BaseModel
|
|
37
|
-
|
|
39
|
+
|
|
40
|
+
from speedy_utils.common.utils_io import jdumps
|
|
38
41
|
|
|
39
42
|
# --------------------------------------------------------------------------- #
|
|
40
43
|
# type helpers
|
|
@@ -68,6 +71,17 @@ def _yellow(text: str) -> str:
|
|
|
68
71
|
return f"\x1b[33m{text}\x1b[0m"
|
|
69
72
|
|
|
70
73
|
|
|
74
|
+
# from functools import lru_cache
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# @lru_cache(maxsize=10)
|
|
78
|
+
# def get_tok(tokenizer_name):
|
|
79
|
+
# from transformers import AutoTokenizer
|
|
80
|
+
|
|
81
|
+
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
|
82
|
+
# return tokenizer
|
|
83
|
+
|
|
84
|
+
|
|
71
85
|
class LM:
|
|
72
86
|
"""
|
|
73
87
|
Unified language-model wrapper.
|
|
@@ -159,7 +173,7 @@ class LM:
|
|
|
159
173
|
available_models = self.list_models(port=port)
|
|
160
174
|
if available_models:
|
|
161
175
|
self.model = available_models[0]
|
|
162
|
-
logger.
|
|
176
|
+
logger.debug(f"Auto-selected model: {self.model}")
|
|
163
177
|
else:
|
|
164
178
|
raise ValueError("No models available to select from.")
|
|
165
179
|
else:
|
|
@@ -258,9 +272,9 @@ class LM:
|
|
|
258
272
|
if parsed:
|
|
259
273
|
# print(_green('<Parsed Structure>'))
|
|
260
274
|
if hasattr(parsed, "model_dump"):
|
|
261
|
-
print(
|
|
275
|
+
print(jdumps(parsed.model_dump(), indent=2))
|
|
262
276
|
else:
|
|
263
|
-
print(
|
|
277
|
+
print(jdumps(parsed, indent=2))
|
|
264
278
|
# print(_green('</Parsed Structure>'))
|
|
265
279
|
print()
|
|
266
280
|
|
|
@@ -468,91 +482,393 @@ class LM:
|
|
|
468
482
|
return None
|
|
469
483
|
|
|
470
484
|
@staticmethod
|
|
471
|
-
def list_models(
|
|
472
|
-
""
|
|
473
|
-
|
|
474
|
-
"""
|
|
485
|
+
def list_models(
|
|
486
|
+
port=None, host="localhost", base_url: Optional[str] = None
|
|
487
|
+
) -> List[str]:
|
|
488
|
+
"""List available models from OpenAI-compatible API server."""
|
|
475
489
|
try:
|
|
476
|
-
client: OpenAI =
|
|
477
|
-
|
|
478
|
-
|
|
490
|
+
client: OpenAI = OpenAI(
|
|
491
|
+
api_key=os.getenv("OPENAI_API_KEY", "abc"),
|
|
492
|
+
base_url=f"http://{host}:{port}/v1" if port else base_url or None,
|
|
493
|
+
)
|
|
479
494
|
models: SyncPage[Model] = client.models.list()
|
|
480
495
|
return [model.id for model in models.data]
|
|
481
496
|
except Exception as exc:
|
|
482
|
-
|
|
483
|
-
|
|
497
|
+
endpoint = f"http://{host}:{port}/v1" if port else base_url
|
|
498
|
+
error_msg = str(exc)
|
|
484
499
|
|
|
500
|
+
if "404" in error_msg or "Not Found" in error_msg:
|
|
501
|
+
raise ValueError(
|
|
502
|
+
f"No OpenAI-compatible API found at {endpoint}. "
|
|
503
|
+
f"The endpoint appears to be running a different service "
|
|
504
|
+
f"(possibly Jupyter Server). Please check the port number."
|
|
505
|
+
) from exc
|
|
506
|
+
elif "Connection" in error_msg:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
f"Cannot connect to {endpoint}. "
|
|
509
|
+
f"Please verify the service is running and accessible."
|
|
510
|
+
) from exc
|
|
511
|
+
else:
|
|
512
|
+
raise ValueError(
|
|
513
|
+
f"Failed to list models from {endpoint}: {error_msg}"
|
|
514
|
+
) from exc
|
|
485
515
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
516
|
+
def parse(
|
|
517
|
+
self,
|
|
518
|
+
response_model: Type[BaseModel],
|
|
519
|
+
instruction: Optional[str] = None,
|
|
520
|
+
prompt: Optional[str] = None,
|
|
521
|
+
messages: Optional[RawMsgs] = None,
|
|
522
|
+
think: Literal[True, False, None] = None,
|
|
523
|
+
add_json_schema_to_instruction: bool = False,
|
|
524
|
+
temperature: Optional[float] = None,
|
|
525
|
+
max_tokens: Optional[int] = None,
|
|
526
|
+
return_openai_response: bool = False,
|
|
527
|
+
cache: Optional[bool] = True,
|
|
528
|
+
**kwargs,
|
|
529
|
+
):
|
|
530
|
+
if messages is None:
|
|
531
|
+
assert instruction is not None, "Instruction must be provided."
|
|
532
|
+
assert prompt is not None, "Prompt must be provided."
|
|
533
|
+
messages = [
|
|
534
|
+
{
|
|
535
|
+
"role": "system",
|
|
536
|
+
"content": instruction,
|
|
537
|
+
},
|
|
538
|
+
{
|
|
539
|
+
"role": "user",
|
|
540
|
+
"content": prompt,
|
|
541
|
+
},
|
|
542
|
+
] # type: ignore
|
|
543
|
+
|
|
544
|
+
post_fix = ""
|
|
545
|
+
json_schema = response_model.model_json_schema()
|
|
546
|
+
if add_json_schema_to_instruction and response_model:
|
|
547
|
+
_schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
548
|
+
post_fix += _schema
|
|
549
|
+
|
|
550
|
+
if think:
|
|
551
|
+
post_fix += "\n\n/think"
|
|
552
|
+
elif think == False:
|
|
553
|
+
post_fix += "\n\n/no_think"
|
|
554
|
+
|
|
555
|
+
assert isinstance(messages, list), "Messages must be a list."
|
|
556
|
+
assert len(messages) > 0, "Messages cannot be empty."
|
|
557
|
+
assert (
|
|
558
|
+
messages[0]["role"] == "system"
|
|
559
|
+
), "First message must be a system message with instruction."
|
|
560
|
+
messages[0]["content"] += post_fix # type: ignore
|
|
561
|
+
|
|
562
|
+
model_kwargs = {}
|
|
563
|
+
if temperature is not None:
|
|
564
|
+
model_kwargs["temperature"] = temperature
|
|
565
|
+
if max_tokens is not None:
|
|
566
|
+
model_kwargs["max_tokens"] = max_tokens
|
|
567
|
+
model_kwargs.update(kwargs)
|
|
568
|
+
|
|
569
|
+
use_cache = self.do_cache if cache is None else cache
|
|
570
|
+
cache_key = None
|
|
571
|
+
if use_cache:
|
|
572
|
+
cache_data = {
|
|
573
|
+
"messages": messages,
|
|
574
|
+
"model_kwargs": model_kwargs,
|
|
575
|
+
"guided_json": json_schema,
|
|
576
|
+
"response_format": response_model.__name__,
|
|
577
|
+
}
|
|
578
|
+
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
579
|
+
cached_response = self._load_cache(cache_key)
|
|
580
|
+
self.last_log = [prompt, messages, cached_response]
|
|
581
|
+
if cached_response is not None:
|
|
582
|
+
if return_openai_response:
|
|
583
|
+
return cached_response
|
|
584
|
+
return self._parse_complete_output(cached_response, response_model)
|
|
585
|
+
|
|
586
|
+
completion = self.client.chat.completions.create(
|
|
587
|
+
model=self.model, # type: ignore
|
|
588
|
+
messages=messages, # type: ignore
|
|
589
|
+
extra_body={"guided_json": json_schema},
|
|
590
|
+
**model_kwargs,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if cache_key:
|
|
594
|
+
self._dump_cache(cache_key, completion)
|
|
595
|
+
|
|
596
|
+
self.last_log = [prompt, messages, completion]
|
|
597
|
+
if return_openai_response:
|
|
598
|
+
return completion
|
|
599
|
+
return self._parse_complete_output(completion, response_model)
|
|
493
600
|
|
|
601
|
+
def _parse_complete_output(
|
|
602
|
+
self, completion: Any, response_model: Type[BaseModel]
|
|
603
|
+
) -> BaseModel:
|
|
604
|
+
"""Parse completion output to response model."""
|
|
605
|
+
if hasattr(completion, "model_dump"):
|
|
606
|
+
completion = completion.model_dump()
|
|
494
607
|
|
|
495
|
-
|
|
496
|
-
|
|
608
|
+
if "choices" not in completion or not completion["choices"]:
|
|
609
|
+
raise ValueError("No choices in OpenAI response")
|
|
497
610
|
|
|
498
|
-
|
|
611
|
+
content = completion["choices"][0]["message"]["content"]
|
|
612
|
+
if not content:
|
|
613
|
+
raise ValueError("Empty content in response")
|
|
614
|
+
|
|
615
|
+
try:
|
|
616
|
+
data = json.loads(content)
|
|
617
|
+
return response_model.model_validate(data)
|
|
618
|
+
except Exception as exc:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
f"Failed to parse response as {response_model.__name__}: {content}"
|
|
621
|
+
) from exc
|
|
622
|
+
|
|
623
|
+
def inspect_word_probs(
|
|
624
|
+
self,
|
|
625
|
+
messages: Optional[List[Dict[str, Any]]] = None,
|
|
626
|
+
tokenizer: Optional[Any] = None,
|
|
627
|
+
do_print=True,
|
|
628
|
+
add_think: bool = True,
|
|
629
|
+
) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
499
630
|
"""
|
|
500
|
-
|
|
631
|
+
Inspect word probabilities in a language model response.
|
|
501
632
|
|
|
502
633
|
Args:
|
|
503
|
-
|
|
634
|
+
tokenizer: Tokenizer instance to encode words.
|
|
635
|
+
messages: List of messages to analyze.
|
|
504
636
|
|
|
505
637
|
Returns:
|
|
506
|
-
A
|
|
638
|
+
A tuple containing:
|
|
639
|
+
- List of word probabilities with their log probabilities.
|
|
640
|
+
- Token log probability dictionaries.
|
|
641
|
+
- Rendered string with colored word probabilities.
|
|
507
642
|
"""
|
|
508
|
-
|
|
509
|
-
|
|
643
|
+
if messages is None:
|
|
644
|
+
messages = self.last_messages(add_think=add_think)
|
|
645
|
+
if messages is None:
|
|
646
|
+
raise ValueError("No messages provided and no last messages available.")
|
|
647
|
+
|
|
648
|
+
if tokenizer is None:
|
|
649
|
+
tokenizer = get_tokenizer(self.model)
|
|
650
|
+
|
|
651
|
+
ret = inspect_word_probs(self, tokenizer, messages)
|
|
652
|
+
if do_print:
|
|
653
|
+
print(ret[-1])
|
|
654
|
+
return ret
|
|
655
|
+
|
|
656
|
+
def last_messages(self, add_think: bool = True) -> Optional[List[Dict[str, str]]]:
|
|
657
|
+
last_conv = self.last_log
|
|
658
|
+
messages = last_conv[1] if len(last_conv) > 1 else None
|
|
659
|
+
last_msg = last_conv[2]
|
|
660
|
+
if not isinstance(last_msg, dict):
|
|
661
|
+
last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
662
|
+
msg = last_conv[2]
|
|
663
|
+
# Ensure msg is a dict
|
|
664
|
+
if hasattr(msg, "model_dump"):
|
|
665
|
+
msg = msg.model_dump()
|
|
666
|
+
message = msg["choices"][0]["message"]
|
|
667
|
+
reasoning = message.get("reasoning_content")
|
|
668
|
+
answer = message.get("content")
|
|
669
|
+
if reasoning and add_think:
|
|
670
|
+
final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
671
|
+
else:
|
|
672
|
+
final_answer = f"<think>\n\n</think>\n{answer}"
|
|
673
|
+
assistant = {"role": "assistant", "content": final_answer}
|
|
674
|
+
messages = messages + [assistant] # type: ignore
|
|
675
|
+
return messages if messages else None
|
|
510
676
|
|
|
511
|
-
return regex
|
|
512
677
|
|
|
513
|
-
|
|
514
|
-
self,
|
|
515
|
-
response_format: type[BaseModel],
|
|
516
|
-
prompt: Optional[str] = None,
|
|
517
|
-
messages: Optional[RawMsgs] = None,
|
|
518
|
-
**kwargs,
|
|
519
|
-
):
|
|
678
|
+
from functools import lru_cache
|
|
520
679
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
pattern = re.compile(
|
|
546
|
-
r"<think>\n(?P<think>.*?)\n</think>\n\n(?P<json>\{.*\})",
|
|
547
|
-
re.DOTALL,
|
|
680
|
+
|
|
681
|
+
@lru_cache(maxsize=10)
|
|
682
|
+
def get_tokenizer(model_name: str) -> Any:
|
|
683
|
+
from transformers import AutoTokenizer
|
|
684
|
+
|
|
685
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
686
|
+
return tokenizer
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def inspect_word_probs(lm, tokenizer, messages):
|
|
690
|
+
import re
|
|
691
|
+
from typing import Any, Dict, List
|
|
692
|
+
|
|
693
|
+
import numpy as np
|
|
694
|
+
|
|
695
|
+
def compute_word_log_probs(
|
|
696
|
+
tokenizer: Any,
|
|
697
|
+
lm_client: Any,
|
|
698
|
+
) -> tuple[List[Dict[str, Any]], Any]:
|
|
699
|
+
# Build a prompt that preserves literal newlines
|
|
700
|
+
prompt = tokenizer.apply_chat_template(
|
|
701
|
+
messages,
|
|
702
|
+
tokenize=False, # Don't tokenize yet, we need raw text
|
|
703
|
+
add_generation_prompt=False, # No generation prompt needed
|
|
548
704
|
)
|
|
549
|
-
match = pattern.search(output)
|
|
550
|
-
if not match:
|
|
551
|
-
raise ValueError("Output does not match expected format")
|
|
552
|
-
parsed_output = match.group(0)
|
|
553
|
-
think_part = match.group("think")
|
|
554
|
-
json_part = match.group("json")
|
|
555
705
|
|
|
556
|
-
|
|
557
|
-
|
|
706
|
+
# Request token logprobs
|
|
707
|
+
response = lm_client.client.completions.create(
|
|
708
|
+
model=lm_client.model, # type: ignore
|
|
709
|
+
prompt=prompt,
|
|
710
|
+
max_tokens=1,
|
|
711
|
+
logprobs=1,
|
|
712
|
+
extra_body={"prompt_logprobs": 0},
|
|
713
|
+
)
|
|
714
|
+
token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
|
|
715
|
+
|
|
716
|
+
# Override first token to known start marker
|
|
717
|
+
start_id = tokenizer.encode("<|im_start|>")[0]
|
|
718
|
+
token_logprob_dicts[0] = {
|
|
719
|
+
str(start_id): {
|
|
720
|
+
"logprob": -1,
|
|
721
|
+
"rank": 1,
|
|
722
|
+
"decoded_token": "<|im_start|>",
|
|
723
|
+
}
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
# Flatten tokens
|
|
727
|
+
tokens: List[Dict[str, Any]] = [
|
|
728
|
+
{"id": int(tid), **tdata}
|
|
729
|
+
for td in token_logprob_dicts
|
|
730
|
+
for tid, tdata in td.items()
|
|
731
|
+
]
|
|
732
|
+
|
|
733
|
+
# Validate tokenization
|
|
734
|
+
tokenized = tokenizer.tokenize(prompt)
|
|
735
|
+
if len(tokenized) != len(tokens):
|
|
736
|
+
raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
|
|
737
|
+
for idx, tok in enumerate(tokens):
|
|
738
|
+
if tokenized[idx] != tok["decoded_token"]:
|
|
739
|
+
raise AssertionError(
|
|
740
|
+
f"Token mismatch at {idx}: "
|
|
741
|
+
f"{tokenized[idx]} != {tok['decoded_token']}"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Split on newline sentinel
|
|
745
|
+
split_prompt = prompt.replace("\n", " <NL> ")
|
|
746
|
+
words = split_prompt.split()
|
|
747
|
+
|
|
748
|
+
word_log_probs: List[Dict[str, Any]] = []
|
|
749
|
+
token_idx = 0
|
|
750
|
+
|
|
751
|
+
for word in words:
|
|
752
|
+
# Map sentinel back to actual newline for encoding
|
|
753
|
+
target = "\n" if word == "<NL>" else word
|
|
754
|
+
sub_ids = tokenizer.encode(target, add_special_tokens=False)
|
|
755
|
+
count = len(sub_ids)
|
|
756
|
+
if count == 0:
|
|
757
|
+
continue
|
|
758
|
+
|
|
759
|
+
subs = tokens[token_idx : token_idx + count]
|
|
760
|
+
avg_logprob = sum(s["logprob"] for s in subs) / count
|
|
761
|
+
prob = float(np.exp(avg_logprob))
|
|
762
|
+
word_log_probs.append({"word": target, "probability": prob})
|
|
763
|
+
token_idx += count
|
|
558
764
|
|
|
765
|
+
return word_log_probs, token_logprob_dicts # type: ignore
|
|
766
|
+
|
|
767
|
+
def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
|
|
768
|
+
"""
|
|
769
|
+
Return an ANSI-colored string for word probabilities (red → green).
|
|
770
|
+
"""
|
|
771
|
+
if not word_log_probs:
|
|
772
|
+
return ""
|
|
773
|
+
|
|
774
|
+
probs = [entry["probability"] for entry in word_log_probs]
|
|
775
|
+
min_p, max_p = min(probs), max(probs)
|
|
776
|
+
parts: List[str] = []
|
|
777
|
+
|
|
778
|
+
for entry in word_log_probs:
|
|
779
|
+
word = entry["word"]
|
|
780
|
+
# Preserve actual line breaks
|
|
781
|
+
if word == "\n":
|
|
782
|
+
parts.append("\n")
|
|
783
|
+
continue
|
|
784
|
+
|
|
785
|
+
p = entry["probability"]
|
|
786
|
+
norm = (p - min_p) / (max_p - min_p or 1.0)
|
|
787
|
+
r = int(255 * (1 - norm)) # red component (high when prob is low)
|
|
788
|
+
g = int(255 * norm) # green component (high when prob is high)
|
|
789
|
+
b = 0 # no blue for red-green gradient
|
|
790
|
+
colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
|
|
791
|
+
parts.append(colored + " ")
|
|
792
|
+
|
|
793
|
+
return "".join(parts).rstrip()
|
|
794
|
+
|
|
795
|
+
word_probs, token_logprob_dicts = compute_word_log_probs(tokenizer, lm)
|
|
796
|
+
return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
class LLMTask(ABC):
|
|
800
|
+
"""
|
|
801
|
+
Callable wrapper around an LM endpoint.
|
|
802
|
+
|
|
803
|
+
Sub-classes must set:
|
|
804
|
+
• lm – the language-model instance
|
|
805
|
+
• InputModel – a Pydantic input class
|
|
806
|
+
• OutputModel – a Pydantic output class
|
|
807
|
+
|
|
808
|
+
Optional flags:
|
|
809
|
+
• temperature – float (default 0.6)
|
|
810
|
+
• think – bool (if the backend supports “chain-of-thought”)
|
|
811
|
+
• add_json_schema – bool (include schema in the instruction)
|
|
812
|
+
|
|
813
|
+
The **docstring** of each sub-class is sent as the LM instruction.
|
|
814
|
+
Example
|
|
815
|
+
```python
|
|
816
|
+
class DemoTask(LLMTask):
|
|
817
|
+
"TODO: SYSTEM_PROMPT_INSTURCTION HERE"
|
|
818
|
+
|
|
819
|
+
lm = LM(port=8130, cache=False, model="gpt-3.5-turbo")
|
|
820
|
+
|
|
821
|
+
class InputModel(BaseModel):
|
|
822
|
+
text_to_translate:str
|
|
823
|
+
|
|
824
|
+
class OutputModel(BaseModel):
|
|
825
|
+
translation:str
|
|
826
|
+
glossary_use:str
|
|
827
|
+
|
|
828
|
+
temperature = 0.6
|
|
829
|
+
think=False
|
|
830
|
+
|
|
831
|
+
demo_task = DemoTask()
|
|
832
|
+
demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
|
|
833
|
+
```
|
|
834
|
+
"""
|
|
835
|
+
|
|
836
|
+
lm: "LM"
|
|
837
|
+
InputModel: Type[BaseModel]
|
|
838
|
+
OutputModel: Type[BaseModel]
|
|
839
|
+
|
|
840
|
+
temperature: float = 0.6
|
|
841
|
+
think: bool = False
|
|
842
|
+
add_json_schema: bool = False
|
|
843
|
+
|
|
844
|
+
def __call__(self, data: BaseModel | dict) -> BaseModel:
|
|
845
|
+
if (
|
|
846
|
+
not hasattr(self, "InputModel")
|
|
847
|
+
or not hasattr(self, "OutputModel")
|
|
848
|
+
or not hasattr(self, "lm")
|
|
849
|
+
):
|
|
850
|
+
raise NotImplementedError(
|
|
851
|
+
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
item = data if isinstance(data, BaseModel) else self.InputModel(**data)
|
|
855
|
+
|
|
856
|
+
return self.lm.parse(
|
|
857
|
+
prompt=item.model_dump_json(),
|
|
858
|
+
instruction=self.__doc__ or "",
|
|
859
|
+
response_model=self.OutputModel,
|
|
860
|
+
temperature=self.temperature,
|
|
861
|
+
think=self.think,
|
|
862
|
+
add_json_schema_to_instruction=self.add_json_schema,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
def generate_training_data(
|
|
866
|
+
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
867
|
+
):
|
|
868
|
+
"Return share gpt like format"
|
|
869
|
+
system_prompt = self.__doc__ or ""
|
|
870
|
+
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
871
|
+
assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
|
|
872
|
+
return get_conversation_one_turn(
|
|
873
|
+
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
874
|
+
)
|
llm_utils/lm/lm_json.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from functools import cache
|
|
4
|
+
from typing import *
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from llm_utils.lm.lm import LM, RawMsgs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LMJson(LM):
|
|
12
|
+
"Regex-based reasoning wrapper for LM."
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model: str | None = None,
|
|
17
|
+
*,
|
|
18
|
+
temperature: float = 0.0,
|
|
19
|
+
max_tokens: int = 2_000,
|
|
20
|
+
host: str = "localhost",
|
|
21
|
+
port: Optional[int | str] = None,
|
|
22
|
+
base_url: Optional[str] = None,
|
|
23
|
+
api_key: Optional[str] = None,
|
|
24
|
+
cache: bool = True,
|
|
25
|
+
**openai_kwargs: Any,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Initialize the LMJson instance.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model (str | None): The model name to use.
|
|
32
|
+
temperature (float): Sampling temperature.
|
|
33
|
+
max_tokens (int): Maximum number of tokens to generate.
|
|
34
|
+
host (str): Host for the API.
|
|
35
|
+
port (int | str, optional): Port for the API.
|
|
36
|
+
base_url (str, optional): Base URL for the API.
|
|
37
|
+
api_key (str, optional): API key for authentication.
|
|
38
|
+
cache (bool): Whether to cache responses.
|
|
39
|
+
**openai_kwargs: Additional OpenAI parameters.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(
|
|
42
|
+
model=model,
|
|
43
|
+
temperature=temperature,
|
|
44
|
+
max_tokens=max_tokens,
|
|
45
|
+
host=host,
|
|
46
|
+
port=port,
|
|
47
|
+
base_url=base_url,
|
|
48
|
+
api_key=api_key,
|
|
49
|
+
cache=cache,
|
|
50
|
+
**openai_kwargs,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def __call__(
|
|
54
|
+
self,
|
|
55
|
+
prompt: Optional[str] = None,
|
|
56
|
+
messages: Optional[RawMsgs] = None,
|
|
57
|
+
cache: Optional[bool] = None,
|
|
58
|
+
max_tokens: Optional[int] = None,
|
|
59
|
+
return_openai_response: bool = False,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
):
|
|
62
|
+
|
|
63
|
+
output = super().__call__(
|
|
64
|
+
prompt=prompt,
|
|
65
|
+
messages=messages,
|
|
66
|
+
response_format=str,
|
|
67
|
+
cache=cache,
|
|
68
|
+
max_tokens=max_tokens,
|
|
69
|
+
return_openai_response=return_openai_response,
|
|
70
|
+
**kwargs,
|
|
71
|
+
)
|
|
72
|
+
return output
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# VLLM Server Examples
|
|
2
|
+
|
|
3
|
+
This directory contains scripts for working with VLLM servers.
|
|
4
|
+
|
|
5
|
+
## Files
|
|
6
|
+
|
|
7
|
+
- `serve_script.sh` - Script to start the VLLM server
|
|
8
|
+
- `example_vllm_client.py` - Beautiful example client for interacting with VLLM
|
|
9
|
+
- `requirements_example.txt` - Python dependencies for the example
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
### 1. Start the VLLM Server
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
bash serve_script.sh
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
### 2. Install Dependencies
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install -r requirements_example.txt
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
### 3. Run Examples
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
python example_vllm_client.py
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Features
|
|
32
|
+
|
|
33
|
+
The example client demonstrates:
|
|
34
|
+
|
|
35
|
+
- ✅ Basic text generation
|
|
36
|
+
- ✅ Batch processing
|
|
37
|
+
- ✅ Creative writing with high temperature
|
|
38
|
+
- ✅ Code generation with low temperature
|
|
39
|
+
- ✅ Proper error handling
|
|
40
|
+
- ✅ Health checks
|
|
41
|
+
- ✅ Beautiful logging with loguru
|
|
42
|
+
- ✅ Type safety with Pydantic models
|
|
43
|
+
- ✅ Async/await patterns
|
|
44
|
+
|
|
45
|
+
## Configuration
|
|
46
|
+
|
|
47
|
+
The client connects to `http://localhost:8140` by default.
|
|
48
|
+
Modify the `VLLMClient` initialization to use different servers.
|