embed-client 2.0.0.0__py3-none-any.whl → 3.1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,12 @@ Async client for Embedding Service API (OpenAPI 3.0.2)
5
5
  - English docstrings and examples
6
6
  - Ready for PyPi
7
7
  - Supports new API format with body, embedding, and chunks
8
+ - Supports all authentication methods (API Key, JWT, Basic Auth, Certificate)
9
+ - Integrates with mcp_security_framework
10
+ - Supports all security modes (HTTP, HTTPS, mTLS)
11
+
12
+ Author: Vasiliy Zdanovskiy
13
+ email: vasilyvz@gmail.com
8
14
  """
9
15
 
10
16
  from typing import Any, Dict, List, Optional, Union
@@ -13,6 +19,12 @@ import asyncio
13
19
  import os
14
20
  import json
15
21
  import logging
22
+ from pathlib import Path
23
+
24
+ # Import authentication, configuration, and SSL systems
25
+ from .auth import ClientAuthManager, create_auth_manager
26
+ from .config import ClientConfig
27
+ from .ssl_manager import ClientSSLManager, create_ssl_manager
16
28
 
17
29
  class EmbeddingServiceError(Exception):
18
30
  """Base exception for EmbeddingServiceAsyncClient."""
@@ -50,17 +62,60 @@ class EmbeddingServiceAsyncClient:
50
62
  - Old format: {"result": {"success": true, "data": {"embeddings": [...]}}}
51
63
  - New format: {"result": {"success": true, "data": {"embeddings": [...], "results": [{"body": "text", "embedding": [...], "tokens": [...], "bm25_tokens": [...]}]}}}
52
64
 
65
+ Supports all authentication methods and security modes:
66
+ - API Key authentication
67
+ - JWT token authentication
68
+ - Basic authentication
69
+ - Certificate authentication (mTLS)
70
+ - HTTP, HTTPS, and mTLS security modes
71
+
53
72
  Args:
54
- base_url (str): Base URL of the embedding service (e.g., "http://localhost").
55
- port (int): Port of the embedding service (e.g., 8001).
73
+ base_url (str, optional): Base URL of the embedding service (e.g., "http://localhost").
74
+ port (int, optional): Port of the embedding service (e.g., 8001).
56
75
  timeout (float): Request timeout in seconds (default: 30).
76
+ config (ClientConfig, optional): Configuration object with authentication and SSL settings.
77
+ config_dict (dict, optional): Configuration dictionary with authentication and SSL settings.
78
+ auth_manager (ClientAuthManager, optional): Authentication manager instance.
79
+
57
80
  Raises:
58
81
  EmbeddingServiceConfigError: If base_url or port is invalid.
59
82
  """
60
- def __init__(self, base_url: Optional[str] = None, port: Optional[int] = None, timeout: float = 30.0):
61
- # Validate and set base_url
83
+ def __init__(self,
84
+ base_url: Optional[str] = None,
85
+ port: Optional[int] = None,
86
+ timeout: float = 30.0,
87
+ config: Optional[ClientConfig] = None,
88
+ config_dict: Optional[Dict[str, Any]] = None,
89
+ auth_manager: Optional[ClientAuthManager] = None):
90
+ # Initialize configuration
91
+ self.config = config
92
+ self.config_dict = config_dict
93
+ self.auth_manager = auth_manager
94
+
95
+ # If config is provided, use it to set base_url and port
96
+ if config:
97
+ self.base_url = config.get("server.host", base_url or os.getenv("EMBEDDING_SERVICE_BASE_URL", "http://localhost"))
98
+ self.port = config.get("server.port", port or int(os.getenv("EMBEDDING_SERVICE_PORT", "8001")))
99
+ self.timeout = config.get("client.timeout", timeout)
100
+ elif config_dict:
101
+ self.base_url = config_dict.get("server", {}).get("host", base_url or os.getenv("EMBEDDING_SERVICE_BASE_URL", "http://localhost"))
102
+ self.port = config_dict.get("server", {}).get("port", port or int(os.getenv("EMBEDDING_SERVICE_PORT", "8001")))
103
+ self.timeout = config_dict.get("client", {}).get("timeout", timeout)
104
+ else:
105
+ # Use provided parameters or environment variables
106
+ try:
107
+ self.base_url = base_url or os.getenv("EMBEDDING_SERVICE_BASE_URL", "http://localhost")
108
+ except (TypeError, AttributeError) as e:
109
+ raise EmbeddingServiceConfigError(f"Invalid base_url configuration: {e}") from e
110
+
111
+ try:
112
+ self.port = port or int(os.getenv("EMBEDDING_SERVICE_PORT", "8001"))
113
+ except (ValueError, TypeError) as e:
114
+ raise EmbeddingServiceConfigError(f"Invalid port configuration: {e}") from e
115
+ self.timeout = timeout
116
+
117
+ # Validate base_url
62
118
  try:
