otdf-python 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
otdf_python/kas_client.py CHANGED
@@ -12,7 +12,7 @@ from dataclasses import dataclass
12
12
 
13
13
  import jwt
14
14
 
15
- from .asym_decryption import AsymDecryption
15
+ from .asym_crypto import AsymDecryption
16
16
  from .crypto_utils import CryptoUtils
17
17
  from .kas_connect_rpc_client import KASConnectRPCClient
18
18
  from .kas_key_cache import KASKeyCache
@@ -25,6 +25,7 @@ class KeyAccess:
25
25
  url: str
26
26
  wrapped_key: str
27
27
  ephemeral_public_key: str | None = None
28
+ header: bytes | None = None # For NanoTDF: entire header including ephemeral key
28
29
 
29
30
 
30
31
  class KASClient:
@@ -139,16 +140,16 @@ class KASClient:
139
140
  except Exception as e:
140
141
  raise SDKException("error creating KAS address", e)
141
142
 
142
- def _create_signed_request_jwt(self, policy_json, client_public_key, key_access): # noqa: C901
143
+ def _get_wrapped_key_base64(self, key_access):
143
144
  """
144
- Create a signed JWT for the rewrap request.
145
- The JWT is signed with the DPoP private key.
146
- """
147
- # Handle both ManifestKeyAccess (new camelCase and old snake_case) and simple KeyAccess (for tests)
148
- # TODO: This can probably be simplified to only camelCase
145
+ Extract and normalize the wrapped key to base64-encoded string.
146
+
147
+ Args:
148
+ key_access: KeyAccess object
149
149
 
