dh-cli 0.8.0__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {dh_cli-0.8.0 → dh_cli-0.8.2}/PKG-INFO +1 -1
  2. {dh_cli-0.8.0 → dh_cli-0.8.2}/pyproject.toml +1 -1
  3. dh_cli-0.8.2/src/dh_cli/_identity.py +88 -0
  4. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/submit.py +24 -8
  5. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/bedrock/__init__.py +1 -0
  6. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/bedrock/commands.py +13 -21
  7. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/bedrock/cost_report.py +115 -48
  8. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/engine_commands.py +2 -8
  9. dh_cli-0.8.2/src/dh_cli/github_commands.py +752 -0
  10. dh_cli-0.8.2/tests/batch/test_submit_merge.py +220 -0
  11. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/conftest.py +3 -0
  12. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_build_report.py +19 -16
  13. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_classify_arn.py +1 -0
  14. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_cli_exit_codes.py +61 -36
  15. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_cost_calc.py +3 -6
  16. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_cost_command.py +53 -38
  17. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_cur_reconciliation.py +7 -20
  18. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_key_command.py +8 -9
  19. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_render_formats.py +103 -78
  20. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_resolve_base_model.py +1 -0
  21. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/test_s3_walker.py +103 -47
  22. dh_cli-0.8.2/tests/github/__init__.py +0 -0
  23. dh_cli-0.8.2/tests/github/conftest.py +187 -0
  24. dh_cli-0.8.2/tests/github/test_engine_role_cannot_read_github_pat.py +44 -0
  25. dh_cli-0.8.2/tests/github/test_identity.py +70 -0
  26. dh_cli-0.8.2/tests/github/test_login.py +227 -0
  27. dh_cli-0.8.2/tests/github/test_login_error_paths.py +134 -0
  28. dh_cli-0.8.2/tests/github/test_login_security.py +141 -0
  29. dh_cli-0.8.2/tests/github/test_logout.py +44 -0
  30. dh_cli-0.8.2/tests/github/test_rotate.py +198 -0
  31. dh_cli-0.8.2/tests/github/test_status.py +86 -0
  32. dh_cli-0.8.0/src/dh_cli/github_commands.py +0 -275
  33. {dh_cli-0.8.0 → dh_cli-0.8.2}/.gitignore +0 -0
  34. {dh_cli-0.8.0 → dh_cli-0.8.2}/LICENSE +0 -0
  35. {dh_cli-0.8.0 → dh_cli-0.8.2}/README.md +0 -0
  36. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/__init__.py +0 -0
  37. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/__init__.py +0 -0
  38. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/aws_batch.py +0 -0
  39. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/__init__.py +0 -0
  40. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/boltz.py +0 -0
  41. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/cancel.py +0 -0
  42. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/clean.py +0 -0
  43. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/embed_t5.py +0 -0
  44. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/finalize.py +0 -0
  45. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/list_jobs.py +0 -0
  46. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/local.py +0 -0
  47. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/logs.py +0 -0
  48. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/orca.py +0 -0
  49. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/protmpnn.py +0 -0
  50. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/protmpnn_to_boltz.py +0 -0
  51. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/retry.py +0 -0
  52. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/status.py +0 -0
  53. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/train.py +0 -0
  54. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/commands/wait_for.py +0 -0
  55. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/fasta_utils.py +0 -0
  56. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/h5_utils.py +0 -0
  57. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/job_id.py +0 -0
  58. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/manifest.py +0 -0
  59. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/batch/s3_transport.py +0 -0
  60. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/bedrock/pricing.yaml +0 -0
  61. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/cloud_commands.py +0 -0
  62. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/codeartifact.py +0 -0
  63. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/__init__.py +0 -0
  64. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/api_client.py +0 -0
  65. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/auth.py +0 -0
  66. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/progress.py +0 -0
  67. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/ssh_config.py +0 -0
  68. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/engines_studios/studio_commands.py +0 -0
  69. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/__init__.py +0 -0
  70. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/deploy.py +0 -0
  71. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/local.py +0 -0
  72. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/test.py +0 -0
  73. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/tf.py +0 -0
  74. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/hz/users.py +0 -0
  75. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/main.py +0 -0
  76. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/utility_commands.py +0 -0
  77. {dh_cli-0.8.0 → dh_cli-0.8.2}/src/dh_cli/warehouse.py +0 -0
  78. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/batch/__init__.py +0 -0
  79. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/batch/test_aws_batch_resources.py +0 -0
  80. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/batch/test_submit_cpu_only.py +0 -0
  81. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/A_cache_write.json +0 -0
  82. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/B_cache_read.json +0 -0
  83. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/C_plain.json +0 -0
  84. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/D_cursor_user.json +0 -0
  85. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/E_service_role.json +0 -0
  86. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/F_legacy_shared.json +0 -0
  87. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/bedrock/fixtures/G_unknown_model.json +0 -0
  88. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/hz/test_init.py +0 -0
  89. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/hz/test_suites.py +0 -0
  90. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/hz/test_users.py +0 -0
  91. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/test_cloud_gcp.py +0 -0
  92. {dh_cli-0.8.0 → dh_cli-0.8.2}/tests/test_finalize_protmpnn.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dh-cli
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: Dayhoff Labs developer CLI
5
5
  Author-email: Dayhoff Labs <dev@dayhofflabs.com>
