exchange-keyshare 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,86 @@
1
+ """Configuration file handling for exchange-keyshare."""
2
+
3
+ import os
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import yaml
9
+
10
+
11
+ def _default_config_path() -> Path:
12
+ """Get default config path, respecting EXCHANGE_KEYSHARE_CONFIG env var."""
13
+ env_path = os.environ.get("EXCHANGE_KEYSHARE_CONFIG")
14
+ if env_path:
15
+ return Path(env_path)
16
+
17
+ home = Path.home()
18
+ return home / ".config" / "exchange-keyshare" / "config.yaml"
19
+
20
+
21
+ @dataclass
22
+ class Config:
23
+ """Configuration container."""
24
+
25
+ config_path: Path = field(default_factory=_default_config_path)
26
+ bucket: str | None = None
27
+ region: str | None = None
28
+ stack_name: str | None = None
29
+ role_arn: str | None = None
30
+ external_id: str | None = None
31
+ kms_key_arn: str | None = None
32
+
33
+ def load(self) -> None:
34
+ """Load config from file."""
35
+ data = load_config(self.config_path)
36
+ self.bucket = data.get("bucket")
37
+ self.region = data.get("region")
38
+ self.stack_name = data.get("stack_name")
39
+ self.role_arn = data.get("role_arn")
40
+ self.external_id = data.get("external_id")
41
+ self.kms_key_arn = data.get("kms_key_arn")
42
+
43
+ def save(self) -> None:
44
+ """Save config to file."""
45
+ data: dict[str, Any] = {}
46
+ if self.bucket:
47
+ data["bucket"] = self.bucket
48
+ if self.region:
49
+ data["region"] = self.region
50
+ if self.stack_name:
51
+ data["stack_name"] = self.stack_name
52
+ if self.role_arn:
53
+ data["role_arn"] = self.role_arn
54
+ if self.external_id:
55
+ data["external_id"] = self.external_id
56
+ if self.kms_key_arn:
57
+ data["kms_key_arn"] = self.kms_key_arn
58
+ save_config(self.config_path, data)
59
+
60
+
61
+ def load_config(path: Path) -> dict[str, Any]:
62
+ """Load config from YAML file. Returns empty dict if file doesn't exist."""
63
+ if not path.exists():
64
+ return {}
65
+
66
+ with open(path) as f:
67
+ data = yaml.safe_load(f)
68
+ return data if data else {}
69
+
70
+
71
+ def save_config(path: Path, data: dict[str, Any]) -> None:
72
+ """Save config to YAML file. Creates parent directories if needed.
73
+
74
+ Config file is created with 0600 permissions (owner read/write only)
75
+ since it may contain sensitive data like external_id.
76
+ """
77
+ path.parent.mkdir(parents=True, exist_ok=True)
78
+
79
+ # Create with restrictive permissions (owner read/write only)
80
+ fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
81
+ try:
82
+ with os.fdopen(fd, "w") as f:
83
+ yaml.dump(data, f, default_flow_style=False)
84
+ except Exception:
85
+ os.close(fd)
86
+ raise
@@ -0,0 +1,106 @@
1
+ """Key management operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import secrets
6
+ import string
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, cast
9
+
10
+ import boto3
11
+ import yaml
12
+
13
+ if TYPE_CHECKING:
14
+ from mypy_boto3_s3 import S3Client
15
+
16
+ from exchange_keyshare.schema import CredentialSchema, validate_credential
17
+
18
+
19
+ def generate_s3_key(exchange: str) -> str:
20
+ """Generate a unique S3 key for a credential."""
21
+ chars = string.ascii_lowercase + string.digits
22
+ suffix = "".join(secrets.choice(chars) for _ in range(12))
23
+ return f"exchange-credentials/{exchange}-{suffix}.yaml"
24
+
25
+
26
+ def parse_credential_from_yaml(content: str) -> CredentialSchema:
27
+ """Parse and validate credential from YAML string."""
28
+ data = yaml.safe_load(content)
29
+ return validate_credential(data)
30
+
31
+
32
+ @dataclass
33
+ class CredentialInfo:
34
+ """Info about a stored credential (for listing)."""
35
+
36
+ key: str
37
+ exchange: str
38
+ pairs: list[str] | None
39
+ labels: list[dict[str, str]] | None
40
+
41
+
42
+ def list_credentials(bucket: str, region: str) -> list[CredentialInfo]:
43
+ """List all credentials in the bucket."""
44
+ s3 = cast("S3Client", boto3.client("s3", region_name=region)) # pyright: ignore[reportUnknownMemberType]
45
+
46
+ result: list[CredentialInfo] = []
47
+
48
+ paginator = s3.get_paginator("list_objects_v2")
49
+ for page in paginator.paginate(Bucket=bucket, Prefix="exchange-credentials/"):
50
+ for obj in page.get("Contents", []):
51
+ key = obj.get("Key")
52
+ if not key or not key.endswith(".yaml"):
53
+ continue
54
+
55
+ response = s3.get_object(Bucket=bucket, Key=key)
56
+ content = response["Body"].read().decode("utf-8")
57
+ cred = parse_credential_from_yaml(content)
58
+ result.append(CredentialInfo(
59
+ key=key,
60
+ exchange=cred.exchange,
61
+ pairs=cred.pairs,
62
+ labels=cred.labels,
63
+ ))
64
+
65
+ return result
66
+
67
+
68
+ def upload_credential(
69
+ bucket: str,
70
+ region: str,
71
+ credential: CredentialSchema,
72
+ kms_key_arn: str,
73
+ s3_key: str | None = None,
74
+ ) -> str:
75
+ """Upload credential to S3. Returns the S3 key used."""
76
+ s3 = cast("S3Client", boto3.client("s3", region_name=region)) # pyright: ignore[reportUnknownMemberType]
77
+
78
+ key = s3_key or generate_s3_key(credential.exchange)
79
+ content = yaml.dump(credential.to_dict(), default_flow_style=False)
80
+
81
+ s3.put_object(
82
+ Bucket=bucket,
83
+ Key=key,
84
+ Body=content.encode("utf-8"),
85
+ ContentType="application/x-yaml",
86
+ ServerSideEncryption="aws:kms",
87
+ SSEKMSKeyId=kms_key_arn,
88
+ )
89
+
90
+ return key
91
+
92
+
93
+ def get_credential(bucket: str, region: str, key: str) -> CredentialSchema:
94
+ """Get a credential from S3."""
95
+ s3 = cast("S3Client", boto3.client("s3", region_name=region)) # pyright: ignore[reportUnknownMemberType]
96
+
97
+ response = s3.get_object(Bucket=bucket, Key=key)
98
+ content = response["Body"].read().decode("utf-8")
99
+
100
+ return parse_credential_from_yaml(content)
101
+
102
+
103
+ def delete_credential(bucket: str, region: str, key: str) -> None:
104
+ """Delete a credential from S3."""
105
+ s3 = cast("S3Client", boto3.client("s3", region_name=region)) # pyright: ignore[reportUnknownMemberType]
106
+ s3.delete_object(Bucket=bucket, Key=key)
@@ -0,0 +1,117 @@
1
+ """Credential schema validation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, cast
5
+
6
+
7
+ SUPPORTED_EXCHANGES = frozenset({
8
+ "binance",
9
+ "coinbase",
10
+ "kraken",
11
+ "kucoin",
12
+ "bitget",
13
+ })
14
+
15
+ PASSPHRASE_REQUIRED_EXCHANGES = frozenset({
16
+ "coinbase",
17
+ "kucoin",
18
+ "bitget",
19
+ })
20
+
21
+
22
+ class SchemaError(Exception):
23
+ """Raised when credential schema validation fails."""
24
+ pass
25
+
26
+
27
+ @dataclass
28
+ class CredentialSchema:
29
+ """Validated credential data."""
30
+
31
+ version: str
32
+ exchange: str
33
+ credential: dict[str, str]
34
+ pairs: list[str] | None = None
35
+ labels: list[dict[str, str]] | None = None
36
+
37
+ def to_dict(self) -> dict[str, Any]:
38
+ """Convert to dictionary for YAML serialization."""
39
+ data: dict[str, Any] = {
40
+ "version": self.version,
41
+ "exchange": self.exchange,
42
+ "credential": self.credential,
43
+ }
44
+ if self.pairs is not None:
45
+ data["pairs"] = self.pairs
46
+ if self.labels is not None:
47
+ data["labels"] = self.labels
48
+ return data
49
+
50
+
51
+ def validate_credential(data: dict[str, Any]) -> CredentialSchema:
52
+ """Validate credential data and return structured result.
53
+
54
+ Raises:
55
+ SchemaError: If validation fails.
56
+ """
57
+ # Check required top-level fields
58
+ if "version" not in data:
59
+ raise SchemaError("missing required field: version")
60
+ if "exchange" not in data:
61
+ raise SchemaError("missing required field: exchange")
62
+ if "credential" not in data:
63
+ raise SchemaError("missing required field: credential")
64
+
65
+ exchange = data["exchange"]
66
+ if exchange not in SUPPORTED_EXCHANGES:
67
+ raise SchemaError(f"unsupported exchange: {exchange}")
68
+
69
+ credential_raw = data["credential"]
70
+ if not isinstance(credential_raw, dict):
71
+ raise SchemaError("credential must be a mapping")
72
+ credential = cast(dict[str, str], credential_raw)
73
+ if "api_key" not in credential:
74
+ raise SchemaError("missing required field: credential.api_key")
75
+ api_key = credential["api_key"]
76
+ if not api_key or not api_key.strip():
77
+ raise SchemaError("credential.api_key cannot be empty")
78
+ if "api_secret" not in credential:
79
+ raise SchemaError("missing required field: credential.api_secret")
80
+ api_secret = credential["api_secret"]
81
+ if not api_secret or not api_secret.strip():
82
+ raise SchemaError("credential.api_secret cannot be empty")
83
+
84
+ # Check passphrase requirement for certain exchanges
85
+ if exchange in PASSPHRASE_REQUIRED_EXCHANGES:
86
+ if "passphrase" not in credential or not credential["passphrase"]:
87
+ raise SchemaError(f"passphrase required for {exchange}")
88
+
89
+ # Validate optional pairs field
90
+ pairs: list[str] | None = None
91
+ pairs_raw = data.get("pairs")
92
+ if pairs_raw is not None:
93
+ if not isinstance(pairs_raw, list):
94
+ raise SchemaError("pairs must be a list")
95
+ for pair in cast(list[Any], pairs_raw):
96
+ if not isinstance(pair, str):
97
+ raise SchemaError("each pair must be a string")
98
+ pairs = cast(list[str], pairs_raw)
99
+
100
+ # Validate optional labels field
101
+ labels: list[dict[str, str]] | None = None
102
+ labels_raw = data.get("labels")
103
+ if labels_raw is not None:
104
+ if not isinstance(labels_raw, list):
105
+ raise SchemaError("labels must be a list")
106
+ for label in cast(list[Any], labels_raw):
107
+ if not isinstance(label, dict) or "key" not in label or "value" not in label:
108
+ raise SchemaError("label must have 'key' and 'value'")
109
+ labels = cast(list[dict[str, str]], labels_raw)
110
+
111
+ return CredentialSchema(
112
+ version=str(data["version"]),
113
+ exchange=str(exchange),
114
+ credential=credential,
115
+ pairs=pairs,
116
+ labels=labels,
117
+ )
@@ -0,0 +1,263 @@
1
+ """Setup command logic for creating AWS infrastructure."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import secrets
7
+ import string
8
+ import time
9
+ from collections.abc import Generator
10
+ from dataclasses import dataclass
11
+ from typing import TYPE_CHECKING, cast
12
+
13
+ import boto3
14
+ from botocore.exceptions import ClientError
15
+
16
+ if TYPE_CHECKING:
17
+ from mypy_boto3_cloudformation import CloudFormationClient
18
+
19
+ from exchange_keyshare.cfn import load_template
20
+
21
+
22
+ def generate_bucket_name() -> str:
23
+ """Generate a unique S3 bucket name."""
24
+ chars = string.ascii_lowercase + string.digits
25
+ suffix = "".join(secrets.choice(chars) for _ in range(12))
26
+ return f"exchange-keyshare-{suffix}"
27
+
28
+
29
+ def get_default_region() -> str:
30
+ """Get default AWS region from env or fallback."""
31
+ return os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-east-1"))
32
+
33
+
34
+ @dataclass
35
+ class SetupResult:
36
+ """Result of setup operation."""
37
+
38
+ bucket: str
39
+ region: str
40
+ role_arn: str
41
+ external_id: str
42
+ stack_name: str
43
+ kms_key_arn: str
44
+
45
+
46
+ @dataclass
47
+ class ResourceStatus:
48
+ """Status of a CloudFormation resource."""
49
+
50
+ logical_id: str
51
+ resource_type: str
52
+ status: str
53
+ reason: str | None = None
54
+
55
+
56
+ @dataclass
57
+ class StackProgress:
58
+ """Progress update for stack creation."""
59
+
60
+ resources: dict[str, ResourceStatus]
61
+ stack_status: str
62
+ is_complete: bool
63
+ is_failed: bool
64
+ failure_reason: str | None = None
65
+
66
+
67
+ # Friendly names for CloudFormation resource types
68
+ RESOURCE_TYPE_NAMES: dict[str, str] = {
69
+ "AWS::S3::Bucket": "S3 Bucket",
70
+ "AWS::S3::BucketPolicy": "Bucket Policy",
71
+ "AWS::KMS::Key": "KMS Key",
72
+ "AWS::KMS::Alias": "KMS Alias",
73
+ "AWS::IAM::Role": "IAM Role",
74
+ }
75
+
76
+
77
+ def get_friendly_type(resource_type: str) -> str:
78
+ """Get a friendly name for a resource type."""
79
+ return RESOURCE_TYPE_NAMES.get(resource_type, resource_type)
80
+
81
+
82
+ def start_stack_creation(
83
+ external_id: str,
84
+ principal_arn: str,
85
+ bucket_name: str | None = None,
86
+ region: str | None = None,
87
+ ) -> tuple[str, str, CloudFormationClient]:
88
+ """Start CloudFormation stack creation. Returns (stack_name, region, cfn_client)."""
89
+ bucket = bucket_name or generate_bucket_name()
90
+ region = region or get_default_region()
91
+ stack_name = bucket # Bucket name already has unique suffix
92
+
93
+ cfn = cast("CloudFormationClient", boto3.client("cloudformation", region_name=region)) # pyright: ignore[reportUnknownMemberType]
94
+ template = load_template()
95
+
96
+ try:
97
+ cfn.create_stack(
98
+ StackName=stack_name,
99
+ TemplateBody=template,
100
+ Parameters=[
101
+ {"ParameterKey": "BucketName", "ParameterValue": bucket},
102
+ {"ParameterKey": "ConsumerPrincipalArn", "ParameterValue": principal_arn},
103
+ {"ParameterKey": "ExternalId", "ParameterValue": external_id},
104
+ ],
105
+ Capabilities=["CAPABILITY_NAMED_IAM"],
106
+ OnFailure="DELETE",
107
+ )
108
+ except ClientError as e:
109
+ if "AlreadyExistsException" in str(e):
110
+ raise Exception(f"Stack {stack_name} already exists") from e
111
+ raise
112
+
113
+ return stack_name, region, cfn
114
+
115
+
116
+ def poll_stack_progress(
117
+ stack_name: str,
118
+ cfn: CloudFormationClient,
119
+ poll_interval: float = 2.0,
120
+ max_attempts: int = 300,
121
+ ) -> Generator[StackProgress, None, None]:
122
+ """Poll stack creation progress and yield updates."""
123
+ resources: dict[str, ResourceStatus] = {}
124
+
125
+ for _ in range(max_attempts):
126
+ # Get current stack status
127
+ try:
128
+ stacks_response = cfn.describe_stacks(StackName=stack_name)
129
+ except ClientError:
130
+ # Stack might be deleted on failure
131
+ yield StackProgress(
132
+ resources=resources,
133
+ stack_status="DELETE_IN_PROGRESS",
134
+ is_complete=False,
135
+ is_failed=True,
136
+ failure_reason="Stack creation failed and is being deleted",
137
+ )
138
+ return
139
+
140
+ stack = stacks_response["Stacks"][0]
141
+ stack_status = stack.get("StackStatus", "UNKNOWN")
142
+
143
+ # Get resource events
144
+ events_response = cfn.describe_stack_events(StackName=stack_name)
145
+ for event in events_response["StackEvents"]:
146
+ logical_id = event.get("LogicalResourceId", "")
147
+ resource_type = event.get("ResourceType", "")
148
+ status = event.get("ResourceStatus", "")
149
+ reason = event.get("ResourceStatusReason")
150
+
151
+ # Skip the stack itself
152
+ if resource_type == "AWS::CloudFormation::Stack":
153
+ continue
154
+
155
+ # Update if this is a newer status for this resource
156
+ if logical_id not in resources or _is_newer_status(status, resources[logical_id].status):
157
+ resources[logical_id] = ResourceStatus(
158
+ logical_id=logical_id,
159
+ resource_type=resource_type,
160
+ status=status,
161
+ reason=reason,
162
+ )
163
+
164
+ # Check for completion
165
+ is_complete = stack_status == "CREATE_COMPLETE"
166
+ is_failed = stack_status in ("CREATE_FAILED", "ROLLBACK_COMPLETE", "ROLLBACK_IN_PROGRESS", "DELETE_IN_PROGRESS")
167
+
168
+ failure_reason: str | None = None
169
+ if is_failed:
170
+ # Find failure reasons
171
+ failed_resources = [r for r in resources.values() if "FAILED" in r.status]
172
+ if failed_resources:
173
+ reasons = [r.reason for r in failed_resources if r.reason]
174
+ failure_reason = "; ".join(reasons[:3]) if reasons else "Unknown error"
175
+
176
+ yield StackProgress(
177
+ resources=resources,
178
+ stack_status=stack_status,
179
+ is_complete=is_complete,
180
+ is_failed=is_failed,
181
+ failure_reason=failure_reason,
182
+ )
183
+
184
+ if is_complete or is_failed:
185
+ return
186
+
187
+ time.sleep(poll_interval)
188
+
189
+ # Timeout
190
+ yield StackProgress(
191
+ resources=resources,
192
+ stack_status="TIMEOUT",
193
+ is_complete=False,
194
+ is_failed=True,
195
+ failure_reason="Stack creation timed out",
196
+ )
197
+
198
+
199
+ def _is_newer_status(new_status: str, old_status: str) -> bool:
200
+ """Check if new_status is more recent than old_status."""
201
+ # Status progression order
202
+ status_order = [
203
+ "CREATE_IN_PROGRESS",
204
+ "CREATE_COMPLETE",
205
+ "CREATE_FAILED",
206
+ "DELETE_IN_PROGRESS",
207
+ "DELETE_COMPLETE",
208
+ ]
209
+
210
+ try:
211
+ new_idx = status_order.index(new_status)
212
+ old_idx = status_order.index(old_status)
213
+ return new_idx > old_idx
214
+ except ValueError:
215
+ # Unknown status, assume it's newer
216
+ return True
217
+
218
+
219
+ def get_stack_outputs(stack_name: str, region: str) -> SetupResult:
220
+ """Get outputs from a completed stack."""
221
+ cfn = cast("CloudFormationClient", boto3.client("cloudformation", region_name=region)) # pyright: ignore[reportUnknownMemberType]
222
+
223
+ response = cfn.describe_stacks(StackName=stack_name)
224
+ stack_outputs = response["Stacks"][0].get("Outputs", [])
225
+ outputs: dict[str, str] = {
226
+ o["OutputKey"]: o["OutputValue"]
227
+ for o in stack_outputs
228
+ if "OutputKey" in o and "OutputValue" in o
229
+ }
230
+
231
+ return SetupResult(
232
+ bucket=outputs["BucketName"],
233
+ region=region,
234
+ role_arn=outputs["RoleArn"],
235
+ external_id=outputs["ExternalId"],
236
+ stack_name=stack_name,
237
+ kms_key_arn=outputs["KmsKeyArn"],
238
+ )
239
+
240
+
241
+ def create_stack(
242
+ external_id: str,
243
+ principal_arn: str,
244
+ bucket_name: str | None = None,
245
+ region: str | None = None,
246
+ ) -> SetupResult:
247
+ """Create CloudFormation stack with credential storage infrastructure.
248
+
249
+ This is the simple blocking API. For progress updates, use
250
+ start_stack_creation() + poll_stack_progress() + get_stack_outputs().
251
+ """
252
+ stack_name, region, cfn = start_stack_creation(
253
+ external_id, principal_arn, bucket_name, region
254
+ )
255
+
256
+ # Poll until complete
257
+ for progress in poll_stack_progress(stack_name, cfn):
258
+ if progress.is_failed:
259
+ raise Exception(f"Stack creation failed: {progress.failure_reason}")
260
+ if progress.is_complete:
261
+ break
262
+
263
+ return get_stack_outputs(stack_name, region)