1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
let derive_id env suffix =
match Env.id env with None -> None | Some id -> Some (id ^ suffix)
let inherit_metadata env = Env.metadata env
let map_observation ~(observation_space : 'obs Space.t)
~(f : 'inner_obs -> Info.t -> 'obs * Info.t)
(env : ('inner_obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
Env.create
?id:(derive_id env "/ObservationWrapper")
~metadata:(inherit_metadata env) ~rng:(Env.rng env)
~render:(fun _ -> Env.render env)
~close:(fun _ -> Env.close env)
~observation_space ~action_space:(Env.action_space env)
~reset:(fun _wrapper ?options () ->
let observation, info = Env.reset env ?options () in
f observation info)
~step:(fun _wrapper action ->
let transition = Env.step env action in
let observation, info = f transition.observation transition.info in
Env.transition ~observation ~reward:transition.reward
~terminated:transition.terminated ~truncated:transition.truncated ~info
())
()
let map_action ~(action_space : 'act Space.t) ~(f : 'act -> 'inner_act)
(env : ('obs, 'inner_act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
Env.create
?id:(derive_id env "/ActionWrapper")
~metadata:(inherit_metadata env) ~rng:(Env.rng env)
~render:(fun _ -> Env.render env)
~close:(fun _ -> Env.close env)
~observation_space:(Env.observation_space env)
~action_space
~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
~step:(fun _wrapper action ->
let transition = Env.step env (f action) in
Env.transition ~observation:transition.observation
~reward:transition.reward ~terminated:transition.terminated
~truncated:transition.truncated ~info:transition.info ())
()
let map_reward ~(f : reward:float -> info:Info.t -> float * Info.t)
(env : ('obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
Env.create
?id:(derive_id env "/RewardWrapper")
~metadata:(inherit_metadata env) ~rng:(Env.rng env)
~render:(fun _ -> Env.render env)
~close:(fun _ -> Env.close env)
~observation_space:(Env.observation_space env)
~action_space:(Env.action_space env)
~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
~step:(fun _wrapper action ->
let transition = Env.step env action in
let reward, info = f ~reward:transition.reward ~info:transition.info in
{ transition with reward; info })
()
let time_limit ~max_episode_steps env =
if max_episode_steps <= 0 then
invalid_arg "Wrapper.time_limit: max_episode_steps must be positive";
let steps = ref 0 in
let add_time_limit_info info truncated =
if truncated then Info.set "time_limit.truncated" (Info.bool true) info
else info
in
Env.create
?id:(derive_id env "/TimeLimit")
~metadata:(inherit_metadata env) ~rng:(Env.rng env)
~render:(fun _ -> Env.render env)
~close:(fun _ -> Env.close env)
~observation_space:(Env.observation_space env)
~action_space:(Env.action_space env)
~reset:(fun _wrapper ?options () ->
steps := 0;
Env.reset env ?options ())
~step:(fun _wrapper action ->
incr steps;
let transition = Env.step env action in
if transition.terminated || transition.truncated then (
steps := 0;
transition)
else if !steps >= max_episode_steps then (
let info = add_time_limit_info transition.info true in
steps := 0;
{ transition with truncated = true; info })
else transition)
()
let with_metadata ~f env =
let metadata = f (Env.metadata env) in
Env.create ?id:(Env.id env) ~metadata ~rng:(Env.rng env)
~render:(fun _ -> Env.render env)
~close:(fun _ -> Env.close env)
~observation_space:(Env.observation_space env)
~action_space:(Env.action_space env)
~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
~step:(fun _wrapper action -> Env.step env action)
()