{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE RecordWildCards #-}

-- | Compressive Sampling Matched Pursuit (CoSaMP): Iterative signal recovery
-- from incomplete and inaccurate samples" by Deanna Needell & Joel Tropp
module CoSaMP ( cosamp
              , dctOrtho
              , idctOrtho
              ) where

import Numeric.LinearAlgebra
import Numeric.LinearAlgebra.Devel        (ti)
import Prelude                     hiding ((<>), reverse)
import Data.List                          (nub)
import Control.Monad.State
import Statistics.Transform               (dct, idct)

-- | CoSaMP Algorithm State
data CoSaMPState = CoSaMPState 
    { CoSaMPState -> Double
p :: Double        -- ^ Precision
    , CoSaMPState -> Int
i :: Int           -- ^ Iteration
    , CoSaMPState -> Double
e :: Double        -- ^ Tolerance
    , CoSaMPState -> Matrix Double
φ :: Matrix Double -- ^ Sampling Matrix
    , CoSaMPState -> Vector Double
u :: Vector Double -- ^ Noisy sampling vector
    , CoSaMPState -> Vector Double
a :: Vector Double -- ^ s-sparse approximation of target signal @ iteration @i@
    , CoSaMPState -> Vector Double
v :: Vector Double -- ^ Updated sample vector
    , CoSaMPState -> Int
s :: Int           -- ^ Sparsity level
    } deriving (Int -> CoSaMPState -> ShowS
[CoSaMPState] -> ShowS
CoSaMPState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CoSaMPState] -> ShowS
$cshowList :: [CoSaMPState] -> ShowS
show :: CoSaMPState -> String
$cshow :: CoSaMPState -> String
showsPrec :: Int -> CoSaMPState -> ShowS
$cshowsPrec :: Int -> CoSaMPState -> ShowS
Show)

-- | Compressive Sampling Matched Pursuit (CoSaMP)
cosamp :: Matrix Double -- ^ sampling matrix Φ
       -> Vector Double -- ^ noisy sample vector u
       -> Int           -- ^ sparsity level s
       -> Int           -- ^ Number of Iterations (default 1000)
       -> Double        -- ^ Tolerance (default 1.0e-10)
       -> Vector Double -- ^ s-sparse approsimation of the target signal
cosamp :: Matrix Double
-> Vector Double -> Int -> Int -> Double -> Vector Double
cosamp Matrix Double
φ' Vector Double
u' Int
s' Int
i' Double
e' = forall s a. State s a -> s -> a
evalState (Bool -> State CoSaMPState (Vector Double)
cosamp' Bool
False) CoSaMPState
st
  where
    p' :: Double
p' = Double
1.0e-12
    a' :: Vector Double
a' = forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst (Double
0 :: Double) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Matrix Double
φ'
    st :: CoSaMPState
st = CoSaMPState { p :: Double
p = Double
p'
                     , i :: Int
i = Int
i'
                     , e :: Double
e = Double
e'
                     , φ :: Matrix Double
φ = Matrix Double
φ'
                     , u :: Vector Double
u = Vector Double
u'
                     , a :: Vector Double
a = Vector Double
a'
                     , v :: Vector Double
v = Vector Double
u'
                     , s :: Int
s = Int
s' }

-- | CoSaMP Algorithm in the CoSaMPState Monad
cosamp' :: Bool -> State CoSaMPState (Vector Double)
cosamp' :: Bool -> State CoSaMPState (Vector Double)
cosamp' Bool
True  = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CoSaMPState -> Vector Double
a
cosamp' Bool
False = do
    st :: CoSaMPState
st@CoSaMPState{Double
Int
Matrix Double
Vector Double
s :: Int
v :: Vector Double
a :: Vector Double
u :: Vector Double
φ :: Matrix Double
e :: Double
i :: Int
p :: Double
s :: CoSaMPState -> Int
v :: CoSaMPState -> Vector Double
a :: CoSaMPState -> Vector Double
u :: CoSaMPState -> Vector Double
φ :: CoSaMPState -> Matrix Double
e :: CoSaMPState -> Double
i :: CoSaMPState -> Int
p :: CoSaMPState -> Double
..} <- forall s (m :: * -> *). MonadState s m => m s
get
    let y :: Vector Double
y  = forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ forall m mt. Transposable m mt => m -> mt
tr Matrix Double
φ forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector Double
v
        c :: Double
c  = forall t. (Ord t, Element t) => Vector t -> Vector t
sortVector Vector Double
y forall c t. Indexable c t => c -> Int -> t
! (forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Vector Double
y forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
2forall a. Num a => a -> a -> a
*Int
s)
        ω :: [IndexOf Vector]
ω  = forall (c :: * -> *) e.
Container c e =>
(e -> Bool) -> c e -> [IndexOf c]
find (\Double
val -> Double
val forall a. Ord a => a -> a -> Bool
> Double
c Bool -> Bool -> Bool
&& Double
val forall a. Ord a => a -> a -> Bool
> Double
p) Vector Double
y
        t :: [IndexOf Vector]
t  = forall a. Eq a => [a] -> [a]
nub forall a b. (a -> b) -> a -> b
$ [Int]
ω forall a. [a] -> [a] -> [a]
++ forall (c :: * -> *) e.
Container c e =>
(e -> Bool) -> c e -> [IndexOf c]
find (forall a. Eq a => a -> a -> Bool
/= Double
0) Vector Double
a :: [IndexOf Vector]
        b :: Vector Double
b  = forall t. Field t => Matrix t -> Matrix t
pinv (Matrix Double
φ forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Extractor
All, Vector I -> Extractor
Pos ([Int] -> Vector I
idxs [Int]
t))) forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector Double
u
        g :: Double
