flash-ansr 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. flash_ansr-0.4.2/LICENSE +21 -0
  2. flash_ansr-0.4.2/PKG-INFO +152 -0
  3. flash_ansr-0.4.2/README.md +104 -0
  4. flash_ansr-0.4.2/pyproject.toml +88 -0
  5. flash_ansr-0.4.2/requirements.txt +17 -0
  6. flash_ansr-0.4.2/setup.cfg +4 -0
  7. flash_ansr-0.4.2/src/flash_ansr/__init__.py +29 -0
  8. flash_ansr-0.4.2/src/flash_ansr/__main__.py +347 -0
  9. flash_ansr-0.4.2/src/flash_ansr/baselines/__init__.py +7 -0
  10. flash_ansr-0.4.2/src/flash_ansr/baselines/brute_force_model.py +306 -0
  11. flash_ansr-0.4.2/src/flash_ansr/baselines/skeleton_pool_model.py +361 -0
  12. flash_ansr-0.4.2/src/flash_ansr/benchmarks/__init__.py +5 -0
  13. flash_ansr-0.4.2/src/flash_ansr/benchmarks/fastsrb.py +519 -0
  14. flash_ansr-0.4.2/src/flash_ansr/compat/__init__.py +1 -0
  15. flash_ansr-0.4.2/src/flash_ansr/compat/convert_data.py +349 -0
  16. flash_ansr-0.4.2/src/flash_ansr/compat/evaluation_nesymres.py +85 -0
  17. flash_ansr-0.4.2/src/flash_ansr/compat/evaluation_pysr.py +95 -0
  18. flash_ansr-0.4.2/src/flash_ansr/compat/nesymres.py +74 -0
  19. flash_ansr-0.4.2/src/flash_ansr/data/__init__.py +3 -0
  20. flash_ansr-0.4.2/src/flash_ansr/data/collate.py +197 -0
  21. flash_ansr-0.4.2/src/flash_ansr/data/data.py +514 -0
  22. flash_ansr-0.4.2/src/flash_ansr/data/streaming.py +390 -0
  23. flash_ansr-0.4.2/src/flash_ansr/decoding/mcts.py +552 -0
  24. flash_ansr-0.4.2/src/flash_ansr/eval/__init__.py +2 -0
  25. flash_ansr-0.4.2/src/flash_ansr/eval/core.py +109 -0
  26. flash_ansr-0.4.2/src/flash_ansr/eval/data_sources.py +963 -0
  27. flash_ansr-0.4.2/src/flash_ansr/eval/engine.py +250 -0
  28. flash_ansr-0.4.2/src/flash_ansr/eval/evaluation.py +156 -0
  29. flash_ansr-0.4.2/src/flash_ansr/eval/evaluation_fastsrb.py +287 -0
  30. flash_ansr-0.4.2/src/flash_ansr/eval/metrics/__init__.py +10 -0
  31. flash_ansr-0.4.2/src/flash_ansr/eval/metrics/bootstrap.py +31 -0
  32. flash_ansr-0.4.2/src/flash_ansr/eval/metrics/token_prediction.py +464 -0
  33. flash_ansr-0.4.2/src/flash_ansr/eval/metrics/zss.py +42 -0
  34. flash_ansr-0.4.2/src/flash_ansr/eval/model_adapters.py +885 -0
  35. flash_ansr-0.4.2/src/flash_ansr/eval/result_store.py +117 -0
  36. flash_ansr-0.4.2/src/flash_ansr/eval/run_config.py +679 -0
  37. flash_ansr-0.4.2/src/flash_ansr/eval/sample_metadata.py +83 -0
  38. flash_ansr-0.4.2/src/flash_ansr/expressions/__init__.py +1 -0
  39. flash_ansr-0.4.2/src/flash_ansr/expressions/compilation.py +37 -0
  40. flash_ansr-0.4.2/src/flash_ansr/expressions/distributions.py +147 -0
  41. flash_ansr-0.4.2/src/flash_ansr/expressions/holdout.py +85 -0
  42. flash_ansr-0.4.2/src/flash_ansr/expressions/normalization.py +73 -0
  43. flash_ansr-0.4.2/src/flash_ansr/expressions/prior_factory.py +31 -0
  44. flash_ansr-0.4.2/src/flash_ansr/expressions/skeleton_pool.py +755 -0
  45. flash_ansr-0.4.2/src/flash_ansr/expressions/skeleton_sampling.py +129 -0
  46. flash_ansr-0.4.2/src/flash_ansr/expressions/structure.py +32 -0
  47. flash_ansr-0.4.2/src/flash_ansr/expressions/support_sampling.py +457 -0
  48. flash_ansr-0.4.2/src/flash_ansr/expressions/token_ops.py +128 -0
  49. flash_ansr-0.4.2/src/flash_ansr/flash_ansr.py +1018 -0
  50. flash_ansr-0.4.2/src/flash_ansr/generation/__init__.py +3 -0
  51. flash_ansr-0.4.2/src/flash_ansr/generation/beam.py +47 -0
  52. flash_ansr-0.4.2/src/flash_ansr/generation/mcts.py +127 -0
  53. flash_ansr-0.4.2/src/flash_ansr/model/__init__.py +24 -0
  54. flash_ansr-0.4.2/src/flash_ansr/model/common/__init__.py +9 -0
  55. flash_ansr-0.4.2/src/flash_ansr/model/common/components.py +118 -0
  56. flash_ansr-0.4.2/src/flash_ansr/model/decoders/__init__.py +10 -0
  57. flash_ansr-0.4.2/src/flash_ansr/model/decoders/components.py +239 -0
  58. flash_ansr-0.4.2/src/flash_ansr/model/decoders/transformer.py +84 -0
  59. flash_ansr-0.4.2/src/flash_ansr/model/encoders/__init__.py +9 -0
  60. flash_ansr-0.4.2/src/flash_ansr/model/encoders/base.py +78 -0
  61. flash_ansr-0.4.2/src/flash_ansr/model/encoders/set_transformer.py +454 -0
  62. flash_ansr-0.4.2/src/flash_ansr/model/factory.py +41 -0
  63. flash_ansr-0.4.2/src/flash_ansr/model/flash_ansr_model.py +835 -0
  64. flash_ansr-0.4.2/src/flash_ansr/model/manage.py +37 -0
  65. flash_ansr-0.4.2/src/flash_ansr/model/pre_encoder.py +29 -0
  66. flash_ansr-0.4.2/src/flash_ansr/model/tokenizer.py +279 -0
  67. flash_ansr-0.4.2/src/flash_ansr/preprocessing/__init__.py +14 -0
  68. flash_ansr-0.4.2/src/flash_ansr/preprocessing/feature_extractor.py +834 -0
  69. flash_ansr-0.4.2/src/flash_ansr/preprocessing/pipeline.py +245 -0
  70. flash_ansr-0.4.2/src/flash_ansr/preprocessing/prompt_serialization.py +240 -0
  71. flash_ansr-0.4.2/src/flash_ansr/preprocessing/schemas.py +23 -0
  72. flash_ansr-0.4.2/src/flash_ansr/refine.py +531 -0
  73. flash_ansr-0.4.2/src/flash_ansr/results.py +142 -0
  74. flash_ansr-0.4.2/src/flash_ansr/train/__init__.py +3 -0
  75. flash_ansr-0.4.2/src/flash_ansr/train/optimizers.py +14 -0
  76. flash_ansr-0.4.2/src/flash_ansr/train/schedules.py +21 -0
  77. flash_ansr-0.4.2/src/flash_ansr/train/train.py +903 -0
  78. flash_ansr-0.4.2/src/flash_ansr/utils/__init__.py +22 -0
  79. flash_ansr-0.4.2/src/flash_ansr/utils/config_io.py +155 -0
  80. flash_ansr-0.4.2/src/flash_ansr/utils/generation.py +262 -0
  81. flash_ansr-0.4.2/src/flash_ansr/utils/numeric.py +105 -0
  82. flash_ansr-0.4.2/src/flash_ansr/utils/paths.py +39 -0
  83. flash_ansr-0.4.2/src/flash_ansr/utils/tensor_ops.py +59 -0
  84. flash_ansr-0.4.2/src/flash_ansr.egg-info/PKG-INFO +152 -0
  85. flash_ansr-0.4.2/src/flash_ansr.egg-info/SOURCES.txt +95 -0
  86. flash_ansr-0.4.2/src/flash_ansr.egg-info/dependency_links.txt +1 -0
  87. flash_ansr-0.4.2/src/flash_ansr.egg-info/entry_points.txt +2 -0
  88. flash_ansr-0.4.2/src/flash_ansr.egg-info/requires.txt +38 -0
  89. flash_ansr-0.4.2/src/flash_ansr.egg-info/top_level.txt +1 -0
  90. flash_ansr-0.4.2/tests/test_data.py +68 -0
  91. flash_ansr-0.4.2/tests/test_decoding.py +97 -0
  92. flash_ansr-0.4.2/tests/test_flash_ansr_prompt.py +184 -0
  93. flash_ansr-0.4.2/tests/test_inference.py +173 -0
  94. flash_ansr-0.4.2/tests/test_numeric.py +80 -0
  95. flash_ansr-0.4.2/tests/test_preprocess_features.py +463 -0
  96. flash_ansr-0.4.2/tests/test_refine.py +77 -0
  97. flash_ansr-0.4.2/tests/test_results_serialization.py +203 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Paul Saegert
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: flash_ansr
3
+ Version: 0.4.2
4
+ Summary: Flash Amortized Neural Symbolic Regression
5
+ Author: Paul Saegert
6
+ Project-URL: Github, https://github.com/psaegert/flash-ansr
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: absl-py
11
+ Requires-Dist: datasets
12
+ Requires-Dist: drawdata
13
+ Requires-Dist: editdistance
14
+ Requires-Dist: einops
15
+ Requires-Dist: matplotlib
16
+ Requires-Dist: numpy
17
+ Requires-Dist: pyyaml
18
+ Requires-Dist: schedulefree
19
+ Requires-Dist: scikit-learn
20
+ Requires-Dist: scipy
21
+ Requires-Dist: simplipy
22
+ Requires-Dist: torch
23
+ Requires-Dist: torch_optimizer
24
+ Requires-Dist: tqdm
25
+ Requires-Dist: wandb
26
+ Requires-Dist: zss
27
+ Provides-Extra: dev
28
+ Requires-Dist: pre-commit; extra == "dev"
29
+ Requires-Dist: pytest; extra == "dev"
30
+ Requires-Dist: pytest-cov; extra == "dev"
31
+ Requires-Dist: mypy; extra == "dev"
32
+ Requires-Dist: flake8; extra == "dev"
33
+ Requires-Dist: pygount; extra == "dev"
34
+ Requires-Dist: pylint; extra == "dev"
35
+ Requires-Dist: types-setuptools; extra == "dev"
36
+ Requires-Dist: types-tqdm; extra == "dev"
37
+ Requires-Dist: types-toml; extra == "dev"
38
+ Requires-Dist: types-PyYAML; extra == "dev"
39
+ Requires-Dist: radon; extra == "dev"
40
+ Requires-Dist: mkdocs; extra == "dev"
41
+ Requires-Dist: mkdocs-material; extra == "dev"
42
+ Requires-Dist: mkdocs-autorefs; extra == "dev"
43
+ Requires-Dist: mkdocstrings; extra == "dev"
44
+ Requires-Dist: mkdocs-get-deps; extra == "dev"
45
+ Requires-Dist: mkdocs-material-extensions; extra == "dev"
46
+ Requires-Dist: mkdocstrings-python; extra == "dev"
47
+ Dynamic: license-file
48
+
49
+ <h1 align="center" style="margin-top: 0px;">⚡ANSR:<br>Flash Amortized Neural Symbolic Regression</h1>
50
+
51
+ <div align="center">
52
+
53
+ [![PyPI version](https://img.shields.io/pypi/v/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
54
+ [![PyPI license](https://img.shields.io/pypi/l/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
55
+ [![Documentation Status](https://readthedocs.org/projects/flash-ansr/badge/?version=latest)](https://flash-ansr.readthedocs.io/en/latest/?badge=latest)
56
+
57
+ </div>
58
+
59
+ <div align="center">
60
+
61
+ [![pytest](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml)
62
+ [![quality checks](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml)
63
+ [![CodeQL Advanced](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml)
64
+
65
+ </div>
66
+
67
+ # Papers
68
+ - WIP
69
+
70
+
71
+ # Usage
72
+
73
+ ```sh
74
+ pip install flash-ansr
75
+ ```
76
+
77
+
78
+ ```python
79
+ import torch
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ # Import flash_ansr
83
+ from flash_ansr import (
84
+ FlashANSR,
85
+ SoftmaxSamplingConfig,
86
+ install_model,
87
+ get_path,
88
+ )
89
+
90
+ # Specify the model
91
+ # Here: https://huggingface.co/psaegert/flash-ansr-v23.0-120M
92
+ MODEL = "psaegert/flash-ansr-v23.0-120M"
93
+
94
+ # Download the latest snapshot of the model
95
+ # By default, the model is downloaded to the directory `./models/` in the package root
96
+ install_model(MODEL)
97
+
98
+ # Load the model
99
+ model = FlashANSR.load(
100
+ directory=get_path('models', MODEL),
101
+ generation_config=SoftmaxSamplingConfig(choices=32), # or BeamSearchConfig / MCTSGenerationConfig
102
+ n_restarts=8,
103
+ ).to(device)
104
+
105
+ # Define data
106
+ X = ...
107
+ y = ...
108
+
109
+ # Fit the model to the data
110
+ model.fit(X, y, verbose=True)
111
+
112
+ # Show the best expression
113
+ print(model.get_expression())
114
+
115
+ # Predict with the best expression
116
+ y_pred = model.predict(X)
117
+ ```
118
+
119
+ # Overview
120
+
121
+ ### Training
122
+
123
+ <img src="./assets/images/flash-ansr-training.png" width="300">
124
+
125
+ > **⚡ANSR Training on Fully Procedurally Generated Data** Inspired by NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427))
126
+
127
+ ### Architecture
128
+
129
+ <img src="./assets/images/flash-ansr.png">
130
+
131
+ > **FlashANSR Architecture.** The model consists of an upgraded version of the SetTransformer ([Lee et al. 2019](https://arxiv.org/abs/1810.00825)) as an encoder, and a Pre-Norm Transformer decoder ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) as a generative model over symbolic expressions.
132
+
133
+ ### Results
134
+ Coming soon
135
+ <!-- <img src="./assets/images/test_time_compute_fastsrb.svg">
136
+
137
+ > **Test Time Compute scaling.** ⚡ANSR, NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427)), PySR ([Cranmer 2023](https://arxiv.org/abs/2305.01582)), and E2E ([Kamienny et al. 2022](https://arxiv.org/abs/2204.10532)) are evaluated on the FastSRB benchmark with 10 datasets per equation, $n_{support}=512$, noise level 0.0.\
138
+ > AMD 9950X (16C32T), RTX 4090 (24GB). -->
139
+
140
+
141
+
142
+ # Citation
143
+ ```bibtex
144
+ @software{flash-ansr2024,
145
+ author = {Paul Saegert},
146
+ title = {Flash Amortized Neural Symbolic Regression},
147
+ year = 2024,
148
+ publisher = {GitHub},
149
+ version = {0.4.2},
150
+ url = {https://github.com/psaegert/flash-ansr}
151
+ }
152
+ ```
@@ -0,0 +1,104 @@
1
+ <h1 align="center" style="margin-top: 0px;">⚡ANSR:<br>Flash Amortized Neural Symbolic Regression</h1>
2
+
3
+ <div align="center">
4
+
5
+ [![PyPI version](https://img.shields.io/pypi/v/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
6
+ [![PyPI license](https://img.shields.io/pypi/l/flash-ansr.svg)](https://pypi.org/project/flash-ansr/)
7
+ [![Documentation Status](https://readthedocs.org/projects/flash-ansr/badge/?version=latest)](https://flash-ansr.readthedocs.io/en/latest/?badge=latest)
8
+
9
+ </div>
10
+
11
+ <div align="center">
12
+
13
+ [![pytest](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pytest.yml)
14
+ [![quality checks](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/pre-commit.yml)
15
+ [![CodeQL Advanced](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml/badge.svg)](https://github.com/psaegert/flash-ansr/actions/workflows/codeql.yaml)
16
+
17
+ </div>
18
+
19
+ # Papers
20
+ - WIP
21
+
22
+
23
+ # Usage
24
+
25
+ ```sh
26
+ pip install flash-ansr
27
+ ```
28
+
29
+
30
+ ```python
31
+ import torch
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ # Import flash_ansr
35
+ from flash_ansr import (
36
+ FlashANSR,
37
+ SoftmaxSamplingConfig,
38
+ install_model,
39
+ get_path,
40
+ )
41
+
42
+ # Specify the model
43
+ # Here: https://huggingface.co/psaegert/flash-ansr-v23.0-120M
44
+ MODEL = "psaegert/flash-ansr-v23.0-120M"
45
+
46
+ # Download the latest snapshot of the model
47
+ # By default, the model is downloaded to the directory `./models/` in the package root
48
+ install_model(MODEL)
49
+
50
+ # Load the model
51
+ model = FlashANSR.load(
52
+ directory=get_path('models', MODEL),
53
+ generation_config=SoftmaxSamplingConfig(choices=32), # or BeamSearchConfig / MCTSGenerationConfig
54
+ n_restarts=8,
55
+ ).to(device)
56
+
57
+ # Define data
58
+ X = ...
59
+ y = ...
60
+
61
+ # Fit the model to the data
62
+ model.fit(X, y, verbose=True)
63
+
64
+ # Show the best expression
65
+ print(model.get_expression())
66
+
67
+ # Predict with the best expression
68
+ y_pred = model.predict(X)
69
+ ```
70
+
71
+ # Overview
72
+
73
+ ### Training
74
+
75
+ <img src="./assets/images/flash-ansr-training.png" width="300">
76
+
77
+ > **⚡ANSR Training on Fully Procedurally Generated Data** Inspired by NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427))
78
+
79
+ ### Architecture
80
+
81
+ <img src="./assets/images/flash-ansr.png">
82
+
83
+ > **FlashANSR Architecture.** The model consists of an upgraded version of the SetTransformer ([Lee et al. 2019](https://arxiv.org/abs/1810.00825)) as an encoder, and a Pre-Norm Transformer decoder ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) as a generative model over symbolic expressions.
84
+
85
+ ### Results
86
+ Coming soon
87
+ <!-- <img src="./assets/images/test_time_compute_fastsrb.svg">
88
+
89
+ > **Test Time Compute scaling.** ⚡ANSR, NeSymReS ([Biggio et al. 2021](https://arxiv.org/abs/2106.06427)), PySR ([Cranmer 2023](https://arxiv.org/abs/2305.01582)), and E2E ([Kamienny et al. 2022](https://arxiv.org/abs/2204.10532)) are evaluated on the FastSRB benchmark with 10 datasets per equation, $n_{support}=512$, noise level 0.0.\
90
+ > AMD 9950X (16C32T), RTX 4090 (24GB). -->
91
+
92
+
93
+
94
+ # Citation
95
+ ```bibtex
96
+ @software{flash-ansr2024,
97
+ author = {Paul Saegert},
98
+ title = {Flash Amortized Neural Symbolic Regression},
99
+ year = 2024,
100
+ publisher = {GitHub},
101
+ version = {0.4.2},
102
+ url = {https://github.com/psaegert/flash-ansr}
103
+ }
104
+ ```
@@ -0,0 +1,88 @@
1
+ [project]
2
+ name = "flash_ansr"
3
+ description = "Flash Amortized Neural Symbolic Regression"
4
+ authors = [
5
+ {name = "Paul Saegert"},
6
+ ]
7
+ readme = "README.md"
8
+ requires-python = ">=3.10"
9
+ dynamic = ["dependencies"]
10
+ version = "0.4.2"
11
+ urls = { Github = "https://github.com/psaegert/flash-ansr"}
12
+
13
+ [project.scripts]
14
+ flash_ansr = "flash_ansr.__main__:main"
15
+
16
+
17
+ [tool.setuptools.dynamic]
18
+ dependencies = {file = ["requirements.txt"]}
19
+
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pre-commit",
24
+ "pytest",
25
+ "pytest-cov",
26
+ "mypy",
27
+ "flake8",
28
+ "pygount",
29
+ "pylint",
30
+ "types-setuptools",
31
+ "types-tqdm",
32
+ "types-toml",
33
+ "types-PyYAML",
34
+ "radon",
35
+ "mkdocs",
36
+ "mkdocs-material",
37
+ "mkdocs-autorefs",
38
+ "mkdocstrings",
39
+ "mkdocs-get-deps",
40
+ "mkdocs-material-extensions",
41
+ "mkdocstrings-python"
42
+ ]
43
+
44
+ [tool.setuptools]
45
+ include-package-data = true
46
+
47
+
48
+ [tool.setuptools.packages.find]
49
+ where = ["src"]
50
+
51
+
52
+ [build-system]
53
+ requires = [
54
+ "setuptools>=68",
55
+ "wheel",
56
+ "platformdirs==3.10.0",
57
+ ]
58
+ build-backend = "setuptools.build_meta"
59
+
60
+
61
+ [tool.flake8]
62
+ ignore = ["E501", "W503"]
63
+ exclude = [
64
+ ".git",
65
+ "__pycache__",
66
+ "documentation",
67
+ "build",
68
+ "venv",
69
+ ".venv",
70
+ "env",
71
+ ".env",
72
+ "images",
73
+ ]
74
+ per-file-ignores = "__init__.py:F401"
75
+
76
+
77
+ [tool.mypy]
78
+ no_implicit_optional = false
79
+ disallow_untyped_defs = true
80
+ disallow_incomplete_defs = true
81
+ explicit_package_bases = true
82
+ exclude = "(.venv|tests/|experimental/|src/flash_ansr/compat/nesymres.py)"
83
+ ignore_missing_imports = true
84
+
85
+ [tool.pytest.ini_options]
86
+ filterwarnings = [
87
+ "ignore:This process .*use of fork.*deadlocks in the child.:DeprecationWarning",
88
+ ]
@@ -0,0 +1,17 @@
1
+ absl-py
2
+ datasets
3
+ drawdata
4
+ editdistance
5
+ einops
6
+ matplotlib
7
+ numpy
8
+ pyyaml
9
+ schedulefree
10
+ scikit-learn
11
+ scipy
12
+ simplipy
13
+ torch
14
+ torch_optimizer
15
+ tqdm
16
+ wandb
17
+ zss
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,29 @@
1
+ from .model import (
2
+ ModelFactory,
3
+ FlashANSRModel,
4
+ SetTransformer,
5
+ Tokenizer,
6
+ RotaryEmbedding,
7
+ IEEE75432PreEncoder,
8
+ install_model,
9
+ remove_model,
10
+ )
11
+ from .expressions import SkeletonPool, NoValidSampleFoundError
12
+ from .utils import (
13
+ GenerationConfig,
14
+ GenerationConfigBase,
15
+ BeamSearchConfig,
16
+ SoftmaxSamplingConfig,
17
+ MCTSGenerationConfig,
18
+ create_generation_config,
19
+ get_path,
20
+ load_config,
21
+ save_config,
22
+ substitute_root_path,
23
+ )
24
+ from .eval import Evaluation
25
+ from .refine import Refiner, ConvergenceError
26
+ from .flash_ansr import FlashANSR
27
+ from .baselines import SkeletonPoolModel, BruteForceModel
28
+ from .data.data import FlashANSRDataset
29
+ from .preprocessing import FlashANSRPreprocessor