pull/1/head
unknown 3 years ago
parent 2d1e4f9d5b
commit ac20e9cacb

@ -1,5 +1,6 @@
using HNSWDemo.Tests; using HNSWDemo.Tests;
using System; using System;
using ZeroLevel.Services.Web;
namespace HNSWDemo namespace HNSWDemo
{ {
@ -7,6 +8,8 @@ namespace HNSWDemo
{ {
static void Main(string[] args) static void Main(string[] args)
{ {
var uri = new Uri("https://hack33d.ru/bpla/upload.php?path=128111&get=0J/QuNC70LjQv9C10L3QutC+INCS0LvQsNC00LjQvNC40YAg0JzQuNGF0LDQudC70L7QstC40Yc7MDQuMDkuMTk1NCAoNjYg0LvQtdGCKTvQnNC+0YHQutC+0LLRgdC60LDRjzsxMjgxMTE7TEFfUkVaVVM7RkxZXzAy");
var parts = UrlUtility.ParseQueryString(uri.Query);
new AutoClusteringMNISTTest().Run(); new AutoClusteringMNISTTest().Run();
//new HistogramTest().Run(); //new HistogramTest().Run();
Console.WriteLine("Completed"); Console.WriteLine("Completed");

@ -81,13 +81,30 @@ namespace HNSWDemo.Tests
var links = world.GetLinks().SelectMany(pair => pair.Value.Select(p=> distance(pair.Key, p))).ToList(); var links = world.GetLinks().SelectMany(pair => pair.Value.Select(p=> distance(pair.Key, p))).ToList();
var exists = links.Where(n => n > 0).ToArray(); var exists = links.Where(n => n > 0).ToArray();
var histogram = new Histogram(HistogramMode.SQRT, links); var histogram = new Histogram(HistogramMode.LOG, links);
DrawHistogram(histogram, @"D:\Mnist\histogram.jpg"); DrawHistogram(histogram, @"D:\Mnist\histogram.jpg");
var clusters = AutomaticGraphClusterer.DetectClusters(world); var clusters = AutomaticGraphClusterer.DetectClusters(world);
Console.WriteLine($"Found {clusters.Count} clusters"); Console.WriteLine($"Found {clusters.Count} clusters");
while (clusters.Count > 10)
{
var last = clusters[clusters.Count - 1];
var testDistance = clusters[0].MinDistance(distance, last);
var index = 0;
for (int i = 1; i < clusters.Count - 1; i++)
{
var d = clusters[i].MinDistance(distance, last);
if (d < testDistance)
{
testDistance = d;
index = i;
}
}
clusters[index].Merge(last);
clusters.RemoveAt(clusters.Count - 1);
}
for (int i = 0; i < clusters.Count; i++) for (int i = 0; i < clusters.Count; i++)
{ {
var ouput = Path.Combine(folder, i.ToString("D3")); var ouput = Path.Combine(folder, i.ToString("D3"));

@ -1,13 +1,89 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
namespace ZeroLevel.HNSW.Services namespace ZeroLevel.HNSW.Services
{ {
public static class AutomaticGraphClusterer public class Cluster
: IEnumerable<int>
{ {
private const int HALF_LONG_BITS = 32; private HashSet<int> _elements = new HashSet<int>();
public int Count => _elements.Count;
public bool Contains(int id) => _elements.Contains(id);
public bool Add(int id) => _elements.Add(id);
public IEnumerator<int> GetEnumerator()
{
return _elements.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return _elements.GetEnumerator();
}
public void Merge(Cluster cluster)
{
foreach (var e in cluster)
{
this._elements.Add(e);
}
}
public float MaxDistance(Func<int, int, float> distance, Cluster other)
{
var max = float.MinValue;
foreach (var e in this._elements)
{
foreach (var o in other)
{
var d = distance(e, o);
if (d > max)
{
max = d;
}
}
}
return max;
}
public float MinDistance(Func<int, int, float> distance, Cluster other)
{
var min = float.MaxValue;
foreach (var e in this._elements)
{
foreach (var o in other)
{
var d = distance(e, o);
if (d < min)
{
min = d;
}
}
}
return min;
}
public float AvgDistance(Func<int, int, float> distance, Cluster other)
{
var dist = new List<float>();
foreach (var e in this._elements)
{
foreach (var o in other)
{
dist.Add(distance(e, o));
}
}
return dist.Average();
}
}
public static class AutomaticGraphClusterer
{
private class Link private class Link
{ {
public int Id1; public int Id1;
@ -15,13 +91,13 @@ namespace ZeroLevel.HNSW.Services
public float Distance; public float Distance;
} }
public static List<HashSet<int>> DetectClusters<T>(SmallWorld<T> world) public static List<Cluster> DetectClusters<T>(SmallWorld<T> world)
{ {
var distance = world.DistanceFunction; var distance = world.DistanceFunction;
var links = world.GetLinks().SelectMany(pair => pair.Value.Select(id => new Link { Id1 = pair.Key, Id2 = id, Distance = distance(pair.Key, id) })).ToList(); var links = world.GetLinks().SelectMany(pair => pair.Value.Select(id => new Link { Id1 = pair.Key, Id2 = id, Distance = distance(pair.Key, id) })).ToList();
// 1. Find R - bound between intra-cluster distances and out-of-cluster distances // 1. Find R - bound between intra-cluster distances and out-of-cluster distances
var histogram = new Histogram(HistogramMode.SQRT, links.Select(l => l.Distance)); var histogram = new Histogram(HistogramMode.LOG, links.Select(l => l.Distance));
int threshold = histogram.CuttOff(); int threshold = histogram.CuttOff();
var min = histogram.Bounds[threshold - 1]; var min = histogram.Bounds[threshold - 1];
var max = histogram.Bounds[threshold]; var max = histogram.Bounds[threshold];
@ -39,7 +115,7 @@ namespace ZeroLevel.HNSW.Services
} }
// 3. Extract clusters // 3. Extract clusters
List<HashSet<int>> clusters = new List<HashSet<int>>(); List<Cluster> clusters = new List<Cluster>();
foreach (var l in resultLinks) foreach (var l in resultLinks)
{ {
var id1 = l.Id1; var id1 = l.Id1;
@ -62,7 +138,7 @@ namespace ZeroLevel.HNSW.Services
} }
if (found == false) if (found == false)
{ {
var c = new HashSet<int>(); var c = new Cluster();
c.Add(id1); c.Add(id1);
c.Add(id2); c.Add(id2);
clusters.Add(c); clusters.Add(c);

@ -7,7 +7,10 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="System.Numerics.Vectors" Version="4.5.0" /> <PackageReference Include="System.Numerics.Vectors" Version="4.5.0" />
<PackageReference Include="ZeroLevel" Version="3.3.5.6" /> </ItemGroup>
<ItemGroup>
<ProjectReference Include="..\ZeroLevel\ZeroLevel.csproj" />
</ItemGroup> </ItemGroup>
</Project> </Project>

Loading…
Cancel
Save

Powered by TurnKey Linux.