Added functionality of active nodes to HNSW

Append HNSW tests in HNSWDemo project
pull/1/head
unknown 3 years ago
parent db076efbb3
commit 4132147657

@ -98,33 +98,78 @@ namespace HNSWDemo
static void Main(string[] args) static void Main(string[] args)
{ {
var dimensionality = 128; FilterTest();
var testCount = 1000; Console.ReadKey();
}
static void FilterTest()
{
var count = 5000; var count = 5000;
var testCount = 1000;
var dimensionality = 128;
var samples = Person.GenerateRandom(dimensionality, count); var samples = Person.GenerateRandom(dimensionality, count);
var sw = new Stopwatch();
var test = new VectorsDirectCompare(samples.Select(s => s.Item1).ToList(), CosineDistance.ForUnits);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple)); var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
var batch = samples.ToArray(); var ids = world.AddItems(samples.Select(i => i.Item1).ToArray());
for (int bi = 0; bi < samples.Count; bi++)
var ids = world.AddItems(batch.Select(i => i.Item1).ToArray());
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
for (int bi = 0; bi < batch.Length; bi++)
{ {
_database.Add(ids[bi], batch[bi].Item2); _database.Add(ids[bi], samples[bi].Item2);
} }
Console.WriteLine("Start test"); Console.WriteLine("Start test");
int K = 200; int K = 200;
var vectors = RandomVectors(dimensionality, testCount); var vectors = RandomVectors(dimensionality, testCount);
var activeNodes = _database.Where(pair => pair.Value.Age > 20 && pair.Value.Age < 50 && pair.Value.Gender == Gender.Feemale).Select(pair => pair.Key).ToHashSet();
var hits = 0;
var miss = 0;
foreach (var v in vectors)
{
var result = world.Search(v, K, activeNodes);
foreach (var r in result)
{
var record = _database[r.Item1];
if (record.Gender == Gender.Feemale && record.Age > 20 && record.Age < 50)
{
hits++;
}
else
{
miss++;
}
}
}
Console.WriteLine($"SUCCESS: {hits}");
Console.WriteLine($"ERROR: {miss}");
}
static void AccuracityTest()
{
int K = 200;
var count = 5000;
var testCount = 1000;
var dimensionality = 128;
var totalHits = new List<int>(); var totalHits = new List<int>();
var timewatchesHNSW = new List<float>();
var timewatchesNP = new List<float>(); var timewatchesNP = new List<float>();
foreach (var v in vectors) var timewatchesHNSW = new List<float>();
var samples = RandomVectors(dimensionality, count);
var sw = new Stopwatch();
var test = new VectorsDirectCompare(samples, CosineDistance.ForUnits);
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 15, 200, 200, CosineDistance.ForUnits, true, true, selectionHeuristic: NeighbourSelectionHeuristic.SelectSimple));
sw.Start();
var ids = world.AddItems(samples.ToArray());
sw.Stop();
Console.WriteLine($"Insert {ids.Length} items on {sw.ElapsedMilliseconds} ms");
Console.WriteLine("Start test");
var test_vectors = RandomVectors(dimensionality, testCount);
foreach (var v in test_vectors)
{ {
sw.Restart(); sw.Restart();
var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2); var gt = test.KNearest(v, K).ToDictionary(p => p.Item1, p => p.Item2);
@ -156,26 +201,6 @@ namespace HNSWDemo
Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms"); Console.WriteLine($"MIN NP TIME: {timewatchesNP.Min()} ms");
Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms"); Console.WriteLine($"AVG NP TIME: {timewatchesNP.Average()} ms");
Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms"); Console.WriteLine($"MAX NP TIME: {timewatchesNP.Max()} ms");
//HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; }));
/*var fackupCount = 0;
foreach (var v in vectors)
{
var result = world.Search(v, 10, filter);
foreach (var r in result)
{
if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale)
{
Interlocked.Increment(ref fackupCount);
}
}
}*/
//Console.WriteLine($"Completed. Fackup count: {fackupCount}");
Console.ReadKey();
} }
} }
} }

