@ictechgy/context-guard 0.4.7 → 0.4.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from __future__ import annotations
11
11
  import argparse
12
12
  import base64
13
13
  import binascii
14
+ import errno
14
15
  try:
15
16
  import fcntl
16
17
  except ImportError: # pragma: no cover - fcntl is unavailable on Windows.
@@ -49,6 +50,8 @@ DEFAULT_SAFETY_FACTOR = 1.25
49
50
  DEFAULT_LARGE_SECTION_BYTES = 64_000
50
51
  MAX_LEDGER_ROWS = 20_000
51
52
  LEDGER_TAIL_INITIAL_BYTES = 64 * 1024
53
+ LEDGER_OPEN_RETRY_ATTEMPTS = 5
54
+ LEDGER_OPEN_RETRY_SECONDS = 0.01
52
55
  TTL_SECONDS = {"5m": 5 * 60, "1h": 60 * 60}
53
56
  ANTHROPIC_DOCS_URL = "https://docs.anthropic.com/en/build-with-claude/prompt-caching"
54
57
  ANTHROPIC_PRICING_URL = "https://platform.claude.com/docs/en/about-claude/pricing"
@@ -58,6 +61,10 @@ ALLOWED_FIRST_COMPONENT_SYMLINKS = {
58
61
  }
59
62
  DIR_FD_OPEN_SUPPORTED = os.open in getattr(os, "supports_dir_fd", set())
60
63
  NO_FOLLOW_SUPPORTED = hasattr(os, "O_NOFOLLOW")
64
+ DIR_FD_STAT_NOFOLLOW_SUPPORTED = (
65
+ os.stat in getattr(os, "supports_dir_fd", set())
66
+ and os.stat in getattr(os, "supports_follow_symlinks", set())
67
+ )
61
68
 
