Fix algorithm
pull/1/head
unknown 3 years ago
parent 8b5bf38dd5
commit 1f967908b9

@ -84,6 +84,7 @@ namespace HNSWDemo
var sw = new Stopwatch(); var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 4, 120, 120, CosineDistance.ForUnits)); var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 4, 120, 120, CosineDistance.ForUnits));
for (int i = 0; i < (count / batchSize); i++) for (int i = 0; i < (count / batchSize); i++)
{ {
var batch = samples.Skip(i * batchSize).Take(batchSize).ToArray(); var batch = samples.Skip(i * batchSize).Take(batchSize).ToArray();
@ -101,18 +102,18 @@ namespace HNSWDemo
//HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; })); //HNSWFilter filter = new HNSWFilter(ids => ids.Where(id => { var p = _database[id]; return p.Age > 45 && p.Gender == Gender.Feemale; }));
/*var fackupCount = 0; /*var fackupCount = 0;
foreach (var v in vectors) foreach (var v in vectors)
{ {
var result = world.Search(v, 10, filter); var result = world.Search(v, 10, filter);
foreach (var r in result) foreach (var r in result)
{ {
if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale) if (_database[r.Item1].Age <= 45 || _database[r.Item1].Gender != Gender.Feemale)
{ {
Interlocked.Increment(ref fackupCount); Interlocked.Increment(ref fackupCount);
} }
} }
}*/ }*/
//Console.WriteLine($"Completed. Fackup count: {fackupCount}"); //Console.WriteLine($"Completed. Fackup count: {fackupCount}");
Console.ReadKey(); Console.ReadKey();

@ -13,6 +13,11 @@ namespace ZeroLevel.HNSW
private readonly VectorSet<TItem> _vectors; private readonly VectorSet<TItem> _vectors;
private CompactBiDirectionalLinksSet _links = new CompactBiDirectionalLinksSet(); private CompactBiDirectionalLinksSet _links = new CompactBiDirectionalLinksSet();
/// <summary>
/// Count nodes at layer
/// </summary>
public int Count => (_links.Count >> 1);
public Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors) public Layer(NSWOptions<TItem> options, VectorSet<TItem> vectors)
{ {
_options = options; _options = options;
@ -50,148 +55,107 @@ namespace ZeroLevel.HNSW
} }
} }
public int GetEntryPointFor(int q)
{
var randomLinkId = DefaultRandomGenerator.Instance.Next(0, _links.Count);
var entryId = _links[randomLinkId].Item1;
var v = new VisitedBitSet(_vectors._set.Count, _options.M);
// v ← ep // set of visited elements
var (ep, ed) = DFS_SearchMinFrom(entryId, q, v);
return ep;
}
private (int, float) DFS_SearchMinFrom(int entryId, int id, VisitedBitSet visited)
{
visited.Add(entryId);
int candidate = entryId;
var candidateDistance = _options.Distance(_vectors[entryId], _vectors[id]);
int counter = 0;
do
{
var (mid, dist) = GetMinNearest(visited, entryId, candidate, candidateDistance);
if (dist > candidateDistance)
{
break;
}
candidate = mid;
candidateDistance = dist;
counter++;
} while (counter < _options.EFConstruction);
return (candidate, candidateDistance);
}
private (int, float) GetMinNearest(VisitedBitSet visited, int entryId, int id, float entryDistance)
{
var minId = entryId;
var minDist = entryDistance;
foreach (var candidate in _links.FindLinksForId(entryId).Select(l => l.Item2))
{
if (visited.Contains(candidate) == false)
{
var dist = _options.Distance(_vectors[candidate], _vectors[id]);
if (dist < minDist)
{
minDist = dist;
minId = candidate;
}
visited.Add(candidate);
}
}
return (minId, minDist);
}
#region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region Implementation of https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 2 /// Algorithm 2
/// </summary> /// </summary>
/// <param name="q">query element</param> /// <param name="q">query element</param>
/// <param name="ep">enter points ep</param> /// <param name="ep">enter points ep</param>
/// <returns>Output: ef closest neighbors to q</returns> /// <returns>Output: ef closest neighbors to q</returns>
public IDictionary<int, float> SEARCH_LAYER(int q, int ep, int ef) public void RunKnnAtLayer(int entryPointId, Func<int, float> targetCosts, IDictionary<int, float> W, int ef)
{ {
var v = new VisitedBitSet(_vectors._set.Count, _options.M); /*
* v ep // set of visited elements
* C ep // set of candidates
* W ep // dynamic list of found nearest neighbors
* while C > 0
* c extract nearest element from C to q
* f get furthest element from W to q
* if distance(c, q) > distance(f, q)
* break // all elements in W are evaluated
* for each e neighbourhood(c) at layer lc // update C and W
* if e v
* v v e
* f get furthest element from W to q
* if distance(e, q) < distance(f, q) or W < ef
* C C e
* W W e
* if W > ef
* remove furthest element from W to q
* return W
*/
var v = new VisitedBitSet(_vectors.Count, _options.M);
// v ← ep // set of visited elements // v ← ep // set of visited elements
v.Add(ep); v.Add(entryPointId);
// C ← ep // set of candidates // C ← ep // set of candidates
var C = new Dictionary<int, float>(); var C = new Dictionary<int, float>();
C.Add(ep, _options.Distance(_vectors[ep], _vectors[q])); C.Add(entryPointId, targetCosts(entryPointId));
// W ← ep // dynamic list of found nearest neighbors // W ← ep // dynamic list of found nearest neighbors
var W = new Dictionary<int, float>(); W.Add(entryPointId, C[entryPointId]);
W.Add(ep, C[ep]);
// while │C│ > 0
// run bfs
while (C.Count > 0) while (C.Count > 0)
{ {
// c ← extract nearest element from C to q // get next candidate to check and expand
var nearest = W.OrderBy(p => p.Value).First(); var toExpand = popCandidate();
var c = nearest.Key; var farthestResult = fartherFromResult();
var md = nearest.Value; if (toExpand.Item2 > farthestResult.Item2)
// var (c, md) = GetMinimalDistanceIndex(C, q);
C.Remove(c);
// f ← get furthest element from W to q
var f = W.OrderBy(p => p.Value).First().Key;
//var f = GetMaximalDistanceIndex(W, q);
// if distance(c, q) > distance(f, q)
if (_options.Distance(_vectors[c], _vectors[q]) > _options.Distance(_vectors[f], _vectors[q]))
{ {
// break // all elements in W are evaluated // the closest candidate is farther than farthest result
break; break;
} }
// for each e ∈ neighbourhood(c) at layer lc // update C and W
foreach (var l in _links.FindLinksForId(c)) // expand candidate
var neighboursIds = GetNeighbors(toExpand.Item1).ToArray();
for (int i = 0; i < neighboursIds.Length; ++i)
{ {
var e = l.Item2; int neighbourId = neighboursIds[i];
// if e ∉ v if (!v.Contains(neighbourId))
if (v.Contains(e) == false)
{ {
// v ← v e // enqueue perspective neighbours to expansion list
v.Add(e); farthestResult = fartherFromResult();
// f ← get furthest element from W to q
f = W.OrderByDescending(p => p.Value).First().Key; var neighbourDistance = targetCosts(neighbourId);
//f = GetMaximalDistanceIndex(W, q); if (W.Count < ef || neighbourDistance < farthestResult.Item2)
// if distance(e, q) < distance(f, q) or │W│ < ef
var ed = _options.Distance(_vectors[e], _vectors[q]);
if (ed > _options.Distance(_vectors[f], _vectors[q])
|| W.Count < ef)
{ {
// C ← C e C.Add(neighbourId, neighbourDistance);
C.Add(e, ed); W.Add(neighbourId, neighbourDistance);
// W ← W e
W.Add(e, ed);
// if │W│ > ef
if (W.Count > ef) if (W.Count > ef)
{ {
// remove furthest element from W to q fartherPopFromResult();
f = W.OrderByDescending(p => p.Value).First().Key;
//f = GetMaximalDistanceIndex(W, q);
W.Remove(f);
} }
} }
v.Add(neighbourId);
} }
} }
} }
// return W C.Clear();
return W; v.Clear();
} }
/// <summary> /// <summary>
/// Algorithm 3 /// Algorithm 3
/// </summary> /// </summary>
/// <param name="q">base element</param> public IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(Func<int, float> distance, IDictionary<int, float> candidates, int M)
/// <param name="C">candidate elements</param>
/// <returns>Output: M nearest elements to q</returns>
public IDictionary<int, float> SELECT_NEIGHBORS_SIMPLE(int q, IDictionary<int, float> C)
{ {
if (C.Count <= _options.M) var bestN = M;
var W = new Dictionary<int, float>(candidates);
if (W.Count > bestN)
{ {
return new Dictionary<int, float>(C); var popFarther = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
while (W.Count > bestN)
{
popFarther();
}
} }
var output = new Dictionary<int, float>();
// return M nearest elements from C to q // return M nearest elements from C to q
return new Dictionary<int, float>(C.OrderBy(p => p.Value).Take(_options.M)); return W;
} }
/// <summary> /// <summary>
/// Algorithm 4 /// Algorithm 4
/// </summary> /// </summary>
@ -200,41 +164,56 @@ namespace ZeroLevel.HNSW
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param> /// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param> /// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns> /// <returns>Output: M elements selected by the heuristic</returns>
public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(int q, IDictionary<int, float> C, bool extendCandidates, bool keepPrunedConnections) public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M, bool extendCandidates, bool keepPrunedConnections)
{ {
// R ← ∅ // R ← ∅
var R = new Dictionary<int, float>(); var R = new Dictionary<int, float>();
// W ← C // working queue for the candidates // W ← C // working queue for the candidates
var W = new List<int>(C.Select(p => p.Key)); var W = new Dictionary<int, float>(candidates);
// if extendCandidates // extend candidates by their neighbors // if extendCandidates // extend candidates by their neighbors
if (extendCandidates) if (extendCandidates)
{ {
var extendBuffer = new HashSet<int>();
// for each e ∈ C // for each e ∈ C
foreach (var e in C) foreach (var e in W)
{ {
var neighbors = GetNeighbors(e.Key);
// for each e_adj ∈ neighbourhood(e) at layer lc // for each e_adj ∈ neighbourhood(e) at layer lc
foreach (var l in _links.FindLinksForId(e.Key)) foreach (var e_adj in neighbors)
{ {
var e_adj = l.Item2;
// if eadj ∉ W // if eadj ∉ W
if (W.Contains(e_adj) == false) if (extendBuffer.Contains(e_adj) == false)
{ {
// W ← W eadj extendBuffer.Add(e_adj);
W.Add(e_adj);
} }
} }
} }
// W ← W eadj
foreach (var id in extendBuffer)
{
W.Add(id, distance(id));
}
} }
// Wd ← ∅ // queue for the discarded candidates // Wd ← ∅ // queue for the discarded candidates
var Wd = new Dictionary<int, float>(); var Wd = new Dictionary<int, float>();
var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
// while │W│ > 0 and │R│< M // while │W│ > 0 and │R│< M
while (W.Count > 0 && R.Count < _options.M) while (W.Count > 0 && R.Count < M)
{ {
// e ← extract nearest element from W to q // e ← extract nearest element from W to q
var (e, ed) = GetMinimalDistanceIndex(W, q); var (e, ed) = popCandidate();
W.Remove(e); var (fe, fd) = fartherFromResult();
// if e is closer to q compared to any element from R // if e is closer to q compared to any element from R
if (ed < R.Min(pair => pair.Value)) if (R.Count == 0 ||
ed < fd)
{ {
// R ← R e // R ← R e
R.Add(e, ed); R.Add(e, ed);
@ -248,37 +227,20 @@ namespace ZeroLevel.HNSW
if (keepPrunedConnections) if (keepPrunedConnections)
{ {
// while │Wd│> 0 and │R│< M // while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < _options.M) while (Wd.Count > 0 && R.Count < M)
{ {
// R ← R extract nearest element from Wd to q // R ← R extract nearest element from Wd to q
var nearest = Wd.Aggregate((l, r) => l.Value < r.Value ? l : r); var nearest = popNearestDiscarded();
Wd.Remove(nearest.Key); R.Add(nearest.Item1, nearest.Item2);
R.Add(nearest.Key, nearest.Value);
} }
} }
} }
// return R // return R
return R; return R;
} }
#endregion #endregion
private (int, float) GetMinimalDistanceIndex(IList<int> self, int q) private IEnumerable<int> GetNeighbors(int id) => _links.FindLinksForId(id).Select(d => d.Item2);
{
float min = _options.Distance(_vectors[self[0]], _vectors[q]);
int minIndex = 0;
for (int i = 1; i < self.Count; ++i)
{
var dist = _options.Distance(_vectors[self[i]], _vectors[q]);
if (dist < min)
{
min = self[i];
minIndex = i;
}
}
return (minIndex, min);
}
} }
} }

