mathformer 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 MathFormer Authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: mathformer
3
+ Version: 1.0.0
4
+ Summary: A transformer-based math library
5
+ Author-email: JeremySu0818 <xinghong.su0818@gmail.com>
6
+ Project-URL: Homepage, https://github.com/JeremySu0818/MathFormer-API
7
+ Project-URL: Bug Tracker, https://github.com/JeremySu0818/MathFormer-API/issues
8
+ Project-URL: Repository, https://github.com/JeremySu0818/MathFormer-API
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: transformers>=4.30.0
17
+ Requires-Dist: safetensors>=0.3.0
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
20
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
+ Dynamic: license-file
22
+
23
+ # MathFormer
24
+
25
+ MathFormer is a Python library for mathematical operations using transformer architectures.
26
+
27
+ ## Installation
28
+
29
+ ```bash
30
+ pip install mathformer
31
+ ```
32
+
33
+ ## Usage
34
+
35
+ ```python
36
+ import mathformer
37
+
38
+ # Example usage
39
+ # mathformer.do_something()
40
+ ```
@@ -0,0 +1,18 @@
1
+ # MathFormer
2
+
3
+ MathFormer is a Python library for mathematical operations using transformer architectures.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install mathformer
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import mathformer
15
+
16
+ # Example usage
17
+ # mathformer.do_something()
18
+ ```
@@ -0,0 +1,47 @@
1
+
2
+ [tool.setuptools.package-dir]
3
+ "" = "src"
4
+
5
+ [tool.setuptools.packages.find]
6
+ where = ["src"]
7
+
8
+ [build-system]
9
+ requires = ["setuptools>=61.0"]
10
+ build-backend = "setuptools.build_meta"
11
+
12
+ [project]
13
+ name = "mathformer"
14
+ version = "1.0.0"
15
+ description = "A transformer-based math library"
16
+ readme = "README.md"
17
+ authors = [
18
+ { name = "JeremySu0818", email = "xinghong.su0818@gmail.com" },
19
+ ]
20
+ classifiers = [
21
+ "Programming Language :: Python :: 3",
22
+ "License :: OSI Approved :: MIT License",
23
+ "Operating System :: OS Independent",
24
+ ]
25
+ requires-python = ">=3.8"
26
+ dependencies = [
27
+ "torch>=2.0.0",
28
+ "transformers>=4.30.0",
29
+ "safetensors>=0.3.0",
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = [
34
+ "pytest>=7.0.0",
35
+ "pytest-cov>=4.0.0",
36
+ ]
37
+
38
+ [project.urls]
39
+ "Homepage" = "https://github.com/JeremySu0818/MathFormer-API"
40
+ "Bug Tracker" = "https://github.com/JeremySu0818/MathFormer-API/issues"
41
+ "Repository" = "https://github.com/JeremySu0818/MathFormer-API"
42
+
43
+ [tool.pytest.ini_options]
44
+ testpaths = ["tests"]
45
+ python_files = ["test_*.py"]
46
+ python_classes = ["Test*"]
47
+ python_functions = ["test_*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,46 @@
1
+ from typing import Union, Optional
2
+ from .api import MathFormerAPI, MathFormer
3
+ from .tokenizer import MathTokenizer
4
+
5
+
6
+ __version__ = "1.0.0"
7
+
8
+
9
+ _default_api = MathFormerAPI()
10
+
11
+
12
+ def add(*args: Union[str, int]) -> str:
13
+ return _default_api.add(*args)
14
+
15
+
16
+ def sub(*args: Union[str, int]) -> str:
17
+ return _default_api.sub(*args)
18
+
19
+
20
+ def mul(*args: Union[str, int]) -> str:
21
+ return _default_api.mul(*args)
22
+
23
+
24
+ def div(*args: Union[str, int]) -> str:
25
+ return _default_api.div(*args)
26
+
27
+
28
+ def calculate(operation: str, a, b) -> str:
29
+ return _default_api.calculate(operation, a, b)
30
+
31
+
32
+ def unload_models():
33
+ _default_api.unload_all()
34
+
35
+
36
+ __all__ = [
37
+ "MathFormerAPI",
38
+ "MathFormer",
39
+ "MathTokenizer",
40
+ "add",
41
+ "sub",
42
+ "mul",
43
+ "div",
44
+ "calculate",
45
+ "unload_models",
46
+ ]
@@ -0,0 +1,580 @@
1
+ import re
2
+ from typing import Optional, Dict, Any, List, Union, Tuple
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from transformers import LlamaForCausalLM, logging
7
+
8
+ import os
9
+ import warnings
10
+
11
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
12
+ warnings.filterwarnings("ignore")
13
+ logging.set_verbosity_error()
14
+ logging.disable_progress_bar()
15
+
16
+ from .tokenizer import MathTokenizer
17
+
18
+
19
+ _BASE_DIR = Path(__file__).parent
20
+ _DEFAULT_MODEL_PATHS = {
21
+ "add": _BASE_DIR / "addformer",
22
+ "sub": _BASE_DIR / "subformer",
23
+ "mul": _BASE_DIR / "mulformer",
24
+ "div": _BASE_DIR / "divformer",
25
+ }
26
+
27
+
28
+ _OPERATION_SYMBOLS = {
29
+ "add": "+",
30
+ "sub": "-",
31
+ "mul": "*",
32
+ "div": "/",
33
+ }
34
+
35
+
36
+ class MathFormer:
37
+
38
+ def __init__(
39
+ self,
40
+ model_path: str,
41
+ device: Optional[str] = None,
42
+ max_new_tokens: int = 32,
43
+ ):
44
+ self.model_path = Path(model_path)
45
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
46
+ self.max_new_tokens = max_new_tokens
47
+
48
+ self._model: Optional[LlamaForCausalLM] = None
49
+ self._tokenizer: Optional[MathTokenizer] = None
50
+ self._loaded = False
51
+
52
+ def load(self) -> "MathFormer":
53
+ if self._loaded:
54
+ return self
55
+
56
+ self._tokenizer = MathTokenizer.from_pretrained(str(self.model_path))
57
+ self._model = LlamaForCausalLM.from_pretrained(str(self.model_path))
58
+ self._model.to(self.device)
59
+ self._model.eval()
60
+ self._loaded = True
61
+ return self
62
+
63
+ def unload(self) -> None:
64
+ if self._model is not None:
65
+ del self._model
66
+ self._model = None
67
+ if self._tokenizer is not None:
68
+ del self._tokenizer
69
+ self._tokenizer = None
70
+ self._loaded = False
71
+ if torch.cuda.is_available():
72
+ torch.cuda.empty_cache()
73
+
74
+ @property
75
+ def is_loaded(self) -> bool:
76
+ return self._loaded
77
+
78
+ def predict(self, expression: str) -> str:
79
+ if not self._loaded:
80
+ self.load()
81
+
82
+ if "=" not in expression:
83
+ expression += "="
84
+
85
+ inputs = self._tokenizer(expression, return_tensors="pt")
86
+ input_ids = inputs["input_ids"].to(self.device)
87
+ attention_mask = inputs["attention_mask"].to(self.device)
88
+
89
+ with torch.no_grad():
90
+ outputs = self._model.generate(
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ max_new_tokens=self.max_new_tokens,
94
+ pad_token_id=self._tokenizer.pad_token_id,
95
+ eos_token_id=self._tokenizer.eos_token_id,
96
+ do_sample=False,
97
+ repetition_penalty=1.1,
98
+ )
99
+
100
+ generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+
102
+ if "=" in generated_text:
103
+ answer = generated_text.split("=", 1)[1].strip()
104
+ else:
105
+ answer = generated_text.strip()
106
+
107
+ return answer
108
+
109
+ def __call__(self, expression: str) -> str:
110
+ return self.predict(expression)
111
+
112
+ def __enter__(self) -> "MathFormer":
113
+ self.load()
114
+ return self
115
+
116
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
117
+ self.unload()
118
+
119
+
120
+ class MathFormerAPI:
121
+
122
+ def __init__(
123
+ self,
124
+ model_paths: Optional[Dict[str, str]] = None,
125
+ device: Optional[str] = None,
126
+ max_new_tokens: int = 32,
127
+ lazy_load: bool = True,
128
+ ):
129
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
130
+ self.max_new_tokens = max_new_tokens
131
+
132
+ paths = model_paths or {}
133
+ self._model_paths = {
134
+ op: Path(paths.get(op, _DEFAULT_MODEL_PATHS[op]))
135
+ for op in ["add", "sub", "mul", "div"]
136
+ }
137
+
138
+ self.models: Dict[str, MathFormer] = {
139
+ op: MathFormer(
140
+ model_path=str(path),
141
+ device=self.device,
142
+ max_new_tokens=self.max_new_tokens,
143
+ )
144
+ for op, path in self._model_paths.items()
145
+ }
146
+
147
+ if not lazy_load:
148
+ self.load_all()
149
+
150
+ def load_all(self) -> "MathFormerAPI":
151
+ for model in self.models.values():
152
+ model.load()
153
+ return self
154
+
155
+ def unload_all(self) -> None:
156
+ for model in self.models.values():
157
+ model.unload()
158
+
159
+ def load(self, operation: str) -> "MathFormerAPI":
160
+ if operation in self.models:
161
+ self.models[operation].load()
162
+ return self
163
+
164
+ def unload(self, operation: str) -> None:
165
+ if operation in self.models:
166
+ self.models[operation].unload()
167
+
168
+ def _raw_predict(self, operation: str, expression: str) -> str:
169
+ if operation not in self.models:
170
+ raise ValueError(
171
+ f"Unknown operation type: {operation}. Available: {list(self.models.keys())}"
172
+ )
173
+ return self.models[operation].predict(expression)
174
+
175
+ def _single_add(self, a: int, b: int) -> Tuple[int, int]:
176
+ result_str = self._raw_predict("add", f"{a}+{b}")
177
+ result = int(result_str)
178
+ return result % 10, result // 10
179
+
180
+ def _single_sub(self, a: int, b: int, borrow: int = 0) -> Tuple[int, int]:
181
+ a_actual = a - borrow
182
+
183
+ if a_actual >= b:
184
+ result_str = self._raw_predict("sub", f"{a_actual}-{b}")
185
+ return int(result_str), 0
186
+ else:
187
+ a_with_borrow = a_actual + 10
188
+ result_str = self._raw_predict("sub", f"{a_with_borrow}-{b}")
189
+ return int(result_str), 1
190
+
191
+ def _single_mul(self, a: int, b: int) -> int:
192
+ result_str = self._raw_predict("mul", f"{a}*{b}")
193
+ return int(result_str)
194
+
195
+ def _single_div(self, a: int, b: int) -> Tuple[int, int]:
196
+ result_str = self._raw_predict("div", f"{a}/{b}")
197
+ match = re.match(r"Q(\d+)R(\d+)", result_str)
198
+ if match:
199
+ return int(match.group(1)), int(match.group(2))
200
+ return int(result_str), 0
201
+
202
+ def _multi_add(self, a: int, b: int) -> int:
203
+ if a < 0 or b < 0:
204
+ if a < 0 and b < 0:
205
+ return -self._multi_add(-a, -b)
206
+ elif a < 0:
207
+ return self._multi_sub(b, -a)
208
+ else:
209
+ return self._multi_sub(a, -b)
210
+
211
+ digits_a = [int(d) for d in str(a)[::-1]]
212
+ digits_b = [int(d) for d in str(b)[::-1]]
213
+
214
+ max_len = max(len(digits_a), len(digits_b))
215
+ digits_a.extend([0] * (max_len - len(digits_a)))
216
+ digits_b.extend([0] * (max_len - len(digits_b)))
217
+
218
+ result = []
219
+ carry = 0
220
+
221
+ for i in range(max_len):
222
+ sum_with_carry = digits_a[i] + carry
223
+ carry_from_first = 0
224
+
225
+ if sum_with_carry > 9:
226
+ sum_with_carry = sum_with_carry - 10
227
+ carry_from_first = 1
228
+
229
+ digit_result, new_carry = self._single_add(sum_with_carry, digits_b[i])
230
+
231
+ result.append(digit_result)
232
+ carry = new_carry + carry_from_first
233
+
234
+ if carry > 0:
235
+ result.append(carry)
236
+
237
+ return int("".join(str(d) for d in result[::-1]))
238
+
239
+ def _multi_sub(self, a: int, b: int) -> int:
240
+ if a < 0 and b < 0:
241
+ return self._multi_sub(-b, -a)
242
+ elif a < 0:
243
+ return -self._multi_add(-a, b)
244
+ elif b < 0:
245
+ return self._multi_add(a, -b)
246
+
247
+ if a < b:
248
+ return -self._multi_sub(b, a)
249
+
250
+ digits_a = [int(d) for d in str(a)[::-1]]
251
+ digits_b = [int(d) for d in str(b)[::-1]]
252
+
253
+ digits_b.extend([0] * (len(digits_a) - len(digits_b)))
254
+
255
+ result = []
256
+ borrow = 0
257
+
258
+ for i in range(len(digits_a)):
259
+ digit_a = digits_a[i]
260
+ digit_b = digits_b[i]
261
+
262
+ digit_result, new_borrow = self._single_sub(digit_a, digit_b, borrow)
263
+ result.append(digit_result)
264
+ borrow = new_borrow
265
+
266
+ while len(result) > 1 and result[-1] == 0:
267
+ result.pop()
268
+
269
+ return int("".join(str(d) for d in result[::-1]))
270
+
271
+ def _multi_mul(self, a: int, b: int) -> int:
272
+ negative = (a < 0) ^ (b < 0)
273
+ a, b = abs(a), abs(b)
274
+
275
+ if a == 0 or b == 0:
276
+ return 0
277
+
278
+ digits_a = [int(d) for d in str(a)[::-1]]
279
+ digits_b = [int(d) for d in str(b)[::-1]]
280
+
281
+ result = [0] * (len(digits_a) + len(digits_b))
282
+
283
+ for i, digit_b in enumerate(digits_b):
284
+ carry = 0
285
+ for j, digit_a in enumerate(digits_a):
286
+ product = self._single_mul(digit_a, digit_b)
287
+
288
+ total = product + carry + result[i + j]
289
+
290
+ result[i + j] = total % 10
291
+ carry = total // 10
292
+
293
+ k = i + len(digits_a)
294
+ while carry > 0:
295
+ total = carry + result[k]
296
+ result[k] = total % 10
297
+ carry = total // 10
298
+ k += 1
299
+
300
+ while len(result) > 1 and result[-1] == 0:
301
+ result.pop()
302
+
303
+ final_result = int("".join(str(d) for d in result[::-1]))
304
+ return -final_result if negative else final_result
305
+
306
+ def _trial_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
307
+ quotient = 0
308
+
309
+ for q in range(9, -1, -1):
310
+ product = self._multi_mul(divisor, q)
311
+ if product <= dividend:
312
+ quotient = q
313
+ break
314
+
315
+ product = self._multi_mul(divisor, quotient)
316
+ remainder = self._multi_sub(dividend, product)
317
+
318
+ return quotient, remainder
319
+
320
+ def _multi_div(self, a: int, b: int) -> Tuple[int, int]:
321
+ if b == 0:
322
+ raise ZeroDivisionError("Divisor cannot be zero")
323
+
324
+ negative = (a < 0) ^ (b < 0)
325
+ a, b = abs(a), abs(b)
326
+
327
+ if a < b:
328
+ return 0, a
329
+
330
+ if a == 0:
331
+ return 0, 0
332
+
333
+ digits_a = [int(d) for d in str(a)]
334
+
335
+ quotient_digits = []
336
+ remainder = 0
337
+
338
+ for digit in digits_a:
339
+ current = remainder * 10 + digit
340
+
341
+ if current < b:
342
+ quotient_digits.append(0)
343
+ remainder = current
344
+ else:
345
+ if b <= 9 and current <= 89:
346
+ q, r = self._single_div(current, b)
347
+ else:
348
+ q, r = self._trial_division(current, b)
349
+
350
+ quotient_digits.append(q)
351
+ remainder = r
352
+
353
+ while len(quotient_digits) > 1 and quotient_digits[0] == 0:
354
+ quotient_digits.pop(0)
355
+
356
+ quotient = int("".join(str(d) for d in quotient_digits))
357
+
358
+ if negative:
359
+ quotient = -quotient
360
+
361
+ return quotient, remainder
362
+
363
+ def _parse_expression(self, expression: str, operation: str) -> Tuple[int, int]:
364
+ expression = expression.replace(" ", "").replace("=", "")
365
+
366
+ if operation == "add":
367
+ parts = expression.split("+")
368
+ elif operation == "sub":
369
+ if expression.startswith("-"):
370
+ rest = expression[1:]
371
+ if "-" in rest:
372
+ idx = rest.index("-")
373
+ parts = ["-" + rest[:idx], rest[idx + 1 :]]
374
+ else:
375
+ raise ValueError(f"Cannot parse expression: {expression}")
376
+ else:
377
+ parts = expression.split("-")
378
+ elif operation == "mul":
379
+ expression = expression.replace("×", "*")
380
+ parts = expression.split("*")
381
+ elif operation == "div":
382
+ expression = expression.replace("÷", "/")
383
+ parts = expression.split("/")
384
+ else:
385
+ raise ValueError(f"Unknown operation type: {operation}")
386
+
387
+ if len(parts) != 2:
388
+ raise ValueError(f"Cannot parse expression: {expression}")
389
+
390
+ return int(parts[0]), int(parts[1])
391
+
392
+ def add(self, *args: Union[str, int]) -> str:
393
+ values = []
394
+ if len(args) == 0:
395
+ raise ValueError("At least one argument is required")
396
+
397
+ if len(args) == 1 and isinstance(args[0], str) and "+" in args[0]:
398
+ expression = args[0].replace(" ", "").replace("=", "")
399
+ parts = expression.split("+")
400
+ try:
401
+ values = [int(p) for p in parts]
402
+ except ValueError:
403
+ raise ValueError(f"Cannot parse expression: {expression}")
404
+ else:
405
+ try:
406
+ values = [int(a) for a in args]
407
+ except ValueError:
408
+ raise ValueError(
409
+ f"Arguments contain values that cannot be converted to integers: {args}"
410
+ )
411
+
412
+ if not values:
413
+ return "0"
414
+
415
+ result = values[0]
416
+ for val in values[1:]:
417
+ result = self._multi_add(result, val)
418
+
419
+ return str(result)
420
+
421
+ def sub(self, *args: Union[str, int]) -> str:
422
+ values = []
423
+ if len(args) == 0:
424
+ raise ValueError("At least one argument is required")
425
+
426
+ if len(args) == 1 and isinstance(args[0], str) and "-" in args[0].lstrip("-"):
427
+ expression = args[0].replace(" ", "").replace("=", "")
428
+ if expression.startswith("-"):
429
+ temp_expr = expression[1:]
430
+ parts = temp_expr.split("-")
431
+ values = [-int(parts[0])] + [int(p) for p in parts[1:]]
432
+ else:
433
+ parts = expression.split("-")
434
+ values = [int(p) for p in parts]
435
+ else:
436
+ try:
437
+ values = [int(a) for a in args]
438
+ except ValueError:
439
+ raise ValueError(
440
+ f"Arguments contain values that cannot be converted to integers: {args}"
441
+ )
442
+
443
+ if not values:
444
+ return "0"
445
+
446
+ result = values[0]
447
+ for val in values[1:]:
448
+ result = self._multi_sub(result, val)
449
+
450
+ return str(result)
451
+
452
+ def mul(self, *args: Union[str, int]) -> str:
453
+ values = []
454
+ if len(args) == 0:
455
+ raise ValueError("At least one argument is required")
456
+
457
+ if (
458
+ len(args) == 1
459
+ and isinstance(args[0], str)
460
+ and any(op in args[0] for op in ["*", "×"])
461
+ ):
462
+ expression = args[0].replace(" ", "").replace("=", "").replace("×", "*")
463
+ parts = expression.split("*")
464
+ try:
465
+ values = [int(p) for p in parts]
466
+ except ValueError:
467
+ raise ValueError(f"Cannot parse expression: {expression}")
468
+ else:
469
+ try:
470
+ values = [int(a) for a in args]
471
+ except ValueError:
472
+ raise ValueError(
473
+ f"Arguments contain values that cannot be converted to integers: {args}"
474
+ )
475
+
476
+ if not values:
477
+ return "0"
478
+
479
+ result = values[0]
480
+ for val in values[1:]:
481
+ result = self._multi_mul(result, val)
482
+
483
+ return str(result)
484
+
485
+ def div(self, *args: Union[str, int]) -> str:
486
+ values = []
487
+ if len(args) == 0:
488
+ raise ValueError("At least one argument is required")
489
+
490
+ if (
491
+ len(args) == 1
492
+ and isinstance(args[0], str)
493
+ and any(op in args[0] for op in ["/", "÷"])
494
+ ):
495
+ expression = args[0].replace(" ", "").replace("=", "").replace("÷", "/")
496
+ parts = expression.split("/")
497
+ try:
498
+ values = [int(p) for p in parts]
499
+ except ValueError:
500
+ raise ValueError(f"Cannot parse expression: {expression}")
501
+ else:
502
+ try:
503
+ values = [int(a) for a in args]
504
+ except ValueError:
505
+ raise ValueError(
506
+ f"Arguments contain values that cannot be converted to integers: {args}"
507
+ )
508
+
509
+ if not values:
510
+ return "0"
511
+
512
+ result_q = values[0]
513
+ result_r = 0
514
+
515
+ for val in values[1:]:
516
+ result_q, result_r = self._multi_div(result_q, val)
517
+
518
+ if result_r == 0:
519
+ return str(result_q)
520
+ else:
521
+ return f"Q{result_q}R{result_r}"
522
+
523
+ def calculate(
524
+ self, operation: str, a: Union[int, float, str], b: Union[int, float, str]
525
+ ) -> str:
526
+ a_int = int(a)
527
+ b_int = int(b)
528
+
529
+ if operation == "add":
530
+ result = self._multi_add(a_int, b_int)
531
+ return str(result)
532
+ elif operation == "sub":
533
+ result = self._multi_sub(a_int, b_int)
534
+ return str(result)
535
+ elif operation == "mul":
536
+ result = self._multi_mul(a_int, b_int)
537
+ return str(result)
538
+ elif operation == "div":
539
+ quotient, remainder = self._multi_div(a_int, b_int)
540
+ if remainder == 0:
541
+ return str(quotient)
542
+ else:
543
+ return f"Q{quotient}R{remainder}"
544
+ else:
545
+ raise ValueError(f"Unknown operation type: {operation}")
546
+
547
+ def batch_predict(
548
+ self,
549
+ operation: str,
550
+ expressions: List[str],
551
+ ) -> List[str]:
552
+ results = []
553
+ for expr in expressions:
554
+ if operation == "add":
555
+ results.append(self.add(expr))
556
+ elif operation == "sub":
557
+ results.append(self.sub(expr))
558
+ elif operation == "mul":
559
+ results.append(self.mul(expr))
560
+ elif operation == "div":
561
+ results.append(self.div(expr))
562
+ else:
563
+ raise ValueError(f"Unknown operation type: {operation}")
564
+ return results
565
+
566
+ def get_model_info(self) -> Dict[str, Any]:
567
+ return {
568
+ op: {
569
+ "path": str(model.model_path),
570
+ "loaded": model.is_loaded,
571
+ "device": model.device,
572
+ }
573
+ for op, model in self.models.items()
574
+ }
575
+
576
+ def __enter__(self) -> "MathFormerAPI":
577
+ return self
578
+
579
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
580
+ self.unload_all()
@@ -0,0 +1,101 @@
1
+ import os
2
+ import json
3
+ from typing import List, Dict, Union, Optional
4
+ import torch
5
+
6
+
7
+ class MathTokenizer:
8
+
9
+ def __init__(self, model_max_length: int = 64):
10
+ self.chars = [
11
+ "<pad>", "<s>", "</s>", "<unk>",
12
+ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
13
+ "+", "-", "*", "/", "=", ".", "(", ")", "^", "%", " ",
14
+ "Q", "R",
15
+ ]
16
+ self.token_to_id = {c: i for i, c in enumerate(self.chars)}
17
+ self.id_to_token = {i: c for i, c in enumerate(self.chars)}
18
+ self.pad_token_id = self.token_to_id["<pad>"]
19
+ self.eos_token_id = self.token_to_id["</s>"]
20
+ self.bos_token_id = self.token_to_id["<s>"]
21
+ self.unk_token_id = self.token_to_id["<unk>"]
22
+ self.padding_side = "left"
23
+ self.model_max_length = model_max_length
24
+
25
+ @classmethod
26
+ def from_pretrained(cls, path: str) -> "MathTokenizer":
27
+ config_path = os.path.join(path, "tokenizer_config.json")
28
+ vocab_path = os.path.join(path, "vocab.json")
29
+
30
+ tokenizer = cls()
31
+ if os.path.exists(config_path):
32
+ with open(config_path, "r", encoding="utf-8") as f:
33
+ config = json.load(f)
34
+ tokenizer.model_max_length = config.get("model_max_length", 64)
35
+ tokenizer.padding_side = config.get("padding_side", "left")
36
+
37
+ if os.path.exists(vocab_path):
38
+ with open(vocab_path, "r", encoding="utf-8") as f:
39
+ vocab = json.load(f)
40
+ tokenizer.token_to_id = vocab
41
+ tokenizer.id_to_token = {int(v): k for k, v in vocab.items()}
42
+
43
+ return tokenizer
44
+
45
+ def __call__(
46
+ self,
47
+ texts: Union[str, List[str]],
48
+ return_tensors: Optional[str] = None,
49
+ padding: bool = True,
50
+ ) -> Dict:
51
+ if isinstance(texts, str):
52
+ texts = [texts]
53
+
54
+ input_ids_list = []
55
+ attention_mask_list = []
56
+
57
+ for text in texts:
58
+ ids = [self.token_to_id.get(c, self.unk_token_id) for c in text]
59
+ input_ids_list.append(ids)
60
+ attention_mask_list.append([1] * len(ids))
61
+
62
+ if return_tensors == "pt":
63
+ max_len = max(len(x) for x in input_ids_list)
64
+ padded_ids = []
65
+ padded_mask = []
66
+
67
+ for ids, mask in zip(input_ids_list, attention_mask_list):
68
+ pad_len = max_len - len(ids)
69
+ if self.padding_side == "left":
70
+ ids = [self.pad_token_id] * pad_len + ids
71
+ mask = [0] * pad_len + mask
72
+ else:
73
+ ids = ids + [self.pad_token_id] * pad_len
74
+ mask = mask + [0] * pad_len
75
+ padded_ids.append(ids)
76
+ padded_mask.append(mask)
77
+
78
+ return {
79
+ "input_ids": torch.tensor(padded_ids, dtype=torch.long),
80
+ "attention_mask": torch.tensor(padded_mask, dtype=torch.long),
81
+ }
82
+
83
+ return {"input_ids": input_ids_list, "attention_mask": attention_mask_list}
84
+
85
+ def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = False) -> str:
86
+ result = ""
87
+ if isinstance(token_ids, torch.Tensor):
88
+ token_ids = token_ids.tolist()
89
+
90
+ for idx in token_ids:
91
+ char = self.id_to_token.get(idx, "<unk>")
92
+ if skip_special_tokens and char in ["<pad>", "<s>", "</s>"]:
93
+ continue
94
+ result += char
95
+ return result
96
+
97
+ def batch_decode(self, sequences: List, skip_special_tokens: bool = False) -> List[str]:
98
+ return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]
99
+
100
+ def __len__(self) -> int:
101
+ return len(self.chars)
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: mathformer
3
+ Version: 1.0.0
4
+ Summary: A transformer-based math library
5
+ Author-email: JeremySu0818 <xinghong.su0818@gmail.com>
6
+ Project-URL: Homepage, https://github.com/JeremySu0818/MathFormer-API
7
+ Project-URL: Bug Tracker, https://github.com/JeremySu0818/MathFormer-API/issues
8
+ Project-URL: Repository, https://github.com/JeremySu0818/MathFormer-API
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: torch>=2.0.0
16
+ Requires-Dist: transformers>=4.30.0
17
+ Requires-Dist: safetensors>=0.3.0
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
20
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
21
+ Dynamic: license-file
22
+
23
+ # MathFormer
24
+
25
+ MathFormer is a Python library for mathematical operations using transformer architectures.
26
+
27
+ ## Installation
28
+
29
+ ```bash
30
+ pip install mathformer
31
+ ```
32
+
33
+ ## Usage
34
+
35
+ ```python
36
+ import mathformer
37
+
38
+ # Example usage
39
+ # mathformer.do_something()
40
+ ```
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/mathformer/__init__.py
5
+ src/mathformer/api.py
6
+ src/mathformer/tokenizer.py
7
+ src/mathformer.egg-info/PKG-INFO
8
+ src/mathformer.egg-info/SOURCES.txt
9
+ src/mathformer.egg-info/dependency_links.txt
10
+ src/mathformer.egg-info/requires.txt
11
+ src/mathformer.egg-info/top_level.txt
12
+ tests/test_api.py
@@ -0,0 +1,7 @@
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ safetensors>=0.3.0
4
+
5
+ [dev]
6
+ pytest>=7.0.0
7
+ pytest-cov>=4.0.0
@@ -0,0 +1 @@
1
+ mathformer
@@ -0,0 +1,107 @@
1
+ """
2
+ Unit tests for MathFormer API
3
+ """
4
+
5
+ import pytest
6
+ import mathformer
7
+
8
+
9
+ class TestAddition:
10
+ """Test cases for addition operations"""
11
+
12
+ def test_add_two_integers(self):
13
+ """Test adding two integers"""
14
+ result = mathformer.add(1, 2)
15
+ assert result == "3"
16
+
17
+ def test_add_multiple_integers(self):
18
+ """Test adding multiple integers"""
19
+ result = mathformer.add(1, 2, 3)
20
+ assert result == "6"
21
+
22
+ def test_add_with_strings(self):
23
+ """Test adding numbers passed as strings"""
24
+ result = mathformer.add("10", "20")
25
+ assert result == "30"
26
+
27
+
28
+ class TestSubtraction:
29
+ """Test cases for subtraction operations"""
30
+
31
+ def test_sub_two_integers(self):
32
+ """Test subtracting two integers"""
33
+ result = mathformer.sub(5, 3)
34
+ assert result == "2"
35
+
36
+ def test_sub_multiple_integers(self):
37
+ """Test subtracting multiple integers"""
38
+ result = mathformer.sub(10, 3, 2)
39
+ assert result == "5"
40
+
41
+
42
+ class TestMultiplication:
43
+ """Test cases for multiplication operations"""
44
+
45
+ def test_mul_two_integers(self):
46
+ """Test multiplying two integers"""
47
+ result = mathformer.mul(3, 4)
48
+ assert result == "12"
49
+
50
+ def test_mul_multiple_integers(self):
51
+ """Test multiplying multiple integers"""
52
+ result = mathformer.mul(2, 3, 4)
53
+ assert result == "24"
54
+
55
+
56
+ class TestDivision:
57
+ """Test cases for division operations"""
58
+
59
+ def test_div_two_integers(self):
60
+ """Test dividing two integers"""
61
+ result = mathformer.div(10, 2)
62
+ assert result == "5"
63
+
64
+
65
+ class TestCalculate:
66
+ """Test cases for the calculate function"""
67
+
68
+ def test_calculate_add(self):
69
+ """Test calculate with add operation"""
70
+ result = mathformer.calculate("add", 5, 3)
71
+ assert result == "8"
72
+
73
+ def test_calculate_sub(self):
74
+ """Test calculate with sub operation"""
75
+ result = mathformer.calculate("sub", 10, 4)
76
+ assert result == "6"
77
+
78
+ def test_calculate_mul(self):
79
+ """Test calculate with mul operation"""
80
+ result = mathformer.calculate("mul", 6, 7)
81
+ assert result == "42"
82
+
83
+ def test_calculate_div(self):
84
+ """Test calculate with div operation"""
85
+ result = mathformer.calculate("div", 20, 5)
86
+ assert result == "4"
87
+
88
+
89
+ class TestModuleExports:
90
+ """Test that all expected exports are available"""
91
+
92
+ def test_mathformerapi_exists(self):
93
+ """Test MathFormerAPI class is exported"""
94
+ assert hasattr(mathformer, "MathFormerAPI")
95
+
96
+ def test_mathformer_exists(self):
97
+ """Test MathFormer class is exported"""
98
+ assert hasattr(mathformer, "MathFormer")
99
+
100
+ def test_tokenizer_exists(self):
101
+ """Test MathTokenizer class is exported"""
102
+ assert hasattr(mathformer, "MathTokenizer")
103
+
104
+ def test_version_exists(self):
105
+ """Test version is defined"""
106
+ assert hasattr(mathformer, "__version__")
107
+ assert mathformer.__version__ == "1.0.0"