lightningrod-ai 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightningrod/__init__.py +66 -0
- lightningrod/_display.py +204 -0
- lightningrod/_errors.py +67 -0
- lightningrod/_generated/__init__.py +8 -0
- lightningrod/_generated/api/__init__.py +1 -0
- lightningrod/_generated/api/datasets/__init__.py +1 -0
- lightningrod/_generated/api/datasets/create_dataset_datasets_post.py +133 -0
- lightningrod/_generated/api/datasets/get_dataset_datasets_dataset_id_get.py +168 -0
- lightningrod/_generated/api/datasets/get_dataset_samples_datasets_dataset_id_samples_get.py +209 -0
- lightningrod/_generated/api/datasets/upload_samples_datasets_dataset_id_samples_post.py +190 -0
- lightningrod/_generated/api/file_sets/__init__.py +1 -0
- lightningrod/_generated/api/file_sets/add_file_to_set_filesets_file_set_id_files_post.py +190 -0
- lightningrod/_generated/api/file_sets/create_file_set_filesets_post.py +174 -0
- lightningrod/_generated/api/file_sets/get_file_set_filesets_file_set_id_get.py +168 -0
- lightningrod/_generated/api/file_sets/list_file_sets_filesets_get.py +173 -0
- lightningrod/_generated/api/file_sets/list_files_in_set_filesets_file_set_id_files_get.py +209 -0
- lightningrod/_generated/api/files/__init__.py +1 -0
- lightningrod/_generated/api/files/create_file_upload_files_post.py +174 -0
- lightningrod/_generated/api/open_ai_compatible/__init__.py +1 -0
- lightningrod/_generated/api/open_ai_compatible/chat_completions_openai_chat_completions_post.py +174 -0
- lightningrod/_generated/api/organizations/__init__.py +1 -0
- lightningrod/_generated/api/organizations/get_balance_organizations_balance_get.py +131 -0
- lightningrod/_generated/api/samples/__init__.py +1 -0
- lightningrod/_generated/api/samples/validate_sample_samples_validate_post.py +174 -0
- lightningrod/_generated/api/transform_jobs/__init__.py +1 -0
- lightningrod/_generated/api/transform_jobs/cost_estimation_transform_jobs_cost_estimation_post.py +174 -0
- lightningrod/_generated/api/transform_jobs/create_transform_job_transform_jobs_post.py +174 -0
- lightningrod/_generated/api/transform_jobs/get_transform_job_metrics_transform_jobs_job_id_metrics_get.py +172 -0
- lightningrod/_generated/api/transform_jobs/get_transform_job_transform_jobs_job_id_get.py +168 -0
- lightningrod/_generated/client.py +268 -0
- lightningrod/_generated/errors.py +16 -0
- lightningrod/_generated/models/__init__.py +147 -0
- lightningrod/_generated/models/answer_type.py +129 -0
- lightningrod/_generated/models/answer_type_enum.py +11 -0
- lightningrod/_generated/models/balance_response.py +61 -0
- lightningrod/_generated/models/chat_completion_request.py +216 -0
- lightningrod/_generated/models/chat_completion_response.py +146 -0
- lightningrod/_generated/models/chat_message.py +69 -0
- lightningrod/_generated/models/choice.py +97 -0
- lightningrod/_generated/models/create_dataset_response.py +61 -0
- lightningrod/_generated/models/create_file_set_file_request.py +101 -0
- lightningrod/_generated/models/create_file_set_file_request_metadata_type_0.py +46 -0
- lightningrod/_generated/models/create_file_set_request.py +83 -0
- lightningrod/_generated/models/create_file_upload_request.py +91 -0
- lightningrod/_generated/models/create_file_upload_response.py +165 -0
- lightningrod/_generated/models/create_file_upload_response_metadata_type_0.py +46 -0
- lightningrod/_generated/models/create_transform_job_request.py +312 -0
- lightningrod/_generated/models/dataset_metadata.py +69 -0
- lightningrod/_generated/models/estimate_cost_request.py +243 -0
- lightningrod/_generated/models/estimate_cost_response.py +117 -0
- lightningrod/_generated/models/event_usage_summary.py +80 -0
- lightningrod/_generated/models/file_set.py +128 -0
- lightningrod/_generated/models/file_set_file.py +203 -0
- lightningrod/_generated/models/file_set_file_metadata_type_0.py +57 -0
- lightningrod/_generated/models/file_set_query_seed_generator.py +136 -0
- lightningrod/_generated/models/file_set_seed_generator.py +126 -0
- lightningrod/_generated/models/filter_criteria.py +83 -0
- lightningrod/_generated/models/forward_looking_question.py +130 -0
- lightningrod/_generated/models/forward_looking_question_generator.py +217 -0
- lightningrod/_generated/models/gdelt_seed_generator.py +103 -0
- lightningrod/_generated/models/http_validation_error.py +79 -0
- lightningrod/_generated/models/job_usage.py +185 -0
- lightningrod/_generated/models/job_usage_by_step_type_0.py +59 -0
- lightningrod/_generated/models/label.py +143 -0
- lightningrod/_generated/models/list_file_set_files_response.py +113 -0
- lightningrod/_generated/models/list_file_sets_response.py +75 -0
- lightningrod/_generated/models/llm_model_usage_summary.py +98 -0
- lightningrod/_generated/models/mock_transform_config.py +243 -0
- lightningrod/_generated/models/mock_transform_config_metadata_additions.py +46 -0
- lightningrod/_generated/models/model_config.py +316 -0
- lightningrod/_generated/models/model_source_type.py +16 -0
- lightningrod/_generated/models/news_context.py +82 -0
- lightningrod/_generated/models/news_context_generator.py +127 -0
- lightningrod/_generated/models/news_seed_generator.py +220 -0
- lightningrod/_generated/models/paginated_samples_response.py +113 -0
- lightningrod/_generated/models/pipeline_metrics_response.py +99 -0
- lightningrod/_generated/models/question.py +74 -0
- lightningrod/_generated/models/question_and_label_generator.py +217 -0
- lightningrod/_generated/models/question_generator.py +217 -0
- lightningrod/_generated/models/question_pipeline.py +417 -0
- lightningrod/_generated/models/question_renderer.py +123 -0
- lightningrod/_generated/models/rag_context.py +82 -0
- lightningrod/_generated/models/response_message.py +69 -0
- lightningrod/_generated/models/rollout.py +130 -0
- lightningrod/_generated/models/rollout_generator.py +139 -0
- lightningrod/_generated/models/rollout_parsed_output_type_0.py +46 -0
- lightningrod/_generated/models/sample.py +323 -0
- lightningrod/_generated/models/sample_meta.py +46 -0
- lightningrod/_generated/models/seed.py +135 -0
- lightningrod/_generated/models/step_cost_breakdown.py +109 -0
- lightningrod/_generated/models/transform_job.py +268 -0
- lightningrod/_generated/models/transform_job_status.py +11 -0
- lightningrod/_generated/models/transform_step_metrics_response.py +131 -0
- lightningrod/_generated/models/transform_type.py +25 -0
- lightningrod/_generated/models/upload_samples_request.py +75 -0
- lightningrod/_generated/models/upload_samples_response.py +69 -0
- lightningrod/_generated/models/usage.py +77 -0
- lightningrod/_generated/models/usage_summary.py +102 -0
- lightningrod/_generated/models/usage_summary_events.py +59 -0
- lightningrod/_generated/models/usage_summary_llm_by_model.py +59 -0
- lightningrod/_generated/models/validate_sample_response.py +69 -0
- lightningrod/_generated/models/validation_error.py +90 -0
- lightningrod/_generated/models/web_search_labeler.py +120 -0
- lightningrod/_generated/py.typed +1 -0
- lightningrod/_generated/types.py +54 -0
- lightningrod/client.py +48 -0
- lightningrod/datasets/__init__.py +4 -0
- lightningrod/datasets/client.py +174 -0
- lightningrod/datasets/dataset.py +255 -0
- lightningrod/files/__init__.py +0 -0
- lightningrod/files/client.py +58 -0
- lightningrod/filesets/__init__.py +0 -0
- lightningrod/filesets/client.py +106 -0
- lightningrod/organization/__init__.py +0 -0
- lightningrod/organization/client.py +17 -0
- lightningrod/py.typed +0 -0
- lightningrod/transforms/__init__.py +0 -0
- lightningrod/transforms/client.py +154 -0
- lightningrod_ai-0.1.6.dist-info/METADATA +122 -0
- lightningrod_ai-0.1.6.dist-info/RECORD +123 -0
- lightningrod_ai-0.1.6.dist-info/WHEEL +5 -0
- lightningrod_ai-0.1.6.dist-info/licenses/LICENSE +23 -0
- lightningrod_ai-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
|
|
5
|
+
|
|
6
|
+
from attrs import define as _attrs_define
|
|
7
|
+
from attrs import field as _attrs_field
|
|
8
|
+
|
|
9
|
+
from ..types import UNSET, Unset
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..models.answer_type import AnswerType
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T", bound="WebSearchLabeler")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@_attrs_define
|
|
19
|
+
class WebSearchLabeler:
|
|
20
|
+
"""
|
|
21
|
+
Attributes:
|
|
22
|
+
config_type (Literal['WEB_SEARCH_LABELER'] | Unset): Type of transform configuration Default:
|
|
23
|
+
'WEB_SEARCH_LABELER'.
|
|
24
|
+
confidence_threshold (float | Unset): Minimum confidence threshold for including questions Default: 0.9.
|
|
25
|
+
answer_type (AnswerType | None | Unset): The type of answer expected, used to guide the labeler
|
|
26
|
+
resolve_redirects (bool | Unset): Resolve redirect URLs to actual destinations Default: False.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
config_type: Literal["WEB_SEARCH_LABELER"] | Unset = "WEB_SEARCH_LABELER"
|
|
30
|
+
confidence_threshold: float | Unset = 0.9
|
|
31
|
+
answer_type: AnswerType | None | Unset = UNSET
|
|
32
|
+
resolve_redirects: bool | Unset = False
|
|
33
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
34
|
+
|
|
35
|
+
def to_dict(self) -> dict[str, Any]:
|
|
36
|
+
from ..models.answer_type import AnswerType
|
|
37
|
+
|
|
38
|
+
config_type = self.config_type
|
|
39
|
+
|
|
40
|
+
confidence_threshold = self.confidence_threshold
|
|
41
|
+
|
|
42
|
+
answer_type: dict[str, Any] | None | Unset
|
|
43
|
+
if isinstance(self.answer_type, Unset):
|
|
44
|
+
answer_type = UNSET
|
|
45
|
+
elif isinstance(self.answer_type, AnswerType):
|
|
46
|
+
answer_type = self.answer_type.to_dict()
|
|
47
|
+
else:
|
|
48
|
+
answer_type = self.answer_type
|
|
49
|
+
|
|
50
|
+
resolve_redirects = self.resolve_redirects
|
|
51
|
+
|
|
52
|
+
field_dict: dict[str, Any] = {}
|
|
53
|
+
field_dict.update(self.additional_properties)
|
|
54
|
+
field_dict.update({})
|
|
55
|
+
if config_type is not UNSET:
|
|
56
|
+
field_dict["config_type"] = config_type
|
|
57
|
+
if confidence_threshold is not UNSET:
|
|
58
|
+
field_dict["confidence_threshold"] = confidence_threshold
|
|
59
|
+
if answer_type is not UNSET:
|
|
60
|
+
field_dict["answer_type"] = answer_type
|
|
61
|
+
if resolve_redirects is not UNSET:
|
|
62
|
+
field_dict["resolve_redirects"] = resolve_redirects
|
|
63
|
+
|
|
64
|
+
return field_dict
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
|
|
68
|
+
from ..models.answer_type import AnswerType
|
|
69
|
+
|
|
70
|
+
d = dict(src_dict)
|
|
71
|
+
config_type = cast(Literal["WEB_SEARCH_LABELER"] | Unset, d.pop("config_type", UNSET))
|
|
72
|
+
if config_type != "WEB_SEARCH_LABELER" and not isinstance(config_type, Unset):
|
|
73
|
+
raise ValueError(f"config_type must match const 'WEB_SEARCH_LABELER', got '{config_type}'")
|
|
74
|
+
|
|
75
|
+
confidence_threshold = d.pop("confidence_threshold", UNSET)
|
|
76
|
+
|
|
77
|
+
def _parse_answer_type(data: object) -> AnswerType | None | Unset:
|
|
78
|
+
if data is None:
|
|
79
|
+
return data
|
|
80
|
+
if isinstance(data, Unset):
|
|
81
|
+
return data
|
|
82
|
+
try:
|
|
83
|
+
if not isinstance(data, dict):
|
|
84
|
+
raise TypeError()
|
|
85
|
+
answer_type_type_0 = AnswerType.from_dict(data)
|
|
86
|
+
|
|
87
|
+
return answer_type_type_0
|
|
88
|
+
except (TypeError, ValueError, AttributeError, KeyError):
|
|
89
|
+
pass
|
|
90
|
+
return cast(AnswerType | None | Unset, data)
|
|
91
|
+
|
|
92
|
+
answer_type = _parse_answer_type(d.pop("answer_type", UNSET))
|
|
93
|
+
|
|
94
|
+
resolve_redirects = d.pop("resolve_redirects", UNSET)
|
|
95
|
+
|
|
96
|
+
web_search_labeler = cls(
|
|
97
|
+
config_type=config_type,
|
|
98
|
+
confidence_threshold=confidence_threshold,
|
|
99
|
+
answer_type=answer_type,
|
|
100
|
+
resolve_redirects=resolve_redirects,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
web_search_labeler.additional_properties = d
|
|
104
|
+
return web_search_labeler
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def additional_keys(self) -> list[str]:
|
|
108
|
+
return list(self.additional_properties.keys())
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, key: str) -> Any:
|
|
111
|
+
return self.additional_properties[key]
|
|
112
|
+
|
|
113
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
114
|
+
self.additional_properties[key] = value
|
|
115
|
+
|
|
116
|
+
def __delitem__(self, key: str) -> None:
|
|
117
|
+
del self.additional_properties[key]
|
|
118
|
+
|
|
119
|
+
def __contains__(self, key: str) -> bool:
|
|
120
|
+
return key in self.additional_properties
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Marker file for PEP 561
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Contains some shared types for properties"""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping, MutableMapping
|
|
4
|
+
from http import HTTPStatus
|
|
5
|
+
from typing import IO, BinaryIO, Generic, Literal, TypeVar
|
|
6
|
+
|
|
7
|
+
from attrs import define
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Unset:
|
|
11
|
+
def __bool__(self) -> Literal[False]:
|
|
12
|
+
return False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
UNSET: Unset = Unset()
|
|
16
|
+
|
|
17
|
+
# The types that `httpx.Client(files=)` can accept, copied from that library.
|
|
18
|
+
FileContent = IO[bytes] | bytes | str
|
|
19
|
+
FileTypes = (
|
|
20
|
+
# (filename, file (or bytes), content_type)
|
|
21
|
+
tuple[str | None, FileContent, str | None]
|
|
22
|
+
# (filename, file (or bytes), content_type, headers)
|
|
23
|
+
| tuple[str | None, FileContent, str | None, Mapping[str, str]]
|
|
24
|
+
)
|
|
25
|
+
RequestFiles = list[tuple[str, FileTypes]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@define
|
|
29
|
+
class File:
|
|
30
|
+
"""Contains information for file uploads"""
|
|
31
|
+
|
|
32
|
+
payload: BinaryIO
|
|
33
|
+
file_name: str | None = None
|
|
34
|
+
mime_type: str | None = None
|
|
35
|
+
|
|
36
|
+
def to_tuple(self) -> FileTypes:
|
|
37
|
+
"""Return a tuple representation that httpx will accept for multipart/form-data"""
|
|
38
|
+
return self.file_name, self.payload, self.mime_type
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
T = TypeVar("T")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@define
|
|
45
|
+
class Response(Generic[T]):
|
|
46
|
+
"""A response from an endpoint"""
|
|
47
|
+
|
|
48
|
+
status_code: HTTPStatus
|
|
49
|
+
content: bytes
|
|
50
|
+
headers: MutableMapping[str, str]
|
|
51
|
+
parsed: T | None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = ["UNSET", "File", "FileTypes", "RequestFiles", "Response", "Unset"]
|
lightningrod/client.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from lightningrod._generated.client import AuthenticatedClient
|
|
4
|
+
from lightningrod._generated.models.sample import Sample
|
|
5
|
+
from lightningrod.datasets.client import DatasetSamplesClient, DatasetsClient
|
|
6
|
+
from lightningrod.datasets.dataset import Dataset
|
|
7
|
+
from lightningrod.files.client import FilesClient
|
|
8
|
+
from lightningrod.filesets.client import FileSetsClient
|
|
9
|
+
from lightningrod.organization.client import OrganizationsClient
|
|
10
|
+
from lightningrod.transforms.client import TransformsClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LightningRod:
|
|
14
|
+
"""
|
|
15
|
+
Python SDK for the Lightning Rod API.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
api_key: Your Lightning Rod API key
|
|
19
|
+
base_url: Base URL for the API (defaults to production)
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
23
|
+
>>> config = QuestionPipeline(...)
|
|
24
|
+
>>> dataset = lr.transforms.run(config)
|
|
25
|
+
>>> samples = dataset.to_samples()
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
api_key: str,
|
|
31
|
+
base_url: str = "https://api.lightningrod.ai/api/public/v1"
|
|
32
|
+
):
|
|
33
|
+
self.api_key: str = api_key
|
|
34
|
+
self.base_url: str = base_url.rstrip("/")
|
|
35
|
+
self._generated_client: AuthenticatedClient = AuthenticatedClient(
|
|
36
|
+
base_url=self.base_url,
|
|
37
|
+
token=api_key,
|
|
38
|
+
prefix="Bearer",
|
|
39
|
+
auth_header_name="Authorization",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self._dataset_samples: DatasetSamplesClient = DatasetSamplesClient(self._generated_client)
|
|
43
|
+
self.transforms: TransformsClient = TransformsClient(self._generated_client, self._dataset_samples)
|
|
44
|
+
self.datasets: DatasetsClient = DatasetsClient(self._generated_client, self._dataset_samples)
|
|
45
|
+
self.organization: OrganizationsClient = OrganizationsClient(self._generated_client)
|
|
46
|
+
# TODO(filesets): Enable when filesets are publicly supported
|
|
47
|
+
# self.files: FilesClient = FilesClient(self._generated_client)
|
|
48
|
+
# self.filesets: FileSetsClient = FileSetsClient(self._generated_client, self.files)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from lightningrod._generated.models import (
|
|
4
|
+
HTTPValidationError,
|
|
5
|
+
UploadSamplesRequest,
|
|
6
|
+
)
|
|
7
|
+
from lightningrod._generated.models.sample import Sample
|
|
8
|
+
from lightningrod._generated.api.datasets import (
|
|
9
|
+
create_dataset_datasets_post,
|
|
10
|
+
get_dataset_datasets_dataset_id_get,
|
|
11
|
+
get_dataset_samples_datasets_dataset_id_samples_get,
|
|
12
|
+
upload_samples_datasets_dataset_id_samples_post,
|
|
13
|
+
)
|
|
14
|
+
from lightningrod._generated.types import Unset
|
|
15
|
+
from lightningrod._generated.client import AuthenticatedClient
|
|
16
|
+
from lightningrod.datasets.dataset import Dataset
|
|
17
|
+
from lightningrod._errors import handle_response_error
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DatasetSamplesClient:
|
|
21
|
+
def __init__(self, client: AuthenticatedClient):
|
|
22
|
+
self._client: AuthenticatedClient = client
|
|
23
|
+
|
|
24
|
+
def list(self, dataset_id: str) -> List[Sample]:
|
|
25
|
+
samples: List[Sample] = []
|
|
26
|
+
cursor: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
while True:
|
|
29
|
+
response = get_dataset_samples_datasets_dataset_id_samples_get.sync_detailed(
|
|
30
|
+
dataset_id=dataset_id,
|
|
31
|
+
client=self._client,
|
|
32
|
+
limit=100,
|
|
33
|
+
cursor=cursor,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
parsed = handle_response_error(response, "fetch samples")
|
|
37
|
+
|
|
38
|
+
samples.extend(parsed.samples)
|
|
39
|
+
|
|
40
|
+
if not parsed.has_more:
|
|
41
|
+
break
|
|
42
|
+
if isinstance(parsed.next_cursor, Unset) or parsed.next_cursor is None:
|
|
43
|
+
break
|
|
44
|
+
cursor = str(parsed.next_cursor)
|
|
45
|
+
|
|
46
|
+
return samples
|
|
47
|
+
|
|
48
|
+
def upload(
|
|
49
|
+
self,
|
|
50
|
+
dataset_id: str,
|
|
51
|
+
samples: List[Sample],
|
|
52
|
+
) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Upload samples to an existing dataset.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
dataset_id: ID of the dataset to upload samples to
|
|
58
|
+
samples: List of Sample objects to upload
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
62
|
+
>>> samples = [Sample(seed=Seed(...), ...), ...]
|
|
63
|
+
>>> lr.datasets.upload(samples)
|
|
64
|
+
"""
|
|
65
|
+
request = UploadSamplesRequest(samples=samples)
|
|
66
|
+
|
|
67
|
+
response = upload_samples_datasets_dataset_id_samples_post.sync_detailed(
|
|
68
|
+
dataset_id=dataset_id,
|
|
69
|
+
client=self._client,
|
|
70
|
+
body=request,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
handle_response_error(response, "upload samples")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class DatasetsClient:
|
|
77
|
+
def __init__(self, client: AuthenticatedClient, dataset_samples_client: DatasetSamplesClient):
|
|
78
|
+
self._client: AuthenticatedClient = client
|
|
79
|
+
self._dataset_samples_client: DatasetSamplesClient = dataset_samples_client
|
|
80
|
+
|
|
81
|
+
def create(self) -> Dataset:
|
|
82
|
+
"""
|
|
83
|
+
Create a new empty dataset.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dataset object representing the newly created dataset
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
90
|
+
>>> dataset = lr.datasets.create()
|
|
91
|
+
>>> print(f"Created dataset: {dataset.id}")
|
|
92
|
+
"""
|
|
93
|
+
response = create_dataset_datasets_post.sync_detailed(
|
|
94
|
+
client=self._client,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
create_result = handle_response_error(response, "create dataset")
|
|
98
|
+
|
|
99
|
+
dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
|
|
100
|
+
dataset_id=create_result.id,
|
|
101
|
+
client=self._client,
|
|
102
|
+
)
|
|
103
|
+
dataset_result = handle_response_error(dataset_response, "get dataset")
|
|
104
|
+
|
|
105
|
+
return Dataset(
|
|
106
|
+
id=dataset_result.id,
|
|
107
|
+
num_rows=dataset_result.num_rows,
|
|
108
|
+
datasets_client=self._dataset_samples_client
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def create_from_samples(
|
|
112
|
+
self,
|
|
113
|
+
samples: List[Sample],
|
|
114
|
+
batch_size: int = 1000,
|
|
115
|
+
) -> Dataset:
|
|
116
|
+
"""
|
|
117
|
+
Create a new dataset and upload samples to it.
|
|
118
|
+
|
|
119
|
+
This is a convenience method that creates a dataset and uploads all samples
|
|
120
|
+
in batches. Useful for creating input datasets from a collection of seeds.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
samples: List of Sample objects to upload
|
|
124
|
+
batch_size: Number of samples to upload per batch (default: 1000)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Dataset object with all samples uploaded
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
131
|
+
>>> samples = [Sample(seed=Seed(...), ...), ...]
|
|
132
|
+
>>> dataset = lr.datasets.create_from_samples(samples, batch_size=1000)
|
|
133
|
+
>>> print(f"Created dataset with {dataset.num_rows} samples")
|
|
134
|
+
"""
|
|
135
|
+
dataset = self.create()
|
|
136
|
+
|
|
137
|
+
for i in range(0, len(samples), batch_size):
|
|
138
|
+
batch = samples[i:i + batch_size]
|
|
139
|
+
self._dataset_samples_client.upload(dataset.id, batch)
|
|
140
|
+
|
|
141
|
+
dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
|
|
142
|
+
dataset_id=dataset.id,
|
|
143
|
+
client=self._client,
|
|
144
|
+
)
|
|
145
|
+
dataset_result = handle_response_error(dataset_response, "refresh dataset")
|
|
146
|
+
|
|
147
|
+
dataset.num_rows = dataset_result.num_rows
|
|
148
|
+
return dataset
|
|
149
|
+
|
|
150
|
+
def get(self, dataset_id: str) -> Dataset:
|
|
151
|
+
"""
|
|
152
|
+
Get a dataset by ID.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
dataset_id: ID of the dataset to retrieve
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Dataset object
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
162
|
+
>>> dataset = lr.datasets.get("dataset-id-here")
|
|
163
|
+
"""
|
|
164
|
+
dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
|
|
165
|
+
dataset_id=dataset_id,
|
|
166
|
+
client=self._client,
|
|
167
|
+
)
|
|
168
|
+
dataset_result = handle_response_error(dataset_response, "get dataset")
|
|
169
|
+
|
|
170
|
+
return Dataset(
|
|
171
|
+
id=dataset_result.id,
|
|
172
|
+
num_rows=dataset_result.num_rows,
|
|
173
|
+
datasets_client=self._dataset_samples_client
|
|
174
|
+
)
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from typing import List, Optional, Dict, Any, TYPE_CHECKING
|
|
2
|
+
import asyncio
|
|
3
|
+
|
|
4
|
+
from lightningrod._generated.models.sample import Sample
|
|
5
|
+
from lightningrod._generated.models.forward_looking_question import ForwardLookingQuestion
|
|
6
|
+
from lightningrod._generated.models.question import Question
|
|
7
|
+
from lightningrod._generated.models.news_context import NewsContext
|
|
8
|
+
from lightningrod._generated.models.rag_context import RAGContext
|
|
9
|
+
from lightningrod._generated.models.sample_meta import SampleMeta
|
|
10
|
+
from lightningrod._generated.types import UNSET, Unset
|
|
11
|
+
|
|
12
|
+
# avoid circular import
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from lightningrod.datasets.client import DatasetSamplesClient
|
|
15
|
+
|
|
16
|
+
class Dataset:
|
|
17
|
+
"""
|
|
18
|
+
Represents a dataset in Lightning Rod.
|
|
19
|
+
|
|
20
|
+
A dataset contains rows of sample data. Use this class to access
|
|
21
|
+
dataset metadata and download the actual samples.
|
|
22
|
+
|
|
23
|
+
Note: Datasets should only be created through LightningRod methods,
|
|
24
|
+
not instantiated directly.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
id: Unique identifier for the dataset
|
|
28
|
+
num_rows: Number of rows in the dataset
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
32
|
+
>>> config = QuestionPipeline(...)
|
|
33
|
+
>>> dataset = lr.transforms.run(config)
|
|
34
|
+
>>> samples = dataset.to_samples()
|
|
35
|
+
>>> print(f"Dataset has {len(samples)} samples")
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
id: str,
|
|
41
|
+
num_rows: int,
|
|
42
|
+
datasets_client: "DatasetSamplesClient"
|
|
43
|
+
):
|
|
44
|
+
self.id: str = id
|
|
45
|
+
self.num_rows: int = num_rows
|
|
46
|
+
self._datasets_client: "DatasetSamplesClient" = datasets_client
|
|
47
|
+
self._samples: Optional[List[Sample]] = None
|
|
48
|
+
|
|
49
|
+
def download(self) -> List[Sample]:
|
|
50
|
+
"""
|
|
51
|
+
Download all samples from the dataset via the paginated API.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of Sample objects
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
58
|
+
>>> dataset = lr.transforms.run(config)
|
|
59
|
+
>>> samples = dataset.download()
|
|
60
|
+
>>> for sample in samples:
|
|
61
|
+
... print(sample.seed.seed_text)
|
|
62
|
+
"""
|
|
63
|
+
self._samples = self._datasets_client.list(self.id)
|
|
64
|
+
return self._samples
|
|
65
|
+
|
|
66
|
+
def samples(self) -> List[Sample]:
|
|
67
|
+
"""
|
|
68
|
+
Get all samples from the dataset.
|
|
69
|
+
Automatically downloads the samples if they haven't been downloaded yet.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
List of Sample objects
|
|
73
|
+
"""
|
|
74
|
+
if not self._samples:
|
|
75
|
+
self.download()
|
|
76
|
+
return self._samples
|
|
77
|
+
|
|
78
|
+
def to_samples(self) -> List[Sample]:
|
|
79
|
+
"""
|
|
80
|
+
Download all samples from the dataset via the paginated API.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List of Sample objects
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
87
|
+
>>> config = QuestionPipeline(...)
|
|
88
|
+
>>> dataset = lr.transforms.run(config)
|
|
89
|
+
>>> samples = dataset.to_samples()
|
|
90
|
+
>>> for sample in samples:
|
|
91
|
+
... print(sample.seed.seed_text)
|
|
92
|
+
"""
|
|
93
|
+
return self.samples()
|
|
94
|
+
|
|
95
|
+
def flattened(self) -> List[Dict[str, Any]]:
|
|
96
|
+
"""
|
|
97
|
+
Convert all samples to a list of dictionaries.
|
|
98
|
+
Automatically downloads the samples if they haven't been downloaded yet.
|
|
99
|
+
|
|
100
|
+
Handles different question types (Question, ForwardLookingQuestion) and
|
|
101
|
+
extracts relevant fields from labels, seeds, and prompts.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
List of dictionaries, each representing a sample row
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> lr = LightningRod(api_key="your-api-key")
|
|
108
|
+
>>> config = QuestionPipeline(...)
|
|
109
|
+
>>> dataset = lr.transforms.run(config)
|
|
110
|
+
>>> rows = dataset.flattened()
|
|
111
|
+
>>> import pandas as pd
|
|
112
|
+
>>> df = pd.DataFrame(rows)
|
|
113
|
+
"""
|
|
114
|
+
samples = self.samples()
|
|
115
|
+
return [self._sample_to_dict(sample) for sample in samples]
|
|
116
|
+
|
|
117
|
+
def _sample_to_dict(self, sample: Sample) -> Dict[str, Any]:
|
|
118
|
+
row: Dict[str, Any] = {}
|
|
119
|
+
|
|
120
|
+
if sample.question and not isinstance(sample.question, Unset):
|
|
121
|
+
if isinstance(sample.question, ForwardLookingQuestion):
|
|
122
|
+
row['question.question_text'] = sample.question.question_text
|
|
123
|
+
row['question.date_close'] = sample.question.date_close.isoformat()
|
|
124
|
+
row['question.event_date'] = sample.question.event_date.isoformat()
|
|
125
|
+
row['question.resolution_criteria'] = sample.question.resolution_criteria
|
|
126
|
+
if sample.question.prediction_date is not None and not isinstance(sample.question.prediction_date, Unset):
|
|
127
|
+
row['question.prediction_date'] = sample.question.prediction_date.isoformat()
|
|
128
|
+
elif isinstance(sample.question, Question):
|
|
129
|
+
row['question.question_text'] = sample.question.question_text
|
|
130
|
+
else:
|
|
131
|
+
question_text = getattr(sample.question, 'question_text', None)
|
|
132
|
+
if question_text is not None:
|
|
133
|
+
row['question.question_text'] = question_text
|
|
134
|
+
|
|
135
|
+
if sample.label and not isinstance(sample.label, Unset):
|
|
136
|
+
row['label.label'] = sample.label.label
|
|
137
|
+
row['label.label_confidence'] = sample.label.label_confidence
|
|
138
|
+
if sample.label.resolution_date is not None and not isinstance(sample.label.resolution_date, Unset):
|
|
139
|
+
row['label.resolution_date'] = sample.label.resolution_date.isoformat()
|
|
140
|
+
if sample.label.reasoning is not None and not isinstance(sample.label.reasoning, Unset):
|
|
141
|
+
row['label.reasoning'] = sample.label.reasoning
|
|
142
|
+
if sample.label.answer_sources is not None and not isinstance(sample.label.answer_sources, Unset):
|
|
143
|
+
row['label.answer_sources'] = sample.label.answer_sources
|
|
144
|
+
|
|
145
|
+
if sample.prompt and not isinstance(sample.prompt, Unset):
|
|
146
|
+
row['prompt'] = sample.prompt
|
|
147
|
+
|
|
148
|
+
if sample.seed and not isinstance(sample.seed, Unset):
|
|
149
|
+
row['seed.seed_text'] = sample.seed.seed_text
|
|
150
|
+
if sample.seed.url is not None and not isinstance(sample.seed.url, Unset):
|
|
151
|
+
row['seed.url'] = sample.seed.url
|
|
152
|
+
if sample.seed.seed_creation_date is not None and not isinstance(sample.seed.seed_creation_date, Unset):
|
|
153
|
+
row['seed.seed_creation_date'] = sample.seed.seed_creation_date.isoformat()
|
|
154
|
+
if sample.seed.search_query is not None and not isinstance(sample.seed.search_query, Unset):
|
|
155
|
+
row['seed.search_query'] = sample.seed.search_query
|
|
156
|
+
|
|
157
|
+
if sample.is_valid is not None and not isinstance(sample.is_valid, Unset):
|
|
158
|
+
row['is_valid'] = sample.is_valid
|
|
159
|
+
|
|
160
|
+
if sample.context is not None and not isinstance(sample.context, Unset):
|
|
161
|
+
for idx, ctx in enumerate(sample.context):
|
|
162
|
+
if isinstance(ctx, NewsContext):
|
|
163
|
+
row[f'context.{idx}.rendered_context'] = ctx.rendered_context
|
|
164
|
+
row[f'context.{idx}.search_query'] = ctx.search_query
|
|
165
|
+
row[f'context.{idx}.context_type'] = ctx.context_type
|
|
166
|
+
elif isinstance(ctx, RAGContext):
|
|
167
|
+
row[f'context.{idx}.rendered_context'] = ctx.rendered_context
|
|
168
|
+
row[f'context.{idx}.document_id'] = ctx.document_id
|
|
169
|
+
row[f'context.{idx}.context_type'] = ctx.context_type
|
|
170
|
+
|
|
171
|
+
if sample.meta is not None and not isinstance(sample.meta, Unset):
|
|
172
|
+
if isinstance(sample.meta, SampleMeta):
|
|
173
|
+
for key, value in sample.meta.additional_properties.items():
|
|
174
|
+
row[f'meta.{key}'] = value
|
|
175
|
+
|
|
176
|
+
if sample.additional_properties:
|
|
177
|
+
for key, value in sample.additional_properties.items():
|
|
178
|
+
row[f'additional_properties.{key}'] = value
|
|
179
|
+
|
|
180
|
+
return row
|
|
181
|
+
|
|
182
|
+
class AsyncDataset:
|
|
183
|
+
"""
|
|
184
|
+
Async wrapper for Dataset.
|
|
185
|
+
|
|
186
|
+
This class provides an async interface to Dataset operations by running
|
|
187
|
+
the synchronous operations in a thread pool using asyncio.to_thread.
|
|
188
|
+
|
|
189
|
+
Note: AsyncDatasets should only be created through AsyncLightningRod methods,
|
|
190
|
+
not instantiated directly.
|
|
191
|
+
|
|
192
|
+
Attributes:
|
|
193
|
+
id: Unique identifier for the dataset
|
|
194
|
+
num_rows: Number of rows in the dataset
|
|
195
|
+
|
|
196
|
+
Example:
|
|
197
|
+
>>> lr = AsyncLightningRod(api_key="your-api-key")
|
|
198
|
+
>>> config = QuestionPipeline(...)
|
|
199
|
+
>>> dataset = await lr.transforms.run(config)
|
|
200
|
+
>>> samples = await dataset.to_samples()
|
|
201
|
+
>>> print(f"Dataset has {len(samples)} samples")
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, sync_dataset: Dataset):
|
|
205
|
+
self._sync_dataset: Dataset = sync_dataset
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def id(self) -> str:
|
|
209
|
+
return self._sync_dataset.id
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def num_rows(self) -> int:
|
|
213
|
+
return self._sync_dataset.num_rows
|
|
214
|
+
|
|
215
|
+
async def to_samples(self) -> List[Sample]:
|
|
216
|
+
"""
|
|
217
|
+
Download all samples from the dataset via the paginated API.
|
|
218
|
+
|
|
219
|
+
All operations are run in a thread pool to avoid blocking the event loop.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
List of Sample objects
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
>>> lr = AsyncLightningRod(api_key="your-api-key")
|
|
226
|
+
>>> config = QuestionPipeline(...)
|
|
227
|
+
>>> dataset = await lr.transforms.run(config)
|
|
228
|
+
>>> samples = await dataset.to_samples()
|
|
229
|
+
>>> for sample in samples:
|
|
230
|
+
... print(sample.seed.seed_text)
|
|
231
|
+
"""
|
|
232
|
+
return await asyncio.to_thread(self._sync_dataset.to_samples)
|
|
233
|
+
|
|
234
|
+
async def flattened(self) -> List[Dict[str, Any]]:
|
|
235
|
+
"""
|
|
236
|
+
Convert all samples to a list of dictionaries.
|
|
237
|
+
Automatically downloads the samples if they haven't been downloaded yet.
|
|
238
|
+
|
|
239
|
+
All operations are run in a thread pool to avoid blocking the event loop.
|
|
240
|
+
|
|
241
|
+
Handles different question types (Question, ForwardLookingQuestion) and
|
|
242
|
+
extracts relevant fields from labels, seeds, and prompts.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
List of dictionaries, each representing a sample row
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
>>> lr = AsyncLightningRod(api_key="your-api-key")
|
|
249
|
+
>>> config = QuestionPipeline(...)
|
|
250
|
+
>>> dataset = await lr.transforms.run(config)
|
|
251
|
+
>>> rows = await dataset.flattened()
|
|
252
|
+
>>> import pandas as pd
|
|
253
|
+
>>> df = pd.DataFrame(rows)
|
|
254
|
+
"""
|
|
255
|
+
return await asyncio.to_thread(self._sync_dataset.flattened)
|
|
File without changes
|