63
- self.base_url = base_url or os.getenv("EMBEDDING_SERVICE_BASE_URL", "http://localhost")
64
119
  if not self.base_url:
65
120
  raise EmbeddingServiceConfigError("base_url must be provided.")
66
121
  if not isinstance(self.base_url, str):
@@ -72,10 +127,8 @@ class EmbeddingServiceAsyncClient:
72
127
  except (TypeError, AttributeError) as e:
73
128
  raise EmbeddingServiceConfigError(f"Invalid base_url configuration: {e}") from e
74
129
 
75
- # Validate and set port
130
+ # Validate port
76
131
  try:
77
- port_env = os.getenv("EMBEDDING_SERVICE_PORT", "8001")
78
- self.port = port if port is not None else int(port_env)
79
132
  if self.port is None:
80
133
  raise EmbeddingServiceConfigError("port must be provided.")
81
134
  if not isinstance(self.port, int) or self.port <= 0 or self.port > 65535:
@@ -85,12 +138,23 @@ class EmbeddingServiceAsyncClient:
85
138
 
86
139
  # Validate timeout
87
140
  try:
88
- self.timeout = float(timeout)
141
+ self.timeout = float(self.timeout)
89
142
  if self.timeout <= 0:
90
143
  raise EmbeddingServiceConfigError("timeout must be positive.")
91
144
  except (ValueError, TypeError) as e:
92
145
  raise EmbeddingServiceConfigError(f"Invalid timeout configuration: {e}") from e
93
146
 
147
+ # Initialize authentication manager if not provided
148
+ if not self.auth_manager and (self.config or self.config_dict):
149
+ config_data = self.config_dict if self.config_dict else self.config.get_all()
150
+ self.auth_manager = create_auth_manager(config_data)
151
+
152
+ # Initialize SSL manager
153
+ self.ssl_manager = None
154
+ if self.config or self.config_dict:
155
+ config_data = self.config_dict if self.config_dict else self.config.get_all()
156
+ self.ssl_manager = create_ssl_manager(config_data)
157
+
94
158
  self._session: Optional[aiohttp.ClientSession] = None
95
159
 
96
160
  def _make_url(self, path: str, base_url: Optional[str] = None, port: Optional[int] = None) -> str:
@@ -294,7 +358,13 @@ class EmbeddingServiceAsyncClient:
294
358
  try:
295
359
  # Create session with timeout configuration
296
360
  timeout = aiohttp.ClientTimeout(total=self.timeout)
297
- self._session = aiohttp.ClientSession(timeout=timeout)
361
+
362
+ # Create SSL connector if SSL manager is available
363
+ connector = None
364
+ if self.ssl_manager:
365
+ connector = self.ssl_manager.create_connector()
366
+
367
+ self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
298
368
  return self
299
369
  except Exception as e:
300
370
  raise EmbeddingServiceError(f"Failed to create HTTP session: {e}") from e
@@ -344,8 +414,9 @@ class EmbeddingServiceAsyncClient:
344
414
  dict: Health status and model info.
345
415
  """
346
416
  url = self._make_url("/health", base_url, port)
417
+ headers = self.get_auth_headers()
347
418
  try:
348
- async with self._session.get(url, timeout=self.timeout) as resp:
419
+ async with self._session.get(url, headers=headers, timeout=self.timeout) as resp:
349
420
  await self._raise_for_status(resp)
350
421
  try:
351
422
  data = await resp.json()
@@ -387,8 +458,9 @@ class EmbeddingServiceAsyncClient:
387
458
  dict: OpenAPI schema.
388
459
  """