150
- # Ensure wrappedKey is a base64-encoded string
151
- # Note: wrappedKey from manifest is already base64-encoded
150
+ Returns:
151
+ Base64-encoded wrapped key string
152
+ """
152
153
  wrapped_key = getattr(key_access, "wrappedKey", None) or getattr(
153
154
  key_access, "wrapped_key", None
154
155
  )
@@ -157,11 +158,24 @@ class KASClient:
157
158
 
158
159
  if isinstance(wrapped_key, bytes):
159
160
  # Only encode if it's raw bytes (shouldn't happen from manifest)
160
- wrapped_key = base64.b64encode(wrapped_key).decode("utf-8")
161
+ return base64.b64encode(wrapped_key).decode("utf-8")
161
162
  elif not isinstance(wrapped_key, str):
162
163
  # Convert to string if it's something else
163
- wrapped_key = str(wrapped_key)
164
+ return str(wrapped_key)
164
165
  # If it's already a string (from manifest), use it as-is since it's already base64-encoded
166
+ return wrapped_key
167
+
168
+ def _build_key_access_dict(self, key_access):
169
+ """
170
+ Build key access dictionary from KeyAccess object, handling both old and new field names.
171
+
172
+ Args:
173
+ key_access: KeyAccess object
174
+
175
+ Returns:
176
+ Dictionary with key access information
177
+ """
178
+ wrapped_key = self._get_wrapped_key_base64(key_access)
165
179
 
166
180
  key_access_dict = {
167
181
  "url": key_access.url,
@@ -172,89 +186,162 @@ class KASClient:
172
186
  key_type = getattr(key_access, "type", None) or getattr(
173
187
  key_access, "key_type", None
174
188
  )
175
- if key_type is not None:
176
- key_access_dict["type"] = key_type
177
- else:
178
- key_access_dict["type"] = "wrapped" # Default type for tests
189
+ key_access_dict["type"] = key_type if key_type is not None else "wrapped"
179
190
 
180
191
  protocol = getattr(key_access, "protocol", None)
181
- if protocol is not None:
182
- key_access_dict["protocol"] = protocol
183
- else:
184
- key_access_dict["protocol"] = "kas" # Default protocol for tests
192
+ key_access_dict["protocol"] = protocol if protocol is not None else "kas"
193
+
194
+ # Add optional fields
195
+ self._add_optional_fields(key_access_dict, key_access)
185
196
 
186
- # Optional fields - handle both old and new field names, only include if they exist and are not None
197
+ return key_access_dict
198
+
199
+ def _add_optional_fields(self, key_access_dict, key_access):
200
+ """
201
+ Add optional fields to key access dictionary.
202
+
203
+ Args:
204
+ key_access_dict: Dictionary to add fields to
205
+ key_access: KeyAccess object to extract fields from
206
+ """
207
+ # Policy binding
187
208
  policy_binding = getattr(key_access, "policyBinding", None) or getattr(
188
209
  key_access, "policy_binding", None
189
210
  )
190
211
  if policy_binding is not None:
191
- # Policy binding hash should be kept as base64-encoded
192
- # The server expects base64-encoded hash values in the JWT request
193
212
  key_access_dict["policyBinding"] = policy_binding
194
213
 
214
+ # Encrypted metadata
195
215
  encrypted_metadata = getattr(key_access, "encryptedMetadata", None) or getattr(
196
216
  key_access, "encrypted_metadata", None
197
217
  )
198
218
  if encrypted_metadata is not None:
199
219
  key_access_dict["encryptedMetadata"] = encrypted_metadata
200
220
 
201
- kid = getattr(key_access, "kid", None)
202
- if kid is not None:
203
- key_access_dict["kid"] = kid
204
-
205
- sid = getattr(key_access, "sid", None)
206
- if sid is not None:
207
- key_access_dict["sid"] = sid
221
+ # Simple optional fields
222
+ for field in ["kid", "sid"]:
223
+ value = getattr(key_access, field, None)
224
+ if value is not None:
225
+ key_access_dict[field] = value
208
226
 
227
+ # Schema version
209
228
  schema_version = getattr(key_access, "schemaVersion", None) or getattr(
210
229
  key_access, "schema_version", None
211
230
  )
212
231
  if schema_version is not None:
213
232
  key_access_dict["schemaVersion"] = schema_version
214
233
 
234
+ # Ephemeral public key
215
235
  ephemeral_public_key = getattr(
216
236
  key_access, "ephemeralPublicKey", None
217
237
  ) or getattr(key_access, "ephemeral_public_key", None)
218
238
  if ephemeral_public_key is not None:
219
239
  key_access_dict["ephemeralPublicKey"] = ephemeral_public_key
220
240
 
221
- # Get current timestamp in seconds since epoch (UNIX timestamp)
222
- now = int(time.time())
241
+ # NanoTDF header
242
+ header = getattr(key_access, "header", None)
243
+ if header is not None:
244
+ key_access_dict["header"] = base64.b64encode(header).decode("utf-8")
245
+
246
+ def _get_algorithm_from_session_key_type(self, session_key_type):
247
+ """
248
+ Convert session key type to algorithm string for KAS.
249
+
250
+ Args:
251
+ session_key_type: Session key type (EC_KEY_TYPE or RSA_KEY_TYPE)
223
252
 
224
- # The server expects a JWT with a requestBody field containing the UnsignedRewrapRequest
225
- # Create the request body that matches UnsignedRewrapRequest protobuf structure
226
- # Use the v2 format with explicit policy ID and requests array for cross-tool compatibility
253
+ Returns:
254
+ Algorithm string or None
255
+ """
256
+ if session_key_type == EC_KEY_TYPE:
257
+ return "ec:secp256r1" # Default EC curve for NanoTDF
258
+ elif session_key_type == RSA_KEY_TYPE:
259
+ return "rsa:2048" # Default RSA key size
260
+ return None
261
+
262
+ def _build_rewrap_request(
263
+ self, policy_json, client_public_key, key_access_dict, algorithm, has_header
264
+ ):
265
+ """
266
+ Build the unsigned rewrap request structure.
227
267
 
