tuft 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,9 +8,14 @@ Key Design:
8
8
  - Nested records: {namespace}::{type}::{parent_id}::{nested_type}::{nested_id}
9
9
 
10
10
  Persistence Modes:
11
- - disabled: No persistence, all data is in-memory only
12
- - redis_url: Use external Redis server via URL
13
- - file_redis: Use file-backed storage for tests and demos
11
+ - DISABLE: No persistence, all data is in-memory only
12
+ - REDIS: Use external Redis server via URL
13
+ - FILE: Use file-backed storage for tests and demos
14
+
15
+ Config Validation:
16
+ - On startup, the current config signature is compared with the stored signature
17
+ - If mismatch is detected, server stops with an error message
18
+ - Use `tuft clear persistence` to override and clear existing data
14
19
  """
15
20
 
16
21
  from __future__ import annotations
@@ -19,12 +24,12 @@ import logging
19
24
  import os
20
25
  import threading
21
26
  import time
22
- from dataclasses import dataclass
27
+ from datetime import datetime, timezone
23
28
  from enum import Enum
24
29
  from pathlib import Path
25
30
  from typing import Any, TypeVar
26
31
 
27
- from pydantic import BaseModel
32
+ from pydantic import BaseModel, Field
28
33
 
29
34
 
30
35
  logger = logging.getLogger(__name__)
@@ -50,71 +55,104 @@ T = TypeVar("T", bound=BaseModel)
50
55
  class PersistenceMode(str, Enum):
51
56
  """Persistence mode options."""
52
57
 
53
- DISABLED = "disabled" # No persistence
54
- REDIS_URL = "redis_url" # Use external Redis server
55
- FILE_REDIS = "file_redis" # Use file-backed storage for tests/demos
58
+ DISABLE = "DISABLE" # No persistence
59
+ REDIS = "REDIS" # Use external Redis server
60
+ FILE = "FILE" # Use file-backed storage for tests/demos
56
61
 
57
62
 
58
63
  # Default TTL values in seconds
59
64
  DEFAULT_FUTURE_TTL_SECONDS = 24 * 3600 # 1 day for future records (short-lived)
60
65
 
61
66
 
62
- @dataclass
63
- class PersistenceConfig:
67
+ class ConfigCheckField:
68
+ """Available fields that can be checked for configuration validation.
69
+
70
+ Field names correspond directly to AppConfig attribute names.
71
+ SUPPORTED_MODELS is always required (mandatory) for restore safety.
72
+ """
73
+
74
+ SUPPORTED_MODELS = "SUPPORTED_MODELS"
75
+ CHECKPOINT_DIR = "CHECKPOINT_DIR"
76
+ MODEL_OWNER = "MODEL_OWNER"
77
+ TOY_BACKEND_SEED = "TOY_BACKEND_SEED"
78
+ AUTHORIZED_USERS = "AUTHORIZED_USERS"
79
+ TELEMETRY = "TELEMETRY"
80
+
81
+
82
+ # Default fields to check (supported_models is mandatory)
83
+ DEFAULT_CHECK_FIELDS: list[str] = [ConfigCheckField.SUPPORTED_MODELS]
84
+
85
+
86
+ class PersistenceConfig(BaseModel):
64
87
  """Configuration for Redis persistence.
65
88
 
66
89
  Attributes:
67
- mode: Persistence mode - disabled, redis_url, or file_redis
68
- redis_url: Redis server URL (only used when mode=redis_url)
69
- file_path: JSON file path (only used when mode=file_redis)
70
- namespace: Key namespace prefix
90
+ mode: Persistence mode - DISABLE, REDIS, or FILE
91
+ redis_url: Redis server URL (only used when mode=REDIS)
92
+ file_path: JSON file path (only used when mode=FILE)
93
+ namespace: Key namespace prefix for Redis keys. Defaults to "persistence-tuft-server".
71
94
  future_ttl_seconds: TTL for future records in seconds. None means no expiry.
