speedy-utils 1.1.34__py3-none-any.whl → 1.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .lm_base import LMBase, get_model_name
7
7
  from .mixins import (
8
8
  ModelUtilsMixin,
9
9
  TemperatureRangeMixin,
10
+ TokenizationMixin,
10
11
  TwoStepPydanticMixin,
11
12
  VLLMMixin,
12
13
  )
@@ -14,19 +15,20 @@ from .signature import Input, InputField, Output, OutputField, Signature
14
15
 
15
16
 
16
17
  __all__ = [
17
- "LMBase",
18
- "LLM",
19
- "AsyncLM",
20
- "AsyncLLMTask",
21
- "BasePromptBuilder",
22
- "LLMSignature",
23
- "Signature",
24
- "InputField",
25
- "OutputField",
26
- "Input",
27
- "Output",
28
- "TemperatureRangeMixin",
29
- "TwoStepPydanticMixin",
30
- "VLLMMixin",
31
- "ModelUtilsMixin",
18
+ 'LMBase',
19
+ 'LLM',
20
+ 'AsyncLM',
21
+ 'AsyncLLMTask',
22
+ 'BasePromptBuilder',
23
+ 'LLMSignature',
24
+ 'Signature',
25
+ 'InputField',
26
+ 'OutputField',
27
+ 'Input',
28
+ 'Output',
29
+ 'TemperatureRangeMixin',
30
+ 'TwoStepPydanticMixin',
31
+ 'VLLMMixin',
32
+ 'ModelUtilsMixin',
33
+ 'TokenizationMixin',
32
34
  ]
llm_utils/lm/llm.py CHANGED
@@ -20,6 +20,7 @@ from .base_prompt_builder import BasePromptBuilder
20
20
  from .mixins import (
21
21
  ModelUtilsMixin,
22
22
  TemperatureRangeMixin,
23
+ TokenizationMixin,
23
24
  TwoStepPydanticMixin,
24
25
  VLLMMixin,
25
26
  )
@@ -47,6 +48,7 @@ class LLM(
47
48
  TwoStepPydanticMixin,
48
49
  VLLMMixin,
49
50
  ModelUtilsMixin,
51
+ TokenizationMixin,
50
52
  ):
51
53
  """LLM task with structured input/output handling."""
52
54
 
@@ -433,3 +435,89 @@ class LLM(
433
435
  vllm_reuse=vllm_reuse,
434
436
  **model_kwargs,
435
437
  )
