{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Net where
import Lib
import GHC.Generics
import qualified Torch as T
import qualified Torch.Extensions as T
import qualified Torch.NN as NN
data OpNetSpec = OpNetSpec { OpNetSpec -> Int
numX :: !Int
, OpNetSpec -> Int
numY :: !Int
} deriving (Int -> OpNetSpec -> ShowS
[OpNetSpec] -> ShowS
OpNetSpec -> String
(Int -> OpNetSpec -> ShowS)
-> (OpNetSpec -> String)
-> ([OpNetSpec] -> ShowS)
-> Show OpNetSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpNetSpec] -> ShowS
$cshowList :: [OpNetSpec] -> ShowS
show :: OpNetSpec -> String
$cshow :: OpNetSpec -> String
showsPrec :: Int -> OpNetSpec -> ShowS
$cshowsPrec :: Int -> OpNetSpec -> ShowS
Show, OpNetSpec -> OpNetSpec -> Bool
(OpNetSpec -> OpNetSpec -> Bool)
-> (OpNetSpec -> OpNetSpec -> Bool) -> Eq OpNetSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpNetSpec -> OpNetSpec -> Bool
$c/= :: OpNetSpec -> OpNetSpec -> Bool
== :: OpNetSpec -> OpNetSpec -> Bool
$c== :: OpNetSpec -> OpNetSpec -> Bool
Eq)
data OpNet = OpNet { OpNet -> Linear
fc0 :: !T.Linear
, OpNet -> Linear
fc1 :: !T.Linear
, OpNet -> Linear
fc2 :: !T.Linear
, OpNet -> Linear
fc3 :: !T.Linear
, OpNet -> Linear
fc4 :: !T.Linear
, OpNet -> Linear
fc5 :: !T.Linear
, OpNet -> Linear
fc6 :: !T.Linear
} deriving ((forall x. OpNet -> Rep OpNet x)
-> (forall x. Rep OpNet x -> OpNet) -> Generic OpNet
forall x. Rep OpNet x -> OpNet
forall x. OpNet -> Rep OpNet x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OpNet x -> OpNet
$cfrom :: forall x. OpNet -> Rep OpNet x
Generic, Int -> OpNet -> ShowS
[OpNet] -> ShowS
OpNet -> String
(Int -> OpNet -> ShowS)
-> (OpNet -> String) -> ([OpNet] -> ShowS) -> Show OpNet
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpNet] -> ShowS
$cshowList :: [OpNet] -> ShowS
show :: OpNet -> String
$cshow :: OpNet -> String
showsPrec :: Int -> OpNet -> ShowS
$cshowsPrec :: Int -> OpNet -> ShowS
Show, OpNet -> [Parameter]
OpNet -> ParamStream OpNet
(OpNet -> [Parameter])
-> (OpNet -> ParamStream OpNet) -> Parameterized OpNet
forall f.
(f -> [Parameter]) -> (f -> ParamStream f) -> Parameterized f
_replaceParameters :: OpNet -> ParamStream OpNet
$c_replaceParameters :: OpNet -> ParamStream OpNet
flattenParameters :: OpNet -> [Parameter]
$cflattenParameters :: OpNet -> [Parameter]
T.Parameterized)
instance T.Randomizable OpNetSpec OpNet where
sample :: OpNetSpec -> IO OpNet
sample OpNetSpec{Int
numY :: Int
numX :: Int
numY :: OpNetSpec -> Int
numX :: OpNetSpec -> Int
..} = Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> OpNet
OpNet (Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> Linear
-> OpNet)
-> IO Linear
-> IO
(Linear -> Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
numX Int
32)
IO
(Linear -> Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear
-> IO (Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
32 Int
128)
IO (Linear -> Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
128 Int
512)
IO (Linear -> Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
512 Int
128)
IO (Linear -> Linear -> Linear -> OpNet)
-> IO Linear -> IO (Linear -> Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
128 Int
64)
IO (Linear -> Linear -> OpNet) -> IO Linear -> IO (Linear -> OpNet)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
64 Int
32)
IO (Linear -> OpNet) -> IO Linear -> IO OpNet
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LinearSpec -> IO Linear
forall spec f. Randomizable spec f => spec -> IO f
T.sample (Int -> Int -> LinearSpec
T.LinearSpec Int
32 Int
numY)
forward :: OpNet -> T.Tensor -> T.Tensor
forward :: OpNet -> Tensor -> Tensor
forward OpNet{Linear
fc6 :: Linear
fc5 :: Linear
fc4 :: Linear
fc3 :: Linear
fc2 :: Linear
fc1 :: Linear
fc0 :: Linear
fc6 :: OpNet -> Linear
fc5 :: OpNet -> Linear
fc4 :: OpNet -> Linear
fc3 :: OpNet -> Linear
fc2 :: OpNet -> Linear
fc1 :: OpNet -> Linear
fc0 :: OpNet -> Linear
..} = Linear -> Tensor -> Tensor
T.linear Linear
fc6 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc5 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc4 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc3 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc2 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc1 (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor
T.relu
(Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear -> Tensor -> Tensor
T.linear Linear
fc0
noGrad :: (NN.Parameterized f) => f -> IO f
noGrad :: f -> IO f
noGrad f
net = do
[Parameter]
params <- (Parameter -> IO Parameter) -> [Parameter] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Tensor -> Bool -> IO Parameter
`T.makeIndependentWithRequiresGrad` Bool
False) (Tensor -> IO Parameter)
-> (Parameter -> Tensor) -> Parameter -> IO Parameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
detachToCPU)
([Parameter] -> IO [Parameter]) -> [Parameter] -> IO [Parameter]
forall a b. (a -> b) -> a -> b
$ f -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
NN.flattenParameters f
net
f -> IO f
forall (f :: * -> *) a. Applicative f => a -> f a
pure (f -> IO f) -> f -> IO f
forall a b. (a -> b) -> a -> b
$ f -> [Parameter] -> f
forall f. Parameterized f => f -> [Parameter] -> f
NN.replaceParameters f
net [Parameter]
params
where
detachToCPU :: Parameter -> Tensor
detachToCPU = Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice Device
T.cpu (Tensor -> Tensor) -> (Parameter -> Tensor) -> Parameter -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
T.toDependent
saveCheckPoint :: FilePath -> OpNet -> T.Adam -> IO ()
saveCheckPoint :: String -> OpNet -> Adam -> IO ()
saveCheckPoint String
path OpNet
net Adam
opt = do
OpNet -> String -> IO ()
forall f. Parameterized f => f -> String -> IO ()
T.saveParams OpNet
net (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/model.pt")
[Tensor] -> String -> IO ()
T.save (Adam -> [Tensor]
T.m1 Adam
opt) (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M1.pt")
[Tensor] -> String -> IO ()
T.save (Adam -> [Tensor]
T.m2 Adam
opt) (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M2.pt")
loadCheckPoint :: FilePath -> OpNetSpec -> Int -> IO (OpNet, T.Adam)
loadCheckPoint :: String -> OpNetSpec -> Int -> IO (OpNet, Adam)
loadCheckPoint String
path OpNetSpec
spec Int
iter = do
OpNet
net <- OpNetSpec -> IO OpNet
forall spec f. Randomizable spec f => spec -> IO f
T.sample OpNetSpec
spec IO OpNet -> (OpNet -> IO OpNet) -> IO OpNet
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (OpNet -> String -> IO OpNet
forall b. Parameterized b => b -> String -> IO b
`T.loadParams` (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/model.pt"))
[Tensor]
m1' <- String -> IO [Tensor]
T.load (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M1.pt")
[Tensor]
m2' <- String -> IO [Tensor]
T.load (String
path String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"/M2.pt")
let opt :: Adam
opt = Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
T.Adam Float
0.9 Float
0.999 [Tensor]
m1' [Tensor]
m2' Int
iter
(OpNet, Adam) -> IO (OpNet, Adam)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpNet
net, Adam
opt)
traceModel :: Device -> PDK -> Int -> [String] -> [String]
-> (T.Tensor -> T.Tensor) -> IO T.ScriptModule
traceModel :: Device
-> PDK
-> Int
-> [String]
-> [String]
-> (Tensor -> Tensor)
-> IO ScriptModule
traceModel Device
dev PDK
pdk Int
num [String]
xs [String]
ys Tensor -> Tensor
predict = do
!RawModule
rm <- [Int] -> IO Tensor
T.randnIO' [Int
10,Int
num] IO Tensor -> (Tensor -> IO RawModule) -> IO RawModule
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
T.trace String
name String
"forward" [Tensor] -> IO [Tensor]
fun ([Tensor] -> IO RawModule)
-> (Tensor -> [Tensor]) -> Tensor -> IO RawModule
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> [Tensor]
forall a. a -> [a]
singleton
RawModule -> String -> IO ()
T.define RawModule
rm (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"def inputs(self,x):\n\treturn " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall a. Show a => a -> String
show [String]
xs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n"
RawModule -> String -> IO ()
T.define RawModule
rm (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"def outputs(self,x):\n\treturn " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall a. Show a => a -> String
show [String]
ys String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n"
RawModule -> IO ScriptModule
T.toScriptModule RawModule
rm
where
fun :: [Tensor] -> IO [Tensor]
fun = [Tensor] -> IO [Tensor]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Tensor] -> IO [Tensor])
-> ([Tensor] -> [Tensor]) -> [Tensor] -> IO [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> Tensor
predict
name :: String
name = PDK -> String
forall a. Show a => a -> String
show PDK
pdk String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Device -> String
forall a. Show a => a -> String
show Device
dev
unTraceModel :: T.ScriptModule -> (T.Tensor -> T.Tensor)
unTraceModel :: ScriptModule -> Tensor -> Tensor
unTraceModel ScriptModule
model' Tensor
x = Tensor
y
where
T.IVTensor Tensor
y = ScriptModule -> String -> IValue -> IValue
T.runMethod1 ScriptModule
model' String
"forward" (IValue -> IValue) -> IValue -> IValue
forall a b. (a -> b) -> a -> b
$ Tensor -> IValue
T.IVTensor Tensor
x
saveInferenceModel :: FilePath -> T.ScriptModule -> IO ()
saveInferenceModel :: String -> ScriptModule -> IO ()
saveInferenceModel String
path ScriptModule
model = ScriptModule -> String -> IO ()
T.saveScript ScriptModule
model String
path
loadInferenceModel :: FilePath -> IO T.ScriptModule
loadInferenceModel :: String -> IO ScriptModule
loadInferenceModel = LoadMode -> String -> IO ScriptModule
T.loadScript LoadMode
T.WithoutRequiredGrad