6
6
  License: # PolyForm Noncommercial License 1.0.0
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "dh-cli"
7
- version = "0.8.0"
7
+ version = "0.8.2"
8
8
  description = "Dayhoff Labs developer CLI"
9
9
  requires-python = ">=3.11"
10
10
  readme = "README.md"
@@ -0,0 +1,88 @@
1
+ """Identity resolution for `dh` commands that read per-developer secrets.
2
+
3
+ The `github_commands` and (future) `bedrock` commands both key per-user
4
+ Secrets Manager entries on the caller's Dayhoff handle. This module
5
+ resolves that handle from the current SSO session; the server-side
6
+ resource policy makes the matching decision using `aws:userid` suffix-
7
+ matching on the same handle (see
8
+ blueprints/terraform/environments/dev/github_pat_secrets.tf header for
9
+ the full story of why aws:userid, not aws:PrincipalTag/Email or
10
+ aws:username or aws:PrincipalArn).
11
+
12
+ Design note — where the handle comes from:
13
+
14
+ There is no SDK API that lets a session read its own `aws:userid`
15
+ directly in a structured way. The closest observable is the
16
+ assumed-role ARN's RoleSessionName, which IAM Identity Center sets
17
+ to the login username (the handle) for every DeveloperAccess
18
+ session — exactly the same string that IAM populates as the suffix
19
+ of `aws:userid` during policy evaluation. So "what handle am I?"
20
+ (this function) and "which secret can I read?" (the server-side
21
+ policy) are answered by the same identity fact by construction.
22
+
23
+ For historical reasons the RoleSessionName in this org is a bare
24
+ handle like `dma`, not `dma@dayhofflabs.com`. If Identity Center is
25
+ ever reconfigured to use emails as session names — or if ABAC is
26
+ turned on and the server-side policy flips to
27
+ aws:PrincipalTag/Email — the optional `domain` argument still
28
+ handles the email-style case (strips the suffix) without code
29
+ changes here.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import re
35
+
36
+ # Kept for the email-style RoleSessionName case; unused in the current
37
+ # org (handles are bare). If you ever need to reintroduce domain-
38
+ # stripping, pass domain="dayhofflabs.com" explicitly.
39
+ DEFAULT_DOMAIN = "dayhofflabs.com"
40
+
41
+ _SSO_ASSUMED_ROLE_RE = re.compile(r"^arn:aws:sts::\d+:assumed-role/AWSReservedSSO_[^/]+/(?P<session>.+)$")
42
+
43
+
44
+ class HandleResolutionError(RuntimeError):
45
+ """Raised when the current session's handle can't be determined.
46
+
47
+ The caller is expected to turn this into a user-facing error
48
+ pointing at `awslogin dev-devaccess`.
49
+ """
50
+
51
+
52
+ def resolve_handle_from_session(session, *, domain: str = DEFAULT_DOMAIN) -> str:
53
+ """Return the dev handle (session-name portion of aws:userid).
54
+
55
+ Args:
56
+ session: a boto3 Session configured with the caller's SSO
57
+ credentials. The function calls `sts:GetCallerIdentity` on
58
+ it; if that call would fall through to the engine instance
59
+ role instead of the dev's SSO creds, the caller must have
60
+ already detected and errored on that before getting here
61
+ (see `github_commands._sso_session`).
62
+ domain: email domain to strip from the session name, for orgs
63
+ where IAM Identity Center uses emails as RoleSessionNames.
64
+ No-op for bare-handle RoleSessionNames (the current org).
65
+
66
+ Returns:
67
+ The handle (e.g. `"dma"`). This is the same string that appears
68
+ as the suffix of `aws:userid` during IAM policy evaluation for
69
+ the caller's session — i.e. the value the server-side policy
70
+ matches against in the `StringLike "aws:userid": "*:<handle>"`
71
+ condition.
72
+
73
+ Raises:
74
+ HandleResolutionError: the caller's ARN doesn't look like an
75
+ Identity Center DeveloperAccess session.
76
+ """
77
+ arn = session.client("sts").get_caller_identity()["Arn"]
78
+ match = _SSO_ASSUMED_ROLE_RE.match(arn)
79
+ if not match:
80
+ raise HandleResolutionError(
81
+ f"Caller ARN does not look like an AWS SSO session: {arn}. "
82
+ f"Run `awslogin dev-devaccess` (or pass --handle explicitly)."
83
+ )
84
+ session_name = match.group("session")
85
+ suffix = f"@{domain}"
86
+ if session_name.endswith(suffix):
87
+ return session_name[: -len(suffix)]
88
+ return session_name
@@ -2,6 +2,7 @@
2
2
 
3
3
  import click
4
4
  import yaml
5
+ from click.core import ParameterSource
5
6
 
6
7
  from ..aws_batch import BatchClient, BatchError, resolve_dependency
7
8
  from ..job_id import generate_job_id, get_aws_username
@@ -34,7 +35,9 @@ DEFAULT_QUEUE = "t4-1x-spot"
34
35
  @click.option("--dry-run", is_flag=True, help="Show plan without submitting")
35
36
  @click.option("--base-path", default=BATCH_JOBS_BASE, help="Base path for job data")
36
37
  @click.option("--after", "after", multiple=True, help="Job ID(s) to wait for before starting")
38
+ @click.pass_context
37
39
  def submit(
40
+ ctx,
38
41
  config_file,
39
42
  command,
40
43
  queue,
@@ -52,7 +55,10 @@ def submit(
52
55
  ):
53
56
  """Submit a custom batch job.
