S2Generator 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- S2Generator-0.0.1/LICENSE +21 -0
- S2Generator-0.0.1/PKG-INFO +21 -0
- S2Generator-0.0.1/README.md +1 -0
- S2Generator-0.0.1/S2Generator/__init__.py +17 -0
- S2Generator-0.0.1/S2Generator/base.py +326 -0
- S2Generator-0.0.1/S2Generator/encoders.py +246 -0
- S2Generator-0.0.1/S2Generator/generators.py +659 -0
- S2Generator-0.0.1/S2Generator/params.py +126 -0
- S2Generator-0.0.1/S2Generator/visualization.py +56 -0
- S2Generator-0.0.1/S2Generator.egg-info/PKG-INFO +21 -0
- S2Generator-0.0.1/S2Generator.egg-info/SOURCES.txt +14 -0
- S2Generator-0.0.1/S2Generator.egg-info/dependency_links.txt +1 -0
- S2Generator-0.0.1/S2Generator.egg-info/requires.txt +3 -0
- S2Generator-0.0.1/S2Generator.egg-info/top_level.txt +1 -0
- S2Generator-0.0.1/setup.cfg +4 -0
- S2Generator-0.0.1/setup.py +33 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2019 Cuixiaolong
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: S2Generator
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A series-symbol (S2) dual-modality data generation mechanism, enabling the unrestricted creation of high-quality time series data paired with corresponding symbolic representations.
|
|
5
|
+
Home-page: https://github.com/wwhenxuan/S2Generator
|
|
6
|
+
Author: whenxuan
|
|
7
|
+
Author-email: wwhenxuan@gmail.com
|
|
8
|
+
Keywords: Time Series,Data Generation,Complex System Modeling
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE
|
|
20
|
+
|
|
21
|
+
# S2Generator
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# S2Generator
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Created on 2025/01/23 17:37:24
|
|
4
|
+
@author: Whenxuan Wang
|
|
5
|
+
@email: wwhenxuan@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
# The basic data structure of symbolic expressions
|
|
8
|
+
from .base import Node, NodeList
|
|
9
|
+
|
|
10
|
+
# Parameter control of S2 data generation
|
|
11
|
+
from .params import Params
|
|
12
|
+
|
|
13
|
+
# S2 Data Generator
|
|
14
|
+
from .generators import Generator
|
|
15
|
+
|
|
16
|
+
# Visualize the generated S2 object
|
|
17
|
+
from .visualization import s2plot
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Created on 2025/01/23 18:25:07
|
|
4
|
+
@author: Whenxuan Wang
|
|
5
|
+
@email: wwhenxuan@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
import scipy.special
|
|
10
|
+
from typing import Optional, Union, List
|
|
11
|
+
|
|
12
|
+
from S2Generator.params import Params
|
|
13
|
+
|
|
14
|
+
operators_real = {
|
|
15
|
+
"add": 2,
|
|
16
|
+
"sub": 2,
|
|
17
|
+
"mul": 2,
|
|
18
|
+
"div": 2,
|
|
19
|
+
"abs": 1,
|
|
20
|
+
"inv": 1,
|
|
21
|
+
"sqrt": 1,
|
|
22
|
+
"log": 1,
|
|
23
|
+
"exp": 1,
|
|
24
|
+
"sin": 1,
|
|
25
|
+
"arcsin": 1,
|
|
26
|
+
"cos": 1,
|
|
27
|
+
"arccos": 1,
|
|
28
|
+
"tan": 1,
|
|
29
|
+
"arctan": 1,
|
|
30
|
+
"pow2": 1,
|
|
31
|
+
"pow3": 1,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
operators_extra = {"pow": 2}
|
|
35
|
+
|
|
36
|
+
math_constants = ["e", "pi", "euler_gamma", "CONSTANT"]
|
|
37
|
+
all_operators = {**operators_real, **operators_extra}
|
|
38
|
+
|
|
39
|
+
SPECIAL_WORDS = [
|
|
40
|
+
"<EOS>",
|
|
41
|
+
"<X>",
|
|
42
|
+
"</X>",
|
|
43
|
+
"<Y>",
|
|
44
|
+
"</Y>",
|
|
45
|
+
"</POINTS>",
|
|
46
|
+
"<INPUT_PAD>",
|
|
47
|
+
"<OUTPUT_PAD>",
|
|
48
|
+
"<PAD>",
|
|
49
|
+
"(",
|
|
50
|
+
")",
|
|
51
|
+
"SPECIAL",
|
|
52
|
+
"OOD_unary_op",
|
|
53
|
+
"OOD_binary_op",
|
|
54
|
+
"OOD_constant",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Node(object):
|
|
59
|
+
"""Generate a node in the sampling tree"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, value: Union[str, int], params: Params, children: list = None) -> None:
|
|
62
|
+
# The specific value stored in the current node
|
|
63
|
+
self.value = value
|
|
64
|
+
# The list of child nodes that the current node points to
|
|
65
|
+
self.children = children if children else []
|
|
66
|
+
self.params = params
|
|
67
|
+
|
|
68
|
+
def push_child(self, child: "Node") -> None:
|
|
69
|
+
"""Add a child node to the current node"""
|
|
70
|
+
self.children.append(child)
|
|
71
|
+
|
|
72
|
+
def prefix(self) -> str:
|
|
73
|
+
"""Get all the contents of this tree using a recursive traversal starting from the current root node"""
|
|
74
|
+
s = str(self.value)
|
|
75
|
+
for c in self.children:
|
|
76
|
+
s += "," + c.prefix()
|
|
77
|
+
return s
|
|
78
|
+
|
|
79
|
+
def qtree_prefix(self) -> str:
|
|
80
|
+
"""Get all the contents of this tree using a recursive traversal starting from the current root node, storing the result in a list"""
|
|
81
|
+
s = "[.$" + str(self.value) + "$ "
|
|
82
|
+
for c in self.children:
|
|
83
|
+
s += c.qtree_prefix()
|
|
84
|
+
s += "]"
|
|
85
|
+
return s
|
|
86
|
+
|
|
87
|
+
def infix(self) -> str:
|
|
88
|
+
"""Output the entire symbolic expression using in-order traversal"""
|
|
89
|
+
nb_children = len(self.children) # Get the number of children
|
|
90
|
+
if nb_children == 0:
|
|
91
|
+
# If there are no children, the current node is a leaf node
|
|
92
|
+
if self.value.lstrip("-").isdigit():
|
|
93
|
+
return str(self.value)
|
|
94
|
+
else:
|
|
95
|
+
s = str(self.value)
|
|
96
|
+
return s # Output the content of the leaf node
|
|
97
|
+
if nb_children == 1:
|
|
98
|
+
# If there is only one child, it indicates a unary operator
|
|
99
|
+
s = str(self.value)
|
|
100
|
+
# Handle different types of unary operators
|
|
101
|
+
if s == "pow2":
|
|
102
|
+
s = "(" + self.children[0].infix() + ")**2"
|
|
103
|
+
elif s == "pow3":
|
|
104
|
+
s = "(" + self.children[0].infix() + ")**3"
|
|
105
|
+
else:
|
|
106
|
+
# Output in the form of f(x), where f is functions like sin and cos
|
|
107
|
+
s = s + "(" + self.children[0].infix() + ")"
|
|
108
|
+
return s
|
|
109
|
+
# If the current node is a binary operator, combine using the intermediate terms
|
|
110
|
+
s = "(" + self.children[0].infix()
|
|
111
|
+
for c in self.children[1:]:
|
|
112
|
+
s = s + " " + str(self.value) + " " + c.infix()
|
|
113
|
+
return s + ")"
|
|
114
|
+
|
|
115
|
+
def val(self, x: ndarray, deterministic: Optional[bool] = True) -> ndarray:
|
|
116
|
+
"""Evaluate the symbolic expression using specific numerical sequences"""
|
|
117
|
+
if len(self.children) == 0:
|
|
118
|
+
# If the node is a leaf node, it is a symbolic variable or a random constant
|
|
119
|
+
if str(self.value).startswith("x_"):
|
|
120
|
+
# Handle symbolic expressions
|
|
121
|
+
_, dim = self.value.split("_")
|
|
122
|
+
dim = int(dim)
|
|
123
|
+
return x[:, dim]
|
|
124
|
+
elif str(self.value) == "rand":
|
|
125
|
+
# Handle random constants
|
|
126
|
+
if deterministic:
|
|
127
|
+
return np.zeros((x.shape[0],))
|
|
128
|
+
return np.random.randn(x.shape[0])
|
|
129
|
+
elif str(self.value) in math_constants:
|
|
130
|
+
return getattr(np, str(self.value)) * np.ones((x.shape[0],))
|
|
131
|
+
else:
|
|
132
|
+
return float(self.value) * np.ones((x.shape[0],))
|
|
133
|
+
|
|
134
|
+
# Handle various binary operators and perform specific calculations recursively
|
|
135
|
+
if self.value == "add":
|
|
136
|
+
return self.children[0].val(x) + self.children[1].val(x) # Addition
|
|
137
|
+
if self.value == "sub":
|
|
138
|
+
return self.children[0].val(x) - self.children[1].val(x) # Subtraction
|
|
139
|
+
if self.value == "mul":
|
|
140
|
+
m1, m2 = self.children[0].val(x), self.children[1].val(x) # Multiplication
|
|
141
|
+
# Handle exceptions in penalized calculations
|
|
142
|
+
try:
|
|
143
|
+
return m1 * m2
|
|
144
|
+
except Exception as e:
|
|
145
|
+
nans = np.empty((m1.shape[0],))
|
|
146
|
+
nans[:] = np.nan
|
|
147
|
+
return nans
|
|
148
|
+
if self.value == "pow":
|
|
149
|
+
m1, m2 = self.children[0].val(x), self.children[1].val(x) # Exponentiation
|
|
150
|
+
try:
|
|
151
|
+
return np.power(m1, m2)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
nans = np.empty((m1.shape[0],))
|
|
154
|
+
nans[:] = np.nan
|
|
155
|
+
return nans
|
|
156
|
+
if self.value == "max":
|
|
157
|
+
return np.maximum(self.children[0].val(x), self.children[1].val(x)) # Maximum
|
|
158
|
+
if self.value == "min":
|
|
159
|
+
return np.minimum(self.children[0].val(x), self.children[1].val(x)) # Minimum
|
|
160
|
+
if self.value == "div":
|
|
161
|
+
# Ensure denominator is not zero
|
|
162
|
+
denominator = self.children[1].val(x)
|
|
163
|
+
denominator[denominator == 0.0] = np.nan
|
|
164
|
+
try:
|
|
165
|
+
return self.children[0].val(x) / denominator # Division
|
|
166
|
+
except Exception as e:
|
|
167
|
+
nans = np.empty((denominator.shape[0],))
|
|
168
|
+
nans[:] = np.nan
|
|
169
|
+
return nans
|
|
170
|
+
|
|
171
|
+
# Handle various unary operators
|
|
172
|
+
if self.value == "inv":
|
|
173
|
+
# Ensure denominator is not zero
|
|
174
|
+
denominator = self.children[0].val(x)
|
|
175
|
+
denominator[denominator == 0.0] = np.nan
|
|
176
|
+
try:
|
|
177
|
+
return 1 / denominator # Reciprocal
|
|
178
|
+
except Exception as e:
|
|
179
|
+
nans = np.empty((denominator.shape[0],))
|
|
180
|
+
nans[:] = np.nan
|
|
181
|
+
return nans
|
|
182
|
+
if self.value == "log":
|
|
183
|
+
numerator = self.children[0].val(x)
|
|
184
|
+
# Ensure logarithm inputs are not negative or zero
|
|
185
|
+
if self.params.use_abs:
|
|
186
|
+
# Use log(abs(.)) if specified
|
|
187
|
+
numerator[numerator <= 0.0] *= -1
|
|
188
|
+
else:
|
|
189
|
+
numerator[numerator <= 0.0] = np.nan
|
|
190
|
+
try:
|
|
191
|
+
return np.log(numerator) # Logarithm
|
|
192
|
+
except Exception as e:
|
|
193
|
+
nans = np.empty((numerator.shape[0],))
|
|
194
|
+
nans[:] = np.nan
|
|
195
|
+
return nans
|
|
196
|
+
if self.value == "sqrt":
|
|
197
|
+
numerator = self.children[0].val(x)
|
|
198
|
+
# Ensure square root inputs are non-negative
|
|
199
|
+
if self.params.use_abs:
|
|
200
|
+
# Apply absolute value if specified
|
|
201
|
+
numerator[numerator <= 0.0] *= -1
|
|
202
|
+
else:
|
|
203
|
+
numerator[numerator < 0.0] = np.nan
|
|
204
|
+
try:
|
|
205
|
+
return np.sqrt(numerator) # Square root
|
|
206
|
+
except Exception as e:
|
|
207
|
+
nans = np.empty((numerator.shape[0],))
|
|
208
|
+
nans[:] = np.nan
|
|
209
|
+
return nans
|
|
210
|
+
if self.value == "pow2":
|
|
211
|
+
numerator = self.children[0].val(x)
|
|
212
|
+
try:
|
|
213
|
+
return numerator ** 2 # Square
|
|
214
|
+
except Exception as e:
|
|
215
|
+
nans = np.empty((numerator.shape[0],))
|
|
216
|
+
nans[:] = np.nan
|
|
217
|
+
return nans
|
|
218
|
+
if self.value == "pow3":
|
|
219
|
+
numerator = self.children[0].val(x)
|
|
220
|
+
try:
|
|
221
|
+
return numerator ** 3 # Cube
|
|
222
|
+
except Exception as e:
|
|
223
|
+
nans = np.empty((numerator.shape[0],))
|
|
224
|
+
nans[:] = np.nan
|
|
225
|
+
return nans
|
|
226
|
+
if self.value == "abs":
|
|
227
|
+
return np.abs(self.children[0].val(x)) # Absolute value
|
|
228
|
+
if self.value == "sign":
|
|
229
|
+
return (self.children[0].val(x) >= 0) * 2.0 - 1.0 # Sign function
|
|
230
|
+
if self.value == "step":
|
|
231
|
+
x = self.children[0].val(x) # Step function
|
|
232
|
+
return x if x > 0 else 0
|
|
233
|
+
if self.value == "id":
|
|
234
|
+
return self.children[0].val(x) # Identity mapping
|
|
235
|
+
if self.value == "fresnel":
|
|
236
|
+
return scipy.special.fresnel(self.children[0].val(x))[0]
|
|
237
|
+
if self.value.startswith("eval"):
|
|
238
|
+
n = self.value[-1]
|
|
239
|
+
return getattr(scipy.special, self.value[:-1])(n, self.children[0].val(x))[0]
|
|
240
|
+
else:
|
|
241
|
+
fn = getattr(np, self.value, None)
|
|
242
|
+
if fn is not None:
|
|
243
|
+
try:
|
|
244
|
+
return fn(self.children[0].val(x))
|
|
245
|
+
except Exception as e:
|
|
246
|
+
nans = np.empty((x.shape[0],))
|
|
247
|
+
nans[:] = np.nan
|
|
248
|
+
return nans
|
|
249
|
+
fn = getattr(scipy.special, self.value, None)
|
|
250
|
+
if fn is not None:
|
|
251
|
+
return fn(self.children[0].val(x))
|
|
252
|
+
assert False, "Could not find function"
|
|
253
|
+
|
|
254
|
+
def get_recurrence_degree(self) -> int:
|
|
255
|
+
"""Get the maximum variable index for leaf nodes when the current node is the root"""
|
|
256
|
+
recurrence_degree = 0
|
|
257
|
+
if len(self.children) == 0:
|
|
258
|
+
# If the current node is a leaf node
|
|
259
|
+
if str(self.value).startswith("x_"):
|
|
260
|
+
_, offset = self.value.split("_")
|
|
261
|
+
offset = int(offset)
|
|
262
|
+
if offset > recurrence_degree:
|
|
263
|
+
recurrence_degree = offset
|
|
264
|
+
return recurrence_degree
|
|
265
|
+
return max([child.get_recurrence_degree() for child in self.children])
|
|
266
|
+
|
|
267
|
+
def replace_node_value(self, old_value: str, new_value: str) -> None:
|
|
268
|
+
"""Traverse the entire symbolic expression and replace it with a specific value"""
|
|
269
|
+
if self.value == old_value:
|
|
270
|
+
self.value = new_value
|
|
271
|
+
for child in self.children:
|
|
272
|
+
child.replace_node_value(old_value, new_value)
|
|
273
|
+
|
|
274
|
+
def __len__(self) -> int:
|
|
275
|
+
"""Output the total length of the expression with the current node as the root node"""
|
|
276
|
+
lenc = 1
|
|
277
|
+
for c in self.children:
|
|
278
|
+
lenc += len(c)
|
|
279
|
+
return lenc
|
|
280
|
+
|
|
281
|
+
def __str__(self) -> str:
|
|
282
|
+
# infix a default print
|
|
283
|
+
return self.infix()
|
|
284
|
+
|
|
285
|
+
def __repr__(self) -> str:
|
|
286
|
+
# infix a default print
|
|
287
|
+
return str(self)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class NodeList(object):
|
|
291
|
+
"""A list that stores the entire multivariate symbolic expression"""
|
|
292
|
+
|
|
293
|
+
def __init__(self, nodes: List[Node]) -> None:
|
|
294
|
+
self.nodes = [] # Initialize the list to store root nodes
|
|
295
|
+
for node in nodes:
|
|
296
|
+
self.nodes.append(node)
|
|
297
|
+
self.params = nodes[0].params
|
|
298
|
+
|
|
299
|
+
def infix(self) -> str:
|
|
300
|
+
"""Connect all multivariate symbolic expressions with |"""
|
|
301
|
+
return " | ".join([node.infix() for node in self.nodes]) # In-order traversal of the tree
|
|
302
|
+
|
|
303
|
+
def prefix(self) -> str:
|
|
304
|
+
"""Connect all multivariate symbolic expressions with ,|,"""
|
|
305
|
+
return ",|,".join([node.prefix() for node in self.nodes])
|
|
306
|
+
|
|
307
|
+
def val(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray:
|
|
308
|
+
"""Sample the entire multivariate symbolic expression to obtain a specific numerical sequence"""
|
|
309
|
+
batch_vals = [np.expand_dims(node.val(np.copy(xs), deterministic=deterministic), -1) for node in self.nodes]
|
|
310
|
+
return np.concatenate(batch_vals, -1)
|
|
311
|
+
|
|
312
|
+
def replace_node_value(self, old_value: str, new_value: str) -> None:
|
|
313
|
+
"""Traverse the entire symbolic expression to replace a specific value"""
|
|
314
|
+
for node in self.nodes:
|
|
315
|
+
node.replace_node_value(old_value, new_value)
|
|
316
|
+
|
|
317
|
+
def __len__(self) -> int:
|
|
318
|
+
# Get the length of the entire multivariate symbolic expression
|
|
319
|
+
return sum([len(node) for node in self.nodes])
|
|
320
|
+
|
|
321
|
+
def __str__(self) -> str:
|
|
322
|
+
"""Output the multivariate symbolic expression in string form"""
|
|
323
|
+
return self.infix()
|
|
324
|
+
|
|
325
|
+
def __repr__(self) -> str:
|
|
326
|
+
return str(self)
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Created on 2025/01/23 17:37:24
|
|
4
|
+
@author: Whenxuan Wang
|
|
5
|
+
@email: wwhenxuan@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
from S2Generator.base import Node, NodeList
|
|
10
|
+
from S2Generator.params import Params
|
|
11
|
+
|
|
12
|
+
from typing import Union, List, Dict, Tuple
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GeneralEncoder(object):
|
|
16
|
+
"""General encoder for handling S2 data"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, params: Params, symbols: List[str], all_operators: Dict[str, str]) -> None:
|
|
19
|
+
# Create a numerical encoder
|
|
20
|
+
self.float_encoder = FloatSequences(params)
|
|
21
|
+
# Create a symbolic encoder
|
|
22
|
+
self.equation_encoder = Equation(params, symbols, self.float_encoder, all_operators)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FloatSequences(object):
|
|
26
|
+
"""Float number encoder for S2 data generation"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, params: Params) -> None:
|
|
29
|
+
# Floating-point precision
|
|
30
|
+
self.float_precision = params.float_precision
|
|
31
|
+
self.mantissa_len = params.mantissa_len
|
|
32
|
+
# Maximum exponent range
|
|
33
|
+
self.max_exponent = params.max_exponent
|
|
34
|
+
# Base
|
|
35
|
+
self.base = (self.float_precision + 1) // self.mantissa_len
|
|
36
|
+
# Maximum number of tokens in the longest encoding
|
|
37
|
+
self.max_token = 10 ** self.base
|
|
38
|
+
|
|
39
|
+
self.symbols = ["+", "-"]
|
|
40
|
+
|
|
41
|
+
self.symbols.extend(["N" + f"%0{self.base}d" % i for i in range(self.max_token)])
|
|
42
|
+
self.symbols.extend(["E" + str(i) for i in range(-self.max_exponent, self.max_exponent + 1)])
|
|
43
|
+
|
|
44
|
+
def encode(self, values: ndarray) -> List:
|
|
45
|
+
"""Encode a float number"""
|
|
46
|
+
precision = self.float_precision
|
|
47
|
+
if len(values.shape) == 1:
|
|
48
|
+
# Process each channel separately
|
|
49
|
+
seq = []
|
|
50
|
+
value = values
|
|
51
|
+
for val in value:
|
|
52
|
+
# Iterate over each value to encode
|
|
53
|
+
assert val not in [-np.inf, np.inf] # Cannot encode illegal maximum values
|
|
54
|
+
# Encode the sign
|
|
55
|
+
sign = "+" if val >= 0 else "-"
|
|
56
|
+
# Use scientific notation
|
|
57
|
+
m, e = (f"%.{precision}e" % val).split("e")
|
|
58
|
+
i, f = m.lstrip("-").split(".")
|
|
59
|
+
i = i + f
|
|
60
|
+
tokens = chunks(i, self.base)
|
|
61
|
+
expon = int(e) - precision
|
|
62
|
+
if expon < -self.max_exponent:
|
|
63
|
+
tokens = ["0" * self.base] * self.mantissa_len
|
|
64
|
+
expon = int(0)
|
|
65
|
+
seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)])
|
|
66
|
+
return seq
|
|
67
|
+
else:
|
|
68
|
+
seqs = [self.encode(values[0])]
|
|
69
|
+
N = values.shape[0]
|
|
70
|
+
for n in range(1, N):
|
|
71
|
+
# Process each channel separately
|
|
72
|
+
seqs += [self.encode(values[n])]
|
|
73
|
+
return seqs
|
|
74
|
+
|
|
75
|
+
def decode(self, lst: List):
|
|
76
|
+
"""Parse a list that starts with a float. Return the float value, and the position it ends in the list."""
|
|
77
|
+
if len(lst) == 0:
|
|
78
|
+
# The encoding sequence is empty
|
|
79
|
+
return None
|
|
80
|
+
seq = []
|
|
81
|
+
for val in chunks(lst, 2 + self.mantissa_len):
|
|
82
|
+
# Iterate over each string value object in the sequence
|
|
83
|
+
for x in val:
|
|
84
|
+
if x[0] not in ["-", "+", "E", "N"]:
|
|
85
|
+
return np.nan
|
|
86
|
+
try:
|
|
87
|
+
# Process the sign bit
|
|
88
|
+
sign = 1 if val[0] == "+" else -1
|
|
89
|
+
mant = ""
|
|
90
|
+
for x in val[1:-1]:
|
|
91
|
+
mant += x[1:]
|
|
92
|
+
# Process the mantissa and exponent
|
|
93
|
+
mant = int(mant)
|
|
94
|
+
exp = int(val[-1][1:])
|
|
95
|
+
value = sign * mant * (10 ** exp)
|
|
96
|
+
value = float(value)
|
|
97
|
+
except Exception:
|
|
98
|
+
value = np.nan
|
|
99
|
+
seq.append(value)
|
|
100
|
+
return seq
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class Equation(object):
|
|
104
|
+
"""Symbolic expression encoder for handling S2 data"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, params: Params, symbols: List[str], float_encoder: FloatSequences,
|
|
107
|
+
all_operators: Dict[str, str]) -> None:
|
|
108
|
+
self.params = params
|
|
109
|
+
# Maximum numerical range in symbolic expressions
|
|
110
|
+
self.max_int = self.params.max_int
|
|
111
|
+
# List of special symbols
|
|
112
|
+
self.symbols = symbols
|
|
113
|
+
|
|
114
|
+
# Handle additional symbols
|
|
115
|
+
if params.extra_unary_operators != "":
|
|
116
|
+
self.extra_unary_operators = self.params.extra_unary_operators.split(",")
|
|
117
|
+
else:
|
|
118
|
+
self.extra_unary_operators = []
|
|
119
|
+
|
|
120
|
+
if params.extra_binary_operators != "":
|
|
121
|
+
self.extra_binary_operators = self.params.extra_binary_operators.split(",")
|
|
122
|
+
else:
|
|
123
|
+
self.extra_binary_operators = []
|
|
124
|
+
|
|
125
|
+
# Sequence encoder used
|
|
126
|
+
self.float_encoder = float_encoder
|
|
127
|
+
# Dictionary storing unary and binary operators
|
|
128
|
+
self.all_operators = all_operators
|
|
129
|
+
|
|
130
|
+
def encode(self, tree: NodeList) -> List[str]:
|
|
131
|
+
"""Encode the input symbolic expression"""
|
|
132
|
+
res = []
|
|
133
|
+
for elem in tree.prefix().split(","):
|
|
134
|
+
try:
|
|
135
|
+
val = float(elem)
|
|
136
|
+
if elem.lstrip("-").isdigit():
|
|
137
|
+
# Add the encoding of the number
|
|
138
|
+
res.extend(self.write_int(int(elem)))
|
|
139
|
+
else:
|
|
140
|
+
# Encode using the numerical encoder
|
|
141
|
+
res.extend(self.float_encoder.encode(np.array([val])))
|
|
142
|
+
except ValueError:
|
|
143
|
+
res.append(elem)
|
|
144
|
+
return res
|
|
145
|
+
|
|
146
|
+
def decode(self, lst):
|
|
147
|
+
trees = []
|
|
148
|
+
lists = self.split_at_value(lst, "|")
|
|
149
|
+
for lst in lists:
|
|
150
|
+
tree = self._decode(lst)[0]
|
|
151
|
+
if tree is None:
|
|
152
|
+
return None
|
|
153
|
+
trees.append(tree)
|
|
154
|
+
tree = NodeList(trees)
|
|
155
|
+
return tree
|
|
156
|
+
|
|
157
|
+
def _decode(self, lst: List) -> Tuple[Union[Node, None], int]:
|
|
158
|
+
if len(lst) == 0:
|
|
159
|
+
return None, 0
|
|
160
|
+
elif "OOD" in lst[0]:
|
|
161
|
+
return None, 0
|
|
162
|
+
elif lst[0] in self.all_operators.keys():
|
|
163
|
+
res = Node(lst[0], self.params)
|
|
164
|
+
arity = self.all_operators[lst[0]]
|
|
165
|
+
pos = 1
|
|
166
|
+
for i in range(arity):
|
|
167
|
+
child, length = self._decode(lst[pos:])
|
|
168
|
+
if child is None:
|
|
169
|
+
return None, pos
|
|
170
|
+
res.push_child(child)
|
|
171
|
+
pos += length
|
|
172
|
+
return res, pos
|
|
173
|
+
elif lst[0].startswith("INT"):
|
|
174
|
+
val, length = self.parse_int(lst)
|
|
175
|
+
return Node(str(val), self.params), length
|
|
176
|
+
elif lst[0] == "+" or lst[0] == "-":
|
|
177
|
+
try:
|
|
178
|
+
val = self.float_encoder.decode(lst[:3])[0]
|
|
179
|
+
except Exception as e:
|
|
180
|
+
return None, 0
|
|
181
|
+
return Node(str(val), self.params), 3
|
|
182
|
+
elif (
|
|
183
|
+
lst[0].startswith("CONSTANT") or lst[0] == "y"
|
|
184
|
+
): # Added this manually, be careful!!
|
|
185
|
+
return Node(lst[0], self.params), 1
|
|
186
|
+
elif lst[0] in self.symbols:
|
|
187
|
+
return Node(lst[0], self.params), 1
|
|
188
|
+
else:
|
|
189
|
+
try:
|
|
190
|
+
float(lst[0]) # If number, return leaf
|
|
191
|
+
return Node(lst[0], self.params), 1
|
|
192
|
+
except:
|
|
193
|
+
return None, 0
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def split_at_value(lst: List, value: int) -> List:
|
|
197
|
+
indices = [i for i, x in enumerate(lst) if x == value]
|
|
198
|
+
res = []
|
|
199
|
+
for start, end in zip([0, *[i + 1 for i in indices]], [*[i - 1 for i in indices], len(lst)]):
|
|
200
|
+
res.append(lst[start: end + 1])
|
|
201
|
+
return res
|
|
202
|
+
|
|
203
|
+
def parse_int(self, lst: List) -> Tuple[int, int]:
|
|
204
|
+
"""
|
|
205
|
+
Parse a list that starts with an integer.
|
|
206
|
+
Return the integer value, and the position it ends in the list.
|
|
207
|
+
"""
|
|
208
|
+
base = self.max_int
|
|
209
|
+
val = 0
|
|
210
|
+
i = 0
|
|
211
|
+
for x in lst[1:]:
|
|
212
|
+
if not (x.rstrip("-").isdigit()):
|
|
213
|
+
break
|
|
214
|
+
val = val * base + int(x)
|
|
215
|
+
i += 1
|
|
216
|
+
if base > 0 and lst[0] == "INT-":
|
|
217
|
+
val = -val
|
|
218
|
+
return val, i + 1
|
|
219
|
+
|
|
220
|
+
def write_int(self, val):
|
|
221
|
+
"""Convert a decimal integer to a representation in the given base."""
|
|
222
|
+
if not self.params.use_sympy:
|
|
223
|
+
return [str(val)]
|
|
224
|
+
|
|
225
|
+
base = self.max_int
|
|
226
|
+
res = []
|
|
227
|
+
max_digit = abs(base)
|
|
228
|
+
neg = val < 0
|
|
229
|
+
val = -val if neg else val
|
|
230
|
+
while True:
|
|
231
|
+
rem = val % base
|
|
232
|
+
val = val // base
|
|
233
|
+
if rem < 0 or rem > max_digit:
|
|
234
|
+
rem -= base
|
|
235
|
+
val += 1
|
|
236
|
+
res.append(str(rem))
|
|
237
|
+
if val == 0:
|
|
238
|
+
break
|
|
239
|
+
res.append("INT-" if neg else "INT+")
|
|
240
|
+
return res[::-1]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def chunks(lst: List, n: int) -> List:
|
|
244
|
+
"""Yield successive n-sized chunks from lst."""
|
|
245
|
+
for i in range(0, len(lst), n):
|
|
246
|
+
yield lst[i: i + n]
|