@with_kw struct LavaWorld <: MDP{GWPos, Symbol}
gridworld::SimpleGridWorld
lava::Union{Symbol, Array{GWPos}} = [GWPos(1,1)]
numtiles = 1
lavasquares = :all
goal = GWPos{7,5}
randomlava_onreset = false
penalty = -1.
reward = 1
observationtype = :all
rng:AbstractRNG = Random.GLOBAL_RNG
initstate = :random
end
function LavaWorldMDP(;size = (7,5), tprob = 1.0, discount = 0.95, rewards = nothing, lavatiles = 1, lava, goal, randlava_onreset,....,)
mdp = LavaWorld(gridworld = SimpleGridWorld(size = size, tprob = tprob, discount, ...))
lava == :random && randomize_lava!(mdp)
mdp
end
POMDPs.gen(mdp, s, a, rng) = (sp = rand(rng,transition(mdp.gridworld s, a)), r = POMDPs.reward(mdp.gridworld, s, a))
function random_lava(mdp)
...
end
function randomize_lava!(mdp)
...
end
function POMDPs.initialstate(mdp)
# return Deterministic(GWPos(1,5))
end
POMDPs.actions(mdp) = ...
"" .states ""
"" .reward ""
"" .isterminal ""
"" .discount ""
function POMDPs.convert_s(::Type{V}, s::GWPos, mdp::LavaWorld) where {V<:AbstractArray}
...
return s[...]
end
function POMDPs.convert_s(::Type{GWPos}, v::V, mdp::LavaWorld) where {V<:AbstractArray}
...
return v
end
goal(mdp::LavaWorld) = GWPos(findfirst(reshape(s, mdp.gridworld.size..., :)[:,:,3] .== 1.0).I)
function gen_occupancy(buffer, mdp)
...
occupancy
end
function render(mdp::LavaWorld, s=GWPos(7,5), a=nothing, color=s->10.0*POMDPs.reward(mdp,s), policy=nothing, return_compose=false)
img = POMDPs.render(mdp.gridworld, ....)
return_compos && return img
tmpfilename = "..."
img |> PNG(tmpfilename, 1cm, .* mdp.gridworld.size...)
load(tmpfilename)
end