You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Zero/ZeroLevel.HNSW/Services/CompactBiDirectionalLinksSe...

245 lines
7.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.HNSW
{
internal sealed class CompactBiDirectionalLinksSet
: IBinarySerializable, IDisposable
{
private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim();
private const int HALF_LONG_BITS = 32;
private SortedList<long, float> _set = new SortedList<long, float>();
internal SortedList<long, float> Links => _set;
internal (int, int) this[int index]
{
get
{
var k = _set.Keys[index];
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
return (id1, id2);
}
}
internal int Count => _set.Count;
internal IEnumerable<(int, int, float)> FindLinksForId(int id)
{
_rwLock.EnterReadLock();
try
{
if (_set.Count == 1)
{
var k = _set.Keys[0];
var v = _set[k];
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
if (id1 == id) yield return (id, id2, v);
else if (id2 == id) yield return (id1, id, v);
}
else if (_set.Count > 1)
{
foreach (var (k, v) in Search(_set, id))
{
var id1 = (int)(k >> HALF_LONG_BITS);
var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS));
yield return (id1, id2, v);
}
}
}
finally
{
_rwLock.ExitReadLock();
}
}
internal IEnumerable<(int, int, float)> Items()
{
_rwLock.EnterReadLock();
try
{
foreach (var pair in _set)
{
var id1 = (int)(pair.Key >> HALF_LONG_BITS);
var id2 = (int)(pair.Key - (((long)id1) << HALF_LONG_BITS));
yield return (id1, id2, pair.Value);
}
}
finally
{
_rwLock.ExitReadLock();
}
}
internal void RemoveIndex(int id1, int id2)
{
long k1 = (((long)(id1)) << HALF_LONG_BITS) + id2;
long k2 = (((long)(id2)) << HALF_LONG_BITS) + id1;
_rwLock.EnterWriteLock();
try
{
if (_set.ContainsKey(k1))
{
_set.Remove(k1);
}
if (_set.ContainsKey(k2))
{
_set.Remove(k2);
}
}
finally
{
_rwLock.ExitWriteLock();
}
}
internal bool Add(int id1, int id2, float distance)
{
_rwLock.EnterWriteLock();
try
{
long k1 = (((long)(id1)) << HALF_LONG_BITS) + id2;
long k2 = (((long)(id2)) << HALF_LONG_BITS) + id1;
if (_set.ContainsKey(k1) == false)
{
_set.Add(k1, distance);
if (k1 != k2)
{
_set.Add(k2, distance);
}
return true;
}
}
finally
{
_rwLock.ExitWriteLock();
}
return false;
}
/*
function binary_search(A, n, T) is
L := 0
R := n 1
while L ≤ R do
m := floor((L + R) / 2)
if A[m] < T then
L := m + 1
else if A[m] > T then
R := m 1
else:
return m
return unsuccessful
*/
private static IEnumerable<(long, float)> Search(SortedList<long, float> set, int index)
{
long k = ((long)index) << HALF_LONG_BITS; // T
int left = 0;
int right = set.Count - 1;
int mid;
long test;
while (left <= right)
{
mid = (int)Math.Floor((right + left) / 2d);
test = (set.Keys[mid] >> HALF_LONG_BITS) << HALF_LONG_BITS; // A[m]
if (test < k)
{
left = mid + 1;
}
else if (test > k)
{
right = mid - 1;
}
else
{
return SearchByPosition(set, k, mid);
}
}
return Enumerable.Empty<(long, float)>();
}
private static IEnumerable<(long, float)> SearchByPosition(SortedList<long, float> set, long k, int position)
{
var start = position;
var end = position;
do
{
position--;
} while (position >= 0 && ((set.Keys[position] >> HALF_LONG_BITS) << HALF_LONG_BITS) == k);
start = position + 1;
position = end + 1;
while (position < set.Count && ((set.Keys[position] >> HALF_LONG_BITS) << HALF_LONG_BITS) == k)
{
position++;
}
end = position - 1;
for (int i = start; i <= end; i++)
{
yield return (set.Keys[i], set.Values[i]);
}
}
public Histogram CalculateHistogram(HistogramMode mode)
{
return new Histogram(mode, _set.Values);
}
internal float Distance(int id1, int id2)
{
long k = (((long)(id1)) << HALF_LONG_BITS) + id2;
if (_set.ContainsKey(k))
{
return _set[k];
}
return float.MaxValue;
}
public void Dispose()
{
_rwLock.Dispose();
_set.Clear();
_set = null;
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteBoolean(true); // true - set with weights
writer.WriteInt32(_set.Count);
foreach (var record in _set)
{
writer.WriteLong(record.Key);
writer.WriteFloat(record.Value);
}
}
public void Deserialize(IBinaryReader reader)
{
if (reader.ReadBoolean() == false)
{
throw new InvalidOperationException("Incompatible data format. The set does not contain weights.");
}
_set.Clear();
_set = null;
var count = reader.ReadInt32();
_set = new SortedList<long, float>(count + 1);
for (int i = 0; i < count; i++)
{
var key = reader.ReadLong();
var value = reader.ReadFloat();
_set.Add(key, value);
}
}
}
}

Powered by TurnKey Linux.