ag2 0.9.1__py3-none-any.whl → 0.9.1.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (357) hide show
  1. {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info}/METADATA +264 -73
  2. ag2-0.9.1.post0.dist-info/RECORD +392 -0
  3. {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info}/WHEEL +1 -2
  4. autogen/__init__.py +89 -0
  5. autogen/_website/__init__.py +3 -0
  6. autogen/_website/generate_api_references.py +427 -0
  7. autogen/_website/generate_mkdocs.py +1174 -0
  8. autogen/_website/notebook_processor.py +476 -0
  9. autogen/_website/process_notebooks.py +656 -0
  10. autogen/_website/utils.py +412 -0
  11. autogen/agentchat/__init__.py +44 -0
  12. autogen/agentchat/agent.py +182 -0
  13. autogen/agentchat/assistant_agent.py +85 -0
  14. autogen/agentchat/chat.py +309 -0
  15. autogen/agentchat/contrib/__init__.py +5 -0
  16. autogen/agentchat/contrib/agent_eval/README.md +7 -0
  17. autogen/agentchat/contrib/agent_eval/agent_eval.py +108 -0
  18. autogen/agentchat/contrib/agent_eval/criterion.py +43 -0
  19. autogen/agentchat/contrib/agent_eval/critic_agent.py +44 -0
  20. autogen/agentchat/contrib/agent_eval/quantifier_agent.py +39 -0
  21. autogen/agentchat/contrib/agent_eval/subcritic_agent.py +45 -0
  22. autogen/agentchat/contrib/agent_eval/task.py +42 -0
  23. autogen/agentchat/contrib/agent_optimizer.py +429 -0
  24. autogen/agentchat/contrib/capabilities/__init__.py +5 -0
  25. autogen/agentchat/contrib/capabilities/agent_capability.py +20 -0
  26. autogen/agentchat/contrib/capabilities/generate_images.py +301 -0
  27. autogen/agentchat/contrib/capabilities/teachability.py +393 -0
  28. autogen/agentchat/contrib/capabilities/text_compressors.py +66 -0
  29. autogen/agentchat/contrib/capabilities/tools_capability.py +22 -0
  30. autogen/agentchat/contrib/capabilities/transform_messages.py +93 -0
  31. autogen/agentchat/contrib/capabilities/transforms.py +566 -0
  32. autogen/agentchat/contrib/capabilities/transforms_util.py +122 -0
  33. autogen/agentchat/contrib/capabilities/vision_capability.py +214 -0
  34. autogen/agentchat/contrib/captainagent/__init__.py +9 -0
  35. autogen/agentchat/contrib/captainagent/agent_builder.py +790 -0
  36. autogen/agentchat/contrib/captainagent/captainagent.py +512 -0
  37. autogen/agentchat/contrib/captainagent/tool_retriever.py +335 -0
  38. autogen/agentchat/contrib/captainagent/tools/README.md +44 -0
  39. autogen/agentchat/contrib/captainagent/tools/__init__.py +5 -0
  40. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +40 -0
  41. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +28 -0
  42. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +28 -0
  43. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +28 -0
  44. autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +21 -0
  45. autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +30 -0
  46. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +27 -0
  47. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +53 -0
  48. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +53 -0
  49. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +38 -0
  50. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +21 -0
  51. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +34 -0
  52. autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +60 -0
  53. autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +61 -0
  54. autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +47 -0
  55. autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +33 -0
  56. autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +21 -0
  57. autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +35 -0
  58. autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +21 -0
  59. autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +18 -0
  60. autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +28 -0
  61. autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +31 -0
  62. autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +16 -0
  63. autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +25 -0
  64. autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +23 -0
  65. autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +27 -0
  66. autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +28 -0
  67. autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +34 -0
  68. autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +39 -0
  69. autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +23 -0
  70. autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +36 -0
  71. autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +15 -0
  72. autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +15 -0
  73. autogen/agentchat/contrib/captainagent/tools/requirements.txt +10 -0
  74. autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +34 -0
  75. autogen/agentchat/contrib/gpt_assistant_agent.py +526 -0
  76. autogen/agentchat/contrib/graph_rag/__init__.py +9 -0
  77. autogen/agentchat/contrib/graph_rag/document.py +29 -0
  78. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +170 -0
  79. autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +103 -0
  80. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +53 -0
  81. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +63 -0
  82. autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +268 -0
  83. autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +83 -0
  84. autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +210 -0
  85. autogen/agentchat/contrib/graph_rag/neo4j_native_graph_rag_capability.py +93 -0
  86. autogen/agentchat/contrib/img_utils.py +397 -0
  87. autogen/agentchat/contrib/llamaindex_conversable_agent.py +117 -0
  88. autogen/agentchat/contrib/llava_agent.py +187 -0
  89. autogen/agentchat/contrib/math_user_proxy_agent.py +464 -0
  90. autogen/agentchat/contrib/multimodal_conversable_agent.py +125 -0
  91. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +324 -0
  92. autogen/agentchat/contrib/rag/__init__.py +10 -0
  93. autogen/agentchat/contrib/rag/chromadb_query_engine.py +272 -0
  94. autogen/agentchat/contrib/rag/llamaindex_query_engine.py +198 -0
  95. autogen/agentchat/contrib/rag/mongodb_query_engine.py +329 -0
  96. autogen/agentchat/contrib/rag/query_engine.py +74 -0
  97. autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
  98. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +703 -0
  99. autogen/agentchat/contrib/society_of_mind_agent.py +199 -0
  100. autogen/agentchat/contrib/swarm_agent.py +1425 -0
  101. autogen/agentchat/contrib/text_analyzer_agent.py +79 -0
  102. autogen/agentchat/contrib/vectordb/__init__.py +5 -0
  103. autogen/agentchat/contrib/vectordb/base.py +232 -0
  104. autogen/agentchat/contrib/vectordb/chromadb.py +315 -0
  105. autogen/agentchat/contrib/vectordb/couchbase.py +407 -0
  106. autogen/agentchat/contrib/vectordb/mongodb.py +550 -0
  107. autogen/agentchat/contrib/vectordb/pgvectordb.py +928 -0
  108. autogen/agentchat/contrib/vectordb/qdrant.py +320 -0
  109. autogen/agentchat/contrib/vectordb/utils.py +126 -0
  110. autogen/agentchat/contrib/web_surfer.py +303 -0
  111. autogen/agentchat/conversable_agent.py +4020 -0
  112. autogen/agentchat/group/__init__.py +64 -0
  113. autogen/agentchat/group/available_condition.py +91 -0
  114. autogen/agentchat/group/context_condition.py +77 -0
  115. autogen/agentchat/group/context_expression.py +238 -0
  116. autogen/agentchat/group/context_str.py +41 -0
  117. autogen/agentchat/group/context_variables.py +192 -0
  118. autogen/agentchat/group/group_tool_executor.py +202 -0
  119. autogen/agentchat/group/group_utils.py +591 -0
  120. autogen/agentchat/group/handoffs.py +244 -0
  121. autogen/agentchat/group/llm_condition.py +93 -0
  122. autogen/agentchat/group/multi_agent_chat.py +237 -0
  123. autogen/agentchat/group/on_condition.py +58 -0
  124. autogen/agentchat/group/on_context_condition.py +54 -0
  125. autogen/agentchat/group/patterns/__init__.py +18 -0
  126. autogen/agentchat/group/patterns/auto.py +159 -0
  127. autogen/agentchat/group/patterns/manual.py +176 -0
  128. autogen/agentchat/group/patterns/pattern.py +288 -0
  129. autogen/agentchat/group/patterns/random.py +106 -0
  130. autogen/agentchat/group/patterns/round_robin.py +117 -0
  131. autogen/agentchat/group/reply_result.py +26 -0
  132. autogen/agentchat/group/speaker_selection_result.py +41 -0
  133. autogen/agentchat/group/targets/__init__.py +4 -0
  134. autogen/agentchat/group/targets/group_chat_target.py +132 -0
  135. autogen/agentchat/group/targets/group_manager_target.py +151 -0
  136. autogen/agentchat/group/targets/transition_target.py +413 -0
  137. autogen/agentchat/group/targets/transition_utils.py +6 -0
  138. autogen/agentchat/groupchat.py +1694 -0
  139. autogen/agentchat/realtime/__init__.py +3 -0
  140. autogen/agentchat/realtime/experimental/__init__.py +20 -0
  141. autogen/agentchat/realtime/experimental/audio_adapters/__init__.py +8 -0
  142. autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py +148 -0
  143. autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py +139 -0
  144. autogen/agentchat/realtime/experimental/audio_observer.py +42 -0
  145. autogen/agentchat/realtime/experimental/clients/__init__.py +15 -0
  146. autogen/agentchat/realtime/experimental/clients/gemini/__init__.py +7 -0
  147. autogen/agentchat/realtime/experimental/clients/gemini/client.py +274 -0
  148. autogen/agentchat/realtime/experimental/clients/oai/__init__.py +8 -0
  149. autogen/agentchat/realtime/experimental/clients/oai/base_client.py +220 -0
  150. autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py +243 -0
  151. autogen/agentchat/realtime/experimental/clients/oai/utils.py +48 -0
  152. autogen/agentchat/realtime/experimental/clients/realtime_client.py +190 -0
  153. autogen/agentchat/realtime/experimental/function_observer.py +85 -0
  154. autogen/agentchat/realtime/experimental/realtime_agent.py +158 -0
  155. autogen/agentchat/realtime/experimental/realtime_events.py +42 -0
  156. autogen/agentchat/realtime/experimental/realtime_observer.py +100 -0
  157. autogen/agentchat/realtime/experimental/realtime_swarm.py +475 -0
  158. autogen/agentchat/realtime/experimental/websockets.py +21 -0
  159. autogen/agentchat/realtime_agent/__init__.py +21 -0
  160. autogen/agentchat/user_proxy_agent.py +111 -0
  161. autogen/agentchat/utils.py +206 -0
  162. autogen/agents/__init__.py +3 -0
  163. autogen/agents/contrib/__init__.py +10 -0
  164. autogen/agents/contrib/time/__init__.py +8 -0
  165. autogen/agents/contrib/time/time_reply_agent.py +73 -0
  166. autogen/agents/contrib/time/time_tool_agent.py +51 -0
  167. autogen/agents/experimental/__init__.py +27 -0
  168. autogen/agents/experimental/deep_research/__init__.py +7 -0
  169. autogen/agents/experimental/deep_research/deep_research.py +52 -0
  170. autogen/agents/experimental/discord/__init__.py +7 -0
  171. autogen/agents/experimental/discord/discord.py +66 -0
  172. autogen/agents/experimental/document_agent/__init__.py +19 -0
  173. autogen/agents/experimental/document_agent/chroma_query_engine.py +316 -0
  174. autogen/agents/experimental/document_agent/docling_doc_ingest_agent.py +118 -0
  175. autogen/agents/experimental/document_agent/document_agent.py +461 -0
  176. autogen/agents/experimental/document_agent/document_conditions.py +50 -0
  177. autogen/agents/experimental/document_agent/document_utils.py +380 -0
  178. autogen/agents/experimental/document_agent/inmemory_query_engine.py +220 -0
  179. autogen/agents/experimental/document_agent/parser_utils.py +130 -0
  180. autogen/agents/experimental/document_agent/url_utils.py +426 -0
  181. autogen/agents/experimental/reasoning/__init__.py +7 -0
  182. autogen/agents/experimental/reasoning/reasoning_agent.py +1178 -0
  183. autogen/agents/experimental/slack/__init__.py +7 -0
  184. autogen/agents/experimental/slack/slack.py +73 -0
  185. autogen/agents/experimental/telegram/__init__.py +7 -0
  186. autogen/agents/experimental/telegram/telegram.py +77 -0
  187. autogen/agents/experimental/websurfer/__init__.py +7 -0
  188. autogen/agents/experimental/websurfer/websurfer.py +62 -0
  189. autogen/agents/experimental/wikipedia/__init__.py +7 -0
  190. autogen/agents/experimental/wikipedia/wikipedia.py +90 -0
  191. autogen/browser_utils.py +309 -0
  192. autogen/cache/__init__.py +10 -0
  193. autogen/cache/abstract_cache_base.py +75 -0
  194. autogen/cache/cache.py +203 -0
  195. autogen/cache/cache_factory.py +88 -0
  196. autogen/cache/cosmos_db_cache.py +144 -0
  197. autogen/cache/disk_cache.py +102 -0
  198. autogen/cache/in_memory_cache.py +58 -0
  199. autogen/cache/redis_cache.py +123 -0
  200. autogen/code_utils.py +596 -0
  201. autogen/coding/__init__.py +22 -0
  202. autogen/coding/base.py +119 -0
  203. autogen/coding/docker_commandline_code_executor.py +268 -0
  204. autogen/coding/factory.py +47 -0
  205. autogen/coding/func_with_reqs.py +202 -0
  206. autogen/coding/jupyter/__init__.py +23 -0
  207. autogen/coding/jupyter/base.py +36 -0
  208. autogen/coding/jupyter/docker_jupyter_server.py +167 -0
  209. autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
  210. autogen/coding/jupyter/import_utils.py +82 -0
  211. autogen/coding/jupyter/jupyter_client.py +231 -0
  212. autogen/coding/jupyter/jupyter_code_executor.py +160 -0
  213. autogen/coding/jupyter/local_jupyter_server.py +172 -0
  214. autogen/coding/local_commandline_code_executor.py +405 -0
  215. autogen/coding/markdown_code_extractor.py +45 -0
  216. autogen/coding/utils.py +56 -0
  217. autogen/doc_utils.py +34 -0
  218. autogen/events/__init__.py +7 -0
  219. autogen/events/agent_events.py +1010 -0
  220. autogen/events/base_event.py +99 -0
  221. autogen/events/client_events.py +167 -0
  222. autogen/events/helpers.py +36 -0
  223. autogen/events/print_event.py +46 -0
  224. autogen/exception_utils.py +73 -0
  225. autogen/extensions/__init__.py +5 -0
  226. autogen/fast_depends/__init__.py +16 -0
  227. autogen/fast_depends/_compat.py +80 -0
  228. autogen/fast_depends/core/__init__.py +14 -0
  229. autogen/fast_depends/core/build.py +225 -0
  230. autogen/fast_depends/core/model.py +576 -0
  231. autogen/fast_depends/dependencies/__init__.py +15 -0
  232. autogen/fast_depends/dependencies/model.py +29 -0
  233. autogen/fast_depends/dependencies/provider.py +39 -0
  234. autogen/fast_depends/library/__init__.py +10 -0
  235. autogen/fast_depends/library/model.py +46 -0
  236. autogen/fast_depends/py.typed +6 -0
  237. autogen/fast_depends/schema.py +66 -0
  238. autogen/fast_depends/use.py +280 -0
  239. autogen/fast_depends/utils.py +187 -0
  240. autogen/formatting_utils.py +83 -0
  241. autogen/function_utils.py +13 -0
  242. autogen/graph_utils.py +178 -0
  243. autogen/import_utils.py +526 -0
  244. autogen/interop/__init__.py +22 -0
  245. autogen/interop/crewai/__init__.py +7 -0
  246. autogen/interop/crewai/crewai.py +88 -0
  247. autogen/interop/interoperability.py +71 -0
  248. autogen/interop/interoperable.py +46 -0
  249. autogen/interop/langchain/__init__.py +8 -0
  250. autogen/interop/langchain/langchain_chat_model_factory.py +155 -0
  251. autogen/interop/langchain/langchain_tool.py +82 -0
  252. autogen/interop/litellm/__init__.py +7 -0
  253. autogen/interop/litellm/litellm_config_factory.py +113 -0
  254. autogen/interop/pydantic_ai/__init__.py +7 -0
  255. autogen/interop/pydantic_ai/pydantic_ai.py +168 -0
  256. autogen/interop/registry.py +69 -0
  257. autogen/io/__init__.py +15 -0
  258. autogen/io/base.py +151 -0
  259. autogen/io/console.py +56 -0
  260. autogen/io/processors/__init__.py +12 -0
  261. autogen/io/processors/base.py +21 -0
  262. autogen/io/processors/console_event_processor.py +56 -0
  263. autogen/io/run_response.py +293 -0
  264. autogen/io/thread_io_stream.py +63 -0
  265. autogen/io/websockets.py +213 -0
  266. autogen/json_utils.py +43 -0
  267. autogen/llm_config.py +379 -0
  268. autogen/logger/__init__.py +11 -0
  269. autogen/logger/base_logger.py +128 -0
  270. autogen/logger/file_logger.py +261 -0
  271. autogen/logger/logger_factory.py +42 -0
  272. autogen/logger/logger_utils.py +57 -0
  273. autogen/logger/sqlite_logger.py +523 -0
  274. autogen/math_utils.py +339 -0
  275. autogen/mcp/__init__.py +7 -0
  276. autogen/mcp/mcp_client.py +208 -0
  277. autogen/messages/__init__.py +7 -0
  278. autogen/messages/agent_messages.py +948 -0
  279. autogen/messages/base_message.py +107 -0
  280. autogen/messages/client_messages.py +171 -0
  281. autogen/messages/print_message.py +49 -0
  282. autogen/oai/__init__.py +53 -0
  283. autogen/oai/anthropic.py +714 -0
  284. autogen/oai/bedrock.py +628 -0
  285. autogen/oai/cerebras.py +299 -0
  286. autogen/oai/client.py +1435 -0
  287. autogen/oai/client_utils.py +169 -0
  288. autogen/oai/cohere.py +479 -0
  289. autogen/oai/gemini.py +990 -0
  290. autogen/oai/gemini_types.py +129 -0
  291. autogen/oai/groq.py +305 -0
  292. autogen/oai/mistral.py +303 -0
  293. autogen/oai/oai_models/__init__.py +11 -0
  294. autogen/oai/oai_models/_models.py +16 -0
  295. autogen/oai/oai_models/chat_completion.py +87 -0
  296. autogen/oai/oai_models/chat_completion_audio.py +32 -0
  297. autogen/oai/oai_models/chat_completion_message.py +86 -0
  298. autogen/oai/oai_models/chat_completion_message_tool_call.py +37 -0
  299. autogen/oai/oai_models/chat_completion_token_logprob.py +63 -0
  300. autogen/oai/oai_models/completion_usage.py +60 -0
  301. autogen/oai/ollama.py +643 -0
  302. autogen/oai/openai_utils.py +881 -0
  303. autogen/oai/together.py +370 -0
  304. autogen/retrieve_utils.py +491 -0
  305. autogen/runtime_logging.py +160 -0
  306. autogen/token_count_utils.py +267 -0
  307. autogen/tools/__init__.py +20 -0
  308. autogen/tools/contrib/__init__.py +9 -0
  309. autogen/tools/contrib/time/__init__.py +7 -0
  310. autogen/tools/contrib/time/time.py +41 -0
  311. autogen/tools/dependency_injection.py +254 -0
  312. autogen/tools/experimental/__init__.py +43 -0
  313. autogen/tools/experimental/browser_use/__init__.py +7 -0
  314. autogen/tools/experimental/browser_use/browser_use.py +161 -0
  315. autogen/tools/experimental/crawl4ai/__init__.py +7 -0
  316. autogen/tools/experimental/crawl4ai/crawl4ai.py +153 -0
  317. autogen/tools/experimental/deep_research/__init__.py +7 -0
  318. autogen/tools/experimental/deep_research/deep_research.py +328 -0
  319. autogen/tools/experimental/duckduckgo/__init__.py +7 -0
  320. autogen/tools/experimental/duckduckgo/duckduckgo_search.py +109 -0
  321. autogen/tools/experimental/google/__init__.py +14 -0
  322. autogen/tools/experimental/google/authentication/__init__.py +11 -0
  323. autogen/tools/experimental/google/authentication/credentials_hosted_provider.py +43 -0
  324. autogen/tools/experimental/google/authentication/credentials_local_provider.py +91 -0
  325. autogen/tools/experimental/google/authentication/credentials_provider.py +35 -0
  326. autogen/tools/experimental/google/drive/__init__.py +9 -0
  327. autogen/tools/experimental/google/drive/drive_functions.py +124 -0
  328. autogen/tools/experimental/google/drive/toolkit.py +88 -0
  329. autogen/tools/experimental/google/model.py +17 -0
  330. autogen/tools/experimental/google/toolkit_protocol.py +19 -0
  331. autogen/tools/experimental/google_search/__init__.py +8 -0
  332. autogen/tools/experimental/google_search/google_search.py +93 -0
  333. autogen/tools/experimental/google_search/youtube_search.py +181 -0
  334. autogen/tools/experimental/messageplatform/__init__.py +17 -0
  335. autogen/tools/experimental/messageplatform/discord/__init__.py +7 -0
  336. autogen/tools/experimental/messageplatform/discord/discord.py +288 -0
  337. autogen/tools/experimental/messageplatform/slack/__init__.py +7 -0
  338. autogen/tools/experimental/messageplatform/slack/slack.py +391 -0
  339. autogen/tools/experimental/messageplatform/telegram/__init__.py +7 -0
  340. autogen/tools/experimental/messageplatform/telegram/telegram.py +275 -0
  341. autogen/tools/experimental/perplexity/__init__.py +7 -0
  342. autogen/tools/experimental/perplexity/perplexity_search.py +260 -0
  343. autogen/tools/experimental/tavily/__init__.py +7 -0
  344. autogen/tools/experimental/tavily/tavily_search.py +183 -0
  345. autogen/tools/experimental/web_search_preview/__init__.py +7 -0
  346. autogen/tools/experimental/web_search_preview/web_search_preview.py +114 -0
  347. autogen/tools/experimental/wikipedia/__init__.py +7 -0
  348. autogen/tools/experimental/wikipedia/wikipedia.py +287 -0
  349. autogen/tools/function_utils.py +411 -0
  350. autogen/tools/tool.py +187 -0
  351. autogen/tools/toolkit.py +86 -0
  352. autogen/types.py +29 -0
  353. autogen/version.py +7 -0
  354. ag2-0.9.1.dist-info/RECORD +0 -6
  355. ag2-0.9.1.dist-info/top_level.txt +0 -1
  356. {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info/licenses}/LICENSE +0 -0
  357. {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info/licenses}/NOTICE.md +0 -0
