entroplain 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,20 @@
1
1
  """
2
2
  Entropy Monitor — Core entropy tracking and early exit logic.
3
+
4
+ Supports multiple exit strategies:
5
+ - Valleys plateau: Exit when reasoning milestones stabilize
6
+ - Entropy drop: Exit when model confidence is high
7
+ - Velocity zero: Exit when entropy stops changing
8
+ - Combined: Multiple conditions with AND/OR logic
9
+ - Repetition: Exit when model starts repeating
10
+ - Confidence: Exit when top token probability > threshold for N tokens
3
11
  """
4
12
 
5
13
  import math
6
14
  from typing import List, Tuple, Optional, Dict, Any, Callable
7
15
  from dataclasses import dataclass, field
8
16
  from enum import Enum
17
+ from collections import Counter
9
18
 
10
19
 
11
20
  class ExitCondition(Enum):
@@ -13,6 +22,10 @@ class ExitCondition(Enum):
13
22
  ENTROPY_DROP = "entropy_drop"
14
23
  VELOCITY_ZERO = "velocity_zero"
15
24
  COMBINED = "combined"
25
+ # New strategies
26
+ REPETITION = "repetition"
27
+ CONFIDENCE = "confidence"
28
+ SEMANTIC = "semantic"
16
29
 
17
30
 
18
31
  @dataclass
@@ -23,6 +36,7 @@ class EntropyPoint:
23
36
  entropy: float
24
37
  is_valley: bool = False
25
38
  velocity: float = 0.0
39
+ confidence: float = 0.0 # Top token probability
26
40
 
27
41
 
28
42
  @dataclass
@@ -35,21 +49,25 @@ class MonitorConfig:
35
49
  valley_window: int = 5
36
50
  plateau_threshold: int = 3
37
51
  exit_condition: ExitCondition = ExitCondition.COMBINED
52
+ # New config options
53
+ repetition_window: int = 20 # Window to check for repetition
54
+ repetition_threshold: float = 0.3 # 30% repetition = exit
55
+ confidence_threshold: float = 0.95 # 95% confidence = exit
56
+ confidence_min_tokens: int = 5 # Min tokens at high confidence
38
57
 
39
58
 
40
59
  class EntropyMonitor:
41
60
  """
42
61
  Monitor entropy trajectory and detect reasoning convergence.
43
-
62
+
44
63
  Usage:
45
64
  monitor = EntropyMonitor()
46
-
47
65
  for token, entropy in stream:
48
66
  monitor.track(token, entropy)
49
67
  if monitor.should_exit():
50
68
  break
51
69
  """
52
-
70
+
53
71
  def __init__(
54
72
  self,
55
73
  entropy_threshold: float = 0.15,
@@ -58,7 +76,12 @@ class EntropyMonitor:
58
76
  min_tokens: int = 50,
59
77
  valley_window: int = 5,
60
78
  plateau_threshold: int = 3,
61
- exit_condition: str = "combined"
79
+ exit_condition: str = "combined",
80
+ # New parameters
81
+ repetition_window: int = 20,
82
+ repetition_threshold: float = 0.3,
83
+ confidence_threshold: float = 0.95,
84
+ confidence_min_tokens: int = 5,
62
85
  ):
63
86
  self.config = MonitorConfig(
64
87
  entropy_threshold=entropy_threshold,
@@ -67,60 +90,66 @@ class EntropyMonitor:
67
90
  min_tokens=min_tokens,
68
91
  valley_window=valley_window,
69
92
  plateau_threshold=plateau_threshold,
70
- exit_condition=ExitCondition(exit_condition)
93
+ exit_condition=ExitCondition(exit_condition),
94
+ repetition_window=repetition_window,
95
+ repetition_threshold=repetition_threshold,
96
+ confidence_threshold=confidence_threshold,
97
+ confidence_min_tokens=confidence_min_tokens,
71
98
  )
72
99
  self._trajectory: List[EntropyPoint] = []
73
100
  self._valleys: List[EntropyPoint] = []
74
101
  self._index = 0
75
-
102
+ self._high_confidence_count = 0 # Track consecutive high confidence
103
+
76
104
  def calculate_entropy(self, logprobs: List[float], from_probs: bool = False) -> float:
