@marimo-team/frontend 0.14.18-dev24 → 0.14.18-dev26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. package/dist/assets/{ConnectedDataExplorerComponent-Dbu2xXc5.js → ConnectedDataExplorerComponent-CPwF1ckF.js} +1 -1
  2. package/dist/assets/{ImageComparisonComponent-BkZIjGHA.js → ImageComparisonComponent-BuFzvNj5.js} +1 -1
  3. package/dist/assets/{VegaLite-Ca7AXGyA.js → VegaLite-DI-xPLgU.js} +1 -1
  4. package/dist/assets/{_baseEach-H5Qk1V2B.js → _baseEach-BvIVAeE-.js} +1 -1
  5. package/dist/assets/_baseMap-DrbRr9s_.js +1 -0
  6. package/dist/assets/{_baseUniq-UAmxGez2.js → _baseUniq-Cz-AR5wG.js} +1 -1
  7. package/dist/assets/{_createAggregator-H-t5qYSG.js → _createAggregator-fFJIgkzg.js} +1 -1
  8. package/dist/assets/{any-language-editor-DUt-pIdy.js → any-language-editor-yovSkXXH.js} +1 -1
  9. package/dist/assets/{architectureDiagram-SUXI7LT5-DWP05q8_.js → architectureDiagram-SUXI7LT5-vhSKxFkz.js} +1 -1
  10. package/dist/assets/{blockDiagram-6J76NXCF-DB57b3LI.js → blockDiagram-6J76NXCF-D8FkmW5U.js} +1 -1
  11. package/dist/assets/{c4Diagram-6F6E4RAY-ByeAGPmY.js → c4Diagram-6F6E4RAY-DMz49WYv.js} +1 -1
  12. package/dist/assets/channel-C5D8_wdC.js +1 -0
  13. package/dist/assets/{chunk-353BL4L5-ngKqumF3.js → chunk-353BL4L5-C_h8QREg.js} +1 -1
  14. package/dist/assets/{chunk-67H74DCK-DTbVgB4A.js → chunk-67H74DCK-vQz9ujXD.js} +1 -1
  15. package/dist/assets/{chunk-AACKK3MU-t9Kf_p_V.js → chunk-AACKK3MU-CUxeSfqv.js} +1 -1
  16. package/dist/assets/{chunk-BFAMUDN2-B0dcgdMs.js → chunk-BFAMUDN2-CLGtBusC.js} +1 -1
  17. package/dist/assets/{chunk-E2GYISFI-C7n5SQvq.js → chunk-E2GYISFI-DGILtsZz.js} +1 -1
  18. package/dist/assets/{chunk-OW32GOEJ-DEvBFIh8.js → chunk-OW32GOEJ-D-HRiiaD.js} +1 -1
  19. package/dist/assets/{chunk-SKB7J2MH-DX71TF_n.js → chunk-SKB7J2MH-xs1kEiQZ.js} +1 -1
  20. package/dist/assets/{chunk-SZ463SBG-CO9fYjVe.js → chunk-SZ463SBG-DKY2KPhj.js} +1 -1
  21. package/dist/assets/{circle-play-B-ZWQLkW.js → circle-play-ub1A5yYA.js} +1 -1
  22. package/dist/assets/classDiagram-M3E45YP4-Bzxtjbca.js +1 -0
  23. package/dist/assets/classDiagram-v2-YAWTLIQI-Bzxtjbca.js +1 -0
  24. package/dist/assets/clone-DngVHtKT.js +1 -0
  25. package/dist/assets/{compile-lv6gzALt.js → compile-Eo2s_HHz.js} +1 -1
  26. package/dist/assets/{dagre-JOIXM2OF-B7ZtMRIJ.js → dagre-JOIXM2OF-BNvsGquX.js} +1 -1
  27. package/dist/assets/{data-grid-overlay-editor-DTAq6v9P.js → data-grid-overlay-editor-Vk_PgvbS.js} +1 -1
  28. package/dist/assets/{diagram-5UYTHUR4-CboeMF3G.js → diagram-5UYTHUR4-6gGyo4aQ.js} +1 -1
  29. package/dist/assets/{diagram-VMROVX33-BcnQcTFp.js → diagram-VMROVX33-Cr9VHqxg.js} +1 -1
  30. package/dist/assets/{diagram-ZTM2IBQH-CaFT3B8g.js → diagram-ZTM2IBQH-BFP88P3p.js} +1 -1
  31. package/dist/assets/{edit-page-CvWZ8wVZ.js → edit-page-Vrbmp55N.js} +53 -53
  32. package/dist/assets/{erDiagram-3M52JZNH-BFPc1-ng.js → erDiagram-3M52JZNH-C0U8FCGV.js} +1 -1
  33. package/dist/assets/{flowDiagram-KYDEHFYC-CAZBj4fC.js → flowDiagram-KYDEHFYC-3GT8PdlC.js} +1 -1
  34. package/dist/assets/{ganttDiagram-EK5VF46D-CUxffTMj.js → ganttDiagram-EK5VF46D-CrQQiHRJ.js} +1 -1
  35. package/dist/assets/{gitGraphDiagram-GW3U2K7C-Tx-ZfecV.js → gitGraphDiagram-GW3U2K7C-BuMcEd-g.js} +1 -1
  36. package/dist/assets/{glide-data-editor-BP8l5q3f.js → glide-data-editor-BB_hULql.js} +11 -11
  37. package/dist/assets/{graph-Bf6eoL3M.js → graph-B67OAr_1.js} +1 -1
  38. package/dist/assets/{home-page--ixzxF-w.js → home-page-D4saUb6P.js} +1 -1
  39. package/dist/assets/index-B3Q6PaCG.css +1 -0
  40. package/dist/assets/{index-g-JZ5Z-_.js → index-BCsh52ZA.js} +1 -1
  41. package/dist/assets/{index-DJ1FNShi.js → index-BUmTew_7.js} +1 -1
  42. package/dist/assets/{index-CmJm6kj5.js → index-BbaurO1n.js} +1 -1
  43. package/dist/assets/{index-D1vinr8C.js → index-BcNCs1Ec.js} +1 -1
  44. package/dist/assets/{index-Cv01rAR1.js → index-BcOm_dWX.js} +1 -1
  45. package/dist/assets/{index-dvYU8Vev.js → index-Bg-oxR2t.js} +94 -94
  46. package/dist/assets/{index-DleA5-JR.js → index-CT6WEyVV.js} +1 -1
  47. package/dist/assets/{index-kaWF-HKt.js → index-CWnyv0Mb.js} +1 -1
  48. package/dist/assets/{index-BMLuYuQT.js → index-CYZYUN8a.js} +1 -1
  49. package/dist/assets/{index-DupjaVjo.js → index-CfQw831p.js} +1 -1
  50. package/dist/assets/{index-C4iPAevr.js → index-CydDDozP.js} +1 -1
  51. package/dist/assets/{index-q4AwmgtF.js → index-D-aN-b2W.js} +1 -1
  52. package/dist/assets/{index-Uh_QGNQm.js → index-DFBulMSA.js} +1 -1
  53. package/dist/assets/{index-U0KSor5u.js → index-DeKrdrEq.js} +1 -1
  54. package/dist/assets/{index-Q73xW9dd.js → index-DqZ9_rHU.js} +1 -1
  55. package/dist/assets/{index-mvfeSH0B.js → index-Dtw1TXyN.js} +1 -1
  56. package/dist/assets/{index-B39Q37Jh.js → index-FrFiygIp.js} +1 -1
  57. package/dist/assets/{index-DhU7edG1.js → index-IDJxuXSR.js} +1 -1
  58. package/dist/assets/{index-CM8rF_ge.js → index-LcqVorg9.js} +1 -1
  59. package/dist/assets/infoDiagram-LHK5PUON-Bc9fde05.js +2 -0
  60. package/dist/assets/{journeyDiagram-EWQZEKCU-DBorgqR8.js → journeyDiagram-EWQZEKCU-CQkZyIS5.js} +1 -1
  61. package/dist/assets/{kanban-definition-ZSS6B67P-rOei7rdW.js → kanban-definition-ZSS6B67P-CEV1b6rk.js} +1 -1
  62. package/dist/assets/{layout-D5U1vfFv.js → layout-DAmvOShj.js} +1 -1
  63. package/dist/assets/{linear-BPaO6rYC.js → linear-ktW_rVCS.js} +1 -1
  64. package/dist/assets/links-C8gYI3jG.js +18 -0
  65. package/dist/assets/{mermaid-DuiiiGkf.js → mermaid-Dn9IUXi8.js} +4 -4
  66. package/dist/assets/{min-GM8d8p3k.js → min-BmsA5VGH.js} +1 -1
  67. package/dist/assets/{mindmap-definition-6CBA2TL7-DFNMYbU5.js → mindmap-definition-6CBA2TL7-B0yAK13C.js} +1 -1
  68. package/dist/assets/{number-overlay-editor-CcD_5O9P.js → number-overlay-editor-DbvoJDEj.js} +1 -1
  69. package/dist/assets/{pieDiagram-NIOCPIFQ-BAnWQSTZ.js → pieDiagram-NIOCPIFQ-BaPxK1u0.js} +1 -1
  70. package/dist/assets/{quadrantDiagram-2OG54O6I-CLzghZ4d.js → quadrantDiagram-2OG54O6I-PGZoquB4.js} +1 -1
  71. package/dist/assets/{react-plotly-qIomJONw.js → react-plotly-CzEXKtwr.js} +1 -1
  72. package/dist/assets/{requirementDiagram-QOLK2EJ7-ClKsOuLQ.js → requirementDiagram-QOLK2EJ7-BdGXi15P.js} +1 -1
  73. package/dist/assets/{run-page-B9ntqQci.js → run-page-BnZ5haku.js} +1 -1
  74. package/dist/assets/{sankeyDiagram-4UZDY2LN-CMUyusMd.js → sankeyDiagram-4UZDY2LN-Dk80vg6t.js} +1 -1
  75. package/dist/assets/{sequenceDiagram-SKLFT4DO-BUL7OokF.js → sequenceDiagram-SKLFT4DO-CqLVyCT3.js} +1 -1
  76. package/dist/assets/{slides-component-DxNxYl9E.js → slides-component-D_39bH-j.js} +1 -1
  77. package/dist/assets/{sortBy-Bo672N53.js → sortBy-CI9MGbno.js} +1 -1
  78. package/dist/assets/{stateDiagram-MI5ZYTHO-CK5D03xc.js → stateDiagram-MI5ZYTHO-C3tpbdm0.js} +1 -1
  79. package/dist/assets/stateDiagram-v2-5AN5P6BG-BXclnbG1.js +1 -0
  80. package/dist/assets/{storage-DCGJ86_2.js → storage-CxOI7c0Z.js} +3 -3
  81. package/dist/assets/{terminal-C9ZYVCQk.js → terminal-9Nyd6Apx.js} +1 -1
  82. package/dist/assets/{time-DUBsogDP.js → time-Do4-nT0Q.js} +1 -1
  83. package/dist/assets/{timeline-definition-MYPXXCX6-DlOzuKHL.js → timeline-definition-MYPXXCX6-CtllVQGu.js} +1 -1
  84. package/dist/assets/{tracing-DkB9iogQ.js → tracing-BWy-UTCz.js} +2 -2
  85. package/dist/assets/{trash-CuNNSzF1.js → trash-BPgHERlA.js} +1 -1
  86. package/dist/assets/{treemap-75Q7IDZK-DUsN_Z4F.js → treemap-75Q7IDZK-BXkwu4nl.js} +1 -1
  87. package/dist/assets/{vega-component-DByprFwW.js → vega-component-COxYVQBE.js} +1 -1
  88. package/dist/assets/{xychartDiagram-H2YORKM3-BaUqYAb6.js → xychartDiagram-H2YORKM3-DfzCBqlv.js} +1 -1
  89. package/dist/index.html +2 -2
  90. package/package.json +1 -1
  91. package/src/components/ai/ai-model-dropdown.tsx +288 -0
  92. package/src/components/ai/ai-provider-icon.tsx +7 -4
  93. package/src/components/app-config/ai-config.tsx +100 -76
  94. package/src/components/app-config/app-config-button.tsx +10 -1
  95. package/src/components/app-config/constants.ts +0 -34
  96. package/src/components/app-config/incorrect-model-id.tsx +4 -2
  97. package/src/components/app-config/user-config-form.tsx +12 -5
  98. package/src/components/chat/chat-panel.tsx +12 -26
  99. package/src/components/slides/slides.css +0 -1
  100. package/src/core/ai/__tests__/model-registry.test.ts +357 -0
  101. package/src/{utils/ai → core/ai/ids}/__tests__/ids.test.ts +2 -1
  102. package/src/{utils/ai → core/ai/ids}/ids.ts +18 -10
  103. package/src/core/ai/model-registry.ts +164 -0
  104. package/src/core/cells/cells.ts +1 -1
  105. package/src/core/cells/effects.ts +1 -1
  106. package/src/plugins/layout/carousel/CarouselPlugin.tsx +0 -2
  107. package/src/utils/__tests__/multi-map.test.ts +295 -0
  108. package/src/utils/multi-map.ts +71 -0
  109. package/dist/assets/_baseMap-lEtQfieX.js +0 -1
  110. package/dist/assets/channel-CJdgPvjM.js +0 -1
  111. package/dist/assets/classDiagram-M3E45YP4-Tb8oQ03C.js +0 -1
  112. package/dist/assets/classDiagram-v2-YAWTLIQI-Tb8oQ03C.js +0 -1
  113. package/dist/assets/clone-DygFoMzB.js +0 -1
  114. package/dist/assets/index-BlxPam9h.css +0 -1
  115. package/dist/assets/infoDiagram-LHK5PUON-Cz497oaY.js +0 -2
  116. package/dist/assets/links-Cxxlu7np.js +0 -17
  117. package/dist/assets/stateDiagram-v2-5AN5P6BG-Deqw4NDh.js +0 -1