@ -1,15 +1,16 @@
using System.Collections.Generic; using System;
using System.Collections.Generic;
using System.Threading; using System.Threading;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public class VectorSet<T> public class VectorSet<T>
{ {
public IList<T> _set = new List<T>(); private List<T> _set = new List<T>();
private SpinLock _lock = new SpinLock();
public T this[int index] => _set[index]; public T this[int index] => _set[index];
public int Count => _set.Count;
SpinLock _lock = new SpinLock();
public int Append(T vector) public int Append(T vector)
{ {
@ -27,5 +28,30 @@ namespace ZeroLevel.HNSW
if (gotLock) _lock.Exit(); if (gotLock) _lock.Exit();
} }
} }
public int[] Append(IEnumerable<T> vectors)
{
bool gotLock = false;
int startIndex, endIndex;
gotLock = false;
try
{
_lock.Enter(ref gotLock);
startIndex = _set.Count;
_set.AddRange(vectors);
endIndex = _set.Count;
}
finally
{
// Only give up the lock if you actually acquired it
if (gotLock) _lock.Exit();
}
var ids = new int[endIndex - startIndex];
for (int i = startIndex, j = 0; i < endIndex; i++, j++)
{
ids[j] = i;
}
return ids;
}
} }
} }

@ -1,9 +1,38 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading;
namespace ZeroLevel.HNSW namespace ZeroLevel.HNSW
{ {
public class ProbabilityLayerNumberGenerator
{
private const float DIVIDER = 4.362f;
private readonly float[] _probabilities;
public ProbabilityLayerNumberGenerator(int maxLayers, int M)
{
_probabilities = new float[maxLayers];
var probability = 1.0f / DIVIDER;
for (int i = 0; i < maxLayers; i++)
{
_probabilities[i] = probability;
probability /= DIVIDER;
}
}
public int GetRandomLayer()
{
var probability = DefaultRandomGenerator.Instance.NextFloat();
for (int i = 0; i < _probabilities.Length; i++)
{
if (probability > _probabilities[i])
return i;
}
return 0;
}
}
public class SmallWorld<TItem> public class SmallWorld<TItem>
{ {
private readonly NSWOptions<TItem> _options; private readonly NSWOptions<TItem> _options;
@ -12,12 +41,17 @@ namespace ZeroLevel.HNSW
private Layer<TItem> EnterPointsLayer => _layers[_layers.Length - 1]; private Layer<TItem> EnterPointsLayer => _layers[_layers.Length - 1];
private Layer<TItem> LastLayer => _layers[0]; private Layer<TItem> LastLayer => _layers[0];
private int EntryPoint = -1;
private int MaxLayer = -1;
private readonly ProbabilityLayerNumberGenerator _layerLevelGenerator;
private ReaderWriterLockSlim _lockGraph = new ReaderWriterLockSlim();
public SmallWorld(NSWOptions<TItem> options) public SmallWorld(NSWOptions<TItem> options)
{ {
_options = options; _options = options;
_vectors = new VectorSet<TItem>(); _vectors = new VectorSet<TItem>();
_layers = new Layer<TItem>[_options.LayersCount]; _layers = new Layer<TItem>[_options.LayersCount];
_layerLevelGenerator = new ProbabilityLayerNumberGenerator(_options.LayersCount, _options.M);
for (int i = 0; i < _options.LayersCount; i++) for (int i = 0; i < _options.LayersCount; i++)
{ {
_layers[i] = new Layer<TItem>(_options, _vectors); _layers[i] = new Layer<TItem>(_options, _vectors);
@ -31,93 +65,174 @@ namespace ZeroLevel.HNSW
public int[] AddItems(IEnumerable<TItem> vectors) public int[] AddItems(IEnumerable<TItem> vectors)
{ {
var insert = vectors.ToArray(); _lockGraph.EnterWriteLock();
var ids = new int[insert.Length]; try
for (int i = 0; i < insert.Length; i++)
{ {
var item = insert[i]; var ids = _vectors.Append(vectors);
ids[i] = Insert(item); for (int i = 0; i < ids.Length; i++)
{
INSERT(ids[i]);
}
return ids;
}
finally
{
_lockGraph.ExitWriteLock();
} }
return ids;
} }
public int Insert(TItem item) public void TestLevelGenerator()
{ {
var id = _vectors.Append(item); var levels = new Dictionary<int, float>();
INSERT(id); for (int i = 0; i < 10000; i++)
return id; {
var level = _layerLevelGenerator.GetRandomLayer();
if (levels.ContainsKey(level) == false)
{
levels.Add(level, 1);
}
else
{
levels[level] += 1.0f;
}
}
foreach (var pair in levels.OrderBy(l => l.Key))
{
Console.WriteLine($"[{pair.Key}]: {pair.Value / 100.0f}% ({pair.Value})");
}
} }
#region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf #region https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf
/// <summary> /// <summary>
/// Algorithm 1 /// Algorithm 1
/// </summary> /// </summary>
/// <param name="q">new element</param>
public void INSERT(int q) public void INSERT(int q)
{ {
var distance = new Func<int, float>(candidate => _options.Distance(_vectors[q], _vectors[candidate]));
// W ← ∅ // list for the currently found nearest elements // W ← ∅ // list for the currently found nearest elements
IDictionary<int, float> W; IDictionary<int, float> W = new Dictionary<int, float>();
// ep ← get enter point for hnsw // ep ← get enter point for hnsw
var ep = EnterPointsLayer.GetEntryPointFor(q); var ep = EntryPoint == -1 ? 0 : EntryPoint;
var epDist = 0.0f;
// L ← level of ep // top layer for hnsw // L ← level of ep // top layer for hnsw
var L = _layers.Length - 1; var L = MaxLayer;
// l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level // l ← ⌊-ln(unif(0..1))∙mL⌋ // new elements level
int l = DefaultRandomGenerator.Instance.Next(0, _options.LayersCount - 1); int l = _layerLevelGenerator.GetRandomLayer();
if (L == -1)
{
L = l;
MaxLayer = l;
}
// for lc ← L … l+1 // for lc ← L … l+1
for (int lc = L; lc > l; lc--) // Проход с верхнего уровня до уровня где появляется элемент, для нахождения точки входа
for (int lc = L; lc > l; --lc)
{ {
// W ← SEARCH-LAYER(q, ep, ef = 1, lc) // W ← SEARCH-LAYER(q, ep, ef = 1, lc)
W = _layers[lc].SEARCH_LAYER(q, ep, 1); _layers[lc].RunKnnAtLayer(ep, distance, W, 1);
// ep ← get the nearest element from W to q // ep ← get the nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key; var nearest = W.OrderBy(p => p.Value).First();
ep = nearest.Key;
epDist = nearest.Value;
W.Clear();
} }
//for lc ← min(L, l) … 0 //for lc ← min(L, l) … 0
for (int lc = Math.Min(L, l); lc >= 0; lc--) // connecting new node to the small world
for (int lc = Math.Min(L, l); lc >= 0; --lc)
{ {
// W ← SEARCH - LAYER(q, ep, efConstruction, lc) // W ← SEARCH - LAYER(q, ep, efConstruction, lc)
W = _layers[lc].SEARCH_LAYER(q, ep, _options.EFConstruction); _layers[lc].RunKnnAtLayer(ep, distance, W, _options.EFConstruction);
// neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4 // neighbors ← SELECT-NEIGHBORS(q, W, M, lc) // alg. 3 or alg. 4
var neighbors = _layers[lc].SELECT_NEIGHBORS_SIMPLE(q, W); var neighbors = SelectBestForConnecting(lc, distance, W);;
// add bidirectionall connectionts from neighbors to q at layer lc // add bidirectionall connectionts from neighbors to q at layer lc
// for each e ∈ neighbors // shrink connections if needed // for each e ∈ neighbors // shrink connections if needed
foreach (var e in neighbors) foreach (var e in neighbors)
{ {
// eConn ← neighbourhood(e) at layer lc // eConn ← neighbourhood(e) at layer lc
_layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value); _layers[lc].AddBidirectionallConnectionts(q, e.Key, e.Value);
// if distance from newNode to newNeighbour is better than to bestPeer => update bestPeer
if (e.Value < epDist)
{
ep = e.Key;
epDist = e.Value;
}
} }
// ep ← W // ep ← W
ep = W.OrderBy(p => p.Value).First().Key; ep = W.OrderBy(p => p.Value).First().Key;
W.Clear();
}
// if l > L
if (l > L)
{
// set enter point for hnsw to q
L = l;
MaxLayer = l;
EntryPoint = ep;
} }
// if l > L }
// set enter point for hnsw to q
/// <summary>
/// Get maximum allowed connections for the given level.
/// </summary>
/// <remarks>
/// Article: Section 4.1:
/// "Selection of the Mmax0 (the maximum number of connections that an element can have in the zero layer) also
/// has a strong influence on the search performance, especially in case of high quality(high recall) search.
/// Simulations show that setting Mmax0 to M(this corresponds to kNN graphs on each layer if the neighbors
/// selection heuristic is not used) leads to a very strong performance penalty at high recall.
/// Simulations also suggest that 2∙M is a good choice for Mmax0;
/// setting the parameter higher leads to performance degradation and excessive memory usage."
/// </remarks>
/// <param name="layer">The level of the layer.</param>
/// <returns>The maximum number of connections.</returns>
internal int GetM(int layer)
{
return layer == 0 ? 2 * _options.M : _options.M;
}
private IDictionary<int, float> SelectBestForConnecting(int layer, Func<int, float> distance, IDictionary<int, float> candidates)
{
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer));
} }
/// <summary> /// <summary>
/// Algorithm 5 /// Algorithm 5
/// </summary> /// </summary>
/// <param name="q">query element</param> internal IEnumerable<(int, float)> KNearest(TItem q, int k)
/// <param name="K">number of nearest neighbors to return</param>
/// <returns>: K nearest elements to q</returns>
public IList<int> K_NN_SEARCH(int q, int K)
{ {
// W ← ∅ // set for the current nearest elements _lockGraph.EnterReadLock();
IDictionary<int, float> W; try
// ep ← get enter point for hnsw
var ep = EnterPointsLayer.GetEntryPointFor(q);
// L ← level of ep // top layer for hnsw
var L = _options.LayersCount - 1;
// for lc ← L … 1
for (var lc = L; lc > 0; lc--)
{ {
// W ← SEARCH-LAYER(q, ep, ef = 1, lc) if (_vectors.Count == 0)
W = _layers[lc].SEARCH_LAYER(q, ep, 1); {
// ep ← get nearest element from W to q return Enumerable.Empty<(int, float)>();
ep = W.OrderBy(p => p.Value).First().Key; }
var distance = new Func<int, float>(candidate => _options.Distance(q, _vectors[candidate]));
// W ← ∅ // set for the current nearest elements
var W = new Dictionary<int, float>(k + 1);
// ep ← get enter point for hnsw
var ep = EntryPoint;
// L ← level of ep // top layer for hnsw
var L = MaxLayer;
// for lc ← L … 1
for (int layer = L; layer > 0; --layer)
{
// W ← SEARCH-LAYER(q, ep, ef = 1, lc)
_layers[layer].RunKnnAtLayer(ep, distance, W, 1);
// ep ← get nearest element from W to q
ep = W.OrderBy(p => p.Value).First().Key;
W.Clear();
}
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
_layers[0].RunKnnAtLayer(ep, distance, W, k);
// return K nearest elements from W to q
return W.Select(p => (p.Key, p.Value));
}
finally
{
_lockGraph.ExitReadLock();
} }
// W ← SEARCH-LAYER(q, ep, ef, lc =0)
W = LastLayer.SEARCH_LAYER(q, ep, _options.EF);
// return K nearest elements from W to q
return W.OrderBy(p => p.Value).Take(K).Select(p => p.Key).ToList();
} }
#endregion #endregion
} }

Loading…
Cancel
Save

Powered by TurnKey Linux.