{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | Hyper Parameters of OpNet
module Torch.Extensions where

import           GHC.Float                       (float2Double)
import qualified Torch                     as T
import qualified Torch.Functional.Internal as T (nan_to_num, powScalar', mse_loss, meanAll)

------------------------------------------------------------------------------
-- Convenience / Syntactic Sugar
------------------------------------------------------------------------------

-- | GPU
gpu :: T.Device
gpu :: Device
gpu = DeviceType -> Int16 -> Device
T.Device DeviceType
T.CUDA Int16
1

-- | CPU
cpu :: T.Device
cpu :: Device
cpu = DeviceType -> Int16 -> Device
T.Device DeviceType
T.CPU Int16
0

-- | The inverse of `log10`
pow10 :: T.Tensor -> T.Tensor
pow10 :: Tensor -> Tensor
pow10 = Float -> Tensor -> Tensor
T.powScalar' Float
10.0

-- | Because snake_case sucks and this project uses Float instead of Double
nanToNum :: Float -> Float -> Float -> T.Tensor -> T.Tensor
nanToNum :: Float -> Float -> Float -> Tensor -> Tensor
nanToNum Float
nan' Float
posinf' Float
neginf' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
  where
    nan :: Double
nan    = Float -> Double
float2Double Float
nan'
    posinf :: Double
posinf = Float -> Double
float2Double Float
posinf'
    neginf :: Double
neginf = Float -> Double
float2Double Float
neginf'

-- | Default limits for `nanToNum`
nanToNum' :: T.Tensor -> T.Tensor
nanToNum' :: Tensor -> Tensor
nanToNum' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
  where
    nan :: Double
nan    = Double
0.0 :: Double
    posinf :: Double
posinf = Float -> Double
float2Double (Float
2.0e32 :: Float)
    neginf :: Double
neginf = Float -> Double
float2Double (-Float
2.0e32 :: Float)

-- | Default limits for `nanToNum` (0.0)
nanToNum'' :: T.Tensor -> T.Tensor
nanToNum'' :: Tensor -> Tensor
nanToNum'' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
  where
    nan :: Double
nan    = Double
0.0 :: Double
    posinf :: Double
posinf = Double
0.0 :: Double
    neginf :: Double
neginf = Double
0.0 :: Double

-- | MSE with reduction
mseLoss' :: T.Reduction -> T.Tensor -> T.Tensor -> T.Tensor
mseLoss' :: Reduction -> Tensor -> Tensor -> Tensor
mseLoss' Reduction
T.ReduceNone Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
0
mseLoss' Reduction
T.ReduceMean Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
1
mseLoss' Reduction
T.ReduceSum  Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
2

-- | Mean over all dimensions
meanAll :: T.Tensor -> T.Tensor
meanAll :: Tensor -> Tensor
meanAll Tensor
x = Tensor -> DType -> Tensor
T.meanAll Tensor
x (Tensor -> DType
T.dtype Tensor
x)