Source file lite3_database.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
open Links_core
open Utility
module S3 = Sqlite3
module Rc = Sqlite3.Rc
module Data = Sqlite3.Data
let error_as_string = function
| Rc.OK -> "ok"
| Rc.ERROR -> "error"
| Rc.INTERNAL -> "internal"
| Rc.PERM -> "perm"
| Rc.ABORT -> "abort"
| Rc.BUSY -> "busy"
| Rc.LOCKED -> "locked"
| Rc.NOMEM -> "nomem"
| Rc.READONLY -> "readonly"
| Rc.INTERRUPT -> "interrupt"
| Rc.IOERR -> "ioerr"
| Rc.CORRUPT -> "corrupt"
| Rc.NOTFOUND -> "notfound"
| Rc.FULL -> "full"
| Rc.CANTOPEN -> "cantopen"
| Rc.PROTOCOL -> "protocol"
| Rc.EMPTY -> "empty"
| Rc.SCHEMA -> "schema"
| Rc.TOOBIG -> "toobig"
| Rc.CONSTRAINT -> "constraint"
| Rc.MISMATCH -> "mismatch"
| Rc.MISUSE -> "misuse"
| Rc.NOFLS -> "nofls"
| Rc.AUTH -> "auth"
| Rc.FORMAT -> "format"
| Rc.RANGE -> "range"
| Rc.NOTADB -> "notadb"
| Rc.ROW -> "row"
| Rc.DONE -> "done"
| Rc.UNKNOWN e -> "unknown: "^ string_of_int (Rc.int_of_unknown e)
let data_to_string data =
match data with
Data.NONE -> ""
| Data.NULL -> ""
| Data.INT i -> Int64.to_string i
| Data.FLOAT f -> string_of_float' f
| Data.TEXT s | Data.BLOB s -> s
;;
class lite3_result (stmt: S3.stmt) = object
inherit Value.dbvalue
val result_buf_and_status =
let result_buf = PolyBuffer.init 1 1024 [] in
let rec get_results (status) =
match status with
`QueryOk -> (
match S3.step stmt with
Rc.OK|Rc.ROW ->
let data = Array.to_list (S3.row_data stmt) in
let row = List.map data_to_string data in
PolyBuffer.append result_buf row;
get_results `QueryOk
| Rc.DONE ->
`QueryOk
| e -> `QueryError (error_as_string e)
)
| _ -> (status)
in
(result_buf,get_results (`QueryOk))
method status : Value.db_status = snd(result_buf_and_status)
method nfields : int = S3.column_count stmt
method ntuples : int = PolyBuffer.length (fst result_buf_and_status)
method fname n : string = S3.column_name stmt n
method getvalue : int -> int -> string = fun n i ->
List.nth(PolyBuffer.get (fst result_buf_and_status) n) i
method gettuple : int -> string array = fun n ->
Array.of_list(PolyBuffer.get (fst result_buf_and_status) n)
method error : string =
match (snd result_buf_and_status) with
`QueryError(msg) -> msg
| `QueryOk -> "OK"
end
class lite3_database file = object(self)
inherit Value.database
val mutable _supports_shredding : bool option = None
val connection = S3.db_open file
method exec query : Value.dbvalue =
Debug.print query;
let stmt = S3.prepare connection query in
let rec last_res r =
match S3.prepare_tail r with
| None -> r
| Some rnew ->
let _ = new lite3_result r in
last_res rnew in
let stmt = last_res stmt in
new lite3_result stmt
method escape_string = Str.global_replace (Str.regexp_string "'") "''"
method quote_field s = "\"" ^ self#escape_string s ^ "\""
method driver_name () = "sqlite3"
method supports_shredding () =
match _supports_shredding with
| Some result -> result
| None ->
let stmt = S3.prepare connection "SELECT sqlite_version()" in
let result = new lite3_result stmt in
match result#status with
| `QueryOk ->
let version = result#getvalue 0 0 in
let is_number str =
Str.string_match (Str.regexp "^[0-9]+$") str 0
in
begin match Str.split (Str.regexp "[0-9]\\.") version with
| major :: minor :: _ when is_number major && is_number minor ->
let nmajor = int_of_string major in
let nminor = int_of_string minor in
if nmajor >= 3 && nminor >= 25 then
(_supports_shredding <- Some true; true)
else
(_supports_shredding <- Some false; false)
| _ -> _supports_shredding <- Some false; false
end
| _ -> false
method! make_insert_returning_query : string -> Sql.query -> string list =
fun returning q ->
match q with
Sql.Insert ins ->
[self#string_of_query q;
Printf.sprintf "select %s from %s where rowid = last_insert_rowid()" returning ins.ins_table]
| _ -> assert false
end
let driver_name = "sqlite3"
let _ = Value.register_driver (driver_name,
fun args ->
new lite3_database args,
Value.reconstruct_db_string (driver_name, args))