ap-client 0.1.4.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ap_client/cli.py ADDED
@@ -0,0 +1,1016 @@
1
+ """Agent Platform CLI - resource/operation style commands."""
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import typer
9
+ from rich import print
10
+ from rich.console import Console
11
+
12
+ from ap_client import __version__, get_client, get_config
13
+ from ap_client.api import set_verbose_override
14
+ from ap_client.exporter import export_group, export_job
15
+ from ap_client.waiter import wait_for_group, wait_for_job
16
+
17
+ app = typer.Typer(
18
+ name="ap",
19
+ help="Agent Platform CLI - minimal job submission tool",
20
+ add_completion=False,
21
+ no_args_is_help=True,
22
+ )
23
+ console = Console()
24
+
25
+
26
+ def _version_callback(value: bool) -> None:
27
+ if value:
28
+ typer.echo(__version__)
29
+ raise typer.Exit()
30
+
31
+
32
+ @app.callback()
33
+ def _global_options(
34
+ verbose: bool = typer.Option(
35
+ False,
36
+ "--verbose",
37
+ help="Print HTTP Request/Response details to stderr (overrides AP_VERBOSE)",
38
+ ),
39
+ version: bool = typer.Option(
40
+ False,
41
+ "--version",
42
+ callback=_version_callback,
43
+ is_eager=True,
44
+ help="Show version and exit",
45
+ ),
46
+ ) -> None:
47
+ """Global CLI options."""
48
+ del version
49
+ if verbose:
50
+ set_verbose_override(True)
51
+
52
+ # Resource sub-command groups
53
+ template_app = typer.Typer(help="Template operations")
54
+ dataset_app = typer.Typer(help="Dataset operations")
55
+ job_app = typer.Typer(help="Job operations")
56
+ group_app = typer.Typer(help="Group operations")
57
+
58
+ app.add_typer(template_app, name="template")
59
+ app.add_typer(dataset_app, name="dataset")
60
+ app.add_typer(job_app, name="job")
61
+ app.add_typer(group_app, name="group")
62
+
63
+ _PAI_RUNTIME_ENV_TAGS: tuple[tuple[str, str], ...] = (
64
+ ("DLC_JOB_ID", "dlc_job_id"),
65
+ ("DSW_INSTANCE_ID", "dsw_instance_id"),
66
+ ("PAI_WORKSPACE_ID", "pai_workspace_id"),
67
+ ("PAI_WORKSPACE_NAME", "pai_workspace_name"),
68
+ ("PAI_USER_ID", "pai_user_id"),
69
+ ("PAI_CLUSTER_ID", "pai_cluster_id"),
70
+ )
71
+
72
+
73
+ def _print_json(data):
74
+ """Pretty-print JSON output."""
75
+ console.print_json(json.dumps(data, ensure_ascii=False, indent=2))
76
+
77
+
78
+ def _emit_info(message: str, output_format: str = "plain") -> None:
79
+ typer.echo(message, err=output_format != "plain")
80
+
81
+
82
+ def _normalize_output_format(output_format: str) -> str:
83
+ value = output_format.lower()
84
+ if value == "text":
85
+ return "plain"
86
+ if value not in {"plain", "json", "yaml"}:
87
+ raise typer.BadParameter("output format must be one of plain/json/yaml")
88
+ return value
89
+
90
+
91
+ def _print_formatted(data: dict, output_format: str) -> None:
92
+ if output_format == "json":
93
+ typer.echo(json.dumps(data, ensure_ascii=False, indent=2))
94
+ return
95
+ if output_format == "yaml":
96
+ try:
97
+ import yaml
98
+ except ImportError as exc:
99
+ raise typer.BadParameter("--format yaml requires PyYAML to be installed") from exc
100
+ typer.echo(yaml.safe_dump(data, allow_unicode=True, sort_keys=False))
101
+ return
102
+ _print_json(data)
103
+
104
+
105
+ def _print_plain_table(rows: list[dict], columns: list[tuple[str, str]]) -> None:
106
+ if not rows:
107
+ typer.echo("")
108
+ return
109
+
110
+ widths = []
111
+ for key, title in columns:
112
+ width = len(title)
113
+ for row in rows:
114
+ width = max(width, len(_stringify_cell(row.get(key))))
115
+ widths.append(width)
116
+
117
+ header = " ".join(title.ljust(width) for (_, title), width in zip(columns, widths))
118
+ typer.echo(header)
119
+ for row in rows:
120
+ typer.echo(
121
+ " ".join(
122
+ _stringify_cell(row.get(key)).ljust(width)
123
+ for (key, _title), width in zip(columns, widths)
124
+ )
125
+ )
126
+
127
+
128
+ def _stringify_cell(value: object) -> str:
129
+ if value is None:
130
+ return ""
131
+ return str(value)
132
+
133
+
134
+ def _merge_job_tags(tags: Optional[list[str]]) -> Optional[list[str]]:
135
+ merged: list[str] = []
136
+ seen_tags: set[str] = set()
137
+
138
+ for raw in (tags or []) + _runtime_job_tags():
139
+ tag = str(raw).strip()
140
+ if not tag or tag in seen_tags:
141
+ continue
142
+ merged.append(tag)
143
+ seen_tags.add(tag)
144
+
145
+ return merged or None
146
+
147
+
148
+ def _runtime_job_tags() -> list[str]:
149
+ if not (os.getenv("DLC_JOB_ID", "").strip() or os.getenv("DSW_INSTANCE_ID", "").strip()):
150
+ return []
151
+
152
+ tags: list[str] = []
153
+ for env_name, tag_name in _PAI_RUNTIME_ENV_TAGS:
154
+ value = os.getenv(env_name, "").strip()
155
+ if value:
156
+ tags.append(f"{tag_name}:{value}")
157
+ return tags
158
+
159
+
160
+ def _print_job_list_plain(result: dict) -> None:
161
+ _print_plain_table(
162
+ result.get("jobs", []),
163
+ [
164
+ ("job_id", "job_id"),
165
+ ("template", "template"),
166
+ ("instance_id", "instance_id"),
167
+ ("group_id", "group_id"),
168
+ ("status", "status"),
169
+ ("failed_reason", "failed_reason"),
170
+ ("created_at", "created_at"),
171
+ ],
172
+ )
173
+
174
+
175
+ def _print_group_list_plain(result: dict) -> None:
176
+ _print_plain_table(
177
+ result.get("groups", []),
178
+ [
179
+ ("group_id", "group_id"),
180
+ ("name", "name"),
181
+ ("max_concurrency", "max_concurrency"),
182
+ ("created_at", "created_at"),
183
+ ],
184
+ )
185
+
186
+
187
+ def _pick_available_columns(
188
+ rows: list[dict], candidates: list[tuple[str, str]]
189
+ ) -> list[tuple[str, str]]:
190
+ return [col for col in candidates if any(col[0] in row for row in rows)]
191
+
192
+
193
+ def _print_group_eval_plain(result: dict) -> None:
194
+ header_parts: list[str] = []
195
+ for key in ("group_id", "name"):
196
+ if key in result:
197
+ header_parts.append(f"{key}={_stringify_cell(result.get(key))}")
198
+ if header_parts:
199
+ typer.echo(" ".join(header_parts))
200
+
201
+ summary = result.get("summary", {})
202
+ has_summary = isinstance(summary, dict) and bool(summary)
203
+ if has_summary:
204
+ pass_keys = sorted(
205
+ k for k in summary.keys() if k.startswith("pass@") or k.startswith("pass^")
206
+ )
207
+ summary_columns = _pick_available_columns(
208
+ [summary],
209
+ [
210
+ ("total_tasks", "total_tasks"),
211
+ ("total_trials", "total_trials"),
212
+ ("scored_tasks", "scored_tasks"),
213
+ ("passed_tasks", "passed_tasks"),
214
+ ("avg_score", "avg_score"),
215
+ ("pass_rate", "pass_rate"),
216
+ ("avg_completion", "avg_completion"),
217
+ ("avg_robustness", "avg_robustness"),
218
+ ("avg_safety", "avg_safety"),
219
+ *[(key, key) for key in pass_keys],
220
+ ],
221
+ )
222
+ if summary_columns:
223
+ _print_plain_table([summary], summary_columns)
224
+ else:
225
+ _print_json(summary)
226
+
227
+ raw_rows = result.get("per_task", [])
228
+ rows = [row for row in raw_rows if isinstance(row, dict)] if isinstance(raw_rows, list) else []
229
+ if has_summary and rows:
230
+ typer.echo("")
231
+ if not rows:
232
+ return
233
+
234
+ columns = _pick_available_columns(
235
+ rows,
236
+ [
237
+ ("task_id", "task_id"),
238
+ ("trials", "trials"),
239
+ ("scored", "scored"),
240
+ ("avg_score", "avg_score"),
241
+ ("passed", "passed"),
242
+ ("pass_rate", "pass_rate"),
243
+ ("completion", "completion"),
244
+ ("robustness", "robustness"),
245
+ ("communication", "communication"),
246
+ ("safety", "safety"),
247
+ ("category", "category"),
248
+ ("difficulty", "difficulty"),
249
+ ("total_tokens", "total_tokens"),
250
+ ("input_tokens", "input_tokens"),
251
+ ("output_tokens", "output_tokens"),
252
+ ("total_turns", "total_turns"),
253
+ ("wall_time_s", "wall_time_s"),
254
+ ("model_time_s", "model_time_s"),
255
+ ],
256
+ )
257
+ if not columns:
258
+ _print_json({"results": rows})
259
+ return
260
+ _print_plain_table(rows, columns)
261
+
262
+
263
+ def _print_job_events_plain(result: dict) -> None:
264
+ if "containers" in result:
265
+ for cname, events in result["containers"].items():
266
+ print(f"\n[{cname}]")
267
+ for event in events:
268
+ print(event)
269
+ return
270
+ for event in result.get("events", []):
271
+ print(event)
272
+
273
+
274
+ def _mask_config_headers(headers: dict) -> dict:
275
+ masked = {}
276
+ for key, value in headers.items():
277
+ lower = key.lower()
278
+ if any(
279
+ token in lower for token in ("authorization", "token", "api-key", "cookie", "secret")
280
+ ):
281
+ masked[key] = "***"
282
+ else:
283
+ masked[key] = value
284
+ return masked
285
+
286
+
287
+ def _config_payload() -> dict:
288
+ config = get_config()
289
+ return {
290
+ "base_url": config.base_url,
291
+ "agenthub_ref": config.agenthub_ref,
292
+ "headers": _mask_config_headers(config.headers),
293
+ }
294
+
295
+
296
+ @app.command("config")
297
+ def show_config(
298
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
299
+ ):
300
+ """Show the current CLI configuration."""
301
+ output_format = _normalize_output_format(output_format)
302
+ payload = _config_payload()
303
+ if output_format == "plain":
304
+ rows = [
305
+ {"key": "base_url", "value": payload["base_url"]},
306
+ {"key": "agenthub_ref", "value": payload["agenthub_ref"]},
307
+ {"key": "headers", "value": json.dumps(payload["headers"], ensure_ascii=False)},
308
+ ]
309
+ _print_plain_table(rows, [("key", "key"), ("value", "value")])
310
+ return
311
+ _print_formatted(payload, output_format)
312
+
313
+
314
+ # ==================== Template operations ====================
315
+
316
+
317
+ @template_app.command("list")
318
+ def template_list():
319
+ """List all templates."""
320
+ client = get_client()
321
+ templates = client.list_templates()
322
+ _print_json(templates)
323
+
324
+
325
+ @template_app.command("get")
326
+ def template_get(name: str = typer.Argument(..., help="Template name")):
327
+ """Show template details."""
328
+ client = get_client()
329
+ template = client.get_template(name)
330
+ _print_json(template)
331
+
332
+
333
+ @template_app.command("render")
334
+ def template_render(
335
+ name: str = typer.Argument(..., help="Template name"),
336
+ params: str = typer.Argument(..., help="Parameters (JSON)"),
337
+ ):
338
+ """Preview template rendering result (without submitting)."""
339
+ client = get_client()
340
+ params_dict = json.loads(params)
341
+ result = client.render_template(name, params_dict)
342
+ print(f"Image: {result.get('image')}")
343
+ print(f"Resources: {result.get('resources')}")
344
+ print(f"Script: {result.get('script_length')} chars")
345
+ if result.get("sidecars"):
346
+ print(f"Sidecars: {', '.join(s['name'] for s in result['sidecars'])}")
347
+ if result.get("env"):
348
+ print(f"\nEnv ({len(result['env'])} vars):")
349
+ for k, v in sorted(result["env"].items()):
350
+ val = str(v)
351
+ if len(val) > 60:
352
+ val = val[:57] + "..."
353
+ print(f" {k}={val}")
354
+
355
+
356
+ # ==================== Dataset operations ====================
357
+
358
+
359
+ @dataset_app.command("list")
360
+ def dataset_list(
361
+ search: Optional[str] = typer.Argument(None, help="Search keyword"),
362
+ ):
363
+ """List all datasets."""
364
+ client = get_client()
365
+ result = client.list_datasets(search)
366
+ _print_json(result)
367
+
368
+
369
+ @dataset_app.command("versions")
370
+ def dataset_versions(
371
+ dataset: str = typer.Argument(..., help="Dataset name"),
372
+ ):
373
+ """List dataset versions."""
374
+ client = get_client()
375
+ versions = client.list_dataset_versions(dataset)
376
+ _print_json(versions)
377
+
378
+
379
+ @dataset_app.command("instances")
380
+ def dataset_instances(
381
+ dataset_version: str = typer.Argument(
382
+ ..., help="Dataset version path (e.g. qclawbench/skill/V0329-tasks)"
383
+ ),
384
+ ):
385
+ """List dataset instances."""
386
+ client = get_client()
387
+ result = client.list_dataset_instances(dataset_version)
388
+ _print_json(result)
389
+
390
+
391
+ # ==================== Job operations ====================
392
+
393
+
394
+ @job_app.command("list")
395
+ def job_list(
396
+ template: Optional[str] = typer.Option(None, "--template", help="Template name"),
397
+ group_id: Optional[str] = typer.Option(None, "--group-id", help="Group ID"),
398
+ status: Optional[str] = typer.Option(None, "--status", help="Job status or finish reason"),
399
+ job_id: Optional[str] = typer.Option(None, "--job-id", help="Job ID prefix"),
400
+ instance_id: Optional[str] = typer.Option(None, "--instance-id", help="Instance ID prefix"),
401
+ finished_after: Optional[str] = typer.Option(
402
+ None, "--finished-after", help="Filter by lower bound of finished_at (ISO time)"
403
+ ),
404
+ finished_before: Optional[str] = typer.Option(
405
+ None, "--finished-before", help="Filter by upper bound of finished_at (ISO time)"
406
+ ),
407
+ created_after: Optional[str] = typer.Option(
408
+ None, "--created-after", help="Filter by lower bound of created_at (ISO time)"
409
+ ),
410
+ created_before: Optional[str] = typer.Option(
411
+ None, "--created-before", help="Filter by upper bound of created_at (ISO time)"
412
+ ),
413
+ skip: int = typer.Option(0, "--skip", help="Skip the first N entries"),
414
+ limit: int = typer.Option(100, "--limit", help="Maximum number of entries to return"),
415
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
416
+ ):
417
+ """List jobs."""
418
+ output_format = _normalize_output_format(output_format)
419
+ client = get_client()
420
+ result = client.list_jobs(
421
+ template=template,
422
+ group_id=group_id,
423
+ status=status,
424
+ job_id=job_id,
425
+ instance_id=instance_id,
426
+ finished_after=finished_after,
427
+ finished_before=finished_before,
428
+ created_after=created_after,
429
+ created_before=created_before,
430
+ skip=skip,
431
+ limit=limit,
432
+ )
433
+ if output_format == "plain":
434
+ _print_job_list_plain(result)
435
+ else:
436
+ _print_formatted(result, output_format)
437
+
438
+
439
+ @job_app.command("create")
440
+ def job_create(
441
+ template: str = typer.Argument(..., help="Template name"),
442
+ instance_id: Optional[str] = typer.Option(
443
+ None, "--instance-id", "-i", help="Instance ID (comma-separated for multiple)"
444
+ ),
445
+ dataset: Optional[str] = typer.Option(None, "--dataset", "-d", help="Dataset version path"),
446
+ params: Optional[str] = typer.Option(None, "--params", "-p", help="Parameters (JSON)"),
447
+ params_list_input: Optional[str] = typer.Option(
448
+ None, "--params-list", "-l", help="Params list: JSON array string or path to a JSON file"
449
+ ),
450
+ suite_name: Optional[str] = typer.Option(None, "--suite-name", help="Suite (job group) name"),
451
+ group_id: Optional[str] = typer.Option(
452
+ None, "--group-id", help="Reuse or create the specified Group ID"
453
+ ),
454
+ tags: Optional[str] = typer.Option(None, "--tags", help="Tags (comma-separated)"),
455
+ overrides: Optional[str] = typer.Option(None, "--overrides", help="Override config (JSON)"),
456
+ concurrency: Optional[int] = typer.Option(
457
+ None,
458
+ "--concurrency",
459
+ "-c",
460
+ help="Concurrency (max simultaneously running jobs in batch submissions)",
461
+ ),
462
+ batch_size: int = typer.Option(
463
+ 200, "--batch-size", min=1, help="Number of params items per batch submission request"
464
+ ),
465
+ wait: bool = typer.Option(
466
+ False, "--wait", help="Keep waiting after submission until the jobs finish"
467
+ ),
468
+ wait_interval: float = typer.Option(
469
+ 5.0, "--wait-interval", min=0.1, help="Polling interval in seconds when used with --wait"
470
+ ),
471
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
472
+ enable_otel_tracing: Optional[bool] = typer.Option(
473
+ None,
474
+ "--enable-otel-tracing",
475
+ help="Enable OTEL tracing (True/False; falls back to global config if omitted)",
476
+ ),
477
+ ):
478
+ """Submit a job.
479
+
480
+ Examples:
481
+ # Single job - full parameters
482
+ ap job create claw-eval-task -p '{"task_id":"T02","model":"qwen-max"}'
483
+
484
+ # Single job - instance_id + shared params
485
+ ap job create claw-eval-task -i T02_email_triage -p '{"model":"qwen-max"}'
486
+
487
+ # Batch jobs - multiple instance_ids + shared params
488
+ ap job create claw-eval-task -i T02,T03,T04 -p '{"model":"qwen-max"}'
489
+
490
+ # Batch jobs - dataset + shared params
491
+ ap job create claw-eval-task --dataset claw-eval/claw-eval/v1@0 -p '{"model":"qwen-max"}'
492
+
493
+ # Batch jobs from a params list (inline JSON)
494
+ ap job create claw-eval-task --params-list '[{"task_id":"T02"},{"task_id":"T03"}]'
495
+
496
+ # Batch jobs from a params list with shared base params (-p merged into each item)
497
+ ap job create claw-eval-task --params-list '[{"task_id":"T02"},{"task_id":"T03"}]' -p '{"model":"qwen-max"}'
498
+
499
+ # Batch jobs from a params list file (also supports -p for shared base params)
500
+ ap job create claw-eval-task --params-list params.json -p '{"model":"qwen-max"}'
501
+ """
502
+ output_format = _normalize_output_format(output_format)
503
+ client = get_client()
504
+
505
+ def wait_printer(message: str) -> None:
506
+ typer.echo(message, err=output_format != "plain")
507
+
508
+ wait_result = None
509
+
510
+ params_dict = json.loads(params) if params else {}
511
+ tags_list = _merge_job_tags(tags.split(",") if tags else None)
512
+ overrides_dict = json.loads(overrides) if overrides else None
513
+
514
+ instance_ids = None
515
+ if instance_id:
516
+ instance_ids = [iid.strip() for iid in instance_id.split(",")]
517
+ # auto-map instance_id -> task_id
518
+ params_dict = {**params_dict, "instance_id": instance_ids[0], "task_id": instance_ids[0]}
519
+
520
+ parsed_params_list = None
521
+ if params_list_input:
522
+ raw = params_list_input.strip()
523
+ if raw.startswith("["):
524
+ try:
525
+ parsed_params_list = json.loads(raw)
526
+ except json.JSONDecodeError as e:
527
+ print(f"[red]error: failed to parse --params-list JSON: {e}[/]")
528
+ raise typer.Exit(1)
529
+ else:
530
+ try:
531
+ with open(raw, "r", encoding="utf-8") as f:
532
+ parsed_params_list = json.load(f)
533
+ except json.JSONDecodeError as e:
534
+ print(f"[red]error: failed to parse --params-list file as JSON: {e}[/]")
535
+ raise typer.Exit(1)
536
+ except FileNotFoundError:
537
+ print(f"[red]error: file not found: {raw}[/]")
538
+ raise typer.Exit(1)
539
+ if not isinstance(parsed_params_list, list):
540
+ print("[red]error: --params-list content must be a JSON array[/]")
541
+ raise typer.Exit(1)
542
+ if len(parsed_params_list) == 0:
543
+ print("[red]error: --params-list content is an empty array[/]")
544
+ raise typer.Exit(1)
545
+
546
+ is_batch = False
547
+ if parsed_params_list:
548
+ is_batch = True
549
+ elif instance_ids and len(instance_ids) > 1:
550
+ is_batch = True
551
+ elif dataset:
552
+ is_batch = True
553
+ elif not instance_id and not dataset and not params:
554
+ print(
555
+ "[red]error: must specify -p, or use -i / --dataset / --params-list[/]"
556
+ )
557
+ raise typer.Exit(1)
558
+
559
+ if is_batch:
560
+ if parsed_params_list:
561
+ # build params_list from --params-list; -p takes precedence over each item
562
+ base = json.loads(params) if params else {}
563
+ params_list = [{**p, **base} for p in parsed_params_list]
564
+ elif instance_ids:
565
+ base = json.loads(params) if params else {}
566
+ params_list = [{**base, "instance_id": iid, "task_id": iid} for iid in instance_ids]
567
+ else:
568
+ _emit_info(f"Fetching dataset instances: {dataset} ...", output_format)
569
+ instances = client.list_dataset_instances(dataset)
570
+ total = instances.get("total", 0)
571
+ if total == 0:
572
+ print("[red]error: no instances found[/]")
573
+ raise typer.Exit(1)
574
+ _emit_info(f"Found {total} instance(s)", output_format)
575
+ instance_ids = instances.get("instance_ids", [])
576
+ params_list = [
577
+ {**params_dict, "instance_id": iid, "task_id": iid} for iid in instance_ids
578
+ ]
579
+
580
+ total_count = len(params_list)
581
+ _emit_info(
582
+ f"Submitting jobs ({total_count} total, batch size {batch_size})...",
583
+ output_format,
584
+ )
585
+
586
+ # Decide the group up front so all jobs share the same group_id. The server
587
+ # only auto-creates a group when a single request's params_list has length
588
+ # > 1; when the first batch has only 1 element the CLI must create the
589
+ # group explicitly to avoid falling back to a standalone job.
590
+ target_group_id = group_id
591
+ if target_group_id:
592
+ _emit_info(f" group_id: {target_group_id}", output_format)
593
+ elif min(batch_size, total_count) == 1:
594
+ group_name = suite_name or f"ap-{template}-n{total_count}"
595
+ group_doc = client.create_group(name=group_name, max_concurrency=concurrency)
596
+ target_group_id = group_doc.get("group_id")
597
+ _emit_info(f" group_id: {target_group_id}", output_format)
598
+
599
+ total_submitted = 0
600
+ total_failed = 0
601
+ batch_results: list[dict] = []
602
+
603
+ for batch_index, start in enumerate(range(0, total_count, batch_size), start=1):
604
+ batch = params_list[start : start + batch_size]
605
+ batch_end = start + len(batch)
606
+ _emit_info(
607
+ f" Submitting batch {batch_index}: {start + 1}-{batch_end}/{total_count}",
608
+ output_format,
609
+ )
610
+ batch_result = client.create_job(
611
+ template=template,
612
+ params_list=batch,
613
+ suite_name=suite_name,
614
+ tags=tags_list,
615
+ overrides=overrides_dict,
616
+ max_concurrency=concurrency,
617
+ group_id=target_group_id,
618
+ enable_otel_tracing=enable_otel_tracing,
619
+ timeout=120,
620
+ )
621
+ if not target_group_id:
622
+ target_group_id = batch_result.get("group_id")
623
+ if target_group_id:
624
+ _emit_info(f" group_id: {target_group_id}", output_format)
625
+
626
+ batch_submitted = int(batch_result.get("submitted") or 0)
627
+ batch_failed = int(batch_result.get("failed") or 0)
628
+ total_submitted += batch_submitted
629
+ total_failed += batch_failed
630
+ batch_results.append(
631
+ {
632
+ "batch": batch_index,
633
+ "start": start,
634
+ "end": batch_end,
635
+ "total": len(batch),
636
+ "submitted": batch_submitted,
637
+ "failed": batch_failed,
638
+ }
639
+ )
640
+ _emit_info(
641
+ f" Progress: {batch_end}/{total_count} "
642
+ f"(submitted {total_submitted}, failed {total_failed})",
643
+ output_format,
644
+ )
645
+
646
+ result = {
647
+ "group_id": target_group_id,
648
+ "total": total_count,
649
+ "submitted": total_submitted,
650
+ "failed": total_failed,
651
+ "batches": batch_results,
652
+ }
653
+ if output_format == "plain":
654
+ print("[green]Batch jobs submitted:[/]")
655
+ print(f" group_id: {result.get('group_id')}")
656
+ print(f" total: {result.get('total')}")
657
+ print(f" submitted: {result.get('submitted')}")
658
+ print(f" failed: {result.get('failed')}")
659
+ else:
660
+ result = client.create_job(
661
+ template=template,
662
+ params=params_dict,
663
+ suite_name=suite_name,
664
+ tags=tags_list,
665
+ overrides=overrides_dict,
666
+ group_id=group_id,
667
+ enable_otel_tracing=enable_otel_tracing,
668
+ )
669
+ if output_format == "plain":
670
+ job = result.get("jobs", [{}])[0]
671
+ print("[green]Job submitted:[/]")
672
+ print(f" job_id: {job.get('job_id')}")
673
+ print(f" status: {job.get('status')}")
674
+ if result.get("group_id"):
675
+ print(f" group_id: {result.get('group_id')}")
676
+
677
+ if wait:
678
+ group_id = result.get("group_id")
679
+ if group_id:
680
+ _emit_info(
681
+ f"Waiting for group to finish: {group_id} (interval={wait_interval}s)",
682
+ output_format,
683
+ )
684
+ final_group = wait_for_group(client, group_id, wait_interval, printer=wait_printer)
685
+ wait_result = {"type": "group", "interval": wait_interval, "result": final_group}
686
+ if output_format == "plain":
687
+ stats = final_group.get("stats") or {}
688
+ print(f"[green]Group finished:[/] {group_id}")
689
+ print(f" total: {stats.get('total', 0)}")
690
+ print(f" succeeded: {stats.get('succeeded', 0)}")
691
+ print(f" failed: {stats.get('failed', 0)}")
692
+ print(f" cancelled: {stats.get('cancelled', 0)}")
693
+ else:
694
+ job = result.get("jobs", [{}])[0]
695
+ job_id = job.get("job_id")
696
+ _emit_info(
697
+ f"Waiting for job to finish: {job_id} (interval={wait_interval}s)",
698
+ output_format,
699
+ )
700
+ final_job = wait_for_job(client, job_id, wait_interval, printer=wait_printer)
701
+ wait_result = {"type": "job", "interval": wait_interval, "result": final_job}
702
+ if output_format == "plain":
703
+ print(f"[green]Job finished:[/] {job_id}")
704
+ print(f" status: {final_job.get('status')}")
705
+ if final_job.get("failed_reason"):
706
+ print(f" failed_reason: {final_job.get('failed_reason')}")
707
+
708
+ if output_format != "plain":
709
+ payload = {"submission": result}
710
+ if wait_result is not None:
711
+ payload["wait"] = wait_result
712
+ _print_formatted(payload, output_format)
713
+
714
+
715
+ @job_app.command("get")
716
+ def job_get(job_id: str = typer.Argument(..., help="Job ID")):
717
+ """Show job status."""
718
+ client = get_client()
719
+ result = client.get_job(job_id)
720
+ _print_json(result)
721
+
722
+
723
+ @job_app.command("logs")
724
+ def job_logs(
725
+ job_id: str = typer.Argument(..., help="Job ID"),
726
+ container: Optional[str] = typer.Option(None, "--container", "-c", help="Container name"),
727
+ tail: Optional[int] = typer.Option(None, "--tail", "-n", help="Show only the last N lines"),
728
+ ):
729
+ """Show job logs."""
730
+ client = get_client()
731
+ result = client.get_job_logs(job_id, container=container, tail=tail)
732
+ if "containers" in result:
733
+ for cname, logs in result["containers"].items():
734
+ print(f"\n[{cname}]")
735
+ print(logs)
736
+ else:
737
+ print(result.get("logs", ""))
738
+
739
+
740
+ @job_app.command("events")
741
+ def job_events(
742
+ job_id: str = typer.Argument(..., help="Job ID"),
743
+ container: Optional[str] = typer.Option(None, "--container", "-c", help="Container name"),
744
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
745
+ ):
746
+ """Show job events."""
747
+ output_format = _normalize_output_format(output_format)
748
+ client = get_client()
749
+ result = client.get_job_events(job_id, container=container)
750
+ if output_format == "plain":
751
+ _print_job_events_plain(result)
752
+ else:
753
+ _print_formatted(result, output_format)
754
+
755
+
756
+ @job_app.command("metrics")
757
+ def job_metrics(
758
+ job_id: str = typer.Argument(..., help="Job ID"),
759
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
760
+ ):
761
+ """Show job metrics."""
762
+ output_format = _normalize_output_format(output_format)
763
+ client = get_client()
764
+ result = client.get_job_metrics(job_id)
765
+ if output_format == "plain":
766
+ _print_json(result)
767
+ else:
768
+ _print_formatted(result, output_format)
769
+
770
+
771
+ @job_app.command("cancel")
772
+ def job_cancel(job_id: str = typer.Argument(..., help="Job ID")):
773
+ """Cancel a job."""
774
+ client = get_client()
775
+ result = client.cancel_job(job_id)
776
+ print(f"[green]Cancelled:[/] {result.get('job_id')}")
777
+ print(f" status: {result.get('status')}")
778
+ if result.get("failed_reason"):
779
+ print(f" failed_reason: {result.get('failed_reason')}")
780
+
781
+
782
+ @job_app.command("artifacts")
783
+ def job_artifacts(
784
+ job_ids: str = typer.Argument(..., help="Job IDs (comma-separated for multiple)"),
785
+ ):
786
+ """Show job artifact download links."""
787
+ client = get_client()
788
+ ids = [jid.strip() for jid in job_ids.split(",")]
789
+ results = client.get_job_artifacts(ids)
790
+ _print_json(results)
791
+
792
+
793
+ @job_app.command("export")
794
+ def job_export(
795
+ job_id: str = typer.Argument(..., help="Job ID"),
796
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Export directory"),
797
+ ):
798
+ """Export job artifacts/logs/events to a local directory."""
799
+ client = get_client()
800
+ dest = export_job(client, job_id, output)
801
+ print(f"[green]Job exported:[/] {job_id}")
802
+ print(f" path: {dest}")
803
+
804
+
805
+ @job_app.command("wait")
806
+ def job_wait(
807
+ job_id: str = typer.Argument(..., help="Job ID"),
808
+ interval: float = typer.Option(
809
+ 5.0, "--interval", "-i", min=0.1, help="Polling interval (seconds)"
810
+ ),
811
+ ):
812
+ """Poll until the job finishes."""
813
+ client = get_client()
814
+ result = wait_for_job(client, job_id, interval)
815
+ print(f"[green]Job finished:[/] {job_id}")
816
+ print(f" status: {result.get('status')}")
817
+ if result.get("failed_reason"):
818
+ print(f" failed_reason: {result.get('failed_reason')}")
819
+
820
+
821
+ # ==================== Group operations ====================
822
+
823
+
824
+ @group_app.command("create")
825
+ def group_create(
826
+ name: Optional[str] = typer.Option(None, "--name", "-n", help="Group name"),
827
+ max_concurrency: Optional[int] = typer.Option(
828
+ None, "--max-concurrency", "-c", min=1, help="Maximum concurrent jobs"
829
+ ),
830
+ eval_config: Optional[str] = typer.Option(
831
+ None, "--eval-config", help="Evaluation config (JSON)"
832
+ ),
833
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
834
+ ):
835
+ """Create a Group."""
836
+ output_format = _normalize_output_format(output_format)
837
+ client = get_client()
838
+ eval_config_dict = json.loads(eval_config) if eval_config else None
839
+ result = client.create_group(
840
+ name=name,
841
+ max_concurrency=max_concurrency,
842
+ eval_config=eval_config_dict,
843
+ )
844
+ if output_format == "plain":
845
+ print("[green]Group created:[/]")
846
+ print(f" group_id: {result.get('group_id')}")
847
+ print(f" name: {result.get('name')}")
848
+ print(f" max_concurrency: {result.get('max_concurrency')}")
849
+ if result.get("eval_config") is not None:
850
+ print(f" eval_config: {json.dumps(result.get('eval_config'), ensure_ascii=False)}")
851
+ return
852
+ _print_formatted(result, output_format)
853
+
854
+
855
+ @group_app.command("list")
856
+ def group_list(
857
+ skip: int = typer.Option(0, "--skip", help="Skip the first N entries"),
858
+ limit: int = typer.Option(100, "--limit", help="Maximum number of entries to return"),
859
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
860
+ ):
861
+ """List Groups."""
862
+ output_format = _normalize_output_format(output_format)
863
+ client = get_client()
864
+ result = client.list_groups(skip=skip, limit=limit)
865
+ if output_format == "plain":
866
+ _print_group_list_plain(result)
867
+ else:
868
+ _print_formatted(result, output_format)
869
+
870
+
871
+ @group_app.command("get")
872
+ def group_get(group_id: str = typer.Argument(..., help="Group ID")):
873
+ """Show Group details."""
874
+ client = get_client()
875
+ result = client.get_group(group_id)
876
+ _print_json(result)
877
+
878
+
879
+ @group_app.command("jobs")
880
+ def group_jobs(
881
+ group_id: str = typer.Argument(..., help="Group ID"),
882
+ tag: Optional[list[str]] = typer.Option(
883
+ None, "--tag", help="Filter by tag (repeatable; multiple tags are combined with AND)"
884
+ ),
885
+ skip: int = typer.Option(0, "--skip", help="Skip the first N entries"),
886
+ limit: int = typer.Option(100, "--limit", help="Maximum number of entries to return"),
887
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
888
+ ):
889
+ """List jobs in a Group."""
890
+ output_format = _normalize_output_format(output_format)
891
+ client = get_client()
892
+ result = client.list_group_jobs(group_id, tag=tag, skip=skip, limit=limit)
893
+ if output_format == "plain":
894
+ _print_job_list_plain(result)
895
+ else:
896
+ _print_formatted(result, output_format)
897
+
898
+
899
+ @group_app.command("stats")
900
+ def group_stats(group_id: str = typer.Argument(..., help="Group ID")):
901
+ """Show Group progress."""
902
+ client = get_client()
903
+ result = client.get_group_stats(group_id)
904
+ _print_json(result)
905
+
906
+
907
+ @group_app.command("eval")
908
+ def group_eval(
909
+ group_id: str = typer.Argument(..., help="Group ID"),
910
+ output_format: str = typer.Option("plain", "--format", help="Output format: plain/json/yaml"),
911
+ ):
912
+ """Show Group aggregated evaluation results."""
913
+ output_format = _normalize_output_format(output_format)
914
+ client = get_client()
915
+ result = client.get_group_eval(group_id)
916
+ if output_format == "plain":
917
+ _print_group_eval_plain(result)
918
+ else:
919
+ _print_formatted(result, output_format)
920
+
921
+
922
+ @group_app.command("artifacts")
923
+ def group_artifacts(group_id: str = typer.Argument(..., help="Group ID")):
924
+ """Show Group artifact download links."""
925
+ client = get_client()
926
+ result = client.get_group_artifacts(group_id)
927
+ _print_json(result)
928
+
929
+
930
+ @group_app.command("cancel")
931
+ def group_cancel(
932
+ group_id: str = typer.Argument(..., help="Group ID"),
933
+ ):
934
+ """Cancel all unfinished jobs under a Group."""
935
+ client = get_client()
936
+ result = client.cancel_group(group_id)
937
+ actual_group_id = result.get("group_id") or group_id
938
+ print(f"[green]Group cancel completed:[/] {actual_group_id}")
939
+ print(f" cancelled: {result.get('cancelled', 0)}")
940
+ print(f" k8s_deleted: {result.get('k8s_deleted', False)}")
941
+ k8s_errors = result.get("k8s_errors") or []
942
+ if k8s_errors:
943
+ print(" k8s_errors:")
944
+ for err in k8s_errors:
945
+ print(f" - {err}")
946
+
947
+
948
+ @group_app.command("export")
949
+ def group_export(
950
+ group_id: str = typer.Argument(..., help="Group ID"),
951
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Export directory"),
952
+ workers: int = typer.Option(
953
+ 4, "--workers", "-w", min=1, help="Number of concurrent export workers"
954
+ ),
955
+ ):
956
+ """Export artifacts/logs/events of all jobs under a Group to a local directory."""
957
+ client = get_client()
958
+ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn
959
+
960
+ with Progress(
961
+ TextColumn("{task.description}"),
962
+ BarColumn(),
963
+ TaskProgressColumn(),
964
+ TextColumn("{task.fields[current_job]}"),
965
+ TextColumn("{task.fields[current_stage]}"),
966
+ transient=True,
967
+ ) as progress:
968
+ task_id = progress.add_task(
969
+ "Exporting group",
970
+ total=0,
971
+ current_job="",
972
+ current_stage="",
973
+ )
974
+
975
+ def _on_progress(done: int, total: int, job_id: str, stage: str) -> None:
976
+ current_job = f"job={job_id}" if job_id else ""
977
+ current_stage = f"stage={stage}" if stage else ""
978
+ progress.update(
979
+ task_id,
980
+ total=total,
981
+ completed=done,
982
+ current_job=current_job,
983
+ current_stage=current_stage,
984
+ )
985
+
986
+ dest = export_group(
987
+ client,
988
+ group_id,
989
+ output,
990
+ progress_callback=_on_progress,
991
+ workers=workers,
992
+ )
993
+ print(f"[green]Group exported:[/] {group_id}")
994
+ print(f" path: {dest}")
995
+
996
+
997
+ @group_app.command("wait")
998
+ def group_wait(
999
+ group_id: str = typer.Argument(..., help="Group ID"),
1000
+ interval: float = typer.Option(
1001
+ 5.0, "--interval", "-i", min=0.1, help="Polling interval (seconds)"
1002
+ ),
1003
+ ):
1004
+ """Poll until all jobs under a Group finish."""
1005
+ client = get_client()
1006
+ result = wait_for_group(client, group_id, interval)
1007
+ stats = result.get("stats") or {}
1008
+ print(f"[green]Group finished:[/] {group_id}")
1009
+ print(f" total: {stats.get('total', 0)}")
1010
+ print(f" succeeded: {stats.get('succeeded', 0)}")
1011
+ print(f" failed: {stats.get('failed', 0)}")
1012
+ print(f" cancelled: {stats.get('cancelled', 0)}")
1013
+
1014
+
1015
+ if __name__ == "__main__":
1016
+ app()