lightningrod-ai 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. lightningrod/__init__.py +66 -0
  2. lightningrod/_display.py +204 -0
  3. lightningrod/_errors.py +67 -0
  4. lightningrod/_generated/__init__.py +8 -0
  5. lightningrod/_generated/api/__init__.py +1 -0
  6. lightningrod/_generated/api/datasets/__init__.py +1 -0
  7. lightningrod/_generated/api/datasets/create_dataset_datasets_post.py +133 -0
  8. lightningrod/_generated/api/datasets/get_dataset_datasets_dataset_id_get.py +168 -0
  9. lightningrod/_generated/api/datasets/get_dataset_samples_datasets_dataset_id_samples_get.py +209 -0
  10. lightningrod/_generated/api/datasets/upload_samples_datasets_dataset_id_samples_post.py +190 -0
  11. lightningrod/_generated/api/file_sets/__init__.py +1 -0
  12. lightningrod/_generated/api/file_sets/add_file_to_set_filesets_file_set_id_files_post.py +190 -0
  13. lightningrod/_generated/api/file_sets/create_file_set_filesets_post.py +174 -0
  14. lightningrod/_generated/api/file_sets/get_file_set_filesets_file_set_id_get.py +168 -0
  15. lightningrod/_generated/api/file_sets/list_file_sets_filesets_get.py +173 -0
  16. lightningrod/_generated/api/file_sets/list_files_in_set_filesets_file_set_id_files_get.py +209 -0
  17. lightningrod/_generated/api/files/__init__.py +1 -0
  18. lightningrod/_generated/api/files/create_file_upload_files_post.py +174 -0
  19. lightningrod/_generated/api/open_ai_compatible/__init__.py +1 -0
  20. lightningrod/_generated/api/open_ai_compatible/chat_completions_openai_chat_completions_post.py +174 -0
  21. lightningrod/_generated/api/organizations/__init__.py +1 -0
  22. lightningrod/_generated/api/organizations/get_balance_organizations_balance_get.py +131 -0
  23. lightningrod/_generated/api/samples/__init__.py +1 -0
  24. lightningrod/_generated/api/samples/validate_sample_samples_validate_post.py +174 -0
  25. lightningrod/_generated/api/transform_jobs/__init__.py +1 -0
  26. lightningrod/_generated/api/transform_jobs/cost_estimation_transform_jobs_cost_estimation_post.py +174 -0
  27. lightningrod/_generated/api/transform_jobs/create_transform_job_transform_jobs_post.py +174 -0
  28. lightningrod/_generated/api/transform_jobs/get_transform_job_metrics_transform_jobs_job_id_metrics_get.py +172 -0
  29. lightningrod/_generated/api/transform_jobs/get_transform_job_transform_jobs_job_id_get.py +168 -0
  30. lightningrod/_generated/client.py +268 -0
  31. lightningrod/_generated/errors.py +16 -0
  32. lightningrod/_generated/models/__init__.py +147 -0
  33. lightningrod/_generated/models/answer_type.py +129 -0
  34. lightningrod/_generated/models/answer_type_enum.py +11 -0
  35. lightningrod/_generated/models/balance_response.py +61 -0
  36. lightningrod/_generated/models/chat_completion_request.py +216 -0
  37. lightningrod/_generated/models/chat_completion_response.py +146 -0
  38. lightningrod/_generated/models/chat_message.py +69 -0
  39. lightningrod/_generated/models/choice.py +97 -0
  40. lightningrod/_generated/models/create_dataset_response.py +61 -0
  41. lightningrod/_generated/models/create_file_set_file_request.py +101 -0
  42. lightningrod/_generated/models/create_file_set_file_request_metadata_type_0.py +46 -0
  43. lightningrod/_generated/models/create_file_set_request.py +83 -0
  44. lightningrod/_generated/models/create_file_upload_request.py +91 -0
  45. lightningrod/_generated/models/create_file_upload_response.py +165 -0
  46. lightningrod/_generated/models/create_file_upload_response_metadata_type_0.py +46 -0
  47. lightningrod/_generated/models/create_transform_job_request.py +312 -0
  48. lightningrod/_generated/models/dataset_metadata.py +69 -0
  49. lightningrod/_generated/models/estimate_cost_request.py +243 -0
  50. lightningrod/_generated/models/estimate_cost_response.py +117 -0
  51. lightningrod/_generated/models/event_usage_summary.py +80 -0
  52. lightningrod/_generated/models/file_set.py +128 -0
  53. lightningrod/_generated/models/file_set_file.py +203 -0
  54. lightningrod/_generated/models/file_set_file_metadata_type_0.py +57 -0
  55. lightningrod/_generated/models/file_set_query_seed_generator.py +136 -0
  56. lightningrod/_generated/models/file_set_seed_generator.py +126 -0
  57. lightningrod/_generated/models/filter_criteria.py +83 -0
  58. lightningrod/_generated/models/forward_looking_question.py +130 -0
  59. lightningrod/_generated/models/forward_looking_question_generator.py +217 -0
  60. lightningrod/_generated/models/gdelt_seed_generator.py +103 -0
  61. lightningrod/_generated/models/http_validation_error.py +79 -0
  62. lightningrod/_generated/models/job_usage.py +185 -0
  63. lightningrod/_generated/models/job_usage_by_step_type_0.py +59 -0
  64. lightningrod/_generated/models/label.py +143 -0
  65. lightningrod/_generated/models/list_file_set_files_response.py +113 -0
  66. lightningrod/_generated/models/list_file_sets_response.py +75 -0
  67. lightningrod/_generated/models/llm_model_usage_summary.py +98 -0
  68. lightningrod/_generated/models/mock_transform_config.py +243 -0
  69. lightningrod/_generated/models/mock_transform_config_metadata_additions.py +46 -0
  70. lightningrod/_generated/models/model_config.py +316 -0
  71. lightningrod/_generated/models/model_source_type.py +16 -0
  72. lightningrod/_generated/models/news_context.py +82 -0
  73. lightningrod/_generated/models/news_context_generator.py +127 -0
  74. lightningrod/_generated/models/news_seed_generator.py +220 -0
  75. lightningrod/_generated/models/paginated_samples_response.py +113 -0
  76. lightningrod/_generated/models/pipeline_metrics_response.py +99 -0
  77. lightningrod/_generated/models/question.py +74 -0
  78. lightningrod/_generated/models/question_and_label_generator.py +217 -0
  79. lightningrod/_generated/models/question_generator.py +217 -0
  80. lightningrod/_generated/models/question_pipeline.py +417 -0
  81. lightningrod/_generated/models/question_renderer.py +123 -0
  82. lightningrod/_generated/models/rag_context.py +82 -0
  83. lightningrod/_generated/models/response_message.py +69 -0
  84. lightningrod/_generated/models/rollout.py +130 -0
  85. lightningrod/_generated/models/rollout_generator.py +139 -0
  86. lightningrod/_generated/models/rollout_parsed_output_type_0.py +46 -0
  87. lightningrod/_generated/models/sample.py +323 -0
  88. lightningrod/_generated/models/sample_meta.py +46 -0
  89. lightningrod/_generated/models/seed.py +135 -0
  90. lightningrod/_generated/models/step_cost_breakdown.py +109 -0
  91. lightningrod/_generated/models/transform_job.py +268 -0
  92. lightningrod/_generated/models/transform_job_status.py +11 -0
  93. lightningrod/_generated/models/transform_step_metrics_response.py +131 -0
  94. lightningrod/_generated/models/transform_type.py +25 -0
  95. lightningrod/_generated/models/upload_samples_request.py +75 -0
  96. lightningrod/_generated/models/upload_samples_response.py +69 -0
  97. lightningrod/_generated/models/usage.py +77 -0
  98. lightningrod/_generated/models/usage_summary.py +102 -0
  99. lightningrod/_generated/models/usage_summary_events.py +59 -0
  100. lightningrod/_generated/models/usage_summary_llm_by_model.py +59 -0
  101. lightningrod/_generated/models/validate_sample_response.py +69 -0
  102. lightningrod/_generated/models/validation_error.py +90 -0
  103. lightningrod/_generated/models/web_search_labeler.py +120 -0
  104. lightningrod/_generated/py.typed +1 -0
  105. lightningrod/_generated/types.py +54 -0
  106. lightningrod/client.py +48 -0
  107. lightningrod/datasets/__init__.py +4 -0
  108. lightningrod/datasets/client.py +174 -0
  109. lightningrod/datasets/dataset.py +255 -0
  110. lightningrod/files/__init__.py +0 -0
  111. lightningrod/files/client.py +58 -0
  112. lightningrod/filesets/__init__.py +0 -0
  113. lightningrod/filesets/client.py +106 -0
  114. lightningrod/organization/__init__.py +0 -0
  115. lightningrod/organization/client.py +17 -0
  116. lightningrod/py.typed +0 -0
  117. lightningrod/transforms/__init__.py +0 -0
  118. lightningrod/transforms/client.py +154 -0
  119. lightningrod_ai-0.1.6.dist-info/METADATA +122 -0
  120. lightningrod_ai-0.1.6.dist-info/RECORD +123 -0
  121. lightningrod_ai-0.1.6.dist-info/WHEEL +5 -0
  122. lightningrod_ai-0.1.6.dist-info/licenses/LICENSE +23 -0
  123. lightningrod_ai-0.1.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
