palimpzest 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.1.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.4.dist-info/RECORD +0 -83
  69. palimpzest-0.5.4.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/WHEEL +0 -0
@@ -1,369 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import abc
4
- import base64
5
- import json
6
- import os
7
- import sys
8
- from io import BytesIO
9
- from typing import Any, Callable
10
-
11
- import modal
12
- import pandas as pd
13
- from bs4 import BeautifulSoup
14
- from papermage import Document
15
-
16
- from palimpzest import constants
17
- from palimpzest.core.elements.records import DataRecord
18
- from palimpzest.core.lib.schemas import (
19
- DefaultSchema,
20
- File,
21
- ImageFile,
22
- PDFFile,
23
- Schema,
24
- TextFile,
25
- WebPage,
26
- XLSFile,
27
- )
28
- from palimpzest.tools.pdfparser import get_text_from_pdf
29
-
30
-
31
- # First level of abstraction
32
- class AbstractDataSource(abc.ABC):
33
- """
34
- An AbstractDataSource is an Iterable which yields DataRecords adhering to a given schema.
35
-
36
- This base class must have its `__iter__` method implemented by a subclass, with each
37
- subclass reading data files from some real-world source (i.e. a directory, an S3 prefix,
38
- etc.).
39
-
40
- Many (if not all) DataSources should use Schemas from `palimpzest.elements.core`.
41
- In the future, programmers can implement their own DataSources using custom Schemas.
42
- """
43
-
44
- def __init__(self, schema: Schema) -> None:
45
- self._schema = schema
46
-
47
- def __str__(self) -> str:
48
- return f"{self.__class__.__name__}(schema={self.schema})"
49
-
50
- def __eq__(self, __value: object) -> bool:
51
- return self.__dict__ == __value.__dict__
52
-
53
- @abc.abstractmethod
54
- def __len__(self) -> int: ...
55
-
56
- @abc.abstractmethod
57
- def get_item(self, idx: int) -> DataRecord: ...
58
-
59
- @abc.abstractmethod
60
- def get_size(self) -> int: ...
61
-
62
- @property
63
- def schema(self) -> Schema:
64
- return self._schema
65
-
66
- def copy(self) -> AbstractDataSource:
67
- raise NotImplementedError("You are calling this method from an abstract class.")
68
-
69
- def serialize(self) -> dict[str, Any]:
70
- return {"schema": self._schema.json_schema()}
71
-
72
-
73
- class DataSource(AbstractDataSource):
74
- def __init__(self, schema: Schema, dataset_id: str) -> None:
75
- super().__init__(schema)
76
- self.dataset_id = dataset_id
77
-
78
- def universal_identifier(self):
79
- """Return a unique identifier for this Set."""
80
- return self.dataset_id
81
-
82
- def __str__(self) -> str:
83
- return f"{self.__class__.__name__}(schema={self.schema}, dataset_id={self.dataset_id})"
84
-
85
-
86
- # Second level of abstraction
87
- class DirectorySource(DataSource):
88
- """DirectorySource returns multiple File objects from a real-world source (a directory on disk)"""
89
-
90
- def __init__(self, path: str, dataset_id: str, schema: Schema) -> None:
91
- self.filepaths = [
92
- os.path.join(path, filename)
93
- for filename in sorted(os.listdir(path))
94
- if os.path.isfile(os.path.join(path, filename))
95
- ]
96
- self.path = path
97
- super().__init__(schema, dataset_id)
98
-
99
- def serialize(self) -> dict[str, Any]:
100
- return {
101
- "schema": self.schema.json_schema(),
102
- "path": self.path,
103
- "source_type": "directory",
104
- }
105
-
106
- def __len__(self):
107
- return len(self.filepaths)
108
-
109
- def get_size(self):
110
- # Get the memory size of the files in the directory
111
- return sum([os.path.getsize(filepath) for filepath in self.filepaths])
112
-
113
- def get_item(self, idx: int):
114
- raise NotImplementedError("You are calling this method from an abstract class.")
115
-
116
-
117
- class FileSource(DataSource):
118
- """FileSource returns a single File object from a single real-world local file"""
119
-
120
- def __init__(self, path: str, dataset_id: str) -> None:
121
- super().__init__(File, dataset_id)
122
- self.filepath = path
123
-
124
- def copy(self):
125
- return FileSource(self.filepath, self.dataset_id)
126
-
127
- def serialize(self) -> dict[str, Any]:
128
- return {
129
- "schema": self.schema.json_schema(),
130
- "path": self.filepath,
131
- "source_type": "file",
132
- }
133
-
134
- def __len__(self):
135
- return 1
136
-
137
- def get_size(self):
138
- # Get the memory size of the filepath
139
- return os.path.getsize(self.filepath)
140
-
141
- def get_item(self, idx: int) -> DataRecord:
142
- dr = DataRecord(self.schema, source_id=self.filepath)
143
- dr.filename = self.filepath
144
- with open(self.filepath, "rb") as f:
145
- dr.contents = f.read()
146
-
147
- return dr
148
-
149
-
150
- class MemorySource(DataSource):
151
- """MemorySource returns multiple objects that reflect contents of an in-memory Python list
152
- TODO(gerardo): Add support for other types of in-memory data structures (he has some code
153
- for subclassing MemorySource on his branch)
154
- """
155
-
156
- def __init__(self, vals: Any, dataset_id: str = "default_memory_input"):
157
- if isinstance(vals, (str, int, float)):
158
- self.vals = [vals]
159
- elif isinstance(vals, tuple):
160
- self.vals = list(vals)
161
- else:
162
- self.vals = vals
163
- schema = Schema.from_df(self.vals) if isinstance(self.vals, pd.DataFrame) else DefaultSchema
164
- super().__init__(schema, dataset_id)
165
-
166
- def copy(self):
167
- return MemorySource(self.vals, self.dataset_id)
168
-
169
- def __len__(self):
170
- return len(self.vals)
171
-
172
- def get_size(self):
173
- return sum([sys.getsizeof(self.get_item(idx)) for idx in range(len(self))])
174
-
175
- def get_item(self, idx: int) -> DataRecord:
176
- dr = DataRecord(self.schema, source_id=idx)
177
- if isinstance(self.vals, pd.DataFrame):
178
- row = self.vals.iloc[idx]
179
- for field_name in row.index:
180
- field_name_str = f"column_{field_name}" if isinstance(field_name, (int, float)) else str(field_name)
181
- setattr(dr, field_name_str, row[field_name])
182
- else:
183
- dr.value = self.vals[idx]
184
-
185
- return dr
186
-
187
-
188
- # Third level of abstraction
189
- class HTMLFileDirectorySource(DirectorySource):
190
- def __init__(self, path: str, dataset_id: str) -> None:
191
- super().__init__(path=path, dataset_id=dataset_id, schema=WebPage)
192
- assert all([filename.endswith(tuple(constants.HTML_EXTENSIONS)) for filename in self.filepaths])
193
-
194
- def copy(self):
195
- return HTMLFileDirectorySource(self.path, self.dataset_id)
196
-
197
- def html_to_text_with_links(self, html):
198
- # Parse the HTML content
199
- soup = BeautifulSoup(html, "html.parser")
200
-
201
- # Find all hyperlink tags
202
- for a in soup.find_all("a"):
203
- # Check if the hyperlink tag has an 'href' attribute
204
- if a.has_attr("href"):
205
- # Replace the hyperlink with its text and URL in parentheses
206
- a.replace_with(f"{a.text} ({a['href']})")
207
-
208
- # Extract text from the modified HTML
209
- text = soup.get_text(separator="\n", strip=True)
210
- return text
211
-
212
- def get_item(self, idx: int) -> DataRecord:
213
- filepath = self.filepaths[idx]
214
- dr = DataRecord(self.schema, source_id=filepath)
215
- dr.filename = os.path.basename(filepath)
216
- with open(filepath) as f:
217
- text_content = f.read()
218
-
219
- html = text_content
220
- tokens = html.split()[: constants.MAX_HTML_ROWS]
221
- dr.html = " ".join(tokens)
222
-
223
- stripped_html = self.html_to_text_with_links(text_content)
224
- tokens = stripped_html.split()[: constants.MAX_HTML_ROWS]
225
- dr.text = " ".join(tokens)
226
-
227
- return dr
228
-
229
-
230
- class ImageFileDirectorySource(DirectorySource):
231
- def __init__(self, path: str, dataset_id: str) -> None:
232
- super().__init__(path=path, dataset_id=dataset_id, schema=ImageFile)
233
- assert all([filename.endswith(tuple(constants.IMAGE_EXTENSIONS)) for filename in self.filepaths])
234
-
235
- def copy(self):
236
- return ImageFileDirectorySource(self.path, self.dataset_id)
237
-
238
- def get_item(self, idx: int) -> DataRecord:
239
- filepath = self.filepaths[idx]
240
- dr = DataRecord(self.schema, source_id=filepath)
241
- dr.filename = os.path.basename(filepath)
242
- with open(filepath, "rb") as f:
243
- dr.contents = base64.b64encode(f.read())
244
- return dr
245
-
246
-
247
- class PDFFileDirectorySource(DirectorySource):
248
- def __init__(
249
- self,
250
- path: str,
251
- dataset_id: str,
252
- pdfprocessor: str = "modal",
253
- file_cache_dir: str = "/tmp",
254
- ) -> None:
255
- super().__init__(path=path, dataset_id=dataset_id, schema=PDFFile)
256
- assert all([filename.endswith(tuple(constants.PDF_EXTENSIONS)) for filename in self.filepaths])
257
- self.pdfprocessor = pdfprocessor
258
- self.file_cache_dir = file_cache_dir
259
-
260
- def copy(self):
261
- return PDFFileDirectorySource(self.path, self.dataset_id)
262
-
263
- def get_item(self, idx: int) -> DataRecord:
264
- filepath = self.filepaths[idx]
265
- pdf_filename = os.path.basename(filepath)
266
- with open(filepath, "rb") as f:
267
- pdf_bytes = f.read()
268
-
269
- if self.pdfprocessor == "modal":
270
- print("handling PDF processing remotely")
271
- remote_func = modal.Function.lookup("palimpzest.tools", "processPapermagePdf")
272
- else:
273
- remote_func = None
274
-
275
- # generate text_content from PDF
276
- if remote_func is not None:
277
- doc_json_str = remote_func.remote([pdf_bytes])
278
- docdict = json.loads(doc_json_str[0])
279
- doc = Document.from_json(docdict)
280
- text_content = ""
281
- for p in doc.pages:
282
- text_content += p.text
283
- else:
284
- text_content = get_text_from_pdf(pdf_filename, pdf_bytes, pdfprocessor=self.pdfprocessor, file_cache_dir=self.file_cache_dir)
285
-
286
- # construct data record
287
- dr = DataRecord(self.schema, source_id=filepath)
288
- dr.filename = pdf_filename
289
- dr.contents = pdf_bytes
290
- dr.text_contents = text_content
291
-
292
- return dr
293
-
294
-
295
- class TextFileDirectorySource(DirectorySource):
296
- def __init__(self, path: str, dataset_id: str) -> None:
297
- super().__init__(path=path, dataset_id=dataset_id, schema=TextFile)
298
-
299
- def copy(self):
300
- return TextFileDirectorySource(self.path, self.dataset_id)
301
-
302
- def get_item(self, idx: int) -> DataRecord:
303
- filepath = self.filepaths[idx]
304
- dr = DataRecord(self.schema, source_id=filepath)
305
- dr.filename = os.path.basename(filepath)
306
- with open(filepath) as f:
307
- dr.contents = f.read()
308
- return dr
309
-
310
-
311
- class XLSFileDirectorySource(DirectorySource):
312
- def __init__(self, path: str, dataset_id: str) -> None:
313
- super().__init__(path=path, dataset_id=dataset_id, schema=XLSFile)
314
- assert all([filename.endswith(tuple(constants.XLS_EXTENSIONS)) for filename in self.filepaths])
315
-
316
- def copy(self):
317
- return XLSFileDirectorySource(self.path, self.dataset_id)
318
-
319
- def get_item(self, idx: int) -> DataRecord:
320
- filepath = self.filepaths[idx]
321
- dr = DataRecord(self.schema, source_id=filepath)
322
- dr.filename = os.path.basename(filepath)
323
- with open(filepath, "rb") as f:
324
- dr.contents = f.read()
325
-
326
- xls = pd.ExcelFile(BytesIO(dr.contents), engine="openpyxl")
327
- dr.number_sheets = len(xls.sheet_names)
328
- dr.sheet_names = xls.sheet_names
329
- return dr
330
-
331
-
332
- # User-defined datasources
333
- class UserSource(DataSource):
334
- """UserSource is a DataSource that is created by the user and not loaded from a file"""
335
-
336
- def __init__(self, schema: Schema, dataset_id: str) -> None:
337
- super().__init__(schema, dataset_id)
338
-
339
- def serialize(self) -> dict[str, Any]:
340
- return {
341
- "schema": self.schema.json_schema(),
342
- "source_type": "user-defined:" + self.__class__.__name__,
343
- }
344
-
345
- def __len__(self):
346
- raise NotImplementedError("User needs to implement this method")
347
-
348
- def get_size(self):
349
- raise NotImplementedError("User may optionally implement this method.")
350
-
351
- def get_item(self, idx: int) -> DataRecord:
352
- raise NotImplementedError("User needs to implement this method.")
353
-
354
- def copy(self):
355
- raise NotImplementedError("User needs to implement this method.")
356
-
357
- class ValidationDataSource(UserSource):
358
- """
359
- TODO: update this class interface (and comment)
360
- """
361
-
362
- def get_val_length(self) -> int:
363
- raise NotImplementedError("User needs to implement this method.")
364
-
365
- def get_field_to_metric_fn(self) -> Callable:
366
- raise NotImplementedError("User needs to implement this method.")
367
-
368
- def get_item(self, idx: int, val: bool = False, include_label: bool = False) -> DataRecord:
369
- raise NotImplementedError("User needs to implement this method.")
File without changes
@@ -1,300 +0,0 @@
1
- import os
2
- import pickle
3
- from threading import Lock
4
-
5
- import pandas as pd
6
- import yaml
7
-
8
- from palimpzest import constants
9
- from palimpzest.config import Config
10
- from palimpzest.constants import DEFAULT_DATASET_ID_CHARS, MAX_DATASET_ID_CHARS, PZ_DIR
11
- from palimpzest.core.data.datasources import (
12
- DataSource,
13
- FileSource,
14
- HTMLFileDirectorySource,
15
- ImageFileDirectorySource,
16
- MemorySource,
17
- PDFFileDirectorySource,
18
- TextFileDirectorySource,
19
- UserSource,
20
- XLSFileDirectorySource,
21
- )
22
- from palimpzest.utils.hash_helpers import hash_for_id
23
-
24
-
25
- class DataDirectorySingletonMeta(type):
26
- _instances = {}
27
- _lock: Lock = Lock()
28
-
29
- def __call__(cls, *args, **kwargs):
30
- with cls._lock:
31
- if cls not in cls._instances:
32
- instance = super().__call__(*args, **kwargs)
33
- cls._instances[cls] = instance
34
- return cls._instances[cls]
35
-
36
-
37
- class CacheService:
38
- """This class manages the cache for the DataDirectory and other misc PZ components.
39
- Eventually modify this to be durable and to have expiration policies."""
40
-
41
- def __init__(self):
42
- self.all_caches = {}
43
-
44
- def get_cached_data(self, cache_name, cache_key):
45
- return self.all_caches.setdefault(cache_name, {}).get(cache_key, None)
46
-
47
- def put_cached_data(self, cache_name, cache_key, cache_val):
48
- self.all_caches.setdefault(cache_name, {})[cache_key] = cache_val
49
-
50
- def rm_cached_data(self, cache_name):
51
- if cache_name in self.all_caches:
52
- del self.all_caches[cache_name]
53
-
54
- def rm_cache(self):
55
- self.all_caches = {}
56
-
57
-
58
- class DataDirectory(metaclass=DataDirectorySingletonMeta):
59
- """The DataDirectory is a registry of data sources."""
60
-
61
- def __init__(self):
62
- self._registry = {}
63
- self._tempRegistry = {}
64
- self._cache = {}
65
- self._tempCache = {}
66
- self.cacheService = CacheService()
67
-
68
- # set up data directory
69
- self._dir = PZ_DIR
70
- current_config_path = os.path.join(self._dir, "current_config.yaml")
71
- if not os.path.exists(self._dir):
72
- os.makedirs(self._dir)
73
- os.makedirs(self._dir + "/data/registered")
74
- os.makedirs(self._dir + "/data/cache")
75
- with open(self._dir + "/data/cache/registry.pkl", "wb") as f:
76
- pickle.dump(self._registry, f)
77
-
78
- # create default config
79
- default_config = Config("default")
80
- default_config.set_current_config()
81
-
82
- # read current config (and dict. of configs) from disk
83
- self._current_config = None
84
- if os.path.exists(current_config_path):
85
- with open(current_config_path) as f:
86
- current_config_dict = yaml.safe_load(f)
87
- self._current_config = Config(current_config_dict["current_config_name"])
88
-
89
- # initialize the file cache directory, defaulting to the system's temporary directory "tmp/pz"
90
- pz_file_cache_dir = self.current_config.get("filecachedir")
91
- if pz_file_cache_dir and not os.path.exists(pz_file_cache_dir):
92
- os.makedirs(pz_file_cache_dir)
93
-
94
- # Unpickle the registry of data sources
95
- if os.path.exists(self._dir + "/data/cache/registry.pkl"):
96
- with open(self._dir + "/data/cache/registry.pkl", "rb") as f:
97
- self._registry = pickle.load(f)
98
-
99
- # Iterate through all items in the cache directory, and rebuild the table of entries
100
- for root, _, files in os.walk(self._dir + "/data/cache"):
101
- for file in files:
102
- if file.endswith(".cached"):
103
- cache_id = file[:-7]
104
- self._cache[cache_id] = root + "/" + file
105
-
106
- @property
107
- def current_config(self):
108
- if not self._current_config:
109
- raise Exception("No current config found.")
110
- return self._current_config
111
-
112
- def get_cache_service(self):
113
- return self.cacheService
114
-
115
- def get_config(self):
116
- return self.current_config._load_config()
117
-
118
- def get_file_cache_dir(self):
119
- return self.current_config.get("filecachedir")
120
-
121
- #
122
- # These methods handle properly registered data files, meant to be kept over the long haul
123
- #
124
- def register_local_directory(self, path, dataset_id):
125
- """Register a local directory as a data source."""
126
- self._registry[dataset_id] = ("dir", path)
127
- with open(self._dir + "/data/cache/registry.pkl", "wb") as f:
128
- pickle.dump(self._registry, f)
129
-
130
- def register_local_file(self, path, dataset_id):
131
- """Register a local file as a data source."""
132
- self._registry[dataset_id] = ("file", path)
133
- with open(self._dir + "/data/cache/registry.pkl", "wb") as f:
134
- pickle.dump(self._registry, f)
135
-
136
- def get_or_register_local_source(self, dataset_id_or_path):
137
- """Return a dataset from the registry."""
138
- if dataset_id_or_path in self._tempRegistry or dataset_id_or_path in self._registry:
139
- return self.get_registered_dataset(dataset_id_or_path)
140
- else:
141
- if os.path.isfile(dataset_id_or_path):
142
- self.register_local_file(dataset_id_or_path, dataset_id_or_path)
143
- elif os.path.isdir(dataset_id_or_path):
144
- self.register_local_directory(dataset_id_or_path, dataset_id_or_path)
145
- else:
146
- raise Exception(f"Path {dataset_id_or_path} is invalid. Does not point to a file or directory.")
147
- return self.get_registered_dataset(dataset_id_or_path)
148
-
149
- #TODO: need to revisit how to best leverage cache for memory sources.
150
- def register_memory_source(self, vals, dataset_id):
151
- """Register an in-memory dataset as a data source"""
152
- self._tempRegistry[dataset_id] = ("memory", vals)
153
-
154
- def get_or_register_memory_source(self, vals):
155
- dataset_id = hash_for_id(str(vals), max_chars=DEFAULT_DATASET_ID_CHARS)
156
- if dataset_id in self._tempRegistry:
157
- return self.get_registered_dataset(dataset_id)
158
- else:
159
- self.register_memory_source(vals, dataset_id)
160
- return self.get_registered_dataset(dataset_id)
161
-
162
- def register_user_source(self, src: UserSource, dataset_id: str):
163
- """Register a user source as a data source."""
164
- # user sources are always ephemeral
165
- self._tempRegistry[dataset_id] = ("user", src)
166
-
167
- def get_registered_dataset(self, dataset_id):
168
- """Return a dataset from the registry."""
169
- if dataset_id in self._tempRegistry:
170
- entry, rock = self._tempRegistry[dataset_id]
171
- elif dataset_id in self._registry:
172
- entry, rock = self._registry[dataset_id]
173
- else:
174
- raise Exception(f"Dataset {dataset_id} not found in the registry.")
175
-
176
- if entry == "dir":
177
- if all([f.endswith(tuple(constants.IMAGE_EXTENSIONS)) for f in os.listdir(rock)]):
178
- return ImageFileDirectorySource(rock, dataset_id)
179
- elif all([f.endswith(tuple(constants.PDF_EXTENSIONS)) for f in os.listdir(rock)]):
180
- pdfprocessor = self.current_config.get("pdfprocessor")
181
- if not pdfprocessor:
182
- raise Exception("No PDF processor found in the current config.")
183
- file_cache_dir = self.get_file_cache_dir()
184
- if not file_cache_dir:
185
- raise Exception("No file cache directory found.")
186
- return PDFFileDirectorySource(
187
- path=rock, dataset_id=dataset_id, pdfprocessor=pdfprocessor, file_cache_dir=file_cache_dir
188
- )
189
- elif all([f.endswith(tuple(constants.XLS_EXTENSIONS)) for f in os.listdir(rock)]):
190
- return XLSFileDirectorySource(rock, dataset_id)
191
- elif all([f.endswith(tuple(constants.HTML_EXTENSIONS)) for f in os.listdir(rock)]):
192
- return HTMLFileDirectorySource(rock, dataset_id)
193
- else:
194
- return TextFileDirectorySource(rock, dataset_id)
195
-
196
- elif entry == "file":
197
- return FileSource(rock, dataset_id)
198
- elif entry == "memory":
199
- return MemorySource(rock, dataset_id)
200
- elif entry == "user":
201
- src = rock
202
- return src
203
- else:
204
- raise Exception("Unknown entry type")
205
-
206
- def get_registered_dataset_type(self, dataset_id):
207
- """Return the type of the given dataset in the registry."""
208
- if dataset_id in self._tempRegistry:
209
- entry, _ = self._tempRegistry[dataset_id]
210
- elif dataset_id in self._registry:
211
- entry, _ = self._registry[dataset_id]
212
- else:
213
- raise Exception("Cannot find dataset", dataset_id, "in the registry.")
214
-
215
- return entry
216
-
217
- def list_registered_datasets(self):
218
- """Return a list of registered datasets."""
219
- return self._registry.items()
220
-
221
- def rm_registered_dataset(self, dataset_id):
222
- """Remove a dataset from the registry."""
223
- del self._registry[dataset_id]
224
- with open(self._dir + "/data/cache/registry.pkl", "wb") as f:
225
- pickle.dump(self._registry, f)
226
-
227
- #
228
- # These methods handle cached results. They are meant to be persisted for performance reasons,
229
- # but can always be recomputed if necessary.
230
- #
231
- def get_cached_result(self, cache_id):
232
- """Return a cached result."""
233
- cached_result = None
234
- if cache_id not in self._cache:
235
- return cached_result
236
-
237
- with open(self._cache[cache_id], "rb") as f:
238
- cached_result = pickle.load(f)
239
-
240
- return MemorySource(cached_result, cache_id)
241
-
242
- def clear_cache(self, keep_registry=False):
243
- """Clear the cache."""
244
- self._cache = {}
245
- self._tempCache = {}
246
-
247
- # Delete all files in the cache directory (except registry.pkl if keep_registry=True)
248
- for root, _, files in os.walk(self._dir + "/data/cache"):
249
- for file in files:
250
- if os.path.basename(file) != "registry.pkl" or keep_registry is False:
251
- os.remove(root + "/" + file)
252
-
253
- def has_cached_answer(self, cache_id):
254
- """Check if a dataset is in the cache."""
255
- return cache_id in self._cache
256
-
257
- def open_cache(self, cache_id):
258
- if cache_id is not None and cache_id not in self._cache and cache_id not in self._tempCache:
259
- self._tempCache[cache_id] = []
260
- return True
261
- return False
262
-
263
- def append_cache(self, cache_id, data):
264
- self._tempCache[cache_id].append(data)
265
-
266
- def close_cache(self, cache_id):
267
- """Close the cache."""
268
- filename = self._dir + "/data/cache/" + cache_id + ".cached"
269
- try:
270
- with open(filename, "wb") as f:
271
- pickle.dump(self._tempCache[cache_id], f)
272
- except pickle.PicklingError:
273
- print("Warning: Failed to save cache due to pickling error")
274
- os.remove(filename)
275
- del self._tempCache[cache_id]
276
- self._cache[cache_id] = filename
277
-
278
- def exists(self, dataset_id):
279
- return dataset_id in self._registry
280
-
281
- def get_path(self, dataset_id):
282
- if dataset_id not in self._registry:
283
- raise Exception("Cannot find dataset", dataset_id, "in the registry.")
284
- entry, path = self._registry[dataset_id]
285
- return path
286
-
287
- def get_or_register_dataset(self, source: str | list | pd.DataFrame | DataSource):
288
- if isinstance(source, str):
289
- if len(source) > MAX_DATASET_ID_CHARS:
290
- raise Exception(f"""Dataset ID {source} is too long. Maximum length is {MAX_DATASET_ID_CHARS} characters.
291
- If you're passing a string data source, please wrap it in a list or pd.DataFrame.""")
292
- source = self.get_or_register_local_source(source)
293
- elif isinstance(source, (list, pd.DataFrame)):
294
- source = self.get_or_register_memory_source(source)
295
- elif isinstance(source, DataSource):
296
- pass
297
- else:
298
- raise Exception(f"Invalid source type: {type(source)}, We only support pd.DataFrame, list, and str")
299
-
300
- return source