-- |
-- Module      : System.Random.Shuffle
-- Copyright   : (c) 2009 Oleg Kiselyov, Manlio Perillo
-- License     : BSD3 (see LICENSE file)
--
-- http://okmij.org/ftp/Haskell/perfect-shuffle.txt
--
{-# OPTIONS_GHC -funbox-strict-fields #-}

module System.Random.Shuffle
    (
     shuffle
    , shuffle'
    , shuffleM
    ) where

import Data.Function (fix)
import System.Random (RandomGen, randomR)
import Control.Monad (liftM,liftM2)
import Control.Monad.Random (MonadRandom, getRandomR)


-- A complete binary tree, of leaves and internal nodes.
-- Internal node: Node card l r
-- where card is the number of leaves under the node.
-- Invariant: card >=2. All internal tree nodes are always full.
data Tree a = Leaf !a
            | Node !Int !(Tree a) !(Tree a)
              deriving Int -> Tree a -> ShowS
[Tree a] -> ShowS
Tree a -> String
(Int -> Tree a -> ShowS)
-> (Tree a -> String) -> ([Tree a] -> ShowS) -> Show (Tree a)
forall a. Show a => Int -> Tree a -> ShowS
forall a. Show a => [Tree a] -> ShowS
forall a. Show a => Tree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tree a] -> ShowS
$cshowList :: forall a. Show a => [Tree a] -> ShowS
show :: Tree a -> String
$cshow :: forall a. Show a => Tree a -> String
showsPrec :: Int -> Tree a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Tree a -> ShowS
Show


-- Convert a sequence (e1...en) to a complete binary tree
buildTree :: [a] -> Tree a
buildTree :: [a] -> Tree a
buildTree = ((([Tree a] -> Tree a) -> [Tree a] -> Tree a) -> [Tree a] -> Tree a
forall a. (a -> a) -> a
fix ([Tree a] -> Tree a) -> [Tree a] -> Tree a
forall a. ([Tree a] -> Tree a) -> [Tree a] -> Tree a
growLevel) ([Tree a] -> Tree a) -> ([a] -> [Tree a]) -> [a] -> Tree a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a -> Tree a) -> [a] -> [Tree a]
forall a b. (a -> b) -> [a] -> [b]
map a -> Tree a
forall a. a -> Tree a
Leaf)
    where
      growLevel :: ([Tree a] -> Tree a) -> [Tree a] -> Tree a
growLevel [Tree a] -> Tree a
_ [Tree a
node] = Tree a
node
      growLevel [Tree a] -> Tree a
self [Tree a]
l = [Tree a] -> Tree a
self ([Tree a] -> Tree a) -> [Tree a] -> Tree a
forall a b. (a -> b) -> a -> b
$ [Tree a] -> [Tree a]
forall a. [Tree a] -> [Tree a]
inner [Tree a]
l

      inner :: [Tree a] -> [Tree a]
inner [] = []
      inner [Tree a
e] = [Tree a
e]
      inner (Tree a
e1 : Tree a
e2 : [Tree a]
es) = Tree a
e1 Tree a -> [Tree a] -> [Tree a]
`seq` Tree a
e2 Tree a -> [Tree a] -> [Tree a]
`seq` (Tree a -> Tree a -> Tree a
forall a. Tree a -> Tree a -> Tree a
join Tree a
e1 Tree a
e2) Tree a -> [Tree a] -> [Tree a]
forall a. a -> [a] -> [a]
: [Tree a] -> [Tree a]
inner [Tree a]
es

      join :: Tree a -> Tree a -> Tree a
join l :: Tree a
l@(Leaf a
_)       r :: Tree a
r@(Leaf a
_)       = Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node Int
2 Tree a
l Tree a
r
      join l :: Tree a
l@(Node Int
ct Tree a
_ Tree a
_)  r :: Tree a
r@(Leaf a
_)       = Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
ct Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Tree a
l Tree a
r
      join l :: Tree a
l@(Leaf a
_)       r :: Tree a
r@(Node Int
ct Tree a
_ Tree a
_)  = Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
ct Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Tree a
l Tree a
r
      join l :: Tree a
l@(Node Int
ctl Tree a
_ Tree a
_) r :: Tree a
r@(Node Int
ctr Tree a
_ Tree a
_) = Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
ctl Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ctr) Tree a
l Tree a
r


-- |Given a sequence (e1,...en) to shuffle, and a sequence
-- (r1,...r[n-1]) of numbers such that r[i] is an independent sample
-- from a uniform random distribution [0..n-i], compute the
-- corresponding permutation of the input sequence.
shuffle :: [a] -> [Int] -> [a]
shuffle :: [a] -> [Int] -> [a]
shuffle [a]
elements = Tree a -> [Int] -> [a]
forall a. Tree a -> [Int] -> [a]
shuffleTree ([a] -> Tree a
forall a. [a] -> Tree a
buildTree [a]
elements)
    where
      shuffleTree :: Tree a -> [Int] -> [a]
shuffleTree (Leaf a
e) [] = [a
e]
      shuffleTree Tree a
tree (Int
r : [Int]
rs) =
          let (a
b, Tree a
rest) = Int -> Tree a -> (a, Tree a)
forall a. Int -> Tree a -> (a, Tree a)
extractTree Int
r Tree a
tree
	  in a
