From 0e2f693b879d2508264b8b714288cb2f0ac6a067 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 Dec 2021 02:03:27 +0300 Subject: [PATCH] HNSW update Fix bugs --- TestHNSW/HNSWDemo/Program.cs | 8 ++-- ZeroLevel.HNSW/Layer.cs | 32 +++++++------ ZeroLevel.HNSW/Model/NSWOptions.cs | 47 +++++++++++++++++-- .../Services/CompactBiDirectionalLinksSet.cs | 25 +++++++--- ZeroLevel.HNSW/SmallWorld.cs | 4 +- 5 files changed, 84 insertions(+), 32 deletions(-) diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 40760bc..ee4610f 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -77,13 +77,13 @@ namespace HNSWDemo static void Main(string[] args) { var dimensionality = 128; - var testCount = 1000; - var count = 100000; - var batchSize = 5000; + var testCount = 5000; + var count = 10000; + var batchSize = 1000; var samples = Person.GenerateRandom(dimensionality, count); var sw = new Stopwatch(); - var world = new SmallWorld(NSWOptions.Create(6, 4, 120, 120, CosineDistance.ForUnits)); + var world = new SmallWorld(NSWOptions.Create(6, 10, 200, 200, CosineDistance.ForUnits, false, false, selectionHeuristic: NeighbourSelectionHeuristic.SelectHeuristic)); for (int i = 0; i < (count / batchSize); i++) { diff --git a/ZeroLevel.HNSW/Layer.cs b/ZeroLevel.HNSW/Layer.cs index 0a919de..577c6b0 100644 --- a/ZeroLevel.HNSW/Layer.cs +++ b/ZeroLevel.HNSW/Layer.cs @@ -94,7 +94,9 @@ namespace ZeroLevel.HNSW // W ← ep // dynamic list of found nearest neighbors W.Add(entryPointId, C[entryPointId]); - + var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); }); + var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); + var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); // run bfs while (C.Count > 0) { @@ -164,14 +166,14 @@ namespace ZeroLevel.HNSW /// flag indicating whether or not to extend candidate list /// flag indicating whether or not to add discarded elements /// Output: M elements selected by the heuristic - public IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M, bool extendCandidates, bool keepPrunedConnections) + public IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M) { // R ← ∅ var R = new Dictionary(); // W ← C // working queue for the candidates var W = new Dictionary(candidates); // if extendCandidates // extend candidates by their neighbors - if (extendCandidates) + if (_options.ExpandBestSelection) { var extendBuffer = new HashSet(); // for each e ∈ C @@ -191,7 +193,7 @@ namespace ZeroLevel.HNSW // W ← W ⋃ eadj foreach (var id in extendBuffer) { - W.Add(id, distance(id)); + W[id] = distance(id); } } @@ -201,7 +203,7 @@ namespace ZeroLevel.HNSW var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); - var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); + var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); }); // while │W│ > 0 and │R│< M @@ -218,21 +220,21 @@ namespace ZeroLevel.HNSW // R ← R ⋃ e R.Add(e, ed); } - // else + else { // Wd ← Wd ⋃ e Wd.Add(e, ed); } - // if keepPrunedConnections // add some of the discarded // connections from Wd - if (keepPrunedConnections) + } + // if keepPrunedConnections // add some of the discarded // connections from Wd + if (_options.KeepPrunedConnections) + { + // while │Wd│> 0 and │R│< M + while (Wd.Count > 0 && R.Count < M) { - // while │Wd│> 0 and │R│< M - while (Wd.Count > 0 && R.Count < M) - { - // R ← R ⋃ extract nearest element from Wd to q - var nearest = popNearestDiscarded(); - R.Add(nearest.Item1, nearest.Item2); - } + // R ← R ⋃ extract nearest element from Wd to q + var nearest = popNearestDiscarded(); + R[nearest.Item1] = nearest.Item2; } } // return R diff --git a/ZeroLevel.HNSW/Model/NSWOptions.cs b/ZeroLevel.HNSW/Model/NSWOptions.cs index c888eaa..b4af8cc 100644 --- a/ZeroLevel.HNSW/Model/NSWOptions.cs +++ b/ZeroLevel.HNSW/Model/NSWOptions.cs @@ -2,10 +2,24 @@ namespace ZeroLevel.HNSW { - public sealed class NSWOptions + /// + /// Type of heuristic to select best neighbours for a node. + /// + public enum NeighbourSelectionHeuristic { - public const int FARTHEST_DIVIDER = 3; + /// + /// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in + /// + SelectSimple, + + /// + /// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in + /// + SelectHeuristic + } + public sealed class NSWOptions + { /// /// Mox node connections on Layer /// @@ -24,19 +38,42 @@ namespace ZeroLevel.HNSW /// public readonly Func Distance; + public readonly bool ExpandBestSelection; + + public readonly bool KeepPrunedConnections; + + public readonly NeighbourSelectionHeuristic SelectionHeuristic; + public readonly int LayersCount; - private NSWOptions(int layersCount, int m, int ef, int ef_construction, Func distance) + private NSWOptions(int layersCount, + int m, + int ef, + int ef_construction, + Func distance, + bool expandBestSelection, + bool keepPrunedConnections, + NeighbourSelectionHeuristic selectionHeuristic) { LayersCount = layersCount; M = m; EF = ef; EFConstruction = ef_construction; Distance = distance; + ExpandBestSelection = expandBestSelection; + KeepPrunedConnections = keepPrunedConnections; + SelectionHeuristic = selectionHeuristic; } - public static NSWOptions Create(int layersCount, int M, int EF, int EF_construction, Func distance) => - new NSWOptions(layersCount, M, EF, EF_construction, distance); + public static NSWOptions Create(int layersCount, + int M, + int EF, + int EF_construction, + Func distance, + bool expandBestSelection = false, + bool keepPrunedConnections = false, + NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) => + new NSWOptions(layersCount, M, EF, EF_construction, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic); } } diff --git a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs index 779acdd..026200b 100644 --- a/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs +++ b/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSet.cs @@ -14,15 +14,14 @@ namespace ZeroLevel.HNSW private SortedList _set = new SortedList(); - public (int, int, float) this[int index] + public (int, int) this[int index] { get { var k = _set.Keys[index]; - var d = _set.Values[index]; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); - return (id1, id2, d); + return (id1, id2); } } @@ -72,10 +71,22 @@ namespace ZeroLevel.HNSW { _set.Remove(k_id1_id2); _set.Remove(k_id2_id1); - _set.Add(k_id_id1, distanceToId1); - _set.Add(k_id1_id, distanceToId1); - _set.Add(k_id_id2, distanceToId2); - _set.Add(k_id2_id, distanceToId2); + if (!_set.ContainsKey(k_id_id1)) + { + _set.Add(k_id_id1, distanceToId1); + } + if (!_set.ContainsKey(k_id1_id)) + { + _set.Add(k_id1_id, distanceToId1); + } + if (!_set.ContainsKey(k_id_id2)) + { + _set.Add(k_id_id2, distanceToId2); + } + if (!_set.ContainsKey(k_id2_id)) + { + _set.Add(k_id2_id, distanceToId2); + } } finally { diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index f93449c..477957d 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -192,7 +192,9 @@ namespace ZeroLevel.HNSW private IDictionary SelectBestForConnecting(int layer, Func distance, IDictionary candidates) { - return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer)); + if(_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple) + return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer)); + return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer)); } ///