package evofmj.algorithm;

import edu.uci.lasso.LassoFit;
import edu.uci.lasso.LassoFitGenerator;
import evofmj.evaluation.java.EFMScaledData;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

/* loaded from: input_file:evofmj/algorithm/RegressionEFM.class */
public class RegressionEFM {
    EFMScaledData dataMatrix;
    long startTime;
    double TIMEOUT;
    int numberOfOriginalFeatures;
    int numberOfArchiveFeatures;
    int numberOfNewFeatures;
    int maxFinalFeatures;
    int tournamentSize;
    double binary_recomb_rate;
    double[] featureScores;
    int[] indicesArchive;
    ArrayList<String> alWeights;
    Random r;
    double lassoIntercept;
    int maxFeatureSize;
    double BEST_MSE;
    double CURRENT_MSE;
    ArrayList<String> bestFeatures;
    ArrayList<Double> bestWeights;
    double bestIntercerpt;
    int STALL_ITERATIONS;
    int indexIteration;
    boolean VERBOSE = false;
    String[] unaryOps = {"mylog", "exp", "mysqrt", "square", "cube", "cos", "sin"};
    String[] binaryOps = {"*", "mydiv", "+", "-"};
    int MAX_STALL_ITERATIONS = 200;
    int FITNESS_BIAS = 1;
    int MODEL_SELECTION_BIAS = 1;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:evofmj/algorithm/RegressionEFM$FeatureScore.class */
    public class FeatureScore implements Comparable<FeatureScore> {
        private int index;
        private double score;

        public FeatureScore(int i, double d) {
            this.index = i;
            this.score = d;
        }

        public int getIndex() {
            return this.index;
        }

        public double getScore() {
            return this.score;
        }

        @Override // java.lang.Comparable
        public int compareTo(FeatureScore featureScore) {
            int i = 0;
            if (this.score > featureScore.getScore()) {
                i = -1;
            } else if (this.score < featureScore.getScore()) {
                i = 1;
            }
            return i;
        }
    }

    public RegressionEFM(String str, int i, int i2, int i3, int i4) throws IOException {
        this.numberOfArchiveFeatures = i;
        this.numberOfNewFeatures = i2;
        this.maxFinalFeatures = i4;
        this.maxFeatureSize = i3;
        this.dataMatrix = new EFMScaledData(this.numberOfArchiveFeatures, this.numberOfNewFeatures, str);
        this.numberOfOriginalFeatures = this.dataMatrix.getNumberOfOriginalFeatures();
        this.featureScores = new double[this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures];
        this.indicesArchive = new int[this.numberOfArchiveFeatures];
        for (int i5 = 0; i5 < this.numberOfArchiveFeatures; i5++) {
            this.indicesArchive[i5] = this.numberOfOriginalFeatures + i5;
        }
        this.lassoIntercept = 0.0d;
        this.tournamentSize = 2;
        this.binary_recomb_rate = 0.5d;
        long currentTimeMillis = System.currentTimeMillis();
        System.out.println(currentTimeMillis);
        this.r = new Random(currentTimeMillis);
        this.BEST_MSE = Double.MAX_VALUE;
        this.bestFeatures = new ArrayList<>();
        this.STALL_ITERATIONS = 0;
        this.startTime = System.currentTimeMillis();
    }

    public void runEFM(double d) throws Exception {
        if (d > 0.0d) {
            this.TIMEOUT = this.startTime + (d * 1000.0d);
        }
        this.dataMatrix.fillInitialArchiveandNewFeatures(this.r);
        boolean z = false;
        this.indexIteration = 0;
        evalAllFeatures();
        while (!z) {
            generateNewFeatures();
            evalAllFeatures();
            getCurrentModelWeights();
            if (this.VERBOSE) {
                saveCurrentFeatureSet();
                saveCurrentModel();
            }
            z = stopCriteria();
            this.indexIteration++;
        }
        saveBestFeatureSet(true);
        saveBestModel(true);
    }

    public boolean stopCriteria() {
        if (System.currentTimeMillis() >= this.TIMEOUT) {
            System.out.println("Timout exceeded, exiting. BEST MSE IS: " + this.BEST_MSE);
            return true;
        }
        if (this.STALL_ITERATIONS <= this.MAX_STALL_ITERATIONS) {
            return false;
        }
        System.out.println("Progress Stalled, exiting. BEST MSE IS: " + this.BEST_MSE);
        return true;
    }

