mm-balance 0.5.1__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. mm_balance-0.6.0/.claude/settings.local.json +12 -0
  2. {mm_balance-0.5.1 → mm_balance-0.6.0}/.gitignore +1 -0
  3. mm_balance-0.6.0/CLAUDE.md +13 -0
  4. mm_balance-0.6.0/PKG-INFO +10 -0
  5. {mm_balance-0.5.1 → mm_balance-0.6.0}/justfile +1 -1
  6. {mm_balance-0.5.1 → mm_balance-0.6.0}/pyproject.toml +13 -13
  7. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/cli.py +2 -2
  8. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/command_runner.py +1 -1
  9. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/config/example.toml +1 -1
  10. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/config.py +8 -4
  11. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/diff.py +1 -1
  12. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/output/formats/table_format.py +8 -8
  13. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/price.py +1 -1
  14. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/result.py +5 -5
  15. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/token_decimals.py +2 -2
  16. mm_balance-0.6.0/src/mm_balance/utils.py +163 -0
  17. mm_balance-0.6.0/tests/test_share_expression.py +88 -0
  18. mm_balance-0.6.0/uv.lock +2226 -0
  19. mm_balance-0.5.1/PKG-INFO +0 -10
  20. mm_balance-0.5.1/dict.dic +0 -10
  21. mm_balance-0.5.1/src/mm_balance/utils.py +0 -30
  22. mm_balance-0.5.1/tests/test_dummy.py +0 -2
  23. mm_balance-0.5.1/uv.lock +0 -1901
  24. {mm_balance-0.5.1 → mm_balance-0.6.0}/README.md +0 -0
  25. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/__init__.py +0 -0
  26. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/balance_fetcher.py +0 -0
  27. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/constants.py +0 -0
  28. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/output/__init__.py +0 -0
  29. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/output/formats/__init__.py +0 -0
  30. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/output/formats/json_format.py +0 -0
  31. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/output/utils.py +0 -0
  32. {mm_balance-0.5.1 → mm_balance-0.6.0}/src/mm_balance/rpc.py +0 -0
  33. {mm_balance-0.5.1 → mm_balance-0.6.0}/tests/__init__.py +0 -0
  34. {mm_balance-0.5.1 → mm_balance-0.6.0}/tests/conftest.py +0 -0
@@ -0,0 +1,12 @@
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "mcp__ide__getDiagnostics",
5
+ "Bash(uv run pytest:*)",
6
+ "Bash(just lint)",
7
+ "Bash(uv run ruff check:*)"
8
+ ],
9
+ "deny": [],
10
+ "ask": []
11
+ }
12
+ }
@@ -13,3 +13,4 @@ pip-wheel-metadata
13
13
  /build
14
14
  /tmp
15
15
  .DS_Store
16
+ requirements.txt
@@ -0,0 +1,13 @@
1
+ # Claude Guidelines
2
+
3
+ ## Critical Guidelines
4
+
5
+ 1. **Always communicate in English** - Regardless of the language the user speaks, always respond in English. All code, comments, and documentation must be in English.
6
+
7
+ 2. **Minimal documentation** - Only add comments/documentation when it simplifies understanding and isn't obvious from the code itself. Keep it strictly relevant and concise.
8
+
9
+ 3. **Critical thinking** - Always critically evaluate user ideas. Users can make mistakes. Think first about whether the user's idea is good before implementing.
10
+
11
+ 4. **Lint after changes** - After making code changes, always run `just lint` to verify code quality and fix any linter issues.
12
+
13
+ 5. **No disabling linter rules** - Never use special disabling comments (like `# noqa`, `# type: ignore`, `# ruff: noqa`, etc.) to turn off linter rules without explicit permission. If you believe a rule should be disabled, ask first.
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: mm-balance
3
+ Version: 0.6.0
4
+ Requires-Python: >=3.13
5
+ Requires-Dist: deepdiff==8.6.1
6
+ Requires-Dist: mm-apt==0.5.0
7
+ Requires-Dist: mm-btc==0.5.5
8
+ Requires-Dist: mm-concurrency~=0.1.0
9
+ Requires-Dist: mm-eth==0.7.3
10
+ Requires-Dist: mm-sol==0.7.3
@@ -21,7 +21,7 @@ lint: format
21
21
 