77
105
  """
78
106
  Calculate Shannon entropy from log probabilities or probabilities.
79
-
107
+
80
108
  Args:
81
109
  logprobs: List of log probabilities (natural log) or probabilities
82
110
  from_probs: If True, treat input as probabilities (will convert)
83
-
111
+
84
112
  Returns:
85
113
  Shannon entropy in bits
86
114
  """
87
115
  if not logprobs:
88
116
  return 0.0
89
-
117
+
90
118
  entropy = 0.0
91
119
  for lp in logprobs:
92
120
  if from_probs:
93
121
  prob = lp
94
122
  else:
95
123
  prob = math.exp(lp)
96
-
97
124
  if prob > 0:
98
125
  entropy -= prob * math.log2(prob + 1e-10)
99
-
126
+
100
127
  return entropy
101
-
102
- def track(self, token: str, entropy: float) -> EntropyPoint:
128
+
129
+ def track(self, token: str, entropy: float, confidence: float = 0.0) -> EntropyPoint:
103
130
  """
104
131
  Track a token and its entropy value.
105
-
132
+
106
133
  Args:
107
134
  token: The generated token
108
135
  entropy: Calculated entropy for this token
109
-
136
+ confidence: Top token probability (optional, for confidence strategy)
137
+
110
138
  Returns:
111
139
  EntropyPoint with valley detection
112
140
  """
113
141
  point = EntropyPoint(
114
142
  index=self._index,
115
143
  token=token,
116
- entropy=entropy
144
+ entropy=entropy,
145
+ confidence=confidence
117
146
  )
118
-
147
+
119
148
  # Calculate velocity
120
149
  if len(self._trajectory) > 0:
121
150
  prev = self._trajectory[-1]
122
151
  point.velocity = abs(entropy - prev.entropy)
123
-
152
+
124
153
  # Detect valley (local minimum)
125
154
  if len(self._trajectory) >= 2:
126
155
  prev2 = self._trajectory[-2]
@@ -128,117 +157,174 @@ class EntropyMonitor:
128
157
  if prev1.entropy < prev2.entropy and prev1.entropy < entropy:
129
158
  prev1.is_valley = True
130
159
  self._valleys.append(prev1)
131
-
160
+
132
161
  self._trajectory.append(point)
133
162
  self._index += 1
134
-
163
+
164
+ # Track high confidence
165
+ if confidence >= self.config.confidence_threshold:
166
+ self._high_confidence_count += 1
167
+ else:
168
+ self._high_confidence_count = 0
169
+
135
170
  return point
136
-
171
+
137
172
  def get_valleys(self) -> List[Tuple[int, float]]:
138
173
  """Get all entropy valleys (local minima) as (index, entropy) tuples."""
139
174
  return [(v.index, v.entropy) for v in self._valleys]
140
-
175
+
141
176
  def get_velocity(self) -> float:
142
177
  """Get current entropy velocity (rate of change)."""
143
178
  if len(self._trajectory) < 2:
144
179
  return 0.0
145
180
  return self._trajectory[-1].velocity
146
-
181
+
147
182
  def get_mean_entropy(self) -> float:
148
183
  """Get mean entropy over the trajectory."""
149
184
  if not self._trajectory:
150
185
  return 0.0
151
186
  return sum(p.entropy for p in self._trajectory) / len(self._trajectory)
152
-
187
+
153
188
  def get_valley_count(self) -> int:
154
189
  """Get the number of detected valleys."""
155
190
  return len(self._valleys)
156
-
191
+
157
192
  def is_valleys_plateau(self) -> bool:
158
193
  """Check if valley count has plateaued."""
159
194
  if len(self._valleys) < self.config.min_valleys:
160
195
  return False
161
-
196
+
162
197
  # Check if last N valleys have similar spacing
163
198
  recent = self._valleys[-self.config.plateau_threshold:]
164
199
  if len(recent) < self.config.plateau_threshold:
165
200
  return False
166
-
201
+
167
202
  # Calculate spacing between recent valleys
