tetra-rp 0.6.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. tetra_rp/__init__.py +109 -19
  2. tetra_rp/cli/commands/__init__.py +1 -0
  3. tetra_rp/cli/commands/apps.py +143 -0
  4. tetra_rp/cli/commands/build.py +1082 -0
  5. tetra_rp/cli/commands/build_utils/__init__.py +1 -0
  6. tetra_rp/cli/commands/build_utils/handler_generator.py +176 -0
  7. tetra_rp/cli/commands/build_utils/lb_handler_generator.py +309 -0
  8. tetra_rp/cli/commands/build_utils/manifest.py +430 -0
  9. tetra_rp/cli/commands/build_utils/mothership_handler_generator.py +75 -0
  10. tetra_rp/cli/commands/build_utils/scanner.py +596 -0
  11. tetra_rp/cli/commands/deploy.py +580 -0
  12. tetra_rp/cli/commands/init.py +123 -0
  13. tetra_rp/cli/commands/resource.py +108 -0
  14. tetra_rp/cli/commands/run.py +296 -0
  15. tetra_rp/cli/commands/test_mothership.py +458 -0
  16. tetra_rp/cli/commands/undeploy.py +533 -0
  17. tetra_rp/cli/main.py +97 -0
  18. tetra_rp/cli/utils/__init__.py +1 -0
  19. tetra_rp/cli/utils/app.py +15 -0
  20. tetra_rp/cli/utils/conda.py +127 -0
  21. tetra_rp/cli/utils/deployment.py +530 -0
  22. tetra_rp/cli/utils/ignore.py +143 -0
  23. tetra_rp/cli/utils/skeleton.py +184 -0
  24. tetra_rp/cli/utils/skeleton_template/.env.example +4 -0
  25. tetra_rp/cli/utils/skeleton_template/.flashignore +40 -0
  26. tetra_rp/cli/utils/skeleton_template/.gitignore +44 -0
  27. tetra_rp/cli/utils/skeleton_template/README.md +263 -0
  28. tetra_rp/cli/utils/skeleton_template/main.py +44 -0
  29. tetra_rp/cli/utils/skeleton_template/mothership.py +55 -0
  30. tetra_rp/cli/utils/skeleton_template/pyproject.toml +58 -0
  31. tetra_rp/cli/utils/skeleton_template/requirements.txt +1 -0
  32. tetra_rp/cli/utils/skeleton_template/workers/__init__.py +0 -0
  33. tetra_rp/cli/utils/skeleton_template/workers/cpu/__init__.py +19 -0
  34. tetra_rp/cli/utils/skeleton_template/workers/cpu/endpoint.py +36 -0
  35. tetra_rp/cli/utils/skeleton_template/workers/gpu/__init__.py +19 -0
  36. tetra_rp/cli/utils/skeleton_template/workers/gpu/endpoint.py +61 -0
  37. tetra_rp/client.py +136 -33
  38. tetra_rp/config.py +29 -0
  39. tetra_rp/core/api/runpod.py +591 -39
  40. tetra_rp/core/deployment.py +232 -0
  41. tetra_rp/core/discovery.py +425 -0
  42. tetra_rp/core/exceptions.py +50 -0
  43. tetra_rp/core/resources/__init__.py +27 -9
  44. tetra_rp/core/resources/app.py +738 -0
  45. tetra_rp/core/resources/base.py +139 -4
  46. tetra_rp/core/resources/constants.py +21 -0
  47. tetra_rp/core/resources/cpu.py +115 -13
  48. tetra_rp/core/resources/gpu.py +182 -16
  49. tetra_rp/core/resources/live_serverless.py +153 -16
  50. tetra_rp/core/resources/load_balancer_sls_resource.py +440 -0
  51. tetra_rp/core/resources/network_volume.py +126 -31
  52. tetra_rp/core/resources/resource_manager.py +436 -35
  53. tetra_rp/core/resources/serverless.py +537 -120
  54. tetra_rp/core/resources/serverless_cpu.py +201 -0
  55. tetra_rp/core/resources/template.py +1 -59
  56. tetra_rp/core/utils/constants.py +10 -0
  57. tetra_rp/core/utils/file_lock.py +260 -0
  58. tetra_rp/core/utils/http.py +67 -0
  59. tetra_rp/core/utils/lru_cache.py +75 -0
  60. tetra_rp/core/utils/singleton.py +36 -1
  61. tetra_rp/core/validation.py +44 -0
  62. tetra_rp/execute_class.py +301 -0
  63. tetra_rp/protos/remote_execution.py +98 -9
  64. tetra_rp/runtime/__init__.py +1 -0
  65. tetra_rp/runtime/circuit_breaker.py +274 -0
  66. tetra_rp/runtime/config.py +12 -0
  67. tetra_rp/runtime/exceptions.py +49 -0
  68. tetra_rp/runtime/generic_handler.py +206 -0
  69. tetra_rp/runtime/lb_handler.py +189 -0
  70. tetra_rp/runtime/load_balancer.py +160 -0
  71. tetra_rp/runtime/manifest_fetcher.py +192 -0
  72. tetra_rp/runtime/metrics.py +325 -0
  73. tetra_rp/runtime/models.py +73 -0
  74. tetra_rp/runtime/mothership_provisioner.py +512 -0
  75. tetra_rp/runtime/production_wrapper.py +266 -0
  76. tetra_rp/runtime/reliability_config.py +149 -0
  77. tetra_rp/runtime/retry_manager.py +118 -0
  78. tetra_rp/runtime/serialization.py +124 -0
  79. tetra_rp/runtime/service_registry.py +346 -0
  80. tetra_rp/runtime/state_manager_client.py +248 -0
  81. tetra_rp/stubs/live_serverless.py +35 -17
  82. tetra_rp/stubs/load_balancer_sls.py +357 -0
  83. tetra_rp/stubs/registry.py +145 -19
  84. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/METADATA +398 -60
  85. tetra_rp-0.24.0.dist-info/RECORD +99 -0
  86. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/WHEEL +1 -1
  87. tetra_rp-0.24.0.dist-info/entry_points.txt +2 -0
  88. tetra_rp/core/pool/cluster_manager.py +0 -177
  89. tetra_rp/core/pool/dataclass.py +0 -18
  90. tetra_rp/core/pool/ex.py +0 -38
  91. tetra_rp/core/pool/job.py +0 -22
  92. tetra_rp/core/pool/worker.py +0 -19
  93. tetra_rp/core/resources/utils.py +0 -50
  94. tetra_rp/core/utils/json.py +0 -33
  95. tetra_rp-0.6.0.dist-info/RECORD +0 -39
  96. /tetra_rp/{core/pool → cli}/__init__.py +0 -0
  97. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import hashlib
