speedy-utils 1.0.14__py3-none-any.whl → 1.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/__init__.py CHANGED
@@ -1,20 +1,16 @@
1
1
  from .chat_format import (
2
- transform_messages,
3
- transform_messages_to_chatml,
4
- show_chat,
5
- get_conversation_one_turn,
6
- show_string_diff,
7
- display_conversations,
8
2
  build_chatml_input,
9
- format_msgs,
10
3
  display_chat_messages_as_html,
4
+ display_conversations,
5
+ format_msgs,
6
+ get_conversation_one_turn,
7
+ show_chat,
8
+ show_string_diff,
9
+ transform_messages,
10
+ transform_messages_to_chatml,
11
11
  )
12
- from .lm.lm import LM, LMReasoner
12
+ from .lm.lm import LM, LLMTask
13
13
  from .lm.alm import AsyncLM
14
- from .group_messages import (
15
- split_indices_by_length,
16
- group_messages_by_len,
17
- )
18
14
 
19
15
  __all__ = [
20
16
  "transform_messages",
@@ -25,10 +21,9 @@ __all__ = [
25
21
  "display_conversations",
26
22
  "build_chatml_input",
27
23
  "format_msgs",
28
- "split_indices_by_length",
29
- "group_messages_by_len",
24
+ # "group_messages_by_len",
30
25
  "LM",
31
- "LMReasoner",
32
26
  "AsyncLM",
33
27
  "display_chat_messages_as_html",
28
+ "LLMTask",
34
29
  ]
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import List, Tuple, Sequence, Any, Dict, Optional
2
+ from typing import Any, Optional
3
3
  from IPython.display import HTML, display
4
4
  from difflib import SequenceMatcher
5
5
 
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
  from copy import deepcopy
3
- from typing import Callable, Dict, List, Sequence
4
3
 
5
4
 
6
5
  def identify_format(item):
@@ -114,7 +113,7 @@ def transform_messages(
114
113
 
115
114
  def transform_messages_to_chatml(input_data, input_format="auto"):
116
115
  if input_format == "auto":
117
- input_data = raw_data = deepcopy(input_data)
116
+ input_data = deepcopy(input_data)
118
117
  if isinstance(input_data, list):
119
118
  input_format = "chatlm"
120
119
  assert (
@@ -76,7 +76,7 @@ def group_messages_by_len(
76
76
  """
77
77
  if messages is None:
78
78
  raise ValueError("messages parameter cannot be None")
79
- from transformers.models.auto.tokenization_auto import AutoTokenizer
79
+ from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
80
80
 
81
81
  tokenizer = AutoTokenizer.from_pretrained(model_name)
82
82
 
llm_utils/lm/alm.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """An **asynchronous** drop‑in replacement for the original `LM` class.
4
2
 
5
3
  Usage example (Python ≥3.8):
@@ -15,26 +13,30 @@ Usage example (Python ≥3.8):
15
13
  asyncio.run(main())
16
14
  """
17
15
 
18
- import asyncio
19
16
  import base64
20
17
  import hashlib
21
18
  import json
22
19
  import os
20
+ from abc import ABC
21
+ from functools import lru_cache
23
22
  from typing import (
24
23
  Any,
25
24
  Dict,
26
25
  List,
26
+ Literal,
27
27
  Optional,
28
28
  Sequence,
29
29
  Type,
30
30
  TypeVar,
31
31
  Union,
32
- overload,
33
32
  cast,
33
+ overload,
34
34
  )
35
35
 
36
36
  from httpx import URL
37
- from openai import AsyncOpenAI, AuthenticationError, RateLimitError
37
+ from loguru import logger
38
+ from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
39
+ from openai.pagination import AsyncPage as AsyncSyncPage
38
40
 
39
41
  # from openai.pagination import AsyncSyncPage
40
42
  from openai.types.chat import (
@@ -44,11 +46,10 @@ from openai.types.chat import (
44
46
  ChatCompletionToolMessageParam,
45
47
  ChatCompletionUserMessageParam,
46
48
  )
47
- from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
48
49
  from openai.types.model import Model
49
50
  from pydantic import BaseModel
50
- from loguru import logger
51
- from openai.pagination import AsyncPage as AsyncSyncPage
51
+
52
+ from llm_utils.chat_format.display import get_conversation_one_turn
52
53
 
53
54
  # --------------------------------------------------------------------------- #
54
55
  # type helpers
@@ -67,10 +68,20 @@ def _color(code: int, text: str) -> str:
67
68
  return f"\x1b[{code}m{text}\x1b[0m"
68
69
 
69
70
 
70
- _red = lambda t: _color(31, t)
71
- _green = lambda t: _color(32, t)
72
- _blue = lambda t: _color(34, t)
73
- _yellow = lambda t: _color(33, t)
71
+ def _red(t):
72
+ return _color(31, t)
73
+
74
+
75
+ def _green(t):
76
+ return _color(32, t)
77
+
78
+
79
+ def _blue(t):
80
+ return _color(34, t)
81
+
82
+
83
+ def _yellow(t):
84
+ return _color(33, t)
74
85
 
75
86
 
76
87
  class AsyncLM:
@@ -100,6 +111,7 @@ class AsyncLM:
100
111
  self.openai_kwargs = openai_kwargs
101
112
  self.do_cache = cache
102
113
  self.ports = ports
114
+ self._init_port = port # <-- store the port provided at init
103
115
 
104
116
  # Async client
105
117
 
@@ -108,6 +120,7 @@ class AsyncLM:
108
120
  # if have multiple ports
109
121
  if self.ports:
110
122
  import random
123
+
111
124
  port = random.choice(self.ports)
112
125
  api_base = f"http://{self.host}:{port}/v1"
113
126
  logger.debug(f"Using port: {port}")
@@ -213,6 +226,13 @@ class AsyncLM:
213
226
  self._cache_key(messages, kw, response_format) if use_cache else None
214
227
  )
215
228
  if cache_key and (hit := self._load_cache(cache_key)) is not None:
229
+ # Check if cached value is an error
230
+ if isinstance(hit, dict) and hit.get("error"):
231
+ error_type = hit.get("error_type", "Unknown")
232
+ error_msg = hit.get("error_message", "Cached error")
233
+ logger.warning(f"Found cached error ({error_type}): {error_msg}")
234
+ # Re-raise as a ValueError with meaningful message
235
+ raise ValueError(f"Cached {error_type}: {error_msg}")
216
236
  return hit
217
237
 
218
238
  try:
@@ -230,8 +250,21 @@ class AsyncLM:
230
250
  **kw,
231
251
  )
232
252
 
233
- except (AuthenticationError, RateLimitError) as exc:
234
- logger.error(exc)
253
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
254
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
255
+ logger.error(error_msg)
256
+
257
+ # Cache the error if it's a BadRequestError to avoid repeated calls
258
+ if isinstance(exc, BadRequestError) and cache_key:
259
+ error_response = {
260
+ "error": True,
261
+ "error_type": "BadRequestError",
262
+ "error_message": str(exc),
263
+ "choices": [],
264
+ }
265
+ self._dump_cache(cache_key, error_response)
266
+ logger.debug(f"Cached BadRequestError for key: {cache_key}")
267
+
235
268
  raise
236
269
 
237
270
  if cache_key:
@@ -354,10 +387,182 @@ class AsyncLM:
354
387
  except Exception:
355
388
  return None
356
389
 
390
+ # ------------------------------------------------------------------ #
391
+ # Missing methods from LM class
392
+ # ------------------------------------------------------------------ #
393
+ async def parse(
394
+ self,
395
+ response_model: Type[BaseModel],
396
+ instruction: Optional[str] = None,
397
+ prompt: Optional[str] = None,
398
+ messages: Optional[RawMsgs] = None,
399
+ think: Literal[True, False, None] = None,
400
+ add_json_schema_to_instruction: bool = False,
401
+ temperature: Optional[float] = None,
402
+ max_tokens: Optional[int] = None,
403
+ return_openai_response: bool = False,
404
+ cache: Optional[bool] = True,
405
+ **kwargs,
406
+ ):
407
+ """Parse response using guided JSON generation."""
408
+ if messages is None:
409
+ assert instruction is not None, "Instruction must be provided."
410
+ assert prompt is not None, "Prompt must be provided."
411
+ messages = [
412
+ {
413
+ "role": "system",
414
+ "content": instruction,
415
+ },
416
+ {
417
+ "role": "user",
418
+ "content": prompt,
419
+ },
420
+ ] # type: ignore
421
+
422
+ post_fix = ""
423
+ json_schema = response_model.model_json_schema()
424
+ if add_json_schema_to_instruction and response_model:
425
+ _schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
426
+ post_fix += _schema
427
+
428
+ if think:
429
+ post_fix += "\n\n/think"
430
+ elif not think:
431
+ post_fix += "\n\n/no_think"
432
+
433
+ assert isinstance(messages, list), "Messages must be a list."
434
+ assert len(messages) > 0, "Messages cannot be empty."
435
+ assert messages[0]["role"] == "system", (
436
+ "First message must be a system message with instruction."
437
+ )
438
+ messages[0]["content"] += post_fix # type: ignore
439
+
440
+ model_kwargs = {}
441
+ if temperature is not None:
442
+ model_kwargs["temperature"] = temperature
443
+ if max_tokens is not None:
444
+ model_kwargs["max_tokens"] = max_tokens
445
+ model_kwargs.update(kwargs)
446
+
447
+ use_cache = self.do_cache if cache is None else cache
448
+ cache_key = None
449
+ if use_cache:
450
+ cache_data = {
451
+ "messages": messages,
452
+ "model_kwargs": model_kwargs,
453
+ "guided_json": json_schema,
454
+ "response_format": response_model.__name__,
455
+ }
456
+ cache_key = self._cache_key(cache_data, {}, response_model)
457
+ cached_response = self._load_cache(cache_key)
458
+ self.last_log = [prompt, messages, cached_response]
459
+ if cached_response is not None:
460
+ if return_openai_response:
461
+ return cached_response
462
+ return self._parse_complete_output(cached_response, response_model)
463
+
464
+ completion = await self.client.chat.completions.create(
465
+ model=self.model, # type: ignore
466
+ messages=messages, # type: ignore
467
+ extra_body={"guided_json": json_schema},
468
+ **model_kwargs,
469
+ )
470
+
471
+ if cache_key:
472
+ self._dump_cache(cache_key, completion)
473
+
474
+ self.last_log = [prompt, messages, completion]
475
+ if return_openai_response:
476
+ return completion
477
+ return self._parse_complete_output(completion, response_model)
478
+
479
+ def _parse_complete_output(
480
+ self, completion: Any, response_model: Type[BaseModel]
481
+ ) -> BaseModel:
482
+ """Parse completion output to response model."""
483
+ if hasattr(completion, "model_dump"):
484
+ completion = completion.model_dump()
485
+
486
+ if "choices" not in completion or not completion["choices"]:
487
+ raise ValueError("No choices in OpenAI response")
488
+
489
+ content = completion["choices"][0]["message"]["content"]
490
+ if not content:
491
+ raise ValueError("Empty content in response")
492
+
493
+ try:
494
+ data = json.loads(content)
495
+ return response_model.model_validate(data)
496
+ except Exception as exc:
497
+ raise ValueError(
498
+ f"Failed to parse response as {response_model.__name__}: {content}"
499
+ ) from exc
500
+
501
+ async def inspect_word_probs(
502
+ self,
503
+ messages: Optional[List[Dict[str, Any]]] = None,
504
+ tokenizer: Optional[Any] = None,
505
+ do_print=True,
506
+ add_think: bool = True,
507
+ ) -> tuple[List[Dict[str, Any]], Any, str]:
508
+ """
509
+ Inspect word probabilities in a language model response.
510
+
511
+ Args:
512
+ tokenizer: Tokenizer instance to encode words.
513
+ messages: List of messages to analyze.
514
+
515
+ Returns:
516
+ A tuple containing:
517
+ - List of word probabilities with their log probabilities.
518
+ - Token log probability dictionaries.
519
+ - Rendered string with colored word probabilities.
520
+ """
521
+ if messages is None:
522
+ messages = await self.last_messages(add_think=add_think)
523
+ if messages is None:
524
+ raise ValueError("No messages provided and no last messages available.")
525
+
526
+ if tokenizer is None:
527
+ tokenizer = get_tokenizer(self.model)
528
+
529
+ ret = await inspect_word_probs_async(self, tokenizer, messages)
530
+ if do_print:
531
+ print(ret[-1])
532
+ return ret
533
+
534
+ async def last_messages(
535
+ self, add_think: bool = True
536
+ ) -> Optional[List[Dict[str, str]]]:
537
+ """Get the last conversation messages including assistant response."""
538
+ if not hasattr(self, "last_log"):
539
+ return None
540
+
541
+ last_conv = self.last_log
542
+ messages = last_conv[1] if len(last_conv) > 1 else None
543
+ last_msg = last_conv[2]
544
+ if not isinstance(last_msg, dict):
545
+ last_conv[2] = last_conv[2].model_dump() # type: ignore
546
+ msg = last_conv[2]
547
+ # Ensure msg is a dict
548
+ if hasattr(msg, "model_dump"):
549
+ msg = msg.model_dump()
550
+ message = msg["choices"][0]["message"]
551
+ reasoning = message.get("reasoning_content")
552
+ answer = message.get("content")
553
+ if reasoning and add_think:
554
+ final_answer = f"<think>{reasoning}</think>\n{answer}"
555
+ else:
556
+ final_answer = f"<think>\n\n</think>\n{answer}"
557
+ assistant = {"role": "assistant", "content": final_answer}
558
+ messages = messages + [assistant] # type: ignore
559
+ return messages if messages else None
560
+
357
561
  # ------------------------------------------------------------------ #
358
562
  # Utility helpers
359
563
  # ------------------------------------------------------------------ #
360
564
  async def inspect_history(self) -> None:
565
+ """Inspect the conversation history with proper formatting."""
361
566
  if not hasattr(self, "last_log"):
362
567
  raise ValueError("No history available. Please call the model first.")
363
568
 
@@ -445,3 +650,210 @@ class AsyncLM:
445
650
  except Exception as exc:
446
651
  logger.error(f"Failed to list models: {exc}")
447
652
  return []
653
+
654
+
655
+ # --------------------------------------------------------------------------- #
656
+ # Module-level utility functions (async versions)
657
+ # --------------------------------------------------------------------------- #
658
+
659
+
660
+ @lru_cache(maxsize=10)
661
+ def get_tokenizer(model_name: str) -> Any:
662
+ """Get tokenizer for the given model."""
663
+ from transformers import AutoTokenizer # type: ignore
664
+
665
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
666
+ return tokenizer
667
+
668
+
669
+ async def inspect_word_probs_async(lm, tokenizer, messages):
670
+ """Async version of inspect_word_probs."""
671
+
672
+ import numpy as np
673
+
674
+ async def compute_word_log_probs(
675
+ tokenizer: Any,
676
+ lm_client: Any,
677
+ ) -> tuple[List[Dict[str, Any]], Any]:
678
+ # Build a prompt that preserves literal newlines
679
+ prompt = tokenizer.apply_chat_template(
680
+ messages,
681
+ tokenize=False, # Don't tokenize yet, we need raw text
682
+ add_generation_prompt=False, # No generation prompt needed
683
+ )
684
+
685
+ # Request token logprobs
686
+ response = await lm_client.client.completions.create(
687
+ model=lm_client.model, # type: ignore
688
+ prompt=prompt,
689
+ max_tokens=1,
690
+ logprobs=1,
691
+ extra_body={"prompt_logprobs": 0},
692
+ )
693
+ token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
694
+
695
+ # Override first token to known start marker
696
+ start_id = tokenizer.encode("<|im_start|>")[0]
697
+ token_logprob_dicts[0] = {
698
+ str(start_id): {
699
+ "logprob": -1,
700
+ "rank": 1,
701
+ "decoded_token": "<|im_start|>",
702
+ }
703
+ }
704
+
705
+ # Flatten tokens
706
+ tokens: List[Dict[str, Any]] = [
707
+ {"id": int(tid), **tdata}
708
+ for td in token_logprob_dicts
709
+ for tid, tdata in td.items()
710
+ ]
711
+
712
+ # Validate tokenization
713
+ tokenized = tokenizer.tokenize(prompt)
714
+ if len(tokenized) != len(tokens):
715
+ raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
716
+ for idx, tok in enumerate(tokens):
717
+ if tokenized[idx] != tok["decoded_token"]:
718
+ raise AssertionError(
719
+ f"Token mismatch at {idx}: "
720
+ f"{tokenized[idx]} != {tok['decoded_token']}"
721
+ )
722
+
723
+ # Split on newline sentinel
724
+ split_prompt = prompt.replace("\n", " <NL> ")
725
+ words = split_prompt.split()
726
+
727
+ word_log_probs: List[Dict[str, Any]] = []
728
+ token_idx = 0
729
+
730
+ for word in words:
731
+ # Map sentinel back to actual newline for encoding
732
+ target = "\n" if word == "<NL>" else word
733
+ sub_ids = tokenizer.encode(target, add_special_tokens=False)
734
+ count = len(sub_ids)
735
+ if count == 0:
736
+ continue
737
+
738
+ subs = tokens[token_idx : token_idx + count]
739
+ avg_logprob = sum(s["logprob"] for s in subs) / count
740
+ prob = float(np.exp(avg_logprob))
741
+ word_log_probs.append({"word": target, "probability": prob})
742
+ token_idx += count
743
+
744
+ return word_log_probs, token_logprob_dicts # type: ignore
745
+
746
+ def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
747
+ """
748
+ Return an ANSI-colored string for word probabilities (red → green).
749
+ """
750
+ if not word_log_probs:
751
+ return ""
752
+
753
+ probs = [entry["probability"] for entry in word_log_probs]
754
+ min_p, max_p = min(probs), max(probs)
755
+ parts: List[str] = []
756
+
757
+ for entry in word_log_probs:
758
+ word = entry["word"]
759
+ # Preserve actual line breaks
760
+ if word == "\n":
761
+ parts.append("\n")
762
+ continue
763
+
764
+ p = entry["probability"]
765
+ norm = (p - min_p) / (max_p - min_p or 1.0)
766
+ r = int(255 * (1 - norm)) # red component (high when prob is low)
767
+ g = int(255 * norm) # green component (high when prob is high)
768
+ b = 0 # no blue for red-green gradient
769
+ colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
770
+ parts.append(colored + " ")
771
+
772
+ return "".join(parts).rstrip()
773
+
774
+ word_probs, token_logprob_dicts = await compute_word_log_probs(tokenizer, lm)
775
+ return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
776
+
777
+
778
+ # --------------------------------------------------------------------------- #
779
+ # Async LLMTask class
780
+ # --------------------------------------------------------------------------- #
781
+
782
+
783
+ class AsyncLLMTask(ABC):
784
+ """
785
+ Async callable wrapper around an AsyncLM endpoint.
786
+
787
+ Sub-classes must set:
788
+ • lm – the async language-model instance
789
+ • InputModel – a Pydantic input class
790
+ • OutputModel – a Pydantic output class
791
+
792
+ Optional flags:
793
+ • temperature – float (default 0.6)
794
+ • think – bool (if the backend supports "chain-of-thought")
795
+ • add_json_schema – bool (include schema in the instruction)
796
+
797
+ The **docstring** of each sub-class is sent as the LM instruction.
798
+ Example
799
+ ```python
800
+ class DemoTask(AsyncLLMTask):
801
+ "TODO: SYSTEM_PROMPT_INSTURCTION HERE"
802
+
803
+ lm = AsyncLM(port=8130, cache=False, model="gpt-3.5-turbo")
804
+
805
+ class InputModel(BaseModel):
806
+ text_to_translate:str
807
+
808
+ class OutputModel(BaseModel):
809
+ translation:str
810
+ glossary_use:str
811
+
812
+ temperature = 0.6
813
+ think=False
814
+
815
+ demo_task = DemoTask()
816
+ result = await demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
817
+ ```
818
+ """
819
+
820
+ lm: "AsyncLM"
821
+ InputModel: Type[BaseModel]
822
+ OutputModel: Type[BaseModel]
823
+
824
+ temperature: float = 0.6
825
+ think: bool = False
826
+ add_json_schema: bool = False
827
+
828
+ async def __call__(self, data: BaseModel | dict) -> BaseModel:
829
+ if (
830
+ not hasattr(self, "InputModel")
831
+ or not hasattr(self, "OutputModel")
832
+ or not hasattr(self, "lm")
833
+ ):
834
+ raise NotImplementedError(
835
+ f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
836
+ )
837
+
838
+ item = data if isinstance(data, BaseModel) else self.InputModel(**data)
839
+
840
+ return await self.lm.parse(
841
+ prompt=item.model_dump_json(),
842
+ instruction=self.__doc__ or "",
843
+ response_model=self.OutputModel,
844
+ temperature=self.temperature,
845
+ think=self.think,
846
+ add_json_schema_to_instruction=self.add_json_schema,
847
+ )
848
+
849
+ def generate_training_data(
850
+ self, input_dict: Dict[str, Any], output: Dict[str, Any]
851
+ ):
852
+ """Return share gpt like format"""
853
+ system_prompt = self.__doc__ or ""
854
+ user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
855
+ assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
856
+ messages = get_conversation_one_turn(
857
+ system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
858
+ )
859
+ return {"messages": messages}