Looking for a clever way of stating relationships between two functions

For some project I am working on, I am trying to relate a large library of functions written in two different styles.

The two styles are related by a type class:

Class Embedding A B :=
  {
    toAbstract : A -> B;
    toConcrete : B -> A;
  }.

with two additional type classes giving properties of “good” embeddings:

Class ProperEmbedding {A B} `(Embedding A B) :=
  {
    roundtrip : forall a, toConcrete (toAbstract a) = a;
  }.

Class Isomorphism {A B} `(E : Embedding A B) `{! ProperEmbedding E} :=
  {
    roundtrip' : forall c, toAbstract (toConcrete c) = c;
  }.

Now my end goal is to rewrite calls to functions like concreteFun c1 c2 into toConcrete (abstractFun (toAbstract c1) (toAbstract c2)), whenever concreteFun and abstractFun are correctly related.

I have a type class:

Class CorrectTranslation
      {CI AI CO AO}
      `{ProperEmbedding CI AI}
      `{ProperEmbedding CO AO}
      (concreteFun : CI -> CO) (abstractFun : AI -> AO)
  :=
    {
      correctTranslation :
        forall ci ai co ao,
          toAbstract ci = ai ->
          concreteFun ci = co ->
          abstractFun ai = ao ->
          toAbstract co = ao;
    }.

and proved a rewriting theorem:

Theorem byCorrectTranslation
        {CI AI CO AO}
        (concreteFun : CI -> CO) (abstractFun : AI -> AO)
        `{ProperEmbedding CI AI}
        `{ProperEmbedding CO AO}
        `{CT : ! CorrectTranslation concreteFun abstractFun}
  : forall ci, concreteFun ci = toConcrete (abstractFun (toAbstract ci)).

which has been working well for unary functions.

However, when dealing with curried n-ary functions, this is quite impractical.

This also feels like something that should exist already in some shape or form.
So, are there nicer ways of dealing with such classes of rewrites?

Not sure if this is what you are looking for but here’s my attempt at solving the problem. It’s quite a mouthful. Note that with constructions like this there are many trade-offs; especially when it comes to the user interface. More concretely, you can’t simply call rewrite anymore unless you know how many arguments the function you are looking for has.

My choice of tactics (see the very end of this post) is not great but it’s very little code for two different modes of use: 1) ssreflect rewrite patterns to select the function and 2) just looking for the first function that fits. I’ve done very little testing so it’s possible that this design is actually crap, though. :slight_smile:
(In fact, the second mode of use will certainly be too slow for big goals since it traverses the whole goal type multiple times.)

As far as selecting the right occurence is concerned I think it is possible to add another Class that maps functions to numbers of arguments so that the user would only have to specify the function to rewrite without any _ after it. I’ll think about that some other time.

In the end, I am not sure that infrastructure like this is actually worth it. Duplicating the original Classes might work better in the long term. (I am bit of a hypocrite: I myself always avoid this kind of duplication. I am not patient enough to change multiple copies of essentially the same thing when something needs fixing. But that’s for small projects that nobody else really looks at.)

Set Universe Polymorphism.
Inductive list1 {T : Type} : Type :=
| cons0 : T -> list1
| cons1 : T -> list1 -> list1.
Arguments list1 : clear implicits.
Notation "'[+' x1 ; .. ; x2 ; x3 ]" := (cons1 x1 .. (cons1 x2 (cons0 x3)) ..) (at level 0).
Notation "'[+' x1 ]" := (cons0 x1) (at level 0).

Fixpoint last {T} (l : list1 T) : T :=
  match l with
  | cons0 t => t
  | cons1 _ l => last l
  end.

Fixpoint map {S T} (f : S -> T) (l : list1 S) : list1 T :=
  match l with
  | cons0 t => cons0 (f t)
  | cons1 t l => cons1 (f t) (map f l)
  end.
Fixpoint reduce {T} (f : T -> T -> T) (l : list1 T) : T :=
  match l with
  | cons0 t => t
  | cons1 t l => f t (reduce f l)
  end.

Fixpoint unzip {T1 T2} (l : list1 (T1 * T2)) : list1 T1 * list1 T2 :=
  match l with
  | cons0 (t1, t2) => (cons0 t1, cons0 t2)
  | cons1 (t1, t2) l =>
    let (l1, l2) := unzip l in
    (cons1 t1 l1, cons1 t2 l2)
  end.

