Source file main.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
open Ppxlib
open Backend.Compiler_modules
open Core_kernel
open Expect_test_common.Std
open Expect_test_matcher.Std
open Mlt_parser

module Clflags  = Ocaml_common.Clflags
module Compmisc = Ocaml_common.Compmisc
module Printast = Ocaml_common.Printast
module Warnings = Ocaml_common.Warnings

let parse_contents ~fname contents =
  let lexbuf = Lexing.from_string contents in
  lexbuf.lex_curr_p <-
    { pos_fname = fname
    ; pos_lnum  = 1
    ; pos_bol   = 0
    ; pos_cnum  = 0
    };
  Ocaml_common.Location.input_name := fname;
  Parse.use_file lexbuf
;;

let reset_line_numbers = ref false
let line_numbers_delta = ref 0
let () =
  Caml.Hashtbl.add Toploop.directive_table
    "reset_line_numbers"
    (Directive_none (fun () -> reset_line_numbers := true))
;;

let print_line_numbers = ref false
let () =
  Caml.Hashtbl.add Toploop.directive_table
    "print_line_numbers"
    (Directive_bool (fun x -> print_line_numbers := x))
;;

let print_line_number ppf line =
  if !print_line_numbers then
    Format.fprintf ppf "%d" line
  else
    Format.pp_print_string ppf "_"
;;

[%%if ocaml_version < (4, 08, 0)]
let print_loc ppf (loc : Location.t) =
  let line = loc.loc_start.pos_lnum in
  let startchar = loc.loc_start.pos_cnum - loc.loc_start.pos_bol in
  let endchar = loc.loc_end.pos_cnum - loc.loc_start.pos_cnum + startchar in
  Format.fprintf ppf "Line %a" print_line_number line;
  if startchar >= 0 then
    Format.fprintf ppf ", characters %d-%d" startchar endchar;
  Format.fprintf ppf ":@.";
;;

let rec error_reporter ppf ({loc; msg; sub; if_highlight=_} : Ocaml_common.Location.error) =
  print_loc ppf loc;
  Format.fprintf ppf "Error: %s" msg;
  List.iter sub ~f:(fun err ->
    Format.fprintf ppf "@\n@[<2>%a@]" error_reporter err)
;;
[%%endif]

[%%if ocaml_version < (4, 06, 0)]
let warning_printer loc ppf w =
  if Warnings.is_active w then begin
    print_loc ppf loc;
    Format.fprintf ppf "Warning %a@." Warnings.print w
  end
[%%elif ocaml_version < (4, 08, 0)]
let warning_printer loc ppf w =
  match Warnings.report w with
    | `Inactive -> ()
  | `Active { Warnings. number; message; is_error; sub_locs = _ } ->
    print_loc ppf loc;
    if is_error
    then
      Format.fprintf ppf "Error (Warning %d): %s@." number message
    else Format.fprintf ppf "Warning %d: %s@." number message
[%%elif ocaml_version >= (4, 08, 0)]
let warning_reporter = Ocaml_common.Location.default_warning_reporter
let alert_reporter = Ocaml_common.Location.default_alert_reporter
[%%endif]
;;

[%%if ocaml_version >= (4, 08, 0)]
let report_printer () =
  let printer = Ocaml_common.Location.default_report_printer () in
  let print_loc _ _report ppf loc =
    let line = loc.loc_start.pos_lnum in
    let startchar = loc.loc_start.pos_cnum - loc.loc_start.pos_bol in
    let endchar = loc.loc_end.pos_cnum - loc.loc_start.pos_cnum + startchar in
    Format.fprintf ppf "Line %a" print_line_number line;
    if startchar >= 0 then
      Format.fprintf ppf ", characters %d-%d" startchar endchar;
    Format.fprintf ppf ":@."
  in
  { printer with Ocaml_common.Location.pp_main_loc = print_loc; pp_submsg_loc = print_loc }
[%%endif]

type var_and_value = V : 'a ref * 'a -> var_and_value

let protect_vars =
  let set_vars l = List.iter l ~f:(fun (V (r, v)) -> r := v) in
  fun vars ~f ->
    let backup = List.map vars ~f:(fun (V (r, _)) -> V (r, !r)) in
    set_vars vars;
    protect ~finally:(fun () -> set_vars backup) ~f
;;

[%%if ocaml_version < (4, 08, 0)]
let capture_compiler_stuff ppf ~f =
  protect_vars
    [ V (Ocaml_common.Location.formatter_for_warnings , ppf            )
    ; V (Ocaml_common.Location.warning_printer        , warning_printer)
    ; V (Ocaml_common.Location.error_reporter         , error_reporter )
    ]
    ~f
