hexdag 0.5.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hexdag/__init__.py +116 -0
- hexdag/__main__.py +30 -0
- hexdag/adapters/executors/__init__.py +5 -0
- hexdag/adapters/executors/local_executor.py +316 -0
- hexdag/builtin/__init__.py +6 -0
- hexdag/builtin/adapters/__init__.py +51 -0
- hexdag/builtin/adapters/anthropic/__init__.py +5 -0
- hexdag/builtin/adapters/anthropic/anthropic_adapter.py +151 -0
- hexdag/builtin/adapters/database/__init__.py +6 -0
- hexdag/builtin/adapters/database/csv/csv_adapter.py +249 -0
- hexdag/builtin/adapters/database/pgvector/__init__.py +5 -0
- hexdag/builtin/adapters/database/pgvector/pgvector_adapter.py +478 -0
- hexdag/builtin/adapters/database/sqlalchemy/sqlalchemy_adapter.py +252 -0
- hexdag/builtin/adapters/database/sqlite/__init__.py +5 -0
- hexdag/builtin/adapters/database/sqlite/sqlite_adapter.py +410 -0
- hexdag/builtin/adapters/local/README.md +59 -0
- hexdag/builtin/adapters/local/__init__.py +7 -0
- hexdag/builtin/adapters/local/local_observer_manager.py +696 -0
- hexdag/builtin/adapters/memory/__init__.py +47 -0
- hexdag/builtin/adapters/memory/file_memory_adapter.py +297 -0
- hexdag/builtin/adapters/memory/in_memory_memory.py +216 -0
- hexdag/builtin/adapters/memory/schemas.py +57 -0
- hexdag/builtin/adapters/memory/session_memory.py +178 -0
- hexdag/builtin/adapters/memory/sqlite_memory_adapter.py +215 -0
- hexdag/builtin/adapters/memory/state_memory.py +280 -0
- hexdag/builtin/adapters/mock/README.md +89 -0
- hexdag/builtin/adapters/mock/__init__.py +15 -0
- hexdag/builtin/adapters/mock/hexdag.toml +50 -0
- hexdag/builtin/adapters/mock/mock_database.py +225 -0
- hexdag/builtin/adapters/mock/mock_embedding.py +223 -0
- hexdag/builtin/adapters/mock/mock_llm.py +177 -0
- hexdag/builtin/adapters/mock/mock_tool_adapter.py +192 -0
- hexdag/builtin/adapters/mock/mock_tool_router.py +232 -0
- hexdag/builtin/adapters/openai/__init__.py +5 -0
- hexdag/builtin/adapters/openai/openai_adapter.py +634 -0
- hexdag/builtin/adapters/secret/__init__.py +7 -0
- hexdag/builtin/adapters/secret/local_secret_adapter.py +248 -0
- hexdag/builtin/adapters/unified_tool_router.py +280 -0
- hexdag/builtin/macros/__init__.py +17 -0
- hexdag/builtin/macros/conversation_agent.py +390 -0
- hexdag/builtin/macros/llm_macro.py +151 -0
- hexdag/builtin/macros/reasoning_agent.py +423 -0
- hexdag/builtin/macros/tool_macro.py +380 -0
- hexdag/builtin/nodes/__init__.py +38 -0
- hexdag/builtin/nodes/_discovery.py +123 -0
- hexdag/builtin/nodes/agent_node.py +696 -0
- hexdag/builtin/nodes/base_node_factory.py +242 -0
- hexdag/builtin/nodes/composite_node.py +926 -0
- hexdag/builtin/nodes/data_node.py +201 -0
- hexdag/builtin/nodes/expression_node.py +487 -0
- hexdag/builtin/nodes/function_node.py +454 -0
- hexdag/builtin/nodes/llm_node.py +491 -0
- hexdag/builtin/nodes/loop_node.py +920 -0
- hexdag/builtin/nodes/mapped_input.py +518 -0
- hexdag/builtin/nodes/port_call_node.py +269 -0
- hexdag/builtin/nodes/tool_call_node.py +195 -0
- hexdag/builtin/nodes/tool_utils.py +390 -0
- hexdag/builtin/prompts/__init__.py +68 -0
- hexdag/builtin/prompts/base.py +422 -0
- hexdag/builtin/prompts/chat_prompts.py +303 -0
- hexdag/builtin/prompts/error_correction_prompts.py +320 -0
- hexdag/builtin/prompts/tool_prompts.py +160 -0
- hexdag/builtin/tools/builtin_tools.py +84 -0
- hexdag/builtin/tools/database_tools.py +164 -0
- hexdag/cli/__init__.py +17 -0
- hexdag/cli/__main__.py +7 -0
- hexdag/cli/commands/__init__.py +27 -0
- hexdag/cli/commands/build_cmd.py +812 -0
- hexdag/cli/commands/create_cmd.py +208 -0
- hexdag/cli/commands/docs_cmd.py +293 -0
- hexdag/cli/commands/generate_types_cmd.py +252 -0
- hexdag/cli/commands/init_cmd.py +188 -0
- hexdag/cli/commands/pipeline_cmd.py +494 -0
- hexdag/cli/commands/plugin_dev_cmd.py +529 -0
- hexdag/cli/commands/plugins_cmd.py +441 -0
- hexdag/cli/commands/studio_cmd.py +101 -0
- hexdag/cli/commands/validate_cmd.py +221 -0
- hexdag/cli/main.py +84 -0
- hexdag/core/__init__.py +83 -0
- hexdag/core/config/__init__.py +20 -0
- hexdag/core/config/loader.py +479 -0
- hexdag/core/config/models.py +150 -0
- hexdag/core/configurable.py +294 -0
- hexdag/core/context/__init__.py +37 -0
- hexdag/core/context/execution_context.py +378 -0
- hexdag/core/docs/__init__.py +26 -0
- hexdag/core/docs/extractors.py +678 -0
- hexdag/core/docs/generators.py +890 -0
- hexdag/core/docs/models.py +120 -0
- hexdag/core/domain/__init__.py +10 -0
- hexdag/core/domain/dag.py +1225 -0
- hexdag/core/exceptions.py +234 -0
- hexdag/core/expression_parser.py +569 -0
- hexdag/core/logging.py +449 -0
- hexdag/core/models/__init__.py +17 -0
- hexdag/core/models/base.py +138 -0
- hexdag/core/orchestration/__init__.py +46 -0
- hexdag/core/orchestration/body_executor.py +481 -0
- hexdag/core/orchestration/components/__init__.py +97 -0
- hexdag/core/orchestration/components/adapter_lifecycle_manager.py +113 -0
- hexdag/core/orchestration/components/checkpoint_manager.py +134 -0
- hexdag/core/orchestration/components/execution_coordinator.py +360 -0
- hexdag/core/orchestration/components/health_check_manager.py +176 -0
- hexdag/core/orchestration/components/input_mapper.py +143 -0
- hexdag/core/orchestration/components/lifecycle_manager.py +583 -0
- hexdag/core/orchestration/components/node_executor.py +377 -0
- hexdag/core/orchestration/components/secret_manager.py +202 -0
- hexdag/core/orchestration/components/wave_executor.py +158 -0
- hexdag/core/orchestration/constants.py +17 -0
- hexdag/core/orchestration/events/README.md +312 -0
- hexdag/core/orchestration/events/__init__.py +104 -0
- hexdag/core/orchestration/events/batching.py +330 -0
- hexdag/core/orchestration/events/decorators.py +139 -0
- hexdag/core/orchestration/events/events.py +573 -0
- hexdag/core/orchestration/events/observers/__init__.py +30 -0
- hexdag/core/orchestration/events/observers/core_observers.py +690 -0
- hexdag/core/orchestration/events/observers/models.py +111 -0
- hexdag/core/orchestration/events/taxonomy.py +269 -0
- hexdag/core/orchestration/hook_context.py +237 -0
- hexdag/core/orchestration/hooks.py +437 -0
- hexdag/core/orchestration/models.py +418 -0
- hexdag/core/orchestration/orchestrator.py +910 -0
- hexdag/core/orchestration/orchestrator_factory.py +275 -0
- hexdag/core/orchestration/port_wrappers.py +327 -0
- hexdag/core/orchestration/prompt/__init__.py +32 -0
- hexdag/core/orchestration/prompt/template.py +332 -0
- hexdag/core/pipeline_builder/__init__.py +21 -0
- hexdag/core/pipeline_builder/component_instantiator.py +386 -0
- hexdag/core/pipeline_builder/include_tag.py +265 -0
- hexdag/core/pipeline_builder/pipeline_config.py +133 -0
- hexdag/core/pipeline_builder/py_tag.py +223 -0
- hexdag/core/pipeline_builder/tag_discovery.py +268 -0
- hexdag/core/pipeline_builder/yaml_builder.py +1196 -0
- hexdag/core/pipeline_builder/yaml_validator.py +569 -0
- hexdag/core/ports/__init__.py +65 -0
- hexdag/core/ports/api_call.py +133 -0
- hexdag/core/ports/database.py +489 -0
- hexdag/core/ports/embedding.py +215 -0
- hexdag/core/ports/executor.py +237 -0
- hexdag/core/ports/file_storage.py +117 -0
- hexdag/core/ports/healthcheck.py +87 -0
- hexdag/core/ports/llm.py +551 -0
- hexdag/core/ports/memory.py +70 -0
- hexdag/core/ports/observer_manager.py +130 -0
- hexdag/core/ports/secret.py +145 -0
- hexdag/core/ports/tool_router.py +94 -0
- hexdag/core/ports_builder.py +623 -0
- hexdag/core/protocols.py +273 -0
- hexdag/core/resolver.py +304 -0
- hexdag/core/schema/__init__.py +9 -0
- hexdag/core/schema/generator.py +742 -0
- hexdag/core/secrets.py +242 -0
- hexdag/core/types.py +413 -0
- hexdag/core/utils/async_warnings.py +206 -0
- hexdag/core/utils/schema_conversion.py +78 -0
- hexdag/core/utils/sql_validation.py +86 -0
- hexdag/core/validation/secure_json.py +148 -0
- hexdag/core/yaml_macro.py +517 -0
- hexdag/mcp_server.py +3120 -0
- hexdag/studio/__init__.py +10 -0
- hexdag/studio/build_ui.py +92 -0
- hexdag/studio/server/__init__.py +1 -0
- hexdag/studio/server/main.py +100 -0
- hexdag/studio/server/routes/__init__.py +9 -0
- hexdag/studio/server/routes/execute.py +208 -0
- hexdag/studio/server/routes/export.py +558 -0
- hexdag/studio/server/routes/files.py +207 -0
- hexdag/studio/server/routes/plugins.py +419 -0
- hexdag/studio/server/routes/validate.py +220 -0
- hexdag/studio/ui/index.html +13 -0
- hexdag/studio/ui/package-lock.json +2992 -0
- hexdag/studio/ui/package.json +31 -0
- hexdag/studio/ui/postcss.config.js +6 -0
- hexdag/studio/ui/public/hexdag.svg +5 -0
- hexdag/studio/ui/src/App.tsx +251 -0
- hexdag/studio/ui/src/components/Canvas.tsx +408 -0
- hexdag/studio/ui/src/components/ContextMenu.tsx +187 -0
- hexdag/studio/ui/src/components/FileBrowser.tsx +123 -0
- hexdag/studio/ui/src/components/Header.tsx +181 -0
- hexdag/studio/ui/src/components/HexdagNode.tsx +193 -0
- hexdag/studio/ui/src/components/NodeInspector.tsx +512 -0
- hexdag/studio/ui/src/components/NodePalette.tsx +262 -0
- hexdag/studio/ui/src/components/NodePortsSection.tsx +403 -0
- hexdag/studio/ui/src/components/PluginManager.tsx +347 -0
- hexdag/studio/ui/src/components/PortsEditor.tsx +481 -0
- hexdag/studio/ui/src/components/PythonEditor.tsx +195 -0
- hexdag/studio/ui/src/components/ValidationPanel.tsx +105 -0
- hexdag/studio/ui/src/components/YamlEditor.tsx +196 -0
- hexdag/studio/ui/src/components/index.ts +8 -0
- hexdag/studio/ui/src/index.css +92 -0
- hexdag/studio/ui/src/main.tsx +10 -0
- hexdag/studio/ui/src/types/index.ts +123 -0
- hexdag/studio/ui/src/vite-env.d.ts +1 -0
- hexdag/studio/ui/tailwind.config.js +29 -0
- hexdag/studio/ui/tsconfig.json +37 -0
- hexdag/studio/ui/tsconfig.node.json +13 -0
- hexdag/studio/ui/vite.config.ts +35 -0
- hexdag/visualization/__init__.py +69 -0
- hexdag/visualization/dag_visualizer.py +1020 -0
- hexdag-0.5.0.dev1.dist-info/METADATA +369 -0
- hexdag-0.5.0.dev1.dist-info/RECORD +261 -0
- hexdag-0.5.0.dev1.dist-info/WHEEL +4 -0
- hexdag-0.5.0.dev1.dist-info/entry_points.txt +4 -0
- hexdag-0.5.0.dev1.dist-info/licenses/LICENSE +190 -0
- hexdag_plugins/.gitignore +43 -0
- hexdag_plugins/README.md +73 -0
- hexdag_plugins/__init__.py +1 -0
- hexdag_plugins/azure/LICENSE +21 -0
- hexdag_plugins/azure/README.md +414 -0
- hexdag_plugins/azure/__init__.py +21 -0
- hexdag_plugins/azure/azure_blob_adapter.py +450 -0
- hexdag_plugins/azure/azure_cosmos_adapter.py +383 -0
- hexdag_plugins/azure/azure_keyvault_adapter.py +314 -0
- hexdag_plugins/azure/azure_openai_adapter.py +415 -0
- hexdag_plugins/azure/pyproject.toml +107 -0
- hexdag_plugins/azure/tests/__init__.py +1 -0
- hexdag_plugins/azure/tests/test_azure_blob_adapter.py +350 -0
- hexdag_plugins/azure/tests/test_azure_cosmos_adapter.py +323 -0
- hexdag_plugins/azure/tests/test_azure_keyvault_adapter.py +330 -0
- hexdag_plugins/azure/tests/test_azure_openai_adapter.py +329 -0
- hexdag_plugins/hexdag_etl/README.md +168 -0
- hexdag_plugins/hexdag_etl/__init__.py +53 -0
- hexdag_plugins/hexdag_etl/examples/01_simple_pandas_transform.py +270 -0
- hexdag_plugins/hexdag_etl/examples/02_simple_pandas_only.py +149 -0
- hexdag_plugins/hexdag_etl/examples/03_file_io_pipeline.py +109 -0
- hexdag_plugins/hexdag_etl/examples/test_pandas_transform.py +84 -0
- hexdag_plugins/hexdag_etl/hexdag.toml +25 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/__init__.py +48 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/__init__.py +13 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/api_extract.py +230 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/base_node_factory.py +181 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/file_io.py +415 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/outlook.py +492 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/pandas_transform.py +563 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/sql_extract_load.py +112 -0
- hexdag_plugins/hexdag_etl/pyproject.toml +82 -0
- hexdag_plugins/hexdag_etl/test_transform.py +54 -0
- hexdag_plugins/hexdag_etl/tests/test_plugin_integration.py +62 -0
- hexdag_plugins/mysql_adapter/LICENSE +21 -0
- hexdag_plugins/mysql_adapter/README.md +224 -0
- hexdag_plugins/mysql_adapter/__init__.py +6 -0
- hexdag_plugins/mysql_adapter/mysql_adapter.py +408 -0
- hexdag_plugins/mysql_adapter/pyproject.toml +93 -0
- hexdag_plugins/mysql_adapter/tests/test_mysql_adapter.py +259 -0
- hexdag_plugins/storage/README.md +184 -0
- hexdag_plugins/storage/__init__.py +19 -0
- hexdag_plugins/storage/file/__init__.py +5 -0
- hexdag_plugins/storage/file/local.py +325 -0
- hexdag_plugins/storage/ports/__init__.py +5 -0
- hexdag_plugins/storage/ports/vector_store.py +236 -0
- hexdag_plugins/storage/sql/__init__.py +7 -0
- hexdag_plugins/storage/sql/base.py +187 -0
- hexdag_plugins/storage/sql/mysql.py +27 -0
- hexdag_plugins/storage/sql/postgresql.py +27 -0
- hexdag_plugins/storage/tests/__init__.py +1 -0
- hexdag_plugins/storage/tests/test_local_file_storage.py +161 -0
- hexdag_plugins/storage/tests/test_sql_adapters.py +212 -0
- hexdag_plugins/storage/vector/__init__.py +7 -0
- hexdag_plugins/storage/vector/chromadb.py +223 -0
- hexdag_plugins/storage/vector/in_memory.py +285 -0
- hexdag_plugins/storage/vector/pgvector.py +502 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Tests for Azure Key Vault adapter."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from hexdag_plugins.azure.azure_keyvault_adapter import AzureKeyVaultAdapter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def keyvault_adapter():
|
|
13
|
+
"""Create Azure Key Vault adapter for testing."""
|
|
14
|
+
return AzureKeyVaultAdapter(
|
|
15
|
+
vault_url="https://test-vault.vault.azure.net",
|
|
16
|
+
use_managed_identity=True,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def keyvault_adapter_service_principal():
|
|
22
|
+
"""Create Azure Key Vault adapter with service principal auth."""
|
|
23
|
+
return AzureKeyVaultAdapter(
|
|
24
|
+
vault_url="https://test-vault.vault.azure.net",
|
|
25
|
+
use_managed_identity=False,
|
|
26
|
+
tenant_id="test-tenant-id",
|
|
27
|
+
client_id="test-client-id",
|
|
28
|
+
client_secret="test-client-secret",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def keyvault_adapter_no_cache():
|
|
34
|
+
"""Create adapter with caching disabled."""
|
|
35
|
+
return AzureKeyVaultAdapter(
|
|
36
|
+
vault_url="https://test-vault.vault.azure.net",
|
|
37
|
+
use_managed_identity=True,
|
|
38
|
+
cache_secrets=False,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.asyncio
|
|
43
|
+
async def test_adapter_initialization(keyvault_adapter):
|
|
44
|
+
"""Test adapter initializes with correct parameters."""
|
|
45
|
+
assert keyvault_adapter.vault_url == "https://test-vault.vault.azure.net"
|
|
46
|
+
assert keyvault_adapter.use_managed_identity is True
|
|
47
|
+
assert keyvault_adapter.cache_secrets is True
|
|
48
|
+
assert keyvault_adapter.cache_ttl == 300
|
|
49
|
+
assert keyvault_adapter._client is None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.mark.asyncio
|
|
53
|
+
async def test_adapter_initialization_service_principal(keyvault_adapter_service_principal):
|
|
54
|
+
"""Test adapter initializes with service principal."""
|
|
55
|
+
assert keyvault_adapter_service_principal.use_managed_identity is False
|
|
56
|
+
assert keyvault_adapter_service_principal.tenant_id == "test-tenant-id"
|
|
57
|
+
assert keyvault_adapter_service_principal.client_id == "test-client-id"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.mark.asyncio
|
|
61
|
+
async def test_get_secret_success(keyvault_adapter):
|
|
62
|
+
"""Test successful secret retrieval."""
|
|
63
|
+
mock_secret = MagicMock()
|
|
64
|
+
mock_secret.value = "secret-value-123"
|
|
65
|
+
|
|
66
|
+
mock_client = MagicMock()
|
|
67
|
+
mock_client.get_secret.return_value = mock_secret
|
|
68
|
+
|
|
69
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
70
|
+
result = await keyvault_adapter.aget("MY-SECRET")
|
|
71
|
+
|
|
72
|
+
assert result == "secret-value-123"
|
|
73
|
+
mock_client.get_secret.assert_called_once_with("MY-SECRET")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.mark.asyncio
|
|
77
|
+
async def test_get_secret_from_cache(keyvault_adapter):
|
|
78
|
+
"""Test secret retrieval from cache."""
|
|
79
|
+
# Pre-populate cache
|
|
80
|
+
keyvault_adapter._cache["CACHED-SECRET"] = ("cached-value", time.time())
|
|
81
|
+
|
|
82
|
+
mock_client = MagicMock()
|
|
83
|
+
|
|
84
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
85
|
+
result = await keyvault_adapter.aget("CACHED-SECRET")
|
|
86
|
+
|
|
87
|
+
assert result == "cached-value"
|
|
88
|
+
mock_client.get_secret.assert_not_called() # Should not hit the API
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.mark.asyncio
|
|
92
|
+
async def test_get_secret_expired_cache(keyvault_adapter):
|
|
93
|
+
"""Test secret retrieval with expired cache."""
|
|
94
|
+
# Pre-populate cache with expired entry
|
|
95
|
+
keyvault_adapter._cache["EXPIRED-SECRET"] = ("old-value", time.time() - 400)
|
|
96
|
+
|
|
97
|
+
mock_secret = MagicMock()
|
|
98
|
+
mock_secret.value = "new-value"
|
|
99
|
+
|
|
100
|
+
mock_client = MagicMock()
|
|
101
|
+
mock_client.get_secret.return_value = mock_secret
|
|
102
|
+
|
|
103
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
104
|
+
result = await keyvault_adapter.aget("EXPIRED-SECRET")
|
|
105
|
+
|
|
106
|
+
assert result == "new-value"
|
|
107
|
+
mock_client.get_secret.assert_called_once()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.asyncio
|
|
111
|
+
async def test_get_secret_no_cache(keyvault_adapter_no_cache):
|
|
112
|
+
"""Test secret retrieval with caching disabled."""
|
|
113
|
+
mock_secret = MagicMock()
|
|
114
|
+
mock_secret.value = "secret-value"
|
|
115
|
+
|
|
116
|
+
mock_client = MagicMock()
|
|
117
|
+
mock_client.get_secret.return_value = mock_secret
|
|
118
|
+
|
|
119
|
+
with patch.object(keyvault_adapter_no_cache, "_get_client", return_value=mock_client):
|
|
120
|
+
# First call
|
|
121
|
+
await keyvault_adapter_no_cache.aget("NO-CACHE-SECRET")
|
|
122
|
+
# Second call should also hit the API
|
|
123
|
+
await keyvault_adapter_no_cache.aget("NO-CACHE-SECRET")
|
|
124
|
+
|
|
125
|
+
assert mock_client.get_secret.call_count == 2
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@pytest.mark.asyncio
|
|
129
|
+
async def test_get_secret_not_found(keyvault_adapter):
|
|
130
|
+
"""Test secret retrieval raises ValueError for missing secret."""
|
|
131
|
+
mock_client = MagicMock()
|
|
132
|
+
mock_client.get_secret.side_effect = Exception("SecretNotFound: Secret not found")
|
|
133
|
+
|
|
134
|
+
with (
|
|
135
|
+
patch.object(keyvault_adapter, "_get_client", return_value=mock_client),
|
|
136
|
+
pytest.raises(ValueError, match="not found"),
|
|
137
|
+
):
|
|
138
|
+
await keyvault_adapter.aget("MISSING-SECRET")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
async def test_get_secret_null_value(keyvault_adapter):
|
|
143
|
+
"""Test secret retrieval raises ValueError for null value."""
|
|
144
|
+
mock_secret = MagicMock()
|
|
145
|
+
mock_secret.value = None
|
|
146
|
+
|
|
147
|
+
mock_client = MagicMock()
|
|
148
|
+
mock_client.get_secret.return_value = mock_secret
|
|
149
|
+
|
|
150
|
+
with (
|
|
151
|
+
patch.object(keyvault_adapter, "_get_client", return_value=mock_client),
|
|
152
|
+
pytest.raises(ValueError, match="has no value"),
|
|
153
|
+
):
|
|
154
|
+
await keyvault_adapter.aget("NULL-SECRET")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@pytest.mark.asyncio
|
|
158
|
+
async def test_get_batch_success(keyvault_adapter):
|
|
159
|
+
"""Test batch secret retrieval."""
|
|
160
|
+
|
|
161
|
+
def mock_get_secret(name):
|
|
162
|
+
mock = MagicMock()
|
|
163
|
+
mock.value = f"value-for-{name}"
|
|
164
|
+
return mock
|
|
165
|
+
|
|
166
|
+
mock_client = MagicMock()
|
|
167
|
+
mock_client.get_secret.side_effect = mock_get_secret
|
|
168
|
+
|
|
169
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
170
|
+
results = await keyvault_adapter.aget_batch(["SECRET1", "SECRET2"])
|
|
171
|
+
|
|
172
|
+
assert len(results) == 2
|
|
173
|
+
assert results["SECRET1"] == "value-for-SECRET1"
|
|
174
|
+
assert results["SECRET2"] == "value-for-SECRET2"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@pytest.mark.asyncio
|
|
178
|
+
async def test_get_batch_partial_success(keyvault_adapter):
|
|
179
|
+
"""Test batch retrieval skips missing secrets."""
|
|
180
|
+
|
|
181
|
+
def mock_get_secret(name):
|
|
182
|
+
if name == "MISSING":
|
|
183
|
+
raise Exception("SecretNotFound")
|
|
184
|
+
mock = MagicMock()
|
|
185
|
+
mock.value = f"value-for-{name}"
|
|
186
|
+
return mock
|
|
187
|
+
|
|
188
|
+
mock_client = MagicMock()
|
|
189
|
+
mock_client.get_secret.side_effect = mock_get_secret
|
|
190
|
+
|
|
191
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
192
|
+
results = await keyvault_adapter.aget_batch(["SECRET1", "MISSING", "SECRET2"])
|
|
193
|
+
|
|
194
|
+
assert len(results) == 2
|
|
195
|
+
assert "SECRET1" in results
|
|
196
|
+
assert "MISSING" not in results
|
|
197
|
+
assert "SECRET2" in results
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@pytest.mark.asyncio
|
|
201
|
+
async def test_set_secret(keyvault_adapter):
|
|
202
|
+
"""Test setting a secret."""
|
|
203
|
+
mock_client = MagicMock()
|
|
204
|
+
mock_client.set_secret = MagicMock()
|
|
205
|
+
|
|
206
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
207
|
+
await keyvault_adapter.aset("NEW-SECRET", "secret-value")
|
|
208
|
+
|
|
209
|
+
mock_client.set_secret.assert_called_once_with("NEW-SECRET", "secret-value")
|
|
210
|
+
# Verify it's also cached
|
|
211
|
+
assert "NEW-SECRET" in keyvault_adapter._cache
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@pytest.mark.asyncio
|
|
215
|
+
async def test_delete_secret(keyvault_adapter):
|
|
216
|
+
"""Test deleting a secret."""
|
|
217
|
+
# Pre-populate cache
|
|
218
|
+
keyvault_adapter._cache["TO-DELETE"] = ("value", time.time())
|
|
219
|
+
|
|
220
|
+
mock_client = MagicMock()
|
|
221
|
+
mock_client.begin_delete_secret = MagicMock()
|
|
222
|
+
|
|
223
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
224
|
+
await keyvault_adapter.adelete("TO-DELETE")
|
|
225
|
+
|
|
226
|
+
mock_client.begin_delete_secret.assert_called_once_with("TO-DELETE")
|
|
227
|
+
assert "TO-DELETE" not in keyvault_adapter._cache
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@pytest.mark.asyncio
|
|
231
|
+
async def test_list_secrets(keyvault_adapter):
|
|
232
|
+
"""Test listing all secrets."""
|
|
233
|
+
mock_secret1 = MagicMock()
|
|
234
|
+
mock_secret1.name = "SECRET1"
|
|
235
|
+
mock_secret2 = MagicMock()
|
|
236
|
+
mock_secret2.name = "SECRET2"
|
|
237
|
+
|
|
238
|
+
mock_client = MagicMock()
|
|
239
|
+
mock_client.list_properties_of_secrets.return_value = [mock_secret1, mock_secret2]
|
|
240
|
+
|
|
241
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
242
|
+
names = await keyvault_adapter.alist()
|
|
243
|
+
|
|
244
|
+
assert names == ["SECRET1", "SECRET2"]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@pytest.mark.asyncio
|
|
248
|
+
async def test_clear_cache(keyvault_adapter):
|
|
249
|
+
"""Test clearing the cache."""
|
|
250
|
+
keyvault_adapter._cache["SECRET1"] = ("value1", time.time())
|
|
251
|
+
keyvault_adapter._cache["SECRET2"] = ("value2", time.time())
|
|
252
|
+
|
|
253
|
+
keyvault_adapter.clear_cache()
|
|
254
|
+
|
|
255
|
+
assert len(keyvault_adapter._cache) == 0
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@pytest.mark.asyncio
|
|
259
|
+
async def test_health_check_healthy(keyvault_adapter):
|
|
260
|
+
"""Test health check returns healthy status."""
|
|
261
|
+
mock_client = MagicMock()
|
|
262
|
+
mock_client.list_properties_of_secrets.return_value = []
|
|
263
|
+
|
|
264
|
+
with patch.object(keyvault_adapter, "_get_client", return_value=mock_client):
|
|
265
|
+
status = await keyvault_adapter.ahealth_check()
|
|
266
|
+
|
|
267
|
+
assert status.status == "healthy"
|
|
268
|
+
assert status.adapter_name == "AzureKeyVault"
|
|
269
|
+
assert status.details["vault_url"] == "https://test-vault.vault.azure.net"
|
|
270
|
+
assert status.details["auth_method"] == "managed_identity"
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@pytest.mark.asyncio
|
|
274
|
+
async def test_health_check_service_principal(keyvault_adapter_service_principal):
|
|
275
|
+
"""Test health check shows service principal auth method."""
|
|
276
|
+
mock_client = MagicMock()
|
|
277
|
+
mock_client.list_properties_of_secrets.return_value = []
|
|
278
|
+
|
|
279
|
+
with patch.object(keyvault_adapter_service_principal, "_get_client", return_value=mock_client):
|
|
280
|
+
status = await keyvault_adapter_service_principal.ahealth_check()
|
|
281
|
+
|
|
282
|
+
assert status.details["auth_method"] == "service_principal"
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@pytest.mark.asyncio
|
|
286
|
+
async def test_health_check_unhealthy(keyvault_adapter):
|
|
287
|
+
"""Test health check returns unhealthy on error."""
|
|
288
|
+
with patch.object(keyvault_adapter, "_get_client", side_effect=Exception("Connection failed")):
|
|
289
|
+
status = await keyvault_adapter.ahealth_check()
|
|
290
|
+
|
|
291
|
+
assert status.status == "unhealthy"
|
|
292
|
+
assert "error" in status.details
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@pytest.mark.asyncio
|
|
296
|
+
async def test_to_dict(keyvault_adapter):
|
|
297
|
+
"""Test serialization excludes credentials."""
|
|
298
|
+
config = keyvault_adapter.to_dict()
|
|
299
|
+
|
|
300
|
+
assert "vault_url" in config
|
|
301
|
+
assert "cache_secrets" in config
|
|
302
|
+
assert "client_secret" not in config # Credentials excluded
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@pytest.mark.asyncio
|
|
306
|
+
async def test_service_principal_requires_all_credentials():
|
|
307
|
+
"""Test service principal auth requires all credentials."""
|
|
308
|
+
adapter = AzureKeyVaultAdapter(
|
|
309
|
+
vault_url="https://test.vault.azure.net",
|
|
310
|
+
use_managed_identity=False,
|
|
311
|
+
tenant_id="tenant",
|
|
312
|
+
# Missing client_id and client_secret
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
with pytest.raises(ValueError, match="tenant_id, client_id, and client_secret"):
|
|
316
|
+
adapter._get_client()
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@pytest.mark.asyncio
|
|
320
|
+
async def test_client_lazy_initialization(keyvault_adapter):
|
|
321
|
+
"""Test client is lazily initialized."""
|
|
322
|
+
assert keyvault_adapter._client is None
|
|
323
|
+
|
|
324
|
+
# The _get_client method creates the client
|
|
325
|
+
with (
|
|
326
|
+
patch("hexdag_plugins.azure.azure_keyvault_adapter.DefaultAzureCredential"),
|
|
327
|
+
patch("hexdag_plugins.azure.azure_keyvault_adapter.SecretClient") as mock_client,
|
|
328
|
+
):
|
|
329
|
+
keyvault_adapter._get_client()
|
|
330
|
+
mock_client.assert_called_once()
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""Tests for Azure OpenAI adapter."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from hexdag.core.ports.llm import Message
|
|
7
|
+
|
|
8
|
+
from hexdag_plugins.azure.azure_openai_adapter import AzureOpenAIAdapter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def azure_adapter():
|
|
13
|
+
"""Create Azure OpenAI adapter for testing."""
|
|
14
|
+
return AzureOpenAIAdapter(
|
|
15
|
+
api_key="test-key",
|
|
16
|
+
resource_name="test-resource",
|
|
17
|
+
deployment_id="gpt-4",
|
|
18
|
+
api_version="2024-02-15-preview",
|
|
19
|
+
temperature=0.7,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.asyncio
|
|
24
|
+
async def test_adapter_initialization(azure_adapter):
|
|
25
|
+
"""Test adapter initializes with correct parameters."""
|
|
26
|
+
assert azure_adapter.api_key == "test-key"
|
|
27
|
+
assert azure_adapter.resource_name == "test-resource"
|
|
28
|
+
assert azure_adapter.deployment_id == "gpt-4"
|
|
29
|
+
assert azure_adapter.api_version == "2024-02-15-preview"
|
|
30
|
+
assert azure_adapter.temperature == 0.7
|
|
31
|
+
assert azure_adapter.api_base == "https://test-resource.openai.azure.com"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.mark.asyncio
|
|
35
|
+
async def test_aresponse_success(azure_adapter):
|
|
36
|
+
"""Test successful response generation."""
|
|
37
|
+
# Mock OpenAI client
|
|
38
|
+
mock_response = MagicMock()
|
|
39
|
+
mock_response.choices = [MagicMock()]
|
|
40
|
+
mock_response.choices[0].message.content = "Hello from Azure!"
|
|
41
|
+
|
|
42
|
+
mock_client = AsyncMock()
|
|
43
|
+
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
44
|
+
|
|
45
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
46
|
+
messages = [Message(role="user", content="Hello")]
|
|
47
|
+
response = await azure_adapter.aresponse(messages)
|
|
48
|
+
|
|
49
|
+
assert response == "Hello from Azure!"
|
|
50
|
+
mock_client.chat.completions.create.assert_called_once()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.mark.asyncio
|
|
54
|
+
async def test_aresponse_with_tools(azure_adapter):
|
|
55
|
+
"""Test response with tool calling."""
|
|
56
|
+
from hexdag.core.ports.llm import LLMResponse
|
|
57
|
+
|
|
58
|
+
# Mock tool call response
|
|
59
|
+
mock_tool_call = MagicMock()
|
|
60
|
+
mock_tool_call.id = "call_123"
|
|
61
|
+
mock_tool_call.function.name = "search"
|
|
62
|
+
mock_tool_call.function.arguments = {"query": "test"}
|
|
63
|
+
|
|
64
|
+
mock_response = MagicMock()
|
|
65
|
+
mock_response.choices = [MagicMock()]
|
|
66
|
+
mock_response.choices[0].message.content = "Let me search"
|
|
67
|
+
mock_response.choices[0].message.tool_calls = [mock_tool_call]
|
|
68
|
+
mock_response.choices[0].finish_reason = "tool_calls"
|
|
69
|
+
|
|
70
|
+
mock_client = AsyncMock()
|
|
71
|
+
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
72
|
+
|
|
73
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
74
|
+
messages = [Message(role="user", content="Search for cats")]
|
|
75
|
+
tools = [
|
|
76
|
+
{
|
|
77
|
+
"type": "function",
|
|
78
|
+
"function": {
|
|
79
|
+
"name": "search",
|
|
80
|
+
"description": "Search",
|
|
81
|
+
"parameters": {"type": "object", "properties": {}},
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
response = await azure_adapter.aresponse_with_tools(messages, tools)
|
|
87
|
+
|
|
88
|
+
assert isinstance(response, LLMResponse)
|
|
89
|
+
assert response.content == "Let me search"
|
|
90
|
+
assert len(response.tool_calls) == 1
|
|
91
|
+
assert response.tool_calls[0].name == "search"
|
|
92
|
+
assert response.finish_reason == "tool_calls"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@pytest.mark.asyncio
|
|
96
|
+
async def test_aresponse_error_handling(azure_adapter):
|
|
97
|
+
"""Test error handling returns None."""
|
|
98
|
+
mock_client = AsyncMock()
|
|
99
|
+
mock_client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
|
100
|
+
|
|
101
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
102
|
+
messages = [Message(role="user", content="Hello")]
|
|
103
|
+
response = await azure_adapter.aresponse(messages)
|
|
104
|
+
|
|
105
|
+
assert response is None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.mark.asyncio
|
|
109
|
+
async def test_health_check_healthy(azure_adapter):
|
|
110
|
+
"""Test health check returns healthy status."""
|
|
111
|
+
mock_response = MagicMock()
|
|
112
|
+
mock_response.choices = [MagicMock()]
|
|
113
|
+
mock_response.choices[0].message.content = "OK"
|
|
114
|
+
|
|
115
|
+
mock_client = AsyncMock()
|
|
116
|
+
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
117
|
+
|
|
118
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
119
|
+
status = await azure_adapter.ahealth_check()
|
|
120
|
+
|
|
121
|
+
assert status.status == "healthy"
|
|
122
|
+
assert status.adapter_name == "AzureOpenAI[gpt-4]"
|
|
123
|
+
assert status.latency_ms > 0
|
|
124
|
+
assert status.details["resource"] == "test-resource"
|
|
125
|
+
assert status.details["deployment"] == "gpt-4"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@pytest.mark.asyncio
|
|
129
|
+
async def test_health_check_unhealthy(azure_adapter):
|
|
130
|
+
"""Test health check handles errors."""
|
|
131
|
+
mock_client = AsyncMock()
|
|
132
|
+
mock_client.chat.completions.create = AsyncMock(side_effect=Exception("Connection failed"))
|
|
133
|
+
|
|
134
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
135
|
+
status = await azure_adapter.ahealth_check()
|
|
136
|
+
|
|
137
|
+
assert status.status == "unhealthy"
|
|
138
|
+
assert "error" in status.details
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
async def test_client_lazy_initialization(azure_adapter):
|
|
143
|
+
"""Test OpenAI client is lazily initialized."""
|
|
144
|
+
assert azure_adapter._client is None
|
|
145
|
+
|
|
146
|
+
# Patch at module level where it's imported
|
|
147
|
+
with patch("hexdag_plugins.azure.azure_openai_adapter.AsyncAzureOpenAI") as mock_azure:
|
|
148
|
+
mock_client = MagicMock()
|
|
149
|
+
mock_azure.return_value = mock_client
|
|
150
|
+
|
|
151
|
+
client = azure_adapter._get_client()
|
|
152
|
+
|
|
153
|
+
assert client is mock_client
|
|
154
|
+
assert azure_adapter._client is mock_client
|
|
155
|
+
mock_azure.assert_called_once_with(
|
|
156
|
+
api_key="test-key",
|
|
157
|
+
api_version="2024-02-15-preview",
|
|
158
|
+
azure_endpoint="https://test-resource.openai.azure.com",
|
|
159
|
+
timeout=30.0,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_client_reuse():
|
|
165
|
+
"""Test that client is reused across multiple calls."""
|
|
166
|
+
adapter = AzureOpenAIAdapter(
|
|
167
|
+
api_key="test-key",
|
|
168
|
+
resource_name="test-resource",
|
|
169
|
+
deployment_id="gpt-4",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Mock AsyncAzureOpenAI creation
|
|
173
|
+
with patch("hexdag_plugins.azure.azure_openai_adapter.AsyncAzureOpenAI") as mock_azure:
|
|
174
|
+
mock_client = MagicMock()
|
|
175
|
+
mock_azure.return_value = mock_client
|
|
176
|
+
|
|
177
|
+
# First call creates client
|
|
178
|
+
client1 = adapter._get_client()
|
|
179
|
+
assert mock_azure.call_count == 1
|
|
180
|
+
|
|
181
|
+
# Second call reuses client
|
|
182
|
+
client2 = adapter._get_client()
|
|
183
|
+
assert mock_azure.call_count == 1 # Still 1, not 2
|
|
184
|
+
|
|
185
|
+
assert client1 is client2
|
|
186
|
+
assert client1 is mock_client
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_custom_timeout(azure_adapter):
|
|
191
|
+
"""Test custom timeout is respected."""
|
|
192
|
+
custom_adapter = AzureOpenAIAdapter(
|
|
193
|
+
api_key="test-key",
|
|
194
|
+
resource_name="test-resource",
|
|
195
|
+
deployment_id="gpt-4",
|
|
196
|
+
timeout=60.0,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
assert custom_adapter.timeout == 60.0
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@pytest.mark.asyncio
|
|
203
|
+
async def test_max_tokens_parameter(azure_adapter):
|
|
204
|
+
"""Test max_tokens parameter is passed to API."""
|
|
205
|
+
azure_adapter.max_tokens = 100
|
|
206
|
+
|
|
207
|
+
mock_response = MagicMock()
|
|
208
|
+
mock_response.choices = [MagicMock()]
|
|
209
|
+
mock_response.choices[0].message.content = "Short response"
|
|
210
|
+
|
|
211
|
+
mock_client = AsyncMock()
|
|
212
|
+
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
213
|
+
|
|
214
|
+
with patch.object(azure_adapter, "_get_client", return_value=mock_client):
|
|
215
|
+
messages = [Message(role="user", content="Hello")]
|
|
216
|
+
await azure_adapter.aresponse(messages)
|
|
217
|
+
|
|
218
|
+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
|
|
219
|
+
assert call_kwargs["max_tokens"] == 100
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# ========== Embedding Tests ==========
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@pytest.fixture
|
|
226
|
+
def azure_embedding_adapter():
|
|
227
|
+
"""Create Azure OpenAI adapter with embedding support."""
|
|
228
|
+
return AzureOpenAIAdapter(
|
|
229
|
+
api_key="test-key",
|
|
230
|
+
resource_name="test-resource",
|
|
231
|
+
deployment_id="gpt-4",
|
|
232
|
+
embedding_deployment_id="text-embedding-3-small",
|
|
233
|
+
embedding_dimensions=1536,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@pytest.mark.asyncio
|
|
238
|
+
async def test_aembed_success(azure_embedding_adapter):
|
|
239
|
+
"""Test successful embedding generation."""
|
|
240
|
+
# Mock embedding response
|
|
241
|
+
mock_embedding_data = MagicMock()
|
|
242
|
+
mock_embedding_data.embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
243
|
+
mock_embedding_data.index = 0
|
|
244
|
+
|
|
245
|
+
mock_response = MagicMock()
|
|
246
|
+
mock_response.data = [mock_embedding_data]
|
|
247
|
+
|
|
248
|
+
mock_client = AsyncMock()
|
|
249
|
+
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
250
|
+
|
|
251
|
+
with patch.object(azure_embedding_adapter, "_get_client", return_value=mock_client):
|
|
252
|
+
embedding = await azure_embedding_adapter.aembed("Hello, world!")
|
|
253
|
+
|
|
254
|
+
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
255
|
+
mock_client.embeddings.create.assert_called_once()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@pytest.mark.asyncio
|
|
259
|
+
async def test_aembed_without_deployment_id():
|
|
260
|
+
"""Test embedding fails without embedding_deployment_id."""
|
|
261
|
+
adapter = AzureOpenAIAdapter(
|
|
262
|
+
api_key="test-key",
|
|
263
|
+
resource_name="test-resource",
|
|
264
|
+
deployment_id="gpt-4",
|
|
265
|
+
# No embedding_deployment_id set
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
with pytest.raises(ValueError, match="embedding_deployment_id must be set"):
|
|
269
|
+
await adapter.aembed("test")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@pytest.mark.asyncio
|
|
273
|
+
async def test_aembed_batch_success(azure_embedding_adapter):
|
|
274
|
+
"""Test successful batch embedding generation."""
|
|
275
|
+
# Mock batch embedding response
|
|
276
|
+
mock_data_1 = MagicMock()
|
|
277
|
+
mock_data_1.embedding = [0.1, 0.2, 0.3]
|
|
278
|
+
mock_data_1.index = 0
|
|
279
|
+
|
|
280
|
+
mock_data_2 = MagicMock()
|
|
281
|
+
mock_data_2.embedding = [0.4, 0.5, 0.6]
|
|
282
|
+
mock_data_2.index = 1
|
|
283
|
+
|
|
284
|
+
mock_response = MagicMock()
|
|
285
|
+
mock_response.data = [mock_data_1, mock_data_2]
|
|
286
|
+
|
|
287
|
+
mock_client = AsyncMock()
|
|
288
|
+
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
289
|
+
|
|
290
|
+
with patch.object(azure_embedding_adapter, "_get_client", return_value=mock_client):
|
|
291
|
+
embeddings = await azure_embedding_adapter.aembed_batch(["Hello", "World"])
|
|
292
|
+
|
|
293
|
+
assert len(embeddings) == 2
|
|
294
|
+
assert embeddings[0] == [0.1, 0.2, 0.3]
|
|
295
|
+
assert embeddings[1] == [0.4, 0.5, 0.6]
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@pytest.mark.asyncio
|
|
299
|
+
async def test_aembed_with_dimensions(azure_embedding_adapter):
|
|
300
|
+
"""Test embedding respects dimensions parameter."""
|
|
301
|
+
mock_embedding_data = MagicMock()
|
|
302
|
+
mock_embedding_data.embedding = [0.1] * 1536
|
|
303
|
+
mock_embedding_data.index = 0
|
|
304
|
+
|
|
305
|
+
mock_response = MagicMock()
|
|
306
|
+
mock_response.data = [mock_embedding_data]
|
|
307
|
+
|
|
308
|
+
mock_client = AsyncMock()
|
|
309
|
+
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
310
|
+
|
|
311
|
+
with patch.object(azure_embedding_adapter, "_get_client", return_value=mock_client):
|
|
312
|
+
await azure_embedding_adapter.aembed("test")
|
|
313
|
+
|
|
314
|
+
call_kwargs = mock_client.embeddings.create.call_args.kwargs
|
|
315
|
+
assert call_kwargs["dimensions"] == 1536
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@pytest.mark.asyncio
|
|
319
|
+
async def test_aembed_image_not_implemented(azure_embedding_adapter):
|
|
320
|
+
"""Test image embedding raises NotImplementedError."""
|
|
321
|
+
with pytest.raises(NotImplementedError, match="Azure OpenAI does not support image embeddings"):
|
|
322
|
+
await azure_embedding_adapter.aembed_image("image.jpg")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@pytest.mark.asyncio
|
|
326
|
+
async def test_aembed_image_batch_not_implemented(azure_embedding_adapter):
|
|
327
|
+
"""Test batch image embedding raises NotImplementedError."""
|
|
328
|
+
with pytest.raises(NotImplementedError, match="Azure OpenAI does not support image embeddings"):
|
|
329
|
+
await azure_embedding_adapter.aembed_image_batch(["image1.jpg", "image2.jpg"])
|