95
+ check_fields: List of AppConfig fields to validate on restart.
96
+ Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always
97
+ checked regardless of this setting for restore safety.
98
+ Available fields: SUPPORTED_MODELS, CHECKPOINT_DIR, MODEL_OWNER,
99
+ TOY_BACKEND_SEED, AUTHORIZED_USERS, TELEMETRY.
72
100
  """
73
101
 
74
- mode: PersistenceMode = PersistenceMode.DISABLED
102
+ # Allow Path type
103
+ model_config = {"arbitrary_types_allowed": True}
104
+
105
+ mode: PersistenceMode = PersistenceMode.DISABLE
75
106
  redis_url: str = "redis://localhost:6379/0"
76
107
  file_path: Path | None = None
77
- namespace: str = "tuft"
78
- future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS # Futures expire after 1 day
108
+ namespace: str = "persistence-tuft-server" # Default namespace for Redis keys
109
+ future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS
110
+ check_fields: list[str] = Field(default_factory=lambda: DEFAULT_CHECK_FIELDS.copy())
79
111
 
80
112
  @property
81
113
  def enabled(self) -> bool:
82
114
  """Check if persistence is enabled."""
83
- return self.mode != PersistenceMode.DISABLED
115
+ return self.mode != PersistenceMode.DISABLE
84
116
 
85
- @classmethod
86
- def disabled(cls, namespace: str = "tuft") -> "PersistenceConfig":
87
- """Create a disabled persistence config."""
88
- return cls(mode=PersistenceMode.DISABLED, namespace=namespace)
117
+ def get_check_fields(self) -> list[str]:
118
+ """Get the fields to check, ensuring SUPPORTED_MODELS is always included."""
119
+ fields = list(self.check_fields)
120
+ if ConfigCheckField.SUPPORTED_MODELS not in fields:
121
+ fields.insert(0, ConfigCheckField.SUPPORTED_MODELS)
122
+ return fields
89
123
 
90
124
  @classmethod
91
125
  def from_redis_url(
92
126
  cls,
93
127
  redis_url: str,
94
- namespace: str = "tuft",
128
+ namespace: str = "persistence-tuft-server",
95
129
  future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS,
130
+ check_fields: list[str] | None = None,
96
131
  ) -> "PersistenceConfig":
97
132
  """Create a config using external Redis server."""
98
133
  return cls(
99
- mode=PersistenceMode.REDIS_URL,
134
+ mode=PersistenceMode.REDIS,
100
135
  redis_url=redis_url,
101
136
  namespace=namespace,
102
137
  future_ttl_seconds=future_ttl_seconds,
138
+ check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(),
103
139
  )
104
140
 
105
141
  @classmethod
106
- def from_file_redis(
142
+ def from_file(
107
143
  cls,
108
144
  file_path: Path | None = None,
109
- namespace: str = "tuft",
145
+ namespace: str = "persistence-tuft-server",
110
146
  future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS,
147
+ check_fields: list[str] | None = None,
111
148
  ) -> "PersistenceConfig":
112
149
  """Create a config using file-backed storage."""
113
150
  return cls(
114
- mode=PersistenceMode.FILE_REDIS,
151
+ mode=PersistenceMode.FILE,
115
152
  file_path=file_path,
116
153
  namespace=namespace,
117
154
  future_ttl_seconds=future_ttl_seconds,
155
+ check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(),
118
156
  )
119
157
 
120
158
 
@@ -123,7 +161,7 @@ class RedisStore:
123
161
 
124
162
  Supports two modes:
125
163
  - External Redis server (via redis-py)
126
- - No persistence (disabled mode)
164
+ - No persistence (DISABLE mode)
127
165
  """
128
166
 
129
167
  _instance: "RedisStore | None" = None
@@ -166,7 +204,7 @@ class RedisStore:
166
204
  if self._redis is None or self._pid != current_pid:
167
205
  self._close_connections()
168
206
 
169
- if self._config.mode in (PersistenceMode.REDIS_URL, PersistenceMode.FILE_REDIS):
207
+ if self._config.mode in (PersistenceMode.REDIS, PersistenceMode.FILE):
170
208
  logger.info("Redis connection begin")
