let train_in_place alpha ?(decay=0.0) () net (input,example) =
let loop rep activation =
let rec loop2 rep activation =
match rep with
neurons :: rest_rep ->
begin
if Array.length (get_synapses neurons.(0)) != Array.length activation +1 then
raise (Invalid_argument "Net size does not match input")
else ();
let activation = Array.append [|1.0|] activation in
let output = neurons |*| activation in
let delta_sums = loop2 rest_rep output in
let deltas = Array.mapi (fun i x -> (derive (get_transfer neurons.(i)) output.(i))*.x) delta_sums in
let new_deltas = Array.make (Array.length activation) 0.0 in
let n = Array.length neurons in
let m = Array.length (get_synapses neurons.(0)) in
for i = 0 to n - 1 do
let weightsi = get_synapses neurons.(i) in
let deltai = deltas.(i) in
for j = 0 to m - 1 do
let weightij = Array.unsafe_get weightsi j in
let change = alpha *. deltai *. Array.unsafe_get activation j -. decay*.weightij in
Array.unsafe_set new_deltas j (Array.unsafe_get new_deltas j +. deltai *. weightij);
Array.unsafe_set weightsi j (weightij +. change);
done
done;
Array.sub new_deltas 1 (Array.length activation - 1)
end
| [] -> Array.mapi (fun i x -> example.(i) -. x) activation
in
match rep with
neurons :: rest_rep ->
begin
if Array.length (get_synapses neurons.(0)) != Array.length activation +1 then
raise (Invalid_argument "Net size does not match input")
else ();
let activation = Array.append [|1.0|] activation in
let output = neurons |*| activation in
let delta_sums = loop2 rest_rep output in
let deltas = Array.mapi (fun i x -> (derive (get_transfer neurons.(i)) output.(i))*.x) delta_sums in
let n = Array.length neurons in
let m = Array.length (get_synapses neurons.(0)) in
for i = 0 to n - 1 do
let weightsi = get_synapses neurons.(i) in
let deltai = Array.unsafe_get deltas i in
for j = 0 to m - 1 do
let weightij = Array.unsafe_get weightsi j in
let change = alpha *. deltai *. Array.unsafe_get activation j -. decay*.weightij in
Array.unsafe_set weightsi j (weightij +. change);
done
done
end
| [] ->
()
in loop net input