letta-nightly 0.4.1.dev20241004104123__py3-none-any.whl → 0.4.1.dev20241005104008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (34) hide show
  1. letta/cli/cli.py +30 -365
  2. letta/cli/cli_config.py +70 -27
  3. letta/client/client.py +103 -11
  4. letta/config.py +80 -80
  5. letta/constants.py +6 -0
  6. letta/credentials.py +10 -1
  7. letta/errors.py +63 -5
  8. letta/llm_api/llm_api_tools.py +110 -52
  9. letta/local_llm/chat_completion_proxy.py +0 -3
  10. letta/main.py +1 -2
  11. letta/metadata.py +12 -0
  12. letta/providers.py +232 -0
  13. letta/schemas/block.py +1 -1
  14. letta/schemas/letta_request.py +17 -0
  15. letta/schemas/letta_response.py +11 -0
  16. letta/schemas/llm_config.py +18 -2
  17. letta/schemas/message.py +40 -13
  18. letta/server/rest_api/app.py +5 -0
  19. letta/server/rest_api/interface.py +115 -24
  20. letta/server/rest_api/routers/v1/agents.py +36 -3
  21. letta/server/rest_api/routers/v1/llms.py +6 -2
  22. letta/server/server.py +60 -87
  23. letta/server/static_files/assets/index-3ab03d5b.css +1 -0
  24. letta/server/static_files/assets/{index-4d08d8a3.js → index-9a9c449b.js} +69 -69
  25. letta/server/static_files/index.html +2 -2
  26. letta/settings.py +144 -114
  27. letta/utils.py +6 -1
  28. {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/METADATA +1 -1
  29. {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/RECORD +32 -32
  30. letta/local_llm/groq/api.py +0 -97
  31. letta/server/static_files/assets/index-156816da.css +0 -1
  32. {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/LICENSE +0 -0
  33. {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/WHEEL +0 -0
  34. {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/entry_points.txt +0 -0
letta/cli/cli.py CHANGED
@@ -1,26 +1,19 @@
1
- import json
2
1
  import logging
3
- import os
4
2
  import sys
5
3
  from enum import Enum
6
4
  from typing import Annotated, Optional
7
5
 
8
6
  import questionary
9
- import requests
10
7
  import typer
11
8
 
12
9
  import letta.utils as utils
13
10
  from letta import create_client
14
11
  from letta.agent import Agent, save_agent
15
- from letta.cli.cli_config import configure
16
12
  from letta.config import LettaConfig
17
13
  from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR
18
- from letta.credentials import LettaCredentials
19
14
  from letta.log import get_logger
20
15
  from letta.metadata import MetadataStore
21
- from letta.schemas.embedding_config import EmbeddingConfig
22
16
  from letta.schemas.enums import OptionState
23
- from letta.schemas.llm_config import LLMConfig
24
17
  from letta.schemas.memory import ChatMemory, Memory
25
18
  from letta.server.server import logger as server_logger
26
19
 
@@ -33,256 +26,6 @@ from letta.utils import open_folder_in_explorer, printd
33
26
  logger = get_logger(__name__)
34
27
 
35
28
 
36
- class QuickstartChoice(Enum):
37
- openai = "openai"
38
- # azure = "azure"
39
- letta_hosted = "letta"
40
- anthropic = "anthropic"
41
-
42
-
43
- def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice:
44
- try:
45
- return QuickstartChoice[choice_str]
46
- except KeyError:
47
- valid_options = [choice.name for choice in QuickstartChoice]
48
- raise ValueError(f"{choice_str} is not a valid QuickstartChoice. Valid options are: {valid_options}")
49
-
50
-
51
- def set_config_with_dict(new_config: dict) -> (LettaConfig, bool):
52
- """_summary_
53
-
54
- Args:
55
- new_config (dict): Dict of new config values
56
-
57
- Returns:
58
- new_config LettaConfig, modified (bool): Returns the new config and a boolean indicating if the config was modified
59
- """
60
- from letta.utils import printd
61
-
62
- old_config = LettaConfig.load()
63
- modified = False
64
- for k, v in vars(old_config).items():
65
- if k in new_config:
66
- if v != new_config[k]:
67
- printd(f"Replacing config {k}: {v} -> {new_config[k]}")
68
- modified = True
69
- # old_config[k] = new_config[k]
70
- setattr(old_config, k, new_config[k]) # Set the new value using dot notation
71
- else:
72
- printd(f"Skipping new config {k}: {v} == {new_config[k]}")
73
-
74
- # update embedding config
75
- if old_config.default_embedding_config:
76
- for k, v in vars(old_config.default_embedding_config).items():
77
- if k in new_config:
78
- if v != new_config[k]:
79
- printd(f"Replacing config {k}: {v} -> {new_config[k]}")
80
- modified = True
81
- # old_config[k] = new_config[k]
82
- setattr(old_config.default_embedding_config, k, new_config[k])
83
- else:
84
- printd(f"Skipping new config {k}: {v} == {new_config[k]}")
85
- else:
86
- modified = True
87
- fields = ["embedding_model", "embedding_dim", "embedding_chunk_size", "embedding_endpoint", "embedding_endpoint_type"]
88
- args = {}
89
- for field in fields:
90
- if field in new_config:
91
- args[field] = new_config[field]
92
- printd(f"Setting new config {field}: {new_config[field]}")
93
- old_config.default_embedding_config = EmbeddingConfig(**args)
94
-
95
- # update llm config
96
- if old_config.default_llm_config:
97
- for k, v in vars(old_config.default_llm_config).items():
98
- if k in new_config:
99
- if v != new_config[k]:
100
- printd(f"Replacing config {k}: {v} -> {new_config[k]}")
101
- modified = True
102
- # old_config[k] = new_config[k]
103
- setattr(old_config.default_llm_config, k, new_config[k])
104
- else:
105
- printd(f"Skipping new config {k}: {v} == {new_config[k]}")
106
- else:
107
- modified = True
108
- fields = ["model", "model_endpoint", "model_endpoint_type", "model_wrapper", "context_window"]
109
- args = {}
110
- for field in fields:
111
- if field in new_config:
112
- args[field] = new_config[field]
113
- printd(f"Setting new config {field}: {new_config[field]}")
114
- old_config.default_llm_config = LLMConfig(**args)
115
- return (old_config, modified)
116
-
117
-
118
- def quickstart(
119
- backend: Annotated[QuickstartChoice, typer.Option(help="Quickstart setup backend")] = "letta",
120
- latest: Annotated[bool, typer.Option(help="Use --latest to pull the latest config from online")] = False,
121
- debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
122
- terminal: bool = True,
123
- ):
124
- """Set the base config file with a single command
125
-
126
- This function and `configure` should be the ONLY places where LettaConfig.save() is called.
127
- """
128
-
129
- # setup logger
130
- utils.DEBUG = debug
131
- logging.getLogger().setLevel(logging.CRITICAL)
132
- if debug:
133
- logging.getLogger().setLevel(logging.DEBUG)
134
-
135
- # make sure everything is set up properly
136
- LettaConfig.create_config_dir()
137
- credentials = LettaCredentials.load()
138
-
139
- config_was_modified = False
140
- if backend == QuickstartChoice.letta_hosted:
141
- # if latest, try to pull the config from the repo
142
- # fallback to using local
143
- if latest:
144
- # Download the latest letta hosted config
145
- url = "https://raw.githubusercontent.com/cpacker/Letta/main/configs/letta_hosted.json"
146
- response = requests.get(url)
147
-
148
- # Check if the request was successful
149
- if response.status_code == 200:
150
- # Parse the response content as JSON
151
- config = response.json()
152
- # Output a success message and the first few items in the dictionary as a sample
153
- printd("JSON config file downloaded successfully.")
154
- new_config, config_was_modified = set_config_with_dict(config)
155
- else:
156
- typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
157
-
158
- # Load the file from the relative path
159
- script_dir = os.path.dirname(__file__) # Get the directory where the script is located
160
- backup_config_path = os.path.join(script_dir, "..", "configs", "letta_hosted.json")
161
- try:
162
- with open(backup_config_path, "r", encoding="utf-8") as file:
163
- backup_config = json.load(file)
164
- printd("Loaded backup config file successfully.")
165
- new_config, config_was_modified = set_config_with_dict(backup_config)
166
- except FileNotFoundError:
167
- typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
168
- return
169
- else:
170
- # Load the file from the relative path
171
- script_dir = os.path.dirname(__file__) # Get the directory where the script is located
172
- backup_config_path = os.path.join(script_dir, "..", "configs", "letta_hosted.json")
173
- try:
174
- with open(backup_config_path, "r", encoding="utf-8") as file:
175
- backup_config = json.load(file)
176
- printd("Loaded config file successfully.")
177
- new_config, config_was_modified = set_config_with_dict(backup_config)
178
- except FileNotFoundError:
179
- typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
180
- return
181
-
182
- elif backend == QuickstartChoice.openai:
183
- # Make sure we have an API key
184
- api_key = os.getenv("OPENAI_API_KEY")
185
- while api_key is None or len(api_key) == 0:
186
- # Ask for API key as input
187
- api_key = questionary.password("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask()
188
- credentials.openai_key = api_key
189
- credentials.save()
190
-
191
- # if latest, try to pull the config from the repo
192
- # fallback to using local
193
- if latest:
194
- url = "https://raw.githubusercontent.com/cpacker/Letta/main/configs/openai.json"
195
- response = requests.get(url)
196
-
197
- # Check if the request was successful
198
- if response.status_code == 200:
199
- # Parse the response content as JSON
200
- config = response.json()
201
- # Output a success message and the first few items in the dictionary as a sample
202
- new_config, config_was_modified = set_config_with_dict(config)
203
- else:
204
- typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
205
-
206
- # Load the file from the relative path
207
- script_dir = os.path.dirname(__file__) # Get the directory where the script is located
208
- backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
209
- try:
210
- with open(backup_config_path, "r", encoding="utf-8") as file:
211
- backup_config = json.load(file)
212
- printd("Loaded backup config file successfully.")
213
- new_config, config_was_modified = set_config_with_dict(backup_config)
214
- except FileNotFoundError:
215
- typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
216
- return
217
- else:
218
- # Load the file from the relative path
219
- script_dir = os.path.dirname(__file__) # Get the directory where the script is located
220
- backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
221
- try:
222
- with open(backup_config_path, "r", encoding="utf-8") as file:
223
- backup_config = json.load(file)
224
- printd("Loaded config file successfully.")
225
- new_config, config_was_modified = set_config_with_dict(backup_config)
226
- except FileNotFoundError:
227
- typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
228
- return
229
-
230
- elif backend == QuickstartChoice.anthropic:
231
- # Make sure we have an API key
232
- api_key = os.getenv("ANTHROPIC_API_KEY")
233
- while api_key is None or len(api_key) == 0:
234
- # Ask for API key as input
235
- api_key = questionary.password("Enter your Anthropic API key:").ask()
236
- credentials.anthropic_key = api_key
237
- credentials.save()
238
-
239
- script_dir = os.path.dirname(__file__) # Get the directory where the script is located
240
- backup_config_path = os.path.join(script_dir, "..", "configs", "anthropic.json")
241
- try:
242
- with open(backup_config_path, "r", encoding="utf-8") as file:
243
- backup_config = json.load(file)
244
- printd("Loaded config file successfully.")
245
- new_config, config_was_modified = set_config_with_dict(backup_config)
246
- except FileNotFoundError:
247
- typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
248
- return
249
-
250
- else:
251
- raise NotImplementedError(backend)
252
-
253
- if config_was_modified:
254
- printd(f"Saving new config file.")
255
- new_config.save()
256
- typer.secho(f"📖 Letta configuration file updated!", fg=typer.colors.GREEN)
257
- typer.secho(
258
- "\n".join(
259
- [
260
- f"🧠 model\t-> {new_config.default_llm_config.model}",
261
- f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
262
- ]
263
- ),
264
- fg=typer.colors.GREEN,
265
- )
266
- else:
267
- typer.secho(f"📖 Letta configuration file unchanged.", fg=typer.colors.WHITE)
268
- typer.secho(
269
- "\n".join(
270
- [
271
- f"🧠 model\t-> {new_config.default_llm_config.model}",
272
- f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
273
- ]
274
- ),
275
- fg=typer.colors.WHITE,
276
- )
277
-
278
- # 'terminal' = quickstart was run alone, in which case we should guide the user on the next command
279
- if terminal:
280
- if config_was_modified:
281
- typer.secho('⚡ Run "letta run" to create an agent with the new config.', fg=typer.colors.YELLOW)
282
- else:
283
- typer.secho('⚡ Run "letta run" to create an agent.', fg=typer.colors.YELLOW)
284
-
285
-
286
29
  def open_folder():
287
30
  """Open a folder viewer of the Letta home directory"""
288
31
  try:
@@ -302,6 +45,7 @@ def server(
302
45
  port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
303
46
  host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None,
304
47
  debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
48
+ ade: Annotated[bool, typer.Option(help="Allows remote access")] = False,
305
49
  ):
306
50
  """Launch a Letta server process"""
307
51
 
@@ -383,83 +127,8 @@ def run(
383
127
  logger.setLevel(logging.CRITICAL)
384
128
  server_logger.setLevel(logging.CRITICAL)
385
129
 
386
- # from letta.migrate import (
387
- # VERSION_CUTOFF,
388
- # config_is_compatible,
389
- # wipe_config_and_reconfigure,
390
- # )
391
-
392
- # if not config_is_compatible(allow_empty=True):
393
- # typer.secho(f"\nYour current config file is incompatible with Letta versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
394
- # choices = [
395
- # "Run the full config setup (recommended)",
396
- # "Create a new config using defaults",
397
- # "Cancel",
398
- # ]
399
- # selection = questionary.select(
400
- # f"To use Letta, you must either downgrade your Letta version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
401
- # choices=choices,
402
- # default=choices[0],
403
- # ).ask()
404
- # if selection == choices[0]:
405
- # try:
406
- # wipe_config_and_reconfigure()
407
- # except Exception as e:
408
- # typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
409
- # raise
410
- # elif selection == choices[1]:
411
- # try:
412
- # # Don't create a config, so that the next block of code asking about quickstart is run
413
- # wipe_config_and_reconfigure(run_configure=False, create_config=False)
414
- # except Exception as e:
415
- # typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
416
- # raise
417
- # else:
418
- # typer.secho("Letta config regeneration cancelled", fg=typer.colors.RED)
419
- # raise KeyboardInterrupt()
420
-
421
- # typer.secho("Note: if you would like to migrate old agents to the new release, please run `letta migrate`!", fg=typer.colors.GREEN)
422
-
423
- if not LettaConfig.exists():
424
- # if no config, ask about quickstart
425
- # do you want to do:
426
- # - openai (run quickstart)
427
- # - letta hosted (run quickstart)
428
- # - other (run configure)
429
- if yes:
430
- # if user is passing '-y' to bypass all inputs, use letta hosted
431
- # since it can't fail out if you don't have an API key
432
- quickstart(backend=QuickstartChoice.letta_hosted)
433
- config = LettaConfig()
434
-
435
- else:
436
- config_choices = {
437
- "letta": "Use the free Letta endpoints",
438
- "openai": "Use OpenAI (requires an OpenAI API key)",
439
- "other": "Other (OpenAI Azure, custom LLM endpoint, etc)",
440
- }
441
- print()
442
- config_selection = questionary.select(
443
- "How would you like to set up Letta?",
444
- choices=list(config_choices.values()),
445
- default=config_choices["letta"],
446
- ).ask()
447
-
448
- if config_selection == config_choices["letta"]:
449
- print()
450
- quickstart(backend=QuickstartChoice.letta_hosted, debug=debug, terminal=False, latest=False)
451
- elif config_selection == config_choices["openai"]:
452
- print()
453
- quickstart(backend=QuickstartChoice.openai, debug=debug, terminal=False, latest=False)
454
- elif config_selection == config_choices["other"]:
455
- configure()
456
- else:
457
- raise ValueError(config_selection)
458
-
459
- config = LettaConfig.load()
460
-
461
- else: # load config
462
- config = LettaConfig.load()
130
+ # load config file
131
+ config = LettaConfig.load()
463
132
 
464
133
  # read user id from config
465
134
  ms = MetadataStore(config)
@@ -556,40 +225,36 @@ def run(
556
225
  typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE)
557
226
 
558
227
  agent_name = agent if agent else utils.create_random_username()
559
- llm_config = config.default_llm_config
560
- embedding_config = config.default_embedding_config # TODO allow overriding embedding params via CLI run
561
-
562
- # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
563
- if model and model != llm_config.model:
564
- typer.secho(f"{CLI_WARNING_PREFIX}Overriding default model {llm_config.model} with {model}", fg=typer.colors.YELLOW)
565
- llm_config.model = model
566
- if context_window is not None and int(context_window) != llm_config.context_window:
567
- typer.secho(
568
- f"{CLI_WARNING_PREFIX}Overriding default context window {llm_config.context_window} with {context_window}",
569
- fg=typer.colors.YELLOW,
570
- )
571
- llm_config.context_window = context_window
572
- if model_wrapper and model_wrapper != llm_config.model_wrapper:
573
- typer.secho(
574
- f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {llm_config.model_wrapper} with {model_wrapper}",
575
- fg=typer.colors.YELLOW,
576
- )
577
- llm_config.model_wrapper = model_wrapper
578
- if model_endpoint and model_endpoint != llm_config.model_endpoint:
579
- typer.secho(
580
- f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {llm_config.model_endpoint} with {model_endpoint}",
581
- fg=typer.colors.YELLOW,
582
- )
583
- llm_config.model_endpoint = model_endpoint
584
- if model_endpoint_type and model_endpoint_type != llm_config.model_endpoint_type:
585
- typer.secho(
586
- f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {llm_config.model_endpoint_type} with {model_endpoint_type}",
587
- fg=typer.colors.YELLOW,
588
- )
589
- llm_config.model_endpoint_type = model_endpoint_type
590
228
 
591
229
  # create agent
592
230
  client = create_client()
231
+
232
+ # choose from list of llm_configs
233
+ llm_configs = client.list_llm_configs()
234
+ llm_options = [llm_config.model for llm_config in llm_configs]
235
+ # select model
236
+ if len(llm_options) == 0:
237
+ raise ValueError("No LLM models found. Please enable a provider.")
238
+ elif len(llm_options) == 1:
239
+ llm_model_name = llm_options[0]
240
+ else:
241
+ llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask()
242
+ llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
243
+
244
+ # choose form list of embedding configs
245
+ embedding_configs = client.list_embedding_configs()
246
+ embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
247
+ # select model
248
+ if len(embedding_options) == 0:
249
+ raise ValueError("No embedding models found. Please enable a provider.")
250
+ elif len(embedding_options) == 1:
251
+ embedding_model_name = embedding_options[0]
252
+ else:
253
+ embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask()
254
+ embedding_config = [
255
+ embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
256
+ ][0]
257
+
593
258
  human_obj = client.get_human(client.get_human_id(name=human))
594
259
  persona_obj = client.get_persona(client.get_persona_id(name=persona))
595
260
  if human_obj is None:
letta/cli/cli_config.py CHANGED
@@ -35,8 +35,6 @@ from letta.local_llm.constants import (
35
35
  DEFAULT_WRAPPER_NAME,
36
36
  )
37
37
  from letta.local_llm.utils import get_available_wrappers
38
- from letta.schemas.embedding_config import EmbeddingConfig
39
- from letta.schemas.llm_config import LLMConfig
40
38
  from letta.server.utils import shorten_key_middle
41
39
 
42
40
  app = typer.Typer()
@@ -71,7 +69,7 @@ def configure_llm_endpoint(config: LettaConfig, credentials: LettaCredentials):
71
69
  model_endpoint_type, model_endpoint = None, None
72
70
 
73
71
  # get default
74
- default_model_endpoint_type = config.default_llm_config.model_endpoint_type if config.default_embedding_config else None
72
+ default_model_endpoint_type = None
75
73
  if (
76
74
  config.default_llm_config
77
75
  and config.default_llm_config.model_endpoint_type is not None
@@ -126,7 +124,41 @@ def configure_llm_endpoint(config: LettaConfig, credentials: LettaCredentials):
126
124
  model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
127
125
  if model_endpoint is None:
128
126
  raise KeyboardInterrupt
129
- provider = "openai"
127
+
128
+ elif provider == "groq":
129
+ groq_user_msg = "Enter your Groq API key (starts with 'gsk-', see https://console.groq.com/keys):"
130
+ # check for key
131
+ if credentials.groq_key is None:
132
+ # allow key to get pulled from env vars
133
+ groq_api_key = os.getenv("GROQ_API_KEY", None)
134
+ # if we still can't find it, ask for it as input
135
+ if groq_api_key is None:
136
+ while groq_api_key is None or len(groq_api_key) == 0:
137
+ # Ask for API key as input
138
+ groq_api_key = questionary.password(groq_user_msg).ask()
139
+ if groq_api_key is None:
140
+ raise KeyboardInterrupt
141
+ credentials.groq_key = groq_api_key
142
+ credentials.save()
143
+ else:
144
+ # Give the user an opportunity to overwrite the key
145
+ default_input = shorten_key_middle(credentials.groq_key) if credentials.groq_key.startswith("gsk-") else credentials.groq_key
146
+ groq_api_key = questionary.password(
147
+ groq_user_msg,
148
+ default=default_input,
149
+ ).ask()
150
+ if groq_api_key is None:
151
+ raise KeyboardInterrupt
152
+ # If the user modified it, use the new one
153
+ if groq_api_key != default_input:
154
+ credentials.groq_key = groq_api_key
155
+ credentials.save()
156
+
157
+ model_endpoint_type = "groq"
158
+ model_endpoint = "https://api.groq.com/openai/v1"
159
+ model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
160
+ if model_endpoint is None:
161
+ raise KeyboardInterrupt
130
162
 
131
163
  elif provider == "azure":
132
164
  # check for necessary vars
@@ -392,6 +424,12 @@ def get_model_options(
392
424
  fetched_model_options = cohere_get_model_list(url=model_endpoint, api_key=credentials.cohere_key)
393
425
  model_options = [obj for obj in fetched_model_options]
394
426
 
427
+ elif model_endpoint_type == "groq":
428
+ if credentials.groq_key is None:
429
+ raise ValueError("Missing Groq API key")
430
+ fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=credentials.groq_key, fix_url=True)
431
+ model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
432
+
395
433
  else:
396
434
  # Attempt to do OpenAI endpoint style model fetching
397
435
  # TODO support local auth with api-key header
@@ -555,10 +593,32 @@ def configure_model(config: LettaConfig, credentials: LettaCredentials, model_en
555
593
  if model is None:
556
594
  raise KeyboardInterrupt
557
595
 
596
+ # Groq support via /chat/completions + function calling endpoints
597
+ elif model_endpoint_type == "groq":
598
+ try:
599
+ fetched_model_options = get_model_options(
600
+ credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
601
+ )
602
+
603
+ except Exception as e:
604
+ # NOTE: if this fails, it means the user's key is probably bad
605
+ typer.secho(
606
+ f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
607
+ )
608
+ raise e
609
+
610
+ model = questionary.select(
611
+ "Select default model:",
612
+ choices=fetched_model_options,
613
+ default=fetched_model_options[0],
614
+ ).ask()
615
+ if model is None:
616
+ raise KeyboardInterrupt
617
+
558
618
  else: # local models
559
619
 
560
620
  # ask about local auth
561
- if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys
621
+ if model_endpoint_type in ["groq-chat-compltions"]: # TODO all llm engines under 'local' that will require api keys
562
622
  use_local_auth = True
563
623
  local_auth_type = "bearer_token"
564
624
  local_auth_key = questionary.password(
@@ -779,7 +839,7 @@ def configure_model(config: LettaConfig, credentials: LettaCredentials, model_en
779
839
  def configure_embedding_endpoint(config: LettaConfig, credentials: LettaCredentials):
780
840
  # configure embedding endpoint
781
841
 
782
- default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type if config.default_embedding_config else None
842
+ default_embedding_endpoint_type = None
783
843
 
784
844
  embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None
785
845
  embedding_provider = questionary.select(
@@ -844,9 +904,7 @@ def configure_embedding_endpoint(config: LettaConfig, credentials: LettaCredenti
844
904
  raise KeyboardInterrupt
845
905
 
846
906
  # get model type
847
- default_embedding_model = (
848
- config.default_embedding_config.embedding_model if config.default_embedding_config else "BAAI/bge-large-en-v1.5"
849
- )
907
+ default_embedding_model = "BAAI/bge-large-en-v1.5"
850
908
  embedding_model = questionary.text(
851
909
  "Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
852
910
  default=default_embedding_model,
@@ -855,7 +913,7 @@ def configure_embedding_endpoint(config: LettaConfig, credentials: LettaCredenti
855
913
  raise KeyboardInterrupt
856
914
 
857
915
  # get model dimentions
858
- default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "1024"
916
+ default_embedding_dim = "1024"
859
917
  embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
860
918
  if embedding_dim is None:
861
919
  raise KeyboardInterrupt
@@ -880,9 +938,7 @@ def configure_embedding_endpoint(config: LettaConfig, credentials: LettaCredenti
880
938
  raise KeyboardInterrupt
881
939
 
882
940
  # get model type
883
- default_embedding_model = (
884
- config.default_embedding_config.embedding_model if config.default_embedding_config else "mxbai-embed-large"
885
- )
941
+ default_embedding_model = "mxbai-embed-large"
886
942
  embedding_model = questionary.text(
887
943
  "Enter Ollama model tag (e.g. mxbai-embed-large):",
888
944
  default=default_embedding_model,
@@ -891,7 +947,7 @@ def configure_embedding_endpoint(config: LettaConfig, credentials: LettaCredenti
891
947
  raise KeyboardInterrupt
892
948
 
893
949
  # get model dimensions
894
- default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "512"
950
+ default_embedding_dim = "512"
895
951
  embedding_dim = questionary.text("Enter embedding model dimensions (e.g. 512):", default=str(default_embedding_dim)).ask()
896
952
  if embedding_dim is None:
897
953
  raise KeyboardInterrupt
@@ -1040,19 +1096,6 @@ def configure():
1040
1096
 
1041
1097
  # TODO: remove most of this (deplicated with User table)
1042
1098
  config = LettaConfig(
1043
- default_llm_config=LLMConfig(
1044
- model=model,
1045
- model_endpoint=model_endpoint,
1046
- model_endpoint_type=model_endpoint_type,
1047
- model_wrapper=model_wrapper,
1048
- context_window=context_window,
1049
- ),
1050
- default_embedding_config=EmbeddingConfig(
1051
- embedding_endpoint_type=embedding_endpoint_type,
1052
- embedding_endpoint=embedding_endpoint,
1053
- embedding_dim=embedding_dim,
1054
- embedding_model=embedding_model,
1055
- ),
1056
1099
  # storage
1057
1100
  archival_storage_type=archival_storage_type,
1058
1101
  archival_storage_uri=archival_storage_uri,