(** Reusable Error Monads *)

(**

This whole module is about using

type ('a, 'b) t = [
  | `Ok of 'a
  | `Error of 'b
]

*)



(** The basic error monad signature. *)

module type ERROR_MONAD = sig
  type ('a, 'b) t
  val return: '-> ('a, 'b) t
  val bind: ('a, 'b) t -> ('-> ('c, 'b) t) -> ('c, 'b) t
  val (>>=): ('a, 'b) t -> ('-> ('c, 'b) t) -> ('c, 'b) t
  val fail: '-> ('a, 'b) t
  val map: ('a, 'b) t -> ('-> 'c) -> ('c, 'b) t
  val (>>|): ('a, 'b) t -> ('-> 'c) -> ('c, 'b) t
  val destruct:
    ('a, 'b) t -> ([> `Ok of '| `Error of 'b] -> ('c, 'd) t) -> ('c, 'd) t
  val (>><):
    ('a, 'b) t -> ([> `Ok of '| `Error of 'b] -> ('c, 'd) t) -> ('c, 'd) t
end

(** The signature of the Result module: ERROR_MONAD + exposed result type. *)

module type RESULT = sig
  include ERROR_MONAD with type ('a, 'b) t = [
    | `Ok of 'a
    | `Error of 'b
  ]
end

(** Implementation of RESULT *)

module Result : RESULT = struct

  type ('a, 'b) t = [
    | `Ok of 'a
    | `Error of 'b
  ]

  let return a  = `Ok a
  let fail b = `Error b

  let bind x f =
    match x with
    | `Ok x -> f x
    | `Error e -> fail e

  let (>>=) = bind

  let map x f =
    match x with
    | `Ok x -> return (f x)
    | `Error e -> fail e

  let (>>|) = map

  let destruct (#as t) f = f t
  let (>><) = destruct
end

(** The signature of a basic “thread” module called Deferred (like Lwt). *)

module type DEFERRED = sig

  type 'a t
  val bind: 'a t -> ('-> 'b t) -> 'b t
  val return: '-> 'a t
  val catch: (unit -> 'a t) -> (exn -> 'a t) -> 'a t

end

(** The result of the functor application: With_deferred(Deferred). *)

module type DEFERRED_RESULT = sig


  type 'a deferred
  include ERROR_MONAD with type ('a, 'b) t = ('a, 'b) Result.t deferred

  val of_result: ('a, 'b) Result.t -> ('a, 'b) t

  val catch_deferred : (unit -> 'a deferred) -> ('a, exn) t

  val wrap_deferred : on_exn:(exn -> 'a) -> (unit -> 'b deferred) -> ('b, 'a) t

  val map_option: 'a option -> f:('-> ('r, 'b) t) ->
    ('r option, 'b) t

  val some: or_fail:'error -> 'a option -> ('a, 'error) t

end

module With_deferred (DeferredDEFERRED) :
  DEFERRED_RESULT with type 'a deferred = 'Deferred.t
struct

  type 'a deferred = 'Deferred.t
  (* type ('a, 'b) result = ('a, 'b) Result.t *)
  type ('a, 'b) t = ('a, 'b) Result.t Deferred.t

  let return x : (_, _) t = Deferred.return (`Ok x)
  let bind x f =
    Deferred.bind x (function
      | `Error e -> Deferred.return (`Error e)
      | `Ok o -> f o)

  let fail x = Deferred.return (Result.fail x)

  let (>>=) = bind
  let (>><) x f : (_, _) t = Deferred.bind x f

  let map m f =
    m >>= fun x ->
    return (f x)
  let (>>|) = map
  let of_result = Deferred.return


  let catch_deferred f : (_, _) t =
    Deferred.catch
      (fun () ->
         let a_exn_m : 'Deferred.t = f () in
         Deferred.bind a_exn_m (fun x -> Deferred.return (`Ok x)))
      (fun e -> Deferred.return (`Error e))

  let wrap_deferred ~on_exn f =
    let caught = catch_deferred f in
    caught >>< function
    | `Ok o -> return o
    | `Error e -> fail (on_exn e)

  let map_option o ~f =
    begin match o with
    | None -> return None
    | Some s ->
      f s
      >>< begin function
      | `Ok o -> return (Some o)
      | `Error e -> fail e
      end
    end

  let some ~or_fail = function
  | None -> fail or_fail
  | Some s -> return s
  let destruct t f =
    Deferred.bind t (function `Ok o -> f (`Ok o) | `Error e -> f (`Error e))
  let (>><) = destruct

end