2
2
  from abc import ABC, abstractmethod
3
- from typing import Optional
3
+ from typing import Optional, Dict, Any
4
4
  from pydantic import BaseModel, ConfigDict
5
5
 
6
6
 
@@ -17,11 +17,137 @@ class BaseResource(BaseModel):
17
17
 
18
18
  @property
19
19
  def resource_id(self) -> str:
20
- """Unique resource ID based on configuration."""
20
+ """Unique resource ID based on configuration.
21
+
22
+ Computed once and cached to ensure stability across the object's lifetime.
23
+ This prevents hash changes if validators mutate the object after first access.
24
+
25
+ The hash excludes the 'id' field since it's assigned by the provider after
26
+ deployment and should not affect resource identity.
27
+
28
+ If the resource defines _hashed_fields as a class variable, only those fields
29
+ are included in the hash.
30
+ """
31
+ # Use a private attribute in __dict__ for caching (handles pickle correctly)
32
+ cache_key = "_cached_resource_id"
33
+ if cache_key not in self.__dict__:
34
+ resource_type = self.__class__.__name__
35
+ # Check if resource defines _hashed_fields as a set/frozenset or ModelPrivateAttr
36
+ hashed_fields_attr = getattr(self.__class__, "_hashed_fields", None)
37
+ include_fields = None
38
+
39
+ if isinstance(hashed_fields_attr, (set, frozenset)):
40
+ # Direct set/frozenset
41
+ include_fields = hashed_fields_attr - {"id"}
42
+ elif hasattr(hashed_fields_attr, "default") and isinstance(
43
+ hashed_fields_attr.default, (set, frozenset)
44
+ ):
45
+ # Pydantic ModelPrivateAttr with set/frozenset default
46
+ include_fields = hashed_fields_attr.default - {"id"}
47
+
48
+ if include_fields:
49
+ config_str = self.model_dump_json(
50
+ exclude_none=True, include=include_fields
51
+ )
52
+ else:
53
+ # Fallback: Exclude only id field
54
+ config_str = self.model_dump_json(exclude_none=True, exclude={"id"})
55
+
56
+ hash_obj = hashlib.md5(f"{resource_type}:{config_str}".encode())
57
+ self.__dict__[cache_key] = f"{resource_type}_{hash_obj.hexdigest()}"
58
+ return self.__dict__[cache_key]
59
+
60
+ @property
61
+ def config_hash(self) -> str:
62
+ """Get hash of current configuration (excluding id and server-assigned fields).
63
+
64
+ Unlike resource_id which is cached, this always computes fresh hash.
65
+ Used for drift detection.
66
+
67
+ For resources with _input_only set, only those fields are included in the hash
68
+ to avoid drift from server-assigned fields.
69
+ """
70
+ import json
71
+ import logging
72
+
21
73
  resource_type = self.__class__.__name__
