package it.dmi.unict.ferrolab.DataMining.CrossValidation;

import it.dmi.unict.ferrolab.DataMining.Bridge.CrossValidationBridge;
import it.dmi.unict.ferrolab.DataMining.CrossValidation.Results.ConfusionMatrix;
import it.dmi.unict.ferrolab.DataMining.CrossValidation.Results.CrossValidationResult;
import it.dmi.unict.ferrolab.DataMining.CrossValidation.Results.FoldResult;
import it.dmi.unict.ferrolab.DataMining.CrossValidation.Results.SingleResult;
import it.dmi.unict.ferrolab.DataMining.Matrix.MatrixImpl.Matrix;
import it.dmi.unict.ferrolab.DataMining.Matrix.MatrixImpl.MatrixElement;
import it.dmi.unict.ferrolab.DataMining.Matrix.Partition.Column.MatrixColumnPart;
import it.dmi.unict.ferrolab.DataMining.Matrix.Partition.Column.MatrixColumnPartition;
import it.dmi.unict.ferrolab.DataMining.Workflow.ClassificationWorkflow;
import it.dmi.unict.ferrolab.DataMining.Workflow.WorkflowExecutionException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

/* loaded from: input_file:it/dmi/unict/ferrolab/DataMining/CrossValidation/KFoldCrossValidation.class */
public class KFoldCrossValidation extends AbstractCrossValidation implements CrossValidationInterface {
    protected int numClasses;
    protected HashMap<String, Integer> uniqueClasses;
    protected HashMap<Integer, String> uniqueClassesId;
    protected int matrixNumRows;
    protected int matrixNumCols;
    protected Matrix inputMatrix;

    public KFoldCrossValidation(Matrix matrix, int i) {
        this.uniqueClasses = matrix.getUniqueClasses();
        this.uniqueClassesId = matrix.getUniqueClassesIds();
        this.numClasses = this.uniqueClasses.size();
        this.matrixNumRows = matrix.getNumRows();
        this.matrixNumCols = matrix.getNumCols();
        this.inputMatrix = matrix;
        init(i);
        for (String str : this.uniqueClasses.keySet()) {
            int i2 = 0;
            for (int i3 = 0; i3 < this.matrixNumCols; i3++) {
                if (matrix.getColumnClass(i3).equals(str)) {
                    this.columnsByFold[i2].add(matrix.getColumn(i3), i3, str);
                    i2 = i2 == i - 1 ? 0 : i2 + 1;
                }
            }
        }
    }

    @Override // it.dmi.unict.ferrolab.DataMining.CrossValidation.CrossValidationInterface
    public CrossValidationResult validate(CrossValidationBridge crossValidationBridge, ClassificationWorkflow classificationWorkflow, HashMap<String, Object> hashMap) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(this.numClasses, this.uniqueClassesId);
        ArrayList arrayList = new ArrayList(this.numFolds);
        MatrixColumnPartition matrixColumnPartition = new MatrixColumnPartition();
        Collections.addAll(matrixColumnPartition, this.columnsByFold);
        if (crossValidationBridge != null) {
            crossValidationBridge.setCrossValidationActive(true).notifyStatusChange().addLog("Starting K-Fold cross validation.").notifyLogAdded();
        }
        for (int i = 0; i < this.numFolds; i++) {
            if (crossValidationBridge != null) {
                crossValidationBridge.setCurrentFold(i).notifyCurrentFoldChanged().addLog("- Validating with k-fold cross validation fold " + (i + 1) + "... ").notifyLogAdded();
            }
            MatrixColumnPart<MatrixElement> matrixColumnPart = this.columnsByFold[i];
            matrixColumnPartition.exclude(i);
            Matrix matrix = new Matrix((MatrixColumnPartition<MatrixElement>) matrixColumnPartition);
            try {
                classificationWorkflow.init().setBridge(crossValidationBridge).setParameters(hashMap);
                classificationWorkflow.setTrainingSet(matrix).setTestSet((MatrixColumnPart) matrixColumnPart).run();
                double[][] results = classificationWorkflow.getResults();
                FoldResult foldResult = new FoldResult(i);
                for (int i2 = 0; i2 < results.length; i2++) {
                    String originalClassification = matrixColumnPart.get(i2).getOriginalClassification();
                    int intValue = this.uniqueClasses.get(originalClassification).intValue();
                    SingleResult singleResult = new SingleResult(intValue, originalClassification, matrixColumnPart.get(i2).getOriginalIndex());
                    int simpleEvaluation = singleResult.simpleEvaluation(results[i2], this.uniqueClassesId);
                    if (simpleEvaluation == 1 || simpleEvaluation == 0) {
                        confusionMatrix.classified(intValue, singleResult.getClassifiedId());
                    }
                    foldResult.add(singleResult);
                }
                arrayList.add(foldResult);
                matrixColumnPartition.clearExclusion();
            } catch (WorkflowExecutionException e) {
                if (crossValidationBridge != null) {
                    crossValidationBridge.addLog("Unable to complete cross validation. An error occurred: " + e.getMessage()).notifyLogAdded();
                }
                throw new RuntimeException("Error during cross validation.", e);
            }
        }
        CrossValidationResult crossValidationResult = new CrossValidationResult(confusionMatrix, arrayList);
        if (crossValidationBridge != null) {
            crossValidationBridge.setCurrentFold(-1).setCrossValidationActive(false).notifyStatusChange().notifyCurrentFoldChanged().addLog("K-Fold cross validation finished without errors.").notifyLogAdded().setOutput(crossValidationResult).notifyOutputReady();
        }
        return crossValidationResult;
    }
}
