speedy-utils 1.1.35__py3-none-any.whl → 1.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/llm.py +86 -0
- llm_utils/lm/mixins.py +204 -0
- {speedy_utils-1.1.35.dist-info → speedy_utils-1.1.36.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.35.dist-info → speedy_utils-1.1.36.dist-info}/RECORD +6 -6
- {speedy_utils-1.1.35.dist-info → speedy_utils-1.1.36.dist-info}/entry_points.txt +1 -0
- {speedy_utils-1.1.35.dist-info → speedy_utils-1.1.36.dist-info}/WHEEL +0 -0
llm_utils/lm/llm.py
CHANGED
|
@@ -435,3 +435,89 @@ class LLM(
|
|
|
435
435
|
vllm_reuse=vllm_reuse,
|
|
436
436
|
**model_kwargs,
|
|
437
437
|
)
|
|
438
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
439
|
+
from pydantic import BaseModel
|
|
440
|
+
from .llm import LLM, Messages
|
|
441
|
+
|
|
442
|
+
class LLM_NEMOTRON3(LLM):
|
|
443
|
+
"""
|
|
444
|
+
Custom implementation for NVIDIA Nemotron-3 reasoning models.
|
|
445
|
+
Supports thinking budget control and native reasoning tags.
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
model: str = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
|
|
451
|
+
thinking_budget: int = 1024,
|
|
452
|
+
enable_thinking: bool = True,
|
|
453
|
+
**kwargs
|
|
454
|
+
):
|
|
455
|
+
# Force reasoning_model to True to enable reasoning_content extraction
|
|
456
|
+
kwargs['is_reasoning_model'] = True
|
|
457
|
+
super().__init__(**kwargs)
|
|
458
|
+
|
|
459
|
+
self.model_kwargs['model'] = model
|
|
460
|
+
self.thinking_budget = thinking_budget
|
|
461
|
+
self.enable_thinking = enable_thinking
|
|
462
|
+
|
|
463
|
+
def _prepare_input(self, input_data: str | BaseModel | list[dict]) -> Messages:
|
|
464
|
+
"""Override to ensure Nemotron chat template requirements are met."""
|
|
465
|
+
messages = super()._prepare_input(input_data)
|
|
466
|
+
return messages
|
|
467
|
+
|
|
468
|
+
def __call__(
|
|
469
|
+
self,
|
|
470
|
+
input_data: str | BaseModel | list[dict],
|
|
471
|
+
thinking_budget: Optional[int] = None,
|
|
472
|
+
**kwargs
|
|
473
|
+
) -> List[Dict[str, Any]]:
|
|
474
|
+
budget = thinking_budget or self.thinking_budget
|
|
475
|
+
|
|
476
|
+
if not self.enable_thinking:
|
|
477
|
+
# Simple pass with thinking disabled in template
|
|
478
|
+
return super().__call__(
|
|
479
|
+
input_data,
|
|
480
|
+
chat_template_kwargs={"enable_thinking": False},
|
|
481
|
+
**kwargs
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# --- STEP 1: Generate Thinking Trace ---
|
|
485
|
+
# We manually append <think> to force the reasoning MoE layers
|
|
486
|
+
messages = self._prepare_input(input_data)
|
|
487
|
+
|
|
488
|
+
# We use the raw text completion for the budget phase
|
|
489
|
+
# Stop at the closing tag or budget limit
|
|
490
|
+
thinking_response = self.text_completion(
|
|
491
|
+
input_data,
|
|
492
|
+
max_tokens=budget,
|
|
493
|
+
stop=["</think>"],
|
|
494
|
+
**kwargs
|
|
495
|
+
)[0]
|
|
496
|
+
|
|
497
|
+
reasoning_content = thinking_response['parsed']
|
|
498
|
+
|
|
499
|
+
# Ensure proper tag closing for the second pass
|
|
500
|
+
if "</think>" not in reasoning_content:
|
|
501
|
+
reasoning_content = f"{reasoning_content}\n</think>"
|
|
502
|
+
elif not reasoning_content.endswith("</think>"):
|
|
503
|
+
# Ensure it ends exactly with the tag for continuity
|
|
504
|
+
reasoning_content = reasoning_content.split("</think>")[0] + "</think>"
|
|
505
|
+
|
|
506
|
+
# --- STEP 2: Generate Final Answer ---
|
|
507
|
+
# Append the thought to the assistant role and continue
|
|
508
|
+
final_messages = messages + [
|
|
509
|
+
{"role": "assistant", "content": f"<think>\n{reasoning_content}\n"}
|
|
510
|
+
]
|
|
511
|
+
|
|
512
|
+
# Use continue_final_message to prevent the model from repeating the header
|
|
513
|
+
results = super().__call__(
|
|
514
|
+
final_messages,
|
|
515
|
+
extra_body={"continue_final_message": True},
|
|
516
|
+
**kwargs
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Inject the reasoning back into the result for the UI/API
|
|
520
|
+
for res in results:
|
|
521
|
+
res['reasoning_content'] = reasoning_content
|
|
522
|
+
|
|
523
|
+
return results
|
llm_utils/lm/mixins.py
CHANGED
|
@@ -469,6 +469,210 @@ class TokenizationMixin:
|
|
|
469
469
|
data = response.json()
|
|
470
470
|
return data['prompt']
|
|
471
471
|
|
|
472
|
+
def generate(
|
|
473
|
+
self,
|
|
474
|
+
input_context: str | list[int],
|
|
475
|
+
*,
|
|
476
|
+
max_tokens: int = 512,
|
|
477
|
+
temperature: float = 1.0,
|
|
478
|
+
top_p: float = 1.0,
|
|
479
|
+
top_k: int = -1,
|
|
480
|
+
min_p: float = 0.0,
|
|
481
|
+
repetition_penalty: float = 1.0,
|
|
482
|
+
presence_penalty: float = 0.0,
|
|
483
|
+
frequency_penalty: float = 0.0,
|
|
484
|
+
n: int = 1,
|
|
485
|
+
stop: str | list[str] | None = None,
|
|
486
|
+
stop_token_ids: list[int] | None = None,
|
|
487
|
+
ignore_eos: bool = False,
|
|
488
|
+
min_tokens: int = 0,
|
|
489
|
+
skip_special_tokens: bool = True,
|
|
490
|
+
spaces_between_special_tokens: bool = True,
|
|
491
|
+
logprobs: int | None = None,
|
|
492
|
+
prompt_logprobs: int | None = None,
|
|
493
|
+
seed: int | None = None,
|
|
494
|
+
return_token_ids: bool = False,
|
|
495
|
+
return_text: bool = True,
|
|
496
|
+
stream: bool = False,
|
|
497
|
+
**kwargs,
|
|
498
|
+
) -> dict[str, Any] | list[dict[str, Any]]:
|
|
499
|
+
"""
|
|
500
|
+
Generate text using HuggingFace Transformers-style interface.
|
|
501
|
+
|
|
502
|
+
This method provides a low-level generation interface similar to
|
|
503
|
+
HuggingFace's model.generate(), working directly with token IDs
|
|
504
|
+
and the /inference/v1/generate endpoint.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
input_context: Input as text (str) or token IDs (list[int])
|
|
508
|
+
max_tokens: Maximum number of tokens to generate
|
|
509
|
+
temperature: Sampling temperature (higher = more random)
|
|
510
|
+
top_p: Nucleus sampling probability threshold
|
|
511
|
+
top_k: Top-k sampling parameter (-1 to disable)
|
|
512
|
+
min_p: Minimum probability threshold
|
|
513
|
+
repetition_penalty: Penalty for repeating tokens
|
|
514
|
+
presence_penalty: Presence penalty for token diversity
|
|
515
|
+
frequency_penalty: Frequency penalty for token diversity
|
|
516
|
+
n: Number of sequences to generate
|
|
517
|
+
stop: Stop sequences (string or list of strings)
|
|
518
|
+
stop_token_ids: Token IDs to stop generation
|
|
519
|
+
ignore_eos: Whether to ignore EOS token
|
|
520
|
+
min_tokens: Minimum number of tokens to generate
|
|
521
|
+
skip_special_tokens: Skip special tokens in output
|
|
522
|
+
spaces_between_special_tokens: Add spaces between special tokens
|
|
523
|
+
logprobs: Number of top logprobs to return
|
|
524
|
+
prompt_logprobs: Number of prompt logprobs to return
|
|
525
|
+
seed: Random seed for reproducibility
|
|
526
|
+
return_token_ids: If True, include token IDs in output
|
|
527
|
+
return_text: If True, include decoded text in output
|
|
528
|
+
stream: If True, stream the response (not fully implemented)
|
|
529
|
+
**kwargs: Additional parameters passed to the API
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
Dictionary with generation results containing:
|
|
533
|
+
- 'text': Generated text (if return_text=True)
|
|
534
|
+
- 'token_ids': Generated token IDs (if return_token_ids=True)
|
|
535
|
+
- 'logprobs': Log probabilities (if logprobs is set)
|
|
536
|
+
If n > 1, returns list of result dictionaries
|
|
537
|
+
"""
|
|
538
|
+
import requests
|
|
539
|
+
|
|
540
|
+
# Convert text input to token IDs if needed
|
|
541
|
+
if isinstance(input_context, str):
|
|
542
|
+
token_ids = self.encode(input_context, add_special_tokens=True)
|
|
543
|
+
else:
|
|
544
|
+
token_ids = input_context
|
|
545
|
+
|
|
546
|
+
# Get base_url (generate endpoint is at root level like /inference/v1/generate)
|
|
547
|
+
base_url = str(self.client.base_url).rstrip('/')
|
|
548
|
+
if base_url.endswith('/v1'):
|
|
549
|
+
base_url = base_url[:-3] # Remove '/v1'
|
|
550
|
+
|
|
551
|
+
# Build sampling params matching the API schema
|
|
552
|
+
sampling_params = {
|
|
553
|
+
'max_tokens': max_tokens,
|
|
554
|
+
'temperature': temperature,
|
|
555
|
+
'top_p': top_p,
|
|
556
|
+
'top_k': top_k,
|
|
557
|
+
'min_p': min_p,
|
|
558
|
+
'repetition_penalty': repetition_penalty,
|
|
559
|
+
'presence_penalty': presence_penalty,
|
|
560
|
+
'frequency_penalty': frequency_penalty,
|
|
561
|
+
'n': n,
|
|
562
|
+
'stop': stop or [],
|
|
563
|
+
'stop_token_ids': stop_token_ids or [],
|
|
564
|
+
'ignore_eos': ignore_eos,
|
|
565
|
+
'min_tokens': min_tokens,
|
|
566
|
+
'skip_special_tokens': skip_special_tokens,
|
|
567
|
+
'spaces_between_special_tokens': spaces_between_special_tokens,
|
|
568
|
+
'logprobs': logprobs,
|
|
569
|
+
'prompt_logprobs': prompt_logprobs,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
if seed is not None:
|
|
573
|
+
sampling_params['seed'] = seed
|
|
574
|
+
|
|
575
|
+
# Build request payload
|
|
576
|
+
request_data = {
|
|
577
|
+
'token_ids': token_ids,
|
|
578
|
+
'sampling_params': sampling_params,
|
|
579
|
+
'stream': stream,
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
# Add any additional kwargs
|
|
583
|
+
request_data.update(kwargs)
|
|
584
|
+
|
|
585
|
+
# Make API request
|
|
586
|
+
response = requests.post(
|
|
587
|
+
f'{base_url}/inference/v1/generate',
|
|
588
|
+
json=request_data,
|
|
589
|
+
)
|
|
590
|
+
response.raise_for_status()
|
|
591
|
+
data = response.json()
|
|
592
|
+
|
|
593
|
+
# Process response
|
|
594
|
+
# The API may return different structures, handle common cases
|
|
595
|
+
if n == 1:
|
|
596
|
+
result = {}
|
|
597
|
+
|
|
598
|
+
# Extract from choices format
|
|
599
|
+
if 'choices' in data and len(data['choices']) > 0:
|
|
600
|
+
choice = data['choices'][0]
|
|
601
|
+
|
|
602
|
+
# Get token IDs first
|
|
603
|
+
generated_token_ids = None
|
|
604
|
+
if 'token_ids' in choice:
|
|
605
|
+
generated_token_ids = choice['token_ids']
|
|
606
|
+
if return_token_ids:
|
|
607
|
+
result['token_ids'] = generated_token_ids
|
|
608
|
+
|
|
609
|
+
# Decode to text if requested
|
|
610
|
+
if return_text:
|
|
611
|
+
if 'text' in choice:
|
|
612
|
+
result['text'] = choice['text']
|
|
613
|
+
elif generated_token_ids is not None:
|
|
614
|
+
# Decode token IDs to text
|
|
615
|
+
result['text'] = self.decode(generated_token_ids)
|
|
616
|
+
elif 'message' in choice and 'content' in choice['message']:
|
|
617
|
+
result['text'] = choice['message']['content']
|
|
618
|
+
|
|
619
|
+
# Include logprobs if requested
|
|
620
|
+
if logprobs is not None and 'logprobs' in choice:
|
|
621
|
+
result['logprobs'] = choice['logprobs']
|
|
622
|
+
|
|
623
|
+
# Include finish reason
|
|
624
|
+
if 'finish_reason' in choice:
|
|
625
|
+
result['finish_reason'] = choice['finish_reason']
|
|
626
|
+
|
|
627
|
+
# Fallback to direct fields
|
|
628
|
+
elif 'text' in data and return_text:
|
|
629
|
+
result['text'] = data['text']
|
|
630
|
+
elif 'token_ids' in data:
|
|
631
|
+
if return_token_ids:
|
|
632
|
+
result['token_ids'] = data['token_ids']
|
|
633
|
+
if return_text:
|
|
634
|
+
result['text'] = self.decode(data['token_ids'])
|
|
635
|
+
|
|
636
|
+
# Store raw response for debugging
|
|
637
|
+
result['_raw_response'] = data
|
|
638
|
+
|
|
639
|
+
return result
|
|
640
|
+
else:
|
|
641
|
+
# Multiple generations (n > 1)
|
|
642
|
+
results = []
|
|
643
|
+
choices = data.get('choices', [])
|
|
644
|
+
|
|
645
|
+
for i in range(min(n, len(choices))):
|
|
646
|
+
choice = choices[i]
|
|
647
|
+
result = {}
|
|
648
|
+
|
|
649
|
+
# Get token IDs
|
|
650
|
+
generated_token_ids = None
|
|
651
|
+
if 'token_ids' in choice:
|
|
652
|
+
generated_token_ids = choice['token_ids']
|
|
653
|
+
if return_token_ids:
|
|
654
|
+
result['token_ids'] = generated_token_ids
|
|
655
|
+
|
|
656
|
+
# Decode to text if requested
|
|
657
|
+
if return_text:
|
|
658
|
+
if 'text' in choice:
|
|
659
|
+
result['text'] = choice['text']
|
|
660
|
+
elif generated_token_ids is not None:
|
|
661
|
+
result['text'] = self.decode(generated_token_ids)
|
|
662
|
+
elif 'message' in choice and 'content' in choice['message']:
|
|
663
|
+
result['text'] = choice['message']['content']
|
|
664
|
+
|
|
665
|
+
if logprobs is not None and 'logprobs' in choice:
|
|
666
|
+
result['logprobs'] = choice['logprobs']
|
|
667
|
+
|
|
668
|
+
if 'finish_reason' in choice:
|
|
669
|
+
result['finish_reason'] = choice['finish_reason']
|
|
670
|
+
|
|
671
|
+
result['_raw_response'] = choice
|
|
672
|
+
results.append(result)
|
|
673
|
+
|
|
674
|
+
return results
|
|
675
|
+
|
|
472
676
|
|
|
473
677
|
class ModelUtilsMixin:
|
|
474
678
|
"""Mixin for model utility methods."""
|
|
@@ -6,10 +6,10 @@ llm_utils/chat_format/transform.py,sha256=PJ2g9KT1GSbWuAs7giEbTpTAffpU9QsIXyRlbf
|
|
|
6
6
|
llm_utils/chat_format/utils.py,sha256=M2EctZ6NeHXqFYufh26Y3CpSphN0bdZm5xoNaEJj5vg,1251
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=4jYMy3wPH3tg-tHFyWEWOqrnmX4Tu32VZCdzRGMGQsI,778
|
|
8
8
|
llm_utils/lm/base_prompt_builder.py,sha256=_TzYMsWr-SsbA_JNXptUVN56lV5RfgWWTrFi-E8LMy4,12337
|
|
9
|
-
llm_utils/lm/llm.py,sha256=
|
|
9
|
+
llm_utils/lm/llm.py,sha256=2vq8BScwp4gWb89EmUPaiBCzkBSr0x2B3qJLaPM11_M,19644
|
|
10
10
|
llm_utils/lm/llm_signature.py,sha256=vV8uZgLLd6ZKqWbq0OPywWvXAfl7hrJQnbtBF-VnZRU,1244
|
|
11
11
|
llm_utils/lm/lm_base.py,sha256=Bk3q34KrcCK_bC4Ryxbc3KqkiPL39zuVZaBQ1i6wJqs,9437
|
|
12
|
-
llm_utils/lm/mixins.py,sha256=
|
|
12
|
+
llm_utils/lm/mixins.py,sha256=Nz7CwJFBOvbZNbODUlJC04Pcbac3zWnT8vy7sZG_MVI,24906
|
|
13
13
|
llm_utils/lm/openai_memoize.py,sha256=rYrSFPpgO7adsjK1lVdkJlhqqIw_13TCW7zU8eNwm3o,5185
|
|
14
14
|
llm_utils/lm/signature.py,sha256=K1hvCAqoC5CmsQ0Y_ywnYy2fRb5JzmIK8OS-hjH-5To,9971
|
|
15
15
|
llm_utils/lm/utils.py,sha256=dEKFta8S6Mm4LjIctcpFlEGL9RnmLm5DHd2TA70UWuA,12649
|
|
@@ -50,7 +50,7 @@ vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
|
|
|
50
50
|
vision_utils/__init__.py,sha256=hF54sT6FAxby8kDVhOvruy4yot8O-Ateey5n96O1pQM,284
|
|
51
51
|
vision_utils/io_utils.py,sha256=pI0Va6miesBysJcllK6NXCay8HpGZsaMWwlsKB2DMgA,26510
|
|
52
52
|
vision_utils/plot.py,sha256=HkNj3osA3moPuupP1VguXfPPOW614dZO5tvC-EFKpKM,12028
|
|
53
|
-
speedy_utils-1.1.
|
|
54
|
-
speedy_utils-1.1.
|
|
55
|
-
speedy_utils-1.1.
|
|
56
|
-
speedy_utils-1.1.
|
|
53
|
+
speedy_utils-1.1.36.dist-info/METADATA,sha256=yZYfOkBwR1aiMwAKAxn78sYCbmNnBt-5lmGY6d11hPI,8048
|
|
54
|
+
speedy_utils-1.1.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
55
|
+
speedy_utils-1.1.36.dist-info/entry_points.txt,sha256=rwn89AYfBUh9SRJtFbpp-u2JIKiqmZ2sczvqyO6s9cI,289
|
|
56
|
+
speedy_utils-1.1.36.dist-info/RECORD,,
|
|
File without changes
|