    private void computeFeatureImportanceVariableCount(LassoFit lassoFit, ArrayList<FeatureScore> arrayList) {
        int i = 0;
        double d = 0.0d;
        for (int i2 = 0; i2 < lassoFit.lambdas.length; i2++) {
            if (lassoFit.rsquared[i2] > d) {
                i = i2;
                d = lassoFit.rsquared[i2];
            }
        }
        double[] weights = lassoFit.getWeights(i);
        for (int i3 = 0; i3 < weights.length; i3++) {
            if (weights[i3] != 0.0d) {
                this.featureScores[i3] = 0.0d;
                for (int i4 = 0; i4 < lassoFit.nonZeroWeights.length; i4++) {
                    if (lassoFit.getWeights(i4)[i3] != 0.0d) {
                        double[] dArr = this.featureScores;
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + 1.0d;
                    }
                }
            } else {
                this.featureScores[i3] = 0.0d;
            }
            if (i3 >= this.numberOfOriginalFeatures) {
                arrayList.add(new FeatureScore(i3, this.featureScores[i3]));
            }
        }
    }

    private void computeFeatureImportanceBiasR2(LassoFit lassoFit, ArrayList<FeatureScore> arrayList) {
        int i = 0;
        double d = 0.0d;
        for (int i2 = 0; i2 < lassoFit.lambdas.length; i2++) {
            if (lassoFit.rsquared[i2] > d) {
                i = i2;
                d = lassoFit.rsquared[i2];
            }
        }
        double[] weights = lassoFit.getWeights(i);
        for (int i3 = 0; i3 < weights.length; i3++) {
            this.featureScores[i3] = 0.0d;
            if (weights[i3] != 0.0d) {
                for (int i4 = 0; i4 < lassoFit.nonZeroWeights.length; i4++) {
                    if (lassoFit.getWeights(i4)[i3] != 0.0d) {
                        double[] dArr = this.featureScores;
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + lassoFit.rsquared[i4];
                    }
                }
            }
            if (i3 >= this.numberOfOriginalFeatures) {
                arrayList.add(new FeatureScore(i3, this.featureScores[i3]));
            }
        }
    }

    private void computeFeatureImportanceBiasMSE(LassoFit lassoFit, ArrayList<FeatureScore> arrayList) {
        double[][] inputValues = this.dataMatrix.getInputValues();
        double[] targetValues = this.dataMatrix.getTargetValues();
        double d = Double.MAX_VALUE;
        int i = 0;
        double[] dArr = new double[lassoFit.lambdas.length];
        for (int i2 = 0; i2 < lassoFit.lambdas.length; i2++) {
            double d2 = lassoFit.intercepts[i2];
            double[] weights = lassoFit.getWeights(i2);
            double d3 = 0.0d;
            for (int i3 = 0; i3 < this.dataMatrix.getNumberOfFitnessCases(); i3++) {
                double d4 = d2;
                int i4 = 0;
                for (int i5 = 0; i5 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i5++) {
                    d4 += inputValues[i3][i5] * weights[i4];
                    i4++;
                }
                d3 += Math.pow(targetValues[i3] - d4, 2.0d);
            }
            double numberOfFitnessCases = d3 / this.dataMatrix.getNumberOfFitnessCases();
            dArr[i2] = numberOfFitnessCases;
            if (numberOfFitnessCases < d) {
                d = numberOfFitnessCases;
                i = i2;
            }
        }
        double[] weights2 = lassoFit.getWeights(i);
        for (int i6 = 0; i6 < weights2.length; i6++) {
            this.featureScores[i6] = 0.0d;
            if (weights2[i6] != 0.0d) {
                for (int i7 = 0; i7 < lassoFit.numberOfLambdas; i7++) {
                    if (lassoFit.getWeights(i7)[i6] != 0.0d) {
                        double[] dArr2 = this.featureScores;
                        int i8 = i6;
                        dArr2[i8] = dArr2[i8] + (1.0d / dArr[i7]);
                    }
                }
            }
            if (i6 >= this.numberOfOriginalFeatures) {
                arrayList.add(new FeatureScore(i6, this.featureScores[i6]));
            }
        }
    }