168
- spacings = [recent[i+1].index - recent[i].index for i in range(len(recent)-1)]
203
+ spacings = [
204
+ recent[i + 1].index - recent[i].index
205
+ for i in range(len(recent) - 1)
206
+ ]
169
207
  if not spacings:
170
208
  return False
171
-
209
+
172
210
  mean_spacing = sum(spacings) / len(spacings)
173
- variance = sum((s - mean_spacing)**2 for s in spacings) / len(spacings)
174
-
211
+ variance = sum((s - mean_spacing) ** 2 for s in spacings) / len(spacings)
212
+
175
213
  # Low variance in spacing = plateau
176
214
  return variance < 10 # Threshold tuned empirically
177
-
215
+
178
216
  def is_entropy_low(self) -> bool:
179
217
  """Check if current entropy is below threshold."""
180
218
  if not self._trajectory:
181
219
  return False
182
220
  return self._trajectory[-1].entropy < self.config.entropy_threshold
183
-
221
+
184
222
  def is_velocity_stable(self) -> bool:
185
223
  """Check if velocity is below threshold."""
186
224
  return self.get_velocity() < self.config.velocity_threshold
187
-
225
+
226
+ def is_repeating(self) -> bool:
227
+ """
228
+ Check if the model is repeating itself.
229
+
230
+ Returns True if the repetition ratio in the recent window
231
+ exceeds the threshold.
232
+ """
233
+ if len(self._trajectory) < self.config.repetition_window:
234
+ return False
235
+
236
+ # Get recent tokens
237
+ recent_tokens = [
238
+ p.token for p in self._trajectory[-self.config.repetition_window :]
239
+ ]
240
+
241
+ # Count unique vs total
242
+ counter = Counter(recent_tokens)
243
+ unique_count = len(counter)
244
+ total_count = len(recent_tokens)
245
+
246
+ # Calculate repetition ratio
247
+ repetition_ratio = 1.0 - (unique_count / total_count)
248
+
249
+ return repetition_ratio >= self.config.repetition_threshold
250
+
251
+ def is_confident(self) -> bool:
252
+ """
253
+ Check if model has been highly confident for consecutive tokens.
254
+
255
+ Returns True if the last N tokens had confidence >= threshold.
256
+ """
257
+ return self._high_confidence_count >= self.config.confidence_min_tokens
258
+
188
259
  def should_exit(self) -> bool:
189
260
  """
190
261
  Determine if reasoning has converged and we should exit.
191
-
262
+
192
263
  Uses the configured exit condition:
193
264
  - valleys_plateau: Exit when valley count plateaus
194
265
  - entropy_drop: Exit when entropy drops below threshold
195
266
  - velocity_zero: Exit when velocity stabilizes
196
267
  - combined: Use all conditions with AND logic
268
+ - repetition: Exit when model starts repeating
269
+ - confidence: Exit when confidence is high for N tokens
197
270
  """
198
271
  # Always require minimum tokens
199
272
  if len(self._trajectory) < self.config.min_tokens:
200
273
  return False
201
-
202
- # Always require minimum valleys
274
+
275
+ # Always require minimum valleys (for most strategies)
276
+ condition = self.config.exit_condition
277
+
278
+ if condition == ExitCondition.REPETITION:
279
+ # Repetition doesn't require valleys
280
+ return self.is_repeating()
281
+
282
+ if condition == ExitCondition.CONFIDENCE:
283
+ # Confidence doesn't require valleys
284
+ return self.is_confident()
285
+
286
+ # For other strategies, require minimum valleys
203
287
  if len(self._valleys) < self.config.min_valleys:
204
288
  return False
205
-
206
- condition = self.config.exit_condition
207
-
289
+
208
290
  if condition == ExitCondition.VALLEYS_PLATEAU:
209
291
  return self.is_valleys_plateau()
210
-
292
+
211
293
  if condition == ExitCondition.ENTROPY_DROP:
212
294
  return self.is_entropy_low()
213
-
295
+
214
296
  if condition == ExitCondition.VELOCITY_ZERO:
215
297
  return self.is_velocity_stable()
216
-
298
+
217
299
  if condition == ExitCondition.COMBINED:
218
300
  # Combined: require entropy low OR valleys plateau, AND velocity stable
219
301
  return (self.is_entropy_low() or self.is_valleys_plateau()) and self.is_velocity_stable()
220
-
302
+
303
+ if condition == ExitCondition.SEMANTIC:
304
+ # Placeholder for future semantic convergence detection
305
+ # Would use embeddings to detect when output stabilizes semantically
306
+ return False
307
+
221
308
  return False
222
-
309
+
223
310
  def is_converged(self) -> bool:
224
311
  """Alias for should_exit()."""
225
312
  return self.should_exit()
226
-
313
+
227
314
  def get_trajectory(self) -> List[float]:
228
315
  """Get full entropy trajectory as list of floats."""
229
316
  return [p.entropy for p in self._trajectory]
230
-
317
+
231
318
  def get_tokens(self) -> List[str]:
232
319
  """Get all tracked tokens."""
233
320
  return [p.token for p in self._trajectory]
234
-
321
+
235
322
  def get_stats(self) -> Dict[str, Any]:
236
323
  """Get summary statistics."""
237
324
  if not self._trajectory:
238
325
  return {}
239
-
326
+
240
327
  entropies = [p.entropy for p in self._trajectory]
241
-
242
328
  return {
243
329
  "token_count": len(self._trajectory),
244
330
  "valley_count": len(self._valleys),
@@ -247,26 +333,58 @@ class EntropyMonitor:
247
333
  "max_entropy": max(entropies),
248
334
  "current_entropy": entropies[-1],
249
335
  "current_velocity": self.get_velocity(),
250
- "is_converged": self.should_exit()
336
+ "is_converged": self.should_exit(),
337
+ "exit_reason": self._get_exit_reason(),
251
338
  }
252
-
339
+
340
+ def _get_exit_reason(self) -> Optional[str]:
341
+ """Get the reason for early exit (if triggered)."""
342
+ if not self.should_exit():
343
+ return None
344
+
345
+ condition = self.config.exit_condition
346
+
347
+ if condition == ExitCondition.REPETITION:
348
+ return "repetition_detected"
349
+ if condition == ExitCondition.CONFIDENCE:
350
+ return "high_confidence"
351
+ if condition == ExitCondition.ENTROPY_DROP:
352
+ return "entropy_below_threshold"
353
+ if condition == ExitCondition.VELOCITY_ZERO:
354
+ return "velocity_stable"
355
+ if condition == ExitCondition.VALLEYS_PLATEAU:
356
+ return "valleys_plateau"
357
+ if condition == ExitCondition.COMBINED:
358
+ if self.is_entropy_low() and self.is_velocity_stable():
359
+ return "entropy_low_velocity_stable"
360
+ if self.is_valleys_plateau() and self.is_velocity_stable():
361
+ return "valleys_plateau_velocity_stable"
362
+ return "combined"
363
+
364
+ return "unknown"
365
+
253
366
  def reset(self) -> None:
254
367
  """Clear all tracked data."""
255
368
  self._trajectory.clear()
256
369
  self._valleys.clear()
257
370
  self._index = 0
371
+ self._high_confidence_count = 0
258
372
 
259
373
 
260
- def calculate_entropy(logprobs: List[float], from_probs: bool = False) -> float:
374
+ # Convenience function for one-shot entropy calculation
375
+ def calculate_entropy_from_logprobs(logprobs: List[float]) -> float:
261
376
  """
262
- Standalone function to calculate Shannon entropy.
263
-
377
+ Calculate Shannon entropy from log probabilities.
378
+
264
379
  Args:
265
- logprobs: List of log probabilities or probabilities
266
- from_probs: If True, treat input as probabilities
267
-
380
+ logprobs: List of log probabilities (natural log)
381
+
268
382
  Returns:
269
383
  Shannon entropy in bits
270
384
  """
271
- monitor = EntropyMonitor()
272
- return monitor.calculate_entropy(logprobs, from_probs)
385
+ entropy = 0.0
386
+ for lp in logprobs:
387
+ prob = math.exp(lp)
388
+ if prob > 0:
389
+ entropy -= prob * math.log2(prob + 1e-10)
390
+ return entropy