flash-ansr 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. flash_ansr/__init__.py +29 -0
  2. flash_ansr/__main__.py +347 -0
  3. flash_ansr/baselines/__init__.py +7 -0
  4. flash_ansr/baselines/brute_force_model.py +306 -0
  5. flash_ansr/baselines/skeleton_pool_model.py +361 -0
  6. flash_ansr/benchmarks/__init__.py +5 -0
  7. flash_ansr/benchmarks/fastsrb.py +519 -0
  8. flash_ansr/compat/__init__.py +1 -0
  9. flash_ansr/compat/convert_data.py +349 -0
  10. flash_ansr/compat/evaluation_nesymres.py +85 -0
  11. flash_ansr/compat/evaluation_pysr.py +95 -0
  12. flash_ansr/compat/nesymres.py +74 -0
  13. flash_ansr/data/__init__.py +3 -0
  14. flash_ansr/data/collate.py +197 -0
  15. flash_ansr/data/data.py +514 -0
  16. flash_ansr/data/streaming.py +390 -0
  17. flash_ansr/decoding/mcts.py +552 -0
  18. flash_ansr/eval/__init__.py +2 -0
  19. flash_ansr/eval/core.py +109 -0
  20. flash_ansr/eval/data_sources.py +963 -0
  21. flash_ansr/eval/engine.py +250 -0
  22. flash_ansr/eval/evaluation.py +156 -0
  23. flash_ansr/eval/evaluation_fastsrb.py +287 -0
  24. flash_ansr/eval/metrics/__init__.py +10 -0
  25. flash_ansr/eval/metrics/bootstrap.py +31 -0
  26. flash_ansr/eval/metrics/token_prediction.py +464 -0
  27. flash_ansr/eval/metrics/zss.py +42 -0
  28. flash_ansr/eval/model_adapters.py +885 -0
  29. flash_ansr/eval/result_store.py +117 -0
  30. flash_ansr/eval/run_config.py +679 -0
  31. flash_ansr/eval/sample_metadata.py +83 -0
  32. flash_ansr/expressions/__init__.py +1 -0
  33. flash_ansr/expressions/compilation.py +37 -0
  34. flash_ansr/expressions/distributions.py +147 -0
  35. flash_ansr/expressions/holdout.py +85 -0
  36. flash_ansr/expressions/normalization.py +73 -0
  37. flash_ansr/expressions/prior_factory.py +31 -0
  38. flash_ansr/expressions/skeleton_pool.py +755 -0
  39. flash_ansr/expressions/skeleton_sampling.py +129 -0
  40. flash_ansr/expressions/structure.py +32 -0
  41. flash_ansr/expressions/support_sampling.py +457 -0
  42. flash_ansr/expressions/token_ops.py +128 -0
  43. flash_ansr/flash_ansr.py +1018 -0
  44. flash_ansr/generation/__init__.py +3 -0
  45. flash_ansr/generation/beam.py +47 -0
  46. flash_ansr/generation/mcts.py +127 -0
  47. flash_ansr/model/__init__.py +24 -0
  48. flash_ansr/model/common/__init__.py +9 -0
  49. flash_ansr/model/common/components.py +118 -0
  50. flash_ansr/model/decoders/__init__.py +10 -0
  51. flash_ansr/model/decoders/components.py +239 -0
  52. flash_ansr/model/decoders/transformer.py +84 -0
  53. flash_ansr/model/encoders/__init__.py +9 -0
  54. flash_ansr/model/encoders/base.py +78 -0
  55. flash_ansr/model/encoders/set_transformer.py +454 -0
  56. flash_ansr/model/factory.py +41 -0
  57. flash_ansr/model/flash_ansr_model.py +835 -0
  58. flash_ansr/model/manage.py +37 -0
  59. flash_ansr/model/pre_encoder.py +29 -0
  60. flash_ansr/model/tokenizer.py +279 -0
  61. flash_ansr/preprocessing/__init__.py +14 -0
  62. flash_ansr/preprocessing/feature_extractor.py +834 -0
  63. flash_ansr/preprocessing/pipeline.py +245 -0
  64. flash_ansr/preprocessing/prompt_serialization.py +240 -0
  65. flash_ansr/preprocessing/schemas.py +23 -0
  66. flash_ansr/refine.py +531 -0
  67. flash_ansr/results.py +142 -0
  68. flash_ansr/train/__init__.py +3 -0
  69. flash_ansr/train/optimizers.py +14 -0
  70. flash_ansr/train/schedules.py +21 -0
  71. flash_ansr/train/train.py +903 -0
  72. flash_ansr/utils/__init__.py +22 -0
  73. flash_ansr/utils/config_io.py +155 -0
  74. flash_ansr/utils/generation.py +262 -0
  75. flash_ansr/utils/numeric.py +105 -0
  76. flash_ansr/utils/paths.py +39 -0
  77. flash_ansr/utils/tensor_ops.py +59 -0
  78. flash_ansr-0.4.2.dist-info/METADATA +152 -0
  79. flash_ansr-0.4.2.dist-info/RECORD +83 -0
  80. flash_ansr-0.4.2.dist-info/WHEEL +5 -0
  81. flash_ansr-0.4.2.dist-info/entry_points.txt +2 -0
  82. flash_ansr-0.4.2.dist-info/licenses/LICENSE +21 -0
  83. flash_ansr-0.4.2.dist-info/top_level.txt +1 -0