    private void evalAllFeatures() throws Exception {
        LassoFitGenerator lassoFitGenerator = new LassoFitGenerator();
        int numberOfFitnessCases = this.dataMatrix.getNumberOfFitnessCases();
        lassoFitGenerator.init(this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures, numberOfFitnessCases);
        for (int i = 0; i < numberOfFitnessCases; i++) {
            double[] row = this.dataMatrix.getRow(i);
            float[] fArr = new float[this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures];
            for (int i2 = 0; i2 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i2++) {
                fArr[i2] = (float) row[i2];
            }
            lassoFitGenerator.setObservationValues(i, fArr);
            lassoFitGenerator.setTarget(i, this.dataMatrix.getTargetValues()[i]);
        }
        LassoFit fit = lassoFitGenerator.fit(this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures, 100);
        this.alWeights = null;
        this.alWeights = new ArrayList<>();
        ArrayList<FeatureScore> arrayList = new ArrayList<>();
        if (this.FITNESS_BIAS == 0) {
            computeFeatureImportanceVariableCount(fit, arrayList);
        } else if (this.FITNESS_BIAS == 1) {
            computeFeatureImportanceBiasR2(fit, arrayList);
        } else if (this.FITNESS_BIAS == 2) {
            computeFeatureImportanceBiasMSE(fit, arrayList);
        }
        Collections.sort(arrayList);
        for (int i3 = 0; i3 < this.numberOfArchiveFeatures; i3++) {
            this.indicesArchive[i3] = arrayList.get(i3).getIndex();
        }
        arrayList.clear();
    }

    private int getIndexLambdaModelSelectionR2(LassoFit lassoFit) {
        int i = 0;
        double d = lassoFit.rsquared[0];
        for (int i2 = 1; i2 < lassoFit.lambdas.length; i2++) {
            if (lassoFit.rsquared[i2] > d) {
                i = i2;
            }
        }
        return i;
    }

