entroplain 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,272 @@
1
+ """
2
+ Entropy Monitor — Core entropy tracking and early exit logic.
3
+ """
4
+
5
+ import math
6
+ from typing import List, Tuple, Optional, Dict, Any, Callable
7
+ from dataclasses import dataclass, field
8
+ from enum import Enum
9
+
10
+
11
+ class ExitCondition(Enum):
12
+ VALLEYS_PLATEAU = "valleys_plateau"
13
+ ENTROPY_DROP = "entropy_drop"
14
+ VELOCITY_ZERO = "velocity_zero"
15
+ COMBINED = "combined"
16
+
17
+
18
+ @dataclass
19
+ class EntropyPoint:
20
+ """A single point in the entropy trajectory."""
21
+ index: int
22
+ token: str
23
+ entropy: float
24
+ is_valley: bool = False
25
+ velocity: float = 0.0
26
+
27
+
28
+ @dataclass
29
+ class MonitorConfig:
30
+ """Configuration for the entropy monitor."""
31
+ entropy_threshold: float = 0.15
32
+ min_valleys: int = 2
33
+ velocity_threshold: float = 0.05
34
+ min_tokens: int = 50
35
+ valley_window: int = 5
36
+ plateau_threshold: int = 3
37
+ exit_condition: ExitCondition = ExitCondition.COMBINED
38
+
39
+
40
+ class EntropyMonitor:
41
+ """
42
+ Monitor entropy trajectory and detect reasoning convergence.
43
+
44
+ Usage:
45
+ monitor = EntropyMonitor()
46
+
47
+ for token, entropy in stream:
48
+ monitor.track(token, entropy)
49
+ if monitor.should_exit():
50
+ break
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ entropy_threshold: float = 0.15,
56
+ min_valleys: int = 2,
57
+ velocity_threshold: float = 0.05,
58
+ min_tokens: int = 50,
59
+ valley_window: int = 5,
60
+ plateau_threshold: int = 3,
61
+ exit_condition: str = "combined"
62
+ ):
63
+ self.config = MonitorConfig(
64
+ entropy_threshold=entropy_threshold,
65
+ min_valleys=min_valleys,
66
+ velocity_threshold=velocity_threshold,
67
+ min_tokens=min_tokens,
68
+ valley_window=valley_window,
69
+ plateau_threshold=plateau_threshold,
70
+ exit_condition=ExitCondition(exit_condition)
71
+ )
72
+ self._trajectory: List[EntropyPoint] = []
73
+ self._valleys: List[EntropyPoint] = []
74
+ self._index = 0
75
+
76
+ def calculate_entropy(self, logprobs: List[float], from_probs: bool = False) -> float:
77
+ """
78
+ Calculate Shannon entropy from log probabilities or probabilities.
79
+
80
+ Args:
81
+ logprobs: List of log probabilities (natural log) or probabilities
82
+ from_probs: If True, treat input as probabilities (will convert)
83
+
84
+ Returns:
85
+ Shannon entropy in bits
86
+ """
87
+ if not logprobs:
88
+ return 0.0
89
+
90
+ entropy = 0.0
91
+ for lp in logprobs:
92
+ if from_probs:
93
+ prob = lp
94
+ else:
95
+ prob = math.exp(lp)
96
+
97
+ if prob > 0:
98
+ entropy -= prob * math.log2(prob + 1e-10)
99
+
100
+ return entropy
101
+
102
+ def track(self, token: str, entropy: float) -> EntropyPoint:
103
+ """
104
+ Track a token and its entropy value.
105
+
106
+ Args:
107
+ token: The generated token
108
+ entropy: Calculated entropy for this token
109
+
110
+ Returns:
111
+ EntropyPoint with valley detection
112
+ """
113
+ point = EntropyPoint(
114
+ index=self._index,
115
+ token=token,
116
+ entropy=entropy
117
+ )
118
+
119
+ # Calculate velocity
120
+ if len(self._trajectory) > 0:
121
+ prev = self._trajectory[-1]
122
+ point.velocity = abs(entropy - prev.entropy)
123
+
124
+ # Detect valley (local minimum)
125
+ if len(self._trajectory) >= 2:
126
+ prev2 = self._trajectory[-2]
127
+ prev1 = self._trajectory[-1]
128
+ if prev1.entropy < prev2.entropy and prev1.entropy < entropy:
129
+ prev1.is_valley = True
130
+ self._valleys.append(prev1)
131
+
132
+ self._trajectory.append(point)
133
+ self._index += 1
134
+
135
+ return point
136
+
137
+ def get_valleys(self) -> List[Tuple[int, float]]:
138
+ """Get all entropy valleys (local minima) as (index, entropy) tuples."""
139
+ return [(v.index, v.entropy) for v in self._valleys]
140
+
141
+ def get_velocity(self) -> float:
142
+ """Get current entropy velocity (rate of change)."""
143
+ if len(self._trajectory) < 2:
144
+ return 0.0
145
+ return self._trajectory[-1].velocity
146
+
147
+ def get_mean_entropy(self) -> float:
148
+ """Get mean entropy over the trajectory."""
149
+ if not self._trajectory:
150
+ return 0.0
151
+ return sum(p.entropy for p in self._trajectory) / len(self._trajectory)
152
+
153
+ def get_valley_count(self) -> int:
154
+ """Get the number of detected valleys."""
155
+ return len(self._valleys)
156
+
157
+ def is_valleys_plateau(self) -> bool:
158
+ """Check if valley count has plateaued."""
159
+ if len(self._valleys) < self.config.min_valleys:
160
+ return False
161
+
162
+ # Check if last N valleys have similar spacing
163
+ recent = self._valleys[-self.config.plateau_threshold:]
164
+ if len(recent) < self.config.plateau_threshold:
165
+ return False
166
+
167
+ # Calculate spacing between recent valleys
168
+ spacings = [recent[i+1].index - recent[i].index for i in range(len(recent)-1)]
169
+ if not spacings:
170
+ return False
171
+
172
+ mean_spacing = sum(spacings) / len(spacings)
173
+ variance = sum((s - mean_spacing)**2 for s in spacings) / len(spacings)
174
+
175
+ # Low variance in spacing = plateau
176
+ return variance < 10 # Threshold tuned empirically
177
+
178
+ def is_entropy_low(self) -> bool:
179
+ """Check if current entropy is below threshold."""
180
+ if not self._trajectory:
181
+ return False
182
+ return self._trajectory[-1].entropy < self.config.entropy_threshold
183
+
184
+ def is_velocity_stable(self) -> bool:
185
+ """Check if velocity is below threshold."""
186
+ return self.get_velocity() < self.config.velocity_threshold
187
+
188
+ def should_exit(self) -> bool:
189
+ """
190
+ Determine if reasoning has converged and we should exit.
191
+
192
+ Uses the configured exit condition:
193
+ - valleys_plateau: Exit when valley count plateaus
194
+ - entropy_drop: Exit when entropy drops below threshold
195
+ - velocity_zero: Exit when velocity stabilizes
196
+ - combined: Use all conditions with AND logic
197
+ """
198
+ # Always require minimum tokens
199
+ if len(self._trajectory) < self.config.min_tokens:
200
+ return False
201
+
202
+ # Always require minimum valleys
203
+ if len(self._valleys) < self.config.min_valleys:
204
+ return False
205
+
206
+ condition = self.config.exit_condition
207
+
208
+ if condition == ExitCondition.VALLEYS_PLATEAU:
209
+ return self.is_valleys_plateau()
210
+
211
+ if condition == ExitCondition.ENTROPY_DROP:
212
+ return self.is_entropy_low()
213
+
214
+ if condition == ExitCondition.VELOCITY_ZERO:
215
+ return self.is_velocity_stable()
216
+
217
+ if condition == ExitCondition.COMBINED:
218
+ # Combined: require entropy low OR valleys plateau, AND velocity stable
219
+ return (self.is_entropy_low() or self.is_valleys_plateau()) and self.is_velocity_stable()
220
+
221
+ return False
222
+
223
+ def is_converged(self) -> bool:
224
+ """Alias for should_exit()."""
225
+ return self.should_exit()
226
+
227
+ def get_trajectory(self) -> List[float]:
228
+ """Get full entropy trajectory as list of floats."""
229
+ return [p.entropy for p in self._trajectory]
230
+
231
+ def get_tokens(self) -> List[str]:
232
+ """Get all tracked tokens."""
233
+ return [p.token for p in self._trajectory]
234
+
235
+ def get_stats(self) -> Dict[str, Any]:
236
+ """Get summary statistics."""
237
+ if not self._trajectory:
238
+ return {}
239
+
240
+ entropies = [p.entropy for p in self._trajectory]
241
+
242
+ return {
243
+ "token_count": len(self._trajectory),
244
+ "valley_count": len(self._valleys),
245
+ "mean_entropy": sum(entropies) / len(entropies),
246
+ "min_entropy": min(entropies),
247
+ "max_entropy": max(entropies),
248
+ "current_entropy": entropies[-1],
249
+ "current_velocity": self.get_velocity(),
250
+ "is_converged": self.should_exit()
251
+ }
252
+
253
+ def reset(self) -> None:
254
+ """Clear all tracked data."""
255
+ self._trajectory.clear()
256
+ self._valleys.clear()
257
+ self._index = 0
258
+
259
+
260
+ def calculate_entropy(logprobs: List[float], from_probs: bool = False) -> float:
261
+ """
262
+ Standalone function to calculate Shannon entropy.
263
+
264
+ Args:
265
+ logprobs: List of log probabilities or probabilities
266
+ from_probs: If True, treat input as probabilities
267
+
268
+ Returns:
269
+ Shannon entropy in bits
270
+ """
271
+ monitor = EntropyMonitor()
272
+ return monitor.calculate_entropy(logprobs, from_probs)