Class ProperEmbeddings (CAs : list1 (Type * Type)) := proper_embeddings :
  reduce (fun A B => prod A B)
    (map
       (fun '(C, A) =>
          { E : Embedding C A & ProperEmbedding E}
       )
       CAs
    ).

Ltac solve_proper_embeddings :=
  let solve_one := refine (existT _ _ _); [apply _] in
  match goal with
  | |- ProperEmbeddings (cons0 (?A,?B)) => solve_one
  | |- ProperEmbeddings (cons1 (?A,?B) ?l) =>
    refine (pair _ _); [solve_one|change (ProperEmbeddings l); [solve_proper_embeddings]]
  end.

Section PETest.
  Global Instance embed_unit_id : Embedding unit unit. refine ({| toAbstract x := x; toConcrete x := x|}). Defined.
  Global Instance embed_nat_id : Embedding nat nat. refine ({| toAbstract x := x; toConcrete x := x|}). Defined.
  Global Instance embed_bool_id : Embedding bool bool. refine ({| toAbstract x := x; toConcrete x := x|}). Defined.
  Global Instance embed_proper_unit_id : ProperEmbedding embed_unit_id. Admitted.
  Global Instance embed_proper_nat_id : ProperEmbedding embed_nat_id. Admitted.
  Global Instance embed_proper_bool_id : ProperEmbedding embed_bool_id. Admitted.

  Goal ProperEmbeddings [+ (nat : Type, nat : Type); (bool : Type, bool : Type); (unit : Type, unit : Type)].
    (* Coq, go home! You are drunk! *)
    solve_proper_embeddings.
  Qed.
End PETest.

Hint Extern 0 (ProperEmbeddings _) => solve_proper_embeddings : typeclass_instances.


Fixpoint fun_of l T :=
  match l with
  | cons0 t => t -> T
  | cons1 t l => t -> fun_of l T
  end.


Definition args_of := reduce prod.

From Coq Require List.          (* No "Import" to avoid shadowing our list1 constructions. *)

Class CorrectTranslation'
      {Co Ao : Type}
      {CAis : list1 (Type * Type)}
      {PEis : ProperEmbeddings CAis}
      `{PE : ProperEmbedding Co Ao}
      (concreteFun : fun_of (map fst CAis) Co)
      (abstractFun : fun_of (map snd CAis) Ao)
  := correctTranslation' :
       (fix go CAis : ProperEmbeddings CAis -> fun_of (map fst CAis) Co -> fun_of (map snd CAis) Ao -> Prop :=
          match CAis as CAis return ProperEmbeddings CAis -> fun_of (map fst CAis) Co -> fun_of (map snd CAis) Ao -> Prop with
          | cons0 (Ci, Ai) =>
            fun '(existT _ _ _) concreteFun abstractFun =>
              forall ci ai ,
                toAbstract ci = ai ->
                forall co ao,
                concreteFun ci = co ->
                abstractFun ai = ao ->
                toAbstract co = ao
          | cons1 (Ci, Ai) CAis =>
            fun '(existT _ _ _, PEis) concreteFun abstractFun =>
              forall ci ai, toAbstract ci = ai -> go CAis PEis (concreteFun ci) (abstractFun ai)
          end) CAis PEis concreteFun abstractFun.

Declare Reduction unfold_stuff := cbv [reduce args_of fun_of map fst snd].

Section CTTest.
  Definition f (n : nat) (b : bool) := if b then n else 0.
  Definition g (n : nat) (b : bool) := if b then n else 0.
  Global Instance translate_nat_bool_unit_id :
    CorrectTranslation' (CAis := [+ (nat:Type, nat:Type) ; (bool:Type, bool:Type) ]) f g.
  red.
  (* pretty good already but it could use some more reduction. *)
  match goal with |- ?g => let g := eval unfold_stuff in g in change g end.
  (* perfect. *)
  unfold f.
  cbn.
  intros n n' ? [|] [|]; cbn; (try discriminate); congruence.
  Qed.
End CTTest.

Definition embeddings_from_proper0 {t12} `{ProperEmbeddings (cons0 t12)} : Embedding (fst t12) (snd t12).
Proof. destruct t12. apply H. Defined.

