<template>
    <b-container>
        <div id="machinelearning"/>
    </b-container>
</template>

<script>
    // tensorflow-js 관련 불러오기
    import * as tf from '@tensorflow/tfjs';
    import * as tfvis from '@tensorflow/tfjs-vis';
    import { MnistData } from '@/MnistData/data.js';
    import { IrisData } from '@/IrisData/data.js';
    import { fitCallbacks } from '@tensorflow/tfjs-vis/dist/show/history';

    // vuex
    import store from '@/store';

    export default {
        name: "ModelTraining",
        data() {
            return {}
        },
        methods: {
            init() {
                tf;
                tfvis;
                MnistData;
            },
            testclick() {
                console.log('test click!!');
            }
        },
        mounted() {
            // 머신러닝 실행 
            run();
            async function run() {
                // 데이터 입력 노드 이름
                let data_input_name = store.getters.get_data_input_name;
                // 저장된 모델 이름
                let model_name = store.getters.get_model_name;

                // 노드 CONNECT TRUE / FALSE                 
                let data_connect = store.getters.get_mnist_connect;
                let model_save_connect = store.getters.get_model_save_con;
                
                store.dispatch('call_show', { show: true });

                // 노드 connect true = 실행 / flase 실행 x
                // if (data_connect === true && model_save_connect === true) {

                // 들어오는 데이터에 따라서 실행 
                switch (data_input_name) {

                    case "데이터입력(MNIST)":
                        const data = new MnistData();

                        await data.load();

                        const model = getModel();

                        tfvis.show.modelSummary({ name: '모델 구조' }, model);

                        await train(model, data);
                        await showAccuracy(model, data);
                        await showConfusion(model, data);

                        await model.save('localstorage://' + model_name);
                        await model.save('downloads://' + model_name);


                        // 로딩 상태 중지
                        store.dispatch('call_show', { show: false });

                        function training_alarm_close() {
                            // 알람창 로딩 중지되면 끄기
                            this.$emit('close_btn');
                        }

                        return;

                    case "데이터입력(IRIS)":
                        const irisdata = new IrisData();

                        const splitrate = 0.2;

                        const [xTrain, yTrain, xTest, yTest] = await irisdata.getIrisData(splitrate);

                        const iris_model = await getIrisModel(xTrain);

                        tfvis.show.modelSummary({ name: '모델 구조' }, iris_model);

                        await trainModel(iris_model, xTrain, yTrain, xTest, yTest);
                        await irisshowAccuracy(iris_model, xTest, yTest);
                        await irisshowConfusion(iris_model, xTest, yTest);


                        await iris_model.save('localstorage://' + model_name);
                        await iris_model.save('downloads://' + model_name);


                        // 알람창 로딩 중지되면 끄기
                        // this.$emit('close_btn');

                        // 로딩 상태 중지
                        store.dispatch('call_show', { show: false });

                        function training_alarm_close2() {

                            // 알람창 로딩 중지되면 끄기
                            this.$emit('close_btn');
                        }

                        return;
                }

                // }
                // else {
                // alert('데이터 연결 혹은 모델 연결을 확인해 주세요 ');
                // }
            }

            function irisPredict(iris_model, xTest, yTest) {
                const [preds, labels] = tf.tidy(() => {
                    const preds = iris_model.predict(xTest).argMax(-1);
                    const labels = yTest.argMax(-1);
                    return [preds, labels];
                })
                return [preds, labels]
            }

            async function irisshowConfusion(iris_model, xTest, yTest) {
                const [preds, labels] = irisPredict(iris_model, xTest, yTest);
                const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
                const container = { name: "컨퓨전 매트릭스(혼동 횡렬)", tab: "평가" };
                tfvis.render.confusionMatrix(
                    container, { values: confusionMatrix },
                    classNames
                );

                labels.dispose();
            }

            async function irisshowAccuracy(iris_model, xTest, yTest) {
                const [preds, labels] = irisPredict(iris_model, xTest, yTest);

                const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
                const container = { name: "정확도", tab: "평가" };
                tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
                labels.dispose();
            }

            // iris model 구성
            function getIrisModel(xTrain) {
                const model = tf.sequential();

                if (store.getters.get_model_node_name === "Iris DenseNet") {

                    let inputShape = 4;

                    let t_num = store.getters.get_model_node_id;

                    let m_num = sessionStorage.getItem('model_numbering' + t_num);
                    let m = JSON.parse(m_num);


                    let model_data = sessionStorage.getItem("modaldata" + t_num);

                    let model_json = JSON.parse(model_data);

                    let mod_json = model_json.nodes;

                    for (let i = 0; i < m.length; i++) {

                        let j = mod_json[m[i]].name;
                        let k = mod_json[m[i]].data;

                        if (j === "dense" && i === 1) {
                            model.add(tf.layers.dense({
                                units: k.units,
                                kernelInitializer: k.kernelInitializer,
                                activation: k.activation,
                                inputShape: [inputShape]
                            }))
                        }
                        else if (j === "dense") {
                            model.add(tf.layers.dense({
                                units: k.units,
                                kernelInitializer: k.kernelInitializer,
                                activation: k.activation,
                            }))
                        }
                    }

                }

                switch (store.getters.get_optimizer) {

                    case 'adam':
                        const optimizer1 = tf.train.adam(0.01);
                        model.compile({
                            optimizer: optimizer1,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;
                }
            }

            async function trainModel(model, xTrain, yTrain, xTest, yTest) {
                const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
                const container = {
                    name: '모델 훈련',
                    styles: { height: '1000px' }
                };

                const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
                const BATCH_SIZE = Number(store.getters.get_batch_size);
                return model.fit(xTrain, yTrain, {
                    epochs: store.getters.get_epoch,
                    batchSize: BATCH_SIZE,
                    validationData: [xTest, yTest],
                    shuffle: false,
                    callbacks: fitCallbacks

                });
                ///return model;
            }

            function getModel() {
                const model = tf.sequential();

                // let IMAGE_WIDTH = 28;
                // let IMAGE_HEIGHT = 28;
                // let IMAGE_CHANNELS = 1;

                let IMAGE_WIDTH = store.getters.get_input_data.IMAGE_WIDTH;
                let IMAGE_HEIGHT = store.getters.get_input_data.IMAGE_HEIGHT;
                let IMAGE_CHANNELS = store.getters.get_input_data.IMAGE_CHANNELS;

                console.log(store.getters.get_model_node_name);


                if (store.getters.get_model_node_name === "MNIST ConvNet") {

                    let t_num = store.getters.get_model_node_id;
                    let m_num = sessionStorage.getItem('model_numbering' + t_num);

                    let m = JSON.parse(m_num);

                    let model_data = sessionStorage.getItem("modaldata" + t_num);

                    let model_json = JSON.parse(model_data);

                    let mod_json = model_json.nodes;

                    for (let i = 0; i < m.length; i++) {

                        let j = mod_json[m[i]].name;
                        let k = mod_json[m[i]].data;

                        // if (j === "데이터입력") {
                        //     IMAGE_WIDTH = k.IMAGE_WIDTH;
                        //     IMAGE_HEIGHT = k.IMAGE_HEIGHT;
                        //     IMAGE_CHANNELS = k.IMAGE_CHANNELS;
                        // }
                        if (j === "maxPooling2d") {
                            let t_1 = [Number(k.poolSize[0]), Number(k.poolSize[2])];
                            let t_2 = [Number(k.strides[0]), Number(k.strides[2])];

                            model.add(tf.layers.maxPooling2d({
                                poolSize: t_1,
                                strides: t_2
                            }))
                        }
                        else if (j === "conv2d" && i === 1) {
                            // console.log(IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS);
                            model.add(tf.layers.conv2d({
                                inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
                                kernelSize: k.kernelSize,
                                filters: k.filters,
                                activation: k.activation,
                                kernelInitializer: k.kernelInitializer
                            }))
                        }
                        else if (j === "conv2d" && i != 1) {

                            model.add(tf.layers.conv2d({
                                kernelSize: k.kernelSize,
                                filters: k.filters,
                                activation: k.activation,
                                kernelInitializer: k.kernelInitializer
                            }))
                        }
                        else if (j === "dense") {

                            model.add(tf.layers.dense({
                                units: k.units,
                                kernelInitializer: k.kernelInitializer,
                                activation: k.activation
                            }))
                        }
                        else if (j === "flatten") {
                            model.add(tf.layers.flatten({
                                inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS]
                            }))
                        }
                    }


                }
                else {

                    let t_num = store.getters.get_model_node_id;

                    let m_num = sessionStorage.getItem('model_numbering' + t_num);

                    let m = JSON.parse(m_num);

                    let model_data = sessionStorage.getItem("modaldata" + t_num);

                    let model_json = JSON.parse(model_data);

                    let mod_json = model_json.nodes;

                    for (let i = 0; i < m.length; i++) {

                        let j = mod_json[m[i]].name;
                        let k = mod_json[m[i]].data;

                        // if (j === "데이터입력") {
                        //     IMAGE_WIDTH = k.IMAGE_WIDTH;
                        //     IMAGE_HEIGHT = k.IMAGE_HEIGHT;
                        //     IMAGE_CHANNELS = k.IMAGE_CHANNELS;
                        // }
                        if (j === "maxPooling2d") {
                            let t_1 = [Number(k.poolSize[0]), Number(k.poolSize[2])];
                            let t_2 = [Number(k.strides[0]), Number(k.strides[2])];

                            model.add(tf.layers.maxPooling2d({
                                poolSize: t_1,
                                strides: t_2
                            }))
                        }
                        else if (j === "conv2d" && i === 1) {
                            console.log(IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS);
                            model.add(tf.layers.conv2d({
                                inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
                                kernelSize: k.kernelSize,
                                filters: k.filters,
                                activation: k.activation,
                                kernelInitializer: k.kernelInitializer
                            }))
                        }
                        else if (j === "conv2d" && i != 1) {

                            model.add(tf.layers.conv2d({
                                kernelSize: k.kernelSize,
                                filters: k.filters,
                                activation: k.activation,
                                kernelInitializer: k.kernelInitializer
                            }))
                        }
                        else if (j === "dense") {

                            model.add(tf.layers.dense({
                                units: k.units,
                                kernelInitializer: k.kernelInitializer,
                                activation: k.activation
                            }))
                        }
                        else if (j === "flatten") {
                            model.add(tf.layers.flatten({
                                inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS]
                            }))
                        }
                    }

                }

                switch (store.getters.get_optimizer) {

                    case 'adam':
                        const optimizer1 = tf.train.adam();
                        model.compile({
                            optimizer: optimizer1,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'sgd':
                        const optimizer2 = tf.train.sgd(0.05);
                        console.log('sgd');
                        model.compile({
                            optimizer: optimizer2,
                            loss: store.getters.get_loss,
                            // loss : 'meanSquaredError',
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'momentum':
                        const optimizer3 = tf.train.momentum(0.05, 0.05);
                        model.compile({
                            optimizer: optimizer3,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'adagrad':
                        const optimizer4 = tf.train.adagrad(0.05);
                        model.compile({
                            optimizer: optimizer4,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'adadelta':
                        const optimizer5 = tf.train.adadelta();
                        model.compile({
                            optimizer: optimizer5,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'adamax':
                        const optimizer6 = tf.train.adamax();
                        model.compile({
                            optimizer: optimizer6,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;

                    case 'rmsprop':
                        const optimizer7 = tf.train.rmsprop(0.05);
                        model.compile({
                            optimizer: optimizer7,
                            loss: store.getters.get_loss,
                            metrics: ['accuracy'],
                        });
                        return model;
                }
            }

            async function train(model, data) {


                const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
                const container = {
                    name: '모델 훈련',
                    styles: { height: '1000px' }
                };

                const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

                // const BATCH_SIZE = 512;
                const BATCH_SIZE = Number(store.getters.get_batch_size);
                // const TRAIN_DATA_SIZE = 5500;
                // const TEST_DATA_SIZE = 1000;

                let TRAIN_DATA_SIZE = store.getters.get_trainlabels_num;
                let TEST_DATA_SIZE = store.getters.get_trainlabels_num;

                const [trainXs, trainYs] = tf.tidy(() => {
                    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
                    return [
                        d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
                        d.labels
                    ];
                });

                const [testXs, testYs] = tf.tidy(() => {
                    const d = data.nextTestBatch(TEST_DATA_SIZE);
                    return [
                        d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
                        d.labels
                    ];
                });

                return model.fit(trainXs, trainYs, {
                    batchSize: BATCH_SIZE,
                    validationData: [testXs, testYs],
                    // epochs: 10,
                    epochs: store.getters.get_epoch,
                    shuffle: true,
                    callbacks: fitCallbacks
                });
            }
            const classNames = [
                "0",
                "1",
                "2",
                "3",
                "4",
                "5",
                "6",
                "7",
                "8",
                "9"
            ];

            function doPrediction(model, data, testDataSize = 500) {
                const IMAGE_WIDTH = 28;
                const IMAGE_HEIGHT = 28;
                const testData = data.nextTestBatch(testDataSize);
                const testxs = testData.xs.reshape([
                    testDataSize,
                    IMAGE_WIDTH,
                    IMAGE_HEIGHT,
                    1
                ]);
                const labels = testData.labels.argMax([-1]);
                const preds = model.predict(testxs).argMax([-1]);

                testxs.dispose();
                return [preds, labels];
            }

            async function showAccuracy(model, data) {
                const [preds, labels] = doPrediction(model, data);
                const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
                const container = { name: "정확도", tab: "평가" };
                tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

                labels.dispose();
            }

            async function showConfusion(model, data) {
                const [preds, labels] = doPrediction(model, data);
                const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
                const container = { name: "컨퓨전 매트릭스(혼동 횡렬)", tab: "평가" };
                tfvis.render.confusionMatrix(
                    container, { values: confusionMatrix },
                    classNames
                );

                labels.dispose();
            }

        },
    }
</script>
<style scoped>
    #machinelearning {
        display: none;
    }
</style>