    private int getIndexLambdaModelSelectionMSE(LassoFit lassoFit) {
        double[][] inputValues = this.dataMatrix.getInputValues();
        double[] targetValues = this.dataMatrix.getTargetValues();
        double d = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < lassoFit.lambdas.length; i2++) {
            double d2 = lassoFit.intercepts[i2];
            double[] weights = lassoFit.getWeights(i2);
            double d3 = 0.0d;
            for (int i3 = 0; i3 < this.dataMatrix.getNumberOfFitnessCases(); i3++) {
                double d4 = d2;
                int i4 = 0;
                for (int i5 = 0; i5 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures; i5++) {
                    d4 += inputValues[i3][i5] * weights[i4];
                    i4++;
                }
                d3 += Math.pow(targetValues[i3] - d4, 2.0d);
            }
            double numberOfFitnessCases = d3 / this.dataMatrix.getNumberOfFitnessCases();
            if (numberOfFitnessCases < d) {
                d = numberOfFitnessCases;
                i = i2;
            }
        }
        return i;
    }

    private void getCurrentModelWeights() throws Exception {
        LassoFitGenerator lassoFitGenerator = new LassoFitGenerator();
        int numberOfFitnessCases = this.dataMatrix.getNumberOfFitnessCases();
        lassoFitGenerator.init(this.numberOfOriginalFeatures + this.numberOfArchiveFeatures, numberOfFitnessCases);
        for (int i = 0; i < numberOfFitnessCases; i++) {
            double[] row = this.dataMatrix.getRow(i);
            float[] fArr = new float[this.numberOfOriginalFeatures + this.numberOfArchiveFeatures];
            int i2 = 0;
            for (int i3 = 0; i3 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i3++) {
                if (i3 < this.numberOfOriginalFeatures || archiveContains(i3)) {
                    fArr[i2] = (float) row[i3];
                    i2++;
                }
            }
            lassoFitGenerator.setObservationValues(i, fArr);
            lassoFitGenerator.setTarget(i, this.dataMatrix.getTargetValues()[i]);
        }
        LassoFit fit = lassoFitGenerator.fit(this.maxFinalFeatures, 100);
        int i4 = 0;
        if (this.MODEL_SELECTION_BIAS == 1) {
            i4 = getIndexLambdaModelSelectionR2(fit);
        } else if (this.MODEL_SELECTION_BIAS == 2) {
            i4 = getIndexLambdaModelSelectionMSE(fit);
        }
        this.alWeights = null;
        this.alWeights = new ArrayList<>();
        double[] weights = fit.getWeights(i4);
        for (double d : weights) {
            this.alWeights.add(Double.toString(d));
        }
        this.lassoIntercept = fit.intercepts[i4];
        double[][] inputValues = this.dataMatrix.getInputValues();
        double[] targetValues = this.dataMatrix.getTargetValues();
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i5 = 0; i5 < this.dataMatrix.getNumberOfFitnessCases(); i5++) {
            double d4 = this.lassoIntercept;
            int i6 = 0;
            for (int i7 = 0; i7 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i7++) {
                if (i7 < this.numberOfOriginalFeatures || archiveContains(i7)) {
                    d4 += inputValues[i5][i7] * weights[i6];
                    i6++;
                }
            }
            d2 += Math.pow(targetValues[i5] - d4, 2.0d);
            d3 += Math.abs(targetValues[i5] - d4);
        }
        double numberOfFitnessCases2 = d2 / this.dataMatrix.getNumberOfFitnessCases();
        System.out.println("TIME IS: " + ((System.currentTimeMillis() - this.startTime) / 1000.0d) + " ; CURRENT MSE IS: " + numberOfFitnessCases2 + " ; MAE IS: " + (d3 / this.dataMatrix.getNumberOfFitnessCases()) + " ; BEST MSE IS: " + this.BEST_MSE);
        this.CURRENT_MSE = numberOfFitnessCases2;
        if (this.CURRENT_MSE >= this.BEST_MSE) {
            this.STALL_ITERATIONS++;
            return;
        }
        this.BEST_MSE = this.CURRENT_MSE;
        this.STALL_ITERATIONS = 0;
        this.bestFeatures = null;
        this.bestFeatures = new ArrayList<>();
        this.bestWeights = null;
        this.bestWeights = new ArrayList<>();
        this.bestIntercerpt = this.lassoIntercept;
        int i8 = 0;
        for (int i9 = 0; i9 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i9++) {
            if (i9 < this.numberOfOriginalFeatures || archiveContains(i9)) {
                String featureString = this.dataMatrix.getFeatureString(i9);
                double d5 = weights[i8];
                this.bestFeatures.add(featureString);
                this.bestWeights.add(Double.valueOf(d5));
                i8++;
            }
        }
    }

    private void generateNewFeatures() {
        int i = this.numberOfOriginalFeatures;
        int i2 = this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures;
        for (int i3 = i; i3 < i2; i3++) {
            if (!archiveContains(i3)) {
                int i4 = tournamentSelection();
                if (this.r.nextFloat() < this.binary_recomb_rate) {
                    int i5 = tournamentSelection();
                    if (this.dataMatrix.getFeatureSize(i4) + this.dataMatrix.getFeatureSize(i5) < this.maxFeatureSize) {
                        binaryRecombination(i3, i4, i5);
                    } else {
                        this.dataMatrix.setFeatureToZero(i3);
                    }
                } else if (this.dataMatrix.getFeatureSize(i4) < this.maxFeatureSize) {
                    unaryRecombination(i3, i4);
                } else {
                    this.dataMatrix.setFeatureToZero(i3);
                }
            }
        }
    }

    private boolean archiveContains(int i) {
        for (int i2 = 0; i2 < this.numberOfArchiveFeatures; i2++) {
            if (this.indicesArchive[i2] == i) {
                return true;
            }
        }
        return false;
    }

    private int tournamentSelection() {
        int nextInt = this.r.nextInt(this.numberOfOriginalFeatures + this.numberOfArchiveFeatures);
        if (nextInt >= this.numberOfOriginalFeatures) {
            nextInt = this.indicesArchive[nextInt - this.numberOfOriginalFeatures];
        }
        for (int i = 0; i < this.tournamentSize - 1; i++) {
            int nextInt2 = this.r.nextInt(this.numberOfOriginalFeatures + this.numberOfArchiveFeatures);
            if (nextInt2 >= this.numberOfOriginalFeatures) {
                nextInt2 = this.indicesArchive[nextInt2 - this.numberOfOriginalFeatures];
            }
            if (this.featureScores[nextInt2] > this.featureScores[nextInt]) {
                nextInt = nextInt2;
            }
        }
        return nextInt;
    }

    private void binaryRecombination(int i, int i2, int i3) {
        String str = this.binaryOps[this.r.nextInt(this.binaryOps.length)];
        boolean z = -1;
        switch (str.hashCode()) {
            case 42:
                if (str.equals("*")) {
                    z = false;
                    break;
                }
                break;
            case 43:
                if (str.equals("+")) {
                    z = 2;
                    break;
                }
                break;
            case 45:
                if (str.equals("-")) {
                    z = 3;
                    break;
                }
                break;
            case 104367973:
                if (str.equals("mydiv")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.dataMatrix.multiplication(i, i2, i3);
                return;
            case true:
                this.dataMatrix.division(i, i2, i3);
                return;
            case true:
                this.dataMatrix.sum(i, i2, i3);
                return;
            case true:
                this.dataMatrix.minus(i, i2, i3);
                return;
            default:
                return;
        }
    }

    private void unaryRecombination(int i, int i2) {
        String str = this.unaryOps[this.r.nextInt(this.unaryOps.length)];
        boolean z = -1;
        switch (str.hashCode()) {
            case -1059105588:
                if (str.equals("mysqrt")) {
                    z = 2;
                    break;
                }
                break;
            case -894674659:
                if (str.equals("square")) {
                    z = 3;
                    break;
                }
                break;
            case 98695:
                if (str.equals("cos")) {
                    z = 5;
                    break;
                }
                break;
            case 100893:
                if (str.equals("exp")) {
                    z = true;
                    break;
                }
                break;
            case 113880:
                if (str.equals("sin")) {
                    z = 6;
                    break;
                }
                break;
            case 3064885:
                if (str.equals("cube")) {
                    z = 4;
                    break;
                }
                break;
            case 104375832:
                if (str.equals("mylog")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.dataMatrix.log(i, i2);
                return;
            case true:
                this.dataMatrix.exp(i, i2);
                return;
            case true:
                this.dataMatrix.sqrt(i, i2);
                return;
            case true:
                this.dataMatrix.square(i, i2);
                return;
            case true:
                this.dataMatrix.cube(i, i2);
                return;
            case true:
                this.dataMatrix.cos(i, i2);
                return;
            case true:
                this.dataMatrix.sin(i, i2);
                return;
            default:
                return;
        }
    }

    private void saveText(String str, String str2, Boolean bool) throws IOException {
        PrintWriter printWriter = new PrintWriter(new BufferedWriter(new FileWriter(str, bool.booleanValue())));
        printWriter.write(str2);
        printWriter.flush();
        printWriter.close();
    }

    private void saveCurrentFeatureSet() throws IOException {
        String str = "features_" + this.indexIteration + ".txt";
        saveText(str, "", false);
        for (int i = 0; i < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i++) {
            if (i < this.numberOfOriginalFeatures || archiveContains(i)) {
                saveText(str, this.dataMatrix.getFeatureString(i) + ",", true);
            }
        }
    }

    private void saveCurrentModel() throws IOException {
        System.out.println(this.indexIteration);
        String str = "model_" + this.indexIteration + ".txt";
        saveText(str, this.lassoIntercept + "\n", false);
        int i = 0;
        for (int i2 = 0; i2 < this.numberOfOriginalFeatures + this.numberOfArchiveFeatures + this.numberOfNewFeatures; i2++) {
            if (i2 < this.numberOfOriginalFeatures || archiveContains(i2)) {
                saveText(str, " + " + this.alWeights.get(i) + " * " + this.dataMatrix.getFeatureString(i2) + "\n", true);
                i++;
            }
        }
    }

    private void saveBestFeatureSet(boolean z) throws IOException {
        String str = "features_" + this.indexIteration + ".txt";
        if (z) {
            str = "features.txt";
        }
        saveText(str, "", false);
        for (int i = 0; i < this.bestFeatures.size(); i++) {
            saveText(str, this.bestFeatures.get(i) + ",", true);
        }
    }

    private void saveBestModel(boolean z) throws IOException {
        String str = "model_" + this.indexIteration + ".txt";
        if (z) {
            str = "model.txt";
        }
        saveText(str, this.dataMatrix.getTargetMin() + "," + this.dataMatrix.getTargetMax() + "\n", false);
        saveText(str, this.bestIntercerpt + "\n", true);
        for (int i = 0; i < this.bestFeatures.size(); i++) {
            if (this.bestWeights.get(i).doubleValue() != 0.0d) {
                saveText(str, " + " + this.bestWeights.get(i) + " * " + this.bestFeatures.get(i) + "\n", true);
            }
        }
    }
}