62
69
  SECRET_RE = re.compile(
63
70
  r"(?is)("
@@ -148,25 +155,68 @@ def token_proxy_obj(data: Any) -> int:
148
155
  return token_proxy_text(json_bytes(data))
149
156
 
150
157
 
158
+ def read_bounded_regular_path(path: str | Path, *, max_bytes: int, label: str) -> tuple[str, bool]:
159
+ if max_bytes < 1 or max_bytes > MAX_MAX_BYTES:
160
+ fail(f"max bytes must be between 1 and {MAX_MAX_BYTES}")
161
+ p = reject_symlink_components(Path(path), label=label)
162
+ leaf_name = _private_leaf_name(p, label=label)
163
+ parent_fd = -1
164
+ fd = -1
165
+ try:
166
+ parent_fd = open_directory_no_follow(p.parent, label=f"{label} parent")
167
+ if not DIR_FD_STAT_NOFOLLOW_SUPPORTED:
168
+ fail(f"{label} requires dir_fd stat support for symlink-safe regular-file validation")
169
+ try:
170
+ pre_st = os.stat(leaf_name, dir_fd=parent_fd, follow_symlinks=False)
171
+ except OSError as exc:
172
+ fail(f"could not inspect {label}: {os_error_detail(exc)}")
173
+ if not stat.S_ISREG(pre_st.st_mode):
174
+ fail(f"{label} must be a regular file")
175
+ flags = _base_open_flags() | _no_follow_flag(label=label)
176
+ if hasattr(os, "O_NONBLOCK"):
177
+ flags |= os.O_NONBLOCK
178
+ if hasattr(os, "O_NOCTTY"):
179
+ flags |= os.O_NOCTTY
180
+ fd = os.open(leaf_name, flags, dir_fd=parent_fd)
181
+ if not stat.S_ISREG(os.fstat(fd).st_mode):
182
+ fail(f"{label} must be a regular file")
183
+ chunks: list[bytes] = []
184
+ remaining = max_bytes + 1
185
+ while remaining > 0:
186
+ chunk = os.read(fd, min(64 * 1024, remaining))
187
+ if not chunk:
188
+ break
189
+ chunks.append(chunk)
190
+ remaining -= len(chunk)
191
+ raw = b"".join(chunks)
192
+ except CostGuardError:
193
+ raise
194
+ except OSError as exc:
195
+ fail(f"could not read {label}: {os_error_detail(exc)}")
196
+ finally:
197
+ if fd >= 0:
198
+ try:
199
+ os.close(fd)
200
+ except OSError:
201
+ pass
202
+ if parent_fd >= 0:
203
+ try:
204
+ os.close(parent_fd)
205
+ except OSError:
206
+ pass
207
+ truncated = len(raw) > max_bytes
208
+ if truncated:
209
+ raw = raw[:max_bytes]
210
+ return raw.decode("utf-8", errors="replace"), truncated
211
+
212
+
151
213
  def read_text_path(path: str, *, max_bytes: int = DEFAULT_MAX_BYTES) -> tuple[str, bool]:
152
214
  if max_bytes < 1 or max_bytes > MAX_MAX_BYTES:
153
215
  fail(f"max bytes must be between 1 and {MAX_MAX_BYTES}")
154
216
  if path == "-":
155
217
  raw = sys.stdin.buffer.read(max_bytes + 1)
156
218
  else:
157
- p = Path(path)
158
- try:
159
- st = p.stat()
160
- except OSError as exc:
161
- fail(f"could not read input file: {exc}")
162
- if not stat.S_ISREG(st.st_mode):
163
- fail("input path must be a regular file")
164
- if st.st_size > max_bytes + 1:
165
- # Read only the bounded prefix so large requests cannot exhaust memory.
166
- with p.open("rb") as fh:
167
- raw = fh.read(max_bytes + 1)
168
- else:
169
- raw = p.read_bytes()
219
+ return read_bounded_regular_path(path, max_bytes=max_bytes, label="input file")
170
220
  truncated = len(raw) > max_bytes
171
221
  if truncated:
172
222
  raw = raw[:max_bytes]
@@ -494,20 +544,20 @@ def _base_open_flags() -> int:
494
544
  return flags
495
545
 
496
546
 
497
- def _no_follow_flag() -> int:
547
+ def _no_follow_flag(*, label: str = "private local cost storage") -> int:
498
548
  if not NO_FOLLOW_SUPPORTED:
499
- fail("private local cost storage requires O_NOFOLLOW support")
549
+ fail(f"{label} requires O_NOFOLLOW support")
500
550
  return os.O_NOFOLLOW
501
551
 
502
552
 
503
- def _directory_open_flags(*, follow_final: bool = False) -> int:
553
+ def _directory_open_flags(*, follow_final: bool = False, label: str = "private local cost storage") -> int:
504
554
  flags = os.O_RDONLY
505
555
  if hasattr(os, "O_CLOEXEC"):
506
556
  flags |= os.O_CLOEXEC
507
557
  if hasattr(os, "O_DIRECTORY"):
508
558
  flags |= os.O_DIRECTORY
509
559
  if not follow_final:
510
- flags |= _no_follow_flag()
560
+ flags |= _no_follow_flag(label=label)
511
561
  return flags
512
562
 
513
563
 
@@ -572,18 +622,18 @@ def reject_symlink_components(path: Path, *, label: str) -> Path:
572
622
  return path
573
623
 
574
624
 
575
- def open_private_directory(path: Path, *, label: str) -> int:
625
+ def open_directory_no_follow(path: Path, *, label: str) -> int:
576
626
  """Open an existing directory without following symlink path components."""
577
627
 
578
628
  if not dir_fd_open_supported():
579
- fail(f"{label} requires dir_fd support for symlink-safe private storage")
629
+ fail(f"{label} requires dir_fd support for symlink-safe directory traversal")
580
630
  path = reject_symlink_components(path, label=label)
581
- flags = _directory_open_flags()
631
+ flags = _directory_open_flags(label=label)
582
632
  if path.is_absolute():
583
633
  anchor = path.anchor or os.sep
584
634
  parts = path.parts[1:]
585
635
  try:
586
- current_fd = os.open(anchor, _directory_open_flags(follow_final=True))
636
+ current_fd = os.open(anchor, _directory_open_flags(follow_final=True, label=label))
587
637
  except OSError as exc:
588
638
  fail(f"could not inspect {label}: {os_error_detail(exc)}")
589
639
  else:
@@ -635,6 +685,12 @@ def open_private_directory(path: Path, *, label: str) -> int:
635
685
  pass
636
686
 
637
687
 
688
+ def open_private_directory(path: Path, *, label: str) -> int:
689
+ """Open an existing private-storage directory without following symlinks."""
690
+
691
+ return open_directory_no_follow(path, label=label)
692
+
693
+
638
694
  def fsync_directory_fd(fd: int) -> None:
639
695
  if os.name != "posix":
640
696
  return
@@ -676,7 +732,7 @@ def open_private_regular_fd_for_read(path: Path, *, label: str) -> int:
676
732
  fd = -1
677
733
  try:
678
734
  parent_fd = open_private_directory(path.parent, label=f"{label} parent")
679
- fd = os.open(leaf_name, _base_open_flags() | _no_follow_flag(), dir_fd=parent_fd)
735
+ fd = os.open(leaf_name, _base_open_flags() | _no_follow_flag(label=label), dir_fd=parent_fd)
680
736
  st = os.fstat(fd)
681
737
  if not stat.S_ISREG(st.st_mode):
682
738
  fail(f"{label} must be a regular file")
@@ -1138,41 +1194,47 @@ def open_private_regular_file_for_append(path: Path, *, label: str) -> int:
1138
1194
  flags = os.O_WRONLY | os.O_CREAT | os.O_APPEND | _no_follow_flag()
1139
1195
  if hasattr(os, "O_CLOEXEC"):
1140
1196
  flags |= os.O_CLOEXEC
1141
- parent_fd = -1
1142
- fd = -1
1143
- try:
1144
- parent_fd = open_private_directory(path.parent, label=f"{label} parent")
1145
- fd = os.open(leaf_name, flags, 0o600, dir_fd=parent_fd)
1146
- st = os.fstat(fd)
1147
- if not stat.S_ISREG(st.st_mode):
1148
- fail(f"{label} must be a regular file")
1149
- try:
1150
- os.fchmod(fd, 0o600)
1151
- except (AttributeError, OSError):
1152
- pass
1153
- st = os.fstat(fd)
1154
- if os.name == "posix" and stat.S_IMODE(st.st_mode) != 0o600:
1155
- fail(f"could not verify {label} privacy: expected mode 0600")
1156
- owned_fd = fd
1197
+ for attempt in range(LEDGER_OPEN_RETRY_ATTEMPTS):
1198
+ parent_fd = -1
1157
1199
  fd = -1
1158
- return owned_fd
1159
- except CostGuardError:
1160
- raise
1161
- except OSError as exc:
1162
- fail(f"could not open {label}: {os_error_detail(exc)}")
1163
- finally:
1164
- if fd >= 0:
1165
- # Ownership transfers to the caller only on the successful return
1166
- # above. On errors, close before surfacing a deterministic message.
1167
- try:
1168
- os.close(fd)
1169
- except OSError:
1170
- pass
1171
- if parent_fd >= 0:
1200
+ try:
1201
+ parent_fd = open_private_directory(path.parent, label=f"{label} parent")
1202
+ fd = os.open(leaf_name, flags, 0o600, dir_fd=parent_fd)
1203
+ st = os.fstat(fd)
1204
+ if not stat.S_ISREG(st.st_mode):
1205
+ fail(f"{label} must be a regular file")
1172
1206
  try:
1173
- os.close(parent_fd)
1174
- except OSError:
1207
+ os.fchmod(fd, 0o600)
1208
+ except (AttributeError, OSError):
1175
1209
  pass
1210
+ st = os.fstat(fd)
1211
+ if os.name == "posix" and stat.S_IMODE(st.st_mode) != 0o600:
1212
+ fail(f"could not verify {label} privacy: expected mode 0600")
1213
+ owned_fd = fd
1214
+ fd = -1
1215
+ return owned_fd
1216
+ except CostGuardError:
1217
+ raise
1218
+ except OSError as exc:
1219
+ if exc.errno == errno.ENOENT and attempt + 1 < LEDGER_OPEN_RETRY_ATTEMPTS:
1220
+ time.sleep(LEDGER_OPEN_RETRY_SECONDS)
1221
+ continue
1222
+ fail(f"could not open {label}: {os_error_detail(exc)}")
1223
+ finally:
1224
+ if fd >= 0:
1225
+ # Ownership transfers to the caller only on the successful
1226
+ # return above. On errors, close before surfacing a
1227
+ # deterministic message.
1228
+ try:
1229
+ os.close(fd)
1230
+ except OSError:
1231
+ pass
1232
+ if parent_fd >= 0:
1233
+ try:
1234
+ os.close(parent_fd)
1235
+ except OSError:
1236
+ pass
1237
+ raise AssertionError("unreachable: append retry loop exits via return or fail")
1176
1238
 
1177
1239
 
1178
1240
  def load_ledger(store_dir: Path) -> list[dict[str, Any]]:
@@ -1280,7 +1342,7 @@ def default_pricing_profile() -> dict[str, Any]:
1280
1342
  }