b a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (Tree a -> [Int] -> [a]
shuffleTree Tree a
rest [Int]
rs)
      shuffleTree Tree a
_ [Int]
_ = String -> [a]
forall a. HasCallStack => String -> a
error String
"[shuffle] called with lists of different lengths"

      -- Extracts the n-th element from the tree and returns
      -- that element, paired with a tree with the element
      -- deleted.
      -- The function maintains the invariant of the completeness
      -- of the tree: all internal nodes are always full.
      extractTree :: Int -> Tree a -> (a, Tree a)
extractTree Int
0 (Node Int
_ (Leaf a
e) Tree a
r) = (a
e, Tree a
r)
      extractTree Int
1 (Node Int
2 (Leaf a
l) (Leaf a
r)) = (a
r, a -> Tree a
forall a. a -> Tree a
Leaf a
l)
      extractTree Int
n (Node Int
c (Leaf a
l) Tree a
r) =
	  let (a
e, Tree a
r') = Int -> Tree a -> (a, Tree a)
extractTree (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Tree a
r
	  in (a
e, Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (a -> Tree a
forall a. a -> Tree a
Leaf a
l) Tree a
r')

      extractTree Int
n (Node Int
n' Tree a
l (Leaf a
e))
	  | Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n' = (a
e, Tree a
l)

      extractTree Int
n (Node Int
c l :: Tree a
l@(Node Int
cl Tree a
_ Tree a
_) Tree a
r)
	  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
cl = let (a
e, Tree a
l') = Int -> Tree a -> (a, Tree a)
extractTree Int
n Tree a
l
		     in (a
e, Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Tree a
l' Tree a
r)
	  | Bool
otherwise = let (a
e, Tree a
r') = Int -> Tree a -> (a, Tree a)
extractTree (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
cl) Tree a
r
			in (a
e, Int -> Tree a -> Tree a -> Tree a
forall a. Int -> Tree a -> Tree a -> Tree a
Node (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Tree a
l Tree a
r')
      extractTree Int
_ Tree a
_ = String -> (a, Tree a)
forall a. HasCallStack => String -> a
error String
"[extractTree] impossible"

-- |Given a sequence (e1,...en) to shuffle, its length, and a random
-- generator, compute the corresponding permutation of the input
-- sequence.
shuffle' :: RandomGen gen => [a] -> Int -> gen -> [a]
shuffle' :: [a] -> Int -> gen -> [a]
shuffle' [a]
elements Int
len = [a] -> [Int] -> [a]
forall a. [a] -> [Int] -> [a]
shuffle [a]
elements ([Int] -> [a]) -> (gen -> [Int]) -> gen -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> gen -> [Int]
forall gen. RandomGen gen => Int -> gen -> [Int]
rseq Int
len
    where
      -- The sequence (r1,...r[n-1]) of numbers such that r[i] is an
      -- independent sample from a uniform random distribution
      -- [0..n-i]
      rseq :: RandomGen gen => Int -> gen -> [Int]
      rseq :: Int -> gen -> [Int]
rseq Int
n = ([Int], [gen]) -> [Int]
forall a b. (a, b) -> a
fst (([Int], [gen]) -> [Int])
-> (gen -> ([Int], [gen])) -> gen -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, gen)] -> ([Int], [gen])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, gen)] -> ([Int], [gen]))
-> (gen -> [(Int, gen)]) -> gen -> ([Int], [gen])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> gen -> [(Int, gen)]
forall gen. RandomGen gen => Int -> gen -> [(Int, gen)]
rseq' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
          where
            rseq' :: RandomGen gen => Int -> gen -> [(Int, gen)]
            rseq' :: Int -> gen -> [(Int, gen)]
rseq' Int
0 gen
_ = []
            rseq' Int
i gen
gen = (Int
j, gen
gen) (Int, gen) -> [(Int, gen)] -> [(Int, gen)]
forall a. a -> [a] -> [a]
: Int -> gen -> [(Int, gen)]
forall gen. RandomGen gen => Int -> gen -> [(Int, gen)]
rseq' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) gen
gen'
                where
                  (Int
j, gen
gen') = (Int, Int) -> gen -> (Int, gen)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, Int
i) gen
gen

-- |shuffle' wrapped in a random monad
shuffleM :: (MonadRandom m) => [a] -> m [a]
shuffleM :: [a] -> m [a]
shuffleM [a]
elements
    | [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
elements = [a] -> m [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    | Bool
otherwise     = ([Int] -> [a]) -> m [Int] -> m [a]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ([a] -> [Int] -> [a]
forall a. [a] -> [Int] -> [a]
shuffle [a]
elements) (Int -> m [Int]
forall (m :: * -> *). MonadRandom m => Int -> m [Int]
rseqM ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
elements Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
  where
    rseqM :: (MonadRandom m) => Int -> m [Int]
    rseqM :: Int -> m [Int]
rseqM Int
0 = [Int] -> m [Int]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    rseqM Int
i = (Int -> [Int] -> [Int]) -> m Int -> m [Int] -> m [Int]
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 (:) ((Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
i)) (Int -> m [Int]
forall (m :: * -> *). MonadRandom m => Int -> m [Int]
rseqM (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))