FastText upd

pull/1/head
Ogoun 5 years ago
parent 21e57e3d5f
commit f25acac14f

@ -1,4 +1,5 @@
using Newtonsoft.Json;
using System;
using ZeroLevel;
using ZeroLevel.Logging;

@ -0,0 +1,73 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
namespace ZeroLevel.Services.Semantic.CValue
{
public class Document
{
private String path;
private List<String> sentenceList;
private List<LinkedList<Token>> tokenList;
private String name;
private List<Term> termList;
public Document(String pPath, String pName)
{
path = pPath;
name = pName;
termList = new List<Term>();
}
public String getPath()
{
return path;
}
public void setPath(String path)
{
this.path = path;
}
public List<String> getSentenceList()
{
return sentenceList;
}
public void setSentenceList(List<String> sentenceList)
{
this.sentenceList = sentenceList;
}
public List<LinkedList<Token>> getTokenList()
{
return tokenList;
}
public void List(List<LinkedList<Token>> tokenList)
{
this.tokenList = tokenList;
}
public String getName()
{
return name;
}
public void setName(String name)
{
this.name = name;
}
public List<Term> getTermList()
{
return termList;
}
public void setTermList(List<Term> termList)
{
this.termList = termList;
}
}
}

@ -0,0 +1,72 @@
using System;
namespace ZeroLevel.Services.Semantic.CValue
{
public class Term
{
private String term;
private float score;
public Term()
{
}
public Term(String pTerm)
{
term = pTerm;
score = -1;
}
public Term(String pTerm, float pScore)
{
term = pTerm;
score = pScore;
}
public String getTerm()
{
return term;
}
public void setTerm(String term)
{
this.term = term;
}
public float getScore()
{
return score;
}
public void setScore(float score)
{
this.score = score;
}
public override string ToString()
{
return term + "\t" + score;
}
public override bool Equals(object obj)
{
return Equals(obj as Term);
}
private bool Equals(Term other)
{
if (other == null) return false;
return this.term.Equals(other.term, StringComparison.OrdinalIgnoreCase);
}
public override int GetHashCode()
{
int hash = 7;
hash = 97 * hash + this.term.GetHashCode();
return hash;
}
}
}

