Domáca č. 3 - Backpropagation na konkrétnej sieti a príklade

Created: 2008-10-15 - 09:11

Sa mi to nechcelo počítať ručne, tak som si pomocou ctrl+c a ctrl+v programovania vytvoril pomôcku:

Vstupy sa zadavaju rucne
Backpropagation:


import java.util.ArrayList;

/**
 * @author Pavol Rajzák "rapasoft"
 * @see http://rapasoft.netkosice.sk 
 */
public class Backpropagation {
	private double[][] w1; // vahy prva vrstva
	private double[][] w2; // vahy druha vrstva
	private double ni; // urciaci pomer
	final double xi = -1.0; // konstantne xi
	private ArrayList< ArrayList<Double> > t; // treningova mnozina
	
	//vystupy VIJ jednotlivych neuronov, kde I je vrtva a J je cislo neuronu vo vrstve 
	private double V11;
	private double V12;
	private double V21;
	
	//zmeny v hodnotach
	private double delta21;
	private double delta11;
	private double delta12;
	
	//aktivacna funkcia
	//------- tu bola predtym chyba v ozatvorkovani:
	private double f(double x){
		return (1/(1+Math.exp(-x)));
	}
	//derivacia f
	private double fD(double x){
		return (Math.exp(-x)/( (1+Math.exp(-x))*(1+Math.exp(-x)) ));
	}
	
	public Backpropagation(double ni){
		//nacitanie vstupov zo suboru:
		this.ni = ni;
		t = new ArrayList<ArrayList<Double>>();
		ArrayList<Double> tmp = new ArrayList<Double>();
		//nastavime nejake pociatocne hodnoty
		tmp.add(0.9); //x1
		tmp.add(0.4); //x2
		tmp.add(-0.5); //d
		t.add(tmp);
		//generovanie vah:
		w1 = new double[3][3];
		w2 = new double[2][3];
		w1[1][0]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[1][0]: "+ w1[1][0]);
		w1[1][1]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[1][1]: "+ w1[1][1]);
		w1[1][2]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[1][2]: "+ w1[1][2]);
		w1[2][0]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[2][0]: "+ w1[2][0]);
		w1[2][1]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[2][1]: "+ w1[2][1]);
		w1[2][2]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w1[2][2]: "+ w1[2][2]);
		w2[1][0]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w2[1][0]: "+ w2[1][0]);
		w2[1][1]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w2[1][1]: "+ w2[1][1]);
		w2[1][2]=((int)(Math.random()*10-5))/10.0;
		System.out.println("w2[1][2]: "+ w2[1][2]);
	}
	
	public void evaluate(){
		//vybereme x1, x2 a d
		ArrayList<Double> tmp = t.get(0);
		
		double x1 = tmp.get(0);
		double x2 = tmp.get(1);
		double d1 = tmp.get(2);
		
		//prvy vystup prvej vrstvy
		V11 = xi*w1[1][0]+x1*w1[1][1]+x2*w1[1][2];
		double h11 = V11;
		System.out.println("h11: "+h11);
		V11 = f(V11);
		System.out.println("V11: "+V11);
		
		//druhy vystup prvej vrstvy
		V12 = xi*w1[2][0]+x1*w1[2][1]+x2*w1[2][2];
		double h12 = V12;
		System.out.println("h12: "+h12);
		V12 = f(V12);
		System.out.println("V12: "+V12);
		
		//prvy vystup druhej vrstvy
		V21 = xi*w2[1][0]+V11*w2[1][1]+V12*w2[1][2];
		double h21 = V21;
		System.out.println("h21: "+h21);
		V21 = f(V21);
		System.out.println("V21: "+V21);
		
		//upravime hodnoty ziskanymi rozdielmi
		delta21 = fD(h21)*(d1-V21);
		System.out.println("delta21: "+delta21);
		delta11 = fD(h11)*(w2[1][1]*delta21);
		System.out.println("delta11: "+delta11);
		delta12 = fD(h12)*(w2[1][2]*delta21);
		System.out.println("delta12: "+delta12);
		System.out.println();
		
		double deltaW110 = ni *  delta11 * xi;
		double deltaW111 = ni *  delta11 * x1;
		double deltaW112 = ni *  delta11 * x2;
		System.out.println("delta w1[1][0]: "+deltaW110);
		System.out.println("delta w1[1][1]: "+deltaW111);
		System.out.println("delta w1[1][2]: "+deltaW112);
		
		System.out.println();
		
		double deltaW120 = ni *  delta12 * xi;
		double deltaW121 = ni *  delta12 * x1;
		double deltaW122 = ni *  delta12 * x2;
		System.out.println("delta w1[2][0]: "+deltaW120);
		System.out.println("delta w1[2][1]: "+deltaW121);
		System.out.println("delta w1[2][2]: "+deltaW122);
		
		double deltaW210 = ni *  delta21 * xi;
		double deltaW211 = ni *  delta21 * V11;
		double deltaW212 = ni *  delta21 * V12;
		System.out.println("delta w2[1][0]: "+deltaW210);
		System.out.println("delta w2[1][1]: "+deltaW211);
		System.out.println("delta w2[1][2]: "+deltaW212);
		
		System.out.println();
		
		System.out.println("Nove hodnoty:");
		
		System.out.println("w1[1][0]: "+(w1[1][0]+deltaW110));
		System.out.println("w1[1][1]: "+(w1[1][1]+deltaW111));
		System.out.println("w1[1][2]: "+(w1[1][2]+deltaW112));
		
		System.out.println("w1[2][0]: "+(w1[2][0]+deltaW120));
		System.out.println("w1[2][1]: "+(w1[2][1]+deltaW121));
		System.out.println("w1[2][2]: "+(w1[2][2]+deltaW122));
		
		System.out.println("w2[1][0]: "+(w2[1][0]+deltaW210));
		System.out.println("w2[1][0]: "+(w2[1][1]+deltaW211));
		System.out.println("w2[1][0]: "+(w2[1][2]+deltaW212));
	}
}

BackpropagationTester

public class BackpropagationTester {
	public static void main(String[] args) {
		Backpropagation b = new Backpropagation(0.9,new File("t.txt"));
		b.evaluate();
	}
}