Source file wrapper.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
let derive_id env suffix =
  match Env.id env with None -> None | Some id -> Some (id ^ suffix)

let inherit_metadata env = Env.metadata env

let map_observation ~(observation_space : 'obs Space.t)
    ~(f : 'inner_obs -> Info.t -> 'obs * Info.t)
    (env : ('inner_obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/ObservationWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () ->
      let observation, info = Env.reset env ?options () in
      f observation info)
    ~step:(fun _wrapper action ->
      let transition = Env.step env action in
      let observation, info = f transition.observation transition.info in
      Env.transition ~observation ~reward:transition.reward
        ~terminated:transition.terminated ~truncated:transition.truncated ~info
        ())
    ()

let map_action ~(action_space : 'act Space.t) ~(f : 'act -> 'inner_act)
    (env : ('obs, 'inner_act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/ActionWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      let transition = Env.step env (f action) in
      Env.transition ~observation:transition.observation
        ~reward:transition.reward ~terminated:transition.terminated
        ~truncated:transition.truncated ~info:transition.info ())
    ()

let map_reward ~(f : reward:float -> info:Info.t -> float * Info.t)
    (env : ('obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/RewardWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      let transition = Env.step env action in
      let reward, info = f ~reward:transition.reward ~info:transition.info in
      { transition with reward; info })
    ()

let time_limit ~max_episode_steps env =
  if max_episode_steps <= 0 then
    invalid_arg "Wrapper.time_limit: max_episode_steps must be positive";
  let steps = ref 0 in
  let add_time_limit_info info truncated =
    if truncated then Info.set "time_limit.truncated" (Info.bool true) info
    else info
  in
  Env.create
    ?id:(derive_id env "/TimeLimit")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () ->
      steps := 0;
      Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      incr steps;
      let transition = Env.step env action in
      if transition.terminated || transition.truncated then (
        steps := 0;
        transition)
      else if !steps >= max_episode_steps then (
        let info = add_time_limit_info transition.info true in
        steps := 0;
        { transition with truncated = true; info })
      else transition)
    ()

let with_metadata ~f env =
  let metadata = f (Env.metadata env) in
  Env.create ?id:(Env.id env) ~metadata ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action -> Env.step env action)
    ()