{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | Neural Network Definition and Training
module Net where

import           Lib
import           GHC.Generics
import qualified Torch            as T
import qualified Torch.Extensions as T
import qualified Torch.NN         as NN

------------------------------------------------------------------------------
-- Neural Network
------------------------------------------------------------------------------

-- | Neural Network Specification
data OpNetSpec = OpNetSpec { OpNetSpec -> Int
numX   :: !Int -- ^ Number of input neurons
                           , OpNetSpec -> Int
numY   :: !Int -- ^ Number of output neurons
                           } deriving (Int -> OpNetSpec -> ShowS
[OpNetSpec] -> ShowS
OpNetSpec -> String
(Int -> OpNetSpec -> ShowS)
-> (OpNetSpec -> String)
-> ([OpNetSpec] -> ShowS)
-> Show OpNetSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpNetSpec] -> ShowS
$cshowList :: [OpNetSpec] -> ShowS
show :: OpNetSpec -> String
$cshow :: OpNetSpec -> String
showsPrec :: Int -> OpNetSpec -> ShowS
$cshowsPrec :: Int -> OpNetSpec -> ShowS
Show, OpNetSpec -> OpNetSpec -> Bool
(OpNetSpec -> OpNetSpec -> Bool)
-> (OpNetSpec -> OpNetSpec -> Bool) -> Eq OpNetSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpNetSpec -> OpNetSpec -> Bool
$c/= :: OpNetSpec -> OpNetSpec -> Bool
== :: OpNetSpec -> OpNetSpec -> Bool
$c== :: OpNetSpec -> OpNetSpec -> Bool
Eq)

-- | Network Architecture
data OpNet = OpNet { OpNet -> Linear
fc0 :: !T.Linear
                   , OpNet -> Linear
fc1 :: !T.Linear
                   , OpNet -> Linear
fc2 :: !T.Linear
                   , OpNet -> Linear
fc3 :: !T.Linear
                   , OpNet -> Linear
fc4 :: !T.Linear
                   , OpNet -> Linear
fc5 :: !T.Linear
                   , OpNet -> Linear
fc6 :: !T.Linear
                   } deriving ((forall x. OpNet -> Rep OpNet x)
-> (forall x. Rep OpNet x -> OpNet) -> Generic OpNet
forall x. Rep OpNet x -> OpNet
forall x. OpNet -> Rep OpNet x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OpNet x -> OpNet
$cfrom :: forall x. OpNet -> Rep OpNet x
Generic, Int -> OpNet -> ShowS
[OpNet] -> ShowS
OpNet -> String
(Int -> OpNet -> ShowS)
-> (OpNet -> String) -> ([OpNet] -> ShowS) -> Show OpNet
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpNet] -> ShowS
$cshowList :: [OpNet] -> ShowS
show :: OpNet -> String
$cshow :: OpNet -> String
showsPrec :: Int -> OpNet -> ShowS
$cshowsPrec :: Int -> OpNet -> ShowS
Show, OpNet -> [Parameter]
OpNet -> ParamStream OpNet
(OpNet -> [Parameter])
-> (OpNet -> ParamStream OpNet) -> Parameterized OpNet
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: OpNet -> ParamStream OpNet
$c_replaceParameters :: OpNet -> ParamStream OpNet
flattenParameters :: OpNet -> [Parameter]
$cflattenParameters :: OpNet -> [Parameter]
T.Parameterized)

-- | Neural Network Weight initialization
instance T.Randomizable OpNetSpec OpNet where
    sample :: OpNetSpec -> IO OpNet
sample OpNetSpec{Int
numY :: Int
numX :: Int
numY :: OpNetSpec -> Int
numX :: OpNetSpec -> Int
..} = Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> OpNet
OpNet (Linear
 -> Linear
 -> Linear
 -> Linear
 -> Linear
 -> Linear
 -> Linear
 -> OpNet)
-> IO Linear
-> IO
     (Linear -> Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
numX Int
32)
                                 IO
  (Linear -> Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear
-> IO (Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
32   Int
128)
                                 IO (Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
128  Int
512)
                                 IO (Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
512  Int
128)
                                 IO (Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
128  Int
64)
                                 IO (Linear -> Linear -> OpNet) -> IO Linear -> IO (Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
64   Int
32)
                                 IO (Linear -> OpNet) -> IO Linear -> IO OpNet
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec   Int
32   Int
numY)

-- | Neural Network Forward Pass with scaled Data
forward :: OpNet -> T.Tensor -> T.Tensor
forward :: OpNet -> Tensor -> Tensor
forward OpNet{Linear
fc6 :: Linear
fc5 :: Linear
fc4 :: Linear
fc3 :: Linear
fc2 :: Linear
fc1 :: Linear
fc0 :: Linear
fc6 :: OpNet -> Linear
fc5 :: OpNet -> Linear
fc4 :: OpNet -> Linear
fc3 :: OpNet -> Linear
fc2 :: OpNet -> Linear
fc1 :: OpNet -> Linear
fc0 :: OpNet -> Linear
..} = Linear -> Tensor -> Tensor
T.linear Linear
fc6 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc5 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc4 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc3 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc2 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc1 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
                  (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc0 
------------------------------------------------------------------------------
-- Serialization
------------------------------------------------------------------------------

-- | Remove Gradient for tracing / scripting
noGrad :: (NN.Parameterized f) => f -> IO f
noGrad :: f -> IO f
noGrad f
net = do
    [Parameter]
params <- (Parameter -> IO Parameter) -> [Parameter] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Tensor -> Bool -> IO Parameter
`T.makeIndependentWithRequiresGrad` Bool
False) (Tensor -> IO Parameter)
-> (Parameter -> Tensor) -> Parameter -> IO Parameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
detachToCPU)
            ([Parameter] -> IO [Parameter]) -> [Parameter] -> IO [Parameter]
