using System.Text; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using VectorSearchApp.Configuration; using VectorSearchApp.Models; namespace VectorSearchApp.Services; public interface IEmbeddingService { Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default); } public class EmbeddingService : IEmbeddingService { private readonly HttpClient? _httpClient; private readonly string _modelName; private readonly int _dimension; private readonly bool _useLocalInference; private readonly InferenceSession? _inferenceSession; private readonly Dictionary? _vocabulary; // BERT vocabulary for all-MiniLM-L6-v2 (subset for demo) private static readonly Dictionary BaseVocabulary = new() { ["[PAD]"] = 0, ["[UNK]"] = 1, ["[CLS]"] = 2, ["[SEP]"] = 3, ["[MASK]"] = 4, ["!"] = 5, ["\""] = 6, ["#"] = 7, ["$"] = 8, ["%"] = 9, ["&"] = 10, ["'"] = 11, ["("] = 12, [")"] = 13, ["*"] = 14, ["+"] = 15, [","] = 16, ["-"] = 17, ["."] = 18, ["/"] = 19, [":"] = 20, [";"] = 21, ["="] = 22, ["?"] = 23, ["@"] = 24, ["["] = 25, ["]"] = 26, ["_"] = 27, ["`"] = 28, ["{"] = 29, ["}"] = 30, ["~"] = 31, ["0"] = 32, ["1"] = 33, ["2"] = 34, ["3"] = 35, ["4"] = 36, ["5"] = 37, ["6"] = 38, ["7"] = 39, ["8"] = 40, ["9"] = 41, ["a"] = 42, ["b"] = 43, ["c"] = 44, ["d"] = 45, ["e"] = 46, ["f"] = 47, ["g"] = 48, ["h"] = 49, ["i"] = 50, ["j"] = 51, ["k"] = 52, ["l"] = 53, ["m"] = 54, ["n"] = 55, ["o"] = 56, ["p"] = 57, ["q"] = 58, ["r"] = 59, ["s"] = 60, ["t"] = 61, ["u"] = 62, ["v"] = 63, ["w"] = 64, ["x"] = 65, ["y"] = 66, ["z"] = 67, ["A"] = 68, ["B"] = 69, ["C"] = 70, ["D"] = 71, ["E"] = 72, ["F"] = 73, ["G"] = 74, ["H"] = 75, ["I"] = 76, ["J"] = 77, ["K"] = 78, ["L"] = 79, ["M"] = 80, ["N"] = 81, ["O"] = 82, ["P"] = 83, ["Q"] = 84, ["R"] = 85, ["S"] = 86, ["T"] = 87, ["U"] = 88, ["V"] = 89, ["W"] = 90, ["X"] = 91, ["Y"] = 92, ["Z"] = 93, ["the"] = 94, ["of"] = 95, ["and"] = 96, ["to"] = 97, ["in"] = 98, ["is"] = 99, ["it"] = 100, ["that"] = 101, ["for"] = 102, ["was"] = 103, ["on"] = 104, ["with"] = 105, ["as"] = 106, ["at"] = 107, ["by"] = 108, ["this"] = 109, ["be"] = 110, ["not"] = 111, ["from"] = 112, ["or"] = 113, ["which"] = 114, ["but"] = 115, ["are"] = 116, ["we"] = 117, ["have"] = 118, ["an"] = 119, ["had"] = 120, ["they"] = 121, ["his"] = 122, ["her"] = 123, ["she"] = 124, ["he"] = 125, ["you"] = 126, ["their"] = 127, ["what"] = 128, ["all"] = 129, ["were"] = 130, ["when"] = 131, ["your"] = 132, ["can"] = 133, ["there"] = 134, ["has"] = 135, ["been"] = 136, ["if"] = 137, ["who"] = 138, ["more"] = 139, ["will"] = 140, ["up"] = 141, ["one"] = 142, ["time"] = 143, ["out"] = 144, ["about"] = 145, ["into"] = 146, ["just"] = 147, ["him"] = 148, ["them"] = 149, ["me"] = 150, ["my"] = 151, ["than"] = 152, ["now"] = 153, ["do"] = 154, ["would"] = 155, ["so"] = 156, ["get"] = 157, ["no"] = 158, ["see"] = 159, ["way"] = 160, ["could"] = 161, ["like"] = 162, ["other"] = 163, ["then"] = 164, ["first"] = 165, ["also"] = 166, ["back"] = 167, ["after"] = 168, ["use"] = 169, ["two"] = 170, ["how"] = 171, ["our"] = 172, ["work"] = 173, ["well"] = 174, ["even"] = 175, ["new"] = 176, ["want"] = 177, ["because"] = 178, ["any"] = 179, ["these"] = 180, ["give"] = 181, ["day"] = 182, ["most"] = 183, ["us"] = 184, ["address"] = 185, ["street"] = 186, ["city"] = 187, ["state"] = 188, ["zip"] = 189, ["code"] = 190, ["name"] = 191, ["number"] = 192 }; public EmbeddingService(EmbeddingConfiguration config) { _modelName = config.ModelName; _dimension = config.Dimension; _useLocalInference = config.UseLocalInference; if (_useLocalInference) { // Use local ONNX inference _httpClient = null; try { var modelPath = GetModelPath(_modelName); _inferenceSession = new InferenceSession(modelPath); _vocabulary = BaseVocabulary; } catch (Exception ex) { throw new InvalidOperationException($"Failed to initialize local inference session: {ex.Message}"); } } else { // Use HuggingFace Inference API _httpClient = new HttpClient { BaseAddress = new Uri("https://router.huggingface.co/") }; _httpClient.DefaultRequestHeaders.Add("User-Agent", "VectorSearchApp"); } } private static string GetModelPath(string modelName) { var modelFileName = modelName switch { "sentence-transformers/all-MiniLM-L6-v2" => "all-MiniLM-L6-v2.onnx", "jarredparrett/all-MiniLM-L6-v2_tuned_on_deepparse_address_mutations_comb_3" => "address-embedding-model.onnx", "custom-all-MiniLM-L6-v2-address" => "address-embedding-model.onnx", _ => throw new NotSupportedException($"Model '{modelName}' is not supported for local inference") }; // Check multiple possible locations for the model file var possiblePaths = new[] { Path.Combine(AppContext.BaseDirectory, "Models", modelFileName), Path.Combine(AppContext.BaseDirectory, "Models", "Models", modelFileName), Path.Combine("Models", modelFileName), Path.Combine("Models", "Models", modelFileName) }; foreach (var modelPath in possiblePaths) { if (File.Exists(modelPath)) { return modelPath; } } throw new FileNotFoundException($"Model file '{modelFileName}' not found. Searched in: {string.Join(", ", possiblePaths)}"); } public Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { return _useLocalInference ? GenerateEmbeddingLocalAsync(text, cancellationToken) : GenerateEmbeddingRemoteAsync(text, cancellationToken); } private Task GenerateEmbeddingLocalAsync(string text, CancellationToken cancellationToken) { if (_inferenceSession == null) { throw new InvalidOperationException("Local inference session is not initialized"); } // Tokenize the input text using word-piece tokenization var tokens = Tokenize(text); const int maxLength = 128; var inputIds = new long[maxLength]; var attentionMask = new long[maxLength]; // Add [CLS] at the beginning inputIds[0] = 2; // [CLS] attentionMask[0] = 1; // Add tokenized words for (int i = 0; i < Math.Min(tokens.Count, maxLength - 2); i++) { inputIds[i + 1] = tokens[i]; attentionMask[i + 1] = 1; } // Add [SEP] at the end var tokenCount = Math.Min(tokens.Count + 2, maxLength); inputIds[tokenCount - 1] = 3; // [SEP] attentionMask[tokenCount - 1] = 1; var inputIdsTensor = new DenseTensor(inputIds, new[] { 1, maxLength }); var attentionMaskTensor = new DenseTensor(attentionMask, new[] { 1, maxLength }); // Token type IDs (all zeros for single sentence input) var tokenTypeIds = new long[maxLength]; var inputs = new List { NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor), NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor), NamedOnnxValue.CreateFromTensor("token_type_ids", new DenseTensor(tokenTypeIds, new[] { 1, maxLength })) }; var results = _inferenceSession.Run(inputs); var output = results.FirstOrDefault(r => r.Name.Contains("last_hidden_state")) ?? results.First(); var hiddenStates = output.AsEnumerable().ToArray(); // Apply mean pooling over the sequence dimension var pooledOutput = new float[_dimension]; var actualLength = attentionMask.Sum(m => m); if (actualLength > 0) { for (int i = 0; i < _dimension; i++) { float sum = 0; for (int j = 0; j < maxLength; j++) { if (attentionMask[j] > 0) { sum += hiddenStates[j * _dimension + i]; } } pooledOutput[i] = sum / actualLength; } } return Task.FromResult(pooledOutput); } private List Tokenize(string text) { var tokens = new List(); var words = WordTokenize(text); foreach (var word in words) { if (word.Length == 1) { // Single character if (_vocabulary!.TryGetValue(word, out var id)) { tokens.Add(id); } else if (_vocabulary.TryGetValue(word.ToLower(), out id)) { tokens.Add(id); } else { tokens.Add(1); // [UNK] } } else { // Multi-character word - try word-piece tokenization var subTokens = WordPieceTokenize(word, _vocabulary!); tokens.AddRange(subTokens); } } return tokens; } private static List WordTokenize(string text) { // Simple word tokenizer that splits on whitespace and punctuation var result = new List(); var sb = new StringBuilder(); foreach (var c in text) { if (char.IsWhiteSpace(c) || char.IsPunctuation(c)) { if (sb.Length > 0) { result.Add(sb.ToString()); sb.Clear(); } } else { sb.Append(c); } } if (sb.Length > 0) { result.Add(sb.ToString()); } return result; } private static List WordPieceTokenize(string word, Dictionary vocab) { var tokens = new List(); var remaining = word; while (!string.IsNullOrEmpty(remaining)) { // Try to find the longest matching subword var match = ""; var matchLength = 0; foreach (var kvp in vocab) { if (kvp.Key.StartsWith("##") && remaining.StartsWith(kvp.Key.Substring(2))) { if (kvp.Key.Length > matchLength) { match = kvp.Key; matchLength = kvp.Key.Length; } } else if (remaining.StartsWith(kvp.Key) && !kvp.Key.StartsWith("##")) { if (kvp.Key.Length > matchLength) { match = kvp.Key; matchLength = kvp.Key.Length; } } } if (matchLength == 0) { // No match found if (vocab.TryGetValue(word.ToLower(), out var wordId)) { tokens.Add(wordId); } else { tokens.Add(1); // [UNK] } break; } if (match.StartsWith("##")) { tokens.Add(vocab[match]); remaining = remaining.Substring(match.Length - 2); } else { tokens.Add(vocab[match]); remaining = remaining.Substring(match.Length); } } return tokens; } private async Task GenerateEmbeddingRemoteAsync(string text, CancellationToken cancellationToken) { var url = $"pipeline/feature-extraction/{_modelName}"; var request = new { inputs = text, options = new { wait_for_model = true } }; var jsonContent = System.Text.Json.JsonSerializer.Serialize(request); var httpContent = new StringContent(jsonContent, System.Text.Encoding.UTF8, "application/json"); var response = await _httpClient!.PostAsync(url, httpContent, cancellationToken); if (!response.IsSuccessStatusCode) { var errorContent = await response.Content.ReadAsStringAsync(cancellationToken); throw new InvalidOperationException($"Failed to generate embedding: {response.StatusCode} - {errorContent}"); } try { var result = await System.Text.Json.JsonSerializer.DeserializeAsync( await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); if (result?.Length > 0 && result[0].Length > 0) { if (result[0].Length != _dimension) { Console.WriteLine($"Warning: Embedding dimension ({result[0].Length}) differs from expected ({_dimension})"); } return result[0]; } } catch (System.Text.Json.JsonException ex) { var rawContent = await response.Content.ReadAsStringAsync(cancellationToken); throw new InvalidOperationException($"Failed to parse embedding response: {ex.Message}. Raw response: {rawContent.Substring(0, Math.Min(500, rawContent.Length))}"); } throw new InvalidOperationException("Failed to generate embedding: empty result"); } }