22
- config_str = self.model_dump_json(exclude_none=True)
74
+
75
+ # If resource defines input_only fields, use only those for hash
76
+ if hasattr(self, "_input_only"):
77
+ # Include only user-provided input fields, not server-assigned ones
78
+ include_fields = self._input_only - {"id"} # Exclude id from input fields
79
+ config_dict = self.model_dump(
80
+ exclude_none=True, include=include_fields, mode="json"
81
+ )
82
+ else:
83
+ # Fallback: exclude only id field
84
+ config_dict = self.model_dump(
85
+ exclude_none=True, exclude={"id"}, mode="json"
86
+ )
87
+
88
+ # Convert to JSON string for hashing
89
+ config_str = json.dumps(config_dict, sort_keys=True)
23
90
  hash_obj = hashlib.md5(f"{resource_type}:{config_str}".encode())
24
- return f"{resource_type}_{hash_obj.hexdigest()}"
91
+ hash_value = hash_obj.hexdigest()
92
+
93
+ # Debug logging to see what's being hashed
94
+ log = logging.getLogger(__name__)
95
+ if hasattr(self, "name"):
96
+ log.debug(
97
+ f"CONFIG HASH for {self.name} ({resource_type}):\n"
98
+ f" Fields included: {sorted(config_dict.keys())}\n"
99
+ f" Config dict: {config_str}\n"
100
+ f" Hash: {hash_value}"
101
+ )
102
+
103
+ return hash_value
104
+
105
+ def get_resource_key(self) -> str:
106
+ """Get stable resource key for tracking.
107
+
108
+ Format: {ResourceType}:{name}
109
+ This provides stable identity even when config changes.
110
+ """
111
+ resource_type = self.__class__.__name__
112
+ name = getattr(self, "name", None)
113
+ if name:
114
+ return f"{resource_type}:{name}"
115
+ # Fallback to resource_id for resources without names
116
+ return self.resource_id
117
+
118
+ def __getstate__(self) -> Dict[str, Any]:
119
+ """Get state for pickling, excluding non-pickleable items."""
120
+ import weakref as weakref_module
121
+
122
+ state = self.__dict__.copy()
123
+
124
+ # Remove any weakrefs from the state dict
125
+ # This handles cases where threading.Lock or similar objects leak weakrefs
126
+ keys_to_remove = []
127
+ for key, value in state.items():
128
+ # Direct weakref
129
+ if isinstance(value, weakref_module.ref):
130
+ keys_to_remove.append(key)
131
+ continue
132
+
133
+ # Check if value holds weakrefs in its __dict__
134
+ if hasattr(value, "__dict__"):
135
+ try:
136
+ for sub_value in value.__dict__.values():
137
+ if isinstance(sub_value, weakref_module.ref):
138
+ keys_to_remove.append(key)
139
+ break
140
+ except Exception:
141
+ pass
142
+
143
+ for key in keys_to_remove:
144
+ del state[key]
145
+
146
+ return state
147
+
148
+ def __setstate__(self, state: Dict[str, Any]) -> None:
149
+ """Restore state from pickling."""
150
+ self.__dict__.update(state)
25
151
 
