uchi-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uchi/__init__.py +57 -0
- uchi/discretize.py +307 -0
- uchi/distributional.py +105 -0
- uchi/dual_predictor.py +172 -0
- uchi/forest.py +410 -0
- uchi/generative.py +910 -0
- uchi/hoeffding.py +225 -0
- uchi/long_term_store.py +345 -0
- uchi/node_compressor.py +492 -0
- uchi/online_tokenizer.py +349 -0
- uchi/predictor.py +578 -0
- uchi/semantic_tokenizer.py +48 -0
- uchi/tabular.py +401 -0
- uchi/timeseries.py +445 -0
- uchi_python-0.1.0.dist-info/METADATA +468 -0
- uchi_python-0.1.0.dist-info/RECORD +19 -0
- uchi_python-0.1.0.dist-info/WHEEL +5 -0
- uchi_python-0.1.0.dist-info/licenses/LICENSE +21 -0
- uchi_python-0.1.0.dist-info/top_level.txt +1 -0
uchi/generative.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
"""
|
|
2
|
+
generative.py
|
|
3
|
+
=============
|
|
4
|
+
Generative capabilities for the Universal Sequence Predictor.
|
|
5
|
+
|
|
6
|
+
The predictor already stores the full conditional distribution P(next | context).
|
|
7
|
+
Generation is sampling from that distribution instead of taking argmax.
|
|
8
|
+
|
|
9
|
+
Sampling controls
|
|
10
|
+
-----------------
|
|
11
|
+
temperature : float (default 1.0)
|
|
12
|
+
< 1.0 → sharper / more deterministic
|
|
13
|
+
> 1.0 → flatter / more creative / more random
|
|
14
|
+
The distribution is raised to (1/T) then renormalised.
|
|
15
|
+
|
|
16
|
+
top_k : int | None
|
|
17
|
+
Keep only the k most probable tokens before sampling.
|
|
18
|
+
|
|
19
|
+
top_p : float | None (nucleus sampling)
|
|
20
|
+
Keep the smallest set of tokens whose cumulative probability ≥ top_p,
|
|
21
|
+
then sample within that nucleus. Balances diversity and coherence
|
|
22
|
+
better than top_k on large vocabularies.
|
|
23
|
+
|
|
24
|
+
Classes
|
|
25
|
+
-------
|
|
26
|
+
SequenceGenerator — auto-regressive text/symbol generation
|
|
27
|
+
TabularGenerator — synthetic tabular row generation (joint distribution)
|
|
28
|
+
TimeSeriesGenerator — sampled multivariate time series generation
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
import random
|
|
35
|
+
from typing import Any
|
|
36
|
+
|
|
37
|
+
from .predictor import UniversalPredictor
|
|
38
|
+
from .discretize import FeatureDiscretizer, LabelEncoder, _to_rows
|
|
39
|
+
from .tabular import _make_predictor, _apply_order, _build_orders, _LABEL_NS
|
|
40
|
+
from .long_term_store import LongTermStore
|
|
41
|
+
from .online_tokenizer import OnlineTokenizer
|
|
42
|
+
from .semantic_tokenizer import SemanticTokenizer
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from sklearn.base import BaseEstimator
|
|
46
|
+
_SKLEARN = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
class BaseEstimator: pass
|
|
49
|
+
_SKLEARN = False
|
|
50
|
+
|
|
51
|
+
# Lazy imports for optional components
|
|
52
|
+
def _get_online_tokenizer():
|
|
53
|
+
from .online_tokenizer import OnlineTokenizer
|
|
54
|
+
return OnlineTokenizer
|
|
55
|
+
|
|
56
|
+
def _get_dual_predictor():
|
|
57
|
+
from .dual_predictor import DualPredictor
|
|
58
|
+
return DualPredictor
|
|
59
|
+
|
|
60
|
+
def _get_long_term_store():
|
|
61
|
+
from .long_term_store import LongTermStore
|
|
62
|
+
return LongTermStore
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
66
|
+
# Shared sampling primitive
|
|
67
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
68
|
+
|
|
69
|
+
def _sample_dist(
|
|
70
|
+
dist: dict,
|
|
71
|
+
temperature: float,
|
|
72
|
+
top_k: int | None,
|
|
73
|
+
top_p: float | None,
|
|
74
|
+
rng: random.Random,
|
|
75
|
+
) -> Any:
|
|
76
|
+
"""
|
|
77
|
+
Sample one token from a probability distribution.
|
|
78
|
+
|
|
79
|
+
Order of operations: temperature → top_k → top_p → normalise → sample.
|
|
80
|
+
"""
|
|
81
|
+
if not dist:
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
tokens = list(dist.keys())
|
|
85
|
+
probs = list(dist.values())
|
|
86
|
+
|
|
87
|
+
# Temperature: reshape p_i ← p_i^(1/T)
|
|
88
|
+
if temperature != 1.0 and temperature > 0:
|
|
89
|
+
inv_t = 1.0 / temperature
|
|
90
|
+
probs = [p ** inv_t for p in probs]
|
|
91
|
+
|
|
92
|
+
# Top-k: zero out all but the k highest
|
|
93
|
+
if top_k is not None and top_k < len(probs):
|
|
94
|
+
threshold = sorted(probs, reverse=True)[top_k - 1]
|
|
95
|
+
probs = [p if p >= threshold else 0.0 for p in probs]
|
|
96
|
+
|
|
97
|
+
# Nucleus (top-p): keep smallest prefix of sorted tokens summing to ≥ top_p
|
|
98
|
+
if top_p is not None:
|
|
99
|
+
order = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
|
|
100
|
+
cumsum = 0.0
|
|
101
|
+
keep: set = set()
|
|
102
|
+
for i in order:
|
|
103
|
+
keep.add(i)
|
|
104
|
+
cumsum += probs[i]
|
|
105
|
+
if cumsum >= top_p:
|
|
106
|
+
break
|
|
107
|
+
probs = [p if i in keep else 0.0 for i, p in enumerate(probs)]
|
|
108
|
+
|
|
109
|
+
total = sum(probs)
|
|
110
|
+
if total < 1e-12:
|
|
111
|
+
probs = [1.0 / len(probs)] * len(probs)
|
|
112
|
+
else:
|
|
113
|
+
probs = [p / total for p in probs]
|
|
114
|
+
|
|
115
|
+
# Inverse CDF sample
|
|
116
|
+
r = rng.random()
|
|
117
|
+
cumsum = 0.0
|
|
118
|
+
for token, p in zip(tokens, probs):
|
|
119
|
+
cumsum += p
|
|
120
|
+
if r <= cumsum:
|
|
121
|
+
return token
|
|
122
|
+
return tokens[-1]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _generate_from_predictor(
|
|
126
|
+
p: UniversalPredictor,
|
|
127
|
+
n_tokens: int,
|
|
128
|
+
seed: list | None,
|
|
129
|
+
temperature: float,
|
|
130
|
+
top_k: int | None,
|
|
131
|
+
top_p: float | None,
|
|
132
|
+
rng: random.Random,
|
|
133
|
+
stop_tokens: set | None = None,
|
|
134
|
+
tokenizer=None,
|
|
135
|
+
long_term_store=None,
|
|
136
|
+
) -> list:
|
|
137
|
+
"""
|
|
138
|
+
Auto-regressively sample n_tokens from predictor p.
|
|
139
|
+
History is temporarily extended then restored.
|
|
140
|
+
|
|
141
|
+
Optional: tokenizer applies merge rules to seed tokens.
|
|
142
|
+
Optional: long_term_store provides three-layer fallback blending.
|
|
143
|
+
"""
|
|
144
|
+
saved = p.history[:]
|
|
145
|
+
|
|
146
|
+
if seed:
|
|
147
|
+
seed_tokens = tokenizer.tokenize(seed) if tokenizer else seed
|
|
148
|
+
for tok in seed_tokens:
|
|
149
|
+
p.observe(tok)
|
|
150
|
+
|
|
151
|
+
generated = []
|
|
152
|
+
for _ in range(n_tokens):
|
|
153
|
+
p.predict()
|
|
154
|
+
dist = dict(p._last_distribution)
|
|
155
|
+
|
|
156
|
+
# Three-layer fallback (Problem 8): blend with long-term store
|
|
157
|
+
if long_term_store is not None and hasattr(p, 'history') and p.history:
|
|
158
|
+
ctx = tuple(p.history[-p.k:]) if len(p.history) >= p.k else tuple(p.history)
|
|
159
|
+
dist = long_term_store.blend(dist, ctx, p._vocab)
|
|
160
|
+
|
|
161
|
+
token = _sample_dist(dist, temperature, top_k, top_p, rng)
|
|
162
|
+
if token is None:
|
|
163
|
+
break
|
|
164
|
+
generated.append(token)
|
|
165
|
+
if stop_tokens and token in stop_tokens:
|
|
166
|
+
break
|
|
167
|
+
p.observe(token)
|
|
168
|
+
|
|
169
|
+
p.history = saved
|
|
170
|
+
|
|
171
|
+
# Detokenize if we used a tokenizer
|
|
172
|
+
if tokenizer is not None:
|
|
173
|
+
generated = tokenizer.detokenize(generated)
|
|
174
|
+
|
|
175
|
+
return generated
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _train_autoregressive(
|
|
179
|
+
p: UniversalPredictor,
|
|
180
|
+
tokens: list,
|
|
181
|
+
tokenizer=None,
|
|
182
|
+
long_term_store=None,
|
|
183
|
+
use_skip_grams=False,
|
|
184
|
+
) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Train p on every consecutive token pair within a single sequence.
|
|
187
|
+
Each token is predicted from all preceding tokens in that sequence.
|
|
188
|
+
This learns the joint distribution P(t_0) P(t_1|t_0) … P(t_n|t_0..t_{n-1}).
|
|
189
|
+
|
|
190
|
+
Optional: tokenizer applies online merge rules before training.
|
|
191
|
+
Optional: long_term_store receives replay after training completes.
|
|
192
|
+
"""
|
|
193
|
+
if tokenizer is not None:
|
|
194
|
+
tokens = tokenizer.tokenize(tokens)
|
|
195
|
+
|
|
196
|
+
p.history.clear()
|
|
197
|
+
correct = 0
|
|
198
|
+
total = 0
|
|
199
|
+
import random
|
|
200
|
+
for i, token in enumerate(tokens):
|
|
201
|
+
p.predict()
|
|
202
|
+
pred = p._last_prediction
|
|
203
|
+
|
|
204
|
+
# Simulated Skip-Gram attention: randomly drop one context token during training
|
|
205
|
+
if use_skip_grams and i > 2 and random.random() < 0.2:
|
|
206
|
+
saved = p.history[:]
|
|
207
|
+
idx_to_drop = random.randint(0, len(p.history) - 1)
|
|
208
|
+
p.history.pop(idx_to_drop)
|
|
209
|
+
p.observe(token)
|
|
210
|
+
p.feedback(token)
|
|
211
|
+
p.history = saved
|
|
212
|
+
|
|
213
|
+
p.observe(token)
|
|
214
|
+
p.feedback(token)
|
|
215
|
+
total += 1
|
|
216
|
+
if pred == token:
|
|
217
|
+
correct += 1
|
|
218
|
+
|
|
219
|
+
# Update tokenizer merge scores with running accuracy
|
|
220
|
+
if tokenizer is not None and total > 0:
|
|
221
|
+
tokenizer.update(tokens, correct / total)
|
|
222
|
+
|
|
223
|
+
# Replay high-confidence patterns into long-term store
|
|
224
|
+
if long_term_store is not None and total > 0:
|
|
225
|
+
long_term_store.replay(p, tokens)
|
|
226
|
+
|
|
227
|
+
p.history.clear()
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
231
|
+
# SequenceGenerator
|
|
232
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
233
|
+
|
|
234
|
+
class SequenceGenerator(BaseEstimator):
|
|
235
|
+
"""
|
|
236
|
+
Auto-regressive sequence generator for any token type.
|
|
237
|
+
|
|
238
|
+
Works with text (characters or words), DNA, symbol streams, event logs —
|
|
239
|
+
anything the UniversalPredictor can model.
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
context_length : int
|
|
244
|
+
Trie depth k — number of preceding tokens used as context.
|
|
245
|
+
temperature : float
|
|
246
|
+
Sampling temperature (default 1.0 = unmodified distribution).
|
|
247
|
+
top_k : int | None
|
|
248
|
+
If set, sample only from the top-k most probable tokens.
|
|
249
|
+
top_p : float | None
|
|
250
|
+
Nucleus sampling threshold (e.g. 0.9 = 90% probability mass).
|
|
251
|
+
learning_rate, cred_max, lambda_power : float
|
|
252
|
+
random_seed : int
|
|
253
|
+
|
|
254
|
+
API
|
|
255
|
+
---
|
|
256
|
+
gen.fit(sequences) — train on one sequence or a list of sequences
|
|
257
|
+
gen.partial_fit(sequences) — online update
|
|
258
|
+
gen.generate(n, seed, **kwargs) — sample n tokens given optional seed
|
|
259
|
+
gen.generate_text(n, seed, sep) — convenience wrapper: joins tokens with sep
|
|
260
|
+
gen.sample_next(**kwargs) — sample one next token from current state
|
|
261
|
+
gen.score(sequence) — average bits-per-token (lower = better)
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(
|
|
265
|
+
self,
|
|
266
|
+
context_length: int = 6,
|
|
267
|
+
temperature: float = 1.0,
|
|
268
|
+
top_k: int | None = None,
|
|
269
|
+
top_p: float | None = None,
|
|
270
|
+
learning_rate: float = 0.08,
|
|
271
|
+
cred_max: float = 6.05,
|
|
272
|
+
lambda_power: float = 0.65,
|
|
273
|
+
random_seed: int = 42,
|
|
274
|
+
use_online_tokenizer: bool = False,
|
|
275
|
+
tokenizer_max_merges: int = 64,
|
|
276
|
+
use_dual_predictor: bool = False,
|
|
277
|
+
long_term_store: Any = None,
|
|
278
|
+
use_similarity_fallback: bool = False,
|
|
279
|
+
use_positional_weights: bool = False,
|
|
280
|
+
use_semantic_hashing: bool = False,
|
|
281
|
+
use_skip_grams: bool = False,
|
|
282
|
+
):
|
|
283
|
+
self.context_length = context_length
|
|
284
|
+
self.temperature = temperature
|
|
285
|
+
self.top_k = top_k
|
|
286
|
+
self.top_p = top_p
|
|
287
|
+
self.learning_rate = learning_rate
|
|
288
|
+
self.cred_max = cred_max
|
|
289
|
+
self.lambda_power = lambda_power
|
|
290
|
+
self.random_seed = random_seed
|
|
291
|
+
self.use_online_tokenizer = use_online_tokenizer
|
|
292
|
+
self.tokenizer_max_merges = tokenizer_max_merges
|
|
293
|
+
self.use_dual_predictor = use_dual_predictor
|
|
294
|
+
self.long_term_store = long_term_store
|
|
295
|
+
self.use_similarity_fallback = use_similarity_fallback
|
|
296
|
+
self.use_positional_weights = use_positional_weights
|
|
297
|
+
self.use_semantic_hashing = use_semantic_hashing
|
|
298
|
+
self.use_skip_grams = use_skip_grams
|
|
299
|
+
|
|
300
|
+
# ── public API ────────────────────────────────────────────────────────────
|
|
301
|
+
|
|
302
|
+
def fit(self, sequences, y=None) -> 'SequenceGenerator':
|
|
303
|
+
"""
|
|
304
|
+
Train on a sequence or list of sequences.
|
|
305
|
+
sequences: str | list | list-of-lists
|
|
306
|
+
A single string is treated as a character sequence.
|
|
307
|
+
"""
|
|
308
|
+
if self.use_dual_predictor:
|
|
309
|
+
DualPredictor = _get_dual_predictor()
|
|
310
|
+
self._pred = DualPredictor(
|
|
311
|
+
self.context_length,
|
|
312
|
+
learning_rate=self.learning_rate,
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
self._pred = _make_predictor(
|
|
316
|
+
self.context_length, self.learning_rate, self.cred_max, self.lambda_power,
|
|
317
|
+
use_similarity_fallback=self.use_similarity_fallback,
|
|
318
|
+
use_positional_weights=self.use_positional_weights,
|
|
319
|
+
)
|
|
320
|
+
self._rng = random.Random(self.random_seed)
|
|
321
|
+
self._tokenizer = None
|
|
322
|
+
if self.use_online_tokenizer:
|
|
323
|
+
OT = _get_online_tokenizer()
|
|
324
|
+
self._tokenizer = OT(max_merges=self.tokenizer_max_merges)
|
|
325
|
+
self.semantic_tokenizer = SemanticTokenizer() if self.use_semantic_hashing else None
|
|
326
|
+
self.is_fitted_ = True
|
|
327
|
+
self._train_sequences(sequences)
|
|
328
|
+
return self
|
|
329
|
+
|
|
330
|
+
def partial_fit(self, sequences, y=None) -> 'SequenceGenerator':
|
|
331
|
+
if not hasattr(self, '_pred'):
|
|
332
|
+
return self.fit(sequences)
|
|
333
|
+
self._train_sequences(sequences)
|
|
334
|
+
return self
|
|
335
|
+
|
|
336
|
+
def generate(
|
|
337
|
+
self,
|
|
338
|
+
n_tokens: int,
|
|
339
|
+
seed: list | str | None = None,
|
|
340
|
+
temperature: float | None = None,
|
|
341
|
+
top_k: int | None = None,
|
|
342
|
+
top_p: float | None = None,
|
|
343
|
+
stop_tokens: list | None = None,
|
|
344
|
+
) -> list:
|
|
345
|
+
"""
|
|
346
|
+
Sample n_tokens auto-regressively.
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
seed : list | str | None
|
|
351
|
+
Starting context. A string is treated as a list of characters.
|
|
352
|
+
temperature, top_k, top_p : override instance defaults for this call.
|
|
353
|
+
stop_tokens : list | None
|
|
354
|
+
Generation halts early if any of these tokens is sampled.
|
|
355
|
+
|
|
356
|
+
Returns list of tokens (same type as training tokens).
|
|
357
|
+
"""
|
|
358
|
+
self._check_fitted()
|
|
359
|
+
seed_list = list(seed) if isinstance(seed, str) else (seed or [])
|
|
360
|
+
return _generate_from_predictor(
|
|
361
|
+
self._pred, n_tokens, seed_list,
|
|
362
|
+
temperature if temperature is not None else self.temperature,
|
|
363
|
+
top_k if top_k is not None else self.top_k,
|
|
364
|
+
top_p if top_p is not None else self.top_p,
|
|
365
|
+
self._rng,
|
|
366
|
+
set(stop_tokens) if stop_tokens else None,
|
|
367
|
+
tokenizer=getattr(self, '_tokenizer', None),
|
|
368
|
+
long_term_store=self.long_term_store,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def generate_text(
|
|
372
|
+
self,
|
|
373
|
+
n_tokens: int,
|
|
374
|
+
seed: str | None = None,
|
|
375
|
+
sep: str = '',
|
|
376
|
+
**kwargs,
|
|
377
|
+
) -> str:
|
|
378
|
+
"""Convenience wrapper: generate and join tokens as a string."""
|
|
379
|
+
tokens = self.generate(n_tokens, seed=list(seed) if seed else None, **kwargs)
|
|
380
|
+
return sep.join(str(t) for t in tokens)
|
|
381
|
+
|
|
382
|
+
def sample_next(self, temperature=None, top_k=None, top_p=None) -> Any:
|
|
383
|
+
"""Sample one next token given the current internal history."""
|
|
384
|
+
self._check_fitted()
|
|
385
|
+
self._pred.predict()
|
|
386
|
+
return _sample_dist(
|
|
387
|
+
dict(self._pred._last_distribution),
|
|
388
|
+
temperature if temperature is not None else self.temperature,
|
|
389
|
+
top_k if top_k is not None else self.top_k,
|
|
390
|
+
top_p if top_p is not None else self.top_p,
|
|
391
|
+
self._rng,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def observe(self, token) -> 'SequenceGenerator':
|
|
395
|
+
"""Advance internal state by one token (does not update trie)."""
|
|
396
|
+
self._check_fitted()
|
|
397
|
+
self._pred.observe(token)
|
|
398
|
+
return self
|
|
399
|
+
|
|
400
|
+
def score(self, sequence) -> float:
|
|
401
|
+
"""Average bits-per-token on a held-out sequence (lower = better)."""
|
|
402
|
+
self._check_fitted()
|
|
403
|
+
tokens = list(sequence)
|
|
404
|
+
if not tokens:
|
|
405
|
+
return float('inf')
|
|
406
|
+
|
|
407
|
+
if self.semantic_tokenizer:
|
|
408
|
+
tokens = [self.semantic_tokenizer.tokenize(t) for t in tokens]
|
|
409
|
+
|
|
410
|
+
saved = self._pred.history[:]
|
|
411
|
+
total = 0.0
|
|
412
|
+
for token in tokens:
|
|
413
|
+
self._pred.predict()
|
|
414
|
+
prob = max(self._pred._last_distribution.get(token, 1e-12), 1e-12)
|
|
415
|
+
total += -math.log2(prob)
|
|
416
|
+
self._pred.observe(token)
|
|
417
|
+
self._pred.history = saved
|
|
418
|
+
return total / len(tokens)
|
|
419
|
+
|
|
420
|
+
@property
|
|
421
|
+
def vocab_(self) -> set:
|
|
422
|
+
"""Set of all tokens seen during training."""
|
|
423
|
+
return set(self._pred._vocab) if hasattr(self, '_pred') else set()
|
|
424
|
+
|
|
425
|
+
# ── internal ──────────────────────────────────────────────────────────────
|
|
426
|
+
|
|
427
|
+
def _train_sequences(self, sequences) -> None:
|
|
428
|
+
tok = getattr(self, '_tokenizer', None)
|
|
429
|
+
lts = self.long_term_store
|
|
430
|
+
if isinstance(sequences, str):
|
|
431
|
+
_train_autoregressive(self._pred, list(sequences), tokenizer=tok, long_term_store=lts, use_skip_grams=self.use_skip_grams)
|
|
432
|
+
elif sequences and not isinstance(sequences[0], (list, tuple)):
|
|
433
|
+
# Flat list — treat as one sequence
|
|
434
|
+
_train_autoregressive(self._pred, list(sequences), tokenizer=tok, long_term_store=lts, use_skip_grams=self.use_skip_grams)
|
|
435
|
+
else:
|
|
436
|
+
for seq in sequences:
|
|
437
|
+
_train_autoregressive(self._pred, list(seq), tokenizer=tok, long_term_store=lts, use_skip_grams=self.use_skip_grams)
|
|
438
|
+
|
|
439
|
+
def _check_fitted(self):
|
|
440
|
+
if not hasattr(self, '_pred'):
|
|
441
|
+
raise RuntimeError("Call fit() first.")
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
445
|
+
# TabularGenerator
|
|
446
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
447
|
+
|
|
448
|
+
class TabularGenerator(BaseEstimator):
|
|
449
|
+
"""
|
|
450
|
+
Synthetic tabular row generator via joint distribution modeling.
|
|
451
|
+
|
|
452
|
+
Unlike TabularPredictor (which learns P(label | features)), this learns the
|
|
453
|
+
full joint P(f_0, f_1, ..., f_{n-1}, label) by treating every feature token
|
|
454
|
+
AND the label token as one auto-regressive sequence.
|
|
455
|
+
|
|
456
|
+
This allows:
|
|
457
|
+
• Unconditional generation — sample complete rows from the joint
|
|
458
|
+
• Class-conditional generation — fix the label, sample features
|
|
459
|
+
• Feature-conditional generation — fix some features, sample the rest
|
|
460
|
+
|
|
461
|
+
Feature order matters here (earlier features are conditioned on fewer
|
|
462
|
+
preceding tokens). MI-ascending order puts the most predictive feature
|
|
463
|
+
last so it can condition on the most context.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
n_bins : int
|
|
468
|
+
Quantile bins for continuous features.
|
|
469
|
+
context_length : int | None
|
|
470
|
+
Trie depth. None = n_features + 1 (full row context).
|
|
471
|
+
n_orderings : int
|
|
472
|
+
Ensemble size for diversity.
|
|
473
|
+
n_epochs : int
|
|
474
|
+
temperature : float
|
|
475
|
+
top_k, top_p : sampling controls
|
|
476
|
+
learning_rate, cred_max, lambda_power : float
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
def __init__(
|
|
480
|
+
self,
|
|
481
|
+
n_bins: int = 10,
|
|
482
|
+
context_length: int | None = None,
|
|
483
|
+
n_orderings: int = 3,
|
|
484
|
+
n_epochs: int = 1,
|
|
485
|
+
temperature: float = 1.0,
|
|
486
|
+
top_k: int | None = None,
|
|
487
|
+
top_p: float | None = None,
|
|
488
|
+
learning_rate: float = 0.08,
|
|
489
|
+
cred_max: float = 6.05,
|
|
490
|
+
lambda_power: float = 0.65,
|
|
491
|
+
random_seed: int = 42,
|
|
492
|
+
long_term_store: Any = None,
|
|
493
|
+
use_similarity_fallback: bool = False,
|
|
494
|
+
use_positional_weights: bool = False,
|
|
495
|
+
):
|
|
496
|
+
self.n_bins = n_bins
|
|
497
|
+
self.context_length = context_length
|
|
498
|
+
self.n_orderings = n_orderings
|
|
499
|
+
self.n_epochs = n_epochs
|
|
500
|
+
self.temperature = temperature
|
|
501
|
+
self.top_k = top_k
|
|
502
|
+
self.top_p = top_p
|
|
503
|
+
self.learning_rate = learning_rate
|
|
504
|
+
self.cred_max = cred_max
|
|
505
|
+
self.lambda_power = lambda_power
|
|
506
|
+
self.random_seed = random_seed
|
|
507
|
+
self.long_term_store = long_term_store
|
|
508
|
+
self.use_similarity_fallback = use_similarity_fallback
|
|
509
|
+
self.use_positional_weights = use_positional_weights
|
|
510
|
+
|
|
511
|
+
# ── public API ────────────────────────────────────────────────────────────
|
|
512
|
+
|
|
513
|
+
def fit(self, X, y) -> 'TabularGenerator':
|
|
514
|
+
"""
|
|
515
|
+
Fit on a labelled dataset. Learns the full joint distribution.
|
|
516
|
+
X : feature matrix (numpy, pandas, or list-of-lists)
|
|
517
|
+
y : class labels
|
|
518
|
+
"""
|
|
519
|
+
self._disc = FeatureDiscretizer(n_bins=self.n_bins)
|
|
520
|
+
self._lenc = LabelEncoder()
|
|
521
|
+
self._rng = random.Random(self.random_seed)
|
|
522
|
+
self._preds = []
|
|
523
|
+
self._orders = []
|
|
524
|
+
|
|
525
|
+
rows = self._disc.fit_transform(X)
|
|
526
|
+
labels = list(y)
|
|
527
|
+
self._lenc.fit(labels)
|
|
528
|
+
y_enc = [self._lenc.encode(l) for l in labels]
|
|
529
|
+
|
|
530
|
+
n_feat = self._disc.n_features
|
|
531
|
+
k = (n_feat + 1) if self.context_length is None else self.context_length
|
|
532
|
+
self._orders = _build_orders(rows, y_enc, n_feat, self.n_orderings, self._rng)
|
|
533
|
+
self._preds = [
|
|
534
|
+
_make_predictor(k, self.learning_rate, self.cred_max, self.lambda_power)
|
|
535
|
+
for _ in self._orders
|
|
536
|
+
]
|
|
537
|
+
|
|
538
|
+
# A second predictor trained label-first enables class-conditional generation.
|
|
539
|
+
# P(f0,...,fn-1 | label) is modelled correctly only when label precedes features.
|
|
540
|
+
# MI-descending order: most discriminative features immediately follow the label,
|
|
541
|
+
# so even shallow-context fallbacks (label + 1 feature) capture class structure.
|
|
542
|
+
self._cond_pred = _make_predictor(k, self.learning_rate, self.cred_max, self.lambda_power)
|
|
543
|
+
self._cond_order = list(reversed(self._orders[0])) # MI-descending for conditional
|
|
544
|
+
|
|
545
|
+
for _ in range(self.n_epochs):
|
|
546
|
+
pairs = list(zip(rows, labels))
|
|
547
|
+
self._rng.shuffle(pairs)
|
|
548
|
+
for tok_row, label in pairs:
|
|
549
|
+
lt = (_LABEL_NS, self._lenc.encode(label))
|
|
550
|
+
for p, order in zip(self._preds, self._orders):
|
|
551
|
+
full_seq = _apply_order(tok_row, order) + [lt]
|
|
552
|
+
_train_autoregressive(p, full_seq)
|
|
553
|
+
# label-first sequence for conditional predictor
|
|
554
|
+
cond_seq = [lt] + _apply_order(tok_row, self._cond_order)
|
|
555
|
+
_train_autoregressive(self._cond_pred, cond_seq)
|
|
556
|
+
|
|
557
|
+
self.is_fitted_ = True
|
|
558
|
+
return self
|
|
559
|
+
|
|
560
|
+
def partial_fit(self, X, y) -> 'TabularGenerator':
|
|
561
|
+
if not hasattr(self, '_disc'):
|
|
562
|
+
return self.fit(X, y)
|
|
563
|
+
rows = self._disc.transform(X)
|
|
564
|
+
labels = list(y)
|
|
565
|
+
self._lenc.partial_fit(labels)
|
|
566
|
+
for tok_row, label in zip(rows, labels):
|
|
567
|
+
lt = (_LABEL_NS, self._lenc.encode(label))
|
|
568
|
+
for p, order in zip(self._preds, self._orders):
|
|
569
|
+
full_seq = _apply_order(tok_row, order) + [lt]
|
|
570
|
+
_train_autoregressive(p, full_seq)
|
|
571
|
+
cond_seq = [lt] + _apply_order(tok_row, self._cond_order)
|
|
572
|
+
_train_autoregressive(self._cond_pred, cond_seq)
|
|
573
|
+
return self
|
|
574
|
+
|
|
575
|
+
def sample(
|
|
576
|
+
self,
|
|
577
|
+
n_rows: int = 1,
|
|
578
|
+
given_label: Any = None,
|
|
579
|
+
given_features: dict | None = None,
|
|
580
|
+
temperature: float | None = None,
|
|
581
|
+
top_k: int | None = None,
|
|
582
|
+
top_p: float | None = None,
|
|
583
|
+
) -> list:
|
|
584
|
+
"""
|
|
585
|
+
Generate n_rows synthetic rows.
|
|
586
|
+
|
|
587
|
+
Parameters
|
|
588
|
+
----------
|
|
589
|
+
given_label : any class label
|
|
590
|
+
If set, condition on this label (label-first generation).
|
|
591
|
+
given_features : {feature_index: value} | None
|
|
592
|
+
Fix specific feature values; sample the rest.
|
|
593
|
+
temperature, top_k, top_p : override instance sampling defaults.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
list of dicts: [{'features': [...], 'label': ...}, ...]
|
|
598
|
+
"""
|
|
599
|
+
self._check_fitted()
|
|
600
|
+
T = temperature if temperature is not None else self.temperature
|
|
601
|
+
K = top_k if top_k is not None else self.top_k
|
|
602
|
+
P = top_p if top_p is not None else self.top_p
|
|
603
|
+
out = []
|
|
604
|
+
for _ in range(n_rows):
|
|
605
|
+
out.append(self._sample_row(given_label, given_features, T, K, P))
|
|
606
|
+
return out
|
|
607
|
+
|
|
608
|
+
def sample_dataframe(self, n_rows: int = 1, **kwargs):
|
|
609
|
+
"""Like sample() but returns a pandas DataFrame. Requires pandas."""
|
|
610
|
+
import pandas as pd
|
|
611
|
+
rows = self.sample(n_rows, **kwargs)
|
|
612
|
+
X_cols = {f'feature_{i}': [r['features'][i] for r in rows]
|
|
613
|
+
for i in range(self._disc.n_features)}
|
|
614
|
+
X_cols['label'] = [r['label'] for r in rows]
|
|
615
|
+
return pd.DataFrame(X_cols)
|
|
616
|
+
|
|
617
|
+
# ── internal ──────────────────────────────────────────────────────────────
|
|
618
|
+
|
|
619
|
+
def _label_token(self, label) -> tuple:
|
|
620
|
+
return (_LABEL_NS, self._lenc.encode(label))
|
|
621
|
+
|
|
622
|
+
def _sample_row(self, given_label, given_features, T, K, P) -> dict:
|
|
623
|
+
n_feat = self._disc.n_features
|
|
624
|
+
classes = self._lenc.classes_
|
|
625
|
+
|
|
626
|
+
# Average distributions across all ordering predictors
|
|
627
|
+
# for each position in the feature sequence
|
|
628
|
+
def avg_dist_at_context(context_tokens):
|
|
629
|
+
combined: dict = {}
|
|
630
|
+
for p, order in zip(self._preds, self._orders):
|
|
631
|
+
saved = p.history[:]
|
|
632
|
+
p.history = list(context_tokens)
|
|
633
|
+
p.predict()
|
|
634
|
+
for tok, prob in p._last_distribution.items():
|
|
635
|
+
combined[tok] = combined.get(tok, 0.0) + prob
|
|
636
|
+
p.history = saved
|
|
637
|
+
total = sum(combined.values())
|
|
638
|
+
if total < 1e-12:
|
|
639
|
+
return combined
|
|
640
|
+
return {t: v / total for t, v in combined.items()}
|
|
641
|
+
|
|
642
|
+
if given_label is not None:
|
|
643
|
+
# Class-conditional generation using the label-first predictor.
|
|
644
|
+
# _cond_pred was trained on [label, f0, f1, ..., fn-1] so
|
|
645
|
+
# P(f_i | label, f_0..f_{i-1}) is correctly modelled here.
|
|
646
|
+
lt = self._label_token(given_label)
|
|
647
|
+
ctx = [lt]
|
|
648
|
+
feature_values = [None] * n_feat
|
|
649
|
+
|
|
650
|
+
def cond_dist_at(context_tokens):
|
|
651
|
+
saved = self._cond_pred.history[:]
|
|
652
|
+
self._cond_pred.history = list(context_tokens)
|
|
653
|
+
self._cond_pred.predict()
|
|
654
|
+
d = dict(self._cond_pred._last_distribution)
|
|
655
|
+
self._cond_pred.history = saved
|
|
656
|
+
return d
|
|
657
|
+
|
|
658
|
+
for col_idx in self._cond_order:
|
|
659
|
+
dist = cond_dist_at(ctx)
|
|
660
|
+
feat_dist = {t: v for t, v in dist.items()
|
|
661
|
+
if isinstance(t, tuple) and len(t) == 2
|
|
662
|
+
and isinstance(t[0], int) and t[0] == col_idx}
|
|
663
|
+
token = _sample_dist(feat_dist, T, K, P, self._rng) if feat_dist else (col_idx, 0)
|
|
664
|
+
ctx.append(token)
|
|
665
|
+
feature_values[col_idx] = self._decode_feature_token(token)
|
|
666
|
+
|
|
667
|
+
return {'features': feature_values, 'label': given_label}
|
|
668
|
+
|
|
669
|
+
else:
|
|
670
|
+
# Unconditional: sample features in MI order, then label
|
|
671
|
+
# Average across orderings for each step
|
|
672
|
+
order = self._orders[0]
|
|
673
|
+
context = []
|
|
674
|
+
feature_values = [None] * n_feat
|
|
675
|
+
|
|
676
|
+
for col_idx in order:
|
|
677
|
+
dist = avg_dist_at_context(context)
|
|
678
|
+
feat_dist = {t: v for t, v in dist.items()
|
|
679
|
+
if isinstance(t, tuple) and len(t) == 2
|
|
680
|
+
and isinstance(t[0], int) and t[0] == col_idx}
|
|
681
|
+
if not feat_dist:
|
|
682
|
+
token = (col_idx, 0)
|
|
683
|
+
else:
|
|
684
|
+
token = _sample_dist(feat_dist, T, K, P, self._rng)
|
|
685
|
+
context.append(token)
|
|
686
|
+
feature_values[col_idx] = self._decode_feature_token(token)
|
|
687
|
+
|
|
688
|
+
# Sample label given all features
|
|
689
|
+
dist = avg_dist_at_context(context)
|
|
690
|
+
lbl_dist = {t: v for t, v in dist.items()
|
|
691
|
+
if isinstance(t, tuple) and t[0] == _LABEL_NS}
|
|
692
|
+
if lbl_dist:
|
|
693
|
+
lbl_token = _sample_dist(lbl_dist, T, K, P, self._rng)
|
|
694
|
+
label = self._lenc.decode(lbl_token[1]) if lbl_token else classes[0]
|
|
695
|
+
else:
|
|
696
|
+
label = classes[0]
|
|
697
|
+
|
|
698
|
+
return {'features': feature_values, 'label': label}
|
|
699
|
+
|
|
700
|
+
def _decode_feature_token(self, token):
|
|
701
|
+
"""Convert (col_idx, bin_or_code) back to an approximate feature value."""
|
|
702
|
+
if token is None:
|
|
703
|
+
return None
|
|
704
|
+
col_idx, bin_val = token
|
|
705
|
+
if col_idx < len(self._disc._types):
|
|
706
|
+
if self._disc._types[col_idx] == 'numeric':
|
|
707
|
+
return self._disc.bin_center(col_idx, bin_val)
|
|
708
|
+
else:
|
|
709
|
+
# Categorical: reverse the int code
|
|
710
|
+
cat_map = self._disc._cat_maps.get(col_idx, {})
|
|
711
|
+
rev = {v: k for k, v in cat_map.items()}
|
|
712
|
+
return rev.get(bin_val, bin_val)
|
|
713
|
+
return bin_val
|
|
714
|
+
|
|
715
|
+
def _check_fitted(self):
|
|
716
|
+
if not hasattr(self, '_disc'):
|
|
717
|
+
raise RuntimeError("Call fit() first.")
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
721
|
+
# TimeSeriesGenerator
|
|
722
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
723
|
+
|
|
724
|
+
class TimeSeriesGenerator(BaseEstimator):
|
|
725
|
+
"""
|
|
726
|
+
Sampled multivariate time series generator.
|
|
727
|
+
|
|
728
|
+
Extends the predictor to draw new sequences from the learned distribution
|
|
729
|
+
instead of returning the argmax/mean (which forecast() does).
|
|
730
|
+
|
|
731
|
+
Can be used for data augmentation, simulation, or scenario generation.
|
|
732
|
+
|
|
733
|
+
Parameters
|
|
734
|
+
----------
|
|
735
|
+
n_bins : int
|
|
736
|
+
context_length : int
|
|
737
|
+
temperature : float
|
|
738
|
+
top_k, top_p : sampling controls (useful for avoiding repetitive sequences)
|
|
739
|
+
learning_rate, cred_max, lambda_power : float
|
|
740
|
+
|
|
741
|
+
API
|
|
742
|
+
---
|
|
743
|
+
gen.fit(X) — fit discretizer and build trie
|
|
744
|
+
gen.generate(n_steps, seed, ...) — sample a new sequence of n_steps
|
|
745
|
+
gen.augment(X, n_copies, ...) — generate n_copies similar sequences
|
|
746
|
+
gen.score(X) — average bits-per-step
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
def __init__(
|
|
750
|
+
self,
|
|
751
|
+
n_bins: int = 8,
|
|
752
|
+
context_length: int = 5,
|
|
753
|
+
temperature: float = 1.0,
|
|
754
|
+
top_k: int | None = None,
|
|
755
|
+
top_p: float | None = None,
|
|
756
|
+
learning_rate: float = 0.08,
|
|
757
|
+
cred_max: float = 6.05,
|
|
758
|
+
lambda_power: float = 0.65,
|
|
759
|
+
random_seed: int = 42,
|
|
760
|
+
use_dual_predictor: bool = False,
|
|
761
|
+
long_term_store: Any = None,
|
|
762
|
+
use_similarity_fallback: bool = False,
|
|
763
|
+
use_positional_weights: bool = False,
|
|
764
|
+
):
|
|
765
|
+
self.n_bins = n_bins
|
|
766
|
+
self.context_length = context_length
|
|
767
|
+
self.temperature = temperature
|
|
768
|
+
self.top_k = top_k
|
|
769
|
+
self.top_p = top_p
|
|
770
|
+
self.learning_rate = learning_rate
|
|
771
|
+
self.cred_max = cred_max
|
|
772
|
+
self.lambda_power = lambda_power
|
|
773
|
+
self.random_seed = random_seed
|
|
774
|
+
self.use_dual_predictor = use_dual_predictor
|
|
775
|
+
self.long_term_store = long_term_store
|
|
776
|
+
self.use_similarity_fallback = use_similarity_fallback
|
|
777
|
+
self.use_positional_weights = use_positional_weights
|
|
778
|
+
|
|
779
|
+
# ── public API ────────────────────────────────────────────────────────────
|
|
780
|
+
|
|
781
|
+
def fit(self, X, y=None) -> 'TimeSeriesGenerator':
|
|
782
|
+
from .timeseries import _compound_token, _make_predictor as _ts_make
|
|
783
|
+
|
|
784
|
+
rows = _to_rows(X)
|
|
785
|
+
if not rows:
|
|
786
|
+
return self
|
|
787
|
+
if not isinstance(rows[0], (list, tuple)):
|
|
788
|
+
rows = [[v] for v in rows]
|
|
789
|
+
|
|
790
|
+
self._n_dims = len(rows[0])
|
|
791
|
+
self._disc = FeatureDiscretizer(n_bins=self.n_bins)
|
|
792
|
+
self._disc.fit(rows)
|
|
793
|
+
self._pred = _ts_make(
|
|
794
|
+
self.context_length, self.learning_rate, self.cred_max, self.lambda_power,
|
|
795
|
+
)
|
|
796
|
+
self._rng = random.Random(self.random_seed)
|
|
797
|
+
self._compound_token = _compound_token
|
|
798
|
+
|
|
799
|
+
for row in rows:
|
|
800
|
+
token = _compound_token(self._disc._encode_row(row))
|
|
801
|
+
self._pred.predict()
|
|
802
|
+
self._pred.observe(token)
|
|
803
|
+
self._pred.feedback(token)
|
|
804
|
+
|
|
805
|
+
self.is_fitted_ = True
|
|
806
|
+
return self
|
|
807
|
+
|
|
808
|
+
def generate(
|
|
809
|
+
self,
|
|
810
|
+
n_steps: int,
|
|
811
|
+
seed: list | None = None,
|
|
812
|
+
temperature: float | None = None,
|
|
813
|
+
top_k: int | None = None,
|
|
814
|
+
top_p: float | None = None,
|
|
815
|
+
) -> list:
|
|
816
|
+
"""
|
|
817
|
+
Sample a new time series of n_steps steps.
|
|
818
|
+
|
|
819
|
+
seed: list of float vectors to prime the context.
|
|
820
|
+
Returns list of float vectors (one per step).
|
|
821
|
+
"""
|
|
822
|
+
self._check_fitted()
|
|
823
|
+
T = temperature if temperature is not None else self.temperature
|
|
824
|
+
K = top_k if top_k is not None else self.top_k
|
|
825
|
+
P = top_p if top_p is not None else self.top_p
|
|
826
|
+
|
|
827
|
+
saved = self._pred.history[:]
|
|
828
|
+
|
|
829
|
+
if seed:
|
|
830
|
+
for x in seed:
|
|
831
|
+
row = [x] if isinstance(x, (int, float)) else list(x)
|
|
832
|
+
self._pred.observe(self._compound_token(self._disc._encode_row(row)))
|
|
833
|
+
|
|
834
|
+
results = []
|
|
835
|
+
for _ in range(n_steps):
|
|
836
|
+
self._pred.predict()
|
|
837
|
+
dist = dict(self._pred._last_distribution)
|
|
838
|
+
token = _sample_dist(dist, T, K, P, self._rng)
|
|
839
|
+
if token is None:
|
|
840
|
+
break
|
|
841
|
+
results.append(self._decode_token(token))
|
|
842
|
+
self._pred.observe(token)
|
|
843
|
+
|
|
844
|
+
self._pred.history = saved
|
|
845
|
+
return results
|
|
846
|
+
|
|
847
|
+
def augment(
|
|
848
|
+
self,
|
|
849
|
+
X,
|
|
850
|
+
n_copies: int = 1,
|
|
851
|
+
temperature: float = 1.1,
|
|
852
|
+
**kwargs,
|
|
853
|
+
) -> list:
|
|
854
|
+
"""
|
|
855
|
+
Generate n_copies perturbed versions of X by seeding with X then sampling.
|
|
856
|
+
|
|
857
|
+
temperature > 1.0 adds variety; temperature < 1.0 stays close to X.
|
|
858
|
+
Returns list of generated sequences.
|
|
859
|
+
"""
|
|
860
|
+
self._check_fitted()
|
|
861
|
+
rows = _to_rows(X)
|
|
862
|
+
if not rows:
|
|
863
|
+
return []
|
|
864
|
+
if not isinstance(rows[0], (list, tuple)):
|
|
865
|
+
rows = [[v] for v in rows]
|
|
866
|
+
|
|
867
|
+
generated = []
|
|
868
|
+
for _ in range(n_copies):
|
|
869
|
+
aug = self.generate(
|
|
870
|
+
n_steps=len(rows),
|
|
871
|
+
seed=rows,
|
|
872
|
+
temperature=temperature,
|
|
873
|
+
**kwargs,
|
|
874
|
+
)
|
|
875
|
+
generated.append(aug)
|
|
876
|
+
return generated
|
|
877
|
+
|
|
878
|
+
def score(self, X, y=None) -> float:
|
|
879
|
+
"""Average bits-per-step (lower = better). Trie not updated."""
|
|
880
|
+
self._check_fitted()
|
|
881
|
+
rows = _to_rows(X)
|
|
882
|
+
if not rows:
|
|
883
|
+
return float('inf')
|
|
884
|
+
if not isinstance(rows[0], (list, tuple)):
|
|
885
|
+
rows = [[v] for v in rows]
|
|
886
|
+
|
|
887
|
+
saved = self._pred.history[:]
|
|
888
|
+
total = 0.0
|
|
889
|
+
for row in rows:
|
|
890
|
+
token = self._compound_token(self._disc._encode_row(row))
|
|
891
|
+
self._pred.predict()
|
|
892
|
+
prob = max(self._pred._last_distribution.get(token, 1e-12), 1e-12)
|
|
893
|
+
total += -math.log2(prob)
|
|
894
|
+
self._pred.observe(token)
|
|
895
|
+
self._pred.history = saved
|
|
896
|
+
return total / len(rows)
|
|
897
|
+
|
|
898
|
+
# ── internal ──────────────────────────────────────────────────────────────
|
|
899
|
+
|
|
900
|
+
def _decode_token(self, token) -> list:
|
|
901
|
+
mid = self.n_bins // 2
|
|
902
|
+
fallback = [self._disc.bin_center(d, mid) for d in range(self._n_dims)]
|
|
903
|
+
if not isinstance(token, tuple) or len(token) != self._n_dims:
|
|
904
|
+
return fallback
|
|
905
|
+
return [self._disc.bin_center(d, b) if isinstance(b, int) else 0.0
|
|
906
|
+
for d, b in enumerate(token)]
|
|
907
|
+
|
|
908
|
+
def _check_fitted(self):
|
|
909
|
+
if not hasattr(self, '_pred'):
|
|
910
|
+
raise RuntimeError("Call fit() first.")
|