1281
1343
 
1282
1344
 
1283
- def load_pricing_profile(raw: str | None) -> dict[str, Any]:
1345
+ def load_pricing_profile(raw: str | None, *, max_bytes: int = DEFAULT_MAX_BYTES) -> dict[str, Any]:
1284
1346
  profile = default_pricing_profile()
1285
1347
  if not raw:
1286
1348
  return profile
@@ -1288,7 +1350,12 @@ def load_pricing_profile(raw: str | None) -> dict[str, Any]:
1288
1350
  if raw.lstrip().startswith("{"):
1289
1351
  override = json.loads(raw, parse_constant=reject_json_constant)
1290
1352
  else:
1291
- override = json.loads(Path(raw).read_text(encoding="utf-8"), parse_constant=reject_json_constant)
1353
+ text, truncated = read_bounded_regular_path(raw, max_bytes=max_bytes, label="pricing profile")
1354
+ if truncated:
1355
+ fail("pricing profile exceeded max bytes")
1356
+ override = json.loads(text, parse_constant=reject_json_constant)
1357
+ except CostGuardError:
1358
+ raise
1292
1359
  except (OSError, json.JSONDecodeError, ValueError) as exc:
1293
1360
  fail(f"could not load pricing profile: {exc}")
1294
1361
  if not isinstance(override, dict):
