MNIST.cs

163 lines | 5.635 kB Blame History Raw Download
using CNTK;
using NeuralNetwork.Model.Neutral.SciSharp;
using NeuralNetwork.Model.Neutral.SciSharp.CNTK;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace NeuralNetwork.Model.Neutral.CNTK
{
    public class Digit
    {
        public float[] Image;
        public int Label;
    }

    public class NeutralNetwork
    {
        int inputDim = 784; // 28 * 28 исходный размер изображения
        int outputDim = 10; // 10 - количество цифр

        int minibatchSize = 64;
        int numMinibatchesToTrain = 2000;

        DeviceDescriptor device = DeviceDescriptor.CPUDevice;

        public NDShape inputShape;
        NDShape outputShape;

        public Variable features;
        Variable label;

        Parameter W1;
        Parameter b1;
        Parameter W2;
        Parameter b2;

        public Function z;

        Function loss;
        Function evalError;

        public NeutralNetwork(Digit[] trainData)
        {
            inputShape = new NDShape(1, inputDim);
            outputShape = new NDShape(1, outputDim);

            features = Variable.InputVariable(inputShape, DataType.Float);
            label = Variable.InputVariable(outputShape, DataType.Float);

            int hiddenDim = 1500;

            W1 = new Parameter(new int[] { hiddenDim, inputDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w1");
            b1 = new Parameter(new int[] { hiddenDim }, DataType.Float, 0, device, "b1");
            W2 = new Parameter(new int[] { outputDim, hiddenDim }, DataType.Float, CNTKLib.GlorotUniformInitializer(), device, "w2");
            b2 = new Parameter(new int[] { outputDim }, DataType.Float, 0, device, "b2");

            z = CNTKLib.Times(W2, CNTKLib.ReLU(CNTKLib.Times(W1, features) + b1)) + b2;

            loss = CNTKLib.CrossEntropyWithSoftmax(z, label);
            evalError = CNTKLib.ClassificationError(z, label);

            Train(trainData);
        }

        private void Train(Digit[] trainData)
        {
            TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.02, 1);

            IList<Learner> parameterLearners =
                new List<Learner>()
                {
                    Learner.SGDLearner(z.Parameters(), learningRatePerSample)
                };
            var trainer = Trainer.CreateTrainer(z, loss, evalError, parameterLearners);

            var feat = new BatchSource<float[]>((from x in trainData select x.Image).ToArray(), minibatchSize);
            var labl = new BatchSource<float[]>((from x in trainData select x.Label.ToOneHot10(10).ToFloatArray()).ToArray(), minibatchSize);

            for (int ep = 0; ep < numMinibatchesToTrain; ep++)
            {
                Value ft, lb;

                feat.MoveNext(); labl.MoveNext();

                ft = Value.CreateBatchOfSequences<float>(inputShape, feat.Current, device);
                lb = Value.CreateBatchOfSequences<float>(outputShape, labl.Current, device);

                trainer.TrainMinibatch(
                    new Dictionary<Variable, Value>() { { features, ft }, { label, lb } }, false, device);

                //if (ep % 50 == 0)
                //{
                //    var _loss = trainer.PreviousMinibatchLossAverage();
                //    var _eval = trainer.PreviousMinibatchEvaluationAverage();
                //    WriteLine($"Epoch={ep}, loss={_loss}, eval={_eval}");
                //}
            }
        }
    }

    static public class MNIST
    {
        private static string BaseDataPath;
        private static string LearnDataPath;
        private static DeviceDescriptor device = DeviceDescriptor.CPUDevice;
        private static NeutralNetwork network;

        public static void Init(string baseDataPath, string learnDataPath)
        {
            BaseDataPath = baseDataPath;
            LearnDataPath = learnDataPath;

            network = new NeutralNetwork(LoadData(BaseDataPath));
        }

        public static int GetNumber(float [] input)
        {
            var imap = new Dictionary<Variable, Value> { { network.features, Value.CreateBatch(network.inputShape, input, device) } };
            var omap = new Dictionary<Variable, Value> { { network.z, null } };
            network.z.Evaluate(imap, omap, device);
            var o = omap[network.z].GetDenseData<float>(network.z).First();
            var res = o.MaxIndex();
            return res;
        }

        static MNIST()
        {
            
        }

        private static Digit[] LoadData(string path)
        {
            //WriteLine("Reading data");
            if (!File.Exists(path))
                throw new System.IO.IOException($"File {path} not found.");
            var f = File.ReadLines(path);
            var data = from t in f
                       where !string.IsNullOrEmpty(t)
                       let zz = t.Split(',').Select(float.Parse)
                       select new Digit
                       {
                           Label = (int)zz.First(),
                           Image = zz.Skip(1).Select(x => x / 256f).ToArray()
                       };
            return data.ToArray();
        }

        public static void WriteNumber(float[] input, int number)
        {
            if (!string.IsNullOrEmpty(LearnDataPath))
            {
                var byteList = input.Select(e => (byte)System.Math.Round(e * 256f)).ToList();
                using (StreamWriter wr = new StreamWriter(LearnDataPath, true))
                {
                    wr.Write(number);
                    byteList.ForEach(e => wr.Write("," + e));
                    wr.WriteLine();
                }
            }
        }
    }
}