171
209
  self._redis = self._create_redis_client()
172
210
 
@@ -181,7 +219,7 @@ class RedisStore:
181
219
  if self._config is None:
182
220
  return None
183
221
  try:
184
- if self._config.mode == PersistenceMode.FILE_REDIS:
222
+ if self._config.mode == PersistenceMode.FILE:
185
223
  from .file_redis import FileRedis
186
224
 
187
225
  file_path = self._config.file_path or (
@@ -201,7 +239,7 @@ class RedisStore:
201
239
 
202
240
  @property
203
241
  def namespace(self) -> str:
204
- return self._config.namespace if self._config else "tuft"
242
+ return self._config.namespace if self._config else "persistence-tuft-server"
205
243
 
206
244
  @property
207
245
  def future_ttl(self) -> int | None:
@@ -486,3 +524,286 @@ def is_persistence_enabled() -> bool:
486
524
  def get_redis_store() -> RedisStore:
487
525
  """Get the global Redis store instance."""
488
526
  return RedisStore.get_instance()
527
+
528
+
529
+ def _now() -> datetime:
530
+ return datetime.now(timezone.utc)
531
+
532
+
533
+ class ConfigSignature(BaseModel):
534
+ """Stores a complete snapshot of AppConfig for validation on restart.
535
+
536
+ Since AppConfig is now a Pydantic model, we directly store its serialized
537
+ form (excluding the persistence field which is runtime-only).
538
+ """
539
+
540
+ # Serialized AppConfig data (excludes persistence)
541
+ config_data: dict[str, Any] = Field(default_factory=dict)
542
+
543
+ # Metadata
544
+ created_at: datetime = Field(default_factory=_now)
545
+ namespace: str = "persistence-tuft-server"
546
+
547
+ @classmethod
548
+ def from_app_config(cls, config: Any) -> "ConfigSignature":
549
+ """Create a signature by serializing the AppConfig."""
550
+ # Use the method on AppConfig to get persistence-safe data
551
+ config_data = config.get_config_for_persistence()
552
+ namespace = (
553
+ config.persistence.namespace if config.persistence else "persistence-tuft-server"
554
+ )
555
+ return cls(config_data=config_data, namespace=namespace)
556
+
557
+ def _get_field_value(self, field_name: str) -> Any:
558
+ """Get the value of a field by name."""
559
+ lowercase_field = field_name.lower()
560
+ return self.config_data.get(lowercase_field)
561
+
562
+ def _normalize_for_comparison(self, value: Any) -> Any:
563
+ if isinstance(value, list):
564
+ normalized_items = []
565
+ for item in value:
566
+ if isinstance(item, dict):
567
+ normalized_items.append(tuple(sorted(item.items())))
568
+ else:
569
+ normalized_items.append(item)
570
+ # Sort for order-independent comparison
571
+ return sorted(normalized_items, key=lambda x: str(x))
572
+ return value
573
+
574
+ def _compare_field(self, other: "ConfigSignature", field_name: str) -> bool:
575
+ """Compare a single field between two signatures."""
576
+ current_value = self._get_field_value(field_name)
577
+ other_value = other._get_field_value(field_name)
578
+ current_normalized = self._normalize_for_comparison(current_value)
579
+ other_normalized = self._normalize_for_comparison(other_value)
580
+
581
+ return current_normalized == other_normalized
582
+
583
+ def _get_field_diff(self, other: "ConfigSignature", field_name: str) -> dict[str, Any] | None:
584
+ """Get the difference for a single field.
585
+
586
+ Returns:
587
+ {"current": value, "stored": value} if different, None otherwise.
588
+ """
589
+ current_value = self._get_field_value(field_name)
590
+ other_value = other._get_field_value(field_name)
591
+
592
+ current_normalized = self._normalize_for_comparison(current_value)
593
+ other_normalized = self._normalize_for_comparison(other_value)
594
+
595
+ if current_normalized != other_normalized:
596
+ return {"current": current_value, "stored": other_value}
597
+ return None
598
+
599
+ def matches(
600
+ self,
601
+ other: "ConfigSignature",
602
+ check_fields: list[str] | None = None,
603
+ ) -> bool:
604
+ """Check if this signature matches another signature.
605
+
606
+ Args:
607
+ other: The other signature to compare against.
608
+ check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS.
609
+ SUPPORTED_MODELS is always included (mandatory).
610
+
611
+ Returns:
612
+ True if all specified fields match, False otherwise.
613
+ """
614
+ fields_to_check = self._get_fields_to_check(check_fields)
615
+
616
+ for field_name in fields_to_check:
617
+ if not self._compare_field(other, field_name):
618
+ return False
619
+ return True
620
+
621
+ def get_diff(
622
+ self,
623
+ other: "ConfigSignature",
624
+ check_fields: list[str] | None = None,
625
+ ) -> dict[str, dict[str, Any]]:
626
+ """Get the differences between this signature and another.
627
+
628
+ Args:
629
+ other: The other signature to compare against.
630
+ check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS.
631
+ SUPPORTED_MODELS is always included (mandatory).
632
+
633
+ Returns:
634
+ Dict mapping field names to their differences.
635
+ """
636
+ fields_to_check = self._get_fields_to_check(check_fields)
637
+ diff: dict[str, dict[str, Any]] = {}
638
+
639
+ for field_name in fields_to_check:
640
+ field_diff = self._get_field_diff(other, field_name)
641
+ if field_diff is not None:
642
+ diff[field_name] = field_diff
643
+
644
+ return diff
645
+
646
+ def _get_fields_to_check(self, check_fields: list[str] | None) -> list[str]:
647
+ """Get the list of fields to check, ensuring mandatory fields are included."""
648
+ if check_fields is None:
649
+ return DEFAULT_CHECK_FIELDS.copy()
650
+
651
+ # Ensure SUPPORTED_MODELS is always included (mandatory)
652
+ fields = list(check_fields)
653
+ if ConfigCheckField.SUPPORTED_MODELS not in fields:
654
+ fields.insert(0, ConfigCheckField.SUPPORTED_MODELS)
655
+ return fields
656
+
657
+
658
+ CONFIG_SIGNATURE_KEY = "config_signature"
659
+
660
+
661
+ def save_config_signature(config: Any) -> bool:
662
+ """Save the config signature to Redis.
663
+
664
+ Args:
665
+ config: The AppConfig to create a signature from.
666
+
667
+ Returns:
668
+ True if saved successfully, False otherwise.
669
+ """
670
+ store = RedisStore.get_instance()
671
+ if not store.is_enabled:
672
+ return False
673
+
674
+ signature = ConfigSignature.from_app_config(config)
675
+ key = store.build_key(CONFIG_SIGNATURE_KEY)
676
+
677
+ try:
678
+ json_str = signature.model_dump_json()
679
+ return store.set(key, json_str)
680
+ except Exception:
681
+ logger.exception("Failed to save config signature to Redis")
682
+ return False
683
+
684
+
685
+ def load_config_signature() -> ConfigSignature | None:
686
+ """Load the config signature from Redis.
687
+
688
+ Returns:
689
+ The stored ConfigSignature, or None if not found.
690
+ """
691
+ store = RedisStore.get_instance()
692
+ if not store.is_enabled:
693
+ return None
694
+
695
+ key = store.build_key(CONFIG_SIGNATURE_KEY)
696
+
697
+ try:
698
+ json_str = store.get(key)
699
+ if json_str is None:
700
+ return None
701
+ return ConfigSignature.model_validate_json(json_str)
702
+ except Exception:
703
+ logger.exception("Failed to load config signature from Redis")
704
+ return None
705
+
706
+
707
+ def has_existing_data() -> bool:
708
+ """Check if there is any existing data in the current namespace.
709
+
710
+ Returns:
711
+ True if any keys exist in the namespace, False otherwise.
712
+ """
713
+ store = RedisStore.get_instance()
714
+ if not store.is_enabled:
715
+ return False
716
+
717
+ pattern = f"{store.namespace}::*"
718
+ keys = store.keys(pattern)
719
+ return len(keys) > 0
720
+
721
+
722
+ def validate_config_signature(config: Any) -> bool:
723
+ """Validate that the current config matches the stored config signature.
724
+
725
+ This function ONLY reads from Redis, it does NOT write.
726
+ The signature should be saved after successful restore using
727
+ save_config_signature().
728
+
729
+ The fields to check are read from config.persistence.check_fields.
730
+ SUPPORTED_MODELS is always checked regardless of this setting.
731
+
732
+ This function handles several cases:
733
+ 1. No signature AND no other data in namespace -> fresh start (return True)
734
+ 2. No signature BUT other data exists -> corrupted/incompatible state, raise error
735
+ 3. Signature exists and matches -> OK (return False, not fresh)
736
+ 4. Signature exists but doesn't match -> raise error
737
+
738
+ Args:
739
+ config: The current AppConfig to validate.
740
+
741
+ Returns:
742
+ True if this is a fresh start (no existing data), False otherwise.
743
+
744
+ Raises:
745
+ ConfigMismatchError: If the configs don't match or state is corrupted.
746
+ """
747
+ from tuft.exceptions import ConfigMismatchError
748
+
749
+ stored = load_config_signature()
750
+
751
+ if stored is None:
752
+ # Check if there's any other data in the namespace
753
+ if has_existing_data():
754
+ # Data exists but no signature -> corrupted/incompatible state
755
+ logger.warning(
756
+ "Redis namespace has data but no config signature. "
757
+ "This indicates a corrupted or incompatible persistence state."
758
+ )
759
+ raise ConfigMismatchError(
760
+ diff={
761
+ "_state": {
762
+ "current": "valid configuration",
763
+ "stored": "missing signature (corrupted or legacy data)",
764
+ }
765
+ }
766
+ )
767
+ else:
768
+ # No data at all -> fresh start
769
+ logger.info("No stored config signature found - fresh start")
770
+ return True
771
+
772
+ # Get check_fields from persistence config
773
+ check_fields = config.persistence.get_check_fields() if config.persistence else None
774
+
775
+ current = ConfigSignature.from_app_config(config)
776
+ if not current.matches(stored, check_fields=check_fields):
777
+ diff = current.get_diff(stored, check_fields=check_fields)
778
+ raise ConfigMismatchError(diff)
779
+
780
+ logger.debug("Config signature validated successfully")
781
+ return False
782
+
783
+
784
+ def get_current_namespace() -> str:
785
+ """Get the current Redis namespace.
786
+
787
+ Returns:
788
+ The namespace string, or 'tuft' if not configured.
789
+ """
790
+ store = RedisStore.get_instance()
791
+ return store.namespace
792
+
793
+
794
+ def flush_all_data() -> tuple[int, str]:
795
+ """Clear all data from the current Redis namespace.
796
+
797
+ This removes all keys with the current namespace prefix.
798
+ Use with caution - this is destructive!
799
+
800
+ Returns:
801
+ A tuple of (number of keys deleted, namespace that was cleared).
802
+ """
803
+ store = RedisStore.get_instance()
804
+ if not store.is_enabled:
805
+ return 0, store.namespace
806
+
807
+ pattern = f"{store.namespace}::*"
808
+ deleted_count = store.delete_pattern(pattern)
809
+ return deleted_count, store.namespace
@@ -20,12 +20,17 @@ from .exceptions import (
20
20
  CheckpointAccessDeniedException,
21
21
  CheckpointNotFoundException,
22
22
  MissingSequenceIDException,
23
- SequenceConflictException,
24
23
  SessionNotFoundException,
25
24
  UnknownModelException,
26
25
  UserMismatchException,
27
26
  )
28
- from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
27
+ from .persistence import (
28
+ get_redis_store,
29
+ is_persistence_enabled,
30
+ load_record,
31
+ save_record,
32
+ )
33
+ from .sequence_executor import SequenceExecutor
29
34
  from .telemetry.metrics import get_metrics
30
35
  from .telemetry.tracing import get_tracer
31
36
 
@@ -64,7 +69,12 @@ class SamplingSessionRecord(BaseModel):
64
69
  model_path: str | None = None
65
70
  session_seq_id: int
66
71
  last_seq_id: int = -1
72
+ max_submitted_seq_id: int = -1
67
73
  history: list[SamplingHistoryEntry] = Field(default_factory=list)
74
+ executor: SequenceExecutor = Field(default_factory=SequenceExecutor, exclude=True)
75
+
76
+ class Config:
77
+ arbitrary_types_allowed = True
68
78
 
69
79
 
70
80
  class SamplingController:
@@ -100,6 +110,9 @@ class SamplingController:
100
110
  if record.base_model and record.base_model not in self._base_backends:
101
111
  invalid_sessions.append(record.sampling_session_id)
102
112
  continue
113
+ # Initialize executor with next_sequence_id based on max_submitted_seq_id
114
+ # to avoid hanging on new requests after restore
115
+ record.executor = SequenceExecutor(next_sequence_id=record.max_submitted_seq_id + 1)
103
116
  self.sampling_sessions[record.sampling_session_id] = record
104
117
  for session_id in invalid_sessions:
105
118
  store.delete(self._build_key(session_id))
@@ -181,8 +194,10 @@ class SamplingController:
181
194
  if model_path:
182
195
  # model_path should have higher priority than base_model
183
196
  try:
197
+ assert self.config.checkpoint_dir is not None
184
198
  parsed_checkpoint = CheckpointRecord.from_tinker_path(
185
- model_path, self.config.checkpoint_dir
199
+ model_path,
200
+ self.config.checkpoint_dir,
186
201
  )
187
202
  except FileNotFoundError as exc:
188
203
  raise CheckpointNotFoundException(checkpoint_id=model_path) from exc
@@ -221,7 +236,8 @@ class SamplingController:
221
236
  model_path=str(adapter_path) if adapter_path else None,
222
237
  session_seq_id=session_seq_id,
223
238
  )
