Source file policy.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
type ('obs, 'act) t = 'obs -> 'act * float option * float option

let deterministic f obs = (f obs, None, None)

let random ?rng env =
  let action_space = Env.action_space env in
  let initial_rng =
    match rng with Some key -> key | None -> Env.take_rng env
  in
  let rng_ref = ref initial_rng in
  fun _ ->
    let action, next = Space.sample ~rng:!rng_ref action_space in
    rng_ref := next;
    (action, None, None)

let greedy_discrete env ~score =
  let action_space = Env.action_space env in
  let start =
    match Space.boundary_values action_space with
    | Space.Value.Int value :: _ -> value
    | _ -> 0
  in
  let to_action index =
    match Space.unpack action_space (Space.Value.Int index) with
    | Ok action -> action
    | Error msg -> invalid_arg ("Policy.greedy_discrete: " ^ msg)
  in
  fun obs ->
    let scores = score obs in
    if Array.length scores = 0 then
      invalid_arg "Policy.greedy_discrete: score returned empty array";
    let best_idx = ref 0 in
    let best_val = ref scores.(0) in
    for i = 1 to Array.length scores - 1 do
      let candidate = scores.(i) in
      if candidate > !best_val then (
        best_idx := i;
        best_val := candidate)
    done;
    let action_value = start + !best_idx in
    let action = to_action action_value in
    (action, None, None)