matrice-compute 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,30 @@
1
- """Module providing scaling functionality."""
1
+
2
2
 
3
3
  import os
4
4
  import logging
5
+ import json
6
+ import psutil
5
7
  from matrice_common.utils import log_errors
8
+ from kafka import KafkaProducer, KafkaConsumer
9
+ import uuid
10
+ import time
11
+ import base64
12
+ import threading
13
+ import platform
14
+ import subprocess
15
+
6
16
 
7
17
  class Scaling:
8
18
 
9
19
  """Class providing scaling functionality for compute instances."""
10
20
 
11
- def __init__(self, session, instance_id=None, enable_kafka=False):
21
+ def __init__(self, session, instance_id=None, enable_kafka=True):
12
22
  """Initialize Scaling instance.
13
23
 
14
24
  Args:
15
25
  session: Session object for making RPC calls
16
26
  instance_id: ID of the compute instance
17
- enable_kafka: Deprecated parameter, kept for backward compatibility (ignored)
27
+ enable_kafka: Enable Kafka communication (default True)
18
28
 
19
29
  Raises:
20
30
  Exception: If instance_id is not provided
@@ -29,11 +39,91 @@ class Scaling:
29
39
  used_ports_str = os.environ.get("USED_PORTS", "")
30
40
  self.used_ports = set(int(p) for p in used_ports_str.split(",") if p.strip())
31
41
 
42
+ # Kafka configuration and initialization
43
+ self.enable_kafka = enable_kafka
44
+ self.kafka_producer = None
45
+ self.kafka_consumer = None
46
+ self.kafka_thread = None
47
+ self.kafka_running = False
48
+
49
+ # Maps correlation_id to threading.Event for request/response matching
50
+ self.pending_requests = {}
51
+ # Maps correlation_id to response data
52
+ self.response_map = {}
53
+ self.response_lock = threading.Lock()
54
+
55
+ if self.enable_kafka:
56
+ try:
57
+ self.kafka_config = {
58
+ "bootstrap_servers": self.get_kafka_bootstrap_servers(),
59
+ "action_request_topic": "action_requests",
60
+ "action_response_topic": "action_responses",
61
+ "compute_request_topic": "compute_requests",
62
+ "compute_response_topic": "compute_responses"
63
+ }
64
+
65
+ # Initialize single producer
66
+ self.kafka_producer = KafkaProducer(
67
+ bootstrap_servers=self.kafka_config["bootstrap_servers"],
68
+ value_serializer=lambda v: json.dumps(v).encode("utf-8"),
69
+ max_block_ms=5000 # Timeout if Kafka is down
70
+ )
71
+
72
+ # Initialize single consumer for both response topics
73
+ self.kafka_consumer = KafkaConsumer(
74
+ self.kafka_config["action_response_topic"],
75
+ self.kafka_config["compute_response_topic"],
76
+ bootstrap_servers=self.kafka_config["bootstrap_servers"],
77
+ group_id=f"py_compute_{instance_id}",
78
+ value_deserializer=lambda m: json.loads(m.decode("utf-8")),
79
+ auto_offset_reset='latest',
80
+ enable_auto_commit=True,
81
+ consumer_timeout_ms=1000, # Poll timeout
82
+ session_timeout_ms=60000, # Increase session timeout to 60s (default 30s)
83
+ heartbeat_interval_ms=3000, # Send heartbeat every 3s
84
+ max_poll_interval_ms=300000 # Max time between polls: 5 minutes
85
+ )
86
+
87
+ # Start background thread to handle responses
88
+ self.kafka_running = True
89
+ self.kafka_thread = threading.Thread(target=self._kafka_response_listener, daemon=True)
90
+ self.kafka_thread.start()
91
+
92
+ logging.info(f"Kafka enabled with bootstrap servers: {self.kafka_config['bootstrap_servers']}")
93
+ except Exception as e:
94
+ logging.warning(f"Failed to initialize Kafka, will use REST API only: {e}")
95
+ self.enable_kafka = False
96
+ self.kafka_producer = None
97
+ self.kafka_consumer = None
98
+
32
99
  logging.info(
33
- "Initialized Scaling with instance_id: %s (REST API only)",
34
- instance_id
100
+ "Initialized Scaling with instance_id: %s, Kafka enabled: %s",
101
+ instance_id,
102
+ self.enable_kafka
35
103
  )
36
104
 
105
+ @log_errors(default_return=None, log_error=True)
106
+ def get_kafka_bootstrap_servers(self):
107
+ """Get Kafka bootstrap servers from API and decode base64 fields.
108
+
109
+ Returns:
110
+ str: Kafka bootstrap servers in format "ip:port"
111
+
112
+ Raises:
113
+ ValueError: If unable to fetch Kafka configuration
114
+ """
115
+ path = "/v1/actions/get_kafka_info"
116
+ response = self.rpc.get(path=path)
117
+ if not response or not response.get("success"):
118
+ raise ValueError(f"Failed to fetch Kafka config: {response.get('message', 'No response')}")
119
+ encoded_ip = response["data"]["ip"]
120
+ encoded_port = response["data"]["port"]
121
+ ip = base64.b64decode(encoded_ip).decode("utf-8")
122
+ port = base64.b64decode(encoded_port).decode("utf-8")
123
+ bootstrap_servers = f"{ip}:{port}"
124
+ # logging.info(f"Retrieved Kafka bootstrap servers: {bootstrap_servers}")
125
+ return bootstrap_servers
126
+
37
127
  @log_errors(default_return=(None, "Error processing response", "Response processing failed"), log_error=True)
38
128
  def handle_response(self, resp, success_message, error_message):
39
129
  """Helper function to handle API response.
@@ -58,52 +148,266 @@ class Scaling:
58
148
  logging.error("%s: %s", message, error)
59
149
  return data, error, message
60
150
 
