tigen 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigen/__init__.py +0 -0
- tigen/ai/__init__.py +10 -0
- tigen/ai/brain.py +29 -0
- tigen/ai/context.py +18 -0
- tigen/ai/memory.py +230 -0
- tigen/common/__init__.py +0 -0
- tigen/common/ds/__init__.py +0 -0
- tigen/common/ds/generational.py +266 -0
- tigen/common/ds/running_stats.py +68 -0
- tigen/common/enum.py +75 -0
- tigen/common/extensions.py +87 -0
- tigen/common/formatting.py +41 -0
- tigen/common/logging.py +75 -0
- tigen/common/math.py +196 -0
- tigen/config.py +29 -0
- tigen/ecs/__init__.py +0 -0
- tigen/ecs/component.py +80 -0
- tigen/ecs/core.py +307 -0
- tigen/ecs/query.py +72 -0
- tigen/ecs/system.py +25 -0
- tigen-0.2.1.dist-info/METADATA +54 -0
- tigen-0.2.1.dist-info/RECORD +23 -0
- tigen-0.2.1.dist-info/WHEEL +4 -0
tigen/__init__.py
ADDED
|
File without changes
|
tigen/ai/__init__.py
ADDED
tigen/ai/brain.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import enum
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from tigen.ai import ActionStep
|
|
7
|
+
from tigen.ai.context import BrainContext
|
|
8
|
+
from tigen.ai.memory import Memory
|
|
9
|
+
|
|
10
|
+
TGoal = TypeVar("TGoal", bound=enum.Enum)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(slots=True)
|
|
14
|
+
class GoalSelector(abc.ABC, Generic[TGoal]):
|
|
15
|
+
@abc.abstractmethod
|
|
16
|
+
def select_goal(self, ctx: BrainContext) -> TGoal: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(slots=True)
|
|
20
|
+
class Planner(abc.ABC, Generic[TGoal]):
|
|
21
|
+
@abc.abstractmethod
|
|
22
|
+
def make_plan(self, ctx: BrainContext, goal: str) -> list[ActionStep]: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(slots=True)
|
|
26
|
+
class Brain(Generic[TGoal]):
|
|
27
|
+
goal_selector: GoalSelector[TGoal]
|
|
28
|
+
planner: Planner[TGoal]
|
|
29
|
+
memory: Memory
|
tigen/ai/context.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from tigen.ai.memory import Memory, MemoryData
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(slots=True, frozen=True)
|
|
7
|
+
class BrainContext:
|
|
8
|
+
"""
|
|
9
|
+
Context for the AI brain, providing necessary data for decision-making.
|
|
10
|
+
All "sensory" data is normalized to a range of 0.0 to 1.0.
|
|
11
|
+
This context is used by the AI engine to make decisions and plans.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
simulation_time: int
|
|
15
|
+
eid: int
|
|
16
|
+
etype: str
|
|
17
|
+
memory_data: MemoryData
|
|
18
|
+
memory_engine: Memory
|
tigen/ai/memory.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import enum
|
|
3
|
+
import heapq
|
|
4
|
+
import math
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Callable, Iterator
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from tigen.common.math import Vector, cosine
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(slots=True)
|
|
14
|
+
class MemoryFact:
|
|
15
|
+
"""
|
|
16
|
+
*uid* : unique and stable identifier
|
|
17
|
+
*t0* : tick when the fact was first encoded
|
|
18
|
+
*value*: **opaque payload** (Memory never inspects this)
|
|
19
|
+
*ctx* : optional context vector for similarity search
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
uid: str
|
|
23
|
+
tag: str
|
|
24
|
+
t0: int
|
|
25
|
+
value: Any
|
|
26
|
+
ctx: Vector | None = field(init=False, default=None)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryType(str, enum.Enum):
|
|
30
|
+
PERFECT = "PERFECT"
|
|
31
|
+
IMPERFECT = "IMPERFECT"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(slots=True)
|
|
35
|
+
class MemoryData:
|
|
36
|
+
ltm: dict[str, MemoryFact] = field(default_factory=dict[str, MemoryFact])
|
|
37
|
+
stm: list[MemoryFact] = field(default_factory=list[MemoryFact])
|
|
38
|
+
strength: dict[str, float] = field(default_factory=dict[str, float])
|
|
39
|
+
cue: dict[str, set[str]] = field(default_factory=dict[str, set[str]])
|
|
40
|
+
rng_state: int = field(default=0xDEADBEEF) # splitmix state
|
|
41
|
+
|
|
42
|
+
def all(self) -> Iterator[MemoryFact]:
|
|
43
|
+
if self.stm:
|
|
44
|
+
yield from self.stm
|
|
45
|
+
if self.ltm.values():
|
|
46
|
+
yield from self.ltm.values()
|
|
47
|
+
|
|
48
|
+
def exists(self, uid: str) -> bool:
|
|
49
|
+
"""Check if a fact with the given UID exists in either STM or LTM."""
|
|
50
|
+
return uid in self.ltm or any(fact.uid == uid for fact in self.stm)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(slots=True, frozen=True)
|
|
54
|
+
class MemQuery:
|
|
55
|
+
fact_type: type[MemoryFact] | None = None
|
|
56
|
+
uid: str | None = None
|
|
57
|
+
where: Callable[[MemoryFact], bool] | None = None
|
|
58
|
+
ctx: Vector | None = None # current context for activation boost
|
|
59
|
+
k: int = 5
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def tag_eq(tag: str) -> "MemQuery":
|
|
63
|
+
"""Create a query for a specific tag."""
|
|
64
|
+
return MemQuery(where=lambda f: f.tag == tag)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass(slots=True)
|
|
68
|
+
class Memory(abc.ABC):
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def remember(self, md: MemoryData, fact: MemoryFact) -> None: ...
|
|
71
|
+
|
|
72
|
+
@abc.abstractmethod
|
|
73
|
+
def recall(self, md: MemoryData, q: MemQuery, now: int) -> list[MemoryFact]: ...
|
|
74
|
+
|
|
75
|
+
@abc.abstractmethod
|
|
76
|
+
def tick(self, md: MemoryData, dt: float, now: int) -> None: ...
|
|
77
|
+
|
|
78
|
+
@abc.abstractmethod
|
|
79
|
+
def forget(self, md: MemoryData, uid: str) -> None: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ———– PerfectMemory ————————————————————————————————————————
|
|
83
|
+
class PerfectMemory(Memory):
|
|
84
|
+
def remember(self, md: MemoryData, fact: MemoryFact) -> None:
|
|
85
|
+
md.ltm[fact.uid] = fact
|
|
86
|
+
|
|
87
|
+
def recall(self, md: MemoryData, q: MemQuery, now: int) -> list[MemoryFact]:
|
|
88
|
+
cand = (fact for facts in (md.ltm.values(), md.stm) for fact in facts)
|
|
89
|
+
if q.where:
|
|
90
|
+
cand = filter(q.where, cand)
|
|
91
|
+
return list(cand)[: q.k]
|
|
92
|
+
|
|
93
|
+
def forget(self, md: MemoryData, uid: str) -> None:
|
|
94
|
+
"""Forget a fact by UID."""
|
|
95
|
+
md.ltm.pop(uid, None)
|
|
96
|
+
md.stm = [f for f in md.stm if f.uid != uid]
|
|
97
|
+
for s in md.cue.values():
|
|
98
|
+
s.discard(uid)
|
|
99
|
+
|
|
100
|
+
def tick(self, md: MemoryData, dt: float, now: int) -> None:
|
|
101
|
+
# no‐op for perfect memory
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ———– ImperfectMemory knobs & logic ——————————————————————————
|
|
106
|
+
class ImperfectMemory(Memory):
|
|
107
|
+
ENCODE_NOISE_P = 0.10
|
|
108
|
+
STM_SPAN_TICKS = 300
|
|
109
|
+
INTERFERENCE_P = 0.40
|
|
110
|
+
DECAY_HALF_LIFE = 30_000
|
|
111
|
+
RETRIEVAL_NOISE_SD = 0.25
|
|
112
|
+
|
|
113
|
+
def remember(self, md: MemoryData, fact: MemoryFact) -> None:
|
|
114
|
+
rng = random.Random(md.rng_state)
|
|
115
|
+
if rng.random() < self.ENCODE_NOISE_P:
|
|
116
|
+
# fact = {**fact, "noisy": True} # your distortion
|
|
117
|
+
...
|
|
118
|
+
md.stm.append(fact)
|
|
119
|
+
md.rng_state = rng.getrandbits(64)
|
|
120
|
+
|
|
121
|
+
def tick(self, md: MemoryData, dt: float, now: int) -> None:
|
|
122
|
+
# STM → LTM + interference
|
|
123
|
+
while md.stm and now - md.stm[0].t0 >= self.STM_SPAN_TICKS:
|
|
124
|
+
f = md.stm.pop(0)
|
|
125
|
+
bucket = f.uid.split("-", 1)[0]
|
|
126
|
+
md.cue.setdefault(bucket, set()).add(f.uid)
|
|
127
|
+
md.ltm[f.uid] = f
|
|
128
|
+
md.strength[f.uid] = md.strength.get(f.uid, 0.01) + 1.0
|
|
129
|
+
|
|
130
|
+
# decay
|
|
131
|
+
k = 0.5 ** (dt / self.DECAY_HALF_LIFE)
|
|
132
|
+
dead: list[str] = []
|
|
133
|
+
for uid, s in md.strength.items():
|
|
134
|
+
s *= k
|
|
135
|
+
if s < 0.02:
|
|
136
|
+
dead.append(uid)
|
|
137
|
+
md.strength[uid] = s
|
|
138
|
+
for uid in dead:
|
|
139
|
+
md.ltm.pop(uid, None)
|
|
140
|
+
for s in md.cue.values():
|
|
141
|
+
s.discard(uid)
|
|
142
|
+
|
|
143
|
+
def recall(self, md: MemoryData, q: MemQuery, now: int) -> list[MemoryFact]:
|
|
144
|
+
rng = random.Random(md.rng_state)
|
|
145
|
+
result_heap: list[
|
|
146
|
+
tuple[float, MemoryFact]
|
|
147
|
+
] = [] # Will store (-score, fact) pairs for min-heap behavior with max scores
|
|
148
|
+
|
|
149
|
+
# Process candidates from both LTM and STM without materializing full lists
|
|
150
|
+
def process_fact(f: MemoryFact):
|
|
151
|
+
if q.where and not q.where(f):
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
base = md.strength.get(f.uid, 0.01)
|
|
155
|
+
act = math.log(base / (now - f.t0 or 1e-9))
|
|
156
|
+
sim = 0.0
|
|
157
|
+
if q.ctx and f.ctx is not None:
|
|
158
|
+
sim = cosine(q.ctx, f.ctx)
|
|
159
|
+
noise = rng.gauss(0.0, self.RETRIEVAL_NOISE_SD)
|
|
160
|
+
score = act + sim + noise
|
|
161
|
+
|
|
162
|
+
# Use negative score for min-heap to function as max-heap
|
|
163
|
+
if len(result_heap) < q.k:
|
|
164
|
+
heapq.heappush(result_heap, (-score, f))
|
|
165
|
+
elif -score > result_heap[0][0]:
|
|
166
|
+
heapq.heappushpop(result_heap, (-score, f))
|
|
167
|
+
|
|
168
|
+
# Process long-term memory
|
|
169
|
+
for f in md.ltm.values():
|
|
170
|
+
process_fact(f)
|
|
171
|
+
|
|
172
|
+
# Process short-term memory
|
|
173
|
+
for f in md.stm:
|
|
174
|
+
process_fact(f)
|
|
175
|
+
|
|
176
|
+
md.rng_state = rng.getrandbits(64)
|
|
177
|
+
|
|
178
|
+
# Extract results in descending order of score
|
|
179
|
+
results = [f for _, f in sorted(result_heap, key=lambda x: x[0])]
|
|
180
|
+
return results
|
|
181
|
+
|
|
182
|
+
def forget(self, md: MemoryData, uid: str) -> None:
|
|
183
|
+
"""Forget a fact by UID."""
|
|
184
|
+
md.ltm.pop(uid, None)
|
|
185
|
+
md.stm = [f for f in md.stm if f.uid != uid]
|
|
186
|
+
for s in md.cue.values():
|
|
187
|
+
s.discard(uid)
|
|
188
|
+
md.strength.pop(uid, None)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# shared singleton instances
|
|
192
|
+
_perfect_memory = PerfectMemory()
|
|
193
|
+
_imperfect_memory = ImperfectMemory()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def get_memory(memory_type: MemoryType) -> Memory:
|
|
197
|
+
if memory_type == MemoryType.PERFECT:
|
|
198
|
+
return _perfect_memory
|
|
199
|
+
return _imperfect_memory
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# ———– façade dispatchers ——————————————————————————————————————
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class MemWriter:
|
|
206
|
+
@staticmethod
|
|
207
|
+
def write(md: MemoryData, memory_type: str, fact: MemoryFact) -> None:
|
|
208
|
+
if memory_type == MemoryType.PERFECT:
|
|
209
|
+
_perfect_memory.remember(md, fact)
|
|
210
|
+
elif memory_type == MemoryType.IMPERFECT:
|
|
211
|
+
_imperfect_memory.remember(md, fact)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class MemReader:
|
|
215
|
+
@staticmethod
|
|
216
|
+
def read(md: MemoryData, memory_type: str, q: MemQuery, now: int) -> list[MemoryFact]:
|
|
217
|
+
if memory_type == MemoryType.PERFECT:
|
|
218
|
+
return _perfect_memory.recall(md, q, now)
|
|
219
|
+
if memory_type == MemoryType.IMPERFECT:
|
|
220
|
+
return _imperfect_memory.recall(md, q, now)
|
|
221
|
+
return [] # empty list if memory type is unknown
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class MemHousekeeping:
|
|
225
|
+
@staticmethod
|
|
226
|
+
def tick(md: MemoryData, memory_type: str, dt: float, now: int) -> None:
|
|
227
|
+
if memory_type == MemoryType.PERFECT:
|
|
228
|
+
_perfect_memory.tick(md, dt, now)
|
|
229
|
+
else:
|
|
230
|
+
_imperfect_memory.tick(md, dt, now)
|
tigen/common/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterator
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import (
|
|
5
|
+
Generic,
|
|
6
|
+
TypeVar,
|
|
7
|
+
cast,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
# --- Generic type declarations ---
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
K = TypeVar("K")
|
|
13
|
+
V = TypeVar("V")
|
|
14
|
+
|
|
15
|
+
# A handle is a tuple: (index, generation)
|
|
16
|
+
Handle = tuple[int, int]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IsolationLevel(Enum):
|
|
20
|
+
"""
|
|
21
|
+
Enum controlling iteration mode.
|
|
22
|
+
NONE: Live iteration (all modifications are visible).
|
|
23
|
+
ALLOW_DELETIONS: Fixed-range iteration—new insertions are ignored, deletions show as gaps.
|
|
24
|
+
FULL: "Immutable" iteration: snapshot range is locked; new insertions are ignored;
|
|
25
|
+
deletions in indices beyond the current iterator position are deferred.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
NONE = 1
|
|
29
|
+
ALLOW_DELETIONS = 2
|
|
30
|
+
FULL = 3
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(slots=True)
|
|
34
|
+
class _ImmutableIterationContext(Generic[T]):
|
|
35
|
+
"""
|
|
36
|
+
Per-iterator context for FULL mode. Records:
|
|
37
|
+
- snapshot_length: the container length at iterator creation.
|
|
38
|
+
- current_position: updated during iteration (the last index yielded).
|
|
39
|
+
- deferred: a dict mapping slot indices (within the snapshot) to the value
|
|
40
|
+
that was present at the time of deletion.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
snapshot_length: int
|
|
44
|
+
current_position: int = 0
|
|
45
|
+
deferred: dict[int, T] = field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(slots=True)
|
|
49
|
+
class GenerationalContainer(Generic[T]):
|
|
50
|
+
"""
|
|
51
|
+
A generic container that stores items in a list with per-slot generation counters.
|
|
52
|
+
Insertions reuse free slots (bumping the generation), and deletions mark a slot as None.
|
|
53
|
+
|
|
54
|
+
The smart_iter() method returns an iterator that obeys one of the three modes
|
|
55
|
+
as controlled by IsolationLevel.
|
|
56
|
+
|
|
57
|
+
For FULL mode, each iterator creates its own _ImmutableIterationContext,
|
|
58
|
+
and the container keeps a list of active contexts. In remove(), if a deletion occurs
|
|
59
|
+
at an index that is greater than the context’s current_position (and within its snapshot),
|
|
60
|
+
then the current (old) value is recorded in that context’s deferred buffer.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
_items: list[T | None] = field(default_factory=list)
|
|
64
|
+
_generations: list[int] = field(default_factory=list)
|
|
65
|
+
_free_indices: list[int] = field(default_factory=list)
|
|
66
|
+
_immutable_contexts: list[_ImmutableIterationContext[T]] = field(default_factory=list)
|
|
67
|
+
|
|
68
|
+
def insert(self, item: T) -> Handle:
|
|
69
|
+
if self._free_indices:
|
|
70
|
+
idx = self._free_indices.pop()
|
|
71
|
+
self._items[idx] = item
|
|
72
|
+
self._generations[idx] += 1 # Invalidate any old handle.
|
|
73
|
+
else:
|
|
74
|
+
idx = len(self._items)
|
|
75
|
+
self._items.append(item)
|
|
76
|
+
self._generations.append(0)
|
|
77
|
+
|
|
78
|
+
return (idx, self._generations[idx])
|
|
79
|
+
|
|
80
|
+
def remove(self, handle: Handle) -> None:
|
|
81
|
+
idx, gen = handle
|
|
82
|
+
if idx >= len(self._items) or self._generations[idx] != gen:
|
|
83
|
+
raise ValueError("Invalid or stale handle")
|
|
84
|
+
if self._items[idx] is None:
|
|
85
|
+
raise ValueError("Item already deleted")
|
|
86
|
+
# For each active immutable context, only record deletion if the slot is within the snapshot
|
|
87
|
+
# and if the deletion index is beyond the current iteration position.
|
|
88
|
+
for ctx in self._immutable_contexts:
|
|
89
|
+
if ctx.current_position < idx < ctx.snapshot_length:
|
|
90
|
+
if self._items[idx] is not None and idx not in ctx.deferred:
|
|
91
|
+
ctx.deferred[idx] = cast(T, self._items[idx])
|
|
92
|
+
self._items[idx] = None
|
|
93
|
+
self._free_indices.append(idx)
|
|
94
|
+
|
|
95
|
+
def get(self, handle: Handle) -> T | None:
|
|
96
|
+
idx, gen = handle
|
|
97
|
+
if idx >= len(self._items) or self._generations[idx] != gen:
|
|
98
|
+
return None
|
|
99
|
+
return self._items[idx]
|
|
100
|
+
|
|
101
|
+
def smart_iter(
|
|
102
|
+
self,
|
|
103
|
+
allowed_mutation: IsolationLevel = IsolationLevel.NONE,
|
|
104
|
+
skip_empty: bool = True,
|
|
105
|
+
) -> Iterator[T | None]:
|
|
106
|
+
if allowed_mutation == IsolationLevel.NONE:
|
|
107
|
+
# Live iteration: yield items from the current list.
|
|
108
|
+
for item in self._items:
|
|
109
|
+
if skip_empty and item is None:
|
|
110
|
+
continue
|
|
111
|
+
yield item
|
|
112
|
+
elif allowed_mutation == IsolationLevel.ALLOW_DELETIONS:
|
|
113
|
+
# Fixed-range iteration: snapshot the length and free-slot set.
|
|
114
|
+
snapshot_length = len(self._items)
|
|
115
|
+
frozen_free = {i for i, item in enumerate(self._items) if item is None}
|
|
116
|
+
for i in range(snapshot_length):
|
|
117
|
+
if i in frozen_free:
|
|
118
|
+
continue
|
|
119
|
+
item = self._items[i]
|
|
120
|
+
if skip_empty and item is None:
|
|
121
|
+
continue
|
|
122
|
+
yield item
|
|
123
|
+
elif allowed_mutation == IsolationLevel.FULL:
|
|
124
|
+
# Create a new immutable iteration context.
|
|
125
|
+
snapshot_length = len(self._items)
|
|
126
|
+
local_ctx = _ImmutableIterationContext[T](snapshot_length)
|
|
127
|
+
self._immutable_contexts.append(local_ctx)
|
|
128
|
+
try:
|
|
129
|
+
for i in range(snapshot_length):
|
|
130
|
+
local_ctx.current_position = i # Update current position.
|
|
131
|
+
# If the live slot is None, try to fetch deferred value.
|
|
132
|
+
item = self._items[i]
|
|
133
|
+
if item is None and i in local_ctx.deferred:
|
|
134
|
+
item = local_ctx.deferred[i]
|
|
135
|
+
if skip_empty and item is None:
|
|
136
|
+
continue
|
|
137
|
+
yield item
|
|
138
|
+
finally:
|
|
139
|
+
if local_ctx in self._immutable_contexts:
|
|
140
|
+
self._immutable_contexts.remove(local_ctx)
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError("Unknown mutation allowance mode")
|
|
143
|
+
|
|
144
|
+
def __iter__(self) -> Iterator[T | None]:
|
|
145
|
+
# Default __iter__ uses NONE (live iteration).
|
|
146
|
+
return self.smart_iter(allowed_mutation=IsolationLevel.NONE)
|
|
147
|
+
|
|
148
|
+
def __len__(self) -> int:
|
|
149
|
+
# Return the number of used slots.
|
|
150
|
+
return len(self._items) - len(self._free_indices)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass(slots=True, frozen=True)
|
|
154
|
+
class Entry(Generic[K, V]):
|
|
155
|
+
key: K
|
|
156
|
+
value: V
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class GenerationalDict(Generic[K, V]):
|
|
160
|
+
__slots__ = ("_container", "_key_to_handle")
|
|
161
|
+
_container: GenerationalContainer[Entry[K, V]]
|
|
162
|
+
_key_to_handle: dict[K, Handle]
|
|
163
|
+
"""
|
|
164
|
+
A dictionary that uses a GenerationalContainer to store Entry(key, value) objects.
|
|
165
|
+
It maintains a mapping from external keys to handles (for O(1) lookups). The smart_iter()
|
|
166
|
+
method delegates to the container`s smart_iter() to yield (key, value) pairs.
|
|
167
|
+
Default iteration value_over the dict yields keys using NONE mode.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self) -> None:
|
|
171
|
+
self._container: GenerationalContainer[Entry[K, V]] = GenerationalContainer()
|
|
172
|
+
self._key_to_handle: dict[K, Handle] = {}
|
|
173
|
+
|
|
174
|
+
def __setitem__(self, key: K, value: V) -> None:
|
|
175
|
+
# To maintain correct generation, we remove the old entry if it exists, and then add the new one.
|
|
176
|
+
if key in self._key_to_handle:
|
|
177
|
+
existing = self._container.get(self._key_to_handle[key])
|
|
178
|
+
if existing is not None and value == existing:
|
|
179
|
+
# optimization: if the value is the same, do nothing.
|
|
180
|
+
return
|
|
181
|
+
self.delete(key) # Remove old entry if it exists.
|
|
182
|
+
self.add(key, value)
|
|
183
|
+
|
|
184
|
+
def __getitem__(self, key: K) -> V:
|
|
185
|
+
try:
|
|
186
|
+
idx, gen = self._key_to_handle[key] # O(1) hash-lookup
|
|
187
|
+
if self._container._generations[idx] == gen: # generation still valid
|
|
188
|
+
entry = self._container._items[idx]
|
|
189
|
+
if entry is not None: # slot not deleted
|
|
190
|
+
return entry.value # ➊ no extra calls
|
|
191
|
+
except (KeyError, IndexError):
|
|
192
|
+
pass # fall through to slow path
|
|
193
|
+
# slow path keeps old invariants
|
|
194
|
+
raise KeyError(key)
|
|
195
|
+
|
|
196
|
+
def add(self, key: K, value: V) -> None:
|
|
197
|
+
entry = Entry(key, value)
|
|
198
|
+
handle = self._container.insert(entry)
|
|
199
|
+
self._key_to_handle[key] = handle
|
|
200
|
+
|
|
201
|
+
def delete(self, key: K) -> None:
|
|
202
|
+
handle = self._key_to_handle.pop(key, None)
|
|
203
|
+
if handle is not None:
|
|
204
|
+
self._container.remove(handle)
|
|
205
|
+
|
|
206
|
+
def get(self, key: K) -> V | None:
|
|
207
|
+
try:
|
|
208
|
+
return self[key]
|
|
209
|
+
except KeyError:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
def items(
|
|
213
|
+
self,
|
|
214
|
+
allowed_mutation: IsolationLevel = IsolationLevel.NONE,
|
|
215
|
+
skip_empty: bool = True,
|
|
216
|
+
) -> Iterator[tuple[K, V]]:
|
|
217
|
+
for entry in self._container.smart_iter(allowed_mutation=allowed_mutation, skip_empty=skip_empty):
|
|
218
|
+
if entry is not None:
|
|
219
|
+
yield (entry.key, entry.value)
|
|
220
|
+
|
|
221
|
+
def keys(
|
|
222
|
+
self,
|
|
223
|
+
allowed_mutation: IsolationLevel = IsolationLevel.NONE,
|
|
224
|
+
skip_empty: bool = True,
|
|
225
|
+
) -> Iterator[K]:
|
|
226
|
+
for entry in self.items(allowed_mutation=allowed_mutation, skip_empty=skip_empty):
|
|
227
|
+
yield entry[0]
|
|
228
|
+
|
|
229
|
+
def values(
|
|
230
|
+
self,
|
|
231
|
+
allowed_mutation: IsolationLevel = IsolationLevel.NONE,
|
|
232
|
+
skip_empty: bool = True,
|
|
233
|
+
) -> Iterator[V]:
|
|
234
|
+
for entry in self.items(allowed_mutation=allowed_mutation, skip_empty=skip_empty):
|
|
235
|
+
yield entry[1]
|
|
236
|
+
|
|
237
|
+
def __iter__(self) -> Iterator[K]:
|
|
238
|
+
for key, _ in self.items(allowed_mutation=IsolationLevel.NONE, skip_empty=True):
|
|
239
|
+
yield key
|
|
240
|
+
|
|
241
|
+
def __len__(self) -> int:
|
|
242
|
+
return len(self._key_to_handle)
|
|
243
|
+
|
|
244
|
+
def __contains__(self, key: K) -> bool:
|
|
245
|
+
return key in self._key_to_handle
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class GenerationalDefaultDict(GenerationalDict[K, V]):
|
|
249
|
+
__slots__ = ("default_factory",)
|
|
250
|
+
"""
|
|
251
|
+
A dictionary based on GenerationalDict that supports a default factory.
|
|
252
|
+
When a key is missing, the default_factory is called to provide a default value,
|
|
253
|
+
the value is inserted into the dictionary, and then returned.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(self, default_factory: Callable[[], V]) -> None:
|
|
257
|
+
super().__init__()
|
|
258
|
+
self.default_factory = default_factory
|
|
259
|
+
|
|
260
|
+
def __getitem__(self, key: K) -> V:
|
|
261
|
+
if key in self._key_to_handle:
|
|
262
|
+
return super().__getitem__(key)
|
|
263
|
+
|
|
264
|
+
default_value = self.default_factory()
|
|
265
|
+
self.__setitem__(key, default_value)
|
|
266
|
+
return default_value
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for maintaining online statistics with O(1) memory usage.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
|
|
9
|
+
from tigen.common.math import MovingAverageFunction, moving_average
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(slots=True)
|
|
13
|
+
class RunningMean:
|
|
14
|
+
"""Maintains a running mean using constant memory."""
|
|
15
|
+
|
|
16
|
+
avg: MovingAverageFunction = field(default_factory=moving_average)
|
|
17
|
+
|
|
18
|
+
def add(self, value: float) -> None:
|
|
19
|
+
"""Add a value to the running mean."""
|
|
20
|
+
self.avg(value)
|
|
21
|
+
|
|
22
|
+
def value(self) -> float:
|
|
23
|
+
"""Get the current mean without modifying it."""
|
|
24
|
+
return self.avg.current()
|
|
25
|
+
|
|
26
|
+
def reset(self) -> None:
|
|
27
|
+
"""Reset the mean calculation."""
|
|
28
|
+
self.avg.reset()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class RunningDistribution:
|
|
33
|
+
"""
|
|
34
|
+
Incrementally maintains a probability distribution over categorical values.
|
|
35
|
+
Supports O(1) updates and probability queries.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
counts: dict[str, int] = field(default_factory=lambda: defaultdict(int))
|
|
39
|
+
total: int = 0
|
|
40
|
+
|
|
41
|
+
def add(self, key: str, n: int = 1) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Record n occurrences of key.
|
|
44
|
+
|
|
45
|
+
:param key: Category to increment
|
|
46
|
+
:param n: Count to add (default: 1)
|
|
47
|
+
"""
|
|
48
|
+
self.counts[key] += n
|
|
49
|
+
self.total += n
|
|
50
|
+
|
|
51
|
+
def prob(self, key: str) -> float:
|
|
52
|
+
"""
|
|
53
|
+
Get P(key) = count(key) / total.
|
|
54
|
+
Returns 0.0 if no samples recorded.
|
|
55
|
+
|
|
56
|
+
:param key: Category to query
|
|
57
|
+
:return: Probability of this category
|
|
58
|
+
"""
|
|
59
|
+
return self.counts[key] / self.total if self.total else 0.0
|
|
60
|
+
|
|
61
|
+
def keys(self) -> Iterable[str]:
|
|
62
|
+
"""Get recorded categories."""
|
|
63
|
+
return self.counts.keys()
|
|
64
|
+
|
|
65
|
+
def reset(self) -> None:
|
|
66
|
+
"""Clear all counts."""
|
|
67
|
+
self.counts.clear()
|
|
68
|
+
self.total = 0
|