modaic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modaic might be problematic. Click here for more details.

Files changed (39) hide show
  1. modaic/__init__.py +25 -0
  2. modaic/agents/rag_agent.py +33 -0
  3. modaic/agents/registry.py +84 -0
  4. modaic/auto_agent.py +228 -0
  5. modaic/context/__init__.py +34 -0
  6. modaic/context/base.py +1064 -0
  7. modaic/context/dtype_mapping.py +25 -0
  8. modaic/context/table.py +585 -0
  9. modaic/context/text.py +94 -0
  10. modaic/databases/__init__.py +35 -0
  11. modaic/databases/graph_database.py +269 -0
  12. modaic/databases/sql_database.py +355 -0
  13. modaic/databases/vector_database/__init__.py +12 -0
  14. modaic/databases/vector_database/benchmarks/baseline.py +123 -0
  15. modaic/databases/vector_database/benchmarks/common.py +48 -0
  16. modaic/databases/vector_database/benchmarks/fork.py +132 -0
  17. modaic/databases/vector_database/benchmarks/threaded.py +119 -0
  18. modaic/databases/vector_database/vector_database.py +722 -0
  19. modaic/databases/vector_database/vendors/milvus.py +408 -0
  20. modaic/databases/vector_database/vendors/mongodb.py +0 -0
  21. modaic/databases/vector_database/vendors/pinecone.py +0 -0
  22. modaic/databases/vector_database/vendors/qdrant.py +1 -0
  23. modaic/exceptions.py +38 -0
  24. modaic/hub.py +305 -0
  25. modaic/indexing.py +127 -0
  26. modaic/module_utils.py +341 -0
  27. modaic/observability.py +275 -0
  28. modaic/precompiled.py +429 -0
  29. modaic/query_language.py +321 -0
  30. modaic/storage/__init__.py +3 -0
  31. modaic/storage/file_store.py +239 -0
  32. modaic/storage/pickle_store.py +25 -0
  33. modaic/types.py +287 -0
  34. modaic/utils.py +21 -0
  35. modaic-0.1.0.dist-info/METADATA +281 -0
  36. modaic-0.1.0.dist-info/RECORD +39 -0
  37. modaic-0.1.0.dist-info/WHEEL +5 -0
  38. modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
  39. modaic-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,25 @@
