synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/__init__.py +1 -0
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/RECORD +294 -258
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,213 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Hello World: Banking77 intent classification with in-context injection
|
|
3
|
-
|
|
4
|
-
This script shows a minimal text-classification pipeline over the
|
|
5
|
-
Hugging Face Banking77 dataset using the Synth LM interface. It also
|
|
6
|
-
demonstrates a simple pre-send prompt-injection step as outlined in
|
|
7
|
-
`synth_ai/learning/prompts/injection_plan.txt`.
|
|
8
|
-
|
|
9
|
-
Notes
|
|
10
|
-
- Network access is required to download the dataset and call the model.
|
|
11
|
-
- Defaults to Groq with model `openai/gpt-oss-20b`.
|
|
12
|
-
- Export your key: `export GROQ_API_KEY=...`
|
|
13
|
-
- Override if needed: `export MODEL=openai/gpt-oss-20b VENDOR=groq`
|
|
14
|
-
|
|
15
|
-
Run
|
|
16
|
-
- `python -m synth_ai.learning.prompts.hello_world_in_context_injection_ex`
|
|
17
|
-
|
|
18
|
-
What "in-context injection" means here
|
|
19
|
-
- The script applies ordered substring replacements to the outgoing
|
|
20
|
-
`messages` array before calling the model. This mirrors the algorithm
|
|
21
|
-
described in `injection_plan.txt` without importing any non-existent
|
|
22
|
-
helper yet. You can adapt `INJECTION_RULES` to your needs.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
from __future__ import annotations
|
|
26
|
-
|
|
27
|
-
import asyncio
|
|
28
|
-
import os
|
|
29
|
-
import random
|
|
30
|
-
|
|
31
|
-
from datasets import load_dataset
|
|
32
|
-
|
|
33
|
-
# Use the v3 LM class present in this repo
|
|
34
|
-
from synth_ai.lm.core.main_v3 import LM, build_messages
|
|
35
|
-
|
|
36
|
-
# Use Overrides context to demonstrate matching by content
|
|
37
|
-
from synth_ai.lm.overrides import LMOverridesContext
|
|
38
|
-
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
|
39
|
-
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
40
|
-
|
|
41
|
-
INJECTION_RULES = [
|
|
42
|
-
{"find": "accnt", "replace": "account"},
|
|
43
|
-
{"find": "atm", "replace": "ATM"},
|
|
44
|
-
{"find": "txn", "replace": "transaction"},
|
|
45
|
-
]
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
async def classify_sample(lm: LM, text: str, label_names: list[str]) -> str:
|
|
49
|
-
"""Classify one Banking77 utterance and return the predicted label name."""
|
|
50
|
-
labels_joined = ", ".join(label_names)
|
|
51
|
-
system_message = (
|
|
52
|
-
"You are an intent classifier for the Banking77 dataset. "
|
|
53
|
-
"Given a customer message, respond with exactly one label from the list. "
|
|
54
|
-
"Return only the label text with no extra words.\n\n"
|
|
55
|
-
f"Valid labels: {labels_joined}"
|
|
56
|
-
)
|
|
57
|
-
user_message = f"Message: {text}\nLabel:"
|
|
58
|
-
|
|
59
|
-
# Build canonical messages; injection will be applied inside the vendor via context
|
|
60
|
-
messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
|
|
61
|
-
resp = await lm.respond_async(messages=messages)
|
|
62
|
-
raw = (resp.raw_response or "").strip()
|
|
63
|
-
return raw
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
async def main() -> None:
|
|
67
|
-
# Configurable model/provider via env, with sensible defaults
|
|
68
|
-
# Default to Groq hosting `openai/gpt-oss-20b`
|
|
69
|
-
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
|
70
|
-
vendor = os.getenv("VENDOR", "groq")
|
|
71
|
-
|
|
72
|
-
# Construct LM
|
|
73
|
-
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
|
74
|
-
|
|
75
|
-
# Load Banking77 dataset
|
|
76
|
-
# Columns: {"text": str, "label": int}; label names at ds.features["label"].names
|
|
77
|
-
print("Loading Banking77 dataset (split='test')...")
|
|
78
|
-
ds = load_dataset("banking77", split="test")
|
|
79
|
-
label_names: list[str] = ds.features["label"].names # type: ignore
|
|
80
|
-
|
|
81
|
-
# Sample a few items for a quick demo
|
|
82
|
-
n = int(os.getenv("N_SAMPLES", "8"))
|
|
83
|
-
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
|
84
|
-
|
|
85
|
-
correct = 0
|
|
86
|
-
# Apply overrides for all calls in this block (match by content)
|
|
87
|
-
overrides = [
|
|
88
|
-
{"match": {"contains": "atm", "role": "user"}, "injection_rules": INJECTION_RULES},
|
|
89
|
-
{"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
|
|
90
|
-
]
|
|
91
|
-
with LMOverridesContext(overrides):
|
|
92
|
-
for i, idx in enumerate(idxs, start=1):
|
|
93
|
-
text: str = ds[idx]["text"] # type: ignore
|
|
94
|
-
gold_label_idx: int = int(ds[idx]["label"]) # type: ignore
|
|
95
|
-
gold_label = label_names[gold_label_idx]
|
|
96
|
-
|
|
97
|
-
try:
|
|
98
|
-
pred = await classify_sample(lm, text, label_names)
|
|
99
|
-
except Exception as e:
|
|
100
|
-
print(f"[{i}] Error calling model: {e}")
|
|
101
|
-
break
|
|
102
|
-
|
|
103
|
-
# Normalize and check exact match; if not exact, attempt a loose fallback
|
|
104
|
-
norm_pred = pred.strip().lower()
|
|
105
|
-
label_lookup = {ln.lower(): ln for ln in label_names}
|
|
106
|
-
pred_label = label_lookup.get(norm_pred)
|
|
107
|
-
if pred_label is None:
|
|
108
|
-
# Fallback: pick the label with highest substring overlap (very naive)
|
|
109
|
-
# This avoids extra deps; feel free to replace with a better matcher.
|
|
110
|
-
def score(cand: str) -> int:
|
|
111
|
-
c = cand.lower()
|
|
112
|
-
return sum(1 for w in c.split() if w in norm_pred)
|
|
113
|
-
|
|
114
|
-
pred_label = max(label_names, key=score)
|
|
115
|
-
|
|
116
|
-
is_correct = pred_label == gold_label
|
|
117
|
-
correct += int(is_correct)
|
|
118
|
-
print(
|
|
119
|
-
f"[{i}] text={text!r}\n gold={gold_label}\n pred={pred} -> mapped={pred_label} {'✅' if is_correct else '❌'}"
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
if idxs:
|
|
123
|
-
acc = correct / len(idxs)
|
|
124
|
-
print(f"\nSamples: {len(idxs)} | Correct: {correct} | Accuracy: {acc:.2%}")
|
|
125
|
-
|
|
126
|
-
# ------------------------------
|
|
127
|
-
# Integration tests (three paths)
|
|
128
|
-
# ------------------------------
|
|
129
|
-
print("\nRunning integration tests with in-context injection...")
|
|
130
|
-
test_text = "I used the atm to withdraw cash."
|
|
131
|
-
|
|
132
|
-
# 1) LM path with v3 tracing: verify substitution in traced messages
|
|
133
|
-
tracer = SessionTracer()
|
|
134
|
-
await tracer.start_session(metadata={"test": "lm_injection"})
|
|
135
|
-
await tracer.start_timestep(step_id="lm_test")
|
|
136
|
-
# Use a tracer-bound LM instance
|
|
137
|
-
lm_traced = LM(model=model, vendor=vendor, temperature=0.0, session_tracer=tracer)
|
|
138
|
-
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
|
139
|
-
_ = await classify_sample(lm_traced, test_text, label_names)
|
|
140
|
-
# inspect trace
|
|
141
|
-
events = [
|
|
142
|
-
e
|
|
143
|
-
for e in (tracer.current_session.event_history if tracer.current_session else [])
|
|
144
|
-
if isinstance(e, LMCAISEvent)
|
|
145
|
-
]
|
|
146
|
-
assert events, "No LMCAISEvent recorded by SessionTracer"
|
|
147
|
-
cr = events[-1].call_records[0]
|
|
148
|
-
traced_user = ""
|
|
149
|
-
for m in cr.input_messages:
|
|
150
|
-
if m.role == "user":
|
|
151
|
-
for part in m.parts:
|
|
152
|
-
if getattr(part, "type", None) == "text":
|
|
153
|
-
traced_user += part.text or ""
|
|
154
|
-
assert "ATM" in traced_user, f"Expected substitution in traced prompt; got: {traced_user!r}"
|
|
155
|
-
print("LM path trace verified: substitution present in traced prompt.")
|
|
156
|
-
await tracer.end_timestep()
|
|
157
|
-
await tracer.end_session()
|
|
158
|
-
|
|
159
|
-
# 2) OpenAI wrapper path (AsyncOpenAI to Groq): ensure apply_injection is active
|
|
160
|
-
try:
|
|
161
|
-
import synth_ai.lm.provider_support.openai as _synth_openai_patch # noqa: F401
|
|
162
|
-
from openai import AsyncOpenAI
|
|
163
|
-
|
|
164
|
-
base_url = os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")
|
|
165
|
-
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
166
|
-
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
|
167
|
-
messages = [
|
|
168
|
-
{"role": "system", "content": "Echo user label."},
|
|
169
|
-
{"role": "user", "content": f"Please classify: {test_text}"},
|
|
170
|
-
]
|
|
171
|
-
with LMOverridesContext(
|
|
172
|
-
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
|
173
|
-
):
|
|
174
|
-
_ = await client.chat.completions.create(
|
|
175
|
-
model=model, messages=messages, temperature=0
|
|
176
|
-
)
|
|
177
|
-
# Not all models echo input; instead, verify that our injected expectation matches
|
|
178
|
-
expected_user = messages[1]["content"].replace("atm", "ATM")
|
|
179
|
-
if messages[1]["content"] == expected_user:
|
|
180
|
-
print("OpenAI wrapper: input already normalized; skipping assertion.")
|
|
181
|
-
else:
|
|
182
|
-
print("OpenAI wrapper: sent message contains substitution expectation:", expected_user)
|
|
183
|
-
except Exception as e:
|
|
184
|
-
print("OpenAI wrapper test skipped due to error:", e)
|
|
185
|
-
|
|
186
|
-
# 3) Anthropic wrapper path (AsyncClient): ensure apply_injection is active
|
|
187
|
-
try:
|
|
188
|
-
import anthropic
|
|
189
|
-
import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
|
|
190
|
-
|
|
191
|
-
a_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022")
|
|
192
|
-
a_key = os.getenv("ANTHROPIC_API_KEY")
|
|
193
|
-
if a_key:
|
|
194
|
-
a_client = anthropic.AsyncClient(api_key=a_key)
|
|
195
|
-
with LMOverridesContext(
|
|
196
|
-
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
|
197
|
-
):
|
|
198
|
-
_ = await a_client.messages.create(
|
|
199
|
-
model=a_model,
|
|
200
|
-
system="Echo user label.",
|
|
201
|
-
max_tokens=64,
|
|
202
|
-
temperature=0,
|
|
203
|
-
messages=[{"role": "user", "content": [{"type": "text", "text": test_text}]}],
|
|
204
|
-
)
|
|
205
|
-
print("Anthropic wrapper call completed (cannot reliably assert echo).")
|
|
206
|
-
else:
|
|
207
|
-
print("Anthropic wrapper test skipped: ANTHROPIC_API_KEY not set.")
|
|
208
|
-
except Exception as e:
|
|
209
|
-
print("Anthropic wrapper test skipped due to error:", e)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
if __name__ == "__main__":
|
|
213
|
-
asyncio.run(main())
|
|
@@ -1,289 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
MIPROv2-style prompt optimizer (modular, DSPy-inspired).
|
|
3
|
-
|
|
4
|
-
This module provides a modular implementation of the MIPROv2 pseudocode from DSPy,
|
|
5
|
-
adapted to a provider-agnostic "program" interface. The goal is to keep the
|
|
6
|
-
bootstrapping and search process pluggable so it can be swapped for alternatives.
|
|
7
|
-
|
|
8
|
-
Key ideas
|
|
9
|
-
- Program adapter: unify how we set instructions/demos and run predictions.
|
|
10
|
-
- Demo bootstrapping: gather high-confidence examples (by metric) as candidates.
|
|
11
|
-
- Instruction proposals: generated by a prompt model from contextual summaries.
|
|
12
|
-
- Search (placeholder): random/Bayesian-like search over (instructions × demos).
|
|
13
|
-
|
|
14
|
-
Notes
|
|
15
|
-
- The implementation is intentionally lightweight and dependency-free.
|
|
16
|
-
- "BayesOpt" here is a placeholder randomized proposer that uses history; you
|
|
17
|
-
can plug in a real optimizer later.
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
from __future__ import annotations
|
|
21
|
-
|
|
22
|
-
import random
|
|
23
|
-
from collections.abc import Callable, Sequence
|
|
24
|
-
from dataclasses import dataclass, replace
|
|
25
|
-
from typing import Any, Protocol
|
|
26
|
-
|
|
27
|
-
# ---------------------------
|
|
28
|
-
# Program adapter and protocols
|
|
29
|
-
# ---------------------------
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class PredictProgram(Protocol):
|
|
33
|
-
"""Minimal protocol a program must satisfy for MIPRO.
|
|
34
|
-
|
|
35
|
-
You can adapt your own pipeline to this by implementing these methods or
|
|
36
|
-
by wrapping it with `ProgramAdapter` below.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def deepcopy(self) -> PredictProgram: ...
|
|
40
|
-
|
|
41
|
-
def run(self, x: Any, *, model: Any | None = None) -> Any: ...
|
|
42
|
-
|
|
43
|
-
def with_instructions(self, instructions: dict[str, str]) -> PredictProgram: ...
|
|
44
|
-
|
|
45
|
-
def with_demos(self, demos: list[tuple[Any, Any]]) -> PredictProgram: ...
|
|
46
|
-
|
|
47
|
-
@property
|
|
48
|
-
def predictors(self) -> list[str]: ...
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@dataclass
|
|
52
|
-
class ProgramAdapter:
|
|
53
|
-
"""Adapter that turns a set of callables/state into a `PredictProgram`.
|
|
54
|
-
|
|
55
|
-
- run_fn: Callable[[x, model], y]
|
|
56
|
-
- state: arbitrary dict; supports `instructions` and `demos` keys
|
|
57
|
-
- predictors: list of predictor identifiers (e.g., names of prompt blocks)
|
|
58
|
-
- set_instructions: Callable to update instructions (per predictor)
|
|
59
|
-
- set_demos: Callable to update demos (global or per predictor)
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
run_fn: Callable[[Any, Any | None], Any]
|
|
63
|
-
state: dict[str, Any]
|
|
64
|
-
_predictors: list[str]
|
|
65
|
-
set_instructions: Callable[[dict[str, str], dict[str, Any]], dict[str, Any]]
|
|
66
|
-
set_demos: Callable[[list[tuple[Any, Any]], dict[str, Any]], dict[str, Any]]
|
|
67
|
-
|
|
68
|
-
def deepcopy(self) -> ProgramAdapter:
|
|
69
|
-
return replace(self, state={**self.state})
|
|
70
|
-
|
|
71
|
-
def run(self, x: Any, *, model: Any | None = None) -> Any:
|
|
72
|
-
return self.run_fn(x, model)
|
|
73
|
-
|
|
74
|
-
def with_instructions(self, instructions: dict[str, str]) -> ProgramAdapter:
|
|
75
|
-
new_state = self.set_instructions(instructions, {**self.state})
|
|
76
|
-
return replace(self, state=new_state)
|
|
77
|
-
|
|
78
|
-
def with_demos(self, demos: list[tuple[Any, Any]]) -> ProgramAdapter:
|
|
79
|
-
new_state = self.set_demos(demos, {**self.state})
|
|
80
|
-
return replace(self, state=new_state)
|
|
81
|
-
|
|
82
|
-
@property
|
|
83
|
-
def predictors(self) -> list[str]:
|
|
84
|
-
return list(self._predictors)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
# ---------------------------
|
|
88
|
-
# Utility helpers
|
|
89
|
-
# ---------------------------
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def summarize_dataset(trainset: Sequence[tuple[Any, Any]], max_items: int = 50) -> str:
|
|
93
|
-
n = len(trainset)
|
|
94
|
-
ex = ", ".join(
|
|
95
|
-
repr(trainset[i][0])[:40] for i in range(0, min(max_items, n), max(1, n // max_items or 1))
|
|
96
|
-
)
|
|
97
|
-
return f"Dataset size: {n}. Example inputs: {ex}"
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def summarize_program(prog: PredictProgram) -> str:
|
|
101
|
-
return f"Program predictors: {prog.predictors}"
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def random_tip(rng: random.Random) -> str:
|
|
105
|
-
tips = [
|
|
106
|
-
"Be concise.",
|
|
107
|
-
"Focus on the task definition.",
|
|
108
|
-
"Use the provided examples as guidance.",
|
|
109
|
-
"Avoid unnecessary verbosity.",
|
|
110
|
-
]
|
|
111
|
-
return rng.choice(tips)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def choose(items: Sequence[Any], rng: random.Random | None = None) -> Any:
|
|
115
|
-
r = rng or random
|
|
116
|
-
return r.choice(items)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
# ---------------------------
|
|
120
|
-
# Evaluator
|
|
121
|
-
# ---------------------------
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@dataclass
|
|
125
|
-
class EvalResult:
|
|
126
|
-
score: float
|
|
127
|
-
subscores: list[float]
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def evaluate_program(
|
|
131
|
-
program: PredictProgram, dataset: Sequence[tuple[Any, Any]], metric: Callable[[Any, Any], float]
|
|
132
|
-
) -> EvalResult:
|
|
133
|
-
subs = []
|
|
134
|
-
for x, y in dataset:
|
|
135
|
-
yhat = program.run(x)
|
|
136
|
-
subs.append(metric(yhat, y))
|
|
137
|
-
return EvalResult(score=float(sum(subs)) / max(1, len(subs)), subscores=subs)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
# ---------------------------
|
|
141
|
-
# MIPROv2 compile
|
|
142
|
-
# ---------------------------
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def mipro_v2_compile(
|
|
146
|
-
student: PredictProgram,
|
|
147
|
-
trainset: Sequence[tuple[Any, Any]],
|
|
148
|
-
valset: Sequence[tuple[Any, Any]],
|
|
149
|
-
metric: Callable[[Any, Any], float],
|
|
150
|
-
*,
|
|
151
|
-
prompt_model: Any,
|
|
152
|
-
task_model: Any,
|
|
153
|
-
max_bootstrapped_demos: int = 8,
|
|
154
|
-
max_labeled_demos: int = 4,
|
|
155
|
-
num_candidates: int = 8,
|
|
156
|
-
num_trials: int = 20,
|
|
157
|
-
minibatch: bool = True,
|
|
158
|
-
minibatch_size: int = 16,
|
|
159
|
-
minibatch_full_eval_steps: int = 5,
|
|
160
|
-
seed: int = 0,
|
|
161
|
-
auto: str = "light",
|
|
162
|
-
program_aware: bool = True,
|
|
163
|
-
data_aware: bool = True,
|
|
164
|
-
tip_aware: bool = True,
|
|
165
|
-
fewshot_aware: bool = True,
|
|
166
|
-
) -> tuple[PredictProgram, list[dict[str, Any]]]:
|
|
167
|
-
"""MIPROv2-style optimizer.
|
|
168
|
-
|
|
169
|
-
Arguments mirror the DSPy pseudocode but remain provider-agnostic. The
|
|
170
|
-
`prompt_model` must expose `generate_instructions(ctx, k)`; the `student`
|
|
171
|
-
program must implement the `PredictProgram` protocol.
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
rng = random.Random(seed)
|
|
175
|
-
program = student.deepcopy()
|
|
176
|
-
|
|
177
|
-
# Step 1: bootstrap few-shot example candidates
|
|
178
|
-
demo_candidates: list[dict[str, Any]] = []
|
|
179
|
-
for _ in range(num_candidates):
|
|
180
|
-
boot: list[tuple[Any, Any]] = []
|
|
181
|
-
# collect bootstrapped, self-consistent demos
|
|
182
|
-
while len(boot) < max_bootstrapped_demos:
|
|
183
|
-
x, y = rng.choice(trainset)
|
|
184
|
-
yhat = program.run(x, model=task_model)
|
|
185
|
-
if metric(yhat, y) == 1: # perfect match
|
|
186
|
-
boot.append((x, y))
|
|
187
|
-
labeled = rng.sample(list(trainset), k=min(max_labeled_demos, len(trainset)))
|
|
188
|
-
demo_candidates.append({"boot": boot, "labeled": labeled})
|
|
189
|
-
|
|
190
|
-
# Step 2: propose instruction candidates per predictor
|
|
191
|
-
instr_candidates: dict[str, list[str]] = {}
|
|
192
|
-
for pred in program.predictors or ["predictor"]:
|
|
193
|
-
ctx: dict[str, Any] = {}
|
|
194
|
-
if data_aware:
|
|
195
|
-
ctx["dataset_summary"] = summarize_dataset(trainset)
|
|
196
|
-
if program_aware:
|
|
197
|
-
ctx["program_summary"] = summarize_program(program)
|
|
198
|
-
if fewshot_aware and demo_candidates:
|
|
199
|
-
ctx["examples"] = choose(demo_candidates, rng)
|
|
200
|
-
if tip_aware:
|
|
201
|
-
ctx["tip"] = random_tip(rng)
|
|
202
|
-
cand = prompt_model.generate_instructions(ctx, k=num_candidates)
|
|
203
|
-
instr_candidates[pred] = list(cand)
|
|
204
|
-
|
|
205
|
-
# Step 3: Bayesian-optimization-like search (random proposer placeholder)
|
|
206
|
-
history: list[tuple[dict[str, Any], float]] = []
|
|
207
|
-
records: list[dict[str, Any]] = []
|
|
208
|
-
best_score = -1.0
|
|
209
|
-
best_cfg: dict[str, Any] | None = None
|
|
210
|
-
|
|
211
|
-
def propose(history_: list[tuple[dict[str, Any], float]]) -> dict[str, Any]:
|
|
212
|
-
# Placeholder: randomly sample from the cartesian product
|
|
213
|
-
instructions = {pred: choose(instr_candidates[pred], rng) for pred in instr_candidates}
|
|
214
|
-
demos = choose(demo_candidates, rng) if demo_candidates else None
|
|
215
|
-
return {"instructions": instructions, "demo_set": demos}
|
|
216
|
-
|
|
217
|
-
for t in range(1, num_trials + 1):
|
|
218
|
-
theta = propose(history)
|
|
219
|
-
program_t = program.with_instructions(theta["instructions"])
|
|
220
|
-
if theta.get("demo_set") is not None:
|
|
221
|
-
# Combine bootstrapped + labeled demos
|
|
222
|
-
ds = theta["demo_set"]
|
|
223
|
-
demo_set = list(ds.get("boot", [])) + list(ds.get("labeled", []))
|
|
224
|
-
program_t = program_t.with_demos(demo_set)
|
|
225
|
-
|
|
226
|
-
batch = (
|
|
227
|
-
valset
|
|
228
|
-
if not minibatch
|
|
229
|
-
else random.sample(list(valset), k=min(minibatch_size, len(valset)))
|
|
230
|
-
)
|
|
231
|
-
batch_res = evaluate_program(program_t, batch, metric)
|
|
232
|
-
s_t = batch_res.score
|
|
233
|
-
history.append((theta, s_t))
|
|
234
|
-
records.append(
|
|
235
|
-
{
|
|
236
|
-
"trial": t,
|
|
237
|
-
"evaluation": "batch" if minibatch else "full",
|
|
238
|
-
"score": s_t,
|
|
239
|
-
"intervention": {
|
|
240
|
-
"instructions": theta.get("instructions"),
|
|
241
|
-
"demo_set": theta.get("demo_set"),
|
|
242
|
-
},
|
|
243
|
-
}
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
if (not minibatch) or (t % max(1, minibatch_full_eval_steps) == 0):
|
|
247
|
-
full_res = evaluate_program(program_t, valset, metric)
|
|
248
|
-
s_full = full_res.score
|
|
249
|
-
if s_full > best_score:
|
|
250
|
-
best_score = s_full
|
|
251
|
-
best_cfg = theta
|
|
252
|
-
records.append(
|
|
253
|
-
{
|
|
254
|
-
"trial": t,
|
|
255
|
-
"evaluation": "full",
|
|
256
|
-
"score": s_full,
|
|
257
|
-
"intervention": {
|
|
258
|
-
"instructions": theta.get("instructions"),
|
|
259
|
-
"demo_set": theta.get("demo_set"),
|
|
260
|
-
},
|
|
261
|
-
}
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
if best_cfg is None:
|
|
265
|
-
return program, records
|
|
266
|
-
|
|
267
|
-
best_program = program.with_instructions(best_cfg["instructions"])
|
|
268
|
-
if best_cfg.get("demo_set") is not None:
|
|
269
|
-
ds = best_cfg["demo_set"]
|
|
270
|
-
demo_set = list(ds.get("boot", [])) + list(ds.get("labeled", []))
|
|
271
|
-
best_program = best_program.with_demos(demo_set)
|
|
272
|
-
return best_program, records
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
__all__ = [
|
|
276
|
-
"PredictProgram",
|
|
277
|
-
"ProgramAdapter",
|
|
278
|
-
"evaluate_program",
|
|
279
|
-
"mipro_v2_compile",
|
|
280
|
-
]
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
class ExampleTwoStepDag:
|
|
284
|
-
pass
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
"""
|
|
288
|
-
A -> B
|
|
289
|
-
"""
|