{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Torch.Extensions where
import GHC.Float (float2Double)
import qualified Torch as T
import qualified Torch.Functional.Internal as T (nan_to_num, powScalar', mse_loss, meanAll)
gpu :: T.Device
gpu :: Device
gpu = DeviceType -> Int16 -> Device
T.Device DeviceType
T.CUDA Int16
1
cpu :: T.Device
cpu :: Device
cpu = DeviceType -> Int16 -> Device
T.Device DeviceType
T.CPU Int16
0
pow10 :: T.Tensor -> T.Tensor
pow10 :: Tensor -> Tensor
pow10 = Float -> Tensor -> Tensor
T.powScalar' Float
10.0
nanToNum :: Float -> Float -> Float -> T.Tensor -> T.Tensor
nanToNum :: Float -> Float -> Float -> Tensor -> Tensor
nanToNum Float
nan' Float
posinf' Float
neginf' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
where
nan :: Double
nan = Float -> Double
float2Double Float
nan'
posinf :: Double
posinf = Float -> Double
float2Double Float
posinf'
neginf :: Double
neginf = Float -> Double
float2Double Float
neginf'
nanToNum' :: T.Tensor -> T.Tensor
nanToNum' :: Tensor -> Tensor
nanToNum' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
where
nan :: Double
nan = Double
0.0 :: Double
posinf :: Double
posinf = Float -> Double
float2Double (Float
2.0e32 :: Float)
neginf :: Double
neginf = Float -> Double
float2Double (-Float
2.0e32 :: Float)
nanToNum'' :: T.Tensor -> T.Tensor
nanToNum'' :: Tensor -> Tensor
nanToNum'' Tensor
self = Tensor -> Double -> Double -> Double -> Tensor
T.nan_to_num Tensor
self Double
nan Double
posinf Double
neginf
where
nan :: Double
nan = Double
0.0 :: Double
posinf :: Double
posinf = Double
0.0 :: Double
neginf :: Double
neginf = Double
0.0 :: Double
mseLoss' :: T.Reduction -> T.Tensor -> T.Tensor -> T.Tensor
mseLoss' :: Reduction -> Tensor -> Tensor -> Tensor
mseLoss' Reduction
T.ReduceNone Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
0
mseLoss' Reduction
T.ReduceMean Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
1
mseLoss' Reduction
T.ReduceSum Tensor
x Tensor
y = Tensor -> Tensor -> Int -> Tensor
T.mse_loss Tensor
x Tensor
y Int
2
meanAll :: T.Tensor -> T.Tensor
meanAll :: Tensor -> Tensor
meanAll Tensor
x = Tensor -> DType -> Tensor
T.meanAll Tensor
x (Tensor -> DType
T.dtype Tensor
x)