g  = forall t. (Ord t, Element t) => Vector t -> Vector t
sortVector (forall a. Num a => a -> a
abs Vector Double
b) forall c t. Indexable c t => c -> Int -> t
! (forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Vector Double
b forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- Int
s)
        j :: Vector I
j  = [Int] -> Vector I
idxs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (c :: * -> *) e.
Container c e =>
(e -> Bool) -> c e -> [IndexOf c]
find (\Double
b'' -> Double
b'' forall a. Ord a => a -> a -> Bool
> Double
g Bool -> Bool -> Bool
&& Double
b'' forall a. Ord a => a -> a -> Bool
> Double
p) forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Vector Double
b
        t' :: Vector I
t' = forall t. Element t => Matrix t -> Vector t
flatten forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Vector a -> Matrix a
asRow ([Int] -> Vector I
idxs [Int]
t) forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Extractor
All, Vector I -> Extractor
Pos Vector I
j)
        b' :: Vector Double
b' = forall t. Element t => Matrix t -> Vector t
flatten forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Vector a -> Matrix a
asRow Vector Double
b forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Extractor
All, Vector I -> Extractor
Pos Vector I
j)
        a' :: Vector Double
a' = forall (c :: * -> *) e.
Container c e =>
c e -> (e -> e -> e) -> [(IndexOf c, e)] -> c e
accum Vector Double
a forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map I -> Int
ti forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Vector a -> [a]
toList Vector I
t') (forall a. Storable a => Vector a -> [a]
toList Vector Double
b')
        v' :: Vector Double
v' = Vector Double
u forall a. Num a => a -> a -> a
- (Matrix Double
φ forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Extractor
All, Vector I -> Extractor
Pos Vector I
t') forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector Double
b')
        i' :: Int
i' = Int
i forall a. Num a => a -> a -> a
- Int
1
        h :: Bool
h  = (forall a. Normed a => a -> Double
norm_2 Vector Double
v forall a. Fractional a => a -> a -> a
/ forall a. Normed a => a -> Double
norm_2 Vector Double
u) forall a. Ord a => a -> a -> Bool
< Double
e Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
<= Int
0
    forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ CoSaMPState
st { i :: Int
i = Int
i', v :: Vector Double
v = Vector Double
v', a :: Vector Double
a = Vector Double
a' }
    Bool -> State CoSaMPState (Vector Double)
cosamp' Bool
h

-- | DCT Type II where corresponding matrix coefficients are made orthonormal
dctOrtho :: Matrix Double -> Matrix Double
dctOrtho :: Matrix Double -> Matrix Double
dctOrtho Matrix Double
x = Matrix Double
y 
  where
    n :: Int
n  = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Matrix Double
x
    n' :: Double
n' = forall a b. (Real a, Fractional b) => a -> b
realToFrac Int
n :: Double
    t :: Matrix Double
t  = forall t. Element t => [Vector t] -> Matrix t
fromColumns forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v Double -> v Double
dct forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> [Vector t]
toColumns Matrix Double
x
    r0 :: Matrix Double
r0 = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
1.0 forall a. Fractional a => a -> a -> a
/ ( Double
2.0 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt Double
n')) (Matrix Double
t forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Vector I -> Extractor
Pos forall a b. (a -> b) -> a -> b
$ [Int] -> Vector I
idxs [Int
0], Extractor
All))
    rs :: Matrix Double
rs = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (forall a. Floating a => a -> a
sqrt (Double
2.0 forall a. Fractional a => a -> a -> a
/ Double
n') forall a. Fractional a => a -> a -> a
/ Double
2.0) (Matrix Double
t forall t.
Element t =>
Matrix t -> (Extractor, Extractor) -> Matrix t
?? (Vector I -> Extractor
Pos forall a b. (a -> b) -> a -> b
$ [Int] -> Vector I
idxs [Int
1 .. (Int
n forall a. Num a => a -> a -> a
- Int
1)], Extractor
All))
    y :: Matrix Double
y  = Matrix Double
r0 forall t. Element t => Matrix t -> Matrix t -> Matrix t
=== Matrix Double
rs

-- | Inverse DCT Type II where corresponding matrix coefficients are made orthonormal
idctOrtho :: Vector Double -> Vector Double
idctOrtho :: Vector Double -> Vector Double
idctOrtho Vector Double
y = Vector Double
x
  where
    n :: Double
n = forall a b. (Real a, Fractional b) => a -> b
realToFrac forall a b. (a -> b) -> a -> b
$ forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Vector Double
y :: Double
    x :: Vector Double
x = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
0.5 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Double
2.0 forall a. Fractional a => a -> a -> a
/ Double
n)) (forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v Double -> v Double
idct Vector Double
y)