flash-ansr 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash_ansr/__init__.py +29 -0
- flash_ansr/__main__.py +347 -0
- flash_ansr/baselines/__init__.py +7 -0
- flash_ansr/baselines/brute_force_model.py +306 -0
- flash_ansr/baselines/skeleton_pool_model.py +361 -0
- flash_ansr/benchmarks/__init__.py +5 -0
- flash_ansr/benchmarks/fastsrb.py +519 -0
- flash_ansr/compat/__init__.py +1 -0
- flash_ansr/compat/convert_data.py +349 -0
- flash_ansr/compat/evaluation_nesymres.py +85 -0
- flash_ansr/compat/evaluation_pysr.py +95 -0
- flash_ansr/compat/nesymres.py +74 -0
- flash_ansr/data/__init__.py +3 -0
- flash_ansr/data/collate.py +197 -0
- flash_ansr/data/data.py +514 -0
- flash_ansr/data/streaming.py +390 -0
- flash_ansr/decoding/mcts.py +552 -0
- flash_ansr/eval/__init__.py +2 -0
- flash_ansr/eval/core.py +109 -0
- flash_ansr/eval/data_sources.py +963 -0
- flash_ansr/eval/engine.py +250 -0
- flash_ansr/eval/evaluation.py +156 -0
- flash_ansr/eval/evaluation_fastsrb.py +287 -0
- flash_ansr/eval/metrics/__init__.py +10 -0
- flash_ansr/eval/metrics/bootstrap.py +31 -0
- flash_ansr/eval/metrics/token_prediction.py +464 -0
- flash_ansr/eval/metrics/zss.py +42 -0
- flash_ansr/eval/model_adapters.py +885 -0
- flash_ansr/eval/result_store.py +117 -0
- flash_ansr/eval/run_config.py +679 -0
- flash_ansr/eval/sample_metadata.py +83 -0
- flash_ansr/expressions/__init__.py +1 -0
- flash_ansr/expressions/compilation.py +37 -0
- flash_ansr/expressions/distributions.py +147 -0
- flash_ansr/expressions/holdout.py +85 -0
- flash_ansr/expressions/normalization.py +73 -0
- flash_ansr/expressions/prior_factory.py +31 -0
- flash_ansr/expressions/skeleton_pool.py +755 -0
- flash_ansr/expressions/skeleton_sampling.py +129 -0
- flash_ansr/expressions/structure.py +32 -0
- flash_ansr/expressions/support_sampling.py +457 -0
- flash_ansr/expressions/token_ops.py +128 -0
- flash_ansr/flash_ansr.py +1018 -0
- flash_ansr/generation/__init__.py +3 -0
- flash_ansr/generation/beam.py +47 -0
- flash_ansr/generation/mcts.py +127 -0
- flash_ansr/model/__init__.py +24 -0
- flash_ansr/model/common/__init__.py +9 -0
- flash_ansr/model/common/components.py +118 -0
- flash_ansr/model/decoders/__init__.py +10 -0
- flash_ansr/model/decoders/components.py +239 -0
- flash_ansr/model/decoders/transformer.py +84 -0
- flash_ansr/model/encoders/__init__.py +9 -0
- flash_ansr/model/encoders/base.py +78 -0
- flash_ansr/model/encoders/set_transformer.py +454 -0
- flash_ansr/model/factory.py +41 -0
- flash_ansr/model/flash_ansr_model.py +835 -0
- flash_ansr/model/manage.py +37 -0
- flash_ansr/model/pre_encoder.py +29 -0
- flash_ansr/model/tokenizer.py +279 -0
- flash_ansr/preprocessing/__init__.py +14 -0
- flash_ansr/preprocessing/feature_extractor.py +834 -0
- flash_ansr/preprocessing/pipeline.py +245 -0
- flash_ansr/preprocessing/prompt_serialization.py +240 -0
- flash_ansr/preprocessing/schemas.py +23 -0
- flash_ansr/refine.py +531 -0
- flash_ansr/results.py +142 -0
- flash_ansr/train/__init__.py +3 -0
- flash_ansr/train/optimizers.py +14 -0
- flash_ansr/train/schedules.py +21 -0
- flash_ansr/train/train.py +903 -0
- flash_ansr/utils/__init__.py +22 -0
- flash_ansr/utils/config_io.py +155 -0
- flash_ansr/utils/generation.py +262 -0
- flash_ansr/utils/numeric.py +105 -0
- flash_ansr/utils/paths.py +39 -0
- flash_ansr/utils/tensor_ops.py +59 -0
- flash_ansr-0.4.2.dist-info/METADATA +152 -0
- flash_ansr-0.4.2.dist-info/RECORD +83 -0
- flash_ansr-0.4.2.dist-info/WHEEL +5 -0
- flash_ansr-0.4.2.dist-info/entry_points.txt +2 -0
- flash_ansr-0.4.2.dist-info/licenses/LICENSE +21 -0
- flash_ansr-0.4.2.dist-info/top_level.txt +1 -0
flash_ansr/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .model import (
|
|
2
|
+
ModelFactory,
|
|
3
|
+
FlashANSRModel,
|
|
4
|
+
SetTransformer,
|
|
5
|
+
Tokenizer,
|
|
6
|
+
RotaryEmbedding,
|
|
7
|
+
IEEE75432PreEncoder,
|
|
8
|
+
install_model,
|
|
9
|
+
remove_model,
|
|
10
|
+
)
|
|
11
|
+
from .expressions import SkeletonPool, NoValidSampleFoundError
|
|
12
|
+
from .utils import (
|
|
13
|
+
GenerationConfig,
|
|
14
|
+
GenerationConfigBase,
|
|
15
|
+
BeamSearchConfig,
|
|
16
|
+
SoftmaxSamplingConfig,
|
|
17
|
+
MCTSGenerationConfig,
|
|
18
|
+
create_generation_config,
|
|
19
|
+
get_path,
|
|
20
|
+
load_config,
|
|
21
|
+
save_config,
|
|
22
|
+
substitute_root_path,
|
|
23
|
+
)
|
|
24
|
+
from .eval import Evaluation
|
|
25
|
+
from .refine import Refiner, ConvergenceError
|
|
26
|
+
from .flash_ansr import FlashANSR
|
|
27
|
+
from .baselines import SkeletonPoolModel, BruteForceModel
|
|
28
|
+
from .data.data import FlashANSRDataset
|
|
29
|
+
from .preprocessing import FlashANSRPreprocessor
|
flash_ansr/__main__.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import argparse
|
|
3
|
+
import sys
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main(argv: str = None) -> None:
|
|
8
|
+
parser = argparse.ArgumentParser(description='Neural Symbolic Regression')
|
|
9
|
+
subparsers = parser.add_subparsers(dest='command_name', required=True)
|
|
10
|
+
|
|
11
|
+
generate_skeleton_pool_parser = subparsers.add_parser("generate-skeleton-pool")
|
|
12
|
+
generate_skeleton_pool_parser.add_argument('-s', '--size', type=str, required=True, help='Size of the skeleton pool')
|
|
13
|
+
generate_skeleton_pool_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
|
|
14
|
+
generate_skeleton_pool_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the configuration file')
|
|
15
|
+
generate_skeleton_pool_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
16
|
+
generate_skeleton_pool_parser.add_argument('--output-reference', type=str, default='relative', help='Reference type for the output directory')
|
|
17
|
+
generate_skeleton_pool_parser.add_argument('--output-recursive', type=bool, default=True, help='Whether to recursively save the configuration')
|
|
18
|
+
|
|
19
|
+
import_test_data_parser = subparsers.add_parser("import-data")
|
|
20
|
+
import_test_data_parser.add_argument('-i', '--input', type=str, required=True, help='Path to the dataset file (CSV or YAML) from Biggio et al. or other benchmarks')
|
|
21
|
+
import_test_data_parser.add_argument('-b', '--base-skeleton-pool', type=str, required=True, help='Path to the base skeleton pool')
|
|
22
|
+
import_test_data_parser.add_argument('-p', '--parser', type=str, required=True, help='Name of the parser to use')
|
|
23
|
+
import_test_data_parser.add_argument('-e', '--simplipy-engine', type=str, required=True, help='Path to the expression space configuration file')
|
|
24
|
+
import_test_data_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
|
|
25
|
+
import_test_data_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
26
|
+
|
|
27
|
+
filter_skeleton_pool_parser = subparsers.add_parser("filter-skeleton-pool")
|
|
28
|
+
filter_skeleton_pool_parser.add_argument('-s', '--source', type=str, required=True, help='Path to the source skeleton pool')
|
|
29
|
+
filter_skeleton_pool_parser.add_argument('-f', '--holdouts', nargs='+', required=True, help='Paths to the holdout skeleton pools')
|
|
30
|
+
filter_skeleton_pool_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
|
|
31
|
+
filter_skeleton_pool_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
32
|
+
|
|
33
|
+
split_skeleton_pool_parser = subparsers.add_parser("split-skeleton-pool")
|
|
34
|
+
split_skeleton_pool_parser.add_argument('-i', '--input', type=str, required=True, help='Path to the input skeleton pool')
|
|
35
|
+
split_skeleton_pool_parser.add_argument('-t', '--train-size', type=float, default=0.8, help='Size of the training set')
|
|
36
|
+
split_skeleton_pool_parser.add_argument('-r', '--random-state', type=int, default=None, help='Random seed for shuffling')
|
|
37
|
+
|
|
38
|
+
train_parser = subparsers.add_parser("train")
|
|
39
|
+
train_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the configuration file')
|
|
40
|
+
train_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
41
|
+
train_parser.add_argument('-o', '--output-dir', type=str, default='.', help='Path to the output directory')
|
|
42
|
+
train_parser.add_argument('-ci', '--checkpoint-interval', type=int, default=None, help='Interval for saving checkpoints')
|
|
43
|
+
train_parser.add_argument('-vi', '--validate-interval', type=int, default=None, help='Interval for validating the model')
|
|
44
|
+
train_parser.add_argument('-w', '--num_workers', type=int, default=None, help='Number of worker processes for data generation')
|
|
45
|
+
train_parser.add_argument('--project', type=str, default='neural-symbolic-regression', help='Name of the wandb project')
|
|
46
|
+
train_parser.add_argument('--entity', type=str, default='psaegert', help='Name of the wandb entity')
|
|
47
|
+
train_parser.add_argument('--name', type=str, default=None, help='Name of the wandb run')
|
|
48
|
+
train_parser.add_argument('--mode', type=str, default='online', help='Mode for wandb logging')
|
|
49
|
+
train_parser.add_argument('--resume-from', type=str, default=None, help='Path to a checkpoint directory to resume from')
|
|
50
|
+
train_parser.add_argument('--resume-step', type=int, default=None, help='Override the inferred resume step when resuming')
|
|
51
|
+
|
|
52
|
+
evaluate_run_parser = subparsers.add_parser("evaluate-run", help="Run an evaluation from a unified config")
|
|
53
|
+
evaluate_run_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the evaluation run config file')
|
|
54
|
+
evaluate_run_parser.add_argument('-n', '--limit', type=int, default=None, help='Override the sample limit specified in the config')
|
|
55
|
+
evaluate_run_parser.add_argument('-o', '--output-file', type=str, default=None, help='Override the output file path from the config')
|
|
56
|
+
evaluate_run_parser.add_argument('--save-every', type=int, default=None, help='Override periodic save frequency')
|
|
57
|
+
evaluate_run_parser.add_argument('--no-resume', action='store_true', help='Ignore previous results even if the output file exists')
|
|
58
|
+
evaluate_run_parser.add_argument('--experiment', type=str, default=None, help='Name of the experiment defined in the config to execute')
|
|
59
|
+
evaluate_run_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
60
|
+
|
|
61
|
+
wandb_stats_parser = subparsers.add_parser("wandb-stats")
|
|
62
|
+
wandb_stats_parser.add_argument('--project', type=str, default='neural-symbolic-regression', help='Name of the wandb project')
|
|
63
|
+
wandb_stats_parser.add_argument('--entity', type=str, default='psaegert', help='Name of the wandb entity')
|
|
64
|
+
wandb_stats_parser.add_argument('-o', '--output-file', type=str, default='wandb_stats.csv', help='Path to the output file')
|
|
65
|
+
|
|
66
|
+
benchmark_parser = subparsers.add_parser("benchmark")
|
|
67
|
+
benchmark_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the dataset configuration file')
|
|
68
|
+
benchmark_parser.add_argument('-n', '--samples', type=int, default=10_000, help='Number of samples to evaluate')
|
|
69
|
+
benchmark_parser.add_argument('-b', '--batch-size', type=int, default=128, help='Batch size for the dataset')
|
|
70
|
+
benchmark_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
71
|
+
|
|
72
|
+
install_parser = subparsers.add_parser("install", help="Install a model")
|
|
73
|
+
install_parser.add_argument("model", type=str, help="Model identifier to install")
|
|
74
|
+
|
|
75
|
+
remove_parser = subparsers.add_parser("remove", help="Remove a model")
|
|
76
|
+
remove_parser.add_argument("path", type=str, help="Path to the model to remove")
|
|
77
|
+
|
|
78
|
+
find_simplifications_parser = subparsers.add_parser("find-simplifications")
|
|
79
|
+
find_simplifications_parser.add_argument('-e', '--simplipy-engine', type=str, required=True, help='Path to the expression space configuration file')
|
|
80
|
+
find_simplifications_parser.add_argument('-n', '--max_n_rules', type=int, default=None, help='Maximum number of rules to find')
|
|
81
|
+
find_simplifications_parser.add_argument('-l', '--max_pattern_length', type=int, default=7, help='Maximum length of the patterns to find')
|
|
82
|
+
find_simplifications_parser.add_argument('-t', '--timeout', type=int, default=None, help='Timeout for the search of simplifications in seconds')
|
|
83
|
+
find_simplifications_parser.add_argument('-d', '--dummy-variables', type=int, nargs='+', default=None, help='Dummy variables to use in the simplifications')
|
|
84
|
+
find_simplifications_parser.add_argument('-m', '--max-simplify-steps', type=int, default=5, help='Maximum number of simplification steps')
|
|
85
|
+
find_simplifications_parser.add_argument('-x', '--X', type=int, default=1024, help='Number of samples to use for comparison of images')
|
|
86
|
+
find_simplifications_parser.add_argument('-c', '--C', type=int, default=1024, help='Number of samples of constants to put in to placeholders')
|
|
87
|
+
find_simplifications_parser.add_argument('-r', '--constants-fit-retries', type=int, default=5, help='Number of retries for fitting the constants')
|
|
88
|
+
find_simplifications_parser.add_argument('-o', '--output-file', type=str, required=True, help='Path to the output json file')
|
|
89
|
+
find_simplifications_parser.add_argument('-s', '--save-every', type=int, default=100, help='Save the simplifications every n rules')
|
|
90
|
+
find_simplifications_parser.add_argument('--reset-rules', action='store_true', help='Reset the rules before finding new ones')
|
|
91
|
+
find_simplifications_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
92
|
+
|
|
93
|
+
# Evaluate input
|
|
94
|
+
args = parser.parse_args(argv)
|
|
95
|
+
|
|
96
|
+
# Execute the command
|
|
97
|
+
match args.command_name:
|
|
98
|
+
case 'generate-skeleton-pool':
|
|
99
|
+
if args.verbose:
|
|
100
|
+
print(f'Generating skeleton pool from {args.config}')
|
|
101
|
+
from flash_ansr.expressions import SkeletonPool
|
|
102
|
+
|
|
103
|
+
skeleton_pool = SkeletonPool.from_config(args.config)
|
|
104
|
+
skeleton_pool.create(size=int(args.size), verbose=args.verbose)
|
|
105
|
+
|
|
106
|
+
if args.verbose:
|
|
107
|
+
print(f"Saving skeleton pool to {args.output_dir}")
|
|
108
|
+
skeleton_pool.save(directory=args.output_dir, config=args.config, reference=args.output_reference, recursive=args.output_recursive)
|
|
109
|
+
|
|
110
|
+
case 'import-data':
|
|
111
|
+
if args.verbose:
|
|
112
|
+
print(f'Importing data from {args.input}')
|
|
113
|
+
from simplipy import SimpliPyEngine
|
|
114
|
+
from flash_ansr.expressions import SkeletonPool
|
|
115
|
+
from flash_ansr.compat import ParserFactory
|
|
116
|
+
from flash_ansr.utils.config_io import load_config
|
|
117
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
118
|
+
|
|
119
|
+
import pandas as pd
|
|
120
|
+
import yaml
|
|
121
|
+
from pathlib import Path
|
|
122
|
+
|
|
123
|
+
simplipy_engine = SimpliPyEngine.load(args.simplipy_engine, install=True)
|
|
124
|
+
base_skeleton_pool = SkeletonPool.from_config(args.base_skeleton_pool)
|
|
125
|
+
input_path = substitute_root_path(args.input)
|
|
126
|
+
path_obj = Path(input_path)
|
|
127
|
+
|
|
128
|
+
if path_obj.suffix.lower() in {'.yaml', '.yml'}:
|
|
129
|
+
with open(input_path, 'r', encoding='utf-8') as handle:
|
|
130
|
+
raw_data = yaml.safe_load(handle)
|
|
131
|
+
|
|
132
|
+
if not isinstance(raw_data, dict):
|
|
133
|
+
raise ValueError('Expected YAML benchmark file to contain a mapping of equation identifiers to entries.')
|
|
134
|
+
|
|
135
|
+
records = []
|
|
136
|
+
for identifier, payload in raw_data.items():
|
|
137
|
+
if not isinstance(payload, dict):
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
record = {'id': identifier}
|
|
141
|
+
record.update(payload)
|
|
142
|
+
if 'prepared' in record and record['prepared'] is None:
|
|
143
|
+
# Normalise missing prepared expressions to empty strings for downstream filtering.
|
|
144
|
+
record['prepared'] = ''
|
|
145
|
+
records.append(record)
|
|
146
|
+
|
|
147
|
+
df = pd.DataFrame.from_records(records)
|
|
148
|
+
else:
|
|
149
|
+
df = pd.read_csv(input_path)
|
|
150
|
+
|
|
151
|
+
data_parser = ParserFactory.get_parser(args.parser)
|
|
152
|
+
test_skeleton_pool: SkeletonPool = data_parser.parse_data(df, simplipy_engine, base_skeleton_pool, verbose=args.verbose)
|
|
153
|
+
|
|
154
|
+
if args.verbose:
|
|
155
|
+
print(f"Saving test set to {args.output_dir}")
|
|
156
|
+
|
|
157
|
+
test_skeleton_pool.save(directory=args.output_dir, config=args.base_skeleton_pool, reference='relative', recursive=True)
|
|
158
|
+
|
|
159
|
+
case 'split-skeleton-pool':
|
|
160
|
+
print(f'Splitting skeleton pool from {args.input}')
|
|
161
|
+
import os
|
|
162
|
+
from flash_ansr.expressions import SkeletonPool
|
|
163
|
+
|
|
164
|
+
print(f"Loading skeleton pool from {args.input}")
|
|
165
|
+
|
|
166
|
+
config, skeleton_pool = SkeletonPool.load(args.input)
|
|
167
|
+
train_skeleton_pool, val_skeleton_pool = skeleton_pool.split(train_size=args.train_size, random_state=args.random_state)
|
|
168
|
+
|
|
169
|
+
train_path = os.path.join(args.input, 'train')
|
|
170
|
+
val_path = os.path.join(args.input, 'val')
|
|
171
|
+
|
|
172
|
+
train_config = deepcopy(config)
|
|
173
|
+
val_config = deepcopy(config)
|
|
174
|
+
|
|
175
|
+
print(f"Saving training pool to {train_path}")
|
|
176
|
+
print(f"Saving validation pool to {val_path}")
|
|
177
|
+
|
|
178
|
+
train_skeleton_pool.save(directory=train_path, config=train_config, reference='relative', recursive=True)
|
|
179
|
+
val_skeleton_pool.save(directory=val_path, config=val_config, reference='relative', recursive=True)
|
|
180
|
+
|
|
181
|
+
case 'train':
|
|
182
|
+
if args.verbose:
|
|
183
|
+
print(f'Training model from {args.config}')
|
|
184
|
+
from flash_ansr.train.train import Trainer
|
|
185
|
+
from flash_ansr.utils.config_io import load_config, save_config
|
|
186
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
187
|
+
|
|
188
|
+
trainer = Trainer.from_config(args.config)
|
|
189
|
+
|
|
190
|
+
config = load_config(args.config)
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
trainer.run(
|
|
194
|
+
project_name=args.project,
|
|
195
|
+
entity=args.entity,
|
|
196
|
+
name=args.name,
|
|
197
|
+
steps=config['steps'],
|
|
198
|
+
preprocess=config.get('preprocess', False),
|
|
199
|
+
device=config['device'],
|
|
200
|
+
compile_mode=config.get('compile_mode'),
|
|
201
|
+
checkpoint_interval=args.checkpoint_interval,
|
|
202
|
+
checkpoint_directory=substitute_root_path(args.output_dir),
|
|
203
|
+
validate_interval=args.validate_interval,
|
|
204
|
+
validate_size=config.get('val_size', None),
|
|
205
|
+
validate_batch_size=config.get('val_batch_size', None),
|
|
206
|
+
wandb_watch_log=config.get('wandb_watch_log', None),
|
|
207
|
+
wandb_watch_log_freq=config.get('wandb_watch_log_freq', 1000),
|
|
208
|
+
wandb_mode=args.mode,
|
|
209
|
+
num_workers=args.num_workers,
|
|
210
|
+
resume_from=args.resume_from,
|
|
211
|
+
resume_step=args.resume_step,
|
|
212
|
+
verbose=args.verbose,
|
|
213
|
+
)
|
|
214
|
+
except KeyboardInterrupt:
|
|
215
|
+
print("Training interrupted. Saving model...")
|
|
216
|
+
|
|
217
|
+
trainer.model.save(directory=args.output_dir, errors='ignore')
|
|
218
|
+
|
|
219
|
+
save_config(
|
|
220
|
+
load_config(args.config, resolve_paths=True),
|
|
221
|
+
directory=substitute_root_path(args.output_dir),
|
|
222
|
+
filename='train.yaml',
|
|
223
|
+
reference='relative',
|
|
224
|
+
recursive=True,
|
|
225
|
+
resolve_paths=True)
|
|
226
|
+
|
|
227
|
+
print(f"Saved model to {args.output_dir}")
|
|
228
|
+
|
|
229
|
+
case 'evaluate-run':
|
|
230
|
+
from flash_ansr.eval.run_config import build_evaluation_run, EvaluationRunPlan
|
|
231
|
+
from flash_ansr.utils.config_io import load_config
|
|
232
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
233
|
+
|
|
234
|
+
config_path = substitute_root_path(args.config)
|
|
235
|
+
if args.verbose:
|
|
236
|
+
print(f"Running evaluation plan from {config_path}")
|
|
237
|
+
|
|
238
|
+
raw_config = load_config(config_path)
|
|
239
|
+
experiment_map = raw_config.get("experiments") if isinstance(raw_config, dict) else None
|
|
240
|
+
|
|
241
|
+
def _execute_plan(plan: EvaluationRunPlan, experiment_name: str | None = None) -> None:
|
|
242
|
+
label = f"[{experiment_name}] " if experiment_name else ""
|
|
243
|
+
if plan.completed or plan.engine is None:
|
|
244
|
+
if args.verbose:
|
|
245
|
+
target = plan.total_limit or 'configured'
|
|
246
|
+
print(f"{label}Evaluation already completed ({plan.existing_results}/{target}). Nothing to do.")
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
plan.engine.run(
|
|
250
|
+
limit=plan.remaining,
|
|
251
|
+
save_every=plan.save_every,
|
|
252
|
+
output_path=plan.output_path,
|
|
253
|
+
verbose=args.verbose,
|
|
254
|
+
progress=args.verbose,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if args.verbose:
|
|
258
|
+
total = plan.engine.result_store.size
|
|
259
|
+
destination = plan.output_path or 'memory'
|
|
260
|
+
print(f"{label}Evaluation finished with {total} samples (saved to {destination}).")
|
|
261
|
+
|
|
262
|
+
if experiment_map and args.experiment is None:
|
|
263
|
+
experiment_names = list(experiment_map.keys())
|
|
264
|
+
if args.verbose:
|
|
265
|
+
count = len(experiment_names)
|
|
266
|
+
print(f"No --experiment provided; running all {count} experiments defined in config.")
|
|
267
|
+
for experiment_name in experiment_names:
|
|
268
|
+
if args.verbose:
|
|
269
|
+
print(f"--> {experiment_name}")
|
|
270
|
+
plan = build_evaluation_run(
|
|
271
|
+
config=config_path,
|
|
272
|
+
limit_override=args.limit,
|
|
273
|
+
output_override=args.output_file,
|
|
274
|
+
save_every_override=args.save_every,
|
|
275
|
+
resume=None if not args.no_resume else False,
|
|
276
|
+
experiment=experiment_name,
|
|
277
|
+
)
|
|
278
|
+
_execute_plan(plan, experiment_name)
|
|
279
|
+
else:
|
|
280
|
+
plan = build_evaluation_run(
|
|
281
|
+
config=config_path,
|
|
282
|
+
limit_override=args.limit,
|
|
283
|
+
output_override=args.output_file,
|
|
284
|
+
save_every_override=args.save_every,
|
|
285
|
+
resume=None if not args.no_resume else False,
|
|
286
|
+
experiment=args.experiment,
|
|
287
|
+
)
|
|
288
|
+
_execute_plan(plan, args.experiment)
|
|
289
|
+
|
|
290
|
+
case 'wandb-stats':
|
|
291
|
+
print(f'Fetching stats from wandb project {args.project} and entity {args.entity}')
|
|
292
|
+
import os
|
|
293
|
+
import wandb
|
|
294
|
+
import pandas as pd
|
|
295
|
+
|
|
296
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
297
|
+
|
|
298
|
+
api = wandb.Api() # type: ignore
|
|
299
|
+
|
|
300
|
+
runs = api.runs(f'{args.entity}/{args.project}')
|
|
301
|
+
runs = {run.id: {'run': run} for run in runs}
|
|
302
|
+
|
|
303
|
+
for key, value in runs.items():
|
|
304
|
+
start_time = datetime.datetime.strptime(value['run'].created_at, '%Y-%m-%dT%H:%M:%S') + datetime.timedelta(hours=2) # HACK: This is a hack to convert to CET
|
|
305
|
+
end_time = datetime.datetime.strptime(value['run'].heartbeatAt, '%Y-%m-%dT%H:%M:%S') + datetime.timedelta(hours=2)
|
|
306
|
+
runs[key]['start_time'] = start_time
|
|
307
|
+
runs[key]['end_time'] = end_time
|
|
308
|
+
runs[key]['duration'] = end_time - start_time
|
|
309
|
+
runs[key]['name'] = value['run'].name
|
|
310
|
+
|
|
311
|
+
df = pd.DataFrame.from_dict(runs, orient='index').drop(columns=['run'])
|
|
312
|
+
|
|
313
|
+
save_path = substitute_root_path(args.output_file)
|
|
314
|
+
if save_path:
|
|
315
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
316
|
+
df.to_csv(save_path)
|
|
317
|
+
|
|
318
|
+
case 'benchmark':
|
|
319
|
+
if args.verbose:
|
|
320
|
+
print(f'Benchmarking dataset {args.config}')
|
|
321
|
+
from flash_ansr.data import FlashANSRDataset
|
|
322
|
+
from flash_ansr.utils.config_io import load_config, save_config
|
|
323
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
324
|
+
import pandas as pd
|
|
325
|
+
|
|
326
|
+
dataset = FlashANSRDataset.from_config(substitute_root_path(args.config))
|
|
327
|
+
|
|
328
|
+
results = dataset._benchmark(n_samples=args.samples, batch_size=args.batch_size, verbose=args.verbose)
|
|
329
|
+
|
|
330
|
+
print(f'Iteration time: {1e3 * results["mean_iteration_time"]:.0f} ± {1e3 * results["std_iteration_time"]:.0f} ms')
|
|
331
|
+
print(f'Range: {1e3 * results["min_iteration_time"]:.0f} - {1e3 * results["max_iteration_time"]:.0f} ms')
|
|
332
|
+
|
|
333
|
+
case 'install':
|
|
334
|
+
from flash_ansr.model.manage import install_model
|
|
335
|
+
install_model(args.model)
|
|
336
|
+
|
|
337
|
+
case 'remove':
|
|
338
|
+
from flash_ansr.model.manage import remove_model
|
|
339
|
+
remove_model(args.path)
|
|
340
|
+
|
|
341
|
+
case _:
|
|
342
|
+
parser.print_help()
|
|
343
|
+
sys.exit(1)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
if __name__ == '__main__':
|
|
347
|
+
main()
|