speedy-utils 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/sync_lm.py DELETED
@@ -1,943 +0,0 @@
1
- """
2
- # ============================================================================= #
3
- # SYNCHRONOUS LANGUAGE MODEL WRAPPER WITH OPENAI COMPATIBILITY
4
- # ============================================================================= #
5
- #
6
- # Title & Intent:
7
- # Unified synchronous language model interface with caching, type safety, and OpenAI API compatibility
8
- #
9
- # High-level Summary:
10
- # This module provides a comprehensive synchronous wrapper for language models that supports both
11
- # string prompts and structured Pydantic model responses. It includes intelligent caching with
12
- # content-based hashing, automatic retry logic for rate limits, and seamless integration with
13
- # OpenAI-compatible APIs. The LM class handles message formatting, response parsing, token counting,
14
- # and provides detailed logging and debugging capabilities for production use.
15
- #
16
- # Public API / Data Contracts:
17
- # • LM(model, temperature=0.0, max_tokens=2000, host="localhost", port=None, **kwargs) - Main wrapper class
18
- # • LM.__call__(prompt=None, messages=None, response_format=str, cache=None, **kwargs) -> str | BaseModel
19
- # • LM.list_models(port=None) -> List[str] - Enumerate available models
20
- # • LM.count_tokens(messages, model=None) -> int - Token counting utility
21
- # • LM.price(messages, model=None, response_tokens=0) -> float - Cost estimation
22
- # • LM.set_model(model_name) -> None - Runtime model switching
23
- # • TModel = TypeVar("TModel", bound=BaseModel) - Generic type for structured responses
24
- # • Messages = List[ChatCompletionMessageParam] - Typed message format
25
- # • RawMsgs = Union[Messages, LegacyMsgs] - Flexible input format
26
- #
27
- # Invariants / Constraints:
28
- # • MUST provide either 'prompt' or 'messages' parameter, but not both
29
- # • MUST set model name before making API calls (auto-detection available)
30
- # • response_format=str MUST return string; response_format=PydanticModel MUST return model instance
31
- # • Caching MUST use content-based hashing for reproducible results
32
- # • MUST handle OpenAI rate limits with exponential backoff (up to 3 retries)
33
- # • MUST preserve message order and format during transformations
34
- # • Token counting SHOULD use tiktoken when available, fall back to character estimation
35
- # • MUST validate Pydantic responses and retry on parsing failures
36
- #
37
- # Usage Example:
38
- # ```python
39
- # from llm_utils.lm.sync_lm import LM
40
- # from pydantic import BaseModel
41
- #
42
- # class CodeResponse(BaseModel):
43
- # language: str
44
- # code: str
45
- # explanation: str
46
- #
47
- # # String response
48
- # lm = LM(model="gpt-4o-mini", temperature=0.1)
49
- # response = lm(prompt="Write a Python hello world")
50
- # print(response) # Returns string
51
- #
52
- # # Structured response
53
- # code_response = lm(
54
- # prompt="Write a Python function to calculate fibonacci",
55
- # response_format=CodeResponse
56
- # )
57
- # print(f"Language: {code_response.language}") # Returns CodeResponse instance
58
- #
59
- # # Message-based conversation
60
- # messages = [
61
- # {"role": "system", "content": "You are a helpful coding assistant"},
62
- # {"role": "user", "content": "Explain async/await in Python"}
63
- # ]
64
- # response = lm(messages=messages, max_tokens=1000)
65
- # ```
66
- #
67
- # TODO & Future Work:
68
- # • Add streaming response support for long-form generation
69
- # • Implement fine-grained token usage tracking per conversation
70
- # • Add support for function calling and tool use
71
- # • Optimize caching strategy for conversation contexts
72
- # • Add async context manager support for resource cleanup
73
- #
74
- # ============================================================================= #
75
- """
76
-
77
- from __future__ import annotations
78
-
79
- import base64
80
- import hashlib
81
- import json
82
- import os
83
- from abc import ABC
84
- from functools import lru_cache
85
- from typing import (
86
- Any,
87
- Dict,
88
- List,
89
- Literal,
90
- Optional,
91
- Sequence,
92
- Type,
93
- TypeVar,
94
- Union,
95
- cast,
96
- overload,
97
- )
98
-
99
- from loguru import logger
100
- from openai import AuthenticationError, OpenAI, RateLimitError
101
- from openai.pagination import SyncPage
102
- from openai.types.chat import (
103
- ChatCompletionAssistantMessageParam,
104
- ChatCompletionMessageParam,
105
- ChatCompletionSystemMessageParam,
106
- ChatCompletionToolMessageParam,
107
- ChatCompletionUserMessageParam,
108
- )
109
- from openai.types.model import Model
110
- from pydantic import BaseModel
111
-
112
- from llm_utils.chat_format.display import get_conversation_one_turn
113
- from speedy_utils.common.utils_io import jdumps
114
-
115
- # --------------------------------------------------------------------------- #
116
- # type helpers
117
- # --------------------------------------------------------------------------- #
118
- TModel = TypeVar("TModel", bound=BaseModel)
119
- Messages = List[ChatCompletionMessageParam] # final, already-typed messages
120
- LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
121
- RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
122
-
123
-
124
- # --------------------------------------------------------------------------- #
125
- # color formatting helpers
126
- # --------------------------------------------------------------------------- #
127
- def _red(text: str) -> str:
128
- """Format text with red color."""
129
- return f"\x1b[31m{text}\x1b[0m"
130
-
131
-
132
- def _green(text: str) -> str:
133
- """Format text with green color."""
134
- return f"\x1b[32m{text}\x1b[0m"
135
-
136
-
137
- def _blue(text: str) -> str:
138
- """Format text with blue color."""
139
- return f"\x1b[34m{text}\x1b[0m"
140
-
141
-
142
- def _yellow(text: str) -> str:
143
- """Format text with yellow color."""
144
- return f"\x1b[33m{text}\x1b[0m"
145
-
146
-
147
- # from functools import lru_cache
148
-
149
-
150
- # @lru_cache(maxsize=10)
151
- # def get_tok(tokenizer_name):
152
- # from transformers import AutoTokenizer
153
-
154
- # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
155
- # return tokenizer
156
-
157
-
158
- class LM:
159
- """
160
- Unified language-model wrapper.
161
-
162
- • `response_format=str` → returns `str`
163
- • `response_format=YourPydanticModel` → returns that model instance
164
- """
165
-
166
- # --------------------------------------------------------------------- #
167
- # ctor / plumbing
168
- # --------------------------------------------------------------------- #
169
- def __init__(
170
- self,
171
- model: str | None = None,
172
- *,
173
- temperature: float = 0.0,
174
- max_tokens: int = 2_000,
175
- host: str = "localhost",
176
- port: Optional[int | str] = None,
177
- base_url: Optional[str] = None,
178
- api_key: Optional[str] = None,
179
- cache: bool = True,
180
- **openai_kwargs: Any,
181
- ) -> None:
182
- self.model = model
183
- self.temperature = temperature
184
- self.max_tokens = max_tokens
185
- self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
186
- self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
187
- self.openai_kwargs = openai_kwargs
188
- self.do_cache = cache
189
- self._init_port = port # <-- store the port provided at init
190
-
191
- self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
192
-
193
- def set_model(self, model: str) -> None:
194
- """Set the model name after initialization."""
195
- self.model = model
196
-
197
- # --------------------------------------------------------------------- #
198
- # public API – typed overloads
199
- # --------------------------------------------------------------------- #
200
- @overload
201
- def __call__(
202
- self,
203
- *,
204
- prompt: str | None = ...,
205
- messages: RawMsgs | None = ...,
206
- response_format: type[str] = str,
207
- return_openai_response: bool = ...,
208
- **kwargs: Any,
209
- ) -> str: ...
210
-
211
- @overload
212
- def __call__(
213
- self,
214
- *,
215
- prompt: str | None = ...,
216
- messages: RawMsgs | None = ...,
217
- response_format: Type[TModel],
218
- return_openai_response: bool = ...,
219
- **kwargs: Any,
220
- ) -> TModel: ...
221
-
222
- # single implementation
223
- def __call__(
224
- self,
225
- prompt: Optional[str] = None,
226
- messages: Optional[RawMsgs] = None,
227
- response_format: Union[type[str], Type[BaseModel]] = str,
228
- cache: Optional[bool] = None,
229
- max_tokens: Optional[int] = None,
230
- return_openai_response: bool = False,
231
- **kwargs: Any,
232
- ):
233
- # argument validation ------------------------------------------------
234
- if (prompt is None) == (messages is None):
235
- raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
236
-
237
- if prompt is not None:
238
- messages = [{"role": "user", "content": prompt}]
239
-
240
- assert messages is not None # for type-checker
241
-
242
- # If model is not specified, but port is provided, use the first available model
243
- if self.model is None:
244
- port = self._init_port
245
- if port:
246
- available_models = self.list_models(port=port)
247
- if available_models:
248
- self.model = available_models[0]
249
- logger.debug(f"Auto-selected model: {self.model}")
250
- else:
251
- raise ValueError("No models available to select from.")
252
- else:
253
- raise AssertionError("Model must be set before calling.")
254
-
255
- openai_msgs: Messages = (
256
- self._convert_messages(cast(LegacyMsgs, messages))
257
- if isinstance(messages[0], dict) # legacy style
258
- else cast(Messages, messages) # already typed
259
- )
260
-
261
- kw = dict(
262
- self.openai_kwargs,
263
- temperature=self.temperature,
264
- max_tokens=max_tokens or self.max_tokens,
265
- )
266
- kw.update(kwargs)
267
- use_cache = self.do_cache if cache is None else cache
268
-
269
- raw_response = self._call_raw(
270
- openai_msgs,
271
- response_format=response_format,
272
- use_cache=use_cache,
273
- **kw,
274
- )
275
-
276
- if return_openai_response:
277
- response = raw_response
278
- else:
279
- response = self._parse_output(raw_response, response_format)
280
-
281
- self.last_log = [prompt, messages, raw_response]
282
- return response
283
-
284
- def inspect_history(self) -> None:
285
- if not hasattr(self, "last_log"):
286
- raise ValueError("No history available. Please call the model first.")
287
-
288
- prompt, messages, response = self.last_log
289
- # Ensure response is a dictionary
290
- if hasattr(response, "model_dump"):
291
- response = response.model_dump()
292
-
293
- if not messages:
294
- messages = [{"role": "user", "content": prompt}]
295
-
296
- print("\n\n")
297
- print(_blue("[Conversation History]") + "\n")
298
-
299
- # Print all messages in the conversation
300
- for msg in messages:
301
- role = msg["role"]
302
- content = msg["content"]
303
- print(_red(f"{role.capitalize()}:"))
304
-
305
- if isinstance(content, str):
306
- print(content.strip())
307
- elif isinstance(content, list):
308
- # Handle multimodal content
309
- for item in content:
310
- if item.get("type") == "text":
311
- print(item["text"].strip())
312
- elif item.get("type") == "image_url":
313
- image_url = item["image_url"]["url"]
314
- if "base64" in image_url:
315
- len_base64 = len(image_url.split("base64,")[1])
316
- print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
317
- else:
318
- print(_blue(f"<image_url: {image_url}>"))
319
- print("\n")
320
-
321
- # Print the response - now always an OpenAI completion
322
- print(_red("Response:"))
323
-
324
- # Handle OpenAI response object
325
- if isinstance(response, dict) and "choices" in response and response["choices"]:
326
- message = response["choices"][0].get("message", {})
327
-
328
- # Check for reasoning content (if available)
329
- reasoning = message.get("reasoning_content")
330
-
331
- # Check for parsed content (structured mode)
332
- parsed = message.get("parsed")
333
-
334
- # Get regular content
335
- content = message.get("content")
336
-
337
- # Display reasoning if available
338
- if reasoning:
339
- print(_yellow("<think>"))
340
- print(reasoning.strip())
341
- print(_yellow("</think>"))
342
- print()
343
-
344
- # Display parsed content for structured responses
345
- if parsed:
346
- # print(_green('<Parsed Structure>'))
347
- if hasattr(parsed, "model_dump"):
348
- print(jdumps(parsed.model_dump(), indent=2))
349
- else:
350
- print(jdumps(parsed, indent=2))
351
- # print(_green('</Parsed Structure>'))
352
- print()
353
-
354
- else:
355
- if content:
356
- # print(_green("<Content>"))
357
- print(content.strip())
358
- # print(_green("</Content>"))
359
- else:
360
- print(_green("[No content]"))
361
-
362
- # Show if there were multiple completions
363
- if len(response["choices"]) > 1:
364
- print(
365
- _blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
366
- )
367
- else:
368
- # Fallback for non-standard response objects or cached responses
369
- print(_yellow("Warning: Not a standard OpenAI response object"))
370
- if isinstance(response, str):
371
- print(_green(response.strip()))
372
- elif isinstance(response, dict):
373
- print(_green(json.dumps(response, indent=2)))
374
- else:
375
- print(_green(str(response)))
376
-
377
- # print("\n\n")
378
-
379
- # --------------------------------------------------------------------- #
380
- # low-level OpenAI call
381
- # --------------------------------------------------------------------- #
382
- def _call_raw(
383
- self,
384
- messages: Sequence[ChatCompletionMessageParam],
385
- response_format: Union[type[str], Type[BaseModel]],
386
- use_cache: bool,
387
- **kw: Any,
388
- ):
389
- assert self.model is not None, "Model must be set before making a call."
390
- model: str = self.model
391
-
392
- cache_key = (
393
- self._cache_key(messages, kw, response_format) if use_cache else None
394
- )
395
- if cache_key and (hit := self._load_cache(cache_key)) is not None:
396
- return hit
397
-
398
- try:
399
- # structured mode
400
- if response_format is not str and issubclass(response_format, BaseModel):
401
- openai_response = self.client.beta.chat.completions.parse(
402
- model=model,
403
- messages=list(messages),
404
- response_format=response_format, # type: ignore[arg-type]
405
- **kw,
406
- )
407
- # plain-text mode
408
- else:
409
- openai_response = self.client.chat.completions.create(
410
- model=model,
411
- messages=list(messages),
412
- **kw,
413
- )
414
-
415
- except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
416
- logger.error(exc)
417
- raise
418
-
419
- if cache_key:
420
- self._dump_cache(cache_key, openai_response)
421
-
422
- return openai_response
423
-
424
- # --------------------------------------------------------------------- #
425
- # legacy → typed messages
426
- # --------------------------------------------------------------------- #
427
- @staticmethod
428
- def _convert_messages(msgs: LegacyMsgs) -> Messages:
429
- converted: Messages = []
430
- for msg in msgs:
431
- role = msg["role"]
432
- content = msg["content"]
433
- if role == "user":
434
- converted.append(
435
- ChatCompletionUserMessageParam(role="user", content=content)
436
- )
437
- elif role == "assistant":
438
- converted.append(
439
- ChatCompletionAssistantMessageParam(
440
- role="assistant", content=content
441
- )
442
- )
443
- elif role == "system":
444
- converted.append(
445
- ChatCompletionSystemMessageParam(role="system", content=content)
446
- )
447
- elif role == "tool":
448
- converted.append(
449
- ChatCompletionToolMessageParam(
450
- role="tool",
451
- content=content,
452
- tool_call_id=msg.get("tool_call_id") or "", # str, never None
453
- )
454
- )
455
- else:
456
- # fall back to raw dict for unknown roles
457
- converted.append({"role": role, "content": content}) # type: ignore[arg-type]
458
- return converted
459
-
460
- # --------------------------------------------------------------------- #
461
- # final parse (needed for plain-text or cache hits only)
462
- # --------------------------------------------------------------------- #
463
- @staticmethod
464
- def _parse_output(
465
- raw_response: Any,
466
- response_format: Union[type[str], Type[BaseModel]],
467
- ) -> str | BaseModel:
468
- # Convert any object to dict if needed
469
- if hasattr(raw_response, "model_dump"):
470
- raw_response = raw_response.model_dump()
471
-
472
- if response_format is str:
473
- # Extract the content from OpenAI response dict
474
- if isinstance(raw_response, dict) and "choices" in raw_response:
475
- message = raw_response["choices"][0]["message"]
476
- return message.get("content", "") or ""
477
- return cast(str, raw_response)
478
-
479
- # For the type-checker: we *know* it's a BaseModel subclass here.
480
- model_cls = cast(Type[BaseModel], response_format)
481
-
482
- # Handle structured response
483
- if isinstance(raw_response, dict) and "choices" in raw_response:
484
- message = raw_response["choices"][0]["message"]
485
-
486
- # Check if already parsed by OpenAI client
487
- if "parsed" in message:
488
- return model_cls.model_validate(message["parsed"])
489
-
490
- # Need to parse the content
491
- content = message.get("content")
492
- if content is None:
493
- raise ValueError("Model returned empty content")
494
-
495
- try:
496
- data = json.loads(content)
497
- return model_cls.model_validate(data)
498
- except Exception as exc:
499
- raise ValueError(
500
- f"Failed to parse model output as JSON:\n{content}"
501
- ) from exc
502
-
503
- # Handle cached response or other formats
504
- if isinstance(raw_response, model_cls):
505
- return raw_response
506
- if isinstance(raw_response, dict):
507
- return model_cls.model_validate(raw_response)
508
-
509
- # Try parsing as JSON string
510
- try:
511
- data = json.loads(raw_response)
512
- return model_cls.model_validate(data)
513
- except Exception as exc:
514
- raise ValueError(
515
- f"Model did not return valid JSON:\n---\n{raw_response}"
516
- ) from exc
517
-
518
- # --------------------------------------------------------------------- #
519
- # tiny disk cache
520
- # --------------------------------------------------------------------- #
521
- @staticmethod
522
- def _cache_key(
523
- messages: Any,
524
- kw: Any,
525
- response_format: Union[type[str], Type[BaseModel]],
526
- ) -> str:
527
- tag = response_format.__name__ if response_format is not str else "text"
528
- blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
529
- return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
530
-
531
- @staticmethod
532
- def _cache_path(key: str) -> str:
533
- return os.path.expanduser(f"~/.cache/lm/{key}.json")
534
-
535
- def _dump_cache(self, key: str, val: Any) -> None:
536
- try:
537
- path = self._cache_path(key)
538
- os.makedirs(os.path.dirname(path), exist_ok=True)
539
- with open(path, "w") as fh:
540
- if isinstance(val, BaseModel):
541
- json.dump(val.model_dump(mode="json"), fh)
542
- else:
543
- json.dump(val, fh)
544
- except Exception as exc: # pragma: no cover
545
- logger.debug(f"cache write skipped: {exc}")
546
-
547
- def _load_cache(self, key: str) -> Any | None:
548
- path = self._cache_path(key)
549
- if not os.path.exists(path):
550
- return None
551
- try:
552
- with open(path) as fh:
553
- return json.load(fh)
554
- except Exception: # pragma: no cover
555
- return None
556
-
557
- @staticmethod
558
- def list_models(
559
- port=None, host="localhost", base_url: Optional[str] = None
560
- ) -> List[str]:
561
- """List available models from OpenAI-compatible API server."""
562
- try:
563
- client: OpenAI = OpenAI(
564
- api_key=os.getenv("OPENAI_API_KEY", "abc"),
565
- base_url=f"http://{host}:{port}/v1" if port else base_url or None,
566
- )
567
- models: SyncPage[Model] = client.models.list()
568
- return [model.id for model in models.data]
569
- except Exception as exc:
570
- endpoint = f"http://{host}:{port}/v1" if port else base_url
571
- error_msg = str(exc)
572
-
573
- if "404" in error_msg or "Not Found" in error_msg:
574
- raise ValueError(
575
- f"No OpenAI-compatible API found at {endpoint}. "
576
- f"The endpoint appears to be running a different service "
577
- f"(possibly Jupyter Server). Please check the port number."
578
- ) from exc
579
- elif "Connection" in error_msg:
580
- raise ValueError(
581
- f"Cannot connect to {endpoint}. "
582
- f"Please verify the service is running and accessible."
583
- ) from exc
584
- else:
585
- raise ValueError(
586
- f"Failed to list models from {endpoint}: {error_msg}"
587
- ) from exc
588
-
589
- def parse(
590
- self,
591
- response_model: Type[BaseModel],
592
- instruction: Optional[str] = None,
593
- prompt: Optional[str] = None,
594
- messages: Optional[RawMsgs] = None,
595
- think: Literal[True, False, None] = None,
596
- add_json_schema_to_instruction: bool = False,
597
- temperature: Optional[float] = None,
598
- max_tokens: Optional[int] = None,
599
- return_openai_response: bool = False,
600
- cache: Optional[bool] = True,
601
- **kwargs,
602
- ):
603
- if messages is None:
604
- assert instruction is not None, "Instruction must be provided."
605
- assert prompt is not None, "Prompt must be provided."
606
- messages = [
607
- {
608
- "role": "system",
609
- "content": instruction,
610
- },
611
- {
612
- "role": "user",
613
- "content": prompt,
614
- },
615
- ] # type: ignore
616
-
617
- post_fix = ""
618
- json_schema = response_model.model_json_schema()
619
- if add_json_schema_to_instruction and response_model:
620
- _schema = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
621
- post_fix += _schema
622
-
623
- if think:
624
- post_fix += "\n\n/think"
625
- elif not think:
626
- post_fix += "\n\n/no_think"
627
-
628
- assert isinstance(messages, list), "Messages must be a list."
629
- assert len(messages) > 0, "Messages cannot be empty."
630
- assert messages[0]["role"] == "system", (
631
- "First message must be a system message with instruction."
632
- )
633
- messages[0]["content"] += post_fix # type: ignore
634
-
635
- model_kwargs = {}
636
- if temperature is not None:
637
- model_kwargs["temperature"] = temperature
638
- if max_tokens is not None:
639
- model_kwargs["max_tokens"] = max_tokens
640
- model_kwargs.update(kwargs)
641
-
642
- use_cache = self.do_cache if cache is None else cache
643
- cache_key = None
644
- if use_cache:
645
- cache_data = {
646
- "messages": messages,
647
- "model_kwargs": model_kwargs,
648
- "guided_json": json_schema,
649
- "response_format": response_model.__name__,
650
- }
651
- cache_key = self._cache_key(cache_data, {}, response_model)
652
- cached_response = self._load_cache(cache_key)
653
- self.last_log = [prompt, messages, cached_response]
654
- if cached_response is not None:
655
- if return_openai_response:
656
- return cached_response
657
- return self._parse_complete_output(cached_response, response_model)
658
-
659
- completion = self.client.chat.completions.create(
660
- model=self.model, # type: ignore
661
- messages=messages, # type: ignore
662
- extra_body={"guided_json": json_schema},
663
- **model_kwargs,
664
- )
665
-
666
- if cache_key:
667
- self._dump_cache(cache_key, completion)
668
-
669
- self.last_log = [prompt, messages, completion]
670
- if return_openai_response:
671
- return completion
672
- return self._parse_complete_output(completion, response_model)
673
-
674
- def _parse_complete_output(
675
- self, completion: Any, response_model: Type[BaseModel]
676
- ) -> BaseModel:
677
- """Parse completion output to response model."""
678
- if hasattr(completion, "model_dump"):
679
- completion = completion.model_dump()
680
-
681
- if "choices" not in completion or not completion["choices"]:
682
- raise ValueError("No choices in OpenAI response")
683
-
684
- content = completion["choices"][0]["message"]["content"]
685
- if not content:
686
- raise ValueError("Empty content in response")
687
-
688
- try:
689
- data = json.loads(content)
690
- return response_model.model_validate(data)
691
- except Exception as exc:
692
- raise ValueError(
693
- f"Failed to parse response as {response_model.__name__}: {content}"
694
- ) from exc
695
-
696
- def inspect_word_probs(
697
- self,
698
- messages: Optional[List[Dict[str, Any]]] = None,
699
- tokenizer: Optional[Any] = None,
700
- do_print=True,
701
- add_think: bool = True,
702
- ) -> tuple[List[Dict[str, Any]], Any, str]:
703
- """
704
- Inspect word probabilities in a language model response.
705
-
706
- Args:
707
- tokenizer: Tokenizer instance to encode words.
708
- messages: List of messages to analyze.
709
-
710
- Returns:
711
- A tuple containing:
712
- - List of word probabilities with their log probabilities.
713
- - Token log probability dictionaries.
714
- - Rendered string with colored word probabilities.
715
- """
716
- if messages is None:
717
- messages = self.last_messages(add_think=add_think)
718
- if messages is None:
719
- raise ValueError("No messages provided and no last messages available.")
720
-
721
- if tokenizer is None:
722
- tokenizer = get_tokenizer(self.model)
723
-
724
- ret = inspect_word_probs(self, tokenizer, messages)
725
- if do_print:
726
- print(ret[-1])
727
- return ret
728
-
729
- def last_messages(self, add_think: bool = True) -> Optional[List[Dict[str, str]]]:
730
- last_conv = self.last_log
731
- messages = last_conv[1] if len(last_conv) > 1 else None
732
- last_msg = last_conv[2]
733
- if not isinstance(last_msg, dict):
734
- last_conv[2] = last_conv[2].model_dump() # type: ignore
735
- msg = last_conv[2]
736
- # Ensure msg is a dict
737
- if hasattr(msg, "model_dump"):
738
- msg = msg.model_dump()
739
- message = msg["choices"][0]["message"]
740
- reasoning = message.get("reasoning_content")
741
- answer = message.get("content")
742
- if reasoning and add_think:
743
- final_answer = f"<think>{reasoning}</think>\n{answer}"
744
- else:
745
- final_answer = f"<think>\n\n</think>\n{answer}"
746
- assistant = {"role": "assistant", "content": final_answer}
747
- messages = messages + [assistant] # type: ignore
748
- return messages if messages else None
749
-
750
-
751
- @lru_cache(maxsize=10)
752
- def get_tokenizer(model_name: str) -> Any:
753
- from transformers import AutoTokenizer # type: ignore
754
-
755
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
756
- return tokenizer
757
-
758
-
759
- def inspect_word_probs(lm, tokenizer, messages):
760
- import numpy as np
761
-
762
- def compute_word_log_probs(
763
- tokenizer: Any,
764
- lm_client: Any,
765
- ) -> tuple[List[Dict[str, Any]], Any]:
766
- # Build a prompt that preserves literal newlines
767
- prompt = tokenizer.apply_chat_template(
768
- messages,
769
- tokenize=False, # Don't tokenize yet, we need raw text
770
- add_generation_prompt=False, # No generation prompt needed
771
- )
772
-
773
- # Request token logprobs
774
- response = lm_client.client.completions.create(
775
- model=lm_client.model, # type: ignore
776
- prompt=prompt,
777
- max_tokens=1,
778
- logprobs=1,
779
- extra_body={"prompt_logprobs": 0},
780
- )
781
- token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
782
-
783
- # Override first token to known start marker
784
- start_id = tokenizer.encode("<|im_start|>")[0]
785
- token_logprob_dicts[0] = {
786
- str(start_id): {
787
- "logprob": -1,
788
- "rank": 1,
789
- "decoded_token": "<|im_start|>",
790
- }
791
- }
792
-
793
- # Flatten tokens
794
- tokens: List[Dict[str, Any]] = [
795
- {"id": int(tid), **tdata}
796
- for td in token_logprob_dicts
797
- for tid, tdata in td.items()
798
- ]
799
-
800
- # Validate tokenization
801
- tokenized = tokenizer.tokenize(prompt)
802
- if len(tokenized) != len(tokens):
803
- raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
804
- for idx, tok in enumerate(tokens):
805
- if tokenized[idx] != tok["decoded_token"]:
806
- raise AssertionError(
807
- f"Token mismatch at {idx}: "
808
- f"{tokenized[idx]} != {tok['decoded_token']}"
809
- )
810
-
811
- # Split on newline sentinel
812
- split_prompt = prompt.replace("\n", " <NL> ")
813
- words = split_prompt.split()
814
-
815
- word_log_probs: List[Dict[str, Any]] = []
816
- token_idx = 0
817
-
818
- for word in words:
819
- # Map sentinel back to actual newline for encoding
820
- target = "\n" if word == "<NL>" else word
821
- sub_ids = tokenizer.encode(target, add_special_tokens=False)
822
- count = len(sub_ids)
823
- if count == 0:
824
- continue
825
-
826
- subs = tokens[token_idx : token_idx + count]
827
- avg_logprob = sum(s["logprob"] for s in subs) / count
828
- prob = float(np.exp(avg_logprob))
829
- word_log_probs.append({"word": target, "probability": prob})
830
- token_idx += count
831
-
832
- return word_log_probs, token_logprob_dicts # type: ignore
833
-
834
- def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
835
- """
836
- Return an ANSI-colored string for word probabilities (red → green).
837
- """
838
- if not word_log_probs:
839
- return ""
840
-
841
- probs = [entry["probability"] for entry in word_log_probs]
842
- min_p, max_p = min(probs), max(probs)
843
- parts: List[str] = []
844
-
845
- for entry in word_log_probs:
846
- word = entry["word"]
847
- # Preserve actual line breaks
848
- if word == "\n":
849
- parts.append("\n")
850
- continue
851
-
852
- p = entry["probability"]
853
- norm = (p - min_p) / (max_p - min_p or 1.0)
854
- r = int(255 * (1 - norm)) # red component (high when prob is low)
855
- g = int(255 * norm) # green component (high when prob is high)
856
- b = 0 # no blue for red-green gradient
857
- colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
858
- parts.append(colored + " ")
859
-
860
- return "".join(parts).rstrip()
861
-
862
- word_probs, token_logprob_dicts = compute_word_log_probs(tokenizer, lm)
863
- return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
864
-
865
-
866
- class LLMTask(ABC):
867
- """
868
- Callable wrapper around an LM endpoint.
869
-
870
- Sub-classes must set:
871
- • lm – the language-model instance
872
- • InputModel – a Pydantic input class
873
- • OutputModel – a Pydantic output class
874
-
875
- Optional flags:
876
- • temperature – float (default 0.6)
877
- • think – bool (if the backend supports “chain-of-thought”)
878
- • add_json_schema – bool (include schema in the instruction)
879
-
880
- The **docstring** of each sub-class is sent as the LM instruction.
881
- Example
882
- ```python
883
- class DemoTask(LLMTask):
884
- "TODO: SYSTEM_PROMPT_INSTURCTION HERE"
885
-
886
- lm = LM(port=8130, cache=False, model="gpt-3.5-turbo")
887
-
888
- class InputModel(BaseModel):
889
- text_to_translate:str
890
-
891
- class OutputModel(BaseModel):
892
- translation:str
893
- glossary_use:str
894
-
895
- temperature = 0.6
896
- think=False
897
-
898
- demo_task = DemoTask()
899
- demo_task({'text_to_translate': 'Translate from english to vietnamese: Hello how are you'})
900
- ```
901
- """
902
-
903
- lm: "LM"
904
- InputModel: Type[BaseModel]
905
- OutputModel: Type[BaseModel]
906
-
907
- temperature: float = 0.6
908
- think: bool = False
909
- add_json_schema: bool = False
910
-
911
- def __call__(self, data: BaseModel | dict) -> BaseModel:
912
- if (
913
- not hasattr(self, "InputModel")
914
- or not hasattr(self, "OutputModel")
915
- or not hasattr(self, "lm")
916
- ):
917
- raise NotImplementedError(
918
- f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
919
- )
920
-
921
- item = data if isinstance(data, BaseModel) else self.InputModel(**data)
922
-
923
- return self.lm.parse(
924
- prompt=item.model_dump_json(),
925
- instruction=self.__doc__ or "",
926
- response_model=self.OutputModel,
927
- temperature=self.temperature,
928
- think=self.think,
929
- add_json_schema_to_instruction=self.add_json_schema,
930
- )
931
-
932
- def generate_training_data(
933
- self, input_dict: Dict[str, Any], output: Dict[str, Any]
934
- ):
935
- "Return share gpt like format"
936
- system_prompt = self.__doc__ or ""
937
- user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
938
- assistant_msg = self.OutputModel(**output).model_dump_json() # type: ignore[attr-defined]
939
- return get_conversation_one_turn(
940
- system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
941
- )
942
-
943
- run = __call__ # alias for compatibility with other LLMTask implementations