151
+ def _kafka_response_listener(self):
152
+ """
153
+ Background thread that continuously polls for Kafka responses.
154
+
155
+ This thread runs in the background and listens for responses from both
156
+ action_responses and compute_responses topics. When a response is received,
157
+ it matches the correlation ID to pending requests and wakes up the waiting thread.
158
+ """
159
+ logging.info("Kafka response listener thread started")
160
+
161
+ while self.kafka_running:
162
+ try:
163
+ # Poll for messages with 1 second timeout
164
+ message_batch = self.kafka_consumer.poll(timeout_ms=1000)
165
+
166
+ if message_batch:
167
+ for topic_partition, messages in message_batch.items():
168
+ for message in messages:
169
+ try:
170
+ msg = message.value
171
+ correlation_id = msg.get("correlationId")
172
+
173
+ if correlation_id:
174
+ with self.response_lock:
175
+ if correlation_id in self.pending_requests:
176
+ # Store response and signal waiting thread
177
+ self.response_map[correlation_id] = msg
178
+ self.pending_requests[correlation_id].set()
179
+ logging.debug(f"Received Kafka response for correlation_id: {correlation_id}")
180
+ else:
181
+ logging.warning(f"Received Kafka message without correlationId: {msg}")
182
+ except Exception as e:
183
+ logging.error(f"Error processing Kafka message: {e}")
184
+
185
+ except Exception as e:
186
+ if self.kafka_running: # Only log if not shutting down
187
+ logging.error(f"Error in Kafka response listener: {e}")
188
+ time.sleep(1) # Avoid tight loop on persistent errors
189
+
190
+ logging.info("Kafka response listener thread stopped")
191
+
192
+ def _send_kafka_request(self, api, payload, request_topic, response_topic, timeout=5):
193
+ """
194
+ Send a request via Kafka and wait for response using the persistent consumer.
195
+
196
+ Args:
197
+ api: API name to call
198
+ payload: Request payload dictionary
199
+ request_topic: Kafka topic to send request to
200
+ response_topic: Kafka topic to receive response from (not used, kept for signature)
201
+ timeout: Timeout in seconds to wait for response
202
+
203
+ Returns:
204
+ Tuple of (data, error, message, kafka_success)
205
+ kafka_success is True if response received, False if timeout/error
206
+ """
207
+ if not self.enable_kafka or not self.kafka_producer:
208
+ return None, "Kafka not enabled", "Kafka not available", False
209
+
210
+ correlation_id = str(uuid.uuid4())
211
+ request_message = {
212
+ "correlationId": correlation_id,
213
+ "api": api,
214
+ "payload": payload,
215
+ }
216
+
217
+ # Create event for this request
218
+ event = threading.Event()
219
+
220
+ with self.response_lock:
221
+ self.pending_requests[correlation_id] = event
222
+
223
+ try:
224
+ # Add auth token if available
225
+ headers = None
226
+ if hasattr(self.session.rpc, 'AUTH_TOKEN'):
227
+ self.session.rpc.AUTH_TOKEN.set_bearer_token()
228
+ auth_token = self.session.rpc.AUTH_TOKEN.bearer_token
229
+ auth_token = auth_token.replace("Bearer ", "")
230
+ headers = [("Authorization", bytes(f"{auth_token}", "utf-8"))]
231
+
232
+ # Send request
233
+ self.kafka_producer.send(request_topic, request_message, headers=headers)
234
+ logging.info(f"Sent Kafka request for {api} with correlation_id: {correlation_id}")
235
+
236
+ # Wait for response with timeout
237
+ if event.wait(timeout=timeout):
238
+ # Response received
239
+ with self.response_lock:
240
+ response = self.response_map.pop(correlation_id, None)
241
+ self.pending_requests.pop(correlation_id, None)
242
+
243
+ if response:
244
+ if response.get("status") == "success":
245
+ data = response.get("data")
246
+ logging.info(f"Kafka success for {api}")
247
+ return data, None, f"Fetched via Kafka for {api}", True
248
+ else:
249
+ error = response.get("error", "Unknown error")
250
+ logging.error(f"Kafka error response for {api}: {error}")
251
+ return None, error, f"Kafka error response for {api}", True
252
+ else:
253
+ logging.warning(f"Kafka response received but missing data for {api}")
254
+ return None, "Response missing data", "Kafka response error", False
255
+ else:
256
+ # Timeout
257
+ with self.response_lock:
258
+ self.pending_requests.pop(correlation_id, None)
259
+ logging.warning(f"Kafka response timeout for {api} after {timeout} seconds")
260
+ return None, "Kafka response timeout", "Kafka response timeout", False
261
+
262
+ except Exception as e:
263
+ # Cleanup on error
264
+ with self.response_lock:
265
+ self.pending_requests.pop(correlation_id, None)
266
+ logging.error(f"Kafka send error for {api}: {e}")
267
+ return None, f"Kafka error: {e}", "Kafka send failed", False
268
+
269
+ def _hybrid_request(self, api, payload, request_topic, response_topic, rest_fallback_func):
270
+ """
271
+ Hybrid request method: try Kafka first, fallback to REST, cache if both fail.
272
+
273
+ Args:
274
+ api: API name
275
+ payload: Request payload
276
+ request_topic: Kafka request topic
277
+ response_topic: Kafka response topic
278
+ rest_fallback_func: Function to call for REST fallback (should return same format as handle_response)
279
+
280
+ Returns:
281
+ Tuple of (data, error, message) matching the API response pattern
282
+ """
283
+ # Try Kafka first
284
+ if self.enable_kafka:
285
+ data, error, message, kafka_success = self._send_kafka_request(
286
+ api, payload, request_topic, response_topic, timeout=5
287
+ )
288
+
289
+ if kafka_success and error is None:
290
+ # Kafka succeeded
291
+ return data, error, message
292
+
293
+ # Kafka returned an error response (not transport error)
294
+ if kafka_success and error is not None:
295
+ logging.warning(f"Kafka returned error for {api}, falling back to REST")
296
+
297
+ # Kafka failed or disabled, try REST
298
+ logging.info(f"Using REST API for {api}")
299
+ try:
300
+ rest_response = rest_fallback_func()
301
+
302
+ # Return REST response (success or failure)
303
+ if rest_response and len(rest_response) == 3:
304
+ return rest_response
305
+ else:
306
+ # Unexpected REST response format
307
+ logging.error(f"REST API returned unexpected format for {api}")
308
+ return None, "Unexpected REST response format", "REST API error"
309
+
310
+ except Exception as e:
311
+ # REST failed
312
+ logging.error(f"REST API failed for {api}: {e}")
313
+ return None, str(e), "REST API failed"
314
+
315
+ def shutdown(self):
316
+ """Gracefully shutdown Kafka connections."""
317
+ if self.kafka_running:
318
+ logging.info("Shutting down Kafka connections...")
319
+ self.kafka_running = False
320
+
321
+ if self.kafka_thread:
322
+ self.kafka_thread.join(timeout=5)
323
+
324
+ if self.kafka_consumer:
325
+ self.kafka_consumer.close()
326
+
327
+ if self.kafka_producer:
328
+ self.kafka_producer.close()
329
+
330
+ logging.info("Kafka connections closed")
61
331
 