forall a b. (a -> b) -> a -> b
$ f -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
NN.flattenParameters f
net
    f -> IO f
forall (f :: * -> *) a. Applicative f => a -> f a
pure (f -> IO f) -> f -> IO f
forall a b. (a -> b) -> a -> b
$ f -> [Parameter] -> f
forall f. Parameterized f => f -> [Parameter] -> f
NN.replaceParameters f
net [Parameter]
params
  where
    detachToCPU :: Parameter -> Tensor
detachToCPU = Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice Device
T.cpu (Tensor -> Tensor) -> (Parameter -> Tensor) -> Parameter -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
T.toDependent

------------------------------------------------------------------------------
-- Saving and Loading
------------------------------------------------------------------------------

-- | Save Model and Optimizer Checkpoint
saveCheckPoint :: FilePath -> OpNet -> T.Adam -> IO ()
saveCheckPoint :: String -> OpNet -> Adam -> IO ()
saveCheckPoint String
path OpNet
net Adam
opt = do
    OpNet -> String -> IO ()
forall f. Parameterized f => f -> String -> IO ()
T.saveParams OpNet
net  (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/model.pt")
    [Tensor] -> String -> IO ()
T.save (Adam -> [Tensor]
T.m1 Adam
opt) (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M1.pt")
    [Tensor] -> String -> IO ()
T.save (Adam -> [Tensor]
T.m2 Adam
opt) (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M2.pt")

-- | Load a Saved Model and Optimizer CheckPoint
loadCheckPoint :: FilePath -> OpNetSpec -> Int -> IO (OpNet, T.Adam)
loadCheckPoint :: String -> OpNetSpec -> Int -> IO (OpNet, Adam)
loadCheckPoint String
path OpNetSpec
spec Int
iter = do
    OpNet
net <- OpNetSpec -> IO OpNet
forall spec f. Randomizable spec f => spec -> IO f
T.sample OpNetSpec
spec IO OpNet -> (OpNet -> IO OpNet) -> IO OpNet
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (OpNet -> String -> IO OpNet
forall b. Parameterized b => b -> String -> IO b
`T.loadParams` (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/model.pt"))
    [Tensor]
m1' <- String -> IO [Tensor]
T.load (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M1.pt")
    [Tensor]
m2' <- String -> IO [Tensor]
T.load (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M2.pt")
    let opt :: Adam
opt = Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
T.Adam Float
0.9 Float
0.999 [Tensor]
m1' [Tensor]
m2' Int
iter
    (OpNet, Adam) -> IO (OpNet, Adam)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpNet
net, Adam
opt)

-- | Trace and Return a Script Module
traceModel :: Device -> PDK -> Int -> [String] -> [String]
           -> (T.Tensor -> T.Tensor) -> IO T.ScriptModule
traceModel :: Device
-> PDK
-> Int
-> [String]
-> [String]
-> (Tensor -> Tensor)
-> IO ScriptModule
traceModel Device
dev PDK
pdk Int
num [String]
xs [String]
ys Tensor -> Tensor
predict = do
    !RawModule
rm <- [Int] -> IO Tensor
T.randnIO' [Int
10,Int
num] IO Tensor -> (Tensor -> IO RawModule) -> IO RawModule
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
T.trace String
name String
"forward" [Tensor] -> IO [Tensor]
fun ([Tensor] -> IO RawModule)
-> (Tensor -> [Tensor]) -> Tensor -> IO RawModule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> [Tensor]
forall a. a -> [a]
singleton
    RawModule -> String -> IO ()
T.define RawModule
rm (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"def inputs(self,x):\n\treturn " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall a. Show a => a -> String
show [String]
xs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n"
    RawModule -> String -> IO ()
T.define RawModule
rm (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"def outputs(self,x):\n\treturn " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall a. Show a => a -> String
show [String]
ys String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n"
    RawModule -> IO ScriptModule
T.toScriptModule RawModule
rm
  where
    fun :: [Tensor] -> IO [Tensor]
fun   = [Tensor] -> IO [Tensor]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Tensor] -> IO [Tensor])
-> ([Tensor] -> [Tensor]) -> [Tensor] -> IO [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> Tensor
predict
    name :: String
name  = PDK -> String
forall a. Show a => a -> String
show PDK
pdk String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Device -> String
forall a. Show a => a -> String
show Device
dev

-- | Trace to Function
unTraceModel :: T.ScriptModule -> (T.Tensor -> T.Tensor)
unTraceModel :: ScriptModule -> Tensor -> Tensor
unTraceModel ScriptModule
model' Tensor
x = Tensor
y
  where
    T.IVTensor Tensor
y = ScriptModule -> String -> IValue -> IValue
T.runMethod1 ScriptModule
model' String
"forward" (IValue -> IValue) -> IValue -> IValue
forall a b. (a -> b) -> a -> b
$ Tensor -> IValue
T.IVTensor Tensor
x

-- | Save a Traced ScriptModule
saveInferenceModel :: FilePath -> T.ScriptModule -> IO ()
saveInferenceModel :: String -> ScriptModule -> IO ()
saveInferenceModel String
path ScriptModule
model = ScriptModule -> String -> IO ()
T.saveScript ScriptModule
model String
path

-- | Load a Traced ScriptModule
loadInferenceModel :: FilePath -> IO T.ScriptModule
loadInferenceModel :: String -> IO ScriptModule
loadInferenceModel = LoadMode -> String -> IO ScriptModule
T.loadScript LoadMode
T.WithoutRequiredGrad