Definition embeddings_from_proper1 {t12} {l} `{ProperEmbeddings (cons1 t12 l)} : Embedding (fst t12) (snd t12).
Proof. destruct t12. apply H. Defined.

Definition proper_embedding_tail {t12} {l} `{ProperEmbeddings (cons1 t12 l)} : ProperEmbeddings l.
Proof. destruct H as [_ ?]. auto. Defined.

Theorem byCorrectTranslation'
      {Co Ao : Type}
      {CAis : list1 (Type * Type)}
      {PEis : ProperEmbeddings CAis}
      `{PE : ProperEmbedding Co Ao}
      (concreteFun : fun_of (map fst CAis) Co)
      (abstractFun : fun_of (map snd CAis) Ao)
      `{CT : ! CorrectTranslation' concreteFun abstractFun}
  :
    (fix go CAis :
    forall (PEis : ProperEmbeddings CAis)
           (concreteFun : fun_of (map fst CAis) Co)
           (abstractFun : fun_of (map snd CAis) Ao),
      Prop
     :=
       match CAis as CAis return
             forall (PEis : ProperEmbeddings CAis)
                    (concreteFun : fun_of (map fst CAis) Co)
                    (abstractFun : fun_of (map snd CAis) Ao),
               Prop
       with
          | cons0 (Ci, Ai) =>
            fun PEis concreteFun abstractFun =>
              forall ci, concreteFun ci = toConcrete (abstractFun (@toAbstract _ _ (embeddings_from_proper0) ci))
          | cons1 (Ci, Ai) CAis =>
            fun PEis concreteFun abstractFun =>
              forall ci, go CAis (proper_embedding_tail) (concreteFun ci) (abstractFun (@toAbstract _ _ (embeddings_from_proper1) ci))
       end
    )
      CAis PEis concreteFun abstractFun
.
Proof.
  induction CAis as [(Ci,Ai)|(Ci,Ai)].
  - destruct PEis. cbn.
    intros.
    red in CT. cbn in CT.
    admit.
  - destruct PEis as [[? ?] PEis].
    cbn. intros ci.
    apply IHCAis.
    apply CT.
    reflexivity.
Admitted.

Ltac translate' haystack tac :=
  let rew f CAis :=
      let CT := uconstr:(_ : CorrectTranslation' (CAis := CAis) f _) in
      let eq := uconstr:(byCorrectTranslation' (CAis := CAis) f _ (CT:=CT)) in
      tac eq
  in
  match haystack with
  | context C [?f _ _ _] =>
    let CAis := uconstr:([+(_,_);(_,_);(_,_);(_,_)]) in
    rew f CAis
  | context C [?f _ _ _] =>
    let CAis := uconstr:([+(_,_);(_,_);(_,_)]) in
    rew f CAis
  | context C [?f _ _] =>
    let CAis := uconstr:([+(_,_);(_,_)]) in
    rew f CAis
  | context C [?f _] =>
    let CAis := uconstr:([+(_,_)]) in
    rew f CAis
  | _ => fail "no match found"
  end.

From Coq Require Import ssrmatching ssreflect.

Tactic Notation "translate" :=
  match goal with |- ?g => translate' g ltac:(fun eq => rewrite eq) end.
Tactic Notation "translate" "at" "[" ssrpatternarg(pat) "]" :=
  ssrpattern pat;
  let name := fresh in
  intros name;
  let t := eval unfold name in name in
      translate' t ltac:(fun eq =>
                           let def_name := fresh in
                           pose proof (eq_refl name) as def_name;
                           unfold name at 1 in def_name;
                           rewrite eq in def_name;
                           rewrite <- def_name;
                           clear name def_name
                        ).

Lemma test n b : f n b = g n b.
  progress translate at [f _ _]. (* do something with f *)
  Fail progress translate at [f _ _]. (* make sure that we have done everything for [f] *)
  Restart.
  Fail progress translate at [f]. (* needs all arguments :( *)
  Restart.
  progress translate.           (* can also search for a function *)
  reflexivity.                  (* NB: this would work without [translate]. *)
Qed.