26
152
 
27
153
  class DeployableResource(BaseResource, ABC):
@@ -45,3 +171,12 @@ class DeployableResource(BaseResource, ABC):
45
171
  async def deploy(self) -> "DeployableResource":
46
172
  """Deploy the resource."""
47
173
  raise NotImplementedError("Subclasses should implement this method.")
174
+
175
+ @abstractmethod
176
+ async def undeploy(self) -> bool:
177
+ """Undeploy/delete the resource.
178
+
179
+ Returns:
180
+ True if successfully undeployed, False otherwise
181
+ """
182
+ raise NotImplementedError("Subclasses should implement this method.")
@@ -1,4 +1,25 @@
1
1
  import os
2
+ from urllib.parse import urlparse
3
+
4
+ import runpod
2
5
 
3
6
  CONSOLE_BASE_URL = os.environ.get("CONSOLE_BASE_URL", "https://console.runpod.io")
4
7
  CONSOLE_URL = f"{CONSOLE_BASE_URL}/serverless/user/endpoint/%s"
8
+
9
+
10
+ def _endpoint_domain_from_base_url(base_url: str) -> str:
11
+ if not base_url:
12
+ return "api.runpod.ai"
13
+ if "://" not in base_url:
14
+ base_url = f"https://{base_url}"
15
+ parsed = urlparse(base_url)
16
+ return parsed.netloc or "api.runpod.ai"
17
+
18
+
19
+ ENDPOINT_DOMAIN = _endpoint_domain_from_base_url(runpod.endpoint_url_base)
20
+
21
+ # Flash app artifact upload constants
22
+ TARBALL_CONTENT_TYPE = "application/gzip"
23
+ MAX_TARBALL_SIZE_MB = 500 # Maximum tarball size in megabytes
24
+ VALID_TARBALL_EXTENSIONS = (".tar.gz", ".tgz") # Valid tarball file extensions
25
+ GZIP_MAGIC_BYTES = (0x1F, 0x8B) # Magic bytes for gzip files
@@ -1,4 +1,5 @@
1
1
  from enum import Enum
2
+ from typing import List, Optional
2
3
 
3
4
 
4
5
  class CpuInstanceType(str, Enum):
@@ -12,23 +13,124 @@ class CpuInstanceType(str, Enum):
12
13
  - cpu3g: 4.0 (1 vCPU = 4GB, 2 vCPU = 8GB, etc.)
13
14
  - cpu3c: 2.0 (1 vCPU = 2GB, 2 vCPU = 4GB, etc.)
14
15
  - cpu5c: 2.0 (1 vCPU = 2GB, 2 vCPU = 4GB, etc.)
