synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +20 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -262
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev5.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Environment implementations."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# wraps hosted app
|
|
@@ -0,0 +1,522 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import logging
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
11
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
|
12
|
+
|
|
13
|
+
from ...utils import convert_numpy_to_python
|
|
14
|
+
from .shared import CRAFTER_ACTIONS, _format_semantic_map_view
|
|
15
|
+
from .tools import TOOLS_SCHEMA
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _encode_image_to_base64(image_array: Any) -> dict[str, Any] | None:
|
|
21
|
+
"""Encode an RGB ndarray into a base64 PNG payload with metadata."""
|
|
22
|
+
|
|
23
|
+
if not isinstance(image_array, np.ndarray):
|
|
24
|
+
return None
|
|
25
|
+
if image_array.ndim != 3 or image_array.shape[-1] not in (1, 3, 4):
|
|
26
|
+
return None
|
|
27
|
+
try:
|
|
28
|
+
# Ensure uint8 for PIL compatibility
|
|
29
|
+
array_uint8 = (
|
|
30
|
+
image_array.astype("uint8")
|
|
31
|
+
if image_array.dtype != np.uint8
|
|
32
|
+
else image_array # pragma: no cover - fast path
|
|
33
|
+
)
|
|
34
|
+
mode = "L" if array_uint8.shape[-1] == 1 else "RGB"
|
|
35
|
+
if array_uint8.shape[-1] == 4:
|
|
36
|
+
mode = "RGBA"
|
|
37
|
+
img = Image.fromarray(array_uint8, mode=mode)
|
|
38
|
+
buffer = BytesIO()
|
|
39
|
+
img.save(buffer, format="PNG")
|
|
40
|
+
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
|
41
|
+
width = int(array_uint8.shape[1])
|
|
42
|
+
height = int(array_uint8.shape[0])
|
|
43
|
+
return {
|
|
44
|
+
"format": "png",
|
|
45
|
+
"width": width,
|
|
46
|
+
"height": height,
|
|
47
|
+
"data": encoded,
|
|
48
|
+
"data_url": f"data:image/png;base64,{encoded}",
|
|
49
|
+
}
|
|
50
|
+
except Exception:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CrafterEnvironmentWrapper:
|
|
55
|
+
"""Host-side environment wrapper matching the sketch contract.
|
|
56
|
+
|
|
57
|
+
Bridges our HTTP routes to a synth-ai `StatefulEnvironment` instance.
|
|
58
|
+
|
|
59
|
+
Contract (see sketch.txt):
|
|
60
|
+
- initialize() -> observation dict
|
|
61
|
+
- step(tool_calls: List[EnvToolCall]) -> observation dict plus optional done/reward/truncated/info
|
|
62
|
+
- snapshot()/restore() handled at route level; this wrapper exposes checkpoint via synth-ai
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, env: StatefulEnvironment, seed: int | None = None) -> None:
|
|
66
|
+
self.env = env
|
|
67
|
+
self.seed = seed
|
|
68
|
+
self.step_idx = 0
|
|
69
|
+
self.last_observation: dict[str, Any] | None = None
|
|
70
|
+
self.last_info: dict[str, Any] | None = None
|
|
71
|
+
|
|
72
|
+
async def initialize(self) -> dict[str, Any]:
|
|
73
|
+
obs = await self.env.initialize()
|
|
74
|
+
# synth-ai InternalObservation expected to expose .observation (dict-like)
|
|
75
|
+
self.step_idx = 0
|
|
76
|
+
self.last_observation = getattr(obs, "observation", obs) # tolerate dict-like
|
|
77
|
+
self.last_info = getattr(obs, "info", None)
|
|
78
|
+
out_obs = self._prepare_observation(self.last_observation)
|
|
79
|
+
# Attach a 7x7 semantic map patch centered on player for client-side rendering
|
|
80
|
+
try:
|
|
81
|
+
pub = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
82
|
+
sem = pub.semantic_map
|
|
83
|
+
px, py = list(pub.player_position)
|
|
84
|
+
size = 7
|
|
85
|
+
half = size // 2
|
|
86
|
+
patch = []
|
|
87
|
+
height = len(sem) if hasattr(sem, "__len__") else 0
|
|
88
|
+
width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
|
|
89
|
+
for dy in range(-half, half + 1):
|
|
90
|
+
row = []
|
|
91
|
+
for dx in range(-half, half + 1):
|
|
92
|
+
x, y = int(px) + dx, int(py) + dy
|
|
93
|
+
if 0 <= x < height and 0 <= y < width:
|
|
94
|
+
row.append(int(sem[x][y]))
|
|
95
|
+
else:
|
|
96
|
+
row.append(0)
|
|
97
|
+
patch.append(row)
|
|
98
|
+
if isinstance(out_obs, dict):
|
|
99
|
+
out_obs["semantic_map_patch7"] = patch
|
|
100
|
+
except Exception:
|
|
101
|
+
pass
|
|
102
|
+
return {
|
|
103
|
+
"observation": out_obs,
|
|
104
|
+
"info": convert_numpy_to_python(self.last_info) if self.last_info else None,
|
|
105
|
+
"step_idx": self.step_idx,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async def step(self, tool_calls: list[dict[str, Any]] | list[EnvToolCall]) -> dict[str, Any]:
|
|
109
|
+
# Normalize JSON tool_calls into EnvToolCall instances if needed
|
|
110
|
+
# Underlying synth-ai environment expects only tool="interact" with args={"action": <action_name>}.
|
|
111
|
+
# LLM may emit:
|
|
112
|
+
# - interact_many with {actions: [...]}
|
|
113
|
+
# - direct tool names like "make_wood_pickaxe" or "do"
|
|
114
|
+
# - or even tool_name "do" with arguments {"action": "make_wood_pickaxe"}
|
|
115
|
+
# We normalize all these into a sequence of EnvToolCall(tool="interact", args={"action": <resolved_action>}).
|
|
116
|
+
allowed_actions = set(
|
|
117
|
+
TOOLS_SCHEMA[0]["function"]["parameters"]["properties"]["actions"]["items"]["enum"]
|
|
118
|
+
)
|
|
119
|
+
normalized: list[EnvToolCall] = []
|
|
120
|
+
|
|
121
|
+
def _action_to_int(action: Any) -> int | None:
|
|
122
|
+
# Handle invalid actions gracefully instead of failing
|
|
123
|
+
if isinstance(action, int):
|
|
124
|
+
return action
|
|
125
|
+
action_str = str(action)
|
|
126
|
+
if action_str not in CRAFTER_ACTIONS:
|
|
127
|
+
logger.warning("Unknown Crafter action: %s - ignoring", action_str)
|
|
128
|
+
return None # Signal to skip this action
|
|
129
|
+
return CRAFTER_ACTIONS[action_str]
|
|
130
|
+
|
|
131
|
+
for tc in tool_calls:
|
|
132
|
+
if isinstance(tc, EnvToolCall):
|
|
133
|
+
# Expand interact_many; otherwise coerce non-interact tools into interact(action=tool)
|
|
134
|
+
if tc.tool == "interact_many":
|
|
135
|
+
actions = tc.args.get("actions", [])
|
|
136
|
+
for action in actions:
|
|
137
|
+
action_int = _action_to_int(action)
|
|
138
|
+
if action_int is not None: # Skip invalid actions
|
|
139
|
+
normalized.append(
|
|
140
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
141
|
+
)
|
|
142
|
+
elif tc.tool != "interact":
|
|
143
|
+
candidate_action = tc.args.get("action") if isinstance(tc.args, dict) else None
|
|
144
|
+
resolved_action = (
|
|
145
|
+
candidate_action if candidate_action in allowed_actions else tc.tool
|
|
146
|
+
)
|
|
147
|
+
action_int = _action_to_int(resolved_action)
|
|
148
|
+
if action_int is not None: # Skip invalid actions
|
|
149
|
+
normalized.append(EnvToolCall(tool="interact", args={"action": action_int}))
|
|
150
|
+
else:
|
|
151
|
+
normalized.append(tc)
|
|
152
|
+
else:
|
|
153
|
+
# Dict input: handle both "tool" and "tool_name" keys
|
|
154
|
+
tool_name = tc.get("tool") or tc.get("tool_name")
|
|
155
|
+
if not tool_name:
|
|
156
|
+
raise ValueError(f"Tool call missing tool name: {tc}")
|
|
157
|
+
# Extract/parse args (may be JSON string from some clients)
|
|
158
|
+
args = tc.get("arguments") or tc.get("args") or {}
|
|
159
|
+
if isinstance(args, str):
|
|
160
|
+
import json as _json
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
args = _json.loads(args)
|
|
164
|
+
except Exception:
|
|
165
|
+
args = {}
|
|
166
|
+
# Expand interact_many into multiple interacts
|
|
167
|
+
if tool_name == "interact_many":
|
|
168
|
+
for action in args.get("actions") or []:
|
|
169
|
+
action_int = _action_to_int(action)
|
|
170
|
+
if action_int is not None: # Skip invalid actions
|
|
171
|
+
normalized.append(
|
|
172
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
# For any non-interact tool, resolve to an interact action.
|
|
176
|
+
# Support a packed list of actions under 'actions' for convenience.
|
|
177
|
+
if (
|
|
178
|
+
isinstance(args, dict)
|
|
179
|
+
and isinstance(args.get("actions"), list)
|
|
180
|
+
and args.get("actions")
|
|
181
|
+
):
|
|
182
|
+
for action in args.get("actions"):
|
|
183
|
+
action_int = _action_to_int(action)
|
|
184
|
+
if action_int is not None:
|
|
185
|
+
normalized.append(
|
|
186
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
candidate_action = None
|
|
190
|
+
if isinstance(args, dict) and "action" in args:
|
|
191
|
+
candidate_action = args["action"]
|
|
192
|
+
# If the caller provided a numeric action id, accept it directly
|
|
193
|
+
action_int: int | None
|
|
194
|
+
if isinstance(candidate_action, int) or (
|
|
195
|
+
isinstance(candidate_action, str)
|
|
196
|
+
and candidate_action in allowed_actions
|
|
197
|
+
):
|
|
198
|
+
action_int = _action_to_int(candidate_action)
|
|
199
|
+
else:
|
|
200
|
+
# Fallback: interpret the tool name itself as the action label
|
|
201
|
+
action_int = _action_to_int(tool_name)
|
|
202
|
+
if action_int is not None:
|
|
203
|
+
normalized.append(
|
|
204
|
+
EnvToolCall(tool="interact", args={"action": action_int})
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Ensure we have at least one valid action; default to noop if none provided
|
|
208
|
+
if not normalized:
|
|
209
|
+
logger.info("No valid actions provided, defaulting to noop")
|
|
210
|
+
normalized.append(EnvToolCall(tool="interact", args={"action": 0})) # noop action
|
|
211
|
+
|
|
212
|
+
# Pre-step logging: capture current public state and print concise summary
|
|
213
|
+
before_state: dict[str, Any] | None = None
|
|
214
|
+
try:
|
|
215
|
+
pub_before = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
216
|
+
before_state = {
|
|
217
|
+
"inventory": pub_before.inventory,
|
|
218
|
+
"achievements_status": pub_before.achievements_status,
|
|
219
|
+
"player_position": list(pub_before.player_position),
|
|
220
|
+
"player_direction": pub_before.player_direction,
|
|
221
|
+
"semantic_map": pub_before.semantic_map,
|
|
222
|
+
}
|
|
223
|
+
actions_printable = [
|
|
224
|
+
(tc.args.get("action") if isinstance(tc.args, dict) else None)
|
|
225
|
+
if isinstance(tc, EnvToolCall)
|
|
226
|
+
else None
|
|
227
|
+
for tc in normalized
|
|
228
|
+
]
|
|
229
|
+
logger.info(
|
|
230
|
+
"Crafter BEFORE seed=%s step_idx=%s pos=%s inv=%s ach=%s actions=%s",
|
|
231
|
+
str(self.seed),
|
|
232
|
+
self.step_idx,
|
|
233
|
+
before_state.get("player_position"),
|
|
234
|
+
{k: v for k, v in before_state["inventory"].items() if v},
|
|
235
|
+
[k for k, v in before_state["achievements_status"].items() if v],
|
|
236
|
+
actions_printable,
|
|
237
|
+
)
|
|
238
|
+
logger.info(
|
|
239
|
+
"Surroundings BEFORE (seed=%s):\n%s",
|
|
240
|
+
str(self.seed),
|
|
241
|
+
_format_semantic_map_view(before_state),
|
|
242
|
+
)
|
|
243
|
+
except Exception as _:
|
|
244
|
+
# Logging should not interfere with stepping; fail-fast elsewhere
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
if not normalized:
|
|
248
|
+
raise ValueError("No valid actions provided to CrafterEnvironmentWrapper.step()")
|
|
249
|
+
|
|
250
|
+
# Execute actions sequentially so multi-action tool calls actually advance the world
|
|
251
|
+
last_obs: Any = None
|
|
252
|
+
for single_call in normalized:
|
|
253
|
+
last_obs = await self.env.step(single_call)
|
|
254
|
+
self.step_idx += 1
|
|
255
|
+
|
|
256
|
+
obs = last_obs
|
|
257
|
+
observation = getattr(obs, "observation", obs)
|
|
258
|
+
info = getattr(obs, "info", None)
|
|
259
|
+
done = getattr(obs, "done", False) # Default to False if None
|
|
260
|
+
reward = getattr(obs, "reward", None)
|
|
261
|
+
truncated = getattr(obs, "truncated", None)
|
|
262
|
+
self.last_observation = observation
|
|
263
|
+
self.last_info = info
|
|
264
|
+
|
|
265
|
+
# Post-step logging: capture new public state and print concise summary
|
|
266
|
+
ach_added_latest: list[str] | None = None
|
|
267
|
+
try:
|
|
268
|
+
pub_after = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
269
|
+
after_dict: dict[str, Any] = {
|
|
270
|
+
"inventory": pub_after.inventory,
|
|
271
|
+
"achievements_status": pub_after.achievements_status,
|
|
272
|
+
"player_position": list(pub_after.player_position),
|
|
273
|
+
"player_direction": pub_after.player_direction,
|
|
274
|
+
"semantic_map": pub_after.semantic_map,
|
|
275
|
+
}
|
|
276
|
+
logger.info(
|
|
277
|
+
"Crafter AFTER seed=%s step_idx=%s pos=%s inv=%s ach=%s done=%s reward=%s",
|
|
278
|
+
str(self.seed),
|
|
279
|
+
self.step_idx,
|
|
280
|
+
after_dict.get("player_position"),
|
|
281
|
+
{k: v for k, v in after_dict["inventory"].items() if v},
|
|
282
|
+
[k for k, v in after_dict["achievements_status"].items() if v],
|
|
283
|
+
bool(done) if done is not None else False,
|
|
284
|
+
reward,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Changes/diff summary (position and inventory)
|
|
288
|
+
if before_state is not None:
|
|
289
|
+
try:
|
|
290
|
+
# Position delta
|
|
291
|
+
pb = before_state.get("player_position", [0, 0])
|
|
292
|
+
pa = after_dict.get("player_position", [0, 0])
|
|
293
|
+
pb_t = (int(pb[0]), int(pb[1])) if isinstance(pb, list | tuple) else (0, 0)
|
|
294
|
+
pa_t = (int(pa[0]), int(pa[1])) if isinstance(pa, list | tuple) else (0, 0)
|
|
295
|
+
delta = (pa_t[0] - pb_t[0], pa_t[1] - pb_t[1])
|
|
296
|
+
|
|
297
|
+
# Inventory changes
|
|
298
|
+
inv_b = before_state.get("inventory", {}) or {}
|
|
299
|
+
inv_a = after_dict.get("inventory", {}) or {}
|
|
300
|
+
changed_items = []
|
|
301
|
+
all_keys = set(inv_b.keys()) | set(inv_a.keys())
|
|
302
|
+
for key in sorted(all_keys):
|
|
303
|
+
vb = int(inv_b.get(key, 0) or 0)
|
|
304
|
+
va = int(inv_a.get(key, 0) or 0)
|
|
305
|
+
if vb != va:
|
|
306
|
+
changed_items.append(f"{key}:{vb}->{va}(Δ{va - vb})")
|
|
307
|
+
inv_changes = ", ".join(changed_items) if changed_items else "none"
|
|
308
|
+
|
|
309
|
+
# Achievements gained/lost
|
|
310
|
+
ach_b = {
|
|
311
|
+
k
|
|
312
|
+
for k, v in (before_state.get("achievements_status", {}) or {}).items()
|
|
313
|
+
if v
|
|
314
|
+
}
|
|
315
|
+
ach_a = {
|
|
316
|
+
k for k, v in (after_dict.get("achievements_status", {}) or {}).items() if v
|
|
317
|
+
}
|
|
318
|
+
ach_added = sorted(ach_a - ach_b)
|
|
319
|
+
ach_added_latest = ach_added
|
|
320
|
+
ach_removed = sorted(ach_b - ach_a)
|
|
321
|
+
|
|
322
|
+
logger.info(
|
|
323
|
+
"Changes: pos %s->%s Δ=%s | inv %s | ach +%s -%s",
|
|
324
|
+
pb_t,
|
|
325
|
+
pa_t,
|
|
326
|
+
delta,
|
|
327
|
+
inv_changes,
|
|
328
|
+
ach_added if ach_added else [],
|
|
329
|
+
ach_removed if ach_removed else [],
|
|
330
|
+
)
|
|
331
|
+
# Reward shaping immediately so logs and response reflect it
|
|
332
|
+
if reward is None and ach_added_latest:
|
|
333
|
+
try:
|
|
334
|
+
reward = float(len(ach_added_latest))
|
|
335
|
+
logger.info(
|
|
336
|
+
"Reward shaping applied: +%s (achievements added)",
|
|
337
|
+
len(ach_added_latest),
|
|
338
|
+
)
|
|
339
|
+
except Exception:
|
|
340
|
+
pass
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
343
|
+
logger.info(
|
|
344
|
+
"Surroundings AFTER (seed=%s):\n%s",
|
|
345
|
+
str(self.seed),
|
|
346
|
+
_format_semantic_map_view(after_dict),
|
|
347
|
+
)
|
|
348
|
+
except Exception as _:
|
|
349
|
+
pass
|
|
350
|
+
result: dict[str, Any] = {
|
|
351
|
+
"observation": self._prepare_observation(observation),
|
|
352
|
+
"step_idx": self.step_idx,
|
|
353
|
+
"done": bool(done) if done is not None else False, # Ensure boolean
|
|
354
|
+
}
|
|
355
|
+
# Attach a 7x7 semantic map patch centered on player for client-side rendering
|
|
356
|
+
try:
|
|
357
|
+
sem = after_dict.get("semantic_map")
|
|
358
|
+
pos = after_dict.get("player_position") or [0, 0]
|
|
359
|
+
px, py = int(pos[0]), int(pos[1])
|
|
360
|
+
size = 7
|
|
361
|
+
half = size // 2
|
|
362
|
+
patch = []
|
|
363
|
+
height = len(sem) if hasattr(sem, "__len__") else 0
|
|
364
|
+
width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
|
|
365
|
+
for dy in range(-half, half + 1):
|
|
366
|
+
row = []
|
|
367
|
+
for dx in range(-half, half + 1):
|
|
368
|
+
x, y = px + dx, py + dy
|
|
369
|
+
if 0 <= x < height and 0 <= y < width:
|
|
370
|
+
row.append(int(sem[x][y]))
|
|
371
|
+
else:
|
|
372
|
+
row.append(0)
|
|
373
|
+
patch.append(row)
|
|
374
|
+
obs_out = result.get("observation")
|
|
375
|
+
if isinstance(obs_out, dict):
|
|
376
|
+
obs_out["semantic_map_patch7"] = patch
|
|
377
|
+
except Exception:
|
|
378
|
+
pass
|
|
379
|
+
result_info = convert_numpy_to_python(info) if info is not None else {}
|
|
380
|
+
# Attach achievements delta for downstream metrics if useful
|
|
381
|
+
if ach_added_latest is not None:
|
|
382
|
+
try:
|
|
383
|
+
if not isinstance(result_info, dict):
|
|
384
|
+
result_info = {"_raw_info": result_info}
|
|
385
|
+
result_info["achievements_added"] = ach_added_latest
|
|
386
|
+
except Exception:
|
|
387
|
+
pass
|
|
388
|
+
if result_info:
|
|
389
|
+
result["info"] = result_info
|
|
390
|
+
if reward is not None:
|
|
391
|
+
result["reward"] = convert_numpy_to_python(reward)
|
|
392
|
+
# Also expose last-step reward inside observation for stepwise consumers
|
|
393
|
+
try:
|
|
394
|
+
obs_out = result.get("observation")
|
|
395
|
+
if isinstance(obs_out, dict):
|
|
396
|
+
obs_out.setdefault("reward_last_step", convert_numpy_to_python(reward))
|
|
397
|
+
except Exception:
|
|
398
|
+
pass
|
|
399
|
+
if truncated is not None:
|
|
400
|
+
result["truncated"] = truncated
|
|
401
|
+
|
|
402
|
+
# Aggregated step summary: action frequencies and achievement stats
|
|
403
|
+
try:
|
|
404
|
+
# Build reverse action map for readability
|
|
405
|
+
int_to_action = {v: k for k, v in CRAFTER_ACTIONS.items()}
|
|
406
|
+
from collections import Counter
|
|
407
|
+
|
|
408
|
+
action_ids = []
|
|
409
|
+
for tc in normalized:
|
|
410
|
+
if isinstance(tc, EnvToolCall) and isinstance(tc.args, dict):
|
|
411
|
+
a = tc.args.get("action")
|
|
412
|
+
if isinstance(a, int):
|
|
413
|
+
action_ids.append(a)
|
|
414
|
+
action_names = [int_to_action.get(a, str(a)) for a in action_ids]
|
|
415
|
+
action_freq = Counter(action_names)
|
|
416
|
+
|
|
417
|
+
# Public achievements after step
|
|
418
|
+
pub_after = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
419
|
+
unlocked = [name for name, on in pub_after.achievements_status.items() if on]
|
|
420
|
+
ach_freq = Counter(unlocked)
|
|
421
|
+
|
|
422
|
+
# Private achievement values (means)
|
|
423
|
+
priv_after = self.env.engine._get_private_state_from_env(0.0, False, False) # type: ignore[attr-defined]
|
|
424
|
+
values = list((priv_after.achievements_current_values or {}).values())
|
|
425
|
+
mean_all = (sum(values) / len(values)) if values else 0.0
|
|
426
|
+
nonzero = [v for v in values if v]
|
|
427
|
+
mean_nonzero = (sum(nonzero) / len(nonzero)) if nonzero else 0.0
|
|
428
|
+
|
|
429
|
+
logger.info(
|
|
430
|
+
"Step summary: seed=%s | actions=%s | achievements=%s | mean_ach_all=%.3f mean_ach_nonzero=%.3f",
|
|
431
|
+
str(self.seed),
|
|
432
|
+
dict(action_freq),
|
|
433
|
+
dict(ach_freq),
|
|
434
|
+
mean_all,
|
|
435
|
+
mean_nonzero,
|
|
436
|
+
)
|
|
437
|
+
except Exception:
|
|
438
|
+
pass
|
|
439
|
+
|
|
440
|
+
return result
|
|
441
|
+
|
|
442
|
+
def _prepare_observation(self, observation: Any) -> dict[str, Any]:
|
|
443
|
+
"""Convert raw observation into a JSON-serializable dict with encoded image."""
|
|
444
|
+
|
|
445
|
+
obs_dict: dict[str, Any]
|
|
446
|
+
image_payload: dict[str, Any] | None = None
|
|
447
|
+
|
|
448
|
+
if isinstance(observation, dict):
|
|
449
|
+
image_payload = _encode_image_to_base64(observation.get("observation_image"))
|
|
450
|
+
# Work on a shallow copy to avoid mutating engine state
|
|
451
|
+
sanitized = dict(observation)
|
|
452
|
+
sanitized.pop("observation_image", None)
|
|
453
|
+
obs_dict = convert_numpy_to_python(sanitized) or {}
|
|
454
|
+
else:
|
|
455
|
+
obs_dict = convert_numpy_to_python(observation) or {}
|
|
456
|
+
|
|
457
|
+
if not isinstance(obs_dict, dict):
|
|
458
|
+
obs_dict = {"value": obs_dict}
|
|
459
|
+
|
|
460
|
+
if image_payload:
|
|
461
|
+
obs_dict["observation_image_base64"] = image_payload["data"]
|
|
462
|
+
obs_dict["observation_image_format"] = image_payload["format"]
|
|
463
|
+
obs_dict["observation_image_width"] = image_payload["width"]
|
|
464
|
+
obs_dict["observation_image_height"] = image_payload["height"]
|
|
465
|
+
obs_dict["observation_image_data_url"] = image_payload["data_url"]
|
|
466
|
+
|
|
467
|
+
return obs_dict
|
|
468
|
+
|
|
469
|
+
async def checkpoint(self) -> dict[str, Any]:
|
|
470
|
+
obs = await self.env.checkpoint()
|
|
471
|
+
observation = getattr(obs, "observation", obs)
|
|
472
|
+
info = getattr(obs, "info", None)
|
|
473
|
+
return {
|
|
474
|
+
"observation": convert_numpy_to_python(observation),
|
|
475
|
+
"info": convert_numpy_to_python(info) if info else None,
|
|
476
|
+
"step_idx": self.step_idx,
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
async def terminate(self) -> dict[str, Any]:
|
|
480
|
+
obs = await self.env.terminate()
|
|
481
|
+
observation = getattr(obs, "observation", obs)
|
|
482
|
+
info = getattr(obs, "info", None)
|
|
483
|
+
return {
|
|
484
|
+
"observation": convert_numpy_to_python(observation),
|
|
485
|
+
"info": convert_numpy_to_python(info) if info else None,
|
|
486
|
+
"step_idx": self.step_idx,
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
def state_dict(self) -> dict[str, Any]:
|
|
490
|
+
return {
|
|
491
|
+
"seed": self.seed,
|
|
492
|
+
"step_idx": self.step_idx,
|
|
493
|
+
"last_observation": self.last_observation,
|
|
494
|
+
"last_info": self.last_info,
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
def load_state_dict(self, state: dict[str, Any]) -> None:
|
|
498
|
+
self.seed = state["seed"]
|
|
499
|
+
self.step_idx = int(state["step_idx"])
|
|
500
|
+
self.last_observation = state["last_observation"]
|
|
501
|
+
self.last_info = state["last_info"]
|
|
502
|
+
|
|
503
|
+
async def serialize(self) -> dict[str, Any]:
|
|
504
|
+
return {
|
|
505
|
+
"name": "crafter",
|
|
506
|
+
"config": {"seed": self.seed},
|
|
507
|
+
"state": self.state_dict(),
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
async def deserialize(
|
|
512
|
+
cls,
|
|
513
|
+
payload: dict[str, Any],
|
|
514
|
+
env: StatefulEnvironment,
|
|
515
|
+
) -> CrafterEnvironmentWrapper:
|
|
516
|
+
seed = payload["config"]["seed"]
|
|
517
|
+
wrapper = cls(env=env, seed=seed)
|
|
518
|
+
wrapper.load_state_dict(payload["state"])
|
|
519
|
+
return wrapper
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
__all__ = ["CrafterEnvironmentWrapper"]
|