groundhog-hpc 0.5.6__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groundhog_hpc/__init__.py +4 -0
- groundhog_hpc/app/add.py +74 -0
- groundhog_hpc/app/init.py +54 -10
- groundhog_hpc/app/main.py +5 -1
- groundhog_hpc/app/remove.py +91 -16
- groundhog_hpc/app/run.py +70 -2
- groundhog_hpc/compute.py +16 -1
- groundhog_hpc/configuration/defaults.py +1 -0
- groundhog_hpc/configuration/endpoints.py +38 -171
- groundhog_hpc/configuration/models.py +26 -0
- groundhog_hpc/configuration/pep723.py +278 -2
- groundhog_hpc/configuration/resolver.py +36 -8
- groundhog_hpc/console.py +1 -1
- groundhog_hpc/decorators.py +35 -16
- groundhog_hpc/function.py +53 -19
- groundhog_hpc/future.py +48 -10
- groundhog_hpc/harness.py +15 -19
- groundhog_hpc/logging.py +51 -0
- groundhog_hpc/serialization.py +22 -2
- groundhog_hpc/templates/init_script.py.jinja +4 -5
- groundhog_hpc/templates/shell_command.sh.jinja +15 -1
- groundhog_hpc/templating.py +17 -0
- {groundhog_hpc-0.5.6.dist-info → groundhog_hpc-0.7.0.dist-info}/METADATA +12 -6
- groundhog_hpc-0.7.0.dist-info/RECORD +34 -0
- groundhog_hpc-0.5.6.dist-info/RECORD +0 -33
- {groundhog_hpc-0.5.6.dist-info → groundhog_hpc-0.7.0.dist-info}/WHEEL +0 -0
- {groundhog_hpc-0.5.6.dist-info → groundhog_hpc-0.7.0.dist-info}/entry_points.txt +0 -0
- {groundhog_hpc-0.5.6.dist-info → groundhog_hpc-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,6 @@ conform to the EndpointConfig/EndpointVariant models for consistency with
|
|
|
6
6
|
existing configuration parsing logic.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from dataclasses import dataclass
|
|
10
9
|
from typing import Any
|
|
11
10
|
from uuid import UUID
|
|
12
11
|
|
|
@@ -16,6 +15,9 @@ from groundhog_hpc.compute import get_endpoint_metadata, get_endpoint_schema
|
|
|
16
15
|
KNOWN_ENDPOINTS: dict[str, dict[str, Any]] = {
|
|
17
16
|
"anvil": {
|
|
18
17
|
"uuid": "5aafb4c1-27b2-40d8-a038-a0277611868f",
|
|
18
|
+
"base": {
|
|
19
|
+
"requirements": "",
|
|
20
|
+
},
|
|
19
21
|
"variants": {
|
|
20
22
|
"gpu": {
|
|
21
23
|
"partition": "gpu-debug",
|
|
@@ -26,24 +28,12 @@ KNOWN_ENDPOINTS: dict[str, dict[str, Any]] = {
|
|
|
26
28
|
},
|
|
27
29
|
"tutorial": {
|
|
28
30
|
"uuid": "4b116d3c-1703-4f8f-9f6f-39921e5864df",
|
|
31
|
+
"base": {},
|
|
29
32
|
"variants": {},
|
|
30
33
|
},
|
|
31
34
|
}
|
|
32
35
|
|
|
33
36
|
|
|
34
|
-
@dataclass
|
|
35
|
-
class FormattedEndpoint:
|
|
36
|
-
"""Formatted endpoint configuration with metadata for template rendering.
|
|
37
|
-
|
|
38
|
-
Attributes:
|
|
39
|
-
name: Endpoint name for use in @hog.function(endpoint="name")
|
|
40
|
-
toml_block: Formatted TOML configuration block
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
name: str
|
|
44
|
-
toml_block: str
|
|
45
|
-
|
|
46
|
-
|
|
47
37
|
class EndpointSpec:
|
|
48
38
|
"""Parsed endpoint specification from --endpoint flag.
|
|
49
39
|
|
|
@@ -51,6 +41,7 @@ class EndpointSpec:
|
|
|
51
41
|
name: Table name for [tool.hog.{name}]
|
|
52
42
|
variant: Optional variant name for [tool.hog.{name}.{variant}]
|
|
53
43
|
uuid: Globus Compute endpoint UUID
|
|
44
|
+
base_defaults: Dict of defaults to apply to base endpoint (if known endpoint)
|
|
54
45
|
variant_defaults: Dict of defaults to apply to variant (if known variant)
|
|
55
46
|
"""
|
|
56
47
|
|
|
@@ -59,11 +50,13 @@ class EndpointSpec:
|
|
|
59
50
|
name: str,
|
|
60
51
|
variant: str | None,
|
|
61
52
|
uuid: str,
|
|
53
|
+
base_defaults: dict[str, Any] | None = None,
|
|
62
54
|
variant_defaults: dict[str, Any] | None = None,
|
|
63
55
|
):
|
|
64
56
|
self.name = name
|
|
65
57
|
self.variant = variant
|
|
66
58
|
self.uuid = uuid
|
|
59
|
+
self.base_defaults = base_defaults or {}
|
|
67
60
|
self.variant_defaults = variant_defaults or {}
|
|
68
61
|
|
|
69
62
|
|
|
@@ -115,29 +108,40 @@ def parse_endpoint_spec(spec: str) -> EndpointSpec:
|
|
|
115
108
|
if "." in spec:
|
|
116
109
|
base_name, variant = spec.split(".", 1)
|
|
117
110
|
if base_name not in KNOWN_ENDPOINTS:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
111
|
+
# Stub out unknown endpoint with variant with TODO placeholder
|
|
112
|
+
return EndpointSpec(
|
|
113
|
+
name=base_name,
|
|
114
|
+
variant=variant,
|
|
115
|
+
uuid="TODO: Replace with your endpoint UUID",
|
|
121
116
|
)
|
|
122
117
|
|
|
123
118
|
endpoint_info = KNOWN_ENDPOINTS[base_name]
|
|
124
119
|
uuid = endpoint_info["uuid"]
|
|
120
|
+
base_defaults = endpoint_info.get("base", {})
|
|
125
121
|
variant_defaults = endpoint_info["variants"].get(variant, {})
|
|
126
122
|
|
|
127
123
|
return EndpointSpec(
|
|
128
124
|
name=base_name,
|
|
129
125
|
variant=variant,
|
|
130
126
|
uuid=uuid,
|
|
127
|
+
base_defaults=base_defaults,
|
|
131
128
|
variant_defaults=variant_defaults,
|
|
132
129
|
)
|
|
133
130
|
|
|
134
|
-
# Must be a known endpoint name
|
|
131
|
+
# Must be a known endpoint name, or stub out unknown ones
|
|
135
132
|
if spec not in KNOWN_ENDPOINTS:
|
|
136
|
-
|
|
137
|
-
|
|
133
|
+
# Stub out unknown endpoint with TODO placeholder
|
|
134
|
+
return EndpointSpec(
|
|
135
|
+
name=spec,
|
|
136
|
+
variant=None,
|
|
137
|
+
uuid="TODO: Replace with your endpoint UUID",
|
|
138
|
+
)
|
|
138
139
|
|
|
139
140
|
endpoint_info = KNOWN_ENDPOINTS[spec]
|
|
140
|
-
|
|
141
|
+
base_defaults = endpoint_info.get("base", {})
|
|
142
|
+
return EndpointSpec(
|
|
143
|
+
name=spec, variant=None, uuid=endpoint_info["uuid"], base_defaults=base_defaults
|
|
144
|
+
)
|
|
141
145
|
|
|
142
146
|
|
|
143
147
|
def generate_endpoint_config(spec: EndpointSpec) -> dict[str, dict[str, Any]]:
|
|
@@ -158,10 +162,22 @@ def generate_endpoint_config(spec: EndpointSpec) -> dict[str, dict[str, Any]]:
|
|
|
158
162
|
"""
|
|
159
163
|
result: dict[str, Any] = {}
|
|
160
164
|
|
|
165
|
+
# If UUID is a TODO placeholder, skip schema fetching
|
|
166
|
+
if spec.uuid.startswith("TODO"):
|
|
167
|
+
filtered_base_defaults = spec.base_defaults.copy()
|
|
168
|
+
else:
|
|
169
|
+
# Filter base_defaults to only include fields present in the endpoint schema
|
|
170
|
+
schema = get_endpoint_schema(spec.uuid)
|
|
171
|
+
schema_fields = set(schema.get("properties", {}).keys())
|
|
172
|
+
filtered_base_defaults = {
|
|
173
|
+
k: v for k, v in spec.base_defaults.items() if k in schema_fields
|
|
174
|
+
}
|
|
175
|
+
|
|
161
176
|
# Base configuration
|
|
162
177
|
base_config = {
|
|
163
178
|
"endpoint": spec.uuid,
|
|
164
|
-
|
|
179
|
+
**filtered_base_defaults,
|
|
180
|
+
# Other fields will be added by user, we just provide the endpoint UUID + defaults
|
|
165
181
|
}
|
|
166
182
|
result[spec.name] = base_config
|
|
167
183
|
|
|
@@ -203,152 +219,3 @@ def get_endpoint_schema_comments(endpoint_uuid: str) -> dict[str, str]:
|
|
|
203
219
|
comments[field_name] = ". ".join(parts)
|
|
204
220
|
|
|
205
221
|
return comments
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def get_endpoint_display_name(endpoint_uuid: str) -> str | None:
|
|
209
|
-
"""Fetch endpoint display name from metadata.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
endpoint_uuid: Globus Compute endpoint UUID
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
Display name if available and different from 'name', otherwise None
|
|
216
|
-
"""
|
|
217
|
-
metadata = get_endpoint_metadata(endpoint_uuid)
|
|
218
|
-
|
|
219
|
-
display_name = metadata.get("display_name")
|
|
220
|
-
name = metadata.get("name")
|
|
221
|
-
|
|
222
|
-
# Only return display_name if it's meaningful
|
|
223
|
-
if display_name and display_name != "None" and display_name != name:
|
|
224
|
-
return display_name
|
|
225
|
-
|
|
226
|
-
return None
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def format_endpoint_config_to_toml(
|
|
230
|
-
config_dict: dict[str, dict[str, Any]],
|
|
231
|
-
endpoint_uuid: str,
|
|
232
|
-
include_schema_comments: bool = True,
|
|
233
|
-
) -> str:
|
|
234
|
-
"""Format an endpoint configuration dict as TOML with inline documentation.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
config_dict: Dict with structure {"endpoint_name": {"endpoint": "uuid", ...}}
|
|
238
|
-
endpoint_uuid: UUID for fetching schema documentation
|
|
239
|
-
include_schema_comments: If True, add commented schema fields with docs
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
Formatted TOML string with comments
|
|
243
|
-
"""
|
|
244
|
-
lines = []
|
|
245
|
-
|
|
246
|
-
# Get display name and schema comments
|
|
247
|
-
display_name = get_endpoint_display_name(endpoint_uuid)
|
|
248
|
-
|
|
249
|
-
# Calculate padding for aligned inline comments
|
|
250
|
-
# Align to UUID line length (approx 51 chars: "# endpoint = "uuid..."")
|
|
251
|
-
# For schema comments: "# # field_name = " should align comment to ~column 52
|
|
252
|
-
alignment_column = 52
|
|
253
|
-
|
|
254
|
-
if include_schema_comments:
|
|
255
|
-
comments = get_endpoint_schema_comments(endpoint_uuid)
|
|
256
|
-
|
|
257
|
-
for endpoint_name, config in config_dict.items():
|
|
258
|
-
# Check if this is a base config or has nested variants
|
|
259
|
-
has_variants = any(isinstance(v, dict) for v in config.values())
|
|
260
|
-
|
|
261
|
-
if has_variants:
|
|
262
|
-
# Process base and variants
|
|
263
|
-
base_config = {k: v for k, v in config.items() if not isinstance(v, dict)}
|
|
264
|
-
variants = {k: v for k, v in config.items() if isinstance(v, dict)}
|
|
265
|
-
|
|
266
|
-
# Base config header
|
|
267
|
-
header = f"[tool.hog.{endpoint_name}]"
|
|
268
|
-
if display_name:
|
|
269
|
-
lines.append(f"# {header} # {display_name}")
|
|
270
|
-
else:
|
|
271
|
-
lines.append(f"# {header}")
|
|
272
|
-
|
|
273
|
-
# Base config fields (active, so prefix with # for PEP 723)
|
|
274
|
-
for key, value in base_config.items():
|
|
275
|
-
if isinstance(value, str):
|
|
276
|
-
lines.append(f'# {key} = "{value}"')
|
|
277
|
-
else:
|
|
278
|
-
lines.append(f"# {key} = {value}")
|
|
279
|
-
|
|
280
|
-
# Add schema comments if requested (commented out, so prefix with # #)
|
|
281
|
-
if include_schema_comments:
|
|
282
|
-
comments = get_endpoint_schema_comments(endpoint_uuid)
|
|
283
|
-
for field_name, comment in comments.items():
|
|
284
|
-
# Pad to align inline comments (left-align, pad to alignment_column)
|
|
285
|
-
lines.append(
|
|
286
|
-
f"# # {field_name} = {'':<{alignment_column - 7 - len(field_name)}}# {comment}"
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# Variant configs (active headers, active fields)
|
|
290
|
-
for variant_name, variant_config in variants.items():
|
|
291
|
-
lines.append("#")
|
|
292
|
-
lines.append(f"# [tool.hog.{endpoint_name}.{variant_name}]")
|
|
293
|
-
for key, value in variant_config.items():
|
|
294
|
-
if isinstance(value, str):
|
|
295
|
-
lines.append(f'# {key} = "{value}"')
|
|
296
|
-
else:
|
|
297
|
-
lines.append(f"# {key} = {value}")
|
|
298
|
-
else:
|
|
299
|
-
# Simple config without variants
|
|
300
|
-
header = f"[tool.hog.{endpoint_name}]"
|
|
301
|
-
if display_name:
|
|
302
|
-
lines.append(f"# {header} # {display_name}")
|
|
303
|
-
else:
|
|
304
|
-
lines.append(f"# {header}")
|
|
305
|
-
|
|
306
|
-
# Active fields (prefix with # for PEP 723)
|
|
307
|
-
for key, value in config.items():
|
|
308
|
-
if isinstance(value, str):
|
|
309
|
-
lines.append(f'# {key} = "{value}"')
|
|
310
|
-
else:
|
|
311
|
-
lines.append(f"# {key} = {value}")
|
|
312
|
-
|
|
313
|
-
# Add schema comments if requested (commented out, so prefix with # #)
|
|
314
|
-
if include_schema_comments:
|
|
315
|
-
comments = get_endpoint_schema_comments(endpoint_uuid)
|
|
316
|
-
for field_name, comment in comments.items():
|
|
317
|
-
# Pad to align inline comments (left-align, pad to alignment_column)
|
|
318
|
-
lines.append(
|
|
319
|
-
f"# # {field_name} = {'':<{alignment_column - 7 - len(field_name)}}# {comment}"
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
return "\n".join(lines)
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
def fetch_and_format_endpoints(endpoint_specs: list[str]) -> list[FormattedEndpoint]:
|
|
326
|
-
"""Parse endpoint specifications and generate formatted TOML blocks.
|
|
327
|
-
|
|
328
|
-
Args:
|
|
329
|
-
endpoint_specs: List of endpoint specification strings
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
List of FormattedEndpoint objects with name and TOML block
|
|
333
|
-
|
|
334
|
-
Raises:
|
|
335
|
-
ValueError: If any endpoint spec is invalid
|
|
336
|
-
RuntimeError: If unable to fetch endpoint metadata
|
|
337
|
-
"""
|
|
338
|
-
endpoints = []
|
|
339
|
-
|
|
340
|
-
for spec_str in endpoint_specs:
|
|
341
|
-
try:
|
|
342
|
-
spec = parse_endpoint_spec(spec_str)
|
|
343
|
-
config_dict = generate_endpoint_config(spec)
|
|
344
|
-
toml_block = format_endpoint_config_to_toml(
|
|
345
|
-
config_dict, spec.uuid, include_schema_comments=True
|
|
346
|
-
)
|
|
347
|
-
endpoints.append(FormattedEndpoint(name=spec.name, toml_block=toml_block))
|
|
348
|
-
except Exception as e:
|
|
349
|
-
# Re-raise with context
|
|
350
|
-
raise RuntimeError(
|
|
351
|
-
f"Failed to process endpoint spec '{spec_str}': {e}"
|
|
352
|
-
) from e
|
|
353
|
-
|
|
354
|
-
return endpoints
|
|
@@ -25,10 +25,12 @@ class EndpointConfig(BaseModel, extra="allow"):
|
|
|
25
25
|
Attributes:
|
|
26
26
|
endpoint: Globus Compute endpoint UUID (required for base configs)
|
|
27
27
|
worker_init: Shell commands to run in worker initialization
|
|
28
|
+
endpoint_setup: Shell commands to run in endpoint setup
|
|
28
29
|
"""
|
|
29
30
|
|
|
30
31
|
endpoint: str | UUID
|
|
31
32
|
worker_init: str | None = None
|
|
33
|
+
endpoint_setup: str | None = None
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
class EndpointVariant(BaseModel, extra="allow"):
|
|
@@ -45,10 +47,12 @@ class EndpointVariant(BaseModel, extra="allow"):
|
|
|
45
47
|
Attributes:
|
|
46
48
|
endpoint: Always None (variants must inherit endpoint from base)
|
|
47
49
|
worker_init: Additional worker init commands (concatenated with base)
|
|
50
|
+
endpoint_setup: Additional endpoint setup commands (concatenated with base)
|
|
48
51
|
"""
|
|
49
52
|
|
|
50
53
|
endpoint: None = None
|
|
51
54
|
worker_init: str | None = None
|
|
55
|
+
endpoint_setup: str | None = None
|
|
52
56
|
|
|
53
57
|
@model_validator(mode="before")
|
|
54
58
|
@classmethod
|
|
@@ -63,9 +67,31 @@ class EndpointVariant(BaseModel, extra="allow"):
|
|
|
63
67
|
|
|
64
68
|
|
|
65
69
|
class UvMetadata(BaseModel, extra="allow", serialize_by_alias=True):
|
|
70
|
+
"""Configuration for uv package manager via [tool.uv].
|
|
71
|
+
|
|
72
|
+
Common fields are modeled for validation and defaults. Additional uv settings
|
|
73
|
+
are supported via extra="allow" - see https://docs.astral.sh/uv/reference/settings/
|
|
74
|
+
|
|
75
|
+
Note: Environment variables (UV_*) and CLI flags take precedence over TOML config.
|
|
76
|
+
See uv documentation for full precedence hierarchy.
|
|
77
|
+
|
|
78
|
+
Attributes:
|
|
79
|
+
exclude_newer: Limit packages to versions uploaded before cutoff (ISO 8601 timestamp)
|
|
80
|
+
python_preference: Control system vs managed Python ("managed" | "only-managed" | "system" | "only-system")
|
|
81
|
+
index_url: Primary package index URL (default: PyPI)
|
|
82
|
+
extra_index_url: Additional package indexes (searched after index_url)
|
|
83
|
+
python_downloads: Control automatic Python downloads ("automatic" | "manual" | "never")
|
|
84
|
+
offline: Disable all network access (use only cache and local files)
|
|
85
|
+
"""
|
|
86
|
+
|
|
66
87
|
exclude_newer: str | None = Field(
|
|
67
88
|
default_factory=_default_exclude_newer, alias="exclude-newer"
|
|
68
89
|
)
|
|
90
|
+
python_preference: str | None = Field(default="managed", alias="python-preference")
|
|
91
|
+
index_url: str | None = Field(default=None, alias="index-url")
|
|
92
|
+
extra_index_url: list[str] | None = Field(default=None, alias="extra-index-url")
|
|
93
|
+
python_downloads: str | None = Field(default=None, alias="python-downloads")
|
|
94
|
+
offline: bool | None = None
|
|
69
95
|
|
|
70
96
|
|
|
71
97
|
class ToolMetadata(BaseModel, extra="allow"):
|
|
@@ -6,8 +6,10 @@ using the PEP 723 inline script metadata format (# /// script ... # ///).
|
|
|
6
6
|
|
|
7
7
|
import re
|
|
8
8
|
import sys
|
|
9
|
+
from typing import Any, cast
|
|
9
10
|
|
|
10
|
-
import
|
|
11
|
+
import tomlkit
|
|
12
|
+
import tomlkit.items
|
|
11
13
|
|
|
12
14
|
from groundhog_hpc.configuration.models import Pep723Metadata
|
|
13
15
|
|
|
@@ -73,7 +75,7 @@ def write_pep723(metadata: Pep723Metadata) -> str:
|
|
|
73
75
|
metadata_dict = metadata.model_dump(by_alias=True, exclude_none=True)
|
|
74
76
|
|
|
75
77
|
# Convert dict to TOML format
|
|
76
|
-
toml_content =
|
|
78
|
+
toml_content = tomlkit.dumps(metadata_dict)
|
|
77
79
|
|
|
78
80
|
# Format as PEP 723 inline metadata block
|
|
79
81
|
lines = ["# /// script"]
|
|
@@ -137,3 +139,277 @@ def insert_or_update_metadata(script_content: str, metadata: Pep723Metadata) ->
|
|
|
137
139
|
lines.insert(insert_index + 1, "")
|
|
138
140
|
|
|
139
141
|
return "\n".join(lines)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def extract_pep723_toml(
|
|
145
|
+
script: str,
|
|
146
|
+
) -> tuple[tomlkit.TOMLDocument, re.Match] | tuple[None, None]:
|
|
147
|
+
"""Extract TOML document from PEP 723 block using tomlkit for round-trip preservation.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
script: The full text content of a Python script
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Tuple of (tomlkit document, regex match) or (None, None) if no block exists.
|
|
154
|
+
"""
|
|
155
|
+
name = "script"
|
|
156
|
+
matches = list(
|
|
157
|
+
filter(
|
|
158
|
+
lambda m: m.group("type") == name,
|
|
159
|
+
re.finditer(INLINE_METADATA_REGEX, script),
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
if len(matches) > 1:
|
|
163
|
+
raise ValueError(f"Multiple {name} blocks found")
|
|
164
|
+
elif len(matches) == 1:
|
|
165
|
+
match = matches[0]
|
|
166
|
+
content = "".join(
|
|
167
|
+
line[2:] if line.startswith("# ") else line[1:]
|
|
168
|
+
for line in match.group("content").splitlines(keepends=True)
|
|
169
|
+
)
|
|
170
|
+
doc = tomlkit.parse(content)
|
|
171
|
+
return doc, match
|
|
172
|
+
else:
|
|
173
|
+
return None, None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def embed_pep723_toml(
|
|
177
|
+
script: str, doc: tomlkit.TOMLDocument, match: re.Match | None
|
|
178
|
+
) -> str:
|
|
179
|
+
"""Replace PEP 723 block with updated TOML document.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
script: The full text content of a Python script
|
|
183
|
+
doc: tomlkit TOMLDocument to embed
|
|
184
|
+
match: regex match from extract_pep723_toml, or None to insert new block
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Updated script content with the new/updated PEP 723 block
|
|
188
|
+
"""
|
|
189
|
+
# Format TOML document as PEP 723 block
|
|
190
|
+
toml_content = tomlkit.dumps(doc)
|
|
191
|
+
lines = ["# /// script"]
|
|
192
|
+
for line in toml_content.splitlines():
|
|
193
|
+
if line.strip():
|
|
194
|
+
lines.append(f"# {line}")
|
|
195
|
+
else:
|
|
196
|
+
lines.append("#")
|
|
197
|
+
lines.append("# ///")
|
|
198
|
+
metadata_block = "\n".join(lines)
|
|
199
|
+
|
|
200
|
+
if match:
|
|
201
|
+
# Replace existing block
|
|
202
|
+
return script[: match.start()] + metadata_block + script[match.end() :]
|
|
203
|
+
else:
|
|
204
|
+
# Insert at the beginning (after shebang/encoding if present)
|
|
205
|
+
script_lines = script.split("\n")
|
|
206
|
+
insert_index = 0
|
|
207
|
+
|
|
208
|
+
# Skip shebang line if present
|
|
209
|
+
if script_lines and script_lines[0].startswith("#!"):
|
|
210
|
+
insert_index = 1
|
|
211
|
+
|
|
212
|
+
# Skip encoding declaration if present
|
|
213
|
+
if insert_index < len(script_lines) and (
|
|
214
|
+
script_lines[insert_index].startswith("# -*- coding:")
|
|
215
|
+
or script_lines[insert_index].startswith("# coding:")
|
|
216
|
+
):
|
|
217
|
+
insert_index += 1
|
|
218
|
+
|
|
219
|
+
# Insert metadata block at the appropriate position
|
|
220
|
+
script_lines.insert(insert_index, metadata_block)
|
|
221
|
+
|
|
222
|
+
# Add blank line after metadata if there isn't one
|
|
223
|
+
if (
|
|
224
|
+
insert_index + 1 < len(script_lines)
|
|
225
|
+
and script_lines[insert_index + 1].strip()
|
|
226
|
+
):
|
|
227
|
+
script_lines.insert(insert_index + 1, "")
|
|
228
|
+
|
|
229
|
+
return "\n".join(script_lines)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def add_endpoint_to_toml(
|
|
233
|
+
doc: tomlkit.TOMLDocument,
|
|
234
|
+
endpoint_name: str,
|
|
235
|
+
endpoint_config: dict[str, Any],
|
|
236
|
+
variant_name: str | None = None,
|
|
237
|
+
variant_config: dict[str, Any] | None = None,
|
|
238
|
+
schema_comments: dict[str, str] | None = None,
|
|
239
|
+
) -> str | None:
|
|
240
|
+
"""Add endpoint config to TOML document in-place.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
doc: tomlkit TOMLDocument to modify
|
|
244
|
+
endpoint_name: Base endpoint name (e.g., "anvil")
|
|
245
|
+
endpoint_config: Base endpoint config dict
|
|
246
|
+
variant_name: Optional variant name (e.g., "gpu")
|
|
247
|
+
variant_config: Optional variant config dict
|
|
248
|
+
schema_comments: Optional dict mapping field names to comment strings
|
|
249
|
+
(e.g., {"account": "Type: string. Your allocation account"})
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Skip message if endpoint/variant already exists, None on success.
|
|
253
|
+
"""
|
|
254
|
+
# Ensure tool.hog exists
|
|
255
|
+
if "tool" not in doc:
|
|
256
|
+
doc["tool"] = tomlkit.table()
|
|
257
|
+
|
|
258
|
+
tool_table = cast(tomlkit.items.Table, doc["tool"])
|
|
259
|
+
if "hog" not in tool_table:
|
|
260
|
+
tool_table["hog"] = tomlkit.table()
|
|
261
|
+
|
|
262
|
+
hog = cast(tomlkit.items.Table, tool_table["hog"])
|
|
263
|
+
|
|
264
|
+
# Check if we're just adding a variant to an existing base
|
|
265
|
+
if variant_name is not None:
|
|
266
|
+
if endpoint_name in hog:
|
|
267
|
+
# Base exists - check if variant exists
|
|
268
|
+
endpoint_table = cast(tomlkit.items.Table, hog[endpoint_name])
|
|
269
|
+
if variant_name in endpoint_table:
|
|
270
|
+
return f"Variant '{endpoint_name}.{variant_name}' already exists"
|
|
271
|
+
# Add variant to existing base
|
|
272
|
+
endpoint_table[variant_name] = tomlkit.table()
|
|
273
|
+
for key, value in (variant_config or {}).items():
|
|
274
|
+
variant_table = cast(tomlkit.items.Table, endpoint_table[variant_name])
|
|
275
|
+
variant_table[key] = value
|
|
276
|
+
return None
|
|
277
|
+
else:
|
|
278
|
+
# Base doesn't exist - add base + variant
|
|
279
|
+
hog[endpoint_name] = tomlkit.table()
|
|
280
|
+
endpoint_table = cast(tomlkit.items.Table, hog[endpoint_name])
|
|
281
|
+
for key, value in endpoint_config.items():
|
|
282
|
+
endpoint_table[key] = value
|
|
283
|
+
# Add schema comments for fields not already in config
|
|
284
|
+
_add_schema_comments(endpoint_table, endpoint_config, schema_comments)
|
|
285
|
+
endpoint_table[variant_name] = tomlkit.table()
|
|
286
|
+
for key, value in (variant_config or {}).items():
|
|
287
|
+
variant_table = cast(tomlkit.items.Table, endpoint_table[variant_name])
|
|
288
|
+
variant_table[key] = value
|
|
289
|
+
return None
|
|
290
|
+
else:
|
|
291
|
+
# Just adding base endpoint
|
|
292
|
+
if endpoint_name in hog:
|
|
293
|
+
return f"Endpoint '{endpoint_name}' already exists"
|
|
294
|
+
hog[endpoint_name] = tomlkit.table()
|
|
295
|
+
endpoint_table = cast(tomlkit.items.Table, hog[endpoint_name])
|
|
296
|
+
for key, value in endpoint_config.items():
|
|
297
|
+
endpoint_table[key] = value
|
|
298
|
+
# Add schema comments for fields not already in config
|
|
299
|
+
_add_schema_comments(endpoint_table, endpoint_config, schema_comments)
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _add_schema_comments(
|
|
304
|
+
table: tomlkit.items.Table,
|
|
305
|
+
existing_config: dict[str, Any],
|
|
306
|
+
schema_comments: dict[str, str] | None,
|
|
307
|
+
) -> None:
|
|
308
|
+
"""Add commented-out schema fields to a tomlkit table.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
table: tomlkit Table to add comments to
|
|
312
|
+
existing_config: Dict of fields already in the config (to skip)
|
|
313
|
+
schema_comments: Dict mapping field names to comment strings
|
|
314
|
+
"""
|
|
315
|
+
if not schema_comments:
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
# Align comments to column 52 (matches format_endpoint_config_to_toml)
|
|
319
|
+
alignment_column = 52
|
|
320
|
+
|
|
321
|
+
for field_name, comment in schema_comments.items():
|
|
322
|
+
# Skip fields already in the active config
|
|
323
|
+
if field_name in existing_config:
|
|
324
|
+
continue
|
|
325
|
+
# Add as a commented-out field with documentation
|
|
326
|
+
# Format: # field_name = # Type: string. Description
|
|
327
|
+
# Note: tomlkit.comment adds "# " prefix, embed_pep723_toml adds another "# "
|
|
328
|
+
# so final output is "# # field_name = {padding}# comment"
|
|
329
|
+
padding = " " * max(1, alignment_column - 5 - len(field_name))
|
|
330
|
+
table.add(tomlkit.comment(f"{field_name} ={padding}# {comment}"))
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def remove_endpoint_from_script(
|
|
334
|
+
script_content: str, endpoint_name: str, variant_name: str | None = None
|
|
335
|
+
) -> str:
|
|
336
|
+
"""Remove an endpoint or variant from a script's PEP 723 block.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
script_content: Full script file content
|
|
340
|
+
endpoint_name: Name of the endpoint (e.g., "my_endpoint")
|
|
341
|
+
variant_name: Optional variant name. If provided, only removes that variant.
|
|
342
|
+
If None, removes the entire endpoint and all its variants.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Updated script content with the endpoint/variant removed
|
|
346
|
+
"""
|
|
347
|
+
doc, match = extract_pep723_toml(script_content)
|
|
348
|
+
if doc is None or match is None:
|
|
349
|
+
return script_content
|
|
350
|
+
|
|
351
|
+
if "tool" in doc:
|
|
352
|
+
tool_table = cast(tomlkit.items.Table, doc["tool"])
|
|
353
|
+
if "hog" in tool_table:
|
|
354
|
+
hog = cast(tomlkit.items.Table, tool_table["hog"])
|
|
355
|
+
|
|
356
|
+
if variant_name is not None:
|
|
357
|
+
# Remove only the specific variant
|
|
358
|
+
if endpoint_name in hog:
|
|
359
|
+
endpoint_table = cast(tomlkit.items.Table, hog[endpoint_name])
|
|
360
|
+
if variant_name in endpoint_table:
|
|
361
|
+
del endpoint_table[variant_name]
|
|
362
|
+
else:
|
|
363
|
+
# Remove the entire endpoint (and all its variants)
|
|
364
|
+
if endpoint_name in hog:
|
|
365
|
+
del hog[endpoint_name]
|
|
366
|
+
|
|
367
|
+
return embed_pep723_toml(script_content, doc, match)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def add_endpoint_to_script(
|
|
371
|
+
script_content: str,
|
|
372
|
+
endpoint_name: str,
|
|
373
|
+
endpoint_config: dict[str, Any],
|
|
374
|
+
variant_name: str | None = None,
|
|
375
|
+
variant_config: dict[str, Any] | None = None,
|
|
376
|
+
schema_comments: dict[str, str] | None = None,
|
|
377
|
+
) -> tuple[str, str | None]:
|
|
378
|
+
"""Add endpoint config to existing script, preserving formatting.
|
|
379
|
+
|
|
380
|
+
Convenience wrapper that combines extract_pep723_toml, add_endpoint_to_toml,
|
|
381
|
+
and embed_pep723_toml into a single call.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
script_content: Full script file content
|
|
385
|
+
endpoint_name: Base endpoint name (e.g., "anvil")
|
|
386
|
+
endpoint_config: Base endpoint config dict
|
|
387
|
+
variant_name: Optional variant name (e.g., "gpu")
|
|
388
|
+
variant_config: Optional variant config dict
|
|
389
|
+
schema_comments: Optional dict mapping field names to comment strings
|
|
390
|
+
(e.g., {"account": "Type: string. Your allocation account"})
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Tuple of (updated_content, skip_message)
|
|
394
|
+
skip_message is None if endpoint was added, or info string if skipped
|
|
395
|
+
"""
|
|
396
|
+
doc, match = extract_pep723_toml(script_content)
|
|
397
|
+
|
|
398
|
+
if doc is None:
|
|
399
|
+
# No PEP 723 block exists - create minimal one with defaults
|
|
400
|
+
doc = tomlkit.document()
|
|
401
|
+
metadata = Pep723Metadata()
|
|
402
|
+
doc["requires-python"] = metadata.requires_python
|
|
403
|
+
doc["dependencies"] = []
|
|
404
|
+
|
|
405
|
+
skip_msg = add_endpoint_to_toml(
|
|
406
|
+
doc,
|
|
407
|
+
endpoint_name,
|
|
408
|
+
endpoint_config,
|
|
409
|
+
variant_name,
|
|
410
|
+
variant_config,
|
|
411
|
+
schema_comments,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
updated_content = embed_pep723_toml(script_content, doc, match)
|
|
415
|
+
return updated_content, skip_msg
|