autogen/oai/client.py ADDED
@@ -0,0 +1,1435 @@
1
+ # Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
+ # SPDX-License-Identifier: MIT
7
+ from __future__ import annotations
8
+
9
+ import inspect
10
+ import json
11
+ import logging
12
+ import re
13
+ import sys
14
+ import uuid
15
+ import warnings
16
+ from functools import lru_cache
17
+ from typing import Any, Callable, Literal, Optional, Protocol, Union
18
+
19
+ from pydantic import BaseModel, Field, HttpUrl, ValidationInfo, field_validator
20
+ from pydantic.type_adapter import TypeAdapter
21
+
22
+ from ..cache import Cache
23
+ from ..doc_utils import export_module
24
+ from ..events.client_events import StreamEvent, UsageSummaryEvent
25
+ from ..exception_utils import ModelToolNotSupportedError
26
+ from ..import_utils import optional_import_block, require_optional_import
27
+ from ..io.base import IOStream
28
+ from ..llm_config import LLMConfigEntry, register_llm_config
29
+ from ..logger.logger_utils import get_current_ts
30
+ from ..runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
31
+ from ..token_count_utils import count_token
32
+ from .client_utils import FormatterProtocol, logging_formatter
33
+ from .openai_utils import OAI_PRICE1K, get_key, is_valid_api_key
34
+
35
+ TOOL_ENABLED = False
36
+ with optional_import_block() as openai_result:
37
+ import openai
38
+
39
+ if openai_result.is_successful:
40
+ # raises exception if openai>=1 is installed and something is wrong with imports
41
+ from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
42
+ from openai import __version__ as openai_version
43
+ from openai.lib._parsing._completions import type_to_response_format_param
44
+ from openai.types.chat import ChatCompletion
45
+ from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
46
+ from openai.types.chat.chat_completion_chunk import (
47
+ ChoiceDeltaFunctionCall,
48
+ ChoiceDeltaToolCall,
49
+ ChoiceDeltaToolCallFunction,
50
+ )
51
+ from openai.types.completion import Completion
52
+ from openai.types.completion_usage import CompletionUsage
53
+
54
+ if openai.__version__ >= "1.1.0":
55
+ TOOL_ENABLED = True
56
+ ERROR = None
57
+ from openai.lib._pydantic import _ensure_strict_json_schema
58
+ else:
59
+ ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
60
+
61
+ # OpenAI = object
62
+ # AzureOpenAI = object
63
+
64
+ with optional_import_block() as cerebras_result:
65
+ from cerebras.cloud.sdk import ( # noqa
66
+ AuthenticationError as cerebras_AuthenticationError,
67
+ InternalServerError as cerebras_InternalServerError,
68
+ RateLimitError as cerebras_RateLimitError,
69
+ )
70
+
71
+ from .cerebras import CerebrasClient
72
+
73
+ if cerebras_result.is_successful:
74
+ cerebras_import_exception: Optional[ImportError] = None
75
+ else:
76
+ cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception # noqa: N816
77
+ cerebras_import_exception = ImportError("cerebras_cloud_sdk not found")
78
+
79
+ with optional_import_block() as gemini_result:
80
+ from google.api_core.exceptions import ( # noqa
81
+ InternalServerError as gemini_InternalServerError,
82
+ ResourceExhausted as gemini_ResourceExhausted,
83
+ )
84
+
85
+ from .gemini import GeminiClient
86
+
87
+ if gemini_result.is_successful:
88
+ gemini_import_exception: Optional[ImportError] = None
89
+ else:
90
+ gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
91
+ gemini_import_exception = ImportError("google-genai not found")
92
+
93
+ with optional_import_block() as anthropic_result:
94
+ from anthropic import ( # noqa
95
+ InternalServerError as anthorpic_InternalServerError,
96
+ RateLimitError as anthorpic_RateLimitError,
97
+ )
98
+
99
+ from .anthropic import AnthropicClient
100
+
101
+ if anthropic_result.is_successful:
102
+ anthropic_import_exception: Optional[ImportError] = None
103
+ else:
104
+ anthorpic_InternalServerError = anthorpic_RateLimitError = Exception # noqa: N816
105
+ anthropic_import_exception = ImportError("anthropic not found")
106
+
107
+ with optional_import_block() as mistral_result:
108
+ from mistralai.models import ( # noqa
109
+ HTTPValidationError as mistral_HTTPValidationError,
110
+ SDKError as mistral_SDKError,
111
+ )
112
+
113
+ from .mistral import MistralAIClient
114
+
115
+ if mistral_result.is_successful:
116
+ mistral_import_exception: Optional[ImportError] = None
117
+ else:
118
+ mistral_SDKError = mistral_HTTPValidationError = Exception # noqa: N816
119
+ mistral_import_exception = ImportError("mistralai not found")
120
+
121
+ with optional_import_block() as together_result:
122
+ from together.error import TogetherException as together_TogetherException
123
+
124
+ from .together import TogetherClient
125
+
126
+ if together_result.is_successful:
127
+ together_import_exception: Optional[ImportError] = None
128
+ else:
129
+ together_TogetherException = Exception # noqa: N816
130
+ together_import_exception = ImportError("together not found")
131
+
132
+ with optional_import_block() as groq_result:
133
+ from groq import ( # noqa
134
+ APIConnectionError as groq_APIConnectionError,
135
+ InternalServerError as groq_InternalServerError,
136
+ RateLimitError as groq_RateLimitError,
137
+ )
138
+
139
+ from .groq import GroqClient
140
+
141
+ if groq_result.is_successful:
142
+ groq_import_exception: Optional[ImportError] = None
143
+ else:
144
+ groq_InternalServerError = groq_RateLimitError = groq_APIConnectionError = Exception # noqa: N816
145
+ groq_import_exception = ImportError("groq not found")
146
+
147
+ with optional_import_block() as cohere_result:
148
+ from cohere.errors import ( # noqa
149
+ InternalServerError as cohere_InternalServerError,
150
+ ServiceUnavailableError as cohere_ServiceUnavailableError,
151
+ TooManyRequestsError as cohere_TooManyRequestsError,
152
+ )
153
+
154
+ from .cohere import CohereClient
155
+
156
+ if cohere_result.is_successful:
157
+ cohere_import_exception: Optional[ImportError] = None
158
+ else:
159
+ cohere_InternalServerError = cohere_TooManyRequestsError = cohere_ServiceUnavailableError = Exception # noqa: N816
160
+ cohere_import_exception = ImportError("cohere not found")
161
+
162
+ with optional_import_block() as ollama_result:
163
+ from ollama import ( # noqa
164
+ RequestError as ollama_RequestError,
165
+ ResponseError as ollama_ResponseError,
166
+ )
167
+
168
+ from .ollama import OllamaClient
169
+
170
+ if ollama_result.is_successful:
171
+ ollama_import_exception: Optional[ImportError] = None
172
+ else:
173
+ ollama_RequestError = ollama_ResponseError = Exception # noqa: N816
174
+ ollama_import_exception = ImportError("ollama not found")
175
+
176
+ with optional_import_block() as bedrock_result:
177
+ from botocore.exceptions import ( # noqa
178
+ BotoCoreError as bedrock_BotoCoreError,
179
+ ClientError as bedrock_ClientError,
180
+ )
181
+
182
+ from .bedrock import BedrockClient
183
+
184
+ if bedrock_result.is_successful:
185
+ bedrock_import_exception: Optional[ImportError] = None
186
+ else:
187
+ bedrock_BotoCoreError = bedrock_ClientError = Exception # noqa: N816
188
+ bedrock_import_exception = ImportError("botocore not found")
189
+
190
+ logger = logging.getLogger(__name__)
191
+ if not logger.handlers:
192
+ # Add the console handler.
193
+ _ch = logging.StreamHandler(stream=sys.stdout)
194
+ _ch.setFormatter(logging_formatter)
195
+ logger.addHandler(_ch)
196
+
197
+ LEGACY_DEFAULT_CACHE_SEED = 41
198
+ LEGACY_CACHE_DIR = ".cache"
199
+ OPEN_API_BASE_URL_PREFIX = "https://api.openai.com"
200
+
201
+ OPENAI_FALLBACK_KWARGS = {
202
+ "api_key",
203
+ "organization",
204
+ "project",
205
+ "base_url",
206
+ "websocket_base_url",
207
+ "timeout",
208
+ "max_retries",
209
+ "default_headers",
210
+ "default_query",
211
+ "http_client",
212
+ "_strict_response_validation",
213
+ }
214
+
215
+ AOPENAI_FALLBACK_KWARGS = {
216
+ "azure_endpoint",
217
+ "azure_deployment",
218
+ "api_version",
219
+ "api_key",
220
+ "azure_ad_token",
221
+ "azure_ad_token_provider",
222
+ "organization",
223
+ "websocket_base_url",
224
+ "timeout",
225
+ "max_retries",
226
+ "default_headers",
227
+ "default_query",
228
+ "http_client",
229
+ "_strict_response_validation",
230
+ "base_url",
231
+ "project",
232
+ }
233
+
234
+
235
+ @lru_cache(maxsize=128)
236
+ def log_cache_seed_value(cache_seed_value: Union[str, int], client: "ModelClient") -> None:
237
+ logger.debug(f"Using cache with seed value {cache_seed_value} for client {client.__class__.__name__}")
238
+
239
+
240
+ @register_llm_config
241
+ class OpenAILLMConfigEntry(LLMConfigEntry):
242
+ api_type: Literal["openai"] = "openai"
243
+ top_p: Optional[float] = None
244
+ price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
245
+ tool_choice: Optional[Literal["none", "auto", "required"]] = None
246
+ user: Optional[str] = None
247
+ extra_body: Optional[dict[str, Any]] = (
248
+ None # For VLLM - See here: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
249
+ )
250
+ # reasoning models - see: https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort
251
+ reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
252
+ max_completion_tokens: Optional[int] = None
253
+
254
+ def create_client(self) -> "ModelClient":
255
+ raise NotImplementedError("create_client method must be implemented in the derived class.")
256
+
257
+
258
+ @register_llm_config
259
+ class AzureOpenAILLMConfigEntry(LLMConfigEntry):
260
+ api_type: Literal["azure"] = "azure"
261
+ top_p: Optional[float] = None
262
+ azure_ad_token_provider: Optional[Union[str, Callable[[], str]]] = None
263
+ tool_choice: Optional[Literal["none", "auto", "required"]] = None
264
+ user: Optional[str] = None
265
+ # reasoning models - see:
266
+ # - https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning
267
+ # - https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview
268
+ reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
269
+ max_completion_tokens: Optional[int] = None
270
+
271
+ def create_client(self) -> "ModelClient":
272
+ raise NotImplementedError
273
+
274
+
275
+ @register_llm_config
276
+ class DeepSeekLLMConfigEntry(LLMConfigEntry):
277
+ api_type: Literal["deepseek"] = "deepseek"
278
+ base_url: HttpUrl = HttpUrl("https://api.deepseek.com/v1")
279
+ temperature: float = Field(0.5, ge=0.0, le=1.0)
280
+ max_tokens: int = Field(8192, ge=1, le=8192)
281
+ top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
282
+ tool_choice: Optional[Literal["none", "auto", "required"]] = None
283
+
284
+ @field_validator("top_p", mode="before")
285
+ @classmethod
286
+ def check_top_p(cls, v: Any, info: ValidationInfo) -> Any:
287
+ if v is not None and info.data.get("temperature") is not None:
288
+ raise ValueError("temperature and top_p cannot be set at the same time.")
289
+ return v
290
+
291
+ def create_client(self) -> None: # type: ignore [override]
292
+ raise NotImplementedError("DeepSeekLLMConfigEntry.create_client is not implemented.")
293
+
294
+
295
+ @export_module("autogen")
296
+ class ModelClient(Protocol):
297
+ """A client class must implement the following methods:
298
+ - create must return a response object that implements the ModelClientResponseProtocol
299
+ - cost must return the cost of the response
300
+ - get_usage must return a dict with the following keys:
301
+ - prompt_tokens
302
+ - completion_tokens
303
+ - total_tokens
304
+ - cost
305
+ - model
306
+
307
+ This class is used to create a client that can be used by OpenAIWrapper.
308
+ The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
309
+ The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
310
+ """
311
+
312
+ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
313
+
314
+ class ModelClientResponseProtocol(Protocol):
315
+ class Choice(Protocol):
316
+ class Message(Protocol):
317
+ content: Optional[str]
318
+
319
+ message: Message
320
+
321
+ choices: list[Choice]
322
+ model: str
323
+
324
+ def create(self, params: dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
325
+
326
+ def message_retrieval(
327
+ self, response: ModelClientResponseProtocol
328
+ ) -> Union[list[str], list[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
329
+ """Retrieve and return a list of strings or a list of Choice.Message from the response.
330
+
331
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
332
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
333
+ """
334
+ ... # pragma: no cover
335
+
336
+ def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover
337
+
338
+ @staticmethod
339
+ def get_usage(response: ModelClientResponseProtocol) -> dict:
340
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
341
+ ... # pragma: no cover
342
+
343
+
344
+ class PlaceHolderClient:
345
+ def __init__(self, config):
346
+ self.config = config
347
+
348
+
349
+ @require_optional_import("openai>=1.66.2", "openai")
350
+ class OpenAIClient:
351
+ """Follows the Client protocol and wraps the OpenAI client."""
352
+
353
+ def __init__(
354
+ self, client: Union[OpenAI, AzureOpenAI], response_format: Union[BaseModel, dict[str, Any], None] = None
355
+ ):
356
+ self._oai_client = client
357
+ self.response_format = response_format
358
+ if (
359
+ not isinstance(client, openai.AzureOpenAI)
360
+ and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
361
+ and not is_valid_api_key(self._oai_client.api_key)
362
+ ):
363
+ logger.warning(
364
+ "The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
365
+ )
366
+
367
+ def message_retrieval(
368
+ self, response: Union[ChatCompletion, Completion]
369
+ ) -> Union[list[str], list[ChatCompletionMessage]]:
370
+ """Retrieve the messages from the response.
371
+
372
+ Args:
373
+ response (ChatCompletion | Completion): The response from openai.
374
+
375
+
376
+ Returns:
377
+ The message from the response.
378
+ """
379
+ choices = response.choices
380
+ if isinstance(response, Completion):
381
+ return [choice.text for choice in choices] # type: ignore [union-attr]
382
+
383
+ def _format_content(content: str) -> str:
384
+ return (
385
+ self.response_format.model_validate_json(content).format()
386
+ if isinstance(self.response_format, FormatterProtocol)
387
+ else content
388
+ )
389
+
390
+ if TOOL_ENABLED:
391
+ return [ # type: ignore [return-value]
392
+ (
393
+ choice.message # type: ignore [union-attr]
394
+ if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
395
+ else _format_content(choice.message.content)
396
+ ) # type: ignore [union-attr]
397
+ for choice in choices
398
+ ]
399
+ else:
400
+ return [ # type: ignore [return-value]
401
+ choice.message if choice.message.function_call is not None else _format_content(choice.message.content) # type: ignore [union-attr]
402
+ for choice in choices
403
+ ]
404
+
405
+ @staticmethod
406
+ def _is_agent_name_error_message(message: str) -> bool:
407
+ pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.")
408
+ return bool(pattern.match(message))
409
+
410
+ @staticmethod
411
+ def _move_system_message_to_beginning(messages: list[dict[str, Any]]) -> None:
412
+ for msg in messages:
413
+ if msg["role"] == "system":
414
+ messages.insert(0, messages.pop(messages.index(msg)))
415
+ break
416
+
417
+ @staticmethod
418
+ def _patch_messages_for_deepseek_reasoner(**kwargs: Any) -> Any:
419
+ if (
420
+ "model" not in kwargs
421
+ or kwargs["model"] != "deepseek-reasoner"
422
+ or "messages" not in kwargs
423
+ or len(kwargs["messages"]) == 0
424
+ ):
425
+ return kwargs
426
+
427
+ # The system message of deepseek-reasoner must be put on the beginning of the message sequence.
428
+ OpenAIClient._move_system_message_to_beginning(kwargs["messages"])
429
+
430
+ new_messages = []
431
+ previous_role = None
432
+ for message in kwargs["messages"]:
433
+ if "role" in message:
434
+ current_role = message["role"]
435
+
436
+ # This model requires alternating roles
437
+ if current_role == previous_role:
438
+ # Swap the role
439
+ if current_role == "user":
440
+ message["role"] = "assistant"
441
+ elif current_role == "assistant":
442
+ message["role"] = "user"
443
+
444
+ previous_role = message["role"]
445
+
446
+ new_messages.append(message)
447
+
448
+ # The last message of deepseek-reasoner must be a user message
449
+ # , or an assistant message with prefix mode on (but this is supported only for beta api)
450
+ if new_messages[-1]["role"] != "user":
451
+ new_messages.append({"role": "user", "content": "continue"})
452
+
453
+ kwargs["messages"] = new_messages
454
+
455
+ return kwargs
456
+
457
+ @staticmethod
458
+ def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]:
459
+ def wrapper(*args: Any, **kwargs: Any):
460
+ try:
461
+ kwargs = OpenAIClient._patch_messages_for_deepseek_reasoner(**kwargs)
462
+ return func(*args, **kwargs)
463
+ except openai.BadRequestError as e:
464
+ response_json = e.response.json()
465
+ # Check if the error message is related to the agent name. If so, raise a ValueError with a more informative message.
466
+ if (
467
+ "error" in response_json
468
+ and "message" in response_json["error"]
469
+ and OpenAIClient._is_agent_name_error_message(response_json["error"]["message"])
470
+ ):
471
+ error_message = (
472
+ f"This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.\n"
473
+ "Please ensure that your agent name follows the correct format and doesn't include any unsupported characters.\n"
474
+ "Check the agent name and try again.\n"
475
+ f"Here is the full BadRequestError from openai:\n{e.message}."
476
+ )
477
+ raise ValueError(error_message)
478
+
479
+ raise e
480
+
481
+ return wrapper
482
+
483
+ @staticmethod
484
+ def _convert_system_role_to_user(messages: list[dict[str, Any]]) -> None:
485
+ for msg in messages:
486
+ if msg.get("role", "") == "system":
487
+ msg["role"] = "user"
488
+
489
+ def create(self, params: dict[str, Any]) -> ChatCompletion:
490
+ """Create a completion for a given config using openai's client.
491
+
492
+ Args:
493
+ params: The params for the completion.
494
+
495
+ Returns:
496
+ The completion.
497
+ """
498
+ iostream = IOStream.get_default()
499
+
500
+ if self.response_format is not None or "response_format" in params:
501
+
502
+ def _create_or_parse(*args, **kwargs):
503
+ if "stream" in kwargs:
504
+ kwargs.pop("stream")
505
+
506
+ if isinstance(kwargs["response_format"], dict):
507
+ kwargs["response_format"] = {
508
+ "type": "json_schema",
509
+ "json_schema": {
510
+ "schema": _ensure_strict_json_schema(
511
+ kwargs["response_format"], path=(), root=kwargs["response_format"]
512
+ ),
513
+ "name": "response_format",
514
+ "strict": True,
515
+ },
516
+ }
517
+ else:
518
+ kwargs["response_format"] = type_to_response_format_param(
519
+ self.response_format or params["response_format"]
520
+ )
521
+
522
+ return self._oai_client.chat.completions.create(*args, **kwargs)
523
+
524
+ create_or_parse = _create_or_parse
525
+ else:
526
+ completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
527
+ create_or_parse = completions.create
528
+ # Wrap _create_or_parse with exception handling
529
+ create_or_parse = OpenAIClient._handle_openai_bad_request_error(create_or_parse)
530
+
531
+ # needs to be updated when the o3 is released to generalize
532
+ is_o1 = "model" in params and params["model"].startswith("o1")
533
+
534
+ is_mistral = "model" in params and "mistral" in params["model"]
535
+ if is_mistral:
536
+ OpenAIClient._convert_system_role_to_user(params["messages"])
537
+
538
+ # If streaming is enabled and has messages, then iterate over the chunks of the response.
539
+ if params.get("stream", False) and "messages" in params and not is_o1:
540
+ response_contents = [""] * params.get("n", 1)
541
+ finish_reasons = [""] * params.get("n", 1)
542
+ completion_tokens = 0
543
+
544
+ # Prepare for potential function call
545
+ full_function_call: Optional[dict[str, Any]] = None
546
+ full_tool_calls: Optional[list[Optional[dict[str, Any]]]] = None
547
+
548
+ # Send the chat completion request to OpenAI's API and process the response in chunks
549
+ for chunk in create_or_parse(**params):
550
+ if chunk.choices:
551
+ for choice in chunk.choices:
552
+ content = choice.delta.content
553
+ tool_calls_chunks = choice.delta.tool_calls
554
+ finish_reasons[choice.index] = choice.finish_reason
555
+
556
+ # todo: remove this after function calls are removed from the API
557
+ # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
558
+ # begin block
559
+ function_call_chunk = (
560
+ choice.delta.function_call if hasattr(choice.delta, "function_call") else None
561
+ )
562
+ # Handle function call
563
+ if function_call_chunk:
564
+ # Handle function call
565
+ if function_call_chunk:
566
+ full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
567
+ function_call_chunk, full_function_call, completion_tokens
568
+ )
569
+ if not content:
570
+ continue
571
+ # end block
572
+
573
+ # Handle tool calls
574
+ if tool_calls_chunks:
575
+ for tool_calls_chunk in tool_calls_chunks:
576
+ # the current tool call to be reconstructed
577
+ ix = tool_calls_chunk.index
578
+ if full_tool_calls is None:
579
+ full_tool_calls = []
580
+ if ix >= len(full_tool_calls):
581
+ # in case ix is not sequential
582
+ full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
583
+
584
+ full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
585
+ tool_calls_chunk, full_tool_calls[ix], completion_tokens
586
+ )
587
+ if not content:
588
+ continue
589
+
590
+ # End handle tool calls
591
+
592
+ # If content is present, print it to the terminal and update response variables
593
+ if content is not None:
594
+ iostream.send(StreamEvent(content=content))
595
+ response_contents[choice.index] += content
596
+ completion_tokens += 1
597
+ else:
598
+ pass
599
+
600
+ # Prepare the final ChatCompletion object based on the accumulated data
601
+ model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
602
+ prompt_tokens = count_token(params["messages"], model)
603
+ response = ChatCompletion(
604
+ id=chunk.id,
605
+ model=chunk.model,
606
+ created=chunk.created,
607
+ object="chat.completion",
608
+ choices=[],
609
+ usage=CompletionUsage(
610
+ prompt_tokens=prompt_tokens,
611
+ completion_tokens=completion_tokens,
612
+ total_tokens=prompt_tokens + completion_tokens,
613
+ ),
614
+ )
615
+ for i in range(len(response_contents)):
616
+ if openai_version >= "1.5": # pragma: no cover
617
+ # OpenAI versions 1.5.0 and above
618
+ choice = Choice(
619
+ index=i,
620
+ finish_reason=finish_reasons[i],
621
+ message=ChatCompletionMessage(
622
+ role="assistant",
623
+ content=response_contents[i],
624
+ function_call=full_function_call,
625
+ tool_calls=full_tool_calls,
626
+ ),
627
+ logprobs=None,
628
+ )
629
+ else:
630
+ # OpenAI versions below 1.5.0
631
+ choice = Choice( # type: ignore [call-arg]
632
+ index=i,
633
+ finish_reason=finish_reasons[i],
634
+ message=ChatCompletionMessage(
635
+ role="assistant",
636
+ content=response_contents[i],
637
+ function_call=full_function_call,
638
+ tool_calls=full_tool_calls,
639
+ ),
640
+ )
641
+
642
+ response.choices.append(choice)
643
+ else:
644
+ # If streaming is not enabled, send a regular chat completion request
645
+ params = params.copy()
646
+ if is_o1:
647
+ # add a warning that model does not support stream
648
+ if params.get("stream", False):
649
+ warnings.warn(
650
+ f"The {params.get('model')} model does not support streaming. The stream will be set to False."
651
+ )
652
+ if params.get("tools", False):
653
+ raise ModelToolNotSupportedError(params.get("model"))
654
+ self._process_reasoning_model_params(params)
655
+ params["stream"] = False
656
+ response = create_or_parse(**params)
657
+ # remove the system_message from the response and add it in the prompt at the start.
658
+ if is_o1:
659
+ for msg in params["messages"]:
660
+ if msg["role"] == "user" and msg["content"].startswith("System message: "):
661
+ msg["role"] = "system"
662
+ msg["content"] = msg["content"][len("System message: ") :]
663
+
664
+ return response
665
+
666
+ def _process_reasoning_model_params(self, params: dict[str, Any]) -> None:
667
+ """Cater for the reasoning model (o1, o3..) parameters
668
+ please refer: https://platform.openai.com/docs/guides/reasoning#limitations
669
+ """
670
+ # Unsupported parameters
671
+ unsupported_params = [
672
+ "temperature",
673
+ "frequency_penalty",
674
+ "presence_penalty",
675
+ "top_p",
676
+ "logprobs",
677
+ "top_logprobs",
678
+ "logit_bias",
679
+ ]
680
+ model_name = params.get("model")
681
+ for param in unsupported_params:
682
+ if param in params:
683
+ warnings.warn(f"`{param}` is not supported with {model_name} model and will be ignored.")
684
+ params.pop(param)
685
+ # Replace max_tokens with max_completion_tokens as reasoning tokens are now factored in
686
+ # and max_tokens isn't valid
687
+ if "max_tokens" in params:
688
+ params["max_completion_tokens"] = params.pop("max_tokens")
689
+
690
+ # TODO - When o1-mini and o1-preview point to newer models (e.g. 2024-12-...), remove them from this list but leave the 2024-09-12 dated versions
691
+ system_not_allowed = model_name in ("o1-mini", "o1-preview", "o1-mini-2024-09-12", "o1-preview-2024-09-12")
692
+
693
+ if "messages" in params and system_not_allowed:
694
+ # o1-mini (2024-09-12) and o1-preview (2024-09-12) don't support role='system' messages, only 'user' and 'assistant'
695
+ # replace the system messages with user messages preappended with "System message: "
696
+ for msg in params["messages"]:
697
+ if msg["role"] == "system":
698
+ msg["role"] = "user"
699
+ msg["content"] = f"System message: {msg['content']}"
700
+
701
+ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
702
+ """Calculate the cost of the response."""
703
+ model = response.model
704
+ if model not in OAI_PRICE1K:
705
+ # log warning that the model is not found
706
+ logger.warning(
707
+ f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
708
+ )
709
+ return 0
710
+
711
+ n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
712
+ n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
713
+ if n_output_tokens is None:
714
+ n_output_tokens = 0
715
+ tmp_price1K = OAI_PRICE1K[model] # noqa: N806
716
+ # First value is input token rate, second value is output token rate
717
+ if isinstance(tmp_price1K, tuple):
718
+ return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
719
+ return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
720
+
721
+ @staticmethod
722
+ def get_usage(response: Union[ChatCompletion, Completion]) -> dict:
723
+ return {
724
+ "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
725
+ "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
726
+ "total_tokens": response.usage.total_tokens if response.usage is not None else 0,
727
+ "cost": response.cost if hasattr(response, "cost") else 0,
728
+ "model": response.model,
729
+ }
730
+
731
+
732
+ @export_module("autogen")
733
+ class OpenAIWrapper:
734
+ """A wrapper class for openai client."""
735
+
736
+ extra_kwargs = {
737
+ "agent",
738
+ "cache",
739
+ "cache_seed",
740
+ "filter_func",
741
+ "allow_format_str_template",
742
+ "context",
743
+ "api_version",
744
+ "api_type",
745
+ "tags",
746
+ "price",
747
+ }
748
+
749
+ @property
750
+ def openai_kwargs(self) -> set[str]:
751
+ if openai_result.is_successful:
752
+ return set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) | set(
753
+ inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs
754
+ )
755
+ else:
756
+ return OPENAI_FALLBACK_KWARGS | AOPENAI_FALLBACK_KWARGS
757
+
758
+ total_usage_summary: Optional[dict[str, Any]] = None
759
+ actual_usage_summary: Optional[dict[str, Any]] = None
760
+
761
+ def __init__(
762
+ self,
763
+ *,
764
+ config_list: Optional[list[dict[str, Any]]] = None,
765
+ **base_config: Any,
766
+ ):
767
+ """Initialize the OpenAIWrapper.
768
+
769
+ Args:
770
+ config_list: a list of config dicts to override the base_config.
771
+ They can contain additional kwargs as allowed in the [create](https://docs.ag2.ai/latest/docs/api-reference/autogen/OpenAIWrapper/#autogen.OpenAIWrapper.create) method. E.g.,
772
+
773
+ ```python
774
+ config_list = [
775
+ {
776
+ "model": "gpt-4",
777
+ "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
778
+ "api_type": "azure",
779
+ "base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
780
+ "api_version": "2024-02-01",
781
+ },
782
+ {
783
+ "model": "gpt-3.5-turbo",
784
+ "api_key": os.environ.get("OPENAI_API_KEY"),
785
+ "base_url": "https://api.openai.com/v1",
786
+ },
787
+ {
788
+ "model": "llama-7B",
789
+ "base_url": "http://127.0.0.1:8080",
790
+ },
791
+ ]
792
+ ```
793
+
794
+ base_config: base config. It can contain both keyword arguments for openai client
795
+ and additional kwargs.
796
+ When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `base_config` or in each config of `config_list`.
797
+ """
798
+ if logging_enabled():
799
+ log_new_wrapper(self, locals())
800
+ openai_config, extra_kwargs = self._separate_openai_config(base_config)
801
+ # It's OK if "model" is not provided in base_config or config_list
802
+ # Because one can provide "model" at `create` time.
803
+
804
+ self._clients: list[ModelClient] = []
805
+ self._config_list: list[dict[str, Any]] = []
806
+
807
+ if config_list:
808
+ config_list = [config.copy() for config in config_list] # make a copy before modifying
809
+ for config in config_list:
810
+ self._register_default_client(config, openai_config) # could modify the config
811
+ self._config_list.append({
812
+ **extra_kwargs,
813
+ **{k: v for k, v in config.items() if k not in self.openai_kwargs},
814
+ })
815
+ else:
816
+ self._register_default_client(extra_kwargs, openai_config)
817
+ self._config_list = [extra_kwargs]
818
+ self.wrapper_id = id(self)
819
+
820
+ def _separate_openai_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
821
+ """Separate the config into openai_config and extra_kwargs."""
822
+ openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
823
+ extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
824
+ return openai_config, extra_kwargs
825
+
826
+ def _separate_create_config(self, config: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
827
+ """Separate the config into create_config and extra_kwargs."""
828
+ create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
829
+ extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
830
+ return create_config, extra_kwargs
831
+
832
+ def _configure_azure_openai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
833
+ openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
834
+ if openai_config["azure_deployment"] is not None:
835
+ openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
836
+ openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
837
+
838
+ # Create a default Azure token provider if requested
839
+ if openai_config.get("azure_ad_token_provider") == "DEFAULT":
840
+ import azure.identity
841
+
842
+ openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider(
843
+ azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
844
+ )
845
+
846
+ def _configure_openai_config_for_bedrock(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
847
+ """Update openai_config with AWS credentials from config."""
848
+ required_keys = ["aws_access_key", "aws_secret_key", "aws_region"]
849
+ optional_keys = ["aws_session_token", "aws_profile_name"]
850
+ for key in required_keys:
851
+ if key in config:
852
+ openai_config[key] = config[key]
853
+ for key in optional_keys:
854
+ if key in config:
855
+ openai_config[key] = config[key]
856
+
857
+ def _configure_openai_config_for_vertextai(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
858
+ """Update openai_config with Google credentials from config."""
859
+ required_keys = ["gcp_project_id", "gcp_region", "gcp_auth_token"]
860
+ for key in required_keys:
861
+ if key in config:
862
+ openai_config[key] = config[key]
863
+
864
+ def _register_default_client(self, config: dict[str, Any], openai_config: dict[str, Any]) -> None:
865
+ """Create a client with the given config to override openai_config,
866
+ after removing extra kwargs.
867
+
868
+ For Azure models/deployment names there's a convenience modification of model removing dots in
869
+ the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name
870
+ "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
871
+ from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
872
+ """
873
+ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
874
+ api_type = config.get("api_type")
875
+ model_client_cls_name = config.get("model_client_cls")
876
+ response_format = config.get("response_format")
877
+ if model_client_cls_name is not None:
878
+ # a config for a custom client is set
879
+ # adding placeholder until the register_model_client is called with the appropriate class
880
+ self._clients.append(PlaceHolderClient(config))
881
+ logger.info(
882
+ f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called."
883
+ )
884
+ # TODO: logging for custom client
885
+ else:
886
+ if api_type is not None and api_type.startswith("azure"):
887
+
888
+ @require_optional_import("openai>=1.66.2", "openai")
889
+ def create_azure_openai_client() -> "AzureOpenAI":
890
+ self._configure_azure_openai(config, openai_config)
891
+ client = AzureOpenAI(**openai_config)
892
+ self._clients.append(OpenAIClient(client, response_format=response_format))
893
+ return client
894
+
895
+ client = create_azure_openai_client()
896
+ elif api_type is not None and api_type.startswith("cerebras"):
897
+ if cerebras_import_exception:
898
+ raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
899
+ client = CerebrasClient(response_format=response_format, **openai_config)
900
+ self._clients.append(client)
901
+ elif api_type is not None and api_type.startswith("google"):
902
+ if gemini_import_exception:
903
+ raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
904
+ client = GeminiClient(response_format=response_format, **openai_config)
905
+ self._clients.append(client)
906
+ elif api_type is not None and api_type.startswith("anthropic"):
907
+ if "api_key" not in config and "aws_region" in config:
908
+ self._configure_openai_config_for_bedrock(config, openai_config)
909
+ elif "api_key" not in config and "gcp_region" in config:
910
+ self._configure_openai_config_for_vertextai(config, openai_config)
911
+ if anthropic_import_exception:
912
+ raise ImportError("Please install `anthropic` to use Anthropic API.")
913
+ client = AnthropicClient(response_format=response_format, **openai_config)
914
+ self._clients.append(client)
915
+ elif api_type is not None and api_type.startswith("mistral"):
916
+ if mistral_import_exception:
917
+ raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
918
+ client = MistralAIClient(response_format=response_format, **openai_config)
919
+ self._clients.append(client)
920
+ elif api_type is not None and api_type.startswith("together"):
921
+ if together_import_exception:
922
+ raise ImportError("Please install `together` to use the Together.AI API.")
923
+ client = TogetherClient(response_format=response_format, **openai_config)
924
+ self._clients.append(client)
925
+ elif api_type is not None and api_type.startswith("groq"):
926
+ if groq_import_exception:
927
+ raise ImportError("Please install `groq` to use the Groq API.")
928
+ client = GroqClient(response_format=response_format, **openai_config)
929
+ self._clients.append(client)
930
+ elif api_type is not None and api_type.startswith("cohere"):
931
+ if cohere_import_exception:
932
+ raise ImportError("Please install `cohere` to use the Cohere API.")
933
+ client = CohereClient(response_format=response_format, **openai_config)
934
+ self._clients.append(client)
935
+ elif api_type is not None and api_type.startswith("ollama"):
936
+ if ollama_import_exception:
937
+ raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
938
+ client = OllamaClient(response_format=response_format, **openai_config)
939
+ self._clients.append(client)
940
+ elif api_type is not None and api_type.startswith("bedrock"):
941
+ self._configure_openai_config_for_bedrock(config, openai_config)
942
+ if bedrock_import_exception:
943
+ raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
944
+ client = BedrockClient(response_format=response_format, **openai_config)
945
+ self._clients.append(client)
946
+ else:
947
+
948
+ @require_optional_import("openai>=1.66.2", "openai")
949
+ def create_openai_client() -> "OpenAI":
950
+ client = OpenAI(**openai_config)
951
+ self._clients.append(OpenAIClient(client, response_format))
952
+ return client
953
+
954
+ client = create_openai_client()
955
+
956
+ if logging_enabled():
957
+ log_new_client(client, self, openai_config)
958
+
959
+ def register_model_client(self, model_client_cls: ModelClient, **kwargs: Any):
960
+ """Register a model client.
961
+
962
+ Args:
963
+ model_client_cls: A custom client class that follows the ModelClient interface
964
+ kwargs: The kwargs for the custom client class to be initialized with
965
+ """
966
+ existing_client_class = False
967
+ for i, client in enumerate(self._clients):
968
+ if isinstance(client, PlaceHolderClient):
969
+ placeholder_config = client.config
970
+
971
+ if placeholder_config.get("model_client_cls") == model_client_cls.__name__:
972
+ self._clients[i] = model_client_cls(placeholder_config, **kwargs)
973
+ return
974
+ elif isinstance(client, model_client_cls):
975
+ existing_client_class = True
976
+
977
+ if existing_client_class:
978
+ logger.warn(
979
+ f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients."
980
+ )
981
+ else:
982
+ raise ValueError(
983
+ f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. '
984
+ f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"'
985
+ )
986
+
987
+ @classmethod
988
+ def instantiate(
989
+ cls,
990
+ template: Optional[Union[str, Callable[[dict[str, Any]], str]]],
991
+ context: Optional[dict[str, Any]] = None,
992
+ allow_format_str_template: Optional[bool] = False,
993
+ ) -> Optional[str]:
994
+ if not context or template is None:
995
+ return template # type: ignore [return-value]
996
+ if isinstance(template, str):
997
+ return template.format(**context) if allow_format_str_template else template
998
+ return template(context)
999
+
1000
+ def _construct_create_params(self, create_config: dict[str, Any], extra_kwargs: dict[str, Any]) -> dict[str, Any]:
1001
+ """Prime the create_config with additional_kwargs."""
1002
+ # Validate the config
1003
+ prompt: Optional[str] = create_config.get("prompt")
1004
+ messages: Optional[list[dict[str, Any]]] = create_config.get("messages")
1005
+ if (prompt is None) == (messages is None):
1006
+ raise ValueError("Either prompt or messages should be in create config but not both.")
1007
+ context = extra_kwargs.get("context")
1008
+ if context is None:
1009
+ # No need to instantiate if no context is provided.
1010
+ return create_config
1011
+ # Instantiate the prompt or messages
1012
+ allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
1013
+ # Make a copy of the config
1014
+ params = create_config.copy()
1015
+ if prompt is not None:
1016
+ # Instantiate the prompt
1017
+ params["prompt"] = self.instantiate(prompt, context, allow_format_str_template)
1018
+ elif context:
1019
+ # Instantiate the messages
1020
+ params["messages"] = [
1021
+ (
1022
+ {
1023
+ **m,
1024
+ "content": self.instantiate(m["content"], context, allow_format_str_template),
1025
+ }
1026
+ if m.get("content")
1027
+ else m
1028
+ )
1029
+ for m in messages # type: ignore [union-attr]
1030
+ ]
1031
+ return params
1032
+
1033
+ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
1034
+ """Make a completion for a given config using available clients.
1035
+ Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
1036
+ The config in each client will be overridden by the config.
1037
+
1038
+ Args:
1039
+ **config: The config for the completion.
1040
+
1041
+ Raises:
1042
+ RuntimeError: If all declared custom model clients are not registered
1043
+ APIError: If any model client create call raises an APIError
1044
+ """
1045
+ # if ERROR:
1046
+ # raise ERROR
1047
+ invocation_id = str(uuid.uuid4())
1048
+ last = len(self._clients) - 1
1049
+ # Check if all configs in config list are activated
1050
+ non_activated = [
1051
+ client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient)
1052
+ ]
1053
+ if non_activated:
1054
+ raise RuntimeError(
1055
+ f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list."
1056
+ )
1057
+ for i, client in enumerate(self._clients):
1058
+ # merge the input config with the i-th config in the config list
1059
+ full_config = {**config, **self._config_list[i]}
1060
+ # separate the config into create_config and extra_kwargs
1061
+ create_config, extra_kwargs = self._separate_create_config(full_config)
1062
+ api_type = extra_kwargs.get("api_type")
1063
+ if api_type and api_type.startswith("azure") and "model" in create_config:
1064
+ create_config["model"] = create_config["model"].replace(".", "")
1065
+ # construct the create params
1066
+ params = self._construct_create_params(create_config, extra_kwargs)
1067
+ # get the cache_seed, filter_func and context
1068
+ cache_seed = extra_kwargs.get("cache_seed")
1069
+ cache = extra_kwargs.get("cache")
1070
+ filter_func = extra_kwargs.get("filter_func")
1071
+ context = extra_kwargs.get("context")
1072
+ agent = extra_kwargs.get("agent")
1073
+ price = extra_kwargs.get("price", None)
1074
+ if isinstance(price, list):
1075
+ price = tuple(price)
1076
+ elif isinstance(price, (float, int)):
1077
+ logger.warning(
1078
+ "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
1079
+ )
1080
+ price = (price, price)
1081
+
1082
+ total_usage = None
1083
+ actual_usage = None
1084
+
1085
+ cache_client = None
1086
+ if cache is not None:
1087
+ # Use the cache object if provided.
1088
+ cache_client = cache
1089
+ elif cache_seed is not None:
1090
+ # Legacy cache behavior, if cache_seed is given, use DiskCache.
1091
+ cache_client = Cache.disk(cache_seed, LEGACY_CACHE_DIR)
1092
+
1093
+ log_cache_seed_value(cache if cache is not None else cache_seed, client=client)
1094
+
1095
+ if cache_client is not None:
1096
+ with cache_client as cache:
1097
+ # Try to get the response from cache
1098
+ key = get_key(
1099
+ {
1100
+ **params,
1101
+ **{"response_format": json.dumps(TypeAdapter(params["response_format"]).json_schema())},
1102
+ }
1103
+ if "response_format" in params and not isinstance(params["response_format"], dict)
1104
+ else params
1105
+ )
1106
+ request_ts = get_current_ts()
1107
+
1108
+ response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
1109
+
1110
+ if response is not None:
1111
+ response.message_retrieval_function = client.message_retrieval
1112
+ try:
1113
+ response.cost # type: ignore [attr-defined]
1114
+ except AttributeError:
1115
+ # update attribute if cost is not calculated
1116
+ response.cost = client.cost(response)
1117
+ cache.set(key, response)
1118
+ total_usage = client.get_usage(response)
1119
+
1120
+ if logging_enabled():
1121
+ # Log the cache hit
1122
+ # TODO: log the config_id and pass_filter etc.
1123
+ log_chat_completion(
1124
+ invocation_id=invocation_id,
1125
+ client_id=id(client),
1126
+ wrapper_id=id(self),
1127
+ agent=agent,
1128
+ request=params,
1129
+ response=response,
1130
+ is_cached=1,
1131
+ cost=response.cost,
1132
+ start_time=request_ts,
1133
+ )
1134
+
1135
+ # check the filter
1136
+ pass_filter = filter_func is None or filter_func(context=context, response=response)
1137
+ if pass_filter or i == last:
1138
+ # Return the response if it passes the filter or it is the last client
1139
+ response.config_id = i
1140
+ response.pass_filter = pass_filter
1141
+ self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
1142
+ return response
1143
+ continue # filter is not passed; try the next config
1144
+ try:
1145
+ request_ts = get_current_ts()
1146
+ response = client.create(params)
1147
+ except Exception as e:
1148
+ if openai_result.is_successful:
1149
+ if APITimeoutError is not None and isinstance(e, APITimeoutError):
1150
+ # logger.debug(f"config {i} timed out", exc_info=True)
1151
+ if i == last:
1152
+ raise TimeoutError(
1153
+ "OpenAI API call timed out. This could be due to congestion or too small a timeout value. The timeout can be specified by setting the 'timeout' value (in seconds) in the llm_config (if you are using agents) or the OpenAIWrapper constructor (if you are using the OpenAIWrapper directly)."
1154
+ ) from e
1155
+ elif APIError is not None and isinstance(e, APIError):
1156
+ error_code = getattr(e, "code", None)
1157
+ if logging_enabled():
1158
+ log_chat_completion(
1159
+ invocation_id=invocation_id,
1160
+ client_id=id(client),
1161
+ wrapper_id=id(self),
1162
+ agent=agent,
1163
+ request=params,
1164
+ response=f"error_code:{error_code}, config {i} failed",
1165
+ is_cached=0,
1166
+ cost=0,
1167
+ start_time=request_ts,
1168
+ )
1169
+
1170
+ if error_code == "content_filter":
1171
+ # raise the error for content_filter
1172
+ raise
1173
+ # logger.debug(f"config {i} failed", exc_info=True)
1174
+ if i == last:
1175
+ raise
1176
+ else:
1177
+ raise
1178
+ else:
1179
+ raise
1180
+ except (
1181
+ gemini_InternalServerError,
1182
+ gemini_ResourceExhausted,
1183
+ anthorpic_InternalServerError,
1184
+ anthorpic_RateLimitError,
1185
+ mistral_SDKError,
1186
+ mistral_HTTPValidationError,
1187
+ together_TogetherException,
1188
+ groq_InternalServerError,
1189
+ groq_RateLimitError,
1190
+ groq_APIConnectionError,
1191
+ cohere_InternalServerError,
1192
+ cohere_TooManyRequestsError,
1193
+ cohere_ServiceUnavailableError,
1194
+ ollama_RequestError,
1195
+ ollama_ResponseError,
1196
+ bedrock_BotoCoreError,
1197
+ bedrock_ClientError,
1198
+ cerebras_AuthenticationError,
1199
+ cerebras_InternalServerError,
1200
+ cerebras_RateLimitError,
1201
+ ):
1202
+ # logger.debug(f"config {i} failed", exc_info=True)
1203
+ if i == last:
1204
+ raise
1205
+ else:
1206
+ # add cost calculation before caching no matter filter is passed or not
1207
+ if price is not None:
1208
+ response.cost = self._cost_with_customized_price(response, price)
1209
+ else:
1210
+ response.cost = client.cost(response)
1211
+ actual_usage = client.get_usage(response)
1212
+ total_usage = actual_usage.copy() if actual_usage is not None else total_usage
1213
+ self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
1214
+
1215
+ if cache_client is not None:
1216
+ # Cache the response
1217
+ with cache_client as cache:
1218
+ cache.set(key, response)
1219
+
1220
+ if logging_enabled():
1221
+ # TODO: log the config_id and pass_filter etc.
1222
+ log_chat_completion(
1223
+ invocation_id=invocation_id,
1224
+ client_id=id(client),
1225
+ wrapper_id=id(self),
1226
+ agent=agent,
1227
+ request=params,
1228
+ response=response,
1229
+ is_cached=0,
1230
+ cost=response.cost,
1231
+ start_time=request_ts,
1232
+ )
1233
+
1234
+ response.message_retrieval_function = client.message_retrieval
1235
+ # check the filter
1236
+ pass_filter = filter_func is None or filter_func(context=context, response=response)
1237
+ if pass_filter or i == last:
1238
+ # Return the response if it passes the filter or it is the last client
1239
+ response.config_id = i
1240
+ response.pass_filter = pass_filter
1241
+ return response
1242
+ continue # filter is not passed; try the next config
1243
+ raise RuntimeError("Should not reach here.")
1244
+
1245
+ @staticmethod
1246
+ def _cost_with_customized_price(
1247
+ response: ModelClient.ModelClientResponseProtocol, price_1k: tuple[float, float]
1248
+ ) -> None:
1249
+ """If a customized cost is passed, overwrite the cost in the response."""
1250
+ n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
1251
+ n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
1252
+ if n_output_tokens is None:
1253
+ n_output_tokens = 0
1254
+ return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
1255
+
1256
+ @staticmethod
1257
+ def _update_dict_from_chunk(chunk: BaseModel, d: dict[str, Any], field: str) -> int:
1258
+ """Update the dict from the chunk.
1259
+
1260
+ Reads `chunk.field` and if present updates `d[field]` accordingly.
1261
+
1262
+ Args:
1263
+ chunk: The chunk.
1264
+ d: The dict to be updated in place.
1265
+ field: The field.
1266
+
1267
+ Returns:
1268
+ The updated dict.
1269
+
1270
+ """
1271
+ completion_tokens = 0
1272
+ assert isinstance(d, dict), d
1273
+ if hasattr(chunk, field) and getattr(chunk, field) is not None:
1274
+ new_value = getattr(chunk, field)
1275
+ if isinstance(new_value, (list, dict)):
1276
+ raise NotImplementedError(
1277
+ f"Field {field} is a list or dict, which is currently not supported. "
1278
+ "Only string and numbers are supported."
1279
+ )
1280
+ if field not in d:
1281
+ d[field] = ""
1282
+ if isinstance(new_value, str):
1283
+ d[field] += getattr(chunk, field)
1284
+ else:
1285
+ d[field] = new_value
1286
+ completion_tokens = 1
1287
+
1288
+ return completion_tokens
1289
+
1290
+ @staticmethod
1291
+ def _update_function_call_from_chunk(
1292
+ function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
1293
+ full_function_call: Optional[dict[str, Any]],
1294
+ completion_tokens: int,
1295
+ ) -> tuple[dict[str, Any], int]:
1296
+ """Update the function call from the chunk.
1297
+
1298
+ Args:
1299
+ function_call_chunk: The function call chunk.
1300
+ full_function_call: The full function call.
1301
+ completion_tokens: The number of completion tokens.
1302
+
1303
+ Returns:
1304
+ The updated full function call and the updated number of completion tokens.
1305
+
1306
+ """
1307
+ # Handle function call
1308
+ if function_call_chunk:
1309
+ if full_function_call is None:
1310
+ full_function_call = {}
1311
+ for field in ["name", "arguments"]:
1312
+ completion_tokens += OpenAIWrapper._update_dict_from_chunk(
1313
+ function_call_chunk, full_function_call, field
1314
+ )
1315
+
1316
+ if full_function_call:
1317
+ return full_function_call, completion_tokens
1318
+ else:
1319
+ raise RuntimeError("Function call is not found, this should not happen.")
1320
+
1321
+ @staticmethod
1322
+ def _update_tool_calls_from_chunk(
1323
+ tool_calls_chunk: ChoiceDeltaToolCall,
1324
+ full_tool_call: Optional[dict[str, Any]],
1325
+ completion_tokens: int,
1326
+ ) -> tuple[dict[str, Any], int]:
1327
+ """Update the tool call from the chunk.
1328
+
1329
+ Args:
1330
+ tool_calls_chunk: The tool call chunk.
1331
+ full_tool_call: The full tool call.
1332
+ completion_tokens: The number of completion tokens.
1333
+
1334
+ Returns:
1335
+ The updated full tool call and the updated number of completion tokens.
1336
+
1337
+ """
1338
+ # future proofing for when tool calls other than function calls are supported
1339
+ if tool_calls_chunk.type and tool_calls_chunk.type != "function":
1340
+ raise NotImplementedError(
1341
+ f"Tool call type {tool_calls_chunk.type} is currently not supported. Only function calls are supported."
1342
+ )
1343
+
1344
+ # Handle tool call
1345
+ assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
1346
+ if tool_calls_chunk:
1347
+ if full_tool_call is None:
1348
+ full_tool_call = {}
1349
+ for field in ["index", "id", "type"]:
1350
+ completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
1351
+
1352
+ if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
1353
+ if "function" not in full_tool_call:
1354
+ full_tool_call["function"] = None
1355
+
1356
+ full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
1357
+ tool_calls_chunk.function, full_tool_call["function"], completion_tokens
1358
+ )
1359
+
1360
+ if full_tool_call:
1361
+ return full_tool_call, completion_tokens
1362
+ else:
1363
+ raise RuntimeError("Tool call is not found, this should not happen.")
1364
+
1365
+ def _update_usage(self, actual_usage, total_usage):
1366
+ def update_usage(usage_summary, response_usage):
1367
+ # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary
1368
+ for key in ModelClient.RESPONSE_USAGE_KEYS:
1369
+ if key not in response_usage:
1370
+ return usage_summary
1371
+
1372
+ model = response_usage["model"]
1373
+ cost = response_usage["cost"]
1374
+ prompt_tokens = response_usage["prompt_tokens"]
1375
+ completion_tokens = response_usage["completion_tokens"]
1376
+ if completion_tokens is None:
1377
+ completion_tokens = 0
1378
+ total_tokens = response_usage["total_tokens"]
1379
+
1380
+ if usage_summary is None:
1381
+ usage_summary = {"total_cost": cost}
1382
+ else:
1383
+ usage_summary["total_cost"] += cost
1384
+
1385
+ usage_summary[model] = {
1386
+ "cost": usage_summary.get(model, {}).get("cost", 0) + cost,
1387
+ "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens,
1388
+ "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens,
1389
+ "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens,
1390
+ }
1391
+ return usage_summary
1392
+
1393
+ if total_usage is not None:
1394
+ self.total_usage_summary = update_usage(self.total_usage_summary, total_usage)
1395
+ if actual_usage is not None:
1396
+ self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage)
1397
+
1398
+ def print_usage_summary(self, mode: Union[str, list[str]] = ["actual", "total"]) -> None:
1399
+ """Print the usage summary."""
1400
+ iostream = IOStream.get_default()
1401
+
1402
+ if isinstance(mode, list):
1403
+ if len(mode) == 0 or len(mode) > 2:
1404
+ raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
1405
+ if "actual" in mode and "total" in mode:
1406
+ mode = "both"
1407
+ elif "actual" in mode:
1408
+ mode = "actual"
1409
+ elif "total" in mode:
1410
+ mode = "total"
1411
+
1412
+ iostream.send(
1413
+ UsageSummaryEvent(
1414
+ actual_usage_summary=self.actual_usage_summary, total_usage_summary=self.total_usage_summary, mode=mode
1415
+ )
1416
+ )
1417
+
1418
+ def clear_usage_summary(self) -> None:
1419
+ """Clear the usage summary."""
1420
+ self.total_usage_summary = None
1421
+ self.actual_usage_summary = None
1422
+
1423
+ @classmethod
1424
+ def extract_text_or_completion_object(
1425
+ cls, response: ModelClient.ModelClientResponseProtocol
1426
+ ) -> Union[list[str], list[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
1427
+ """Extract the text or ChatCompletion objects from a completion or chat response.
1428
+
1429
+ Args:
1430
+ response (ChatCompletion | Completion): The response from openai.
1431
+
1432
+ Returns:
1433
+ A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
1434
+ """
1435
+ return response.message_retrieval_function(response)