synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +56 -26
- examples/swe/task_app/hosted/rollout.py +42 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +799 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/models/supported.py +1 -0
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +48 -59
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +1922 -190
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +9 -9
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +24 -5
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +138 -39
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +56 -0
- synth_ai/task/rubrics/loaders.py +152 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +413 -6
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +66 -43
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,862 @@
|
|
|
1
|
+
from io import BytesIO
|
|
2
|
+
from PIL import Image
|
|
3
|
+
import os
|
|
4
|
+
import base64
|
|
5
|
+
import random
|
|
6
|
+
import time
|
|
7
|
+
import logging
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Union, List, Dict, Any, Optional
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
# Set up module logging
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Import LLM logger
|
|
16
|
+
from utils.llm_logger import log_llm_interaction, log_llm_error
|
|
17
|
+
|
|
18
|
+
# Define the retry decorator with exponential backoff
|
|
19
|
+
def retry_with_exponential_backoff(
|
|
20
|
+
func,
|
|
21
|
+
initial_delay: float = 1,
|
|
22
|
+
exponential_base: float = 2,
|
|
23
|
+
jitter: bool = True,
|
|
24
|
+
max_retries: int = 10,
|
|
25
|
+
errors: tuple = (Exception,),
|
|
26
|
+
):
|
|
27
|
+
"""Retry a function with exponential backoff."""
|
|
28
|
+
def wrapper(*args, **kwargs):
|
|
29
|
+
num_retries = 0
|
|
30
|
+
delay = initial_delay
|
|
31
|
+
while True:
|
|
32
|
+
try:
|
|
33
|
+
return func(*args, **kwargs)
|
|
34
|
+
except errors as e:
|
|
35
|
+
num_retries += 1
|
|
36
|
+
if num_retries > max_retries:
|
|
37
|
+
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
|
38
|
+
# Increase the delay with exponential factor and random jitter
|
|
39
|
+
delay *= exponential_base * (1 + jitter * random.random())
|
|
40
|
+
time.sleep(delay)
|
|
41
|
+
except Exception as e:
|
|
42
|
+
raise e
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
class VLMBackend(ABC):
|
|
46
|
+
"""Abstract base class for VLM backends"""
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
50
|
+
"""Process an image and text prompt"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
55
|
+
"""Process a text-only prompt"""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
class OpenAIBackend(VLMBackend):
|
|
59
|
+
"""OpenAI API backend"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, model_name: str, **kwargs):
|
|
62
|
+
try:
|
|
63
|
+
import openai
|
|
64
|
+
from openai import OpenAI
|
|
65
|
+
except ImportError:
|
|
66
|
+
raise ImportError("OpenAI package not found. Install with: pip install openai")
|
|
67
|
+
|
|
68
|
+
self.model_name = model_name
|
|
69
|
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
70
|
+
|
|
71
|
+
if not self.api_key:
|
|
72
|
+
raise ValueError("Error: OpenAI API key is missing! Set OPENAI_API_KEY environment variable.")
|
|
73
|
+
|
|
74
|
+
self.client = OpenAI(api_key=self.api_key)
|
|
75
|
+
self.errors = (openai.RateLimitError,)
|
|
76
|
+
|
|
77
|
+
@retry_with_exponential_backoff
|
|
78
|
+
def _call_completion(self, messages):
|
|
79
|
+
"""Calls the completions.create method with exponential backoff."""
|
|
80
|
+
return self.client.chat.completions.create(
|
|
81
|
+
model=self.model_name,
|
|
82
|
+
messages=messages
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
86
|
+
"""Process an image and text prompt using OpenAI API"""
|
|
87
|
+
start_time = time.time()
|
|
88
|
+
|
|
89
|
+
# Handle both PIL Images and numpy arrays
|
|
90
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
91
|
+
image = img
|
|
92
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
93
|
+
image = Image.fromarray(img)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
96
|
+
|
|
97
|
+
buffered = BytesIO()
|
|
98
|
+
image.save(buffered, format="PNG")
|
|
99
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
100
|
+
|
|
101
|
+
messages = [{
|
|
102
|
+
"role": "user",
|
|
103
|
+
"content": [
|
|
104
|
+
{"type": "text", "text": text},
|
|
105
|
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
|
106
|
+
]
|
|
107
|
+
}]
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
response = self._call_completion(messages)
|
|
111
|
+
result = response.choices[0].message.content
|
|
112
|
+
duration = time.time() - start_time
|
|
113
|
+
|
|
114
|
+
# Extract token usage if available
|
|
115
|
+
token_usage = {}
|
|
116
|
+
if hasattr(response, 'usage'):
|
|
117
|
+
token_usage = {
|
|
118
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
119
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
120
|
+
"total_tokens": response.usage.total_tokens
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
# Log the interaction
|
|
124
|
+
log_llm_interaction(
|
|
125
|
+
interaction_type=f"openai_{module_name}",
|
|
126
|
+
prompt=text,
|
|
127
|
+
response=result,
|
|
128
|
+
duration=duration,
|
|
129
|
+
metadata={"model": self.model_name, "backend": "openai", "has_image": True, "token_usage": token_usage},
|
|
130
|
+
model_info={"model": self.model_name, "backend": "openai"}
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return result
|
|
134
|
+
except Exception as e:
|
|
135
|
+
duration = time.time() - start_time
|
|
136
|
+
log_llm_error(
|
|
137
|
+
interaction_type=f"openai_{module_name}",
|
|
138
|
+
prompt=text,
|
|
139
|
+
error=str(e),
|
|
140
|
+
metadata={"model": self.model_name, "backend": "openai", "duration": duration, "has_image": True}
|
|
141
|
+
)
|
|
142
|
+
logger.error(f"OpenAI API error: {e}")
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
146
|
+
"""Process a text-only prompt using OpenAI API"""
|
|
147
|
+
start_time = time.time()
|
|
148
|
+
|
|
149
|
+
messages = [{
|
|
150
|
+
"role": "user",
|
|
151
|
+
"content": [{"type": "text", "text": text}]
|
|
152
|
+
}]
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
response = self._call_completion(messages)
|
|
156
|
+
result = response.choices[0].message.content
|
|
157
|
+
duration = time.time() - start_time
|
|
158
|
+
|
|
159
|
+
# Extract token usage if available
|
|
160
|
+
token_usage = {}
|
|
161
|
+
if hasattr(response, 'usage'):
|
|
162
|
+
token_usage = {
|
|
163
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
164
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
165
|
+
"total_tokens": response.usage.total_tokens
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Log the interaction
|
|
169
|
+
log_llm_interaction(
|
|
170
|
+
interaction_type=f"openai_{module_name}",
|
|
171
|
+
prompt=text,
|
|
172
|
+
response=result,
|
|
173
|
+
duration=duration,
|
|
174
|
+
metadata={"model": self.model_name, "backend": "openai", "has_image": False, "token_usage": token_usage},
|
|
175
|
+
model_info={"model": self.model_name, "backend": "openai"}
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return result
|
|
179
|
+
except Exception as e:
|
|
180
|
+
duration = time.time() - start_time
|
|
181
|
+
log_llm_error(
|
|
182
|
+
interaction_type=f"openai_{module_name}",
|
|
183
|
+
prompt=text,
|
|
184
|
+
error=str(e),
|
|
185
|
+
metadata={"model": self.model_name, "backend": "openai", "duration": duration, "has_image": False}
|
|
186
|
+
)
|
|
187
|
+
logger.error(f"OpenAI API error: {e}")
|
|
188
|
+
raise
|
|
189
|
+
|
|
190
|
+
class OpenRouterBackend(VLMBackend):
|
|
191
|
+
"""OpenRouter API backend"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, model_name: str, **kwargs):
|
|
194
|
+
try:
|
|
195
|
+
from openai import OpenAI
|
|
196
|
+
except ImportError:
|
|
197
|
+
raise ImportError("OpenAI package not found. Install with: pip install openai")
|
|
198
|
+
|
|
199
|
+
self.model_name = model_name
|
|
200
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
201
|
+
|
|
202
|
+
if not self.api_key:
|
|
203
|
+
raise ValueError("Error: OpenRouter API key is missing! Set OPENROUTER_API_KEY environment variable.")
|
|
204
|
+
|
|
205
|
+
self.client = OpenAI(
|
|
206
|
+
base_url="https://openrouter.ai/api/v1",
|
|
207
|
+
api_key=self.api_key,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@retry_with_exponential_backoff
|
|
211
|
+
def _call_completion(self, messages):
|
|
212
|
+
"""Calls the completions.create method with exponential backoff."""
|
|
213
|
+
return self.client.chat.completions.create(
|
|
214
|
+
model=self.model_name,
|
|
215
|
+
messages=messages
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
219
|
+
"""Process an image and text prompt using OpenRouter API"""
|
|
220
|
+
# Handle both PIL Images and numpy arrays
|
|
221
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
222
|
+
image = img
|
|
223
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
224
|
+
image = Image.fromarray(img)
|
|
225
|
+
else:
|
|
226
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
227
|
+
|
|
228
|
+
buffered = BytesIO()
|
|
229
|
+
image.save(buffered, format="PNG")
|
|
230
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
231
|
+
|
|
232
|
+
messages = [{
|
|
233
|
+
"role": "user",
|
|
234
|
+
"content": [
|
|
235
|
+
{"type": "text", "text": text},
|
|
236
|
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
|
237
|
+
]
|
|
238
|
+
}]
|
|
239
|
+
|
|
240
|
+
# Log the prompt
|
|
241
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
242
|
+
logger.info(f"[{module_name}] OPENROUTER VLM IMAGE QUERY:")
|
|
243
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
244
|
+
|
|
245
|
+
response = self._call_completion(messages)
|
|
246
|
+
result = response.choices[0].message.content
|
|
247
|
+
|
|
248
|
+
# Log the response
|
|
249
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
250
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
251
|
+
logger.info(f"[{module_name}] ---")
|
|
252
|
+
|
|
253
|
+
return result
|
|
254
|
+
|
|
255
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
256
|
+
"""Process a text-only prompt using OpenRouter API"""
|
|
257
|
+
messages = [{
|
|
258
|
+
"role": "user",
|
|
259
|
+
"content": [{"type": "text", "text": text}]
|
|
260
|
+
}]
|
|
261
|
+
|
|
262
|
+
# Log the prompt
|
|
263
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
264
|
+
logger.info(f"[{module_name}] OPENROUTER VLM TEXT QUERY:")
|
|
265
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
266
|
+
|
|
267
|
+
response = self._call_completion(messages)
|
|
268
|
+
result = response.choices[0].message.content
|
|
269
|
+
|
|
270
|
+
# Log the response
|
|
271
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
272
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
273
|
+
logger.info(f"[{module_name}] ---")
|
|
274
|
+
|
|
275
|
+
return result
|
|
276
|
+
|
|
277
|
+
class LocalHuggingFaceBackend(VLMBackend):
|
|
278
|
+
"""Local HuggingFace transformers backend with bitsandbytes optimization"""
|
|
279
|
+
|
|
280
|
+
def __init__(self, model_name: str, device: str = "auto", load_in_4bit: bool = False, **kwargs):
|
|
281
|
+
try:
|
|
282
|
+
import torch
|
|
283
|
+
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
|
284
|
+
from PIL import Image
|
|
285
|
+
except ImportError as e:
|
|
286
|
+
raise ImportError(f"Required packages not found. Install with: pip install torch transformers bitsandbytes accelerate. Error: {e}")
|
|
287
|
+
|
|
288
|
+
self.model_name = model_name
|
|
289
|
+
self.device = device
|
|
290
|
+
self.torch = torch
|
|
291
|
+
|
|
292
|
+
logger.info(f"Loading local VLM model: {model_name}")
|
|
293
|
+
|
|
294
|
+
# Configure quantization if requested
|
|
295
|
+
quantization_config = None
|
|
296
|
+
if load_in_4bit:
|
|
297
|
+
quantization_config = BitsAndBytesConfig(
|
|
298
|
+
load_in_4bit=True,
|
|
299
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
300
|
+
bnb_4bit_use_double_quant=True,
|
|
301
|
+
bnb_4bit_quant_type="nf4"
|
|
302
|
+
)
|
|
303
|
+
logger.info("Using 4-bit quantization with bitsandbytes")
|
|
304
|
+
|
|
305
|
+
# Load processor and model
|
|
306
|
+
try:
|
|
307
|
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
308
|
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
|
309
|
+
model_name,
|
|
310
|
+
quantization_config=quantization_config,
|
|
311
|
+
device_map=device if device != "auto" else "auto",
|
|
312
|
+
torch_dtype=torch.float16 if not load_in_4bit else None,
|
|
313
|
+
trust_remote_code=True
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if device != "auto" and not load_in_4bit:
|
|
317
|
+
self.model = self.model.to(device)
|
|
318
|
+
|
|
319
|
+
logger.info(f"Model loaded successfully on {device}")
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
logger.error(f"Failed to load model {model_name}: {e}")
|
|
323
|
+
raise
|
|
324
|
+
|
|
325
|
+
def _generate_response(self, inputs: Dict[str, Any], text: str, module_name: str) -> str:
|
|
326
|
+
"""Generate response using the local model"""
|
|
327
|
+
try:
|
|
328
|
+
|
|
329
|
+
# Log the prompt
|
|
330
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
331
|
+
logger.info(f"[{module_name}] LOCAL HF VLM QUERY:")
|
|
332
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
333
|
+
|
|
334
|
+
with self.torch.no_grad():
|
|
335
|
+
# Ensure all inputs are on the correct device
|
|
336
|
+
if hasattr(self.model, 'device'):
|
|
337
|
+
device = self.model.device
|
|
338
|
+
elif hasattr(self.model, 'module') and hasattr(self.model.module, 'device'):
|
|
339
|
+
device = self.model.module.device
|
|
340
|
+
else:
|
|
341
|
+
device = next(self.model.parameters()).device
|
|
342
|
+
|
|
343
|
+
# Move inputs to device if needed
|
|
344
|
+
inputs_on_device = {}
|
|
345
|
+
for k, v in inputs.items():
|
|
346
|
+
if hasattr(v, 'to'):
|
|
347
|
+
inputs_on_device[k] = v.to(device)
|
|
348
|
+
else:
|
|
349
|
+
inputs_on_device[k] = v
|
|
350
|
+
|
|
351
|
+
generated_ids = self.model.generate(
|
|
352
|
+
**inputs_on_device,
|
|
353
|
+
max_new_tokens=1024,
|
|
354
|
+
do_sample=True,
|
|
355
|
+
temperature=0.7,
|
|
356
|
+
pad_token_id=self.processor.tokenizer.eos_token_id
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Decode the response
|
|
360
|
+
generated_text = self.processor.decode(generated_ids[0], skip_special_tokens=True)
|
|
361
|
+
|
|
362
|
+
# Extract only the generated part (remove the prompt)
|
|
363
|
+
if text in generated_text:
|
|
364
|
+
result = generated_text.split(text)[-1].strip()
|
|
365
|
+
else:
|
|
366
|
+
result = generated_text.strip()
|
|
367
|
+
|
|
368
|
+
# Log the response
|
|
369
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
370
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
371
|
+
logger.info(f"[{module_name}] ---")
|
|
372
|
+
|
|
373
|
+
return result
|
|
374
|
+
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.error(f"Error generating response: {e}")
|
|
377
|
+
raise
|
|
378
|
+
|
|
379
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
380
|
+
"""Process an image and text prompt using local HuggingFace model"""
|
|
381
|
+
# Handle both PIL Images and numpy arrays
|
|
382
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
383
|
+
image = img
|
|
384
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
385
|
+
image = Image.fromarray(img)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
388
|
+
|
|
389
|
+
# Prepare messages with proper chat template format
|
|
390
|
+
messages = [
|
|
391
|
+
{"role": "user",
|
|
392
|
+
"content": [
|
|
393
|
+
{"type": "image", "image": image},
|
|
394
|
+
{"type": "text", "text": text}
|
|
395
|
+
]}
|
|
396
|
+
]
|
|
397
|
+
formatted_text = self.processor.apply_chat_template(
|
|
398
|
+
messages, tokenize=False, add_generation_prompt=True)
|
|
399
|
+
inputs = self.processor(text=formatted_text, images=image, return_tensors="pt")
|
|
400
|
+
|
|
401
|
+
return self._generate_response(inputs, text, module_name)
|
|
402
|
+
|
|
403
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
404
|
+
"""Process a text-only prompt using local HuggingFace model"""
|
|
405
|
+
# For text-only queries, use simple text format without image
|
|
406
|
+
messages = [
|
|
407
|
+
{"role": "user", "content": text}
|
|
408
|
+
]
|
|
409
|
+
formatted_text = self.processor.apply_chat_template(
|
|
410
|
+
messages, tokenize=False, add_generation_prompt=True)
|
|
411
|
+
inputs = self.processor(text=formatted_text, return_tensors="pt")
|
|
412
|
+
|
|
413
|
+
return self._generate_response(inputs, text, module_name)
|
|
414
|
+
|
|
415
|
+
class LegacyOllamaBackend(VLMBackend):
|
|
416
|
+
"""Legacy Ollama backend for backward compatibility"""
|
|
417
|
+
|
|
418
|
+
def __init__(self, model_name: str, port: int = 8010, **kwargs):
|
|
419
|
+
try:
|
|
420
|
+
from openai import OpenAI
|
|
421
|
+
except ImportError:
|
|
422
|
+
raise ImportError("OpenAI package not found. Install with: pip install openai")
|
|
423
|
+
|
|
424
|
+
self.model_name = model_name
|
|
425
|
+
self.port = port
|
|
426
|
+
self.client = OpenAI(api_key='', base_url=f'http://localhost:{port}/v1')
|
|
427
|
+
|
|
428
|
+
@retry_with_exponential_backoff
|
|
429
|
+
def _call_completion(self, messages):
|
|
430
|
+
"""Calls the completions.create method with exponential backoff."""
|
|
431
|
+
return self.client.chat.completions.create(
|
|
432
|
+
model=self.model_name,
|
|
433
|
+
messages=messages
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
437
|
+
"""Process an image and text prompt using legacy Ollama backend"""
|
|
438
|
+
# Handle both PIL Images and numpy arrays
|
|
439
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
440
|
+
image = img
|
|
441
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
442
|
+
image = Image.fromarray(img)
|
|
443
|
+
else:
|
|
444
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
445
|
+
|
|
446
|
+
buffered = BytesIO()
|
|
447
|
+
image.save(buffered, format="PNG")
|
|
448
|
+
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
449
|
+
|
|
450
|
+
messages = [{
|
|
451
|
+
"role": "user",
|
|
452
|
+
"content": [
|
|
453
|
+
{"type": "text", "text": text},
|
|
454
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
|
455
|
+
]
|
|
456
|
+
}]
|
|
457
|
+
|
|
458
|
+
# Log the prompt
|
|
459
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
460
|
+
logger.info(f"[{module_name}] OLLAMA VLM IMAGE QUERY:")
|
|
461
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
462
|
+
|
|
463
|
+
response = self._call_completion(messages)
|
|
464
|
+
result = response.choices[0].message.content
|
|
465
|
+
|
|
466
|
+
# Log the response
|
|
467
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
468
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
469
|
+
logger.info(f"[{module_name}] ---")
|
|
470
|
+
|
|
471
|
+
return result
|
|
472
|
+
|
|
473
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
474
|
+
"""Process a text-only prompt using legacy Ollama backend"""
|
|
475
|
+
messages = [{
|
|
476
|
+
"role": "user",
|
|
477
|
+
"content": [{"type": "text", "text": text}]
|
|
478
|
+
}]
|
|
479
|
+
|
|
480
|
+
# Log the prompt
|
|
481
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
482
|
+
logger.info(f"[{module_name}] OLLAMA VLM TEXT QUERY:")
|
|
483
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
484
|
+
|
|
485
|
+
response = self._call_completion(messages)
|
|
486
|
+
result = response.choices[0].message.content
|
|
487
|
+
|
|
488
|
+
# Log the response
|
|
489
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
490
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
491
|
+
logger.info(f"[{module_name}] ---")
|
|
492
|
+
|
|
493
|
+
return result
|
|
494
|
+
|
|
495
|
+
class VertexBackend(VLMBackend):
|
|
496
|
+
"""Google Gemini API with Vertex backend"""
|
|
497
|
+
|
|
498
|
+
def __init__(self, model_name: str, **kwargs):
|
|
499
|
+
try:
|
|
500
|
+
from google import genai
|
|
501
|
+
except ImportError:
|
|
502
|
+
raise ImportError("Google Generative AI package not found. Install with: pip install google-generativeai")
|
|
503
|
+
|
|
504
|
+
self.model_name = model_name
|
|
505
|
+
|
|
506
|
+
# Initialize the model
|
|
507
|
+
self.client = genai.Client(
|
|
508
|
+
vertexai=True,
|
|
509
|
+
project='pokeagent-011',
|
|
510
|
+
location='us-central1',
|
|
511
|
+
)
|
|
512
|
+
self.genai = genai
|
|
513
|
+
|
|
514
|
+
logger.info(f"Gemini backend initialized with model: {model_name}")
|
|
515
|
+
|
|
516
|
+
def _prepare_image(self, img: Union[Image.Image, np.ndarray]) -> Image.Image:
|
|
517
|
+
"""Prepare image for Gemini API"""
|
|
518
|
+
# Handle both PIL Images and numpy arrays
|
|
519
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
520
|
+
return img
|
|
521
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
522
|
+
return Image.fromarray(img)
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
525
|
+
|
|
526
|
+
@retry_with_exponential_backoff
|
|
527
|
+
def _call_generate_content(self, content_parts):
|
|
528
|
+
"""Calls the generate_content method with exponential backoff."""
|
|
529
|
+
response = self.client.models.generate_content(
|
|
530
|
+
model='gemini-2.5-flash',
|
|
531
|
+
contents=content_parts
|
|
532
|
+
)
|
|
533
|
+
return response
|
|
534
|
+
|
|
535
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
536
|
+
"""Process an image and text prompt using Gemini API"""
|
|
537
|
+
try:
|
|
538
|
+
image = self._prepare_image(img)
|
|
539
|
+
|
|
540
|
+
# Prepare content for Gemini
|
|
541
|
+
content_parts = [text, image]
|
|
542
|
+
|
|
543
|
+
# Log the prompt
|
|
544
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
545
|
+
logger.info(f"[{module_name}] GEMINI VLM IMAGE QUERY:")
|
|
546
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
547
|
+
|
|
548
|
+
# Generate response
|
|
549
|
+
response = self._call_generate_content(content_parts)
|
|
550
|
+
|
|
551
|
+
# Check for safety filter or content policy issues
|
|
552
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
553
|
+
candidate = response.candidates[0]
|
|
554
|
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 12:
|
|
555
|
+
logger.warning(f"[{module_name}] Gemini safety filter triggered (finish_reason=12). Trying text-only fallback.")
|
|
556
|
+
# Fallback to text-only query
|
|
557
|
+
return self.get_text_query(text, module_name)
|
|
558
|
+
|
|
559
|
+
result = response.text
|
|
560
|
+
|
|
561
|
+
# Log the response
|
|
562
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
563
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
564
|
+
logger.info(f"[{module_name}] ---")
|
|
565
|
+
|
|
566
|
+
return result
|
|
567
|
+
|
|
568
|
+
except Exception as e:
|
|
569
|
+
logger.error(f"Error in Gemini image query: {e}")
|
|
570
|
+
# Try text-only fallback for any Gemini error
|
|
571
|
+
try:
|
|
572
|
+
logger.info(f"[{module_name}] Attempting text-only fallback due to error: {e}")
|
|
573
|
+
return self.get_text_query(text, module_name)
|
|
574
|
+
except Exception as fallback_error:
|
|
575
|
+
logger.error(f"[{module_name}] Text-only fallback also failed: {fallback_error}")
|
|
576
|
+
raise e
|
|
577
|
+
|
|
578
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
579
|
+
"""Process a text-only prompt using Gemini API"""
|
|
580
|
+
try:
|
|
581
|
+
# Log the prompt
|
|
582
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
583
|
+
logger.info(f"[{module_name}] GEMINI VLM TEXT QUERY:")
|
|
584
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
585
|
+
|
|
586
|
+
# Generate response
|
|
587
|
+
response = self._call_generate_content([text])
|
|
588
|
+
|
|
589
|
+
# Check for safety filter or content policy issues
|
|
590
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
591
|
+
candidate = response.candidates[0]
|
|
592
|
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 12:
|
|
593
|
+
logger.warning(f"[{module_name}] Gemini safety filter triggered (finish_reason=12). Returning default response.")
|
|
594
|
+
return "I cannot analyze this content due to safety restrictions. I'll proceed with a basic action: press 'A' to continue."
|
|
595
|
+
|
|
596
|
+
result = response.text
|
|
597
|
+
|
|
598
|
+
# Log the response
|
|
599
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
600
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
601
|
+
logger.info(f"[{module_name}] ---")
|
|
602
|
+
|
|
603
|
+
return result
|
|
604
|
+
|
|
605
|
+
except Exception as e:
|
|
606
|
+
logger.error(f"Error in Gemini text query: {e}")
|
|
607
|
+
# Return a safe default response
|
|
608
|
+
logger.warning(f"[{module_name}] Returning default response due to error: {e}")
|
|
609
|
+
return "I encountered an error processing the request. I'll proceed with a basic action: press 'A' to continue."
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
class GeminiBackend(VLMBackend):
|
|
613
|
+
"""Google Gemini API backend"""
|
|
614
|
+
|
|
615
|
+
def __init__(self, model_name: str, **kwargs):
|
|
616
|
+
try:
|
|
617
|
+
import google.generativeai as genai
|
|
618
|
+
except ImportError:
|
|
619
|
+
raise ImportError("Google Generative AI package not found. Install with: pip install google-generativeai")
|
|
620
|
+
|
|
621
|
+
self.model_name = model_name
|
|
622
|
+
self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
623
|
+
|
|
624
|
+
if not self.api_key:
|
|
625
|
+
raise ValueError("Error: Gemini API key is missing! Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
|
|
626
|
+
|
|
627
|
+
# Configure the API
|
|
628
|
+
genai.configure(api_key=self.api_key)
|
|
629
|
+
|
|
630
|
+
# Initialize the model
|
|
631
|
+
self.model = genai.GenerativeModel(model_name)
|
|
632
|
+
self.genai = genai
|
|
633
|
+
|
|
634
|
+
logger.info(f"Gemini backend initialized with model: {model_name}")
|
|
635
|
+
|
|
636
|
+
def _prepare_image(self, img: Union[Image.Image, np.ndarray]) -> Image.Image:
|
|
637
|
+
"""Prepare image for Gemini API"""
|
|
638
|
+
# Handle both PIL Images and numpy arrays
|
|
639
|
+
if hasattr(img, 'convert'): # It's a PIL Image
|
|
640
|
+
return img
|
|
641
|
+
elif hasattr(img, 'shape'): # It's a numpy array
|
|
642
|
+
return Image.fromarray(img)
|
|
643
|
+
else:
|
|
644
|
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
645
|
+
|
|
646
|
+
@retry_with_exponential_backoff
|
|
647
|
+
def _call_generate_content(self, content_parts):
|
|
648
|
+
"""Calls the generate_content method with exponential backoff."""
|
|
649
|
+
response = self.model.generate_content(content_parts)
|
|
650
|
+
response.resolve()
|
|
651
|
+
return response
|
|
652
|
+
|
|
653
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
654
|
+
"""Process an image and text prompt using Gemini API"""
|
|
655
|
+
start_time = time.time()
|
|
656
|
+
try:
|
|
657
|
+
image = self._prepare_image(img)
|
|
658
|
+
|
|
659
|
+
# Prepare content for Gemini
|
|
660
|
+
content_parts = [text, image]
|
|
661
|
+
|
|
662
|
+
# Log the prompt
|
|
663
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
664
|
+
logger.info(f"[{module_name}] GEMINI VLM IMAGE QUERY:")
|
|
665
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
666
|
+
|
|
667
|
+
# Generate response
|
|
668
|
+
response = self._call_generate_content(content_parts)
|
|
669
|
+
|
|
670
|
+
# Check for safety filter or content policy issues
|
|
671
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
672
|
+
candidate = response.candidates[0]
|
|
673
|
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 12:
|
|
674
|
+
logger.warning(f"[{module_name}] Gemini safety filter triggered (finish_reason=12). Trying text-only fallback.")
|
|
675
|
+
# Fallback to text-only query
|
|
676
|
+
return self.get_text_query(text, module_name)
|
|
677
|
+
|
|
678
|
+
result = response.text
|
|
679
|
+
duration = time.time() - start_time
|
|
680
|
+
|
|
681
|
+
# Extract token usage if available
|
|
682
|
+
token_usage = {}
|
|
683
|
+
if hasattr(response, 'usage_metadata'):
|
|
684
|
+
usage = response.usage_metadata
|
|
685
|
+
token_usage = {
|
|
686
|
+
"prompt_tokens": getattr(usage, 'prompt_token_count', 0),
|
|
687
|
+
"completion_tokens": getattr(usage, 'candidates_token_count', 0),
|
|
688
|
+
"total_tokens": getattr(usage, 'total_token_count', 0)
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
# Log the interaction
|
|
692
|
+
log_llm_interaction(
|
|
693
|
+
interaction_type=f"gemini_{module_name}",
|
|
694
|
+
prompt=text,
|
|
695
|
+
response=result,
|
|
696
|
+
duration=duration,
|
|
697
|
+
metadata={"model": self.model_name, "backend": "gemini", "has_image": True, "token_usage": token_usage},
|
|
698
|
+
model_info={"model": self.model_name, "backend": "gemini"}
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
# Log the response
|
|
702
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
703
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
704
|
+
logger.info(f"[{module_name}] ---")
|
|
705
|
+
|
|
706
|
+
return result
|
|
707
|
+
|
|
708
|
+
except Exception as e:
|
|
709
|
+
logger.error(f"Error in Gemini image query: {e}")
|
|
710
|
+
# Try text-only fallback for any Gemini error
|
|
711
|
+
try:
|
|
712
|
+
logger.info(f"[{module_name}] Attempting text-only fallback due to error: {e}")
|
|
713
|
+
return self.get_text_query(text, module_name)
|
|
714
|
+
except Exception as fallback_error:
|
|
715
|
+
logger.error(f"[{module_name}] Text-only fallback also failed: {fallback_error}")
|
|
716
|
+
raise e
|
|
717
|
+
|
|
718
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
719
|
+
"""Process a text-only prompt using Gemini API"""
|
|
720
|
+
start_time = time.time()
|
|
721
|
+
try:
|
|
722
|
+
# Log the prompt
|
|
723
|
+
prompt_preview = text[:2000] + "..." if len(text) > 2000 else text
|
|
724
|
+
logger.info(f"[{module_name}] GEMINI VLM TEXT QUERY:")
|
|
725
|
+
logger.info(f"[{module_name}] PROMPT: {prompt_preview}")
|
|
726
|
+
|
|
727
|
+
# Generate response
|
|
728
|
+
response = self._call_generate_content([text])
|
|
729
|
+
|
|
730
|
+
# Check for safety filter or content policy issues
|
|
731
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
732
|
+
candidate = response.candidates[0]
|
|
733
|
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 12:
|
|
734
|
+
logger.warning(f"[{module_name}] Gemini safety filter triggered (finish_reason=12). Returning default response.")
|
|
735
|
+
return "I cannot analyze this content due to safety restrictions. I'll proceed with a basic action: press 'A' to continue."
|
|
736
|
+
|
|
737
|
+
result = response.text
|
|
738
|
+
duration = time.time() - start_time
|
|
739
|
+
|
|
740
|
+
# Extract token usage if available
|
|
741
|
+
token_usage = {}
|
|
742
|
+
if hasattr(response, 'usage_metadata'):
|
|
743
|
+
usage = response.usage_metadata
|
|
744
|
+
token_usage = {
|
|
745
|
+
"prompt_tokens": getattr(usage, 'prompt_token_count', 0),
|
|
746
|
+
"completion_tokens": getattr(usage, 'candidates_token_count', 0),
|
|
747
|
+
"total_tokens": getattr(usage, 'total_token_count', 0)
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
# Log the interaction
|
|
751
|
+
log_llm_interaction(
|
|
752
|
+
interaction_type=f"gemini_{module_name}",
|
|
753
|
+
prompt=text,
|
|
754
|
+
response=result,
|
|
755
|
+
duration=duration,
|
|
756
|
+
metadata={"model": self.model_name, "backend": "gemini", "has_image": False, "token_usage": token_usage},
|
|
757
|
+
model_info={"model": self.model_name, "backend": "gemini"}
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Log the response
|
|
761
|
+
result_preview = result[:1000] + "..." if len(result) > 1000 else result
|
|
762
|
+
logger.info(f"[{module_name}] RESPONSE: {result_preview}")
|
|
763
|
+
logger.info(f"[{module_name}] ---")
|
|
764
|
+
|
|
765
|
+
return result
|
|
766
|
+
|
|
767
|
+
except Exception as e:
|
|
768
|
+
logger.error(f"Error in Gemini text query: {e}")
|
|
769
|
+
# Return a safe default response
|
|
770
|
+
logger.warning(f"[{module_name}] Returning default response due to error: {e}")
|
|
771
|
+
return "I encountered an error processing the request. I'll proceed with a basic action: press 'A' to continue."
|
|
772
|
+
|
|
773
|
+
class VLM:
|
|
774
|
+
"""Main VLM class that supports multiple backends"""
|
|
775
|
+
|
|
776
|
+
BACKENDS = {
|
|
777
|
+
'openai': OpenAIBackend,
|
|
778
|
+
'openrouter': OpenRouterBackend,
|
|
779
|
+
'local': LocalHuggingFaceBackend,
|
|
780
|
+
'gemini': GeminiBackend,
|
|
781
|
+
'ollama': LegacyOllamaBackend, # Legacy support
|
|
782
|
+
'vertex': VertexBackend, # Added Vertex backend
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
def __init__(self, model_name: str, backend: str = 'openai', port: int = 8010, **kwargs):
|
|
786
|
+
"""
|
|
787
|
+
Initialize VLM with specified backend
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
model_name: Name of the model to use
|
|
791
|
+
backend: Backend type ('openai', 'openrouter', 'local', 'gemini', 'ollama')
|
|
792
|
+
port: Port for Ollama backend (legacy)
|
|
793
|
+
**kwargs: Additional arguments passed to backend
|
|
794
|
+
"""
|
|
795
|
+
self.model_name = model_name
|
|
796
|
+
self.backend_type = backend.lower()
|
|
797
|
+
|
|
798
|
+
# Auto-detect backend based on model name if not explicitly specified
|
|
799
|
+
if backend == 'auto':
|
|
800
|
+
self.backend_type = self._auto_detect_backend(model_name)
|
|
801
|
+
|
|
802
|
+
if self.backend_type not in self.BACKENDS:
|
|
803
|
+
raise ValueError(f"Unsupported backend: {self.backend_type}. Available: {list(self.BACKENDS.keys())}")
|
|
804
|
+
|
|
805
|
+
# Initialize the appropriate backend
|
|
806
|
+
backend_class = self.BACKENDS[self.backend_type]
|
|
807
|
+
|
|
808
|
+
# Pass port parameter for legacy Ollama backend
|
|
809
|
+
if self.backend_type == 'ollama':
|
|
810
|
+
self.backend = backend_class(model_name, port=port, **kwargs)
|
|
811
|
+
else:
|
|
812
|
+
self.backend = backend_class(model_name, **kwargs)
|
|
813
|
+
|
|
814
|
+
logger.info(f"VLM initialized with {self.backend_type} backend using model: {model_name}")
|
|
815
|
+
|
|
816
|
+
def _auto_detect_backend(self, model_name: str) -> str:
|
|
817
|
+
"""Auto-detect backend based on model name"""
|
|
818
|
+
model_lower = model_name.lower()
|
|
819
|
+
|
|
820
|
+
if any(x in model_lower for x in ['gpt', 'o4-mini', 'o3', 'claude']):
|
|
821
|
+
return 'openai'
|
|
822
|
+
elif any(x in model_lower for x in ['gemini', 'palm']):
|
|
823
|
+
return 'gemini'
|
|
824
|
+
elif any(x in model_lower for x in ['llama', 'mistral', 'qwen', 'phi']):
|
|
825
|
+
return 'local'
|
|
826
|
+
else:
|
|
827
|
+
# Default to OpenAI for unknown models
|
|
828
|
+
return 'openai'
|
|
829
|
+
|
|
830
|
+
def get_query(self, img: Union[Image.Image, np.ndarray], text: str, module_name: str = "Unknown") -> str:
|
|
831
|
+
"""Process an image and text prompt"""
|
|
832
|
+
try:
|
|
833
|
+
# Backend handles its own logging, so we don't duplicate it here
|
|
834
|
+
result = self.backend.get_query(img, text, module_name)
|
|
835
|
+
return result
|
|
836
|
+
except Exception as e:
|
|
837
|
+
# Only log errors that aren't already logged by the backend
|
|
838
|
+
duration = 0 # Backend tracks actual duration
|
|
839
|
+
log_llm_error(
|
|
840
|
+
interaction_type=f"{self.backend.__class__.__name__.lower()}_{module_name}",
|
|
841
|
+
prompt=text,
|
|
842
|
+
error=str(e),
|
|
843
|
+
metadata={"model": self.model_name, "backend": self.backend.__class__.__name__, "duration": duration, "has_image": True}
|
|
844
|
+
)
|
|
845
|
+
raise
|
|
846
|
+
|
|
847
|
+
def get_text_query(self, text: str, module_name: str = "Unknown") -> str:
|
|
848
|
+
"""Process a text-only prompt"""
|
|
849
|
+
try:
|
|
850
|
+
# Backend handles its own logging, so we don't duplicate it here
|
|
851
|
+
result = self.backend.get_text_query(text, module_name)
|
|
852
|
+
return result
|
|
853
|
+
except Exception as e:
|
|
854
|
+
# Only log errors that aren't already logged by the backend
|
|
855
|
+
duration = 0 # Backend tracks actual duration
|
|
856
|
+
log_llm_error(
|
|
857
|
+
interaction_type=f"{self.backend.__class__.__name__.lower()}_{module_name}",
|
|
858
|
+
prompt=text,
|
|
859
|
+
error=str(e),
|
|
860
|
+
metadata={"model": self.model_name, "backend": self.backend.__class__.__name__, "duration": duration, "has_image": False}
|
|
861
|
+
)
|
|
862
|
+
raise
|