1
+ import numpy as np
2
+
3
+ INTEGER_DTYPE_MAPPING = {
4
+ np.int8: "TINYINT",
5
+ np.int16: "SMALLINT",
6
+ np.int32: "INT",
7
+ np.int64: "BIGINT",
8
+ }
9
+
10
+ SPECIAL_INTEGER_DTYPE_MAPPING = {"Int64": "BIGINT", "UInt64": "BIGINT UNSIGNED"}
11
+
12
+ FLOAT_DTYPE_MAPPING = {
13
+ np.float16: "FLOAT",
14
+ np.float32: "FLOAT",
15
+ np.float64: "DOUBLE",
16
+ }
17
+
18
+ OTHER_DTYPE_MAPPING = {
19
+ "boolean": "BOOLEAN",
20
+ "datetime": "DATETIME",
21
+ "timedelta": "TIME",
22
+ "string": "VARCHAR(255)",
23
+ "category": "VARCHAR(255)",
24
+ "default": "TEXT",
25
+ }
@@ -0,0 +1,585 @@
1
+ import hashlib
2
+ import random
3
+ import re
4
+ import warnings
5
+ from abc import ABC
6
+ from contextlib import contextmanager
7
+ from pathlib import Path
8
+ from typing import IO, Any, Dict, List, Literal, Optional, Set
9
+
10
+ import duckdb
11
+ import numpy as np
12
+ import pandas as pd
13
+ from pydantic import PrivateAttr, ValidatorFunctionWrapHandler, field_validator, model_validator
14
+
15
+ from modaic.context.base import Context, HydratedAttr, requires_hydration
16
+ from modaic.types import Field
17
+
18
+ from ..storage.file_store import FileStore
19
+ from .dtype_mapping import (
20
+ FLOAT_DTYPE_MAPPING,
21
+ INTEGER_DTYPE_MAPPING,
22
+ OTHER_DTYPE_MAPPING,
23
+ SPECIAL_INTEGER_DTYPE_MAPPING,
24
+ )
25
+ from .text import Text
26
+
27
+
28
+ class BaseTable(Context, ABC):
29
+ name: str
30
+ _df: pd.DataFrame = PrivateAttr()
31
+
32
+ @field_validator("name", mode="before")
33
+ @classmethod
34
+ def sanitize_name(cls, name: str) -> str:
35
+ return sanitize_name(name)
36
+
37
+ def column_samples(self, col: str) -> list[Any]:
38
+ """
39
+ Return up to 3 distinct sample values from the given column.
40
+
41
+ Picks at most three unique, non-null, short (<64 chars) values from
42
+ the column, favoring speed by sampling after de-duplicating values.
43
+
44
+ Args:
45
+ col: Column name to sample from.
46
+
47
+ Returns:
48
+ A list with up to three JSON-serializable sample values.
49
+ """
50
+ # TODO look up columnn
51
+ series = self._df[col]
52
+
53
+ valid_values = [x for x in series.dropna().unique() if pd.notnull(x) and len(str(x)) < 64]
54
+ sample_values = random.sample(valid_values, min(3, len(valid_values)))
55
+
56
+ # Convert numpy types to Python native types for JSON serialization
57
+ converted_values = []
58
+ for val in sample_values:
59
+ if hasattr(val, "item"): # numpy types have .item() method
60
+ converted_values.append(val.item())
61
+ else:
62
+ converted_values.append(val)
63
+
64
+ return converted_values if converted_values else []
65
+
66
+ def get_col(self, col_name: str) -> pd.Series:
67
+ """
68
+ Gets a single column from the table.
69
+
70
+ Args:
71
+ col_name: Name of the column to get
72
+
73
+ Returns:
74
+ The specified column as a pandas Series.
75
+ """
76
+ return self._df[col_name]
77
+
78
+ def schema_info(self) -> dict[str, Any]:
79
+ column_dict = {}
80
+ for col in self._df.columns:
81
+ if isinstance(self._df[col], pd.DataFrame):
82
+ raise ValueError(f"Column {col} is a DataFrame, which is not supported.")
83
+ column_dict[col] = {
84
+ "type": _pandas_to_mysql_dtype(self._df[col].dtype),
85
+ "sample_values": self.column_samples(col),
86
+ }
87
+
88
+ schema_dict = {"table_name": self.name, "column_dict": column_dict}
89
+
90
+ return schema_dict
91
+
92
+ def query(self, query: str) -> pd.DataFrame:
93
+ """
94
+ Queries the table using DuckDB SQL.
95
+
96
+ Notes:
97
+ - Refer to the in-memory table as `this` (alias `This`).
98
+
99
+ Example:
100
+ ```python
101
+ # Select a few rows
102
+ df = table.query("SELECT * FROM this LIMIT 5")
103
+
104
+ # Aggregate over a column
105
+ df = table.query("SELECT category, COUNT(*) AS n FROM this GROUP BY category")
106
+ ```
107
+ """
108
+ return duckdb.query_df(self._df, self.name, query).to_df()
109
+
110
+ def embedme(self) -> str:
111
+ """
112
+ embedme method for table. Returns a markdown representation of the table.
113
+ """
114
+ return self.markdown()
115
+
116
+ def markdown(self) -> str: # TODO: add example
117
+ """
118
+ Converts the table to markdown format.
119
+ Returns a markdown representation of the table with the table name as header.
120
+ """
121
+ content = ""
122
+ content += f"Table name: {self.name}\n"
123
+
124
+ # Add header row
125
+ columns = [str(col) for col in self._df.columns]
126
+ content += "| " + " | ".join(columns) + " |\n"
127
+
128
+ # Add header separator
129
+ content += "| " + " | ".join(["---"] * len(columns)) + " |\n"
130
+
131
+ # Add data rows
132
+ for _, row in self._df.iterrows():
133
+ row_values = []
134
+ for value in row:
135
+ if pd.isna(value) or value is None:
136
+ row_values.append("")
137
+ else:
138
+ row_values.append(str(value))
139
+ content += "| " + " | ".join(row_values) + " |\n"
140
+
141
+ return content
142
+
143
+ def to_text(self) -> Text:
144
+ """
145
+ Converts the table to markdown and returns a Text context object.
146
+ """
147
+ return Text(self.markdown())
148
+
149
+
150
+ class Table(BaseTable):
151
+ content: str = ""
152
+ df: List[Dict[str, Any]] = Field(hidden=True)
153
+ _df: pd.DataFrame = PrivateAttr()
154
+
155
+ @model_validator(mode="wrap")
156
+ @classmethod
157
+ def truncate(cls, data: Any, handler: ValidatorFunctionWrapHandler) -> "Table":
158
+ df = data["df"]
159
+ if isinstance(df, pd.DataFrame):
160
+ serialized_df = df.to_dict(orient="records")
161
+ pd_df = df
162
+ else:
163
+ serialized_df = df
164
+ pd_df = pd.DataFrame(serialized_df)
165
+ data["df"] = serialized_df
166
+ self = handler(data)
167
+ self._df = pd_df
168
+ return self
169
+
170
+
171
+ class TableFile(BaseTable):
172
+ """
173
+ A Context object to represent table documents such as excel, csv and tsv files.
174
+ """
175
+
176
+ file_ref: str
177
+ file_type: Literal["xls", "xlsx", "csv", "tsv"]
178
+ sheet_name: Optional[str] = None
179
+ _df: pd.DataFrame = HydratedAttr()
180
+
181
+ @classmethod
182
+ def from_file(
183
+ cls,
184
+ file_ref: str,
185
+ file: Path | IO,
186
+ file_type: Literal["xls", "xlsx", "csv", "tsv"] = "xls",
187
+ name: Optional[str] = None,
188
+ sheet_name: Optional[str] = None,
189
+ **kwargs,
190
+ ) -> "TableFile":
191
+ # NOTE: Always set a sheet name for excel files
192
+ if file_type in ["xls", "xlsx"] and sheet_name is None:
193
+ xls = pd.ExcelFile(file)
194
+ sheet_name = xls.sheet_names[0]
195
+ # NOTE: ensure we always have a semantic name for the table
196
+ if name is None and file_type in ["xls", "xlsx"]:
197
+ name = sheet_name
198
+ elif name is None:
199
+ name = file_ref.split("/")[-1].split(".")[0]
200
+ instance = cls(
201
+ name=name,
202
+ file_ref=file_ref,
203
+ file_type=file_type,
204
+ sheet_name=sheet_name,
205
+ **kwargs,
206
+ )
207
+ instance._hydrate_from_file(file)
208
+ return instance
209
+
210
+ @classmethod
211
+ def from_file_store(cls, file_ref: str, file_store: FileStore, **kwargs) -> "TableFile":
212
+ file_result = file_store.get(file_ref)
213
+ if "sheet_name" in file_result.metadata:
214
+ sheet_name = file_result.metadata["sheet_name"]
215
+ else:
216
+ sheet_name = None
217
+ return cls.from_file(
218
+ file_ref,
219
+ file_result.file,
220
+ file_result.type,
221
+ name=file_result.name,
222
+ sheet_name=sheet_name,
223
+ **kwargs,
224
+ )
225
+
226
+ def hydrate(self, file_store: FileStore):
227
+ file = file_store.get(self.file_ref)
228
+ self._hydrate_from_file(file)
229
+
230
+ def _hydrate_from_file(self, file: Path | IO):
231
+ if self.file_type in ["excel", "xlsx"]:
232
+ if self.sheet_name is None:
233
+ df = pd.read_excel(file)
234
+ else:
235
+ df = pd.read_excel(file, sheet_name=self.sheet_name)
236
+ elif self.file_type == "csv":
237
+ df = pd.read_csv(file)
238
+ elif self.file_type == "tsv":
239
+ df = pd.read_csv(file, sep="\t")
240
+ else:
241
+ raise ValueError(f"TableFile got unsupported file type: {self.file_type}")
242
+ self._df = _process_df(df)
243
+
244
+ @requires_hydration
245
+ def column_samples(self, col: str) -> list[Any]:
246
+ """
247
+ Returns up to 3 distinct sample values from the given column.
248
+ """
249
+ return super().column_samples(col)
250
+
251
+ @requires_hydration
252
+ def get_col(self, col_name: str) -> pd.Series:
253
+ """
254
+ Gets a single column from the table.
255
+ """
256
+ return super().get_col(col_name)
257
+
258
+ @requires_hydration
259
+ def schema_info(self) -> dict[str, Any]:
260
+ """
261
+ Returns the schema information of the table.
262
+ """
263
+ return super().schema_info()
264
+
265
+ @requires_hydration
266
+ def query(self, query: str) -> pd.DataFrame:
267
+ """
268
+ Queries the table using DuckDB SQL.
269
+ """
270
+ return super().query(query)
271
+
272
+ @requires_hydration
273
+ def embedme(self) -> str:
274
+ """
275
+ Converts the table to markdown and returns a Text context object.
276
+ """
277
+ return super().embedme()
278
+
279
+ @requires_hydration
280
+ def markdown(self) -> str:
281
+ """
282
+ Converts the table to markdown format.
283
+ """
284
+ return super().markdown()
285
+
286
+ @requires_hydration
287
+ def to_text(self) -> Text:
288
+ """
289
+ Converts the table to markdown and returns a Text context object.
290
+ """
291
+ return super().to_text()
292
+
293
+
294
+ class BaseTabbedTable(Context):
295
+ names: Set[str]
296
+ _tables: Optional[Dict[str, pd.DataFrame]] = PrivateAttr()
297
+ _sql_db: Optional[duckdb.DuckDBPyConnection] = PrivateAttr()
298
+
299
+ def init_sql(self):
300
+ """
301
+ Initilizes and in memory sql database for querying
302
+ """
303
+ self._sql_db = duckdb.connect(database=":memory:")
304
+ for table_name, table in self._tables.items():
305
+ self._sql_db.register(table_name, table.df)
306
+
307
+ def close_sql(self):
308
+ """
309
+ Closes the in memory sql database
310
+ """
311
+ self._sql_db.close()
312
+ self._sql_db = None
313
+
314
+ @contextmanager
315
+ def sql(self):
316
+ self.init_sql()
317
+ yield self._sql_db
318
+ self.close_sql()
319
+
320
+ def query(self, query: str) -> Table:
321
+ """
322
+ Queries the in memory sql database
323
+ """
324
+ if self._sql_db is None:
325
+ raise ValueError(
326
+ "Attempted to run query on MultiTabbedTable without initializing the SQL database. Use with `with MultiTabbedTable.sql():` or `MultiTabbedTable.init_sql()`"
327
+ )
328
+ try:
329
+ df = self._sql_db.execute(query).fetchdf()
330
+ return Table(df=df, name="query_result")
331
+ except Exception as e:
332
+ raise ValueError("Error querying SQL database") from e
333
+
334
+
335
+ class TabbedTable(BaseTabbedTable):
336
+ tables: Dict[str, List[Dict[str, Any]]]
337
+ _tables: Dict[str, pd.DataFrame] = PrivateAttr()
338
+ _sql_db: duckdb.DuckDBPyConnection = PrivateAttr()
339
+
340
+ @field_validator("tables", mode="before")
341
+ @classmethod
342
+ def serialize_tables(cls, tables: Dict[str, pd.DataFrame] | Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
343
+ first_val = next(iter(tables.values()))
344
+ if isinstance(first_val, pd.DataFrame):
345
+ serialized_tables = {k: v.to_dict(orient="records") for k, v in tables.items()}
346
+ else:
347
+ serialized_tables = tables
348
+ return serialized_tables
349
+
350
+ @model_validator(mode="after")
351
+ def set_tables(self) -> "TabbedTable":
352
+ self._tables = {k: pd.DataFrame(v) for k, v in self.tables.items()}
353
+ return self
354
+
355
+
356
+ class TabbedTableFile(BaseTabbedTable):
357
+ file_ref: str
358
+ file_type: Literal["excel"] = "excel"
359
+ _tables: Dict[str, pd.DataFrame] = HydratedAttr()
360
+ _sql_db: duckdb.DuckDBPyConnection = HydratedAttr()
361
+
362
+ @classmethod
363
+ def from_file(
364
+ cls,
365
+ file_ref: str,
366
+ file: Path | IO,
367
+ file_type: Literal["excel"] = "excel",
368
+ names: Optional[List[str]] = None,
369
+ **kwargs,
370
+ ) -> "TabbedTableFile":
371
+ if file_type == "excel":
372
+ xls = pd.ExcelFile(file)
373
+ if names is None:
374
+ names = xls.sheet_names
375
+ else:
376
+ for name in names:
377
+ if name not in xls.sheet_names:
378
+ raise ValueError(f"Sheet name {name} not found in file")
379
+ elif names is None:
380
+ raise ValueError(f"names must be provided for file type: {file_type}")
381
+
382
+ instance = cls(
383
+ file_ref=file_ref,
384
+ file_type=file_type,
385
+ names=set(names),
386
+ **kwargs,
387
+ )
388
+ instance._hydrate_from_file(file)
389
+ return instance
390
+
391
+ @classmethod
392
+ def from_file_store(cls, file_ref: str, file_store: FileStore, **kwargs) -> "TabbedTableFile":
393
+ file_result = file_store.get(file_ref)
394
+ if "sheet_name" in file_result.metadata:
395
+ sheet_name = file_result.metadata["sheet_name"]
396
+ else:
397
+ sheet_name = 0
398
+ return cls.from_file(
399
+ file_ref,
400
+ file_result.file,
401
+ file_result.type,
402
+ name=file_result.name,
403
+ sheet_name=sheet_name,
404
+ **kwargs,
405
+ )
406
+
407
+ def hydrate(self, file_store: FileStore):
408
+ file = file_store.get(self.file_ref)
409
+ self._hydrate_from_file(file)
410
+
411
+ def _hydrate_from_file(self, file: Path | IO):
412
+ if isinstance(file, IO):
413
+ file = file.read()
414
+ else:
415
+ file = file.read_text()
416
+
417
+ if self.file_type == "excel":
418
+ df_dict = pd.read_excel(file, sheet_name=self.names)
419
+ else:
420
+ raise ValueError(f"Unsupported file type: {self.file_type}")
421
+
422
+ if isinstance(df_dict, dict):
423
+ self._tables = {name: _process_df(df) for name, df in df_dict.items()}
424
+ else:
425
+ self._tables = {self.names[0]: _process_df(df_dict)}
426
+
427
+ @classmethod
428
+ def from_gsheet():
429
+ raise NotImplementedError("Not implemented")
430
+
431
+ @classmethod
432
+ def from_sharepoint():
433
+ raise NotImplementedError("Not implemented")
434
+
435
+ @classmethod
436
+ def from_s3():
437
+ raise NotImplementedError("Not implemented")
438
+
439
+
440
+ def downcast_pd_series(series: pd.Series) -> pd.Series:
441
+ try:
442
+ return pd.to_numeric(series, downcast="integer")
443
+ except ValueError:
444
+ pass
445
+ try:
446
+ return pd.to_numeric(series, downcast="float")
447
+ except ValueError:
448
+ pass
449
+ try:
450
+ with warnings.catch_warnings():
451
+ warnings.simplefilter("ignore", UserWarning)
452
+ return pd.to_datetime(series)
453
+ except ValueError:
454
+ pass
455
+ return series
456
+
457
+
458
+ def _pandas_to_mysql_dtype(dtype: np.dtype) -> str:
459
+ if pd.api.types.is_integer_dtype(dtype):
460
+ if str(dtype) in SPECIAL_INTEGER_DTYPE_MAPPING:
461
+ return SPECIAL_INTEGER_DTYPE_MAPPING[str(dtype)]
462
+ return INTEGER_DTYPE_MAPPING.get(dtype, "INT")
463
+
464
+ elif pd.api.types.is_float_dtype(dtype):
465
+ return FLOAT_DTYPE_MAPPING.get(dtype, "FLOAT")
466
+
467
+ elif pd.api.types.is_bool_dtype(dtype):
468
+ return OTHER_DTYPE_MAPPING["boolean"]
469
+
470
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
471
+ return OTHER_DTYPE_MAPPING["datetime"]
472
+
473
+ elif pd.api.types.is_timedelta64_dtype(dtype):
474
+ return OTHER_DTYPE_MAPPING["timedelta"]
475
+
476
+ elif pd.api.types.is_string_dtype(dtype):
477
+ return OTHER_DTYPE_MAPPING["string"]
478
+
479
+ elif pd.api.types.is_categorical_dtype(dtype):
480
+ return OTHER_DTYPE_MAPPING["category"]
481
+
482
+ else:
483
+ return OTHER_DTYPE_MAPPING["default"]
484
+
485
+
486
+ def sanitize_name(original_name: str) -> str: # TODO: also sanitize SQL keywords
487
+ """
488
+ Sanitizes names of files and directories.
489
+
490
+ Rules:
491
+ 1. Remove file extension
492
+ 2. Replace illegal characters with underscores
493
+ 3. Replace consecutive consecutive underscores/illegal charachters with a single underscore
494
+ 4. Replace - with _
495
+ 5. no caps
496
+ 4. remove leading/trailing underscores
497
+ 5. if name starts with a number, add t_
498
+ 6. if name is longer than 64 chars, truncate and add a hash suffix
499
+
500
+ Args:
501
+ original_name: The name to sanitize.
502
+
503
+ Returns:
504
+ The sanitized name.
505
+ """
506
+ # Remove file extension
507
+ name = original_name.split(".")[0]
508
+
509
+ # Replace illegal characters with underscores
510
+ name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
511
+
512
+ # Remove consecutive underscores
513
+ name = re.sub(r"_+", "_", name)
514
+
515
+ # Remove leading/trailing underscores
516
+ if len(name) > 2:
517
+ name = name.strip("_")
518
+
519
+ # Convert to lowercase
520
+ name = name.lower()
521
+
522
+ # Ensure name does not start with a number
523
+ if name[0].isdigit():
524
+ name = "t_" + name
525
+
526
+ # If name is longer than 64 chars, truncate and add a hash suffix
527
+ if len(name) > 64:
528
+ prefix = name[:20].rstrip("_")
529
+ hash_suffix = hashlib.md5(name.encode("utf-8")).hexdigest()[:8]
530
+ name = f"{prefix}_{hash_suffix}"
531
+
532
+ return name
533
+
534
+
535
+ def is_valid_table_name(name: str) -> bool:
536
+ """
537
+ Checks if a name is a valid table name.
538
+
539
+ Args:
540
+ name: The name to validate.
541
+
542
+ Returns:
543
+ True if the name is valid, False otherwise.
544
+ """
545
+ valid = (
546
+ name.islower()
547
+ and not name.startswith("_")
548
+ and not name.endswith("_")
549
+ and not name[0].isdigit()
550
+ and len(name) <= 64
551
+ )
552
+ return valid
553
+
554
+
555
+ def downcast_column(col: pd.Series) -> pd.Series:
556
+ """
557
+ Downcasts a column to the smallest possible dtype.
558
+ """
559
+ return col.apply(downcast_pd_series)
560
+
561
+
562
+ def _sanitize_columns(df: pd.DataFrame) -> None:
563
+ columns = [sanitize_name(col) for col in df.columns]
564
+ columns = ["No" if i == 0 and (not col or pd.isna(col)) else col for i, col in enumerate(columns)]
565
+
566
+ seen = {}
567
+ new_columns = []
568
+ for col in columns:
569
+ if col in seen:
570
+ seen[col] += 1
571
+ new_columns.append(f"{col}_{seen[col]}")
572
+ else:
573
+ seen[col] = 0
574
+ new_columns.append(col)
575
+ df.columns = new_columns
576
+
577
+
578
+ def _process_df(df: pd.DataFrame) -> pd.DataFrame:
579
+ """
580
+ Processes the dataframe to ensure it is in the correct format.
581
+ """
582
+ # Downcast columns
583
+ df = df.apply(downcast_pd_series)
584
+ _sanitize_columns(df)
585
+ return df
modaic/context/text.py ADDED
@@ -0,0 +1,94 @@
1
+ from pathlib import Path
2
+ from typing import IO, Callable, Iterable, Iterator, List, Literal
3
+
4
+ from modaic.storage.file_store import FileStore
5
+
6
+ from .base import Context, HydratedAttr, requires_hydration
7
+
8
+
9
+ class Text(Context):
10
+ """
11
+ Text context class.
12
+ """
13
+
14
+ text: str
15
+
16
+ def chunk_text(
17
+ self,
18
+ chunk_fn: Callable[[str], Iterable[str | tuple[str, dict]]],
19
+ kwargs: dict = None,
20
+ ):
21
+ def chunk_text_fn(text_context: "Text") -> Iterator["Text"]:
22
+ for chunk in chunk_fn(text_context.text, **(kwargs or {})):
23
+ yield Text(text=chunk)
24
+
25
+ self.chunk_with(chunk_text_fn)
26
+
27
+ @classmethod
28
+ def from_file(cls, file: str | Path | IO, type: Literal["txt"] = "txt", params: dict = None) -> "Text":
29
+ """
30
+ Load a LongText instance from a file.
31
+ """
32
+ if isinstance(file, (str, Path)):
33
+ file = Path(file)
34
+ text = file.read_text()
35
+ elif isinstance(file, IO):
36
+ text = file.read()
37
+ return cls(text=text, **(params or {}))
38
+
39
+
40
+ class TextFile(Context):
41
+ """
42
+ Text document context class.
43
+ """
44
+
45
+ _text: str = HydratedAttr()
46
+ file_ref: str
47
+ file_type: Literal["txt"] = "txt"
48
+
49
+ def hydrate(self, file_store: FileStore) -> None:
50
+ file = file_store.get(self.file_ref)
51
+ if isinstance(file, Path):
52
+ file = file.read_text()
53
+ else:
54
+ file = file.read()
55
+ self._text = file
56
+
57
+ @classmethod
58
+ def from_file_store(
59
+ cls,
60
+ file_ref: str,
61
+ file_store: FileStore,
62
+ params: dict = None,
63
+ ) -> "TextFile":
64
+ """
65
+ Load a TextFile instance from a file.
66
+
67
+ Args:
68
+ file: The file to load.
69
+ file_store: The file store to use.
70
+ type: The type of file to load.
71
+ params: The parameters to pass to the constructor.
72
+ """
73
+ file = file_store.get(file_ref)
74
+ instance = cls(file_ref=file, **(params or {}))
75
+ instance.hydrate(file_store)
76
+ return instance
77
+
78
+ @requires_hydration
79
+ def dump(self) -> None:
80
+ return self._text
81
+
82
+ @requires_hydration
83
+ def chunk_text(
84
+ self,
85
+ chunk_fn: Callable[[str], List[str | tuple[str, dict]]],
86
+ kwargs: dict = None,
87
+ ):
88
+ def chunk_text_fn(text_context: "TextFile") -> List["Text"]:
89
+ chunks = []
90
+ for chunk in chunk_fn(text_context._text, **(kwargs or {})):
91
+ chunks.append(Text(text=chunk))
92
+ return chunks
93
+
94
+ self.apply_to_chunks(chunk_text_fn)