Concurrent affine MCMC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
--------------------------------------------------------------------------------
-- Affine Invariant MCMC
-- (Concurrent implementation via software transactional memory)
--
-- Jared Tobin, 2012
-- The University of Auckland
-- License: BSD
--------------------------------------------------------------------------------

import Control.Concurrent
import Control.Concurrent.STM
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as V

import Data.IntMap.Strict (IntMap, Key)
import qualified Data.IntMap.Strict as IntMap

import System.Random.MWC
import Data.Word

import Control.Monad
import Control.Monad.Primitive

import Data.Function (fix)
import Data.Maybe (fromJust)

import Data.List.Split

import System.Environment
import System.Exit

--------------------------------------------------------------------------------
-- Utilities
--------------------------------------------------------------------------------

pairGen :: PrimMonad m => Gen (PrimState m) -> Int -> m (Int, Int)
pairGen gen bound = do
    a <- uniformR (1, bound) gen
    fix $ \loopB -> do
        b <- uniformR (1, bound) gen
        if   a == b
        then loopB
        else return (a, b)
{-# INLINE pairGen #-}

-- | Fork off `length range` threads to read from `sourceTVar` and perform `operation`, and
--   wait for them to finish.  This is only used to initialize some TVars.
forkAndWait :: TVar a -> (Int -> a -> a) -> [Int] -> IO ()
forkAndWait sourceTVar operation range = mapM fork range >>= mapM_ takeMVar
  where
    fork :: Int -> IO (MVar ())
    fork wid = do
        done <- newEmptyMVar
        forkIO $ do
            atomically $ do
                tvarContents <- readTVar sourceTVar
                writeTVar sourceTVar $! operation wid tvarContents
            putMVar done ()
        return done


{-# INLINE forkAndWait #-}

--------------------------------------------------------------------------------
-- State manipulation 
--------------------------------------------------------------------------------

-- | Perform a 'step' of affine invariant MCMC.  Atomically read from relevant TVars, 
--   calculate a proposal, and write to TVars.
walk :: TVar (IntMap (TVar (Vector Double)))
     -> Int
     -> TVar (IntMap (TVar Double))
     -> TVar (IntMap (TVar Double))
     -> TVar Int
     -> (Vector Double -> Double)
     -> (Vector Double -> Double)
     -> Gen RealWorld
     -> IO ()
walk conf nw lptv lltv actv logPrior logLikelihood gen = do
    (a, b) <- pairGen gen nw
    z0     <- uniformR (0 :: Double, 1) gen
    z1     <- uniformR (0 :: Double, 1) gen
    let z  = 0.5 * (z0 + 1) ^ 2

    atomically $ do
        config   <- readTVar $! conf
        logp     <- readTVar $! lptv
        logl     <- readTVar $! lltv
        nacc     <- readTVar $! actv

        w1       <- readTVar $! fromJust (IntMap.lookup a config)
        w2       <- readTVar $! fromJust (IntMap.lookup b config)
        lp       <- readTVar $! fromJust (IntMap.lookup a logp)
        ll       <- readTVar $! fromJust (IntMap.lookup a logl)

        let proposal = V.zipWith (+) r1 r2
                where r1 = V.map (*z)     w1
                      r2 = V.map (*(1-z)) w2

            logA     = if val > 0 then 0 else val
                where val = lp_prop + ll_prop - lp - ll + ((fromIntegral nw - 1) * log z)
                      lp_prop = logPrior      proposal
                      ll_prop = logLikelihood proposal

            cVal = z1 <= exp logA

        w1t            <- if cVal then newTVar proposal                        else newTVar w1
        lpt            <- if cVal then newTVar (logPrior      proposal)        else newTVar lp
        llt            <- if cVal then newTVar (logLikelihood proposal)        else newTVar ll

        writeTVar conf $! if cVal then IntMap.update (\_ -> Just w1t) a config else config
        writeTVar lptv $! if cVal then IntMap.update (\_ -> Just lpt) a logp   else logp
        writeTVar lltv $! if cVal then IntMap.update (\_ -> Just llt) a logl   else logl
        writeTVar actv $! if cVal then nacc + 1                                else nacc
{-# INLINE walk #-}


--------------------------------------------------------------------------------
-- Main 
--------------------------------------------------------------------------------

main :: IO ()
main = do
    conf    <- newTVarIO IntMap.empty
    logP    <- newTVarIO IntMap.empty
    logL    <- newTVarIO IntMap.empty
    accepts <- newTVarIO 0

    args <- getArgs
    when (args == []) $ putStrLn "Concurrent affine invariant MCMC 0.1 \nUsage: concurrentAffineMCMC numWalkers numDimensions numEpochs seed" >> exitSuccess

    let nw     = read (args !! 0) :: Int
        nd     = read (args !! 1) :: Int
        epochs = read (args !! 2) :: Int
        seed   = read (args !! 3) :: Word32

        logPrior, logLikelihood :: Vector Double -> Double
        logPrior xs   = if   V.any (<0) xs || V.any (>1) xs then - 1 / 0 else 0
        logLikelihood = (* (-0.5)) . V.sum . V.map ((^2) . (/0.01) . (+ (-0.5)))
        {-# INLINE logPrior #-}
        {-# INLINE logLikelihood #-}

    nThreads <- getNumCapabilities
    remaining <- newTVarIO epochs

    -- | Generate random doubles to fill the initial data structures.
    gen  <- initialize (V.singleton seed)
    rnds <- replicateM (nw * nd) (uniformR (0 :: Double, 1) gen)

    let walkerInits = map V.fromList (splitEvery nd rnds)
        lpInits     = map logPrior walkerInits
        llInits     = map logLikelihood walkerInits
    wInitTs        <- atomically $ mapM newTVar walkerInits
    lpInitTs       <- atomically $ mapM newTVar lpInits
    llInitTs       <- atomically $ mapM newTVar llInits

    forkAndWait conf (\wid config -> IntMap.insert wid (wInitTs  !! (wid - 1)) config) [1..nw]
    forkAndWait logP (\wid lp     -> IntMap.insert wid (lpInitTs !! (wid - 1)) lp)     [1..nw]
    forkAndWait logL (\wid ll     -> IntMap.insert wid (llInitTs !! (wid - 1)) ll)     [1..nw]

    -- | Run the chain on many threads.
    -- FIXME separate generators for each thread? (check this)


    -- Convenience function.
    let go :: TVar Int -> IO ()
        go remaining = do
            counter <- atomically $ readTVar remaining
            atomically $ modifyTVar remaining (+ (-1))
            when (counter > 0) $ do
                walk conf nw logP logL accepts logPrior logLikelihood gen
                go remaining

    -- Fork the threads. 
    results <- replicateM nThreads $ do
        done <- newEmptyMVar
        forkIO $ do
            go remaining
            putMVar done $! ()
        return done

    -- Wait for results
    mapM_ takeMVar results

    -- FIXME more appropriate output
    nacc <- atomically $ readTVar accepts
    print nacc