flash_ansr/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ from .model import (
2
+ ModelFactory,
3
+ FlashANSRModel,
4
+ SetTransformer,
5
+ Tokenizer,
6
+ RotaryEmbedding,
7
+ IEEE75432PreEncoder,
8
+ install_model,
9
+ remove_model,
10
+ )
11
+ from .expressions import SkeletonPool, NoValidSampleFoundError
12
+ from .utils import (
13
+ GenerationConfig,
14
+ GenerationConfigBase,
15
+ BeamSearchConfig,
16
+ SoftmaxSamplingConfig,
17
+ MCTSGenerationConfig,
18
+ create_generation_config,
19
+ get_path,
20
+ load_config,
21
+ save_config,
22
+ substitute_root_path,
23
+ )
24
+ from .eval import Evaluation
25
+ from .refine import Refiner, ConvergenceError
26
+ from .flash_ansr import FlashANSR
27
+ from .baselines import SkeletonPoolModel, BruteForceModel
28
+ from .data.data import FlashANSRDataset
29
+ from .preprocessing import FlashANSRPreprocessor
flash_ansr/__main__.py ADDED
@@ -0,0 +1,347 @@
1
+ import datetime
2
+ import argparse
3
+ import sys
4
+ from copy import deepcopy
5
+
6
+
7
+ def main(argv: str = None) -> None:
8
+ parser = argparse.ArgumentParser(description='Neural Symbolic Regression')
9
+ subparsers = parser.add_subparsers(dest='command_name', required=True)
10
+
11
+ generate_skeleton_pool_parser = subparsers.add_parser("generate-skeleton-pool")
12
+ generate_skeleton_pool_parser.add_argument('-s', '--size', type=str, required=True, help='Size of the skeleton pool')
13
+ generate_skeleton_pool_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
14
+ generate_skeleton_pool_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the configuration file')
15
+ generate_skeleton_pool_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
16
+ generate_skeleton_pool_parser.add_argument('--output-reference', type=str, default='relative', help='Reference type for the output directory')
17
+ generate_skeleton_pool_parser.add_argument('--output-recursive', type=bool, default=True, help='Whether to recursively save the configuration')
18
+
19
+ import_test_data_parser = subparsers.add_parser("import-data")
20
+ import_test_data_parser.add_argument('-i', '--input', type=str, required=True, help='Path to the dataset file (CSV or YAML) from Biggio et al. or other benchmarks')
21
+ import_test_data_parser.add_argument('-b', '--base-skeleton-pool', type=str, required=True, help='Path to the base skeleton pool')
22
+ import_test_data_parser.add_argument('-p', '--parser', type=str, required=True, help='Name of the parser to use')
23
+ import_test_data_parser.add_argument('-e', '--simplipy-engine', type=str, required=True, help='Path to the expression space configuration file')
24
+ import_test_data_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
25
+ import_test_data_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
26
+
27
+ filter_skeleton_pool_parser = subparsers.add_parser("filter-skeleton-pool")
28
+ filter_skeleton_pool_parser.add_argument('-s', '--source', type=str, required=True, help='Path to the source skeleton pool')
29
+ filter_skeleton_pool_parser.add_argument('-f', '--holdouts', nargs='+', required=True, help='Paths to the holdout skeleton pools')
30
+ filter_skeleton_pool_parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to the output directory')
31
+ filter_skeleton_pool_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
32
+
33
+ split_skeleton_pool_parser = subparsers.add_parser("split-skeleton-pool")
34
+ split_skeleton_pool_parser.add_argument('-i', '--input', type=str, required=True, help='Path to the input skeleton pool')
35
+ split_skeleton_pool_parser.add_argument('-t', '--train-size', type=float, default=0.8, help='Size of the training set')
36
+ split_skeleton_pool_parser.add_argument('-r', '--random-state', type=int, default=None, help='Random seed for shuffling')
37
+
38
+ train_parser = subparsers.add_parser("train")
39
+ train_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the configuration file')
40
+ train_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
41
+ train_parser.add_argument('-o', '--output-dir', type=str, default='.', help='Path to the output directory')
42
+ train_parser.add_argument('-ci', '--checkpoint-interval', type=int, default=None, help='Interval for saving checkpoints')
43
+ train_parser.add_argument('-vi', '--validate-interval', type=int, default=None, help='Interval for validating the model')
44
+ train_parser.add_argument('-w', '--num_workers', type=int, default=None, help='Number of worker processes for data generation')
45
+ train_parser.add_argument('--project', type=str, default='neural-symbolic-regression', help='Name of the wandb project')
46
+ train_parser.add_argument('--entity', type=str, default='psaegert', help='Name of the wandb entity')
47
+ train_parser.add_argument('--name', type=str, default=None, help='Name of the wandb run')
48
+ train_parser.add_argument('--mode', type=str, default='online', help='Mode for wandb logging')
49
+ train_parser.add_argument('--resume-from', type=str, default=None, help='Path to a checkpoint directory to resume from')
50
+ train_parser.add_argument('--resume-step', type=int, default=None, help='Override the inferred resume step when resuming')
51
+
52
+ evaluate_run_parser = subparsers.add_parser("evaluate-run", help="Run an evaluation from a unified config")
53
+ evaluate_run_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the evaluation run config file')
54
+ evaluate_run_parser.add_argument('-n', '--limit', type=int, default=None, help='Override the sample limit specified in the config')
55
+ evaluate_run_parser.add_argument('-o', '--output-file', type=str, default=None, help='Override the output file path from the config')
56
+ evaluate_run_parser.add_argument('--save-every', type=int, default=None, help='Override periodic save frequency')
57
+ evaluate_run_parser.add_argument('--no-resume', action='store_true', help='Ignore previous results even if the output file exists')
58
+ evaluate_run_parser.add_argument('--experiment', type=str, default=None, help='Name of the experiment defined in the config to execute')
59
+ evaluate_run_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
60
+
61
+ wandb_stats_parser = subparsers.add_parser("wandb-stats")
62
+ wandb_stats_parser.add_argument('--project', type=str, default='neural-symbolic-regression', help='Name of the wandb project')
63
+ wandb_stats_parser.add_argument('--entity', type=str, default='psaegert', help='Name of the wandb entity')
64
+ wandb_stats_parser.add_argument('-o', '--output-file', type=str, default='wandb_stats.csv', help='Path to the output file')
65
+
66
+ benchmark_parser = subparsers.add_parser("benchmark")
67
+ benchmark_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the dataset configuration file')
68
+ benchmark_parser.add_argument('-n', '--samples', type=int, default=10_000, help='Number of samples to evaluate')
69
+ benchmark_parser.add_argument('-b', '--batch-size', type=int, default=128, help='Batch size for the dataset')
70
+ benchmark_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
71
+
72
+ install_parser = subparsers.add_parser("install", help="Install a model")
73
+ install_parser.add_argument("model", type=str, help="Model identifier to install")
74
+
75
+ remove_parser = subparsers.add_parser("remove", help="Remove a model")
76
+ remove_parser.add_argument("path", type=str, help="Path to the model to remove")
77
+
78
+ find_simplifications_parser = subparsers.add_parser("find-simplifications")
79
+ find_simplifications_parser.add_argument('-e', '--simplipy-engine', type=str, required=True, help='Path to the expression space configuration file')
80
+ find_simplifications_parser.add_argument('-n', '--max_n_rules', type=int, default=None, help='Maximum number of rules to find')
81
+ find_simplifications_parser.add_argument('-l', '--max_pattern_length', type=int, default=7, help='Maximum length of the patterns to find')
82
+ find_simplifications_parser.add_argument('-t', '--timeout', type=int, default=None, help='Timeout for the search of simplifications in seconds')
83
+ find_simplifications_parser.add_argument('-d', '--dummy-variables', type=int, nargs='+', default=None, help='Dummy variables to use in the simplifications')
84
+ find_simplifications_parser.add_argument('-m', '--max-simplify-steps', type=int, default=5, help='Maximum number of simplification steps')
85
+ find_simplifications_parser.add_argument('-x', '--X', type=int, default=1024, help='Number of samples to use for comparison of images')
86
+ find_simplifications_parser.add_argument('-c', '--C', type=int, default=1024, help='Number of samples of constants to put in to placeholders')
87
+ find_simplifications_parser.add_argument('-r', '--constants-fit-retries', type=int, default=5, help='Number of retries for fitting the constants')
88
+ find_simplifications_parser.add_argument('-o', '--output-file', type=str, required=True, help='Path to the output json file')
89
+ find_simplifications_parser.add_argument('-s', '--save-every', type=int, default=100, help='Save the simplifications every n rules')
90
+ find_simplifications_parser.add_argument('--reset-rules', action='store_true', help='Reset the rules before finding new ones')
91
+ find_simplifications_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
92
+
93
+ # Evaluate input
94
+ args = parser.parse_args(argv)
95
+
96
+ # Execute the command
97
+ match args.command_name:
98
+ case 'generate-skeleton-pool':
99
+ if args.verbose:
100
+ print(f'Generating skeleton pool from {args.config}')
101
+ from flash_ansr.expressions import SkeletonPool
102
+
103
+ skeleton_pool = SkeletonPool.from_config(args.config)
104
+ skeleton_pool.create(size=int(args.size), verbose=args.verbose)
105
+
106
+ if args.verbose:
107
+ print(f"Saving skeleton pool to {args.output_dir}")
108
+ skeleton_pool.save(directory=args.output_dir, config=args.config, reference=args.output_reference, recursive=args.output_recursive)
109
+
110
+ case 'import-data':
111
+ if args.verbose:
112
+ print(f'Importing data from {args.input}')
113
+ from simplipy import SimpliPyEngine
114
+ from flash_ansr.expressions import SkeletonPool
115
+ from flash_ansr.compat import ParserFactory
116
+ from flash_ansr.utils.config_io import load_config
117
+ from flash_ansr.utils.paths import substitute_root_path
118
+
119
+ import pandas as pd
120
+ import yaml
121
+ from pathlib import Path
122
+
123
+ simplipy_engine = SimpliPyEngine.load(args.simplipy_engine, install=True)
124
+ base_skeleton_pool = SkeletonPool.from_config(args.base_skeleton_pool)
125
+ input_path = substitute_root_path(args.input)
126
+ path_obj = Path(input_path)
127
+
128
+ if path_obj.suffix.lower() in {'.yaml', '.yml'}:
129
+ with open(input_path, 'r', encoding='utf-8') as handle:
130
+ raw_data = yaml.safe_load(handle)
131
+
132
+ if not isinstance(raw_data, dict):
133
+ raise ValueError('Expected YAML benchmark file to contain a mapping of equation identifiers to entries.')
134
+
135
+ records = []
136
+ for identifier, payload in raw_data.items():
137
+ if not isinstance(payload, dict):
138
+ continue
139
+
140
+ record = {'id': identifier}
141
+ record.update(payload)
142
+ if 'prepared' in record and record['prepared'] is None:
143
+ # Normalise missing prepared expressions to empty strings for downstream filtering.
144
+ record['prepared'] = ''
145
+ records.append(record)
146
+
147
+ df = pd.DataFrame.from_records(records)
148
+ else:
149
+ df = pd.read_csv(input_path)
150
+
151
+ data_parser = ParserFactory.get_parser(args.parser)
152
+ test_skeleton_pool: SkeletonPool = data_parser.parse_data(df, simplipy_engine, base_skeleton_pool, verbose=args.verbose)
153
+
154
+ if args.verbose:
155
+ print(f"Saving test set to {args.output_dir}")
156
+
157
+ test_skeleton_pool.save(directory=args.output_dir, config=args.base_skeleton_pool, reference='relative', recursive=True)
158
+
159
+ case 'split-skeleton-pool':
160
+ print(f'Splitting skeleton pool from {args.input}')
161
+ import os
162
+ from flash_ansr.expressions import SkeletonPool
163
+
164
+ print(f"Loading skeleton pool from {args.input}")
165
+
166
+ config, skeleton_pool = SkeletonPool.load(args.input)
167
+ train_skeleton_pool, val_skeleton_pool = skeleton_pool.split(train_size=args.train_size, random_state=args.random_state)
168
+
169
+ train_path = os.path.join(args.input, 'train')
170
+ val_path = os.path.join(args.input, 'val')
171
+
172
+ train_config = deepcopy(config)
173
+ val_config = deepcopy(config)
174
+
175
+ print(f"Saving training pool to {train_path}")
176
+ print(f"Saving validation pool to {val_path}")
177
+
178
+ train_skeleton_pool.save(directory=train_path, config=train_config, reference='relative', recursive=True)
179
+ val_skeleton_pool.save(directory=val_path, config=val_config, reference='relative', recursive=True)
180
+
181
+ case 'train':
182
+ if args.verbose:
183
+ print(f'Training model from {args.config}')
184
+ from flash_ansr.train.train import Trainer
185
+ from flash_ansr.utils.config_io import load_config, save_config
186
+ from flash_ansr.utils.paths import substitute_root_path
187
+
188
+ trainer = Trainer.from_config(args.config)
189
+
190
+ config = load_config(args.config)
191
+
192
+ try:
193
+ trainer.run(
194
+ project_name=args.project,
195
+ entity=args.entity,
196
+ name=args.name,
197
+ steps=config['steps'],
198
+ preprocess=config.get('preprocess', False),
199
+ device=config['device'],
200
+ compile_mode=config.get('compile_mode'),
201
+ checkpoint_interval=args.checkpoint_interval,
202
+ checkpoint_directory=substitute_root_path(args.output_dir),
203
+ validate_interval=args.validate_interval,
204
+ validate_size=config.get('val_size', None),
205
+ validate_batch_size=config.get('val_batch_size', None),
206
+ wandb_watch_log=config.get('wandb_watch_log', None),
207
+ wandb_watch_log_freq=config.get('wandb_watch_log_freq', 1000),
208
+ wandb_mode=args.mode,
209
+ num_workers=args.num_workers,
210
+ resume_from=args.resume_from,
211
+ resume_step=args.resume_step,
212
+ verbose=args.verbose,
213
+ )
214
+ except KeyboardInterrupt:
215
+ print("Training interrupted. Saving model...")
216
+
217
+ trainer.model.save(directory=args.output_dir, errors='ignore')
218
+
219
+ save_config(
220
+ load_config(args.config, resolve_paths=True),
221
+ directory=substitute_root_path(args.output_dir),
222
+ filename='train.yaml',
223
+ reference='relative',
224
+ recursive=True,
225
+ resolve_paths=True)
226
+
227
+ print(f"Saved model to {args.output_dir}")
228
+
229
+ case 'evaluate-run':
230
+ from flash_ansr.eval.run_config import build_evaluation_run, EvaluationRunPlan
231
+ from flash_ansr.utils.config_io import load_config
232
+ from flash_ansr.utils.paths import substitute_root_path
233
+
234
+ config_path = substitute_root_path(args.config)
235
+ if args.verbose:
236
+ print(f"Running evaluation plan from {config_path}")
237
+
238
+ raw_config = load_config(config_path)
239
+ experiment_map = raw_config.get("experiments") if isinstance(raw_config, dict) else None
240
+
241
+ def _execute_plan(plan: EvaluationRunPlan, experiment_name: str | None = None) -> None:
242
+ label = f"[{experiment_name}] " if experiment_name else ""
243
+ if plan.completed or plan.engine is None:
244
+ if args.verbose:
245
+ target = plan.total_limit or 'configured'
246
+ print(f"{label}Evaluation already completed ({plan.existing_results}/{target}). Nothing to do.")
247
+ return
248
+
249
+ plan.engine.run(
250
+ limit=plan.remaining,
251
+ save_every=plan.save_every,
252
+ output_path=plan.output_path,
253
+ verbose=args.verbose,
254
+ progress=args.verbose,
255
+ )
256
+
257
+ if args.verbose:
258
+ total = plan.engine.result_store.size
259
+ destination = plan.output_path or 'memory'
260
+ print(f"{label}Evaluation finished with {total} samples (saved to {destination}).")
261
+
262
+ if experiment_map and args.experiment is None:
263
+ experiment_names = list(experiment_map.keys())
264
+ if args.verbose:
265
+ count = len(experiment_names)
266
+ print(f"No --experiment provided; running all {count} experiments defined in config.")
267
+ for experiment_name in experiment_names:
268
+ if args.verbose:
269
+ print(f"--> {experiment_name}")
270
+ plan = build_evaluation_run(
271
+ config=config_path,
272
+ limit_override=args.limit,
273
+ output_override=args.output_file,
274
+ save_every_override=args.save_every,
275
+ resume=None if not args.no_resume else False,
276
+ experiment=experiment_name,
277
+ )
278
+ _execute_plan(plan, experiment_name)
279
+ else:
280
+ plan = build_evaluation_run(
281
+ config=config_path,
282
+ limit_override=args.limit,
283
+ output_override=args.output_file,
284
+ save_every_override=args.save_every,
285
+ resume=None if not args.no_resume else False,
286
+ experiment=args.experiment,
287
+ )
288
+ _execute_plan(plan, args.experiment)
289
+
290
+ case 'wandb-stats':
291
+ print(f'Fetching stats from wandb project {args.project} and entity {args.entity}')
292
+ import os
293
+ import wandb
294
+ import pandas as pd
295
+
296
+ from flash_ansr.utils.paths import substitute_root_path
297
+
298
+ api = wandb.Api() # type: ignore
299
+
300
+ runs = api.runs(f'{args.entity}/{args.project}')
301
+ runs = {run.id: {'run': run} for run in runs}
302
+
303
+ for key, value in runs.items():
304
+ start_time = datetime.datetime.strptime(value['run'].created_at, '%Y-%m-%dT%H:%M:%S') + datetime.timedelta(hours=2) # HACK: This is a hack to convert to CET
305
+ end_time = datetime.datetime.strptime(value['run'].heartbeatAt, '%Y-%m-%dT%H:%M:%S') + datetime.timedelta(hours=2)
306
+ runs[key]['start_time'] = start_time
307
+ runs[key]['end_time'] = end_time
308
+ runs[key]['duration'] = end_time - start_time
309
+ runs[key]['name'] = value['run'].name
310
+
311
+ df = pd.DataFrame.from_dict(runs, orient='index').drop(columns=['run'])
312
+
313
+ save_path = substitute_root_path(args.output_file)
314
+ if save_path:
315
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
316
+ df.to_csv(save_path)
317
+
318
+ case 'benchmark':
319
+ if args.verbose:
320
+ print(f'Benchmarking dataset {args.config}')
321
+ from flash_ansr.data import FlashANSRDataset
322
+ from flash_ansr.utils.config_io import load_config, save_config
323
+ from flash_ansr.utils.paths import substitute_root_path
324
+ import pandas as pd
325
+
326
+ dataset = FlashANSRDataset.from_config(substitute_root_path(args.config))
327
+
328
+ results = dataset._benchmark(n_samples=args.samples, batch_size=args.batch_size, verbose=args.verbose)
329
+
330
+ print(f'Iteration time: {1e3 * results["mean_iteration_time"]:.0f} ± {1e3 * results["std_iteration_time"]:.0f} ms')
331
+ print(f'Range: {1e3 * results["min_iteration_time"]:.0f} - {1e3 * results["max_iteration_time"]:.0f} ms')
332
+
333
+ case 'install':
334
+ from flash_ansr.model.manage import install_model
335
+ install_model(args.model)
336
+
337
+ case 'remove':
338
+ from flash_ansr.model.manage import remove_model
339
+ remove_model(args.path)
340
+
341
+ case _:
342
+ parser.print_help()
343
+ sys.exit(1)
344
+
345
+
346
+ if __name__ == '__main__':
347
+ main()
@@ -0,0 +1,7 @@
1
+ from .skeleton_pool_model import SkeletonPoolModel
2
+ from .brute_force_model import BruteForceModel
3
+
4
+ __all__ = [
5
+ "SkeletonPoolModel",
6
+ "BruteForceModel",
7
+ ]