hafnia 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hafnia/experiment/command_builder.py +686 -0
- {hafnia-0.5.1.dist-info → hafnia-0.5.2.dist-info}/METADATA +2 -1
- {hafnia-0.5.1.dist-info → hafnia-0.5.2.dist-info}/RECORD +6 -5
- {hafnia-0.5.1.dist-info → hafnia-0.5.2.dist-info}/WHEEL +0 -0
- {hafnia-0.5.1.dist-info → hafnia-0.5.2.dist-info}/entry_points.txt +0 -0
- {hafnia-0.5.1.dist-info → hafnia-0.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,686 @@
|
|
|
1
|
+
"""
|
|
2
|
+
We have a concept in Hafnia Training-aaS called the "Command Builder" that helps the user construct
|
|
3
|
+
a CLI command that will be passed to and executed on the trainer.
|
|
4
|
+
|
|
5
|
+
An example of this could be `python scripts/train.py --batch-size 32`, that tells the trainer to
|
|
6
|
+
launch the `scripts/train.py` python script and set the `batch_size` parameter to `32`.
|
|
7
|
+
|
|
8
|
+
The "Command Builder" is a dynamic form-based interface that is specific for each training script and
|
|
9
|
+
should contain available parameters that the script accepts. Each parameter is described by
|
|
10
|
+
its name and type and optionally parameter title, description, and default value.
|
|
11
|
+
|
|
12
|
+
The CommandBuilderSchema is a json description/configuration for creating a command builder form.
|
|
13
|
+
This python module provides functionality for constructing or automatically generating such a schema from
|
|
14
|
+
a CLI function (e.g. the `main` function in `scripts/train.py`), and save it as a JSON file.
|
|
15
|
+
|
|
16
|
+
The convention is that if the CLI function is defined in `./scripts/train.py`, the corresponding
|
|
17
|
+
CommandBuilderSchema JSON file should be saved as `./scripts/train.json`.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import inspect
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import (
|
|
25
|
+
Annotated,
|
|
26
|
+
Any,
|
|
27
|
+
Callable,
|
|
28
|
+
Dict,
|
|
29
|
+
List,
|
|
30
|
+
Literal,
|
|
31
|
+
Optional,
|
|
32
|
+
Tuple,
|
|
33
|
+
Type,
|
|
34
|
+
get_args,
|
|
35
|
+
get_origin,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
import docstring_parser
|
|
39
|
+
import flatten_dict
|
|
40
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model, model_validator
|
|
41
|
+
from pydantic.fields import FieldInfo
|
|
42
|
+
from pydantic_core import PydanticUndefined
|
|
43
|
+
|
|
44
|
+
DEFAULT_N_POSITIONAL_ARGS: int = 0
|
|
45
|
+
DEFAULT_CASE_CONVERSION: Literal["none", "kebab"] = "none"
|
|
46
|
+
DEFAULT_PARAMETER_PREFIX: str = "--"
|
|
47
|
+
DEFAULT_NESTED_PARAMETER_HANDLING: Literal["dot", "not-supported"] = "dot"
|
|
48
|
+
DEFAULT_ASSIGNMENT_SEPARATOR: Literal["space", "equals"] = "space"
|
|
49
|
+
DEFAULT_BOOL_HANDLING: Literal["none", "flag-negation"] = "none"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def auto_save_command_builder_schema(
|
|
53
|
+
cli_function: Callable[..., Any],
|
|
54
|
+
n_positional_args: int = DEFAULT_N_POSITIONAL_ARGS,
|
|
55
|
+
cli_tool: Optional[str] = "cyclopts",
|
|
56
|
+
ignore_params: tuple[str, ...] = (),
|
|
57
|
+
handle_union_types: bool = True,
|
|
58
|
+
path_schema: Optional[Path] = None,
|
|
59
|
+
cmd: Optional[str] = None,
|
|
60
|
+
case_conversion: Optional[Literal["none", "kebab"]] = None,
|
|
61
|
+
parameter_prefix: Optional[str] = None,
|
|
62
|
+
nested_parameter_handling: Optional[Literal["dot", "not-supported"]] = None,
|
|
63
|
+
assignment_separator: Optional[Literal["space", "equals"]] = None,
|
|
64
|
+
bool_handling: Optional[Literal["none", "flag-negation"]] = None,
|
|
65
|
+
) -> Path:
|
|
66
|
+
"""
|
|
67
|
+
Magic function to create and save CommandBuilderSchema as JSON file from a CLI function.
|
|
68
|
+
If the function is invoked in e.g. 'scripts/train.py' it will create a JSON schema file
|
|
69
|
+
'scripts/train.json' next to the script file.
|
|
70
|
+
|
|
71
|
+
The the auto-generated schema might not work for all CLI functions and settings. In that case,
|
|
72
|
+
you can manually create the CommandBuilderSchema using the 'CommandBuilderSchema.from_function(...)'
|
|
73
|
+
method and save it using the 'to_json_file(...)' method. Or directly create the CommandBuilderSchema
|
|
74
|
+
by providing the 'json_schema' argument as shown in the example below.
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
function_schema = schema_from_cli_function(main_cli_function)
|
|
78
|
+
command_builder = CommandBuilderSchema(
|
|
79
|
+
cmd="python scripts/train.py",
|
|
80
|
+
json_schema=function_schema,
|
|
81
|
+
kebab_case=True,
|
|
82
|
+
parameter_prefix="--",
|
|
83
|
+
nested_parameter_separator=".",
|
|
84
|
+
assignment_separator="space",
|
|
85
|
+
n_positional_args=0,
|
|
86
|
+
bool_handling="flag-negation",
|
|
87
|
+
)
|
|
88
|
+
command_builder.to_json_file(Path("scripts/train.json"))
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
launch_schema = CommandBuilderSchema.from_function(
|
|
93
|
+
cli_function,
|
|
94
|
+
cli_tool=cli_tool,
|
|
95
|
+
ignore_params=ignore_params,
|
|
96
|
+
handle_union_types=handle_union_types,
|
|
97
|
+
cmd=cmd,
|
|
98
|
+
n_positional_args=n_positional_args,
|
|
99
|
+
case_conversion=case_conversion,
|
|
100
|
+
parameter_prefix=parameter_prefix,
|
|
101
|
+
nested_parameter_handling=nested_parameter_handling,
|
|
102
|
+
assignment_separator=assignment_separator,
|
|
103
|
+
bool_handling=bool_handling,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
path_schema = path_schema or path_of_function(cli_function).with_suffix(".json")
|
|
107
|
+
launch_schema.to_json_file(path_schema)
|
|
108
|
+
return path_schema
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class CommandBuilderSchema(BaseModel):
|
|
112
|
+
"""
|
|
113
|
+
A schema or configuration for the Command Builder Form. It is a data model that describes
|
|
114
|
+
the command to execute, which parameters are available for the command, and how the command line
|
|
115
|
+
arguments should be formatted.
|
|
116
|
+
|
|
117
|
+
The 'CommandBuilderSchema' data model can be constructed manually by providing the 'cmd' and
|
|
118
|
+
'json_schema' arguments, or it can be automatically generated from a CLI function using the
|
|
119
|
+
'from_function(...)' method.
|
|
120
|
+
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
cmd: str = Field(
|
|
124
|
+
...,
|
|
125
|
+
description=(
|
|
126
|
+
"The command to execute the trainer launcher 'python scripts/train.py', "
|
|
127
|
+
"'uv run scripts/train.py', 'trainer-classification'"
|
|
128
|
+
),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
json_schema: Dict[str, Any] = Field(
|
|
132
|
+
default_factory=dict,
|
|
133
|
+
description="JSON schema representing the command parameters. ",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
n_positional_args: int = Field(
|
|
137
|
+
DEFAULT_N_POSITIONAL_ARGS,
|
|
138
|
+
description=(
|
|
139
|
+
"Number of positional arguments in the command. To handle that some CLI tools "
|
|
140
|
+
"accept positional arguments before options, this specifies how many positional arguments are expected. "
|
|
141
|
+
"These will be filled in order before any named arguments. "
|
|
142
|
+
"E.g. for 'python train.py data/train --batch-size 32', n_positional_args would be 1 ('data/train')."
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Below defines how the specified parameters/arguments are formatted as command line arguments
|
|
147
|
+
case_conversion: Literal["none", "kebab"] = Field(
|
|
148
|
+
DEFAULT_CASE_CONVERSION,
|
|
149
|
+
description=(
|
|
150
|
+
"CLI format: The command line arguments will use kebab-case (e.g. --batch-size). "
|
|
151
|
+
"This handles what is commonly done in CLI tools where function argument names are converted from "
|
|
152
|
+
"'snake_case' to '--kebab-case' when used as command line arguments."
|
|
153
|
+
),
|
|
154
|
+
)
|
|
155
|
+
parameter_prefix: str = Field(
|
|
156
|
+
DEFAULT_PARAMETER_PREFIX,
|
|
157
|
+
description="CLI format: The prefix used for command line arguments. E.g. '--' for '--batch-size' ",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
nested_parameter_handling: Literal["dot", "not-supported"] = Field(
|
|
161
|
+
DEFAULT_NESTED_PARAMETER_HANDLING,
|
|
162
|
+
description=(
|
|
163
|
+
"CLI format: Separator used for nested parameters in the command line interface. "
|
|
164
|
+
"E.g. 'optimizer.lr' to set the learning rate inside an optimizer configuration."
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
assignment_separator: Literal["space", "equals"] = Field(
|
|
169
|
+
DEFAULT_ASSIGNMENT_SEPARATOR,
|
|
170
|
+
description=(
|
|
171
|
+
"CLI format: Separator between parameter names and their values in the command line. "
|
|
172
|
+
"E.g. ' ' for '--batch-size 32' or '=' for '--batch-size=32'."
|
|
173
|
+
),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
bool_handling: Literal["none", "flag-negation"] = Field(
|
|
177
|
+
DEFAULT_BOOL_HANDLING,
|
|
178
|
+
description=(
|
|
179
|
+
"CLI format: How boolean parameters are handled in the command line. "
|
|
180
|
+
"'none' means boolean flags must be explicitly set to True or False "
|
|
181
|
+
"e.g. '--flag True' or '--flag False'. "
|
|
182
|
+
"'flag-negation' uses flag negation where '--flag' sets True and '--no-flag' sets False."
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def to_json_file(self, path: Path) -> Path:
|
|
187
|
+
"""Write the launcher schema to a JSON file.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
path (Path): The file path to write the schema to.
|
|
191
|
+
"""
|
|
192
|
+
if not path.suffix == ".json":
|
|
193
|
+
raise ValueError(f"Expected a .json file path, got: {path}")
|
|
194
|
+
|
|
195
|
+
json_str = self.model_dump_json(indent=4)
|
|
196
|
+
path.write_text(json_str)
|
|
197
|
+
return path
|
|
198
|
+
|
|
199
|
+
@model_validator(mode="after")
|
|
200
|
+
def validate_json_schema(self) -> "CommandBuilderSchema":
|
|
201
|
+
# Basic validation of the json_schema field to ensure it has 'properties'
|
|
202
|
+
if "properties" not in self.json_schema:
|
|
203
|
+
raise ValueError("The 'json_schema' must contain a 'properties' field.")
|
|
204
|
+
|
|
205
|
+
if "$defs" in self.json_schema and self.nested_parameter_handling == "not-supported":
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"The 'json_schema' contains '$defs' indicating nested parameters, "
|
|
208
|
+
"but 'nested_parameter_handling' is set to 'not-supported'. "
|
|
209
|
+
"Either remove nested parameters from the schema or change "
|
|
210
|
+
"'nested_parameter_handling' to support nested parameters."
|
|
211
|
+
)
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def from_json_file(path: Path) -> "CommandBuilderSchema":
|
|
216
|
+
json_str = path.read_text()
|
|
217
|
+
return CommandBuilderSchema.model_validate_json(json_str)
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def from_function(
|
|
221
|
+
cli_function: Callable[..., Any],
|
|
222
|
+
cli_tool: Optional[str] = "cyclopts",
|
|
223
|
+
ignore_params: tuple[str, ...] = (),
|
|
224
|
+
cmd: Optional[str] = None,
|
|
225
|
+
handle_union_types: bool = True,
|
|
226
|
+
n_positional_args: int = DEFAULT_N_POSITIONAL_ARGS,
|
|
227
|
+
case_conversion: Optional[Literal["none", "kebab"]] = None,
|
|
228
|
+
parameter_prefix: Optional[str] = None,
|
|
229
|
+
nested_parameter_handling: Optional[Literal["dot", "not-supported"]] = None,
|
|
230
|
+
assignment_separator: Optional[Literal["space", "equals"]] = None,
|
|
231
|
+
bool_handling: Optional[Literal["none", "flag-negation"]] = None,
|
|
232
|
+
) -> "CommandBuilderSchema":
|
|
233
|
+
cmd = cmd or f"python {path_of_function(cli_function).as_posix()}"
|
|
234
|
+
if cli_tool == "cyclopts":
|
|
235
|
+
set_case_conversion = "kebab"
|
|
236
|
+
set_parameter_prefix = "--"
|
|
237
|
+
set_nested_parameter_handling = "dot"
|
|
238
|
+
set_assignment_separator = "space"
|
|
239
|
+
set_bool_handling = "flag-negation"
|
|
240
|
+
elif cli_tool == "typer":
|
|
241
|
+
set_case_conversion = "kebab"
|
|
242
|
+
set_parameter_prefix = "--"
|
|
243
|
+
set_nested_parameter_handling = "not-supported"
|
|
244
|
+
set_assignment_separator = "space"
|
|
245
|
+
set_bool_handling = "flag-negation"
|
|
246
|
+
else:
|
|
247
|
+
raise ValueError(f"Unsupported cli_tool: {cli_tool}")
|
|
248
|
+
|
|
249
|
+
# Option to override individual settings
|
|
250
|
+
set_case_conversion = case_conversion or set_case_conversion or DEFAULT_CASE_CONVERSION
|
|
251
|
+
set_parameter_prefix = parameter_prefix or set_parameter_prefix or DEFAULT_PARAMETER_PREFIX
|
|
252
|
+
set_nested_parameter_handling = (
|
|
253
|
+
nested_parameter_handling or set_nested_parameter_handling or DEFAULT_NESTED_PARAMETER_HANDLING
|
|
254
|
+
)
|
|
255
|
+
set_assignment_separator = assignment_separator or set_assignment_separator or DEFAULT_ASSIGNMENT_SEPARATOR
|
|
256
|
+
set_bool_handling = bool_handling or set_bool_handling or DEFAULT_BOOL_HANDLING
|
|
257
|
+
|
|
258
|
+
function_schema = schema_from_cli_function(
|
|
259
|
+
cli_function,
|
|
260
|
+
ignore_params=ignore_params,
|
|
261
|
+
handle_union_types=handle_union_types,
|
|
262
|
+
)
|
|
263
|
+
return CommandBuilderSchema(
|
|
264
|
+
cmd=cmd,
|
|
265
|
+
json_schema=function_schema,
|
|
266
|
+
case_conversion=set_case_conversion,
|
|
267
|
+
parameter_prefix=set_parameter_prefix,
|
|
268
|
+
nested_parameter_handling=set_nested_parameter_handling,
|
|
269
|
+
assignment_separator=set_assignment_separator,
|
|
270
|
+
bool_handling=set_bool_handling,
|
|
271
|
+
n_positional_args=n_positional_args,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def command_args_from_form_data(self, form_data: Dict[str, Any], include_cmd: bool = True) -> List[str]:
|
|
275
|
+
"""Convert form data dictionary to command line arguments list.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
form_data (Dict[str, Any]): The form data dictionary.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
List[str]: The command line arguments list.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
# Flatten nested parameters ({"optimizer": {"lr": 0.01}} -> {"optimizer.lr": 0.01})
|
|
285
|
+
if self.nested_parameter_handling == "dot":
|
|
286
|
+
sep = "."
|
|
287
|
+
form_data = flatten_dict.flatten(form_data) # Flatten doesn't support str directly
|
|
288
|
+
form_data = {sep.join(keys): value for keys, value in form_data.items()}
|
|
289
|
+
elif self.nested_parameter_handling == "none":
|
|
290
|
+
pass # No special handling
|
|
291
|
+
elif self.nested_parameter_handling == "not-supported":
|
|
292
|
+
# Ensure there are no nested parameters
|
|
293
|
+
for value in form_data.values():
|
|
294
|
+
if isinstance(value, dict):
|
|
295
|
+
raise ValueError("Nested parameters are not supported, but found nested parameter in form data.")
|
|
296
|
+
else:
|
|
297
|
+
raise ValueError(f"Unsupported nested_parameter_handling: {self.nested_parameter_handling}")
|
|
298
|
+
|
|
299
|
+
cmd_args: List[str] = []
|
|
300
|
+
|
|
301
|
+
if include_cmd:
|
|
302
|
+
cmd_args.extend(self.cmd.split(" "))
|
|
303
|
+
|
|
304
|
+
for position, (name, value) in enumerate(form_data.items()):
|
|
305
|
+
is_positional = position < self.n_positional_args
|
|
306
|
+
arg_parts = name_and_value_as_command_line_arguments(
|
|
307
|
+
name=name,
|
|
308
|
+
value=value,
|
|
309
|
+
is_positional=is_positional,
|
|
310
|
+
case_conversion=self.case_conversion,
|
|
311
|
+
parameter_prefix=self.parameter_prefix,
|
|
312
|
+
assignment_separator=self.assignment_separator,
|
|
313
|
+
bool_handling=self.bool_handling,
|
|
314
|
+
)
|
|
315
|
+
cmd_args.extend(arg_parts)
|
|
316
|
+
return cmd_args
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def name_and_value_as_command_line_arguments(
|
|
320
|
+
name: str,
|
|
321
|
+
value: Any,
|
|
322
|
+
is_positional: bool,
|
|
323
|
+
case_conversion: Literal["none", "kebab"],
|
|
324
|
+
parameter_prefix: str,
|
|
325
|
+
assignment_separator: Literal["space", "equals"],
|
|
326
|
+
bool_handling: Literal["none", "flag-negation"],
|
|
327
|
+
) -> List[str]:
|
|
328
|
+
# Uses 'repr' (instead of 'str') to convert values. This is is important when handling string parameter
|
|
329
|
+
# values that include space. E.g. this '--some_str Some String' would not work and
|
|
330
|
+
# should be "--some_str 'Some String'" to be parsed correctly by the CLI.
|
|
331
|
+
value_str = repr(value)
|
|
332
|
+
|
|
333
|
+
if case_conversion == "kebab":
|
|
334
|
+
name = snake_case_to_kebab_case(name)
|
|
335
|
+
elif case_conversion == "none":
|
|
336
|
+
pass # No conversion
|
|
337
|
+
else:
|
|
338
|
+
raise ValueError(f"Unsupported case_conversion: {case_conversion}")
|
|
339
|
+
|
|
340
|
+
# Handle boolean flags with flag-negation
|
|
341
|
+
if isinstance(value, bool):
|
|
342
|
+
if bool_handling == "flag-negation":
|
|
343
|
+
if value is False:
|
|
344
|
+
name = f"no-{name}" # E.g. '--no-flag' for False
|
|
345
|
+
return [f"{parameter_prefix}{name}"]
|
|
346
|
+
elif bool_handling == "none":
|
|
347
|
+
pass # No special handling
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError(f"Unsupported bool_handling: {bool_handling}")
|
|
350
|
+
|
|
351
|
+
name = f"{parameter_prefix}{name}"
|
|
352
|
+
|
|
353
|
+
if is_positional:
|
|
354
|
+
return [value_str]
|
|
355
|
+
|
|
356
|
+
if assignment_separator == "equals": # Is combined into one arg
|
|
357
|
+
return [f"{name}={value_str}"]
|
|
358
|
+
elif assignment_separator == "space": # Is separated into two args
|
|
359
|
+
return [name, value_str]
|
|
360
|
+
else:
|
|
361
|
+
raise ValueError(f"Unsupported assignment_separator: {assignment_separator}")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def path_of_function(cli_function: Callable[..., Any]) -> Path:
|
|
365
|
+
"""Get the default path for saving the launch schema JSON file from a CLI function."""
|
|
366
|
+
path_script = Path(cli_function.__globals__["__file__"])
|
|
367
|
+
script_relative_path = Path(path_script).relative_to(Path.cwd())
|
|
368
|
+
return script_relative_path
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def simulate_form_data(function: Callable[..., Any], user_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
372
|
+
"""
|
|
373
|
+
The purpose of this module is to simulate the output of a form submission
|
|
374
|
+
from a user.
|
|
375
|
+
|
|
376
|
+
The function is mostly used for testing and validation purposes.
|
|
377
|
+
"""
|
|
378
|
+
pydantic_model = pydantic_model_from_cli_function(function)
|
|
379
|
+
cli_args = pydantic_model(**user_args) # Validate args
|
|
380
|
+
return cli_args.model_dump(mode="json")
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def schema_from_cli_function(
|
|
384
|
+
func: Callable[..., Any],
|
|
385
|
+
ignore_params: tuple[str, ...] = (),
|
|
386
|
+
handle_union_types: bool = True,
|
|
387
|
+
) -> Dict:
|
|
388
|
+
"""
|
|
389
|
+
Generate a function schema from a cli function signature.
|
|
390
|
+
|
|
391
|
+
This function creates a json schema based on the parameters of the provided
|
|
392
|
+
function, preserving type annotations, default values, and descriptions from
|
|
393
|
+
Annotated types. It then returns the JSON schema representation of that model.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
func (Callable[..., Any]): The function to generate the schema from.
|
|
397
|
+
convert_to_kebab_case (bool, optional): Whether to convert parameter names
|
|
398
|
+
from snake_case to kebab-case in the resulting schema. Defaults to True.
|
|
399
|
+
ignore_params (tuple[str, ...], optional): Parameters to ignore when generating the schema. Defaults to ().
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
BaseModel: The Pydantic model representing the function's parameters.
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
pydantic_model: Type[BaseModel] = pydantic_model_from_cli_function(func, ignore_params=ignore_params)
|
|
406
|
+
function_schema = pydantic_model.model_json_schema()
|
|
407
|
+
if handle_union_types:
|
|
408
|
+
function_schema = convert_union_types_in_schema(function_schema)
|
|
409
|
+
return function_schema
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def convert_union_types_in_schema(schema: Dict) -> Dict:
|
|
413
|
+
"""
|
|
414
|
+
Convert 'anyOf' parameters (typically from 'Union[...]' or 'Optional[...]') in the JSON schema
|
|
415
|
+
E.g.
|
|
416
|
+
|
|
417
|
+
Example 1:
|
|
418
|
+
The python type:
|
|
419
|
+
Optional[int]
|
|
420
|
+
|
|
421
|
+
Will be represented in the schema as:
|
|
422
|
+
{
|
|
423
|
+
"anyOf": [
|
|
424
|
+
{"type": "integer"},
|
|
425
|
+
{"type": "null"}
|
|
426
|
+
]
|
|
427
|
+
}
|
|
428
|
+
We will improve form readability by converting it to:
|
|
429
|
+
{
|
|
430
|
+
"anyOf": [
|
|
431
|
+
{
|
|
432
|
+
"title": "Select Integer",
|
|
433
|
+
"type": "integer"
|
|
434
|
+
},
|
|
435
|
+
{
|
|
436
|
+
"title": "Select None",
|
|
437
|
+
"type": "null"
|
|
438
|
+
}
|
|
439
|
+
],
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
Example 2:
|
|
443
|
+
Python type:
|
|
444
|
+
float_or_int: Annotated[float | int, "A float or int value"] = 3.14,
|
|
445
|
+
Schema before conversion:
|
|
446
|
+
{
|
|
447
|
+
"anyOf": [
|
|
448
|
+
{"type": "number"},
|
|
449
|
+
{"type": "integer"}
|
|
450
|
+
]
|
|
451
|
+
}
|
|
452
|
+
Schema after conversion:
|
|
453
|
+
{
|
|
454
|
+
"anyOf": [
|
|
455
|
+
{
|
|
456
|
+
"title": "Select Number",
|
|
457
|
+
"type": "number"
|
|
458
|
+
},
|
|
459
|
+
{
|
|
460
|
+
"title": "Select Integer",
|
|
461
|
+
"type": "integer"
|
|
462
|
+
}
|
|
463
|
+
],
|
|
464
|
+
},
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
schema (Dict): The JSON schema dictionary.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Dict: The modified JSON schema dictionary with union types handled.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def type_to_title(type_name: str) -> str:
|
|
474
|
+
"""Convert JSON schema type to a display title"""
|
|
475
|
+
if type_name == "null":
|
|
476
|
+
return "Select None"
|
|
477
|
+
return f"Select '{type_name.capitalize()}'"
|
|
478
|
+
|
|
479
|
+
def add_titles_to_anyof(obj: Any) -> Any:
|
|
480
|
+
"""Recursively add titles to anyOf items"""
|
|
481
|
+
if isinstance(obj, dict):
|
|
482
|
+
converted: Dict[str, Any] = {}
|
|
483
|
+
for key, value in obj.items():
|
|
484
|
+
if key == "anyOf" and isinstance(value, list):
|
|
485
|
+
# Add titles to each item in anyOf
|
|
486
|
+
converted[key] = []
|
|
487
|
+
for item in value:
|
|
488
|
+
if isinstance(item, dict) and "type" in item:
|
|
489
|
+
new_item = item.copy()
|
|
490
|
+
if "title" not in new_item:
|
|
491
|
+
new_item["title"] = type_to_title(item["type"])
|
|
492
|
+
converted[key].append(new_item)
|
|
493
|
+
else:
|
|
494
|
+
converted[key].append(add_titles_to_anyof(item))
|
|
495
|
+
else:
|
|
496
|
+
converted[key] = add_titles_to_anyof(value)
|
|
497
|
+
return converted
|
|
498
|
+
elif isinstance(obj, list):
|
|
499
|
+
return [add_titles_to_anyof(item) for item in obj]
|
|
500
|
+
else:
|
|
501
|
+
return obj
|
|
502
|
+
|
|
503
|
+
return add_titles_to_anyof(schema)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def pydantic_model_from_cli_function(
|
|
507
|
+
func: Callable[..., Any],
|
|
508
|
+
*,
|
|
509
|
+
model_name: Optional[str] = None,
|
|
510
|
+
ignore_params: tuple[str, ...] = (),
|
|
511
|
+
forbid_varargs: bool = True,
|
|
512
|
+
forbid_positional_only: bool = True,
|
|
513
|
+
use_docstring_as_description: bool = True,
|
|
514
|
+
config: Optional[ConfigDict] = None,
|
|
515
|
+
) -> type[BaseModel]:
|
|
516
|
+
"""
|
|
517
|
+
Create a Pydantic BaseModel from a function signature.
|
|
518
|
+
|
|
519
|
+
Preserves Annotated[...] metadata (e.g. Field constraints), because the
|
|
520
|
+
parameter annotation is forwarded unchanged into pydantic.create_model().
|
|
521
|
+
|
|
522
|
+
Notes:
|
|
523
|
+
- Parameters must be type-annotated.
|
|
524
|
+
- Required parameters become required model fields.
|
|
525
|
+
- Defaults become model defaults.
|
|
526
|
+
"""
|
|
527
|
+
sig = inspect.signature(func)
|
|
528
|
+
|
|
529
|
+
# Collect fields as: name -> (annotation, default)
|
|
530
|
+
fields: dict[str, tuple[Any, Any]] = {}
|
|
531
|
+
|
|
532
|
+
params = list(sig.parameters.values())
|
|
533
|
+
|
|
534
|
+
if len(ignore_params) > 0:
|
|
535
|
+
params_reduced_set = []
|
|
536
|
+
for param in params:
|
|
537
|
+
if param.name in ignore_params:
|
|
538
|
+
continue
|
|
539
|
+
params_reduced_set.append(param)
|
|
540
|
+
params = params_reduced_set
|
|
541
|
+
|
|
542
|
+
# Extract docstring parameter descriptions and add them to Field(...) if not already present.
|
|
543
|
+
function_docstring = (func.__doc__ or "").strip()
|
|
544
|
+
parsed_doc_string = docstring_parser.parse(function_docstring)
|
|
545
|
+
docstring_params = {p.arg_name: p.description for p in parsed_doc_string.params}
|
|
546
|
+
|
|
547
|
+
for param in params:
|
|
548
|
+
# Reject both '*args' and '**kwargs' arguments
|
|
549
|
+
if forbid_varargs and param.kind in (
|
|
550
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
551
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
552
|
+
):
|
|
553
|
+
raise TypeError(
|
|
554
|
+
f"{func.__name__}: varargs (*args/**kwargs) are not supported for model generation "
|
|
555
|
+
f"(parameter '{param.name}')"
|
|
556
|
+
)
|
|
557
|
+
# Reject positional-only parameters. E.g. for 'def func(var0, /, var1)' - 'var0' is positional-only.
|
|
558
|
+
if forbid_positional_only and param.kind is inspect.Parameter.POSITIONAL_ONLY:
|
|
559
|
+
raise TypeError(
|
|
560
|
+
f"{func.__name__}: positional-only parameters are not supported for model generation "
|
|
561
|
+
f"(parameter '{param.name}')"
|
|
562
|
+
)
|
|
563
|
+
# Ensure parameter has a type annotation 'param: int'
|
|
564
|
+
if param.annotation is inspect.Parameter.empty:
|
|
565
|
+
raise TypeError(
|
|
566
|
+
f"{func.__name__}: parameter '{param.name}' must have a type annotation ('{param.name}: [TYPE]'). "
|
|
567
|
+
f"An example of this would be '{param.name}: int' or '{param.name}: Optional[str] = None'."
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
base_type, field_info = function_param_as_type_and_field_info(param)
|
|
571
|
+
|
|
572
|
+
default = derive_default_value(field_info, param)
|
|
573
|
+
if field_info.description is None and param.name in docstring_params:
|
|
574
|
+
field_info_attributes = field_info.asdict()["attributes"]
|
|
575
|
+
field_info_attributes["description"] = docstring_params[param.name]
|
|
576
|
+
field_info = Field(**field_info_attributes)
|
|
577
|
+
assert isinstance(field_info, FieldInfo), "Expected FieldInfo in Annotated metadata"
|
|
578
|
+
annotation = Annotated[base_type, field_info] # type: ignore[valid-type]
|
|
579
|
+
fields[param.name] = (annotation, default)
|
|
580
|
+
name = model_name or f"{func.__name__}Args"
|
|
581
|
+
|
|
582
|
+
model_kwargs: dict[str, Any] = {}
|
|
583
|
+
if config is not None:
|
|
584
|
+
model_kwargs["__config__"] = config
|
|
585
|
+
if use_docstring_as_description and function_docstring:
|
|
586
|
+
model_kwargs["__doc__"] = inspect.getdoc(func)
|
|
587
|
+
|
|
588
|
+
pydantic_model = create_model(name, **model_kwargs, **fields)
|
|
589
|
+
|
|
590
|
+
return pydantic_model
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def snake_case_to_kebab_case(s: str) -> str:
|
|
594
|
+
return s.replace("_", "-")
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def derive_default_value(field_info: FieldInfo, param: inspect.Parameter) -> Any:
|
|
598
|
+
# The default value of an argument can be defined in two ways for Annotated types:
|
|
599
|
+
# 1) After the equal sign in the function signature.
|
|
600
|
+
# Example: "param: int = 5"
|
|
601
|
+
# 2) In the Field(...) info inside the Annotated metadata.
|
|
602
|
+
# Example: "param: Annotated[int, Field(default=5)]"
|
|
603
|
+
|
|
604
|
+
default_from_equal_sign = param.default
|
|
605
|
+
if default_from_equal_sign is inspect.Parameter.empty:
|
|
606
|
+
default_from_equal_sign = PydanticUndefined
|
|
607
|
+
default_from_field_info = field_info.default
|
|
608
|
+
if default_from_field_info is PydanticUndefined:
|
|
609
|
+
default_from_field_info = PydanticUndefined
|
|
610
|
+
|
|
611
|
+
defined_default_from_equal_sign = default_from_equal_sign is not PydanticUndefined
|
|
612
|
+
defined_default_from_field_info = default_from_field_info is not PydanticUndefined
|
|
613
|
+
|
|
614
|
+
if not defined_default_from_equal_sign and not defined_default_from_field_info:
|
|
615
|
+
return ... # No default defined
|
|
616
|
+
|
|
617
|
+
if defined_default_from_equal_sign and defined_default_from_field_info:
|
|
618
|
+
if default_from_field_info != default_from_equal_sign:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
f"The default value for '{param.name}' have been defined in two ways!! - and they are conflicting!!"
|
|
621
|
+
f"The 'Field(default={default_from_field_info})' and ' = {default_from_equal_sign}'"
|
|
622
|
+
)
|
|
623
|
+
# After the equal sign takes precedence over Field(...) default.
|
|
624
|
+
if default_from_equal_sign is not PydanticUndefined:
|
|
625
|
+
return default_from_equal_sign
|
|
626
|
+
return default_from_field_info
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def is_field_info(obj: Any) -> bool:
|
|
630
|
+
return hasattr(obj, "__class__") and obj.__class__.__name__ == "FieldInfo"
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def is_cyclopts_parameter(obj: Any) -> bool:
|
|
634
|
+
# Checks if the object is 'cyclopts.Parameter' without adding the 'cyclopts' package as a dependency.
|
|
635
|
+
is_cyclopts = (
|
|
636
|
+
hasattr(obj, "__class__") and obj.__class__.__name__ == "Parameter" and "cyclopts" in obj.__class__.__module__
|
|
637
|
+
)
|
|
638
|
+
return is_cyclopts
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def is_typer(obj: Any) -> bool:
|
|
642
|
+
# Checks if the object is 'typer.Argument'/'typer.Option' without adding the 'typer' package as a dependency.
|
|
643
|
+
is_typer_arg = (
|
|
644
|
+
hasattr(obj, "__class__")
|
|
645
|
+
and obj.__class__.__name__ in ("ArgumentInfo", "OptionInfo")
|
|
646
|
+
and "typer" in obj.__class__.__module__
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
is_argument = is_typer_arg and obj.__class__.__name__ == "ArgumentInfo"
|
|
650
|
+
if is_argument:
|
|
651
|
+
raise TypeError(
|
|
652
|
+
"typer.Argument is not supported for schema generation, "
|
|
653
|
+
"as it represents a positional argument. Use typer.Option(...) instead."
|
|
654
|
+
)
|
|
655
|
+
return is_typer_arg
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def function_param_as_type_and_field_info(p: inspect.Parameter) -> Tuple[Type, FieldInfo]:
|
|
659
|
+
if get_origin(p.annotation) is not Annotated:
|
|
660
|
+
base_type = p.annotation
|
|
661
|
+
return base_type, Field()
|
|
662
|
+
|
|
663
|
+
args = get_args(p.annotation)
|
|
664
|
+
base_type = args[0]
|
|
665
|
+
annotations = convert_annotations_to_field_info(args[1:])
|
|
666
|
+
for ann in annotations:
|
|
667
|
+
if is_field_info(ann):
|
|
668
|
+
return base_type, ann # Return the first FieldInfo found
|
|
669
|
+
|
|
670
|
+
return base_type, Field()
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def convert_annotations_to_field_info(annotations: Tuple[Any, ...]) -> List[Any]:
|
|
674
|
+
annotations_converted = []
|
|
675
|
+
for ann in annotations:
|
|
676
|
+
if isinstance(ann, str):
|
|
677
|
+
annotations_converted.append(Field(description=ann))
|
|
678
|
+
elif is_cyclopts_parameter(ann):
|
|
679
|
+
param_help = getattr(ann, "help", PydanticUndefined)
|
|
680
|
+
annotations_converted.append(Field(description=param_help))
|
|
681
|
+
elif is_typer(ann):
|
|
682
|
+
typer_help = getattr(ann, "help", PydanticUndefined)
|
|
683
|
+
annotations_converted.append(Field(description=typer_help))
|
|
684
|
+
else:
|
|
685
|
+
annotations_converted.append(ann) # Keep as is
|
|
686
|
+
return annotations_converted
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hafnia
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.2
|
|
4
4
|
Summary: Python SDK for communication with Hafnia platform.
|
|
5
5
|
Author-email: Milestone Systems <hafniaplatform@milestone.dk>
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Requires-Python: >=3.10
|
|
8
8
|
Requires-Dist: boto3>=1.35.91
|
|
9
9
|
Requires-Dist: click>=8.1.8
|
|
10
|
+
Requires-Dist: docstring-parser>=0.17.0
|
|
10
11
|
Requires-Dist: emoji>=2.14.1
|
|
11
12
|
Requires-Dist: flatten-dict>=0.4.2
|
|
12
13
|
Requires-Dist: keyring>=25.6.0
|
|
@@ -33,6 +33,7 @@ hafnia/dataset/primitives/primitive.py,sha256=Wvby0sCGgYj8ec39PLcHsmip5VKL96ZSCz
|
|
|
33
33
|
hafnia/dataset/primitives/segmentation.py,sha256=hnXIUklkuDMxYkUaff1bgRTcXI_2b32RIbobxR3ejzk,2017
|
|
34
34
|
hafnia/dataset/primitives/utils.py,sha256=3gT1as-xXEj8CamoIuBb9gQwUN9Ae9qnqtqF_uEe0zo,1993
|
|
35
35
|
hafnia/experiment/__init__.py,sha256=OEFE6HqhO5zcTCLZcPcPVjIg7wMFFnvZ1uOtAVhRz7M,85
|
|
36
|
+
hafnia/experiment/command_builder.py,sha256=9_Of8WsJyyaCSK4C0vJMLVaUU6qEhnVNaIUAEHc80rI,27580
|
|
36
37
|
hafnia/experiment/hafnia_logger.py,sha256=BHIOLAds_3JxT0cev_ikUH0XQVIxBJTkcBSx2Q_SIk0,10894
|
|
37
38
|
hafnia/platform/__init__.py,sha256=L_Q7CNpsJ0HMNPy_rLlLK5RhmuCU7IF4BchxKv6amYc,782
|
|
38
39
|
hafnia/platform/builder.py,sha256=kUEuj5-qtL1uk5v2tUvOCREn5yV-G4Fr6F31haIAb5E,5808
|
|
@@ -55,8 +56,8 @@ hafnia_cli/keychain.py,sha256=bNyjjULVQu7kV338wUC65UvbCwmSGOmEjKWPLIQjT0k,2555
|
|
|
55
56
|
hafnia_cli/profile_cmds.py,sha256=yTyOsPsUssLCzFIxURkxbKrFEhYIVDlUC0G2s5Uks-U,3476
|
|
56
57
|
hafnia_cli/runc_cmds.py,sha256=7P5TjF6KA9K4OKPG1qC_0gteXfLJbXlA858WWrosoGQ,5098
|
|
57
58
|
hafnia_cli/trainer_package_cmds.py,sha256=hUBc6gCMV28fcAA0xQdXKL1z-a3aL9lMWcVqjvHO1Uo,2326
|
|
58
|
-
hafnia-0.5.
|
|
59
|
-
hafnia-0.5.
|
|
60
|
-
hafnia-0.5.
|
|
61
|
-
hafnia-0.5.
|
|
62
|
-
hafnia-0.5.
|
|
59
|
+
hafnia-0.5.2.dist-info/METADATA,sha256=U8Qy_hlNHR-hfgFqsvAkM8dfzS28v9nRI9ne5oXA87o,19312
|
|
60
|
+
hafnia-0.5.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
61
|
+
hafnia-0.5.2.dist-info/entry_points.txt,sha256=j2jsj1pqajLAiSOnF7sq66A3d1SVeHPKVTVyIFzipSA,52
|
|
62
|
+
hafnia-0.5.2.dist-info/licenses/LICENSE,sha256=wLZw1B7_mod_CO1H8LXqQgfqlWD6QceJR8--LJYRZGE,1078
|
|
63
|
+
hafnia-0.5.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|