438
+ from typing import Any, Dict, List, Optional, Type, Union
439
+ from pydantic import BaseModel
440
+ from .llm import LLM, Messages
441
+
442
+ class LLM_NEMOTRON3(LLM):
443
+ """
444
+ Custom implementation for NVIDIA Nemotron-3 reasoning models.
445
+ Supports thinking budget control and native reasoning tags.
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ model: str = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
451
+ thinking_budget: int = 1024,
452
+ enable_thinking: bool = True,
453
+ **kwargs
454
+ ):
455
+ # Force reasoning_model to True to enable reasoning_content extraction
456
+ kwargs['is_reasoning_model'] = True
457
+ super().__init__(**kwargs)
458
+
459
+ self.model_kwargs['model'] = model
460
+ self.thinking_budget = thinking_budget
461
+ self.enable_thinking = enable_thinking
462
+
463
+ def _prepare_input(self, input_data: str | BaseModel | list[dict]) -> Messages:
464
+ """Override to ensure Nemotron chat template requirements are met."""
465
+ messages = super()._prepare_input(input_data)
466
+ return messages
467
+
468
+ def __call__(
469
+ self,
470
+ input_data: str | BaseModel | list[dict],
471
+ thinking_budget: Optional[int] = None,
472
+ **kwargs
473
+ ) -> List[Dict[str, Any]]:
474
+ budget = thinking_budget or self.thinking_budget
475
+
476
+ if not self.enable_thinking:
477
+ # Simple pass with thinking disabled in template
478
+ return super().__call__(
479
+ input_data,
480
+ chat_template_kwargs={"enable_thinking": False},
481
+ **kwargs
482
+ )
483
+
484
+ # --- STEP 1: Generate Thinking Trace ---
485
+ # We manually append <think> to force the reasoning MoE layers
486
+ messages = self._prepare_input(input_data)
487
+
488
+ # We use the raw text completion for the budget phase
489
+ # Stop at the closing tag or budget limit
490
+ thinking_response = self.text_completion(
491
+ input_data,
492
+ max_tokens=budget,
493
+ stop=["</think>"],
494
+ **kwargs
495
+ )[0]
496
+
497
+ reasoning_content = thinking_response['parsed']
498
+
499
+ # Ensure proper tag closing for the second pass
500
+ if "</think>" not in reasoning_content:
501
+ reasoning_content = f"{reasoning_content}\n</think>"
502
+ elif not reasoning_content.endswith("</think>"):
503
+ # Ensure it ends exactly with the tag for continuity
504
+ reasoning_content = reasoning_content.split("</think>")[0] + "</think>"
505
+
506
+ # --- STEP 2: Generate Final Answer ---
507
+ # Append the thought to the assistant role and continue
508
+ final_messages = messages + [
509
+ {"role": "assistant", "content": f"<think>\n{reasoning_content}\n"}
510
+ ]
511
+
512
+ # Use continue_final_message to prevent the model from repeating the header
513
+ results = super().__call__(
514
+ final_messages,
515
+ extra_body={"continue_final_message": True},
516
+ **kwargs
517
+ )
518
+
519
+ # Inject the reasoning back into the result for the UI/API
520
+ for res in results:
521
+ res['reasoning_content'] = reasoning_content
522
+
523
+ return results
llm_utils/lm/mixins.py CHANGED
@@ -396,6 +396,284 @@ class VLLMMixin:
396
396
  return _kill_vllm_on_port(port)
397
397
 
398
398
 
399
+ class TokenizationMixin:
400
+ """Mixin for tokenization operations (encode/decode)."""
401
+
402
+ def encode(
403
+ self,
404
+ text: str,
405
+ *,
406
+ add_special_tokens: bool = True,
407
+ return_token_strs: bool = False,
408
+ ) -> list[int] | tuple[list[int], list[str]]:
409
+ """
410
+ Encode text to token IDs using the model's tokenizer.
411
+
412
+ Args:
413
+ text: Text to tokenize
414
+ add_special_tokens: Whether to add special tokens (e.g., BOS)
415
+ return_token_strs: If True, also return token strings
416
+
417
+ Returns:
418
+ List of token IDs, or tuple of (token IDs, token strings)
419
+ """
420
+ import requests
421
+
422
+ # Get base_url from client and remove /v1 suffix if present
423
+ # (tokenize endpoint is at root level, not under /v1)
424
+ base_url = str(self.client.base_url).rstrip('/')
425
+ if base_url.endswith('/v1'):
426
+ base_url = base_url[:-3] # Remove '/v1'
427
+
428
+ response = requests.post(
429
+ f'{base_url}/tokenize',
430
+ json={
431
+ 'prompt': text,
432
+ 'add_special_tokens': add_special_tokens,
433
+ 'return_token_strs': return_token_strs,
434
+ },
435
+ )
436
+ response.raise_for_status()
437
+ data = response.json()
438
+
439
+ if return_token_strs:
440
+ return data['tokens'], data.get('token_strs', [])
441
+ return data['tokens']
442
+
443
+ def decode(
444
+ self,
445
+ token_ids: list[int],
446
+ ) -> str:
447
+ """
448
+ Decode token IDs to text using the model's tokenizer.
449
+
450
+ Args:
451
+ token_ids: List of token IDs to decode
452
+
453
+ Returns:
454
+ Decoded text string
455
+ """
456
+ import requests
457
+
458
+ # Get base_url from client and remove /v1 suffix if present
459
+ # (detokenize endpoint is at root level, not under /v1)
460
+ base_url = str(self.client.base_url).rstrip('/')
461
+ if base_url.endswith('/v1'):
462
+ base_url = base_url[:-3] # Remove '/v1'
463
+
464
+ response = requests.post(
465
+ f'{base_url}/detokenize',
466
+ json={'tokens': token_ids},
467
+ )
468
+ response.raise_for_status()
469
+ data = response.json()
470
+ return data['prompt']
471
+
472
+ def generate(
473
+ self,
474
+ input_context: str | list[int],
475
+ *,
476
+ max_tokens: int = 512,
477
+ temperature: float = 1.0,
478
+ top_p: float = 1.0,
479
+ top_k: int = -1,
480
+ min_p: float = 0.0,
481
+ repetition_penalty: float = 1.0,
482
+ presence_penalty: float = 0.0,
483
+ frequency_penalty: float = 0.0,
484
+ n: int = 1,
485
+ stop: str | list[str] | None = None,
486
+ stop_token_ids: list[int] | None = None,
487
+ ignore_eos: bool = False,
488
+ min_tokens: int = 0,
489
+ skip_special_tokens: bool = True,
490
+ spaces_between_special_tokens: bool = True,
491
+ logprobs: int | None = None,
492
+ prompt_logprobs: int | None = None,
493
+ seed: int | None = None,
494
+ return_token_ids: bool = False,
495
+ return_text: bool = True,
496
+ stream: bool = False,
497
+ **kwargs,
498
+ ) -> dict[str, Any] | list[dict[str, Any]]:
499
+ """
500
+ Generate text using HuggingFace Transformers-style interface.
501
+
502
+ This method provides a low-level generation interface similar to
503
+ HuggingFace's model.generate(), working directly with token IDs
504
+ and the /inference/v1/generate endpoint.
505
+
506
+ Args:
507
+ input_context: Input as text (str) or token IDs (list[int])
508
+ max_tokens: Maximum number of tokens to generate
509
+ temperature: Sampling temperature (higher = more random)
510
+ top_p: Nucleus sampling probability threshold
511
+ top_k: Top-k sampling parameter (-1 to disable)
512
+ min_p: Minimum probability threshold
513
+ repetition_penalty: Penalty for repeating tokens
514
+ presence_penalty: Presence penalty for token diversity
515
+ frequency_penalty: Frequency penalty for token diversity
516
+ n: Number of sequences to generate
517
+ stop: Stop sequences (string or list of strings)
518
+ stop_token_ids: Token IDs to stop generation
519
+ ignore_eos: Whether to ignore EOS token
520
+ min_tokens: Minimum number of tokens to generate
521
+ skip_special_tokens: Skip special tokens in output
522
+ spaces_between_special_tokens: Add spaces between special tokens
523
+ logprobs: Number of top logprobs to return
524
+ prompt_logprobs: Number of prompt logprobs to return
525
+ seed: Random seed for reproducibility
526
+ return_token_ids: If True, include token IDs in output
527
+ return_text: If True, include decoded text in output
528
+ stream: If True, stream the response (not fully implemented)
529
+ **kwargs: Additional parameters passed to the API
530
+
531
+ Returns:
532
+ Dictionary with generation results containing:
533
+ - 'text': Generated text (if return_text=True)
534
+ - 'token_ids': Generated token IDs (if return_token_ids=True)
535
+ - 'logprobs': Log probabilities (if logprobs is set)
536
+ If n > 1, returns list of result dictionaries
537
+ """
538
+ import requests
539
+
540
+ # Convert text input to token IDs if needed
541
+ if isinstance(input_context, str):
542
+ token_ids = self.encode(input_context, add_special_tokens=True)
543
+ else:
544
+ token_ids = input_context
545
+
546
+ # Get base_url (generate endpoint is at root level like /inference/v1/generate)
547
+ base_url = str(self.client.base_url).rstrip('/')
548
+ if base_url.endswith('/v1'):
549
+ base_url = base_url[:-3] # Remove '/v1'
550
+
551
+ # Build sampling params matching the API schema
552
+ sampling_params = {
553
+ 'max_tokens': max_tokens,
554
+ 'temperature': temperature,
555
+ 'top_p': top_p,
556
+ 'top_k': top_k,
557
+ 'min_p': min_p,
558
+ 'repetition_penalty': repetition_penalty,
559
+ 'presence_penalty': presence_penalty,
560
+ 'frequency_penalty': frequency_penalty,
561
+ 'n': n,
562
+ 'stop': stop or [],
563
+ 'stop_token_ids': stop_token_ids or [],
564
+ 'ignore_eos': ignore_eos,
565
+ 'min_tokens': min_tokens,
566
+ 'skip_special_tokens': skip_special_tokens,
567
+ 'spaces_between_special_tokens': spaces_between_special_tokens,
568
+ 'logprobs': logprobs,
569
+ 'prompt_logprobs': prompt_logprobs,
570
+ }
571
+
572
+ if seed is not None:
573
+ sampling_params['seed'] = seed
574
+
575
+ # Build request payload
576
+ request_data = {
577
+ 'token_ids': token_ids,
578
+ 'sampling_params': sampling_params,
579
+ 'stream': stream,
580
+ }
581
+
582
+ # Add any additional kwargs
583
+ request_data.update(kwargs)
584
+
585
+ # Make API request
586
+ response = requests.post(
587
+ f'{base_url}/inference/v1/generate',
588
+ json=request_data,
589
+ )
590
+ response.raise_for_status()
591
+ data = response.json()
592
+
593
+ # Process response
594
+ # The API may return different structures, handle common cases
595
+ if n == 1:
596
+ result = {}
597
+
598
+ # Extract from choices format
599
+ if 'choices' in data and len(data['choices']) > 0:
600
+ choice = data['choices'][0]
601
+
602
+ # Get token IDs first
603
+ generated_token_ids = None
604
+ if 'token_ids' in choice:
605
+ generated_token_ids = choice['token_ids']
606
+ if return_token_ids:
607
+ result['token_ids'] = generated_token_ids
608
+
609
+ # Decode to text if requested
610
+ if return_text:
611
+ if 'text' in choice:
612
+ result['text'] = choice['text']
613
+ elif generated_token_ids is not None:
614
+ # Decode token IDs to text
615
+ result['text'] = self.decode(generated_token_ids)
616
+ elif 'message' in choice and 'content' in choice['message']:
617
+ result['text'] = choice['message']['content']
618
+
619
+ # Include logprobs if requested
620
+ if logprobs is not None and 'logprobs' in choice:
621
+ result['logprobs'] = choice['logprobs']
622
+
623
+ # Include finish reason
624
+ if 'finish_reason' in choice:
625
+ result['finish_reason'] = choice['finish_reason']
626
+
627
+ # Fallback to direct fields
628
+ elif 'text' in data and return_text:
629
+ result['text'] = data['text']
630
+ elif 'token_ids' in data:
631
+ if return_token_ids:
632
+ result['token_ids'] = data['token_ids']
633
+ if return_text:
634
+ result['text'] = self.decode(data['token_ids'])
635
+
636
+ # Store raw response for debugging
637
+ result['_raw_response'] = data
638
+
639
+ return result
640
+ else:
641
+ # Multiple generations (n > 1)
642
+ results = []
643
+ choices = data.get('choices', [])
644
+
645
+ for i in range(min(n, len(choices))):
646
+ choice = choices[i]
647
+ result = {}
648
+
649
+ # Get token IDs
650
+ generated_token_ids = None
651
+ if 'token_ids' in choice:
652
+ generated_token_ids = choice['token_ids']
653
+ if return_token_ids:
654
+ result['token_ids'] = generated_token_ids
655
+
656
+ # Decode to text if requested
657
+ if return_text:
658
+ if 'text' in choice:
659
+ result['text'] = choice['text']
660
+ elif generated_token_ids is not None:
661
+ result['text'] = self.decode(generated_token_ids)
662
+ elif 'message' in choice and 'content' in choice['message']:
663
+ result['text'] = choice['message']['content']
664
+
665
+ if logprobs is not None and 'logprobs' in choice:
666
+ result['logprobs'] = choice['logprobs']
667
+
668
+ if 'finish_reason' in choice:
669
+ result['finish_reason'] = choice['finish_reason']
670
+
671
+ result['_raw_response'] = choice
672
+ results.append(result)
673
+
674
+ return results
675
+
676
+
399
677
  class ModelUtilsMixin:
400
678
  """Mixin for model utility methods."""
401
679
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.34
3
+ Version: 1.1.36
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy
@@ -4,12 +4,12 @@ llm_utils/chat_format/__init__.py,sha256=a7BKtBVktgLMq2Do4iNu3YfdDdTG1v9M_BkmaEo
4
4
  llm_utils/chat_format/display.py,sha256=Lffjzna9_vV3QgfiXZM2_tuVb3wqA-WxwrmoAjsJigw,17356
5
5
  llm_utils/chat_format/transform.py,sha256=PJ2g9KT1GSbWuAs7giEbTpTAffpU9QsIXyRlbfpTZUQ,5351
6
6
  llm_utils/chat_format/utils.py,sha256=M2EctZ6NeHXqFYufh26Y3CpSphN0bdZm5xoNaEJj5vg,1251
7
- llm_utils/lm/__init__.py,sha256=lFE2DZRpj6eRMo11kx7oRLyYOP2FuDmz08mAcq-cYew,730
7
+ llm_utils/lm/__init__.py,sha256=4jYMy3wPH3tg-tHFyWEWOqrnmX4Tu32VZCdzRGMGQsI,778
8
8
  llm_utils/lm/base_prompt_builder.py,sha256=_TzYMsWr-SsbA_JNXptUVN56lV5RfgWWTrFi-E8LMy4,12337
9
- llm_utils/lm/llm.py,sha256=C8Z8l6Ljs7uVX-zabLcDCdTf3fpGxfljaYRM0patHUQ,16469
9
+ llm_utils/lm/llm.py,sha256=2vq8BScwp4gWb89EmUPaiBCzkBSr0x2B3qJLaPM11_M,19644
10
10
  llm_utils/lm/llm_signature.py,sha256=vV8uZgLLd6ZKqWbq0OPywWvXAfl7hrJQnbtBF-VnZRU,1244
11
11
  llm_utils/lm/lm_base.py,sha256=Bk3q34KrcCK_bC4Ryxbc3KqkiPL39zuVZaBQ1i6wJqs,9437
12
- llm_utils/lm/mixins.py,sha256=on83g-JO2SpZ0digOpU8mooqFBX6w7Bc-DeGzVoVCX8,14536
12
+ llm_utils/lm/mixins.py,sha256=Nz7CwJFBOvbZNbODUlJC04Pcbac3zWnT8vy7sZG_MVI,24906
13
13
  llm_utils/lm/openai_memoize.py,sha256=rYrSFPpgO7adsjK1lVdkJlhqqIw_13TCW7zU8eNwm3o,5185
14
14
  llm_utils/lm/signature.py,sha256=K1hvCAqoC5CmsQ0Y_ywnYy2fRb5JzmIK8OS-hjH-5To,9971
15
15
  llm_utils/lm/utils.py,sha256=dEKFta8S6Mm4LjIctcpFlEGL9RnmLm5DHd2TA70UWuA,12649
@@ -50,7 +50,7 @@ vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
50
50
  vision_utils/__init__.py,sha256=hF54sT6FAxby8kDVhOvruy4yot8O-Ateey5n96O1pQM,284
51
51
  vision_utils/io_utils.py,sha256=pI0Va6miesBysJcllK6NXCay8HpGZsaMWwlsKB2DMgA,26510
52
52
  vision_utils/plot.py,sha256=HkNj3osA3moPuupP1VguXfPPOW614dZO5tvC-EFKpKM,12028
53
- speedy_utils-1.1.34.dist-info/METADATA,sha256=diZ6MTVGRDDhsbxoK9eBydHrbW2I6rvYG8lXXzJnJEU,8048
54
- speedy_utils-1.1.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
- speedy_utils-1.1.34.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
- speedy_utils-1.1.34.dist-info/RECORD,,
53
+ speedy_utils-1.1.36.dist-info/METADATA,sha256=yZYfOkBwR1aiMwAKAxn78sYCbmNnBt-5lmGY6d11hPI,8048
54
+ speedy_utils-1.1.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
+ speedy_utils-1.1.36.dist-info/entry_points.txt,sha256=rwn89AYfBUh9SRJtFbpp-u2JIKiqmZ2sczvqyO6s9cI,289
56
+ speedy_utils-1.1.36.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ fast-vllm = llm_utils.scripts.fast_vllm:main
2
3
  mpython = speedy_utils.scripts.mpython:main
3
4
  openapi_client_codegen = speedy_utils.scripts.openapi_client_codegen:main
4
5
  svllm = llm_utils.scripts.vllm_serve:main