HNSW update

Fix bugs
pull/1/head
unknown 3 years ago
parent 1f967908b9
commit 0e2f693b87

@ -77,13 +77,13 @@ namespace HNSWDemo
static void Main(string[] args)
{
var dimensionality = 128;
var testCount = 1000;
var count = 100000;
var batchSize = 5000;
var testCount = 5000;
var count = 10000;
var batchSize = 1000;
var samples = Person.GenerateRandom(dimensionality, count);
var sw = new Stopwatch();
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 4, 120, 120, CosineDistance.ForUnits));
var world = new SmallWorld<float[]>(NSWOptions<float[]>.Create(6, 10, 200, 200, CosineDistance.ForUnits, false, false, selectionHeuristic: NeighbourSelectionHeuristic.SelectHeuristic));
for (int i = 0; i < (count / batchSize); i++)
{

@ -94,7 +94,9 @@ namespace ZeroLevel.HNSW
// W ← ep // dynamic list of found nearest neighbors
W.Add(entryPointId, C[entryPointId]);
var popCandidate = new Func<(int, float)>(() => { var pair = C.OrderBy(e => e.Value).First(); C.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { var pair = W.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var fartherPopFromResult = new Action(() => { var pair = W.OrderByDescending(e => e.Value).First(); W.Remove(pair.Key); });
// run bfs
while (C.Count > 0)
{
@ -164,14 +166,14 @@ namespace ZeroLevel.HNSW
/// <param name="extendCandidates">flag indicating whether or not to extend candidate list</param>
/// <param name="keepPrunedConnections">flag indicating whether or not to add discarded elements</param>
/// <returns>Output: M elements selected by the heuristic</returns>
public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M, bool extendCandidates, bool keepPrunedConnections)
public IDictionary<int, float> SELECT_NEIGHBORS_HEURISTIC(Func<int, float> distance, IDictionary<int, float> candidates, int M)
{
// R ← ∅
var R = new Dictionary<int, float>();
// W ← C // working queue for the candidates
var W = new Dictionary<int, float>(candidates);
// if extendCandidates // extend candidates by their neighbors
if (extendCandidates)
if (_options.ExpandBestSelection)
{
var extendBuffer = new HashSet<int>();
// for each e ∈ C
@ -191,7 +193,7 @@ namespace ZeroLevel.HNSW
// W ← W eadj
foreach (var id in extendBuffer)
{
W.Add(id, distance(id));
W[id] = distance(id);
}
}
@ -201,7 +203,7 @@ namespace ZeroLevel.HNSW
var popCandidate = new Func<(int, float)>(() => { var pair = W.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var fartherFromResult = new Func<(int, float)>(() => { if (R.Count == 0) return (-1, 0f); var pair = R.OrderByDescending(e => e.Value).First(); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); W.Remove(pair.Key); return (pair.Key, pair.Value); });
var popNearestDiscarded = new Func<(int, float)>(() => { var pair = Wd.OrderBy(e => e.Value).First(); Wd.Remove(pair.Key); return (pair.Key, pair.Value); });
// while │W│ > 0 and │R│< M
@ -218,21 +220,21 @@ namespace ZeroLevel.HNSW
// R ← R e
R.Add(e, ed);
}
// else
else
{
// Wd ← Wd e
Wd.Add(e, ed);
}
// if keepPrunedConnections // add some of the discarded // connections from Wd
if (keepPrunedConnections)
}
// if keepPrunedConnections // add some of the discarded // connections from Wd
if (_options.KeepPrunedConnections)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{
// while │Wd│> 0 and │R│< M
while (Wd.Count > 0 && R.Count < M)
{
// R ← R extract nearest element from Wd to q
var nearest = popNearestDiscarded();
R.Add(nearest.Item1, nearest.Item2);
}
// R ← R extract nearest element from Wd to q
var nearest = popNearestDiscarded();
R[nearest.Item1] = nearest.Item2;
}
}
// return R

@ -2,10 +2,24 @@
namespace ZeroLevel.HNSW
{
public sealed class NSWOptions<TItem>
/// <summary>
/// Type of heuristic to select best neighbours for a node.
/// </summary>
public enum NeighbourSelectionHeuristic
{
public const int FARTHEST_DIVIDER = 3;
/// <summary>
/// Marker for the Algorithm 3 (SELECT-NEIGHBORS-SIMPLE) from the article. Implemented in <see cref="Algorithms.Algorithm3{TItem, TDistance}"/>
/// </summary>
SelectSimple,
/// <summary>
/// Marker for the Algorithm 4 (SELECT-NEIGHBORS-HEURISTIC) from the article. Implemented in <see cref="Algorithms.Algorithm4{TItem, TDistance}"/>
/// </summary>
SelectHeuristic
}
public sealed class NSWOptions<TItem>
{
/// <summary>
/// Mox node connections on Layer
/// </summary>
@ -24,19 +38,42 @@ namespace ZeroLevel.HNSW
/// </summary>
public readonly Func<TItem, TItem, float> Distance;
public readonly bool ExpandBestSelection;
public readonly bool KeepPrunedConnections;
public readonly NeighbourSelectionHeuristic SelectionHeuristic;
public readonly int LayersCount;
private NSWOptions(int layersCount, int m, int ef, int ef_construction, Func<TItem, TItem, float> distance)
private NSWOptions(int layersCount,
int m,
int ef,
int ef_construction,
Func<TItem, TItem, float> distance,
bool expandBestSelection,
bool keepPrunedConnections,
NeighbourSelectionHeuristic selectionHeuristic)
{
LayersCount = layersCount;
M = m;
EF = ef;
EFConstruction = ef_construction;
Distance = distance;
ExpandBestSelection = expandBestSelection;
KeepPrunedConnections = keepPrunedConnections;
SelectionHeuristic = selectionHeuristic;
}
public static NSWOptions<TItem> Create(int layersCount, int M, int EF, int EF_construction, Func<TItem, TItem, float> distance) =>
new NSWOptions<TItem>(layersCount, M, EF, EF_construction, distance);
public static NSWOptions<TItem> Create(int layersCount,
int M,
int EF,
int EF_construction,
Func<TItem, TItem, float> distance,
bool expandBestSelection = false,
bool keepPrunedConnections = false,
NeighbourSelectionHeuristic selectionHeuristic = NeighbourSelectionHeuristic.SelectSimple) =>
new NSWOptions<TItem>(layersCount, M, EF, EF_construction, distance, expandBestSelection, keepPrunedConnections, selectionHeuristic);
}
}

@ -14,15 +14,14 @@ namespace ZeroLevel.HNSW
private SortedList<long, float> _set = new SortedList<long, float>();
public (int, int, float) this[int index]
public (int, int) this[int index]
{
get
{
var k = _set.Keys[index];
var d = _set.Values[index];
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
return (id1, id2, d);
return (id1, id2);
}
}
@ -72,10 +71,22 @@ namespace ZeroLevel.HNSW
{
_set.Remove(k_id1_id2);
_set.Remove(k_id2_id1);
_set.Add(k_id_id1, distanceToId1);
_set.Add(k_id1_id, distanceToId1);
_set.Add(k_id_id2, distanceToId2);
_set.Add(k_id2_id, distanceToId2);
if (!_set.ContainsKey(k_id_id1))
{
_set.Add(k_id_id1, distanceToId1);
}
if (!_set.ContainsKey(k_id1_id))
{
_set.Add(k_id1_id, distanceToId1);
}
if (!_set.ContainsKey(k_id_id2))
{
_set.Add(k_id_id2, distanceToId2);
}
if (!_set.ContainsKey(k_id2_id))
{
_set.Add(k_id2_id, distanceToId2);
}
}
finally
{

@ -192,7 +192,9 @@ namespace ZeroLevel.HNSW
private IDictionary<int, float> SelectBestForConnecting(int layer, Func<int, float> distance, IDictionary<int, float> candidates)
{
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer));
if(_options.SelectionHeuristic == NeighbourSelectionHeuristic.SelectSimple)
return _layers[layer].SELECT_NEIGHBORS_SIMPLE(distance, candidates, GetM(layer));
return _layers[layer].SELECT_NEIGHBORS_HEURISTIC(distance, candidates, GetM(layer));
}
/// <summary>

Loading…
Cancel
Save

Powered by TurnKey Linux.