@marimo-team/frontend 0.14.18-dev24 → 0.14.18-dev26
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/assets/{ConnectedDataExplorerComponent-Dbu2xXc5.js → ConnectedDataExplorerComponent-CPwF1ckF.js} +1 -1
- package/dist/assets/{ImageComparisonComponent-BkZIjGHA.js → ImageComparisonComponent-BuFzvNj5.js} +1 -1
- package/dist/assets/{VegaLite-Ca7AXGyA.js → VegaLite-DI-xPLgU.js} +1 -1
- package/dist/assets/{_baseEach-H5Qk1V2B.js → _baseEach-BvIVAeE-.js} +1 -1
- package/dist/assets/_baseMap-DrbRr9s_.js +1 -0
- package/dist/assets/{_baseUniq-UAmxGez2.js → _baseUniq-Cz-AR5wG.js} +1 -1
- package/dist/assets/{_createAggregator-H-t5qYSG.js → _createAggregator-fFJIgkzg.js} +1 -1
- package/dist/assets/{any-language-editor-DUt-pIdy.js → any-language-editor-yovSkXXH.js} +1 -1
- package/dist/assets/{architectureDiagram-SUXI7LT5-DWP05q8_.js → architectureDiagram-SUXI7LT5-vhSKxFkz.js} +1 -1
- package/dist/assets/{blockDiagram-6J76NXCF-DB57b3LI.js → blockDiagram-6J76NXCF-D8FkmW5U.js} +1 -1
- package/dist/assets/{c4Diagram-6F6E4RAY-ByeAGPmY.js → c4Diagram-6F6E4RAY-DMz49WYv.js} +1 -1
- package/dist/assets/channel-C5D8_wdC.js +1 -0
- package/dist/assets/{chunk-353BL4L5-ngKqumF3.js → chunk-353BL4L5-C_h8QREg.js} +1 -1
- package/dist/assets/{chunk-67H74DCK-DTbVgB4A.js → chunk-67H74DCK-vQz9ujXD.js} +1 -1
- package/dist/assets/{chunk-AACKK3MU-t9Kf_p_V.js → chunk-AACKK3MU-CUxeSfqv.js} +1 -1
- package/dist/assets/{chunk-BFAMUDN2-B0dcgdMs.js → chunk-BFAMUDN2-CLGtBusC.js} +1 -1
- package/dist/assets/{chunk-E2GYISFI-C7n5SQvq.js → chunk-E2GYISFI-DGILtsZz.js} +1 -1
- package/dist/assets/{chunk-OW32GOEJ-DEvBFIh8.js → chunk-OW32GOEJ-D-HRiiaD.js} +1 -1
- package/dist/assets/{chunk-SKB7J2MH-DX71TF_n.js → chunk-SKB7J2MH-xs1kEiQZ.js} +1 -1
- package/dist/assets/{chunk-SZ463SBG-CO9fYjVe.js → chunk-SZ463SBG-DKY2KPhj.js} +1 -1
- package/dist/assets/{circle-play-B-ZWQLkW.js → circle-play-ub1A5yYA.js} +1 -1
- package/dist/assets/classDiagram-M3E45YP4-Bzxtjbca.js +1 -0
- package/dist/assets/classDiagram-v2-YAWTLIQI-Bzxtjbca.js +1 -0
- package/dist/assets/clone-DngVHtKT.js +1 -0
- package/dist/assets/{compile-lv6gzALt.js → compile-Eo2s_HHz.js} +1 -1
- package/dist/assets/{dagre-JOIXM2OF-B7ZtMRIJ.js → dagre-JOIXM2OF-BNvsGquX.js} +1 -1
- package/dist/assets/{data-grid-overlay-editor-DTAq6v9P.js → data-grid-overlay-editor-Vk_PgvbS.js} +1 -1
- package/dist/assets/{diagram-5UYTHUR4-CboeMF3G.js → diagram-5UYTHUR4-6gGyo4aQ.js} +1 -1
- package/dist/assets/{diagram-VMROVX33-BcnQcTFp.js → diagram-VMROVX33-Cr9VHqxg.js} +1 -1
- package/dist/assets/{diagram-ZTM2IBQH-CaFT3B8g.js → diagram-ZTM2IBQH-BFP88P3p.js} +1 -1
- package/dist/assets/{edit-page-CvWZ8wVZ.js → edit-page-Vrbmp55N.js} +53 -53
- package/dist/assets/{erDiagram-3M52JZNH-BFPc1-ng.js → erDiagram-3M52JZNH-C0U8FCGV.js} +1 -1
- package/dist/assets/{flowDiagram-KYDEHFYC-CAZBj4fC.js → flowDiagram-KYDEHFYC-3GT8PdlC.js} +1 -1
- package/dist/assets/{ganttDiagram-EK5VF46D-CUxffTMj.js → ganttDiagram-EK5VF46D-CrQQiHRJ.js} +1 -1
- package/dist/assets/{gitGraphDiagram-GW3U2K7C-Tx-ZfecV.js → gitGraphDiagram-GW3U2K7C-BuMcEd-g.js} +1 -1
- package/dist/assets/{glide-data-editor-BP8l5q3f.js → glide-data-editor-BB_hULql.js} +11 -11
- package/dist/assets/{graph-Bf6eoL3M.js → graph-B67OAr_1.js} +1 -1
- package/dist/assets/{home-page--ixzxF-w.js → home-page-D4saUb6P.js} +1 -1
- package/dist/assets/index-B3Q6PaCG.css +1 -0
- package/dist/assets/{index-g-JZ5Z-_.js → index-BCsh52ZA.js} +1 -1
- package/dist/assets/{index-DJ1FNShi.js → index-BUmTew_7.js} +1 -1
- package/dist/assets/{index-CmJm6kj5.js → index-BbaurO1n.js} +1 -1
- package/dist/assets/{index-D1vinr8C.js → index-BcNCs1Ec.js} +1 -1
- package/dist/assets/{index-Cv01rAR1.js → index-BcOm_dWX.js} +1 -1
- package/dist/assets/{index-dvYU8Vev.js → index-Bg-oxR2t.js} +94 -94
- package/dist/assets/{index-DleA5-JR.js → index-CT6WEyVV.js} +1 -1
- package/dist/assets/{index-kaWF-HKt.js → index-CWnyv0Mb.js} +1 -1
- package/dist/assets/{index-BMLuYuQT.js → index-CYZYUN8a.js} +1 -1
- package/dist/assets/{index-DupjaVjo.js → index-CfQw831p.js} +1 -1
- package/dist/assets/{index-C4iPAevr.js → index-CydDDozP.js} +1 -1
- package/dist/assets/{index-q4AwmgtF.js → index-D-aN-b2W.js} +1 -1
- package/dist/assets/{index-Uh_QGNQm.js → index-DFBulMSA.js} +1 -1
- package/dist/assets/{index-U0KSor5u.js → index-DeKrdrEq.js} +1 -1
- package/dist/assets/{index-Q73xW9dd.js → index-DqZ9_rHU.js} +1 -1
- package/dist/assets/{index-mvfeSH0B.js → index-Dtw1TXyN.js} +1 -1
- package/dist/assets/{index-B39Q37Jh.js → index-FrFiygIp.js} +1 -1
- package/dist/assets/{index-DhU7edG1.js → index-IDJxuXSR.js} +1 -1
- package/dist/assets/{index-CM8rF_ge.js → index-LcqVorg9.js} +1 -1
- package/dist/assets/infoDiagram-LHK5PUON-Bc9fde05.js +2 -0
- package/dist/assets/{journeyDiagram-EWQZEKCU-DBorgqR8.js → journeyDiagram-EWQZEKCU-CQkZyIS5.js} +1 -1
- package/dist/assets/{kanban-definition-ZSS6B67P-rOei7rdW.js → kanban-definition-ZSS6B67P-CEV1b6rk.js} +1 -1
- package/dist/assets/{layout-D5U1vfFv.js → layout-DAmvOShj.js} +1 -1
- package/dist/assets/{linear-BPaO6rYC.js → linear-ktW_rVCS.js} +1 -1
- package/dist/assets/links-C8gYI3jG.js +18 -0
- package/dist/assets/{mermaid-DuiiiGkf.js → mermaid-Dn9IUXi8.js} +4 -4
- package/dist/assets/{min-GM8d8p3k.js → min-BmsA5VGH.js} +1 -1
- package/dist/assets/{mindmap-definition-6CBA2TL7-DFNMYbU5.js → mindmap-definition-6CBA2TL7-B0yAK13C.js} +1 -1
- package/dist/assets/{number-overlay-editor-CcD_5O9P.js → number-overlay-editor-DbvoJDEj.js} +1 -1
- package/dist/assets/{pieDiagram-NIOCPIFQ-BAnWQSTZ.js → pieDiagram-NIOCPIFQ-BaPxK1u0.js} +1 -1
- package/dist/assets/{quadrantDiagram-2OG54O6I-CLzghZ4d.js → quadrantDiagram-2OG54O6I-PGZoquB4.js} +1 -1
- package/dist/assets/{react-plotly-qIomJONw.js → react-plotly-CzEXKtwr.js} +1 -1
- package/dist/assets/{requirementDiagram-QOLK2EJ7-ClKsOuLQ.js → requirementDiagram-QOLK2EJ7-BdGXi15P.js} +1 -1
- package/dist/assets/{run-page-B9ntqQci.js → run-page-BnZ5haku.js} +1 -1
- package/dist/assets/{sankeyDiagram-4UZDY2LN-CMUyusMd.js → sankeyDiagram-4UZDY2LN-Dk80vg6t.js} +1 -1
- package/dist/assets/{sequenceDiagram-SKLFT4DO-BUL7OokF.js → sequenceDiagram-SKLFT4DO-CqLVyCT3.js} +1 -1
- package/dist/assets/{slides-component-DxNxYl9E.js → slides-component-D_39bH-j.js} +1 -1
- package/dist/assets/{sortBy-Bo672N53.js → sortBy-CI9MGbno.js} +1 -1
- package/dist/assets/{stateDiagram-MI5ZYTHO-CK5D03xc.js → stateDiagram-MI5ZYTHO-C3tpbdm0.js} +1 -1
- package/dist/assets/stateDiagram-v2-5AN5P6BG-BXclnbG1.js +1 -0
- package/dist/assets/{storage-DCGJ86_2.js → storage-CxOI7c0Z.js} +3 -3
- package/dist/assets/{terminal-C9ZYVCQk.js → terminal-9Nyd6Apx.js} +1 -1
- package/dist/assets/{time-DUBsogDP.js → time-Do4-nT0Q.js} +1 -1
- package/dist/assets/{timeline-definition-MYPXXCX6-DlOzuKHL.js → timeline-definition-MYPXXCX6-CtllVQGu.js} +1 -1
- package/dist/assets/{tracing-DkB9iogQ.js → tracing-BWy-UTCz.js} +2 -2
- package/dist/assets/{trash-CuNNSzF1.js → trash-BPgHERlA.js} +1 -1
- package/dist/assets/{treemap-75Q7IDZK-DUsN_Z4F.js → treemap-75Q7IDZK-BXkwu4nl.js} +1 -1
- package/dist/assets/{vega-component-DByprFwW.js → vega-component-COxYVQBE.js} +1 -1
- package/dist/assets/{xychartDiagram-H2YORKM3-BaUqYAb6.js → xychartDiagram-H2YORKM3-DfzCBqlv.js} +1 -1
- package/dist/index.html +2 -2
- package/package.json +1 -1
- package/src/components/ai/ai-model-dropdown.tsx +288 -0
- package/src/components/ai/ai-provider-icon.tsx +7 -4
- package/src/components/app-config/ai-config.tsx +100 -76
- package/src/components/app-config/app-config-button.tsx +10 -1
- package/src/components/app-config/constants.ts +0 -34
- package/src/components/app-config/incorrect-model-id.tsx +4 -2
- package/src/components/app-config/user-config-form.tsx +12 -5
- package/src/components/chat/chat-panel.tsx +12 -26
- package/src/components/slides/slides.css +0 -1
- package/src/core/ai/__tests__/model-registry.test.ts +357 -0
- package/src/{utils/ai → core/ai/ids}/__tests__/ids.test.ts +2 -1
- package/src/{utils/ai → core/ai/ids}/ids.ts +18 -10
- package/src/core/ai/model-registry.ts +164 -0
- package/src/core/cells/cells.ts +1 -1
- package/src/core/cells/effects.ts +1 -1
- package/src/plugins/layout/carousel/CarouselPlugin.tsx +0 -2
- package/src/utils/__tests__/multi-map.test.ts +295 -0
- package/src/utils/multi-map.ts +71 -0
- package/dist/assets/_baseMap-lEtQfieX.js +0 -1
- package/dist/assets/channel-CJdgPvjM.js +0 -1
- package/dist/assets/classDiagram-M3E45YP4-Tb8oQ03C.js +0 -1
- package/dist/assets/classDiagram-v2-YAWTLIQI-Tb8oQ03C.js +0 -1
- package/dist/assets/clone-DygFoMzB.js +0 -1
- package/dist/assets/index-BlxPam9h.css +0 -1
- package/dist/assets/infoDiagram-LHK5PUON-Cz497oaY.js +0 -2
- package/dist/assets/links-Cxxlu7np.js +0 -17
- package/dist/assets/stateDiagram-v2-5AN5P6BG-Deqw4NDh.js +0 -1
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
/* Copyright 2024 Marimo. All rights reserved. */
|
|
2
2
|
import React from "react";
|
|
3
|
+
import { AiModelId, type QualifiedModelId } from "@/core/ai/ids/ids";
|
|
3
4
|
import { Banner } from "@/plugins/impl/common/error-banner";
|
|
4
|
-
import { AiModelId, type QualifiedModelId } from "@/utils/ai/ids";
|
|
5
5
|
|
|
6
6
|
interface IncorrectModelIdProps {
|
|
7
7
|
value: string | null | undefined;
|
|
8
|
+
includeSuggestion?: boolean;
|
|
8
9
|
}
|
|
9
10
|
|
|
10
11
|
export const IncorrectModelId: React.FC<IncorrectModelIdProps> = ({
|
|
11
12
|
value,
|
|
13
|
+
includeSuggestion = true,
|
|
12
14
|
}) => {
|
|
13
15
|
if (!value) {
|
|
14
16
|
return null;
|
|
@@ -31,7 +33,7 @@ export const IncorrectModelId: React.FC<IncorrectModelIdProps> = ({
|
|
|
31
33
|
provider.
|
|
32
34
|
</span>
|
|
33
35
|
<br />
|
|
34
|
-
{suggestion && (
|
|
36
|
+
{includeSuggestion && suggestion && (
|
|
35
37
|
<span>
|
|
36
38
|
Did you mean <code className="font-bold">{suggestion}</code>?
|
|
37
39
|
</span>
|
|
@@ -12,7 +12,7 @@ import {
|
|
|
12
12
|
MonitorIcon,
|
|
13
13
|
PackageIcon,
|
|
14
14
|
} from "lucide-react";
|
|
15
|
-
import React, { useRef } from "react";
|
|
15
|
+
import React, { useId, useRef } from "react";
|
|
16
16
|
import { useForm } from "react-hook-form";
|
|
17
17
|
import { Button } from "@/components/ui/button";
|
|
18
18
|
import { Checkbox } from "@/components/ui/checkbox";
|
|
@@ -130,6 +130,8 @@ export const UserConfigForm: React.FC = () => {
|
|
|
130
130
|
};
|
|
131
131
|
|
|
132
132
|
const isWasmRuntime = isWasm();
|
|
133
|
+
const htmlCheckboxId = useId();
|
|
134
|
+
const ipynbCheckboxId = useId();
|
|
133
135
|
|
|
134
136
|
const renderBody = () => {
|
|
135
137
|
switch (activeCategory) {
|
|
@@ -168,6 +170,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
168
170
|
<FormLabel>Autosave delay (seconds)</FormLabel>
|
|
169
171
|
<FormControl>
|
|
170
172
|
<NumberField
|
|
173
|
+
aria-label="Autosave delay"
|
|
171
174
|
data-testid="autosave-delay-input"
|
|
172
175
|
className="m-0 w-24"
|
|
173
176
|
isDisabled={
|
|
@@ -205,7 +208,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
205
208
|
<div className="flex gap-4">
|
|
206
209
|
<div className="flex items-center space-x-2">
|
|
207
210
|
<Checkbox
|
|
208
|
-
id=
|
|
211
|
+
id={htmlCheckboxId}
|
|
209
212
|
checked={
|
|
210
213
|
Array.isArray(field.value) &&
|
|
211
214
|
field.value.includes("html")
|
|
@@ -219,11 +222,11 @@ export const UserConfigForm: React.FC = () => {
|
|
|
219
222
|
);
|
|
220
223
|
}}
|
|
221
224
|
/>
|
|
222
|
-
<FormLabel htmlFor=
|
|
225
|
+
<FormLabel htmlFor={htmlCheckboxId}>HTML</FormLabel>
|
|
223
226
|
</div>
|
|
224
227
|
<div className="flex items-center space-x-2">
|
|
225
228
|
<Checkbox
|
|
226
|
-
id=
|
|
229
|
+
id={ipynbCheckboxId}
|
|
227
230
|
checked={
|
|
228
231
|
Array.isArray(field.value) &&
|
|
229
232
|
field.value.includes("ipynb")
|
|
@@ -237,7 +240,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
237
240
|
);
|
|
238
241
|
}}
|
|
239
242
|
/>
|
|
240
|
-
<FormLabel htmlFor=
|
|
243
|
+
<FormLabel htmlFor={ipynbCheckboxId}>
|
|
241
244
|
IPYNB
|
|
242
245
|
</FormLabel>
|
|
243
246
|
</div>
|
|
@@ -295,6 +298,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
295
298
|
<FormLabel>Line length</FormLabel>
|
|
296
299
|
<FormControl>
|
|
297
300
|
<NumberField
|
|
301
|
+
aria-label="Line length"
|
|
298
302
|
data-testid="line-length-input"
|
|
299
303
|
className="m-0 w-24"
|
|
300
304
|
{...field}
|
|
@@ -716,6 +720,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
716
720
|
<FormControl>
|
|
717
721
|
<span className="inline-flex mr-2">
|
|
718
722
|
<NumberField
|
|
723
|
+
aria-label="Code editor font size"
|
|
719
724
|
data-testid="code-editor-font-size-input"
|
|
720
725
|
className="m-0 w-24"
|
|
721
726
|
{...field}
|
|
@@ -848,6 +853,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
848
853
|
<FormLabel>Default table page size</FormLabel>
|
|
849
854
|
<FormControl>
|
|
850
855
|
<NumberField
|
|
856
|
+
aria-label="Default table page size"
|
|
851
857
|
data-testid="default-table-page-size-input"
|
|
852
858
|
className="m-0 w-24"
|
|
853
859
|
{...field}
|
|
@@ -884,6 +890,7 @@ export const UserConfigForm: React.FC = () => {
|
|
|
884
890
|
<FormLabel>Default table max columns</FormLabel>
|
|
885
891
|
<FormControl>
|
|
886
892
|
<NumberField
|
|
893
|
+
aria-label="Default table max columns"
|
|
887
894
|
data-testid="default-table-max-columns-input"
|
|
888
895
|
className="m-0 w-24"
|
|
889
896
|
{...field}
|
|
@@ -36,6 +36,7 @@ import {
|
|
|
36
36
|
SelectValue,
|
|
37
37
|
} from "@/components/ui/select";
|
|
38
38
|
import { addMessageToChat } from "@/core/ai/chat-utils";
|
|
39
|
+
import type { QualifiedModelId } from "@/core/ai/ids/ids";
|
|
39
40
|
import {
|
|
40
41
|
activeChatAtom,
|
|
41
42
|
type Chat,
|
|
@@ -54,7 +55,7 @@ import { cn } from "@/utils/cn";
|
|
|
54
55
|
import { timeAgo } from "@/utils/dates";
|
|
55
56
|
import { Logger } from "@/utils/Logger";
|
|
56
57
|
import { generateUUID } from "@/utils/uuid";
|
|
57
|
-
import {
|
|
58
|
+
import { AIModelDropdown } from "../ai/ai-model-dropdown";
|
|
58
59
|
import { useOpenSettingsToTab } from "../app-config/state";
|
|
59
60
|
import { PromptInput } from "../editor/ai/add-cell-with-ai";
|
|
60
61
|
import { getAICompletionBody } from "../editor/ai/completion-utils";
|
|
@@ -276,7 +277,7 @@ const ChatInputFooter: React.FC<ChatInputFooterProps> = memo(
|
|
|
276
277
|
saveConfig(newConfig);
|
|
277
278
|
};
|
|
278
279
|
|
|
279
|
-
const handleModelChange = async (newModel:
|
|
280
|
+
const handleModelChange = async (newModel: QualifiedModelId) => {
|
|
280
281
|
const newConfig: UserConfig = {
|
|
281
282
|
...userConfig,
|
|
282
283
|
ai: {
|
|
@@ -324,30 +325,15 @@ const ChatInputFooter: React.FC<ChatInputFooterProps> = memo(
|
|
|
324
325
|
</SelectContent>
|
|
325
326
|
</Select>
|
|
326
327
|
</FeatureFlagged>
|
|
327
|
-
<
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
<SelectItem
|
|
337
|
-
key={currentModel}
|
|
338
|
-
value={currentModel}
|
|
339
|
-
className="text-sm"
|
|
340
|
-
>
|
|
341
|
-
{currentModel}
|
|
342
|
-
</SelectItem>
|
|
343
|
-
)}
|
|
344
|
-
{KNOWN_AI_MODELS.map((model) => (
|
|
345
|
-
<SelectItem key={model} value={model} className="text-sm">
|
|
346
|
-
{model}
|
|
347
|
-
</SelectItem>
|
|
348
|
-
))}
|
|
349
|
-
</SelectContent>
|
|
350
|
-
</Select>
|
|
328
|
+
<AIModelDropdown
|
|
329
|
+
value={currentModel}
|
|
330
|
+
placeholder="Model"
|
|
331
|
+
onSelect={handleModelChange}
|
|
332
|
+
triggerClassName="h-6 text-xs shadow-none! ring-0! bg-muted hover:bg-muted/30 rounded-sm"
|
|
333
|
+
iconSize="small"
|
|
334
|
+
showAddCustomModelDocs={true}
|
|
335
|
+
forRole="chat"
|
|
336
|
+
/>
|
|
351
337
|
</div>
|
|
352
338
|
<Button
|
|
353
339
|
variant="ghost"
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
/* Copyright 2024 Marimo. All rights reserved. */
|
|
2
|
+
import { beforeEach, describe, expect, it, vi } from "vitest";
|
|
3
|
+
|
|
4
|
+
// Mock the models.json import
|
|
5
|
+
vi.mock("@marimo-team/llm-info/models.json", () => {
|
|
6
|
+
const models: AiModel[] = [
|
|
7
|
+
{
|
|
8
|
+
name: "GPT-4",
|
|
9
|
+
model: "gpt-4",
|
|
10
|
+
description: "OpenAI GPT-4 model",
|
|
11
|
+
providers: ["openai"],
|
|
12
|
+
roles: ["chat", "edit"],
|
|
13
|
+
thinking: false,
|
|
14
|
+
},
|
|
15
|
+
{
|
|
16
|
+
name: "Claude 3",
|
|
17
|
+
model: "claude-3-sonnet",
|
|
18
|
+
description: "Anthropic Claude 3 Sonnet",
|
|
19
|
+
providers: ["anthropic"],
|
|
20
|
+
roles: ["chat", "edit"],
|
|
21
|
+
thinking: false,
|
|
22
|
+
},
|
|
23
|
+
{
|
|
24
|
+
name: "Gemini Pro",
|
|
25
|
+
model: "gemini-pro",
|
|
26
|
+
description: "Google Gemini Pro model",
|
|
27
|
+
providers: ["google"],
|
|
28
|
+
roles: ["chat", "edit"],
|
|
29
|
+
thinking: false,
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
name: "Multi Provider Model",
|
|
33
|
+
model: "multi-model",
|
|
34
|
+
description: "Model available on multiple providers",
|
|
35
|
+
providers: ["openai", "anthropic"],
|
|
36
|
+
roles: ["chat", "edit"],
|
|
37
|
+
thinking: false,
|
|
38
|
+
},
|
|
39
|
+
];
|
|
40
|
+
|
|
41
|
+
return {
|
|
42
|
+
models: models,
|
|
43
|
+
};
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
import type { AiModel } from "@marimo-team/llm-info";
|
|
47
|
+
import { AiModelRegistry } from "../model-registry";
|
|
48
|
+
|
|
49
|
+
describe("AiModelRegistry", () => {
|
|
50
|
+
beforeEach(() => {
|
|
51
|
+
vi.clearAllMocks();
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
describe("create", () => {
|
|
55
|
+
it("should create registry with no custom or displayed models", () => {
|
|
56
|
+
const registry = AiModelRegistry.create({});
|
|
57
|
+
|
|
58
|
+
expect(registry).toBeInstanceOf(AiModelRegistry);
|
|
59
|
+
expect(registry.getCustomModels()).toEqual(new Set());
|
|
60
|
+
expect(registry.getDisplayedModels()).toEqual(new Set());
|
|
61
|
+
});
|
|
62
|
+
|
|
63
|
+
it("should create registry with custom models", () => {
|
|
64
|
+
const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
|
|
65
|
+
const registry = AiModelRegistry.create({ customModels });
|
|
66
|
+
|
|
67
|
+
expect(registry.getCustomModels()).toEqual(new Set(customModels));
|
|
68
|
+
expect(registry.getDisplayedModels()).toEqual(new Set());
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
it("should create registry with displayed models", () => {
|
|
72
|
+
const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
|
|
73
|
+
const registry = AiModelRegistry.create({ displayedModels });
|
|
74
|
+
|
|
75
|
+
expect(registry.getCustomModels()).toEqual(new Set());
|
|
76
|
+
expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
it("should create registry with both custom and displayed models", () => {
|
|
80
|
+
const customModels = ["openai/custom-gpt"];
|
|
81
|
+
const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
|
|
82
|
+
const registry = AiModelRegistry.create({
|
|
83
|
+
customModels,
|
|
84
|
+
displayedModels,
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
expect(registry.getCustomModels()).toEqual(new Set(customModels));
|
|
88
|
+
expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
|
|
89
|
+
});
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
describe("getModelsByProvider", () => {
|
|
93
|
+
it("should return models for a specific provider", () => {
|
|
94
|
+
const registry = AiModelRegistry.create({});
|
|
95
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
96
|
+
|
|
97
|
+
expect(openaiModels).toHaveLength(2); // gpt-4 and multi-model
|
|
98
|
+
expect(
|
|
99
|
+
openaiModels.every((model) => model.providers.includes("openai")),
|
|
100
|
+
).toBe(true);
|
|
101
|
+
});
|
|
102
|
+
|
|
103
|
+
it("should return empty array for provider with no models", () => {
|
|
104
|
+
const registry = AiModelRegistry.create({});
|
|
105
|
+
const azureModels = registry.getModelsByProvider("azure");
|
|
106
|
+
|
|
107
|
+
expect(azureModels).toEqual([]);
|
|
108
|
+
});
|
|
109
|
+
|
|
110
|
+
it("should include custom models for the provider", () => {
|
|
111
|
+
const customModels = ["openai/custom-gpt"];
|
|
112
|
+
const registry = AiModelRegistry.create({ customModels });
|
|
113
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
114
|
+
|
|
115
|
+
const customModel = openaiModels.find((model) => model.custom);
|
|
116
|
+
expect(customModel).toBeDefined();
|
|
117
|
+
expect(customModel?.name).toBe("openai/custom-gpt");
|
|
118
|
+
expect(customModel?.model).toBe("custom-gpt");
|
|
119
|
+
expect(customModel?.description).toBe("Custom model");
|
|
120
|
+
expect(customModel?.providers).toEqual(["openai"]);
|
|
121
|
+
expect(customModel?.roles).toEqual([]);
|
|
122
|
+
expect(customModel?.thinking).toBe(false);
|
|
123
|
+
});
|
|
124
|
+
|
|
125
|
+
it("should filter models based on displayed models", () => {
|
|
126
|
+
const displayedModels = ["openai/gpt-4"];
|
|
127
|
+
const registry = AiModelRegistry.create({ displayedModels });
|
|
128
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
129
|
+
|
|
130
|
+
expect(openaiModels).toHaveLength(1);
|
|
131
|
+
expect(openaiModels[0].model).toBe("gpt-4");
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
it("should filter custom models based on displayed models", () => {
|
|
135
|
+
const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
|
|
136
|
+
const displayedModels = ["openai/custom-gpt"];
|
|
137
|
+
const registry = AiModelRegistry.create({
|
|
138
|
+
customModels,
|
|
139
|
+
displayedModels,
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
143
|
+
const anthropicModels = registry.getModelsByProvider("anthropic");
|
|
144
|
+
|
|
145
|
+
expect(
|
|
146
|
+
openaiModels.some(
|
|
147
|
+
(model) => model.custom && model.model === "custom-gpt",
|
|
148
|
+
),
|
|
149
|
+
).toBe(true);
|
|
150
|
+
expect(
|
|
151
|
+
anthropicModels.some(
|
|
152
|
+
(model) => model.custom && model.model === "custom-claude",
|
|
153
|
+
),
|
|
154
|
+
).toBe(false);
|
|
155
|
+
});
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
describe("getGroupedModelsByProvider", () => {
|
|
159
|
+
it("should return all models grouped by provider", () => {
|
|
160
|
+
const registry = AiModelRegistry.create({});
|
|
161
|
+
const groupedModels = registry.getGroupedModelsByProvider();
|
|
162
|
+
|
|
163
|
+
expect(groupedModels.has("openai")).toBe(true);
|
|
164
|
+
expect(groupedModels.has("anthropic")).toBe(true);
|
|
165
|
+
expect(groupedModels.has("google")).toBe(true);
|
|
166
|
+
|
|
167
|
+
const openaiModels = groupedModels.get("openai") || [];
|
|
168
|
+
const anthropicModels = groupedModels.get("anthropic") || [];
|
|
169
|
+
const googleModels = groupedModels.get("google") || [];
|
|
170
|
+
|
|
171
|
+
expect(openaiModels.length).toEqual(2);
|
|
172
|
+
expect(anthropicModels.length).toEqual(2);
|
|
173
|
+
expect(googleModels.length).toEqual(1);
|
|
174
|
+
});
|
|
175
|
+
|
|
176
|
+
it("should include custom models in the grouped results", () => {
|
|
177
|
+
const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
|
|
178
|
+
const registry = AiModelRegistry.create({ customModels });
|
|
179
|
+
const groupedModels = registry.getGroupedModelsByProvider();
|
|
180
|
+
|
|
181
|
+
const openaiModels = groupedModels.get("openai") || [];
|
|
182
|
+
const anthropicModels = groupedModels.get("anthropic") || [];
|
|
183
|
+
|
|
184
|
+
expect(
|
|
185
|
+
openaiModels.some(
|
|
186
|
+
(model) => model.custom && model.model === "custom-gpt",
|
|
187
|
+
),
|
|
188
|
+
).toBe(true);
|
|
189
|
+
expect(
|
|
190
|
+
anthropicModels.some(
|
|
191
|
+
(model) => model.custom && model.model === "custom-claude",
|
|
192
|
+
),
|
|
193
|
+
).toBe(true);
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
it("should respect displayed models filter", () => {
|
|
197
|
+
const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
|
|
198
|
+
const registry = AiModelRegistry.create({ displayedModels });
|
|
199
|
+
const groupedModels = registry.getGroupedModelsByProvider();
|
|
200
|
+
|
|
201
|
+
const openaiModels = groupedModels.get("openai") || [];
|
|
202
|
+
const anthropicModels = groupedModels.get("anthropic") || [];
|
|
203
|
+
const googleModels = groupedModels.get("google") || [];
|
|
204
|
+
|
|
205
|
+
expect(openaiModels.length).toBe(1);
|
|
206
|
+
expect(openaiModels[0].model).toBe("gpt-4");
|
|
207
|
+
expect(anthropicModels.length).toBe(1);
|
|
208
|
+
expect(anthropicModels[0].model).toBe("claude-3-sonnet");
|
|
209
|
+
expect(googleModels.length).toBe(0);
|
|
210
|
+
});
|
|
211
|
+
});
|
|
212
|
+
|
|
213
|
+
describe("getCustomModels", () => {
|
|
214
|
+
it("should return empty set when no custom models", () => {
|
|
215
|
+
const registry = AiModelRegistry.create({});
|
|
216
|
+
expect(registry.getCustomModels()).toEqual(new Set());
|
|
217
|
+
});
|
|
218
|
+
|
|
219
|
+
it("should return set of custom model IDs", () => {
|
|
220
|
+
const customModels = ["openai/custom-gpt", "anthropic/custom-claude"];
|
|
221
|
+
const registry = AiModelRegistry.create({ customModels });
|
|
222
|
+
expect(registry.getCustomModels()).toEqual(new Set(customModels));
|
|
223
|
+
});
|
|
224
|
+
});
|
|
225
|
+
|
|
226
|
+
describe("getDisplayedModels", () => {
|
|
227
|
+
it("should return empty set when no displayed models specified", () => {
|
|
228
|
+
const registry = AiModelRegistry.create({});
|
|
229
|
+
expect(registry.getDisplayedModels()).toEqual(new Set());
|
|
230
|
+
});
|
|
231
|
+
|
|
232
|
+
it("should return set of displayed model IDs", () => {
|
|
233
|
+
const displayedModels = ["openai/gpt-4", "anthropic/claude-3-sonnet"];
|
|
234
|
+
const registry = AiModelRegistry.create({ displayedModels });
|
|
235
|
+
expect(registry.getDisplayedModels()).toEqual(new Set(displayedModels));
|
|
236
|
+
});
|
|
237
|
+
});
|
|
238
|
+
|
|
239
|
+
describe("edge cases", () => {
|
|
240
|
+
it("should handle empty arrays for custom and displayed models", () => {
|
|
241
|
+
const registry = AiModelRegistry.create({
|
|
242
|
+
customModels: [],
|
|
243
|
+
displayedModels: [],
|
|
244
|
+
});
|
|
245
|
+
|
|
246
|
+
expect(registry.getCustomModels()).toEqual(new Set());
|
|
247
|
+
expect(registry.getDisplayedModels()).toEqual(new Set());
|
|
248
|
+
|
|
249
|
+
// Should still load default models
|
|
250
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
251
|
+
expect(openaiModels.length).toBeGreaterThan(0);
|
|
252
|
+
});
|
|
253
|
+
|
|
254
|
+
it("should handle models with multiple providers", () => {
|
|
255
|
+
const registry = AiModelRegistry.create({});
|
|
256
|
+
|
|
257
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
258
|
+
const anthropicModels = registry.getModelsByProvider("anthropic");
|
|
259
|
+
|
|
260
|
+
// The multi-model should appear in both providers
|
|
261
|
+
const multiModelInOpenai = openaiModels.find(
|
|
262
|
+
(model) => model.model === "multi-model",
|
|
263
|
+
);
|
|
264
|
+
const multiModelInAnthropic = anthropicModels.find(
|
|
265
|
+
(model) => model.model === "multi-model",
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
expect(multiModelInOpenai).toBeDefined();
|
|
269
|
+
expect(multiModelInAnthropic).toBeDefined();
|
|
270
|
+
expect(multiModelInOpenai).toEqual(multiModelInAnthropic);
|
|
271
|
+
});
|
|
272
|
+
|
|
273
|
+
it("should handle displayed models filter with non-existent models", () => {
|
|
274
|
+
const displayedModels = [
|
|
275
|
+
"openai/non-existent-model",
|
|
276
|
+
"anthropic/claude-3-sonnet",
|
|
277
|
+
];
|
|
278
|
+
const registry = AiModelRegistry.create({ displayedModels });
|
|
279
|
+
|
|
280
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
281
|
+
const anthropicModels = registry.getModelsByProvider("anthropic");
|
|
282
|
+
|
|
283
|
+
// Should only show the existing model
|
|
284
|
+
expect(openaiModels.length).toBe(0);
|
|
285
|
+
expect(anthropicModels.length).toBe(1);
|
|
286
|
+
expect(anthropicModels[0].model).toBe("claude-3-sonnet");
|
|
287
|
+
});
|
|
288
|
+
});
|
|
289
|
+
|
|
290
|
+
describe("model properties", () => {
|
|
291
|
+
it("should ensure all models have required properties", () => {
|
|
292
|
+
const registry = AiModelRegistry.create({});
|
|
293
|
+
const groupedModels = registry.getGroupedModelsByProvider();
|
|
294
|
+
|
|
295
|
+
for (const [provider, models] of groupedModels.entries()) {
|
|
296
|
+
for (const model of models) {
|
|
297
|
+
expect(model).toHaveProperty("name");
|
|
298
|
+
expect(model).toHaveProperty("model");
|
|
299
|
+
expect(model).toHaveProperty("description");
|
|
300
|
+
expect(model).toHaveProperty("providers");
|
|
301
|
+
expect(model).toHaveProperty("roles");
|
|
302
|
+
expect(model).toHaveProperty("thinking");
|
|
303
|
+
expect(model).toHaveProperty("custom");
|
|
304
|
+
|
|
305
|
+
expect(typeof model.name).toBe("string");
|
|
306
|
+
expect(typeof model.model).toBe("string");
|
|
307
|
+
expect(typeof model.description).toBe("string");
|
|
308
|
+
expect(Array.isArray(model.providers)).toBe(true);
|
|
309
|
+
expect(Array.isArray(model.roles)).toBe(true);
|
|
310
|
+
expect(typeof model.thinking).toBe("boolean");
|
|
311
|
+
expect(typeof model.custom).toBe("boolean");
|
|
312
|
+
|
|
313
|
+
expect(model.providers).toContain(provider);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
});
|
|
317
|
+
|
|
318
|
+
it("should ensure custom models have correct custom property", () => {
|
|
319
|
+
const customModels = ["openai/custom-gpt"];
|
|
320
|
+
const registry = AiModelRegistry.create({ customModels });
|
|
321
|
+
const openaiModels = registry.getModelsByProvider("openai");
|
|
322
|
+
|
|
323
|
+
const customModel = openaiModels.find((model) => model.custom);
|
|
324
|
+
const defaultModel = openaiModels.find((model) => !model.custom);
|
|
325
|
+
|
|
326
|
+
expect(customModel).toMatchInlineSnapshot(`
|
|
327
|
+
{
|
|
328
|
+
"custom": true,
|
|
329
|
+
"description": "Custom model",
|
|
330
|
+
"model": "custom-gpt",
|
|
331
|
+
"name": "openai/custom-gpt",
|
|
332
|
+
"providers": [
|
|
333
|
+
"openai",
|
|
334
|
+
],
|
|
335
|
+
"roles": [],
|
|
336
|
+
"thinking": false,
|
|
337
|
+
}
|
|
338
|
+
`);
|
|
339
|
+
expect(defaultModel).toMatchInlineSnapshot(`
|
|
340
|
+
{
|
|
341
|
+
"custom": false,
|
|
342
|
+
"description": "OpenAI GPT-4 model",
|
|
343
|
+
"model": "gpt-4",
|
|
344
|
+
"name": "GPT-4",
|
|
345
|
+
"providers": [
|
|
346
|
+
"openai",
|
|
347
|
+
],
|
|
348
|
+
"roles": [
|
|
349
|
+
"chat",
|
|
350
|
+
"edit",
|
|
351
|
+
],
|
|
352
|
+
"thinking": false,
|
|
353
|
+
}
|
|
354
|
+
`);
|
|
355
|
+
});
|
|
356
|
+
});
|
|
357
|
+
});
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
/* Copyright 2024 Marimo. All rights reserved. */
|
|
2
2
|
|
|
3
3
|
import { describe, expect, it } from "vitest";
|
|
4
|
-
import {
|
|
4
|
+
import type { ProviderId } from "../ids";
|
|
5
|
+
import { AiModelId, type ShortModelId } from "../ids";
|
|
5
6
|
|
|
6
7
|
describe("AiModelId", () => {
|
|
7
8
|
describe("constructor", () => {
|
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
/* Copyright 2024 Marimo. All rights reserved. */
|
|
2
2
|
|
|
3
|
-
import type { TypedString } from "
|
|
3
|
+
import type { TypedString } from "@/utils/typed";
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
5
|
+
export const PROVIDERS = [
|
|
6
|
+
"openai",
|
|
7
|
+
"anthropic",
|
|
8
|
+
"google",
|
|
9
|
+
"ollama",
|
|
10
|
+
"bedrock",
|
|
11
|
+
"deepseek",
|
|
12
|
+
"azure",
|
|
13
|
+
] as const;
|
|
14
|
+
export type ProviderId = (typeof PROVIDERS)[number];
|
|
14
15
|
|
|
15
16
|
export type ShortModelId = TypedString<"ShortModelId">;
|
|
16
17
|
|
|
@@ -57,5 +58,12 @@ function guessProviderId(id: string): ProviderId {
|
|
|
57
58
|
if (id.startsWith("gemini") || id.startsWith("google")) {
|
|
58
59
|
return "google";
|
|
59
60
|
}
|
|
61
|
+
if (id.startsWith("deepseek")) {
|
|
62
|
+
return "deepseek";
|
|
63
|
+
}
|
|
60
64
|
return "ollama";
|
|
61
65
|
}
|
|
66
|
+
|
|
67
|
+
export function isKnownAIProvider(providerId: ProviderId): boolean {
|
|
68
|
+
return PROVIDERS.includes(providerId);
|
|
69
|
+
}
|