MNIST.cs

144 lines | 4.98 kB Blame History Raw Download
using CNTK;
using NeuralNetwork.Model.Neutral.SciSharp;
using NeuralNetwork.Model.Neutral.SciSharp.CNTK;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace NeuralNetwork.Model.Neutral.CNTK
{
    public class Digit
    {
        public float[] Image;
        public int Label;
    }

    public class NeutralNetwork
    {
        int inputDim = 784; // 28 * 28 исходный размер изображения
        int outputDim = 10; // 10 - количество цифр

        int minibatchSize = 64;
        int numMinibatchesToTrain = 2000;

        DeviceDescriptor device = DeviceDescriptor.CPUDevice;

        public NDShape inputShape;
        NDShape outputShape;

        public Variable features;
        Variable label;

        Parameter W1;
        Parameter b1;
        Parameter W2;
        Parameter b2;

        public Function z;

        Function loss;
        Function evalError;

        public NeutralNetwork(Digit[] trainData)
        {
            inputShape = new NDShape(1, inputDim);
            outputShape = new NDShape(1, outputDim);

            features = Variable.InputVariable(inputShape, DataType.Float);
            label = Variable.InputVariable(outputShape, DataType.Float);

            int hiddenDim = 1500;

            W1 = new Parameter(new int[] { hiddenDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w1");
            b1 = new Parameter(new int[] { hiddenDim }, DataType.Float, 0, device, "b1");
            W2 = new Parameter(new int[] { outputDim, hiddenDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w2");
            b2 = new Parameter(new int[] { outputDim }, DataType.Float, 0, device, "b2");

            z = CNTKLib.Times(W2, CNTKLib.ReLU(CNTKLib.Times(W1, features) + b1)) + b2;

            loss = CNTKLib.CrossEntropyWithSoftmax(z, label);
            evalError = CNTKLib.ClassificationError(z, label);

            Train(trainData);
        }

        private void Train(Digit[] trainData)
        {
            TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.02, 1);

            IList<Learner> parameterLearners =
                new List<Learner>()
                {
                    Learner.SGDLearner(z.Parameters(), learningRatePerSample)
                };
            var trainer = Trainer.CreateTrainer(z, loss, evalError, parameterLearners);

            var feat = new BatchSource<float[]>((from x in trainData select x.Image).ToArray(), minibatchSize);
            var labl = new BatchSource<float[]>((from x in trainData select x.Label.ToOneHot10(10).ToFloatArray()).ToArray(), minibatchSize);

            for (int ep = 0; ep < numMinibatchesToTrain; ep++)
            {
                Value ft, lb;

                feat.MoveNext(); labl.MoveNext();

                ft = Value.CreateBatchOfSequences<float>(inputShape, feat.Current, device);
                lb = Value.CreateBatchOfSequences<float>(outputShape, labl.Current, device);

                trainer.TrainMinibatch(
                    new Dictionary<Variable, Value>() { { features, ft }, { label, lb } }, false, device);

                //if (ep % 50 == 0)
                //{
                //    var _loss = trainer.PreviousMinibatchLossAverage();
                //    var _eval = trainer.PreviousMinibatchEvaluationAverage();
                //    WriteLine($"Epoch={ep}, loss={_loss}, eval={_eval}");
                //}
            }
        }
    }

    static public class MNIST
    {
        //private const string PathData = @"App_Data\train.csv";//@"d:\train.csv";//@"d:\dataset.csv";
        private static DeviceDescriptor device = DeviceDescriptor.CPUDevice;
        private static NeutralNetwork network;

        public static void Init(string path)
        {
            network = new NeutralNetwork(LoadData(path));
        }

        public static int GetNumber(float [] input)
        {
            var imap = new Dictionary<Variable, Value> { { network.features, Value.CreateBatch(network.inputShape, input, device) } };
            var omap = new Dictionary<Variable, Value> { { network.z, null } };
            network.z.Evaluate(imap, omap, device);
            var o = omap[network.z].GetDenseData<float>(network.z).First();
            var res = o.MaxIndex();
            return res;
        }

        static MNIST()
        {
            
        }

        private static Digit[] LoadData(string path)
        {
            //WriteLine("Reading data");
            if (!File.Exists(path))
                throw new System.IO.IOException($"File {path} not found.");
            var f = File.ReadLines(path);
            var data = from t in f.Skip(1)
                       let zz = t.Split(',').Select(float.Parse)
                       select new Digit
                       {
                           Label = (int)zz.First(),
                           Image = zz.Skip(1).Select(x => x / 256f).ToArray()
                       };
            return data.ToArray();
        }
    }
}