From 1f967908b906c5d05693c342159b0684568dd54c Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 Dec 2021 01:13:44 +0300 Subject: [PATCH] HNSW Fix algorithm --- TestHNSW/HNSWDemo/Program.cs | 25 +-- ZeroLevel.HNSW/Layer.cs | 240 +++++++++++---------------- ZeroLevel.HNSW/Services/VectorSet.cs | 34 +++- ZeroLevel.HNSW/SmallWorld.cs | 201 +++++++++++++++++----- 4 files changed, 302 insertions(+), 198 deletions(-) diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index 07df108..40760bc 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -84,6 +84,7 @@ namespace HNSWDemo var sw = new Stopwatch(); var world = new SmallWorld(NSWOptions.Create(6, 4, 120, 120, CosineDistance.ForUnits)); + for (int i = 0; i < (count / batchSize); i++) { var batch = samples.Skip(i * batchSize).Take(batchSize).ToArray(); @@ -101,18 +102,18 @@ namespace HNSWDemo //HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; })); -/*var fackupCount = 0; - foreach (var v in vectors) - { - var result = world.Search(v, 10, filter); - foreach (var r in result) - { - if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale) - { - Interlocked.Increment(ref fackupCount); - } - } - }*/ + /*var fackupCount = 0; + foreach (var v in vectors) + { + var result = world.Search(v, 10, filter); + foreach (var r in result) + { + if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale) + { + Interlocked.Increment(ref fackupCount); + } + } + }*/ //Console.WriteLine($"Completed. Fackup count: {fackupCount}"); Console.ReadKey(); diff --git a/ZeroLevel.HNSW/Layer.cs b/ZeroLevel.HNSW/Layer.cs index 7fb8f7c..0a919de 100644 --- a/ZeroLevel.HNSW/Layer.cs +++ b/ZeroLevel.HNSW/Layer.cs @@ -13,6 +13,11 @@ namespace ZeroLevel.HNSW private readonly VectorSet _vectors; private CompactBiDirectionalLinksSet _links = new CompactBiDirectionalLinksSet(); + /// + /// Count nodes at layer + /// + public int Count => (_links.Count >> 1); + public Layer(NSWOptions options, VectorSet vectors) { _options = options; @@ -50,148 +55,107 @@ namespace ZeroLevel.HNSW } } - public int GetEntryPointFor(int q) - { - var randomLinkId = DefaultRandomGenerator.Instance.Next(0, _links.Count); - var entryId = _links[randomLinkId].Item1; - var v = new VisitedBitSet(_vectors._set.Count, _options.M); - // v ← ep // set of visited elements - var (ep, ed) = DFS_SearchMinFrom(entryId, q, v); - return ep; - } - - private (int, float) DFS_SearchMinFrom(int entryId, int id, VisitedBitSet visited) - { - visited.Add(entryId); - int candidate = entryId; - var candidateDistance = _options.Distance(_vectors[entryId], _vectors[id]); - int counter = 0; - do - { - var (mid, dist) = GetMinNearest(visited, entryId, candidate, candidateDistance); - if (dist > candidateDistance) - { - break; - } - candidate = mid; - candidateDistance = dist; - - counter++; - } while (counter < _options.EFConstruction); - return (candidate, candidateDistance); - } - - private (int, float) GetMinNearest(VisitedBitSet visited, int entryId, int id, float entryDistance) - { - var minId = entryId; - var minDist = entryDistance; - foreach (var candidate in _links.FindLinksForId(entryId).Select(l => l.Item2)) - { - if (visited.Contains(candidate) == false) - { - var dist = _options.Distance(_vectors[candidate], _vectors[id]); - if (dist < minDist) - { - minDist = dist; - minId = candidate; - } - visited.Add(candidate); - } - } - return (minId, minDist); - } - #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf - /// /// Algorithm 2 /// /// query element /// enter points ep /// Output: ef closest neighbors to q - public IDictionary SEARCH_LAYER(int q, int ep, int ef) + public void RunKnnAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef) { - var v = new VisitedBitSet(_vectors._set.Count, _options.M); + /* + * v ← ep // set of visited elements + * C ← ep // set of candidates + * W ← ep // dynamic list of found nearest neighbors + * while │C│ > 0 + * c ← extract nearest element from C to q + * f ← get furthest element from W to q + * if distance(c, q) > distance(f, q) + * break // all elements in W are evaluated + * for each e ∈ neighbourhood(c) at layer lc // update C and W + * if e ∉ v + * v ← v ⋃ e + * f ← get furthest element from W to q + * if distance(e, q) < distance(f, q) or │W│ < ef + * C ← C ⋃ e + * W ← W ⋃ e + * if │W│ > ef + * remove furthest element from W to q + * return W + */ + + var v = new VisitedBitSet(_vectors.Count, _options.M); // v ← ep // set of visited elements - v.Add(ep); + v.Add(entryPointId); // C ← ep // set of candidates var C = new Dictionary(); - C.Add(ep, _options.Distance(_vectors[ep], _vectors[q])); + C.Add(entryPointId, targetCosts(entryPointId)); // W ← ep // dynamic list of found nearest neighbors - var W = new Dictionary(); - W.Add(ep, C[ep]); - // while │C│ > 0 + W.Add(entryPointId, C[entryPointId]); + + + // run bfs while (C.Count > 0) { - // c ← extract nearest element from C to q - var nearest = W.OrderBy(p => p.Value).First(); - var c = nearest.Key; - var md = nearest.Value; - // var (c, md) = GetMinimalDistanceIndex(C, q); - C.Remove(c); - // f ← get furthest element from W to q - var f = W.OrderBy(p => p.Value).First().Key; - //var f = GetMaximalDistanceIndex(W, q); - // if distance(c, q) > distance(f, q) - if (_options.Distance(_vectors[c], _vectors[q]) > _options.Distance(_vectors[f], _vectors[q])) + // get next candidate to check and expand + var toExpand = popCandidate(); + var farthestResult = fartherFromResult(); + if (toExpand.Item2 > farthestResult.Item2) { - // break // all elements in W are evaluated + // the closest candidate is farther than farthest result break; } - // for each e ∈ neighbourhood(c) at layer lc // update C and W - foreach (var l in _links.FindLinksForId(c)) + + // expand candidate + var neighboursIds = GetNeighbors(toExpand.Item1).ToArray(); + for (int i = 0; i < neighboursIds.Length; ++i) { - var e = l.Item2; - // if e ∉ v - if (v.Contains(e) == false) + int neighbourId = neighboursIds[i]; + if (!v.Contains(neighbourId)) { - // v ← v ⋃ e - v.Add(e); - // f ← get furthest element from W to q - f = W.OrderByDescending(p => p.Value).First().Key; - //f = GetMaximalDistanceIndex(W, q); - // if distance(e, q) < distance(f, q) or │W│ < ef - var ed = _options.Distance(_vectors[e], _vectors[q]); - if (ed > _options.Distance(_vectors[f], _vectors[q]) - || W.Count < ef) + // enqueue perspective neighbours to expansion list + farthestResult = fartherFromResult(); + + var neighbourDistance = targetCosts(neighbourId); + if (W.Count < ef || neighbourDistance < farthestResult.Item2) { - // C ← C ⋃ e - C.Add(e, ed); - // W ← W ⋃ e - W.Add(e, ed); - // if │W│ > ef + C.Add(neighbourId, neighbourDistance); + W.Add(neighbourId, neighbourDistance); if (W.Count > ef) { - // remove furthest element from W to q - f = W.OrderByDescending(p => p.Value).First().Key; - //f = GetMaximalDistanceIndex(W, q); - W.Remove(f); + fartherPopFromResult(); } } + v.Add(neighbourId); } } } - // return W - return W; + C.Clear(); + v.Clear(); } /// /// Algorithm 3 /// - /// base element - /// candidate elements - /// Output: M nearest elements to q - public IDictionary SELECT_NEIGHBORS_SIMPLE(int q, IDictionary C) + public IDictionary SELECT_NEIGHBORS_SIMPLE(Func distance, IDictionary candidates, int M) { - if (C.Count <= _options.M) + var bestN = M; + var W = new Dictionary(candidates); + if (W.Count > bestN) { - return new Dictionary(C); + var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); + while (W.Count > bestN) + { + popFarther(); + } } - var output = new Dictionary(); // return M nearest elements from C to q - return new Dictionary(C.OrderBy(p => p.Value).Take(_options.M)); + return W; } + + /// /// Algorithm 4 /// @@ -200,41 +164,56 @@ namespace ZeroLevel.HNSW /// flag indicating whether or not to extend candidate list /// flag indicating whether or not to add discarded elements /// Output: M elements selected by the heuristic - public IDictionary SELECT_NEIGHBORS_HEURISTIC(int q, IDictionary C, bool extendCandidates, bool keepPrunedConnections) + public IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M, bool extendCandidates, bool keepPrunedConnections) { // R ← ∅ var R = new Dictionary(); // W ← C // working queue for the candidates - var W = new List(C.Select(p => p.Key)); + var W = new Dictionary(candidates); // if extendCandidates // extend candidates by their neighbors if (extendCandidates) { + var extendBuffer = new HashSet(); // for each e ∈ C - foreach (var e in C) + foreach (var e in W) { + var neighbors = GetNeighbors(e.Key); // for each e_adj ∈ neighbourhood(e) at layer lc - foreach (var l in _links.FindLinksForId(e.Key)) + foreach (var e_adj in neighbors) { - var e_adj = l.Item2; // if eadj ∉ W - if (W.Contains(e_adj) == false) + if (extendBuffer.Contains(e_adj) == false) { - // W ← W ⋃ eadj - W.Add(e_adj); + extendBuffer.Add(e_adj); } } } + // W ← W ⋃ eadj + foreach (var id in extendBuffer) + { + W.Add(id, distance(id)); + } } + // Wd ← ∅ // queue for the discarded candidates var Wd = new Dictionary(); + + + var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); + var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); + var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); + + // while │W│ > 0 and │R│< M - while (W.Count > 0 && R.Count < _options.M) + while (W.Count > 0 && R.Count < M) { // e ← extract nearest element from W to q - var (e, ed) = GetMinimalDistanceIndex(W, q); - W.Remove(e); + var (e, ed) = popCandidate(); + var (fe, fd) = fartherFromResult(); + // if e is closer to q compared to any element from R - if (ed < R.Min(pair => pair.Value)) + if (R.Count == 0 || + ed < fd) { // R ← R ⋃ e R.Add(e, ed); @@ -248,37 +227,20 @@ namespace ZeroLevel.HNSW if (keepPrunedConnections) { // while │Wd│> 0 and │R│< M - while (Wd.Count > 0 && R.Count < _options.M) + while (Wd.Count > 0 && R.Count < M) { // R ← R ⋃ extract nearest element from Wd to q - var nearest = Wd.Aggregate((l, r) => l.Value < r.Value ? l : r); - Wd.Remove(nearest.Key); - R.Add(nearest.Key, nearest.Value); + var nearest = popNearestDiscarded(); + R.Add(nearest.Item1, nearest.Item2); } } } // return R return R; } - - #endregion - - private (int, float) GetMinimalDistanceIndex(IList self, int q) - { - float min = _options.Distance(_vectors[self[0]], _vectors[q]); - int minIndex = 0; - for (int i = 1; i < self.Count; ++i) - { - var dist = _options.Distance(_vectors[self[i]], _vectors[q]); - if (dist < min) - { - min = self[i]; - minIndex = i; - } - } - return (minIndex, min); - } + + private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); } -} +} \ No newline at end of file diff --git a/ZeroLevel.HNSW/Services/VectorSet.cs b/ZeroLevel.HNSW/Services/VectorSet.cs index fd5d38f..227a2c8 100644 --- a/ZeroLevel.HNSW/Services/VectorSet.cs +++ b/ZeroLevel.HNSW/Services/VectorSet.cs @@ -1,15 +1,16 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Threading; namespace ZeroLevel.HNSW { public class VectorSet { - public IList _set = new List(); + private List _set = new List(); + private SpinLock _lock = new SpinLock(); public T this[int index] => _set[index]; - - SpinLock _lock = new SpinLock(); + public int Count => _set.Count; public int Append(T vector) { @@ -27,5 +28,30 @@ namespace ZeroLevel.HNSW if (gotLock) _lock.Exit(); } } + + public int[] Append(IEnumerable vectors) + { + bool gotLock = false; + int startIndex, endIndex; + gotLock = false; + try + { + _lock.Enter(ref gotLock); + startIndex = _set.Count; + _set.AddRange(vectors); + endIndex = _set.Count; + } + finally + { + // Only give up the lock if you actually acquired it + if (gotLock) _lock.Exit(); + } + var ids = new int[endIndex - startIndex]; + for (int i = startIndex, j = 0; i < endIndex; i++, j++) + { + ids[j] = i; + } + return ids; + } } } diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index d177a69..f93449c 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -1,9 +1,38 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; namespace ZeroLevel.HNSW { + public class ProbabilityLayerNumberGenerator + { + private const float DIVIDER = 4.362f; + private readonly float[] _probabilities; + + public ProbabilityLayerNumberGenerator(int maxLayers, int M) + { + _probabilities = new float[maxLayers]; + var probability = 1.0f / DIVIDER; + for (int i = 0; i < maxLayers; i++) + { + _probabilities[i] = probability; + probability /= DIVIDER; + } + } + + public int GetRandomLayer() + { + var probability = DefaultRandomGenerator.Instance.NextFloat(); + for (int i = 0; i < _probabilities.Length; i++) + { + if (probability > _probabilities[i]) + return i; + } + return 0; + } + } + public class SmallWorld { private readonly NSWOptions _options; @@ -12,12 +41,17 @@ namespace ZeroLevel.HNSW private Layer EnterPointsLayer => _layers[_layers.Length - 1]; private Layer LastLayer => _layers[0]; + private int EntryPoint = -1; + private int MaxLayer = -1; + private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator; + private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim(); public SmallWorld(NSWOptions options) { _options = options; _vectors = new VectorSet(); _layers = new Layer[_options.LayersCount]; + _layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M); for (int i = 0; i < _options.LayersCount; i++) { _layers[i] = new Layer(_options, _vectors); @@ -31,93 +65,174 @@ namespace ZeroLevel.HNSW public int[] AddItems(IEnumerable vectors) { - var insert = vectors.ToArray(); - var ids = new int[insert.Length]; - for (int i = 0; i < insert.Length; i++) + _lockGraph.EnterWriteLock(); + try { - var item = insert[i]; - ids[i] = Insert(item); + var ids = _vectors.Append(vectors); + for (int i = 0; i < ids.Length; i++) + { + INSERT(ids[i]); + } + return ids; + } + finally + { + _lockGraph.ExitWriteLock(); } - return ids; } - public int Insert(TItem item) + public void TestLevelGenerator() { - var id = _vectors.Append(item); - INSERT(id); - return id; + var levels = new Dictionary(); + for (int i = 0; i < 10000; i++) + { + var level = _layerLevelGenerator.GetRandomLayer(); + if (levels.ContainsKey(level) == false) + { + levels.Add(level, 1); + } + else + { + levels[level] += 1.0f; + } + } + foreach (var pair in levels.OrderBy(l => l.Key)) + { + Console.WriteLine($"[{pair.Key}]: {pair.Value / 100.0f}% ({pair.Value})"); + } } #region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf /// /// Algorithm 1 /// - /// new element public void INSERT(int q) { + var distance = new Func(candidate => _options.Distance(_vectors[q], _vectors[candidate])); + // W ← ∅ // list for the currently found nearest elements - IDictionary W; + IDictionary W = new Dictionary(); // ep ← get enter point for hnsw - var ep = EnterPointsLayer.GetEntryPointFor(q); + var ep = EntryPoint == -1 ? 0 : EntryPoint; + var epDist = 0.0f; // L ← level of ep // top layer for hnsw - var L = _layers.Length - 1; + var L = MaxLayer; // l ← ⌊-ln(unif(0..1))∙mL⌋ // new element’s level - int l = DefaultRandomGenerator.Instance.Next(0, _options.LayersCount - 1); + int l = _layerLevelGenerator.GetRandomLayer(); + if (L == -1) + { + L = l; + MaxLayer = l; + } // for lc ← L … l+1 - for (int lc = L; lc > l; lc--) + // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа + for (int lc = L; lc > l; --lc) { // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - W = _layers[lc].SEARCH_LAYER(q, ep, 1); + _layers[lc].RunKnnAtLayer(ep, distance, W, 1); // ep ← get the nearest element from W to q - ep = W.OrderBy(p => p.Value).First().Key; + var nearest = W.OrderBy(p => p.Value).First(); + ep = nearest.Key; + epDist = nearest.Value; + W.Clear(); } //for lc ← min(L, l) … 0 - for (int lc = Math.Min(L, l); lc >= 0; lc--) + // connecting new node to the small world + for (int lc = Math.Min(L, l); lc >= 0; --lc) { // W ← SEARCH - LAYER(q, ep, efConstruction, lc) - W = _layers[lc].SEARCH_LAYER(q, ep, _options.EFConstruction); + _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 - var neighbors = _layers[lc].SELECT_NEIGHBORS_SIMPLE(q, W); + var neighbors = SelectBestForConnecting(lc, distance, W);; // add bidirectionall connectionts from neighbors to q at layer lc // for each e ∈ neighbors // shrink connections if needed foreach (var e in neighbors) { // eConn ← neighbourhood(e) at layer lc _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value); + // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer + if (e.Value < epDist) + { + ep = e.Key; + epDist = e.Value; + } } // ep ← W ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); + } + // if l > L + if (l > L) + { + // set enter point for hnsw to q + L = l; + MaxLayer = l; + EntryPoint = ep; } - // if l > L - // set enter point for hnsw to q + } + + /// + /// Get maximum allowed connections for the given level. + /// + /// + /// Article: Section 4.1: + /// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also + /// has a strong influence on the search performance, especially in case of high quality(high recall) search. + /// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors + /// selection heuristic is not used) leads to a very strong performance penalty at high recall. + /// Simulations also suggest that 2∙M is a good choice for Mmax0; + /// setting the parameter higher leads to performance degradation and excessive memory usage." + /// + /// The level of the layer. + /// The maximum number of connections. + internal int GetM(int layer) + { + return layer == 0 ? 2 * _options.M : _options.M; + } + + private IDictionary SelectBestForConnecting(int layer, Func distance, IDictionary candidates) + { + return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer)); } /// /// Algorithm 5 /// - /// query element - /// number of nearest neighbors to return - /// : K nearest elements to q - public IList K_NN_SEARCH(int q, int K) + internal IEnumerable<(int, float)> KNearest(TItem q, int k) { - // W ← ∅ // set for the current nearest elements - IDictionary W; - // ep ← get enter point for hnsw - var ep = EnterPointsLayer.GetEntryPointFor(q); - // L ← level of ep // top layer for hnsw - var L = _options.LayersCount - 1; - // for lc ← L … 1 - for (var lc = L; lc > 0; lc--) + _lockGraph.EnterReadLock(); + try { - // W ← SEARCH-LAYER(q, ep, ef = 1, lc) - W = _layers[lc].SEARCH_LAYER(q, ep, 1); - // ep ← get nearest element from W to q - ep = W.OrderBy(p => p.Value).First().Key; + if (_vectors.Count == 0) + { + return Enumerable.Empty<(int, float)>(); + } + var distance = new Func(candidate => _options.Distance(q, _vectors[candidate])); + + // W ← ∅ // set for the current nearest elements + var W = new Dictionary(k + 1); + // ep ← get enter point for hnsw + var ep = EntryPoint; + // L ← level of ep // top layer for hnsw + var L = MaxLayer; + // for lc ← L … 1 + for (int layer = L; layer > 0; --layer) + { + // W ← SEARCH-LAYER(q, ep, ef = 1, lc) + _layers[layer].RunKnnAtLayer(ep, distance, W, 1); + // ep ← get nearest element from W to q + ep = W.OrderBy(p => p.Value).First().Key; + W.Clear(); + } + // W ← SEARCH-LAYER(q, ep, ef, lc =0) + _layers[0].RunKnnAtLayer(ep, distance, W, k); + // return K nearest elements from W to q + return W.Select(p => (p.Key, p.Value)); + } + finally + { + _lockGraph.ExitReadLock(); } - // W ← SEARCH-LAYER(q, ep, ef, lc =0) - W = LastLayer.SEARCH_LAYER(q, ep, _options.EF); - // return K nearest elements from W to q - return W.OrderBy(p => p.Value).Take(K).Select(p => p.Key).ToList(); } #endregion }