Source file grid_world.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
open Fehu
type obs = (int32, Nx.int32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = Text of string | Image of Render.image
let grid_size = 5
let max_steps = 200
let observation_space = Space.Multi_discrete.create [| grid_size; grid_size |]
let action_space = Space.Discrete.create 4
let is_goal row col = row = grid_size - 1 && col = grid_size - 1
let is_obstacle row col = row = 2 && col = 2
let is_valid row col =
row >= 0 && row < grid_size && col >= 0 && col < grid_size
&& not (is_obstacle row col)
let make_obs row col =
Nx.create Nx.int32 [| 2 |] [| Int32.of_int row; Int32.of_int col |]
let render_text row col =
let buffer = Bytes.make (grid_size * grid_size) '.' in
Bytes.set buffer ((row * grid_size) + col) 'A';
Bytes.set buffer (((grid_size - 1) * grid_size) + (grid_size - 1)) 'G';
Bytes.set buffer ((2 * grid_size) + 2) '#';
let rows =
List.init grid_size (fun r ->
Bytes.sub_string buffer (r * grid_size) grid_size)
in
Format.asprintf "Position: (%d, %d)@.%a" row col
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
Format.pp_print_string)
rows
let cell_size = 32
let frame_width = grid_size * cell_size
let frame_height = grid_size * cell_size
let fill_rect data ~x0 ~y0 ~w ~h ~r ~g ~b =
for dy = 0 to h - 1 do
let row_offset = (y0 + dy) * frame_width * 3 in
for dx = 0 to w - 1 do
let base = row_offset + ((x0 + dx) * 3) in
Bigarray.Array1.unsafe_set data base r;
Bigarray.Array1.unsafe_set data (base + 1) g;
Bigarray.Array1.unsafe_set data (base + 2) b
done
done
let render_image row col =
let len = frame_width * frame_height * 3 in
let data =
Bigarray.Array1.create Bigarray.int8_unsigned Bigarray.c_layout len
in
fill_rect data ~x0:0 ~y0:0 ~w:frame_width ~h:frame_height ~r:30 ~g:33 ~b:36;
for gr = 0 to grid_size - 1 do
for gc = 0 to grid_size - 1 do
let x0 = gc * cell_size in
let y0 = gr * cell_size in
fill_rect data ~x0 ~y0 ~w:cell_size ~h:cell_size ~r:44 ~g:48 ~b:52;
fill_rect data ~x0:(x0 + 1) ~y0:(y0 + 1) ~w:(cell_size - 2)
~h:(cell_size - 2) ~r:54 ~g:60 ~b:65
done
done;
let draw_cell cr cc ~r ~g ~b =
fill_rect data
~x0:((cc * cell_size) + 2)
~y0:((cr * cell_size) + 2)
~w:(cell_size - 4) ~h:(cell_size - 4) ~r ~g ~b
in
draw_cell row col ~r:78 ~g:162 ~b:196;
draw_cell (grid_size - 1) (grid_size - 1) ~r:76 ~g:175 ~b:80;
draw_cell 2 2 ~r:200 ~g:80 ~b:80;
Render.image ~width:frame_width ~height:frame_height data
let make ?render_mode () =
let row = ref 0 in
let col = ref 0 in
let steps = ref 0 in
let reset _env ?options:_ () =
row := 0;
col := 0;
steps := 0;
(make_obs 0 0, Info.empty)
in
let step _env action =
let r, c = (!row, !col) in
let nr, nc =
match Space.Discrete.to_int action with
| 0 -> (r - 1, c)
| 1 -> (r + 1, c)
| 2 -> (r, c - 1)
| 3 -> (r, c + 1)
| _ -> (r, c)
in
let nr, nc = if is_valid nr nc then (nr, nc) else (r, c) in
row := nr;
col := nc;
incr steps;
let terminated = is_goal nr nc in
let truncated = (not terminated) && !steps >= max_steps in
let reward = if terminated then 10.0 else -1.0 in
let info = Info.set "steps" (Info.int !steps) Info.empty in
Env.step_result ~observation:(make_obs nr nc) ~reward ~terminated ~truncated
~info ()
in
let render_mode_val = render_mode in
let render () =
match render_mode_val with
| Some `Rgb_array -> Some (Image (render_image !row !col))
| _ -> Some (Text (render_text !row !col))
in
Env.create ?render_mode ~render_modes:[ "ansi"; "rgb_array" ]
~id:"GridWorld-v0" ~observation_space ~action_space ~reset ~step ~render ()