Source file env.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
open Errors

type ('obs, 'act, 'render) transition = {
  observation : 'obs;
  reward : float;
  terminated : bool;
  truncated : bool;
  info : Info.t;
}

let transition ?(reward = 0.) ?(terminated = false) ?(truncated = false)
    ?(info = Info.empty) ~observation () =
  { observation; reward; terminated; truncated; info }

type render_mode = [ `Human | `Rgb_array | `Ansi | `Svg | `Custom of string ]

let render_mode_to_string = function
  | `Human -> "human"
  | `Rgb_array -> "rgb_array"
  | `Ansi -> "ansi"
  | `Svg -> "svg"
  | `Custom name -> name

type ('obs, 'act, 'render) t = {
  id : string option;
  mutable metadata : Metadata.t;
  render_mode : render_mode option;
  observation_space : 'obs Space.t;
  action_space : 'act Space.t;
  mutable rng : Rune.Rng.key;
  mutable closed : bool;
  mutable needs_reset : bool;
  reset_impl :
    ('obs, 'act, 'render) t -> ?options:Info.t -> unit -> 'obs * Info.t;
  step_impl :
    ('obs, 'act, 'render) t -> 'act -> ('obs, 'act, 'render) transition;
  render_impl : ('obs, 'act, 'render) t -> 'render option;
  close_impl : ('obs, 'act, 'render) t -> unit;
}

let ensure_open env ~operation =
  if env.closed then
    raise_error
      (Closed_environment
         (Printf.sprintf "Operation '%s' on a closed environment" operation))

let ensure_reset env ~operation =
  if env.needs_reset then
    raise_error
      (Reset_needed
         (Printf.sprintf "Operation '%s' requires calling reset first" operation))

let create ?id ?(metadata = Metadata.default) ?render_mode ?validate_transition
    ~rng ~observation_space ~action_space ~reset:reset_handler
    ~step:step_handler ?render ?close () =
  (match render_mode with
  | None -> ()
  | Some mode ->
      let mode_key = render_mode_to_string mode in
      if not (Metadata.supports_render_mode mode_key metadata) then
        raise_error
          (Invalid_metadata
             (Printf.sprintf
                "Render mode '%s' is not declared in metadata.render_modes"
                mode_key)));
  let render_impl = Option.value render ~default:(fun _ -> None) in
  let close_impl = Option.value close ~default:(fun _ -> ()) in
  let render_mode = render_mode in
  let maybe_human_render env =
    match env.render_mode with
    | Some `Human ->
        (* Ignore return value; human mode renders side effects such as
           windows. *)
        ignore (env.render_impl env)
    | _ -> ()
  in
  let rec env =
    {
      id;
      metadata;
      render_mode;
      observation_space;
      action_space;
      rng;
      closed = false;
      needs_reset = true;
      reset_impl;
      step_impl;
      render_impl;
      close_impl;
    }
  and reset_impl env ?options () =
    ensure_open env ~operation:"reset";
    let observation, info = reset_handler env ?options () in
    if not (Space.contains env.observation_space observation) then
      let value =
        Space.pack env.observation_space observation |> Space.Value.to_string
      in
      raise_error
        (Invalid_metadata
           (Printf.sprintf
              "Reset produced an observation outside observation_space \
               (value=%s)"
              value))
    else (
      maybe_human_render env;
      env.needs_reset <- false;
      (observation, info))
  and step_impl env action =
    ensure_open env ~operation:"step";
    ensure_reset env ~operation:"step";
    (if not (Space.contains env.action_space action) then
       let value =
         Space.pack env.action_space action |> Space.Value.to_string
       in
       ignore
         (raise_error
            (Invalid_action
               (Printf.sprintf "Action outside of action_space (value=%s)" value))));
    let transition = step_handler env action in
    (if not (Space.contains env.observation_space transition.observation) then
       let value =
         Space.pack env.observation_space transition.observation
         |> Space.Value.to_string
       in
       ignore
         (raise_error
            (Invalid_metadata
               (Printf.sprintf
                  "Step produced an observation outside observation_space \
                   (value=%s)"
                  value))));
    Option.iter (fun validate -> validate transition) validate_transition;
    if transition.terminated || transition.truncated then
      env.needs_reset <- true;
    maybe_human_render env;
    transition
  in
  env

let id env = env.id
let metadata env = env.metadata
let render_mode env = env.render_mode

let set_metadata env metadata =
  Option.iter
    (fun mode ->
      let mode_key = render_mode_to_string mode in
      if not (Metadata.supports_render_mode mode_key metadata) then
        raise_error
          (Invalid_metadata
             (Printf.sprintf
                "Render mode '%s' is not declared in metadata.render_modes"
                mode_key)))
    env.render_mode;
  env.metadata <- metadata

let rng env = env.rng

let set_rng env rng =
  env.rng <- rng;
  env.needs_reset <- true

let take_rng env =
  let keys = Rune.Rng.split env.rng in
  env.rng <- keys.(0);
  keys.(1)

let split_rng env ~n =
  if n <= 0 then invalid_arg "Env.split_rng: n must be positive";
  let keys = Rune.Rng.split ~n:(n + 1) env.rng in
  env.rng <- keys.(0);
  Array.sub keys 1 n

let observation_space env = env.observation_space
let action_space env = env.action_space
let reset env ?options () = env.reset_impl env ?options ()
let step env action = env.step_impl env action

let render env =
  ensure_open env ~operation:"render";
  env.render_impl env

let close env =
  if not env.closed then (
    env.close_impl env;
    env.closed <- true;
    env.needs_reset <- true)

let closed env = env.closed