224
- self._save_session(sampling_session_id)
239
+ loop = asyncio.get_event_loop()
240
+ await loop.run_in_executor(None, self._save_session, sampling_session_id)
225
241
 
226
242
  # Update metrics
227
243
  get_metrics().sampling_sessions_active.add(1, {"base_model": base_model_ref or ""})
@@ -236,11 +252,9 @@ class SamplingController:
236
252
  tokens = ",".join(str(token) for token in prompt.to_ints())
237
253
  return hashlib.sha1(tokens.encode("utf-8")).hexdigest()[:16]
238
254
 
239
- def _record_sequence(
255
+ async def _record_sequence(
240
256
  self, record: SamplingSessionRecord, seq_id: int, prompt: types.ModelInput
241
257
  ) -> None:
242
- if seq_id <= record.last_seq_id:
243
- raise SequenceConflictException(expected=record.last_seq_id + 1, got=seq_id)
244
258
  record.last_seq_id = seq_id
245
259
  entry = SamplingHistoryEntry(
246
260
  seq_id=seq_id,
@@ -248,9 +262,10 @@ class SamplingController:
248
262
  prompt_hash=self._hash_prompt(prompt),
249
263
  )
250
264
  record.history.append(entry)
251
- self._save_session(record.sampling_session_id)
265
+ loop = asyncio.get_event_loop()
266
+ await loop.run_in_executor(None, self._save_session, record.sampling_session_id)
252
267
 
253
- def _resolve_backend(
268
+ async def _resolve_backend(
254
269
  self, request: types.SampleRequest, user_id: str
255
270
  ) -> Tuple[BaseSamplingBackend, str | None]:
256
271
  """Resolve the appropriate backend for the sampling request.
@@ -269,7 +284,18 @@ class SamplingController:
269
284
  raise UserMismatchException()
270
285
  if request.seq_id is None:
271
286
  raise MissingSequenceIDException()
272
- self._record_sequence(record, request.seq_id, request.prompt)
287
+ # Track the maximum submitted seq_id for recovery purposes
288
+ if request.seq_id > record.max_submitted_seq_id:
289
+ record.max_submitted_seq_id = request.seq_id
290
+ loop = asyncio.get_event_loop()
291
+ await loop.run_in_executor(None, self._save_session, record.sampling_session_id)
292
+ await record.executor.submit(
293
+ sequence_id=request.seq_id,
294
+ func=self._record_sequence,
295
+ record=record,
296
+ seq_id=request.seq_id,
297
+ prompt=request.prompt,
298
+ )
273
299
  if record.base_model not in self._base_backends:
274
300
  raise UnknownModelException(model_name=record.base_model)
275
301
  if record.model_path is None:
@@ -297,7 +323,7 @@ class SamplingController:
297
323
  logger.info("Sampling begin for %s", sampling_session_id)
298
324
  start_time = time.perf_counter()
299
325
 
300
- backend, lora_id = self._resolve_backend(request, user_id=user_id)
326
+ backend, lora_id = await self._resolve_backend(request, user_id=user_id)
301
327
  prompt = request.prompt
302
328
  sampling_params = request.sampling_params
303
329
  num_samples = request.num_samples
@@ -0,0 +1,72 @@
1
+ import asyncio
2
+ import heapq
3
+ from typing import Any, Callable
4
+
5
+ from .exceptions import SequenceConflictException, SequenceTimeoutException
6
+
7
+
8
+ class SequenceExecutor:
9
+ """An executor that processes tasks strictly in the order of their `sequence_id`.
10
+ Out-of-order tasks are buffered until all previous sequence ids have been processed.
11
+ """
12
+
13
+ def __init__(self, timeout: float = 900, next_sequence_id: int = 0) -> None:
14
+ self.pending_heap = [] # (sequence_id, func, kwargs, future)
15
+ self.heap_lock = asyncio.Lock()
16
+ self.next_sequence_id = next_sequence_id
17
+ self._processing = False
18
+ self.timeout = timeout
19
+
20
+ async def submit(self, sequence_id: int, func: Callable, **kwargs) -> Any:
21
+ """Submit a task with a specific sequence_id.
22
+
23
+ Args:
24
+ sequence_id (int): The sequence ID of the task.
25
+ func (Callable): The async function to execute.
26
+ **kwargs: Keyword arguments to pass to the function.
27
+
28
+ Returns:
29
+ Any: The result of the function execution.
30
+
31
+ Raises:
32
+ SequenceTimeoutException: If the task times out waiting for its turn.
33
+ SequenceConflictException: If a task with a lower sequence_id has already been
34
+ processed.
35
+ """
36
+ if sequence_id < self.next_sequence_id:
37
+ raise SequenceConflictException(expected=self.next_sequence_id, got=sequence_id)
38
+ future = asyncio.Future()
39
+ async with self.heap_lock:
40
+ heapq.heappush(self.pending_heap, (sequence_id, func, kwargs, future))
41
+ # Start processing if not already running
42
+ if not self._processing:
43
+ self._processing = True
44
+ asyncio.create_task(self._process_tasks())
45
+ try:
46
+ result = await asyncio.wait_for(future, timeout=self.timeout)
47
+ except (asyncio.TimeoutError, asyncio.CancelledError) as e:
48
+ raise SequenceTimeoutException(sequence_id) from e
49
+ return result
50
+
51
+ async def _process_tasks(self):
52
+ while True:
53
+ async with self.heap_lock:
54
+ if not self.pending_heap:
55
+ self._processing = False
56
+ break
57
+ # Peek at the smallest sequence_id
58
+ sequence_id, func, kwargs, future = self.pending_heap[0]
59
+ if sequence_id != self.next_sequence_id:
60
+ # wait next sequence_id
61
+ self._processing = False
62
+ break
63
+
64
+ heapq.heappop(self.pending_heap)
65
+ self.next_sequence_id += 1
66
+ try:
67
+ result = await func(**kwargs)
68
+ if not future.done():
69
+ future.set_result(result)
70
+ except Exception as e:
71
+ if not future.done():
72
+ future.set_exception(e)