@ -11,20 +11,33 @@ namespace ZeroLevel.HNSW
{ {
private readonly NSWOptions<TItem> _options; private readonly NSWOptions<TItem> _options;
private readonly VectorSet<TItem> _vectors; private readonly VectorSet<TItem> _vectors;
private CompactBiDirectionalLinksSet _links = new CompactBiDirectionalLinksSet(); private readonly CompactBiDirectionalLinksSet _links;
/// <summary> /// <summary>
/// Count nodes at layer /// There are links е the layer
/// </summary> /// </summary>
public int CountLinks => (_links.Count); internal bool HasLinks => (_links.Count > 0);
public Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors) /// <summary>
/// HNSW layer
/// </summary>
/// <param name="options">HNSW graph options</param>
/// <param name="vectors">General vector set</param>
internal Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors)
{ {
_options = options; _options = options;
_vectors = vectors; _vectors = vectors;
_links = new CompactBiDirectionalLinksSet();
} }
public void AddBidirectionallConnectionts(int q, int p, float qpDistance, bool isMapLayer) /// <summary>
/// Adding new bidirectional link
/// </summary>
/// <param name="q">New node</param>
/// <param name="p">The node with which the connection will be made</param>
/// <param name="qpDistance"></param>
/// <param name="isMapLayer"></param>
internal void AddBidirectionallConnections(int q, int p, float qpDistance, bool isMapLayer)
{ {
// поиск в ширину ближайших узлов к найденному // поиск в ширину ближайших узлов к найденному
var nearest = _links.FindLinksForId(p).ToArray(); var nearest = _links.FindLinksForId(p).ToArray();
@ -55,12 +68,15 @@ namespace ZeroLevel.HNSW
} }
} }
public void Append(int q) /// <summary>
/// Adding a node with a connection to itself
/// </summary>
/// <param name="q"></param>
internal void Append(int q)
{ {
_links.Add(q, q, 0); _links.Add(q, q, 0);
} }
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 2 /// Algorithm 2
@ -68,7 +84,7 @@ namespace ZeroLevel.HNSW
/// <param name="q">query element</param> /// <param name="q">query element</param>
/// <param name="ep">enter points ep</param> /// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns> /// <returns>Output: ef closest neighbors to q</returns>
public void RunKnnAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef) internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef)
{ {
/* /*
* v ep // set of visited elements * v ep // set of visited elements
@ -90,7 +106,6 @@ namespace ZeroLevel.HNSW
* remove furthest element from W to q * remove furthest element from W to q
* return W * return W
*/ */
var v = new VisitedBitSet(_vectors.Count, _options.M); var v = new VisitedBitSet(_vectors.Count, _options.M);
// v ← ep // set of visited elements // v ← ep // set of visited elements
v.Add(entryPointId); v.Add(entryPointId);
@ -143,10 +158,98 @@ namespace ZeroLevel.HNSW
v.Clear(); v.Clear();
} }
/// <summary>
/// Algorithm 2
/// </summary>
/// <param name="q">query element</param>
/// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns>
internal void KNearestAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef, HashSet<int> activeNodes)
{
/*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, _options.M);
// v ← ep // set of visited elements
v.Add(entryPointId);
// C ← ep // set of candidates
var C = new Dictionary<int, float>();
C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors
if (activeNodes.Contains(entryPointId))
{
W.Add(entryPointId, C[entryPointId]);
}
var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); });
var farthestDistance = new Func<float>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return pair.Value; });
var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
// run bfs
while (C.Count > 0)
{
// get next candidate to check and expand
var toExpand = popCandidate();
if (W.Count > 0)
{
if (toExpand.Item2 > farthestDistance())
{
// the closest candidate is farther than farthest result
break;
}
}
// expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{
int neighbourId = neighboursIds[i];
if (!v.Contains(neighbourId))
{
// enqueue perspective neighbours to expansion list
var neighbourDistance = targetCosts(neighbourId);
if (activeNodes.Contains(neighbourId))
{
if (W.Count < ef || (W.Count > 0 && neighbourDistance < farthestDistance()))
{
W.Add(neighbourId, neighbourDistance);
if (W.Count > ef)
{
fartherPopFromResult();
}
}
}
if (W.Count < ef)
{
C.Add(neighbourId, neighbourDistance);
}
v.Add(neighbourId);
}
}
}
C.Clear();
v.Clear();
}
/// <summary> /// <summary>
/// Algorithm 3 /// Algorithm 3
/// </summary> /// </summary>
public IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(Func<int, float> distance, IDictionary<int, float> candidates, int M) internal IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{ {
var bestN = M; var bestN = M;
var W = new Dictionary<int, float>(candidates); var W = new Dictionary<int, float>(candidates);
@ -172,7 +275,7 @@ namespace ZeroLevel.HNSW
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param> /// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param> /// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns> /// <returns>Output: M elements selected by the heuristic</returns>
public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M) internal IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{ {
// R ← ∅ // R ← ∅
var R = new Dictionary<int, float>(); var R = new Dictionary<int, float>();
@ -248,7 +351,6 @@ namespace ZeroLevel.HNSW
} }
#endregion #endregion
private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2); private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2);
} }
} }

@ -55,14 +55,26 @@ namespace ZeroLevel.HNSW
} }
} }
/// <summary>
/// Search in the graph K for vectors closest to a given vector
/// </summary>
/// <param name="vector">Given vector</param>
/// <param name="k">Count of elements for search</param>
/// <param name="activeNodes"></param>
/// <returns></returns>
public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes = null) public IEnumerable<(int, TItem, float)> Search(TItem vector, int k, HashSet<int> activeNodes = null)
{ {
foreach (var pair in KNearest(vector, k)) foreach (var pair in KNearest(vector, k, activeNodes))
{ {
yield return (pair.Item1, _vectors[pair.Item1], pair.Item2); yield return (pair.Item1, _vectors[pair.Item1], pair.Item2);
} }
} }
/// <summary>
/// Adding vectors batch
/// </summary>
/// <param name="vectors">Vectors</param>
/// <returns>Vector identifiers in a graph</returns>
public int[] AddItems(IEnumerable<TItem> vectors) public int[] AddItems(IEnumerable<TItem> vectors)
{ {
_lockGraph.EnterWriteLock(); _lockGraph.EnterWriteLock();
@ -81,32 +93,11 @@ namespace ZeroLevel.HNSW
} }
} }
public void TestLevelGenerator()
{
var levels = new Dictionary<int, float>();
for (int i = 0; i < 10000; i++)
{
var level = _layerLevelGenerator.GetRandomLayer();
if (levels.ContainsKey(level) == false)
{
levels.Add(level, 1);
}
else
{
levels[level] += 1.0f;
}
}
foreach (var pair in levels.OrderBy(l => l.Key))
{
Console.WriteLine($"[{pair.Key}]: {pair.Value / 100.0f}% ({pair.Value})");
}
}
#region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 1 /// Algorithm 1
/// </summary> /// </summary>
public void INSERT(int q) private void INSERT(int q)
{ {
var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate])); var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate]));
// W ← ∅ // list for the currently found nearest elements // W ← ∅ // list for the currently found nearest elements
@ -122,7 +113,7 @@ namespace ZeroLevel.HNSW
// Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа
for (int lc = L; lc > l; --lc) for (int lc = L; lc > l; --lc)
{ {
if (_layers[lc].CountLinks == 0) if (_layers[lc].HasLinks == false)
{ {
_layers[lc].Append(q); _layers[lc].Append(q);
ep = q; ep = q;
@ -130,7 +121,7 @@ namespace ZeroLevel.HNSW
else else
{ {
// W ← SEARCH-LAYER(q, ep, ef = 1, lc) // W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[lc].RunKnnAtLayer(ep, distance, W, 1); _layers[lc].KNearestAtLayer(ep, distance, W, 1);
// ep ← get the nearest element from W to q // ep ← get the nearest element from W to q
var nearest = W.OrderBy(p => p.Value).First(); var nearest = W.OrderBy(p => p.Value).First();
ep = nearest.Key; ep = nearest.Key;
@ -142,7 +133,7 @@ namespace ZeroLevel.HNSW
// connecting new node to the small world // connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc) for (int lc = Math.Min(L, l); lc >= 0; --lc)
{ {
if (_layers[lc].CountLinks == 0) if (_layers[lc].HasLinks == false)
{ {
_layers[lc].Append(q); _layers[lc].Append(q);
ep = q; ep = q;
@ -150,7 +141,7 @@ namespace ZeroLevel.HNSW
else else
{ {
// W ← SEARCH - LAYER(q, ep, efConstruction, lc) // W ← SEARCH - LAYER(q, ep, efConstruction, lc)
_layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction); _layers[lc].KNearestAtLayer(ep, distance, W, _options.EFConstruction);
// neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
var neighbors = SelectBestForConnecting(lc, distance, W); var neighbors = SelectBestForConnecting(lc, distance, W);
// add bidirectionall connectionts from neighbors to q at layer lc // add bidirectionall connectionts from neighbors to q at layer lc
@ -158,7 +149,7 @@ namespace ZeroLevel.HNSW
foreach (var e in neighbors) foreach (var e in neighbors)
{ {
// eConn ← neighbourhood(e) at layer lc // eConn ← neighbourhood(e) at layer lc
_layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value, lc == 0); _layers[lc].AddBidirectionallConnections(q, e.Key, e.Value, lc == 0);
// if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer // if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
if (e.Value < epDist) if (e.Value < epDist)
{ {
@ -195,7 +186,7 @@ namespace ZeroLevel.HNSW
/// </remarks> /// </remarks>
/// <param name="layer">The level of the layer.</param> /// <param name="layer">The level of the layer.</param>
/// <returns>The maximum number of connections.</returns> /// <returns>The maximum number of connections.</returns>
internal int GetM(int layer) private int GetM(int layer)
{ {
return layer == 0 ? 2 * _options.M : _options.M; return layer == 0 ? 2 * _options.M : _options.M;
} }
@ -210,7 +201,7 @@ namespace ZeroLevel.HNSW
/// <summary> /// <summary>
/// Algorithm 5 /// Algorithm 5
/// </summary> /// </summary>
internal IEnumerable<(int, float)> KNearest(TItem q, int k) private IEnumerable<(int, float)> KNearest(TItem q, int k, HashSet<int> activeNodes = null)
{ {
_lockGraph.EnterReadLock(); _lockGraph.EnterReadLock();
try try
@ -231,13 +222,13 @@ namespace ZeroLevel.HNSW
for (int layer = L; layer > 0; --layer) for (int layer = L; layer > 0; --layer)
{ {
// W ← SEARCH-LAYER(q, ep, ef = 1, lc) // W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].RunKnnAtLayer(ep, distance, W, 1); _layers[layer].KNearestAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q // ep ← get nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key; ep = W.OrderBy(p => p.Value).First().Key;
W.Clear(); W.Clear();
} }
// W ← SEARCH-LAYER(q, ep, ef, lc =0) // W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].RunKnnAtLayer(ep, distance, W, k); _layers[0].KNearestAtLayer(ep, distance, W, k, activeNodes);
// return K nearest elements from W to q // return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value)); return W.Select(p => (p.Key, p.Value));
} }

Loading…
Cancel
Save

Powered by TurnKey Linux.