tuft 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__main__.py +7 -0
- tuft/backends/hf_training_model.py +184 -64
- tuft/cli.py +161 -8
- tuft/config.py +63 -59
- tuft/exceptions.py +66 -0
- tuft/futures.py +22 -2
- tuft/loss_fn/__init__.py +33 -0
- tuft/persistence/__init__.py +10 -2
- tuft/persistence/redis_store.py +352 -31
- tuft/sampling_controller.py +37 -11
- tuft/sequence_executor.py +72 -0
- tuft/server.py +9 -2
- tuft/state.py +3 -0
- tuft/training_controller.py +20 -5
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/METADATA +10 -66
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/RECORD +19 -17
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/WHEEL +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/entry_points.txt +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/licenses/LICENSE +0 -0
tuft/persistence/redis_store.py
CHANGED
|
@@ -8,9 +8,14 @@ Key Design:
|
|
|
8
8
|
- Nested records: {namespace}::{type}::{parent_id}::{nested_type}::{nested_id}
|
|
9
9
|
|
|
10
10
|
Persistence Modes:
|
|
11
|
-
-
|
|
12
|
-
-
|
|
13
|
-
-
|
|
11
|
+
- DISABLE: No persistence, all data is in-memory only
|
|
12
|
+
- REDIS: Use external Redis server via URL
|
|
13
|
+
- FILE: Use file-backed storage for tests and demos
|
|
14
|
+
|
|
15
|
+
Config Validation:
|
|
16
|
+
- On startup, the current config signature is compared with the stored signature
|
|
17
|
+
- If mismatch is detected, server stops with an error message
|
|
18
|
+
- Use `tuft clear persistence` to override and clear existing data
|
|
14
19
|
"""
|
|
15
20
|
|
|
16
21
|
from __future__ import annotations
|
|
@@ -19,12 +24,12 @@ import logging
|
|
|
19
24
|
import os
|
|
20
25
|
import threading
|
|
21
26
|
import time
|
|
22
|
-
from
|
|
27
|
+
from datetime import datetime, timezone
|
|
23
28
|
from enum import Enum
|
|
24
29
|
from pathlib import Path
|
|
25
30
|
from typing import Any, TypeVar
|
|
26
31
|
|
|
27
|
-
from pydantic import BaseModel
|
|
32
|
+
from pydantic import BaseModel, Field
|
|
28
33
|
|
|
29
34
|
|
|
30
35
|
logger = logging.getLogger(__name__)
|
|
@@ -50,71 +55,104 @@ T = TypeVar("T", bound=BaseModel)
|
|
|
50
55
|
class PersistenceMode(str, Enum):
|
|
51
56
|
"""Persistence mode options."""
|
|
52
57
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
58
|
+
DISABLE = "DISABLE" # No persistence
|
|
59
|
+
REDIS = "REDIS" # Use external Redis server
|
|
60
|
+
FILE = "FILE" # Use file-backed storage for tests/demos
|
|
56
61
|
|
|
57
62
|
|
|
58
63
|
# Default TTL values in seconds
|
|
59
64
|
DEFAULT_FUTURE_TTL_SECONDS = 24 * 3600 # 1 day for future records (short-lived)
|
|
60
65
|
|
|
61
66
|
|
|
62
|
-
|
|
63
|
-
|
|
67
|
+
class ConfigCheckField:
|
|
68
|
+
"""Available fields that can be checked for configuration validation.
|
|
69
|
+
|
|
70
|
+
Field names correspond directly to AppConfig attribute names.
|
|
71
|
+
SUPPORTED_MODELS is always required (mandatory) for restore safety.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
SUPPORTED_MODELS = "SUPPORTED_MODELS"
|
|
75
|
+
CHECKPOINT_DIR = "CHECKPOINT_DIR"
|
|
76
|
+
MODEL_OWNER = "MODEL_OWNER"
|
|
77
|
+
TOY_BACKEND_SEED = "TOY_BACKEND_SEED"
|
|
78
|
+
AUTHORIZED_USERS = "AUTHORIZED_USERS"
|
|
79
|
+
TELEMETRY = "TELEMETRY"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Default fields to check (supported_models is mandatory)
|
|
83
|
+
DEFAULT_CHECK_FIELDS: list[str] = [ConfigCheckField.SUPPORTED_MODELS]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PersistenceConfig(BaseModel):
|
|
64
87
|
"""Configuration for Redis persistence.
|
|
65
88
|
|
|
66
89
|
Attributes:
|
|
67
|
-
mode: Persistence mode -
|
|
68
|
-
redis_url: Redis server URL (only used when mode=
|
|
69
|
-
file_path: JSON file path (only used when mode=
|
|
70
|
-
namespace: Key namespace prefix
|
|
90
|
+
mode: Persistence mode - DISABLE, REDIS, or FILE
|
|
91
|
+
redis_url: Redis server URL (only used when mode=REDIS)
|
|
92
|
+
file_path: JSON file path (only used when mode=FILE)
|
|
93
|
+
namespace: Key namespace prefix for Redis keys. Defaults to "persistence-tuft-server".
|
|
71
94
|
future_ttl_seconds: TTL for future records in seconds. None means no expiry.
|
|
95
|
+
check_fields: List of AppConfig fields to validate on restart.
|
|
96
|
+
Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always
|
|
97
|
+
checked regardless of this setting for restore safety.
|
|
98
|
+
Available fields: SUPPORTED_MODELS, CHECKPOINT_DIR, MODEL_OWNER,
|
|
99
|
+
TOY_BACKEND_SEED, AUTHORIZED_USERS, TELEMETRY.
|
|
72
100
|
"""
|
|
73
101
|
|
|
74
|
-
|
|
102
|
+
# Allow Path type
|
|
103
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
104
|
+
|
|
105
|
+
mode: PersistenceMode = PersistenceMode.DISABLE
|
|
75
106
|
redis_url: str = "redis://localhost:6379/0"
|
|
76
107
|
file_path: Path | None = None
|
|
77
|
-
namespace: str = "tuft"
|
|
78
|
-
future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS
|
|
108
|
+
namespace: str = "persistence-tuft-server" # Default namespace for Redis keys
|
|
109
|
+
future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS
|
|
110
|
+
check_fields: list[str] = Field(default_factory=lambda: DEFAULT_CHECK_FIELDS.copy())
|
|
79
111
|
|
|
80
112
|
@property
|
|
81
113
|
def enabled(self) -> bool:
|
|
82
114
|
"""Check if persistence is enabled."""
|
|
83
|
-
return self.mode != PersistenceMode.
|
|
115
|
+
return self.mode != PersistenceMode.DISABLE
|
|
84
116
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
117
|
+
def get_check_fields(self) -> list[str]:
|
|
118
|
+
"""Get the fields to check, ensuring SUPPORTED_MODELS is always included."""
|
|
119
|
+
fields = list(self.check_fields)
|
|
120
|
+
if ConfigCheckField.SUPPORTED_MODELS not in fields:
|
|
121
|
+
fields.insert(0, ConfigCheckField.SUPPORTED_MODELS)
|
|
122
|
+
return fields
|
|
89
123
|
|
|
90
124
|
@classmethod
|
|
91
125
|
def from_redis_url(
|
|
92
126
|
cls,
|
|
93
127
|
redis_url: str,
|
|
94
|
-
namespace: str = "tuft",
|
|
128
|
+
namespace: str = "persistence-tuft-server",
|
|
95
129
|
future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS,
|
|
130
|
+
check_fields: list[str] | None = None,
|
|
96
131
|
) -> "PersistenceConfig":
|
|
97
132
|
"""Create a config using external Redis server."""
|
|
98
133
|
return cls(
|
|
99
|
-
mode=PersistenceMode.
|
|
134
|
+
mode=PersistenceMode.REDIS,
|
|
100
135
|
redis_url=redis_url,
|
|
101
136
|
namespace=namespace,
|
|
102
137
|
future_ttl_seconds=future_ttl_seconds,
|
|
138
|
+
check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(),
|
|
103
139
|
)
|
|
104
140
|
|
|
105
141
|
@classmethod
|
|
106
|
-
def
|
|
142
|
+
def from_file(
|
|
107
143
|
cls,
|
|
108
144
|
file_path: Path | None = None,
|
|
109
|
-
namespace: str = "tuft",
|
|
145
|
+
namespace: str = "persistence-tuft-server",
|
|
110
146
|
future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS,
|
|
147
|
+
check_fields: list[str] | None = None,
|
|
111
148
|
) -> "PersistenceConfig":
|
|
112
149
|
"""Create a config using file-backed storage."""
|
|
113
150
|
return cls(
|
|
114
|
-
mode=PersistenceMode.
|
|
151
|
+
mode=PersistenceMode.FILE,
|
|
115
152
|
file_path=file_path,
|
|
116
153
|
namespace=namespace,
|
|
117
154
|
future_ttl_seconds=future_ttl_seconds,
|
|
155
|
+
check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(),
|
|
118
156
|
)
|
|
119
157
|
|
|
120
158
|
|
|
@@ -123,7 +161,7 @@ class RedisStore:
|
|
|
123
161
|
|
|
124
162
|
Supports two modes:
|
|
125
163
|
- External Redis server (via redis-py)
|
|
126
|
-
- No persistence (
|
|
164
|
+
- No persistence (DISABLE mode)
|
|
127
165
|
"""
|
|
128
166
|
|
|
129
167
|
_instance: "RedisStore | None" = None
|
|
@@ -166,7 +204,7 @@ class RedisStore:
|
|
|
166
204
|
if self._redis is None or self._pid != current_pid:
|
|
167
205
|
self._close_connections()
|
|
168
206
|
|
|
169
|
-
if self._config.mode in (PersistenceMode.
|
|
207
|
+
if self._config.mode in (PersistenceMode.REDIS, PersistenceMode.FILE):
|
|
170
208
|
logger.info("Redis connection begin")
|
|
171
209
|
self._redis = self._create_redis_client()
|
|
172
210
|
|
|
@@ -181,7 +219,7 @@ class RedisStore:
|
|
|
181
219
|
if self._config is None:
|
|
182
220
|
return None
|
|
183
221
|
try:
|
|
184
|
-
if self._config.mode == PersistenceMode.
|
|
222
|
+
if self._config.mode == PersistenceMode.FILE:
|
|
185
223
|
from .file_redis import FileRedis
|
|
186
224
|
|
|
187
225
|
file_path = self._config.file_path or (
|
|
@@ -201,7 +239,7 @@ class RedisStore:
|
|
|
201
239
|
|
|
202
240
|
@property
|
|
203
241
|
def namespace(self) -> str:
|
|
204
|
-
return self._config.namespace if self._config else "tuft"
|
|
242
|
+
return self._config.namespace if self._config else "persistence-tuft-server"
|
|
205
243
|
|
|
206
244
|
@property
|
|
207
245
|
def future_ttl(self) -> int | None:
|
|
@@ -486,3 +524,286 @@ def is_persistence_enabled() -> bool:
|
|
|
486
524
|
def get_redis_store() -> RedisStore:
|
|
487
525
|
"""Get the global Redis store instance."""
|
|
488
526
|
return RedisStore.get_instance()
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def _now() -> datetime:
|
|
530
|
+
return datetime.now(timezone.utc)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class ConfigSignature(BaseModel):
|
|
534
|
+
"""Stores a complete snapshot of AppConfig for validation on restart.
|
|
535
|
+
|
|
536
|
+
Since AppConfig is now a Pydantic model, we directly store its serialized
|
|
537
|
+
form (excluding the persistence field which is runtime-only).
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
# Serialized AppConfig data (excludes persistence)
|
|
541
|
+
config_data: dict[str, Any] = Field(default_factory=dict)
|
|
542
|
+
|
|
543
|
+
# Metadata
|
|
544
|
+
created_at: datetime = Field(default_factory=_now)
|
|
545
|
+
namespace: str = "persistence-tuft-server"
|
|
546
|
+
|
|
547
|
+
@classmethod
|
|
548
|
+
def from_app_config(cls, config: Any) -> "ConfigSignature":
|
|
549
|
+
"""Create a signature by serializing the AppConfig."""
|
|
550
|
+
# Use the method on AppConfig to get persistence-safe data
|
|
551
|
+
config_data = config.get_config_for_persistence()
|
|
552
|
+
namespace = (
|
|
553
|
+
config.persistence.namespace if config.persistence else "persistence-tuft-server"
|
|
554
|
+
)
|
|
555
|
+
return cls(config_data=config_data, namespace=namespace)
|
|
556
|
+
|
|
557
|
+
def _get_field_value(self, field_name: str) -> Any:
|
|
558
|
+
"""Get the value of a field by name."""
|
|
559
|
+
lowercase_field = field_name.lower()
|
|
560
|
+
return self.config_data.get(lowercase_field)
|
|
561
|
+
|
|
562
|
+
def _normalize_for_comparison(self, value: Any) -> Any:
|
|
563
|
+
if isinstance(value, list):
|
|
564
|
+
normalized_items = []
|
|
565
|
+
for item in value:
|
|
566
|
+
if isinstance(item, dict):
|
|
567
|
+
normalized_items.append(tuple(sorted(item.items())))
|
|
568
|
+
else:
|
|
569
|
+
normalized_items.append(item)
|
|
570
|
+
# Sort for order-independent comparison
|
|
571
|
+
return sorted(normalized_items, key=lambda x: str(x))
|
|
572
|
+
return value
|
|
573
|
+
|
|
574
|
+
def _compare_field(self, other: "ConfigSignature", field_name: str) -> bool:
|
|
575
|
+
"""Compare a single field between two signatures."""
|
|
576
|
+
current_value = self._get_field_value(field_name)
|
|
577
|
+
other_value = other._get_field_value(field_name)
|
|
578
|
+
current_normalized = self._normalize_for_comparison(current_value)
|
|
579
|
+
other_normalized = self._normalize_for_comparison(other_value)
|
|
580
|
+
|
|
581
|
+
return current_normalized == other_normalized
|
|
582
|
+
|
|
583
|
+
def _get_field_diff(self, other: "ConfigSignature", field_name: str) -> dict[str, Any] | None:
|
|
584
|
+
"""Get the difference for a single field.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
{"current": value, "stored": value} if different, None otherwise.
|
|
588
|
+
"""
|
|
589
|
+
current_value = self._get_field_value(field_name)
|
|
590
|
+
other_value = other._get_field_value(field_name)
|
|
591
|
+
|
|
592
|
+
current_normalized = self._normalize_for_comparison(current_value)
|
|
593
|
+
other_normalized = self._normalize_for_comparison(other_value)
|
|
594
|
+
|
|
595
|
+
if current_normalized != other_normalized:
|
|
596
|
+
return {"current": current_value, "stored": other_value}
|
|
597
|
+
return None
|
|
598
|
+
|
|
599
|
+
def matches(
|
|
600
|
+
self,
|
|
601
|
+
other: "ConfigSignature",
|
|
602
|
+
check_fields: list[str] | None = None,
|
|
603
|
+
) -> bool:
|
|
604
|
+
"""Check if this signature matches another signature.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
other: The other signature to compare against.
|
|
608
|
+
check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS.
|
|
609
|
+
SUPPORTED_MODELS is always included (mandatory).
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
True if all specified fields match, False otherwise.
|
|
613
|
+
"""
|
|
614
|
+
fields_to_check = self._get_fields_to_check(check_fields)
|
|
615
|
+
|
|
616
|
+
for field_name in fields_to_check:
|
|
617
|
+
if not self._compare_field(other, field_name):
|
|
618
|
+
return False
|
|
619
|
+
return True
|
|
620
|
+
|
|
621
|
+
def get_diff(
|
|
622
|
+
self,
|
|
623
|
+
other: "ConfigSignature",
|
|
624
|
+
check_fields: list[str] | None = None,
|
|
625
|
+
) -> dict[str, dict[str, Any]]:
|
|
626
|
+
"""Get the differences between this signature and another.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
other: The other signature to compare against.
|
|
630
|
+
check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS.
|
|
631
|
+
SUPPORTED_MODELS is always included (mandatory).
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
Dict mapping field names to their differences.
|
|
635
|
+
"""
|
|
636
|
+
fields_to_check = self._get_fields_to_check(check_fields)
|
|
637
|
+
diff: dict[str, dict[str, Any]] = {}
|
|
638
|
+
|
|
639
|
+
for field_name in fields_to_check:
|
|
640
|
+
field_diff = self._get_field_diff(other, field_name)
|
|
641
|
+
if field_diff is not None:
|
|
642
|
+
diff[field_name] = field_diff
|
|
643
|
+
|
|
644
|
+
return diff
|
|
645
|
+
|
|
646
|
+
def _get_fields_to_check(self, check_fields: list[str] | None) -> list[str]:
|
|
647
|
+
"""Get the list of fields to check, ensuring mandatory fields are included."""
|
|
648
|
+
if check_fields is None:
|
|
649
|
+
return DEFAULT_CHECK_FIELDS.copy()
|
|
650
|
+
|
|
651
|
+
# Ensure SUPPORTED_MODELS is always included (mandatory)
|
|
652
|
+
fields = list(check_fields)
|
|
653
|
+
if ConfigCheckField.SUPPORTED_MODELS not in fields:
|
|
654
|
+
fields.insert(0, ConfigCheckField.SUPPORTED_MODELS)
|
|
655
|
+
return fields
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
CONFIG_SIGNATURE_KEY = "config_signature"
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def save_config_signature(config: Any) -> bool:
|
|
662
|
+
"""Save the config signature to Redis.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
config: The AppConfig to create a signature from.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
True if saved successfully, False otherwise.
|
|
669
|
+
"""
|
|
670
|
+
store = RedisStore.get_instance()
|
|
671
|
+
if not store.is_enabled:
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
signature = ConfigSignature.from_app_config(config)
|
|
675
|
+
key = store.build_key(CONFIG_SIGNATURE_KEY)
|
|
676
|
+
|
|
677
|
+
try:
|
|
678
|
+
json_str = signature.model_dump_json()
|
|
679
|
+
return store.set(key, json_str)
|
|
680
|
+
except Exception:
|
|
681
|
+
logger.exception("Failed to save config signature to Redis")
|
|
682
|
+
return False
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def load_config_signature() -> ConfigSignature | None:
|
|
686
|
+
"""Load the config signature from Redis.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
The stored ConfigSignature, or None if not found.
|
|
690
|
+
"""
|
|
691
|
+
store = RedisStore.get_instance()
|
|
692
|
+
if not store.is_enabled:
|
|
693
|
+
return None
|
|
694
|
+
|
|
695
|
+
key = store.build_key(CONFIG_SIGNATURE_KEY)
|
|
696
|
+
|
|
697
|
+
try:
|
|
698
|
+
json_str = store.get(key)
|
|
699
|
+
if json_str is None:
|
|
700
|
+
return None
|
|
701
|
+
return ConfigSignature.model_validate_json(json_str)
|
|
702
|
+
except Exception:
|
|
703
|
+
logger.exception("Failed to load config signature from Redis")
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def has_existing_data() -> bool:
|
|
708
|
+
"""Check if there is any existing data in the current namespace.
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
True if any keys exist in the namespace, False otherwise.
|
|
712
|
+
"""
|
|
713
|
+
store = RedisStore.get_instance()
|
|
714
|
+
if not store.is_enabled:
|
|
715
|
+
return False
|
|
716
|
+
|
|
717
|
+
pattern = f"{store.namespace}::*"
|
|
718
|
+
keys = store.keys(pattern)
|
|
719
|
+
return len(keys) > 0
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def validate_config_signature(config: Any) -> bool:
|
|
723
|
+
"""Validate that the current config matches the stored config signature.
|
|
724
|
+
|
|
725
|
+
This function ONLY reads from Redis, it does NOT write.
|
|
726
|
+
The signature should be saved after successful restore using
|
|
727
|
+
save_config_signature().
|
|
728
|
+
|
|
729
|
+
The fields to check are read from config.persistence.check_fields.
|
|
730
|
+
SUPPORTED_MODELS is always checked regardless of this setting.
|
|
731
|
+
|
|
732
|
+
This function handles several cases:
|
|
733
|
+
1. No signature AND no other data in namespace -> fresh start (return True)
|
|
734
|
+
2. No signature BUT other data exists -> corrupted/incompatible state, raise error
|
|
735
|
+
3. Signature exists and matches -> OK (return False, not fresh)
|
|
736
|
+
4. Signature exists but doesn't match -> raise error
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
config: The current AppConfig to validate.
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
True if this is a fresh start (no existing data), False otherwise.
|
|
743
|
+
|
|
744
|
+
Raises:
|
|
745
|
+
ConfigMismatchError: If the configs don't match or state is corrupted.
|
|
746
|
+
"""
|
|
747
|
+
from tuft.exceptions import ConfigMismatchError
|
|
748
|
+
|
|
749
|
+
stored = load_config_signature()
|
|
750
|
+
|
|
751
|
+
if stored is None:
|
|
752
|
+
# Check if there's any other data in the namespace
|
|
753
|
+
if has_existing_data():
|
|
754
|
+
# Data exists but no signature -> corrupted/incompatible state
|
|
755
|
+
logger.warning(
|
|
756
|
+
"Redis namespace has data but no config signature. "
|
|
757
|
+
"This indicates a corrupted or incompatible persistence state."
|
|
758
|
+
)
|
|
759
|
+
raise ConfigMismatchError(
|
|
760
|
+
diff={
|
|
761
|
+
"_state": {
|
|
762
|
+
"current": "valid configuration",
|
|
763
|
+
"stored": "missing signature (corrupted or legacy data)",
|
|
764
|
+
}
|
|
765
|
+
}
|
|
766
|
+
)
|
|
767
|
+
else:
|
|
768
|
+
# No data at all -> fresh start
|
|
769
|
+
logger.info("No stored config signature found - fresh start")
|
|
770
|
+
return True
|
|
771
|
+
|
|
772
|
+
# Get check_fields from persistence config
|
|
773
|
+
check_fields = config.persistence.get_check_fields() if config.persistence else None
|
|
774
|
+
|
|
775
|
+
current = ConfigSignature.from_app_config(config)
|
|
776
|
+
if not current.matches(stored, check_fields=check_fields):
|
|
777
|
+
diff = current.get_diff(stored, check_fields=check_fields)
|
|
778
|
+
raise ConfigMismatchError(diff)
|
|
779
|
+
|
|
780
|
+
logger.debug("Config signature validated successfully")
|
|
781
|
+
return False
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def get_current_namespace() -> str:
|
|
785
|
+
"""Get the current Redis namespace.
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
The namespace string, or 'tuft' if not configured.
|
|
789
|
+
"""
|
|
790
|
+
store = RedisStore.get_instance()
|
|
791
|
+
return store.namespace
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def flush_all_data() -> tuple[int, str]:
|
|
795
|
+
"""Clear all data from the current Redis namespace.
|
|
796
|
+
|
|
797
|
+
This removes all keys with the current namespace prefix.
|
|
798
|
+
Use with caution - this is destructive!
|
|
799
|
+
|
|
800
|
+
Returns:
|
|
801
|
+
A tuple of (number of keys deleted, namespace that was cleared).
|
|
802
|
+
"""
|
|
803
|
+
store = RedisStore.get_instance()
|
|
804
|
+
if not store.is_enabled:
|
|
805
|
+
return 0, store.namespace
|
|
806
|
+
|
|
807
|
+
pattern = f"{store.namespace}::*"
|
|
808
|
+
deleted_count = store.delete_pattern(pattern)
|
|
809
|
+
return deleted_count, store.namespace
|
tuft/sampling_controller.py
CHANGED
|
@@ -20,12 +20,17 @@ from .exceptions import (
|
|
|
20
20
|
CheckpointAccessDeniedException,
|
|
21
21
|
CheckpointNotFoundException,
|
|
22
22
|
MissingSequenceIDException,
|
|
23
|
-
SequenceConflictException,
|
|
24
23
|
SessionNotFoundException,
|
|
25
24
|
UnknownModelException,
|
|
26
25
|
UserMismatchException,
|
|
27
26
|
)
|
|
28
|
-
from .persistence import
|
|
27
|
+
from .persistence import (
|
|
28
|
+
get_redis_store,
|
|
29
|
+
is_persistence_enabled,
|
|
30
|
+
load_record,
|
|
31
|
+
save_record,
|
|
32
|
+
)
|
|
33
|
+
from .sequence_executor import SequenceExecutor
|
|
29
34
|
from .telemetry.metrics import get_metrics
|
|
30
35
|
from .telemetry.tracing import get_tracer
|
|
31
36
|
|
|
@@ -64,7 +69,12 @@ class SamplingSessionRecord(BaseModel):
|
|
|
64
69
|
model_path: str | None = None
|
|
65
70
|
session_seq_id: int
|
|
66
71
|
last_seq_id: int = -1
|
|
72
|
+
max_submitted_seq_id: int = -1
|
|
67
73
|
history: list[SamplingHistoryEntry] = Field(default_factory=list)
|
|
74
|
+
executor: SequenceExecutor = Field(default_factory=SequenceExecutor, exclude=True)
|
|
75
|
+
|
|
76
|
+
class Config:
|
|
77
|
+
arbitrary_types_allowed = True
|
|
68
78
|
|
|
69
79
|
|
|
70
80
|
class SamplingController:
|
|
@@ -100,6 +110,9 @@ class SamplingController:
|
|
|
100
110
|
if record.base_model and record.base_model not in self._base_backends:
|
|
101
111
|
invalid_sessions.append(record.sampling_session_id)
|
|
102
112
|
continue
|
|
113
|
+
# Initialize executor with next_sequence_id based on max_submitted_seq_id
|
|
114
|
+
# to avoid hanging on new requests after restore
|
|
115
|
+
record.executor = SequenceExecutor(next_sequence_id=record.max_submitted_seq_id + 1)
|
|
103
116
|
self.sampling_sessions[record.sampling_session_id] = record
|
|
104
117
|
for session_id in invalid_sessions:
|
|
105
118
|
store.delete(self._build_key(session_id))
|
|
@@ -181,8 +194,10 @@ class SamplingController:
|
|
|
181
194
|
if model_path:
|
|
182
195
|
# model_path should have higher priority than base_model
|
|
183
196
|
try:
|
|
197
|
+
assert self.config.checkpoint_dir is not None
|
|
184
198
|
parsed_checkpoint = CheckpointRecord.from_tinker_path(
|
|
185
|
-
model_path,
|
|
199
|
+
model_path,
|
|
200
|
+
self.config.checkpoint_dir,
|
|
186
201
|
)
|
|
187
202
|
except FileNotFoundError as exc:
|
|
188
203
|
raise CheckpointNotFoundException(checkpoint_id=model_path) from exc
|
|
@@ -221,7 +236,8 @@ class SamplingController:
|
|
|
221
236
|
model_path=str(adapter_path) if adapter_path else None,
|
|
222
237
|
session_seq_id=session_seq_id,
|
|
223
238
|
)
|
|
224
|
-
|
|
239
|
+
loop = asyncio.get_event_loop()
|
|
240
|
+
await loop.run_in_executor(None, self._save_session, sampling_session_id)
|
|
225
241
|
|
|
226
242
|
# Update metrics
|
|
227
243
|
get_metrics().sampling_sessions_active.add(1, {"base_model": base_model_ref or ""})
|
|
@@ -236,11 +252,9 @@ class SamplingController:
|
|
|
236
252
|
tokens = ",".join(str(token) for token in prompt.to_ints())
|
|
237
253
|
return hashlib.sha1(tokens.encode("utf-8")).hexdigest()[:16]
|
|
238
254
|
|
|
239
|
-
def _record_sequence(
|
|
255
|
+
async def _record_sequence(
|
|
240
256
|
self, record: SamplingSessionRecord, seq_id: int, prompt: types.ModelInput
|
|
241
257
|
) -> None:
|
|
242
|
-
if seq_id <= record.last_seq_id:
|
|
243
|
-
raise SequenceConflictException(expected=record.last_seq_id + 1, got=seq_id)
|
|
244
258
|
record.last_seq_id = seq_id
|
|
245
259
|
entry = SamplingHistoryEntry(
|
|
246
260
|
seq_id=seq_id,
|
|
@@ -248,9 +262,10 @@ class SamplingController:
|
|
|
248
262
|
prompt_hash=self._hash_prompt(prompt),
|
|
249
263
|
)
|
|
250
264
|
record.history.append(entry)
|
|
251
|
-
|
|
265
|
+
loop = asyncio.get_event_loop()
|
|
266
|
+
await loop.run_in_executor(None, self._save_session, record.sampling_session_id)
|
|
252
267
|
|
|
253
|
-
def _resolve_backend(
|
|
268
|
+
async def _resolve_backend(
|
|
254
269
|
self, request: types.SampleRequest, user_id: str
|
|
255
270
|
) -> Tuple[BaseSamplingBackend, str | None]:
|
|
256
271
|
"""Resolve the appropriate backend for the sampling request.
|
|
@@ -269,7 +284,18 @@ class SamplingController:
|
|
|
269
284
|
raise UserMismatchException()
|
|
270
285
|
if request.seq_id is None:
|
|
271
286
|
raise MissingSequenceIDException()
|
|
272
|
-
|
|
287
|
+
# Track the maximum submitted seq_id for recovery purposes
|
|
288
|
+
if request.seq_id > record.max_submitted_seq_id:
|
|
289
|
+
record.max_submitted_seq_id = request.seq_id
|
|
290
|
+
loop = asyncio.get_event_loop()
|
|
291
|
+
await loop.run_in_executor(None, self._save_session, record.sampling_session_id)
|
|
292
|
+
await record.executor.submit(
|
|
293
|
+
sequence_id=request.seq_id,
|
|
294
|
+
func=self._record_sequence,
|
|
295
|
+
record=record,
|
|
296
|
+
seq_id=request.seq_id,
|
|
297
|
+
prompt=request.prompt,
|
|
298
|
+
)
|
|
273
299
|
if record.base_model not in self._base_backends:
|
|
274
300
|
raise UnknownModelException(model_name=record.base_model)
|
|
275
301
|
if record.model_path is None:
|
|
@@ -297,7 +323,7 @@ class SamplingController:
|
|
|
297
323
|
logger.info("Sampling begin for %s", sampling_session_id)
|
|
298
324
|
start_time = time.perf_counter()
|
|
299
325
|
|
|
300
|
-
backend, lora_id = self._resolve_backend(request, user_id=user_id)
|
|
326
|
+
backend, lora_id = await self._resolve_backend(request, user_id=user_id)
|
|
301
327
|
prompt = request.prompt
|
|
302
328
|
sampling_params = request.sampling_params
|
|
303
329
|
num_samples = request.num_samples
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import heapq
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from .exceptions import SequenceConflictException, SequenceTimeoutException
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SequenceExecutor:
|
|
9
|
+
"""An executor that processes tasks strictly in the order of their `sequence_id`.
|
|
10
|
+
Out-of-order tasks are buffered until all previous sequence ids have been processed.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, timeout: float = 900, next_sequence_id: int = 0) -> None:
|
|
14
|
+
self.pending_heap = [] # (sequence_id, func, kwargs, future)
|
|
15
|
+
self.heap_lock = asyncio.Lock()
|
|
16
|
+
self.next_sequence_id = next_sequence_id
|
|
17
|
+
self._processing = False
|
|
18
|
+
self.timeout = timeout
|
|
19
|
+
|
|
20
|
+
async def submit(self, sequence_id: int, func: Callable, **kwargs) -> Any:
|
|
21
|
+
"""Submit a task with a specific sequence_id.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
sequence_id (int): The sequence ID of the task.
|
|
25
|
+
func (Callable): The async function to execute.
|
|
26
|
+
**kwargs: Keyword arguments to pass to the function.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Any: The result of the function execution.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
SequenceTimeoutException: If the task times out waiting for its turn.
|
|
33
|
+
SequenceConflictException: If a task with a lower sequence_id has already been
|
|
34
|
+
processed.
|
|
35
|
+
"""
|
|
36
|
+
if sequence_id < self.next_sequence_id:
|
|
37
|
+
raise SequenceConflictException(expected=self.next_sequence_id, got=sequence_id)
|
|
38
|
+
future = asyncio.Future()
|
|
39
|
+
async with self.heap_lock:
|
|
40
|
+
heapq.heappush(self.pending_heap, (sequence_id, func, kwargs, future))
|
|
41
|
+
# Start processing if not already running
|
|
42
|
+
if not self._processing:
|
|
43
|
+
self._processing = True
|
|
44
|
+
asyncio.create_task(self._process_tasks())
|
|
45
|
+
try:
|
|
46
|
+
result = await asyncio.wait_for(future, timeout=self.timeout)
|
|
47
|
+
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
|
48
|
+
raise SequenceTimeoutException(sequence_id) from e
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
async def _process_tasks(self):
|
|
52
|
+
while True:
|
|
53
|
+
async with self.heap_lock:
|
|
54
|
+
if not self.pending_heap:
|
|
55
|
+
self._processing = False
|
|
56
|
+
break
|
|
57
|
+
# Peek at the smallest sequence_id
|
|
58
|
+
sequence_id, func, kwargs, future = self.pending_heap[0]
|
|
59
|
+
if sequence_id != self.next_sequence_id:
|
|
60
|
+
# wait next sequence_id
|
|
61
|
+
self._processing = False
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
heapq.heappop(self.pending_heap)
|
|
65
|
+
self.next_sequence_id += 1
|
|
66
|
+
try:
|
|
67
|
+
result = await func(**kwargs)
|
|
68
|
+
if not future.done():
|
|
69
|
+
future.set_result(result)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
if not future.done():
|
|
72
|
+
future.set_exception(e)
|