62
332
  @log_errors(log_error=True)
63
333
  def get_downscaled_ids(self):
64
- """Get IDs of downscaled instances using REST API.
334
+ """Get IDs of downscaled instances using Kafka (with REST fallback).
65
335
 
66
336
  Returns:
67
337
  Tuple of (data, error, message) from API response
68
338
  """
69
339
  logging.info("Getting downscaled ids for instance %s", self.instance_id)
70
- path = f"/v1/compute/down_scaled_ids/{self.instance_id}"
71
- resp = self.rpc.get(path=path)
72
- return self.handle_response(
73
- resp,
74
- "Downscaled ids info fetched successfully",
75
- "Could not fetch the Downscaled ids info",
340
+
341
+ payload = {"instance_id": self.instance_id}
342
+
343
+ def rest_fallback():
344
+ path = f"/v1/compute/down_scaled_ids/{self.instance_id}"
345
+ resp = self.rpc.get(path=path)
346
+ return self.handle_response(
347
+ resp,
348
+ "Downscaled ids info fetched successfully",
349
+ "Could not fetch the Downscaled ids info",
350
+ )
351
+
352
+ return self._hybrid_request(
353
+ api="get_downscaled_ids",
354
+ payload=payload,
355
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
356
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
357
+ rest_fallback_func=rest_fallback
76
358
  )
77
359
 
78
360
  @log_errors(default_return=(None, "API call failed", "Failed to stop instance"), log_error=True)
79
361
  def stop_instance(self):
80
- """Stop the compute instance using REST API.
362
+ """Stop the compute instance using Kafka (with REST fallback).
81
363
 
82
364
  Returns:
83
365
  Tuple of (data, error, message) from API response
84
366
  """
85
367
  logging.info("Stopping instance %s", self.instance_id)
368
+
86
369
  payload = {
87
370
  "_idInstance": self.instance_id,
88
371
  "isForcedStop": False,
89
372
  }
90
- path = "/v1/compute/compute_instance/stop"
91
- resp = self.rpc.put(path=path, payload=payload)
92
- return self.handle_response(
93
- resp,
94
- "Instance stopped successfully",
95
- "Could not stop the instance",
373
+
374
+ def rest_fallback():
375
+ path = "/v1/compute/compute_instance/stop"
376
+ resp = self.rpc.put(path=path, payload=payload)
377
+ return self.handle_response(
378
+ resp,
379
+ "Instance stopped successfully",
380
+ "Could not stop the instance",
381
+ )
382
+
383
+ return self._hybrid_request(
384
+ api="stop_instance",
385
+ payload=payload,
386
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
387
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
388
+ rest_fallback_func=rest_fallback
96
389
  )
97
390
 
98
391
  @log_errors(log_error=True)
99
392
  def update_jupyter_token(self, token=""):
100
- """Update Jupyter notebook token using REST API."""
101
- path = f"/v1/scaling/update_jupyter_notebook_token/{self.instance_id}"
102
- resp = self.rpc.put(path=path, payload={"token": token})
103
- return self.handle_response(
104
- resp,
105
- "Resources updated successfully",
106
- "Could not update the resources",
393
+ """Update Jupyter notebook token using Kafka (with REST fallback)."""
394
+ payload = {"token": token, "instance_id": self.instance_id}
395
+
396
+ def rest_fallback():
397
+ path = f"/v1/compute/update_jupyter_notebook_token/{self.instance_id}"
398
+ resp = self.rpc.put(path=path, payload={"token": token})
399
+ return self.handle_response(
400
+ resp,
401
+ "Resources updated successfully",
402
+ "Could not update the resources",
403
+ )
404
+
405
+ return self._hybrid_request(
406
+ api="update_jupyter_token",
407
+ payload=payload,
408
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
409
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
410
+ rest_fallback_func=rest_fallback
107
411
  )
108
412
 
109
413
  @log_errors(log_error=True)
@@ -122,7 +426,7 @@ class Scaling:
122
426
  createdAt=None,
123
427
  updatedAt=None,
124
428
  ):
125
- """Update status of an action using REST API.
429
+ """Update status of an action using Kafka (with REST fallback).
126
430
 
127
431
  Args:
128
432
  service_provider: Provider of the service
@@ -162,12 +466,21 @@ class Scaling:
162
466
  "updatedAt": updatedAt,
163
467
  }
164
468
 
