(*
 * lablai - An ML Artificial Inteligence library
 * Copyright (C) 2006  Till Crueger
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *)


(* File $RCSfile$ *)
(* last edited by $Author: till_crueger $ *)
(* $Date: 2007-12-15 19:53:03 +0100 (Sa, 15 Dez 2007) $, $Revision: 29 $ *)

open Types;;

(* A node contains a value describing the position and a list of associated examples *)
type 'a node = { mutable value : 'a ; mutable content : 'a list};;


type dist_coord = {dist : float; coord : int coord};;

class ['a] som (size' : int coord) dist' (move' : float -> '-> '-> 'a) (init : int coord -> 'a)  =
   (* Initialize the array of nodes with values picked by the init function *)
   let nodes' = Array.init size'.x (fun x -> Array.init size'.y (fun y -> {value = init {x=x; y=y}; content = []})) in
   object (self)
      (* complete node table *)
      val nodes = nodes'
      val size : int coord = size'
      
      (* functions to work with the search space *)
      val dist  : '-> '-> float = dist'
      val move = move'
      
      val mutable generation = 0
      
      val mutable radius = 20
      val mutable alpha = 0.2
      
      (* The policies overwrite the fixed values *)
      val mutable radius_policy : (int -> int) option = None
      val mutable alpha_policy : (int -> float) option = None
      
      method set_radius radius' = 
         radius <- radius' ;
         radius_policy <-None
      method get_radius = 
         match radius_policy with
            None -> Some radius
         |  Some _ -> None
      
      method set_alpha alpha' =
         alpha <- alpha';
         alpha_policy <-None
      method get_alpha = 
         match alpha_policy with
            None ->  Some alpha
         |  Some _ -> None
      
      method set_radius_policy radius_policy' = radius_policy <- Some radius_policy'
      method get_radius_policy = radius_policy
      
      method set_alpha_policy alpha_policy' = alpha_policy <- Some alpha_policy'
      method get_alpha_policy = alpha_policy
      
      method get_value coord =
         nodes.(coord.x).(coord.y).value
   
      (* find the closest node to a given example *)
      method find_node example = 
         (* calculate the distance for all nodes *)
         let mapy x = Array.mapi (fun y -> fun node -> {dist= dist example node.value; coord= {x=x;y=y}}) in
         let evaluated = Array.mapi mapy nodes in
         (* find the one with minimum distance by folding twice *)
         let folder current best = if current.dist < best.dist then current else best in
         (* first fold all all columns (get the best in each column)*)
         let folder2 x arr = Array.fold_right folder arr evaluated.(x).(0) in
         let folded1 = Array.mapi folder2 evaluated in
         (* then fold the resulting array *)
         let folded2 = Array.fold_left folder folded1.(0) folded1 in
         (* we just need the coord *)
         folded2.coord
      
      (* move a node closer to another position*)
      method move_node x y alpha pos =
         (* make sure we don't hit the Array bounds *)
         if (x < size.x) && (x >=0) then
            if (y < size.y) && (y >=0) then
               nodes.(x).(y).value<- move alpha nodes.(x).(y).value pos
            else 
            ()
         else 
         ()
      
      (* Train the node with a given example *)
      method train example = 
         (* find out which trainign radius to use *)
         let radius' =
            (* when a policy is set, that is used *)
            match radius_policy with
               Some policy -> policy generation
            | None ->radius 
         in
         let alpha' =
            (* Same as above for the learning ratio *)
            match alpha_policy with
               Some policy -> policy generation
            | None -> alpha
         in
         
         (* find the best matching node for the example *)
         let best = self#find_node example in
         
         (* move all the nodes in the neighborhood *)
         self#move_node best.x best.y alpha' example;
         for distance = 1 to radius' do
            let ratio = (1.0-. (float_of_int distance) /. (float_of_int radius')) **2.0in
            for x = 0 to distance-1  do
               self#move_node (best.x-x) (best.y-distance+x) (alpha' *. ratio) example;
               self#move_node (best.x-distance+x) (best.y+x) (alpha' *. ratio) example;
               self#move_node (best.x+x) (best.y+distance-x) (alpha' *. ratio) example;
               self#move_node (best.x+distance-x) (best.y-x) (alpha' *. ratio) example;
            done
         done;
         
         generation <- generation +1
      
      (* Batch learn a whole set of examples *)
      method learn examples =
         (* try to learn as long as the function can provide us with examples *)
         let rec loop () = 
            let example = examples self in
            match example with
               None -> ()
            |  Some example -> self#train example; loop ()
         in loop ()
      
      (* can be used for dumping the network on screen for example *)
      method dump (func : int -> int -> '-> unit) (func2 : unit -> unit) = 
         (* go through all lines and call the passed functions *)
         for x = 0 to size.x -1 do
            for y = 0 to size.y-1 do
               func x y nodes.(x).(y).value
            done;
            func2 ()
         done
         
   end;;

(* 
 * $Log$
 * Revision 1.2  2007/12/15 18:52:58  till_crueger
 * - Updated documentation
 * - Moved Log-Tags to a better position in the sources
 *
 *)