22
22
  audit:
23
23
  uv export --no-dev --all-extras --format requirements-txt --no-emit-project > requirements.txt
24
- uv run pip-audit -r requirements.txt --disable-pip
24
+ uv run pip-audit -r requirements.txt --disable-pip --ignore-vuln GHSA-wj6h-64fc-37mp
25
25
  rm requirements.txt
26
26
  uv run bandit --silent --recursive --configfile "pyproject.toml" src
27
27
 
@@ -1,15 +1,15 @@
1
1
  [project]
2
2
  name = "mm-balance"
3
- version = "0.5.1"
3
+ version = "0.6.0"
4
4
  description = ""
5
5
  requires-python = ">=3.13"
6
6
  dependencies = [
7
7
  "mm-concurrency~=0.1.0",
8
- "mm-apt==0.4.2",
9
- "mm-btc==0.5.3",
10
- "mm-eth==0.7.1",
11
- "mm-sol==0.7.1",
12
- "deepdiff==8.5.0",
8
+ "mm-apt==0.5.0",
9
+ "mm-btc==0.5.5",
10
+ "mm-eth==0.7.3",
11
+ "mm-sol==0.7.3",
12
+ "deepdiff==8.6.1",
13
13
  ]
14
14
  [project.scripts]
15
15
  mm-balance = "mm_balance.cli:app"
@@ -18,14 +18,14 @@ mm-balance = "mm_balance.cli:app"
18
18
  requires = ["hatchling"]
19
19
  build-backend = "hatchling.build"
20
20
 