[%%else]
let capture_compiler_stuff ppf ~f =
  protect_vars
    [ V (Ocaml_common.Location.formatter_for_warnings , ppf             )
    ; V (Ocaml_common.Location.warning_reporter       , warning_reporter)
    ; V (Ocaml_common.Location.report_printer         , report_printer  )
    ; V (Ocaml_common.Location.alert_reporter         , alert_reporter  )
    ]
    ~f
[%%endif]
;;

let apply_rewriters = function
  | Ptop_dir _ as x -> x
  | Ptop_def s ->
    Ptop_def (Driver.map_structure s
              |> Migrate_parsetree.Driver.migrate_some_structure
                   (module Ppxlib_ast.Selected_ast))
;;

let verbose = ref false
let () =
  Caml.Hashtbl.add Toploop.directive_table
    "verbose"
    (Directive_bool (fun x -> verbose := x))
;;

let shift_line_numbers = object
  inherit [int] Ast_traverse.map_with_context
  method! position delta pos =
    { pos with pos_lnum  = pos.pos_lnum + delta }
end

let exec_phrase ppf phrase =
  if !reset_line_numbers then begin
    match phrase with
    | Ptop_def (st :: _) ->
      reset_line_numbers := false;
      line_numbers_delta := 1 - st.pstr_loc.loc_start.pos_lnum
    | _ -> ()
  end;
  let phrase =
    match !line_numbers_delta with
    | 0 -> phrase
    | n -> shift_line_numbers#toplevel_phrase n phrase
  in
  let phrase = apply_rewriters phrase in
  let module Js = Ppxlib_ast.Selected_ast in
  let ocaml_phrase = Js.to_ocaml Toplevel_phrase phrase in
  if !Clflags.dump_parsetree then Printast. top_phrase ppf ocaml_phrase;
  if !Clflags.dump_source    then Pprintast.top_phrase ppf phrase;
  Toploop.execute_phrase !verbose ppf ocaml_phrase
;;

let count_newlines : _ Cst.t Expectation.Body.t -> int =
  let count s = String.count s ~f:(Char.(=) '\n') in
  function
  | Unreachable | Output -> 0
  | Exact s -> count s
  | Pretty cst ->
    match cst with
    | Empty       e -> count e
    | Single_line s -> count s.trailing_spaces
    | Multi_lines m ->
      List.length m.lines - 1 +
      count m.leading_spaces  +
      count m.trailing_spaces
;;

let canonicalize_cst : 'a Cst.t -> 'a Cst.t = function
  | Empty _ -> Empty "\n"
  | Single_line s ->
    Multi_lines
      { leading_spaces  = "\n"
      ; trailing_spaces = "\n"
      ; indentation     = ""
      ; lines           =
          [ Not_blank
              { trailing_blanks = ""
              ; orig            = s.orig
              ; data            = s.data
              }
          ]
      }
  | Multi_lines m ->
    Multi_lines
      { leading_spaces  = "\n"
      ; trailing_spaces = "\n"
      ; indentation     = ""
      ; lines           = List.map m.lines ~f:Cst.Line.strip
      }
;;

let reconcile ~actual ~expect ~allow_output_patterns : _ Reconcile.Result.t =
  match
    Reconcile.expectation_body
      ~expect
      ~actual
      ~default_indent:0
      ~pad_single_line:false
      ~allow_output_patterns
  with
  | Match -> Match
  | Correction c -> Correction (Expectation.Body.map_pretty c ~f:canonicalize_cst)
;;

let redirect ~f =
  let stdout_backup = Unix.dup Unix.stdout in
  let stderr_backup = Unix.dup Unix.stderr in
  let filename = Caml.Filename.temp_file "expect-test" "stdout" in
  let fd_out = Unix.openfile filename [O_WRONLY; O_CREAT; O_TRUNC] 0o600 in
  Unix.dup2 fd_out Unix.stdout;
  Unix.dup2 fd_out Unix.stderr;
  let ic = In_channel.create filename in
  let read_up_to = ref 0 in
  let capture buf =
    Out_channel.flush stdout;
    Out_channel.flush stderr;
    let pos = Unix.lseek fd_out 0 SEEK_CUR in
    let len = pos - !read_up_to in
    read_up_to := pos;
    Caml.Buffer.add_channel buf ic len
  in
  protect ~f:(fun () -> f ~capture)
    ~finally:(fun () ->
      In_channel.close ic;
      Unix.close fd_out;
      Unix.dup2 stdout_backup Unix.stdout;
      Unix.dup2 stderr_backup Unix.stderr;
      Unix.close stdout_backup;
      Unix.close stderr_backup;
      Sys.remove filename)
;;

type chunk_result =
  | Matched
  | Didn't_match of Fmt.t Cst.t Expectation.Body.t

