tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
tuft/loss_fn/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from ..exceptions import (
|
|
7
|
+
LossFunctionInputShapeMismatchException,
|
|
8
|
+
LossFunctionMissingInputException,
|
|
9
|
+
LossFunctionNotFoundException,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
LossFnType: TypeAlias = Callable[
|
|
14
|
+
[Dict[str, Tensor], Dict[str, float]], Tuple[Tensor, Dict[str, float]]
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
LOSS_FN = {
|
|
18
|
+
"cispo": "tuft.loss_fn.cispo.cispo_loss",
|
|
19
|
+
"cross_entropy": "tuft.loss_fn.cross_entropy.cross_entropy_loss",
|
|
20
|
+
"dro": "tuft.loss_fn.dro.dro_loss",
|
|
21
|
+
"importance_sampling": "tuft.loss_fn.importance_sampling.importance_sampling_loss",
|
|
22
|
+
"ppo": "tuft.loss_fn.ppo.ppo_loss",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_loss_fn(loss_fn_name: str) -> LossFnType:
|
|
27
|
+
"""Retrieve the loss function by name."""
|
|
28
|
+
if loss_fn_name not in LOSS_FN:
|
|
29
|
+
raise LossFunctionNotFoundException(loss_fn_name)
|
|
30
|
+
|
|
31
|
+
module_path, func_name = LOSS_FN[loss_fn_name].rsplit(".", 1)
|
|
32
|
+
module = __import__(module_path, fromlist=[func_name])
|
|
33
|
+
return getattr(module, func_name)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _check_loss_fn_inputs(
|
|
37
|
+
loss_fn_inputs: Dict[str, Tensor], required_keys: Tuple[str, ...], check_shapes: bool = False
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Check if all required keys are present in loss_fn_inputs and optionally
|
|
40
|
+
check if their shapes match."""
|
|
41
|
+
for key in required_keys:
|
|
42
|
+
if key not in loss_fn_inputs:
|
|
43
|
+
raise LossFunctionMissingInputException(key)
|
|
44
|
+
|
|
45
|
+
if check_shapes:
|
|
46
|
+
shapes = [loss_fn_inputs[key].shape for key in required_keys]
|
|
47
|
+
if not all(shape == shapes[0] for shape in shapes):
|
|
48
|
+
raise LossFunctionInputShapeMismatchException(shapes)
|
tuft/loss_fn/cispo.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _check_loss_fn_inputs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cispo_loss(
|
|
9
|
+
loss_fn_inputs: Dict[str, torch.Tensor], loss_fn_config: Dict[str, float]
|
|
10
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
11
|
+
"""Computes the Clipped Importance Sampling Policy Optimization (CISPO) loss.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
loss_fn_inputs: A dictionary of tensors required for the loss function.
|
|
15
|
+
Expected keys: "target_logprobs", "logprobs", "advantages".
|
|
16
|
+
loss_fn_config: A dictionary of configuration parameters for the loss function.
|
|
17
|
+
Expected keys: "clip_low_threshold", "clip_high_threshold".
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
A tuple containing the computed loss and a dictionary of metrics.
|
|
21
|
+
"""
|
|
22
|
+
_check_loss_fn_inputs(
|
|
23
|
+
loss_fn_inputs, ("target_logprobs", "logprobs", "advantages"), check_shapes=True
|
|
24
|
+
)
|
|
25
|
+
target_logprobs = loss_fn_inputs["target_logprobs"]
|
|
26
|
+
sampling_logprobs = loss_fn_inputs["logprobs"]
|
|
27
|
+
advantages = loss_fn_inputs["advantages"]
|
|
28
|
+
clip_low_threshold = loss_fn_config.get("clip_low_threshold", 0.9)
|
|
29
|
+
clip_high_threshold = loss_fn_config.get("clip_high_threshold", 1.1)
|
|
30
|
+
|
|
31
|
+
# Compute probability ratio
|
|
32
|
+
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
|
|
33
|
+
# Apply clipping
|
|
34
|
+
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
|
|
35
|
+
# Compute CISPO objective (detach the clipped ratio)
|
|
36
|
+
cispo_objective = clipped_ratio.detach() * target_logprobs * advantages
|
|
37
|
+
# CISPO loss is negative of objective
|
|
38
|
+
loss = -cispo_objective.sum()
|
|
39
|
+
|
|
40
|
+
return loss, {"loss:sum": loss.item()}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _check_loss_fn_inputs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cross_entropy_loss(
|
|
9
|
+
loss_fn_inputs: Dict[str, torch.Tensor], loss_fn_config: Dict[str, float]
|
|
10
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
11
|
+
"""Computes the Cross Entropy loss.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
loss_fn_inputs: A dictionary of tensors required for the loss function.
|
|
15
|
+
Expected keys: "target_logprobs", "weights".
|
|
16
|
+
loss_fn_config: A dictionary of configuration parameters for the loss function.
|
|
17
|
+
(No expected keys for this loss function.)
|
|
18
|
+
Returns:
|
|
19
|
+
A tuple containing the computed loss and a dictionary of metrics.
|
|
20
|
+
"""
|
|
21
|
+
_check_loss_fn_inputs(loss_fn_inputs, ("target_logprobs", "weights"), check_shapes=True)
|
|
22
|
+
target_logprobs = loss_fn_inputs["target_logprobs"]
|
|
23
|
+
weights = loss_fn_inputs["weights"]
|
|
24
|
+
|
|
25
|
+
loss = -(target_logprobs * weights).sum()
|
|
26
|
+
return loss, {"loss:sum": loss.item()}
|
tuft/loss_fn/dro.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _check_loss_fn_inputs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def dro_loss(
|
|
9
|
+
loss_fn_inputs: Dict[str, torch.Tensor], loss_fn_config: Dict[str, float]
|
|
10
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
11
|
+
"""Computes the Distributionally Robust Optimization (DRO) loss.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
loss_fn_inputs: A dictionary of tensors required for the loss function.
|
|
15
|
+
Expected keys: "target_logprobs", "logprobs", "advantages".
|
|
16
|
+
loss_fn_config: A dictionary of configuration parameters for the loss function.
|
|
17
|
+
Expected keys: "beta".
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
A tuple containing the computed loss and a dictionary of metrics.
|
|
21
|
+
"""
|
|
22
|
+
_check_loss_fn_inputs(
|
|
23
|
+
loss_fn_inputs, ("target_logprobs", "logprobs", "advantages"), check_shapes=True
|
|
24
|
+
)
|
|
25
|
+
target_logprobs = loss_fn_inputs["target_logprobs"]
|
|
26
|
+
sampling_logprobs = loss_fn_inputs["logprobs"]
|
|
27
|
+
advantages = loss_fn_inputs["advantages"]
|
|
28
|
+
beta = loss_fn_config.get("beta", 0.01)
|
|
29
|
+
|
|
30
|
+
# Compute quadratic penalty term
|
|
31
|
+
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
|
|
32
|
+
# Compute DRO objective
|
|
33
|
+
dro_objective = target_logprobs * advantages - 0.5 * beta * quadratic_term
|
|
34
|
+
# DRO loss is negative of objective
|
|
35
|
+
loss = -dro_objective.sum()
|
|
36
|
+
|
|
37
|
+
return loss, {"loss:sum": loss.item()}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _check_loss_fn_inputs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def importance_sampling_loss(
|
|
9
|
+
loss_fn_inputs: Dict[str, torch.Tensor], loss_fn_config: Dict[str, float]
|
|
10
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
11
|
+
"""Computes the importance sampling loss.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
loss_fn_inputs: A dictionary of tensors required for the loss function.
|
|
15
|
+
Expected keys: "target_logprobs", "logprobs", "advantages".
|
|
16
|
+
loss_fn_config: This parameter is unused.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A tuple containing the computed loss and a dictionary of metrics.
|
|
20
|
+
"""
|
|
21
|
+
_check_loss_fn_inputs(
|
|
22
|
+
loss_fn_inputs, ("target_logprobs", "logprobs", "advantages"), check_shapes=True
|
|
23
|
+
)
|
|
24
|
+
target_logprobs = loss_fn_inputs["target_logprobs"]
|
|
25
|
+
sampling_logprobs = loss_fn_inputs["logprobs"]
|
|
26
|
+
advantages = loss_fn_inputs["advantages"]
|
|
27
|
+
|
|
28
|
+
# Compute probability ratio
|
|
29
|
+
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
|
|
30
|
+
# Compute importance-weighted loss
|
|
31
|
+
loss = -(prob_ratio * advantages).sum()
|
|
32
|
+
|
|
33
|
+
return loss, {"loss:sum": loss.item()}
|
tuft/loss_fn/ppo.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _check_loss_fn_inputs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def ppo_loss(
|
|
9
|
+
loss_fn_inputs: Dict[str, torch.Tensor], loss_fn_config: Dict[str, float]
|
|
10
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
11
|
+
"""Computes the Proximal Policy Optimization (PPO) loss.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
loss_fn_inputs: A dictionary of tensors required for the loss function.
|
|
15
|
+
Expected keys: "target_logprobs", "logprobs", "advantages".
|
|
16
|
+
loss_fn_config: A dictionary of configuration parameters for the loss function.
|
|
17
|
+
Expected keys: "clip_low_threshold", "clip_high_threshold".
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
A tuple containing the computed loss and a dictionary of metrics.
|
|
21
|
+
"""
|
|
22
|
+
_check_loss_fn_inputs(
|
|
23
|
+
loss_fn_inputs, ("target_logprobs", "logprobs", "advantages"), check_shapes=True
|
|
24
|
+
)
|
|
25
|
+
target_logprobs = loss_fn_inputs["target_logprobs"]
|
|
26
|
+
sampling_logprobs = loss_fn_inputs["logprobs"]
|
|
27
|
+
advantages = loss_fn_inputs["advantages"]
|
|
28
|
+
clip_low_threshold = loss_fn_config.get("clip_low_threshold", 0.9)
|
|
29
|
+
clip_high_threshold = loss_fn_config.get("clip_high_threshold", 1.1)
|
|
30
|
+
|
|
31
|
+
# Compute probability ratio
|
|
32
|
+
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
|
|
33
|
+
# Apply clipping
|
|
34
|
+
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
|
|
35
|
+
# Compute both objectives
|
|
36
|
+
unclipped_objective = prob_ratio * advantages
|
|
37
|
+
clipped_objective = clipped_ratio * advantages
|
|
38
|
+
# Take minimum (most conservative)
|
|
39
|
+
ppo_objective = torch.min(unclipped_objective, clipped_objective)
|
|
40
|
+
# PPO loss is negative of objective
|
|
41
|
+
loss = -ppo_objective.sum()
|
|
42
|
+
|
|
43
|
+
return loss, {"loss:sum": loss.item()}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Persistence package exports."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .redis_store import (
|
|
6
|
+
DEFAULT_FUTURE_TTL_SECONDS,
|
|
7
|
+
PersistenceConfig,
|
|
8
|
+
PersistenceMode,
|
|
9
|
+
RedisPipeline,
|
|
10
|
+
RedisStore,
|
|
11
|
+
delete_record,
|
|
12
|
+
get_redis_store,
|
|
13
|
+
is_persistence_enabled,
|
|
14
|
+
load_record,
|
|
15
|
+
save_record,
|
|
16
|
+
save_records_atomic,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"DEFAULT_FUTURE_TTL_SECONDS",
|
|
22
|
+
"PersistenceConfig",
|
|
23
|
+
"PersistenceMode",
|
|
24
|
+
"RedisPipeline",
|
|
25
|
+
"RedisStore",
|
|
26
|
+
"delete_record",
|
|
27
|
+
"get_redis_store",
|
|
28
|
+
"is_persistence_enabled",
|
|
29
|
+
"load_record",
|
|
30
|
+
"save_record",
|
|
31
|
+
"save_records_atomic",
|
|
32
|
+
]
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""File-backed Redis-like store for small demos and tests.
|
|
2
|
+
|
|
3
|
+
This module implements a minimal subset of redis-py behaviors with a JSON
|
|
4
|
+
backing file. It is designed for low-volume usage where performance is not a
|
|
5
|
+
concern. All write operations flush the full in-memory state to disk.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from tuft.persistence.file_redis import FileRedis
|
|
11
|
+
|
|
12
|
+
store = FileRedis(Path("~/.cache/tuft/file_redis.json").expanduser())
|
|
13
|
+
store.set("alpha", "1")
|
|
14
|
+
store.setex("beta", 5, "2")
|
|
15
|
+
assert store.get("alpha") == "1"
|
|
16
|
+
for key in store.scan_iter(match="a*"):
|
|
17
|
+
print(key)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import json
|
|
23
|
+
import logging
|
|
24
|
+
import threading
|
|
25
|
+
import time
|
|
26
|
+
from dataclasses import dataclass
|
|
27
|
+
from fnmatch import fnmatch
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Iterable
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class _FileRedisValue:
|
|
37
|
+
value: str
|
|
38
|
+
expires_at: float | None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FileRedis:
|
|
42
|
+
"""Tiny file-backed Redis-like store for tests and demos.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
file_path: Path to the JSON file used for persistence.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
from pathlib import Path
|
|
49
|
+
|
|
50
|
+
store = FileRedis(Path("/tmp/file_redis.json"))
|
|
51
|
+
store.set("key", "value")
|
|
52
|
+
assert store.get("key") == "value"
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, file_path: Path) -> None:
|
|
56
|
+
self._file_path = Path(file_path)
|
|
57
|
+
self._file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
58
|
+
self._lock = threading.Lock()
|
|
59
|
+
self._data: dict[str, _FileRedisValue] = {}
|
|
60
|
+
self._load()
|
|
61
|
+
|
|
62
|
+
def _load(self) -> None:
|
|
63
|
+
"""Load persisted data from disk into memory."""
|
|
64
|
+
if not self._file_path.exists():
|
|
65
|
+
return
|
|
66
|
+
try:
|
|
67
|
+
raw = json.loads(self._file_path.read_text(encoding="utf-8"))
|
|
68
|
+
for key, payload in raw.items():
|
|
69
|
+
if not isinstance(payload, dict):
|
|
70
|
+
continue
|
|
71
|
+
self._data[key] = _FileRedisValue(
|
|
72
|
+
value=str(payload.get("value", "")),
|
|
73
|
+
expires_at=payload.get("expires_at"),
|
|
74
|
+
)
|
|
75
|
+
except (json.JSONDecodeError, OSError):
|
|
76
|
+
logger.exception("Failed to load FileRedis data from %s", self._file_path)
|
|
77
|
+
|
|
78
|
+
def _dump(self) -> None:
|
|
79
|
+
"""Write the in-memory store to disk as JSON."""
|
|
80
|
+
payload = {
|
|
81
|
+
key: {"value": entry.value, "expires_at": entry.expires_at}
|
|
82
|
+
for key, entry in self._data.items()
|
|
83
|
+
}
|
|
84
|
+
tmp_path = self._file_path.with_suffix(self._file_path.suffix + ".tmp")
|
|
85
|
+
tmp_path.write_text(json.dumps(payload, ensure_ascii=True), encoding="utf-8")
|
|
86
|
+
tmp_path.replace(self._file_path)
|
|
87
|
+
|
|
88
|
+
def _purge_expired(self) -> None:
|
|
89
|
+
"""Remove expired keys and persist the updated store."""
|
|
90
|
+
now = time.time()
|
|
91
|
+
expired = [
|
|
92
|
+
key for key, entry in self._data.items() if entry.expires_at and entry.expires_at <= now
|
|
93
|
+
]
|
|
94
|
+
if expired:
|
|
95
|
+
for key in expired:
|
|
96
|
+
self._data.pop(key, None)
|
|
97
|
+
self._dump()
|
|
98
|
+
|
|
99
|
+
def set(self, key: str, value: str) -> bool:
|
|
100
|
+
"""Set a key to a string value.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
key: Key to set.
|
|
104
|
+
value: String value to store.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
True on success.
|
|
108
|
+
"""
|
|
109
|
+
with self._lock:
|
|
110
|
+
self._data[key] = _FileRedisValue(value=value, expires_at=None)
|
|
111
|
+
self._dump()
|
|
112
|
+
return True
|
|
113
|
+
|
|
114
|
+
def setex(self, key: str, ttl_seconds: int | float, value: str) -> bool:
|
|
115
|
+
"""Set a key with TTL in seconds.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
key: Key to set.
|
|
119
|
+
ttl_seconds: Time-to-live in seconds.
|
|
120
|
+
value: String value to store.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
True on success.
|
|
124
|
+
"""
|
|
125
|
+
with self._lock:
|
|
126
|
+
expires_at = time.time() + float(ttl_seconds)
|
|
127
|
+
self._data[key] = _FileRedisValue(value=value, expires_at=expires_at)
|
|
128
|
+
self._dump()
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
def get(self, key: str) -> str | None:
|
|
132
|
+
"""Get a string value by key.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
key: Key to retrieve.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The stored value, or None if missing/expired.
|
|
139
|
+
"""
|
|
140
|
+
with self._lock:
|
|
141
|
+
self._purge_expired()
|
|
142
|
+
entry = self._data.get(key)
|
|
143
|
+
return entry.value if entry else None
|
|
144
|
+
|
|
145
|
+
def delete(self, *keys: str) -> int:
|
|
146
|
+
"""Delete one or more keys.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
*keys: Keys to delete.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Number of keys removed.
|
|
153
|
+
"""
|
|
154
|
+
removed = 0
|
|
155
|
+
with self._lock:
|
|
156
|
+
for key in keys:
|
|
157
|
+
if key in self._data:
|
|
158
|
+
self._data.pop(key, None)
|
|
159
|
+
removed += 1
|
|
160
|
+
if removed:
|
|
161
|
+
self._dump()
|
|
162
|
+
return removed
|
|
163
|
+
|
|
164
|
+
def exists(self, key: str) -> int:
|
|
165
|
+
"""Check if a key exists.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
key: Key to check.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
1 if the key exists, otherwise 0.
|
|
172
|
+
"""
|
|
173
|
+
with self._lock:
|
|
174
|
+
self._purge_expired()
|
|
175
|
+
return 1 if key in self._data else 0
|
|
176
|
+
|
|
177
|
+
def scan_iter(self, match: str | None = None) -> Iterable[str]:
|
|
178
|
+
"""Iterate over keys matching a pattern.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
match: Optional glob pattern (e.g., "prefix:*").
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
An iterator of matching keys.
|
|
185
|
+
"""
|
|
186
|
+
with self._lock:
|
|
187
|
+
self._purge_expired()
|
|
188
|
+
keys = list(self._data.keys())
|
|
189
|
+
pattern = match or "*"
|
|
190
|
+
for key in keys:
|
|
191
|
+
if fnmatch(key, pattern):
|
|
192
|
+
yield key
|
|
193
|
+
|
|
194
|
+
def pipeline(self, transaction: bool = True) -> "FileRedisPipeline":
|
|
195
|
+
"""Create a pipeline for batched operations.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
transaction: Ignored; kept for compatibility.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
A FileRedisPipeline instance.
|
|
202
|
+
"""
|
|
203
|
+
_ = transaction # kept for signature compatibility
|
|
204
|
+
return FileRedisPipeline(self)
|
|
205
|
+
|
|
206
|
+
def close(self) -> None:
|
|
207
|
+
"""No-op close for API compatibility."""
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class FileRedisPipeline:
|
|
212
|
+
"""Minimal pipeline that writes once on exit.
|
|
213
|
+
|
|
214
|
+
Example:
|
|
215
|
+
with store.pipeline() as pipe:
|
|
216
|
+
pipe.set("a", "1")
|
|
217
|
+
pipe.setex("b", 10, "2")
|
|
218
|
+
pipe.delete("c")
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, store: FileRedis) -> None:
|
|
222
|
+
self._store = store
|
|
223
|
+
self._ops: list[tuple[str, tuple]] = []
|
|
224
|
+
|
|
225
|
+
def __enter__(self) -> "FileRedisPipeline":
|
|
226
|
+
return self
|
|
227
|
+
|
|
228
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
229
|
+
if exc_type is None:
|
|
230
|
+
self._execute()
|
|
231
|
+
|
|
232
|
+
def set(self, key: str, value: str) -> "FileRedisPipeline":
|
|
233
|
+
"""Queue a SET operation."""
|
|
234
|
+
self._ops.append(("set", (key, value)))
|
|
235
|
+
return self
|
|
236
|
+
|
|
237
|
+
def setex(self, key: str, ttl_seconds: int | float, value: str) -> "FileRedisPipeline":
|
|
238
|
+
"""Queue a SETEX operation."""
|
|
239
|
+
self._ops.append(("setex", (key, ttl_seconds, value)))
|
|
240
|
+
return self
|
|
241
|
+
|
|
242
|
+
def delete(self, *keys: str) -> "FileRedisPipeline":
|
|
243
|
+
"""Queue a DELETE operation."""
|
|
244
|
+
self._ops.append(("delete", keys))
|
|
245
|
+
return self
|
|
246
|
+
|
|
247
|
+
def _execute(self) -> None:
|
|
248
|
+
"""Apply queued operations and flush to disk."""
|
|
249
|
+
with self._store._lock:
|
|
250
|
+
for op, args in self._ops:
|
|
251
|
+
if op == "set":
|
|
252
|
+
key, value = args
|
|
253
|
+
self._store._data[key] = _FileRedisValue(value=value, expires_at=None)
|
|
254
|
+
elif op == "setex":
|
|
255
|
+
key, ttl_seconds, value = args
|
|
256
|
+
expires_at = time.time() + float(ttl_seconds)
|
|
257
|
+
self._store._data[key] = _FileRedisValue(value=value, expires_at=expires_at)
|
|
258
|
+
elif op == "delete":
|
|
259
|
+
for key in args:
|
|
260
|
+
self._store._data.pop(key, None)
|
|
261
|
+
if self._ops:
|
|
262
|
+
self._store._dump()
|
|
263
|
+
self._ops.clear()
|
|
264
|
+
|
|
265
|
+
def execute(self) -> list[object]:
|
|
266
|
+
"""Execute queued operations (redis-py compatibility)."""
|
|
267
|
+
self._execute()
|
|
268
|
+
return []
|