Domáca č. 2 - Učiaci algoritmus perceptrónu

Created: 2008-10-15 - 09:07

Načítavanie vstupov prebieha cez súbor:


0,2 -0,4 0,9
0,3 0,9 -0,4

PereptronTester:

import java.io.File;


public class PerceptronTest {
	public static void main(String[] args) {
		File subor = new File("t.txt");
		Perceptron p = new Perceptron(0.5,subor);
		p.natrenuj();
	}
}


Perceptron:

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;

/**
 * @author Pavol Rajzák
 */
public class Perceptron {
	/**
	 * Pocet prvkov treninngovej mnoziny
	 */
	private int m;

	/**
	 * Ulozenie treningovej mnoziny - zoznam zoznamov prvkov V kazdom zozname sa
	 * nachádza zoznam súradníc vektora x^i a hodnota d^i (napr. {2,6,1}
	 * znamena, ze x^i = (2,6) a d^i = 1
	 */
	private ArrayList<ArrayList<Double>> T;

	/**
	 * Určiaci pomer
	 */
	private double ro;

	/**
	 * Počiatočné nastavenie váh, ktoré sa generujú automaticky
	 */
	private ArrayList<Double> w0;

	/**
	 * Konštanta x0 = -1
	 */
	final int x0 = -1;

	/**
	 * Konstruktor nacitava zadanie treningovej mnoziny zo suboru, cita po
	 * riadkoch a uklada parametre do prislusnych premennych. Hodnotu m získa z
	 * počtu vstupných údajov (riadkov).
	 * 
	 * @param ro -
	 *            určiaci pomer
	 * @param subor -
	 *            vstupný súbor, kde vstupné hodnoty sú zapísané po riadkoch v
	 *            tvare x^i_0 x^i_1 d^i, napr.: -6 -3 1 2 4 -1 atď...
	 */
	public Perceptron(double ro, File subor) {
		T = new ArrayList<ArrayList<Double>>();
		w0 = new ArrayList<Double>();
		nacitajTrenovaciuMnozinu(subor);
		this.ro = ro;
	}

	/**
	 * Načítava trénovaciu množinu zo súboru
	 */
	private void nacitajTrenovaciuMnozinu(File subor) {
		try {
			Scanner sc = new Scanner(subor);
			int j = 0;
			while (sc.hasNextLine()) {
				String riadok = sc.nextLine();
				String[] podelene = riadok.split(" ");
				ArrayList<Double> tmp = new ArrayList<Double>();
				for (String string : podelene) {
					tmp.add(Double.parseDouble(string));
				}
				T.add(tmp);
				j++;
			}
			this.m = j;
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	@Override
	public String toString() {
		String s = T.toString();
		return s;
	}

	/**
	 * Generovanie počiatočných váh
	 */
	private void generujW0() {
		w0.add((double) ((int) (Math.random() * 10 - 5)));
		w0.add((double) ((int) (Math.random() * 10 - 5)));
		w0.add((double) ((int) (Math.random() * 10 - 5)));
	}

	/**
	 * Samotný algoritmus trénovania na začiatku vygeneruje počiatočné váhy.
	 * Ďalej iteruje cez tréningovú množinu a nastavuje nové hodnoty váh. Ak sa
	 * váhy zmenia pokračuje, ak nie, zastaví sa.
	 */
	public void natrenuj() {
		generujW0();
		boolean zmenilaSaTrenovaciaMnozina = true;
		double w00 = w0.get(0);
		double w01 = w0.get(1);
		double w02 = w0.get(2);
		System.out.println("Vstup: w0 = (" + w00 + "," + w01 + "," + w02 + ")");
		System.out.println("");
		int k = 0; //počet iterácií
		while (zmenilaSaTrenovaciaMnozina) {
			if(k>100000){
				System.out.println("atď...");
				break;
			}
			double tmp0 = w00;
			double tmp1 = w01;
			double tmp2 = w02;
			for (int i = 0; i < m; i++) {
				k++;
				ArrayList<Double> tmp = T.get(i); 
				double t0 = tmp.get(0);
				double t1 = tmp.get(1);
				double t2 = tmp.get(2);
				double y = Math.signum(w00 * x0 + w01 * t0 + w02 * t1);
				w00 = w00 + ro * (t2 - y) * x0;
				w01 = w01 + ro * (t2 - y) * t0;
				w02 = w02 + ro * (t2 - y) * t1;
				System.out.println("w" + k + " = (" + w00 + "," + w01 + ","
						+ w02 + ")");
			}
			if (tmp0 == w00 && tmp1 == w01 && tmp2 == w02) {
				zmenilaSaTrenovaciaMnozina = false;
				System.out.println("Koniec");
			} else
				System.out.println("Zmena, pokracujem v iteracii");
		}
	}
}