using CNTK;
using NeuralNetwork.Model.Neutral.SciSharp;
using NeuralNetwork.Model.Neutral.SciSharp.CNTK;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace NeuralNetwork.Model.Neutral.CNTK
{
public class Digit
{
public float[] Image;
public int Label;
}
public class NeutralNetwork
{
int inputDim = 784; // 28 * 28 исходный размер изображения
int outputDim = 10; // 10 - количество цифр
int minibatchSize = 64;
int numMinibatchesToTrain = 2000;
DeviceDescriptor device = DeviceDescriptor.CPUDevice;
public NDShape inputShape;
NDShape outputShape;
public Variable features;
Variable label;
Parameter W1;
Parameter b1;
Parameter W2;
Parameter b2;
public Function z;
Function loss;
Function evalError;
public NeutralNetwork(Digit[] trainData)
{
inputShape = new NDShape(1, inputDim);
outputShape = new NDShape(1, outputDim);
features = Variable.InputVariable(inputShape, DataType.Float);
label = Variable.InputVariable(outputShape, DataType.Float);
int hiddenDim = 1500;
W1 = new Parameter(new int[] { hiddenDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w1");
b1 = new Parameter(new int[] { hiddenDim }, DataType.Float, 0, device, "b1");
W2 = new Parameter(new int[] { outputDim, hiddenDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w2");
b2 = new Parameter(new int[] { outputDim }, DataType.Float, 0, device, "b2");
z = CNTKLib.Times(W2, CNTKLib.ReLU(CNTKLib.Times(W1, features) + b1)) + b2;
loss = CNTKLib.CrossEntropyWithSoftmax(z, label);
evalError = CNTKLib.ClassificationError(z, label);
Train(trainData);
}
private void Train(Digit[] trainData)
{
TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.02, 1);
IList<Learner> parameterLearners =
new List<Learner>()
{
Learner.SGDLearner(z.Parameters(), learningRatePerSample)
};
var trainer = Trainer.CreateTrainer(z, loss, evalError, parameterLearners);
var feat = new BatchSource<float[]>((from x in trainData select x.Image).ToArray(), minibatchSize);
var labl = new BatchSource<float[]>((from x in trainData select x.Label.ToOneHot10(10).ToFloatArray()).ToArray(), minibatchSize);
for (int ep = 0; ep < numMinibatchesToTrain; ep++)
{
Value ft, lb;
feat.MoveNext(); labl.MoveNext();
ft = Value.CreateBatchOfSequences<float>(inputShape, feat.Current, device);
lb = Value.CreateBatchOfSequences<float>(outputShape, labl.Current, device);
trainer.TrainMinibatch(
new Dictionary<Variable, Value>() { { features, ft }, { label, lb } }, false, device);
//if (ep % 50 == 0)
//{
// var _loss = trainer.PreviousMinibatchLossAverage();
// var _eval = trainer.PreviousMinibatchEvaluationAverage();
// WriteLine($"Epoch={ep}, loss={_loss}, eval={_eval}");
//}
}
}
}
static public class MNIST
{
private static string BaseDataPath;
private static string LearnDataPath;
private static DeviceDescriptor device = DeviceDescriptor.CPUDevice;
private static NeutralNetwork network;
public static void Init(string baseDataPath, string learnDataPath)
{
BaseDataPath = baseDataPath;
LearnDataPath = learnDataPath;
network = new NeutralNetwork(LoadData(BaseDataPath));
}
public static int GetNumber(float [] input)
{
var imap = new Dictionary<Variable, Value> { { network.features, Value.CreateBatch(network.inputShape, input, device) } };
var omap = new Dictionary<Variable, Value> { { network.z, null } };
network.z.Evaluate(imap, omap, device);
var o = omap[network.z].GetDenseData<float>(network.z).First();
var res = o.MaxIndex();
return res;
}
static MNIST()
{
}
private static Digit[] LoadData(string path)
{
//WriteLine("Reading data");
if (!File.Exists(path))
throw new System.IO.IOException($"File {path} not found.");
var f = File.ReadLines(path);
var data = from t in f
where !string.IsNullOrEmpty(t)
let zz = t.Split(',').Select(float.Parse)
select new Digit
{
Label = (int)zz.First(),
Image = zz.Skip(1).Select(x => x / 256f).ToArray()
};
return data.ToArray();
}
public static void WriteNumber(float[] input, int number)
{
if (!string.IsNullOrEmpty(LearnDataPath))
{
var byteList = input.Select(e => (byte)System.Math.Round(e * 256f)).ToList();
using (StreamWriter wr = new StreamWriter(LearnDataPath, true))
{
wr.Write(number);
byteList.ForEach(e => wr.Write("," + e));
wr.WriteLine();
}
}
}
}
}