synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -491,10 +491,9 @@ class RolloutTracingContext:
|
|
|
491
491
|
getattr(request.record, "trace_format", "compact") or "compact"
|
|
492
492
|
).lower()
|
|
493
493
|
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
494
|
-
|
|
495
|
-
"[TRACE_DEBUG] RolloutTracingContext init: trace_format
|
|
496
|
-
|
|
497
|
-
self.return_trace,
|
|
494
|
+
print(
|
|
495
|
+
f"[TRACE_DEBUG] RolloutTracingContext init: trace_format={self.trace_format} return_trace={self.return_trace}",
|
|
496
|
+
flush=True,
|
|
498
497
|
)
|
|
499
498
|
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
500
499
|
self.session_trace = None
|
|
@@ -518,19 +517,24 @@ class RolloutTracingContext:
|
|
|
518
517
|
|
|
519
518
|
async def start_session(self) -> None:
|
|
520
519
|
if not self.enabled or self.tracer is None:
|
|
520
|
+
print("[TRACE_DEBUG] start_session skipped: tracer disabled", flush=True)
|
|
521
521
|
return
|
|
522
522
|
try:
|
|
523
523
|
await self.tracer.initialize()
|
|
524
|
+
print("[TRACE_DEBUG] tracer initialized", flush=True)
|
|
524
525
|
except Exception as exc:
|
|
525
526
|
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
527
|
+
# Hard fail: tracing requested but cannot initialize
|
|
528
|
+
raise
|
|
526
529
|
try:
|
|
527
530
|
await self.tracer.start_session(
|
|
528
531
|
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
529
532
|
)
|
|
533
|
+
print(f"[TRACE_DEBUG] start_session succeeded for run_id={self.run_id}", flush=True)
|
|
530
534
|
except Exception as exc:
|
|
531
535
|
logger.info("TRACING_START_FAIL: %s", exc)
|
|
532
|
-
|
|
533
|
-
|
|
536
|
+
# Hard fail: tracing requested but cannot start session
|
|
537
|
+
raise
|
|
534
538
|
|
|
535
539
|
async def start_decision(self, turn_number: int) -> None:
|
|
536
540
|
self.current_turn = turn_number
|
|
@@ -595,7 +599,7 @@ class RolloutTracingContext:
|
|
|
595
599
|
# Debug: Check message count
|
|
596
600
|
if self.tracer and self.tracer._current_trace:
|
|
597
601
|
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
598
|
-
|
|
602
|
+
print(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages", flush=True)
|
|
599
603
|
|
|
600
604
|
def _content_to_text(self, content: Any) -> str:
|
|
601
605
|
if isinstance(content, str):
|
|
@@ -669,15 +673,19 @@ class RolloutTracingContext:
|
|
|
669
673
|
return
|
|
670
674
|
if self.enabled and self.tracer is not None:
|
|
671
675
|
try:
|
|
676
|
+
payload = {
|
|
677
|
+
"role": "assistant",
|
|
678
|
+
"tool_calls": tool_calls,
|
|
679
|
+
}
|
|
672
680
|
await self.tracer.record_message(
|
|
673
|
-
content=
|
|
674
|
-
message_type="assistant",
|
|
681
|
+
content=payload,
|
|
682
|
+
message_type="assistant",
|
|
675
683
|
metadata={**self._message_metadata(), "is_tool_call": True},
|
|
676
684
|
)
|
|
677
685
|
if self.tracer._current_trace:
|
|
678
|
-
|
|
679
|
-
"[TRACE_DEBUG] After tool invocation: messages
|
|
680
|
-
|
|
686
|
+
print(
|
|
687
|
+
f"[TRACE_DEBUG] After tool invocation: messages={len(self.tracer._current_trace.markov_blanket_message_history)}",
|
|
688
|
+
flush=True,
|
|
681
689
|
)
|
|
682
690
|
except Exception as exc:
|
|
683
691
|
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
@@ -784,9 +792,33 @@ class RolloutTracingContext:
|
|
|
784
792
|
}
|
|
785
793
|
)
|
|
786
794
|
|
|
795
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
796
|
+
assistant_text = self._content_to_text(assistant_content)
|
|
797
|
+
|
|
798
|
+
if self.enabled and self.tracer is not None:
|
|
799
|
+
assistant_payload: dict[str, Any] = {
|
|
800
|
+
"role": "assistant",
|
|
801
|
+
"content": assistant_structured,
|
|
802
|
+
"text": assistant_text,
|
|
803
|
+
}
|
|
804
|
+
if isinstance(assistant_message, dict):
|
|
805
|
+
if assistant_message.get("tool_calls"):
|
|
806
|
+
assistant_payload["tool_calls"] = assistant_message.get("tool_calls")
|
|
807
|
+
if assistant_message.get("reasoning"):
|
|
808
|
+
assistant_payload["reasoning"] = assistant_message.get("reasoning")
|
|
809
|
+
if assistant_message.get("thinking"):
|
|
810
|
+
assistant_payload["thinking"] = assistant_message.get("thinking")
|
|
811
|
+
try:
|
|
812
|
+
await self.tracer.record_message(
|
|
813
|
+
content=assistant_payload,
|
|
814
|
+
message_type="assistant",
|
|
815
|
+
metadata=self._message_metadata(),
|
|
816
|
+
)
|
|
817
|
+
except Exception as exc:
|
|
818
|
+
logger.debug("TRACING_ASSISTANT_MSG_FAIL: %s", exc)
|
|
819
|
+
|
|
787
820
|
if self.sft_output_dir is not None:
|
|
788
821
|
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
789
|
-
assistant_text = self._content_to_text(assistant_content)
|
|
790
822
|
dialogue_structured: list[dict[str, Any]] = []
|
|
791
823
|
for content in self.latest_system_prompt_content:
|
|
792
824
|
if content is None:
|
|
@@ -951,17 +983,23 @@ class RolloutTracingContext:
|
|
|
951
983
|
# Debug: Check message count before end_session
|
|
952
984
|
if self.tracer._current_trace:
|
|
953
985
|
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
954
|
-
|
|
955
|
-
|
|
986
|
+
print(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace", flush=True)
|
|
987
|
+
|
|
956
988
|
self.session_trace = await self.tracer.end_session()
|
|
957
989
|
|
|
958
990
|
# Debug: Check if session was saved
|
|
959
991
|
if self.session_trace:
|
|
960
|
-
|
|
992
|
+
print(
|
|
993
|
+
f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}",
|
|
994
|
+
flush=True,
|
|
995
|
+
)
|
|
961
996
|
self.session_trace.metadata.update(self.metadata_updates)
|
|
962
|
-
|
|
997
|
+
print(
|
|
998
|
+
f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}",
|
|
999
|
+
flush=True,
|
|
1000
|
+
)
|
|
963
1001
|
else:
|
|
964
|
-
|
|
1002
|
+
print("[TRACE_DEBUG] end_session returned None!", flush=True)
|
|
965
1003
|
except Exception as exc:
|
|
966
1004
|
logger.warning(f"TRACING_END_SESSION_FAIL: {exc}", exc_info=True)
|
|
967
1005
|
self.session_trace = None
|
|
@@ -1001,9 +1039,9 @@ class RolloutTracingContext:
|
|
|
1001
1039
|
if self.trace_format in ("full", "structured"):
|
|
1002
1040
|
payload = session_trace.to_dict()
|
|
1003
1041
|
payload.setdefault("metadata", {}).update(self.metadata_updates)
|
|
1004
|
-
|
|
1005
|
-
"[TRACE_DEBUG] build_trace_payload returning structured trace with messages
|
|
1006
|
-
|
|
1042
|
+
print(
|
|
1043
|
+
f"[TRACE_DEBUG] build_trace_payload returning structured trace with messages={len(payload.get('markov_blanket_message_history') or [])}",
|
|
1044
|
+
flush=True,
|
|
1007
1045
|
)
|
|
1008
1046
|
return payload
|
|
1009
1047
|
|
|
@@ -1943,6 +1981,15 @@ async def execute_rollout(
|
|
|
1943
1981
|
if 'policy_config_snapshot' not in locals():
|
|
1944
1982
|
policy_config_snapshot = {}
|
|
1945
1983
|
|
|
1984
|
+
# Normalize inference URL for trajectory (and ensure no path in query)
|
|
1985
|
+
try:
|
|
1986
|
+
from .utils import force_normalize_chat_completions_url, ensure_chat_completions_url
|
|
1987
|
+
inference_url = force_normalize_chat_completions_url(inference_url)
|
|
1988
|
+
# apply mode-aware normalization too (keeps cid, appends path if missing)
|
|
1989
|
+
inference_url = ensure_chat_completions_url(inference_url, mode=request.mode)
|
|
1990
|
+
except Exception:
|
|
1991
|
+
pass
|
|
1992
|
+
|
|
1946
1993
|
logger.info(
|
|
1947
1994
|
"ROLLOUT_TRAJECTORY: run_id=%s policy_id=%s inference_url=%s trace_id=%s",
|
|
1948
1995
|
request.run_id,
|
|
@@ -2057,6 +2104,16 @@ async def execute_rollout(
|
|
|
2057
2104
|
if metrics.num_steps <= 0:
|
|
2058
2105
|
raise HTTPException(status_code=500, detail="no_steps_executed: avg_turns == 0")
|
|
2059
2106
|
|
|
2107
|
+
# Ensure at least one tool call executed successfully
|
|
2108
|
+
tool_call_executed = any(
|
|
2109
|
+
isinstance(step.tool_calls, list) and len(step.tool_calls) > 0 for step in trajectory_steps
|
|
2110
|
+
)
|
|
2111
|
+
if not tool_call_executed:
|
|
2112
|
+
raise HTTPException(
|
|
2113
|
+
status_code=502,
|
|
2114
|
+
detail="no_tool_calls_executed: model failed to produce actionable tool calls.",
|
|
2115
|
+
)
|
|
2116
|
+
|
|
2060
2117
|
response = RolloutResponse(
|
|
2061
2118
|
run_id=request.run_id,
|
|
2062
2119
|
trajectories=[trajectory],
|
|
@@ -11,6 +11,129 @@ logger = logging.getLogger(__name__)
|
|
|
11
11
|
_CHAT_COMPLETIONS_SUFFIX = "/v1/chat/completions"
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
def force_normalize_chat_completions_url(raw_url: Any) -> str:
|
|
15
|
+
"""
|
|
16
|
+
Bulletproof normalizer: converts ANY malformed inference URL into the
|
|
17
|
+
correct chat-completions URL form.
|
|
18
|
+
|
|
19
|
+
Rules:
|
|
20
|
+
- Final path MUST end with /v1/chat/completions
|
|
21
|
+
- Query MUST NOT contain any '/' characters (no path segments in query)
|
|
22
|
+
- If the original query contained a path (e.g., '?cid=.../v1/chat/completions'),
|
|
23
|
+
extract that path and move it to the URL path; keep remaining query params
|
|
24
|
+
- Preserve scheme, host, port and existing query params order as much as possible
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
https://host?cid=trace_123/v1/chat/completions
|
|
28
|
+
-> https://host/v1/chat/completions?cid=trace_123
|
|
29
|
+
https://host:8000?cid=trace_abc/v1/chat/completions&foo=bar
|
|
30
|
+
-> https://host:8000/v1/chat/completions?cid=trace_abc&foo=bar
|
|
31
|
+
https://host?cid=trace_123/v1/chat/completions?other=param
|
|
32
|
+
-> https://host/v1/chat/completions?cid=trace_123&other=param
|
|
33
|
+
"""
|
|
34
|
+
if not isinstance(raw_url, str):
|
|
35
|
+
return raw_url
|
|
36
|
+
url = raw_url.strip()
|
|
37
|
+
if not url:
|
|
38
|
+
return raw_url
|
|
39
|
+
|
|
40
|
+
parsed = urlparse(url)
|
|
41
|
+
path = (parsed.path or "").rstrip("/")
|
|
42
|
+
query = parsed.query or ""
|
|
43
|
+
|
|
44
|
+
# If query contains a path (has '/'), extract and repair
|
|
45
|
+
if query and "/" in query:
|
|
46
|
+
# Split query at the first '/' (everything before is real query params)
|
|
47
|
+
before_slash, after_slash = query.split("/", 1)
|
|
48
|
+
|
|
49
|
+
# after_slash may contain path and then more query params separated by '&' or '?' (malformed)
|
|
50
|
+
sep_indices = [i for i in [after_slash.find("&"), after_slash.find("?")] if i >= 0]
|
|
51
|
+
cut_idx = min(sep_indices) if sep_indices else len(after_slash)
|
|
52
|
+
path_from_query = "/" + after_slash[:cut_idx] # restore leading '/'
|
|
53
|
+
extra_query = after_slash[cut_idx + 1 :] if cut_idx < len(after_slash) else ""
|
|
54
|
+
|
|
55
|
+
# Merge query params: base (before_slash) + extra_query
|
|
56
|
+
merged_query = before_slash
|
|
57
|
+
if extra_query:
|
|
58
|
+
merged_query = f"{merged_query}&{extra_query}" if merged_query else extra_query
|
|
59
|
+
|
|
60
|
+
# Decide final path
|
|
61
|
+
if path_from_query.startswith(_CHAT_COMPLETIONS_SUFFIX):
|
|
62
|
+
final_path = path_from_query
|
|
63
|
+
else:
|
|
64
|
+
final_path = f"{path_from_query.rstrip('/')}{_CHAT_COMPLETIONS_SUFFIX}"
|
|
65
|
+
|
|
66
|
+
parsed = parsed._replace(path=final_path, query=merged_query)
|
|
67
|
+
url = urlunparse(parsed)
|
|
68
|
+
parsed = urlparse(url)
|
|
69
|
+
path = parsed.path or ""
|
|
70
|
+
query = parsed.query or ""
|
|
71
|
+
|
|
72
|
+
# Ensure path ends with chat completions suffix
|
|
73
|
+
if not path.endswith(_CHAT_COMPLETIONS_SUFFIX):
|
|
74
|
+
new_path = f"{path}{_CHAT_COMPLETIONS_SUFFIX}" if path else _CHAT_COMPLETIONS_SUFFIX
|
|
75
|
+
parsed = parsed._replace(path=new_path)
|
|
76
|
+
url = urlunparse(parsed)
|
|
77
|
+
parsed = urlparse(url)
|
|
78
|
+
path = parsed.path or ""
|
|
79
|
+
query = parsed.query or ""
|
|
80
|
+
|
|
81
|
+
# Final validation: no '/' in query
|
|
82
|
+
if query and "/" in query:
|
|
83
|
+
# As a last resort, drop anything after the first '/'
|
|
84
|
+
safe_query = query.split("/")[0]
|
|
85
|
+
parsed = parsed._replace(query=safe_query)
|
|
86
|
+
url = urlunparse(parsed)
|
|
87
|
+
|
|
88
|
+
return url
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_url_structure(url: str, context: str = "") -> None:
|
|
92
|
+
"""
|
|
93
|
+
Validate that a URL has correct structure (path before query, not vice versa).
|
|
94
|
+
|
|
95
|
+
Raises ValueError if URL is malformed.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
url: The URL to validate
|
|
99
|
+
context: Optional context for error messages
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: If URL is malformed (path-like segments in query string)
|
|
103
|
+
"""
|
|
104
|
+
if not isinstance(url, str) or not url.strip():
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
parsed = urlparse(url)
|
|
109
|
+
query = parsed.query or ""
|
|
110
|
+
|
|
111
|
+
# CRITICAL CHECK: If query contains path-like segments (contains /), it's malformed
|
|
112
|
+
if query and "/" in query:
|
|
113
|
+
path_segment = query.split("/", 1)[1] if "/" in query else ""
|
|
114
|
+
error_msg = (
|
|
115
|
+
f"FATAL [TASK_APP_URL_VALIDATION]: Malformed inference URL detected!\n"
|
|
116
|
+
f"\n"
|
|
117
|
+
f"URL: {url}\n"
|
|
118
|
+
f"Context: {context}\n"
|
|
119
|
+
f"\n"
|
|
120
|
+
f"The URL has a path-like segment ('/{path_segment}') in the query string.\n"
|
|
121
|
+
f"This indicates incorrect URL construction upstream.\n"
|
|
122
|
+
f"\n"
|
|
123
|
+
f"Expected: https://host/v1/chat/completions?cid=trace_123\n"
|
|
124
|
+
f"Malformed: https://host?cid=trace_123/v1/chat/completions\n"
|
|
125
|
+
f"\n"
|
|
126
|
+
f"This should be caught by the trainer, but if you see this,\n"
|
|
127
|
+
f"the trainer's URL validation may have failed.\n"
|
|
128
|
+
)
|
|
129
|
+
logger.error(error_msg)
|
|
130
|
+
raise ValueError(error_msg)
|
|
131
|
+
except ValueError:
|
|
132
|
+
raise
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.warning(f"[URL_VALIDATION] Failed to parse URL: {url} (context: {context}, error: {e})")
|
|
135
|
+
|
|
136
|
+
|
|
14
137
|
def ensure_chat_completions_url(raw_url: Any, mode: str | None = None) -> Any:
|
|
15
138
|
"""
|
|
16
139
|
Ensure inference URLs point at the chat completions endpoint.
|
|
@@ -43,9 +166,75 @@ def ensure_chat_completions_url(raw_url: Any, mode: str | None = None) -> Any:
|
|
|
43
166
|
|
|
44
167
|
parsed = urlparse(url)
|
|
45
168
|
path = (parsed.path or "").rstrip("/")
|
|
169
|
+
query = parsed.query
|
|
170
|
+
|
|
171
|
+
logger.debug(
|
|
172
|
+
"ensure_chat_completions_url: parsing url=%s -> path=%r query=%r",
|
|
173
|
+
url,
|
|
174
|
+
path,
|
|
175
|
+
query,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# CRITICAL: Check for malformed URLs (path in query) and fix them FIRST
|
|
179
|
+
# Example: https://host?cid=trace_123/v1/chat/completions
|
|
180
|
+
# Should be: https://host/v1/chat/completions?cid=trace_123
|
|
181
|
+
if query and "/" in query:
|
|
182
|
+
logger.error(
|
|
183
|
+
f"[URL_FIX] Detected malformed URL in ensure_chat_completions_url: {url}\n"
|
|
184
|
+
f"Path-like segment found in query string. Attempting to fix..."
|
|
185
|
+
)
|
|
186
|
+
# Split query at first "/" to separate query params from path
|
|
187
|
+
query_parts = query.split("/", 1)
|
|
188
|
+
if len(query_parts) == 2:
|
|
189
|
+
# query_parts[0] is the actual query (e.g., "cid=trace_123")
|
|
190
|
+
# query_parts[1] is the path that was incorrectly put in query
|
|
191
|
+
actual_query = query_parts[0]
|
|
192
|
+
path_and_more = query_parts[1] # Could be "v1/chat/completions" or "v1/chat/completions&foo=bar"
|
|
193
|
+
|
|
194
|
+
# Extract the path part (everything before "&" or "?" if present)
|
|
195
|
+
# Handle both "&" (query param separator) and "?" (another malformed query separator)
|
|
196
|
+
if "&" in path_and_more:
|
|
197
|
+
# Path is followed by more query params (separated by &)
|
|
198
|
+
path_segment, extra_query = path_and_more.split("&", 1)
|
|
199
|
+
path_in_query = "/" + path_segment # Restore leading slash
|
|
200
|
+
# Merge extra query params with actual_query
|
|
201
|
+
actual_query = f"{actual_query}&{extra_query}"
|
|
202
|
+
elif "?" in path_and_more:
|
|
203
|
+
# Path is followed by more query params (separated by ?, which is malformed)
|
|
204
|
+
path_segment, extra_query = path_and_more.split("?", 1)
|
|
205
|
+
path_in_query = "/" + path_segment # Restore leading slash
|
|
206
|
+
# Merge extra query params with actual_query (use & as separator)
|
|
207
|
+
actual_query = f"{actual_query}&{extra_query}"
|
|
208
|
+
else:
|
|
209
|
+
# No extra query params, just the path
|
|
210
|
+
path_in_query = "/" + path_and_more # Restore leading slash
|
|
211
|
+
|
|
212
|
+
# If the path_in_query already contains /v1/chat/completions, use it
|
|
213
|
+
# Otherwise, append /v1/chat/completions
|
|
214
|
+
if path_in_query.startswith("/v1/chat/completions"):
|
|
215
|
+
final_path = path_in_query
|
|
216
|
+
else:
|
|
217
|
+
# Append /v1/chat/completions to whatever path we found
|
|
218
|
+
final_path = path_in_query.rstrip("/") + "/v1/chat/completions"
|
|
219
|
+
|
|
220
|
+
# Reconstruct URL correctly: path comes before query
|
|
221
|
+
parsed = parsed._replace(path=final_path, query=actual_query)
|
|
222
|
+
fixed_url = urlunparse(parsed)
|
|
223
|
+
logger.warning(f"[URL_FIX] Fixed malformed URL:\n FROM: {url}\n TO: {fixed_url}")
|
|
224
|
+
url = fixed_url
|
|
225
|
+
# Re-parse after fix
|
|
226
|
+
parsed = urlparse(url)
|
|
227
|
+
path = parsed.path.rstrip("/")
|
|
228
|
+
query = parsed.query
|
|
229
|
+
else:
|
|
230
|
+
# Can't parse - this shouldn't happen but validate will catch it
|
|
231
|
+
logger.error(f"[URL_FIX] Could not parse malformed query: {query}")
|
|
232
|
+
_validate_url_structure(url, context="ensure_chat_completions_url input - cannot fix")
|
|
233
|
+
|
|
46
234
|
if path.endswith("/v1/chat/completions"):
|
|
47
235
|
logger.debug("ensure_chat_completions_url: URL already normalized %s", url)
|
|
48
|
-
#
|
|
236
|
+
# Validate final URL
|
|
237
|
+
_validate_url_structure(url, context="ensure_chat_completions_url output")
|
|
49
238
|
return url
|
|
50
239
|
|
|
51
240
|
if not path:
|
|
@@ -55,6 +244,10 @@ def ensure_chat_completions_url(raw_url: Any, mode: str | None = None) -> Any:
|
|
|
55
244
|
|
|
56
245
|
rebuilt = parsed._replace(path=new_path)
|
|
57
246
|
normalized = urlunparse(rebuilt)
|
|
247
|
+
|
|
248
|
+
# CRITICAL: Validate the normalized URL
|
|
249
|
+
_validate_url_structure(normalized, context="ensure_chat_completions_url output")
|
|
250
|
+
|
|
58
251
|
logger.info(
|
|
59
252
|
"ensure_chat_completions_url: RL mode - normalized inference URL from %s to %s",
|
|
60
253
|
url,
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""GEPA benchmark task apps (HotpotQA, IFBench, HoVer, PUPA)."""
|
|
2
|
+
|
|
3
|
+
# Import modules for side effects (task app registration) when package is imported.
|
|
4
|
+
from . import hotpotqa_task_app # noqa: F401
|
|
5
|
+
from . import hover_task_app # noqa: F401
|
|
6
|
+
from . import ifbench_task_app # noqa: F401
|
|
7
|
+
from . import pupa_task_app # noqa: F401
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Shared helpers for GEPA benchmark task apps (HotpotQA, IFBench, HoVer, PUPA)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, Iterable, Mapping, Sequence
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from fastapi import HTTPException
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _resolve_inference_url(base_url: str) -> str:
|
|
15
|
+
"""Normalise a base inference URL to the chat completions endpoint."""
|
|
16
|
+
|
|
17
|
+
normalised = (base_url or "").rstrip("/")
|
|
18
|
+
if not normalised:
|
|
19
|
+
raise RuntimeError("policy.config.inference_url required")
|
|
20
|
+
if normalised.endswith("/v1/chat/completions"):
|
|
21
|
+
return normalised
|
|
22
|
+
if normalised.endswith("/chat/completions"):
|
|
23
|
+
return normalised
|
|
24
|
+
if normalised.endswith("/v1"):
|
|
25
|
+
return f"{normalised}/chat/completions"
|
|
26
|
+
return f"{normalised}/v1/chat/completions"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
_PLACEHOLDER_PATTERN = re.compile(r"\{([^{}]+)\}")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _substitute_placeholders(text: str, values: Mapping[str, Any]) -> str:
|
|
33
|
+
"""Replace `{placeholder}` tokens in `text` with entries from `values`."""
|
|
34
|
+
|
|
35
|
+
def _replace(match: re.Match[str]) -> str:
|
|
36
|
+
key = match.group(1)
|
|
37
|
+
replacement = values.get(key)
|
|
38
|
+
return str(replacement) if replacement is not None else match.group(0)
|
|
39
|
+
|
|
40
|
+
return _PLACEHOLDER_PATTERN.sub(_replace, text)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def render_messages(
|
|
44
|
+
policy_config: Mapping[str, Any],
|
|
45
|
+
placeholders: Mapping[str, Any],
|
|
46
|
+
default_messages: Sequence[Mapping[str, str]],
|
|
47
|
+
) -> list[dict[str, str]]:
|
|
48
|
+
"""Render chat messages either from policy prompt patterns or defaults."""
|
|
49
|
+
|
|
50
|
+
prompt_config = policy_config.get("prompt") if isinstance(policy_config, Mapping) else None
|
|
51
|
+
rendered: list[dict[str, str]] = []
|
|
52
|
+
if prompt_config and isinstance(prompt_config, Mapping):
|
|
53
|
+
messages = prompt_config.get("messages")
|
|
54
|
+
if isinstance(messages, Sequence):
|
|
55
|
+
for entry in messages:
|
|
56
|
+
if not isinstance(entry, Mapping):
|
|
57
|
+
continue
|
|
58
|
+
role = str(entry.get("role") or "user")
|
|
59
|
+
pattern = entry.get("pattern") or entry.get("content") or ""
|
|
60
|
+
content = _substitute_placeholders(str(pattern), placeholders)
|
|
61
|
+
rendered.append({"role": role, "content": content})
|
|
62
|
+
if not rendered:
|
|
63
|
+
for entry in default_messages:
|
|
64
|
+
role = str(entry.get("role") or "user")
|
|
65
|
+
pattern = entry.get("pattern") or entry.get("content") or ""
|
|
66
|
+
content = _substitute_placeholders(str(pattern), placeholders)
|
|
67
|
+
rendered.append({"role": role, "content": content})
|
|
68
|
+
return rendered
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def call_chat_completion(
|
|
72
|
+
policy_config: Mapping[str, Any],
|
|
73
|
+
placeholders: Mapping[str, Any],
|
|
74
|
+
default_messages: Sequence[Mapping[str, str]],
|
|
75
|
+
*,
|
|
76
|
+
tool_spec: Sequence[Mapping[str, Any]] | None = None,
|
|
77
|
+
tool_choice: Mapping[str, Any] | None = None,
|
|
78
|
+
timeout: float = 60.0,
|
|
79
|
+
) -> tuple[str, dict[str, Any], list[dict[str, Any]]]:
|
|
80
|
+
"""Invoke an OpenAI-compatible chat/completions endpoint.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
response_text: The assistant message text (empty string if missing).
|
|
84
|
+
raw_response: The JSON payload from the provider.
|
|
85
|
+
messages: The messages sent to the model (after placeholder substitution).
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
if not isinstance(policy_config, Mapping):
|
|
89
|
+
raise RuntimeError("policy.config must be a mapping for chat completion calls")
|
|
90
|
+
|
|
91
|
+
messages = render_messages(policy_config, placeholders, default_messages)
|
|
92
|
+
|
|
93
|
+
model = policy_config.get("model")
|
|
94
|
+
if not model:
|
|
95
|
+
raise RuntimeError("policy.config.model required for rollout")
|
|
96
|
+
|
|
97
|
+
temperature = policy_config.get("temperature", 0.0)
|
|
98
|
+
max_tokens = policy_config.get("max_tokens")
|
|
99
|
+
max_completion_tokens = policy_config.get("max_completion_tokens", max_tokens or 512)
|
|
100
|
+
|
|
101
|
+
inference_url = policy_config.get("inference_url") or ""
|
|
102
|
+
final_url = _resolve_inference_url(str(inference_url))
|
|
103
|
+
|
|
104
|
+
payload: dict[str, Any] = {
|
|
105
|
+
"model": model,
|
|
106
|
+
"messages": messages,
|
|
107
|
+
"temperature": temperature,
|
|
108
|
+
"max_completion_tokens": max_completion_tokens,
|
|
109
|
+
}
|
|
110
|
+
if tool_spec:
|
|
111
|
+
payload["tools"] = list(tool_spec)
|
|
112
|
+
if tool_choice:
|
|
113
|
+
payload["tool_choice"] = tool_choice
|
|
114
|
+
|
|
115
|
+
# Prefer provider-specific keys, fall back to SYNTH/OPENAI.
|
|
116
|
+
proxy_keys = {
|
|
117
|
+
"GROQ_API_KEY": os.getenv("GROQ_API_KEY"),
|
|
118
|
+
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
|
119
|
+
"SYNTH_API_KEY": os.getenv("SYNTH_API_KEY"),
|
|
120
|
+
}
|
|
121
|
+
api_key = next((value for value in proxy_keys.values() if value), None)
|
|
122
|
+
|
|
123
|
+
headers = {"Content-Type": "application/json"}
|
|
124
|
+
if api_key:
|
|
125
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
126
|
+
|
|
127
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
|
|
128
|
+
response = await client.post(final_url, json=payload, headers=headers)
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
data = response.json()
|
|
132
|
+
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
|
133
|
+
raise HTTPException(
|
|
134
|
+
status_code=502,
|
|
135
|
+
detail=f"Inference provider returned invalid JSON: {response.text[:800]}",
|
|
136
|
+
) from exc
|
|
137
|
+
|
|
138
|
+
if response.status_code >= 500:
|
|
139
|
+
raise HTTPException(
|
|
140
|
+
status_code=502,
|
|
141
|
+
detail=f"Inference provider returned an error: {data}",
|
|
142
|
+
)
|
|
143
|
+
if response.status_code >= 400:
|
|
144
|
+
raise HTTPException(
|
|
145
|
+
status_code=400,
|
|
146
|
+
detail=f"Invalid inference request: {data}",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
response_text = ""
|
|
150
|
+
choices = data.get("choices") if isinstance(data, Mapping) else None
|
|
151
|
+
if isinstance(choices, Sequence) and choices:
|
|
152
|
+
message = choices[0].get("message")
|
|
153
|
+
if isinstance(message, Mapping):
|
|
154
|
+
response_text = str(message.get("content") or "")
|
|
155
|
+
|
|
156
|
+
return response_text, data, messages
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def normalise_answer(text: str) -> str:
|
|
160
|
+
"""Normalise free-form text answers (HotpotQA style)."""
|
|
161
|
+
|
|
162
|
+
lowered = text.lower()
|
|
163
|
+
# Remove punctuation and articles.
|
|
164
|
+
cleaned = re.sub(r"[^a-z0-9\s]", " ", lowered)
|
|
165
|
+
cleaned = re.sub(r"\b(a|an|the)\b", " ", cleaned)
|
|
166
|
+
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
|
167
|
+
return cleaned
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
_EMOJI_PATTERN = re.compile(
|
|
171
|
+
"["
|
|
172
|
+
"\U0001F600-\U0001F64F" # emoticons
|
|
173
|
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
|
174
|
+
"\U0001F680-\U0001F6FF" # transport & map symbols
|
|
175
|
+
"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
|
176
|
+
"\U00002700-\U000027BF"
|
|
177
|
+
"\U0001F900-\U0001F9FF"
|
|
178
|
+
"\U00002600-\U000026FF"
|
|
179
|
+
"\U00002B00-\U00002BFF"
|
|
180
|
+
"]",
|
|
181
|
+
flags=re.UNICODE,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def count_emojis(text: str) -> int:
|
|
186
|
+
"""Return rough count of emoji characters."""
|
|
187
|
+
|
|
188
|
+
return len(_EMOJI_PATTERN.findall(text))
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def tokenize(text: str) -> list[str]:
|
|
192
|
+
"""Simple whitespace/token splitter with punctuation stripping."""
|
|
193
|
+
|
|
194
|
+
cleaned = re.sub(r"[^\w\s]", " ", text.lower())
|
|
195
|
+
return [token for token in cleaned.split() if token]
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def sentence_split(text: str) -> list[str]:
|
|
199
|
+
"""Split text into sentences using punctuation heuristics."""
|
|
200
|
+
|
|
201
|
+
parts = re.split(r"(?<=[.!?])\s+", text.strip())
|
|
202
|
+
return [part.strip() for part in parts if part.strip()]
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def count_numbers(text: str) -> int:
|
|
206
|
+
"""Count occurrences of numeric tokens."""
|
|
207
|
+
|
|
208
|
+
return len(re.findall(r"\b\d+(?:\.\d+)?\b", text))
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def unique_word_count(tokens: Iterable[str]) -> int:
|
|
212
|
+
"""Return number of unique tokens."""
|
|
213
|
+
|
|
214
|
+
return len(set(tokens))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
PRONOUNS = {
|
|
218
|
+
"i",
|
|
219
|
+
"me",
|
|
220
|
+
"you",
|
|
221
|
+
"he",
|
|
222
|
+
"him",
|
|
223
|
+
"she",
|
|
224
|
+
"her",
|
|
225
|
+
"it",
|
|
226
|
+
"we",
|
|
227
|
+
"us",
|
|
228
|
+
"they",
|
|
229
|
+
"them",
|
|
230
|
+
"my",
|
|
231
|
+
"mine",
|
|
232
|
+
"your",
|
|
233
|
+
"yours",
|
|
234
|
+
"his",
|
|
235
|
+
"hers",
|
|
236
|
+
"its",
|
|
237
|
+
"our",
|
|
238
|
+
"ours",
|
|
239
|
+
"their",
|
|
240
|
+
"theirs",
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def count_pronouns(tokens: Iterable[str]) -> int:
|
|
245
|
+
"""Count pronoun tokens from a predefined list."""
|
|
246
|
+
|
|
247
|
+
return sum(1 for token in tokens if token in PRONOUNS)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
__all__ = [
|
|
251
|
+
"call_chat_completion",
|
|
252
|
+
"count_emojis",
|
|
253
|
+
"count_numbers",
|
|
254
|
+
"count_pronouns",
|
|
255
|
+
"normalise_answer",
|
|
256
|
+
"render_messages",
|
|
257
|
+
"sentence_split",
|
|
258
|
+
"tokenize",
|
|
259
|
+
"unique_word_count",
|
|
260
|
+
]
|