comfygit-deploy 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,549 @@
1
+ """RunPod REST and GraphQL API client for pod and resource management.
2
+
3
+ REST API v1: https://rest.runpod.io/v1
4
+ GraphQL API: https://api.runpod.io/graphql
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ import aiohttp
11
+
12
+ # RunPod data centers (static list - no REST endpoint available)
13
+ DATA_CENTERS = [
14
+ {"id": "US-GA-1", "name": "United States (Georgia)", "available": True},
15
+ {"id": "US-GA-2", "name": "United States (Georgia 2)", "available": True},
16
+ {"id": "US-IL-1", "name": "United States (Illinois)", "available": True},
17
+ {"id": "US-KS-2", "name": "United States (Kansas)", "available": True},
18
+ {"id": "US-KS-3", "name": "United States (Kansas 2)", "available": True},
19
+ {"id": "US-TX-1", "name": "United States (Texas)", "available": True},
20
+ {"id": "US-TX-3", "name": "United States (Texas 2)", "available": True},
21
+ {"id": "US-TX-4", "name": "United States (Texas 3)", "available": True},
22
+ {"id": "US-WA-1", "name": "United States (Washington)", "available": True},
23
+ {"id": "US-CA-2", "name": "United States (California)", "available": True},
24
+ {"id": "US-NC-1", "name": "United States (North Carolina)", "available": True},
25
+ {"id": "US-DE-1", "name": "United States (Delaware)", "available": True},
26
+ {"id": "CA-MTL-1", "name": "Canada (Montreal)", "available": True},
27
+ {"id": "CA-MTL-2", "name": "Canada (Montreal 2)", "available": True},
28
+ {"id": "CA-MTL-3", "name": "Canada (Montreal 3)", "available": True},
29
+ {"id": "EU-CZ-1", "name": "Europe (Czech Republic)", "available": True},
30
+ {"id": "EU-FR-1", "name": "Europe (France)", "available": True},
31
+ {"id": "EU-NL-1", "name": "Europe (Netherlands)", "available": True},
32
+ {"id": "EU-RO-1", "name": "Europe (Romania)", "available": True},
33
+ {"id": "EU-SE-1", "name": "Europe (Sweden)", "available": True},
34
+ {"id": "EUR-IS-1", "name": "Europe (Iceland)", "available": True},
35
+ {"id": "EUR-IS-2", "name": "Europe (Iceland 2)", "available": True},
36
+ {"id": "EUR-IS-3", "name": "Europe (Iceland 3)", "available": True},
37
+ {"id": "EUR-NO-1", "name": "Europe (Norway)", "available": True},
38
+ {"id": "AP-JP-1", "name": "Asia-Pacific (Japan)", "available": True},
39
+ {"id": "OC-AU-1", "name": "Oceania (Australia)", "available": True},
40
+ ]
41
+
42
+ # Common GPU types available on RunPod
43
+ GPU_TYPES = [
44
+ "NVIDIA GeForce RTX 4090",
45
+ "NVIDIA GeForce RTX 3090",
46
+ "NVIDIA GeForce RTX 3090 Ti",
47
+ "NVIDIA GeForce RTX 3080 Ti",
48
+ "NVIDIA GeForce RTX 3080",
49
+ "NVIDIA GeForce RTX 3070",
50
+ "NVIDIA GeForce RTX 4080",
51
+ "NVIDIA GeForce RTX 4080 SUPER",
52
+ "NVIDIA GeForce RTX 4070 Ti",
53
+ "NVIDIA GeForce RTX 5090",
54
+ "NVIDIA GeForce RTX 5080",
55
+ "NVIDIA A40",
56
+ "NVIDIA A100 80GB PCIe",
57
+ "NVIDIA A100-SXM4-80GB",
58
+ "NVIDIA A30",
59
+ "NVIDIA H100 80GB HBM3",
60
+ "NVIDIA H100 PCIe",
61
+ "NVIDIA H100 NVL",
62
+ "NVIDIA H200",
63
+ "NVIDIA B200",
64
+ "NVIDIA L40S",
65
+ "NVIDIA L40",
66
+ "NVIDIA L4",
67
+ "NVIDIA RTX A6000",
68
+ "NVIDIA RTX A5000",
69
+ "NVIDIA RTX A4500",
70
+ "NVIDIA RTX A4000",
71
+ "NVIDIA RTX A2000",
72
+ "NVIDIA RTX 6000 Ada Generation",
73
+ "NVIDIA RTX 5000 Ada Generation",
74
+ "NVIDIA RTX 4000 Ada Generation",
75
+ "NVIDIA RTX 4000 SFF Ada Generation",
76
+ "NVIDIA RTX 2000 Ada Generation",
77
+ "Tesla V100-PCIE-16GB",
78
+ "Tesla V100-FHHL-16GB",
79
+ "Tesla V100-SXM2-16GB",
80
+ "Tesla V100-SXM2-32GB",
81
+ "AMD Instinct MI300X OAM",
82
+ ]
83
+
84
+
85
+ @dataclass
86
+ class RunPodAPIError(Exception):
87
+ """Error from RunPod API."""
88
+
89
+ message: str
90
+ status_code: int
91
+
92
+ def __str__(self) -> str:
93
+ return f"RunPod API Error ({self.status_code}): {self.message}"
94
+
95
+
96
+ class RunPodClient:
97
+ """Async client for RunPod REST and GraphQL APIs."""
98
+
99
+ base_url = "https://rest.runpod.io/v1"
100
+ graphql_url = "https://api.runpod.io/graphql"
101
+
102
+ def __init__(self, api_key: str):
103
+ """Initialize client with API key.
104
+
105
+ Args:
106
+ api_key: RunPod API key (starts with rpa_ or rps_)
107
+
108
+ Raises:
109
+ ValueError: If api_key is empty
110
+ """
111
+ if not api_key:
112
+ raise ValueError("API key required")
113
+ self.api_key = api_key
114
+
115
+ def _headers(self) -> dict[str, str]:
116
+ """Get request headers with authorization."""
117
+ return {
118
+ "Authorization": f"Bearer {self.api_key}",
119
+ "Content-Type": "application/json",
120
+ }
121
+
122
+ async def _get(
123
+ self, path: str, params: dict | None = None, operation: str = "get"
124
+ ) -> Any:
125
+ """Make GET request and return JSON response."""
126
+ async with aiohttp.ClientSession() as session:
127
+ async with session.get(
128
+ f"{self.base_url}{path}",
129
+ params=params,
130
+ headers=self._headers(),
131
+ ) as response:
132
+ if response.status >= 400:
133
+ await self._handle_error(response)
134
+ return await response.json()
135
+
136
+ async def _post(
137
+ self, path: str, data: dict | None = None, operation: str = "post"
138
+ ) -> Any:
139
+ """Make POST request and return JSON response."""
140
+ async with aiohttp.ClientSession() as session:
141
+ async with session.post(
142
+ f"{self.base_url}{path}",
143
+ json=data,
144
+ headers=self._headers(),
145
+ ) as response:
146
+ if response.status >= 400:
147
+ await self._handle_error(response)
148
+ if response.status == 204:
149
+ return None
150
+ try:
151
+ return await response.json()
152
+ except Exception:
153
+ return None
154
+
155
+ async def _delete(self, path: str, operation: str = "delete") -> None:
156
+ """Make DELETE request."""
157
+ async with aiohttp.ClientSession() as session:
158
+ async with session.delete(
159
+ f"{self.base_url}{path}",
160
+ headers=self._headers(),
161
+ ) as response:
162
+ if response.status >= 400:
163
+ await self._handle_error(response)
164
+
165
+ async def _patch(
166
+ self, path: str, data: dict | None = None, operation: str = "patch"
167
+ ) -> Any:
168
+ """Make PATCH request and return JSON response."""
169
+ async with aiohttp.ClientSession() as session:
170
+ async with session.patch(
171
+ f"{self.base_url}{path}",
172
+ json=data,
173
+ headers=self._headers(),
174
+ ) as response:
175
+ if response.status >= 400:
176
+ await self._handle_error(response)
177
+ return await response.json()
178
+
179
+ async def _handle_error(self, response: aiohttp.ClientResponse) -> None:
180
+ """Handle error response."""
181
+ try:
182
+ error_body = await response.json()
183
+ message = error_body.get(
184
+ "message", error_body.get("error", str(error_body))
185
+ )
186
+ except Exception:
187
+ message = await response.text() or f"HTTP {response.status}"
188
+ raise RunPodAPIError(message, response.status)
189
+
190
+ async def _graphql_query(
191
+ self, query: str, variables: dict | None = None, operation: str = "graphql"
192
+ ) -> dict:
193
+ """Execute a GraphQL query against RunPod API.
194
+
195
+ Note: RunPod GraphQL uses API key as URL parameter, not Bearer token.
196
+ """
197
+ url = f"{self.graphql_url}?api_key={self.api_key}"
198
+ payload: dict[str, Any] = {"query": query}
199
+ if variables:
200
+ payload["variables"] = variables
201
+
202
+ async with aiohttp.ClientSession() as session:
203
+ async with session.post(
204
+ url,
205
+ json=payload,
206
+ headers={"Content-Type": "application/json"},
207
+ ) as response:
208
+ return await response.json()
209
+
210
+ def _handle_graphql_errors(
211
+ self, result: dict, operation: str = "graphql"
212
+ ) -> None:
213
+ """Raise exception if GraphQL response contains errors."""
214
+ if "errors" in result:
215
+ error_msg = result["errors"][0].get("message", "GraphQL error")
216
+ raise RunPodAPIError(error_msg, 400)
217
+
218
+ # =========================================================================
219
+ # User Info / Connection Test
220
+ # =========================================================================
221
+
222
+ async def get_user_info(self) -> dict[str, Any]:
223
+ """Get user account info including credit balance."""
224
+ query = """
225
+ query {
226
+ myself {
227
+ id
228
+ clientBalance
229
+ currentSpendPerHr
230
+ spendLimit
231
+ }
232
+ }
233
+ """
234
+ result = await self._graphql_query(query, operation="get_user_info")
235
+ self._handle_graphql_errors(result, "get_user_info")
236
+
237
+ if not result.get("data") or not result["data"].get("myself"):
238
+ raise RunPodAPIError("Invalid API key or unauthorized", 401)
239
+
240
+ return result["data"]["myself"]
241
+
242
+ async def test_connection(self) -> dict[str, Any]:
243
+ """Test API key validity and return account info.
244
+
245
+ Returns:
246
+ {"success": True, "credit_balance": float} on success
247
+ {"success": False, "error": "message"} on failure
248
+ """
249
+ try:
250
+ user_info = await self.get_user_info()
251
+ return {
252
+ "success": True,
253
+ "credit_balance": user_info.get("clientBalance", 0),
254
+ }
255
+ except RunPodAPIError as e:
256
+ return {"success": False, "error": e.message}
257
+ except Exception as e:
258
+ return {"success": False, "error": str(e)}
259
+
260
+ # =========================================================================
261
+ # Pod Operations
262
+ # =========================================================================
263
+
264
+ async def list_pods(
265
+ self,
266
+ desired_status: str | None = None,
267
+ gpu_type_id: str | None = None,
268
+ include_machine: bool = False,
269
+ ) -> list[dict]:
270
+ """List all pods."""
271
+ params = {}
272
+ if desired_status:
273
+ params["desiredStatus"] = desired_status
274
+ if gpu_type_id:
275
+ params["gpuTypeId"] = gpu_type_id
276
+ if include_machine:
277
+ params["includeMachine"] = "true"
278
+
279
+ return await self._get("/pods", params=params or None, operation="list_pods")
280
+
281
+ async def get_pod(self, pod_id: str, include_machine: bool = False) -> dict:
282
+ """Get pod by ID."""
283
+ params = {}
284
+ if include_machine:
285
+ params["includeMachine"] = "true"
286
+ return await self._get(
287
+ f"/pods/{pod_id}", params=params or None, operation="get_pod"
288
+ )
289
+
290
+ async def create_pod(
291
+ self,
292
+ name: str,
293
+ image_name: str,
294
+ gpu_type_id: str,
295
+ gpu_count: int = 1,
296
+ volume_in_gb: int = 20,
297
+ container_disk_in_gb: int = 50,
298
+ cloud_type: str = "SECURE",
299
+ ports: list[str] | None = None,
300
+ env: dict[str, str] | None = None,
301
+ docker_start_cmd: list[str] | None = None,
302
+ network_volume_id: str | None = None,
303
+ data_center_ids: list[str] | None = None,
304
+ interruptible: bool = False,
305
+ ) -> dict:
306
+ """Create a new pod."""
307
+ data = {
308
+ "name": name,
309
+ "imageName": image_name,
310
+ "gpuTypeIds": [gpu_type_id],
311
+ "gpuCount": gpu_count,
312
+ "volumeInGb": volume_in_gb,
313
+ "containerDiskInGb": container_disk_in_gb,
314
+ "cloudType": cloud_type,
315
+ "interruptible": interruptible,
316
+ }
317
+
318
+ if ports:
319
+ data["ports"] = ports
320
+ if env:
321
+ data["env"] = env
322
+ if docker_start_cmd:
323
+ data["dockerStartCmd"] = docker_start_cmd
324
+ if network_volume_id:
325
+ data["networkVolumeId"] = network_volume_id
326
+ if data_center_ids:
327
+ data["dataCenterIds"] = data_center_ids
328
+
329
+ return await self._post("/pods", data=data, operation="create_pod")
330
+
331
+ async def delete_pod(self, pod_id: str) -> bool:
332
+ """Delete a pod."""
333
+ await self._delete(f"/pods/{pod_id}", operation="delete_pod")
334
+ return True
335
+
336
+ async def start_pod(self, pod_id: str) -> dict:
337
+ """Start a stopped pod using GraphQL podResume mutation."""
338
+ query = f"""
339
+ mutation {{
340
+ podResume(input: {{ podId: "{pod_id}" }}) {{
341
+ id
342
+ desiredStatus
343
+ costPerHr
344
+ }}
345
+ }}
346
+ """
347
+ result = await self._graphql_query(query, operation="start_pod")
348
+ self._handle_graphql_errors(result, "start_pod")
349
+ return result["data"]["podResume"]
350
+
351
+ async def stop_pod(self, pod_id: str) -> dict:
352
+ """Stop a running pod using GraphQL podStop mutation."""
353
+ query = f"""
354
+ mutation {{
355
+ podStop(input: {{ podId: "{pod_id}" }}) {{
356
+ id
357
+ desiredStatus
358
+ }}
359
+ }}
360
+ """
361
+ result = await self._graphql_query(query, operation="stop_pod")
362
+ self._handle_graphql_errors(result, "stop_pod")
363
+ return result["data"]["podStop"]
364
+
365
+ async def restart_pod(self, pod_id: str) -> bool:
366
+ """Restart a pod."""
367
+ await self._post(f"/pods/{pod_id}/restart", operation="restart_pod")
368
+ return True
369
+
370
+ # =========================================================================
371
+ # Network Volume Operations
372
+ # =========================================================================
373
+
374
+ async def list_network_volumes(self) -> list[dict]:
375
+ """List network volumes."""
376
+ return await self._get("/networkvolumes", operation="list_network_volumes")
377
+
378
+ async def get_network_volume(self, volume_id: str) -> dict:
379
+ """Get network volume by ID."""
380
+ return await self._get(
381
+ f"/networkvolumes/{volume_id}", operation="get_network_volume"
382
+ )
383
+
384
+ async def create_network_volume(
385
+ self, name: str, size_gb: int, data_center_id: str
386
+ ) -> dict:
387
+ """Create a network volume."""
388
+ data = {
389
+ "name": name,
390
+ "size": size_gb,
391
+ "dataCenterId": data_center_id,
392
+ }
393
+ return await self._post(
394
+ "/networkvolumes", data=data, operation="create_network_volume"
395
+ )
396
+
397
+ async def delete_network_volume(self, volume_id: str) -> bool:
398
+ """Delete a network volume."""
399
+ await self._delete(
400
+ f"/networkvolumes/{volume_id}", operation="delete_network_volume"
401
+ )
402
+ return True
403
+
404
+ # =========================================================================
405
+ # GPU Types and Data Centers
406
+ # =========================================================================
407
+
408
+ async def get_gpu_types_with_pricing(
409
+ self, data_center_id: str | None = None
410
+ ) -> list[dict]:
411
+ """Get GPU types with pricing and availability."""
412
+ if data_center_id:
413
+ lowest_price_input = (
414
+ f'input: {{ gpuCount: 1, dataCenterId: "{data_center_id}" }}'
415
+ )
416
+ else:
417
+ lowest_price_input = "input: { gpuCount: 1 }"
418
+
419
+ query = f"""
420
+ query {{
421
+ gpuTypes {{
422
+ id
423
+ displayName
424
+ memoryInGb
425
+ secureCloud
426
+ communityCloud
427
+ securePrice
428
+ communityPrice
429
+ secureSpotPrice
430
+ communitySpotPrice
431
+ lowestPrice({lowest_price_input}) {{
432
+ minimumBidPrice
433
+ uninterruptablePrice
434
+ stockStatus
435
+ }}
436
+ }}
437
+ }}
438
+ """
439
+ result = await self._graphql_query(
440
+ query, operation="get_gpu_types_with_pricing"
441
+ )
442
+ self._handle_graphql_errors(result, "get_gpu_types_with_pricing")
443
+ return result["data"]["gpuTypes"]
444
+
445
+ async def get_data_centers(self) -> list[dict]:
446
+ """Get available data centers (uses static fallback if API fails)."""
447
+ try:
448
+ query = """
449
+ query {
450
+ myself {
451
+ datacenters {
452
+ id
453
+ name
454
+ location
455
+ storageSupport
456
+ listed
457
+ region
458
+ }
459
+ }
460
+ }
461
+ """
462
+ result = await self._graphql_query(query, operation="get_data_centers")
463
+ self._handle_graphql_errors(result, "get_data_centers")
464
+ raw_dcs = result["data"]["myself"]["datacenters"]
465
+ return [
466
+ {
467
+ "id": dc.get("id"),
468
+ "name": dc.get("name") or dc.get("location", dc.get("id")),
469
+ "location": dc.get("location"),
470
+ "region": dc.get("region"),
471
+ "available": dc.get("listed", True)
472
+ and dc.get("storageSupport", True),
473
+ }
474
+ for dc in raw_dcs
475
+ if dc.get("listed", True)
476
+ ]
477
+ except Exception:
478
+ return DATA_CENTERS.copy()
479
+
480
+ # =========================================================================
481
+ # Static Helper Methods
482
+ # =========================================================================
483
+
484
+ @staticmethod
485
+ def get_comfyui_url(pod: dict, port: int = 8188) -> str | None:
486
+ """Get ComfyUI proxy URL for a running pod."""
487
+ if pod.get("desiredStatus") != "RUNNING":
488
+ return None
489
+
490
+ pod_id = pod.get("id")
491
+ if not pod_id:
492
+ return None
493
+
494
+ return f"https://{pod_id}-{port}.proxy.runpod.net"
495
+
496
+ @staticmethod
497
+ def get_ssh_command(pod: dict) -> str | None:
498
+ """Get SSH command for connecting to a pod."""
499
+ public_ip = pod.get("publicIp")
500
+ port_mappings = pod.get("portMappings", {})
501
+
502
+ if not public_ip or not port_mappings:
503
+ return None
504
+
505
+ ssh_port = port_mappings.get("22")
506
+ if not ssh_port:
507
+ return None
508
+
509
+ return f"ssh root@{public_ip} -p {ssh_port}"
510
+
511
+ @staticmethod
512
+ def _estimate_gpu_memory(gpu_id: str) -> int:
513
+ """Estimate GPU memory based on model name."""
514
+ memory_map = {
515
+ "RTX 4090": 24,
516
+ "RTX 5090": 32,
517
+ "RTX 5080": 16,
518
+ "RTX 4080": 16,
519
+ "RTX 3090": 24,
520
+ "RTX 3080": 10,
521
+ "RTX 3070": 8,
522
+ "RTX 4070": 12,
523
+ "A100 80GB": 80,
524
+ "A100-SXM4-80GB": 80,
525
+ "H100 80GB": 80,
526
+ "H100 PCIe": 80,
527
+ "H100 NVL": 94,
528
+ "H200": 141,
529
+ "B200": 192,
530
+ "A40": 48,
531
+ "A30": 24,
532
+ "L40S": 48,
533
+ "L40": 48,
534
+ "L4": 24,
535
+ "RTX A6000": 48,
536
+ "RTX A5000": 24,
537
+ "RTX A4500": 20,
538
+ "RTX A4000": 16,
539
+ "RTX 6000 Ada": 48,
540
+ "RTX 5000 Ada": 32,
541
+ "RTX 4000 Ada": 20,
542
+ "V100": 16,
543
+ "V100-SXM2-32GB": 32,
544
+ "MI300X": 192,
545
+ }
546
+ for key, mem in memory_map.items():
547
+ if key in gpu_id:
548
+ return mem
549
+ return 24 # Default
@@ -0,0 +1 @@
1
+ """Startup script generation for deployments."""