228
- # Use "policy" as policy ID for compatibility with otdfctl
268
+ Args:
269
+ policy_json: Policy JSON string
270
+ client_public_key: Client public key PEM string
271
+ key_access_dict: Key access dictionary
272
+ algorithm: Algorithm string (e.g., "ec:secp256r1" or "rsa:2048")
273
+ has_header: Whether NanoTDF header is present
274
+
275
+ Returns:
276
+ Dictionary with unsigned rewrap request
277
+ """
229
278
  import json
230
279
 
231
280
  policy_uuid = "policy" # otdfctl uses "policy" as the policy ID
232
-
233
- # For v2 format, the policy body must be base64-encoded
234
281
  policy_base64 = base64.b64encode(policy_json.encode("utf-8")).decode("utf-8")
235
282
 
236
- unsigned_rewrap_request = {
237
- "clientPublicKey": client_public_key, # Maps to client_public_key
238
- "requests": [
239
- { # Maps to requests array (v2 format)
240
- "keyAccessObjects": [
241
- {
242
- "keyAccessObjectId": "kao-0", # Standard KAO ID
243
- "keyAccessObject": key_access_dict,
244
- }
245
- ],
246
- "policy": {
247
- "id": policy_uuid, # Use the UUID from policy as the policy ID
248
- "body": policy_base64, # Base64-encoded policy JSON
249
- },
283
+ # Build the request object
284
+ request_item = {
285
+ "keyAccessObjects": [
286
+ {
287
+ "keyAccessObjectId": "kao-0", # Standard KAO ID
288
+ "keyAccessObject": key_access_dict,
250
289
  }
251
290
  ],
291
+ "policy": {
292
+ "id": policy_uuid,
293
+ },
294
+ }
295
+
296
+ # Only include policy body if header is NOT provided (standard TDF)
297
+ if not has_header:
298
+ request_item["policy"]["body"] = policy_base64
299
+
300
+ # Add algorithm if provided (required for NanoTDF/ECDH)
301
+ if algorithm:
302
+ request_item["algorithm"] = algorithm
303
+
304
+ unsigned_rewrap_request = {
305
+ "clientPublicKey": client_public_key,
306
+ "requests": [request_item],
252
307
  "keyAccess": key_access_dict,
253
- "policy": policy_base64,
254
308
  }
255
309
 
256
- # Convert to JSON string
257
- request_body_json = json.dumps(unsigned_rewrap_request)
310
+ # Only include legacy policy field for standard TDF (not NanoTDF with header)
311
+ if not has_header:
312
+ unsigned_rewrap_request["policy"] = policy_base64
313
+
314
+ return json.dumps(unsigned_rewrap_request)
315
+
316
+ def _create_signed_request_jwt(
317
+ self, policy_json, client_public_key, key_access, session_key_type=None
318
+ ):
319
+ """
320
+ Create a signed JWT for the rewrap request.
321
+ The JWT is signed with the DPoP private key.
322
+
323
+ Args:
324
+ policy_json: Policy JSON string
325
+ client_public_key: Client public key PEM string
326
+ key_access: KeyAccess object
327
+ session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE)
328
+ """
329
+ # Build key access dictionary handling both old and new field names
330
+ key_access_dict = self._build_key_access_dict(key_access)
331
+
332
+ # Get current timestamp
333
+ now = int(time.time())
334
+
335
+ # Convert session_key_type to algorithm string for KAS
336
+ algorithm = self._get_algorithm_from_session_key_type(session_key_type)
337
+
338
+ # Check if header is present (for NanoTDF)
339
+ has_header = getattr(key_access, "header", None) is not None
340
+
341
+ # Build the unsigned rewrap request
342
+ request_body_json = self._build_rewrap_request(
343
+ policy_json, client_public_key, key_access_dict, algorithm, has_header
344
+ )
258
345
 
259
346
  # JWT payload with requestBody field containing the JSON string
260
347
  payload = {
@@ -264,9 +351,7 @@ class KASClient:
264
351
  }
265
352
 
266
353
  # Sign the JWT with the DPoP private key (RS256)
267
- signed_jwt = jwt.encode(payload, self._dpop_private_key_pem, algorithm="RS256")
268
-
269
- return signed_jwt
354
+ return jwt.encode(payload, self._dpop_private_key_pem, algorithm="RS256")
270
355
 
271
356
  def _create_connect_rpc_signed_token(self, key_access, policy_json):
272
357
  """