@@ -1,14 +1,16 @@
1
1
  /* Copyright 2024 Marimo. All rights reserved. */
2
2
  import React from "react";
3
+ import { AiModelId, type QualifiedModelId } from "@/core/ai/ids/ids";
3
4
  import { Banner } from "@/plugins/impl/common/error-banner";
4
- import { AiModelId, type QualifiedModelId } from "@/utils/ai/ids";
5
5
 
6
6
  interface IncorrectModelIdProps {
7
7
  value: string | null | undefined;
8
+ includeSuggestion?: boolean;
8
9
  }
9
10
 
10
11
  export const IncorrectModelId: React.FC<IncorrectModelIdProps> = ({
11
12
  value,
13
+ includeSuggestion = true,
12
14
  }) => {
13
15
  if (!value) {
14
16
  return null;
@@ -31,7 +33,7 @@ export const IncorrectModelId: React.FC<IncorrectModelIdProps> = ({
31
33
  provider.
32
34
  </span>
33
35
  <br />
34
- {suggestion && (
36
+ {includeSuggestion && suggestion && (
35
37
  <span>
36
38
  Did you mean <code className="font-bold">{suggestion}</code>?
37
39
  </span>
@@ -12,7 +12,7 @@ import {
12
12
  MonitorIcon,
13
13
  PackageIcon,
14
14
  } from "lucide-react";
15
- import React, { useRef } from "react";
15
+ import React, { useId, useRef } from "react";
16
16
  import { useForm } from "react-hook-form";
17
17
  import { Button } from "@/components/ui/button";
18
18
  import { Checkbox } from "@/components/ui/checkbox";
@@ -130,6 +130,8 @@ export const UserConfigForm: React.FC = () => {
130
130
  };
131
131
 
132
132
  const isWasmRuntime = isWasm();
133
+ const htmlCheckboxId = useId();
134
+ const ipynbCheckboxId = useId();
133
135
 
134
136
  const renderBody = () => {
135
137
  switch (activeCategory) {
@@ -168,6 +170,7 @@ export const UserConfigForm: React.FC = () => {
168
170
  <FormLabel>Autosave delay (seconds)</FormLabel>
169
171
  <FormControl>
170
172
  <NumberField
173
+ aria-label="Autosave delay"
171
174
  data-testid="autosave-delay-input"
172
175
  className="m-0 w-24"
173
176
  isDisabled={
@@ -205,7 +208,7 @@ export const UserConfigForm: React.FC = () => {
205
208
  <div className="flex gap-4">
206
209
  <div className="flex items-center space-x-2">
207
210
  <Checkbox
208
- id="html-checkbox"
211
+ id={htmlCheckboxId}
209
212
  checked={
210
213
  Array.isArray(field.value) &&
211
214
  field.value.includes("html")
@@ -219,11 +222,11 @@ export const UserConfigForm: React.FC = () => {
219
222
  );
220
223
  }}
221
224
  />
222
- <FormLabel htmlFor="html-checkbox">HTML</FormLabel>
225
+ <FormLabel htmlFor={htmlCheckboxId}>HTML</FormLabel>
223
226
  </div>
224
227
  <div className="flex items-center space-x-2">
225
228
  <Checkbox
226
- id="ipynb-checkbox"
229
+ id={ipynbCheckboxId}
227
230
  checked={
228
231
  Array.isArray(field.value) &&
229
232
  field.value.includes("ipynb")
@@ -237,7 +240,7 @@ export const UserConfigForm: React.FC = () => {
237
240
  );
238
241
  }}
239
242
  />
240
- <FormLabel htmlFor="ipynb-checkbox">
243
+ <FormLabel htmlFor={ipynbCheckboxId}>
241
244
  IPYNB
242
245
  </FormLabel>
243
246
  </div>
@@ -295,6 +298,7 @@ export const UserConfigForm: React.FC = () => {
295
298
  <FormLabel>Line length</FormLabel>
296
299
  <FormControl>
297
300
  <NumberField
301
+ aria-label="Line length"
298
302
  data-testid="line-length-input"
299
303
  className="m-0 w-24"
300
304
  {...field}
@@ -716,6 +720,7 @@ export const UserConfigForm: React.FC = () => {
716
720
  <FormControl>
717
721
  <span className="inline-flex mr-2">
718
722
  <NumberField
723
+ aria-label="Code editor font size"
719
724
  data-testid="code-editor-font-size-input"
720
725
  className="m-0 w-24"
721
726
  {...field}
@@ -848,6 +853,7 @@ export const UserConfigForm: React.FC = () => {
848
853
  <FormLabel>Default table page size</FormLabel>
849
854
  <FormControl>
850
855
  <NumberField
856
+ aria-label="Default table page size"
851
857
  data-testid="default-table-page-size-input"
852
858
  className="m-0 w-24"
853
859
  {...field}
@@ -884,6 +890,7 @@ export const UserConfigForm: React.FC = () => {
884
890
  <FormLabel>Default table max columns</FormLabel>
885
891
  <FormControl>
886
892
  <NumberField
893
+ aria-label="Default table max columns"
887
894
  data-testid="default-table-max-columns-input"
888
895
  className="m-0 w-24"
889
896
  {...field}
@@ -36,6 +36,7 @@ import {
36
36
  SelectValue,
37
37
  } from "@/components/ui/select";
38
38
  import { addMessageToChat } from "@/core/ai/chat-utils";
39
+ import type { QualifiedModelId } from "@/core/ai/ids/ids";
39
40
  import {
40
41
  activeChatAtom,
41
42
  type Chat,
@@ -54,7 +55,7 @@ import { cn } from "@/utils/cn";
54
55
  import { timeAgo } from "@/utils/dates";
55
56
  import { Logger } from "@/utils/Logger";
56
57
  import { generateUUID } from "@/utils/uuid";
57
- import { KNOWN_AI_MODELS } from "../app-config/constants";
58
+ import { AIModelDropdown } from "../ai/ai-model-dropdown";
58
59
  import { useOpenSettingsToTab } from "../app-config/state";
59
60
  import { PromptInput } from "../editor/ai/add-cell-with-ai";
60
61
  import { getAICompletionBody } from "../editor/ai/completion-utils";
@@ -276,7 +277,7 @@ const ChatInputFooter: React.FC<ChatInputFooterProps> = memo(
276
277
  saveConfig(newConfig);
277
278
  };
278
279
 
279
- const handleModelChange = async (newModel: string) => {
280
+ const handleModelChange = async (newModel: QualifiedModelId) => {
280
281
  const newConfig: UserConfig = {
281
282
  ...userConfig,
282
283
  ai: {
@@ -324,30 +325,15 @@ const ChatInputFooter: React.FC<ChatInputFooterProps> = memo(
324
325
  </SelectContent>
325
326
  </Select>
326
327
  </FeatureFlagged>
327
- <Select value={currentModel} onValueChange={handleModelChange}>
328
- <SelectTrigger className="h-6 text-xs border-border shadow-none! ring-0! bg-muted hover:bg-muted/30 py-0 px-2 gap-1">
329
- <SelectValue placeholder="Model" />
330
- </SelectTrigger>
331
- <SelectContent>
332
- {/* Show current model if it's not in the known models list */}
333
- {!(KNOWN_AI_MODELS as readonly string[]).includes(
334
- currentModel,
335
- ) && (
336
- <SelectItem
337
- key={currentModel}
338
- value={currentModel}
339
- className="text-sm"
340
- >
341
- {currentModel}
342
- </SelectItem>
343
- )}
344
- {KNOWN_AI_MODELS.map((model) => (
345
- <SelectItem key={model} value={model} className="text-sm">
346
- {model}
347
- </SelectItem>
348
- ))}
349
- </SelectContent>
350
- </Select>
328
+ <AIModelDropdown
329
+ value={currentModel}
330
+ placeholder="Model"
331
+ onSelect={handleModelChange}
332
+ triggerClassName="h-6 text-xs shadow-none! ring-0! bg-muted hover:bg-muted/30 rounded-sm"
333
+ iconSize="small"
334
+ showAddCustomModelDocs={true}
335
+ forRole="chat"
336
+ />
351
337
  </div>
352
338
  <Button
353
339
  variant="ghost"
@@ -1,6 +1,5 @@
1
1
  @import "swiper/css";
2
2
  @import "swiper/css/virtual";
3
- @import "swiper/css/keyboard";
4
3
  @import "swiper/css/navigation";
5
4
  @import "swiper/css/pagination";
6
5
  @import "swiper/css/scrollbar";
@@ -0,0 +1,357 @@
1
+ /* Copyright 2024 Marimo. All rights reserved. */
2
+ import { beforeEach, describe, expect, it, vi } from "vitest";
3
+
4
+ // Mock the models.json import
5
+ vi.mock("@marimo-team/llm-info/models.json", () => {
6
+ const models: AiModel[] = [
7
+ {
8
+ name: "GPT-4",
9
+ model: "gpt-4",
10
+ description: "OpenAI GPT-4 model",
11
+ providers: ["openai"],
12
+ roles: ["chat", "edit"],
13
+ thinking: false,
14
+ },
15
+ {
16
+ name: "Claude 3",
17
+ model: "claude-3-sonnet",
18
+ description: "Anthropic Claude 3 Sonnet",
19
+ providers: ["anthropic"],
20
+ roles: ["chat", "edit"],
21
+ thinking: false,
22
+ },
23
+ {
24
+ name: "Gemini Pro",
25
+ model: "gemini-pro",
26
+ description: "Google Gemini Pro model",
27
+ providers: ["google"],
28
+ roles: ["chat", "edit"],
29
+ thinking: false,
30
+ },
31
+ {
32
+ name: "Multi Provider Model",
33
+ model: "multi-model",
34
+ description: "Model available on multiple providers",
35
+ providers: ["openai", "anthropic"],
36
+ roles: ["chat", "edit"],
37
+ thinking: false,
38
+ },
39
+ ];
40
+
41
+ return {
42
+ models: models,
43
+ };
44
+ });
45
+
46
+ import type { AiModel } from "@marimo-team/llm-info";
47
+ import { AiModelRegistry } from "../model-registry";
48
+
49
+ describe("AiModelRegistry", () => {
50
+ beforeEach(() => {
51
+ vi.clearAllMocks();
52
+ });
53
+
54
+ describe("create", () => {
55
+ it("should create registry with no custom or displayed models", () => {
56
+ const registry = AiModelRegistry.create({});
57
+
58
+ expect(registry).toBeInstanceOf(AiModelRegistry);
59
+ expect(registry.getCustomModels()).toEqual(new Set());
60
+ expect(registry.getDisplayedModels()).toEqual(new Set());
61
+ });
62
+
63
+ it("should create registry with custom models", () => {
64
+ const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
65
+ const registry = AiModelRegistry.create({ customModels });
66
+
67
+ expect(registry.getCustomModels()).toEqual(new Set(customModels));
68
+ expect(registry.getDisplayedModels()).toEqual(new Set());
69
+ });
70
+
71
+ it("should create registry with displayed models", () => {
72
+ const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
73
+ const registry = AiModelRegistry.create({ displayedModels });
74
+
75
+ expect(registry.getCustomModels()).toEqual(new Set());
76
+ expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
77
+ });
78
+
79
+ it("should create registry with both custom and displayed models", () => {
80
+ const customModels = ["openai/custom-gpt"];
81
+ const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
82
+ const registry = AiModelRegistry.create({
83
+ customModels,
84
+ displayedModels,
85
+ });
86
+
87
+ expect(registry.getCustomModels()).toEqual(new Set(customModels));
88
+ expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
89
+ });
90
+ });
91
+
92
+ describe("getModelsByProvider", () => {
93
+ it("should return models for a specific provider", () => {
94
+ const registry = AiModelRegistry.create({});
95
+ const openaiModels = registry.getModelsByProvider("openai");
96
+
97
+ expect(openaiModels).toHaveLength(2); // gpt-4 and multi-model
98
+ expect(
99
+ openaiModels.every((model) => model.providers.includes("openai")),
100
+ ).toBe(true);
101
+ });
102
+
103
+ it("should return empty array for provider with no models", () => {
104
+ const registry = AiModelRegistry.create({});
105
+ const azureModels = registry.getModelsByProvider("azure");
106
+
107
+ expect(azureModels).toEqual([]);
108
+ });
109
+
110
+ it("should include custom models for the provider", () => {
111
+ const customModels = ["openai/custom-gpt"];
112
+ const registry = AiModelRegistry.create({ customModels });
113
+ const openaiModels = registry.getModelsByProvider("openai");
114
+
115
+ const customModel = openaiModels.find((model) => model.custom);
116
+ expect(customModel).toBeDefined();
117
+ expect(customModel?.name).toBe("openai/custom-gpt");
118
+ expect(customModel?.model).toBe("custom-gpt");
119
+ expect(customModel?.description).toBe("Custom model");
120
+ expect(customModel?.providers).toEqual(["openai"]);
121
+ expect(customModel?.roles).toEqual([]);
122
+ expect(customModel?.thinking).toBe(false);
123
+ });
124
+
125
+ it("should filter models based on displayed models", () => {
126
+ const displayedModels = ["openai/gpt-4"];
127
+ const registry = AiModelRegistry.create({ displayedModels });
128
+ const openaiModels = registry.getModelsByProvider("openai");
129
+
130
+ expect(openaiModels).toHaveLength(1);
131
+ expect(openaiModels[0].model).toBe("gpt-4");
132
+ });
133
+
134
+ it("should filter custom models based on displayed models", () => {
135
+ const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
136
+ const displayedModels = ["openai/custom-gpt"];
137
+ const registry = AiModelRegistry.create({
138
+ customModels,
139
+ displayedModels,
140
+ });
141
+
142
+ const openaiModels = registry.getModelsByProvider("openai");
143
+ const anthropicModels = registry.getModelsByProvider("anthropic");
144
+
145
+ expect(
146
+ openaiModels.some(
147
+ (model) => model.custom && model.model === "custom-gpt",
148
+ ),
149
+ ).toBe(true);
150
+ expect(
151
+ anthropicModels.some(
152
+ (model) => model.custom && model.model === "custom-claude",
153
+ ),
154
+ ).toBe(false);
155
+ });
156
+ });
157
+
158
+ describe("getGroupedModelsByProvider", () => {
159
+ it("should return all models grouped by provider", () => {
160
+ const registry = AiModelRegistry.create({});
161
+ const groupedModels = registry.getGroupedModelsByProvider();
162
+
163
+ expect(groupedModels.has("openai")).toBe(true);
164
+ expect(groupedModels.has("anthropic")).toBe(true);
165
+ expect(groupedModels.has("google")).toBe(true);
166
+
167
+ const openaiModels = groupedModels.get("openai") || [];
168
+ const anthropicModels = groupedModels.get("anthropic") || [];
169
+ const googleModels = groupedModels.get("google") || [];
170
+
171
+ expect(openaiModels.length).toEqual(2);
172
+ expect(anthropicModels.length).toEqual(2);
173
+ expect(googleModels.length).toEqual(1);
174
+ });
175
+
176
+ it("should include custom models in the grouped results", () => {
177
+ const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
178
+ const registry = AiModelRegistry.create({ customModels });
179
+ const groupedModels = registry.getGroupedModelsByProvider();
180
+
181
+ const openaiModels = groupedModels.get("openai") || [];
182
+ const anthropicModels = groupedModels.get("anthropic") || [];
183
+
184
+ expect(
185
+ openaiModels.some(
186
+ (model) => model.custom && model.model === "custom-gpt",
187
+ ),
188
+ ).toBe(true);
189
+ expect(
190
+ anthropicModels.some(
191
+ (model) => model.custom && model.model === "custom-claude",
192
+ ),
193
+ ).toBe(true);
194
+ });
195
+
196
+ it("should respect displayed models filter", () => {
197
+ const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
198
+ const registry = AiModelRegistry.create({ displayedModels });
199
+ const groupedModels = registry.getGroupedModelsByProvider();
200
+
201
+ const openaiModels = groupedModels.get("openai") || [];
202
+ const anthropicModels = groupedModels.get("anthropic") || [];
203
+ const googleModels = groupedModels.get("google") || [];
204
+
205
+ expect(openaiModels.length).toBe(1);
206
+ expect(openaiModels[0].model).toBe("gpt-4");
207
+ expect(anthropicModels.length).toBe(1);
208
+ expect(anthropicModels[0].model).toBe("claude-3-sonnet");
209
+ expect(googleModels.length).toBe(0);
210
+ });
211
+ });
212
+
213
+ describe("getCustomModels", () => {
214
+ it("should return empty set when no custom models", () => {
215
+ const registry = AiModelRegistry.create({});
216
+ expect(registry.getCustomModels()).toEqual(new Set());
217
+ });
218
+
219
+ it("should return set of custom model IDs", () => {
220
+ const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
221
+ const registry = AiModelRegistry.create({ customModels });
222
+ expect(registry.getCustomModels()).toEqual(new Set(customModels));
223
+ });
224
+ });
225
+
226
+ describe("getDisplayedModels", () => {
227
+ it("should return empty set when no displayed models specified", () => {
228
+ const registry = AiModelRegistry.create({});
229
+ expect(registry.getDisplayedModels()).toEqual(new Set());
230
+ });
231
+
232
+ it("should return set of displayed model IDs", () => {
233
+ const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
234
+ const registry = AiModelRegistry.create({ displayedModels });
235
+ expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
236
+ });
237
+ });
238
+
239
+ describe("edge cases", () => {
240
+ it("should handle empty arrays for custom and displayed models", () => {
241
+ const registry = AiModelRegistry.create({
242
+ customModels: [],
243
+ displayedModels: [],
244
+ });
245
+
246
+ expect(registry.getCustomModels()).toEqual(new Set());
247
+ expect(registry.getDisplayedModels()).toEqual(new Set());
248
+
249
+ // Should still load default models
250
+ const openaiModels = registry.getModelsByProvider("openai");
251
+ expect(openaiModels.length).toBeGreaterThan(0);
252
+ });
253
+
254
+ it("should handle models with multiple providers", () => {
255
+ const registry = AiModelRegistry.create({});
256
+
257
+ const openaiModels = registry.getModelsByProvider("openai");
258
+ const anthropicModels = registry.getModelsByProvider("anthropic");
259
+
260
+ // The multi-model should appear in both providers
261
+ const multiModelInOpenai = openaiModels.find(
262
+ (model) => model.model === "multi-model",
263
+ );
264
+ const multiModelInAnthropic = anthropicModels.find(
265
+ (model) => model.model === "multi-model",
266
+ );
267
+
268
+ expect(multiModelInOpenai).toBeDefined();
269
+ expect(multiModelInAnthropic).toBeDefined();
270
+ expect(multiModelInOpenai).toEqual(multiModelInAnthropic);
271
+ });
272
+
273
+ it("should handle displayed models filter with non-existent models", () => {
274
+ const displayedModels = [
275
+ "openai/non-existent-model",
276
+ "anthropic/claude-3-sonnet",
277
+ ];
278
+ const registry = AiModelRegistry.create({ displayedModels });
279
+
280
+ const openaiModels = registry.getModelsByProvider("openai");
281
+ const anthropicModels = registry.getModelsByProvider("anthropic");
282
+
283
+ // Should only show the existing model
284
+ expect(openaiModels.length).toBe(0);
285
+ expect(anthropicModels.length).toBe(1);
286
+ expect(anthropicModels[0].model).toBe("claude-3-sonnet");
287
+ });
288
+ });
289
+
290
+ describe("model properties", () => {
291
+ it("should ensure all models have required properties", () => {
292
+ const registry = AiModelRegistry.create({});
293
+ const groupedModels = registry.getGroupedModelsByProvider();
294
+
295
+ for (const [provider, models] of groupedModels.entries()) {
296
+ for (const model of models) {
297
+ expect(model).toHaveProperty("name");
298
+ expect(model).toHaveProperty("model");
299
+ expect(model).toHaveProperty("description");
300
+ expect(model).toHaveProperty("providers");
301
+ expect(model).toHaveProperty("roles");
302
+ expect(model).toHaveProperty("thinking");
303
+ expect(model).toHaveProperty("custom");
304
+
305
+ expect(typeof model.name).toBe("string");
306
+ expect(typeof model.model).toBe("string");
307
+ expect(typeof model.description).toBe("string");
308
+ expect(Array.isArray(model.providers)).toBe(true);
309
+ expect(Array.isArray(model.roles)).toBe(true);
310
+ expect(typeof model.thinking).toBe("boolean");
311
+ expect(typeof model.custom).toBe("boolean");
312
+
313
+ expect(model.providers).toContain(provider);
314
+ }
315
+ }
316
+ });
317
+
318
+ it("should ensure custom models have correct custom property", () => {
319
+ const customModels = ["openai/custom-gpt"];
320
+ const registry = AiModelRegistry.create({ customModels });
321
+ const openaiModels = registry.getModelsByProvider("openai");
322
+
323
+ const customModel = openaiModels.find((model) => model.custom);
324
+ const defaultModel = openaiModels.find((model) => !model.custom);
325
+
326
+ expect(customModel).toMatchInlineSnapshot(`
327
+ {
328
+ "custom": true,
329
+ "description": "Custom model",
330
+ "model": "custom-gpt",
331
+ "name": "openai/custom-gpt",
332
+ "providers": [
333
+ "openai",
334
+ ],
335
+ "roles": [],
336
+ "thinking": false,
337
+ }
338
+ `);
339
+ expect(defaultModel).toMatchInlineSnapshot(`
340
+ {
341
+ "custom": false,
342
+ "description": "OpenAI GPT-4 model",
343
+ "model": "gpt-4",
344
+ "name": "GPT-4",
345
+ "providers": [
346
+ "openai",
347
+ ],
348
+ "roles": [
349
+ "chat",
350
+ "edit",
351
+ ],
352
+ "thinking": false,
353
+ }
354
+ `);
355
+ });
356
+ });
357
+ });
@@ -1,7 +1,8 @@
1
1
  /* Copyright 2024 Marimo. All rights reserved. */
2
2
 
3
3
  import { describe, expect, it } from "vitest";
4
- import { AiModelId, type ProviderId, type ShortModelId } from "../ids";
4
+ import type { ProviderId } from "../ids";
5
+ import { AiModelId, type ShortModelId } from "../ids";
5
6
 
6
7
  describe("AiModelId", () => {
7
8
  describe("constructor", () => {
@@ -1,16 +1,17 @@
1
1
  /* Copyright 2024 Marimo. All rights reserved. */
2
2
 
3
- import type { TypedString } from "../typed";
3
+ import type { TypedString } from "@/utils/typed";
4
4
 
5
- /**
6
- * Supported providers by the marimo server.
7
- */
8
- export type ProviderId =
9
- | "openai"
10
- | "anthropic"
11
- | "google"
12
- | "ollama"
13
- | "bedrock";
5
+ export const PROVIDERS = [
6
+ "openai",
7
+ "anthropic",
8
+ "google",
9
+ "ollama",
10
+ "bedrock",
11
+ "deepseek",
12
+ "azure",
13
+ ] as const;
14
+ export type ProviderId = (typeof PROVIDERS)[number];
14
15
 
15
16
  export type ShortModelId = TypedString<"ShortModelId">;
16
17
 
@@ -57,5 +58,12 @@ function guessProviderId(id: string): ProviderId {
57
58
  if (id.startsWith("gemini") || id.startsWith("google")) {
58
59
  return "google";
59
60
  }
61
+ if (id.startsWith("deepseek")) {
62
+ return "deepseek";
63
+ }
60
64
  return "ollama";
61
65
  }
66
+
67
+ export function isKnownAIProvider(providerId: ProviderId): boolean {
68
+ return PROVIDERS.includes(providerId);
69
+ }