@@ -1542,7 +1609,7 @@ def annotate_cache_state(
1542
1609
  def preflight_command(args: argparse.Namespace) -> int:
1543
1610
  request_raw, _truncated = load_json_input(args.request, max_bytes=args.max_bytes)
1544
1611
  request = require_json_object(request_raw, "request")
1545
- profile = load_pricing_profile(args.pricing_profile)
1612
+ profile = load_pricing_profile(args.pricing_profile, max_bytes=args.max_bytes)
1546
1613
  if args.usd_to_krw is not None:
1547
1614
  profile["usd_to_krw"] = usd_to_krw(profile, args.usd_to_krw)
1548
1615
  if args.budget_usd is not None:
@@ -1809,7 +1876,7 @@ def observe_command(args: argparse.Namespace) -> int:
1809
1876
  usage = usage_raw
1810
1877
  if not isinstance(usage, dict):
1811
1878
  fail("usage must be a JSON object or an object containing a usage object")
1812
- profile = load_pricing_profile(args.pricing_profile)
1879
+ profile = load_pricing_profile(args.pricing_profile, max_bytes=args.max_bytes)
1813
1880
  if args.usd_to_krw is not None:
1814
1881
  profile["usd_to_krw"] = usd_to_krw(profile, args.usd_to_krw)
1815
1882
  model = str(args.model or (usage_raw.get("model") if isinstance(usage_raw, dict) else "") or "unknown")
@@ -2217,7 +2284,7 @@ def emit(data: dict[str, Any], *, json_mode: bool) -> None:
2217
2284
  def add_common_cost_args(parser: argparse.ArgumentParser) -> None:
2218
2285
  parser.add_argument("--pricing-profile", help="JSON string or file with input/output rates, cache multipliers, and usd_to_krw")
2219
2286
  parser.add_argument("--usd-to-krw", type=float, help="override USD→KRW exchange rate used for estimates")
2220
- parser.add_argument("--max-bytes", type=int, default=DEFAULT_MAX_BYTES, help=f"maximum JSON input bytes (default: {DEFAULT_MAX_BYTES})")
2287
+ parser.add_argument("--max-bytes", type=int, default=DEFAULT_MAX_BYTES, help=f"maximum JSON input and pricing profile file bytes (default: {DEFAULT_MAX_BYTES})")
2221
2288
  parser.add_argument("--json", action="store_true", help="emit machine-readable JSON")
2222
2289
 
2223
2290