diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 7263796..b993ef6 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -99,77 +99,499 @@ namespace HNSWDemo } } + public class QVectorsDirectCompare + { + private const int HALF_LONG_BITS = 32; + private readonly IList _vectors; + private readonly Func _distance; + + public QVectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(byte[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)(i)) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } + } + + public class QLVectorsDirectCompare + { + private const int HALF_LONG_BITS = 32; + private readonly IList _vectors; + private readonly Func _distance; + + public QLVectorsDirectCompare(List vectors, Func distance) + { + _vectors = vectors; + _distance = distance; + } + + public IEnumerable<(int, float)> KNearest(long[] v, int k) + { + var weights = new Dictionary(); + for (int i = 0; i < _vectors.Count; i++) + { + var d = _distance(v, _vectors[i]); + weights[i] = d; + } + return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); + } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)(i)) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } + } + public enum Gender { - Unknown, Male, Feemale + Unknown, Male, Feemale + } + + public class Person + { + public Gender Gender { get; set; } + public int Age { get; set; } + public long Number { get; set; } + + private static (float[], Person) Generate(int vector_size) + { + var rnd = new Random((int)Environment.TickCount); + var vector = new float[vector_size]; + DefaultRandomGenerator.Instance.NextFloats(vector); + VectorUtils.NormalizeSIMD(vector); + var p = new Person(); + p.Age = rnd.Next(15, 80); + var gr = rnd.Next(0, 3); + p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown; + p.Number = CreateNumber(rnd); + return (vector, p); + } + + public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount) + { + var vectors = new List<(float[], Person)>(); + for (int i = 0; i < vectorsCount; i++) + { + vectors.Add(Generate(vectorSize)); + } + return vectors; + } + + static HashSet _exists = new HashSet(); + private static long CreateNumber(Random rnd) + { + long start_number; + do + { + start_number = 79600000000L; + start_number = start_number + rnd.Next(4, 8) * 10000000; + start_number += rnd.Next(0, 1000000); + } + while (_exists.Add(start_number) == false); + return start_number; + } + } + + private static List RandomVectors(int vectorSize, int vectorsCount) + { + var vectors = new List(); + for (int i = 0; i < vectorsCount; i++) + { + var vector = new float[vectorSize]; + DefaultRandomGenerator.Instance.NextFloats(vector); + VectorUtils.NormalizeSIMD(vector); + vectors.Add(vector); + } + return vectors; + } + + + static void Main(string[] args) + { + QuantizatorTest(); + Console.WriteLine("Completed"); + Console.ReadKey(); + } + + static void QAccuracityTest() + { + int K = 200; + var count = 5000; + var testCount = 500; + var dimensionality = 128; + var totalHits = new List(); + var timewatchesNP = new List(); + var timewatchesHNSW = new List(); + var q = new Quantizator(-1f, 1f); + + var samples = RandomVectors(dimensionality, count).Select(v => q.QuantizeToLong(v)).ToList(); + + var sw = new Stopwatch(); + + var test = new QLVectorsDirectCompare(samples, CosineDistance.NonOptimized); + var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, CosineDistance.NonOptimized)); + + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + + Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + + var test_vectors = RandomVectors(dimensionality, testCount).Select(v => q.QuantizeToLong(v)).ToList(); + foreach (var v in test_vectors) + { + sw.Restart(); + var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = world.Search(v, K); + sw.Stop(); + + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalHits.Add(hits); + } + + Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + } + + static void QInsertTimeExplosionTest() + { + var count = 10000; + var iterationCount = 100; + var dimensionality = 128; + var sw = new Stopwatch(); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); + var q = new Quantizator(-1f, 1f); + for (int i = 0; i < iterationCount; i++) + { + var samples = RandomVectors(dimensionality, count); + sw.Restart(); + var ids = world.AddItems(samples.Select(v => q.QuantizeToLong(v)).ToArray()); + sw.Stop(); + Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); + } + } + + static void AccuracityTest() + { + int K = 200; + var count = 3000; + var testCount = 500; + var dimensionality = 128; + var totalHits = new List(); + var timewatchesNP = new List(); + var timewatchesHNSW = new List(); + + var samples = RandomVectors(dimensionality, count); + + var sw = new Stopwatch(); + + var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized); + var world = new SmallWorld(NSWOptions.Create(8, 12, 100, 100, CosineDistance.NonOptimized)); + + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + + /* + byte[] dump; + using (var ms = new MemoryStream()) + { + world.Serialize(ms); + dump = ms.ToArray(); + } + Console.WriteLine($"Full dump size: {dump.Length} bytes"); + + + ReadOnlySmallWorld world; + using (var ms = new MemoryStream(dump)) + { + world = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(100, CosineDistance.NonOptimized, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + */ + + Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); + + + var test_vectors = RandomVectors(dimensionality, testCount); + foreach (var v in test_vectors) + { + sw.Restart(); + var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = world.Search(v, K); + sw.Stop(); + + timewatchesHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalHits.Add(hits); + } + + Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); + Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); + Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + } + + static void QuantizatorTest() + { + var samples = RandomVectors(128, 500000); + var min = samples.SelectMany(s => s).Min(); + var max = samples.SelectMany(s => s).Max(); + var q = new Quantizator(min, max); + var q_samples = samples.Select(s => q.QuantizeToLong(s)).ToArray(); + + // comparing + var list = new List(); + for (int i = 0; i < samples.Count - 1; i++) + { + var v1 = samples[i]; + var v2 = samples[i + 1]; + var dist = CosineDistance.NonOptimized(v1, v2); + + var qv1 = q_samples[i]; + var qv2 = q_samples[i + 1]; + var qdist = CosineDistance.NonOptimized(qv1, qv2); + + list.Add(Math.Abs(dist - qdist)); + } + + Console.WriteLine($"Min diff: {list.Min()}"); + Console.WriteLine($"Avg diff: {list.Average()}"); + Console.WriteLine($"Max diff: {list.Max()}"); } - public class Person + static void SaveRestoreTest() { - public Gender Gender { get; set; } - public int Age { get; set; } - public long Number { get; set; } + var count = 1000; + var dimensionality = 128; + var samples = RandomVectors(dimensionality, count); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); + var sw = new Stopwatch(); + sw.Start(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + Console.WriteLine("Start test"); - private static (float[], Person) Generate(int vector_size) + byte[] dump; + using (var ms = new MemoryStream()) { - var rnd = new Random((int)Environment.TickCount); - var vector = new float[vector_size]; - DefaultRandomGenerator.Instance.NextFloats(vector); - VectorUtils.NormalizeSIMD(vector); - var p = new Person(); - p.Age = rnd.Next(15, 80); - var gr = rnd.Next(0, 3); - p.Gender = (gr == 0) ? Gender.Male : (gr == 1) ? Gender.Feemale : Gender.Unknown; - p.Number = CreateNumber(rnd); - return (vector, p); + world.Serialize(ms); + dump = ms.ToArray(); } + Console.WriteLine($"Full dump size: {dump.Length} bytes"); - public static List<(float[], Person)> GenerateRandom(int vectorSize, int vectorsCount) + byte[] testDump; + var restoredWorld = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); + using (var ms = new MemoryStream(dump)) { - var vectors = new List<(float[], Person)>(); - for (int i = 0; i < vectorsCount; i++) - { - vectors.Add(Generate(vectorSize)); - } - return vectors; + restoredWorld.Deserialize(ms); } - static HashSet _exists = new HashSet(); - private static long CreateNumber(Random rnd) + using (var ms = new MemoryStream()) { - long start_number; - do - { - start_number = 79600000000L; - start_number = start_number + rnd.Next(4, 8) * 10000000; - start_number += rnd.Next(0, 1000000); - } - while (_exists.Add(start_number) == false); - return start_number; + restoredWorld.Serialize(ms); + testDump = ms.ToArray(); } - } - - private static List RandomVectors(int vectorSize, int vectorsCount) - { - var vectors = new List(); - for (int i = 0; i < vectorsCount; i++) + if (testDump.Length != dump.Length) { - var vector = new float[vectorSize]; - DefaultRandomGenerator.Instance.NextFloats(vector); - VectorUtils.NormalizeSIMD(vector); - vectors.Add(vector); + Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}"); + return; } - return vectors; } - - static void Main(string[] args) + static void InsertTimeExplosionTest() { - InsertTimeExplosionTest(); - Console.WriteLine("Completed"); - Console.ReadKey(); + var count = 10000; + var iterationCount = 100; + var dimensionality = 128; + var sw = new Stopwatch(); + var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized)); + for (int i = 0; i < iterationCount; i++) + { + var samples = RandomVectors(dimensionality, count); + sw.Restart(); + var ids = world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); + } } + /* static void TestOnMnist() { int imageCount, rowCount, colCount; @@ -227,77 +649,74 @@ namespace HNSWDemo } - static void AutoClusteringTest() - { - var vectors = RandomVectors(128, 3000); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - world.AddItems(vectors); - var clusters = AutomaticGraphClusterer.DetectClusters(world); - Console.WriteLine($"Found {clusters.Count} clusters"); - for (int i = 0; i < clusters.Count; i++) - { - Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); - } - } + static void AutoClusteringTest() + { + var vectors = RandomVectors(128, 3000); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + world.AddItems(vectors); + var clusters = AutomaticGraphClusterer.DetectClusters(world); + Console.WriteLine($"Found {clusters.Count} clusters"); + for (int i = 0; i < clusters.Count; i++) + { + Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); + } + } - static void HistogramTest() - { - var vectors = RandomVectors(128, 3000); - var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - world.AddItems(vectors); - var histogram = world.GetHistogram(); + static void HistogramTest() + { + var vectors = RandomVectors(128, 3000); + var world = SmallWorld.CreateWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + world.AddItems(vectors); + var histogram = world.GetHistogram(); - int threshold = histogram.OTSU(); - var min = histogram.Bounds[threshold - 1]; - var max = histogram.Bounds[threshold]; - var R = (max + min) / 2; + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; - DrawHistogram(histogram, @"D:\hist.jpg"); - } + DrawHistogram(histogram, @"D:\hist.jpg"); + } static void DrawHistogram(Histogram histogram, string filename) { - /* while (histogram.CountSignChanges() > 3) - { - histogram.Smooth(); - }*/ - var wb = 1200 / histogram.Values.Length; - var k = 600.0f / (float)histogram.Values.Max(); + var wb = 1200 / histogram.Values.Length; + var k = 600.0f / (float)histogram.Values.Max(); - var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m); - int threshold = histogram.OTSU(); + var maxes = histogram.GetMaximums().ToDictionary(m => m.Index, m => m); + int threshold = histogram.OTSU(); using (var bmp = new Bitmap(1200, 600)) { - using (var g = Graphics.FromImage(bmp)) { - for (int i = 0; i < histogram.Values.Length; i++) + using (var g = Graphics.FromImage(bmp)) + { + for (int i = 0; i(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); var ids = world.AddItems(samples.ToArray()); Console.WriteLine("Start test"); @@ -313,7 +732,7 @@ namespace HNSWDemo ReadOnlySmallWorld compactWorld; using (var ms = new MemoryStream(dump)) { - compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits), ms); } // Compare worlds outputs @@ -379,7 +798,7 @@ namespace HNSWDemo var samples = RandomVectors(dimensionality, count); var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits); - var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits)); var ids = world.AddItems(samples.ToArray()); Console.WriteLine("Start test"); @@ -394,7 +813,7 @@ namespace HNSWDemo ReadOnlySmallWorld compactWorld; using (var ms = new MemoryStream(dump)) { - compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits), ms); } // Compare worlds outputs @@ -493,63 +912,8 @@ namespace HNSWDemo Console.WriteLine($"AVG HNSWCompact Accuracity: {totalHitsHNSWCompact.Average() * 100 / K}%"); Console.WriteLine($"MAX HNSWCompact Accuracity: {totalHitsHNSWCompact.Max() * 100 / K}%"); } - - static void SaveRestoreTest() - { - var count = 1000; - var dimensionality = 128; - var samples = RandomVectors(dimensionality, count); - var world = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - var sw = new Stopwatch(); - sw.Start(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); - Console.WriteLine("Start test"); - - byte[] dump; - using (var ms = new MemoryStream()) - { - world.Serialize(ms); - dump = ms.ToArray(); - } - Console.WriteLine($"Full dump size: {dump.Length} bytes"); - - byte[] testDump; - var restoredWorld = new SmallWorld(NSWOptions.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - using (var ms = new MemoryStream(dump)) - { - restoredWorld.Deserialize(ms); - } - - using (var ms = new MemoryStream()) - { - restoredWorld.Serialize(ms); - testDump = ms.ToArray(); - } - if (testDump.Length != dump.Length) - { - Console.WriteLine($"Incorrect restored size. Got {testDump.Length}. Expected: {dump.Length}"); - return; - } - - ReadOnlySmallWorld compactWorld; - using (var ms = new MemoryStream(dump)) - { - compactWorld = SmallWorld.CreateReadOnlyWorldFrom(NSWReadOnlyOption.Create(200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); - } - - byte[] smallWorldDump; - using (var ms = new MemoryStream()) - { - compactWorld.Serialize(ms); - smallWorldDump = ms.ToArray(); - } - var p = smallWorldDump.Length * 100.0f / dump.Length; - Console.WriteLine($"Compact dump size: {smallWorldDump.Length} bytes. Decrease: {100 - p}%"); - } - - static void FilterTest() + + static void FilterTest() { var count = 1000; var testCount = 100; @@ -598,82 +962,6 @@ namespace HNSWDemo Console.WriteLine($"SUCCESS: {hits}"); Console.WriteLine($"ERROR: {miss}"); } - - static void AccuracityTest() - { - int K = 200; - var count = 2000; - var testCount = 1000; - var dimensionality = 128; - var totalHits = new List(); - var timewatchesNP = new List(); - var timewatchesHNSW = new List(); - - var samples = RandomVectors(dimensionality, count); - - var sw = new Stopwatch(); - - var test = new VectorsDirectCompare(samples, CosineDistance.NonOptimized); - var world = new SmallWorld(NSWOptions.Create(6, 12, 100, 100, CosineDistance.NonOptimized, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - - sw.Start(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); - Console.WriteLine("Start test"); - - - var test_vectors = RandomVectors(dimensionality, testCount); - foreach (var v in test_vectors) - { - sw.Restart(); - var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); - sw.Stop(); - timewatchesNP.Add(sw.ElapsedMilliseconds); - - sw.Restart(); - var result = world.Search(v, K); - sw.Stop(); - - timewatchesHNSW.Add(sw.ElapsedMilliseconds); - var hits = 0; - foreach (var r in result) - { - if (gt.ContainsKey(r.Item1)) - { - hits++; - } - } - totalHits.Add(hits); - } - Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); - Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); - Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); - - Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); - Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); - Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); - - Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); - Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); - Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); - } - - static void InsertTimeExplosionTest() - { - var count = 20000; - var iterationCount = 100; - var dimensionality = 128; - var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 8, 150, 150, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - for (int i = 0; i < iterationCount; i++) - { - var samples = RandomVectors(dimensionality, count); - sw.Restart(); - var ids = world.AddItems(samples.ToArray()); - sw.Stop(); - Console.WriteLine($"ITERATION: [{i.ToString("D4")}] COUNT: [{ids.Length}] ELAPSED [{sw.ElapsedMilliseconds} ms]"); - } - } + */ } } diff --git a/ZeroLevel.HNSW/HNSWMap.cs b/ZeroLevel.HNSW/HNSWMap.cs index f6986d8..1d3141d 100644 --- a/ZeroLevel.HNSW/HNSWMap.cs +++ b/ZeroLevel.HNSW/HNSWMap.cs @@ -1,5 +1,4 @@ -using System.Collections.Concurrent; -using System.Collections.Generic; +using System.Collections.Generic; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW @@ -10,8 +9,23 @@ namespace ZeroLevel.HNSW public class HNSWMap : IBinarySerializable { - private ConcurrentDictionary _map = new ConcurrentDictionary(); - private ConcurrentDictionary _reverse_map = new ConcurrentDictionary(); + private Dictionary _map; + private Dictionary _reverse_map; + + public HNSWMap(int capacity = -1) + { + if (capacity > 0) + { + _map = new Dictionary(capacity); + _reverse_map = new Dictionary(capacity); + } + else + { + _map = new Dictionary(); + _reverse_map = new Dictionary(); + + } + } public void Append(TFeature feature, int vectorId) { @@ -45,8 +59,8 @@ namespace ZeroLevel.HNSW public void Deserialize(IBinaryReader reader) { - this._map = reader.ReadDictionaryAsConcurrent(); - this._reverse_map = reader.ReadDictionaryAsConcurrent(); + this._map = reader.ReadDictionary(); + this._reverse_map = reader.ReadDictionary(); } public void Serialize(IBinaryWriter writer) diff --git a/ZeroLevel.HNSW/Model/Histogram.cs b/ZeroLevel.HNSW/Model/Histogram.cs index ed49a43..4ed35cc 100644 --- a/ZeroLevel.HNSW/Model/Histogram.cs +++ b/ZeroLevel.HNSW/Model/Histogram.cs @@ -21,12 +21,13 @@ namespace ZeroLevel.HNSW public float[] Bounds { get; } public int[] Values { get; } - public Histogram(HistogramMode mode, IList data) + public Histogram(HistogramMode mode, IEnumerable data) { Mode = mode; Min = data.Min(); Max = data.Max(); - int M = mode == HistogramMode.LOG ? (int)(1f + 3.2f * Math.Log(data.Count)) : (int)(Math.Sqrt(data.Count)); + int count = data.Count(); + int M = mode == HistogramMode.LOG ? (int)(1f + 3.2f * Math.Log(count)) : (int)(Math.Sqrt(count)); BoundsPeriod = (Max - Min) / M; Bounds = new float[M - 1]; diff --git a/ZeroLevel.HNSW/Model/NSWOptions.cs b/ZeroLevel.HNSW/Model/NSWOptions.cs index 344b042..7cb0597 100644 --- a/ZeroLevel.HNSW/Model/NSWOptions.cs +++ b/ZeroLevel.HNSW/Model/NSWOptions.cs @@ -21,12 +21,6 @@ namespace ZeroLevel.HNSW /// public readonly Func Distance; - public readonly bool ExpandBestSelection; - - public readonly bool KeepPrunedConnections; - - public readonly NeighbourSelectionHeuristic SelectionHeuristic; - public readonly int LayersCount; @@ -34,29 +28,20 @@ namespace ZeroLevel.HNSW int m, int ef, int ef_construction, - Func distance, - bool expandBestSelection, - bool keepPrunedConnections, - NeighbourSelectionHeuristic selectionHeuristic) + Func distance) { LayersCount = layersCount; M = m; EF = ef; EFConstruction = ef_construction; Distance = distance; - ExpandBestSelection = expandBestSelection; - KeepPrunedConnections = keepPrunedConnections; - SelectionHeuristic = selectionHeuristic; } public static NSWOptions Create(int layersCount, int M, int EF, int EF_construction, - Func distance, - bool expandBestSelection = false, - bool keepPrunedConnections = false, - NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) => - new NSWOptions(layersCount, M, EF, EF_construction, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic); + Func distance) => + new NSWOptions(layersCount, M, EF, EF_construction, distance); } } diff --git a/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs b/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs deleted file mode 100644 index bdc9d97..0000000 --- a/ZeroLevel.HNSW/Model/NSWReadOnlyOption.cs +++ /dev/null @@ -1,44 +0,0 @@ -using System; - -namespace ZeroLevel.HNSW -{ - public sealed class NSWReadOnlyOption - { - /// - /// Max search buffer - /// - public readonly int EF; - /// - /// Distance function beetween vectors - /// - public readonly Func Distance; - - public readonly bool ExpandBestSelection; - - public readonly bool KeepPrunedConnections; - - public readonly NeighbourSelectionHeuristic SelectionHeuristic; - - private NSWReadOnlyOption( - int ef, - Func distance, - bool expandBestSelection, - bool keepPrunedConnections, - NeighbourSelectionHeuristic selectionHeuristic) - { - EF = ef; - Distance = distance; - ExpandBestSelection = expandBestSelection; - KeepPrunedConnections = keepPrunedConnections; - SelectionHeuristic = selectionHeuristic; - } - - public static NSWReadOnlyOption Create( - int EF, - Func distance, - bool expandBestSelection = false, - bool keepPrunedConnections = false, - NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) => - new NSWReadOnlyOption(EF, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic); - } -} diff --git a/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs b/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs deleted file mode 100644 index 48b88fb..0000000 --- a/ZeroLevel.HNSW/Model/NeighbourSelectionHeuristic.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace ZeroLevel.HNSW -{ - /// - /// Type of heuristic to select best neighbours for a node. - /// - public enum NeighbourSelectionHeuristic - { - /// - /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in - /// - SelectSimple, - - /// - /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in - /// - SelectHeuristic - } -} diff --git a/ZeroLevel.HNSW/ReadOnlySmallWorld.cs b/ZeroLevel.HNSW/ReadOnlySmallWorld.cs deleted file mode 100644 index 8241734..0000000 --- a/ZeroLevel.HNSW/ReadOnlySmallWorld.cs +++ /dev/null @@ -1,152 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using ZeroLevel.Services.Serialization; - -namespace ZeroLevel.HNSW -{ - public class ReadOnlySmallWorld - { - private readonly NSWReadOnlyOption _options; - private ReadOnlyVectorSet _vectors; - private ReadOnlyLayer[] _layers; - private int EntryPoint = 0; - private int MaxLayer = 0; - - private ReadOnlySmallWorld() { } - - internal ReadOnlySmallWorld(NSWReadOnlyOption options, Stream stream) - { - _options = options; - Deserialize(stream); - } - - /// - /// Search in the graph K for vectors closest to a given vector - /// - /// Given vector - /// Count of elements for search - /// - /// - public IEnumerable<(int, TItem, float)> Search(TItem vector, int k) - { - foreach (var pair in KNearest(vector, k)) - { - yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); - } - } - - public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context) - { - if (context == null) - { - foreach (var pair in KNearest(vector, k)) - { - yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); - } - } - else - { - foreach (var pair in KNearest(vector, k, context)) - { - yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); - } - } - } - - /// - /// Algorithm 5 - /// - private IEnumerable<(int, float)> KNearest(TItem q, int k) - { - if (_vectors.Count == 0) - { - return Enumerable.Empty<(int, float)>(); - } - var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); - - // W ← ∅ // set for the current nearest elements - var W = new Dictionary(k + 1); - // ep ← get enter point for hnsw - var ep = EntryPoint; - // L ← level of ep // top layer for hnsw - var L = MaxLayer; - // for lc ← L … 1 - for (int layer = L; layer > 0; --layer) - { - // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[layer].KNearestAtLayer(ep, distance, W, 1); - // ep ← get nearest element from W to q - ep = W.OrderBy(p => p.Value).First().Key; - W.Clear(); - } - // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k); - // return K nearest elements from W to q - return W.Select(p => (p.Key, p.Value)); - } - - private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context) - { - if (_vectors.Count == 0) - { - return Enumerable.Empty<(int, float)>(); - } - var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); - - // W ← ∅ // set for the current nearest elements - var W = new Dictionary(k + 1); - // ep ← get enter point for hnsw - var ep = EntryPoint; - // L ← level of ep // top layer for hnsw - var L = MaxLayer; - // for lc ← L … 1 - for (int layer = L; layer > 0; --layer) - { - // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[layer].KNearestAtLayer(ep, distance, W, 1); - // ep ← get nearest element from W to q - ep = W.OrderBy(p => p.Value).First().Key; - W.Clear(); - } - // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k, context); - // return K nearest elements from W to q - return W.Select(p => (p.Key, p.Value)); - } - - public void Serialize(Stream stream) - { - using (var writer = new MemoryStreamWriter(stream)) - { - writer.WriteInt32(EntryPoint); - writer.WriteInt32(MaxLayer); - _vectors.Serialize(writer); - writer.WriteInt32(_layers.Length); - foreach (var l in _layers) - { - l.Serialize(writer); - } - } - } - - public void Deserialize(Stream stream) - { - using (var reader = new MemoryStreamReader(stream)) - { - this.EntryPoint = reader.ReadInt32(); - this.MaxLayer = reader.ReadInt32(); - _vectors = new ReadOnlyVectorSet(); - _vectors.Deserialize(reader); - var countLayers = reader.ReadInt32(); - _layers = new ReadOnlyLayer[countLayers]; - for (int i = 0; i < countLayers; i++) - { - _layers[i] = new ReadOnlyLayer(_vectors); - _layers[i].Deserialize(reader); - } - } - } - } -} diff --git a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs index 05ec8f1..b799e52 100644 --- a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs +++ b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs @@ -6,7 +6,7 @@ namespace ZeroLevel.HNSW.Services { private const int HALF_LONG_BITS = 32; - public static List> DetectClusters(SmallWorld world) + /*public static List> DetectClusters(SmallWorld world) { var links = world.GetNSWLinks(); // 1. Find R - bound between intra-cluster distances and out-of-cluster distances @@ -60,6 +60,6 @@ namespace ZeroLevel.HNSW.Services } } return clusters; - } + }*/ } } diff --git a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs index 10b5f42..5e4a99e 100644 --- a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs +++ b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs @@ -30,75 +30,6 @@ namespace ZeroLevel.HNSW internal int Count => _set.Count; - /// - /// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1 - /// - internal void Relink(int id1, int id2, int id, float distance) - { - long k1old = (((long)(id1)) << HALF_LONG_BITS) + id2; - long k2old = (((long)(id2)) << HALF_LONG_BITS) + id1; - - long k1new = (((long)(id1)) << HALF_LONG_BITS) + id; - long k2new = (((long)(id)) << HALF_LONG_BITS) + id1; - - _rwLock.EnterWriteLock(); - try - { - _set.Remove(k1old); - _set.Remove(k2old); - if (!_set.ContainsKey(k1new)) - _set.Add(k1new, distance); - if (!_set.ContainsKey(k2new)) - _set.Add(k2new, distance); - } - finally - { - _rwLock.ExitWriteLock(); - } - } - - /// - /// Разрывает связи id1 - id2 и id2 - id1, и строит новые id1 - id, id - id1, id2 - id, id - id2 - /// - internal void Relink(int id1, int id2, int id, float distanceToId1, float distanceToId2) - { - long k_id1_id2 = (((long)(id1)) << HALF_LONG_BITS) + id2; - long k_id2_id1 = (((long)(id2)) << HALF_LONG_BITS) + id1; - - long k_id_id1 = (((long)(id)) << HALF_LONG_BITS) + id1; - long k_id1_id = (((long)(id1)) << HALF_LONG_BITS) + id; - - long k_id_id2 = (((long)(id)) << HALF_LONG_BITS) + id2; - long k_id2_id = (((long)(id2)) << HALF_LONG_BITS) + id; - - _rwLock.EnterWriteLock(); - try - { - _set.Remove(k_id1_id2); - _set.Remove(k_id2_id1); - if (!_set.ContainsKey(k_id_id1)) - { - _set.Add(k_id_id1, distanceToId1); - } - if (!_set.ContainsKey(k_id1_id)) - { - _set.Add(k_id1_id, distanceToId1); - } - if (!_set.ContainsKey(k_id_id2)) - { - _set.Add(k_id_id2, distanceToId2); - } - if (!_set.ContainsKey(k_id2_id)) - { - _set.Add(k_id2_id, distanceToId2); - } - } - finally - { - _rwLock.ExitWriteLock(); - } - } - internal IEnumerable<(int, int, float)> FindLinksForId(int id) { _rwLock.EnterReadLock(); @@ -147,43 +78,6 @@ namespace ZeroLevel.HNSW } } - internal void RemoveIndex(int id) - { - long[] forward; - long[] backward; - _rwLock.EnterReadLock(); - try - { - forward = Search(_set, id).Select(pair => pair.Item1).ToArray(); - backward = forward.Select(k => - { - var id1 = k >> HALF_LONG_BITS; - var id2 = k - (id1 << HALF_LONG_BITS); - return (id2 << HALF_LONG_BITS) + id1; - }).ToArray(); - } - finally - { - _rwLock.ExitReadLock(); - } - _rwLock.EnterWriteLock(); - try - { - foreach (var k in forward) - { - _set.Remove(k); - } - foreach (var k in backward) - { - _set.Remove(k); - } - } - finally - { - _rwLock.ExitWriteLock(); - } - } - internal void RemoveIndex(int id1, int id2) { long k1 = (((long)(id1)) << HALF_LONG_BITS) + id2; @@ -230,40 +124,46 @@ namespace ZeroLevel.HNSW return false; } + /* + +function binary_search(A, n, T) is + L := 0 + R := n − 1 + while L ≤ R do + m := floor((L + R) / 2) + if A[m] < T then + L := m + 1 + else if A[m] > T then + R := m − 1 + else: + return m + return unsuccessful + + */ + private static IEnumerable<(long, float)> Search(SortedList set, int index) { - long k = ((long)index) << HALF_LONG_BITS; + long k = ((long)index) << HALF_LONG_BITS; // T int left = 0; int right = set.Count - 1; int mid; long test; - while (left < right) + while (left <= right) { - mid = (right + left) / 2; - test = (set.Keys[mid] >> HALF_LONG_BITS) << HALF_LONG_BITS; + mid = (int)Math.Floor((right + left) / 2d); + test = (set.Keys[mid] >> HALF_LONG_BITS) << HALF_LONG_BITS; // A[m] - if (left == mid || right == mid) + if (test < k) { - if (test == k) - { - return SearchByPosition(set, k, mid); - } - break; + left = mid + 1; } - if (test < k) + else if (test > k) { - left = mid; + right = mid - 1; } else { - if (test == k) - { - return SearchByPosition(set, k, mid); - } - else - { - right = mid; - } + return SearchByPosition(set, k, mid); } } return Enumerable.Empty<(long, float)>(); diff --git a/ZeroLevel.HNSW/Services/Layer.cs b/ZeroLevel.HNSW/Services/Layer.cs index 50d8168..e3a098f 100644 --- a/ZeroLevel.HNSW/Services/Layer.cs +++ b/ZeroLevel.HNSW/Services/Layer.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using ZeroLevel.HNSW.Services; +using ZeroLevel.Services.Pools; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW @@ -14,146 +15,80 @@ namespace ZeroLevel.HNSW { private readonly NSWOptions _options; private readonly VectorSet _vectors; - private readonly CompactBiDirectionalLinksSet _links; - internal SortedList Links => _links.Links; + private readonly LinksSet _links; + //internal SortedList Links => _links.Links; /// /// There are links е the layer /// internal bool HasLinks => (_links.Count > 0); + private int GetM(bool nswLayer) + { + return nswLayer ? 2 * _options.M : _options.M; + } + /// /// HNSW layer /// /// HNSW graph options /// General vector set - internal Layer(NSWOptions options, VectorSet vectors) + internal Layer(NSWOptions options, VectorSet vectors, bool nswLayer) { _options = options; _vectors = vectors; - _links = new CompactBiDirectionalLinksSet(); + _links = new LinksSet(GetM(nswLayer), (id1, id2) => options.Distance(_vectors[id1], _vectors[id2])); } - /// - /// Adding new bidirectional link - /// - /// New node - /// The node with which the connection will be made - /// - /// - /*internal void AddBidirectionallConnections(int q, int p, float qpDistance, bool isMapLayer) + internal int FindEntryPointAtLayer(Func targetCosts) { - // поиск в ширину ближайших узлов к найденному - var nearest = _links.FindLinksForId(p).ToArray(); - // если у найденного узла максимальное количество связей - // if │eConn│ > Mmax // shrink connections of e - if (nearest.Length >= (isMapLayer ? _options.M * 2 : _options.M)) - { - // ищем связь с самой большой дистанцией - float distance = nearest[0].Item3; - int index = 0; - for (int ni = 1; ni < nearest.Length; ni++) - { - // Если осталась ссылка узла на себя, удаляем ее в первую очередь - if (nearest[ni].Item1 == nearest[ni].Item2) - { - index = ni; - break; - } - if (nearest[ni].Item3 > distance) - { - index = ni; - distance = nearest[ni].Item3; - } - } - // делаем перелинковку вставляя новый узел между найденными - var id1 = nearest[index].Item1; - var id2 = nearest[index].Item2; - _links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q])); - } - else + if (_links.Count == 0) return EntryPoint; + var set = new HashSet(_links.Items().Select(p => p.Item1)); + int minId = -1; + float minDist = float.MaxValue; + foreach (var id in set) { - if (nearest.Length == 1 && nearest[0].Item1 == nearest[0].Item2) - { - // убираем связи на самих себя - var id1 = nearest[0].Item1; - var id2 = nearest[0].Item2; - _links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q])); - } - else + var d = targetCosts(id); + if (d < minDist && Math.Abs(d) > float.Epsilon) { - // добавляем связь нового узла к найденному - _links.Add(q, p, qpDistance); + minDist = d; + minId = id; } } - }*/ - - internal void AddBidirectionallConnections(int q, int p, float qpDistance) - { - _links.Add(q, p, qpDistance); + return minId; } - internal void TrimLinks(int q, bool isMapLayer) + internal void AddBidirectionallConnections(int q, int p) { - var M = (isMapLayer ? _options.M * 2 : _options.M); - // поиск в ширину ближайших узлов к найденному - var nearest = _links.FindLinksForId(q).ToArray(); - - if (nearest.Length <= M && nearest.Length > 1) + if (q == p) { - foreach (var l in nearest) + if (EntryPoint >= 0) { - if (l.Item1 == l.Item2) - { - _links.RemoveIndex(l.Item1, l.Item2); - } + _links.Add(q, EntryPoint); } - } - else if (nearest.Length > M) - { - var removeCount = nearest.Length - M; - foreach (var l in nearest.OrderByDescending(n => n.Item3).Take(removeCount)) + else { - _links.RemoveIndex(l.Item1, l.Item2); + EntryPoint = q; } } + else + { + _links.Add(q, p); + } } + private int EntryPoint = -1; - /// - /// Adding a node with a connection to itself - /// - /// - internal void Append(int q) - { - _links.Add(q, q, 0); - } + internal void Trim(int id) => _links.Trim(id); #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf - internal int FindEntryPointAtLayer(Func targetCosts) - { - var set = new HashSet(_links.Items().Select(p => p.Item1)); - int minId = -1; - float minDist = float.MaxValue; - foreach (var id in set) - { - var d = targetCosts(id); - if (d < minDist && Math.Abs(d) > float.Epsilon) - { - minDist = d; - minId = id; - } - } - return minId; - } - /// /// Algorithm 2 /// /// query element /// enter points ep /// Output: ef closest neighbors to q - internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, IEnumerable<(int, float)> w, int ef) + internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, int ef) { /* * v ← ep // set of visited elements @@ -175,55 +110,60 @@ namespace ZeroLevel.HNSW * remove furthest element from W to q * return W */ - var v = new VisitedBitSet(_vectors.Count, _options.M); - // v ← ep // set of visited elements - v.Add(entryPointId); - var W = new MaxHeap(ef + 1); - foreach (var i in w) W.Push(i); + int farthestId; + float farthestDistance; var d = targetCosts(entryPointId); - // C ← ep // set of candidates + + var v = new VisitedBitSet(_vectors.Count, _options.M); + // * v ← ep // set of visited elements + v.Add(entryPointId); + // * C ← ep // set of candidates var C = new MinHeap(ef); C.Push((entryPointId, d)); - // W ← ep // dynamic list of found nearest neighbors + // * W ← ep // dynamic list of found nearest neighbors + var W = new MaxHeap(ef + 1); W.Push((entryPointId, d)); - int farthestId; - float farthestDistance; - - // run bfs + // * while │C│ > 0 while (C.Count > 0) { - // get next candidate to check and expand - var toExpand = C.Pop(); - if (W.TryPeek(out _, out farthestDistance) && toExpand.Item2 > farthestDistance) + // * c ← extract nearest element from C to q + var c = C.Pop(); + // * f ← get furthest element from W to q + // * if distance(c, q) > distance(f, q) + if (W.TryPeek(out _, out farthestDistance) && c.Item2 > farthestDistance) { - // the closest candidate is farther than farthest result + // * break // all elements in W are evaluated break; } - // expand candidate - var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); - for (int i = 0; i < neighboursIds.Length; ++i) + // * for each e ∈ neighbourhood(c) at layer lc // update C and W + foreach (var e in GetNeighbors(c.Item1)) { - int neighbourId = neighboursIds[i]; - if (!v.Contains(neighbourId)) + // * if e ∉ v + if (!v.Contains(e)) { - // enqueue perspective neighbours to expansion list + // * v ← v ⋃ e + v.Add(e); + // * f ← get furthest element from W to q W.TryPeek(out farthestId, out farthestDistance); - var neighbourDistance = targetCosts(neighbourId); - if (W.Count < ef || (farthestId >= 0 && neighbourDistance < farthestDistance)) + var eDistance = targetCosts(e); + // * if distance(e, q) < distance(f, q) or │W│ < ef + if (W.Count < ef || (farthestId >= 0 && eDistance < farthestDistance)) { - C.Push((neighbourId, neighbourDistance)); - - W.Push((neighbourId, neighbourDistance)); + // * C ← C ⋃ e + C.Push((e, eDistance)); + // * W ← W ⋃ e + W.Push((e, eDistance)); + // * if │W│ > ef if (W.Count > ef) { + // * remove furthest element from W to q W.Pop(); } } - v.Add(neighbourId); } } } @@ -238,40 +178,22 @@ namespace ZeroLevel.HNSW /// query element /// enter points ep /// Output: ef closest neighbors to q - internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, IEnumerable<(int, float)> w, int ef, SearchContext context) + /* + internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, int ef, SearchContext context) { - /* - * v ← ep // set of visited elements - * C ← ep // set of candidates - * W ← ep // dynamic list of found nearest neighbors - * while │C│ > 0 - * c ← extract nearest element from C to q - * f ← get furthest element from W to q - * if distance(c, q) > distance(f, q) - * break // all elements in W are evaluated - * for each e ∈ neighbourhood(c) at layer lc // update C and W - * if e ∉ v - * v ← v ⋃ e - * f ← get furthest element from W to q - * if distance(e, q) < distance(f, q) or │W│ < ef - * C ← C ⋃ e - * W ← W ⋃ e - * if │W│ > ef - * remove furthest element from W to q - * return W - */ + int farthestId; + float farthestDistance; + var d = targetCosts(entryPointId); + var v = new VisitedBitSet(_vectors.Count, _options.M); // v ← ep // set of visited elements v.Add(entryPointId); - - var W = new MaxHeap(ef + 1); - foreach (var i in w) W.Push(i); - // C ← ep // set of candidates var C = new MinHeap(ef); - var d = targetCosts(entryPointId); C.Push((entryPointId, d)); // W ← ep // dynamic list of found nearest neighbors + var W = new MaxHeap(ef + 1); + // W ← ep // dynamic list of found nearest neighbors if (context.IsActiveNode(entryPointId)) { W.Push((entryPointId, d)); @@ -281,14 +203,10 @@ namespace ZeroLevel.HNSW { // get next candidate to check and expand var toExpand = C.Pop(); - if (W.Count > 0) + if (W.TryPeek(out _, out farthestDistance) && toExpand.Item2 > farthestDistance) { - if (W.TryPeek(out _, out var dist)) - if (toExpand.Item2 > dist) - { - // the closest candidate is farther than farthest result - break; - } + // the closest candidate is farther than farthest result + break; } // expand candidate @@ -298,11 +216,12 @@ namespace ZeroLevel.HNSW int neighbourId = neighboursIds[i]; if (!v.Contains(neighbourId)) { + W.TryPeek(out farthestId, out farthestDistance); // enqueue perspective neighbours to expansion list var neighbourDistance = targetCosts(neighbourId); if (context.IsActiveNode(neighbourId)) { - if (W.Count < ef || (W.Count > 0 && (W.TryPeek(out _, out var dist) && neighbourDistance < dist))) + if (W.Count < ef || (farthestId >= 0 && neighbourDistance < farthestDistance)) { W.Push((neighbourId, neighbourDistance)); if (W.Count > ef) @@ -311,7 +230,7 @@ namespace ZeroLevel.HNSW } } } - if (W.Count < ef) + if (W.TryPeek(out _, out farthestDistance) && neighbourDistance < farthestDistance) { C.Push((neighbourId, neighbourDistance)); } @@ -323,6 +242,7 @@ namespace ZeroLevel.HNSW v.Clear(); return W; } + */ /// /// Algorithm 2, modified for LookAlike @@ -330,28 +250,9 @@ namespace ZeroLevel.HNSW /// query element /// enter points ep /// Output: ef closest neighbors to q + /* internal IEnumerable<(int, float)> KNearestAtLayer(IEnumerable<(int, float)> w, int ef, SearchContext context) { - /* - * v ← ep // set of visited elements - * C ← ep // set of candidates - * W ← ep // dynamic list of found nearest neighbors - * while │C│ > 0 - * c ← extract nearest element from C to q - * f ← get furthest element from W to q - * if distance(c, q) > distance(f, q) - * break // all elements in W are evaluated - * for each e ∈ neighbourhood(c) at layer lc // update C and W - * if e ∉ v - * v ← v ⋃ e - * f ← get furthest element from W to q - * if distance(e, q) < distance(f, q) or │W│ < ef - * C ← C ⋃ e - * W ← W ⋃ e - * if │W│ > ef - * remove furthest element from W to q - * return W - */ // v ← ep // set of visited elements var v = new VisitedBitSet(_vectors.Count, _options.M); // C ← ep // set of candidates @@ -445,106 +346,10 @@ namespace ZeroLevel.HNSW v.Clear(); return W; } - - /// - /// Algorithm 3 - /// - internal MaxHeap SELECT_NEIGHBORS_SIMPLE(IEnumerable<(int, float)> w, int M) - { - var W = new MaxHeap(w.Count()); - foreach (var i in w) W.Push(i); - var bestN = M; - if (W.Count > bestN) - { - while (W.Count > bestN) - { - W.Pop(); - } - } - return W; - } - - - - /// - /// Algorithm 4 - /// - /// base element - /// candidate elements - /// flag indicating whether or not to extend candidate list - /// flag indicating whether or not to add discarded elements - /// Output: M elements selected by the heuristic - internal MaxHeap SELECT_NEIGHBORS_HEURISTIC(Func distance, IEnumerable<(int, float)> w, int M) - { - // R ← ∅ - var R = new MaxHeap(_options.EFConstruction); - // W ← C // working queue for the candidates - var W = new MaxHeap(_options.EFConstruction + 1); - foreach (var i in w) W.Push(i); - // if extendCandidates // extend candidates by their neighbors - if (_options.ExpandBestSelection) - { - var extendBuffer = new HashSet(); - // for each e ∈ C - foreach (var e in W) - { - var neighbors = GetNeighbors(e.Item1); - // for each e_adj ∈ neighbourhood(e) at layer lc - foreach (var e_adj in neighbors) - { - // if eadj ∉ W - if (extendBuffer.Contains(e_adj) == false) - { - extendBuffer.Add(e_adj); - } - } - } - // W ← W ⋃ eadj - foreach (var id in extendBuffer) - { - W.Push((id, distance(id))); - } - } - - // Wd ← ∅ // queue for the discarded candidates - var Wd = new MinHeap(_options.EFConstruction); - // while │W│ > 0 and │R│< M - while (W.Count > 0 && R.Count < M) - { - // e ← extract nearest element from W to q - var (e, ed) = W.Pop(); - var (fe, fd) = R.Pop(); - - // if e is closer to q compared to any element from R - if (R.Count == 0 || - ed < fd) - { - // R ← R ⋃ e - R.Push((e, ed)); - } - else - { - // Wd ← Wd ⋃ e - Wd.Push((e, ed)); - } - } - // if keepPrunedConnections // add some of the discarded // connections from Wd - if (_options.KeepPrunedConnections) - { - // while │Wd│> 0 and │R│< M - while (Wd.Count > 0 && R.Count < M) - { - // R ← R ⋃ extract nearest element from Wd to q - var nearest = Wd.Pop(); - R.Push((nearest.Item1, nearest.Item2)); - } - } - // return R - return R; - } + */ #endregion - private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); + private IEnumerable GetNeighbors(int id) => _links.FindNeighbors(id); public void Serialize(IBinaryWriter writer) { @@ -556,6 +361,6 @@ namespace ZeroLevel.HNSW _links.Deserialize(reader); } - internal Histogram GetHistogram(HistogramMode mode) => _links.CalculateHistogram(mode); + // internal Histogram GetHistogram(HistogramMode mode) => _links.CalculateHistogram(mode); } } diff --git a/ZeroLevel.HNSW/Services/LinksSet.cs b/ZeroLevel.HNSW/Services/LinksSet.cs new file mode 100644 index 0000000..0201b0a --- /dev/null +++ b/ZeroLevel.HNSW/Services/LinksSet.cs @@ -0,0 +1,279 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW +{ + /* + internal struct Link + : IEquatable + { + public int Id; + public float Distance; + + public override int GetHashCode() + { + return Id.GetHashCode(); + } + + public override bool Equals(object obj) + { + if (obj is Link) + return this.Equals((Link)obj); + return false; + } + + public bool Equals(Link other) + { + return this.Id == other.Id; + } + } + + public class LinksSetWithCachee + { + private ConcurrentDictionary> _set = new ConcurrentDictionary>(); + + internal int Count => _set.Count; + + private readonly int _M; + private readonly Func _distance; + + public LinksSetWithCachee(int M, Func distance) + { + _distance = distance; + _M = M; + } + + internal IEnumerable FindNeighbors(int id) + { + if (_set.ContainsKey(id)) + { + return _set[id].Select(l=>l.Id); + } + return Enumerable.Empty(); + } + + internal void RemoveIndex(int id1, int id2) + { + var link1 = new Link { Id = id1 }; + var link2 = new Link { Id = id2 }; + + _set[id1].Remove(link2); + _set[id2].Remove(link1); + } + + internal bool Add(int id1, int id2, float distance) + { + if (!_set.ContainsKey(id1)) + { + _set[id1] = new HashSet(); + } + if (!_set.ContainsKey(id2)) + { + _set[id2] = new HashSet(); + } + var r1 = _set[id1].Add(new Link { Id = id2, Distance = distance }); + var r2 = _set[id2].Add(new Link { Id = id1, Distance = distance }); + + //TrimSet(_set[id1]); + TrimSet(id2, _set[id2]); + + return r1 || r2; + } + + internal void Trim(int id) => TrimSet(id, _set[id]); + + private void TrimSet(int id, HashSet set) + { + if (set.Count > _M) + { + var removeCount = set.Count - _M; + var removeLinks = set.OrderByDescending(n => n.Distance).Take(removeCount).ToArray(); + foreach (var l in removeLinks) + { + set.Remove(l); + } + } + } + + public void Dispose() + { + _set.Clear(); + _set = null; + } + + private const int HALF_LONG_BITS = 32; + public void Serialize(IBinaryWriter writer) + { + writer.WriteBoolean(false); // true - set with weights + writer.WriteInt32(_set.Sum(pair => pair.Value.Count)); + foreach (var record in _set) + { + var id = record.Key; + foreach (var r in record.Value) + { + var key = (((long)(id)) << HALF_LONG_BITS) + r; + writer.WriteLong(key); + } + } + } + + public void Deserialize(IBinaryReader reader) + { + if (reader.ReadBoolean() == false) + { + throw new InvalidOperationException("Incompatible data format. The set does not contain weights."); + } + _set.Clear(); + _set = null; + var count = reader.ReadInt32(); + _set = new ConcurrentDictionary>(); + for (int i = 0; i < count; i++) + { + var key = reader.ReadLong(); + + var id1 = (int)(key >> HALF_LONG_BITS); + var id2 = (int)(key - (((long)id1) << HALF_LONG_BITS)); + + if (!_set.ContainsKey(id1)) + { + _set[id1] = new HashSet(); + } + _set[id1].Add(id2); + } + } + } + */ + + public class LinksSet + { + private ConcurrentDictionary> _set = new ConcurrentDictionary>(); + + internal IDictionary> Links => _set; + + internal int Count => _set.Count; + + private readonly int _M; + private readonly Func _distance; + + public LinksSet(int M, Func distance) + { + _distance = distance; + _M = M; + } + + internal IEnumerable<(int, int)> FindLinksForId(int id) + { + if (_set.ContainsKey(id)) + { + return _set[id].Select(v => (id, v)); + } + return Enumerable.Empty<(int, int)>(); + } + + internal IEnumerable FindNeighbors(int id) + { + if (_set.ContainsKey(id)) + { + return _set[id]; + } + return Enumerable.Empty(); + } + + internal IEnumerable<(int, int)> Items() + { + return _set + .SelectMany(pair => _set[pair.Key] + .Select(v => (pair.Key, v))); + } + + internal void RemoveIndex(int id1, int id2) + { + _set[id1].Remove(id2); + _set[id2].Remove(id1); + } + + internal bool Add(int id1, int id2) + { + if (!_set.ContainsKey(id1)) + { + _set[id1] = new HashSet(_M + 1); + } + if (!_set.ContainsKey(id2)) + { + _set[id2] = new HashSet(_M + 1); + } + var r1 = _set[id1].Add(id2); + var r2 = _set[id2].Add(id1); + + TrimSet(id1, _set[id1]); + TrimSet(id2, _set[id2]); + + return r1 || r2; + } + + internal void Trim(int id) => TrimSet(id, _set[id]); + + private void TrimSet(int id, HashSet set) + { + if (set.Count > _M) + { + var removeCount = set.Count - _M; + var removeLinks = set.OrderByDescending(n => _distance(id, n)).Take(removeCount).ToArray(); + foreach (var l in removeLinks) + { + set.Remove(l); + } + } + } + + public void Dispose() + { + _set.Clear(); + _set = null; + } + + private const int HALF_LONG_BITS = 32; + public void Serialize(IBinaryWriter writer) + { + writer.WriteBoolean(false); // true - set with weights + writer.WriteInt32(_set.Sum(pair => pair.Value.Count)); + foreach (var record in _set) + { + var id = record.Key; + foreach (var r in record.Value) + { + var key = (((long)(id)) << HALF_LONG_BITS) + r; + writer.WriteLong(key); + } + } + } + + public void Deserialize(IBinaryReader reader) + { + if (reader.ReadBoolean() == false) + { + throw new InvalidOperationException("Incompatible data format. The set does not contain weights."); + } + _set.Clear(); + _set = null; + var count = reader.ReadInt32(); + _set = new ConcurrentDictionary>(); + for (int i = 0; i < count; i++) + { + var key = reader.ReadLong(); + + var id1 = (int)(key >> HALF_LONG_BITS); + var id2 = (int)(key - (((long)id1) << HALF_LONG_BITS)); + + if (!_set.ContainsKey(id1)) + { + _set[id1] = new HashSet(); + } + _set[id1].Add(id2); + } + } + } +} diff --git a/ZeroLevel.HNSW/Services/Pool.cs b/ZeroLevel.HNSW/Services/Pool.cs new file mode 100644 index 0000000..e26f698 --- /dev/null +++ b/ZeroLevel.HNSW/Services/Pool.cs @@ -0,0 +1,287 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; + +namespace ZeroLevel.Services.Pools +{ + public enum LoadingMode { Eager, Lazy, LazyExpanding }; + + public enum AccessMode { FIFO, LIFO, Circular }; + + public sealed class Pool : IDisposable + { + private bool isDisposed; + private Func, T> factory; + private LoadingMode loadingMode; + private IItemStore itemStore; + private int size; + private int count; + private Semaphore sync; + + public Pool(int size, Func, T> factory) + : this(size, factory, LoadingMode.Lazy, AccessMode.FIFO) + { + } + + public Pool(int size, Func, T> factory, + LoadingMode loadingMode, AccessMode accessMode) + { + if (size <= 0) + throw new ArgumentOutOfRangeException("size", size, + "Argument 'size' must be greater than zero."); + if (factory == null) + throw new ArgumentNullException("factory"); + + this.size = size; + this.factory = factory; + sync = new Semaphore(size, size); + this.loadingMode = loadingMode; + this.itemStore = CreateItemStore(accessMode, size); + if (loadingMode == LoadingMode.Eager) + { + PreloadItems(); + } + } + + public T Acquire() + { + sync.WaitOne(); + switch (loadingMode) + { + case LoadingMode.Eager: + return AcquireEager(); + case LoadingMode.Lazy: + return AcquireLazy(); + default: + Debug.Assert(loadingMode == LoadingMode.LazyExpanding, + "Unknown LoadingMode encountered in Acquire method."); + return AcquireLazyExpanding(); + } + } + + public void Release(T item) + { + lock (itemStore) + { + itemStore.Store(item); + } + sync.Release(); + } + + public void Dispose() + { + if (isDisposed) + { + return; + } + isDisposed = true; + if (typeof(IDisposable).IsAssignableFrom(typeof(T))) + { + lock (itemStore) + { + while (itemStore.Count > 0) + { + IDisposable disposable = (IDisposable)itemStore.Fetch(); + disposable.Dispose(); + } + } + } + sync.Close(); + } + + #region Acquisition + + private T AcquireEager() + { + lock (itemStore) + { + return itemStore.Fetch(); + } + } + + private T AcquireLazy() + { + lock (itemStore) + { + if (itemStore.Count > 0) + { + return itemStore.Fetch(); + } + } + Interlocked.Increment(ref count); + return factory(this); + } + + private T AcquireLazyExpanding() + { + bool shouldExpand = false; + if (count < size) + { + int newCount = Interlocked.Increment(ref count); + if (newCount <= size) + { + shouldExpand = true; + } + else + { + // Another thread took the last spot - use the store instead + Interlocked.Decrement(ref count); + } + } + if (shouldExpand) + { + return factory(this); + } + else + { + lock (itemStore) + { + return itemStore.Fetch(); + } + } + } + + private void PreloadItems() + { + for (int i = 0; i < size; i++) + { + T item = factory(this); + itemStore.Store(item); + } + count = size; + } + + #endregion + + #region Collection Wrappers + + interface IItemStore + { + T Fetch(); + void Store(T item); + int Count { get; } + } + + private IItemStore CreateItemStore(AccessMode mode, int capacity) + { + switch (mode) + { + case AccessMode.FIFO: + return new QueueStore(capacity); + case AccessMode.LIFO: + return new StackStore(capacity); + default: + Debug.Assert(mode == AccessMode.Circular, + "Invalid AccessMode in CreateItemStore"); + return new CircularStore(capacity); + } + } + + class QueueStore : Queue, IItemStore + { + public QueueStore(int capacity) : base(capacity) + { + } + + public T Fetch() + { + return Dequeue(); + } + + public void Store(T item) + { + Enqueue(item); + } + } + + class StackStore : Stack, IItemStore + { + public StackStore(int capacity) : base(capacity) + { + } + + public T Fetch() + { + return Pop(); + } + + public void Store(T item) + { + Push(item); + } + } + + class CircularStore : IItemStore + { + private List slots; + private int freeSlotCount; + private int position = -1; + + public CircularStore(int capacity) + { + slots = new List(capacity); + } + + public T Fetch() + { + if (Count == 0) + throw new InvalidOperationException("The buffer is empty."); + + int startPosition = position; + do + { + Advance(); + Slot slot = slots[position]; + if (!slot.IsInUse) + { + slot.IsInUse = true; + --freeSlotCount; + return slot.Item; + } + } while (startPosition != position); + throw new InvalidOperationException("No free slots."); + } + + public void Store(T item) + { + Slot slot = slots.Find(s => object.Equals(s.Item, item)); + if (slot == null) + { + slot = new Slot(item); + slots.Add(slot); + } + slot.IsInUse = false; + ++freeSlotCount; + } + + public int Count + { + get { return freeSlotCount; } + } + + private void Advance() + { + position = (position + 1) % slots.Count; + } + + class Slot + { + public Slot(T item) + { + this.Item = item; + } + + public T Item { get; private set; } + public bool IsInUse { get; set; } + } + } + + #endregion + + public bool IsDisposed + { + get { return isDisposed; } + } + } +} \ No newline at end of file diff --git a/ZeroLevel.HNSW/Services/Quantizator.cs b/ZeroLevel.HNSW/Services/Quantizator.cs new file mode 100644 index 0000000..0364af0 --- /dev/null +++ b/ZeroLevel.HNSW/Services/Quantizator.cs @@ -0,0 +1,86 @@ +using System; + +namespace ZeroLevel.HNSW.Services +{ + public class Quantizator + { + private readonly float _min; + private readonly float _max; + private readonly float _diff; + + public Quantizator(float min, float max) + { + _min = min; + _max = max; + _diff = _max - _min; + } + + public byte[] Quantize(float[] v) + { + var result = new byte[v.Length]; + for (int i = 0; i < v.Length; i++) + { + result[i] = _quantizeInRange(v[i]); + } + return result; + } + + public int[] QuantizeToInt(float[] v) + { + if (v.Length % 4 != 0) + { + throw new ArgumentOutOfRangeException("v.Length % 4 must be zero!"); + } + var result = new int[v.Length / 4]; + byte[] buf = new byte[4]; + for (int i = 0; i < v.Length; i += 4) + { + buf[0] = _quantizeInRange(v[i]); + buf[1] = _quantizeInRange(v[i + 1]); + buf[2] = _quantizeInRange(v[i + 2]); + buf[3] = _quantizeInRange(v[i + 3]); + result[(i >> 2)] = BitConverter.ToInt32(buf); + } + return result; + } + + public long[] QuantizeToLong(float[] v) + { + if (v.Length % 8 != 0) + { + throw new ArgumentOutOfRangeException("v.Length % 8 must be zero!"); + } + var result = new long[v.Length / 8]; + byte[] buf = new byte[8]; + for (int i = 0; i < v.Length; i += 8) + { + buf[0] = _quantizeInRange(v[i + 0]); + buf[1] = _quantizeInRange(v[i + 1]); + buf[2] = _quantizeInRange(v[i + 2]); + buf[3] = _quantizeInRange(v[i + 3]); + buf[4] = _quantizeInRange(v[i + 4]); + buf[5] = _quantizeInRange(v[i + 5]); + buf[6] = _quantizeInRange(v[i + 6]); + buf[7] = _quantizeInRange(v[i + 7]); + + result[(i >> 3)] = BitConverter.ToInt64(buf); + } + return result; + } + + //Map x in [0,1] to {0, 1, ..., 255} + private byte _quantize(float x) + { + x = (int)Math.Floor(256 * x); + if (x < 0) return 0; + else if (x > 255) return 255; + else return (byte)x; + } + + //Map x in [min,max] to {0, 1, ..., 255} + private byte _quantizeInRange(float x) + { + return _quantize((x - _min) / (_diff)); + } + } +} diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs deleted file mode 100644 index 2b190ff..0000000 --- a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyCompactBiDirectionalLinksSet.cs +++ /dev/null @@ -1,107 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using ZeroLevel.Services.Serialization; - -namespace ZeroLevel.HNSW -{ - internal sealed class ReadOnlyCompactBiDirectionalLinksSet - : IBinarySerializable, IDisposable - { - private const int HALF_LONG_BITS = 32; - - private Dictionary _set = new Dictionary(); - - internal int Count => _set.Count; - - internal IEnumerable FindLinksForId(int id) - { - if (_set.ContainsKey(id)) - { - return _set[id]; - } - return Enumerable.Empty(); - } - - public void Dispose() - { - _set.Clear(); - _set = null; - } - - public void Serialize(IBinaryWriter writer) - { - writer.WriteBoolean(false); // false - set without weights - writer.WriteInt32(_set.Count); - foreach (var record in _set) - { - writer.WriteInt32(record.Key); - writer.WriteCollection(record.Value); - } - } - - public void Deserialize(IBinaryReader reader) - { - _set.Clear(); - _set = null; - if (reader.ReadBoolean() == false) - { - var count = reader.ReadInt32(); - _set = new Dictionary(count); - for (int i = 0; i < count; i++) - { - var key = reader.ReadInt32(); - var value = reader.ReadInt32Array(); - _set.Add(key, value); - } - } - else - { - var count = reader.ReadInt32(); - _set = new Dictionary(count); - - // hack, We know that an sortedset has been saved - long key; - int id1, id2; - var prevId = -1; - var set = new HashSet(); - - for (int i = 0; i < count; i++) - { - key = reader.ReadLong(); - id1 = (int)(key >> HALF_LONG_BITS); - id2 = (int)(key - (((long)id1) << HALF_LONG_BITS)); - - reader.ReadFloat(); // SKIP - - if (prevId == -1) - { - prevId = id1; - if (id1 != id2) - { - set.Add(id2); - } - } - else if (prevId != id1) - { - _set.Add(prevId, set.ToArray()); - set.Clear(); - prevId = id1; - } - else - { - if (id1 != id2) - { - set.Add(id2); - } - } - } - if (set.Count > 0) - { - _set.Add(prevId, set.ToArray()); - set.Clear(); - } - } - } - } -} diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs deleted file mode 100644 index 8273cfb..0000000 --- a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs +++ /dev/null @@ -1,209 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using ZeroLevel.Services.Serialization; - -namespace ZeroLevel.HNSW -{ - /// - /// NSW graph - /// - internal sealed class ReadOnlyLayer - : IBinarySerializable - { - private readonly ReadOnlyVectorSet _vectors; - private readonly ReadOnlyCompactBiDirectionalLinksSet _links; - - /// - /// HNSW layer - /// - /// General vector set - internal ReadOnlyLayer(ReadOnlyVectorSet vectors) - { - _vectors = vectors; - _links = new ReadOnlyCompactBiDirectionalLinksSet(); - } - - #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf - /// - /// Algorithm 2 - /// - /// query element - /// enter points ep - /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef) - { - /* - * v ← ep // set of visited elements - * C ← ep // set of candidates - * W ← ep // dynamic list of found nearest neighbors - * while │C│ > 0 - * c ← extract nearest element from C to q - * f ← get furthest element from W to q - * if distance(c, q) > distance(f, q) - * break // all elements in W are evaluated - * for each e ∈ neighbourhood(c) at layer lc // update C and W - * if e ∉ v - * v ← v ⋃ e - * f ← get furthest element from W to q - * if distance(e, q) < distance(f, q) or │W│ < ef - * C ← C ⋃ e - * W ← W ⋃ e - * if │W│ > ef - * remove furthest element from W to q - * return W - */ - var v = new VisitedBitSet(_vectors.Count, 1); - // v ← ep // set of visited elements - v.Add(entryPointId); - // C ← ep // set of candidates - var C = new Dictionary(); - C.Add(entryPointId, targetCosts(entryPointId)); - // W ← ep // dynamic list of found nearest neighbors - W.Add(entryPointId, C[entryPointId]); - - var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); - var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); - var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); - // run bfs - while (C.Count > 0) - { - // get next candidate to check and expand - var toExpand = popCandidate(); - var farthestResult = fartherFromResult(); - if (toExpand.Item2 > farthestResult.Item2) - { - // the closest candidate is farther than farthest result - break; - } - - // expand candidate - var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); - for (int i = 0; i < neighboursIds.Length; ++i) - { - int neighbourId = neighboursIds[i]; - if (!v.Contains(neighbourId)) - { - // enqueue perspective neighbours to expansion list - farthestResult = fartherFromResult(); - - var neighbourDistance = targetCosts(neighbourId); - if (W.Count < ef || neighbourDistance < farthestResult.Item2) - { - C.Add(neighbourId, neighbourDistance); - W.Add(neighbourId, neighbourDistance); - if (W.Count > ef) - { - fartherPopFromResult(); - } - } - v.Add(neighbourId); - } - } - } - C.Clear(); - v.Clear(); - } - - /// - /// Algorithm 2 - /// - /// query element - /// enter points ep - /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, SearchContext context) - { - /* - * v ← ep // set of visited elements - * C ← ep // set of candidates - * W ← ep // dynamic list of found nearest neighbors - * while │C│ > 0 - * c ← extract nearest element from C to q - * f ← get furthest element from W to q - * if distance(c, q) > distance(f, q) - * break // all elements in W are evaluated - * for each e ∈ neighbourhood(c) at layer lc // update C and W - * if e ∉ v - * v ← v ⋃ e - * f ← get furthest element from W to q - * if distance(e, q) < distance(f, q) or │W│ < ef - * C ← C ⋃ e - * W ← W ⋃ e - * if │W│ > ef - * remove furthest element from W to q - * return W - */ - var v = new VisitedBitSet(_vectors.Count, 1); - // v ← ep // set of visited elements - v.Add(entryPointId); - // C ← ep // set of candidates - var C = new Dictionary(); - C.Add(entryPointId, targetCosts(entryPointId)); - // W ← ep // dynamic list of found nearest neighbors - if (context.IsActiveNode(entryPointId)) - { - W.Add(entryPointId, C[entryPointId]); - } - var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); - var farthestDistance = new Func(() => { var pair = W.OrderByDescending(e => e.Value).First(); return pair.Value; }); - var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); - // run bfs - while (C.Count > 0) - { - // get next candidate to check and expand - var toExpand = popCandidate(); - if (W.Count > 0) - { - if (toExpand.Item2 > farthestDistance()) - { - // the closest candidate is farther than farthest result - break; - } - } - - // expand candidate - var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); - for (int i = 0; i < neighboursIds.Length; ++i) - { - int neighbourId = neighboursIds[i]; - if (!v.Contains(neighbourId)) - { - // enqueue perspective neighbours to expansion list - var neighbourDistance = targetCosts(neighbourId); - if (context.IsActiveNode(neighbourId)) - { - if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) - { - W.Add(neighbourId, neighbourDistance); - if (W.Count > ef) - { - fartherPopFromResult(); - } - } - } - if (W.Count < ef) - { - C.Add(neighbourId, neighbourDistance); - } - v.Add(neighbourId); - } - } - } - C.Clear(); - v.Clear(); - } - #endregion - - private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id); - - public void Serialize(IBinaryWriter writer) - { - _links.Serialize(writer); - } - - public void Deserialize(IBinaryReader reader) - { - _links.Deserialize(reader); - } - } -} diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs deleted file mode 100644 index 50801f2..0000000 --- a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyVectorSet.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.Collections.Generic; -using ZeroLevel.Services.Serialization; - -namespace ZeroLevel.HNSW -{ - internal sealed class ReadOnlyVectorSet - : IBinarySerializable - { - private List _set = new List(); - - internal T this[int index] => _set[index]; - internal int Count => _set.Count; - - public void Deserialize(IBinaryReader reader) - { - int count = reader.ReadInt32(); - _set = new List(count + 1); - for (int i = 0; i < count; i++) - { - _set.Add(reader.ReadCompatible()); - } - } - - public void Serialize(IBinaryWriter writer) - { - writer.WriteInt32(_set.Count); - foreach (var r in _set) - { - writer.WriteCompatible(r); - } - } - } -} diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 1be1e5f..74aa09e 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -17,7 +17,6 @@ namespace ZeroLevel.HNSW private int MaxLayer = 0; private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); - internal SortedList GetNSWLinks() => _layers[0].Links; public SmallWorld(NSWOptions options) { @@ -27,7 +26,7 @@ namespace ZeroLevel.HNSW _layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M); for (int i = 0; i < _options.LayersCount; i++) { - _layers[i] = new Layer(_options, _vectors); + _layers[i] = new Layer(_options, _vectors, i == 0); } } @@ -51,7 +50,7 @@ namespace ZeroLevel.HNSW yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); } } - + /* public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context) { if (context == null) @@ -84,6 +83,7 @@ namespace ZeroLevel.HNSW } } } + */ /// /// Adding vectors batch @@ -119,21 +119,23 @@ namespace ZeroLevel.HNSW var W = new MinHeap(_options.EFConstruction + 1); // ep ← get enter point for hnsw var ep = _layers[MaxLayer].FindEntryPointAtLayer(distance); - if (ep == -1) ep = EntryPoint; + if (ep == -1) + ep = EntryPoint; + var epDist = distance(ep); // L ← level of ep // top layer for hnsw var L = MaxLayer; // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level int l = _layerLevelGenerator.GetRandomLayer(); - // for lc ← L … l+1 + // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа - int id; float value; + // for lc ← L … l+1 for (int lc = L; lc > l; --lc) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, W, 1)) + foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, 1)) { W.Push(i); } @@ -143,7 +145,6 @@ namespace ZeroLevel.HNSW ep = id; epDist = value; } - _layers[lc].TrimLinks(q, lc == 0); W.Clear(); } //for lc ← min(L, l) … 0 @@ -152,12 +153,12 @@ namespace ZeroLevel.HNSW { if (_layers[lc].HasLinks == false) { - _layers[lc].Append(q); + _layers[lc].AddBidirectionallConnections(q, q); } else { // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction)) + foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, _options.EFConstruction)) { W.Push(i); } @@ -170,13 +171,13 @@ namespace ZeroLevel.HNSW } // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 - var neighbors = SelectBestForConnecting(lc, distance, W); + var neighbors = SelectBestForConnecting(lc, W); // add bidirectionall connectionts from neighbors to q at layer lc // for each e ∈ neighbors // shrink connections if needed foreach (var e in neighbors) { // eConn ← neighbourhood(e) at layer lc - _layers[lc].AddBidirectionallConnections(q, e.Item1, e.Item2); + _layers[lc].AddBidirectionallConnections(q, e.Item1); // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer if (e.Item2 < epDist) { @@ -184,8 +185,6 @@ namespace ZeroLevel.HNSW epDist = e.Item2; } } - - _layers[lc].TrimLinks(q, lc == 0); W.Clear(); } } @@ -218,11 +217,11 @@ namespace ZeroLevel.HNSW return layer == 0 ? 2 * _options.M : _options.M; } - private IEnumerable<(int, float)> SelectBestForConnecting(int layer, Func distance, IEnumerable<(int, float)> candidates) + private IEnumerable<(int, float)> SelectBestForConnecting(int layer, MinHeap candidates) { - if (_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) - return _layers[layer].SELECT_NEIGHBORS_SIMPLE(candidates, GetM(layer)); - return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer)); + int count = GetM(layer); + while (count >= 0 && candidates.Count > 0) + yield return candidates.Pop(); } /// @@ -252,7 +251,7 @@ namespace ZeroLevel.HNSW for (int layer = L; layer > 0; --layer) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, W, 1)) + foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, 1)) { W.Push(i); } @@ -264,7 +263,7 @@ namespace ZeroLevel.HNSW W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - foreach (var i in _layers[0].KNearestAtLayer(ep, distance, W, k)) + foreach (var i in _layers[0].KNearestAtLayer(ep, distance, k)) { W.Push(i); } @@ -276,7 +275,7 @@ namespace ZeroLevel.HNSW _lockGraph.ExitReadLock(); } } - + /* private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context) { _lockGraph.EnterReadLock(); @@ -300,7 +299,7 @@ namespace ZeroLevel.HNSW for (int layer = L; layer > 0; --layer) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, W, 1)) + foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, 1)) { W.Push(i); } @@ -312,7 +311,7 @@ namespace ZeroLevel.HNSW W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - foreach (var i in _layers[0].KNearestAtLayer(ep, distance, W, k, context)) + foreach (var i in _layers[0].KNearestAtLayer(ep, distance, k, context)) { W.Push(i); } @@ -324,7 +323,9 @@ namespace ZeroLevel.HNSW _lockGraph.ExitReadLock(); } } + */ + /* private IEnumerable<(int, float)> KNearest(int k, SearchContext context) { _lockGraph.EnterReadLock(); @@ -349,6 +350,7 @@ namespace ZeroLevel.HNSW _lockGraph.ExitReadLock(); } } + */ #endregion public void Serialize(Stream stream) @@ -378,13 +380,13 @@ namespace ZeroLevel.HNSW _layers = new Layer[countLayers]; for (int i = 0; i < countLayers; i++) { - _layers[i] = new Layer(_options, _vectors); + _layers[i] = new Layer(_options, _vectors, i == 0); _layers[i].Deserialize(reader); } } } - public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT) - => _layers[0].GetHistogram(mode); + /*public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT) + => _layers[0].GetHistogram(mode);*/ } } diff --git a/ZeroLevel.HNSW/SmallWorldFactory.cs b/ZeroLevel.HNSW/SmallWorldFactory.cs index cbbdd1a..204bb68 100644 --- a/ZeroLevel.HNSW/SmallWorldFactory.cs +++ b/ZeroLevel.HNSW/SmallWorldFactory.cs @@ -6,6 +6,5 @@ namespace ZeroLevel.HNSW { public static SmallWorld CreateWorld(NSWOptions options) => new SmallWorld(options); public static SmallWorld CreateWorldFrom(NSWOptions options, Stream stream) => new SmallWorld(options, stream); - public static ReadOnlySmallWorld CreateReadOnlyWorldFrom(NSWReadOnlyOption options, Stream stream) => new ReadOnlySmallWorld(options, stream); } } diff --git a/ZeroLevel.HNSW/Utils/CosineDistance.cs b/ZeroLevel.HNSW/Utils/CosineDistance.cs index 5531294..11b032f 100644 --- a/ZeroLevel.HNSW/Utils/CosineDistance.cs +++ b/ZeroLevel.HNSW/Utils/CosineDistance.cs @@ -47,6 +47,121 @@ namespace ZeroLevel.HNSW return 1 - similarity; } + public static float NonOptimized(byte[] u, byte[] v) + { + if (u.Length != v.Length) + { + throw new ArgumentException("Vectors have non-matching dimensions"); + } + + float dot = 0.0f; + float nru = 0.0f; + float nrv = 0.0f; + for (int i = 0; i < u.Length; ++i) + { + dot += (float)(u[i] * v[i]); + nru += (float)(u[i] * u[i]); + nrv += (float)(v[i] * v[i]); + } + + var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv)); + return 1 - similarity; + } + + public static float NonOptimized(int[] u, int[] v) + { + if (u.Length != v.Length) + { + throw new ArgumentException("Vectors have non-matching dimensions"); + } + + float dot = 0.0f; + float nru = 0.0f; + float nrv = 0.0f; + byte[] bu; + byte[] bv; + + for (int i = 0; i < u.Length; ++i) + { + bu = BitConverter.GetBytes(u[i]); + bv = BitConverter.GetBytes(v[i]); + + dot += (float)(bu[0] * bv[0]); + nru += (float)(bu[0] * bu[0]); + nrv += (float)(bv[0] * bv[0]); + + dot += (float)(bu[1] * bv[1]); + nru += (float)(bu[1] * bu[1]); + nrv += (float)(bv[1] * bv[1]); + + dot += (float)(bu[2] * bv[2]); + nru += (float)(bu[2] * bu[2]); + nrv += (float)(bv[2] * bv[2]); + + dot += (float)(bu[3] * bv[3]); + nru += (float)(bu[3] * bu[3]); + nrv += (float)(bv[3] * bv[3]); + } + + var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv)); + return 1 - similarity; + } + + public static float NonOptimized(long[] u, long[] v) + { + if (u.Length != v.Length) + { + throw new ArgumentException("Vectors have non-matching dimensions"); + } + + float dot = 0.0f; + float nru = 0.0f; + float nrv = 0.0f; + byte[] bu; + byte[] bv; + + for (int i = 0; i < u.Length; ++i) + { + bu = BitConverter.GetBytes(u[i]); + bv = BitConverter.GetBytes(v[i]); + + dot += (float)(bu[0] * bv[0]); + nru += (float)(bu[0] * bu[0]); + nrv += (float)(bv[0] * bv[0]); + + dot += (float)(bu[1] * bv[1]); + nru += (float)(bu[1] * bu[1]); + nrv += (float)(bv[1] * bv[1]); + + dot += (float)(bu[2] * bv[2]); + nru += (float)(bu[2] * bu[2]); + nrv += (float)(bv[2] * bv[2]); + + dot += (float)(bu[3] * bv[3]); + nru += (float)(bu[3] * bu[3]); + nrv += (float)(bv[3] * bv[3]); + + dot += (float)(bu[4] * bv[4]); + nru += (float)(bu[4] * bu[4]); + nrv += (float)(bv[4] * bv[4]); + + dot += (float)(bu[5] * bv[5]); + nru += (float)(bu[5] * bu[5]); + nrv += (float)(bv[5] * bv[5]); + + dot += (float)(bu[6] * bv[6]); + nru += (float)(bu[6] * bu[6]); + nrv += (float)(bv[6] * bv[6]); + + dot += (float)(bu[7] * bv[7]); + nru += (float)(bu[7] * bu[7]); + nrv += (float)(bv[7] * bv[7]); + } + + var similarity = dot / (float)(Math.Sqrt(nru) * Math.Sqrt(nrv)); + return 1 - similarity; + } + /// /// Calculates cosine distance with assumption that u and v are unit vectors. /// diff --git a/ZeroLevel.sln b/ZeroLevel.sln index 8c9eed5..c49ad82 100644 --- a/ZeroLevel.sln +++ b/ZeroLevel.sln @@ -63,10 +63,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ZeroLevel.HNSW", "ZeroLevel EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSWDemo", "TestHNSW\HNSWDemo\HNSWDemo.csproj", "{E0E9EC21-B958-4018-AE30-67DB88EFCB90}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "temp", "temp\temp.csproj", "{DFE59EBC-B6BC-450C-9D81-394CCAE30498}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "temp2", "temp2\temp2.csproj", "{DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}" -EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -317,30 +313,6 @@ Global {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x64.Build.0 = Release|Any CPU {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x86.ActiveCfg = Release|Any CPU {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x86.Build.0 = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|Any CPU.Build.0 = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x64.ActiveCfg = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x64.Build.0 = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x86.ActiveCfg = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x86.Build.0 = Debug|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|Any CPU.ActiveCfg = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|Any CPU.Build.0 = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x64.ActiveCfg = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x64.Build.0 = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x86.ActiveCfg = Release|Any CPU - {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x86.Build.0 = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|Any CPU.Build.0 = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x64.ActiveCfg = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x64.Build.0 = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x86.ActiveCfg = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x86.Build.0 = Debug|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|Any CPU.ActiveCfg = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|Any CPU.Build.0 = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x64.ActiveCfg = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x64.Build.0 = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x86.ActiveCfg = Release|Any CPU - {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/ZeroLevel/Services/Network/Utils/RequestBuffer.cs b/ZeroLevel/Services/Network/Utils/RequestBuffer.cs index 2f5d2e7..49bc773 100644 --- a/ZeroLevel/Services/Network/Utils/RequestBuffer.cs +++ b/ZeroLevel/Services/Network/Utils/RequestBuffer.cs @@ -8,11 +8,11 @@ namespace ZeroLevel.Network internal sealed class RequestBuffer { private ConcurrentDictionary _requests = new ConcurrentDictionary(); - private static ObjectPool _ri_pool = new ObjectPool(() => new RequestInfo()); + private static Pool _ri_pool = new Pool(128, (p) => new RequestInfo()); public void RegisterForFrame(int identity, Action callback, Action fail = null) { - var ri = _ri_pool.Allocate(); + var ri = _ri_pool.Acquire(); ri.Reset(callback, fail); _requests[identity] = ri; } @@ -23,7 +23,7 @@ namespace ZeroLevel.Network if (_requests.TryRemove(frameId, out ri)) { ri.Fail(message); - _ri_pool.Free(ri); + _ri_pool.Release(ri); } } @@ -33,7 +33,7 @@ namespace ZeroLevel.Network if (_requests.TryRemove(frameId, out ri)) { ri.Success(data); - _ri_pool.Free(ri); + _ri_pool.Release(ri); } } @@ -53,7 +53,7 @@ namespace ZeroLevel.Network { if (_requests.TryRemove(frameIds[i], out ri)) { - _ri_pool.Free(ri); + _ri_pool.Release(ri); } } }