diff --git a/TestHNSW/HNSWDemo/HNSWDemo.csproj b/TestHNSW/HNSWDemo/HNSWDemo.csproj index b4e60da..7b3611a 100644 --- a/TestHNSW/HNSWDemo/HNSWDemo.csproj +++ b/TestHNSW/HNSWDemo/HNSWDemo.csproj @@ -1,4 +1,4 @@ - + Exe @@ -13,4 +13,10 @@ + + + Always + + + diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index a122c12..3c112b5 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using ZeroLevel.HNSW; using ZeroLevel.HNSW.Services; +using ZeroLevel.HNSW.Services.OPT; namespace HNSWDemo { @@ -13,6 +14,7 @@ namespace HNSWDemo { public class VectorsDirectCompare { + private const int HALF_LONG_BITS = 32; private readonly IList _vectors; private readonly Func _distance; @@ -32,6 +34,70 @@ namespace HNSWDemo } return weights.OrderBy(p => p.Value).Take(k).Select(p => (p.Key, p.Value)); } + + public List> DetectClusters() + { + var links = new SortedList(); + for (int i = 0; i < _vectors.Count; i++) + { + for (int j = i + 1; j < _vectors.Count; j++) + { + long k = (((long)(i)) << HALF_LONG_BITS) + j; + links.Add(k, _distance(_vectors[i], _vectors[j])); + } + } + + // 1. Find R - bound between intra-cluster distances and out-of-cluster distances + var histogram = new Histogram(HistogramMode.SQRT, links.Values); + int threshold = histogram.OTSU(); + var min = histogram.Bounds[threshold - 1]; + var max = histogram.Bounds[threshold]; + var R = (max + min) / 2; + + // 2. Get links with distances less than R + var resultLinks = new SortedList(); + foreach (var pair in links) + { + if (pair.Value < R) + { + resultLinks.Add(pair.Key, pair.Value); + } + } + + // 3. Extract clusters + List> clusters = new List>(); + foreach (var pair in resultLinks) + { + var k = pair.Key; + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + + bool found = false; + foreach (var c in clusters) + { + if (c.Contains(id1)) + { + c.Add(id2); + found = true; + break; + } + else if (c.Contains(id2)) + { + c.Add(id1); + found = true; + break; + } + } + if (found == false) + { + var c = new HashSet(); + c.Add(id1); + c.Add(id2); + clusters.Add(c); + } + } + return clusters; + } } public enum Gender @@ -91,7 +157,7 @@ namespace HNSWDemo { var vector = new float[vectorSize]; DefaultRandomGenerator.Instance.NextFloats(vector); - //VectorUtils.NormalizeSIMD(vector); + VectorUtils.NormalizeSIMD(vector); vectors.Add(vector); } return vectors; @@ -100,11 +166,107 @@ namespace HNSWDemo static void Main(string[] args) { - AutoClusteringTest(); + var samples = RandomVectors(128, 600); + var opt_world = new OptWorld(NSWOptions.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + opt_world.AddItems(samples); + //AccuracityTest(); Console.WriteLine("Completed"); Console.ReadKey(); } + static void BinaryHeapTest() + { + var heap = new BinaryHeap(); + heap.Push(1, .03f); + heap.Push(2, .05f); + heap.Push(3, .01f); + heap.Push(4, 1.03f); + heap.Push(5, 2.03f); + heap.Push(6, .73f); + + var n = heap.Nearest; + Console.WriteLine($"Nearest: [{n.Item1}] = {n.Item2}"); + var f = heap.Farthest; + Console.WriteLine($"Farthest: [{f.Item1}] = {f.Item2}"); + + Console.WriteLine("From nearest to farthest"); + while (heap.Count > 0) + { + var i = heap.PopNearest(); + Console.WriteLine($"[{i.Item1}] = {i.Item2}"); + } + + heap.Push(1, .03f); + heap.Push(2, .05f); + heap.Push(3, .01f); + heap.Push(4, 1.03f); + heap.Push(5, 2.03f); + heap.Push(6, .73f); + Console.WriteLine("From farthest to nearest"); + while (heap.Count > 0) + { + var i = heap.PopFarthest(); + Console.WriteLine($"[{i.Item1}] = {i.Item2}"); + } + } + + static void TestOnMnist() + { + int imageCount, rowCount, colCount; + var buf = new byte[4]; + var image = new byte[28 * 28]; + var vectors = new List(); + using (var fs = new FileStream("t10k-images.idx3-ubyte", FileMode.Open, FileAccess.Read, FileShare.None)) + { + // first 4 bytes is a magic number + fs.Read(buf, 0, 4); + // second 4 bytes is the number of images + fs.Read(buf, 0, 4); + imageCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0); + // third 4 bytes is the row count + fs.Read(buf, 0, 4); + rowCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0); + // fourth 4 bytes is the column count + fs.Read(buf, 0, 4); + colCount = BitConverter.ToInt32(buf.Reverse().ToArray(), 0); + + for (int i = 0; i < imageCount; i++) + { + fs.Read(image, 0, image.Length); + vectors.Add(image.Select(b => (float)b).ToArray()); + } + } + + //var direct = new VectorsDirectCompare(vectors, Metrics.L2Euclidean); + + var options = NSWOptions.Create(8, 16, 200, 200, Metrics.L2Euclidean, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple); + SmallWorld world; + if (File.Exists("graph.bin")) + { + using (var fs = new FileStream("graph.bin", FileMode.Open, FileAccess.Read, FileShare.None)) + { + world = SmallWorld.CreateWorldFrom(options, fs); + } + } + else + { + world = SmallWorld.CreateWorld(options); + world.AddItems(vectors); + using (var fs = new FileStream("graph.bin", FileMode.Create, FileAccess.Write, FileShare.None)) + { + world.Serialize(fs); + } + } + + var clusters = AutomaticGraphClusterer.DetectClusters(world); + Console.WriteLine($"Found {clusters.Count} clusters"); + for (int i = 0; i < clusters.Count; i++) + { + Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); + } + + } + static void AutoClusteringTest() { var vectors = RandomVectors(128, 3000); @@ -114,7 +276,7 @@ namespace HNSWDemo Console.WriteLine($"Found {clusters.Count} clusters"); for (int i = 0; i < clusters.Count; i++) { - Console.WriteLine($"Cluster {i+1} countains {clusters[i].Count} items"); + Console.WriteLine($"Cluster {i + 1} countains {clusters[i].Count} items"); } } @@ -135,10 +297,10 @@ namespace HNSWDemo static void DrawHistogram(Histogram histogram, string filename) { - /* while (histogram.CountSignChanges() > 3) - { - histogram.Smooth(); - }*/ + /* while (histogram.CountSignChanges() > 3) + { + histogram.Smooth(); + }*/ var wb = 1200 / histogram.Values.Length; var k = 600.0f / (float)histogram.Values.Max(); @@ -147,8 +309,7 @@ namespace HNSWDemo using (var bmp = new Bitmap(1200, 600)) { - using (var g = Graphics.FromImage(bmp)) - { + using (var g = Graphics.FromImage(bmp)) { for (int i = 0; i < histogram.Values.Length; i++) { var height = (int)(histogram.Values[i] * k); @@ -481,26 +642,37 @@ namespace HNSWDemo static void AccuracityTest() { int K = 200; - var count = 5000; + var count = 2000; var testCount = 1000; var dimensionality = 128; var totalHits = new List(); var timewatchesNP = new List(); var timewatchesHNSW = new List(); + + var totalOptHits = new List(); + var timewatchesOptHNSW = new List(); + var samples = RandomVectors(dimensionality, count); var sw = new Stopwatch(); - var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits); - var world = new SmallWorld(NSWOptions.Create(8, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + var test = new VectorsDirectCompare(samples, Metrics.L2Euclidean); + var world = new SmallWorld(NSWOptions.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); + + var opt_world = new OptWorld(NSWOptions.Create(8, 15, 200, 200, Metrics.L2Euclidean, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); sw.Start(); var ids = world.AddItems(samples.ToArray()); sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items: {sw.ElapsedMilliseconds} ms"); - Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms"); + sw.Restart(); + opt_world.AddItems(samples.ToArray()); + sw.Stop(); + Console.WriteLine($"Insert {ids.Length} items in OPT: {sw.ElapsedMilliseconds} ms"); Console.WriteLine("Start test"); + var test_vectors = RandomVectors(dimensionality, testCount); foreach (var v in test_vectors) { @@ -512,6 +684,7 @@ namespace HNSWDemo sw.Restart(); var result = world.Search(v, K); sw.Stop(); + timewatchesHNSW.Add(sw.ElapsedMilliseconds); var hits = 0; foreach (var r in result) @@ -522,15 +695,39 @@ namespace HNSWDemo } } totalHits.Add(hits); + + + sw.Restart(); + result = opt_world.Search(v, K); + sw.Stop(); + + timewatchesOptHNSW.Add(sw.ElapsedMilliseconds); + hits = 0; + foreach (var r in result) + { + if (gt.ContainsKey(r.Item1)) + { + hits++; + } + } + totalOptHits.Add(hits); } Console.WriteLine($"MIN Accuracity: {totalHits.Min() * 100 / K}%"); Console.WriteLine($"AVG Accuracity: {totalHits.Average() * 100 / K}%"); Console.WriteLine($"MAX Accuracity: {totalHits.Max() * 100 / K}%"); + Console.WriteLine($"MIN Opt Accuracity: {totalOptHits.Min() * 100 / K}%"); + Console.WriteLine($"AVG Opt Accuracity: {totalOptHits.Average() * 100 / K}%"); + Console.WriteLine($"MAX Opt Accuracity: {totalOptHits.Max() * 100 / K}%"); + Console.WriteLine($"MIN HNSW TIME: {timewatchesHNSW.Min()} ms"); Console.WriteLine($"AVG HNSW TIME: {timewatchesHNSW.Average()} ms"); Console.WriteLine($"MAX HNSW TIME: {timewatchesHNSW.Max()} ms"); + Console.WriteLine($"MIN Opt HNSW TIME: {timewatchesOptHNSW.Min()} ms"); + Console.WriteLine($"AVG Opt HNSW TIME: {timewatchesOptHNSW.Average()} ms"); + Console.WriteLine($"MAX Opt HNSW TIME: {timewatchesOptHNSW.Max()} ms"); + Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); diff --git a/TestHNSW/HNSWDemo/t10k-images.idx3-ubyte b/TestHNSW/HNSWDemo/t10k-images.idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/TestHNSW/HNSWDemo/t10k-images.idx3-ubyte differ diff --git a/ZeroLevel.HNSW/Model/Histogram.cs b/ZeroLevel.HNSW/Model/Histogram.cs index 8ab2a89..ed49a43 100644 --- a/ZeroLevel.HNSW/Model/Histogram.cs +++ b/ZeroLevel.HNSW/Model/Histogram.cs @@ -21,7 +21,7 @@ namespace ZeroLevel.HNSW public float[] Bounds { get; } public int[] Values { get; } - internal Histogram(HistogramMode mode, IList data) + public Histogram(HistogramMode mode, IList data) { Mode = mode; Min = data.Min(); @@ -171,6 +171,13 @@ namespace ZeroLevel.HNSW threshold = k; } } + /* + var local_max = Values[threshold]; + for (int i = threshold + 1; i < Values.Length; i++) + { + + } + */ return threshold; } #endregion diff --git a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs index 60177ef..05ec8f1 100644 --- a/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs +++ b/ZeroLevel.HNSW/Services/AutomaticGraphClusterer.cs @@ -16,6 +16,7 @@ namespace ZeroLevel.HNSW.Services var max = histogram.Bounds[threshold]; var R = (max + min) / 2; + // 2. Get links with distances less than R var resultLinks = new SortedList(); foreach (var pair in links) diff --git a/ZeroLevel.HNSW/Services/BinaryHeap.cs b/ZeroLevel.HNSW/Services/BinaryHeap.cs index 675f541..c03318a 100644 --- a/ZeroLevel.HNSW/Services/BinaryHeap.cs +++ b/ZeroLevel.HNSW/Services/BinaryHeap.cs @@ -1,80 +1,90 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; -using System.Text; -using System.Threading.Tasks; -namespace ZeroLevel.HNSW.Services +namespace ZeroLevel.HNSW { /// /// Binary heap wrapper around the /// It's a max-heap implementation i.e. the maximum element is always on top. - /// But the order of elements can be customized by providing instance. /// /// The type of the items in the source list. - public class BinaryHeap + public class BinaryHeap : + IEnumerable<(int, float)> { - /// - /// Initializes a new instance of the class. - /// - /// The buffer to store heap items. - public BinaryHeap(IList buffer) - : this(buffer, Comparer.Default) - { - } + private static BinaryHeap _empty = new BinaryHeap(); - /// - /// Initializes a new instance of the class. - /// - /// The buffer to store heap items. - /// The comparer which defines order of items. - public BinaryHeap(IList buffer, IComparer comparer) + public static BinaryHeap Empty => _empty; + + private readonly List<(int, float)> _data; + + private bool _frozen = false; + public (int, float) Nearest => _data[_data.Count - 1]; + public (int, float) Farthest => _data[0]; + + public (int, float) PopNearest() { - if (buffer == null) + if (this._data.Any()) { - throw new ArgumentNullException(nameof(buffer)); + var result = this._data[this._data.Count - 1]; + this._data.RemoveAt(this._data.Count - 1); + return result; } + return (-1, -1); + } - this.Buffer = buffer; - this.Comparer = comparer; - for (int i = 1; i < this.Buffer.Count; ++i) + public (int, float) PopFarthest() + { + if (this._data.Any()) { - this.SiftUp(i); + var result = this._data.First(); + this._data[0] = this._data.Last(); + this._data.RemoveAt(this._data.Count - 1); + this.SiftDown(0); + return result; } + return (-1, -1); } - /// - /// Gets the heap comparer. - /// - public IComparer Comparer { get; private set; } + public int Count => _data.Count; + public void Clear() => _data.Clear(); /// - /// Gets the buffer of the heap. + /// Initializes a new instance of the class. /// - public IList Buffer { get; private set; } + /// The buffer to store heap items. + public BinaryHeap(int k = -1, bool frozen = false) + { + _frozen = frozen; + if (k > 0) + _data = new List<(int, float)>(k); + else + _data = new List<(int, float)>(); + } /// /// Pushes item to the heap. /// /// The item to push. - public void Push(T item) + public void Push(int item, float distance) { - this.Buffer.Add(item); - this.SiftUp(this.Buffer.Count - 1); + this._data.Add((item, distance)); + this.SiftUp(this._data.Count - 1); } /// /// Pops the item from the heap. /// /// The popped item. - public T Pop() + public (int, float) Pop() { - if (this.Buffer.Any()) + if (this._data.Any()) { - var result = this.Buffer.First(); + var result = this._data.First(); - this.Buffer[0] = this.Buffer.Last(); - this.Buffer.RemoveAt(this.Buffer.Count - 1); + this._data[0] = this._data.Last(); + this._data.RemoveAt(this._data.Count - 1); this.SiftDown(0); return result; @@ -90,21 +100,19 @@ namespace ZeroLevel.HNSW.Services /// The position of item where heap property is violated. private void SiftDown(int i) { - while (i < this.Buffer.Count) + while (i < this._data.Count) { int l = (2 * i) + 1; int r = l + 1; - if (l >= this.Buffer.Count) + if (l >= this._data.Count) { break; } - - int m = r < this.Buffer.Count && this.Comparer.Compare(this.Buffer[l], this.Buffer[r]) < 0 ? r : l; - if (this.Comparer.Compare(this.Buffer[m], this.Buffer[i]) <= 0) + int m = ((r < this._data.Count) && this._data[l].Item2 < this._data[r].Item2) ? r : l; + if (this._data[m].Item2 <= this._data[i].Item2) { break; } - this.Swap(i, m); i = m; } @@ -120,11 +128,10 @@ namespace ZeroLevel.HNSW.Services while (i > 0) { int p = (i - 1) / 2; - if (this.Comparer.Compare(this.Buffer[i], this.Buffer[p]) <= 0) + if (this._data[i].Item2 <= this._data[p].Item2) { break; } - this.Swap(i, p); i = p; } @@ -137,9 +144,19 @@ namespace ZeroLevel.HNSW.Services /// The second index. private void Swap(int i, int j) { - var temp = this.Buffer[i]; - this.Buffer[i] = this.Buffer[j]; - this.Buffer[j] = temp; + var temp = this._data[i]; + this._data[i] = this._data[j]; + this._data[j] = temp; + } + + public IEnumerator<(int, float)> GetEnumerator() + { + return _data.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _data.GetEnumerator(); } } } diff --git a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs index de44419..72ff287 100644 --- a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs +++ b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs @@ -104,11 +104,23 @@ namespace ZeroLevel.HNSW _rwLock.EnterReadLock(); try { - foreach (var (k, v) in Search(_set, id)) + if (_set.Count == 1) { + var k = _set.Keys[0]; + var v = _set[k]; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - yield return (id1, id2, v); + if (id1 == id) yield return (id, id2, v); + else if (id2 == id) yield return (id1, id, v); + } + else if (_set.Count > 1) + { + foreach (var (k, v) in Search(_set, id)) + { + var id1 = (int)(k >> HALF_LONG_BITS); + var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); + yield return (id1, id2, v); + } } } finally diff --git a/ZeroLevel.HNSW/Services/Layer.cs b/ZeroLevel.HNSW/Services/Layer.cs index ace3a65..a840cea 100644 --- a/ZeroLevel.HNSW/Services/Layer.cs +++ b/ZeroLevel.HNSW/Services/Layer.cs @@ -53,6 +53,12 @@ namespace ZeroLevel.HNSW int index = 0; for (int ni = 1; ni < nearest.Length; ni++) { + // Если осталась ссылка узла на себя, удаляем ее в первую очередь + if (nearest[ni].Item1 == nearest[ni].Item2) + { + index = ni; + break; + } if (nearest[ni].Item3 > distance) { index = ni; @@ -81,6 +87,23 @@ namespace ZeroLevel.HNSW } #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf + internal int FingEntryPointAtLayer(Func targetCosts) + { + var set = new HashSet(_links.Items().Select(p => p.Item1)); + int minId = -1; + float minDist = float.MaxValue; + foreach (var id in set) + { + var d = targetCosts(id); + if (d < minDist && Math.Abs(d) > float.Epsilon) + { + minDist = d; + minId = id; + } + } + return minId; + } + /// /// Algorithm 2 /// diff --git a/ZeroLevel.HNSW/Services/OPT/OptLayer.cs b/ZeroLevel.HNSW/Services/OPT/OptLayer.cs new file mode 100644 index 0000000..e8af302 --- /dev/null +++ b/ZeroLevel.HNSW/Services/OPT/OptLayer.cs @@ -0,0 +1,491 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW.Services.OPT +{ + /// + /// NSW graph + /// + internal sealed class OptLayer + : IBinarySerializable + { + private readonly NSWOptions _options; + private readonly VectorSet _vectors; + private readonly CompactBiDirectionalLinksSet _links; + internal SortedList Links => _links.Links; + + /// + /// There are links е the layer + /// + internal bool HasLinks => (_links.Count > 0); + + /// + /// HNSW layer + /// + /// HNSW graph options + /// General vector set + internal OptLayer(NSWOptions options, VectorSet vectors) + { + _options = options; + _vectors = vectors; + _links = new CompactBiDirectionalLinksSet(); + } + + /// + /// Adding new bidirectional link + /// + /// New node + /// The node with which the connection will be made + /// + /// + internal void AddBidirectionallConnections(int q, int p, float qpDistance, bool isMapLayer) + { + // поиск в ширину ближайших узлов к найденному + var nearest = _links.FindLinksForId(p).ToArray(); + // если у найденного узла максимальное количество связей + // if │eConn│ > Mmax // shrink connections of e + if (nearest.Length >= (isMapLayer ? _options.M * 2 : _options.M)) + { + // ищем связь с самой большой дистанцией + float distance = nearest[0].Item3; + int index = 0; + for (int ni = 1; ni < nearest.Length; ni++) + { + // Если осталась ссылка узла на себя, удаляем ее в первую очередь + if (nearest[ni].Item1 == nearest[ni].Item2) + { + index = ni; + break; + } + if (nearest[ni].Item3 > distance) + { + index = ni; + distance = nearest[ni].Item3; + } + } + // делаем перелинковку вставляя новый узел между найденными + var id1 = nearest[index].Item1; + var id2 = nearest[index].Item2; + _links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q])); + } + else + { + if (nearest.Length == 1 && nearest[0].Item1 == nearest[0].Item2) + { + // убираем связи на самих себя + var id1 = nearest[0].Item1; + var id2 = nearest[0].Item2; + _links.Relink(id1, id2, q, qpDistance, _options.Distance(_vectors[id2], _vectors[q])); + } + else + { + // добавляем связь нового узла к найденному + _links.Add(q, p, qpDistance); + } + } + } + + /// + /// Adding a node with a connection to itself + /// + /// + internal void Append(int q) + { + _links.Add(q, q, 0); + } + + #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf + /// + /// Algorithm 2 + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(int entryPointId, Func targetCosts, BinaryHeap W, int ef) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + var v = new VisitedBitSet(_vectors.Count, _options.M); + // v ← ep // set of visited elements + v.Add(entryPointId); + + var d = targetCosts(entryPointId); + // C ← ep // set of candidates + var C = new BinaryHeap(); + C.Push(entryPointId, d); + // W ← ep // dynamic list of found nearest neighbors + W.Push(entryPointId, d); + + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = C.PopNearest(); + var farthestResult = W.Farthest; + if (toExpand.Item2 > farthestResult.Item2) + { + // the closest candidate is farther than farthest result + break; + } + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + farthestResult = W.Farthest; + + var neighbourDistance = targetCosts(neighbourId); + if (W.Count < ef || neighbourDistance < farthestResult.Item2) + { + C.Push(neighbourId, neighbourDistance); + W.Push(neighbourId, neighbourDistance); + if (W.Count > ef) + { + W.PopFarthest(); + } + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + + /// + /// Algorithm 2 + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(int entryPointId, Func targetCosts, BinaryHeap W, int ef, SearchContext context) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + var v = new VisitedBitSet(_vectors.Count, _options.M); + // v ← ep // set of visited elements + v.Add(entryPointId); + // C ← ep // set of candidates + var C = new BinaryHeap(); + var d = targetCosts(entryPointId); + C.Push(entryPointId, d); + // W ← ep // dynamic list of found nearest neighbors + if (context.IsActiveNode(entryPointId)) + { + W.Push(entryPointId, d); + } + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = C.PopNearest(); + if (W.Count > 0) + { + if (toExpand.Item2 > W.Farthest.Item2) + { + // the closest candidate is farther than farthest result + break; + } + } + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + var neighbourDistance = targetCosts(neighbourId); + if (context.IsActiveNode(neighbourId)) + { + if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2)) + { + W.Push(neighbourId, neighbourDistance); + if (W.Count > ef) + { + W.PopFarthest(); + } + } + } + if (W.Count < ef) + { + C.Push(neighbourId, neighbourDistance); + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + + /// + /// Algorithm 2, modified for LookAlike + /// + /// query element + /// enter points ep + /// Output: ef closest neighbors to q + internal void KNearestAtLayer(BinaryHeap W, int ef, SearchContext context) + { + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + // v ← ep // set of visited elements + var v = new VisitedBitSet(_vectors.Count, _options.M); + // C ← ep // set of candidates + var C = new BinaryHeap(); + foreach (var ep in context.EntryPoints) + { + var neighboursIds = GetNeighbors(ep).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + C.Push(ep, _links.Distance(ep, neighboursIds[i])); + } + v.Add(ep); + } + // W ← ep // dynamic list of found nearest neighbors + // run bfs + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = C.PopNearest(); + if (W.Count > 0) + { + if (toExpand.Item2 > W.Farthest.Item2) + { + // the closest candidate is farther than farthest result + break; + } + } + if (context.IsActiveNode(toExpand.Item1)) + { + if (W.Count < ef || W.Count == 0 || (W.Count > 0 && toExpand.Item2 < W.Farthest.Item2)) + { + W.Push(toExpand.Item1, toExpand.Item2); + if (W.Count > ef) + { + W.PopFarthest(); + } + } + } + } + if (W.Count > ef) + { + while (W.Count > ef) + { + W.PopFarthest(); + } + return; + } + else + { + foreach (var c in W) + { + C.Push(c.Item1, c.Item2); + } + } + while (C.Count > 0) + { + // get next candidate to check and expand + var toExpand = C.PopNearest(); + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) + { + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) + { + // enqueue perspective neighbours to expansion list + var neighbourDistance = _links.Distance(toExpand.Item1, neighbourId); + if (context.IsActiveNode(neighbourId)) + { + if (W.Count < ef || (W.Count > 0 && neighbourDistance < W.Farthest.Item2)) + { + W.Push(neighbourId, neighbourDistance); + if (W.Count > ef) + { + W.PopFarthest(); + } + } + } + if (W.Count < ef) + { + C.Push(neighbourId, neighbourDistance); + } + v.Add(neighbourId); + } + } + } + C.Clear(); + v.Clear(); + } + + /// + /// Algorithm 3 + /// + internal BinaryHeap SELECT_NEIGHBORS_SIMPLE(BinaryHeap W, int M) + { + var bestN = M; + if (W.Count > bestN) + { + while (W.Count > bestN) + { + W.PopFarthest(); + } + } + return W; + } + + + + /// + /// Algorithm 4 + /// + /// base element + /// candidate elements + /// flag indicating whether or not to extend candidate list + /// flag indicating whether or not to add discarded elements + /// Output: M elements selected by the heuristic + internal BinaryHeap SELECT_NEIGHBORS_HEURISTIC(Func distance, BinaryHeap W, int M) + { + // R ← ∅ + var R = new BinaryHeap(); + // W ← C // working queue for the candidates + // if extendCandidates // extend candidates by their neighbors + if (_options.ExpandBestSelection) + { + var extendBuffer = new HashSet(); + // for each e ∈ C + foreach (var e in W) + { + var neighbors = GetNeighbors(e.Item1); + // for each e_adj ∈ neighbourhood(e) at layer lc + foreach (var e_adj in neighbors) + { + // if eadj ∉ W + if (extendBuffer.Contains(e_adj) == false) + { + extendBuffer.Add(e_adj); + } + } + } + // W ← W ⋃ eadj + foreach (var id in extendBuffer) + { + W.Push(id, distance(id)); + } + } + + // Wd ← ∅ // queue for the discarded candidates + var Wd = new BinaryHeap(); + // while │W│ > 0 and │R│< M + while (W.Count > 0 && R.Count < M) + { + // e ← extract nearest element from W to q + var (e, ed) = W.PopNearest(); + var (fe, fd) = R.PopFarthest(); + + // if e is closer to q compared to any element from R + if (R.Count == 0 || + ed < fd) + { + // R ← R ⋃ e + R.Push(e, ed); + } + else + { + // Wd ← Wd ⋃ e + Wd.Push(e, ed); + } + } + // if keepPrunedConnections // add some of the discarded // connections from Wd + if (_options.KeepPrunedConnections) + { + // while │Wd│> 0 and │R│< M + while (Wd.Count > 0 && R.Count < M) + { + // R ← R ⋃ extract nearest element from Wd to q + var nearest = Wd.PopNearest(); + R.Push(nearest.Item1, nearest.Item2); + } + } + // return R + return R; + } + #endregion + + private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); + + public void Serialize(IBinaryWriter writer) + { + _links.Serialize(writer); + } + + public void Deserialize(IBinaryReader reader) + { + _links.Deserialize(reader); + } + + internal Histogram GetHistogram(HistogramMode mode) => _links.CalculateHistogram(mode); + } +} diff --git a/ZeroLevel.HNSW/Services/OPT/OptWorld.cs b/ZeroLevel.HNSW/Services/OPT/OptWorld.cs new file mode 100644 index 0000000..4ce962c --- /dev/null +++ b/ZeroLevel.HNSW/Services/OPT/OptWorld.cs @@ -0,0 +1,347 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using ZeroLevel.Services.Serialization; + +namespace ZeroLevel.HNSW.Services.OPT +{ + public class OptWorld + { + private readonly NSWOptions _options; + private VectorSet _vectors; + private OptLayer[] _layers; + private int EntryPoint = 0; + private int MaxLayer = 0; + private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; + private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); + internal SortedList GetNSWLinks() => _layers[0].Links; + + public OptWorld(NSWOptions options) + { + _options = options; + _vectors = new VectorSet(); + _layers = new OptLayer[_options.LayersCount]; + _layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M); + for (int i = 0; i < _options.LayersCount; i++) + { + _layers[i] = new OptLayer(_options, _vectors); + } + } + + internal OptWorld(NSWOptions options, Stream stream) + { + _options = options; + Deserialize(stream); + } + + /// + /// Search in the graph K for vectors closest to a given vector + /// + /// Given vector + /// Count of elements for search + /// + /// + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k) + { + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context) + { + if (context == null) + { + foreach (var pair in KNearest(vector, k)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + else + { + foreach (var pair in KNearest(vector, k, context)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + } + + public IEnumerable<(int, TItem, float)> Search(int k, SearchContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + else + { + foreach (var pair in KNearest(k, context)) + { + yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); + } + } + } + + /// + /// Adding vectors batch + /// + /// Vectors + /// Vector identifiers in a graph + public int[] AddItems(IEnumerable vectors) + { + _lockGraph.EnterWriteLock(); + try + { + var ids = _vectors.Append(vectors); + for (int i = 0; i < ids.Length; i++) + { + INSERT(ids[i]); + } + return ids; + } + finally + { + _lockGraph.ExitWriteLock(); + } + } + + #region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf + /// + /// Algorithm 1 + /// + private void INSERT(int q) + { + var distance = new Func(candidate => _options.Distance(_vectors[q], _vectors[candidate])); + // W ← ∅ // list for the currently found nearest elements + var W = new BinaryHeap(); + // ep ← get enter point for hnsw + //var ep = _layers[MaxLayer].FingEntryPointAtLayer(distance); + //if(ep == -1) ep = EntryPoint; + var ep = EntryPoint; + var epDist = distance(ep); + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level + int l = _layerLevelGenerator.GetRandomLayer(); + // for lc ← L … l+1 + // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа + for (int lc = L; lc > l; --lc) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[lc].KNearestAtLayer(ep, distance, W, 1); + // ep ← get the nearest element from W to q + var nearest = W.Nearest; + ep = nearest.Item1; + epDist = nearest.Item2; + W.Clear(); + } + //for lc ← min(L, l) … 0 + // connecting new node to the small world + for (int lc = Math.Min(L, l); lc >= 0; --lc) + { + if (_layers[lc].HasLinks == false) + { + _layers[lc].Append(q); + } + else + { + // W ← SEARCH - LAYER(q, ep, efConstruction, lc) + _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction); + + // ep ← W + var nearest = W.Nearest; + ep = nearest.Item1; + epDist = nearest.Item2; + + // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 + var neighbors = SelectBestForConnecting(lc, distance, W); + // add bidirectionall connectionts from neighbors to q at layer lc + // for each e ∈ neighbors // shrink connections if needed + foreach (var e in neighbors) + { + // eConn ← neighbourhood(e) at layer lc + _layers[lc].AddBidirectionallConnections(q, e.Item1, e.Item2, lc == 0); + // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer + if (e.Item2 < epDist) + { + ep = e.Item1; + epDist = e.Item2; + } + } + W.Clear(); + } + } + // if l > L + if (l > L) + { + // set enter point for hnsw to q + L = l; + MaxLayer = l; + EntryPoint = ep; + } + } + + /// + /// Get maximum allowed connections for the given level. + /// + /// + /// Article: Section 4.1: + /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also + /// has a strong influence on the search performance, especially in case of high quality(high recall) search. + /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors + /// selection heuristic is not used) leads to a very strong performance penalty at high recall. + /// Simulations also suggest that 2∙M is a good choice for Mmax0; + /// setting the parameter higher leads to performance degradation and excessive memory usage." + /// + /// The level of the layer. + /// The maximum number of connections. + private int GetM(int layer) + { + return layer == 0 ? 2 * _options.M : _options.M; + } + + private BinaryHeap SelectBestForConnecting(int layer, Func distance, BinaryHeap candidates) + { + if (_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) + return _layers[layer].SELECT_NEIGHBORS_SIMPLE(candidates, GetM(layer)); + return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer)); + } + + /// + /// Algorithm 5 + /// + private BinaryHeap KNearest(TItem q, int k) + { + _lockGraph.EnterReadLock(); + try + { + if (_vectors.Count == 0) + { + return BinaryHeap.Empty; + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new BinaryHeap(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].KNearestAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.Nearest.Item1; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(ep, distance, W, k); + // return K nearest elements from W to q + return W; + } + finally + { + _lockGraph.ExitReadLock(); + } + } + private BinaryHeap KNearest(TItem q, int k, SearchContext context) + { + _lockGraph.EnterReadLock(); + try + { + if (_vectors.Count == 0) + { + return BinaryHeap.Empty; + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new BinaryHeap(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].KNearestAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.Nearest.Item1; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(ep, distance, W, k, context); + // return K nearest elements from W to q + return W; + } + finally + { + _lockGraph.ExitReadLock(); + } + } + + private BinaryHeap KNearest(int k, SearchContext context) + { + _lockGraph.EnterReadLock(); + try + { + if (_vectors.Count == 0) + { + return BinaryHeap.Empty; + } + var distance = new Func((id1, id2) => _options.Distance(_vectors[id1], _vectors[id2])); + // W ← ∅ // set for the current nearest elements + var W = new BinaryHeap(k + 1); + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].KNearestAtLayer(W, k, context); + // return K nearest elements from W to q + return W; + } + finally + { + _lockGraph.ExitReadLock(); + } + } + #endregion + + public void Serialize(Stream stream) + { + using (var writer = new MemoryStreamWriter(stream)) + { + writer.WriteInt32(EntryPoint); + writer.WriteInt32(MaxLayer); + _vectors.Serialize(writer); + writer.WriteInt32(_layers.Length); + foreach (var l in _layers) + { + l.Serialize(writer); + } + } + } + + public void Deserialize(Stream stream) + { + using (var reader = new MemoryStreamReader(stream)) + { + this.EntryPoint = reader.ReadInt32(); + this.MaxLayer = reader.ReadInt32(); + _vectors = new VectorSet(); + _vectors.Deserialize(reader); + var countLayers = reader.ReadInt32(); + _layers = new OptLayer[countLayers]; + for (int i = 0; i < countLayers; i++) + { + _layers[i] = new OptLayer(_options, _vectors); + _layers[i].Deserialize(reader); + } + } + } + + public Histogram GetHistogram(HistogramMode mode = HistogramMode.SQRT) + => _layers[0].GetHistogram(mode); + } +} diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index 26c9606..bc65ff8 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -11,7 +11,7 @@ namespace ZeroLevel.HNSW public class SmallWorld { private readonly NSWOptions _options; - private readonly VectorSet _vectors; + private VectorSet _vectors; private Layer[] _layers; private int EntryPoint = 0; private int MaxLayer = 0; @@ -118,6 +118,8 @@ namespace ZeroLevel.HNSW // W ← ∅ // list for the currently found nearest elements IDictionary W = new Dictionary(); // ep ← get enter point for hnsw + //var ep = _layers[MaxLayer].FingEntryPointAtLayer(distance); + //if(ep == -1) ep = EntryPoint; var ep = EntryPoint; var epDist = distance(ep); // L ← level of ep // top layer for hnsw @@ -334,6 +336,7 @@ namespace ZeroLevel.HNSW { this.EntryPoint = reader.ReadInt32(); this.MaxLayer = reader.ReadInt32(); + _vectors = new VectorSet(); _vectors.Deserialize(reader); var countLayers = reader.ReadInt32(); _layers = new Layer[countLayers]; diff --git a/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj b/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj index 3c39f1c..69295c1 100644 --- a/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj +++ b/ZeroLevel.HNSW/ZeroLevel.HNSW.csproj @@ -1,7 +1,8 @@ - + net5.0 + AnyCPU;x64