Source file vector_env.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
open Errors

type autoreset_mode = Next_step | Disabled

type ('obs, 'act, 'render) step = {
  observations : 'obs array;
  rewards : float array;
  terminations : bool array;
  truncations : bool array;
  infos : Info.t array;
}

type ('obs, 'act, 'render) t = {
  envs : ('obs, 'act, 'render) Env.t array;
  autoreset_mode : autoreset_mode;
  observation_space : Space.packed;
  action_space : Space.packed;
  metadata : Metadata.t;
}

let ensure_non_empty envs =
  match envs with
  | [] -> invalid_arg "Vector_env.create_sync: env list cannot be empty"
  | _ -> ()

let ensure_consistent_spaces envs =
  match envs with
  | [] | [ _ ] -> ()
  | first :: rest ->
      let obs_space = Env.observation_space first in
      let act_space = Env.action_space first in
      List.iter
        (fun env ->
          let observation_space = Env.observation_space env in
          let action_space = Env.action_space env in
          if Space.shape obs_space <> Space.shape observation_space then
            raise_error
              (Invalid_metadata
                 "Vector env requires homogeneous observation spaces");
          if Space.shape act_space <> Space.shape action_space then
            raise_error
              (Invalid_metadata "Vector env requires homogeneous action spaces"))
        rest

let create_sync ?(autoreset_mode = Next_step) ~envs () =
  ensure_non_empty envs;
  ensure_consistent_spaces envs;
  let envs = Array.of_list envs in
  let first = envs.(0) in
  {
    envs;
    autoreset_mode;
    observation_space = Space.Pack (Env.observation_space first);
    action_space = Space.Pack (Env.action_space first);
    metadata = Env.metadata first;
  }

let num_envs vector_env = Array.length vector_env.envs
let observation_space vector_env = vector_env.observation_space
let action_space vector_env = vector_env.action_space
let metadata vector_env = vector_env.metadata

let reset vector_env () =
  let num_envs = num_envs vector_env in
  let results =
    Array.init num_envs (fun idx ->
        let env = vector_env.envs.(idx) in
        Env.reset env ())
  in
  let observations = Array.map fst results in
  let infos = Array.map snd results in
  (observations, infos)

let step vector_env actions =
  let num_envs = num_envs vector_env in
  if Array.length actions <> num_envs then
    invalid_arg "Vector_env.step: action array length mismatch";
  let results =
    Array.init num_envs (fun idx ->
        let env = vector_env.envs.(idx) in
        let transition = Env.step env actions.(idx) in
        match vector_env.autoreset_mode with
        | Disabled ->
            ( transition.observation,
              transition.reward,
              transition.terminated,
              transition.truncated,
              transition.info )
        | Next_step ->
            if transition.terminated || transition.truncated then
              let info =
                let packed =
                  Space.pack (Env.observation_space env) transition.observation
                in
                Info.set "vector.final_observation"
                  (Info.string (Space.Value.to_string packed))
                  transition.info
              in
              let obs_reset, info_reset = Env.reset env () in
              ( obs_reset,
                transition.reward,
                transition.terminated,
                transition.truncated,
                Info.merge info info_reset )
            else
              ( transition.observation,
                transition.reward,
                transition.terminated,
                transition.truncated,
                transition.info ))
  in
  let observations = Array.map (fun (obs, _, _, _, _) -> obs) results in
  let rewards = Array.map (fun (_, reward, _, _, _) -> reward) results in
  let terminations =
    Array.map (fun (_, _, terminated, _, _) -> terminated) results
  in
  let truncations =
    Array.map (fun (_, _, _, truncated, _) -> truncated) results
  in
  let infos = Array.map (fun (_, _, _, _, info) -> info) results in
  { observations; rewards; terminations; truncations; infos }

let close vector_env = Array.iter Env.close vector_env.envs