21
- [tool.uv]
22
- dev-dependencies = [
23
- "pytest~=8.4.0",
24
- "pytest-xdist~=3.7.0",
25
- "ruff~=0.11.13",
21
+ [dependency-groups]
22
+ dev = [
23
+ "pytest~=8.4.2",
24
+ "pytest-xdist~=3.8.0",
25
+ "ruff~=0.14.2",
26
26
  "pip-audit~=2.9.0",
27
- "bandit~=1.8.3",
28
- "mypy~=1.16.0",
27
+ "bandit~=1.8.6",
28
+ "mypy~=1.18.2",
29
29
  ]
30
30
 
31
31
  [tool.mypy]
@@ -25,8 +25,8 @@ def example_callback(value: bool) -> None:
25
25
  if value:
26
26
  data = pkgutil.get_data(__name__, "config/example.toml")
27
27
  if data is None:
28
- mm_print.fatal("Example config not found")
29
- mm_print.toml(toml=data.decode("utf-8"))
28
+ mm_print.exit_with_error("Example config not found")
29
+ mm_print.toml(data.decode("utf-8"))
30
30
  raise typer.Exit
31
31
 
32
32
 
@@ -63,7 +63,7 @@ async def run(params: CommandParameters) -> None:
63
63
  elif config.settings.print_format is PrintFormat.JSON:
64
64
  json_format.print_result(config, token_decimals, prices, workers, result)
65
65
  else:
66
- mm_print.fatal("Unsupported print format")
66
+ mm_print.exit_with_error("Unsupported print format")
67
67
 
68
68
  if params.save_balances:
69
69
  BalancesDict.from_balances_result(result).save_to_path(params.save_balances)
@@ -17,7 +17,7 @@ addresses = """
17
17
  bc1qgdjqv0av3q56jvd82tkdjpy7gdp9ut8tlqmgrpmv24sq90ecnvqqjwvw97 # bitfinex
18
18
  bc1ql49ydapnjafl5t2cp9zqpjwe6pdgmxy98859v2 # robinhood
19
19
  """
20
- share = 0.1
20
+ share = "0.5(total - 100.5)" # or: "total - 1000", "0.3total + 50", etc.
21
21
 
22
22
  [[coins]]
23
23
  ticker = "ETH"
@@ -10,7 +10,7 @@ from mm_web3 import ConfigValidators, Web3CliConfig
10
10
  from pydantic import BeforeValidator, Field, StringConstraints, model_validator
11
11
 
12
12
  from mm_balance.constants import DEFAULT_NODES, TOKEN_ADDRESS, Network
13
- from mm_balance.utils import PrintFormat
13
+ from mm_balance.utils import PrintFormat, evaluate_share_expression
14
14
 
15
15
 
16
16
  class Validators(ConfigValidators):
@@ -32,7 +32,7 @@ class AssetGroup(Web3CliConfig):
32
32
  decimals: int | None = None
33
33
  coingecko_id: str | None = None
34
34
  addresses: Annotated[list[str], BeforeValidator(Validators.addresses(deduplicate=True))]
35
- share: Decimal = Decimal(1)
35
+ share: str = "total"
36
36
 
37
37
  @property
38
38
  def name(self) -> str:
@@ -42,6 +42,10 @@ class AssetGroup(Web3CliConfig):
42
42
  result += " / " + self.network
43
43
  return result
44
44
 
45
+ def evaluate_share(self, balance_sum: Decimal) -> Decimal:
46
+ """Evaluate share expression with actual balance_sum value."""
47
+ return evaluate_share_expression(self.share, balance_sum)
48
+
45
49
  @model_validator(mode="after")
46
50
  def final_validator(self) -> Self:
47
51
  if self.token is None:
@@ -58,7 +62,7 @@ class AssetGroup(Web3CliConfig):
58
62
  if path.is_file():
59
63
  result += path.read_text().strip().splitlines()
60
64
  else:
61
- mm_print.fatal(f"File with addresses not found: {path}")
65
+ mm_print.exit_with_error(f"File with addresses not found: {path}")
62
66
  elif line.startswith("group:"):
63
67
  group_name = line.removeprefix("group:").strip()
64
68
  address_group = next((ag for ag in address_groups if ag.name == group_name), None)
@@ -96,7 +100,7 @@ class Config(Web3CliConfig):
96
100
  settings: Settings = Field(default_factory=Settings) # type: ignore[arg-type]
97
101
 
98
102
  def has_share(self) -> bool:
99
- return any(g.share != Decimal(1) for g in self.groups)
103
+ return any(g.share != "total" for g in self.groups)
100
104
 
101
105
  def networks(self) -> list[Network]:
102
106
  return pydash.uniq([group.network for group in self.groups])
@@ -97,7 +97,7 @@ class Diff(BaseModel):
97
97
  for address in self.balance_changed[network][ticker]:
98
98
  old_value, new_value = self.balance_changed[network][ticker][address]
99
99
  rows.append([network, ticker, address, old_value, new_value, new_value - old_value])
100
- mm_print.table("", ["Network", "Ticker", "Address", "Old", "New", "Change"], rows)
100
+ mm_print.table(["Network", "Ticker", "Address", "Old", "New", "Change"], rows)
101
101
 
102
102
  def _print_json(self) -> None:
103
103
  # mm_print.json(data=self.model_dump(), type_handlers=str) ?? default?
@@ -14,18 +14,18 @@ def print_nodes(config: Config) -> None:
14
14
  rows = []
15
15
  for network, nodes in config.nodes.items():
16
16
  rows.append([network, "\n".join(nodes)])
17
- mm_print.table("Nodes", ["network", "nodes"], rows)
17
+ mm_print.table(["network", "nodes"], rows, title="Nodes")
18
18
 
19
19
 
20
20
  def print_proxy_count(config: Config) -> None:
21
- mm_print.table("Proxies", ["count"], [[len(config.settings.proxies)]])
21
+ mm_print.table(["count"], [[len(config.settings.proxies)]], title="Proxies")
22
22
 
23
23
 
24
24
  def print_token_decimals(token_decimals: TokenDecimals) -> None:
25
25
  rows = []
26
26
  for network, decimals in token_decimals.items():
27
27
  rows.append([network, decimals])
28
- mm_print.table("Token Decimals", ["network", "decimals"], rows)
28
+ mm_print.table(["network", "decimals"], rows, title="Token Decimals")
29
29
 
30
30
 
31
31
  def print_prices(config: Config, prices: Prices) -> None:
@@ -35,7 +35,7 @@ def print_prices(config: Config, prices: Prices) -> None:
35
35
  rows.append(
36
36
  [ticker, format_number(price, config.settings.format_number_separator, "$", config.settings.round_ndigits)]
37
37
  )
38
- mm_print.table("Prices", ["coin", "usd"], rows)
38
+ mm_print.table(["coin", "usd"], rows, title="Prices")
39
39
 
40
40
 
41
41
  def print_result(config: Config, result: BalancesResult, workers: BalanceFetcher) -> None:
@@ -57,7 +57,7 @@ def _print_errors(config: Config, workers: BalanceFetcher) -> None:
57
57
  for task in error_tasks:
58
58
  group = config.groups[task.group_index]
59
59
  rows.append([group.ticker + " / " + group.network, task.wallet_address, task.balance.error]) # type: ignore[union-attr]
60
- mm_print.table("Errors", ["coin", "address", "error"], rows)
60
+ mm_print.table(["coin", "address", "error"], rows, title="Errors")
61
61
 
62
62
 
63
63
  def _print_total(config: Config, total: Total, is_share_total: bool) -> None:
@@ -80,7 +80,7 @@ def _print_total(config: Config, total: Total, is_share_total: bool) -> None:
80
80
  rows.append(["stablecoin_sum", format_number(total.stablecoin_sum, config.settings.format_number_separator, "$")])
81
81
  rows.append(["total_usd_sum", format_number(total.total_usd_sum, config.settings.format_number_separator, "$")])
82
82
 
83
- mm_print.table(table_name, headers, rows)
83
+ mm_print.table(headers, rows, title=table_name)
84
84
 
85
85
 
86
86
  def _print_group(config: Config, group: GroupResult) -> None:
@@ -108,7 +108,7 @@ def _print_group(config: Config, group: GroupResult) -> None:
108
108
  sum_row.append(format_number(group.usd_sum, config.settings.format_number_separator, "$"))
109
109
  rows.append(sum_row)
110
110
 
111
- if group.share < Decimal(1):
111
+ if group.share != "total":
112
112
  sum_share_str = format_number(group.balance_sum_share, config.settings.format_number_separator)
113
113
  sum_share_row = [f"sum_share, {group.share}", sum_share_str]
114
114
  if config.settings.price:
@@ -118,4 +118,4 @@ def _print_group(config: Config, group: GroupResult) -> None:
118
118
  table_headers = ["address", "balance"]
119
119
  if config.settings.price:
120
120
  table_headers += ["usd"]
121
- mm_print.table(group_name, table_headers, rows)
121
+ mm_print.table(table_headers, rows, title=group_name)
@@ -34,7 +34,7 @@ async def get_prices(config: Config) -> Prices:
34
34
  if res.status_code != 200:
35
35
  continue
36
36
 
37
- json_body = res.parse_json_body() or {}
37
+ json_body = res.parse_json()
38
38
 
39
39
  for ticker, coingecko_id in coingecko_map.items():
40
40
  if coingecko_id in json_body:
@@ -26,12 +26,12 @@ class GroupResult:
26
26
  ticker: str
27
27
  network: Network
28
28
  comment: str
29
- share: Decimal
29
+ share: str
30
30
  addresses: list[AddressBalance]
31
31
  balance_sum: Decimal # sum of all balances in the group
32
32
  usd_sum: Decimal # sum of all usd values in the group
33
- balance_sum_share: Decimal # sum of all balances in the group multiplied by share
34
- usd_sum_share: Decimal # sum of all usd values in the group multiplied by share
33
+ balance_sum_share: Decimal # calculated from share expression
34
+ usd_sum_share: Decimal # proportional to balance_sum_share
35
35
 
36
36
 
37
37
  @dataclass
@@ -113,8 +113,8 @@ def _create_group_result(config: Config, group: AssetGroup, tasks: list[Task], p
113
113
  balance = task.balance.unwrap_err()
114
114
  addresses.append(AddressBalance(address=task.wallet_address, balance=balance))
115
115
 
116
- balance_sum_share = balance_sum * group.share
117
- usd_sum_share = usd_sum * group.share
116
+ balance_sum_share = group.evaluate_share(balance_sum)
117
+ usd_sum_share = usd_sum * (balance_sum_share / balance_sum) if balance_sum > 0 else Decimal(0)
118
118
 
119
119
  return GroupResult(
120
120
  ticker=group.ticker,
@@ -30,7 +30,7 @@ async def get_token_decimals(config: Config) -> TokenDecimals:
30
30
  elif group.network in (NETWORK_BITCOIN, NETWORK_APTOS):
31
31
  result[group.network][None] = 8
32
32
  else:
33
- mm_print.fatal(f"Can't get token decimals for native token on network: {group.network}")
33
+ mm_print.exit_with_error(f"Can't get token decimals for native token on network: {group.network}")
34
34
  continue
35
35
 
36
36
  # get token_decimals via RPC
@@ -46,7 +46,7 @@ async def get_token_decimals(config: Config) -> TokenDecimals:
46
46
  msg = f"can't get decimals for token {group.ticker} / {group.token}, error={res.unwrap_err()}"
47
47
  if config.settings.print_debug:
48
48
  msg += f"\n{res.extra}"
49
- mm_print.fatal(msg)
49
+ mm_print.exit_with_error(msg)
50
50
 
51
51
  result[group.network][group.token] = res.unwrap()
52
52
 
@@ -0,0 +1,163 @@
1
+ import re
2
+ from decimal import Decimal
3
+ from enum import StrEnum, unique
4
+
5
+
6
+ def fnumber(value: Decimal, separator: str, extra: str | None = None) -> str:
7
+ str_value = f"{value:,}".replace(",", separator)
8
+ if extra == "$":
9
+ return "$" + str_value
10
+ if extra == "%":
11
+ return str_value + "%"
12
+ return str_value
13
+
14
+
15
+ def scale_and_round(value: int, decimals: int, round_ndigits: int) -> Decimal:
16
+ if value == 0:
17
+ return Decimal(0)
18
+ return round(Decimal(value / 10**decimals), round_ndigits)
19
+
20
+
21
+ def round_decimal(value: Decimal, round_ndigits: int) -> Decimal:
22
+ if value == Decimal(0):
23
+ return Decimal(0)
24
+ return round(value, round_ndigits)
25
+
26
+
27
+ class _ExpressionParser:
28
+ """Recursive descent parser for arithmetic expressions.
29
+
30
+ Parses expressions by reading left-to-right, respecting operator precedence:
31
+ 1. Parentheses and unary +/- (highest)
32
+ 2. Multiplication and division
33
+ 3. Addition and subtraction (lowest)
34
+
35
+ Uses self.pos as a cursor tracking current position in the string.
36
+ """
37
+
38
+ def __init__(self, expr: str) -> None:
39
+ self.expr = expr.replace(" ", "") # Expression string with spaces removed
40
+ self.pos = 0 # Current position/index in the string (starts at beginning)
41
+
42
+ def parse(self) -> Decimal:
43
+ result = self._parse_expression()
44
+ if self.pos < len(self.expr):
45
+ raise ValueError(f"Unexpected character at position {self.pos}: '{self.expr[self.pos]}'")
46
+ return result
47
+
48
+ def _parse_expression(self) -> Decimal:
49
+ """Parse addition and subtraction (lowest precedence)."""
50
+ result = self._parse_term()
51
+
52
+ while self.pos < len(self.expr):
53
+ if self._peek() == "+":
54
+ self.pos += 1
55
+ result = result + self._parse_term()
56
+ elif self._peek() == "-":
57
+ self.pos += 1
58
+ result = result - self._parse_term()
59
+ else:
60
+ break
61
+
62
+ return result
63
+
64
+ def _parse_term(self) -> Decimal:
65
+ """Parse multiplication and division (medium precedence)."""
66
+ result = self._parse_factor()
67
+
68
+ while self.pos < len(self.expr):
69
+ if self._peek() == "*":
70
+ self.pos += 1
71
+ result = result * self._parse_factor()
72
+ elif self._peek() == "/":
73
+ self.pos += 1
74
+ divisor = self._parse_factor()
75
+ if divisor == 0:
76
+ raise ValueError("Division by zero")
77
+ result = result / divisor
78
+ else:
79
+ break
80
+
81
+ return result
82
+
83
+ def _parse_factor(self) -> Decimal:
84
+ """Parse unary +/-, numbers, and parentheses (highest precedence)."""
85
+ if self._peek() == "+":
86
+ self.pos += 1
87
+ return self._parse_factor()
88
+ if self._peek() == "-":
89
+ self.pos += 1
90
+ return -self._parse_factor()
91
+
92
+ if self._peek() == "(":
93
+ self.pos += 1
94
+ result = self._parse_expression()
95
+ if self._peek() != ")":
96
+ raise ValueError(f"Expected ')' at position {self.pos}")
97
+ self.pos += 1
98
+ return result
99
+
100
+ return self._parse_number()
101
+
102
+ def _parse_number(self) -> Decimal:
103
+ """Parse a numeric value.
104
+
105
+ Scans digits, optionally followed by a decimal point and more digits.
106
+ Advances self.pos past the entire number.
107
+ """
108
+ start = self.pos
109
+
110
+ # Scan integer part: consecutive digits
111
+ while self.pos < len(self.expr) and self.expr[self.pos].isdigit():
112
+ self.pos += 1
113
+
114
+ # If there's a decimal point, scan fractional part
115
+ if self.pos < len(self.expr) and self.expr[self.pos] == ".":
116
+ self.pos += 1
117
+ while self.pos < len(self.expr) and self.expr[self.pos].isdigit():
118
+ self.pos += 1
119
+
120
+ # Ensure we consumed at least one digit
121
+ if start == self.pos:
122
+ raise ValueError(f"Expected number at position {self.pos}")
123
+
124
+ # Extract the substring from start to current position
125
+ return Decimal(self.expr[start : self.pos])
126
+
127
+ def _peek(self) -> str | None:
128
+ """Peek at the current character without consuming it."""
129
+ if self.pos < len(self.expr):
130
+ return self.expr[self.pos]
131
+ return None
132
+
133
+
134
+ def evaluate_share_expression(expression: str, balance_sum: Decimal) -> Decimal:
135
+ """Evaluate share expression with actual balance_sum value.
136
+
137
+ Supports expressions like:
138
+ - "total" -> full balance
139
+ - "0.5total" -> 50% of balance
140
+ - "0.5(total - 100)" -> 50% of (balance - 100)
141
+ - "total - 1000" -> balance minus 1000
142
+ """
143
+ if not re.match(r"^[0-9+\-*/.() total]+$", expression):
144
+ raise ValueError(f"Invalid share expression '{expression}': contains invalid characters")
145
+
146
+ # Insert * before ( when preceded by digit or )
147
+ expr = re.sub(r"(\d|\))\(", r"\1*(", expression)
148
+ # Insert * before 'total' when preceded by digit or )
149
+ expr = re.sub(r"(\d|\))total", r"\1*total", expr)
150
+ # Replace 'total' with actual value in parentheses
151
+ expr = expr.replace("total", f"({balance_sum})")
152
+ try:
153
+ parser = _ExpressionParser(expr)
154
+ return parser.parse()
155
+ except Exception as e:
156
+ raise ValueError(f"Invalid share expression '{expression}': {e}") from e
157
+
158
+
159
+ @unique
160
+ class PrintFormat(StrEnum):
161
+ PLAIN = "plain"
162
+ TABLE = "table"
163
+ JSON = "json"
@@ -0,0 +1,88 @@
1
+ from decimal import Decimal
2
+
3
+ import pytest
4
+
5
+ from mm_balance.utils import evaluate_share_expression
6
+
7
+
8
+ def test_total_expression():
9
+ """Test simple 'total' expression returns full balance."""
10
+ assert evaluate_share_expression("total", Decimal(100)) == Decimal(100)
11
+ assert evaluate_share_expression("total", Decimal("1234.56")) == Decimal("1234.56")
12
+ assert evaluate_share_expression("total", Decimal(0)) == Decimal(0)
13
+
14
+
15
+ def test_percentage_expressions():
16
+ """Test percentage-style expressions with implicit multiplication."""
17
+ assert evaluate_share_expression("0.5total", Decimal(100)) == Decimal(50)
18
+ assert evaluate_share_expression("0.1total", Decimal(1000)) == Decimal(100)
19
+ assert evaluate_share_expression("0.25total", Decimal(200)) == Decimal(50)
20
+
21
+
22
+ def test_subtraction_expressions():
23
+ """Test expressions with subtraction."""
24
+ assert evaluate_share_expression("total - 100", Decimal(500)) == Decimal(400)
25
+ assert evaluate_share_expression("total - 50.5", Decimal(200)) == Decimal("149.5")
26
+ assert evaluate_share_expression("0.5total - 100", Decimal(1000)) == Decimal(400)
27
+
28
+
29
+ def test_parentheses_expressions():
30
+ """Test expressions with parentheses and implicit multiplication."""
31
+ assert evaluate_share_expression("0.5(total - 100)", Decimal(500)) == Decimal(200)
32
+ assert evaluate_share_expression("0.5(total - 50.44)", Decimal(200)) == Decimal("74.78")
33
+ assert evaluate_share_expression("(total - 1000)", Decimal(5000)) == Decimal(4000)
34
+
35
+
36
+ def test_addition_expressions():
37
+ """Test expressions with addition."""
38
+ assert evaluate_share_expression("total + 100", Decimal(500)) == Decimal(600)
39
+ assert evaluate_share_expression("0.3total + 50", Decimal(100)) == Decimal(80)
40
+
41
+
42
+ def test_complex_expressions():
43
+ """Test more complex mathematical expressions."""
44
+ assert evaluate_share_expression("(total - 100) * 0.5", Decimal(500)) == Decimal(200)
45
+ assert evaluate_share_expression("total / 2", Decimal(100)) == Decimal(50)
46
+ assert evaluate_share_expression("2(total - 50)", Decimal(100)) == Decimal(100)
47
+
48
+
49
+ def test_nested_parentheses():
50
+ """Test expressions with nested parentheses."""
51
+ assert evaluate_share_expression("0.5((total - 100) + 50)", Decimal(500)) == Decimal(225)
52
+
53
+
54
+ def test_zero_balance():
55
+ """Test expressions with zero balance."""
56
+ assert evaluate_share_expression("total", Decimal(0)) == Decimal(0)
57
+ assert evaluate_share_expression("0.5total", Decimal(0)) == Decimal(0)
58
+ assert evaluate_share_expression("total - 100", Decimal(0)) == Decimal(-100)
59
+
60
+
61
+ def test_negative_results():
62
+ """Test expressions that result in negative values."""
63
+ assert evaluate_share_expression("total - 1000", Decimal(500)) == Decimal(-500)
64
+ assert evaluate_share_expression("0.5total - 100", Decimal(100)) == Decimal(-50)
65
+
66
+
67
+ def test_invalid_characters():
68
+ """Test that invalid characters are rejected."""
69
+ with pytest.raises(ValueError, match="invalid characters"):
70
+ evaluate_share_expression("total; import os", Decimal(100))
71
+
72
+ with pytest.raises(ValueError, match="invalid characters"):
73
+ evaluate_share_expression("__import__('os')", Decimal(100))
74
+
75
+ with pytest.raises(ValueError, match="invalid characters"):
76
+ evaluate_share_expression("total & 1", Decimal(100))
77
+
78
+
79
+ def test_invalid_syntax():
80
+ """Test that invalid syntax is rejected."""
81
+ with pytest.raises(ValueError, match="Invalid share expression"):
82
+ evaluate_share_expression("total +", Decimal(100))
83
+
84
+ with pytest.raises(ValueError, match="Invalid share expression"):
85
+ evaluate_share_expression("(total", Decimal(100))
86
+
87
+ with pytest.raises(ValueError, match="Invalid share expression"):
88
+ evaluate_share_expression("total total", Decimal(100))