Added ability to filter based on score of similarity in a couple ways.

This commit is contained in:
2026-01-13 23:57:09 -05:00
parent 4dac36605b
commit b2c6898140
3 changed files with 25 additions and 7 deletions

View File

@@ -5,4 +5,5 @@ public class AddressEmbedding
public Guid Id { get; set; } public Guid Id { get; set; }
public string FullAddress { get; set; } = string.Empty; public string FullAddress { get; set; } = string.Empty;
public float[] Vector { get; set; } = Array.Empty<float>(); public float[] Vector { get; set; } = Array.Empty<float>();
public float Score { get; set; }
} }

View File

@@ -156,13 +156,29 @@ async Task SearchAddressesAsync(IEmbeddingService embeddingService, IQdrantServi
return; return;
} }
Console.Write("Enter similarity threshold (0.0-1.0, lower = stricter, press Enter for 0.7): ");
var thresholdInput = Console.ReadLine()?.Trim();
float threshold = 0.7f;
if (!string.IsNullOrEmpty(thresholdInput) && float.TryParse(thresholdInput, out var parsedThreshold))
{
threshold = Math.Clamp(parsedThreshold, 0f, 1f);
}
Console.Write("Enter max results to return (press Enter for 5, or '2' for top 2 most similar): ");
var limitInput = Console.ReadLine()?.Trim();
int limit = 5;
if (!string.IsNullOrEmpty(limitInput) && int.TryParse(limitInput, out var parsedLimit))
{
limit = Math.Max(1, parsedLimit);
}
Console.WriteLine("Generating query embedding..."); Console.WriteLine("Generating query embedding...");
try try
{ {
var queryEmbedding = await embeddingService.GenerateEmbeddingAsync(query); var queryEmbedding = await embeddingService.GenerateEmbeddingAsync(query);
Console.WriteLine($"Searching for similar addresses..."); Console.WriteLine($"Searching for similar addresses (threshold: {threshold:F2}, max results: {limit})...");
var results = await qdrantService.SearchSimilarAddressesAsync(queryEmbedding, limit: 5); var results = await qdrantService.SearchSimilarAddressesAsync(queryEmbedding, limit: limit, scoreThreshold: threshold);
if (results.Count == 0) if (results.Count == 0)
{ {
@@ -173,7 +189,7 @@ async Task SearchAddressesAsync(IEmbeddingService embeddingService, IQdrantServi
Console.WriteLine($"\nFound {results.Count} similar address(es):"); Console.WriteLine($"\nFound {results.Count} similar address(es):");
for (int i = 0; i < results.Count; i++) for (int i = 0; i < results.Count; i++)
{ {
Console.WriteLine($" {i + 1}. {results[i].FullAddress}"); Console.WriteLine($" {i + 1}. {results[i].FullAddress} (Score: {results[i].Score:F4})");
} }
} }
catch (Exception ex) catch (Exception ex)

View File

@@ -9,7 +9,7 @@ public interface IQdrantService
{ {
Task InitializeCollectionAsync(CancellationToken cancellationToken = default); Task InitializeCollectionAsync(CancellationToken cancellationToken = default);
Task StoreAddressAsync(Address address, float[] embedding, CancellationToken cancellationToken = default); Task StoreAddressAsync(Address address, float[] embedding, CancellationToken cancellationToken = default);
Task<List<AddressEmbedding>> SearchSimilarAddressesAsync(float[] queryEmbedding, int limit = 5, CancellationToken cancellationToken = default); Task<List<AddressEmbedding>> SearchSimilarAddressesAsync(float[] queryEmbedding, int limit = 5, float scoreThreshold = 0.0f, CancellationToken cancellationToken = default);
} }
public class QdrantService : IQdrantService public class QdrantService : IQdrantService
@@ -55,15 +55,16 @@ public class QdrantService : IQdrantService
await _client.UpsertAsync(_collectionName, new[] { point }, cancellationToken: cancellationToken); await _client.UpsertAsync(_collectionName, new[] { point }, cancellationToken: cancellationToken);
} }
public async Task<List<AddressEmbedding>> SearchSimilarAddressesAsync(float[] queryEmbedding, int limit = 5, CancellationToken cancellationToken = default) public async Task<List<AddressEmbedding>> SearchSimilarAddressesAsync(float[] queryEmbedding, int limit = 5, float scoreThreshold = 0.0f, CancellationToken cancellationToken = default)
{ {
var results = await _client.SearchAsync(_collectionName, queryEmbedding, limit: (ulong)limit, cancellationToken: cancellationToken); var results = await _client.SearchAsync(_collectionName, queryEmbedding, limit: (ulong)limit, scoreThreshold: scoreThreshold, cancellationToken: cancellationToken);
return results.Select(r => new AddressEmbedding return results.Select(r => new AddressEmbedding
{ {
Id = Guid.Parse(r.Id.Uuid), Id = Guid.Parse(r.Id.Uuid),
FullAddress = r.Payload["address"].StringValue, FullAddress = r.Payload["address"].StringValue,
Vector = Array.Empty<float>() Vector = Array.Empty<float>(),
Score = r.Score
}).ToList(); }).ToList();
} }
} }