genlayer-test 0.11.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/METADATA +67 -1
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/RECORD +13 -6
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/WHEEL +1 -1
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/entry_points.txt +1 -0
- gltest/direct/__init__.py +31 -0
- gltest/direct/loader.py +288 -0
- gltest/direct/pytest_plugin.py +117 -0
- gltest/direct/sdk_loader.py +260 -0
- gltest/direct/types.py +18 -0
- gltest/direct/vm.py +432 -0
- gltest/direct/wasi_mock.py +258 -0
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {genlayer_test-0.11.0.dist-info → genlayer_test-0.13.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SDK version loader for direct test runner.
|
|
3
|
+
|
|
4
|
+
Handles downloading and extracting the correct genlayer-py-std version
|
|
5
|
+
based on contract header dependencies, similar to genvm-linter.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import sys
|
|
11
|
+
import json
|
|
12
|
+
import tarfile
|
|
13
|
+
import tempfile
|
|
14
|
+
import urllib.request
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Optional, Dict, List
|
|
17
|
+
|
|
18
|
+
CACHE_DIR = Path.home() / ".cache" / "gltest-direct"
|
|
19
|
+
GITHUB_RELEASES_URL = "https://github.com/genlayerlabs/genvm/releases"
|
|
20
|
+
|
|
21
|
+
RUNNER_TYPE = "py-genlayer"
|
|
22
|
+
STD_LIB_TYPE = "py-lib-genlayer-std"
|
|
23
|
+
EMBEDDINGS_TYPE = "py-lib-genlayer-embeddings"
|
|
24
|
+
PROTOBUF_TYPE = "py-lib-protobuf"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def parse_contract_header(contract_path: Path) -> Dict[str, str]:
|
|
28
|
+
"""
|
|
29
|
+
Parse contract file header to extract dependency hashes.
|
|
30
|
+
|
|
31
|
+
Returns dict mapping dependency name to hash.
|
|
32
|
+
"""
|
|
33
|
+
deps = {}
|
|
34
|
+
with open(contract_path, "r") as f:
|
|
35
|
+
content = f.read(2000)
|
|
36
|
+
|
|
37
|
+
pattern = r'"Depends":\s*"([^:]+):([^"]+)"'
|
|
38
|
+
for match in re.finditer(pattern, content):
|
|
39
|
+
name, hash_val = match.groups()
|
|
40
|
+
deps[name] = hash_val
|
|
41
|
+
|
|
42
|
+
return deps
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_latest_version() -> str:
|
|
46
|
+
"""Get latest genvm release version from GitHub."""
|
|
47
|
+
try:
|
|
48
|
+
req = urllib.request.Request(
|
|
49
|
+
f"{GITHUB_RELEASES_URL}/latest",
|
|
50
|
+
method="HEAD",
|
|
51
|
+
)
|
|
52
|
+
req.add_header("User-Agent", "gltest-direct")
|
|
53
|
+
with urllib.request.urlopen(req, timeout=10) as resp:
|
|
54
|
+
final_url = resp.url
|
|
55
|
+
version = final_url.split("/")[-1]
|
|
56
|
+
return version
|
|
57
|
+
except Exception:
|
|
58
|
+
return "v0.2.12"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def list_cached_versions() -> List[str]:
|
|
62
|
+
"""List all cached genvm versions."""
|
|
63
|
+
if not CACHE_DIR.exists():
|
|
64
|
+
return []
|
|
65
|
+
|
|
66
|
+
versions = []
|
|
67
|
+
for f in CACHE_DIR.glob("genvm-universal-*.tar.xz"):
|
|
68
|
+
match = re.search(r"genvm-universal-(.+)\.tar\.xz", f.name)
|
|
69
|
+
if match:
|
|
70
|
+
versions.append(match.group(1))
|
|
71
|
+
return sorted(versions, reverse=True)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def download_artifacts(version: str) -> Path:
|
|
75
|
+
"""Download genvm release tarball if not cached."""
|
|
76
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
|
|
78
|
+
tarball_name = f"genvm-universal-{version}.tar.xz"
|
|
79
|
+
tarball_path = CACHE_DIR / tarball_name
|
|
80
|
+
|
|
81
|
+
if tarball_path.exists():
|
|
82
|
+
return tarball_path
|
|
83
|
+
|
|
84
|
+
url = f"{GITHUB_RELEASES_URL}/download/{version}/genvm-universal.tar.xz"
|
|
85
|
+
print(f"Downloading {url}...")
|
|
86
|
+
|
|
87
|
+
req = urllib.request.Request(url)
|
|
88
|
+
req.add_header("User-Agent", "gltest-direct")
|
|
89
|
+
|
|
90
|
+
with urllib.request.urlopen(req, timeout=300) as resp:
|
|
91
|
+
total = int(resp.headers.get("Content-Length", 0))
|
|
92
|
+
downloaded = 0
|
|
93
|
+
|
|
94
|
+
with tempfile.NamedTemporaryFile(delete=False, dir=CACHE_DIR) as tmp:
|
|
95
|
+
while True:
|
|
96
|
+
chunk = resp.read(1024 * 1024)
|
|
97
|
+
if not chunk:
|
|
98
|
+
break
|
|
99
|
+
tmp.write(chunk)
|
|
100
|
+
downloaded += len(chunk)
|
|
101
|
+
if total:
|
|
102
|
+
pct = downloaded * 100 // total
|
|
103
|
+
print(f"\r {pct}% ({downloaded // 1024 // 1024}MB)", end="", flush=True)
|
|
104
|
+
|
|
105
|
+
tmp_path = tmp.name
|
|
106
|
+
|
|
107
|
+
print()
|
|
108
|
+
os.rename(tmp_path, tarball_path)
|
|
109
|
+
return tarball_path
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def extract_runner(
|
|
113
|
+
tarball_path: Path,
|
|
114
|
+
runner_type: str,
|
|
115
|
+
runner_hash: Optional[str] = None,
|
|
116
|
+
version: Optional[str] = None,
|
|
117
|
+
) -> Path:
|
|
118
|
+
"""Extract a runner from the tarball."""
|
|
119
|
+
if version is None:
|
|
120
|
+
match = re.search(r"genvm-universal-(.+)\.tar\.xz", tarball_path.name)
|
|
121
|
+
version = match.group(1) if match else "unknown"
|
|
122
|
+
|
|
123
|
+
extract_base = CACHE_DIR / "extracted" / version / runner_type
|
|
124
|
+
|
|
125
|
+
# Fast path: if hash specified and already extracted, skip tarball entirely
|
|
126
|
+
if runner_hash and runner_hash.lower() != "latest":
|
|
127
|
+
extract_dir = extract_base / runner_hash
|
|
128
|
+
if extract_dir.exists():
|
|
129
|
+
return extract_dir
|
|
130
|
+
|
|
131
|
+
# Check if any version already extracted (for "latest" case)
|
|
132
|
+
if extract_base.exists():
|
|
133
|
+
existing = sorted(extract_base.iterdir(), reverse=True)
|
|
134
|
+
if existing and (not runner_hash or runner_hash.lower() == "latest"):
|
|
135
|
+
return existing[0]
|
|
136
|
+
|
|
137
|
+
# Need to open tarball - this is slow (~13s for xz)
|
|
138
|
+
with tarfile.open(tarball_path, "r:xz") as outer_tar:
|
|
139
|
+
prefix = f"runners/{runner_type}/"
|
|
140
|
+
runner_tars = [
|
|
141
|
+
m.name for m in outer_tar.getmembers()
|
|
142
|
+
if m.name.startswith(prefix) and m.name.endswith(".tar")
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
if not runner_tars:
|
|
146
|
+
raise ValueError(f"No {runner_type} runners found in tarball")
|
|
147
|
+
|
|
148
|
+
# Treat "latest" as no specific hash
|
|
149
|
+
if runner_hash and runner_hash.lower() != "latest":
|
|
150
|
+
target = f"runners/{runner_type}/{runner_hash[:2]}/{runner_hash[2:]}.tar"
|
|
151
|
+
if target not in runner_tars:
|
|
152
|
+
raise ValueError(f"Runner hash {runner_hash} not found")
|
|
153
|
+
runner_tar_name = target
|
|
154
|
+
extract_dir = extract_base / runner_hash
|
|
155
|
+
else:
|
|
156
|
+
runner_tar_name = sorted(runner_tars)[-1]
|
|
157
|
+
parts = runner_tar_name.split("/")
|
|
158
|
+
runner_hash = parts[-2] + parts[-1].replace(".tar", "")
|
|
159
|
+
extract_dir = extract_base / runner_hash
|
|
160
|
+
|
|
161
|
+
if extract_dir.exists():
|
|
162
|
+
return extract_dir
|
|
163
|
+
|
|
164
|
+
inner_tar_file = outer_tar.extractfile(runner_tar_name)
|
|
165
|
+
if inner_tar_file is None:
|
|
166
|
+
raise ValueError(f"Failed to read {runner_tar_name}")
|
|
167
|
+
|
|
168
|
+
extract_dir.mkdir(parents=True, exist_ok=True)
|
|
169
|
+
|
|
170
|
+
with tarfile.open(fileobj=inner_tar_file, mode="r:") as inner_tar:
|
|
171
|
+
inner_tar.extractall(extract_dir)
|
|
172
|
+
|
|
173
|
+
return extract_dir
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def parse_runner_manifest(runner_dir: Path) -> Dict[str, str]:
|
|
177
|
+
"""Parse runner.json to get transitive dependencies."""
|
|
178
|
+
manifest_path = runner_dir / "runner.json"
|
|
179
|
+
if not manifest_path.exists():
|
|
180
|
+
return {}
|
|
181
|
+
|
|
182
|
+
with open(manifest_path) as f:
|
|
183
|
+
manifest = json.load(f)
|
|
184
|
+
|
|
185
|
+
deps = {}
|
|
186
|
+
seq = manifest.get("Seq", [])
|
|
187
|
+
for item in seq:
|
|
188
|
+
if "Depends" in item:
|
|
189
|
+
dep = item["Depends"]
|
|
190
|
+
if ":" in dep:
|
|
191
|
+
name, hash_val = dep.split(":", 1)
|
|
192
|
+
deps[name] = hash_val
|
|
193
|
+
|
|
194
|
+
return deps
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def setup_sdk_paths(
|
|
198
|
+
contract_path: Optional[Path] = None,
|
|
199
|
+
version: Optional[str] = None,
|
|
200
|
+
) -> List[Path]:
|
|
201
|
+
"""
|
|
202
|
+
Setup sys.path with correct SDK versions for a contract.
|
|
203
|
+
|
|
204
|
+
Returns list of paths added to sys.path.
|
|
205
|
+
"""
|
|
206
|
+
contract_deps = {}
|
|
207
|
+
if contract_path and contract_path.exists():
|
|
208
|
+
contract_deps = parse_contract_header(contract_path)
|
|
209
|
+
|
|
210
|
+
if version is None:
|
|
211
|
+
cached = list_cached_versions()
|
|
212
|
+
version = cached[0] if cached else get_latest_version()
|
|
213
|
+
|
|
214
|
+
tarball = download_artifacts(version)
|
|
215
|
+
|
|
216
|
+
runner_hash = contract_deps.get(RUNNER_TYPE)
|
|
217
|
+
runner_dir = extract_runner(tarball, RUNNER_TYPE, runner_hash, version)
|
|
218
|
+
|
|
219
|
+
runner_deps = parse_runner_manifest(runner_dir)
|
|
220
|
+
|
|
221
|
+
std_hash = runner_deps.get(STD_LIB_TYPE)
|
|
222
|
+
std_dir: Optional[Path] = None
|
|
223
|
+
if std_hash:
|
|
224
|
+
std_dir = extract_runner(tarball, STD_LIB_TYPE, std_hash, version)
|
|
225
|
+
|
|
226
|
+
embeddings_hash = contract_deps.get(EMBEDDINGS_TYPE)
|
|
227
|
+
embeddings_dir: Optional[Path] = None
|
|
228
|
+
proto_dir: Optional[Path] = None
|
|
229
|
+
if embeddings_hash:
|
|
230
|
+
embeddings_dir = extract_runner(tarball, EMBEDDINGS_TYPE, embeddings_hash, version)
|
|
231
|
+
proto_hash = runner_deps.get(PROTOBUF_TYPE)
|
|
232
|
+
if proto_hash:
|
|
233
|
+
proto_dir = extract_runner(tarball, PROTOBUF_TYPE, proto_hash, version)
|
|
234
|
+
|
|
235
|
+
added_paths = []
|
|
236
|
+
|
|
237
|
+
# Helper to add path - tries both 'src' subdirectory and direct directory
|
|
238
|
+
def add_sdk_path(sdk_dir: Path) -> None:
|
|
239
|
+
src_path = sdk_dir / "src"
|
|
240
|
+
if src_path.exists():
|
|
241
|
+
path_to_add = src_path
|
|
242
|
+
else:
|
|
243
|
+
path_to_add = sdk_dir
|
|
244
|
+
|
|
245
|
+
if str(path_to_add) not in sys.path:
|
|
246
|
+
sys.path.insert(0, str(path_to_add))
|
|
247
|
+
added_paths.append(path_to_add)
|
|
248
|
+
|
|
249
|
+
add_sdk_path(runner_dir)
|
|
250
|
+
|
|
251
|
+
if std_dir:
|
|
252
|
+
add_sdk_path(std_dir)
|
|
253
|
+
|
|
254
|
+
if embeddings_dir:
|
|
255
|
+
add_sdk_path(embeddings_dir)
|
|
256
|
+
|
|
257
|
+
if proto_dir:
|
|
258
|
+
add_sdk_path(proto_dir)
|
|
259
|
+
|
|
260
|
+
return added_paths
|
gltest/direct/types.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type definitions for direct test runner.
|
|
3
|
+
|
|
4
|
+
Reuses MockedLLMResponse and MockedWebResponse from gltest.types.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Re-export from parent package for convenience
|
|
8
|
+
from ..types import (
|
|
9
|
+
MockedLLMResponse,
|
|
10
|
+
MockedWebResponse,
|
|
11
|
+
MockedWebResponseData,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"MockedLLMResponse",
|
|
16
|
+
"MockedWebResponse",
|
|
17
|
+
"MockedWebResponseData",
|
|
18
|
+
]
|
gltest/direct/vm.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VMContext - Foundry-style test VM for GenLayer contracts.
|
|
3
|
+
|
|
4
|
+
Provides cheatcodes for:
|
|
5
|
+
- Setting sender/value (vm.sender, vm.value)
|
|
6
|
+
- Snapshots and reverts (vm.snapshot(), vm.revert())
|
|
7
|
+
- Mocking nondet operations (vm.mock_web(), vm.mock_llm())
|
|
8
|
+
- Expecting reverts (vm.expect_revert())
|
|
9
|
+
- Pranking (vm.prank(), vm.startPrank(), vm.stopPrank())
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import re
|
|
15
|
+
import sys
|
|
16
|
+
import hashlib
|
|
17
|
+
from contextlib import contextmanager, ExitStack
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Any, Optional, Pattern, List, Tuple, Dict
|
|
20
|
+
from unittest.mock import patch
|
|
21
|
+
|
|
22
|
+
from ..types import MockedWebResponseData
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Snapshot:
|
|
27
|
+
"""Storage snapshot for revert functionality."""
|
|
28
|
+
id: int
|
|
29
|
+
storage_data: Dict[bytes, bytes]
|
|
30
|
+
balances: Dict[bytes, int]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class InmemManager:
|
|
34
|
+
"""
|
|
35
|
+
In-memory storage manager compatible with genlayer.py.storage.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
self._parts: Dict[bytes, Tuple["Slot", bytearray]] = {}
|
|
40
|
+
|
|
41
|
+
def get_store_slot(self, slot_id: bytes) -> "Slot":
|
|
42
|
+
res = self._parts.get(slot_id)
|
|
43
|
+
if res is None:
|
|
44
|
+
slot = Slot(slot_id, self)
|
|
45
|
+
self._parts[slot_id] = (slot, bytearray())
|
|
46
|
+
return slot
|
|
47
|
+
return res[0]
|
|
48
|
+
|
|
49
|
+
def do_read(self, slot_id: bytes, off: int, length: int) -> bytes:
|
|
50
|
+
res = self._parts.get(slot_id)
|
|
51
|
+
if res is None:
|
|
52
|
+
slot = Slot(slot_id, self)
|
|
53
|
+
mem = bytearray()
|
|
54
|
+
self._parts[slot_id] = (slot, mem)
|
|
55
|
+
else:
|
|
56
|
+
_, mem = res
|
|
57
|
+
|
|
58
|
+
needed = off + length
|
|
59
|
+
if len(mem) < needed:
|
|
60
|
+
mem.extend(b'\x00' * (needed - len(mem)))
|
|
61
|
+
|
|
62
|
+
return bytes(memoryview(mem)[off:off + length])
|
|
63
|
+
|
|
64
|
+
def do_write(self, slot_id: bytes, off: int, what: bytes) -> None:
|
|
65
|
+
res = self._parts.get(slot_id)
|
|
66
|
+
if res is None:
|
|
67
|
+
slot = Slot(slot_id, self)
|
|
68
|
+
mem = bytearray()
|
|
69
|
+
self._parts[slot_id] = (slot, mem)
|
|
70
|
+
else:
|
|
71
|
+
_, mem = res
|
|
72
|
+
|
|
73
|
+
what_view = memoryview(what)
|
|
74
|
+
length = len(what_view)
|
|
75
|
+
|
|
76
|
+
needed = off + length
|
|
77
|
+
if len(mem) < needed:
|
|
78
|
+
mem.extend(b'\x00' * (needed - len(mem)))
|
|
79
|
+
|
|
80
|
+
memoryview(mem)[off:off + length] = what_view
|
|
81
|
+
|
|
82
|
+
def snapshot(self) -> Dict[bytes, bytes]:
|
|
83
|
+
return {
|
|
84
|
+
slot_id: bytes(mem)
|
|
85
|
+
for slot_id, (_, mem) in self._parts.items()
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def restore(self, data: Dict[bytes, bytes]) -> None:
|
|
89
|
+
self._parts.clear()
|
|
90
|
+
for slot_id, mem_data in data.items():
|
|
91
|
+
slot = Slot(slot_id, self)
|
|
92
|
+
self._parts[slot_id] = (slot, bytearray(mem_data))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Slot:
|
|
96
|
+
"""Storage slot compatible with genlayer.py.storage."""
|
|
97
|
+
|
|
98
|
+
__slots__ = ('id', 'manager', '_indir_cache')
|
|
99
|
+
|
|
100
|
+
def __init__(self, slot_id: bytes, manager: InmemManager):
|
|
101
|
+
self.id = slot_id
|
|
102
|
+
self.manager = manager
|
|
103
|
+
self._indir_cache = hashlib.sha3_256(slot_id)
|
|
104
|
+
|
|
105
|
+
def read(self, off: int, length: int) -> bytes:
|
|
106
|
+
return self.manager.do_read(self.id, off, length)
|
|
107
|
+
|
|
108
|
+
def write(self, off: int, what: bytes) -> None:
|
|
109
|
+
self.manager.do_write(self.id, off, what)
|
|
110
|
+
|
|
111
|
+
def indirect(self, off: int) -> "Slot":
|
|
112
|
+
hasher = self._indir_cache.copy()
|
|
113
|
+
hasher.update(off.to_bytes(4, 'little'))
|
|
114
|
+
return self.manager.get_store_slot(hasher.digest())
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
ROOT_SLOT_ID = b'\x00' * 32
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class VMContext:
|
|
122
|
+
"""
|
|
123
|
+
Test VM context providing Foundry-style cheatcodes.
|
|
124
|
+
|
|
125
|
+
Usage:
|
|
126
|
+
vm = VMContext()
|
|
127
|
+
vm.sender = Address("0x" + "a" * 40)
|
|
128
|
+
vm.mock_web("api.example.com", {"status": 200, "body": "{}"})
|
|
129
|
+
|
|
130
|
+
with vm.activate():
|
|
131
|
+
contract = deploy_contract("Token.py", vm, owner)
|
|
132
|
+
contract.transfer(bob, 100)
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Message context
|
|
136
|
+
_sender: Optional[Any] = None
|
|
137
|
+
_origin: Optional[Any] = None
|
|
138
|
+
_contract_address: Optional[Any] = None
|
|
139
|
+
_value: int = 0
|
|
140
|
+
_chain_id: int = 1
|
|
141
|
+
_datetime: str = "2024-01-01T00:00:00Z"
|
|
142
|
+
|
|
143
|
+
# Storage
|
|
144
|
+
_storage: InmemManager = field(default_factory=InmemManager)
|
|
145
|
+
_balances: Dict[bytes, int] = field(default_factory=dict)
|
|
146
|
+
|
|
147
|
+
# Snapshots
|
|
148
|
+
_snapshots: Dict[int, Snapshot] = field(default_factory=dict)
|
|
149
|
+
_snapshot_counter: int = 0
|
|
150
|
+
|
|
151
|
+
# Mocks
|
|
152
|
+
_web_mocks: List[Tuple[Pattern, MockedWebResponseData]] = field(default_factory=list)
|
|
153
|
+
_llm_mocks: List[Tuple[Pattern, str]] = field(default_factory=list)
|
|
154
|
+
|
|
155
|
+
# Expect revert
|
|
156
|
+
_expect_revert: Optional[str] = None
|
|
157
|
+
_expect_revert_any: bool = False
|
|
158
|
+
|
|
159
|
+
# Prank stack
|
|
160
|
+
_prank_stack: List[Any] = field(default_factory=list)
|
|
161
|
+
|
|
162
|
+
# Return value capture
|
|
163
|
+
_return_value: Any = None
|
|
164
|
+
_returned: bool = False
|
|
165
|
+
|
|
166
|
+
# Debug tracing
|
|
167
|
+
_traces: List[str] = field(default_factory=list)
|
|
168
|
+
_trace_enabled: bool = True
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def sender(self) -> Any:
|
|
172
|
+
if self._prank_stack:
|
|
173
|
+
return self._prank_stack[-1]
|
|
174
|
+
return self._sender
|
|
175
|
+
|
|
176
|
+
@sender.setter
|
|
177
|
+
def sender(self, addr: Any) -> None:
|
|
178
|
+
self._sender = addr
|
|
179
|
+
self._refresh_gl_message()
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def value(self) -> int:
|
|
183
|
+
return self._value
|
|
184
|
+
|
|
185
|
+
@value.setter
|
|
186
|
+
def value(self, val: int) -> None:
|
|
187
|
+
self._value = val
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def origin(self) -> Any:
|
|
191
|
+
return self._origin or self._sender
|
|
192
|
+
|
|
193
|
+
@origin.setter
|
|
194
|
+
def origin(self, addr: Any) -> None:
|
|
195
|
+
self._origin = addr
|
|
196
|
+
|
|
197
|
+
def warp(self, timestamp: str) -> None:
|
|
198
|
+
"""Set block timestamp (ISO format)."""
|
|
199
|
+
self._datetime = timestamp
|
|
200
|
+
|
|
201
|
+
def deal(self, address: Any, amount: int) -> None:
|
|
202
|
+
"""Set balance for an address."""
|
|
203
|
+
addr_bytes = self._to_bytes(address)
|
|
204
|
+
self._balances[addr_bytes] = amount
|
|
205
|
+
|
|
206
|
+
def snapshot(self) -> int:
|
|
207
|
+
"""Take a snapshot of current state. Returns snapshot ID."""
|
|
208
|
+
snap_id = self._snapshot_counter
|
|
209
|
+
self._snapshot_counter += 1
|
|
210
|
+
|
|
211
|
+
self._snapshots[snap_id] = Snapshot(
|
|
212
|
+
id=snap_id,
|
|
213
|
+
storage_data=self._storage.snapshot(),
|
|
214
|
+
balances=dict(self._balances),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return snap_id
|
|
218
|
+
|
|
219
|
+
def revert(self, snapshot_id: int) -> None:
|
|
220
|
+
"""Revert to a previous snapshot."""
|
|
221
|
+
if snapshot_id not in self._snapshots:
|
|
222
|
+
raise ValueError(f"Snapshot {snapshot_id} not found")
|
|
223
|
+
|
|
224
|
+
snap = self._snapshots[snapshot_id]
|
|
225
|
+
self._storage.restore(snap.storage_data)
|
|
226
|
+
self._balances = dict(snap.balances)
|
|
227
|
+
|
|
228
|
+
self._snapshots = {
|
|
229
|
+
k: v for k, v in self._snapshots.items()
|
|
230
|
+
if k <= snapshot_id
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def mock_web(
|
|
234
|
+
self,
|
|
235
|
+
url_pattern: str,
|
|
236
|
+
response: MockedWebResponseData,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Mock web requests matching URL pattern."""
|
|
239
|
+
pattern = re.compile(url_pattern)
|
|
240
|
+
self._web_mocks.append((pattern, response))
|
|
241
|
+
|
|
242
|
+
def mock_llm(self, prompt_pattern: str, response: str) -> None:
|
|
243
|
+
"""Mock LLM prompts matching pattern."""
|
|
244
|
+
pattern = re.compile(prompt_pattern)
|
|
245
|
+
self._llm_mocks.append((pattern, response))
|
|
246
|
+
|
|
247
|
+
def clear_mocks(self) -> None:
|
|
248
|
+
"""Clear all registered mocks."""
|
|
249
|
+
self._web_mocks.clear()
|
|
250
|
+
self._llm_mocks.clear()
|
|
251
|
+
|
|
252
|
+
@contextmanager
|
|
253
|
+
def expect_revert(self, message: Optional[str] = None):
|
|
254
|
+
"""Context manager expecting the next call to revert."""
|
|
255
|
+
self._expect_revert = message
|
|
256
|
+
self._expect_revert_any = message is None
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
yield
|
|
260
|
+
raise AssertionError(
|
|
261
|
+
f"Expected revert{f' with message: {message}' if message else ''}, but call succeeded"
|
|
262
|
+
)
|
|
263
|
+
except Exception as e:
|
|
264
|
+
from .wasi_mock import ContractRollback
|
|
265
|
+
|
|
266
|
+
if isinstance(e, ContractRollback):
|
|
267
|
+
if message is not None and message not in e.message:
|
|
268
|
+
raise AssertionError(
|
|
269
|
+
f"Expected revert with message '{message}', got '{e.message}'"
|
|
270
|
+
)
|
|
271
|
+
elif isinstance(e, AssertionError):
|
|
272
|
+
raise
|
|
273
|
+
else:
|
|
274
|
+
if message is not None and message not in str(e):
|
|
275
|
+
raise
|
|
276
|
+
finally:
|
|
277
|
+
self._expect_revert = None
|
|
278
|
+
self._expect_revert_any = False
|
|
279
|
+
|
|
280
|
+
@contextmanager
|
|
281
|
+
def prank(self, address: Any):
|
|
282
|
+
"""Context manager to temporarily change sender."""
|
|
283
|
+
self._prank_stack.append(address)
|
|
284
|
+
self._refresh_gl_message()
|
|
285
|
+
try:
|
|
286
|
+
yield
|
|
287
|
+
finally:
|
|
288
|
+
self._prank_stack.pop()
|
|
289
|
+
self._refresh_gl_message()
|
|
290
|
+
|
|
291
|
+
def startPrank(self, address: Any) -> None:
|
|
292
|
+
"""Start pranking as address (persists until stopPrank)."""
|
|
293
|
+
self._prank_stack.append(address)
|
|
294
|
+
self._refresh_gl_message()
|
|
295
|
+
|
|
296
|
+
def stopPrank(self) -> None:
|
|
297
|
+
"""Stop the current prank."""
|
|
298
|
+
if self._prank_stack:
|
|
299
|
+
self._prank_stack.pop()
|
|
300
|
+
self._refresh_gl_message()
|
|
301
|
+
else:
|
|
302
|
+
raise RuntimeError("No active prank to stop")
|
|
303
|
+
|
|
304
|
+
@contextmanager
|
|
305
|
+
def activate(self):
|
|
306
|
+
"""
|
|
307
|
+
Activate this VM context for contract execution.
|
|
308
|
+
Uses proper cleanup via ExitStack for resource management.
|
|
309
|
+
"""
|
|
310
|
+
from . import wasi_mock
|
|
311
|
+
|
|
312
|
+
with ExitStack() as stack:
|
|
313
|
+
wasi_mock.set_vm(self)
|
|
314
|
+
sys.modules['_genlayer_wasi'] = wasi_mock
|
|
315
|
+
|
|
316
|
+
stack.enter_context(
|
|
317
|
+
patch('os.fdopen', wasi_mock.patched_fdopen)
|
|
318
|
+
)
|
|
319
|
+
stack.callback(self._cleanup_after_deactivate)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
yield self
|
|
323
|
+
finally:
|
|
324
|
+
if '_genlayer_wasi' in sys.modules:
|
|
325
|
+
del sys.modules['_genlayer_wasi']
|
|
326
|
+
wasi_mock.clear_vm()
|
|
327
|
+
|
|
328
|
+
def _cleanup_after_deactivate(self) -> None:
|
|
329
|
+
"""Clean up resources after VM deactivation."""
|
|
330
|
+
modules_to_remove = [
|
|
331
|
+
key for key in sys.modules.keys()
|
|
332
|
+
if key.startswith('genlayer') or key.startswith('_contract_')
|
|
333
|
+
]
|
|
334
|
+
for mod in modules_to_remove:
|
|
335
|
+
del sys.modules[mod]
|
|
336
|
+
|
|
337
|
+
def _match_web_mock(self, url: str, method: str = "GET") -> Optional[MockedWebResponseData]:
|
|
338
|
+
for pattern, response in self._web_mocks:
|
|
339
|
+
if pattern.search(url):
|
|
340
|
+
if response.get("method", "GET") == method:
|
|
341
|
+
return response
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
def _match_llm_mock(self, prompt: str) -> Optional[str]:
|
|
345
|
+
for pattern, response in self._llm_mocks:
|
|
346
|
+
if pattern.search(prompt):
|
|
347
|
+
return response
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
def _trace(self, message: str) -> None:
|
|
351
|
+
if self._trace_enabled:
|
|
352
|
+
self._traces.append(message)
|
|
353
|
+
|
|
354
|
+
def _to_bytes(self, addr: Any) -> bytes:
|
|
355
|
+
if isinstance(addr, bytes):
|
|
356
|
+
return addr
|
|
357
|
+
if hasattr(addr, 'as_bytes'):
|
|
358
|
+
return addr.as_bytes
|
|
359
|
+
if hasattr(addr, '__bytes__'):
|
|
360
|
+
return bytes(addr)
|
|
361
|
+
if isinstance(addr, str):
|
|
362
|
+
if addr.startswith("0x"):
|
|
363
|
+
return bytes.fromhex(addr[2:])
|
|
364
|
+
return bytes.fromhex(addr)
|
|
365
|
+
raise ValueError(f"Cannot convert {type(addr)} to bytes")
|
|
366
|
+
|
|
367
|
+
def _refresh_gl_message(self) -> None:
|
|
368
|
+
"""
|
|
369
|
+
Refresh gl.message and gl.message_raw to reflect current sender.
|
|
370
|
+
|
|
371
|
+
GenLayer SDK caches gl.message at import time. This method updates
|
|
372
|
+
the cached values so contracts see the current vm.sender.
|
|
373
|
+
|
|
374
|
+
Only updates if genlayer.gl is already imported - we must not trigger
|
|
375
|
+
a fresh import as that would read from stdin before message is injected.
|
|
376
|
+
"""
|
|
377
|
+
# Only proceed if genlayer.gl is already loaded
|
|
378
|
+
if 'genlayer.gl' not in sys.modules:
|
|
379
|
+
return
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
gl = sys.modules['genlayer.gl']
|
|
383
|
+
from genlayer.py.types import Address, u256
|
|
384
|
+
|
|
385
|
+
# Convert sender to Address if needed
|
|
386
|
+
sender = self.sender
|
|
387
|
+
if sender is not None and not isinstance(sender, Address):
|
|
388
|
+
if isinstance(sender, bytes):
|
|
389
|
+
sender = Address(sender)
|
|
390
|
+
elif hasattr(sender, 'as_bytes'):
|
|
391
|
+
sender = Address(sender.as_bytes)
|
|
392
|
+
|
|
393
|
+
origin = self.origin
|
|
394
|
+
if origin is not None and not isinstance(origin, Address):
|
|
395
|
+
if isinstance(origin, bytes):
|
|
396
|
+
origin = Address(origin)
|
|
397
|
+
elif hasattr(origin, 'as_bytes'):
|
|
398
|
+
origin = Address(origin.as_bytes)
|
|
399
|
+
|
|
400
|
+
# Update message_raw dict (mutable)
|
|
401
|
+
if hasattr(gl, 'message_raw') and gl.message_raw is not None:
|
|
402
|
+
gl.message_raw['sender_address'] = sender
|
|
403
|
+
gl.message_raw['origin_address'] = origin
|
|
404
|
+
|
|
405
|
+
# Replace gl.message with new NamedTuple (immutable, must recreate)
|
|
406
|
+
if hasattr(gl, 'message') and gl.message is not None:
|
|
407
|
+
gl.message = gl.MessageType(
|
|
408
|
+
contract_address=gl.message.contract_address,
|
|
409
|
+
sender_address=sender,
|
|
410
|
+
origin_address=origin,
|
|
411
|
+
value=u256(self._value),
|
|
412
|
+
chain_id=u256(self._chain_id),
|
|
413
|
+
)
|
|
414
|
+
except ImportError:
|
|
415
|
+
# genlayer not loaded yet, nothing to update
|
|
416
|
+
pass
|
|
417
|
+
|
|
418
|
+
def get_message_raw(self) -> Dict[str, Any]:
|
|
419
|
+
"""Get MessageRawType dict for stdin injection."""
|
|
420
|
+
return {
|
|
421
|
+
"contract_address": self._contract_address,
|
|
422
|
+
"sender_address": self.sender,
|
|
423
|
+
"origin_address": self.origin,
|
|
424
|
+
"stack": [],
|
|
425
|
+
"value": self._value,
|
|
426
|
+
"datetime": self._datetime,
|
|
427
|
+
"is_init": False,
|
|
428
|
+
"chain_id": self._chain_id,
|
|
429
|
+
"entry_kind": 0,
|
|
430
|
+
"entry_data": b"",
|
|
431
|
+
"entry_stage_data": None,
|
|
432
|
+
}
|