khoj 1.16.1.dev47__py3-none-any.whl → 1.17.1.dev216__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. khoj/configure.py +6 -6
  2. khoj/database/adapters/__init__.py +47 -2
  3. khoj/database/migrations/0053_agent_style_color_agent_style_icon.py +61 -0
  4. khoj/database/models/__init__.py +35 -0
  5. khoj/interface/web/assets/icons/favicon-128x128.png +0 -0
  6. khoj/interface/web/assets/icons/favicon-256x256.png +0 -0
  7. khoj/interface/web/assets/icons/khoj-logo-sideways-200.png +0 -0
  8. khoj/interface/web/assets/icons/khoj-logo-sideways-500.png +0 -0
  9. khoj/interface/web/assets/icons/khoj-logo-sideways.svg +31 -5384
  10. khoj/interface/web/assets/icons/khoj.svg +26 -0
  11. khoj/interface/web/chat.html +5 -5
  12. khoj/interface/web/content_source_computer_input.html +3 -3
  13. khoj/interface/web/content_source_github_input.html +1 -1
  14. khoj/interface/web/content_source_notion_input.html +1 -1
  15. khoj/interface/web/public_conversation.html +1 -1
  16. khoj/interface/web/search.html +2 -2
  17. khoj/interface/web/{config.html → settings.html} +30 -30
  18. khoj/interface/web/utils.html +1 -1
  19. khoj/processor/content/docx/docx_to_entries.py +4 -9
  20. khoj/processor/content/github/github_to_entries.py +1 -3
  21. khoj/processor/content/images/image_to_entries.py +4 -9
  22. khoj/processor/content/markdown/markdown_to_entries.py +4 -9
  23. khoj/processor/content/notion/notion_to_entries.py +1 -3
  24. khoj/processor/content/org_mode/org_to_entries.py +4 -9
  25. khoj/processor/content/pdf/pdf_to_entries.py +4 -9
  26. khoj/processor/content/plaintext/plaintext_to_entries.py +4 -9
  27. khoj/processor/content/text_to_entries.py +1 -3
  28. khoj/processor/tools/online_search.py +4 -4
  29. khoj/routers/api.py +49 -4
  30. khoj/routers/api_agents.py +3 -1
  31. khoj/routers/api_chat.py +80 -88
  32. khoj/routers/api_content.py +538 -0
  33. khoj/routers/api_model.py +156 -0
  34. khoj/routers/helpers.py +308 -7
  35. khoj/routers/notion.py +2 -8
  36. khoj/routers/web_client.py +43 -256
  37. khoj/search_type/text_search.py +5 -4
  38. khoj/utils/fs_syncer.py +3 -1
  39. khoj/utils/rawconfig.py +6 -1
  40. {khoj-1.16.1.dev47.dist-info → khoj-1.17.1.dev216.dist-info}/METADATA +2 -2
  41. {khoj-1.16.1.dev47.dist-info → khoj-1.17.1.dev216.dist-info}/RECORD +44 -42
  42. khoj/routers/api_config.py +0 -434
  43. khoj/routers/indexer.py +0 -349
  44. {khoj-1.16.1.dev47.dist-info → khoj-1.17.1.dev216.dist-info}/WHEEL +0 -0
  45. {khoj-1.16.1.dev47.dist-info → khoj-1.17.1.dev216.dist-info}/entry_points.txt +0 -0
  46. {khoj-1.16.1.dev47.dist-info → khoj-1.17.1.dev216.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,538 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import math
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from asgiref.sync import sync_to_async
8
+ from fastapi import (
9
+ APIRouter,
10
+ Depends,
11
+ Header,
12
+ HTTPException,
13
+ Request,
14
+ Response,
15
+ UploadFile,
16
+ )
17
+ from pydantic import BaseModel
18
+ from starlette.authentication import requires
19
+
20
+ from khoj.database import adapters
21
+ from khoj.database.adapters import (
22
+ EntryAdapters,
23
+ get_user_github_config,
24
+ get_user_notion_config,
25
+ )
26
+ from khoj.database.models import Entry as DbEntry
27
+ from khoj.database.models import (
28
+ GithubConfig,
29
+ GithubRepoConfig,
30
+ KhojUser,
31
+ LocalMarkdownConfig,
32
+ LocalOrgConfig,
33
+ LocalPdfConfig,
34
+ LocalPlaintextConfig,
35
+ NotionConfig,
36
+ )
37
+ from khoj.routers.helpers import (
38
+ ApiIndexedDataLimiter,
39
+ CommonQueryParams,
40
+ configure_content,
41
+ get_user_config,
42
+ update_telemetry_state,
43
+ )
44
+ from khoj.utils import constants, state
45
+ from khoj.utils.config import SearchModels
46
+ from khoj.utils.helpers import get_file_type
47
+ from khoj.utils.rawconfig import (
48
+ ContentConfig,
49
+ FullConfig,
50
+ GithubContentConfig,
51
+ NotionContentConfig,
52
+ SearchConfig,
53
+ )
54
+ from khoj.utils.state import SearchType
55
+ from khoj.utils.yaml import save_config_to_file_updated_state
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ api_content = APIRouter()
60
+
61
+
62
+ class File(BaseModel):
63
+ path: str
64
+ content: Union[str, bytes]
65
+
66
+
67
+ class IndexBatchRequest(BaseModel):
68
+ files: list[File]
69
+
70
+
71
+ class IndexerInput(BaseModel):
72
+ org: Optional[dict[str, str]] = None
73
+ markdown: Optional[dict[str, str]] = None
74
+ pdf: Optional[dict[str, bytes]] = None
75
+ plaintext: Optional[dict[str, str]] = None
76
+ image: Optional[dict[str, bytes]] = None
77
+ docx: Optional[dict[str, bytes]] = None
78
+
79
+
80
+ @api_content.put("")
81
+ @requires(["authenticated"])
82
+ async def put_content(
83
+ request: Request,
84
+ files: List[UploadFile] = [],
85
+ t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
86
+ client: Optional[str] = None,
87
+ user_agent: Optional[str] = Header(None),
88
+ referer: Optional[str] = Header(None),
89
+ host: Optional[str] = Header(None),
90
+ indexed_data_limiter: ApiIndexedDataLimiter = Depends(
91
+ ApiIndexedDataLimiter(
92
+ incoming_entries_size_limit=10,
93
+ subscribed_incoming_entries_size_limit=75,
94
+ total_entries_size_limit=10,
95
+ subscribed_total_entries_size_limit=100,
96
+ )
97
+ ),
98
+ ):
99
+ return await indexer(request, files, t, True, client, user_agent, referer, host)
100
+
101
+
102
+ @api_content.patch("")
103
+ @requires(["authenticated"])
104
+ async def patch_content(
105
+ request: Request,
106
+ files: List[UploadFile] = [],
107
+ t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
108
+ client: Optional[str] = None,
109
+ user_agent: Optional[str] = Header(None),
110
+ referer: Optional[str] = Header(None),
111
+ host: Optional[str] = Header(None),
112
+ indexed_data_limiter: ApiIndexedDataLimiter = Depends(
113
+ ApiIndexedDataLimiter(
114
+ incoming_entries_size_limit=10,
115
+ subscribed_incoming_entries_size_limit=75,
116
+ total_entries_size_limit=10,
117
+ subscribed_total_entries_size_limit=100,
118
+ )
119
+ ),
120
+ ):
121
+ return await indexer(request, files, t, False, client, user_agent, referer, host)
122
+
123
+
124
+ @api_content.get("/github", response_class=Response)
125
+ @requires(["authenticated"])
126
+ def get_content_github(request: Request) -> Response:
127
+ user = request.user.object
128
+ user_config = get_user_config(user, request)
129
+ del user_config["request"]
130
+
131
+ current_github_config = get_user_github_config(user)
132
+
133
+ if current_github_config:
134
+ raw_repos = current_github_config.githubrepoconfig.all()
135
+ repos = []
136
+ for repo in raw_repos:
137
+ repos.append(
138
+ GithubRepoConfig(
139
+ name=repo.name,
140
+ owner=repo.owner,
141
+ branch=repo.branch,
142
+ )
143
+ )
144
+ current_config = GithubContentConfig(
145
+ pat_token=current_github_config.pat_token,
146
+ repos=repos,
147
+ )
148
+ current_config = json.loads(current_config.json())
149
+ else:
150
+ current_config = {} # type: ignore
151
+
152
+ user_config["current_config"] = current_config
153
+
154
+ # Return config data as a JSON response
155
+ return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
156
+
157
+
158
+ @api_content.get("/notion", response_class=Response)
159
+ @requires(["authenticated"])
160
+ def get_content_notion(request: Request) -> Response:
161
+ user = request.user.object
162
+ user_config = get_user_config(user, request)
163
+ del user_config["request"]
164
+
165
+ current_notion_config = get_user_notion_config(user)
166
+ token = current_notion_config.token if current_notion_config else ""
167
+ current_config = NotionContentConfig(token=token)
168
+ current_config = json.loads(current_config.model_dump_json())
169
+
170
+ user_config["current_config"] = current_config
171
+
172
+ # Return config data as a JSON response
173
+ return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
174
+
175
+
176
+ @api_content.post("/github", status_code=200)
177
+ @requires(["authenticated"])
178
+ async def set_content_github(
179
+ request: Request,
180
+ updated_config: Union[GithubContentConfig, None],
181
+ client: Optional[str] = None,
182
+ ):
183
+ _initialize_config()
184
+
185
+ user = request.user.object
186
+
187
+ try:
188
+ await adapters.set_user_github_config(
189
+ user=user,
190
+ pat_token=updated_config.pat_token,
191
+ repos=updated_config.repos,
192
+ )
193
+ except Exception as e:
194
+ logger.error(e, exc_info=True)
195
+ raise HTTPException(status_code=500, detail="Failed to set Github config")
196
+
197
+ update_telemetry_state(
198
+ request=request,
199
+ telemetry_type="api",
200
+ api="set_content_config",
201
+ client=client,
202
+ metadata={"content_type": "github"},
203
+ )
204
+
205
+ return {"status": "ok"}
206
+
207
+
208
+ @api_content.post("/notion", status_code=200)
209
+ @requires(["authenticated"])
210
+ async def set_content_notion(
211
+ request: Request,
212
+ updated_config: Union[NotionContentConfig, None],
213
+ client: Optional[str] = None,
214
+ ):
215
+ _initialize_config()
216
+
217
+ user = request.user.object
218
+
219
+ try:
220
+ await adapters.set_notion_config(
221
+ user=user,
222
+ token=updated_config.token,
223
+ )
224
+ except Exception as e:
225
+ logger.error(e, exc_info=True)
226
+ raise HTTPException(status_code=500, detail="Failed to set Notion config")
227
+
228
+ update_telemetry_state(
229
+ request=request,
230
+ telemetry_type="api",
231
+ api="set_content_config",
232
+ client=client,
233
+ metadata={"content_type": "notion"},
234
+ )
235
+
236
+ return {"status": "ok"}
237
+
238
+
239
+ @api_content.delete("/file", status_code=201)
240
+ @requires(["authenticated"])
241
+ async def delete_content_files(
242
+ request: Request,
243
+ filename: str,
244
+ client: Optional[str] = None,
245
+ ):
246
+ user = request.user.object
247
+
248
+ update_telemetry_state(
249
+ request=request,
250
+ telemetry_type="api",
251
+ api="delete_file",
252
+ client=client,
253
+ )
254
+
255
+ await EntryAdapters.adelete_entry_by_file(user, filename)
256
+
257
+ return {"status": "ok"}
258
+
259
+
260
+ class DeleteFilesRequest(BaseModel):
261
+ files: List[str]
262
+
263
+
264
+ @api_content.delete("/files", status_code=201)
265
+ @requires(["authenticated"])
266
+ async def delete_content_file(
267
+ request: Request,
268
+ files: DeleteFilesRequest,
269
+ client: Optional[str] = None,
270
+ ):
271
+ user = request.user.object
272
+
273
+ update_telemetry_state(
274
+ request=request,
275
+ telemetry_type="api",
276
+ api="delete_file",
277
+ client=client,
278
+ )
279
+
280
+ deleted_count = await EntryAdapters.adelete_entries_by_filenames(user, files.files)
281
+
282
+ return {"status": "ok", "deleted_count": deleted_count}
283
+
284
+
285
+ @api_content.get("/size", response_model=Dict[str, int])
286
+ @requires(["authenticated"])
287
+ async def get_content_size(request: Request, common: CommonQueryParams, client: Optional[str] = None):
288
+ user = request.user.object
289
+ indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
290
+ return Response(
291
+ content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
292
+ media_type="application/json",
293
+ status_code=200,
294
+ )
295
+
296
+
297
+ @api_content.get("/types", response_model=List[str])
298
+ @requires(["authenticated"])
299
+ def get_content_types(request: Request, client: Optional[str] = None):
300
+ user = request.user.object
301
+ all_content_types = {s.value for s in SearchType}
302
+ configured_content_types = set(EntryAdapters.get_unique_file_types(user))
303
+ configured_content_types |= {"all"}
304
+
305
+ if state.config and state.config.content_type:
306
+ for ctype in state.config.content_type.model_dump(exclude_none=True):
307
+ configured_content_types.add(ctype)
308
+
309
+ return list(configured_content_types & all_content_types)
310
+
311
+
312
+ @api_content.get("/{content_source}", response_model=List[str])
313
+ @requires(["authenticated"])
314
+ async def get_content_source(
315
+ request: Request,
316
+ content_source: str,
317
+ client: Optional[str] = None,
318
+ ):
319
+ user = request.user.object
320
+
321
+ update_telemetry_state(
322
+ request=request,
323
+ telemetry_type="api",
324
+ api="get_all_filenames",
325
+ client=client,
326
+ )
327
+
328
+ return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
329
+
330
+
331
+ @api_content.delete("/{content_source}", status_code=200)
332
+ @requires(["authenticated"])
333
+ async def delete_content_source(
334
+ request: Request,
335
+ content_source: str,
336
+ client: Optional[str] = None,
337
+ ):
338
+ user = request.user.object
339
+
340
+ content_object = map_config_to_object(content_source)
341
+ if content_object is None:
342
+ raise ValueError(f"Invalid content source: {content_source}")
343
+ elif content_object != "Computer":
344
+ await content_object.objects.filter(user=user).adelete()
345
+ await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
346
+
347
+ if content_source == DbEntry.EntrySource.NOTION:
348
+ await NotionConfig.objects.filter(user=user).adelete()
349
+ elif content_source == DbEntry.EntrySource.GITHUB:
350
+ await GithubConfig.objects.filter(user=user).adelete()
351
+
352
+ update_telemetry_state(
353
+ request=request,
354
+ telemetry_type="api",
355
+ api="delete_content_config",
356
+ client=client,
357
+ metadata={"content_source": content_source},
358
+ )
359
+
360
+ enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
361
+ return {"status": "ok"}
362
+
363
+
364
+ async def indexer(
365
+ request: Request,
366
+ files: list[UploadFile],
367
+ t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
368
+ regenerate: bool = False,
369
+ client: Optional[str] = None,
370
+ user_agent: Optional[str] = Header(None),
371
+ referer: Optional[str] = Header(None),
372
+ host: Optional[str] = Header(None),
373
+ ):
374
+ user = request.user.object
375
+ method = "regenerate" if regenerate else "sync"
376
+ index_files: Dict[str, Dict[str, str]] = {
377
+ "org": {},
378
+ "markdown": {},
379
+ "pdf": {},
380
+ "plaintext": {},
381
+ "image": {},
382
+ "docx": {},
383
+ }
384
+ try:
385
+ logger.info(f"📬 Updating content index via API call by {client} client")
386
+ for file in files:
387
+ file_content = file.file.read()
388
+ file_type, encoding = get_file_type(file.content_type, file_content)
389
+ if file_type in index_files:
390
+ index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
391
+ else:
392
+ logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
393
+
394
+ indexer_input = IndexerInput(
395
+ org=index_files["org"],
396
+ markdown=index_files["markdown"],
397
+ pdf=index_files["pdf"],
398
+ plaintext=index_files["plaintext"],
399
+ image=index_files["image"],
400
+ docx=index_files["docx"],
401
+ )
402
+
403
+ if state.config == None:
404
+ logger.info("📬 Initializing content index on first run.")
405
+ default_full_config = FullConfig(
406
+ content_type=None,
407
+ search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
408
+ processor=None,
409
+ )
410
+ state.config = default_full_config
411
+ default_content_config = ContentConfig(
412
+ org=None,
413
+ markdown=None,
414
+ pdf=None,
415
+ docx=None,
416
+ image=None,
417
+ github=None,
418
+ notion=None,
419
+ plaintext=None,
420
+ )
421
+ state.config.content_type = default_content_config
422
+ save_config_to_file_updated_state()
423
+ configure_search(state.search_models, state.config.search_type)
424
+
425
+ loop = asyncio.get_event_loop()
426
+ success = await loop.run_in_executor(
427
+ None,
428
+ configure_content,
429
+ indexer_input.model_dump(),
430
+ regenerate,
431
+ t,
432
+ user,
433
+ )
434
+ if not success:
435
+ raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
436
+ logger.info(f"Finished {method} {t} data sent by {client} client into content index")
437
+ except Exception as e:
438
+ logger.error(f"Failed to {method} {t} data sent by {client} client into content index: {e}", exc_info=True)
439
+ logger.error(
440
+ f"🚨 Failed to {method} {t} data sent by {client} client into content index: {e}",
441
+ exc_info=True,
442
+ )
443
+ return Response(content="Failed", status_code=500)
444
+
445
+ indexing_metadata = {
446
+ "num_org": len(index_files["org"]),
447
+ "num_markdown": len(index_files["markdown"]),
448
+ "num_pdf": len(index_files["pdf"]),
449
+ "num_plaintext": len(index_files["plaintext"]),
450
+ "num_image": len(index_files["image"]),
451
+ "num_docx": len(index_files["docx"]),
452
+ }
453
+
454
+ update_telemetry_state(
455
+ request=request,
456
+ telemetry_type="api",
457
+ api="index/update",
458
+ client=client,
459
+ user_agent=user_agent,
460
+ referer=referer,
461
+ host=host,
462
+ metadata=indexing_metadata,
463
+ )
464
+
465
+ logger.info(f"📪 Content index updated via API call by {client} client")
466
+
467
+ indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
468
+ return Response(content=indexed_filenames, status_code=200)
469
+
470
+
471
+ def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
472
+ # Run Validation Checks
473
+ if search_models is None:
474
+ search_models = SearchModels()
475
+
476
+ return search_models
477
+
478
+
479
+ def map_config_to_object(content_source: str):
480
+ if content_source == DbEntry.EntrySource.GITHUB:
481
+ return GithubConfig
482
+ if content_source == DbEntry.EntrySource.NOTION:
483
+ return NotionConfig
484
+ if content_source == DbEntry.EntrySource.COMPUTER:
485
+ return "Computer"
486
+
487
+
488
+ async def map_config_to_db(config: FullConfig, user: KhojUser):
489
+ if config.content_type:
490
+ if config.content_type.org:
491
+ await LocalOrgConfig.objects.filter(user=user).adelete()
492
+ await LocalOrgConfig.objects.acreate(
493
+ input_files=config.content_type.org.input_files,
494
+ input_filter=config.content_type.org.input_filter,
495
+ index_heading_entries=config.content_type.org.index_heading_entries,
496
+ user=user,
497
+ )
498
+ if config.content_type.markdown:
499
+ await LocalMarkdownConfig.objects.filter(user=user).adelete()
500
+ await LocalMarkdownConfig.objects.acreate(
501
+ input_files=config.content_type.markdown.input_files,
502
+ input_filter=config.content_type.markdown.input_filter,
503
+ index_heading_entries=config.content_type.markdown.index_heading_entries,
504
+ user=user,
505
+ )
506
+ if config.content_type.pdf:
507
+ await LocalPdfConfig.objects.filter(user=user).adelete()
508
+ await LocalPdfConfig.objects.acreate(
509
+ input_files=config.content_type.pdf.input_files,
510
+ input_filter=config.content_type.pdf.input_filter,
511
+ index_heading_entries=config.content_type.pdf.index_heading_entries,
512
+ user=user,
513
+ )
514
+ if config.content_type.plaintext:
515
+ await LocalPlaintextConfig.objects.filter(user=user).adelete()
516
+ await LocalPlaintextConfig.objects.acreate(
517
+ input_files=config.content_type.plaintext.input_files,
518
+ input_filter=config.content_type.plaintext.input_filter,
519
+ index_heading_entries=config.content_type.plaintext.index_heading_entries,
520
+ user=user,
521
+ )
522
+ if config.content_type.github:
523
+ await adapters.set_user_github_config(
524
+ user=user,
525
+ pat_token=config.content_type.github.pat_token,
526
+ repos=config.content_type.github.repos,
527
+ )
528
+ if config.content_type.notion:
529
+ await adapters.set_notion_config(
530
+ user=user,
531
+ token=config.content_type.notion.token,
532
+ )
533
+
534
+
535
+ def _initialize_config():
536
+ if state.config is None:
537
+ state.config = FullConfig()
538
+ state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@@ -0,0 +1,156 @@
1
+ import json
2
+ import logging
3
+ from typing import Dict, Optional, Union
4
+
5
+ from fastapi import APIRouter, HTTPException, Request
6
+ from fastapi.requests import Request
7
+ from fastapi.responses import Response
8
+ from starlette.authentication import has_required_scope, requires
9
+
10
+ from khoj.database import adapters
11
+ from khoj.database.adapters import ConversationAdapters, EntryAdapters
12
+ from khoj.routers.helpers import update_telemetry_state
13
+
14
+ api_model = APIRouter()
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @api_model.get("/chat/options", response_model=Dict[str, Union[str, int]])
19
+ def get_chat_model_options(
20
+ request: Request,
21
+ client: Optional[str] = None,
22
+ ):
23
+ conversation_options = ConversationAdapters.get_conversation_processor_options().all()
24
+
25
+ all_conversation_options = list()
26
+ for conversation_option in conversation_options:
27
+ all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
28
+
29
+ return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
30
+
31
+
32
+ @api_model.get("/chat")
33
+ @requires(["authenticated"])
34
+ def get_user_chat_model(
35
+ request: Request,
36
+ client: Optional[str] = None,
37
+ ):
38
+ user = request.user.object
39
+
40
+ chat_model = ConversationAdapters.get_conversation_config(user)
41
+
42
+ if chat_model is None:
43
+ chat_model = ConversationAdapters.get_default_conversation_config()
44
+
45
+ return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
46
+
47
+
48
+ @api_model.post("/chat", status_code=200)
49
+ @requires(["authenticated", "premium"])
50
+ async def update_chat_model(
51
+ request: Request,
52
+ id: str,
53
+ client: Optional[str] = None,
54
+ ):
55
+ user = request.user.object
56
+
57
+ new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
58
+
59
+ update_telemetry_state(
60
+ request=request,
61
+ telemetry_type="api",
62
+ api="set_conversation_chat_model",
63
+ client=client,
64
+ metadata={"processor_conversation_type": "conversation"},
65
+ )
66
+
67
+ if new_config is None:
68
+ return {"status": "error", "message": "Model not found"}
69
+
70
+ return {"status": "ok"}
71
+
72
+
73
+ @api_model.post("/voice", status_code=200)
74
+ @requires(["authenticated", "premium"])
75
+ async def update_voice_model(
76
+ request: Request,
77
+ id: str,
78
+ client: Optional[str] = None,
79
+ ):
80
+ user = request.user.object
81
+
82
+ new_config = await ConversationAdapters.aset_user_voice_model(user, id)
83
+
84
+ update_telemetry_state(
85
+ request=request,
86
+ telemetry_type="api",
87
+ api="set_voice_model",
88
+ client=client,
89
+ )
90
+
91
+ if new_config is None:
92
+ return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
93
+
94
+ return Response(status_code=202, content=json.dumps({"status": "ok"}))
95
+
96
+
97
+ @api_model.post("/search", status_code=200)
98
+ @requires(["authenticated"])
99
+ async def update_search_model(
100
+ request: Request,
101
+ id: str,
102
+ client: Optional[str] = None,
103
+ ):
104
+ user = request.user.object
105
+
106
+ prev_config = await adapters.aget_user_search_model(user)
107
+ new_config = await adapters.aset_user_search_model(user, int(id))
108
+
109
+ if prev_config and int(id) != prev_config.id and new_config:
110
+ await EntryAdapters.adelete_all_entries(user)
111
+
112
+ if not prev_config:
113
+ # If the use was just using the default config, delete all the entries and set the new config.
114
+ await EntryAdapters.adelete_all_entries(user)
115
+
116
+ if new_config is None:
117
+ return {"status": "error", "message": "Model not found"}
118
+ else:
119
+ update_telemetry_state(
120
+ request=request,
121
+ telemetry_type="api",
122
+ api="set_search_model",
123
+ client=client,
124
+ metadata={"search_model": new_config.setting.name},
125
+ )
126
+
127
+ return {"status": "ok"}
128
+
129
+
130
+ @api_model.post("/paint", status_code=200)
131
+ @requires(["authenticated"])
132
+ async def update_paint_model(
133
+ request: Request,
134
+ id: str,
135
+ client: Optional[str] = None,
136
+ ):
137
+ user = request.user.object
138
+ subscribed = has_required_scope(request, ["premium"])
139
+
140
+ if not subscribed:
141
+ raise HTTPException(status_code=403, detail="User is not subscribed to premium")
142
+
143
+ new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
144
+
145
+ update_telemetry_state(
146
+ request=request,
147
+ telemetry_type="api",
148
+ api="set_paint_model",
149
+ client=client,
150
+ metadata={"paint_model": new_config.setting.model_name},
151
+ )
152
+
153
+ if new_config is None:
154
+ return {"status": "error", "message": "Model not found"}
155
+
156
+ return {"status": "ok"}