/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.data;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Trainer;
import org.tribuo.data.DataOptions;
import org.tribuo.evaluation.CrossValidation;
import org.tribuo.evaluation.DescriptiveStats;
import org.tribuo.evaluation.Evaluation;
import org.tribuo.evaluation.EvaluationAggregator;
import org.tribuo.evaluation.Evaluator;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.transform.TransformTrainer;
import org.tribuo.transform.TransformationMap;
import org.tribuo.util.Util;

public final class ConfigurableTrainTest {
    private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName());

    private ConfigurableTrainTest() {
    }

    public static <T extends Output<T>> void main(String[] args) {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null || o.outputFactory == null) {
            logger.info(cm.usage());
            System.exit(1);
        }
        Pair<Dataset<?>, Dataset<?>> data = null;
        try {
            data = o.general.load(o.outputFactory);
        }
        catch (IOException e) {
            logger.log(Level.SEVERE, "Failed to load data", e);
            System.exit(1);
        }
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        if (o.trainer == null) {
            logger.warning("No trainer supplied");
            logger.info(cm.usage());
            System.exit(1);
        }
        if (o.transformationMap != null) {
            o.trainer = new TransformTrainer(o.trainer, o.transformationMap);
        }
        logger.info("Trainer is " + ((TrainerProvenance)o.trainer.getProvenance()).toString());
        logger.info("Outputs are " + train.getOutputInfo().toReadableString());
        logger.info("Number of features: " + train.getFeatureMap().size());
        long trainStart = System.currentTimeMillis();
        Model model = o.trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        Evaluator evaluator = train.getOutputFactory().getEvaluator();
        long testStart = System.currentTimeMillis();
        Evaluation evaluation = evaluator.evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        if (o.general.outputPath != null) {
            try {
                o.general.saveModel(model);
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "Error writing model", e);
            }
        }
        if (o.crossValidation) {
            if (o.numFolds > 1) {
                logger.info("Running " + o.numFolds + " fold cross-validation");
                CrossValidation cv = new CrossValidation(o.trainer, train, evaluator, o.numFolds, o.general.seed);
                List evaluations = cv.evaluate();
                List evals = evaluations.stream().map(Pair::getA).collect(Collectors.toList());
                Map summary = EvaluationAggregator.summarize(evals);
                List keys = new ArrayList(summary.keySet()).stream().sorted(Comparator.comparing(Pair::getB)).collect(Collectors.toList());
                System.out.println("Summary across the folds:");
                for (MetricID key : keys) {
                    DescriptiveStats stats = (DescriptiveStats)summary.get(key);
                    System.out.printf("%-10s  %.5f (%.5f)%n", key, stats.getMean(), stats.getStandardDeviation());
                }
            } else {
                logger.warning("The number of cross-validation folds must be greater than 1, found " + o.numFolds);
            }
        }
    }

    public static class ConfigurableTrainTestOptions
    implements Options {
        public DataOptions general;
        @Option(charName=116, longName="trainer", usage="Load a trainer from the config file.")
        public Trainer<?> trainer;
        @Option(longName="transformer", usage="Load a transformation map from the config file.")
        public TransformationMap transformationMap;
        @Option(charName=97, longName="output-factory", usage="The output factory to construct.")
        public OutputFactory<?> outputFactory;
        @Option(charName=120, longName="cross-validate", usage="Cross-validate the output metrics.")
        public boolean crossValidation;
        @Option(charName=110, longName="num-folds", usage="The number of cross validation folds.")
        public int numFolds = 5;

        public String getOptionsDescription() {
            return "Loads a Trainer from a config file, trains a Model (optionally with cross-validation), tests it and optionally saves it to disk.";
        }
    }
}