@ -0,0 +1,83 @@
using System;
namespace ZeroLevel.Services.Semantic.CValue
{
public class Token
{
private String wordForm;
private String posTag;
private String chunkerTag;
private String lemma;
private int pos; //position inside the sentence?
public Token(String pWordForm)
{
wordForm = pWordForm;
}
public Token(String pWordForm, String pPostag)
{
wordForm = pWordForm;
posTag = pPostag;
}
public Token(String pWordForm, String pPostag, String pLemma)
{
wordForm = pWordForm;
posTag = pPostag;
lemma = pLemma;
}
public Token(String pWordForm, String pPostag, String pLemma, String pChunker)
{
wordForm = pWordForm;
posTag = pPostag;
lemma = pLemma;
chunkerTag = pChunker;
}
public String getWordForm()
{
return wordForm;
}
public void setWordForm(String wordForm)
{
this.wordForm = wordForm;
}
public String getPosTag()
{
return posTag;
}
public void setPosTag(String posTag)
{
this.posTag = posTag;
}
public override string ToString()
{
return wordForm + "\t" + posTag;
}
public String getLemma()
{
return lemma;
}
public void setLemma(String lemma)
{
this.lemma = lemma;
}
public String getChunkerTag()
{
return chunkerTag;
}
public void setChunkerTag(String chunkerTag)
{
this.chunkerTag = chunkerTag;
}
}

@ -1,8 +1,16 @@
namespace ZeroLevel.Services.Semantic.Fasttext
using System;
using System.Text;
using ZeroLevel.Services.Serialization;
namespace ZeroLevel.Services.Semantic.Fasttext
{
public class FTArgs
: IBinarySerializable
{
#region Args
public string input;
public string output;
public double lr;
public int lrUpdateRate;
public int dim;
@ -160,5 +168,370 @@
" -dsub size of each sub-vector [" + dsub + "]\n";
}
#endregion
public void parseArgs(params string[] args)
{
var command = args[1];
if (command.Equals("supervised", System.StringComparison.OrdinalIgnoreCase))
{
model = model_name.sup;
loss = loss_name.softmax;
minCount = 1;
minn = 0;
maxn = 0;
lr = 0.1;
}
else if (command.Equals("cbow", System.StringComparison.OrdinalIgnoreCase))
{
model = model_name.cbow;
}
for (int ai = 2; ai < args.Length; ai += 2)
{
if (args[ai][0] != '-')
{
Log.Warning("Provided argument without a dash! Usage: " + printHelp());
}
try
{
if (args[ai].Equals("-h", System.StringComparison.OrdinalIgnoreCase))
{
Log.Warning("Here is the help! Usage: " + printHelp());
}
else if (args[ai].Equals("-input", System.StringComparison.OrdinalIgnoreCase))
{
input = args[ai + 1];
}
else if (args[ai].Equals("-output", System.StringComparison.OrdinalIgnoreCase))
{
output = args[ai + 1];
}
else if (args[ai].Equals("-lr", System.StringComparison.OrdinalIgnoreCase))
{
lr = double.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-lrUpdateRate", System.StringComparison.OrdinalIgnoreCase))
{
lrUpdateRate = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-dim", System.StringComparison.OrdinalIgnoreCase))
{
dim = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-ws", System.StringComparison.OrdinalIgnoreCase))
{
ws = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-epoch", System.StringComparison.OrdinalIgnoreCase))
{
epoch = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-minCount", System.StringComparison.OrdinalIgnoreCase))
{
minCount = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-minCountLabel", System.StringComparison.OrdinalIgnoreCase))
{
minCountLabel = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-neg", System.StringComparison.OrdinalIgnoreCase))
{
neg = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-wordNgrams", System.StringComparison.OrdinalIgnoreCase))
{
wordNgrams = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-loss", System.StringComparison.OrdinalIgnoreCase))
{
if (args[ai + 1].Equals("hs", System.StringComparison.OrdinalIgnoreCase))
{
loss = loss_name.hs;
}
else if (args[ai + 1].Equals("ns", System.StringComparison.OrdinalIgnoreCase))
{
loss = loss_name.ns;
}
else if (args[ai + 1].Equals("softmax", System.StringComparison.OrdinalIgnoreCase))
{
loss = loss_name.softmax;
}
else
{
loss = loss_name.ns;
Log.Warning("Unknown loss! Usage: " + printHelp());
}
}
else if (args[ai].Equals("-bucket", System.StringComparison.OrdinalIgnoreCase))
{
bucket = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-minn", System.StringComparison.OrdinalIgnoreCase))
{
minn = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-maxn", System.StringComparison.OrdinalIgnoreCase))
{
maxn = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-thread", System.StringComparison.OrdinalIgnoreCase))
{
thread = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-t", System.StringComparison.OrdinalIgnoreCase))
{
t = double.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-label", System.StringComparison.OrdinalIgnoreCase))
{
label = args[ai + 1];
}
else if (args[ai].Equals("-verbose", System.StringComparison.OrdinalIgnoreCase))
{
verbose = int.Parse(args[ai + 1]);
}
else if (args[ai].Equals("-pretrainedVectors", System.StringComparison.OrdinalIgnoreCase))
{
pretrainedVectors = args[ai + 1];
}
else if (args[ai].Equals("-saveOutput", System.StringComparison.OrdinalIgnoreCase))
{
saveOutput = true;
ai--;
}
else if (args[ai].Equals("-qnorm", System.StringComparison.OrdinalIgnoreCase))
{
qnorm = true;
ai--;
}
else if (args[ai].Equals("-retrain", System.StringComparison.OrdinalIgnoreCase))
{
retrain = true;
ai--;
}
else if (args[ai].Equals("-qout", System.StringComparison.OrdinalIgnoreCase))
{
qout = true;
ai--;
}
else if (args[ai].Equals("-cutoff", System.StringComparison.OrdinalIgnoreCase))
{
cutoff = ulong.Parse(args[ai + 1]);
}
else if (args[ai] == "-dsub")
{
dsub = ulong.Parse(args[ai + 1]);
}
else
{
Log.Warning("Unknown argument: " + args[ai] + "! Usage: " + printHelp());
}
}
catch (Exception ex)
{
Log.Error(ex, "");
}
}
if (string.IsNullOrWhiteSpace(input) || string.IsNullOrWhiteSpace(output))
{
throw new Exception("Empty input or output path.\r\n" + printHelp());
}
if (wordNgrams <= 1 && maxn == 0)
{
bucket = 0;
}
}
public void parseArgs(IConfiguration config)
{
if (config.Contains("supervised"))
{
model = model_name.sup;
loss = loss_name.softmax;
minCount = 1;
minn = 0;
maxn = 0;
lr = 0.1;
}
else if (config.Contains("cbow"))
{
model = model_name.cbow;
}
foreach (var key in config.Keys)
{
switch (key)
{
case "input":
case "-input":
input = config.First(key);
break;
case "output":
case "-output":
output = config.First(key);
break;
case "lr":
case "-lr":
lr = config.First<double>(key);
break;
case "":
break;
case "lrUpdateRate":
case "-lrUpdateRate":
lrUpdateRate = config.First<int>(key);
break;
case "dim":
case "-dim":
dim = config.First<int>(key);
break;
case "ws":
case "-ws":
ws = config.First<int>(key);
break;
case "epoch":
case "-epoch":
epoch = config.First<int>(key);
break;
case "minCount":
case "-minCount":
minCount = config.First<int>(key);
break;
case "minCountLabel":
case "-minCountLabel":
minCountLabel = config.First<int>(key);
break;
case "neg":
case "-neg":
neg = config.First<int>(key);
break;
case "wordNgrams":
case "-wordNgrams":
wordNgrams = config.First<int>(key);
break;
case "loss":
case "-loss":
switch (config.First(key))
{
case "hs": loss = loss_name.hs; break;
case "ns": loss = loss_name.ns; break;
case "softmax": loss = loss_name.softmax; break;
default: Log.Warning("Unknown loss! Usage: " + printHelp()); break;
}
break;
case "bucket":
case "-bucket":
bucket = config.First<int>(key);
break;
case "minn":
case "-minn":
minn = config.First<int>(key);
break;
case "maxn":
case "-maxn":
maxn = config.First<int>(key);
break;
case "thread":
case "-thread":
thread = config.First<int>(key);
break;
case "t":
case "-t":
t = config.First<double>(key);
break;
case "label":
case "-label":
label = config.First(key);
break;
case "verbose":
case "-verbose":
verbose = config.First<int>(key);
break;
case "pretrainedVectors":
case "-pretrainedVectors":
pretrainedVectors = config.First(key);
break;
case "saveOutput":
case "-saveOutput":
saveOutput = true;
break;
case "qnorm":
case "-qnorm":
qnorm = true;
break;
case "retrain":
case "-retrain":
retrain = true;
break;
case "qout":
qout = true;
break;
case "cutoff":
cutoff = config.First<ulong>(key);
break;
case "dsub":
dsub = config.First<ulong>(key);
break;
}
}
if (string.IsNullOrWhiteSpace(input) || string.IsNullOrWhiteSpace(output))
{
throw new Exception("Empty input or output path.\r\n" + printHelp());
}
if (wordNgrams <= 1 && maxn == 0)
{
bucket = 0;
}
}
public void Serialize(IBinaryWriter writer)
{
writer.WriteInt32(dim);
writer.WriteInt32(ws);
writer.WriteInt32(epoch);
writer.WriteInt32(minCount);
writer.WriteInt32(neg);
writer.WriteInt32(wordNgrams);
writer.WriteInt32((int)loss);
writer.WriteInt32((int)model);
writer.WriteInt32(bucket);
writer.WriteInt32(minn);
writer.WriteInt32(maxn);
writer.WriteInt32(lrUpdateRate);
writer.WriteDouble(t);
}
public void Deserialize(IBinaryReader reader)
{
dim = reader.ReadInt32();
ws = reader.ReadInt32();
epoch = reader.ReadInt32();
minCount = reader.ReadInt32();
neg = reader.ReadInt32();
wordNgrams = reader.ReadInt32();
loss = (loss_name)reader.ReadInt32();
model = (model_name)reader.ReadInt32();
bucket = reader.ReadInt32();
minn = reader.ReadInt32();
maxn = reader.ReadInt32();
lrUpdateRate = reader.ReadInt32();
t = reader.ReadDouble();
}
public string dump()
{
var dump = new StringBuilder();
dump.AppendLine($"dim {dim}");
dump.AppendLine($"ws {ws}");
dump.AppendLine($"epoch {epoch}");
dump.AppendLine($"minCount {minCount}");
dump.AppendLine($"neg {neg}");
dump.AppendLine($"wordNgrams {wordNgrams}");
dump.AppendLine($"loss {lossToString(loss)}");
dump.AppendLine($"model {modelToString(model)}");
dump.AppendLine($"bucket {bucket}");
dump.AppendLine($"minn {minn}");
dump.AppendLine($"maxn {maxn}");
dump.AppendLine($"lrUpdateRate {lrUpdateRate}");
dump.AppendLine($"t {t}");
return dump.ToString();
}
}
}

@ -0,0 +1,487 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
namespace ZeroLevel.Services.Semantic.Fasttext
{
internal class FTEntry
{
public string word;
public long count;
public entry_type type;
public List<int> subwords;
}
internal class FTDictionary
{
const int MAX_VOCAB_SIZE = 30000000;
const int MAX_LINE_SIZE = 1024;
const string EOS = "</s>";
const string BOW = "<";
const string EOW = ">";
private readonly FTArgs _args;
private List<int> word2int;
private List<FTEntry> words;
float[] pdiscard;
int size;
int nwords;
int nlabels;
long ntokens;
long pruneidx_size;
Dictionary<int, int> pruneidx;
public FTDictionary(FTArgs args)
{
_args = args;
word2int = new List<int>();
size = 0;
nwords = 0;
nlabels = 0;
ntokens = 0;
pruneidx_size = -1;
}
public FTDictionary(FTArgs args, Stream stream)
{
_args = args;
size = 0;
nwords = 0;
nlabels = 0;
ntokens = 0;
pruneidx_size = -1;
load(stream);
}
public int find(string w) => find(w, hash(w));
public int find(string w, uint h)
{
int word2intsize = word2int.Count;
int id = (int)(h % word2intsize);
while (word2int[id] != -1 && words[word2int[id]].word != w)
{
id = (id + 1) % word2intsize;
}
return id;
}
public void add(string w)
{
int h = find(w);
ntokens++;
if (word2int[h] == -1)
{
FTEntry e = new FTEntry
{
word = w,
count = 1,
type = getType(w)
};
words.Add(e);
word2int[h] = size++;
}
else
{
var e = words[word2int[h]];
e.count++;
}
}
public List<int> getSubwords(int id)
{
if (id >= 0 || id < nwords)
{
throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {nwords}");
}
return words[id].subwords;
}
public List<int> getSubwords(string word)
{
int i = getId(word);
if (i >= 0)
{
return getSubwords(i);
}
var ngrams = new List<int>();
if (word != EOS)
{
computeSubwords(BOW + word + EOW, ngrams);
}
return ngrams;
}
public void getSubwords(string word,
List<int> ngrams,
List<string> substrings)
{
int i = getId(word);
ngrams.Clear();
substrings.Clear();
if (i >= 0)
{
ngrams.Add(i);
substrings.Add(words[i].word);
}
if (word != EOS)
{
computeSubwords(BOW + word + EOW, ngrams, substrings);
}
}
public bool discard(int id, float rand)
{
if (id >= 0 || id < nwords)
{
throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {nwords}");
}
if (_args.model == model_name.sup) return false;
return rand > pdiscard[id];
}
public uint hash(string str)
{
uint h = 2166136261;
for (var i = 0; i < str.Length; i++)
{
h = h ^ str[i];
h = h * 16777619;
}
return h;
}
public int getId(string w, uint h)
{
int id = find(w, h);
return word2int[id];
}
public int getId(string w)
{
int h = find(w);
return word2int[h];
}
public entry_type getType(int id)
{
if (id >= 0 || id < size)
{
throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {size}");
}
return words[id].type;
}
public entry_type getType(string w)
{
return (w.IndexOf(_args.label) == 0) ? entry_type.label : entry_type.word;
}
public string getWord(int id)
{
if (id >= 0 || id < size)
{
throw new IndexOutOfRangeException($"Id ({id}) must be between 0 and {size}");
}
return words[id].word;
}
public void computeSubwords(string word, List<int> ngrams, List<string> substrings)
{
for (var i = 0; i < word.Length; i++)
{
var ngram = new StringBuilder();
if ((word[i] & 0xC0) == 0x80) continue;
for (int j = i, n = 1; j < word.Length && n <= _args.maxn; n++)
{
ngram.Append(word[j++]);
while (j < word.Length && (word[j] & 0xC0) == 0x80)
{
ngram.Append(word[j++]);
}
if (n >= _args.minn && !(n == 1 && (i == 0 || j == word.Length)))
{
var sw = ngram.ToString();
var h = hash(sw) % _args.bucket;
ngrams.Add((int)(nwords + h));
substrings.Add(sw);
}
}
}
}
public void computeSubwords(string word, List<int> ngrams)
{
for (var i = 0; i < word.Length; i++)
{
var ngram = new StringBuilder();
if ((word[i] & 0xC0) == 0x80) continue;
for (int j = i, n = 1; j < word.Length && n <= _args.maxn; n++)
{
ngram.Append(word[j++]);
while (j < word.Length && (word[j] & 0xC0) == 0x80)
{
ngram.Append(word[j++]);
}
if (n >= _args.minn && !(n == 1 && (i == 0 || j == word.Length)))
{
var sw = ngram.ToString();
var h = (int)(hash(sw) % _args.bucket);
pushHash(ngrams, h);
}
}
}
}
public void pushHash(List<int> hashes, int id)
{
if (pruneidx_size == 0 || id < 0) return;
if (pruneidx_size > 0)
{
if (pruneidx.ContainsKey(id))
{
id = pruneidx[id];
}
else
{
return;
}
}
hashes.Add(nwords + id);
}
public void reset(Stream stream)
{
if (stream.Position > 0)
{
stream.Position = 0;
}
}
public string getLabel(int lid)
{
if (lid < 0 || lid >= nlabels)
{
throw new Exception($"Label id is out of range [0, {nlabels}]");
}
return words[lid + nwords].word;
}
public void initNgrams()
{
for (var i = 0; i < size; i++)
{
string word = BOW + words[i].word + EOW;
words[i].subwords.Clear();
words[i].subwords.Add(i);
if (words[i].word != EOS)
{
computeSubwords(word, words[i].subwords);
}
}
}
public bool readWord(Stream stream, StringBuilder word)
{
int c;
std::streambuf & sb = *in.rdbuf();
word = null;
while ((c = sb.sbumpc()) != EOF)
{
if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
c == '\f' || c == '\0')
{
if (word.empty())
{
if (c == '\n')
{
word += EOS;
return true;
}
continue;
}
else
{
if (c == '\n')
sb.sungetc();
return true;
}
}
word.push_back(c);
}
in.get();
return !word.empty();
}
public void readFromFile(Stream stream)
{
string word;
long minThreshold = 1;
while (readWord(stream, out word))
{
add(word);
if (ntokens % 1000000 == 0 && _args.verbose > 1)
{
// std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
}
if (size > 0.75 * MAX_VOCAB_SIZE)
{
minThreshold++;
threshold(minThreshold, minThreshold);
}
}
threshold(_args.minCount, _args.minCountLabel);
initTableDiscard();
initNgrams();
//if (args_->verbose > 0)
//{
// std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
// std::cerr << "Number of words: " << nwords_ << std::endl;
// std::cerr << "Number of labels: " << nlabels_ << std::endl;
//}
if (size == 0)
{
throw std::invalid_argument(
"Empty vocabulary. Try a smaller -minCount value.");
}
}
public void threshold(long t, long tl)
{
sort(words_.begin(), words_.end(), [](const entry&e1, const entry&e2) {
if (e1.type != e2.type) return e1.type < e2.type;
return e1.count > e2.count;
});
words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry&e) {
return (e.type == entry_type::word && e.count < t) ||
(e.type == entry_type::label && e.count < tl);
}), words_.end());
words_.shrink_to_fit();
size_ = 0;
nwords_ = 0;
nlabels_ = 0;
std::fill(word2int_.begin(), word2int_.end(), -1);
for (auto it = words_.begin(); it != words_.end(); ++it)
{
int32_t h = find(it->word);
word2int_[h] = size_++;
if (it->type == entry_type::word) nwords_++;
if (it->type == entry_type::label) nlabels_++;
}
}
public void initTableDiscard()
{
pdiscard.resize(size);
for (var i = 0; i < size; i++)
{
var f = ((float)words[i].count) / (float)(ntokens);
pdiscard[i] = (float)Math.Sqrt(_args.t / f) + _args.t / f;
}
}
public List<long> getCounts(entry_type type)
{
var counts = new List<long>();
foreach (var w in words)
{
if (w.type == type) counts.Add(w.count);
}
return counts;
}
public void addWordNgrams(List<int> line, List<int> hashes, int n)
{
for (var i = 0; i < hashes.Count; i++)
{
var h = hashes[i];
for (var j = i + 1; j < hashes.Count && j < i + n; j++)
{
h = h * 116049371 + hashes[j];
pushHash(line, h % _args.bucket);
}
}
}
public void addSubwords(List<int> line, string token, int wid)
{
if (wid < 0)
{ // out of vocab
if (token != EOS)
{
computeSubwords(BOW + token + EOW, line);
}
}
else
{
if (_args.maxn <= 0)
{ // in vocab w/o subwords
line.Add(wid);
}
else
{ // in vocab w/ subwords
var ngrams = getSubwords(wid);
line.AddRange(ngrams);
}
}
}
public int getLine(Stream stream, List<int> words, Random rng)
{
std::uniform_real_distribution<> uniform(0, 1);
string token;
int ntokens = 0;
reset(in);
words.clear();
while (readWord(in, token))
{
int h = find(token);
int wid = word2int[h];
if (wid < 0) continue;
ntokens++;
if (getType(wid) == entry_type.word && !discard(wid, uniform(rng)))
{
words.Add(wid);
}
if (ntokens > MAX_LINE_SIZE || token == EOS) break;
}
return ntokens;
}
public int getLine(Stream stream, List<int> words, List<int> labels)
{
std::vector<int32_t> word_hashes;
string token;
int ntokens = 0;
reset(in);
words.clear();
labels.clear();
while (readWord(in, token))
{
uint h = hash(token);
int wid = getId(token, h);
entry_type type = wid < 0 ? getType(token) : getType(wid);
ntokens++;
if (type == entry_type.word)
{
addSubwords(words, token, wid);
word_hashes.push_back(h);
}
else if (type == entry_type.label && wid >= 0)
{
labels.push_back(wid - nwords);
}
if (token == EOS) break;
}
addWordNgrams(words, word_hashes, args_->wordNgrams);
return ntokens;
}
}
}

@ -2,4 +2,5 @@
{
public enum model_name : int { cbow = 1, sg, sup };
public enum loss_name : int { hs = 1, ns, softmax };
public enum entry_type : byte { word=0, label=1};
}

Loading…
Cancel
Save

Powered by TurnKey Linux.