165
- path = "/v1/compute/update_action_status"
166
- resp = self.rpc.put(path=path, payload=payload)
167
- return self.handle_response(
168
- resp,
169
- "Action status details updated successfully",
170
- "Could not update the action status details ",
469
+ def rest_fallback():
470
+ path = "/v1/compute/update_action_status"
471
+ resp = self.rpc.put(path=path, payload=payload)
472
+ return self.handle_response(
473
+ resp,
474
+ "Action status details updated successfully",
475
+ "Could not update the action status details ",
476
+ )
477
+
478
+ return self._hybrid_request(
479
+ api="update_action_status",
480
+ payload=payload,
481
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
482
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
483
+ rest_fallback_func=rest_fallback
171
484
  )
172
485
 
173
486
  @log_errors(log_error=True)
@@ -180,7 +493,7 @@ class Scaling:
180
493
  status,
181
494
  status_description,
182
495
  ):
183
- """Update status of an action using REST API.
496
+ """Update status of an action using Kafka (with REST fallback).
184
497
 
185
498
  Args:
186
499
  action_record_id: ID of the action record
@@ -201,45 +514,78 @@ class Scaling:
201
514
  "statusDescription": status_description,
202
515
  }
203
516
 
204
- url = "/v1/actions"
205
- self.rpc.put(path=url, payload=payload)
206
- return None, None, "Status updated"
517
+ def rest_fallback():
518
+ url = "/v1/actions"
519
+ self.rpc.put(path=url, payload=payload)
520
+ return None, None, "Status updated"
521
+
522
+ return self._hybrid_request(
523
+ api="update_action",
524
+ payload=payload,
525
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
526
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
527
+ rest_fallback_func=rest_fallback
528
+ )
207
529
 
208
530
  @log_errors(log_error=True)
209
531
  def get_shutdown_details(self):
210
- """Get shutdown details for the instance using REST API.
532
+ """Get shutdown details for the instance using Kafka (with REST fallback).
211
533
 
212
534
  Returns:
213
535
  Tuple of (data, error, message) from API response
214
536
  """
215
537
  logging.info("Getting shutdown details for instance %s", self.instance_id)
216
- path = f"/v1/compute/get_shutdown_details/{self.instance_id}"
217
- resp = self.rpc.get(path=path)
218
- return self.handle_response(
219
- resp,
220
- "Shutdown info fetched successfully",
221
- "Could not fetch the shutdown details",
538
+
539
+ payload = {"instance_id": self.instance_id}
540
+
541
+ def rest_fallback():
542
+ path = f"/v1/compute/get_shutdown_details/{self.instance_id}"
543
+ resp = self.rpc.get(path=path)
544
+ return self.handle_response(
545
+ resp,
546
+ "Shutdown info fetched successfully",
547
+ "Could not fetch the shutdown details",
548
+ )
549
+
550
+ return self._hybrid_request(
551
+ api="get_shutdown_details",
552
+ payload=payload,
553
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
554
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
555
+ rest_fallback_func=rest_fallback
222
556
  )
223
557
 
224
558
  @log_errors(log_error=True)
225
559
  def get_tasks_details(self):
226
- """Get task details for the instance using REST API.
560
+ """Get task details for the instance using Kafka (with REST fallback).
227
561
 
228
562
  Returns:
229
563
  Tuple of (data, error, message) from API response
230
564
  """
231
565
  logging.info("Getting tasks details for instance %s", self.instance_id)
