From adf09b08c85804787bb68f7097f8dfa3a0ceebcd Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 24 Dec 2021 05:54:16 +0300 Subject: [PATCH] HNSW Fix INSERT algorithm --- TestHNSW/HNSWDemo/Model/Gender.cs | 7 + TestHNSW/HNSWDemo/Model/Person.cs | 51 ++ TestHNSW/HNSWDemo/Program.cs | 851 +----------------- TestHNSW/HNSWDemo/Tests/AccuracityTest.cs | 75 ++ TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs | 26 + TestHNSW/HNSWDemo/Tests/HistogramTest.cs | 69 ++ TestHNSW/HNSWDemo/Tests/ITest.cs | 7 + .../HNSWDemo/Tests/InsertTimeExplosionTest.cs | 28 + TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs | 43 + .../HNSWDemo/Tests/QuantizeAccuracityTest.cs | 79 ++ .../HNSWDemo/Tests/QuantizeHistogramTest.cs | 71 ++ .../Tests/QuantizeInsertTimeExplosionTest.cs | 31 + TestHNSW/HNSWDemo/Tests/SaveRestoreTest.cs | 52 ++ .../HNSWDemo/Utils/QLVectorsDirectCompare.cs | 95 ++ .../HNSWDemo/Utils/QVectorsDirectCompare.cs | 95 ++ .../HNSWDemo/Utils/VectorsDirectCompare.cs | 95 ++ ZeroLevel.HNSW/Model/NSWOptions.cs | 6 + .../Services/AutomaticGraphClusterer.cs | 37 +- ZeroLevel.HNSW/Services/Layer.cs | 110 ++- ZeroLevel.HNSW/Services/LinksSet.cs | 166 +--- ZeroLevel.HNSW/Services/VectorSet.cs | 15 +- ZeroLevel.HNSW/SmallWorld.cs | 105 +-- ZeroLevel.HNSW/Utils/CosineDistance.cs | 1 + ZeroLevel.HNSW/Utils/VectorUtils.cs | 13 + 24 files changed, 1026 insertions(+), 1102 deletions(-) create mode 100644 TestHNSW/HNSWDemo/Model/Gender.cs create mode 100644 TestHNSW/HNSWDemo/Model/Person.cs create mode 100644 TestHNSW/HNSWDemo/Tests/AccuracityTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/HistogramTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/ITest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs create mode 100644 TestHNSW/HNSWDemo/Tests/SaveRestoreTest.cs create mode 100644 TestHNSW/HNSWDemo/Utils/QLVectorsDirectCompare.cs create mode 100644 TestHNSW/HNSWDemo/Utils/QVectorsDirectCompare.cs create mode 100644 TestHNSW/HNSWDemo/Utils/VectorsDirectCompare.cs diff --git a/TestHNSW/HNSWDemo/Model/Gender.cs b/TestHNSW/HNSWDemo/Model/Gender.cs new file mode 100644 index 0000000..518e5a3 --- /dev/null +++ b/TestHNSW/HNSWDemo/Model/Gender.cs @@ -0,0 +1,7 @@ +namespace HNSWDemo.Model +{ + public enum Gender + { + Unknown, Male, Feemale + } +} diff --git a/TestHNSW/HNSWDemo/Model/Person.cs b/TestHNSW/HNSWDemo/Model/Person.cs new file mode 100644 index 0000000..bcd031c --- /dev/null +++ b/TestHNSW/HNSWDemo/Model/Person.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Model +{ + public class Person + { + public Gender Gender { get; set; } + public int Age { get; set; } + public long Number { get; set; } + + private static (float[], Person) Generate(int vector_size) + { + var rnd = new Random((int)Environment.TickCount); + var vector = new float[vector_size]; + DefaultRandomGenerator.Instance.NextFloats(vector); + VectorUtils.NormalizeSIMD(vector); + var p = new Person(); + p.Age = rnd.Next(15, 80); + var gr = rnd.Next(0, 3); + p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown; + p.Number = CreateNumber(rnd); + return (vector, p); + } + + public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount) + { + var vectors = new List<(float[], Person)>(); + for (int i = 0; i < vectorsCount; i++) + { + vectors.Add(Generate(vectorSize)); + } + return vectors; + } + + static HashSet _exists = new HashSet(); + private static long CreateNumber(Random rnd) + { + long start_number; + do + { + start_number = 79600000000L; + start_number = start_number + rnd.Next(4, 8) * 10000000; + start_number += rnd.Next(0, 1000000); + } + while (_exists.Add(start_number) == false); + return start_number; + } + } +} diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index b993ef6..487917c 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -1,596 +1,17 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Drawing; -using System.IO; -using System.Linq; -using ZeroLevel.HNSW; -using ZeroLevel.HNSW.Services; +using HNSWDemo.Tests; +using System; namespace HNSWDemo { class Program { - public class VectorsDirectCompare - { - private const int HALF_LONG_BITS = 32; - private readonly IList _vectors; - private readonly Func _distance; - - public VectorsDirectCompare(List vectors, Func distance) - { - _vectors = vectors; - _distance = distance; - } - - public IEnumerable<(int, float)> KNearest(float[] v, int k) - { - var weights = new Dictionary(); - for (int i = 0; i < _vectors.Count; i++) - { - var d = _distance(v, _vectors[i]); - weights[i] = d; - } - return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); - } - - public List> DetectClusters() - { - var links = new SortedList(); - for (int i = 0; i < _vectors.Count; i++) - { - for (int j = i + 1; j < _vectors.Count; j++) - { - long k = (((long)(i)) << HALF_LONG_BITS) + j; - links.Add(k, _distance(_vectors[i], _vectors[j])); - } - } - - // 1. Find R - bound between intra-cluster distances and out-of-cluster distances - var histogram = new Histogram(HistogramMode.SQRT, links.Values); - int threshold = histogram.OTSU(); - var min = histogram.Bounds[threshold - 1]; - var max = histogram.Bounds[threshold]; - var R = (max + min) / 2; - - // 2. Get links with distances less than R - var resultLinks = new SortedList(); - foreach (var pair in links) - { - if (pair.Value < R) - { - resultLinks.Add(pair.Key, pair.Value); - } - } - - // 3. Extract clusters - List> clusters = new List>(); - foreach (var pair in resultLinks) - { - var k = pair.Key; - var id1 = (int)(k >> HALF_LONG_BITS); - var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - - bool found = false; - foreach (var c in clusters) - { - if (c.Contains(id1)) - { - c.Add(id2); - found = true; - break; - } - else if (c.Contains(id2)) - { - c.Add(id1); - found = true; - break; - } - } - if (found == false) - { - var c = new HashSet(); - c.Add(id1); - c.Add(id2); - clusters.Add(c); - } - } - return clusters; - } - } - - public class QVectorsDirectCompare - { - private const int HALF_LONG_BITS = 32; - private readonly IList _vectors; - private readonly Func _distance; - - public QVectorsDirectCompare(List vectors, Func distance) - { - _vectors = vectors; - _distance = distance; - } - - public IEnumerable<(int, float)> KNearest(byte[] v, int k) - { - var weights = new Dictionary(); - for (int i = 0; i < _vectors.Count; i++) - { - var d = _distance(v, _vectors[i]); - weights[i] = d; - } - return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); - } - - public List> DetectClusters() - { - var links = new SortedList(); - for (int i = 0; i < _vectors.Count; i++) - { - for (int j = i + 1; j < _vectors.Count; j++) - { - long k = (((long)(i)) << HALF_LONG_BITS) + j; - links.Add(k, _distance(_vectors[i], _vectors[j])); - } - } - - // 1. Find R - bound between intra-cluster distances and out-of-cluster distances - var histogram = new Histogram(HistogramMode.SQRT, links.Values); - int threshold = histogram.OTSU(); - var min = histogram.Bounds[threshold - 1]; - var max = histogram.Bounds[threshold]; - var R = (max + min) / 2; - - // 2. Get links with distances less than R - var resultLinks = new SortedList(); - foreach (var pair in links) - { - if (pair.Value < R) - { - resultLinks.Add(pair.Key, pair.Value); - } - } - - // 3. Extract clusters - List> clusters = new List>(); - foreach (var pair in resultLinks) - { - var k = pair.Key; - var id1 = (int)(k >> HALF_LONG_BITS); - var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - - bool found = false; - foreach (var c in clusters) - { - if (c.Contains(id1)) - { - c.Add(id2); - found = true; - break; - } - else if (c.Contains(id2)) - { - c.Add(id1); - found = true; - break; - } - } - if (found == false) - { - var c = new HashSet(); - c.Add(id1); - c.Add(id2); - clusters.Add(c); - } - } - return clusters; - } - } - - public class QLVectorsDirectCompare - { - private const int HALF_LONG_BITS = 32; - private readonly IList _vectors; - private readonly Func _distance; - - public QLVectorsDirectCompare(List vectors, Func distance) - { - _vectors = vectors; - _distance = distance; - } - - public IEnumerable<(int, float)> KNearest(long[] v, int k) - { - var weights = new Dictionary(); - for (int i = 0; i < _vectors.Count; i++) - { - var d = _distance(v, _vectors[i]); - weights[i] = d; - } - return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); - } - - public List> DetectClusters() - { - var links = new SortedList(); - for (int i = 0; i < _vectors.Count; i++) - { - for (int j = i + 1; j < _vectors.Count; j++) - { - long k = (((long)(i)) << HALF_LONG_BITS) + j; - links.Add(k, _distance(_vectors[i], _vectors[j])); - } - } - - // 1. Find R - bound between intra-cluster distances and out-of-cluster distances - var histogram = new Histogram(HistogramMode.SQRT, links.Values); - int threshold = histogram.OTSU(); - var min = histogram.Bounds[threshold - 1]; - var max = histogram.Bounds[threshold]; - var R = (max + min) / 2; - - // 2. Get links with distances less than R - var resultLinks = new SortedList(); - foreach (var pair in links) - { - if (pair.Value < R) - { - resultLinks.Add(pair.Key, pair.Value); - } - } - - // 3. Extract clusters - List> clusters = new List>(); - foreach (var pair in resultLinks) - { - var k = pair.Key; - var id1 = (int)(k >> HALF_LONG_BITS); - var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - - bool found = false; - foreach (var c in clusters) - { - if (c.Contains(id1)) - { - c.Add(id2); - found = true; - break; - } - else if (c.Contains(id2)) - { - c.Add(id1); - found = true; - break; - } - } - if (found == false) - { - var c = new HashSet(); - c.Add(id1); - c.Add(id2); - clusters.Add(c); - } - } - return clusters; - } - } - - public enum Gender - { - Unknown, Male, Feemale - } - - public class Person - { - public Gender Gender { get; set; } - public int Age { get; set; } - public long Number { get; set; } - - private static (float[], Person) Generate(int vector_size) - { - var rnd = new Random((int)Environment.TickCount); - var vector = new float[vector_size]; - DefaultRandomGenerator.Instance.NextFloats(vector); - VectorUtils.NormalizeSIMD(vector); - var p = new Person(); - p.Age = rnd.Next(15, 80); - var gr = rnd.Next(0, 3); - p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown; - p.Number = CreateNumber(rnd); - return (vector, p); - } - - public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount) - { - var vectors = new List<(float[], Person)>(); - for (int i = 0; i < vectorsCount; i++) - { - vectors.Add(Generate(vectorSize)); - } - return vectors; - } - - static HashSet _exists = new HashSet(); - private static long CreateNumber(Random rnd) - { - long start_number; - do - { - start_number = 79600000000L; - start_number = start_number + rnd.Next(4, 8) * 10000000; - start_number += rnd.Next(0, 1000000); - } - while (_exists.Add(start_number) == false); - return start_number; - } - } - - private static List RandomVectors(int vectorSize, int vectorsCount) - { - var vectors = new List(); - for (int i = 0; i < vectorsCount; i++) - { - var vector = new float[vectorSize]; - DefaultRandomGenerator.Instance.NextFloats(vector); - VectorUtils.NormalizeSIMD(vector); - vectors.Add(vector); - } - return vectors; - } - - static void Main(string[] args) { - QuantizatorTest(); + new AutoClusteringTest().Run(); Console.WriteLine("Completed"); Console.ReadKey(); } - - static void QAccuracityTest() - { - int K = 200; - var count = 5000; - var testCount = 500; - var dimensionality = 128; - var totalHits = new List(); - var timewatchesNP = new List(); - var timewatchesHNSW = new List(); - var q = new Quantizator(-1f, 1f); - - var samples = RandomVectors(dimensionality, count).Select(v => q.QuantizeToLong(v)).ToList(); - - var sw = new Stopwatch(); - - var test = new QLVectorsDirectCompare(samples, CosineDistance.NonOptimized); - var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, CosineDistance.NonOptimized)); - - sw.Start(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - - Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); - Console.WriteLine("Start test"); - - - var test_vectors = RandomVectors(dimensionality, testCount).Select(v => q.QuantizeToLong(v)).ToList(); - foreach (var v in test_vectors) - { - sw.Restart(); - var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); - sw.Stop(); - timewatchesNP.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var result = world.Search(v, K); - sw.Stop(); - - timewatchesHNSW.Add(sw.ElapsedMilliseconds); - var hits = 0; - foreach (var r in result) - { - if (gt.ContainsKey(r.Item1)) - { - hits++; - } - } - totalHits.Add(hits); - } - - Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); - Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); - Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); - - Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); - Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); - Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); - - Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); - Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); - Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); - } - - static void QInsertTimeExplosionTest() - { - var count = 10000; - var iterationCount = 100; - var dimensionality = 128; - var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); - var q = new Quantizator(-1f, 1f); - for (int i = 0; i < iterationCount; i++) - { - var samples = RandomVectors(dimensionality, count); - sw.Restart(); - var ids = world.AddItems(samples.Select(v => q.QuantizeToLong(v)).ToArray()); - sw.Stop(); - Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); - } - } - - static void AccuracityTest() - { - int K = 200; - var count = 3000; - var testCount = 500; - var dimensionality = 128; - var totalHits = new List(); - var timewatchesNP = new List(); - var timewatchesHNSW = new List(); - - var samples = RandomVectors(dimensionality, count); - - var sw = new Stopwatch(); - - var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized); - var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, CosineDistance.NonOptimized)); - - sw.Start(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - - /* - byte[] dump; - using (var ms = new MemoryStream()) - { - world.Serialize(ms); - dump = ms.ToArray(); - } - Console.WriteLine($"Full dump size: {dump.Length} bytes"); - - - ReadOnlySmallWorld world; - using (var ms = new MemoryStream(dump)) - { - world = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(100, CosineDistance.NonOptimized, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); - } - */ - - Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); - Console.WriteLine("Start test"); - - - var test_vectors = RandomVectors(dimensionality, testCount); - foreach (var v in test_vectors) - { - sw.Restart(); - var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); - sw.Stop(); - timewatchesNP.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var result = world.Search(v, K); - sw.Stop(); - - timewatchesHNSW.Add(sw.ElapsedMilliseconds); - var hits = 0; - foreach (var r in result) - { - if (gt.ContainsKey(r.Item1)) - { - hits++; - } - } - totalHits.Add(hits); - } - - Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); - Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); - Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); - - Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); - Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); - Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); - - Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); - Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); - Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); - } - - static void QuantizatorTest() - { - var samples = RandomVectors(128, 500000); - var min = samples.SelectMany(s => s).Min(); - var max = samples.SelectMany(s => s).Max(); - var q = new Quantizator(min, max); - var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray(); - - // comparing - var list = new List(); - for (int i = 0; i < samples.Count - 1; i++) - { - var v1 = samples[i]; - var v2 = samples[i + 1]; - var dist = CosineDistance.NonOptimized(v1, v2); - - var qv1 = q_samples[i]; - var qv2 = q_samples[i + 1]; - var qdist = CosineDistance.NonOptimized(qv1, qv2); - - list.Add(Math.Abs(dist - qdist)); - } - - Console.WriteLine($"Min diff: {list.Min()}"); - Console.WriteLine($"Avg diff: {list.Average()}"); - Console.WriteLine($"Max diff: {list.Max()}"); - } - - static void SaveRestoreTest() - { - var count = 1000; - var dimensionality = 128; - var samples = RandomVectors(dimensionality, count); - var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); - var sw = new Stopwatch(); - sw.Start(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); - Console.WriteLine("Start test"); - - byte[] dump; - using (var ms = new MemoryStream()) - { - world.Serialize(ms); - dump = ms.ToArray(); - } - Console.WriteLine($"Full dump size: {dump.Length} bytes"); - - byte[] testDump; - var restoredWorld = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); - using (var ms = new MemoryStream(dump)) - { - restoredWorld.Deserialize(ms); - } - - using (var ms = new MemoryStream()) - { - restoredWorld.Serialize(ms); - testDump = ms.ToArray(); - } - if (testDump.Length != dump.Length) - { - Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}"); - return; - } - } - - static void InsertTimeExplosionTest() - { - var count = 10000; - var iterationCount = 100; - var dimensionality = 128; - var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); - for (int i = 0; i < iterationCount; i++) - { - var samples = RandomVectors(dimensionality, count); - sw.Restart(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); - } - } - + /* static void TestOnMnist() { @@ -646,272 +67,8 @@ namespace HNSWDemo { Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); } - - } - - static void AutoClusteringTest() - { - var vectors = RandomVectors(128, 3000); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - world.AddItems(vectors); - var clusters = AutomaticGraphClusterer.DetectClusters(world); - Console.WriteLine($"Found {clusters.Count} clusters"); - for (int i = 0; i < clusters.Count; i++) - { - Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); - } - } - - static void HistogramTest() - { - var vectors = RandomVectors(128, 3000); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - world.AddItems(vectors); - var histogram = world.GetHistogram(); - - int threshold = histogram.OTSU(); - var min = histogram.Bounds[threshold - 1]; - var max = histogram.Bounds[threshold]; - var R = (max + min) / 2; - - DrawHistogram(histogram, @"D:\hist.jpg"); - } - - static void DrawHistogram(Histogram histogram, string filename) - { - var wb = 1200 / histogram.Values.Length; - var k = 600.0f / (float)histogram.Values.Max(); - - var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m); - int threshold = histogram.OTSU(); - - using (var bmp = new Bitmap(1200, 600)) - { - using (var g = Graphics.FromImage(bmp)) - { - for (int i = 0; i(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); - var ids = world.AddItems(samples.ToArray()); - - Console.WriteLine("Start test"); - - byte[] dump; - using (var ms = new MemoryStream()) - { - world.Serialize(ms); - dump = ms.ToArray(); - } - Console.WriteLine($"Full dump size: {dump.Length} bytes"); - - ReadOnlySmallWorld compactWorld; - using (var ms = new MemoryStream(dump)) - { - compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits), ms); - } - - // Compare worlds outputs - int K = 200; - var hits = 0; - var miss = 0; - var testCount = 1000; - var sw = new Stopwatch(); - var timewatchesHNSW = new List(); - var timewatchesHNSWCompact = new List(); - var test_vectors = RandomVectors(dimensionality, testCount); - - foreach (var v in test_vectors) - { - sw.Restart(); - var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet(); - sw.Stop(); - timewatchesHNSW.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet(); - sw.Stop(); - timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds); - - foreach (var r in result) - { - if (gt.Contains(r)) - { - hits++; - } - else - { - miss++; - } - } - } - - byte[] smallWorldDump; - using (var ms = new MemoryStream()) - { - compactWorld.Serialize(ms); - smallWorldDump = ms.ToArray(); - } - var p = smallWorldDump.Length * 100.0f / dump.Length; - Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); - - Console.WriteLine($"HITS: {hits}"); - Console.WriteLine($"MISSES: {miss}"); - - Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); - Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); - Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); - - Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms"); - Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms"); - Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms"); } - static void TransformToCompactWorldTestWithAccuracity() - { - var count = 10000; - var dimensionality = 128; - var samples = RandomVectors(dimensionality, count); - - var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits); - var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); - var ids = world.AddItems(samples.ToArray()); - - Console.WriteLine("Start test"); - - byte[] dump; - using (var ms = new MemoryStream()) - { - world.Serialize(ms); - dump = ms.ToArray(); - } - - ReadOnlySmallWorld compactWorld; - using (var ms = new MemoryStream(dump)) - { - compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits), ms); - } - - // Compare worlds outputs - int K = 200; - var hits = 0; - var miss = 0; - - var testCount = 2000; - var sw = new Stopwatch(); - var timewatchesNP = new List(); - var timewatchesHNSW = new List(); - var timewatchesHNSWCompact = new List(); - var test_vectors = RandomVectors(dimensionality, testCount); - - var totalHitsHNSW = new List(); - var totalHitsHNSWCompact = new List(); - - foreach (var v in test_vectors) - { - var npHitsHNSW = 0; - var npHitsHNSWCompact = 0; - - sw.Restart(); - var gtNP = test.KNearest(v, K).Select(p => p.Item1).ToHashSet(); - sw.Stop(); - timewatchesNP.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var gt = world.Search(v, K).Select(e => e.Item1).ToHashSet(); - sw.Stop(); - timewatchesHNSW.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var result = compactWorld.Search(v, K).Select(e => e.Item1).ToHashSet(); - sw.Stop(); - timewatchesHNSWCompact.Add(sw.ElapsedMilliseconds); - - foreach (var r in result) - { - if (gt.Contains(r)) - { - hits++; - } - else - { - miss++; - } - if (gtNP.Contains(r)) - { - npHitsHNSWCompact++; - } - } - - foreach (var r in gt) - { - if (gtNP.Contains(r)) - { - npHitsHNSW++; - } - } - - totalHitsHNSW.Add(npHitsHNSW); - totalHitsHNSWCompact.Add(npHitsHNSWCompact); - } - - byte[] smallWorldDump; - using (var ms = new MemoryStream()) - { - compactWorld.Serialize(ms); - smallWorldDump = ms.ToArray(); - } - var p = smallWorldDump.Length * 100.0f / dump.Length; - Console.WriteLine($"Full dump size: {dump.Length} bytes"); - Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); - - Console.WriteLine($"HITS: {hits}"); - Console.WriteLine($"MISSES: {miss}"); - - Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); - Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); - Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); - - Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); - Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); - Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); - - Console.WriteLine($"MIN HNSWCompact TIME: {timewatchesHNSWCompact.Min()} ms"); - Console.WriteLine($"AVG HNSWCompact TIME: {timewatchesHNSWCompact.Average()} ms"); - Console.WriteLine($"MAX HNSWCompact TIME: {timewatchesHNSWCompact.Max()} ms"); - - Console.WriteLine($"MIN HNSW Accuracity: {totalHitsHNSW.Min() * 100 / K}%"); - Console.WriteLine($"AVG HNSW Accuracity: {totalHitsHNSW.Average() * 100 / K}%"); - Console.WriteLine($"MAX HNSW Accuracity: {totalHitsHNSW.Max() * 100 / K}%"); - - Console.WriteLine($"MIN HNSWCompact Accuracity: {totalHitsHNSWCompact.Min() * 100 / K}%"); - Console.WriteLine($"AVG HNSWCompact Accuracity: {totalHitsHNSWCompact.Average() * 100 / K}%"); - Console.WriteLine($"MAX HNSWCompact Accuracity: {totalHitsHNSWCompact.Max() * 100 / K}%"); - } static void FilterTest() { diff --git a/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs b/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs new file mode 100644 index 0000000..fe229c5 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/AccuracityTest.cs @@ -0,0 +1,75 @@ +using HNSWDemo.Utils; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Tests +{ + public class AccuracityTest + : ITest + { + private static int K = 200; + private static int count = 3000; + private static int testCount = 500; + private static int dimensionality = 128; + + public void Run() + { + var totalHits = new List(); + var timewatchesNP = new List(); + var timewatchesHNSW = new List(); + + var samples = VectorUtils.RandomVectors(dimensionality, count); + + var sw = new Stopwatch(); + + var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized); + var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, CosineDistance.NonOptimized)); + + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + + Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + var test_vectors = VectorUtils.RandomVectors(dimensionality, testCount); + foreach (var v in test_vectors) + { + sw.Restart(); + var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = world.Search(v, K); + sw.Stop(); + + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalHits.Add(hits); + } + + Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs b/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs new file mode 100644 index 0000000..9bdbbe0 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/AutoClusteringTest.cs @@ -0,0 +1,26 @@ +using System; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services; + +namespace HNSWDemo.Tests +{ + public class AutoClusteringTest + : ITest + { + private static int Count = 3000; + private static int Dimensionality = 128; + + public void Run() + { + var vectors = VectorUtils.RandomVectors(Dimensionality, Count); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean)); + world.AddItems(vectors); + var clusters = AutomaticGraphClusterer.DetectClusters(world); + Console.WriteLine($"Found {clusters.Count} clusters"); + for (int i = 0; i < clusters.Count; i++) + { + Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/HistogramTest.cs b/TestHNSW/HNSWDemo/Tests/HistogramTest.cs new file mode 100644 index 0000000..1b804c3 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/HistogramTest.cs @@ -0,0 +1,69 @@ +using System; +using System.Drawing; +using System.Linq; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Tests +{ + public class HistogramTest + : ITest + { + private static int Count = 3000; + private static int Dimensionality = 128; + private static int Width = 3000; + private static int Height = 3000; + + public void Run() + { + var vectors = VectorUtils.RandomVectors(Dimensionality, Count); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean)); + world.AddItems(vectors); + + var distance = new Func((id1, id2) => Metrics.L2Euclidean(world.GetVector(id1), world.GetVector(id2))); + var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id))); + var histogram = new Histogram(HistogramMode.SQRT, weights); + histogram.Smooth(); + + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + DrawHistogram(histogram, @"D:\hist.jpg"); + } + + static void DrawHistogram(Histogram histogram, string filename) + { + var wb = Width / histogram.Values.Length; + var k = ((float)Height) / (float)histogram.Values.Max(); + + var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m); + int threshold = histogram.OTSU(); + + using (var bmp = new Bitmap(Width, Height)) + { + using (var g = Graphics.FromImage(bmp)) + { + for (int i = 0; i < histogram.Values.Length; i++) + { + var height = (int)(histogram.Values[i] * k); + if (maxes.ContainsKey(i)) + { + g.DrawRectangle(Pens.Red, i * wb, bmp.Height - height, wb, height); + g.DrawRectangle(Pens.Red, i * wb + 1, bmp.Height - height, wb - 1, height); + } + else + { + g.DrawRectangle(Pens.Blue, i * wb, bmp.Height - height, wb, height); + } + if (i == threshold) + { + g.DrawLine(Pens.Green, i * wb + wb / 2, 0, i * wb + wb / 2, bmp.Height); + } + } + } + bmp.Save(filename); + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/ITest.cs b/TestHNSW/HNSWDemo/Tests/ITest.cs new file mode 100644 index 0000000..6ee3d4c --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/ITest.cs @@ -0,0 +1,7 @@ +namespace HNSWDemo.Tests +{ + public interface ITest + { + void Run(); + } +} diff --git a/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs b/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs new file mode 100644 index 0000000..2641b36 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/InsertTimeExplosionTest.cs @@ -0,0 +1,28 @@ +using System; +using System.Diagnostics; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Tests +{ + public class InsertTimeExplosionTest + : ITest + { + private static int Count = 10000; + private static int IterationCount = 100; + private static int Dimensionality = 128; + + public void Run() + { + var sw = new Stopwatch(); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); + for (int i = 0; i < IterationCount; i++) + { + var samples = VectorUtils.RandomVectors(Dimensionality, Count); + sw.Restart(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs new file mode 100644 index 0000000..a6dee3d --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/QuantizatorTest.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services; + +namespace HNSWDemo.Tests +{ + public class QuantizatorTest + : ITest + { + private static int Count = 500000; + private static int Dimensionality = 128; + + public void Run() + { + var samples = VectorUtils.RandomVectors(Dimensionality, Count); + var min = samples.SelectMany(s => s).Min(); + var max = samples.SelectMany(s => s).Max(); + var q = new Quantizator(min, max); + var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray(); + + // comparing + var list = new List(); + for (int i = 0; i < samples.Count - 1; i++) + { + var v1 = samples[i]; + var v2 = samples[i + 1]; + var dist = CosineDistance.NonOptimized(v1, v2); + + var qv1 = q_samples[i]; + var qv2 = q_samples[i + 1]; + var qdist = CosineDistance.NonOptimized(qv1, qv2); + + list.Add(Math.Abs(dist - qdist)); + } + + Console.WriteLine($"Min diff: {list.Min()}"); + Console.WriteLine($"Avg diff: {list.Average()}"); + Console.WriteLine($"Max diff: {list.Max()}"); + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs new file mode 100644 index 0000000..6e3bcc3 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/QuantizeAccuracityTest.cs @@ -0,0 +1,79 @@ +using HNSWDemo.Utils; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services; + +namespace HNSWDemo.Tests +{ + public class QuantizeAccuracityTest + : ITest + { + private static int Count = 5000; + private static int Dimensionality = 128; + private static int K = 200; + private static int TestCount =500; + + public void Run() + { + var totalHits = new List(); + var timewatchesNP = new List(); + var timewatchesHNSW = new List(); + var q = new Quantizator(-1f, 1f); + + var s = VectorUtils.RandomVectors(Dimensionality, Count); + var samples = s.Select(v => q.QuantizeToLong(v)).ToList(); + + var sw = new Stopwatch(); + + var test = new VectorsDirectCompare(s, CosineDistance.NonOptimized); + var world = new SmallWorld(NSWOptions.Create(6, 8, 100, 100, CosineDistance.NonOptimized)); + + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + + Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + var tv = VectorUtils.RandomVectors(Dimensionality, TestCount); + var test_vectors = tv.Select(v => q.QuantizeToLong(v)).ToList(); + for (int i = 0; i < tv.Count; i++) + { + sw.Restart(); + var gt = test.KNearest(tv[i], K).ToDictionary(p => p.Item1, p => p.Item2); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = world.Search(test_vectors[i], K); + sw.Stop(); + + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalHits.Add(hits); + } + + Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs new file mode 100644 index 0000000..4e81ee2 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/QuantizeHistogramTest.cs @@ -0,0 +1,71 @@ +using System; +using System.Drawing; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services; + +namespace HNSWDemo.Tests +{ + public class QuantizeHistogramTest + : ITest + { + private static int Count = 3000; + private static int Dimensionality = 128; + private static int Width = 3000; + private static int Height = 3000; + + public void Run() + { + var vectors = VectorUtils.RandomVectors(Dimensionality, Count); + var q = new Quantizator(-1f, 1f); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, CosineDistance.NonOptimized)); + world.AddItems(vectors.Select(v => q.QuantizeToLong(v)).ToList()); + + var distance = new Func((id1, id2) => CosineDistance.NonOptimized(world.GetVector(id1), world.GetVector(id2))); + var weights = world.GetLinks().SelectMany(pair => pair.Value.Select(id => distance(pair.Key, id))); + var histogram = new Histogram(HistogramMode.SQRT, weights); + histogram.Smooth(); + + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + DrawHistogram(histogram, @"D:\hist.jpg"); + } + + static void DrawHistogram(Histogram histogram, string filename) + { + var wb = Width / histogram.Values.Length; + var k = ((float)Height) / (float)histogram.Values.Max(); + + var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m); + int threshold = histogram.OTSU(); + + using (var bmp = new Bitmap(Width, Height)) + { + using (var g = Graphics.FromImage(bmp)) + { + for (int i = 0; i < histogram.Values.Length; i++) + { + var height = (int)(histogram.Values[i] * k); + if (maxes.ContainsKey(i)) + { + g.DrawRectangle(Pens.Red, i * wb, bmp.Height - height, wb, height); + g.DrawRectangle(Pens.Red, i * wb + 1, bmp.Height - height, wb - 1, height); + } + else + { + g.DrawRectangle(Pens.Blue, i * wb, bmp.Height - height, wb, height); + } + if (i == threshold) + { + g.DrawLine(Pens.Green, i * wb + wb / 2, 0, i * wb + wb / 2, bmp.Height); + } + } + } + bmp.Save(filename); + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs b/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs new file mode 100644 index 0000000..70fe553 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/QuantizeInsertTimeExplosionTest.cs @@ -0,0 +1,31 @@ +using System; +using System.Diagnostics; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services; + +namespace HNSWDemo.Tests +{ + public class QuantizeInsertTimeExplosionTest + : ITest + { + private static int Count = 10000; + private static int IterationCount = 100; + private static int Dimensionality = 128; + + public void Run() + { + var sw = new Stopwatch(); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); + var q = new Quantizator(-1f, 1f); + for (int i = 0; i < IterationCount; i++) + { + var samples = VectorUtils.RandomVectors(Dimensionality, Count); + sw.Restart(); + var ids = world.AddItems(samples.Select(v => q.QuantizeToLong(v)).ToArray()); + sw.Stop(); + Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Tests/SaveRestoreTest.cs b/TestHNSW/HNSWDemo/Tests/SaveRestoreTest.cs new file mode 100644 index 0000000..f135ea6 --- /dev/null +++ b/TestHNSW/HNSWDemo/Tests/SaveRestoreTest.cs @@ -0,0 +1,52 @@ +using System; +using System.Diagnostics; +using System.IO; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Tests +{ + public class SaveRestoreTest + : ITest + { + private static int Count = 1000; + private static int Dimensionality = 128; + + public void Run() + { + var samples = VectorUtils.RandomVectors(Dimensionality, Count); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); + var sw = new Stopwatch(); + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + byte[] dump; + using (var ms = new MemoryStream()) + { + world.Serialize(ms); + dump = ms.ToArray(); + } + Console.WriteLine($"Full dump size: {dump.Length} bytes"); + + byte[] testDump; + var restoredWorld = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); + using (var ms = new MemoryStream(dump)) + { + restoredWorld.Deserialize(ms); + } + + using (var ms = new MemoryStream()) + { + restoredWorld.Serialize(ms); + testDump = ms.ToArray(); + } + if (testDump.Length != dump.Length) + { + Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}"); + return; + } + } + } +} diff --git a/TestHNSW/HNSWDemo/Utils/QLVectorsDirectCompare.cs b/TestHNSW/HNSWDemo/Utils/QLVectorsDirectCompare.cs new file mode 100644 index 0000000..196118a --- /dev/null +++ b/TestHNSW/HNSWDemo/Utils/QLVectorsDirectCompare.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Utils +{ + public class QLVectorsDirectCompare + { + private const int HALF_LONG_BITS = 32; + private readonly IList _vectors; + private readonly Func _distance; + + public QLVectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(long[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)(i)) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } + } +} diff --git a/TestHNSW/HNSWDemo/Utils/QVectorsDirectCompare.cs b/TestHNSW/HNSWDemo/Utils/QVectorsDirectCompare.cs new file mode 100644 index 0000000..92c88d7 --- /dev/null +++ b/TestHNSW/HNSWDemo/Utils/QVectorsDirectCompare.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Utils +{ + public class QVectorsDirectCompare + { + private const int HALF_LONG_BITS = 32; + private readonly IList _vectors; + private readonly Func _distance; + + public QVectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(byte[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)i) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } + } +} diff --git a/TestHNSW/HNSWDemo/Utils/VectorsDirectCompare.cs b/TestHNSW/HNSWDemo/Utils/VectorsDirectCompare.cs new file mode 100644 index 0000000..a000c45 --- /dev/null +++ b/TestHNSW/HNSWDemo/Utils/VectorsDirectCompare.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.HNSW; + +namespace HNSWDemo.Utils +{ + public class VectorsDirectCompare + { + private const int HALF_LONG_BITS = 32; + private readonly IList _vectors; + private readonly Func _distance; + + public VectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(float[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)(i)) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } + } +} diff --git a/ZeroLevel.HNSW/Model/NSWOptions.cs b/ZeroLevel.HNSW/Model/NSWOptions.cs index 7cb0597..1620a75 100644 --- a/ZeroLevel.HNSW/Model/NSWOptions.cs +++ b/ZeroLevel.HNSW/Model/NSWOptions.cs @@ -16,6 +16,12 @@ namespace ZeroLevel.HNSW /// Max search buffer for inserting /// public readonly int EFConstruction; + + public static NSWOptions Create(int v1, int v2, int v3, int v4, Func l2Euclidean, object selectionHeuristic) + { + throw new NotImplementedException(); + } + /// /// Distance function beetween vectors /// diff --git a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs index b799e52..b78fe02 100644 --- a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs +++ b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs @@ -1,4 +1,6 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; namespace ZeroLevel.HNSW.Services { @@ -6,11 +8,20 @@ namespace ZeroLevel.HNSW.Services { private const int HALF_LONG_BITS = 32; - /*public static List> DetectClusters(SmallWorld world) + private class Link { - var links = world.GetNSWLinks(); + public int Id1; + public int Id2; + public float Distance; + } + + public static List> DetectClusters(SmallWorld world) + { + var distance = world.DistanceFunction; + var links = world.GetLinks().SelectMany(pair => pair.Value.Select(id => new Link { Id1 = pair.Key, Id2 = id, Distance = distance(pair.Key, id) })).ToList(); + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances - var histogram = new Histogram(HistogramMode.SQRT, links.Values); + var histogram = new Histogram(HistogramMode.SQRT, links.Select(l => l.Distance)); int threshold = histogram.OTSU(); var min = histogram.Bounds[threshold - 1]; var max = histogram.Bounds[threshold]; @@ -18,23 +29,21 @@ namespace ZeroLevel.HNSW.Services // 2. Get links with distances less than R - var resultLinks = new SortedList(); - foreach (var pair in links) + var resultLinks = new List(); + foreach (var l in links) { - if (pair.Value < R) + if (l.Distance < R) { - resultLinks.Add(pair.Key, pair.Value); + resultLinks.Add(l); } } // 3. Extract clusters List> clusters = new List>(); - foreach (var pair in resultLinks) + foreach (var l in resultLinks) { - var k = pair.Key; - var id1 = (int)(k >> HALF_LONG_BITS); - var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - + var id1 = l.Id1; + var id2 = l.Id2; bool found = false; foreach (var c in clusters) { @@ -60,6 +69,6 @@ namespace ZeroLevel.HNSW.Services } } return clusters; - }*/ + } } } diff --git a/ZeroLevel.HNSW/Services/Layer.cs b/ZeroLevel.HNSW/Services/Layer.cs index e3a098f..3966202 100644 --- a/ZeroLevel.HNSW/Services/Layer.cs +++ b/ZeroLevel.HNSW/Services/Layer.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using ZeroLevel.HNSW.Services; -using ZeroLevel.Services.Pools; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW @@ -16,20 +15,28 @@ namespace ZeroLevel.HNSW private readonly NSWOptions _options; private readonly VectorSet _vectors; private readonly LinksSet _links; - //internal SortedList Links => _links.Links; + public readonly int M; + private readonly Dictionary connections; + internal IDictionary> Links => _links.Links; /// /// There are links е the layer /// internal bool HasLinks => (_links.Count > 0); - private int GetM(bool nswLayer) - { - return nswLayer ? 2 * _options.M : _options.M; - } + internal IEnumerable this[int vector_index] => _links.FindNeighbors(vector_index); /// /// HNSW layer + /// + /// Article: Section 4.1: + /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also + /// has a strong influence on the search performance, especially in case of high quality(high recall) search. + /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors + /// selection heuristic is not used) leads to a very strong performance penalty at high recall. + /// Simulations also suggest that 2∙M is a good choice for Mmax0; + /// setting the parameter higher leads to performance degradation and excessive memory usage." + /// /// /// HNSW graph options /// General vector set @@ -37,7 +44,9 @@ namespace ZeroLevel.HNSW { _options = options; _vectors = vectors; - _links = new LinksSet(GetM(nswLayer), (id1, id2) => options.Distance(_vectors[id1], _vectors[id2])); + M = nswLayer ? 2 * _options.M : _options.M; + _links = new LinksSet(M); + connections = new Dictionary(M + 1); } internal int FindEntryPointAtLayer(Func targetCosts) @@ -58,13 +67,89 @@ namespace ZeroLevel.HNSW return minId; } - internal void AddBidirectionallConnections(int q, int p) + internal void Push(int q, int ep, MinHeap W, Func distance) + { + if (HasLinks == false) + { + AddBidirectionallConnections(q, q); + } + else + { + // W ← SEARCH - LAYER(q, ep, efConstruction, lc) + foreach (var i in KNearestAtLayer(ep, distance, _options.EFConstruction)) + { + W.Push(i); + } + + int count = 0; + connections.Clear(); + while (count < M && W.Count > 0) + { + var nearest = W.Pop(); + var nearest_nearest = GetNeighbors(nearest.Item1).ToArray(); + if (nearest_nearest.Length < M) + { + if (AddBidirectionallConnections(q, nearest.Item1)) + { + connections.Add(nearest.Item1, nearest.Item2); + count++; + } + } + else + { + if ((M - count) < 2) + { + // remove link q - max_q + var max = connections.OrderBy(pair => pair.Value).First(); + RemoveBidirectionallConnections(q, max.Key); + connections.Remove(max.Key); + } + // get nearest_nearest candidate + var mn_id = -1; + var mn_d = float.MinValue; + for (int i = 0; i < nearest_nearest.Length; i++) + { + var d = _options.Distance(_vectors[nearest.Item1], _vectors[nearest_nearest[i]]); + if (q != nearest_nearest[i] && connections.ContainsKey(nearest_nearest[i]) == false) + { + if (mn_id == -1 || d > mn_d) + { + mn_d = d; + mn_id = nearest_nearest[i]; + } + } + } + // remove link neareset - nearest_nearest + RemoveBidirectionallConnections(nearest.Item1, mn_id); + // add link q - neareset + if (AddBidirectionallConnections(q, nearest.Item1)) + { + connections.Add(nearest.Item1, nearest.Item2); + count++; + } + // add link q - max_nearest_nearest + if (AddBidirectionallConnections(q, mn_id)) + { + connections.Add(mn_id, mn_d); + count++; + } + } + } + } + } + + internal void RemoveBidirectionallConnections(int q, int p) + { + _links.RemoveIndex(q, p); + } + + internal bool AddBidirectionallConnections(int q, int p) { if (q == p) { if (EntryPoint >= 0) { - _links.Add(q, EntryPoint); + return _links.Add(q, EntryPoint); } else { @@ -73,14 +158,13 @@ namespace ZeroLevel.HNSW } else { - _links.Add(q, p); + return _links.Add(q, p); } + return false; } private int EntryPoint = -1; - internal void Trim(int id) => _links.Trim(id); - #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf /// /// Algorithm 2 @@ -349,7 +433,7 @@ namespace ZeroLevel.HNSW */ #endregion - private IEnumerable GetNeighbors(int id) => _links.FindNeighbors(id); + internal IEnumerable GetNeighbors(int id) => _links.FindNeighbors(id); public void Serialize(IBinaryWriter writer) { diff --git a/ZeroLevel.HNSW/Services/LinksSet.cs b/ZeroLevel.HNSW/Services/LinksSet.cs index 0201b0a..c3631ff 100644 --- a/ZeroLevel.HNSW/Services/LinksSet.cs +++ b/ZeroLevel.HNSW/Services/LinksSet.cs @@ -6,161 +6,15 @@ using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { - /* - internal struct Link - : IEquatable - { - public int Id; - public float Distance; - - public override int GetHashCode() - { - return Id.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is Link) - return this.Equals((Link)obj); - return false; - } - - public bool Equals(Link other) - { - return this.Id == other.Id; - } - } - - public class LinksSetWithCachee - { - private ConcurrentDictionary> _set = new ConcurrentDictionary>(); - - internal int Count => _set.Count; - - private readonly int _M; - private readonly Func _distance; - - public LinksSetWithCachee(int M, Func distance) - { - _distance = distance; - _M = M; - } - - internal IEnumerable FindNeighbors(int id) - { - if (_set.ContainsKey(id)) - { - return _set[id].Select(l=>l.Id); - } - return Enumerable.Empty(); - } - - internal void RemoveIndex(int id1, int id2) - { - var link1 = new Link { Id = id1 }; - var link2 = new Link { Id = id2 }; - - _set[id1].Remove(link2); - _set[id2].Remove(link1); - } - - internal bool Add(int id1, int id2, float distance) - { - if (!_set.ContainsKey(id1)) - { - _set[id1] = new HashSet(); - } - if (!_set.ContainsKey(id2)) - { - _set[id2] = new HashSet(); - } - var r1 = _set[id1].Add(new Link { Id = id2, Distance = distance }); - var r2 = _set[id2].Add(new Link { Id = id1, Distance = distance }); - - //TrimSet(_set[id1]); - TrimSet(id2, _set[id2]); - - return r1 || r2; - } - - internal void Trim(int id) => TrimSet(id, _set[id]); - - private void TrimSet(int id, HashSet set) - { - if (set.Count > _M) - { - var removeCount = set.Count - _M; - var removeLinks = set.OrderByDescending(n => n.Distance).Take(removeCount).ToArray(); - foreach (var l in removeLinks) - { - set.Remove(l); - } - } - } - - public void Dispose() - { - _set.Clear(); - _set = null; - } - - private const int HALF_LONG_BITS = 32; - public void Serialize(IBinaryWriter writer) - { - writer.WriteBoolean(false); // true - set with weights - writer.WriteInt32(_set.Sum(pair => pair.Value.Count)); - foreach (var record in _set) - { - var id = record.Key; - foreach (var r in record.Value) - { - var key = (((long)(id)) << HALF_LONG_BITS) + r; - writer.WriteLong(key); - } - } - } - - public void Deserialize(IBinaryReader reader) - { - if (reader.ReadBoolean() == false) - { - throw new InvalidOperationException("Incompatible data format. The set does not contain weights."); - } - _set.Clear(); - _set = null; - var count = reader.ReadInt32(); - _set = new ConcurrentDictionary>(); - for (int i = 0; i < count; i++) - { - var key = reader.ReadLong(); - - var id1 = (int)(key >> HALF_LONG_BITS); - var id2 = (int)(key - (((long)id1) << HALF_LONG_BITS)); - - if (!_set.ContainsKey(id1)) - { - _set[id1] = new HashSet(); - } - _set[id1].Add(id2); - } - } - } - */ - public class LinksSet { private ConcurrentDictionary> _set = new ConcurrentDictionary>(); - internal IDictionary> Links => _set; - internal int Count => _set.Count; - private readonly int _M; - private readonly Func _distance; - public LinksSet(int M, Func distance) + public LinksSet(int M) { - _distance = distance; _M = M; } @@ -207,27 +61,9 @@ namespace ZeroLevel.HNSW } var r1 = _set[id1].Add(id2); var r2 = _set[id2].Add(id1); - - TrimSet(id1, _set[id1]); - TrimSet(id2, _set[id2]); - return r1 || r2; } - internal void Trim(int id) => TrimSet(id, _set[id]); - - private void TrimSet(int id, HashSet set) - { - if (set.Count > _M) - { - var removeCount = set.Count - _M; - var removeLinks = set.OrderByDescending(n => _distance(id, n)).Take(removeCount).ToArray(); - foreach (var l in removeLinks) - { - set.Remove(l); - } - } - } public void Dispose() { diff --git a/ZeroLevel.HNSW/Services/VectorSet.cs b/ZeroLevel.HNSW/Services/VectorSet.cs index 761973c..07d1550 100644 --- a/ZeroLevel.HNSW/Services/VectorSet.cs +++ b/ZeroLevel.HNSW/Services/VectorSet.cs @@ -1,11 +1,12 @@ -using System.Collections.Generic; +using System.Collections; +using System.Collections.Generic; using System.Threading; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { internal sealed class VectorSet - : IBinarySerializable + : IEnumerable, IBinarySerializable { private List _set = new List(); private SpinLock _lock = new SpinLock(); @@ -73,5 +74,15 @@ namespace ZeroLevel.HNSW writer.WriteCompatible(r); } } + + public IEnumerator GetEnumerator() + { + return _set.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _set.GetEnumerator(); + } } } diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 74aa09e..2da0dd6 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -18,12 +18,19 @@ namespace ZeroLevel.HNSW private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); + public readonly Func DistanceFunction; + public TItem GetVector(int id) => _vectors[id]; + public IDictionary> GetLinks() => _layers[0].Links; + public SmallWorld(NSWOptions options) { _options = options; _vectors = new VectorSet(); _layers = new Layer[_options.LayersCount]; _layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M); + + DistanceFunction = new Func((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2])); + for (int i = 0; i < _options.LayersCount; i++) { _layers[i] = new Layer(_options, _vectors, i == 0); @@ -127,7 +134,7 @@ namespace ZeroLevel.HNSW var L = MaxLayer; // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level int l = _layerLevelGenerator.GetRandomLayer(); - + // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа int id; float value; @@ -151,42 +158,14 @@ namespace ZeroLevel.HNSW // connecting new node to the small world for (int lc = Math.Min(L, l); lc >= 0; --lc) { - if (_layers[lc].HasLinks == false) - { - _layers[lc].AddBidirectionallConnections(q, q); - } - else + _layers[lc].Push(q, ep, W, distance); + // ep ← W + if (W.TryPeek(out id, out value)) { - // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, _options.EFConstruction)) - { - W.Push(i); - } - - // ep ← W - if (W.TryPeek(out id, out value)) - { - ep = id; - epDist = value; - } - - // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 - var neighbors = SelectBestForConnecting(lc, W); - // add bidirectionall connectionts from neighbors to q at layer lc - // for each e ∈ neighbors // shrink connections if needed - foreach (var e in neighbors) - { - // eConn ← neighbourhood(e) at layer lc - _layers[lc].AddBidirectionallConnections(q, e.Item1); - // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer - if (e.Item2 < epDist) - { - ep = e.Item1; - epDist = e.Item2; - } - } - W.Clear(); + ep = id; + epDist = value; } + W.Clear(); } // if l > L if (l > L) @@ -198,30 +177,37 @@ namespace ZeroLevel.HNSW } } - /// - /// Get maximum allowed connections for the given level. - /// - /// - /// Article: Section 4.1: - /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also - /// has a strong influence on the search performance, especially in case of high quality(high recall) search. - /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors - /// selection heuristic is not used) leads to a very strong performance penalty at high recall. - /// Simulations also suggest that 2∙M is a good choice for Mmax0; - /// setting the parameter higher leads to performance degradation and excessive memory usage." - /// - /// The level of the layer. - /// The maximum number of connections. - private int GetM(int layer) - { - return layer == 0 ? 2 * _options.M : _options.M; - } - - private IEnumerable<(int, float)> SelectBestForConnecting(int layer, MinHeap candidates) + public void TestWorld() { - int count = GetM(layer); - while (count >= 0 && candidates.Count > 0) - yield return candidates.Pop(); + for (var v = 0; v < _vectors.Count; v++) + { + var nearest = _layers[0][v].ToArray(); + if (nearest.Length > _layers[0].M) + { + Console.WriteLine($"V{v}. Count of links ({nearest.Length}) more than max ({_layers[0].M})"); + } + } + // coverage test + var ep = 0; + var visited = new HashSet(); + var next = new Stack(); + next.Push(ep); + while (next.Count > 0) + { + ep = next.Pop(); + visited.Add(ep); + foreach (var n in _layers[0].GetNeighbors(ep)) + { + if (visited.Contains(n) == false) + { + next.Push(n); + } + } + } + if (visited.Count != _vectors.Count) + { + Console.Write($"Vectors count ({_vectors.Count}) less than BFS visited nodes count ({visited.Count})"); + } } /// @@ -385,8 +371,5 @@ namespace ZeroLevel.HNSW } } } - - /*public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT) - => _layers[0].GetHistogram(mode);*/ } } diff --git a/ZeroLevel.HNSW/Utils/CosineDistance.cs b/ZeroLevel.HNSW/Utils/CosineDistance.cs index 11b032f..3b7317c 100644 --- a/ZeroLevel.HNSW/Utils/CosineDistance.cs +++ b/ZeroLevel.HNSW/Utils/CosineDistance.cs @@ -1,6 +1,7 @@ using System; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace ZeroLevel.HNSW { diff --git a/ZeroLevel.HNSW/Utils/VectorUtils.cs b/ZeroLevel.HNSW/Utils/VectorUtils.cs index c4a72eb..7ffb988 100644 --- a/ZeroLevel.HNSW/Utils/VectorUtils.cs +++ b/ZeroLevel.HNSW/Utils/VectorUtils.cs @@ -6,6 +6,19 @@ namespace ZeroLevel.HNSW { public static class VectorUtils { + public static List RandomVectors(int vectorSize, int vectorsCount) + { + var vectors = new List(); + for (int i = 0; i < vectorsCount; i++) + { + var vector = new float[vectorSize]; + DefaultRandomGenerator.Instance.NextFloats(vector); + VectorUtils.NormalizeSIMD(vector); + vectors.Add(vector); + } + return vectors; + } + public static float Magnitude(IList vector) { float magnitude = 0.0f;