54
57
 
55
- Jobs can be defined via a config file (-f) or inline options.
58
+ Jobs can be defined via a config file (-f) or inline options. When
59
+ both are provided, a CLI flag takes precedence over the
60
+ corresponding YAML field only if the user actually passes the
61
+ flag; otherwise the YAML value wins.
56
62
 
57
63
  \b
58
64
  Examples:
@@ -90,13 +96,23 @@ def submit(
90
96
  if not job_command:
91
97
  raise click.UsageError("Must specify --command or provide config file with 'command' field")
92
98
 
93
- job_queue = queue if queue != DEFAULT_QUEUE else config.get("queue", queue)
94
- job_memory = memory if memory != "30G" else config.get("memory", memory)
95
- job_vcpus = vcpus if vcpus != 8 else config.get("vcpus", vcpus)
96
- job_gpus = gpus if gpus != 1 else config.get("gpus", gpus)
97
- job_array = array if array != 1 else config.get("array", array)
98
- job_retry = retry if retry != 3 else config.get("retry", retry)
99
- job_timeout = timeout if timeout != "6h" else config.get("timeout", timeout)
99
+ # Merge CLI flags with YAML. CLI wins iff the user actually passed the
100
+ # flag; otherwise YAML if set; otherwise the Click default. Uses
101
+ # ParameterSource to tell "user typed --gpus 1" from "Click filled in
102
+ # the default 1", which a bare value comparison cannot do.
103
+ def _pick(param_name, cli_value, yaml_key=None):
104
+ yaml_key = yaml_key or param_name
105
+ if ctx.get_parameter_source(param_name) == ParameterSource.COMMANDLINE:
106
+ return cli_value
107
+ return config.get(yaml_key, cli_value)
108
+
109
+ job_queue = _pick("queue", queue)
110
+ job_memory = _pick("memory", memory)
111
+ job_vcpus = _pick("vcpus", vcpus)
112
+ job_gpus = _pick("gpus", gpus)
113
+ job_array = _pick("array", array)
114
+ job_retry = _pick("retry", retry)
115
+ job_timeout = _pick("timeout", timeout)
100
116
  job_image = image or config.get("image")
101
117
 
102
118
  # Parse environment variables
@@ -1,4 +1,5 @@
1
1
  """`dh bedrock` command group — key delivery + per-user cost reporting."""
2
+
2
3
  from .commands import bedrock
3
4
 
4
5
  __all__ = ["bedrock"]
@@ -12,6 +12,7 @@ Two user-facing commands:
12
12
  Both commands default to reading the caller's identity via STS to
13
13
  resolve their own handle, so the common case is parameter-free.
14
14
  """
15
+
15
16
  from __future__ import annotations
16
17
 
17
18
  import datetime as dt
@@ -82,8 +83,7 @@ def _resolve_handle_from_sts() -> str:
82
83
  if principal.principal_type in ("claude-code", "cursor"):
83
84
  return principal.principal_name
84
85
  raise click.ClickException(
85
- f"Couldn't infer a developer handle from your identity ({arn}). "
86
- "Pass --handle explicitly."
86
+ f"Couldn't infer a developer handle from your identity ({arn}). Pass --handle explicitly."
87
87
  )
88
88
 
89
89
 
@@ -153,9 +153,7 @@ def bedrock_key(handle: Optional[str], region: str, mode: str):
153
153
  ak = payload.get("access_key_id", "")
154
154
  sk = payload.get("secret_access_key", "")
155
155
  if not ak or not sk:
156
- raise click.ClickException(
157
- f"Secret `{secret_id}` is missing access_key_id/secret_access_key fields."
158
- )
156
+ raise click.ClickException(f"Secret `{secret_id}` is missing access_key_id/secret_access_key fields.")
159
157
  click.echo(f"export AWS_ACCESS_KEY_ID='{ak}'")
160
158
  click.echo(f"export AWS_SECRET_ACCESS_KEY='{sk}'")
161
159
  click.echo(f"export AWS_DEFAULT_REGION='{payload.get('region', region)}'")
@@ -296,9 +294,7 @@ def bedrock_cost(
296
294
  if start is None:
297
295
  start = end - dt.timedelta(days=days - 1)
298
296
  if start > end:
299
- raise click.BadParameter(
300
- f"--start ({start}) must be on or before --end ({end})."
301
- )
297
+ raise click.BadParameter(f"--start ({start}) must be on or before --end ({end}).")
302
298
 
303
299
  pricing_file = pricing_path or cr.default_pricing_path()
304
300
  try:
@@ -308,8 +304,11 @@ def bedrock_cost(
308
304
  sys.exit(1)
309
305
 
310
306
  import boto3
307
+ from botocore.config import Config
311
308
 
312
- s3 = boto3.client("s3")
309
+ # Match the thread pool used by walk_logs so urllib3 doesn't block
310
+ # or warn when many parallel GETs are in flight.
311
+ s3 = boto3.client("s3", config=Config(max_pool_connections=32))
313
312
 
314
313
  my_handle: Optional[str] = None
315
314
  if me:
@@ -329,10 +328,9 @@ def bedrock_cost(
329
328
  # mode, including 'model' and 'principal_type' which collapse
330
329
  # principal_name to "" in the output rows.
331
330
  records = (
332
- rec for rec in records
333
- if cr.classify_arn(
334
- rec.get("identity", {}).get("arn", "")
335
- ).principal_name == my_handle
331
+ rec
332
+ for rec in records
333
+ if cr.classify_arn(rec.get("identity", {}).get("arn", "")).principal_name == my_handle
336
334
  )
337
335
  report = cr.build_report(records, pricing, group_by=group_by)
338
336
  except cr.UnknownModel as exc:
@@ -393,16 +391,10 @@ def bedrock_cost(
393
391
  # Keep reconcile output minimal in csv/markdown modes so the body
394
392
  # of the output stays pipe-friendly — stderr, not stdout.
395
393
  stream_err = output_format in ("csv", "markdown")
396
- delta_pct = (
397
- f"{result.delta_fraction * 100:.1f}%"
398
- if result.delta_fraction != float("inf")
399
- else "n/a"
400
- )
394
+ delta_pct = f"{result.delta_fraction * 100:.1f}%" if result.delta_fraction != float("inf") else "n/a"
401
395
  status = "OK" if result.ok else "DRIFT"
402
396
  reconcile_line = (
403
- f"\nReconcile: estimate ${estimate_total:,.2f} "
404
- f"Cost Explorer ${ce_total:,.2f} "
405
- f"delta {delta_pct} [{status}]"
397
+ f"\nReconcile: estimate ${estimate_total:,.2f} Cost Explorer ${ce_total:,.2f} delta {delta_pct} [{status}]"
406
398
  )
407
399
  click.echo(reconcile_line, err=stream_err)
408
400
  # Absolute-dollar floor on the drift exit code: below $1 of discrepancy,
@@ -21,11 +21,13 @@ Exported API:
21
21
  fetch_cost_explorer_total(start, end) -> float
22
22
  default_pricing_path() -> Path
23
23
  """
24
+
24
25
  from __future__ import annotations
25
26
 
26
27
  import datetime as dt
27
28
  import gzip
28
29
  import json
30
+ from concurrent.futures import ThreadPoolExecutor
29
31
  from dataclasses import dataclass, field
30
32
  from pathlib import Path
31
33
  from typing import Any, Iterable, Iterator
@@ -108,7 +110,7 @@ def resolve_base_model(model_id: str) -> str:
108
110
  stripped = model_id
109
111
  for prefix in ("us.", "global.", "eu.", "apac."):
110
112
  if stripped.startswith(prefix):
111
- stripped = stripped[len(prefix):]
113
+ stripped = stripped[len(prefix) :]
112
114
  break
113
115
  for base in _BASE_MODELS:
114
116
  if base in stripped:
@@ -165,9 +167,7 @@ def build_report(
165
167
  group_by: str = "user+model",
166
168
  ) -> Report:
167
169
  if group_by not in _VALID_GROUP_BY:
168
- raise ValueError(
169
- f"group_by={group_by!r} is not one of {sorted(_VALID_GROUP_BY)}"
170
- )
170
+ raise ValueError(f"group_by={group_by!r} is not one of {sorted(_VALID_GROUP_BY)}")
171
171
  agg: dict[tuple, dict[str, Any]] = {}
172
172
  for rec in records:
173
173
  # Bedrock emits records for failed validations / throttles with
@@ -264,17 +264,45 @@ def render_markdown(report: Report) -> str:
264
264
  # Columns that are constant/empty for the chosen grouping are dropped so
265
265
  # the output table stays narrow and scannable in a terminal.
266
266
  _PRETTY_COLUMNS_BY_GROUP = {
267
- "user": ("principal_type", "principal_name", "invocations",
268
- "input_tokens", "output_tokens", "cache_read",
269
- "cache_write", "estimated_cost_usd"),
270
- "user+model": ("principal_type", "principal_name", "model", "invocations",
271
- "input_tokens", "output_tokens", "cache_read",
272
- "cache_write", "estimated_cost_usd"),
273
- "model": ("model", "invocations", "input_tokens", "output_tokens",
274
- "cache_read", "cache_write", "estimated_cost_usd"),
275
- "principal_type": ("principal_type", "invocations", "input_tokens",
276
- "output_tokens", "cache_read", "cache_write",
277
- "estimated_cost_usd"),
267
+ "user": (
268
+ "principal_type",
269
+ "principal_name",
270
+ "invocations",
271
+ "input_tokens",
272
+ "output_tokens",
273
+ "cache_read",
274
+ "cache_write",
275
+ "estimated_cost_usd",
276
+ ),
277
+ "user+model": (
278
+ "principal_type",
279
+ "principal_name",
280
+ "model",
281
+ "invocations",
282
+ "input_tokens",
283
+ "output_tokens",
284
+ "cache_read",
285
+ "cache_write",
286
+ "estimated_cost_usd",
287
+ ),
288
+ "model": (
289
+ "model",
290
+ "invocations",
291
+ "input_tokens",
292
+ "output_tokens",
293
+ "cache_read",
294
+ "cache_write",
295
+ "estimated_cost_usd",
296
+ ),
297
+ "principal_type": (
298
+ "principal_type",
299
+ "invocations",
300
+ "input_tokens",
301
+ "output_tokens",
302
+ "cache_read",
303
+ "cache_write",
304
+ "estimated_cost_usd",
305
+ ),
278
306
  }
279
307
 
280
308
  # Nicer column headers for the pretty renderer.
@@ -290,10 +318,16 @@ _PRETTY_HEADERS = {
290
318
  "estimated_cost_usd": "cost",
291
319
  }
292
320
 
293
- _NUMERIC_COLUMNS = frozenset({
294
- "invocations", "input_tokens", "output_tokens",
295
- "cache_read", "cache_write", "estimated_cost_usd",
296
- })
321
+ _NUMERIC_COLUMNS = frozenset(
322
+ {
323
+ "invocations",
324
+ "input_tokens",
325
+ "output_tokens",
326
+ "cache_read",
327
+ "cache_write",
328
+ "estimated_cost_usd",
329
+ }
330
+ )
297
331
 
298
332
 
299
333
  def _format_cell(column: str, row: ReportRow) -> str:
@@ -323,9 +357,7 @@ def render_pretty(report: Report, *, group_by: str = "user+model") -> str:
323
357
  # Totals footer — numeric columns sum, non-numeric columns are
324
358
  # blank except the first one which gets "TOTAL".
325
359
  if report.rows:
326
- totals: dict[str, int] = {
327
- c: 0 for c in _NUMERIC_COLUMNS if c in columns and c != "estimated_cost_usd"
328
- }
360
+ totals: dict[str, int] = {c: 0 for c in _NUMERIC_COLUMNS if c in columns and c != "estimated_cost_usd"}
329
361
  cost_total = 0.0
330
362
  for row in report.rows:
331
363
  for c in totals:
@@ -378,17 +410,19 @@ def render_csv(report: Report) -> str:
378
410
  writer = csv.writer(buf)
379
411
  writer.writerow(_COLUMNS)
380
412
  for row in report.rows:
381
- writer.writerow([
382
- row.principal_type,
383
- row.principal_name,
384
- row.model,
385
- row.invocations,
386
- row.input_tokens,
387
- row.output_tokens,
388
- row.cache_read,
389
- row.cache_write,
390
- f"{row.estimated_cost_usd:.6f}",
391
- ])
413
+ writer.writerow(
414
+ [
415
+ row.principal_type,
416
+ row.principal_name,
417
+ row.model,
418
+ row.invocations,
419
+ row.input_tokens,
420
+ row.output_tokens,
421
+ row.cache_read,
422
+ row.cache_write,
423
+ f"{row.estimated_cost_usd:.6f}",
424
+ ]
425
+ )
392
426
  return buf.getvalue()
393
427
 
394
428
 
@@ -410,14 +444,46 @@ def walk_logs(
410
444
  region: str,
411
445
  start: dt.date,
412
446
  end: dt.date,
447
+ max_workers: int = 32,
413
448
  ) -> Iterator[dict]:
449
+ """Yield every invocation record in `[start, end]` (inclusive, UTC days).
450
+
451
+ Object GETs are parallelised with a thread pool because each day's
452
+ prefix holds hundreds of tiny (~400-byte) gzipped objects and
453
+ per-request latency dominates wall time. Records within a single
454
+ object are yielded in their original NDJSON order; records *across*
455
+ objects may be reordered — downstream aggregation (`build_report`)
456
+ is order-insensitive.
457
+
458
+ `max_workers` caps in-flight S3 GETs per day. The caller's
459
+ `s3_client` should be configured with `max_pool_connections` >=
460
+ `max_workers` (see `botocore.config.Config`) to avoid urllib3
461
+ connection-pool contention.
462
+ """
414
463
  paginator = s3_client.get_paginator("list_objects_v2")
415
464
  seen_keys: set[str] = set()
465
+
466
+ def _fetch_and_parse(key: str) -> list[dict]:
467
+ body = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
468
+ decompressed = gzip.decompress(body)
469
+ out: list[dict] = []
470
+ # Each object is one or more JSON records separated by
471
+ # newlines (NDJSON). Older Bedrock traffic produced
472
+ # one-record objects; multi-record objects appeared in
473
+ # our bucket on 2026-04-20. Parse line-by-line so both
474
+ # shapes work, and tolerate a trailing newline.
475
+ for line in decompressed.splitlines():
476
+ if not line.strip():
477
+ continue
478
+ out.append(json.loads(line))
479
+ return out
480
+
416
481
  for day in _iter_days(start, end):
417
482
  prefix = (
418
483
  f"invocation-logs/AWSLogs/{account}/BedrockModelInvocationLogs/"
419
484
  f"{region}/{day.year:04d}/{day.month:02d}/{day.day:02d}/"
420
485
  )
486
+ keys: list[str] = []
421
487
  for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
422
488
  for obj in page.get("Contents", []) or []:
423
489
  key = obj["Key"]
@@ -428,17 +494,18 @@ def walk_logs(
428
494
  if key in seen_keys:
429
495
  continue
430
496
  seen_keys.add(key)
431
- body = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
432
- decompressed = gzip.decompress(body)
433
- # Each object is one or more JSON records separated by
434
- # newlines (NDJSON). Older Bedrock traffic produced
435
- # one-record objects; multi-record objects appeared in
436
- # our bucket on 2026-04-20. Parse line-by-line so both
437
- # shapes work, and tolerate a trailing newline.
438
- for line in decompressed.splitlines():
439
- if not line.strip():
440
- continue
441
- yield json.loads(line)
497
+ keys.append(key)
498
+ if not keys:
499
+ continue
500
+ # One pool per day bounds concurrent in-flight GETs and caps
501
+ # peak memory (at most ~max_workers decompressed objects held
502
+ # at once). ex.map preserves submission order, so the day's
503
+ # records stream out in a stable — though not chronological —
504
+ # order.
505
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
506
+ for records in ex.map(_fetch_and_parse, keys):
507
+ for rec in records:
508
+ yield rec
442
509
 
443
510
 
444
511
  def reconcile_with_cost_explorer(
@@ -451,8 +518,9 @@ def reconcile_with_cost_explorer(
451
518
  if estimate_total == 0:
452
519
  return ReconcileResult(False, float("inf"), estimate_total, ce_total, threshold)
453
520
  if ce_total == 0:
454
- return ReconcileResult(False, abs(ce_total - estimate_total) / estimate_total,
455
- estimate_total, ce_total, threshold)
521
+ return ReconcileResult(
522
+ False, abs(ce_total - estimate_total) / estimate_total, estimate_total, ce_total, threshold
523
+ )
456
524
  delta = abs(ce_total - estimate_total) / estimate_total
457
525
  return ReconcileResult(delta <= threshold, delta, estimate_total, ce_total, threshold)
458
526
 
@@ -479,8 +547,7 @@ def fetch_cost_explorer_total(start: dt.date, end: dt.date) -> float:
479
547
  Dimension="SERVICE",
480
548
  )
481
549
  bedrock_services = [
482
- v["Value"] for v in dim.get("DimensionValues", [])
483
- if v["Value"].endswith("(Amazon Bedrock Edition)")
550
+ v["Value"] for v in dim.get("DimensionValues", []) if v["Value"].endswith("(Amazon Bedrock Edition)")
484
551
  ]
485
552
  if not bedrock_services:
486
553
  # No Bedrock-family spend at all in the window — CE honestly
@@ -606,11 +606,7 @@ def list_engines(env: Optional[str]):
606
606
  return left + mid.join("─" * (w + 1) for w in cols) + right
607
607
 
608
608
  click.echo(border("╭", "┬", "╮"))
609
- click.echo(
610
- "│"
611
- + "│".join(f" {h:{a}{w}}" for h, a, w in zip(headers, aligns, cols))
612
- + "│"
613
- )
609
+ click.echo("│" + "│".join(f" {h:{a}{w}}" for h, a, w in zip(headers, aligns, cols)) + "│")
614
610
  click.echo(border("├", "┼", "┤"))
615
611
 
616
612
  for i, engine in enumerate(engines):
@@ -649,9 +645,7 @@ def list_engines(env: Optional[str]):
649
645
  else:
650
646
  disk_d = f"\033[32m{disk_text:>{dw}}\033[0m"
651
647
 
652
- click.echo(
653
- f"│ {name_d}│ {state_d}│ {user:<{uw}}│ {etype:<{tw}}│ {uptime_d}│ {disk_d}│"
654
- )
648
+ click.echo(f"│ {name_d}│ {state_d}│ {user:<{uw}}│ {etype:<{tw}}│ {uptime_d}│ {disk_d}│")
655
649
 
656
650
  click.echo(border("╰", "┴", "╯"))
657
651
  click.echo(f"Total: {len(engines)}\n")