modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
modaic/hub.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
import git
|
|
6
|
+
import requests
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
|
|
9
|
+
from .exceptions import AuthenticationError, RepositoryExistsError, RepositoryNotFoundError
|
|
10
|
+
from .utils import compute_cache_dir
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
MODAIC_TOKEN = os.getenv("MODAIC_TOKEN")
|
|
15
|
+
MODAIC_GIT_URL = os.getenv("MODAIC_GIT_URL", "git.modaic.dev").replace("https://", "").rstrip("/")
|
|
16
|
+
MODAIC_CACHE = compute_cache_dir()
|
|
17
|
+
AGENTS_CACHE = Path(MODAIC_CACHE) / "agents"
|
|
18
|
+
|
|
19
|
+
USE_GITHUB = "github.com" in MODAIC_GIT_URL
|
|
20
|
+
|
|
21
|
+
user_info = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_remote_repo(repo_path: str, access_token: str, exist_ok: bool = False) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Creates a remote repository in modaic hub on the given repo_path. e.g. "user/repo"
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
repo_path: The path on Modaic hub to create the remote repository.
|
|
30
|
+
access_token: User's access token for authentication.
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
AlreadyExists: If the repository already exists on the hub.
|
|
35
|
+
AuthenticationError: If authentication fails or access is denied.
|
|
36
|
+
ValueError: If inputs are invalid.
|
|
37
|
+
"""
|
|
38
|
+
if not repo_path or not repo_path.strip():
|
|
39
|
+
raise ValueError("Repository ID cannot be empty")
|
|
40
|
+
|
|
41
|
+
repo_name = repo_path.strip().split("/")[-1]
|
|
42
|
+
|
|
43
|
+
if len(repo_name) > 100:
|
|
44
|
+
raise ValueError("Repository name too long (max 100 characters)")
|
|
45
|
+
|
|
46
|
+
api_url = get_repos_endpoint()
|
|
47
|
+
|
|
48
|
+
headers = get_headers(access_token)
|
|
49
|
+
|
|
50
|
+
payload = get_repo_payload(repo_name)
|
|
51
|
+
# TODO: Implement orgs path. Also switch to using gitea's push-to-create
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
response = requests.post(api_url, json=payload, headers=headers, timeout=30)
|
|
55
|
+
|
|
56
|
+
if response.status_code == 201:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
error_data = {}
|
|
60
|
+
try:
|
|
61
|
+
error_data = response.json()
|
|
62
|
+
except Exception:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
error_message = error_data.get("message", f"HTTP {response.status_code}")
|
|
66
|
+
|
|
67
|
+
if response.status_code == 409 or response.status_code == 422 or "already exists" in error_message.lower():
|
|
68
|
+
if exist_ok:
|
|
69
|
+
return
|
|
70
|
+
else:
|
|
71
|
+
raise RepositoryExistsError(f"Repository '{repo_name}' already exists")
|
|
72
|
+
elif response.status_code == 401:
|
|
73
|
+
raise AuthenticationError("Invalid access token or authentication failed")
|
|
74
|
+
elif response.status_code == 403:
|
|
75
|
+
raise AuthenticationError("Access denied - insufficient permissions")
|
|
76
|
+
else:
|
|
77
|
+
raise Exception(f"Failed to create repository: {error_message}")
|
|
78
|
+
|
|
79
|
+
except requests.exceptions.RequestException as e:
|
|
80
|
+
raise Exception(f"Request failed: {str(e)}") from e
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def push_folder_to_hub(
|
|
84
|
+
folder: str,
|
|
85
|
+
repo_path: str,
|
|
86
|
+
access_token: Optional[str] = None,
|
|
87
|
+
commit_message: str = "(no commit message)",
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Pushes a local directory as a commit to a remote git repository.
|
|
91
|
+
Steps:
|
|
92
|
+
1. If local folder is not a git repository, initialize it.
|
|
93
|
+
2. Checkout to a temporary 'snapshot' branch.
|
|
94
|
+
3. Add and commit all files in the local folder.
|
|
95
|
+
4. Add origin to local repository (if not already added) and fetch it
|
|
96
|
+
5. Switch to the 'main' branch at origin/main
|
|
97
|
+
6. use `git restore --source=snapshot --staged --worktree .` to sync working tree of 'main' to 'snapshot' and stage changes to 'main'
|
|
98
|
+
7. Commit changes to 'main' with custom commit message
|
|
99
|
+
8. Fast forward push to origin/main
|
|
100
|
+
9. Delete the 'snapshot' branch
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
folder: The local folder to push to the remote repository.
|
|
104
|
+
namespace: The namespace of the remote repository. e.g. "user" or "org"
|
|
105
|
+
repo_name: The name of the remote repository. e.g. "repo"
|
|
106
|
+
access_token: The access token to use for authentication.
|
|
107
|
+
commit_message: The message to use for the commit.
|
|
108
|
+
|
|
109
|
+
Warning:
|
|
110
|
+
This is not the standard pull/push workflow. No merging/rebasing is done.
|
|
111
|
+
This simply pushes new changes to make main mirror the local directory.
|
|
112
|
+
|
|
113
|
+
Warning:
|
|
114
|
+
Assumes that the remote repository exists
|
|
115
|
+
"""
|
|
116
|
+
if not access_token and MODAIC_TOKEN:
|
|
117
|
+
access_token = MODAIC_TOKEN
|
|
118
|
+
elif not access_token and not MODAIC_TOKEN:
|
|
119
|
+
raise AuthenticationError("MODAIC_TOKEN is not set")
|
|
120
|
+
|
|
121
|
+
if "/" not in repo_path:
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"Modaic fast paths not yet implemented. Please load agents with 'user/repo' or 'org/repo' format"
|
|
124
|
+
)
|
|
125
|
+
assert repo_path.count("/") <= 1, f"Extra '/' in repo_path: {repo_path}"
|
|
126
|
+
|
|
127
|
+
create_remote_repo(repo_path, access_token, exist_ok=True)
|
|
128
|
+
username = get_user_info(access_token)["login"]
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
# 1) If local folder is not a git repository, initialize it.
|
|
132
|
+
local_repo = git.Repo.init(folder)
|
|
133
|
+
# 2) Checkout to a temporary 'snapshot' branch (create or reset if exists).
|
|
134
|
+
local_repo.git.switch("-C", "snapshot")
|
|
135
|
+
# 3) Add and commit all files in the local folder.
|
|
136
|
+
if local_repo.is_dirty(untracked_files=True):
|
|
137
|
+
local_repo.git.add("-A")
|
|
138
|
+
local_repo.git.commit("-m", "Local snapshot before transplant")
|
|
139
|
+
# 4) Add origin to local repository (if not already added) and fetch it
|
|
140
|
+
remote_url = f"https://{username}:{access_token}@{MODAIC_GIT_URL}/{repo_path}.git"
|
|
141
|
+
try:
|
|
142
|
+
local_repo.create_remote("origin", remote_url)
|
|
143
|
+
except git.exc.GitCommandError:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
local_repo.git.fetch("origin")
|
|
148
|
+
except git.exc.GitCommandError:
|
|
149
|
+
raise RepositoryNotFoundError(f"Repository '{repo_path}' does not exist") from None
|
|
150
|
+
|
|
151
|
+
# 5) Switch to the 'main' branch at origin/main
|
|
152
|
+
local_repo.git.switch("-C", "main", "origin/main")
|
|
153
|
+
|
|
154
|
+
# 4) Make main’s index + working tree EXACTLY match snapshot (incl. deletions)
|
|
155
|
+
local_repo.git.restore("--source=snapshot", "--staged", "--worktree", ".")
|
|
156
|
+
|
|
157
|
+
# 5) One commit that transforms remote contents into your local snapshot
|
|
158
|
+
if local_repo.is_dirty(untracked_files=True):
|
|
159
|
+
local_repo.git.commit("-m", commit_message)
|
|
160
|
+
|
|
161
|
+
# 6) Fast-forward push: preserves prior remote history + your single commit
|
|
162
|
+
local_repo.git.push("-u", "origin", "main")
|
|
163
|
+
finally:
|
|
164
|
+
# clean up - switch to main and delete snapshot branch
|
|
165
|
+
try:
|
|
166
|
+
local_repo.git.switch("main")
|
|
167
|
+
except git.exc.GitCommandError:
|
|
168
|
+
local_repo.git.switch("-c", "main")
|
|
169
|
+
local_repo.git.branch("-D", "snapshot")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_headers(access_token: str) -> Dict[str, str]:
|
|
173
|
+
if USE_GITHUB:
|
|
174
|
+
return {
|
|
175
|
+
"Accept": "application/vnd.github+json",
|
|
176
|
+
"Authorization": f"Bearer {access_token}",
|
|
177
|
+
"X-GitHub-Api-Version": "2022-11-28",
|
|
178
|
+
}
|
|
179
|
+
else:
|
|
180
|
+
return {
|
|
181
|
+
"Authorization": f"token {access_token}",
|
|
182
|
+
"Content-Type": "application/json",
|
|
183
|
+
"Accept": "application/json",
|
|
184
|
+
"User-Agent": "ModaicClient/1.0",
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_repos_endpoint() -> str:
|
|
189
|
+
if USE_GITHUB:
|
|
190
|
+
return "https://api.github.com/user/repos"
|
|
191
|
+
else:
|
|
192
|
+
return f"https://{MODAIC_GIT_URL}/api/v1/user/repos"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def get_repo_payload(repo_name: str) -> Dict[str, Any]:
|
|
196
|
+
payload = {
|
|
197
|
+
"name": repo_name,
|
|
198
|
+
"description": "",
|
|
199
|
+
"private": False,
|
|
200
|
+
"auto_init": True,
|
|
201
|
+
"default_branch": "main",
|
|
202
|
+
}
|
|
203
|
+
if not USE_GITHUB:
|
|
204
|
+
payload["trust_model"] = "default"
|
|
205
|
+
return payload
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_user_info(access_token: str) -> Dict[str, Any]:
|
|
209
|
+
"""
|
|
210
|
+
Returns the user info for the given access token.
|
|
211
|
+
Caches the user info in the global user_info variable.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
access_token: The access token to get the user info for.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
{
|
|
218
|
+
"login": str,
|
|
219
|
+
"email": str,
|
|
220
|
+
"avatar_url": str,
|
|
221
|
+
"name": str,
|
|
222
|
+
}
|
|
223
|
+
"""
|
|
224
|
+
global user_info
|
|
225
|
+
if user_info:
|
|
226
|
+
return user_info
|
|
227
|
+
if USE_GITHUB:
|
|
228
|
+
response = requests.get("https://api.github.com/user", headers=get_headers(access_token)).json()
|
|
229
|
+
user_info = {
|
|
230
|
+
"login": response["login"],
|
|
231
|
+
"email": response["email"],
|
|
232
|
+
"avatar_url": response["avatar_url"],
|
|
233
|
+
"name": response["name"],
|
|
234
|
+
}
|
|
235
|
+
else:
|
|
236
|
+
response = requests.get(f"https://{MODAIC_GIT_URL}/api/v1/user", headers=get_headers(access_token)).json()
|
|
237
|
+
user_info = {
|
|
238
|
+
"login": response["login"],
|
|
239
|
+
"email": response["email"],
|
|
240
|
+
"avatar_url": response["avatar_url"],
|
|
241
|
+
"name": response["full_name"],
|
|
242
|
+
}
|
|
243
|
+
return user_info
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def git_snapshot(
|
|
247
|
+
repo_path: str,
|
|
248
|
+
*,
|
|
249
|
+
rev: str = "main",
|
|
250
|
+
access_token: Optional[str] = None,
|
|
251
|
+
) -> Path:
|
|
252
|
+
"""
|
|
253
|
+
Ensure a local cached checkout of a hub repository and return its path.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
repo_path: Hub path ("user/repo").
|
|
257
|
+
rev: Branch, tag, or full commit SHA to checkout; defaults to "main".
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Absolute path to the local cached repository under AGENTS_CACHE/repo_path.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
if access_token is None and MODAIC_TOKEN is not None:
|
|
264
|
+
access_token = MODAIC_TOKEN
|
|
265
|
+
elif access_token is None:
|
|
266
|
+
raise ValueError("Access token is required")
|
|
267
|
+
|
|
268
|
+
# If a local folder path is provided, just return it
|
|
269
|
+
repo_dir = Path(AGENTS_CACHE) / repo_path
|
|
270
|
+
username = get_user_info(access_token)["login"]
|
|
271
|
+
try:
|
|
272
|
+
repo_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
273
|
+
|
|
274
|
+
remote_url = f"https://{username}:{access_token}@{MODAIC_GIT_URL}/{repo_path}.git"
|
|
275
|
+
|
|
276
|
+
if not repo_dir.exists():
|
|
277
|
+
git.Repo.clone_from(remote_url, repo_dir, branch=rev)
|
|
278
|
+
return repo_dir
|
|
279
|
+
|
|
280
|
+
# Repo exists → update
|
|
281
|
+
repo = git.Repo(repo_dir)
|
|
282
|
+
if "origin" not in [r.name for r in repo.remotes]:
|
|
283
|
+
repo.create_remote("origin", remote_url)
|
|
284
|
+
else:
|
|
285
|
+
repo.remotes.origin.set_url(remote_url)
|
|
286
|
+
|
|
287
|
+
repo.remotes.origin.fetch()
|
|
288
|
+
target = rev
|
|
289
|
+
# Create/switch branch to track origin/target and hard reset to it
|
|
290
|
+
repo.git.switch("-C", target, f"origin/{target}")
|
|
291
|
+
repo.git.reset("--hard", f"origin/{target}")
|
|
292
|
+
return repo_dir
|
|
293
|
+
except Exception as e:
|
|
294
|
+
repo_dir.rmdir()
|
|
295
|
+
raise e
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def load_repo(repo_path: str, is_local: bool = False) -> Path:
|
|
299
|
+
if is_local:
|
|
300
|
+
path = Path(repo_path)
|
|
301
|
+
if not path.exists():
|
|
302
|
+
raise FileNotFoundError(f"Local repo path {repo_path} does not exist")
|
|
303
|
+
return path
|
|
304
|
+
else:
|
|
305
|
+
return git_snapshot(repo_path)
|
modaic/indexing.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import dspy
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .context.base import Context
|
|
9
|
+
from .observability import Trackable, track_modaic_obj
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Reranker(ABC, Trackable):
|
|
13
|
+
def __init__(self, *args, **kwargs):
|
|
14
|
+
ABC.__init__(self)
|
|
15
|
+
Trackable.__init__(self, **kwargs)
|
|
16
|
+
|
|
17
|
+
@track_modaic_obj
|
|
18
|
+
def __call__(
|
|
19
|
+
self,
|
|
20
|
+
query: str,
|
|
21
|
+
options: List[Context | Tuple[str, Context]],
|
|
22
|
+
k: int = 10,
|
|
23
|
+
**kwargs,
|
|
24
|
+
) -> List[Tuple[float, Context]]:
|
|
25
|
+
"""
|
|
26
|
+
Reranks the options based on the query.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
query: The query to rerank the options for.
|
|
30
|
+
options: The options to rerank. Each option is a Context or tuple of (embedme_string, Context).
|
|
31
|
+
k: The number of options to return.
|
|
32
|
+
**kwargs: Additional keyword arguments to pass to the reranker.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A list of tuples, where each tuple is (Context, score).
|
|
36
|
+
"""
|
|
37
|
+
embedmes = []
|
|
38
|
+
payloads = []
|
|
39
|
+
for option in options:
|
|
40
|
+
if isinstance(option, Context):
|
|
41
|
+
embedmes.append(option.embedme())
|
|
42
|
+
payloads.append(option)
|
|
43
|
+
elif isinstance(option, Tuple):
|
|
44
|
+
assert isinstance(option[0], str) and isinstance(option[1], Context), (
|
|
45
|
+
"options provided to rerank must be Context objects"
|
|
46
|
+
)
|
|
47
|
+
embedmes.append(option[0])
|
|
48
|
+
payloads.append(option[1])
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"Invalid option type: {type(option)}. Must be Context or Tuple[str, Context]")
|
|
51
|
+
|
|
52
|
+
results = self._rerank(query, embedmes, k, **kwargs)
|
|
53
|
+
|
|
54
|
+
return [(score, payloads[idx]) for idx, score in results]
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def _rerank(self, query: str, options: List[str], k: int = 10, **kwargs) -> List[Tuple[int, float]]:
|
|
58
|
+
"""
|
|
59
|
+
Reranks the options based on the query.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
query: The query to rerank the options for.
|
|
63
|
+
options: The options to rerank. Each option is a string.
|
|
64
|
+
k: The number of options to return.
|
|
65
|
+
**kwargs: Additional keyword arguments to pass to the reranker.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A list of tuples, where each tuple is (index, score).
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PineconeReranker(Reranker):
|
|
74
|
+
def __init__(self, model: str, api_key: Optional[str] = None, *args, **kwargs):
|
|
75
|
+
super().__init__(*args, **kwargs)
|
|
76
|
+
self.model = model
|
|
77
|
+
try:
|
|
78
|
+
from pinecone import Pinecone
|
|
79
|
+
except ImportError:
|
|
80
|
+
raise ImportError("Pinecone is not installed. Please install it with `uv add pinecone`")
|
|
81
|
+
|
|
82
|
+
if api_key is None:
|
|
83
|
+
self.pinecone = Pinecone(os.getenv("PINECONE_API_KEY"))
|
|
84
|
+
else:
|
|
85
|
+
self.pinecone = Pinecone(api_key)
|
|
86
|
+
|
|
87
|
+
def _rerank(
|
|
88
|
+
self,
|
|
89
|
+
query: str,
|
|
90
|
+
options: List[str],
|
|
91
|
+
k: int = 10,
|
|
92
|
+
parameters: Optional[Dict[str, Any]] = None,
|
|
93
|
+
) -> List[Tuple[int, float]]:
|
|
94
|
+
results = self.pinecone.inference.rerank(
|
|
95
|
+
model=self.model,
|
|
96
|
+
query=query,
|
|
97
|
+
documents=options,
|
|
98
|
+
top_n=k,
|
|
99
|
+
return_documents=False,
|
|
100
|
+
parameters=parameters,
|
|
101
|
+
)
|
|
102
|
+
return [(result.index, result.score) for result in results.data]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Embedder(dspy.Embedder):
|
|
106
|
+
"""
|
|
107
|
+
A wrapper around dspy.Embedder that automatically determines the output size of the model.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self, *args, embedding_dim: Optional[int] = None, **kwargs):
|
|
111
|
+
super().__init__(*args, **kwargs)
|
|
112
|
+
self.embedding_dim = embedding_dim
|
|
113
|
+
|
|
114
|
+
if self.embedding_dim is None:
|
|
115
|
+
output = self("hello")
|
|
116
|
+
self.embedding_dim = output.shape[0]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class DummyEmbedder(Embedder):
|
|
120
|
+
def __init__(self, embedding_dim: int = 512):
|
|
121
|
+
self.embedding_dim = embedding_dim
|
|
122
|
+
|
|
123
|
+
def __call__(self, text: str | List[str]) -> np.ndarray:
|
|
124
|
+
if isinstance(text, str):
|
|
125
|
+
return np.random.rand(self.embedding_dim)
|
|
126
|
+
else:
|
|
127
|
+
return np.random.rand(len(text), self.embedding_dim)
|