Source file Comparison.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

(* This file is free software, part of Logtk. See file "license" for more details. *)

(** {1 Partial Ordering values} *)

type t = Lt | Eq | Gt | Incomparable
(** partial order *)

type comparison = t
let equal : t -> t -> bool = Pervasives.(=)

let to_string = function
  | Lt -> "=<="
  | Gt -> "=>="
  | Eq -> "==="
  | Incomparable -> "=?="

let pp out c = CCFormat.string out (to_string c)

let combine cmp1 cmp2 = match cmp1, cmp2 with
  | Eq, Eq
  | Eq, Incomparable | Incomparable, Eq -> Eq
  | Lt, Incomparable | Incomparable, Lt -> Lt
  | Gt, Incomparable | Incomparable, Gt -> Gt
  | Incomparable, Incomparable -> Incomparable
  | _ ->
    invalid_arg "inconsistent comparisons"

let opp cmp = match cmp with
  | Eq | Incomparable -> cmp
  | Lt -> Gt
  | Gt -> Lt

let to_total ord = match ord with
  | Lt -> -1
  | Gt -> 1
  | Eq | Incomparable -> 0

let of_total ord = match ord with
  | 0 -> Eq
  | x when x > 0 -> Gt
  | _ -> Lt

let lexico a b = match a with
  | Incomparable -> b
  | _ -> a

type 'a comparator = 'a -> 'a -> t

let (++) = lexico

let (@>>) f g x y =
  match f x y with
  | Eq -> g x y
  | res -> res

type ('a,'b) combination = {
  call : 'a -> 'a -> 'b;
  ignore : t -> 'a -> 'a -> 'b;
}

let last f = {
  call = f;
  ignore = (fun res _ _ -> res);
}

let (>>>) f g = {
  call = (fun x y -> match f x y with
      | Incomparable -> g.call
      | res -> g.ignore res
    );
  ignore = (fun res _ _ -> g.ignore res);
}

let call f x y = f.call x y

let dominates f l1 l2 =
  let rec find_x l1 y = match l1 with
    | [] -> false
    | x::l1' ->
      match f x y with
      | Gt -> true
      | _ -> find_x l1' y
  and check_all l2 = match l2 with
    | [] -> true
    | y :: l2' -> find_x l1 y && check_all l2'
  in
  check_all l2;;

module type PARTIAL_ORD = sig
  type t

  val partial_cmp : t -> t -> comparison
end