15
- - cpu5g: Not available
16
16
  """
17
17
 
18
18
  # 3rd Generation General Purpose (RAM multiplier: 4.0)
19
- CPU3G_1_4 = "cpu3g-1-4" # 1 vCPU, 4GB RAM
20
- CPU3G_2_8 = "cpu3g-2-8" # 2 vCPU, 8GB RAM
21
- CPU3G_4_16 = "cpu3g-4-16" # 4 vCPU, 16GB RAM
22
- CPU3G_8_32 = "cpu3g-8-32" # 8 vCPU, 32GB RAM
19
+
20
+ CPU3G_1_4 = "cpu3g-1-4"
21
+ """1 vCPU, 4GB RAM, max 10GB container disk"""
22
+
23
+ CPU3G_2_8 = "cpu3g-2-8"
24
+ """2 vCPU, 8GB RAM, max 20GB container disk"""
25
+
26
+ CPU3G_4_16 = "cpu3g-4-16"
27
+ """4 vCPU, 16GB RAM, max 40GB container disk"""
28
+
29
+ CPU3G_8_32 = "cpu3g-8-32"
30
+ """8 vCPU, 32GB RAM, max 80GB container disk"""
23
31
 
24
32
  # 3rd Generation Compute-Optimized (RAM multiplier: 2.0)
25
- CPU3C_1_2 = "cpu3c-1-2" # 1 vCPU, 2GB RAM
26
- CPU3C_2_4 = "cpu3c-2-4" # 2 vCPU, 4GB RAM
27
- CPU3C_4_8 = "cpu3c-4-8" # 4 vCPU, 8GB RAM
28
- CPU3C_8_16 = "cpu3c-8-16" # 8 vCPU, 16GB RAM
33
+
34
+ CPU3C_1_2 = "cpu3c-1-2"
35
+ """1 vCPU, 2GB RAM, max 10GB container disk"""
36
+
37
+ CPU3C_2_4 = "cpu3c-2-4"
38
+ """2 vCPU, 4GB RAM, max 20GB container disk"""
39
+
40
+ CPU3C_4_8 = "cpu3c-4-8"
41
+ """4 vCPU, 8GB RAM, max 40GB container disk"""
42
+
43
+ CPU3C_8_16 = "cpu3c-8-16"
44
+ """8 vCPU, 16GB RAM, max 80GB container disk"""
29
45
 
30
46
  # 5th Generation Compute-Optimized (RAM multiplier: 2.0)
31
- CPU5C_1_2 = "cpu5c-1-2" # 1 vCPU, 2GB RAM
32
- CPU5C_2_4 = "cpu5c-2-4" # 2 vCPU, 4GB RAM
33
- CPU5C_4_8 = "cpu5c-4-8" # 4 vCPU, 8GB RAM
34
- CPU5C_8_16 = "cpu5c-8-16" # 8 vCPU, 16GB RAM
47
+
48
+ CPU5C_1_2 = "cpu5c-1-2"
49
+ """1 vCPU, 2GB RAM, max 15GB container disk"""
50
+
51
+ CPU5C_2_4 = "cpu5c-2-4"
52
+ """2 vCPU, 4GB RAM, max 30GB container disk"""
53
+
54
+ CPU5C_4_8 = "cpu5c-4-8"
55
+ """4 vCPU, 8GB RAM, max 60GB container disk"""
56
+
57
+ CPU5C_8_16 = "cpu5c-8-16"
58
+ """8 vCPU, 16GB RAM, max 120GB container disk"""
59
+
60
+
61
+ def calculate_max_disk_size(instance_type: CpuInstanceType) -> int:
62
+ """
63
+ Calculate the maximum container disk size for a CPU instance type.
64
+
65
+ Formula:
66
+ - CPU3G/CPU3C: vCPU count × 10GB
67
+ - CPU5C: vCPU count × 15GB
68
+
69
+ Args:
70
+ instance_type: CPU instance type enum
71
+
72
+ Returns:
73
+ Maximum container disk size in GB
74
+
75
+ Example:
76
+ >>> calculate_max_disk_size(CpuInstanceType.CPU3G_1_4)
77
+ 10
78
+ >>> calculate_max_disk_size(CpuInstanceType.CPU5C_2_4)
79
+ 30
80
+ """
81
+ # Parse the instance type string to extract vCPU count
82
+ # Format: "cpu{generation}{type}-{vcpu}-{memory}"
83
+ instance_str = instance_type.value
84
+ parts = instance_str.split("-")
85
+
86
+ if len(parts) != 3:
87
+ raise ValueError(f"Invalid instance type format: {instance_str}")
88
+
89
+ vcpu_count = int(parts[1])
90
+
91
+ # Determine disk multiplier based on generation
92
+ if instance_str.startswith("cpu5c"):
93
+ disk_multiplier = 15 # CPU5C: 15GB per vCPU
94
+ elif instance_str.startswith(("cpu3g", "cpu3c")):
95
+ disk_multiplier = 10 # CPU3G/CPU3C: 10GB per vCPU
96
+ else:
97
+ raise ValueError(f"Unknown CPU generation/type: {instance_str}")
98
+
99
+ return vcpu_count * disk_multiplier
100
+
101
+
102
+ # CPU Instance Type Disk Limits (calculated programmatically)
103
+ CPU_INSTANCE_DISK_LIMITS = {
104
+ instance_type: calculate_max_disk_size(instance_type)
105
+ for instance_type in CpuInstanceType
106
+ }
107
+
108
+
109
+ def get_max_disk_size_for_instances(
110
+ instance_types: Optional[List[CpuInstanceType]],
111
+ ) -> Optional[int]:
112
+ """
113
+ Calculate the maximum container disk size for a list of CPU instance types.
114
+
115
+ Returns the minimum disk limit across all instance types to ensure compatibility
116
+ with all specified instances.
117
+
118
+ Args:
119
+ instance_types: List of CPU instance types, or None
120
+
121
+ Returns:
122
+ Maximum allowed disk size in GB, or None if no CPU instances specified
123
+
124
+ Example:
125
+ >>> get_max_disk_size_for_instances([CpuInstanceType.CPU3G_1_4])
126
+ 10
127
+ >>> get_max_disk_size_for_instances([CpuInstanceType.CPU3G_1_4, CpuInstanceType.CPU3G_2_8])
128
+ 10
129
+ """
130
+ if not instance_types:
131
+ return None
132
+
133
+ disk_limits = [
134
+ CPU_INSTANCE_DISK_LIMITS[instance_type] for instance_type in instance_types
135
+ ]
136
+ return min(disk_limits)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Optional, List
2
4
  from pydantic import BaseModel
3
5
  from enum import Enum
@@ -8,13 +10,13 @@ class GpuLowestPrice(BaseModel):
8
10
  uninterruptablePrice: Optional[float] = None
9
11
 
10
12
 
11
- class GpuType(BaseModel):
13
+ class GpuTypeModel(BaseModel):
12
14
  id: str
13
15
  displayName: str
14
16
  memoryInGb: int
15
17
 
16
18
 
17
- class GpuTypeDetail(GpuType):
19
+ class GpuTypeDetail(GpuTypeModel):
18
20
  communityCloud: Optional[bool] = None
19
21
  communityPrice: Optional[float] = None
20
22
  communitySpotPrice: Optional[float] = None
@@ -32,22 +34,186 @@ class GpuTypeDetail(GpuType):
32
34
 
33
35
  # TODO: this should be fetched from an API
34
36
  class GpuGroup(Enum):
35
- ANY = "any" # "Any GPU"
36
- ADA_24 = "ADA_24" # "NVIDIA GeForce RTX 4090"
37
- ADA_32_PRO = "ADA_32_PRO" # "NVIDIA GeForce RTX 5090"
38
- ADA_48_PRO = (
39
- "ADA_48_PRO" # "NVIDIA RTX 6000 Ada Generation, NVIDIA L40, NVIDIA L40S"
40
- )
41
- ADA_80_PRO = (
42
- "ADA_80_PRO" # "NVIDIA H100 PCIe, NVIDIA H100 80GB HBM3, NVIDIA H100 NVL"
43
- )
44
- AMPERE_16 = "AMPERE_16" # "NVIDIA RTX A4000, NVIDIA RTX A4500, NVIDIA RTX 4000 Ada Generation, NVIDIA RTX 2000 Ada Generation"
45
- AMPERE_24 = "AMPERE_24" # "NVIDIA RTX A5000, NVIDIA L4, NVIDIA GeForce RTX 3090"
46
- AMPERE_48 = "AMPERE_48" # "NVIDIA A40, NVIDIA RTX A6000"
47
- AMPERE_80 = "AMPERE_80" # "NVIDIA A100 80GB PCIe, NVIDIA A100-SXM4-80GB"
48
- HOPPER_141 = "HOPPER_141" # "NVIDIA H200"
37
+ ANY = "any"
38
+ """Any GPU"""
39
+
40
+ ADA_24 = "ADA_24"
41
+ """NVIDIA GeForce RTX 4090"""
42
+
43
+ ADA_32_PRO = "ADA_32_PRO"
44
+ """NVIDIA GeForce RTX 5090"""
45
+
46
+ ADA_48_PRO = "ADA_48_PRO"
47
+ """NVIDIA RTX 6000 Ada Generation, NVIDIA L40, NVIDIA L40S"""
48
+
49
+ ADA_80_PRO = "ADA_80_PRO"
50
+ """NVIDIA H100 PCIe, NVIDIA H100 80GB HBM3, NVIDIA H100 NVL"""
51
+
52
+ AMPERE_16 = "AMPERE_16"
53
+ """NVIDIA RTX A4000, NVIDIA RTX A4500, NVIDIA RTX 4000 Ada Generation, NVIDIA RTX 2000 Ada Generation"""
54
+
55
+ AMPERE_24 = "AMPERE_24"
56
+ """NVIDIA RTX A5000, NVIDIA L4, NVIDIA GeForce RTX 3090"""
57
+
58
+ AMPERE_48 = "AMPERE_48"
59
+ """NVIDIA A40, NVIDIA RTX A6000"""
60
+
61
+ AMPERE_80 = "AMPERE_80"
62
+ """NVIDIA A100 80GB PCIe, NVIDIA A100-SXM4-80GB"""
63
+
64
+ HOPPER_141 = "HOPPER_141"
65
+ """NVIDIA H200"""
49
66
 
50
67
  @classmethod
51
68
  def all(cls) -> List["GpuGroup"]:
52
69
  """Returns all GPU groups."""
53
70
  return [cls.AMPERE_48] + [g for g in cls if g != cls.ANY]
71
+
72
+ @classmethod
73
+ def to_gpu_ids_str(cls, gpu_types: List[GpuType | GpuGroup]) -> str:
74
+ """
75
+ The API expects a comma-separated list of pool IDs, with GPU ID negations (-4090, -3090, etc.). So to convert a list of GPU types to a string of pool IDs, we need to:
76
+ 1. Convert the GPU types to pool IDs
77
+ 2. Add a negation for each GPU type that is not in the list
78
+ 3. Join the pool IDs with a comma
79
+ """
80
+ pool_ids = set()
81
+ pool_ids_from_groups = set()
82
+ explicit_gpu_types = set()
83
+
84
+ for gpu_type in gpu_types:
85
+ if isinstance(gpu_type, GpuGroup):
86
+ pool_id = gpu_type
87
+ pool_ids_from_groups.add(pool_id)
88
+ else:
89
+ pool_id = _pool_from_gpu_type(gpu_type)
90
+ explicit_gpu_types.add(gpu_type)
91
+
92
+ if pool_id:
93
+ pool_ids.add(pool_id)
94
+
95
+ # only add negations for pools selected via explicit gpu types
96
+ if explicit_gpu_types:
97
+ # iterate over a snapshot because we add negations into the same set
98
+ for pool_id in list(pool_ids):
99
+ if pool_id in pool_ids_from_groups:
100
+ continue
101
+ for gpu_type in POOLS_TO_TYPES.get(pool_id, []):
102
+ if gpu_type not in explicit_gpu_types:
103
+ pool_ids.add(f"-{gpu_type.value}")
104
+
105
+ # normalize to strings for the api
106
+ out = []
107
+ for pool_id in pool_ids:
108
+ if isinstance(pool_id, GpuGroup):
109
+ out.append(pool_id.value)
110
+ else:
111
+ out.append(str(pool_id))
112
+ return ",".join(out)
113
+
114
+ @classmethod
115
+ def from_gpu_ids_str(cls, gpu_ids_str: str) -> List[GpuGroup | GpuType]:
116
+ """
117
+ Convert a comma-separated list of pool IDs to a list of GPU types.
118
+ """
119
+ ids = gpu_ids_str.split(",")
120
+ pool_ids = []
121
+ gpu_types = []
122
+ negated_gpu_types = []
123
+ for id in ids:
124
+ if id.startswith("-") and GpuType.is_gpu_type(id[1:]):
125
+ negated_gpu_types.append(GpuType(id[1:]))
126
+ elif GpuType.is_gpu_type(id):
127
+ gpu_types.append(GpuType(id))
128
+ else:
129
+ pool_ids.append(id)
130
+
131
+ ids = []
132
+
133
+ for pool_id in pool_ids:
134
+ try:
135
+ pool = GpuGroup(pool_id)
136
+ except ValueError:
137
+ # ignore unknown pool ids from backend
138
+ continue
139
+
140
+ pool_gpus = POOLS_TO_TYPES.get(pool, [])
141
+ # check if there are any negated gpu types in the pool
142
+ if any(gpu_type in negated_gpu_types for gpu_type in pool_gpus):
143
+ # add the gpu types that are not in the negated gpu types
144
+ ids.extend(
145
+ [
146
+ gpu_type
147
+ for gpu_type in pool_gpus
148
+ if gpu_type not in negated_gpu_types
149
+ ]
150
+ )
151
+ else:
152
+ ids.append(pool)
153
+
154
+ ids.extend(gpu_types)
155
+ return ids
156
+
157
+
158
+ # TODO: fetch from central registry at some point
159
+ class GpuType(Enum):
160
+ ANY = "any"
161
+ """Any GPU"""
162
+
163
+ NVIDIA_GEFORCE_RTX_4090 = "NVIDIA GeForce RTX 4090"
164
+ NVIDIA_GEFORCE_RTX_5090 = "NVIDIA GeForce RTX 5090"
165
+ NVIDIA_RTX_6000_ADA_GENERATION = "NVIDIA RTX 6000 Ada Generation"
166
+ NVIDIA_H100_80GB_HBM3 = "NVIDIA H100 80GB HBM3"
167
+ NVIDIA_RTX_A4000 = "NVIDIA RTX A4000"
168
+ NVIDIA_RTX_A4500 = "NVIDIA RTX A4500"
169
+ NVIDIA_RTX_4000_ADA_GENERATION = "NVIDIA RTX 4000 Ada Generation"
170
+ NVIDIA_RTX_2000_ADA_GENERATION = "NVIDIA RTX 2000 Ada Generation"
171
+ NVIDIA_RTX_A5000 = "NVIDIA RTX A5000"
172
+ NVIDIA_L4 = "NVIDIA L4"
173
+ NVIDIA_GEFORCE_RTX_3090 = "NVIDIA GeForce RTX 3090"
174
+ NVIDIA_A40 = "NVIDIA A40"
175
+ NVIDIA_RTX_A6000 = "NVIDIA RTX A6000"
176
+ NVIDIA_A100_80GB_PCIe = "NVIDIA A100 80GB PCIe"
177
+ NVIDIA_A100_SXM4_80GB = "NVIDIA A100-SXM4-80GB"
178
+ NVIDIA_H200 = "NVIDIA H200"
179
+
180
+ @classmethod
181
+ def all(cls) -> List["GpuType"]:
182
+ """Returns all GPU types."""
183
+ return [g for g in cls if g != cls.ANY]
184
+
185
+ @classmethod
186
+ def is_gpu_type(cls, gpu_type: str) -> bool:
187
+ """
188
+ Check if a string is a valid GPU type.
189
+ """
190
+ return gpu_type in {m.value for m in cls}
191
+
192
+
193
+ POOLS_TO_TYPES = {
194
+ GpuGroup.ADA_24: [GpuType.NVIDIA_GEFORCE_RTX_4090],
195
+ GpuGroup.ADA_32_PRO: [GpuType.NVIDIA_GEFORCE_RTX_5090],
196
+ GpuGroup.ADA_48_PRO: [GpuType.NVIDIA_RTX_6000_ADA_GENERATION],
197
+ GpuGroup.ADA_80_PRO: [GpuType.NVIDIA_H100_80GB_HBM3],
198
+ GpuGroup.AMPERE_16: [
199
+ GpuType.NVIDIA_RTX_A4000,
200
+ GpuType.NVIDIA_RTX_A4500,
201
+ GpuType.NVIDIA_RTX_4000_ADA_GENERATION,
202
+ GpuType.NVIDIA_RTX_2000_ADA_GENERATION,
203
+ ],
204
+ GpuGroup.AMPERE_24: [
205
+ GpuType.NVIDIA_RTX_A5000,
206
+ GpuType.NVIDIA_L4,
207
+ GpuType.NVIDIA_GEFORCE_RTX_3090,
208
+ ],
209
+ GpuGroup.AMPERE_48: [GpuType.NVIDIA_A40, GpuType.NVIDIA_RTX_A6000],
210
+ GpuGroup.AMPERE_80: [GpuType.NVIDIA_A100_80GB_PCIe, GpuType.NVIDIA_A100_SXM4_80GB],
211
+ GpuGroup.HOPPER_141: [GpuType.NVIDIA_H200],
212
+ }
213
+
214
+
215
+ def _pool_from_gpu_type(gpu_type: GpuType) -> str:
216
+ for group, types in POOLS_TO_TYPES.items():
217
+ if gpu_type in types:
218
+ return group
219
+ return None