@@ -506,11 +591,13 @@ class KASClient:
506
591
  self.decryptor = AsymDecryption(private_key_pem)
507
592
  self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key)
508
593
  else:
509
- # For EC keys, generate fresh key pair each time
510
- # TODO: Implement proper EC key handling
511
- private_key, public_key = CryptoUtils.generate_rsa_keypair()
512
- private_key_pem = CryptoUtils.get_rsa_private_key_pem(private_key)
513
- self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key)
594
+ # For EC keys (NanoTDF/ECDH), still need RSA keypair for encrypting the rewrap response
595
+ # KAS uses client public key to encrypt the symmetric key it derived via ECDH
596
+ if self.decryptor is None:
597
+ private_key, public_key = CryptoUtils.generate_rsa_keypair()
598
+ private_key_pem = CryptoUtils.get_rsa_private_key_pem(private_key)
599
+ self.decryptor = AsymDecryption(private_key_pem)
600
+ self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key)
514
601
 
515
602
  def _parse_and_decrypt_response(self, response):
516
603
  """
@@ -559,14 +646,22 @@ class KASClient:
559
646
  policy_json,
560
647
  self.client_public_key,
561
648
  key_access, # Use ephemeral key, not DPoP key
649
+ session_key_type, # Pass algorithm type for NanoTDF
562
650
  )
563
651
 
564
652
  # Call Connect RPC unwrap
565
- return self._unwrap_with_connect_rpc(key_access, signed_token)
653
+ return self._unwrap_with_connect_rpc(key_access, signed_token, session_key_type)
566
654
 
567
- def _unwrap_with_connect_rpc(self, key_access, signed_token) -> bytes:
655
+ def _unwrap_with_connect_rpc(
656
+ self, key_access, signed_token, session_key_type=None
657
+ ) -> bytes:
568
658
  """
569
659
  Connect RPC method for unwrapping keys.
660
+
661
+ Args:
662
+ key_access: KeyAccess object
663
+ signed_token: Signed JWT token
664
+ session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE)
570
665
  """
571
666
 
572
667
  # Get access token for authentication if token source is available
@@ -586,12 +681,23 @@ class KASClient:
586
681
  normalized_kas_url, key_access, signed_token, access_token
587
682
  )
588
683
 
589
- # Decrypt the wrapped key
684
+ # Both ECDH and RSA modes return an RSA-encrypted key
685
+ # For ECDH (EC_KEY_TYPE): KAS performs ECDH to derive symmetric key, then RSA-encrypts it with client public key
686
+ # For RSA (RSA_KEY_TYPE): KAS RSA-decrypts wrapped key, then RSA-encrypts it with client public key
687
+ # In both cases, we need to RSA-decrypt using our client private key
590
688
  if not self.decryptor:
591
689
  raise SDKException("Decryptor not initialized")
592
690
 
593
691
  result = self.decryptor.decrypt(entity_wrapped_key)
594
- logging.info("Connect RPC rewrap succeeded")
692
+
693
+ if session_key_type == EC_KEY_TYPE:
694
+ logging.info(
695
+ f"Connect RPC rewrap succeeded (ECDH - KAS derived key via ECDH, length={len(result)} bytes)"
696
+ )
697
+ else:
698
+ logging.info(
699
+ f"Connect RPC rewrap succeeded (RSA - length={len(result)} bytes)"
700
+ )
595
701
  return result
596
702
 
597
703
  except Exception as e: