Source file env.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
open Errors

type ('obs, 'act, 'render) transition = {
  observation : 'obs;
  reward : float;
  terminated : bool;
  truncated : bool;
  info : Info.t;
}

let transition ?(reward = 0.) ?(terminated = false) ?(truncated = false)
    ?(info = Info.empty) ~observation () =
  { observation; reward; terminated; truncated; info }

type ('obs, 'act, 'render) t = {
  id : string option;
  mutable metadata : Metadata.t;
  observation_space : 'obs Space.t;
  action_space : 'act Space.t;
  mutable rng : Rune.Rng.key;
  mutable closed : bool;
  mutable needs_reset : bool;
  reset_impl :
    ('obs, 'act, 'render) t -> ?options:Info.t -> unit -> 'obs * Info.t;
  step_impl :
    ('obs, 'act, 'render) t -> 'act -> ('obs, 'act, 'render) transition;
  render_impl : ('obs, 'act, 'render) t -> 'render option;
  close_impl : ('obs, 'act, 'render) t -> unit;
}

let ensure_open env ~operation =
  if env.closed then
    raise_error
      (Closed_environment
         (Printf.sprintf "Operation '%s' on a closed environment" operation))

let ensure_reset env ~operation =
  if env.needs_reset then
    raise_error
      (Reset_needed
         (Printf.sprintf "Operation '%s' requires calling reset first" operation))

let create ?id ?(metadata = Metadata.default) ~rng ~observation_space
    ~action_space ~reset:reset_handler ~step:step_handler ?render ?close () =
  let render_impl = Option.value render ~default:(fun _ -> None) in
  let close_impl = Option.value close ~default:(fun _ -> ()) in
  let rec env =
    {
      id;
      metadata;
      observation_space;
      action_space;
      rng;
      closed = false;
      needs_reset = true;
      reset_impl;
      step_impl;
      render_impl;
      close_impl;
    }
  and reset_impl env ?options () =
    ensure_open env ~operation:"reset";
    let observation, info = reset_handler env ?options () in
    if not (Space.contains env.observation_space observation) then
      raise_error
        (Invalid_metadata
           "Reset produced an observation outside observation_space")
    else (
      env.needs_reset <- false;
      (observation, info))
  and step_impl env action =
    ensure_open env ~operation:"step";
    ensure_reset env ~operation:"step";
    if not (Space.contains env.action_space action) then
      raise_error (Invalid_action "Action outside of action_space");
    let transition = step_handler env action in
    if not (Space.contains env.observation_space transition.observation) then
      raise_error
        (Invalid_metadata
           "Step produced an observation outside observation_space");
    if transition.terminated || transition.truncated then
      env.needs_reset <- true;
    transition
  in
  env

let id env = env.id
let metadata env = env.metadata
let set_metadata env metadata = env.metadata <- metadata
let rng env = env.rng

let set_rng env rng =
  env.rng <- rng;
  env.needs_reset <- true

let take_rng env =
  let keys = Rune.Rng.split env.rng in
  env.rng <- keys.(0);
  keys.(1)

let split_rng env ~n =
  if n <= 0 then invalid_arg "Env.split_rng: n must be positive";
  let keys = Rune.Rng.split ~n:(n + 1) env.rng in
  env.rng <- keys.(0);
  Array.sub keys 1 n

let observation_space env = env.observation_space
let action_space env = env.action_space
let reset env ?options () = env.reset_impl env ?options ()
let step env action = env.step_impl env action

let render env =
  ensure_open env ~operation:"render";
  env.render_impl env

let close env =
  if not env.closed then (
    env.close_impl env;
    env.closed <- true;
    env.needs_reset <- true)

let closed env = env.closed