HNSW. Append inactive nodes

Active and inactive nodes placed in SearchContext
pull/1/head
unknown 3 years ago
parent f8b72e38e3
commit c128396ac1

@ -99,7 +99,7 @@ namespace HNSWDemo
static void Main(string[] args) static void Main(string[] args)
{ {
TransformToCompactWorldTestWithAccuracity(); FilterTest();
Console.ReadKey(); Console.ReadKey();
} }
@ -362,8 +362,8 @@ namespace HNSWDemo
static void FilterTest() static void FilterTest()
{ {
var count = 5000; var count = 1000;
var testCount = 1000; var testCount = 100;
var dimensionality = 128; var dimensionality = 128;
var samples = Person.GenerateRandom(dimensionality, count); var samples = Person.GenerateRandom(dimensionality, count);
@ -379,13 +379,13 @@ namespace HNSWDemo
int K = 200; int K = 200;
var vectors = RandomVectors(dimensionality, testCount); var vectors = RandomVectors(dimensionality, testCount);
var activeNodes = _database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key).ToHashSet(); var context = new SearchContext().SetActiveNodes(_database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key));
var hits = 0; var hits = 0;
var miss = 0; var miss = 0;
foreach (var v in vectors) foreach (var v in vectors)
{ {
var result = world.Search(v, K, activeNodes); var result = world.Search(v, K, context);
foreach (var r in result) foreach (var r in result)
{ {
var record = _database[r.Item1]; var record = _database[r.Item1];

@ -0,0 +1,81 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
namespace ZeroLevel.HNSW
{
public sealed class SearchContext
{
enum Mode
{
None,
ActiveCheck,
InactiveCheck,
ActiveInactiveCheck
}
private HashSet<int> _activeNodes;
private HashSet<int> _inactiveNodes;
private Mode _mode;
public SearchContext()
{
_mode = Mode.None;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal bool IsActiveNode(int nodeId)
{
switch (_mode)
{
case Mode.ActiveCheck: return _activeNodes.Contains(nodeId);
case Mode.InactiveCheck: return _inactiveNodes.Contains(nodeId) == false;
case Mode.ActiveInactiveCheck: return _inactiveNodes.Contains(nodeId) == false && _activeNodes.Contains(nodeId);
}
return nodeId >= 0;
}
public SearchContext SetActiveNodes(IEnumerable<int> activeNodes)
{
if (activeNodes != null && activeNodes.Any())
{
if (_mode == Mode.ActiveCheck || _mode == Mode.ActiveInactiveCheck)
{
throw new InvalidOperationException("Active nodes are already defined");
}
_activeNodes = new HashSet<int>(activeNodes);
if (_mode == Mode.None)
{
_mode = Mode.ActiveCheck;
}
else if (_mode == Mode.InactiveCheck)
{
_mode = Mode.ActiveInactiveCheck;
}
}
return this;
}
public SearchContext SetInactiveNodes(IEnumerable<int> inactiveNodes)
{
if (inactiveNodes != null && inactiveNodes.Any())
{
if (_mode == Mode.InactiveCheck || _mode == Mode.ActiveInactiveCheck)
{
throw new InvalidOperationException("Inctive nodes are already defined");
}
_inactiveNodes = new HashSet<int>(inactiveNodes);
if (_mode == Mode.None)
{
_mode = Mode.InactiveCheck;
}
else if (_mode == Mode.ActiveCheck)
{
_mode = Mode.ActiveInactiveCheck;
}
}
return this;
}
}
}

@ -37,9 +37,9 @@ namespace ZeroLevel.HNSW
} }
} }
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes) public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context)
{ {
if (activeNodes == null) if (context == null)
{ {
foreach (var pair in KNearest(vector, k)) foreach (var pair in KNearest(vector, k))
{ {
@ -48,7 +48,7 @@ namespace ZeroLevel.HNSW
} }
else else
{ {
foreach (var pair in KNearest(vector, k, activeNodes)) foreach (var pair in KNearest(vector, k, context))
{ {
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
} }
@ -87,7 +87,7 @@ namespace ZeroLevel.HNSW
return W.Select(p => (p.Key, p.Value)); return W.Select(p => (p.Key, p.Value));
} }
private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes) private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context)
{ {
if (_vectors.Count == 0) if (_vectors.Count == 0)
{ {
@ -111,7 +111,7 @@ namespace ZeroLevel.HNSW
W.Clear(); W.Clear();
} }
// W ← SEARCH-LAYER(q, ep, ef, lc =0) // W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); _layers[0].KNearestAtLayer(ep, distance, W, k, context);
// return K nearest elements from W to q // return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value)); return W.Select(p => (p.Key, p.Value));
} }
@ -143,7 +143,7 @@ namespace ZeroLevel.HNSW
_layers = new ReadOnlyLayer<TItem>[countLayers]; _layers = new ReadOnlyLayer<TItem>[countLayers];
for (int i = 0; i < countLayers; i++) for (int i = 0; i < countLayers; i++)
{ {
_layers[i] = new ReadOnlyLayer<TItem>(_options, _vectors); _layers[i] = new ReadOnlyLayer<TItem>(_vectors);
_layers[i].Deserialize(reader); _layers[i].Deserialize(reader);
} }
} }