232
- path = f"/v1/actions/fetch_instance_action_details/{self.instance_id}/action_details"
233
- resp = self.rpc.get(path=path)
234
- return self.handle_response(
235
- resp,
236
- "Task details fetched successfully",
237
- "Could not fetch the task details",
566
+
567
+ payload = {"instance_id": self.instance_id}
568
+
569
+ def rest_fallback():
570
+ path = f"/v1/actions/fetch_instance_action_details/{self.instance_id}/action_details"
571
+ resp = self.rpc.get(path=path)
572
+ return self.handle_response(
573
+ resp,
574
+ "Task details fetched successfully",
575
+ "Could not fetch the task details",
576
+ )
577
+
578
+ return self._hybrid_request(
579
+ api="get_tasks_details",
580
+ payload=payload,
581
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
582
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
583
+ rest_fallback_func=rest_fallback
238
584
  )
239
585
 
240
586
  @log_errors(log_error=True)
241
587
  def get_action_details(self, action_status_id):
242
- """Get details for a specific action using REST API.
588
+ """Get details for a specific action using Kafka (with REST fallback).
243
589
 
244
590
  Args:
245
591
  action_status_id: ID of the action status to fetch
@@ -248,12 +594,24 @@ class Scaling:
248
594
  Tuple of (data, error, message) from API response
249
595
  """
250
596
  logging.info("Getting action details for action %s", action_status_id)
251
- path = f"/v1/actions/action/{action_status_id}/details"
252
- resp = self.rpc.get(path=path)
253
- return self.handle_response(
254
- resp,
255
- "Task details fetched successfully",
256
- "Could not fetch the task details",
597
+
598
+ payload = {"actionRecordId": action_status_id}
599
+
600
+ def rest_fallback():
601
+ path = f"/v1/actions/action/{action_status_id}/details"
602
+ resp = self.rpc.get(path=path)
603
+ return self.handle_response(
604
+ resp,
605
+ "Task details fetched successfully",
606
+ "Could not fetch the task details",
607
+ )
608
+
609
+ return self._hybrid_request(
610
+ api="get_action_details",
611
+ payload=payload,
612
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
613
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
614
+ rest_fallback_func=rest_fallback
257
615
  )
258
616
 
259
617
  @log_errors(log_error=True)
@@ -268,7 +626,7 @@ class Scaling:
268
626
  service="",
269
627
  job_params=None,
270
628
  ):
271
- """Update an action using REST API.
629
+ """Update an action using Kafka (with REST fallback).
272
630
 
273
631
  Args:
274
632
  id: Action ID
@@ -299,12 +657,21 @@ class Scaling:
299
657
  "jobParams": job_params,
300
658
  }
301
659
 
302
- path = "/v1/actions"
303
- resp = self.rpc.put(path=path, payload=payload)
304
- return self.handle_response(
305
- resp,
306
- "Error logged successfully",
307
- "Could not log the errors",
660
+ def rest_fallback():
661
+ path = "/v1/actions"
662
+ resp = self.rpc.put(path=path, payload=payload)
663
+ return self.handle_response(
664
+ resp,
665
+ "Error logged successfully",
666
+ "Could not log the errors",
667
+ )
668
+
669
+ return self._hybrid_request(
670
+ api="update_action",
671
+ payload=payload,
672
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
673
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
674
+ rest_fallback_func=rest_fallback
308
675
  )
309
676
 
310
677
 
@@ -342,7 +709,7 @@ class Scaling:
342
709
  availableMemory=0,
343
710
  availableGPUMemory=0,
344
711
  ):
345
- """Update available resources for the instance using REST API.
712
+ """Update available resources for the instance using Kafka (with REST fallback).
346
713
 
347
714
  Args:
348
715
  availableCPU: Available CPU resources
@@ -362,17 +729,28 @@ class Scaling:
362
729
  "availableGPU": availableGPU,
363
730
  }
364
731
 
365
- path = f"/v1/compute/update_available_resources/{self.instance_id}"
366
- resp = self.rpc.put(path=path, payload=payload)
367
- return self.handle_response(
368
- resp,
369
- "Resources updated successfully",
370
- "Could not update the resources",
732
+ # Define REST fallback function
733
+ def rest_fallback():
734
+ path = f"/v1/compute/update_available_resources/{self.instance_id}"
735
+ resp = self.rpc.put(path=path, payload=payload)
736
+ return self.handle_response(
737
+ resp,
738
+ "Resources updated successfully",
739
+ "Could not update the resources",
740
+ )
741
+
742
+ # Use hybrid approach: Kafka first, REST fallback, cache if both fail
743
+ return self._hybrid_request(
744
+ api="update_available_resources",
745
+ payload=payload,
746
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
747
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
748
+ rest_fallback_func=rest_fallback
371
749
  )
372
750
 
373
751
  @log_errors(log_error=True)
374
752
  def update_action_docker_logs(self, action_record_id, log_content):
375
- """Update docker logs for an action using REST API.
753
+ """Update docker logs for an action using Kafka (with REST fallback).
376
754
 
377
755
  Args:
378
756
  action_record_id: ID of the action record
@@ -388,43 +766,100 @@ class Scaling:
388
766
  "logContent": log_content,
389
767
  }
390
768
 
391
- path = "/v1/actions/update_action_docker_logs"
769
+ def rest_fallback():
770
+ path = "/v1/actions/update_action_docker_logs"
771
+ resp = self.rpc.put(path=path, payload=payload)
772
+ return self.handle_response(
773
+ resp,
774
+ "Docker logs updated successfully",
775
+ "Could not update the docker logs",
776
+ )
777
+
778
+ return self._hybrid_request(
779
+ api="update_action_docker_logs",
780
+ payload=payload,
781
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
782
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
783
+ rest_fallback_func=rest_fallback
784
+ )
785
+
786
+ def update_action_container_id(self, action_record_id, container_id):
787
+ """Update container ID for an action using Kafka (with REST fallback).
788
+
789
+ Args:
790
+ action_record_id: ID of the action record
791
+ container_id: Container ID to update
792
+
793
+ Returns:
794
+ Tuple of (data, error, message) from API response
795
+ """
796
+ logging.info("Updating container ID for action %s", action_record_id)
797
+
798
+ payload = {
799
+ "actionRecordId": action_record_id,
800
+ "containerId": container_id,
801
+ }
802
+
803
+ path = "/v1/actions/update_action_container_id"
392
804
  resp = self.rpc.put(path=path, payload=payload)
393
805
  return self.handle_response(
394
- resp,
395
- "Docker logs updated successfully",
396
- "Could not update the docker logs",
806
+ resp,
807
+ "Container ID updated successfully",
808
+ "Could not update the container ID",
397
809
  )
398
810
 
399
811
  @log_errors(log_error=True)
400
812
  def get_docker_hub_credentials(self):
401
- """Get Docker Hub credentials using REST API.
813
+ """Get Docker Hub credentials using Kafka (with REST fallback).
402
814
 
403
815
  Returns:
404
816
  Tuple of (data, error, message) from API response
405
817
  """
406
818
  logging.info("Getting docker credentials")
407
- path = "/v1/compute/get_docker_hub_credentials"
408
- resp = self.rpc.get(path=path)
409
- return self.handle_response(
410
- resp,
411
- "Docker credentials fetched successfully",
412
- "Could not fetch the docker credentials",
819
+
820
+ payload = {}
821
+
822
+ def rest_fallback():
823
+ path = "/v1/compute/get_docker_hub_credentials"
824
+ resp = self.rpc.get(path=path)
825
+ return self.handle_response(
826
+ resp,
827
+ "Docker credentials fetched successfully",
828
+ "Could not fetch the docker credentials",
829
+ )
830
+
831
+ return self._hybrid_request(
832
+ api="get_docker_hub_credentials",
833
+ payload=payload,
834
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
835
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
836
+ rest_fallback_func=rest_fallback
413
837
  )
414
838
 
415
839
  @log_errors(log_error=True)
416
840
  def get_open_ports_config(self):
417
- """Get open ports configuration using REST API.
841
+ """Get open ports configuration using Kafka (with REST fallback).
418
842
 
419
843
  Returns:
420
844
  Tuple of (data, error, message) from API response
421
845
  """
422
- path = f"/v1/compute/get_open_ports/{self.instance_id}"
423
- resp = self.rpc.get(path=path)
424
- return self.handle_response(
425
- resp,
426
- "Open ports config fetched successfully",
427
- "Could not fetch the open ports config",
846
+ payload = {"instance_id": self.instance_id}
847
+
848
+ def rest_fallback():
849
+ path = f"/v1/compute/get_open_ports/{self.instance_id}"
850
+ resp = self.rpc.get(path=path)
851
+ return self.handle_response(
852
+ resp,
853
+ "Open ports config fetched successfully",
854
+ "Could not fetch the open ports config",
855
+ )
856
+
857
+ return self._hybrid_request(
858
+ api="get_open_ports_config",
859
+ payload=payload,
860
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
861
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
862
+ rest_fallback_func=rest_fallback
428
863
  )
429
864
 
430
865
  @log_errors(default_return=None, log_error=True)
@@ -475,7 +910,7 @@ class Scaling:
475
910
 
476
911
  @log_errors(log_error=True)
477
912
  def get_model_secret_keys(self, secret_name):
478
- """Get model secret keys using REST API.
913
+ """Get model secret keys using Kafka (with REST fallback).
479
914
 
480
915
  Args:
481
916
  secret_name: Name of the secret
@@ -483,12 +918,23 @@ class Scaling:
483
918
  Returns:
484
919
  Tuple of (data, error, message) from API response
485
920
  """
486
- path = f"/v1/compute/get_models_secret_keys?secret_name={secret_name}"
487
- resp = self.rpc.get(path=path)
488
- return self.handle_response(
489
- resp,
490
- "Secret keys fetched successfully",
491
- "Could not fetch the secret keys",
921
+ payload = {"secret_name": secret_name}
922
+
923
+ def rest_fallback():
924
+ path = f"/v1/compute/get_models_secret_keys?secret_name={secret_name}"
925
+ resp = self.rpc.get(path=path)
926
+ return self.handle_response(
927
+ resp,
928
+ "Secret keys fetched successfully",
929
+ "Could not fetch the secret keys",
930
+ )
931
+
932
+ return self._hybrid_request(
933
+ api="get_model_secret_keys",
934
+ payload=payload,
935
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
936
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
937
+ rest_fallback_func=rest_fallback
492
938
  )
493
939
 
494
940
  @log_errors(log_error=True)
@@ -589,7 +1035,7 @@ class Scaling:
589
1035
 
590
1036
  @log_errors(log_error=True)
591
1037
  def stop_account_compute(self, account_number, alias):
592
- """Stop a compute instance for an account using REST API.
1038
+ """Stop a compute instance for an account using Kafka (with REST fallback).
593
1039
 
594
1040
  Args:
595
1041
  account_number: Account number
@@ -599,17 +1045,32 @@ class Scaling:
599
1045
  Tuple of (data, error, message) from API response
600
1046
  """
601
1047
  logging.info("Stopping account compute for %s/%s", account_number, alias)
602
- path = f"/v1/compute/stop_account_compute/{account_number}/{alias}"
603
- resp = self.rpc.put(path=path)
604
- return self.handle_response(
605
- resp,
606
- "Compute instance stopped successfully",
607
- "Could not stop the compute instance",
1048
+
1049
+ payload = {
1050
+ "account_number": account_number,
1051
+ "alias": alias,
1052
+ }
1053
+
1054
+ def rest_fallback():
1055
+ path = f"/v1/compute/stop_account_compute/{account_number}/{alias}"
1056
+ resp = self.rpc.put(path=path)
1057
+ return self.handle_response(
1058
+ resp,
1059
+ "Compute instance stopped successfully",
1060
+ "Could not stop the compute instance",
1061
+ )
1062
+
1063
+ return self._hybrid_request(
1064
+ api="stop_account_compute",
1065
+ payload=payload,
1066
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
1067
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
1068
+ rest_fallback_func=rest_fallback
608
1069
  )
609
1070
 
610
1071
  @log_errors(log_error=True)
611
1072
  def restart_account_compute(self, account_number, alias):
612
- """Restart a compute instance for an account using REST API.
1073
+ """Restart a compute instance for an account using Kafka (with REST fallback).
613
1074
 
614
1075
  Args:
615
1076
  account_number: Account number
@@ -619,12 +1080,27 @@ class Scaling:
619
1080
  Tuple of (data, error, message) from API response
620
1081
  """
621
1082
  logging.info("Restarting account compute for %s/%s", account_number, alias)
622
- path = f"/v1/compute/restart_account_compute/{account_number}/{alias}"
623
- resp = self.rpc.put(path=path)
624
- return self.handle_response(
625
- resp,
626
- "Compute instance restarted successfully",
627
- "Could not restart the compute instance",
1083
+
1084
+ payload = {
1085
+ "account_number": account_number,
1086
+ "alias": alias,
1087
+ }
1088
+
1089
+ def rest_fallback():
1090
+ path = f"/v1/compute/restart_account_compute/{account_number}/{alias}"
1091
+ resp = self.rpc.put(path=path)
1092
+ return self.handle_response(
1093
+ resp,
1094
+ "Compute instance restarted successfully",
1095
+ "Could not restart the compute instance",
1096
+ )
1097
+
1098
+ return self._hybrid_request(
1099
+ api="restart_account_compute",
1100
+ payload=payload,
1101
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
1102
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
1103
+ rest_fallback_func=rest_fallback
628
1104
  )
629
1105
 
630
1106
  @log_errors(log_error=True)
@@ -648,37 +1124,59 @@ class Scaling:
648
1124
 
649
1125
  @log_errors(log_error=True)
650
1126
  def get_all_instances_type(self):
651
- """Get all instance types using REST API.
1127
+ """Get all instance types using Kafka (with REST fallback).
652
1128
 
653
1129
  Returns:
654
1130
  Tuple of (data, error, message) from API response
655
1131
  """
656
- path = "/v1/compute/get_all_instances_type"
657
- resp = self.rpc.get(path=path)
658
- return self.handle_response(
659
- resp,
660
- "All instance types fetched successfully",
661
- "Could not fetch the instance types",
1132
+ payload = {}
1133
+
1134
+ def rest_fallback():
1135
+ path = "/v1/compute/get_all_instances_type"
1136
+ resp = self.rpc.get(path=path)
1137
+ return self.handle_response(
1138
+ resp,
1139
+ "All instance types fetched successfully",
1140
+ "Could not fetch the instance types",
1141
+ )
1142
+
1143
+ return self._hybrid_request(
1144
+ api="get_all_instances_type",
1145
+ payload=payload,
1146
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
1147
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
1148
+ rest_fallback_func=rest_fallback
662
1149
  )
663
1150
 
664
1151
  @log_errors(log_error=True)
665
1152
  def get_compute_details(self):
666
- """Get compute instance details using REST API.
1153
+ """Get compute instance details using Kafka (with REST fallback).
667
1154
 
668
1155
  Returns:
669
1156
  Tuple of (data, error, message) from API response
670
1157
  """
671
- path = f"/v1/scaling/get_compute_details/{self.instance_id}"
672
- resp = self.rpc.get(path=path)
673
- return self.handle_response(
674
- resp,
675
- "Compute details fetched successfully",
676
- "Could not fetch the compute details",
677
- )
1158
+ payload = {"instance_id": self.instance_id}
1159
+
1160
+ def rest_fallback():
1161
+ path = f"/v1/compute/get_compute_details/{self.instance_id}"
1162
+ resp = self.rpc.get(path=path)
1163
+ return self.handle_response(
1164
+ resp,
1165
+ "Compute details fetched successfully",
1166
+ "Could not fetch the compute details",
1167
+ )
678
1168
 
1169
+ return self._hybrid_request(
1170
+ api="get_compute_details",
1171
+ payload=payload,
1172
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
1173
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
1174
+ rest_fallback_func=rest_fallback
1175
+ )
1176
+
679
1177
  @log_errors(log_error=True)
680
1178
  def get_user_access_key_pair(self, user_id):
681
- """Get user access key pair using REST API.
1179
+ """Get user access key pair using Kafka (with REST fallback).
682
1180
 
683
1181
  Args:
684
1182
  user_id: ID of the user
@@ -686,17 +1184,190 @@ class Scaling:
686
1184
  Returns:
687
1185
  Tuple of (data, error, message) from API response
688
1186
  """
689
- path = f"/v1/compute/get_user_access_key_pair/{user_id}/{self.instance_id}"
690
- resp = self.rpc.get(path=path)
1187
+ payload = {"user_id": user_id, "instance_id": self.instance_id}
1188
+
1189
+ def rest_fallback():
1190
+ path = f"/v1/compute/get_user_access_key_pair/{user_id}/{self.instance_id}"
1191
+ resp = self.rpc.get(path=path)
1192
+ return self.handle_response(
1193
+ resp,
1194
+ "User access key pair fetched successfully",
1195
+ "Could not fetch the user access key pair",
1196
+ )
1197
+
1198
+ return self._hybrid_request(
1199
+ api="get_user_access_key_pair",
1200
+ payload=payload,
1201
+ request_topic=self.kafka_config["compute_request_topic"] if self.enable_kafka else None,
1202
+ response_topic=self.kafka_config["compute_response_topic"] if self.enable_kafka else None,
1203
+ rest_fallback_func=rest_fallback
1204
+ )
1205
+
1206
+
1207
+
1208
+ def report_architecture_info(self):
1209
+ """Collects and sends architecture info to the compute service."""
1210
+ cpu_arch = platform.machine()
1211
+ cpu_name = None
1212
+ total_memory_gb = None
1213
+ gpu_provider = None
1214
+ gpu_arch = None
1215
+ cuda_version = None
1216
+ is_jetson = False
1217
+ gpu_arch_family = None
1218
+ gpu_compute_cap = None
1219
+
1220
+ if cpu_arch== "x86_64":
1221
+ cpu_arch = "x86"
1222
+ elif cpu_arch == "aarch64":
1223
+ cpu_arch = "arm64"
1224
+
1225
+ # Get CPU name
1226
+ try:
1227
+ cpu_info = subprocess.run(["lscpu"], capture_output=True, text=True)
1228
+ if cpu_info.returncode == 0:
1229
+ for line in cpu_info.stdout.splitlines():
1230
+ if "Model name:" in line:
1231
+ cpu_name = line.split("Model name:")[-1].strip()
1232
+ break
1233
+ # Fallback for systems without lscpu
1234
+ if not cpu_name:
1235
+ try:
1236
+ with open("/proc/cpuinfo", "r") as f:
1237
+ for line in f:
1238
+ if "model name" in line:
1239
+ cpu_name = line.split(":")[-1].strip()
1240
+ break
1241
+ except Exception:
1242
+ pass
1243
+ except Exception:
1244
+ pass
1245
+
1246
+ # Get total memory in GB
1247
+ try:
1248
+ total_memory_bytes = psutil.virtual_memory().total
1249
+ total_memory_gb = round(total_memory_bytes / (1024 ** 3), 2)
1250
+ except Exception:
1251
+ try:
1252
+ # Fallback using /proc/meminfo
1253
+ with open("/proc/meminfo", "r") as f:
1254
+ for line in f:
1255
+ if "MemTotal:" in line:
1256
+ mem_kb = int(line.split()[1])
1257
+ total_memory_gb = round(mem_kb / (1024 ** 2), 2)
1258
+ break
1259
+ except Exception:
1260
+ pass
1261
+
1262
+ # Jetson detection first (avoid nvidia-smi on Jetson)
1263
+ try:
1264
+ with open("/proc/device-tree/model") as f:
1265
+ model = f.read().lower()
1266
+ if "jetson" in model or "tegra" in model:
1267
+ is_jetson = True
1268
+ gpu_provider = "NVIDIA"
1269
+ try:
1270
+ cuda_result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True)
1271
+ if cuda_result.returncode == 0:
1272
+ for line in cuda_result.stdout.splitlines():
1273
+ if "release" in line:
1274
+ cuda_version = line.split("release")[-1].split(",")[0].strip()
1275
+ break
1276
+ except Exception:
1277
+ pass
1278
+ except Exception:
1279
+ pass
1280
+
1281
+ # If not Jetson, try NVIDIA (nvidia-smi)
1282
+ if not is_jetson:
1283
+ try:
1284
+ result = subprocess.run(["nvidia-smi", "--query-gpu=name,compute_cap", "--format=csv,noheader"], capture_output=True, text=True)
1285
+ if result.returncode == 0:
1286
+ gpu_provider = "NVIDIA"
1287
+ gpu_info = result.stdout.strip().split("\n")[0].split(",")
1288
+ gpu_arch = gpu_info[0].strip() if len(gpu_info) > 0 else None
1289
+ gpu_compute_cap = gpu_info[1].strip() if len(gpu_info) > 1 else None
1290
+ # Map compute capability to arch family
1291
+ if gpu_compute_cap:
1292
+ major = int(gpu_compute_cap.split(".")[0])
1293
+ if major == 5:
1294
+ gpu_arch_family = "Maxwell"
1295
+ elif major == 6:
1296
+ gpu_arch_family = "Pascal"
1297
+ elif major == 7:
1298
+ gpu_arch_family = "Volta"
1299
+ elif major == 8:
1300
+ gpu_arch_family = "Ampere"
1301
+ elif major == 9:
1302
+ gpu_arch_family = "Hopper"
1303
+ elif major == 10:
1304
+ gpu_arch_family = "Blackwell"
1305
+ else:
1306
+ gpu_arch_family = "Unknown"
1307
+ # Get CUDA version
1308
+ cuda_result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True)
1309
+ if cuda_result.returncode == 0:
1310
+ for line in cuda_result.stdout.splitlines():
1311
+ if "release" in line:
1312
+ cuda_version = line.split("release")[-1].split(",")[0].strip()
1313
+ break
1314
+ except FileNotFoundError:
1315
+ pass
1316
+
1317
+ # Try AMD if NVIDIA not found
1318
+ if gpu_provider is None:
1319
+ try:
1320
+ result = subprocess.run(["lspci"], capture_output=True, text=True)
1321
+ if result.returncode == 0:
1322
+ for line in result.stdout.splitlines():
1323
+ if "AMD" in line or "Advanced Micro Devices" in line:
1324
+ gpu_provider = "AMD"
1325
+ gpu_arch = line.strip()
1326
+ break
1327
+ except FileNotFoundError:
1328
+ pass
1329
+
1330
+ # Only send if provider is NVIDIA or AMD
1331
+ if gpu_provider in ("NVIDIA", "AMD"):
1332
+ payload = {
1333
+ "instance_id": self.instance_id,
1334
+ "cpu_architecture": cpu_arch,
1335
+ "cpu_name": cpu_name if cpu_name else "Unknown",
1336
+ "total_memory_gb": total_memory_gb if total_memory_gb else 0,
1337
+ "gpu_provider": gpu_provider,
1338
+ "gpu_architecture": gpu_arch_family if gpu_arch_family else "Unknown",
1339
+ "gpu": gpu_arch,
1340
+ "cuda_version": cuda_version if cuda_version else "N/A",
1341
+ "is_jetson": is_jetson
1342
+ }
1343
+ else:
1344
+ payload = {
1345
+ "instance_id": self.instance_id,
1346
+ "cpu_architecture": cpu_arch,
1347
+ "cpu_name": cpu_name if cpu_name else "Unknown",
1348
+ "total_memory_gb": total_memory_gb if total_memory_gb else 0,
1349
+ "gpu_provider": "None",
1350
+ "gpu_architecture": "None",
1351
+ "gpu": "None",
1352
+ "cuda_version": "N/A",
1353
+ "is_jetson": False
1354
+ }
1355
+
1356
+ #report for a simple cpu only instance
1357
+
1358
+ path = "/v1/compute/report_architecture_info"
1359
+ resp = self.rpc.post(path=path, payload=payload)
691
1360
  return self.handle_response(
692
1361
  resp,
693
- "User access key pair fetched successfully",
694
- "Could not fetch the user access key pair",
1362
+ "Architecture info reported successfully",
1363
+ "Could not report architecture info",
695
1364
  )
1365
+
1366
+
696
1367
 
697
1368
  @log_errors(log_error=True)
698
1369
  def get_internal_api_key(self, action_id):
699
- """Get internal API key using REST API.
1370
+ """Get internal API key using Kafka (with REST fallback).
700
1371
 
701
1372
  Args:
702
1373
  action_id: ID of the action
@@ -704,11 +1375,21 @@ class Scaling:
704
1375
  Returns:
705
1376
  Tuple of (data, error, message) from API response
706
1377
  """
707
- path = f"/v1/actions/get_internal_api_key/{action_id}/{self.instance_id}"
708
- resp = self.rpc.get(path=path)
709
- return self.handle_response(
710
- resp,
711
- "internal keys fetched successfully",
712
- "Could not fetch internal keys",
713
- )
1378
+ payload = {"action_id": action_id, "instance_id": self.instance_id}
1379
+
1380
+ def rest_fallback():
1381
+ path = f"/v1/actions/get_internal_api_key/{action_id}/{self.instance_id}"
1382
+ resp = self.rpc.get(path=path)
1383
+ return self.handle_response(
1384
+ resp,
1385
+ "internal keys fetched successfully",
1386
+ "Could not fetch internal keys",
1387
+ )
714
1388
 
1389
+ return self._hybrid_request(
1390
+ api="get_internal_api_key",
1391
+ payload=payload,
1392
+ request_topic=self.kafka_config["action_request_topic"] if self.enable_kafka else None,
1393
+ response_topic=self.kafka_config["action_response_topic"] if self.enable_kafka else None,
1394
+ rest_fallback_func=rest_fallback
1395
+ )