speedy-utils 1.1.35__py3-none-any.whl → 1.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm.py CHANGED
@@ -435,3 +435,89 @@ class LLM(
435
435
  vllm_reuse=vllm_reuse,
436
436
  **model_kwargs,
437
437
  )
438
+ from typing import Any, Dict, List, Optional, Type, Union
439
+ from pydantic import BaseModel
440
+ from .llm import LLM, Messages
441
+
442
+ class LLM_NEMOTRON3(LLM):
443
+ """
444
+ Custom implementation for NVIDIA Nemotron-3 reasoning models.
445
+ Supports thinking budget control and native reasoning tags.
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ model: str = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
451
+ thinking_budget: int = 1024,
452
+ enable_thinking: bool = True,
453
+ **kwargs
454
+ ):
455
+ # Force reasoning_model to True to enable reasoning_content extraction
456
+ kwargs['is_reasoning_model'] = True
457
+ super().__init__(**kwargs)
458
+
459
+ self.model_kwargs['model'] = model
460
+ self.thinking_budget = thinking_budget
461
+ self.enable_thinking = enable_thinking
462
+
463
+ def _prepare_input(self, input_data: str | BaseModel | list[dict]) -> Messages:
464
+ """Override to ensure Nemotron chat template requirements are met."""
465
+ messages = super()._prepare_input(input_data)
466
+ return messages
467
+
468
+ def __call__(
469
+ self,
470
+ input_data: str | BaseModel | list[dict],
471
+ thinking_budget: Optional[int] = None,
472
+ **kwargs
473
+ ) -> List[Dict[str, Any]]:
474
+ budget = thinking_budget or self.thinking_budget
475
+
476
+ if not self.enable_thinking:
477
+ # Simple pass with thinking disabled in template
478
+ return super().__call__(
479
+ input_data,
480
+ chat_template_kwargs={"enable_thinking": False},
481
+ **kwargs
482
+ )
483
+
484
+ # --- STEP 1: Generate Thinking Trace ---
485
+ # We manually append <think> to force the reasoning MoE layers
486
+ messages = self._prepare_input(input_data)
487
+
488
+ # We use the raw text completion for the budget phase
489
+ # Stop at the closing tag or budget limit
490
+ thinking_response = self.text_completion(
491
+ input_data,
492
+ max_tokens=budget,
493
+ stop=["</think>"],
494
+ **kwargs
495
+ )[0]
496
+
497
+ reasoning_content = thinking_response['parsed']
498
+
499
+ # Ensure proper tag closing for the second pass
500
+ if "</think>" not in reasoning_content:
501
+ reasoning_content = f"{reasoning_content}\n</think>"
502
+ elif not reasoning_content.endswith("</think>"):
503
+ # Ensure it ends exactly with the tag for continuity
504
+ reasoning_content = reasoning_content.split("</think>")[0] + "</think>"
505
+
506
+ # --- STEP 2: Generate Final Answer ---
507
+ # Append the thought to the assistant role and continue
508
+ final_messages = messages + [
509
+ {"role": "assistant", "content": f"<think>\n{reasoning_content}\n"}
510
+ ]
511
+
512
+ # Use continue_final_message to prevent the model from repeating the header
513
+ results = super().__call__(
514
+ final_messages,
515
+ extra_body={"continue_final_message": True},
516
+ **kwargs
517
+ )
518
+
519
+ # Inject the reasoning back into the result for the UI/API
520
+ for res in results:
521
+ res['reasoning_content'] = reasoning_content
522
+
523
+ return results
llm_utils/lm/mixins.py CHANGED
@@ -469,6 +469,210 @@ class TokenizationMixin:
469
469
  data = response.json()
470
470
  return data['prompt']
471
471
 
472
+ def generate(
473
+ self,
474
+ input_context: str | list[int],
475
+ *,
476
+ max_tokens: int = 512,
477
+ temperature: float = 1.0,
478
+ top_p: float = 1.0,
479
+ top_k: int = -1,
480
+ min_p: float = 0.0,
481
+ repetition_penalty: float = 1.0,
482
+ presence_penalty: float = 0.0,
483
+ frequency_penalty: float = 0.0,
484
+ n: int = 1,
485
+ stop: str | list[str] | None = None,
486
+ stop_token_ids: list[int] | None = None,
487
+ ignore_eos: bool = False,
488
+ min_tokens: int = 0,
489
+ skip_special_tokens: bool = True,
490
+ spaces_between_special_tokens: bool = True,
491
+ logprobs: int | None = None,
492
+ prompt_logprobs: int | None = None,
493
+ seed: int | None = None,
494
+ return_token_ids: bool = False,
495
+ return_text: bool = True,
496
+ stream: bool = False,
497
+ **kwargs,
498
+ ) -> dict[str, Any] | list[dict[str, Any]]:
499
+ """
500
+ Generate text using HuggingFace Transformers-style interface.
501
+
502
+ This method provides a low-level generation interface similar to
503
+ HuggingFace's model.generate(), working directly with token IDs
504
+ and the /inference/v1/generate endpoint.
505
+
506
+ Args:
507
+ input_context: Input as text (str) or token IDs (list[int])
508
+ max_tokens: Maximum number of tokens to generate
509
+ temperature: Sampling temperature (higher = more random)
510
+ top_p: Nucleus sampling probability threshold
511
+ top_k: Top-k sampling parameter (-1 to disable)
512
+ min_p: Minimum probability threshold
513
+ repetition_penalty: Penalty for repeating tokens
514
+ presence_penalty: Presence penalty for token diversity
515
+ frequency_penalty: Frequency penalty for token diversity
516
+ n: Number of sequences to generate
517
+ stop: Stop sequences (string or list of strings)
518
+ stop_token_ids: Token IDs to stop generation
519
+ ignore_eos: Whether to ignore EOS token
520
+ min_tokens: Minimum number of tokens to generate
521
+ skip_special_tokens: Skip special tokens in output
522
+ spaces_between_special_tokens: Add spaces between special tokens
523
+ logprobs: Number of top logprobs to return
524
+ prompt_logprobs: Number of prompt logprobs to return
525
+ seed: Random seed for reproducibility
526
+ return_token_ids: If True, include token IDs in output
527
+ return_text: If True, include decoded text in output
528
+ stream: If True, stream the response (not fully implemented)
529
+ **kwargs: Additional parameters passed to the API
530
+
531
+ Returns:
532
+ Dictionary with generation results containing:
533
+ - 'text': Generated text (if return_text=True)
534
+ - 'token_ids': Generated token IDs (if return_token_ids=True)
535
+ - 'logprobs': Log probabilities (if logprobs is set)
536
+ If n > 1, returns list of result dictionaries
537
+ """
538
+ import requests
539
+
540
+ # Convert text input to token IDs if needed
541
+ if isinstance(input_context, str):
542
+ token_ids = self.encode(input_context, add_special_tokens=True)
543
+ else:
544
+ token_ids = input_context
545
+
546
+ # Get base_url (generate endpoint is at root level like /inference/v1/generate)
547
+ base_url = str(self.client.base_url).rstrip('/')
548
+ if base_url.endswith('/v1'):
549
+ base_url = base_url[:-3] # Remove '/v1'
550
+
551
+ # Build sampling params matching the API schema
552
+ sampling_params = {
553
+ 'max_tokens': max_tokens,
554
+ 'temperature': temperature,
555
+ 'top_p': top_p,
556
+ 'top_k': top_k,
557
+ 'min_p': min_p,
558
+ 'repetition_penalty': repetition_penalty,
559
+ 'presence_penalty': presence_penalty,
560
+ 'frequency_penalty': frequency_penalty,
561
+ 'n': n,
562
+ 'stop': stop or [],
563
+ 'stop_token_ids': stop_token_ids or [],
564
+ 'ignore_eos': ignore_eos,
565
+ 'min_tokens': min_tokens,
566
+ 'skip_special_tokens': skip_special_tokens,
567
+ 'spaces_between_special_tokens': spaces_between_special_tokens,
568
+ 'logprobs': logprobs,
569
+ 'prompt_logprobs': prompt_logprobs,
570
+ }
571
+
572
+ if seed is not None:
573
+ sampling_params['seed'] = seed
574
+
575
+ # Build request payload
576
+ request_data = {
577
+ 'token_ids': token_ids,
578
+ 'sampling_params': sampling_params,
579
+ 'stream': stream,
580
+ }
581
+
582
+ # Add any additional kwargs
583
+ request_data.update(kwargs)
584
+
585
+ # Make API request
586
+ response = requests.post(
587
+ f'{base_url}/inference/v1/generate',
588
+ json=request_data,
589
+ )
590
+ response.raise_for_status()
591
+ data = response.json()
592
+
593
+ # Process response
594
+ # The API may return different structures, handle common cases
595
+ if n == 1:
596
+ result = {}
597
+
598
+ # Extract from choices format
599
+ if 'choices' in data and len(data['choices']) > 0:
600
+ choice = data['choices'][0]
601
+
602
+ # Get token IDs first
603
+ generated_token_ids = None
604
+ if 'token_ids' in choice:
605
+ generated_token_ids = choice['token_ids']
606
+ if return_token_ids:
607
+ result['token_ids'] = generated_token_ids
608
+
609
+ # Decode to text if requested
610
+ if return_text:
611
+ if 'text' in choice:
612
+ result['text'] = choice['text']
613
+ elif generated_token_ids is not None:
614
+ # Decode token IDs to text
615
+ result['text'] = self.decode(generated_token_ids)
616
+ elif 'message' in choice and 'content' in choice['message']:
617
+ result['text'] = choice['message']['content']
618
+
619
+ # Include logprobs if requested
620
+ if logprobs is not None and 'logprobs' in choice:
621
+ result['logprobs'] = choice['logprobs']
622
+
623
+ # Include finish reason
624
+ if 'finish_reason' in choice:
625
+ result['finish_reason'] = choice['finish_reason']
626
+
627
+ # Fallback to direct fields
628
+ elif 'text' in data and return_text:
629
+ result['text'] = data['text']
630
+ elif 'token_ids' in data:
631
+ if return_token_ids:
632
+ result['token_ids'] = data['token_ids']
633
+ if return_text:
634
+ result['text'] = self.decode(data['token_ids'])
635
+
636
+ # Store raw response for debugging
637
+ result['_raw_response'] = data
638
+
639
+ return result
640
+ else:
641
+ # Multiple generations (n > 1)
642
+ results = []
643
+ choices = data.get('choices', [])
644
+
645
+ for i in range(min(n, len(choices))):
646
+ choice = choices[i]
647
+ result = {}
648
+
649
+ # Get token IDs
650
+ generated_token_ids = None
651
+ if 'token_ids' in choice:
652
+ generated_token_ids = choice['token_ids']
653
+ if return_token_ids:
654
+ result['token_ids'] = generated_token_ids
655
+
656
+ # Decode to text if requested
657
+ if return_text:
658
+ if 'text' in choice:
659
+ result['text'] = choice['text']
660
+ elif generated_token_ids is not None:
661
+ result['text'] = self.decode(generated_token_ids)
662
+ elif 'message' in choice and 'content' in choice['message']:
663
+ result['text'] = choice['message']['content']
664
+
665
+ if logprobs is not None and 'logprobs' in choice:
666
+ result['logprobs'] = choice['logprobs']
667
+
668
+ if 'finish_reason' in choice:
669
+ result['finish_reason'] = choice['finish_reason']
670
+
671
+ result['_raw_response'] = choice
672
+ results.append(result)
673
+
674
+ return results
675
+
472
676
 
473
677
  class ModelUtilsMixin:
474
678
  """Mixin for model utility methods."""
@@ -0,0 +1,131 @@
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import time
5
+ import argparse
6
+ import subprocess
7
+ from pathlib import Path
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+
10
+ def get_hf_cache_home():
11
+ """Locate the Hugging Face cache directory."""
12
+ if "HF_HOME" in os.environ:
13
+ return Path(os.environ["HF_HOME"]) / "hub"
14
+ return Path.home() / ".cache" / "huggingface" / "hub"
15
+
16
+ def resolve_model_path(model_id, cache_dir):
17
+ """Find the physical snapshot directory for the given model ID."""
18
+ dir_name = "models--" + model_id.replace("/", "--")
19
+ model_root = cache_dir / dir_name
20
+ if not model_root.exists():
21
+ raise FileNotFoundError(f"Model folder not found at: {model_root}")
22
+
23
+ # 1. Try to find hash via refs/main
24
+ ref_path = model_root / "refs" / "main"
25
+ if ref_path.exists():
26
+ with open(ref_path, "r") as f:
27
+ commit_hash = f.read().strip()
28
+ snapshot_path = model_root / "snapshots" / commit_hash
29
+ if snapshot_path.exists():
30
+ return snapshot_path
31
+
32
+ # 2. Fallback to the newest snapshot folder
33
+ snapshots_dir = model_root / "snapshots"
34
+ if snapshots_dir.exists():
35
+ subdirs = [x for x in snapshots_dir.iterdir() if x.is_dir()]
36
+ if subdirs:
37
+ return sorted(subdirs, key=lambda x: x.stat().st_mtime, reverse=True)[0]
38
+
39
+ raise FileNotFoundError(f"No valid snapshot found in {model_root}")
40
+
41
+ def copy_worker(src, dst):
42
+ """Copy a single file, following symlinks to capture actual data."""
43
+ try:
44
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
45
+ # copy2 follows symlinks by default
46
+ shutil.copy2(src, dst)
47
+ return os.path.getsize(dst)
48
+ except Exception as e:
49
+ return str(e)
50
+
51
+ def cache_to_ram(model_id, shm_base, workers=64):
52
+ """Parallel copy from HF cache to the specified RAM directory."""
53
+ cache_home = get_hf_cache_home()
54
+ src_path = resolve_model_path(model_id, cache_home)
55
+
56
+ safe_name = model_id.replace("/", "_")
57
+ dst_path = Path(shm_base) / safe_name
58
+
59
+ # Check available space in shm
60
+ shm_stats = shutil.disk_usage(shm_base)
61
+ print(f"📦 Source: {src_path}", file=sys.stderr)
62
+ print(f"🚀 Target RAM: {dst_path} (Available: {shm_stats.free/(1024**3):.1f} GB)", file=sys.stderr)
63
+
64
+ files_to_copy = []
65
+ for root, _, files in os.walk(src_path):
66
+ for file in files:
67
+ full_src = Path(root) / file
68
+ rel_path = full_src.relative_to(src_path)
69
+ files_to_copy.append((full_src, dst_path / rel_path))
70
+
71
+ total_bytes = 0
72
+ start = time.time()
73
+ with ThreadPoolExecutor(max_workers=workers) as pool:
74
+ futures = {pool.submit(copy_worker, s, d): s for s, d in files_to_copy}
75
+ for i, future in enumerate(as_completed(futures)):
76
+ res = future.result()
77
+ if isinstance(res, int):
78
+ total_bytes += res
79
+ if i % 100 == 0 or i == len(files_to_copy) - 1:
80
+ print(f" Progress: {i+1}/{len(files_to_copy)} files...", end="\r", file=sys.stderr)
81
+
82
+ elapsed = time.time() - start
83
+ print(f"\n✅ Copied {total_bytes/(1024**3):.2f} GB in {elapsed:.2f}s", file=sys.stderr)
84
+ return dst_path
85
+
86
+ def main():
87
+ parser = argparse.ArgumentParser(description="vLLM RAM-cached loader", add_help=False)
88
+ parser.add_argument("--model", type=str, required=True, help="HuggingFace Model ID")
89
+ parser.add_argument("--shm-dir", type=str, default="/dev/shm", help="RAM disk mount point")
90
+ parser.add_argument("--cache-workers", type=int, default=64, help="Threads for copying")
91
+ parser.add_argument("--keep-cache", action="store_true", help="Do not delete files from RAM on exit")
92
+
93
+ # Capture wrapper args vs vLLM args
94
+ args, vllm_args = parser.parse_known_args()
95
+
96
+ ram_path = None
97
+ try:
98
+ # 1. Sync weights to RAM disk
99
+ ram_path = cache_to_ram(args.model, args.shm_dir, args.cache_workers)
100
+
101
+ # 2. Prepare vLLM Command
102
+ # Point vLLM to the RAM files, but keep the original model ID for the API
103
+ cmd = [
104
+ "vllm", "serve", str(ram_path),
105
+ "--served-model-name", args.model
106
+ ] + vllm_args
107
+
108
+ print(f"\n🔥 Launching vLLM...")
109
+ print(f" Command: {' '.join(cmd)}\n", file=sys.stderr)
110
+
111
+ # 3. Run vLLM and wait
112
+ subprocess.run(cmd, check=True)
113
+
114
+ except KeyboardInterrupt:
115
+ print("\n👋 Process interrupted by user.", file=sys.stderr)
116
+ except subprocess.CalledProcessError as e:
117
+ print(f"\n❌ vLLM exited with error: {e}", file=sys.stderr)
118
+ except Exception as e:
119
+ print(f"\n❌ Error: {e}", file=sys.stderr)
120
+ finally:
121
+ # 4. Cleanup RAM Disk
122
+ if ram_path and ram_path.exists() and not args.keep_cache:
123
+ print(f"🧹 Cleaning up RAM cache: {ram_path}", file=sys.stderr)
124
+ try:
125
+ shutil.rmtree(ram_path)
126
+ print("✨ RAM disk cleared.", file=sys.stderr)
127
+ except Exception as e:
128
+ print(f"⚠️ Failed to clean {ram_path}: {e}", file=sys.stderr)
129
+
130
+ if __name__ == "__main__":
131
+ main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.35
3
+ Version: 1.1.38
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy
@@ -6,10 +6,10 @@ llm_utils/chat_format/transform.py,sha256=PJ2g9KT1GSbWuAs7giEbTpTAffpU9QsIXyRlbf
6
6
  llm_utils/chat_format/utils.py,sha256=M2EctZ6NeHXqFYufh26Y3CpSphN0bdZm5xoNaEJj5vg,1251
7
7
  llm_utils/lm/__init__.py,sha256=4jYMy3wPH3tg-tHFyWEWOqrnmX4Tu32VZCdzRGMGQsI,778
8
8
  llm_utils/lm/base_prompt_builder.py,sha256=_TzYMsWr-SsbA_JNXptUVN56lV5RfgWWTrFi-E8LMy4,12337
9
- llm_utils/lm/llm.py,sha256=yas7Khd0Djc8-GD8jL--B2oPteV9FC3PpfPbr9XCLOQ,16515
9
+ llm_utils/lm/llm.py,sha256=2vq8BScwp4gWb89EmUPaiBCzkBSr0x2B3qJLaPM11_M,19644
10
10
  llm_utils/lm/llm_signature.py,sha256=vV8uZgLLd6ZKqWbq0OPywWvXAfl7hrJQnbtBF-VnZRU,1244
11
11
  llm_utils/lm/lm_base.py,sha256=Bk3q34KrcCK_bC4Ryxbc3KqkiPL39zuVZaBQ1i6wJqs,9437
12
- llm_utils/lm/mixins.py,sha256=o0tZiaKW4u1BxBVlT_0yTwnO8h7KnY02HX5TuWipvr0,16735
12
+ llm_utils/lm/mixins.py,sha256=Nz7CwJFBOvbZNbODUlJC04Pcbac3zWnT8vy7sZG_MVI,24906
13
13
  llm_utils/lm/openai_memoize.py,sha256=rYrSFPpgO7adsjK1lVdkJlhqqIw_13TCW7zU8eNwm3o,5185
14
14
  llm_utils/lm/signature.py,sha256=K1hvCAqoC5CmsQ0Y_ywnYy2fRb5JzmIK8OS-hjH-5To,9971
15
15
  llm_utils/lm/utils.py,sha256=dEKFta8S6Mm4LjIctcpFlEGL9RnmLm5DHd2TA70UWuA,12649
@@ -20,6 +20,7 @@ llm_utils/lm/async_lm/async_lm.py,sha256=W8n_S5PKJln9bzO9T525-tqo5DFwBZNXDucz_v-
20
20
  llm_utils/lm/async_lm/async_lm_base.py,sha256=ga5nCzows5Ye3yop41zsUxNYxcj_Vpf02DsfJ1eoE9U,8358
21
21
  llm_utils/lm/async_lm/lm_specific.py,sha256=PxP54ltrh9NrLJx7BPib52oYo_aCvDOjf7KzMjp1MYg,3929
22
22
  llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
23
+ llm_utils/scripts/fast_vllm.py,sha256=00UWajLOfTorSMgmgxUOpssdg55oOHneNUY0lhVuRGQ,5128
23
24
  llm_utils/scripts/vllm_load_balancer.py,sha256=eQlH07573EDWIBkwc9ef1WvI59anLr4hQqLfZvQk7xk,37133
24
25
  llm_utils/scripts/vllm_serve.py,sha256=tPcRB_MbJ01LcqC83RHQ7W9XDS7b1Ntc0fCRdegsNXU,14761
25
26
  llm_utils/vector_cache/__init__.py,sha256=oZXpjgBuutI-Pd_pBNYAQMY7-K2C6xv8Qt6a3p88GBQ,879
@@ -50,7 +51,7 @@ vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
50
51
  vision_utils/__init__.py,sha256=hF54sT6FAxby8kDVhOvruy4yot8O-Ateey5n96O1pQM,284
51
52
  vision_utils/io_utils.py,sha256=pI0Va6miesBysJcllK6NXCay8HpGZsaMWwlsKB2DMgA,26510
52
53
  vision_utils/plot.py,sha256=HkNj3osA3moPuupP1VguXfPPOW614dZO5tvC-EFKpKM,12028
53
- speedy_utils-1.1.35.dist-info/METADATA,sha256=wsz89syaYNXEeGjJXV8zb0W2ZrTjpN2Lj47tE7LQeEI,8048
54
- speedy_utils-1.1.35.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
- speedy_utils-1.1.35.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
- speedy_utils-1.1.35.dist-info/RECORD,,
54
+ speedy_utils-1.1.38.dist-info/METADATA,sha256=8WgY6bVeosqELf3KSmLIrygeQcYe1uQag4BFwvLfSWM,8048
55
+ speedy_utils-1.1.38.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
56
+ speedy_utils-1.1.38.dist-info/entry_points.txt,sha256=rwn89AYfBUh9SRJtFbpp-u2JIKiqmZ2sczvqyO6s9cI,289
57
+ speedy_utils-1.1.38.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ fast-vllm = llm_utils.scripts.fast_vllm:main
2
3
  mpython = speedy_utils.scripts.mpython:main
3
4
  openapi_client_codegen = speedy_utils.scripts.openapi_client_codegen:main
4
5
  svllm = llm_utils.scripts.vllm_serve:main