more-compute 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,316 @@
1
+ from pydantic import BaseModel
2
+ from datetime import datetime
3
+ import httpx
4
+ from fastapi import HTTPException
5
+
6
+ class EnvVar(BaseModel):
7
+ key: str
8
+ value: str
9
+
10
+ class PodConfig(BaseModel):
11
+ # Required fields
12
+ name: str
13
+ cloudId: str
14
+ gpuType: str
15
+ socket: str
16
+ gpuCount: int = 1
17
+
18
+ # Optional
19
+ diskSize: int | None = None
20
+ vcpus: int | None = None
21
+ memory: int | None = None
22
+ maxPrice: float | None = None
23
+ image: str | None = None
24
+ customTemplateId: str | None = None
25
+ dataCenterId: str | None = None
26
+ country: str | None = None
27
+ security: str | None = None
28
+ envVars: list[EnvVar] | None = None
29
+ jupyterPassword: str | None = None
30
+ autoRestart: bool | None = None
31
+
32
+
33
+ class ProviderConfig(BaseModel):
34
+ type: str = "runpod"
35
+
36
+
37
+ class TeamConfig(BaseModel):
38
+ teamId: str | None = None
39
+
40
+
41
+ class CreatePodRequest(BaseModel):
42
+ pod: PodConfig
43
+ provider: ProviderConfig
44
+ team: TeamConfig | None = None
45
+
46
+
47
+ class PodResponse(BaseModel):
48
+ id: str
49
+ userId: str
50
+ teamId: str | None
51
+ name: str
52
+ status: str
53
+ gpuName: str
54
+ gpuCount: int
55
+ priceHr: float
56
+ sshConnection: str | None
57
+ ip: str | None
58
+ createdAt: datetime
59
+ updatedAt: datetime
60
+
61
+
62
+ class AvailabilityQuery(BaseModel):
63
+ regions: list[str] | None = None
64
+ gpu_count: int | None = None
65
+ gpu_type: str | None = None
66
+ security: str | None = None
67
+
68
+
69
+ class PrimeIntellectService:
70
+ """ service to collect pi stuff"""
71
+ def __init__(self, api_key : str, base_url: str = "https://api.primeintellect.ai/api/v1"):
72
+ self.api_key = api_key
73
+ self.base_url = base_url
74
+ self.headers = {
75
+ "Authorization": f"Bearer {api_key}",
76
+ "Content-Type": "application/json"
77
+ }
78
+
79
+ async def _make_request(
80
+ self,
81
+ method: str,
82
+ endpoint: str,
83
+ params: dict[str, str | int | float | list[str]] | None = None,
84
+ json_data: dict[str, object] | None = None
85
+ ) -> dict[str, object]:
86
+ """Internal method to make HTTP requests with error handling."""
87
+ url = f"{self.base_url}{endpoint}"
88
+
89
+ async with httpx.AsyncClient() as client:
90
+ try:
91
+ response = await client.request(
92
+ method=method,
93
+ url=url,
94
+ headers=self.headers,
95
+ params=params,
96
+ json=json_data,
97
+ timeout=30.0
98
+ )
99
+ response.raise_for_status()
100
+ return response.json()
101
+ except httpx.HTTPStatusError as e:
102
+ raise HTTPException(
103
+ status_code=e.response.status_code,
104
+ detail=f"Prime Intellect API error: {e.response.text}"
105
+ )
106
+ except httpx.RequestError as e:
107
+ raise HTTPException(
108
+ status_code=503,
109
+ detail=f"Connection error: {str(e)}"
110
+ )
111
+
112
+
113
+ async def get_gpu_availability(
114
+ self,
115
+ regions: list[str] | None = None,
116
+ gpu_count: int | None = None,
117
+ gpu_type: str | None = None,
118
+ security: str | None = None
119
+ ) -> dict[str, object]:
120
+ """
121
+ Get available GPU resources with pricing and specifications
122
+
123
+ Args:
124
+ regions,
125
+ gpu_count,
126
+ gpu_type (e.g H100, A100..)
127
+ security: secure cloud or community cloud\
128
+
129
+ Returns:
130
+ Dict containing available gpus given the parameters.
131
+ """
132
+ params: dict[str, str | int | float | list[str]] = {}
133
+
134
+ if regions is not None:
135
+ params["regions"] = regions
136
+ if gpu_count is not None:
137
+ params["gpu_count"] = gpu_count
138
+ if gpu_type is not None:
139
+ params["gpu_type"] = gpu_type
140
+ if security is not None:
141
+ params["security"] = security
142
+ return await self._make_request("GET", "/availability/", params=params)
143
+
144
+
145
+ async def get_cluster(self) -> dict[str, object]:
146
+ """
147
+ Get availabile multi-node cluster configs
148
+ """
149
+ return await self._make_request("GET", "/cluster-availability")
150
+
151
+ async def create_pod(self, pod_request: CreatePodRequest) -> PodResponse:
152
+ """
153
+ Create a new pod
154
+ """
155
+ import sys
156
+ payload = pod_request.model_dump(exclude_none=True)
157
+ print(f"[PI SERVICE] Creating pod with payload: {payload}", file=sys.stderr, flush=True)
158
+
159
+ response = await self._make_request(
160
+ "POST",
161
+ "/pods/",
162
+ json_data=payload
163
+ )
164
+ return PodResponse.model_validate(response)
165
+
166
+
167
+ async def get_pods(
168
+ self,
169
+ status: str | None = None,
170
+ limit: int = 100,
171
+ offset: int = 0
172
+ ) -> dict[str, object]:
173
+
174
+ """
175
+ Get list of all pods that user has pulled
176
+
177
+ Args:
178
+ status: Filter by status (running, stopped, etc)
179
+ limit: max # of results
180
+ offset: pagination offset
181
+
182
+ returns:
183
+ dict with list of pods
184
+ """
185
+ params: dict[str, str | int | float | list[str]] = {"limit": limit, "offset": offset}
186
+ if status is not None:
187
+ params['status'] = status
188
+
189
+ return await self._make_request("GET", "/pods/", params=params)
190
+
191
+ async def get_pod(self, pod_id: str) -> PodResponse:
192
+ """
193
+ for searching up a pod via pod_id
194
+
195
+ Args:
196
+ pod_id: the pod identifier
197
+
198
+ Returns:
199
+ PodResponse with pod Information
200
+ """
201
+ response = await self._make_request("GET", f"/pods/{pod_id}")
202
+ return PodResponse.model_validate(response)
203
+
204
+ async def get_pod_status(self, pod_ids: list[str]) -> dict[str, object]:
205
+ """
206
+ for searching up pod status via pod_id
207
+
208
+ Args:
209
+ pod_ids: list of specific pod IDs to check
210
+
211
+ Returns:
212
+ Dict with status information for requested pods
213
+ """
214
+ params = {}
215
+ if pod_ids:
216
+ params['pod_ids'] = pod_ids
217
+
218
+ return await self._make_request("GET", "/pods/status", params=params)
219
+
220
+
221
+ async def get_pods_history(
222
+ self,
223
+ limit: int = 100,
224
+ offset: int = 0
225
+ ) -> dict[str, object]:
226
+ """
227
+ Get historical data for terminated pods.
228
+
229
+ Args:
230
+ limit: Maximum number of results
231
+ offset: Pagination offset
232
+
233
+ Returns:
234
+ Dict with historical pod data
235
+ """
236
+ params: dict[str, str | int | float | list[str]] = {"limit": limit, "offset": offset}
237
+ return await self._make_request("GET", "/pods/history", params=params)
238
+
239
+ async def delete_pod(self, pod_id: str) -> dict[str, object]:
240
+ """
241
+ delete a pod
242
+
243
+ args:
244
+ pod_id: the pod identifier
245
+
246
+ returns:
247
+ Dict with deletion confirmation
248
+ """
249
+ return await self._make_request("DELETE", f"/pods/{pod_id}")
250
+
251
+ async def get_pod_logs(self, pod_id: str) -> dict[str, object]:
252
+ """
253
+ Retrieve logs for a specific pod.
254
+
255
+ Args:
256
+ pod_id: The pod identifier
257
+
258
+ Returns:
259
+ Dict containing pod logs
260
+ """
261
+ return await self._make_request("GET", f"/pods/{pod_id}/logs")
262
+
263
+ async def add_metrics(self, pod_id: str, metrics: dict[str, object]) -> dict[str, object]:
264
+ """
265
+ Add custom metrics for a pod.
266
+
267
+ Args:
268
+ pod_id: The pod identifier
269
+ metrics: Dictionary of metric data
270
+
271
+ Returns:
272
+ Dict with confirmation
273
+ """
274
+ return await self._make_request("POST", f"/pods/{pod_id}/metrics", json_data=metrics)
275
+
276
+ async def get_ssh_keys(self) -> dict[str, object]:
277
+ """Get list of all SSH keys."""
278
+ return await self._make_request("GET", "/ssh-keys/")
279
+
280
+ async def upload_ssh_key(self, name: str, public_key: str) -> dict[str, object]:
281
+ """
282
+ Upload a new SSH public key.
283
+
284
+ Args:
285
+ name: Name for the SSH key
286
+ public_key: The SSH public key content
287
+
288
+ Returns:
289
+ Dict with key information
290
+ """
291
+ data: dict[str, object] = {"name": name, "publicKey": public_key}
292
+ return await self._make_request("POST", "/ssh-keys/", json_data=data)
293
+
294
+ async def delete_ssh_key(self, key_id: str) -> dict[str, object]:
295
+ """
296
+ Delete an SSH key.
297
+
298
+ Args:
299
+ key_id: The SSH key identifier
300
+
301
+ Returns:
302
+ Dict with deletion confirmation
303
+ """
304
+ return await self._make_request("DELETE", f"/ssh-keys/{key_id}")
305
+
306
+ async def set_primary_ssh_key(self, key_id: str) -> dict[str, object]:
307
+ """
308
+ Set an SSH key as primary.
309
+
310
+ Args:
311
+ key_id: The SSH key identifier
312
+
313
+ Returns:
314
+ Dict with confirmation
315
+ """
316
+ return await self._make_request("PATCH", f"/ssh-keys/{key_id}/primary")