kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -1,12 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
See our docs for details about our datamodel: https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
from __future__ import annotations
|
|
2
6
|
|
|
3
7
|
import json
|
|
8
|
+
import math
|
|
9
|
+
import random
|
|
4
10
|
from enum import Enum, IntEnum
|
|
5
|
-
from typing import TYPE_CHECKING, Dict, List, Type, Union
|
|
11
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union
|
|
6
12
|
|
|
7
13
|
import jsonschema
|
|
8
14
|
import jsonschema.exceptions
|
|
9
|
-
from pydantic import
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
Field,
|
|
18
|
+
ValidationInfo,
|
|
19
|
+
model_validator,
|
|
20
|
+
)
|
|
10
21
|
from typing_extensions import Self
|
|
11
22
|
|
|
12
23
|
from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
|
|
@@ -14,6 +25,8 @@ from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
|
|
|
14
25
|
from .basemodel import (
|
|
15
26
|
ID_FIELD,
|
|
16
27
|
ID_TYPE,
|
|
28
|
+
NAME_FIELD,
|
|
29
|
+
SHORT_NAME_FIELD,
|
|
17
30
|
KilnBaseModel,
|
|
18
31
|
KilnParentedModel,
|
|
19
32
|
KilnParentModel,
|
|
@@ -39,27 +52,23 @@ __all__ = [
|
|
|
39
52
|
"TaskOutputRatingType",
|
|
40
53
|
"TaskRequirement",
|
|
41
54
|
"TaskDeterminism",
|
|
55
|
+
"strict_mode",
|
|
56
|
+
"set_strict_mode",
|
|
42
57
|
]
|
|
43
58
|
|
|
44
59
|
|
|
45
|
-
#
|
|
46
|
-
#
|
|
47
|
-
|
|
60
|
+
# We want to be hard on ourselves for data completeness generated by the Kiln App, but don't want to make it hard for users to use the datamodel/library.
|
|
61
|
+
# Strict mode enables extra validations that we want to enforce in Kiln App (and any other client that wants best practices), but not in the library (unless they opt in)
|
|
62
|
+
_strict_mode: bool = False
|
|
48
63
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
SHORT_NAME_FIELD = Field(
|
|
58
|
-
min_length=1,
|
|
59
|
-
max_length=32,
|
|
60
|
-
pattern=NAME_REGEX,
|
|
61
|
-
description="A name for this entity",
|
|
62
|
-
)
|
|
64
|
+
|
|
65
|
+
def strict_mode() -> bool:
|
|
66
|
+
return _strict_mode
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def set_strict_mode(value: bool) -> None:
|
|
70
|
+
global _strict_mode
|
|
71
|
+
_strict_mode = value
|
|
63
72
|
|
|
64
73
|
|
|
65
74
|
class Priority(IntEnum):
|
|
@@ -137,8 +146,9 @@ class TaskOutput(KilnBaseModel):
|
|
|
137
146
|
output: str = Field(
|
|
138
147
|
description="The output of the task. JSON formatted for structured output, plaintext for unstructured output."
|
|
139
148
|
)
|
|
140
|
-
source: DataSource = Field(
|
|
141
|
-
description="The source of the output: human or synthetic."
|
|
149
|
+
source: DataSource | None = Field(
|
|
150
|
+
description="The source of the output: human or synthetic.",
|
|
151
|
+
default=None,
|
|
142
152
|
)
|
|
143
153
|
rating: TaskOutputRating | None = Field(
|
|
144
154
|
default=None, description="The rating of the output"
|
|
@@ -155,6 +165,83 @@ class TaskOutput(KilnBaseModel):
|
|
|
155
165
|
raise ValueError(f"Output does not match task output schema: {e}")
|
|
156
166
|
return self
|
|
157
167
|
|
|
168
|
+
@model_validator(mode="after")
|
|
169
|
+
def validate_output_source(self, info: ValidationInfo) -> Self:
|
|
170
|
+
# On strict mode and not loaded from file, we validate output_source is not None.
|
|
171
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
172
|
+
if not strict_mode():
|
|
173
|
+
return self
|
|
174
|
+
if self.loaded_from_file(info):
|
|
175
|
+
return self
|
|
176
|
+
if self.source is None:
|
|
177
|
+
raise ValueError("Output source is required when strict mode is enabled")
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class FineTuneStatusType(str, Enum):
|
|
182
|
+
"""
|
|
183
|
+
The status type of a fine-tune (running, completed, failed, etc).
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
unknown = "unknown" # server error
|
|
187
|
+
pending = "pending"
|
|
188
|
+
running = "running"
|
|
189
|
+
completed = "completed"
|
|
190
|
+
failed = "failed"
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class Finetune(KilnParentedModel):
|
|
194
|
+
name: str = NAME_FIELD
|
|
195
|
+
description: str | None = Field(
|
|
196
|
+
default=None,
|
|
197
|
+
description="A description of the fine-tune for you and your team. Not used in training.",
|
|
198
|
+
)
|
|
199
|
+
provider: str = Field(
|
|
200
|
+
description="The provider to use for the fine-tune (e.g. 'openai')."
|
|
201
|
+
)
|
|
202
|
+
base_model_id: str = Field(
|
|
203
|
+
description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
|
|
204
|
+
)
|
|
205
|
+
provider_id: str | None = Field(
|
|
206
|
+
default=None,
|
|
207
|
+
description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.",
|
|
208
|
+
)
|
|
209
|
+
fine_tune_model_id: str | None = Field(
|
|
210
|
+
default=None,
|
|
211
|
+
description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.",
|
|
212
|
+
)
|
|
213
|
+
dataset_split_id: str = Field(
|
|
214
|
+
description="The ID of the dataset split to use for this fine-tune.",
|
|
215
|
+
)
|
|
216
|
+
train_split_name: str = Field(
|
|
217
|
+
default="train",
|
|
218
|
+
description="The name of the training split to use for this fine-tune.",
|
|
219
|
+
)
|
|
220
|
+
validation_split_name: str | None = Field(
|
|
221
|
+
default=None,
|
|
222
|
+
description="The name of the validation split to use for this fine-tune. Optional.",
|
|
223
|
+
)
|
|
224
|
+
parameters: dict[str, str | int | float | bool] = Field(
|
|
225
|
+
default={},
|
|
226
|
+
description="The parameters to use for this fine-tune. These are provider-specific.",
|
|
227
|
+
)
|
|
228
|
+
system_message: str = Field(
|
|
229
|
+
description="The system message to use for this fine-tune.",
|
|
230
|
+
)
|
|
231
|
+
latest_status: FineTuneStatusType = Field(
|
|
232
|
+
default=FineTuneStatusType.unknown,
|
|
233
|
+
description="The latest known status of this fine-tune. Not updated in real time.",
|
|
234
|
+
)
|
|
235
|
+
properties: Dict[str, str | int | float] = Field(
|
|
236
|
+
default={},
|
|
237
|
+
description="Properties of the fine-tune. Different providers may use different properties.",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def parent_task(self) -> Task | None:
|
|
241
|
+
if not isinstance(self.parent, Task):
|
|
242
|
+
return None
|
|
243
|
+
return self.parent
|
|
244
|
+
|
|
158
245
|
|
|
159
246
|
class DataSourceType(str, Enum):
|
|
160
247
|
"""
|
|
@@ -277,8 +364,8 @@ class TaskRun(KilnParentedModel):
|
|
|
277
364
|
input: str = Field(
|
|
278
365
|
description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input."
|
|
279
366
|
)
|
|
280
|
-
input_source: DataSource = Field(
|
|
281
|
-
description="The source of the input: human or synthetic."
|
|
367
|
+
input_source: DataSource | None = Field(
|
|
368
|
+
default=None, description="The source of the input: human or synthetic."
|
|
282
369
|
)
|
|
283
370
|
|
|
284
371
|
output: TaskOutput = Field(description="The output of the task run.")
|
|
@@ -343,6 +430,172 @@ class TaskRun(KilnParentedModel):
|
|
|
343
430
|
)
|
|
344
431
|
return self
|
|
345
432
|
|
|
433
|
+
@model_validator(mode="after")
|
|
434
|
+
def validate_input_source(self, info: ValidationInfo) -> Self:
|
|
435
|
+
# On strict mode and not loaded from file, we validate input_source is not None.
|
|
436
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
437
|
+
if not strict_mode():
|
|
438
|
+
return self
|
|
439
|
+
if self.loaded_from_file(info):
|
|
440
|
+
return self
|
|
441
|
+
if self.input_source is None:
|
|
442
|
+
raise ValueError("input_source is required when strict mode is enabled")
|
|
443
|
+
return self
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
# Define the type alias for clarity
|
|
447
|
+
DatasetFilter = Callable[[TaskRun], bool]
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def AllDatasetFilter(_: TaskRun) -> bool:
|
|
451
|
+
return True
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
|
|
455
|
+
if task_run.output is None or task_run.output.rating is None:
|
|
456
|
+
return False
|
|
457
|
+
return task_run.output.rating.is_high_quality()
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class DatasetSplitDefinition(BaseModel):
|
|
461
|
+
"""
|
|
462
|
+
A definition of a split in a dataset.
|
|
463
|
+
|
|
464
|
+
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
name: str = NAME_FIELD
|
|
468
|
+
description: str | None = Field(
|
|
469
|
+
default=None,
|
|
470
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
471
|
+
)
|
|
472
|
+
percentage: float = Field(
|
|
473
|
+
ge=0.0,
|
|
474
|
+
le=1.0,
|
|
475
|
+
description="The percentage of the dataset that this split represents (between 0 and 1).",
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
AllSplitDefinition: list[DatasetSplitDefinition] = [
|
|
480
|
+
DatasetSplitDefinition(name="all", percentage=1.0)
|
|
481
|
+
]
|
|
482
|
+
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
483
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
484
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
485
|
+
]
|
|
486
|
+
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
487
|
+
DatasetSplitDefinition(name="train", percentage=0.6),
|
|
488
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
489
|
+
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
490
|
+
]
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class DatasetSplit(KilnParentedModel):
|
|
494
|
+
"""
|
|
495
|
+
A collection of task runs, with optional splits (train, test, validation).
|
|
496
|
+
|
|
497
|
+
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
|
|
498
|
+
|
|
499
|
+
Maintains a list of IDs for each split, to avoid data duplication.
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
name: str = NAME_FIELD
|
|
503
|
+
description: str | None = Field(
|
|
504
|
+
default=None,
|
|
505
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
506
|
+
)
|
|
507
|
+
splits: list[DatasetSplitDefinition] = Field(
|
|
508
|
+
default_factory=list,
|
|
509
|
+
description="The splits in the dataset.",
|
|
510
|
+
)
|
|
511
|
+
split_contents: dict[str, list[str]] = Field(
|
|
512
|
+
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
@model_validator(mode="after")
|
|
516
|
+
def validate_split_percentages(self) -> "DatasetSplit":
|
|
517
|
+
total = sum(split.percentage for split in self.splits)
|
|
518
|
+
if not math.isclose(total, 1.0, rel_tol=1e-9):
|
|
519
|
+
raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
|
|
520
|
+
return self
|
|
521
|
+
|
|
522
|
+
@classmethod
|
|
523
|
+
def from_task(
|
|
524
|
+
cls,
|
|
525
|
+
name: str,
|
|
526
|
+
task: "Task",
|
|
527
|
+
splits: list[DatasetSplitDefinition],
|
|
528
|
+
filter: DatasetFilter = AllDatasetFilter,
|
|
529
|
+
description: str | None = None,
|
|
530
|
+
):
|
|
531
|
+
"""
|
|
532
|
+
Build a dataset split from a task.
|
|
533
|
+
"""
|
|
534
|
+
split_contents = cls.build_split_contents(task, splits, filter)
|
|
535
|
+
return cls(
|
|
536
|
+
parent=task,
|
|
537
|
+
name=name,
|
|
538
|
+
description=description,
|
|
539
|
+
splits=splits,
|
|
540
|
+
split_contents=split_contents,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
@classmethod
|
|
544
|
+
def build_split_contents(
|
|
545
|
+
cls,
|
|
546
|
+
task: "Task",
|
|
547
|
+
splits: list[DatasetSplitDefinition],
|
|
548
|
+
filter: DatasetFilter,
|
|
549
|
+
) -> dict[str, list[str]]:
|
|
550
|
+
valid_ids = []
|
|
551
|
+
for task_run in task.runs():
|
|
552
|
+
if filter(task_run):
|
|
553
|
+
valid_ids.append(task_run.id)
|
|
554
|
+
|
|
555
|
+
# Shuffle and split by split percentage
|
|
556
|
+
random.shuffle(valid_ids)
|
|
557
|
+
split_contents = {}
|
|
558
|
+
start_idx = 0
|
|
559
|
+
remaining_items = len(valid_ids)
|
|
560
|
+
|
|
561
|
+
# Handle all splits except the last one
|
|
562
|
+
for split in splits[:-1]:
|
|
563
|
+
split_size = round(len(valid_ids) * split.percentage)
|
|
564
|
+
split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
|
|
565
|
+
start_idx += split_size
|
|
566
|
+
remaining_items -= split_size
|
|
567
|
+
|
|
568
|
+
# Last split gets all remaining items (for rounding)
|
|
569
|
+
if splits:
|
|
570
|
+
split_contents[splits[-1].name] = valid_ids[start_idx:]
|
|
571
|
+
|
|
572
|
+
return split_contents
|
|
573
|
+
|
|
574
|
+
def parent_task(self) -> "Task | None":
|
|
575
|
+
# inline import to avoid circular import
|
|
576
|
+
from kiln_ai.datamodel import Task
|
|
577
|
+
|
|
578
|
+
if not isinstance(self.parent, Task):
|
|
579
|
+
return None
|
|
580
|
+
return self.parent
|
|
581
|
+
|
|
582
|
+
def missing_count(self) -> int:
|
|
583
|
+
"""
|
|
584
|
+
Returns:
|
|
585
|
+
int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
|
|
586
|
+
"""
|
|
587
|
+
parent = self.parent_task()
|
|
588
|
+
if parent is None:
|
|
589
|
+
raise ValueError("DatasetSplit has no parent task")
|
|
590
|
+
|
|
591
|
+
runs = parent.runs()
|
|
592
|
+
all_ids = set(run.id for run in runs)
|
|
593
|
+
all_ids_in_splits = set()
|
|
594
|
+
for ids in self.split_contents.values():
|
|
595
|
+
all_ids_in_splits.update(ids)
|
|
596
|
+
missing = all_ids_in_splits - all_ids
|
|
597
|
+
return len(missing)
|
|
598
|
+
|
|
346
599
|
|
|
347
600
|
class TaskRequirement(BaseModel):
|
|
348
601
|
"""
|
|
@@ -376,7 +629,11 @@ class TaskDeterminism(str, Enum):
|
|
|
376
629
|
class Task(
|
|
377
630
|
KilnParentedModel,
|
|
378
631
|
KilnParentModel,
|
|
379
|
-
parent_of={
|
|
632
|
+
parent_of={
|
|
633
|
+
"runs": TaskRun,
|
|
634
|
+
"dataset_splits": DatasetSplit,
|
|
635
|
+
"finetunes": Finetune,
|
|
636
|
+
},
|
|
380
637
|
):
|
|
381
638
|
"""
|
|
382
639
|
Represents a specific task to be performed, with associated requirements and validation rules.
|
|
@@ -416,6 +673,12 @@ class Task(
|
|
|
416
673
|
def runs(self) -> list[TaskRun]:
|
|
417
674
|
return super().runs() # type: ignore
|
|
418
675
|
|
|
676
|
+
def dataset_splits(self) -> list[DatasetSplit]:
|
|
677
|
+
return super().dataset_splits() # type: ignore
|
|
678
|
+
|
|
679
|
+
def finetunes(self) -> list[Finetune]:
|
|
680
|
+
return super().finetunes() # type: ignore
|
|
681
|
+
|
|
419
682
|
|
|
420
683
|
class Project(KilnParentModel, parent_of={"tasks": Task}):
|
|
421
684
|
"""
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
2
4
|
import shutil
|
|
3
5
|
import uuid
|
|
4
6
|
from abc import ABCMeta
|
|
@@ -6,7 +8,6 @@ from builtins import classmethod
|
|
|
6
8
|
from datetime import datetime
|
|
7
9
|
from pathlib import Path
|
|
8
10
|
from typing import (
|
|
9
|
-
TYPE_CHECKING,
|
|
10
11
|
Any,
|
|
11
12
|
Dict,
|
|
12
13
|
List,
|
|
@@ -20,12 +21,14 @@ from pydantic import (
|
|
|
20
21
|
ConfigDict,
|
|
21
22
|
Field,
|
|
22
23
|
ValidationError,
|
|
24
|
+
ValidationInfo,
|
|
23
25
|
computed_field,
|
|
24
26
|
model_validator,
|
|
25
27
|
)
|
|
26
28
|
from pydantic_core import ErrorDetails
|
|
27
29
|
from typing_extensions import Self
|
|
28
30
|
|
|
31
|
+
from kiln_ai.datamodel.model_cache import ModelCache
|
|
29
32
|
from kiln_ai.utils.config import Config
|
|
30
33
|
from kiln_ai.utils.formatting import snake_case
|
|
31
34
|
|
|
@@ -39,6 +42,35 @@ T = TypeVar("T", bound="KilnBaseModel")
|
|
|
39
42
|
PT = TypeVar("PT", bound="KilnParentedModel")
|
|
40
43
|
|
|
41
44
|
|
|
45
|
+
# Naming conventions:
|
|
46
|
+
# 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
|
|
47
|
+
# 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
|
|
48
|
+
|
|
49
|
+
# Filename compatible names
|
|
50
|
+
NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
|
|
51
|
+
NAME_FIELD = Field(
|
|
52
|
+
min_length=1,
|
|
53
|
+
max_length=120,
|
|
54
|
+
pattern=NAME_REGEX,
|
|
55
|
+
description="A name for this entity.",
|
|
56
|
+
)
|
|
57
|
+
SHORT_NAME_FIELD = Field(
|
|
58
|
+
min_length=1,
|
|
59
|
+
max_length=32,
|
|
60
|
+
pattern=NAME_REGEX,
|
|
61
|
+
description="A name for this entity",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def string_to_valid_name(name: str) -> str:
|
|
66
|
+
# Replace any character not allowed by NAME_REGEX with an underscore
|
|
67
|
+
valid_name = re.sub(r"[^A-Za-z0-9 _-]", "_", name)
|
|
68
|
+
# Replace consecutive underscores with a single underscore
|
|
69
|
+
valid_name = re.sub(r"_+", "_", valid_name)
|
|
70
|
+
# Remove leading and trailing underscores or whitespace
|
|
71
|
+
return valid_name.strip("_").strip()
|
|
72
|
+
|
|
73
|
+
|
|
42
74
|
class KilnBaseModel(BaseModel):
|
|
43
75
|
"""Base model for all Kiln data models with common functionality for persistence and versioning.
|
|
44
76
|
|
|
@@ -58,6 +90,8 @@ class KilnBaseModel(BaseModel):
|
|
|
58
90
|
created_at: datetime = Field(default_factory=datetime.now)
|
|
59
91
|
created_by: str = Field(default_factory=lambda: Config.shared().user_id)
|
|
60
92
|
|
|
93
|
+
_loaded_from_file: bool = False
|
|
94
|
+
|
|
61
95
|
@computed_field()
|
|
62
96
|
def model_type(self) -> str:
|
|
63
97
|
return self.type_name()
|
|
@@ -86,7 +120,7 @@ class KilnBaseModel(BaseModel):
|
|
|
86
120
|
return cls.load_from_file(path)
|
|
87
121
|
|
|
88
122
|
@classmethod
|
|
89
|
-
def load_from_file(cls: Type[T], path: Path) -> T:
|
|
123
|
+
def load_from_file(cls: Type[T], path: Path | str) -> T:
|
|
90
124
|
"""Load a model instance from a specific file path.
|
|
91
125
|
|
|
92
126
|
Args:
|
|
@@ -97,15 +131,28 @@ class KilnBaseModel(BaseModel):
|
|
|
97
131
|
|
|
98
132
|
Raises:
|
|
99
133
|
ValueError: If the loaded model is not of the expected type or version
|
|
134
|
+
FileNotFoundError: If the file does not exist
|
|
100
135
|
"""
|
|
136
|
+
if isinstance(path, str):
|
|
137
|
+
path = Path(path)
|
|
138
|
+
cached_model = ModelCache.shared().get_model(path, cls)
|
|
139
|
+
if cached_model is not None:
|
|
140
|
+
return cached_model
|
|
101
141
|
with open(path, "r") as file:
|
|
142
|
+
# modified time of file for cache invalidation. From file descriptor so it's atomic w read.
|
|
143
|
+
mtime_ns = os.fstat(file.fileno()).st_mtime_ns
|
|
102
144
|
file_data = file.read()
|
|
103
145
|
# TODO P2 perf: parsing the JSON twice here.
|
|
104
146
|
# Once for model_type, once for model. Can't call model_validate with parsed json because enum types break; they get strings instead of enums.
|
|
105
147
|
parsed_json = json.loads(file_data)
|
|
106
|
-
m = cls.model_validate_json(
|
|
148
|
+
m = cls.model_validate_json(
|
|
149
|
+
file_data,
|
|
150
|
+
strict=True,
|
|
151
|
+
context={"loading_from_file": True},
|
|
152
|
+
)
|
|
107
153
|
if not isinstance(m, cls):
|
|
108
154
|
raise ValueError(f"Loaded model is not of type {cls.__name__}")
|
|
155
|
+
m._loaded_from_file = True
|
|
109
156
|
file_data = None
|
|
110
157
|
m.path = path
|
|
111
158
|
if m.v > m.max_schema_version():
|
|
@@ -120,8 +167,21 @@ class KilnBaseModel(BaseModel):
|
|
|
120
167
|
f"Class: {m.__class__.__name__}, id: {getattr(m, 'id', None)}, path: {path}, "
|
|
121
168
|
f"version: {m.v}, max version: {m.max_schema_version()}"
|
|
122
169
|
)
|
|
170
|
+
ModelCache.shared().set_model(path, m, mtime_ns)
|
|
123
171
|
return m
|
|
124
172
|
|
|
173
|
+
def loaded_from_file(self, info: ValidationInfo | None = None) -> bool:
|
|
174
|
+
# Two methods of indicated it's loaded from file:
|
|
175
|
+
# 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
176
|
+
# 2) self._loaded_from_file -> After loading, set by the loader
|
|
177
|
+
if (
|
|
178
|
+
info is not None
|
|
179
|
+
and info.context is not None
|
|
180
|
+
and info.context.get("loading_from_file", False)
|
|
181
|
+
):
|
|
182
|
+
return True
|
|
183
|
+
return self._loaded_from_file
|
|
184
|
+
|
|
125
185
|
def save_to_file(self) -> None:
|
|
126
186
|
"""Save the model instance to a file.
|
|
127
187
|
|
|
@@ -140,6 +200,9 @@ class KilnBaseModel(BaseModel):
|
|
|
140
200
|
file.write(json_data)
|
|
141
201
|
# save the path so even if something like name changes, the file doesn't move
|
|
142
202
|
self.path = path
|
|
203
|
+
# We could save, but invalidating will trigger load on next use.
|
|
204
|
+
# This ensures everything in cache is loaded from disk, and the cache perfectly reflects what's on disk
|
|
205
|
+
ModelCache.shared().invalidate(path)
|
|
143
206
|
|
|
144
207
|
def delete(self) -> None:
|
|
145
208
|
if self.path is None:
|
|
@@ -148,6 +211,7 @@ class KilnBaseModel(BaseModel):
|
|
|
148
211
|
if dir_path is None:
|
|
149
212
|
raise ValueError("Cannot delete model because path is not set")
|
|
150
213
|
shutil.rmtree(dir_path)
|
|
214
|
+
ModelCache.shared().invalidate(self.path)
|
|
151
215
|
self.path = None
|
|
152
216
|
|
|
153
217
|
def build_path(self) -> Path | None:
|
|
@@ -167,51 +231,44 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
167
231
|
including parent reference handling and file system organization.
|
|
168
232
|
|
|
169
233
|
Attributes:
|
|
170
|
-
|
|
234
|
+
parent (KilnBaseModel): Reference to the parent model instance. Not persisted, just in memory.
|
|
171
235
|
"""
|
|
172
236
|
|
|
173
|
-
|
|
237
|
+
# Parent is an in memory only reference to parent. If it's set we use that. If not we'll try to load it from disk based on the path.
|
|
238
|
+
# We don't persist the parent reference to disk. See the accessors below for how we make it a clean api (parent accessor will lazy load from disk)
|
|
239
|
+
parent: Optional[KilnBaseModel] = Field(default=None, exclude=True)
|
|
174
240
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
241
|
+
def __getattribute__(self, name: str) -> Any:
|
|
242
|
+
if name == "parent":
|
|
243
|
+
return self.load_parent()
|
|
244
|
+
return super().__getattribute__(name)
|
|
178
245
|
|
|
179
|
-
def
|
|
180
|
-
|
|
181
|
-
if "parent" in data:
|
|
182
|
-
self.parent = data["parent"]
|
|
246
|
+
def cached_parent(self) -> Optional[KilnBaseModel]:
|
|
247
|
+
return object.__getattribute__(self, "parent")
|
|
183
248
|
|
|
184
|
-
|
|
185
|
-
def parent(self) -> Optional[KilnBaseModel]:
|
|
249
|
+
def load_parent(self) -> Optional[KilnBaseModel]:
|
|
186
250
|
"""Get the parent model instance, loading it from disk if necessary.
|
|
187
251
|
|
|
188
252
|
Returns:
|
|
189
253
|
Optional[KilnBaseModel]: The parent model instance or None if not set
|
|
190
254
|
"""
|
|
191
|
-
|
|
192
|
-
|
|
255
|
+
cached_parent = self.cached_parent()
|
|
256
|
+
if cached_parent is not None:
|
|
257
|
+
return cached_parent
|
|
258
|
+
|
|
193
259
|
# lazy load parent from path
|
|
194
260
|
if self.path is None:
|
|
195
261
|
return None
|
|
196
|
-
#
|
|
262
|
+
# Note: this only works with base_filename. If we every support custom names, we need to change this.
|
|
197
263
|
parent_path = (
|
|
198
264
|
self.path.parent.parent.parent
|
|
199
265
|
/ self.__class__.parent_type().base_filename()
|
|
200
266
|
)
|
|
201
267
|
if parent_path is None:
|
|
202
268
|
return None
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@parent.setter
|
|
207
|
-
def parent(self, value: Optional[KilnBaseModel]):
|
|
208
|
-
if value is not None:
|
|
209
|
-
expected_parent_type = self.__class__.parent_type()
|
|
210
|
-
if not isinstance(value, expected_parent_type):
|
|
211
|
-
raise ValueError(
|
|
212
|
-
f"Parent must be of type {expected_parent_type}, but was {type(value)}"
|
|
213
|
-
)
|
|
214
|
-
self._parent = value
|
|
269
|
+
loaded_parent = self.__class__.parent_type().load_from_file(parent_path)
|
|
270
|
+
self.parent = loaded_parent
|
|
271
|
+
return loaded_parent
|
|
215
272
|
|
|
216
273
|
# Dynamically implemented by KilnParentModel method injection
|
|
217
274
|
@classmethod
|
|
@@ -225,11 +282,12 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
225
282
|
|
|
226
283
|
@model_validator(mode="after")
|
|
227
284
|
def check_parent_type(self) -> Self:
|
|
228
|
-
|
|
285
|
+
cached_parent = self.cached_parent()
|
|
286
|
+
if cached_parent is not None:
|
|
229
287
|
expected_parent_type = self.__class__.parent_type()
|
|
230
|
-
if not isinstance(
|
|
288
|
+
if not isinstance(cached_parent, expected_parent_type):
|
|
231
289
|
raise ValueError(
|
|
232
|
-
f"Parent must be of type {expected_parent_type}, but was {type(
|
|
290
|
+
f"Parent must be of type {expected_parent_type}, but was {type(cached_parent)}"
|
|
233
291
|
)
|
|
234
292
|
return self
|
|
235
293
|
|
|
@@ -268,9 +326,7 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
268
326
|
)
|
|
269
327
|
|
|
270
328
|
@classmethod
|
|
271
|
-
def
|
|
272
|
-
cls: Type[PT], parent_path: Path | None
|
|
273
|
-
) -> list[PT]:
|
|
329
|
+
def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | None):
|
|
274
330
|
if parent_path is None:
|
|
275
331
|
# children are disk based. If not saved, they don't exist
|
|
276
332
|
return []
|
|
@@ -292,13 +348,41 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
292
348
|
return []
|
|
293
349
|
|
|
294
350
|
# Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
|
|
295
|
-
children = []
|
|
296
351
|
for child_file in relationship_folder.glob(f"**/{cls.base_filename()}"):
|
|
297
|
-
|
|
298
|
-
children.append(child)
|
|
352
|
+
yield child_file
|
|
299
353
|
|
|
354
|
+
@classmethod
|
|
355
|
+
def all_children_of_parent_path(
|
|
356
|
+
cls: Type[PT], parent_path: Path | None
|
|
357
|
+
) -> list[PT]:
|
|
358
|
+
children = []
|
|
359
|
+
for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
|
|
360
|
+
children.append(cls.load_from_file(child_path))
|
|
300
361
|
return children
|
|
301
362
|
|
|
363
|
+
@classmethod
|
|
364
|
+
def from_id_and_parent_path(
|
|
365
|
+
cls: Type[PT], id: str, parent_path: Path | None
|
|
366
|
+
) -> PT | None:
|
|
367
|
+
"""
|
|
368
|
+
Fast search by ID using the cache. Avoids the model_copy overhead on all but the exact match.
|
|
369
|
+
|
|
370
|
+
Uses cache so still slow on first load.
|
|
371
|
+
"""
|
|
372
|
+
if parent_path is None:
|
|
373
|
+
return None
|
|
374
|
+
|
|
375
|
+
# Note: we're using the in-file ID. We could make this faster using the path-ID if this becomes perf bottleneck, but it's better to have 1 source of truth.
|
|
376
|
+
for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
|
|
377
|
+
child_id = ModelCache.shared().get_model_id(child_path, cls)
|
|
378
|
+
if child_id == id:
|
|
379
|
+
return cls.load_from_file(child_path)
|
|
380
|
+
if child_id is None:
|
|
381
|
+
child = cls.load_from_file(child_path)
|
|
382
|
+
if child.id == id:
|
|
383
|
+
return child
|
|
384
|
+
return None
|
|
385
|
+
|
|
302
386
|
|
|
303
387
|
# Parent create methods for all child relationships
|
|
304
388
|
# You must pass in parent_of in the subclass definition, defining the child relationships
|