@ -166,7 +166,7 @@ namespace ZeroLevel.HNSW
/// <param name="q">query element</param> /// <param name="q">query element</param>
/// <param name="ep">enter points ep</param> /// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns> /// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, HashSet<int> activeNodes) internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, SearchContext context)
{ {
/* /*
* v ep // set of visited elements * v ep // set of visited elements
@ -195,7 +195,7 @@ namespace ZeroLevel.HNSW
var C = new Dictionary<int, float>(); var C = new Dictionary<int, float>();
C.Add(entryPointId, targetCosts(entryPointId)); C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors // W ← ep // dynamic list of found nearest neighbors
if (activeNodes.Contains(entryPointId)) if (context.IsActiveNode(entryPointId))
{ {
W.Add(entryPointId, C[entryPointId]); W.Add(entryPointId, C[entryPointId]);
} }
@ -225,7 +225,7 @@ namespace ZeroLevel.HNSW
{ {
// enqueue perspective neighbours to expansion list // enqueue perspective neighbours to expansion list
var neighbourDistance = targetCosts(neighbourId); var neighbourDistance = targetCosts(neighbourId);
if (activeNodes.Contains(neighbourId)) if (context.IsActiveNode(neighbourId))
{ {
if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance()))
{ {

@ -11,18 +11,15 @@ namespace ZeroLevel.HNSW
internal sealed class ReadOnlyLayer<TItem> internal sealed class ReadOnlyLayer<TItem>
: IBinarySerializable : IBinarySerializable
{ {
private readonly NSWReadOnlyOption<TItem> _options;
private readonly ReadOnlyVectorSet<TItem> _vectors; private readonly ReadOnlyVectorSet<TItem> _vectors;
private readonly ReadOnlyCompactBiDirectionalLinksSet _links; private readonly ReadOnlyCompactBiDirectionalLinksSet _links;
/// <summary> /// <summary>
/// HNSW layer /// HNSW layer
/// </summary> /// </summary>
/// <param name="options">HNSW graph options</param>
/// <param name="vectors">General vector set</param> /// <param name="vectors">General vector set</param>
internal ReadOnlyLayer(NSWReadOnlyOption<TItem> options, ReadOnlyVectorSet<TItem> vectors) internal ReadOnlyLayer(ReadOnlyVectorSet<TItem> vectors)
{ {
_options = options;
_vectors = vectors; _vectors = vectors;
_links = new ReadOnlyCompactBiDirectionalLinksSet(); _links = new ReadOnlyCompactBiDirectionalLinksSet();
} }
@ -114,7 +111,7 @@ namespace ZeroLevel.HNSW
/// <param name="q">query element</param> /// <param name="q">query element</param>
/// <param name="ep">enter points ep</param> /// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns> /// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, HashSet<int> activeNodes) internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, SearchContext context)
{ {
/* /*
* v ep // set of visited elements * v ep // set of visited elements
@ -143,7 +140,7 @@ namespace ZeroLevel.HNSW
var C = new Dictionary<int, float>(); var C = new Dictionary<int, float>();
C.Add(entryPointId, targetCosts(entryPointId)); C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors // W ← ep // dynamic list of found nearest neighbors
if (activeNodes.Contains(entryPointId)) if (context.IsActiveNode(entryPointId))
{ {
W.Add(entryPointId, C[entryPointId]); W.Add(entryPointId, C[entryPointId]);
} }
@ -173,7 +170,7 @@ namespace ZeroLevel.HNSW
{ {
// enqueue perspective neighbours to expansion list // enqueue perspective neighbours to expansion list
var neighbourDistance = targetCosts(neighbourId); var neighbourDistance = targetCosts(neighbourId);
if (activeNodes.Contains(neighbourId)) if (context.IsActiveNode(neighbourId))
{ {
if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance())) if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance()))
{ {
@ -195,110 +192,6 @@ namespace ZeroLevel.HNSW
C.Clear(); C.Clear();
v.Clear(); v.Clear();
} }
/// <summary>
/// Algorithm 3
/// </summary>
internal IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{
var bestN = M;
var W = new Dictionary<int, float>(candidates);
if (W.Count > bestN)
{
var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
while (W.Count > bestN)
{
popFarther();
}
}
// return M nearest elements from C to q
return W;
}
/// <summary>
/// Algorithm 4
/// </summary>
/// <param name="q">base element</param>
/// <param name="C">candidate elements</param>
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns>
internal IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{
// R ← ∅
var R = new Dictionary<int, float>();
// W ← C // working queue for the candidates
var W = new Dictionary<int, float>(candidates);
// if extendCandidates // extend candidates by their neighbors
if (_options.ExpandBestSelection)
{
var extendBuffer = new HashSet<int>();
// for each e ∈ C
foreach (var e in W)
{
var neighbors = GetNeighbors(e.Key);
// for each e_adj ∈ neighbourhood(e) at layer lc
foreach (var e_adj in neighbors)
{
// if eadj ∉ W
if (extendBuffer.Contains(e_adj) == false)
{
extendBuffer.Add(e_adj);
}
}
}
// W ← W eadj
foreach (var id in extendBuffer)
{
W[id] = distance(id);
}
}
// Wd ← ∅ // queue for the discarded candidates
var Wd = new Dictionary<int, float>();
var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); });
// while │W│ > 0 and │R│< M
while (W.Count > 0 && R.Count < M)
{
// e ← extract nearest element from W to q
var (e, ed) = popCandidate();
var (fe, fd) = fartherFromResult();
// if e is closer to q compared to any element from R
if (R.Count == 0 ||
ed < fd)
{
// R ← R e
R.Add(e, ed);
}
else
{
// Wd ← Wd e
Wd.Add(e, ed);
}
}
// if keepPrunedConnections // add some of the discarded // connections from Wd
if (_options.KeepPrunedConnections)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{
// R ← R extract nearest element from Wd to q
var nearest = popNearestDiscarded();
R[nearest.Item1] = nearest.Item2;
}
}
// return R
return R;
}
#endregion #endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id); private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id);

@ -51,9 +51,9 @@ namespace ZeroLevel.HNSW
} }
} }
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes) public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, SearchContext context)
{ {
if (activeNodes == null) if (context == null)
{ {
foreach (var pair in KNearest(vector, k)) foreach (var pair in KNearest(vector, k))
{ {
@ -62,7 +62,7 @@ namespace ZeroLevel.HNSW
} }
else else
{ {
foreach (var pair in KNearest(vector, k, activeNodes)) foreach (var pair in KNearest(vector, k, context))
{ {
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
} }
@ -236,7 +236,7 @@ namespace ZeroLevel.HNSW
_lockGraph.ExitReadLock(); _lockGraph.ExitReadLock();
} }
} }
private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes) private IEnumerable<(int, float)> KNearest(TItem q, int k, SearchContext context)
{ {
_lockGraph.EnterReadLock(); _lockGraph.EnterReadLock();
try try
@ -263,7 +263,7 @@ namespace ZeroLevel.HNSW
W.Clear(); W.Clear();
} }
// W ← SEARCH-LAYER(q, ep, ef, lc =0) // W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes); _layers[0].KNearestAtLayer(ep, distance, W, k, context);
// return K nearest elements from W to q // return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value)); return W.Select(p => (p.Key, p.Value));
} }

Loading…
Cancel
Save

Powered by TurnKey Linux.