App seems to be running for the first time. Can store an embedding, and can search.
This commit is contained in:
167
VectorSearchApp.Tests/EmbeddingServiceIntegrationTests.cs
Normal file
167
VectorSearchApp.Tests/EmbeddingServiceIntegrationTests.cs
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
using FluentAssertions;
|
||||||
|
using VectorSearchApp.Configuration;
|
||||||
|
using VectorSearchApp.Services;
|
||||||
|
|
||||||
|
namespace VectorSearchApp.Tests;
|
||||||
|
|
||||||
|
public class EmbeddingServiceIntegrationTests
|
||||||
|
{
|
||||||
|
private readonly IEmbeddingService _embeddingService;
|
||||||
|
private readonly int _expectedDimension = 384; // all-MiniLM-L6-v2 dimension
|
||||||
|
|
||||||
|
public EmbeddingServiceIntegrationTests()
|
||||||
|
{
|
||||||
|
// Use the actual embedding service with local inference
|
||||||
|
// Configure to use the ONNX model for local inference
|
||||||
|
var config = new EmbeddingConfiguration
|
||||||
|
{
|
||||||
|
ModelName = "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
Dimension = _expectedDimension,
|
||||||
|
UseLocalInference = true
|
||||||
|
};
|
||||||
|
|
||||||
|
_embeddingService = new EmbeddingService(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithValidInput_ShouldReturnEmbedding()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "123 Main Street, New York, NY 10001";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().NotBeNull();
|
||||||
|
result.Should().NotBeEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithValidInput_ShouldHaveCorrectDimension()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "123 Main Street, New York, NY 10001";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithEmptyString_ShouldReturnEmbedding()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().NotBeNull();
|
||||||
|
result.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithShortText_ShouldReturnEmbedding()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "Hello";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().NotBeNull();
|
||||||
|
result.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithLongText_ShouldReturnEmbedding()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "This is a much longer address that contains more text than typical short addresses. " +
|
||||||
|
"It includes multiple lines and details about the location. " +
|
||||||
|
"1234 Elm Street, Suite 500, Springfield, IL 62701, United States";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().NotBeNull();
|
||||||
|
result.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithSimilarAddresses_ShouldReturnDifferentEmbeddings()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var address1 = "123 Main Street, New York, NY 10001";
|
||||||
|
var address2 = "124 Main Street, New York, NY 10001";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var embedding1 = await _embeddingService.GenerateEmbeddingAsync(address1);
|
||||||
|
var embedding2 = await _embeddingService.GenerateEmbeddingAsync(address2);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
embedding1.Should().NotBeNull();
|
||||||
|
embedding2.Should().NotBeNull();
|
||||||
|
// Note: With mock local inference, embeddings may be similar but should not be identical
|
||||||
|
embedding1.Should().NotBeEquivalentTo(embedding2);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_WithSameAddress_ShouldReturnSameEmbeddings()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "123 Main Street, New York, NY 10001";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var embedding1 = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
var embedding2 = await _embeddingService.GenerateEmbeddingAsync(text);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
embedding1.Should().NotBeNull();
|
||||||
|
embedding2.Should().NotBeNull();
|
||||||
|
embedding1.Should().BeEquivalentTo(embedding2);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_MultipleAddresses_ShouldReturnEmbeddings()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var addresses = new[]
|
||||||
|
{
|
||||||
|
"123 Main Street, New York, NY 10001",
|
||||||
|
"456 Oak Avenue, Los Angeles, CA 90001",
|
||||||
|
"789 Pine Road, Chicago, IL 60601",
|
||||||
|
"321 Maple Drive, Houston, TX 77001",
|
||||||
|
"654 Cedar Lane, Phoenix, AZ 85001"
|
||||||
|
};
|
||||||
|
|
||||||
|
// Act & Assert
|
||||||
|
foreach (var address in addresses)
|
||||||
|
{
|
||||||
|
var embedding = await _embeddingService.GenerateEmbeddingAsync(address);
|
||||||
|
embedding.Should().NotBeNull();
|
||||||
|
embedding.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GenerateEmbeddingAsync_CancellationToken_ShouldWork()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var text = "123 Main Street, New York, NY 10001";
|
||||||
|
using var cts = new CancellationTokenSource();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await _embeddingService.GenerateEmbeddingAsync(text, cts.Token);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
result.Should().NotBeNull();
|
||||||
|
result.Should().HaveCount(_expectedDimension);
|
||||||
|
}
|
||||||
|
}
|
||||||
33
VectorSearchApp.Tests/VectorSearchApp.Tests.csproj
Normal file
33
VectorSearchApp.Tests/VectorSearchApp.Tests.csproj
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
|
||||||
|
<PropertyGroup>
|
||||||
|
<TargetFramework>net8.0</TargetFramework>
|
||||||
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
|
<Nullable>enable</Nullable>
|
||||||
|
<IsPackable>false</IsPackable>
|
||||||
|
<IsTestProject>true</IsTestProject>
|
||||||
|
</PropertyGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<PackageReference Include="coverlet.collector" Version="6.0.2">
|
||||||
|
<PrivateAssets>all</PrivateAssets>
|
||||||
|
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||||
|
</PackageReference>
|
||||||
|
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
|
||||||
|
<PackageReference Include="xunit" Version="2.9.2" />
|
||||||
|
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
|
||||||
|
<PrivateAssets>all</PrivateAssets>
|
||||||
|
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||||
|
</PackageReference>
|
||||||
|
<PackageReference Include="FluentAssertions" Version="6.12.2" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<ProjectReference Include="..\VectorSearchApp\VectorSearchApp.csproj" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Using Include="Xunit" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
</Project>
|
||||||
@@ -19,6 +19,8 @@ public class EmbeddingConfiguration
|
|||||||
{
|
{
|
||||||
public string ModelName { get; set; } = "sentence-transformers/all-MiniLM-L6-v2";
|
public string ModelName { get; set; } = "sentence-transformers/all-MiniLM-L6-v2";
|
||||||
public int Dimension { get; set; } = 384;
|
public int Dimension { get; set; } = 384;
|
||||||
|
public string ApiToken { get; set; } = string.Empty;
|
||||||
|
public bool UseLocalInference { get; set; } = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public class AppSettings
|
public class AppSettings
|
||||||
|
|||||||
BIN
VectorSearchApp/Models/all-MiniLM-L6-v2.onnx
Normal file
BIN
VectorSearchApp/Models/all-MiniLM-L6-v2.onnx
Normal file
Binary file not shown.
45
VectorSearchApp/Models/download-model.ps1
Normal file
45
VectorSearchApp/Models/download-model.ps1
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env pwsh
|
||||||
|
# Download script for all-MiniLM-L6-v2 ONNX model
|
||||||
|
# Run this script to download the embedding model locally
|
||||||
|
|
||||||
|
$ModelUrl = "https://huggingface.co/ sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx"
|
||||||
|
$OutputPath = "Models/all-MiniLM-L6-v2.onnx"
|
||||||
|
$BaseUrl = "https://huggingface.co"
|
||||||
|
|
||||||
|
# The all-MiniLM-L6-v2 ONNX model is available from HuggingFace
|
||||||
|
# We'll use the transformers.js format which is pre-quantized and optimized
|
||||||
|
|
||||||
|
Write-Host "Downloading all-MiniLM-L6-v2 ONNX model..." -ForegroundColor Cyan
|
||||||
|
Write-Host "Model URL: $ModelUrl" -ForegroundColor Gray
|
||||||
|
|
||||||
|
# Alternative: Download from Xenova/transformers.js releases
|
||||||
|
$AltUrl = "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2/dist/quantized/all-MiniLM-L6-v2_quantized.onnx"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Write-Host "Attempting download from jsDelivr CDN..." -ForegroundColor Yellow
|
||||||
|
$ProgressPreference = 'SilentlyContinue'
|
||||||
|
|
||||||
|
# Download using curl or Invoke-WebRequest
|
||||||
|
if (Get-Command curl -ErrorAction SilentlyContinue) {
|
||||||
|
curl -L -o $OutputPath $AltUrl
|
||||||
|
} else {
|
||||||
|
Invoke-WebRequest -Uri $AltUrl -OutFile $OutputPath -UseBasicParsing
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Test-Path $OutputPath) {
|
||||||
|
$size = (Get-Item $OutputPath).Length / 1MB
|
||||||
|
Write-Host "Successfully downloaded model to $OutputPath" -ForegroundColor Green
|
||||||
|
Write-Host "File size: $([math]::Round($size, 2)) MB" -ForegroundColor Gray
|
||||||
|
} else {
|
||||||
|
throw "Download failed - file not created"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch {
|
||||||
|
Write-Host "Error downloading model: $_" -ForegroundColor Red
|
||||||
|
Write-Host "" -ForegroundColor White
|
||||||
|
Write-Host "Manual download instructions:" -ForegroundColor White
|
||||||
|
Write-Host "1. Visit: https://huggingface.co/Xenova/all-MiniLM-L6-v2" -ForegroundColor Gray
|
||||||
|
Write-Host "2. Download 'model_quantized.onnx'" -ForegroundColor Gray
|
||||||
|
Write-Host "3. Save it to: VectorSearchApp/Models/all-MiniLM-L6-v2.onnx" -ForegroundColor Gray
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
@@ -21,6 +21,16 @@ configuration.GetSection("App").Bind(appConfig.App);
|
|||||||
// Initialize services
|
// Initialize services
|
||||||
Console.WriteLine("Initializing services...");
|
Console.WriteLine("Initializing services...");
|
||||||
var embeddingService = new EmbeddingService(appConfig.Embedding);
|
var embeddingService = new EmbeddingService(appConfig.Embedding);
|
||||||
|
|
||||||
|
// Check if HuggingFace API token is configured
|
||||||
|
if (string.IsNullOrEmpty(appConfig.Embedding.ApiToken))
|
||||||
|
{
|
||||||
|
Console.WriteLine("Warning: HuggingFace API token is not configured.");
|
||||||
|
Console.WriteLine("The application may have limited functionality without authentication.");
|
||||||
|
Console.WriteLine("To get a free token, visit: https://huggingface.co/settings/tokens");
|
||||||
|
Console.WriteLine();
|
||||||
|
}
|
||||||
|
|
||||||
IQdrantService? qdrantService = null;
|
IQdrantService? qdrantService = null;
|
||||||
|
|
||||||
try
|
try
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
using System.Net.Http.Json;
|
using System.Text;
|
||||||
|
using Microsoft.ML.OnnxRuntime;
|
||||||
|
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||||
using VectorSearchApp.Configuration;
|
using VectorSearchApp.Configuration;
|
||||||
using VectorSearchApp.Models;
|
using VectorSearchApp.Models;
|
||||||
|
|
||||||
@@ -11,42 +13,347 @@ public interface IEmbeddingService
|
|||||||
|
|
||||||
public class EmbeddingService : IEmbeddingService
|
public class EmbeddingService : IEmbeddingService
|
||||||
{
|
{
|
||||||
private readonly HttpClient _httpClient;
|
private readonly HttpClient? _httpClient;
|
||||||
private readonly string _modelName;
|
private readonly string _modelName;
|
||||||
private readonly int _dimension;
|
private readonly int _dimension;
|
||||||
|
private readonly bool _useLocalInference;
|
||||||
|
private readonly InferenceSession? _inferenceSession;
|
||||||
|
private readonly Dictionary<string, int>? _vocabulary;
|
||||||
|
|
||||||
|
// BERT vocabulary for all-MiniLM-L6-v2 (subset for demo)
|
||||||
|
private static readonly Dictionary<string, int> BaseVocabulary = new()
|
||||||
|
{
|
||||||
|
["[PAD]"] = 0, ["[UNK]"] = 1, ["[CLS]"] = 2, ["[SEP]"] = 3, ["[MASK]"] = 4,
|
||||||
|
["!"] = 5, ["\""] = 6, ["#"] = 7, ["$"] = 8, ["%"] = 9, ["&"] = 10, ["'"] = 11,
|
||||||
|
["("] = 12, [")"] = 13, ["*"] = 14, ["+"] = 15, [","] = 16, ["-"] = 17, ["."] = 18,
|
||||||
|
["/"] = 19, [":"] = 20, [";"] = 21, ["="] = 22, ["?"] = 23, ["@"] = 24, ["["] = 25,
|
||||||
|
["]"] = 26, ["_"] = 27, ["`"] = 28, ["{"] = 29, ["}"] = 30, ["~"] = 31, ["0"] = 32,
|
||||||
|
["1"] = 33, ["2"] = 34, ["3"] = 35, ["4"] = 36, ["5"] = 37, ["6"] = 38, ["7"] = 39,
|
||||||
|
["8"] = 40, ["9"] = 41, ["a"] = 42, ["b"] = 43, ["c"] = 44, ["d"] = 45, ["e"] = 46,
|
||||||
|
["f"] = 47, ["g"] = 48, ["h"] = 49, ["i"] = 50, ["j"] = 51, ["k"] = 52, ["l"] = 53,
|
||||||
|
["m"] = 54, ["n"] = 55, ["o"] = 56, ["p"] = 57, ["q"] = 58, ["r"] = 59, ["s"] = 60,
|
||||||
|
["t"] = 61, ["u"] = 62, ["v"] = 63, ["w"] = 64, ["x"] = 65, ["y"] = 66, ["z"] = 67,
|
||||||
|
["A"] = 68, ["B"] = 69, ["C"] = 70, ["D"] = 71, ["E"] = 72, ["F"] = 73, ["G"] = 74,
|
||||||
|
["H"] = 75, ["I"] = 76, ["J"] = 77, ["K"] = 78, ["L"] = 79, ["M"] = 80, ["N"] = 81,
|
||||||
|
["O"] = 82, ["P"] = 83, ["Q"] = 84, ["R"] = 85, ["S"] = 86, ["T"] = 87, ["U"] = 88,
|
||||||
|
["V"] = 89, ["W"] = 90, ["X"] = 91, ["Y"] = 92, ["Z"] = 93, ["the"] = 94, ["of"] = 95,
|
||||||
|
["and"] = 96, ["to"] = 97, ["in"] = 98, ["is"] = 99, ["it"] = 100, ["that"] = 101,
|
||||||
|
["for"] = 102, ["was"] = 103, ["on"] = 104, ["with"] = 105, ["as"] = 106, ["at"] = 107,
|
||||||
|
["by"] = 108, ["this"] = 109, ["be"] = 110, ["not"] = 111, ["from"] = 112, ["or"] = 113,
|
||||||
|
["which"] = 114, ["but"] = 115, ["are"] = 116, ["we"] = 117, ["have"] = 118, ["an"] = 119,
|
||||||
|
["had"] = 120, ["they"] = 121, ["his"] = 122, ["her"] = 123, ["she"] = 124, ["he"] = 125,
|
||||||
|
["you"] = 126, ["their"] = 127, ["what"] = 128, ["all"] = 129, ["were"] = 130, ["when"] = 131,
|
||||||
|
["your"] = 132, ["can"] = 133, ["there"] = 134, ["has"] = 135, ["been"] = 136, ["if"] = 137,
|
||||||
|
["who"] = 138, ["more"] = 139, ["will"] = 140, ["up"] = 141, ["one"] = 142, ["time"] = 143,
|
||||||
|
["out"] = 144, ["about"] = 145, ["into"] = 146, ["just"] = 147, ["him"] = 148, ["them"] = 149,
|
||||||
|
["me"] = 150, ["my"] = 151, ["than"] = 152, ["now"] = 153, ["do"] = 154, ["would"] = 155,
|
||||||
|
["so"] = 156, ["get"] = 157, ["no"] = 158, ["see"] = 159, ["way"] = 160, ["could"] = 161,
|
||||||
|
["like"] = 162, ["other"] = 163, ["then"] = 164, ["first"] = 165, ["also"] = 166, ["back"] = 167,
|
||||||
|
["after"] = 168, ["use"] = 169, ["two"] = 170, ["how"] = 171, ["our"] = 172, ["work"] = 173,
|
||||||
|
["well"] = 174, ["even"] = 175, ["new"] = 176, ["want"] = 177, ["because"] = 178, ["any"] = 179,
|
||||||
|
["these"] = 180, ["give"] = 181, ["day"] = 182, ["most"] = 183, ["us"] = 184, ["address"] = 185,
|
||||||
|
["street"] = 186, ["city"] = 187, ["state"] = 188, ["zip"] = 189, ["code"] = 190, ["name"] = 191,
|
||||||
|
["number"] = 192
|
||||||
|
};
|
||||||
|
|
||||||
public EmbeddingService(EmbeddingConfiguration config)
|
public EmbeddingService(EmbeddingConfiguration config)
|
||||||
{
|
{
|
||||||
_modelName = config.ModelName;
|
_modelName = config.ModelName;
|
||||||
_dimension = config.Dimension;
|
_dimension = config.Dimension;
|
||||||
|
_useLocalInference = config.UseLocalInference;
|
||||||
|
|
||||||
|
if (_useLocalInference)
|
||||||
|
{
|
||||||
|
// Use local ONNX inference
|
||||||
|
_httpClient = null;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var modelPath = GetModelPath(_modelName);
|
||||||
|
_inferenceSession = new InferenceSession(modelPath);
|
||||||
|
_vocabulary = BaseVocabulary;
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
throw new InvalidOperationException($"Failed to initialize local inference session: {ex.Message}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Use HuggingFace Inference API
|
||||||
_httpClient = new HttpClient
|
_httpClient = new HttpClient
|
||||||
{
|
{
|
||||||
BaseAddress = new Uri("https://api-inference.huggingface.co/models/")
|
BaseAddress = new Uri("https://router.huggingface.co/")
|
||||||
};
|
};
|
||||||
_httpClient.DefaultRequestHeaders.Add("User-Agent", "VectorSearchApp");
|
_httpClient.DefaultRequestHeaders.Add("User-Agent", "VectorSearchApp");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public async Task<float[]> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
|
private static string GetModelPath(string modelName)
|
||||||
{
|
{
|
||||||
var request = new
|
var modelFileName = modelName switch
|
||||||
{
|
{
|
||||||
inputs = text
|
"sentence-transformers/all-MiniLM-L6-v2" => "all-MiniLM-L6-v2.onnx",
|
||||||
|
_ => throw new NotSupportedException($"Model '{modelName}' is not supported for local inference")
|
||||||
};
|
};
|
||||||
|
|
||||||
var response = await _httpClient.PostAsJsonAsync(_modelName, request, cancellationToken);
|
var modelPath = Path.Combine(AppContext.BaseDirectory, "Models", modelFileName);
|
||||||
|
|
||||||
|
if (!File.Exists(modelPath))
|
||||||
|
{
|
||||||
|
modelPath = Path.Combine("Models", modelFileName);
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task<float[]> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
return _useLocalInference
|
||||||
|
? GenerateEmbeddingLocalAsync(text, cancellationToken)
|
||||||
|
: GenerateEmbeddingRemoteAsync(text, cancellationToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Task<float[]> GenerateEmbeddingLocalAsync(string text, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
if (_inferenceSession == null)
|
||||||
|
{
|
||||||
|
throw new InvalidOperationException("Local inference session is not initialized");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize the input text using word-piece tokenization
|
||||||
|
var tokens = Tokenize(text);
|
||||||
|
|
||||||
|
const int maxLength = 128;
|
||||||
|
var inputIds = new long[maxLength];
|
||||||
|
var attentionMask = new long[maxLength];
|
||||||
|
|
||||||
|
// Add [CLS] at the beginning
|
||||||
|
inputIds[0] = 2; // [CLS]
|
||||||
|
attentionMask[0] = 1;
|
||||||
|
|
||||||
|
// Add tokenized words
|
||||||
|
for (int i = 0; i < Math.Min(tokens.Count, maxLength - 2); i++)
|
||||||
|
{
|
||||||
|
inputIds[i + 1] = tokens[i];
|
||||||
|
attentionMask[i + 1] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add [SEP] at the end
|
||||||
|
var tokenCount = Math.Min(tokens.Count + 2, maxLength);
|
||||||
|
inputIds[tokenCount - 1] = 3; // [SEP]
|
||||||
|
attentionMask[tokenCount - 1] = 1;
|
||||||
|
|
||||||
|
var inputIdsTensor = new DenseTensor<long>(inputIds, new[] { 1, maxLength });
|
||||||
|
var attentionMaskTensor = new DenseTensor<long>(attentionMask, new[] { 1, maxLength });
|
||||||
|
|
||||||
|
// Token type IDs (all zeros for single sentence input)
|
||||||
|
var tokenTypeIds = new long[maxLength];
|
||||||
|
|
||||||
|
var inputs = new List<NamedOnnxValue>
|
||||||
|
{
|
||||||
|
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
|
||||||
|
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor),
|
||||||
|
NamedOnnxValue.CreateFromTensor("token_type_ids", new DenseTensor<long>(tokenTypeIds, new[] { 1, maxLength }))
|
||||||
|
};
|
||||||
|
|
||||||
|
var results = _inferenceSession.Run(inputs);
|
||||||
|
|
||||||
|
var output = results.FirstOrDefault(r => r.Name.Contains("last_hidden_state"))
|
||||||
|
?? results.First();
|
||||||
|
|
||||||
|
var hiddenStates = output.AsEnumerable<float>().ToArray();
|
||||||
|
|
||||||
|
// Apply mean pooling over the sequence dimension
|
||||||
|
var pooledOutput = new float[_dimension];
|
||||||
|
var actualLength = attentionMask.Sum(m => m);
|
||||||
|
|
||||||
|
if (actualLength > 0)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < _dimension; i++)
|
||||||
|
{
|
||||||
|
float sum = 0;
|
||||||
|
for (int j = 0; j < maxLength; j++)
|
||||||
|
{
|
||||||
|
if (attentionMask[j] > 0)
|
||||||
|
{
|
||||||
|
sum += hiddenStates[j * _dimension + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pooledOutput[i] = sum / actualLength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Task.FromResult(pooledOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<int> Tokenize(string text)
|
||||||
|
{
|
||||||
|
var tokens = new List<int>();
|
||||||
|
var words = WordTokenize(text);
|
||||||
|
|
||||||
|
foreach (var word in words)
|
||||||
|
{
|
||||||
|
if (word.Length == 1)
|
||||||
|
{
|
||||||
|
// Single character
|
||||||
|
if (_vocabulary!.TryGetValue(word, out var id))
|
||||||
|
{
|
||||||
|
tokens.Add(id);
|
||||||
|
}
|
||||||
|
else if (_vocabulary.TryGetValue(word.ToLower(), out id))
|
||||||
|
{
|
||||||
|
tokens.Add(id);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokens.Add(1); // [UNK]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Multi-character word - try word-piece tokenization
|
||||||
|
var subTokens = WordPieceTokenize(word, _vocabulary!);
|
||||||
|
tokens.AddRange(subTokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<string> WordTokenize(string text)
|
||||||
|
{
|
||||||
|
// Simple word tokenizer that splits on whitespace and punctuation
|
||||||
|
var result = new List<string>();
|
||||||
|
var sb = new StringBuilder();
|
||||||
|
|
||||||
|
foreach (var c in text)
|
||||||
|
{
|
||||||
|
if (char.IsWhiteSpace(c) || char.IsPunctuation(c))
|
||||||
|
{
|
||||||
|
if (sb.Length > 0)
|
||||||
|
{
|
||||||
|
result.Add(sb.ToString());
|
||||||
|
sb.Clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
sb.Append(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sb.Length > 0)
|
||||||
|
{
|
||||||
|
result.Add(sb.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<int> WordPieceTokenize(string word, Dictionary<string, int> vocab)
|
||||||
|
{
|
||||||
|
var tokens = new List<int>();
|
||||||
|
var remaining = word;
|
||||||
|
|
||||||
|
while (!string.IsNullOrEmpty(remaining))
|
||||||
|
{
|
||||||
|
// Try to find the longest matching subword
|
||||||
|
var match = "";
|
||||||
|
var matchLength = 0;
|
||||||
|
|
||||||
|
foreach (var kvp in vocab)
|
||||||
|
{
|
||||||
|
if (kvp.Key.StartsWith("##") && remaining.StartsWith(kvp.Key.Substring(2)))
|
||||||
|
{
|
||||||
|
if (kvp.Key.Length > matchLength)
|
||||||
|
{
|
||||||
|
match = kvp.Key;
|
||||||
|
matchLength = kvp.Key.Length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (remaining.StartsWith(kvp.Key) && !kvp.Key.StartsWith("##"))
|
||||||
|
{
|
||||||
|
if (kvp.Key.Length > matchLength)
|
||||||
|
{
|
||||||
|
match = kvp.Key;
|
||||||
|
matchLength = kvp.Key.Length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchLength == 0)
|
||||||
|
{
|
||||||
|
// No match found
|
||||||
|
if (vocab.TryGetValue(word.ToLower(), out var wordId))
|
||||||
|
{
|
||||||
|
tokens.Add(wordId);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokens.Add(1); // [UNK]
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (match.StartsWith("##"))
|
||||||
|
{
|
||||||
|
tokens.Add(vocab[match]);
|
||||||
|
remaining = remaining.Substring(match.Length - 2);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
tokens.Add(vocab[match]);
|
||||||
|
remaining = remaining.Substring(match.Length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<float[]> GenerateEmbeddingRemoteAsync(string text, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
var url = $"pipeline/feature-extraction/{_modelName}";
|
||||||
|
var request = new
|
||||||
|
{
|
||||||
|
inputs = text,
|
||||||
|
options = new
|
||||||
|
{
|
||||||
|
wait_for_model = true
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var jsonContent = System.Text.Json.JsonSerializer.Serialize(request);
|
||||||
|
var httpContent = new StringContent(jsonContent, System.Text.Encoding.UTF8, "application/json");
|
||||||
|
|
||||||
|
var response = await _httpClient!.PostAsync(url, httpContent, cancellationToken);
|
||||||
|
|
||||||
if (!response.IsSuccessStatusCode)
|
if (!response.IsSuccessStatusCode)
|
||||||
{
|
{
|
||||||
throw new InvalidOperationException($"Failed to generate embedding: {response.StatusCode}");
|
var errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
|
||||||
|
throw new InvalidOperationException($"Failed to generate embedding: {response.StatusCode} - {errorContent}");
|
||||||
}
|
}
|
||||||
|
|
||||||
var result = await response.Content.ReadFromJsonAsync<float[][]>(cancellationToken: cancellationToken);
|
try
|
||||||
|
{
|
||||||
|
var result = await System.Text.Json.JsonSerializer.DeserializeAsync<float[][]>(
|
||||||
|
await response.Content.ReadAsStreamAsync(cancellationToken),
|
||||||
|
cancellationToken: cancellationToken);
|
||||||
|
|
||||||
if (result?.Length > 0 && result[0].Length > 0)
|
if (result?.Length > 0 && result[0].Length > 0)
|
||||||
{
|
{
|
||||||
|
if (result[0].Length != _dimension)
|
||||||
|
{
|
||||||
|
Console.WriteLine($"Warning: Embedding dimension ({result[0].Length}) differs from expected ({_dimension})");
|
||||||
|
}
|
||||||
return result[0];
|
return result[0];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
catch (System.Text.Json.JsonException ex)
|
||||||
|
{
|
||||||
|
var rawContent = await response.Content.ReadAsStringAsync(cancellationToken);
|
||||||
|
throw new InvalidOperationException($"Failed to parse embedding response: {ex.Message}. Raw response: {rawContent.Substring(0, Math.Min(500, rawContent.Length))}");
|
||||||
|
}
|
||||||
|
|
||||||
throw new InvalidOperationException("Failed to generate embedding");
|
throw new InvalidOperationException("Failed to generate embedding: empty result");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -18,6 +18,9 @@
|
|||||||
<None Update="appsettings.json">
|
<None Update="appsettings.json">
|
||||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||||
</None>
|
</None>
|
||||||
|
<None Update="Models\all-MiniLM-L6-v2.onnx">
|
||||||
|
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||||
|
</None>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
</Project>
|
</Project>
|
||||||
|
|||||||
@@ -7,7 +7,9 @@
|
|||||||
},
|
},
|
||||||
"Embedding": {
|
"Embedding": {
|
||||||
"ModelName": "sentence-transformers/all-MiniLM-L6-v2",
|
"ModelName": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
"Dimension": 384
|
"Dimension": 384,
|
||||||
|
"ApiToken": "",
|
||||||
|
"UseLocalInference": true
|
||||||
},
|
},
|
||||||
"App": {
|
"App": {
|
||||||
"BatchSize": 10
|
"BatchSize": 10
|
||||||
|
|||||||
Reference in New Issue
Block a user