modaic 0.10.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modaic/precompiled.py ADDED
@@ -0,0 +1,608 @@
1
+ import importlib
2
+ import inspect
3
+ import json
4
+ import os
5
+ import pathlib
6
+ import sys
7
+ import warnings
8
+ from abc import ABC, abstractmethod
9
+ from pathlib import Path
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Optional,
14
+ Type,
15
+ TypeVar,
16
+ )
17
+
18
+ import dspy
19
+ from pydantic import BaseModel
20
+
21
+ from modaic.observability import Trackable, track_modaic_obj
22
+
23
+ from .exceptions import MissingSecretError
24
+ from .hub import Commit, load_repo, sync_and_push
25
+
26
+ C = TypeVar("C", bound="PrecompiledConfig")
27
+ A = TypeVar("A", bound="PrecompiledProgram")
28
+ R = TypeVar("R", bound="Retriever")
29
+
30
+ _DSPY_SIGNATURE_PREFIX = "__dspy_signature__:"
31
+
32
+
33
+ class PrecompiledConfig(BaseModel):
34
+ model: Optional[str] = None
35
+
36
+ @staticmethod
37
+ def _is_dspy_signature(obj: Any) -> bool:
38
+ """Check if an object is a DSPy Signature class (not instance)."""
39
+ try:
40
+ # Check if it's a class that inherits from dspy.Signature
41
+ return isinstance(obj, type) and issubclass(obj, dspy.Signature)
42
+ except (TypeError, AttributeError):
43
+ return False
44
+
45
+ @classmethod
46
+ def _get_signature_module_path(cls, sig_class: Type) -> str:
47
+ """Get the module path for a DSPy signature class."""
48
+ from .module_utils import resolve_project_root
49
+
50
+ module_name = sig_class.__module__
51
+
52
+ # If it's defined in __main__, try to resolve the actual module path
53
+ if module_name == "__main__":
54
+ module = sys.modules[module_name]
55
+ if hasattr(module, "__file__") and module.__file__:
56
+ file_path = Path(module.__file__)
57
+ try:
58
+ project_root = resolve_project_root()
59
+ rel_path = file_path.relative_to(project_root).with_suffix("")
60
+ module_path = str(rel_path).replace("/", ".")
61
+ return f"{module_path}.{sig_class.__name__}"
62
+ except (ValueError, FileNotFoundError):
63
+ # Fallback to just using the class name
64
+ return sig_class.__name__
65
+
66
+ return f"{module_name}.{sig_class.__name__}"
67
+
68
+ @classmethod
69
+ def _serialize_dspy_signatures(cls, obj: Any) -> Any:
70
+ """Recursively serialize DSPy Signature classes in nested structures."""
71
+ if cls._is_dspy_signature(obj):
72
+ module_path = cls._get_signature_module_path(obj)
73
+ return f"{_DSPY_SIGNATURE_PREFIX}{module_path}"
74
+ elif isinstance(obj, dict):
75
+ return {key: cls._serialize_dspy_signatures(value) for key, value in obj.items()}
76
+ elif isinstance(obj, (list, tuple)):
77
+ return type(obj)(cls._serialize_dspy_signatures(item) for item in obj)
78
+ else:
79
+ return obj
80
+
81
+ @classmethod
82
+ def _deserialize_dspy_signatures(cls, obj: Any) -> Any:
83
+ """Recursively deserialize DSPy Signature classes from nested structures."""
84
+ if isinstance(obj, str) and obj.startswith(_DSPY_SIGNATURE_PREFIX):
85
+ # Extract the module path
86
+ module_path = obj[len(_DSPY_SIGNATURE_PREFIX) :]
87
+ # Import and return the signature class
88
+ module_name, _, class_name = module_path.rpartition(".")
89
+ try:
90
+ module = importlib.import_module(module_name)
91
+ return getattr(module, class_name)
92
+ except (ImportError, AttributeError) as e:
93
+ warnings.warn(
94
+ f"Failed to import DSPy signature '{module_path}': {e}. Returning the serialized string instead.",
95
+ stacklevel=2,
96
+ )
97
+ return obj
98
+ elif isinstance(obj, dict):
99
+ return {key: cls._deserialize_dspy_signatures(value) for key, value in obj.items()}
100
+ elif isinstance(obj, (list, tuple)):
101
+ return type(obj)(cls._deserialize_dspy_signatures(item) for item in obj)
102
+ else:
103
+ return obj
104
+
105
+ def save_precompiled(
106
+ self,
107
+ path: str | Path,
108
+ ) -> None:
109
+ """
110
+ Saves the config to a config.json file in the given local folder.
111
+ Also saves the auto_classes.json with AutoConfig and any other auto classes passed to _extra_auto_classes
112
+
113
+ Args:
114
+ path: The local folder to save the config to.
115
+ """
116
+ # NOTE: since we don't allow PrecompiledConfig.push_to_hub(), when _extra_auto_classes is None we will assume that we don't need to save the auto_classes.json
117
+ self._save_precompiled(path)
118
+
119
+ def _save_precompiled(self, path: Path, extra_auto_classes: Optional[Dict[str, object]] = None) -> None:
120
+ """
121
+ Saves the config to a config.json file in the given local folder.
122
+ Also saves the auto_classes.json with AutoConfig and any other auto classes passed to _extra_auto_classes
123
+
124
+ Args:
125
+ path: The local folder to save the config to.
126
+ extra_auto_classes: An argument used internally to add extra auto classes to program repo
127
+ """
128
+ from .module_utils import _module_path
129
+
130
+ path = pathlib.Path(path)
131
+ path.mkdir(parents=True, exist_ok=True)
132
+
133
+ with open(path / "config.json", "w") as f:
134
+ json.dump(self.to_dict(), f, indent=2)
135
+
136
+ if extra_auto_classes is None:
137
+ return
138
+
139
+ auto_classes = {"AutoConfig": self}
140
+ auto_classes.update(extra_auto_classes)
141
+
142
+ auto_classes_paths = {k: _module_path(cls) for k, cls in auto_classes.items()}
143
+
144
+ with open(path / "auto_classes.json", "w") as f:
145
+ json.dump(auto_classes_paths, f, indent=2)
146
+
147
+ @classmethod
148
+ def from_precompiled(
149
+ cls: Type[C],
150
+ path: str | Path,
151
+ access_token: Optional[str] = None,
152
+ rev: str = "main",
153
+ **kwargs,
154
+ ) -> C:
155
+ """
156
+ Loads the config from a config.json file in the given path. The path can be a local directory or a repo on Modaic Hub.
157
+
158
+ Args:
159
+ path: The path to load the config from. Can be a local directory or a repo on Modaic Hub.
160
+ **kwargs: Additional keyword arguments used to override the default config.
161
+
162
+ Returns:
163
+ An instance of the PrecompiledConfig class.
164
+ """
165
+ local = is_local_path(path)
166
+ local_dir, _ = load_repo(path, access_token=access_token, is_local=local, rev=rev)
167
+ # TODO load repos from the hub if not local
168
+ path = local_dir / "config.json"
169
+ with open(path, "r") as f:
170
+ config_dict = json.load(f)
171
+ return cls.from_dict(config_dict, **kwargs)
172
+
173
+ @classmethod
174
+ def from_dict(cls: Type[C], dict: Dict, **kwargs) -> C:
175
+ """
176
+ Loads the config from a dictionary.
177
+
178
+ Args:
179
+ dict: A dictionary containing the config.
180
+ **kwargs: Additional keyword arguments used to override the default config.
181
+
182
+ Returns:
183
+ An instance of the PrecompiledConfig class.
184
+ """
185
+ # Deserialize any DSPy signatures
186
+ deserialized_dict = cls._deserialize_dspy_signatures(dict)
187
+ deserialized_kwargs = cls._deserialize_dspy_signatures(kwargs)
188
+ instance = cls(**{**deserialized_dict, **deserialized_kwargs})
189
+ return instance
190
+
191
+ @classmethod
192
+ def from_json(cls: Type[C], path: str, **kwargs) -> C:
193
+ """
194
+ Loads the config from a json file.
195
+
196
+ Args:
197
+ path: The path to load the config from.
198
+ **kwargs: Additional keyword arguments used to override the default config.
199
+
200
+ Returns:
201
+ An instance of the PrecompiledConfig class.
202
+ """
203
+ with open(path, "r") as f:
204
+ config_dict = json.load(f)
205
+ return cls.from_dict(config_dict, **kwargs)
206
+
207
+ def to_dict(self) -> Dict:
208
+ """
209
+ Converts the config to a dictionary. handling DSPy signatures.
210
+ """
211
+ result = self.model_dump()
212
+ # Serialize any DSPy signatures to importable module paths
213
+ return self._serialize_dspy_signatures(result)
214
+
215
+ def to_json(self) -> str:
216
+ """
217
+ Converts the config to a json string.
218
+ """
219
+ return self.model_dump_json()
220
+
221
+
222
+ # Use a metaclass to enforce super().__init__() with config
223
+ class PrecompiledProgram(dspy.Module):
224
+ """
225
+ Bases: `dspy.Module`
226
+
227
+ PrecompiledProgram supports observability tracking through DSPy callbacks.
228
+
229
+ Attributes:
230
+ config: The config for the program.
231
+ retriever: The retriever for the program.
232
+ _source: The source of the program.
233
+ _source_commit: The commit hash of the program.
234
+ _from_auto: Whether the program was loaded from AutoProgram.
235
+ """
236
+
237
+ config: PrecompiledConfig
238
+ retriever: Optional["Retriever"]
239
+ _source: Path = None
240
+ _source_commit: Optional[Commit] = None
241
+ _from_auto: bool = False
242
+
243
+ def __init__(
244
+ self,
245
+ config: Optional[PrecompiledConfig | dict] = None,
246
+ *,
247
+ retriever: Optional["Retriever"] = None,
248
+ **kwargs,
249
+ ):
250
+ if config is None:
251
+ config = self.__annotations__.get("config", PrecompiledConfig)()
252
+ elif isinstance(config, dict):
253
+ config = self.__annotations__.get("config", PrecompiledConfig)(**config)
254
+ elif type(config) is not self.__annotations__.get("config", PrecompiledConfig):
255
+ raise ValueError(
256
+ f"config must be an instance of {self.__class__.__name__}'s config class ({self.__annotations__.get('config', PrecompiledConfig)}). Sublasses are not allowed."
257
+ )
258
+ self.config = config # type: ignore
259
+ # create DSPy callback for observability if tracing is enabled
260
+
261
+ # initialize DSPy Module with callbacks
262
+ super().__init__()
263
+ self.retriever = retriever
264
+ # TODO: throw a warning if the config of the retriever has different values than the config of the program
265
+
266
+ def forward(self, **kwargs) -> str:
267
+ """
268
+ Forward pass for the program.
269
+
270
+ Args:
271
+ **kwargs: Additional keyword arguments.
272
+
273
+ Returns:
274
+ Forward pass result.
275
+ """
276
+ raise NotImplementedError(
277
+ "Forward pass for PrecompiledProgram is not implemented. You must implement a forward method in your subclass."
278
+ )
279
+
280
+ def save_precompiled(self, path: str, _with_auto_classes: bool = False) -> None:
281
+ """
282
+ Saves the program.json and the config.json to the given local folder.
283
+
284
+ Args:
285
+ path: The local folder to save the program and config to. Must be a local path.
286
+ _with_auto_classes: Internally used argument used to configure whether to save the auto classes mapping.
287
+ """
288
+ path = pathlib.Path(path)
289
+ extra_auto_classes = None
290
+ if _with_auto_classes:
291
+ extra_auto_classes = {"AutoProgram": self}
292
+ if self.retriever is not None:
293
+ extra_auto_classes["AutoRetriever"] = self.retriever
294
+ self.config._save_precompiled(path, extra_auto_classes)
295
+ self.save(path / "program.json")
296
+ _clean_secrets(path / "program.json")
297
+
298
+ @classmethod
299
+ def from_precompiled(
300
+ cls: Type[A],
301
+ path: str | Path,
302
+ config: Optional[PrecompiledConfig | dict] = None,
303
+ access_token: Optional[str] = None,
304
+ api_key: Optional[str | dict[str, str]] = None,
305
+ hf_token: Optional[str | dict[str, str]] = None,
306
+ rev: str = "main",
307
+ **kwargs,
308
+ ) -> A:
309
+ """
310
+ Loads the program and the config from the given path.
311
+
312
+ Args:
313
+ path: The path to load the program and config from. Can be a local path or a path on Modaic Hub.
314
+ config: A dictionary containg key-value pairs used to override the default config.
315
+ api_key: Your API key for your LM (NOT YOUR MODAIC ACCESS TOKEN)
316
+ hf_token: Your Hugging Face token.
317
+ **kwargs: Additional keyword arguments forwarded to the PrecompiledProgram's constructor.
318
+
319
+ Returns:
320
+ An instance of the PrecompiledProgram class.
321
+ """
322
+
323
+ if cls is PrecompiledProgram:
324
+ raise ValueError("from_precompiled() can only be used on a subclass of PrecompiledProgram.")
325
+
326
+ ConfigClass: Type[PrecompiledConfig] = cls.__annotations__.get("config", PrecompiledConfig) # noqa: N806
327
+ local = is_local_path(path)
328
+ local_dir, source_commit = load_repo(path, access_token=access_token, is_local=local, rev=rev)
329
+ config = config or {}
330
+ config = ConfigClass.from_precompiled(local_dir, **config)
331
+
332
+ # Check if the program takes in a config parameter
333
+ sig = inspect.signature(cls.__init__)
334
+ if "config" in sig.parameters:
335
+ program = cls(config=config, **kwargs)
336
+ else:
337
+ program = cls(**kwargs)
338
+ # Support new (program.json) and legacy (agent.json) naming
339
+ program_state_path = local_dir / "program.json"
340
+ agent_state_path = local_dir / "agent.json"
341
+ state_path = (
342
+ program_state_path if program_state_path.exists() else agent_state_path
343
+ ) # TODO: deprecate agent.json in next major release
344
+
345
+ if state_path.exists():
346
+ secrets = {"api_key": api_key, "hf_token": hf_token}
347
+ state = _get_state_with_secrets(state_path, secrets)
348
+ program.load_state(state)
349
+
350
+ # We set _source_commit to track the commit hash.
351
+ program._source = local_dir
352
+ program._source_commit = source_commit
353
+ return program
354
+
355
+ def push_to_hub(
356
+ self,
357
+ repo_path: str,
358
+ access_token: Optional[str] = None,
359
+ commit_message: str = "(no commit message)",
360
+ with_code: Optional[bool] = True,
361
+ private: bool = False,
362
+ branch: str = "main",
363
+ tag: str = None,
364
+ ) -> Commit:
365
+ """
366
+ Pushes the program and the config to the given repo_path.
367
+
368
+ Args:
369
+ repo_path: The path on Modaic hub to save the program and config to.
370
+ access_token: Your Modaic access token.
371
+ commit_message: The commit message to use when pushing to the hub.
372
+ with_code: Whether to save the code along with the program.json and config.json.
373
+ - Defaults to True if the Program was loaded via AutoProgram, otherwise defaults to False
374
+ """
375
+ # Default to with_code=True if self._source is provided, otherwise default to false
376
+ if with_code is None:
377
+ with_code = self._from_auto
378
+
379
+ return sync_and_push(
380
+ self,
381
+ repo_path,
382
+ access_token=access_token,
383
+ commit_message=commit_message,
384
+ private=private,
385
+ branch=branch,
386
+ tag=tag,
387
+ with_code=with_code,
388
+ )
389
+
390
+
391
+ class Retriever(ABC, Trackable):
392
+ config: PrecompiledConfig
393
+ _source: Optional[Path] = None
394
+ _source_commit: Optional[Commit] = None
395
+ _from_auto: bool = False
396
+
397
+ def __init__(self, config: Optional[PrecompiledConfig | dict] = None, **kwargs):
398
+ ABC.__init__(self)
399
+ Trackable.__init__(self, **kwargs)
400
+ if config is None:
401
+ config = self.__annotations__.get("config", PrecompiledConfig)()
402
+ elif isinstance(config, dict):
403
+ config = self.__annotations__.get("config", PrecompiledConfig)(**config)
404
+ elif type(config) is not self.__annotations__.get("config", PrecompiledConfig):
405
+ raise ValueError(
406
+ f"config must be an instance of {self.__class__.__name__}'s config class ({self.__annotations__.get('config', PrecompiledConfig)}). Sublasses are not allowed."
407
+ )
408
+ self.config = config # type: ignore
409
+
410
+ @track_modaic_obj
411
+ @abstractmethod
412
+ def retrieve(self, query: str, **kwargs):
413
+ pass
414
+
415
+ @classmethod
416
+ def from_precompiled(
417
+ cls: Type[R],
418
+ path: str | Path,
419
+ config: Optional[dict] = None,
420
+ access_token: Optional[str] = None,
421
+ rev: str = "main",
422
+ **kwargs,
423
+ ) -> R:
424
+ """
425
+ Loads the retriever and the config from the given path.
426
+ """
427
+ if cls is Retriever:
428
+ raise ValueError("from_precompiled() can only be used on a subclass of Retriever.")
429
+
430
+ ConfigClass: Type[PrecompiledConfig] = cls.__annotations__["config"] # noqa: N806
431
+ local = is_local_path(path)
432
+ local_dir, source_commit = load_repo(path, access_token=access_token, is_local=local, rev=rev)
433
+ config = config or {}
434
+ config = ConfigClass.from_precompiled(local_dir, **config)
435
+ sig = inspect.signature(cls.__init__)
436
+ if "config" in sig.parameters:
437
+ retriever = cls(config=config, **kwargs)
438
+ else:
439
+ retriever = cls(**kwargs)
440
+
441
+ # We set _source_commit to track the commit hash.
442
+ # _source is intentionally not set here because its initialized from Retriever and not AutoRetriever.
443
+ retriever._source = local_dir
444
+ retriever._source_commit = source_commit
445
+ return retriever
446
+
447
+ def save_precompiled(self, path: str | Path, _with_auto_classes: bool = False) -> None:
448
+ """
449
+ Saves the retriever configuration to the given path.
450
+
451
+ Args:
452
+ path: The path to save the retriever configuration and auto classes mapping.
453
+ _with_auto_classes: Internal argument used to configure whether to save the auto classes mapping.
454
+ """
455
+ path_obj = pathlib.Path(path)
456
+ extra_auto_classes = None
457
+ if _with_auto_classes:
458
+ extra_auto_classes = {"AutoRetriever": self}
459
+ self.config._save_precompiled(path_obj, extra_auto_classes)
460
+
461
+ def push_to_hub(
462
+ self,
463
+ repo_path: str,
464
+ access_token: Optional[str] = None,
465
+ commit_message: str = "(no commit message)",
466
+ with_code: Optional[bool] = None,
467
+ private: bool = False,
468
+ branch: str = "main",
469
+ tag: str = None,
470
+ ) -> Commit:
471
+ """
472
+ Pushes the retriever and the config to the given repo_path.
473
+
474
+ Args:
475
+ repo_path: The path on Modaic hub to save the DSPy programand config to.
476
+ access_token: Your Modaic access token.
477
+ commit_message: The commit message to use when pushing to the hub.
478
+ with_code: Whether to save the code along with the retriever.json and config.json.
479
+ - Defaults to True if the Retriever was loaded via AutoRetriever, otherwise defaults to False
480
+ """
481
+ # Default to with_code=True if self._source is provided, otherwise default to false
482
+ if with_code is None:
483
+ with_code = self._from_auto
484
+
485
+ return sync_and_push(
486
+ self,
487
+ repo_path,
488
+ access_token=access_token,
489
+ commit_message=commit_message,
490
+ private=private,
491
+ branch=branch,
492
+ tag=tag,
493
+ with_code=with_code,
494
+ )
495
+
496
+
497
+ class Indexer(Retriever):
498
+ config: PrecompiledConfig
499
+
500
+ @abstractmethod
501
+ def index(self, contents: Any, **kwargs):
502
+ pass
503
+
504
+
505
+ def is_local_path(s: str | Path) -> bool:
506
+ # absolute or relative filesystem path
507
+ if isinstance(s, Path):
508
+ return True
509
+ s = str(s)
510
+
511
+ if os.path.isabs(s) or s.startswith((".", "/", "\\")):
512
+ return True
513
+ parts = s.split("/")
514
+ # hub IDs: "repo" or "user/repo"
515
+ if len(parts) == 1:
516
+ raise ValueError(
517
+ f"Invalid repo: '{s}'. Please prefix local paths with './', '/', or '../' . And use 'user/repo' format for hub paths."
518
+ )
519
+ elif len(parts) == 2 and all(parts):
520
+ return False
521
+ return True
522
+
523
+
524
+ SECRET_MASK = "********"
525
+ COMMON_SECRETS = ["api_key", "hf_token"]
526
+
527
+
528
+ def _clean_secrets(path: Path, extra_secrets: Optional[list[str]] = None):
529
+ """
530
+ Removes all secret keys from `lm` dict in program.json file
531
+ """
532
+ secret_keys = COMMON_SECRETS + (extra_secrets or [])
533
+
534
+ with open(path, "r") as f:
535
+ d = json.load(f)
536
+
537
+ for predictor in d.values():
538
+ lm = predictor.get("lm", None)
539
+ if lm is None:
540
+ continue
541
+ for k in lm.keys():
542
+ if k in secret_keys:
543
+ lm[k] = SECRET_MASK
544
+
545
+ with open(path, "w") as f:
546
+ json.dump(d, f, indent=2)
547
+
548
+
549
+ def _get_state_with_secrets(path: Path, secrets: dict[str, str | dict[str, str] | None]):
550
+ """`
551
+ Fills secret keys in `lm` dict in program.json file
552
+
553
+ Args:
554
+ path: The path to the program.json file.
555
+ secrets: A dictionary containing the secrets to fill in the `lm` dict.
556
+ - Dict[k,v] where k is the name of a secret (e.g. "api_key") and v is the value of the secret
557
+ - If v is a string, every lm will use v for k
558
+ - if v is a dict, each key of v should be the name of a named predictor
559
+ (e.g. "my_program.predict", "my_program.summarizer") mapping to the secret value for that predictor
560
+ Returns:
561
+ A dictionary containing the state of the program.json file with the secrets filled in.
562
+ """
563
+ with open(path, "r") as f:
564
+ named_predictors = json.load(f)
565
+
566
+ def _get_secret(predictor_name: str, secret_name: str) -> Optional[str]:
567
+ if secret_val := secrets.get(secret_name):
568
+ if isinstance(secret_val, str):
569
+ return secret_val
570
+ elif isinstance(secret_val, dict):
571
+ return secret_val.get(predictor_name)
572
+ return None
573
+
574
+ for predictor_name, predictor in named_predictors.items():
575
+ lm = predictor.get("lm", {}) or {}
576
+ for kw, arg in lm.items():
577
+ if kw in COMMON_SECRETS and arg != "" and arg != SECRET_MASK:
578
+ warnings.warn(
579
+ f"{str(path)} exposes the secret key {kw}. Please remove it or ensure this file is not made public.",
580
+ stacklevel=2,
581
+ )
582
+ secret = _get_secret(predictor_name, kw)
583
+ if secret is not None and arg != "" and arg != SECRET_MASK:
584
+ raise ValueError(
585
+ f"Failed to fill insert secret value for {predictor_name}['lm']['{kw}']. It is already set to {arg}"
586
+ )
587
+ elif secret is None and kw in COMMON_SECRETS:
588
+ raise MissingSecretError(f"Please specify a value for {kw} in the secrets dictionary", kw)
589
+ elif secret is not None:
590
+ lm[kw] = secret
591
+ return named_predictors
592
+
593
+
594
+ # Deprecated alias for backward compatibility
595
+ PrecompiledAgent = PrecompiledProgram
596
+
597
+
598
+ def __getattr__(name: str):
599
+ """Handle deprecated imports with warnings."""
600
+ if name == "PrecompiledAgent":
601
+ warnings.warn(
602
+ "PrecompiledAgent is deprecated and will be removed in a future version. "
603
+ "Please use PrecompiledProgram instead for better parity with DSPy.",
604
+ DeprecationWarning,
605
+ stacklevel=2,
606
+ )
607
+ return PrecompiledProgram
608
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1 @@
1
+ from .predict import Predict, PredictConfig # noqa: F401
@@ -0,0 +1,51 @@
1
+ import warnings
2
+ from typing import Optional
3
+
4
+ import dspy
5
+
6
+ from ..hub import Commit
7
+ from ..precompiled import PrecompiledConfig, PrecompiledProgram
8
+ from ..serializers import SerializableLM, SerializableSignature
9
+
10
+
11
+ # Config takes in a signature and also an LM since sometimes dspy.configure does not set the lm that is serialized.
12
+ class PredictConfig(PrecompiledConfig):
13
+ signature: SerializableSignature
14
+ lm: SerializableLM
15
+
16
+
17
+ class Predict(PrecompiledProgram):
18
+ config: PredictConfig
19
+
20
+ def __init__(self, config: PredictConfig, **kwargs):
21
+ super().__init__(config, **kwargs)
22
+ self.predictor = dspy.Predict(config.signature)
23
+ self.predictor.set_lm(lm=config.lm)
24
+ self.set_lm(lm=config.lm)
25
+
26
+ def forward(self, **kwargs) -> dspy.Prediction:
27
+ return self.predictor(**kwargs)
28
+
29
+ def push_to_hub(
30
+ self,
31
+ repo_path: str,
32
+ access_token: str = None,
33
+ commit_message: str = "(no commit message)",
34
+ with_code: Optional[bool] = None,
35
+ private: bool = False,
36
+ branch: str = "main",
37
+ tag: str = None,
38
+ ) -> Commit:
39
+ if with_code is not None:
40
+ warnings.warn(
41
+ "push_to_hub(with_code=...) is not supported for modaic.Predict, it will be ignored", stacklevel=2
42
+ )
43
+ return super().push_to_hub(
44
+ repo_path=repo_path,
45
+ access_token=access_token,
46
+ commit_message=commit_message,
47
+ with_code=False,
48
+ private=private,
49
+ branch=branch,
50
+ tag=tag,
51
+ )
@@ -0,0 +1,35 @@
1
+ from typing import Any
2
+
3
+ from modaic import Indexer, PrecompiledConfig, PrecompiledProgram
4
+
5
+ from .registry import builtin_config, builtin_indexer, builtin_program
6
+
7
+ program_name = "basic-rag"
8
+
9
+
10
+ @builtin_config(program_name)
11
+ class RAGProgramConfig(PrecompiledConfig):
12
+ def __init__(self):
13
+ pass
14
+
15
+ def forward(self, query: str) -> str:
16
+ return "hello"
17
+
18
+
19
+ @builtin_indexer(program_name)
20
+ class RAGIndexer(Indexer):
21
+ def __init__(self, config: RAGProgramConfig):
22
+ super().__init__(config)
23
+
24
+ def index(self, contents: Any):
25
+ pass
26
+
27
+
28
+ @builtin_program(program_name)
29
+ class RAGProgram(PrecompiledProgram):
30
+ def __init__(self, config: RAGProgramConfig, indexer: RAGIndexer):
31
+ super().__init__(config)
32
+ self.indexer = indexer
33
+
34
+ def forward(self, query: str) -> str:
35
+ return "hello"