389
460
  url = self._make_url("/openapi.json", base_url, port)
461
+ headers = self.get_auth_headers()
390
462
  try:
391
- async with self._session.get(url, timeout=self.timeout) as resp:
463
+ async with self._session.get(url, headers=headers, timeout=self.timeout) as resp:
392
464
  await self._raise_for_status(resp)
393
465
  try:
394
466
  data = await resp.json()
@@ -430,8 +502,9 @@ class EmbeddingServiceAsyncClient:
430
502
  dict: List of commands and their descriptions.
431
503
  """
432
504
  url = self._make_url("/api/commands", base_url, port)
505
+ headers = self.get_auth_headers()
433
506
  try:
434
- async with self._session.get(url, timeout=self.timeout) as resp:
507
+ async with self._session.get(url, headers=headers, timeout=self.timeout) as resp:
435
508
  await self._raise_for_status(resp)
436
509
  try:
437
510
  data = await resp.json()
@@ -533,12 +606,13 @@ class EmbeddingServiceAsyncClient:
533
606
 
534
607
  logger = logging.getLogger('EmbeddingServiceAsyncClient.cmd')
535
608
  url = self._make_url("/cmd", base_url, port)
609
+ headers = self.get_auth_headers()
536
610
  payload = {"command": command}
537
611
  if params is not None:
538
612
  payload["params"] = params
539
- logger.info(f"Sending embedding command: url={url}, payload={payload}")
613
+ logger.info(f"Sending embedding command: url={url}, payload={payload}, headers={headers}")
540
614
  try:
541
- async with self._session.post(url, json=payload, timeout=self.timeout) as resp:
615
+ async with self._session.post(url, json=payload, headers=headers, timeout=self.timeout) as resp:
542
616
  logger.info(f"Embedding service HTTP status: {resp.status}")
543
617
  await self._raise_for_status(resp)
544
618
  try:
@@ -604,4 +678,290 @@ class EmbeddingServiceAsyncClient:
604
678
  finally:
605
679
  self._session = None
606
680
 
607
- # TODO: Add methods for /cmd, /api/commands, etc.
681
+ # TODO: Add methods for /cmd, /api/commands, etc.
682
+
683
+ @classmethod
684
+ def from_config(cls, config: ClientConfig) -> "EmbeddingServiceAsyncClient":
685
+ """
686
+ Create client from ClientConfig object.
687
+
688
+ Args:
689
+ config: ClientConfig object with authentication and SSL settings
690
+
691
+ Returns:
692
+ EmbeddingServiceAsyncClient instance configured with the provided config
693
+ """
694
+ return cls(config=config)
695
+
696
+ @classmethod
697
+ def from_config_dict(cls, config_dict: Dict[str, Any]) -> "EmbeddingServiceAsyncClient":
698
+ """
699
+ Create client from configuration dictionary.
700
+
701
+ Args:
702
+ config_dict: Configuration dictionary with authentication and SSL settings
703
+
704
+ Returns:
705
+ EmbeddingServiceAsyncClient instance configured with the provided config
706
+ """
707
+ return cls(config_dict=config_dict)
708
+
709
+ @classmethod
710
+ def from_config_file(cls, config_path: Union[str, Path]) -> "EmbeddingServiceAsyncClient":
711
+ """
712
+ Create client from configuration file.
713
+
714
+ Args:
715
+ config_path: Path to configuration file (JSON or YAML)
716
+
717
+ Returns:
718
+ EmbeddingServiceAsyncClient instance configured with the provided config
719
+ """
720
+ config = ClientConfig.load_config(config_path)
721
+ return cls(config=config)
722
+
723
+ @classmethod
724
+ def with_auth(cls,
725
+ base_url: str,
726
+ port: int,
727
+ auth_method: str,
728
+ **kwargs) -> "EmbeddingServiceAsyncClient":
729
+ """
730
+ Create client with authentication configuration.
731
+
732
+ Args:
733
+ base_url: Base URL of the embedding service
734
+ port: Port of the embedding service
735
+ auth_method: Authentication method ("api_key", "jwt", "basic", "certificate")
736
+ **kwargs: Additional authentication parameters
737
+
738
+ Returns:
739
+ EmbeddingServiceAsyncClient instance with authentication configured
740
+
741
+ Examples:
742
+ # API Key authentication
743
+ client = EmbeddingServiceAsyncClient.with_auth(
744
+ "http://localhost", 8001, "api_key",
745
+ api_keys={"user": "api_key_123"}
746
+ )
747
+
748
+ # JWT authentication
749
+ client = EmbeddingServiceAsyncClient.with_auth(
750
+ "http://localhost", 8001, "jwt",
751
+ secret="secret", username="user", password="pass"
752
+ )
753
+
754
+ # Basic authentication
755
+ client = EmbeddingServiceAsyncClient.with_auth(
756
+ "http://localhost", 8001, "basic",
757
+ username="user", password="pass"
758
+ )
759
+
760
+ # Certificate authentication
761
+ client = EmbeddingServiceAsyncClient.with_auth(
762
+ "https://localhost", 9443, "certificate",
763
+ cert_file="certs/client.crt", key_file="keys/client.key"
764
+ )
765
+ """
766
+ # Build configuration dictionary
767
+ config_dict = {
768
+ "server": {
769
+ "host": base_url,
770
+ "port": port
771
+ },
772
+ "client": {
773
+ "timeout": kwargs.get("timeout", 30.0)
774
+ },
775
+ "auth": {
776
+ "method": auth_method
777
+ }
778
+ }
779
+
780
+ # Add authentication configuration based on method
781
+ if auth_method == "api_key":
782
+ if "api_keys" in kwargs:
783
+ config_dict["auth"]["api_keys"] = kwargs["api_keys"]
784
+ elif "api_key" in kwargs:
785
+ config_dict["auth"]["api_keys"] = {"user": kwargs["api_key"]}
786
+ else:
787
+ raise ValueError("api_keys or api_key parameter required for api_key authentication")
788
+
789
+ elif auth_method == "jwt":
790
+ required_params = ["secret", "username", "password"]
791
+ for param in required_params:
792
+ if param not in kwargs:
793
+ raise ValueError(f"{param} parameter required for jwt authentication")
794
+
795
+ config_dict["auth"]["jwt"] = {
796
+ "secret": kwargs["secret"],
797
+ "username": kwargs["username"],
798
+ "password": kwargs["password"],
799
+ "expiry_hours": kwargs.get("expiry_hours", 24)
800
+ }
801
+
802
+ elif auth_method == "basic":
803
+ required_params = ["username", "password"]
804
+ for param in required_params:
805
+ if param not in kwargs:
806
+ raise ValueError(f"{param} parameter required for basic authentication")
807
+
808
+ config_dict["auth"]["basic"] = {
809
+ "username": kwargs["username"],
810
+ "password": kwargs["password"]
811
+ }
812
+
813
+ elif auth_method == "certificate":
814
+ required_params = ["cert_file", "key_file"]
815
+ for param in required_params:
816
+ if param not in kwargs:
817
+ raise ValueError(f"{param} parameter required for certificate authentication")
818
+
819
+ config_dict["auth"]["certificate"] = {
820
+ "cert_file": kwargs["cert_file"],
821
+ "key_file": kwargs["key_file"]
822
+ }
823
+
824
+ else:
825
+ raise ValueError(f"Unsupported authentication method: {auth_method}")
826
+
827
+ # Add SSL configuration if provided or if using HTTPS
828
+ ssl_enabled = kwargs.get("ssl_enabled")
829
+ if ssl_enabled is None:
830
+ ssl_enabled = base_url.startswith("https://")
831
+
832
+ if ssl_enabled or any(key in kwargs for key in ["ca_cert_file", "cert_file", "key_file", "ssl_enabled"]):
833
+ config_dict["ssl"] = {
834
+ "enabled": ssl_enabled,
835
+ "verify_mode": kwargs.get("verify_mode", "CERT_REQUIRED"),
836
+ "check_hostname": kwargs.get("check_hostname", True),
837
+ "check_expiry": kwargs.get("check_expiry", True)
838
+ }
839
+
840
+ if "ca_cert_file" in kwargs:
841
+ config_dict["ssl"]["ca_cert_file"] = kwargs["ca_cert_file"]
842
+
843
+ if "cert_file" in kwargs:
844
+ config_dict["ssl"]["cert_file"] = kwargs["cert_file"]
845
+
846
+ if "key_file" in kwargs:
847
+ config_dict["ssl"]["key_file"] = kwargs["key_file"]
848
+
849
+ return cls(config_dict=config_dict)
850
+
851
+ def get_auth_headers(self) -> Dict[str, str]:
852
+ """
853
+ Get authentication headers for requests.
854
+
855
+ Returns:
856
+ Dictionary of authentication headers
857
+ """
858
+ if not self.auth_manager:
859
+ return {}
860
+
861
+ auth_method = self.auth_manager.get_auth_method()
862
+ if auth_method == "none":
863
+ return {}
864
+
865
+ # Get authentication parameters from config
866
+ auth_config = self.config_dict.get("auth", {}) if self.config_dict else {}
867
+ if self.config:
868
+ auth_config = self.config.get("auth", {})
869
+
870
+ if auth_method == "api_key":
871
+ api_keys = auth_config.get("api_keys", {})
872
+ # Use first available API key
873
+ for user_id, api_key in api_keys.items():
874
+ return self.auth_manager.get_auth_headers("api_key", api_key=api_key)
875
+
876
+ elif auth_method == "jwt":
877
+ jwt_config = auth_config.get("jwt", {})
878
+ username = jwt_config.get("username")
879
+ password = jwt_config.get("password")
880
+ if username and password:
881
+ # Create JWT token
882
+ token = self.auth_manager.create_jwt_token(username, ["user"])
883
+ return self.auth_manager.get_auth_headers("jwt", token=token)
884
+
885
+ elif auth_method == "basic":
886
+ basic_config = auth_config.get("basic", {})
887
+ username = basic_config.get("username")
888
+ password = basic_config.get("password")
889
+ if username and password:
890
+ return self.auth_manager.get_auth_headers("basic", username=username, password=password)
891
+
892
+ return {}
893
+
894
+ def is_authenticated(self) -> bool:
895
+ """
896
+ Check if client is configured for authentication.
897
+
898
+ Returns:
899
+ True if authentication is configured, False otherwise
900
+ """
901
+ return self.auth_manager is not None and self.auth_manager.is_auth_enabled()
902
+
903
+ def get_auth_method(self) -> str:
904
+ """
905
+ Get current authentication method.
906
+
907
+ Returns:
908
+ Authentication method name or "none" if not configured
909
+ """
910
+ if not self.auth_manager:
911
+ return "none"
912
+ return self.auth_manager.get_auth_method()
913
+
914
+ def is_ssl_enabled(self) -> bool:
915
+ """
916
+ Check if SSL/TLS is enabled.
917
+
918
+ Returns:
919
+ True if SSL/TLS is enabled, False otherwise
920
+ """
921
+ if not self.ssl_manager:
922
+ return False
923
+ return self.ssl_manager.is_ssl_enabled()
924
+
925
+ def is_mtls_enabled(self) -> bool:
926
+ """
927
+ Check if mTLS (mutual TLS) is enabled.
928
+
929
+ Returns:
930
+ True if mTLS is enabled, False otherwise
931
+ """
932
+ if not self.ssl_manager:
933
+ return False
934
+ return self.ssl_manager.is_mtls_enabled()
935
+
936
+ def get_ssl_config(self) -> Dict[str, Any]:
937
+ """
938
+ Get current SSL configuration.
939
+
940
+ Returns:
941
+ Dictionary with SSL configuration or empty dict if not configured
942
+ """
943
+ if not self.ssl_manager:
944
+ return {}
945
+ return self.ssl_manager.get_ssl_config()
946
+
947
+ def validate_ssl_config(self) -> List[str]:
948
+ """
949
+ Validate SSL configuration.
950
+
951
+ Returns:
952
+ List of validation errors
953
+ """
954
+ if not self.ssl_manager:
955
+ return []
956
+ return self.ssl_manager.validate_ssl_config()
957
+
958
+ def get_supported_ssl_protocols(self) -> List[str]:
959
+ """
960
+ Get list of supported SSL/TLS protocols.
961
+
962
+ Returns:
963
+ List of supported protocol names
964
+ """
965
+ if not self.ssl_manager:
966
+ return []
967
+ return self.ssl_manager.get_supported_protocols()