@jupyterlite/ai 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.ts ADDED
@@ -0,0 +1,130 @@
1
+ import {
2
+ ActiveCellManager,
3
+ buildChatSidebar,
4
+ buildErrorWidget,
5
+ IActiveCellManager
6
+ } from '@jupyter/chat';
7
+ import {
8
+ JupyterFrontEnd,
9
+ JupyterFrontEndPlugin
10
+ } from '@jupyterlab/application';
11
+ import { ReactWidget, IThemeManager } from '@jupyterlab/apputils';
12
+ import { ICompletionProviderManager } from '@jupyterlab/completer';
13
+ import { INotebookTracker } from '@jupyterlab/notebook';
14
+ import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
15
+ import { ISettingRegistry } from '@jupyterlab/settingregistry';
16
+
17
+ import { ChatHandler } from './chat-handler';
18
+ import { AIProvider } from './provider';
19
+ import { IAIProvider } from './token';
20
+
21
+ const chatPlugin: JupyterFrontEndPlugin<void> = {
22
+ id: '@jupyterlite/ai:chat',
23
+ description: 'LLM chat extension',
24
+ autoStart: true,
25
+ optional: [INotebookTracker, ISettingRegistry, IThemeManager],
26
+ requires: [IAIProvider, IRenderMimeRegistry],
27
+ activate: async (
28
+ app: JupyterFrontEnd,
29
+ aiProvider: IAIProvider,
30
+ rmRegistry: IRenderMimeRegistry,
31
+ notebookTracker: INotebookTracker | null,
32
+ settingsRegistry: ISettingRegistry | null,
33
+ themeManager: IThemeManager | null
34
+ ) => {
35
+ let activeCellManager: IActiveCellManager | null = null;
36
+ if (notebookTracker) {
37
+ activeCellManager = new ActiveCellManager({
38
+ tracker: notebookTracker,
39
+ shell: app.shell
40
+ });
41
+ }
42
+
43
+ const chatHandler = new ChatHandler({
44
+ aiProvider: aiProvider,
45
+ activeCellManager: activeCellManager
46
+ });
47
+
48
+ let sendWithShiftEnter = false;
49
+ let enableCodeToolbar = true;
50
+
51
+ function loadSetting(setting: ISettingRegistry.ISettings): void {
52
+ sendWithShiftEnter = setting.get('sendWithShiftEnter')
53
+ .composite as boolean;
54
+ enableCodeToolbar = setting.get('enableCodeToolbar').composite as boolean;
55
+ chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
56
+ }
57
+
58
+ Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
59
+ .then(([, settings]) => {
60
+ if (!settings) {
61
+ console.warn(
62
+ 'The SettingsRegistry is not loaded for the chat extension'
63
+ );
64
+ return;
65
+ }
66
+ loadSetting(settings);
67
+ settings.changed.connect(loadSetting);
68
+ })
69
+ .catch(reason => {
70
+ console.error(
71
+ `Something went wrong when reading the settings.\n${reason}`
72
+ );
73
+ });
74
+
75
+ let chatWidget: ReactWidget | null = null;
76
+ try {
77
+ chatWidget = buildChatSidebar({
78
+ model: chatHandler,
79
+ themeManager,
80
+ rmRegistry
81
+ });
82
+ chatWidget.title.caption = 'Codestral Chat';
83
+ } catch (e) {
84
+ chatWidget = buildErrorWidget(themeManager);
85
+ }
86
+
87
+ app.shell.add(chatWidget as ReactWidget, 'left', { rank: 2000 });
88
+
89
+ console.log('Chat extension initialized');
90
+ }
91
+ };
92
+
93
+ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
94
+ id: '@jupyterlite/ai:ai-provider',
95
+ autoStart: true,
96
+ requires: [ICompletionProviderManager, ISettingRegistry],
97
+ provides: IAIProvider,
98
+ activate: (
99
+ app: JupyterFrontEnd,
100
+ manager: ICompletionProviderManager,
101
+ settingRegistry: ISettingRegistry
102
+ ): IAIProvider => {
103
+ const aiProvider = new AIProvider({
104
+ completionProviderManager: manager,
105
+ requestCompletion: () => app.commands.execute('inline-completer:invoke')
106
+ });
107
+
108
+ settingRegistry
109
+ .load(aiProviderPlugin.id)
110
+ .then(settings => {
111
+ const updateProvider = () => {
112
+ const provider = settings.get('provider').composite as string;
113
+ aiProvider.setModels(provider, settings.composite);
114
+ };
115
+
116
+ settings.changed.connect(() => updateProvider());
117
+ updateProvider();
118
+ })
119
+ .catch(reason => {
120
+ console.error(
121
+ `Failed to load settings for ${aiProviderPlugin.id}`,
122
+ reason
123
+ );
124
+ });
125
+
126
+ return aiProvider;
127
+ }
128
+ };
129
+
130
+ export default [chatPlugin, aiProviderPlugin];
@@ -0,0 +1,38 @@
1
+ import {
2
+ CompletionHandler,
3
+ IInlineCompletionContext
4
+ } from '@jupyterlab/completer';
5
+ import { LLM } from '@langchain/core/language_models/llms';
6
+ import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
7
+
8
+ export interface IBaseCompleter {
9
+ /**
10
+ * The LLM completer.
11
+ */
12
+ provider: LLM;
13
+
14
+ /**
15
+ * The function to fetch a new completion.
16
+ */
17
+ requestCompletion?: () => void;
18
+
19
+ /**
20
+ * The fetch request for the LLM completer.
21
+ */
22
+ fetch(
23
+ request: CompletionHandler.IRequest,
24
+ context: IInlineCompletionContext
25
+ ): Promise<any>;
26
+ }
27
+
28
+ /**
29
+ * The namespace for the base completer.
30
+ */
31
+ export namespace BaseCompleter {
32
+ /**
33
+ * The options for the constructor of a completer.
34
+ */
35
+ export interface IOptions {
36
+ settings: ReadonlyPartialJSONObject;
37
+ }
38
+ }
@@ -0,0 +1,113 @@
1
+ import {
2
+ CompletionHandler,
3
+ IInlineCompletionContext
4
+ } from '@jupyterlab/completer';
5
+ import { LLM } from '@langchain/core/language_models/llms';
6
+ import { MistralAI } from '@langchain/mistralai';
7
+ import { Throttler } from '@lumino/polling';
8
+ import { CompletionRequest } from '@mistralai/mistralai';
9
+
10
+ import { BaseCompleter, IBaseCompleter } from './base-completer';
11
+
12
+ /**
13
+ * The Mistral API has a rate limit of 1 request per second
14
+ */
15
+ const INTERVAL = 1000;
16
+
17
+ /**
18
+ * Timeout to avoid endless requests
19
+ */
20
+ const REQUEST_TIMEOUT = 3000;
21
+
22
+ export class CodestralCompleter implements IBaseCompleter {
23
+ constructor(options: BaseCompleter.IOptions) {
24
+ // this._requestCompletion = options.requestCompletion;
25
+ this._mistralProvider = new MistralAI({ ...options.settings });
26
+ this._throttler = new Throttler(
27
+ async (data: CompletionRequest) => {
28
+ const invokedData = data;
29
+
30
+ // Request completion.
31
+ const request = this._mistralProvider.completionWithRetry(
32
+ data,
33
+ {},
34
+ false
35
+ );
36
+ const timeoutPromise = new Promise<null>(resolve => {
37
+ return setTimeout(() => resolve(null), REQUEST_TIMEOUT);
38
+ });
39
+
40
+ // Fetch again if the request is too long or if the prompt has changed.
41
+ const response = await Promise.race([request, timeoutPromise]);
42
+ if (
43
+ response === null ||
44
+ invokedData.prompt !== this._currentData?.prompt
45
+ ) {
46
+ return {
47
+ items: [],
48
+ fetchAgain: true
49
+ };
50
+ }
51
+
52
+ // Extract results of completion request.
53
+ const items = response.choices.map((choice: any) => {
54
+ return { insertText: choice.message.content as string };
55
+ });
56
+
57
+ return {
58
+ items
59
+ };
60
+ },
61
+ { limit: INTERVAL }
62
+ );
63
+ }
64
+
65
+ get provider(): LLM {
66
+ return this._mistralProvider;
67
+ }
68
+
69
+ set requestCompletion(value: () => void) {
70
+ this._requestCompletion = value;
71
+ }
72
+
73
+ async fetch(
74
+ request: CompletionHandler.IRequest,
75
+ context: IInlineCompletionContext
76
+ ) {
77
+ const { text, offset: cursorOffset } = request;
78
+ const prompt = text.slice(0, cursorOffset);
79
+ const suffix = text.slice(cursorOffset);
80
+
81
+ const data = {
82
+ prompt,
83
+ suffix,
84
+ model: this._mistralProvider.model,
85
+ // temperature: 0,
86
+ // top_p: 1,
87
+ // max_tokens: 1024,
88
+ // min_tokens: 0,
89
+ stream: false,
90
+ // random_seed: 1337,
91
+ stop: []
92
+ };
93
+
94
+ try {
95
+ this._currentData = data;
96
+ const completionResult = await this._throttler.invoke(data);
97
+ if (completionResult.fetchAgain) {
98
+ if (this._requestCompletion) {
99
+ this._requestCompletion();
100
+ }
101
+ }
102
+ return { items: completionResult.items };
103
+ } catch (error) {
104
+ console.error('Error fetching completions', error);
105
+ return { items: [] };
106
+ }
107
+ }
108
+
109
+ private _requestCompletion?: () => void;
110
+ private _throttler: Throttler;
111
+ private _mistralProvider: MistralAI;
112
+ private _currentData: CompletionRequest | null = null;
113
+ }
@@ -0,0 +1,3 @@
1
+ export * from './base-completer';
2
+ export * from './codestral-completer';
3
+ export * from './utils';
@@ -0,0 +1,41 @@
1
+ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
2
+ import { ChatMistralAI } from '@langchain/mistralai';
3
+ import { IBaseCompleter } from './base-completer';
4
+ import { CodestralCompleter } from './codestral-completer';
5
+ import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
6
+
7
+ /**
8
+ * Get an LLM completer from the name.
9
+ */
10
+ export function getCompleter(
11
+ name: string,
12
+ settings: ReadonlyPartialJSONObject
13
+ ): IBaseCompleter | null {
14
+ if (name === 'MistralAI') {
15
+ return new CodestralCompleter({ settings });
16
+ }
17
+ return null;
18
+ }
19
+
20
+ /**
21
+ * Get an LLM chat model from the name.
22
+ */
23
+ export function getChatModel(
24
+ name: string,
25
+ settings: ReadonlyPartialJSONObject
26
+ ): BaseChatModel | null {
27
+ if (name === 'MistralAI') {
28
+ return new ChatMistralAI({ ...settings });
29
+ }
30
+ return null;
31
+ }
32
+
33
+ /**
34
+ * Get the error message from provider.
35
+ */
36
+ export function getErrorMessage(name: string, error: any): string {
37
+ if (name === 'MistralAI') {
38
+ return error.message;
39
+ }
40
+ return 'Unknown provider';
41
+ }
@@ -0,0 +1,154 @@
1
+ import { ICompletionProviderManager } from '@jupyterlab/completer';
2
+ import { BaseLanguageModel } from '@langchain/core/language_models/base';
3
+ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
4
+ import { ISignal, Signal } from '@lumino/signaling';
5
+ import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
6
+
7
+ import { CompletionProvider } from './completion-provider';
8
+ import { getChatModel, IBaseCompleter } from './llm-models';
9
+ import { IAIProvider } from './token';
10
+
11
+ export class AIProvider implements IAIProvider {
12
+ constructor(options: AIProvider.IOptions) {
13
+ this._completionProvider = new CompletionProvider({
14
+ name: 'None',
15
+ settings: {},
16
+ requestCompletion: options.requestCompletion
17
+ });
18
+ options.completionProviderManager.registerInlineProvider(
19
+ this._completionProvider
20
+ );
21
+ }
22
+
23
+ get name(): string {
24
+ return this._name;
25
+ }
26
+
27
+ /**
28
+ * Get the current completer of the completion provider.
29
+ */
30
+ get completer(): IBaseCompleter | null {
31
+ if (this._name === null) {
32
+ return null;
33
+ }
34
+ return this._completionProvider.completer;
35
+ }
36
+
37
+ /**
38
+ * Get the current llm chat model.
39
+ */
40
+ get chatModel(): BaseChatModel | null {
41
+ if (this._name === null) {
42
+ return null;
43
+ }
44
+ return this._llmChatModel;
45
+ }
46
+
47
+ /**
48
+ * Get the current chat error;
49
+ */
50
+ get chatError(): string {
51
+ return this._chatError;
52
+ }
53
+
54
+ /**
55
+ * get the current completer error.
56
+ */
57
+ get completerError(): string {
58
+ return this._completerError;
59
+ }
60
+
61
+ /**
62
+ * Set the models (chat model and completer).
63
+ * Creates the models if the name has changed, otherwise only updates their config.
64
+ *
65
+ * @param name - the name of the model to use.
66
+ * @param settings - the settings for the models.
67
+ */
68
+ setModels(name: string, settings: ReadonlyPartialJSONObject) {
69
+ try {
70
+ this._completionProvider.setCompleter(name, settings);
71
+ this._completerError = '';
72
+ } catch (e: any) {
73
+ this._completerError = e.message;
74
+ }
75
+ try {
76
+ this._llmChatModel = getChatModel(name, settings);
77
+ this._chatError = '';
78
+ } catch (e: any) {
79
+ this._chatError = e.message;
80
+ this._llmChatModel = null;
81
+ }
82
+ this._name = name;
83
+ this._modelChange.emit();
84
+ }
85
+
86
+ get modelChange(): ISignal<IAIProvider, void> {
87
+ return this._modelChange;
88
+ }
89
+
90
+ private _completionProvider: CompletionProvider;
91
+ private _llmChatModel: BaseChatModel | null = null;
92
+ private _name: string = 'None';
93
+ private _modelChange = new Signal<IAIProvider, void>(this);
94
+ private _chatError: string = '';
95
+ private _completerError: string = '';
96
+ }
97
+
98
+ export namespace AIProvider {
99
+ /**
100
+ * The options for the LLM provider.
101
+ */
102
+ export interface IOptions {
103
+ /**
104
+ * The completion provider manager in which register the LLM completer.
105
+ */
106
+ completionProviderManager: ICompletionProviderManager;
107
+ /**
108
+ * The application commands registry.
109
+ */
110
+ requestCompletion: () => void;
111
+ }
112
+
113
+ /**
114
+ * This function indicates whether a key is writable in an object.
115
+ * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript
116
+ *
117
+ * @param obj - An object extending the BaseLanguageModel interface.
118
+ * @param key - A string as a key of the object.
119
+ * @returns a boolean whether the key is writable or not.
120
+ */
121
+ export function isWritable<T extends BaseLanguageModel>(
122
+ obj: T,
123
+ key: keyof T
124
+ ) {
125
+ const desc =
126
+ Object.getOwnPropertyDescriptor(obj, key) ||
127
+ Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) ||
128
+ {};
129
+ return Boolean(desc.writable);
130
+ }
131
+
132
+ /**
133
+ * Update the config of a language model.
134
+ * It only updates the writable attributes of the model.
135
+ *
136
+ * @param model - the model to update.
137
+ * @param settings - the configuration s a JSON object.
138
+ */
139
+ export function updateConfig<T extends BaseLanguageModel>(
140
+ model: T,
141
+ settings: ReadonlyPartialJSONObject
142
+ ) {
143
+ Object.entries(settings).forEach(([key, value], index) => {
144
+ if (key in model) {
145
+ const modelKey = key as keyof typeof model;
146
+ if (isWritable(model, modelKey)) {
147
+ // eslint-disable-next-line @typescript-eslint/ban-ts-comment
148
+ // @ts-ignore
149
+ model[modelKey] = value;
150
+ }
151
+ }
152
+ });
153
+ }
154
+ }
package/src/token.ts ADDED
@@ -0,0 +1,19 @@
1
+ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
2
+ import { Token } from '@lumino/coreutils';
3
+ import { ISignal } from '@lumino/signaling';
4
+
5
+ import { IBaseCompleter } from './llm-models';
6
+
7
+ export interface IAIProvider {
8
+ name: string;
9
+ completer: IBaseCompleter | null;
10
+ chatModel: BaseChatModel | null;
11
+ modelChange: ISignal<IAIProvider, void>;
12
+ chatError: string;
13
+ completerError: string;
14
+ }
15
+
16
+ export const IAIProvider = new Token<IAIProvider>(
17
+ '@jupyterlite/ai:AIProvider',
18
+ 'Provider for chat and completion LLM provider'
19
+ );
package/style/base.css ADDED
@@ -0,0 +1,7 @@
1
+ /*
2
+ See the JupyterLab Developer Guide for useful CSS Patterns:
3
+
4
+ https://jupyterlab.readthedocs.io/en/stable/developer/css.html
5
+ */
6
+
7
+ @import url('@jupyter/chat/style/index.css');
@@ -0,0 +1 @@
1
+ @import url('base.css');
package/style/index.js ADDED
@@ -0,0 +1 @@
1
+ import './base.css';