5
+
6
+ from attrs import define as _attrs_define
7
+ from attrs import field as _attrs_field
8
+
9
+ from ..types import UNSET, Unset
10
+
11
+ if TYPE_CHECKING:
12
+ from ..models.answer_type import AnswerType
13
+
14
+
15
+ T = TypeVar("T", bound="WebSearchLabeler")
16
+
17
+
18
+ @_attrs_define
19
+ class WebSearchLabeler:
20
+ """
21
+ Attributes:
22
+ config_type (Literal['WEB_SEARCH_LABELER'] | Unset): Type of transform configuration Default:
23
+ 'WEB_SEARCH_LABELER'.
24
+ confidence_threshold (float | Unset): Minimum confidence threshold for including questions Default: 0.9.
25
+ answer_type (AnswerType | None | Unset): The type of answer expected, used to guide the labeler
26
+ resolve_redirects (bool | Unset): Resolve redirect URLs to actual destinations Default: False.
27
+ """
28
+
29
+ config_type: Literal["WEB_SEARCH_LABELER"] | Unset = "WEB_SEARCH_LABELER"
30
+ confidence_threshold: float | Unset = 0.9
31
+ answer_type: AnswerType | None | Unset = UNSET
32
+ resolve_redirects: bool | Unset = False
33
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
34
+
35
+ def to_dict(self) -> dict[str, Any]:
36
+ from ..models.answer_type import AnswerType
37
+
38
+ config_type = self.config_type
39
+
40
+ confidence_threshold = self.confidence_threshold
41
+
42
+ answer_type: dict[str, Any] | None | Unset
43
+ if isinstance(self.answer_type, Unset):
44
+ answer_type = UNSET
45
+ elif isinstance(self.answer_type, AnswerType):
46
+ answer_type = self.answer_type.to_dict()
47
+ else:
48
+ answer_type = self.answer_type
49
+
50
+ resolve_redirects = self.resolve_redirects
51
+
52
+ field_dict: dict[str, Any] = {}
53
+ field_dict.update(self.additional_properties)
54
+ field_dict.update({})
55
+ if config_type is not UNSET:
56
+ field_dict["config_type"] = config_type
57
+ if confidence_threshold is not UNSET:
58
+ field_dict["confidence_threshold"] = confidence_threshold
59
+ if answer_type is not UNSET:
60
+ field_dict["answer_type"] = answer_type
61
+ if resolve_redirects is not UNSET:
62
+ field_dict["resolve_redirects"] = resolve_redirects
63
+
64
+ return field_dict
65
+
66
+ @classmethod
67
+ def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
68
+ from ..models.answer_type import AnswerType
69
+
70
+ d = dict(src_dict)
71
+ config_type = cast(Literal["WEB_SEARCH_LABELER"] | Unset, d.pop("config_type", UNSET))
72
+ if config_type != "WEB_SEARCH_LABELER" and not isinstance(config_type, Unset):
73
+ raise ValueError(f"config_type must match const 'WEB_SEARCH_LABELER', got '{config_type}'")
74
+
75
+ confidence_threshold = d.pop("confidence_threshold", UNSET)
76
+
77
+ def _parse_answer_type(data: object) -> AnswerType | None | Unset:
78
+ if data is None:
79
+ return data
80
+ if isinstance(data, Unset):
81
+ return data
82
+ try:
83
+ if not isinstance(data, dict):
84
+ raise TypeError()
85
+ answer_type_type_0 = AnswerType.from_dict(data)
86
+
87
+ return answer_type_type_0
88
+ except (TypeError, ValueError, AttributeError, KeyError):
89
+ pass
90
+ return cast(AnswerType | None | Unset, data)
91
+
92
+ answer_type = _parse_answer_type(d.pop("answer_type", UNSET))
93
+
94
+ resolve_redirects = d.pop("resolve_redirects", UNSET)
95
+
96
+ web_search_labeler = cls(
97
+ config_type=config_type,
98
+ confidence_threshold=confidence_threshold,
99
+ answer_type=answer_type,
100
+ resolve_redirects=resolve_redirects,
101
+ )
102
+
103
+ web_search_labeler.additional_properties = d
104
+ return web_search_labeler
105
+
106
+ @property
107
+ def additional_keys(self) -> list[str]:
108
+ return list(self.additional_properties.keys())
109
+
110
+ def __getitem__(self, key: str) -> Any:
111
+ return self.additional_properties[key]
112
+
113
+ def __setitem__(self, key: str, value: Any) -> None:
114
+ self.additional_properties[key] = value
115
+
116
+ def __delitem__(self, key: str) -> None:
117
+ del self.additional_properties[key]
118
+
119
+ def __contains__(self, key: str) -> bool:
120
+ return key in self.additional_properties
@@ -0,0 +1 @@
1
+ # Marker file for PEP 561
@@ -0,0 +1,54 @@
1
+ """Contains some shared types for properties"""
2
+
3
+ from collections.abc import Mapping, MutableMapping
4
+ from http import HTTPStatus
5
+ from typing import IO, BinaryIO, Generic, Literal, TypeVar
6
+
7
+ from attrs import define
8
+
9
+
10
+ class Unset:
11
+ def __bool__(self) -> Literal[False]:
12
+ return False
13
+
14
+
15
+ UNSET: Unset = Unset()
16
+
17
+ # The types that `httpx.Client(files=)` can accept, copied from that library.
18
+ FileContent = IO[bytes] | bytes | str
19
+ FileTypes = (
20
+ # (filename, file (or bytes), content_type)
21
+ tuple[str | None, FileContent, str | None]
22
+ # (filename, file (or bytes), content_type, headers)
23
+ | tuple[str | None, FileContent, str | None, Mapping[str, str]]
24
+ )
25
+ RequestFiles = list[tuple[str, FileTypes]]
26
+
27
+
28
+ @define
29
+ class File:
30
+ """Contains information for file uploads"""
31
+
32
+ payload: BinaryIO
33
+ file_name: str | None = None
34
+ mime_type: str | None = None
35
+
36
+ def to_tuple(self) -> FileTypes:
37
+ """Return a tuple representation that httpx will accept for multipart/form-data"""
38
+ return self.file_name, self.payload, self.mime_type
39
+
40
+
41
+ T = TypeVar("T")
42
+
43
+
44
+ @define
45
+ class Response(Generic[T]):
46
+ """A response from an endpoint"""
47
+
48
+ status_code: HTTPStatus
49
+ content: bytes
50
+ headers: MutableMapping[str, str]
51
+ parsed: T | None
52
+
53
+
54
+ __all__ = ["UNSET", "File", "FileTypes", "RequestFiles", "Response", "Unset"]
lightningrod/client.py ADDED
@@ -0,0 +1,48 @@
1
+ from typing import List
2
+
3
+ from lightningrod._generated.client import AuthenticatedClient
4
+ from lightningrod._generated.models.sample import Sample
5
+ from lightningrod.datasets.client import DatasetSamplesClient, DatasetsClient
6
+ from lightningrod.datasets.dataset import Dataset
7
+ from lightningrod.files.client import FilesClient
8
+ from lightningrod.filesets.client import FileSetsClient
9
+ from lightningrod.organization.client import OrganizationsClient
10
+ from lightningrod.transforms.client import TransformsClient
11
+
12
+
13
+ class LightningRod:
14
+ """
15
+ Python SDK for the Lightning Rod API.
16
+
17
+ Args:
18
+ api_key: Your Lightning Rod API key
19
+ base_url: Base URL for the API (defaults to production)
20
+
21
+ Example:
22
+ >>> lr = LightningRod(api_key="your-api-key")
23
+ >>> config = QuestionPipeline(...)
24
+ >>> dataset = lr.transforms.run(config)
25
+ >>> samples = dataset.to_samples()
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ api_key: str,
31
+ base_url: str = "https://api.lightningrod.ai/api/public/v1"
32
+ ):
33
+ self.api_key: str = api_key
34
+ self.base_url: str = base_url.rstrip("/")
35
+ self._generated_client: AuthenticatedClient = AuthenticatedClient(
36
+ base_url=self.base_url,
37
+ token=api_key,
38
+ prefix="Bearer",
39
+ auth_header_name="Authorization",
40
+ )
41
+
42
+ self._dataset_samples: DatasetSamplesClient = DatasetSamplesClient(self._generated_client)
43
+ self.transforms: TransformsClient = TransformsClient(self._generated_client, self._dataset_samples)
44
+ self.datasets: DatasetsClient = DatasetsClient(self._generated_client, self._dataset_samples)
45
+ self.organization: OrganizationsClient = OrganizationsClient(self._generated_client)
46
+ # TODO(filesets): Enable when filesets are publicly supported
47
+ # self.files: FilesClient = FilesClient(self._generated_client)
48
+ # self.filesets: FileSetsClient = FileSetsClient(self._generated_client, self.files)
@@ -0,0 +1,4 @@
1
+ from lightningrod.datasets.client import DatasetsClient, DatasetSamplesClient
2
+ from lightningrod.datasets.dataset import Dataset, AsyncDataset
3
+
4
+ __all__ = ["DatasetsClient", "DatasetSamplesClient", "Dataset", "AsyncDataset"]
@@ -0,0 +1,174 @@
1
+ from typing import List, Optional
2
+
3
+ from lightningrod._generated.models import (
4
+ HTTPValidationError,
5
+ UploadSamplesRequest,
6
+ )
7
+ from lightningrod._generated.models.sample import Sample
8
+ from lightningrod._generated.api.datasets import (
9
+ create_dataset_datasets_post,
10
+ get_dataset_datasets_dataset_id_get,
11
+ get_dataset_samples_datasets_dataset_id_samples_get,
12
+ upload_samples_datasets_dataset_id_samples_post,
13
+ )
14
+ from lightningrod._generated.types import Unset
15
+ from lightningrod._generated.client import AuthenticatedClient
16
+ from lightningrod.datasets.dataset import Dataset
17
+ from lightningrod._errors import handle_response_error
18
+
19
+
20
+ class DatasetSamplesClient:
21
+ def __init__(self, client: AuthenticatedClient):
22
+ self._client: AuthenticatedClient = client
23
+
24
+ def list(self, dataset_id: str) -> List[Sample]:
25
+ samples: List[Sample] = []
26
+ cursor: Optional[str] = None
27
+
28
+ while True:
29
+ response = get_dataset_samples_datasets_dataset_id_samples_get.sync_detailed(
30
+ dataset_id=dataset_id,
31
+ client=self._client,
32
+ limit=100,
33
+ cursor=cursor,
34
+ )
35
+
36
+ parsed = handle_response_error(response, "fetch samples")
37
+
38
+ samples.extend(parsed.samples)
39
+
40
+ if not parsed.has_more:
41
+ break
42
+ if isinstance(parsed.next_cursor, Unset) or parsed.next_cursor is None:
43
+ break
44
+ cursor = str(parsed.next_cursor)
45
+
46
+ return samples
47
+
48
+ def upload(
49
+ self,
50
+ dataset_id: str,
51
+ samples: List[Sample],
52
+ ) -> None:
53
+ """
54
+ Upload samples to an existing dataset.
55
+
56
+ Args:
57
+ dataset_id: ID of the dataset to upload samples to
58
+ samples: List of Sample objects to upload
59
+
60
+ Example:
61
+ >>> lr = LightningRod(api_key="your-api-key")
62
+ >>> samples = [Sample(seed=Seed(...), ...), ...]
63
+ >>> lr.datasets.upload(samples)
64
+ """
65
+ request = UploadSamplesRequest(samples=samples)
66
+
67
+ response = upload_samples_datasets_dataset_id_samples_post.sync_detailed(
68
+ dataset_id=dataset_id,
69
+ client=self._client,
70
+ body=request,
71
+ )
72
+
73
+ handle_response_error(response, "upload samples")
74
+
75
+
76
+ class DatasetsClient:
77
+ def __init__(self, client: AuthenticatedClient, dataset_samples_client: DatasetSamplesClient):
78
+ self._client: AuthenticatedClient = client
79
+ self._dataset_samples_client: DatasetSamplesClient = dataset_samples_client
80
+
81
+ def create(self) -> Dataset:
82
+ """
83
+ Create a new empty dataset.
84
+
85
+ Returns:
86
+ Dataset object representing the newly created dataset
87
+
88
+ Example:
89
+ >>> lr = LightningRod(api_key="your-api-key")
90
+ >>> dataset = lr.datasets.create()
91
+ >>> print(f"Created dataset: {dataset.id}")
92
+ """
93
+ response = create_dataset_datasets_post.sync_detailed(
94
+ client=self._client,
95
+ )
96
+
97
+ create_result = handle_response_error(response, "create dataset")
98
+
99
+ dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
100
+ dataset_id=create_result.id,
101
+ client=self._client,
102
+ )
103
+ dataset_result = handle_response_error(dataset_response, "get dataset")
104
+
105
+ return Dataset(
106
+ id=dataset_result.id,
107
+ num_rows=dataset_result.num_rows,
108
+ datasets_client=self._dataset_samples_client
109
+ )
110
+
111
+ def create_from_samples(
112
+ self,
113
+ samples: List[Sample],
114
+ batch_size: int = 1000,
115
+ ) -> Dataset:
116
+ """
117
+ Create a new dataset and upload samples to it.
118
+
119
+ This is a convenience method that creates a dataset and uploads all samples
120
+ in batches. Useful for creating input datasets from a collection of seeds.
121
+
122
+ Args:
123
+ samples: List of Sample objects to upload
124
+ batch_size: Number of samples to upload per batch (default: 1000)
125
+
126
+ Returns:
127
+ Dataset object with all samples uploaded
128
+
129
+ Example:
130
+ >>> lr = LightningRod(api_key="your-api-key")
131
+ >>> samples = [Sample(seed=Seed(...), ...), ...]
132
+ >>> dataset = lr.datasets.create_from_samples(samples, batch_size=1000)
133
+ >>> print(f"Created dataset with {dataset.num_rows} samples")
134
+ """
135
+ dataset = self.create()
136
+
137
+ for i in range(0, len(samples), batch_size):
138
+ batch = samples[i:i + batch_size]
139
+ self._dataset_samples_client.upload(dataset.id, batch)
140
+
141
+ dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
142
+ dataset_id=dataset.id,
143
+ client=self._client,
144
+ )
145
+ dataset_result = handle_response_error(dataset_response, "refresh dataset")
146
+
147
+ dataset.num_rows = dataset_result.num_rows
148
+ return dataset
149
+
150
+ def get(self, dataset_id: str) -> Dataset:
151
+ """
152
+ Get a dataset by ID.
153
+
154
+ Args:
155
+ dataset_id: ID of the dataset to retrieve
156
+
157
+ Returns:
158
+ Dataset object
159
+
160
+ Example:
161
+ >>> lr = LightningRod(api_key="your-api-key")
162
+ >>> dataset = lr.datasets.get("dataset-id-here")
163
+ """
164
+ dataset_response = get_dataset_datasets_dataset_id_get.sync_detailed(
165
+ dataset_id=dataset_id,
166
+ client=self._client,
167
+ )
168
+ dataset_result = handle_response_error(dataset_response, "get dataset")
169
+
170
+ return Dataset(
171
+ id=dataset_result.id,
172
+ num_rows=dataset_result.num_rows,
173
+ datasets_client=self._dataset_samples_client
174
+ )
@@ -0,0 +1,255 @@
1
+ from typing import List, Optional, Dict, Any, TYPE_CHECKING
2
+ import asyncio
3
+
4
+ from lightningrod._generated.models.sample import Sample
5
+ from lightningrod._generated.models.forward_looking_question import ForwardLookingQuestion
6
+ from lightningrod._generated.models.question import Question
7
+ from lightningrod._generated.models.news_context import NewsContext
8
+ from lightningrod._generated.models.rag_context import RAGContext
9
+ from lightningrod._generated.models.sample_meta import SampleMeta
10
+ from lightningrod._generated.types import UNSET, Unset
11
+
12
+ # avoid circular import
13
+ if TYPE_CHECKING:
14
+ from lightningrod.datasets.client import DatasetSamplesClient
15
+
16
+ class Dataset:
17
+ """
18
+ Represents a dataset in Lightning Rod.
19
+
20
+ A dataset contains rows of sample data. Use this class to access
21
+ dataset metadata and download the actual samples.
22
+
23
+ Note: Datasets should only be created through LightningRod methods,
24
+ not instantiated directly.
25
+
26
+ Attributes:
27
+ id: Unique identifier for the dataset
28
+ num_rows: Number of rows in the dataset
29
+
30
+ Example:
31
+ >>> lr = LightningRod(api_key="your-api-key")
32
+ >>> config = QuestionPipeline(...)
33
+ >>> dataset = lr.transforms.run(config)
34
+ >>> samples = dataset.to_samples()
35
+ >>> print(f"Dataset has {len(samples)} samples")
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ id: str,
41
+ num_rows: int,
42
+ datasets_client: "DatasetSamplesClient"
43
+ ):
44
+ self.id: str = id
45
+ self.num_rows: int = num_rows
46
+ self._datasets_client: "DatasetSamplesClient" = datasets_client
47
+ self._samples: Optional[List[Sample]] = None
48
+
49
+ def download(self) -> List[Sample]:
50
+ """
51
+ Download all samples from the dataset via the paginated API.
52
+
53
+ Returns:
54
+ List of Sample objects
55
+
56
+ Example:
57
+ >>> lr = LightningRod(api_key="your-api-key")
58
+ >>> dataset = lr.transforms.run(config)
59
+ >>> samples = dataset.download()
60
+ >>> for sample in samples:
61
+ ... print(sample.seed.seed_text)
62
+ """
63
+ self._samples = self._datasets_client.list(self.id)
64
+ return self._samples
65
+
66
+ def samples(self) -> List[Sample]:
67
+ """
68
+ Get all samples from the dataset.
69
+ Automatically downloads the samples if they haven't been downloaded yet.
70
+
71
+ Returns:
72
+ List of Sample objects
73
+ """
74
+ if not self._samples:
75
+ self.download()
76
+ return self._samples
77
+
78
+ def to_samples(self) -> List[Sample]:
79
+ """
80
+ Download all samples from the dataset via the paginated API.
81
+
82
+ Returns:
83
+ List of Sample objects
84
+
85
+ Example:
86
+ >>> lr = LightningRod(api_key="your-api-key")
87
+ >>> config = QuestionPipeline(...)
88
+ >>> dataset = lr.transforms.run(config)
89
+ >>> samples = dataset.to_samples()
90
+ >>> for sample in samples:
91
+ ... print(sample.seed.seed_text)
92
+ """
93
+ return self.samples()
94
+
95
+ def flattened(self) -> List[Dict[str, Any]]:
96
+ """
97
+ Convert all samples to a list of dictionaries.
98
+ Automatically downloads the samples if they haven't been downloaded yet.
99
+
100
+ Handles different question types (Question, ForwardLookingQuestion) and
101
+ extracts relevant fields from labels, seeds, and prompts.
102
+
103
+ Returns:
104
+ List of dictionaries, each representing a sample row
105
+
106
+ Example:
107
+ >>> lr = LightningRod(api_key="your-api-key")
108
+ >>> config = QuestionPipeline(...)
109
+ >>> dataset = lr.transforms.run(config)
110
+ >>> rows = dataset.flattened()
111
+ >>> import pandas as pd
112
+ >>> df = pd.DataFrame(rows)
113
+ """
114
+ samples = self.samples()
115
+ return [self._sample_to_dict(sample) for sample in samples]
116
+
117
+ def _sample_to_dict(self, sample: Sample) -> Dict[str, Any]:
118
+ row: Dict[str, Any] = {}
119
+
120
+ if sample.question and not isinstance(sample.question, Unset):
121
+ if isinstance(sample.question, ForwardLookingQuestion):
122
+ row['question.question_text'] = sample.question.question_text
123
+ row['question.date_close'] = sample.question.date_close.isoformat()
124
+ row['question.event_date'] = sample.question.event_date.isoformat()
125
+ row['question.resolution_criteria'] = sample.question.resolution_criteria
126
+ if sample.question.prediction_date is not None and not isinstance(sample.question.prediction_date, Unset):
127
+ row['question.prediction_date'] = sample.question.prediction_date.isoformat()
128
+ elif isinstance(sample.question, Question):
129
+ row['question.question_text'] = sample.question.question_text
130
+ else:
131
+ question_text = getattr(sample.question, 'question_text', None)
132
+ if question_text is not None:
133
+ row['question.question_text'] = question_text
134
+
135
+ if sample.label and not isinstance(sample.label, Unset):
136
+ row['label.label'] = sample.label.label
137
+ row['label.label_confidence'] = sample.label.label_confidence
138
+ if sample.label.resolution_date is not None and not isinstance(sample.label.resolution_date, Unset):
139
+ row['label.resolution_date'] = sample.label.resolution_date.isoformat()
140
+ if sample.label.reasoning is not None and not isinstance(sample.label.reasoning, Unset):
141
+ row['label.reasoning'] = sample.label.reasoning
142
+ if sample.label.answer_sources is not None and not isinstance(sample.label.answer_sources, Unset):
143
+ row['label.answer_sources'] = sample.label.answer_sources
144
+
145
+ if sample.prompt and not isinstance(sample.prompt, Unset):
146
+ row['prompt'] = sample.prompt
147
+
148
+ if sample.seed and not isinstance(sample.seed, Unset):
149
+ row['seed.seed_text'] = sample.seed.seed_text
150
+ if sample.seed.url is not None and not isinstance(sample.seed.url, Unset):
151
+ row['seed.url'] = sample.seed.url
152
+ if sample.seed.seed_creation_date is not None and not isinstance(sample.seed.seed_creation_date, Unset):
153
+ row['seed.seed_creation_date'] = sample.seed.seed_creation_date.isoformat()
154
+ if sample.seed.search_query is not None and not isinstance(sample.seed.search_query, Unset):
155
+ row['seed.search_query'] = sample.seed.search_query
156
+
157
+ if sample.is_valid is not None and not isinstance(sample.is_valid, Unset):
158
+ row['is_valid'] = sample.is_valid
159
+
160
+ if sample.context is not None and not isinstance(sample.context, Unset):
161
+ for idx, ctx in enumerate(sample.context):
162
+ if isinstance(ctx, NewsContext):
163
+ row[f'context.{idx}.rendered_context'] = ctx.rendered_context
164
+ row[f'context.{idx}.search_query'] = ctx.search_query
165
+ row[f'context.{idx}.context_type'] = ctx.context_type
166
+ elif isinstance(ctx, RAGContext):
167
+ row[f'context.{idx}.rendered_context'] = ctx.rendered_context
168
+ row[f'context.{idx}.document_id'] = ctx.document_id
169
+ row[f'context.{idx}.context_type'] = ctx.context_type
170
+
171
+ if sample.meta is not None and not isinstance(sample.meta, Unset):
172
+ if isinstance(sample.meta, SampleMeta):
173
+ for key, value in sample.meta.additional_properties.items():
174
+ row[f'meta.{key}'] = value
175
+
176
+ if sample.additional_properties:
177
+ for key, value in sample.additional_properties.items():
178
+ row[f'additional_properties.{key}'] = value
179
+
180
+ return row
181
+
182
+ class AsyncDataset:
183
+ """
184
+ Async wrapper for Dataset.
185
+
186
+ This class provides an async interface to Dataset operations by running
187
+ the synchronous operations in a thread pool using asyncio.to_thread.
188
+
189
+ Note: AsyncDatasets should only be created through AsyncLightningRod methods,
190
+ not instantiated directly.
191
+
192
+ Attributes:
193
+ id: Unique identifier for the dataset
194
+ num_rows: Number of rows in the dataset
195
+
196
+ Example:
197
+ >>> lr = AsyncLightningRod(api_key="your-api-key")
198
+ >>> config = QuestionPipeline(...)
199
+ >>> dataset = await lr.transforms.run(config)
200
+ >>> samples = await dataset.to_samples()
201
+ >>> print(f"Dataset has {len(samples)} samples")
202
+ """
203
+
204
+ def __init__(self, sync_dataset: Dataset):
205
+ self._sync_dataset: Dataset = sync_dataset
206
+
207
+ @property
208
+ def id(self) -> str:
209
+ return self._sync_dataset.id
210
+
211
+ @property
212
+ def num_rows(self) -> int:
213
+ return self._sync_dataset.num_rows
214
+
215
+ async def to_samples(self) -> List[Sample]:
216
+ """
217
+ Download all samples from the dataset via the paginated API.
218
+
219
+ All operations are run in a thread pool to avoid blocking the event loop.
220
+
221
+ Returns:
222
+ List of Sample objects
223
+
224
+ Example:
225
+ >>> lr = AsyncLightningRod(api_key="your-api-key")
226
+ >>> config = QuestionPipeline(...)
227
+ >>> dataset = await lr.transforms.run(config)
228
+ >>> samples = await dataset.to_samples()
229
+ >>> for sample in samples:
230
+ ... print(sample.seed.seed_text)
231
+ """
232
+ return await asyncio.to_thread(self._sync_dataset.to_samples)
233
+
234
+ async def flattened(self) -> List[Dict[str, Any]]:
235
+ """
236
+ Convert all samples to a list of dictionaries.
237
+ Automatically downloads the samples if they haven't been downloaded yet.
238
+
239
+ All operations are run in a thread pool to avoid blocking the event loop.
240
+
241
+ Handles different question types (Question, ForwardLookingQuestion) and
242
+ extracts relevant fields from labels, seeds, and prompts.
243
+
244
+ Returns:
245
+ List of dictionaries, each representing a sample row
246
+
247
+ Example:
248
+ >>> lr = AsyncLightningRod(api_key="your-api-key")
249
+ >>> config = QuestionPipeline(...)
250
+ >>> dataset = await lr.transforms.run(config)
251
+ >>> rows = await dataset.flattened()
252
+ >>> import pandas as pd
253
+ >>> df = pd.DataFrame(rows)
254
+ """
255
+ return await asyncio.to_thread(self._sync_dataset.flattened)
File without changes