Affine Invariant MCMC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import Data.IntMap.Strict              (IntMap)
import Data.List                       (delete)
import Data.Maybe                      (fromJust)
import Data.Sequence                   (Seq, index)
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Sequence      as Seq

import Control.Monad
import Control.Monad.State

import System.Random

data SimConfig = SimConfig {
        numDimensions :: !Int            -- strict
    ,   numWalkers    :: !Int            -- strict
    ,   simArray      :: IntMap [Double] -- strict spine
    ,   logP          :: Seq Double      -- strict spine
    ,   logL          :: Seq Double      -- strict spine
    ,   pairStream    :: [(Int, Int)]    -- lazy
    ,   doubleStream  :: [Double]        -- lazy 
    } deriving Show

streamToAssocList :: (Enum a, Num a, Ord a) => a -> Int -> [b] -> [(a, [b])]
streamToAssocList nrec recsize xs = zip [1..nrec] (go nrec recsize 0)
    where go nrec recsize j
             | nrec == 0 = []
             | otherwise = (take recsize . drop j $ xs) : go (nrec - 1) recsize (j + recsize)

consPairStream :: RandomGen g => Int -> g -> [(Int, Int)]
consPairStream n gen = filter (uncurry (/=)) $ zip xs ys
    where xs = randomRs (1, n) gen
          ys = randomRs (1, n) (snd $ split gen)

logPrior, logLikelihood  :: [Double] -> Double
logPrior xs   = if   any (<0) xs || any (>1) xs then - 1 / 0 else 0
logLikelihood = (* (-0.5)) . sum . map ((^2) . (/0.01) . (+ (-0.5)))


simKernel :: State SimConfig ()
simKernel = do
    config <- get
    let arr   = simArray      config
    let n     = numWalkers    config
    let d     = numDimensions config
    let rstm0 = pairStream    config
    let rstm1 = doubleStream  config
    let lp    = logP          config
    let ll    = logL          config

    let z0 = head . map affineTransform $ take 1 rstm1
            where affineTransform a = 0.5 * (a + 1) ^ 2

    let (a, b)    = head rstm0

    let proposal  = zipWith (+) r1 r2
            where r1    = map (*z0)     $ fromJust (IntMap.lookup a arr)
                  r2    = map (*(1-z0)) $ fromJust (IntMap.lookup b arr)

    let logA = if val > 0 then 0 else val
            where val = logP_proposal + logL_proposal - (lp `index` (a - 1)) - (ll `index` (a - 1)) + ((fromIntegral n - 1) * log z0)
                  logP_proposal = logPrior proposal
                  logL_proposal = logLikelihood proposal

    let cVal       = (rstm1 !! 1) <= exp logA

    let newConfig = SimConfig { simArray = if   cVal
                                           then IntMap.update (\_ -> Just proposal) a arr
                                           else arr
                              , numWalkers = n
                              , numDimensions = d
                              , pairStream   = drop 1 rstm0
                              , doubleStream = drop 2 rstm1
                              , logP = if   cVal
                                       then Seq.update (a - 1) (logPrior proposal) lp
                                       else lp
                              , logL = if   cVal
                                       then Seq.update (a - 1) (logLikelihood proposal) ll
                                       else ll
                              }

    put newConfig


main = do
    let (nw, nd) = (5, 4)

    setStdGen (mkStdGen 42)
    gen <- getStdGen
    let randomList = randomRs (0 :: Double, 1 :: Double) gen

    newStdGen
    gen <- getStdGen
    let pairList   = consPairStream nw gen

    let arr = IntMap.fromList $ streamToAssocList nw nd randomList

    let initConfig = SimConfig { numWalkers = nw
                               , numDimensions = nd
                               , simArray = arr
                               , logL = Seq.fromList $ map (logLikelihood . fromJust . flip IntMap.lookup arr) [1..nw]
                               , logP = Seq.fromList $ map (logPrior      . fromJust . flip IntMap.lookup arr) [1..nw]
                               , pairStream   = pairList
                               , doubleStream = drop (nw * nd) randomList}

    let sim = logL $ (`execState` initConfig) . replicateM 100000 $ simKernel

    print sim