using System; using System.Collections.Generic; using System.Linq; using System.Threading; using ZeroLevel.Services.Serialization; namespace ZeroLevel.HNSW { internal sealed class CompactBiDirectionalLinksSet : IBinarySerializable, IDisposable { private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim(); private const int HALF_LONG_BITS = 32; private SortedList _set = new SortedList(); internal SortedList Links => _set; internal (int, int) this[int index] { get { var k = _set.Keys[index]; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); return (id1, id2); } } internal int Count => _set.Count; internal IEnumerable<(int, int, float)> FindLinksForId(int id) { _rwLock.EnterReadLock(); try { if (_set.Count == 1) { var k = _set.Keys[0]; var v = _set[k]; var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); if (id1 == id) yield return (id, id2, v); else if (id2 == id) yield return (id1, id, v); } else if (_set.Count > 1) { foreach (var (k, v) in Search(_set, id)) { var id1 = (int)(k >> HALF_LONG_BITS); var id2 = (int)(k - (((long)id1) << HALF_LONG_BITS)); yield return (id1, id2, v); } } } finally { _rwLock.ExitReadLock(); } } internal IEnumerable<(int, int, float)> Items() { _rwLock.EnterReadLock(); try { foreach (var pair in _set) { var id1 = (int)(pair.Key >> HALF_LONG_BITS); var id2 = (int)(pair.Key - (((long)id1) << HALF_LONG_BITS)); yield return (id1, id2, pair.Value); } } finally { _rwLock.ExitReadLock(); } } internal void RemoveIndex(int id1, int id2) { long k1 = (((long)(id1)) << HALF_LONG_BITS) + id2; long k2 = (((long)(id2)) << HALF_LONG_BITS) + id1; _rwLock.EnterWriteLock(); try { if (_set.ContainsKey(k1)) { _set.Remove(k1); } if (_set.ContainsKey(k2)) { _set.Remove(k2); } } finally { _rwLock.ExitWriteLock(); } } internal bool Add(int id1, int id2, float distance) { _rwLock.EnterWriteLock(); try { long k1 = (((long)(id1)) << HALF_LONG_BITS) + id2; long k2 = (((long)(id2)) << HALF_LONG_BITS) + id1; if (_set.ContainsKey(k1) == false) { _set.Add(k1, distance); if (k1 != k2) { _set.Add(k2, distance); } return true; } } finally { _rwLock.ExitWriteLock(); } return false; } /* function binary_search(A, n, T) is L := 0 R := n − 1 while L ≤ R do m := floor((L + R) / 2) if A[m] < T then L := m + 1 else if A[m] > T then R := m − 1 else: return m return unsuccessful */ private static IEnumerable<(long, float)> Search(SortedList set, int index) { long k = ((long)index) << HALF_LONG_BITS; // T int left = 0; int right = set.Count - 1; int mid; long test; while (left <= right) { mid = (int)Math.Floor((right + left) / 2d); test = (set.Keys[mid] >> HALF_LONG_BITS) << HALF_LONG_BITS; // A[m] if (test < k) { left = mid + 1; } else if (test > k) { right = mid - 1; } else { return SearchByPosition(set, k, mid); } } return Enumerable.Empty<(long, float)>(); } private static IEnumerable<(long, float)> SearchByPosition(SortedList set, long k, int position) { var start = position; var end = position; do { position--; } while (position >= 0 && ((set.Keys[position] >> HALF_LONG_BITS) << HALF_LONG_BITS) == k); start = position + 1; position = end + 1; while (position < set.Count && ((set.Keys[position] >> HALF_LONG_BITS) << HALF_LONG_BITS) == k) { position++; } end = position - 1; for (int i = start; i <= end; i++) { yield return (set.Keys[i], set.Values[i]); } } public Histogram CalculateHistogram(HistogramMode mode) { return new Histogram(mode, _set.Values); } internal float Distance(int id1, int id2) { long k = (((long)(id1)) << HALF_LONG_BITS) + id2; if (_set.ContainsKey(k)) { return _set[k]; } return float.MaxValue; } public void Dispose() { _rwLock.Dispose(); _set.Clear(); _set = null; } public void Serialize(IBinaryWriter writer) { writer.WriteBoolean(true); // true - set with weights writer.WriteInt32(_set.Count); foreach (var record in _set) { writer.WriteLong(record.Key); writer.WriteFloat(record.Value); } } public void Deserialize(IBinaryReader reader) { if (reader.ReadBoolean() == false) { throw new InvalidOperationException("Incompatible data format. The set does not contain weights."); } _set.Clear(); _set = null; var count = reader.ReadInt32(); _set = new SortedList(count + 1); for (int i = 0; i < count; i++) { var key = reader.ReadLong(); var value = reader.ReadFloat(); _set.Add(key, value); } } } }