package evogpj.postprocessing;

import evogpj.evaluation.java.CSVDataJava;
import evogpj.evaluation.java.DataJava;
import evogpj.gp.Population;
import java.util.ArrayList;
import java.util.Collections;

/* loaded from: input_file:evogpj/postprocessing/ModelFuserARM.class */
public class ModelFuserARM {
    String fusionTrainingSet;
    DataJava ad;
    int numIters;
    int numberOfLines;
    Population models;
    int numModels;
    int indexSplit;
    int numberOfFeatures;
    boolean round;
    double[][] predictions;
    double[][] predictionsShuffled;
    double[] target;
    double[] targetShuffled;
    double[] expEjs;
    double[][] lnWjs;
    double[] Wjs;

    public ModelFuserARM(String str, int i, Population population, int i2, boolean z) {
        this.fusionTrainingSet = str;
        this.ad = new CSVDataJava(this.fusionTrainingSet);
        this.numberOfLines = this.ad.getNumberOfFitnessCases();
        this.numberOfFeatures = i;
        this.target = this.ad.getTargetValues();
        this.numIters = i2;
        this.models = population;
        this.numModels = this.models.size();
        this.round = z;
        this.target = new double[this.numberOfLines];
        this.targetShuffled = new double[this.numberOfLines];
        this.predictions = new double[this.numberOfLines][this.numModels];
        this.predictionsShuffled = new double[this.numberOfLines][this.numModels];
        this.expEjs = new double[this.numModels];
        this.lnWjs = new double[this.numIters][this.numModels];
        this.Wjs = new double[this.numModels];
    }

    private double std(double[] dArr, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += dArr[i2];
        }
        double d2 = d / i;
        double d3 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d3 += Math.pow(dArr[i3] - d2, 2.0d) / i;
        }
        return Math.sqrt(d3);
    }

    private void steps0TO3() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numberOfLines; i++) {
            arrayList.add(i, Integer.valueOf(i));
        }
        Collections.shuffle(arrayList);
        for (int i2 = 0; i2 < this.numberOfLines; i2++) {
            int intValue = ((Integer) arrayList.get(i2)).intValue();
            this.targetShuffled[i2] = this.target[intValue];
            System.arraycopy(this.predictions[intValue], 0, this.predictionsShuffled[i2], 0, this.numModels);
        }
        this.indexSplit = Math.round(this.numberOfLines / 2);
        for (int i3 = 0; i3 < this.numModels; i3++) {
            double[] dArr = new double[this.numberOfLines];
            double[] dArr2 = new double[this.numberOfLines];
            for (int i4 = 0; i4 < this.numberOfLines; i4++) {
                dArr[i4] = this.targetShuffled[i4] - this.predictionsShuffled[i4][i3];
                dArr2[i4] = Math.pow(dArr[i4], 2.0d);
            }
            double std = std(dArr, this.indexSplit);
            double d = 0.0d;
            for (int i5 = this.indexSplit; i5 < this.numberOfLines; i5++) {
                d += dArr2[i5];
            }
            this.expEjs[i3] = (((-this.numberOfLines) / 2) * Math.log(std)) + (((-d) * Math.pow(std, -2.0d)) / 2.0d);
        }
    }

    private void step04(int i) {
        double logSumExpExpEjs = logSumExpExpEjs();
        for (int i2 = 0; i2 < this.numModels; i2++) {
            this.lnWjs[i][i2] = this.expEjs[i2] - logSumExpExpEjs;
        }
    }

    private double logSumExpExpEjs() {
        double d = -1.7976931348623157E308d;
        for (int i = 0; i < this.expEjs.length; i++) {
            if (this.expEjs[i] > d) {
                d = this.expEjs[i];
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.expEjs.length; i2++) {
            if (this.expEjs[i2] != Double.NEGATIVE_INFINITY) {
                d2 += Math.exp(this.expEjs[i2] - d);
            }
        }
        double log = d + Math.log(d2);
        if (log == Double.MAX_VALUE) {
            log = d;
        }
        return log;
    }

    private void step5() {
        for (int i = 0; i < this.numModels; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numIters; i2++) {
                d += Math.exp(this.lnWjs[i2][i]);
            }
            this.Wjs[i] = d / this.numIters;
        }
    }

    public double[] arm_weights() {
        EvalScaledModels evalScaledModels = new EvalScaledModels(this.ad);
        for (int i = 0; i < this.models.size(); i++) {
            evalScaledModels.eval(this.models.get(i), i, this.predictions, this.round);
        }
        for (int i2 = 0; i2 < this.numIters; i2++) {
            steps0TO3();
            step04(i2);
        }
        step5();
        return this.Wjs;
    }
}
