nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
# Explicit paper-defined variant (avoid inheriting repo default `hope_hybrid`).
|
|
7
|
+
block_variant: hope_attention
|
|
8
|
+
# Paper-faithful: treat "surprise" as the (scaled) teach signal itself, without threshold gating.
|
|
9
|
+
surprise_threshold: null
|
|
10
|
+
# Paper updates on the last (possibly partial) chunk; enable flush for non-multiple seq lengths.
|
|
11
|
+
cms_flush_partial_at_end: true
|
|
12
|
+
# Paper: q is non-adaptive and uses a fixed projection.
|
|
13
|
+
self_mod_adaptive_q: false
|
|
14
|
+
# Paper: local causal conv in the HOPE self-mod module.
|
|
15
|
+
self_mod_local_conv_window: 4
|
|
16
|
+
|
|
17
|
+
data:
|
|
18
|
+
# Paper-faithful semantics: CMS/TITAN fast state is per-context; this repo currently treats
|
|
19
|
+
# each *batch* as a single shared context when batch_size>1.
|
|
20
|
+
batch_size: 1
|
|
21
|
+
|
|
22
|
+
train:
|
|
23
|
+
algorithm_mode: two_pass_stopgrad_updates
|
|
24
|
+
# Keep this explicit (instead of inherited) so paper-faithful behavior is visible in one file.
|
|
25
|
+
online_updates: true
|
|
26
|
+
# Paper: re-initialize fast memories per context (sequence).
|
|
27
|
+
use_fast_state: true
|
|
28
|
+
strict_streaming_contract: true
|
|
29
|
+
# Use explicit boundary-token supervision (no overlap approximation).
|
|
30
|
+
online_boundary_targets: true
|
|
31
|
+
# Carry attention state across chunks during online updates.
|
|
32
|
+
online_carry_attention_cache: true
|
|
33
|
+
# Fail fast if DDP would silently disable paper-critical features.
|
|
34
|
+
fail_if_paper_faithful_disabled: true
|
|
35
|
+
|
|
36
|
+
optim:
|
|
37
|
+
# Ensure meta-learning updates include memory module initial states (paper §8.2).
|
|
38
|
+
param_policy: all
|
|
39
|
+
|
|
40
|
+
logging:
|
|
41
|
+
run_name: pilot-paper-faithful
|
|
42
|
+
path: logs/pilot_paper_faithful_metrics.json
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot_paper_faithful
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
# Chunk update cadence (paper §8.2): other memories update more often than M_memory.
|
|
8
|
+
self_mod_chunk_size: 8
|
|
9
|
+
self_mod_chunk_size_memory: 64
|
|
10
|
+
self_mod_use_skip: false
|
|
11
|
+
|
|
12
|
+
train:
|
|
13
|
+
checkpoint:
|
|
14
|
+
dir: artifacts/checkpoints/pilot_selfmod_paper_faithful
|
|
15
|
+
|
|
16
|
+
logging:
|
|
17
|
+
run_name: pilot-selfmod-paper-faithful
|
|
18
|
+
path: logs/pilot_selfmod_paper_faithful_metrics.json
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
|
|
8
|
+
model:
|
|
9
|
+
vocab_size: 32000
|
|
10
|
+
dim: 128
|
|
11
|
+
num_layers: 2
|
|
12
|
+
heads: 4
|
|
13
|
+
titan_level:
|
|
14
|
+
name: titan
|
|
15
|
+
update_period: 8
|
|
16
|
+
optimizer_key: titan_opt
|
|
17
|
+
cms_levels:
|
|
18
|
+
- name: cms_fast
|
|
19
|
+
update_period: 1
|
|
20
|
+
optimizer_key: cms_opt
|
|
21
|
+
- name: cms_mid
|
|
22
|
+
update_period: 4
|
|
23
|
+
optimizer_key: cms_opt
|
|
24
|
+
- name: cms_slow
|
|
25
|
+
update_period: 16
|
|
26
|
+
optimizer_key: cms_opt
|
|
27
|
+
optimizers:
|
|
28
|
+
titan_opt:
|
|
29
|
+
type: deep_momentum
|
|
30
|
+
lr: 1.0e-3
|
|
31
|
+
params:
|
|
32
|
+
beta: 0.9
|
|
33
|
+
beta2: 0.999
|
|
34
|
+
cms_opt:
|
|
35
|
+
type: deep_momentum
|
|
36
|
+
lr: 5.0e-4
|
|
37
|
+
params:
|
|
38
|
+
beta: 0.9
|
|
39
|
+
beta2: 0.999
|
|
40
|
+
|
|
41
|
+
data:
|
|
42
|
+
source: synthetic
|
|
43
|
+
vocab_size: 32000
|
|
44
|
+
seq_len: 64
|
|
45
|
+
dataset_size: 1024
|
|
46
|
+
batch_size: 4
|
|
47
|
+
num_workers: 0
|
|
48
|
+
|
|
49
|
+
train:
|
|
50
|
+
strict_streaming_contract: false
|
|
51
|
+
online_updates: true
|
|
52
|
+
online_chunk_size: 0
|
|
53
|
+
online_boundary_targets: false
|
|
54
|
+
online_carry_attention_cache: false
|
|
55
|
+
per_layer_teach_signal: true
|
|
56
|
+
steps: 10
|
|
57
|
+
log_interval: 1
|
|
58
|
+
device: "cpu"
|
|
59
|
+
seed: 1234
|
|
60
|
+
deterministic: true
|
|
61
|
+
mixed_precision:
|
|
62
|
+
enabled: false
|
|
63
|
+
dtype: bf16
|
|
64
|
+
compile:
|
|
65
|
+
enable: false
|
|
66
|
+
checkpoint:
|
|
67
|
+
enable: true
|
|
68
|
+
dir: artifacts/checkpoints/pilot_smoke
|
|
69
|
+
save_interval: 10
|
|
70
|
+
save_last: true
|
|
71
|
+
|
|
72
|
+
optim:
|
|
73
|
+
type: adamw
|
|
74
|
+
lr: 3.0e-4
|
|
75
|
+
fused: false
|
|
76
|
+
|
|
77
|
+
logging:
|
|
78
|
+
enabled: true
|
|
79
|
+
backend: json
|
|
80
|
+
path: logs/pilot_smoke.json
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
model:
|
|
8
|
+
vocab_size: 32000
|
|
9
|
+
dim: 384
|
|
10
|
+
num_layers: 8
|
|
11
|
+
heads: 6
|
|
12
|
+
teach_scale: 0.1
|
|
13
|
+
teach_clip: 5.0
|
|
14
|
+
self_mod_lr: 0.001
|
|
15
|
+
teach_schedule:
|
|
16
|
+
warmup_steps: 2000
|
|
17
|
+
decay_start: 120000
|
|
18
|
+
decay_duration: 20000
|
|
19
|
+
titan_level:
|
|
20
|
+
name: titan
|
|
21
|
+
update_period: 8
|
|
22
|
+
optimizer_key: titan_opt
|
|
23
|
+
cms_levels:
|
|
24
|
+
- name: cms_fast
|
|
25
|
+
update_period: 8
|
|
26
|
+
optimizer_key: cms_opt
|
|
27
|
+
- name: cms_mid
|
|
28
|
+
update_period: 32
|
|
29
|
+
optimizer_key: cms_opt
|
|
30
|
+
- name: cms_slow
|
|
31
|
+
update_period: 128
|
|
32
|
+
optimizer_key: cms_opt
|
|
33
|
+
- name: cms_ultra
|
|
34
|
+
update_period: 512
|
|
35
|
+
optimizer_key: cms_opt
|
|
36
|
+
optimizers:
|
|
37
|
+
titan_opt:
|
|
38
|
+
type: deep_momentum
|
|
39
|
+
lr: 0.0006
|
|
40
|
+
params:
|
|
41
|
+
beta: 0.9
|
|
42
|
+
beta2: 0.999
|
|
43
|
+
cms_opt:
|
|
44
|
+
type: deep_momentum
|
|
45
|
+
lr: 0.0003
|
|
46
|
+
params:
|
|
47
|
+
beta: 0.9
|
|
48
|
+
beta2: 0.999
|
|
49
|
+
cms_hidden_multiplier: 2
|
|
50
|
+
data:
|
|
51
|
+
source: mixture
|
|
52
|
+
seq_len: 1024
|
|
53
|
+
batch_size: 2
|
|
54
|
+
num_workers: 2
|
|
55
|
+
mixture:
|
|
56
|
+
samples_per_epoch: 65536
|
|
57
|
+
seed: 1337
|
|
58
|
+
sources:
|
|
59
|
+
- name: refinedweb
|
|
60
|
+
shards_dir: data/shards/refinedweb_filtered
|
|
61
|
+
weight: 0.4
|
|
62
|
+
- name: wikipedia
|
|
63
|
+
shards_dir: data/shards/wikipedia_filtered
|
|
64
|
+
weight: 0.2
|
|
65
|
+
- name: c4
|
|
66
|
+
shards_dir: data/shards/c4_filtered
|
|
67
|
+
weight: 0.15
|
|
68
|
+
- name: redpajama
|
|
69
|
+
shards_dir: data/shards/redpajama_filtered
|
|
70
|
+
weight: 0.15
|
|
71
|
+
- name: code
|
|
72
|
+
shards_dir: data/shards/code_filtered
|
|
73
|
+
weight: 0.1
|
|
74
|
+
train:
|
|
75
|
+
online_updates: true
|
|
76
|
+
online_chunk_size: 0
|
|
77
|
+
per_layer_teach_signal: true
|
|
78
|
+
steps: 5000
|
|
79
|
+
log_interval: 25
|
|
80
|
+
device: cuda:1
|
|
81
|
+
seed: 1337
|
|
82
|
+
deterministic: false
|
|
83
|
+
mixed_precision:
|
|
84
|
+
enabled: true
|
|
85
|
+
dtype: bf16
|
|
86
|
+
compile:
|
|
87
|
+
enable: false
|
|
88
|
+
mode: max-autotune
|
|
89
|
+
checkpoint:
|
|
90
|
+
enable: true
|
|
91
|
+
dir: artifacts/checkpoints/pilot_cms_sparse
|
|
92
|
+
save_interval: 1000
|
|
93
|
+
save_last: true
|
|
94
|
+
resume_path: null
|
|
95
|
+
resume_tag: null
|
|
96
|
+
optim:
|
|
97
|
+
type: adamw
|
|
98
|
+
lr: 0.00025
|
|
99
|
+
fused: auto
|
|
100
|
+
logging:
|
|
101
|
+
enabled: true
|
|
102
|
+
backend: json
|
|
103
|
+
path: logs/pilot_cms_sparse_metrics.json
|
|
104
|
+
project: nested-learning
|
|
105
|
+
run_name: pilot-cms-sparse
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
model:
|
|
2
|
+
vocab_size: 32000
|
|
3
|
+
dim: 512
|
|
4
|
+
num_layers: 12
|
|
5
|
+
heads: 8
|
|
6
|
+
teach_scale: 0.10
|
|
7
|
+
teach_clip: 5.0
|
|
8
|
+
surprise_threshold: 0.02
|
|
9
|
+
freeze_backbone: false
|
|
10
|
+
qk_l2_norm: true
|
|
11
|
+
local_conv_window: 4
|
|
12
|
+
block_variant: hope_attention
|
|
13
|
+
teach_schedule:
|
|
14
|
+
warmup_steps: 2000
|
|
15
|
+
decay_start: 120000
|
|
16
|
+
decay_duration: 20000
|
|
17
|
+
titan_level:
|
|
18
|
+
name: titan
|
|
19
|
+
update_period: 8
|
|
20
|
+
optimizer_key: titan_opt
|
|
21
|
+
cms_levels:
|
|
22
|
+
- name: cms_fast
|
|
23
|
+
update_period: 1
|
|
24
|
+
optimizer_key: cms_opt
|
|
25
|
+
- name: cms_mid
|
|
26
|
+
update_period: 4
|
|
27
|
+
optimizer_key: cms_opt
|
|
28
|
+
- name: cms_slow
|
|
29
|
+
update_period: 32
|
|
30
|
+
optimizer_key: cms_opt
|
|
31
|
+
- name: cms_ultra
|
|
32
|
+
update_period: 128
|
|
33
|
+
optimizer_key: cms_opt
|
|
34
|
+
optimizers:
|
|
35
|
+
titan_opt:
|
|
36
|
+
type: deep_momentum
|
|
37
|
+
lr: 6.0e-4
|
|
38
|
+
params:
|
|
39
|
+
beta: 0.9
|
|
40
|
+
beta2: 0.999
|
|
41
|
+
variant: nl_l2_precond
|
|
42
|
+
cms_opt:
|
|
43
|
+
type: deep_momentum
|
|
44
|
+
lr: 3.0e-4
|
|
45
|
+
params:
|
|
46
|
+
beta: 0.9
|
|
47
|
+
beta2: 0.999
|
|
48
|
+
variant: nl_l2_precond
|
|
49
|
+
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
model:
|
|
2
|
+
vocab_size: 32000
|
|
3
|
+
dim: 512
|
|
4
|
+
num_layers: 12
|
|
5
|
+
heads: 8
|
|
6
|
+
teach_scale: 0.10
|
|
7
|
+
teach_clip: 5.0
|
|
8
|
+
surprise_threshold: 0.02
|
|
9
|
+
freeze_backbone: false
|
|
10
|
+
qk_l2_norm: true
|
|
11
|
+
local_conv_window: 4
|
|
12
|
+
block_variant: transformer
|
|
13
|
+
teach_schedule:
|
|
14
|
+
warmup_steps: 2000
|
|
15
|
+
decay_start: 120000
|
|
16
|
+
decay_duration: 20000
|
|
17
|
+
titan_level:
|
|
18
|
+
name: titan
|
|
19
|
+
update_period: 8
|
|
20
|
+
optimizer_key: titan_opt
|
|
21
|
+
cms_levels:
|
|
22
|
+
- name: cms_fast
|
|
23
|
+
update_period: 1
|
|
24
|
+
optimizer_key: cms_opt
|
|
25
|
+
- name: cms_mid
|
|
26
|
+
update_period: 4
|
|
27
|
+
optimizer_key: cms_opt
|
|
28
|
+
- name: cms_slow
|
|
29
|
+
update_period: 32
|
|
30
|
+
optimizer_key: cms_opt
|
|
31
|
+
- name: cms_ultra
|
|
32
|
+
update_period: 128
|
|
33
|
+
optimizer_key: cms_opt
|
|
34
|
+
optimizers:
|
|
35
|
+
titan_opt:
|
|
36
|
+
type: deep_momentum
|
|
37
|
+
lr: 6.0e-4
|
|
38
|
+
params:
|
|
39
|
+
beta: 0.9
|
|
40
|
+
beta2: 0.999
|
|
41
|
+
variant: nl_l2_precond
|
|
42
|
+
cms_opt:
|
|
43
|
+
type: deep_momentum
|
|
44
|
+
lr: 3.0e-4
|
|
45
|
+
params:
|
|
46
|
+
beta: 0.9
|
|
47
|
+
beta2: 0.999
|
|
48
|
+
variant: nl_l2_precond
|
|
49
|
+
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Iterable, List, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class ClassificationExample:
|
|
9
|
+
text: str
|
|
10
|
+
label: str
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class LoadedClassificationDataset:
|
|
15
|
+
name: str
|
|
16
|
+
split: str
|
|
17
|
+
examples: List[ClassificationExample]
|
|
18
|
+
label_names: List[str]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_hf_classification_dataset(
|
|
22
|
+
dataset: str,
|
|
23
|
+
*,
|
|
24
|
+
split: str,
|
|
25
|
+
text_field: str,
|
|
26
|
+
label_field: str,
|
|
27
|
+
name: str | None = None,
|
|
28
|
+
max_samples: int | None = None,
|
|
29
|
+
) -> LoadedClassificationDataset:
|
|
30
|
+
"""
|
|
31
|
+
Load a HuggingFace `datasets` text classification dataset into a simple in-memory format.
|
|
32
|
+
|
|
33
|
+
This is used by the Phase 4 continual-learning harness (CLINC/Banking/DBpedia).
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
37
|
+
except Exception as exc: # pragma: no cover
|
|
38
|
+
raise RuntimeError(
|
|
39
|
+
"`datasets` dependency is required for continual classification."
|
|
40
|
+
) from exc
|
|
41
|
+
|
|
42
|
+
ds = load_dataset(dataset, name=name, split=split)
|
|
43
|
+
features = getattr(ds, "features", None)
|
|
44
|
+
label_names: List[str] = []
|
|
45
|
+
if features is not None and label_field in features:
|
|
46
|
+
feature = features[label_field]
|
|
47
|
+
if getattr(feature, "names", None) is not None:
|
|
48
|
+
label_names = list(feature.names)
|
|
49
|
+
|
|
50
|
+
examples: List[ClassificationExample] = []
|
|
51
|
+
count = 0
|
|
52
|
+
for row in ds:
|
|
53
|
+
if max_samples is not None and count >= max_samples:
|
|
54
|
+
break
|
|
55
|
+
text = str(row[text_field])
|
|
56
|
+
raw_label = row[label_field]
|
|
57
|
+
if isinstance(raw_label, int) and label_names:
|
|
58
|
+
label = label_names[raw_label]
|
|
59
|
+
else:
|
|
60
|
+
label = str(raw_label)
|
|
61
|
+
examples.append(ClassificationExample(text=text, label=label))
|
|
62
|
+
count += 1
|
|
63
|
+
|
|
64
|
+
if not label_names:
|
|
65
|
+
label_names = sorted({ex.label for ex in examples})
|
|
66
|
+
|
|
67
|
+
return LoadedClassificationDataset(
|
|
68
|
+
name=dataset if name is None else f"{dataset}:{name}",
|
|
69
|
+
split=split,
|
|
70
|
+
examples=examples,
|
|
71
|
+
label_names=label_names,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def load_clinc_oos(
|
|
76
|
+
*,
|
|
77
|
+
split: str = "test",
|
|
78
|
+
max_samples: int | None = None,
|
|
79
|
+
) -> LoadedClassificationDataset:
|
|
80
|
+
# HF dataset: "clinc_oos" with fields {"text", "intent"}.
|
|
81
|
+
return load_hf_classification_dataset(
|
|
82
|
+
"clinc_oos",
|
|
83
|
+
split=split,
|
|
84
|
+
text_field="text",
|
|
85
|
+
label_field="intent",
|
|
86
|
+
max_samples=max_samples,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def load_banking77(
|
|
91
|
+
*,
|
|
92
|
+
split: str = "test",
|
|
93
|
+
max_samples: int | None = None,
|
|
94
|
+
) -> LoadedClassificationDataset:
|
|
95
|
+
# HF dataset: "banking77" with fields {"text", "label"}.
|
|
96
|
+
return load_hf_classification_dataset(
|
|
97
|
+
"banking77",
|
|
98
|
+
split=split,
|
|
99
|
+
text_field="text",
|
|
100
|
+
label_field="label",
|
|
101
|
+
max_samples=max_samples,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def load_dbpedia14(
|
|
106
|
+
*,
|
|
107
|
+
split: str = "test",
|
|
108
|
+
max_samples: int | None = None,
|
|
109
|
+
) -> LoadedClassificationDataset:
|
|
110
|
+
# HF dataset: "dbpedia_14" with fields {"content", "label"}.
|
|
111
|
+
return load_hf_classification_dataset(
|
|
112
|
+
"dbpedia_14",
|
|
113
|
+
split=split,
|
|
114
|
+
text_field="content",
|
|
115
|
+
label_field="label",
|
|
116
|
+
max_samples=max_samples,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def unique_labels(examples: Iterable[ClassificationExample]) -> List[str]:
|
|
121
|
+
seen = set()
|
|
122
|
+
ordered: List[str] = []
|
|
123
|
+
for ex in examples:
|
|
124
|
+
if ex.label in seen:
|
|
125
|
+
continue
|
|
126
|
+
seen.add(ex.label)
|
|
127
|
+
ordered.append(ex.label)
|
|
128
|
+
return ordered
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def filter_examples_by_labels(
|
|
132
|
+
examples: Sequence[ClassificationExample],
|
|
133
|
+
*,
|
|
134
|
+
allowed: set[str],
|
|
135
|
+
) -> List[ClassificationExample]:
|
|
136
|
+
return [ex for ex in examples if ex.label in allowed]
|