1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
open Fehu
type observation = (float, Rune.float32_elt) Rune.t
type action = (int32, Rune.int32_elt) Rune.t
type render = string
type state = {
mutable x : float;
mutable x_dot : float;
mutable theta : float;
mutable theta_dot : float;
mutable steps : int;
rng : Rune.Rng.key ref;
}
let gravity = 9.8
let masscart = 1.0
let masspole = 0.1
let total_mass = masscart +. masspole
let length = 0.5
let polemass_length = masspole *. length
let force_mag = 10.0
let tau = 0.02
let theta_threshold_radians = 12. *. Float.pi /. 180.
let x_threshold = 2.4
let observation_space =
Space.Box.create
~low:
[|
-4.8;
-.Float.max_float;
-.theta_threshold_radians *. 2.;
-.Float.max_float;
|]
~high:
[| 4.8; Float.max_float; theta_threshold_radians *. 2.; Float.max_float |]
let action_space = Space.Discrete.create 2
let metadata =
Metadata.default
|> Metadata.add_render_mode "ansi"
|> Metadata.with_description (Some "Classic cart-pole balancing problem")
|> Metadata.add_author "Fehu"
|> Metadata.with_version (Some "0.1.0")
let reset _env ?options:_ () state =
let keys = Rune.Rng.split !(state.rng) ~n:5 in
state.rng := keys.(0);
let random_state i =
let r = Rune.Rng.uniform keys.(i + 1) Rune.float32 [| 1 |] in
let v = (Rune.to_array r).(0) in
(v -. 0.5) *. 0.1
in
state.x <- random_state 0;
state.x_dot <- random_state 1;
state.theta <- random_state 2;
state.theta_dot <- random_state 3;
state.steps <- 0;
let obs =
Rune.create Rune.float32 [| 4 |]
[| state.x; state.x_dot; state.theta; state.theta_dot |]
in
(obs, Info.empty)
let step _env action state =
let action_value =
let arr : Int32.t array = Rune.to_array action in
Int32.to_int arr.(0)
in
let force = if action_value = 1 then force_mag else -.force_mag in
let costheta = cos state.theta in
let sintheta = sin state.theta in
let temp =
(force
+. (polemass_length *. state.theta_dot *. state.theta_dot *. sintheta))
/. total_mass
in
let thetaacc =
((gravity *. sintheta) -. (costheta *. temp))
/. (length
*. ((4.0 /. 3.0) -. (masspole *. costheta *. costheta /. total_mass)))
in
let xacc = temp -. (polemass_length *. thetaacc *. costheta /. total_mass) in
state.x <- state.x +. (tau *. state.x_dot);
state.x_dot <- state.x_dot +. (tau *. xacc);
state.theta <- state.theta +. (tau *. state.theta_dot);
state.theta_dot <- state.theta_dot +. (tau *. thetaacc);
state.steps <- state.steps + 1;
let terminated =
state.x < -.x_threshold || state.x > x_threshold
|| state.theta < -.theta_threshold_radians
|| state.theta > theta_threshold_radians
in
let truncated = state.steps >= 500 in
let reward = if terminated then 0.0 else 1.0 in
let obs =
Rune.create Rune.float32 [| 4 |]
[| state.x; state.x_dot; state.theta; state.theta_dot |]
in
let info = Info.set "steps" (Info.int state.steps) Info.empty in
Env.transition ~observation:obs ~reward ~terminated ~truncated ~info ()
let render state =
Printf.sprintf
"CartPole: x=%.3f, x_dot=%.3f, theta=%.3f°, theta_dot=%.3f, steps=%d"
state.x state.x_dot
(state.theta *. 180. /. Float.pi)
state.theta_dot state.steps
let make ~rng () =
let state =
{
x = 0.0;
x_dot = 0.0;
theta = 0.0;
theta_dot = 0.0;
steps = 0;
rng = ref rng;
}
in
Env.create ~id:"CartPole-v1" ~metadata ~rng ~observation_space ~action_space
~reset:(fun env ?options () -> reset env ?options () state)
~step:(fun env action -> step env action state)
~render:(fun _ -> Some (render state))
~close:(fun _ -> ())
()