360 lines
13 KiB
C#
360 lines
13 KiB
C#
using System.Text;
|
|
using Microsoft.ML.OnnxRuntime;
|
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
|
using VectorSearchApp.Configuration;
|
|
using VectorSearchApp.Models;
|
|
|
|
namespace VectorSearchApp.Services;
|
|
|
|
public interface IEmbeddingService
|
|
{
|
|
Task<float[]> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default);
|
|
}
|
|
|
|
public class EmbeddingService : IEmbeddingService
|
|
{
|
|
private readonly HttpClient? _httpClient;
|
|
private readonly string _modelName;
|
|
private readonly int _dimension;
|
|
private readonly bool _useLocalInference;
|
|
private readonly InferenceSession? _inferenceSession;
|
|
private readonly Dictionary<string, int>? _vocabulary;
|
|
|
|
// BERT vocabulary for all-MiniLM-L6-v2 (subset for demo)
|
|
private static readonly Dictionary<string, int> BaseVocabulary = new()
|
|
{
|
|
["[PAD]"] = 0, ["[UNK]"] = 1, ["[CLS]"] = 2, ["[SEP]"] = 3, ["[MASK]"] = 4,
|
|
["!"] = 5, ["\""] = 6, ["#"] = 7, ["$"] = 8, ["%"] = 9, ["&"] = 10, ["'"] = 11,
|
|
["("] = 12, [")"] = 13, ["*"] = 14, ["+"] = 15, [","] = 16, ["-"] = 17, ["."] = 18,
|
|
["/"] = 19, [":"] = 20, [";"] = 21, ["="] = 22, ["?"] = 23, ["@"] = 24, ["["] = 25,
|
|
["]"] = 26, ["_"] = 27, ["`"] = 28, ["{"] = 29, ["}"] = 30, ["~"] = 31, ["0"] = 32,
|
|
["1"] = 33, ["2"] = 34, ["3"] = 35, ["4"] = 36, ["5"] = 37, ["6"] = 38, ["7"] = 39,
|
|
["8"] = 40, ["9"] = 41, ["a"] = 42, ["b"] = 43, ["c"] = 44, ["d"] = 45, ["e"] = 46,
|
|
["f"] = 47, ["g"] = 48, ["h"] = 49, ["i"] = 50, ["j"] = 51, ["k"] = 52, ["l"] = 53,
|
|
["m"] = 54, ["n"] = 55, ["o"] = 56, ["p"] = 57, ["q"] = 58, ["r"] = 59, ["s"] = 60,
|
|
["t"] = 61, ["u"] = 62, ["v"] = 63, ["w"] = 64, ["x"] = 65, ["y"] = 66, ["z"] = 67,
|
|
["A"] = 68, ["B"] = 69, ["C"] = 70, ["D"] = 71, ["E"] = 72, ["F"] = 73, ["G"] = 74,
|
|
["H"] = 75, ["I"] = 76, ["J"] = 77, ["K"] = 78, ["L"] = 79, ["M"] = 80, ["N"] = 81,
|
|
["O"] = 82, ["P"] = 83, ["Q"] = 84, ["R"] = 85, ["S"] = 86, ["T"] = 87, ["U"] = 88,
|
|
["V"] = 89, ["W"] = 90, ["X"] = 91, ["Y"] = 92, ["Z"] = 93, ["the"] = 94, ["of"] = 95,
|
|
["and"] = 96, ["to"] = 97, ["in"] = 98, ["is"] = 99, ["it"] = 100, ["that"] = 101,
|
|
["for"] = 102, ["was"] = 103, ["on"] = 104, ["with"] = 105, ["as"] = 106, ["at"] = 107,
|
|
["by"] = 108, ["this"] = 109, ["be"] = 110, ["not"] = 111, ["from"] = 112, ["or"] = 113,
|
|
["which"] = 114, ["but"] = 115, ["are"] = 116, ["we"] = 117, ["have"] = 118, ["an"] = 119,
|
|
["had"] = 120, ["they"] = 121, ["his"] = 122, ["her"] = 123, ["she"] = 124, ["he"] = 125,
|
|
["you"] = 126, ["their"] = 127, ["what"] = 128, ["all"] = 129, ["were"] = 130, ["when"] = 131,
|
|
["your"] = 132, ["can"] = 133, ["there"] = 134, ["has"] = 135, ["been"] = 136, ["if"] = 137,
|
|
["who"] = 138, ["more"] = 139, ["will"] = 140, ["up"] = 141, ["one"] = 142, ["time"] = 143,
|
|
["out"] = 144, ["about"] = 145, ["into"] = 146, ["just"] = 147, ["him"] = 148, ["them"] = 149,
|
|
["me"] = 150, ["my"] = 151, ["than"] = 152, ["now"] = 153, ["do"] = 154, ["would"] = 155,
|
|
["so"] = 156, ["get"] = 157, ["no"] = 158, ["see"] = 159, ["way"] = 160, ["could"] = 161,
|
|
["like"] = 162, ["other"] = 163, ["then"] = 164, ["first"] = 165, ["also"] = 166, ["back"] = 167,
|
|
["after"] = 168, ["use"] = 169, ["two"] = 170, ["how"] = 171, ["our"] = 172, ["work"] = 173,
|
|
["well"] = 174, ["even"] = 175, ["new"] = 176, ["want"] = 177, ["because"] = 178, ["any"] = 179,
|
|
["these"] = 180, ["give"] = 181, ["day"] = 182, ["most"] = 183, ["us"] = 184, ["address"] = 185,
|
|
["street"] = 186, ["city"] = 187, ["state"] = 188, ["zip"] = 189, ["code"] = 190, ["name"] = 191,
|
|
["number"] = 192
|
|
};
|
|
|
|
public EmbeddingService(EmbeddingConfiguration config)
|
|
{
|
|
_modelName = config.ModelName;
|
|
_dimension = config.Dimension;
|
|
_useLocalInference = config.UseLocalInference;
|
|
|
|
if (_useLocalInference)
|
|
{
|
|
// Use local ONNX inference
|
|
_httpClient = null;
|
|
try
|
|
{
|
|
var modelPath = GetModelPath(_modelName);
|
|
_inferenceSession = new InferenceSession(modelPath);
|
|
_vocabulary = BaseVocabulary;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
throw new InvalidOperationException($"Failed to initialize local inference session: {ex.Message}");
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Use HuggingFace Inference API
|
|
_httpClient = new HttpClient
|
|
{
|
|
BaseAddress = new Uri("https://router.huggingface.co/")
|
|
};
|
|
_httpClient.DefaultRequestHeaders.Add("User-Agent", "VectorSearchApp");
|
|
}
|
|
}
|
|
|
|
private static string GetModelPath(string modelName)
|
|
{
|
|
var modelFileName = modelName switch
|
|
{
|
|
"sentence-transformers/all-MiniLM-L6-v2" => "all-MiniLM-L6-v2.onnx",
|
|
_ => throw new NotSupportedException($"Model '{modelName}' is not supported for local inference")
|
|
};
|
|
|
|
var modelPath = Path.Combine(AppContext.BaseDirectory, "Models", modelFileName);
|
|
|
|
if (!File.Exists(modelPath))
|
|
{
|
|
modelPath = Path.Combine("Models", modelFileName);
|
|
}
|
|
|
|
return modelPath;
|
|
}
|
|
|
|
public Task<float[]> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
|
|
{
|
|
return _useLocalInference
|
|
? GenerateEmbeddingLocalAsync(text, cancellationToken)
|
|
: GenerateEmbeddingRemoteAsync(text, cancellationToken);
|
|
}
|
|
|
|
private Task<float[]> GenerateEmbeddingLocalAsync(string text, CancellationToken cancellationToken)
|
|
{
|
|
if (_inferenceSession == null)
|
|
{
|
|
throw new InvalidOperationException("Local inference session is not initialized");
|
|
}
|
|
|
|
// Tokenize the input text using word-piece tokenization
|
|
var tokens = Tokenize(text);
|
|
|
|
const int maxLength = 128;
|
|
var inputIds = new long[maxLength];
|
|
var attentionMask = new long[maxLength];
|
|
|
|
// Add [CLS] at the beginning
|
|
inputIds[0] = 2; // [CLS]
|
|
attentionMask[0] = 1;
|
|
|
|
// Add tokenized words
|
|
for (int i = 0; i < Math.Min(tokens.Count, maxLength - 2); i++)
|
|
{
|
|
inputIds[i + 1] = tokens[i];
|
|
attentionMask[i + 1] = 1;
|
|
}
|
|
|
|
// Add [SEP] at the end
|
|
var tokenCount = Math.Min(tokens.Count + 2, maxLength);
|
|
inputIds[tokenCount - 1] = 3; // [SEP]
|
|
attentionMask[tokenCount - 1] = 1;
|
|
|
|
var inputIdsTensor = new DenseTensor<long>(inputIds, new[] { 1, maxLength });
|
|
var attentionMaskTensor = new DenseTensor<long>(attentionMask, new[] { 1, maxLength });
|
|
|
|
// Token type IDs (all zeros for single sentence input)
|
|
var tokenTypeIds = new long[maxLength];
|
|
|
|
var inputs = new List<NamedOnnxValue>
|
|
{
|
|
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
|
|
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor),
|
|
NamedOnnxValue.CreateFromTensor("token_type_ids", new DenseTensor<long>(tokenTypeIds, new[] { 1, maxLength }))
|
|
};
|
|
|
|
var results = _inferenceSession.Run(inputs);
|
|
|
|
var output = results.FirstOrDefault(r => r.Name.Contains("last_hidden_state"))
|
|
?? results.First();
|
|
|
|
var hiddenStates = output.AsEnumerable<float>().ToArray();
|
|
|
|
// Apply mean pooling over the sequence dimension
|
|
var pooledOutput = new float[_dimension];
|
|
var actualLength = attentionMask.Sum(m => m);
|
|
|
|
if (actualLength > 0)
|
|
{
|
|
for (int i = 0; i < _dimension; i++)
|
|
{
|
|
float sum = 0;
|
|
for (int j = 0; j < maxLength; j++)
|
|
{
|
|
if (attentionMask[j] > 0)
|
|
{
|
|
sum += hiddenStates[j * _dimension + i];
|
|
}
|
|
}
|
|
pooledOutput[i] = sum / actualLength;
|
|
}
|
|
}
|
|
|
|
return Task.FromResult(pooledOutput);
|
|
}
|
|
|
|
private List<int> Tokenize(string text)
|
|
{
|
|
var tokens = new List<int>();
|
|
var words = WordTokenize(text);
|
|
|
|
foreach (var word in words)
|
|
{
|
|
if (word.Length == 1)
|
|
{
|
|
// Single character
|
|
if (_vocabulary!.TryGetValue(word, out var id))
|
|
{
|
|
tokens.Add(id);
|
|
}
|
|
else if (_vocabulary.TryGetValue(word.ToLower(), out id))
|
|
{
|
|
tokens.Add(id);
|
|
}
|
|
else
|
|
{
|
|
tokens.Add(1); // [UNK]
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Multi-character word - try word-piece tokenization
|
|
var subTokens = WordPieceTokenize(word, _vocabulary!);
|
|
tokens.AddRange(subTokens);
|
|
}
|
|
}
|
|
|
|
return tokens;
|
|
}
|
|
|
|
private static List<string> WordTokenize(string text)
|
|
{
|
|
// Simple word tokenizer that splits on whitespace and punctuation
|
|
var result = new List<string>();
|
|
var sb = new StringBuilder();
|
|
|
|
foreach (var c in text)
|
|
{
|
|
if (char.IsWhiteSpace(c) || char.IsPunctuation(c))
|
|
{
|
|
if (sb.Length > 0)
|
|
{
|
|
result.Add(sb.ToString());
|
|
sb.Clear();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
sb.Append(c);
|
|
}
|
|
}
|
|
|
|
if (sb.Length > 0)
|
|
{
|
|
result.Add(sb.ToString());
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
private static List<int> WordPieceTokenize(string word, Dictionary<string, int> vocab)
|
|
{
|
|
var tokens = new List<int>();
|
|
var remaining = word;
|
|
|
|
while (!string.IsNullOrEmpty(remaining))
|
|
{
|
|
// Try to find the longest matching subword
|
|
var match = "";
|
|
var matchLength = 0;
|
|
|
|
foreach (var kvp in vocab)
|
|
{
|
|
if (kvp.Key.StartsWith("##") && remaining.StartsWith(kvp.Key.Substring(2)))
|
|
{
|
|
if (kvp.Key.Length > matchLength)
|
|
{
|
|
match = kvp.Key;
|
|
matchLength = kvp.Key.Length;
|
|
}
|
|
}
|
|
else if (remaining.StartsWith(kvp.Key) && !kvp.Key.StartsWith("##"))
|
|
{
|
|
if (kvp.Key.Length > matchLength)
|
|
{
|
|
match = kvp.Key;
|
|
matchLength = kvp.Key.Length;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (matchLength == 0)
|
|
{
|
|
// No match found
|
|
if (vocab.TryGetValue(word.ToLower(), out var wordId))
|
|
{
|
|
tokens.Add(wordId);
|
|
}
|
|
else
|
|
{
|
|
tokens.Add(1); // [UNK]
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (match.StartsWith("##"))
|
|
{
|
|
tokens.Add(vocab[match]);
|
|
remaining = remaining.Substring(match.Length - 2);
|
|
}
|
|
else
|
|
{
|
|
tokens.Add(vocab[match]);
|
|
remaining = remaining.Substring(match.Length);
|
|
}
|
|
}
|
|
|
|
return tokens;
|
|
}
|
|
|
|
private async Task<float[]> GenerateEmbeddingRemoteAsync(string text, CancellationToken cancellationToken)
|
|
{
|
|
var url = $"pipeline/feature-extraction/{_modelName}";
|
|
var request = new
|
|
{
|
|
inputs = text,
|
|
options = new
|
|
{
|
|
wait_for_model = true
|
|
}
|
|
};
|
|
|
|
var jsonContent = System.Text.Json.JsonSerializer.Serialize(request);
|
|
var httpContent = new StringContent(jsonContent, System.Text.Encoding.UTF8, "application/json");
|
|
|
|
var response = await _httpClient!.PostAsync(url, httpContent, cancellationToken);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
var errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
throw new InvalidOperationException($"Failed to generate embedding: {response.StatusCode} - {errorContent}");
|
|
}
|
|
|
|
try
|
|
{
|
|
var result = await System.Text.Json.JsonSerializer.DeserializeAsync<float[][]>(
|
|
await response.Content.ReadAsStreamAsync(cancellationToken),
|
|
cancellationToken: cancellationToken);
|
|
|
|
if (result?.Length > 0 && result[0].Length > 0)
|
|
{
|
|
if (result[0].Length != _dimension)
|
|
{
|
|
Console.WriteLine($"Warning: Embedding dimension ({result[0].Length}) differs from expected ({_dimension})");
|
|
}
|
|
return result[0];
|
|
}
|
|
}
|
|
catch (System.Text.Json.JsonException ex)
|
|
{
|
|
var rawContent = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
throw new InvalidOperationException($"Failed to parse embedding response: {ex.Message}. Raw response: {rawContent.Substring(0, Math.Min(500, rawContent.Length))}");
|
|
}
|
|
|
|
throw new InvalidOperationException("Failed to generate embedding: empty result");
|
|
}
|
|
}
|