Source file linExpr.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
(*
 * OSDP (OCaml SDP) is an OCaml frontend library to semi-definite
 * programming (SDP) solvers.
 * Copyright (C) 2012, 2014  P. Roux and P.L. Garoche
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *)

module type S = sig
  module  Coeff : Scalar.S
  type t
  val of_list : (Ident.t * Coeff.t) list -> Coeff.t -> t
  val to_list : t -> (Ident.t * Coeff.t) list * Coeff.t
  val var : Ident.t -> t
  val const : Coeff.t -> t
  val mult_scalar : Coeff.t -> t -> t
  val add : t -> t -> t
  val sub : t -> t -> t
  val replace : t -> (Ident.t * t) list -> t
  val remove : t -> Ident.t -> t
  val compare : t -> t -> int
  val is_var : t -> (Ident.t * Coeff.t) option
  val is_const : t -> Coeff.t option
  val choose : t -> (Ident.t * Coeff.t) option
  val pp : Format.formatter -> t -> unit
end

module Make (SC : Scalar.S) : S with module Coeff = SC = struct
  module Coeff = SC

  module IM = Ident.Map

  (* type invariant: lin is sorted by Ident.compare
   * and doesn't contain any zero coefficient *)
  type t = { const : Coeff.t; lin : Coeff.t IM.t }

  let of_list l c =
    let l = List.filter (fun (_, s) -> Coeff.(s <> zero)) l in
    let m =
      List.fold_left
        (fun acc (id, c) ->
           let c = try Coeff.(c + IM.find id acc) with Not_found -> c in
           IM.add id c acc)
        IM.empty l in
    { const = c; lin = m }

  let to_list a = IM.bindings a.lin, a.const

  let var id = { const = Coeff.zero; lin = IM.singleton id Coeff.one }

  let const c = { const = c; lin = IM.empty }

  let mult_scalar s a =
    if Coeff.(s = zero) then const Coeff.zero else
      { const = Coeff.(s * a.const);
        lin = IM.map (fun c -> Coeff.(s * c)) a.lin }

  let map2 f m1 m2 =
    let opt s = if Coeff.(s <> zero) then Some s else None in
    IM.merge
      (fun _ c1 c2 ->
       match c1, c2 with
       | None, None -> None
       | Some c1, None -> opt (f c1 Coeff.zero)
       | None, Some c2 -> opt (f Coeff.zero c2)
       | Some c1, Some c2 -> opt (f c1 c2))
      m1 m2
  let map2 f a1 a2 =
    { const = f a1.const a2.const; lin = map2 f a1.lin a2.lin }
  let add = map2 Coeff.add
  let sub = map2 Coeff.sub

  let replace l ll =
    let m, ll =
      List.fold_left
        (fun (m, ll) (id, l') ->
           try
             let c = IM.find id m in
             IM.remove id m, (c, l') :: ll
           with Not_found -> m, ll)
        (l.lin, []) ll in
    List.fold_left
      (fun l (c, l') ->
         let m =
           IM.fold
             (fun id c' m ->
                let c =
                  try Coeff.(IM.find id m + c * c')
                  with Not_found -> Coeff.(c * c') in
                if Coeff.(c <> zero) then IM.add id c m else IM.remove id m)
             l'.lin l.lin in
         { const = Coeff.(l.const + c * l'.const); lin = m })
      { const = l.const; lin = m } ll

  let remove l i = { const = l.const; lin = IM.remove i l.lin }

  let compare a1 a2 =
    let c = Coeff.compare a1.const a2.const in
    if c <> 0 then c else IM.compare Coeff.compare a1.lin a2.lin

  let is_var a =
    if Coeff.(equal a.const zero) then
      match IM.bindings a.lin with
        | [id_c] -> Some id_c
        | _ -> None
    else None

  let is_const a = if IM.is_empty a.lin then Some a.const else None

  let choose l = try Some (IM.choose l.lin) with Not_found -> None

  let pp fmt a =
    let pp_coeff fmt (x, a) =
      if Coeff.(a = one) then
        Format.fprintf fmt "%a" Ident.pp x
      else if Coeff.(a = minus_one) then
        Format.fprintf fmt "-%a" Ident.pp x
      else
        Format.fprintf fmt "%a %a" Coeff.pp a Ident.pp x in
    if is_const a <> None then
      Format.fprintf fmt "%a" Coeff.pp a.const
    else if Coeff.(a.const = zero) then
      Format.fprintf fmt "@[%a@]"
                     (Utils.pp_list ~sep:"@ + " pp_coeff)
                     (IM.bindings a.lin)
    else
      Format.fprintf fmt "@[%a@ + %a@]"
                     (Utils.pp_list ~sep:"@ + " pp_coeff)
                     (IM.bindings a.lin)
                     Coeff.pp a.const
end

module Q = Make (Scalar.Q)

module Float = Make (Scalar.Float)

exception Not_linear

module MakeScalar (L : S) : Scalar.S with type t = L.t = Scalar.Make (struct
  type t = L.t
  let compare = L.compare
  let zero = L.const L.Coeff.zero
  let one = L.const L.Coeff.one
  let of_float f = L.const (L.Coeff.of_float f)
  let to_float _ = assert false  (* should never happen *)
  let of_q x = L.const (L.Coeff.of_q x)
  let to_q _ = assert false  (* should never happen *)
  let add = L.add
  let sub = L.sub
  let mult e1 e2 =
    match L.is_const e1, L.is_const e2 with
    | None, None -> raise Not_linear
    | Some s, _ -> L.mult_scalar s e2
    | None, Some s -> L.mult_scalar s e1
  let div _ _ = assert false  (* should never happen *)
  let pp fmt a =
    let lin, const = L.to_list a in
    let l, l' =
      (if L.Coeff.(const = zero) then 0 else 1),
      List.length lin in
    if l + l' <= 1 then Format.fprintf fmt "%a" L.pp a
    else Format.fprintf fmt "(%a)" L.pp a
end)