diff --git a/TestHNSW/HNSWDemo/Program.cs b/TestHNSW/HNSWDemo/Program.cs index df4d556..090f6a6 100644 --- a/TestHNSW/HNSWDemo/Program.cs +++ b/TestHNSW/HNSWDemo/Program.cs @@ -99,7 +99,7 @@ namespace HNSWDemo static void Main(string[] args) { - TransformToCompactWorldTestWithAccuracity(); + FilterTest(); Console.ReadKey(); } @@ -362,8 +362,8 @@ namespace HNSWDemo static void FilterTest() { - var count = 5000; - var testCount = 1000; + var count = 1000; + var testCount = 100; var dimensionality = 128; var samples = Person.GenerateRandom(dimensionality, count); @@ -379,13 +379,13 @@ namespace HNSWDemo int K = 200; var vectors = RandomVectors(dimensionality, testCount); - var activeNodes = _database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key).ToHashSet(); + var context = new SearchContext().SetActiveNodes(_database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key)); var hits = 0; var miss = 0; foreach (var v in vectors) { - var result = world.Search(v, K, activeNodes); + var result = world.Search(v, K, context); foreach (var r in result) { var record = _database[r.Item1]; diff --git a/ZeroLevel.HNSW/Model/SearchContext.cs b/ZeroLevel.HNSW/Model/SearchContext.cs new file mode 100644 index 0000000..4cfce40 --- /dev/null +++ b/ZeroLevel.HNSW/Model/SearchContext.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; + +namespace ZeroLevel.HNSW +{ + public sealed class SearchContext + { + enum Mode + { + None, + ActiveCheck, + InactiveCheck, + ActiveInactiveCheck + } + + private HashSet _activeNodes; + private HashSet _inactiveNodes; + private Mode _mode; + + public SearchContext() + { + _mode = Mode.None; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal bool IsActiveNode(int nodeId) + { + switch (_mode) + { + case Mode.ActiveCheck: return _activeNodes.Contains(nodeId); + case Mode.InactiveCheck: return _inactiveNodes.Contains(nodeId) == false; + case Mode.ActiveInactiveCheck: return _inactiveNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId); + } + return nodeId >= 0; + } + + public SearchContext SetActiveNodes(IEnumerable activeNodes) + { + if (activeNodes != null && activeNodes.Any()) + { + if (_mode == Mode.ActiveCheck || _mode == Mode.ActiveInactiveCheck) + { + throw new InvalidOperationException("Active nodes are already defined"); + } + _activeNodes = new HashSet(activeNodes); + if (_mode == Mode.None) + { + _mode = Mode.ActiveCheck; + } + else if (_mode == Mode.InactiveCheck) + { + _mode = Mode.ActiveInactiveCheck; + } + } + return this; + } + + public SearchContext SetInactiveNodes(IEnumerable inactiveNodes) + { + if (inactiveNodes != null && inactiveNodes.Any()) + { + if (_mode == Mode.InactiveCheck || _mode == Mode.ActiveInactiveCheck) + { + throw new InvalidOperationException("Inctive nodes are already defined"); + } + _inactiveNodes = new HashSet(inactiveNodes); + if (_mode == Mode.None) + { + _mode = Mode.InactiveCheck; + } + else if (_mode == Mode.ActiveCheck) + { + _mode = Mode.ActiveInactiveCheck; + } + } + return this; + } + } +} diff --git a/ZeroLevel.HNSW/ReadOnlySmallWorld.cs b/ZeroLevel.HNSW/ReadOnlySmallWorld.cs index 7f6929e..8241734 100644 --- a/ZeroLevel.HNSW/ReadOnlySmallWorld.cs +++ b/ZeroLevel.HNSW/ReadOnlySmallWorld.cs @@ -37,9 +37,9 @@ namespace ZeroLevel.HNSW } } - public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes) + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context) { - if (activeNodes == null) + if (context == null) { foreach (var pair in KNearest(vector, k)) { @@ -48,7 +48,7 @@ namespace ZeroLevel.HNSW } else { - foreach (var pair in KNearest(vector, k, activeNodes)) + foreach (var pair in KNearest(vector, k, context)) { yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); } @@ -87,7 +87,7 @@ namespace ZeroLevel.HNSW return W.Select(p => (p.Key, p.Value)); } - private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes) + private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context) { if (_vectors.Count == 0) { @@ -111,7 +111,7 @@ namespace ZeroLevel.HNSW W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); + _layers[0].KNearestAtLayer(ep, distance, W, k, context); // return K nearest elements from W to q return W.Select(p => (p.Key, p.Value)); } @@ -143,7 +143,7 @@ namespace ZeroLevel.HNSW _layers = new ReadOnlyLayer[countLayers]; for (int i = 0; i < countLayers; i++) { - _layers[i] = new ReadOnlyLayer(_options, _vectors); + _layers[i] = new ReadOnlyLayer(_vectors); _layers[i].Deserialize(reader); } } diff --git a/ZeroLevel.HNSW/Services/Layer.cs b/ZeroLevel.HNSW/Services/Layer.cs index 65d90e2..6e5c5ea 100644 --- a/ZeroLevel.HNSW/Services/Layer.cs +++ b/ZeroLevel.HNSW/Services/Layer.cs @@ -166,7 +166,7 @@ namespace ZeroLevel.HNSW /// query element /// enter points ep /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, HashSet activeNodes) + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, SearchContext context) { /* * v ← ep // set of visited elements @@ -195,7 +195,7 @@ namespace ZeroLevel.HNSW var C = new Dictionary(); C.Add(entryPointId, targetCosts(entryPointId)); // W ← ep // dynamic list of found nearest neighbors - if (activeNodes.Contains(entryPointId)) + if (context.IsActiveNode(entryPointId)) { W.Add(entryPointId, C[entryPointId]); } @@ -225,7 +225,7 @@ namespace ZeroLevel.HNSW { // enqueue perspective neighbours to expansion list var neighbourDistance = targetCosts(neighbourId); - if (activeNodes.Contains(neighbourId)) + if (context.IsActiveNode(neighbourId)) { if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) { diff --git a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs index f35905e..8273cfb 100644 --- a/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs +++ b/ZeroLevel.HNSW/Services/ReadOnly/ReadOnlyLayer.cs @@ -11,18 +11,15 @@ namespace ZeroLevel.HNSW internal sealed class ReadOnlyLayer : IBinarySerializable { - private readonly NSWReadOnlyOption _options; private readonly ReadOnlyVectorSet _vectors; private readonly ReadOnlyCompactBiDirectionalLinksSet _links; /// /// HNSW layer /// - /// HNSW graph options /// General vector set - internal ReadOnlyLayer(NSWReadOnlyOption options, ReadOnlyVectorSet vectors) + internal ReadOnlyLayer(ReadOnlyVectorSet vectors) { - _options = options; _vectors = vectors; _links = new ReadOnlyCompactBiDirectionalLinksSet(); } @@ -114,7 +111,7 @@ namespace ZeroLevel.HNSW /// query element /// enter points ep /// Output: ef closest neighbors to q - internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, HashSet activeNodes) + internal void KNearestAtLayer(int entryPointId, Func targetCosts, IDictionary W, int ef, SearchContext context) { /* * v ← ep // set of visited elements @@ -143,7 +140,7 @@ namespace ZeroLevel.HNSW var C = new Dictionary(); C.Add(entryPointId, targetCosts(entryPointId)); // W ← ep // dynamic list of found nearest neighbors - if (activeNodes.Contains(entryPointId)) + if (context.IsActiveNode(entryPointId)) { W.Add(entryPointId, C[entryPointId]); } @@ -173,7 +170,7 @@ namespace ZeroLevel.HNSW { // enqueue perspective neighbours to expansion list var neighbourDistance = targetCosts(neighbourId); - if (activeNodes.Contains(neighbourId)) + if (context.IsActiveNode(neighbourId)) { if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) { @@ -195,110 +192,6 @@ namespace ZeroLevel.HNSW C.Clear(); v.Clear(); } - - /// - /// Algorithm 3 - /// - internal IDictionary SELECT_NEIGHBORS_SIMPLE(Func distance, IDictionary candidates, int M) - { - var bestN = M; - var W = new Dictionary(candidates); - if (W.Count > bestN) - { - var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); }); - while (W.Count > bestN) - { - popFarther(); - } - } - // return M nearest elements from C to q - return W; - } - - - - /// - /// Algorithm 4 - /// - /// base element - /// candidate elements - /// flag indicating whether or not to extend candidate list - /// flag indicating whether or not to add discarded elements - /// Output: M elements selected by the heuristic - internal IDictionary SELECT_NEIGHBORS_HEURISTIC(Func distance, IDictionary candidates, int M) - { - // R ← ∅ - var R = new Dictionary(); - // W ← C // working queue for the candidates - var W = new Dictionary(candidates); - // if extendCandidates // extend candidates by their neighbors - if (_options.ExpandBestSelection) - { - var extendBuffer = new HashSet(); - // for each e ∈ C - foreach (var e in W) - { - var neighbors = GetNeighbors(e.Key); - // for each e_adj ∈ neighbourhood(e) at layer lc - foreach (var e_adj in neighbors) - { - // if eadj ∉ W - if (extendBuffer.Contains(e_adj) == false) - { - extendBuffer.Add(e_adj); - } - } - } - // W ← W ⋃ eadj - foreach (var id in extendBuffer) - { - W[id] = distance(id); - } - } - - // Wd ← ∅ // queue for the discarded candidates - var Wd = new Dictionary(); - - - var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); }); - var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); }); - var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); }); - - - // while │W│ > 0 and │R│< M - while (W.Count > 0 && R.Count < M) - { - // e ← extract nearest element from W to q - var (e, ed) = popCandidate(); - var (fe, fd) = fartherFromResult(); - - // if e is closer to q compared to any element from R - if (R.Count == 0 || - ed < fd) - { - // R ← R ⋃ e - R.Add(e, ed); - } - else - { - // Wd ← Wd ⋃ e - Wd.Add(e, ed); - } - } - // if keepPrunedConnections // add some of the discarded // connections from Wd - if (_options.KeepPrunedConnections) - { - // while │Wd│> 0 and │R│< M - while (Wd.Count > 0 && R.Count < M) - { - // R ← R ⋃ extract nearest element from Wd to q - var nearest = popNearestDiscarded(); - R[nearest.Item1] = nearest.Item2; - } - } - // return R - return R; - } #endregion private IEnumerable GetNeighbors(int id) => _links.FindLinksForId(id); diff --git a/ZeroLevel.HNSW/SmallWorld.cs b/ZeroLevel.HNSW/SmallWorld.cs index fe51c9e..fa57eef 100644 --- a/ZeroLevel.HNSW/SmallWorld.cs +++ b/ZeroLevel.HNSW/SmallWorld.cs @@ -51,9 +51,9 @@ namespace ZeroLevel.HNSW } } - public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet activeNodes) + public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context) { - if (activeNodes == null) + if (context == null) { foreach (var pair in KNearest(vector, k)) { @@ -62,7 +62,7 @@ namespace ZeroLevel.HNSW } else { - foreach (var pair in KNearest(vector, k, activeNodes)) + foreach (var pair in KNearest(vector, k, context)) { yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); } @@ -236,7 +236,7 @@ namespace ZeroLevel.HNSW _lockGraph.ExitReadLock(); } } - private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet activeNodes) + private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context) { _lockGraph.EnterReadLock(); try @@ -263,7 +263,7 @@ namespace ZeroLevel.HNSW W.Clear(); } // W ← SEARCH-LAYER(q, ep, ef, lc =0) - _layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); + _layers[0].KNearestAtLayer(ep, distance, W, k, context); // return K nearest elements from W to q return W.Select(p => (p.Key, p.Value)); }