let eval_expect_file fname ~file_contents ~capture ~allow_output_patterns =
  (* 4.03: Warnings.reset_fatal (); *)
  let chunks, trailing_code =
    parse_contents ~fname file_contents
    |> split_chunks ~fname ~allow_output_patterns
  in
  let buf = Buffer.create 1024 in
  let ppf = Format.formatter_of_buffer buf in
  reset_line_numbers := false;
  line_numbers_delta := 0;
  let exec_phrases phrases =
    (* So that [%expect_exact] nodes look nice *)
    Buffer.add_char buf '\n';
    List.iter phrases ~f:(fun phrase ->
      let snap = Ocaml_common.Btype.snapshot () in
      match exec_phrase ppf phrase with
      | (_ : bool) -> ()
      | exception exn ->
        Location.report_exception ppf exn;
        Ocaml_common.Btype.backtrack snap
    );
    Format.pp_print_flush ppf ();
    let len = Buffer.length buf in
    if len > 0 && Buffer.nth buf (len - 1) <> '\n' then
      (* So that [%expect_exact] nodes look nice *)
      Buffer.add_char buf '\n';
    capture buf;
    if Buffer.nth buf (len - 1) <> '\n' then
      Buffer.add_char buf '\n';
    let s = Buffer.contents buf in
    Buffer.clear buf;
    s
  in
  let results =
    capture_compiler_stuff ppf ~f:(fun () ->
      List.map chunks ~f:(fun chunk ->
        let actual = exec_phrases chunk.phrases in
        match
          reconcile ~actual ~expect:chunk.expectation.body ~allow_output_patterns
        with
        | Match -> (chunk, actual, Matched)
        | Correction correction ->
          line_numbers_delta :=
            !line_numbers_delta +
            count_newlines correction -
            count_newlines chunk.expectation.body;
          (chunk, actual, Didn't_match correction)))
  in
  let trailing =
    match trailing_code with
    | None -> None
    | Some (phrases, pos_start, part) ->
      let actual, result =
        capture_compiler_stuff ppf ~f:(fun () ->
          let actual = exec_phrases phrases in
          (actual, reconcile ~actual ~expect:(Pretty Cst.empty)
                     ~allow_output_patterns))
      in
      Some (pos_start, actual, result, part)
  in
  (results, trailing)
;;

let interpret_results_for_diffing ~fname ~file_contents (results, trailing) =
  let corrections =
    List.filter_map results ~f:(fun (chunk, _, result) ->
      match result with
      | Matched -> None
      | Didn't_match correction ->
        Some (chunk.expectation,
              Matcher.Test_correction.Node_correction.Correction correction))
  in
  let trailing_output =
    match trailing with
    | None -> Reconcile.Result.Match
    | Some (_, _, correction, _) -> correction
  in
  Matcher.Test_correction.make
    ~location:{ filename    = File.Name.of_string fname
              ; line_number = 1
              ; line_start  = 0
              ; start_pos   = 0
              ; end_pos     = String.length file_contents
              }
    ~corrections
    ~trailing_output
    ~uncaught_exn:Match
;;

module T = Toplevel_expect_test_types

(* Take a part of a file, trimming spaces at the beginning as well as ';;' *)
let sub_file file_contents ~start ~stop =
  let rec loop start =
    if start >= stop then
      start
    else
      match file_contents.[start] with
      | ' ' | '\t' | '\n' -> loop (start + 1)
      | ';' when start + 1 < stop && file_contents.[start+1] = ';' ->
        loop (start + 2)
      | _ -> start
  in
  let start = loop start in
  String.sub file_contents ~pos:start ~len:(stop - start)
;;

let generate_doc_for_sexp_output ~fname:_ ~file_contents (results, trailing) =
  let rev_contents =
    List.rev_map results ~f:(fun (chunk, resp, _) ->
      let loc = chunk.phrases_loc in
      (chunk.part,
       { T.Chunk.
         ocaml_code = sub_file file_contents ~start:loc.loc_start.pos_cnum
                        ~stop:loc.loc_end.pos_cnum
       ; toplevel_response = resp
       }))
  in
  let rev_contents =
    match trailing with
    | None -> rev_contents
    | Some (pos_start, resp, _, part) ->
      (part,
       { ocaml_code = sub_file file_contents ~start:pos_start.Lexing.pos_cnum
                        ~stop:(String.length file_contents)
       ; toplevel_response = resp
       }) :: rev_contents
  in
  let parts =
    List.group (List.rev rev_contents) ~break:(fun (a, _) (b, _) -> a <> b)
    |> List.map ~f:(function chunks ->
      { T.Part.
        name = Option.bind (List.hd chunks) ~f:fst |> Option.value ~default:""
      ; chunks = List.map chunks ~f:snd
      })
  in
  let matched =
    List.for_all results ~f:(fun (_, _, r) -> r = Matched) &&
    match trailing with
    | None | Some (_, _, Reconcile.Result.Match, _) -> true
    | Some (_, _, Reconcile.Result.Correction _, _) -> false
  in
  { T.Document. parts; matched }
;;

let diff_command = ref None

let process_expect_file fname ~use_color ~in_place ~sexp_output ~use_absolute_path
      ~allow_output_patterns =
  (* Captures the working directory before running the user code, which might change it *)
  let cwd = Sys.getcwd () in
  let file_contents = In_channel.read_all fname in
  let result =
    redirect ~f:(eval_expect_file fname ~file_contents ~allow_output_patterns)
  in
  if sexp_output then begin
    let doc = generate_doc_for_sexp_output ~fname ~file_contents result in
    Format.printf "%a@." Sexp.pp_hum (T.Document.sexp_of_t doc)
  end;
  let corrected_fname = fname ^ ".corrected" in
  let remove_corrected () =
    if Sys.file_exists corrected_fname then
      Sys.remove corrected_fname
  in
  match interpret_results_for_diffing ~fname ~file_contents result with
  | Correction correction ->
    Matcher.write_corrected [correction]
      ~file:(if in_place then fname else corrected_fname)
      ~file_contents ~mode:Toplevel_expect_test;
    if in_place then begin
      remove_corrected ();
      true
    end else begin
      if not sexp_output then begin
        let maybe_use_absolute_path file =
          if use_absolute_path
          then Filename.concat cwd file
          else file
        in
        Ppxlib_print_diff.print ()
          ~file1:(maybe_use_absolute_path fname)
          ~file2:(maybe_use_absolute_path corrected_fname)
          ~use_color
          ?diff_command:!diff_command
      end;
      false
    end
  | Match ->
    if not in_place then remove_corrected ();
    true
;;

let setup_env () =
  (* Same as what run-tests.py does, to get repeatable output *)
  List.iter ~f:(fun (k, v) -> Unix.putenv k v)
    [ "LANG"        , "C"
    ; "LC_ALL"      , "C"
    ; "LANGUAGE"    , "C"
    ; "TZ"          , "GMT"
    ; "EMAIL"       , "Foo Bar <foo.bar@example.com>"
    ; "CDPATH"      , ""
    ; "COLUMNS"     , "80"
    ; "GREP_OPTIONS", ""
    ; "http_proxy"  , ""
    ; "no_proxy"    , ""
    ; "NO_PROXY"    , ""
    ; "TERM"        , "xterm"
    ]

[%%if ocaml_version < (4, 08, 0)]
let warnings = "@a-4-29-40-41-42-44-45-48-58"
[%%else]
let warnings = "@a-4-29-40-41-42-44-45-48-58-66"
[%%endif]

let setup_config () =
  Clflags.real_paths      := false;
  Clflags.strict_sequence := true;
  Clflags.strict_formats  := true;
  Clflags.unsafe_string   := Backend.unsafe_string ();
  Warnings.parse_options false warnings;
;;

let use_color   = ref true
let in_place    = ref false
let sexp_output = ref false
let use_absolute_path = ref false
let allow_output_patterns = ref false

[%%if ocaml_version < (4, 09, 0)]
let init_path () = Compmisc.init_path true
[%%else]
let init_path () = Compmisc.init_path ()
[%%endif]

let main fname =
  let cmd_line =
    Array.sub Sys.argv ~pos:!Arg.current ~len:(Array.length Sys.argv - !Arg.current)
  in
  setup_env ();
  setup_config ();
  Core.Sys.override_argv cmd_line;
  Toploop.set_paths ();
  init_path ();
  Toploop.toplevel_env := Compmisc.initial_env ();
  Sys.interactive := false;
  Backend.init ();
  let success =
    process_expect_file fname ~use_color:!use_color ~in_place:!in_place
      ~sexp_output:!sexp_output ~use_absolute_path:!use_absolute_path
      ~allow_output_patterns:!allow_output_patterns
  in
  exit (if success then 0 else 1)
;;

let args =
  Arg.align
    [ "-no-color", Clear use_color, " Produce colored diffs"
    ; "-in-place", Set in_place,    " Overwirte file in place"
    ; "-diff-cmd", String (fun s -> diff_command := Some s), " Diff command"
    ; "-sexp"    , Set sexp_output, " Output the result as a s-expression instead of diffing"
    ; "-absolute-path", Set use_absolute_path, " Use absolute path in diff-error message"
    ; "-allow-output-patterns", Set allow_output_patterns,
      " Allow output patterns in tests expectations";
    ]

let main () =
  let usage =
    Printf.sprintf "Usage: %s [OPTIONS] FILE [ARGS]\n"
      (Filename.basename Sys.argv.(0))
  in
  try
    Arg.parse args main (usage ^ "\nOptions are:");
    Out_channel.output_string Out_channel.stderr usage;
    exit 2
  with exn ->
    Location.report_exception Format.err_formatter exn;
    exit 2
;;