ai.muna.muna 0.0.44 → 0.0.46
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Editor/MunaMenu.cs +17 -7
- package/Plugins/Android/Muna.aar +0 -0
- package/Plugins/macOS/Function.dylib.meta +26 -25
- package/README.md +1 -1
- package/Runtime/API/DotNetClient.cs +0 -3
- package/Runtime/Beta/BetaClient.cs +14 -1
- package/Runtime/Beta/OpenAI/AudioService.cs +38 -0
- package/Runtime/Beta/OpenAI/AudioService.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/ChatService.cs +38 -0
- package/Runtime/Beta/OpenAI/ChatService.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/CompletionService.cs +117 -0
- package/Runtime/Beta/OpenAI/CompletionService.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/EmbeddingService.cs +252 -0
- package/Runtime/Beta/OpenAI/EmbeddingService.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/OpenAIClient.cs +50 -0
- package/Runtime/Beta/OpenAI/OpenAIClient.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/SpeechService.cs +256 -0
- package/Runtime/Beta/OpenAI/SpeechService.cs.meta +11 -0
- package/Runtime/Beta/OpenAI/Types.cs +364 -0
- package/Runtime/Beta/OpenAI/Types.cs.meta +11 -0
- package/Runtime/Beta/OpenAI.meta +8 -0
- package/Runtime/Beta/Remote/RemotePredictionService.cs +50 -70
- package/Runtime/Beta/{Value.cs → Types/Value.cs} +3 -4
- package/Runtime/Beta/Types.meta +8 -0
- package/Runtime/C/Configuration.cs +1 -1
- package/Runtime/C/Function.cs +1 -1
- package/Runtime/C/Prediction.cs +1 -1
- package/Runtime/C/PredictionStream.cs +1 -1
- package/Runtime/C/Predictor.cs +1 -1
- package/Runtime/C/Value.cs +3 -2
- package/Runtime/C/ValueMap.cs +1 -1
- package/Runtime/Muna.cs +2 -2
- package/Runtime/Types/Parameter.cs +75 -0
- package/Runtime/Types/Parameter.cs.meta +11 -0
- package/Runtime/Types/Predictor.cs +0 -53
- package/Unity/API/PredictionCacheClient.cs +1 -1
- package/Unity/Converters/Color.cs +46 -0
- package/Unity/Converters/Color.cs.meta +2 -0
- package/Unity/Converters/Rect.cs +230 -0
- package/Unity/Converters/Rect.cs.meta +2 -0
- package/Unity/Converters/Vector2.cs +44 -0
- package/Unity/Converters/Vector2.cs.meta +2 -0
- package/Unity/Converters/Vector3.cs +45 -0
- package/Unity/Converters/Vector3.cs.meta +2 -0
- package/Unity/Converters/Vector4.cs +46 -0
- package/Unity/Converters/Vector4.cs.meta +2 -0
- package/Unity/Converters.meta +8 -0
- package/Unity/MunaUnity.cs +67 -19
- package/package.json +1 -1
- /package/Runtime/Beta/{Value.cs.meta → Types/Value.cs.meta} +0 -0
package/Editor/MunaMenu.cs
CHANGED
|
@@ -5,28 +5,38 @@
|
|
|
5
5
|
|
|
6
6
|
namespace Muna.Editor {
|
|
7
7
|
|
|
8
|
+
using System.IO;
|
|
8
9
|
using UnityEditor;
|
|
9
10
|
|
|
10
11
|
internal static class MunaMenu {
|
|
11
12
|
|
|
12
13
|
private const int BasePriority = -50;
|
|
13
|
-
|
|
14
|
-
[MenuItem(@"Muna/Muna " + Muna.Version, false, BasePriority)]
|
|
14
|
+
|
|
15
|
+
[MenuItem(@"Tools/Muna/Muna " + Muna.Version, false, BasePriority)]
|
|
15
16
|
private static void Version() { }
|
|
16
17
|
|
|
17
|
-
[MenuItem(@"Muna/Muna " + Muna.Version, true, BasePriority)]
|
|
18
|
+
[MenuItem(@"Tools/Muna/Muna " + Muna.Version, true, BasePriority)]
|
|
18
19
|
private static bool EnableVersion() => false;
|
|
19
20
|
|
|
20
|
-
[MenuItem(@"Muna/Get Access Key", false, BasePriority + 1)]
|
|
21
|
+
[MenuItem(@"Tools/Muna/Get Access Key", false, BasePriority + 1)]
|
|
21
22
|
private static void GetAccessKey() => Help.BrowseURL(@"https://muna.ai/settings/developer");
|
|
22
23
|
|
|
23
|
-
[MenuItem(@"Muna/Explore Predictors", false, BasePriority + 2)]
|
|
24
|
+
[MenuItem(@"Tools/Muna/Explore Predictors", false, BasePriority + 2)]
|
|
24
25
|
private static void OpenExplore() => Help.BrowseURL(@"https://muna.ai/explore");
|
|
25
26
|
|
|
26
|
-
[MenuItem(@"Muna/View the Docs", false, BasePriority + 3)]
|
|
27
|
+
[MenuItem(@"Tools/Muna/View the Docs", false, BasePriority + 3)]
|
|
27
28
|
private static void OpenDocs() => Help.BrowseURL(@"https://docs.muna.ai");
|
|
28
29
|
|
|
29
|
-
[MenuItem(@"Muna/Report an Issue", false, BasePriority + 4)]
|
|
30
|
+
[MenuItem(@"Tools/Muna/Report an Issue", false, BasePriority + 4)]
|
|
30
31
|
private static void ReportIssue() => Help.BrowseURL(@"https://github.com/muna-ai/muna-unity");
|
|
32
|
+
|
|
33
|
+
[MenuItem(@"Tools/Muna/Clear Predictor Cache", false, BasePriority + 5)]
|
|
34
|
+
private static void ClearPredictorCache() {
|
|
35
|
+
Directory.Delete(
|
|
36
|
+
global::Muna.API.PredictionCacheClient.PredictorCachePath,
|
|
37
|
+
true
|
|
38
|
+
);
|
|
39
|
+
UnityEngine.Debug.Log("Muna: Cleared predictor cache.");
|
|
40
|
+
}
|
|
31
41
|
}
|
|
32
42
|
}
|
package/Plugins/Android/Muna.aar
CHANGED
|
Binary file
|
|
@@ -2,7 +2,7 @@ fileFormatVersion: 2
|
|
|
2
2
|
guid: 62aec3fdf35a340f78ab3a076fc7ca3e
|
|
3
3
|
PluginImporter:
|
|
4
4
|
externalObjects: {}
|
|
5
|
-
serializedVersion:
|
|
5
|
+
serializedVersion: 2
|
|
6
6
|
iconMap: {}
|
|
7
7
|
executionOrder: {}
|
|
8
8
|
defineConstraints: []
|
|
@@ -11,52 +11,53 @@ PluginImporter:
|
|
|
11
11
|
isExplicitlyReferenced: 0
|
|
12
12
|
validateReferences: 1
|
|
13
13
|
platformData:
|
|
14
|
-
|
|
14
|
+
- first:
|
|
15
|
+
:
|
|
16
|
+
second:
|
|
15
17
|
enabled: 0
|
|
16
|
-
settings:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
Any:
|
|
18
|
+
settings: {}
|
|
19
|
+
- first:
|
|
20
|
+
: Any
|
|
21
|
+
second:
|
|
21
22
|
enabled: 0
|
|
22
23
|
settings:
|
|
23
|
-
Exclude Android: 1
|
|
24
24
|
Exclude Editor: 0
|
|
25
25
|
Exclude Linux64: 1
|
|
26
26
|
Exclude OSXUniversal: 0
|
|
27
|
-
Exclude WebGL: 1
|
|
28
27
|
Exclude Win: 1
|
|
29
28
|
Exclude Win64: 1
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
- first:
|
|
30
|
+
Editor: Editor
|
|
31
|
+
second:
|
|
32
32
|
enabled: 1
|
|
33
33
|
settings:
|
|
34
34
|
CPU: ARM64
|
|
35
35
|
DefaultValueInitialized: true
|
|
36
36
|
OS: OSX
|
|
37
|
-
|
|
37
|
+
- first:
|
|
38
|
+
Standalone: Linux64
|
|
39
|
+
second:
|
|
38
40
|
enabled: 0
|
|
39
41
|
settings:
|
|
40
|
-
CPU:
|
|
41
|
-
|
|
42
|
+
CPU: x86_64
|
|
43
|
+
- first:
|
|
44
|
+
Standalone: OSXUniversal
|
|
45
|
+
second:
|
|
42
46
|
enabled: 1
|
|
43
47
|
settings:
|
|
44
48
|
CPU: ARM64
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
CPU: AnyCPU
|
|
49
|
-
Win64:
|
|
49
|
+
- first:
|
|
50
|
+
Standalone: Win
|
|
51
|
+
second:
|
|
50
52
|
enabled: 0
|
|
51
53
|
settings:
|
|
52
|
-
CPU:
|
|
53
|
-
|
|
54
|
+
CPU: x86
|
|
55
|
+
- first:
|
|
56
|
+
Standalone: Win64
|
|
57
|
+
second:
|
|
54
58
|
enabled: 0
|
|
55
59
|
settings:
|
|
56
|
-
|
|
57
|
-
CPU: AnyCPU
|
|
58
|
-
CompileFlags:
|
|
59
|
-
FrameworkDependencies:
|
|
60
|
+
CPU: x86_64
|
|
60
61
|
userData:
|
|
61
62
|
assetBundleName:
|
|
62
63
|
assetBundleVariant:
|
package/README.md
CHANGED
|
@@ -26,9 +26,6 @@ namespace Muna.API {
|
|
|
26
26
|
/// </summary>
|
|
27
27
|
/// <param name="url">Muna API URL.</param>
|
|
28
28
|
/// <param name="accessKey">Muna access key.</param>
|
|
29
|
-
/// <param name="clientId">Client identifier.</param>
|
|
30
|
-
/// <param name="deviceId">Device model identifier.</param>
|
|
31
|
-
/// <param name="cachePath">Prediction resource cache path.</param>
|
|
32
29
|
public DotNetClient(
|
|
33
30
|
string url,
|
|
34
31
|
string? accessKey = default
|
|
@@ -8,7 +8,10 @@
|
|
|
8
8
|
namespace Muna.Beta {
|
|
9
9
|
|
|
10
10
|
using API;
|
|
11
|
+
using OpenAI;
|
|
11
12
|
using Services;
|
|
13
|
+
using PredictorService = global::Muna.Services.PredictorService;
|
|
14
|
+
using EdgePredictionService = global::Muna.Services.PredictionService;
|
|
12
15
|
|
|
13
16
|
/// <summary>
|
|
14
17
|
/// Client for incubating features.
|
|
@@ -20,13 +23,23 @@ namespace Muna.Beta {
|
|
|
20
23
|
/// Make predictions.
|
|
21
24
|
/// </summary>
|
|
22
25
|
public readonly PredictionService Predictions;
|
|
26
|
+
|
|
27
|
+
/// <summary>
|
|
28
|
+
/// OpenAI client.
|
|
29
|
+
/// </summary>
|
|
30
|
+
public readonly OpenAIClient OpenAI;
|
|
23
31
|
#endregion
|
|
24
32
|
|
|
25
33
|
|
|
26
34
|
#region --Operations--
|
|
27
35
|
|
|
28
|
-
internal BetaClient(
|
|
36
|
+
internal BetaClient(
|
|
37
|
+
MunaClient client,
|
|
38
|
+
PredictorService predictors,
|
|
39
|
+
EdgePredictionService predictions
|
|
40
|
+
) {
|
|
29
41
|
this.Predictions = new PredictionService(client);
|
|
42
|
+
this.OpenAI = new OpenAIClient(predictors, predictions, Predictions.Remote);
|
|
30
43
|
}
|
|
31
44
|
#endregion
|
|
32
45
|
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Muna
|
|
3
|
+
* Copyright © 2025 NatML Inc. All rights reserved.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#nullable enable
|
|
7
|
+
|
|
8
|
+
namespace Muna.Beta.OpenAI {
|
|
9
|
+
|
|
10
|
+
using Services;
|
|
11
|
+
using PredictorService = global::Muna.Services.PredictorService;
|
|
12
|
+
using EdgePredictionService = global::Muna.Services.PredictionService;
|
|
13
|
+
|
|
14
|
+
/// <summary>
|
|
15
|
+
/// Create speech and transcriptions.
|
|
16
|
+
/// </summary>
|
|
17
|
+
public sealed class AudioService {
|
|
18
|
+
|
|
19
|
+
#region --Client API--
|
|
20
|
+
/// <summary>
|
|
21
|
+
/// Create speech.
|
|
22
|
+
/// </summary>
|
|
23
|
+
public readonly SpeechService Speech;
|
|
24
|
+
#endregion
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
#region --Operations--
|
|
28
|
+
|
|
29
|
+
internal AudioService(
|
|
30
|
+
PredictorService predictors,
|
|
31
|
+
EdgePredictionService predictions,
|
|
32
|
+
RemotePredictionService remotePredictions
|
|
33
|
+
) {
|
|
34
|
+
Speech = new SpeechService(predictors, predictions, remotePredictions);
|
|
35
|
+
}
|
|
36
|
+
#endregion
|
|
37
|
+
}
|
|
38
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Muna
|
|
3
|
+
* Copyright © 2025 NatML Inc. All rights reserved.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#nullable enable
|
|
7
|
+
|
|
8
|
+
namespace Muna.Beta.OpenAI {
|
|
9
|
+
|
|
10
|
+
using Services;
|
|
11
|
+
using PredictorService = global::Muna.Services.PredictorService;
|
|
12
|
+
using EdgePredictionService = global::Muna.Services.PredictionService;
|
|
13
|
+
|
|
14
|
+
/// <summary>
|
|
15
|
+
/// Create chat conversations.
|
|
16
|
+
/// </summary>
|
|
17
|
+
public sealed class ChatService {
|
|
18
|
+
|
|
19
|
+
#region --Client API--
|
|
20
|
+
/// <summary>
|
|
21
|
+
/// Create completions.
|
|
22
|
+
/// </summary>
|
|
23
|
+
public readonly ChatCompletionService Completions;
|
|
24
|
+
#endregion
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
#region --Operations--
|
|
28
|
+
|
|
29
|
+
internal ChatService(
|
|
30
|
+
PredictorService predictors,
|
|
31
|
+
EdgePredictionService predictions,
|
|
32
|
+
RemotePredictionService remotePredictions
|
|
33
|
+
) {
|
|
34
|
+
Completions = new ChatCompletionService(predictors, predictions, remotePredictions);
|
|
35
|
+
}
|
|
36
|
+
#endregion
|
|
37
|
+
}
|
|
38
|
+
}
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Muna
|
|
3
|
+
* Copyright © 2025 NatML Inc. All rights reserved.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#nullable enable
|
|
7
|
+
|
|
8
|
+
namespace Muna.Beta.OpenAI {
|
|
9
|
+
|
|
10
|
+
using System;
|
|
11
|
+
using System.Collections.Generic;
|
|
12
|
+
using System.Threading.Tasks;
|
|
13
|
+
using Newtonsoft.Json.Linq;
|
|
14
|
+
using Services;
|
|
15
|
+
using PredictorService = global::Muna.Services.PredictorService;
|
|
16
|
+
using EdgePredictionService = global::Muna.Services.PredictionService;
|
|
17
|
+
|
|
18
|
+
/// <summary>
|
|
19
|
+
/// Create chat conversations.
|
|
20
|
+
/// </summary>
|
|
21
|
+
public sealed class ChatCompletionService {
|
|
22
|
+
|
|
23
|
+
#region --Client API--
|
|
24
|
+
/// <summary>
|
|
25
|
+
/// Create a chat completion.
|
|
26
|
+
/// </summary>
|
|
27
|
+
/// <param name="model">Chat model predictor tag.</param>
|
|
28
|
+
/// <param name="messages">Messages comprising the conversation so far.</param>
|
|
29
|
+
/// <param name="maxTokens">Maximum output tokens.</param>
|
|
30
|
+
/// <param name="acceleration">Prediction acceleration.</param>
|
|
31
|
+
public async Task<ChatCompletion> Create(
|
|
32
|
+
string model,
|
|
33
|
+
ChatMessage[] messages,
|
|
34
|
+
int? maxTokens = null,
|
|
35
|
+
object? acceleration = null
|
|
36
|
+
) {
|
|
37
|
+
var prediction = await CreatePrediction(
|
|
38
|
+
model,
|
|
39
|
+
new() {
|
|
40
|
+
[@"messages"] = messages,
|
|
41
|
+
[@"max_tokens"] = maxTokens
|
|
42
|
+
},
|
|
43
|
+
acceleration: acceleration ?? Acceleration.Auto
|
|
44
|
+
);
|
|
45
|
+
if (prediction.error != null)
|
|
46
|
+
throw new InvalidOperationException(prediction.error);
|
|
47
|
+
var completion = (prediction.results![0] as JObject)!.ToObject<ChatCompletion>()!;
|
|
48
|
+
return completion;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/// <summary>
|
|
52
|
+
/// Stream a chat completion.
|
|
53
|
+
/// </summary>
|
|
54
|
+
/// <param name="model">Chat model predictor tag.</param>
|
|
55
|
+
/// <param name="messages">Messages comprising the conversation so far.</param>
|
|
56
|
+
/// <param name="maxTokens">Maximum output tokens.</param>
|
|
57
|
+
/// <param name="acceleration">Prediction acceleration.</param>
|
|
58
|
+
public async IAsyncEnumerable<ChatCompletionChunk> Stream(
|
|
59
|
+
string model,
|
|
60
|
+
ChatMessage[] messages,
|
|
61
|
+
int? maxTokens = null,
|
|
62
|
+
object? acceleration = null
|
|
63
|
+
) {
|
|
64
|
+
var stream = StreamPrediction(
|
|
65
|
+
model,
|
|
66
|
+
new() {
|
|
67
|
+
[@"messages"] = messages,
|
|
68
|
+
[@"max_tokens"] = maxTokens
|
|
69
|
+
},
|
|
70
|
+
acceleration: acceleration ?? Acceleration.Auto
|
|
71
|
+
);
|
|
72
|
+
await foreach (var prediction in stream) {
|
|
73
|
+
if (prediction.error != null)
|
|
74
|
+
throw new InvalidOperationException(prediction.error);
|
|
75
|
+
var chunk = (prediction.results![0] as JObject)!.ToObject<ChatCompletionChunk>()!;
|
|
76
|
+
yield return chunk;
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
#endregion
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
#region --Operations--
|
|
83
|
+
private readonly PredictorService predictors;
|
|
84
|
+
private readonly EdgePredictionService predictions;
|
|
85
|
+
private readonly RemotePredictionService remotePredictions;
|
|
86
|
+
|
|
87
|
+
internal ChatCompletionService(
|
|
88
|
+
PredictorService predictors,
|
|
89
|
+
EdgePredictionService predictions,
|
|
90
|
+
RemotePredictionService remotePredictions
|
|
91
|
+
) {
|
|
92
|
+
this.predictors = predictors;
|
|
93
|
+
this.predictions = predictions;
|
|
94
|
+
this.remotePredictions = remotePredictions;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
private Task<Prediction> CreatePrediction(
|
|
98
|
+
string tag,
|
|
99
|
+
Dictionary<string, object?> inputs,
|
|
100
|
+
object acceleration
|
|
101
|
+
) => acceleration switch {
|
|
102
|
+
Acceleration acc => predictions.Create(tag, inputs, acc),
|
|
103
|
+
RemoteAcceleration acc => remotePredictions.Create(tag, inputs, acc),
|
|
104
|
+
_ => throw new InvalidOperationException($"Cannot create {tag} prediction because acceleration is invalid: {acceleration}")
|
|
105
|
+
};
|
|
106
|
+
|
|
107
|
+
private IAsyncEnumerable<Prediction> StreamPrediction(
|
|
108
|
+
string tag,
|
|
109
|
+
Dictionary<string, object?> inputs,
|
|
110
|
+
object acceleration
|
|
111
|
+
) => acceleration switch {
|
|
112
|
+
Acceleration acc => predictions.Stream(tag, inputs, acc),
|
|
113
|
+
_ => throw new InvalidOperationException($"Cannot stream {tag} prediction because acceleration is invalid: {acceleration}")
|
|
114
|
+
};
|
|
115
|
+
#endregion
|
|
116
|
+
}
|
|
117
|
+
}
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Muna
|
|
3
|
+
* Copyright © 2025 NatML Inc. All rights reserved.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#nullable enable
|
|
7
|
+
|
|
8
|
+
namespace Muna.Beta.OpenAI {
|
|
9
|
+
|
|
10
|
+
using System;
|
|
11
|
+
using System.Collections.Generic;
|
|
12
|
+
using System.Linq;
|
|
13
|
+
using System.Runtime.Serialization;
|
|
14
|
+
using System.Threading.Tasks;
|
|
15
|
+
using Newtonsoft.Json;
|
|
16
|
+
using Newtonsoft.Json.Converters;
|
|
17
|
+
using Newtonsoft.Json.Linq;
|
|
18
|
+
using Services;
|
|
19
|
+
using PredictorService = global::Muna.Services.PredictorService;
|
|
20
|
+
using EdgePredictionService = global::Muna.Services.PredictionService;
|
|
21
|
+
|
|
22
|
+
/// <summary>
|
|
23
|
+
/// Create embeddings.
|
|
24
|
+
/// </summary>
|
|
25
|
+
public sealed class EmbeddingService {
|
|
26
|
+
|
|
27
|
+
#region --Client API--
|
|
28
|
+
/// <summary>
|
|
29
|
+
/// Embedding encoding format.
|
|
30
|
+
/// </summary>
|
|
31
|
+
[JsonConverter(typeof(StringEnumConverter))]
|
|
32
|
+
public enum EncodingFormat {
|
|
33
|
+
/// <summary>
|
|
34
|
+
/// Float array.
|
|
35
|
+
/// </summary>
|
|
36
|
+
[EnumMember(Value = @"float")]
|
|
37
|
+
Float = 1,
|
|
38
|
+
/// <summary>
|
|
39
|
+
/// Base64 string.
|
|
40
|
+
/// </summary>
|
|
41
|
+
[EnumMember(Value = @"base64")]
|
|
42
|
+
Base64 = 2
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
/// <summary>
|
|
46
|
+
/// Create an embedding vector representing the input text.
|
|
47
|
+
/// </summary>
|
|
48
|
+
/// <param name="input">Input text to embed. The input must not exceed the max input tokens for the model.</param>
|
|
49
|
+
/// <param name="model">Embedding model predictor tag.</param>
|
|
50
|
+
/// <param name="dimensions">The number of dimensions the resulting output embeddings should have. Only supported by Matryoshka embedding models.</param>
|
|
51
|
+
/// <param name="encodingFormat">The format to return the embeddings in.</param>
|
|
52
|
+
/// <param name="acceleration">Prediction acceleration.</param>
|
|
53
|
+
/// <returns>Embeddings.</returns>
|
|
54
|
+
public Task<CreateEmbeddingResponse> Create(
|
|
55
|
+
string model,
|
|
56
|
+
string input,
|
|
57
|
+
int? dimensions = null,
|
|
58
|
+
EncodingFormat encodingFormat = EncodingFormat.Float,
|
|
59
|
+
object? acceleration = null
|
|
60
|
+
) => Create(
|
|
61
|
+
model,
|
|
62
|
+
new[] { input },
|
|
63
|
+
dimensions: dimensions,
|
|
64
|
+
encodingFormat: encodingFormat,
|
|
65
|
+
acceleration: acceleration
|
|
66
|
+
);
|
|
67
|
+
|
|
68
|
+
/// <summary>
|
|
69
|
+
/// Create an embedding vector representing the input text.
|
|
70
|
+
/// </summary>
|
|
71
|
+
/// <param name="input">Input text to embed. The input must not exceed the max input tokens for the model.</param>
|
|
72
|
+
/// <param name="model">Embedding model predictor tag.</param>
|
|
73
|
+
/// <param name="dimensions">The number of dimensions the resulting output embeddings should have. Only supported by Matryoshka embedding models.</param>
|
|
74
|
+
/// <param name="encodingFormat">The format to return the embeddings in.</param>
|
|
75
|
+
/// <param name="acceleration">Prediction acceleration.</param>
|
|
76
|
+
/// <returns>Embeddings.</returns>
|
|
77
|
+
public async Task<CreateEmbeddingResponse> Create(
|
|
78
|
+
string model,
|
|
79
|
+
string[] input,
|
|
80
|
+
int? dimensions = null,
|
|
81
|
+
EncodingFormat encodingFormat = EncodingFormat.Float,
|
|
82
|
+
object? acceleration = null
|
|
83
|
+
) {
|
|
84
|
+
// Ensure we have a delegate
|
|
85
|
+
if (!cache.ContainsKey(model)) {
|
|
86
|
+
var @delegate = await CreateEmbeddingDelegate(model);
|
|
87
|
+
cache.Add(model, @delegate);
|
|
88
|
+
}
|
|
89
|
+
// Make prediction
|
|
90
|
+
var handler = cache[model];
|
|
91
|
+
var result = await handler(
|
|
92
|
+
model,
|
|
93
|
+
input,
|
|
94
|
+
dimensions,
|
|
95
|
+
encodingFormat,
|
|
96
|
+
acceleration: acceleration ?? Acceleration.Auto
|
|
97
|
+
);
|
|
98
|
+
// Return
|
|
99
|
+
return result;
|
|
100
|
+
}
|
|
101
|
+
#endregion
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
#region --Operations--
|
|
105
|
+
private readonly PredictorService predictors;
|
|
106
|
+
private readonly EdgePredictionService predictions;
|
|
107
|
+
private readonly RemotePredictionService remotePredictions;
|
|
108
|
+
private readonly Dictionary<string, EmbeddingDelegate> cache;
|
|
109
|
+
private delegate Task<CreateEmbeddingResponse> EmbeddingDelegate(
|
|
110
|
+
string model,
|
|
111
|
+
string[] input,
|
|
112
|
+
int? dimensions,
|
|
113
|
+
EncodingFormat encodingFormat,
|
|
114
|
+
object acceleration
|
|
115
|
+
);
|
|
116
|
+
|
|
117
|
+
internal EmbeddingService(
|
|
118
|
+
PredictorService predictors,
|
|
119
|
+
EdgePredictionService predictions,
|
|
120
|
+
RemotePredictionService remotePredictions
|
|
121
|
+
) {
|
|
122
|
+
this.predictors = predictors;
|
|
123
|
+
this.predictions = predictions;
|
|
124
|
+
this.remotePredictions = remotePredictions;
|
|
125
|
+
this.cache = new();
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
private async Task<EmbeddingDelegate> CreateEmbeddingDelegate(string tag) {
|
|
129
|
+
// Retrieve predictor
|
|
130
|
+
var predictor = await predictors.Retrieve(tag);
|
|
131
|
+
if (predictor == null)
|
|
132
|
+
throw new ArgumentException($"{tag} cannot be used for OpenAI embedding API because the predictor could not be found.");
|
|
133
|
+
// Get required inputs
|
|
134
|
+
var signature = predictor.signature!;
|
|
135
|
+
var requiredInputParams = signature.inputs.Where(parameter => parameter.optional == false).ToArray();
|
|
136
|
+
if (requiredInputParams.Length != 1)
|
|
137
|
+
throw new InvalidOperationException($"{tag} cannot be used with OpenAI embedding API because it does not have exactly one required input parameter.");
|
|
138
|
+
// Check the text input parameter
|
|
139
|
+
var inputParam = requiredInputParams[0];
|
|
140
|
+
if (inputParam.type != Dtype.List)
|
|
141
|
+
throw new InvalidOperationException($"{tag} cannot be used with OpenAI embedding API because it does not have the required text embedding input parameter.");
|
|
142
|
+
// Get the Matryoshka dim parameter (optional)
|
|
143
|
+
var matryoshkaParam = signature.inputs.FirstOrDefault(parameter =>
|
|
144
|
+
new[] {
|
|
145
|
+
Dtype.Int8, Dtype.Int16, Dtype.Int32, Dtype.Int64,
|
|
146
|
+
Dtype.Uint8, Dtype.Uint16, Dtype.Uint32, Dtype.Uint64
|
|
147
|
+
}.Contains(parameter.type) &&
|
|
148
|
+
parameter.denotation == "embedding.dims"
|
|
149
|
+
);
|
|
150
|
+
// Get the embedding output parameter
|
|
151
|
+
var (embeddingParamIdx, embeddingParam) = signature.outputs
|
|
152
|
+
.Select((parameter, idx) => (idx, parameter))
|
|
153
|
+
.Where(pair =>
|
|
154
|
+
pair.parameter.type == Dtype.Float32 &&
|
|
155
|
+
pair.parameter.denotation == "embedding"
|
|
156
|
+
)
|
|
157
|
+
.FirstOrDefault();
|
|
158
|
+
if (embeddingParam == null)
|
|
159
|
+
throw new InvalidOperationException($"{tag} cannot be used with OpenAI embedding API because it has no outputs with an `embedding` denotation.");
|
|
160
|
+
// Get the index of the usage output (optional)
|
|
161
|
+
var (usageParamIdx, usageParam) = signature.outputs
|
|
162
|
+
.Select((parameter, idx) => (idx, parameter))
|
|
163
|
+
.Where(pair =>
|
|
164
|
+
pair.parameter.type == Dtype.Dict &&
|
|
165
|
+
pair.parameter.denotation == "openai.embedding.usage"
|
|
166
|
+
)
|
|
167
|
+
.FirstOrDefault();
|
|
168
|
+
// Create delegate
|
|
169
|
+
EmbeddingDelegate result = async (
|
|
170
|
+
string model,
|
|
171
|
+
string[] input,
|
|
172
|
+
int? dimensions,
|
|
173
|
+
EncodingFormat encodingFormat,
|
|
174
|
+
object acceleration
|
|
175
|
+
) => {
|
|
176
|
+
// Build prediction input map
|
|
177
|
+
var inputMap = new Dictionary<string, object?> {
|
|
178
|
+
[inputParam.name] = input
|
|
179
|
+
};
|
|
180
|
+
if (dimensions != null && matryoshkaParam != null)
|
|
181
|
+
inputMap[matryoshkaParam.name] = dimensions.Value;
|
|
182
|
+
// Create prediction
|
|
183
|
+
var prediction = await CreatePrediction(
|
|
184
|
+
model,
|
|
185
|
+
inputs: inputMap,
|
|
186
|
+
acceleration: acceleration
|
|
187
|
+
);
|
|
188
|
+
// Check for error
|
|
189
|
+
if (prediction.error != null)
|
|
190
|
+
throw new InvalidOperationException(prediction.error);
|
|
191
|
+
// Check returned embedding
|
|
192
|
+
var rawEmbeddingMatrix = prediction.results![embeddingParamIdx]!;
|
|
193
|
+
if (!(rawEmbeddingMatrix is Tensor<float> embeddingMatrix))
|
|
194
|
+
throw new InvalidOperationException($"{tag} cannot be used with OpenAI embedding API because it returned an object of type {rawEmbeddingMatrix.GetType()} instead of an embedding matrix.");
|
|
195
|
+
if (embeddingMatrix.shape.Length != 2) {
|
|
196
|
+
var shapeStr = "(" + string.Join(",", embeddingMatrix.shape) + ")";
|
|
197
|
+
throw new InvalidOperationException($"{tag} cannot be used with OpenAI embedding API because it returned an embedding matrix with invalid shape: {shapeStr}");
|
|
198
|
+
}
|
|
199
|
+
// Create embedding response
|
|
200
|
+
var embeddings = Enumerable
|
|
201
|
+
.Range(0, embeddingMatrix.shape[0])
|
|
202
|
+
.Select(idx => ParseEmbedding(embeddingMatrix, idx, encodingFormat))
|
|
203
|
+
.ToArray();
|
|
204
|
+
var usage = usageParam != null ?
|
|
205
|
+
(prediction.results![usageParamIdx]! as JObject)!.ToObject<CreateEmbeddingResponse.UsageInfo>() :
|
|
206
|
+
default;
|
|
207
|
+
var response = new CreateEmbeddingResponse {
|
|
208
|
+
Object = "list",
|
|
209
|
+
Model = model,
|
|
210
|
+
Data = embeddings,
|
|
211
|
+
Usage = usage
|
|
212
|
+
};
|
|
213
|
+
// Return
|
|
214
|
+
return response;
|
|
215
|
+
};
|
|
216
|
+
// Return
|
|
217
|
+
return result;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
private Task<Prediction> CreatePrediction(
|
|
221
|
+
string tag,
|
|
222
|
+
Dictionary<string, object?> inputs,
|
|
223
|
+
object acceleration
|
|
224
|
+
) => acceleration switch {
|
|
225
|
+
Acceleration acc => predictions.Create(tag, inputs, acc),
|
|
226
|
+
RemoteAcceleration acc => remotePredictions.Create(tag, inputs, acc),
|
|
227
|
+
_ => throw new InvalidOperationException($"Cannot create {tag} prediction because acceleration is invalid: {acceleration}")
|
|
228
|
+
};
|
|
229
|
+
|
|
230
|
+
private unsafe Embedding ParseEmbedding(
|
|
231
|
+
Tensor<float> matrix,
|
|
232
|
+
int index,
|
|
233
|
+
EncodingFormat format
|
|
234
|
+
) {
|
|
235
|
+
fixed (float* data = matrix) {
|
|
236
|
+
var baseAddress = data + index * matrix.shape[1];
|
|
237
|
+
var floatSpan = new ReadOnlySpan<float>(baseAddress, matrix.shape[1]);
|
|
238
|
+
var byteSpan = new ReadOnlySpan<byte>(baseAddress, matrix.shape[1] * sizeof(float));
|
|
239
|
+
var embeddingVector = format == EncodingFormat.Float ? floatSpan.ToArray() : null;
|
|
240
|
+
var base64Rep = format == EncodingFormat.Base64 ? Convert.ToBase64String(byteSpan) : null;
|
|
241
|
+
var embedding = new Embedding {
|
|
242
|
+
Object = @"embedding",
|
|
243
|
+
Floats = embeddingVector,
|
|
244
|
+
Index = index,
|
|
245
|
+
Base64 = base64Rep
|
|
246
|
+
};
|
|
247
|
+
return embedding;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
#endregion
|
|
251
|
+
}
|
|
252
|
+
}
|