From c37207ad6a2c3176f4aaf967ba3eb98d7602e02c Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 17 Dec 2021 12:52:53 +0300 Subject: [PATCH] HNSW Fix optimized world and level --- TestHNSW/HNSWDemo/Program.cs | 142 +++++++++++++++------ ZeroLevel.HNSW/Services/BinaryHeap.cs | 162 ------------------------ ZeroLevel.HNSW/Services/MaxHeap.cs | 130 +++++++++++++++++++ ZeroLevel.HNSW/Services/MinHeap.cs | 130 +++++++++++++++++++ ZeroLevel.HNSW/Services/OPT/OptLayer.cs | 120 ++++++++++-------- ZeroLevel.HNSW/Services/OPT/OptWorld.cs | 95 ++++++++++---- ZeroLevel.sln | 30 ++++- temp/Program.cs | 45 +++++++ temp/temp.csproj | 12 ++ temp2/Program.cs | 47 +++++++ temp2/temp2.csproj | 12 ++ 11 files changed, 644 insertions(+), 281 deletions(-) delete mode 100644 ZeroLevel.HNSW/Services/BinaryHeap.cs create mode 100644 ZeroLevel.HNSW/Services/MaxHeap.cs create mode 100644 ZeroLevel.HNSW/Services/MinHeap.cs create mode 100644 temp/Program.cs create mode 100644 temp/temp.csproj create mode 100644 temp2/Program.cs create mode 100644 temp2/temp2.csproj diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 3c112b5..86c033d 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -7,6 +7,7 @@ using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; using ZeroLevel.HNSW.Services.OPT; +using ZeroLevel.Services.Serialization; namespace HNSWDemo { @@ -166,50 +167,11 @@ namespace HNSWDemo static void Main(string[] args) { - var samples = RandomVectors(128, 600); - var opt_world = new OptWorld(NSWOptions.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); - opt_world.AddItems(samples); - //AccuracityTest(); + OptAccuracityTest(); Console.WriteLine("Completed"); Console.ReadKey(); } - static void BinaryHeapTest() - { - var heap = new BinaryHeap(); - heap.Push(1, .03f); - heap.Push(2, .05f); - heap.Push(3, .01f); - heap.Push(4, 1.03f); - heap.Push(5, 2.03f); - heap.Push(6, .73f); - - var n = heap.Nearest; - Console.WriteLine($"Nearest: [{n.Item1}] = {n.Item2}"); - var f = heap.Farthest; - Console.WriteLine($"Farthest: [{f.Item1}] = {f.Item2}"); - - Console.WriteLine("From nearest to farthest"); - while (heap.Count > 0) - { - var i = heap.PopNearest(); - Console.WriteLine($"[{i.Item1}] = {i.Item2}"); - } - - heap.Push(1, .03f); - heap.Push(2, .05f); - heap.Push(3, .01f); - heap.Push(4, 1.03f); - heap.Push(5, 2.03f); - heap.Push(6, .73f); - Console.WriteLine("From farthest to nearest"); - while (heap.Count > 0) - { - var i = heap.PopFarthest(); - Console.WriteLine($"[{i.Item1}] = {i.Item2}"); - } - } - static void TestOnMnist() { int imageCount, rowCount, colCount; @@ -732,5 +694,105 @@ namespace HNSWDemo Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); } + + static void OptAccuracityTest() + { + int K = 200; + var count = 5000; + var testCount = 1000; + var dimensionality = 128; + var timewatchesNP = new List(); + var totalOptHits = new List(); + var timewatchesOptHNSW = new List(); + + var totalRestoredHits = new List(); + var timewatchesRestoredHNSW = new List(); + + var samples = RandomVectors(dimensionality, count); + + var sw = new Stopwatch(); + + var test = new VectorsDirectCompare(samples, Metrics.L2Euclidean); + + var opt_world = new OptWorld(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + + sw.Restart(); + var ids = opt_world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items in OPT: {sw.ElapsedMilliseconds} ms"); + + byte[] dump; + using (var ms = new MemoryStream()) + { + opt_world.Serialize(ms); + dump = ms.ToArray(); + } + + SmallWorld compactWorld; + using (var ms = new MemoryStream(dump)) + { + compactWorld = SmallWorld.CreateWorldFrom(NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + + + Console.WriteLine("Start test"); + + + var test_vectors = RandomVectors(dimensionality, testCount); + foreach (var v in test_vectors) + { + sw.Restart(); + var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); + sw.Stop(); + timewatchesNP.Add(sw.ElapsedMilliseconds); + + sw.Restart(); + var result = opt_world.Search(v, K).ToArray(); + sw.Stop(); + timewatchesOptHNSW.Add(sw.ElapsedMilliseconds); + var hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalOptHits.Add(hits); + + sw.Restart(); + result = compactWorld.Search(v, K).ToArray(); + sw.Stop(); + timewatchesRestoredHNSW.Add(sw.ElapsedMilliseconds); + hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalRestoredHits.Add(hits); + } + Console.WriteLine($"MIN Opt Accuracity: {totalOptHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Opt Accuracity: {totalOptHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Opt Accuracity: {totalOptHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN Test Accuracity: {totalRestoredHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Test Accuracity: {totalRestoredHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Test Accuracity: {totalRestoredHits.Max() * 100 / K}%"); + + Console.WriteLine($"MIN Opt HNSW TIME: {timewatchesOptHNSW.Min()} ms"); + Console.WriteLine($"AVG Opt HNSW TIME: {timewatchesOptHNSW.Average()} ms"); + Console.WriteLine($"MAX Opt HNSW TIME: {timewatchesOptHNSW.Max()} ms"); + + Console.WriteLine($"MIN Test HNSW TIME: {timewatchesRestoredHNSW.Min()} ms"); + Console.WriteLine($"AVG Test HNSW TIME: {timewatchesRestoredHNSW.Average()} ms"); + Console.WriteLine($"MAX Test HNSW TIME: {timewatchesRestoredHNSW.Max()} ms"); + + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); + Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); + Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); + } } } diff --git a/ZeroLevel.HNSW/Services/BinaryHeap.cs b/ZeroLevel.HNSW/Services/BinaryHeap.cs deleted file mode 100644 index c03318a..0000000 --- a/ZeroLevel.HNSW/Services/BinaryHeap.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; - -namespace ZeroLevel.HNSW -{ - /// - /// Binary heap wrapper around the - /// It's a max-heap implementation i.e. the maximum element is always on top. - /// - /// The type of the items in the source list. - public class BinaryHeap : - IEnumerable<(int, float)> - { - private static BinaryHeap _empty = new BinaryHeap(); - - public static BinaryHeap Empty => _empty; - - private readonly List<(int, float)> _data; - - private bool _frozen = false; - public (int, float) Nearest => _data[_data.Count - 1]; - public (int, float) Farthest => _data[0]; - - public (int, float) PopNearest() - { - if (this._data.Any()) - { - var result = this._data[this._data.Count - 1]; - this._data.RemoveAt(this._data.Count - 1); - return result; - } - return (-1, -1); - } - - public (int, float) PopFarthest() - { - if (this._data.Any()) - { - var result = this._data.First(); - this._data[0] = this._data.Last(); - this._data.RemoveAt(this._data.Count - 1); - this.SiftDown(0); - return result; - } - return (-1, -1); - } - - public int Count => _data.Count; - public void Clear() => _data.Clear(); - - /// - /// Initializes a new instance of the class. - /// - /// The buffer to store heap items. - public BinaryHeap(int k = -1, bool frozen = false) - { - _frozen = frozen; - if (k > 0) - _data = new List<(int, float)>(k); - else - _data = new List<(int, float)>(); - } - - /// - /// Pushes item to the heap. - /// - /// The item to push. - public void Push(int item, float distance) - { - this._data.Add((item, distance)); - this.SiftUp(this._data.Count - 1); - } - - /// - /// Pops the item from the heap. - /// - /// The popped item. - public (int, float) Pop() - { - if (this._data.Any()) - { - var result = this._data.First(); - - this._data[0] = this._data.Last(); - this._data.RemoveAt(this._data.Count - 1); - this.SiftDown(0); - - return result; - } - - throw new InvalidOperationException("Heap is empty"); - } - - /// - /// Restores the heap property starting from i'th position down to the bottom - /// given that the downstream items fulfill the rule. - /// - /// The position of item where heap property is violated. - private void SiftDown(int i) - { - while (i < this._data.Count) - { - int l = (2 * i) + 1; - int r = l + 1; - if (l >= this._data.Count) - { - break; - } - int m = ((r < this._data.Count) && this._data[l].Item2 < this._data[r].Item2) ? r : l; - if (this._data[m].Item2 <= this._data[i].Item2) - { - break; - } - this.Swap(i, m); - i = m; - } - } - - /// - /// Restores the heap property starting from i'th position up to the head - /// given that the upstream items fulfill the rule. - /// - /// The position of item where heap property is violated. - private void SiftUp(int i) - { - while (i > 0) - { - int p = (i - 1) / 2; - if (this._data[i].Item2 <= this._data[p].Item2) - { - break; - } - this.Swap(i, p); - i = p; - } - } - - /// - /// Swaps items with the specified indicies. - /// - /// The first index. - /// The second index. - private void Swap(int i, int j) - { - var temp = this._data[i]; - this._data[i] = this._data[j]; - this._data[j] = temp; - } - - public IEnumerator<(int, float)> GetEnumerator() - { - return _data.GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return _data.GetEnumerator(); - } - } -} diff --git a/ZeroLevel.HNSW/Services/MaxHeap.cs b/ZeroLevel.HNSW/Services/MaxHeap.cs new file mode 100644 index 0000000..cc66b37 --- /dev/null +++ b/ZeroLevel.HNSW/Services/MaxHeap.cs @@ -0,0 +1,130 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace ZeroLevel.HNSW.Services +{ + /// + /// Max element always on top + /// + public class MaxHeap : + IEnumerable<(int, float)> + { + private readonly List<(int, float)> _elements; + + public MaxHeap(int size = -1) + { + if (size > 0) + _elements = new List<(int, float)>(size); + else + _elements = new List<(int, float)>(); + } + + private int GetLeftChildIndex(int elementIndex) => 2 * elementIndex + 1; + private int GetRightChildIndex(int elementIndex) => 2 * elementIndex + 2; + private int GetParentIndex(int elementIndex) => (elementIndex - 1) / 2; + + private bool HasLeftChild(int elementIndex) => GetLeftChildIndex(elementIndex) < _elements.Count; + private bool HasRightChild(int elementIndex) => GetRightChildIndex(elementIndex) < _elements.Count; + private bool IsRoot(int elementIndex) => elementIndex == 0; + + private (int, float) GetLeftChild(int elementIndex) => _elements[GetLeftChildIndex(elementIndex)]; + private (int, float) GetRightChild(int elementIndex) => _elements[GetRightChildIndex(elementIndex)]; + private (int, float) GetParent(int elementIndex) => _elements[GetParentIndex(elementIndex)]; + + public int Count => _elements.Count; + + public void Clear() + { + _elements.Clear(); + } + + private void Swap(int firstIndex, int secondIndex) + { + var temp = _elements[firstIndex]; + _elements[firstIndex] = _elements[secondIndex]; + _elements[secondIndex] = temp; + } + + public bool IsEmpty() + { + return _elements.Count == 0; + } + + public bool TryPeek(out int id, out float value) + { + if (_elements.Count == 0) + { + id = -1; + value = 0; + return false; + } + id = _elements[0].Item1; + value = _elements[0].Item2; + return true; + } + + public (int, float) Pop() + { + if (_elements.Count == 0) + throw new IndexOutOfRangeException(); + + var result = _elements[0]; + _elements[0] = _elements[_elements.Count - 1]; + _elements.RemoveAt(_elements.Count - 1); + + ReCalculateDown(); + + return result; + } + + public void Push((int, float) element) + { + _elements.Add(element); + + ReCalculateUp(); + } + + private void ReCalculateDown() + { + int index = 0; + while (HasLeftChild(index)) + { + var biggerIndex = GetLeftChildIndex(index); + if (HasRightChild(index) && GetRightChild(index).Item2 > GetLeftChild(index).Item2) + { + biggerIndex = GetRightChildIndex(index); + } + + if (_elements[biggerIndex].Item2 < _elements[index].Item2) + { + break; + } + + Swap(biggerIndex, index); + index = biggerIndex; + } + } + + private void ReCalculateUp() + { + var index = _elements.Count - 1; + while (!IsRoot(index) && _elements[index].Item2 > GetParent(index).Item2) + { + var parentIndex = GetParentIndex(index); + Swap(parentIndex, index); + index = parentIndex; + } + } + + public IEnumerator<(int, float)> GetEnumerator() + { + return _elements.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _elements.GetEnumerator(); + } + } +} diff --git a/ZeroLevel.HNSW/Services/MinHeap.cs b/ZeroLevel.HNSW/Services/MinHeap.cs new file mode 100644 index 0000000..c860da2 --- /dev/null +++ b/ZeroLevel.HNSW/Services/MinHeap.cs @@ -0,0 +1,130 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace ZeroLevel.HNSW.Services +{ + /// + /// Min element always on top + /// + public class MinHeap : + IEnumerable<(int, float)> + { + private readonly List<(int, float)> _elements; + + public MinHeap(int size = -1) + { + if (size > 0) + _elements = new List<(int, float)>(size); + else + _elements = new List<(int, float)>(); + } + + private int GetLeftChildIndex(int elementIndex) => 2 * elementIndex + 1; + private int GetRightChildIndex(int elementIndex) => 2 * elementIndex + 2; + private int GetParentIndex(int elementIndex) => (elementIndex - 1) / 2; + + private bool HasLeftChild(int elementIndex) => GetLeftChildIndex(elementIndex) < _elements.Count; + private bool HasRightChild(int elementIndex) => GetRightChildIndex(elementIndex) < _elements.Count; + private bool IsRoot(int elementIndex) => elementIndex == 0; + + private (int, float) GetLeftChild(int elementIndex) => _elements[GetLeftChildIndex(elementIndex)]; + private (int, float) GetRightChild(int elementIndex) => _elements[GetRightChildIndex(elementIndex)]; + private (int, float) GetParent(int elementIndex) => _elements[GetParentIndex(elementIndex)]; + + public int Count => _elements.Count; + + public void Clear() + { + _elements.Clear(); + } + + private void Swap(int firstIndex, int secondIndex) + { + var temp = _elements[firstIndex]; + _elements[firstIndex] = _elements[secondIndex]; + _elements[secondIndex] = temp; + } + + public bool IsEmpty() + { + return _elements.Count == 0; + } + + public bool TryPeek(out int id, out float value) + { + if (_elements.Count == 0) + { + id = -1; + value = 0; + return false; + } + id = _elements[0].Item1; + value = _elements[0].Item2; + return true; + } + + public (int, float) Pop() + { + if (_elements.Count == 0) + throw new IndexOutOfRangeException(); + + var result = _elements[0]; + _elements[0] = _elements[_elements.Count - 1]; + _elements.RemoveAt(_elements.Count - 1); + + ReCalculateDown(); + + return result; + } + + public void Push((int, float) element) + { + _elements.Add(element); + + ReCalculateUp(); + } + + private void ReCalculateDown() + { + int index = 0; + while (HasLeftChild(index)) + { + var smallerIndex = GetLeftChildIndex(index); + if (HasRightChild(index) && GetRightChild(index).Item2 < GetLeftChild(index).Item2) + { + smallerIndex = GetRightChildIndex(index); + } + + if (_elements[smallerIndex].Item2 >= _elements[index].Item2) + { + break; + } + + Swap(smallerIndex, index); + index = smallerIndex; + } + } + + private void ReCalculateUp() + { + var index = _elements.Count - 1; + while (!IsRoot(index) && _elements[index].Item2 < GetParent(index).Item2) + { + var parentIndex = GetParentIndex(index); + Swap(parentIndex, index); + index = parentIndex; + } + } + + public IEnumerator<(int, float)> GetEnumerator() + { + return _elements.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _elements.GetEnumerator(); + } + } +} \ No newline at end of file diff --git a/ZeroLevel.HNSW/Services/OPT/OptLayer.cs b/ZeroLevel.HNSW/Services/OPT/OptLayer.cs index e8af302..e47dad4 100644 --- a/ZeroLevel.HNSW/Services/OPT/OptLayer.cs +++ b/ZeroLevel.HNSW/Services/OPT/OptLayer.cs @@ -103,7 +103,7 @@ namespace ZeroLevel.HNSW.Services.OPT /// query element /// enter points ep /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, BinaryHeap W, int ef) + internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, IEnumerable<(int, float)> w, int ef) { /* * v ← ep // set of visited elements @@ -128,21 +128,25 @@ namespace ZeroLevel.HNSW.Services.OPT var v = new VisitedBitSet(_vectors.Count, _options.M); // v ← ep // set of visited elements v.Add(entryPointId); + var W = new MaxHeap(ef + 1); + foreach (var i in w) W.Push(i); var d = targetCosts(entryPointId); // C ← ep // set of candidates - var C = new BinaryHeap(); - C.Push(entryPointId, d); + var C = new MinHeap(ef); + C.Push((entryPointId, d)); // W ← ep // dynamic list of found nearest neighbors - W.Push(entryPointId, d); + W.Push((entryPointId, d)); + + int farthestId; + float farthestDistance; // run bfs while (C.Count > 0) { // get next candidate to check and expand - var toExpand = C.PopNearest(); - var farthestResult = W.Farthest; - if (toExpand.Item2 > farthestResult.Item2) + var toExpand = C.Pop(); + if (W.TryPeek(out _, out farthestDistance) && toExpand.Item2 > farthestDistance) { // the closest candidate is farther than farthest result break; @@ -156,16 +160,17 @@ namespace ZeroLevel.HNSW.Services.OPT if (!v.Contains(neighbourId)) { // enqueue perspective neighbours to expansion list - farthestResult = W.Farthest; + W.TryPeek(out farthestId, out farthestDistance); var neighbourDistance = targetCosts(neighbourId); - if (W.Count < ef || neighbourDistance < farthestResult.Item2) + if (W.Count < ef || (farthestId >= 0 && neighbourDistance < farthestDistance)) { - C.Push(neighbourId, neighbourDistance); - W.Push(neighbourId, neighbourDistance); + C.Push((neighbourId, neighbourDistance)); + + W.Push((neighbourId, neighbourDistance)); if (W.Count > ef) { - W.PopFarthest(); + W.Pop(); } } v.Add(neighbourId); @@ -174,6 +179,7 @@ namespace ZeroLevel.HNSW.Services.OPT } C.Clear(); v.Clear(); + return W; } /// @@ -182,7 +188,7 @@ namespace ZeroLevel.HNSW.Services.OPT /// query element /// enter points ep /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, BinaryHeap W, int ef, SearchContext context) + internal IEnumerable<(int, float)> KNearestAtLayer(int entryPointId, Func targetCosts, IEnumerable<(int, float)> w, int ef, SearchContext context) { /* * v ← ep // set of visited elements @@ -207,23 +213,28 @@ namespace ZeroLevel.HNSW.Services.OPT var v = new VisitedBitSet(_vectors.Count, _options.M); // v ← ep // set of visited elements v.Add(entryPointId); + + var W = new MaxHeap(ef + 1); + foreach (var i in w) W.Push(i); + // C ← ep // set of candidates - var C = new BinaryHeap(); + var C = new MinHeap(ef); var d = targetCosts(entryPointId); - C.Push(entryPointId, d); + C.Push((entryPointId, d)); // W ← ep // dynamic list of found nearest neighbors if (context.IsActiveNode(entryPointId)) { - W.Push(entryPointId, d); + W.Push((entryPointId, d)); } // run bfs while (C.Count > 0) { // get next candidate to check and expand - var toExpand = C.PopNearest(); + var toExpand = C.Pop(); if (W.Count > 0) { - if (toExpand.Item2 > W.Farthest.Item2) + if(W.TryPeek(out _, out var dist )) + if (toExpand.Item2 > dist) { // the closest candidate is farther than farthest result break; @@ -241,18 +252,18 @@ namespace ZeroLevel.HNSW.Services.OPT var neighbourDistance = targetCosts(neighbourId); if (context.IsActiveNode(neighbourId)) { - if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2)) + if (W.Count < ef || (W.Count > 0 && (W.TryPeek(out _, out var dist) && neighbourDistance < dist))) { - W.Push(neighbourId, neighbourDistance); + W.Push((neighbourId, neighbourDistance)); if (W.Count > ef) { - W.PopFarthest(); + W.Pop(); } } } if (W.Count < ef) { - C.Push(neighbourId, neighbourDistance); + C.Push((neighbourId, neighbourDistance)); } v.Add(neighbourId); } @@ -260,6 +271,7 @@ namespace ZeroLevel.HNSW.Services.OPT } C.Clear(); v.Clear(); + return W; } /// @@ -268,7 +280,7 @@ namespace ZeroLevel.HNSW.Services.OPT /// query element /// enter points ep /// Output: ef closest neighbors to q - internal void KNearestAtLayer(BinaryHeap W, int ef, SearchContext context) + internal IEnumerable<(int, float)> KNearestAtLayer(IEnumerable<(int, float)> w, int ef, SearchContext context) { /* * v ← ep // set of visited elements @@ -293,25 +305,28 @@ namespace ZeroLevel.HNSW.Services.OPT // v ← ep // set of visited elements var v = new VisitedBitSet(_vectors.Count, _options.M); // C ← ep // set of candidates - var C = new BinaryHeap(); + var C = new MinHeap(ef); foreach (var ep in context.EntryPoints) { var neighboursIds = GetNeighbors(ep).ToArray(); for (int i = 0; i < neighboursIds.Length; ++i) { - C.Push(ep, _links.Distance(ep, neighboursIds[i])); + C.Push((ep, _links.Distance(ep, neighboursIds[i]))); } v.Add(ep); } // W ← ep // dynamic list of found nearest neighbors + var W = new MaxHeap(ef + 1); + foreach (var i in w) W.Push(i); + // run bfs while (C.Count > 0) { // get next candidate to check and expand - var toExpand = C.PopNearest(); + var toExpand = C.Pop(); if (W.Count > 0) { - if (toExpand.Item2 > W.Farthest.Item2) + if (W.TryPeek(out _, out var dist) && toExpand.Item2 > dist) { // the closest candidate is farther than farthest result break; @@ -319,12 +334,12 @@ namespace ZeroLevel.HNSW.Services.OPT } if (context.IsActiveNode(toExpand.Item1)) { - if (W.Count < ef || W.Count == 0 || (W.Count > 0 && toExpand.Item2 < W.Farthest.Item2)) + if (W.Count < ef || W.Count == 0 || (W.Count > 0 && (W.TryPeek(out _, out var dist) && toExpand.Item2 < dist))) { - W.Push(toExpand.Item1, toExpand.Item2); + W.Push((toExpand.Item1, toExpand.Item2)); if (W.Count > ef) { - W.PopFarthest(); + W.Pop(); } } } @@ -333,21 +348,21 @@ namespace ZeroLevel.HNSW.Services.OPT { while (W.Count > ef) { - W.PopFarthest(); + W.Pop(); } - return; + return W; } else { foreach (var c in W) { - C.Push(c.Item1, c.Item2); + C.Push((c.Item1, c.Item2)); } } while (C.Count > 0) { // get next candidate to check and expand - var toExpand = C.PopNearest(); + var toExpand = C.Pop(); // expand candidate var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); for (int i = 0; i < neighboursIds.Length; ++i) @@ -359,18 +374,18 @@ namespace ZeroLevel.HNSW.Services.OPT var neighbourDistance = _links.Distance(toExpand.Item1, neighbourId); if (context.IsActiveNode(neighbourId)) { - if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2)) + if (W.Count < ef || (W.Count > 0 && (W.TryPeek(out _, out var dist) && neighbourDistance < dist))) { - W.Push(neighbourId, neighbourDistance); + W.Push((neighbourId, neighbourDistance)); if (W.Count > ef) { - W.PopFarthest(); + W.Pop(); } } } if (W.Count < ef) { - C.Push(neighbourId, neighbourDistance); + C.Push((neighbourId, neighbourDistance)); } v.Add(neighbourId); } @@ -378,19 +393,22 @@ namespace ZeroLevel.HNSW.Services.OPT } C.Clear(); v.Clear(); + return W; } /// /// Algorithm 3 /// - internal BinaryHeap SELECT_NEIGHBORS_SIMPLE(BinaryHeap W, int M) + internal MaxHeap SELECT_NEIGHBORS_SIMPLE(IEnumerable<(int, float)> w, int M) { + var W = new MaxHeap(w.Count()); + foreach (var i in w) W.Push(i); var bestN = M; if (W.Count > bestN) { while (W.Count > bestN) { - W.PopFarthest(); + W.Pop(); } } return W; @@ -406,11 +424,13 @@ namespace ZeroLevel.HNSW.Services.OPT /// flag indicating whether or not to extend candidate list /// flag indicating whether or not to add discarded elements /// Output: M elements selected by the heuristic - internal BinaryHeap SELECT_NEIGHBORS_HEURISTIC(Func distance, BinaryHeap W, int M) + internal MaxHeap SELECT_NEIGHBORS_HEURISTIC(Func distance, IEnumerable<(int, float)> w, int M) { // R ← ∅ - var R = new BinaryHeap(); + var R = new MaxHeap(_options.EFConstruction); // W ← C // working queue for the candidates + var W = new MaxHeap(_options.EFConstruction + 1); + foreach (var i in w) W.Push(i); // if extendCandidates // extend candidates by their neighbors if (_options.ExpandBestSelection) { @@ -432,30 +452,30 @@ namespace ZeroLevel.HNSW.Services.OPT // W ← W ⋃ eadj foreach (var id in extendBuffer) { - W.Push(id, distance(id)); + W.Push((id, distance(id))); } } // Wd ← ∅ // queue for the discarded candidates - var Wd = new BinaryHeap(); + var Wd = new MinHeap(_options.EFConstruction); // while │W│ > 0 and │R│< M while (W.Count > 0 && R.Count < M) { // e ← extract nearest element from W to q - var (e, ed) = W.PopNearest(); - var (fe, fd) = R.PopFarthest(); + var (e, ed) = W.Pop(); + var (fe, fd) = R.Pop(); // if e is closer to q compared to any element from R if (R.Count == 0 || ed < fd) { // R ← R ⋃ e - R.Push(e, ed); + R.Push((e, ed)); } else { // Wd ← Wd ⋃ e - Wd.Push(e, ed); + Wd.Push((e, ed)); } } // if keepPrunedConnections // add some of the discarded // connections from Wd @@ -465,8 +485,8 @@ namespace ZeroLevel.HNSW.Services.OPT while (Wd.Count > 0 && R.Count < M) { // R ← R ⋃ extract nearest element from Wd to q - var nearest = Wd.PopNearest(); - R.Push(nearest.Item1, nearest.Item2); + var nearest = Wd.Pop(); + R.Push((nearest.Item1, nearest.Item2)); } } // return R diff --git a/ZeroLevel.HNSW/Services/OPT/OptWorld.cs b/ZeroLevel.HNSW/Services/OPT/OptWorld.cs index 4ce962c..c2ba6fe 100644 --- a/ZeroLevel.HNSW/Services/OPT/OptWorld.cs +++ b/ZeroLevel.HNSW/Services/OPT/OptWorld.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading; using ZeroLevel.Services.Serialization; @@ -29,7 +30,7 @@ namespace ZeroLevel.HNSW.Services.OPT } } - internal OptWorld(NSWOptions options, Stream stream) + public OptWorld(NSWOptions options, Stream stream) { _options = options; Deserialize(stream); @@ -114,7 +115,7 @@ namespace ZeroLevel.HNSW.Services.OPT { var distance = new Func(candidate => _options.Distance(_vectors[q], _vectors[candidate])); // W ← ∅ // list for the currently found nearest elements - var W = new BinaryHeap(); + var W = new MinHeap(); // ep ← get enter point for hnsw //var ep = _layers[MaxLayer].FingEntryPointAtLayer(distance); //if(ep == -1) ep = EntryPoint; @@ -126,14 +127,22 @@ namespace ZeroLevel.HNSW.Services.OPT int l = _layerLevelGenerator.GetRandomLayer(); // for lc ← L … l+1 // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа + + int id; + float value; for (int lc = L; lc > l; --lc) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[lc].KNearestAtLayer(ep, distance, W, 1); + foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, W, 1)) + { + W.Push(i); + } // ep ← get the nearest element from W to q - var nearest = W.Nearest; - ep = nearest.Item1; - epDist = nearest.Item2; + if (W.TryPeek(out id, out value)) + { + ep = id; + epDist = value; + } W.Clear(); } //for lc ← min(L, l) … 0 @@ -147,12 +156,17 @@ namespace ZeroLevel.HNSW.Services.OPT else { // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction); + foreach (var i in _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction)) + { + W.Push(i); + } // ep ← W - var nearest = W.Nearest; - ep = nearest.Item1; - epDist = nearest.Item2; + if (W.TryPeek(out id, out value)) + { + ep = id; + epDist = value; + } // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 var neighbors = SelectBestForConnecting(lc, distance, W); @@ -201,7 +215,7 @@ namespace ZeroLevel.HNSW.Services.OPT return layer == 0 ? 2 * _options.M : _options.M; } - private BinaryHeap SelectBestForConnecting(int layer, Func distance, BinaryHeap candidates) + private IEnumerable<(int, float)> SelectBestForConnecting(int layer, Func distance, IEnumerable<(int, float)> candidates) { if (_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) return _layers[layer].SELECT_NEIGHBORS_SIMPLE(candidates, GetM(layer)); @@ -211,19 +225,22 @@ namespace ZeroLevel.HNSW.Services.OPT /// /// Algorithm 5 /// - private BinaryHeap KNearest(TItem q, int k) + private IEnumerable<(int, float)> KNearest(TItem q, int k) { _lockGraph.EnterReadLock(); try { if (_vectors.Count == 0) { - return BinaryHeap.Empty; + return Enumerable.Empty<(int, float)>(); } + + int id; + float value; var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); // W ← ∅ // set for the current nearest elements - var W = new BinaryHeap(k + 1); + var W = new MinHeap(k + 1); // ep ← get enter point for hnsw var ep = EntryPoint; // L ← level of ep // top layer for hnsw @@ -232,13 +249,22 @@ namespace ZeroLevel.HNSW.Services.OPT for (int layer = L; layer > 0; --layer) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[layer].KNearestAtLayer(ep, distance, W, 1); + foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, W, 1)) + { + W.Push(i); + } // ep ← get nearest element from W to q - ep = W.Nearest.Item1; + if (W.TryPeek(out id, out value)) + { + ep = id; + } W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k); + foreach (var i in _layers[0].KNearestAtLayer(ep, distance, W, k)) + { + W.Push(i); + } // return K nearest elements from W to q return W; } @@ -247,19 +273,21 @@ namespace ZeroLevel.HNSW.Services.OPT _lockGraph.ExitReadLock(); } } - private BinaryHeap KNearest(TItem q, int k, SearchContext context) + private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context) { _lockGraph.EnterReadLock(); try { if (_vectors.Count == 0) { - return BinaryHeap.Empty; + return Enumerable.Empty<(int, float)>(); } + int id; + float value; var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); // W ← ∅ // set for the current nearest elements - var W = new BinaryHeap(k + 1); + var W = new MinHeap(k + 1); // ep ← get enter point for hnsw var ep = EntryPoint; // L ← level of ep // top layer for hnsw @@ -268,13 +296,22 @@ namespace ZeroLevel.HNSW.Services.OPT for (int layer = L; layer > 0; --layer) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - _layers[layer].KNearestAtLayer(ep, distance, W, 1); + foreach (var i in _layers[layer].KNearestAtLayer(ep, distance, W, 1)) + { + W.Push(i); + } // ep ← get nearest element from W to q - ep = W.Nearest.Item1; + if (W.TryPeek(out id, out value)) + { + ep = id; + } W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k, context); + foreach (var i in _layers[0].KNearestAtLayer(ep, distance, W, k, context)) + { + W.Push(i); + } // return K nearest elements from W to q return W; } @@ -284,20 +321,22 @@ namespace ZeroLevel.HNSW.Services.OPT } } - private BinaryHeap KNearest(int k, SearchContext context) + private IEnumerable<(int, float)> KNearest(int k, SearchContext context) { _lockGraph.EnterReadLock(); try { if (_vectors.Count == 0) { - return BinaryHeap.Empty; + return Enumerable.Empty<(int, float)>(); } - var distance = new Func((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2])); // W ← ∅ // set for the current nearest elements - var W = new BinaryHeap(k + 1); + var W = new MaxHeap(k + 1); // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(W, k, context); + foreach (var i in _layers[0].KNearestAtLayer(W, k, context)) + { + W.Push(i); + } // return K nearest elements from W to q return W; } diff --git a/ZeroLevel.sln b/ZeroLevel.sln index aa957a7..8c9eed5 100644 --- a/ZeroLevel.sln +++ b/ZeroLevel.sln @@ -61,7 +61,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ZeroLevel.Qdrant", "ZeroLev EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ZeroLevel.HNSW", "ZeroLevel.HNSW\ZeroLevel.HNSW.csproj", "{1EAC0A2C-B00F-4353-94D3-3BB4DC5C92AE}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HNSWDemo", "TestHNSW\HNSWDemo\HNSWDemo.csproj", "{E0E9EC21-B958-4018-AE30-67DB88EFCB90}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HNSWDemo", "TestHNSW\HNSWDemo\HNSWDemo.csproj", "{E0E9EC21-B958-4018-AE30-67DB88EFCB90}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "temp", "temp\temp.csproj", "{DFE59EBC-B6BC-450C-9D81-394CCAE30498}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "temp2", "temp2\temp2.csproj", "{DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -313,6 +317,30 @@ Global {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x64.Build.0 = Release|Any CPU {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x86.ActiveCfg = Release|Any CPU {E0E9EC21-B958-4018-AE30-67DB88EFCB90}.Release|x86.Build.0 = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x64.ActiveCfg = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x64.Build.0 = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x86.ActiveCfg = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Debug|x86.Build.0 = Debug|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|Any CPU.Build.0 = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x64.ActiveCfg = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x64.Build.0 = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x86.ActiveCfg = Release|Any CPU + {DFE59EBC-B6BC-450C-9D81-394CCAE30498}.Release|x86.Build.0 = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x64.ActiveCfg = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x64.Build.0 = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x86.ActiveCfg = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Debug|x86.Build.0 = Debug|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|Any CPU.Build.0 = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x64.ActiveCfg = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x64.Build.0 = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x86.ActiveCfg = Release|Any CPU + {DEBF2F14-E7F8-40A4-A4A8-87C9033D52A4}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/temp/Program.cs b/temp/Program.cs new file mode 100644 index 0000000..17a4f0e --- /dev/null +++ b/temp/Program.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.Services.Serialization; + +namespace temp +{ + class Program + { + static void Main(string[] args) + { + SmallWorld world; + using (var ms = new FileStream(@"F:\graph_test.bin", FileMode.Open, FileAccess.Read, FileShare.None)) + { + world = SmallWorld.CreateWorldFrom(NSWOptions.Create(6, 12, 100, 10, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + var test_vectors = new List(); + using (var ms = new FileStream(@"F:\test_vectors.bin", FileMode.Open, FileAccess.Read, FileShare.None)) + { + using (var reader = new MemoryStreamReader(ms)) + { + var count = reader.ReadInt32(); + for (int i = 0; i < count; i++) + { + test_vectors.Add(reader.ReadFloatArray()); + } + } + } + Forward(world, test_vectors); + Console.WriteLine("Completed"); + } + + static void Forward(SmallWorld world, List test_vectors) + { + int K = 10; + foreach (var v in test_vectors) + { + var result = world.Search(v, K); + Console.WriteLine(result.Count()); + } + } + } +} diff --git a/temp/temp.csproj b/temp/temp.csproj new file mode 100644 index 0000000..b849ee3 --- /dev/null +++ b/temp/temp.csproj @@ -0,0 +1,12 @@ + + + + Exe + net5.0 + + + + + + + diff --git a/temp2/Program.cs b/temp2/Program.cs new file mode 100644 index 0000000..1798fcc --- /dev/null +++ b/temp2/Program.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using ZeroLevel.HNSW; +using ZeroLevel.HNSW.Services.OPT; +using ZeroLevel.Services.Serialization; + +namespace temp2 +{ + class Program + { + static void Main(string[] args) + { + OptWorld world; + using (var ms = new FileStream(@"F:\graph_test.bin", FileMode.Open, FileAccess.Read, FileShare.None)) + { + world = new OptWorld(NSWOptions.Create(6, 12, 100, 10, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple), ms); + } + + var test_vectors = new List(); + using (var ms = new FileStream(@"F:\test_vectors.bin", FileMode.Open, FileAccess.Read, FileShare.None)) + { + using (var reader = new MemoryStreamReader(ms)) + { + var count = reader.ReadInt32(); + for(int i=0;i world, List test_vectors) + { + int K = 10; + foreach (var v in test_vectors) + { + var result = world.Search(v, K); + Console.WriteLine(result.Count()); + } + } + } +} diff --git a/temp2/temp2.csproj b/temp2/temp2.csproj new file mode 100644 index 0000000..b849ee3 --- /dev/null +++ b/temp2/temp2.csproj @@ -0,0 +1,12 @@ + + + + Exe + net5.0 + + + + + + +